Files
gnoma/internal/slm/classifier.go
T
vikingowl a14fe8b504 feat(slm): pluggable backends + trivial-prompt routing
The SLM had two intended jobs — classify every prompt and execute the
small ones itself — but in practice three independent gates kept it
out of nearly all real work:

  1. llamafile cold-start blocked pipe-mode runs (always faster than
     the 15 s health check)
  2. ClassifyTask defaulted RequiresTools=true, excluding the SLM arm
     (ToolUse=false) from 9/10 task types
  3. armTier hard-coded CLI agents > local > API, so even when the SLM
     arm was feasible a CLI agent won

Each gate is addressed below. The result is an SLM that actually does
its job — small stuff stays local, complex stuff routes up — gated by
arm capability rather than by accidents of the boot order.

Backend layer (the bigger change)

The original implementation hard-coded llamafile. That's fine if you
have nothing else, but most users with a local model setup already run
Ollama or llama.cpp. The new factory at internal/slm/backend.go picks
between:

  - ollama (any local Ollama daemon)
  - llamacpp (any llama.cpp server)
  - llamafile (gnoma-managed, current behaviour)
  - openaicompat (LM Studio, vLLM, remote API)
  - auto (probes in order, picks first reachable)
  - disabled

[slm].backend in config.toml selects which. Documented in
docs/slm-backends.md with copy-paste presets for each. The factory
probes the underlying model's actual capabilities (Ollama /api/show,
llama.cpp /props) and sets the SLM arm's ToolUse accordingly — so the
arm picks up simple file-read style tasks on tool-capable models and
stays knowledge-only on completion-only models.

Trivial-prompt heuristic (Gate 2)

ClassifyTask now flips RequiresTools=false for short, low-complexity
prompts whose task type doesn't imply existing code (Explain,
Generation, Boilerplate). Tool-needing tokens (read, write, run, test,
file, …) keep RequiresTools=true even when the prompt is brief.

Complexity-aware tier ordering (Gate 3)

armTier takes a Task and returns tier 0 for arms whose MaxComplexity
ceiling fits the task. CLI agents drop to tier 1, local to 2, API to 3.
For trivial tasks the SLM arm wins; for complex tasks the SLM falls
out of the feasible set (MaxComplexity exclusion) and the original
ordering reasserts.

Eager boot with user-facing wait (Gate 1)

Removed the original goroutine-only path. SLM startup now blocks
synchronously inside the factory; for llamafile that means up to
[slm].startup_timeout (default 5 s) of waiting on the first
invocation, with "Starting SLM…" → "SLM ready (backend, model, tools,
boot=N)" / "SLM unavailable: …" messages on stderr. Ollama / llamacpp
backends boot instantly because the daemon is already running.

waitHealthy() now respects the caller's context deadline instead of
its old hardcoded 15 s ceiling.

Classifier reliability

Classifier timeout bumped 2 s → 5 s for thinking-mode models like
Qwen3-distilled Tiny3.5. System prompt includes /no_think directive
for the same family. These help but don't eliminate small-model
JSON-contract failures — see the docs section on picking a model.

Probe + telemetry surfaces

gnoma slm status now prints the configured backend + model + a live
probe result (✓/✗) instead of just the llamafile manifest state.

`gnoma router stats` already (from the previous commit) shows the
classifier-source mix; with this change you can finally see slm /
slm_fallback / heuristic share rise from "always heuristic" to
something reflecting real SLM activity.

Tests

  - 9 new backend-factory tests (httptest-backed Ollama probe, error
    paths, auto-detection, capability flags)
  - Tier-ordering tests cover the new "specialised small arm wins
    trivial task" path
  - Trivial-prompt heuristic tested for both halves (knowledge-only
    flips RequiresTools=false; debug/file/run keeps it true)

Deletes the dead SLMManager field from the TUI Config — it was
declared but never read.
2026-05-19 18:53:32 +02:00

