Add streaming infrastructure: - SSE line parser handling multi-line data, comments, [DONE] sentinel - Generic Stream[T] pull-based iterator (no goroutines, no channel leaks) - doStream() HTTP helper for streaming endpoints - ChatCompleteStream() method - 28 new tests: SSE edge cases, iterator behavior, httptest integration
110 lines
2.6 KiB
Go
110 lines
2.6 KiB
Go
package mistral
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
)
|
|
|
|
func (c *Client) do(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
|
|
req, err := http.NewRequestWithContext(ctx, method, c.baseURL+path, body)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("mistral: create request: %w", err)
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+c.apiKey)
|
|
req.Header.Set("Accept", "application/json")
|
|
if body != nil {
|
|
req.Header.Set("Content-Type", "application/json")
|
|
}
|
|
resp, err := c.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("mistral: send request: %w", err)
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
func (c *Client) doJSON(ctx context.Context, method, path string, reqBody, respBody any) error {
|
|
var body io.Reader
|
|
if reqBody != nil {
|
|
data, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return fmt.Errorf("mistral: marshal request: %w", err)
|
|
}
|
|
body = bytes.NewReader(data)
|
|
}
|
|
resp, err := c.do(ctx, method, path, body)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode >= 400 {
|
|
return parseAPIError(resp)
|
|
}
|
|
if respBody != nil {
|
|
if err := json.NewDecoder(resp.Body).Decode(respBody); err != nil {
|
|
return fmt.Errorf("mistral: decode response: %w", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) doStream(ctx context.Context, method, path string, reqBody any) (*http.Response, error) {
|
|
var body io.Reader
|
|
if reqBody != nil {
|
|
data, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("mistral: marshal request: %w", err)
|
|
}
|
|
body = bytes.NewReader(data)
|
|
}
|
|
resp, err := c.do(ctx, method, path, body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if resp.StatusCode >= 400 {
|
|
defer resp.Body.Close()
|
|
return nil, parseAPIError(resp)
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
func parseAPIError(resp *http.Response) error {
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return &APIError{
|
|
StatusCode: resp.StatusCode,
|
|
Message: fmt.Sprintf("failed to read error response: %v", err),
|
|
}
|
|
}
|
|
var envelope struct {
|
|
Detail string `json:"detail"`
|
|
Message string `json:"message"`
|
|
Type string `json:"type"`
|
|
Param string `json:"param"`
|
|
Code string `json:"code"`
|
|
}
|
|
if err := json.Unmarshal(body, &envelope); err == nil {
|
|
msg := envelope.Message
|
|
if msg == "" {
|
|
msg = envelope.Detail
|
|
}
|
|
if msg == "" {
|
|
msg = string(body)
|
|
}
|
|
return &APIError{
|
|
StatusCode: resp.StatusCode,
|
|
Type: envelope.Type,
|
|
Message: msg,
|
|
Param: envelope.Param,
|
|
Code: envelope.Code,
|
|
}
|
|
}
|
|
return &APIError{
|
|
StatusCode: resp.StatusCode,
|
|
Message: string(body),
|
|
}
|
|
}
|