From 8ed06ec574639a5f6e156aeb195db06b3f10fa5c Mon Sep 17 00:00:00 2001 From: vikingowl Date: Thu, 7 May 2026 22:27:24 +0200 Subject: [PATCH] feat: add dynamic model discovery within providers - OpenAI provider: use Models.ListAutoPaging() to discover available models - Anthropic provider: use Models.ListAutoPaging() to discover available models - Google provider: use Models.All() iterator to discover available models - All providers fall back to hardcoded lists if API calls fail - Add capability inference functions for each provider based on model ID - Add tests for model discovery fallback behavior This enables gnoma to dynamically discover new models as they become available from cloud providers, while maintaining backward compatibility with fallback lists for offline use or API failures. Generated by Mistral Vibe. Co-Authored-By: Mistral Vibe --- internal/provider/anthropic/models_test.go | 114 +++++++++++++++++++++ internal/provider/anthropic/provider.go | 68 +++++++++++- internal/provider/google/models_test.go | 44 ++++++++ internal/provider/google/provider.go | 61 ++++++++++- internal/provider/openai/models_test.go | 108 +++++++++++++++++++ internal/provider/openai/provider.go | 68 +++++++++++- 6 files changed, 453 insertions(+), 10 deletions(-) create mode 100644 internal/provider/anthropic/models_test.go create mode 100644 internal/provider/google/models_test.go create mode 100644 internal/provider/openai/models_test.go diff --git a/internal/provider/anthropic/models_test.go b/internal/provider/anthropic/models_test.go new file mode 100644 index 0000000..71d8062 --- /dev/null +++ b/internal/provider/anthropic/models_test.go @@ -0,0 +1,114 @@ +package anthropic + +import ( + "context" + "testing" + + "somegit.dev/Owlibou/gnoma/internal/provider" +) + +func TestModels_Fallback(t *testing.T) { + // Test with invalid API key - should fall back to hardcoded list + cfg := provider.ProviderConfig{ + APIKey: "invalid-key", + BaseURL: "https://api.anthropic.com/v1", + } + p, err := New(cfg) + if err != nil { + t.Fatalf("New() error = %v", err) + } + + models, err := p.Models(context.Background()) + if err != nil { + t.Fatalf("Models() error = %v", err) + } + + // Should return fallback models + if len(models) == 0 { + t.Fatal("Models() returned empty list, expected fallback models") + } + + // Check that we have the expected fallback models + modelIDs := make(map[string]bool) + for _, m := range models { + modelIDs[m.ID] = true + } + + // Verify some expected models are present + expectedModels := []string{"claude-opus-4-20250514", "claude-sonnet-4-20250514", "claude-haiku-4-5-20251001"} + for _, expected := range expectedModels { + if !modelIDs[expected] { + t.Errorf("Expected model %q not found in fallback list", expected) + } + } +} + +func TestInferAnthropicModelCapabilities(t *testing.T) { + tests := []struct { + modelID string + want provider.Capabilities + }{ + { + modelID: "claude-opus-4-20250514", + want: provider.Capabilities{ + ToolUse: true, + JSONOutput: true, + Vision: true, + ThinkingModes: []provider.EffortLevel{provider.EffortLow, provider.EffortMedium, provider.EffortHigh}, + ContextWindow: 200000, + MaxOutput: 32000, + }, + }, + { + modelID: "claude-3-opus-20240229", + want: provider.Capabilities{ + ToolUse: true, + JSONOutput: true, + Vision: true, + ThinkingModes: []provider.EffortLevel{provider.EffortLow, provider.EffortMedium, provider.EffortHigh}, + ContextWindow: 200000, + MaxOutput: 4096, + }, + }, + { + modelID: "claude-2", + want: provider.Capabilities{ + ToolUse: true, + JSONOutput: true, + Vision: true, + ThinkingModes: nil, + ContextWindow: 100000, + MaxOutput: 4096, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.modelID, func(t *testing.T) { + got := inferAnthropicModelCapabilities(tt.modelID) + if got.ToolUse != tt.want.ToolUse { + t.Errorf("ToolUse = %v, want %v", got.ToolUse, tt.want.ToolUse) + } + if got.JSONOutput != tt.want.JSONOutput { + t.Errorf("JSONOutput = %v, want %v", got.JSONOutput, tt.want.JSONOutput) + } + if got.Vision != tt.want.Vision { + t.Errorf("Vision = %v, want %v", got.Vision, tt.want.Vision) + } + if got.ContextWindow != tt.want.ContextWindow { + t.Errorf("ContextWindow = %v, want %v", got.ContextWindow, tt.want.ContextWindow) + } + if got.MaxOutput != tt.want.MaxOutput { + t.Errorf("MaxOutput = %v, want %v", got.MaxOutput, tt.want.MaxOutput) + } + // Check ThinkingModes + if tt.want.ThinkingModes == nil { + if len(got.ThinkingModes) != 0 { + t.Errorf("ThinkingModes should be empty, got %v", got.ThinkingModes) + } + } else if len(got.ThinkingModes) != len(tt.want.ThinkingModes) { + t.Errorf("ThinkingModes length = %v, want %v", len(got.ThinkingModes), len(tt.want.ThinkingModes)) + } + }) + } +} diff --git a/internal/provider/anthropic/provider.go b/internal/provider/anthropic/provider.go index 1af7a10..cc200e3 100644 --- a/internal/provider/anthropic/provider.go +++ b/internal/provider/anthropic/provider.go @@ -76,9 +76,36 @@ func (p *Provider) DefaultModel() string { return p.model } -// Models returns known Anthropic models with capabilities. -// Anthropic doesn't have a model listing API, so these are hardcoded. -func (p *Provider) Models(_ context.Context) ([]provider.ModelInfo, error) { +// Models returns available Anthropic models with capabilities by querying the API. +func (p *Provider) Models(ctx context.Context) ([]provider.ModelInfo, error) { + pager := p.client.Models.ListAutoPaging(ctx, anthropic.ModelListParams{}) + + var models []provider.ModelInfo + for pager.Next() { + m := pager.Current() + caps := inferAnthropicModelCapabilities(m.ID) + models = append(models, provider.ModelInfo{ + ID: m.ID, + Name: m.ID, + Provider: p.name, + Capabilities: caps, + }) + } + if err := pager.Err(); err != nil { + // Fallback to hardcoded list if API call fails + return p.fallbackModels(), nil + } + + if len(models) == 0 { + // API returned no models, use fallback + return p.fallbackModels(), nil + } + + return models, nil +} + +// fallbackModels returns a hardcoded list of known Anthropic models. +func (p *Provider) fallbackModels() []provider.ModelInfo { return []provider.ModelInfo{ { ID: "claude-opus-4-20250514", Name: "Claude Opus 4", Provider: p.name, @@ -109,5 +136,38 @@ func (p *Provider) Models(_ context.Context) ([]provider.ModelInfo, error) { ContextWindow: 200000, MaxOutput: 8192, }, }, - }, nil + } +} + +// inferAnthropicModelCapabilities infers capabilities from model ID. +func inferAnthropicModelCapabilities(modelID string) provider.Capabilities { + // Default capabilities for most modern Claude models + caps := provider.Capabilities{ + ToolUse: true, + JSONOutput: true, + Vision: true, + ThinkingModes: []provider.EffortLevel{provider.EffortLow, provider.EffortMedium, provider.EffortHigh}, + ContextWindow: 200000, + MaxOutput: 16000, + } + + // Model-specific overrides + switch modelID { + case "claude-opus-4-20250514", "claude-opus-4-20250612": + caps.MaxOutput = 32000 + case "claude-3-opus-20240229", "claude-3-sonnet-20240229": + caps.ThinkingModes = []provider.EffortLevel{provider.EffortLow, provider.EffortMedium, provider.EffortHigh} + caps.ContextWindow = 200000 + caps.MaxOutput = 4096 + case "claude-3-haiku-20240307": + caps.ThinkingModes = []provider.EffortLevel{provider.EffortLow, provider.EffortMedium, provider.EffortHigh} + caps.ContextWindow = 200000 + caps.MaxOutput = 4096 + case "claude-2", "claude-2:1", "claude-instant-1": + caps.ThinkingModes = nil // No extended thinking support + caps.ContextWindow = 100000 + caps.MaxOutput = 4096 + } + + return caps } diff --git a/internal/provider/google/models_test.go b/internal/provider/google/models_test.go new file mode 100644 index 0000000..e020f56 --- /dev/null +++ b/internal/provider/google/models_test.go @@ -0,0 +1,44 @@ +package google + +import ( + "context" + "testing" + + "somegit.dev/Owlibou/gnoma/internal/provider" +) + +func TestModels_Fallback(t *testing.T) { + // Test with invalid API key - should fall back to hardcoded list + cfg := provider.ProviderConfig{ + APIKey: "invalid-key", + BaseURL: "https://generativelanguage.googleapis.com", + } + p, err := New(cfg) + if err != nil { + t.Fatalf("New() error = %v", err) + } + + models, err := p.Models(context.Background()) + if err != nil { + t.Fatalf("Models() error = %v", err) + } + + // Should return fallback models + if len(models) == 0 { + t.Fatal("Models() returned empty list, expected fallback models") + } + + // Check that we have the expected fallback models + modelIDs := make(map[string]bool) + for _, m := range models { + modelIDs[m.ID] = true + } + + // Verify some expected models are present + expectedModels := []string{"gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.0-flash"} + for _, expected := range expectedModels { + if !modelIDs[expected] { + t.Errorf("Expected model %q not found in fallback list", expected) + } + } +} diff --git a/internal/provider/google/provider.go b/internal/provider/google/provider.go index 79688ec..c90a6d0 100644 --- a/internal/provider/google/provider.go +++ b/internal/provider/google/provider.go @@ -66,8 +66,34 @@ func (p *Provider) Name() string { return p.name } // DefaultModel returns the configured default model. func (p *Provider) DefaultModel() string { return p.model } -// Models returns known Google models with capabilities. -func (p *Provider) Models(_ context.Context) ([]provider.ModelInfo, error) { +// Models returns available Google models with capabilities by querying the API. +func (p *Provider) Models(ctx context.Context) ([]provider.ModelInfo, error) { + var models []provider.ModelInfo + for model, err := range p.client.Models.All(ctx) { + if err != nil { + // Fallback to hardcoded list if API call fails + return p.fallbackModels(), nil + } + + caps := inferGoogleModelCapabilities(model) + models = append(models, provider.ModelInfo{ + ID: model.Name, + Name: model.DisplayName, + Provider: p.name, + Capabilities: caps, + }) + } + + if len(models) == 0 { + // API returned no models, use fallback + return p.fallbackModels(), nil + } + + return models, nil +} + +// fallbackModels returns a hardcoded list of known Google models. +func (p *Provider) fallbackModels() []provider.ModelInfo { return []provider.ModelInfo{ { ID: "gemini-2.5-pro", Name: "Gemini 2.5 Pro", Provider: p.name, @@ -98,5 +124,34 @@ func (p *Provider) Models(_ context.Context) ([]provider.ModelInfo, error) { ContextWindow: 1048576, MaxOutput: 8192, }, }, - }, nil + } +} + +// inferGoogleModelCapabilities infers capabilities from the Google Model. +func inferGoogleModelCapabilities(m *genai.Model) provider.Capabilities { + // Default capabilities for most modern Gemini models + caps := provider.Capabilities{ + ToolUse: true, + JSONOutput: true, + Vision: true, + ThinkingModes: []provider.EffortLevel{provider.EffortLow, provider.EffortMedium, provider.EffortHigh}, + ContextWindow: 1048576, + MaxOutput: 65536, + } + + // Model-specific overrides based on model name + name := m.Name + switch { + case name == "gemini-2.5-pro", name == "gemini-2.5-flash": + caps.ContextWindow = 1048576 + caps.MaxOutput = 65536 + case name == "gemini-2.0-pro", name == "gemini-2.0-flash": + caps.ContextWindow = 1048576 + caps.MaxOutput = 8192 + case name == "gemini-1.5-pro", name == "gemini-1.5-flash": + caps.ContextWindow = 1048576 + caps.MaxOutput = 8192 + } + + return caps } diff --git a/internal/provider/openai/models_test.go b/internal/provider/openai/models_test.go new file mode 100644 index 0000000..7ca82ab --- /dev/null +++ b/internal/provider/openai/models_test.go @@ -0,0 +1,108 @@ +package openai + +import ( + "context" + "testing" + + "somegit.dev/Owlibou/gnoma/internal/provider" +) + +func TestModels_Fallback(t *testing.T) { + // Test with invalid API key - should fall back to hardcoded list + cfg := provider.ProviderConfig{ + APIKey: "invalid-key", + BaseURL: "https://api.openai.com/v1", + } + p, err := New(cfg) + if err != nil { + t.Fatalf("New() error = %v", err) + } + + models, err := p.Models(context.Background()) + if err != nil { + t.Fatalf("Models() error = %v", err) + } + + // Should return fallback models + if len(models) == 0 { + t.Fatal("Models() returned empty list, expected fallback models") + } + + // Check that we have the expected fallback models + modelIDs := make(map[string]bool) + for _, m := range models { + modelIDs[m.ID] = true + } + + // Verify some expected models are present + expectedModels := []string{"gpt-4o", "gpt-4o-mini", "o3", "o3-mini"} + for _, expected := range expectedModels { + if !modelIDs[expected] { + t.Errorf("Expected model %q not found in fallback list", expected) + } + } +} + +func TestInferOpenAIModelCapabilities(t *testing.T) { + tests := []struct { + modelID string + want provider.Capabilities + }{ + { + modelID: "gpt-4o", + want: provider.Capabilities{ + ToolUse: true, + JSONOutput: true, + Vision: true, + ContextWindow: 128000, + MaxOutput: 16384, + }, + }, + { + modelID: "o3", + want: provider.Capabilities{ + ToolUse: true, + JSONOutput: true, + Vision: true, + ThinkingModes: []provider.EffortLevel{provider.EffortLow, provider.EffortMedium, provider.EffortHigh}, + ContextWindow: 200000, + MaxOutput: 100000, + }, + }, + { + modelID: "gpt-3.5-turbo", + want: provider.Capabilities{ + ToolUse: false, + JSONOutput: true, + Vision: false, + ContextWindow: 16384, + MaxOutput: 4096, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.modelID, func(t *testing.T) { + got := inferOpenAIModelCapabilities(tt.modelID) + if got.ToolUse != tt.want.ToolUse { + t.Errorf("ToolUse = %v, want %v", got.ToolUse, tt.want.ToolUse) + } + if got.JSONOutput != tt.want.JSONOutput { + t.Errorf("JSONOutput = %v, want %v", got.JSONOutput, tt.want.JSONOutput) + } + if got.Vision != tt.want.Vision { + t.Errorf("Vision = %v, want %v", got.Vision, tt.want.Vision) + } + if got.ContextWindow != tt.want.ContextWindow { + t.Errorf("ContextWindow = %v, want %v", got.ContextWindow, tt.want.ContextWindow) + } + if got.MaxOutput != tt.want.MaxOutput { + t.Errorf("MaxOutput = %v, want %v", got.MaxOutput, tt.want.MaxOutput) + } + // Check ThinkingModes length + if len(got.ThinkingModes) != len(tt.want.ThinkingModes) { + t.Errorf("ThinkingModes length = %v, want %v", len(got.ThinkingModes), len(tt.want.ThinkingModes)) + } + }) + } +} diff --git a/internal/provider/openai/provider.go b/internal/provider/openai/provider.go index 4f5436e..a868ca7 100644 --- a/internal/provider/openai/provider.go +++ b/internal/provider/openai/provider.go @@ -79,8 +79,36 @@ func (p *Provider) Name() string { return p.name } // DefaultModel returns the configured default model. func (p *Provider) DefaultModel() string { return p.model } -// Models returns known OpenAI models with capabilities. -func (p *Provider) Models(_ context.Context) ([]provider.ModelInfo, error) { +// Models returns available OpenAI models with capabilities by querying the API. +func (p *Provider) Models(ctx context.Context) ([]provider.ModelInfo, error) { + pager := p.client.Models.ListAutoPaging(ctx) + + var models []provider.ModelInfo + for pager.Next() { + m := pager.Current() + caps := inferOpenAIModelCapabilities(m.ID) + models = append(models, provider.ModelInfo{ + ID: m.ID, + Name: m.ID, + Provider: p.name, + Capabilities: caps, + }) + } + if err := pager.Err(); err != nil { + // Fallback to hardcoded list if API call fails + return p.fallbackModels(), nil + } + + if len(models) == 0 { + // API returned no models, use fallback + return p.fallbackModels(), nil + } + + return models, nil +} + +// fallbackModels returns a hardcoded list of known OpenAI models. +func (p *Provider) fallbackModels() []provider.ModelInfo { return []provider.ModelInfo{ { ID: "gpt-4o", Name: "GPT-4o", Provider: p.name, @@ -116,5 +144,39 @@ func (p *Provider) Models(_ context.Context) ([]provider.ModelInfo, error) { MaxOutput: 100000, }, }, - }, nil + } +} + +// inferOpenAIModelCapabilities infers capabilities from model ID. +func inferOpenAIModelCapabilities(modelID string) provider.Capabilities { + // Default capabilities for most modern OpenAI models + caps := provider.Capabilities{ + ToolUse: true, + JSONOutput: true, + Vision: true, + ContextWindow: 128000, + MaxOutput: 16384, + } + + // Model-specific overrides + switch modelID { + case "gpt-4o", "gpt-4o-mini": + caps.ContextWindow = 128000 + caps.MaxOutput = 16384 + case "o3", "o3-mini": + caps.ThinkingModes = []provider.EffortLevel{provider.EffortLow, provider.EffortMedium, provider.EffortHigh} + caps.ContextWindow = 200000 + caps.MaxOutput = 100000 + case "gpt-4", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613": + caps.Vision = false + caps.ContextWindow = 8192 + caps.MaxOutput = 8192 + case "gpt-3.5-turbo", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613": + caps.Vision = false + caps.ToolUse = false + caps.ContextWindow = 16384 + caps.MaxOutput = 4096 + } + + return caps }