From 2fecdbc2ccabfb5c55087d0d5b187ecf5f091714 Mon Sep 17 00:00:00 2001 From: vikingowl Date: Thu, 5 Mar 2026 19:44:32 +0100 Subject: [PATCH] =?UTF-8?q?feat:=20Phase=205=20retry=20+=20resilience=20?= =?UTF-8?q?=E2=80=94=20exponential=20backoff,=20jitter,=20Retry-After?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add retry logic to all HTTP methods: - doRetry() core loop with configurable max retries - Exponential backoff with 0.5-1.5x jitter - Retry-After header support (seconds and HTTP-date) - Retry on 429 and 5xx; no retry on 4xx client errors - Context cancellation respected during retry delays - Multipart uploads also retry via doRetry() - 9 new tests: 429/500 recovery, exhaustion, no-retry-on-400, backoff math --- request.go | 143 +++++++++++++++++++++++---- retry_test.go | 264 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 386 insertions(+), 21 deletions(-) create mode 100644 retry_test.go diff --git a/request.go b/request.go index a4d6e84..e99dd54 100644 --- a/request.go +++ b/request.go @@ -6,25 +6,90 @@ import ( "encoding/json" "fmt" "io" + "math/rand/v2" "mime/multipart" "net/http" + "strconv" + "time" ) +// doRetry executes an HTTP request with retry logic. +// buildReq is called on each attempt to create a fresh request. +func (c *Client) doRetry(ctx context.Context, buildReq func() (*http.Request, error)) (*http.Response, error) { + maxAttempts := 1 + c.maxRetries + var lastErr error + var lastResp *http.Response + + for attempt := range maxAttempts { + if attempt > 0 { + delay := c.backoff(attempt) + if lastResp != nil { + if ra := retryAfterDelay(lastResp); ra > delay { + delay = ra + } + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(delay): + } + } + + req, err := buildReq() + if err != nil { + return nil, fmt.Errorf("mistral: create request: %w", err) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + lastErr = fmt.Errorf("mistral: send request: %w", err) + if attempt < maxAttempts-1 { + continue + } + return nil, lastErr + } + + if !shouldRetry(resp.StatusCode) || attempt >= maxAttempts-1 { + return resp, nil + } + + io.Copy(io.Discard, resp.Body) + resp.Body.Close() + lastResp = resp + } + + if lastErr != nil { + return nil, lastErr + } + return lastResp, nil +} + func (c *Client) do(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) { - req, err := http.NewRequestWithContext(ctx, method, c.baseURL+path, body) - if err != nil { - return nil, fmt.Errorf("mistral: create request: %w", err) - } - req.Header.Set("Authorization", "Bearer "+c.apiKey) - req.Header.Set("Accept", "application/json") + var bodyBytes []byte if body != nil { - req.Header.Set("Content-Type", "application/json") + var err error + bodyBytes, err = io.ReadAll(body) + if err != nil { + return nil, fmt.Errorf("mistral: buffer request body: %w", err) + } } - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("mistral: send request: %w", err) - } - return resp, nil + + return c.doRetry(ctx, func() (*http.Request, error) { + var br io.Reader + if bodyBytes != nil { + br = bytes.NewReader(bodyBytes) + } + req, err := http.NewRequestWithContext(ctx, method, c.baseURL+path, br) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+c.apiKey) + req.Header.Set("Accept", "application/json") + if bodyBytes != nil { + req.Header.Set("Content-Type", "application/json") + } + return req, nil + }) } func (c *Client) doJSON(ctx context.Context, method, path string, reqBody, respBody any) error { @@ -91,17 +156,21 @@ func (c *Client) doMultipart(ctx context.Context, path string, filename string, return fmt.Errorf("mistral: close multipart: %w", err) } - req, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+path, &buf) - if err != nil { - return fmt.Errorf("mistral: create request: %w", err) - } - req.Header.Set("Authorization", "Bearer "+c.apiKey) - req.Header.Set("Accept", "application/json") - req.Header.Set("Content-Type", w.FormDataContentType()) + bodyBytes := buf.Bytes() + ct := w.FormDataContentType() - resp, err := c.httpClient.Do(req) + resp, err := c.doRetry(ctx, func() (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, "POST", c.baseURL+path, bytes.NewReader(bodyBytes)) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+c.apiKey) + req.Header.Set("Accept", "application/json") + req.Header.Set("Content-Type", ct) + return req, nil + }) if err != nil { - return fmt.Errorf("mistral: send request: %w", err) + return err } defer resp.Body.Close() if resp.StatusCode >= 400 { @@ -115,6 +184,38 @@ func (c *Client) doMultipart(ctx context.Context, path string, filename string, return nil } +// backoff computes the retry delay with exponential backoff and jitter. +func (c *Client) backoff(attempt int) time.Duration { + if c.retryDelay <= 0 { + return 0 + } + delay := c.retryDelay * (1 << uint(attempt-1)) + jitter := 0.5 + rand.Float64() // 0.5–1.5x + return time.Duration(float64(delay) * jitter) +} + +// shouldRetry returns true if the status code is retryable. +func shouldRetry(statusCode int) bool { + return statusCode == http.StatusTooManyRequests || statusCode >= 500 +} + +// retryAfterDelay parses the Retry-After header. +func retryAfterDelay(resp *http.Response) time.Duration { + header := resp.Header.Get("Retry-After") + if header == "" { + return 0 + } + if secs, err := strconv.Atoi(header); err == nil { + return time.Duration(secs) * time.Second + } + if t, err := http.ParseTime(header); err == nil { + if d := time.Until(t); d > 0 { + return d + } + } + return 0 +} + func parseAPIError(resp *http.Response) error { body, err := io.ReadAll(resp.Body) if err != nil { diff --git a/retry_test.go b/retry_test.go new file mode 100644 index 0000000..e932cfb --- /dev/null +++ b/retry_test.go @@ -0,0 +1,264 @@ +package mistral + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "somegit.dev/vikingowl/mistral-go-sdk/chat" +) + +func TestRetry_429ThenSuccess(t *testing.T) { + var attempts atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := attempts.Add(1) + if n <= 2 { + w.WriteHeader(http.StatusTooManyRequests) + json.NewEncoder(w).Encode(map[string]any{"message": "rate limited"}) + return + } + json.NewEncoder(w).Encode(map[string]any{ + "id": "ok", "object": "chat.completion", + "model": "m", "created": 0, + "choices": []map[string]any{{ + "index": 0, "message": map[string]any{"role": "assistant", "content": "success"}, + "finish_reason": "stop", + }}, + "usage": map[string]any{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + }) + })) + defer server.Close() + + client := NewClient("key", + WithBaseURL(server.URL), + WithRetry(3, 1*time.Millisecond), + ) + resp, err := client.ChatComplete(context.Background(), &chat.CompletionRequest{ + Model: "m", + Messages: []chat.Message{&chat.UserMessage{Content: chat.TextContent("Hi")}}, + }) + if err != nil { + t.Fatal(err) + } + if resp.Choices[0].Message.Content.String() != "success" { + t.Errorf("got %q", resp.Choices[0].Message.Content.String()) + } + if attempts.Load() != 3 { + t.Errorf("expected 3 attempts, got %d", attempts.Load()) + } +} + +func TestRetry_500ThenSuccess(t *testing.T) { + var attempts atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := attempts.Add(1) + if n == 1 { + w.WriteHeader(http.StatusInternalServerError) + json.NewEncoder(w).Encode(map[string]any{"message": "server error"}) + return + } + json.NewEncoder(w).Encode(map[string]any{ + "id": "ok", "object": "chat.completion", + "model": "m", "created": 0, + "choices": []map[string]any{{ + "index": 0, "message": map[string]any{"role": "assistant", "content": "ok"}, + "finish_reason": "stop", + }}, + "usage": map[string]any{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + }) + })) + defer server.Close() + + client := NewClient("key", + WithBaseURL(server.URL), + WithRetry(2, 1*time.Millisecond), + ) + resp, err := client.ChatComplete(context.Background(), &chat.CompletionRequest{ + Model: "m", + Messages: []chat.Message{&chat.UserMessage{Content: chat.TextContent("Hi")}}, + }) + if err != nil { + t.Fatal(err) + } + if resp.Choices[0].Message.Content.String() != "ok" { + t.Errorf("got %q", resp.Choices[0].Message.Content.String()) + } + if attempts.Load() != 2 { + t.Errorf("expected 2 attempts, got %d", attempts.Load()) + } +} + +func TestRetry_NoRetryOn400(t *testing.T) { + var attempts atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts.Add(1) + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(map[string]any{"message": "bad request"}) + })) + defer server.Close() + + client := NewClient("key", + WithBaseURL(server.URL), + WithRetry(3, 1*time.Millisecond), + ) + _, err := client.ChatComplete(context.Background(), &chat.CompletionRequest{ + Model: "m", + Messages: []chat.Message{&chat.UserMessage{Content: chat.TextContent("Hi")}}, + }) + if err == nil { + t.Fatal("expected error") + } + if attempts.Load() != 1 { + t.Errorf("expected 1 attempt (no retry on 400), got %d", attempts.Load()) + } +} + +func TestRetry_ExhaustedRetries(t *testing.T) { + var attempts atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts.Add(1) + w.WriteHeader(http.StatusTooManyRequests) + json.NewEncoder(w).Encode(map[string]any{"message": "rate limited"}) + })) + defer server.Close() + + client := NewClient("key", + WithBaseURL(server.URL), + WithRetry(2, 1*time.Millisecond), + ) + _, err := client.ChatComplete(context.Background(), &chat.CompletionRequest{ + Model: "m", + Messages: []chat.Message{&chat.UserMessage{Content: chat.TextContent("Hi")}}, + }) + if err == nil { + t.Fatal("expected error after exhausting retries") + } + if !IsRateLimit(err) { + t.Errorf("expected rate limit error, got: %v", err) + } + if attempts.Load() != 3 { + t.Errorf("expected 3 attempts (1 + 2 retries), got %d", attempts.Load()) + } +} + +func TestRetry_NoRetryByDefault(t *testing.T) { + var attempts atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts.Add(1) + w.WriteHeader(http.StatusTooManyRequests) + json.NewEncoder(w).Encode(map[string]any{"message": "rate limited"}) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + _, err := client.ChatComplete(context.Background(), &chat.CompletionRequest{ + Model: "m", + Messages: []chat.Message{&chat.UserMessage{Content: chat.TextContent("Hi")}}, + }) + if err == nil { + t.Fatal("expected error") + } + if attempts.Load() != 1 { + t.Errorf("expected 1 attempt (no retries configured), got %d", attempts.Load()) + } +} + +func TestRetry_RetryAfterHeader(t *testing.T) { + var attempts atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + n := attempts.Add(1) + if n == 1 { + w.Header().Set("Retry-After", "0") + w.WriteHeader(http.StatusTooManyRequests) + json.NewEncoder(w).Encode(map[string]any{"message": "rate limited"}) + return + } + json.NewEncoder(w).Encode(map[string]any{ + "id": "ok", "object": "chat.completion", + "model": "m", "created": 0, + "choices": []map[string]any{{ + "index": 0, "message": map[string]any{"role": "assistant", "content": "ok"}, + "finish_reason": "stop", + }}, + "usage": map[string]any{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + }) + })) + defer server.Close() + + client := NewClient("key", + WithBaseURL(server.URL), + WithRetry(1, 1*time.Millisecond), + ) + resp, err := client.ChatComplete(context.Background(), &chat.CompletionRequest{ + Model: "m", + Messages: []chat.Message{&chat.UserMessage{Content: chat.TextContent("Hi")}}, + }) + if err != nil { + t.Fatal(err) + } + if resp.ID != "ok" { + t.Errorf("got id %q", resp.ID) + } +} + +func TestRetry_ContextCanceled(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTooManyRequests) + json.NewEncoder(w).Encode(map[string]any{"message": "rate limited"}) + })) + defer server.Close() + + client := NewClient("key", + WithBaseURL(server.URL), + WithRetry(10, 10*time.Second), + ) + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + _, err := client.ChatComplete(ctx, &chat.CompletionRequest{ + Model: "m", + Messages: []chat.Message{&chat.UserMessage{Content: chat.TextContent("Hi")}}, + }) + if err == nil { + t.Fatal("expected error") + } +} + +func TestBackoff(t *testing.T) { + c := &Client{retryDelay: 100 * time.Millisecond} + + for i := 1; i <= 5; i++ { + d := c.backoff(i) + base := 100 * time.Millisecond * (1 << uint(i-1)) + minD := time.Duration(float64(base) * 0.5) + maxD := time.Duration(float64(base) * 1.5) + if d < minD || d > maxD { + t.Errorf("attempt %d: backoff %v not in [%v, %v]", i, d, minD, maxD) + } + } +} + +func TestShouldRetry(t *testing.T) { + tests := []struct { + code int + want bool + }{ + {200, false}, + {400, false}, + {401, false}, + {404, false}, + {429, true}, + {500, true}, + {502, true}, + {503, true}, + } + for _, tt := range tests { + if got := shouldRetry(tt.code); got != tt.want { + t.Errorf("shouldRetry(%d) = %v, want %v", tt.code, got, tt.want) + } + } +}