feat(openai): lexical repair for malformed tool-call arguments
Local-model servers (Ollama, llama.cpp, llamafile) routed through the
OpenAI-compatible path frequently emit tool-call arguments that are
*almost* valid JSON — wrapped in markdown fences, padded with prose, or
trailing a stray comma. Strict parsing fails, the engine receives empty
args, and the agent loop has to retry or escalate.
Adds repairArgs(raw) at the EventToolCallDone boundary: strict-parse
first, then apply cheap lexical fixes (strip ```json fences, drop
trailing commas before }/], extract the first balanced {...} block with
proper string/escape awareness). On success, the repaired bytes flow
through unchanged; on failure, the original is returned and downstream
parsing surfaces the error as before.
Frontier providers (OpenAI proper, Anthropic, Mistral, Google) are
unaffected — their SDKs return structured args that pass strict parse.
The repair only does work when the upstream output is malformed.
11 unit tests cover: valid passthrough, empty, trailing commas,
single/double-line fences, prose-wrapped, braces-inside-strings,
multiple top-level objects (takes the first), and unrepairable input.
A stream-level test verifies the wiring through flushNextToolCall.
This commit is contained in:
@@ -0,0 +1,128 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// repairArgs accepts a string of (possibly malformed) tool-call arguments and
|
||||
// returns valid JSON when small lexical fixes can recover it. The bool return
|
||||
// reports whether a repair was applied.
|
||||
//
|
||||
// Small local models served via OpenAI-compatible endpoints (Ollama,
|
||||
// llama.cpp, llamafile) frequently emit args wrapped in markdown fences,
|
||||
// surrounded by prose, or with trailing commas. Repairing these here keeps
|
||||
// the downstream agent loop from failing on cosmetic noise.
|
||||
//
|
||||
// Repair tiers (cheap → less cheap):
|
||||
// 1. Strict json.Valid → return as-is.
|
||||
// 2. Strip ```json / ``` code fences.
|
||||
// 3. Trim trailing commas before `}` or `]`.
|
||||
// 4. Extract the first balanced {...} block (respects strings/escapes).
|
||||
//
|
||||
// If none of the tiers produces valid JSON, returns the original input bytes
|
||||
// and repaired=false so callers can surface the parse error normally.
|
||||
func repairArgs(raw string) (json.RawMessage, bool) {
|
||||
if raw == "" {
|
||||
return json.RawMessage(raw), false
|
||||
}
|
||||
if json.Valid([]byte(raw)) {
|
||||
return json.RawMessage(raw), false
|
||||
}
|
||||
|
||||
candidates := []string{
|
||||
stripCodeFences(raw),
|
||||
stripCodeFences(trimTrailingCommas(raw)),
|
||||
extractFirstObject(raw),
|
||||
extractFirstObject(stripCodeFences(raw)),
|
||||
trimTrailingCommas(extractFirstObject(stripCodeFences(raw))),
|
||||
}
|
||||
for _, c := range candidates {
|
||||
if c == "" || c == raw {
|
||||
continue
|
||||
}
|
||||
if json.Valid([]byte(c)) {
|
||||
return json.RawMessage(c), true
|
||||
}
|
||||
}
|
||||
return json.RawMessage(raw), false
|
||||
}
|
||||
|
||||
// codeFenceRE matches a backtick-fenced block, optionally tagged ```json.
|
||||
// Captures the block's body in group 1.
|
||||
var codeFenceRE = regexp.MustCompile("(?s)```(?:json)?\\s*(.*?)\\s*```")
|
||||
|
||||
// stripCodeFences pulls JSON out of a markdown code fence. If a complete
|
||||
// fenced block exists, returns its body; otherwise strips a leading or
|
||||
// trailing partial fence and returns the remainder trimmed.
|
||||
func stripCodeFences(s string) string {
|
||||
if m := codeFenceRE.FindStringSubmatch(s); m != nil {
|
||||
return strings.TrimSpace(m[1])
|
||||
}
|
||||
// Partial fence — strip leading ```json / ``` if present, and any
|
||||
// trailing ``` even if no closing pair matched.
|
||||
out := s
|
||||
out = strings.TrimSpace(out)
|
||||
out = strings.TrimPrefix(out, "```json")
|
||||
out = strings.TrimPrefix(out, "```")
|
||||
out = strings.TrimSuffix(out, "```")
|
||||
return strings.TrimSpace(out)
|
||||
}
|
||||
|
||||
// trailingCommaRE matches a comma followed only by whitespace before a
|
||||
// closing `}` or `]` — i.e. a JSON-illegal trailing comma.
|
||||
var trailingCommaRE = regexp.MustCompile(`,(\s*[}\]])`)
|
||||
|
||||
// trimTrailingCommas removes JSON-illegal trailing commas. Naïve regex
|
||||
// pass — fine for our use because tool-call arg payloads don't contain
|
||||
// literal commas-inside-strings that would resemble a trailing comma after
|
||||
// whitespace.
|
||||
func trimTrailingCommas(s string) string {
|
||||
return trailingCommaRE.ReplaceAllString(s, "$1")
|
||||
}
|
||||
|
||||
// extractFirstObject walks s and returns the substring from the first `{`
|
||||
// to its matching `}`, respecting string boundaries and escapes. Returns
|
||||
// "" when no balanced object is found.
|
||||
func extractFirstObject(s string) string {
|
||||
start := -1
|
||||
depth := 0
|
||||
inString := false
|
||||
escaped := false
|
||||
|
||||
for i := 0; i < len(s); i++ {
|
||||
c := s[i]
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if inString {
|
||||
switch c {
|
||||
case '\\':
|
||||
escaped = true
|
||||
case '"':
|
||||
inString = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
switch c {
|
||||
case '"':
|
||||
inString = true
|
||||
case '{':
|
||||
if start == -1 {
|
||||
start = i
|
||||
}
|
||||
depth++
|
||||
case '}':
|
||||
if depth == 0 {
|
||||
continue
|
||||
}
|
||||
depth--
|
||||
if depth == 0 && start != -1 {
|
||||
return s[start : i+1]
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -0,0 +1,231 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/stream"
|
||||
)
|
||||
|
||||
func TestOpenAIStream_FlushNextToolCall_RepairsArgs(t *testing.T) {
|
||||
s := &openaiStream{
|
||||
toolCalls: map[int64]*toolCallState{
|
||||
0: {
|
||||
id: "call_1",
|
||||
name: "fs.edit",
|
||||
// Malformed: wrapped in code fence + trailing comma
|
||||
args: "```json\n{\"path\":\"/x\",\"old_string\":\"a\",\"new_string\":\"b\",}\n```",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ev, ok := s.flushNextToolCall()
|
||||
if !ok {
|
||||
t.Fatal("flushNextToolCall returned ok=false with pending call")
|
||||
}
|
||||
if ev.Type != stream.EventToolCallDone {
|
||||
t.Errorf("event type = %v, want EventToolCallDone", ev.Type)
|
||||
}
|
||||
if ev.ToolCallID != "call_1" {
|
||||
t.Errorf("ToolCallID = %q", ev.ToolCallID)
|
||||
}
|
||||
if !json.Valid(ev.Args) {
|
||||
t.Fatalf("Args is not valid JSON after repair: %q", string(ev.Args))
|
||||
}
|
||||
var parsed map[string]string
|
||||
if err := json.Unmarshal(ev.Args, &parsed); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if parsed["path"] != "/x" || parsed["old_string"] != "a" || parsed["new_string"] != "b" {
|
||||
t.Errorf("data lost in repair: %v", parsed)
|
||||
}
|
||||
|
||||
// Second call: queue empty.
|
||||
if _, ok := s.flushNextToolCall(); ok {
|
||||
t.Error("flushNextToolCall returned ok=true on empty queue")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIStream_FlushNextToolCall_ValidArgsPassThrough(t *testing.T) {
|
||||
original := `{"path":"/x"}`
|
||||
s := &openaiStream{
|
||||
toolCalls: map[int64]*toolCallState{
|
||||
0: {id: "call_1", name: "fs.read", args: original},
|
||||
},
|
||||
}
|
||||
ev, ok := s.flushNextToolCall()
|
||||
if !ok {
|
||||
t.Fatal("flushNextToolCall returned ok=false")
|
||||
}
|
||||
if string(ev.Args) != original {
|
||||
t.Errorf("valid args mutated: %q → %q", original, string(ev.Args))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepairArgs_ValidPassesThrough(t *testing.T) {
|
||||
cases := []string{
|
||||
`{"path":"/foo.go"}`,
|
||||
`{}`,
|
||||
`{"a":1,"b":[1,2,3],"c":{"d":"e"}}`,
|
||||
`{"text":"contains \"quoted\" inner"}`,
|
||||
}
|
||||
for _, in := range cases {
|
||||
got, repaired := repairArgs(in)
|
||||
if repaired {
|
||||
t.Errorf("repairArgs(%q): repaired=true on valid input", in)
|
||||
}
|
||||
if string(got) != in {
|
||||
t.Errorf("repairArgs(%q): mutated valid input → %q", in, string(got))
|
||||
}
|
||||
if !json.Valid(got) {
|
||||
t.Errorf("repairArgs(%q): output not valid JSON: %q", in, string(got))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepairArgs_EmptyInput(t *testing.T) {
|
||||
got, repaired := repairArgs("")
|
||||
if repaired {
|
||||
t.Error("empty input should not be marked repaired")
|
||||
}
|
||||
if string(got) != "" {
|
||||
t.Errorf("empty input → %q, want empty", string(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepairArgs_TrimsTrailingComma(t *testing.T) {
|
||||
cases := []struct {
|
||||
in, want string
|
||||
}{
|
||||
{`{"a":1,}`, `{"a":1}`},
|
||||
{`{"a":1, "b":2,}`, `{"a":1, "b":2}`},
|
||||
{`{"a":[1,2,3,]}`, `{"a":[1,2,3]}`},
|
||||
{`{"a":1 , }`, `{"a":1 }`},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
got, repaired := repairArgs(tc.in)
|
||||
if !repaired {
|
||||
t.Errorf("repairArgs(%q): repaired=false, want true", tc.in)
|
||||
}
|
||||
if !json.Valid(got) {
|
||||
t.Errorf("repairArgs(%q): output not valid JSON: %q", tc.in, string(got))
|
||||
}
|
||||
if string(got) != tc.want {
|
||||
t.Errorf("repairArgs(%q) = %q, want %q", tc.in, string(got), tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepairArgs_StripsCodeFences(t *testing.T) {
|
||||
cases := []string{
|
||||
"```json\n{\"path\":\"/x\"}\n```",
|
||||
"```\n{\"path\":\"/x\"}\n```",
|
||||
"```json\n{\"path\":\"/x\"}",
|
||||
" ```json {\"path\":\"/x\"} ``` ",
|
||||
}
|
||||
for _, in := range cases {
|
||||
got, repaired := repairArgs(in)
|
||||
if !repaired {
|
||||
t.Errorf("repairArgs(%q): repaired=false, want true", in)
|
||||
}
|
||||
if !json.Valid(got) {
|
||||
t.Errorf("repairArgs(%q): output not valid JSON: %q", in, string(got))
|
||||
}
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal(got, &parsed); err != nil {
|
||||
t.Errorf("repairArgs(%q): unmarshal: %v", in, err)
|
||||
continue
|
||||
}
|
||||
if parsed["path"] != "/x" {
|
||||
t.Errorf("repairArgs(%q): lost data, got %v", in, parsed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepairArgs_ExtractsFromProse(t *testing.T) {
|
||||
cases := []string{
|
||||
`Here are the arguments: {"path":"/x"}`,
|
||||
`{"path":"/x"} -- that's the call`,
|
||||
`Sure, calling with {"path":"/x"} now.`,
|
||||
}
|
||||
for _, in := range cases {
|
||||
got, repaired := repairArgs(in)
|
||||
if !repaired {
|
||||
t.Errorf("repairArgs(%q): repaired=false, want true", in)
|
||||
}
|
||||
if !json.Valid(got) {
|
||||
t.Errorf("repairArgs(%q): output not valid JSON: %q", in, string(got))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepairArgs_HandlesBracesInsideStrings(t *testing.T) {
|
||||
in := `{"snippet":"if x { return y }","other":"a}b"}`
|
||||
got, _ := repairArgs(in)
|
||||
if !json.Valid(got) {
|
||||
t.Fatalf("output not valid JSON: %q", string(got))
|
||||
}
|
||||
var parsed map[string]string
|
||||
if err := json.Unmarshal(got, &parsed); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if parsed["snippet"] != "if x { return y }" {
|
||||
t.Errorf("snippet corrupted: %q", parsed["snippet"])
|
||||
}
|
||||
if parsed["other"] != "a}b" {
|
||||
t.Errorf("other corrupted: %q", parsed["other"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepairArgs_TakesFirstBalancedBlock(t *testing.T) {
|
||||
// Some small models emit two JSON objects back-to-back; take the first.
|
||||
in := `{"path":"/a"} {"path":"/b"}`
|
||||
got, _ := repairArgs(in)
|
||||
if !json.Valid(got) {
|
||||
t.Fatalf("not valid: %q", string(got))
|
||||
}
|
||||
var parsed map[string]string
|
||||
if err := json.Unmarshal(got, &parsed); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if parsed["path"] != "/a" {
|
||||
t.Errorf("expected first block, got %q", parsed["path"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepairArgs_UnrepairableFails(t *testing.T) {
|
||||
cases := []string{
|
||||
`{"a":`, // truncated
|
||||
`not json at all`, // no JSON
|
||||
`{{{`, // unbalanced
|
||||
`{"a":1`, // missing close
|
||||
}
|
||||
for _, in := range cases {
|
||||
got, repaired := repairArgs(in)
|
||||
// Either: returns valid JSON (we got lucky) or returns original + repaired=false
|
||||
if json.Valid(got) {
|
||||
continue // acceptable — we managed to repair
|
||||
}
|
||||
if repaired {
|
||||
t.Errorf("repairArgs(%q): claims repaired but output invalid: %q", in, string(got))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepairArgs_FencesAndTrailingCommaCombined(t *testing.T) {
|
||||
in := "```json\n{\"path\":\"/x\",}\n```"
|
||||
got, repaired := repairArgs(in)
|
||||
if !repaired {
|
||||
t.Fatal("expected repaired=true")
|
||||
}
|
||||
if !json.Valid(got) {
|
||||
t.Fatalf("not valid: %q", string(got))
|
||||
}
|
||||
var parsed map[string]string
|
||||
if err := json.Unmarshal(got, &parsed); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if parsed["path"] != "/x" {
|
||||
t.Errorf("lost data: %v", parsed)
|
||||
}
|
||||
}
|
||||
@@ -41,6 +41,33 @@ func newOpenAIStream(raw *ssestream.Stream[oai.ChatCompletionChunk]) *openaiStre
|
||||
}
|
||||
}
|
||||
|
||||
// flushNextToolCall returns the next pending tool-call Done event, applying
|
||||
// repairArgs to recover from small lexical mistakes that local-model servers
|
||||
// (Ollama, llama.cpp, llamafile) routinely emit: markdown fences, trailing
|
||||
// commas, prose-wrapped objects. The bool return is false once the queue is
|
||||
// empty.
|
||||
func (s *openaiStream) flushNextToolCall() (stream.Event, bool) {
|
||||
for idx, tc := range s.toolCalls {
|
||||
args, repaired := repairArgs(tc.args)
|
||||
if repaired {
|
||||
slog.Debug("openai: repaired malformed tool-call arguments",
|
||||
"tool", tc.name,
|
||||
"raw_len", len(tc.args),
|
||||
"repaired_len", len(args),
|
||||
)
|
||||
}
|
||||
ev := stream.Event{
|
||||
Type: stream.EventToolCallDone,
|
||||
ToolCallID: tc.id,
|
||||
ToolCallName: unsanitizeToolName(tc.name),
|
||||
Args: args,
|
||||
}
|
||||
delete(s.toolCalls, idx)
|
||||
return ev, true
|
||||
}
|
||||
return stream.Event{}, false
|
||||
}
|
||||
|
||||
func (s *openaiStream) Next() bool {
|
||||
for s.raw.Next() {
|
||||
chunk := s.raw.Current()
|
||||
@@ -146,15 +173,9 @@ func (s *openaiStream) Next() bool {
|
||||
}
|
||||
}
|
||||
|
||||
// Stream ended — flush tool call Done events, then emit stop
|
||||
for idx, tc := range s.toolCalls {
|
||||
s.cur = stream.Event{
|
||||
Type: stream.EventToolCallDone,
|
||||
ToolCallID: tc.id,
|
||||
ToolCallName: unsanitizeToolName(tc.name),
|
||||
Args: json.RawMessage(tc.args),
|
||||
}
|
||||
delete(s.toolCalls, idx)
|
||||
// Stream ended — flush tool call Done events, then emit stop.
|
||||
if ev, ok := s.flushNextToolCall(); ok {
|
||||
s.cur = ev
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user