From bf05a5866bae0698ed2ce662e65a70528dc85d2e Mon Sep 17 00:00:00 2001 From: vikingowl Date: Tue, 19 May 2026 17:59:05 +0200 Subject: [PATCH] feat(openai): lexical repair for malformed tool-call arguments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Local-model servers (Ollama, llama.cpp, llamafile) routed through the OpenAI-compatible path frequently emit tool-call arguments that are *almost* valid JSON — wrapped in markdown fences, padded with prose, or trailing a stray comma. Strict parsing fails, the engine receives empty args, and the agent loop has to retry or escalate. Adds repairArgs(raw) at the EventToolCallDone boundary: strict-parse first, then apply cheap lexical fixes (strip ```json fences, drop trailing commas before }/], extract the first balanced {...} block with proper string/escape awareness). On success, the repaired bytes flow through unchanged; on failure, the original is returned and downstream parsing surfaces the error as before. Frontier providers (OpenAI proper, Anthropic, Mistral, Google) are unaffected — their SDKs return structured args that pass strict parse. The repair only does work when the upstream output is malformed. 11 unit tests cover: valid passthrough, empty, trailing commas, single/double-line fences, prose-wrapped, braces-inside-strings, multiple top-level objects (takes the first), and unrepairable input. A stream-level test verifies the wiring through flushNextToolCall. --- internal/provider/openai/repair.go | 128 +++++++++++++ internal/provider/openai/repair_test.go | 231 ++++++++++++++++++++++++ internal/provider/openai/stream.go | 39 +++- 3 files changed, 389 insertions(+), 9 deletions(-) create mode 100644 internal/provider/openai/repair.go create mode 100644 internal/provider/openai/repair_test.go diff --git a/internal/provider/openai/repair.go b/internal/provider/openai/repair.go new file mode 100644 index 0000000..8ff439d --- /dev/null +++ b/internal/provider/openai/repair.go @@ -0,0 +1,128 @@ +package openai + +import ( + "encoding/json" + "regexp" + "strings" +) + +// repairArgs accepts a string of (possibly malformed) tool-call arguments and +// returns valid JSON when small lexical fixes can recover it. The bool return +// reports whether a repair was applied. +// +// Small local models served via OpenAI-compatible endpoints (Ollama, +// llama.cpp, llamafile) frequently emit args wrapped in markdown fences, +// surrounded by prose, or with trailing commas. Repairing these here keeps +// the downstream agent loop from failing on cosmetic noise. +// +// Repair tiers (cheap → less cheap): +// 1. Strict json.Valid → return as-is. +// 2. Strip ```json / ``` code fences. +// 3. Trim trailing commas before `}` or `]`. +// 4. Extract the first balanced {...} block (respects strings/escapes). +// +// If none of the tiers produces valid JSON, returns the original input bytes +// and repaired=false so callers can surface the parse error normally. +func repairArgs(raw string) (json.RawMessage, bool) { + if raw == "" { + return json.RawMessage(raw), false + } + if json.Valid([]byte(raw)) { + return json.RawMessage(raw), false + } + + candidates := []string{ + stripCodeFences(raw), + stripCodeFences(trimTrailingCommas(raw)), + extractFirstObject(raw), + extractFirstObject(stripCodeFences(raw)), + trimTrailingCommas(extractFirstObject(stripCodeFences(raw))), + } + for _, c := range candidates { + if c == "" || c == raw { + continue + } + if json.Valid([]byte(c)) { + return json.RawMessage(c), true + } + } + return json.RawMessage(raw), false +} + +// codeFenceRE matches a backtick-fenced block, optionally tagged ```json. +// Captures the block's body in group 1. +var codeFenceRE = regexp.MustCompile("(?s)```(?:json)?\\s*(.*?)\\s*```") + +// stripCodeFences pulls JSON out of a markdown code fence. If a complete +// fenced block exists, returns its body; otherwise strips a leading or +// trailing partial fence and returns the remainder trimmed. +func stripCodeFences(s string) string { + if m := codeFenceRE.FindStringSubmatch(s); m != nil { + return strings.TrimSpace(m[1]) + } + // Partial fence — strip leading ```json / ``` if present, and any + // trailing ``` even if no closing pair matched. + out := s + out = strings.TrimSpace(out) + out = strings.TrimPrefix(out, "```json") + out = strings.TrimPrefix(out, "```") + out = strings.TrimSuffix(out, "```") + return strings.TrimSpace(out) +} + +// trailingCommaRE matches a comma followed only by whitespace before a +// closing `}` or `]` — i.e. a JSON-illegal trailing comma. +var trailingCommaRE = regexp.MustCompile(`,(\s*[}\]])`) + +// trimTrailingCommas removes JSON-illegal trailing commas. Naïve regex +// pass — fine for our use because tool-call arg payloads don't contain +// literal commas-inside-strings that would resemble a trailing comma after +// whitespace. +func trimTrailingCommas(s string) string { + return trailingCommaRE.ReplaceAllString(s, "$1") +} + +// extractFirstObject walks s and returns the substring from the first `{` +// to its matching `}`, respecting string boundaries and escapes. Returns +// "" when no balanced object is found. +func extractFirstObject(s string) string { + start := -1 + depth := 0 + inString := false + escaped := false + + for i := 0; i < len(s); i++ { + c := s[i] + if escaped { + escaped = false + continue + } + if inString { + switch c { + case '\\': + escaped = true + case '"': + inString = false + } + continue + } + switch c { + case '"': + inString = true + case '{': + if start == -1 { + start = i + } + depth++ + case '}': + if depth == 0 { + continue + } + depth-- + if depth == 0 && start != -1 { + return s[start : i+1] + } + } + } + return "" +} diff --git a/internal/provider/openai/repair_test.go b/internal/provider/openai/repair_test.go new file mode 100644 index 0000000..f8a2f87 --- /dev/null +++ b/internal/provider/openai/repair_test.go @@ -0,0 +1,231 @@ +package openai + +import ( + "encoding/json" + "testing" + + "somegit.dev/Owlibou/gnoma/internal/stream" +) + +func TestOpenAIStream_FlushNextToolCall_RepairsArgs(t *testing.T) { + s := &openaiStream{ + toolCalls: map[int64]*toolCallState{ + 0: { + id: "call_1", + name: "fs.edit", + // Malformed: wrapped in code fence + trailing comma + args: "```json\n{\"path\":\"/x\",\"old_string\":\"a\",\"new_string\":\"b\",}\n```", + }, + }, + } + + ev, ok := s.flushNextToolCall() + if !ok { + t.Fatal("flushNextToolCall returned ok=false with pending call") + } + if ev.Type != stream.EventToolCallDone { + t.Errorf("event type = %v, want EventToolCallDone", ev.Type) + } + if ev.ToolCallID != "call_1" { + t.Errorf("ToolCallID = %q", ev.ToolCallID) + } + if !json.Valid(ev.Args) { + t.Fatalf("Args is not valid JSON after repair: %q", string(ev.Args)) + } + var parsed map[string]string + if err := json.Unmarshal(ev.Args, &parsed); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if parsed["path"] != "/x" || parsed["old_string"] != "a" || parsed["new_string"] != "b" { + t.Errorf("data lost in repair: %v", parsed) + } + + // Second call: queue empty. + if _, ok := s.flushNextToolCall(); ok { + t.Error("flushNextToolCall returned ok=true on empty queue") + } +} + +func TestOpenAIStream_FlushNextToolCall_ValidArgsPassThrough(t *testing.T) { + original := `{"path":"/x"}` + s := &openaiStream{ + toolCalls: map[int64]*toolCallState{ + 0: {id: "call_1", name: "fs.read", args: original}, + }, + } + ev, ok := s.flushNextToolCall() + if !ok { + t.Fatal("flushNextToolCall returned ok=false") + } + if string(ev.Args) != original { + t.Errorf("valid args mutated: %q → %q", original, string(ev.Args)) + } +} + +func TestRepairArgs_ValidPassesThrough(t *testing.T) { + cases := []string{ + `{"path":"/foo.go"}`, + `{}`, + `{"a":1,"b":[1,2,3],"c":{"d":"e"}}`, + `{"text":"contains \"quoted\" inner"}`, + } + for _, in := range cases { + got, repaired := repairArgs(in) + if repaired { + t.Errorf("repairArgs(%q): repaired=true on valid input", in) + } + if string(got) != in { + t.Errorf("repairArgs(%q): mutated valid input → %q", in, string(got)) + } + if !json.Valid(got) { + t.Errorf("repairArgs(%q): output not valid JSON: %q", in, string(got)) + } + } +} + +func TestRepairArgs_EmptyInput(t *testing.T) { + got, repaired := repairArgs("") + if repaired { + t.Error("empty input should not be marked repaired") + } + if string(got) != "" { + t.Errorf("empty input → %q, want empty", string(got)) + } +} + +func TestRepairArgs_TrimsTrailingComma(t *testing.T) { + cases := []struct { + in, want string + }{ + {`{"a":1,}`, `{"a":1}`}, + {`{"a":1, "b":2,}`, `{"a":1, "b":2}`}, + {`{"a":[1,2,3,]}`, `{"a":[1,2,3]}`}, + {`{"a":1 , }`, `{"a":1 }`}, + } + for _, tc := range cases { + got, repaired := repairArgs(tc.in) + if !repaired { + t.Errorf("repairArgs(%q): repaired=false, want true", tc.in) + } + if !json.Valid(got) { + t.Errorf("repairArgs(%q): output not valid JSON: %q", tc.in, string(got)) + } + if string(got) != tc.want { + t.Errorf("repairArgs(%q) = %q, want %q", tc.in, string(got), tc.want) + } + } +} + +func TestRepairArgs_StripsCodeFences(t *testing.T) { + cases := []string{ + "```json\n{\"path\":\"/x\"}\n```", + "```\n{\"path\":\"/x\"}\n```", + "```json\n{\"path\":\"/x\"}", + " ```json {\"path\":\"/x\"} ``` ", + } + for _, in := range cases { + got, repaired := repairArgs(in) + if !repaired { + t.Errorf("repairArgs(%q): repaired=false, want true", in) + } + if !json.Valid(got) { + t.Errorf("repairArgs(%q): output not valid JSON: %q", in, string(got)) + } + var parsed map[string]any + if err := json.Unmarshal(got, &parsed); err != nil { + t.Errorf("repairArgs(%q): unmarshal: %v", in, err) + continue + } + if parsed["path"] != "/x" { + t.Errorf("repairArgs(%q): lost data, got %v", in, parsed) + } + } +} + +func TestRepairArgs_ExtractsFromProse(t *testing.T) { + cases := []string{ + `Here are the arguments: {"path":"/x"}`, + `{"path":"/x"} -- that's the call`, + `Sure, calling with {"path":"/x"} now.`, + } + for _, in := range cases { + got, repaired := repairArgs(in) + if !repaired { + t.Errorf("repairArgs(%q): repaired=false, want true", in) + } + if !json.Valid(got) { + t.Errorf("repairArgs(%q): output not valid JSON: %q", in, string(got)) + } + } +} + +func TestRepairArgs_HandlesBracesInsideStrings(t *testing.T) { + in := `{"snippet":"if x { return y }","other":"a}b"}` + got, _ := repairArgs(in) + if !json.Valid(got) { + t.Fatalf("output not valid JSON: %q", string(got)) + } + var parsed map[string]string + if err := json.Unmarshal(got, &parsed); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if parsed["snippet"] != "if x { return y }" { + t.Errorf("snippet corrupted: %q", parsed["snippet"]) + } + if parsed["other"] != "a}b" { + t.Errorf("other corrupted: %q", parsed["other"]) + } +} + +func TestRepairArgs_TakesFirstBalancedBlock(t *testing.T) { + // Some small models emit two JSON objects back-to-back; take the first. + in := `{"path":"/a"} {"path":"/b"}` + got, _ := repairArgs(in) + if !json.Valid(got) { + t.Fatalf("not valid: %q", string(got)) + } + var parsed map[string]string + if err := json.Unmarshal(got, &parsed); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if parsed["path"] != "/a" { + t.Errorf("expected first block, got %q", parsed["path"]) + } +} + +func TestRepairArgs_UnrepairableFails(t *testing.T) { + cases := []string{ + `{"a":`, // truncated + `not json at all`, // no JSON + `{{{`, // unbalanced + `{"a":1`, // missing close + } + for _, in := range cases { + got, repaired := repairArgs(in) + // Either: returns valid JSON (we got lucky) or returns original + repaired=false + if json.Valid(got) { + continue // acceptable — we managed to repair + } + if repaired { + t.Errorf("repairArgs(%q): claims repaired but output invalid: %q", in, string(got)) + } + } +} + +func TestRepairArgs_FencesAndTrailingCommaCombined(t *testing.T) { + in := "```json\n{\"path\":\"/x\",}\n```" + got, repaired := repairArgs(in) + if !repaired { + t.Fatal("expected repaired=true") + } + if !json.Valid(got) { + t.Fatalf("not valid: %q", string(got)) + } + var parsed map[string]string + if err := json.Unmarshal(got, &parsed); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if parsed["path"] != "/x" { + t.Errorf("lost data: %v", parsed) + } +} diff --git a/internal/provider/openai/stream.go b/internal/provider/openai/stream.go index 086849f..7b4c33b 100644 --- a/internal/provider/openai/stream.go +++ b/internal/provider/openai/stream.go @@ -41,6 +41,33 @@ func newOpenAIStream(raw *ssestream.Stream[oai.ChatCompletionChunk]) *openaiStre } } +// flushNextToolCall returns the next pending tool-call Done event, applying +// repairArgs to recover from small lexical mistakes that local-model servers +// (Ollama, llama.cpp, llamafile) routinely emit: markdown fences, trailing +// commas, prose-wrapped objects. The bool return is false once the queue is +// empty. +func (s *openaiStream) flushNextToolCall() (stream.Event, bool) { + for idx, tc := range s.toolCalls { + args, repaired := repairArgs(tc.args) + if repaired { + slog.Debug("openai: repaired malformed tool-call arguments", + "tool", tc.name, + "raw_len", len(tc.args), + "repaired_len", len(args), + ) + } + ev := stream.Event{ + Type: stream.EventToolCallDone, + ToolCallID: tc.id, + ToolCallName: unsanitizeToolName(tc.name), + Args: args, + } + delete(s.toolCalls, idx) + return ev, true + } + return stream.Event{}, false +} + func (s *openaiStream) Next() bool { for s.raw.Next() { chunk := s.raw.Current() @@ -146,15 +173,9 @@ func (s *openaiStream) Next() bool { } } - // Stream ended — flush tool call Done events, then emit stop - for idx, tc := range s.toolCalls { - s.cur = stream.Event{ - Type: stream.EventToolCallDone, - ToolCallID: tc.id, - ToolCallName: unsanitizeToolName(tc.name), - Args: json.RawMessage(tc.args), - } - delete(s.toolCalls, idx) + // Stream ended — flush tool call Done events, then emit stop. + if ev, ok := s.flushNextToolCall(); ok { + s.cur = ev return true }