fix(engine): guard mutable state with a mutex
Engine.history, usage, activatedTools, modelCaps, turnOpts, and cfg.Provider/Model are now mutated and read under e.mu. The lock is released across blocking provider.Stream calls so external setters (SetProvider, SetHistory, InjectMessage, etc.) can interleave. History() now returns a copy. Snapshot helpers (latestUserPrompt, historySnapshot, snapshotTurnOpts, etc.) replace the unsynchronised reads scattered through runLoop and buildRequest. Closes audit finding H4. Adds a race regression test that fails under -race before the fix and passes after.
This commit is contained in:
+123
-15
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"sync"
|
||||
|
||||
gnomactx "somegit.dev/Owlibou/gnoma/internal/context"
|
||||
"somegit.dev/Owlibou/gnoma/internal/hook"
|
||||
@@ -59,21 +60,22 @@ type TurnOptions struct {
|
||||
}
|
||||
|
||||
// Engine orchestrates the conversation.
|
||||
//
|
||||
// Mutable state (history, usage, activatedTools, modelCaps, turnOpts, and the
|
||||
// hot fields of cfg — Provider/Model) is guarded by mu. The lock is released
|
||||
// across blocking provider.Stream calls so external setters can interleave.
|
||||
type Engine struct {
|
||||
mu sync.Mutex
|
||||
cfg Config
|
||||
history []message.Message
|
||||
usage message.Usage
|
||||
logger *slog.Logger
|
||||
|
||||
// Cached model capabilities, resolved lazily
|
||||
modelCaps *provider.Capabilities
|
||||
modelCapsFor string // model ID the cached caps are for
|
||||
modelCapsFor string
|
||||
|
||||
// Deferred tool loading: tools with ShouldDefer() are excluded until
|
||||
// the model requests them. Activated on first use.
|
||||
activatedTools map[string]bool
|
||||
|
||||
// Per-turn options, set for the duration of SubmitWithOptions.
|
||||
turnOpts TurnOptions
|
||||
}
|
||||
|
||||
@@ -145,18 +147,20 @@ func New(cfg Config) (*Engine, error) {
|
||||
// resolveCapabilities returns the capabilities for the active model.
|
||||
// Caches the result — re-resolves if the model changes.
|
||||
func (e *Engine) resolveCapabilities(ctx context.Context) *provider.Capabilities {
|
||||
e.mu.Lock()
|
||||
model := e.cfg.Model
|
||||
if model == "" {
|
||||
model = e.cfg.Provider.DefaultModel()
|
||||
}
|
||||
|
||||
// Return cached if same model
|
||||
if e.modelCaps != nil && e.modelCapsFor == model {
|
||||
return e.modelCaps
|
||||
caps := e.modelCaps
|
||||
e.mu.Unlock()
|
||||
return caps
|
||||
}
|
||||
prov := e.cfg.Provider
|
||||
e.mu.Unlock()
|
||||
|
||||
// Query provider for model list
|
||||
models, err := e.cfg.Provider.Models(ctx)
|
||||
models, err := prov.Models(ctx)
|
||||
if err != nil {
|
||||
e.logger.Debug("failed to fetch model capabilities", "error", err)
|
||||
return nil
|
||||
@@ -164,9 +168,12 @@ func (e *Engine) resolveCapabilities(ctx context.Context) *provider.Capabilities
|
||||
|
||||
for _, m := range models {
|
||||
if m.ID == model {
|
||||
e.mu.Lock()
|
||||
e.modelCaps = &m.Capabilities
|
||||
e.modelCapsFor = model
|
||||
return e.modelCaps
|
||||
caps := e.modelCaps
|
||||
e.mu.Unlock()
|
||||
return caps
|
||||
}
|
||||
}
|
||||
|
||||
@@ -174,9 +181,13 @@ func (e *Engine) resolveCapabilities(ctx context.Context) *provider.Capabilities
|
||||
return nil
|
||||
}
|
||||
|
||||
// History returns the full conversation.
|
||||
// History returns a snapshot copy of the conversation.
|
||||
func (e *Engine) History() []message.Message {
|
||||
return e.history
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
out := make([]message.Message, len(e.history))
|
||||
copy(out, e.history)
|
||||
return out
|
||||
}
|
||||
|
||||
// ContextWindow returns the context window (may be nil).
|
||||
@@ -188,7 +199,9 @@ func (e *Engine) ContextWindow() *gnomactx.Window {
|
||||
// Used for system notifications (permission mode changes, incognito toggles) that
|
||||
// the model should see as context in subsequent turns.
|
||||
func (e *Engine) InjectMessage(msg message.Message) {
|
||||
e.mu.Lock()
|
||||
e.history = append(e.history, msg)
|
||||
e.mu.Unlock()
|
||||
if e.cfg.Context != nil {
|
||||
e.cfg.Context.AppendMessage(msg)
|
||||
}
|
||||
@@ -196,23 +209,31 @@ func (e *Engine) InjectMessage(msg message.Message) {
|
||||
|
||||
// Usage returns cumulative token usage.
|
||||
func (e *Engine) Usage() message.Usage {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
return e.usage
|
||||
}
|
||||
|
||||
// SetProvider swaps the active provider (for dynamic switching).
|
||||
func (e *Engine) SetProvider(p provider.Provider) {
|
||||
e.mu.Lock()
|
||||
e.cfg.Provider = p
|
||||
e.mu.Unlock()
|
||||
}
|
||||
|
||||
// SetModel changes the model within the current provider.
|
||||
func (e *Engine) SetModel(model string) {
|
||||
e.mu.Lock()
|
||||
e.cfg.Model = model
|
||||
e.mu.Unlock()
|
||||
}
|
||||
|
||||
// SetHistory replaces the conversation history (for session restore).
|
||||
// Also syncs the context window and re-estimates the tracker's token count.
|
||||
func (e *Engine) SetHistory(msgs []message.Message) {
|
||||
e.mu.Lock()
|
||||
e.history = msgs
|
||||
e.mu.Unlock()
|
||||
if e.cfg.Context != nil {
|
||||
e.cfg.Context.SetMessages(msgs)
|
||||
e.cfg.Context.Tracker().Set(e.cfg.Context.Tracker().CountMessages(msgs))
|
||||
@@ -221,12 +242,16 @@ func (e *Engine) SetHistory(msgs []message.Message) {
|
||||
|
||||
// SetUsage sets cumulative token usage (for session restore).
|
||||
func (e *Engine) SetUsage(u message.Usage) {
|
||||
e.mu.Lock()
|
||||
e.usage = u
|
||||
e.mu.Unlock()
|
||||
}
|
||||
|
||||
// SetActivatedTools restores the set of activated deferred tools (for session restore).
|
||||
func (e *Engine) SetActivatedTools(tools map[string]bool) {
|
||||
e.mu.Lock()
|
||||
e.activatedTools = tools
|
||||
e.mu.Unlock()
|
||||
}
|
||||
|
||||
// classify returns a Task for the given prompt using the configured classifier.
|
||||
@@ -236,7 +261,11 @@ func (e *Engine) classify(ctx context.Context, prompt string) router.Task {
|
||||
if cls == nil {
|
||||
cls = router.HeuristicClassifier{}
|
||||
}
|
||||
task, err := cls.Classify(ctx, prompt, e.history)
|
||||
e.mu.Lock()
|
||||
histSnap := make([]message.Message, len(e.history))
|
||||
copy(histSnap, e.history)
|
||||
e.mu.Unlock()
|
||||
task, err := cls.Classify(ctx, prompt, histSnap)
|
||||
if err != nil {
|
||||
e.logger.Debug("classifier error, falling back to heuristic", "error", err)
|
||||
return router.ClassifyTask(prompt)
|
||||
@@ -244,12 +273,91 @@ func (e *Engine) classify(ctx context.Context, prompt string) router.Task {
|
||||
return task
|
||||
}
|
||||
|
||||
// latestUserPrompt returns the text of the most recent user message.
|
||||
func (e *Engine) latestUserPrompt() string {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
for i := len(e.history) - 1; i >= 0; i-- {
|
||||
if e.history[i].Role == message.RoleUser {
|
||||
return e.history[i].TextContent()
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// historySnapshot returns a copy of the current history slice.
|
||||
func (e *Engine) historySnapshot() []message.Message {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
out := make([]message.Message, len(e.history))
|
||||
copy(out, e.history)
|
||||
return out
|
||||
}
|
||||
|
||||
// appendHistory appends a message under the lock.
|
||||
func (e *Engine) appendHistory(msg message.Message) {
|
||||
e.mu.Lock()
|
||||
e.history = append(e.history, msg)
|
||||
e.mu.Unlock()
|
||||
}
|
||||
|
||||
// replaceHistory swaps the history slice (used after context compaction).
|
||||
func (e *Engine) replaceHistory(msgs []message.Message) {
|
||||
e.mu.Lock()
|
||||
e.history = msgs
|
||||
e.mu.Unlock()
|
||||
}
|
||||
|
||||
// addUsage accumulates token usage.
|
||||
func (e *Engine) addUsage(u message.Usage) {
|
||||
e.mu.Lock()
|
||||
e.usage.Add(u)
|
||||
e.mu.Unlock()
|
||||
}
|
||||
|
||||
// activeProvider returns the current provider under lock.
|
||||
func (e *Engine) activeProvider() provider.Provider {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
return e.cfg.Provider
|
||||
}
|
||||
|
||||
// activeModel returns the configured model name under lock.
|
||||
func (e *Engine) activeModel() string {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
return e.cfg.Model
|
||||
}
|
||||
|
||||
// snapshotTurnOpts returns a copy of the current per-turn options.
|
||||
func (e *Engine) snapshotTurnOpts() TurnOptions {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
return e.turnOpts
|
||||
}
|
||||
|
||||
// markToolActivated records that a deferred tool has been requested.
|
||||
func (e *Engine) markToolActivated(name string) {
|
||||
e.mu.Lock()
|
||||
e.activatedTools[name] = true
|
||||
e.mu.Unlock()
|
||||
}
|
||||
|
||||
// isToolActivated reports whether a deferred tool has been activated.
|
||||
func (e *Engine) isToolActivated(name string) bool {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
return e.activatedTools[name]
|
||||
}
|
||||
|
||||
// Reset clears conversation history and usage.
|
||||
func (e *Engine) Reset() {
|
||||
e.mu.Lock()
|
||||
e.history = nil
|
||||
e.usage = message.Usage{}
|
||||
e.activatedTools = make(map[string]bool)
|
||||
e.mu.Unlock()
|
||||
if e.cfg.Context != nil {
|
||||
e.cfg.Context.Reset()
|
||||
}
|
||||
e.activatedTools = make(map[string]bool)
|
||||
}
|
||||
|
||||
+35
-49
@@ -28,11 +28,17 @@ func (e *Engine) Submit(ctx context.Context, input string, cb Callback) (*Turn,
|
||||
|
||||
// SubmitWithOptions is like Submit but applies per-turn overrides (e.g. ToolChoice).
|
||||
func (e *Engine) SubmitWithOptions(ctx context.Context, input string, opts TurnOptions, cb Callback) (*Turn, error) {
|
||||
e.mu.Lock()
|
||||
e.turnOpts = opts
|
||||
defer func() { e.turnOpts = TurnOptions{} }()
|
||||
|
||||
userMsg := message.NewUserText(input)
|
||||
e.history = append(e.history, userMsg)
|
||||
e.mu.Unlock()
|
||||
defer func() {
|
||||
e.mu.Lock()
|
||||
e.turnOpts = TurnOptions{}
|
||||
e.mu.Unlock()
|
||||
}()
|
||||
|
||||
if e.cfg.Context != nil {
|
||||
e.cfg.Context.AppendMessage(userMsg)
|
||||
}
|
||||
@@ -42,7 +48,9 @@ func (e *Engine) SubmitWithOptions(ctx context.Context, input string, opts TurnO
|
||||
|
||||
// SubmitMessages is like Submit but accepts pre-built messages.
|
||||
func (e *Engine) SubmitMessages(ctx context.Context, msgs []message.Message, cb Callback) (*Turn, error) {
|
||||
e.mu.Lock()
|
||||
e.history = append(e.history, msgs...)
|
||||
e.mu.Unlock()
|
||||
if e.cfg.Context != nil {
|
||||
for _, m := range msgs {
|
||||
e.cfg.Context.AppendMessage(m)
|
||||
@@ -89,14 +97,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
var decision router.RoutingDecision
|
||||
|
||||
if e.cfg.Router != nil {
|
||||
// Classify task from the latest user message
|
||||
prompt := ""
|
||||
for i := len(e.history) - 1; i >= 0; i-- {
|
||||
if e.history[i].Role == message.RoleUser {
|
||||
prompt = e.history[i].TextContent()
|
||||
break
|
||||
}
|
||||
}
|
||||
prompt := e.latestUserPrompt()
|
||||
task := e.classify(ctx, prompt)
|
||||
if e.cfg.Context != nil {
|
||||
task.EstimatedTokens = int(e.cfg.Context.Tracker().CountTokens(prompt))
|
||||
@@ -131,14 +132,15 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
prov := e.activeProvider()
|
||||
e.logger.Debug("streaming request",
|
||||
"provider", e.cfg.Provider.Name(),
|
||||
"provider", prov.Name(),
|
||||
"model", req.Model,
|
||||
"messages", len(req.Messages),
|
||||
"tools", len(req.Tools),
|
||||
"round", turn.Rounds,
|
||||
)
|
||||
s, err = e.cfg.Provider.Stream(ctx, req)
|
||||
s, err = prov.Stream(ctx, req)
|
||||
}
|
||||
if err != nil {
|
||||
var failedArms []router.ArmID
|
||||
@@ -161,13 +163,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
// Retry on transient errors (429, 5xx) with exponential backoff
|
||||
s, err = e.retryOnTransient(ctx, err, skipDelay, func() (stream.Stream, error) {
|
||||
if e.cfg.Router != nil {
|
||||
prompt := ""
|
||||
for i := len(e.history) - 1; i >= 0; i-- {
|
||||
if e.history[i].Role == message.RoleUser {
|
||||
prompt = e.history[i].TextContent()
|
||||
break
|
||||
}
|
||||
}
|
||||
prompt := e.latestUserPrompt()
|
||||
task := e.classify(ctx, prompt)
|
||||
if e.cfg.Context != nil {
|
||||
task.EstimatedTokens = int(e.cfg.Context.Tracker().CountTokens(prompt))
|
||||
@@ -192,7 +188,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
}
|
||||
return s, err
|
||||
}
|
||||
return e.cfg.Provider.Stream(ctx, req)
|
||||
return e.activeProvider().Stream(ctx, req)
|
||||
})
|
||||
if err != nil {
|
||||
// Try reactive compaction on 413 (request too large)
|
||||
@@ -266,8 +262,8 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
|
||||
turn.Usage.Add(resp.Usage)
|
||||
turn.Messages = append(turn.Messages, resp.Message)
|
||||
e.history = append(e.history, resp.Message)
|
||||
e.usage.Add(resp.Usage)
|
||||
e.appendHistory(resp.Message)
|
||||
e.addUsage(resp.Usage)
|
||||
|
||||
// Track in context window and check for compaction
|
||||
if e.cfg.Context != nil {
|
||||
@@ -282,8 +278,9 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
if compacted, err := e.cfg.Context.CompactIfNeeded(); err != nil {
|
||||
e.logger.Error("context compaction failed", "error", err)
|
||||
} else if compacted {
|
||||
e.history = e.cfg.Context.Messages()
|
||||
e.logger.Info("context compacted", "messages", len(e.history))
|
||||
compactedMsgs := e.cfg.Context.Messages()
|
||||
e.replaceHistory(compactedMsgs)
|
||||
e.logger.Info("context compacted", "messages", len(compactedMsgs))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -304,7 +301,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
// Model hit its output token budget mid-response. Inject a continue prompt
|
||||
// and re-query so the response is completed rather than silently truncated.
|
||||
contMsg := message.NewUserText("Continue from where you left off.")
|
||||
e.history = append(e.history, contMsg)
|
||||
e.appendHistory(contMsg)
|
||||
if e.cfg.Context != nil {
|
||||
e.cfg.Context.AppendMessage(contMsg)
|
||||
}
|
||||
@@ -319,7 +316,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
}
|
||||
toolMsg := message.NewToolResults(results...)
|
||||
turn.Messages = append(turn.Messages, toolMsg)
|
||||
e.history = append(e.history, toolMsg)
|
||||
e.appendHistory(toolMsg)
|
||||
if e.cfg.Context != nil {
|
||||
e.cfg.Context.AppendMessage(toolMsg)
|
||||
}
|
||||
@@ -336,7 +333,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
|
||||
func (e *Engine) buildRequest(ctx context.Context) provider.Request {
|
||||
// Use AllMessages (prefix + history) if context window manages prefix docs
|
||||
messages := e.history
|
||||
messages := e.historySnapshot()
|
||||
if e.cfg.Context != nil {
|
||||
messages = e.cfg.Context.AllMessages()
|
||||
}
|
||||
@@ -351,11 +348,12 @@ func (e *Engine) buildRequest(ctx context.Context) provider.Request {
|
||||
systemPrompt = e.cfg.Firewall.ScanSystemPrompt(systemPrompt)
|
||||
}
|
||||
|
||||
turnOpts := e.snapshotTurnOpts()
|
||||
req := provider.Request{
|
||||
Model: e.cfg.Model,
|
||||
Model: e.activeModel(),
|
||||
SystemPrompt: systemPrompt,
|
||||
Messages: messages,
|
||||
ToolChoice: e.turnOpts.ToolChoice,
|
||||
ToolChoice: turnOpts.ToolChoice,
|
||||
Temperature: e.cfg.Temperature,
|
||||
}
|
||||
|
||||
@@ -371,10 +369,10 @@ func (e *Engine) buildRequest(ctx context.Context) provider.Request {
|
||||
includeTools = caps == nil || caps.ToolUse
|
||||
}
|
||||
if includeTools {
|
||||
allowed := e.turnOpts.AllowedTools
|
||||
allowed := turnOpts.AllowedTools
|
||||
for _, t := range e.cfg.Tools.All() {
|
||||
// Skip deferred tools until the model requests them
|
||||
if dt, ok := t.(tool.DeferrableTool); ok && dt.ShouldDefer() && !e.activatedTools[t.Name()] {
|
||||
if dt, ok := t.(tool.DeferrableTool); ok && dt.ShouldDefer() && !e.isToolActivated(t.Name()) {
|
||||
continue
|
||||
}
|
||||
// Filter to allowed tools when a restrict list is set
|
||||
@@ -399,13 +397,7 @@ func (e *Engine) buildRequest(ctx context.Context) provider.Request {
|
||||
|
||||
// Inject coordinator guidance for orchestration tasks
|
||||
if e.cfg.Router != nil {
|
||||
prompt := ""
|
||||
for i := len(e.history) - 1; i >= 0; i-- {
|
||||
if e.history[i].Role == message.RoleUser {
|
||||
prompt = e.history[i].TextContent()
|
||||
break
|
||||
}
|
||||
}
|
||||
prompt := e.latestUserPrompt()
|
||||
if e.classify(ctx, prompt).Type == router.TaskOrchestration {
|
||||
req.SystemPrompt = coordinatorPrompt() + "\n\n" + req.SystemPrompt
|
||||
}
|
||||
@@ -453,7 +445,7 @@ func (e *Engine) executeTools(ctx context.Context, calls []message.ToolCall, cb
|
||||
if ok {
|
||||
// Activate deferred tools on first use
|
||||
if dt, isDeferrable := t.(tool.DeferrableTool); isDeferrable && dt.ShouldDefer() {
|
||||
e.activatedTools[call.Name] = true
|
||||
e.markToolActivated(call.Name)
|
||||
}
|
||||
}
|
||||
if !ok {
|
||||
@@ -536,7 +528,7 @@ func (e *Engine) executeSingleTool(ctx context.Context, call message.ToolCall, t
|
||||
}
|
||||
|
||||
// Path restriction: deny bash and validate fs tool paths against AllowedPaths.
|
||||
if denied, blocked := checkPathRestriction(call, t, args, e.turnOpts.AllowedPaths); blocked {
|
||||
if denied, blocked := checkPathRestriction(call, t, args, e.snapshotTurnOpts().AllowedPaths); blocked {
|
||||
return denied
|
||||
}
|
||||
|
||||
@@ -615,17 +607,11 @@ func (e *Engine) handleRequestTooLarge(ctx context.Context, origErr error, req p
|
||||
return nil, origErr
|
||||
}
|
||||
|
||||
e.history = e.cfg.Context.Messages()
|
||||
e.replaceHistory(e.cfg.Context.Messages())
|
||||
req = e.buildRequest(ctx)
|
||||
|
||||
if e.cfg.Router != nil {
|
||||
prompt := ""
|
||||
for i := len(e.history) - 1; i >= 0; i-- {
|
||||
if e.history[i].Role == message.RoleUser {
|
||||
prompt = e.history[i].TextContent()
|
||||
break
|
||||
}
|
||||
}
|
||||
prompt := e.latestUserPrompt()
|
||||
task := e.classify(ctx, prompt)
|
||||
if e.cfg.Context != nil {
|
||||
task.EstimatedTokens = int(e.cfg.Context.Tracker().CountTokens(prompt))
|
||||
@@ -635,7 +621,7 @@ func (e *Engine) handleRequestTooLarge(ctx context.Context, origErr error, req p
|
||||
s, _, err := e.cfg.Router.Stream(ctx, task, req)
|
||||
return s, err
|
||||
}
|
||||
return e.cfg.Provider.Stream(ctx, req)
|
||||
return e.activeProvider().Stream(ctx, req)
|
||||
}
|
||||
|
||||
// retryOnTransient retries the stream call on 429/5xx with exponential backoff.
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
"somegit.dev/Owlibou/gnoma/internal/stream"
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool"
|
||||
)
|
||||
|
||||
// blockingStream emits one text delta, then blocks Next() until release is closed,
|
||||
// then emits the stop event. Lets a test interleave Submit with concurrent setters.
|
||||
type blockingStream struct {
|
||||
release chan struct{}
|
||||
emitted bool
|
||||
released bool
|
||||
stopReason message.StopReason
|
||||
model string
|
||||
}
|
||||
|
||||
func newBlockingStream(release chan struct{}, model string) *blockingStream {
|
||||
return &blockingStream{release: release, model: model, stopReason: message.StopEndTurn}
|
||||
}
|
||||
|
||||
func (s *blockingStream) Next() bool {
|
||||
if !s.emitted {
|
||||
s.emitted = true
|
||||
return true
|
||||
}
|
||||
if !s.released {
|
||||
<-s.release
|
||||
s.released = true
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *blockingStream) Current() stream.Event {
|
||||
if s.released {
|
||||
return stream.Event{Type: stream.EventTextDelta, StopReason: s.stopReason, Model: s.model}
|
||||
}
|
||||
return stream.Event{Type: stream.EventTextDelta, Text: "hi", Model: s.model}
|
||||
}
|
||||
|
||||
func (s *blockingStream) Err() error { return nil }
|
||||
func (s *blockingStream) Close() error { return nil }
|
||||
|
||||
func TestEngine_ConcurrentSubmitAndSetters(t *testing.T) {
|
||||
release := make(chan struct{})
|
||||
mp := &mockProvider{
|
||||
name: "test",
|
||||
streams: []stream.Stream{newBlockingStream(release, "mock-model")},
|
||||
}
|
||||
e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()})
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = e.Submit(context.Background(), "go", nil)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 100; i++ {
|
||||
e.InjectMessage(message.NewUserText("noise"))
|
||||
_ = e.History()
|
||||
_ = e.Usage()
|
||||
}
|
||||
close(release)
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
Reference in New Issue
Block a user