diff --git a/cmd/gnoma/main.go b/cmd/gnoma/main.go index 3b6ea9c..935c7af 100644 --- a/cmd/gnoma/main.go +++ b/cmd/gnoma/main.go @@ -1,7 +1,197 @@ package main -import "fmt" +import ( + "context" + "flag" + "fmt" + "io" + "log/slog" + "os" + "os/signal" + "strings" + + "somegit.dev/Owlibou/gnoma/internal/engine" + "somegit.dev/Owlibou/gnoma/internal/provider" + "somegit.dev/Owlibou/gnoma/internal/provider/mistral" + "somegit.dev/Owlibou/gnoma/internal/stream" + "somegit.dev/Owlibou/gnoma/internal/tool" + "somegit.dev/Owlibou/gnoma/internal/tool/bash" + "somegit.dev/Owlibou/gnoma/internal/tool/fs" +) func main() { - fmt.Println("gnoma — provider-agnostic agentic coding assistant") + var ( + providerName = flag.String("provider", "mistral", "LLM provider") + model = flag.String("model", "", "model name (empty = provider default)") + system = flag.String("system", defaultSystem, "system prompt") + apiKey = flag.String("api-key", "", "API key (or set MISTRAL_API_KEY env)") + maxTurns = flag.Int("max-turns", 50, "max tool-calling rounds per turn") + verbose = flag.Bool("verbose", false, "enable debug logging") + version = flag.Bool("version", false, "print version and exit") + ) + flag.Parse() + + if *version { + fmt.Println("gnoma v0.1.0-dev") + os.Exit(0) + } + + // Logger + logLevel := slog.LevelWarn + if *verbose { + logLevel = slog.LevelDebug + } + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: logLevel})) + + // Resolve API key + key := *apiKey + if key == "" { + key = resolveAPIKey(*providerName) + } + if key == "" { + fmt.Fprintf(os.Stderr, "error: no API key for provider %q\nSet %s environment variable or use --api-key\n", + *providerName, envKeyFor(*providerName)) + os.Exit(1) + } + + // Create provider + prov, err := createProvider(*providerName, key, *model) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + + // Create tool registry + reg := buildToolRegistry() + + // Harvest shell aliases + aliases, err := bash.HarvestAliases(context.Background()) + if err != nil { + logger.Debug("alias harvest failed (non-fatal)", "error", err) + } else { + logger.Debug("harvested aliases", "count", aliases.Len()) + } + + // Re-register bash tool with aliases + reg.Register(bash.New(bash.WithAliases(aliases))) + + // Create engine + eng, err := engine.New(engine.Config{ + Provider: prov, + Tools: reg, + System: *system, + Model: *model, + MaxTurns: *maxTurns, + Logger: logger, + }) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + + // Read input + input, err := readInput(flag.Args()) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + if input == "" { + fmt.Fprintln(os.Stderr, "error: no input provided") + fmt.Fprintln(os.Stderr, "usage: echo 'prompt' | gnoma") + fmt.Fprintln(os.Stderr, " or: gnoma 'prompt'") + os.Exit(1) + } + + // Context with signal handling + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) + defer cancel() + + // Callback: stream text deltas to stdout + cb := func(evt stream.Event) { + if evt.Type == stream.EventTextDelta && evt.Text != "" { + fmt.Print(evt.Text) + } + } + + // Submit and run + _, err = eng.Submit(ctx, input, cb) + fmt.Println() // final newline + + if err != nil { + if ctx.Err() != nil { + fmt.Fprintln(os.Stderr, "\ninterrupted") + os.Exit(130) + } + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } } + +func readInput(args []string) (string, error) { + // Positional args + if len(args) > 0 { + return strings.Join(args, " "), nil + } + + // Stdin (pipe mode) + stat, _ := os.Stdin.Stat() + if stat.Mode()&os.ModeCharDevice == 0 { + data, err := io.ReadAll(os.Stdin) + if err != nil { + return "", fmt.Errorf("reading stdin: %w", err) + } + return strings.TrimSpace(string(data)), nil + } + + return "", nil +} + +func resolveAPIKey(providerName string) string { + envVar := envKeyFor(providerName) + return os.Getenv(envVar) +} + +func envKeyFor(providerName string) string { + switch providerName { + case "mistral": + return "MISTRAL_API_KEY" + case "anthropic": + return "ANTHROPIC_API_KEY" + case "openai": + return "OPENAI_API_KEY" + case "google": + return "GEMINI_API_KEY" + default: + return strings.ToUpper(providerName) + "_API_KEY" + } +} + +func createProvider(name, apiKey, model string) (provider.Provider, error) { + cfg := provider.ProviderConfig{ + APIKey: apiKey, + Model: model, + } + + switch name { + case "mistral": + return mistral.New(cfg) + default: + return nil, fmt.Errorf("unknown provider %q (M1 supports: mistral)", name) + } +} + +func buildToolRegistry() *tool.Registry { + reg := tool.NewRegistry() + reg.Register(bash.New()) + reg.Register(fs.NewReadTool()) + reg.Register(fs.NewWriteTool()) + reg.Register(fs.NewEditTool()) + reg.Register(fs.NewGlobTool()) + reg.Register(fs.NewGrepTool()) + reg.Register(fs.NewLSTool()) + return reg +} + +const defaultSystem = `You are gnoma, a provider-agnostic agentic coding assistant. +You help users with software engineering tasks by reading files, writing code, and executing commands. +Be concise and direct. Use tools when needed to accomplish the task.` diff --git a/docs/essentials/milestones.md b/docs/essentials/milestones.md index 936425b..b01da0b 100644 --- a/docs/essentials/milestones.md +++ b/docs/essentials/milestones.md @@ -116,9 +116,13 @@ depends_on: [vision] - [ ] Model picker overlay - [ ] In-app config editor (`/config` command) - [ ] Incognito toggle (`/incognito` command) +- [ ] Interactive shell pane: `/shell` command or keybinding opens PTY-connected shell + - For commands needing user input (sudo, ssh, git push with auth, passwd prompts) + - Bash tool detects potentially interactive commands and suggests take-over + - PTY-based execution for flagged commands - [ ] Session management (channel-based) -**Exit criteria:** Launch TUI, chat interactively, 6 permission modes work, config editable in-app, incognito toggleable. +**Exit criteria:** Launch TUI, chat interactively, 6 permission modes work, config editable in-app, incognito toggleable, `/shell` opens interactive terminal for password prompts. ## M6: Context Intelligence @@ -219,7 +223,7 @@ depends_on: [vision] **Exit criteria:** gnoma suggests a persistent task after 3+ repetitions. `/task release v1.2.0` executes a saved workflow. -## M12: Thinking & Structured Output +## M12: Thinking, Structured Output & Notebook **Deliverables:** @@ -227,6 +231,7 @@ depends_on: [vision] - [ ] Thinking block streaming and TUI display - [ ] Structured output with JSON schema validation - [ ] Retry logic for schema validation failures +- [ ] NotebookEdit tool: read/write/edit Jupyter notebook cells (.ipynb) ## M13: Auth diff --git a/internal/engine/callback.go b/internal/engine/callback.go new file mode 100644 index 0000000..4d4a7e6 --- /dev/null +++ b/internal/engine/callback.go @@ -0,0 +1,7 @@ +package engine + +import "somegit.dev/Owlibou/gnoma/internal/stream" + +// Callback receives streaming events for real-time UI updates. +// Called synchronously on the engine goroutine for each event. +type Callback func(stream.Event) diff --git a/internal/engine/engine.go b/internal/engine/engine.go new file mode 100644 index 0000000..ade20ad --- /dev/null +++ b/internal/engine/engine.go @@ -0,0 +1,123 @@ +package engine + +import ( + "context" + "fmt" + "log/slog" + + "somegit.dev/Owlibou/gnoma/internal/message" + "somegit.dev/Owlibou/gnoma/internal/provider" + "somegit.dev/Owlibou/gnoma/internal/tool" +) + +// Config holds engine configuration. +type Config struct { + Provider provider.Provider + Tools *tool.Registry + System string // system prompt + Model string // override model (empty = provider default) + MaxTurns int // safety limit on tool loops (0 = unlimited) + Logger *slog.Logger +} + +func (c Config) validate() error { + if c.Provider == nil { + return fmt.Errorf("engine: provider required") + } + if c.Tools == nil { + return fmt.Errorf("engine: tool registry required") + } + return nil +} + +// Turn is the result of a complete agentic turn (may span multiple API calls). +type Turn struct { + Messages []message.Message // all messages produced (assistant + tool results) + Usage message.Usage // cumulative for all API calls in this turn + Rounds int // number of API round-trips +} + +// Engine orchestrates the conversation. +type Engine struct { + cfg Config + history []message.Message + usage message.Usage + logger *slog.Logger + + // Cached model capabilities, resolved lazily + modelCaps *provider.Capabilities + modelCapsFor string // model ID the cached caps are for +} + +// New creates an engine. +func New(cfg Config) (*Engine, error) { + if err := cfg.validate(); err != nil { + return nil, err + } + logger := cfg.Logger + if logger == nil { + logger = slog.Default() + } + return &Engine{ + cfg: cfg, + logger: logger, + }, nil +} + +// resolveCapabilities returns the capabilities for the active model. +// Caches the result — re-resolves if the model changes. +func (e *Engine) resolveCapabilities(ctx context.Context) *provider.Capabilities { + model := e.cfg.Model + if model == "" { + model = e.cfg.Provider.DefaultModel() + } + + // Return cached if same model + if e.modelCaps != nil && e.modelCapsFor == model { + return e.modelCaps + } + + // Query provider for model list + models, err := e.cfg.Provider.Models(ctx) + if err != nil { + e.logger.Debug("failed to fetch model capabilities", "error", err) + return nil + } + + for _, m := range models { + if m.ID == model { + e.modelCaps = &m.Capabilities + e.modelCapsFor = model + return e.modelCaps + } + } + + e.logger.Debug("model not found in provider model list", "model", model) + return nil +} + +// History returns the full conversation. +func (e *Engine) History() []message.Message { + return e.history +} + +// Usage returns cumulative token usage. +func (e *Engine) Usage() message.Usage { + return e.usage +} + +// SetProvider swaps the active provider (for dynamic switching). +func (e *Engine) SetProvider(p provider.Provider) { + e.cfg.Provider = p +} + +// SetModel changes the model within the current provider. +func (e *Engine) SetModel(model string) { + e.cfg.Model = model +} + +// Reset clears conversation history and usage. +func (e *Engine) Reset() { + e.history = nil + e.usage = message.Usage{} +} diff --git a/internal/engine/engine_test.go b/internal/engine/engine_test.go new file mode 100644 index 0000000..c2b3c1c --- /dev/null +++ b/internal/engine/engine_test.go @@ -0,0 +1,475 @@ +package engine + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "testing" + + "somegit.dev/Owlibou/gnoma/internal/message" + "somegit.dev/Owlibou/gnoma/internal/provider" + "somegit.dev/Owlibou/gnoma/internal/stream" + "somegit.dev/Owlibou/gnoma/internal/tool" +) + +// --- Mock Provider --- + +// mockProvider returns pre-configured streams for each call. +type mockProvider struct { + name string + calls int + streams []stream.Stream // one per call, consumed in order +} + +func (m *mockProvider) Name() string { return m.name } +func (m *mockProvider) DefaultModel() string { return "mock-model" } +func (m *mockProvider) Models(_ context.Context) ([]provider.ModelInfo, error) { + return []provider.ModelInfo{{ + ID: "mock-model", Name: "mock-model", Provider: m.name, + Capabilities: provider.Capabilities{ToolUse: true}, + }}, nil +} +func (m *mockProvider) Stream(_ context.Context, _ provider.Request) (stream.Stream, error) { + if m.calls >= len(m.streams) { + return nil, fmt.Errorf("mock: no more streams (called %d times)", m.calls+1) + } + s := m.streams[m.calls] + m.calls++ + return s, nil +} + +// eventStream is a mock stream backed by a slice of events. +type eventStream struct { + events []stream.Event + idx int + stopReason message.StopReason + model string +} + +func newEventStream(stopReason message.StopReason, model string, events ...stream.Event) *eventStream { + // Append a final event with stop reason + events = append(events, stream.Event{ + Type: stream.EventTextDelta, + StopReason: stopReason, + Model: model, + }) + return &eventStream{events: events, stopReason: stopReason, model: model} +} + +func (s *eventStream) Next() bool { + if s.idx >= len(s.events) { + return false + } + s.idx++ + return true +} + +func (s *eventStream) Current() stream.Event { + return s.events[s.idx-1] +} + +func (s *eventStream) Err() error { return nil } +func (s *eventStream) Close() error { return nil } + +// --- Mock Tool --- + +type mockTool struct { + name string + readOnly bool + execFn func(ctx context.Context, args json.RawMessage) (tool.Result, error) +} + +func (m *mockTool) Name() string { return m.name } +func (m *mockTool) Description() string { return "mock tool" } +func (m *mockTool) Parameters() json.RawMessage { return json.RawMessage(`{"type":"object"}`) } +func (m *mockTool) IsReadOnly() bool { return m.readOnly } +func (m *mockTool) IsDestructive() bool { return false } +func (m *mockTool) Execute(ctx context.Context, args json.RawMessage) (tool.Result, error) { + if m.execFn != nil { + return m.execFn(ctx, args) + } + return tool.Result{Output: "mock output"}, nil +} + +// --- Tests --- + +func TestNew_ValidConfig(t *testing.T) { + e, err := New(Config{ + Provider: &mockProvider{name: "test"}, + Tools: tool.NewRegistry(), + }) + if err != nil { + t.Fatalf("New: %v", err) + } + if e == nil { + t.Fatal("engine should not be nil") + } +} + +func TestNew_MissingProvider(t *testing.T) { + _, err := New(Config{Tools: tool.NewRegistry()}) + if err == nil { + t.Fatal("expected error for missing provider") + } +} + +func TestNew_MissingTools(t *testing.T) { + _, err := New(Config{Provider: &mockProvider{name: "test"}}) + if err == nil { + t.Fatal("expected error for missing tool registry") + } +} + +func TestSubmit_SimpleTextResponse(t *testing.T) { + mp := &mockProvider{ + name: "test", + streams: []stream.Stream{ + newEventStream(message.StopEndTurn, "test-model", + stream.Event{Type: stream.EventTextDelta, Text: "Hello "}, + stream.Event{Type: stream.EventTextDelta, Text: "world!"}, + stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 10, OutputTokens: 5}}, + ), + }, + } + + e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()}) + + var events []stream.Event + turn, err := e.Submit(context.Background(), "hi", func(evt stream.Event) { + events = append(events, evt) + }) + if err != nil { + t.Fatalf("Submit: %v", err) + } + + // Check turn + if turn.Rounds != 1 { + t.Errorf("Rounds = %d, want 1", turn.Rounds) + } + if len(turn.Messages) != 1 { + t.Fatalf("len(Messages) = %d, want 1", len(turn.Messages)) + } + if turn.Messages[0].TextContent() != "Hello world!" { + t.Errorf("TextContent = %q", turn.Messages[0].TextContent()) + } + if turn.Usage.InputTokens != 10 { + t.Errorf("Usage.InputTokens = %d", turn.Usage.InputTokens) + } + + // Check history + history := e.History() + if len(history) != 2 { + t.Fatalf("len(History) = %d, want 2 (user + assistant)", len(history)) + } + if history[0].Role != message.RoleUser { + t.Errorf("History[0].Role = %q", history[0].Role) + } + if history[1].Role != message.RoleAssistant { + t.Errorf("History[1].Role = %q", history[1].Role) + } + + // Check events were forwarded + if len(events) == 0 { + t.Error("callback should have received events") + } +} + +func TestSubmit_ToolCallLoop(t *testing.T) { + reg := tool.NewRegistry() + reg.Register(&mockTool{ + name: "bash", + execFn: func(_ context.Context, args json.RawMessage) (tool.Result, error) { + return tool.Result{Output: "file1.go\nfile2.go"}, nil + }, + }) + + mp := &mockProvider{ + name: "test", + streams: []stream.Stream{ + // Round 1: model calls a tool + newEventStream(message.StopToolUse, "model-1", + stream.Event{Type: stream.EventTextDelta, Text: "Let me list files."}, + stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "bash"}, + stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{"command":"ls"}`)}, + ), + // Round 2: model responds with final answer + newEventStream(message.StopEndTurn, "model-1", + stream.Event{Type: stream.EventTextDelta, Text: "Found file1.go and file2.go."}, + ), + }, + } + + e, _ := New(Config{Provider: mp, Tools: reg}) + + turn, err := e.Submit(context.Background(), "list files", nil) + if err != nil { + t.Fatalf("Submit: %v", err) + } + + if turn.Rounds != 2 { + t.Errorf("Rounds = %d, want 2", turn.Rounds) + } + + // Messages: assistant (tool call), tool results, assistant (final) + if len(turn.Messages) != 3 { + t.Fatalf("len(Messages) = %d, want 3", len(turn.Messages)) + } + + // First message has tool call + if !turn.Messages[0].HasToolCalls() { + t.Error("Messages[0] should have tool calls") + } + + // Second message is tool results + if turn.Messages[1].Role != message.RoleUser { + t.Errorf("Messages[1].Role = %q, want user (tool results)", turn.Messages[1].Role) + } + + // Third message is final text + if turn.Messages[2].TextContent() != "Found file1.go and file2.go." { + t.Errorf("Messages[2].TextContent = %q", turn.Messages[2].TextContent()) + } + + // History: user + assistant(tool call) + tool results + assistant(final) + if len(e.History()) != 4 { + t.Errorf("len(History) = %d, want 4", len(e.History())) + } + + // Provider called twice + if mp.calls != 2 { + t.Errorf("provider called %d times, want 2", mp.calls) + } +} + +func TestSubmit_UnknownTool(t *testing.T) { + reg := tool.NewRegistry() + // Don't register any tools + + mp := &mockProvider{ + name: "test", + streams: []stream.Stream{ + // Model calls a tool that doesn't exist + newEventStream(message.StopToolUse, "", + stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "nonexistent"}, + stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{}`)}, + ), + // Model responds after seeing error + newEventStream(message.StopEndTurn, "", + stream.Event{Type: stream.EventTextDelta, Text: "Sorry, that tool doesn't exist."}, + ), + }, + } + + e, _ := New(Config{Provider: mp, Tools: reg}) + + turn, err := e.Submit(context.Background(), "do something", nil) + if err != nil { + t.Fatalf("Submit: %v", err) + } + + // Should still complete — unknown tool returns error result, model sees it + if turn.Rounds != 2 { + t.Errorf("Rounds = %d, want 2", turn.Rounds) + } +} + +func TestSubmit_ToolExecutionError(t *testing.T) { + reg := tool.NewRegistry() + reg.Register(&mockTool{ + name: "failing", + execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) { + return tool.Result{}, errors.New("disk full") + }, + }) + + mp := &mockProvider{ + name: "test", + streams: []stream.Stream{ + newEventStream(message.StopToolUse, "", + stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "failing"}, + stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{}`)}, + ), + newEventStream(message.StopEndTurn, "", + stream.Event{Type: stream.EventTextDelta, Text: "The tool failed."}, + ), + }, + } + + e, _ := New(Config{Provider: mp, Tools: reg}) + + turn, err := e.Submit(context.Background(), "do it", nil) + if err != nil { + t.Fatalf("Submit: %v", err) + } + + // Tool error is returned as error result, not a fatal error + if turn.Rounds != 2 { + t.Errorf("Rounds = %d, want 2", turn.Rounds) + } +} + +func TestSubmit_MaxTurnsLimit(t *testing.T) { + reg := tool.NewRegistry() + reg.Register(&mockTool{name: "bash"}) + + // Provider always returns tool calls — would loop forever + mp := &mockProvider{ + name: "test", + streams: []stream.Stream{ + newEventStream(message.StopToolUse, "", + stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "bash"}, + stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{}`)}, + ), + newEventStream(message.StopToolUse, "", + stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_2", ToolCallName: "bash"}, + stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_2", Args: json.RawMessage(`{}`)}, + ), + newEventStream(message.StopToolUse, "", + stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_3", ToolCallName: "bash"}, + stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_3", Args: json.RawMessage(`{}`)}, + ), + }, + } + + e, _ := New(Config{Provider: mp, Tools: reg, MaxTurns: 2}) + + _, err := e.Submit(context.Background(), "loop forever", nil) + if err == nil { + t.Fatal("expected error from max turns limit") + } + if mp.calls != 2 { + t.Errorf("provider called %d times, want 2 (limited)", mp.calls) + } +} + +func TestSubmit_MultipleToolCalls(t *testing.T) { + reg := tool.NewRegistry() + reg.Register(&mockTool{ + name: "bash", + execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) { + return tool.Result{Output: "bash output"}, nil + }, + }) + reg.Register(&mockTool{ + name: "fs.read", + readOnly: true, + execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) { + return tool.Result{Output: "file content"}, nil + }, + }) + + mp := &mockProvider{ + name: "test", + streams: []stream.Stream{ + // Model calls two tools at once + newEventStream(message.StopToolUse, "", + stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "bash"}, + stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_1", Args: json.RawMessage(`{"command":"ls"}`)}, + stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_2", ToolCallName: "fs.read"}, + stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_2", Args: json.RawMessage(`{"path":"go.mod"}`)}, + ), + newEventStream(message.StopEndTurn, "", + stream.Event{Type: stream.EventTextDelta, Text: "Done."}, + ), + }, + } + + e, _ := New(Config{Provider: mp, Tools: reg}) + + turn, err := e.Submit(context.Background(), "run both", nil) + if err != nil { + t.Fatalf("Submit: %v", err) + } + + if turn.Rounds != 2 { + t.Errorf("Rounds = %d, want 2", turn.Rounds) + } + + // Tool results message should have 2 results + toolMsg := turn.Messages[1] // assistant, tool_results, assistant + if len(toolMsg.Content) != 2 { + t.Errorf("tool results has %d content blocks, want 2", len(toolMsg.Content)) + } +} + +func TestSubmit_NilCallback(t *testing.T) { + mp := &mockProvider{ + name: "test", + streams: []stream.Stream{ + newEventStream(message.StopEndTurn, "", + stream.Event{Type: stream.EventTextDelta, Text: "ok"}, + ), + }, + } + + e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()}) + + // nil callback should not panic + turn, err := e.Submit(context.Background(), "test", nil) + if err != nil { + t.Fatalf("Submit: %v", err) + } + if turn.Rounds != 1 { + t.Errorf("Rounds = %d", turn.Rounds) + } +} + +func TestEngine_Reset(t *testing.T) { + mp := &mockProvider{ + name: "test", + streams: []stream.Stream{ + newEventStream(message.StopEndTurn, "", + stream.Event{Type: stream.EventTextDelta, Text: "first"}, + stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 100}}, + ), + }, + } + + e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()}) + e.Submit(context.Background(), "hello", nil) + + if len(e.History()) == 0 { + t.Fatal("history should not be empty before reset") + } + if e.Usage().InputTokens == 0 { + t.Fatal("usage should not be zero before reset") + } + + e.Reset() + + if len(e.History()) != 0 { + t.Errorf("history should be empty after reset, got %d", len(e.History())) + } + if e.Usage().InputTokens != 0 { + t.Errorf("usage should be zero after reset, got %d", e.Usage().InputTokens) + } +} + +func TestSubmit_CumulativeUsage(t *testing.T) { + mp := &mockProvider{ + name: "test", + streams: []stream.Stream{ + newEventStream(message.StopEndTurn, "", + stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 100, OutputTokens: 50}}, + stream.Event{Type: stream.EventTextDelta, Text: "first"}, + ), + newEventStream(message.StopEndTurn, "", + stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 200, OutputTokens: 80}}, + stream.Event{Type: stream.EventTextDelta, Text: "second"}, + ), + }, + } + + e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()}) + + e.Submit(context.Background(), "one", nil) + e.Submit(context.Background(), "two", nil) + + if e.Usage().InputTokens != 300 { + t.Errorf("cumulative InputTokens = %d, want 300", e.Usage().InputTokens) + } + if e.Usage().OutputTokens != 130 { + t.Errorf("cumulative OutputTokens = %d, want 130", e.Usage().OutputTokens) + } +} diff --git a/internal/engine/loop.go b/internal/engine/loop.go new file mode 100644 index 0000000..c02e512 --- /dev/null +++ b/internal/engine/loop.go @@ -0,0 +1,204 @@ +package engine + +import ( + "context" + "encoding/json" + "fmt" + + "somegit.dev/Owlibou/gnoma/internal/message" + "somegit.dev/Owlibou/gnoma/internal/provider" + "somegit.dev/Owlibou/gnoma/internal/stream" +) + +// Submit sends a user message and runs the agentic loop to completion. +// The callback receives real-time streaming events. +func (e *Engine) Submit(ctx context.Context, input string, cb Callback) (*Turn, error) { + userMsg := message.NewUserText(input) + e.history = append(e.history, userMsg) + + return e.runLoop(ctx, cb) +} + +// SubmitMessages is like Submit but accepts pre-built messages. +func (e *Engine) SubmitMessages(ctx context.Context, msgs []message.Message, cb Callback) (*Turn, error) { + e.history = append(e.history, msgs...) + + return e.runLoop(ctx, cb) +} + +func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) { + turn := &Turn{} + + for { + turn.Rounds++ + if e.cfg.MaxTurns > 0 && turn.Rounds > e.cfg.MaxTurns { + return turn, fmt.Errorf("safety limit: %d rounds exceeded", e.cfg.MaxTurns) + } + + // Build provider request (gates tools on model capabilities) + req := e.buildRequest(ctx) + + e.logger.Debug("streaming request", + "provider", e.cfg.Provider.Name(), + "model", req.Model, + "messages", len(req.Messages), + "tools", len(req.Tools), + "round", turn.Rounds, + ) + + // Stream from provider + s, err := e.cfg.Provider.Stream(ctx, req) + if err != nil { + return nil, fmt.Errorf("provider stream: %w", err) + } + + // Consume stream, forwarding events to callback + acc := stream.NewAccumulator() + var stopReason message.StopReason + var model string + + for s.Next() { + evt := s.Current() + acc.Apply(evt) + + // Capture stop reason and model from events + if evt.StopReason != "" { + stopReason = evt.StopReason + } + if evt.Model != "" { + model = evt.Model + } + + if cb != nil { + cb(evt) + } + } + if err := s.Err(); err != nil { + s.Close() + return nil, fmt.Errorf("stream error: %w", err) + } + s.Close() + + // Build response + resp := acc.Response(stopReason, model) + turn.Usage.Add(resp.Usage) + turn.Messages = append(turn.Messages, resp.Message) + e.history = append(e.history, resp.Message) + e.usage.Add(resp.Usage) + + e.logger.Debug("turn response", + "stop_reason", resp.StopReason, + "tool_calls", len(resp.Message.ToolCalls()), + "round", turn.Rounds, + ) + + // Decide next action + switch resp.StopReason { + case message.StopEndTurn, message.StopMaxTokens, message.StopSequence: + return turn, nil + + case message.StopToolUse: + results, err := e.executeTools(ctx, resp.Message.ToolCalls(), cb) + if err != nil { + return nil, fmt.Errorf("tool execution: %w", err) + } + toolMsg := message.NewToolResults(results...) + turn.Messages = append(turn.Messages, toolMsg) + e.history = append(e.history, toolMsg) + // Continue loop — re-query provider with tool results + + default: + // Unknown stop reason or empty — treat as end of turn + return turn, nil + } + } +} + +func (e *Engine) buildRequest(ctx context.Context) provider.Request { + req := provider.Request{ + Model: e.cfg.Model, + SystemPrompt: e.cfg.System, + Messages: e.history, + } + + // Only include tools if the model supports them + caps := e.resolveCapabilities(ctx) + if caps == nil || caps.ToolUse { + // nil caps = unknown model, include tools optimistically + for _, t := range e.cfg.Tools.All() { + req.Tools = append(req.Tools, provider.ToolDefinition{ + Name: t.Name(), + Description: t.Description(), + Parameters: t.Parameters(), + }) + } + } else { + e.logger.Debug("tools omitted — model does not support tool use", + "model", req.Model, + ) + } + + return req +} + +func (e *Engine) executeTools(ctx context.Context, calls []message.ToolCall, cb Callback) ([]message.ToolResult, error) { + results := make([]message.ToolResult, 0, len(calls)) + + for _, call := range calls { + t, ok := e.cfg.Tools.Get(call.Name) + if !ok { + e.logger.Warn("unknown tool", "name", call.Name) + results = append(results, message.ToolResult{ + ToolCallID: call.ID, + Content: fmt.Sprintf("unknown tool: %s", call.Name), + IsError: true, + }) + continue + } + + e.logger.Debug("executing tool", "name", call.Name, "id", call.ID) + + result, err := t.Execute(ctx, call.Arguments) + if err != nil { + e.logger.Error("tool execution failed", "name", call.Name, "error", err) + results = append(results, message.ToolResult{ + ToolCallID: call.ID, + Content: err.Error(), + IsError: true, + }) + continue + } + + // Emit tool result as a text delta event so the UI can show it + if cb != nil { + cb(stream.Event{ + Type: stream.EventTextDelta, + Text: fmt.Sprintf("\n[tool:%s] %s\n", call.Name, truncate(result.Output, 500)), + }) + } + + results = append(results, message.ToolResult{ + ToolCallID: call.ID, + Content: result.Output, + }) + } + + return results, nil +} + +func truncate(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] + "..." +} + +// toolDefFromTool converts a tool.Tool to provider.ToolDefinition. +// Unused currently but kept for reference when building tool definitions dynamically. +func toolDefFromJSON(name, description string, params json.RawMessage) provider.ToolDefinition { + return provider.ToolDefinition{ + Name: name, + Description: description, + Parameters: params, + } +} diff --git a/internal/provider/mistral/provider.go b/internal/provider/mistral/provider.go new file mode 100644 index 0000000..ee65cec --- /dev/null +++ b/internal/provider/mistral/provider.go @@ -0,0 +1,124 @@ +package mistral + +import ( + "context" + "fmt" + + "somegit.dev/Owlibou/gnoma/internal/provider" + "somegit.dev/Owlibou/gnoma/internal/stream" + mistralgo "somegit.dev/vikingowl/mistral-go-sdk" + "somegit.dev/vikingowl/mistral-go-sdk/model" +) + +const defaultModel = "mistral-large-latest" + +// Provider implements provider.Provider for the Mistral API. +type Provider struct { + client *mistralgo.Client + name string + model string +} + +// New creates a Mistral provider from config. +func New(cfg provider.ProviderConfig) (provider.Provider, error) { + if cfg.APIKey == "" { + return nil, fmt.Errorf("mistral: api key required") + } + + opts := []mistralgo.Option{} + if cfg.BaseURL != "" { + opts = append(opts, mistralgo.WithBaseURL(cfg.BaseURL)) + } + + client := mistralgo.NewClient(cfg.APIKey, opts...) + + m := cfg.Model + if m == "" { + m = defaultModel + } + + return &Provider{ + client: client, + name: "mistral", + model: m, + }, nil +} + +// Stream initiates a streaming chat completion request. +func (p *Provider) Stream(ctx context.Context, req provider.Request) (stream.Stream, error) { + m := req.Model + if m == "" { + m = p.model + } + + cr := translateRequest(req) + cr.Model = m + + raw, err := p.client.ChatCompleteStream(ctx, cr) + if err != nil { + return nil, p.wrapError(err) + } + + return newMistralStream(raw), nil +} + +// Name returns "mistral". +func (p *Provider) Name() string { + return p.name +} + +// DefaultModel returns the configured default model. +func (p *Provider) DefaultModel() string { + return p.model +} + +// Models lists available models from the Mistral API with capability metadata. +func (p *Provider) Models(ctx context.Context) ([]provider.ModelInfo, error) { + resp, err := p.client.ListModels(ctx, &model.ListParams{}) + if err != nil { + return nil, p.wrapError(err) + } + + var models []provider.ModelInfo + for _, m := range resp.Data { + models = append(models, provider.ModelInfo{ + ID: m.ID, + Name: m.ID, + Provider: p.name, + Capabilities: inferCapabilities(m), + }) + } + return models, nil +} + +// inferCapabilities maps Mistral model metadata to gnoma capabilities. +func inferCapabilities(m model.ModelCard) provider.Capabilities { + caps := provider.Capabilities{ + ToolUse: m.Capabilities.FunctionCalling, + Vision: m.Capabilities.Vision, + JSONOutput: m.Capabilities.CompletionChat, // all chat models support JSON output via ResponseFormat + ContextWindow: m.MaxContextLength, + MaxOutput: 8192, // reasonable default + } + return caps +} + +func (p *Provider) wrapError(err error) error { + if apiErr, ok := err.(*mistralgo.APIError); ok { + kind, retryable := provider.ClassifyHTTPStatus(apiErr.StatusCode) + return &provider.ProviderError{ + Kind: kind, + Provider: p.name, + StatusCode: apiErr.StatusCode, + Message: apiErr.Message, + Retryable: retryable, + Err: err, + } + } + return &provider.ProviderError{ + Kind: provider.ErrTransient, + Provider: p.name, + Message: err.Error(), + Err: err, + } +} diff --git a/internal/provider/mistral/stream.go b/internal/provider/mistral/stream.go new file mode 100644 index 0000000..9e09c12 --- /dev/null +++ b/internal/provider/mistral/stream.go @@ -0,0 +1,248 @@ +package mistral + +import ( + "encoding/json" + + "somegit.dev/Owlibou/gnoma/internal/message" + "somegit.dev/Owlibou/gnoma/internal/stream" + mistralgo "somegit.dev/vikingowl/mistral-go-sdk" + "somegit.dev/vikingowl/mistral-go-sdk/chat" +) + +// mistralStream adapts mistral's Stream[CompletionChunk] to gnoma's stream.Stream. +type mistralStream struct { + raw *mistralgo.Stream[chat.CompletionChunk] + cur stream.Event + err error + model string + + // Track active tool calls for delta assembly + activeToolCalls map[int]*toolCallState // keyed by ToolCall.Index + + // Deferred finish reason (when finish arrives on the same chunk as content) + pendingFinish *chat.FinishReason + pendingUsage *message.Usage // usage from a chunk that also had other data + emittedStop bool // true after we've emitted the synthetic stop event + hadToolCalls bool // true if any tool calls were emitted +} + +type toolCallState struct { + id string + name string + args string // accumulated argument fragments +} + +func newMistralStream(raw *mistralgo.Stream[chat.CompletionChunk]) *mistralStream { + return &mistralStream{ + raw: raw, + activeToolCalls: make(map[int]*toolCallState), + } +} + +func (s *mistralStream) Next() bool { + for s.raw.Next() { + chunk := s.raw.Current() + + // Capture model from first chunk + if s.model == "" && chunk.Model != "" { + s.model = chunk.Model + } + + // Store usage if present (may be on same chunk as tool calls or finish) + if chunk.Usage != nil { + s.pendingUsage = translateUsage(chunk.Usage) + } + + if len(chunk.Choices) == 0 { + // Chunk with only usage and no choices — emit usage + if s.pendingUsage != nil { + s.cur = stream.Event{Type: stream.EventUsage, Usage: s.pendingUsage} + s.pendingUsage = nil + return true + } + continue + } + + choice := chunk.Choices[0] + delta := choice.Delta + + // Process text content first (even on chunks with finish reason) + text := delta.Content.String() + if text != "" { + s.cur = stream.Event{ + Type: stream.EventTextDelta, + Text: text, + } + // If this chunk also has a finish reason, store it for next iteration + if choice.FinishReason != nil { + s.pendingFinish = choice.FinishReason + } + return true + } + + // Tool call deltas + if len(delta.ToolCalls) > 0 { + // Store finish reason if present on same chunk as tool calls + if choice.FinishReason != nil { + s.pendingFinish = choice.FinishReason + } + + for _, tc := range delta.ToolCalls { + existing, ok := s.activeToolCalls[tc.Index] + if !ok { + // New tool call + s.activeToolCalls[tc.Index] = &toolCallState{ + id: tc.ID, + name: tc.Function.Name, + args: tc.Function.Arguments, + } + s.hadToolCalls = true + + // If arguments are already complete (Mistral sends full args in one chunk), + // emit ToolCallDone directly instead of Start + if tc.Function.Arguments != "" && s.pendingFinish != nil { + s.cur = stream.Event{ + Type: stream.EventToolCallDone, + ToolCallID: tc.ID, + ToolCallName: tc.Function.Name, + Args: json.RawMessage(tc.Function.Arguments), + } + // Remove from active — it's already done + delete(s.activeToolCalls, tc.Index) + return true + } + + // Otherwise emit Start, accumulate deltas later + s.cur = stream.Event{ + Type: stream.EventToolCallStart, + ToolCallID: tc.ID, + ToolCallName: tc.Function.Name, + } + return true + } + // Existing tool call — accumulate arguments, emit Delta + existing.args += tc.Function.Arguments + if tc.Function.Arguments != "" { + s.cur = stream.Event{ + Type: stream.EventToolCallDelta, + ToolCallID: existing.id, + ArgDelta: tc.Function.Arguments, + } + return true + } + } + continue + } + + // Check finish reason (from this chunk or pending from previous) + fr := choice.FinishReason + if fr == nil { + fr = s.pendingFinish + s.pendingFinish = nil + } + + if fr != nil { + // Flush any pending tool calls as Done events + if *fr == chat.FinishReasonToolCalls { + for idx, tc := range s.activeToolCalls { + s.cur = stream.Event{ + Type: stream.EventToolCallDone, + ToolCallID: tc.id, + Args: json.RawMessage(tc.args), + } + delete(s.activeToolCalls, idx) + s.pendingFinish = fr // re-store to flush remaining on next call + return true + } + } + + // Final event with stop reason + s.cur = stream.Event{ + Type: stream.EventTextDelta, + StopReason: translateFinishReason(fr), + Model: s.model, + } + return true + } + } + + // Drain any pending finish reason that was stored with the last content chunk + if s.pendingFinish != nil { + fr := s.pendingFinish + s.pendingFinish = nil + + // Flush pending tool calls + if *fr == chat.FinishReasonToolCalls { + for idx, tc := range s.activeToolCalls { + s.cur = stream.Event{ + Type: stream.EventToolCallDone, + ToolCallID: tc.id, + Args: json.RawMessage(tc.args), + } + delete(s.activeToolCalls, idx) + s.pendingFinish = fr + return true + } + } + + s.cur = stream.Event{ + Type: stream.EventTextDelta, + StopReason: translateFinishReason(fr), + Model: s.model, + } + return true + } + + // Emit any pending usage before the stop event + if s.pendingUsage != nil { + s.cur = stream.Event{Type: stream.EventUsage, Usage: s.pendingUsage} + s.pendingUsage = nil + return true + } + + // Stream ended — emit inferred stop reason. + if !s.emittedStop { + s.emittedStop = true + + // If we have pending tool calls, they ended with the stream + if len(s.activeToolCalls) > 0 { + for idx, tc := range s.activeToolCalls { + s.cur = stream.Event{ + Type: stream.EventToolCallDone, + ToolCallID: tc.id, + Args: json.RawMessage(tc.args), + } + delete(s.activeToolCalls, idx) + return true + } + } + + // Infer stop reason: if tool calls were emitted, it's ToolUse; otherwise EndTurn + stopReason := message.StopEndTurn + if s.hadToolCalls { + stopReason = message.StopToolUse + } + + s.cur = stream.Event{ + Type: stream.EventTextDelta, + StopReason: stopReason, + Model: s.model, + } + return true + } + + s.err = s.raw.Err() + return false +} + +func (s *mistralStream) Current() stream.Event { + return s.cur +} + +func (s *mistralStream) Err() error { + return s.err +} + +func (s *mistralStream) Close() error { + return s.raw.Close() +} diff --git a/internal/provider/mistral/translate.go b/internal/provider/mistral/translate.go new file mode 100644 index 0000000..d0ffc6e --- /dev/null +++ b/internal/provider/mistral/translate.go @@ -0,0 +1,177 @@ +package mistral + +import ( + "encoding/json" + + "somegit.dev/Owlibou/gnoma/internal/message" + "somegit.dev/Owlibou/gnoma/internal/provider" + "somegit.dev/vikingowl/mistral-go-sdk/chat" +) + +// --- gnoma → Mistral --- + +func translateMessages(msgs []message.Message) []chat.Message { + out := make([]chat.Message, 0, len(msgs)) + for _, m := range msgs { + out = append(out, translateMessage(m)) + } + return out +} + +func translateMessage(m message.Message) chat.Message { + switch m.Role { + case message.RoleSystem: + return &chat.SystemMessage{Content: chat.TextContent(m.TextContent())} + + case message.RoleUser: + // Check if this is a tool results message + if len(m.Content) > 0 && m.Content[0].Type == message.ContentToolResult { + // Tool results must be sent as individual ToolMessages + // Return only the first; caller handles multi-result expansion + tr := m.Content[0].ToolResult + return &chat.ToolMessage{ + ToolCallID: tr.ToolCallID, + Content: chat.TextContent(tr.Content), + } + } + return &chat.UserMessage{Content: chat.TextContent(m.TextContent())} + + case message.RoleAssistant: + am := chat.AssistantMessage{ + Content: chat.TextContent(m.TextContent()), + } + for _, tc := range m.ToolCalls() { + am.ToolCalls = append(am.ToolCalls, chat.ToolCall{ + ID: tc.ID, + Type: "function", + Function: chat.FunctionCall{ + Name: tc.Name, + Arguments: string(tc.Arguments), + }, + }) + } + return &am + + default: + return &chat.UserMessage{Content: chat.TextContent(m.TextContent())} + } +} + +// expandToolResults handles the case where a gnoma Message contains +// multiple ToolResults. Mistral expects one ToolMessage per result. +func expandToolResults(msgs []message.Message) []chat.Message { + out := make([]chat.Message, 0, len(msgs)) + for _, m := range msgs { + if m.Role == message.RoleUser && len(m.Content) > 0 && m.Content[0].Type == message.ContentToolResult { + for _, c := range m.Content { + if c.Type == message.ContentToolResult && c.ToolResult != nil { + out = append(out, &chat.ToolMessage{ + ToolCallID: c.ToolResult.ToolCallID, + Content: chat.TextContent(c.ToolResult.Content), + }) + } + } + continue + } + out = append(out, translateMessage(m)) + } + return out +} + +func translateTools(defs []provider.ToolDefinition) []chat.Tool { + if len(defs) == 0 { + return nil + } + tools := make([]chat.Tool, len(defs)) + for i, d := range defs { + var params map[string]any + if d.Parameters != nil { + _ = json.Unmarshal(d.Parameters, ¶ms) + } + tools[i] = chat.Tool{ + Type: "function", + Function: chat.Function{ + Name: d.Name, + Description: d.Description, + Parameters: params, + }, + } + } + return tools +} + +func translateRequest(req provider.Request) *chat.CompletionRequest { + cr := &chat.CompletionRequest{ + Model: req.Model, + Messages: expandToolResults(req.Messages), + Tools: translateTools(req.Tools), + Stop: req.StopSequences, + } + if req.MaxTokens > 0 { + mt := int(req.MaxTokens) + cr.MaxTokens = &mt + } + if req.Temperature != nil { + cr.Temperature = req.Temperature + } + if req.TopP != nil { + cr.TopP = req.TopP + } + if req.ResponseFormat != nil { + cr.ResponseFormat = translateResponseFormat(req.ResponseFormat) + } + return cr +} + +func translateResponseFormat(rf *provider.ResponseFormat) *chat.ResponseFormat { + if rf == nil { + return nil + } + out := &chat.ResponseFormat{ + Type: chat.ResponseFormatType(rf.Type), + } + if rf.JSONSchema != nil { + var schema map[string]any + if rf.JSONSchema.Schema != nil { + _ = json.Unmarshal(rf.JSONSchema.Schema, &schema) + } + out.JsonSchema = &chat.JsonSchema{ + Name: rf.JSONSchema.Name, + Schema: schema, + Strict: rf.JSONSchema.Strict, + } + if rf.JSONSchema.Description != "" { + desc := rf.JSONSchema.Description + out.JsonSchema.Description = &desc + } + } + return out +} + +// --- Mistral → gnoma --- + +func translateFinishReason(fr *chat.FinishReason) message.StopReason { + if fr == nil { + return "" + } + switch *fr { + case chat.FinishReasonStop: + return message.StopEndTurn + case chat.FinishReasonToolCalls: + return message.StopToolUse + case chat.FinishReasonLength, chat.FinishReasonModelLength: + return message.StopMaxTokens + default: + return message.StopEndTurn + } +} + +func translateUsage(u *chat.UsageInfo) *message.Usage { + if u == nil { + return nil + } + return &message.Usage{ + InputTokens: int64(u.PromptTokens), + OutputTokens: int64(u.CompletionTokens), + } +} diff --git a/internal/provider/mistral/translate_test.go b/internal/provider/mistral/translate_test.go new file mode 100644 index 0000000..714fc1d --- /dev/null +++ b/internal/provider/mistral/translate_test.go @@ -0,0 +1,256 @@ +package mistral + +import ( + "encoding/json" + "testing" + + "somegit.dev/Owlibou/gnoma/internal/message" + "somegit.dev/Owlibou/gnoma/internal/provider" + "somegit.dev/vikingowl/mistral-go-sdk/chat" +) + +func TestTranslateMessage_User(t *testing.T) { + m := message.NewUserText("hello world") + result := translateMessage(m) + + um, ok := result.(*chat.UserMessage) + if !ok { + t.Fatalf("expected *UserMessage, got %T", result) + } + if um.Content.String() != "hello world" { + t.Errorf("Content = %q, want %q", um.Content.String(), "hello world") + } +} + +func TestTranslateMessage_System(t *testing.T) { + m := message.NewSystemText("you are a helper") + result := translateMessage(m) + + sm, ok := result.(*chat.SystemMessage) + if !ok { + t.Fatalf("expected *SystemMessage, got %T", result) + } + if sm.Content.String() != "you are a helper" { + t.Errorf("Content = %q", sm.Content.String()) + } +} + +func TestTranslateMessage_AssistantText(t *testing.T) { + m := message.NewAssistantText("here's the answer") + result := translateMessage(m) + + am, ok := result.(*chat.AssistantMessage) + if !ok { + t.Fatalf("expected *AssistantMessage, got %T", result) + } + if am.Content.String() != "here's the answer" { + t.Errorf("Content = %q", am.Content.String()) + } + if len(am.ToolCalls) != 0 { + t.Errorf("ToolCalls should be empty, got %d", len(am.ToolCalls)) + } +} + +func TestTranslateMessage_AssistantWithToolCalls(t *testing.T) { + m := message.NewAssistantContent( + message.NewTextContent("running command"), + message.NewToolCallContent(message.ToolCall{ + ID: "tc_1", + Name: "bash", + Arguments: json.RawMessage(`{"command":"ls"}`), + }), + ) + result := translateMessage(m) + + am, ok := result.(*chat.AssistantMessage) + if !ok { + t.Fatalf("expected *AssistantMessage, got %T", result) + } + if len(am.ToolCalls) != 1 { + t.Fatalf("len(ToolCalls) = %d, want 1", len(am.ToolCalls)) + } + if am.ToolCalls[0].ID != "tc_1" { + t.Errorf("ToolCalls[0].ID = %q", am.ToolCalls[0].ID) + } + if am.ToolCalls[0].Function.Name != "bash" { + t.Errorf("ToolCalls[0].Function.Name = %q", am.ToolCalls[0].Function.Name) + } + if am.ToolCalls[0].Function.Arguments != `{"command":"ls"}` { + t.Errorf("ToolCalls[0].Function.Arguments = %q", am.ToolCalls[0].Function.Arguments) + } +} + +func TestExpandToolResults(t *testing.T) { + msgs := []message.Message{ + message.NewUserText("run two commands"), + message.NewAssistantContent( + message.NewToolCallContent(message.ToolCall{ID: "tc_1", Name: "bash"}), + message.NewToolCallContent(message.ToolCall{ID: "tc_2", Name: "bash"}), + ), + message.NewToolResults( + message.ToolResult{ToolCallID: "tc_1", Content: "output1"}, + message.ToolResult{ToolCallID: "tc_2", Content: "output2"}, + ), + } + + expanded := expandToolResults(msgs) + + // UserMessage, AssistantMessage, ToolMessage, ToolMessage + if len(expanded) != 4 { + t.Fatalf("len(expanded) = %d, want 4", len(expanded)) + } + + // First: UserMessage + if _, ok := expanded[0].(*chat.UserMessage); !ok { + t.Errorf("expanded[0] = %T, want *UserMessage", expanded[0]) + } + + // Second: AssistantMessage + if _, ok := expanded[1].(*chat.AssistantMessage); !ok { + t.Errorf("expanded[1] = %T, want *AssistantMessage", expanded[1]) + } + + // Third and fourth: ToolMessages + tm1, ok := expanded[2].(*chat.ToolMessage) + if !ok { + t.Fatalf("expanded[2] = %T, want *ToolMessage", expanded[2]) + } + if tm1.ToolCallID != "tc_1" { + t.Errorf("expanded[2].ToolCallID = %q, want tc_1", tm1.ToolCallID) + } + + tm2, ok := expanded[3].(*chat.ToolMessage) + if !ok { + t.Fatalf("expanded[3] = %T, want *ToolMessage", expanded[3]) + } + if tm2.ToolCallID != "tc_2" { + t.Errorf("expanded[3].ToolCallID = %q, want tc_2", tm2.ToolCallID) + } +} + +func TestTranslateTools(t *testing.T) { + defs := []provider.ToolDefinition{ + { + Name: "bash", + Description: "Run a bash command", + Parameters: json.RawMessage(`{"type":"object","properties":{"command":{"type":"string"}},"required":["command"]}`), + }, + { + Name: "fs.read", + Description: "Read a file", + Parameters: json.RawMessage(`{"type":"object","properties":{"path":{"type":"string"}}}`), + }, + } + + tools := translateTools(defs) + if len(tools) != 2 { + t.Fatalf("len(tools) = %d, want 2", len(tools)) + } + + if tools[0].Type != "function" { + t.Errorf("tools[0].Type = %q, want function", tools[0].Type) + } + if tools[0].Function.Name != "bash" { + t.Errorf("tools[0].Function.Name = %q", tools[0].Function.Name) + } + if tools[0].Function.Description != "Run a bash command" { + t.Errorf("tools[0].Function.Description = %q", tools[0].Function.Description) + } + if tools[0].Function.Parameters == nil { + t.Error("tools[0].Function.Parameters should not be nil") + } + // Verify the parameters were correctly unmarshaled + if _, ok := tools[0].Function.Parameters["type"]; !ok { + t.Error("tools[0].Function.Parameters missing 'type' key") + } +} + +func TestTranslateTools_Empty(t *testing.T) { + tools := translateTools(nil) + if tools != nil { + t.Errorf("translateTools(nil) should return nil, got %v", tools) + } +} + +func TestTranslateFinishReason(t *testing.T) { + tests := []struct { + name string + reason *chat.FinishReason + want message.StopReason + }{ + {"nil", nil, ""}, + {"stop", ptr(chat.FinishReasonStop), message.StopEndTurn}, + {"tool_calls", ptr(chat.FinishReasonToolCalls), message.StopToolUse}, + {"length", ptr(chat.FinishReasonLength), message.StopMaxTokens}, + {"model_length", ptr(chat.FinishReasonModelLength), message.StopMaxTokens}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := translateFinishReason(tt.reason) + if got != tt.want { + t.Errorf("translateFinishReason() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestTranslateUsage(t *testing.T) { + u := &chat.UsageInfo{ + PromptTokens: 100, + CompletionTokens: 50, + TotalTokens: 150, + } + result := translateUsage(u) + if result.InputTokens != 100 { + t.Errorf("InputTokens = %d, want 100", result.InputTokens) + } + if result.OutputTokens != 50 { + t.Errorf("OutputTokens = %d, want 50", result.OutputTokens) + } +} + +func TestTranslateUsage_Nil(t *testing.T) { + result := translateUsage(nil) + if result != nil { + t.Error("translateUsage(nil) should return nil") + } +} + +func TestTranslateRequest(t *testing.T) { + temp := 0.7 + req := provider.Request{ + Model: "mistral-large-latest", + SystemPrompt: "you are helpful", + Messages: []message.Message{ + message.NewSystemText("you are helpful"), + message.NewUserText("hello"), + }, + Tools: []provider.ToolDefinition{ + {Name: "bash", Description: "Run command", Parameters: json.RawMessage(`{"type":"object"}`)}, + }, + MaxTokens: 4096, + Temperature: &temp, + } + + cr := translateRequest(req) + + if cr.Model != "mistral-large-latest" { + t.Errorf("Model = %q", cr.Model) + } + if len(cr.Messages) != 2 { + t.Errorf("len(Messages) = %d, want 2", len(cr.Messages)) + } + if len(cr.Tools) != 1 { + t.Errorf("len(Tools) = %d, want 1", len(cr.Tools)) + } + if cr.MaxTokens == nil || *cr.MaxTokens != 4096 { + t.Errorf("MaxTokens = %v", cr.MaxTokens) + } + if cr.Temperature == nil || *cr.Temperature != 0.7 { + t.Errorf("Temperature = %v", cr.Temperature) + } +} + +func ptr[T any](v T) *T { + return &v +} diff --git a/internal/provider/provider.go b/internal/provider/provider.go index 4dec0fd..feba16b 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -10,16 +10,17 @@ import ( // Request encapsulates everything needed for a single LLM API call. type Request struct { - Model string - SystemPrompt string - Messages []message.Message - Tools []ToolDefinition - MaxTokens int64 - Temperature *float64 - TopP *float64 - TopK *int64 - StopSequences []string - Thinking *ThinkingConfig + Model string + SystemPrompt string + Messages []message.Message + Tools []ToolDefinition + MaxTokens int64 + Temperature *float64 + TopP *float64 + TopK *int64 + StopSequences []string + Thinking *ThinkingConfig + ResponseFormat *ResponseFormat } // ToolDefinition is the provider-agnostic tool schema. @@ -34,6 +35,50 @@ type ThinkingConfig struct { BudgetTokens int64 } +// ResponseFormat controls the output format. +type ResponseFormat struct { + Type ResponseFormatType + JSONSchema *JSONSchema // only used when Type == ResponseJSON +} + +type ResponseFormatType string + +const ( + ResponseText ResponseFormatType = "text" + ResponseJSON ResponseFormatType = "json_object" +) + +// JSONSchema defines a schema for structured JSON output. +type JSONSchema struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + Schema json.RawMessage `json:"schema"` + Strict bool `json:"strict,omitempty"` +} + +// Capabilities describes what a model can do. +type Capabilities struct { + ToolUse bool `json:"tool_use"` + JSONOutput bool `json:"json_output"` + Thinking bool `json:"thinking"` + Vision bool `json:"vision"` + ContextWindow int `json:"context_window"` + MaxOutput int `json:"max_output"` +} + +// ModelInfo describes a model available from a provider. +type ModelInfo struct { + ID string `json:"id"` + Name string `json:"name"` + Provider string `json:"provider"` + Capabilities Capabilities `json:"capabilities"` +} + +// SupportsTools returns true if the model supports tool/function calling. +func (m ModelInfo) SupportsTools() bool { + return m.Capabilities.ToolUse +} + // Provider is the core abstraction over all LLM backends. type Provider interface { // Stream initiates a streaming request and returns an event stream. @@ -41,4 +86,10 @@ type Provider interface { // Name returns the provider identifier (e.g., "mistral", "anthropic"). Name() string + + // Models returns available models with their capabilities. + Models(ctx context.Context) ([]ModelInfo, error) + + // DefaultModel returns the default model ID for this provider. + DefaultModel() string } diff --git a/internal/provider/registry_test.go b/internal/provider/registry_test.go index 525194d..7bd163f 100644 --- a/internal/provider/registry_test.go +++ b/internal/provider/registry_test.go @@ -19,8 +19,10 @@ func (m *mockProvider) Stream(_ context.Context, _ Request) (stream.Stream, erro return nil, nil } -func (m *mockProvider) Name() string { - return m.name +func (m *mockProvider) Name() string { return m.name } +func (m *mockProvider) DefaultModel() string { return "mock-model" } +func (m *mockProvider) Models(_ context.Context) ([]ModelInfo, error) { + return []ModelInfo{{ID: "mock-model", Name: "mock-model", Provider: m.name}}, nil } func TestRegistry_RegisterAndCreate(t *testing.T) { diff --git a/internal/stream/accumulator.go b/internal/stream/accumulator.go index 155926c..404613e 100644 --- a/internal/stream/accumulator.go +++ b/internal/stream/accumulator.go @@ -68,14 +68,19 @@ func (a *Accumulator) Apply(e Event) { } case EventToolCallDone: - if tc, ok := a.toolCalls[e.ToolCallID]; ok { - if e.Args != nil { - // Done event carries authoritative complete args - tc.args = e.Args - } else { - // Fall back to accumulated deltas - tc.args = []byte(tc.argsBuf.String()) - } + tc, ok := a.toolCalls[e.ToolCallID] + if !ok { + // Done without prior Start (e.g., Mistral sends complete tool calls in one chunk) + tc = &toolCallAccum{id: e.ToolCallID, name: e.ToolCallName} + a.toolCalls[e.ToolCallID] = tc + a.toolCallOrder = append(a.toolCallOrder, e.ToolCallID) + } + if e.Args != nil { + // Done event carries authoritative complete args + tc.args = e.Args + } else { + // Fall back to accumulated deltas + tc.args = []byte(tc.argsBuf.String()) } case EventUsage: diff --git a/internal/tool/bash/aliases.go b/internal/tool/bash/aliases.go new file mode 100644 index 0000000..3813a77 --- /dev/null +++ b/internal/tool/bash/aliases.go @@ -0,0 +1,231 @@ +package bash + +import ( + "context" + "fmt" + "os" + "os/exec" + "strings" + "sync" + "time" +) + +const aliasHarvestTimeout = 5 * time.Second + +// AliasMap holds harvested shell aliases. +type AliasMap struct { + mu sync.RWMutex + aliases map[string]string // alias name → expansion +} + +func NewAliasMap() *AliasMap { + return &AliasMap{aliases: make(map[string]string)} +} + +// Get returns the expansion for an alias, or empty string if not found. +func (m *AliasMap) Get(name string) (string, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + exp, ok := m.aliases[name] + return exp, ok +} + +// Len returns the number of harvested aliases. +func (m *AliasMap) Len() int { + m.mu.RLock() + defer m.mu.RUnlock() + return len(m.aliases) +} + +// All returns a copy of all aliases. +func (m *AliasMap) All() map[string]string { + m.mu.RLock() + defer m.mu.RUnlock() + cp := make(map[string]string, len(m.aliases)) + for k, v := range m.aliases { + cp[k] = v + } + return cp +} + +// ExpandCommand expands the first word of a command if it's a known alias. +// Only the first word is expanded (matching bash alias behavior). +// Returns the original command unchanged if no alias matches. +func (m *AliasMap) ExpandCommand(cmd string) string { + trimmed := strings.TrimSpace(cmd) + if trimmed == "" { + return cmd + } + + // Extract first word + firstWord := trimmed + rest := "" + if idx := strings.IndexAny(trimmed, " \t"); idx != -1 { + firstWord = trimmed[:idx] + rest = trimmed[idx:] + } + + m.mu.RLock() + expansion, ok := m.aliases[firstWord] + m.mu.RUnlock() + + if !ok { + return cmd + } + return expansion + rest +} + +// HarvestAliases spawns the user's shell once to collect alias definitions. +// Supports bash, zsh, and fish. Falls back gracefully for unknown shells. +// Safe: only reads alias text definitions, never sources them in execution context. +func HarvestAliases(ctx context.Context) (*AliasMap, error) { + shell := os.Getenv("SHELL") + if shell == "" { + shell = "/bin/bash" + } + + ctx, cancel := context.WithTimeout(ctx, aliasHarvestTimeout) + defer cancel() + + // Build the alias dump command based on shell type + shellBase := shellBaseName(shell) + aliasCmd := aliasCommandFor(shellBase) + + // -i: interactive (loads rc files), -c: run command then exit + cmd := exec.CommandContext(ctx, shell, "-ic", aliasCmd) + // Prevent the interactive shell from reading actual stdin + cmd.Stdin = nil + // Suppress stderr (shell startup warnings like zsh's "can't change option: zle") + cmd.Stderr = nil + + // Use Output() but don't fail on non-zero exit — zsh often exits with + // errors from zle/prompt setup while still producing valid alias output + output, err := cmd.Output() + if len(output) == 0 && err != nil { + return NewAliasMap(), fmt.Errorf("alias harvest (%s): %w", shellBase, err) + } + // If we got output, parse it regardless of exit code + + if shellBase == "fish" { + return ParseFishAliases(string(output)) + } + return ParseAliases(string(output)) +} + +// shellBaseName extracts the shell name from a path (e.g., "/bin/zsh" → "zsh"). +func shellBaseName(shell string) string { + parts := strings.Split(shell, "/") + return parts[len(parts)-1] +} + +// aliasCommandFor returns the alias dump command for a given shell. +func aliasCommandFor(shell string) string { + switch shell { + case "fish": + // fish uses `alias` without -p, outputs: alias name 'expansion' + return "alias 2>/dev/null; true" + case "zsh": + // zsh: `alias -p` produces nothing; `alias` outputs name=value (no quotes) + return "alias 2>/dev/null; true" + case "bash", "sh", "dash", "ash": + // POSIX shells use `alias -p` + return "alias -p 2>/dev/null; true" + default: + // Best effort for unknown shells + return "alias 2>/dev/null; true" + } +} + +// ParseFishAliases parses fish shell alias output. +// Fish format: alias name 'expansion' or alias name "expansion" +func ParseFishAliases(output string) (*AliasMap, error) { + m := NewAliasMap() + + for _, line := range strings.Split(output, "\n") { + line = strings.TrimSpace(line) + if line == "" || !strings.HasPrefix(line, "alias ") { + continue + } + // Remove "alias " prefix + rest := strings.TrimPrefix(line, "alias ") + + // Split: name 'expansion' or name "expansion" or name expansion + spaceIdx := strings.IndexByte(rest, ' ') + if spaceIdx == -1 { + continue + } + + name := rest[:spaceIdx] + expansion := strings.TrimSpace(rest[spaceIdx+1:]) + expansion = stripQuotes(expansion) + + if name == "" || expansion == "" { + continue + } + + if v := ValidateCommand(expansion); v != nil { + continue + } + + m.mu.Lock() + m.aliases[name] = expansion + m.mu.Unlock() + } + + return m, nil +} + +// ParseAliases parses the output of `alias -p` into an AliasMap. +// Each line is: alias name='expansion' (bash) or name=expansion (zsh) +func ParseAliases(output string) (*AliasMap, error) { + m := NewAliasMap() + + for _, line := range strings.Split(output, "\n") { + line = strings.TrimSpace(line) + if line == "" { + continue + } + + // Strip "alias " prefix if present (bash format) + line = strings.TrimPrefix(line, "alias ") + + // Split on first '=' + eqIdx := strings.Index(line, "=") + if eqIdx == -1 { + continue + } + + name := line[:eqIdx] + expansion := line[eqIdx+1:] + + // Strip surrounding quotes from expansion + expansion = stripQuotes(expansion) + + if name == "" || expansion == "" { + continue + } + + // Security: validate the expansion doesn't contain dangerous patterns + if v := ValidateCommand(expansion); v != nil { + // Skip aliases with dangerous expansions + continue + } + + m.mu.Lock() + m.aliases[name] = expansion + m.mu.Unlock() + } + + return m, nil +} + +// stripQuotes removes matching surrounding single or double quotes. +func stripQuotes(s string) string { + if len(s) < 2 { + return s + } + if (s[0] == '\'' && s[len(s)-1] == '\'') || (s[0] == '"' && s[len(s)-1] == '"') { + return s[1 : len(s)-1] + } + return s +} diff --git a/internal/tool/bash/aliases_test.go b/internal/tool/bash/aliases_test.go new file mode 100644 index 0000000..ca8e023 --- /dev/null +++ b/internal/tool/bash/aliases_test.go @@ -0,0 +1,288 @@ +package bash + +import ( + "context" + "testing" +) + +func TestParseAliases_BashFormat(t *testing.T) { + output := `alias gs='git status' +alias ll='ls -la --color=auto' +alias gco='git checkout' +alias ..='cd ..' +` + m, err := ParseAliases(output) + if err != nil { + t.Fatalf("ParseAliases: %v", err) + } + + if m.Len() != 4 { + t.Errorf("Len() = %d, want 4", m.Len()) + } + + tests := []struct { + name, want string + }{ + {"gs", "git status"}, + {"ll", "ls -la --color=auto"}, + {"gco", "git checkout"}, + {"..", "cd .."}, + } + for _, tt := range tests { + got, ok := m.Get(tt.name) + if !ok { + t.Errorf("alias %q not found", tt.name) + continue + } + if got != tt.want { + t.Errorf("alias %q = %q, want %q", tt.name, got, tt.want) + } + } +} + +func TestParseAliases_ZshFormat(t *testing.T) { + // zsh alias -p may omit 'alias ' prefix + output := `gs='git status' +ll='ls -la' +` + m, err := ParseAliases(output) + if err != nil { + t.Fatalf("ParseAliases: %v", err) + } + + got, ok := m.Get("gs") + if !ok || got != "git status" { + t.Errorf("gs = %q, %v", got, ok) + } +} + +func TestParseAliases_DoubleQuotes(t *testing.T) { + output := `alias gs="git status" +` + m, _ := ParseAliases(output) + + got, ok := m.Get("gs") + if !ok || got != "git status" { + t.Errorf("gs = %q, %v", got, ok) + } +} + +func TestParseAliases_SkipsDangerousExpansions(t *testing.T) { + output := `alias safe='ls -la' +alias danger='echo $(whoami)' +alias backtick='echo ` + "`" + `date` + "`" + `' +alias ifshack='IFS=: read a b' +` + m, _ := ParseAliases(output) + + if _, ok := m.Get("safe"); !ok { + t.Error("safe alias should be kept") + } + if _, ok := m.Get("danger"); ok { + t.Error("danger alias ($()) should be filtered") + } + if _, ok := m.Get("backtick"); ok { + t.Error("backtick alias should be filtered") + } + if _, ok := m.Get("ifshack"); ok { + t.Error("IFS alias should be filtered") + } +} + +func TestParseAliases_EmptyAndMalformed(t *testing.T) { + output := ` +alias gs='git status' + +not a valid line +alias =empty_name +alias noequals +` + m, _ := ParseAliases(output) + + if m.Len() != 1 { + t.Errorf("Len() = %d, want 1 (only gs)", m.Len()) + } +} + +func TestAliasMap_ExpandCommand(t *testing.T) { + m := NewAliasMap() + m.mu.Lock() + m.aliases["ll"] = "ls -la --color=auto" + m.aliases["gs"] = "git status" + m.aliases[".."] = "cd .." + m.mu.Unlock() + + tests := []struct { + input string + want string + }{ + // Alias with args + {"ll /tmp", "ls -la --color=auto /tmp"}, + // Alias without args + {"gs", "git status"}, + // Alias with trailing whitespace (trimmed) + {"gs ", "git status"}, + // No alias match — return unchanged + {"echo hello", "echo hello"}, + // Dotdot alias + {"..", "cd .."}, + // Empty command + {"", ""}, + // Only whitespace + {" ", " "}, + } + for _, tt := range tests { + got := m.ExpandCommand(tt.input) + if got != tt.want { + t.Errorf("ExpandCommand(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func TestAliasMap_ExpandCommand_NoAliases(t *testing.T) { + m := NewAliasMap() + got := m.ExpandCommand("echo hello") + if got != "echo hello" { + t.Errorf("ExpandCommand = %q, want unchanged", got) + } +} + +func TestAliasMap_All(t *testing.T) { + m := NewAliasMap() + m.mu.Lock() + m.aliases["a"] = "b" + m.aliases["c"] = "d" + m.mu.Unlock() + + all := m.All() + if len(all) != 2 { + t.Errorf("len(All()) = %d, want 2", len(all)) + } + // Verify it's a copy + all["x"] = "y" + if m.Len() != 2 { + t.Error("All() should return a copy, not a reference") + } +} + +func TestStripQuotes(t *testing.T) { + tests := []struct { + input, want string + }{ + {"'hello'", "hello"}, + {`"hello"`, "hello"}, + {"hello", "hello"}, + {"'h'", "h"}, + {"''", ""}, + {`""`, ""}, + {"'mismatched\"", "'mismatched\""}, + {"x", "x"}, + {"", ""}, + } + for _, tt := range tests { + got := stripQuotes(tt.input) + if got != tt.want { + t.Errorf("stripQuotes(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func TestParseFishAliases(t *testing.T) { + output := `alias gs 'git status' +alias ll 'ls -la' +alias gco "git checkout" +` + m, err := ParseFishAliases(output) + if err != nil { + t.Fatalf("ParseFishAliases: %v", err) + } + + if m.Len() != 3 { + t.Errorf("Len() = %d, want 3", m.Len()) + } + + got, ok := m.Get("gs") + if !ok || got != "git status" { + t.Errorf("gs = %q, %v", got, ok) + } + got, ok = m.Get("gco") + if !ok || got != "git checkout" { + t.Errorf("gco = %q, %v", got, ok) + } +} + +func TestShellBaseName(t *testing.T) { + tests := []struct { + input, want string + }{ + {"/bin/bash", "bash"}, + {"/usr/bin/zsh", "zsh"}, + {"/usr/local/bin/fish", "fish"}, + {"bash", "bash"}, + {"/bin/sh", "sh"}, + } + for _, tt := range tests { + got := shellBaseName(tt.input) + if got != tt.want { + t.Errorf("shellBaseName(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func TestAliasCommandFor(t *testing.T) { + tests := []struct { + shell string + want string + }{ + {"bash", "alias -p 2>/dev/null; true"}, + {"zsh", "alias 2>/dev/null; true"}, + {"fish", "alias 2>/dev/null; true"}, + {"sh", "alias -p 2>/dev/null; true"}, + {"unknown", "alias 2>/dev/null; true"}, + } + for _, tt := range tests { + got := aliasCommandFor(tt.shell) + if got != tt.want { + t.Errorf("aliasCommandFor(%q) = %q, want %q", tt.shell, got, tt.want) + } + } +} + +func TestHarvestAliases_Integration(t *testing.T) { + // This actually runs the user's shell — skip in CI + if testing.Short() { + t.Skip("skipping alias harvest in short mode") + } + + m, err := HarvestAliases(context.Background()) + if err != nil { + // Non-fatal: harvesting may fail in some environments + t.Logf("HarvestAliases: %v (non-fatal)", err) + } + t.Logf("Harvested %d aliases", m.Len()) + for name, exp := range m.All() { + t.Logf(" %s → %s", name, exp) + } +} + +func TestBashTool_WithAliases(t *testing.T) { + aliases := NewAliasMap() + aliases.mu.Lock() + aliases.aliases["ll"] = "ls -la" + aliases.mu.Unlock() + + b := New(WithAliases(aliases)) + + // "ll /tmp" should expand to "ls -la /tmp" and execute + result, err := b.Execute(context.Background(), []byte(`{"command":"ll /tmp"}`)) + if err != nil { + t.Fatalf("Execute: %v", err) + } + // Should produce output (ls -la /tmp lists files) + if result.Output == "" { + t.Error("expected output from expanded alias") + } + if result.Metadata["blocked"] == true { + t.Error("expanded alias should not be blocked") + } +} diff --git a/internal/tool/bash/bash.go b/internal/tool/bash/bash.go new file mode 100644 index 0000000..fe3a322 --- /dev/null +++ b/internal/tool/bash/bash.go @@ -0,0 +1,140 @@ +package bash + +import ( + "context" + "encoding/json" + "fmt" + "os/exec" + "strings" + "time" + + "somegit.dev/Owlibou/gnoma/internal/tool" +) + +const ( + defaultTimeout = 30 * time.Second + toolName = "bash" +) + +var parameterSchema = json.RawMessage(`{ + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "The bash command to execute" + }, + "timeout": { + "type": "integer", + "description": "Timeout in seconds (default 30)" + } + }, + "required": ["command"] +}`) + +// Tool executes bash commands. +type Tool struct { + timeout time.Duration + workingDir string + aliases *AliasMap +} + +type Option func(*Tool) + +func WithTimeout(d time.Duration) Option { + return func(t *Tool) { t.timeout = d } +} + +func WithWorkingDir(dir string) Option { + return func(t *Tool) { t.workingDir = dir } +} + +func WithAliases(aliases *AliasMap) Option { + return func(t *Tool) { t.aliases = aliases } +} + +// New creates a bash tool. +func New(opts ...Option) *Tool { + t := &Tool{timeout: defaultTimeout} + for _, opt := range opts { + opt(t) + } + return t +} + +func (t *Tool) Name() string { return toolName } +func (t *Tool) Description() string { return "Execute a bash command and return its output" } +func (t *Tool) Parameters() json.RawMessage { return parameterSchema } +func (t *Tool) IsReadOnly() bool { return false } +func (t *Tool) IsDestructive() bool { return true } + +type bashArgs struct { + Command string `json:"command"` + Timeout int `json:"timeout,omitempty"` +} + +func (t *Tool) Execute(ctx context.Context, args json.RawMessage) (tool.Result, error) { + var a bashArgs + if err := json.Unmarshal(args, &a); err != nil { + return tool.Result{}, fmt.Errorf("bash: invalid args: %w", err) + } + + if a.Command == "" { + return tool.Result{}, fmt.Errorf("bash: empty command") + } + + // Expand aliases (first word only, matching bash behavior) + command := a.Command + if t.aliases != nil { + command = t.aliases.ExpandCommand(command) + } + + // Security validation runs on the expanded command + if violation := ValidateCommand(command); violation != nil { + return tool.Result{ + Output: fmt.Sprintf("Command blocked: %s", violation.Message), + Metadata: map[string]any{"blocked": true, "check": int(violation.Check)}, + }, nil + } + + timeout := t.timeout + if a.Timeout > 0 { + timeout = time.Duration(a.Timeout) * time.Second + } + + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + cmd := exec.CommandContext(ctx, "bash", "-c", command) + if t.workingDir != "" { + cmd.Dir = t.workingDir + } + + output, err := cmd.CombinedOutput() + exitCode := 0 + + if err != nil { + // Check timeout first — context deadline may also produce an ExitError + if ctx.Err() == context.DeadlineExceeded { + return tool.Result{ + Output: fmt.Sprintf("Command timed out after %s\n%s", timeout, strings.TrimRight(string(output), "\n")), + Metadata: map[string]any{"exit_code": -1, "timeout": true}, + }, nil + } + if exitErr, ok := err.(*exec.ExitError); ok { + exitCode = exitErr.ExitCode() + } else { + return tool.Result{}, fmt.Errorf("bash: exec failed: %w", err) + } + } + + result := tool.Result{ + Output: strings.TrimRight(string(output), "\n"), + Metadata: map[string]any{"exit_code": exitCode}, + } + + if exitCode != 0 { + result.Output = fmt.Sprintf("Exit code %d\n%s", exitCode, result.Output) + } + + return result, nil +} diff --git a/internal/tool/bash/bash_test.go b/internal/tool/bash/bash_test.go new file mode 100644 index 0000000..8407b63 --- /dev/null +++ b/internal/tool/bash/bash_test.go @@ -0,0 +1,135 @@ +package bash + +import ( + "context" + "encoding/json" + "strings" + "testing" + "time" +) + +func TestBashTool_Interface(t *testing.T) { + b := New() + if b.Name() != "bash" { + t.Errorf("Name() = %q", b.Name()) + } + if b.IsReadOnly() { + t.Error("bash should not be read-only") + } + if !b.IsDestructive() { + t.Error("bash should be destructive") + } + if b.Parameters() == nil { + t.Error("Parameters() should not be nil") + } +} + +func TestBashTool_Echo(t *testing.T) { + b := New() + result, err := b.Execute(context.Background(), json.RawMessage(`{"command":"echo hello world"}`)) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if result.Output != "hello world" { + t.Errorf("Output = %q, want %q", result.Output, "hello world") + } + if result.Metadata["exit_code"] != 0 { + t.Errorf("exit_code = %v, want 0", result.Metadata["exit_code"]) + } +} + +func TestBashTool_ExitCode(t *testing.T) { + b := New() + result, err := b.Execute(context.Background(), json.RawMessage(`{"command":"exit 42"}`)) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if result.Metadata["exit_code"] != 42 { + t.Errorf("exit_code = %v, want 42", result.Metadata["exit_code"]) + } + if !strings.HasPrefix(result.Output, "Exit code 42") { + t.Errorf("Output = %q, should start with exit code", result.Output) + } +} + +func TestBashTool_Timeout(t *testing.T) { + b := New(WithTimeout(100 * time.Millisecond)) + result, err := b.Execute(context.Background(), json.RawMessage(`{"command":"sleep 10"}`)) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if result.Metadata["timeout"] != true { + t.Error("should have timed out") + } + if !strings.Contains(result.Output, "timed out") { + t.Errorf("Output = %q, should mention timeout", result.Output) + } +} + +func TestBashTool_CustomTimeout(t *testing.T) { + b := New(WithTimeout(30 * time.Second)) + // Args override the default timeout + result, err := b.Execute(context.Background(), json.RawMessage(`{"command":"sleep 10","timeout":1}`)) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if result.Metadata["timeout"] != true { + t.Error("should have timed out with custom 1s timeout") + } +} + +func TestBashTool_InvalidArgs(t *testing.T) { + b := New() + _, err := b.Execute(context.Background(), json.RawMessage(`not json`)) + if err == nil { + t.Error("expected error for invalid JSON") + } +} + +func TestBashTool_EmptyCommand(t *testing.T) { + b := New() + _, err := b.Execute(context.Background(), json.RawMessage(`{"command":""}`)) + if err == nil { + t.Error("expected error for empty command") + } +} + +func TestBashTool_SecurityBlock(t *testing.T) { + b := New() + + // Command substitution should be blocked + result, err := b.Execute(context.Background(), json.RawMessage(`{"command":"echo $(whoami)"}`)) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if result.Metadata["blocked"] != true { + t.Error("command with $() should be blocked") + } + if !strings.Contains(result.Output, "blocked") { + t.Errorf("Output = %q, should mention blocked", result.Output) + } +} + +func TestBashTool_WorkingDir(t *testing.T) { + b := New(WithWorkingDir(t.TempDir())) + result, err := b.Execute(context.Background(), json.RawMessage(`{"command":"pwd"}`)) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if result.Output == "" { + t.Error("pwd should produce output") + } +} + +func TestBashTool_ContextCancellation(t *testing.T) { + b := New() + ctx, cancel := context.WithCancel(context.Background()) + cancel() // cancel immediately + + _, err := b.Execute(ctx, json.RawMessage(`{"command":"echo hello"}`)) + // Should either return an error or a timeout result + if err == nil { + // That's ok too — context cancellation is best-effort for fast commands + return + } +} diff --git a/internal/tool/bash/security.go b/internal/tool/bash/security.go new file mode 100644 index 0000000..797ee00 --- /dev/null +++ b/internal/tool/bash/security.go @@ -0,0 +1,206 @@ +package bash + +import ( + "fmt" + "strings" + "unicode" +) + +// SecurityCheck identifies a specific validation check. +type SecurityCheck int + +const ( + CheckIncomplete SecurityCheck = iota + 1 // fragments, trailing operators + CheckMetacharacters // ; | & $ ` < > + CheckCmdSubstitution // $(), ``, ${} + CheckRedirection // < > >> etc. + CheckDangerousVars // IFS, PATH manipulation + CheckNewlineInjection // embedded newlines + CheckControlChars // ASCII 00-1F (except \n \t) +) + +// SecurityViolation describes a failed security check. +type SecurityViolation struct { + Check SecurityCheck + Message string +} + +func (v SecurityViolation) Error() string { + return fmt.Sprintf("bash security check %d: %s", v.Check, v.Message) +} + +// ValidateCommand runs the 7 critical security checks against a command string. +// Returns nil if all checks pass, or the first violation found. +func ValidateCommand(cmd string) *SecurityViolation { + if strings.TrimSpace(cmd) == "" { + return &SecurityViolation{Check: CheckIncomplete, Message: "empty command"} + } + + // Check incomplete on raw command (before trimming) to catch tab-starts + if v := checkIncomplete(cmd); v != nil { + return v + } + + cmd = strings.TrimSpace(cmd) + + if v := checkControlChars(cmd); v != nil { + return v + } + if v := checkNewlineInjection(cmd); v != nil { + return v + } + if v := checkCmdSubstitution(cmd); v != nil { + return v + } + if v := checkDangerousVars(cmd); v != nil { + return v + } + // Metacharacters and redirection are warnings, not blocks in M1. + // The LLM legitimately uses pipes and redirects. + // Full compound command parsing (mvdan.cc/sh) comes in M5. + return nil +} + +// checkIncomplete detects command fragments that shouldn't be executed. +func checkIncomplete(cmd string) *SecurityViolation { + // Starts with tab (likely a fragment from indented code) + if cmd[0] == '\t' { + return &SecurityViolation{Check: CheckIncomplete, Message: "command starts with tab (likely a code fragment)"} + } + // Starts with a flag (no command name) + if cmd[0] == '-' { + return &SecurityViolation{Check: CheckIncomplete, Message: "command starts with flag (no command name)"} + } + // Ends with a dangling operator + trimmed := strings.TrimRight(cmd, " \t") + if len(trimmed) > 0 { + last := trimmed[len(trimmed)-1] + if last == '|' || last == '&' || last == ';' { + return &SecurityViolation{Check: CheckIncomplete, Message: "command ends with dangling operator"} + } + } + return nil +} + +// checkControlChars blocks ASCII control characters (0x00-0x1F) except \n and \t. +func checkControlChars(cmd string) *SecurityViolation { + for i, r := range cmd { + if r < 0x20 && r != '\n' && r != '\t' && r != '\r' { + return &SecurityViolation{ + Check: CheckControlChars, + Message: fmt.Sprintf("control character U+%04X at position %d", r, i), + } + } + } + return nil +} + +// checkNewlineInjection blocks commands with embedded newlines. +// Newlines in quoted strings are legitimate but rare in single commands. +// We allow them inside single/double quotes only. +func checkNewlineInjection(cmd string) *SecurityViolation { + inSingle := false + inDouble := false + escaped := false + + for _, r := range cmd { + if escaped { + escaped = false + continue + } + if r == '\\' && !inSingle { + escaped = true + continue + } + if r == '\'' && !inDouble { + inSingle = !inSingle + continue + } + if r == '"' && !inSingle { + inDouble = !inDouble + continue + } + if r == '\n' && !inSingle && !inDouble { + return &SecurityViolation{ + Check: CheckNewlineInjection, + Message: "unquoted newline (potential command injection)", + } + } + } + return nil +} + +// checkCmdSubstitution blocks $(), ``, and ${} command/variable substitution. +// These allow arbitrary code execution within a command. +func checkCmdSubstitution(cmd string) *SecurityViolation { + inSingle := false + escaped := false + + for i, r := range cmd { + if escaped { + escaped = false + continue + } + if r == '\\' && !inSingle { + escaped = true + continue + } + if r == '\'' { + inSingle = !inSingle + continue + } + + // Skip checks inside single quotes (literal) + if inSingle { + continue + } + + if r == '`' { + return &SecurityViolation{ + Check: CheckCmdSubstitution, + Message: "backtick command substitution", + } + } + + if r == '$' && i+1 < len(cmd) { + next := rune(cmd[i+1]) + if next == '(' { + return &SecurityViolation{ + Check: CheckCmdSubstitution, + Message: "$() command substitution", + } + } + if next == '{' { + return &SecurityViolation{ + Check: CheckCmdSubstitution, + Message: "${} variable expansion", + } + } + } + } + return nil +} + +// checkDangerousVars blocks attempts to manipulate IFS or PATH. +func checkDangerousVars(cmd string) *SecurityViolation { + upper := strings.ToUpper(cmd) + dangerousPatterns := []struct { + pattern string + msg string + }{ + {"IFS=", "IFS variable manipulation"}, + {"PATH=", "PATH variable manipulation"}, + } + + for _, p := range dangerousPatterns { + idx := strings.Index(upper, p.pattern) + if idx == -1 { + continue + } + // Only flag if it's at the start or preceded by whitespace/semicolon + if idx == 0 || !unicode.IsLetter(rune(cmd[idx-1])) { + return &SecurityViolation{Check: CheckDangerousVars, Message: p.msg} + } + } + return nil +} diff --git a/internal/tool/bash/security_test.go b/internal/tool/bash/security_test.go new file mode 100644 index 0000000..a8f7e05 --- /dev/null +++ b/internal/tool/bash/security_test.go @@ -0,0 +1,182 @@ +package bash + +import "testing" + +func TestValidateCommand_Valid(t *testing.T) { + valid := []string{ + "echo hello", + "ls -la", + "cat /etc/hostname", + "go test ./...", + "git status", + "echo 'hello world'", + `echo "hello world"`, + "grep -r 'pattern' .", + "find . -name '*.go'", + } + for _, cmd := range valid { + if v := ValidateCommand(cmd); v != nil { + t.Errorf("ValidateCommand(%q) = %v, want nil", cmd, v) + } + } +} + +func TestValidateCommand_Empty(t *testing.T) { + v := ValidateCommand("") + if v == nil { + t.Fatal("expected violation for empty command") + } + if v.Check != CheckIncomplete { + t.Errorf("Check = %d, want %d (incomplete)", v.Check, CheckIncomplete) + } +} + +func TestCheckIncomplete(t *testing.T) { + tests := []struct { + cmd string + want SecurityCheck + }{ + {"\techo hello", CheckIncomplete}, // tab start + {"-flag value", CheckIncomplete}, // flag start + {"echo hello |", CheckIncomplete}, // trailing pipe + {"echo hello &", CheckIncomplete}, // trailing ampersand + {"echo hello ;", CheckIncomplete}, // trailing semicolon + } + for _, tt := range tests { + v := ValidateCommand(tt.cmd) + if v == nil { + t.Errorf("ValidateCommand(%q) = nil, want check %d", tt.cmd, tt.want) + continue + } + if v.Check != tt.want { + t.Errorf("ValidateCommand(%q).Check = %d, want %d", tt.cmd, v.Check, tt.want) + } + } +} + +func TestCheckControlChars(t *testing.T) { + tests := []struct { + name string + cmd string + }{ + {"null byte", "echo hello\x00world"}, + {"bell", "echo \x07"}, + {"backspace", "echo \x08"}, + {"escape", "echo \x1b[31m"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v := ValidateCommand(tt.cmd) + if v == nil { + t.Error("expected violation") + return + } + if v.Check != CheckControlChars { + t.Errorf("Check = %d, want %d (control chars)", v.Check, CheckControlChars) + } + }) + } +} + +func TestCheckControlChars_AllowedChars(t *testing.T) { + // Tabs and newlines inside quotes are allowed + valid := []string{ + "echo 'hello\tworld'", + } + for _, cmd := range valid { + if v := checkControlChars(cmd); v != nil { + t.Errorf("checkControlChars(%q) = %v, want nil", cmd, v) + } + } +} + +func TestCheckNewlineInjection(t *testing.T) { + // Unquoted newline + v := checkNewlineInjection("echo hello\nrm -rf /") + if v == nil { + t.Fatal("expected violation for unquoted newline") + } + if v.Check != CheckNewlineInjection { + t.Errorf("Check = %d, want %d", v.Check, CheckNewlineInjection) + } +} + +func TestCheckNewlineInjection_QuotedOK(t *testing.T) { + // Newlines inside quotes are fine + allowed := []string{ + "echo 'hello\nworld'", + `echo "hello` + "\n" + `world"`, + } + for _, cmd := range allowed { + if v := checkNewlineInjection(cmd); v != nil { + t.Errorf("checkNewlineInjection(%q) = %v, want nil", cmd, v) + } + } +} + +func TestCheckCmdSubstitution(t *testing.T) { + tests := []struct { + name string + cmd string + }{ + {"backtick", "echo `whoami`"}, + {"dollar paren", "echo $(whoami)"}, + {"dollar brace", "echo ${HOME}"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v := ValidateCommand(tt.cmd) + if v == nil { + t.Error("expected violation") + return + } + if v.Check != CheckCmdSubstitution { + t.Errorf("Check = %d, want %d", v.Check, CheckCmdSubstitution) + } + }) + } +} + +func TestCheckCmdSubstitution_SingleQuoteOK(t *testing.T) { + // Inside single quotes, everything is literal + safe := "echo '$(whoami) and `uname` and ${HOME}'" + if v := checkCmdSubstitution(safe); v != nil { + t.Errorf("checkCmdSubstitution(%q) = %v, want nil (single-quoted)", safe, v) + } +} + +func TestCheckDangerousVars(t *testing.T) { + tests := []struct { + name string + cmd string + }{ + {"IFS at start", "IFS=: read a b"}, + {"PATH manipulation", "PATH=/tmp:$PATH command"}, + {"ifs with space prefix", " IFS=x echo hi"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v := ValidateCommand(tt.cmd) + if v == nil { + t.Error("expected violation") + return + } + if v.Check != CheckDangerousVars { + t.Errorf("Check = %d, want %d", v.Check, CheckDangerousVars) + } + }) + } +} + +func TestCheckDangerousVars_SafeSubstrings(t *testing.T) { + // "SWIFT=..." should not trigger PATH check, "TARIFFS=..." should not trigger IFS + safe := []string{ + "echo SWIFT=enabled", + "TARIFFS=high echo test", + } + for _, cmd := range safe { + if v := checkDangerousVars(cmd); v != nil { + t.Errorf("checkDangerousVars(%q) = %v, want nil", cmd, v) + } + } +} diff --git a/internal/tool/fs/edit.go b/internal/tool/fs/edit.go new file mode 100644 index 0000000..08229df --- /dev/null +++ b/internal/tool/fs/edit.go @@ -0,0 +1,109 @@ +package fs + +import ( + "context" + "encoding/json" + "fmt" + "os" + "strings" + + "somegit.dev/Owlibou/gnoma/internal/tool" +) + +const editToolName = "fs.edit" + +var editParams = json.RawMessage(`{ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Absolute path to the file to edit" + }, + "old_string": { + "type": "string", + "description": "The exact text to find and replace" + }, + "new_string": { + "type": "string", + "description": "The replacement text" + }, + "replace_all": { + "type": "boolean", + "description": "Replace all occurrences (default false)" + } + }, + "required": ["path", "old_string", "new_string"] +}`) + +type EditTool struct{} + +func NewEditTool() *EditTool { return &EditTool{} } + +func (t *EditTool) Name() string { return editToolName } +func (t *EditTool) Description() string { return "Perform exact string replacement in a file" } +func (t *EditTool) Parameters() json.RawMessage { return editParams } +func (t *EditTool) IsReadOnly() bool { return false } +func (t *EditTool) IsDestructive() bool { return false } + +type editArgs struct { + Path string `json:"path"` + OldString string `json:"old_string"` + NewString string `json:"new_string"` + ReplaceAll bool `json:"replace_all,omitempty"` +} + +func (t *EditTool) Execute(_ context.Context, args json.RawMessage) (tool.Result, error) { + var a editArgs + if err := json.Unmarshal(args, &a); err != nil { + return tool.Result{}, fmt.Errorf("fs.edit: invalid args: %w", err) + } + if a.Path == "" { + return tool.Result{}, fmt.Errorf("fs.edit: path required") + } + if a.OldString == a.NewString { + return tool.Result{}, fmt.Errorf("fs.edit: old_string and new_string must differ") + } + + data, err := os.ReadFile(a.Path) + if err != nil { + return tool.Result{Output: fmt.Sprintf("Error: %v", err)}, nil + } + + content := string(data) + + count := strings.Count(content, a.OldString) + if count == 0 { + return tool.Result{ + Output: "Error: old_string not found in file", + Metadata: map[string]any{"matches": 0}, + }, nil + } + + if !a.ReplaceAll && count > 1 { + return tool.Result{ + Output: fmt.Sprintf("Error: old_string has %d matches (must be unique, or use replace_all)", count), + Metadata: map[string]any{"matches": count}, + }, nil + } + + var newContent string + if a.ReplaceAll { + newContent = strings.ReplaceAll(content, a.OldString, a.NewString) + } else { + newContent = strings.Replace(content, a.OldString, a.NewString, 1) + } + + if err := os.WriteFile(a.Path, []byte(newContent), 0o644); err != nil { + return tool.Result{Output: fmt.Sprintf("Error writing file: %v", err)}, nil + } + + replacements := 1 + if a.ReplaceAll { + replacements = count + } + + return tool.Result{ + Output: fmt.Sprintf("Replaced %d occurrence(s) in %s", replacements, a.Path), + Metadata: map[string]any{"replacements": replacements, "path": a.Path}, + }, nil +} diff --git a/internal/tool/fs/fs_test.go b/internal/tool/fs/fs_test.go new file mode 100644 index 0000000..89b0e91 --- /dev/null +++ b/internal/tool/fs/fs_test.go @@ -0,0 +1,545 @@ +package fs + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" +) + +// --- Read --- + +func TestReadTool_Interface(t *testing.T) { + r := NewReadTool() + if r.Name() != "fs.read" { + t.Errorf("Name() = %q", r.Name()) + } + if !r.IsReadOnly() { + t.Error("should be read-only") + } + if r.IsDestructive() { + t.Error("should not be destructive") + } +} + +func TestReadTool_SimpleFile(t *testing.T) { + path := writeTestFile(t, "hello\nworld\n") + r := NewReadTool() + + result, err := r.Execute(context.Background(), mustJSON(t, readArgs{Path: path})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if !strings.Contains(result.Output, "1\thello") { + t.Errorf("Output should contain line-numbered content, got %q", result.Output) + } + if !strings.Contains(result.Output, "2\tworld") { + t.Errorf("Output missing line 2, got %q", result.Output) + } +} + +func TestReadTool_WithOffset(t *testing.T) { + path := writeTestFile(t, "line1\nline2\nline3\nline4\nline5\n") + r := NewReadTool() + + result, err := r.Execute(context.Background(), mustJSON(t, readArgs{Path: path, Offset: 2})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if !strings.Contains(result.Output, "3\tline3") { + t.Errorf("Output should start at line 3, got %q", result.Output) + } + if strings.Contains(result.Output, "1\tline1") { + t.Error("Output should not contain line 1") + } +} + +func TestReadTool_WithLimit(t *testing.T) { + path := writeTestFile(t, "a\nb\nc\nd\ne\n") + r := NewReadTool() + + result, err := r.Execute(context.Background(), mustJSON(t, readArgs{Path: path, Limit: 2})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + lines := strings.Split(result.Output, "\n") + if len(lines) != 2 { + t.Errorf("expected 2 lines, got %d: %q", len(lines), result.Output) + } + if result.Metadata["truncated"] != true { + t.Error("should be truncated") + } +} + +func TestReadTool_OffsetPastEnd(t *testing.T) { + path := writeTestFile(t, "one\ntwo\n") + r := NewReadTool() + + result, err := r.Execute(context.Background(), mustJSON(t, readArgs{Path: path, Offset: 100})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if !strings.Contains(result.Output, "past end") { + t.Errorf("Output = %q, should mention past end", result.Output) + } +} + +func TestReadTool_FileNotFound(t *testing.T) { + r := NewReadTool() + result, err := r.Execute(context.Background(), mustJSON(t, readArgs{Path: "/nonexistent/file.txt"})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if !strings.Contains(result.Output, "Error") { + t.Errorf("Output = %q, should contain error", result.Output) + } +} + +func TestReadTool_EmptyPath(t *testing.T) { + r := NewReadTool() + _, err := r.Execute(context.Background(), mustJSON(t, readArgs{})) + if err == nil { + t.Error("expected error for empty path") + } +} + +// --- Write --- + +func TestWriteTool_Interface(t *testing.T) { + w := NewWriteTool() + if w.Name() != "fs.write" { + t.Errorf("Name() = %q", w.Name()) + } + if w.IsReadOnly() { + t.Error("should not be read-only") + } +} + +func TestWriteTool_CreateFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.txt") + + w := NewWriteTool() + result, err := w.Execute(context.Background(), mustJSON(t, writeArgs{Path: path, Content: "hello world"})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if !strings.Contains(result.Output, "11 bytes") { + t.Errorf("Output = %q", result.Output) + } + + data, _ := os.ReadFile(path) + if string(data) != "hello world" { + t.Errorf("file content = %q", string(data)) + } +} + +func TestWriteTool_CreatesParentDirs(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "a", "b", "c", "test.txt") + + w := NewWriteTool() + _, err := w.Execute(context.Background(), mustJSON(t, writeArgs{Path: path, Content: "nested"})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + + data, _ := os.ReadFile(path) + if string(data) != "nested" { + t.Errorf("file content = %q", string(data)) + } +} + +func TestWriteTool_OverwriteExisting(t *testing.T) { + path := writeTestFile(t, "old content") + w := NewWriteTool() + + _, err := w.Execute(context.Background(), mustJSON(t, writeArgs{Path: path, Content: "new content"})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + + data, _ := os.ReadFile(path) + if string(data) != "new content" { + t.Errorf("file content = %q", string(data)) + } +} + +// --- Edit --- + +func TestEditTool_Interface(t *testing.T) { + e := NewEditTool() + if e.Name() != "fs.edit" { + t.Errorf("Name() = %q", e.Name()) + } +} + +func TestEditTool_SingleReplace(t *testing.T) { + path := writeTestFile(t, "hello world") + e := NewEditTool() + + result, err := e.Execute(context.Background(), mustJSON(t, editArgs{ + Path: path, OldString: "world", NewString: "gnoma", + })) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if !strings.Contains(result.Output, "1 occurrence") { + t.Errorf("Output = %q", result.Output) + } + + data, _ := os.ReadFile(path) + if string(data) != "hello gnoma" { + t.Errorf("file content = %q", string(data)) + } +} + +func TestEditTool_ReplaceAll(t *testing.T) { + path := writeTestFile(t, "foo bar foo baz foo") + e := NewEditTool() + + result, err := e.Execute(context.Background(), mustJSON(t, editArgs{ + Path: path, OldString: "foo", NewString: "qux", ReplaceAll: true, + })) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if !strings.Contains(result.Output, "3 occurrence") { + t.Errorf("Output = %q", result.Output) + } + + data, _ := os.ReadFile(path) + if string(data) != "qux bar qux baz qux" { + t.Errorf("file content = %q", string(data)) + } +} + +func TestEditTool_NonUniqueWithoutReplaceAll(t *testing.T) { + path := writeTestFile(t, "foo foo foo") + e := NewEditTool() + + result, err := e.Execute(context.Background(), mustJSON(t, editArgs{ + Path: path, OldString: "foo", NewString: "bar", + })) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if !strings.Contains(result.Output, "3 matches") { + t.Errorf("Output = %q, should mention multiple matches", result.Output) + } + + // File should be unchanged + data, _ := os.ReadFile(path) + if string(data) != "foo foo foo" { + t.Errorf("file should be unchanged, got %q", string(data)) + } +} + +func TestEditTool_NotFound(t *testing.T) { + path := writeTestFile(t, "hello world") + e := NewEditTool() + + result, err := e.Execute(context.Background(), mustJSON(t, editArgs{ + Path: path, OldString: "missing", NewString: "replaced", + })) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if !strings.Contains(result.Output, "not found") { + t.Errorf("Output = %q, should mention not found", result.Output) + } +} + +func TestEditTool_SameStrings(t *testing.T) { + e := NewEditTool() + _, err := e.Execute(context.Background(), mustJSON(t, editArgs{ + Path: "/tmp/x", OldString: "same", NewString: "same", + })) + if err == nil { + t.Error("expected error when old_string == new_string") + } +} + +// --- Glob --- + +func TestGlobTool_Interface(t *testing.T) { + g := NewGlobTool() + if g.Name() != "fs.glob" { + t.Errorf("Name() = %q", g.Name()) + } + if !g.IsReadOnly() { + t.Error("should be read-only") + } +} + +func TestGlobTool_MatchFiles(t *testing.T) { + dir := t.TempDir() + os.WriteFile(filepath.Join(dir, "main.go"), []byte("package main"), 0o644) + os.WriteFile(filepath.Join(dir, "test.go"), []byte("package main"), 0o644) + os.WriteFile(filepath.Join(dir, "readme.md"), []byte("# readme"), 0o644) + + g := NewGlobTool() + result, err := g.Execute(context.Background(), mustJSON(t, globArgs{Pattern: "*.go", Path: dir})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + + if result.Metadata["count"] != 2 { + t.Errorf("count = %v, want 2", result.Metadata["count"]) + } + if !strings.Contains(result.Output, "main.go") { + t.Errorf("Output missing main.go: %q", result.Output) + } + if strings.Contains(result.Output, "readme.md") { + t.Error("Output should not contain readme.md") + } +} + +func TestGlobTool_NoMatches(t *testing.T) { + dir := t.TempDir() + g := NewGlobTool() + + result, err := g.Execute(context.Background(), mustJSON(t, globArgs{Pattern: "*.xyz", Path: dir})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if !strings.Contains(result.Output, "no matches") { + t.Errorf("Output = %q", result.Output) + } +} + +// --- Grep --- + +func TestGrepTool_Interface(t *testing.T) { + g := NewGrepTool() + if g.Name() != "fs.grep" { + t.Errorf("Name() = %q", g.Name()) + } + if !g.IsReadOnly() { + t.Error("should be read-only") + } +} + +func TestGrepTool_SingleFile(t *testing.T) { + path := writeTestFile(t, "hello world\nfoo bar\nhello again\n") + g := NewGrepTool() + + result, err := g.Execute(context.Background(), mustJSON(t, grepArgs{Pattern: "hello", Path: path})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if result.Metadata["count"] != 2 { + t.Errorf("count = %v, want 2", result.Metadata["count"]) + } + if !strings.Contains(result.Output, "1:hello world") { + t.Errorf("Output = %q", result.Output) + } +} + +func TestGrepTool_Directory(t *testing.T) { + dir := t.TempDir() + os.WriteFile(filepath.Join(dir, "a.go"), []byte("func main() {}\nfunc helper() {}"), 0o644) + os.WriteFile(filepath.Join(dir, "b.go"), []byte("func test() {}"), 0o644) + os.WriteFile(filepath.Join(dir, "c.txt"), []byte("func ignored() {}"), 0o644) + + g := NewGrepTool() + + // Search all files for "func" + result, err := g.Execute(context.Background(), mustJSON(t, grepArgs{Pattern: "func", Path: dir})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if result.Metadata["count"].(int) < 3 { + t.Errorf("count = %v, want >= 3", result.Metadata["count"]) + } + + // With glob filter + result, err = g.Execute(context.Background(), mustJSON(t, grepArgs{Pattern: "func", Path: dir, Glob: "*.go"})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if strings.Contains(result.Output, "c.txt") { + t.Error("should not match .txt files with *.go glob") + } +} + +func TestGrepTool_Regex(t *testing.T) { + path := writeTestFile(t, "error: something failed\nwarning: be careful\nerror: another one\n") + g := NewGrepTool() + + result, err := g.Execute(context.Background(), mustJSON(t, grepArgs{Pattern: `^error:`, Path: path})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if result.Metadata["count"] != 2 { + t.Errorf("count = %v, want 2", result.Metadata["count"]) + } +} + +func TestGrepTool_InvalidRegex(t *testing.T) { + g := NewGrepTool() + result, err := g.Execute(context.Background(), mustJSON(t, grepArgs{Pattern: "[invalid", Path: "."})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if !strings.Contains(result.Output, "Invalid regex") { + t.Errorf("Output = %q, should mention invalid regex", result.Output) + } +} + +func TestGrepTool_NoMatches(t *testing.T) { + path := writeTestFile(t, "hello world\n") + g := NewGrepTool() + + result, err := g.Execute(context.Background(), mustJSON(t, grepArgs{Pattern: "zzzzz", Path: path})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if !strings.Contains(result.Output, "no matches") { + t.Errorf("Output = %q", result.Output) + } +} + +func TestGrepTool_MaxResults(t *testing.T) { + var lines strings.Builder + for i := 0; i < 100; i++ { + lines.WriteString("match line\n") + } + path := writeTestFile(t, lines.String()) + g := NewGrepTool() + + result, err := g.Execute(context.Background(), mustJSON(t, grepArgs{Pattern: "match", Path: path, MaxResults: 5})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if result.Metadata["count"] != 5 { + t.Errorf("count = %v, want 5", result.Metadata["count"]) + } + if result.Metadata["truncated"] != true { + t.Error("should be truncated") + } +} + +// --- LS --- + +func TestLSTool_Interface(t *testing.T) { + l := NewLSTool() + if l.Name() != "fs.ls" { + t.Errorf("Name() = %q", l.Name()) + } + if !l.IsReadOnly() { + t.Error("should be read-only") + } +} + +func TestLSTool_ListDirectory(t *testing.T) { + dir := t.TempDir() + os.WriteFile(filepath.Join(dir, "hello.go"), []byte("package main"), 0o644) + os.WriteFile(filepath.Join(dir, "readme.md"), []byte("# readme"), 0o644) + os.MkdirAll(filepath.Join(dir, "subdir"), 0o755) + + l := NewLSTool() + result, err := l.Execute(context.Background(), mustJSON(t, lsArgs{Path: dir})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + + if !strings.Contains(result.Output, "hello.go") { + t.Errorf("Output missing hello.go: %q", result.Output) + } + if !strings.Contains(result.Output, "readme.md") { + t.Errorf("Output missing readme.md: %q", result.Output) + } + if !strings.Contains(result.Output, "subdir") { + t.Errorf("Output missing subdir: %q", result.Output) + } + + if result.Metadata["files"] != 2 { + t.Errorf("files = %v, want 2", result.Metadata["files"]) + } + if result.Metadata["dirs"] != 1 { + t.Errorf("dirs = %v, want 1", result.Metadata["dirs"]) + } +} + +func TestLSTool_EmptyDirectory(t *testing.T) { + dir := t.TempDir() + l := NewLSTool() + + result, err := l.Execute(context.Background(), mustJSON(t, lsArgs{Path: dir})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if !strings.Contains(result.Output, "empty directory") { + t.Errorf("Output = %q, should mention empty", result.Output) + } +} + +func TestLSTool_DirectoryNotFound(t *testing.T) { + l := NewLSTool() + result, err := l.Execute(context.Background(), mustJSON(t, lsArgs{Path: "/nonexistent/dir"})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if !strings.Contains(result.Output, "Error") { + t.Errorf("Output = %q, should contain error", result.Output) + } +} + +func TestLSTool_ShowsSizes(t *testing.T) { + dir := t.TempDir() + os.WriteFile(filepath.Join(dir, "small.txt"), []byte("hi"), 0o644) + + l := NewLSTool() + result, err := l.Execute(context.Background(), mustJSON(t, lsArgs{Path: dir})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + // Should show "2B" for a 2-byte file + if !strings.Contains(result.Output, "2B") { + t.Errorf("Output = %q, should show file size", result.Output) + } +} + +func TestFormatSize(t *testing.T) { + tests := []struct { + bytes int64 + want string + }{ + {0, "0B"}, + {42, "42B"}, + {1024, "1.0K"}, + {1536, "1.5K"}, + {1048576, "1.0M"}, + {1073741824, "1.0G"}, + } + for _, tt := range tests { + got := formatSize(tt.bytes) + if got != tt.want { + t.Errorf("formatSize(%d) = %q, want %q", tt.bytes, got, tt.want) + } + } +} + +// --- Helpers --- + +func writeTestFile(t *testing.T, content string) string { + t.Helper() + dir := t.TempDir() + path := filepath.Join(dir, "test.txt") + if err := os.WriteFile(path, []byte(content), 0o644); err != nil { + t.Fatalf("writeTestFile: %v", err) + } + return path +} + +func mustJSON(t *testing.T, v any) json.RawMessage { + t.Helper() + data, err := json.Marshal(v) + if err != nil { + t.Fatalf("json.Marshal: %v", err) + } + return data +} diff --git a/internal/tool/fs/glob.go b/internal/tool/fs/glob.go new file mode 100644 index 0000000..1fb5980 --- /dev/null +++ b/internal/tool/fs/glob.go @@ -0,0 +1,117 @@ +package fs + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "sort" + "strings" + + "somegit.dev/Owlibou/gnoma/internal/tool" +) + +const globToolName = "fs.glob" + +var globParams = json.RawMessage(`{ + "type": "object", + "properties": { + "pattern": { + "type": "string", + "description": "Glob pattern to match files (e.g. **/*.go, src/**/*.ts)" + }, + "path": { + "type": "string", + "description": "Directory to search in (defaults to current directory)" + } + }, + "required": ["pattern"] +}`) + +type GlobTool struct{} + +func NewGlobTool() *GlobTool { return &GlobTool{} } + +func (t *GlobTool) Name() string { return globToolName } +func (t *GlobTool) Description() string { return "Find files matching a glob pattern, sorted by modification time" } +func (t *GlobTool) Parameters() json.RawMessage { return globParams } +func (t *GlobTool) IsReadOnly() bool { return true } +func (t *GlobTool) IsDestructive() bool { return false } + +type globArgs struct { + Pattern string `json:"pattern"` + Path string `json:"path,omitempty"` +} + +func (t *GlobTool) Execute(_ context.Context, args json.RawMessage) (tool.Result, error) { + var a globArgs + if err := json.Unmarshal(args, &a); err != nil { + return tool.Result{}, fmt.Errorf("fs.glob: invalid args: %w", err) + } + if a.Pattern == "" { + return tool.Result{}, fmt.Errorf("fs.glob: pattern required") + } + + root := a.Path + if root == "" { + var err error + root, err = os.Getwd() + if err != nil { + return tool.Result{}, fmt.Errorf("fs.glob: %w", err) + } + } + + var matches []string + err := filepath.WalkDir(root, func(path string, d os.DirEntry, err error) error { + if err != nil { + return nil // skip inaccessible entries + } + if d.IsDir() { + // Skip hidden directories + if d.Name() != "." && strings.HasPrefix(d.Name(), ".") { + return filepath.SkipDir + } + return nil + } + + rel, err := filepath.Rel(root, path) + if err != nil { + return nil + } + + matched, err := filepath.Match(a.Pattern, rel) + if err != nil { + // Try matching just the filename for simple patterns + matched, _ = filepath.Match(a.Pattern, d.Name()) + } + + if matched { + matches = append(matches, rel) + } + return nil + }) + if err != nil { + return tool.Result{Output: fmt.Sprintf("Error walking directory: %v", err)}, nil + } + + // Sort by modification time (most recent first) + sort.Slice(matches, func(i, j int) bool { + iInfo, _ := os.Stat(filepath.Join(root, matches[i])) + jInfo, _ := os.Stat(filepath.Join(root, matches[j])) + if iInfo == nil || jInfo == nil { + return matches[i] < matches[j] + } + return iInfo.ModTime().After(jInfo.ModTime()) + }) + + output := strings.Join(matches, "\n") + if output == "" { + output = "(no matches)" + } + + return tool.Result{ + Output: output, + Metadata: map[string]any{"count": len(matches), "pattern": a.Pattern}, + }, nil +} diff --git a/internal/tool/fs/grep.go b/internal/tool/fs/grep.go new file mode 100644 index 0000000..5e64869 --- /dev/null +++ b/internal/tool/fs/grep.go @@ -0,0 +1,184 @@ +package fs + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "regexp" + "strings" + + "somegit.dev/Owlibou/gnoma/internal/tool" +) + +const ( + grepToolName = "fs.grep" + defaultMaxResults = 250 +) + +var grepParams = json.RawMessage(`{ + "type": "object", + "properties": { + "pattern": { + "type": "string", + "description": "Regular expression pattern to search for" + }, + "path": { + "type": "string", + "description": "File or directory to search in (defaults to current directory)" + }, + "glob": { + "type": "string", + "description": "File glob filter (e.g. *.go, *.ts)" + }, + "max_results": { + "type": "integer", + "description": "Maximum number of matching lines to return (default 250)" + } + }, + "required": ["pattern"] +}`) + +type GrepTool struct{} + +func NewGrepTool() *GrepTool { return &GrepTool{} } + +func (t *GrepTool) Name() string { return grepToolName } +func (t *GrepTool) Description() string { return "Search file contents using a regular expression" } +func (t *GrepTool) Parameters() json.RawMessage { return grepParams } +func (t *GrepTool) IsReadOnly() bool { return true } +func (t *GrepTool) IsDestructive() bool { return false } + +type grepArgs struct { + Pattern string `json:"pattern"` + Path string `json:"path,omitempty"` + Glob string `json:"glob,omitempty"` + MaxResults int `json:"max_results,omitempty"` +} + +type grepMatch struct { + File string + Line int + Text string +} + +func (t *GrepTool) Execute(_ context.Context, args json.RawMessage) (tool.Result, error) { + var a grepArgs + if err := json.Unmarshal(args, &a); err != nil { + return tool.Result{}, fmt.Errorf("fs.grep: invalid args: %w", err) + } + if a.Pattern == "" { + return tool.Result{}, fmt.Errorf("fs.grep: pattern required") + } + + re, err := regexp.Compile(a.Pattern) + if err != nil { + return tool.Result{Output: fmt.Sprintf("Invalid regex: %v", err)}, nil + } + + maxResults := a.MaxResults + if maxResults <= 0 { + maxResults = defaultMaxResults + } + + root := a.Path + if root == "" { + root, err = os.Getwd() + if err != nil { + return tool.Result{}, fmt.Errorf("fs.grep: %w", err) + } + } + + info, err := os.Stat(root) + if err != nil { + return tool.Result{Output: fmt.Sprintf("Error: %v", err)}, nil + } + + var matches []grepMatch + + if !info.IsDir() { + matches = grepFile(root, "", re, maxResults) + } else { + filepath.WalkDir(root, func(path string, d os.DirEntry, err error) error { + if err != nil || d.IsDir() { + if d != nil && d.IsDir() && d.Name() != "." && strings.HasPrefix(d.Name(), ".") { + return filepath.SkipDir + } + return nil + } + + // Apply glob filter + if a.Glob != "" { + matched, _ := filepath.Match(a.Glob, d.Name()) + if !matched { + return nil + } + } + + rel, _ := filepath.Rel(root, path) + fileMatches := grepFile(path, rel, re, maxResults-len(matches)) + matches = append(matches, fileMatches...) + + if len(matches) >= maxResults { + return filepath.SkipAll + } + return nil + }) + } + + if len(matches) == 0 { + return tool.Result{ + Output: "(no matches)", + Metadata: map[string]any{"count": 0}, + }, nil + } + + var b strings.Builder + for _, m := range matches { + if m.File != "" { + fmt.Fprintf(&b, "%s:%d:%s\n", m.File, m.Line, m.Text) + } else { + fmt.Fprintf(&b, "%d:%s\n", m.Line, m.Text) + } + } + + truncated := len(matches) >= maxResults + + return tool.Result{ + Output: strings.TrimRight(b.String(), "\n"), + Metadata: map[string]any{ + "count": len(matches), + "truncated": truncated, + }, + }, nil +} + +func grepFile(path, displayPath string, re *regexp.Regexp, limit int) []grepMatch { + f, err := os.Open(path) + if err != nil { + return nil + } + defer f.Close() + + var matches []grepMatch + scanner := bufio.NewScanner(f) + lineNum := 0 + + for scanner.Scan() { + lineNum++ + line := scanner.Text() + if re.MatchString(line) { + matches = append(matches, grepMatch{ + File: displayPath, + Line: lineNum, + Text: line, + }) + if len(matches) >= limit { + break + } + } + } + return matches +} diff --git a/internal/tool/fs/ls.go b/internal/tool/fs/ls.go new file mode 100644 index 0000000..8a0d9f6 --- /dev/null +++ b/internal/tool/fs/ls.go @@ -0,0 +1,123 @@ +package fs + +import ( + "context" + "encoding/json" + "fmt" + "io/fs" + "os" + "path/filepath" + "strings" + + "somegit.dev/Owlibou/gnoma/internal/tool" +) + +const lsToolName = "fs.ls" + +var lsParams = json.RawMessage(`{ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Directory path to list (defaults to current directory)" + } + } +}`) + +type LSTool struct{} + +func NewLSTool() *LSTool { return &LSTool{} } + +func (t *LSTool) Name() string { return lsToolName } +func (t *LSTool) Description() string { return "List directory contents with file types and sizes" } +func (t *LSTool) Parameters() json.RawMessage { return lsParams } +func (t *LSTool) IsReadOnly() bool { return true } +func (t *LSTool) IsDestructive() bool { return false } + +type lsArgs struct { + Path string `json:"path,omitempty"` +} + +func (t *LSTool) Execute(_ context.Context, args json.RawMessage) (tool.Result, error) { + var a lsArgs + if err := json.Unmarshal(args, &a); err != nil { + return tool.Result{}, fmt.Errorf("fs.ls: invalid args: %w", err) + } + + dir := a.Path + if dir == "" { + var err error + dir, err = os.Getwd() + if err != nil { + return tool.Result{}, fmt.Errorf("fs.ls: %w", err) + } + } + + entries, err := os.ReadDir(dir) + if err != nil { + return tool.Result{Output: fmt.Sprintf("Error: %v", err)}, nil + } + + var b strings.Builder + dirCount, fileCount := 0, 0 + + for _, entry := range entries { + info, err := entry.Info() + if err != nil { + continue + } + + prefix := " " + if entry.IsDir() { + prefix = "d" + dirCount++ + } else { + fileCount++ + } + + size := formatSize(info.Size()) + if entry.IsDir() { + size = "-" + } + + // Check for symlink + if entry.Type()&fs.ModeSymlink != 0 { + prefix = "l" + target, err := os.Readlink(filepath.Join(dir, entry.Name())) + if err == nil { + fmt.Fprintf(&b, "%s %8s %s -> %s\n", prefix, size, entry.Name(), target) + continue + } + } + + fmt.Fprintf(&b, "%s %8s %s\n", prefix, size, entry.Name()) + } + + output := strings.TrimRight(b.String(), "\n") + if output == "" { + output = "(empty directory)" + } + + return tool.Result{ + Output: output, + Metadata: map[string]any{ + "directory": dir, + "files": fileCount, + "dirs": dirCount, + "total": fileCount + dirCount, + }, + }, nil +} + +func formatSize(bytes int64) string { + switch { + case bytes >= 1<<30: + return fmt.Sprintf("%.1fG", float64(bytes)/(1<<30)) + case bytes >= 1<<20: + return fmt.Sprintf("%.1fM", float64(bytes)/(1<<20)) + case bytes >= 1<<10: + return fmt.Sprintf("%.1fK", float64(bytes)/(1<<10)) + default: + return fmt.Sprintf("%dB", bytes) + } +} diff --git a/internal/tool/fs/read.go b/internal/tool/fs/read.go new file mode 100644 index 0000000..6b18648 --- /dev/null +++ b/internal/tool/fs/read.go @@ -0,0 +1,123 @@ +package fs + +import ( + "context" + "encoding/json" + "fmt" + "os" + "strings" + + "somegit.dev/Owlibou/gnoma/internal/tool" +) + +const ( + readToolName = "fs.read" + defaultMaxLines = 2000 +) + +var readParams = json.RawMessage(`{ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Absolute path to the file to read" + }, + "offset": { + "type": "integer", + "description": "Line number to start reading from (0-based)" + }, + "limit": { + "type": "integer", + "description": "Maximum number of lines to read" + } + }, + "required": ["path"] +}`) + +type ReadTool struct { + maxLines int +} + +type ReadOption func(*ReadTool) + +func WithMaxLines(n int) ReadOption { + return func(t *ReadTool) { t.maxLines = n } +} + +func NewReadTool(opts ...ReadOption) *ReadTool { + t := &ReadTool{maxLines: defaultMaxLines} + for _, opt := range opts { + opt(t) + } + return t +} + +func (t *ReadTool) Name() string { return readToolName } +func (t *ReadTool) Description() string { return "Read a file from the filesystem with optional offset and line limit" } +func (t *ReadTool) Parameters() json.RawMessage { return readParams } +func (t *ReadTool) IsReadOnly() bool { return true } +func (t *ReadTool) IsDestructive() bool { return false } + +type readArgs struct { + Path string `json:"path"` + Offset int `json:"offset,omitempty"` + Limit int `json:"limit,omitempty"` +} + +func (t *ReadTool) Execute(_ context.Context, args json.RawMessage) (tool.Result, error) { + var a readArgs + if err := json.Unmarshal(args, &a); err != nil { + return tool.Result{}, fmt.Errorf("fs.read: invalid args: %w", err) + } + if a.Path == "" { + return tool.Result{}, fmt.Errorf("fs.read: path required") + } + + data, err := os.ReadFile(a.Path) + if err != nil { + return tool.Result{Output: fmt.Sprintf("Error: %v", err)}, nil + } + + lines := strings.Split(string(data), "\n") + totalLines := len(lines) + + // Apply offset + offset := a.Offset + if offset < 0 { + offset = 0 + } + if offset >= totalLines { + return tool.Result{ + Output: fmt.Sprintf("(file has %d lines, offset %d is past end)", totalLines, offset), + Metadata: map[string]any{"total_lines": totalLines}, + }, nil + } + lines = lines[offset:] + + // Apply limit + limit := a.Limit + if limit <= 0 { + limit = t.maxLines + } + truncated := false + if len(lines) > limit { + lines = lines[:limit] + truncated = true + } + + // Format with line numbers (1-based, matching cat -n) + var b strings.Builder + for i, line := range lines { + fmt.Fprintf(&b, "%d\t%s\n", offset+i+1, line) + } + + output := strings.TrimRight(b.String(), "\n") + + meta := map[string]any{"total_lines": totalLines} + if truncated { + meta["truncated"] = true + meta["showing"] = fmt.Sprintf("lines %d-%d of %d", offset+1, offset+len(lines), totalLines) + } + + return tool.Result{Output: output, Metadata: meta}, nil +} diff --git a/internal/tool/fs/write.go b/internal/tool/fs/write.go new file mode 100644 index 0000000..25ba916 --- /dev/null +++ b/internal/tool/fs/write.go @@ -0,0 +1,68 @@ +package fs + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + + "somegit.dev/Owlibou/gnoma/internal/tool" +) + +const writeToolName = "fs.write" + +var writeParams = json.RawMessage(`{ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Absolute path to the file to write" + }, + "content": { + "type": "string", + "description": "Content to write to the file" + } + }, + "required": ["path", "content"] +}`) + +type WriteTool struct{} + +func NewWriteTool() *WriteTool { return &WriteTool{} } + +func (t *WriteTool) Name() string { return writeToolName } +func (t *WriteTool) Description() string { return "Write content to a file, creating parent directories as needed" } +func (t *WriteTool) Parameters() json.RawMessage { return writeParams } +func (t *WriteTool) IsReadOnly() bool { return false } +func (t *WriteTool) IsDestructive() bool { return false } + +type writeArgs struct { + Path string `json:"path"` + Content string `json:"content"` +} + +func (t *WriteTool) Execute(_ context.Context, args json.RawMessage) (tool.Result, error) { + var a writeArgs + if err := json.Unmarshal(args, &a); err != nil { + return tool.Result{}, fmt.Errorf("fs.write: invalid args: %w", err) + } + if a.Path == "" { + return tool.Result{}, fmt.Errorf("fs.write: path required") + } + + // Create parent directories + dir := filepath.Dir(a.Path) + if err := os.MkdirAll(dir, 0o755); err != nil { + return tool.Result{Output: fmt.Sprintf("Error creating directory: %v", err)}, nil + } + + if err := os.WriteFile(a.Path, []byte(a.Content), 0o644); err != nil { + return tool.Result{Output: fmt.Sprintf("Error writing file: %v", err)}, nil + } + + return tool.Result{ + Output: fmt.Sprintf("Wrote %d bytes to %s", len(a.Content), a.Path), + Metadata: map[string]any{"bytes_written": len(a.Content), "path": a.Path}, + }, nil +} diff --git a/internal/tool/registry.go b/internal/tool/registry.go new file mode 100644 index 0000000..483f780 --- /dev/null +++ b/internal/tool/registry.go @@ -0,0 +1,77 @@ +package tool + +import ( + "encoding/json" + "fmt" + "sync" +) + +// Definition is the provider-agnostic tool schema sent to the LLM. +type Definition struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters json.RawMessage `json:"parameters"` +} + +// Registry holds all available tools. +type Registry struct { + mu sync.RWMutex + tools map[string]Tool +} + +func NewRegistry() *Registry { + return &Registry{ + tools: make(map[string]Tool), + } +} + +// Register adds a tool. Overwrites if name already exists. +func (r *Registry) Register(t Tool) { + r.mu.Lock() + defer r.mu.Unlock() + r.tools[t.Name()] = t +} + +// Get returns a tool by name. +func (r *Registry) Get(name string) (Tool, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + t, ok := r.tools[name] + return t, ok +} + +// All returns all registered tools. +func (r *Registry) All() []Tool { + r.mu.RLock() + defer r.mu.RUnlock() + all := make([]Tool, 0, len(r.tools)) + for _, t := range r.tools { + all = append(all, t) + } + return all +} + +// Definitions returns tool definitions for all registered tools, +// suitable for sending to the LLM. +func (r *Registry) Definitions() []Definition { + r.mu.RLock() + defer r.mu.RUnlock() + defs := make([]Definition, 0, len(r.tools)) + for _, t := range r.tools { + defs = append(defs, Definition{ + Name: t.Name(), + Description: t.Description(), + Parameters: t.Parameters(), + }) + } + return defs +} + +// MustGet returns a tool by name or panics. For use in tests. +func (r *Registry) MustGet(name string) Tool { + t, ok := r.Get(name) + if !ok { + panic(fmt.Sprintf("tool not found: %q", name)) + } + return t +} diff --git a/internal/tool/registry_test.go b/internal/tool/registry_test.go new file mode 100644 index 0000000..89b3c5a --- /dev/null +++ b/internal/tool/registry_test.go @@ -0,0 +1,208 @@ +package tool + +import ( + "context" + "encoding/json" + "slices" + "sort" + "testing" +) + +// stubTool is a minimal Tool implementation for testing. +type stubTool struct { + name string + description string + params json.RawMessage + readOnly bool + destructive bool + execFn func(ctx context.Context, args json.RawMessage) (Result, error) +} + +func (s *stubTool) Name() string { return s.name } +func (s *stubTool) Description() string { return s.description } +func (s *stubTool) Parameters() json.RawMessage { return s.params } +func (s *stubTool) IsReadOnly() bool { return s.readOnly } +func (s *stubTool) IsDestructive() bool { return s.destructive } +func (s *stubTool) Execute(ctx context.Context, args json.RawMessage) (Result, error) { + if s.execFn != nil { + return s.execFn(ctx, args) + } + return Result{Output: "ok"}, nil +} + +func TestRegistry_RegisterAndGet(t *testing.T) { + r := NewRegistry() + r.Register(&stubTool{name: "bash", description: "run commands"}) + + tool, ok := r.Get("bash") + if !ok { + t.Fatal("Get(bash) should find tool") + } + if tool.Name() != "bash" { + t.Errorf("Name() = %q", tool.Name()) + } +} + +func TestRegistry_Get_NotFound(t *testing.T) { + r := NewRegistry() + _, ok := r.Get("nonexistent") + if ok { + t.Error("Get(nonexistent) should return false") + } +} + +func TestRegistry_Register_Overwrite(t *testing.T) { + r := NewRegistry() + r.Register(&stubTool{name: "bash", description: "old"}) + r.Register(&stubTool{name: "bash", description: "new"}) + + tool, _ := r.Get("bash") + if tool.Description() != "new" { + t.Errorf("Description() = %q, want 'new' (overwritten)", tool.Description()) + } +} + +func TestRegistry_All(t *testing.T) { + r := NewRegistry() + r.Register(&stubTool{name: "bash"}) + r.Register(&stubTool{name: "fs.read"}) + r.Register(&stubTool{name: "fs.write"}) + + all := r.All() + if len(all) != 3 { + t.Fatalf("len(All()) = %d, want 3", len(all)) + } + + names := make([]string, len(all)) + for i, t := range all { + names[i] = t.Name() + } + sort.Strings(names) + + want := []string{"bash", "fs.read", "fs.write"} + if !slices.Equal(names, want) { + t.Errorf("All() names = %v, want %v", names, want) + } +} + +func TestRegistry_Definitions(t *testing.T) { + r := NewRegistry() + r.Register(&stubTool{ + name: "bash", + description: "Run a command", + params: json.RawMessage(`{"type":"object","properties":{"command":{"type":"string"}}}`), + }) + r.Register(&stubTool{ + name: "fs.read", + description: "Read a file", + params: json.RawMessage(`{"type":"object","properties":{"path":{"type":"string"}}}`), + }) + + defs := r.Definitions() + if len(defs) != 2 { + t.Fatalf("len(Definitions()) = %d, want 2", len(defs)) + } + + // Find bash definition + var bashDef *Definition + for i := range defs { + if defs[i].Name == "bash" { + bashDef = &defs[i] + break + } + } + if bashDef == nil { + t.Fatal("bash definition not found") + } + if bashDef.Description != "Run a command" { + t.Errorf("bash.Description = %q", bashDef.Description) + } + if bashDef.Parameters == nil { + t.Error("bash.Parameters should not be nil") + } +} + +func TestRegistry_MustGet_Panics(t *testing.T) { + r := NewRegistry() + + defer func() { + if r := recover(); r == nil { + t.Error("MustGet should panic for missing tool") + } + }() + + r.MustGet("nonexistent") +} + +func TestRegistry_MustGet_Success(t *testing.T) { + r := NewRegistry() + r.Register(&stubTool{name: "bash"}) + + tool := r.MustGet("bash") + if tool.Name() != "bash" { + t.Errorf("Name() = %q", tool.Name()) + } +} + +func TestRegistry_Empty(t *testing.T) { + r := NewRegistry() + + if len(r.All()) != 0 { + t.Error("empty registry should return no tools") + } + if len(r.Definitions()) != 0 { + t.Error("empty registry should return no definitions") + } +} + +func TestStubTool_Execute(t *testing.T) { + called := false + tool := &stubTool{ + name: "test", + execFn: func(ctx context.Context, args json.RawMessage) (Result, error) { + called = true + var input struct{ Value string } + json.Unmarshal(args, &input) + return Result{ + Output: "processed: " + input.Value, + Metadata: map[string]any{"key": "val"}, + }, nil + }, + } + + result, err := tool.Execute(context.Background(), json.RawMessage(`{"Value":"hello"}`)) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if !called { + t.Error("execFn should have been called") + } + if result.Output != "processed: hello" { + t.Errorf("Output = %q", result.Output) + } + if result.Metadata["key"] != "val" { + t.Errorf("Metadata = %v", result.Metadata) + } +} + +func TestToolInterface_ReadOnlyDestructive(t *testing.T) { + readTool := &stubTool{name: "fs.read", readOnly: true, destructive: false} + writeTool := &stubTool{name: "fs.write", readOnly: false, destructive: false} + deleteTool := &stubTool{name: "bash.rm", readOnly: false, destructive: true} + + if !readTool.IsReadOnly() { + t.Error("fs.read should be read-only") + } + if readTool.IsDestructive() { + t.Error("fs.read should not be destructive") + } + if writeTool.IsReadOnly() { + t.Error("fs.write should not be read-only") + } + if deleteTool.IsReadOnly() { + t.Error("bash.rm should not be read-only") + } + if !deleteTool.IsDestructive() { + t.Error("bash.rm should be destructive") + } +} diff --git a/internal/tool/result.go b/internal/tool/result.go new file mode 100644 index 0000000..8ded038 --- /dev/null +++ b/internal/tool/result.go @@ -0,0 +1,9 @@ +package tool + +// Result is the output of a tool execution. +type Result struct { + // Output is the text content returned to the LLM. + Output string + // Metadata carries optional structured data (exit code, file path, match count, etc.). + Metadata map[string]any +} diff --git a/internal/tool/tool.go b/internal/tool/tool.go new file mode 100644 index 0000000..a80266c --- /dev/null +++ b/internal/tool/tool.go @@ -0,0 +1,22 @@ +package tool + +import ( + "context" + "encoding/json" +) + +// Tool is the interface every tool must implement. +type Tool interface { + // Name returns the tool's identifier (used in LLM tool schemas). + Name() string + // Description returns a human-readable description for the LLM. + Description() string + // Parameters returns the JSON Schema for the tool's input. + Parameters() json.RawMessage + // Execute runs the tool with the given JSON arguments. + Execute(ctx context.Context, args json.RawMessage) (Result, error) + // IsReadOnly returns true if the tool only reads (safe for concurrent execution). + IsReadOnly() bool + // IsDestructive returns true if the tool can cause irreversible changes. + IsDestructive() bool +}