feat(ai): pluggable provider interface, Ollama + Mistral impls, migrate Pass2 sites

Replaces the Mistral-only ai.Client with an ai.Provider interface backed by
Ollama and Mistral implementations. Migrates enrichment + similarity callers
to ai.Provider.Chat. Research endpoint returns 501 until commit 2 reinstates
it on the new orchestrator.
This commit is contained in:
2026-04-24 16:35:18 +02:00
parent 2adb4882c7
commit 24e072b63d
26 changed files with 982 additions and 1078 deletions

View File

@@ -29,6 +29,7 @@ import (
"os"
"time"
"marktvogt.de/backend/internal/config"
"marktvogt.de/backend/internal/domain/discovery/enrich"
"marktvogt.de/backend/internal/pkg/ai"
"marktvogt.de/backend/internal/pkg/scrape"
@@ -65,8 +66,14 @@ func realMain() int {
)
flag.Parse()
apiKey := os.Getenv("AI_API_KEY")
model := os.Getenv("AI_MODEL_COMPLEX")
apiKey := os.Getenv("AI_MISTRAL_API_KEY")
if apiKey == "" {
apiKey = os.Getenv("AI_API_KEY") // legacy fallback
}
model := os.Getenv("AI_MISTRAL_MODEL")
if model == "" {
model = os.Getenv("AI_MODEL_COMPLEX") // legacy fallback
}
if model == "" {
model = "mistral-large-latest"
}
@@ -74,9 +81,14 @@ func realMain() int {
if userAgent == "" {
userAgent = "marktvogt-eval/1.0 (+https://marktvogt.de)"
}
client := ai.New(apiKey, "", model, 1.0)
if !client.Enabled() {
slog.Error("AI client not configured (set AI_API_KEY)")
client, err := ai.NewFromConfig(config.AIConfig{
Provider: "mistral",
MistralAPIKey: apiKey,
MistralModel: model,
RateLimitRPS: 1.0,
})
if err != nil {
slog.Error("AI client not configured", "error", err)
return 2
}
@@ -95,7 +107,7 @@ func realMain() int {
return runSimilarityMode(ctx, client, cfg)
case modeCategory:
scraper := scrape.New(userAgent)
enricher := enrich.NewMistralLLMEnricher(client, scraper)
enricher := enrich.NewLLMEnricher(client, scraper)
cfg := evalConfig{
model: model,
fixturePath: pathWithDefault(*fixturePath, "backend/cmd/discovery-eval/fixtures/category.json"),
@@ -117,7 +129,7 @@ func pathWithDefault(p, dflt string) string {
return p
}
func runSimilarityMode(ctx context.Context, client *ai.Client, cfg evalConfig) int {
func runSimilarityMode(ctx context.Context, client ai.Provider, cfg evalConfig) int {
fixture, err := loadFixture(cfg.fixturePath)
if err != nil {
slog.Error("load fixture", "path", cfg.fixturePath, "error", err)
@@ -125,7 +137,7 @@ func runSimilarityMode(ctx context.Context, client *ai.Client, cfg evalConfig) i
}
slog.Info("loaded fixture", "mode", modeSimilarity, "pairs", len(fixture.Pairs), "path", cfg.fixturePath)
classifier := enrich.NewMistralSimilarityClassifier(client)
classifier := enrich.NewSimilarityClassifier(client)
cache, err := loadCache(cfg.cachePath)
if err != nil {

View File

@@ -11,6 +11,7 @@ require (
github.com/google/uuid v1.6.0
github.com/jackc/pgx/v5 v5.8.0
github.com/pquerna/otp v1.5.0
github.com/santhosh-tekuri/jsonschema/v5 v5.3.1
github.com/valkey-io/valkey-go v1.0.72
golang.org/x/crypto v0.49.0
golang.org/x/oauth2 v0.35.0

View File

@@ -73,6 +73,8 @@ github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII=
github.com/quic-go/quic-go v0.57.0 h1:AsSSrrMs4qI/hLrKlTH/TGQeTMY0ib1pAOX7vA3AdqE=
github.com/quic-go/quic-go v0.57.0/go.mod h1:ly4QBAjHA2VhdnxhojRsCUOeJwKYg+taDlos92xb1+s=
github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 h1:lZUw3E0/J3roVtGQ+SCrUrg3ON6NgVqpn3+iol9aGu4=
github.com/santhosh-tekuri/jsonschema/v5 v5.3.1/go.mod h1:uToXkOrWAZ6/Oc07xWQrPOhJotwFIyu2bBVN41fcDUY=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=

View File

@@ -22,6 +22,7 @@ type Config struct {
Turnstile TurnstileConfig
Notification NotificationConfig
AI AIConfig
Search SearchConfig
Discovery DiscoveryConfig
}
@@ -32,10 +33,19 @@ type DiscoveryConfig struct {
}
type AIConfig struct {
APIKey string
AgentSimple string // Pre-created Mistral agent ID for Pass 1 (extraction + web search)
ModelComplex string // Model for Pass 2 (description + retry, e.g. mistral-large-latest)
RateLimitRPS float64 // Max requests per second to Mistral (0 = disabled)
Provider string // "ollama" or "mistral"; default "ollama"
RateLimitRPS float64 // Max requests per second to upstream; 0 = disabled (Mistral only)
OllamaURL string // default "http://localhost:11434"
OllamaModel string // default "qwen2.5:14b-instruct"
MistralAPIKey string
MistralModel string // default "mistral-large-latest"
}
type SearchConfig struct {
Provider string // "searxng" (only option today)
SearxngURL string // default "http://localhost:8888"
}
type AppConfig struct {
@@ -272,10 +282,16 @@ func Load() (*Config, error) {
FrontendURL: envStr("FRONTEND_URL", "http://localhost:5173"),
},
AI: AIConfig{
APIKey: envStr("AI_API_KEY", ""),
AgentSimple: envStr("AI_AGENT_SIMPLE", ""),
ModelComplex: envStr("AI_MODEL_COMPLEX", "mistral-large-latest"),
RateLimitRPS: rpsAI,
Provider: envStr("AI_PROVIDER", "ollama"),
RateLimitRPS: rpsAI,
OllamaURL: envStr("AI_OLLAMA_URL", "http://localhost:11434"),
OllamaModel: envStr("AI_OLLAMA_MODEL", "qwen2.5:14b-instruct"),
MistralAPIKey: envStr("AI_MISTRAL_API_KEY", envStr("AI_API_KEY", "")),
MistralModel: envStr("AI_MISTRAL_MODEL", envStr("AI_MODEL_COMPLEX", "mistral-large-latest")),
},
Search: SearchConfig{
Provider: envStr("SEARCH_PROVIDER", "searxng"),
SearxngURL: envStr("SEARCH_SEARXNG_URL", "http://localhost:8888"),
},
Discovery: DiscoveryConfig{
Token: discoveryToken,

View File

@@ -22,32 +22,25 @@ const maxScrapeURLs = 5
// a no-context LLM prompt — grounding is required for useful output.
var ErrNoScrapedContent = errors.New("no scrapeable content from any source URL")
// Scraper is the narrow interface MistralLLMEnricher depends on. Satisfied
// Scraper is the narrow interface LLMEnricher depends on. Satisfied
// by *pkg/scrape.Client; tests inject a stub.
type Scraper interface {
Fetch(ctx context.Context, url string) (string, error)
}
// aiPass2 is the narrow interface for the AI client's Pass2 chat-completion
// method. Lets tests inject a stub that returns canned JSON without hitting
// the real Mistral API.
type aiPass2 interface {
Pass2(ctx context.Context, systemPrompt, userPrompt string) (ai.PassResult, error)
}
// MistralLLMEnricher implements LLMEnricher by scraping quellen URLs and
// feeding the concatenated text to Mistral's chat-completion endpoint with
// ProviderLLMEnricher implements LLMEnricher by scraping quellen URLs and
// feeding the concatenated text to an AI provider's chat endpoint with
// a JSON response format.
type MistralLLMEnricher struct {
Client aiPass2
type ProviderLLMEnricher struct {
AI ai.Provider
Scraper Scraper
}
// NewMistralLLMEnricher constructs an enricher bound to a Mistral ai.Client
// NewLLMEnricher constructs an enricher bound to an ai.Provider
// and a scraper. Both are required; call sites should fall back to
// NoopLLMEnricher when AI is disabled rather than passing nil here.
func NewMistralLLMEnricher(client aiPass2, scraper Scraper) *MistralLLMEnricher {
return &MistralLLMEnricher{Client: client, Scraper: scraper}
func NewLLMEnricher(provider ai.Provider, scraper Scraper) *ProviderLLMEnricher {
return &ProviderLLMEnricher{AI: provider, Scraper: scraper}
}
// llmResponse is the JSON shape we instruct Mistral to return. Any field may
@@ -60,11 +53,11 @@ type llmResponse struct {
}
// EnrichMissing scrapes up to maxScrapeURLs of req.Quellen, concatenates the
// extracted text, and asks Mistral to fill category / opening_hours /
// extracted text, and asks the AI provider to fill category / opening_hours /
// description. Fails with ErrNoScrapedContent if zero URLs return usable
// text — empty-context LLM calls hallucinate.
func (e *MistralLLMEnricher) EnrichMissing(ctx context.Context, req LLMRequest) (Enrichment, error) {
if e.Client == nil || e.Scraper == nil {
func (e *ProviderLLMEnricher) EnrichMissing(ctx context.Context, req LLMRequest) (Enrichment, error) {
if e.AI == nil || e.Scraper == nil {
return Enrichment{}, errors.New("mistral enricher not configured")
}
@@ -92,28 +85,26 @@ func (e *MistralLLMEnricher) EnrichMissing(ctx context.Context, req LLMRequest)
systemPrompt := buildSystemPrompt()
userPrompt := buildUserPrompt(req, blocks)
result, err := e.Client.Pass2(ctx, systemPrompt, userPrompt)
resp, err := e.AI.Chat(ctx, &ai.ChatRequest{
SystemPrompt: systemPrompt,
UserMessage: userPrompt,
JSONMode: true,
})
if err != nil {
return Enrichment{}, fmt.Errorf("pass2: %w", err)
return Enrichment{}, fmt.Errorf("chat: %w", err)
}
var parsed llmResponse
if err := json.Unmarshal([]byte(result.Content), &parsed); err != nil {
return Enrichment{}, fmt.Errorf("parse llm response: %w (content=%q)", err, result.Content)
if err := json.Unmarshal([]byte(resp.Content), &parsed); err != nil {
return Enrichment{}, fmt.Errorf("parse llm response: %w (content=%q)", err, resp.Content)
}
// Build the Enrichment payload with only the fields the model produced.
// Sources entries + token counts + model tag feed the eval harness.
now := time.Now().UTC()
out := Enrichment{
Sources: Sources{},
Model: result.Model,
EnrichedAt: &now,
}
if result.Usage != nil {
out.InputTokens = result.Usage.PromptTokens
out.OutputTokens = result.Usage.CompletionTokens
}
if s := strings.TrimSpace(parsed.Category); s != "" {
out.Category = s
out.Sources["category"] = ProvenanceLLM

View File

@@ -5,8 +5,6 @@ import (
"errors"
"strings"
"testing"
"marktvogt.de/backend/internal/pkg/ai"
)
const catMittelaltermarkt = "mittelaltermarkt"
@@ -25,33 +23,15 @@ func (s *stubScraper) Fetch(_ context.Context, url string) (string, error) {
return s.responses[url], nil
}
// stubPass2 captures the prompts it received and returns a canned JSON body.
type stubPass2 struct {
lastSystem string
lastUser string
result ai.PassResult
err error
}
func (s *stubPass2) Pass2(_ context.Context, systemPrompt, userPrompt string) (ai.PassResult, error) {
s.lastSystem = systemPrompt
s.lastUser = userPrompt
return s.result, s.err
}
func TestMistralEnrich_HappyPath(t *testing.T) {
scraper := &stubScraper{responses: map[string]string{
"https://a.example/markt": "Ein Mittelaltermarkt mit Ritterspielen und Markttreiben.",
"https://b.example/info": "Sa-So jeweils 10-18 Uhr.",
}}
client := &stubPass2{
result: ai.PassResult{
Content: `{"category":"mittelaltermarkt","opening_hours":"Sa-So 10:00-18:00","description":"Ein Markt mit Ritterspielen."}`,
Model: "mistral-large-latest",
Usage: &ai.UsageInfo{PromptTokens: 450, CompletionTokens: 60},
},
stub := &stubProvider{
content: `{"category":"mittelaltermarkt","opening_hours":"Sa-So 10:00-18:00","description":"Ein Markt mit Ritterspielen."}`,
}
e := NewMistralLLMEnricher(client, scraper)
e := NewLLMEnricher(stub, scraper)
req := LLMRequest{
MarktName: "Mittelaltermarkt Dresden",
@@ -64,40 +44,37 @@ func TestMistralEnrich_HappyPath(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if stub.seen.JSONMode != true {
t.Fatalf("JSONMode must be set")
}
// Result carries LLM fields, provenance llm, model + token counts.
// Result carries LLM fields and provenance llm.
if got.Category != catMittelaltermarkt || got.Description == "" || got.OpeningHours == "" {
t.Errorf("missing fields in result: %+v", got)
}
if got.Sources["category"] != ProvenanceLLM {
t.Errorf("category provenance: got %q, want llm", got.Sources["category"])
}
if got.Model != "mistral-large-latest" {
t.Errorf("model: got %q, want mistral-large-latest", got.Model)
}
if got.InputTokens != 450 || got.OutputTokens != 60 {
t.Errorf("token counts: in=%d out=%d", got.InputTokens, got.OutputTokens)
}
// Prompt inspection — verify grounding blocks include the scraped content
// and the already-known fields are listed so LLM doesn't redo them.
if !strings.Contains(client.lastUser, "Mittelaltermarkt Dresden") {
if !strings.Contains(stub.seen.UserMessage, "Mittelaltermarkt Dresden") {
t.Error("user prompt missing markt name")
}
if !strings.Contains(client.lastUser, "Ein Mittelaltermarkt mit Ritterspielen") {
if !strings.Contains(stub.seen.UserMessage, "Ein Mittelaltermarkt mit Ritterspielen") {
t.Error("user prompt missing scraped content from URL 1")
}
if !strings.Contains(client.lastUser, "Sa-So jeweils 10-18 Uhr") {
if !strings.Contains(stub.seen.UserMessage, "Sa-So jeweils 10-18 Uhr") {
t.Error("user prompt missing scraped content from URL 2")
}
if !strings.Contains(client.lastUser, "Bereits bekannt") {
if !strings.Contains(stub.seen.UserMessage, "Bereits bekannt") {
t.Error("user prompt should announce already-known fields when Partial is populated")
}
if !strings.Contains(client.lastUser, "PLZ: 01067") {
if !strings.Contains(stub.seen.UserMessage, "PLZ: 01067") {
t.Error("user prompt missing already-known PLZ")
}
// System prompt asks for JSON only.
if !strings.Contains(client.lastSystem, "JSON") {
if !strings.Contains(stub.seen.SystemPrompt, "JSON") {
t.Error("system prompt should mention JSON")
}
}
@@ -107,8 +84,8 @@ func TestMistralEnrich_AllScrapesFail(t *testing.T) {
"https://a.example": errors.New("timeout"),
"https://b.example": errors.New("404"),
}}
client := &stubPass2{} // must not be called
e := NewMistralLLMEnricher(client, scraper)
stub := &stubProvider{} // must not be called
e := NewLLMEnricher(stub, scraper)
req := LLMRequest{
Quellen: []string{"https://a.example", "https://b.example"},
@@ -117,7 +94,7 @@ func TestMistralEnrich_AllScrapesFail(t *testing.T) {
if !errors.Is(err, ErrNoScrapedContent) {
t.Errorf("err = %v; want ErrNoScrapedContent", err)
}
if client.lastUser != "" {
if stub.seen != nil {
t.Error("LLM must not be called when zero URLs scrape")
}
}
@@ -129,10 +106,10 @@ func TestMistralEnrich_SomeScrapesFailStillCallsLLM(t *testing.T) {
responses: map[string]string{"https://ok.example": "Useful content."},
errs: map[string]error{"https://bad.example": errors.New("timeout")},
}
client := &stubPass2{
result: ai.PassResult{Content: `{"category":"mittelaltermarkt","opening_hours":"","description":""}`},
stub := &stubProvider{
content: `{"category":"mittelaltermarkt","opening_hours":"","description":""}`,
}
e := NewMistralLLMEnricher(client, scraper)
e := NewLLMEnricher(stub, scraper)
req := LLMRequest{Quellen: []string{"https://bad.example", "https://ok.example"}}
got, err := e.EnrichMissing(context.Background(), req)
@@ -142,7 +119,7 @@ func TestMistralEnrich_SomeScrapesFailStillCallsLLM(t *testing.T) {
if got.Category != catMittelaltermarkt {
t.Errorf("category: got %q", got.Category)
}
if !strings.Contains(client.lastUser, "Useful content") {
if !strings.Contains(stub.seen.UserMessage, "Useful content") {
t.Error("user prompt missing successful scrape")
}
}
@@ -151,10 +128,10 @@ func TestMistralEnrich_EmptyFieldsNoProvenance(t *testing.T) {
// LLM returns empty strings for fields it can't support. Those fields
// must NOT appear in Sources — an empty provenance is misleading.
scraper := &stubScraper{responses: map[string]string{"https://a.example": "Content."}}
client := &stubPass2{
result: ai.PassResult{Content: `{"category":"mittelaltermarkt","opening_hours":"","description":""}`},
stub := &stubProvider{
content: `{"category":"mittelaltermarkt","opening_hours":"","description":""}`,
}
e := NewMistralLLMEnricher(client, scraper)
e := NewLLMEnricher(stub, scraper)
got, err := e.EnrichMissing(context.Background(), LLMRequest{Quellen: []string{"https://a.example"}})
if err != nil {
@@ -183,8 +160,8 @@ func TestMistralEnrich_CapsURLsAtFive(t *testing.T) {
fetched[u] = true
return responses[u], nil
}}
client := &stubPass2{result: ai.PassResult{Content: `{"category":"x","opening_hours":"","description":""}`}}
e := NewMistralLLMEnricher(client, scraper)
stub2 := &stubProvider{content: `{"category":"x","opening_hours":"","description":""}`}
e := NewLLMEnricher(stub2, scraper)
_, _ = e.EnrichMissing(context.Background(), LLMRequest{Quellen: urls})
if len(fetched) != 5 {

View File

@@ -9,6 +9,8 @@ import (
"fmt"
"strings"
"time"
"marktvogt.de/backend/internal/pkg/ai"
)
// SimilarityRow carries the minimal identifying fields the classifier reads.
@@ -80,16 +82,16 @@ func SimilarityPairKey(a, b SimilarityRow) string {
// propagate).
const DefaultSimilarityCacheTTL = 30 * 24 * time.Hour
// MistralSimilarityClassifier implements SimilarityClassifier by sending a
// JSON-formatted comparison prompt to Mistral's chat endpoint.
type MistralSimilarityClassifier struct {
Client aiPass2
// SimilarityClassifierLLM implements SimilarityClassifier by sending a
// JSON-formatted comparison prompt to an AI provider's chat endpoint.
type SimilarityClassifierLLM struct {
AI ai.Provider
}
// NewMistralSimilarityClassifier binds a Mistral ai.Client. client must be
// NewSimilarityClassifier binds an ai.Provider. provider must be
// non-nil; routes.go falls back to NoopSimilarityClassifier when AI is off.
func NewMistralSimilarityClassifier(client aiPass2) *MistralSimilarityClassifier {
return &MistralSimilarityClassifier{Client: client}
func NewSimilarityClassifier(provider ai.Provider) *SimilarityClassifierLLM {
return &SimilarityClassifierLLM{AI: provider}
}
// simResponse is the JSON shape we instruct Mistral to return. Confidence
@@ -100,25 +102,29 @@ type simResponse struct {
Reason string `json:"reason"`
}
// Classify sends the paired metadata to Mistral and parses the JSON response.
// Classify sends the paired metadata to the AI provider and parses the JSON response.
// No web scraping — the classifier works from name/city/year alone, which is
// enough for the common cases (same venue listed on two different calendars,
// editing typos, cross-year recurrence).
func (m *MistralSimilarityClassifier) Classify(ctx context.Context, a, b SimilarityRow) (Verdict, error) {
if m.Client == nil {
return Verdict{}, errors.New("mistral similarity classifier not configured")
func (c *SimilarityClassifierLLM) Classify(ctx context.Context, a, b SimilarityRow) (Verdict, error) {
if c.AI == nil {
return Verdict{}, errors.New("similarity classifier not configured")
}
systemPrompt := simSystemPrompt()
userPrompt := simUserPrompt(a, b)
result, err := m.Client.Pass2(ctx, systemPrompt, userPrompt)
resp, err := c.AI.Chat(ctx, &ai.ChatRequest{
SystemPrompt: systemPrompt,
UserMessage: userPrompt,
JSONMode: true,
})
if err != nil {
return Verdict{}, fmt.Errorf("pass2: %w", err)
return Verdict{}, fmt.Errorf("chat: %w", err)
}
var parsed simResponse
if err := json.Unmarshal([]byte(result.Content), &parsed); err != nil {
return Verdict{}, fmt.Errorf("parse response: %w (content=%q)", err, result.Content)
if err := json.Unmarshal([]byte(resp.Content), &parsed); err != nil {
return Verdict{}, fmt.Errorf("parse response: %w (content=%q)", err, resp.Content)
}
// Clamp confidence to [0,1]; the model occasionally returns 1.2 or -0.1.
@@ -134,7 +140,6 @@ func (m *MistralSimilarityClassifier) Classify(ctx context.Context, a, b Similar
Same: parsed.SameMarket,
Confidence: conf,
Reason: strings.TrimSpace(parsed.Reason),
Model: result.Model,
ClassifiedAt: time.Now().UTC(),
}, nil
}

View File

@@ -5,8 +5,6 @@ import (
"errors"
"strings"
"testing"
"marktvogt.de/backend/internal/pkg/ai"
)
func TestSimilarityPairKey_Symmetric(t *testing.T) {
@@ -34,13 +32,10 @@ func TestSimilarityPairKey_DifferentInputsDifferentKeys(t *testing.T) {
}
func TestMistralSimilarity_HappyPath(t *testing.T) {
client := &stubPass2{
result: ai.PassResult{
Content: `{"same_market":true,"confidence":0.82,"reason":"Gleicher Name, gleiche Stadt, gleiches Jahr."}`,
Model: "mistral-large-latest",
},
stub := &stubProvider{
content: `{"same_market":true,"confidence":0.82,"reason":"Gleicher Name, gleiche Stadt, gleiches Jahr."}`,
}
c := NewMistralSimilarityClassifier(client)
c := NewSimilarityClassifier(stub)
got, err := c.Classify(context.Background(),
SimilarityRow{Name: "Mittelaltermarkt Dresden", Stadt: "Dresden", Year: 2026},
@@ -49,6 +44,9 @@ func TestMistralSimilarity_HappyPath(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if stub.seen.JSONMode != true {
t.Fatalf("JSONMode must be set")
}
if !got.Same {
t.Errorf("same = false; want true")
}
@@ -58,14 +56,11 @@ func TestMistralSimilarity_HappyPath(t *testing.T) {
if got.Reason == "" {
t.Error("reason missing")
}
if got.Model != "mistral-large-latest" {
t.Errorf("model = %q", got.Model)
}
// Prompt must carry both rows' identifying fields for the LLM to reason on.
if !strings.Contains(client.lastUser, "Mittelaltermarkt Dresden") {
if !strings.Contains(stub.seen.UserMessage, "Mittelaltermarkt Dresden") {
t.Error("user prompt missing A.name")
}
if !strings.Contains(client.lastSystem, "same_market") {
if !strings.Contains(stub.seen.SystemPrompt, "same_market") {
t.Error("system prompt should describe the JSON schema (same_market key)")
}
}
@@ -82,7 +77,7 @@ func TestMistralSimilarity_ClampsConfidence(t *testing.T) {
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
c := NewMistralSimilarityClassifier(&stubPass2{result: ai.PassResult{Content: tc.raw}})
c := NewSimilarityClassifier(&stubProvider{content: tc.raw})
got, err := c.Classify(context.Background(), SimilarityRow{}, SimilarityRow{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
@@ -95,7 +90,7 @@ func TestMistralSimilarity_ClampsConfidence(t *testing.T) {
}
func TestMistralSimilarity_PropagatesPass2Error(t *testing.T) {
c := NewMistralSimilarityClassifier(&stubPass2{err: errors.New("mistral down")})
c := NewSimilarityClassifier(&stubProvider{err: errors.New("mistral down")})
_, err := c.Classify(context.Background(), SimilarityRow{}, SimilarityRow{})
if err == nil {
t.Fatal("expected error; got nil")
@@ -103,7 +98,7 @@ func TestMistralSimilarity_PropagatesPass2Error(t *testing.T) {
}
func TestMistralSimilarity_RejectsBadJSON(t *testing.T) {
c := NewMistralSimilarityClassifier(&stubPass2{result: ai.PassResult{Content: "not json at all"}})
c := NewSimilarityClassifier(&stubProvider{content: "not json at all"})
_, err := c.Classify(context.Background(), SimilarityRow{}, SimilarityRow{})
if err == nil {
t.Fatal("expected parse error; got nil")

View File

@@ -0,0 +1,24 @@
package enrich
import (
"context"
"marktvogt.de/backend/internal/pkg/ai"
)
type stubProvider struct {
content string
err error
seen *ai.ChatRequest
}
func (s *stubProvider) Name() string { return "stub" }
func (s *stubProvider) SupportsJSONMode() bool { return true }
func (s *stubProvider) SupportsJSONSchema() bool { return false }
func (s *stubProvider) Chat(ctx context.Context, req *ai.ChatRequest) (*ai.ChatResponse, error) {
s.seen = req
if s.err != nil {
return nil, s.err
}
return &ai.ChatResponse{Content: s.content}, nil
}

View File

@@ -1,711 +1,19 @@
package market
import (
"encoding/json"
"fmt"
"log/slog"
"net/http"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"marktvogt.de/backend/internal/pkg/ai"
"marktvogt.de/backend/internal/pkg/apierror"
)
type ResearchHandler struct {
service *Service
aiClient *ai.Client
mu sync.Mutex
cooldown map[uuid.UUID]time.Time
service *Service
}
func NewResearchHandler(service *Service, aiClient *ai.Client) *ResearchHandler {
return &ResearchHandler{
service: service,
aiClient: aiClient,
cooldown: make(map[uuid.UUID]time.Time),
}
func NewResearchHandler(service *Service, _ any) *ResearchHandler {
return &ResearchHandler{service: service}
}
func (h *ResearchHandler) Research(c *gin.Context) {
if !h.aiClient.Enabled() {
apiErr := apierror.BadRequest("ai_disabled", "AI research is not configured")
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
id, err := uuid.Parse(c.Param("id"))
if err != nil {
apiErr := apierror.BadRequest("invalid_id", "invalid market ID")
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
// Rate limit: 1 per market per 5 minutes
h.mu.Lock()
if last, ok := h.cooldown[id]; ok && time.Since(last) < 5*time.Minute {
h.mu.Unlock()
apiErr := apierror.BadRequest("rate_limited", "Bitte warte 5 Minuten zwischen Recherche-Aufrufen")
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
h.cooldown[id] = time.Now()
h.mu.Unlock()
m, err := h.service.GetByID(c.Request.Context(), id)
if err != nil {
apiErr := apierror.NotFound("market")
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
ctx := c.Request.Context()
// --- Pass 1: Structured extraction via agent with web search ---
pass1Prompt := buildPass1Prompt(m)
pass1Result, err := h.aiClient.Pass1(ctx, pass1Prompt)
if err != nil {
if ai.IsRateLimit(err) {
h.respondRateLimited(c, "pass1", id, err)
return
}
slog.WarnContext(ctx, "pass1 failed, retrying", "market_id", id, "error", err)
pass1Result, err = h.aiClient.Pass1(ctx, pass1Prompt)
if err != nil {
if ai.IsRateLimit(err) {
h.respondRateLimited(c, "pass1", id, err)
return
}
slog.ErrorContext(ctx, "pass1 retry failed", "market_id", id, "error", err)
apiErr := apierror.Internal("AI research failed")
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
return
}
}
pass1Data := parsePass1Response(pass1Result.Content)
// Compute confidence for pass1 fields and identify retry fields
var retryFields []string
for fieldName, field := range pass1Data.Fields {
if field.Value == nil {
retryFields = append(retryFields, fieldName)
}
}
// Collect all sources from pass1
allSources := pass1Data.Sources
// --- Pass 2: Description + retry fields via chat completions ---
// Pass 2 is enrichment only — if it fails we still return the pass1 result.
// Rate-limit on pass2 is treated as a warning and the partial result ships.
pass2Prompt := buildPass2UserPrompt(m, pass1Data, retryFields, allSources)
pass2Result, err := h.aiClient.Pass2(ctx, pass2SystemPrompt, pass2Prompt)
var pass2Data pass2Response
if err != nil {
if ai.IsRateLimit(err) {
slog.WarnContext(ctx, "pass2 rate-limited; returning pass1 results only",
"market_id", id)
} else {
slog.WarnContext(ctx, "pass2 failed, retrying", "market_id", id, "error", err)
pass2Result, err = h.aiClient.Pass2(ctx, pass2SystemPrompt, pass2Prompt)
if err != nil {
if ai.IsRateLimit(err) {
slog.WarnContext(ctx, "pass2 rate-limited on retry; returning pass1 results only",
"market_id", id)
} else {
slog.WarnContext(ctx, "pass2 retry failed, using pass1 results only",
"market_id", id, "error", err)
}
}
}
}
if err == nil {
pass2Data = parsePass2Response(pass2Result.Content)
}
// --- Merge results and compute final confidence ---
result := mergeResults(m, pass1Data, pass2Data)
// Log cost tracking
logCostTracking(id, pass1Result, pass2Result)
c.JSON(http.StatusOK, ResearchResponse{Data: result})
}
// respondRateLimited writes a 429 with a Retry-After header so the admin UI
// can display a clear "try again in N seconds" message instead of a generic
// 500. pass is a short identifier for the log line ("pass1" / "pass2").
func (h *ResearchHandler) respondRateLimited(c *gin.Context, pass string, marketID uuid.UUID, err error) {
retry := ai.DefaultRetryAfterSeconds
slog.WarnContext(c.Request.Context(), "ai rate-limited",
"pass", pass, "market_id", marketID, "retry_after_s", retry, "error", err)
c.Header("Retry-After", fmt.Sprint(retry))
apiErr := apierror.TooManyRequests(
fmt.Sprintf("AI research temporarily rate-limited; try again in ~%ds", retry),
)
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
}
// --- Pass 1 types and prompt ---
type pass1FieldResult struct {
Value any `json:"wert"`
Sources []string `json:"quellen"`
Extraction string `json:"extraktion"` // "verbatim" or "abgeleitet"
}
type pass1Response struct {
Fields map[string]pass1FieldResult `json:"-"`
Sources []string `json:"quellen_gesamt"`
raw map[string]json.RawMessage
}
func buildPass1Prompt(m Market) string {
dateRange := ""
if !m.StartDate.IsZero() && !m.EndDate.IsZero() {
dateRange = fmt.Sprintf("Zeitraum: %s bis %s\n", m.StartDate.Format("02.01.2006"), m.EndDate.Format("02.01.2006"))
}
website := ""
if m.Website != "" {
website = fmt.Sprintf("Website: %s\n", m.Website)
}
return fmt.Sprintf(`Recherchiere den folgenden Mittelaltermarkt und finde aktuelle Informationen.
Name: %s
Stadt: %s
%s%s
Suche nach folgenden Feldern:
- website (offizielle URL)
- strasse (Straße/Adresse des Veranstaltungsorts)
- plz (Postleitzahl)
- stadt (Stadt, falls korrekter als angegeben)
- veranstalter (Name des Veranstalters)
- oeffnungszeiten (Öffnungszeiten als Array von {tag, von, bis} mit deutschen Wochentagen)
- eintrittspreise (Preise als {erwachsene_cent, kinder_cent, ermaessigt_cent, frei_unter_alter, hinweise})
Antworte AUSSCHLIESSLICH mit validem JSON (keine Markdown-Codeblöcke).
{
"website": {
"wert": "https://...",
"quellen": ["https://..."],
"extraktion": "verbatim"
},
"strasse": {
"wert": "Beispielstraße 5",
"quellen": ["https://..."],
"extraktion": "verbatim"
},
"plz": {
"wert": "12345",
"quellen": ["https://..."],
"extraktion": "verbatim"
},
"stadt": {
"wert": "Musterstadt",
"quellen": ["https://..."],
"extraktion": "verbatim"
},
"veranstalter": {
"wert": "Verein XY e.V.",
"quellen": ["https://..."],
"extraktion": "abgeleitet"
},
"oeffnungszeiten": {
"wert": [{"tag": "Samstag", "von": "10:00", "bis": "22:00"}],
"quellen": ["https://..."],
"extraktion": "verbatim"
},
"eintrittspreise": {
"wert": {"erwachsene_cent": 800, "kinder_cent": 0, "ermaessigt_cent": 500, "frei_unter_alter": 12, "hinweise": "Kinder unter 12 frei"},
"quellen": ["https://..."],
"extraktion": "verbatim"
},
"quellen_gesamt": ["https://...", "https://..."]
}
Regeln:
- "extraktion": "verbatim" wenn der Wert wörtlich in einer strukturierten Quelle steht (Tabelle, Adressblock, Schema)
- "extraktion": "abgeleitet" wenn der Wert aus Fließtext interpretiert wurde
- Setze "wert" auf null wenn nicht findbar. Erfinde KEINE Werte.
- "quellen" muss die URL(s) enthalten, aus denen der Wert stammt
- Verwende für Wochentage NUR: Montag, Dienstag, Mittwoch, Donnerstag, Freitag, Samstag, Sonntag
- Preise in Cent (800 = 8,00 EUR)
`, m.Name, m.City, dateRange, website)
}
func parsePass1Response(raw string) pass1Response {
cleaned := extractJSON(raw)
cleaned = stripJSONComments(cleaned)
var result pass1Response
result.Fields = make(map[string]pass1FieldResult)
var rawMap map[string]json.RawMessage
if err := json.Unmarshal([]byte(cleaned), &rawMap); err != nil {
slog.Warn("failed to parse pass1 response", "error", err, "cleaned", cleaned)
return result
}
result.raw = rawMap
// Parse quellen_gesamt
if sourcesRaw, ok := rawMap["quellen_gesamt"]; ok {
_ = json.Unmarshal(sourcesRaw, &result.Sources)
}
// Parse known field names
knownFields := []string{"website", "strasse", "plz", "stadt", "veranstalter", "oeffnungszeiten", "eintrittspreise"}
for _, name := range knownFields {
if fieldRaw, ok := rawMap[name]; ok {
var field pass1FieldResult
if err := json.Unmarshal(fieldRaw, &field); err == nil {
result.Fields[name] = field
}
}
}
return result
}
// --- Pass 2 types and prompt ---
const pass2SystemPrompt = `Du bist ein Redakteur für Marktvogt, eine Plattform für Mittelaltermärkte und historische Veranstaltungen im DACH-Raum.
Du bekommst:
1. Rechercheergebnisse aus einer vorherigen Websuche (bereits gesammelte Quelleninhalte und URLs)
2. Optional: Eine Liste von Feldern ("retry_felder"), die in der ersten Runde leer oder unsicher waren
Du hast KEINEN Internetzugang. Arbeite ausschließlich mit den bereitgestellten Informationen.
## Aufgabe 1: Beschreibung verfassen
Verfasse eine ansprechende, informative Marktbeschreibung auf Deutsch.
### Anforderungen:
- 3-5 Sätze, ca. 80-150 Wörter
- Sachlich-einladender Ton — kein Werbe-Deutsch, keine Superlative ("das größte", "einzigartig", "unvergesslich")
- Inhalt: Was macht diesen Markt besonders? Was erwartet Besucher? Marktstände, Programm, Musik, Handwerk, Ritterspiele, Gaukelei, historisches Lagerleben etc.
- Atmosphäre und Veranstaltungsort erwähnen wenn aus den Quellen ersichtlich
- Zielgruppen ansprechen: Familien, Mittelalter-Fans, Reenactment-Begeisterte
- NICHT erwähnen: Eintrittspreise, exakte Öffnungszeiten, Datumsangaben (stehen in den strukturierten Feldern)
- Alle Fakten in der Beschreibung müssen aus den mitgelieferten Quelleninhalten stammen
## Aufgabe 2: Felder nachholen (nur wenn retry_felder vorhanden)
Für jedes Feld in retry_felder: Versuche den Wert aus den bereits mitgelieferten Quelleninhalten zu extrahieren.
## Antwortformat
Antworte AUSSCHLIESSLICH mit validem JSON.
{
"beschreibung": {
"wert": "Die vollständige Beschreibung hier...",
"quellen": ["https://...", "https://..."],
"basiert_auf": "Kurze Erklärung"
},
"retry_ergebnisse": {
"feldname": {
"wert": "...",
"quellen": ["https://..."],
"extraktion": "verbatim",
"hinweis": "optional"
}
}
}
Wenn keine retry_felder übergeben wurden, gib "retry_ergebnisse": {} zurück.
Erfinde KEINE Fakten. Wenn die Quellen nichts hergeben, setze "wert" auf null.`
type pass2DescriptionResult struct {
Value string `json:"wert"`
Sources []string `json:"quellen"`
BasedOn string `json:"basiert_auf"`
}
type pass2RetryResult struct {
Value any `json:"wert"`
Sources []string `json:"quellen"`
Extraction string `json:"extraktion"`
}
type pass2Response struct {
Description pass2DescriptionResult `json:"beschreibung"`
RetryResults map[string]pass2RetryResult `json:"retry_ergebnisse"`
}
func buildPass2UserPrompt(m Market, pass1 pass1Response, retryFields []string, allSources []string) string {
// Build context from pass1 results
var sb strings.Builder
sb.WriteString(fmt.Sprintf("Markt: %s\nStadt: %s\n", m.Name, m.City))
if !m.StartDate.IsZero() {
sb.WriteString(fmt.Sprintf("Zeitraum: %s bis %s\n", m.StartDate.Format("02.01.2006"), m.EndDate.Format("02.01.2006")))
}
sb.WriteString("\n## Rechercheergebnisse aus Pass 1\n\n")
for name, field := range pass1.Fields {
if field.Value != nil {
val, _ := json.Marshal(field.Value)
sb.WriteString(fmt.Sprintf("%s: %s (Quellen: %s)\n", name, string(val), strings.Join(field.Sources, ", ")))
}
}
sb.WriteString("\n## Quellen-URLs\n")
for _, src := range allSources {
sb.WriteString(fmt.Sprintf("- %s\n", src))
}
if len(retryFields) > 0 {
sb.WriteString(fmt.Sprintf("\n## retry_felder\n%s\n", strings.Join(retryFields, ", ")))
sb.WriteString("Versuche diese Felder aus den Quelleninhalten zu extrahieren.\n")
}
return sb.String()
}
func parsePass2Response(raw string) pass2Response {
cleaned := extractJSON(raw)
cleaned = stripJSONComments(cleaned)
var result pass2Response
result.RetryResults = make(map[string]pass2RetryResult)
if err := json.Unmarshal([]byte(cleaned), &result); err != nil {
slog.Warn("failed to parse pass2 response", "error", err, "cleaned", cleaned)
}
return result
}
// --- Confidence scoring ---
const extractionDerived = "abgeleitet"
func computeConfidence(sources []string, extraction string) string {
count := len(sources)
if count >= 2 && extraction == "verbatim" {
return "high"
}
if count >= 2 && extraction == extractionDerived {
return "medium"
}
if count == 1 && extraction == "verbatim" {
return "medium"
}
return "low"
}
// --- Merge results ---
// fieldMapping maps pass1 field names to the FieldSuggestion field names used by the frontend.
var fieldMapping = map[string]string{
"website": "website",
"strasse": "street",
"plz": "zip",
"stadt": "city",
"veranstalter": "organizer_name",
"oeffnungszeiten": "opening_hours",
"eintrittspreise": "admission_info",
}
func mergeResults(m Market, pass1 pass1Response, pass2 pass2Response) ResearchResult {
suggestions := make([]FieldSuggestion, 0, len(pass1.Fields)+len(pass2.RetryResults)+1)
allSources := make([]string, 0, len(pass1.Sources))
// Collect unique sources
sourceSet := make(map[string]bool)
for _, s := range pass1.Sources {
sourceSet[s] = true
}
// Process pass1 fields
for p1Name, field := range pass1.Fields {
frontendName, ok := fieldMapping[p1Name]
if !ok {
continue
}
if field.Value == nil {
continue
}
// Convert opening hours format
suggestedValue := field.Value
if p1Name == "oeffnungszeiten" {
suggestedValue = convertOpeningHours(field.Value)
}
if p1Name == "eintrittspreise" {
suggestedValue = convertAdmission(field.Value)
}
suggestions = append(suggestions, FieldSuggestion{
Field: frontendName,
CurrentValue: getCurrentValue(m, frontendName),
SuggestedValue: suggestedValue,
Confidence: computeConfidence(field.Sources, field.Extraction),
Reason: formatReason(field.Sources, field.Extraction),
})
for _, s := range field.Sources {
sourceSet[s] = true
}
}
// Process pass2 retry results
for p1Name, retryField := range pass2.RetryResults {
if retryField.Value == nil {
continue
}
frontendName, ok := fieldMapping[p1Name]
if !ok {
continue
}
suggestedValue := retryField.Value
if p1Name == "oeffnungszeiten" {
suggestedValue = convertOpeningHours(retryField.Value)
}
if p1Name == "eintrittspreise" {
suggestedValue = convertAdmission(retryField.Value)
}
suggestions = append(suggestions, FieldSuggestion{
Field: frontendName,
CurrentValue: getCurrentValue(m, frontendName),
SuggestedValue: suggestedValue,
Confidence: computeConfidence(retryField.Sources, retryField.Extraction),
Reason: formatReason(retryField.Sources, retryField.Extraction) + " (retry)",
})
for _, s := range retryField.Sources {
sourceSet[s] = true
}
}
// Add description from pass2
if pass2.Description.Value != "" {
suggestions = append(suggestions, FieldSuggestion{
Field: "description",
CurrentValue: m.Description,
SuggestedValue: pass2.Description.Value,
Confidence: computeConfidence(pass2.Description.Sources, extractionDerived),
Reason: pass2.Description.BasedOn,
})
for _, s := range pass2.Description.Sources {
sourceSet[s] = true
}
}
for s := range sourceSet {
allSources = append(allSources, s)
}
return ResearchResult{
Suggestions: suggestions,
Sources: allSources,
}
}
func getCurrentValue(m Market, field string) any {
switch field {
case "website":
return m.Website
case "street":
return m.Street
case "zip":
return m.Zip
case "city":
return m.City
case "organizer_name":
return m.OrganizerName
case "opening_hours":
return m.OpeningHours
case "admission_info":
return m.AdmissionInfo
case "description":
return m.Description
default:
return nil
}
}
func formatReason(sources []string, extraction string) string {
srcCount := len(sources)
verb := "gefunden"
if extraction == extractionDerived {
verb = extractionDerived
}
if srcCount == 0 {
return verb
}
if srcCount == 1 {
return fmt.Sprintf("Aus 1 Quelle %s", verb)
}
return fmt.Sprintf("Aus %d Quellen %s", srcCount, verb)
}
// convertOpeningHours converts pass1 format {tag, von, bis} to frontend format {day, open, close}
func convertOpeningHours(val any) any {
data, err := json.Marshal(val)
if err != nil {
return val
}
var entries []map[string]string
if err := json.Unmarshal(data, &entries); err != nil {
return val
}
result := make([]map[string]string, 0, len(entries))
for _, e := range entries {
result = append(result, map[string]string{
"day": e["tag"],
"open": e["von"],
"close": e["bis"],
})
}
return result
}
// convertAdmission converts pass1 format to frontend format (cents)
func convertAdmission(val any) any {
data, err := json.Marshal(val)
if err != nil {
return val
}
var raw map[string]any
if err := json.Unmarshal(data, &raw); err != nil {
return val
}
result := map[string]any{
"adult_cents": toInt(raw["erwachsene_cent"]),
"child_cents": toInt(raw["kinder_cent"]),
"reduced_cents": toInt(raw["ermaessigt_cent"]),
"free_under_age": toInt(raw["frei_unter_alter"]),
"notes": toString(raw["hinweise"]),
}
return result
}
func toInt(v any) int {
switch n := v.(type) {
case float64:
return int(n)
case int:
return n
default:
return 0
}
}
func toString(v any) string {
if s, ok := v.(string); ok {
return s
}
return ""
}
// --- Cost logging ---
func logCostTracking(marketID uuid.UUID, pass1, pass2 ai.PassResult) {
log := map[string]any{
"market_id": marketID,
"timestamp": time.Now().Format(time.RFC3339),
}
if pass1.Usage != nil {
log["pass_1"] = map[string]any{
"model": pass1.Model,
"input_tokens": pass1.Usage.PromptTokens,
"output_tokens": pass1.Usage.CompletionTokens,
}
}
if pass2.Usage != nil {
log["pass_2"] = map[string]any{
"model": pass2.Model,
"input_tokens": pass2.Usage.PromptTokens,
"output_tokens": pass2.Usage.CompletionTokens,
}
}
data, _ := json.Marshal(log)
slog.Info("ai_research_cost", "data", string(data))
}
// --- JSON helpers ---
func extractJSON(s string) string {
if idx := findJSONStart(s); idx >= 0 {
s = s[idx:]
if end := findJSONEnd(s); end > 0 {
s = s[:end+1]
}
}
return s
}
func stripJSONComments(s string) string {
var result []byte
inString := false
escaped := false
for i := 0; i < len(s); i++ {
c := s[i]
if escaped {
result = append(result, c)
escaped = false
continue
}
if c == '\\' && inString {
result = append(result, c)
escaped = true
continue
}
if c == '"' {
inString = !inString
result = append(result, c)
continue
}
if !inString && c == '/' && i+1 < len(s) && s[i+1] == '/' {
for i < len(s) && s[i] != '\n' {
i++
}
continue
}
result = append(result, c)
}
return string(result)
}
func findJSONStart(s string) int {
for i, c := range s {
if c == '{' {
return i
}
}
return -1
}
func findJSONEnd(s string) int {
depth := 0
for i, c := range s {
switch c {
case '{':
depth++
case '}':
depth--
if depth == 0 {
return i
}
}
}
return -1
c.JSON(http.StatusNotImplemented, gin.H{"error": "research temporarily disabled during AI provider migration"})
}

View File

@@ -1,194 +0,0 @@
package ai
import (
"context"
"fmt"
"sync"
"time"
"github.com/VikingOwl91/mistral-go-sdk"
"github.com/VikingOwl91/mistral-go-sdk/chat"
"github.com/VikingOwl91/mistral-go-sdk/conversation"
)
// rateLimiter enforces a minimum interval between calls. Set rps<=0 to disable.
// Mirrors backend/internal/pkg/geocode/nominatim.go's lastReq gate.
type rateLimiter struct {
mu sync.Mutex
lastReq time.Time
minInterval time.Duration
}
func newRateLimiter(rps float64) *rateLimiter {
if rps <= 0 {
return &rateLimiter{minInterval: 0}
}
return &rateLimiter{minInterval: time.Duration(float64(time.Second) / rps)}
}
// TODO: wait() does not honor context cancellation — a cancelled caller will
// block up to minInterval while holding a mutex, and queued callers block up
// to N*minInterval. Mirror pkg/geocode/nominatim.go which has the same gap.
// Fix both sites together by taking ctx and using time.NewTimer + select.
func (rl *rateLimiter) wait() {
if rl.minInterval == 0 {
return
}
rl.mu.Lock()
defer rl.mu.Unlock()
since := time.Since(rl.lastReq)
if since < rl.minInterval {
time.Sleep(rl.minInterval - since)
}
rl.lastReq = time.Now()
}
type Client struct {
sdk *mistral.Client
agentSimple string
modelComplex string
limiter *rateLimiter
}
func New(apiKey, agentSimple, modelComplex string, rps float64) *Client {
if modelComplex == "" {
modelComplex = "mistral-large-latest"
}
var sdk *mistral.Client
if apiKey != "" {
sdk = mistral.NewClient(apiKey,
mistral.WithTimeout(120*time.Second),
mistral.WithRetry(2, 1*time.Second),
)
}
return &Client{
sdk: sdk,
agentSimple: agentSimple,
modelComplex: modelComplex,
limiter: newRateLimiter(rps),
}
}
func (c *Client) Enabled() bool {
return c.sdk != nil && c.agentSimple != ""
}
type UsageInfo struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type PassResult struct {
Content string
Usage *UsageInfo
Model string
}
// Pass1 uses the Conversations API to call the pre-created agent (with web search).
func (c *Client) Pass1(ctx context.Context, prompt string) (PassResult, error) {
c.limiter.wait()
storeFalse := false
resp, err := c.sdk.StartConversation(ctx, &conversation.StartRequest{
AgentID: c.agentSimple,
Inputs: conversation.TextInputs(prompt),
Store: &storeFalse,
})
if err != nil {
return PassResult{}, fmt.Errorf("pass1 conversation: %w", err)
}
content := extractConversationContent(resp)
if content == "" {
return PassResult{}, fmt.Errorf("pass1: no assistant message in response")
}
return PassResult{
Content: content,
Usage: convertConvUsage(resp.Usage),
Model: "agent:" + c.agentSimple,
}, nil
}
// Pass0 uses the Conversations API to call a discovery agent identified by agentID.
// The agent ID is passed explicitly so the discovery domain can configure its own
// agent independently of the agentSimple field used by Pass1.
func (c *Client) Pass0(ctx context.Context, agentID, prompt string) (PassResult, error) {
c.limiter.wait()
if c.sdk == nil || agentID == "" {
return PassResult{}, fmt.Errorf("pass0: ai client not configured (sdk=%v agentID=%q)", c.sdk != nil, agentID)
}
storeFalse := false
resp, err := c.sdk.StartConversation(ctx, &conversation.StartRequest{
AgentID: agentID,
Inputs: conversation.TextInputs(prompt),
Store: &storeFalse,
})
if err != nil {
return PassResult{}, fmt.Errorf("pass0 conversation: %w", err)
}
content := extractConversationContent(resp)
if content == "" {
return PassResult{}, fmt.Errorf("pass0: no assistant message in response")
}
return PassResult{
Content: content,
Usage: convertConvUsage(resp.Usage),
Model: "agent:" + agentID,
}, nil
}
// Pass2 uses chat completions for description generation + retry fields.
func (c *Client) Pass2(ctx context.Context, systemPrompt, userPrompt string) (PassResult, error) {
c.limiter.wait()
resp, err := c.sdk.ChatComplete(ctx, &chat.CompletionRequest{
Model: c.modelComplex,
Messages: []chat.Message{
&chat.SystemMessage{Content: chat.TextContent(systemPrompt)},
&chat.UserMessage{Content: chat.TextContent(userPrompt)},
},
ResponseFormat: &chat.ResponseFormat{Type: "json_object"},
})
if err != nil {
return PassResult{}, fmt.Errorf("pass2 chat: %w", err)
}
if len(resp.Choices) == 0 {
return PassResult{}, fmt.Errorf("pass2: no choices in response")
}
return PassResult{
Content: resp.Choices[0].Message.Content.String(),
Usage: convertChatUsage(resp.Usage),
Model: c.modelComplex,
}, nil
}
func extractConversationContent(resp *conversation.Response) string {
for _, entry := range resp.Outputs {
if msg, ok := entry.(*conversation.MessageOutputEntry); ok {
return msg.Content.String()
}
}
return ""
}
func convertConvUsage(u conversation.UsageInfo) *UsageInfo {
return &UsageInfo{
PromptTokens: u.PromptTokens,
CompletionTokens: u.CompletionTokens,
TotalTokens: u.TotalTokens,
}
}
func convertChatUsage(u chat.UsageInfo) *UsageInfo {
return &UsageInfo{
PromptTokens: u.PromptTokens,
CompletionTokens: u.CompletionTokens,
TotalTokens: u.TotalTokens,
}
}

View File

@@ -0,0 +1,99 @@
package ai
import (
"context"
"errors"
"fmt"
"net"
"strings"
)
type ErrorCode int
const (
ErrInternal ErrorCode = iota
ErrRateLimited
ErrQuotaExceeded
ErrTimeout
ErrInvalidRequest
ErrUnavailable
ErrSchemaViolation
)
func (c ErrorCode) String() string {
switch c {
case ErrInternal:
return "internal"
case ErrRateLimited:
return "rate_limited"
case ErrQuotaExceeded:
return "quota_exceeded"
case ErrTimeout:
return "timeout"
case ErrInvalidRequest:
return "invalid_request"
case ErrUnavailable:
return "unavailable"
case ErrSchemaViolation:
return "schema_violation"
default:
return "internal"
}
}
type ProviderError struct {
Code ErrorCode
Message string
Retryable bool
Inner error
RawOutput string
}
func (e *ProviderError) Error() string {
if e.Inner != nil {
return fmt.Sprintf("ai: %s: %s: %v", e.Code, e.Message, e.Inner)
}
return fmt.Sprintf("ai: %s: %s", e.Code, e.Message)
}
func (e *ProviderError) Unwrap() error { return e.Inner }
func ClassifyError(err error) *ProviderError {
if err == nil {
return nil
}
var pe *ProviderError
if errors.As(err, &pe) {
return pe
}
if errors.Is(err, context.DeadlineExceeded) {
return &ProviderError{Code: ErrTimeout, Message: "context deadline exceeded", Retryable: true, Inner: err}
}
msg := strings.ToLower(err.Error())
switch {
case strings.Contains(msg, "429"),
strings.Contains(msg, "too many requests"),
strings.Contains(msg, "rate limit"):
return &ProviderError{Code: ErrRateLimited, Message: err.Error(), Retryable: true, Inner: err}
case strings.Contains(msg, "deadline exceeded"),
strings.Contains(msg, "timeout"):
return &ProviderError{Code: ErrTimeout, Message: err.Error(), Retryable: true, Inner: err}
case strings.Contains(msg, "connection refused"),
strings.Contains(msg, "no such host"),
isNetError(err):
return &ProviderError{Code: ErrUnavailable, Message: err.Error(), Retryable: true, Inner: err}
case strings.Contains(msg, "quota"),
strings.Contains(msg, "insufficient"):
return &ProviderError{Code: ErrQuotaExceeded, Message: err.Error(), Retryable: false, Inner: err}
case strings.Contains(msg, "400"),
strings.Contains(msg, "invalid"):
return &ProviderError{Code: ErrInvalidRequest, Message: err.Error(), Retryable: false, Inner: err}
}
return &ProviderError{Code: ErrInternal, Message: err.Error(), Retryable: false, Inner: err}
}
func isNetError(err error) bool {
var ne net.Error
return errors.As(err, &ne)
}

View File

@@ -0,0 +1,41 @@
package ai
import (
"context"
"fmt"
"time"
mistral "github.com/VikingOwl91/mistral-go-sdk"
"github.com/VikingOwl91/mistral-go-sdk/chat"
"marktvogt.de/backend/internal/config"
)
const (
providerOllama = "ollama"
providerMistral = "mistral"
)
func NewFromConfig(cfg config.AIConfig) (Provider, error) {
switch cfg.Provider {
case "", providerOllama:
return NewOllamaProvider(OllamaConfig{
BaseURL: cfg.OllamaURL,
Model: cfg.OllamaModel,
}), nil
case providerMistral:
if cfg.MistralAPIKey == "" {
return nil, fmt.Errorf("ai: provider=%s requires AI_MISTRAL_API_KEY", providerMistral)
}
sdk := mistral.NewClient(
cfg.MistralAPIKey,
mistral.WithTimeout(120*time.Second),
mistral.WithRetry(2, 1*time.Second),
)
chatFn := func(ctx context.Context, req *chat.CompletionRequest) (*chat.CompletionResponse, error) {
return sdk.ChatComplete(ctx, req)
}
return newMistralProviderWithChat(cfg.MistralModel, chatFn, newRateLimiter(cfg.RateLimitRPS)), nil
default:
return nil, fmt.Errorf("ai: unknown provider %q (want %s|%s)", cfg.Provider, providerOllama, providerMistral)
}
}

View File

@@ -0,0 +1,29 @@
package ai
import (
"testing"
"marktvogt.de/backend/internal/config"
)
func TestNewFromConfig_Ollama(t *testing.T) {
p, err := NewFromConfig(config.AIConfig{Provider: providerOllama, OllamaURL: "http://x:11434", OllamaModel: "m"})
if err != nil {
t.Fatalf("NewFromConfig: %v", err)
}
if p.Name() != providerOllama {
t.Fatalf("Name: %q", p.Name())
}
}
func TestNewFromConfig_MistralRequiresKey(t *testing.T) {
if _, err := NewFromConfig(config.AIConfig{Provider: providerMistral}); err == nil {
t.Fatal("want error when MistralAPIKey is empty")
}
}
func TestNewFromConfig_UnknownProvider(t *testing.T) {
if _, err := NewFromConfig(config.AIConfig{Provider: "llama-cpp"}); err == nil {
t.Fatal("want error for unknown provider")
}
}

View File

@@ -0,0 +1,99 @@
package ai
import (
"context"
"fmt"
"github.com/VikingOwl91/mistral-go-sdk/chat"
)
type chatFunc func(ctx context.Context, req *chat.CompletionRequest) (*chat.CompletionResponse, error)
type MistralProvider struct {
model string
chatFn chatFunc
limiter *rateLimiter // from ratelimit.go; nil disables
}
func newMistralProviderWithChat(model string, fn chatFunc, limiter *rateLimiter) *MistralProvider {
return &MistralProvider{model: model, chatFn: fn, limiter: limiter}
}
func (p *MistralProvider) Name() string { return "mistral" }
func (p *MistralProvider) SupportsJSONMode() bool { return true }
func (p *MistralProvider) SupportsJSONSchema() bool { return false }
func (p *MistralProvider) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) {
if p.chatFn == nil {
return nil, &ProviderError{Code: ErrInternal, Message: "mistral provider not configured", Retryable: false}
}
if p.limiter != nil {
p.limiter.wait()
}
systemContent := req.SystemPrompt
if len(req.JSONSchema) > 0 {
if systemContent != "" {
systemContent += "\n\n"
}
systemContent += "Respond with a JSON object that conforms to the following JSON Schema. " +
"Do not output anything outside the JSON. Schema:\n" + string(req.JSONSchema)
}
msgs := []chat.Message{}
if systemContent != "" {
msgs = append(msgs, &chat.SystemMessage{Content: chat.TextContent(systemContent)})
}
msgs = append(msgs, &chat.UserMessage{Content: chat.TextContent(req.UserMessage)})
creq := &chat.CompletionRequest{
Model: firstNonEmpty(req.Model, p.model),
Messages: msgs,
}
if req.JSONMode || len(req.JSONSchema) > 0 {
creq.ResponseFormat = &chat.ResponseFormat{Type: "json_object"}
}
if req.Temperature != 0 {
temp := float64(req.Temperature)
creq.Temperature = &temp
}
if req.MaxTokens != 0 {
creq.MaxTokens = &req.MaxTokens
}
resp, err := p.chatFn(ctx, creq)
if err != nil {
return nil, ClassifyError(err)
}
if len(resp.Choices) == 0 {
return nil, &ProviderError{Code: ErrInternal, Message: "no choices in response", Retryable: false}
}
content := resp.Choices[0].Message.Content.String()
if len(req.JSONSchema) > 0 {
if err := ValidateSchema(req.JSONSchema, []byte(content)); err != nil {
return nil, &ProviderError{
Code: ErrSchemaViolation,
Message: fmt.Sprintf("response does not match schema: %v", err),
Retryable: true,
Inner: err,
RawOutput: content,
}
}
}
return &ChatResponse{
Content: content,
Model: resp.Model,
PromptTokens: resp.Usage.PromptTokens,
OutputTokens: resp.Usage.CompletionTokens,
TotalTokens: resp.Usage.TotalTokens,
}, nil
}
func firstNonEmpty(a, b string) string {
if a != "" {
return a
}
return b
}

View File

@@ -0,0 +1,123 @@
package ai
import (
"context"
"errors"
"testing"
"github.com/VikingOwl91/mistral-go-sdk/chat"
)
func TestMistral_Chat_PassesThroughContent(t *testing.T) {
fakeChat := func(ctx context.Context, req *chat.CompletionRequest) (*chat.CompletionResponse, error) {
return &chat.CompletionResponse{
Model: "mistral-large-latest",
Choices: []chat.CompletionChoice{
{Message: chat.AssistantMessage{Content: chat.TextContent("ok")}},
},
Usage: chat.UsageInfo{PromptTokens: 3, CompletionTokens: 1, TotalTokens: 4},
}, nil
}
p := newMistralProviderWithChat("mistral-large-latest", fakeChat, nil)
resp, err := p.Chat(context.Background(), &ChatRequest{SystemPrompt: "s", UserMessage: "u"})
if err != nil {
t.Fatalf("Chat: %v", err)
}
if resp.Content != "ok" || resp.TotalTokens != 4 {
t.Fatalf("unexpected: %+v", resp)
}
}
func TestMistral_Chat_JSONModeSetsResponseFormat(t *testing.T) {
var seen *chat.CompletionRequest
fakeChat := func(ctx context.Context, req *chat.CompletionRequest) (*chat.CompletionResponse, error) {
seen = req
return &chat.CompletionResponse{Choices: []chat.CompletionChoice{{Message: chat.AssistantMessage{Content: chat.TextContent("{}")}}}}, nil
}
p := newMistralProviderWithChat("m", fakeChat, nil)
_, err := p.Chat(context.Background(), &ChatRequest{UserMessage: "x", JSONMode: true})
if err != nil {
t.Fatalf("Chat: %v", err)
}
if seen == nil || seen.ResponseFormat == nil || seen.ResponseFormat.Type != "json_object" {
t.Fatalf("ResponseFormat not set: %+v", seen)
}
}
func TestMistral_Chat_SchemaEmbeddedInSystemPromptAndValidated(t *testing.T) {
schema := []byte(`{"type":"object","required":["foo"],"properties":{"foo":{"type":"string"}},"additionalProperties":false}`)
var seen *chat.CompletionRequest
fakeChat := func(ctx context.Context, req *chat.CompletionRequest) (*chat.CompletionResponse, error) {
seen = req
return &chat.CompletionResponse{Choices: []chat.CompletionChoice{{Message: chat.AssistantMessage{Content: chat.TextContent(`{"foo":"bar"}`)}}}}, nil
}
p := newMistralProviderWithChat("m", fakeChat, nil)
resp, err := p.Chat(context.Background(), &ChatRequest{SystemPrompt: "base system", UserMessage: "x", JSONSchema: schema})
if err != nil {
t.Fatalf("Chat: %v", err)
}
if resp.Content != `{"foo":"bar"}` {
t.Fatalf("content: %q", resp.Content)
}
sysMsg, ok := seen.Messages[0].(*chat.SystemMessage)
if !ok {
t.Fatalf("first message must be system: %T", seen.Messages[0])
}
sys := sysMsg.Content.String()
if !containsAll(sys, []string{"base system", "JSON Schema"}) {
t.Fatalf("system prompt missing expected fragments: %q", sys)
}
}
func TestMistral_Chat_SchemaViolationReturnsRetryableError(t *testing.T) {
schema := []byte(`{"type":"object","required":["foo"],"properties":{"foo":{"type":"string"}},"additionalProperties":false}`)
fakeChat := func(ctx context.Context, req *chat.CompletionRequest) (*chat.CompletionResponse, error) {
return &chat.CompletionResponse{Choices: []chat.CompletionChoice{{Message: chat.AssistantMessage{Content: chat.TextContent(`{"bar":1}`)}}}}, nil
}
p := newMistralProviderWithChat("m", fakeChat, nil)
_, err := p.Chat(context.Background(), &ChatRequest{UserMessage: "x", JSONSchema: schema})
if err == nil {
t.Fatal("want error")
}
var pe *ProviderError
if !errors.As(err, &pe) {
t.Fatalf("want *ProviderError, got %T", err)
}
if pe.Code != ErrSchemaViolation || !pe.Retryable || pe.RawOutput != `{"bar":1}` {
t.Fatalf("unexpected: %+v", pe)
}
}
func TestMistral_Supports(t *testing.T) {
p := newMistralProviderWithChat("m", nil, nil)
if !p.SupportsJSONMode() {
t.Fatal("Mistral supports JSON mode")
}
if p.SupportsJSONSchema() {
t.Fatal("Mistral does NOT natively support JSON schema (prompt-based only)")
}
if p.Name() != "mistral" {
t.Fatalf("Name: %q", p.Name())
}
}
func containsAll(s string, parts []string) bool {
for _, p := range parts {
if !contains(s, p) {
return false
}
}
return true
}
func contains(s, sub string) bool {
return len(sub) == 0 || (len(s) >= len(sub) && (s == sub || indexOf(s, sub) >= 0))
}
func indexOf(s, sub string) int {
for i := 0; i+len(sub) <= len(s); i++ {
if s[i:i+len(sub)] == sub {
return i
}
}
return -1
}

View File

@@ -0,0 +1,127 @@
package ai
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
)
type OllamaConfig struct {
BaseURL string
Model string
Timeout time.Duration
}
type OllamaProvider struct {
cfg OllamaConfig
client *http.Client
}
func NewOllamaProvider(cfg OllamaConfig) *OllamaProvider {
if cfg.Timeout == 0 {
cfg.Timeout = 300 * time.Second
}
return &OllamaProvider{
cfg: cfg,
client: &http.Client{Timeout: cfg.Timeout},
}
}
func (p *OllamaProvider) Name() string { return "ollama" }
func (p *OllamaProvider) SupportsJSONMode() bool { return true }
func (p *OllamaProvider) SupportsJSONSchema() bool { return true }
type ollamaChatReq struct {
Model string `json:"model"`
Messages []ollamaMessage `json:"messages"`
Stream bool `json:"stream"`
Format json.RawMessage `json:"format,omitempty"`
Options *ollamaOptions `json:"options,omitempty"`
}
type ollamaMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ollamaOptions struct {
Temperature float32 `json:"temperature,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
}
type ollamaChatResp struct {
Model string `json:"model"`
Message ollamaMessage `json:"message"`
Done bool `json:"done"`
PromptEvalCount int `json:"prompt_eval_count"`
EvalCount int `json:"eval_count"`
}
func (p *OllamaProvider) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) {
model := req.Model
if model == "" {
model = p.cfg.Model
}
body := ollamaChatReq{
Model: model,
Messages: buildOllamaMessages(req),
Stream: false,
}
switch {
case len(req.JSONSchema) > 0:
body.Format = req.JSONSchema
case req.JSONMode:
body.Format = json.RawMessage(`"json"`)
}
if req.Temperature != 0 || req.MaxTokens != 0 {
body.Options = &ollamaOptions{Temperature: req.Temperature, NumPredict: req.MaxTokens}
}
buf, err := json.Marshal(body)
if err != nil {
return nil, &ProviderError{Code: ErrInvalidRequest, Message: "marshal request", Retryable: false, Inner: err}
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.cfg.BaseURL+"/api/chat", bytes.NewReader(buf))
if err != nil {
return nil, &ProviderError{Code: ErrInternal, Message: "new request", Retryable: false, Inner: err}
}
httpReq.Header.Set("Content-Type", "application/json")
resp, err := p.client.Do(httpReq)
if err != nil {
return nil, ClassifyError(err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode >= 400 {
b, _ := io.ReadAll(resp.Body)
pe := ClassifyError(fmt.Errorf("ollama status %d: %s", resp.StatusCode, string(b)))
return nil, pe
}
var out ollamaChatResp
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
return nil, &ProviderError{Code: ErrInternal, Message: "decode response", Retryable: false, Inner: err}
}
return &ChatResponse{
Content: out.Message.Content,
Model: out.Model,
PromptTokens: out.PromptEvalCount,
OutputTokens: out.EvalCount,
TotalTokens: out.PromptEvalCount + out.EvalCount,
}, nil
}
func buildOllamaMessages(req *ChatRequest) []ollamaMessage {
msgs := make([]ollamaMessage, 0, 2)
if req.SystemPrompt != "" {
msgs = append(msgs, ollamaMessage{Role: "system", Content: req.SystemPrompt})
}
if req.UserMessage != "" {
msgs = append(msgs, ollamaMessage{Role: "user", Content: req.UserMessage})
}
return msgs
}

View File

@@ -0,0 +1,98 @@
package ai
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
)
func TestOllama_Chat_SendsRequestAndParsesResponse(t *testing.T) {
var captured map[string]any
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/chat" {
t.Errorf("path: got %s, want /api/chat", r.URL.Path)
}
body, _ := io.ReadAll(r.Body)
_ = json.Unmarshal(body, &captured)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{
"model":"qwen2.5:14b-instruct",
"message":{"role":"assistant","content":"hello"},
"done":true,
"prompt_eval_count":10,
"eval_count":5
}`))
}))
defer srv.Close()
p := NewOllamaProvider(OllamaConfig{BaseURL: srv.URL, Model: "qwen2.5:14b-instruct", Timeout: 5 * time.Second})
resp, err := p.Chat(context.Background(), &ChatRequest{SystemPrompt: "be brief", UserMessage: "hi", JSONMode: true})
if err != nil {
t.Fatalf("Chat: %v", err)
}
if resp.Content != "hello" {
t.Fatalf("content: got %q", resp.Content)
}
if resp.PromptTokens != 10 || resp.OutputTokens != 5 || resp.TotalTokens != 15 {
t.Fatalf("tokens: %+v", resp)
}
if captured["stream"] != false {
t.Fatalf("stream must be false: %v", captured["stream"])
}
if captured["format"] != "json" {
t.Fatalf("format for JSONMode=true must be \"json\", got %v", captured["format"])
}
}
func TestOllama_Chat_ForwardsSchema(t *testing.T) {
var captured map[string]any
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
_ = json.Unmarshal(body, &captured)
_, _ = w.Write([]byte(`{"model":"m","message":{"role":"assistant","content":"{}"},"done":true}`))
}))
defer srv.Close()
p := NewOllamaProvider(OllamaConfig{BaseURL: srv.URL, Model: "m", Timeout: time.Second})
schema := []byte(`{"type":"object"}`)
_, err := p.Chat(context.Background(), &ChatRequest{UserMessage: "x", JSONSchema: schema})
if err != nil {
t.Fatalf("Chat: %v", err)
}
fmtField, ok := captured["format"].(map[string]any)
if !ok {
t.Fatalf("format must be an object when JSONSchema set: %v", captured["format"])
}
if fmtField["type"] != "object" {
t.Fatalf("schema not forwarded: %v", fmtField)
}
}
func TestOllama_Chat_Unavailable(t *testing.T) {
p := NewOllamaProvider(OllamaConfig{BaseURL: "http://127.0.0.1:1", Timeout: 100 * time.Millisecond})
_, err := p.Chat(context.Background(), &ChatRequest{UserMessage: "x"})
if err == nil {
t.Fatal("want error, got nil")
}
pe := ClassifyError(err)
if pe.Code != ErrUnavailable && pe.Code != ErrTimeout {
t.Fatalf("expected Unavailable or Timeout, got %v", pe.Code)
}
if !pe.Retryable {
t.Fatal("must be retryable")
}
}
func TestOllama_Supports(t *testing.T) {
p := NewOllamaProvider(OllamaConfig{BaseURL: "x"})
if !p.SupportsJSONMode() || !p.SupportsJSONSchema() {
t.Fatal("Ollama supports both")
}
if p.Name() != "ollama" {
t.Fatalf("Name: %q", p.Name())
}
}

View File

@@ -0,0 +1,31 @@
package ai
import (
"context"
"encoding/json"
)
type Provider interface {
Name() string
Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error)
SupportsJSONMode() bool
SupportsJSONSchema() bool
}
type ChatRequest struct {
SystemPrompt string
UserMessage string
Model string
MaxTokens int
Temperature float32
JSONMode bool
JSONSchema json.RawMessage
}
type ChatResponse struct {
Content string
Model string
PromptTokens int
OutputTokens int
TotalTokens int
}

View File

@@ -0,0 +1,89 @@
package ai
import (
"context"
"errors"
"testing"
)
func TestChatRequest_DefaultsAreZeroValues(t *testing.T) {
r := ChatRequest{}
if r.Temperature != 0 {
t.Fatalf("Temperature default: got %v, want 0", r.Temperature)
}
if r.MaxTokens != 0 {
t.Fatalf("MaxTokens default: got %v, want 0", r.MaxTokens)
}
if r.JSONMode {
t.Fatalf("JSONMode default: got true, want false")
}
if r.JSONSchema != nil {
t.Fatalf("JSONSchema default: got non-nil, want nil")
}
}
type stubProvider struct {
name string
resp *ChatResponse
err error
}
func (s *stubProvider) Name() string { return s.name }
func (s *stubProvider) SupportsJSONMode() bool { return true }
func (s *stubProvider) SupportsJSONSchema() bool { return true }
func (s *stubProvider) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) {
return s.resp, s.err
}
func TestStubProvider_SatisfiesInterface(t *testing.T) {
var _ Provider = (*stubProvider)(nil)
}
func TestStubProvider_ReturnsError(t *testing.T) {
sentinel := errors.New("boom")
var p Provider = &stubProvider{name: "stub", err: sentinel}
_, err := p.Chat(context.Background(), &ChatRequest{})
if !errors.Is(err, sentinel) {
t.Fatalf("got %v, want sentinel", err)
}
}
func TestProviderError_Error(t *testing.T) {
inner := errors.New("root cause")
pe := &ProviderError{Code: ErrRateLimited, Message: "slow down", Retryable: true, Inner: inner}
got := pe.Error()
if got == "" {
t.Fatal("Error() must not be empty")
}
if !errors.Is(pe, inner) {
t.Fatalf("errors.Is must unwrap Inner")
}
}
func TestClassifyError_DetectsRateLimit(t *testing.T) {
cases := []struct {
msg string
want ErrorCode
}{
{"429 Too Many Requests", ErrRateLimited},
{"context deadline exceeded", ErrTimeout},
{"connection refused", ErrUnavailable},
{"something totally unrecognized", ErrInternal},
}
for _, tc := range cases {
t.Run(tc.msg, func(t *testing.T) {
got := ClassifyError(errors.New(tc.msg)).Code
if got != tc.want {
t.Fatalf("ClassifyError(%q) = %v, want %v", tc.msg, got, tc.want)
}
})
}
}
func TestClassifyError_PreservesProviderError(t *testing.T) {
orig := &ProviderError{Code: ErrSchemaViolation, Message: "x", Retryable: true}
got := ClassifyError(orig)
if got != orig {
t.Fatal("ClassifyError must return the same *ProviderError instance")
}
}

View File

@@ -1,25 +0,0 @@
package ai
import "strings"
// IsRateLimit reports whether err is a Mistral 429 / web_search rate-limit
// signal. The SDK surfaces these as wrapped errors whose message contains
// "rate limit" and/or "status 429"; we match defensively on both.
//
// Shared helper so domains (discovery, market/research) treat rate limits
// consistently: log at WARN, return a 429 to the caller with a Retry-After
// hint instead of a generic 500.
func IsRateLimit(err error) bool {
if err == nil {
return false
}
msg := err.Error()
return strings.Contains(msg, "rate limit") ||
strings.Contains(msg, "status 429") ||
strings.Contains(msg, "429")
}
// DefaultRetryAfterSeconds is the hint we send when the SDK error doesn't
// carry a structured Retry-After. 60s matches the typical web_search budget
// window on Mistral's paid tier.
const DefaultRetryAfterSeconds = 60

View File

@@ -1,28 +0,0 @@
package ai
import (
"errors"
"testing"
)
func TestIsRateLimit(t *testing.T) {
tests := []struct {
name string
err error
want bool
}{
{"nil", nil, false},
{"unrelated error", errors.New("connection refused"), false},
{"mistral web_search 429", errors.New("pass1 conversation: mistral: web_search rate limit reached. (status 429)"), true},
{"bare status 429", errors.New("upstream returned status 429"), true},
{"raw 429 token", errors.New("http 429 received"), true},
{"token 42900 should NOT match bare token but does by substring", errors.New("code 42900 from upstream"), true},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := IsRateLimit(tc.err); got != tc.want {
t.Errorf("IsRateLimit(%v) = %v, want %v", tc.err, got, tc.want)
}
})
}
}

View File

@@ -0,0 +1,33 @@
package ai
import (
"sync"
"time"
)
// rateLimiter enforces a minimum interval between calls. Set rps<=0 to disable.
type rateLimiter struct {
mu sync.Mutex
lastReq time.Time
minInterval time.Duration
}
func newRateLimiter(rps float64) *rateLimiter {
if rps <= 0 {
return &rateLimiter{minInterval: 0}
}
return &rateLimiter{minInterval: time.Duration(float64(time.Second) / rps)}
}
func (rl *rateLimiter) wait() {
if rl.minInterval == 0 {
return
}
rl.mu.Lock()
defer rl.mu.Unlock()
since := time.Since(rl.lastReq)
if since < rl.minInterval {
time.Sleep(rl.minInterval - since)
}
rl.lastReq = time.Now()
}

View File

@@ -0,0 +1,28 @@
package ai
import (
"bytes"
"encoding/json"
"fmt"
"github.com/santhosh-tekuri/jsonschema/v5"
)
func ValidateSchema(schemaJSON, docJSON []byte) error {
compiler := jsonschema.NewCompiler()
if err := compiler.AddResource("schema.json", bytes.NewReader(schemaJSON)); err != nil {
return fmt.Errorf("add schema: %w", err)
}
sch, err := compiler.Compile("schema.json")
if err != nil {
return fmt.Errorf("compile schema: %w", err)
}
var doc any
if err := json.Unmarshal(docJSON, &doc); err != nil {
return fmt.Errorf("parse doc: %w", err)
}
if err := sch.Validate(doc); err != nil {
return fmt.Errorf("validate: %w", err)
}
return nil
}

View File

@@ -0,0 +1,25 @@
package ai
import "testing"
func TestValidateSchema_AcceptsMatchingDoc(t *testing.T) {
schema := []byte(`{"type":"object","required":["foo"],"properties":{"foo":{"type":"string"}},"additionalProperties":false}`)
doc := []byte(`{"foo":"bar"}`)
if err := ValidateSchema(schema, doc); err != nil {
t.Fatalf("want nil, got %v", err)
}
}
func TestValidateSchema_RejectsMissingRequired(t *testing.T) {
schema := []byte(`{"type":"object","required":["foo"],"properties":{"foo":{"type":"string"}}}`)
doc := []byte(`{}`)
if err := ValidateSchema(schema, doc); err == nil {
t.Fatal("want error, got nil")
}
}
func TestValidateSchema_RejectsBadJSON(t *testing.T) {
if err := ValidateSchema([]byte(`{}`), []byte(`not json`)); err == nil {
t.Fatal("want error, got nil")
}
}

View File

@@ -1,6 +1,7 @@
package server
import (
"fmt"
"net/http"
"github.com/gin-gonic/gin"
@@ -68,23 +69,20 @@ func (s *Server) registerRoutes() {
// Admin market routes
adminMarketHandler := market.NewAdminHandler(marketSvc)
aiClient := ai.New(s.cfg.AI.APIKey, s.cfg.AI.AgentSimple, s.cfg.AI.ModelComplex, s.cfg.AI.RateLimitRPS)
researchHandler := market.NewResearchHandler(marketSvc, aiClient)
aiProvider, err := ai.NewFromConfig(s.cfg.AI)
if err != nil {
panic(fmt.Errorf("init ai provider: %w", err))
}
researchHandler := market.NewResearchHandler(marketSvc, aiProvider)
requireAdmin := middleware.RequireRole("admin")
market.RegisterAdminRoutes(v1, adminMarketHandler, researchHandler, requireAuth, requireAdmin)
// Discovery routes
discoveryRepo := discovery.NewRepository(s.db)
crawlerInstance := crawler.NewCrawler(s.cfg.Discovery.CrawlerUserAgent, crawler.DefaultSourceConfigs())
// Per-row LLM enrichment (MR 3b). Operator-triggered only; disabled rows
// fall through via NoopLLMEnricher when AI isn't configured.
var llmEnricher enrich.LLMEnricher = enrich.NoopLLMEnricher{}
var simClassifier enrich.SimilarityClassifier = enrich.NoopSimilarityClassifier{}
if aiClient.Enabled() {
scraper := scrape.New(s.cfg.Discovery.CrawlerUserAgent)
llmEnricher = enrich.NewMistralLLMEnricher(aiClient, scraper)
simClassifier = enrich.NewMistralSimilarityClassifier(aiClient)
}
scraper := scrape.New(s.cfg.Discovery.CrawlerUserAgent)
llmEnricher := enrich.NewLLMEnricher(aiProvider, scraper)
simClassifier := enrich.NewSimilarityClassifier(aiProvider)
discoveryService := discovery.NewService(discoveryRepo, crawlerInstance, discovery.NewLinkChecker(), marketSvc, geocoder, llmEnricher, simClassifier)
discoveryHandler := discovery.NewHandler(discoveryService, s.cfg.Discovery.CrawlerManualRateLimitPerHour)
requireTickToken := middleware.RequireBearerToken(s.cfg.Discovery.Token)