feat: Phase 6 conversations — start, append, restart, stream, history

Conversations API with full CRUD: start, append, restart (+ stream
variants), get, list, delete, history, messages. Discriminated unions
for entries (6 types) and streaming events (10 types). EventStream
wraps Stream[json.RawMessage] with typed event dispatch.
This commit is contained in:
2026-03-05 19:53:41 +01:00
parent 9778dd6a8e
commit 8c655893fc
8 changed files with 1318 additions and 0 deletions

View File

@@ -0,0 +1,182 @@
package conversation
import (
"encoding/json"
"fmt"
"somegit.dev/vikingowl/mistral-go-sdk/chat"
)
// HandoffExecution controls tool call execution.
type HandoffExecution string
const (
HandoffClient HandoffExecution = "client"
HandoffServer HandoffExecution = "server"
)
// Tool represents a conversation tool.
type Tool struct {
Type string `json:"type"`
Function *chat.Function `json:"function,omitempty"`
LibraryIDs []string `json:"library_ids,omitempty"`
ToolConfiguration *ToolConfig `json:"tool_configuration,omitempty"`
}
// ToolConfig configures tool behavior.
type ToolConfig struct {
Exclude []string `json:"exclude,omitempty"`
Include []string `json:"include,omitempty"`
RequiresConfirmation []string `json:"requires_confirmation,omitempty"`
}
// CompletionArgs holds optional completion parameters.
type CompletionArgs struct {
Stop []string `json:"stop,omitempty"`
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
MaxTokens *int `json:"max_tokens,omitempty"`
RandomSeed *int `json:"random_seed,omitempty"`
Prediction *chat.Prediction `json:"prediction,omitempty"`
ResponseFormat *chat.ResponseFormat `json:"response_format,omitempty"`
ToolChoice *chat.ToolChoiceMode `json:"tool_choice,omitempty"`
}
// ToolCallConfirmation confirms or denies a pending tool call.
type ToolCallConfirmation struct {
ToolCallID string `json:"tool_call_id"`
Confirmation string `json:"confirmation"` // "allow" or "deny"
}
// Inputs represents conversation inputs (text string or entry array).
type Inputs struct {
text *string
entries []Entry
}
// TextInputs creates Inputs from a plain text string.
func TextInputs(s string) Inputs { return Inputs{text: &s} }
// EntryInputs creates Inputs from entry objects.
func EntryInputs(entries ...Entry) Inputs { return Inputs{entries: entries} }
func (i Inputs) MarshalJSON() ([]byte, error) {
if i.text != nil {
return json.Marshal(*i.text)
}
return json.Marshal(i.entries)
}
// UsageInfo contains conversation token usage.
type UsageInfo struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
ConnectorTokens *int `json:"connector_tokens,omitempty"`
Connectors map[string]int `json:"connectors,omitempty"`
}
// Response is the response from starting, appending, or restarting a conversation.
type Response struct {
Object string `json:"object"`
ConversationID string `json:"conversation_id"`
Outputs []Entry `json:"-"`
Usage UsageInfo `json:"usage"`
Guardrails json.RawMessage `json:"guardrails,omitempty"`
}
func (r *Response) UnmarshalJSON(data []byte) error {
type alias Response
var raw struct {
*alias
Outputs []json.RawMessage `json:"outputs"`
}
raw.alias = (*alias)(r)
if err := json.Unmarshal(data, &raw); err != nil {
return fmt.Errorf("mistral: unmarshal conversation response: %w", err)
}
r.Outputs = make([]Entry, len(raw.Outputs))
for i, o := range raw.Outputs {
entry, err := UnmarshalEntry(o)
if err != nil {
return err
}
r.Outputs[i] = entry
}
return nil
}
// Conversation represents conversation metadata.
type Conversation struct {
Object string `json:"object"`
ID string `json:"id"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
Model string `json:"model,omitempty"`
AgentID string `json:"agent_id,omitempty"`
AgentVersion json.RawMessage `json:"agent_version,omitempty"`
Name *string `json:"name,omitempty"`
Description *string `json:"description,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
Instructions *string `json:"instructions,omitempty"`
Tools []Tool `json:"tools,omitempty"`
CompletionArgs *CompletionArgs `json:"completion_args,omitempty"`
}
// History is the response from getting conversation history.
type History struct {
Object string `json:"object"`
ConversationID string `json:"conversation_id"`
Entries []Entry `json:"-"`
}
func (h *History) UnmarshalJSON(data []byte) error {
type alias History
var raw struct {
*alias
Entries []json.RawMessage `json:"entries"`
}
raw.alias = (*alias)(h)
if err := json.Unmarshal(data, &raw); err != nil {
return fmt.Errorf("mistral: unmarshal conversation history: %w", err)
}
h.Entries = make([]Entry, len(raw.Entries))
for i, e := range raw.Entries {
entry, err := UnmarshalEntry(e)
if err != nil {
return err
}
h.Entries[i] = entry
}
return nil
}
// Messages is the response from getting conversation messages.
type Messages struct {
Object string `json:"object"`
ConversationID string `json:"conversation_id"`
Messages []Entry `json:"-"`
}
func (m *Messages) UnmarshalJSON(data []byte) error {
type alias Messages
var raw struct {
*alias
Messages []json.RawMessage `json:"messages"`
}
raw.alias = (*alias)(m)
if err := json.Unmarshal(data, &raw); err != nil {
return fmt.Errorf("mistral: unmarshal conversation messages: %w", err)
}
m.Messages = make([]Entry, len(raw.Messages))
for i, msg := range raw.Messages {
entry, err := UnmarshalEntry(msg)
if err != nil {
return err
}
m.Messages[i] = entry
}
return nil
}

