feat(ai): add process-wide 1 req/s rate limiter to Mistral client
This commit is contained in:
@@ -24,9 +24,11 @@ type Config struct {
|
||||
}
|
||||
|
||||
type AIConfig struct {
|
||||
APIKey string
|
||||
AgentSimple string // Pre-created Mistral agent ID for Pass 1 (extraction + web search)
|
||||
ModelComplex string // Model for Pass 2 (description + retry, e.g. mistral-large-latest)
|
||||
APIKey string
|
||||
AgentSimple string // Pre-created Mistral agent ID for Pass 1 (extraction + web search)
|
||||
AgentDiscovery string // Agent ID for discovery pipeline (Task 7)
|
||||
ModelComplex string // Model for Pass 2 (description + retry, e.g. mistral-large-latest)
|
||||
RateLimitRPS float64 // Max requests per second to Mistral (0 = disabled)
|
||||
}
|
||||
|
||||
type AppConfig struct {
|
||||
@@ -169,6 +171,11 @@ func Load() (*Config, error) {
|
||||
return nil, fmt.Errorf("SMTP_PORT: %w", err)
|
||||
}
|
||||
|
||||
rpsAI, err := envFloat("AI_RATE_LIMIT_RPS", 1.0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("AI_RATE_LIMIT_RPS: %w", err)
|
||||
}
|
||||
|
||||
jwtSecret := envStr("JWT_SECRET", "")
|
||||
if jwtSecret == "" {
|
||||
return nil, fmt.Errorf("JWT_SECRET is required")
|
||||
@@ -248,9 +255,11 @@ func Load() (*Config, error) {
|
||||
FrontendURL: envStr("FRONTEND_URL", "http://localhost:5173"),
|
||||
},
|
||||
AI: AIConfig{
|
||||
APIKey: envStr("AI_API_KEY", ""),
|
||||
AgentSimple: envStr("AI_AGENT_SIMPLE", ""),
|
||||
ModelComplex: envStr("AI_MODEL_COMPLEX", "mistral-large-latest"),
|
||||
APIKey: envStr("AI_API_KEY", ""),
|
||||
AgentSimple: envStr("AI_AGENT_SIMPLE", ""),
|
||||
AgentDiscovery: envStr("AI_AGENT_DISCOVERY", ""),
|
||||
ModelComplex: envStr("AI_MODEL_COMPLEX", "mistral-large-latest"),
|
||||
RateLimitRPS: rpsAI,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package ai
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/VikingOwl91/mistral-go-sdk"
|
||||
@@ -10,13 +11,42 @@ import (
|
||||
"github.com/VikingOwl91/mistral-go-sdk/conversation"
|
||||
)
|
||||
|
||||
// rateLimiter enforces a minimum interval between calls. Set rps<=0 to disable.
|
||||
// Mirrors backend/internal/pkg/geocode/nominatim.go's lastReq gate.
|
||||
type rateLimiter struct {
|
||||
mu sync.Mutex
|
||||
lastReq time.Time
|
||||
minInterval time.Duration
|
||||
}
|
||||
|
||||
func newRateLimiter(rps float64) *rateLimiter {
|
||||
if rps <= 0 {
|
||||
return &rateLimiter{minInterval: 0}
|
||||
}
|
||||
return &rateLimiter{minInterval: time.Duration(float64(time.Second) / rps)}
|
||||
}
|
||||
|
||||
func (rl *rateLimiter) wait() {
|
||||
if rl.minInterval == 0 {
|
||||
return
|
||||
}
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
since := time.Since(rl.lastReq)
|
||||
if since < rl.minInterval {
|
||||
time.Sleep(rl.minInterval - since)
|
||||
}
|
||||
rl.lastReq = time.Now()
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
sdk *mistral.Client
|
||||
agentSimple string
|
||||
modelComplex string
|
||||
limiter *rateLimiter
|
||||
}
|
||||
|
||||
func New(apiKey, agentSimple, modelComplex string) *Client {
|
||||
func New(apiKey, agentSimple, modelComplex string, rps float64) *Client {
|
||||
if modelComplex == "" {
|
||||
modelComplex = "mistral-large-latest"
|
||||
}
|
||||
@@ -33,6 +63,7 @@ func New(apiKey, agentSimple, modelComplex string) *Client {
|
||||
sdk: sdk,
|
||||
agentSimple: agentSimple,
|
||||
modelComplex: modelComplex,
|
||||
limiter: newRateLimiter(rps),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,6 +85,7 @@ type PassResult struct {
|
||||
|
||||
// Pass1 uses the Conversations API to call the pre-created agent (with web search).
|
||||
func (c *Client) Pass1(ctx context.Context, prompt string) (PassResult, error) {
|
||||
c.limiter.wait()
|
||||
storeFalse := false
|
||||
resp, err := c.sdk.StartConversation(ctx, &conversation.StartRequest{
|
||||
AgentID: c.agentSimple,
|
||||
@@ -78,6 +110,7 @@ func (c *Client) Pass1(ctx context.Context, prompt string) (PassResult, error) {
|
||||
|
||||
// Pass2 uses chat completions for description generation + retry fields.
|
||||
func (c *Client) Pass2(ctx context.Context, systemPrompt, userPrompt string) (PassResult, error) {
|
||||
c.limiter.wait()
|
||||
resp, err := c.sdk.ChatComplete(ctx, &chat.CompletionRequest{
|
||||
Model: c.modelComplex,
|
||||
Messages: []chat.Message{
|
||||
|
||||
55
backend/internal/pkg/ai/rate_limiter_test.go
Normal file
55
backend/internal/pkg/ai/rate_limiter_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRateLimiterSerializesCalls(t *testing.T) {
|
||||
rl := newRateLimiter(2.0) // 2 req/s → minInterval 500ms
|
||||
var (
|
||||
mu sync.Mutex
|
||||
times []time.Time
|
||||
)
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 3; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
rl.wait()
|
||||
mu.Lock()
|
||||
times = append(times, time.Now())
|
||||
mu.Unlock()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// Sort times; gaps between consecutive must be >= 500ms - small tolerance.
|
||||
sortTimes(times)
|
||||
if gap := times[1].Sub(times[0]); gap < 450*time.Millisecond {
|
||||
t.Errorf("gap[0->1] = %v, want >= 450ms", gap)
|
||||
}
|
||||
if gap := times[2].Sub(times[1]); gap < 450*time.Millisecond {
|
||||
t.Errorf("gap[1->2] = %v, want >= 450ms", gap)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterDisabledWhenRPSZero(t *testing.T) {
|
||||
rl := newRateLimiter(0) // disabled
|
||||
start := time.Now()
|
||||
for i := 0; i < 5; i++ {
|
||||
rl.wait()
|
||||
}
|
||||
if elapsed := time.Since(start); elapsed > 50*time.Millisecond {
|
||||
t.Errorf("expected no throttling when rps=0, elapsed %v", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func sortTimes(ts []time.Time) {
|
||||
for i := 1; i < len(ts); i++ {
|
||||
for j := i; j > 0 && ts[j].Before(ts[j-1]); j-- {
|
||||
ts[j], ts[j-1] = ts[j-1], ts[j]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -64,7 +64,7 @@ func (s *Server) registerRoutes() {
|
||||
|
||||
// Admin market routes
|
||||
adminMarketHandler := market.NewAdminHandler(marketSvc)
|
||||
aiClient := ai.New(s.cfg.AI.APIKey, s.cfg.AI.AgentSimple, s.cfg.AI.ModelComplex)
|
||||
aiClient := ai.New(s.cfg.AI.APIKey, s.cfg.AI.AgentSimple, s.cfg.AI.ModelComplex, s.cfg.AI.RateLimitRPS)
|
||||
researchHandler := market.NewResearchHandler(marketSvc, aiClient)
|
||||
requireAdmin := middleware.RequireRole("admin")
|
||||
market.RegisterAdminRoutes(v1, adminMarketHandler, researchHandler, requireAuth, requireAdmin)
|
||||
|
||||
Reference in New Issue
Block a user