diff --git a/cmd/gnoma/main.go b/cmd/gnoma/main.go index 8809506..59f6d43 100644 --- a/cmd/gnoma/main.go +++ b/cmd/gnoma/main.go @@ -12,6 +12,7 @@ import ( "somegit.dev/Owlibou/gnoma/internal/engine" "somegit.dev/Owlibou/gnoma/internal/provider" + "somegit.dev/Owlibou/gnoma/internal/router" "somegit.dev/Owlibou/gnoma/internal/security" anthropicprov "somegit.dev/Owlibou/gnoma/internal/provider/anthropic" "somegit.dev/Owlibou/gnoma/internal/provider/mistral" @@ -81,6 +82,23 @@ func main() { // Re-register bash tool with aliases reg.Register(bash.New(bash.WithAliases(aliases))) + // Create router and register the provider as a single arm + // (M4 foundation: one provider from CLI. Multi-provider routing comes with config.) + rtr := router.New(router.Config{Logger: logger}) + armModel := *model + if armModel == "" { + armModel = prov.DefaultModel() + } + armID := router.NewArmID(*providerName, armModel) + rtr.RegisterArm(&router.Arm{ + ID: armID, + Provider: prov, + ModelName: armModel, + IsLocal: localProviders[*providerName], + Capabilities: provider.Capabilities{ToolUse: true}, // trust CLI provider + }) + rtr.ForceArm(armID) + // Create firewall fw := security.NewFirewall(security.FirewallConfig{ ScanOutgoing: true, @@ -92,6 +110,7 @@ func main() { // Create engine eng, err := engine.New(engine.Config{ Provider: prov, + Router: rtr, Tools: reg, Firewall: fw, System: *system, diff --git a/internal/engine/engine.go b/internal/engine/engine.go index 8899da5..d42eb46 100644 --- a/internal/engine/engine.go +++ b/internal/engine/engine.go @@ -7,13 +7,15 @@ import ( "somegit.dev/Owlibou/gnoma/internal/message" "somegit.dev/Owlibou/gnoma/internal/provider" + "somegit.dev/Owlibou/gnoma/internal/router" "somegit.dev/Owlibou/gnoma/internal/security" "somegit.dev/Owlibou/gnoma/internal/tool" ) // Config holds engine configuration. type Config struct { - Provider provider.Provider + Provider provider.Provider // direct provider (used if Router is nil) + Router *router.Router // nil = use Provider directly Tools *tool.Registry Firewall *security.Firewall // nil = no scanning System string // system prompt diff --git a/internal/engine/loop.go b/internal/engine/loop.go index a9a5378..9745277 100644 --- a/internal/engine/loop.go +++ b/internal/engine/loop.go @@ -7,6 +7,7 @@ import ( "somegit.dev/Owlibou/gnoma/internal/message" "somegit.dev/Owlibou/gnoma/internal/provider" + "somegit.dev/Owlibou/gnoma/internal/router" "somegit.dev/Owlibou/gnoma/internal/stream" ) @@ -38,16 +39,50 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) { // Build provider request (gates tools on model capabilities) req := e.buildRequest(ctx) - e.logger.Debug("streaming request", - "provider", e.cfg.Provider.Name(), - "model", req.Model, - "messages", len(req.Messages), - "tools", len(req.Tools), - "round", turn.Rounds, - ) + // Route and stream + var s stream.Stream + var err error - // Stream from provider - s, err := e.cfg.Provider.Stream(ctx, req) + if e.cfg.Router != nil { + // Classify task from the latest user message + prompt := "" + for i := len(e.history) - 1; i >= 0; i-- { + if e.history[i].Role == message.RoleUser { + prompt = e.history[i].TextContent() + break + } + } + task := router.ClassifyTask(prompt) + task.EstimatedTokens = 4000 // rough default + + e.logger.Debug("routing request", + "task_type", task.Type, + "complexity", task.ComplexityScore, + "round", turn.Rounds, + ) + + var arm *router.Arm + s, arm, err = e.cfg.Router.Stream(ctx, task, req) + if arm != nil { + e.logger.Debug("streaming request", + "provider", arm.Provider.Name(), + "model", arm.ModelName, + "arm", arm.ID, + "messages", len(req.Messages), + "tools", len(req.Tools), + "round", turn.Rounds, + ) + } + } else { + e.logger.Debug("streaming request", + "provider", e.cfg.Provider.Name(), + "model", req.Model, + "messages", len(req.Messages), + "tools", len(req.Tools), + "round", turn.Rounds, + ) + s, err = e.cfg.Provider.Stream(ctx, req) + } if err != nil { return nil, fmt.Errorf("provider stream: %w", err) } diff --git a/internal/router/arm.go b/internal/router/arm.go new file mode 100644 index 0000000..d28f183 --- /dev/null +++ b/internal/router/arm.go @@ -0,0 +1,47 @@ +package router + +import ( + "somegit.dev/Owlibou/gnoma/internal/provider" +) + +// ArmID uniquely identifies a model+provider pair. +type ArmID string + +// Arm represents a provider+model pair available for routing. +type Arm struct { + ID ArmID + Provider provider.Provider + ModelName string + IsLocal bool + Capabilities provider.Capabilities + Pools []*LimitPool + + // Cost per 1k tokens (EUR, estimated) + CostPer1kInput float64 + CostPer1kOutput float64 +} + +// NewArmID creates an arm ID from provider name and model. +func NewArmID(providerName, model string) ArmID { + return ArmID(providerName + "/" + model) +} + +// EstimateCost returns estimated cost in EUR for a task. +func (a *Arm) EstimateCost(estimatedTokens int) float64 { + // Rough estimate: 60% input, 40% output + inputTokens := float64(estimatedTokens) * 0.6 + outputTokens := float64(estimatedTokens) * 0.4 + return (inputTokens/1000)*a.CostPer1kInput + (outputTokens/1000)*a.CostPer1kOutput +} + +// SupportsTools returns true if this arm's model supports function calling. +func (a *Arm) SupportsTools() bool { + return a.Capabilities.ToolUse +} + +// ArmPerf holds live performance metrics for an arm. +type ArmPerf struct { + TTFT_P50_ms float64 // time to first token, p50 + TTFT_P95_ms float64 // time to first token, p95 + ToksPerSec float64 // tokens per second throughput +} diff --git a/internal/router/pool.go b/internal/router/pool.go new file mode 100644 index 0000000..725b75c --- /dev/null +++ b/internal/router/pool.go @@ -0,0 +1,170 @@ +package router + +import ( + "math" + "sync" + "time" +) + +// PoolKind identifies the type of resource a pool tracks. +type PoolKind int + +const ( + PoolRPM PoolKind = iota // requests per minute + PoolRPD // requests per day + PoolTPD // tokens per day + PoolCostEUR // monetary cost cap + PoolCustom // arbitrary units +) + +// LimitPool tracks a shared resource budget that arms draw from. +type LimitPool struct { + mu sync.Mutex + + ID string + Kind PoolKind + TotalLimit float64 + Used float64 + Reserved float64 // optimistically reserved for in-flight requests + ResetPeriod time.Duration + ResetAt time.Time + + // Per-arm consumption rates (units per 1k tokens or per request) + ArmRates map[ArmID]float64 + + // Scarcity curve aggressiveness. k=2 gentle, k=4 aggressive hoarding. + ScarcityK float64 +} + +// RemainingFraction returns the fraction of budget still available. +func (p *LimitPool) RemainingFraction() float64 { + p.mu.Lock() + defer p.mu.Unlock() + if p.TotalLimit <= 0 { + return 0 + } + return 1.0 - (p.Used+p.Reserved)/p.TotalLimit +} + +// ScarcityMultiplier returns a cost inflation factor based on remaining budget. +// As resources deplete, the multiplier increases, making the arm more expensive. +func (p *LimitPool) ScarcityMultiplier() float64 { + p.mu.Lock() + defer p.mu.Unlock() + return p.scarcityMultiplierLocked() +} + +func (p *LimitPool) scarcityMultiplierLocked() float64 { + if p.TotalLimit <= 0 { + return math.Inf(1) + } + + f := 1.0 - (p.Used+p.Reserved)/p.TotalLimit + if f <= 0 { + return math.Inf(1) // exhausted + } + + // Use-it-or-lose-it: if reset is imminent and headroom exists, discount + hoursToReset := time.Until(p.ResetAt).Hours() + if !p.ResetAt.IsZero() && hoursToReset > 0 && hoursToReset < 1.0 && f > 0.3 { + return 0.5 + } + + k := p.ScarcityK + if k <= 0 { + k = 2.0 // gentle default + } + return 1.0 / math.Pow(f, k) +} + +// Exhausted returns true if the pool has no remaining capacity. +func (p *LimitPool) Exhausted() bool { + return p.RemainingFraction() <= 0 +} + +// CanAfford returns true if the pool can cover the projected consumption. +func (p *LimitPool) CanAfford(armID ArmID, estimatedTokens int) bool { + p.mu.Lock() + defer p.mu.Unlock() + + rate := p.ArmRates[armID] + if rate == 0 { + return true // no rate defined = no limit + } + projected := rate * float64(estimatedTokens) / 1000.0 + available := p.TotalLimit - p.Used - p.Reserved + return projected <= available +} + +// Reservation represents an optimistic resource reservation. +type Reservation struct { + pool *LimitPool + armID ArmID + projected float64 + committed bool +} + +// Reserve creates an optimistic reservation. Call Commit() with actual usage +// on completion, or Rollback() on failure. +func (p *LimitPool) Reserve(armID ArmID, estimatedTokens int) (*Reservation, bool) { + p.mu.Lock() + defer p.mu.Unlock() + + rate := p.ArmRates[armID] + if rate == 0 { + return &Reservation{pool: p}, true // no limit + } + + projected := rate * float64(estimatedTokens) / 1000.0 + available := p.TotalLimit - p.Used - p.Reserved + if projected > available { + return nil, false + } + + p.Reserved += projected + return &Reservation{ + pool: p, + armID: armID, + projected: projected, + }, true +} + +// Commit finalizes the reservation with actual consumption. +func (r *Reservation) Commit(actualTokens int) { + if r.committed || r.pool == nil { + return + } + r.committed = true + r.pool.mu.Lock() + defer r.pool.mu.Unlock() + + rate := r.pool.ArmRates[r.armID] + actual := rate * float64(actualTokens) / 1000.0 + + r.pool.Reserved -= r.projected + r.pool.Used += actual +} + +// Rollback releases the reservation without consumption. +func (r *Reservation) Rollback() { + if r.committed || r.pool == nil || r.projected == 0 { + return + } + r.committed = true + r.pool.mu.Lock() + defer r.pool.mu.Unlock() + + r.pool.Reserved -= r.projected +} + +// CheckReset resets usage if the reset period has elapsed. +func (p *LimitPool) CheckReset() { + p.mu.Lock() + defer p.mu.Unlock() + + if !p.ResetAt.IsZero() && time.Now().After(p.ResetAt) { + p.Used = 0 + p.Reserved = 0 + p.ResetAt = p.ResetAt.Add(p.ResetPeriod) + } +} diff --git a/internal/router/router.go b/internal/router/router.go new file mode 100644 index 0000000..c2283a1 --- /dev/null +++ b/internal/router/router.go @@ -0,0 +1,160 @@ +package router + +import ( + "context" + "fmt" + "log/slog" + "sync" + + "somegit.dev/Owlibou/gnoma/internal/provider" + "somegit.dev/Owlibou/gnoma/internal/stream" +) + +// Router selects the best arm for a given task. +// M4: heuristic selection. M9: bandit learning. +type Router struct { + mu sync.RWMutex + arms map[ArmID]*Arm + logger *slog.Logger + + // Optional: force a specific arm (--provider flag override) + forcedArm ArmID +} + +type Config struct { + Logger *slog.Logger +} + +func New(cfg Config) *Router { + logger := cfg.Logger + if logger == nil { + logger = slog.Default() + } + return &Router{ + arms: make(map[ArmID]*Arm), + logger: logger, + } +} + +// RegisterArm adds an arm to the router. +func (r *Router) RegisterArm(arm *Arm) { + r.mu.Lock() + defer r.mu.Unlock() + r.arms[arm.ID] = arm + r.logger.Debug("arm registered", "id", arm.ID, "local", arm.IsLocal, "tools", arm.SupportsTools()) +} + +// ForceArm overrides routing to always select a specific arm. +// Used for --provider CLI flag. +func (r *Router) ForceArm(id ArmID) { + r.mu.Lock() + defer r.mu.Unlock() + r.forcedArm = id +} + +// Select picks the best arm for the given task. +func (r *Router) Select(task Task) RoutingDecision { + r.mu.RLock() + defer r.mu.RUnlock() + + // If an arm is forced, use it directly + if r.forcedArm != "" { + arm, ok := r.arms[r.forcedArm] + if !ok { + return RoutingDecision{Error: fmt.Errorf("forced arm %q not found", r.forcedArm)} + } + return RoutingDecision{Strategy: StrategySingleArm, Arm: arm} + } + + // Collect all arms + allArms := make([]*Arm, 0, len(r.arms)) + for _, arm := range r.arms { + allArms = append(allArms, arm) + } + + if len(allArms) == 0 { + return RoutingDecision{Error: fmt.Errorf("no arms registered")} + } + + // Filter to feasible arms + feasible := filterFeasible(allArms, task) + if len(feasible) == 0 { + return RoutingDecision{Error: fmt.Errorf("no feasible arm for task type %s", task.Type)} + } + + // Select best + best := selectBest(feasible, task) + if best == nil { + return RoutingDecision{Error: fmt.Errorf("selection failed")} + } + + r.logger.Debug("arm selected", + "arm", best.ID, + "task_type", task.Type, + "complexity", task.ComplexityScore, + ) + + return RoutingDecision{Strategy: StrategySingleArm, Arm: best} +} + +// Arms returns all registered arms. +func (r *Router) Arms() []*Arm { + r.mu.RLock() + defer r.mu.RUnlock() + arms := make([]*Arm, 0, len(r.arms)) + for _, a := range r.arms { + arms = append(arms, a) + } + return arms +} + +// RegisterProvider registers all models from a provider as arms. +func (r *Router) RegisterProvider(ctx context.Context, prov provider.Provider, isLocal bool, costs map[string][2]float64) { + models, err := prov.Models(ctx) + if err != nil { + r.logger.Debug("failed to list models", "provider", prov.Name(), "error", err) + // Register at least the default model + id := NewArmID(prov.Name(), prov.DefaultModel()) + r.RegisterArm(&Arm{ + ID: id, + Provider: prov, + ModelName: prov.DefaultModel(), + IsLocal: isLocal, + Capabilities: provider.Capabilities{ToolUse: true}, // optimistic + }) + return + } + + for _, m := range models { + id := NewArmID(prov.Name(), m.ID) + arm := &Arm{ + ID: id, + Provider: prov, + ModelName: m.ID, + IsLocal: isLocal, + Capabilities: m.Capabilities, + } + if c, ok := costs[m.ID]; ok { + arm.CostPer1kInput = c[0] + arm.CostPer1kOutput = c[1] + } + r.RegisterArm(arm) + } +} + +// Stream is a convenience that selects an arm and streams from it. +func (r *Router) Stream(ctx context.Context, task Task, req provider.Request) (stream.Stream, *Arm, error) { + decision := r.Select(task) + if decision.Error != nil { + return nil, nil, decision.Error + } + + arm := decision.Arm + req.Model = arm.ModelName + + s, err := arm.Provider.Stream(ctx, req) + if err != nil { + return nil, arm, err + } + return s, arm, nil +} diff --git a/internal/router/router_test.go b/internal/router/router_test.go new file mode 100644 index 0000000..1f5cf9a --- /dev/null +++ b/internal/router/router_test.go @@ -0,0 +1,305 @@ +package router + +import ( + "math" + "testing" + "time" + + "somegit.dev/Owlibou/gnoma/internal/provider" +) + +// --- Task Classification --- + +func TestClassifyTask(t *testing.T) { + tests := []struct { + prompt string + want TaskType + }{ + {"fix the bug in auth module", TaskDebug}, + {"review this pull request", TaskReview}, + {"refactor the database layer", TaskRefactor}, + {"write unit tests for the handler", TaskUnitTest}, + {"explain how the router works", TaskExplain}, + {"create a new REST endpoint", TaskGeneration}, + {"plan the migration strategy", TaskPlanning}, + {"audit security of the API", TaskSecurityReview}, + {"scaffold a new service", TaskBoilerplate}, + {"coordinate the deployment pipeline", TaskOrchestration}, + {"hello", TaskGeneration}, // default + } + for _, tt := range tests { + task := ClassifyTask(tt.prompt) + if task.Type != tt.want { + t.Errorf("ClassifyTask(%q).Type = %s, want %s", tt.prompt, task.Type, tt.want) + } + } +} + +func TestClassifyTask_RequiresTools(t *testing.T) { + // Explain tasks don't require tools + task := ClassifyTask("explain how generics work") + if task.RequiresTools { + t.Error("explain task should not require tools") + } + + // Debug tasks require tools + task = ClassifyTask("debug the failing test") + if !task.RequiresTools { + t.Error("debug task should require tools") + } +} + +func TestTaskValueScore(t *testing.T) { + low := Task{Type: TaskBoilerplate, Priority: PriorityLow} + high := Task{Type: TaskSecurityReview, Priority: PriorityCritical} + + if low.ValueScore() >= high.ValueScore() { + t.Errorf("low priority boilerplate (%f) should score less than critical security (%f)", + low.ValueScore(), high.ValueScore()) + } +} + +func TestEstimateComplexity(t *testing.T) { + simple := estimateComplexity("rename the variable") + complex := estimateComplexity("design and implement a distributed caching system with migration support and integration testing across multiple environments") + + if simple >= complex { + t.Errorf("simple (%f) should be less than complex (%f)", simple, complex) + } +} + +// --- Arm --- + +func TestArmEstimateCost(t *testing.T) { + arm := &Arm{ + CostPer1kInput: 0.003, + CostPer1kOutput: 0.015, + } + cost := arm.EstimateCost(10000) + // 6000 input tokens * 0.003/1k + 4000 output tokens * 0.015/1k + // = 0.018 + 0.060 = 0.078 + if cost < 0.07 || cost > 0.09 { + t.Errorf("EstimateCost(10000) = %f, want ~0.078", cost) + } +} + +func TestArmEstimateCost_Free(t *testing.T) { + arm := &Arm{} // local model, zero cost + cost := arm.EstimateCost(10000) + if cost != 0 { + t.Errorf("free model should have zero cost, got %f", cost) + } +} + +// --- Pool --- + +func TestLimitPool_RemainingFraction(t *testing.T) { + p := &LimitPool{TotalLimit: 100, Used: 30, Reserved: 20} + f := p.RemainingFraction() + if f != 0.5 { + t.Errorf("RemainingFraction = %f, want 0.5", f) + } +} + +func TestLimitPool_ScarcityMultiplier(t *testing.T) { + // Half remaining, k=2: 1/0.5^2 = 4 + p := &LimitPool{TotalLimit: 100, Used: 50, ScarcityK: 2} + m := p.ScarcityMultiplier() + if m < 3.9 || m > 4.1 { + t.Errorf("ScarcityMultiplier = %f, want ~4.0", m) + } +} + +func TestLimitPool_ScarcityMultiplier_Exhausted(t *testing.T) { + p := &LimitPool{TotalLimit: 100, Used: 100} + m := p.ScarcityMultiplier() + if !math.IsInf(m, 1) { + t.Errorf("exhausted pool should return +Inf, got %f", m) + } +} + +func TestLimitPool_ScarcityMultiplier_UseItOrLoseIt(t *testing.T) { + p := &LimitPool{ + TotalLimit: 100, Used: 30, // 70% remaining + ScarcityK: 2, + ResetAt: time.Now().Add(30 * time.Minute), // reset in 30 min + } + m := p.ScarcityMultiplier() + if m != 0.5 { + t.Errorf("use-it-or-lose-it discount: ScarcityMultiplier = %f, want 0.5", m) + } +} + +func TestLimitPool_ReserveAndCommit(t *testing.T) { + p := &LimitPool{ + TotalLimit: 1000, + ArmRates: map[ArmID]float64{"test/model": 5.0}, // 5 units per 1k tokens + ScarcityK: 2, + } + + res, ok := p.Reserve("test/model", 10000) // 5 * 10 = 50 units + if !ok { + t.Fatal("reservation should succeed") + } + if p.Reserved != 50 { + t.Errorf("Reserved = %f, want 50", p.Reserved) + } + + res.Commit(8000) // actual: 5 * 8 = 40 units + if p.Used != 40 { + t.Errorf("Used = %f, want 40", p.Used) + } + if p.Reserved != 0 { + t.Errorf("Reserved = %f, want 0 after commit", p.Reserved) + } +} + +func TestLimitPool_ReserveExhausted(t *testing.T) { + p := &LimitPool{ + TotalLimit: 100, + Used: 90, + ArmRates: map[ArmID]float64{"test/model": 5.0}, + ScarcityK: 2, + } + + _, ok := p.Reserve("test/model", 10000) // needs 50, only 10 available + if ok { + t.Error("reservation should fail when exhausted") + } +} + +func TestLimitPool_Rollback(t *testing.T) { + p := &LimitPool{ + TotalLimit: 1000, + ArmRates: map[ArmID]float64{"test/model": 5.0}, + ScarcityK: 2, + } + + res, _ := p.Reserve("test/model", 10000) + if p.Reserved != 50 { + t.Fatalf("Reserved = %f, want 50", p.Reserved) + } + + res.Rollback() + if p.Reserved != 0 { + t.Errorf("Reserved = %f, want 0 after rollback", p.Reserved) + } + if p.Used != 0 { + t.Errorf("Used = %f, want 0 after rollback", p.Used) + } +} + +func TestLimitPool_CheckReset(t *testing.T) { + p := &LimitPool{ + TotalLimit: 1000, + Used: 500, + Reserved: 100, + ResetPeriod: time.Hour, + ResetAt: time.Now().Add(-time.Minute), // already passed + ScarcityK: 2, + } + + p.CheckReset() + if p.Used != 0 { + t.Errorf("Used = %f after reset, want 0", p.Used) + } + if p.Reserved != 0 { + t.Errorf("Reserved = %f after reset, want 0", p.Reserved) + } +} + +// --- Selector --- + +func TestSelectBest_PrefersToolSupport(t *testing.T) { + withTools := &Arm{ + ID: "a/with-tools", ModelName: "with-tools", + Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 128000}, + } + withoutTools := &Arm{ + ID: "b/no-tools", ModelName: "no-tools", + Capabilities: provider.Capabilities{ToolUse: false, ContextWindow: 128000}, + } + + task := Task{Type: TaskGeneration, RequiresTools: true, Priority: PriorityNormal} + best := selectBest([]*Arm{withoutTools, withTools}, task) + + if best.ID != "a/with-tools" { + t.Errorf("should prefer arm with tool support, got %s", best.ID) + } +} + +func TestSelectBest_PrefersThinkingForPlanning(t *testing.T) { + thinking := &Arm{ + ID: "a/thinking", ModelName: "thinking", + Capabilities: provider.Capabilities{ToolUse: true, Thinking: true, ContextWindow: 200000}, + CostPer1kInput: 0.01, CostPer1kOutput: 0.05, + } + noThinking := &Arm{ + ID: "b/basic", ModelName: "basic", + Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 128000}, + CostPer1kInput: 0.01, CostPer1kOutput: 0.05, + } + + task := Task{Type: TaskPlanning, RequiresTools: true, Priority: PriorityNormal, EstimatedTokens: 5000} + best := selectBest([]*Arm{noThinking, thinking}, task) + + if best.ID != "a/thinking" { + t.Errorf("should prefer thinking model for planning, got %s", best.ID) + } +} + +func TestFilterFeasible_ExcludesExhausted(t *testing.T) { + pool := &LimitPool{ + TotalLimit: 100, + Used: 100, // exhausted + ArmRates: map[ArmID]float64{"a/model": 1.0}, + ScarcityK: 2, + } + arm := &Arm{ + ID: "a/model", ModelName: "model", + Capabilities: provider.Capabilities{ToolUse: true}, + Pools: []*LimitPool{pool}, + } + + task := Task{Type: TaskGeneration, RequiresTools: true, EstimatedTokens: 1000} + feasible := filterFeasible([]*Arm{arm}, task) + + if len(feasible) != 0 { + t.Error("exhausted arm should not be feasible") + } +} + +// --- Router --- + +func TestRouter_SelectForced(t *testing.T) { + r := New(Config{}) + r.RegisterArm(&Arm{ID: "a/model1", Capabilities: provider.Capabilities{ToolUse: true}}) + r.RegisterArm(&Arm{ID: "b/model2", Capabilities: provider.Capabilities{ToolUse: true}}) + + r.ForceArm("b/model2") + + decision := r.Select(Task{Type: TaskGeneration}) + if decision.Error != nil { + t.Fatalf("Select: %v", decision.Error) + } + if decision.Arm.ID != "b/model2" { + t.Errorf("forced arm should be selected, got %s", decision.Arm.ID) + } +} + +func TestRouter_SelectNoArms(t *testing.T) { + r := New(Config{}) + decision := r.Select(Task{Type: TaskGeneration}) + if decision.Error == nil { + t.Error("should error with no arms") + } +} + +func TestRouter_SelectForcedNotFound(t *testing.T) { + r := New(Config{}) + r.ForceArm("nonexistent/model") + decision := r.Select(Task{Type: TaskGeneration}) + if decision.Error == nil { + t.Error("should error when forced arm not found") + } +} diff --git a/internal/router/selector.go b/internal/router/selector.go new file mode 100644 index 0000000..65c4ae1 --- /dev/null +++ b/internal/router/selector.go @@ -0,0 +1,167 @@ +package router + +import ( + "math" +) + +// Strategy identifies how a task should be executed. +type Strategy int + +const ( + StrategySingleArm Strategy = iota + // Future (M9): StrategyCascade, StrategyParallelEnsemble, StrategyMultiRound +) + +// RoutingDecision is the result of arm selection. +type RoutingDecision struct { + Strategy Strategy + Arm *Arm // primary arm + Error error +} + +// selectBest picks the highest-scoring feasible arm using heuristic scoring. +// No bandit learning — that's M9. Just smart defaults based on model size, +// locality, task type, cost, and pool scarcity. +func selectBest(arms []*Arm, task Task) *Arm { + if len(arms) == 0 { + return nil + } + + var best *Arm + bestScore := math.Inf(-1) + + for _, arm := range arms { + score := scoreArm(arm, task) + if score > bestScore { + bestScore = score + best = arm + } + } + + return best +} + +// scoreArm computes a heuristic quality/cost score for an arm. +// Score = (quality × value) / effective_cost +func scoreArm(arm *Arm, task Task) float64 { + quality := heuristicQuality(arm, task) + value := task.ValueScore() + cost := effectiveCost(arm, task) + + if cost <= 0 { + cost = 0.001 // prevent division by zero for free local models + } + + return (quality * value) / cost +} + +// heuristicQuality estimates arm quality without historical data. +func heuristicQuality(arm *Arm, task Task) float64 { + score := 0.5 // base + + // Larger context window = better for complex tasks + if arm.Capabilities.ContextWindow >= 100000 { + score += 0.1 + } + if arm.Capabilities.ContextWindow >= 200000 { + score += 0.05 + } + + // Thinking capability valuable for planning/orchestration/security + if arm.Capabilities.Thinking { + switch task.Type { + case TaskPlanning, TaskOrchestration, TaskSecurityReview: + score += 0.2 + case TaskDebug, TaskRefactor: + score += 0.1 + } + } + + // Tool support required — arm without tools gets heavy penalty + if task.RequiresTools && !arm.SupportsTools() { + score *= 0.1 + } + + // Local models get a small boost (no network latency, privacy) + if arm.IsLocal { + score += 0.05 + } + + // Complexity adjustment — complex tasks penalize small/local models + if task.ComplexityScore > 0.7 && arm.IsLocal { + score *= 0.7 + } + + // Clamp + if score > 1.0 { + score = 1.0 + } + if score < 0.0 { + score = 0.0 + } + return score +} + +// effectiveCost returns the base cost inflated by pool scarcity. +func effectiveCost(arm *Arm, task Task) float64 { + base := arm.EstimateCost(task.EstimatedTokens) + if base <= 0 { + base = 0.001 // local models are ~free but not zero for scoring + } + + // Apply maximum scarcity multiplier across all pools + maxMultiplier := 1.0 + for _, pool := range arm.Pools { + m := pool.ScarcityMultiplier() + if m > maxMultiplier { + maxMultiplier = m + } + } + + return base * maxMultiplier +} + +// filterFeasible returns arms that can handle the task (tools, pool capacity). +func filterFeasible(arms []*Arm, task Task) []*Arm { + var feasible []*Arm + for _, arm := range arms { + // Must support tools if task requires them + if task.RequiresTools && !arm.SupportsTools() { + continue + } + + // Check all pools have capacity + poolsOK := true + for _, pool := range arm.Pools { + pool.CheckReset() + if !pool.CanAfford(arm.ID, task.EstimatedTokens) { + poolsOK = false + break + } + } + if !poolsOK { + continue + } + + feasible = append(feasible, arm) + } + + // If no arm with tools is feasible but task requires them, + // fall back to any available arm (tool-less is better than nothing) + if len(feasible) == 0 && task.RequiresTools { + for _, arm := range arms { + poolsOK := true + for _, pool := range arm.Pools { + if !pool.CanAfford(arm.ID, task.EstimatedTokens) { + poolsOK = false + break + } + } + if poolsOK { + feasible = append(feasible, arm) + } + } + } + + return feasible +} diff --git a/internal/router/task.go b/internal/router/task.go new file mode 100644 index 0000000..c1320b7 --- /dev/null +++ b/internal/router/task.go @@ -0,0 +1,199 @@ +package router + +import ( + "fmt" + "strings" +) + +// TaskType classifies a task for routing purposes. +type TaskType int + +const ( + TaskBoilerplate TaskType = iota // simple scaffolding, templates + TaskGeneration // new code creation + TaskRefactor // restructuring existing code + TaskReview // code review, analysis + TaskUnitTest // writing tests + TaskPlanning // architecture, design + TaskOrchestration // multi-step coordination + TaskSecurityReview // security-focused analysis + TaskDebug // finding and fixing bugs + TaskExplain // explaining code or concepts +) + +func (t TaskType) String() string { + switch t { + case TaskBoilerplate: + return "boilerplate" + case TaskGeneration: + return "generation" + case TaskRefactor: + return "refactor" + case TaskReview: + return "review" + case TaskUnitTest: + return "unit_test" + case TaskPlanning: + return "planning" + case TaskOrchestration: + return "orchestration" + case TaskSecurityReview: + return "security_review" + case TaskDebug: + return "debug" + case TaskExplain: + return "explain" + default: + return fmt.Sprintf("unknown(%d)", t) + } +} + +// Priority indicates task importance for routing decisions. +type Priority int + +const ( + PriorityLow Priority = iota + PriorityNormal + PriorityHigh + PriorityCritical +) + +// Task represents a classified unit of work for routing. +type Task struct { + Type TaskType + Priority Priority + EstimatedTokens int + RequiresTools bool + ComplexityScore float64 // 0-1 +} + +// ValueScore computes a routing value based on priority and type. +func (t Task) ValueScore() float64 { + base := map[Priority]float64{ + PriorityLow: 0.5, + PriorityNormal: 1.0, + PriorityHigh: 2.0, + PriorityCritical: 5.0, + }[t.Priority] + + return base * taskTypeMultiplier[t.Type] +} + +var taskTypeMultiplier = map[TaskType]float64{ + TaskBoilerplate: 0.6, + TaskGeneration: 1.0, + TaskRefactor: 0.9, + TaskReview: 1.1, + TaskUnitTest: 0.8, + TaskPlanning: 1.4, + TaskOrchestration: 1.5, + TaskSecurityReview: 2.0, + TaskDebug: 1.2, + TaskExplain: 0.7, +} + +// QualityThreshold defines minimum acceptable quality for a task type. +type QualityThreshold struct { + Minimum float64 // below → output is harmful, never accept + Acceptable float64 // good enough + Target float64 // ideal +} + +var DefaultThresholds = map[TaskType]QualityThreshold{ + TaskBoilerplate: {0.50, 0.70, 0.80}, + TaskGeneration: {0.60, 0.75, 0.88}, + TaskRefactor: {0.65, 0.78, 0.90}, + TaskReview: {0.70, 0.82, 0.92}, + TaskUnitTest: {0.60, 0.75, 0.85}, + TaskPlanning: {0.75, 0.88, 0.95}, + TaskOrchestration: {0.80, 0.90, 0.96}, + TaskSecurityReview: {0.88, 0.94, 0.99}, + TaskDebug: {0.65, 0.80, 0.90}, + TaskExplain: {0.55, 0.72, 0.85}, +} + +// ClassifyTask infers a TaskType from the user's prompt using keyword heuristics. +func ClassifyTask(prompt string) Task { + lower := strings.ToLower(prompt) + + task := Task{ + Priority: PriorityNormal, + RequiresTools: true, // assume tools needed by default + } + + // Check for task type keywords (order matters — more specific first) + switch { + case containsAny(lower, "security", "vulnerability", "cve", "owasp", "xss", "injection", "audit security"): + task.Type = TaskSecurityReview + task.Priority = PriorityHigh + case containsAny(lower, "plan", "architect", "design", "strategy", "roadmap"): + task.Type = TaskPlanning + case containsAny(lower, "orchestrat", "coordinate", "dispatch", "pipeline"): + task.Type = TaskOrchestration + task.Priority = PriorityHigh + case containsAny(lower, "debug", "fix", "troubleshoot", "not working", "error", "crash", "failing", "bug"): + task.Type = TaskDebug + case containsAny(lower, "review", "check", "analyze", "audit", "inspect"): + task.Type = TaskReview + case containsAny(lower, "refactor", "restructure", "reorganize", "clean up", "simplify"): + task.Type = TaskRefactor + case containsAny(lower, "test", "spec", "coverage", "assert"): + task.Type = TaskUnitTest + case containsAny(lower, "explain", "what is", "how does", "describe", "tell me about"): + task.Type = TaskExplain + task.RequiresTools = false + case containsAny(lower, "create", "implement", "build", "add", "write", "generate", "make"): + task.Type = TaskGeneration + case containsAny(lower, "scaffold", "boilerplate", "template", "stub", "skeleton"): + task.Type = TaskBoilerplate + default: + task.Type = TaskGeneration // default + } + + // Estimate complexity from prompt length and keywords + task.ComplexityScore = estimateComplexity(lower) + + return task +} + +func containsAny(s string, keywords ...string) bool { + for _, kw := range keywords { + if strings.Contains(s, kw) { + return true + } + } + return false +} + +func estimateComplexity(prompt string) float64 { + score := 0.0 + + // Length contributes to complexity + words := len(strings.Fields(prompt)) + score += float64(words) / 200.0 // normalize: 200 words = 1.0 + + // Complexity keywords + complexKeywords := []string{"implement", "design", "architect", "system", "integration", "migrate", "optimize"} + for _, kw := range complexKeywords { + if strings.Contains(prompt, kw) { + score += 0.15 + } + } + + // Simple keywords reduce complexity + simpleKeywords := []string{"rename", "format", "add field", "change name", "typo", "simple"} + for _, kw := range simpleKeywords { + if strings.Contains(prompt, kw) { + score -= 0.15 + } + } + + // Clamp to [0, 1] + if score < 0 { + score = 0 + } + if score > 1 { + score = 1 + } + return score +}