161
conversation/entry.go Normal file
View File

@@ -0,0 +1,161 @@
package conversation
import (
"encoding/json"
"fmt"
"strings"
)
// Entry is a sealed interface for conversation history entries.
type Entry interface {
entryType() string
}
// MessageInputEntry represents a user or assistant input message.
type MessageInputEntry struct {
Object string `json:"object"`
ID string `json:"id"`
Type string `json:"type"`
CreatedAt string `json:"created_at"`
CompletedAt *string `json:"completed_at,omitempty"`
Role string `json:"role"`
Content json.RawMessage `json:"content"`
Prefix bool `json:"prefix,omitempty"`
}
func (*MessageInputEntry) entryType() string { return "message.input" }
// MessageOutputEntry represents an assistant output message.
type MessageOutputEntry struct {
Object string `json:"object"`
ID string `json:"id"`
Type string `json:"type"`
CreatedAt string `json:"created_at"`
CompletedAt *string `json:"completed_at,omitempty"`
Role string `json:"role"`
Content json.RawMessage `json:"content"`
AgentID *string `json:"agent_id,omitempty"`
Model *string `json:"model,omitempty"`
}
func (*MessageOutputEntry) entryType() string { return "message.output" }
// FunctionCallEntry represents a function call by the model.
type FunctionCallEntry struct {
Object string `json:"object"`
ID string `json:"id"`
Type string `json:"type"`
CreatedAt string `json:"created_at"`
CompletedAt *string `json:"completed_at,omitempty"`
ToolCallID string `json:"tool_call_id"`
Name string `json:"name"`
Arguments json.RawMessage `json:"arguments"`
ConfirmationStatus *string `json:"confirmation_status,omitempty"`
AgentID *string `json:"agent_id,omitempty"`
Model *string `json:"model,omitempty"`
}
func (*FunctionCallEntry) entryType() string { return "function.call" }
// FunctionResultEntry represents a function result provided by the client.
type FunctionResultEntry struct {
Object string `json:"object"`
ID string `json:"id"`
Type string `json:"type"`
CreatedAt string `json:"created_at"`
CompletedAt *string `json:"completed_at,omitempty"`
ToolCallID string `json:"tool_call_id"`
Result string `json:"result"`
}
func (*FunctionResultEntry) entryType() string { return "function.result" }
// ToolExecutionEntry represents a built-in tool execution.
type ToolExecutionEntry struct {
Object string `json:"object"`
ID string `json:"id"`
Type string `json:"type"`
CreatedAt string `json:"created_at"`
CompletedAt *string `json:"completed_at,omitempty"`
Name string `json:"name"`
Arguments string `json:"arguments"`
Info map[string]any `json:"info,omitempty"`
AgentID *string `json:"agent_id,omitempty"`
Model *string `json:"model,omitempty"`
}
func (*ToolExecutionEntry) entryType() string { return "tool.execution" }
// AgentHandoffEntry represents an agent-to-agent handoff.
type AgentHandoffEntry struct {
Object string `json:"object"`
ID string `json:"id"`
Type string `json:"type"`
CreatedAt string `json:"created_at"`
CompletedAt *string `json:"completed_at,omitempty"`
PreviousAgentID string `json:"previous_agent_id"`
PreviousAgentName string `json:"previous_agent_name"`
NextAgentID string `json:"next_agent_id"`
NextAgentName string `json:"next_agent_name"`
}
func (*AgentHandoffEntry) entryType() string { return "agent.handoff" }
// UnmarshalEntry dispatches JSON to the concrete Entry type
// based on the "type" discriminator field.
func UnmarshalEntry(data []byte) (Entry, error) {
var probe struct {
Type string `json:"type"`
}
if err := json.Unmarshal(data, &probe); err != nil {
return nil, fmt.Errorf("mistral: unmarshal entry: %w", err)
}
switch probe.Type {
case "message.input":
var e MessageInputEntry
return &e, json.Unmarshal(data, &e)
case "message.output":
var e MessageOutputEntry
return &e, json.Unmarshal(data, &e)
case "function.call":
var e FunctionCallEntry
return &e, json.Unmarshal(data, &e)
case "function.result":
var e FunctionResultEntry
return &e, json.Unmarshal(data, &e)
case "tool.execution":
var e ToolExecutionEntry
return &e, json.Unmarshal(data, &e)
case "agent.handoff":
var e AgentHandoffEntry
return &e, json.Unmarshal(data, &e)
default:
return nil, fmt.Errorf("mistral: unknown entry type: %q", probe.Type)
}
}
// TextContent extracts text from a raw content field.
// Handles both string content and chunk arrays (extracts text chunks).
func TextContent(raw json.RawMessage) string {
if len(raw) == 0 {
return ""
}
var s string
if json.Unmarshal(raw, &s) == nil {
return s
}
var chunks []struct {
Type string `json:"type"`
Text string `json:"text"`
}
if json.Unmarshal(raw, &chunks) == nil {
var sb strings.Builder
for _, ch := range chunks {
if ch.Type == "text" {
sb.WriteString(ch.Text)
}
}
return sb.String()
}
return ""
}

