feat(classifier): Wave A — TaskClassifier interface + HeuristicClassifier

- internal/router/classifier.go: TaskClassifier interface with
  Classify(ctx, prompt, history) signature. HeuristicClassifier wraps
  the existing ClassifyTask() with zero behavior change.

- engine.Config.Classifier: injectable TaskClassifier; nil defaults
  to HeuristicClassifier. Engine.classify() helper handles nil + error
  fallback transparently.

- loop.go: all four router.ClassifyTask() call sites replaced with
  e.classify(ctx, prompt). SLMClassifier slots in without further
  changes to the engine.
This commit is contained in:
2026-05-07 16:11:20 +02:00
parent 0b1392cf6b
commit 8b2202e8ec
5 changed files with 210 additions and 16 deletions
+28 -12
View File
@@ -6,30 +6,31 @@ import (
"log/slog"
gnomactx "somegit.dev/Owlibou/gnoma/internal/context"
"somegit.dev/Owlibou/gnoma/internal/hook"
"somegit.dev/Owlibou/gnoma/internal/message"
"somegit.dev/Owlibou/gnoma/internal/permission"
"somegit.dev/Owlibou/gnoma/internal/provider"
"somegit.dev/Owlibou/gnoma/internal/router"
"somegit.dev/Owlibou/gnoma/internal/security"
"somegit.dev/Owlibou/gnoma/internal/tool"
"somegit.dev/Owlibou/gnoma/internal/hook"
"somegit.dev/Owlibou/gnoma/internal/tool/persist"
)
// Config holds engine configuration.
type Config struct {
Provider provider.Provider // direct provider (used if Router is nil)
Router *router.Router // nil = use Provider directly
Provider provider.Provider // direct provider (used if Router is nil)
Router *router.Router // nil = use Provider directly
Classifier router.TaskClassifier // nil = HeuristicClassifier
Tools *tool.Registry
Firewall *security.Firewall // nil = no scanning
Permissions *permission.Checker // nil = allow all
Context *gnomactx.Window // nil = no compaction
System string // system prompt
Model string // override model (empty = provider default)
Temperature *float64 // nil = provider default
MaxTurns int // safety limit on tool loops (0 = unlimited)
Store *persist.Store // nil = no result persistence
Hooks *hook.Dispatcher // nil = no hooks
Firewall *security.Firewall // nil = no scanning
Permissions *permission.Checker // nil = allow all
Context *gnomactx.Window // nil = no compaction
System string // system prompt
Model string // override model (empty = provider default)
Temperature *float64 // nil = provider default
MaxTurns int // safety limit on tool loops (0 = unlimited)
Store *persist.Store // nil = no result persistence
Hooks *hook.Dispatcher // nil = no hooks
Logger *slog.Logger
}
@@ -228,6 +229,21 @@ func (e *Engine) SetActivatedTools(tools map[string]bool) {
e.activatedTools = tools
}
// classify returns a Task for the given prompt using the configured classifier.
// Falls back to HeuristicClassifier if none is configured or if classification fails.
func (e *Engine) classify(ctx context.Context, prompt string) router.Task {
cls := e.cfg.Classifier
if cls == nil {
cls = router.HeuristicClassifier{}
}
task, err := cls.Classify(ctx, prompt, e.history)
if err != nil {
e.logger.Debug("classifier error, falling back to heuristic", "error", err)
return router.ClassifyTask(prompt)
}
return task
}
// Reset clears conversation history and usage.
func (e *Engine) Reset() {
e.history = nil
+88
View File
@@ -579,6 +579,94 @@ func TestSubmit_CumulativeUsage(t *testing.T) {
}
}
// spyClassifier records calls and delegates to HeuristicClassifier.
type spyClassifier struct {
calls int
result *router.Task // when non-nil, return this instead of heuristic result
}
func (s *spyClassifier) Classify(ctx context.Context, prompt string, history []message.Message) (router.Task, error) {
s.calls++
if s.result != nil {
return *s.result, nil
}
return router.HeuristicClassifier{}.Classify(ctx, prompt, history)
}
func TestSubmit_UsesInjectedClassifier(t *testing.T) {
rtr := router.New(router.Config{})
armID := router.NewArmID("test", "mock-model")
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
newEventStream(message.StopEndTurn, "mock-model",
stream.Event{Type: stream.EventTextDelta, Text: "ok"},
),
},
}
rtr.RegisterArm(&router.Arm{
ID: armID,
Provider: mp,
ModelName: "mock-model",
Capabilities: provider.Capabilities{ToolUse: true},
})
rtr.ForceArm(armID)
spy := &spyClassifier{}
e, err := New(Config{
Provider: mp,
Router: rtr,
Tools: tool.NewRegistry(),
Classifier: spy,
})
if err != nil {
t.Fatalf("New: %v", err)
}
if _, err := e.Submit(context.Background(), "implement a parser", nil); err != nil {
t.Fatalf("Submit: %v", err)
}
if spy.calls == 0 {
t.Error("expected Classify to be called at least once, got 0 calls")
}
}
func TestSubmit_NilClassifierFallsBackToHeuristic(t *testing.T) {
rtr := router.New(router.Config{})
armID := router.NewArmID("test", "mock-model")
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
newEventStream(message.StopEndTurn, "mock-model",
stream.Event{Type: stream.EventTextDelta, Text: "ok"},
),
},
}
rtr.RegisterArm(&router.Arm{
ID: armID,
Provider: mp,
ModelName: "mock-model",
Capabilities: provider.Capabilities{ToolUse: true},
})
rtr.ForceArm(armID)
// No Classifier set — should not panic, should use heuristic
e, err := New(Config{
Provider: mp,
Router: rtr,
Tools: tool.NewRegistry(),
})
if err != nil {
t.Fatalf("New: %v", err)
}
_, err = e.Submit(context.Background(), "debug the server crash", nil)
if err != nil {
t.Fatalf("Submit with nil Classifier: %v", err)
}
}
func TestSubmit_ReportsOutcomeToRouter(t *testing.T) {
rtr := router.New(router.Config{})
armID := router.NewArmID("test", "mock-model")
+4 -4
View File
@@ -97,7 +97,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
break
}
}
task := router.ClassifyTask(prompt)
task := e.classify(ctx, prompt)
if e.cfg.Context != nil {
task.EstimatedTokens = int(e.cfg.Context.Tracker().CountTokens(prompt))
} else {
@@ -151,7 +151,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
break
}
}
task := router.ClassifyTask(prompt)
task := e.classify(ctx, prompt)
if e.cfg.Context != nil {
task.EstimatedTokens = int(e.cfg.Context.Tracker().CountTokens(prompt))
} else {
@@ -376,7 +376,7 @@ func (e *Engine) buildRequest(ctx context.Context) provider.Request {
break
}
}
if router.ClassifyTask(prompt).Type == router.TaskOrchestration {
if e.classify(ctx, prompt).Type == router.TaskOrchestration {
req.SystemPrompt = coordinatorPrompt() + "\n\n" + req.SystemPrompt
}
}
@@ -596,7 +596,7 @@ func (e *Engine) handleRequestTooLarge(ctx context.Context, origErr error, req p
break
}
}
task := router.ClassifyTask(prompt)
task := e.classify(ctx, prompt)
if e.cfg.Context != nil {
task.EstimatedTokens = int(e.cfg.Context.Tracker().CountTokens(prompt))
} else {
+22
View File
@@ -0,0 +1,22 @@
package router
import (
"context"
"somegit.dev/Owlibou/gnoma/internal/message"
)
// TaskClassifier classifies a user prompt into a Task for routing decisions.
// The history slice provides prior conversation context; implementations may
// ignore it (HeuristicClassifier) or use it for richer inference (SLMClassifier).
type TaskClassifier interface {
Classify(ctx context.Context, prompt string, history []message.Message) (Task, error)
}
// HeuristicClassifier is the default classifier. It wraps the keyword-based
// ClassifyTask function and ignores conversation history.
type HeuristicClassifier struct{}
func (HeuristicClassifier) Classify(_ context.Context, prompt string, _ []message.Message) (Task, error) {
return ClassifyTask(prompt), nil
}
+68
View File
@@ -0,0 +1,68 @@
package router
import (
"context"
"testing"
"somegit.dev/Owlibou/gnoma/internal/message"
)
// TestHeuristicClassifier_ParityWithClassifyTask verifies that
// HeuristicClassifier.Classify produces identical results to ClassifyTask.
func TestHeuristicClassifier_ParityWithClassifyTask(t *testing.T) {
prompts := []string{
"debug the failing test",
"explain how generics work",
"implement a new HTTP handler",
"refactor the auth middleware",
"security audit the login flow",
"write unit tests for the parser",
"scaffold a new service",
"plan the migration strategy",
"orchestrate the deployment pipeline",
"review the pull request",
}
cls := HeuristicClassifier{}
ctx := context.Background()
var noHistory []message.Message
for _, p := range prompts {
want := ClassifyTask(p)
got, err := cls.Classify(ctx, p, noHistory)
if err != nil {
t.Errorf("Classify(%q) unexpected error: %v", p, err)
continue
}
if got.Type != want.Type {
t.Errorf("Classify(%q).Type = %s, want %s", p, got.Type, want.Type)
}
if got.ComplexityScore != want.ComplexityScore {
t.Errorf("Classify(%q).ComplexityScore = %v, want %v", p, got.ComplexityScore, want.ComplexityScore)
}
if got.RequiresTools != want.RequiresTools {
t.Errorf("Classify(%q).RequiresTools = %v, want %v", p, got.RequiresTools, want.RequiresTools)
}
if got.Priority != want.Priority {
t.Errorf("Classify(%q).Priority = %v, want %v", p, got.Priority, want.Priority)
}
}
}
// TestHeuristicClassifier_IgnoresHistory verifies that history has no effect
// on the heuristic classifier (it operates only on the prompt).
func TestHeuristicClassifier_IgnoresHistory(t *testing.T) {
cls := HeuristicClassifier{}
ctx := context.Background()
prompt := "implement a binary search function"
withoutHistory, _ := cls.Classify(ctx, prompt, nil)
withHistory, _ := cls.Classify(ctx, prompt, []message.Message{
{Role: message.RoleUser, Content: []message.Content{{Type: message.ContentText, Text: "previous message"}}},
})
if withoutHistory.Type != withHistory.Type {
t.Errorf("history should not affect HeuristicClassifier: got %s vs %s",
withoutHistory.Type, withHistory.Type)
}
}