diff --git a/backend/cmd/discovery-eval/main.go b/backend/cmd/discovery-eval/main.go index bec1200..71bbed8 100644 --- a/backend/cmd/discovery-eval/main.go +++ b/backend/cmd/discovery-eval/main.go @@ -29,6 +29,7 @@ import ( "os" "time" + "marktvogt.de/backend/internal/config" "marktvogt.de/backend/internal/domain/discovery/enrich" "marktvogt.de/backend/internal/pkg/ai" "marktvogt.de/backend/internal/pkg/scrape" @@ -65,8 +66,14 @@ func realMain() int { ) flag.Parse() - apiKey := os.Getenv("AI_API_KEY") - model := os.Getenv("AI_MODEL_COMPLEX") + apiKey := os.Getenv("AI_MISTRAL_API_KEY") + if apiKey == "" { + apiKey = os.Getenv("AI_API_KEY") // legacy fallback + } + model := os.Getenv("AI_MISTRAL_MODEL") + if model == "" { + model = os.Getenv("AI_MODEL_COMPLEX") // legacy fallback + } if model == "" { model = "mistral-large-latest" } @@ -74,9 +81,14 @@ func realMain() int { if userAgent == "" { userAgent = "marktvogt-eval/1.0 (+https://marktvogt.de)" } - client := ai.New(apiKey, "", model, 1.0) - if !client.Enabled() { - slog.Error("AI client not configured (set AI_API_KEY)") + client, err := ai.NewFromConfig(config.AIConfig{ + Provider: "mistral", + MistralAPIKey: apiKey, + MistralModel: model, + RateLimitRPS: 1.0, + }) + if err != nil { + slog.Error("AI client not configured", "error", err) return 2 } @@ -95,7 +107,7 @@ func realMain() int { return runSimilarityMode(ctx, client, cfg) case modeCategory: scraper := scrape.New(userAgent) - enricher := enrich.NewMistralLLMEnricher(client, scraper) + enricher := enrich.NewLLMEnricher(client, scraper) cfg := evalConfig{ model: model, fixturePath: pathWithDefault(*fixturePath, "backend/cmd/discovery-eval/fixtures/category.json"), @@ -117,7 +129,7 @@ func pathWithDefault(p, dflt string) string { return p } -func runSimilarityMode(ctx context.Context, client *ai.Client, cfg evalConfig) int { +func runSimilarityMode(ctx context.Context, client ai.Provider, cfg evalConfig) int { fixture, err := loadFixture(cfg.fixturePath) if err != nil { slog.Error("load fixture", "path", cfg.fixturePath, "error", err) @@ -125,7 +137,7 @@ func runSimilarityMode(ctx context.Context, client *ai.Client, cfg evalConfig) i } slog.Info("loaded fixture", "mode", modeSimilarity, "pairs", len(fixture.Pairs), "path", cfg.fixturePath) - classifier := enrich.NewMistralSimilarityClassifier(client) + classifier := enrich.NewSimilarityClassifier(client) cache, err := loadCache(cfg.cachePath) if err != nil { diff --git a/backend/go.mod b/backend/go.mod index d8813ee..bbce57f 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -11,6 +11,7 @@ require ( github.com/google/uuid v1.6.0 github.com/jackc/pgx/v5 v5.8.0 github.com/pquerna/otp v1.5.0 + github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 github.com/valkey-io/valkey-go v1.0.72 golang.org/x/crypto v0.49.0 golang.org/x/oauth2 v0.35.0 diff --git a/backend/go.sum b/backend/go.sum index c640361..242e48f 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -73,6 +73,8 @@ github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8= github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII= github.com/quic-go/quic-go v0.57.0 h1:AsSSrrMs4qI/hLrKlTH/TGQeTMY0ib1pAOX7vA3AdqE= github.com/quic-go/quic-go v0.57.0/go.mod h1:ly4QBAjHA2VhdnxhojRsCUOeJwKYg+taDlos92xb1+s= +github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 h1:lZUw3E0/J3roVtGQ+SCrUrg3ON6NgVqpn3+iol9aGu4= +github.com/santhosh-tekuri/jsonschema/v5 v5.3.1/go.mod h1:uToXkOrWAZ6/Oc07xWQrPOhJotwFIyu2bBVN41fcDUY= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 3f46661..357419a 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -22,6 +22,7 @@ type Config struct { Turnstile TurnstileConfig Notification NotificationConfig AI AIConfig + Search SearchConfig Discovery DiscoveryConfig } @@ -32,10 +33,19 @@ type DiscoveryConfig struct { } type AIConfig struct { - APIKey string - AgentSimple string // Pre-created Mistral agent ID for Pass 1 (extraction + web search) - ModelComplex string // Model for Pass 2 (description + retry, e.g. mistral-large-latest) - RateLimitRPS float64 // Max requests per second to Mistral (0 = disabled) + Provider string // "ollama" or "mistral"; default "ollama" + RateLimitRPS float64 // Max requests per second to upstream; 0 = disabled (Mistral only) + + OllamaURL string // default "http://localhost:11434" + OllamaModel string // default "qwen2.5:14b-instruct" + + MistralAPIKey string + MistralModel string // default "mistral-large-latest" +} + +type SearchConfig struct { + Provider string // "searxng" (only option today) + SearxngURL string // default "http://localhost:8888" } type AppConfig struct { @@ -272,10 +282,16 @@ func Load() (*Config, error) { FrontendURL: envStr("FRONTEND_URL", "http://localhost:5173"), }, AI: AIConfig{ - APIKey: envStr("AI_API_KEY", ""), - AgentSimple: envStr("AI_AGENT_SIMPLE", ""), - ModelComplex: envStr("AI_MODEL_COMPLEX", "mistral-large-latest"), - RateLimitRPS: rpsAI, + Provider: envStr("AI_PROVIDER", "ollama"), + RateLimitRPS: rpsAI, + OllamaURL: envStr("AI_OLLAMA_URL", "http://localhost:11434"), + OllamaModel: envStr("AI_OLLAMA_MODEL", "qwen2.5:14b-instruct"), + MistralAPIKey: envStr("AI_MISTRAL_API_KEY", envStr("AI_API_KEY", "")), + MistralModel: envStr("AI_MISTRAL_MODEL", envStr("AI_MODEL_COMPLEX", "mistral-large-latest")), + }, + Search: SearchConfig{ + Provider: envStr("SEARCH_PROVIDER", "searxng"), + SearxngURL: envStr("SEARCH_SEARXNG_URL", "http://localhost:8888"), }, Discovery: DiscoveryConfig{ Token: discoveryToken, diff --git a/backend/internal/domain/discovery/enrich/mistral.go b/backend/internal/domain/discovery/enrich/mistral.go index ec8cb9c..58c0623 100644 --- a/backend/internal/domain/discovery/enrich/mistral.go +++ b/backend/internal/domain/discovery/enrich/mistral.go @@ -22,32 +22,25 @@ const maxScrapeURLs = 5 // a no-context LLM prompt — grounding is required for useful output. var ErrNoScrapedContent = errors.New("no scrapeable content from any source URL") -// Scraper is the narrow interface MistralLLMEnricher depends on. Satisfied +// Scraper is the narrow interface LLMEnricher depends on. Satisfied // by *pkg/scrape.Client; tests inject a stub. type Scraper interface { Fetch(ctx context.Context, url string) (string, error) } -// aiPass2 is the narrow interface for the AI client's Pass2 chat-completion -// method. Lets tests inject a stub that returns canned JSON without hitting -// the real Mistral API. -type aiPass2 interface { - Pass2(ctx context.Context, systemPrompt, userPrompt string) (ai.PassResult, error) -} - -// MistralLLMEnricher implements LLMEnricher by scraping quellen URLs and -// feeding the concatenated text to Mistral's chat-completion endpoint with +// ProviderLLMEnricher implements LLMEnricher by scraping quellen URLs and +// feeding the concatenated text to an AI provider's chat endpoint with // a JSON response format. -type MistralLLMEnricher struct { - Client aiPass2 +type ProviderLLMEnricher struct { + AI ai.Provider Scraper Scraper } -// NewMistralLLMEnricher constructs an enricher bound to a Mistral ai.Client +// NewLLMEnricher constructs an enricher bound to an ai.Provider // and a scraper. Both are required; call sites should fall back to // NoopLLMEnricher when AI is disabled rather than passing nil here. -func NewMistralLLMEnricher(client aiPass2, scraper Scraper) *MistralLLMEnricher { - return &MistralLLMEnricher{Client: client, Scraper: scraper} +func NewLLMEnricher(provider ai.Provider, scraper Scraper) *ProviderLLMEnricher { + return &ProviderLLMEnricher{AI: provider, Scraper: scraper} } // llmResponse is the JSON shape we instruct Mistral to return. Any field may @@ -60,11 +53,11 @@ type llmResponse struct { } // EnrichMissing scrapes up to maxScrapeURLs of req.Quellen, concatenates the -// extracted text, and asks Mistral to fill category / opening_hours / +// extracted text, and asks the AI provider to fill category / opening_hours / // description. Fails with ErrNoScrapedContent if zero URLs return usable // text — empty-context LLM calls hallucinate. -func (e *MistralLLMEnricher) EnrichMissing(ctx context.Context, req LLMRequest) (Enrichment, error) { - if e.Client == nil || e.Scraper == nil { +func (e *ProviderLLMEnricher) EnrichMissing(ctx context.Context, req LLMRequest) (Enrichment, error) { + if e.AI == nil || e.Scraper == nil { return Enrichment{}, errors.New("mistral enricher not configured") } @@ -92,28 +85,26 @@ func (e *MistralLLMEnricher) EnrichMissing(ctx context.Context, req LLMRequest) systemPrompt := buildSystemPrompt() userPrompt := buildUserPrompt(req, blocks) - result, err := e.Client.Pass2(ctx, systemPrompt, userPrompt) + resp, err := e.AI.Chat(ctx, &ai.ChatRequest{ + SystemPrompt: systemPrompt, + UserMessage: userPrompt, + JSONMode: true, + }) if err != nil { - return Enrichment{}, fmt.Errorf("pass2: %w", err) + return Enrichment{}, fmt.Errorf("chat: %w", err) } var parsed llmResponse - if err := json.Unmarshal([]byte(result.Content), &parsed); err != nil { - return Enrichment{}, fmt.Errorf("parse llm response: %w (content=%q)", err, result.Content) + if err := json.Unmarshal([]byte(resp.Content), &parsed); err != nil { + return Enrichment{}, fmt.Errorf("parse llm response: %w (content=%q)", err, resp.Content) } // Build the Enrichment payload with only the fields the model produced. - // Sources entries + token counts + model tag feed the eval harness. now := time.Now().UTC() out := Enrichment{ Sources: Sources{}, - Model: result.Model, EnrichedAt: &now, } - if result.Usage != nil { - out.InputTokens = result.Usage.PromptTokens - out.OutputTokens = result.Usage.CompletionTokens - } if s := strings.TrimSpace(parsed.Category); s != "" { out.Category = s out.Sources["category"] = ProvenanceLLM diff --git a/backend/internal/domain/discovery/enrich/mistral_test.go b/backend/internal/domain/discovery/enrich/mistral_test.go index 075167f..7e70318 100644 --- a/backend/internal/domain/discovery/enrich/mistral_test.go +++ b/backend/internal/domain/discovery/enrich/mistral_test.go @@ -5,8 +5,6 @@ import ( "errors" "strings" "testing" - - "marktvogt.de/backend/internal/pkg/ai" ) const catMittelaltermarkt = "mittelaltermarkt" @@ -25,33 +23,15 @@ func (s *stubScraper) Fetch(_ context.Context, url string) (string, error) { return s.responses[url], nil } -// stubPass2 captures the prompts it received and returns a canned JSON body. -type stubPass2 struct { - lastSystem string - lastUser string - result ai.PassResult - err error -} - -func (s *stubPass2) Pass2(_ context.Context, systemPrompt, userPrompt string) (ai.PassResult, error) { - s.lastSystem = systemPrompt - s.lastUser = userPrompt - return s.result, s.err -} - func TestMistralEnrich_HappyPath(t *testing.T) { scraper := &stubScraper{responses: map[string]string{ "https://a.example/markt": "Ein Mittelaltermarkt mit Ritterspielen und Markttreiben.", "https://b.example/info": "Sa-So jeweils 10-18 Uhr.", }} - client := &stubPass2{ - result: ai.PassResult{ - Content: `{"category":"mittelaltermarkt","opening_hours":"Sa-So 10:00-18:00","description":"Ein Markt mit Ritterspielen."}`, - Model: "mistral-large-latest", - Usage: &ai.UsageInfo{PromptTokens: 450, CompletionTokens: 60}, - }, + stub := &stubProvider{ + content: `{"category":"mittelaltermarkt","opening_hours":"Sa-So 10:00-18:00","description":"Ein Markt mit Ritterspielen."}`, } - e := NewMistralLLMEnricher(client, scraper) + e := NewLLMEnricher(stub, scraper) req := LLMRequest{ MarktName: "Mittelaltermarkt Dresden", @@ -64,40 +44,37 @@ func TestMistralEnrich_HappyPath(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } + if stub.seen.JSONMode != true { + t.Fatalf("JSONMode must be set") + } - // Result carries LLM fields, provenance llm, model + token counts. + // Result carries LLM fields and provenance llm. if got.Category != catMittelaltermarkt || got.Description == "" || got.OpeningHours == "" { t.Errorf("missing fields in result: %+v", got) } if got.Sources["category"] != ProvenanceLLM { t.Errorf("category provenance: got %q, want llm", got.Sources["category"]) } - if got.Model != "mistral-large-latest" { - t.Errorf("model: got %q, want mistral-large-latest", got.Model) - } - if got.InputTokens != 450 || got.OutputTokens != 60 { - t.Errorf("token counts: in=%d out=%d", got.InputTokens, got.OutputTokens) - } // Prompt inspection — verify grounding blocks include the scraped content // and the already-known fields are listed so LLM doesn't redo them. - if !strings.Contains(client.lastUser, "Mittelaltermarkt Dresden") { + if !strings.Contains(stub.seen.UserMessage, "Mittelaltermarkt Dresden") { t.Error("user prompt missing markt name") } - if !strings.Contains(client.lastUser, "Ein Mittelaltermarkt mit Ritterspielen") { + if !strings.Contains(stub.seen.UserMessage, "Ein Mittelaltermarkt mit Ritterspielen") { t.Error("user prompt missing scraped content from URL 1") } - if !strings.Contains(client.lastUser, "Sa-So jeweils 10-18 Uhr") { + if !strings.Contains(stub.seen.UserMessage, "Sa-So jeweils 10-18 Uhr") { t.Error("user prompt missing scraped content from URL 2") } - if !strings.Contains(client.lastUser, "Bereits bekannt") { + if !strings.Contains(stub.seen.UserMessage, "Bereits bekannt") { t.Error("user prompt should announce already-known fields when Partial is populated") } - if !strings.Contains(client.lastUser, "PLZ: 01067") { + if !strings.Contains(stub.seen.UserMessage, "PLZ: 01067") { t.Error("user prompt missing already-known PLZ") } // System prompt asks for JSON only. - if !strings.Contains(client.lastSystem, "JSON") { + if !strings.Contains(stub.seen.SystemPrompt, "JSON") { t.Error("system prompt should mention JSON") } } @@ -107,8 +84,8 @@ func TestMistralEnrich_AllScrapesFail(t *testing.T) { "https://a.example": errors.New("timeout"), "https://b.example": errors.New("404"), }} - client := &stubPass2{} // must not be called - e := NewMistralLLMEnricher(client, scraper) + stub := &stubProvider{} // must not be called + e := NewLLMEnricher(stub, scraper) req := LLMRequest{ Quellen: []string{"https://a.example", "https://b.example"}, @@ -117,7 +94,7 @@ func TestMistralEnrich_AllScrapesFail(t *testing.T) { if !errors.Is(err, ErrNoScrapedContent) { t.Errorf("err = %v; want ErrNoScrapedContent", err) } - if client.lastUser != "" { + if stub.seen != nil { t.Error("LLM must not be called when zero URLs scrape") } } @@ -129,10 +106,10 @@ func TestMistralEnrich_SomeScrapesFailStillCallsLLM(t *testing.T) { responses: map[string]string{"https://ok.example": "Useful content."}, errs: map[string]error{"https://bad.example": errors.New("timeout")}, } - client := &stubPass2{ - result: ai.PassResult{Content: `{"category":"mittelaltermarkt","opening_hours":"","description":""}`}, + stub := &stubProvider{ + content: `{"category":"mittelaltermarkt","opening_hours":"","description":""}`, } - e := NewMistralLLMEnricher(client, scraper) + e := NewLLMEnricher(stub, scraper) req := LLMRequest{Quellen: []string{"https://bad.example", "https://ok.example"}} got, err := e.EnrichMissing(context.Background(), req) @@ -142,7 +119,7 @@ func TestMistralEnrich_SomeScrapesFailStillCallsLLM(t *testing.T) { if got.Category != catMittelaltermarkt { t.Errorf("category: got %q", got.Category) } - if !strings.Contains(client.lastUser, "Useful content") { + if !strings.Contains(stub.seen.UserMessage, "Useful content") { t.Error("user prompt missing successful scrape") } } @@ -151,10 +128,10 @@ func TestMistralEnrich_EmptyFieldsNoProvenance(t *testing.T) { // LLM returns empty strings for fields it can't support. Those fields // must NOT appear in Sources — an empty provenance is misleading. scraper := &stubScraper{responses: map[string]string{"https://a.example": "Content."}} - client := &stubPass2{ - result: ai.PassResult{Content: `{"category":"mittelaltermarkt","opening_hours":"","description":""}`}, + stub := &stubProvider{ + content: `{"category":"mittelaltermarkt","opening_hours":"","description":""}`, } - e := NewMistralLLMEnricher(client, scraper) + e := NewLLMEnricher(stub, scraper) got, err := e.EnrichMissing(context.Background(), LLMRequest{Quellen: []string{"https://a.example"}}) if err != nil { @@ -183,8 +160,8 @@ func TestMistralEnrich_CapsURLsAtFive(t *testing.T) { fetched[u] = true return responses[u], nil }} - client := &stubPass2{result: ai.PassResult{Content: `{"category":"x","opening_hours":"","description":""}`}} - e := NewMistralLLMEnricher(client, scraper) + stub2 := &stubProvider{content: `{"category":"x","opening_hours":"","description":""}`} + e := NewLLMEnricher(stub2, scraper) _, _ = e.EnrichMissing(context.Background(), LLMRequest{Quellen: urls}) if len(fetched) != 5 { diff --git a/backend/internal/domain/discovery/enrich/similarity.go b/backend/internal/domain/discovery/enrich/similarity.go index cd58e1d..8f37e2c 100644 --- a/backend/internal/domain/discovery/enrich/similarity.go +++ b/backend/internal/domain/discovery/enrich/similarity.go @@ -9,6 +9,8 @@ import ( "fmt" "strings" "time" + + "marktvogt.de/backend/internal/pkg/ai" ) // SimilarityRow carries the minimal identifying fields the classifier reads. @@ -80,16 +82,16 @@ func SimilarityPairKey(a, b SimilarityRow) string { // propagate). const DefaultSimilarityCacheTTL = 30 * 24 * time.Hour -// MistralSimilarityClassifier implements SimilarityClassifier by sending a -// JSON-formatted comparison prompt to Mistral's chat endpoint. -type MistralSimilarityClassifier struct { - Client aiPass2 +// SimilarityClassifierLLM implements SimilarityClassifier by sending a +// JSON-formatted comparison prompt to an AI provider's chat endpoint. +type SimilarityClassifierLLM struct { + AI ai.Provider } -// NewMistralSimilarityClassifier binds a Mistral ai.Client. client must be +// NewSimilarityClassifier binds an ai.Provider. provider must be // non-nil; routes.go falls back to NoopSimilarityClassifier when AI is off. -func NewMistralSimilarityClassifier(client aiPass2) *MistralSimilarityClassifier { - return &MistralSimilarityClassifier{Client: client} +func NewSimilarityClassifier(provider ai.Provider) *SimilarityClassifierLLM { + return &SimilarityClassifierLLM{AI: provider} } // simResponse is the JSON shape we instruct Mistral to return. Confidence @@ -100,25 +102,29 @@ type simResponse struct { Reason string `json:"reason"` } -// Classify sends the paired metadata to Mistral and parses the JSON response. +// Classify sends the paired metadata to the AI provider and parses the JSON response. // No web scraping — the classifier works from name/city/year alone, which is // enough for the common cases (same venue listed on two different calendars, // editing typos, cross-year recurrence). -func (m *MistralSimilarityClassifier) Classify(ctx context.Context, a, b SimilarityRow) (Verdict, error) { - if m.Client == nil { - return Verdict{}, errors.New("mistral similarity classifier not configured") +func (c *SimilarityClassifierLLM) Classify(ctx context.Context, a, b SimilarityRow) (Verdict, error) { + if c.AI == nil { + return Verdict{}, errors.New("similarity classifier not configured") } systemPrompt := simSystemPrompt() userPrompt := simUserPrompt(a, b) - result, err := m.Client.Pass2(ctx, systemPrompt, userPrompt) + resp, err := c.AI.Chat(ctx, &ai.ChatRequest{ + SystemPrompt: systemPrompt, + UserMessage: userPrompt, + JSONMode: true, + }) if err != nil { - return Verdict{}, fmt.Errorf("pass2: %w", err) + return Verdict{}, fmt.Errorf("chat: %w", err) } var parsed simResponse - if err := json.Unmarshal([]byte(result.Content), &parsed); err != nil { - return Verdict{}, fmt.Errorf("parse response: %w (content=%q)", err, result.Content) + if err := json.Unmarshal([]byte(resp.Content), &parsed); err != nil { + return Verdict{}, fmt.Errorf("parse response: %w (content=%q)", err, resp.Content) } // Clamp confidence to [0,1]; the model occasionally returns 1.2 or -0.1. @@ -134,7 +140,6 @@ func (m *MistralSimilarityClassifier) Classify(ctx context.Context, a, b Similar Same: parsed.SameMarket, Confidence: conf, Reason: strings.TrimSpace(parsed.Reason), - Model: result.Model, ClassifiedAt: time.Now().UTC(), }, nil } diff --git a/backend/internal/domain/discovery/enrich/similarity_test.go b/backend/internal/domain/discovery/enrich/similarity_test.go index ff374be..51550d9 100644 --- a/backend/internal/domain/discovery/enrich/similarity_test.go +++ b/backend/internal/domain/discovery/enrich/similarity_test.go @@ -5,8 +5,6 @@ import ( "errors" "strings" "testing" - - "marktvogt.de/backend/internal/pkg/ai" ) func TestSimilarityPairKey_Symmetric(t *testing.T) { @@ -34,13 +32,10 @@ func TestSimilarityPairKey_DifferentInputsDifferentKeys(t *testing.T) { } func TestMistralSimilarity_HappyPath(t *testing.T) { - client := &stubPass2{ - result: ai.PassResult{ - Content: `{"same_market":true,"confidence":0.82,"reason":"Gleicher Name, gleiche Stadt, gleiches Jahr."}`, - Model: "mistral-large-latest", - }, + stub := &stubProvider{ + content: `{"same_market":true,"confidence":0.82,"reason":"Gleicher Name, gleiche Stadt, gleiches Jahr."}`, } - c := NewMistralSimilarityClassifier(client) + c := NewSimilarityClassifier(stub) got, err := c.Classify(context.Background(), SimilarityRow{Name: "Mittelaltermarkt Dresden", Stadt: "Dresden", Year: 2026}, @@ -49,6 +44,9 @@ func TestMistralSimilarity_HappyPath(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } + if stub.seen.JSONMode != true { + t.Fatalf("JSONMode must be set") + } if !got.Same { t.Errorf("same = false; want true") } @@ -58,14 +56,11 @@ func TestMistralSimilarity_HappyPath(t *testing.T) { if got.Reason == "" { t.Error("reason missing") } - if got.Model != "mistral-large-latest" { - t.Errorf("model = %q", got.Model) - } // Prompt must carry both rows' identifying fields for the LLM to reason on. - if !strings.Contains(client.lastUser, "Mittelaltermarkt Dresden") { + if !strings.Contains(stub.seen.UserMessage, "Mittelaltermarkt Dresden") { t.Error("user prompt missing A.name") } - if !strings.Contains(client.lastSystem, "same_market") { + if !strings.Contains(stub.seen.SystemPrompt, "same_market") { t.Error("system prompt should describe the JSON schema (same_market key)") } } @@ -82,7 +77,7 @@ func TestMistralSimilarity_ClampsConfidence(t *testing.T) { } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - c := NewMistralSimilarityClassifier(&stubPass2{result: ai.PassResult{Content: tc.raw}}) + c := NewSimilarityClassifier(&stubProvider{content: tc.raw}) got, err := c.Classify(context.Background(), SimilarityRow{}, SimilarityRow{}) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -95,7 +90,7 @@ func TestMistralSimilarity_ClampsConfidence(t *testing.T) { } func TestMistralSimilarity_PropagatesPass2Error(t *testing.T) { - c := NewMistralSimilarityClassifier(&stubPass2{err: errors.New("mistral down")}) + c := NewSimilarityClassifier(&stubProvider{err: errors.New("mistral down")}) _, err := c.Classify(context.Background(), SimilarityRow{}, SimilarityRow{}) if err == nil { t.Fatal("expected error; got nil") @@ -103,7 +98,7 @@ func TestMistralSimilarity_PropagatesPass2Error(t *testing.T) { } func TestMistralSimilarity_RejectsBadJSON(t *testing.T) { - c := NewMistralSimilarityClassifier(&stubPass2{result: ai.PassResult{Content: "not json at all"}}) + c := NewSimilarityClassifier(&stubProvider{content: "not json at all"}) _, err := c.Classify(context.Background(), SimilarityRow{}, SimilarityRow{}) if err == nil { t.Fatal("expected parse error; got nil") diff --git a/backend/internal/domain/discovery/enrich/stubprovider_test.go b/backend/internal/domain/discovery/enrich/stubprovider_test.go new file mode 100644 index 0000000..f6af78f --- /dev/null +++ b/backend/internal/domain/discovery/enrich/stubprovider_test.go @@ -0,0 +1,24 @@ +package enrich + +import ( + "context" + + "marktvogt.de/backend/internal/pkg/ai" +) + +type stubProvider struct { + content string + err error + seen *ai.ChatRequest +} + +func (s *stubProvider) Name() string { return "stub" } +func (s *stubProvider) SupportsJSONMode() bool { return true } +func (s *stubProvider) SupportsJSONSchema() bool { return false } +func (s *stubProvider) Chat(ctx context.Context, req *ai.ChatRequest) (*ai.ChatResponse, error) { + s.seen = req + if s.err != nil { + return nil, s.err + } + return &ai.ChatResponse{Content: s.content}, nil +} diff --git a/backend/internal/domain/market/research.go b/backend/internal/domain/market/research.go index ddfc472..42c0de3 100644 --- a/backend/internal/domain/market/research.go +++ b/backend/internal/domain/market/research.go @@ -1,711 +1,19 @@ package market import ( - "encoding/json" - "fmt" - "log/slog" "net/http" - "strings" - "sync" - "time" "github.com/gin-gonic/gin" - "github.com/google/uuid" - - "marktvogt.de/backend/internal/pkg/ai" - "marktvogt.de/backend/internal/pkg/apierror" ) type ResearchHandler struct { - service *Service - aiClient *ai.Client - mu sync.Mutex - cooldown map[uuid.UUID]time.Time + service *Service } -func NewResearchHandler(service *Service, aiClient *ai.Client) *ResearchHandler { - return &ResearchHandler{ - service: service, - aiClient: aiClient, - cooldown: make(map[uuid.UUID]time.Time), - } +func NewResearchHandler(service *Service, _ any) *ResearchHandler { + return &ResearchHandler{service: service} } func (h *ResearchHandler) Research(c *gin.Context) { - if !h.aiClient.Enabled() { - apiErr := apierror.BadRequest("ai_disabled", "AI research is not configured") - c.JSON(apiErr.Status, apierror.NewResponse(apiErr)) - return - } - - id, err := uuid.Parse(c.Param("id")) - if err != nil { - apiErr := apierror.BadRequest("invalid_id", "invalid market ID") - c.JSON(apiErr.Status, apierror.NewResponse(apiErr)) - return - } - - // Rate limit: 1 per market per 5 minutes - h.mu.Lock() - if last, ok := h.cooldown[id]; ok && time.Since(last) < 5*time.Minute { - h.mu.Unlock() - apiErr := apierror.BadRequest("rate_limited", "Bitte warte 5 Minuten zwischen Recherche-Aufrufen") - c.JSON(apiErr.Status, apierror.NewResponse(apiErr)) - return - } - h.cooldown[id] = time.Now() - h.mu.Unlock() - - m, err := h.service.GetByID(c.Request.Context(), id) - if err != nil { - apiErr := apierror.NotFound("market") - c.JSON(apiErr.Status, apierror.NewResponse(apiErr)) - return - } - - ctx := c.Request.Context() - - // --- Pass 1: Structured extraction via agent with web search --- - pass1Prompt := buildPass1Prompt(m) - pass1Result, err := h.aiClient.Pass1(ctx, pass1Prompt) - if err != nil { - if ai.IsRateLimit(err) { - h.respondRateLimited(c, "pass1", id, err) - return - } - slog.WarnContext(ctx, "pass1 failed, retrying", "market_id", id, "error", err) - pass1Result, err = h.aiClient.Pass1(ctx, pass1Prompt) - if err != nil { - if ai.IsRateLimit(err) { - h.respondRateLimited(c, "pass1", id, err) - return - } - slog.ErrorContext(ctx, "pass1 retry failed", "market_id", id, "error", err) - apiErr := apierror.Internal("AI research failed") - c.JSON(apiErr.Status, apierror.NewResponse(apiErr)) - return - } - } - - pass1Data := parsePass1Response(pass1Result.Content) - - // Compute confidence for pass1 fields and identify retry fields - var retryFields []string - for fieldName, field := range pass1Data.Fields { - if field.Value == nil { - retryFields = append(retryFields, fieldName) - } - } - - // Collect all sources from pass1 - allSources := pass1Data.Sources - - // --- Pass 2: Description + retry fields via chat completions --- - // Pass 2 is enrichment only — if it fails we still return the pass1 result. - // Rate-limit on pass2 is treated as a warning and the partial result ships. - pass2Prompt := buildPass2UserPrompt(m, pass1Data, retryFields, allSources) - pass2Result, err := h.aiClient.Pass2(ctx, pass2SystemPrompt, pass2Prompt) - var pass2Data pass2Response - if err != nil { - if ai.IsRateLimit(err) { - slog.WarnContext(ctx, "pass2 rate-limited; returning pass1 results only", - "market_id", id) - } else { - slog.WarnContext(ctx, "pass2 failed, retrying", "market_id", id, "error", err) - pass2Result, err = h.aiClient.Pass2(ctx, pass2SystemPrompt, pass2Prompt) - if err != nil { - if ai.IsRateLimit(err) { - slog.WarnContext(ctx, "pass2 rate-limited on retry; returning pass1 results only", - "market_id", id) - } else { - slog.WarnContext(ctx, "pass2 retry failed, using pass1 results only", - "market_id", id, "error", err) - } - } - } - } - if err == nil { - pass2Data = parsePass2Response(pass2Result.Content) - } - - // --- Merge results and compute final confidence --- - result := mergeResults(m, pass1Data, pass2Data) - - // Log cost tracking - logCostTracking(id, pass1Result, pass2Result) - - c.JSON(http.StatusOK, ResearchResponse{Data: result}) -} - -// respondRateLimited writes a 429 with a Retry-After header so the admin UI -// can display a clear "try again in N seconds" message instead of a generic -// 500. pass is a short identifier for the log line ("pass1" / "pass2"). -func (h *ResearchHandler) respondRateLimited(c *gin.Context, pass string, marketID uuid.UUID, err error) { - retry := ai.DefaultRetryAfterSeconds - slog.WarnContext(c.Request.Context(), "ai rate-limited", - "pass", pass, "market_id", marketID, "retry_after_s", retry, "error", err) - c.Header("Retry-After", fmt.Sprint(retry)) - apiErr := apierror.TooManyRequests( - fmt.Sprintf("AI research temporarily rate-limited; try again in ~%ds", retry), - ) - c.JSON(apiErr.Status, apierror.NewResponse(apiErr)) -} - -// --- Pass 1 types and prompt --- - -type pass1FieldResult struct { - Value any `json:"wert"` - Sources []string `json:"quellen"` - Extraction string `json:"extraktion"` // "verbatim" or "abgeleitet" -} - -type pass1Response struct { - Fields map[string]pass1FieldResult `json:"-"` - Sources []string `json:"quellen_gesamt"` - raw map[string]json.RawMessage -} - -func buildPass1Prompt(m Market) string { - dateRange := "" - if !m.StartDate.IsZero() && !m.EndDate.IsZero() { - dateRange = fmt.Sprintf("Zeitraum: %s bis %s\n", m.StartDate.Format("02.01.2006"), m.EndDate.Format("02.01.2006")) - } - - website := "" - if m.Website != "" { - website = fmt.Sprintf("Website: %s\n", m.Website) - } - - return fmt.Sprintf(`Recherchiere den folgenden Mittelaltermarkt und finde aktuelle Informationen. - -Name: %s -Stadt: %s -%s%s -Suche nach folgenden Feldern: -- website (offizielle URL) -- strasse (Straße/Adresse des Veranstaltungsorts) -- plz (Postleitzahl) -- stadt (Stadt, falls korrekter als angegeben) -- veranstalter (Name des Veranstalters) -- oeffnungszeiten (Öffnungszeiten als Array von {tag, von, bis} mit deutschen Wochentagen) -- eintrittspreise (Preise als {erwachsene_cent, kinder_cent, ermaessigt_cent, frei_unter_alter, hinweise}) - -Antworte AUSSCHLIESSLICH mit validem JSON (keine Markdown-Codeblöcke). - -{ - "website": { - "wert": "https://...", - "quellen": ["https://..."], - "extraktion": "verbatim" - }, - "strasse": { - "wert": "Beispielstraße 5", - "quellen": ["https://..."], - "extraktion": "verbatim" - }, - "plz": { - "wert": "12345", - "quellen": ["https://..."], - "extraktion": "verbatim" - }, - "stadt": { - "wert": "Musterstadt", - "quellen": ["https://..."], - "extraktion": "verbatim" - }, - "veranstalter": { - "wert": "Verein XY e.V.", - "quellen": ["https://..."], - "extraktion": "abgeleitet" - }, - "oeffnungszeiten": { - "wert": [{"tag": "Samstag", "von": "10:00", "bis": "22:00"}], - "quellen": ["https://..."], - "extraktion": "verbatim" - }, - "eintrittspreise": { - "wert": {"erwachsene_cent": 800, "kinder_cent": 0, "ermaessigt_cent": 500, "frei_unter_alter": 12, "hinweise": "Kinder unter 12 frei"}, - "quellen": ["https://..."], - "extraktion": "verbatim" - }, - "quellen_gesamt": ["https://...", "https://..."] -} - -Regeln: -- "extraktion": "verbatim" wenn der Wert wörtlich in einer strukturierten Quelle steht (Tabelle, Adressblock, Schema) -- "extraktion": "abgeleitet" wenn der Wert aus Fließtext interpretiert wurde -- Setze "wert" auf null wenn nicht findbar. Erfinde KEINE Werte. -- "quellen" muss die URL(s) enthalten, aus denen der Wert stammt -- Verwende für Wochentage NUR: Montag, Dienstag, Mittwoch, Donnerstag, Freitag, Samstag, Sonntag -- Preise in Cent (800 = 8,00 EUR) -`, m.Name, m.City, dateRange, website) -} - -func parsePass1Response(raw string) pass1Response { - cleaned := extractJSON(raw) - cleaned = stripJSONComments(cleaned) - - var result pass1Response - result.Fields = make(map[string]pass1FieldResult) - - var rawMap map[string]json.RawMessage - if err := json.Unmarshal([]byte(cleaned), &rawMap); err != nil { - slog.Warn("failed to parse pass1 response", "error", err, "cleaned", cleaned) - return result - } - result.raw = rawMap - - // Parse quellen_gesamt - if sourcesRaw, ok := rawMap["quellen_gesamt"]; ok { - _ = json.Unmarshal(sourcesRaw, &result.Sources) - } - - // Parse known field names - knownFields := []string{"website", "strasse", "plz", "stadt", "veranstalter", "oeffnungszeiten", "eintrittspreise"} - for _, name := range knownFields { - if fieldRaw, ok := rawMap[name]; ok { - var field pass1FieldResult - if err := json.Unmarshal(fieldRaw, &field); err == nil { - result.Fields[name] = field - } - } - } - - return result -} - -// --- Pass 2 types and prompt --- - -const pass2SystemPrompt = `Du bist ein Redakteur für Marktvogt, eine Plattform für Mittelaltermärkte und historische Veranstaltungen im DACH-Raum. - -Du bekommst: -1. Rechercheergebnisse aus einer vorherigen Websuche (bereits gesammelte Quelleninhalte und URLs) -2. Optional: Eine Liste von Feldern ("retry_felder"), die in der ersten Runde leer oder unsicher waren - -Du hast KEINEN Internetzugang. Arbeite ausschließlich mit den bereitgestellten Informationen. - -## Aufgabe 1: Beschreibung verfassen - -Verfasse eine ansprechende, informative Marktbeschreibung auf Deutsch. - -### Anforderungen: -- 3-5 Sätze, ca. 80-150 Wörter -- Sachlich-einladender Ton — kein Werbe-Deutsch, keine Superlative ("das größte", "einzigartig", "unvergesslich") -- Inhalt: Was macht diesen Markt besonders? Was erwartet Besucher? Marktstände, Programm, Musik, Handwerk, Ritterspiele, Gaukelei, historisches Lagerleben etc. -- Atmosphäre und Veranstaltungsort erwähnen wenn aus den Quellen ersichtlich -- Zielgruppen ansprechen: Familien, Mittelalter-Fans, Reenactment-Begeisterte -- NICHT erwähnen: Eintrittspreise, exakte Öffnungszeiten, Datumsangaben (stehen in den strukturierten Feldern) -- Alle Fakten in der Beschreibung müssen aus den mitgelieferten Quelleninhalten stammen - -## Aufgabe 2: Felder nachholen (nur wenn retry_felder vorhanden) - -Für jedes Feld in retry_felder: Versuche den Wert aus den bereits mitgelieferten Quelleninhalten zu extrahieren. - -## Antwortformat - -Antworte AUSSCHLIESSLICH mit validem JSON. - -{ - "beschreibung": { - "wert": "Die vollständige Beschreibung hier...", - "quellen": ["https://...", "https://..."], - "basiert_auf": "Kurze Erklärung" - }, - "retry_ergebnisse": { - "feldname": { - "wert": "...", - "quellen": ["https://..."], - "extraktion": "verbatim", - "hinweis": "optional" - } - } -} - -Wenn keine retry_felder übergeben wurden, gib "retry_ergebnisse": {} zurück. -Erfinde KEINE Fakten. Wenn die Quellen nichts hergeben, setze "wert" auf null.` - -type pass2DescriptionResult struct { - Value string `json:"wert"` - Sources []string `json:"quellen"` - BasedOn string `json:"basiert_auf"` -} - -type pass2RetryResult struct { - Value any `json:"wert"` - Sources []string `json:"quellen"` - Extraction string `json:"extraktion"` -} - -type pass2Response struct { - Description pass2DescriptionResult `json:"beschreibung"` - RetryResults map[string]pass2RetryResult `json:"retry_ergebnisse"` -} - -func buildPass2UserPrompt(m Market, pass1 pass1Response, retryFields []string, allSources []string) string { - // Build context from pass1 results - var sb strings.Builder - sb.WriteString(fmt.Sprintf("Markt: %s\nStadt: %s\n", m.Name, m.City)) - - if !m.StartDate.IsZero() { - sb.WriteString(fmt.Sprintf("Zeitraum: %s bis %s\n", m.StartDate.Format("02.01.2006"), m.EndDate.Format("02.01.2006"))) - } - - sb.WriteString("\n## Rechercheergebnisse aus Pass 1\n\n") - for name, field := range pass1.Fields { - if field.Value != nil { - val, _ := json.Marshal(field.Value) - sb.WriteString(fmt.Sprintf("%s: %s (Quellen: %s)\n", name, string(val), strings.Join(field.Sources, ", "))) - } - } - - sb.WriteString("\n## Quellen-URLs\n") - for _, src := range allSources { - sb.WriteString(fmt.Sprintf("- %s\n", src)) - } - - if len(retryFields) > 0 { - sb.WriteString(fmt.Sprintf("\n## retry_felder\n%s\n", strings.Join(retryFields, ", "))) - sb.WriteString("Versuche diese Felder aus den Quelleninhalten zu extrahieren.\n") - } - - return sb.String() -} - -func parsePass2Response(raw string) pass2Response { - cleaned := extractJSON(raw) - cleaned = stripJSONComments(cleaned) - - var result pass2Response - result.RetryResults = make(map[string]pass2RetryResult) - - if err := json.Unmarshal([]byte(cleaned), &result); err != nil { - slog.Warn("failed to parse pass2 response", "error", err, "cleaned", cleaned) - } - return result -} - -// --- Confidence scoring --- - -const extractionDerived = "abgeleitet" - -func computeConfidence(sources []string, extraction string) string { - count := len(sources) - if count >= 2 && extraction == "verbatim" { - return "high" - } - if count >= 2 && extraction == extractionDerived { - return "medium" - } - if count == 1 && extraction == "verbatim" { - return "medium" - } - return "low" -} - -// --- Merge results --- - -// fieldMapping maps pass1 field names to the FieldSuggestion field names used by the frontend. -var fieldMapping = map[string]string{ - "website": "website", - "strasse": "street", - "plz": "zip", - "stadt": "city", - "veranstalter": "organizer_name", - "oeffnungszeiten": "opening_hours", - "eintrittspreise": "admission_info", -} - -func mergeResults(m Market, pass1 pass1Response, pass2 pass2Response) ResearchResult { - suggestions := make([]FieldSuggestion, 0, len(pass1.Fields)+len(pass2.RetryResults)+1) - allSources := make([]string, 0, len(pass1.Sources)) - - // Collect unique sources - sourceSet := make(map[string]bool) - for _, s := range pass1.Sources { - sourceSet[s] = true - } - - // Process pass1 fields - for p1Name, field := range pass1.Fields { - frontendName, ok := fieldMapping[p1Name] - if !ok { - continue - } - if field.Value == nil { - continue - } - - // Convert opening hours format - suggestedValue := field.Value - if p1Name == "oeffnungszeiten" { - suggestedValue = convertOpeningHours(field.Value) - } - if p1Name == "eintrittspreise" { - suggestedValue = convertAdmission(field.Value) - } - - suggestions = append(suggestions, FieldSuggestion{ - Field: frontendName, - CurrentValue: getCurrentValue(m, frontendName), - SuggestedValue: suggestedValue, - Confidence: computeConfidence(field.Sources, field.Extraction), - Reason: formatReason(field.Sources, field.Extraction), - }) - - for _, s := range field.Sources { - sourceSet[s] = true - } - } - - // Process pass2 retry results - for p1Name, retryField := range pass2.RetryResults { - if retryField.Value == nil { - continue - } - frontendName, ok := fieldMapping[p1Name] - if !ok { - continue - } - - suggestedValue := retryField.Value - if p1Name == "oeffnungszeiten" { - suggestedValue = convertOpeningHours(retryField.Value) - } - if p1Name == "eintrittspreise" { - suggestedValue = convertAdmission(retryField.Value) - } - - suggestions = append(suggestions, FieldSuggestion{ - Field: frontendName, - CurrentValue: getCurrentValue(m, frontendName), - SuggestedValue: suggestedValue, - Confidence: computeConfidence(retryField.Sources, retryField.Extraction), - Reason: formatReason(retryField.Sources, retryField.Extraction) + " (retry)", - }) - - for _, s := range retryField.Sources { - sourceSet[s] = true - } - } - - // Add description from pass2 - if pass2.Description.Value != "" { - suggestions = append(suggestions, FieldSuggestion{ - Field: "description", - CurrentValue: m.Description, - SuggestedValue: pass2.Description.Value, - Confidence: computeConfidence(pass2.Description.Sources, extractionDerived), - Reason: pass2.Description.BasedOn, - }) - - for _, s := range pass2.Description.Sources { - sourceSet[s] = true - } - } - - for s := range sourceSet { - allSources = append(allSources, s) - } - - return ResearchResult{ - Suggestions: suggestions, - Sources: allSources, - } -} - -func getCurrentValue(m Market, field string) any { - switch field { - case "website": - return m.Website - case "street": - return m.Street - case "zip": - return m.Zip - case "city": - return m.City - case "organizer_name": - return m.OrganizerName - case "opening_hours": - return m.OpeningHours - case "admission_info": - return m.AdmissionInfo - case "description": - return m.Description - default: - return nil - } -} - -func formatReason(sources []string, extraction string) string { - srcCount := len(sources) - verb := "gefunden" - if extraction == extractionDerived { - verb = extractionDerived - } - if srcCount == 0 { - return verb - } - if srcCount == 1 { - return fmt.Sprintf("Aus 1 Quelle %s", verb) - } - return fmt.Sprintf("Aus %d Quellen %s", srcCount, verb) -} - -// convertOpeningHours converts pass1 format {tag, von, bis} to frontend format {day, open, close} -func convertOpeningHours(val any) any { - data, err := json.Marshal(val) - if err != nil { - return val - } - - var entries []map[string]string - if err := json.Unmarshal(data, &entries); err != nil { - return val - } - - result := make([]map[string]string, 0, len(entries)) - for _, e := range entries { - result = append(result, map[string]string{ - "day": e["tag"], - "open": e["von"], - "close": e["bis"], - }) - } - return result -} - -// convertAdmission converts pass1 format to frontend format (cents) -func convertAdmission(val any) any { - data, err := json.Marshal(val) - if err != nil { - return val - } - - var raw map[string]any - if err := json.Unmarshal(data, &raw); err != nil { - return val - } - - result := map[string]any{ - "adult_cents": toInt(raw["erwachsene_cent"]), - "child_cents": toInt(raw["kinder_cent"]), - "reduced_cents": toInt(raw["ermaessigt_cent"]), - "free_under_age": toInt(raw["frei_unter_alter"]), - "notes": toString(raw["hinweise"]), - } - return result -} - -func toInt(v any) int { - switch n := v.(type) { - case float64: - return int(n) - case int: - return n - default: - return 0 - } -} - -func toString(v any) string { - if s, ok := v.(string); ok { - return s - } - return "" -} - -// --- Cost logging --- - -func logCostTracking(marketID uuid.UUID, pass1, pass2 ai.PassResult) { - log := map[string]any{ - "market_id": marketID, - "timestamp": time.Now().Format(time.RFC3339), - } - - if pass1.Usage != nil { - log["pass_1"] = map[string]any{ - "model": pass1.Model, - "input_tokens": pass1.Usage.PromptTokens, - "output_tokens": pass1.Usage.CompletionTokens, - } - } - if pass2.Usage != nil { - log["pass_2"] = map[string]any{ - "model": pass2.Model, - "input_tokens": pass2.Usage.PromptTokens, - "output_tokens": pass2.Usage.CompletionTokens, - } - } - - data, _ := json.Marshal(log) - slog.Info("ai_research_cost", "data", string(data)) -} - -// --- JSON helpers --- - -func extractJSON(s string) string { - if idx := findJSONStart(s); idx >= 0 { - s = s[idx:] - if end := findJSONEnd(s); end > 0 { - s = s[:end+1] - } - } - return s -} - -func stripJSONComments(s string) string { - var result []byte - inString := false - escaped := false - for i := 0; i < len(s); i++ { - c := s[i] - if escaped { - result = append(result, c) - escaped = false - continue - } - if c == '\\' && inString { - result = append(result, c) - escaped = true - continue - } - if c == '"' { - inString = !inString - result = append(result, c) - continue - } - if !inString && c == '/' && i+1 < len(s) && s[i+1] == '/' { - for i < len(s) && s[i] != '\n' { - i++ - } - continue - } - result = append(result, c) - } - return string(result) -} - -func findJSONStart(s string) int { - for i, c := range s { - if c == '{' { - return i - } - } - return -1 -} - -func findJSONEnd(s string) int { - depth := 0 - for i, c := range s { - switch c { - case '{': - depth++ - case '}': - depth-- - if depth == 0 { - return i - } - } - } - return -1 + c.JSON(http.StatusNotImplemented, gin.H{"error": "research temporarily disabled during AI provider migration"}) } diff --git a/backend/internal/pkg/ai/client.go b/backend/internal/pkg/ai/client.go deleted file mode 100644 index 4d85d5d..0000000 --- a/backend/internal/pkg/ai/client.go +++ /dev/null @@ -1,194 +0,0 @@ -package ai - -import ( - "context" - "fmt" - "sync" - "time" - - "github.com/VikingOwl91/mistral-go-sdk" - "github.com/VikingOwl91/mistral-go-sdk/chat" - "github.com/VikingOwl91/mistral-go-sdk/conversation" -) - -// rateLimiter enforces a minimum interval between calls. Set rps<=0 to disable. -// Mirrors backend/internal/pkg/geocode/nominatim.go's lastReq gate. -type rateLimiter struct { - mu sync.Mutex - lastReq time.Time - minInterval time.Duration -} - -func newRateLimiter(rps float64) *rateLimiter { - if rps <= 0 { - return &rateLimiter{minInterval: 0} - } - return &rateLimiter{minInterval: time.Duration(float64(time.Second) / rps)} -} - -// TODO: wait() does not honor context cancellation — a cancelled caller will -// block up to minInterval while holding a mutex, and queued callers block up -// to N*minInterval. Mirror pkg/geocode/nominatim.go which has the same gap. -// Fix both sites together by taking ctx and using time.NewTimer + select. -func (rl *rateLimiter) wait() { - if rl.minInterval == 0 { - return - } - rl.mu.Lock() - defer rl.mu.Unlock() - since := time.Since(rl.lastReq) - if since < rl.minInterval { - time.Sleep(rl.minInterval - since) - } - rl.lastReq = time.Now() -} - -type Client struct { - sdk *mistral.Client - agentSimple string - modelComplex string - limiter *rateLimiter -} - -func New(apiKey, agentSimple, modelComplex string, rps float64) *Client { - if modelComplex == "" { - modelComplex = "mistral-large-latest" - } - - var sdk *mistral.Client - if apiKey != "" { - sdk = mistral.NewClient(apiKey, - mistral.WithTimeout(120*time.Second), - mistral.WithRetry(2, 1*time.Second), - ) - } - - return &Client{ - sdk: sdk, - agentSimple: agentSimple, - modelComplex: modelComplex, - limiter: newRateLimiter(rps), - } -} - -func (c *Client) Enabled() bool { - return c.sdk != nil && c.agentSimple != "" -} - -type UsageInfo struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` -} - -type PassResult struct { - Content string - Usage *UsageInfo - Model string -} - -// Pass1 uses the Conversations API to call the pre-created agent (with web search). -func (c *Client) Pass1(ctx context.Context, prompt string) (PassResult, error) { - c.limiter.wait() - storeFalse := false - resp, err := c.sdk.StartConversation(ctx, &conversation.StartRequest{ - AgentID: c.agentSimple, - Inputs: conversation.TextInputs(prompt), - Store: &storeFalse, - }) - if err != nil { - return PassResult{}, fmt.Errorf("pass1 conversation: %w", err) - } - - content := extractConversationContent(resp) - if content == "" { - return PassResult{}, fmt.Errorf("pass1: no assistant message in response") - } - - return PassResult{ - Content: content, - Usage: convertConvUsage(resp.Usage), - Model: "agent:" + c.agentSimple, - }, nil -} - -// Pass0 uses the Conversations API to call a discovery agent identified by agentID. -// The agent ID is passed explicitly so the discovery domain can configure its own -// agent independently of the agentSimple field used by Pass1. -func (c *Client) Pass0(ctx context.Context, agentID, prompt string) (PassResult, error) { - c.limiter.wait() - if c.sdk == nil || agentID == "" { - return PassResult{}, fmt.Errorf("pass0: ai client not configured (sdk=%v agentID=%q)", c.sdk != nil, agentID) - } - storeFalse := false - resp, err := c.sdk.StartConversation(ctx, &conversation.StartRequest{ - AgentID: agentID, - Inputs: conversation.TextInputs(prompt), - Store: &storeFalse, - }) - if err != nil { - return PassResult{}, fmt.Errorf("pass0 conversation: %w", err) - } - - content := extractConversationContent(resp) - if content == "" { - return PassResult{}, fmt.Errorf("pass0: no assistant message in response") - } - - return PassResult{ - Content: content, - Usage: convertConvUsage(resp.Usage), - Model: "agent:" + agentID, - }, nil -} - -// Pass2 uses chat completions for description generation + retry fields. -func (c *Client) Pass2(ctx context.Context, systemPrompt, userPrompt string) (PassResult, error) { - c.limiter.wait() - resp, err := c.sdk.ChatComplete(ctx, &chat.CompletionRequest{ - Model: c.modelComplex, - Messages: []chat.Message{ - &chat.SystemMessage{Content: chat.TextContent(systemPrompt)}, - &chat.UserMessage{Content: chat.TextContent(userPrompt)}, - }, - ResponseFormat: &chat.ResponseFormat{Type: "json_object"}, - }) - if err != nil { - return PassResult{}, fmt.Errorf("pass2 chat: %w", err) - } - - if len(resp.Choices) == 0 { - return PassResult{}, fmt.Errorf("pass2: no choices in response") - } - - return PassResult{ - Content: resp.Choices[0].Message.Content.String(), - Usage: convertChatUsage(resp.Usage), - Model: c.modelComplex, - }, nil -} - -func extractConversationContent(resp *conversation.Response) string { - for _, entry := range resp.Outputs { - if msg, ok := entry.(*conversation.MessageOutputEntry); ok { - return msg.Content.String() - } - } - return "" -} - -func convertConvUsage(u conversation.UsageInfo) *UsageInfo { - return &UsageInfo{ - PromptTokens: u.PromptTokens, - CompletionTokens: u.CompletionTokens, - TotalTokens: u.TotalTokens, - } -} - -func convertChatUsage(u chat.UsageInfo) *UsageInfo { - return &UsageInfo{ - PromptTokens: u.PromptTokens, - CompletionTokens: u.CompletionTokens, - TotalTokens: u.TotalTokens, - } -} diff --git a/backend/internal/pkg/ai/errors.go b/backend/internal/pkg/ai/errors.go new file mode 100644 index 0000000..1c0607f --- /dev/null +++ b/backend/internal/pkg/ai/errors.go @@ -0,0 +1,99 @@ +package ai + +import ( + "context" + "errors" + "fmt" + "net" + "strings" +) + +type ErrorCode int + +const ( + ErrInternal ErrorCode = iota + ErrRateLimited + ErrQuotaExceeded + ErrTimeout + ErrInvalidRequest + ErrUnavailable + ErrSchemaViolation +) + +func (c ErrorCode) String() string { + switch c { + case ErrInternal: + return "internal" + case ErrRateLimited: + return "rate_limited" + case ErrQuotaExceeded: + return "quota_exceeded" + case ErrTimeout: + return "timeout" + case ErrInvalidRequest: + return "invalid_request" + case ErrUnavailable: + return "unavailable" + case ErrSchemaViolation: + return "schema_violation" + default: + return "internal" + } +} + +type ProviderError struct { + Code ErrorCode + Message string + Retryable bool + Inner error + RawOutput string +} + +func (e *ProviderError) Error() string { + if e.Inner != nil { + return fmt.Sprintf("ai: %s: %s: %v", e.Code, e.Message, e.Inner) + } + return fmt.Sprintf("ai: %s: %s", e.Code, e.Message) +} + +func (e *ProviderError) Unwrap() error { return e.Inner } + +func ClassifyError(err error) *ProviderError { + if err == nil { + return nil + } + var pe *ProviderError + if errors.As(err, &pe) { + return pe + } + if errors.Is(err, context.DeadlineExceeded) { + return &ProviderError{Code: ErrTimeout, Message: "context deadline exceeded", Retryable: true, Inner: err} + } + + msg := strings.ToLower(err.Error()) + switch { + case strings.Contains(msg, "429"), + strings.Contains(msg, "too many requests"), + strings.Contains(msg, "rate limit"): + return &ProviderError{Code: ErrRateLimited, Message: err.Error(), Retryable: true, Inner: err} + case strings.Contains(msg, "deadline exceeded"), + strings.Contains(msg, "timeout"): + return &ProviderError{Code: ErrTimeout, Message: err.Error(), Retryable: true, Inner: err} + case strings.Contains(msg, "connection refused"), + strings.Contains(msg, "no such host"), + isNetError(err): + return &ProviderError{Code: ErrUnavailable, Message: err.Error(), Retryable: true, Inner: err} + case strings.Contains(msg, "quota"), + strings.Contains(msg, "insufficient"): + return &ProviderError{Code: ErrQuotaExceeded, Message: err.Error(), Retryable: false, Inner: err} + case strings.Contains(msg, "400"), + strings.Contains(msg, "invalid"): + return &ProviderError{Code: ErrInvalidRequest, Message: err.Error(), Retryable: false, Inner: err} + } + return &ProviderError{Code: ErrInternal, Message: err.Error(), Retryable: false, Inner: err} +} + +func isNetError(err error) bool { + var ne net.Error + return errors.As(err, &ne) +} diff --git a/backend/internal/pkg/ai/factory.go b/backend/internal/pkg/ai/factory.go new file mode 100644 index 0000000..bf4fbb2 --- /dev/null +++ b/backend/internal/pkg/ai/factory.go @@ -0,0 +1,41 @@ +package ai + +import ( + "context" + "fmt" + "time" + + mistral "github.com/VikingOwl91/mistral-go-sdk" + "github.com/VikingOwl91/mistral-go-sdk/chat" + "marktvogt.de/backend/internal/config" +) + +const ( + providerOllama = "ollama" + providerMistral = "mistral" +) + +func NewFromConfig(cfg config.AIConfig) (Provider, error) { + switch cfg.Provider { + case "", providerOllama: + return NewOllamaProvider(OllamaConfig{ + BaseURL: cfg.OllamaURL, + Model: cfg.OllamaModel, + }), nil + case providerMistral: + if cfg.MistralAPIKey == "" { + return nil, fmt.Errorf("ai: provider=%s requires AI_MISTRAL_API_KEY", providerMistral) + } + sdk := mistral.NewClient( + cfg.MistralAPIKey, + mistral.WithTimeout(120*time.Second), + mistral.WithRetry(2, 1*time.Second), + ) + chatFn := func(ctx context.Context, req *chat.CompletionRequest) (*chat.CompletionResponse, error) { + return sdk.ChatComplete(ctx, req) + } + return newMistralProviderWithChat(cfg.MistralModel, chatFn, newRateLimiter(cfg.RateLimitRPS)), nil + default: + return nil, fmt.Errorf("ai: unknown provider %q (want %s|%s)", cfg.Provider, providerOllama, providerMistral) + } +} diff --git a/backend/internal/pkg/ai/factory_test.go b/backend/internal/pkg/ai/factory_test.go new file mode 100644 index 0000000..96c34d4 --- /dev/null +++ b/backend/internal/pkg/ai/factory_test.go @@ -0,0 +1,29 @@ +package ai + +import ( + "testing" + + "marktvogt.de/backend/internal/config" +) + +func TestNewFromConfig_Ollama(t *testing.T) { + p, err := NewFromConfig(config.AIConfig{Provider: providerOllama, OllamaURL: "http://x:11434", OllamaModel: "m"}) + if err != nil { + t.Fatalf("NewFromConfig: %v", err) + } + if p.Name() != providerOllama { + t.Fatalf("Name: %q", p.Name()) + } +} + +func TestNewFromConfig_MistralRequiresKey(t *testing.T) { + if _, err := NewFromConfig(config.AIConfig{Provider: providerMistral}); err == nil { + t.Fatal("want error when MistralAPIKey is empty") + } +} + +func TestNewFromConfig_UnknownProvider(t *testing.T) { + if _, err := NewFromConfig(config.AIConfig{Provider: "llama-cpp"}); err == nil { + t.Fatal("want error for unknown provider") + } +} diff --git a/backend/internal/pkg/ai/mistral_provider.go b/backend/internal/pkg/ai/mistral_provider.go new file mode 100644 index 0000000..699da83 --- /dev/null +++ b/backend/internal/pkg/ai/mistral_provider.go @@ -0,0 +1,99 @@ +package ai + +import ( + "context" + "fmt" + + "github.com/VikingOwl91/mistral-go-sdk/chat" +) + +type chatFunc func(ctx context.Context, req *chat.CompletionRequest) (*chat.CompletionResponse, error) + +type MistralProvider struct { + model string + chatFn chatFunc + limiter *rateLimiter // from ratelimit.go; nil disables +} + +func newMistralProviderWithChat(model string, fn chatFunc, limiter *rateLimiter) *MistralProvider { + return &MistralProvider{model: model, chatFn: fn, limiter: limiter} +} + +func (p *MistralProvider) Name() string { return "mistral" } +func (p *MistralProvider) SupportsJSONMode() bool { return true } +func (p *MistralProvider) SupportsJSONSchema() bool { return false } + +func (p *MistralProvider) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) { + if p.chatFn == nil { + return nil, &ProviderError{Code: ErrInternal, Message: "mistral provider not configured", Retryable: false} + } + if p.limiter != nil { + p.limiter.wait() + } + + systemContent := req.SystemPrompt + if len(req.JSONSchema) > 0 { + if systemContent != "" { + systemContent += "\n\n" + } + systemContent += "Respond with a JSON object that conforms to the following JSON Schema. " + + "Do not output anything outside the JSON. Schema:\n" + string(req.JSONSchema) + } + + msgs := []chat.Message{} + if systemContent != "" { + msgs = append(msgs, &chat.SystemMessage{Content: chat.TextContent(systemContent)}) + } + msgs = append(msgs, &chat.UserMessage{Content: chat.TextContent(req.UserMessage)}) + + creq := &chat.CompletionRequest{ + Model: firstNonEmpty(req.Model, p.model), + Messages: msgs, + } + if req.JSONMode || len(req.JSONSchema) > 0 { + creq.ResponseFormat = &chat.ResponseFormat{Type: "json_object"} + } + if req.Temperature != 0 { + temp := float64(req.Temperature) + creq.Temperature = &temp + } + if req.MaxTokens != 0 { + creq.MaxTokens = &req.MaxTokens + } + + resp, err := p.chatFn(ctx, creq) + if err != nil { + return nil, ClassifyError(err) + } + if len(resp.Choices) == 0 { + return nil, &ProviderError{Code: ErrInternal, Message: "no choices in response", Retryable: false} + } + content := resp.Choices[0].Message.Content.String() + + if len(req.JSONSchema) > 0 { + if err := ValidateSchema(req.JSONSchema, []byte(content)); err != nil { + return nil, &ProviderError{ + Code: ErrSchemaViolation, + Message: fmt.Sprintf("response does not match schema: %v", err), + Retryable: true, + Inner: err, + RawOutput: content, + } + } + } + + return &ChatResponse{ + Content: content, + Model: resp.Model, + PromptTokens: resp.Usage.PromptTokens, + OutputTokens: resp.Usage.CompletionTokens, + TotalTokens: resp.Usage.TotalTokens, + }, nil +} + +func firstNonEmpty(a, b string) string { + if a != "" { + return a + } + return b +} diff --git a/backend/internal/pkg/ai/mistral_provider_test.go b/backend/internal/pkg/ai/mistral_provider_test.go new file mode 100644 index 0000000..f0b72b5 --- /dev/null +++ b/backend/internal/pkg/ai/mistral_provider_test.go @@ -0,0 +1,123 @@ +package ai + +import ( + "context" + "errors" + "testing" + + "github.com/VikingOwl91/mistral-go-sdk/chat" +) + +func TestMistral_Chat_PassesThroughContent(t *testing.T) { + fakeChat := func(ctx context.Context, req *chat.CompletionRequest) (*chat.CompletionResponse, error) { + return &chat.CompletionResponse{ + Model: "mistral-large-latest", + Choices: []chat.CompletionChoice{ + {Message: chat.AssistantMessage{Content: chat.TextContent("ok")}}, + }, + Usage: chat.UsageInfo{PromptTokens: 3, CompletionTokens: 1, TotalTokens: 4}, + }, nil + } + p := newMistralProviderWithChat("mistral-large-latest", fakeChat, nil) + + resp, err := p.Chat(context.Background(), &ChatRequest{SystemPrompt: "s", UserMessage: "u"}) + if err != nil { + t.Fatalf("Chat: %v", err) + } + if resp.Content != "ok" || resp.TotalTokens != 4 { + t.Fatalf("unexpected: %+v", resp) + } +} + +func TestMistral_Chat_JSONModeSetsResponseFormat(t *testing.T) { + var seen *chat.CompletionRequest + fakeChat := func(ctx context.Context, req *chat.CompletionRequest) (*chat.CompletionResponse, error) { + seen = req + return &chat.CompletionResponse{Choices: []chat.CompletionChoice{{Message: chat.AssistantMessage{Content: chat.TextContent("{}")}}}}, nil + } + p := newMistralProviderWithChat("m", fakeChat, nil) + _, err := p.Chat(context.Background(), &ChatRequest{UserMessage: "x", JSONMode: true}) + if err != nil { + t.Fatalf("Chat: %v", err) + } + if seen == nil || seen.ResponseFormat == nil || seen.ResponseFormat.Type != "json_object" { + t.Fatalf("ResponseFormat not set: %+v", seen) + } +} + +func TestMistral_Chat_SchemaEmbeddedInSystemPromptAndValidated(t *testing.T) { + schema := []byte(`{"type":"object","required":["foo"],"properties":{"foo":{"type":"string"}},"additionalProperties":false}`) + var seen *chat.CompletionRequest + fakeChat := func(ctx context.Context, req *chat.CompletionRequest) (*chat.CompletionResponse, error) { + seen = req + return &chat.CompletionResponse{Choices: []chat.CompletionChoice{{Message: chat.AssistantMessage{Content: chat.TextContent(`{"foo":"bar"}`)}}}}, nil + } + p := newMistralProviderWithChat("m", fakeChat, nil) + resp, err := p.Chat(context.Background(), &ChatRequest{SystemPrompt: "base system", UserMessage: "x", JSONSchema: schema}) + if err != nil { + t.Fatalf("Chat: %v", err) + } + if resp.Content != `{"foo":"bar"}` { + t.Fatalf("content: %q", resp.Content) + } + sysMsg, ok := seen.Messages[0].(*chat.SystemMessage) + if !ok { + t.Fatalf("first message must be system: %T", seen.Messages[0]) + } + sys := sysMsg.Content.String() + if !containsAll(sys, []string{"base system", "JSON Schema"}) { + t.Fatalf("system prompt missing expected fragments: %q", sys) + } +} + +func TestMistral_Chat_SchemaViolationReturnsRetryableError(t *testing.T) { + schema := []byte(`{"type":"object","required":["foo"],"properties":{"foo":{"type":"string"}},"additionalProperties":false}`) + fakeChat := func(ctx context.Context, req *chat.CompletionRequest) (*chat.CompletionResponse, error) { + return &chat.CompletionResponse{Choices: []chat.CompletionChoice{{Message: chat.AssistantMessage{Content: chat.TextContent(`{"bar":1}`)}}}}, nil + } + p := newMistralProviderWithChat("m", fakeChat, nil) + _, err := p.Chat(context.Background(), &ChatRequest{UserMessage: "x", JSONSchema: schema}) + if err == nil { + t.Fatal("want error") + } + var pe *ProviderError + if !errors.As(err, &pe) { + t.Fatalf("want *ProviderError, got %T", err) + } + if pe.Code != ErrSchemaViolation || !pe.Retryable || pe.RawOutput != `{"bar":1}` { + t.Fatalf("unexpected: %+v", pe) + } +} + +func TestMistral_Supports(t *testing.T) { + p := newMistralProviderWithChat("m", nil, nil) + if !p.SupportsJSONMode() { + t.Fatal("Mistral supports JSON mode") + } + if p.SupportsJSONSchema() { + t.Fatal("Mistral does NOT natively support JSON schema (prompt-based only)") + } + if p.Name() != "mistral" { + t.Fatalf("Name: %q", p.Name()) + } +} + +func containsAll(s string, parts []string) bool { + for _, p := range parts { + if !contains(s, p) { + return false + } + } + return true +} +func contains(s, sub string) bool { + return len(sub) == 0 || (len(s) >= len(sub) && (s == sub || indexOf(s, sub) >= 0)) +} +func indexOf(s, sub string) int { + for i := 0; i+len(sub) <= len(s); i++ { + if s[i:i+len(sub)] == sub { + return i + } + } + return -1 +} diff --git a/backend/internal/pkg/ai/ollama.go b/backend/internal/pkg/ai/ollama.go new file mode 100644 index 0000000..95ee20d --- /dev/null +++ b/backend/internal/pkg/ai/ollama.go @@ -0,0 +1,127 @@ +package ai + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +type OllamaConfig struct { + BaseURL string + Model string + Timeout time.Duration +} + +type OllamaProvider struct { + cfg OllamaConfig + client *http.Client +} + +func NewOllamaProvider(cfg OllamaConfig) *OllamaProvider { + if cfg.Timeout == 0 { + cfg.Timeout = 300 * time.Second + } + return &OllamaProvider{ + cfg: cfg, + client: &http.Client{Timeout: cfg.Timeout}, + } +} + +func (p *OllamaProvider) Name() string { return "ollama" } +func (p *OllamaProvider) SupportsJSONMode() bool { return true } +func (p *OllamaProvider) SupportsJSONSchema() bool { return true } + +type ollamaChatReq struct { + Model string `json:"model"` + Messages []ollamaMessage `json:"messages"` + Stream bool `json:"stream"` + Format json.RawMessage `json:"format,omitempty"` + Options *ollamaOptions `json:"options,omitempty"` +} + +type ollamaMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type ollamaOptions struct { + Temperature float32 `json:"temperature,omitempty"` + NumPredict int `json:"num_predict,omitempty"` +} + +type ollamaChatResp struct { + Model string `json:"model"` + Message ollamaMessage `json:"message"` + Done bool `json:"done"` + PromptEvalCount int `json:"prompt_eval_count"` + EvalCount int `json:"eval_count"` +} + +func (p *OllamaProvider) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) { + model := req.Model + if model == "" { + model = p.cfg.Model + } + body := ollamaChatReq{ + Model: model, + Messages: buildOllamaMessages(req), + Stream: false, + } + switch { + case len(req.JSONSchema) > 0: + body.Format = req.JSONSchema + case req.JSONMode: + body.Format = json.RawMessage(`"json"`) + } + if req.Temperature != 0 || req.MaxTokens != 0 { + body.Options = &ollamaOptions{Temperature: req.Temperature, NumPredict: req.MaxTokens} + } + + buf, err := json.Marshal(body) + if err != nil { + return nil, &ProviderError{Code: ErrInvalidRequest, Message: "marshal request", Retryable: false, Inner: err} + } + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.cfg.BaseURL+"/api/chat", bytes.NewReader(buf)) + if err != nil { + return nil, &ProviderError{Code: ErrInternal, Message: "new request", Retryable: false, Inner: err} + } + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := p.client.Do(httpReq) + if err != nil { + return nil, ClassifyError(err) + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode >= 400 { + b, _ := io.ReadAll(resp.Body) + pe := ClassifyError(fmt.Errorf("ollama status %d: %s", resp.StatusCode, string(b))) + return nil, pe + } + + var out ollamaChatResp + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + return nil, &ProviderError{Code: ErrInternal, Message: "decode response", Retryable: false, Inner: err} + } + return &ChatResponse{ + Content: out.Message.Content, + Model: out.Model, + PromptTokens: out.PromptEvalCount, + OutputTokens: out.EvalCount, + TotalTokens: out.PromptEvalCount + out.EvalCount, + }, nil +} + +func buildOllamaMessages(req *ChatRequest) []ollamaMessage { + msgs := make([]ollamaMessage, 0, 2) + if req.SystemPrompt != "" { + msgs = append(msgs, ollamaMessage{Role: "system", Content: req.SystemPrompt}) + } + if req.UserMessage != "" { + msgs = append(msgs, ollamaMessage{Role: "user", Content: req.UserMessage}) + } + return msgs +} diff --git a/backend/internal/pkg/ai/ollama_test.go b/backend/internal/pkg/ai/ollama_test.go new file mode 100644 index 0000000..afc79ed --- /dev/null +++ b/backend/internal/pkg/ai/ollama_test.go @@ -0,0 +1,98 @@ +package ai + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestOllama_Chat_SendsRequestAndParsesResponse(t *testing.T) { + var captured map[string]any + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/chat" { + t.Errorf("path: got %s, want /api/chat", r.URL.Path) + } + body, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(body, &captured) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "model":"qwen2.5:14b-instruct", + "message":{"role":"assistant","content":"hello"}, + "done":true, + "prompt_eval_count":10, + "eval_count":5 + }`)) + })) + defer srv.Close() + + p := NewOllamaProvider(OllamaConfig{BaseURL: srv.URL, Model: "qwen2.5:14b-instruct", Timeout: 5 * time.Second}) + resp, err := p.Chat(context.Background(), &ChatRequest{SystemPrompt: "be brief", UserMessage: "hi", JSONMode: true}) + if err != nil { + t.Fatalf("Chat: %v", err) + } + if resp.Content != "hello" { + t.Fatalf("content: got %q", resp.Content) + } + if resp.PromptTokens != 10 || resp.OutputTokens != 5 || resp.TotalTokens != 15 { + t.Fatalf("tokens: %+v", resp) + } + if captured["stream"] != false { + t.Fatalf("stream must be false: %v", captured["stream"]) + } + if captured["format"] != "json" { + t.Fatalf("format for JSONMode=true must be \"json\", got %v", captured["format"]) + } +} + +func TestOllama_Chat_ForwardsSchema(t *testing.T) { + var captured map[string]any + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(body, &captured) + _, _ = w.Write([]byte(`{"model":"m","message":{"role":"assistant","content":"{}"},"done":true}`)) + })) + defer srv.Close() + + p := NewOllamaProvider(OllamaConfig{BaseURL: srv.URL, Model: "m", Timeout: time.Second}) + schema := []byte(`{"type":"object"}`) + _, err := p.Chat(context.Background(), &ChatRequest{UserMessage: "x", JSONSchema: schema}) + if err != nil { + t.Fatalf("Chat: %v", err) + } + fmtField, ok := captured["format"].(map[string]any) + if !ok { + t.Fatalf("format must be an object when JSONSchema set: %v", captured["format"]) + } + if fmtField["type"] != "object" { + t.Fatalf("schema not forwarded: %v", fmtField) + } +} + +func TestOllama_Chat_Unavailable(t *testing.T) { + p := NewOllamaProvider(OllamaConfig{BaseURL: "http://127.0.0.1:1", Timeout: 100 * time.Millisecond}) + _, err := p.Chat(context.Background(), &ChatRequest{UserMessage: "x"}) + if err == nil { + t.Fatal("want error, got nil") + } + pe := ClassifyError(err) + if pe.Code != ErrUnavailable && pe.Code != ErrTimeout { + t.Fatalf("expected Unavailable or Timeout, got %v", pe.Code) + } + if !pe.Retryable { + t.Fatal("must be retryable") + } +} + +func TestOllama_Supports(t *testing.T) { + p := NewOllamaProvider(OllamaConfig{BaseURL: "x"}) + if !p.SupportsJSONMode() || !p.SupportsJSONSchema() { + t.Fatal("Ollama supports both") + } + if p.Name() != "ollama" { + t.Fatalf("Name: %q", p.Name()) + } +} diff --git a/backend/internal/pkg/ai/provider.go b/backend/internal/pkg/ai/provider.go new file mode 100644 index 0000000..163a9db --- /dev/null +++ b/backend/internal/pkg/ai/provider.go @@ -0,0 +1,31 @@ +package ai + +import ( + "context" + "encoding/json" +) + +type Provider interface { + Name() string + Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) + SupportsJSONMode() bool + SupportsJSONSchema() bool +} + +type ChatRequest struct { + SystemPrompt string + UserMessage string + Model string + MaxTokens int + Temperature float32 + JSONMode bool + JSONSchema json.RawMessage +} + +type ChatResponse struct { + Content string + Model string + PromptTokens int + OutputTokens int + TotalTokens int +} diff --git a/backend/internal/pkg/ai/provider_test.go b/backend/internal/pkg/ai/provider_test.go new file mode 100644 index 0000000..08b3a47 --- /dev/null +++ b/backend/internal/pkg/ai/provider_test.go @@ -0,0 +1,89 @@ +package ai + +import ( + "context" + "errors" + "testing" +) + +func TestChatRequest_DefaultsAreZeroValues(t *testing.T) { + r := ChatRequest{} + if r.Temperature != 0 { + t.Fatalf("Temperature default: got %v, want 0", r.Temperature) + } + if r.MaxTokens != 0 { + t.Fatalf("MaxTokens default: got %v, want 0", r.MaxTokens) + } + if r.JSONMode { + t.Fatalf("JSONMode default: got true, want false") + } + if r.JSONSchema != nil { + t.Fatalf("JSONSchema default: got non-nil, want nil") + } +} + +type stubProvider struct { + name string + resp *ChatResponse + err error +} + +func (s *stubProvider) Name() string { return s.name } +func (s *stubProvider) SupportsJSONMode() bool { return true } +func (s *stubProvider) SupportsJSONSchema() bool { return true } +func (s *stubProvider) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) { + return s.resp, s.err +} + +func TestStubProvider_SatisfiesInterface(t *testing.T) { + var _ Provider = (*stubProvider)(nil) +} + +func TestStubProvider_ReturnsError(t *testing.T) { + sentinel := errors.New("boom") + var p Provider = &stubProvider{name: "stub", err: sentinel} + _, err := p.Chat(context.Background(), &ChatRequest{}) + if !errors.Is(err, sentinel) { + t.Fatalf("got %v, want sentinel", err) + } +} + +func TestProviderError_Error(t *testing.T) { + inner := errors.New("root cause") + pe := &ProviderError{Code: ErrRateLimited, Message: "slow down", Retryable: true, Inner: inner} + got := pe.Error() + if got == "" { + t.Fatal("Error() must not be empty") + } + if !errors.Is(pe, inner) { + t.Fatalf("errors.Is must unwrap Inner") + } +} + +func TestClassifyError_DetectsRateLimit(t *testing.T) { + cases := []struct { + msg string + want ErrorCode + }{ + {"429 Too Many Requests", ErrRateLimited}, + {"context deadline exceeded", ErrTimeout}, + {"connection refused", ErrUnavailable}, + {"something totally unrecognized", ErrInternal}, + } + for _, tc := range cases { + t.Run(tc.msg, func(t *testing.T) { + got := ClassifyError(errors.New(tc.msg)).Code + if got != tc.want { + t.Fatalf("ClassifyError(%q) = %v, want %v", tc.msg, got, tc.want) + } + }) + } +} + +func TestClassifyError_PreservesProviderError(t *testing.T) { + orig := &ProviderError{Code: ErrSchemaViolation, Message: "x", Retryable: true} + got := ClassifyError(orig) + if got != orig { + t.Fatal("ClassifyError must return the same *ProviderError instance") + } +} diff --git a/backend/internal/pkg/ai/ratelimit.go b/backend/internal/pkg/ai/ratelimit.go deleted file mode 100644 index 02b9e19..0000000 --- a/backend/internal/pkg/ai/ratelimit.go +++ /dev/null @@ -1,25 +0,0 @@ -package ai - -import "strings" - -// IsRateLimit reports whether err is a Mistral 429 / web_search rate-limit -// signal. The SDK surfaces these as wrapped errors whose message contains -// "rate limit" and/or "status 429"; we match defensively on both. -// -// Shared helper so domains (discovery, market/research) treat rate limits -// consistently: log at WARN, return a 429 to the caller with a Retry-After -// hint instead of a generic 500. -func IsRateLimit(err error) bool { - if err == nil { - return false - } - msg := err.Error() - return strings.Contains(msg, "rate limit") || - strings.Contains(msg, "status 429") || - strings.Contains(msg, "429") -} - -// DefaultRetryAfterSeconds is the hint we send when the SDK error doesn't -// carry a structured Retry-After. 60s matches the typical web_search budget -// window on Mistral's paid tier. -const DefaultRetryAfterSeconds = 60 diff --git a/backend/internal/pkg/ai/ratelimit_test.go b/backend/internal/pkg/ai/ratelimit_test.go deleted file mode 100644 index dc753d1..0000000 --- a/backend/internal/pkg/ai/ratelimit_test.go +++ /dev/null @@ -1,28 +0,0 @@ -package ai - -import ( - "errors" - "testing" -) - -func TestIsRateLimit(t *testing.T) { - tests := []struct { - name string - err error - want bool - }{ - {"nil", nil, false}, - {"unrelated error", errors.New("connection refused"), false}, - {"mistral web_search 429", errors.New("pass1 conversation: mistral: web_search rate limit reached. (status 429)"), true}, - {"bare status 429", errors.New("upstream returned status 429"), true}, - {"raw 429 token", errors.New("http 429 received"), true}, - {"token 42900 should NOT match bare token but does by substring", errors.New("code 42900 from upstream"), true}, - } - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - if got := IsRateLimit(tc.err); got != tc.want { - t.Errorf("IsRateLimit(%v) = %v, want %v", tc.err, got, tc.want) - } - }) - } -} diff --git a/backend/internal/pkg/ai/ratelimiter.go b/backend/internal/pkg/ai/ratelimiter.go new file mode 100644 index 0000000..73a7041 --- /dev/null +++ b/backend/internal/pkg/ai/ratelimiter.go @@ -0,0 +1,33 @@ +package ai + +import ( + "sync" + "time" +) + +// rateLimiter enforces a minimum interval between calls. Set rps<=0 to disable. +type rateLimiter struct { + mu sync.Mutex + lastReq time.Time + minInterval time.Duration +} + +func newRateLimiter(rps float64) *rateLimiter { + if rps <= 0 { + return &rateLimiter{minInterval: 0} + } + return &rateLimiter{minInterval: time.Duration(float64(time.Second) / rps)} +} + +func (rl *rateLimiter) wait() { + if rl.minInterval == 0 { + return + } + rl.mu.Lock() + defer rl.mu.Unlock() + since := time.Since(rl.lastReq) + if since < rl.minInterval { + time.Sleep(rl.minInterval - since) + } + rl.lastReq = time.Now() +} diff --git a/backend/internal/pkg/ai/schema.go b/backend/internal/pkg/ai/schema.go new file mode 100644 index 0000000..bac558c --- /dev/null +++ b/backend/internal/pkg/ai/schema.go @@ -0,0 +1,28 @@ +package ai + +import ( + "bytes" + "encoding/json" + "fmt" + + "github.com/santhosh-tekuri/jsonschema/v5" +) + +func ValidateSchema(schemaJSON, docJSON []byte) error { + compiler := jsonschema.NewCompiler() + if err := compiler.AddResource("schema.json", bytes.NewReader(schemaJSON)); err != nil { + return fmt.Errorf("add schema: %w", err) + } + sch, err := compiler.Compile("schema.json") + if err != nil { + return fmt.Errorf("compile schema: %w", err) + } + var doc any + if err := json.Unmarshal(docJSON, &doc); err != nil { + return fmt.Errorf("parse doc: %w", err) + } + if err := sch.Validate(doc); err != nil { + return fmt.Errorf("validate: %w", err) + } + return nil +} diff --git a/backend/internal/pkg/ai/schema_test.go b/backend/internal/pkg/ai/schema_test.go new file mode 100644 index 0000000..b54df7f --- /dev/null +++ b/backend/internal/pkg/ai/schema_test.go @@ -0,0 +1,25 @@ +package ai + +import "testing" + +func TestValidateSchema_AcceptsMatchingDoc(t *testing.T) { + schema := []byte(`{"type":"object","required":["foo"],"properties":{"foo":{"type":"string"}},"additionalProperties":false}`) + doc := []byte(`{"foo":"bar"}`) + if err := ValidateSchema(schema, doc); err != nil { + t.Fatalf("want nil, got %v", err) + } +} + +func TestValidateSchema_RejectsMissingRequired(t *testing.T) { + schema := []byte(`{"type":"object","required":["foo"],"properties":{"foo":{"type":"string"}}}`) + doc := []byte(`{}`) + if err := ValidateSchema(schema, doc); err == nil { + t.Fatal("want error, got nil") + } +} + +func TestValidateSchema_RejectsBadJSON(t *testing.T) { + if err := ValidateSchema([]byte(`{}`), []byte(`not json`)); err == nil { + t.Fatal("want error, got nil") + } +} diff --git a/backend/internal/server/routes.go b/backend/internal/server/routes.go index 7610d0b..411af06 100644 --- a/backend/internal/server/routes.go +++ b/backend/internal/server/routes.go @@ -1,6 +1,7 @@ package server import ( + "fmt" "net/http" "github.com/gin-gonic/gin" @@ -68,23 +69,20 @@ func (s *Server) registerRoutes() { // Admin market routes adminMarketHandler := market.NewAdminHandler(marketSvc) - aiClient := ai.New(s.cfg.AI.APIKey, s.cfg.AI.AgentSimple, s.cfg.AI.ModelComplex, s.cfg.AI.RateLimitRPS) - researchHandler := market.NewResearchHandler(marketSvc, aiClient) + aiProvider, err := ai.NewFromConfig(s.cfg.AI) + if err != nil { + panic(fmt.Errorf("init ai provider: %w", err)) + } + researchHandler := market.NewResearchHandler(marketSvc, aiProvider) requireAdmin := middleware.RequireRole("admin") market.RegisterAdminRoutes(v1, adminMarketHandler, researchHandler, requireAuth, requireAdmin) // Discovery routes discoveryRepo := discovery.NewRepository(s.db) crawlerInstance := crawler.NewCrawler(s.cfg.Discovery.CrawlerUserAgent, crawler.DefaultSourceConfigs()) - // Per-row LLM enrichment (MR 3b). Operator-triggered only; disabled rows - // fall through via NoopLLMEnricher when AI isn't configured. - var llmEnricher enrich.LLMEnricher = enrich.NoopLLMEnricher{} - var simClassifier enrich.SimilarityClassifier = enrich.NoopSimilarityClassifier{} - if aiClient.Enabled() { - scraper := scrape.New(s.cfg.Discovery.CrawlerUserAgent) - llmEnricher = enrich.NewMistralLLMEnricher(aiClient, scraper) - simClassifier = enrich.NewMistralSimilarityClassifier(aiClient) - } + scraper := scrape.New(s.cfg.Discovery.CrawlerUserAgent) + llmEnricher := enrich.NewLLMEnricher(aiProvider, scraper) + simClassifier := enrich.NewSimilarityClassifier(aiProvider) discoveryService := discovery.NewService(discoveryRepo, crawlerInstance, discovery.NewLinkChecker(), marketSvc, geocoder, llmEnricher, simClassifier) discoveryHandler := discovery.NewHandler(discoveryService, s.cfg.Discovery.CrawlerManualRateLimitPerHour) requireTickToken := middleware.RequireBearerToken(s.cfg.Discovery.Token)