diff --git a/internal/engine/coordinator_test.go b/internal/engine/coordinator_test.go index 4f0e08c..9993e92 100644 --- a/internal/engine/coordinator_test.go +++ b/internal/engine/coordinator_test.go @@ -7,16 +7,22 @@ import ( "somegit.dev/Owlibou/gnoma/internal/router" ) -func TestCoordinatorSystemPrompt_InjectedForOrchestration(t *testing.T) { +func TestCoordinatorPrompt_RequiredToolsMentioned(t *testing.T) { prompt := coordinatorPrompt() - if !strings.Contains(prompt, "spawn_elfs") { - t.Error("coordinator prompt must mention spawn_elfs") + for _, required := range []string{"spawn_elfs", "list_results", "read_result"} { + if !strings.Contains(prompt, required) { + t.Errorf("coordinator prompt must mention %q", required) + } } - if !strings.Contains(prompt, "list_results") { - t.Error("coordinator prompt must mention list_results") - } - if !strings.Contains(prompt, "read_result") { - t.Error("coordinator prompt must mention read_result") +} + +func TestCoordinatorPrompt_GuidanceContent(t *testing.T) { + prompt := strings.ToLower(coordinatorPrompt()) + // Prompt must instruct parallel dispatch, serial writes, and synthesis. + for _, required := range []string{"parallel", "serial", "synthesize"} { + if !strings.Contains(prompt, required) { + t.Errorf("coordinator prompt must contain guidance keyword %q", required) + } } } @@ -24,19 +30,29 @@ func TestShouldInjectCoordinatorPrompt(t *testing.T) { cases := []struct { prompt string want bool + note string }{ - {"orchestrate the migration", true}, - {"coordinate the refactor", true}, - {"dispatch tasks to elfs", true}, - {"fix the bug in main.go", false}, - {"explain this function", false}, - {"write unit tests for auth", false}, + // True positives: explicit orchestration intent + {"orchestrate the migration", true, "explicit orchestrat keyword"}, + {"coordinate the deployment", true, "explicit coordinate keyword"}, + {"dispatch tasks to elfs", true, "explicit dispatch keyword"}, + {"fan out this work to 5 elfs", true, "fan out keyword"}, + {"split this into subtasks", true, "subtask keyword"}, + {"delegate to worker elfs", true, "delegate to keyword"}, + // False positives: operational task types must gate first + {"fix the bug in main.go", false, "debug gates before orchestration"}, + {"explain this function", false, "explain gates before orchestration"}, + {"write unit tests for auth", false, "test gates before orchestration"}, + {"review the orchestration layer", false, "review gates even with orchestrat substring"}, + {"refactor the pipeline dispatch", false, "refactor gates even with dispatch substring"}, + {"debug the dispatch table", false, "debug gates even with dispatch substring"}, } for _, c := range cases { task := router.ClassifyTask(c.prompt) got := task.Type == router.TaskOrchestration if got != c.want { - t.Errorf("prompt %q: want orchestration=%v, got %v (type=%s)", c.prompt, c.want, got, task.Type) + t.Errorf("prompt %q (%s): want orchestration=%v, got %v (type=%s)", + c.prompt, c.note, c.want, got, task.Type) } } } diff --git a/internal/router/feedback_test.go b/internal/router/feedback_test.go index bab4802..a0e6e4d 100644 --- a/internal/router/feedback_test.go +++ b/internal/router/feedback_test.go @@ -3,6 +3,7 @@ package router_test import ( "testing" + "somegit.dev/Owlibou/gnoma/internal/provider" "somegit.dev/Owlibou/gnoma/internal/router" ) @@ -56,3 +57,59 @@ func TestQualityTracker_ConcurrentSafe(t *testing.T) { t.Errorf("invalid score after concurrent writes: %f", score) } } + +func TestQualityTracker_InfluencesArmSelection(t *testing.T) { + // After enough observations, the arm with a higher quality history should + // be preferred by Router.Select() over an identically-heuristic arm. + caps := provider.Capabilities{ToolUse: true} + armA := &router.Arm{ID: "test/arm-a", ModelName: "arm-a", Capabilities: caps} + armB := &router.Arm{ID: "test/arm-b", ModelName: "arm-b", Capabilities: caps} + + r := router.New(router.Config{}) + r.RegisterArm(armA) + r.RegisterArm(armB) + + // Record 5 successes for A, 5 failures for B — enough to exceed minObservations=3. + task := router.Task{Type: router.TaskGeneration, RequiresTools: true, Priority: router.PriorityNormal} + for range 5 { + r.ReportOutcome(router.Outcome{ArmID: "test/arm-a", TaskType: router.TaskGeneration, Success: true}) + r.ReportOutcome(router.Outcome{ArmID: "test/arm-b", TaskType: router.TaskGeneration, Success: false}) + } + + decision := r.Select(task) + if decision.Error != nil { + t.Fatalf("Select: %v", decision.Error) + } + defer decision.Rollback() + + if decision.Arm.ID != "test/arm-a" { + t.Errorf("expected arm-a (high quality history) to be selected, got %s", decision.Arm.ID) + } +} + +func TestQualityTracker_InsufficientDataFallsBackToHeuristic(t *testing.T) { + // Below minObservations (3), Quality() returns hasData=false and routing + // must still succeed (falls back to heuristic scoring). + caps := provider.Capabilities{ToolUse: true} + arm := &router.Arm{ID: "test/arm-x", ModelName: "arm-x", Capabilities: caps} + + r := router.New(router.Config{}) + r.RegisterArm(arm) + + // Only 1 observation — below the minimum. + r.ReportOutcome(router.Outcome{ArmID: "test/arm-x", TaskType: router.TaskGeneration, Success: true}) + + qt := r.QualityTracker() + _, hasData := qt.Quality("test/arm-x", router.TaskGeneration) + if hasData { + t.Error("expected no usable data below minObservations") + } + + // Router.Select must still succeed despite no quality data. + task := router.Task{Type: router.TaskGeneration, RequiresTools: true} + decision := r.Select(task) + if decision.Error != nil { + t.Errorf("Select should succeed via heuristic fallback: %v", decision.Error) + } + decision.Rollback() +} diff --git a/internal/tool/agent/agent.go b/internal/tool/agent/agent.go index d157a39..766996e 100644 --- a/internal/tool/agent/agent.go +++ b/internal/tool/agent/agent.go @@ -213,12 +213,7 @@ func (t *Tool) Execute(ctx context.Context, args json.RawMessage) (tool.Result, if result.Output != "" { // Truncate elf output to avoid flooding parent context. // The parent LLM gets enough to summarize; full text stays in the elf. - output := result.Output - const maxOutputChars = 2000 - if len(output) > maxOutputChars { - output = output[:maxOutputChars] + fmt.Sprintf("\n\n[truncated — full output was %d chars]", len(result.Output)) - } - b.WriteString(output) + b.WriteString(truncateOutput(result.Output, maxOutputChars)) } return tool.Result{ diff --git a/internal/tool/agent/agent_test.go b/internal/tool/agent/agent_test.go index 161c501..9293dfa 100644 --- a/internal/tool/agent/agent_test.go +++ b/internal/tool/agent/agent_test.go @@ -1,6 +1,8 @@ package agent import ( + "encoding/json" + "strings" "testing" "somegit.dev/Owlibou/gnoma/internal/router" @@ -50,3 +52,108 @@ func TestParseTaskType_AutoClassifiesWhenNoHint(t *testing.T) { } } } + +// --- Tool interface tests --- + +func TestAgentTool_Interface(t *testing.T) { + tool := New(nil, nil) + + if tool.Name() != "agent" { + t.Errorf("Name() = %q, want %q", tool.Name(), "agent") + } + if tool.Description() == "" { + t.Error("Description() must be non-empty") + } + if !json.Valid(tool.Parameters()) { + t.Error("Parameters() must be valid JSON") + } + if !tool.IsReadOnly() { + t.Error("IsReadOnly() must be true — agent spawning is non-destructive to parent context") + } + if tool.IsDestructive() { + t.Error("IsDestructive() must be false") + } +} + +func TestBatchTool_Interface(t *testing.T) { + tool := NewBatch(nil, nil) + + if tool.Name() != "spawn_elfs" { + t.Errorf("Name() = %q, want %q", tool.Name(), "spawn_elfs") + } + if tool.Description() == "" { + t.Error("Description() must be non-empty") + } + if !json.Valid(tool.Parameters()) { + t.Error("Parameters() must be valid JSON") + } + if !tool.IsReadOnly() { + t.Error("IsReadOnly() must be true") + } + if tool.IsDestructive() { + t.Error("IsDestructive() must be false") + } +} + +func TestListResultsTool_Interface(t *testing.T) { + tool := NewListResultsTool(nil) + + if tool.Name() != "list_results" { + t.Errorf("Name() = %q, want %q", tool.Name(), "list_results") + } + if !json.Valid(tool.Parameters()) { + t.Error("Parameters() must be valid JSON") + } + if !tool.IsReadOnly() { + t.Error("IsReadOnly() must be true") + } +} + +func TestReadResultTool_Interface(t *testing.T) { + tool := NewReadResultTool(nil) + + if tool.Name() != "read_result" { + t.Errorf("Name() = %q, want %q", tool.Name(), "read_result") + } + if !json.Valid(tool.Parameters()) { + t.Error("Parameters() must be valid JSON") + } + if !tool.IsReadOnly() { + t.Error("IsReadOnly() must be true") + } +} + +// --- Truncation tests --- + +func TestTruncateOutput(t *testing.T) { + tests := []struct { + name string + input string + max int + truncated bool // true = expect truncation; false = expect unchanged + want string // only checked when !truncated + }{ + {name: "short text unchanged", input: "hello", max: 2000, want: "hello"}, + {name: "empty string unchanged", input: "", max: 2000, want: ""}, + {name: "exact max unchanged", input: strings.Repeat("x", 2000), max: 2000, want: strings.Repeat("x", 2000)}, + {name: "over max is truncated", input: strings.Repeat("y", 5000), max: 2000, truncated: true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := truncateOutput(tt.input, tt.max) + if !tt.truncated { + if got != tt.want { + t.Errorf("got %q, want %q", got, tt.want) + } + return + } + // Truncated: result must start with first max chars and include notice. + if !strings.HasPrefix(got, strings.Repeat("y", tt.max)) { + t.Error("truncated result should start with first max chars of input") + } + if !strings.Contains(got, "[truncated") { + t.Error("truncated result must contain '[truncated' notice") + } + }) + } +} diff --git a/internal/tool/agent/batch.go b/internal/tool/agent/batch.go index 87f846f..1472aeb 100644 --- a/internal/tool/agent/batch.go +++ b/internal/tool/agent/batch.go @@ -222,12 +222,7 @@ func (t *BatchTool) Execute(ctx context.Context, args json.RawMessage) (tool.Res fmt.Fprintf(&b, "Error: %v\n", r.Error) } if r.Output != "" { - output := r.Output - const maxOutputChars = 2000 - if len(output) > maxOutputChars { - output = output[:maxOutputChars] + fmt.Sprintf("\n\n[truncated — full output was %d chars]", len(r.Output)) - } - b.WriteString(output) + b.WriteString(truncateOutput(r.Output, maxOutputChars)) } b.WriteString("\n\n") } diff --git a/internal/tool/agent/format.go b/internal/tool/agent/format.go new file mode 100644 index 0000000..da1fa25 --- /dev/null +++ b/internal/tool/agent/format.go @@ -0,0 +1,13 @@ +package agent + +import "fmt" + +const maxOutputChars = 2000 + +// truncateOutput truncates output to max characters, appending a note with the original length. +func truncateOutput(output string, max int) string { + if len(output) <= max { + return output + } + return output[:max] + fmt.Sprintf("\n\n[truncated — full output was %d chars]", len(output)) +}