145
conversation/entry_test.go Normal file
View File

@@ -0,0 +1,145 @@
package conversation
import (
"encoding/json"
"testing"
)
func TestUnmarshalEntry_MessageInput(t *testing.T) {
data := []byte(`{"object":"entry","id":"e1","type":"message.input","created_at":"2024-01-01T00:00:00Z","role":"user","content":"Hello"}`)
entry, err := UnmarshalEntry(data)
if err != nil {
t.Fatal(err)
}
e, ok := entry.(*MessageInputEntry)
if !ok {
t.Fatalf("expected *MessageInputEntry, got %T", entry)
}
if e.ID != "e1" {
t.Errorf("got id %q", e.ID)
}
if e.Role != "user" {
t.Errorf("got role %q", e.Role)
}
if TextContent(e.Content) != "Hello" {
t.Errorf("got content %q", TextContent(e.Content))
}
}
func TestUnmarshalEntry_MessageOutput(t *testing.T) {
data := []byte(`{"object":"entry","id":"e2","type":"message.output","created_at":"2024-01-01T00:00:00Z","role":"assistant","content":"Hi there!"}`)
entry, err := UnmarshalEntry(data)
if err != nil {
t.Fatal(err)
}
e, ok := entry.(*MessageOutputEntry)
if !ok {
t.Fatalf("expected *MessageOutputEntry, got %T", entry)
}
if TextContent(e.Content) != "Hi there!" {
t.Errorf("got content %q", TextContent(e.Content))
}
}
func TestUnmarshalEntry_FunctionCall(t *testing.T) {
data := []byte(`{"object":"entry","id":"e3","type":"function.call","created_at":"2024-01-01T00:00:00Z","tool_call_id":"tc1","name":"get_weather","arguments":"{\"city\":\"Paris\"}"}`)
entry, err := UnmarshalEntry(data)
if err != nil {
t.Fatal(err)
}
e, ok := entry.(*FunctionCallEntry)
if !ok {
t.Fatalf("expected *FunctionCallEntry, got %T", entry)
}
if e.Name != "get_weather" {
t.Errorf("got name %q", e.Name)
}
if e.ToolCallID != "tc1" {
t.Errorf("got tool_call_id %q", e.ToolCallID)
}
}
func TestUnmarshalEntry_FunctionResult(t *testing.T) {
data := []byte(`{"object":"entry","id":"e4","type":"function.result","created_at":"2024-01-01T00:00:00Z","tool_call_id":"tc1","result":"{\"temp\":22}"}`)
entry, err := UnmarshalEntry(data)
if err != nil {
t.Fatal(err)
}
e, ok := entry.(*FunctionResultEntry)
if !ok {
t.Fatalf("expected *FunctionResultEntry, got %T", entry)
}
if e.Result != `{"temp":22}` {
t.Errorf("got result %q", e.Result)
}
}
func TestUnmarshalEntry_ToolExecution(t *testing.T) {
data := []byte(`{"object":"entry","id":"e5","type":"tool.execution","created_at":"2024-01-01T00:00:00Z","name":"web_search","arguments":"query"}`)
entry, err := UnmarshalEntry(data)
if err != nil {
t.Fatal(err)
}
e, ok := entry.(*ToolExecutionEntry)
if !ok {
t.Fatalf("expected *ToolExecutionEntry, got %T", entry)
}
if e.Name != "web_search" {
t.Errorf("got name %q", e.Name)
}
}
func TestUnmarshalEntry_AgentHandoff(t *testing.T) {
data := []byte(`{"object":"entry","id":"e6","type":"agent.handoff","created_at":"2024-01-01T00:00:00Z","previous_agent_id":"a1","previous_agent_name":"Agent A","next_agent_id":"a2","next_agent_name":"Agent B"}`)
entry, err := UnmarshalEntry(data)
if err != nil {
t.Fatal(err)
}
e, ok := entry.(*AgentHandoffEntry)
if !ok {
t.Fatalf("expected *AgentHandoffEntry, got %T", entry)
}
if e.PreviousAgentName != "Agent A" {
t.Errorf("got prev %q", e.PreviousAgentName)
}
if e.NextAgentName != "Agent B" {
t.Errorf("got next %q", e.NextAgentName)
}
}
func TestUnmarshalEntry_Unknown(t *testing.T) {
_, err := UnmarshalEntry([]byte(`{"type":"unknown.type"}`))
if err == nil {
t.Error("expected error for unknown type")
}
}
func TestTextContent_String(t *testing.T) {
raw := json.RawMessage(`"Hello world"`)
if TextContent(raw) != "Hello world" {
t.Errorf("got %q", TextContent(raw))
}
}
func TestTextContent_ChunkArray(t *testing.T) {
raw := json.RawMessage(`[{"type":"text","text":"Hello "},{"type":"text","text":"world"}]`)
if TextContent(raw) != "Hello world" {
t.Errorf("got %q", TextContent(raw))
}
}
func TestTextContent_Empty(t *testing.T) {
if TextContent(nil) != "" {
t.Error("expected empty for nil")
}
if TextContent(json.RawMessage{}) != "" {
t.Error("expected empty for empty")
}
}
func TestTextContent_MixedChunks(t *testing.T) {
raw := json.RawMessage(`[{"type":"text","text":"Hello"},{"type":"tool_reference","tool":"web_search","title":"Result"},{"type":"text","text":" world"}]`)
if TextContent(raw) != "Hello world" {
t.Errorf("got %q", TextContent(raw))
}
}

