From b36ef564ab74fe30a0462082575475105c0e5840 Mon Sep 17 00:00:00 2001 From: vikingowl Date: Tue, 19 May 2026 16:18:17 +0200 Subject: [PATCH] fix(engine): guard mutable state with a mutex Engine.history, usage, activatedTools, modelCaps, turnOpts, and cfg.Provider/Model are now mutated and read under e.mu. The lock is released across blocking provider.Stream calls so external setters (SetProvider, SetHistory, InjectMessage, etc.) can interleave. History() now returns a copy. Snapshot helpers (latestUserPrompt, historySnapshot, snapshotTurnOpts, etc.) replace the unsynchronised reads scattered through runLoop and buildRequest. Closes audit finding H4. Adds a race regression test that fails under -race before the fix and passes after. --- internal/engine/engine.go | 138 +++++++++++++++++++++++++++++++---- internal/engine/loop.go | 84 +++++++++------------ internal/engine/race_test.go | 77 +++++++++++++++++++ 3 files changed, 235 insertions(+), 64 deletions(-) create mode 100644 internal/engine/race_test.go diff --git a/internal/engine/engine.go b/internal/engine/engine.go index 55e292d..25c42b6 100644 --- a/internal/engine/engine.go +++ b/internal/engine/engine.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log/slog" + "sync" gnomactx "somegit.dev/Owlibou/gnoma/internal/context" "somegit.dev/Owlibou/gnoma/internal/hook" @@ -59,21 +60,22 @@ type TurnOptions struct { } // Engine orchestrates the conversation. +// +// Mutable state (history, usage, activatedTools, modelCaps, turnOpts, and the +// hot fields of cfg — Provider/Model) is guarded by mu. The lock is released +// across blocking provider.Stream calls so external setters can interleave. type Engine struct { + mu sync.Mutex cfg Config history []message.Message usage message.Usage logger *slog.Logger - // Cached model capabilities, resolved lazily modelCaps *provider.Capabilities - modelCapsFor string // model ID the cached caps are for + modelCapsFor string - // Deferred tool loading: tools with ShouldDefer() are excluded until - // the model requests them. Activated on first use. activatedTools map[string]bool - // Per-turn options, set for the duration of SubmitWithOptions. turnOpts TurnOptions } @@ -145,18 +147,20 @@ func New(cfg Config) (*Engine, error) { // resolveCapabilities returns the capabilities for the active model. // Caches the result — re-resolves if the model changes. func (e *Engine) resolveCapabilities(ctx context.Context) *provider.Capabilities { + e.mu.Lock() model := e.cfg.Model if model == "" { model = e.cfg.Provider.DefaultModel() } - - // Return cached if same model if e.modelCaps != nil && e.modelCapsFor == model { - return e.modelCaps + caps := e.modelCaps + e.mu.Unlock() + return caps } + prov := e.cfg.Provider + e.mu.Unlock() - // Query provider for model list - models, err := e.cfg.Provider.Models(ctx) + models, err := prov.Models(ctx) if err != nil { e.logger.Debug("failed to fetch model capabilities", "error", err) return nil @@ -164,9 +168,12 @@ func (e *Engine) resolveCapabilities(ctx context.Context) *provider.Capabilities for _, m := range models { if m.ID == model { + e.mu.Lock() e.modelCaps = &m.Capabilities e.modelCapsFor = model - return e.modelCaps + caps := e.modelCaps + e.mu.Unlock() + return caps } } @@ -174,9 +181,13 @@ func (e *Engine) resolveCapabilities(ctx context.Context) *provider.Capabilities return nil } -// History returns the full conversation. +// History returns a snapshot copy of the conversation. func (e *Engine) History() []message.Message { - return e.history + e.mu.Lock() + defer e.mu.Unlock() + out := make([]message.Message, len(e.history)) + copy(out, e.history) + return out } // ContextWindow returns the context window (may be nil). @@ -188,7 +199,9 @@ func (e *Engine) ContextWindow() *gnomactx.Window { // Used for system notifications (permission mode changes, incognito toggles) that // the model should see as context in subsequent turns. func (e *Engine) InjectMessage(msg message.Message) { + e.mu.Lock() e.history = append(e.history, msg) + e.mu.Unlock() if e.cfg.Context != nil { e.cfg.Context.AppendMessage(msg) } @@ -196,23 +209,31 @@ func (e *Engine) InjectMessage(msg message.Message) { // Usage returns cumulative token usage. func (e *Engine) Usage() message.Usage { + e.mu.Lock() + defer e.mu.Unlock() return e.usage } // SetProvider swaps the active provider (for dynamic switching). func (e *Engine) SetProvider(p provider.Provider) { + e.mu.Lock() e.cfg.Provider = p + e.mu.Unlock() } // SetModel changes the model within the current provider. func (e *Engine) SetModel(model string) { + e.mu.Lock() e.cfg.Model = model + e.mu.Unlock() } // SetHistory replaces the conversation history (for session restore). // Also syncs the context window and re-estimates the tracker's token count. func (e *Engine) SetHistory(msgs []message.Message) { + e.mu.Lock() e.history = msgs + e.mu.Unlock() if e.cfg.Context != nil { e.cfg.Context.SetMessages(msgs) e.cfg.Context.Tracker().Set(e.cfg.Context.Tracker().CountMessages(msgs)) @@ -221,12 +242,16 @@ func (e *Engine) SetHistory(msgs []message.Message) { // SetUsage sets cumulative token usage (for session restore). func (e *Engine) SetUsage(u message.Usage) { + e.mu.Lock() e.usage = u + e.mu.Unlock() } // SetActivatedTools restores the set of activated deferred tools (for session restore). func (e *Engine) SetActivatedTools(tools map[string]bool) { + e.mu.Lock() e.activatedTools = tools + e.mu.Unlock() } // classify returns a Task for the given prompt using the configured classifier. @@ -236,7 +261,11 @@ func (e *Engine) classify(ctx context.Context, prompt string) router.Task { if cls == nil { cls = router.HeuristicClassifier{} } - task, err := cls.Classify(ctx, prompt, e.history) + e.mu.Lock() + histSnap := make([]message.Message, len(e.history)) + copy(histSnap, e.history) + e.mu.Unlock() + task, err := cls.Classify(ctx, prompt, histSnap) if err != nil { e.logger.Debug("classifier error, falling back to heuristic", "error", err) return router.ClassifyTask(prompt) @@ -244,12 +273,91 @@ func (e *Engine) classify(ctx context.Context, prompt string) router.Task { return task } +// latestUserPrompt returns the text of the most recent user message. +func (e *Engine) latestUserPrompt() string { + e.mu.Lock() + defer e.mu.Unlock() + for i := len(e.history) - 1; i >= 0; i-- { + if e.history[i].Role == message.RoleUser { + return e.history[i].TextContent() + } + } + return "" +} + +// historySnapshot returns a copy of the current history slice. +func (e *Engine) historySnapshot() []message.Message { + e.mu.Lock() + defer e.mu.Unlock() + out := make([]message.Message, len(e.history)) + copy(out, e.history) + return out +} + +// appendHistory appends a message under the lock. +func (e *Engine) appendHistory(msg message.Message) { + e.mu.Lock() + e.history = append(e.history, msg) + e.mu.Unlock() +} + +// replaceHistory swaps the history slice (used after context compaction). +func (e *Engine) replaceHistory(msgs []message.Message) { + e.mu.Lock() + e.history = msgs + e.mu.Unlock() +} + +// addUsage accumulates token usage. +func (e *Engine) addUsage(u message.Usage) { + e.mu.Lock() + e.usage.Add(u) + e.mu.Unlock() +} + +// activeProvider returns the current provider under lock. +func (e *Engine) activeProvider() provider.Provider { + e.mu.Lock() + defer e.mu.Unlock() + return e.cfg.Provider +} + +// activeModel returns the configured model name under lock. +func (e *Engine) activeModel() string { + e.mu.Lock() + defer e.mu.Unlock() + return e.cfg.Model +} + +// snapshotTurnOpts returns a copy of the current per-turn options. +func (e *Engine) snapshotTurnOpts() TurnOptions { + e.mu.Lock() + defer e.mu.Unlock() + return e.turnOpts +} + +// markToolActivated records that a deferred tool has been requested. +func (e *Engine) markToolActivated(name string) { + e.mu.Lock() + e.activatedTools[name] = true + e.mu.Unlock() +} + +// isToolActivated reports whether a deferred tool has been activated. +func (e *Engine) isToolActivated(name string) bool { + e.mu.Lock() + defer e.mu.Unlock() + return e.activatedTools[name] +} + // Reset clears conversation history and usage. func (e *Engine) Reset() { + e.mu.Lock() e.history = nil e.usage = message.Usage{} + e.activatedTools = make(map[string]bool) + e.mu.Unlock() if e.cfg.Context != nil { e.cfg.Context.Reset() } - e.activatedTools = make(map[string]bool) } diff --git a/internal/engine/loop.go b/internal/engine/loop.go index 95ebe91..25effb2 100644 --- a/internal/engine/loop.go +++ b/internal/engine/loop.go @@ -28,11 +28,17 @@ func (e *Engine) Submit(ctx context.Context, input string, cb Callback) (*Turn, // SubmitWithOptions is like Submit but applies per-turn overrides (e.g. ToolChoice). func (e *Engine) SubmitWithOptions(ctx context.Context, input string, opts TurnOptions, cb Callback) (*Turn, error) { + e.mu.Lock() e.turnOpts = opts - defer func() { e.turnOpts = TurnOptions{} }() - userMsg := message.NewUserText(input) e.history = append(e.history, userMsg) + e.mu.Unlock() + defer func() { + e.mu.Lock() + e.turnOpts = TurnOptions{} + e.mu.Unlock() + }() + if e.cfg.Context != nil { e.cfg.Context.AppendMessage(userMsg) } @@ -42,7 +48,9 @@ func (e *Engine) SubmitWithOptions(ctx context.Context, input string, opts TurnO // SubmitMessages is like Submit but accepts pre-built messages. func (e *Engine) SubmitMessages(ctx context.Context, msgs []message.Message, cb Callback) (*Turn, error) { + e.mu.Lock() e.history = append(e.history, msgs...) + e.mu.Unlock() if e.cfg.Context != nil { for _, m := range msgs { e.cfg.Context.AppendMessage(m) @@ -89,14 +97,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) { var decision router.RoutingDecision if e.cfg.Router != nil { - // Classify task from the latest user message - prompt := "" - for i := len(e.history) - 1; i >= 0; i-- { - if e.history[i].Role == message.RoleUser { - prompt = e.history[i].TextContent() - break - } - } + prompt := e.latestUserPrompt() task := e.classify(ctx, prompt) if e.cfg.Context != nil { task.EstimatedTokens = int(e.cfg.Context.Tracker().CountTokens(prompt)) @@ -131,14 +132,15 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) { } } } else { + prov := e.activeProvider() e.logger.Debug("streaming request", - "provider", e.cfg.Provider.Name(), + "provider", prov.Name(), "model", req.Model, "messages", len(req.Messages), "tools", len(req.Tools), "round", turn.Rounds, ) - s, err = e.cfg.Provider.Stream(ctx, req) + s, err = prov.Stream(ctx, req) } if err != nil { var failedArms []router.ArmID @@ -161,13 +163,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) { // Retry on transient errors (429, 5xx) with exponential backoff s, err = e.retryOnTransient(ctx, err, skipDelay, func() (stream.Stream, error) { if e.cfg.Router != nil { - prompt := "" - for i := len(e.history) - 1; i >= 0; i-- { - if e.history[i].Role == message.RoleUser { - prompt = e.history[i].TextContent() - break - } - } + prompt := e.latestUserPrompt() task := e.classify(ctx, prompt) if e.cfg.Context != nil { task.EstimatedTokens = int(e.cfg.Context.Tracker().CountTokens(prompt)) @@ -192,7 +188,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) { } return s, err } - return e.cfg.Provider.Stream(ctx, req) + return e.activeProvider().Stream(ctx, req) }) if err != nil { // Try reactive compaction on 413 (request too large) @@ -266,8 +262,8 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) { turn.Usage.Add(resp.Usage) turn.Messages = append(turn.Messages, resp.Message) - e.history = append(e.history, resp.Message) - e.usage.Add(resp.Usage) + e.appendHistory(resp.Message) + e.addUsage(resp.Usage) // Track in context window and check for compaction if e.cfg.Context != nil { @@ -282,8 +278,9 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) { if compacted, err := e.cfg.Context.CompactIfNeeded(); err != nil { e.logger.Error("context compaction failed", "error", err) } else if compacted { - e.history = e.cfg.Context.Messages() - e.logger.Info("context compacted", "messages", len(e.history)) + compactedMsgs := e.cfg.Context.Messages() + e.replaceHistory(compactedMsgs) + e.logger.Info("context compacted", "messages", len(compactedMsgs)) } } @@ -304,7 +301,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) { // Model hit its output token budget mid-response. Inject a continue prompt // and re-query so the response is completed rather than silently truncated. contMsg := message.NewUserText("Continue from where you left off.") - e.history = append(e.history, contMsg) + e.appendHistory(contMsg) if e.cfg.Context != nil { e.cfg.Context.AppendMessage(contMsg) } @@ -319,7 +316,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) { } toolMsg := message.NewToolResults(results...) turn.Messages = append(turn.Messages, toolMsg) - e.history = append(e.history, toolMsg) + e.appendHistory(toolMsg) if e.cfg.Context != nil { e.cfg.Context.AppendMessage(toolMsg) } @@ -336,7 +333,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) { func (e *Engine) buildRequest(ctx context.Context) provider.Request { // Use AllMessages (prefix + history) if context window manages prefix docs - messages := e.history + messages := e.historySnapshot() if e.cfg.Context != nil { messages = e.cfg.Context.AllMessages() } @@ -351,11 +348,12 @@ func (e *Engine) buildRequest(ctx context.Context) provider.Request { systemPrompt = e.cfg.Firewall.ScanSystemPrompt(systemPrompt) } + turnOpts := e.snapshotTurnOpts() req := provider.Request{ - Model: e.cfg.Model, + Model: e.activeModel(), SystemPrompt: systemPrompt, Messages: messages, - ToolChoice: e.turnOpts.ToolChoice, + ToolChoice: turnOpts.ToolChoice, Temperature: e.cfg.Temperature, } @@ -371,10 +369,10 @@ func (e *Engine) buildRequest(ctx context.Context) provider.Request { includeTools = caps == nil || caps.ToolUse } if includeTools { - allowed := e.turnOpts.AllowedTools + allowed := turnOpts.AllowedTools for _, t := range e.cfg.Tools.All() { // Skip deferred tools until the model requests them - if dt, ok := t.(tool.DeferrableTool); ok && dt.ShouldDefer() && !e.activatedTools[t.Name()] { + if dt, ok := t.(tool.DeferrableTool); ok && dt.ShouldDefer() && !e.isToolActivated(t.Name()) { continue } // Filter to allowed tools when a restrict list is set @@ -399,13 +397,7 @@ func (e *Engine) buildRequest(ctx context.Context) provider.Request { // Inject coordinator guidance for orchestration tasks if e.cfg.Router != nil { - prompt := "" - for i := len(e.history) - 1; i >= 0; i-- { - if e.history[i].Role == message.RoleUser { - prompt = e.history[i].TextContent() - break - } - } + prompt := e.latestUserPrompt() if e.classify(ctx, prompt).Type == router.TaskOrchestration { req.SystemPrompt = coordinatorPrompt() + "\n\n" + req.SystemPrompt } @@ -453,7 +445,7 @@ func (e *Engine) executeTools(ctx context.Context, calls []message.ToolCall, cb if ok { // Activate deferred tools on first use if dt, isDeferrable := t.(tool.DeferrableTool); isDeferrable && dt.ShouldDefer() { - e.activatedTools[call.Name] = true + e.markToolActivated(call.Name) } } if !ok { @@ -536,7 +528,7 @@ func (e *Engine) executeSingleTool(ctx context.Context, call message.ToolCall, t } // Path restriction: deny bash and validate fs tool paths against AllowedPaths. - if denied, blocked := checkPathRestriction(call, t, args, e.turnOpts.AllowedPaths); blocked { + if denied, blocked := checkPathRestriction(call, t, args, e.snapshotTurnOpts().AllowedPaths); blocked { return denied } @@ -615,17 +607,11 @@ func (e *Engine) handleRequestTooLarge(ctx context.Context, origErr error, req p return nil, origErr } - e.history = e.cfg.Context.Messages() + e.replaceHistory(e.cfg.Context.Messages()) req = e.buildRequest(ctx) if e.cfg.Router != nil { - prompt := "" - for i := len(e.history) - 1; i >= 0; i-- { - if e.history[i].Role == message.RoleUser { - prompt = e.history[i].TextContent() - break - } - } + prompt := e.latestUserPrompt() task := e.classify(ctx, prompt) if e.cfg.Context != nil { task.EstimatedTokens = int(e.cfg.Context.Tracker().CountTokens(prompt)) @@ -635,7 +621,7 @@ func (e *Engine) handleRequestTooLarge(ctx context.Context, origErr error, req p s, _, err := e.cfg.Router.Stream(ctx, task, req) return s, err } - return e.cfg.Provider.Stream(ctx, req) + return e.activeProvider().Stream(ctx, req) } // retryOnTransient retries the stream call on 429/5xx with exponential backoff. diff --git a/internal/engine/race_test.go b/internal/engine/race_test.go new file mode 100644 index 0000000..c045885 --- /dev/null +++ b/internal/engine/race_test.go @@ -0,0 +1,77 @@ +package engine + +import ( + "context" + "sync" + "testing" + + "somegit.dev/Owlibou/gnoma/internal/message" + "somegit.dev/Owlibou/gnoma/internal/stream" + "somegit.dev/Owlibou/gnoma/internal/tool" +) + +// blockingStream emits one text delta, then blocks Next() until release is closed, +// then emits the stop event. Lets a test interleave Submit with concurrent setters. +type blockingStream struct { + release chan struct{} + emitted bool + released bool + stopReason message.StopReason + model string +} + +func newBlockingStream(release chan struct{}, model string) *blockingStream { + return &blockingStream{release: release, model: model, stopReason: message.StopEndTurn} +} + +func (s *blockingStream) Next() bool { + if !s.emitted { + s.emitted = true + return true + } + if !s.released { + <-s.release + s.released = true + return true + } + return false +} + +func (s *blockingStream) Current() stream.Event { + if s.released { + return stream.Event{Type: stream.EventTextDelta, StopReason: s.stopReason, Model: s.model} + } + return stream.Event{Type: stream.EventTextDelta, Text: "hi", Model: s.model} +} + +func (s *blockingStream) Err() error { return nil } +func (s *blockingStream) Close() error { return nil } + +func TestEngine_ConcurrentSubmitAndSetters(t *testing.T) { + release := make(chan struct{}) + mp := &mockProvider{ + name: "test", + streams: []stream.Stream{newBlockingStream(release, "mock-model")}, + } + e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()}) + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + _, _ = e.Submit(context.Background(), "go", nil) + }() + + go func() { + defer wg.Done() + for i := 0; i < 100; i++ { + e.InjectMessage(message.NewUserText("noise")) + _ = e.History() + _ = e.Usage() + } + close(release) + }() + + wg.Wait() +}