feat(discovery): add Pass 0 agent client + parser (Mistral)
This commit is contained in:
115
backend/internal/domain/discovery/agent_client.go
Normal file
115
backend/internal/domain/discovery/agent_client.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package discovery
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"marktvogt.de/backend/internal/pkg/ai"
|
||||
)
|
||||
|
||||
// AgentClient wraps the Mistral Pass 0 agent for discovery.
|
||||
type AgentClient struct {
|
||||
ai *ai.Client
|
||||
agentID string
|
||||
}
|
||||
|
||||
func NewAgentClient(aiClient *ai.Client, agentID string) *AgentClient {
|
||||
return &AgentClient{ai: aiClient, agentID: agentID}
|
||||
}
|
||||
|
||||
func (c *AgentClient) Enabled() bool {
|
||||
return c.ai != nil && c.agentID != ""
|
||||
}
|
||||
|
||||
// Discover runs Pass 0 for the given bucket. The agent's full instructions
|
||||
// are set in the Mistral console (see spec §6.2). We only inject the bucket
|
||||
// parameters here.
|
||||
func (c *AgentClient) Discover(ctx context.Context, b Bucket) (Pass0Response, error) {
|
||||
if !c.Enabled() {
|
||||
return Pass0Response{}, fmt.Errorf("discovery agent not configured")
|
||||
}
|
||||
prompt := fmt.Sprintf(
|
||||
"Bucket:\nland: %s\nregion: %s\njahr_monat: %s\n\nFinde alle Maerkte in diesem Bucket und antworte im vorgegebenen JSON-Format.",
|
||||
b.Land, b.Region, b.YearMonth,
|
||||
)
|
||||
result, err := c.ai.Pass0(ctx, c.agentID, prompt)
|
||||
if err != nil {
|
||||
return Pass0Response{}, fmt.Errorf("mistral pass0: %w", err)
|
||||
}
|
||||
return parsePass0Response(result.Content)
|
||||
}
|
||||
|
||||
func parsePass0Response(raw string) (Pass0Response, error) {
|
||||
cleaned := extractJSON(raw)
|
||||
cleaned = stripJSONComments(cleaned)
|
||||
var out Pass0Response
|
||||
if err := json.Unmarshal([]byte(cleaned), &out); err != nil {
|
||||
return Pass0Response{}, fmt.Errorf("unmarshal pass0: %w (raw first 500: %q)", err, truncate(raw, 500))
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// --- JSON helpers (independent copy; logic mirrors domain/market/research.go).
|
||||
// Do not import from the market package — keeping packages decoupled.
|
||||
|
||||
func extractJSON(s string) string {
|
||||
start := strings.IndexByte(s, '{')
|
||||
if start < 0 {
|
||||
return s
|
||||
}
|
||||
s = s[start:]
|
||||
depth := 0
|
||||
for i := 0; i < len(s); i++ {
|
||||
switch s[i] {
|
||||
case '{':
|
||||
depth++
|
||||
case '}':
|
||||
depth--
|
||||
if depth == 0 {
|
||||
return s[:i+1]
|
||||
}
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func stripJSONComments(s string) string {
|
||||
var result []byte
|
||||
inString := false
|
||||
escaped := false
|
||||
for i := 0; i < len(s); i++ {
|
||||
c := s[i]
|
||||
if escaped {
|
||||
result = append(result, c)
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if c == '\\' && inString {
|
||||
result = append(result, c)
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if c == '"' {
|
||||
inString = !inString
|
||||
result = append(result, c)
|
||||
continue
|
||||
}
|
||||
if !inString && c == '/' && i+1 < len(s) && s[i+1] == '/' {
|
||||
for i < len(s) && s[i] != '\n' {
|
||||
i++
|
||||
}
|
||||
continue
|
||||
}
|
||||
result = append(result, c)
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func truncate(s string, n int) string {
|
||||
if len(s) <= n {
|
||||
return s
|
||||
}
|
||||
return s[:n]
|
||||
}
|
||||
74
backend/internal/domain/discovery/agent_client_test.go
Normal file
74
backend/internal/domain/discovery/agent_client_test.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package discovery
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParsePass0_Valid(t *testing.T) {
|
||||
raw := `{
|
||||
"bucket": {"land": "Deutschland", "region": "Bayern", "jahr_monat": "2026-09"},
|
||||
"recherche_datum": "2026-04-18",
|
||||
"quellen_gesamt": ["https://a.example", "https://b.example"],
|
||||
"maerkte": [
|
||||
{
|
||||
"markt_name": "Mittelaltermarkt Trostberg",
|
||||
"stadt": "Trostberg",
|
||||
"bundesland": "Bayern",
|
||||
"start_datum": "2026-09-12",
|
||||
"end_datum": "2026-09-14",
|
||||
"website": "https://trostberg.example",
|
||||
"quellen": ["https://a.example"],
|
||||
"extraktion": "verbatim",
|
||||
"hinweis": null
|
||||
}
|
||||
]
|
||||
}`
|
||||
got, err := parsePass0Response(raw)
|
||||
if err != nil {
|
||||
t.Fatalf("parse err: %v", err)
|
||||
}
|
||||
if got.Bucket.Region != "Bayern" {
|
||||
t.Errorf("region = %q, want Bayern", got.Bucket.Region)
|
||||
}
|
||||
if len(got.Maerkte) != 1 || got.Maerkte[0].MarktName != "Mittelaltermarkt Trostberg" {
|
||||
t.Errorf("unexpected markets: %+v", got.Maerkte)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePass0_WithCommentsAndTrailingText(t *testing.T) {
|
||||
raw := `Here is the JSON:
|
||||
{
|
||||
"bucket": {"land": "Deutschland", "region": "Bayern", "jahr_monat": "2026-09"},
|
||||
// a comment the agent added
|
||||
"recherche_datum": "2026-04-18",
|
||||
"quellen_gesamt": [],
|
||||
"maerkte": []
|
||||
}
|
||||
end.`
|
||||
got, err := parsePass0Response(raw)
|
||||
if err != nil {
|
||||
t.Fatalf("parse err: %v", err)
|
||||
}
|
||||
if got.Bucket.Region != "Bayern" || len(got.Maerkte) != 0 {
|
||||
t.Errorf("unexpected: %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePass0_Malformed(t *testing.T) {
|
||||
raw := `not JSON at all`
|
||||
if _, err := parsePass0Response(raw); err == nil {
|
||||
t.Error("expected error on non-JSON input")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePass0_EmptyMaerkte(t *testing.T) {
|
||||
raw := `{"bucket":{"land":"Deutschland","region":"Bayern","jahr_monat":"2026-09"},"recherche_datum":"","quellen_gesamt":[],"maerkte":[]}`
|
||||
got, err := parsePass0Response(raw)
|
||||
if err != nil {
|
||||
t.Fatalf("parse err: %v", err)
|
||||
}
|
||||
if got.Maerkte == nil {
|
||||
got.Maerkte = []Pass0Market{} // nil vs empty is fine
|
||||
}
|
||||
if len(got.Maerkte) != 0 {
|
||||
t.Errorf("expected empty, got %+v", got.Maerkte)
|
||||
}
|
||||
}
|
||||
@@ -112,6 +112,36 @@ func (c *Client) Pass1(ctx context.Context, prompt string) (PassResult, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Pass0 uses the Conversations API to call a discovery agent identified by agentID.
|
||||
// The agent ID is passed explicitly so the discovery domain can configure its own
|
||||
// agent independently of the agentSimple field used by Pass1.
|
||||
func (c *Client) Pass0(ctx context.Context, agentID, prompt string) (PassResult, error) {
|
||||
c.limiter.wait()
|
||||
if c.sdk == nil || agentID == "" {
|
||||
return PassResult{}, fmt.Errorf("pass0: ai client not configured (sdk=%v agentID=%q)", c.sdk != nil, agentID)
|
||||
}
|
||||
storeFalse := false
|
||||
resp, err := c.sdk.StartConversation(ctx, &conversation.StartRequest{
|
||||
AgentID: agentID,
|
||||
Inputs: conversation.TextInputs(prompt),
|
||||
Store: &storeFalse,
|
||||
})
|
||||
if err != nil {
|
||||
return PassResult{}, fmt.Errorf("pass0 conversation: %w", err)
|
||||
}
|
||||
|
||||
content := extractConversationContent(resp)
|
||||
if content == "" {
|
||||
return PassResult{}, fmt.Errorf("pass0: no assistant message in response")
|
||||
}
|
||||
|
||||
return PassResult{
|
||||
Content: content,
|
||||
Usage: convertConvUsage(resp.Usage),
|
||||
Model: "agent:" + agentID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Pass2 uses chat completions for description generation + retry fields.
|
||||
func (c *Client) Pass2(ctx context.Context, systemPrompt, userPrompt string) (PassResult, error) {
|
||||
c.limiter.wait()
|
||||
|
||||
Reference in New Issue
Block a user