177
conversation/event.go Normal file
View File

@@ -0,0 +1,177 @@
package conversation
import (
"encoding/json"
"fmt"
)
// Event is a sealed interface for conversation streaming events.
type Event interface {
eventType() string
}
// ResponseStartedEvent signals the start of a conversation response.
type ResponseStartedEvent struct {
Type string `json:"type"`
CreatedAt string `json:"created_at"`
ConversationID string `json:"conversation_id"`
}
func (*ResponseStartedEvent) eventType() string { return "conversation.response.started" }
// ResponseDoneEvent signals the completion of a conversation response.
type ResponseDoneEvent struct {
Type string `json:"type"`
CreatedAt string `json:"created_at"`
Usage UsageInfo `json:"usage"`
}
func (*ResponseDoneEvent) eventType() string { return "conversation.response.done" }
// ResponseErrorEvent signals an error during conversation processing.
type ResponseErrorEvent struct {
Type string `json:"type"`
CreatedAt string `json:"created_at"`
Message string `json:"message"`
Code int `json:"code"`
}
func (*ResponseErrorEvent) eventType() string { return "conversation.response.error" }
// MessageOutputEvent contains a delta of assistant message output.
type MessageOutputEvent struct {
Type string `json:"type"`
CreatedAt string `json:"created_at"`
OutputIndex int `json:"output_index"`
ID string `json:"id"`
ContentIndex int `json:"content_index"`
Content json.RawMessage `json:"content"`
Model *string `json:"model,omitempty"`
AgentID *string `json:"agent_id,omitempty"`
Role string `json:"role"`
}
func (*MessageOutputEvent) eventType() string { return "message.output.delta" }
// ToolExecutionStartedEvent signals the start of a tool execution.
type ToolExecutionStartedEvent struct {
Type string `json:"type"`
CreatedAt string `json:"created_at"`
OutputIndex int `json:"output_index"`
ID string `json:"id"`
Name string `json:"name"`
Arguments string `json:"arguments"`
Model *string `json:"model,omitempty"`
AgentID *string `json:"agent_id,omitempty"`
}
func (*ToolExecutionStartedEvent) eventType() string { return "tool.execution.started" }
// ToolExecutionDeltaEvent contains a delta of tool execution output.
type ToolExecutionDeltaEvent struct {
Type string `json:"type"`
CreatedAt string `json:"created_at"`
OutputIndex int `json:"output_index"`
ID string `json:"id"`
Name string `json:"name"`
Arguments string `json:"arguments"`
}
func (*ToolExecutionDeltaEvent) eventType() string { return "tool.execution.delta" }
// ToolExecutionDoneEvent signals the completion of a tool execution.
type ToolExecutionDoneEvent struct {
Type string `json:"type"`
CreatedAt string `json:"created_at"`
OutputIndex int `json:"output_index"`
ID string `json:"id"`
Name string `json:"name"`
Info map[string]any `json:"info,omitempty"`
}
func (*ToolExecutionDoneEvent) eventType() string { return "tool.execution.done" }
// FunctionCallEvent contains a delta of a function call.
type FunctionCallEvent struct {
Type string `json:"type"`
CreatedAt string `json:"created_at"`
OutputIndex int `json:"output_index"`
ID string `json:"id"`
Name string `json:"name"`
ToolCallID string `json:"tool_call_id"`
Arguments string `json:"arguments"`
ConfirmationStatus *string `json:"confirmation_status,omitempty"`
Model *string `json:"model,omitempty"`
AgentID *string `json:"agent_id,omitempty"`
}
func (*FunctionCallEvent) eventType() string { return "function.call.delta" }
// AgentHandoffStartedEvent signals the start of an agent handoff.
type AgentHandoffStartedEvent struct {
Type string `json:"type"`
CreatedAt string `json:"created_at"`
OutputIndex int `json:"output_index"`
ID string `json:"id"`
PreviousAgentID string `json:"previous_agent_id"`
PreviousAgentName string `json:"previous_agent_name"`
}
func (*AgentHandoffStartedEvent) eventType() string { return "agent.handoff.started" }
// AgentHandoffDoneEvent signals the completion of an agent handoff.
type AgentHandoffDoneEvent struct {
Type string `json:"type"`
CreatedAt string `json:"created_at"`
OutputIndex int `json:"output_index"`
ID string `json:"id"`
NextAgentID string `json:"next_agent_id"`
NextAgentName string `json:"next_agent_name"`
}
func (*AgentHandoffDoneEvent) eventType() string { return "agent.handoff.done" }
// UnmarshalEvent dispatches JSON to the concrete Event type
// based on the "type" discriminator field.
func UnmarshalEvent(data []byte) (Event, error) {
var probe struct {
Type string `json:"type"`
}
if err := json.Unmarshal(data, &probe); err != nil {
return nil, fmt.Errorf("mistral: unmarshal event: %w", err)
}
switch probe.Type {
case "conversation.response.started":
var e ResponseStartedEvent
return &e, json.Unmarshal(data, &e)
case "conversation.response.done":
var e ResponseDoneEvent
return &e, json.Unmarshal(data, &e)
case "conversation.response.error":
var e ResponseErrorEvent
return &e, json.Unmarshal(data, &e)
case "message.output.delta":
var e MessageOutputEvent
return &e, json.Unmarshal(data, &e)
case "tool.execution.started":
var e ToolExecutionStartedEvent
return &e, json.Unmarshal(data, &e)
case "tool.execution.delta":
var e ToolExecutionDeltaEvent
return &e, json.Unmarshal(data, &e)
case "tool.execution.done":
var e ToolExecutionDoneEvent
return &e, json.Unmarshal(data, &e)
case "function.call.delta":
var e FunctionCallEvent
return &e, json.Unmarshal(data, &e)
case "agent.handoff.started":
var e AgentHandoffStartedEvent
return &e, json.Unmarshal(data, &e)
case "agent.handoff.done":
var e AgentHandoffDoneEvent
return &e, json.Unmarshal(data, &e)
default:
return nil, fmt.Errorf("mistral: unknown event type: %q", probe.Type)
}
}

