feat(ai): add process-wide 1 req/s rate limiter to Mistral client

This commit is contained in:
2026-04-18 07:00:35 +02:00
parent aa965d292a
commit c95261d747
4 changed files with 105 additions and 8 deletions

View File

@@ -24,9 +24,11 @@ type Config struct {
}
type AIConfig struct {
APIKey string
AgentSimple string // Pre-created Mistral agent ID for Pass 1 (extraction + web search)
ModelComplex string // Model for Pass 2 (description + retry, e.g. mistral-large-latest)
APIKey string
AgentSimple string // Pre-created Mistral agent ID for Pass 1 (extraction + web search)
AgentDiscovery string // Agent ID for discovery pipeline (Task 7)
ModelComplex string // Model for Pass 2 (description + retry, e.g. mistral-large-latest)
RateLimitRPS float64 // Max requests per second to Mistral (0 = disabled)
}
type AppConfig struct {
@@ -169,6 +171,11 @@ func Load() (*Config, error) {
return nil, fmt.Errorf("SMTP_PORT: %w", err)
}
rpsAI, err := envFloat("AI_RATE_LIMIT_RPS", 1.0)
if err != nil {
return nil, fmt.Errorf("AI_RATE_LIMIT_RPS: %w", err)
}
jwtSecret := envStr("JWT_SECRET", "")
if jwtSecret == "" {
return nil, fmt.Errorf("JWT_SECRET is required")
@@ -248,9 +255,11 @@ func Load() (*Config, error) {
FrontendURL: envStr("FRONTEND_URL", "http://localhost:5173"),
},
AI: AIConfig{
APIKey: envStr("AI_API_KEY", ""),
AgentSimple: envStr("AI_AGENT_SIMPLE", ""),
ModelComplex: envStr("AI_MODEL_COMPLEX", "mistral-large-latest"),
APIKey: envStr("AI_API_KEY", ""),
AgentSimple: envStr("AI_AGENT_SIMPLE", ""),
AgentDiscovery: envStr("AI_AGENT_DISCOVERY", ""),
ModelComplex: envStr("AI_MODEL_COMPLEX", "mistral-large-latest"),
RateLimitRPS: rpsAI,
},
}, nil
}

View File

@@ -3,6 +3,7 @@ package ai
import (
"context"
"fmt"
"sync"
"time"
"github.com/VikingOwl91/mistral-go-sdk"
@@ -10,13 +11,42 @@ import (
"github.com/VikingOwl91/mistral-go-sdk/conversation"
)
// rateLimiter enforces a minimum interval between calls. Set rps<=0 to disable.
// Mirrors backend/internal/pkg/geocode/nominatim.go's lastReq gate.
type rateLimiter struct {
mu sync.Mutex
lastReq time.Time
minInterval time.Duration
}
func newRateLimiter(rps float64) *rateLimiter {
if rps <= 0 {
return &rateLimiter{minInterval: 0}
}
return &rateLimiter{minInterval: time.Duration(float64(time.Second) / rps)}
}
func (rl *rateLimiter) wait() {
if rl.minInterval == 0 {
return
}
rl.mu.Lock()
defer rl.mu.Unlock()
since := time.Since(rl.lastReq)
if since < rl.minInterval {
time.Sleep(rl.minInterval - since)
}
rl.lastReq = time.Now()
}
type Client struct {
sdk *mistral.Client
agentSimple string
modelComplex string
limiter *rateLimiter
}
func New(apiKey, agentSimple, modelComplex string) *Client {
func New(apiKey, agentSimple, modelComplex string, rps float64) *Client {
if modelComplex == "" {
modelComplex = "mistral-large-latest"
}
@@ -33,6 +63,7 @@ func New(apiKey, agentSimple, modelComplex string) *Client {
sdk: sdk,
agentSimple: agentSimple,
modelComplex: modelComplex,
limiter: newRateLimiter(rps),
}
}
@@ -54,6 +85,7 @@ type PassResult struct {
// Pass1 uses the Conversations API to call the pre-created agent (with web search).
func (c *Client) Pass1(ctx context.Context, prompt string) (PassResult, error) {
c.limiter.wait()
storeFalse := false
resp, err := c.sdk.StartConversation(ctx, &conversation.StartRequest{
AgentID: c.agentSimple,
@@ -78,6 +110,7 @@ func (c *Client) Pass1(ctx context.Context, prompt string) (PassResult, error) {
// Pass2 uses chat completions for description generation + retry fields.
func (c *Client) Pass2(ctx context.Context, systemPrompt, userPrompt string) (PassResult, error) {
c.limiter.wait()
resp, err := c.sdk.ChatComplete(ctx, &chat.CompletionRequest{
Model: c.modelComplex,
Messages: []chat.Message{

View File

@@ -0,0 +1,55 @@
package ai
import (
"sync"
"testing"
"time"
)
func TestRateLimiterSerializesCalls(t *testing.T) {
rl := newRateLimiter(2.0) // 2 req/s → minInterval 500ms
var (
mu sync.Mutex
times []time.Time
)
var wg sync.WaitGroup
for i := 0; i < 3; i++ {
wg.Add(1)
go func() {
defer wg.Done()
rl.wait()
mu.Lock()
times = append(times, time.Now())
mu.Unlock()
}()
}
wg.Wait()
// Sort times; gaps between consecutive must be >= 500ms - small tolerance.
sortTimes(times)
if gap := times[1].Sub(times[0]); gap < 450*time.Millisecond {
t.Errorf("gap[0->1] = %v, want >= 450ms", gap)
}
if gap := times[2].Sub(times[1]); gap < 450*time.Millisecond {
t.Errorf("gap[1->2] = %v, want >= 450ms", gap)
}
}
func TestRateLimiterDisabledWhenRPSZero(t *testing.T) {
rl := newRateLimiter(0) // disabled
start := time.Now()
for i := 0; i < 5; i++ {
rl.wait()
}
if elapsed := time.Since(start); elapsed > 50*time.Millisecond {
t.Errorf("expected no throttling when rps=0, elapsed %v", elapsed)
}
}
func sortTimes(ts []time.Time) {
for i := 1; i < len(ts); i++ {
for j := i; j > 0 && ts[j].Before(ts[j-1]); j-- {
ts[j], ts[j-1] = ts[j-1], ts[j]
}
}
}

View File

@@ -64,7 +64,7 @@ func (s *Server) registerRoutes() {
// Admin market routes
adminMarketHandler := market.NewAdminHandler(marketSvc)
aiClient := ai.New(s.cfg.AI.APIKey, s.cfg.AI.AgentSimple, s.cfg.AI.ModelComplex)
aiClient := ai.New(s.cfg.AI.APIKey, s.cfg.AI.AgentSimple, s.cfg.AI.ModelComplex, s.cfg.AI.RateLimitRPS)
researchHandler := market.NewResearchHandler(marketSvc, aiClient)
requireAdmin := middleware.RequireRole("admin")
market.RegisterAdminRoutes(v1, adminMarketHandler, researchHandler, requireAuth, requireAdmin)