Files
mistral-go-sdk/embeddings_test.go
vikingowl 07505b333f feat: Phase 3 core completions — FIM, Agents, Embeddings
Add FIM, Agents, and Embedding endpoints:
- fim/request.go: FIMCompletionRequest (prompt/suffix model)
- agents/request.go: AgentsCompletionRequest (agent_id + messages)
- embedding/embedding.go: Request/Response/Data types with dtype/encoding
- FIMComplete, FIMCompleteStream, AgentsComplete, AgentsCompleteStream,
  CreateEmbeddings service methods
- All reuse chat.CompletionResponse/CompletionChunk for responses
- 11 new httptest-based tests
2026-03-05 19:36:49 +01:00

159 lines
4.6 KiB
Go

package mistral
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/embedding"
)
func TestCreateEmbeddings_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/embeddings" {
t.Errorf("expected /v1/embeddings, got %s", r.URL.Path)
}
if r.Method != "POST" {
t.Errorf("expected POST, got %s", r.Method)
}
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
if body["model"] != "mistral-embed" {
t.Errorf("expected model=mistral-embed, got %v", body["model"])
}
inputs := body["input"].([]any)
if len(inputs) != 2 {
t.Errorf("expected 2 inputs, got %d", len(inputs))
}
json.NewEncoder(w).Encode(map[string]any{
"id": "emb-1",
"object": "list",
"model": "mistral-embed",
"usage": map[string]any{"prompt_tokens": 12, "completion_tokens": 0, "total_tokens": 12},
"data": []map[string]any{
{"object": "embedding", "embedding": []float64{0.1, 0.2, 0.3}, "index": 0},
{"object": "embedding", "embedding": []float64{0.4, 0.5, 0.6}, "index": 1},
},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.CreateEmbeddings(context.Background(), &embedding.Request{
Model: "mistral-embed",
Input: []string{"Hello world", "Goodbye world"},
})
if err != nil {
t.Fatal(err)
}
if resp.ID != "emb-1" {
t.Errorf("got id %q", resp.ID)
}
if len(resp.Data) != 2 {
t.Fatalf("got %d embeddings, want 2", len(resp.Data))
}
if resp.Data[0].Index != 0 {
t.Errorf("got index %d", resp.Data[0].Index)
}
if len(resp.Data[0].Embedding) != 3 {
t.Fatalf("got %d dims, want 3", len(resp.Data[0].Embedding))
}
if resp.Data[0].Embedding[0] != 0.1 {
t.Errorf("got embedding[0]=%f", resp.Data[0].Embedding[0])
}
if resp.Data[1].Embedding[2] != 0.6 {
t.Errorf("got embedding[2]=%f", resp.Data[1].Embedding[2])
}
if resp.Usage.PromptTokens != 12 {
t.Errorf("got prompt_tokens=%d", resp.Usage.PromptTokens)
}
}
func TestCreateEmbeddings_SingleInput(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
inputs := body["input"].([]any)
if len(inputs) != 1 {
t.Errorf("expected 1 input, got %d", len(inputs))
}
json.NewEncoder(w).Encode(map[string]any{
"id": "emb-2", "object": "list", "model": "mistral-embed",
"usage": map[string]any{"prompt_tokens": 5, "completion_tokens": 0, "total_tokens": 5},
"data": []map[string]any{
{"object": "embedding", "embedding": []float64{0.1, 0.2}, "index": 0},
},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.CreateEmbeddings(context.Background(), &embedding.Request{
Model: "mistral-embed",
Input: []string{"Just one"},
})
if err != nil {
t.Fatal(err)
}
if len(resp.Data) != 1 {
t.Errorf("got %d, want 1", len(resp.Data))
}
}
func TestCreateEmbeddings_WithOptions(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
if body["output_dimension"] != float64(256) {
t.Errorf("expected output_dimension=256, got %v", body["output_dimension"])
}
if body["output_dtype"] != "int8" {
t.Errorf("expected output_dtype=int8, got %v", body["output_dtype"])
}
json.NewEncoder(w).Encode(map[string]any{
"id": "emb-3", "object": "list", "model": "m",
"usage": map[string]any{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
"data": []map[string]any{{"object": "embedding", "embedding": []float64{1, 2}, "index": 0}},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
dim := 256
dtype := embedding.DtypeInt8
_, err := client.CreateEmbeddings(context.Background(), &embedding.Request{
Model: "m",
Input: []string{"test"},
OutputDimension: &dim,
OutputDtype: &dtype,
})
if err != nil {
t.Fatal(err)
}
}
func TestCreateEmbeddings_APIError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTooManyRequests)
json.NewEncoder(w).Encode(map[string]any{"message": "rate limited"})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
_, err := client.CreateEmbeddings(context.Background(), &embedding.Request{
Model: "m",
Input: []string{"test"},
})
if err == nil {
t.Fatal("expected error")
}
if !IsRateLimit(err) {
t.Errorf("expected rate limit, got: %v", err)
}
}