163
conversation/event_test.go Normal file
View File

@@ -0,0 +1,163 @@
package conversation
import (
"encoding/json"
"testing"
)
func TestUnmarshalEvent_ResponseStarted(t *testing.T) {
data := []byte(`{"type":"conversation.response.started","created_at":"2024-01-01T00:00:00Z","conversation_id":"conv-123"}`)
event, err := UnmarshalEvent(data)
if err != nil {
t.Fatal(err)
}
e, ok := event.(*ResponseStartedEvent)
if !ok {
t.Fatalf("expected *ResponseStartedEvent, got %T", event)
}
if e.ConversationID != "conv-123" {
t.Errorf("got %q", e.ConversationID)
}
}
func TestUnmarshalEvent_ResponseDone(t *testing.T) {
data := []byte(`{"type":"conversation.response.done","created_at":"2024-01-01T00:00:00Z","usage":{"prompt_tokens":10,"completion_tokens":5,"total_tokens":15}}`)
event, err := UnmarshalEvent(data)
if err != nil {
t.Fatal(err)
}
e, ok := event.(*ResponseDoneEvent)
if !ok {
t.Fatalf("expected *ResponseDoneEvent, got %T", event)
}
if e.Usage.TotalTokens != 15 {
t.Errorf("got total_tokens %d", e.Usage.TotalTokens)
}
}
func TestUnmarshalEvent_ResponseError(t *testing.T) {
data := []byte(`{"type":"conversation.response.error","created_at":"2024-01-01T00:00:00Z","message":"error occurred","code":500}`)
event, err := UnmarshalEvent(data)
if err != nil {
t.Fatal(err)
}
e, ok := event.(*ResponseErrorEvent)
if !ok {
t.Fatalf("expected *ResponseErrorEvent, got %T", event)
}
if e.Message != "error occurred" {
t.Errorf("got %q", e.Message)
}
if e.Code != 500 {
t.Errorf("got code %d", e.Code)
}
}
func TestUnmarshalEvent_MessageOutput(t *testing.T) {
data := []byte(`{"type":"message.output.delta","created_at":"2024-01-01T00:00:00Z","output_index":0,"id":"m1","content_index":0,"content":"Hello","role":"assistant"}`)
event, err := UnmarshalEvent(data)
if err != nil {
t.Fatal(err)
}
e, ok := event.(*MessageOutputEvent)
if !ok {
t.Fatalf("expected *MessageOutputEvent, got %T", event)
}
if e.ID != "m1" {
t.Errorf("got id %q", e.ID)
}
if TextContent(e.Content) != "Hello" {
t.Errorf("got content %q", TextContent(e.Content))
}
}
func TestUnmarshalEvent_FunctionCall(t *testing.T) {
data := []byte(`{"type":"function.call.delta","created_at":"2024-01-01T00:00:00Z","output_index":0,"id":"fc1","name":"search","tool_call_id":"tc1","arguments":"{\"q\":\"test\"}"}`)
event, err := UnmarshalEvent(data)
if err != nil {
t.Fatal(err)
}
e, ok := event.(*FunctionCallEvent)
if !ok {
t.Fatalf("expected *FunctionCallEvent, got %T", event)
}
if e.Name != "search" {
t.Errorf("got name %q", e.Name)
}
if e.ToolCallID != "tc1" {
t.Errorf("got tool_call_id %q", e.ToolCallID)
}
}
func TestUnmarshalEvent_ToolExecutionStarted(t *testing.T) {
data := []byte(`{"type":"tool.execution.started","created_at":"2024-01-01T00:00:00Z","output_index":0,"id":"te1","name":"web_search","arguments":"query"}`)
event, err := UnmarshalEvent(data)
if err != nil {
t.Fatal(err)
}
e, ok := event.(*ToolExecutionStartedEvent)
if !ok {
t.Fatalf("expected *ToolExecutionStartedEvent, got %T", event)
}
if e.Name != "web_search" {
t.Errorf("got %q", e.Name)
}
}
func TestUnmarshalEvent_ToolExecutionDone(t *testing.T) {
data := []byte(`{"type":"tool.execution.done","created_at":"2024-01-01T00:00:00Z","output_index":0,"id":"te1","name":"web_search"}`)
event, err := UnmarshalEvent(data)
if err != nil {
t.Fatal(err)
}
_, ok := event.(*ToolExecutionDoneEvent)
if !ok {
t.Fatalf("expected *ToolExecutionDoneEvent, got %T", event)
}
}
func TestUnmarshalEvent_AgentHandoff(t *testing.T) {
started := []byte(`{"type":"agent.handoff.started","created_at":"2024-01-01T00:00:00Z","output_index":0,"id":"h1","previous_agent_id":"a1","previous_agent_name":"A"}`)
event, err := UnmarshalEvent(started)
if err != nil {
t.Fatal(err)
}
hs, ok := event.(*AgentHandoffStartedEvent)
if !ok {
t.Fatalf("expected *AgentHandoffStartedEvent, got %T", event)
}
if hs.PreviousAgentID != "a1" {
t.Errorf("got %q", hs.PreviousAgentID)
}
done := []byte(`{"type":"agent.handoff.done","created_at":"2024-01-01T00:00:00Z","output_index":0,"id":"h1","next_agent_id":"a2","next_agent_name":"B"}`)
event, err = UnmarshalEvent(done)
if err != nil {
t.Fatal(err)
}
hd, ok := event.(*AgentHandoffDoneEvent)
if !ok {
t.Fatalf("expected *AgentHandoffDoneEvent, got %T", event)
}
if hd.NextAgentID != "a2" {
t.Errorf("got %q", hd.NextAgentID)
}
}
func TestUnmarshalEvent_Unknown(t *testing.T) {
_, err := UnmarshalEvent([]byte(`{"type":"unknown.event"}`))
if err == nil {
t.Error("expected error for unknown type")
}
}
func TestInputs_TextMarshal(t *testing.T) {
inputs := TextInputs("Hello")
data, err := json.Marshal(inputs)
if err != nil {
t.Fatal(err)
}
if string(data) != `"Hello"` {
t.Errorf("got %s", data)
}
}

