feat(engine): early-stop detection for runaway agent loops
Adds three lightweight per-turn detectors that fire corrective user messages back into the conversation when the model goes off the rails: - RepetitionDetector: sliding-window scan over streamed text deltas; trips when a 50/80/120-char pattern repeats >= 3 times in the trailing 200 chars. Breaks the active stream and injects a correction. - PatchFailureTracker: per-path counter for fs.edit/fs.write failures; trips on the 4th consecutive failure and steers the model to fs.write rather than another fs.edit on the same path. Success decrements with a floor of 0; paths are isolated. - DetectGreeting: narrow allowlist for "how can I help" style replies; only consulted after a round that used tools, so first-turn greetings don't false-positive. Detector state is per-turn (declared locally in runLoop), single- goroutine use. Corrective messages are appended as user-role text to both engine history and the context window. Telemetry: each trigger logs at INFO with round + path where applicable. Covered by 12 unit tests for the primitives and 5 loop-level integration tests that drive the full agentic loop via the existing eventStream mock.
This commit is contained in:
@@ -0,0 +1,205 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Default tuning for the early-stop detectors. These mirror the values used
|
||||
// in smallcode's reference implementation, adjusted for our streaming shape.
|
||||
const (
|
||||
defaultRepetitionWindow = 200 // last N chars of stream we inspect
|
||||
defaultRepetitionThreshold = 3 // pattern must repeat ≥ this many times
|
||||
defaultMaxPatchFailures = 4 // consecutive failures on a path → escalate
|
||||
)
|
||||
|
||||
var defaultRepetitionSizes = []int{50, 80, 120}
|
||||
|
||||
// RepetitionDetector watches a stream's text deltas for a fixed-size pattern
|
||||
// that recurs ≥ threshold times within the trailing window. Detects the
|
||||
// "model lost the plot and is now repeating itself" failure mode.
|
||||
//
|
||||
// Single-goroutine use only — the loop drives it from the stream consume path.
|
||||
type RepetitionDetector struct {
|
||||
windowChars int
|
||||
threshold int
|
||||
sizes []int
|
||||
buf strings.Builder
|
||||
}
|
||||
|
||||
func NewRepetitionDetector() *RepetitionDetector {
|
||||
return &RepetitionDetector{
|
||||
windowChars: defaultRepetitionWindow,
|
||||
threshold: defaultRepetitionThreshold,
|
||||
sizes: defaultRepetitionSizes,
|
||||
}
|
||||
}
|
||||
|
||||
// Feed appends streamed text to the buffer and returns true when a repetition
|
||||
// pattern is detected. Once triggered, the caller is expected to act on the
|
||||
// signal and call Reset before reusing the detector.
|
||||
func (d *RepetitionDetector) Feed(text string) bool {
|
||||
if text == "" {
|
||||
return false
|
||||
}
|
||||
d.buf.WriteString(text)
|
||||
|
||||
// Trim the buffer to bound memory. Keep twice the window so we always
|
||||
// have a stable trailing slice to scan.
|
||||
if d.buf.Len() > d.windowChars*4 {
|
||||
s := d.buf.String()
|
||||
keep := s[len(s)-d.windowChars*2:]
|
||||
d.buf.Reset()
|
||||
d.buf.WriteString(keep)
|
||||
}
|
||||
|
||||
s := d.buf.String()
|
||||
// We need at least one window's worth of data for the smallest pattern
|
||||
// to recur threshold times.
|
||||
if len(s) < d.sizes[0]*d.threshold {
|
||||
return false
|
||||
}
|
||||
tail := s
|
||||
if len(tail) > d.windowChars {
|
||||
tail = tail[len(tail)-d.windowChars:]
|
||||
}
|
||||
|
||||
for _, size := range d.sizes {
|
||||
if len(tail) < size*d.threshold {
|
||||
continue
|
||||
}
|
||||
pattern := tail[:size]
|
||||
count := 0
|
||||
for i := 0; i+size <= len(tail); {
|
||||
if tail[i:i+size] == pattern {
|
||||
count++
|
||||
if count >= d.threshold {
|
||||
return true
|
||||
}
|
||||
i += size
|
||||
continue
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Reset clears the accumulated buffer. Call at the start of a new turn.
|
||||
func (d *RepetitionDetector) Reset() {
|
||||
d.buf.Reset()
|
||||
}
|
||||
|
||||
// PatchFailureTracker counts consecutive write/edit failures per file path
|
||||
// within a turn. Triggers when a single path crosses the configured threshold,
|
||||
// at which point the loop should steer the model away from further patches
|
||||
// against that path.
|
||||
type PatchFailureTracker struct {
|
||||
maxFailures int
|
||||
failures map[string]int
|
||||
}
|
||||
|
||||
func NewPatchFailureTracker() *PatchFailureTracker {
|
||||
return &PatchFailureTracker{
|
||||
maxFailures: defaultMaxPatchFailures,
|
||||
failures: make(map[string]int),
|
||||
}
|
||||
}
|
||||
|
||||
// RecordFailure increments the failure count for path and returns true when
|
||||
// the threshold has just been reached. After triggering, the path's counter
|
||||
// is reset so subsequent failures don't re-fire the signal until they
|
||||
// re-accumulate.
|
||||
func (t *PatchFailureTracker) RecordFailure(path string) bool {
|
||||
if path == "" {
|
||||
return false
|
||||
}
|
||||
t.failures[path]++
|
||||
if t.failures[path] >= t.maxFailures {
|
||||
delete(t.failures, path)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// RecordSuccess decrements the failure count for path with a floor of 0.
|
||||
// A run of successful edits should let the path recover, but we don't fully
|
||||
// reset on a single success — a path that fails three times then succeeds
|
||||
// once is still a suspicious target.
|
||||
func (t *PatchFailureTracker) RecordSuccess(path string) {
|
||||
if path == "" {
|
||||
return
|
||||
}
|
||||
if n := t.failures[path]; n > 0 {
|
||||
t.failures[path] = n - 1
|
||||
if t.failures[path] == 0 {
|
||||
delete(t.failures, path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reset clears all per-path counters. Call at the start of a new turn.
|
||||
func (t *PatchFailureTracker) Reset() {
|
||||
t.failures = make(map[string]int)
|
||||
}
|
||||
|
||||
// greetingMarkers are case-folded substrings that indicate the model has
|
||||
// dropped its task context and reverted to an opening-of-conversation reply.
|
||||
// Kept deliberately narrow — we only want to fire on responses that look
|
||||
// like the start of a new chat, not on any polite phrasing.
|
||||
var greetingMarkers = []string{
|
||||
"how can i help",
|
||||
"how can i assist",
|
||||
"what would you like",
|
||||
"what can i do for you",
|
||||
"i'm ready to",
|
||||
"hi there",
|
||||
}
|
||||
|
||||
// DetectGreeting reports whether text looks like a greeting/reset response.
|
||||
// Stateless. The loop should only consult this after a round that contained
|
||||
// tool calls — a greeting at the start of a turn is fine.
|
||||
func DetectGreeting(text string) bool {
|
||||
if len(text) < 10 {
|
||||
return false
|
||||
}
|
||||
lc := strings.ToLower(text)
|
||||
for _, m := range greetingMarkers {
|
||||
if strings.Contains(lc, m) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Corrective injections returned to the model when a detector fires. These
|
||||
// are appended as user messages before the next round so the model sees a
|
||||
// concrete instruction rather than a system reset.
|
||||
|
||||
// RepetitionInjection is the corrective message used when the repetition
|
||||
// detector fires.
|
||||
func RepetitionInjection() string {
|
||||
return "[system] Your output is repeating itself in a loop. Stop. " +
|
||||
"Take a different approach, or state explicitly what is blocking you " +
|
||||
"and why the current strategy is not converging."
|
||||
}
|
||||
|
||||
// PatchSpiralInjection is the corrective message used when a single file
|
||||
// has accumulated too many failed fs.edit attempts. Steers the model toward
|
||||
// fs.write rather than another patch.
|
||||
func PatchSpiralInjection(path string) string {
|
||||
return fmt.Sprintf(
|
||||
"[system] You have failed to edit %s several times. Stop using fs.edit "+
|
||||
"on this file. Instead: 1) read the current file with fs.read, "+
|
||||
"2) decide what the file should contain in full, "+
|
||||
"3) rewrite it with fs.write. Do not attempt another fs.edit on %s.",
|
||||
path, path)
|
||||
}
|
||||
|
||||
// GreetingInjection is the corrective message used when the model emits a
|
||||
// greeting mid-task (context loss).
|
||||
func GreetingInjection() string {
|
||||
return "[system] You produced a greeting instead of continuing the task. " +
|
||||
"Look at the conversation above — there is work in progress. " +
|
||||
"Resume where you left off. Do not restart the conversation."
|
||||
}
|
||||
@@ -0,0 +1,227 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
"somegit.dev/Owlibou/gnoma/internal/stream"
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool"
|
||||
)
|
||||
|
||||
// failingEditStream emits a single fs.edit tool call against the given path.
|
||||
func failingEditStream(callID, path string) stream.Stream {
|
||||
args := []byte(`{"path":"` + path + `","old_string":"foo","new_string":"bar"}`)
|
||||
return newEventStream(message.StopToolUse, "test-model",
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "Patching the file."},
|
||||
stream.Event{Type: stream.EventToolCallStart, ToolCallID: callID, ToolCallName: "fs.edit"},
|
||||
stream.Event{Type: stream.EventToolCallDone, ToolCallID: callID, ToolCallName: "fs.edit", Args: json.RawMessage(args)},
|
||||
)
|
||||
}
|
||||
|
||||
func TestEarlyStop_PatchSpiral_InjectsCorrection(t *testing.T) {
|
||||
reg := tool.NewRegistry()
|
||||
reg.Register(&mockTool{
|
||||
name: "fs.edit",
|
||||
execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) {
|
||||
return tool.Result{}, errors.New("old_string not found")
|
||||
},
|
||||
})
|
||||
|
||||
// Four failed edits on the same path, then a final acknowledgement.
|
||||
mp := &mockProvider{
|
||||
name: "test",
|
||||
streams: []stream.Stream{
|
||||
failingEditStream("tc_1", "/work/main.go"),
|
||||
failingEditStream("tc_2", "/work/main.go"),
|
||||
failingEditStream("tc_3", "/work/main.go"),
|
||||
failingEditStream("tc_4", "/work/main.go"),
|
||||
newEventStream(message.StopEndTurn, "test-model",
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "Understood, switching to a full rewrite."},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
e, _ := New(Config{Provider: mp, Tools: reg})
|
||||
_, err := e.Submit(context.Background(), "fix the bug", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Submit: %v", err)
|
||||
}
|
||||
|
||||
// Walk history for the corrective injection.
|
||||
var foundSpiral bool
|
||||
for _, m := range e.History() {
|
||||
if m.Role != message.RoleUser {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(m.TextContent(), "/work/main.go") &&
|
||||
strings.Contains(strings.ToLower(m.TextContent()), "fs.write") {
|
||||
foundSpiral = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundSpiral {
|
||||
t.Fatal("expected patch-spiral corrective message in history, not found")
|
||||
}
|
||||
|
||||
if mp.calls != 5 {
|
||||
t.Errorf("provider calls = %d, want 5 (4 failing edits + 1 ack)", mp.calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEarlyStop_PatchSpiral_PerPathIsolation(t *testing.T) {
|
||||
reg := tool.NewRegistry()
|
||||
reg.Register(&mockTool{
|
||||
name: "fs.edit",
|
||||
execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) {
|
||||
return tool.Result{}, errors.New("old_string not found")
|
||||
},
|
||||
})
|
||||
|
||||
// Failures alternate between two paths; neither reaches the threshold.
|
||||
mp := &mockProvider{
|
||||
name: "test",
|
||||
streams: []stream.Stream{
|
||||
failingEditStream("tc_1", "/work/a.go"),
|
||||
failingEditStream("tc_2", "/work/b.go"),
|
||||
failingEditStream("tc_3", "/work/a.go"),
|
||||
failingEditStream("tc_4", "/work/b.go"),
|
||||
newEventStream(message.StopEndTurn, "test-model",
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "Giving up."},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
e, _ := New(Config{Provider: mp, Tools: reg})
|
||||
_, err := e.Submit(context.Background(), "edit two files", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Submit: %v", err)
|
||||
}
|
||||
|
||||
for _, m := range e.History() {
|
||||
if m.Role == message.RoleUser && strings.Contains(strings.ToLower(m.TextContent()), "fs.write") {
|
||||
t.Fatal("patch-spiral injection fired despite per-path failures below threshold")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEarlyStop_GreetingRegression_InjectsCorrection(t *testing.T) {
|
||||
reg := tool.NewRegistry()
|
||||
reg.Register(&mockTool{
|
||||
name: "fs.read",
|
||||
readOnly: true,
|
||||
execFn: func(_ context.Context, _ json.RawMessage) (tool.Result, error) {
|
||||
return tool.Result{Output: "package main"}, nil
|
||||
},
|
||||
})
|
||||
|
||||
mp := &mockProvider{
|
||||
name: "test",
|
||||
streams: []stream.Stream{
|
||||
// Round 1: legitimate tool call
|
||||
newEventStream(message.StopToolUse, "test-model",
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "Reading file."},
|
||||
stream.Event{Type: stream.EventToolCallStart, ToolCallID: "tc_1", ToolCallName: "fs.read"},
|
||||
stream.Event{Type: stream.EventToolCallDone, ToolCallID: "tc_1", ToolCallName: "fs.read", Args: json.RawMessage(`{"path":"/x.go"}`)},
|
||||
),
|
||||
// Round 2: model loses context, emits a greeting
|
||||
newEventStream(message.StopEndTurn, "test-model",
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "Hello! How can I help you today?"},
|
||||
),
|
||||
// Round 3: model resumes after correction
|
||||
newEventStream(message.StopEndTurn, "test-model",
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "Sorry — continuing. The file is a Go package."},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
e, _ := New(Config{Provider: mp, Tools: reg})
|
||||
_, err := e.Submit(context.Background(), "inspect /x.go", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Submit: %v", err)
|
||||
}
|
||||
|
||||
var foundGreeting bool
|
||||
for _, m := range e.History() {
|
||||
if m.Role == message.RoleUser && strings.Contains(m.TextContent(), "greeting instead of continuing") {
|
||||
foundGreeting = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundGreeting {
|
||||
t.Fatal("expected greeting-regression corrective message in history")
|
||||
}
|
||||
if mp.calls != 3 {
|
||||
t.Errorf("provider calls = %d, want 3", mp.calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEarlyStop_NoFalsePositive_GreetingOnFirstTurn(t *testing.T) {
|
||||
// A greeting on the very first round (no prior tool calls) is fine — the
|
||||
// detector should only fire after a round that used tools.
|
||||
mp := &mockProvider{
|
||||
name: "test",
|
||||
streams: []stream.Stream{
|
||||
newEventStream(message.StopEndTurn, "test-model",
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "Hi there! How can I help you?"},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()})
|
||||
_, err := e.Submit(context.Background(), "hello", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Submit: %v", err)
|
||||
}
|
||||
|
||||
for _, m := range e.History() {
|
||||
if m.Role == message.RoleUser && strings.Contains(m.TextContent(), "greeting instead of continuing") {
|
||||
t.Fatal("greeting detector fired on first-round greeting")
|
||||
}
|
||||
}
|
||||
if mp.calls != 1 {
|
||||
t.Errorf("provider calls = %d, want 1", mp.calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEarlyStop_Repetition_BreaksAndCorrects(t *testing.T) {
|
||||
// Round 1: a stream that repeats a phrase enough to trip the detector.
|
||||
phrase := "I will read the file and then carefully apply the edit. "
|
||||
repeatEvents := make([]stream.Event, 0, 8)
|
||||
for range 8 {
|
||||
repeatEvents = append(repeatEvents, stream.Event{Type: stream.EventTextDelta, Text: phrase})
|
||||
}
|
||||
round1 := newEventStream(message.StopEndTurn, "test-model", repeatEvents...)
|
||||
|
||||
round2 := newEventStream(message.StopEndTurn, "test-model",
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "Acknowledged — taking a different approach."},
|
||||
)
|
||||
|
||||
mp := &mockProvider{
|
||||
name: "test",
|
||||
streams: []stream.Stream{round1, round2},
|
||||
}
|
||||
|
||||
e, _ := New(Config{Provider: mp, Tools: tool.NewRegistry()})
|
||||
_, err := e.Submit(context.Background(), "do something", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Submit: %v", err)
|
||||
}
|
||||
|
||||
var foundRep bool
|
||||
for _, m := range e.History() {
|
||||
if m.Role == message.RoleUser && strings.Contains(m.TextContent(), "repeating itself in a loop") {
|
||||
foundRep = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundRep {
|
||||
t.Fatal("expected repetition corrective message in history")
|
||||
}
|
||||
if mp.calls != 2 {
|
||||
t.Errorf("provider calls = %d, want 2", mp.calls)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,170 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRepetitionDetector_NoRepetition(t *testing.T) {
|
||||
d := NewRepetitionDetector()
|
||||
chunks := []string{
|
||||
"Let me think about this step by step. ",
|
||||
"First, I need to read the file. ",
|
||||
"Then I will identify the section that needs to change. ",
|
||||
"After that, I can apply the edit and verify the result. ",
|
||||
}
|
||||
for i, c := range chunks {
|
||||
if d.Feed(c) {
|
||||
t.Fatalf("unexpected repetition at chunk %d (%q)", i, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepetitionDetector_TriggersOnRepeatedPattern(t *testing.T) {
|
||||
d := NewRepetitionDetector()
|
||||
// 60-char phrase repeated 5x → 300 chars total. Falls inside the
|
||||
// 200-char window with multiple repeats.
|
||||
phrase := "I need to read the file and then apply the edit carefully. "
|
||||
triggered := false
|
||||
for range 5 {
|
||||
if d.Feed(phrase) {
|
||||
triggered = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !triggered {
|
||||
t.Fatal("expected repetition detection on repeated phrase")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepetitionDetector_DoesNotTriggerOnNaturalText(t *testing.T) {
|
||||
d := NewRepetitionDetector()
|
||||
// Natural code-review-style text with repeated bigrams but distinct sentences.
|
||||
text := strings.Repeat(
|
||||
"The function returns an error when the path is invalid or unreadable. ",
|
||||
1) +
|
||||
"It accepts a context and a byte slice as arguments. " +
|
||||
"The implementation walks the slice and validates each entry. " +
|
||||
"On success it produces a structured result containing the parsed value. " +
|
||||
"This avoids reallocating the underlying buffer on each call. " +
|
||||
"Tests cover the happy path and the malformed input cases. " +
|
||||
"Future work may add streaming support for very large inputs."
|
||||
if d.Feed(text) {
|
||||
t.Fatal("false positive on natural prose")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepetitionDetector_Reset(t *testing.T) {
|
||||
d := NewRepetitionDetector()
|
||||
phrase := "loop loop loop loop loop loop loop loop loop loop loop loop "
|
||||
for range 5 {
|
||||
d.Feed(phrase)
|
||||
}
|
||||
d.Reset()
|
||||
// After reset, a small amount of new text must not trigger.
|
||||
if d.Feed("hello world") {
|
||||
t.Fatal("detector triggered immediately after reset on fresh input")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchFailureTracker_TriggersAtThreshold(t *testing.T) {
|
||||
tr := NewPatchFailureTracker()
|
||||
for i := range 3 {
|
||||
if tr.RecordFailure("/foo.go") {
|
||||
t.Fatalf("triggered too early at attempt %d", i+1)
|
||||
}
|
||||
}
|
||||
if !tr.RecordFailure("/foo.go") {
|
||||
t.Fatal("should trigger on 4th failure")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchFailureTracker_TriggerResetsPath(t *testing.T) {
|
||||
tr := NewPatchFailureTracker()
|
||||
for range 4 {
|
||||
tr.RecordFailure("/foo.go")
|
||||
}
|
||||
// Subsequent failure on same path starts fresh (we already escalated).
|
||||
if tr.RecordFailure("/foo.go") {
|
||||
t.Fatal("should not re-trigger immediately after escalation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchFailureTracker_SuccessDecrements(t *testing.T) {
|
||||
tr := NewPatchFailureTracker()
|
||||
tr.RecordFailure("/foo.go")
|
||||
tr.RecordFailure("/foo.go")
|
||||
tr.RecordSuccess("/foo.go") // back to 1
|
||||
if tr.RecordFailure("/foo.go") {
|
||||
t.Fatal("triggered at 2 after decrement")
|
||||
}
|
||||
if tr.RecordFailure("/foo.go") {
|
||||
t.Fatal("triggered at 3 after decrement")
|
||||
}
|
||||
if !tr.RecordFailure("/foo.go") {
|
||||
t.Fatal("should trigger at 4")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchFailureTracker_PerPathIsolation(t *testing.T) {
|
||||
tr := NewPatchFailureTracker()
|
||||
for range 4 {
|
||||
tr.RecordFailure("/foo.go")
|
||||
}
|
||||
// /bar.go must be unaffected.
|
||||
for i := range 3 {
|
||||
if tr.RecordFailure("/bar.go") {
|
||||
t.Fatalf("/bar.go triggered too early at attempt %d", i+1)
|
||||
}
|
||||
}
|
||||
if !tr.RecordFailure("/bar.go") {
|
||||
t.Fatal("/bar.go should trigger at 4")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchFailureTracker_Reset(t *testing.T) {
|
||||
tr := NewPatchFailureTracker()
|
||||
tr.RecordFailure("/foo.go")
|
||||
tr.RecordFailure("/foo.go")
|
||||
tr.RecordFailure("/foo.go")
|
||||
tr.Reset()
|
||||
for i := range 3 {
|
||||
if tr.RecordFailure("/foo.go") {
|
||||
t.Fatalf("triggered too early after Reset at attempt %d", i+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectGreeting(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
text string
|
||||
want bool
|
||||
}{
|
||||
{"how can I help", "How can I help you today?", true},
|
||||
{"what would you like", "What would you like to do?", true},
|
||||
{"hello ready", "Hello! I'm ready to assist.", true},
|
||||
{"hi there", "Hi there! What can I do for you?", true},
|
||||
{"task progress", "I've updated the file as requested.", false},
|
||||
{"code reference", "The function in foo.go returns an error.", false},
|
||||
{"empty", "", false},
|
||||
{"single word", "Done.", false},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := DetectGreeting(tc.text); got != tc.want {
|
||||
t.Fatalf("DetectGreeting(%q) = %v, want %v", tc.text, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPatchSpiralInjection_NamesPath(t *testing.T) {
|
||||
msg := PatchSpiralInjection("/work/foo.go")
|
||||
if !strings.Contains(msg, "/work/foo.go") {
|
||||
t.Fatalf("injection missing path: %q", msg)
|
||||
}
|
||||
if !strings.Contains(strings.ToLower(msg), "fs.write") {
|
||||
t.Fatalf("injection should steer toward fs.write rewrite: %q", msg)
|
||||
}
|
||||
}
|
||||
+105
-1
@@ -2,6 +2,7 @@ package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
@@ -66,6 +67,11 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
var lastArmID router.ArmID
|
||||
var lastTaskType router.TaskType
|
||||
|
||||
// Early-stop detectors — per-turn scope, single-goroutine use.
|
||||
repetitionDet := NewRepetitionDetector()
|
||||
patchFails := NewPatchFailureTracker()
|
||||
priorRoundHadToolCalls := false
|
||||
|
||||
reportOutcome := func(err error) {
|
||||
if e.cfg.Router == nil || lastArmID == "" {
|
||||
return
|
||||
@@ -210,6 +216,7 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
|
||||
streamStart := time.Now()
|
||||
var firstTokenAt time.Time
|
||||
repetitionTripped := false
|
||||
|
||||
for s.Next() {
|
||||
evt := s.Current()
|
||||
@@ -220,6 +227,20 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
firstTokenAt = time.Now()
|
||||
}
|
||||
|
||||
// Feed text deltas to the repetition detector. On trigger, stop
|
||||
// consuming further events — the partial response is committed
|
||||
// to history below and a corrective message is injected.
|
||||
if evt.Type == stream.EventTextDelta && evt.Text != "" {
|
||||
if repetitionDet.Feed(evt.Text) {
|
||||
repetitionTripped = true
|
||||
e.logger.Info("early-stop: repetition loop detected", "round", turn.Rounds)
|
||||
if cb != nil {
|
||||
cb(evt)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Capture stop reason and model from events
|
||||
if evt.StopReason != "" {
|
||||
stopReason = evt.StopReason
|
||||
@@ -294,6 +315,21 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
"round", turn.Rounds,
|
||||
)
|
||||
|
||||
// Repetition loop — inject correction and re-query.
|
||||
if repetitionTripped {
|
||||
e.injectCorrective(RepetitionInjection())
|
||||
continue
|
||||
}
|
||||
|
||||
// Greeting regression — only meaningful after a round that used tools.
|
||||
if priorRoundHadToolCalls && !resp.Message.HasToolCalls() {
|
||||
if DetectGreeting(resp.Message.TextContent()) {
|
||||
e.logger.Info("early-stop: greeting regression detected", "round", turn.Rounds)
|
||||
e.injectCorrective(GreetingInjection())
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Decide next action
|
||||
switch resp.StopReason {
|
||||
case message.StopEndTurn, message.StopSequence:
|
||||
@@ -312,7 +348,8 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
// Continue loop — next round will resume generation
|
||||
|
||||
case message.StopToolUse:
|
||||
results, err := e.executeTools(ctx, resp.Message.ToolCalls(), cb)
|
||||
calls := resp.Message.ToolCalls()
|
||||
results, err := e.executeTools(ctx, calls, cb)
|
||||
if err != nil {
|
||||
toolErr := fmt.Errorf("tool execution: %w", err)
|
||||
reportOutcome(toolErr)
|
||||
@@ -324,6 +361,15 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
if e.cfg.Context != nil {
|
||||
e.cfg.Context.AppendMessage(toolMsg)
|
||||
}
|
||||
|
||||
// Track patch failures per file; trigger an escalation if a
|
||||
// single path crosses the threshold.
|
||||
if spiralPath := e.recordPatchOutcomes(calls, results, patchFails); spiralPath != "" {
|
||||
e.logger.Info("early-stop: patch spiral detected", "path", spiralPath, "round", turn.Rounds)
|
||||
e.injectCorrective(PatchSpiralInjection(spiralPath))
|
||||
}
|
||||
|
||||
priorRoundHadToolCalls = true
|
||||
// Continue loop — re-query provider with tool results
|
||||
|
||||
default:
|
||||
@@ -335,6 +381,64 @@ func (e *Engine) runLoop(ctx context.Context, cb Callback) (*Turn, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// injectCorrective appends a user-role corrective message to history and the
|
||||
// context window. Used by the early-stop detectors to steer the model on the
|
||||
// next round.
|
||||
func (e *Engine) injectCorrective(text string) {
|
||||
msg := message.NewUserText(text)
|
||||
e.appendHistory(msg)
|
||||
if e.cfg.Context != nil {
|
||||
e.cfg.Context.AppendMessage(msg)
|
||||
}
|
||||
}
|
||||
|
||||
// recordPatchOutcomes walks fs.edit/fs.write tool calls and feeds their
|
||||
// success/failure into the tracker. Returns the first path that crossed the
|
||||
// patch-spiral threshold on this round, or "" if none did.
|
||||
func (e *Engine) recordPatchOutcomes(calls []message.ToolCall, results []message.ToolResult, tr *PatchFailureTracker) string {
|
||||
if len(calls) == 0 || len(results) == 0 {
|
||||
return ""
|
||||
}
|
||||
resByID := make(map[string]*message.ToolResult, len(results))
|
||||
for i := range results {
|
||||
resByID[results[i].ToolCallID] = &results[i]
|
||||
}
|
||||
var spiralPath string
|
||||
for _, call := range calls {
|
||||
if call.Name != "fs.edit" && call.Name != "fs.write" {
|
||||
continue
|
||||
}
|
||||
res, ok := resByID[call.ID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
path := extractPatchPath(call.Arguments)
|
||||
if path == "" {
|
||||
continue
|
||||
}
|
||||
if res.IsError {
|
||||
if tr.RecordFailure(path) && spiralPath == "" {
|
||||
spiralPath = path
|
||||
}
|
||||
} else {
|
||||
tr.RecordSuccess(path)
|
||||
}
|
||||
}
|
||||
return spiralPath
|
||||
}
|
||||
|
||||
// extractPatchPath pulls "path" out of fs.edit / fs.write arguments. Returns
|
||||
// "" when the args are unreadable — the tracker treats that as "skip".
|
||||
func extractPatchPath(args json.RawMessage) string {
|
||||
var a struct {
|
||||
Path string `json:"path"`
|
||||
}
|
||||
if err := json.Unmarshal(args, &a); err != nil {
|
||||
return ""
|
||||
}
|
||||
return a.Path
|
||||
}
|
||||
|
||||
func (e *Engine) buildRequest(ctx context.Context) provider.Request {
|
||||
// Use AllMessages (prefix + history) if context window manages prefix docs
|
||||
messages := e.historySnapshot()
|
||||
|
||||
Reference in New Issue
Block a user