From 03d3a5d016c5d8d30e157031a9168ed47a8a50d7 Mon Sep 17 00:00:00 2001 From: vikingowl Date: Tue, 7 Apr 2026 01:02:55 +0200 Subject: [PATCH] =?UTF-8?q?feat:=20engine=20hook=20integration=20=E2=80=94?= =?UTF-8?q?=20PreToolUse,=20PostToolUse,=20Stop?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/engine/engine.go | 2 + internal/engine/hook_integration_test.go | 346 +++++++++++++++++++++++ internal/engine/loop.go | 34 ++- internal/hook/agent.go | 30 +- internal/hook/agent_test.go | 93 ++---- internal/hook/dispatcher.go | 38 ++- internal/hook/payload.go | 15 + 7 files changed, 469 insertions(+), 89 deletions(-) create mode 100644 internal/engine/hook_integration_test.go diff --git a/internal/engine/engine.go b/internal/engine/engine.go index b4560e6..f2ad98d 100644 --- a/internal/engine/engine.go +++ b/internal/engine/engine.go @@ -12,6 +12,7 @@ import ( "somegit.dev/Owlibou/gnoma/internal/router" "somegit.dev/Owlibou/gnoma/internal/security" "somegit.dev/Owlibou/gnoma/internal/tool" + "somegit.dev/Owlibou/gnoma/internal/hook" "somegit.dev/Owlibou/gnoma/internal/tool/persist" ) @@ -27,6 +28,7 @@ type Config struct { Model string // override model (empty = provider default) MaxTurns int // safety limit on tool loops (0 = unlimited) Store *persist.Store // nil = no result persistence + Hooks *hook.Dispatcher // nil = no hooks Logger *slog.Logger } diff --git a/internal/engine/hook_integration_test.go b/internal/engine/hook_integration_test.go new file mode 100644 index 0000000..bd0ecbc --- /dev/null +++ b/internal/engine/hook_integration_test.go @@ -0,0 +1,346 @@ +package engine + +import ( + "context" + "encoding/json" + "strings" + "testing" + + "somegit.dev/Owlibou/gnoma/internal/hook" + "somegit.dev/Owlibou/gnoma/internal/message" + "somegit.dev/Owlibou/gnoma/internal/stream" + "somegit.dev/Owlibou/gnoma/internal/tool" +) + +// --- test executors --- + +type blockingExecutor struct{} + +func (b *blockingExecutor) Execute(_ context.Context, _ []byte) (hook.HookResult, error) { + return hook.HookResult{Action: hook.Deny}, nil +} + +type allowingExecutor struct{} + +func (a *allowingExecutor) Execute(_ context.Context, _ []byte) (hook.HookResult, error) { + return hook.HookResult{Action: hook.Allow}, nil +} + +// argTransformExecutor replaces the "args" field in the payload. +type argTransformExecutor struct{ newArgs json.RawMessage } + +func (t *argTransformExecutor) Execute(_ context.Context, payload []byte) (hook.HookResult, error) { + out, _ := json.Marshal(map[string]any{ + "tool": hook.ExtractToolName(payload), + "args": t.newArgs, + }) + return hook.HookResult{Action: hook.Allow, Output: out}, nil +} + +// resultTransformExecutor replaces the tool output. +type resultTransformExecutor struct{ newOutput string } + +func (r *resultTransformExecutor) Execute(_ context.Context, _ []byte) (hook.HookResult, error) { + out, _ := json.Marshal(map[string]any{"output": r.newOutput}) + return hook.HookResult{Action: hook.Allow, Output: out}, nil +} + +// recordingExecutor records whether it was called and the payload. +type recordingExecutor struct { + called bool + payload []byte +} + +func (r *recordingExecutor) Execute(_ context.Context, payload []byte) (hook.HookResult, error) { + r.called = true + r.payload = append([]byte(nil), payload...) + return hook.HookResult{Action: hook.Allow}, nil +} + +// --- helpers --- + +func hookDispatcher(event hook.EventType, ex hook.Executor) *hook.Dispatcher { + def := hook.HookDef{Name: "test", Event: event, Command: hook.CommandTypeShell, Exec: "x"} + d := &hook.Dispatcher{} + d.SetChain(event, []hook.Handler{hook.NewHandler(def, ex)}) + return d +} + +// toolCallStream builds a stream that emits a single tool call then stops. +func toolCallStream(callID, toolName, args string, stopReason message.StopReason, model string) stream.Stream { + events := []stream.Event{ + {Type: stream.EventToolCallDone, ToolCallID: callID, ToolCallName: toolName, Args: json.RawMessage(args)}, + {Type: stream.EventTextDelta, StopReason: stopReason, Model: model}, + } + return &eventStream{events: events} +} + +// --- tests --- + +func TestHook_NilDispatcher_NoChange(t *testing.T) { + mp := &mockProvider{ + streams: []stream.Stream{ + newEventStream(message.StopEndTurn, "m", + stream.Event{Type: stream.EventTextDelta, Text: "hello"}, + ), + }, + } + eng, err := New(Config{Provider: mp, Tools: tool.NewRegistry()}) + if err != nil { + t.Fatal(err) + } + turn, err := eng.Submit(context.Background(), "hi", nil) + if err != nil { + t.Fatal(err) + } + if turn.Rounds != 1 { + t.Errorf("rounds = %d, want 1", turn.Rounds) + } +} + +func TestHook_PreToolUse_Deny(t *testing.T) { + executed := false + reg := tool.NewRegistry() + reg.Register(&mockTool{ + name: "bash", + execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) { + executed = true + return tool.Result{Output: "should not run"}, nil + }, + }) + + mp := &mockProvider{ + streams: []stream.Stream{ + toolCallStream("c1", "bash", `{"command":"rm -rf /"}`, message.StopToolUse, "m"), + newEventStream(message.StopEndTurn, "m", + stream.Event{Type: stream.EventTextDelta, Text: "ok"}, + ), + }, + } + + eng, _ := New(Config{ + Provider: mp, + Tools: reg, + Hooks: hookDispatcher(hook.PreToolUse, &blockingExecutor{}), + }) + eng.Submit(context.Background(), "run", nil) + + if executed { + t.Error("tool was executed despite PreToolUse deny") + } +} + +func TestHook_PreToolUse_Allow(t *testing.T) { + executed := false + reg := tool.NewRegistry() + reg.Register(&mockTool{ + name: "bash", + execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) { + executed = true + return tool.Result{Output: "ran"}, nil + }, + }) + + mp := &mockProvider{ + streams: []stream.Stream{ + toolCallStream("c1", "bash", `{}`, message.StopToolUse, "m"), + newEventStream(message.StopEndTurn, "m", + stream.Event{Type: stream.EventTextDelta, Text: "ok"}, + ), + }, + } + + eng, _ := New(Config{ + Provider: mp, + Tools: reg, + Hooks: hookDispatcher(hook.PreToolUse, &allowingExecutor{}), + }) + eng.Submit(context.Background(), "run", nil) + + if !executed { + t.Error("tool was not executed despite PreToolUse allow") + } +} + +func TestHook_PreToolUse_DenyMessage(t *testing.T) { + reg := tool.NewRegistry() + reg.Register(&mockTool{ + name: "bash", + execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) { + return tool.Result{Output: "should not run"}, nil + }, + }) + + mp := &mockProvider{ + streams: []stream.Stream{ + toolCallStream("c1", "bash", `{}`, message.StopToolUse, "m"), + newEventStream(message.StopEndTurn, "m", + stream.Event{Type: stream.EventTextDelta, Text: "ok"}, + ), + }, + } + + eng, _ := New(Config{ + Provider: mp, + Tools: reg, + Hooks: hookDispatcher(hook.PreToolUse, &blockingExecutor{}), + }) + eng.Submit(context.Background(), "run", nil) + + for _, msg := range eng.History() { + for _, c := range msg.Content { + if c.Type == message.ContentToolResult && c.ToolResult != nil { + if !strings.HasPrefix(c.ToolResult.Content, "denied by hook") { + t.Errorf("denied result = %q, want prefix 'denied by hook'", c.ToolResult.Content) + } + return + } + } + } + t.Error("no tool result found in history") +} + +func TestHook_PreToolUse_Transform(t *testing.T) { + var receivedArgs json.RawMessage + reg := tool.NewRegistry() + reg.Register(&mockTool{ + name: "bash", + execFn: func(_ context.Context, args json.RawMessage) (tool.Result, error) { + receivedArgs = args + return tool.Result{Output: "ok"}, nil + }, + }) + + mp := &mockProvider{ + streams: []stream.Stream{ + toolCallStream("c1", "bash", `{"command":"original"}`, message.StopToolUse, "m"), + newEventStream(message.StopEndTurn, "m", + stream.Event{Type: stream.EventTextDelta, Text: "done"}, + ), + }, + } + + eng, _ := New(Config{ + Provider: mp, + Tools: reg, + Hooks: hookDispatcher(hook.PreToolUse, + &argTransformExecutor{newArgs: json.RawMessage(`{"command":"safe-replacement"}`)}), + }) + eng.Submit(context.Background(), "run", nil) + + var got map[string]string + json.Unmarshal(receivedArgs, &got) + if got["command"] != "safe-replacement" { + t.Errorf("tool args = %s, want safe-replacement", receivedArgs) + } +} + +func TestHook_PostToolUse_Transform(t *testing.T) { + reg := tool.NewRegistry() + reg.Register(&mockTool{ + name: "bash", + execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) { + return tool.Result{Output: "original output"}, nil + }, + }) + + mp := &mockProvider{ + streams: []stream.Stream{ + toolCallStream("c1", "bash", `{}`, message.StopToolUse, "m"), + newEventStream(message.StopEndTurn, "m", + stream.Event{Type: stream.EventTextDelta, Text: "done"}, + ), + }, + } + + eng, _ := New(Config{ + Provider: mp, + Tools: reg, + Hooks: hookDispatcher(hook.PostToolUse, + &resultTransformExecutor{newOutput: "transformed output"}), + }) + eng.Submit(context.Background(), "run", nil) + + for _, msg := range eng.History() { + for _, c := range msg.Content { + if c.Type == message.ContentToolResult && c.ToolResult != nil { + if c.ToolResult.Content != "transformed output" { + t.Errorf("tool result = %q, want 'transformed output'", c.ToolResult.Content) + } + return + } + } + } + t.Error("no tool result found in history") +} + +func TestHook_PostToolUse_DenyTreatedAsSkip(t *testing.T) { + reg := tool.NewRegistry() + reg.Register(&mockTool{ + name: "bash", + execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) { + return tool.Result{Output: "tool ran"}, nil + }, + }) + + mp := &mockProvider{ + streams: []stream.Stream{ + toolCallStream("c1", "bash", `{}`, message.StopToolUse, "m"), + newEventStream(message.StopEndTurn, "m", + stream.Event{Type: stream.EventTextDelta, Text: "done"}, + ), + }, + } + + eng, _ := New(Config{ + Provider: mp, + Tools: reg, + Hooks: hookDispatcher(hook.PostToolUse, &blockingExecutor{}), + }) + turn, err := eng.Submit(context.Background(), "run", nil) + if err != nil { + t.Fatal(err) + } + // 2 rounds = tool call + end turn, confirming the result reached the LLM. + if turn.Rounds != 2 { + t.Errorf("rounds = %d, want 2 (result reached LLM despite PostToolUse deny)", turn.Rounds) + } +} + +func TestHook_Stop_MaxTurns(t *testing.T) { + // Stop hook fires when MaxTurns is exceeded. + stopRecorder := &recordingExecutor{} + reg := tool.NewRegistry() + reg.Register(&mockTool{ + name: "bash", + execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) { + return tool.Result{Output: "ok"}, nil + }, + }) + + mp := &mockProvider{ + streams: []stream.Stream{ + // Round 1: tool call → will loop to round 2 + toolCallStream("c1", "bash", `{}`, message.StopToolUse, "m"), + // Round 2: MaxTurns=1 check triggers before this, so it's never consumed + }, + } + + d := &hook.Dispatcher{} + d.SetChain(hook.Stop, []hook.Handler{ + hook.NewHandler( + hook.HookDef{Name: "stop-rec", Event: hook.Stop, Command: hook.CommandTypeShell, Exec: "x"}, + stopRecorder, + ), + }) + + eng, _ := New(Config{Provider: mp, Tools: reg, Hooks: d, MaxTurns: 1}) + _, err := eng.Submit(context.Background(), "run", nil) + // MaxTurns exceeded returns an error + if err == nil { + t.Fatal("expected error for MaxTurns exceeded") + } + if !stopRecorder.called { + t.Error("Stop hook was not fired on MaxTurns exceeded") + } +} diff --git a/internal/engine/loop.go b/internal/engine/loop.go index 9d78a54..728512d 100644 --- a/internal/engine/loop.go +++ b/internal/engine/loop.go @@ -8,6 +8,7 @@ import ( "time" gnomactx "somegit.dev/Owlibou/gnoma/internal/context" + "somegit.dev/Owlibou/gnoma/internal/hook" "somegit.dev/Owlibou/gnoma/internal/message" "somegit.dev/Owlibou/gnoma/internal/permission" "somegit.dev/Owlibou/gnoma/internal/provider" @@ -55,6 +56,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) { for { turn.Rounds++ if e.cfg.MaxTurns > 0 && turn.Rounds > e.cfg.MaxTurns { + e.cfg.Hooks.Fire(hook.Stop, hook.MarshalStopPayload("max_turns")) //nolint:errcheck return turn, fmt.Errorf("safety limit: %d rounds exceeded", e.cfg.MaxTurns) } @@ -227,6 +229,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) { // Decide next action switch resp.StopReason { case message.StopEndTurn, message.StopSequence: + e.cfg.Hooks.Fire(hook.Stop, hook.MarshalStopPayload("end_turn")) //nolint:errcheck return turn, nil case message.StopMaxTokens: @@ -254,6 +257,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) { default: // Unknown stop reason or empty — treat as end of turn + e.cfg.Hooks.Fire(hook.Stop, hook.MarshalStopPayload("unknown")) //nolint:errcheck return turn, nil } } @@ -411,9 +415,26 @@ func (e *Engine) executeSingleTool(ctx context.Context, call message.ToolCall, t } } + // PreToolUse hook: can deny execution or transform args. + args := call.Arguments + if e.cfg.Hooks != nil { + payload := hook.MarshalPreToolPayload(call.Name, args) + transformed, action, _ := e.cfg.Hooks.Fire(hook.PreToolUse, payload) + if action == hook.Deny { + return message.ToolResult{ + ToolCallID: call.ID, + Content: "denied by hook", + IsError: true, + } + } + if newArgs := hook.ExtractTransformedArgs(transformed); newArgs != nil { + args = newArgs + } + } + e.logger.Debug("executing tool", "name", call.Name, "id", call.ID) - result, err := t.Execute(ctx, call.Arguments) + result, err := t.Execute(ctx, args) if err != nil { e.logger.Error("tool execution failed", "name", call.Name, "error", err) return message.ToolResult{ @@ -423,8 +444,17 @@ func (e *Engine) executeSingleTool(ctx context.Context, call message.ToolCall, t } } - // Scan tool result through firewall + // PostToolUse hook: can transform result (Deny treated as Skip). output := result.Output + if e.cfg.Hooks != nil { + payload := hook.MarshalPostToolPayload(call.Name, args, output, result.Metadata) + transformed, _, _ := e.cfg.Hooks.Fire(hook.PostToolUse, payload) + if s := hook.ExtractTransformedOutput(transformed); s != "" { + output = s + } + } + + // Scan tool result through firewall if e.cfg.Firewall != nil { output = e.cfg.Firewall.ScanToolResult(output) } diff --git a/internal/hook/agent.go b/internal/hook/agent.go index 504d5ae..c7a7003 100644 --- a/internal/hook/agent.go +++ b/internal/hook/agent.go @@ -4,25 +4,21 @@ import ( "context" "fmt" "time" - - "somegit.dev/Owlibou/gnoma/internal/elf" - "somegit.dev/Owlibou/gnoma/internal/router" ) -// ElfSpawner is the minimal interface AgentExecutor needs from elf.Manager. -type ElfSpawner interface { - Spawn(ctx context.Context, taskType router.TaskType, prompt, systemPrompt string, maxTurns int) (elf.Elf, error) -} +// ElfSpawnFn spawns an elf with the given prompt and returns its output text. +// This is satisfied by a closure wrapping elf.Manager.Spawn in main.go. +type ElfSpawnFn func(ctx context.Context, prompt string) (output string, err error) // AgentExecutor spawns an elf and parses ALLOW/DENY from its output. type AgentExecutor struct { def HookDef - spawner ElfSpawner + spawnFn ElfSpawnFn } // NewAgentExecutor constructs an AgentExecutor. -func NewAgentExecutor(def HookDef, spawner ElfSpawner) *AgentExecutor { - return &AgentExecutor{def: def, spawner: spawner} +func NewAgentExecutor(def HookDef, spawnFn ElfSpawnFn) *AgentExecutor { + return &AgentExecutor{def: def, spawnFn: spawnFn} } // Execute renders the hook template, spawns an elf, waits for its result, @@ -35,19 +31,13 @@ func (a *AgentExecutor) Execute(ctx context.Context, payload []byte) (HookResult } start := time.Now() - e, err := a.spawner.Spawn(ctx, router.TaskReview, prompt, "", 5) - if err != nil { - return HookResult{}, fmt.Errorf("hook %q: spawn elf: %w", a.def.Name, err) - } - - result := e.Wait() + output, err := a.spawnFn(ctx, prompt) duration := time.Since(start) - - if result.Error != nil { - return HookResult{Duration: duration}, fmt.Errorf("hook %q: elf failed: %w", a.def.Name, result.Error) + if err != nil { + return HookResult{Duration: duration}, fmt.Errorf("hook %q: elf failed: %w", a.def.Name, err) } - action := parseDecision(result.Output) + action := parseDecision(output) return HookResult{ Action: action, Duration: duration, diff --git a/internal/hook/agent_test.go b/internal/hook/agent_test.go index bb75797..03dcbd7 100644 --- a/internal/hook/agent_test.go +++ b/internal/hook/agent_test.go @@ -5,45 +5,32 @@ import ( "errors" "testing" "time" - - "somegit.dev/Owlibou/gnoma/internal/elf" - "somegit.dev/Owlibou/gnoma/internal/router" - "somegit.dev/Owlibou/gnoma/internal/stream" ) -// mockElfSpawner satisfies ElfSpawner. Records calls and returns configurable results. -type mockElfSpawner struct { - result elf.Result - err error - // Captures - lastPrompt string - lastTask router.TaskType -} - -func (m *mockElfSpawner) Spawn(ctx context.Context, taskType router.TaskType, prompt, systemPrompt string, maxTurns int) (elf.Elf, error) { - m.lastPrompt = prompt - m.lastTask = taskType - if m.err != nil { - return nil, m.err +func spawnFnOK(output string) ElfSpawnFn { + return func(_ context.Context, _ string) (string, error) { + return output, nil } - return &immediateElf{result: m.result}, nil } -// immediateElf returns a pre-computed result immediately. -type immediateElf struct { - result elf.Result +func spawnFnErr(err error) ElfSpawnFn { + return func(_ context.Context, _ string) (string, error) { + return "", err + } } -func (e *immediateElf) ID() string { return "test-elf" } -func (e *immediateElf) Status() elf.Status { return e.result.Status } -func (e *immediateElf) Events() <-chan stream.Event { return nil } -func (e *immediateElf) Wait() elf.Result { return e.result } -func (e *immediateElf) Cancel() {} +func capturingSpawnFn(output string) (ElfSpawnFn, *string) { + captured := new(string) + fn := func(_ context.Context, prompt string) (string, error) { + *captured = prompt + return output, nil + } + return fn, captured +} func TestAgentExecutor_OutputALLOW(t *testing.T) { def := HookDef{Name: "test", Event: PreToolUse, Command: CommandTypeAgent, Exec: "Review this tool call."} - spawner := &mockElfSpawner{result: elf.Result{Output: "After analysis, ALLOW this.", Status: elf.StatusCompleted}} - ex := NewAgentExecutor(def, spawner) + ex := NewAgentExecutor(def, spawnFnOK("After analysis, ALLOW this.")) result, err := ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil)) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -55,8 +42,7 @@ func TestAgentExecutor_OutputALLOW(t *testing.T) { func TestAgentExecutor_OutputDENY(t *testing.T) { def := HookDef{Name: "test", Event: PreToolUse, Command: CommandTypeAgent, Exec: "Review this."} - spawner := &mockElfSpawner{result: elf.Result{Output: "This is dangerous. DENY.", Status: elf.StatusCompleted}} - ex := NewAgentExecutor(def, spawner) + ex := NewAgentExecutor(def, spawnFnOK("This is dangerous. DENY.")) result, err := ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil)) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -68,8 +54,7 @@ func TestAgentExecutor_OutputDENY(t *testing.T) { func TestAgentExecutor_OutputNoMatch_Skip(t *testing.T) { def := HookDef{Name: "test", Event: PreToolUse, Command: CommandTypeAgent, Exec: "Review this."} - spawner := &mockElfSpawner{result: elf.Result{Output: "I'm unsure.", Status: elf.StatusCompleted}} - ex := NewAgentExecutor(def, spawner) + ex := NewAgentExecutor(def, spawnFnOK("I'm unsure.")) result, err := ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil)) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -79,24 +64,9 @@ func TestAgentExecutor_OutputNoMatch_Skip(t *testing.T) { } } -func TestAgentExecutor_ElfFailure_Error(t *testing.T) { - def := HookDef{Name: "test", Event: PreToolUse, Command: CommandTypeAgent, Exec: "Review."} - spawner := &mockElfSpawner{result: elf.Result{ - Output: "", - Status: elf.StatusFailed, - Error: errors.New("elf crashed"), - }} - ex := NewAgentExecutor(def, spawner) - _, err := ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil)) - if err == nil { - t.Error("expected error for failed elf") - } -} - func TestAgentExecutor_SpawnError(t *testing.T) { def := HookDef{Name: "test", Event: PreToolUse, Command: CommandTypeAgent, Exec: "Review."} - spawner := &mockElfSpawner{err: errors.New("no arms available")} - ex := NewAgentExecutor(def, spawner) + ex := NewAgentExecutor(def, spawnFnErr(errors.New("no arms available"))) _, err := ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil)) if err == nil { t.Error("expected error when spawn fails") @@ -110,30 +80,23 @@ func TestAgentExecutor_TemplateRendered(t *testing.T) { Command: CommandTypeAgent, Exec: "Tool={{.Tool}} Event={{.Event}}", } - spawner := &mockElfSpawner{result: elf.Result{Output: "ALLOW", Status: elf.StatusCompleted}} - ex := NewAgentExecutor(def, spawner) + fn, captured := capturingSpawnFn("ALLOW") + ex := NewAgentExecutor(def, fn) ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil)) - if spawner.lastPrompt != "Tool=bash Event=pre_tool_use" { - t.Errorf("prompt = %q", spawner.lastPrompt) + if *captured != "Tool=bash Event=pre_tool_use" { + t.Errorf("prompt = %q", *captured) } } func TestAgentExecutor_Duration(t *testing.T) { def := HookDef{Name: "test", Event: PreToolUse, Command: CommandTypeAgent, Exec: "Review."} - spawner := &mockElfSpawner{result: elf.Result{Output: "ALLOW", Status: elf.StatusCompleted, Duration: 100 * time.Millisecond}} - ex := NewAgentExecutor(def, spawner) + fn := func(_ context.Context, _ string) (string, error) { + time.Sleep(1 * time.Millisecond) + return "ALLOW", nil + } + ex := NewAgentExecutor(def, fn) result, _ := ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil)) if result.Duration <= 0 { t.Error("expected Duration > 0") } } - -func TestAgentExecutor_TaskTypeIsReview(t *testing.T) { - def := HookDef{Name: "test", Event: PreToolUse, Command: CommandTypeAgent, Exec: "Review."} - spawner := &mockElfSpawner{result: elf.Result{Output: "ALLOW", Status: elf.StatusCompleted}} - ex := NewAgentExecutor(def, spawner) - ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil)) - if spawner.lastTask != router.TaskReview { - t.Errorf("task type = %v, want TaskReview", spawner.lastTask) - } -} diff --git a/internal/hook/dispatcher.go b/internal/hook/dispatcher.go index 0fddbd0..68d88d7 100644 --- a/internal/hook/dispatcher.go +++ b/internal/hook/dispatcher.go @@ -13,9 +13,23 @@ type Dispatcher struct { logger *slog.Logger } +// SetChain replaces the handler chain for an event. Primarily for testing. +func (d *Dispatcher) SetChain(event EventType, handlers []Handler) { + if d.chains == nil { + d.chains = make(map[EventType][]Handler) + } + d.chains[event] = handlers +} + +// NewHandler constructs a Handler from a definition and executor. +func NewHandler(def HookDef, ex Executor) Handler { + return Handler{def: def, executor: ex} +} + // NewDispatcher validates defs, constructs the appropriate executor per // CommandType, and groups handlers by EventType. -func NewDispatcher(defs []HookDef, logger *slog.Logger, executorFn func(HookDef) (Executor, error)) (*Dispatcher, error) { +// streamer and spawnFn may be nil if no prompt/agent hooks are configured. +func NewDispatcher(defs []HookDef, streamer Streamer, spawnFn ElfSpawnFn, logger *slog.Logger) (*Dispatcher, error) { if logger == nil { logger = slog.Default() } @@ -27,7 +41,7 @@ func NewDispatcher(defs []HookDef, logger *slog.Logger, executorFn func(HookDef) if err := def.Validate(); err != nil { return nil, fmt.Errorf("hook.NewDispatcher: %w", err) } - ex, err := executorFn(def) + ex, err := buildExecutor(def, streamer, spawnFn) if err != nil { return nil, fmt.Errorf("hook.NewDispatcher: building executor for %q: %w", def.Name, err) } @@ -36,6 +50,26 @@ func NewDispatcher(defs []HookDef, logger *slog.Logger, executorFn func(HookDef) return d, nil } +// buildExecutor constructs the right Executor for a HookDef. +func buildExecutor(def HookDef, streamer Streamer, spawnFn ElfSpawnFn) (Executor, error) { + switch def.Command { + case CommandTypeShell: + return NewCommandExecutor(def), nil + case CommandTypePrompt: + if streamer == nil { + return nil, fmt.Errorf("prompt hook %q requires a Streamer (no router configured)", def.Name) + } + return NewPromptExecutor(def, streamer), nil + case CommandTypeAgent: + if spawnFn == nil { + return nil, fmt.Errorf("agent hook %q requires an ElfSpawnFn (no elf manager configured)", def.Name) + } + return NewAgentExecutor(def, spawnFn), nil + default: + return nil, fmt.Errorf("unknown command type %v", def.Command) + } +} + // Fire runs all handlers registered for event, in order. // Returns the (possibly transformed) payload, the aggregate Action, and the first error. // Safe to call on a nil *Dispatcher — returns (payload, Allow, nil). diff --git a/internal/hook/payload.go b/internal/hook/payload.go index 27a40a4..388e29b 100644 --- a/internal/hook/payload.go +++ b/internal/hook/payload.go @@ -136,6 +136,21 @@ func parseActionString(s string) (Action, error) { } } +// ExtractTransformedArgs extracts the "args" field from a transformed PreToolUse payload. +// Returns nil if the field is absent or the payload is malformed. +func ExtractTransformedArgs(payload []byte) json.RawMessage { + if payload == nil { + return nil + } + var v struct { + Args json.RawMessage `json:"args"` + } + if err := json.Unmarshal(payload, &v); err != nil { + return nil + } + return v.Args +} + // ExtractTransformedOutput extracts the "output" string from a PostToolUse // transformed payload. Returns "" if the payload is nil or malformed. func ExtractTransformedOutput(transformed json.RawMessage) string {