81
conversation/request.go Normal file
View File

@@ -0,0 +1,81 @@
package conversation
import "encoding/json"
// StartRequest starts a new conversation.
type StartRequest struct {
Inputs Inputs `json:"inputs"`
Model string `json:"model,omitempty"`
AgentID string `json:"agent_id,omitempty"`
AgentVersion json.RawMessage `json:"agent_version,omitempty"`
Instructions *string `json:"instructions,omitempty"`
Tools []Tool `json:"tools,omitempty"`
CompletionArgs *CompletionArgs `json:"completion_args,omitempty"`
Store *bool `json:"store,omitempty"`
HandoffExecution *HandoffExecution `json:"handoff_execution,omitempty"`
Name *string `json:"name,omitempty"`
Description *string `json:"description,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
stream bool
}
func (r *StartRequest) SetStream(v bool) { r.stream = v }
func (r *StartRequest) MarshalJSON() ([]byte, error) {
type Alias StartRequest
return json.Marshal(&struct {
Stream bool `json:"stream"`
*Alias
}{
Stream: r.stream,
Alias: (*Alias)(r),
})
}
// AppendRequest appends to an existing conversation.
type AppendRequest struct {
Inputs Inputs `json:"inputs"`
CompletionArgs *CompletionArgs `json:"completion_args,omitempty"`
Store *bool `json:"store,omitempty"`
HandoffExecution *HandoffExecution `json:"handoff_execution,omitempty"`
ToolConfirmations []ToolCallConfirmation `json:"tool_confirmations,omitempty"`
stream bool
}
func (r *AppendRequest) SetStream(v bool) { r.stream = v }
func (r *AppendRequest) MarshalJSON() ([]byte, error) {
type Alias AppendRequest
return json.Marshal(&struct {
Stream bool `json:"stream"`
*Alias
}{
Stream: r.stream,
Alias: (*Alias)(r),
})
}
// RestartRequest restarts a conversation from a specific entry.
type RestartRequest struct {
Inputs Inputs `json:"inputs"`
FromEntryID string `json:"from_entry_id"`
CompletionArgs *CompletionArgs `json:"completion_args,omitempty"`
Store *bool `json:"store,omitempty"`
HandoffExecution *HandoffExecution `json:"handoff_execution,omitempty"`
AgentVersion json.RawMessage `json:"agent_version,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
stream bool
}
func (r *RestartRequest) SetStream(v bool) { r.stream = v }
func (r *RestartRequest) MarshalJSON() ([]byte, error) {
type Alias RestartRequest
return json.Marshal(&struct {
Stream bool `json:"stream"`
*Alias
}{
Stream: r.stream,
Alias: (*Alias)(r),
})
}

160
conversations.go Normal file
View File

