feat: Phase 3 core completions — FIM, Agents, Embeddings

Add FIM, Agents, and Embedding endpoints:
- fim/request.go: FIMCompletionRequest (prompt/suffix model)
- agents/request.go: AgentsCompletionRequest (agent_id + messages)
- embedding/embedding.go: Request/Response/Data types with dtype/encoding
- FIMComplete, FIMCompleteStream, AgentsComplete, AgentsCompleteStream,
  CreateEmbeddings service methods
- All reuse chat.CompletionResponse/CompletionChunk for responses
- 11 new httptest-based tests
This commit is contained in:
2026-03-05 19:36:49 +01:00
parent 9b453ca62a
commit 4b8ca4be5d
9 changed files with 692 additions and 0 deletions

45
agents/request.go Normal file
View File

@@ -0,0 +1,45 @@
package agents
import (
"encoding/json"
"somegit.dev/vikingowl/mistral-go-sdk/chat"
)
// CompletionRequest represents an agents completion request.
type CompletionRequest struct {
AgentID string `json:"agent_id"`
Messages []chat.Message `json:"-"`
MaxTokens *int `json:"max_tokens,omitempty"`
Stop []string `json:"stop,omitempty"`
RandomSeed *int `json:"random_seed,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
ResponseFormat *chat.ResponseFormat `json:"response_format,omitempty"`
Tools []chat.Tool `json:"tools,omitempty"`
ToolChoice *chat.ToolChoice `json:"tool_choice,omitempty"`
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
N *int `json:"n,omitempty"`
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
Prediction *chat.Prediction `json:"prediction,omitempty"`
PromptMode *chat.PromptMode `json:"prompt_mode,omitempty"`
stream bool
}
// SetStream is used internally to set the stream field.
func (r *CompletionRequest) SetStream(v bool) { r.stream = v }
func (r *CompletionRequest) MarshalJSON() ([]byte, error) {
type Alias CompletionRequest
return json.Marshal(&struct {
Messages []chat.Message `json:"messages"`
Stream bool `json:"stream"`
*Alias
}{
Messages: r.Messages,
Stream: r.stream,
Alias: (*Alias)(r),
})
}

27
agents_complete.go Normal file
View File

@@ -0,0 +1,27 @@
package mistral
import (
"context"
"somegit.dev/vikingowl/mistral-go-sdk/agents"
"somegit.dev/vikingowl/mistral-go-sdk/chat"
)
// AgentsComplete sends an agents completion request.
func (c *Client) AgentsComplete(ctx context.Context, req *agents.CompletionRequest) (*chat.CompletionResponse, error) {
var resp chat.CompletionResponse
if err := c.doJSON(ctx, "POST", "/v1/agents/completions", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// AgentsCompleteStream sends an agents request and returns a stream of chunks.
func (c *Client) AgentsCompleteStream(ctx context.Context, req *agents.CompletionRequest) (*Stream[chat.CompletionChunk], error) {
req.SetStream(true)
resp, err := c.doStream(ctx, "POST", "/v1/agents/completions", req)
if err != nil {
return nil, err
}
return newStream[chat.CompletionChunk](resp.Body), nil
}

164
agents_complete_test.go Normal file
View File

@@ -0,0 +1,164 @@
package mistral
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/agents"
"somegit.dev/vikingowl/mistral-go-sdk/chat"
)
func TestAgentsComplete_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/agents/completions" {
t.Errorf("expected /v1/agents/completions, got %s", r.URL.Path)
}
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
if body["agent_id"] != "agent-123" {
t.Errorf("expected agent_id=agent-123, got %v", body["agent_id"])
}
msgs := body["messages"].([]any)
if len(msgs) != 1 {
t.Errorf("expected 1 message, got %d", len(msgs))
}
if body["stream"] != false {
t.Errorf("expected stream=false")
}
json.NewEncoder(w).Encode(map[string]any{
"id": "agent-resp-1", "object": "chat.completion",
"model": "mistral-large-latest", "created": 1234567890,
"choices": []map[string]any{{
"index": 0,
"message": map[string]any{"role": "assistant", "content": "Agent response"},
"finish_reason": "stop",
}},
"usage": map[string]any{"prompt_tokens": 20, "completion_tokens": 5, "total_tokens": 25},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.AgentsComplete(context.Background(), &agents.CompletionRequest{
AgentID: "agent-123",
Messages: []chat.Message{
&chat.UserMessage{Content: chat.TextContent("Hello agent")},
},
})
if err != nil {
t.Fatal(err)
}
if resp.ID != "agent-resp-1" {
t.Errorf("got id %q", resp.ID)
}
if resp.Choices[0].Message.Content.String() != "Agent response" {
t.Errorf("got content %q", resp.Choices[0].Message.Content.String())
}
}
func TestAgentsComplete_WithTools(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
tools := body["tools"].([]any)
if len(tools) != 1 {
t.Errorf("expected 1 tool, got %d", len(tools))
}
json.NewEncoder(w).Encode(map[string]any{
"id": "a2", "object": "chat.completion",
"model": "m", "created": 0,
"choices": []map[string]any{{
"index": 0,
"message": map[string]any{
"role": "assistant", "content": nil,
"tool_calls": []map[string]any{{
"id": "tc1", "type": "function",
"function": map[string]any{"name": "search", "arguments": `{"q":"test"}`},
}},
},
"finish_reason": "tool_calls",
}},
"usage": map[string]any{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.AgentsComplete(context.Background(), &agents.CompletionRequest{
AgentID: "agent-456",
Messages: []chat.Message{
&chat.UserMessage{Content: chat.TextContent("Search for test")},
},
Tools: []chat.Tool{{
Type: "function",
Function: chat.Function{
Name: "search",
Parameters: map[string]any{"type": "object", "properties": map[string]any{"q": map[string]any{"type": "string"}}},
},
}},
})
if err != nil {
t.Fatal(err)
}
if len(resp.Choices[0].Message.ToolCalls) != 1 {
t.Fatalf("expected 1 tool call")
}
if resp.Choices[0].Message.ToolCalls[0].Function.Name != "search" {
t.Errorf("got function %q", resp.Choices[0].Message.ToolCalls[0].Function.Name)
}
}
func TestAgentsCompleteStream_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
if body["stream"] != true {
t.Errorf("expected stream=true")
}
w.Header().Set("Content-Type", "text/event-stream")
flusher, _ := w.(http.Flusher)
for _, word := range []string{"Hello", " from", " agent"} {
chunk := chat.CompletionChunk{
ID: "ac", Model: "m",
Choices: []chat.CompletionStreamChoice{{
Index: 0,
Delta: chat.DeltaMessage{Content: chat.TextContent(word)},
}},
}
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()
}
fmt.Fprint(w, "data: [DONE]\n\n")
flusher.Flush()
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
stream, err := client.AgentsCompleteStream(context.Background(), &agents.CompletionRequest{
AgentID: "agent-789",
Messages: []chat.Message{&chat.UserMessage{Content: chat.TextContent("Hi")}},
})
if err != nil {
t.Fatal(err)
}
defer stream.Close()
var count int
for stream.Next() {
count++
}
if stream.Err() != nil {
t.Fatal(stream.Err())
}
if count != 3 {
t.Errorf("got %d chunks, want 3", count)
}
}

48
embedding/embedding.go Normal file
View File

@@ -0,0 +1,48 @@
package embedding
import "somegit.dev/vikingowl/mistral-go-sdk/chat"
// Dtype specifies the data type of output embeddings.
type Dtype string
const (
DtypeFloat Dtype = "float"
DtypeInt8 Dtype = "int8"
DtypeUint8 Dtype = "uint8"
DtypeBinary Dtype = "binary"
DtypeUbinary Dtype = "ubinary"
)
// EncodingFormat specifies the format of embeddings in the response.
type EncodingFormat string
const (
EncodingFormatFloat EncodingFormat = "float"
EncodingFormatBase64 EncodingFormat = "base64"
)
// Request represents an embedding request.
type Request struct {
Model string `json:"model"`
Input []string `json:"input"`
OutputDimension *int `json:"output_dimension,omitempty"`
OutputDtype *Dtype `json:"output_dtype,omitempty"`
EncodingFormat *EncodingFormat `json:"encoding_format,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
}
// Response represents an embedding response.
type Response struct {
ID string `json:"id"`
Object string `json:"object"`
Model string `json:"model"`
Usage chat.UsageInfo `json:"usage"`
Data []Data `json:"data"`
}
// Data represents a single embedding result.
type Data struct {
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
}

