feat(ai): pluggable provider interface, Ollama + Mistral impls, migrate Pass2 sites
Replaces the Mistral-only ai.Client with an ai.Provider interface backed by Ollama and Mistral implementations. Migrates enrichment + similarity callers to ai.Provider.Chat. Research endpoint returns 501 until commit 2 reinstates it on the new orchestrator.
This commit is contained in:
@@ -29,6 +29,7 @@ import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"marktvogt.de/backend/internal/config"
|
||||
"marktvogt.de/backend/internal/domain/discovery/enrich"
|
||||
"marktvogt.de/backend/internal/pkg/ai"
|
||||
"marktvogt.de/backend/internal/pkg/scrape"
|
||||
@@ -65,8 +66,14 @@ func realMain() int {
|
||||
)
|
||||
flag.Parse()
|
||||
|
||||
apiKey := os.Getenv("AI_API_KEY")
|
||||
model := os.Getenv("AI_MODEL_COMPLEX")
|
||||
apiKey := os.Getenv("AI_MISTRAL_API_KEY")
|
||||
if apiKey == "" {
|
||||
apiKey = os.Getenv("AI_API_KEY") // legacy fallback
|
||||
}
|
||||
model := os.Getenv("AI_MISTRAL_MODEL")
|
||||
if model == "" {
|
||||
model = os.Getenv("AI_MODEL_COMPLEX") // legacy fallback
|
||||
}
|
||||
if model == "" {
|
||||
model = "mistral-large-latest"
|
||||
}
|
||||
@@ -74,9 +81,14 @@ func realMain() int {
|
||||
if userAgent == "" {
|
||||
userAgent = "marktvogt-eval/1.0 (+https://marktvogt.de)"
|
||||
}
|
||||
client := ai.New(apiKey, "", model, 1.0)
|
||||
if !client.Enabled() {
|
||||
slog.Error("AI client not configured (set AI_API_KEY)")
|
||||
client, err := ai.NewFromConfig(config.AIConfig{
|
||||
Provider: "mistral",
|
||||
MistralAPIKey: apiKey,
|
||||
MistralModel: model,
|
||||
RateLimitRPS: 1.0,
|
||||
})
|
||||
if err != nil {
|
||||
slog.Error("AI client not configured", "error", err)
|
||||
return 2
|
||||
}
|
||||
|
||||
@@ -95,7 +107,7 @@ func realMain() int {
|
||||
return runSimilarityMode(ctx, client, cfg)
|
||||
case modeCategory:
|
||||
scraper := scrape.New(userAgent)
|
||||
enricher := enrich.NewMistralLLMEnricher(client, scraper)
|
||||
enricher := enrich.NewLLMEnricher(client, scraper)
|
||||
cfg := evalConfig{
|
||||
model: model,
|
||||
fixturePath: pathWithDefault(*fixturePath, "backend/cmd/discovery-eval/fixtures/category.json"),
|
||||
@@ -117,7 +129,7 @@ func pathWithDefault(p, dflt string) string {
|
||||
return p
|
||||
}
|
||||
|
||||
func runSimilarityMode(ctx context.Context, client *ai.Client, cfg evalConfig) int {
|
||||
func runSimilarityMode(ctx context.Context, client ai.Provider, cfg evalConfig) int {
|
||||
fixture, err := loadFixture(cfg.fixturePath)
|
||||
if err != nil {
|
||||
slog.Error("load fixture", "path", cfg.fixturePath, "error", err)
|
||||
@@ -125,7 +137,7 @@ func runSimilarityMode(ctx context.Context, client *ai.Client, cfg evalConfig) i
|
||||
}
|
||||
slog.Info("loaded fixture", "mode", modeSimilarity, "pairs", len(fixture.Pairs), "path", cfg.fixturePath)
|
||||
|
||||
classifier := enrich.NewMistralSimilarityClassifier(client)
|
||||
classifier := enrich.NewSimilarityClassifier(client)
|
||||
|
||||
cache, err := loadCache(cfg.cachePath)
|
||||
if err != nil {
|
||||
|
||||
@@ -11,6 +11,7 @@ require (
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jackc/pgx/v5 v5.8.0
|
||||
github.com/pquerna/otp v1.5.0
|
||||
github.com/santhosh-tekuri/jsonschema/v5 v5.3.1
|
||||
github.com/valkey-io/valkey-go v1.0.72
|
||||
golang.org/x/crypto v0.49.0
|
||||
golang.org/x/oauth2 v0.35.0
|
||||
|
||||
@@ -73,6 +73,8 @@ github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
|
||||
github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII=
|
||||
github.com/quic-go/quic-go v0.57.0 h1:AsSSrrMs4qI/hLrKlTH/TGQeTMY0ib1pAOX7vA3AdqE=
|
||||
github.com/quic-go/quic-go v0.57.0/go.mod h1:ly4QBAjHA2VhdnxhojRsCUOeJwKYg+taDlos92xb1+s=
|
||||
github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 h1:lZUw3E0/J3roVtGQ+SCrUrg3ON6NgVqpn3+iol9aGu4=
|
||||
github.com/santhosh-tekuri/jsonschema/v5 v5.3.1/go.mod h1:uToXkOrWAZ6/Oc07xWQrPOhJotwFIyu2bBVN41fcDUY=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
|
||||
@@ -22,6 +22,7 @@ type Config struct {
|
||||
Turnstile TurnstileConfig
|
||||
Notification NotificationConfig
|
||||
AI AIConfig
|
||||
Search SearchConfig
|
||||
Discovery DiscoveryConfig
|
||||
}
|
||||
|
||||
@@ -32,10 +33,19 @@ type DiscoveryConfig struct {
|
||||
}
|
||||
|
||||
type AIConfig struct {
|
||||
APIKey string
|
||||
AgentSimple string // Pre-created Mistral agent ID for Pass 1 (extraction + web search)
|
||||
ModelComplex string // Model for Pass 2 (description + retry, e.g. mistral-large-latest)
|
||||
RateLimitRPS float64 // Max requests per second to Mistral (0 = disabled)
|
||||
Provider string // "ollama" or "mistral"; default "ollama"
|
||||
RateLimitRPS float64 // Max requests per second to upstream; 0 = disabled (Mistral only)
|
||||
|
||||
OllamaURL string // default "http://localhost:11434"
|
||||
OllamaModel string // default "qwen2.5:14b-instruct"
|
||||
|
||||
MistralAPIKey string
|
||||
MistralModel string // default "mistral-large-latest"
|
||||
}
|
||||
|
||||
type SearchConfig struct {
|
||||
Provider string // "searxng" (only option today)
|
||||
SearxngURL string // default "http://localhost:8888"
|
||||
}
|
||||
|
||||
type AppConfig struct {
|
||||
@@ -272,10 +282,16 @@ func Load() (*Config, error) {
|
||||
FrontendURL: envStr("FRONTEND_URL", "http://localhost:5173"),
|
||||
},
|
||||
AI: AIConfig{
|
||||
APIKey: envStr("AI_API_KEY", ""),
|
||||
AgentSimple: envStr("AI_AGENT_SIMPLE", ""),
|
||||
ModelComplex: envStr("AI_MODEL_COMPLEX", "mistral-large-latest"),
|
||||
RateLimitRPS: rpsAI,
|
||||
Provider: envStr("AI_PROVIDER", "ollama"),
|
||||
RateLimitRPS: rpsAI,
|
||||
OllamaURL: envStr("AI_OLLAMA_URL", "http://localhost:11434"),
|
||||
OllamaModel: envStr("AI_OLLAMA_MODEL", "qwen2.5:14b-instruct"),
|
||||
MistralAPIKey: envStr("AI_MISTRAL_API_KEY", envStr("AI_API_KEY", "")),
|
||||
MistralModel: envStr("AI_MISTRAL_MODEL", envStr("AI_MODEL_COMPLEX", "mistral-large-latest")),
|
||||
},
|
||||
Search: SearchConfig{
|
||||
Provider: envStr("SEARCH_PROVIDER", "searxng"),
|
||||
SearxngURL: envStr("SEARCH_SEARXNG_URL", "http://localhost:8888"),
|
||||
},
|
||||
Discovery: DiscoveryConfig{
|
||||
Token: discoveryToken,
|
||||
|
||||
@@ -22,32 +22,25 @@ const maxScrapeURLs = 5
|
||||
// a no-context LLM prompt — grounding is required for useful output.
|
||||
var ErrNoScrapedContent = errors.New("no scrapeable content from any source URL")
|
||||
|
||||
// Scraper is the narrow interface MistralLLMEnricher depends on. Satisfied
|
||||
// Scraper is the narrow interface LLMEnricher depends on. Satisfied
|
||||
// by *pkg/scrape.Client; tests inject a stub.
|
||||
type Scraper interface {
|
||||
Fetch(ctx context.Context, url string) (string, error)
|
||||
}
|
||||
|
||||
// aiPass2 is the narrow interface for the AI client's Pass2 chat-completion
|
||||
// method. Lets tests inject a stub that returns canned JSON without hitting
|
||||
// the real Mistral API.
|
||||
type aiPass2 interface {
|
||||
Pass2(ctx context.Context, systemPrompt, userPrompt string) (ai.PassResult, error)
|
||||
}
|
||||
|
||||
// MistralLLMEnricher implements LLMEnricher by scraping quellen URLs and
|
||||
// feeding the concatenated text to Mistral's chat-completion endpoint with
|
||||
// ProviderLLMEnricher implements LLMEnricher by scraping quellen URLs and
|
||||
// feeding the concatenated text to an AI provider's chat endpoint with
|
||||
// a JSON response format.
|
||||
type MistralLLMEnricher struct {
|
||||
Client aiPass2
|
||||
type ProviderLLMEnricher struct {
|
||||
AI ai.Provider
|
||||
Scraper Scraper
|
||||
}
|
||||
|
||||
// NewMistralLLMEnricher constructs an enricher bound to a Mistral ai.Client
|
||||
// NewLLMEnricher constructs an enricher bound to an ai.Provider
|
||||
// and a scraper. Both are required; call sites should fall back to
|
||||
// NoopLLMEnricher when AI is disabled rather than passing nil here.
|
||||
func NewMistralLLMEnricher(client aiPass2, scraper Scraper) *MistralLLMEnricher {
|
||||
return &MistralLLMEnricher{Client: client, Scraper: scraper}
|
||||
func NewLLMEnricher(provider ai.Provider, scraper Scraper) *ProviderLLMEnricher {
|
||||
return &ProviderLLMEnricher{AI: provider, Scraper: scraper}
|
||||
}
|
||||
|
||||
// llmResponse is the JSON shape we instruct Mistral to return. Any field may
|
||||
@@ -60,11 +53,11 @@ type llmResponse struct {
|
||||
}
|
||||
|
||||
// EnrichMissing scrapes up to maxScrapeURLs of req.Quellen, concatenates the
|
||||
// extracted text, and asks Mistral to fill category / opening_hours /
|
||||
// extracted text, and asks the AI provider to fill category / opening_hours /
|
||||
// description. Fails with ErrNoScrapedContent if zero URLs return usable
|
||||
// text — empty-context LLM calls hallucinate.
|
||||
func (e *MistralLLMEnricher) EnrichMissing(ctx context.Context, req LLMRequest) (Enrichment, error) {
|
||||
if e.Client == nil || e.Scraper == nil {
|
||||
func (e *ProviderLLMEnricher) EnrichMissing(ctx context.Context, req LLMRequest) (Enrichment, error) {
|
||||
if e.AI == nil || e.Scraper == nil {
|
||||
return Enrichment{}, errors.New("mistral enricher not configured")
|
||||
}
|
||||
|
||||
@@ -92,28 +85,26 @@ func (e *MistralLLMEnricher) EnrichMissing(ctx context.Context, req LLMRequest)
|
||||
systemPrompt := buildSystemPrompt()
|
||||
userPrompt := buildUserPrompt(req, blocks)
|
||||
|
||||
result, err := e.Client.Pass2(ctx, systemPrompt, userPrompt)
|
||||
resp, err := e.AI.Chat(ctx, &ai.ChatRequest{
|
||||
SystemPrompt: systemPrompt,
|
||||
UserMessage: userPrompt,
|
||||
JSONMode: true,
|
||||
})
|
||||
if err != nil {
|
||||
return Enrichment{}, fmt.Errorf("pass2: %w", err)
|
||||
return Enrichment{}, fmt.Errorf("chat: %w", err)
|
||||
}
|
||||
|
||||
var parsed llmResponse
|
||||
if err := json.Unmarshal([]byte(result.Content), &parsed); err != nil {
|
||||
return Enrichment{}, fmt.Errorf("parse llm response: %w (content=%q)", err, result.Content)
|
||||
if err := json.Unmarshal([]byte(resp.Content), &parsed); err != nil {
|
||||
return Enrichment{}, fmt.Errorf("parse llm response: %w (content=%q)", err, resp.Content)
|
||||
}
|
||||
|
||||
// Build the Enrichment payload with only the fields the model produced.
|
||||
// Sources entries + token counts + model tag feed the eval harness.
|
||||
now := time.Now().UTC()
|
||||
out := Enrichment{
|
||||
Sources: Sources{},
|
||||
Model: result.Model,
|
||||
EnrichedAt: &now,
|
||||
}
|
||||
if result.Usage != nil {
|
||||
out.InputTokens = result.Usage.PromptTokens
|
||||
out.OutputTokens = result.Usage.CompletionTokens
|
||||
}
|
||||
if s := strings.TrimSpace(parsed.Category); s != "" {
|
||||
out.Category = s
|
||||
out.Sources["category"] = ProvenanceLLM
|
||||
|
||||
@@ -5,8 +5,6 @@ import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"marktvogt.de/backend/internal/pkg/ai"
|
||||
)
|
||||
|
||||
const catMittelaltermarkt = "mittelaltermarkt"
|
||||
@@ -25,33 +23,15 @@ func (s *stubScraper) Fetch(_ context.Context, url string) (string, error) {
|
||||
return s.responses[url], nil
|
||||
}
|
||||
|
||||
// stubPass2 captures the prompts it received and returns a canned JSON body.
|
||||
type stubPass2 struct {
|
||||
lastSystem string
|
||||
lastUser string
|
||||
result ai.PassResult
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *stubPass2) Pass2(_ context.Context, systemPrompt, userPrompt string) (ai.PassResult, error) {
|
||||
s.lastSystem = systemPrompt
|
||||
s.lastUser = userPrompt
|
||||
return s.result, s.err
|
||||
}
|
||||
|
||||
func TestMistralEnrich_HappyPath(t *testing.T) {
|
||||
scraper := &stubScraper{responses: map[string]string{
|
||||
"https://a.example/markt": "Ein Mittelaltermarkt mit Ritterspielen und Markttreiben.",
|
||||
"https://b.example/info": "Sa-So jeweils 10-18 Uhr.",
|
||||
}}
|
||||
client := &stubPass2{
|
||||
result: ai.PassResult{
|
||||
Content: `{"category":"mittelaltermarkt","opening_hours":"Sa-So 10:00-18:00","description":"Ein Markt mit Ritterspielen."}`,
|
||||
Model: "mistral-large-latest",
|
||||
Usage: &ai.UsageInfo{PromptTokens: 450, CompletionTokens: 60},
|
||||
},
|
||||
stub := &stubProvider{
|
||||
content: `{"category":"mittelaltermarkt","opening_hours":"Sa-So 10:00-18:00","description":"Ein Markt mit Ritterspielen."}`,
|
||||
}
|
||||
e := NewMistralLLMEnricher(client, scraper)
|
||||
e := NewLLMEnricher(stub, scraper)
|
||||
|
||||
req := LLMRequest{
|
||||
MarktName: "Mittelaltermarkt Dresden",
|
||||
@@ -64,40 +44,37 @@ func TestMistralEnrich_HappyPath(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if stub.seen.JSONMode != true {
|
||||
t.Fatalf("JSONMode must be set")
|
||||
}
|
||||
|
||||
// Result carries LLM fields, provenance llm, model + token counts.
|
||||
// Result carries LLM fields and provenance llm.
|
||||
if got.Category != catMittelaltermarkt || got.Description == "" || got.OpeningHours == "" {
|
||||
t.Errorf("missing fields in result: %+v", got)
|
||||
}
|
||||
if got.Sources["category"] != ProvenanceLLM {
|
||||
t.Errorf("category provenance: got %q, want llm", got.Sources["category"])
|
||||
}
|
||||
if got.Model != "mistral-large-latest" {
|
||||
t.Errorf("model: got %q, want mistral-large-latest", got.Model)
|
||||
}
|
||||
if got.InputTokens != 450 || got.OutputTokens != 60 {
|
||||
t.Errorf("token counts: in=%d out=%d", got.InputTokens, got.OutputTokens)
|
||||
}
|
||||
|
||||
// Prompt inspection — verify grounding blocks include the scraped content
|
||||
// and the already-known fields are listed so LLM doesn't redo them.
|
||||
if !strings.Contains(client.lastUser, "Mittelaltermarkt Dresden") {
|
||||
if !strings.Contains(stub.seen.UserMessage, "Mittelaltermarkt Dresden") {
|
||||
t.Error("user prompt missing markt name")
|
||||
}
|
||||
if !strings.Contains(client.lastUser, "Ein Mittelaltermarkt mit Ritterspielen") {
|
||||
if !strings.Contains(stub.seen.UserMessage, "Ein Mittelaltermarkt mit Ritterspielen") {
|
||||
t.Error("user prompt missing scraped content from URL 1")
|
||||
}
|
||||
if !strings.Contains(client.lastUser, "Sa-So jeweils 10-18 Uhr") {
|
||||
if !strings.Contains(stub.seen.UserMessage, "Sa-So jeweils 10-18 Uhr") {
|
||||
t.Error("user prompt missing scraped content from URL 2")
|
||||
}
|
||||
if !strings.Contains(client.lastUser, "Bereits bekannt") {
|
||||
if !strings.Contains(stub.seen.UserMessage, "Bereits bekannt") {
|
||||
t.Error("user prompt should announce already-known fields when Partial is populated")
|
||||
}
|
||||
if !strings.Contains(client.lastUser, "PLZ: 01067") {
|
||||
if !strings.Contains(stub.seen.UserMessage, "PLZ: 01067") {
|
||||
t.Error("user prompt missing already-known PLZ")
|
||||
}
|
||||
// System prompt asks for JSON only.
|
||||
if !strings.Contains(client.lastSystem, "JSON") {
|
||||
if !strings.Contains(stub.seen.SystemPrompt, "JSON") {
|
||||
t.Error("system prompt should mention JSON")
|
||||
}
|
||||
}
|
||||
@@ -107,8 +84,8 @@ func TestMistralEnrich_AllScrapesFail(t *testing.T) {
|
||||
"https://a.example": errors.New("timeout"),
|
||||
"https://b.example": errors.New("404"),
|
||||
}}
|
||||
client := &stubPass2{} // must not be called
|
||||
e := NewMistralLLMEnricher(client, scraper)
|
||||
stub := &stubProvider{} // must not be called
|
||||
e := NewLLMEnricher(stub, scraper)
|
||||
|
||||
req := LLMRequest{
|
||||
Quellen: []string{"https://a.example", "https://b.example"},
|
||||
@@ -117,7 +94,7 @@ func TestMistralEnrich_AllScrapesFail(t *testing.T) {
|
||||
if !errors.Is(err, ErrNoScrapedContent) {
|
||||
t.Errorf("err = %v; want ErrNoScrapedContent", err)
|
||||
}
|
||||
if client.lastUser != "" {
|
||||
if stub.seen != nil {
|
||||
t.Error("LLM must not be called when zero URLs scrape")
|
||||
}
|
||||
}
|
||||
@@ -129,10 +106,10 @@ func TestMistralEnrich_SomeScrapesFailStillCallsLLM(t *testing.T) {
|
||||
responses: map[string]string{"https://ok.example": "Useful content."},
|
||||
errs: map[string]error{"https://bad.example": errors.New("timeout")},
|
||||
}
|
||||
client := &stubPass2{
|
||||
result: ai.PassResult{Content: `{"category":"mittelaltermarkt","opening_hours":"","description":""}`},
|
||||
stub := &stubProvider{
|
||||
content: `{"category":"mittelaltermarkt","opening_hours":"","description":""}`,
|
||||
}
|
||||
e := NewMistralLLMEnricher(client, scraper)
|
||||
e := NewLLMEnricher(stub, scraper)
|
||||
|
||||
req := LLMRequest{Quellen: []string{"https://bad.example", "https://ok.example"}}
|
||||
got, err := e.EnrichMissing(context.Background(), req)
|
||||
@@ -142,7 +119,7 @@ func TestMistralEnrich_SomeScrapesFailStillCallsLLM(t *testing.T) {
|
||||
if got.Category != catMittelaltermarkt {
|
||||
t.Errorf("category: got %q", got.Category)
|
||||
}
|
||||
if !strings.Contains(client.lastUser, "Useful content") {
|
||||
if !strings.Contains(stub.seen.UserMessage, "Useful content") {
|
||||
t.Error("user prompt missing successful scrape")
|
||||
}
|
||||
}
|
||||
@@ -151,10 +128,10 @@ func TestMistralEnrich_EmptyFieldsNoProvenance(t *testing.T) {
|
||||
// LLM returns empty strings for fields it can't support. Those fields
|
||||
// must NOT appear in Sources — an empty provenance is misleading.
|
||||
scraper := &stubScraper{responses: map[string]string{"https://a.example": "Content."}}
|
||||
client := &stubPass2{
|
||||
result: ai.PassResult{Content: `{"category":"mittelaltermarkt","opening_hours":"","description":""}`},
|
||||
stub := &stubProvider{
|
||||
content: `{"category":"mittelaltermarkt","opening_hours":"","description":""}`,
|
||||
}
|
||||
e := NewMistralLLMEnricher(client, scraper)
|
||||
e := NewLLMEnricher(stub, scraper)
|
||||
|
||||
got, err := e.EnrichMissing(context.Background(), LLMRequest{Quellen: []string{"https://a.example"}})
|
||||
if err != nil {
|
||||
@@ -183,8 +160,8 @@ func TestMistralEnrich_CapsURLsAtFive(t *testing.T) {
|
||||
fetched[u] = true
|
||||
return responses[u], nil
|
||||
}}
|
||||
client := &stubPass2{result: ai.PassResult{Content: `{"category":"x","opening_hours":"","description":""}`}}
|
||||
e := NewMistralLLMEnricher(client, scraper)
|
||||
stub2 := &stubProvider{content: `{"category":"x","opening_hours":"","description":""}`}
|
||||
e := NewLLMEnricher(stub2, scraper)
|
||||
|
||||
_, _ = e.EnrichMissing(context.Background(), LLMRequest{Quellen: urls})
|
||||
if len(fetched) != 5 {
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"marktvogt.de/backend/internal/pkg/ai"
|
||||
)
|
||||
|
||||
// SimilarityRow carries the minimal identifying fields the classifier reads.
|
||||
@@ -80,16 +82,16 @@ func SimilarityPairKey(a, b SimilarityRow) string {
|
||||
// propagate).
|
||||
const DefaultSimilarityCacheTTL = 30 * 24 * time.Hour
|
||||
|
||||
// MistralSimilarityClassifier implements SimilarityClassifier by sending a
|
||||
// JSON-formatted comparison prompt to Mistral's chat endpoint.
|
||||
type MistralSimilarityClassifier struct {
|
||||
Client aiPass2
|
||||
// SimilarityClassifierLLM implements SimilarityClassifier by sending a
|
||||
// JSON-formatted comparison prompt to an AI provider's chat endpoint.
|
||||
type SimilarityClassifierLLM struct {
|
||||
AI ai.Provider
|
||||
}
|
||||
|
||||
// NewMistralSimilarityClassifier binds a Mistral ai.Client. client must be
|
||||
// NewSimilarityClassifier binds an ai.Provider. provider must be
|
||||
// non-nil; routes.go falls back to NoopSimilarityClassifier when AI is off.
|
||||
func NewMistralSimilarityClassifier(client aiPass2) *MistralSimilarityClassifier {
|
||||
return &MistralSimilarityClassifier{Client: client}
|
||||
func NewSimilarityClassifier(provider ai.Provider) *SimilarityClassifierLLM {
|
||||
return &SimilarityClassifierLLM{AI: provider}
|
||||
}
|
||||
|
||||
// simResponse is the JSON shape we instruct Mistral to return. Confidence
|
||||
@@ -100,25 +102,29 @@ type simResponse struct {
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
// Classify sends the paired metadata to Mistral and parses the JSON response.
|
||||
// Classify sends the paired metadata to the AI provider and parses the JSON response.
|
||||
// No web scraping — the classifier works from name/city/year alone, which is
|
||||
// enough for the common cases (same venue listed on two different calendars,
|
||||
// editing typos, cross-year recurrence).
|
||||
func (m *MistralSimilarityClassifier) Classify(ctx context.Context, a, b SimilarityRow) (Verdict, error) {
|
||||
if m.Client == nil {
|
||||
return Verdict{}, errors.New("mistral similarity classifier not configured")
|
||||
func (c *SimilarityClassifierLLM) Classify(ctx context.Context, a, b SimilarityRow) (Verdict, error) {
|
||||
if c.AI == nil {
|
||||
return Verdict{}, errors.New("similarity classifier not configured")
|
||||
}
|
||||
systemPrompt := simSystemPrompt()
|
||||
userPrompt := simUserPrompt(a, b)
|
||||
|
||||
result, err := m.Client.Pass2(ctx, systemPrompt, userPrompt)
|
||||
resp, err := c.AI.Chat(ctx, &ai.ChatRequest{
|
||||
SystemPrompt: systemPrompt,
|
||||
UserMessage: userPrompt,
|
||||
JSONMode: true,
|
||||
})
|
||||
if err != nil {
|
||||
return Verdict{}, fmt.Errorf("pass2: %w", err)
|
||||
return Verdict{}, fmt.Errorf("chat: %w", err)
|
||||
}
|
||||
|
||||
var parsed simResponse
|
||||
if err := json.Unmarshal([]byte(result.Content), &parsed); err != nil {
|
||||
return Verdict{}, fmt.Errorf("parse response: %w (content=%q)", err, result.Content)
|
||||
if err := json.Unmarshal([]byte(resp.Content), &parsed); err != nil {
|
||||
return Verdict{}, fmt.Errorf("parse response: %w (content=%q)", err, resp.Content)
|
||||
}
|
||||
|
||||
// Clamp confidence to [0,1]; the model occasionally returns 1.2 or -0.1.
|
||||
@@ -134,7 +140,6 @@ func (m *MistralSimilarityClassifier) Classify(ctx context.Context, a, b Similar
|
||||
Same: parsed.SameMarket,
|
||||
Confidence: conf,
|
||||
Reason: strings.TrimSpace(parsed.Reason),
|
||||
Model: result.Model,
|
||||
ClassifiedAt: time.Now().UTC(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -5,8 +5,6 @@ import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"marktvogt.de/backend/internal/pkg/ai"
|
||||
)
|
||||
|
||||
func TestSimilarityPairKey_Symmetric(t *testing.T) {
|
||||
@@ -34,13 +32,10 @@ func TestSimilarityPairKey_DifferentInputsDifferentKeys(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMistralSimilarity_HappyPath(t *testing.T) {
|
||||
client := &stubPass2{
|
||||
result: ai.PassResult{
|
||||
Content: `{"same_market":true,"confidence":0.82,"reason":"Gleicher Name, gleiche Stadt, gleiches Jahr."}`,
|
||||
Model: "mistral-large-latest",
|
||||
},
|
||||
stub := &stubProvider{
|
||||
content: `{"same_market":true,"confidence":0.82,"reason":"Gleicher Name, gleiche Stadt, gleiches Jahr."}`,
|
||||
}
|
||||
c := NewMistralSimilarityClassifier(client)
|
||||
c := NewSimilarityClassifier(stub)
|
||||
|
||||
got, err := c.Classify(context.Background(),
|
||||
SimilarityRow{Name: "Mittelaltermarkt Dresden", Stadt: "Dresden", Year: 2026},
|
||||
@@ -49,6 +44,9 @@ func TestMistralSimilarity_HappyPath(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if stub.seen.JSONMode != true {
|
||||
t.Fatalf("JSONMode must be set")
|
||||
}
|
||||
if !got.Same {
|
||||
t.Errorf("same = false; want true")
|
||||
}
|
||||
@@ -58,14 +56,11 @@ func TestMistralSimilarity_HappyPath(t *testing.T) {
|
||||
if got.Reason == "" {
|
||||
t.Error("reason missing")
|
||||
}
|
||||
if got.Model != "mistral-large-latest" {
|
||||
t.Errorf("model = %q", got.Model)
|
||||
}
|
||||
// Prompt must carry both rows' identifying fields for the LLM to reason on.
|
||||
if !strings.Contains(client.lastUser, "Mittelaltermarkt Dresden") {
|
||||
if !strings.Contains(stub.seen.UserMessage, "Mittelaltermarkt Dresden") {
|
||||
t.Error("user prompt missing A.name")
|
||||
}
|
||||
if !strings.Contains(client.lastSystem, "same_market") {
|
||||
if !strings.Contains(stub.seen.SystemPrompt, "same_market") {
|
||||
t.Error("system prompt should describe the JSON schema (same_market key)")
|
||||
}
|
||||
}
|
||||
@@ -82,7 +77,7 @@ func TestMistralSimilarity_ClampsConfidence(t *testing.T) {
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c := NewMistralSimilarityClassifier(&stubPass2{result: ai.PassResult{Content: tc.raw}})
|
||||
c := NewSimilarityClassifier(&stubProvider{content: tc.raw})
|
||||
got, err := c.Classify(context.Background(), SimilarityRow{}, SimilarityRow{})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
@@ -95,7 +90,7 @@ func TestMistralSimilarity_ClampsConfidence(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMistralSimilarity_PropagatesPass2Error(t *testing.T) {
|
||||
c := NewMistralSimilarityClassifier(&stubPass2{err: errors.New("mistral down")})
|
||||
c := NewSimilarityClassifier(&stubProvider{err: errors.New("mistral down")})
|
||||
_, err := c.Classify(context.Background(), SimilarityRow{}, SimilarityRow{})
|
||||
if err == nil {
|
||||
t.Fatal("expected error; got nil")
|
||||
@@ -103,7 +98,7 @@ func TestMistralSimilarity_PropagatesPass2Error(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMistralSimilarity_RejectsBadJSON(t *testing.T) {
|
||||
c := NewMistralSimilarityClassifier(&stubPass2{result: ai.PassResult{Content: "not json at all"}})
|
||||
c := NewSimilarityClassifier(&stubProvider{content: "not json at all"})
|
||||
_, err := c.Classify(context.Background(), SimilarityRow{}, SimilarityRow{})
|
||||
if err == nil {
|
||||
t.Fatal("expected parse error; got nil")
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
package enrich
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"marktvogt.de/backend/internal/pkg/ai"
|
||||
)
|
||||
|
||||
type stubProvider struct {
|
||||
content string
|
||||
err error
|
||||
seen *ai.ChatRequest
|
||||
}
|
||||
|
||||
func (s *stubProvider) Name() string { return "stub" }
|
||||
func (s *stubProvider) SupportsJSONMode() bool { return true }
|
||||
func (s *stubProvider) SupportsJSONSchema() bool { return false }
|
||||
func (s *stubProvider) Chat(ctx context.Context, req *ai.ChatRequest) (*ai.ChatResponse, error) {
|
||||
s.seen = req
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
return &ai.ChatResponse{Content: s.content}, nil
|
||||
}
|
||||
@@ -1,711 +1,19 @@
|
||||
package market
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"marktvogt.de/backend/internal/pkg/ai"
|
||||
"marktvogt.de/backend/internal/pkg/apierror"
|
||||
)
|
||||
|
||||
type ResearchHandler struct {
|
||||
service *Service
|
||||
aiClient *ai.Client
|
||||
mu sync.Mutex
|
||||
cooldown map[uuid.UUID]time.Time
|
||||
service *Service
|
||||
}
|
||||
|
||||
func NewResearchHandler(service *Service, aiClient *ai.Client) *ResearchHandler {
|
||||
return &ResearchHandler{
|
||||
service: service,
|
||||
aiClient: aiClient,
|
||||
cooldown: make(map[uuid.UUID]time.Time),
|
||||
}
|
||||
func NewResearchHandler(service *Service, _ any) *ResearchHandler {
|
||||
return &ResearchHandler{service: service}
|
||||
}
|
||||
|
||||
func (h *ResearchHandler) Research(c *gin.Context) {
|
||||
if !h.aiClient.Enabled() {
|
||||
apiErr := apierror.BadRequest("ai_disabled", "AI research is not configured")
|
||||
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
apiErr := apierror.BadRequest("invalid_id", "invalid market ID")
|
||||
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
||||
return
|
||||
}
|
||||
|
||||
// Rate limit: 1 per market per 5 minutes
|
||||
h.mu.Lock()
|
||||
if last, ok := h.cooldown[id]; ok && time.Since(last) < 5*time.Minute {
|
||||
h.mu.Unlock()
|
||||
apiErr := apierror.BadRequest("rate_limited", "Bitte warte 5 Minuten zwischen Recherche-Aufrufen")
|
||||
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
||||
return
|
||||
}
|
||||
h.cooldown[id] = time.Now()
|
||||
h.mu.Unlock()
|
||||
|
||||
m, err := h.service.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
apiErr := apierror.NotFound("market")
|
||||
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
||||
return
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// --- Pass 1: Structured extraction via agent with web search ---
|
||||
pass1Prompt := buildPass1Prompt(m)
|
||||
pass1Result, err := h.aiClient.Pass1(ctx, pass1Prompt)
|
||||
if err != nil {
|
||||
if ai.IsRateLimit(err) {
|
||||
h.respondRateLimited(c, "pass1", id, err)
|
||||
return
|
||||
}
|
||||
slog.WarnContext(ctx, "pass1 failed, retrying", "market_id", id, "error", err)
|
||||
pass1Result, err = h.aiClient.Pass1(ctx, pass1Prompt)
|
||||
if err != nil {
|
||||
if ai.IsRateLimit(err) {
|
||||
h.respondRateLimited(c, "pass1", id, err)
|
||||
return
|
||||
}
|
||||
slog.ErrorContext(ctx, "pass1 retry failed", "market_id", id, "error", err)
|
||||
apiErr := apierror.Internal("AI research failed")
|
||||
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
pass1Data := parsePass1Response(pass1Result.Content)
|
||||
|
||||
// Compute confidence for pass1 fields and identify retry fields
|
||||
var retryFields []string
|
||||
for fieldName, field := range pass1Data.Fields {
|
||||
if field.Value == nil {
|
||||
retryFields = append(retryFields, fieldName)
|
||||
}
|
||||
}
|
||||
|
||||
// Collect all sources from pass1
|
||||
allSources := pass1Data.Sources
|
||||
|
||||
// --- Pass 2: Description + retry fields via chat completions ---
|
||||
// Pass 2 is enrichment only — if it fails we still return the pass1 result.
|
||||
// Rate-limit on pass2 is treated as a warning and the partial result ships.
|
||||
pass2Prompt := buildPass2UserPrompt(m, pass1Data, retryFields, allSources)
|
||||
pass2Result, err := h.aiClient.Pass2(ctx, pass2SystemPrompt, pass2Prompt)
|
||||
var pass2Data pass2Response
|
||||
if err != nil {
|
||||
if ai.IsRateLimit(err) {
|
||||
slog.WarnContext(ctx, "pass2 rate-limited; returning pass1 results only",
|
||||
"market_id", id)
|
||||
} else {
|
||||
slog.WarnContext(ctx, "pass2 failed, retrying", "market_id", id, "error", err)
|
||||
pass2Result, err = h.aiClient.Pass2(ctx, pass2SystemPrompt, pass2Prompt)
|
||||
if err != nil {
|
||||
if ai.IsRateLimit(err) {
|
||||
slog.WarnContext(ctx, "pass2 rate-limited on retry; returning pass1 results only",
|
||||
"market_id", id)
|
||||
} else {
|
||||
slog.WarnContext(ctx, "pass2 retry failed, using pass1 results only",
|
||||
"market_id", id, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if err == nil {
|
||||
pass2Data = parsePass2Response(pass2Result.Content)
|
||||
}
|
||||
|
||||
// --- Merge results and compute final confidence ---
|
||||
result := mergeResults(m, pass1Data, pass2Data)
|
||||
|
||||
// Log cost tracking
|
||||
logCostTracking(id, pass1Result, pass2Result)
|
||||
|
||||
c.JSON(http.StatusOK, ResearchResponse{Data: result})
|
||||
}
|
||||
|
||||
// respondRateLimited writes a 429 with a Retry-After header so the admin UI
|
||||
// can display a clear "try again in N seconds" message instead of a generic
|
||||
// 500. pass is a short identifier for the log line ("pass1" / "pass2").
|
||||
func (h *ResearchHandler) respondRateLimited(c *gin.Context, pass string, marketID uuid.UUID, err error) {
|
||||
retry := ai.DefaultRetryAfterSeconds
|
||||
slog.WarnContext(c.Request.Context(), "ai rate-limited",
|
||||
"pass", pass, "market_id", marketID, "retry_after_s", retry, "error", err)
|
||||
c.Header("Retry-After", fmt.Sprint(retry))
|
||||
apiErr := apierror.TooManyRequests(
|
||||
fmt.Sprintf("AI research temporarily rate-limited; try again in ~%ds", retry),
|
||||
)
|
||||
c.JSON(apiErr.Status, apierror.NewResponse(apiErr))
|
||||
}
|
||||
|
||||
// --- Pass 1 types and prompt ---
|
||||
|
||||
type pass1FieldResult struct {
|
||||
Value any `json:"wert"`
|
||||
Sources []string `json:"quellen"`
|
||||
Extraction string `json:"extraktion"` // "verbatim" or "abgeleitet"
|
||||
}
|
||||
|
||||
type pass1Response struct {
|
||||
Fields map[string]pass1FieldResult `json:"-"`
|
||||
Sources []string `json:"quellen_gesamt"`
|
||||
raw map[string]json.RawMessage
|
||||
}
|
||||
|
||||
func buildPass1Prompt(m Market) string {
|
||||
dateRange := ""
|
||||
if !m.StartDate.IsZero() && !m.EndDate.IsZero() {
|
||||
dateRange = fmt.Sprintf("Zeitraum: %s bis %s\n", m.StartDate.Format("02.01.2006"), m.EndDate.Format("02.01.2006"))
|
||||
}
|
||||
|
||||
website := ""
|
||||
if m.Website != "" {
|
||||
website = fmt.Sprintf("Website: %s\n", m.Website)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`Recherchiere den folgenden Mittelaltermarkt und finde aktuelle Informationen.
|
||||
|
||||
Name: %s
|
||||
Stadt: %s
|
||||
%s%s
|
||||
Suche nach folgenden Feldern:
|
||||
- website (offizielle URL)
|
||||
- strasse (Straße/Adresse des Veranstaltungsorts)
|
||||
- plz (Postleitzahl)
|
||||
- stadt (Stadt, falls korrekter als angegeben)
|
||||
- veranstalter (Name des Veranstalters)
|
||||
- oeffnungszeiten (Öffnungszeiten als Array von {tag, von, bis} mit deutschen Wochentagen)
|
||||
- eintrittspreise (Preise als {erwachsene_cent, kinder_cent, ermaessigt_cent, frei_unter_alter, hinweise})
|
||||
|
||||
Antworte AUSSCHLIESSLICH mit validem JSON (keine Markdown-Codeblöcke).
|
||||
|
||||
{
|
||||
"website": {
|
||||
"wert": "https://...",
|
||||
"quellen": ["https://..."],
|
||||
"extraktion": "verbatim"
|
||||
},
|
||||
"strasse": {
|
||||
"wert": "Beispielstraße 5",
|
||||
"quellen": ["https://..."],
|
||||
"extraktion": "verbatim"
|
||||
},
|
||||
"plz": {
|
||||
"wert": "12345",
|
||||
"quellen": ["https://..."],
|
||||
"extraktion": "verbatim"
|
||||
},
|
||||
"stadt": {
|
||||
"wert": "Musterstadt",
|
||||
"quellen": ["https://..."],
|
||||
"extraktion": "verbatim"
|
||||
},
|
||||
"veranstalter": {
|
||||
"wert": "Verein XY e.V.",
|
||||
"quellen": ["https://..."],
|
||||
"extraktion": "abgeleitet"
|
||||
},
|
||||
"oeffnungszeiten": {
|
||||
"wert": [{"tag": "Samstag", "von": "10:00", "bis": "22:00"}],
|
||||
"quellen": ["https://..."],
|
||||
"extraktion": "verbatim"
|
||||
},
|
||||
"eintrittspreise": {
|
||||
"wert": {"erwachsene_cent": 800, "kinder_cent": 0, "ermaessigt_cent": 500, "frei_unter_alter": 12, "hinweise": "Kinder unter 12 frei"},
|
||||
"quellen": ["https://..."],
|
||||
"extraktion": "verbatim"
|
||||
},
|
||||
"quellen_gesamt": ["https://...", "https://..."]
|
||||
}
|
||||
|
||||
Regeln:
|
||||
- "extraktion": "verbatim" wenn der Wert wörtlich in einer strukturierten Quelle steht (Tabelle, Adressblock, Schema)
|
||||
- "extraktion": "abgeleitet" wenn der Wert aus Fließtext interpretiert wurde
|
||||
- Setze "wert" auf null wenn nicht findbar. Erfinde KEINE Werte.
|
||||
- "quellen" muss die URL(s) enthalten, aus denen der Wert stammt
|
||||
- Verwende für Wochentage NUR: Montag, Dienstag, Mittwoch, Donnerstag, Freitag, Samstag, Sonntag
|
||||
- Preise in Cent (800 = 8,00 EUR)
|
||||
`, m.Name, m.City, dateRange, website)
|
||||
}
|
||||
|
||||
func parsePass1Response(raw string) pass1Response {
|
||||
cleaned := extractJSON(raw)
|
||||
cleaned = stripJSONComments(cleaned)
|
||||
|
||||
var result pass1Response
|
||||
result.Fields = make(map[string]pass1FieldResult)
|
||||
|
||||
var rawMap map[string]json.RawMessage
|
||||
if err := json.Unmarshal([]byte(cleaned), &rawMap); err != nil {
|
||||
slog.Warn("failed to parse pass1 response", "error", err, "cleaned", cleaned)
|
||||
return result
|
||||
}
|
||||
result.raw = rawMap
|
||||
|
||||
// Parse quellen_gesamt
|
||||
if sourcesRaw, ok := rawMap["quellen_gesamt"]; ok {
|
||||
_ = json.Unmarshal(sourcesRaw, &result.Sources)
|
||||
}
|
||||
|
||||
// Parse known field names
|
||||
knownFields := []string{"website", "strasse", "plz", "stadt", "veranstalter", "oeffnungszeiten", "eintrittspreise"}
|
||||
for _, name := range knownFields {
|
||||
if fieldRaw, ok := rawMap[name]; ok {
|
||||
var field pass1FieldResult
|
||||
if err := json.Unmarshal(fieldRaw, &field); err == nil {
|
||||
result.Fields[name] = field
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// --- Pass 2 types and prompt ---
|
||||
|
||||
const pass2SystemPrompt = `Du bist ein Redakteur für Marktvogt, eine Plattform für Mittelaltermärkte und historische Veranstaltungen im DACH-Raum.
|
||||
|
||||
Du bekommst:
|
||||
1. Rechercheergebnisse aus einer vorherigen Websuche (bereits gesammelte Quelleninhalte und URLs)
|
||||
2. Optional: Eine Liste von Feldern ("retry_felder"), die in der ersten Runde leer oder unsicher waren
|
||||
|
||||
Du hast KEINEN Internetzugang. Arbeite ausschließlich mit den bereitgestellten Informationen.
|
||||
|
||||
## Aufgabe 1: Beschreibung verfassen
|
||||
|
||||
Verfasse eine ansprechende, informative Marktbeschreibung auf Deutsch.
|
||||
|
||||
### Anforderungen:
|
||||
- 3-5 Sätze, ca. 80-150 Wörter
|
||||
- Sachlich-einladender Ton — kein Werbe-Deutsch, keine Superlative ("das größte", "einzigartig", "unvergesslich")
|
||||
- Inhalt: Was macht diesen Markt besonders? Was erwartet Besucher? Marktstände, Programm, Musik, Handwerk, Ritterspiele, Gaukelei, historisches Lagerleben etc.
|
||||
- Atmosphäre und Veranstaltungsort erwähnen wenn aus den Quellen ersichtlich
|
||||
- Zielgruppen ansprechen: Familien, Mittelalter-Fans, Reenactment-Begeisterte
|
||||
- NICHT erwähnen: Eintrittspreise, exakte Öffnungszeiten, Datumsangaben (stehen in den strukturierten Feldern)
|
||||
- Alle Fakten in der Beschreibung müssen aus den mitgelieferten Quelleninhalten stammen
|
||||
|
||||
## Aufgabe 2: Felder nachholen (nur wenn retry_felder vorhanden)
|
||||
|
||||
Für jedes Feld in retry_felder: Versuche den Wert aus den bereits mitgelieferten Quelleninhalten zu extrahieren.
|
||||
|
||||
## Antwortformat
|
||||
|
||||
Antworte AUSSCHLIESSLICH mit validem JSON.
|
||||
|
||||
{
|
||||
"beschreibung": {
|
||||
"wert": "Die vollständige Beschreibung hier...",
|
||||
"quellen": ["https://...", "https://..."],
|
||||
"basiert_auf": "Kurze Erklärung"
|
||||
},
|
||||
"retry_ergebnisse": {
|
||||
"feldname": {
|
||||
"wert": "...",
|
||||
"quellen": ["https://..."],
|
||||
"extraktion": "verbatim",
|
||||
"hinweis": "optional"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Wenn keine retry_felder übergeben wurden, gib "retry_ergebnisse": {} zurück.
|
||||
Erfinde KEINE Fakten. Wenn die Quellen nichts hergeben, setze "wert" auf null.`
|
||||
|
||||
type pass2DescriptionResult struct {
|
||||
Value string `json:"wert"`
|
||||
Sources []string `json:"quellen"`
|
||||
BasedOn string `json:"basiert_auf"`
|
||||
}
|
||||
|
||||
type pass2RetryResult struct {
|
||||
Value any `json:"wert"`
|
||||
Sources []string `json:"quellen"`
|
||||
Extraction string `json:"extraktion"`
|
||||
}
|
||||
|
||||
type pass2Response struct {
|
||||
Description pass2DescriptionResult `json:"beschreibung"`
|
||||
RetryResults map[string]pass2RetryResult `json:"retry_ergebnisse"`
|
||||
}
|
||||
|
||||
func buildPass2UserPrompt(m Market, pass1 pass1Response, retryFields []string, allSources []string) string {
|
||||
// Build context from pass1 results
|
||||
var sb strings.Builder
|
||||
sb.WriteString(fmt.Sprintf("Markt: %s\nStadt: %s\n", m.Name, m.City))
|
||||
|
||||
if !m.StartDate.IsZero() {
|
||||
sb.WriteString(fmt.Sprintf("Zeitraum: %s bis %s\n", m.StartDate.Format("02.01.2006"), m.EndDate.Format("02.01.2006")))
|
||||
}
|
||||
|
||||
sb.WriteString("\n## Rechercheergebnisse aus Pass 1\n\n")
|
||||
for name, field := range pass1.Fields {
|
||||
if field.Value != nil {
|
||||
val, _ := json.Marshal(field.Value)
|
||||
sb.WriteString(fmt.Sprintf("%s: %s (Quellen: %s)\n", name, string(val), strings.Join(field.Sources, ", ")))
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("\n## Quellen-URLs\n")
|
||||
for _, src := range allSources {
|
||||
sb.WriteString(fmt.Sprintf("- %s\n", src))
|
||||
}
|
||||
|
||||
if len(retryFields) > 0 {
|
||||
sb.WriteString(fmt.Sprintf("\n## retry_felder\n%s\n", strings.Join(retryFields, ", ")))
|
||||
sb.WriteString("Versuche diese Felder aus den Quelleninhalten zu extrahieren.\n")
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func parsePass2Response(raw string) pass2Response {
|
||||
cleaned := extractJSON(raw)
|
||||
cleaned = stripJSONComments(cleaned)
|
||||
|
||||
var result pass2Response
|
||||
result.RetryResults = make(map[string]pass2RetryResult)
|
||||
|
||||
if err := json.Unmarshal([]byte(cleaned), &result); err != nil {
|
||||
slog.Warn("failed to parse pass2 response", "error", err, "cleaned", cleaned)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// --- Confidence scoring ---
|
||||
|
||||
const extractionDerived = "abgeleitet"
|
||||
|
||||
func computeConfidence(sources []string, extraction string) string {
|
||||
count := len(sources)
|
||||
if count >= 2 && extraction == "verbatim" {
|
||||
return "high"
|
||||
}
|
||||
if count >= 2 && extraction == extractionDerived {
|
||||
return "medium"
|
||||
}
|
||||
if count == 1 && extraction == "verbatim" {
|
||||
return "medium"
|
||||
}
|
||||
return "low"
|
||||
}
|
||||
|
||||
// --- Merge results ---
|
||||
|
||||
// fieldMapping maps pass1 field names to the FieldSuggestion field names used by the frontend.
|
||||
var fieldMapping = map[string]string{
|
||||
"website": "website",
|
||||
"strasse": "street",
|
||||
"plz": "zip",
|
||||
"stadt": "city",
|
||||
"veranstalter": "organizer_name",
|
||||
"oeffnungszeiten": "opening_hours",
|
||||
"eintrittspreise": "admission_info",
|
||||
}
|
||||
|
||||
func mergeResults(m Market, pass1 pass1Response, pass2 pass2Response) ResearchResult {
|
||||
suggestions := make([]FieldSuggestion, 0, len(pass1.Fields)+len(pass2.RetryResults)+1)
|
||||
allSources := make([]string, 0, len(pass1.Sources))
|
||||
|
||||
// Collect unique sources
|
||||
sourceSet := make(map[string]bool)
|
||||
for _, s := range pass1.Sources {
|
||||
sourceSet[s] = true
|
||||
}
|
||||
|
||||
// Process pass1 fields
|
||||
for p1Name, field := range pass1.Fields {
|
||||
frontendName, ok := fieldMapping[p1Name]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if field.Value == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Convert opening hours format
|
||||
suggestedValue := field.Value
|
||||
if p1Name == "oeffnungszeiten" {
|
||||
suggestedValue = convertOpeningHours(field.Value)
|
||||
}
|
||||
if p1Name == "eintrittspreise" {
|
||||
suggestedValue = convertAdmission(field.Value)
|
||||
}
|
||||
|
||||
suggestions = append(suggestions, FieldSuggestion{
|
||||
Field: frontendName,
|
||||
CurrentValue: getCurrentValue(m, frontendName),
|
||||
SuggestedValue: suggestedValue,
|
||||
Confidence: computeConfidence(field.Sources, field.Extraction),
|
||||
Reason: formatReason(field.Sources, field.Extraction),
|
||||
})
|
||||
|
||||
for _, s := range field.Sources {
|
||||
sourceSet[s] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Process pass2 retry results
|
||||
for p1Name, retryField := range pass2.RetryResults {
|
||||
if retryField.Value == nil {
|
||||
continue
|
||||
}
|
||||
frontendName, ok := fieldMapping[p1Name]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
suggestedValue := retryField.Value
|
||||
if p1Name == "oeffnungszeiten" {
|
||||
suggestedValue = convertOpeningHours(retryField.Value)
|
||||
}
|
||||
if p1Name == "eintrittspreise" {
|
||||
suggestedValue = convertAdmission(retryField.Value)
|
||||
}
|
||||
|
||||
suggestions = append(suggestions, FieldSuggestion{
|
||||
Field: frontendName,
|
||||
CurrentValue: getCurrentValue(m, frontendName),
|
||||
SuggestedValue: suggestedValue,
|
||||
Confidence: computeConfidence(retryField.Sources, retryField.Extraction),
|
||||
Reason: formatReason(retryField.Sources, retryField.Extraction) + " (retry)",
|
||||
})
|
||||
|
||||
for _, s := range retryField.Sources {
|
||||
sourceSet[s] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Add description from pass2
|
||||
if pass2.Description.Value != "" {
|
||||
suggestions = append(suggestions, FieldSuggestion{
|
||||
Field: "description",
|
||||
CurrentValue: m.Description,
|
||||
SuggestedValue: pass2.Description.Value,
|
||||
Confidence: computeConfidence(pass2.Description.Sources, extractionDerived),
|
||||
Reason: pass2.Description.BasedOn,
|
||||
})
|
||||
|
||||
for _, s := range pass2.Description.Sources {
|
||||
sourceSet[s] = true
|
||||
}
|
||||
}
|
||||
|
||||
for s := range sourceSet {
|
||||
allSources = append(allSources, s)
|
||||
}
|
||||
|
||||
return ResearchResult{
|
||||
Suggestions: suggestions,
|
||||
Sources: allSources,
|
||||
}
|
||||
}
|
||||
|
||||
func getCurrentValue(m Market, field string) any {
|
||||
switch field {
|
||||
case "website":
|
||||
return m.Website
|
||||
case "street":
|
||||
return m.Street
|
||||
case "zip":
|
||||
return m.Zip
|
||||
case "city":
|
||||
return m.City
|
||||
case "organizer_name":
|
||||
return m.OrganizerName
|
||||
case "opening_hours":
|
||||
return m.OpeningHours
|
||||
case "admission_info":
|
||||
return m.AdmissionInfo
|
||||
case "description":
|
||||
return m.Description
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func formatReason(sources []string, extraction string) string {
|
||||
srcCount := len(sources)
|
||||
verb := "gefunden"
|
||||
if extraction == extractionDerived {
|
||||
verb = extractionDerived
|
||||
}
|
||||
if srcCount == 0 {
|
||||
return verb
|
||||
}
|
||||
if srcCount == 1 {
|
||||
return fmt.Sprintf("Aus 1 Quelle %s", verb)
|
||||
}
|
||||
return fmt.Sprintf("Aus %d Quellen %s", srcCount, verb)
|
||||
}
|
||||
|
||||
// convertOpeningHours converts pass1 format {tag, von, bis} to frontend format {day, open, close}
|
||||
func convertOpeningHours(val any) any {
|
||||
data, err := json.Marshal(val)
|
||||
if err != nil {
|
||||
return val
|
||||
}
|
||||
|
||||
var entries []map[string]string
|
||||
if err := json.Unmarshal(data, &entries); err != nil {
|
||||
return val
|
||||
}
|
||||
|
||||
result := make([]map[string]string, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
result = append(result, map[string]string{
|
||||
"day": e["tag"],
|
||||
"open": e["von"],
|
||||
"close": e["bis"],
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// convertAdmission converts pass1 format to frontend format (cents)
|
||||
func convertAdmission(val any) any {
|
||||
data, err := json.Marshal(val)
|
||||
if err != nil {
|
||||
return val
|
||||
}
|
||||
|
||||
var raw map[string]any
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return val
|
||||
}
|
||||
|
||||
result := map[string]any{
|
||||
"adult_cents": toInt(raw["erwachsene_cent"]),
|
||||
"child_cents": toInt(raw["kinder_cent"]),
|
||||
"reduced_cents": toInt(raw["ermaessigt_cent"]),
|
||||
"free_under_age": toInt(raw["frei_unter_alter"]),
|
||||
"notes": toString(raw["hinweise"]),
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func toInt(v any) int {
|
||||
switch n := v.(type) {
|
||||
case float64:
|
||||
return int(n)
|
||||
case int:
|
||||
return n
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func toString(v any) string {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// --- Cost logging ---
|
||||
|
||||
func logCostTracking(marketID uuid.UUID, pass1, pass2 ai.PassResult) {
|
||||
log := map[string]any{
|
||||
"market_id": marketID,
|
||||
"timestamp": time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
if pass1.Usage != nil {
|
||||
log["pass_1"] = map[string]any{
|
||||
"model": pass1.Model,
|
||||
"input_tokens": pass1.Usage.PromptTokens,
|
||||
"output_tokens": pass1.Usage.CompletionTokens,
|
||||
}
|
||||
}
|
||||
if pass2.Usage != nil {
|
||||
log["pass_2"] = map[string]any{
|
||||
"model": pass2.Model,
|
||||
"input_tokens": pass2.Usage.PromptTokens,
|
||||
"output_tokens": pass2.Usage.CompletionTokens,
|
||||
}
|
||||
}
|
||||
|
||||
data, _ := json.Marshal(log)
|
||||
slog.Info("ai_research_cost", "data", string(data))
|
||||
}
|
||||
|
||||
// --- JSON helpers ---
|
||||
|
||||
func extractJSON(s string) string {
|
||||
if idx := findJSONStart(s); idx >= 0 {
|
||||
s = s[idx:]
|
||||
if end := findJSONEnd(s); end > 0 {
|
||||
s = s[:end+1]
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func stripJSONComments(s string) string {
|
||||
var result []byte
|
||||
inString := false
|
||||
escaped := false
|
||||
for i := 0; i < len(s); i++ {
|
||||
c := s[i]
|
||||
if escaped {
|
||||
result = append(result, c)
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if c == '\\' && inString {
|
||||
result = append(result, c)
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if c == '"' {
|
||||
inString = !inString
|
||||
result = append(result, c)
|
||||
continue
|
||||
}
|
||||
if !inString && c == '/' && i+1 < len(s) && s[i+1] == '/' {
|
||||
for i < len(s) && s[i] != '\n' {
|
||||
i++
|
||||
}
|
||||
continue
|
||||
}
|
||||
result = append(result, c)
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func findJSONStart(s string) int {
|
||||
for i, c := range s {
|
||||
if c == '{' {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func findJSONEnd(s string) int {
|
||||
depth := 0
|
||||
for i, c := range s {
|
||||
switch c {
|
||||
case '{':
|
||||
depth++
|
||||
case '}':
|
||||
depth--
|
||||
if depth == 0 {
|
||||
return i
|
||||
}
|
||||
}
|
||||
}
|
||||
return -1
|
||||
c.JSON(http.StatusNotImplemented, gin.H{"error": "research temporarily disabled during AI provider migration"})
|
||||
}
|
||||
|
||||
@@ -1,194 +0,0 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/VikingOwl91/mistral-go-sdk"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/chat"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/conversation"
|
||||
)
|
||||
|
||||
// rateLimiter enforces a minimum interval between calls. Set rps<=0 to disable.
|
||||
// Mirrors backend/internal/pkg/geocode/nominatim.go's lastReq gate.
|
||||
type rateLimiter struct {
|
||||
mu sync.Mutex
|
||||
lastReq time.Time
|
||||
minInterval time.Duration
|
||||
}
|
||||
|
||||
func newRateLimiter(rps float64) *rateLimiter {
|
||||
if rps <= 0 {
|
||||
return &rateLimiter{minInterval: 0}
|
||||
}
|
||||
return &rateLimiter{minInterval: time.Duration(float64(time.Second) / rps)}
|
||||
}
|
||||
|
||||
// TODO: wait() does not honor context cancellation — a cancelled caller will
|
||||
// block up to minInterval while holding a mutex, and queued callers block up
|
||||
// to N*minInterval. Mirror pkg/geocode/nominatim.go which has the same gap.
|
||||
// Fix both sites together by taking ctx and using time.NewTimer + select.
|
||||
func (rl *rateLimiter) wait() {
|
||||
if rl.minInterval == 0 {
|
||||
return
|
||||
}
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
since := time.Since(rl.lastReq)
|
||||
if since < rl.minInterval {
|
||||
time.Sleep(rl.minInterval - since)
|
||||
}
|
||||
rl.lastReq = time.Now()
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
sdk *mistral.Client
|
||||
agentSimple string
|
||||
modelComplex string
|
||||
limiter *rateLimiter
|
||||
}
|
||||
|
||||
func New(apiKey, agentSimple, modelComplex string, rps float64) *Client {
|
||||
if modelComplex == "" {
|
||||
modelComplex = "mistral-large-latest"
|
||||
}
|
||||
|
||||
var sdk *mistral.Client
|
||||
if apiKey != "" {
|
||||
sdk = mistral.NewClient(apiKey,
|
||||
mistral.WithTimeout(120*time.Second),
|
||||
mistral.WithRetry(2, 1*time.Second),
|
||||
)
|
||||
}
|
||||
|
||||
return &Client{
|
||||
sdk: sdk,
|
||||
agentSimple: agentSimple,
|
||||
modelComplex: modelComplex,
|
||||
limiter: newRateLimiter(rps),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) Enabled() bool {
|
||||
return c.sdk != nil && c.agentSimple != ""
|
||||
}
|
||||
|
||||
type UsageInfo struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type PassResult struct {
|
||||
Content string
|
||||
Usage *UsageInfo
|
||||
Model string
|
||||
}
|
||||
|
||||
// Pass1 uses the Conversations API to call the pre-created agent (with web search).
|
||||
func (c *Client) Pass1(ctx context.Context, prompt string) (PassResult, error) {
|
||||
c.limiter.wait()
|
||||
storeFalse := false
|
||||
resp, err := c.sdk.StartConversation(ctx, &conversation.StartRequest{
|
||||
AgentID: c.agentSimple,
|
||||
Inputs: conversation.TextInputs(prompt),
|
||||
Store: &storeFalse,
|
||||
})
|
||||
if err != nil {
|
||||
return PassResult{}, fmt.Errorf("pass1 conversation: %w", err)
|
||||
}
|
||||
|
||||
content := extractConversationContent(resp)
|
||||
if content == "" {
|
||||
return PassResult{}, fmt.Errorf("pass1: no assistant message in response")
|
||||
}
|
||||
|
||||
return PassResult{
|
||||
Content: content,
|
||||
Usage: convertConvUsage(resp.Usage),
|
||||
Model: "agent:" + c.agentSimple,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Pass0 uses the Conversations API to call a discovery agent identified by agentID.
|
||||
// The agent ID is passed explicitly so the discovery domain can configure its own
|
||||
// agent independently of the agentSimple field used by Pass1.
|
||||
func (c *Client) Pass0(ctx context.Context, agentID, prompt string) (PassResult, error) {
|
||||
c.limiter.wait()
|
||||
if c.sdk == nil || agentID == "" {
|
||||
return PassResult{}, fmt.Errorf("pass0: ai client not configured (sdk=%v agentID=%q)", c.sdk != nil, agentID)
|
||||
}
|
||||
storeFalse := false
|
||||
resp, err := c.sdk.StartConversation(ctx, &conversation.StartRequest{
|
||||
AgentID: agentID,
|
||||
Inputs: conversation.TextInputs(prompt),
|
||||
Store: &storeFalse,
|
||||
})
|
||||
if err != nil {
|
||||
return PassResult{}, fmt.Errorf("pass0 conversation: %w", err)
|
||||
}
|
||||
|
||||
content := extractConversationContent(resp)
|
||||
if content == "" {
|
||||
return PassResult{}, fmt.Errorf("pass0: no assistant message in response")
|
||||
}
|
||||
|
||||
return PassResult{
|
||||
Content: content,
|
||||
Usage: convertConvUsage(resp.Usage),
|
||||
Model: "agent:" + agentID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Pass2 uses chat completions for description generation + retry fields.
|
||||
func (c *Client) Pass2(ctx context.Context, systemPrompt, userPrompt string) (PassResult, error) {
|
||||
c.limiter.wait()
|
||||
resp, err := c.sdk.ChatComplete(ctx, &chat.CompletionRequest{
|
||||
Model: c.modelComplex,
|
||||
Messages: []chat.Message{
|
||||
&chat.SystemMessage{Content: chat.TextContent(systemPrompt)},
|
||||
&chat.UserMessage{Content: chat.TextContent(userPrompt)},
|
||||
},
|
||||
ResponseFormat: &chat.ResponseFormat{Type: "json_object"},
|
||||
})
|
||||
if err != nil {
|
||||
return PassResult{}, fmt.Errorf("pass2 chat: %w", err)
|
||||
}
|
||||
|
||||
if len(resp.Choices) == 0 {
|
||||
return PassResult{}, fmt.Errorf("pass2: no choices in response")
|
||||
}
|
||||
|
||||
return PassResult{
|
||||
Content: resp.Choices[0].Message.Content.String(),
|
||||
Usage: convertChatUsage(resp.Usage),
|
||||
Model: c.modelComplex,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func extractConversationContent(resp *conversation.Response) string {
|
||||
for _, entry := range resp.Outputs {
|
||||
if msg, ok := entry.(*conversation.MessageOutputEntry); ok {
|
||||
return msg.Content.String()
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func convertConvUsage(u conversation.UsageInfo) *UsageInfo {
|
||||
return &UsageInfo{
|
||||
PromptTokens: u.PromptTokens,
|
||||
CompletionTokens: u.CompletionTokens,
|
||||
TotalTokens: u.TotalTokens,
|
||||
}
|
||||
}
|
||||
|
||||
func convertChatUsage(u chat.UsageInfo) *UsageInfo {
|
||||
return &UsageInfo{
|
||||
PromptTokens: u.PromptTokens,
|
||||
CompletionTokens: u.CompletionTokens,
|
||||
TotalTokens: u.TotalTokens,
|
||||
}
|
||||
}
|
||||
99
backend/internal/pkg/ai/errors.go
Normal file
99
backend/internal/pkg/ai/errors.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ErrorCode int
|
||||
|
||||
const (
|
||||
ErrInternal ErrorCode = iota
|
||||
ErrRateLimited
|
||||
ErrQuotaExceeded
|
||||
ErrTimeout
|
||||
ErrInvalidRequest
|
||||
ErrUnavailable
|
||||
ErrSchemaViolation
|
||||
)
|
||||
|
||||
func (c ErrorCode) String() string {
|
||||
switch c {
|
||||
case ErrInternal:
|
||||
return "internal"
|
||||
case ErrRateLimited:
|
||||
return "rate_limited"
|
||||
case ErrQuotaExceeded:
|
||||
return "quota_exceeded"
|
||||
case ErrTimeout:
|
||||
return "timeout"
|
||||
case ErrInvalidRequest:
|
||||
return "invalid_request"
|
||||
case ErrUnavailable:
|
||||
return "unavailable"
|
||||
case ErrSchemaViolation:
|
||||
return "schema_violation"
|
||||
default:
|
||||
return "internal"
|
||||
}
|
||||
}
|
||||
|
||||
type ProviderError struct {
|
||||
Code ErrorCode
|
||||
Message string
|
||||
Retryable bool
|
||||
Inner error
|
||||
RawOutput string
|
||||
}
|
||||
|
||||
func (e *ProviderError) Error() string {
|
||||
if e.Inner != nil {
|
||||
return fmt.Sprintf("ai: %s: %s: %v", e.Code, e.Message, e.Inner)
|
||||
}
|
||||
return fmt.Sprintf("ai: %s: %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
func (e *ProviderError) Unwrap() error { return e.Inner }
|
||||
|
||||
func ClassifyError(err error) *ProviderError {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
var pe *ProviderError
|
||||
if errors.As(err, &pe) {
|
||||
return pe
|
||||
}
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return &ProviderError{Code: ErrTimeout, Message: "context deadline exceeded", Retryable: true, Inner: err}
|
||||
}
|
||||
|
||||
msg := strings.ToLower(err.Error())
|
||||
switch {
|
||||
case strings.Contains(msg, "429"),
|
||||
strings.Contains(msg, "too many requests"),
|
||||
strings.Contains(msg, "rate limit"):
|
||||
return &ProviderError{Code: ErrRateLimited, Message: err.Error(), Retryable: true, Inner: err}
|
||||
case strings.Contains(msg, "deadline exceeded"),
|
||||
strings.Contains(msg, "timeout"):
|
||||
return &ProviderError{Code: ErrTimeout, Message: err.Error(), Retryable: true, Inner: err}
|
||||
case strings.Contains(msg, "connection refused"),
|
||||
strings.Contains(msg, "no such host"),
|
||||
isNetError(err):
|
||||
return &ProviderError{Code: ErrUnavailable, Message: err.Error(), Retryable: true, Inner: err}
|
||||
case strings.Contains(msg, "quota"),
|
||||
strings.Contains(msg, "insufficient"):
|
||||
return &ProviderError{Code: ErrQuotaExceeded, Message: err.Error(), Retryable: false, Inner: err}
|
||||
case strings.Contains(msg, "400"),
|
||||
strings.Contains(msg, "invalid"):
|
||||
return &ProviderError{Code: ErrInvalidRequest, Message: err.Error(), Retryable: false, Inner: err}
|
||||
}
|
||||
return &ProviderError{Code: ErrInternal, Message: err.Error(), Retryable: false, Inner: err}
|
||||
}
|
||||
|
||||
func isNetError(err error) bool {
|
||||
var ne net.Error
|
||||
return errors.As(err, &ne)
|
||||
}
|
||||
41
backend/internal/pkg/ai/factory.go
Normal file
41
backend/internal/pkg/ai/factory.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
mistral "github.com/VikingOwl91/mistral-go-sdk"
|
||||
"github.com/VikingOwl91/mistral-go-sdk/chat"
|
||||
"marktvogt.de/backend/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
providerOllama = "ollama"
|
||||
providerMistral = "mistral"
|
||||
)
|
||||
|
||||
func NewFromConfig(cfg config.AIConfig) (Provider, error) {
|
||||
switch cfg.Provider {
|
||||
case "", providerOllama:
|
||||
return NewOllamaProvider(OllamaConfig{
|
||||
BaseURL: cfg.OllamaURL,
|
||||
Model: cfg.OllamaModel,
|
||||
}), nil
|
||||
case providerMistral:
|
||||
if cfg.MistralAPIKey == "" {
|
||||
return nil, fmt.Errorf("ai: provider=%s requires AI_MISTRAL_API_KEY", providerMistral)
|
||||
}
|
||||
sdk := mistral.NewClient(
|
||||
cfg.MistralAPIKey,
|
||||
mistral.WithTimeout(120*time.Second),
|
||||
mistral.WithRetry(2, 1*time.Second),
|
||||
)
|
||||
chatFn := func(ctx context.Context, req *chat.CompletionRequest) (*chat.CompletionResponse, error) {
|
||||
return sdk.ChatComplete(ctx, req)
|
||||
}
|
||||
return newMistralProviderWithChat(cfg.MistralModel, chatFn, newRateLimiter(cfg.RateLimitRPS)), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("ai: unknown provider %q (want %s|%s)", cfg.Provider, providerOllama, providerMistral)
|
||||
}
|
||||
}
|
||||
29
backend/internal/pkg/ai/factory_test.go
Normal file
29
backend/internal/pkg/ai/factory_test.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"marktvogt.de/backend/internal/config"
|
||||
)
|
||||
|
||||
func TestNewFromConfig_Ollama(t *testing.T) {
|
||||
p, err := NewFromConfig(config.AIConfig{Provider: providerOllama, OllamaURL: "http://x:11434", OllamaModel: "m"})
|
||||
if err != nil {
|
||||
t.Fatalf("NewFromConfig: %v", err)
|
||||
}
|
||||
if p.Name() != providerOllama {
|
||||
t.Fatalf("Name: %q", p.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewFromConfig_MistralRequiresKey(t *testing.T) {
|
||||
if _, err := NewFromConfig(config.AIConfig{Provider: providerMistral}); err == nil {
|
||||
t.Fatal("want error when MistralAPIKey is empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewFromConfig_UnknownProvider(t *testing.T) {
|
||||
if _, err := NewFromConfig(config.AIConfig{Provider: "llama-cpp"}); err == nil {
|
||||
t.Fatal("want error for unknown provider")
|
||||
}
|
||||
}
|
||||
99
backend/internal/pkg/ai/mistral_provider.go
Normal file
99
backend/internal/pkg/ai/mistral_provider.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/VikingOwl91/mistral-go-sdk/chat"
|
||||
)
|
||||
|
||||
type chatFunc func(ctx context.Context, req *chat.CompletionRequest) (*chat.CompletionResponse, error)
|
||||
|
||||
type MistralProvider struct {
|
||||
model string
|
||||
chatFn chatFunc
|
||||
limiter *rateLimiter // from ratelimit.go; nil disables
|
||||
}
|
||||
|
||||
func newMistralProviderWithChat(model string, fn chatFunc, limiter *rateLimiter) *MistralProvider {
|
||||
return &MistralProvider{model: model, chatFn: fn, limiter: limiter}
|
||||
}
|
||||
|
||||
func (p *MistralProvider) Name() string { return "mistral" }
|
||||
func (p *MistralProvider) SupportsJSONMode() bool { return true }
|
||||
func (p *MistralProvider) SupportsJSONSchema() bool { return false }
|
||||
|
||||
func (p *MistralProvider) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) {
|
||||
if p.chatFn == nil {
|
||||
return nil, &ProviderError{Code: ErrInternal, Message: "mistral provider not configured", Retryable: false}
|
||||
}
|
||||
if p.limiter != nil {
|
||||
p.limiter.wait()
|
||||
}
|
||||
|
||||
systemContent := req.SystemPrompt
|
||||
if len(req.JSONSchema) > 0 {
|
||||
if systemContent != "" {
|
||||
systemContent += "\n\n"
|
||||
}
|
||||
systemContent += "Respond with a JSON object that conforms to the following JSON Schema. " +
|
||||
"Do not output anything outside the JSON. Schema:\n" + string(req.JSONSchema)
|
||||
}
|
||||
|
||||
msgs := []chat.Message{}
|
||||
if systemContent != "" {
|
||||
msgs = append(msgs, &chat.SystemMessage{Content: chat.TextContent(systemContent)})
|
||||
}
|
||||
msgs = append(msgs, &chat.UserMessage{Content: chat.TextContent(req.UserMessage)})
|
||||
|
||||
creq := &chat.CompletionRequest{
|
||||
Model: firstNonEmpty(req.Model, p.model),
|
||||
Messages: msgs,
|
||||
}
|
||||
if req.JSONMode || len(req.JSONSchema) > 0 {
|
||||
creq.ResponseFormat = &chat.ResponseFormat{Type: "json_object"}
|
||||
}
|
||||
if req.Temperature != 0 {
|
||||
temp := float64(req.Temperature)
|
||||
creq.Temperature = &temp
|
||||
}
|
||||
if req.MaxTokens != 0 {
|
||||
creq.MaxTokens = &req.MaxTokens
|
||||
}
|
||||
|
||||
resp, err := p.chatFn(ctx, creq)
|
||||
if err != nil {
|
||||
return nil, ClassifyError(err)
|
||||
}
|
||||
if len(resp.Choices) == 0 {
|
||||
return nil, &ProviderError{Code: ErrInternal, Message: "no choices in response", Retryable: false}
|
||||
}
|
||||
content := resp.Choices[0].Message.Content.String()
|
||||
|
||||
if len(req.JSONSchema) > 0 {
|
||||
if err := ValidateSchema(req.JSONSchema, []byte(content)); err != nil {
|
||||
return nil, &ProviderError{
|
||||
Code: ErrSchemaViolation,
|
||||
Message: fmt.Sprintf("response does not match schema: %v", err),
|
||||
Retryable: true,
|
||||
Inner: err,
|
||||
RawOutput: content,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &ChatResponse{
|
||||
Content: content,
|
||||
Model: resp.Model,
|
||||
PromptTokens: resp.Usage.PromptTokens,
|
||||
OutputTokens: resp.Usage.CompletionTokens,
|
||||
TotalTokens: resp.Usage.TotalTokens,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func firstNonEmpty(a, b string) string {
|
||||
if a != "" {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
123
backend/internal/pkg/ai/mistral_provider_test.go
Normal file
123
backend/internal/pkg/ai/mistral_provider_test.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/VikingOwl91/mistral-go-sdk/chat"
|
||||
)
|
||||
|
||||
func TestMistral_Chat_PassesThroughContent(t *testing.T) {
|
||||
fakeChat := func(ctx context.Context, req *chat.CompletionRequest) (*chat.CompletionResponse, error) {
|
||||
return &chat.CompletionResponse{
|
||||
Model: "mistral-large-latest",
|
||||
Choices: []chat.CompletionChoice{
|
||||
{Message: chat.AssistantMessage{Content: chat.TextContent("ok")}},
|
||||
},
|
||||
Usage: chat.UsageInfo{PromptTokens: 3, CompletionTokens: 1, TotalTokens: 4},
|
||||
}, nil
|
||||
}
|
||||
p := newMistralProviderWithChat("mistral-large-latest", fakeChat, nil)
|
||||
|
||||
resp, err := p.Chat(context.Background(), &ChatRequest{SystemPrompt: "s", UserMessage: "u"})
|
||||
if err != nil {
|
||||
t.Fatalf("Chat: %v", err)
|
||||
}
|
||||
if resp.Content != "ok" || resp.TotalTokens != 4 {
|
||||
t.Fatalf("unexpected: %+v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMistral_Chat_JSONModeSetsResponseFormat(t *testing.T) {
|
||||
var seen *chat.CompletionRequest
|
||||
fakeChat := func(ctx context.Context, req *chat.CompletionRequest) (*chat.CompletionResponse, error) {
|
||||
seen = req
|
||||
return &chat.CompletionResponse{Choices: []chat.CompletionChoice{{Message: chat.AssistantMessage{Content: chat.TextContent("{}")}}}}, nil
|
||||
}
|
||||
p := newMistralProviderWithChat("m", fakeChat, nil)
|
||||
_, err := p.Chat(context.Background(), &ChatRequest{UserMessage: "x", JSONMode: true})
|
||||
if err != nil {
|
||||
t.Fatalf("Chat: %v", err)
|
||||
}
|
||||
if seen == nil || seen.ResponseFormat == nil || seen.ResponseFormat.Type != "json_object" {
|
||||
t.Fatalf("ResponseFormat not set: %+v", seen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMistral_Chat_SchemaEmbeddedInSystemPromptAndValidated(t *testing.T) {
|
||||
schema := []byte(`{"type":"object","required":["foo"],"properties":{"foo":{"type":"string"}},"additionalProperties":false}`)
|
||||
var seen *chat.CompletionRequest
|
||||
fakeChat := func(ctx context.Context, req *chat.CompletionRequest) (*chat.CompletionResponse, error) {
|
||||
seen = req
|
||||
return &chat.CompletionResponse{Choices: []chat.CompletionChoice{{Message: chat.AssistantMessage{Content: chat.TextContent(`{"foo":"bar"}`)}}}}, nil
|
||||
}
|
||||
p := newMistralProviderWithChat("m", fakeChat, nil)
|
||||
resp, err := p.Chat(context.Background(), &ChatRequest{SystemPrompt: "base system", UserMessage: "x", JSONSchema: schema})
|
||||
if err != nil {
|
||||
t.Fatalf("Chat: %v", err)
|
||||
}
|
||||
if resp.Content != `{"foo":"bar"}` {
|
||||
t.Fatalf("content: %q", resp.Content)
|
||||
}
|
||||
sysMsg, ok := seen.Messages[0].(*chat.SystemMessage)
|
||||
if !ok {
|
||||
t.Fatalf("first message must be system: %T", seen.Messages[0])
|
||||
}
|
||||
sys := sysMsg.Content.String()
|
||||
if !containsAll(sys, []string{"base system", "JSON Schema"}) {
|
||||
t.Fatalf("system prompt missing expected fragments: %q", sys)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMistral_Chat_SchemaViolationReturnsRetryableError(t *testing.T) {
|
||||
schema := []byte(`{"type":"object","required":["foo"],"properties":{"foo":{"type":"string"}},"additionalProperties":false}`)
|
||||
fakeChat := func(ctx context.Context, req *chat.CompletionRequest) (*chat.CompletionResponse, error) {
|
||||
return &chat.CompletionResponse{Choices: []chat.CompletionChoice{{Message: chat.AssistantMessage{Content: chat.TextContent(`{"bar":1}`)}}}}, nil
|
||||
}
|
||||
p := newMistralProviderWithChat("m", fakeChat, nil)
|
||||
_, err := p.Chat(context.Background(), &ChatRequest{UserMessage: "x", JSONSchema: schema})
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
var pe *ProviderError
|
||||
if !errors.As(err, &pe) {
|
||||
t.Fatalf("want *ProviderError, got %T", err)
|
||||
}
|
||||
if pe.Code != ErrSchemaViolation || !pe.Retryable || pe.RawOutput != `{"bar":1}` {
|
||||
t.Fatalf("unexpected: %+v", pe)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMistral_Supports(t *testing.T) {
|
||||
p := newMistralProviderWithChat("m", nil, nil)
|
||||
if !p.SupportsJSONMode() {
|
||||
t.Fatal("Mistral supports JSON mode")
|
||||
}
|
||||
if p.SupportsJSONSchema() {
|
||||
t.Fatal("Mistral does NOT natively support JSON schema (prompt-based only)")
|
||||
}
|
||||
if p.Name() != "mistral" {
|
||||
t.Fatalf("Name: %q", p.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func containsAll(s string, parts []string) bool {
|
||||
for _, p := range parts {
|
||||
if !contains(s, p) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
func contains(s, sub string) bool {
|
||||
return len(sub) == 0 || (len(s) >= len(sub) && (s == sub || indexOf(s, sub) >= 0))
|
||||
}
|
||||
func indexOf(s, sub string) int {
|
||||
for i := 0; i+len(sub) <= len(s); i++ {
|
||||
if s[i:i+len(sub)] == sub {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
127
backend/internal/pkg/ai/ollama.go
Normal file
127
backend/internal/pkg/ai/ollama.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type OllamaConfig struct {
|
||||
BaseURL string
|
||||
Model string
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
type OllamaProvider struct {
|
||||
cfg OllamaConfig
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func NewOllamaProvider(cfg OllamaConfig) *OllamaProvider {
|
||||
if cfg.Timeout == 0 {
|
||||
cfg.Timeout = 300 * time.Second
|
||||
}
|
||||
return &OllamaProvider{
|
||||
cfg: cfg,
|
||||
client: &http.Client{Timeout: cfg.Timeout},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *OllamaProvider) Name() string { return "ollama" }
|
||||
func (p *OllamaProvider) SupportsJSONMode() bool { return true }
|
||||
func (p *OllamaProvider) SupportsJSONSchema() bool { return true }
|
||||
|
||||
type ollamaChatReq struct {
|
||||
Model string `json:"model"`
|
||||
Messages []ollamaMessage `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
Format json.RawMessage `json:"format,omitempty"`
|
||||
Options *ollamaOptions `json:"options,omitempty"`
|
||||
}
|
||||
|
||||
type ollamaMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ollamaOptions struct {
|
||||
Temperature float32 `json:"temperature,omitempty"`
|
||||
NumPredict int `json:"num_predict,omitempty"`
|
||||
}
|
||||
|
||||
type ollamaChatResp struct {
|
||||
Model string `json:"model"`
|
||||
Message ollamaMessage `json:"message"`
|
||||
Done bool `json:"done"`
|
||||
PromptEvalCount int `json:"prompt_eval_count"`
|
||||
EvalCount int `json:"eval_count"`
|
||||
}
|
||||
|
||||
func (p *OllamaProvider) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) {
|
||||
model := req.Model
|
||||
if model == "" {
|
||||
model = p.cfg.Model
|
||||
}
|
||||
body := ollamaChatReq{
|
||||
Model: model,
|
||||
Messages: buildOllamaMessages(req),
|
||||
Stream: false,
|
||||
}
|
||||
switch {
|
||||
case len(req.JSONSchema) > 0:
|
||||
body.Format = req.JSONSchema
|
||||
case req.JSONMode:
|
||||
body.Format = json.RawMessage(`"json"`)
|
||||
}
|
||||
if req.Temperature != 0 || req.MaxTokens != 0 {
|
||||
body.Options = &ollamaOptions{Temperature: req.Temperature, NumPredict: req.MaxTokens}
|
||||
}
|
||||
|
||||
buf, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, &ProviderError{Code: ErrInvalidRequest, Message: "marshal request", Retryable: false, Inner: err}
|
||||
}
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, p.cfg.BaseURL+"/api/chat", bytes.NewReader(buf))
|
||||
if err != nil {
|
||||
return nil, &ProviderError{Code: ErrInternal, Message: "new request", Retryable: false, Inner: err}
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := p.client.Do(httpReq)
|
||||
if err != nil {
|
||||
return nil, ClassifyError(err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
if resp.StatusCode >= 400 {
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
pe := ClassifyError(fmt.Errorf("ollama status %d: %s", resp.StatusCode, string(b)))
|
||||
return nil, pe
|
||||
}
|
||||
|
||||
var out ollamaChatResp
|
||||
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
||||
return nil, &ProviderError{Code: ErrInternal, Message: "decode response", Retryable: false, Inner: err}
|
||||
}
|
||||
return &ChatResponse{
|
||||
Content: out.Message.Content,
|
||||
Model: out.Model,
|
||||
PromptTokens: out.PromptEvalCount,
|
||||
OutputTokens: out.EvalCount,
|
||||
TotalTokens: out.PromptEvalCount + out.EvalCount,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildOllamaMessages(req *ChatRequest) []ollamaMessage {
|
||||
msgs := make([]ollamaMessage, 0, 2)
|
||||
if req.SystemPrompt != "" {
|
||||
msgs = append(msgs, ollamaMessage{Role: "system", Content: req.SystemPrompt})
|
||||
}
|
||||
if req.UserMessage != "" {
|
||||
msgs = append(msgs, ollamaMessage{Role: "user", Content: req.UserMessage})
|
||||
}
|
||||
return msgs
|
||||
}
|
||||
98
backend/internal/pkg/ai/ollama_test.go
Normal file
98
backend/internal/pkg/ai/ollama_test.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestOllama_Chat_SendsRequestAndParsesResponse(t *testing.T) {
|
||||
var captured map[string]any
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/api/chat" {
|
||||
t.Errorf("path: got %s, want /api/chat", r.URL.Path)
|
||||
}
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(body, &captured)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`{
|
||||
"model":"qwen2.5:14b-instruct",
|
||||
"message":{"role":"assistant","content":"hello"},
|
||||
"done":true,
|
||||
"prompt_eval_count":10,
|
||||
"eval_count":5
|
||||
}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := NewOllamaProvider(OllamaConfig{BaseURL: srv.URL, Model: "qwen2.5:14b-instruct", Timeout: 5 * time.Second})
|
||||
resp, err := p.Chat(context.Background(), &ChatRequest{SystemPrompt: "be brief", UserMessage: "hi", JSONMode: true})
|
||||
if err != nil {
|
||||
t.Fatalf("Chat: %v", err)
|
||||
}
|
||||
if resp.Content != "hello" {
|
||||
t.Fatalf("content: got %q", resp.Content)
|
||||
}
|
||||
if resp.PromptTokens != 10 || resp.OutputTokens != 5 || resp.TotalTokens != 15 {
|
||||
t.Fatalf("tokens: %+v", resp)
|
||||
}
|
||||
if captured["stream"] != false {
|
||||
t.Fatalf("stream must be false: %v", captured["stream"])
|
||||
}
|
||||
if captured["format"] != "json" {
|
||||
t.Fatalf("format for JSONMode=true must be \"json\", got %v", captured["format"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestOllama_Chat_ForwardsSchema(t *testing.T) {
|
||||
var captured map[string]any
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(body, &captured)
|
||||
_, _ = w.Write([]byte(`{"model":"m","message":{"role":"assistant","content":"{}"},"done":true}`))
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
p := NewOllamaProvider(OllamaConfig{BaseURL: srv.URL, Model: "m", Timeout: time.Second})
|
||||
schema := []byte(`{"type":"object"}`)
|
||||
_, err := p.Chat(context.Background(), &ChatRequest{UserMessage: "x", JSONSchema: schema})
|
||||
if err != nil {
|
||||
t.Fatalf("Chat: %v", err)
|
||||
}
|
||||
fmtField, ok := captured["format"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("format must be an object when JSONSchema set: %v", captured["format"])
|
||||
}
|
||||
if fmtField["type"] != "object" {
|
||||
t.Fatalf("schema not forwarded: %v", fmtField)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOllama_Chat_Unavailable(t *testing.T) {
|
||||
p := NewOllamaProvider(OllamaConfig{BaseURL: "http://127.0.0.1:1", Timeout: 100 * time.Millisecond})
|
||||
_, err := p.Chat(context.Background(), &ChatRequest{UserMessage: "x"})
|
||||
if err == nil {
|
||||
t.Fatal("want error, got nil")
|
||||
}
|
||||
pe := ClassifyError(err)
|
||||
if pe.Code != ErrUnavailable && pe.Code != ErrTimeout {
|
||||
t.Fatalf("expected Unavailable or Timeout, got %v", pe.Code)
|
||||
}
|
||||
if !pe.Retryable {
|
||||
t.Fatal("must be retryable")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOllama_Supports(t *testing.T) {
|
||||
p := NewOllamaProvider(OllamaConfig{BaseURL: "x"})
|
||||
if !p.SupportsJSONMode() || !p.SupportsJSONSchema() {
|
||||
t.Fatal("Ollama supports both")
|
||||
}
|
||||
if p.Name() != "ollama" {
|
||||
t.Fatalf("Name: %q", p.Name())
|
||||
}
|
||||
}
|
||||
31
backend/internal/pkg/ai/provider.go
Normal file
31
backend/internal/pkg/ai/provider.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
type Provider interface {
|
||||
Name() string
|
||||
Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error)
|
||||
SupportsJSONMode() bool
|
||||
SupportsJSONSchema() bool
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
SystemPrompt string
|
||||
UserMessage string
|
||||
Model string
|
||||
MaxTokens int
|
||||
Temperature float32
|
||||
JSONMode bool
|
||||
JSONSchema json.RawMessage
|
||||
}
|
||||
|
||||
type ChatResponse struct {
|
||||
Content string
|
||||
Model string
|
||||
PromptTokens int
|
||||
OutputTokens int
|
||||
TotalTokens int
|
||||
}
|
||||
89
backend/internal/pkg/ai/provider_test.go
Normal file
89
backend/internal/pkg/ai/provider_test.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestChatRequest_DefaultsAreZeroValues(t *testing.T) {
|
||||
r := ChatRequest{}
|
||||
if r.Temperature != 0 {
|
||||
t.Fatalf("Temperature default: got %v, want 0", r.Temperature)
|
||||
}
|
||||
if r.MaxTokens != 0 {
|
||||
t.Fatalf("MaxTokens default: got %v, want 0", r.MaxTokens)
|
||||
}
|
||||
if r.JSONMode {
|
||||
t.Fatalf("JSONMode default: got true, want false")
|
||||
}
|
||||
if r.JSONSchema != nil {
|
||||
t.Fatalf("JSONSchema default: got non-nil, want nil")
|
||||
}
|
||||
}
|
||||
|
||||
type stubProvider struct {
|
||||
name string
|
||||
resp *ChatResponse
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *stubProvider) Name() string { return s.name }
|
||||
func (s *stubProvider) SupportsJSONMode() bool { return true }
|
||||
func (s *stubProvider) SupportsJSONSchema() bool { return true }
|
||||
func (s *stubProvider) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) {
|
||||
return s.resp, s.err
|
||||
}
|
||||
|
||||
func TestStubProvider_SatisfiesInterface(t *testing.T) {
|
||||
var _ Provider = (*stubProvider)(nil)
|
||||
}
|
||||
|
||||
func TestStubProvider_ReturnsError(t *testing.T) {
|
||||
sentinel := errors.New("boom")
|
||||
var p Provider = &stubProvider{name: "stub", err: sentinel}
|
||||
_, err := p.Chat(context.Background(), &ChatRequest{})
|
||||
if !errors.Is(err, sentinel) {
|
||||
t.Fatalf("got %v, want sentinel", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderError_Error(t *testing.T) {
|
||||
inner := errors.New("root cause")
|
||||
pe := &ProviderError{Code: ErrRateLimited, Message: "slow down", Retryable: true, Inner: inner}
|
||||
got := pe.Error()
|
||||
if got == "" {
|
||||
t.Fatal("Error() must not be empty")
|
||||
}
|
||||
if !errors.Is(pe, inner) {
|
||||
t.Fatalf("errors.Is must unwrap Inner")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_DetectsRateLimit(t *testing.T) {
|
||||
cases := []struct {
|
||||
msg string
|
||||
want ErrorCode
|
||||
}{
|
||||
{"429 Too Many Requests", ErrRateLimited},
|
||||
{"context deadline exceeded", ErrTimeout},
|
||||
{"connection refused", ErrUnavailable},
|
||||
{"something totally unrecognized", ErrInternal},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.msg, func(t *testing.T) {
|
||||
got := ClassifyError(errors.New(tc.msg)).Code
|
||||
if got != tc.want {
|
||||
t.Fatalf("ClassifyError(%q) = %v, want %v", tc.msg, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyError_PreservesProviderError(t *testing.T) {
|
||||
orig := &ProviderError{Code: ErrSchemaViolation, Message: "x", Retryable: true}
|
||||
got := ClassifyError(orig)
|
||||
if got != orig {
|
||||
t.Fatal("ClassifyError must return the same *ProviderError instance")
|
||||
}
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
package ai
|
||||
|
||||
import "strings"
|
||||
|
||||
// IsRateLimit reports whether err is a Mistral 429 / web_search rate-limit
|
||||
// signal. The SDK surfaces these as wrapped errors whose message contains
|
||||
// "rate limit" and/or "status 429"; we match defensively on both.
|
||||
//
|
||||
// Shared helper so domains (discovery, market/research) treat rate limits
|
||||
// consistently: log at WARN, return a 429 to the caller with a Retry-After
|
||||
// hint instead of a generic 500.
|
||||
func IsRateLimit(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
msg := err.Error()
|
||||
return strings.Contains(msg, "rate limit") ||
|
||||
strings.Contains(msg, "status 429") ||
|
||||
strings.Contains(msg, "429")
|
||||
}
|
||||
|
||||
// DefaultRetryAfterSeconds is the hint we send when the SDK error doesn't
|
||||
// carry a structured Retry-After. 60s matches the typical web_search budget
|
||||
// window on Mistral's paid tier.
|
||||
const DefaultRetryAfterSeconds = 60
|
||||
@@ -1,28 +0,0 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsRateLimit(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
want bool
|
||||
}{
|
||||
{"nil", nil, false},
|
||||
{"unrelated error", errors.New("connection refused"), false},
|
||||
{"mistral web_search 429", errors.New("pass1 conversation: mistral: web_search rate limit reached. (status 429)"), true},
|
||||
{"bare status 429", errors.New("upstream returned status 429"), true},
|
||||
{"raw 429 token", errors.New("http 429 received"), true},
|
||||
{"token 42900 should NOT match bare token but does by substring", errors.New("code 42900 from upstream"), true},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := IsRateLimit(tc.err); got != tc.want {
|
||||
t.Errorf("IsRateLimit(%v) = %v, want %v", tc.err, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
33
backend/internal/pkg/ai/ratelimiter.go
Normal file
33
backend/internal/pkg/ai/ratelimiter.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// rateLimiter enforces a minimum interval between calls. Set rps<=0 to disable.
|
||||
type rateLimiter struct {
|
||||
mu sync.Mutex
|
||||
lastReq time.Time
|
||||
minInterval time.Duration
|
||||
}
|
||||
|
||||
func newRateLimiter(rps float64) *rateLimiter {
|
||||
if rps <= 0 {
|
||||
return &rateLimiter{minInterval: 0}
|
||||
}
|
||||
return &rateLimiter{minInterval: time.Duration(float64(time.Second) / rps)}
|
||||
}
|
||||
|
||||
func (rl *rateLimiter) wait() {
|
||||
if rl.minInterval == 0 {
|
||||
return
|
||||
}
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
since := time.Since(rl.lastReq)
|
||||
if since < rl.minInterval {
|
||||
time.Sleep(rl.minInterval - since)
|
||||
}
|
||||
rl.lastReq = time.Now()
|
||||
}
|
||||
28
backend/internal/pkg/ai/schema.go
Normal file
28
backend/internal/pkg/ai/schema.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/santhosh-tekuri/jsonschema/v5"
|
||||
)
|
||||
|
||||
func ValidateSchema(schemaJSON, docJSON []byte) error {
|
||||
compiler := jsonschema.NewCompiler()
|
||||
if err := compiler.AddResource("schema.json", bytes.NewReader(schemaJSON)); err != nil {
|
||||
return fmt.Errorf("add schema: %w", err)
|
||||
}
|
||||
sch, err := compiler.Compile("schema.json")
|
||||
if err != nil {
|
||||
return fmt.Errorf("compile schema: %w", err)
|
||||
}
|
||||
var doc any
|
||||
if err := json.Unmarshal(docJSON, &doc); err != nil {
|
||||
return fmt.Errorf("parse doc: %w", err)
|
||||
}
|
||||
if err := sch.Validate(doc); err != nil {
|
||||
return fmt.Errorf("validate: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
25
backend/internal/pkg/ai/schema_test.go
Normal file
25
backend/internal/pkg/ai/schema_test.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package ai
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestValidateSchema_AcceptsMatchingDoc(t *testing.T) {
|
||||
schema := []byte(`{"type":"object","required":["foo"],"properties":{"foo":{"type":"string"}},"additionalProperties":false}`)
|
||||
doc := []byte(`{"foo":"bar"}`)
|
||||
if err := ValidateSchema(schema, doc); err != nil {
|
||||
t.Fatalf("want nil, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSchema_RejectsMissingRequired(t *testing.T) {
|
||||
schema := []byte(`{"type":"object","required":["foo"],"properties":{"foo":{"type":"string"}}}`)
|
||||
doc := []byte(`{}`)
|
||||
if err := ValidateSchema(schema, doc); err == nil {
|
||||
t.Fatal("want error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateSchema_RejectsBadJSON(t *testing.T) {
|
||||
if err := ValidateSchema([]byte(`{}`), []byte(`not json`)); err == nil {
|
||||
t.Fatal("want error, got nil")
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -68,23 +69,20 @@ func (s *Server) registerRoutes() {
|
||||
|
||||
// Admin market routes
|
||||
adminMarketHandler := market.NewAdminHandler(marketSvc)
|
||||
aiClient := ai.New(s.cfg.AI.APIKey, s.cfg.AI.AgentSimple, s.cfg.AI.ModelComplex, s.cfg.AI.RateLimitRPS)
|
||||
researchHandler := market.NewResearchHandler(marketSvc, aiClient)
|
||||
aiProvider, err := ai.NewFromConfig(s.cfg.AI)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("init ai provider: %w", err))
|
||||
}
|
||||
researchHandler := market.NewResearchHandler(marketSvc, aiProvider)
|
||||
requireAdmin := middleware.RequireRole("admin")
|
||||
market.RegisterAdminRoutes(v1, adminMarketHandler, researchHandler, requireAuth, requireAdmin)
|
||||
|
||||
// Discovery routes
|
||||
discoveryRepo := discovery.NewRepository(s.db)
|
||||
crawlerInstance := crawler.NewCrawler(s.cfg.Discovery.CrawlerUserAgent, crawler.DefaultSourceConfigs())
|
||||
// Per-row LLM enrichment (MR 3b). Operator-triggered only; disabled rows
|
||||
// fall through via NoopLLMEnricher when AI isn't configured.
|
||||
var llmEnricher enrich.LLMEnricher = enrich.NoopLLMEnricher{}
|
||||
var simClassifier enrich.SimilarityClassifier = enrich.NoopSimilarityClassifier{}
|
||||
if aiClient.Enabled() {
|
||||
scraper := scrape.New(s.cfg.Discovery.CrawlerUserAgent)
|
||||
llmEnricher = enrich.NewMistralLLMEnricher(aiClient, scraper)
|
||||
simClassifier = enrich.NewMistralSimilarityClassifier(aiClient)
|
||||
}
|
||||
scraper := scrape.New(s.cfg.Discovery.CrawlerUserAgent)
|
||||
llmEnricher := enrich.NewLLMEnricher(aiProvider, scraper)
|
||||
simClassifier := enrich.NewSimilarityClassifier(aiProvider)
|
||||
discoveryService := discovery.NewService(discoveryRepo, crawlerInstance, discovery.NewLinkChecker(), marketSvc, geocoder, llmEnricher, simClassifier)
|
||||
discoveryHandler := discovery.NewHandler(discoveryService, s.cfg.Discovery.CrawlerManualRateLimitPerHour)
|
||||
requireTickToken := middleware.RequireBearerToken(s.cfg.Discovery.Token)
|
||||
|
||||
Reference in New Issue
Block a user