126cc58cbf
- Extract readJSONFile + writeJSONAtomic in cache.go; category cache reuses them (saveCategoryCache is one line, loadCategoryCache uses the standard load-or-empty shape). - Drop dead errMsg param from scoreCategoryResult (always ""). - Wrap writeCategoryReport errors with context for consistency. - Wrap runSimilarityMode / runCategoryMode's 5 per-mode flags into an evalConfig struct so params don't drift. - Promote validModes to a package-level var. - Remove redundant cache = new...() fallback after load* (both load helpers already return a non-nil empty cache on error). - Strip narrating / diff-referencing comments per CLAUDE.md; keep the one genuine WHY on normalizeCategory (divergence from normalize.Name). Net -54 lines across 4 files; go build + go vet + tests green.
292 lines
7.4 KiB
Go
292 lines
7.4 KiB
Go
package main
|
||
|
||
import (
|
||
"context"
|
||
"crypto/sha256"
|
||
"encoding/hex"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"log/slog"
|
||
"os"
|
||
"sort"
|
||
"strings"
|
||
|
||
"marktvogt.de/backend/internal/domain/discovery/enrich"
|
||
)
|
||
|
||
type CategoryFixture struct {
|
||
Rows []CategoryRow `json:"rows"`
|
||
}
|
||
|
||
type CategoryRow struct {
|
||
MarktName string `json:"markt_name"`
|
||
Stadt string `json:"stadt"`
|
||
Bundesland string `json:"bundesland,omitempty"`
|
||
Land string `json:"land,omitempty"`
|
||
Year int `json:"year,omitempty"`
|
||
Quellen []string `json:"quellen"`
|
||
ExpectedCategory string `json:"expected_category"`
|
||
Note string `json:"note,omitempty"`
|
||
}
|
||
|
||
func loadCategoryFixture(path string) (*CategoryFixture, error) {
|
||
data, err := os.ReadFile(path)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("read fixture: %w", err)
|
||
}
|
||
var f CategoryFixture
|
||
if err := json.Unmarshal(data, &f); err != nil {
|
||
return nil, fmt.Errorf("parse fixture: %w", err)
|
||
}
|
||
if len(f.Rows) == 0 {
|
||
return nil, fmt.Errorf("fixture has no rows")
|
||
}
|
||
return &f, nil
|
||
}
|
||
|
||
type CategoryCache struct {
|
||
Entries map[string]enrich.Enrichment `json:"entries"`
|
||
}
|
||
|
||
func newCategoryCache() *CategoryCache {
|
||
return &CategoryCache{Entries: map[string]enrich.Enrichment{}}
|
||
}
|
||
|
||
func categoryCacheKey(r CategoryRow, model string) string {
|
||
raw := fmt.Sprintf("%s|%s|%d|%s",
|
||
strings.ToLower(r.MarktName), strings.ToLower(r.Stadt), r.Year, model)
|
||
sum := sha256.Sum256([]byte(raw))
|
||
return hex.EncodeToString(sum[:])
|
||
}
|
||
|
||
func (c *CategoryCache) Get(r CategoryRow, model string) (enrich.Enrichment, bool) {
|
||
v, ok := c.Entries[categoryCacheKey(r, model)]
|
||
return v, ok
|
||
}
|
||
|
||
func (c *CategoryCache) Put(r CategoryRow, model string, v enrich.Enrichment) {
|
||
c.Entries[categoryCacheKey(r, model)] = v
|
||
}
|
||
|
||
func loadCategoryCache(path string) (*CategoryCache, error) {
|
||
c := newCategoryCache()
|
||
exists, err := readJSONFile(path, c)
|
||
if err != nil {
|
||
return newCategoryCache(), fmt.Errorf("parse cache (starting empty): %w", err)
|
||
}
|
||
if !exists {
|
||
return c, nil
|
||
}
|
||
if c.Entries == nil {
|
||
c.Entries = map[string]enrich.Enrichment{}
|
||
}
|
||
return c, nil
|
||
}
|
||
|
||
func saveCategoryCache(path string, c *CategoryCache) error {
|
||
return writeJSONAtomic(path, ".cat-cache-*.tmp", c)
|
||
}
|
||
|
||
// normalizeCategory is forgiving (ä→a, plural -märkte→-markt) — deliberately
|
||
// diverges from discovery/normalize.Name (which maps ä→ae for identity
|
||
// dedup). Category labels are free-form enough that we want maximum forgiveness.
|
||
func normalizeCategory(s string) string {
|
||
s = strings.TrimSpace(strings.ToLower(s))
|
||
replacer := strings.NewReplacer(
|
||
"ä", "a", "ö", "o", "ü", "u", "ß", "ss",
|
||
)
|
||
s = replacer.Replace(s)
|
||
if strings.HasSuffix(s, "markte") {
|
||
s = strings.TrimSuffix(s, "e")
|
||
}
|
||
return s
|
||
}
|
||
|
||
type CategoryResult struct {
|
||
Row CategoryRow
|
||
Got string
|
||
Want string
|
||
Correct bool
|
||
FromCache bool
|
||
Err string
|
||
}
|
||
|
||
func runCategory(
|
||
ctx context.Context,
|
||
enricher enrich.LLMEnricher,
|
||
cache *CategoryCache,
|
||
fixture *CategoryFixture,
|
||
model string,
|
||
) ([]CategoryResult, error) {
|
||
results := make([]CategoryResult, 0, len(fixture.Rows))
|
||
for i, r := range fixture.Rows {
|
||
if err := ctx.Err(); err != nil {
|
||
return results, err
|
||
}
|
||
|
||
if v, ok := cache.Get(r, model); ok {
|
||
results = append(results, scoreCategoryResult(r, v, true))
|
||
continue
|
||
}
|
||
|
||
req := enrich.LLMRequest{
|
||
MarktName: r.MarktName,
|
||
Stadt: r.Stadt,
|
||
Land: r.Land,
|
||
Bundesland: r.Bundesland,
|
||
Quellen: r.Quellen,
|
||
}
|
||
got, err := enricher.EnrichMissing(ctx, req)
|
||
if err != nil {
|
||
slog.Warn("enrich failed; scoring as incorrect",
|
||
"row_index", i, "markt", r.MarktName, "error", err)
|
||
results = append(results, CategoryResult{
|
||
Row: r,
|
||
Want: r.ExpectedCategory,
|
||
Err: err.Error(),
|
||
})
|
||
continue
|
||
}
|
||
cache.Put(r, model, got)
|
||
results = append(results, scoreCategoryResult(r, got, false))
|
||
}
|
||
return results, nil
|
||
}
|
||
|
||
func scoreCategoryResult(r CategoryRow, got enrich.Enrichment, fromCache bool) CategoryResult {
|
||
gotCat := strings.TrimSpace(got.Category)
|
||
return CategoryResult{
|
||
Row: r,
|
||
Got: gotCat,
|
||
Want: r.ExpectedCategory,
|
||
Correct: normalizeCategory(gotCat) == normalizeCategory(r.ExpectedCategory),
|
||
FromCache: fromCache,
|
||
}
|
||
}
|
||
|
||
type CategoryMetrics struct {
|
||
Total int `json:"total"`
|
||
Correct int `json:"correct"`
|
||
Incorrect int `json:"incorrect"`
|
||
Errors int `json:"errors"`
|
||
CacheHits int `json:"cache_hits"`
|
||
LLMCalls int `json:"llm_calls"`
|
||
Confusion map[string]map[string]int `json:"confusion"`
|
||
Accuracy float64 `json:"accuracy"`
|
||
}
|
||
|
||
func computeCategoryMetrics(results []CategoryResult) CategoryMetrics {
|
||
m := CategoryMetrics{
|
||
Total: len(results),
|
||
Confusion: map[string]map[string]int{},
|
||
}
|
||
if m.Total == 0 {
|
||
return m
|
||
}
|
||
for _, r := range results {
|
||
if r.FromCache {
|
||
m.CacheHits++
|
||
} else {
|
||
m.LLMCalls++
|
||
}
|
||
if r.Err != "" {
|
||
m.Errors++
|
||
continue
|
||
}
|
||
want := normalizeCategory(r.Want)
|
||
got := normalizeCategory(r.Got)
|
||
if r.Correct {
|
||
m.Correct++
|
||
} else {
|
||
m.Incorrect++
|
||
}
|
||
if _, ok := m.Confusion[want]; !ok {
|
||
m.Confusion[want] = map[string]int{}
|
||
}
|
||
m.Confusion[want][got]++
|
||
}
|
||
m.Accuracy = float64(m.Correct) / float64(m.Total)
|
||
return m
|
||
}
|
||
|
||
func printCategorySummary(w io.Writer, results []CategoryResult, m CategoryMetrics, model string) {
|
||
wf(w, "\n=== discovery-eval (category) ===\n")
|
||
wf(w, "model: %s\n", model)
|
||
wf(w, "rows: %d\n", m.Total)
|
||
wf(w, "cache: %d hits, %d llm calls\n", m.CacheHits, m.LLMCalls)
|
||
wf(w, "\n")
|
||
wf(w, "correct: %d\n", m.Correct)
|
||
wf(w, "incorrect: %d\n", m.Incorrect)
|
||
wf(w, "errors: %d\n", m.Errors)
|
||
wf(w, "accuracy: %.3f\n", m.Accuracy)
|
||
|
||
if len(m.Confusion) > 0 {
|
||
labels := make([]string, 0, len(m.Confusion))
|
||
for k := range m.Confusion {
|
||
labels = append(labels, k)
|
||
}
|
||
sort.Strings(labels)
|
||
wf(w, "\nconfusion (want → predictions):\n")
|
||
for _, want := range labels {
|
||
preds := m.Confusion[want]
|
||
predKeys := make([]string, 0, len(preds))
|
||
for k := range preds {
|
||
predKeys = append(predKeys, k)
|
||
}
|
||
sort.Strings(predKeys)
|
||
wf(w, " %-24s", want)
|
||
first := true
|
||
for _, p := range predKeys {
|
||
if !first {
|
||
wf(w, ", ")
|
||
}
|
||
marker := ""
|
||
if p != want {
|
||
marker = "!"
|
||
}
|
||
wf(w, "%s%s×%d", marker, p, preds[p])
|
||
first = false
|
||
}
|
||
wf(w, "\n")
|
||
}
|
||
}
|
||
|
||
var wrong []CategoryResult
|
||
for _, r := range results {
|
||
if !r.Correct || r.Err != "" {
|
||
wrong = append(wrong, r)
|
||
}
|
||
}
|
||
if len(wrong) > 0 {
|
||
wf(w, "\nmismatches (%d):\n", len(wrong))
|
||
for _, r := range wrong {
|
||
if r.Err != "" {
|
||
wf(w, " ERROR %q (%s): %s\n", r.Row.MarktName, r.Row.Stadt, r.Err)
|
||
continue
|
||
}
|
||
wf(w, " want=%q got=%q %q (%s)\n", r.Want, r.Got, r.Row.MarktName, r.Row.Stadt)
|
||
}
|
||
}
|
||
wf(w, "\n")
|
||
}
|
||
|
||
type CategoryReport struct {
|
||
Mode string `json:"mode"`
|
||
Model string `json:"model"`
|
||
Metrics CategoryMetrics `json:"metrics"`
|
||
Results []CategoryResult `json:"results"`
|
||
}
|
||
|
||
func writeCategoryReport(path string, results []CategoryResult, m CategoryMetrics, model string) error {
|
||
rep := CategoryReport{Mode: "category", Model: model, Metrics: m, Results: results}
|
||
data, err := json.MarshalIndent(rep, "", " ")
|
||
if err != nil {
|
||
return fmt.Errorf("marshal report: %w", err)
|
||
}
|
||
if err := os.WriteFile(path, data, 0o644); err != nil {
|
||
return fmt.Errorf("write report %s: %w", path, err)
|
||
}
|
||
return nil
|
||
}
|