16
embeddings.go Normal file
View File

@@ -0,0 +1,16 @@
package mistral
import (
"context"
"somegit.dev/vikingowl/mistral-go-sdk/embedding"
)
// CreateEmbeddings sends an embedding request and returns the response.
func (c *Client) CreateEmbeddings(ctx context.Context, req *embedding.Request) (*embedding.Response, error) {
var resp embedding.Response
if err := c.doJSON(ctx, "POST", "/v1/embeddings", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}

158
embeddings_test.go Normal file
View File

@@ -0,0 +1,158 @@
package mistral
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/embedding"
)
func TestCreateEmbeddings_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/embeddings" {
t.Errorf("expected /v1/embeddings, got %s", r.URL.Path)
}
if r.Method != "POST" {
t.Errorf("expected POST, got %s", r.Method)
}
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
if body["model"] != "mistral-embed" {
t.Errorf("expected model=mistral-embed, got %v", body["model"])
}
inputs := body["input"].([]any)
if len(inputs) != 2 {
t.Errorf("expected 2 inputs, got %d", len(inputs))
}
json.NewEncoder(w).Encode(map[string]any{
"id": "emb-1",
"object": "list",
"model": "mistral-embed",
"usage": map[string]any{"prompt_tokens": 12, "completion_tokens": 0, "total_tokens": 12},
"data": []map[string]any{
{"object": "embedding", "embedding": []float64{0.1, 0.2, 0.3}, "index": 0},
{"object": "embedding", "embedding": []float64{0.4, 0.5, 0.6}, "index": 1},
},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.CreateEmbeddings(context.Background(), &embedding.Request{
Model: "mistral-embed",
Input: []string{"Hello world", "Goodbye world"},
})
if err != nil {
t.Fatal(err)
}
if resp.ID != "emb-1" {
t.Errorf("got id %q", resp.ID)
}
if len(resp.Data) != 2 {
t.Fatalf("got %d embeddings, want 2", len(resp.Data))
}
if resp.Data[0].Index != 0 {
t.Errorf("got index %d", resp.Data[0].Index)
}
if len(resp.Data[0].Embedding) != 3 {
t.Fatalf("got %d dims, want 3", len(resp.Data[0].Embedding))
}
if resp.Data[0].Embedding[0] != 0.1 {
t.Errorf("got embedding[0]=%f", resp.Data[0].Embedding[0])
}
if resp.Data[1].Embedding[2] != 0.6 {
t.Errorf("got embedding[2]=%f", resp.Data[1].Embedding[2])
}
if resp.Usage.PromptTokens != 12 {
t.Errorf("got prompt_tokens=%d", resp.Usage.PromptTokens)
}
}
func TestCreateEmbeddings_SingleInput(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
inputs := body["input"].([]any)
if len(inputs) != 1 {
t.Errorf("expected 1 input, got %d", len(inputs))
}
json.NewEncoder(w).Encode(map[string]any{
"id": "emb-2", "object": "list", "model": "mistral-embed",
"usage": map[string]any{"prompt_tokens": 5, "completion_tokens": 0, "total_tokens": 5},
"data": []map[string]any{
{"object": "embedding", "embedding": []float64{0.1, 0.2}, "index": 0},
},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.CreateEmbeddings(context.Background(), &embedding.Request{
Model: "mistral-embed",
Input: []string{"Just one"},
})
if err != nil {
t.Fatal(err)
}
if len(resp.Data) != 1 {
t.Errorf("got %d, want 1", len(resp.Data))
}
}
func TestCreateEmbeddings_WithOptions(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
if body["output_dimension"] != float64(256) {
t.Errorf("expected output_dimension=256, got %v", body["output_dimension"])
}
if body["output_dtype"] != "int8" {
t.Errorf("expected output_dtype=int8, got %v", body["output_dtype"])
}
json.NewEncoder(w).Encode(map[string]any{
"id": "emb-3", "object": "list", "model": "m",
"usage": map[string]any{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
"data": []map[string]any{{"object": "embedding", "embedding": []float64{1, 2}, "index": 0}},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
dim := 256
dtype := embedding.DtypeInt8
_, err := client.CreateEmbeddings(context.Background(), &embedding.Request{
Model: "m",
Input: []string{"test"},
OutputDimension: &dim,
OutputDtype: &dtype,
})
if err != nil {
t.Fatal(err)
}
}
func TestCreateEmbeddings_APIError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTooManyRequests)
json.NewEncoder(w).Encode(map[string]any{"message": "rate limited"})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
_, err := client.CreateEmbeddings(context.Background(), &embedding.Request{
Model: "m",
Input: []string{"test"},
})
if err == nil {
t.Fatal("expected error")
}
if !IsRateLimit(err) {
t.Errorf("expected rate limit, got: %v", err)
}
}

32
fim/request.go Normal file
View File

@@ -0,0 +1,32 @@
package fim
import "encoding/json"
// CompletionRequest represents a Fill-In-the-Middle completion request.
type CompletionRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Suffix *string `json:"suffix,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
MaxTokens *int `json:"max_tokens,omitempty"`
MinTokens *int `json:"min_tokens,omitempty"`
Stop []string `json:"stop,omitempty"`
RandomSeed *int `json:"random_seed,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
stream bool
}
// SetStream is used internally to set the stream field.
func (r *CompletionRequest) SetStream(v bool) { r.stream = v }
func (r *CompletionRequest) MarshalJSON() ([]byte, error) {
type Alias CompletionRequest
return json.Marshal(&struct {
Stream bool `json:"stream"`
*Alias
}{
Stream: r.stream,
Alias: (*Alias)(r),
})
}

27
fim_complete.go Normal file
View File

@@ -0,0 +1,27 @@
package mistral
import (
"context"
"somegit.dev/vikingowl/mistral-go-sdk/chat"
"somegit.dev/vikingowl/mistral-go-sdk/fim"
)
// FIMComplete sends a Fill-In-the-Middle completion request.
func (c *Client) FIMComplete(ctx context.Context, req *fim.CompletionRequest) (*chat.CompletionResponse, error) {
var resp chat.CompletionResponse
if err := c.doJSON(ctx, "POST", "/v1/fim/completions", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// FIMCompleteStream sends a FIM request and returns a stream of chunks.
func (c *Client) FIMCompleteStream(ctx context.Context, req *fim.CompletionRequest) (*Stream[chat.CompletionChunk], error) {
req.SetStream(true)
resp, err := c.doStream(ctx, "POST", "/v1/fim/completions", req)
if err != nil {
return nil, err
}
return newStream[chat.CompletionChunk](resp.Body), nil
}

175
fim_complete_test.go Normal file
View File

@@ -0,0 +1,175 @@
package mistral
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/chat"
"somegit.dev/vikingowl/mistral-go-sdk/fim"
)
func TestFIMComplete_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/fim/completions" {
t.Errorf("expected /v1/fim/completions, got %s", r.URL.Path)
}
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
if body["prompt"] != "def add(a, b):" {
t.Errorf("expected prompt, got %v", body["prompt"])
}
if body["suffix"] != "return result" {
t.Errorf("expected suffix, got %v", body["suffix"])
}
if body["model"] != "codestral-latest" {
t.Errorf("expected model codestral-latest, got %v", body["model"])
}
if body["stream"] != false {
t.Errorf("expected stream=false, got %v", body["stream"])
}
json.NewEncoder(w).Encode(map[string]any{
"id": "fim-1", "object": "chat.completion",
"model": "codestral-latest", "created": 1234567890,
"choices": []map[string]any{{
"index": 0,
"message": map[string]any{"role": "assistant", "content": "\n result = a + b\n "},
"finish_reason": "stop",
}},
"usage": map[string]any{"prompt_tokens": 15, "completion_tokens": 10, "total_tokens": 25},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
suffix := "return result"
resp, err := client.FIMComplete(context.Background(), &fim.CompletionRequest{
Model: "codestral-latest",
Prompt: "def add(a, b):",
Suffix: &suffix,
})
if err != nil {
t.Fatal(err)
}
if resp.ID != "fim-1" {
t.Errorf("got id %q", resp.ID)
}
if resp.Choices[0].Message.Content.String() != "\n result = a + b\n " {
t.Errorf("got content %q", resp.Choices[0].Message.Content.String())
}
}
func TestFIMComplete_WithParams(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
if body["temperature"] != 0.2 {
t.Errorf("expected temperature=0.2, got %v", body["temperature"])
}
if body["max_tokens"] != float64(50) {
t.Errorf("expected max_tokens=50, got %v", body["max_tokens"])
}
if body["min_tokens"] != float64(10) {
t.Errorf("expected min_tokens=10, got %v", body["min_tokens"])
}
json.NewEncoder(w).Encode(map[string]any{
"id": "fim-2", "object": "chat.completion",
"model": "codestral-latest", "created": 0,
"choices": []map[string]any{{
"index": 0, "message": map[string]any{"role": "assistant", "content": "code"},
"finish_reason": "length",
}},
"usage": map[string]any{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
temp := 0.2
maxTok := 50
minTok := 10
_, err := client.FIMComplete(context.Background(), &fim.CompletionRequest{
Model: "codestral-latest",
Prompt: "fn main() {",
Temperature: &temp,
MaxTokens: &maxTok,
MinTokens: &minTok,
})
if err != nil {
t.Fatal(err)
}
}
func TestFIMCompleteStream_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
if body["stream"] != true {
t.Errorf("expected stream=true")
}
w.Header().Set("Content-Type", "text/event-stream")
flusher, _ := w.(http.Flusher)
for _, content := range []string{"\n ", "result = a + b", "\n "} {
chunk := chat.CompletionChunk{
ID: "fc", Model: "codestral-latest",
Choices: []chat.CompletionStreamChoice{{
Index: 0,
Delta: chat.DeltaMessage{Content: chat.TextContent(content)},
}},
}
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()
}
fmt.Fprint(w, "data: [DONE]\n\n")
flusher.Flush()
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
stream, err := client.FIMCompleteStream(context.Background(), &fim.CompletionRequest{
Model: "codestral-latest",
Prompt: "def add(a, b):",
})
if err != nil {
t.Fatal(err)
}
defer stream.Close()
var count int
for stream.Next() {
count++
}
if stream.Err() != nil {
t.Fatal(stream.Err())
}
if count != 3 {
t.Errorf("got %d chunks, want 3", count)
}
}
func TestFIMComplete_APIError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
json.NewEncoder(w).Encode(map[string]any{"message": "model not found"})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
_, err := client.FIMComplete(context.Background(), &fim.CompletionRequest{
Model: "bad-model",
Prompt: "code",
})
if err == nil {
t.Fatal("expected error")
}
if !IsNotFound(err) {
t.Errorf("expected not found, got: %v", err)
}
}