feat: add dynamic model discovery within providers
- OpenAI provider: use Models.ListAutoPaging() to discover available models - Anthropic provider: use Models.ListAutoPaging() to discover available models - Google provider: use Models.All() iterator to discover available models - All providers fall back to hardcoded lists if API calls fail - Add capability inference functions for each provider based on model ID - Add tests for model discovery fallback behavior This enables gnoma to dynamically discover new models as they become available from cloud providers, while maintaining backward compatibility with fallback lists for offline use or API failures. Generated by Mistral Vibe. Co-Authored-By: Mistral Vibe <vibe@mistral.ai>
This commit is contained in:
@@ -0,0 +1,114 @@
|
||||
package anthropic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
)
|
||||
|
||||
func TestModels_Fallback(t *testing.T) {
|
||||
// Test with invalid API key - should fall back to hardcoded list
|
||||
cfg := provider.ProviderConfig{
|
||||
APIKey: "invalid-key",
|
||||
BaseURL: "https://api.anthropic.com/v1",
|
||||
}
|
||||
p, err := New(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
models, err := p.Models(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Models() error = %v", err)
|
||||
}
|
||||
|
||||
// Should return fallback models
|
||||
if len(models) == 0 {
|
||||
t.Fatal("Models() returned empty list, expected fallback models")
|
||||
}
|
||||
|
||||
// Check that we have the expected fallback models
|
||||
modelIDs := make(map[string]bool)
|
||||
for _, m := range models {
|
||||
modelIDs[m.ID] = true
|
||||
}
|
||||
|
||||
// Verify some expected models are present
|
||||
expectedModels := []string{"claude-opus-4-20250514", "claude-sonnet-4-20250514", "claude-haiku-4-5-20251001"}
|
||||
for _, expected := range expectedModels {
|
||||
if !modelIDs[expected] {
|
||||
t.Errorf("Expected model %q not found in fallback list", expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInferAnthropicModelCapabilities(t *testing.T) {
|
||||
tests := []struct {
|
||||
modelID string
|
||||
want provider.Capabilities
|
||||
}{
|
||||
{
|
||||
modelID: "claude-opus-4-20250514",
|
||||
want: provider.Capabilities{
|
||||
ToolUse: true,
|
||||
JSONOutput: true,
|
||||
Vision: true,
|
||||
ThinkingModes: []provider.EffortLevel{provider.EffortLow, provider.EffortMedium, provider.EffortHigh},
|
||||
ContextWindow: 200000,
|
||||
MaxOutput: 32000,
|
||||
},
|
||||
},
|
||||
{
|
||||
modelID: "claude-3-opus-20240229",
|
||||
want: provider.Capabilities{
|
||||
ToolUse: true,
|
||||
JSONOutput: true,
|
||||
Vision: true,
|
||||
ThinkingModes: []provider.EffortLevel{provider.EffortLow, provider.EffortMedium, provider.EffortHigh},
|
||||
ContextWindow: 200000,
|
||||
MaxOutput: 4096,
|
||||
},
|
||||
},
|
||||
{
|
||||
modelID: "claude-2",
|
||||
want: provider.Capabilities{
|
||||
ToolUse: true,
|
||||
JSONOutput: true,
|
||||
Vision: true,
|
||||
ThinkingModes: nil,
|
||||
ContextWindow: 100000,
|
||||
MaxOutput: 4096,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.modelID, func(t *testing.T) {
|
||||
got := inferAnthropicModelCapabilities(tt.modelID)
|
||||
if got.ToolUse != tt.want.ToolUse {
|
||||
t.Errorf("ToolUse = %v, want %v", got.ToolUse, tt.want.ToolUse)
|
||||
}
|
||||
if got.JSONOutput != tt.want.JSONOutput {
|
||||
t.Errorf("JSONOutput = %v, want %v", got.JSONOutput, tt.want.JSONOutput)
|
||||
}
|
||||
if got.Vision != tt.want.Vision {
|
||||
t.Errorf("Vision = %v, want %v", got.Vision, tt.want.Vision)
|
||||
}
|
||||
if got.ContextWindow != tt.want.ContextWindow {
|
||||
t.Errorf("ContextWindow = %v, want %v", got.ContextWindow, tt.want.ContextWindow)
|
||||
}
|
||||
if got.MaxOutput != tt.want.MaxOutput {
|
||||
t.Errorf("MaxOutput = %v, want %v", got.MaxOutput, tt.want.MaxOutput)
|
||||
}
|
||||
// Check ThinkingModes
|
||||
if tt.want.ThinkingModes == nil {
|
||||
if len(got.ThinkingModes) != 0 {
|
||||
t.Errorf("ThinkingModes should be empty, got %v", got.ThinkingModes)
|
||||
}
|
||||
} else if len(got.ThinkingModes) != len(tt.want.ThinkingModes) {
|
||||
t.Errorf("ThinkingModes length = %v, want %v", len(got.ThinkingModes), len(tt.want.ThinkingModes))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -76,9 +76,36 @@ func (p *Provider) DefaultModel() string {
|
||||
return p.model
|
||||
}
|
||||
|
||||
// Models returns known Anthropic models with capabilities.
|
||||
// Anthropic doesn't have a model listing API, so these are hardcoded.
|
||||
func (p *Provider) Models(_ context.Context) ([]provider.ModelInfo, error) {
|
||||
// Models returns available Anthropic models with capabilities by querying the API.
|
||||
func (p *Provider) Models(ctx context.Context) ([]provider.ModelInfo, error) {
|
||||
pager := p.client.Models.ListAutoPaging(ctx, anthropic.ModelListParams{})
|
||||
|
||||
var models []provider.ModelInfo
|
||||
for pager.Next() {
|
||||
m := pager.Current()
|
||||
caps := inferAnthropicModelCapabilities(m.ID)
|
||||
models = append(models, provider.ModelInfo{
|
||||
ID: m.ID,
|
||||
Name: m.ID,
|
||||
Provider: p.name,
|
||||
Capabilities: caps,
|
||||
})
|
||||
}
|
||||
if err := pager.Err(); err != nil {
|
||||
// Fallback to hardcoded list if API call fails
|
||||
return p.fallbackModels(), nil
|
||||
}
|
||||
|
||||
if len(models) == 0 {
|
||||
// API returned no models, use fallback
|
||||
return p.fallbackModels(), nil
|
||||
}
|
||||
|
||||
return models, nil
|
||||
}
|
||||
|
||||
// fallbackModels returns a hardcoded list of known Anthropic models.
|
||||
func (p *Provider) fallbackModels() []provider.ModelInfo {
|
||||
return []provider.ModelInfo{
|
||||
{
|
||||
ID: "claude-opus-4-20250514", Name: "Claude Opus 4", Provider: p.name,
|
||||
@@ -109,5 +136,38 @@ func (p *Provider) Models(_ context.Context) ([]provider.ModelInfo, error) {
|
||||
ContextWindow: 200000, MaxOutput: 8192,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// inferAnthropicModelCapabilities infers capabilities from model ID.
|
||||
func inferAnthropicModelCapabilities(modelID string) provider.Capabilities {
|
||||
// Default capabilities for most modern Claude models
|
||||
caps := provider.Capabilities{
|
||||
ToolUse: true,
|
||||
JSONOutput: true,
|
||||
Vision: true,
|
||||
ThinkingModes: []provider.EffortLevel{provider.EffortLow, provider.EffortMedium, provider.EffortHigh},
|
||||
ContextWindow: 200000,
|
||||
MaxOutput: 16000,
|
||||
}
|
||||
|
||||
// Model-specific overrides
|
||||
switch modelID {
|
||||
case "claude-opus-4-20250514", "claude-opus-4-20250612":
|
||||
caps.MaxOutput = 32000
|
||||
case "claude-3-opus-20240229", "claude-3-sonnet-20240229":
|
||||
caps.ThinkingModes = []provider.EffortLevel{provider.EffortLow, provider.EffortMedium, provider.EffortHigh}
|
||||
caps.ContextWindow = 200000
|
||||
caps.MaxOutput = 4096
|
||||
case "claude-3-haiku-20240307":
|
||||
caps.ThinkingModes = []provider.EffortLevel{provider.EffortLow, provider.EffortMedium, provider.EffortHigh}
|
||||
caps.ContextWindow = 200000
|
||||
caps.MaxOutput = 4096
|
||||
case "claude-2", "claude-2:1", "claude-instant-1":
|
||||
caps.ThinkingModes = nil // No extended thinking support
|
||||
caps.ContextWindow = 100000
|
||||
caps.MaxOutput = 4096
|
||||
}
|
||||
|
||||
return caps
|
||||
}
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
package google
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
)
|
||||
|
||||
func TestModels_Fallback(t *testing.T) {
|
||||
// Test with invalid API key - should fall back to hardcoded list
|
||||
cfg := provider.ProviderConfig{
|
||||
APIKey: "invalid-key",
|
||||
BaseURL: "https://generativelanguage.googleapis.com",
|
||||
}
|
||||
p, err := New(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
models, err := p.Models(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Models() error = %v", err)
|
||||
}
|
||||
|
||||
// Should return fallback models
|
||||
if len(models) == 0 {
|
||||
t.Fatal("Models() returned empty list, expected fallback models")
|
||||
}
|
||||
|
||||
// Check that we have the expected fallback models
|
||||
modelIDs := make(map[string]bool)
|
||||
for _, m := range models {
|
||||
modelIDs[m.ID] = true
|
||||
}
|
||||
|
||||
// Verify some expected models are present
|
||||
expectedModels := []string{"gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.0-flash"}
|
||||
for _, expected := range expectedModels {
|
||||
if !modelIDs[expected] {
|
||||
t.Errorf("Expected model %q not found in fallback list", expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -66,8 +66,34 @@ func (p *Provider) Name() string { return p.name }
|
||||
// DefaultModel returns the configured default model.
|
||||
func (p *Provider) DefaultModel() string { return p.model }
|
||||
|
||||
// Models returns known Google models with capabilities.
|
||||
func (p *Provider) Models(_ context.Context) ([]provider.ModelInfo, error) {
|
||||
// Models returns available Google models with capabilities by querying the API.
|
||||
func (p *Provider) Models(ctx context.Context) ([]provider.ModelInfo, error) {
|
||||
var models []provider.ModelInfo
|
||||
for model, err := range p.client.Models.All(ctx) {
|
||||
if err != nil {
|
||||
// Fallback to hardcoded list if API call fails
|
||||
return p.fallbackModels(), nil
|
||||
}
|
||||
|
||||
caps := inferGoogleModelCapabilities(model)
|
||||
models = append(models, provider.ModelInfo{
|
||||
ID: model.Name,
|
||||
Name: model.DisplayName,
|
||||
Provider: p.name,
|
||||
Capabilities: caps,
|
||||
})
|
||||
}
|
||||
|
||||
if len(models) == 0 {
|
||||
// API returned no models, use fallback
|
||||
return p.fallbackModels(), nil
|
||||
}
|
||||
|
||||
return models, nil
|
||||
}
|
||||
|
||||
// fallbackModels returns a hardcoded list of known Google models.
|
||||
func (p *Provider) fallbackModels() []provider.ModelInfo {
|
||||
return []provider.ModelInfo{
|
||||
{
|
||||
ID: "gemini-2.5-pro", Name: "Gemini 2.5 Pro", Provider: p.name,
|
||||
@@ -98,5 +124,34 @@ func (p *Provider) Models(_ context.Context) ([]provider.ModelInfo, error) {
|
||||
ContextWindow: 1048576, MaxOutput: 8192,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// inferGoogleModelCapabilities infers capabilities from the Google Model.
|
||||
func inferGoogleModelCapabilities(m *genai.Model) provider.Capabilities {
|
||||
// Default capabilities for most modern Gemini models
|
||||
caps := provider.Capabilities{
|
||||
ToolUse: true,
|
||||
JSONOutput: true,
|
||||
Vision: true,
|
||||
ThinkingModes: []provider.EffortLevel{provider.EffortLow, provider.EffortMedium, provider.EffortHigh},
|
||||
ContextWindow: 1048576,
|
||||
MaxOutput: 65536,
|
||||
}
|
||||
|
||||
// Model-specific overrides based on model name
|
||||
name := m.Name
|
||||
switch {
|
||||
case name == "gemini-2.5-pro", name == "gemini-2.5-flash":
|
||||
caps.ContextWindow = 1048576
|
||||
caps.MaxOutput = 65536
|
||||
case name == "gemini-2.0-pro", name == "gemini-2.0-flash":
|
||||
caps.ContextWindow = 1048576
|
||||
caps.MaxOutput = 8192
|
||||
case name == "gemini-1.5-pro", name == "gemini-1.5-flash":
|
||||
caps.ContextWindow = 1048576
|
||||
caps.MaxOutput = 8192
|
||||
}
|
||||
|
||||
return caps
|
||||
}
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
)
|
||||
|
||||
func TestModels_Fallback(t *testing.T) {
|
||||
// Test with invalid API key - should fall back to hardcoded list
|
||||
cfg := provider.ProviderConfig{
|
||||
APIKey: "invalid-key",
|
||||
BaseURL: "https://api.openai.com/v1",
|
||||
}
|
||||
p, err := New(cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("New() error = %v", err)
|
||||
}
|
||||
|
||||
models, err := p.Models(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Models() error = %v", err)
|
||||
}
|
||||
|
||||
// Should return fallback models
|
||||
if len(models) == 0 {
|
||||
t.Fatal("Models() returned empty list, expected fallback models")
|
||||
}
|
||||
|
||||
// Check that we have the expected fallback models
|
||||
modelIDs := make(map[string]bool)
|
||||
for _, m := range models {
|
||||
modelIDs[m.ID] = true
|
||||
}
|
||||
|
||||
// Verify some expected models are present
|
||||
expectedModels := []string{"gpt-4o", "gpt-4o-mini", "o3", "o3-mini"}
|
||||
for _, expected := range expectedModels {
|
||||
if !modelIDs[expected] {
|
||||
t.Errorf("Expected model %q not found in fallback list", expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInferOpenAIModelCapabilities(t *testing.T) {
|
||||
tests := []struct {
|
||||
modelID string
|
||||
want provider.Capabilities
|
||||
}{
|
||||
{
|
||||
modelID: "gpt-4o",
|
||||
want: provider.Capabilities{
|
||||
ToolUse: true,
|
||||
JSONOutput: true,
|
||||
Vision: true,
|
||||
ContextWindow: 128000,
|
||||
MaxOutput: 16384,
|
||||
},
|
||||
},
|
||||
{
|
||||
modelID: "o3",
|
||||
want: provider.Capabilities{
|
||||
ToolUse: true,
|
||||
JSONOutput: true,
|
||||
Vision: true,
|
||||
ThinkingModes: []provider.EffortLevel{provider.EffortLow, provider.EffortMedium, provider.EffortHigh},
|
||||
ContextWindow: 200000,
|
||||
MaxOutput: 100000,
|
||||
},
|
||||
},
|
||||
{
|
||||
modelID: "gpt-3.5-turbo",
|
||||
want: provider.Capabilities{
|
||||
ToolUse: false,
|
||||
JSONOutput: true,
|
||||
Vision: false,
|
||||
ContextWindow: 16384,
|
||||
MaxOutput: 4096,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.modelID, func(t *testing.T) {
|
||||
got := inferOpenAIModelCapabilities(tt.modelID)
|
||||
if got.ToolUse != tt.want.ToolUse {
|
||||
t.Errorf("ToolUse = %v, want %v", got.ToolUse, tt.want.ToolUse)
|
||||
}
|
||||
if got.JSONOutput != tt.want.JSONOutput {
|
||||
t.Errorf("JSONOutput = %v, want %v", got.JSONOutput, tt.want.JSONOutput)
|
||||
}
|
||||
if got.Vision != tt.want.Vision {
|
||||
t.Errorf("Vision = %v, want %v", got.Vision, tt.want.Vision)
|
||||
}
|
||||
if got.ContextWindow != tt.want.ContextWindow {
|
||||
t.Errorf("ContextWindow = %v, want %v", got.ContextWindow, tt.want.ContextWindow)
|
||||
}
|
||||
if got.MaxOutput != tt.want.MaxOutput {
|
||||
t.Errorf("MaxOutput = %v, want %v", got.MaxOutput, tt.want.MaxOutput)
|
||||
}
|
||||
// Check ThinkingModes length
|
||||
if len(got.ThinkingModes) != len(tt.want.ThinkingModes) {
|
||||
t.Errorf("ThinkingModes length = %v, want %v", len(got.ThinkingModes), len(tt.want.ThinkingModes))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -79,8 +79,36 @@ func (p *Provider) Name() string { return p.name }
|
||||
// DefaultModel returns the configured default model.
|
||||
func (p *Provider) DefaultModel() string { return p.model }
|
||||
|
||||
// Models returns known OpenAI models with capabilities.
|
||||
func (p *Provider) Models(_ context.Context) ([]provider.ModelInfo, error) {
|
||||
// Models returns available OpenAI models with capabilities by querying the API.
|
||||
func (p *Provider) Models(ctx context.Context) ([]provider.ModelInfo, error) {
|
||||
pager := p.client.Models.ListAutoPaging(ctx)
|
||||
|
||||
var models []provider.ModelInfo
|
||||
for pager.Next() {
|
||||
m := pager.Current()
|
||||
caps := inferOpenAIModelCapabilities(m.ID)
|
||||
models = append(models, provider.ModelInfo{
|
||||
ID: m.ID,
|
||||
Name: m.ID,
|
||||
Provider: p.name,
|
||||
Capabilities: caps,
|
||||
})
|
||||
}
|
||||
if err := pager.Err(); err != nil {
|
||||
// Fallback to hardcoded list if API call fails
|
||||
return p.fallbackModels(), nil
|
||||
}
|
||||
|
||||
if len(models) == 0 {
|
||||
// API returned no models, use fallback
|
||||
return p.fallbackModels(), nil
|
||||
}
|
||||
|
||||
return models, nil
|
||||
}
|
||||
|
||||
// fallbackModels returns a hardcoded list of known OpenAI models.
|
||||
func (p *Provider) fallbackModels() []provider.ModelInfo {
|
||||
return []provider.ModelInfo{
|
||||
{
|
||||
ID: "gpt-4o", Name: "GPT-4o", Provider: p.name,
|
||||
@@ -116,5 +144,39 @@ func (p *Provider) Models(_ context.Context) ([]provider.ModelInfo, error) {
|
||||
MaxOutput: 100000,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// inferOpenAIModelCapabilities infers capabilities from model ID.
|
||||
func inferOpenAIModelCapabilities(modelID string) provider.Capabilities {
|
||||
// Default capabilities for most modern OpenAI models
|
||||
caps := provider.Capabilities{
|
||||
ToolUse: true,
|
||||
JSONOutput: true,
|
||||
Vision: true,
|
||||
ContextWindow: 128000,
|
||||
MaxOutput: 16384,
|
||||
}
|
||||
|
||||
// Model-specific overrides
|
||||
switch modelID {
|
||||
case "gpt-4o", "gpt-4o-mini":
|
||||
caps.ContextWindow = 128000
|
||||
caps.MaxOutput = 16384
|
||||
case "o3", "o3-mini":
|
||||
caps.ThinkingModes = []provider.EffortLevel{provider.EffortLow, provider.EffortMedium, provider.EffortHigh}
|
||||
caps.ContextWindow = 200000
|
||||
caps.MaxOutput = 100000
|
||||
case "gpt-4", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613":
|
||||
caps.Vision = false
|
||||
caps.ContextWindow = 8192
|
||||
caps.MaxOutput = 8192
|
||||
case "gpt-3.5-turbo", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613":
|
||||
caps.Vision = false
|
||||
caps.ToolUse = false
|
||||
caps.ContextWindow = 16384
|
||||
caps.MaxOutput = 4096
|
||||
}
|
||||
|
||||
return caps
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user