From 6b3c673cd020f45693fd811f7ed8904b6468bfd8 Mon Sep 17 00:00:00 2001 From: vikingowl Date: Sat, 25 Apr 2026 12:42:53 +0200 Subject: [PATCH 1/3] feat(ai): tighter Gemini model filter with per-model pricing - Replace ListModelNames with ListModels returning ModelInfo structs - Name-based filter: require gemini- prefix, drop tuned models, block EOL 2.0 family, TTS/image/live/audio/robotics/embedding, Gemma/Imagen/Veo - Static pricing table with longest-prefix match; stable vs preview flag - Settings handler validates SetModel against allowed list (degrade-open) - Frontend dropdown shows input/output price per 1M tokens + Preview tag - Table-driven unit tests for filter, sort order and pricing lookup --- backend/internal/domain/settings/handler.go | 34 ++-- backend/internal/pkg/ai/gemini.go | 103 +++++++++-- backend/internal/pkg/ai/gemini_test.go | 170 ++++++++++++++++++ backend/internal/pkg/ai/provider.go | 2 +- web/src/lib/api/types.ts | 12 +- .../routes/admin/einstellungen/+page.svelte | 6 +- 6 files changed, 303 insertions(+), 24 deletions(-) create mode 100644 backend/internal/pkg/ai/gemini_test.go diff --git a/backend/internal/domain/settings/handler.go b/backend/internal/domain/settings/handler.go index ca6f444..949c4ae 100644 --- a/backend/internal/domain/settings/handler.go +++ b/backend/internal/domain/settings/handler.go @@ -12,14 +12,14 @@ import ( // AIStatus is the response payload for GET /admin/settings/ai. type AIStatus struct { - Provider string `json:"provider"` - Connected bool `json:"connected"` - Model string `json:"model"` - Models []string `json:"models"` - APIKeyFingerprint string `json:"api_key_fingerprint,omitempty"` - GroundingEnabled bool `json:"grounding_enabled"` - GroundingQuota int `json:"grounding_quota"` - Usage UsageSummary `json:"usage"` + Provider string `json:"provider"` + Connected bool `json:"connected"` + Model string `json:"model"` + Models []ai.ModelInfo `json:"models"` + APIKeyFingerprint string `json:"api_key_fingerprint,omitempty"` + GroundingEnabled bool `json:"grounding_enabled"` + GroundingQuota int `json:"grounding_quota"` + Usage UsageSummary `json:"usage"` } type UsageSummary struct { @@ -42,10 +42,10 @@ func NewHandler(provider *ai.GeminiProvider, store *Store, usageRepo *UsageRepo) func (h *Handler) GetAI(c *gin.Context) { ctx := c.Request.Context() - models, err := h.provider.ListModelNames(ctx) + models, err := h.provider.ListModels(ctx) connected := err == nil if models == nil { - models = []string{} + models = []ai.ModelInfo{} } // Fingerprint: last 4 chars of stored key (if any) @@ -85,6 +85,20 @@ func (h *Handler) SetModel(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"}) return } + // Validate against allowed list; degrade open if the list call fails (e.g. network blip). + if allowed, err := h.provider.ListModels(ctx); err == nil { + found := false + for _, m := range allowed { + if m.Name == req.Model { + found = true + break + } + } + if !found { + c.JSON(http.StatusBadRequest, gin.H{"error": "model not in allowed list"}) + return + } + } userID := callerID(c) if err := h.store.SetModel(ctx, req.Model, userID); err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save model"}) diff --git a/backend/internal/pkg/ai/gemini.go b/backend/internal/pkg/ai/gemini.go index fa4e1eb..d0a6f1b 100644 --- a/backend/internal/pkg/ai/gemini.go +++ b/backend/internal/pkg/ai/gemini.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "sort" "strings" "sync" "time" @@ -11,6 +12,95 @@ import ( "google.golang.org/genai" ) +// ModelInfo describes a Gemini model available for use. +type ModelInfo struct { + Name string `json:"name"` + DisplayName string `json:"display_name"` + Stable bool `json:"stable"` + Thinking bool `json:"thinking"` + InputTokenLimit int32 `json:"input_token_limit"` + InputUSDPerM float64 `json:"input_usd_per_m"` + OutputUSDPerM float64 `json:"output_usd_per_m"` +} + +// geminiPricing maps model name prefixes to $/1M token rates (≤200k tier). +// Source: https://ai.google.dev/gemini-api/docs/pricing — update manually when prices change. +var geminiPricing = map[string]struct{ in, out float64 }{ + "gemini-3.1-pro": {2.00, 12.00}, + "gemini-3.1-flash-lite": {0.25, 1.50}, + "gemini-3-flash": {0.50, 3.00}, + "gemini-2.5-pro": {1.25, 10.00}, + "gemini-2.5-flash-lite": {0.10, 0.40}, + "gemini-2.5-flash": {0.30, 2.50}, +} + +// priceFor returns the $/1M input and output token cost for the given model name. +// Uses longest-prefix match against geminiPricing; returns (0, 0) if unknown. +func priceFor(name string) (in, out float64) { + best := "" + for prefix, p := range geminiPricing { + if strings.HasPrefix(name, prefix) && len(prefix) > len(best) { + best = prefix + in, out = p.in, p.out + } + } + return +} + +// filterCompatibleModels selects models that work with our request shape: +// generateContent + systemInstruction + responseSchema + optional googleSearchRetrieval. +// +// We rely on name-based filtering rather than SupportedActions because the Gemini +// public API omits supportedGenerationMethods for stable text models, leaving +// SupportedActions empty even for fully compatible models like gemini-2.5-flash. +func filterCompatibleModels(items []*genai.Model) []ModelInfo { + blockedSubstrings := []string{ + "-tts", "-image", "-native-audio", "-live", + "-computer-use", "-robotics", "-embedding", + } + out := make([]ModelInfo, 0, len(items)) + for _, m := range items { + if m.TunedModelInfo != nil { + continue + } + name := strings.TrimPrefix(m.Name, "models/") + if !strings.HasPrefix(name, "gemini-") { + continue + } + if strings.HasPrefix(name, "gemini-2.0-") { + continue + } + blocked := false + for _, sub := range blockedSubstrings { + if strings.Contains(name, sub) { + blocked = true + break + } + } + if blocked { + continue + } + in, outP := priceFor(name) + out = append(out, ModelInfo{ + Name: name, + DisplayName: m.DisplayName, + Stable: !strings.Contains(name, "-preview"), + Thinking: m.Thinking, + InputTokenLimit: m.InputTokenLimit, + InputUSDPerM: in, + OutputUSDPerM: outP, + }) + } + // Stable models first; within each group sort by name descending (newer families first). + sort.Slice(out, func(i, j int) bool { + if out[i].Stable != out[j].Stable { + return out[i].Stable + } + return out[i].Name > out[j].Name + }) + return out +} + // Gemini API pricing (as of 2026-04). Refresh constants when pricing changes. // https://ai.google.dev/gemini-api/docs/pricing const ( @@ -91,7 +181,7 @@ func (p *GeminiProvider) SetModel(model string) { p.model = model } -func (p *GeminiProvider) ListModelNames(ctx context.Context) ([]string, error) { +func (p *GeminiProvider) ListModels(ctx context.Context) ([]ModelInfo, error) { p.mu.RLock() client := p.client p.mu.RUnlock() @@ -102,16 +192,7 @@ func (p *GeminiProvider) ListModelNames(ctx context.Context) ([]string, error) { if err != nil { return nil, fmt.Errorf("gemini: list models: %w", err) } - var names []string - for _, m := range resp.Items { - for _, action := range m.SupportedActions { - if action == "generateContent" { - names = append(names, strings.TrimPrefix(m.Name, "models/")) - break - } - } - } - return names, nil + return filterCompatibleModels(resp.Items), nil } func (p *GeminiProvider) Chat(ctx context.Context, req *ChatRequest) (*ChatResponse, error) { diff --git a/backend/internal/pkg/ai/gemini_test.go b/backend/internal/pkg/ai/gemini_test.go new file mode 100644 index 0000000..e83861b --- /dev/null +++ b/backend/internal/pkg/ai/gemini_test.go @@ -0,0 +1,170 @@ +package ai + +import ( + "strings" + "testing" + + "google.golang.org/genai" +) + +func makeModel(name string, tuned, thinking bool) *genai.Model { + m := &genai.Model{ + Name: "models/" + name, + DisplayName: name, + Thinking: thinking, + InputTokenLimit: 32000, + } + if tuned { + m.TunedModelInfo = &genai.TunedModelInfo{} + } + return m +} + +func TestFilterCompatibleModels_KeepsGeminiTextFamilies(t *testing.T) { + // SupportedActions intentionally nil — stable text models don't have it populated. + kept := []*genai.Model{ + makeModel("gemini-2.5-pro", false, true), + makeModel("gemini-2.5-flash", false, false), + makeModel("gemini-2.5-flash-lite", false, false), + makeModel("gemini-3-flash-preview-04-2026", false, false), + makeModel("gemini-3.1-pro-preview-04-2026", false, true), + makeModel("gemini-3.1-flash-lite-preview-04-2026", false, false), + } + got := filterCompatibleModels(kept) + if len(got) != len(kept) { + t.Errorf("want %d models, got %d: %v", len(kept), len(got), modelNames(got)) + } +} + +func TestFilterCompatibleModels_DropsExcludedFamilies(t *testing.T) { + // SupportedActions is intentionally ignored — the Gemini public API omits + // supportedGenerationMethods for stable text models. Name-based filtering is + // the reliable gate; these cases verify every blocked category by name. + cases := []struct { + name string + model *genai.Model + }{ + {"tts", makeModel("gemini-2.5-flash-preview-tts", false, false)}, + {"pro tts", makeModel("gemini-2.5-pro-preview-tts", false, false)}, + {"image", makeModel("gemini-2.5-flash-image", false, false)}, + {"native audio", makeModel("gemini-2.5-flash-native-audio-preview-12-2025", false, false)}, + {"live", makeModel("gemini-2.5-flash-live-preview", false, false)}, + {"computer use", makeModel("gemini-2.5-computer-use-preview-10-2025", false, false)}, + {"robotics", makeModel("gemini-robotics-er-1.6-preview", false, false)}, + {"embedding", makeModel("gemini-embedding-001", false, false)}, + {"gemma", makeModel("gemma-3-27b-it", false, false)}, + {"gemma nano", makeModel("gemma-3n-e4b-it", false, false)}, + {"deep research", makeModel("deep-research-preview-04-2026", false, false)}, + {"deep research max", makeModel("deep-research-max-preview-04-2026", false, false)}, + {"imagen", makeModel("imagen-3.0-generate-001", false, false)}, + {"veo", makeModel("veo-3.1-generate-preview", false, false)}, + {"lyria", makeModel("lyria-realtime-exp", false, false)}, + {"learnlm", makeModel("learnlm-2.0-flash-experimental", false, false)}, + {"gemini 2.0 flash eol", makeModel("gemini-2.0-flash", false, false)}, + {"gemini 2.0 flash lite eol", makeModel("gemini-2.0-flash-lite", false, false)}, + {"tuned model", makeModel("gemini-2.5-flash", true, false)}, + } + for _, tc := range cases { + got := filterCompatibleModels([]*genai.Model{tc.model}) + if len(got) != 0 { + t.Errorf("case %q: want 0 models, got %d: %v", tc.name, len(got), modelNames(got)) + } + } +} + +func TestFilterCompatibleModels_StableField(t *testing.T) { + items := []*genai.Model{ + makeModel("gemini-2.5-flash", false, false), + makeModel("gemini-3-flash-preview-04-2026", false, false), + } + got := filterCompatibleModels(items) + if len(got) != 2 { + t.Fatalf("want 2, got %d", len(got)) + } + for _, m := range got { + expectStable := !strings.Contains(m.Name, "-preview") + if m.Stable != expectStable { + t.Errorf("model %q: Stable=%v, want %v", m.Name, m.Stable, expectStable) + } + } +} + +func TestFilterCompatibleModels_ThinkingField(t *testing.T) { + items := []*genai.Model{ + makeModel("gemini-2.5-pro", false, true), + makeModel("gemini-2.5-flash", false, false), + } + got := filterCompatibleModels(items) + if len(got) != 2 { + t.Fatalf("want 2, got %d", len(got)) + } + // find by name + for _, m := range got { + if m.Name == "gemini-2.5-pro" && !m.Thinking { + t.Errorf("gemini-2.5-pro: want Thinking=true") + } + if m.Name == "gemini-2.5-flash" && m.Thinking { + t.Errorf("gemini-2.5-flash: want Thinking=false") + } + } +} + +func TestFilterCompatibleModels_SortStableFirst(t *testing.T) { + items := []*genai.Model{ + makeModel("gemini-3-flash-preview-04-2026", false, false), + makeModel("gemini-2.5-pro", false, true), + makeModel("gemini-3.1-pro-preview-04-2026", false, true), + makeModel("gemini-2.5-flash-lite", false, false), + } + got := filterCompatibleModels(items) + if len(got) != 4 { + t.Fatalf("want 4, got %d", len(got)) + } + // First two must be stable + if !got[0].Stable || !got[1].Stable { + t.Errorf("first two should be stable, got %v %v", got[0].Name, got[1].Name) + } + // Last two must be preview + if got[2].Stable || got[3].Stable { + t.Errorf("last two should be preview, got %v %v", got[2].Name, got[3].Name) + } +} + +func TestPriceFor_KnownFamilies(t *testing.T) { + cases := []struct { + name string + wantIn float64 + wantOut float64 + }{ + {"gemini-2.5-flash-lite", 0.10, 0.40}, + {"gemini-2.5-flash-lite-preview-05-2026", 0.10, 0.40}, + {"gemini-2.5-flash", 0.30, 2.50}, + {"gemini-2.5-flash-preview-04-2026", 0.30, 2.50}, + {"gemini-2.5-pro", 1.25, 10.00}, + {"gemini-2.5-pro-preview-06-2026", 1.25, 10.00}, + {"gemini-3-flash-preview-04-2026", 0.50, 3.00}, + {"gemini-3.1-pro-preview-04-2026", 2.00, 12.00}, + {"gemini-3.1-flash-lite-preview-04-2026", 0.25, 1.50}, + } + for _, tc := range cases { + in, out := priceFor(tc.name) + if in != tc.wantIn || out != tc.wantOut { + t.Errorf("priceFor(%q): got (%v, %v), want (%v, %v)", tc.name, in, out, tc.wantIn, tc.wantOut) + } + } +} + +func TestPriceFor_UnknownReturnsZero(t *testing.T) { + in, out := priceFor("gemini-99-experimental-unknown") + if in != 0 || out != 0 { + t.Errorf("want (0, 0), got (%v, %v)", in, out) + } +} + +func modelNames(ms []ModelInfo) []string { + names := make([]string, len(ms)) + for i, m := range ms { + names[i] = m.Name + } + return names +} diff --git a/backend/internal/pkg/ai/provider.go b/backend/internal/pkg/ai/provider.go index 3aca71b..c35c210 100644 --- a/backend/internal/pkg/ai/provider.go +++ b/backend/internal/pkg/ai/provider.go @@ -16,7 +16,7 @@ type Provider interface { type ModelSelector interface { Model() string SetModel(string) - ListModelNames(ctx context.Context) ([]string, error) + ListModels(ctx context.Context) ([]ModelInfo, error) BaseURL() string } diff --git a/web/src/lib/api/types.ts b/web/src/lib/api/types.ts index 6b034ac..32ab8b6 100644 --- a/web/src/lib/api/types.ts +++ b/web/src/lib/api/types.ts @@ -206,11 +206,21 @@ export interface AIUsageEvent { error?: string; } +export interface AIModelInfo { + name: string; + display_name: string; + stable: boolean; + thinking: boolean; + input_token_limit: number; + input_usd_per_m: number; + output_usd_per_m: number; +} + export interface AIStatus { provider: string; connected: boolean; model: string; - models: string[]; + models: AIModelInfo[]; api_key_fingerprint?: string; grounding_enabled: boolean; grounding_quota: number; diff --git a/web/src/routes/admin/einstellungen/+page.svelte b/web/src/routes/admin/einstellungen/+page.svelte index 544ddd7..66c7753 100644 --- a/web/src/routes/admin/einstellungen/+page.svelte +++ b/web/src/routes/admin/einstellungen/+page.svelte @@ -183,7 +183,11 @@ class="focus:border-primary-500 focus:ring-primary-500 rounded-md border border-stone-300 bg-white px-3 py-2 text-sm text-stone-900 shadow-sm focus:ring-1 focus:outline-none dark:border-stone-600 dark:bg-stone-800 dark:text-stone-100" > {#each data.ai.models as model} - + {/each}