a9213ec382
- slm.Classifier: openaicompat → llamafile, 2s timeout + heuristic fallback, heuristic baseline blended so Priority/RequiredEffort are never zeroed, extractJSON strips markdown fences from small-model responses - router.ParseTaskType: case-insensitive string → TaskType, unknown → TaskGeneration - router.Arm.MaxComplexity: zero = no ceiling (preserves existing arm behavior); filterFeasible excludes arms when task.ComplexityScore > MaxComplexity - config.SLMSection: [slm] enabled / model_url / data_dir - openaicompat.NewLlamafile: no API key, model = "default", no retries - slm.Manager: DefaultDataDir() (XDG), Manifest() accessor - cmd/gnoma: `gnoma slm setup` / `gnoma slm status` subcommands; SLM arm registered with MaxComplexity=0.3 when enabled + set up - tui: /config shows slm status (ready/missing/not set up + base URL if running) - docs: roadmap updated to reflect llamafile pivot from Ollama
149 lines
4.1 KiB
Go
149 lines
4.1 KiB
Go
package slm
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"log/slog"
|
||
"strings"
|
||
"time"
|
||
|
||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||
"somegit.dev/Owlibou/gnoma/internal/router"
|
||
"somegit.dev/Owlibou/gnoma/internal/stream"
|
||
)
|
||
|
||
const defaultClassifyTimeout = 2 * time.Second
|
||
|
||
const classifySystemPrompt = `Classify the following coding request. Respond with JSON only, no other text.
|
||
Format: {"task_type": "<type>", "complexity": <0.0-1.0>, "requires_tools": <true|false>}
|
||
|
||
Task types: Debug, Explain, Generation, Refactor, UnitTest, Boilerplate, Planning, Orchestration, SecurityReview, Review
|
||
|
||
Complexity guide:
|
||
0.0–0.3: boilerplate, trivial edits, simple lookups, short explanations
|
||
0.4–0.6: new functions, refactors, unit tests, moderate analysis
|
||
0.7–1.0: architectural changes, multi-file edits, security review, planning`
|
||
|
||
type classifyResponse struct {
|
||
TaskType string `json:"task_type"`
|
||
Complexity float64 `json:"complexity"`
|
||
RequiresTools bool `json:"requires_tools"`
|
||
}
|
||
|
||
// Classifier implements router.TaskClassifier using a llamafile-hosted SLM.
|
||
// On timeout or parse failure it falls back to router.HeuristicClassifier.
|
||
type Classifier struct {
|
||
provider provider.Provider
|
||
model string
|
||
timeout time.Duration
|
||
logger *slog.Logger
|
||
}
|
||
|
||
// NewClassifier creates a Classifier. model is the model name passed to the provider
|
||
// (llamafile ignores it but openaicompat requires a non-empty value).
|
||
func NewClassifier(p provider.Provider, model string, logger *slog.Logger) *Classifier {
|
||
if logger == nil {
|
||
logger = slog.Default()
|
||
}
|
||
return &Classifier{
|
||
provider: p,
|
||
model: model,
|
||
timeout: defaultClassifyTimeout,
|
||
logger: logger,
|
||
}
|
||
}
|
||
|
||
// Classify calls the SLM and overlays the three SLM-authoritative fields
|
||
// (Type, ComplexityScore, RequiresTools) onto a heuristic baseline Task.
|
||
// This ensures Priority, EstimatedTokens, and RequiredEffort are always set.
|
||
func (c *Classifier) Classify(ctx context.Context, prompt string, history []message.Message) (router.Task, error) {
|
||
tctx, cancel := context.WithTimeout(ctx, c.timeout)
|
||
defer cancel()
|
||
|
||
resp, err := c.callSLM(tctx, prompt)
|
||
if err != nil {
|
||
c.logger.Debug("slm classify fallback", "error", err)
|
||
return router.HeuristicClassifier{}.Classify(ctx, prompt, history)
|
||
}
|
||
|
||
// Start from the heuristic baseline so Priority/EstimatedTokens/RequiredEffort are set.
|
||
task := router.ClassifyTask(prompt)
|
||
task.Type = router.ParseTaskType(resp.TaskType)
|
||
task.ComplexityScore = resp.Complexity
|
||
task.RequiresTools = resp.RequiresTools
|
||
return task, nil
|
||
}
|
||
|
||
func (c *Classifier) callSLM(ctx context.Context, prompt string) (*classifyResponse, error) {
|
||
req := provider.Request{
|
||
Model: c.model,
|
||
SystemPrompt: classifySystemPrompt,
|
||
Messages: []message.Message{
|
||
{
|
||
Role: message.RoleUser,
|
||
Content: []message.Content{{Type: message.ContentText, Text: prompt}},
|
||
},
|
||
},
|
||
}
|
||
|
||
strm, err := c.provider.Stream(ctx, req)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("stream: %w", err)
|
||
}
|
||
defer strm.Close()
|
||
|
||
var sb strings.Builder
|
||
for strm.Next() {
|
||
ev := strm.Current()
|
||
if ev.Type == stream.EventTextDelta {
|
||
sb.WriteString(ev.Text)
|
||
}
|
||
}
|
||
if err := strm.Err(); err != nil {
|
||
return nil, fmt.Errorf("stream error: %w", err)
|
||
}
|
||
|
||
text := extractJSON(sb.String())
|
||
var resp classifyResponse
|
||
if err := json.Unmarshal([]byte(text), &resp); err != nil {
|
||
return nil, fmt.Errorf("parse %q: %w", text, err)
|
||
}
|
||
return &resp, nil
|
||
}
|
||
|
||
// extractJSON pulls the first {...} substring from s, stripping markdown fences if present.
|
||
func extractJSON(s string) string {
|
||
s = strings.TrimSpace(s)
|
||
|
||
// Strip ```json ... ``` fences.
|
||
if strings.HasPrefix(s, "```") {
|
||
end := strings.LastIndex(s, "```")
|
||
if end > 3 {
|
||
inner := s[3:end]
|
||
inner = strings.TrimPrefix(inner, "json")
|
||
s = strings.TrimSpace(inner)
|
||
}
|
||
}
|
||
|
||
// Extract first balanced {...} block.
|
||
start := strings.IndexByte(s, '{')
|
||
if start < 0 {
|
||
return s
|
||
}
|
||
depth := 0
|
||
for i := start; i < len(s); i++ {
|
||
switch s[i] {
|
||
case '{':
|
||
depth++
|
||
case '}':
|
||
depth--
|
||
if depth == 0 {
|
||
return s[start : i+1]
|
||
}
|
||
}
|
||
}
|
||
return s[start:]
|
||
}
|