From 1aa1d83e9e29b569ae3134a7d8b5140b277efd07 Mon Sep 17 00:00:00 2001 From: vikingowl Date: Tue, 7 Apr 2026 00:53:53 +0200 Subject: [PATCH] =?UTF-8?q?feat:=20PromptExecutor=20=E2=80=94=20LLM-based?= =?UTF-8?q?=20hook=20evaluation=20via=20router?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/hook/prompt.go | 135 +++++++++++++++++++++ internal/hook/prompt_test.go | 223 +++++++++++++++++++++++++++++++++++ 2 files changed, 358 insertions(+) create mode 100644 internal/hook/prompt.go create mode 100644 internal/hook/prompt_test.go diff --git a/internal/hook/prompt.go b/internal/hook/prompt.go new file mode 100644 index 0000000..2603619 --- /dev/null +++ b/internal/hook/prompt.go @@ -0,0 +1,135 @@ +package hook + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "strings" + "text/template" + "time" + + "somegit.dev/Owlibou/gnoma/internal/message" + "somegit.dev/Owlibou/gnoma/internal/stream" +) + +// TemplateData holds the variables available in hook prompt templates. +type TemplateData struct { + Event string + Tool string + Args string + Result string +} + +// renderTemplate executes a text/template with the given data. +func renderTemplate(tmpl string, data TemplateData) (string, error) { + t, err := template.New("hook").Parse(tmpl) + if err != nil { + return "", fmt.Errorf("hook: template parse error: %w", err) + } + var buf bytes.Buffer + if err := t.Execute(&buf, data); err != nil { + return "", fmt.Errorf("hook: template execute error: %w", err) + } + return buf.String(), nil +} + +// parseDecision scans text for the first case-insensitive occurrence of +// "ALLOW" or "DENY". Returns Skip if neither is found. +func parseDecision(text string) Action { + upper := strings.ToUpper(text) + ai := strings.Index(upper, "ALLOW") + di := strings.Index(upper, "DENY") + switch { + case ai >= 0 && (di < 0 || ai < di): + return Allow + case di >= 0: + return Deny + default: + return Skip + } +} + +// Streamer is the minimal interface PromptExecutor needs from the router. +// *router.Router satisfies this interface via an adapter in main.go. +type Streamer interface { + Stream(ctx context.Context, prompt string) (stream.Stream, error) +} + +// PromptExecutor sends a templated prompt to an LLM and parses ALLOW/DENY +// from the response. +type PromptExecutor struct { + def HookDef + streamer Streamer +} + +// NewPromptExecutor constructs a PromptExecutor. +func NewPromptExecutor(def HookDef, streamer Streamer) *PromptExecutor { + return &PromptExecutor{def: def, streamer: streamer} +} + +// Execute renders the template, sends the prompt, and parses the response. +func (p *PromptExecutor) Execute(ctx context.Context, payload []byte) (HookResult, error) { + data := templateDataFromPayload(payload, p.def.Event) + prompt, err := renderTemplate(p.def.Exec, data) + if err != nil { + return HookResult{}, fmt.Errorf("hook %q: %w", p.def.Name, err) + } + + start := time.Now() + s, err := p.streamer.Stream(ctx, prompt) + if err != nil { + return HookResult{}, fmt.Errorf("hook %q: stream error: %w", p.def.Name, err) + } + defer s.Close() + + acc := stream.NewAccumulator() + var stopReason message.StopReason + var model string + for s.Next() { + evt := s.Current() + acc.Apply(evt) + if evt.StopReason != "" { + stopReason = evt.StopReason + model = evt.Model + } + } + if err := s.Err(); err != nil { + return HookResult{}, fmt.Errorf("hook %q: stream error: %w", p.def.Name, err) + } + + resp := acc.Response(stopReason, model) + text := resp.Message.TextContent() + + action := parseDecision(text) + return HookResult{ + Action: action, + Output: []byte(text), + Duration: time.Since(start), + }, nil +} + +// templateDataFromPayload builds TemplateData from a hook payload. +func templateDataFromPayload(payload []byte, event EventType) TemplateData { + data := TemplateData{Event: event.String()} + if event == PreToolUse || event == PostToolUse { + data.Tool = ExtractToolName(payload) + data.Args = extractRawField(payload, "args") + data.Result = extractRawField(payload, "result") + } + return data +} + +// extractRawField returns the JSON-encoded value of a top-level field. +// Returns "" if absent or on error. +func extractRawField(payload []byte, field string) string { + var v map[string]json.RawMessage + if err := json.Unmarshal(payload, &v); err != nil { + return "" + } + raw, ok := v[field] + if !ok { + return "" + } + return string(raw) +} diff --git a/internal/hook/prompt_test.go b/internal/hook/prompt_test.go new file mode 100644 index 0000000..98e7065 --- /dev/null +++ b/internal/hook/prompt_test.go @@ -0,0 +1,223 @@ +package hook + +import ( + "context" + "errors" + "testing" + + "somegit.dev/Owlibou/gnoma/internal/message" + "somegit.dev/Owlibou/gnoma/internal/stream" +) + +// mockStreamer implements Streamer by returning a pre-built stream or an error. +type mockStreamer struct { + s stream.Stream + err error +} + +func (m *mockStreamer) Stream(_ context.Context, prompt string) (stream.Stream, error) { + return m.s, m.err +} + +// textStream returns a single-event stream with the given text. +func textStream(text string) stream.Stream { + events := []stream.Event{ + {Type: stream.EventTextDelta, Text: text}, + {Type: stream.EventTextDelta, StopReason: message.StopEndTurn}, + } + return &sliceStream{events: events} +} + +type sliceStream struct { + events []stream.Event + idx int +} + +func (s *sliceStream) Next() bool { s.idx++; return s.idx <= len(s.events) } +func (s *sliceStream) Current() stream.Event { return s.events[s.idx-1] } +func (s *sliceStream) Err() error { return nil } +func (s *sliceStream) Close() error { return nil } + +// --- Template rendering tests --- + +func TestRenderTemplate_EventVar(t *testing.T) { + got, err := renderTemplate("event is {{.Event}}", TemplateData{Event: "pre_tool_use"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != "event is pre_tool_use" { + t.Errorf("got %q", got) + } +} + +func TestRenderTemplate_AllVars(t *testing.T) { + tmpl := "{{.Event}} {{.Tool}} {{.Args}} {{.Result}}" + data := TemplateData{Event: "pre_tool_use", Tool: "bash", Args: `{"cmd":"ls"}`, Result: ""} + got, err := renderTemplate(tmpl, data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != `pre_tool_use bash {"cmd":"ls"} ` { + t.Errorf("got %q", got) + } +} + +func TestRenderTemplate_NonToolEvent_EmptyToolFields(t *testing.T) { + tmpl := "[{{.Tool}}][{{.Args}}][{{.Result}}]" + data := TemplateData{Event: "session_start"} // Tool/Args/Result are zero values + got, err := renderTemplate(tmpl, data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != "[][][]" { + t.Errorf("got %q", got) + } +} + +func TestRenderTemplate_InvalidTemplate(t *testing.T) { + _, err := renderTemplate("{{.Unknown field", TemplateData{}) + if err == nil { + t.Error("expected error for invalid template") + } +} + +// --- parseDecision tests --- + +func TestParseDecision_ALLOW(t *testing.T) { + if got := parseDecision("The action is ALLOW."); got != Allow { + t.Errorf("got %v, want Allow", got) + } +} + +func TestParseDecision_DENY(t *testing.T) { + if got := parseDecision("I must DENY this request."); got != Deny { + t.Errorf("got %v, want Deny", got) + } +} + +func TestParseDecision_NoMatch(t *testing.T) { + if got := parseDecision("I don't know."); got != Skip { + t.Errorf("got %v, want Skip", got) + } +} + +func TestParseDecision_CaseInsensitive(t *testing.T) { + cases := []struct { + text string + want Action + }{ + {"allow", Allow}, + {"Allow", Allow}, + {"ALLOW", Allow}, + {"deny", Deny}, + {"Deny", Deny}, + {"DENY", Deny}, + } + for _, tt := range cases { + if got := parseDecision(tt.text); got != tt.want { + t.Errorf("parseDecision(%q) = %v, want %v", tt.text, got, tt.want) + } + } +} + +func TestParseDecision_FirstMatchWins(t *testing.T) { + // "DENY" appears before "ALLOW" → Deny + if got := parseDecision("I will DENY this, not ALLOW."); got != Deny { + t.Errorf("got %v, want Deny (first match)", got) + } +} + +// --- PromptExecutor tests --- + +func TestPromptExecutor_ResponseALLOW(t *testing.T) { + def := HookDef{Name: "test", Event: PreToolUse, Command: CommandTypePrompt, Exec: "Is this safe? ALLOW or DENY."} + ex := NewPromptExecutor(def, &mockStreamer{s: textStream("This is safe. ALLOW.")}) + result, err := ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Action != Allow { + t.Errorf("action = %v, want Allow", result.Action) + } +} + +func TestPromptExecutor_ResponseDENY(t *testing.T) { + def := HookDef{Name: "test", Event: PreToolUse, Command: CommandTypePrompt, Exec: "Is this safe? ALLOW or DENY."} + ex := NewPromptExecutor(def, &mockStreamer{s: textStream("This is dangerous. DENY.")}) + result, err := ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Action != Deny { + t.Errorf("action = %v, want Deny", result.Action) + } +} + +func TestPromptExecutor_ResponseNoMatch_Skip(t *testing.T) { + def := HookDef{Name: "test", Event: PreToolUse, Command: CommandTypePrompt, Exec: "Review this."} + ex := NewPromptExecutor(def, &mockStreamer{s: textStream("I'm not sure what to do.")}) + result, err := ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result.Action != Skip { + t.Errorf("action = %v, want Skip", result.Action) + } +} + +func TestPromptExecutor_TemplateRendered(t *testing.T) { + // Verify template vars are substituted — use a streamer that captures the prompt. + var capturedPrompt string + capturingStreamer := &capturingStreamer{response: "ALLOW"} + def := HookDef{ + Name: "test", + Event: PreToolUse, + Command: CommandTypePrompt, + Exec: "Tool={{.Tool}} Event={{.Event}}", + } + ex := NewPromptExecutor(def, capturingStreamer) + ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil)) + capturedPrompt = capturingStreamer.prompt + if capturedPrompt == "" { + t.Fatal("prompt not captured") + } + if capturedPrompt != "Tool=bash Event=pre_tool_use" { + t.Errorf("prompt = %q", capturedPrompt) + } +} + +func TestPromptExecutor_OutputIsFullResponse(t *testing.T) { + def := HookDef{Name: "test", Event: PreToolUse, Command: CommandTypePrompt, Exec: "Review."} + response := "After analysis, ALLOW this operation." + ex := NewPromptExecutor(def, &mockStreamer{s: textStream(response)}) + result, _ := ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil)) + if result.Error != nil { + t.Fatalf("unexpected error: %v", result.Error) + } + // Output field carries the full LLM response text (for observability) + if string(result.Output) != response { + t.Errorf("Output = %q, want %q", result.Output, response) + } +} + +func TestPromptExecutor_StreamerError(t *testing.T) { + def := HookDef{Name: "test", Event: PreToolUse, Command: CommandTypePrompt, Exec: "Review."} + ex := NewPromptExecutor(def, &mockStreamer{err: errors.New("provider unavailable")}) + result, err := ex.Execute(context.Background(), MarshalPreToolPayload("bash", nil)) + if err == nil { + t.Fatal("expected error") + } + // fail_open=false (default) → Deny on error; but error is returned, caller (Dispatcher) applies policy + _ = result +} + +// capturingStreamer records the prompt it was called with. +type capturingStreamer struct { + prompt string + response string +} + +func (c *capturingStreamer) Stream(_ context.Context, prompt string) (stream.Stream, error) { + c.prompt = prompt + return textStream(c.response), nil +}