From c95261d7479dbbf1a23f7d74436d143ae562c435 Mon Sep 17 00:00:00 2001 From: vikingowl Date: Sat, 18 Apr 2026 07:00:35 +0200 Subject: [PATCH] feat(ai): add process-wide 1 req/s rate limiter to Mistral client --- backend/internal/config/config.go | 21 +++++--- backend/internal/pkg/ai/client.go | 35 ++++++++++++- backend/internal/pkg/ai/rate_limiter_test.go | 55 ++++++++++++++++++++ backend/internal/server/routes.go | 2 +- 4 files changed, 105 insertions(+), 8 deletions(-) create mode 100644 backend/internal/pkg/ai/rate_limiter_test.go diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 7f4c362..b77aa61 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -24,9 +24,11 @@ type Config struct { } type AIConfig struct { - APIKey string - AgentSimple string // Pre-created Mistral agent ID for Pass 1 (extraction + web search) - ModelComplex string // Model for Pass 2 (description + retry, e.g. mistral-large-latest) + APIKey string + AgentSimple string // Pre-created Mistral agent ID for Pass 1 (extraction + web search) + AgentDiscovery string // Agent ID for discovery pipeline (Task 7) + ModelComplex string // Model for Pass 2 (description + retry, e.g. mistral-large-latest) + RateLimitRPS float64 // Max requests per second to Mistral (0 = disabled) } type AppConfig struct { @@ -169,6 +171,11 @@ func Load() (*Config, error) { return nil, fmt.Errorf("SMTP_PORT: %w", err) } + rpsAI, err := envFloat("AI_RATE_LIMIT_RPS", 1.0) + if err != nil { + return nil, fmt.Errorf("AI_RATE_LIMIT_RPS: %w", err) + } + jwtSecret := envStr("JWT_SECRET", "") if jwtSecret == "" { return nil, fmt.Errorf("JWT_SECRET is required") @@ -248,9 +255,11 @@ func Load() (*Config, error) { FrontendURL: envStr("FRONTEND_URL", "http://localhost:5173"), }, AI: AIConfig{ - APIKey: envStr("AI_API_KEY", ""), - AgentSimple: envStr("AI_AGENT_SIMPLE", ""), - ModelComplex: envStr("AI_MODEL_COMPLEX", "mistral-large-latest"), + APIKey: envStr("AI_API_KEY", ""), + AgentSimple: envStr("AI_AGENT_SIMPLE", ""), + AgentDiscovery: envStr("AI_AGENT_DISCOVERY", ""), + ModelComplex: envStr("AI_MODEL_COMPLEX", "mistral-large-latest"), + RateLimitRPS: rpsAI, }, }, nil } diff --git a/backend/internal/pkg/ai/client.go b/backend/internal/pkg/ai/client.go index f78da70..fe17eab 100644 --- a/backend/internal/pkg/ai/client.go +++ b/backend/internal/pkg/ai/client.go @@ -3,6 +3,7 @@ package ai import ( "context" "fmt" + "sync" "time" "github.com/VikingOwl91/mistral-go-sdk" @@ -10,13 +11,42 @@ import ( "github.com/VikingOwl91/mistral-go-sdk/conversation" ) +// rateLimiter enforces a minimum interval between calls. Set rps<=0 to disable. +// Mirrors backend/internal/pkg/geocode/nominatim.go's lastReq gate. +type rateLimiter struct { + mu sync.Mutex + lastReq time.Time + minInterval time.Duration +} + +func newRateLimiter(rps float64) *rateLimiter { + if rps <= 0 { + return &rateLimiter{minInterval: 0} + } + return &rateLimiter{minInterval: time.Duration(float64(time.Second) / rps)} +} + +func (rl *rateLimiter) wait() { + if rl.minInterval == 0 { + return + } + rl.mu.Lock() + defer rl.mu.Unlock() + since := time.Since(rl.lastReq) + if since < rl.minInterval { + time.Sleep(rl.minInterval - since) + } + rl.lastReq = time.Now() +} + type Client struct { sdk *mistral.Client agentSimple string modelComplex string + limiter *rateLimiter } -func New(apiKey, agentSimple, modelComplex string) *Client { +func New(apiKey, agentSimple, modelComplex string, rps float64) *Client { if modelComplex == "" { modelComplex = "mistral-large-latest" } @@ -33,6 +63,7 @@ func New(apiKey, agentSimple, modelComplex string) *Client { sdk: sdk, agentSimple: agentSimple, modelComplex: modelComplex, + limiter: newRateLimiter(rps), } } @@ -54,6 +85,7 @@ type PassResult struct { // Pass1 uses the Conversations API to call the pre-created agent (with web search). func (c *Client) Pass1(ctx context.Context, prompt string) (PassResult, error) { + c.limiter.wait() storeFalse := false resp, err := c.sdk.StartConversation(ctx, &conversation.StartRequest{ AgentID: c.agentSimple, @@ -78,6 +110,7 @@ func (c *Client) Pass1(ctx context.Context, prompt string) (PassResult, error) { // Pass2 uses chat completions for description generation + retry fields. func (c *Client) Pass2(ctx context.Context, systemPrompt, userPrompt string) (PassResult, error) { + c.limiter.wait() resp, err := c.sdk.ChatComplete(ctx, &chat.CompletionRequest{ Model: c.modelComplex, Messages: []chat.Message{ diff --git a/backend/internal/pkg/ai/rate_limiter_test.go b/backend/internal/pkg/ai/rate_limiter_test.go new file mode 100644 index 0000000..328f1e2 --- /dev/null +++ b/backend/internal/pkg/ai/rate_limiter_test.go @@ -0,0 +1,55 @@ +package ai + +import ( + "sync" + "testing" + "time" +) + +func TestRateLimiterSerializesCalls(t *testing.T) { + rl := newRateLimiter(2.0) // 2 req/s → minInterval 500ms + var ( + mu sync.Mutex + times []time.Time + ) + var wg sync.WaitGroup + for i := 0; i < 3; i++ { + wg.Add(1) + go func() { + defer wg.Done() + rl.wait() + mu.Lock() + times = append(times, time.Now()) + mu.Unlock() + }() + } + wg.Wait() + + // Sort times; gaps between consecutive must be >= 500ms - small tolerance. + sortTimes(times) + if gap := times[1].Sub(times[0]); gap < 450*time.Millisecond { + t.Errorf("gap[0->1] = %v, want >= 450ms", gap) + } + if gap := times[2].Sub(times[1]); gap < 450*time.Millisecond { + t.Errorf("gap[1->2] = %v, want >= 450ms", gap) + } +} + +func TestRateLimiterDisabledWhenRPSZero(t *testing.T) { + rl := newRateLimiter(0) // disabled + start := time.Now() + for i := 0; i < 5; i++ { + rl.wait() + } + if elapsed := time.Since(start); elapsed > 50*time.Millisecond { + t.Errorf("expected no throttling when rps=0, elapsed %v", elapsed) + } +} + +func sortTimes(ts []time.Time) { + for i := 1; i < len(ts); i++ { + for j := i; j > 0 && ts[j].Before(ts[j-1]); j-- { + ts[j], ts[j-1] = ts[j-1], ts[j] + } + } +} diff --git a/backend/internal/server/routes.go b/backend/internal/server/routes.go index 9d54155..f3764b9 100644 --- a/backend/internal/server/routes.go +++ b/backend/internal/server/routes.go @@ -64,7 +64,7 @@ func (s *Server) registerRoutes() { // Admin market routes adminMarketHandler := market.NewAdminHandler(marketSvc) - aiClient := ai.New(s.cfg.AI.APIKey, s.cfg.AI.AgentSimple, s.cfg.AI.ModelComplex) + aiClient := ai.New(s.cfg.AI.APIKey, s.cfg.AI.AgentSimple, s.cfg.AI.ModelComplex, s.cfg.AI.RateLimitRPS) researchHandler := market.NewResearchHandler(marketSvc, aiClient) requireAdmin := middleware.RequireRole("admin") market.RegisterAdminRoutes(v1, adminMarketHandler, researchHandler, requireAuth, requireAdmin)