feat: Phase 6 conversations — start, append, restart, stream, history
Conversations API with full CRUD: start, append, restart (+ stream variants), get, list, delete, history, messages. Discriminated unions for entries (6 types) and streaming events (10 types). EventStream wraps Stream[json.RawMessage] with typed event dispatch.
This commit is contained in:
182
conversation/conversation.go
Normal file
182
conversation/conversation.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/chat"
|
||||
)
|
||||
|
||||
// HandoffExecution controls tool call execution.
|
||||
type HandoffExecution string
|
||||
|
||||
const (
|
||||
HandoffClient HandoffExecution = "client"
|
||||
HandoffServer HandoffExecution = "server"
|
||||
)
|
||||
|
||||
// Tool represents a conversation tool.
|
||||
type Tool struct {
|
||||
Type string `json:"type"`
|
||||
Function *chat.Function `json:"function,omitempty"`
|
||||
LibraryIDs []string `json:"library_ids,omitempty"`
|
||||
ToolConfiguration *ToolConfig `json:"tool_configuration,omitempty"`
|
||||
}
|
||||
|
||||
// ToolConfig configures tool behavior.
|
||||
type ToolConfig struct {
|
||||
Exclude []string `json:"exclude,omitempty"`
|
||||
Include []string `json:"include,omitempty"`
|
||||
RequiresConfirmation []string `json:"requires_confirmation,omitempty"`
|
||||
}
|
||||
|
||||
// CompletionArgs holds optional completion parameters.
|
||||
type CompletionArgs struct {
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
RandomSeed *int `json:"random_seed,omitempty"`
|
||||
Prediction *chat.Prediction `json:"prediction,omitempty"`
|
||||
ResponseFormat *chat.ResponseFormat `json:"response_format,omitempty"`
|
||||
ToolChoice *chat.ToolChoiceMode `json:"tool_choice,omitempty"`
|
||||
}
|
||||
|
||||
// ToolCallConfirmation confirms or denies a pending tool call.
|
||||
type ToolCallConfirmation struct {
|
||||
ToolCallID string `json:"tool_call_id"`
|
||||
Confirmation string `json:"confirmation"` // "allow" or "deny"
|
||||
}
|
||||
|
||||
// Inputs represents conversation inputs (text string or entry array).
|
||||
type Inputs struct {
|
||||
text *string
|
||||
entries []Entry
|
||||
}
|
||||
|
||||
// TextInputs creates Inputs from a plain text string.
|
||||
func TextInputs(s string) Inputs { return Inputs{text: &s} }
|
||||
|
||||
// EntryInputs creates Inputs from entry objects.
|
||||
func EntryInputs(entries ...Entry) Inputs { return Inputs{entries: entries} }
|
||||
|
||||
func (i Inputs) MarshalJSON() ([]byte, error) {
|
||||
if i.text != nil {
|
||||
return json.Marshal(*i.text)
|
||||
}
|
||||
return json.Marshal(i.entries)
|
||||
}
|
||||
|
||||
// UsageInfo contains conversation token usage.
|
||||
type UsageInfo struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
ConnectorTokens *int `json:"connector_tokens,omitempty"`
|
||||
Connectors map[string]int `json:"connectors,omitempty"`
|
||||
}
|
||||
|
||||
// Response is the response from starting, appending, or restarting a conversation.
|
||||
type Response struct {
|
||||
Object string `json:"object"`
|
||||
ConversationID string `json:"conversation_id"`
|
||||
Outputs []Entry `json:"-"`
|
||||
Usage UsageInfo `json:"usage"`
|
||||
Guardrails json.RawMessage `json:"guardrails,omitempty"`
|
||||
}
|
||||
|
||||
func (r *Response) UnmarshalJSON(data []byte) error {
|
||||
type alias Response
|
||||
var raw struct {
|
||||
*alias
|
||||
Outputs []json.RawMessage `json:"outputs"`
|
||||
}
|
||||
raw.alias = (*alias)(r)
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return fmt.Errorf("mistral: unmarshal conversation response: %w", err)
|
||||
}
|
||||
r.Outputs = make([]Entry, len(raw.Outputs))
|
||||
for i, o := range raw.Outputs {
|
||||
entry, err := UnmarshalEntry(o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.Outputs[i] = entry
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Conversation represents conversation metadata.
|
||||
type Conversation struct {
|
||||
Object string `json:"object"`
|
||||
ID string `json:"id"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
Model string `json:"model,omitempty"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
AgentVersion json.RawMessage `json:"agent_version,omitempty"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
Instructions *string `json:"instructions,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
CompletionArgs *CompletionArgs `json:"completion_args,omitempty"`
|
||||
}
|
||||
|
||||
// History is the response from getting conversation history.
|
||||
type History struct {
|
||||
Object string `json:"object"`
|
||||
ConversationID string `json:"conversation_id"`
|
||||
Entries []Entry `json:"-"`
|
||||
}
|
||||
|
||||
func (h *History) UnmarshalJSON(data []byte) error {
|
||||
type alias History
|
||||
var raw struct {
|
||||
*alias
|
||||
Entries []json.RawMessage `json:"entries"`
|
||||
}
|
||||
raw.alias = (*alias)(h)
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return fmt.Errorf("mistral: unmarshal conversation history: %w", err)
|
||||
}
|
||||
h.Entries = make([]Entry, len(raw.Entries))
|
||||
for i, e := range raw.Entries {
|
||||
entry, err := UnmarshalEntry(e)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.Entries[i] = entry
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Messages is the response from getting conversation messages.
|
||||
type Messages struct {
|
||||
Object string `json:"object"`
|
||||
ConversationID string `json:"conversation_id"`
|
||||
Messages []Entry `json:"-"`
|
||||
}
|
||||
|
||||
func (m *Messages) UnmarshalJSON(data []byte) error {
|
||||
type alias Messages
|
||||
var raw struct {
|
||||
*alias
|
||||
Messages []json.RawMessage `json:"messages"`
|
||||
}
|
||||
raw.alias = (*alias)(m)
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
return fmt.Errorf("mistral: unmarshal conversation messages: %w", err)
|
||||
}
|
||||
m.Messages = make([]Entry, len(raw.Messages))
|
||||
for i, msg := range raw.Messages {
|
||||
entry, err := UnmarshalEntry(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.Messages[i] = entry
|
||||
}
|
||||
return nil
|
||||
}
|
||||
161
conversation/entry.go
Normal file
161
conversation/entry.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Entry is a sealed interface for conversation history entries.
|
||||
type Entry interface {
|
||||
entryType() string
|
||||
}
|
||||
|
||||
// MessageInputEntry represents a user or assistant input message.
|
||||
type MessageInputEntry struct {
|
||||
Object string `json:"object"`
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
CompletedAt *string `json:"completed_at,omitempty"`
|
||||
Role string `json:"role"`
|
||||
Content json.RawMessage `json:"content"`
|
||||
Prefix bool `json:"prefix,omitempty"`
|
||||
}
|
||||
|
||||
func (*MessageInputEntry) entryType() string { return "message.input" }
|
||||
|
||||
// MessageOutputEntry represents an assistant output message.
|
||||
type MessageOutputEntry struct {
|
||||
Object string `json:"object"`
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
CompletedAt *string `json:"completed_at,omitempty"`
|
||||
Role string `json:"role"`
|
||||
Content json.RawMessage `json:"content"`
|
||||
AgentID *string `json:"agent_id,omitempty"`
|
||||
Model *string `json:"model,omitempty"`
|
||||
}
|
||||
|
||||
func (*MessageOutputEntry) entryType() string { return "message.output" }
|
||||
|
||||
// FunctionCallEntry represents a function call by the model.
|
||||
type FunctionCallEntry struct {
|
||||
Object string `json:"object"`
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
CompletedAt *string `json:"completed_at,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id"`
|
||||
Name string `json:"name"`
|
||||
Arguments json.RawMessage `json:"arguments"`
|
||||
ConfirmationStatus *string `json:"confirmation_status,omitempty"`
|
||||
AgentID *string `json:"agent_id,omitempty"`
|
||||
Model *string `json:"model,omitempty"`
|
||||
}
|
||||
|
||||
func (*FunctionCallEntry) entryType() string { return "function.call" }
|
||||
|
||||
// FunctionResultEntry represents a function result provided by the client.
|
||||
type FunctionResultEntry struct {
|
||||
Object string `json:"object"`
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
CompletedAt *string `json:"completed_at,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id"`
|
||||
Result string `json:"result"`
|
||||
}
|
||||
|
||||
func (*FunctionResultEntry) entryType() string { return "function.result" }
|
||||
|
||||
// ToolExecutionEntry represents a built-in tool execution.
|
||||
type ToolExecutionEntry struct {
|
||||
Object string `json:"object"`
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
CompletedAt *string `json:"completed_at,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
Info map[string]any `json:"info,omitempty"`
|
||||
AgentID *string `json:"agent_id,omitempty"`
|
||||
Model *string `json:"model,omitempty"`
|
||||
}
|
||||
|
||||
func (*ToolExecutionEntry) entryType() string { return "tool.execution" }
|
||||
|
||||
// AgentHandoffEntry represents an agent-to-agent handoff.
|
||||
type AgentHandoffEntry struct {
|
||||
Object string `json:"object"`
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
CompletedAt *string `json:"completed_at,omitempty"`
|
||||
PreviousAgentID string `json:"previous_agent_id"`
|
||||
PreviousAgentName string `json:"previous_agent_name"`
|
||||
NextAgentID string `json:"next_agent_id"`
|
||||
NextAgentName string `json:"next_agent_name"`
|
||||
}
|
||||
|
||||
func (*AgentHandoffEntry) entryType() string { return "agent.handoff" }
|
||||
|
||||
// UnmarshalEntry dispatches JSON to the concrete Entry type
|
||||
// based on the "type" discriminator field.
|
||||
func UnmarshalEntry(data []byte) (Entry, error) {
|
||||
var probe struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &probe); err != nil {
|
||||
return nil, fmt.Errorf("mistral: unmarshal entry: %w", err)
|
||||
}
|
||||
switch probe.Type {
|
||||
case "message.input":
|
||||
var e MessageInputEntry
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
case "message.output":
|
||||
var e MessageOutputEntry
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
case "function.call":
|
||||
var e FunctionCallEntry
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
case "function.result":
|
||||
var e FunctionResultEntry
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
case "tool.execution":
|
||||
var e ToolExecutionEntry
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
case "agent.handoff":
|
||||
var e AgentHandoffEntry
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
default:
|
||||
return nil, fmt.Errorf("mistral: unknown entry type: %q", probe.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// TextContent extracts text from a raw content field.
|
||||
// Handles both string content and chunk arrays (extracts text chunks).
|
||||
func TextContent(raw json.RawMessage) string {
|
||||
if len(raw) == 0 {
|
||||
return ""
|
||||
}
|
||||
var s string
|
||||
if json.Unmarshal(raw, &s) == nil {
|
||||
return s
|
||||
}
|
||||
var chunks []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
if json.Unmarshal(raw, &chunks) == nil {
|
||||
var sb strings.Builder
|
||||
for _, ch := range chunks {
|
||||
if ch.Type == "text" {
|
||||
sb.WriteString(ch.Text)
|
||||
}
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
145
conversation/entry_test.go
Normal file
145
conversation/entry_test.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUnmarshalEntry_MessageInput(t *testing.T) {
|
||||
data := []byte(`{"object":"entry","id":"e1","type":"message.input","created_at":"2024-01-01T00:00:00Z","role":"user","content":"Hello"}`)
|
||||
entry, err := UnmarshalEntry(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
e, ok := entry.(*MessageInputEntry)
|
||||
if !ok {
|
||||
t.Fatalf("expected *MessageInputEntry, got %T", entry)
|
||||
}
|
||||
if e.ID != "e1" {
|
||||
t.Errorf("got id %q", e.ID)
|
||||
}
|
||||
if e.Role != "user" {
|
||||
t.Errorf("got role %q", e.Role)
|
||||
}
|
||||
if TextContent(e.Content) != "Hello" {
|
||||
t.Errorf("got content %q", TextContent(e.Content))
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalEntry_MessageOutput(t *testing.T) {
|
||||
data := []byte(`{"object":"entry","id":"e2","type":"message.output","created_at":"2024-01-01T00:00:00Z","role":"assistant","content":"Hi there!"}`)
|
||||
entry, err := UnmarshalEntry(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
e, ok := entry.(*MessageOutputEntry)
|
||||
if !ok {
|
||||
t.Fatalf("expected *MessageOutputEntry, got %T", entry)
|
||||
}
|
||||
if TextContent(e.Content) != "Hi there!" {
|
||||
t.Errorf("got content %q", TextContent(e.Content))
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalEntry_FunctionCall(t *testing.T) {
|
||||
data := []byte(`{"object":"entry","id":"e3","type":"function.call","created_at":"2024-01-01T00:00:00Z","tool_call_id":"tc1","name":"get_weather","arguments":"{\"city\":\"Paris\"}"}`)
|
||||
entry, err := UnmarshalEntry(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
e, ok := entry.(*FunctionCallEntry)
|
||||
if !ok {
|
||||
t.Fatalf("expected *FunctionCallEntry, got %T", entry)
|
||||
}
|
||||
if e.Name != "get_weather" {
|
||||
t.Errorf("got name %q", e.Name)
|
||||
}
|
||||
if e.ToolCallID != "tc1" {
|
||||
t.Errorf("got tool_call_id %q", e.ToolCallID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalEntry_FunctionResult(t *testing.T) {
|
||||
data := []byte(`{"object":"entry","id":"e4","type":"function.result","created_at":"2024-01-01T00:00:00Z","tool_call_id":"tc1","result":"{\"temp\":22}"}`)
|
||||
entry, err := UnmarshalEntry(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
e, ok := entry.(*FunctionResultEntry)
|
||||
if !ok {
|
||||
t.Fatalf("expected *FunctionResultEntry, got %T", entry)
|
||||
}
|
||||
if e.Result != `{"temp":22}` {
|
||||
t.Errorf("got result %q", e.Result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalEntry_ToolExecution(t *testing.T) {
|
||||
data := []byte(`{"object":"entry","id":"e5","type":"tool.execution","created_at":"2024-01-01T00:00:00Z","name":"web_search","arguments":"query"}`)
|
||||
entry, err := UnmarshalEntry(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
e, ok := entry.(*ToolExecutionEntry)
|
||||
if !ok {
|
||||
t.Fatalf("expected *ToolExecutionEntry, got %T", entry)
|
||||
}
|
||||
if e.Name != "web_search" {
|
||||
t.Errorf("got name %q", e.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalEntry_AgentHandoff(t *testing.T) {
|
||||
data := []byte(`{"object":"entry","id":"e6","type":"agent.handoff","created_at":"2024-01-01T00:00:00Z","previous_agent_id":"a1","previous_agent_name":"Agent A","next_agent_id":"a2","next_agent_name":"Agent B"}`)
|
||||
entry, err := UnmarshalEntry(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
e, ok := entry.(*AgentHandoffEntry)
|
||||
if !ok {
|
||||
t.Fatalf("expected *AgentHandoffEntry, got %T", entry)
|
||||
}
|
||||
if e.PreviousAgentName != "Agent A" {
|
||||
t.Errorf("got prev %q", e.PreviousAgentName)
|
||||
}
|
||||
if e.NextAgentName != "Agent B" {
|
||||
t.Errorf("got next %q", e.NextAgentName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalEntry_Unknown(t *testing.T) {
|
||||
_, err := UnmarshalEntry([]byte(`{"type":"unknown.type"}`))
|
||||
if err == nil {
|
||||
t.Error("expected error for unknown type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextContent_String(t *testing.T) {
|
||||
raw := json.RawMessage(`"Hello world"`)
|
||||
if TextContent(raw) != "Hello world" {
|
||||
t.Errorf("got %q", TextContent(raw))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextContent_ChunkArray(t *testing.T) {
|
||||
raw := json.RawMessage(`[{"type":"text","text":"Hello "},{"type":"text","text":"world"}]`)
|
||||
if TextContent(raw) != "Hello world" {
|
||||
t.Errorf("got %q", TextContent(raw))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextContent_Empty(t *testing.T) {
|
||||
if TextContent(nil) != "" {
|
||||
t.Error("expected empty for nil")
|
||||
}
|
||||
if TextContent(json.RawMessage{}) != "" {
|
||||
t.Error("expected empty for empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextContent_MixedChunks(t *testing.T) {
|
||||
raw := json.RawMessage(`[{"type":"text","text":"Hello"},{"type":"tool_reference","tool":"web_search","title":"Result"},{"type":"text","text":" world"}]`)
|
||||
if TextContent(raw) != "Hello world" {
|
||||
t.Errorf("got %q", TextContent(raw))
|
||||
}
|
||||
}
|
||||
177
conversation/event.go
Normal file
177
conversation/event.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Event is a sealed interface for conversation streaming events.
|
||||
type Event interface {
|
||||
eventType() string
|
||||
}
|
||||
|
||||
// ResponseStartedEvent signals the start of a conversation response.
|
||||
type ResponseStartedEvent struct {
|
||||
Type string `json:"type"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
ConversationID string `json:"conversation_id"`
|
||||
}
|
||||
|
||||
func (*ResponseStartedEvent) eventType() string { return "conversation.response.started" }
|
||||
|
||||
// ResponseDoneEvent signals the completion of a conversation response.
|
||||
type ResponseDoneEvent struct {
|
||||
Type string `json:"type"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Usage UsageInfo `json:"usage"`
|
||||
}
|
||||
|
||||
func (*ResponseDoneEvent) eventType() string { return "conversation.response.done" }
|
||||
|
||||
// ResponseErrorEvent signals an error during conversation processing.
|
||||
type ResponseErrorEvent struct {
|
||||
Type string `json:"type"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Message string `json:"message"`
|
||||
Code int `json:"code"`
|
||||
}
|
||||
|
||||
func (*ResponseErrorEvent) eventType() string { return "conversation.response.error" }
|
||||
|
||||
// MessageOutputEvent contains a delta of assistant message output.
|
||||
type MessageOutputEvent struct {
|
||||
Type string `json:"type"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
OutputIndex int `json:"output_index"`
|
||||
ID string `json:"id"`
|
||||
ContentIndex int `json:"content_index"`
|
||||
Content json.RawMessage `json:"content"`
|
||||
Model *string `json:"model,omitempty"`
|
||||
AgentID *string `json:"agent_id,omitempty"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
func (*MessageOutputEvent) eventType() string { return "message.output.delta" }
|
||||
|
||||
// ToolExecutionStartedEvent signals the start of a tool execution.
|
||||
type ToolExecutionStartedEvent struct {
|
||||
Type string `json:"type"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
OutputIndex int `json:"output_index"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
Model *string `json:"model,omitempty"`
|
||||
AgentID *string `json:"agent_id,omitempty"`
|
||||
}
|
||||
|
||||
func (*ToolExecutionStartedEvent) eventType() string { return "tool.execution.started" }
|
||||
|
||||
// ToolExecutionDeltaEvent contains a delta of tool execution output.
|
||||
type ToolExecutionDeltaEvent struct {
|
||||
Type string `json:"type"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
OutputIndex int `json:"output_index"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
func (*ToolExecutionDeltaEvent) eventType() string { return "tool.execution.delta" }
|
||||
|
||||
// ToolExecutionDoneEvent signals the completion of a tool execution.
|
||||
type ToolExecutionDoneEvent struct {
|
||||
Type string `json:"type"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
OutputIndex int `json:"output_index"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Info map[string]any `json:"info,omitempty"`
|
||||
}
|
||||
|
||||
func (*ToolExecutionDoneEvent) eventType() string { return "tool.execution.done" }
|
||||
|
||||
// FunctionCallEvent contains a delta of a function call.
|
||||
type FunctionCallEvent struct {
|
||||
Type string `json:"type"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
OutputIndex int `json:"output_index"`
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
ToolCallID string `json:"tool_call_id"`
|
||||
Arguments string `json:"arguments"`
|
||||
ConfirmationStatus *string `json:"confirmation_status,omitempty"`
|
||||
Model *string `json:"model,omitempty"`
|
||||
AgentID *string `json:"agent_id,omitempty"`
|
||||
}
|
||||
|
||||
func (*FunctionCallEvent) eventType() string { return "function.call.delta" }
|
||||
|
||||
// AgentHandoffStartedEvent signals the start of an agent handoff.
|
||||
type AgentHandoffStartedEvent struct {
|
||||
Type string `json:"type"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
OutputIndex int `json:"output_index"`
|
||||
ID string `json:"id"`
|
||||
PreviousAgentID string `json:"previous_agent_id"`
|
||||
PreviousAgentName string `json:"previous_agent_name"`
|
||||
}
|
||||
|
||||
func (*AgentHandoffStartedEvent) eventType() string { return "agent.handoff.started" }
|
||||
|
||||
// AgentHandoffDoneEvent signals the completion of an agent handoff.
|
||||
type AgentHandoffDoneEvent struct {
|
||||
Type string `json:"type"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
OutputIndex int `json:"output_index"`
|
||||
ID string `json:"id"`
|
||||
NextAgentID string `json:"next_agent_id"`
|
||||
NextAgentName string `json:"next_agent_name"`
|
||||
}
|
||||
|
||||
func (*AgentHandoffDoneEvent) eventType() string { return "agent.handoff.done" }
|
||||
|
||||
// UnmarshalEvent dispatches JSON to the concrete Event type
|
||||
// based on the "type" discriminator field.
|
||||
func UnmarshalEvent(data []byte) (Event, error) {
|
||||
var probe struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &probe); err != nil {
|
||||
return nil, fmt.Errorf("mistral: unmarshal event: %w", err)
|
||||
}
|
||||
switch probe.Type {
|
||||
case "conversation.response.started":
|
||||
var e ResponseStartedEvent
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
case "conversation.response.done":
|
||||
var e ResponseDoneEvent
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
case "conversation.response.error":
|
||||
var e ResponseErrorEvent
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
case "message.output.delta":
|
||||
var e MessageOutputEvent
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
case "tool.execution.started":
|
||||
var e ToolExecutionStartedEvent
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
case "tool.execution.delta":
|
||||
var e ToolExecutionDeltaEvent
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
case "tool.execution.done":
|
||||
var e ToolExecutionDoneEvent
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
case "function.call.delta":
|
||||
var e FunctionCallEvent
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
case "agent.handoff.started":
|
||||
var e AgentHandoffStartedEvent
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
case "agent.handoff.done":
|
||||
var e AgentHandoffDoneEvent
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
default:
|
||||
return nil, fmt.Errorf("mistral: unknown event type: %q", probe.Type)
|
||||
}
|
||||
}
|
||||
163
conversation/event_test.go
Normal file
163
conversation/event_test.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package conversation
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUnmarshalEvent_ResponseStarted(t *testing.T) {
|
||||
data := []byte(`{"type":"conversation.response.started","created_at":"2024-01-01T00:00:00Z","conversation_id":"conv-123"}`)
|
||||
event, err := UnmarshalEvent(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
e, ok := event.(*ResponseStartedEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected *ResponseStartedEvent, got %T", event)
|
||||
}
|
||||
if e.ConversationID != "conv-123" {
|
||||
t.Errorf("got %q", e.ConversationID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalEvent_ResponseDone(t *testing.T) {
|
||||
data := []byte(`{"type":"conversation.response.done","created_at":"2024-01-01T00:00:00Z","usage":{"prompt_tokens":10,"completion_tokens":5,"total_tokens":15}}`)
|
||||
event, err := UnmarshalEvent(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
e, ok := event.(*ResponseDoneEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected *ResponseDoneEvent, got %T", event)
|
||||
}
|
||||
if e.Usage.TotalTokens != 15 {
|
||||
t.Errorf("got total_tokens %d", e.Usage.TotalTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalEvent_ResponseError(t *testing.T) {
|
||||
data := []byte(`{"type":"conversation.response.error","created_at":"2024-01-01T00:00:00Z","message":"error occurred","code":500}`)
|
||||
event, err := UnmarshalEvent(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
e, ok := event.(*ResponseErrorEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected *ResponseErrorEvent, got %T", event)
|
||||
}
|
||||
if e.Message != "error occurred" {
|
||||
t.Errorf("got %q", e.Message)
|
||||
}
|
||||
if e.Code != 500 {
|
||||
t.Errorf("got code %d", e.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalEvent_MessageOutput(t *testing.T) {
|
||||
data := []byte(`{"type":"message.output.delta","created_at":"2024-01-01T00:00:00Z","output_index":0,"id":"m1","content_index":0,"content":"Hello","role":"assistant"}`)
|
||||
event, err := UnmarshalEvent(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
e, ok := event.(*MessageOutputEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected *MessageOutputEvent, got %T", event)
|
||||
}
|
||||
if e.ID != "m1" {
|
||||
t.Errorf("got id %q", e.ID)
|
||||
}
|
||||
if TextContent(e.Content) != "Hello" {
|
||||
t.Errorf("got content %q", TextContent(e.Content))
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalEvent_FunctionCall(t *testing.T) {
|
||||
data := []byte(`{"type":"function.call.delta","created_at":"2024-01-01T00:00:00Z","output_index":0,"id":"fc1","name":"search","tool_call_id":"tc1","arguments":"{\"q\":\"test\"}"}`)
|
||||
event, err := UnmarshalEvent(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
e, ok := event.(*FunctionCallEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected *FunctionCallEvent, got %T", event)
|
||||
}
|
||||
if e.Name != "search" {
|
||||
t.Errorf("got name %q", e.Name)
|
||||
}
|
||||
if e.ToolCallID != "tc1" {
|
||||
t.Errorf("got tool_call_id %q", e.ToolCallID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalEvent_ToolExecutionStarted(t *testing.T) {
|
||||
data := []byte(`{"type":"tool.execution.started","created_at":"2024-01-01T00:00:00Z","output_index":0,"id":"te1","name":"web_search","arguments":"query"}`)
|
||||
event, err := UnmarshalEvent(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
e, ok := event.(*ToolExecutionStartedEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected *ToolExecutionStartedEvent, got %T", event)
|
||||
}
|
||||
if e.Name != "web_search" {
|
||||
t.Errorf("got %q", e.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalEvent_ToolExecutionDone(t *testing.T) {
|
||||
data := []byte(`{"type":"tool.execution.done","created_at":"2024-01-01T00:00:00Z","output_index":0,"id":"te1","name":"web_search"}`)
|
||||
event, err := UnmarshalEvent(data)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, ok := event.(*ToolExecutionDoneEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected *ToolExecutionDoneEvent, got %T", event)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalEvent_AgentHandoff(t *testing.T) {
|
||||
started := []byte(`{"type":"agent.handoff.started","created_at":"2024-01-01T00:00:00Z","output_index":0,"id":"h1","previous_agent_id":"a1","previous_agent_name":"A"}`)
|
||||
event, err := UnmarshalEvent(started)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
hs, ok := event.(*AgentHandoffStartedEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected *AgentHandoffStartedEvent, got %T", event)
|
||||
}
|
||||
if hs.PreviousAgentID != "a1" {
|
||||
t.Errorf("got %q", hs.PreviousAgentID)
|
||||
}
|
||||
|
||||
done := []byte(`{"type":"agent.handoff.done","created_at":"2024-01-01T00:00:00Z","output_index":0,"id":"h1","next_agent_id":"a2","next_agent_name":"B"}`)
|
||||
event, err = UnmarshalEvent(done)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
hd, ok := event.(*AgentHandoffDoneEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected *AgentHandoffDoneEvent, got %T", event)
|
||||
}
|
||||
if hd.NextAgentID != "a2" {
|
||||
t.Errorf("got %q", hd.NextAgentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnmarshalEvent_Unknown(t *testing.T) {
|
||||
_, err := UnmarshalEvent([]byte(`{"type":"unknown.event"}`))
|
||||
if err == nil {
|
||||
t.Error("expected error for unknown type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInputs_TextMarshal(t *testing.T) {
|
||||
inputs := TextInputs("Hello")
|
||||
data, err := json.Marshal(inputs)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(data) != `"Hello"` {
|
||||
t.Errorf("got %s", data)
|
||||
}
|
||||
}
|
||||
81
conversation/request.go
Normal file
81
conversation/request.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package conversation
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// StartRequest starts a new conversation.
|
||||
type StartRequest struct {
|
||||
Inputs Inputs `json:"inputs"`
|
||||
Model string `json:"model,omitempty"`
|
||||
AgentID string `json:"agent_id,omitempty"`
|
||||
AgentVersion json.RawMessage `json:"agent_version,omitempty"`
|
||||
Instructions *string `json:"instructions,omitempty"`
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
CompletionArgs *CompletionArgs `json:"completion_args,omitempty"`
|
||||
Store *bool `json:"store,omitempty"`
|
||||
HandoffExecution *HandoffExecution `json:"handoff_execution,omitempty"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
stream bool
|
||||
}
|
||||
|
||||
func (r *StartRequest) SetStream(v bool) { r.stream = v }
|
||||
|
||||
func (r *StartRequest) MarshalJSON() ([]byte, error) {
|
||||
type Alias StartRequest
|
||||
return json.Marshal(&struct {
|
||||
Stream bool `json:"stream"`
|
||||
*Alias
|
||||
}{
|
||||
Stream: r.stream,
|
||||
Alias: (*Alias)(r),
|
||||
})
|
||||
}
|
||||
|
||||
// AppendRequest appends to an existing conversation.
|
||||
type AppendRequest struct {
|
||||
Inputs Inputs `json:"inputs"`
|
||||
CompletionArgs *CompletionArgs `json:"completion_args,omitempty"`
|
||||
Store *bool `json:"store,omitempty"`
|
||||
HandoffExecution *HandoffExecution `json:"handoff_execution,omitempty"`
|
||||
ToolConfirmations []ToolCallConfirmation `json:"tool_confirmations,omitempty"`
|
||||
stream bool
|
||||
}
|
||||
|
||||
func (r *AppendRequest) SetStream(v bool) { r.stream = v }
|
||||
|
||||
func (r *AppendRequest) MarshalJSON() ([]byte, error) {
|
||||
type Alias AppendRequest
|
||||
return json.Marshal(&struct {
|
||||
Stream bool `json:"stream"`
|
||||
*Alias
|
||||
}{
|
||||
Stream: r.stream,
|
||||
Alias: (*Alias)(r),
|
||||
})
|
||||
}
|
||||
|
||||
// RestartRequest restarts a conversation from a specific entry.
|
||||
type RestartRequest struct {
|
||||
Inputs Inputs `json:"inputs"`
|
||||
FromEntryID string `json:"from_entry_id"`
|
||||
CompletionArgs *CompletionArgs `json:"completion_args,omitempty"`
|
||||
Store *bool `json:"store,omitempty"`
|
||||
HandoffExecution *HandoffExecution `json:"handoff_execution,omitempty"`
|
||||
AgentVersion json.RawMessage `json:"agent_version,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
stream bool
|
||||
}
|
||||
|
||||
func (r *RestartRequest) SetStream(v bool) { r.stream = v }
|
||||
|
||||
func (r *RestartRequest) MarshalJSON() ([]byte, error) {
|
||||
type Alias RestartRequest
|
||||
return json.Marshal(&struct {
|
||||
Stream bool `json:"stream"`
|
||||
*Alias
|
||||
}{
|
||||
Stream: r.stream,
|
||||
Alias: (*Alias)(r),
|
||||
})
|
||||
}
|
||||
160
conversations.go
Normal file
160
conversations.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/conversation"
|
||||
)
|
||||
|
||||
// StartConversation creates and starts a new conversation.
|
||||
func (c *Client) StartConversation(ctx context.Context, req *conversation.StartRequest) (*conversation.Response, error) {
|
||||
var resp conversation.Response
|
||||
if err := c.doJSON(ctx, "POST", "/v1/conversations", req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// StartConversationStream creates a conversation and returns a stream of events.
|
||||
func (c *Client) StartConversationStream(ctx context.Context, req *conversation.StartRequest) (*EventStream, error) {
|
||||
req.SetStream(true)
|
||||
resp, err := c.doStream(ctx, "POST", "/v1/conversations", req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newEventStream(resp.Body), nil
|
||||
}
|
||||
|
||||
// AppendConversation appends inputs to an existing conversation.
|
||||
func (c *Client) AppendConversation(ctx context.Context, conversationID string, req *conversation.AppendRequest) (*conversation.Response, error) {
|
||||
var resp conversation.Response
|
||||
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/conversations/%s", conversationID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// AppendConversationStream appends to a conversation and returns a stream of events.
|
||||
func (c *Client) AppendConversationStream(ctx context.Context, conversationID string, req *conversation.AppendRequest) (*EventStream, error) {
|
||||
req.SetStream(true)
|
||||
resp, err := c.doStream(ctx, "POST", fmt.Sprintf("/v1/conversations/%s", conversationID), req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newEventStream(resp.Body), nil
|
||||
}
|
||||
|
||||
// RestartConversation restarts a conversation from a specific entry.
|
||||
func (c *Client) RestartConversation(ctx context.Context, conversationID string, req *conversation.RestartRequest) (*conversation.Response, error) {
|
||||
var resp conversation.Response
|
||||
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/conversations/%s/restart", conversationID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// RestartConversationStream restarts a conversation and returns a stream of events.
|
||||
func (c *Client) RestartConversationStream(ctx context.Context, conversationID string, req *conversation.RestartRequest) (*EventStream, error) {
|
||||
req.SetStream(true)
|
||||
resp, err := c.doStream(ctx, "POST", fmt.Sprintf("/v1/conversations/%s/restart", conversationID), req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newEventStream(resp.Body), nil
|
||||
}
|
||||
|
||||
// GetConversation retrieves conversation metadata.
|
||||
func (c *Client) GetConversation(ctx context.Context, conversationID string) (*conversation.Conversation, error) {
|
||||
var resp conversation.Conversation
|
||||
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/conversations/%s", conversationID), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ListConversations lists conversations with optional pagination.
|
||||
func (c *Client) ListConversations(ctx context.Context) ([]conversation.Conversation, error) {
|
||||
var resp []conversation.Conversation
|
||||
if err := c.doJSON(ctx, "GET", "/v1/conversations", nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// DeleteConversation deletes a conversation.
|
||||
func (c *Client) DeleteConversation(ctx context.Context, conversationID string) error {
|
||||
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/conversations/%s", conversationID), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 400 {
|
||||
return parseAPIError(resp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetConversationHistory returns the full history of a conversation.
|
||||
func (c *Client) GetConversationHistory(ctx context.Context, conversationID string) (*conversation.History, error) {
|
||||
var resp conversation.History
|
||||
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/conversations/%s/history", conversationID), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// GetConversationMessages returns the messages of a conversation.
|
||||
func (c *Client) GetConversationMessages(ctx context.Context, conversationID string) (*conversation.Messages, error) {
|
||||
var resp conversation.Messages
|
||||
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/conversations/%s/messages", conversationID), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// EventStream wraps the generic Stream to provide typed conversation events.
|
||||
type EventStream struct {
|
||||
stream *Stream[json.RawMessage]
|
||||
event conversation.Event
|
||||
err error
|
||||
}
|
||||
|
||||
func newEventStream(body readCloser) *EventStream {
|
||||
return &EventStream{
|
||||
stream: newStream[json.RawMessage](body),
|
||||
}
|
||||
}
|
||||
|
||||
// Next advances to the next event. Returns false when done or on error.
|
||||
func (s *EventStream) Next() bool {
|
||||
if s.err != nil {
|
||||
return false
|
||||
}
|
||||
if !s.stream.Next() {
|
||||
s.err = s.stream.Err()
|
||||
return false
|
||||
}
|
||||
event, err := conversation.UnmarshalEvent(s.stream.Current())
|
||||
if err != nil {
|
||||
s.err = err
|
||||
return false
|
||||
}
|
||||
s.event = event
|
||||
return true
|
||||
}
|
||||
|
||||
// Current returns the most recently read event.
|
||||
func (s *EventStream) Current() conversation.Event { return s.event }
|
||||
|
||||
// Err returns any error encountered during streaming.
|
||||
func (s *EventStream) Err() error { return s.err }
|
||||
|
||||
// Close releases the underlying connection.
|
||||
func (s *EventStream) Close() error { return s.stream.Close() }
|
||||
|
||||
type readCloser = interface {
|
||||
Read(p []byte) (n int, err error)
|
||||
Close() error
|
||||
}
|
||||
249
conversations_test.go
Normal file
249
conversations_test.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/conversation"
|
||||
)
|
||||
|
||||
func TestStartConversation_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/conversations" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
if r.Method != "POST" {
|
||||
t.Errorf("got method %s", r.Method)
|
||||
}
|
||||
var body map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
if body["inputs"] != "Hello" {
|
||||
t.Errorf("expected inputs=Hello, got %v", body["inputs"])
|
||||
}
|
||||
if body["model"] != "mistral-small-latest" {
|
||||
t.Errorf("expected model, got %v", body["model"])
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"object": "conversation.response",
|
||||
"conversation_id": "conv-123",
|
||||
"outputs": []map[string]any{{
|
||||
"object": "entry", "id": "e1", "type": "message.output",
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"role": "assistant", "content": "Hello! How can I help?",
|
||||
}},
|
||||
"usage": map[string]any{
|
||||
"prompt_tokens": 10, "completion_tokens": 8, "total_tokens": 18,
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.StartConversation(context.Background(), &conversation.StartRequest{
|
||||
Inputs: conversation.TextInputs("Hello"),
|
||||
Model: "mistral-small-latest",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.ConversationID != "conv-123" {
|
||||
t.Errorf("got conv id %q", resp.ConversationID)
|
||||
}
|
||||
if len(resp.Outputs) != 1 {
|
||||
t.Fatalf("got %d outputs", len(resp.Outputs))
|
||||
}
|
||||
out, ok := resp.Outputs[0].(*conversation.MessageOutputEntry)
|
||||
if !ok {
|
||||
t.Fatalf("expected *MessageOutputEntry, got %T", resp.Outputs[0])
|
||||
}
|
||||
if conversation.TextContent(out.Content) != "Hello! How can I help?" {
|
||||
t.Errorf("got %q", conversation.TextContent(out.Content))
|
||||
}
|
||||
if resp.Usage.TotalTokens != 18 {
|
||||
t.Errorf("got total_tokens %d", resp.Usage.TotalTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendConversation_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/conversations/conv-123" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"object": "conversation.response", "conversation_id": "conv-123",
|
||||
"outputs": []map[string]any{{
|
||||
"object": "entry", "id": "e2", "type": "message.output",
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"role": "assistant", "content": "The weather is sunny.",
|
||||
}},
|
||||
"usage": map[string]any{"prompt_tokens": 20, "completion_tokens": 5, "total_tokens": 25},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.AppendConversation(context.Background(), "conv-123", &conversation.AppendRequest{
|
||||
Inputs: conversation.TextInputs("What's the weather?"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.ConversationID != "conv-123" {
|
||||
t.Errorf("got %q", resp.ConversationID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetConversation_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/conversations/conv-123" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
if r.Method != "GET" {
|
||||
t.Errorf("got method %s", r.Method)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"object": "conversation", "id": "conv-123",
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"updated_at": "2024-01-01T00:01:00Z",
|
||||
"model": "mistral-small-latest",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
conv, err := client.GetConversation(context.Background(), "conv-123")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if conv.ID != "conv-123" {
|
||||
t.Errorf("got id %q", conv.ID)
|
||||
}
|
||||
if conv.Model != "mistral-small-latest" {
|
||||
t.Errorf("got model %q", conv.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteConversation_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "DELETE" {
|
||||
t.Errorf("got method %s", r.Method)
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
err := client.DeleteConversation(context.Background(), "conv-123")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetConversationHistory_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/conversations/conv-123/history" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"object": "conversation.history", "conversation_id": "conv-123",
|
||||
"entries": []map[string]any{
|
||||
{"object": "entry", "id": "e1", "type": "message.input", "created_at": "2024-01-01T00:00:00Z", "role": "user", "content": "Hi"},
|
||||
{"object": "entry", "id": "e2", "type": "message.output", "created_at": "2024-01-01T00:00:01Z", "role": "assistant", "content": "Hello!"},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
history, err := client.GetConversationHistory(context.Background(), "conv-123")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(history.Entries) != 2 {
|
||||
t.Fatalf("got %d entries", len(history.Entries))
|
||||
}
|
||||
if _, ok := history.Entries[0].(*conversation.MessageInputEntry); !ok {
|
||||
t.Errorf("expected *MessageInputEntry, got %T", history.Entries[0])
|
||||
}
|
||||
if _, ok := history.Entries[1].(*conversation.MessageOutputEntry); !ok {
|
||||
t.Errorf("expected *MessageOutputEntry, got %T", history.Entries[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartConversationStream_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var body map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
if body["stream"] != true {
|
||||
t.Errorf("expected stream=true")
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
flusher, _ := w.(http.Flusher)
|
||||
|
||||
events := []map[string]any{
|
||||
{"type": "conversation.response.started", "created_at": "2024-01-01T00:00:00Z", "conversation_id": "conv-456"},
|
||||
{"type": "message.output.delta", "created_at": "2024-01-01T00:00:00Z", "output_index": 0, "id": "m1", "content_index": 0, "content": "Hello", "role": "assistant"},
|
||||
{"type": "message.output.delta", "created_at": "2024-01-01T00:00:00Z", "output_index": 0, "id": "m1", "content_index": 0, "content": " world!", "role": "assistant"},
|
||||
{"type": "conversation.response.done", "created_at": "2024-01-01T00:00:01Z", "usage": map[string]any{"prompt_tokens": 5, "completion_tokens": 2, "total_tokens": 7}},
|
||||
}
|
||||
for _, ev := range events {
|
||||
data, _ := json.Marshal(ev)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
}
|
||||
fmt.Fprint(w, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
stream, err := client.StartConversationStream(context.Background(), &conversation.StartRequest{
|
||||
Inputs: conversation.TextInputs("Hi"),
|
||||
Model: "mistral-small-latest",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
var events []conversation.Event
|
||||
for stream.Next() {
|
||||
events = append(events, stream.Current())
|
||||
}
|
||||
if stream.Err() != nil {
|
||||
t.Fatal(stream.Err())
|
||||
}
|
||||
if len(events) != 4 {
|
||||
t.Fatalf("got %d events, want 4", len(events))
|
||||
}
|
||||
|
||||
started, ok := events[0].(*conversation.ResponseStartedEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected *ResponseStartedEvent, got %T", events[0])
|
||||
}
|
||||
if started.ConversationID != "conv-456" {
|
||||
t.Errorf("got conv id %q", started.ConversationID)
|
||||
}
|
||||
|
||||
msg, ok := events[1].(*conversation.MessageOutputEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected *MessageOutputEvent, got %T", events[1])
|
||||
}
|
||||
if conversation.TextContent(msg.Content) != "Hello" {
|
||||
t.Errorf("got %q", conversation.TextContent(msg.Content))
|
||||
}
|
||||
|
||||
done, ok := events[3].(*conversation.ResponseDoneEvent)
|
||||
if !ok {
|
||||
t.Fatalf("expected *ResponseDoneEvent, got %T", events[3])
|
||||
}
|
||||
if done.Usage.TotalTokens != 7 {
|
||||
t.Errorf("got total_tokens %d", done.Usage.TotalTokens)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user