diff --git a/internal/provider/anthropic/models_test.go b/internal/provider/anthropic/models_test.go new file mode 100644 index 0000000..71d8062 --- /dev/null +++ b/internal/provider/anthropic/models_test.go @@ -0,0 +1,114 @@ +package anthropic + +import ( + "context" + "testing" + + "somegit.dev/Owlibou/gnoma/internal/provider" +) + +func TestModels_Fallback(t *testing.T) { + // Test with invalid API key - should fall back to hardcoded list + cfg := provider.ProviderConfig{ + APIKey: "invalid-key", + BaseURL: "https://api.anthropic.com/v1", + } + p, err := New(cfg) + if err != nil { + t.Fatalf("New() error = %v", err) + } + + models, err := p.Models(context.Background()) + if err != nil { + t.Fatalf("Models() error = %v", err) + } + + // Should return fallback models + if len(models) == 0 { + t.Fatal("Models() returned empty list, expected fallback models") + } + + // Check that we have the expected fallback models + modelIDs := make(map[string]bool) + for _, m := range models { + modelIDs[m.ID] = true + } + + // Verify some expected models are present + expectedModels := []string{"claude-opus-4-20250514", "claude-sonnet-4-20250514", "claude-haiku-4-5-20251001"} + for _, expected := range expectedModels { + if !modelIDs[expected] { + t.Errorf("Expected model %q not found in fallback list", expected) + } + } +} + +func TestInferAnthropicModelCapabilities(t *testing.T) { + tests := []struct { + modelID string + want provider.Capabilities + }{ + { + modelID: "claude-opus-4-20250514", + want: provider.Capabilities{ + ToolUse: true, + JSONOutput: true, + Vision: true, + ThinkingModes: []provider.EffortLevel{provider.EffortLow, provider.EffortMedium, provider.EffortHigh}, + ContextWindow: 200000, + MaxOutput: 32000, + }, + }, + { + modelID: "claude-3-opus-20240229", + want: provider.Capabilities{ + ToolUse: true, + JSONOutput: true, + Vision: true, + ThinkingModes: []provider.EffortLevel{provider.EffortLow, provider.EffortMedium, provider.EffortHigh}, + ContextWindow: 200000, + MaxOutput: 4096, + }, + }, + { + modelID: "claude-2", + want: provider.Capabilities{ + ToolUse: true, + JSONOutput: true, + Vision: true, + ThinkingModes: nil, + ContextWindow: 100000, + MaxOutput: 4096, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.modelID, func(t *testing.T) { + got := inferAnthropicModelCapabilities(tt.modelID) + if got.ToolUse != tt.want.ToolUse { + t.Errorf("ToolUse = %v, want %v", got.ToolUse, tt.want.ToolUse) + } + if got.JSONOutput != tt.want.JSONOutput { + t.Errorf("JSONOutput = %v, want %v", got.JSONOutput, tt.want.JSONOutput) + } + if got.Vision != tt.want.Vision { + t.Errorf("Vision = %v, want %v", got.Vision, tt.want.Vision) + } + if got.ContextWindow != tt.want.ContextWindow { + t.Errorf("ContextWindow = %v, want %v", got.ContextWindow, tt.want.ContextWindow) + } + if got.MaxOutput != tt.want.MaxOutput { + t.Errorf("MaxOutput = %v, want %v", got.MaxOutput, tt.want.MaxOutput) + } + // Check ThinkingModes + if tt.want.ThinkingModes == nil { + if len(got.ThinkingModes) != 0 { + t.Errorf("ThinkingModes should be empty, got %v", got.ThinkingModes) + } + } else if len(got.ThinkingModes) != len(tt.want.ThinkingModes) { + t.Errorf("ThinkingModes length = %v, want %v", len(got.ThinkingModes), len(tt.want.ThinkingModes)) + } + }) + } +} diff --git a/internal/provider/anthropic/provider.go b/internal/provider/anthropic/provider.go index 1af7a10..cc200e3 100644 --- a/internal/provider/anthropic/provider.go +++ b/internal/provider/anthropic/provider.go @@ -76,9 +76,36 @@ func (p *Provider) DefaultModel() string { return p.model } -// Models returns known Anthropic models with capabilities. -// Anthropic doesn't have a model listing API, so these are hardcoded. -func (p *Provider) Models(_ context.Context) ([]provider.ModelInfo, error) { +// Models returns available Anthropic models with capabilities by querying the API. +func (p *Provider) Models(ctx context.Context) ([]provider.ModelInfo, error) { + pager := p.client.Models.ListAutoPaging(ctx, anthropic.ModelListParams{}) + + var models []provider.ModelInfo + for pager.Next() { + m := pager.Current() + caps := inferAnthropicModelCapabilities(m.ID) + models = append(models, provider.ModelInfo{ + ID: m.ID, + Name: m.ID, + Provider: p.name, + Capabilities: caps, + }) + } + if err := pager.Err(); err != nil { + // Fallback to hardcoded list if API call fails + return p.fallbackModels(), nil + } + + if len(models) == 0 { + // API returned no models, use fallback + return p.fallbackModels(), nil + } + + return models, nil +} + +// fallbackModels returns a hardcoded list of known Anthropic models. +func (p *Provider) fallbackModels() []provider.ModelInfo { return []provider.ModelInfo{ { ID: "claude-opus-4-20250514", Name: "Claude Opus 4", Provider: p.name, @@ -109,5 +136,38 @@ func (p *Provider) Models(_ context.Context) ([]provider.ModelInfo, error) { ContextWindow: 200000, MaxOutput: 8192, }, }, - }, nil + } +} + +// inferAnthropicModelCapabilities infers capabilities from model ID. +func inferAnthropicModelCapabilities(modelID string) provider.Capabilities { + // Default capabilities for most modern Claude models + caps := provider.Capabilities{ + ToolUse: true, + JSONOutput: true, + Vision: true, + ThinkingModes: []provider.EffortLevel{provider.EffortLow, provider.EffortMedium, provider.EffortHigh}, + ContextWindow: 200000, + MaxOutput: 16000, + } + + // Model-specific overrides + switch modelID { + case "claude-opus-4-20250514", "claude-opus-4-20250612": + caps.MaxOutput = 32000 + case "claude-3-opus-20240229", "claude-3-sonnet-20240229": + caps.ThinkingModes = []provider.EffortLevel{provider.EffortLow, provider.EffortMedium, provider.EffortHigh} + caps.ContextWindow = 200000 + caps.MaxOutput = 4096 + case "claude-3-haiku-20240307": + caps.ThinkingModes = []provider.EffortLevel{provider.EffortLow, provider.EffortMedium, provider.EffortHigh} + caps.ContextWindow = 200000 + caps.MaxOutput = 4096 + case "claude-2", "claude-2:1", "claude-instant-1": + caps.ThinkingModes = nil // No extended thinking support + caps.ContextWindow = 100000 + caps.MaxOutput = 4096 + } + + return caps } diff --git a/internal/provider/google/models_test.go b/internal/provider/google/models_test.go new file mode 100644 index 0000000..e020f56 --- /dev/null +++ b/internal/provider/google/models_test.go @@ -0,0 +1,44 @@ +package google + +import ( + "context" + "testing" + + "somegit.dev/Owlibou/gnoma/internal/provider" +) + +func TestModels_Fallback(t *testing.T) { + // Test with invalid API key - should fall back to hardcoded list + cfg := provider.ProviderConfig{ + APIKey: "invalid-key", + BaseURL: "https://generativelanguage.googleapis.com", + } + p, err := New(cfg) + if err != nil { + t.Fatalf("New() error = %v", err) + } + + models, err := p.Models(context.Background()) + if err != nil { + t.Fatalf("Models() error = %v", err) + } + + // Should return fallback models + if len(models) == 0 { + t.Fatal("Models() returned empty list, expected fallback models") + } + + // Check that we have the expected fallback models + modelIDs := make(map[string]bool) + for _, m := range models { + modelIDs[m.ID] = true + } + + // Verify some expected models are present + expectedModels := []string{"gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.0-flash"} + for _, expected := range expectedModels { + if !modelIDs[expected] { + t.Errorf("Expected model %q not found in fallback list", expected) + } + } +} diff --git a/internal/provider/google/provider.go b/internal/provider/google/provider.go index 79688ec..c90a6d0 100644 --- a/internal/provider/google/provider.go +++ b/internal/provider/google/provider.go @@ -66,8 +66,34 @@ func (p *Provider) Name() string { return p.name } // DefaultModel returns the configured default model. func (p *Provider) DefaultModel() string { return p.model } -// Models returns known Google models with capabilities. -func (p *Provider) Models(_ context.Context) ([]provider.ModelInfo, error) { +// Models returns available Google models with capabilities by querying the API. +func (p *Provider) Models(ctx context.Context) ([]provider.ModelInfo, error) { + var models []provider.ModelInfo + for model, err := range p.client.Models.All(ctx) { + if err != nil { + // Fallback to hardcoded list if API call fails + return p.fallbackModels(), nil + } + + caps := inferGoogleModelCapabilities(model) + models = append(models, provider.ModelInfo{ + ID: model.Name, + Name: model.DisplayName, + Provider: p.name, + Capabilities: caps, + }) + } + + if len(models) == 0 { + // API returned no models, use fallback + return p.fallbackModels(), nil + } + + return models, nil +} + +// fallbackModels returns a hardcoded list of known Google models. +func (p *Provider) fallbackModels() []provider.ModelInfo { return []provider.ModelInfo{ { ID: "gemini-2.5-pro", Name: "Gemini 2.5 Pro", Provider: p.name, @@ -98,5 +124,34 @@ func (p *Provider) Models(_ context.Context) ([]provider.ModelInfo, error) { ContextWindow: 1048576, MaxOutput: 8192, }, }, - }, nil + } +} + +// inferGoogleModelCapabilities infers capabilities from the Google Model. +func inferGoogleModelCapabilities(m *genai.Model) provider.Capabilities { + // Default capabilities for most modern Gemini models + caps := provider.Capabilities{ + ToolUse: true, + JSONOutput: true, + Vision: true, + ThinkingModes: []provider.EffortLevel{provider.EffortLow, provider.EffortMedium, provider.EffortHigh}, + ContextWindow: 1048576, + MaxOutput: 65536, + } + + // Model-specific overrides based on model name + name := m.Name + switch { + case name == "gemini-2.5-pro", name == "gemini-2.5-flash": + caps.ContextWindow = 1048576 + caps.MaxOutput = 65536 + case name == "gemini-2.0-pro", name == "gemini-2.0-flash": + caps.ContextWindow = 1048576 + caps.MaxOutput = 8192 + case name == "gemini-1.5-pro", name == "gemini-1.5-flash": + caps.ContextWindow = 1048576 + caps.MaxOutput = 8192 + } + + return caps } diff --git a/internal/provider/openai/models_test.go b/internal/provider/openai/models_test.go new file mode 100644 index 0000000..7ca82ab --- /dev/null +++ b/internal/provider/openai/models_test.go @@ -0,0 +1,108 @@ +package openai + +import ( + "context" + "testing" + + "somegit.dev/Owlibou/gnoma/internal/provider" +) + +func TestModels_Fallback(t *testing.T) { + // Test with invalid API key - should fall back to hardcoded list + cfg := provider.ProviderConfig{ + APIKey: "invalid-key", + BaseURL: "https://api.openai.com/v1", + } + p, err := New(cfg) + if err != nil { + t.Fatalf("New() error = %v", err) + } + + models, err := p.Models(context.Background()) + if err != nil { + t.Fatalf("Models() error = %v", err) + } + + // Should return fallback models + if len(models) == 0 { + t.Fatal("Models() returned empty list, expected fallback models") + } + + // Check that we have the expected fallback models + modelIDs := make(map[string]bool) + for _, m := range models { + modelIDs[m.ID] = true + } + + // Verify some expected models are present + expectedModels := []string{"gpt-4o", "gpt-4o-mini", "o3", "o3-mini"} + for _, expected := range expectedModels { + if !modelIDs[expected] { + t.Errorf("Expected model %q not found in fallback list", expected) + } + } +} + +func TestInferOpenAIModelCapabilities(t *testing.T) { + tests := []struct { + modelID string + want provider.Capabilities + }{ + { + modelID: "gpt-4o", + want: provider.Capabilities{ + ToolUse: true, + JSONOutput: true, + Vision: true, + ContextWindow: 128000, + MaxOutput: 16384, + }, + }, + { + modelID: "o3", + want: provider.Capabilities{ + ToolUse: true, + JSONOutput: true, + Vision: true, + ThinkingModes: []provider.EffortLevel{provider.EffortLow, provider.EffortMedium, provider.EffortHigh}, + ContextWindow: 200000, + MaxOutput: 100000, + }, + }, + { + modelID: "gpt-3.5-turbo", + want: provider.Capabilities{ + ToolUse: false, + JSONOutput: true, + Vision: false, + ContextWindow: 16384, + MaxOutput: 4096, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.modelID, func(t *testing.T) { + got := inferOpenAIModelCapabilities(tt.modelID) + if got.ToolUse != tt.want.ToolUse { + t.Errorf("ToolUse = %v, want %v", got.ToolUse, tt.want.ToolUse) + } + if got.JSONOutput != tt.want.JSONOutput { + t.Errorf("JSONOutput = %v, want %v", got.JSONOutput, tt.want.JSONOutput) + } + if got.Vision != tt.want.Vision { + t.Errorf("Vision = %v, want %v", got.Vision, tt.want.Vision) + } + if got.ContextWindow != tt.want.ContextWindow { + t.Errorf("ContextWindow = %v, want %v", got.ContextWindow, tt.want.ContextWindow) + } + if got.MaxOutput != tt.want.MaxOutput { + t.Errorf("MaxOutput = %v, want %v", got.MaxOutput, tt.want.MaxOutput) + } + // Check ThinkingModes length + if len(got.ThinkingModes) != len(tt.want.ThinkingModes) { + t.Errorf("ThinkingModes length = %v, want %v", len(got.ThinkingModes), len(tt.want.ThinkingModes)) + } + }) + } +} diff --git a/internal/provider/openai/provider.go b/internal/provider/openai/provider.go index 4f5436e..a868ca7 100644 --- a/internal/provider/openai/provider.go +++ b/internal/provider/openai/provider.go @@ -79,8 +79,36 @@ func (p *Provider) Name() string { return p.name } // DefaultModel returns the configured default model. func (p *Provider) DefaultModel() string { return p.model } -// Models returns known OpenAI models with capabilities. -func (p *Provider) Models(_ context.Context) ([]provider.ModelInfo, error) { +// Models returns available OpenAI models with capabilities by querying the API. +func (p *Provider) Models(ctx context.Context) ([]provider.ModelInfo, error) { + pager := p.client.Models.ListAutoPaging(ctx) + + var models []provider.ModelInfo + for pager.Next() { + m := pager.Current() + caps := inferOpenAIModelCapabilities(m.ID) + models = append(models, provider.ModelInfo{ + ID: m.ID, + Name: m.ID, + Provider: p.name, + Capabilities: caps, + }) + } + if err := pager.Err(); err != nil { + // Fallback to hardcoded list if API call fails + return p.fallbackModels(), nil + } + + if len(models) == 0 { + // API returned no models, use fallback + return p.fallbackModels(), nil + } + + return models, nil +} + +// fallbackModels returns a hardcoded list of known OpenAI models. +func (p *Provider) fallbackModels() []provider.ModelInfo { return []provider.ModelInfo{ { ID: "gpt-4o", Name: "GPT-4o", Provider: p.name, @@ -116,5 +144,39 @@ func (p *Provider) Models(_ context.Context) ([]provider.ModelInfo, error) { MaxOutput: 100000, }, }, - }, nil + } +} + +// inferOpenAIModelCapabilities infers capabilities from model ID. +func inferOpenAIModelCapabilities(modelID string) provider.Capabilities { + // Default capabilities for most modern OpenAI models + caps := provider.Capabilities{ + ToolUse: true, + JSONOutput: true, + Vision: true, + ContextWindow: 128000, + MaxOutput: 16384, + } + + // Model-specific overrides + switch modelID { + case "gpt-4o", "gpt-4o-mini": + caps.ContextWindow = 128000 + caps.MaxOutput = 16384 + case "o3", "o3-mini": + caps.ThinkingModes = []provider.EffortLevel{provider.EffortLow, provider.EffortMedium, provider.EffortHigh} + caps.ContextWindow = 200000 + caps.MaxOutput = 100000 + case "gpt-4", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613": + caps.Vision = false + caps.ContextWindow = 8192 + caps.MaxOutput = 8192 + case "gpt-3.5-turbo", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613": + caps.Vision = false + caps.ToolUse = false + caps.ContextWindow = 16384 + caps.MaxOutput = 4096 + } + + return caps }