diff --git a/conversation/conversation.go b/conversation/conversation.go new file mode 100644 index 0000000..9d02b24 --- /dev/null +++ b/conversation/conversation.go @@ -0,0 +1,182 @@ +package conversation + +import ( + "encoding/json" + "fmt" + + "somegit.dev/vikingowl/mistral-go-sdk/chat" +) + +// HandoffExecution controls tool call execution. +type HandoffExecution string + +const ( + HandoffClient HandoffExecution = "client" + HandoffServer HandoffExecution = "server" +) + +// Tool represents a conversation tool. +type Tool struct { + Type string `json:"type"` + Function *chat.Function `json:"function,omitempty"` + LibraryIDs []string `json:"library_ids,omitempty"` + ToolConfiguration *ToolConfig `json:"tool_configuration,omitempty"` +} + +// ToolConfig configures tool behavior. +type ToolConfig struct { + Exclude []string `json:"exclude,omitempty"` + Include []string `json:"include,omitempty"` + RequiresConfirmation []string `json:"requires_confirmation,omitempty"` +} + +// CompletionArgs holds optional completion parameters. +type CompletionArgs struct { + Stop []string `json:"stop,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + RandomSeed *int `json:"random_seed,omitempty"` + Prediction *chat.Prediction `json:"prediction,omitempty"` + ResponseFormat *chat.ResponseFormat `json:"response_format,omitempty"` + ToolChoice *chat.ToolChoiceMode `json:"tool_choice,omitempty"` +} + +// ToolCallConfirmation confirms or denies a pending tool call. +type ToolCallConfirmation struct { + ToolCallID string `json:"tool_call_id"` + Confirmation string `json:"confirmation"` // "allow" or "deny" +} + +// Inputs represents conversation inputs (text string or entry array). +type Inputs struct { + text *string + entries []Entry +} + +// TextInputs creates Inputs from a plain text string. +func TextInputs(s string) Inputs { return Inputs{text: &s} } + +// EntryInputs creates Inputs from entry objects. +func EntryInputs(entries ...Entry) Inputs { return Inputs{entries: entries} } + +func (i Inputs) MarshalJSON() ([]byte, error) { + if i.text != nil { + return json.Marshal(*i.text) + } + return json.Marshal(i.entries) +} + +// UsageInfo contains conversation token usage. +type UsageInfo struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + ConnectorTokens *int `json:"connector_tokens,omitempty"` + Connectors map[string]int `json:"connectors,omitempty"` +} + +// Response is the response from starting, appending, or restarting a conversation. +type Response struct { + Object string `json:"object"` + ConversationID string `json:"conversation_id"` + Outputs []Entry `json:"-"` + Usage UsageInfo `json:"usage"` + Guardrails json.RawMessage `json:"guardrails,omitempty"` +} + +func (r *Response) UnmarshalJSON(data []byte) error { + type alias Response + var raw struct { + *alias + Outputs []json.RawMessage `json:"outputs"` + } + raw.alias = (*alias)(r) + if err := json.Unmarshal(data, &raw); err != nil { + return fmt.Errorf("mistral: unmarshal conversation response: %w", err) + } + r.Outputs = make([]Entry, len(raw.Outputs)) + for i, o := range raw.Outputs { + entry, err := UnmarshalEntry(o) + if err != nil { + return err + } + r.Outputs[i] = entry + } + return nil +} + +// Conversation represents conversation metadata. +type Conversation struct { + Object string `json:"object"` + ID string `json:"id"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` + Model string `json:"model,omitempty"` + AgentID string `json:"agent_id,omitempty"` + AgentVersion json.RawMessage `json:"agent_version,omitempty"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + Instructions *string `json:"instructions,omitempty"` + Tools []Tool `json:"tools,omitempty"` + CompletionArgs *CompletionArgs `json:"completion_args,omitempty"` +} + +// History is the response from getting conversation history. +type History struct { + Object string `json:"object"` + ConversationID string `json:"conversation_id"` + Entries []Entry `json:"-"` +} + +func (h *History) UnmarshalJSON(data []byte) error { + type alias History + var raw struct { + *alias + Entries []json.RawMessage `json:"entries"` + } + raw.alias = (*alias)(h) + if err := json.Unmarshal(data, &raw); err != nil { + return fmt.Errorf("mistral: unmarshal conversation history: %w", err) + } + h.Entries = make([]Entry, len(raw.Entries)) + for i, e := range raw.Entries { + entry, err := UnmarshalEntry(e) + if err != nil { + return err + } + h.Entries[i] = entry + } + return nil +} + +// Messages is the response from getting conversation messages. +type Messages struct { + Object string `json:"object"` + ConversationID string `json:"conversation_id"` + Messages []Entry `json:"-"` +} + +func (m *Messages) UnmarshalJSON(data []byte) error { + type alias Messages + var raw struct { + *alias + Messages []json.RawMessage `json:"messages"` + } + raw.alias = (*alias)(m) + if err := json.Unmarshal(data, &raw); err != nil { + return fmt.Errorf("mistral: unmarshal conversation messages: %w", err) + } + m.Messages = make([]Entry, len(raw.Messages)) + for i, msg := range raw.Messages { + entry, err := UnmarshalEntry(msg) + if err != nil { + return err + } + m.Messages[i] = entry + } + return nil +} diff --git a/conversation/entry.go b/conversation/entry.go new file mode 100644 index 0000000..eb7939c --- /dev/null +++ b/conversation/entry.go @@ -0,0 +1,161 @@ +package conversation + +import ( + "encoding/json" + "fmt" + "strings" +) + +// Entry is a sealed interface for conversation history entries. +type Entry interface { + entryType() string +} + +// MessageInputEntry represents a user or assistant input message. +type MessageInputEntry struct { + Object string `json:"object"` + ID string `json:"id"` + Type string `json:"type"` + CreatedAt string `json:"created_at"` + CompletedAt *string `json:"completed_at,omitempty"` + Role string `json:"role"` + Content json.RawMessage `json:"content"` + Prefix bool `json:"prefix,omitempty"` +} + +func (*MessageInputEntry) entryType() string { return "message.input" } + +// MessageOutputEntry represents an assistant output message. +type MessageOutputEntry struct { + Object string `json:"object"` + ID string `json:"id"` + Type string `json:"type"` + CreatedAt string `json:"created_at"` + CompletedAt *string `json:"completed_at,omitempty"` + Role string `json:"role"` + Content json.RawMessage `json:"content"` + AgentID *string `json:"agent_id,omitempty"` + Model *string `json:"model,omitempty"` +} + +func (*MessageOutputEntry) entryType() string { return "message.output" } + +// FunctionCallEntry represents a function call by the model. +type FunctionCallEntry struct { + Object string `json:"object"` + ID string `json:"id"` + Type string `json:"type"` + CreatedAt string `json:"created_at"` + CompletedAt *string `json:"completed_at,omitempty"` + ToolCallID string `json:"tool_call_id"` + Name string `json:"name"` + Arguments json.RawMessage `json:"arguments"` + ConfirmationStatus *string `json:"confirmation_status,omitempty"` + AgentID *string `json:"agent_id,omitempty"` + Model *string `json:"model,omitempty"` +} + +func (*FunctionCallEntry) entryType() string { return "function.call" } + +// FunctionResultEntry represents a function result provided by the client. +type FunctionResultEntry struct { + Object string `json:"object"` + ID string `json:"id"` + Type string `json:"type"` + CreatedAt string `json:"created_at"` + CompletedAt *string `json:"completed_at,omitempty"` + ToolCallID string `json:"tool_call_id"` + Result string `json:"result"` +} + +func (*FunctionResultEntry) entryType() string { return "function.result" } + +// ToolExecutionEntry represents a built-in tool execution. +type ToolExecutionEntry struct { + Object string `json:"object"` + ID string `json:"id"` + Type string `json:"type"` + CreatedAt string `json:"created_at"` + CompletedAt *string `json:"completed_at,omitempty"` + Name string `json:"name"` + Arguments string `json:"arguments"` + Info map[string]any `json:"info,omitempty"` + AgentID *string `json:"agent_id,omitempty"` + Model *string `json:"model,omitempty"` +} + +func (*ToolExecutionEntry) entryType() string { return "tool.execution" } + +// AgentHandoffEntry represents an agent-to-agent handoff. +type AgentHandoffEntry struct { + Object string `json:"object"` + ID string `json:"id"` + Type string `json:"type"` + CreatedAt string `json:"created_at"` + CompletedAt *string `json:"completed_at,omitempty"` + PreviousAgentID string `json:"previous_agent_id"` + PreviousAgentName string `json:"previous_agent_name"` + NextAgentID string `json:"next_agent_id"` + NextAgentName string `json:"next_agent_name"` +} + +func (*AgentHandoffEntry) entryType() string { return "agent.handoff" } + +// UnmarshalEntry dispatches JSON to the concrete Entry type +// based on the "type" discriminator field. +func UnmarshalEntry(data []byte) (Entry, error) { + var probe struct { + Type string `json:"type"` + } + if err := json.Unmarshal(data, &probe); err != nil { + return nil, fmt.Errorf("mistral: unmarshal entry: %w", err) + } + switch probe.Type { + case "message.input": + var e MessageInputEntry + return &e, json.Unmarshal(data, &e) + case "message.output": + var e MessageOutputEntry + return &e, json.Unmarshal(data, &e) + case "function.call": + var e FunctionCallEntry + return &e, json.Unmarshal(data, &e) + case "function.result": + var e FunctionResultEntry + return &e, json.Unmarshal(data, &e) + case "tool.execution": + var e ToolExecutionEntry + return &e, json.Unmarshal(data, &e) + case "agent.handoff": + var e AgentHandoffEntry + return &e, json.Unmarshal(data, &e) + default: + return nil, fmt.Errorf("mistral: unknown entry type: %q", probe.Type) + } +} + +// TextContent extracts text from a raw content field. +// Handles both string content and chunk arrays (extracts text chunks). +func TextContent(raw json.RawMessage) string { + if len(raw) == 0 { + return "" + } + var s string + if json.Unmarshal(raw, &s) == nil { + return s + } + var chunks []struct { + Type string `json:"type"` + Text string `json:"text"` + } + if json.Unmarshal(raw, &chunks) == nil { + var sb strings.Builder + for _, ch := range chunks { + if ch.Type == "text" { + sb.WriteString(ch.Text) + } + } + return sb.String() + } + return "" +} diff --git a/conversation/entry_test.go b/conversation/entry_test.go new file mode 100644 index 0000000..26be87c --- /dev/null +++ b/conversation/entry_test.go @@ -0,0 +1,145 @@ +package conversation + +import ( + "encoding/json" + "testing" +) + +func TestUnmarshalEntry_MessageInput(t *testing.T) { + data := []byte(`{"object":"entry","id":"e1","type":"message.input","created_at":"2024-01-01T00:00:00Z","role":"user","content":"Hello"}`) + entry, err := UnmarshalEntry(data) + if err != nil { + t.Fatal(err) + } + e, ok := entry.(*MessageInputEntry) + if !ok { + t.Fatalf("expected *MessageInputEntry, got %T", entry) + } + if e.ID != "e1" { + t.Errorf("got id %q", e.ID) + } + if e.Role != "user" { + t.Errorf("got role %q", e.Role) + } + if TextContent(e.Content) != "Hello" { + t.Errorf("got content %q", TextContent(e.Content)) + } +} + +func TestUnmarshalEntry_MessageOutput(t *testing.T) { + data := []byte(`{"object":"entry","id":"e2","type":"message.output","created_at":"2024-01-01T00:00:00Z","role":"assistant","content":"Hi there!"}`) + entry, err := UnmarshalEntry(data) + if err != nil { + t.Fatal(err) + } + e, ok := entry.(*MessageOutputEntry) + if !ok { + t.Fatalf("expected *MessageOutputEntry, got %T", entry) + } + if TextContent(e.Content) != "Hi there!" { + t.Errorf("got content %q", TextContent(e.Content)) + } +} + +func TestUnmarshalEntry_FunctionCall(t *testing.T) { + data := []byte(`{"object":"entry","id":"e3","type":"function.call","created_at":"2024-01-01T00:00:00Z","tool_call_id":"tc1","name":"get_weather","arguments":"{\"city\":\"Paris\"}"}`) + entry, err := UnmarshalEntry(data) + if err != nil { + t.Fatal(err) + } + e, ok := entry.(*FunctionCallEntry) + if !ok { + t.Fatalf("expected *FunctionCallEntry, got %T", entry) + } + if e.Name != "get_weather" { + t.Errorf("got name %q", e.Name) + } + if e.ToolCallID != "tc1" { + t.Errorf("got tool_call_id %q", e.ToolCallID) + } +} + +func TestUnmarshalEntry_FunctionResult(t *testing.T) { + data := []byte(`{"object":"entry","id":"e4","type":"function.result","created_at":"2024-01-01T00:00:00Z","tool_call_id":"tc1","result":"{\"temp\":22}"}`) + entry, err := UnmarshalEntry(data) + if err != nil { + t.Fatal(err) + } + e, ok := entry.(*FunctionResultEntry) + if !ok { + t.Fatalf("expected *FunctionResultEntry, got %T", entry) + } + if e.Result != `{"temp":22}` { + t.Errorf("got result %q", e.Result) + } +} + +func TestUnmarshalEntry_ToolExecution(t *testing.T) { + data := []byte(`{"object":"entry","id":"e5","type":"tool.execution","created_at":"2024-01-01T00:00:00Z","name":"web_search","arguments":"query"}`) + entry, err := UnmarshalEntry(data) + if err != nil { + t.Fatal(err) + } + e, ok := entry.(*ToolExecutionEntry) + if !ok { + t.Fatalf("expected *ToolExecutionEntry, got %T", entry) + } + if e.Name != "web_search" { + t.Errorf("got name %q", e.Name) + } +} + +func TestUnmarshalEntry_AgentHandoff(t *testing.T) { + data := []byte(`{"object":"entry","id":"e6","type":"agent.handoff","created_at":"2024-01-01T00:00:00Z","previous_agent_id":"a1","previous_agent_name":"Agent A","next_agent_id":"a2","next_agent_name":"Agent B"}`) + entry, err := UnmarshalEntry(data) + if err != nil { + t.Fatal(err) + } + e, ok := entry.(*AgentHandoffEntry) + if !ok { + t.Fatalf("expected *AgentHandoffEntry, got %T", entry) + } + if e.PreviousAgentName != "Agent A" { + t.Errorf("got prev %q", e.PreviousAgentName) + } + if e.NextAgentName != "Agent B" { + t.Errorf("got next %q", e.NextAgentName) + } +} + +func TestUnmarshalEntry_Unknown(t *testing.T) { + _, err := UnmarshalEntry([]byte(`{"type":"unknown.type"}`)) + if err == nil { + t.Error("expected error for unknown type") + } +} + +func TestTextContent_String(t *testing.T) { + raw := json.RawMessage(`"Hello world"`) + if TextContent(raw) != "Hello world" { + t.Errorf("got %q", TextContent(raw)) + } +} + +func TestTextContent_ChunkArray(t *testing.T) { + raw := json.RawMessage(`[{"type":"text","text":"Hello "},{"type":"text","text":"world"}]`) + if TextContent(raw) != "Hello world" { + t.Errorf("got %q", TextContent(raw)) + } +} + +func TestTextContent_Empty(t *testing.T) { + if TextContent(nil) != "" { + t.Error("expected empty for nil") + } + if TextContent(json.RawMessage{}) != "" { + t.Error("expected empty for empty") + } +} + +func TestTextContent_MixedChunks(t *testing.T) { + raw := json.RawMessage(`[{"type":"text","text":"Hello"},{"type":"tool_reference","tool":"web_search","title":"Result"},{"type":"text","text":" world"}]`) + if TextContent(raw) != "Hello world" { + t.Errorf("got %q", TextContent(raw)) + } +} diff --git a/conversation/event.go b/conversation/event.go new file mode 100644 index 0000000..330820d --- /dev/null +++ b/conversation/event.go @@ -0,0 +1,177 @@ +package conversation + +import ( + "encoding/json" + "fmt" +) + +// Event is a sealed interface for conversation streaming events. +type Event interface { + eventType() string +} + +// ResponseStartedEvent signals the start of a conversation response. +type ResponseStartedEvent struct { + Type string `json:"type"` + CreatedAt string `json:"created_at"` + ConversationID string `json:"conversation_id"` +} + +func (*ResponseStartedEvent) eventType() string { return "conversation.response.started" } + +// ResponseDoneEvent signals the completion of a conversation response. +type ResponseDoneEvent struct { + Type string `json:"type"` + CreatedAt string `json:"created_at"` + Usage UsageInfo `json:"usage"` +} + +func (*ResponseDoneEvent) eventType() string { return "conversation.response.done" } + +// ResponseErrorEvent signals an error during conversation processing. +type ResponseErrorEvent struct { + Type string `json:"type"` + CreatedAt string `json:"created_at"` + Message string `json:"message"` + Code int `json:"code"` +} + +func (*ResponseErrorEvent) eventType() string { return "conversation.response.error" } + +// MessageOutputEvent contains a delta of assistant message output. +type MessageOutputEvent struct { + Type string `json:"type"` + CreatedAt string `json:"created_at"` + OutputIndex int `json:"output_index"` + ID string `json:"id"` + ContentIndex int `json:"content_index"` + Content json.RawMessage `json:"content"` + Model *string `json:"model,omitempty"` + AgentID *string `json:"agent_id,omitempty"` + Role string `json:"role"` +} + +func (*MessageOutputEvent) eventType() string { return "message.output.delta" } + +// ToolExecutionStartedEvent signals the start of a tool execution. +type ToolExecutionStartedEvent struct { + Type string `json:"type"` + CreatedAt string `json:"created_at"` + OutputIndex int `json:"output_index"` + ID string `json:"id"` + Name string `json:"name"` + Arguments string `json:"arguments"` + Model *string `json:"model,omitempty"` + AgentID *string `json:"agent_id,omitempty"` +} + +func (*ToolExecutionStartedEvent) eventType() string { return "tool.execution.started" } + +// ToolExecutionDeltaEvent contains a delta of tool execution output. +type ToolExecutionDeltaEvent struct { + Type string `json:"type"` + CreatedAt string `json:"created_at"` + OutputIndex int `json:"output_index"` + ID string `json:"id"` + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +func (*ToolExecutionDeltaEvent) eventType() string { return "tool.execution.delta" } + +// ToolExecutionDoneEvent signals the completion of a tool execution. +type ToolExecutionDoneEvent struct { + Type string `json:"type"` + CreatedAt string `json:"created_at"` + OutputIndex int `json:"output_index"` + ID string `json:"id"` + Name string `json:"name"` + Info map[string]any `json:"info,omitempty"` +} + +func (*ToolExecutionDoneEvent) eventType() string { return "tool.execution.done" } + +// FunctionCallEvent contains a delta of a function call. +type FunctionCallEvent struct { + Type string `json:"type"` + CreatedAt string `json:"created_at"` + OutputIndex int `json:"output_index"` + ID string `json:"id"` + Name string `json:"name"` + ToolCallID string `json:"tool_call_id"` + Arguments string `json:"arguments"` + ConfirmationStatus *string `json:"confirmation_status,omitempty"` + Model *string `json:"model,omitempty"` + AgentID *string `json:"agent_id,omitempty"` +} + +func (*FunctionCallEvent) eventType() string { return "function.call.delta" } + +// AgentHandoffStartedEvent signals the start of an agent handoff. +type AgentHandoffStartedEvent struct { + Type string `json:"type"` + CreatedAt string `json:"created_at"` + OutputIndex int `json:"output_index"` + ID string `json:"id"` + PreviousAgentID string `json:"previous_agent_id"` + PreviousAgentName string `json:"previous_agent_name"` +} + +func (*AgentHandoffStartedEvent) eventType() string { return "agent.handoff.started" } + +// AgentHandoffDoneEvent signals the completion of an agent handoff. +type AgentHandoffDoneEvent struct { + Type string `json:"type"` + CreatedAt string `json:"created_at"` + OutputIndex int `json:"output_index"` + ID string `json:"id"` + NextAgentID string `json:"next_agent_id"` + NextAgentName string `json:"next_agent_name"` +} + +func (*AgentHandoffDoneEvent) eventType() string { return "agent.handoff.done" } + +// UnmarshalEvent dispatches JSON to the concrete Event type +// based on the "type" discriminator field. +func UnmarshalEvent(data []byte) (Event, error) { + var probe struct { + Type string `json:"type"` + } + if err := json.Unmarshal(data, &probe); err != nil { + return nil, fmt.Errorf("mistral: unmarshal event: %w", err) + } + switch probe.Type { + case "conversation.response.started": + var e ResponseStartedEvent + return &e, json.Unmarshal(data, &e) + case "conversation.response.done": + var e ResponseDoneEvent + return &e, json.Unmarshal(data, &e) + case "conversation.response.error": + var e ResponseErrorEvent + return &e, json.Unmarshal(data, &e) + case "message.output.delta": + var e MessageOutputEvent + return &e, json.Unmarshal(data, &e) + case "tool.execution.started": + var e ToolExecutionStartedEvent + return &e, json.Unmarshal(data, &e) + case "tool.execution.delta": + var e ToolExecutionDeltaEvent + return &e, json.Unmarshal(data, &e) + case "tool.execution.done": + var e ToolExecutionDoneEvent + return &e, json.Unmarshal(data, &e) + case "function.call.delta": + var e FunctionCallEvent + return &e, json.Unmarshal(data, &e) + case "agent.handoff.started": + var e AgentHandoffStartedEvent + return &e, json.Unmarshal(data, &e) + case "agent.handoff.done": + var e AgentHandoffDoneEvent + return &e, json.Unmarshal(data, &e) + default: + return nil, fmt.Errorf("mistral: unknown event type: %q", probe.Type) + } +} diff --git a/conversation/event_test.go b/conversation/event_test.go new file mode 100644 index 0000000..7a9c69b --- /dev/null +++ b/conversation/event_test.go @@ -0,0 +1,163 @@ +package conversation + +import ( + "encoding/json" + "testing" +) + +func TestUnmarshalEvent_ResponseStarted(t *testing.T) { + data := []byte(`{"type":"conversation.response.started","created_at":"2024-01-01T00:00:00Z","conversation_id":"conv-123"}`) + event, err := UnmarshalEvent(data) + if err != nil { + t.Fatal(err) + } + e, ok := event.(*ResponseStartedEvent) + if !ok { + t.Fatalf("expected *ResponseStartedEvent, got %T", event) + } + if e.ConversationID != "conv-123" { + t.Errorf("got %q", e.ConversationID) + } +} + +func TestUnmarshalEvent_ResponseDone(t *testing.T) { + data := []byte(`{"type":"conversation.response.done","created_at":"2024-01-01T00:00:00Z","usage":{"prompt_tokens":10,"completion_tokens":5,"total_tokens":15}}`) + event, err := UnmarshalEvent(data) + if err != nil { + t.Fatal(err) + } + e, ok := event.(*ResponseDoneEvent) + if !ok { + t.Fatalf("expected *ResponseDoneEvent, got %T", event) + } + if e.Usage.TotalTokens != 15 { + t.Errorf("got total_tokens %d", e.Usage.TotalTokens) + } +} + +func TestUnmarshalEvent_ResponseError(t *testing.T) { + data := []byte(`{"type":"conversation.response.error","created_at":"2024-01-01T00:00:00Z","message":"error occurred","code":500}`) + event, err := UnmarshalEvent(data) + if err != nil { + t.Fatal(err) + } + e, ok := event.(*ResponseErrorEvent) + if !ok { + t.Fatalf("expected *ResponseErrorEvent, got %T", event) + } + if e.Message != "error occurred" { + t.Errorf("got %q", e.Message) + } + if e.Code != 500 { + t.Errorf("got code %d", e.Code) + } +} + +func TestUnmarshalEvent_MessageOutput(t *testing.T) { + data := []byte(`{"type":"message.output.delta","created_at":"2024-01-01T00:00:00Z","output_index":0,"id":"m1","content_index":0,"content":"Hello","role":"assistant"}`) + event, err := UnmarshalEvent(data) + if err != nil { + t.Fatal(err) + } + e, ok := event.(*MessageOutputEvent) + if !ok { + t.Fatalf("expected *MessageOutputEvent, got %T", event) + } + if e.ID != "m1" { + t.Errorf("got id %q", e.ID) + } + if TextContent(e.Content) != "Hello" { + t.Errorf("got content %q", TextContent(e.Content)) + } +} + +func TestUnmarshalEvent_FunctionCall(t *testing.T) { + data := []byte(`{"type":"function.call.delta","created_at":"2024-01-01T00:00:00Z","output_index":0,"id":"fc1","name":"search","tool_call_id":"tc1","arguments":"{\"q\":\"test\"}"}`) + event, err := UnmarshalEvent(data) + if err != nil { + t.Fatal(err) + } + e, ok := event.(*FunctionCallEvent) + if !ok { + t.Fatalf("expected *FunctionCallEvent, got %T", event) + } + if e.Name != "search" { + t.Errorf("got name %q", e.Name) + } + if e.ToolCallID != "tc1" { + t.Errorf("got tool_call_id %q", e.ToolCallID) + } +} + +func TestUnmarshalEvent_ToolExecutionStarted(t *testing.T) { + data := []byte(`{"type":"tool.execution.started","created_at":"2024-01-01T00:00:00Z","output_index":0,"id":"te1","name":"web_search","arguments":"query"}`) + event, err := UnmarshalEvent(data) + if err != nil { + t.Fatal(err) + } + e, ok := event.(*ToolExecutionStartedEvent) + if !ok { + t.Fatalf("expected *ToolExecutionStartedEvent, got %T", event) + } + if e.Name != "web_search" { + t.Errorf("got %q", e.Name) + } +} + +func TestUnmarshalEvent_ToolExecutionDone(t *testing.T) { + data := []byte(`{"type":"tool.execution.done","created_at":"2024-01-01T00:00:00Z","output_index":0,"id":"te1","name":"web_search"}`) + event, err := UnmarshalEvent(data) + if err != nil { + t.Fatal(err) + } + _, ok := event.(*ToolExecutionDoneEvent) + if !ok { + t.Fatalf("expected *ToolExecutionDoneEvent, got %T", event) + } +} + +func TestUnmarshalEvent_AgentHandoff(t *testing.T) { + started := []byte(`{"type":"agent.handoff.started","created_at":"2024-01-01T00:00:00Z","output_index":0,"id":"h1","previous_agent_id":"a1","previous_agent_name":"A"}`) + event, err := UnmarshalEvent(started) + if err != nil { + t.Fatal(err) + } + hs, ok := event.(*AgentHandoffStartedEvent) + if !ok { + t.Fatalf("expected *AgentHandoffStartedEvent, got %T", event) + } + if hs.PreviousAgentID != "a1" { + t.Errorf("got %q", hs.PreviousAgentID) + } + + done := []byte(`{"type":"agent.handoff.done","created_at":"2024-01-01T00:00:00Z","output_index":0,"id":"h1","next_agent_id":"a2","next_agent_name":"B"}`) + event, err = UnmarshalEvent(done) + if err != nil { + t.Fatal(err) + } + hd, ok := event.(*AgentHandoffDoneEvent) + if !ok { + t.Fatalf("expected *AgentHandoffDoneEvent, got %T", event) + } + if hd.NextAgentID != "a2" { + t.Errorf("got %q", hd.NextAgentID) + } +} + +func TestUnmarshalEvent_Unknown(t *testing.T) { + _, err := UnmarshalEvent([]byte(`{"type":"unknown.event"}`)) + if err == nil { + t.Error("expected error for unknown type") + } +} + +func TestInputs_TextMarshal(t *testing.T) { + inputs := TextInputs("Hello") + data, err := json.Marshal(inputs) + if err != nil { + t.Fatal(err) + } + if string(data) != `"Hello"` { + t.Errorf("got %s", data) + } +} diff --git a/conversation/request.go b/conversation/request.go new file mode 100644 index 0000000..1fe9389 --- /dev/null +++ b/conversation/request.go @@ -0,0 +1,81 @@ +package conversation + +import "encoding/json" + +// StartRequest starts a new conversation. +type StartRequest struct { + Inputs Inputs `json:"inputs"` + Model string `json:"model,omitempty"` + AgentID string `json:"agent_id,omitempty"` + AgentVersion json.RawMessage `json:"agent_version,omitempty"` + Instructions *string `json:"instructions,omitempty"` + Tools []Tool `json:"tools,omitempty"` + CompletionArgs *CompletionArgs `json:"completion_args,omitempty"` + Store *bool `json:"store,omitempty"` + HandoffExecution *HandoffExecution `json:"handoff_execution,omitempty"` + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + stream bool +} + +func (r *StartRequest) SetStream(v bool) { r.stream = v } + +func (r *StartRequest) MarshalJSON() ([]byte, error) { + type Alias StartRequest + return json.Marshal(&struct { + Stream bool `json:"stream"` + *Alias + }{ + Stream: r.stream, + Alias: (*Alias)(r), + }) +} + +// AppendRequest appends to an existing conversation. +type AppendRequest struct { + Inputs Inputs `json:"inputs"` + CompletionArgs *CompletionArgs `json:"completion_args,omitempty"` + Store *bool `json:"store,omitempty"` + HandoffExecution *HandoffExecution `json:"handoff_execution,omitempty"` + ToolConfirmations []ToolCallConfirmation `json:"tool_confirmations,omitempty"` + stream bool +} + +func (r *AppendRequest) SetStream(v bool) { r.stream = v } + +func (r *AppendRequest) MarshalJSON() ([]byte, error) { + type Alias AppendRequest + return json.Marshal(&struct { + Stream bool `json:"stream"` + *Alias + }{ + Stream: r.stream, + Alias: (*Alias)(r), + }) +} + +// RestartRequest restarts a conversation from a specific entry. +type RestartRequest struct { + Inputs Inputs `json:"inputs"` + FromEntryID string `json:"from_entry_id"` + CompletionArgs *CompletionArgs `json:"completion_args,omitempty"` + Store *bool `json:"store,omitempty"` + HandoffExecution *HandoffExecution `json:"handoff_execution,omitempty"` + AgentVersion json.RawMessage `json:"agent_version,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + stream bool +} + +func (r *RestartRequest) SetStream(v bool) { r.stream = v } + +func (r *RestartRequest) MarshalJSON() ([]byte, error) { + type Alias RestartRequest + return json.Marshal(&struct { + Stream bool `json:"stream"` + *Alias + }{ + Stream: r.stream, + Alias: (*Alias)(r), + }) +} diff --git a/conversations.go b/conversations.go new file mode 100644 index 0000000..1ae5836 --- /dev/null +++ b/conversations.go @@ -0,0 +1,160 @@ +package mistral + +import ( + "context" + "encoding/json" + "fmt" + + "somegit.dev/vikingowl/mistral-go-sdk/conversation" +) + +// StartConversation creates and starts a new conversation. +func (c *Client) StartConversation(ctx context.Context, req *conversation.StartRequest) (*conversation.Response, error) { + var resp conversation.Response + if err := c.doJSON(ctx, "POST", "/v1/conversations", req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// StartConversationStream creates a conversation and returns a stream of events. +func (c *Client) StartConversationStream(ctx context.Context, req *conversation.StartRequest) (*EventStream, error) { + req.SetStream(true) + resp, err := c.doStream(ctx, "POST", "/v1/conversations", req) + if err != nil { + return nil, err + } + return newEventStream(resp.Body), nil +} + +// AppendConversation appends inputs to an existing conversation. +func (c *Client) AppendConversation(ctx context.Context, conversationID string, req *conversation.AppendRequest) (*conversation.Response, error) { + var resp conversation.Response + if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/conversations/%s", conversationID), req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// AppendConversationStream appends to a conversation and returns a stream of events. +func (c *Client) AppendConversationStream(ctx context.Context, conversationID string, req *conversation.AppendRequest) (*EventStream, error) { + req.SetStream(true) + resp, err := c.doStream(ctx, "POST", fmt.Sprintf("/v1/conversations/%s", conversationID), req) + if err != nil { + return nil, err + } + return newEventStream(resp.Body), nil +} + +// RestartConversation restarts a conversation from a specific entry. +func (c *Client) RestartConversation(ctx context.Context, conversationID string, req *conversation.RestartRequest) (*conversation.Response, error) { + var resp conversation.Response + if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/conversations/%s/restart", conversationID), req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// RestartConversationStream restarts a conversation and returns a stream of events. +func (c *Client) RestartConversationStream(ctx context.Context, conversationID string, req *conversation.RestartRequest) (*EventStream, error) { + req.SetStream(true) + resp, err := c.doStream(ctx, "POST", fmt.Sprintf("/v1/conversations/%s/restart", conversationID), req) + if err != nil { + return nil, err + } + return newEventStream(resp.Body), nil +} + +// GetConversation retrieves conversation metadata. +func (c *Client) GetConversation(ctx context.Context, conversationID string) (*conversation.Conversation, error) { + var resp conversation.Conversation + if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/conversations/%s", conversationID), nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// ListConversations lists conversations with optional pagination. +func (c *Client) ListConversations(ctx context.Context) ([]conversation.Conversation, error) { + var resp []conversation.Conversation + if err := c.doJSON(ctx, "GET", "/v1/conversations", nil, &resp); err != nil { + return nil, err + } + return resp, nil +} + +// DeleteConversation deletes a conversation. +func (c *Client) DeleteConversation(ctx context.Context, conversationID string) error { + resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/conversations/%s", conversationID), nil) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode >= 400 { + return parseAPIError(resp) + } + return nil +} + +// GetConversationHistory returns the full history of a conversation. +func (c *Client) GetConversationHistory(ctx context.Context, conversationID string) (*conversation.History, error) { + var resp conversation.History + if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/conversations/%s/history", conversationID), nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// GetConversationMessages returns the messages of a conversation. +func (c *Client) GetConversationMessages(ctx context.Context, conversationID string) (*conversation.Messages, error) { + var resp conversation.Messages + if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/conversations/%s/messages", conversationID), nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// EventStream wraps the generic Stream to provide typed conversation events. +type EventStream struct { + stream *Stream[json.RawMessage] + event conversation.Event + err error +} + +func newEventStream(body readCloser) *EventStream { + return &EventStream{ + stream: newStream[json.RawMessage](body), + } +} + +// Next advances to the next event. Returns false when done or on error. +func (s *EventStream) Next() bool { + if s.err != nil { + return false + } + if !s.stream.Next() { + s.err = s.stream.Err() + return false + } + event, err := conversation.UnmarshalEvent(s.stream.Current()) + if err != nil { + s.err = err + return false + } + s.event = event + return true +} + +// Current returns the most recently read event. +func (s *EventStream) Current() conversation.Event { return s.event } + +// Err returns any error encountered during streaming. +func (s *EventStream) Err() error { return s.err } + +// Close releases the underlying connection. +func (s *EventStream) Close() error { return s.stream.Close() } + +type readCloser = interface { + Read(p []byte) (n int, err error) + Close() error +} diff --git a/conversations_test.go b/conversations_test.go new file mode 100644 index 0000000..738c742 --- /dev/null +++ b/conversations_test.go @@ -0,0 +1,249 @@ +package mistral + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "somegit.dev/vikingowl/mistral-go-sdk/conversation" +) + +func TestStartConversation_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/conversations" { + t.Errorf("got path %s", r.URL.Path) + } + if r.Method != "POST" { + t.Errorf("got method %s", r.Method) + } + var body map[string]any + json.NewDecoder(r.Body).Decode(&body) + if body["inputs"] != "Hello" { + t.Errorf("expected inputs=Hello, got %v", body["inputs"]) + } + if body["model"] != "mistral-small-latest" { + t.Errorf("expected model, got %v", body["model"]) + } + + json.NewEncoder(w).Encode(map[string]any{ + "object": "conversation.response", + "conversation_id": "conv-123", + "outputs": []map[string]any{{ + "object": "entry", "id": "e1", "type": "message.output", + "created_at": "2024-01-01T00:00:00Z", + "role": "assistant", "content": "Hello! How can I help?", + }}, + "usage": map[string]any{ + "prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18, + }, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.StartConversation(context.Background(), &conversation.StartRequest{ + Inputs: conversation.TextInputs("Hello"), + Model: "mistral-small-latest", + }) + if err != nil { + t.Fatal(err) + } + if resp.ConversationID != "conv-123" { + t.Errorf("got conv id %q", resp.ConversationID) + } + if len(resp.Outputs) != 1 { + t.Fatalf("got %d outputs", len(resp.Outputs)) + } + out, ok := resp.Outputs[0].(*conversation.MessageOutputEntry) + if !ok { + t.Fatalf("expected *MessageOutputEntry, got %T", resp.Outputs[0]) + } + if conversation.TextContent(out.Content) != "Hello! How can I help?" { + t.Errorf("got %q", conversation.TextContent(out.Content)) + } + if resp.Usage.TotalTokens != 18 { + t.Errorf("got total_tokens %d", resp.Usage.TotalTokens) + } +} + +func TestAppendConversation_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/conversations/conv-123" { + t.Errorf("got path %s", r.URL.Path) + } + json.NewEncoder(w).Encode(map[string]any{ + "object": "conversation.response", "conversation_id": "conv-123", + "outputs": []map[string]any{{ + "object": "entry", "id": "e2", "type": "message.output", + "created_at": "2024-01-01T00:00:00Z", + "role": "assistant", "content": "The weather is sunny.", + }}, + "usage": map[string]any{"prompt_tokens": 20, "completion_tokens": 5, "total_tokens": 25}, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.AppendConversation(context.Background(), "conv-123", &conversation.AppendRequest{ + Inputs: conversation.TextInputs("What's the weather?"), + }) + if err != nil { + t.Fatal(err) + } + if resp.ConversationID != "conv-123" { + t.Errorf("got %q", resp.ConversationID) + } +} + +func TestGetConversation_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/conversations/conv-123" { + t.Errorf("got path %s", r.URL.Path) + } + if r.Method != "GET" { + t.Errorf("got method %s", r.Method) + } + json.NewEncoder(w).Encode(map[string]any{ + "object": "conversation", "id": "conv-123", + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:01:00Z", + "model": "mistral-small-latest", + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + conv, err := client.GetConversation(context.Background(), "conv-123") + if err != nil { + t.Fatal(err) + } + if conv.ID != "conv-123" { + t.Errorf("got id %q", conv.ID) + } + if conv.Model != "mistral-small-latest" { + t.Errorf("got model %q", conv.Model) + } +} + +func TestDeleteConversation_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "DELETE" { + t.Errorf("got method %s", r.Method) + } + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + err := client.DeleteConversation(context.Background(), "conv-123") + if err != nil { + t.Fatal(err) + } +} + +func TestGetConversationHistory_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/conversations/conv-123/history" { + t.Errorf("got path %s", r.URL.Path) + } + json.NewEncoder(w).Encode(map[string]any{ + "object": "conversation.history", "conversation_id": "conv-123", + "entries": []map[string]any{ + {"object": "entry", "id": "e1", "type": "message.input", "created_at": "2024-01-01T00:00:00Z", "role": "user", "content": "Hi"}, + {"object": "entry", "id": "e2", "type": "message.output", "created_at": "2024-01-01T00:00:01Z", "role": "assistant", "content": "Hello!"}, + }, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + history, err := client.GetConversationHistory(context.Background(), "conv-123") + if err != nil { + t.Fatal(err) + } + if len(history.Entries) != 2 { + t.Fatalf("got %d entries", len(history.Entries)) + } + if _, ok := history.Entries[0].(*conversation.MessageInputEntry); !ok { + t.Errorf("expected *MessageInputEntry, got %T", history.Entries[0]) + } + if _, ok := history.Entries[1].(*conversation.MessageOutputEntry); !ok { + t.Errorf("expected *MessageOutputEntry, got %T", history.Entries[1]) + } +} + +func TestStartConversationStream_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body map[string]any + json.NewDecoder(r.Body).Decode(&body) + if body["stream"] != true { + t.Errorf("expected stream=true") + } + + w.Header().Set("Content-Type", "text/event-stream") + flusher, _ := w.(http.Flusher) + + events := []map[string]any{ + {"type": "conversation.response.started", "created_at": "2024-01-01T00:00:00Z", "conversation_id": "conv-456"}, + {"type": "message.output.delta", "created_at": "2024-01-01T00:00:00Z", "output_index": 0, "id": "m1", "content_index": 0, "content": "Hello", "role": "assistant"}, + {"type": "message.output.delta", "created_at": "2024-01-01T00:00:00Z", "output_index": 0, "id": "m1", "content_index": 0, "content": " world!", "role": "assistant"}, + {"type": "conversation.response.done", "created_at": "2024-01-01T00:00:01Z", "usage": map[string]any{"prompt_tokens": 5, "completion_tokens": 2, "total_tokens": 7}}, + } + for _, ev := range events { + data, _ := json.Marshal(ev) + fmt.Fprintf(w, "data: %s\n\n", data) + flusher.Flush() + } + fmt.Fprint(w, "data: [DONE]\n\n") + flusher.Flush() + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + stream, err := client.StartConversationStream(context.Background(), &conversation.StartRequest{ + Inputs: conversation.TextInputs("Hi"), + Model: "mistral-small-latest", + }) + if err != nil { + t.Fatal(err) + } + defer stream.Close() + + var events []conversation.Event + for stream.Next() { + events = append(events, stream.Current()) + } + if stream.Err() != nil { + t.Fatal(stream.Err()) + } + if len(events) != 4 { + t.Fatalf("got %d events, want 4", len(events)) + } + + started, ok := events[0].(*conversation.ResponseStartedEvent) + if !ok { + t.Fatalf("expected *ResponseStartedEvent, got %T", events[0]) + } + if started.ConversationID != "conv-456" { + t.Errorf("got conv id %q", started.ConversationID) + } + + msg, ok := events[1].(*conversation.MessageOutputEvent) + if !ok { + t.Fatalf("expected *MessageOutputEvent, got %T", events[1]) + } + if conversation.TextContent(msg.Content) != "Hello" { + t.Errorf("got %q", conversation.TextContent(msg.Content)) + } + + done, ok := events[3].(*conversation.ResponseDoneEvent) + if !ok { + t.Fatalf("expected *ResponseDoneEvent, got %T", events[3]) + } + if done.Usage.TotalTokens != 7 { + t.Errorf("got total_tokens %d", done.Usage.TotalTokens) + } +}