From 4b8ca4be5dbe48da8f98238149a201a3dac91ac7 Mon Sep 17 00:00:00 2001 From: vikingowl Date: Thu, 5 Mar 2026 19:36:49 +0100 Subject: [PATCH] =?UTF-8?q?feat:=20Phase=203=20core=20completions=20?= =?UTF-8?q?=E2=80=94=20FIM,=20Agents,=20Embeddings?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add FIM, Agents, and Embedding endpoints: - fim/request.go: FIMCompletionRequest (prompt/suffix model) - agents/request.go: AgentsCompletionRequest (agent_id + messages) - embedding/embedding.go: Request/Response/Data types with dtype/encoding - FIMComplete, FIMCompleteStream, AgentsComplete, AgentsCompleteStream, CreateEmbeddings service methods - All reuse chat.CompletionResponse/CompletionChunk for responses - 11 new httptest-based tests --- agents/request.go | 45 +++++++++++ agents_complete.go | 27 +++++++ agents_complete_test.go | 164 +++++++++++++++++++++++++++++++++++++ embedding/embedding.go | 48 +++++++++++ embeddings.go | 16 ++++ embeddings_test.go | 158 ++++++++++++++++++++++++++++++++++++ fim/request.go | 32 ++++++++ fim_complete.go | 27 +++++++ fim_complete_test.go | 175 ++++++++++++++++++++++++++++++++++++++++ 9 files changed, 692 insertions(+) create mode 100644 agents/request.go create mode 100644 agents_complete.go create mode 100644 agents_complete_test.go create mode 100644 embedding/embedding.go create mode 100644 embeddings.go create mode 100644 embeddings_test.go create mode 100644 fim/request.go create mode 100644 fim_complete.go create mode 100644 fim_complete_test.go diff --git a/agents/request.go b/agents/request.go new file mode 100644 index 0000000..1fae492 --- /dev/null +++ b/agents/request.go @@ -0,0 +1,45 @@ +package agents + +import ( + "encoding/json" + + "somegit.dev/vikingowl/mistral-go-sdk/chat" +) + +// CompletionRequest represents an agents completion request. +type CompletionRequest struct { + AgentID string `json:"agent_id"` + Messages []chat.Message `json:"-"` + MaxTokens *int `json:"max_tokens,omitempty"` + Stop []string `json:"stop,omitempty"` + RandomSeed *int `json:"random_seed,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + ResponseFormat *chat.ResponseFormat `json:"response_format,omitempty"` + Tools []chat.Tool `json:"tools,omitempty"` + ToolChoice *chat.ToolChoice `json:"tool_choice,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + N *int `json:"n,omitempty"` + ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + Prediction *chat.Prediction `json:"prediction,omitempty"` + PromptMode *chat.PromptMode `json:"prompt_mode,omitempty"` + stream bool +} + +// SetStream is used internally to set the stream field. +func (r *CompletionRequest) SetStream(v bool) { r.stream = v } + +func (r *CompletionRequest) MarshalJSON() ([]byte, error) { + type Alias CompletionRequest + return json.Marshal(&struct { + Messages []chat.Message `json:"messages"` + Stream bool `json:"stream"` + *Alias + }{ + Messages: r.Messages, + Stream: r.stream, + Alias: (*Alias)(r), + }) +} diff --git a/agents_complete.go b/agents_complete.go new file mode 100644 index 0000000..37191ca --- /dev/null +++ b/agents_complete.go @@ -0,0 +1,27 @@ +package mistral + +import ( + "context" + + "somegit.dev/vikingowl/mistral-go-sdk/agents" + "somegit.dev/vikingowl/mistral-go-sdk/chat" +) + +// AgentsComplete sends an agents completion request. +func (c *Client) AgentsComplete(ctx context.Context, req *agents.CompletionRequest) (*chat.CompletionResponse, error) { + var resp chat.CompletionResponse + if err := c.doJSON(ctx, "POST", "/v1/agents/completions", req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// AgentsCompleteStream sends an agents request and returns a stream of chunks. +func (c *Client) AgentsCompleteStream(ctx context.Context, req *agents.CompletionRequest) (*Stream[chat.CompletionChunk], error) { + req.SetStream(true) + resp, err := c.doStream(ctx, "POST", "/v1/agents/completions", req) + if err != nil { + return nil, err + } + return newStream[chat.CompletionChunk](resp.Body), nil +} diff --git a/agents_complete_test.go b/agents_complete_test.go new file mode 100644 index 0000000..cbdaf95 --- /dev/null +++ b/agents_complete_test.go @@ -0,0 +1,164 @@ +package mistral + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "somegit.dev/vikingowl/mistral-go-sdk/agents" + "somegit.dev/vikingowl/mistral-go-sdk/chat" +) + +func TestAgentsComplete_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/agents/completions" { + t.Errorf("expected /v1/agents/completions, got %s", r.URL.Path) + } + var body map[string]any + json.NewDecoder(r.Body).Decode(&body) + if body["agent_id"] != "agent-123" { + t.Errorf("expected agent_id=agent-123, got %v", body["agent_id"]) + } + msgs := body["messages"].([]any) + if len(msgs) != 1 { + t.Errorf("expected 1 message, got %d", len(msgs)) + } + if body["stream"] != false { + t.Errorf("expected stream=false") + } + + json.NewEncoder(w).Encode(map[string]any{ + "id": "agent-resp-1", "object": "chat.completion", + "model": "mistral-large-latest", "created": 1234567890, + "choices": []map[string]any{{ + "index": 0, + "message": map[string]any{"role": "assistant", "content": "Agent response"}, + "finish_reason": "stop", + }}, + "usage": map[string]any{"prompt_tokens": 20, "completion_tokens": 5, "total_tokens": 25}, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.AgentsComplete(context.Background(), &agents.CompletionRequest{ + AgentID: "agent-123", + Messages: []chat.Message{ + &chat.UserMessage{Content: chat.TextContent("Hello agent")}, + }, + }) + if err != nil { + t.Fatal(err) + } + if resp.ID != "agent-resp-1" { + t.Errorf("got id %q", resp.ID) + } + if resp.Choices[0].Message.Content.String() != "Agent response" { + t.Errorf("got content %q", resp.Choices[0].Message.Content.String()) + } +} + +func TestAgentsComplete_WithTools(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body map[string]any + json.NewDecoder(r.Body).Decode(&body) + tools := body["tools"].([]any) + if len(tools) != 1 { + t.Errorf("expected 1 tool, got %d", len(tools)) + } + + json.NewEncoder(w).Encode(map[string]any{ + "id": "a2", "object": "chat.completion", + "model": "m", "created": 0, + "choices": []map[string]any{{ + "index": 0, + "message": map[string]any{ + "role": "assistant", "content": nil, + "tool_calls": []map[string]any{{ + "id": "tc1", "type": "function", + "function": map[string]any{"name": "search", "arguments": `{"q":"test"}`}, + }}, + }, + "finish_reason": "tool_calls", + }}, + "usage": map[string]any{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.AgentsComplete(context.Background(), &agents.CompletionRequest{ + AgentID: "agent-456", + Messages: []chat.Message{ + &chat.UserMessage{Content: chat.TextContent("Search for test")}, + }, + Tools: []chat.Tool{{ + Type: "function", + Function: chat.Function{ + Name: "search", + Parameters: map[string]any{"type": "object", "properties": map[string]any{"q": map[string]any{"type": "string"}}}, + }, + }}, + }) + if err != nil { + t.Fatal(err) + } + if len(resp.Choices[0].Message.ToolCalls) != 1 { + t.Fatalf("expected 1 tool call") + } + if resp.Choices[0].Message.ToolCalls[0].Function.Name != "search" { + t.Errorf("got function %q", resp.Choices[0].Message.ToolCalls[0].Function.Name) + } +} + +func TestAgentsCompleteStream_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body map[string]any + json.NewDecoder(r.Body).Decode(&body) + if body["stream"] != true { + t.Errorf("expected stream=true") + } + + w.Header().Set("Content-Type", "text/event-stream") + flusher, _ := w.(http.Flusher) + for _, word := range []string{"Hello", " from", " agent"} { + chunk := chat.CompletionChunk{ + ID: "ac", Model: "m", + Choices: []chat.CompletionStreamChoice{{ + Index: 0, + Delta: chat.DeltaMessage{Content: chat.TextContent(word)}, + }}, + } + data, _ := json.Marshal(chunk) + fmt.Fprintf(w, "data: %s\n\n", data) + flusher.Flush() + } + fmt.Fprint(w, "data: [DONE]\n\n") + flusher.Flush() + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + stream, err := client.AgentsCompleteStream(context.Background(), &agents.CompletionRequest{ + AgentID: "agent-789", + Messages: []chat.Message{&chat.UserMessage{Content: chat.TextContent("Hi")}}, + }) + if err != nil { + t.Fatal(err) + } + defer stream.Close() + + var count int + for stream.Next() { + count++ + } + if stream.Err() != nil { + t.Fatal(stream.Err()) + } + if count != 3 { + t.Errorf("got %d chunks, want 3", count) + } +} diff --git a/embedding/embedding.go b/embedding/embedding.go new file mode 100644 index 0000000..ecd6b54 --- /dev/null +++ b/embedding/embedding.go @@ -0,0 +1,48 @@ +package embedding + +import "somegit.dev/vikingowl/mistral-go-sdk/chat" + +// Dtype specifies the data type of output embeddings. +type Dtype string + +const ( + DtypeFloat Dtype = "float" + DtypeInt8 Dtype = "int8" + DtypeUint8 Dtype = "uint8" + DtypeBinary Dtype = "binary" + DtypeUbinary Dtype = "ubinary" +) + +// EncodingFormat specifies the format of embeddings in the response. +type EncodingFormat string + +const ( + EncodingFormatFloat EncodingFormat = "float" + EncodingFormatBase64 EncodingFormat = "base64" +) + +// Request represents an embedding request. +type Request struct { + Model string `json:"model"` + Input []string `json:"input"` + OutputDimension *int `json:"output_dimension,omitempty"` + OutputDtype *Dtype `json:"output_dtype,omitempty"` + EncodingFormat *EncodingFormat `json:"encoding_format,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// Response represents an embedding response. +type Response struct { + ID string `json:"id"` + Object string `json:"object"` + Model string `json:"model"` + Usage chat.UsageInfo `json:"usage"` + Data []Data `json:"data"` +} + +// Data represents a single embedding result. +type Data struct { + Object string `json:"object"` + Embedding []float64 `json:"embedding"` + Index int `json:"index"` +} diff --git a/embeddings.go b/embeddings.go new file mode 100644 index 0000000..c7f61d8 --- /dev/null +++ b/embeddings.go @@ -0,0 +1,16 @@ +package mistral + +import ( + "context" + + "somegit.dev/vikingowl/mistral-go-sdk/embedding" +) + +// CreateEmbeddings sends an embedding request and returns the response. +func (c *Client) CreateEmbeddings(ctx context.Context, req *embedding.Request) (*embedding.Response, error) { + var resp embedding.Response + if err := c.doJSON(ctx, "POST", "/v1/embeddings", req, &resp); err != nil { + return nil, err + } + return &resp, nil +} diff --git a/embeddings_test.go b/embeddings_test.go new file mode 100644 index 0000000..43c5b87 --- /dev/null +++ b/embeddings_test.go @@ -0,0 +1,158 @@ +package mistral + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "somegit.dev/vikingowl/mistral-go-sdk/embedding" +) + +func TestCreateEmbeddings_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/embeddings" { + t.Errorf("expected /v1/embeddings, got %s", r.URL.Path) + } + if r.Method != "POST" { + t.Errorf("expected POST, got %s", r.Method) + } + var body map[string]any + json.NewDecoder(r.Body).Decode(&body) + if body["model"] != "mistral-embed" { + t.Errorf("expected model=mistral-embed, got %v", body["model"]) + } + inputs := body["input"].([]any) + if len(inputs) != 2 { + t.Errorf("expected 2 inputs, got %d", len(inputs)) + } + + json.NewEncoder(w).Encode(map[string]any{ + "id": "emb-1", + "object": "list", + "model": "mistral-embed", + "usage": map[string]any{"prompt_tokens": 12, "completion_tokens": 0, "total_tokens": 12}, + "data": []map[string]any{ + {"object": "embedding", "embedding": []float64{0.1, 0.2, 0.3}, "index": 0}, + {"object": "embedding", "embedding": []float64{0.4, 0.5, 0.6}, "index": 1}, + }, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.CreateEmbeddings(context.Background(), &embedding.Request{ + Model: "mistral-embed", + Input: []string{"Hello world", "Goodbye world"}, + }) + if err != nil { + t.Fatal(err) + } + if resp.ID != "emb-1" { + t.Errorf("got id %q", resp.ID) + } + if len(resp.Data) != 2 { + t.Fatalf("got %d embeddings, want 2", len(resp.Data)) + } + if resp.Data[0].Index != 0 { + t.Errorf("got index %d", resp.Data[0].Index) + } + if len(resp.Data[0].Embedding) != 3 { + t.Fatalf("got %d dims, want 3", len(resp.Data[0].Embedding)) + } + if resp.Data[0].Embedding[0] != 0.1 { + t.Errorf("got embedding[0]=%f", resp.Data[0].Embedding[0]) + } + if resp.Data[1].Embedding[2] != 0.6 { + t.Errorf("got embedding[2]=%f", resp.Data[1].Embedding[2]) + } + if resp.Usage.PromptTokens != 12 { + t.Errorf("got prompt_tokens=%d", resp.Usage.PromptTokens) + } +} + +func TestCreateEmbeddings_SingleInput(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body map[string]any + json.NewDecoder(r.Body).Decode(&body) + inputs := body["input"].([]any) + if len(inputs) != 1 { + t.Errorf("expected 1 input, got %d", len(inputs)) + } + + json.NewEncoder(w).Encode(map[string]any{ + "id": "emb-2", "object": "list", "model": "mistral-embed", + "usage": map[string]any{"prompt_tokens": 5, "completion_tokens": 0, "total_tokens": 5}, + "data": []map[string]any{ + {"object": "embedding", "embedding": []float64{0.1, 0.2}, "index": 0}, + }, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.CreateEmbeddings(context.Background(), &embedding.Request{ + Model: "mistral-embed", + Input: []string{"Just one"}, + }) + if err != nil { + t.Fatal(err) + } + if len(resp.Data) != 1 { + t.Errorf("got %d, want 1", len(resp.Data)) + } +} + +func TestCreateEmbeddings_WithOptions(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body map[string]any + json.NewDecoder(r.Body).Decode(&body) + if body["output_dimension"] != float64(256) { + t.Errorf("expected output_dimension=256, got %v", body["output_dimension"]) + } + if body["output_dtype"] != "int8" { + t.Errorf("expected output_dtype=int8, got %v", body["output_dtype"]) + } + + json.NewEncoder(w).Encode(map[string]any{ + "id": "emb-3", "object": "list", "model": "m", + "usage": map[string]any{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + "data": []map[string]any{{"object": "embedding", "embedding": []float64{1, 2}, "index": 0}}, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + dim := 256 + dtype := embedding.DtypeInt8 + _, err := client.CreateEmbeddings(context.Background(), &embedding.Request{ + Model: "m", + Input: []string{"test"}, + OutputDimension: &dim, + OutputDtype: &dtype, + }) + if err != nil { + t.Fatal(err) + } +} + +func TestCreateEmbeddings_APIError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTooManyRequests) + json.NewEncoder(w).Encode(map[string]any{"message": "rate limited"}) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + _, err := client.CreateEmbeddings(context.Background(), &embedding.Request{ + Model: "m", + Input: []string{"test"}, + }) + if err == nil { + t.Fatal("expected error") + } + if !IsRateLimit(err) { + t.Errorf("expected rate limit, got: %v", err) + } +} diff --git a/fim/request.go b/fim/request.go new file mode 100644 index 0000000..3040d8c --- /dev/null +++ b/fim/request.go @@ -0,0 +1,32 @@ +package fim + +import "encoding/json" + +// CompletionRequest represents a Fill-In-the-Middle completion request. +type CompletionRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + Suffix *string `json:"suffix,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + MinTokens *int `json:"min_tokens,omitempty"` + Stop []string `json:"stop,omitempty"` + RandomSeed *int `json:"random_seed,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` + stream bool +} + +// SetStream is used internally to set the stream field. +func (r *CompletionRequest) SetStream(v bool) { r.stream = v } + +func (r *CompletionRequest) MarshalJSON() ([]byte, error) { + type Alias CompletionRequest + return json.Marshal(&struct { + Stream bool `json:"stream"` + *Alias + }{ + Stream: r.stream, + Alias: (*Alias)(r), + }) +} diff --git a/fim_complete.go b/fim_complete.go new file mode 100644 index 0000000..7dd9a3c --- /dev/null +++ b/fim_complete.go @@ -0,0 +1,27 @@ +package mistral + +import ( + "context" + + "somegit.dev/vikingowl/mistral-go-sdk/chat" + "somegit.dev/vikingowl/mistral-go-sdk/fim" +) + +// FIMComplete sends a Fill-In-the-Middle completion request. +func (c *Client) FIMComplete(ctx context.Context, req *fim.CompletionRequest) (*chat.CompletionResponse, error) { + var resp chat.CompletionResponse + if err := c.doJSON(ctx, "POST", "/v1/fim/completions", req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// FIMCompleteStream sends a FIM request and returns a stream of chunks. +func (c *Client) FIMCompleteStream(ctx context.Context, req *fim.CompletionRequest) (*Stream[chat.CompletionChunk], error) { + req.SetStream(true) + resp, err := c.doStream(ctx, "POST", "/v1/fim/completions", req) + if err != nil { + return nil, err + } + return newStream[chat.CompletionChunk](resp.Body), nil +} diff --git a/fim_complete_test.go b/fim_complete_test.go new file mode 100644 index 0000000..6945eec --- /dev/null +++ b/fim_complete_test.go @@ -0,0 +1,175 @@ +package mistral + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "somegit.dev/vikingowl/mistral-go-sdk/chat" + "somegit.dev/vikingowl/mistral-go-sdk/fim" +) + +func TestFIMComplete_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/fim/completions" { + t.Errorf("expected /v1/fim/completions, got %s", r.URL.Path) + } + var body map[string]any + json.NewDecoder(r.Body).Decode(&body) + if body["prompt"] != "def add(a, b):" { + t.Errorf("expected prompt, got %v", body["prompt"]) + } + if body["suffix"] != "return result" { + t.Errorf("expected suffix, got %v", body["suffix"]) + } + if body["model"] != "codestral-latest" { + t.Errorf("expected model codestral-latest, got %v", body["model"]) + } + if body["stream"] != false { + t.Errorf("expected stream=false, got %v", body["stream"]) + } + + json.NewEncoder(w).Encode(map[string]any{ + "id": "fim-1", "object": "chat.completion", + "model": "codestral-latest", "created": 1234567890, + "choices": []map[string]any{{ + "index": 0, + "message": map[string]any{"role": "assistant", "content": "\n result = a + b\n "}, + "finish_reason": "stop", + }}, + "usage": map[string]any{"prompt_tokens": 15, "completion_tokens": 10, "total_tokens": 25}, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + suffix := "return result" + resp, err := client.FIMComplete(context.Background(), &fim.CompletionRequest{ + Model: "codestral-latest", + Prompt: "def add(a, b):", + Suffix: &suffix, + }) + if err != nil { + t.Fatal(err) + } + if resp.ID != "fim-1" { + t.Errorf("got id %q", resp.ID) + } + if resp.Choices[0].Message.Content.String() != "\n result = a + b\n " { + t.Errorf("got content %q", resp.Choices[0].Message.Content.String()) + } +} + +func TestFIMComplete_WithParams(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body map[string]any + json.NewDecoder(r.Body).Decode(&body) + if body["temperature"] != 0.2 { + t.Errorf("expected temperature=0.2, got %v", body["temperature"]) + } + if body["max_tokens"] != float64(50) { + t.Errorf("expected max_tokens=50, got %v", body["max_tokens"]) + } + if body["min_tokens"] != float64(10) { + t.Errorf("expected min_tokens=10, got %v", body["min_tokens"]) + } + + json.NewEncoder(w).Encode(map[string]any{ + "id": "fim-2", "object": "chat.completion", + "model": "codestral-latest", "created": 0, + "choices": []map[string]any{{ + "index": 0, "message": map[string]any{"role": "assistant", "content": "code"}, + "finish_reason": "length", + }}, + "usage": map[string]any{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + temp := 0.2 + maxTok := 50 + minTok := 10 + _, err := client.FIMComplete(context.Background(), &fim.CompletionRequest{ + Model: "codestral-latest", + Prompt: "fn main() {", + Temperature: &temp, + MaxTokens: &maxTok, + MinTokens: &minTok, + }) + if err != nil { + t.Fatal(err) + } +} + +func TestFIMCompleteStream_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body map[string]any + json.NewDecoder(r.Body).Decode(&body) + if body["stream"] != true { + t.Errorf("expected stream=true") + } + + w.Header().Set("Content-Type", "text/event-stream") + flusher, _ := w.(http.Flusher) + for _, content := range []string{"\n ", "result = a + b", "\n "} { + chunk := chat.CompletionChunk{ + ID: "fc", Model: "codestral-latest", + Choices: []chat.CompletionStreamChoice{{ + Index: 0, + Delta: chat.DeltaMessage{Content: chat.TextContent(content)}, + }}, + } + data, _ := json.Marshal(chunk) + fmt.Fprintf(w, "data: %s\n\n", data) + flusher.Flush() + } + fmt.Fprint(w, "data: [DONE]\n\n") + flusher.Flush() + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + stream, err := client.FIMCompleteStream(context.Background(), &fim.CompletionRequest{ + Model: "codestral-latest", + Prompt: "def add(a, b):", + }) + if err != nil { + t.Fatal(err) + } + defer stream.Close() + + var count int + for stream.Next() { + count++ + } + if stream.Err() != nil { + t.Fatal(stream.Err()) + } + if count != 3 { + t.Errorf("got %d chunks, want 3", count) + } +} + +func TestFIMComplete_APIError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + json.NewEncoder(w).Encode(map[string]any{"message": "model not found"}) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + _, err := client.FIMComplete(context.Background(), &fim.CompletionRequest{ + Model: "bad-model", + Prompt: "code", + }) + if err == nil { + t.Fatal("expected error") + } + if !IsNotFound(err) { + t.Errorf("expected not found, got: %v", err) + } +}