156 lines
4.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package slm
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strings"
"time"
"somegit.dev/Owlibou/gnoma/internal/message"
"somegit.dev/Owlibou/gnoma/internal/provider"
"somegit.dev/Owlibou/gnoma/internal/router"
"somegit.dev/Owlibou/gnoma/internal/stream"
)
// defaultClassifyTimeout — 5 s accommodates thinking-mode models like
// Qwen3 distillations (Tiny3.5) that emit reasoning tokens before output.
// Non-thinking models complete in well under 1 s.
const defaultClassifyTimeout = 5 * time.Second
const classifySystemPrompt = `Classify the following coding request. /no_think
Respond with JSON only, no other text, no reasoning, no thinking tags.
Format: {"task_type": "<type>", "complexity": <0.0-1.0>, "requires_tools": <true|false>}
Task types: Debug, Explain, Generation, Refactor, UnitTest, Boilerplate, Planning, Orchestration, SecurityReview, Review
Complexity guide:
0.00.3: boilerplate, trivial edits, simple lookups, short explanations
0.40.6: new functions, refactors, unit tests, moderate analysis
0.71.0: architectural changes, multi-file edits, security review, planning`
type classifyResponse struct {
TaskType string `json:"task_type"`
Complexity float64 `json:"complexity"`
RequiresTools bool `json:"requires_tools"`
}
// Classifier implements router.TaskClassifier using a llamafile-hosted SLM.
// On timeout or parse failure it falls back to router.HeuristicClassifier.
type Classifier struct {
provider provider.Provider
model string
timeout time.Duration
logger *slog.Logger
}
// NewClassifier creates a Classifier. model is the model name passed to the provider
// (llamafile ignores it but openaicompat requires a non-empty value).
func NewClassifier(p provider.Provider, model string, logger *slog.Logger) *Classifier {
if logger == nil {
logger = slog.Default()
}
return &Classifier{
provider: p,
model: model,
timeout: defaultClassifyTimeout,
logger: logger,
}
}
// Classify calls the SLM and overlays the three SLM-authoritative fields
// (Type, ComplexityScore, RequiresTools) onto a heuristic baseline Task.
// This ensures Priority, EstimatedTokens, and RequiredEffort are always set.
func (c *Classifier) Classify(ctx context.Context, prompt string, history []message.Message) (router.Task, error) {
tctx, cancel := context.WithTimeout(ctx, c.timeout)
defer cancel()
resp, err := c.callSLM(tctx, prompt)
if err != nil {
c.logger.Debug("slm classify fallback", "error", err)
t, ferr := router.HeuristicClassifier{}.Classify(ctx, prompt, history)
t.ClassifierSource = router.ClassifierSLMFallback
return t, ferr
}
// Start from the heuristic baseline so Priority/EstimatedTokens/RequiredEffort are set.
task := router.ClassifyTask(prompt)
task.Type = router.ParseTaskType(resp.TaskType)
task.ComplexityScore = resp.Complexity
task.RequiresTools = resp.RequiresTools
task.ClassifierSource = router.ClassifierSLM
return task, nil
}
func (c *Classifier) callSLM(ctx context.Context, prompt string) (*classifyResponse, error) {
req := provider.Request{
Model: c.model,
SystemPrompt: classifySystemPrompt,
Messages: []message.Message{
{
Role: message.RoleUser,
Content: []message.Content{{Type: message.ContentText, Text: prompt}},
},
},
}
strm, err := c.provider.Stream(ctx, req)
if err != nil {
return nil, fmt.Errorf("stream: %w", err)
}
defer func() { _ = strm.Close() }()
var sb strings.Builder
for strm.Next() {
ev := strm.Current()
if ev.Type == stream.EventTextDelta {
sb.WriteString(ev.Text)
}
}
if err := strm.Err(); err != nil {
return nil, fmt.Errorf("stream error: %w", err)
}
text := extractJSON(sb.String())
var resp classifyResponse
if err := json.Unmarshal([]byte(text), &resp); err != nil {
return nil, fmt.Errorf("parse %q: %w", text, err)
}
return &resp, nil
}
// extractJSON pulls the first {...} substring from s, stripping markdown fences if present.
func extractJSON(s string) string {
s = strings.TrimSpace(s)
// Strip ```json ... ``` fences.
if strings.HasPrefix(s, "```") {
end := strings.LastIndex(s, "```")
if end > 3 {
inner := s[3:end]
inner = strings.TrimPrefix(inner, "json")
s = strings.TrimSpace(inner)
}
}
// Extract first balanced {...} block.
start := strings.IndexByte(s, '{')
if start < 0 {
return s
}
depth := 0
for i := start; i < len(s); i++ {
switch s[i] {
case '{':
depth++
case '}':
depth--
if depth == 0 {
return s[start : i+1]
}
}
}
return s[start:]
}