diff --git a/cmd/gnoma/main.go b/cmd/gnoma/main.go index 9fe2a56..a6bcebf 100644 --- a/cmd/gnoma/main.go +++ b/cmd/gnoma/main.go @@ -18,6 +18,7 @@ import ( "somegit.dev/Owlibou/gnoma/internal/engine" "somegit.dev/Owlibou/gnoma/internal/hook" "somegit.dev/Owlibou/gnoma/internal/skill" + "somegit.dev/Owlibou/gnoma/internal/slm" "somegit.dev/Owlibou/gnoma/internal/tool/persist" gnomacfg "somegit.dev/Owlibou/gnoma/internal/config" gnomactx "somegit.dev/Owlibou/gnoma/internal/context" @@ -131,6 +132,11 @@ func main() { *permMode = cfg.Permission.Mode } + // SLM subcommands: `gnoma slm setup` / `gnoma slm status` + if cliArgs := flag.Args(); len(cliArgs) > 0 && cliArgs[0] == "slm" { + os.Exit(runSLMCommand(cliArgs[1:], cfg, logger)) + } + // Resolve API key: CLI flag → config → env vars knownProviders := map[string]bool{ "mistral": true, "anthropic": true, "openai": true, @@ -551,10 +557,48 @@ func main() { logger.Debug("prefix token baseline set", "tokens", prefixTokens) } + // Wire SLM: start llamafile, register arm, inject classifier (opt-in). + var slmMgr *slm.Manager + var engineClassifier router.TaskClassifier + if cfg.SLM.Enabled { + slmDataDir := cfg.SLM.DataDir + if slmDataDir == "" { + slmDataDir = slm.DefaultDataDir() + } + slmMgr = slm.New(slm.Config{DataDir: slmDataDir, ModelURL: cfg.SLM.ModelURL}, logger) + if slmMgr.IsSetUp() { + slmBaseURL, startErr := slmMgr.Start(context.Background()) + if startErr != nil { + logger.Warn("failed to start SLM; falling back to heuristic classifier", "error", startErr) + } else { + slmProv, provErr := openaicompat.NewLlamafile(provider.ProviderConfig{ + BaseURL: slmBaseURL + "/v1", + }) + if provErr != nil { + logger.Warn("failed to create SLM provider", "error", provErr) + } else { + engineClassifier = slm.NewClassifier(slmProv, "default", logger) + rtr.RegisterArm(&router.Arm{ + ID: "slm/llamafile", + Provider: slmProv, + ModelName: "default", + IsLocal: true, + MaxComplexity: 0.3, + Capabilities: provider.Capabilities{ToolUse: false}, + }) + logger.Info("SLM ready", "url", slmBaseURL) + } + } + } else { + logger.Warn("SLM enabled but not set up; run: gnoma slm setup") + } + } + // Create engine eng, err := engine.New(engine.Config{ Provider: prov, Router: rtr, + Classifier: engineClassifier, Tools: reg, Firewall: fw, Permissions: permChecker, @@ -729,13 +773,14 @@ func main() { Permissions: permChecker, Router: rtr, ElfManager: elfMgr, + SLMManager: slmMgr, PermCh: permCh, PermReqCh: permReqCh, ElfProgress: elfProgressCh, SessionStore: sessStore, StartWithResumePicker: openResumePicker, Skills: skillReg, - PluginInfos: buildPluginInfos(discoveredPlugins, enabledSet), + PluginInfos: buildPluginInfos(discoveredPlugins, enabledSet), Version: buildVersion, ModelUpdateCh: modelUpdateCh, }) @@ -972,6 +1017,89 @@ func buildPluginInfos(plugins []plugin.Plugin, enabledSet map[string]bool) []tui return infos } +// runSLMCommand handles `gnoma slm `. +// Returns an exit code. +func runSLMCommand(args []string, cfg *gnomacfg.Config, logger *slog.Logger) int { + if len(args) == 0 { + fmt.Fprintln(os.Stderr, "usage: gnoma slm ") + fmt.Fprintln(os.Stderr, "commands:") + fmt.Fprintln(os.Stderr, " setup download and verify the llamafile model") + fmt.Fprintln(os.Stderr, " status show current setup state") + return 1 + } + + dataDir := cfg.SLM.DataDir + if dataDir == "" { + dataDir = slm.DefaultDataDir() + } + mgr := slm.New(slm.Config{DataDir: dataDir, ModelURL: cfg.SLM.ModelURL}, logger) + + switch args[0] { + case "setup": + if cfg.SLM.ModelURL == "" { + fmt.Fprintln(os.Stderr, "error: [slm] model_url must be set in config before running setup") + fmt.Fprintln(os.Stderr, "") + fmt.Fprintln(os.Stderr, "Example (~/.config/gnoma/config.toml):") + fmt.Fprintln(os.Stderr, " [slm]") + fmt.Fprintln(os.Stderr, ` model_url = "https://huggingface.co/mozilla-ai/TinyLlama-1.1B-Chat-v1.0-llamafile/resolve/main/TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile"`) + return 1 + } + fmt.Printf("downloading llamafile from %s\n", cfg.SLM.ModelURL) + err := mgr.Setup(context.Background(), func(downloaded, total int64) { + if total > 0 { + pct := float64(downloaded) / float64(total) * 100 + fmt.Printf("\r %.1f%% (%s / %s) ", pct, humanBytes(downloaded), humanBytes(total)) + } else { + fmt.Printf("\r %s downloaded ", humanBytes(downloaded)) + } + }) + fmt.Println() + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + return 1 + } + fmt.Printf("SLM ready at: %s\n", dataDir) + fmt.Println("Enable in config:") + fmt.Println(" [slm]") + fmt.Println(" enabled = true") + return 0 + + case "status": + status := mgr.Status() + fmt.Printf("slm status: %s\n", status) + if mf := mgr.Manifest(); mf != nil { + fmt.Printf(" file: %s\n", mf.FilePath) + fmt.Printf(" size: %s\n", humanBytes(mf.Size)) + fmt.Printf(" sha256: %s\n", mf.SHA256[:16]+"...") + fmt.Printf(" setup: %s\n", mf.SetupAt.Format("2006-01-02 15:04 UTC")) + } + if status == slm.StatusNotSetUp { + fmt.Println(" run: gnoma slm setup") + } else if status == slm.StatusMissing { + fmt.Println(" file is missing; run: gnoma slm setup") + } + return 0 + + default: + fmt.Fprintf(os.Stderr, "unknown slm command: %s\n", args[0]) + return 1 + } +} + +// humanBytes formats a byte count as a human-readable string. +func humanBytes(n int64) string { + const unit = 1024 + if n < unit { + return fmt.Sprintf("%d B", n) + } + div, exp := int64(unit), 0 + for n2 := n / unit; n2 >= unit; n2 /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %ciB", float64(n)/float64(div), "KMGTPE"[exp]) +} + // resolveEnabledPlugins determines which plugins are enabled based on config. // If Enabled is empty, all plugins are enabled by default (opt-out via Disabled). // If Enabled is non-empty, only listed plugins are enabled (opt-in). diff --git a/docs/superpowers/plans/2026-05-07-gnoma-roadmap.md b/docs/superpowers/plans/2026-05-07-gnoma-roadmap.md index 8216032..7760fad 100644 --- a/docs/superpowers/plans/2026-05-07-gnoma-roadmap.md +++ b/docs/superpowers/plans/2026-05-07-gnoma-roadmap.md @@ -51,67 +51,44 @@ Bash tool flags `passwd foo` and offers takeover. ## Phase 3: SLM Task Classifier -Add an optional SLM-driven task classifier behind the existing `TaskClassifier` interface. The SLM -calls Ollama HTTP via the existing `openaicompat` provider — zero new dependencies, no CGO, no -daemon management. +Add an optional SLM-driven task classifier and low-complexity executor behind the `TaskClassifier` +interface. Uses llamafile (single-file download, OpenAI-compatible HTTP) instead of Ollama. +Zero new Go dependencies; the model binary is downloaded separately on opt-in. -**Context:** `gemma-integration-analysis.md` describes how gemini-cli implements this using -LiteRT-LM (a Node.js daemon + PID files). Those specifics do not apply here. The Go approach is -simpler: Ollama HTTP + structured JSON output + hard timeout + heuristic fallback. +**Implementation note (diverges from original plan):** Original plan used Ollama HTTP with +`router.slm_model` config key. Pivoted to llamafile after discussion: user downloads a specific +model file once (`gnoma slm setup`), gnoma manages the subprocess lifetime. Requires no external +daemon or package manager. Config section is `[slm]` not `[router]`. -### Interface +### Architecture -```go -// internal/router/classifier.go -type TaskClassifier interface { - Classify(ctx context.Context, input string, history []message.Message) (Task, error) -} - -type HeuristicClassifier struct{} // default — wraps existing ClassifyTask() -type SLMClassifier struct { - provider provider.Provider // openaicompat pointing at Ollama - model string - timeout time.Duration // default 2s -} -``` - -`SLMClassifier.Classify` sends a structured prompt with the Complexity Rubric (adapted from -gemma-integration-analysis.md) and expects a JSON response: - -```json -{"task_type": "Generation", "complexity": 0.4, "requires_tools": true} -``` - -On timeout or parse failure, it falls back to `HeuristicClassifier`. - -### Complexity Rubric (prompt fragment) - -``` -Classify this coding request. Respond with JSON only. -Complexity 0.0–0.3: boilerplate, trivial edits, simple lookups -Complexity 0.4–0.6: moderate — new functions, refactors, unit tests -Complexity 0.7–1.0: architectural, multi-file, security review, planning -``` +- `internal/slm/` — Manager (download, subprocess lifecycle, health check), Classifier +- `internal/router/` — `TaskClassifier` interface, `HeuristicClassifier`, `ParseTaskType` +- `Arm.MaxComplexity` — SLM arm capped at 0.3; excluded from complex tasks by `filterFeasible` ### Config ```toml -[router] -slm_model = "" # empty = disabled (HeuristicClassifier used) - # e.g. "gemma3:1b" — must be available in Ollama +[slm] +enabled = true +model_url = "https://huggingface.co/mozilla-ai/TinyLlama-1.1B-Chat-v1.0-llamafile/resolve/main/TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile" +data_dir = "" # empty = ~/.local/share/gnoma/slm ``` ### Tasks -- [ ] `TaskClassifier` interface in `internal/router/classifier.go` -- [ ] `HeuristicClassifier` wraps existing `ClassifyTask()` (zero behavior change) -- [ ] `SLMClassifier`: Ollama HTTP via openaicompat, JSON parse, 2s timeout + fallback -- [ ] Complexity Rubric prompt (task type + complexity float + requires_tools bool) -- [ ] Config key `router.slm_model`; wire into router construction in `cmd/gnoma/main.go` -- [ ] Tests: `HeuristicClassifier` behavior unchanged; `SLMClassifier` fallback on timeout; - `SLMClassifier` correct parse on valid JSON response +- [x] `TaskClassifier` interface in `internal/router/classifier.go` +- [x] `HeuristicClassifier` wraps existing `ClassifyTask()` (zero behavior change) +- [x] `internal/slm/` — Manager, Manifest, download, subprocess lifecycle (Wave B) +- [x] `slm.Classifier`: openaicompat pointing at llamafile, JSON parse, 2s timeout + fallback +- [x] `ParseTaskType` in `internal/router/task.go` +- [x] `Arm.MaxComplexity` + `filterFeasible` ceiling +- [x] `[slm]` config section in `internal/config/` +- [x] `gnoma slm setup` / `gnoma slm status` CLI subcommands +- [x] SLM arm registered with `MaxComplexity = 0.3` in `cmd/gnoma/main.go` +- [x] TUI `/config` shows SLM status -**Dependencies:** existing `internal/provider/openaicompat` — no new deps. +**Dependencies:** existing `internal/provider/openaicompat` — no new Go deps. **Exit criteria:** `gnoma` with `slm_model = "gemma3:1b"` routes using SLM classification. Without config key, behavior is identical to today. @@ -218,8 +195,8 @@ arches: amd64/arm64). The binary is a single static executable with zero runtime | Item | Reason | |------|--------| | `.gnoma/tmp/` local temp directory | `persist.Store` already uses `/tmp/gnoma-/`; adding `.gnoma/tmp/` adds complexity (cleanup, gitignore, collision avoidance) for no benefit | -| LiteRT-LM / CGO SLM runtime | `CGO_ENABLED=0` (goreleaser constraint). Go approach: Ollama HTTP via existing openaicompat | -| Daemon/PID file management for SLM | Node.js-specific pattern from gemma-integration-analysis.md; not applicable to this Go binary | +| LiteRT-LM / CGO SLM runtime | `CGO_ENABLED=0` (goreleaser constraint). Go approach: llamafile subprocess via existing openaicompat | +| Ollama-based SLM classifier | Pivoted to llamafile: single-file download, no external daemon, user-controlled opt-in | | PTY via `go-pty` library | Requires CGO. Replaced by `tea.ExecProcess` (already in go.mod, no CGO) | --- diff --git a/internal/config/config.go b/internal/config/config.go index a44e695..bd78f61 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -10,11 +10,28 @@ type Config struct { RateLimits RateLimitSection `toml:"rate_limits"` Security SecuritySection `toml:"security"` Session SessionSection `toml:"session"` + SLM SLMSection `toml:"slm"` Hooks []HookConfig `toml:"hooks"` MCPServers []MCPServerConfig `toml:"mcp_servers"` Plugins PluginsSection `toml:"plugins"` } +// SLMSection configures the optional small language model for task classification +// and low-complexity task execution. +// +// Example config: +// +// [slm] +// enabled = true +// model_url = "https://huggingface.co/mozilla-ai/TinyLlama-1.1B-Chat-v1.0-llamafile/resolve/main/TinyLlama-1.1B-Chat-v1.0.Q5_K_M.llamafile" +// +// Run `gnoma slm setup` to download and verify the model before enabling. +type SLMSection struct { + Enabled bool `toml:"enabled"` + ModelURL string `toml:"model_url"` + DataDir string `toml:"data_dir"` // empty = XDG default (~/.local/share/gnoma/slm) +} + // MCPServerConfig defines an MCP server to start and connect to. // // Example: diff --git a/internal/provider/openaicompat/provider.go b/internal/provider/openaicompat/provider.go index 9b7274f..b27a90a 100644 --- a/internal/provider/openaicompat/provider.go +++ b/internal/provider/openaicompat/provider.go @@ -47,3 +47,18 @@ func NewLlamaCpp(cfg provider.ProviderConfig) (provider.Provider, error) { } return oaiprov.New(cfg) } + +// NewLlamafile creates a provider for a llamafile process. +// BaseURL must include /v1, e.g. "http://127.0.0.1:8080/v1". +func NewLlamafile(cfg provider.ProviderConfig) (provider.Provider, error) { + if cfg.APIKey == "" { + cfg.APIKey = "llamafile" // llamafile doesn't require a real key + } + if cfg.Model == "" { + cfg.Model = "default" // llamafile ignores the model field + } + if cfg.MaxRetries == nil { + cfg.MaxRetries = intPtr(0) + } + return oaiprov.New(cfg) +} diff --git a/internal/router/arm.go b/internal/router/arm.go index 8e35ac7..c5b79d6 100644 --- a/internal/router/arm.go +++ b/internal/router/arm.go @@ -22,6 +22,10 @@ type Arm struct { Capabilities provider.Capabilities Pools []*LimitPool + // MaxComplexity is a hard ceiling on task complexity this arm will accept. + // Zero means no ceiling (default for all existing arms). + MaxComplexity float64 + // Cost per 1k tokens (EUR, estimated) CostPer1kInput float64 CostPer1kOutput float64 diff --git a/internal/router/router_test.go b/internal/router/router_test.go index 6191396..0efe844 100644 --- a/internal/router/router_test.go +++ b/internal/router/router_test.go @@ -657,3 +657,52 @@ func TestRouter_AllDisabled_ReturnsError(t *testing.T) { } } +func TestFilterFeasible_MaxComplexity(t *testing.T) { + slmArm := &Arm{ + ID: "slm/tiny", + IsLocal: true, + MaxComplexity: 0.3, + Capabilities: provider.Capabilities{ToolUse: false}, + } + apiArm := &Arm{ + ID: "api/big", + Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000}, + } + + // Low-complexity task: SLM arm passes the ceiling. + lowTask := Task{Type: TaskBoilerplate, ComplexityScore: 0.2} + got := filterFeasible([]*Arm{slmArm, apiArm}, lowTask) + found := false + for _, a := range got { + if a.ID == "slm/tiny" { + found = true + } + } + if !found { + t.Error("slm arm should pass filterFeasible for low-complexity task") + } + + // High-complexity task: SLM arm must be excluded. + highTask := Task{Type: TaskPlanning, ComplexityScore: 0.8, RequiresTools: false} + got = filterFeasible([]*Arm{slmArm, apiArm}, highTask) + for _, a := range got { + if a.ID == "slm/tiny" { + t.Error("slm arm should be excluded for high-complexity task") + } + } +} + +func TestFilterFeasible_MaxComplexity_Zero_MeansNoLimit(t *testing.T) { + // MaxComplexity == 0 means "no ceiling" — existing arms are unaffected. + arm := &Arm{ + ID: "api/arm", + MaxComplexity: 0, // zero = no ceiling + Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 200000}, + } + task := Task{Type: TaskOrchestration, ComplexityScore: 0.99} + got := filterFeasible([]*Arm{arm}, task) + if len(got) == 0 { + t.Error("arm with MaxComplexity=0 should never be excluded by complexity ceiling") + } +} + diff --git a/internal/router/selector.go b/internal/router/selector.go index a6d86ac..abadafc 100644 --- a/internal/router/selector.go +++ b/internal/router/selector.go @@ -179,6 +179,11 @@ func filterFeasible(arms []*Arm, task Task) []*Arm { var belowQuality []*Arm // passed tool+pool but scored below minimum quality for _, arm := range arms { + // Complexity ceiling: zero means no ceiling (preserves behavior for all existing arms). + if arm.MaxComplexity > 0 && task.ComplexityScore > arm.MaxComplexity { + continue + } + // Must support tools if task requires them if task.RequiresTools && !arm.SupportsTools() { continue diff --git a/internal/router/task.go b/internal/router/task.go index 29743a0..09fcf55 100644 --- a/internal/router/task.go +++ b/internal/router/task.go @@ -241,3 +241,32 @@ func estimateComplexity(prompt string) float64 { } return score } + +// ParseTaskType converts a string from an SLM JSON response to a TaskType. +// Matching is case-insensitive. Unknown strings fall back to TaskGeneration. +func ParseTaskType(s string) TaskType { + switch strings.ToLower(strings.ReplaceAll(s, "_", "")) { + case "debug": + return TaskDebug + case "explain": + return TaskExplain + case "generation": + return TaskGeneration + case "refactor": + return TaskRefactor + case "unittest": + return TaskUnitTest + case "boilerplate": + return TaskBoilerplate + case "planning": + return TaskPlanning + case "orchestration": + return TaskOrchestration + case "securityreview": + return TaskSecurityReview + case "review": + return TaskReview + default: + return TaskGeneration + } +} diff --git a/internal/router/task_test.go b/internal/router/task_test.go new file mode 100644 index 0000000..b67c083 --- /dev/null +++ b/internal/router/task_test.go @@ -0,0 +1,44 @@ +package router + +import "testing" + +func TestParseTaskType(t *testing.T) { + cases := []struct { + input string + want TaskType + }{ + {"Debug", TaskDebug}, + {"debug", TaskDebug}, + {"DEBUG", TaskDebug}, + {"Explain", TaskExplain}, + {"explain", TaskExplain}, + {"Generation", TaskGeneration}, + {"generation", TaskGeneration}, + {"Refactor", TaskRefactor}, + {"refactor", TaskRefactor}, + {"UnitTest", TaskUnitTest}, + {"unit_test", TaskUnitTest}, + {"unitTest", TaskUnitTest}, + {"Boilerplate", TaskBoilerplate}, + {"boilerplate", TaskBoilerplate}, + {"Planning", TaskPlanning}, + {"planning", TaskPlanning}, + {"Orchestration", TaskOrchestration}, + {"orchestration", TaskOrchestration}, + {"SecurityReview", TaskSecurityReview}, + {"security_review", TaskSecurityReview}, + {"Review", TaskReview}, + {"review", TaskReview}, + // unknown falls back to TaskGeneration + {"", TaskGeneration}, + {"unknown", TaskGeneration}, + {"gibberish", TaskGeneration}, + } + + for _, tc := range cases { + got := ParseTaskType(tc.input) + if got != tc.want { + t.Errorf("ParseTaskType(%q) = %s, want %s", tc.input, got, tc.want) + } + } +} diff --git a/internal/slm/classifier.go b/internal/slm/classifier.go new file mode 100644 index 0000000..1d3eba8 --- /dev/null +++ b/internal/slm/classifier.go @@ -0,0 +1,148 @@ +package slm + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "strings" + "time" + + "somegit.dev/Owlibou/gnoma/internal/message" + "somegit.dev/Owlibou/gnoma/internal/provider" + "somegit.dev/Owlibou/gnoma/internal/router" + "somegit.dev/Owlibou/gnoma/internal/stream" +) + +const defaultClassifyTimeout = 2 * time.Second + +const classifySystemPrompt = `Classify the following coding request. Respond with JSON only, no other text. +Format: {"task_type": "", "complexity": <0.0-1.0>, "requires_tools": } + +Task types: Debug, Explain, Generation, Refactor, UnitTest, Boilerplate, Planning, Orchestration, SecurityReview, Review + +Complexity guide: +0.0–0.3: boilerplate, trivial edits, simple lookups, short explanations +0.4–0.6: new functions, refactors, unit tests, moderate analysis +0.7–1.0: architectural changes, multi-file edits, security review, planning` + +type classifyResponse struct { + TaskType string `json:"task_type"` + Complexity float64 `json:"complexity"` + RequiresTools bool `json:"requires_tools"` +} + +// Classifier implements router.TaskClassifier using a llamafile-hosted SLM. +// On timeout or parse failure it falls back to router.HeuristicClassifier. +type Classifier struct { + provider provider.Provider + model string + timeout time.Duration + logger *slog.Logger +} + +// NewClassifier creates a Classifier. model is the model name passed to the provider +// (llamafile ignores it but openaicompat requires a non-empty value). +func NewClassifier(p provider.Provider, model string, logger *slog.Logger) *Classifier { + if logger == nil { + logger = slog.Default() + } + return &Classifier{ + provider: p, + model: model, + timeout: defaultClassifyTimeout, + logger: logger, + } +} + +// Classify calls the SLM and overlays the three SLM-authoritative fields +// (Type, ComplexityScore, RequiresTools) onto a heuristic baseline Task. +// This ensures Priority, EstimatedTokens, and RequiredEffort are always set. +func (c *Classifier) Classify(ctx context.Context, prompt string, history []message.Message) (router.Task, error) { + tctx, cancel := context.WithTimeout(ctx, c.timeout) + defer cancel() + + resp, err := c.callSLM(tctx, prompt) + if err != nil { + c.logger.Debug("slm classify fallback", "error", err) + return router.HeuristicClassifier{}.Classify(ctx, prompt, history) + } + + // Start from the heuristic baseline so Priority/EstimatedTokens/RequiredEffort are set. + task := router.ClassifyTask(prompt) + task.Type = router.ParseTaskType(resp.TaskType) + task.ComplexityScore = resp.Complexity + task.RequiresTools = resp.RequiresTools + return task, nil +} + +func (c *Classifier) callSLM(ctx context.Context, prompt string) (*classifyResponse, error) { + req := provider.Request{ + Model: c.model, + SystemPrompt: classifySystemPrompt, + Messages: []message.Message{ + { + Role: message.RoleUser, + Content: []message.Content{{Type: message.ContentText, Text: prompt}}, + }, + }, + } + + strm, err := c.provider.Stream(ctx, req) + if err != nil { + return nil, fmt.Errorf("stream: %w", err) + } + defer strm.Close() + + var sb strings.Builder + for strm.Next() { + ev := strm.Current() + if ev.Type == stream.EventTextDelta { + sb.WriteString(ev.Text) + } + } + if err := strm.Err(); err != nil { + return nil, fmt.Errorf("stream error: %w", err) + } + + text := extractJSON(sb.String()) + var resp classifyResponse + if err := json.Unmarshal([]byte(text), &resp); err != nil { + return nil, fmt.Errorf("parse %q: %w", text, err) + } + return &resp, nil +} + +// extractJSON pulls the first {...} substring from s, stripping markdown fences if present. +func extractJSON(s string) string { + s = strings.TrimSpace(s) + + // Strip ```json ... ``` fences. + if strings.HasPrefix(s, "```") { + end := strings.LastIndex(s, "```") + if end > 3 { + inner := s[3:end] + inner = strings.TrimPrefix(inner, "json") + s = strings.TrimSpace(inner) + } + } + + // Extract first balanced {...} block. + start := strings.IndexByte(s, '{') + if start < 0 { + return s + } + depth := 0 + for i := start; i < len(s); i++ { + switch s[i] { + case '{': + depth++ + case '}': + depth-- + if depth == 0 { + return s[start : i+1] + } + } + } + return s[start:] +} diff --git a/internal/slm/classifier_test.go b/internal/slm/classifier_test.go new file mode 100644 index 0000000..39765f8 --- /dev/null +++ b/internal/slm/classifier_test.go @@ -0,0 +1,174 @@ +package slm + +import ( + "context" + "errors" + "testing" + "time" + + "somegit.dev/Owlibou/gnoma/internal/message" + "somegit.dev/Owlibou/gnoma/internal/provider" + "somegit.dev/Owlibou/gnoma/internal/router" + "somegit.dev/Owlibou/gnoma/internal/stream" +) + +// mockProvider implements provider.Provider for classifier tests. +type mockProvider struct { + text string + delay time.Duration + err error +} + +func (m *mockProvider) Name() string { return "mock" } +func (m *mockProvider) DefaultModel() string { return "default" } +func (m *mockProvider) Models(_ context.Context) ([]provider.ModelInfo, error) { + return nil, nil +} +func (m *mockProvider) Stream(ctx context.Context, _ provider.Request) (stream.Stream, error) { + if m.delay > 0 { + select { + case <-time.After(m.delay): + case <-ctx.Done(): + return nil, ctx.Err() + } + } + if m.err != nil { + return nil, m.err + } + return &mockStream{events: []stream.Event{ + {Type: stream.EventTextDelta, Text: m.text}, + }}, nil +} + +type mockStream struct { + events []stream.Event + idx int +} + +func (s *mockStream) Next() bool { s.idx++; return s.idx <= len(s.events) } +func (s *mockStream) Current() stream.Event { return s.events[s.idx-1] } +func (s *mockStream) Err() error { return nil } +func (s *mockStream) Close() error { return nil } + +func TestClassifier_HappyPath(t *testing.T) { + p := &mockProvider{text: `{"task_type":"Debug","complexity":0.25,"requires_tools":false}`} + cls := NewClassifier(p, "default", nil) + + task, err := cls.Classify(context.Background(), "fix the failing test", nil) + if err != nil { + t.Fatalf("Classify: %v", err) + } + if task.Type != router.TaskDebug { + t.Errorf("Type = %s, want Debug", task.Type) + } + if task.ComplexityScore != 0.25 { + t.Errorf("ComplexityScore = %v, want 0.25", task.ComplexityScore) + } + if task.RequiresTools != false { + t.Errorf("RequiresTools = true, want false") + } +} + +func TestClassifier_BlendHeuristic(t *testing.T) { + // SLM returns one type; other Task fields should come from heuristic. + p := &mockProvider{text: `{"task_type":"Boilerplate","complexity":0.1,"requires_tools":false}`} + cls := NewClassifier(p, "default", nil) + + task, err := cls.Classify(context.Background(), "scaffold a new HTTP handler", nil) + if err != nil { + t.Fatalf("Classify: %v", err) + } + if task.Type != router.TaskBoilerplate { + t.Errorf("Type = %s, want Boilerplate", task.Type) + } + // Priority must come from the heuristic baseline (PriorityNormal = 1, not zero). + if task.Priority < router.PriorityNormal { + t.Errorf("Priority = %v, want at least PriorityNormal from heuristic baseline", task.Priority) + } +} + +func TestClassifier_FallbackOnBadJSON(t *testing.T) { + p := &mockProvider{text: "I cannot classify that."} + cls := NewClassifier(p, "default", nil) + + // Should not error — falls back to heuristic. + task, err := cls.Classify(context.Background(), "write unit tests for the parser", nil) + if err != nil { + t.Fatalf("Classify should not error on bad JSON: %v", err) + } + // Heuristic would return UnitTest for "write unit tests". + if task.Type != router.TaskUnitTest { + t.Errorf("heuristic fallback: Type = %s, want UnitTest", task.Type) + } +} + +func TestClassifier_FallbackOnProviderError(t *testing.T) { + p := &mockProvider{err: errors.New("connection refused")} + cls := NewClassifier(p, "default", nil) + + task, err := cls.Classify(context.Background(), "explain how generics work", nil) + if err != nil { + t.Fatalf("Classify should not error on provider error: %v", err) + } + // Heuristic fallback: "explain" → TaskExplain + if task.Type != router.TaskExplain { + t.Errorf("heuristic fallback: Type = %s, want Explain", task.Type) + } +} + +func TestClassifier_FallbackOnTimeout(t *testing.T) { + p := &mockProvider{delay: 500 * time.Millisecond} + cls := NewClassifier(p, "default", nil) + cls.timeout = 50 * time.Millisecond // force timeout + + task, err := cls.Classify(context.Background(), "debug the failing test", nil) + if err != nil { + t.Fatalf("Classify should not error on timeout: %v", err) + } + // Falls back to heuristic: "debug" → TaskDebug + if task.Type != router.TaskDebug { + t.Errorf("heuristic fallback: Type = %s, want Debug", task.Type) + } +} + +func TestClassifier_FenceStripping(t *testing.T) { + fenced := "```json\n{\"task_type\":\"Refactor\",\"complexity\":0.5,\"requires_tools\":true}\n```" + p := &mockProvider{text: fenced} + cls := NewClassifier(p, "default", nil) + + task, err := cls.Classify(context.Background(), "refactor the auth middleware", nil) + if err != nil { + t.Fatalf("Classify: %v", err) + } + if task.Type != router.TaskRefactor { + t.Errorf("Type = %s, want Refactor", task.Type) + } +} + +func TestClassifier_UnknownTaskType_FallsBackToHeuristic(t *testing.T) { + p := &mockProvider{text: `{"task_type":"FooBar","complexity":0.3,"requires_tools":false}`} + cls := NewClassifier(p, "default", nil) + + task, err := cls.Classify(context.Background(), "implement a binary search function", nil) + if err != nil { + t.Fatalf("Classify: %v", err) + } + // "implement" → heuristic should give Generation or Boilerplate; SLM gave FooBar → Generation fallback + _ = task // just verify no panic and no error +} + +func TestClassifier_ContextPassedToHistory(t *testing.T) { + p := &mockProvider{text: `{"task_type":"Explain","complexity":0.2,"requires_tools":false}`} + cls := NewClassifier(p, "default", nil) + + history := []message.Message{ + {Role: message.RoleUser, Content: []message.Content{{Type: message.ContentText, Text: "prior"}}}, + } + task, err := cls.Classify(context.Background(), "explain this code", history) + if err != nil { + t.Fatalf("Classify: %v", err) + } + if task.Type != router.TaskExplain { + t.Errorf("Type = %s, want Explain", task.Type) + } +} diff --git a/internal/slm/manager.go b/internal/slm/manager.go index 091de66..ef7af96 100644 --- a/internal/slm/manager.go +++ b/internal/slm/manager.go @@ -17,6 +17,18 @@ import ( const pidFile = "llamafile.pid" +// DefaultDataDir returns the platform default SLM data directory. +// Follows XDG Base Directory Specification: $XDG_DATA_HOME/gnoma/slm, +// falling back to ~/.local/share/gnoma/slm. +func DefaultDataDir() string { + dir := os.Getenv("XDG_DATA_HOME") + if dir == "" { + home, _ := os.UserHomeDir() + dir = filepath.Join(home, ".local", "share") + } + return filepath.Join(dir, "gnoma", "slm") +} + // Status describes the setup state of the SLM. type Status int @@ -180,6 +192,15 @@ func (m *Manager) BaseURL() string { return fmt.Sprintf("http://127.0.0.1:%d", m.port) } +// Manifest returns the on-disk manifest if present, or nil. +func (m *Manager) Manifest() *Manifest { + mf, err := readManifest(m.cfg.DataDir) + if err != nil { + return nil + } + return mf +} + func (m *Manager) pidPath() string { return filepath.Join(m.cfg.DataDir, pidFile) } diff --git a/internal/tui/app.go b/internal/tui/app.go index d20f4fd..46dbdee 100644 --- a/internal/tui/app.go +++ b/internal/tui/app.go @@ -19,6 +19,7 @@ import ( gnomacfg "somegit.dev/Owlibou/gnoma/internal/config" "somegit.dev/Owlibou/gnoma/internal/elf" "somegit.dev/Owlibou/gnoma/internal/skill" + "somegit.dev/Owlibou/gnoma/internal/slm" "somegit.dev/Owlibou/gnoma/internal/engine" "somegit.dev/Owlibou/gnoma/internal/message" "somegit.dev/Owlibou/gnoma/internal/permission" @@ -61,6 +62,7 @@ type Config struct { Permissions *permission.Checker // for mode switching Router *router.Router // for model listing ElfManager *elf.Manager // for CancelAll on escape/quit + SLMManager *slm.Manager // nil = SLM not configured PermCh chan bool // TUI → engine: y/n response PermReqCh <-chan PermReqMsg // engine → TUI: tool requesting approval ElfProgress <-chan elf.Progress // elf → TUI: structured progress updates @@ -877,16 +879,32 @@ func (m Model) handleCommand(cmd string) (tea.Model, tea.Cmd) { status := m.session.Status() var b strings.Builder b.WriteString("Current configuration:\n") - fmt.Fprintf(&b, " provider: %s\n", status.Provider) - fmt.Fprintf(&b, " model: %s\n", status.Model) + fmt.Fprintf(&b, " provider: %s\n", status.Provider) + fmt.Fprintf(&b, " model: %s\n", status.Model) if m.config.Permissions != nil { fmt.Fprintf(&b, " permission: %s\n", m.config.Permissions.Mode()) } - fmt.Fprintf(&b, " incognito: %v\n", m.incognito) - fmt.Fprintf(&b, " cwd: %s\n", m.cwd) + fmt.Fprintf(&b, " incognito: %v\n", m.incognito) + fmt.Fprintf(&b, " cwd: %s\n", m.cwd) if m.gitBranch != "" { fmt.Fprintf(&b, " git branch: %s\n", m.gitBranch) } + if m.config.SLMManager != nil { + slmStat := m.config.SLMManager.Status() + switch slmStat { + case slm.StatusReady: + url := m.config.SLMManager.BaseURL() + if url != "" { + fmt.Fprintf(&b, " slm: ready (running at %s)\n", url) + } else { + b.WriteString(" slm: ready (not started)\n") + } + case slm.StatusMissing: + b.WriteString(" slm: file missing — run: gnoma slm setup\n") + default: + b.WriteString(" slm: not set up — run: gnoma slm setup\n") + } + } b.WriteString("\nConfig files: ~/.config/gnoma/config.toml, .gnoma/config.toml") b.WriteString("\nEdit: /config set ") m.messages = append(m.messages, chatMessage{role: "system", content: b.String()})