Files
mistral-go-sdk/request.go
vikingowl 2b980e14b3 fix: post-review fixes — metadata, unknown types, typed tools, API polish
1. Add README, LICENSE (MIT), .gitignore, Makefile, CHANGELOG
2. Add Version constant and User-Agent header to all requests
3. Rename SetStream to EnableStream (narrower API surface)
4. Fix FinishReason in CompletionStreamChoice to use typed *FinishReason
5. Type conversation entry Content as chat.Content instead of json.RawMessage
6. Graceful unknown type handling — UnknownEntry, UnknownEvent,
   UnknownChunk, UnknownMessage, UnknownAgentTool all return data
   instead of erroring on unrecognized discriminator values
7. Type agent tools with AgentTool sealed interface + UnmarshalAgentTool
8. Add pagination params to ListConversations and ListLibraries
9. Move openapi.yaml to docs/openapi.yaml
2026-03-05 20:51:24 +01:00

300 lines
7.5 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package mistral
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"math/rand/v2"
"mime/multipart"
"net/http"
"strconv"
"time"
)
// doRetry executes an HTTP request with retry logic.
// buildReq is called on each attempt to create a fresh request.
func (c *Client) doRetry(ctx context.Context, buildReq func() (*http.Request, error)) (*http.Response, error) {
maxAttempts := 1 + c.maxRetries
var lastErr error
var lastResp *http.Response
for attempt := range maxAttempts {
if attempt > 0 {
delay := c.backoff(attempt)
if lastResp != nil {
if ra := retryAfterDelay(lastResp); ra > delay {
delay = ra
}
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(delay):
}
}
req, err := buildReq()
if err != nil {
return nil, fmt.Errorf("mistral: create request: %w", err)
}
resp, err := c.httpClient.Do(req)
if err != nil {
lastErr = fmt.Errorf("mistral: send request: %w", err)
if attempt < maxAttempts-1 {
continue
}
return nil, lastErr
}
if !shouldRetry(resp.StatusCode) || attempt >= maxAttempts-1 {
return resp, nil
}
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
lastResp = resp
}
if lastErr != nil {
return nil, lastErr
}
return lastResp, nil
}
func (c *Client) do(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
var bodyBytes []byte
if body != nil {
var err error
bodyBytes, err = io.ReadAll(body)
if err != nil {
return nil, fmt.Errorf("mistral: buffer request body: %w", err)
}
}
return c.doRetry(ctx, func() (*http.Request, error) {
var br io.Reader
if bodyBytes != nil {
br = bytes.NewReader(bodyBytes)
}
req, err := http.NewRequestWithContext(ctx, method, c.baseURL+path, br)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+c.apiKey)
req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", "mistral-go-sdk/"+Version)
if bodyBytes != nil {
req.Header.Set("Content-Type", "application/json")
}
return req, nil
})
}
func (c *Client) doJSON(ctx context.Context, method, path string, reqBody, respBody any) error {
var body io.Reader
if reqBody != nil {
data, err := json.Marshal(reqBody)
if err != nil {
return fmt.Errorf("mistral: marshal request: %w", err)
}
body = bytes.NewReader(data)
}
resp, err := c.do(ctx, method, path, body)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return parseAPIError(resp)
}
if respBody != nil {
if err := json.NewDecoder(resp.Body).Decode(respBody); err != nil {
return fmt.Errorf("mistral: decode response: %w", err)
}
}
return nil
}
func (c *Client) doStream(ctx context.Context, method, path string, reqBody any) (*http.Response, error) {
var body io.Reader
if reqBody != nil {
data, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("mistral: marshal request: %w", err)
}
body = bytes.NewReader(data)
}
resp, err := c.do(ctx, method, path, body)
if err != nil {
return nil, err
}
if resp.StatusCode >= 400 {
defer resp.Body.Close()
return nil, parseAPIError(resp)
}
return resp, nil
}
func (c *Client) doMultipart(ctx context.Context, path string, filename string, file io.Reader, fields map[string]string, respBody any) error {
var buf bytes.Buffer
w := multipart.NewWriter(&buf)
part, err := w.CreateFormFile("file", filename)
if err != nil {
return fmt.Errorf("mistral: create form file: %w", err)
}
if _, err := io.Copy(part, file); err != nil {
return fmt.Errorf("mistral: copy file data: %w", err)
}
for k, v := range fields {
if err := w.WriteField(k, v); err != nil {
return fmt.Errorf("mistral: write field %s: %w", k, err)
}
}
if err := w.Close(); err != nil {
return fmt.Errorf("mistral: close multipart: %w", err)
}
bodyBytes := buf.Bytes()
ct := w.FormDataContentType()
resp, err := c.doRetry(ctx, func() (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+path, bytes.NewReader(bodyBytes))
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+c.apiKey)
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", ct)
req.Header.Set("User-Agent", "mistral-go-sdk/"+Version)
return req, nil
})
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return parseAPIError(resp)
}
if respBody != nil {
if err := json.NewDecoder(resp.Body).Decode(respBody); err != nil {
return fmt.Errorf("mistral: decode response: %w", err)
}
}
return nil
}
func (c *Client) doMultipartStream(ctx context.Context, path string, filename string, file io.Reader, fields map[string]string) (*http.Response, error) {
var buf bytes.Buffer
w := multipart.NewWriter(&buf)
part, err := w.CreateFormFile("file", filename)
if err != nil {
return nil, fmt.Errorf("mistral: create form file: %w", err)
}
if _, err := io.Copy(part, file); err != nil {
return nil, fmt.Errorf("mistral: copy file data: %w", err)
}
for k, v := range fields {
if err := w.WriteField(k, v); err != nil {
return nil, fmt.Errorf("mistral: write field %s: %w", k, err)
}
}
if err := w.Close(); err != nil {
return nil, fmt.Errorf("mistral: close multipart: %w", err)
}
bodyBytes := buf.Bytes()
ct := w.FormDataContentType()
resp, err := c.doRetry(ctx, func() (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+path, bytes.NewReader(bodyBytes))
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer "+c.apiKey)
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Content-Type", ct)
req.Header.Set("User-Agent", "mistral-go-sdk/"+Version)
return req, nil
})
if err != nil {
return nil, err
}
if resp.StatusCode >= 400 {
defer resp.Body.Close()
return nil, parseAPIError(resp)
}
return resp, nil
}
// backoff computes the retry delay with exponential backoff and jitter.
func (c *Client) backoff(attempt int) time.Duration {
if c.retryDelay <= 0 {
return 0
}
delay := c.retryDelay * (1 << uint(attempt-1))
jitter := 0.5 + rand.Float64() // 0.51.5x
return time.Duration(float64(delay) * jitter)
}
// shouldRetry returns true if the status code is retryable.
func shouldRetry(statusCode int) bool {
return statusCode == http.StatusTooManyRequests || statusCode >= 500
}
// retryAfterDelay parses the Retry-After header.
func retryAfterDelay(resp *http.Response) time.Duration {
header := resp.Header.Get("Retry-After")
if header == "" {
return 0
}
if secs, err := strconv.Atoi(header); err == nil {
return time.Duration(secs) * time.Second
}
if t, err := http.ParseTime(header); err == nil {
if d := time.Until(t); d > 0 {
return d
}
}
return 0
}
func parseAPIError(resp *http.Response) error {
body, err := io.ReadAll(resp.Body)
if err != nil {
return &APIError{
StatusCode: resp.StatusCode,
Message: fmt.Sprintf("failed to read error response: %v", err),
}
}
var envelope struct {
Detail string `json:"detail"`
Message string `json:"message"`
Type string `json:"type"`
Param string `json:"param"`
Code string `json:"code"`
}
if err := json.Unmarshal(body, &envelope); err == nil {
msg := envelope.Message
if msg == "" {
msg = envelope.Detail
}
if msg == "" {
msg = string(body)
}
return &APIError{
StatusCode: resp.StatusCode,
Type: envelope.Type,
Message: msg,
Param: envelope.Param,
Code: envelope.Code,
}
}
return &APIError{
StatusCode: resp.StatusCode,
Message: string(body),
}
}