@@ -0,0 +1,160 @@
package mistral
import (
"context"
"encoding/json"
"fmt"
"somegit.dev/vikingowl/mistral-go-sdk/conversation"
)
// StartConversation creates and starts a new conversation.
func (c *Client) StartConversation(ctx context.Context, req *conversation.StartRequest) (*conversation.Response, error) {
var resp conversation.Response
if err := c.doJSON(ctx, "POST", "/v1/conversations", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// StartConversationStream creates a conversation and returns a stream of events.
func (c *Client) StartConversationStream(ctx context.Context, req *conversation.StartRequest) (*EventStream, error) {
req.SetStream(true)
resp, err := c.doStream(ctx, "POST", "/v1/conversations", req)
if err != nil {
return nil, err
}
return newEventStream(resp.Body), nil
}
// AppendConversation appends inputs to an existing conversation.
func (c *Client) AppendConversation(ctx context.Context, conversationID string, req *conversation.AppendRequest) (*conversation.Response, error) {
var resp conversation.Response
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/conversations/%s", conversationID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// AppendConversationStream appends to a conversation and returns a stream of events.
func (c *Client) AppendConversationStream(ctx context.Context, conversationID string, req *conversation.AppendRequest) (*EventStream, error) {
req.SetStream(true)
resp, err := c.doStream(ctx, "POST", fmt.Sprintf("/v1/conversations/%s", conversationID), req)
if err != nil {
return nil, err
}
return newEventStream(resp.Body), nil
}
// RestartConversation restarts a conversation from a specific entry.
func (c *Client) RestartConversation(ctx context.Context, conversationID string, req *conversation.RestartRequest) (*conversation.Response, error) {
var resp conversation.Response
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/conversations/%s/restart", conversationID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// RestartConversationStream restarts a conversation and returns a stream of events.
func (c *Client) RestartConversationStream(ctx context.Context, conversationID string, req *conversation.RestartRequest) (*EventStream, error) {
req.SetStream(true)
resp, err := c.doStream(ctx, "POST", fmt.Sprintf("/v1/conversations/%s/restart", conversationID), req)
if err != nil {
return nil, err
}
return newEventStream(resp.Body), nil
}
// GetConversation retrieves conversation metadata.
func (c *Client) GetConversation(ctx context.Context, conversationID string) (*conversation.Conversation, error) {
var resp conversation.Conversation
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/conversations/%s", conversationID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ListConversations lists conversations with optional pagination.
func (c *Client) ListConversations(ctx context.Context) ([]conversation.Conversation, error) {
var resp []conversation.Conversation
if err := c.doJSON(ctx, "GET", "/v1/conversations", nil, &resp); err != nil {
return nil, err
}
return resp, nil
}
// DeleteConversation deletes a conversation.
func (c *Client) DeleteConversation(ctx context.Context, conversationID string) error {
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/conversations/%s", conversationID), nil)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return parseAPIError(resp)
}
return nil
}
// GetConversationHistory returns the full history of a conversation.
func (c *Client) GetConversationHistory(ctx context.Context, conversationID string) (*conversation.History, error) {
var resp conversation.History
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/conversations/%s/history", conversationID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetConversationMessages returns the messages of a conversation.
func (c *Client) GetConversationMessages(ctx context.Context, conversationID string) (*conversation.Messages, error) {
var resp conversation.Messages
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/conversations/%s/messages", conversationID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// EventStream wraps the generic Stream to provide typed conversation events.
type EventStream struct {
stream *Stream[json.RawMessage]
event conversation.Event
err error
}
func newEventStream(body readCloser) *EventStream {
return &EventStream{
stream: newStream[json.RawMessage](body),
}
}
// Next advances to the next event. Returns false when done or on error.
func (s *EventStream) Next() bool {
if s.err != nil {
return false
}
if !s.stream.Next() {
s.err = s.stream.Err()
return false
}
event, err := conversation.UnmarshalEvent(s.stream.Current())
if err != nil {
s.err = err
return false
}
s.event = event
return true
}
// Current returns the most recently read event.
func (s *EventStream) Current() conversation.Event { return s.event }
// Err returns any error encountered during streaming.
func (s *EventStream) Err() error { return s.err }
// Close releases the underlying connection.
func (s *EventStream) Close() error { return s.stream.Close() }
type readCloser = interface {
Read(p []byte) (n int, err error)
Close() error
}

249
conversations_test.go Normal file
View File

@@ -0,0 +1,249 @@
package mistral
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/conversation"
)
func TestStartConversation_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/conversations" {
t.Errorf("got path %s", r.URL.Path)
}
if r.Method != "POST" {
t.Errorf("got method %s", r.Method)
}
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
if body["inputs"] != "Hello" {
t.Errorf("expected inputs=Hello, got %v", body["inputs"])
}
if body["model"] != "mistral-small-latest" {
t.Errorf("expected model, got %v", body["model"])
}
json.NewEncoder(w).Encode(map[string]any{
"object": "conversation.response",
"conversation_id": "conv-123",
"outputs": []map[string]any{{
"object": "entry", "id": "e1", "type": "message.output",
"created_at": "2024-01-01T00:00:00Z",
"role": "assistant", "content": "Hello! How can I help?",
}},
"usage": map[string]any{
"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18,
},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.StartConversation(context.Background(), &conversation.StartRequest{
Inputs: conversation.TextInputs("Hello"),
Model: "mistral-small-latest",
})
if err != nil {
t.Fatal(err)
}
if resp.ConversationID != "conv-123" {
t.Errorf("got conv id %q", resp.ConversationID)
}
if len(resp.Outputs) != 1 {
t.Fatalf("got %d outputs", len(resp.Outputs))
}
out, ok := resp.Outputs[0].(*conversation.MessageOutputEntry)
if !ok {
t.Fatalf("expected *MessageOutputEntry, got %T", resp.Outputs[0])
}
if conversation.TextContent(out.Content) != "Hello! How can I help?" {
t.Errorf("got %q", conversation.TextContent(out.Content))
}
if resp.Usage.TotalTokens != 18 {
t.Errorf("got total_tokens %d", resp.Usage.TotalTokens)
}
}
func TestAppendConversation_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/conversations/conv-123" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"object": "conversation.response", "conversation_id": "conv-123",
"outputs": []map[string]any{{
"object": "entry", "id": "e2", "type": "message.output",
"created_at": "2024-01-01T00:00:00Z",
"role": "assistant", "content": "The weather is sunny.",
}},
"usage": map[string]any{"prompt_tokens": 20, "completion_tokens": 5, "total_tokens": 25},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.AppendConversation(context.Background(), "conv-123", &conversation.AppendRequest{
Inputs: conversation.TextInputs("What's the weather?"),
})
if err != nil {
t.Fatal(err)
}
if resp.ConversationID != "conv-123" {
t.Errorf("got %q", resp.ConversationID)
}
}
func TestGetConversation_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/conversations/conv-123" {
t.Errorf("got path %s", r.URL.Path)
}
if r.Method != "GET" {
t.Errorf("got method %s", r.Method)
}
json.NewEncoder(w).Encode(map[string]any{
"object": "conversation", "id": "conv-123",
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:01:00Z",
"model": "mistral-small-latest",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
conv, err := client.GetConversation(context.Background(), "conv-123")
if err != nil {
t.Fatal(err)
}
if conv.ID != "conv-123" {
t.Errorf("got id %q", conv.ID)
}
if conv.Model != "mistral-small-latest" {
t.Errorf("got model %q", conv.Model)
}
}
func TestDeleteConversation_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "DELETE" {
t.Errorf("got method %s", r.Method)
}
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
err := client.DeleteConversation(context.Background(), "conv-123")
if err != nil {
t.Fatal(err)
}
}
func TestGetConversationHistory_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/conversations/conv-123/history" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"object": "conversation.history", "conversation_id": "conv-123",
"entries": []map[string]any{
{"object": "entry", "id": "e1", "type": "message.input", "created_at": "2024-01-01T00:00:00Z", "role": "user", "content": "Hi"},
{"object": "entry", "id": "e2", "type": "message.output", "created_at": "2024-01-01T00:00:01Z", "role": "assistant", "content": "Hello!"},
},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
history, err := client.GetConversationHistory(context.Background(), "conv-123")
if err != nil {
t.Fatal(err)
}
if len(history.Entries) != 2 {
t.Fatalf("got %d entries", len(history.Entries))
}
if _, ok := history.Entries[0].(*conversation.MessageInputEntry); !ok {
t.Errorf("expected *MessageInputEntry, got %T", history.Entries[0])
}
if _, ok := history.Entries[1].(*conversation.MessageOutputEntry); !ok {
t.Errorf("expected *MessageOutputEntry, got %T", history.Entries[1])
}
}
func TestStartConversationStream_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
if body["stream"] != true {
t.Errorf("expected stream=true")
}
w.Header().Set("Content-Type", "text/event-stream")
flusher, _ := w.(http.Flusher)
events := []map[string]any{
{"type": "conversation.response.started", "created_at": "2024-01-01T00:00:00Z", "conversation_id": "conv-456"},
{"type": "message.output.delta", "created_at": "2024-01-01T00:00:00Z", "output_index": 0, "id": "m1", "content_index": 0, "content": "Hello", "role": "assistant"},
{"type": "message.output.delta", "created_at": "2024-01-01T00:00:00Z", "output_index": 0, "id": "m1", "content_index": 0, "content": " world!", "role": "assistant"},
{"type": "conversation.response.done", "created_at": "2024-01-01T00:00:01Z", "usage": map[string]any{"prompt_tokens": 5, "completion_tokens": 2, "total_tokens": 7}},
}
for _, ev := range events {
data, _ := json.Marshal(ev)
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()
}
fmt.Fprint(w, "data: [DONE]\n\n")
flusher.Flush()
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
stream, err := client.StartConversationStream(context.Background(), &conversation.StartRequest{
Inputs: conversation.TextInputs("Hi"),
Model: "mistral-small-latest",
})
if err != nil {
t.Fatal(err)
}
defer stream.Close()
var events []conversation.Event
for stream.Next() {
events = append(events, stream.Current())
}
if stream.Err() != nil {
t.Fatal(stream.Err())
}
if len(events) != 4 {
t.Fatalf("got %d events, want 4", len(events))
}
started, ok := events[0].(*conversation.ResponseStartedEvent)
if !ok {
t.Fatalf("expected *ResponseStartedEvent, got %T", events[0])
}
if started.ConversationID != "conv-456" {
t.Errorf("got conv id %q", started.ConversationID)
}
msg, ok := events[1].(*conversation.MessageOutputEvent)
if !ok {
t.Fatalf("expected *MessageOutputEvent, got %T", events[1])
}
if conversation.TextContent(msg.Content) != "Hello" {
t.Errorf("got %q", conversation.TextContent(msg.Content))
}
done, ok := events[3].(*conversation.ResponseDoneEvent)
if !ok {
t.Fatalf("expected *ResponseDoneEvent, got %T", events[3])
}
if done.Usage.TotalTokens != 7 {
t.Errorf("got total_tokens %d", done.Usage.TotalTokens)
}
}