fix(engine): guard mutable state with a mutex

Engine.history, usage, activatedTools, modelCaps, turnOpts, and
cfg.Provider/Model are now mutated and read under e.mu. The lock is
released across blocking provider.Stream calls so external setters
(SetProvider, SetHistory, InjectMessage, etc.) can interleave.

History() now returns a copy. Snapshot helpers (latestUserPrompt,
historySnapshot, snapshotTurnOpts, etc.) replace the unsynchronised
reads scattered through runLoop and buildRequest.

Closes audit finding H4. Adds a race regression test that fails under
-race before the fix and passes after.
This commit is contained in:
2026-05-19 16:18:17 +02:00
parent 153a7e3cf9
commit b36ef564ab
3 changed files with 235 additions and 64 deletions
+123 -15
View File
@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"log/slog"
"sync"
gnomactx "somegit.dev/Owlibou/gnoma/internal/context"
"somegit.dev/Owlibou/gnoma/internal/hook"
@@ -59,21 +60,22 @@ type TurnOptions struct {
}
// Engine orchestrates the conversation.
//
// Mutable state (history, usage, activatedTools, modelCaps, turnOpts, and the
// hot fields of cfg — Provider/Model) is guarded by mu. The lock is released
// across blocking provider.Stream calls so external setters can interleave.
type Engine struct {
mu sync.Mutex
cfg Config
history []message.Message
usage message.Usage
logger *slog.Logger
// Cached model capabilities, resolved lazily
modelCaps *provider.Capabilities
modelCapsFor string // model ID the cached caps are for
modelCapsFor string
// Deferred tool loading: tools with ShouldDefer() are excluded until
// the model requests them. Activated on first use.
activatedTools map[string]bool
// Per-turn options, set for the duration of SubmitWithOptions.
turnOpts TurnOptions
}
@@ -145,18 +147,20 @@ func New(cfg Config) (*Engine, error) {
// resolveCapabilities returns the capabilities for the active model.
// Caches the result — re-resolves if the model changes.
func (e *Engine) resolveCapabilities(ctx context.Context) *provider.Capabilities {
e.mu.Lock()
model := e.cfg.Model
if model == "" {
model = e.cfg.Provider.DefaultModel()
}
// Return cached if same model
if e.modelCaps != nil && e.modelCapsFor == model {
return e.modelCaps
caps := e.modelCaps
e.mu.Unlock()
return caps
}
prov := e.cfg.Provider
e.mu.Unlock()
// Query provider for model list
models, err := e.cfg.Provider.Models(ctx)
models, err := prov.Models(ctx)
if err != nil {
e.logger.Debug("failed to fetch model capabilities", "error", err)
return nil
@@ -164,9 +168,12 @@ func (e *Engine) resolveCapabilities(ctx context.Context) *provider.Capabilities
for _, m := range models {
if m.ID == model {
e.mu.Lock()
e.modelCaps = &m.Capabilities
e.modelCapsFor = model
return e.modelCaps
caps := e.modelCaps
e.mu.Unlock()
return caps
}
}
@@ -174,9 +181,13 @@ func (e *Engine) resolveCapabilities(ctx context.Context) *provider.Capabilities
return nil
}
// History returns the full conversation.
// History returns a snapshot copy of the conversation.
func (e *Engine) History() []message.Message {
return e.history
e.mu.Lock()
defer e.mu.Unlock()
out := make([]message.Message, len(e.history))
copy(out, e.history)
return out
}
// ContextWindow returns the context window (may be nil).
@@ -188,7 +199,9 @@ func (e *Engine) ContextWindow() *gnomactx.Window {
// Used for system notifications (permission mode changes, incognito toggles) that
// the model should see as context in subsequent turns.
func (e *Engine) InjectMessage(msg message.Message) {
e.mu.Lock()
e.history = append(e.history, msg)
e.mu.Unlock()
if e.cfg.Context != nil {
e.cfg.Context.AppendMessage(msg)
}
@@ -196,23 +209,31 @@ func (e *Engine) InjectMessage(msg message.Message) {
// Usage returns cumulative token usage.
func (e *Engine) Usage() message.Usage {
e.mu.Lock()
defer e.mu.Unlock()
return e.usage
}
// SetProvider swaps the active provider (for dynamic switching).
func (e *Engine) SetProvider(p provider.Provider) {
e.mu.Lock()
e.cfg.Provider = p
e.mu.Unlock()
}
// SetModel changes the model within the current provider.
func (e *Engine) SetModel(model string) {
e.mu.Lock()
e.cfg.Model = model
e.mu.Unlock()
}
// SetHistory replaces the conversation history (for session restore).
// Also syncs the context window and re-estimates the tracker's token count.
func (e *Engine) SetHistory(msgs []message.Message) {
e.mu.Lock()
e.history = msgs
e.mu.Unlock()
if e.cfg.Context != nil {
e.cfg.Context.SetMessages(msgs)
e.cfg.Context.Tracker().Set(e.cfg.Context.Tracker().CountMessages(msgs))
@@ -221,12 +242,16 @@ func (e *Engine) SetHistory(msgs []message.Message) {
// SetUsage sets cumulative token usage (for session restore).
func (e *Engine) SetUsage(u message.Usage) {
e.mu.Lock()
e.usage = u
e.mu.Unlock()
}
// SetActivatedTools restores the set of activated deferred tools (for session restore).
func (e *Engine) SetActivatedTools(tools map[string]bool) {
e.mu.Lock()
e.activatedTools = tools
e.mu.Unlock()
}
// classify returns a Task for the given prompt using the configured classifier.
@@ -236,7 +261,11 @@ func (e *Engine) classify(ctx context.Context, prompt string) router.Task {
if cls == nil {
cls = router.HeuristicClassifier{}
}
task, err := cls.Classify(ctx, prompt, e.history)
e.mu.Lock()
histSnap := make([]message.Message, len(e.history))
copy(histSnap, e.history)
e.mu.Unlock()
task, err := cls.Classify(ctx, prompt, histSnap)
if err != nil {
e.logger.Debug("classifier error, falling back to heuristic", "error", err)
return router.ClassifyTask(prompt)
@@ -244,12 +273,91 @@ func (e *Engine) classify(ctx context.Context, prompt string) router.Task {
return task
}
// latestUserPrompt returns the text of the most recent user message.
func (e *Engine) latestUserPrompt() string {
e.mu.Lock()
defer e.mu.Unlock()
for i := len(e.history) - 1; i >= 0; i-- {
if e.history[i].Role == message.RoleUser {
return e.history[i].TextContent()
}
}
return ""
}
// historySnapshot returns a copy of the current history slice.
func (e *Engine) historySnapshot() []message.Message {
e.mu.Lock()
defer e.mu.Unlock()
out := make([]message.Message, len(e.history))
copy(out, e.history)
return out
}
// appendHistory appends a message under the lock.
func (e *Engine) appendHistory(msg message.Message) {
e.mu.Lock()
e.history = append(e.history, msg)
e.mu.Unlock()
}
// replaceHistory swaps the history slice (used after context compaction).
func (e *Engine) replaceHistory(msgs []message.Message) {
e.mu.Lock()
e.history = msgs
e.mu.Unlock()
}
// addUsage accumulates token usage.
func (e *Engine) addUsage(u message.Usage) {
e.mu.Lock()
e.usage.Add(u)
e.mu.Unlock()
}
// activeProvider returns the current provider under lock.
func (e *Engine) activeProvider() provider.Provider {
e.mu.Lock()
defer e.mu.Unlock()
return e.cfg.Provider
}
// activeModel returns the configured model name under lock.
func (e *Engine) activeModel() string {
e.mu.Lock()
defer e.mu.Unlock()
return e.cfg.Model
}
// snapshotTurnOpts returns a copy of the current per-turn options.
func (e *Engine) snapshotTurnOpts() TurnOptions {
e.mu.Lock()
defer e.mu.Unlock()
return e.turnOpts
}
// markToolActivated records that a deferred tool has been requested.
func (e *Engine) markToolActivated(name string) {
e.mu.Lock()
e.activatedTools[name] = true
e.mu.Unlock()
}
// isToolActivated reports whether a deferred tool has been activated.
func (e *Engine) isToolActivated(name string) bool {
e.mu.Lock()
defer e.mu.Unlock()
return e.activatedTools[name]
}
// Reset clears conversation history and usage.
func (e *Engine) Reset() {
e.mu.Lock()
e.history = nil
e.usage = message.Usage{}
e.activatedTools = make(map[string]bool)
e.mu.Unlock()
if e.cfg.Context != nil {
e.cfg.Context.Reset()
}
e.activatedTools = make(map[string]bool)
}
+35 -49
View File
@@ -28,11 +28,17 @@ func (e *Engine) Submit(ctx context.Context, input string, cb Callback) (*Turn,
// SubmitWithOptions is like Submit but applies per-turn overrides (e.g. ToolChoice).
func (e *Engine) SubmitWithOptions(ctx context.Context, input string, opts TurnOptions, cb Callback) (*Turn, error) {
e.mu.Lock()
e.turnOpts = opts
defer func() { e.turnOpts = TurnOptions{} }()
userMsg := message.NewUserText(input)
e.history = append(e.history, userMsg)
e.mu.Unlock()
defer func() {
e.mu.Lock()
e.turnOpts = TurnOptions{}
e.mu.Unlock()
}()
if e.cfg.Context != nil {
e.cfg.Context.AppendMessage(userMsg)
}
@@ -42,7 +48,9 @@ func (e *Engine) SubmitWithOptions(ctx context.Context, input string, opts TurnO
// SubmitMessages is like Submit but accepts pre-built messages.
func (e *Engine) SubmitMessages(ctx context.Context, msgs []message.Message, cb Callback) (*Turn, error) {
e.mu.Lock()
e.history = append(e.history, msgs...)
e.mu.Unlock()
if e.cfg.Context != nil {
for _, m := range msgs {
e.cfg.Context.AppendMessage(m)
@@ -89,14 +97,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
var decision router.RoutingDecision
if e.cfg.Router != nil {
// Classify task from the latest user message
prompt := ""
for i := len(e.history) - 1; i >= 0; i-- {
if e.history[i].Role == message.RoleUser {
prompt = e.history[i].TextContent()
break
}
}
prompt := e.latestUserPrompt()
task := e.classify(ctx, prompt)
if e.cfg.Context != nil {
task.EstimatedTokens = int(e.cfg.Context.Tracker().CountTokens(prompt))
@@ -131,14 +132,15 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
}
}
} else {
prov := e.activeProvider()
e.logger.Debug("streaming request",
"provider", e.cfg.Provider.Name(),
"provider", prov.Name(),
"model", req.Model,
"messages", len(req.Messages),
"tools", len(req.Tools),
"round", turn.Rounds,
)
s, err = e.cfg.Provider.Stream(ctx, req)
s, err = prov.Stream(ctx, req)
}
if err != nil {
var failedArms []router.ArmID
@@ -161,13 +163,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
// Retry on transient errors (429, 5xx) with exponential backoff
s, err = e.retryOnTransient(ctx, err, skipDelay, func() (stream.Stream, error) {
if e.cfg.Router != nil {
prompt := ""
for i := len(e.history) - 1; i >= 0; i-- {
if e.history[i].Role == message.RoleUser {
prompt = e.history[i].TextContent()
break
}
}
prompt := e.latestUserPrompt()
task := e.classify(ctx, prompt)
if e.cfg.Context != nil {
task.EstimatedTokens = int(e.cfg.Context.Tracker().CountTokens(prompt))
@@ -192,7 +188,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
}
return s, err
}
return e.cfg.Provider.Stream(ctx, req)
return e.activeProvider().Stream(ctx, req)
})
if err != nil {
// Try reactive compaction on 413 (request too large)
@@ -266,8 +262,8 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
turn.Usage.Add(resp.Usage)
turn.Messages = append(turn.Messages, resp.Message)
e.history = append(e.history, resp.Message)
e.usage.Add(resp.Usage)
e.appendHistory(resp.Message)
e.addUsage(resp.Usage)
// Track in context window and check for compaction
if e.cfg.Context != nil {
@@ -282,8 +278,9 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
if compacted, err := e.cfg.Context.CompactIfNeeded(); err != nil {
e.logger.Error("context compaction failed", "error", err)
} else if compacted {
e.history = e.cfg.Context.Messages()
e.logger.Info("context compacted", "messages", len(e.history))
compactedMsgs := e.cfg.Context.Messages()
e.replaceHistory(compactedMsgs)
e.logger.Info("context compacted", "messages", len(compactedMsgs))
}
}
@@ -304,7 +301,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
// Model hit its output token budget mid-response. Inject a continue prompt
// and re-query so the response is completed rather than silently truncated.
contMsg := message.NewUserText("Continue from where you left off.")
e.history = append(e.history, contMsg)
e.appendHistory(contMsg)
if e.cfg.Context != nil {
e.cfg.Context.AppendMessage(contMsg)
}
@@ -319,7 +316,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
}
toolMsg := message.NewToolResults(results...)
turn.Messages = append(turn.Messages, toolMsg)
e.history = append(e.history, toolMsg)
e.appendHistory(toolMsg)
if e.cfg.Context != nil {
e.cfg.Context.AppendMessage(toolMsg)
}
@@ -336,7 +333,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
func (e *Engine) buildRequest(ctx context.Context) provider.Request {
// Use AllMessages (prefix + history) if context window manages prefix docs
messages := e.history
messages := e.historySnapshot()
if e.cfg.Context != nil {
messages = e.cfg.Context.AllMessages()
}
@@ -351,11 +348,12 @@ func (e *Engine) buildRequest(ctx context.Context) provider.Request {
systemPrompt = e.cfg.Firewall.ScanSystemPrompt(systemPrompt)
}
turnOpts := e.snapshotTurnOpts()
req := provider.Request{
Model: e.cfg.Model,
Model: e.activeModel(),
SystemPrompt: systemPrompt,
Messages: messages,
ToolChoice: e.turnOpts.ToolChoice,
ToolChoice: turnOpts.ToolChoice,
Temperature: e.cfg.Temperature,
}
@@ -371,10 +369,10 @@ func (e *Engine) buildRequest(ctx context.Context) provider.Request {
includeTools = caps == nil || caps.ToolUse
}
if includeTools {
allowed := e.turnOpts.AllowedTools
allowed := turnOpts.AllowedTools
for _, t := range e.cfg.Tools.All() {
// Skip deferred tools until the model requests them
if dt, ok := t.(tool.DeferrableTool); ok && dt.ShouldDefer() && !e.activatedTools[t.Name()] {
if dt, ok := t.(tool.DeferrableTool); ok && dt.ShouldDefer() && !e.isToolActivated(t.Name()) {
continue
}
// Filter to allowed tools when a restrict list is set
@@ -399,13 +397,7 @@ func (e *Engine) buildRequest(ctx context.Context) provider.Request {
// Inject coordinator guidance for orchestration tasks
if e.cfg.Router != nil {
prompt := ""
for i := len(e.history) - 1; i >= 0; i-- {
if e.history[i].Role == message.RoleUser {
prompt = e.history[i].TextContent()
break
}
}
prompt := e.latestUserPrompt()
if e.classify(ctx, prompt).Type == router.TaskOrchestration {
req.SystemPrompt = coordinatorPrompt() + "\n\n" + req.SystemPrompt
}
@@ -453,7 +445,7 @@ func (e *Engine) executeTools(ctx context.Context, calls []message.ToolCall, cb
if ok {
// Activate deferred tools on first use
if dt, isDeferrable := t.(tool.DeferrableTool); isDeferrable && dt.ShouldDefer() {
e.activatedTools[call.Name] = true
e.markToolActivated(call.Name)
}
}
if !ok {
@@ -536,7 +528,7 @@ func (e *Engine) executeSingleTool(ctx context.Context, call message.ToolCall, t
}
// Path restriction: deny bash and validate fs tool paths against AllowedPaths.
if denied, blocked := checkPathRestriction(call, t, args, e.turnOpts.AllowedPaths); blocked {
if denied, blocked := checkPathRestriction(call, t, args, e.snapshotTurnOpts().AllowedPaths); blocked {
return denied
}
@@ -615,17 +607,11 @@ func (e *Engine) handleRequestTooLarge(ctx context.Context, origErr error, req p
return nil, origErr
}
e.history = e.cfg.Context.Messages()
e.replaceHistory(e.cfg.Context.Messages())
req = e.buildRequest(ctx)
if e.cfg.Router != nil {
prompt := ""
for i := len(e.history) - 1; i >= 0; i-- {
if e.history[i].Role == message.RoleUser {
prompt = e.history[i].TextContent()
break
}
}
prompt := e.latestUserPrompt()
task := e.classify(ctx, prompt)
if e.cfg.Context != nil {
task.EstimatedTokens = int(e.cfg.Context.Tracker().CountTokens(prompt))
@@ -635,7 +621,7 @@ func (e *Engine) handleRequestTooLarge(ctx context.Context, origErr error, req p
s, _, err := e.cfg.Router.Stream(ctx, task, req)
return s, err
}
return e.cfg.Provider.Stream(ctx, req)
return e.activeProvider().Stream(ctx, req)
}
// retryOnTransient retries the stream call on 429/5xx with exponential backoff.
+77
View File
@@ -0,0 +1,77 @@
package engine
import (
"context"
"sync"
"testing"
"somegit.dev/Owlibou/gnoma/internal/message"
"somegit.dev/Owlibou/gnoma/internal/stream"
"somegit.dev/Owlibou/gnoma/internal/tool"
)
// blockingStream emits one text delta, then blocks Next() until release is closed,
// then emits the stop event. Lets a test interleave Submit with concurrent setters.
type blockingStream struct {
release chan struct{}
emitted bool
released bool
stopReason message.StopReason
model string
}
func newBlockingStream(release chan struct{}, model string) *blockingStream {
return &blockingStream{release: release, model: model, stopReason: message.StopEndTurn}
}
func (s *blockingStream) Next() bool {
if !s.emitted {
s.emitted = true
return true
}
if !s.released {
<-s.release
s.released = true
return true
}
return false
}
func (s *blockingStream) Current() stream.Event {
if s.released {
return stream.Event{Type: stream.EventTextDelta, StopReason: s.stopReason, Model: s.model}
}
return stream.Event{Type: stream.EventTextDelta, Text: "hi", Model: s.model}
}
func (s *blockingStream) Err() error { return nil }
func (s *blockingStream) Close() error { return nil }
func TestEngine_ConcurrentSubmitAndSetters(t *testing.T) {
release := make(chan struct{})
mp := &mockProvider{
name: "test",
streams: []stream.Stream{newBlockingStream(release, "mock-model")},
}
e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()})
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
_, _ = e.Submit(context.Background(), "go", nil)
}()
go func() {
defer wg.Done()
for i := 0; i < 100; i++ {
e.InjectMessage(message.NewUserText("noise"))
_ = e.History()
_ = e.Usage()
}
close(release)
}()
wg.Wait()
}