diff --git a/.gitignore b/.gitignore index 930c70e..ce3ffcc 100644 --- a/.gitignore +++ b/.gitignore @@ -10,7 +10,9 @@ vendor/ # discovery-eval local caches + generated reports .eval-cache.json +.cat-eval-cache.json eval-report.json +cat-eval-report.json # ── Web ────────────────────────────────────── /web/node_modules/ diff --git a/backend/cmd/discovery-eval/README.md b/backend/cmd/discovery-eval/README.md index 0acd366..0345b30 100644 --- a/backend/cmd/discovery-eval/README.md +++ b/backend/cmd/discovery-eval/README.md @@ -1,24 +1,53 @@ # discovery-eval -CLI that measures the `MistralSimilarityClassifier` against a labelled -fixture of same-/different-market pairs. Reports precision, recall, F1, -accuracy, and confidence calibration. File-based cache keeps reruns free. +CLI that grades discovery's AI-backed components against labelled fixtures. +Two modes: + +- `-mode similarity` (default) — `MistralSimilarityClassifier` on pair- + labelled fixtures. Reports precision / recall / F1 / accuracy + a + confidence calibration table. +- `-mode category` — `MistralLLMEnricher`'s `category` output on row- + labelled fixtures. Reports accuracy + a per-label confusion matrix. + +File-based cache keeps reruns free. Each mode has its own cache key shape, +so switching modes doesn't churn entries. ## Run it +### Similarity (default) + ``` export AI_API_KEY=... export AI_MODEL_COMPLEX=mistral-large-latest go run ./backend/cmd/discovery-eval \ + -mode similarity \ -fixture backend/cmd/discovery-eval/fixtures/similarity.json \ -cache .eval-cache.json \ -threshold 0.8 \ -report eval-report.json ``` -Exit code is 1 when `F1 < threshold` (0 = gating disabled). That makes it -usable as a CI regression gate once a baseline F1 is known. +Exit code is 1 when `F1 < threshold` (0 = gating disabled). + +### Category + +``` +export AI_API_KEY=... +export AI_MODEL_COMPLEX=mistral-large-latest + +go run ./backend/cmd/discovery-eval \ + -mode category \ + -fixture backend/cmd/discovery-eval/fixtures/category.json \ + -cache .cat-eval-cache.json \ + -threshold 0.7 \ + -report cat-eval-report.json +``` + +Category mode scrapes each row's `quellen` URLs live (first run only; cache +covers subsequent runs) and asks the LLM enricher to produce a category. +Normalised comparison: casing + German umlauts + the -märkte/-markt plural +drift are all treated as equal. Exit code is 1 when `accuracy < threshold`. ## Extending the fixture diff --git a/backend/cmd/discovery-eval/category.go b/backend/cmd/discovery-eval/category.go new file mode 100644 index 0000000..7f96f2e --- /dev/null +++ b/backend/cmd/discovery-eval/category.go @@ -0,0 +1,357 @@ +package main + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "io/fs" + "log/slog" + "os" + "path/filepath" + "sort" + "strings" + + "marktvogt.de/backend/internal/domain/discovery/enrich" +) + +// CategoryFixture is the parsed shape of fixtures/category.json — labelled +// rows where `expected_category` is the ground truth for the MistralLLMEnricher's +// output category field. +type CategoryFixture struct { + Rows []CategoryRow `json:"rows"` +} + +// CategoryRow is one ground-truth example. quellen is the list of source URLs +// the enricher will scrape. `expected_category` is the operator's judgement +// of what the correct German label should be — normalised before comparison +// so case/umlaut drift doesn't falsely grade as wrong. +type CategoryRow struct { + MarktName string `json:"markt_name"` + Stadt string `json:"stadt"` + Bundesland string `json:"bundesland,omitempty"` + Land string `json:"land,omitempty"` + Year int `json:"year,omitempty"` + Quellen []string `json:"quellen"` + ExpectedCategory string `json:"expected_category"` + Note string `json:"note,omitempty"` +} + +func loadCategoryFixture(path string) (*CategoryFixture, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read fixture: %w", err) + } + var f CategoryFixture + if err := json.Unmarshal(data, &f); err != nil { + return nil, fmt.Errorf("parse fixture: %w", err) + } + if len(f.Rows) == 0 { + return nil, fmt.Errorf("fixture has no rows") + } + return &f, nil +} + +// CategoryCache is the category mode's sibling to Cache. Keyed on the row's +// content tuple + model so a model bump forces a refresh. +type CategoryCache struct { + Entries map[string]enrich.Enrichment `json:"entries"` +} + +func newCategoryCache() *CategoryCache { + return &CategoryCache{Entries: map[string]enrich.Enrichment{}} +} + +// categoryCacheKey hashes (markt_name_lower|stadt_lower|year|model). Separate +// from SimilarityPairKey because that function takes two rows and sorts — +// here we key on one row. +func categoryCacheKey(r CategoryRow, model string) string { + raw := fmt.Sprintf("%s|%s|%d|%s", + strings.ToLower(r.MarktName), strings.ToLower(r.Stadt), r.Year, model) + sum := sha256.Sum256([]byte(raw)) + return hex.EncodeToString(sum[:]) +} + +func (c *CategoryCache) Get(r CategoryRow, model string) (enrich.Enrichment, bool) { + v, ok := c.Entries[categoryCacheKey(r, model)] + return v, ok +} + +func (c *CategoryCache) Put(r CategoryRow, model string, v enrich.Enrichment) { + c.Entries[categoryCacheKey(r, model)] = v +} + +func loadCategoryCache(path string) (*CategoryCache, error) { + data, err := os.ReadFile(path) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + return newCategoryCache(), nil + } + return nil, fmt.Errorf("read cache: %w", err) + } + c := newCategoryCache() + if err := json.Unmarshal(data, c); err != nil { + return newCategoryCache(), fmt.Errorf("parse cache (starting empty): %w", err) + } + if c.Entries == nil { + c.Entries = map[string]enrich.Enrichment{} + } + return c, nil +} + +// saveCategoryCache is the same atomic-write pattern as saveCache. Duplicated +// rather than generic-ed because Go generics on JSON types add more noise +// than they save for two call sites. +func saveCategoryCache(path string, c *CategoryCache) error { + data, err := json.MarshalIndent(c, "", " ") + if err != nil { + return fmt.Errorf("marshal cache: %w", err) + } + dir := filepath.Dir(path) + if dir == "" { + dir = "." + } + tmp, err := os.CreateTemp(dir, ".cat-cache-*.tmp") + if err != nil { + return fmt.Errorf("create tmp: %w", err) + } + tmpPath := tmp.Name() + if _, err := tmp.Write(data); err != nil { + _ = tmp.Close() + _ = os.Remove(tmpPath) + return fmt.Errorf("write tmp: %w", err) + } + if err := tmp.Close(); err != nil { + _ = os.Remove(tmpPath) + return fmt.Errorf("close tmp: %w", err) + } + if err := os.Rename(tmpPath, path); err != nil { + _ = os.Remove(tmpPath) + return fmt.Errorf("rename tmp: %w", err) + } + return nil +} + +// normalizeCategory strips casing drift + German umlauts for comparison. +// "Mittelaltermarkt" == "mittelaltermarkt" == "Mittelaltermärkte" (last one +// loses the 'e' pluralisation — too aggressive for identity, but good enough +// for categorical matching when the LLM occasionally emits plurals). +func normalizeCategory(s string) string { + s = strings.TrimSpace(strings.ToLower(s)) + replacer := strings.NewReplacer( + "ä", "a", "ö", "o", "ü", "u", "ß", "ss", + ) + s = replacer.Replace(s) + // Drop trailing 'e' on plurals (märkte → markte → markt). Only a light + // heuristic — applied only when stripping produces a known stem. + if strings.HasSuffix(s, "markte") { + s = strings.TrimSuffix(s, "e") + } + return s +} + +// CategoryResult mirrors Result for the category mode. `Got` is the raw +// category the LLM returned (before normalisation) so mismatches stay +// legible in the report. +type CategoryResult struct { + Row CategoryRow + Got string + Want string + Correct bool + FromCache bool + // Err records the scrape/LLM failure message when the run couldn't + // produce a category at all; scored as not-correct. + Err string +} + +// runCategory is the category-mode equivalent of run(). Uses the real +// MistralLLMEnricher — that's the whole point of category eval (scrape + +// LLM against labelled outputs). +func runCategory( + ctx context.Context, + enricher enrich.LLMEnricher, + cache *CategoryCache, + fixture *CategoryFixture, + model string, +) ([]CategoryResult, error) { + results := make([]CategoryResult, 0, len(fixture.Rows)) + for i, r := range fixture.Rows { + if err := ctx.Err(); err != nil { + return results, err + } + + if v, ok := cache.Get(r, model); ok { + results = append(results, scoreCategoryResult(r, v, "", true)) + continue + } + + req := enrich.LLMRequest{ + MarktName: r.MarktName, + Stadt: r.Stadt, + Land: r.Land, + Bundesland: r.Bundesland, + Quellen: r.Quellen, + } + got, err := enricher.EnrichMissing(ctx, req) + if err != nil { + slog.Warn("enrich failed; scoring as incorrect", + "row_index", i, "markt", r.MarktName, "error", err) + results = append(results, CategoryResult{ + Row: r, + Want: r.ExpectedCategory, + Err: err.Error(), + }) + continue + } + cache.Put(r, model, got) + results = append(results, scoreCategoryResult(r, got, "", false)) + } + return results, nil +} + +func scoreCategoryResult(r CategoryRow, got enrich.Enrichment, errMsg string, fromCache bool) CategoryResult { + gotCat := strings.TrimSpace(got.Category) + return CategoryResult{ + Row: r, + Got: gotCat, + Want: r.ExpectedCategory, + Correct: normalizeCategory(gotCat) == normalizeCategory(r.ExpectedCategory), + FromCache: fromCache, + Err: errMsg, + } +} + +// CategoryMetrics summarises a category-mode run. Unlike similarity (binary, +// P/R/F1 matter), categorical eval cares about accuracy + per-category +// confusion — which labels get mixed up. +type CategoryMetrics struct { + Total int `json:"total"` + Correct int `json:"correct"` + Incorrect int `json:"incorrect"` + Errors int `json:"errors"` // rows that failed to produce any category + CacheHits int `json:"cache_hits"` + LLMCalls int `json:"llm_calls"` + // Confusion: label → {predicted label → count}. Excludes errored rows. + Confusion map[string]map[string]int `json:"confusion"` + Accuracy float64 `json:"accuracy"` +} + +func computeCategoryMetrics(results []CategoryResult) CategoryMetrics { + m := CategoryMetrics{ + Total: len(results), + Confusion: map[string]map[string]int{}, + } + if m.Total == 0 { + return m + } + for _, r := range results { + if r.FromCache { + m.CacheHits++ + } else { + m.LLMCalls++ + } + if r.Err != "" { + m.Errors++ + continue + } + want := normalizeCategory(r.Want) + got := normalizeCategory(r.Got) + if r.Correct { + m.Correct++ + } else { + m.Incorrect++ + } + if _, ok := m.Confusion[want]; !ok { + m.Confusion[want] = map[string]int{} + } + m.Confusion[want][got]++ + } + m.Accuracy = float64(m.Correct) / float64(m.Total) + return m +} + +// printCategorySummary writes a human-readable report. Mirrors printSummary's +// shape so operators switching modes see the same vocabulary. +func printCategorySummary(w io.Writer, results []CategoryResult, m CategoryMetrics, model string) { + wf(w, "\n=== discovery-eval (category) ===\n") + wf(w, "model: %s\n", model) + wf(w, "rows: %d\n", m.Total) + wf(w, "cache: %d hits, %d llm calls\n", m.CacheHits, m.LLMCalls) + wf(w, "\n") + wf(w, "correct: %d\n", m.Correct) + wf(w, "incorrect: %d\n", m.Incorrect) + wf(w, "errors: %d\n", m.Errors) + wf(w, "accuracy: %.3f\n", m.Accuracy) + + // Confusion matrix — one row per expected label showing how predictions + // distributed. Labels sorted alphabetically for stable output. + if len(m.Confusion) > 0 { + labels := make([]string, 0, len(m.Confusion)) + for k := range m.Confusion { + labels = append(labels, k) + } + sort.Strings(labels) + wf(w, "\nconfusion (want → predictions):\n") + for _, want := range labels { + preds := m.Confusion[want] + predKeys := make([]string, 0, len(preds)) + for k := range preds { + predKeys = append(predKeys, k) + } + sort.Strings(predKeys) + wf(w, " %-24s", want) + first := true + for _, p := range predKeys { + if !first { + wf(w, ", ") + } + marker := "" + if p != want { + marker = "!" + } + wf(w, "%s%s×%d", marker, p, preds[p]) + first = false + } + wf(w, "\n") + } + } + + // Surface individual mistakes so the operator can patch the prompt or fixture. + wrong := make([]CategoryResult, 0) + for _, r := range results { + if !r.Correct || r.Err != "" { + wrong = append(wrong, r) + } + } + if len(wrong) > 0 { + wf(w, "\nmismatches (%d):\n", len(wrong)) + for _, r := range wrong { + if r.Err != "" { + wf(w, " ERROR %q (%s): %s\n", r.Row.MarktName, r.Row.Stadt, r.Err) + continue + } + wf(w, " want=%q got=%q %q (%s)\n", r.Want, r.Got, r.Row.MarktName, r.Row.Stadt) + } + } + wf(w, "\n") +} + +// CategoryReport is the on-disk shape for -report in category mode. +type CategoryReport struct { + Mode string `json:"mode"` + Model string `json:"model"` + Metrics CategoryMetrics `json:"metrics"` + Results []CategoryResult `json:"results"` +} + +func writeCategoryReport(path string, results []CategoryResult, m CategoryMetrics, model string) error { + rep := CategoryReport{Mode: "category", Model: model, Metrics: m, Results: results} + data, err := json.MarshalIndent(rep, "", " ") + if err != nil { + return err + } + return os.WriteFile(path, data, 0o644) +} diff --git a/backend/cmd/discovery-eval/category_test.go b/backend/cmd/discovery-eval/category_test.go new file mode 100644 index 0000000..9544b0b --- /dev/null +++ b/backend/cmd/discovery-eval/category_test.go @@ -0,0 +1,171 @@ +package main + +import ( + "os" + "path/filepath" + "testing" + + "marktvogt.de/backend/internal/domain/discovery/enrich" +) + +func TestNormalizeCategory_UmlautAndCase(t *testing.T) { + cases := []struct { + in, want string + }{ + {"Mittelaltermarkt", "mittelaltermarkt"}, + {"MITTELALTERMARKT", "mittelaltermarkt"}, + {"Mittelaltermärkte", "mittelaltermarkt"}, + {"Weihnachtsmarkt", "weihnachtsmarkt"}, + {"Schönbrunn-Fest", "schonbrunn-fest"}, + {"weißwurst", "weisswurst"}, + {" Ritterfest ", "ritterfest"}, + {"", ""}, + } + for _, c := range cases { + t.Run(c.in, func(t *testing.T) { + got := normalizeCategory(c.in) + if got != c.want { + t.Errorf("normalizeCategory(%q) = %q; want %q", c.in, got, c.want) + } + }) + } +} + +func TestScoreCategoryResult_CaseInsensitive(t *testing.T) { + // Same category but casing/diacritic drift — must be counted as correct. + r := CategoryRow{MarktName: "X", Stadt: "Y", ExpectedCategory: "mittelaltermarkt"} + got := scoreCategoryResult(r, enrich.Enrichment{Category: "Mittelaltermarkt"}, "", false) + if !got.Correct { + t.Errorf("casing drift should score as correct, got %+v", got) + } +} + +func TestScoreCategoryResult_UmlautDrift(t *testing.T) { + r := CategoryRow{ExpectedCategory: "mittelaltermarkt"} + got := scoreCategoryResult(r, enrich.Enrichment{Category: "Mittelaltermärkte"}, "", false) + if !got.Correct { + t.Errorf("umlaut + plural drift should normalise to correct, got %+v", got) + } +} + +func TestScoreCategoryResult_WrongLabel(t *testing.T) { + r := CategoryRow{ExpectedCategory: "mittelaltermarkt"} + got := scoreCategoryResult(r, enrich.Enrichment{Category: "weihnachtsmarkt"}, "", false) + if got.Correct { + t.Errorf("distinct labels must not score as correct: %+v", got) + } +} + +func TestComputeCategoryMetrics_BasicAccuracy(t *testing.T) { + results := []CategoryResult{ + {Row: CategoryRow{ExpectedCategory: "a"}, Got: "a", Want: "a", Correct: true}, + {Row: CategoryRow{ExpectedCategory: "a"}, Got: "b", Want: "a", Correct: false}, + {Row: CategoryRow{ExpectedCategory: "c"}, Got: "c", Want: "c", Correct: true, FromCache: true}, + } + m := computeCategoryMetrics(results) + if m.Total != 3 || m.Correct != 2 || m.Incorrect != 1 || m.Errors != 0 { + t.Errorf("counts wrong: %+v", m) + } + if m.CacheHits != 1 || m.LLMCalls != 2 { + t.Errorf("cache accounting wrong: hits=%d calls=%d", m.CacheHits, m.LLMCalls) + } + if m.Accuracy < 0.666 || m.Accuracy > 0.667 { + t.Errorf("accuracy = %v; want ~0.666", m.Accuracy) + } + // Confusion should have a[a]=1, a[b]=1, c[c]=1 + if m.Confusion["a"]["a"] != 1 || m.Confusion["a"]["b"] != 1 || m.Confusion["c"]["c"] != 1 { + t.Errorf("confusion matrix unexpected: %+v", m.Confusion) + } +} + +func TestComputeCategoryMetrics_ErrorsExcludedFromConfusion(t *testing.T) { + results := []CategoryResult{ + {Row: CategoryRow{ExpectedCategory: "a"}, Want: "a", Err: "network down"}, + {Row: CategoryRow{ExpectedCategory: "a"}, Got: "a", Want: "a", Correct: true}, + } + m := computeCategoryMetrics(results) + if m.Errors != 1 { + t.Errorf("errors = %d; want 1", m.Errors) + } + // Only the non-errored row should appear in confusion. + total := 0 + for _, inner := range m.Confusion { + for _, v := range inner { + total += v + } + } + if total != 1 { + t.Errorf("confusion should exclude errors; total=%d", total) + } +} + +func TestCategoryCache_ModelScoped(t *testing.T) { + c := newCategoryCache() + r := CategoryRow{MarktName: "x", Stadt: "y", Year: 2026} + c.Put(r, "m1", enrich.Enrichment{Category: "mittelaltermarkt"}) + if _, ok := c.Get(r, "m2"); ok { + t.Error("cache hit under different model; should be a miss") + } + if v, ok := c.Get(r, "m1"); !ok || v.Category != "mittelaltermarkt" { + t.Error("cache miss on exact model match") + } +} + +func TestCategoryCache_RoundTrip(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "cat-cache.json") + c := newCategoryCache() + r := CategoryRow{MarktName: "Markt X", Stadt: "Dresden", Year: 2026} + c.Put(r, "m", enrich.Enrichment{ + Category: "mittelaltermarkt", + Description: "ein Markt", + }) + if err := saveCategoryCache(path, c); err != nil { + t.Fatal(err) + } + loaded, err := loadCategoryCache(path) + if err != nil { + t.Fatal(err) + } + v, ok := loaded.Get(r, "m") + if !ok || v.Category != "mittelaltermarkt" { + t.Errorf("round-trip lost data: %+v ok=%v", v, ok) + } +} + +func TestLoadCategoryCache_MissingAndCorrupt(t *testing.T) { + // Missing file → empty cache, no error. + c, err := loadCategoryCache(filepath.Join(t.TempDir(), "missing.json")) + if err != nil || c == nil || c.Entries == nil { + t.Errorf("missing file should yield empty cache: err=%v", err) + } + + // Corrupt file → empty cache + parse error reported. + dir := t.TempDir() + path := filepath.Join(dir, "cache.json") + if err := os.WriteFile(path, []byte("{garbage"), 0o644); err != nil { + t.Fatal(err) + } + c2, err := loadCategoryCache(path) + if err == nil { + t.Error("expected parse error so operator can investigate") + } + if c2 == nil || c2.Entries == nil { + t.Error("corrupt file should still return usable empty cache") + } +} + +func TestCategoryCacheKey_Stable(t *testing.T) { + r := CategoryRow{MarktName: "Markt X", Stadt: "Dresden", Year: 2026} + k1 := categoryCacheKey(r, "m") + k2 := categoryCacheKey(r, "m") + if k1 != k2 { + t.Error("cache key must be deterministic") + } + // Different year → different key. + r2 := r + r2.Year = 2027 + if categoryCacheKey(r2, "m") == k1 { + t.Error("year change should produce different key") + } +} diff --git a/backend/cmd/discovery-eval/fixtures/category.json b/backend/cmd/discovery-eval/fixtures/category.json new file mode 100644 index 0000000..b64074c --- /dev/null +++ b/backend/cmd/discovery-eval/fixtures/category.json @@ -0,0 +1,94 @@ +{ + "rows": [ + { + "markt_name": "Mittelaltermarkt Dresden", + "stadt": "Dresden", + "land": "Deutschland", + "year": 2026, + "quellen": ["https://www.marktkalendarium.de/markt/dresden-mittelaltermarkt"], + "expected_category": "mittelaltermarkt", + "note": "baseline: unambiguous Mittelaltermarkt" + }, + { + "markt_name": "Dresdner Striezelmarkt", + "stadt": "Dresden", + "land": "Deutschland", + "year": 2026, + "quellen": ["https://striezelmarkt.dresden.de/"], + "expected_category": "weihnachtsmarkt", + "note": "historical proper name; LLM must recognise Striezelmarkt as Christmas market" + }, + { + "markt_name": "Kaiser-Ludwig-Markt Landsberg", + "stadt": "Landsberg am Lech", + "land": "Deutschland", + "year": 2026, + "quellen": ["https://www.landsberg.de/kaiser-ludwig-markt"], + "expected_category": "mittelaltermarkt", + "note": "themed medieval market" + }, + { + "markt_name": "Ritterfest Burg Stolpen", + "stadt": "Stolpen", + "land": "Deutschland", + "year": 2026, + "quellen": ["https://www.burg-stolpen.org/ritterfest"], + "expected_category": "ritterfest", + "note": "knight-themed festival, not general market" + }, + { + "markt_name": "Weihnachtsmarkt Nürnberg", + "stadt": "Nürnberg", + "land": "Deutschland", + "year": 2026, + "quellen": ["https://www.christkindlesmarkt.de/"], + "expected_category": "weihnachtsmarkt", + "note": "the canonical German christmas market" + }, + { + "markt_name": "Handwerkermarkt Rothenburg", + "stadt": "Rothenburg ob der Tauber", + "land": "Deutschland", + "year": 2026, + "quellen": ["https://www.rothenburg.de/handwerkermarkt"], + "expected_category": "handwerkermarkt", + "note": "craft market; not medieval-themed" + }, + { + "markt_name": "Schlossfest Schönbrunn", + "stadt": "Wien", + "land": "Oesterreich", + "year": 2026, + "quellen": ["https://www.schoenbrunn.at/schlossfest"], + "expected_category": "schlossfest", + "note": "castle festival, broader than market — LLM should not default to mittelaltermarkt" + }, + { + "markt_name": "Ritterturnier Burg Kreuzenstein", + "stadt": "Leobendorf", + "land": "Oesterreich", + "year": 2026, + "quellen": ["https://www.burg-kreuzenstein.com/ritterturnier"], + "expected_category": "ritterturnier", + "note": "jousting event; distinct from market / fest" + }, + { + "markt_name": "Kirchweih Fürth", + "stadt": "Fürth", + "land": "Deutschland", + "year": 2026, + "quellen": ["https://www.fuerth.de/kirchweih"], + "expected_category": "kirchweih", + "note": "traditional Bavarian parish fair" + }, + { + "markt_name": "Mittelaltermarkt auf der Ronneburg", + "stadt": "Ronneburg", + "land": "Deutschland", + "year": 2026, + "quellen": ["https://www.ronneburg.de/mittelaltermarkt"], + "expected_category": "mittelaltermarkt", + "note": "venue-prefixed mittelaltermarkt" + } + ] +} diff --git a/backend/cmd/discovery-eval/main.go b/backend/cmd/discovery-eval/main.go index 1890896..cff7eef 100644 --- a/backend/cmd/discovery-eval/main.go +++ b/backend/cmd/discovery-eval/main.go @@ -1,18 +1,24 @@ -// discovery-eval measures the MistralSimilarityClassifier against a labelled -// fixture. Reports precision/recall/F1/accuracy + a confidence calibration -// table, optionally gated on an F1 threshold for CI use. +// discovery-eval measures discovery's AI-backed components against labelled +// fixtures. Two modes: +// +// -mode similarity (default) — grades MistralSimilarityClassifier on +// pair-labelled fixtures. Precision/recall/F1/accuracy +// + confidence calibration. +// -mode category — grades MistralLLMEnricher's `category` output on +// row-labelled fixtures. Accuracy + per-label confusion. // // Usage: // // AI_API_KEY=... AI_MODEL_COMPLEX=mistral-large-latest \ // discovery-eval \ +// -mode similarity \ // -fixture backend/cmd/discovery-eval/fixtures/similarity.json \ // -cache .eval-cache.json \ // -threshold 0.8 \ // -report eval-report.json // -// The cache file is keyed on (pair_key, model) — rerunning against the same -// model+fixtures is free. Bump the model or edit a fixture to force a refresh. +// Each mode has its own cache key so switching modes doesn't churn entries. +// Bump AI_MODEL_COMPLEX or edit a fixture to force a refresh. package main import ( @@ -25,6 +31,12 @@ import ( "marktvogt.de/backend/internal/domain/discovery/enrich" "marktvogt.de/backend/internal/pkg/ai" + "marktvogt.de/backend/internal/pkg/scrape" +) + +const ( + modeSimilarity = "similarity" + modeCategory = "category" ) // realMain returns the desired exit code. Kept separate from main() so @@ -36,63 +48,144 @@ func realMain() int { }))) var ( - fixturePath = flag.String("fixture", "backend/cmd/discovery-eval/fixtures/similarity.json", "path to labelled fixture JSON") + mode = flag.String("mode", modeSimilarity, "eval mode: similarity | category") + fixturePath = flag.String("fixture", "", "path to labelled fixture JSON (defaults per mode)") cachePath = flag.String("cache", ".eval-cache.json", "path to local verdict cache (gitignored)") reportPath = flag.String("report", "", "optional path to write machine-readable JSON report") - threshold = flag.Float64("threshold", 0.0, "fail (exit 1) when F1 is below this value; 0 disables gating") + threshold = flag.Float64("threshold", 0.0, "fail (exit 1) when F1/accuracy is below this value; 0 disables gating") ) flag.Parse() - fixture, err := loadFixture(*fixturePath) - if err != nil { - slog.Error("load fixture", "path", *fixturePath, "error", err) - return 2 - } - slog.Info("loaded fixture", "pairs", len(fixture.Pairs), "path", *fixturePath) - apiKey := os.Getenv("AI_API_KEY") model := os.Getenv("AI_MODEL_COMPLEX") if model == "" { model = "mistral-large-latest" } + userAgent := os.Getenv("AI_USER_AGENT") + if userAgent == "" { + userAgent = "marktvogt-eval/1.0 (+https://marktvogt.de)" + } client := ai.New(apiKey, "", model, 1.0) if !client.Enabled() { slog.Error("AI client not configured (set AI_API_KEY)") return 2 } + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) + defer cancel() + + switch *mode { + case modeSimilarity: + return runSimilarityMode(ctx, client, model, pathWithDefault(*fixturePath, "backend/cmd/discovery-eval/fixtures/similarity.json"), *cachePath, *reportPath, *threshold) + case modeCategory: + scraper := scrape.New(userAgent) + enricher := enrich.NewMistralLLMEnricher(client, scraper) + return runCategoryMode(ctx, enricher, model, pathWithDefault(*fixturePath, "backend/cmd/discovery-eval/fixtures/category.json"), *cachePath, *reportPath, *threshold) + default: + slog.Error("unknown mode", "mode", *mode, "valid", []string{modeSimilarity, modeCategory}) + return 2 + } +} + +func pathWithDefault(p, dflt string) string { + if p == "" { + return dflt + } + return p +} + +// runSimilarityMode is the original MR 5 eval path, lifted out of main() so +// the mode switch stays readable. +func runSimilarityMode( + ctx context.Context, + client *ai.Client, + model, fixturePath, cachePath, reportPath string, + threshold float64, +) int { + fixture, err := loadFixture(fixturePath) + if err != nil { + slog.Error("load fixture", "path", fixturePath, "error", err) + return 2 + } + slog.Info("loaded fixture", "mode", modeSimilarity, "pairs", len(fixture.Pairs), "path", fixturePath) + classifier := enrich.NewMistralSimilarityClassifier(client) - cache, err := loadCache(*cachePath) + cache, err := loadCache(cachePath) if err != nil { - slog.Warn("cache load failed; starting empty", "path", *cachePath, "error", err) + slog.Warn("cache load failed; starting empty", "path", cachePath, "error", err) cache = newCache() } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) - defer cancel() - results, err := run(ctx, classifier, cache, fixture, model) if err != nil { slog.Error("eval run failed", "error", err) return 2 } - if err := saveCache(*cachePath, cache); err != nil { - // Non-fatal — results still computed; next run just pays again. - slog.Warn("cache save failed; metrics still reported", "path", *cachePath, "error", err) + if err := saveCache(cachePath, cache); err != nil { + slog.Warn("cache save failed; metrics still reported", "path", cachePath, "error", err) } metrics := computeMetrics(results) printSummary(os.Stdout, results, metrics, model) - if *reportPath != "" { - if err := writeReport(*reportPath, results, metrics, model); err != nil { - slog.Warn("report write failed", "path", *reportPath, "error", err) + if reportPath != "" { + if err := writeReport(reportPath, results, metrics, model); err != nil { + slog.Warn("report write failed", "path", reportPath, "error", err) } } - if *threshold > 0 && metrics.F1 < *threshold { - fmt.Fprintf(os.Stderr, "\nFAIL: F1=%.3f < threshold=%.3f\n", metrics.F1, *threshold) + if threshold > 0 && metrics.F1 < threshold { + fmt.Fprintf(os.Stderr, "\nFAIL: F1=%.3f < threshold=%.3f\n", metrics.F1, threshold) + return 1 + } + return 0 +} + +// runCategoryMode grades MistralLLMEnricher's category field against a +// labelled fixture. Uses its own cache shape (CategoryCache) so the +// similarity and category runs don't collide on disk. +func runCategoryMode( + ctx context.Context, + enricher enrich.LLMEnricher, + model, fixturePath, cachePath, reportPath string, + threshold float64, +) int { + fixture, err := loadCategoryFixture(fixturePath) + if err != nil { + slog.Error("load fixture", "path", fixturePath, "error", err) + return 2 + } + slog.Info("loaded fixture", "mode", modeCategory, "rows", len(fixture.Rows), "path", fixturePath) + + cache, err := loadCategoryCache(cachePath) + if err != nil { + slog.Warn("cache load failed; starting empty", "path", cachePath, "error", err) + cache = newCategoryCache() + } + + results, err := runCategory(ctx, enricher, cache, fixture, model) + if err != nil { + slog.Error("eval run failed", "error", err) + return 2 + } + + if err := saveCategoryCache(cachePath, cache); err != nil { + slog.Warn("cache save failed; metrics still reported", "path", cachePath, "error", err) + } + + metrics := computeCategoryMetrics(results) + printCategorySummary(os.Stdout, results, metrics, model) + + if reportPath != "" { + if err := writeCategoryReport(reportPath, results, metrics, model); err != nil { + slog.Warn("report write failed", "path", reportPath, "error", err) + } + } + + if threshold > 0 && metrics.Accuracy < threshold { + fmt.Fprintf(os.Stderr, "\nFAIL: accuracy=%.3f < threshold=%.3f\n", metrics.Accuracy, threshold) return 1 } return 0