diff --git a/cmd/gnoma/main.go b/cmd/gnoma/main.go index daa4135..5aa82e1 100644 --- a/cmd/gnoma/main.go +++ b/cmd/gnoma/main.go @@ -229,10 +229,23 @@ func main() { } } + cwd, cwdErr := os.Getwd() + if cwdErr != nil { + fmt.Fprintf(os.Stderr, "error: cannot resolve working directory: %v\n", cwdErr) + os.Exit(1) + } + fsGuard, err := fs.NewGuard(cwd) + if err != nil { + fmt.Fprintf(os.Stderr, "error: workspace guard: %v\n", err) + os.Exit(1) + } + // Create tool registry - reg := buildToolRegistry() + reg := buildToolRegistry(fsGuard) if cfg.Tools.MaxFileSize > 0 { - reg.Register(fs.NewWriteTool(fs.WithMaxFileSize(cfg.Tools.MaxFileSize))) + w := fs.NewWriteTool(fs.WithMaxFileSize(cfg.Tools.MaxFileSize)) + w.SetGuard(fsGuard) + reg.Register(w) } // Harvest aliases, inventory, CLI agents, and local models in parallel. @@ -991,15 +1004,34 @@ func createProvider(name, apiKey, model, baseURL string) (provider.Provider, err } } -func buildToolRegistry() *tool.Registry { +func buildToolRegistry(guard *fs.Guard) *tool.Registry { reg := tool.NewRegistry() reg.Register(bash.New()) - reg.Register(fs.NewReadTool()) - reg.Register(fs.NewWriteTool()) - reg.Register(fs.NewEditTool()) - reg.Register(fs.NewGlobTool()) - reg.Register(fs.NewGrepTool()) - reg.Register(fs.NewLSTool()) + + read := fs.NewReadTool() + read.SetGuard(guard) + reg.Register(read) + + write := fs.NewWriteTool() + write.SetGuard(guard) + reg.Register(write) + + edit := fs.NewEditTool() + edit.SetGuard(guard) + reg.Register(edit) + + glob := fs.NewGlobTool() + glob.SetGuard(guard) + reg.Register(glob) + + grep := fs.NewGrepTool() + grep.SetGuard(guard) + reg.Register(grep) + + ls := fs.NewLSTool() + ls.SetGuard(guard) + reg.Register(ls) + return reg } diff --git a/internal/tool/fs/edit.go b/internal/tool/fs/edit.go index c98fc4f..fb6d49d 100644 --- a/internal/tool/fs/edit.go +++ b/internal/tool/fs/edit.go @@ -35,10 +35,14 @@ var editParams = json.RawMessage(`{ "required": ["path", "old_string", "new_string"] }`) -type EditTool struct{} +type EditTool struct { + guard *Guard +} func NewEditTool() *EditTool { return &EditTool{} } +func (t *EditTool) SetGuard(g *Guard) { t.guard = g } + func (t *EditTool) Name() string { return editToolName } func (t *EditTool) Description() string { return "Perform exact string replacement in a file" } func (t *EditTool) Parameters() json.RawMessage { return editParams } @@ -72,7 +76,16 @@ func (t *EditTool) Execute(_ context.Context, args json.RawMessage) (tool.Result return tool.Result{}, fmt.Errorf("fs.edit: old_string and new_string must differ") } - data, err := os.ReadFile(a.Path) + path := a.Path + if t.guard != nil { + resolved, err := t.guard.ResolveRead(path) + if err != nil { + return tool.Result{Output: fmt.Sprintf("Error: %v", err)}, nil + } + path = resolved + } + + data, err := os.ReadFile(path) if err != nil { return tool.Result{Output: fmt.Sprintf("Error: %v", err)}, nil } @@ -101,7 +114,7 @@ func (t *EditTool) Execute(_ context.Context, args json.RawMessage) (tool.Result newContent = strings.Replace(content, a.OldString, a.NewString, 1) } - if err := os.WriteFile(a.Path, []byte(newContent), 0o644); err != nil { + if err := os.WriteFile(path, []byte(newContent), 0o644); err != nil { return tool.Result{Output: fmt.Sprintf("Error writing file: %v", err)}, nil } @@ -111,11 +124,11 @@ func (t *EditTool) Execute(_ context.Context, args json.RawMessage) (tool.Result } // Generate diff-style output with context - diff := buildEditDiff(content, a.OldString, a.NewString, a.Path, replacements) + diff := buildEditDiff(content, a.OldString, a.NewString, path, replacements) return tool.Result{ Output: diff, - Metadata: map[string]any{"replacements": replacements, "path": a.Path}, + Metadata: map[string]any{"replacements": replacements, "path": path}, }, nil } diff --git a/internal/tool/fs/glob.go b/internal/tool/fs/glob.go index 5c3e7ff..baa0c5a 100644 --- a/internal/tool/fs/glob.go +++ b/internal/tool/fs/glob.go @@ -30,10 +30,14 @@ var globParams = json.RawMessage(`{ "required": ["pattern"] }`) -type GlobTool struct{} +type GlobTool struct { + guard *Guard +} func NewGlobTool() *GlobTool { return &GlobTool{} } +func (t *GlobTool) SetGuard(g *Guard) { t.guard = g } + func (t *GlobTool) Name() string { return globToolName } func (t *GlobTool) Description() string { return "Find files matching a glob pattern, sorted by modification time" } func (t *GlobTool) Parameters() json.RawMessage { return globParams } @@ -64,12 +68,23 @@ func (t *GlobTool) Execute(_ context.Context, args json.RawMessage) (tool.Result root := a.Path if root == "" { - var err error - root, err = os.Getwd() - if err != nil { - return tool.Result{}, fmt.Errorf("fs.glob: %w", err) + if t.guard != nil { + root = t.guard.Roots()[0] + } else { + var err error + root, err = os.Getwd() + if err != nil { + return tool.Result{}, fmt.Errorf("fs.glob: %w", err) + } } } + if t.guard != nil { + resolved, err := t.guard.ResolveRead(root) + if err != nil { + return tool.Result{Output: fmt.Sprintf("Error: %v", err)}, nil + } + root = resolved + } var matches []string err := filepath.WalkDir(root, func(path string, d os.DirEntry, err error) error { diff --git a/internal/tool/fs/grep.go b/internal/tool/fs/grep.go index f07a171..ed993fa 100644 --- a/internal/tool/fs/grep.go +++ b/internal/tool/fs/grep.go @@ -41,10 +41,14 @@ var grepParams = json.RawMessage(`{ "required": ["pattern"] }`) -type GrepTool struct{} +type GrepTool struct { + guard *Guard +} func NewGrepTool() *GrepTool { return &GrepTool{} } +func (t *GrepTool) SetGuard(g *Guard) { t.guard = g } + func (t *GrepTool) Name() string { return grepToolName } func (t *GrepTool) Description() string { return "Search file contents using a regular expression" } func (t *GrepTool) Parameters() json.RawMessage { return grepParams } @@ -93,11 +97,22 @@ func (t *GrepTool) Execute(_ context.Context, args json.RawMessage) (tool.Result root := a.Path if root == "" { - root, err = os.Getwd() - if err != nil { - return tool.Result{}, fmt.Errorf("fs.grep: %w", err) + if t.guard != nil { + root = t.guard.Roots()[0] + } else { + root, err = os.Getwd() + if err != nil { + return tool.Result{}, fmt.Errorf("fs.grep: %w", err) + } } } + if t.guard != nil { + resolved, err := t.guard.ResolveRead(root) + if err != nil { + return tool.Result{Output: fmt.Sprintf("Error: %v", err)}, nil + } + root = resolved + } info, err := os.Stat(root) if err != nil { diff --git a/internal/tool/fs/guard.go b/internal/tool/fs/guard.go new file mode 100644 index 0000000..6afdbc7 --- /dev/null +++ b/internal/tool/fs/guard.go @@ -0,0 +1,123 @@ +package fs + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "strings" +) + +var ErrOutsideWorkspace = errors.New("path outside workspace") + +type Guard struct { + roots []string +} + +func NewGuard(roots ...string) (*Guard, error) { + if len(roots) == 0 { + return nil, errors.New("guard requires at least one root") + } + resolved := make([]string, 0, len(roots)) + for _, r := range roots { + if !filepath.IsAbs(r) { + return nil, fmt.Errorf("guard root %q must be absolute", r) + } + canonical, err := filepath.EvalSymlinks(r) + if err != nil { + return nil, fmt.Errorf("guard root %q: %w", r, err) + } + info, err := os.Stat(canonical) + if err != nil { + return nil, fmt.Errorf("guard root %q: %w", r, err) + } + if !info.IsDir() { + return nil, fmt.Errorf("guard root %q is not a directory", r) + } + resolved = append(resolved, filepath.Clean(canonical)) + } + return &Guard{roots: resolved}, nil +} + +func (g *Guard) Roots() []string { + out := make([]string, len(g.roots)) + copy(out, g.roots) + return out +} + +func (g *Guard) ResolveRead(path string) (string, error) { + abs, err := g.absolutise(path) + if err != nil { + return "", err + } + canonical, err := filepath.EvalSymlinks(abs) + if err != nil { + return "", fmt.Errorf("resolve %q: %w", path, err) + } + if !g.contains(canonical) { + return "", fmt.Errorf("%w: %s", ErrOutsideWorkspace, path) + } + return canonical, nil +} + +// ResolveWrite canonicalises the deepest existing ancestor so a symlinked +// parent escaping the workspace is rejected even when the leaf doesn't exist. +func (g *Guard) ResolveWrite(path string) (string, error) { + abs, err := g.absolutise(path) + if err != nil { + return "", err + } + + ancestor := abs + tail := "" + for { + if _, err := os.Lstat(ancestor); err == nil { + break + } + parent := filepath.Dir(ancestor) + if parent == ancestor { + return "", fmt.Errorf("resolve %q: no existing ancestor", path) + } + tail = filepath.Join(filepath.Base(ancestor), tail) + ancestor = parent + } + + canonicalAncestor, err := filepath.EvalSymlinks(ancestor) + if err != nil { + return "", fmt.Errorf("resolve ancestor of %q: %w", path, err) + } + resolved := canonicalAncestor + if tail != "" { + resolved = filepath.Join(canonicalAncestor, tail) + } + if !g.contains(resolved) { + return "", fmt.Errorf("%w: %s", ErrOutsideWorkspace, path) + } + return resolved, nil +} + +// absolutise anchors relative paths against the first root rather than process +// cwd, which may drift over the lifetime of the agent. +func (g *Guard) absolutise(path string) (string, error) { + if path == "" { + return "", errors.New("empty path") + } + if filepath.IsAbs(path) { + return filepath.Clean(path), nil + } + return filepath.Clean(filepath.Join(g.roots[0], path)), nil +} + +// contains uses a separator boundary so "/ws-evil" is not considered inside "/ws". +func (g *Guard) contains(canonical string) bool { + for _, root := range g.roots { + if canonical == root { + return true + } + prefix := root + string(filepath.Separator) + if strings.HasPrefix(canonical, prefix) { + return true + } + } + return false +} diff --git a/internal/tool/fs/guard_test.go b/internal/tool/fs/guard_test.go new file mode 100644 index 0000000..ac90b7b --- /dev/null +++ b/internal/tool/fs/guard_test.go @@ -0,0 +1,232 @@ +package fs + +import ( + "errors" + "os" + "path/filepath" + "runtime" + "strings" + "testing" +) + +func TestNewGuard_RejectsEmptyRoots(t *testing.T) { + if _, err := NewGuard(); err == nil { + t.Fatal("NewGuard() with no roots should error") + } +} + +func TestNewGuard_RejectsRelativeRoot(t *testing.T) { + if _, err := NewGuard("relative/path"); err == nil { + t.Fatal("NewGuard with relative root should error") + } +} + +func TestNewGuard_RejectsNonexistentRoot(t *testing.T) { + if _, err := NewGuard("/definitely/does/not/exist/anywhere"); err == nil { + t.Fatal("NewGuard with nonexistent root should error") + } +} + +func TestGuard_ResolveInsideRoot(t *testing.T) { + root := t.TempDir() + g := mustGuard(t, root) + + path := filepath.Join(root, "file.txt") + if err := os.WriteFile(path, []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + + got, err := g.ResolveRead(path) + if err != nil { + t.Fatalf("ResolveRead inside root: %v", err) + } + if got != path { + t.Errorf("got %q, want %q", got, path) + } +} + +func TestGuard_ResolveReadOutsideRootDenied(t *testing.T) { + root := t.TempDir() + outside := t.TempDir() + g := mustGuard(t, root) + + outsidePath := filepath.Join(outside, "secret") + if err := os.WriteFile(outsidePath, []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + + _, err := g.ResolveRead(outsidePath) + if !errors.Is(err, ErrOutsideWorkspace) { + t.Fatalf("want ErrOutsideWorkspace, got %v", err) + } +} + +func TestGuard_ResolveReadRelativeEscape(t *testing.T) { + root := t.TempDir() + g := mustGuard(t, root) + + // Relative path with ../../../ should resolve relative to first root and + // escape it; guard must deny. + _, err := g.ResolveRead("../../../etc/passwd") + if !errors.Is(err, ErrOutsideWorkspace) { + t.Fatalf("want ErrOutsideWorkspace, got %v", err) + } +} + +func TestGuard_ResolveReadSymlinkEscapeDenied(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("symlink semantics differ on Windows") + } + root := t.TempDir() + outside := t.TempDir() + g := mustGuard(t, root) + + target := filepath.Join(outside, "target") + if err := os.WriteFile(target, []byte("secret"), 0o644); err != nil { + t.Fatal(err) + } + link := filepath.Join(root, "link") + if err := os.Symlink(target, link); err != nil { + t.Fatal(err) + } + + _, err := g.ResolveRead(link) + if !errors.Is(err, ErrOutsideWorkspace) { + t.Fatalf("symlink escaping root should be denied; got %v", err) + } +} + +func TestGuard_ResolveReadSymlinkWithinRootAllowed(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("symlink semantics differ on Windows") + } + root := t.TempDir() + g := mustGuard(t, root) + + target := filepath.Join(root, "target") + if err := os.WriteFile(target, []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + link := filepath.Join(root, "link") + if err := os.Symlink(target, link); err != nil { + t.Fatal(err) + } + + got, err := g.ResolveRead(link) + if err != nil { + t.Fatalf("symlink inside root: %v", err) + } + // Canonical form should be the target, not the link. + if !strings.HasPrefix(got, root) { + t.Errorf("canonical %q should be inside root %q", got, root) + } +} + +func TestGuard_ResolveWriteNewFileAllowed(t *testing.T) { + root := t.TempDir() + g := mustGuard(t, root) + + newFile := filepath.Join(root, "newdir", "newfile.txt") + got, err := g.ResolveWrite(newFile) + if err != nil { + t.Fatalf("ResolveWrite to new path inside root: %v", err) + } + if !strings.HasPrefix(got, root) { + t.Errorf("canonical %q should be inside root %q", got, root) + } +} + +func TestGuard_ResolveWriteOutsideRootDenied(t *testing.T) { + root := t.TempDir() + outside := t.TempDir() + g := mustGuard(t, root) + + _, err := g.ResolveWrite(filepath.Join(outside, "evil.txt")) + if !errors.Is(err, ErrOutsideWorkspace) { + t.Fatalf("want ErrOutsideWorkspace, got %v", err) + } +} + +func TestGuard_ResolveWriteViaSymlinkedParentDenied(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("symlink semantics differ on Windows") + } + root := t.TempDir() + outside := t.TempDir() + g := mustGuard(t, root) + + // Create a symlink inside root whose target is outside. + linkedDir := filepath.Join(root, "escape") + if err := os.Symlink(outside, linkedDir); err != nil { + t.Fatal(err) + } + + // Writing under the symlinked dir lands outside root. + _, err := g.ResolveWrite(filepath.Join(linkedDir, "evil.txt")) + if !errors.Is(err, ErrOutsideWorkspace) { + t.Fatalf("write via symlinked parent should be denied; got %v", err) + } +} + +func TestGuard_MultipleRoots(t *testing.T) { + rootA := t.TempDir() + rootB := t.TempDir() + g := mustGuard(t, rootA, rootB) + + a := filepath.Join(rootA, "file") + if err := os.WriteFile(a, []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + b := filepath.Join(rootB, "file") + if err := os.WriteFile(b, []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + + if _, err := g.ResolveRead(a); err != nil { + t.Errorf("rootA: %v", err) + } + if _, err := g.ResolveRead(b); err != nil { + t.Errorf("rootB: %v", err) + } +} + +func TestGuard_RootBoundaryNotPrefixMatch(t *testing.T) { + // Catch the classic bug: /foo/bar must NOT be considered inside /foo/ba. + parent := t.TempDir() + rootShort := filepath.Join(parent, "ws") + rootLongName := filepath.Join(parent, "ws-evil") + for _, d := range []string{rootShort, rootLongName} { + if err := os.MkdirAll(d, 0o755); err != nil { + t.Fatal(err) + } + } + g := mustGuard(t, rootShort) + + evil := filepath.Join(rootLongName, "secret") + if err := os.WriteFile(evil, []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + + _, err := g.ResolveRead(evil) + if !errors.Is(err, ErrOutsideWorkspace) { + t.Fatalf("ws-evil/secret should not be considered inside ws; got %v", err) + } +} + +func TestGuard_RootItselfAllowed(t *testing.T) { + root := t.TempDir() + g := mustGuard(t, root) + + if _, err := g.ResolveRead(root); err != nil { + t.Errorf("root path itself should be allowed: %v", err) + } +} + +func mustGuard(t *testing.T, roots ...string) *Guard { + t.Helper() + g, err := NewGuard(roots...) + if err != nil { + t.Fatalf("NewGuard: %v", err) + } + return g +} diff --git a/internal/tool/fs/guard_tools_test.go b/internal/tool/fs/guard_tools_test.go new file mode 100644 index 0000000..3139d7a --- /dev/null +++ b/internal/tool/fs/guard_tools_test.go @@ -0,0 +1,184 @@ +package fs + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" +) + +// These tests exercise each fs tool with a Guard installed, verifying that +// paths outside the workspace are rejected at the tool boundary. + +func TestReadTool_GuardDeniesOutsideRoot(t *testing.T) { + root := t.TempDir() + outside := t.TempDir() + outsideFile := filepath.Join(outside, "secret") + if err := os.WriteFile(outsideFile, []byte("secret"), 0o644); err != nil { + t.Fatal(err) + } + + r := NewReadTool() + r.SetGuard(mustGuard(t, root)) + + res, err := r.Execute(context.Background(), mustJSON(t, readArgs{Path: outsideFile})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if !strings.Contains(res.Output, "path outside workspace") { + t.Errorf("expected workspace error, got %q", res.Output) + } +} + +func TestReadTool_GuardAllowsInsideRoot(t *testing.T) { + root := t.TempDir() + inside := filepath.Join(root, "ok.txt") + if err := os.WriteFile(inside, []byte("hi"), 0o644); err != nil { + t.Fatal(err) + } + + r := NewReadTool() + r.SetGuard(mustGuard(t, root)) + + res, err := r.Execute(context.Background(), mustJSON(t, readArgs{Path: inside})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if !strings.Contains(res.Output, "hi") { + t.Errorf("expected file content, got %q", res.Output) + } +} + +func TestWriteTool_GuardDeniesOutsideRoot(t *testing.T) { + root := t.TempDir() + outside := t.TempDir() + target := filepath.Join(outside, "evil.txt") + + w := NewWriteTool() + w.SetGuard(mustGuard(t, root)) + + res, err := w.Execute(context.Background(), mustJSON(t, writeArgs{Path: target, Content: "x"})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if !strings.Contains(res.Output, "path outside workspace") { + t.Errorf("expected workspace error, got %q", res.Output) + } + if _, err := os.Stat(target); err == nil { + t.Errorf("file was written despite guard: %s", target) + } +} + +func TestWriteTool_GuardAllowsInsideRoot(t *testing.T) { + root := t.TempDir() + target := filepath.Join(root, "sub", "ok.txt") + + w := NewWriteTool() + w.SetGuard(mustGuard(t, root)) + + res, err := w.Execute(context.Background(), mustJSON(t, writeArgs{Path: target, Content: "hi"})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if !strings.Contains(res.Output, "Wrote") { + t.Errorf("expected write confirmation, got %q", res.Output) + } + if _, err := os.Stat(target); err != nil { + t.Errorf("file missing after write: %v", err) + } +} + +func TestEditTool_GuardDeniesOutsideRoot(t *testing.T) { + root := t.TempDir() + outside := t.TempDir() + target := filepath.Join(outside, "f.txt") + if err := os.WriteFile(target, []byte("hello"), 0o644); err != nil { + t.Fatal(err) + } + + e := NewEditTool() + e.SetGuard(mustGuard(t, root)) + + res, err := e.Execute(context.Background(), mustJSON(t, editArgs{Path: target, OldString: "hello", NewString: "hi"})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if !strings.Contains(res.Output, "path outside workspace") { + t.Errorf("expected workspace error, got %q", res.Output) + } + // File must remain unchanged. + data, _ := os.ReadFile(target) + if string(data) != "hello" { + t.Errorf("file mutated despite guard: %q", string(data)) + } +} + +func TestLSTool_GuardDeniesOutsideRoot(t *testing.T) { + root := t.TempDir() + outside := t.TempDir() + + l := NewLSTool() + l.SetGuard(mustGuard(t, root)) + + res, err := l.Execute(context.Background(), mustJSON(t, lsArgs{Path: outside})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if !strings.Contains(res.Output, "path outside workspace") { + t.Errorf("expected workspace error, got %q", res.Output) + } +} + +func TestLSTool_GuardEmptyPathDefaultsToRoot(t *testing.T) { + root := t.TempDir() + if err := os.WriteFile(filepath.Join(root, "marker.txt"), []byte("x"), 0o644); err != nil { + t.Fatal(err) + } + + l := NewLSTool() + l.SetGuard(mustGuard(t, root)) + + res, err := l.Execute(context.Background(), mustJSON(t, lsArgs{})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if !strings.Contains(res.Output, "marker.txt") { + t.Errorf("expected to list root contents, got %q", res.Output) + } +} + +func TestGlobTool_GuardDeniesOutsideRoot(t *testing.T) { + root := t.TempDir() + outside := t.TempDir() + + g := NewGlobTool() + g.SetGuard(mustGuard(t, root)) + + res, err := g.Execute(context.Background(), mustJSON(t, globArgs{Pattern: "*", Path: outside})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if !strings.Contains(res.Output, "path outside workspace") { + t.Errorf("expected workspace error, got %q", res.Output) + } +} + +func TestGrepTool_GuardDeniesOutsideRoot(t *testing.T) { + root := t.TempDir() + outside := t.TempDir() + if err := os.WriteFile(filepath.Join(outside, "f.txt"), []byte("needle"), 0o644); err != nil { + t.Fatal(err) + } + + g := NewGrepTool() + g.SetGuard(mustGuard(t, root)) + + res, err := g.Execute(context.Background(), mustJSON(t, grepArgs{Pattern: "needle", Path: outside})) + if err != nil { + t.Fatalf("Execute: %v", err) + } + if !strings.Contains(res.Output, "path outside workspace") { + t.Errorf("expected workspace error, got %q", res.Output) + } +} diff --git a/internal/tool/fs/ls.go b/internal/tool/fs/ls.go index 8ef9a4d..dc71954 100644 --- a/internal/tool/fs/ls.go +++ b/internal/tool/fs/ls.go @@ -24,10 +24,14 @@ var lsParams = json.RawMessage(`{ } }`) -type LSTool struct{} +type LSTool struct { + guard *Guard +} func NewLSTool() *LSTool { return &LSTool{} } +func (t *LSTool) SetGuard(g *Guard) { t.guard = g } + func (t *LSTool) Name() string { return lsToolName } func (t *LSTool) Description() string { return "List directory contents with file types and sizes" } func (t *LSTool) Parameters() json.RawMessage { return lsParams } @@ -54,12 +58,23 @@ func (t *LSTool) Execute(_ context.Context, args json.RawMessage) (tool.Result, dir := a.Path if dir == "" { - var err error - dir, err = os.Getwd() - if err != nil { - return tool.Result{}, fmt.Errorf("fs.ls: %w", err) + if t.guard != nil { + dir = t.guard.Roots()[0] + } else { + var err error + dir, err = os.Getwd() + if err != nil { + return tool.Result{}, fmt.Errorf("fs.ls: %w", err) + } } } + if t.guard != nil { + resolved, err := t.guard.ResolveRead(dir) + if err != nil { + return tool.Result{Output: fmt.Sprintf("Error: %v", err)}, nil + } + dir = resolved + } entries, err := os.ReadDir(dir) if err != nil { diff --git a/internal/tool/fs/read.go b/internal/tool/fs/read.go index e9b757b..2e64bb8 100644 --- a/internal/tool/fs/read.go +++ b/internal/tool/fs/read.go @@ -36,8 +36,11 @@ var readParams = json.RawMessage(`{ type ReadTool struct { maxLines int + guard *Guard } +func (t *ReadTool) SetGuard(g *Guard) { t.guard = g } + type ReadOption func(*ReadTool) func WithMaxLines(n int) ReadOption { @@ -81,7 +84,16 @@ func (t *ReadTool) Execute(_ context.Context, args json.RawMessage) (tool.Result return tool.Result{}, fmt.Errorf("fs.read: path required") } - data, err := os.ReadFile(a.Path) + path := a.Path + if t.guard != nil { + resolved, err := t.guard.ResolveRead(path) + if err != nil { + return tool.Result{Output: fmt.Sprintf("Error: %v", err)}, nil + } + path = resolved + } + + data, err := os.ReadFile(path) if err != nil { return tool.Result{Output: fmt.Sprintf("Error: %v", err)}, nil } diff --git a/internal/tool/fs/write.go b/internal/tool/fs/write.go index 85e8d4a..255920b 100644 --- a/internal/tool/fs/write.go +++ b/internal/tool/fs/write.go @@ -36,8 +36,11 @@ func WithMaxFileSize(n int64) WriteOption { type WriteTool struct { maxFileSize int64 + guard *Guard } +func (t *WriteTool) SetGuard(g *Guard) { t.guard = g } + func NewWriteTool(opts ...WriteOption) *WriteTool { t := &WriteTool{} for _, opt := range opts { @@ -78,18 +81,27 @@ func (t *WriteTool) Execute(_ context.Context, args json.RawMessage) (tool.Resul return tool.Result{Output: fmt.Sprintf("Error: content too large (%d bytes, limit %d bytes)", len(a.Content), t.maxFileSize)}, nil } + path := a.Path + if t.guard != nil { + resolved, err := t.guard.ResolveWrite(path) + if err != nil { + return tool.Result{Output: fmt.Sprintf("Error: %v", err)}, nil + } + path = resolved + } + // Create parent directories - dir := filepath.Dir(a.Path) + dir := filepath.Dir(path) if err := os.MkdirAll(dir, 0o755); err != nil { return tool.Result{Output: fmt.Sprintf("Error creating directory: %v", err)}, nil } - if err := os.WriteFile(a.Path, []byte(a.Content), 0o644); err != nil { + if err := os.WriteFile(path, []byte(a.Content), 0o644); err != nil { return tool.Result{Output: fmt.Sprintf("Error writing file: %v", err)}, nil } return tool.Result{ - Output: fmt.Sprintf("Wrote %d bytes to %s", len(a.Content), a.Path), - Metadata: map[string]any{"bytes_written": len(a.Content), "path": a.Path}, + Output: fmt.Sprintf("Wrote %d bytes to %s", len(a.Content), path), + Metadata: map[string]any{"bytes_written": len(a.Content), "path": path}, }, nil }