diff --git a/go.mod b/go.mod index e0181b3..df75dce 100644 --- a/go.mod +++ b/go.mod @@ -16,7 +16,7 @@ require ( cloud.google.com/go/auth v0.9.3 // indirect cloud.google.com/go/compute/metadata v0.5.0 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect - github.com/google/go-cmp v0.6.0 // indirect + github.com/google/go-cmp v0.7.0 // indirect github.com/google/s2a-go v0.1.8 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect github.com/gorilla/websocket v1.5.3 // indirect @@ -27,9 +27,10 @@ require ( go.opencensus.io v0.24.0 // indirect golang.org/x/crypto v0.40.0 // indirect golang.org/x/net v0.41.0 // indirect - golang.org/x/sync v0.16.0 // indirect - golang.org/x/sys v0.34.0 // indirect + golang.org/x/sync v0.17.0 // indirect + golang.org/x/sys v0.42.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect google.golang.org/grpc v1.66.2 // indirect google.golang.org/protobuf v1.34.2 // indirect + mvdan.cc/sh/v3 v3.13.0 // indirect ) diff --git a/go.sum b/go.sum index 6d5877f..6a712e5 100644 --- a/go.sum +++ b/go.sum @@ -45,6 +45,8 @@ github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/s2a-go v0.1.8 h1:zZDs9gcbt9ZPLV0ndSyQk6Kacx2g/X+SKYovpnz3SMM= github.com/google/s2a-go v0.1.8/go.mod h1:6iNWHTpQ+nfNRN5E00MSdfDwVesa8hhS32PhPO8deJA= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -99,12 +101,16 @@ golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= +golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= @@ -150,3 +156,5 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +mvdan.cc/sh/v3 v3.13.0 h1:dSfq/MVsY4w0Vsi6Lbs0IcQquMVqLdKLESAOZjuHdLg= +mvdan.cc/sh/v3 v3.13.0/go.mod h1:KV1GByGPc/Ho0X1E6Uz9euhsIQEj4hwyKnodLlFLoDM= diff --git a/internal/permission/checker.go b/internal/permission/checker.go new file mode 100644 index 0000000..f69273f --- /dev/null +++ b/internal/permission/checker.go @@ -0,0 +1,196 @@ +package permission + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" +) + +var ErrDenied = errors.New("permission denied") + +// PromptFunc asks the user to approve/deny a tool call. +// Returns true if approved. +type PromptFunc func(ctx context.Context, toolName string, args json.RawMessage) (bool, error) + +// ToolInfo provides tool metadata for permission decisions. +type ToolInfo struct { + Name string + IsReadOnly bool + IsDestructive bool +} + +// Checker evaluates tool permissions using the 7-step decision flow. +// +// Decision flow (from CC, adapted): +// 1. Rule-based deny gates (BEFORE mode — even bypass can't override) +// 2. Tool-specific safety checks (.env, .git, credentials) +// 3. Mode-based bypass +// 4. Rule-based allow +// 5. Mode-specific behavior +// 6. Prompt user if needed +type Checker struct { + mode Mode + rules []Rule + promptFn PromptFunc + + // Safety patterns — always checked, even in bypass mode + safetyDenyPatterns []string +} + +func NewChecker(mode Mode, rules []Rule, promptFn PromptFunc) *Checker { + return &Checker{ + mode: mode, + rules: rules, + promptFn: promptFn, + safetyDenyPatterns: []string{ + ".env", ".git/", "credentials", "id_rsa", "id_ed25519", + ".ssh/", ".gnupg/", ".aws/credentials", + }, + } +} + +// SetMode changes the active permission mode. +func (c *Checker) SetMode(mode Mode) { + c.mode = mode +} + +// Mode returns the current permission mode. +func (c *Checker) Mode() Mode { + return c.mode +} + +// Check evaluates whether a tool call is permitted. +// Returns nil if allowed, ErrDenied if denied. +func (c *Checker) Check(ctx context.Context, info ToolInfo, args json.RawMessage) error { + // Step 1: Rule-based deny gates (bypass-immune) + if c.matchesRule(info.Name, args, ActionDeny) { + return fmt.Errorf("%w: deny rule matched for %s", ErrDenied, info.Name) + } + + // Step 2: Safety checks (bypass-immune) + if err := c.safetyCheck(info.Name, args); err != nil { + return err + } + + // For compound bash commands, check each subcommand + if info.Name == "bash" { + if err := c.checkCompoundCommand(ctx, info, args); err != nil { + return err + } + } + + // Step 3: Mode-based bypass + if c.mode == ModeBypass { + return nil + } + + // Step 4: Rule-based allow + if c.matchesRule(info.Name, args, ActionAllow) { + return nil + } + + // Step 5: Mode-specific behavior + switch c.mode { + case ModeDeny: + return fmt.Errorf("%w: deny mode, no allow rule for %s", ErrDenied, info.Name) + + case ModePlan: + if !info.IsReadOnly { + return fmt.Errorf("%w: plan mode, %s is not read-only", ErrDenied, info.Name) + } + return nil + + case ModeAcceptEdits: + // Auto-allow file reads and edits, prompt for bash/destructive + if info.IsReadOnly { + return nil + } + if strings.HasPrefix(info.Name, "fs.") && !info.IsDestructive { + return nil + } + // Fall through to prompt + + case ModeAuto: + // Auto-allow read-only tools + if info.IsReadOnly { + return nil + } + // Fall through to prompt for write tools + + case ModeDefault: + // Always prompt + } + + // Step 6: Prompt user + return c.prompt(ctx, info.Name, args) +} + +func (c *Checker) matchesRule(toolName string, args json.RawMessage, action Action) bool { + for _, rule := range c.rules { + if rule.Action != action { + continue + } + if !rule.Matches(toolName) { + continue + } + // If rule has a pattern, check it against serialized args + if rule.Pattern != "" { + if !strings.Contains(string(args), rule.Pattern) { + continue + } + } + return true + } + return false +} + +func (c *Checker) safetyCheck(toolName string, args json.RawMessage) error { + argsStr := string(args) + for _, pattern := range c.safetyDenyPatterns { + if strings.Contains(argsStr, pattern) { + return fmt.Errorf("%w: safety check blocked access to %q via %s", ErrDenied, pattern, toolName) + } + } + return nil +} + +func (c *Checker) checkCompoundCommand(ctx context.Context, info ToolInfo, args json.RawMessage) error { + var bashArgs struct { + Command string `json:"command"` + } + if err := json.Unmarshal(args, &bashArgs); err != nil || bashArgs.Command == "" { + return nil + } + + subcommands := SplitCompoundCommand(bashArgs.Command) + if len(subcommands) <= 1 { + return nil // single command, handled by main flow + } + + // Check each subcommand — deny from any subcommand denies the whole compound + for _, sub := range subcommands { + subArgs, _ := json.Marshal(map[string]string{"command": sub}) + if c.matchesRule("bash", subArgs, ActionDeny) { + return fmt.Errorf("%w: deny rule matched subcommand %q", ErrDenied, sub) + } + } + return nil +} + +func (c *Checker) prompt(ctx context.Context, toolName string, args json.RawMessage) error { + if c.promptFn == nil { + // No prompt function — deny by default + return fmt.Errorf("%w: no prompt handler for %s", ErrDenied, toolName) + } + + approved, err := c.promptFn(ctx, toolName, args) + if err != nil { + return fmt.Errorf("permission prompt: %w", err) + } + if !approved { + return fmt.Errorf("%w: user denied %s", ErrDenied, toolName) + } + return nil +} diff --git a/internal/permission/mode.go b/internal/permission/mode.go new file mode 100644 index 0000000..318567f --- /dev/null +++ b/internal/permission/mode.go @@ -0,0 +1,29 @@ +package permission + +// Mode controls the overall permission behavior. +type Mode string + +const ( + // ModeDefault prompts the user for each tool invocation. + ModeDefault Mode = "default" + // ModeAcceptEdits auto-allows file edits + reads, prompts for bash/destructive. + ModeAcceptEdits Mode = "accept_edits" + // ModeBypass allows everything without prompting. + ModeBypass Mode = "bypass" + // ModeDeny denies everything unless an explicit allow rule matches. + ModeDeny Mode = "deny" + // ModePlan allows only read-only tools, blocks all writes. + ModePlan Mode = "plan" + // ModeAuto uses task type + tool risk scoring to decide. + // Low-risk read-only tools auto-allow, everything else prompts. + ModeAuto Mode = "auto" +) + +// Valid returns true if the mode is recognized. +func (m Mode) Valid() bool { + switch m { + case ModeDefault, ModeAcceptEdits, ModeBypass, ModeDeny, ModePlan, ModeAuto: + return true + } + return false +} diff --git a/internal/permission/permission_test.go b/internal/permission/permission_test.go new file mode 100644 index 0000000..6dc9591 --- /dev/null +++ b/internal/permission/permission_test.go @@ -0,0 +1,235 @@ +package permission + +import ( + "context" + "encoding/json" + "errors" + "testing" +) + +func TestMode_Valid(t *testing.T) { + valid := []Mode{ModeDefault, ModeAcceptEdits, ModeBypass, ModeDeny, ModePlan, ModeAuto} + for _, m := range valid { + if !m.Valid() { + t.Errorf("mode %q should be valid", m) + } + } + if Mode("bogus").Valid() { + t.Error("bogus mode should be invalid") + } +} + +func TestChecker_BypassMode(t *testing.T) { + c := NewChecker(ModeBypass, nil, nil) + + err := c.Check(context.Background(), ToolInfo{Name: "bash", IsDestructive: true}, json.RawMessage(`{"command":"rm -rf /"}`)) + if err != nil { + t.Errorf("bypass mode should allow everything, got: %v", err) + } +} + +func TestChecker_BypassDenyRuleImmune(t *testing.T) { + rules := []Rule{{Tool: "bash", Pattern: "rm -rf", Action: ActionDeny}} + c := NewChecker(ModeBypass, rules, nil) + + err := c.Check(context.Background(), ToolInfo{Name: "bash"}, json.RawMessage(`{"command":"rm -rf /"}`)) + if err == nil { + t.Error("deny rules should override bypass mode") + } +} + +func TestChecker_DenyMode(t *testing.T) { + c := NewChecker(ModeDeny, nil, nil) + + err := c.Check(context.Background(), ToolInfo{Name: "bash"}, json.RawMessage(`{}`)) + if !errors.Is(err, ErrDenied) { + t.Error("deny mode should deny without allow rules") + } +} + +func TestChecker_DenyModeWithAllowRule(t *testing.T) { + rules := []Rule{{Tool: "fs.*", Action: ActionAllow}} + c := NewChecker(ModeDeny, rules, nil) + + // Allowed by rule + err := c.Check(context.Background(), ToolInfo{Name: "fs.read", IsReadOnly: true}, json.RawMessage(`{}`)) + if err != nil { + t.Errorf("should allow fs.read via rule: %v", err) + } + + // Not allowed — no matching rule + err = c.Check(context.Background(), ToolInfo{Name: "bash"}, json.RawMessage(`{}`)) + if !errors.Is(err, ErrDenied) { + t.Error("bash should be denied without allow rule") + } +} + +func TestChecker_PlanMode(t *testing.T) { + c := NewChecker(ModePlan, nil, nil) + + // Read-only allowed + err := c.Check(context.Background(), ToolInfo{Name: "fs.read", IsReadOnly: true}, json.RawMessage(`{}`)) + if err != nil { + t.Errorf("plan mode should allow read-only: %v", err) + } + + // Write denied + err = c.Check(context.Background(), ToolInfo{Name: "fs.write"}, json.RawMessage(`{}`)) + if !errors.Is(err, ErrDenied) { + t.Error("plan mode should deny writes") + } + + // Bash denied + err = c.Check(context.Background(), ToolInfo{Name: "bash"}, json.RawMessage(`{}`)) + if !errors.Is(err, ErrDenied) { + t.Error("plan mode should deny bash") + } +} + +func TestChecker_AcceptEditsMode(t *testing.T) { + c := NewChecker(ModeAcceptEdits, nil, func(_ context.Context, _ string, _ json.RawMessage) (bool, error) { + return false, nil // deny prompt + }) + + // Read-only allowed + err := c.Check(context.Background(), ToolInfo{Name: "fs.read", IsReadOnly: true}, json.RawMessage(`{}`)) + if err != nil { + t.Errorf("should allow read-only: %v", err) + } + + // File edits allowed + err = c.Check(context.Background(), ToolInfo{Name: "fs.write"}, json.RawMessage(`{}`)) + if err != nil { + t.Errorf("should allow fs.write in acceptEdits: %v", err) + } + + // Bash requires prompt — denied since our prompt returns false + err = c.Check(context.Background(), ToolInfo{Name: "bash"}, json.RawMessage(`{}`)) + if !errors.Is(err, ErrDenied) { + t.Error("bash should go through prompt in acceptEdits mode") + } +} + +func TestChecker_AutoMode(t *testing.T) { + c := NewChecker(ModeAuto, nil, func(_ context.Context, _ string, _ json.RawMessage) (bool, error) { + return true, nil // approve prompt + }) + + // Read-only auto-allowed + err := c.Check(context.Background(), ToolInfo{Name: "fs.grep", IsReadOnly: true}, json.RawMessage(`{}`)) + if err != nil { + t.Errorf("auto mode should auto-allow read-only: %v", err) + } + + // Write goes to prompt — approved + err = c.Check(context.Background(), ToolInfo{Name: "bash"}, json.RawMessage(`{}`)) + if err != nil { + t.Errorf("auto mode should prompt for write, prompt approved: %v", err) + } +} + +func TestChecker_DefaultMode_Prompts(t *testing.T) { + prompted := false + c := NewChecker(ModeDefault, nil, func(_ context.Context, name string, _ json.RawMessage) (bool, error) { + prompted = true + return true, nil + }) + + err := c.Check(context.Background(), ToolInfo{Name: "fs.read", IsReadOnly: true}, json.RawMessage(`{}`)) + if err != nil { + t.Errorf("should allow after prompt: %v", err) + } + if !prompted { + t.Error("default mode should always prompt") + } +} + +func TestChecker_SafetyCheck(t *testing.T) { + // Safety checks are bypass-immune + c := NewChecker(ModeBypass, nil, nil) + + tests := []struct { + name string + args string + }{ + {"env file", `{"path":".env"}`}, + {"git dir", `{"path":".git/config"}`}, + {"ssh key", `{"path":"id_rsa"}`}, + {"aws creds", `{"path":".aws/credentials"}`}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := c.Check(context.Background(), ToolInfo{Name: "fs.read"}, json.RawMessage(tt.args)) + if !errors.Is(err, ErrDenied) { + t.Errorf("safety check should block: %v", err) + } + }) + } +} + +func TestChecker_CompoundCommand(t *testing.T) { + rules := []Rule{{Tool: "bash", Pattern: "rm", Action: ActionDeny}} + c := NewChecker(ModeBypass, rules, nil) + + // Single safe command — allowed + err := c.Check(context.Background(), ToolInfo{Name: "bash"}, json.RawMessage(`{"command":"echo hello"}`)) + if err != nil { + t.Errorf("single safe command should be allowed: %v", err) + } + + // Compound with denied subcommand + err = c.Check(context.Background(), ToolInfo{Name: "bash"}, json.RawMessage(`{"command":"echo hello && rm -rf /"}`)) + if !errors.Is(err, ErrDenied) { + t.Error("compound with denied subcommand should be denied") + } +} + +func TestSplitCompoundCommand(t *testing.T) { + tests := []struct { + cmd string + want int + }{ + {"echo hello", 1}, + {"echo hello && echo world", 2}, + {"echo a; echo b; echo c", 3}, + {"echo hello | grep h", 1}, // pipe is one statement + {"cd src && make && make test", 3}, + } + for _, tt := range tests { + parts := SplitCompoundCommand(tt.cmd) + if len(parts) != tt.want { + t.Errorf("SplitCompoundCommand(%q) = %d parts %v, want %d", tt.cmd, len(parts), parts, tt.want) + } + } +} + +func TestRule_Matches(t *testing.T) { + tests := []struct { + rule Rule + tool string + want bool + }{ + {Rule{Tool: "bash"}, "bash", true}, + {Rule{Tool: "bash"}, "fs.read", false}, + {Rule{Tool: "fs.*"}, "fs.read", true}, + {Rule{Tool: "fs.*"}, "fs.write", true}, + {Rule{Tool: "fs.*"}, "bash", false}, + {Rule{Tool: "*"}, "anything", true}, + } + for _, tt := range tests { + if got := tt.rule.Matches(tt.tool); got != tt.want { + t.Errorf("Rule{%q}.Matches(%q) = %v, want %v", tt.rule.Tool, tt.tool, got, tt.want) + } + } +} + +func TestChecker_SetMode(t *testing.T) { + c := NewChecker(ModeDefault, nil, nil) + if c.Mode() != ModeDefault { + t.Errorf("initial mode should be default") + } + c.SetMode(ModePlan) + if c.Mode() != ModePlan { + t.Errorf("mode should be plan after SetMode") + } +} diff --git a/internal/permission/rule.go b/internal/permission/rule.go new file mode 100644 index 0000000..5a71854 --- /dev/null +++ b/internal/permission/rule.go @@ -0,0 +1,78 @@ +package permission + +import ( + "path/filepath" + "strings" + + "mvdan.cc/sh/v3/syntax" +) + +// Action is the decision for a permission rule. +type Action string + +const ( + ActionAllow Action = "allow" + ActionDeny Action = "deny" +) + +// Rule defines a single permission rule. +type Rule struct { + Tool string `toml:"tool"` // glob pattern: "bash", "fs.*", "*" + Pattern string `toml:"pattern"` // optional: argument pattern + Action Action `toml:"action"` +} + +// Matches returns true if the rule matches the given tool name. +func (r Rule) Matches(toolName string) bool { + matched, _ := filepath.Match(r.Tool, toolName) + return matched +} + +// SplitCompoundCommand decomposes a shell command into individual simple commands +// using a proper POSIX shell parser (mvdan.cc/sh). Recursively walks BinaryCmd +// nodes (&&, ||) and statement lists (;). +func SplitCompoundCommand(cmd string) []string { + reader := strings.NewReader(cmd) + parser := syntax.NewParser(syntax.KeepComments(false)) + + file, err := parser.Parse(reader, "") + if err != nil { + return []string{cmd} + } + + var commands []string + printer := syntax.NewPrinter() + + for _, stmt := range file.Stmts { + extractCommands(stmt.Cmd, printer, &commands) + } + + if len(commands) == 0 { + return []string{cmd} + } + return commands +} + +func extractCommands(node syntax.Command, printer *syntax.Printer, out *[]string) { + if node == nil { + return + } + // Only split on && and || (logical operators), not pipes + if bin, ok := node.(*syntax.BinaryCmd); ok { + if bin.Op == syntax.AndStmt || bin.Op == syntax.OrStmt { + if bin.X != nil { + extractCommands(bin.X.Cmd, printer, out) + } + if bin.Y != nil { + extractCommands(bin.Y.Cmd, printer, out) + } + return + } + } + // Everything else (simple command, pipe, subshell) — print as one unit + var b strings.Builder + printer.Print(&b, node) + if s := strings.TrimSpace(b.String()); s != "" { + *out = append(*out, s) + } +}