feat: Phase 3 core completions — FIM, Agents, Embeddings
Add FIM, Agents, and Embedding endpoints: - fim/request.go: FIMCompletionRequest (prompt/suffix model) - agents/request.go: AgentsCompletionRequest (agent_id + messages) - embedding/embedding.go: Request/Response/Data types with dtype/encoding - FIMComplete, FIMCompleteStream, AgentsComplete, AgentsCompleteStream, CreateEmbeddings service methods - All reuse chat.CompletionResponse/CompletionChunk for responses - 11 new httptest-based tests
This commit is contained in:
45
agents/request.go
Normal file
45
agents/request.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package agents
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/chat"
|
||||
)
|
||||
|
||||
// CompletionRequest represents an agents completion request.
|
||||
type CompletionRequest struct {
|
||||
AgentID string `json:"agent_id"`
|
||||
Messages []chat.Message `json:"-"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
RandomSeed *int `json:"random_seed,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
ResponseFormat *chat.ResponseFormat `json:"response_format,omitempty"`
|
||||
Tools []chat.Tool `json:"tools,omitempty"`
|
||||
ToolChoice *chat.ToolChoice `json:"tool_choice,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
N *int `json:"n,omitempty"`
|
||||
ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
Prediction *chat.Prediction `json:"prediction,omitempty"`
|
||||
PromptMode *chat.PromptMode `json:"prompt_mode,omitempty"`
|
||||
stream bool
|
||||
}
|
||||
|
||||
// SetStream is used internally to set the stream field.
|
||||
func (r *CompletionRequest) SetStream(v bool) { r.stream = v }
|
||||
|
||||
func (r *CompletionRequest) MarshalJSON() ([]byte, error) {
|
||||
type Alias CompletionRequest
|
||||
return json.Marshal(&struct {
|
||||
Messages []chat.Message `json:"messages"`
|
||||
Stream bool `json:"stream"`
|
||||
*Alias
|
||||
}{
|
||||
Messages: r.Messages,
|
||||
Stream: r.stream,
|
||||
Alias: (*Alias)(r),
|
||||
})
|
||||
}
|
||||
27
agents_complete.go
Normal file
27
agents_complete.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/agents"
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/chat"
|
||||
)
|
||||
|
||||
// AgentsComplete sends an agents completion request.
|
||||
func (c *Client) AgentsComplete(ctx context.Context, req *agents.CompletionRequest) (*chat.CompletionResponse, error) {
|
||||
var resp chat.CompletionResponse
|
||||
if err := c.doJSON(ctx, "POST", "/v1/agents/completions", req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// AgentsCompleteStream sends an agents request and returns a stream of chunks.
|
||||
func (c *Client) AgentsCompleteStream(ctx context.Context, req *agents.CompletionRequest) (*Stream[chat.CompletionChunk], error) {
|
||||
req.SetStream(true)
|
||||
resp, err := c.doStream(ctx, "POST", "/v1/agents/completions", req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newStream[chat.CompletionChunk](resp.Body), nil
|
||||
}
|
||||
164
agents_complete_test.go
Normal file
164
agents_complete_test.go
Normal file
@@ -0,0 +1,164 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/agents"
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/chat"
|
||||
)
|
||||
|
||||
func TestAgentsComplete_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/agents/completions" {
|
||||
t.Errorf("expected /v1/agents/completions, got %s", r.URL.Path)
|
||||
}
|
||||
var body map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
if body["agent_id"] != "agent-123" {
|
||||
t.Errorf("expected agent_id=agent-123, got %v", body["agent_id"])
|
||||
}
|
||||
msgs := body["messages"].([]any)
|
||||
if len(msgs) != 1 {
|
||||
t.Errorf("expected 1 message, got %d", len(msgs))
|
||||
}
|
||||
if body["stream"] != false {
|
||||
t.Errorf("expected stream=false")
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "agent-resp-1", "object": "chat.completion",
|
||||
"model": "mistral-large-latest", "created": 1234567890,
|
||||
"choices": []map[string]any{{
|
||||
"index": 0,
|
||||
"message": map[string]any{"role": "assistant", "content": "Agent response"},
|
||||
"finish_reason": "stop",
|
||||
}},
|
||||
"usage": map[string]any{"prompt_tokens": 20, "completion_tokens": 5, "total_tokens": 25},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.AgentsComplete(context.Background(), &agents.CompletionRequest{
|
||||
AgentID: "agent-123",
|
||||
Messages: []chat.Message{
|
||||
&chat.UserMessage{Content: chat.TextContent("Hello agent")},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.ID != "agent-resp-1" {
|
||||
t.Errorf("got id %q", resp.ID)
|
||||
}
|
||||
if resp.Choices[0].Message.Content.String() != "Agent response" {
|
||||
t.Errorf("got content %q", resp.Choices[0].Message.Content.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentsComplete_WithTools(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var body map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
tools := body["tools"].([]any)
|
||||
if len(tools) != 1 {
|
||||
t.Errorf("expected 1 tool, got %d", len(tools))
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "a2", "object": "chat.completion",
|
||||
"model": "m", "created": 0,
|
||||
"choices": []map[string]any{{
|
||||
"index": 0,
|
||||
"message": map[string]any{
|
||||
"role": "assistant", "content": nil,
|
||||
"tool_calls": []map[string]any{{
|
||||
"id": "tc1", "type": "function",
|
||||
"function": map[string]any{"name": "search", "arguments": `{"q":"test"}`},
|
||||
}},
|
||||
},
|
||||
"finish_reason": "tool_calls",
|
||||
}},
|
||||
"usage": map[string]any{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.AgentsComplete(context.Background(), &agents.CompletionRequest{
|
||||
AgentID: "agent-456",
|
||||
Messages: []chat.Message{
|
||||
&chat.UserMessage{Content: chat.TextContent("Search for test")},
|
||||
},
|
||||
Tools: []chat.Tool{{
|
||||
Type: "function",
|
||||
Function: chat.Function{
|
||||
Name: "search",
|
||||
Parameters: map[string]any{"type": "object", "properties": map[string]any{"q": map[string]any{"type": "string"}}},
|
||||
},
|
||||
}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(resp.Choices[0].Message.ToolCalls) != 1 {
|
||||
t.Fatalf("expected 1 tool call")
|
||||
}
|
||||
if resp.Choices[0].Message.ToolCalls[0].Function.Name != "search" {
|
||||
t.Errorf("got function %q", resp.Choices[0].Message.ToolCalls[0].Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentsCompleteStream_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var body map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
if body["stream"] != true {
|
||||
t.Errorf("expected stream=true")
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
flusher, _ := w.(http.Flusher)
|
||||
for _, word := range []string{"Hello", " from", " agent"} {
|
||||
chunk := chat.CompletionChunk{
|
||||
ID: "ac", Model: "m",
|
||||
Choices: []chat.CompletionStreamChoice{{
|
||||
Index: 0,
|
||||
Delta: chat.DeltaMessage{Content: chat.TextContent(word)},
|
||||
}},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
}
|
||||
fmt.Fprint(w, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
stream, err := client.AgentsCompleteStream(context.Background(), &agents.CompletionRequest{
|
||||
AgentID: "agent-789",
|
||||
Messages: []chat.Message{&chat.UserMessage{Content: chat.TextContent("Hi")}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
var count int
|
||||
for stream.Next() {
|
||||
count++
|
||||
}
|
||||
if stream.Err() != nil {
|
||||
t.Fatal(stream.Err())
|
||||
}
|
||||
if count != 3 {
|
||||
t.Errorf("got %d chunks, want 3", count)
|
||||
}
|
||||
}
|
||||
48
embedding/embedding.go
Normal file
48
embedding/embedding.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package embedding
|
||||
|
||||
import "somegit.dev/vikingowl/mistral-go-sdk/chat"
|
||||
|
||||
// Dtype specifies the data type of output embeddings.
|
||||
type Dtype string
|
||||
|
||||
const (
|
||||
DtypeFloat Dtype = "float"
|
||||
DtypeInt8 Dtype = "int8"
|
||||
DtypeUint8 Dtype = "uint8"
|
||||
DtypeBinary Dtype = "binary"
|
||||
DtypeUbinary Dtype = "ubinary"
|
||||
)
|
||||
|
||||
// EncodingFormat specifies the format of embeddings in the response.
|
||||
type EncodingFormat string
|
||||
|
||||
const (
|
||||
EncodingFormatFloat EncodingFormat = "float"
|
||||
EncodingFormatBase64 EncodingFormat = "base64"
|
||||
)
|
||||
|
||||
// Request represents an embedding request.
|
||||
type Request struct {
|
||||
Model string `json:"model"`
|
||||
Input []string `json:"input"`
|
||||
OutputDimension *int `json:"output_dimension,omitempty"`
|
||||
OutputDtype *Dtype `json:"output_dtype,omitempty"`
|
||||
EncodingFormat *EncodingFormat `json:"encoding_format,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// Response represents an embedding response.
|
||||
type Response struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Model string `json:"model"`
|
||||
Usage chat.UsageInfo `json:"usage"`
|
||||
Data []Data `json:"data"`
|
||||
}
|
||||
|
||||
// Data represents a single embedding result.
|
||||
type Data struct {
|
||||
Object string `json:"object"`
|
||||
Embedding []float64 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
}
|
||||
16
embeddings.go
Normal file
16
embeddings.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/embedding"
|
||||
)
|
||||
|
||||
// CreateEmbeddings sends an embedding request and returns the response.
|
||||
func (c *Client) CreateEmbeddings(ctx context.Context, req *embedding.Request) (*embedding.Response, error) {
|
||||
var resp embedding.Response
|
||||
if err := c.doJSON(ctx, "POST", "/v1/embeddings", req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
158
embeddings_test.go
Normal file
158
embeddings_test.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/embedding"
|
||||
)
|
||||
|
||||
func TestCreateEmbeddings_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/embeddings" {
|
||||
t.Errorf("expected /v1/embeddings, got %s", r.URL.Path)
|
||||
}
|
||||
if r.Method != "POST" {
|
||||
t.Errorf("expected POST, got %s", r.Method)
|
||||
}
|
||||
var body map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
if body["model"] != "mistral-embed" {
|
||||
t.Errorf("expected model=mistral-embed, got %v", body["model"])
|
||||
}
|
||||
inputs := body["input"].([]any)
|
||||
if len(inputs) != 2 {
|
||||
t.Errorf("expected 2 inputs, got %d", len(inputs))
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "emb-1",
|
||||
"object": "list",
|
||||
"model": "mistral-embed",
|
||||
"usage": map[string]any{"prompt_tokens": 12, "completion_tokens": 0, "total_tokens": 12},
|
||||
"data": []map[string]any{
|
||||
{"object": "embedding", "embedding": []float64{0.1, 0.2, 0.3}, "index": 0},
|
||||
{"object": "embedding", "embedding": []float64{0.4, 0.5, 0.6}, "index": 1},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.CreateEmbeddings(context.Background(), &embedding.Request{
|
||||
Model: "mistral-embed",
|
||||
Input: []string{"Hello world", "Goodbye world"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.ID != "emb-1" {
|
||||
t.Errorf("got id %q", resp.ID)
|
||||
}
|
||||
if len(resp.Data) != 2 {
|
||||
t.Fatalf("got %d embeddings, want 2", len(resp.Data))
|
||||
}
|
||||
if resp.Data[0].Index != 0 {
|
||||
t.Errorf("got index %d", resp.Data[0].Index)
|
||||
}
|
||||
if len(resp.Data[0].Embedding) != 3 {
|
||||
t.Fatalf("got %d dims, want 3", len(resp.Data[0].Embedding))
|
||||
}
|
||||
if resp.Data[0].Embedding[0] != 0.1 {
|
||||
t.Errorf("got embedding[0]=%f", resp.Data[0].Embedding[0])
|
||||
}
|
||||
if resp.Data[1].Embedding[2] != 0.6 {
|
||||
t.Errorf("got embedding[2]=%f", resp.Data[1].Embedding[2])
|
||||
}
|
||||
if resp.Usage.PromptTokens != 12 {
|
||||
t.Errorf("got prompt_tokens=%d", resp.Usage.PromptTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateEmbeddings_SingleInput(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var body map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
inputs := body["input"].([]any)
|
||||
if len(inputs) != 1 {
|
||||
t.Errorf("expected 1 input, got %d", len(inputs))
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "emb-2", "object": "list", "model": "mistral-embed",
|
||||
"usage": map[string]any{"prompt_tokens": 5, "completion_tokens": 0, "total_tokens": 5},
|
||||
"data": []map[string]any{
|
||||
{"object": "embedding", "embedding": []float64{0.1, 0.2}, "index": 0},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.CreateEmbeddings(context.Background(), &embedding.Request{
|
||||
Model: "mistral-embed",
|
||||
Input: []string{"Just one"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(resp.Data) != 1 {
|
||||
t.Errorf("got %d, want 1", len(resp.Data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateEmbeddings_WithOptions(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var body map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
if body["output_dimension"] != float64(256) {
|
||||
t.Errorf("expected output_dimension=256, got %v", body["output_dimension"])
|
||||
}
|
||||
if body["output_dtype"] != "int8" {
|
||||
t.Errorf("expected output_dtype=int8, got %v", body["output_dtype"])
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "emb-3", "object": "list", "model": "m",
|
||||
"usage": map[string]any{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
"data": []map[string]any{{"object": "embedding", "embedding": []float64{1, 2}, "index": 0}},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
dim := 256
|
||||
dtype := embedding.DtypeInt8
|
||||
_, err := client.CreateEmbeddings(context.Background(), &embedding.Request{
|
||||
Model: "m",
|
||||
Input: []string{"test"},
|
||||
OutputDimension: &dim,
|
||||
OutputDtype: &dtype,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateEmbeddings_APIError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
json.NewEncoder(w).Encode(map[string]any{"message": "rate limited"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
_, err := client.CreateEmbeddings(context.Background(), &embedding.Request{
|
||||
Model: "m",
|
||||
Input: []string{"test"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !IsRateLimit(err) {
|
||||
t.Errorf("expected rate limit, got: %v", err)
|
||||
}
|
||||
}
|
||||
32
fim/request.go
Normal file
32
fim/request.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package fim
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// CompletionRequest represents a Fill-In-the-Middle completion request.
|
||||
type CompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Suffix *string `json:"suffix,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
MaxTokens *int `json:"max_tokens,omitempty"`
|
||||
MinTokens *int `json:"min_tokens,omitempty"`
|
||||
Stop []string `json:"stop,omitempty"`
|
||||
RandomSeed *int `json:"random_seed,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
stream bool
|
||||
}
|
||||
|
||||
// SetStream is used internally to set the stream field.
|
||||
func (r *CompletionRequest) SetStream(v bool) { r.stream = v }
|
||||
|
||||
func (r *CompletionRequest) MarshalJSON() ([]byte, error) {
|
||||
type Alias CompletionRequest
|
||||
return json.Marshal(&struct {
|
||||
Stream bool `json:"stream"`
|
||||
*Alias
|
||||
}{
|
||||
Stream: r.stream,
|
||||
Alias: (*Alias)(r),
|
||||
})
|
||||
}
|
||||
27
fim_complete.go
Normal file
27
fim_complete.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/chat"
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/fim"
|
||||
)
|
||||
|
||||
// FIMComplete sends a Fill-In-the-Middle completion request.
|
||||
func (c *Client) FIMComplete(ctx context.Context, req *fim.CompletionRequest) (*chat.CompletionResponse, error) {
|
||||
var resp chat.CompletionResponse
|
||||
if err := c.doJSON(ctx, "POST", "/v1/fim/completions", req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// FIMCompleteStream sends a FIM request and returns a stream of chunks.
|
||||
func (c *Client) FIMCompleteStream(ctx context.Context, req *fim.CompletionRequest) (*Stream[chat.CompletionChunk], error) {
|
||||
req.SetStream(true)
|
||||
resp, err := c.doStream(ctx, "POST", "/v1/fim/completions", req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newStream[chat.CompletionChunk](resp.Body), nil
|
||||
}
|
||||
175
fim_complete_test.go
Normal file
175
fim_complete_test.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/chat"
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/fim"
|
||||
)
|
||||
|
||||
func TestFIMComplete_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/fim/completions" {
|
||||
t.Errorf("expected /v1/fim/completions, got %s", r.URL.Path)
|
||||
}
|
||||
var body map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
if body["prompt"] != "def add(a, b):" {
|
||||
t.Errorf("expected prompt, got %v", body["prompt"])
|
||||
}
|
||||
if body["suffix"] != "return result" {
|
||||
t.Errorf("expected suffix, got %v", body["suffix"])
|
||||
}
|
||||
if body["model"] != "codestral-latest" {
|
||||
t.Errorf("expected model codestral-latest, got %v", body["model"])
|
||||
}
|
||||
if body["stream"] != false {
|
||||
t.Errorf("expected stream=false, got %v", body["stream"])
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "fim-1", "object": "chat.completion",
|
||||
"model": "codestral-latest", "created": 1234567890,
|
||||
"choices": []map[string]any{{
|
||||
"index": 0,
|
||||
"message": map[string]any{"role": "assistant", "content": "\n result = a + b\n "},
|
||||
"finish_reason": "stop",
|
||||
}},
|
||||
"usage": map[string]any{"prompt_tokens": 15, "completion_tokens": 10, "total_tokens": 25},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
suffix := "return result"
|
||||
resp, err := client.FIMComplete(context.Background(), &fim.CompletionRequest{
|
||||
Model: "codestral-latest",
|
||||
Prompt: "def add(a, b):",
|
||||
Suffix: &suffix,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.ID != "fim-1" {
|
||||
t.Errorf("got id %q", resp.ID)
|
||||
}
|
||||
if resp.Choices[0].Message.Content.String() != "\n result = a + b\n " {
|
||||
t.Errorf("got content %q", resp.Choices[0].Message.Content.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestFIMComplete_WithParams(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var body map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
if body["temperature"] != 0.2 {
|
||||
t.Errorf("expected temperature=0.2, got %v", body["temperature"])
|
||||
}
|
||||
if body["max_tokens"] != float64(50) {
|
||||
t.Errorf("expected max_tokens=50, got %v", body["max_tokens"])
|
||||
}
|
||||
if body["min_tokens"] != float64(10) {
|
||||
t.Errorf("expected min_tokens=10, got %v", body["min_tokens"])
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "fim-2", "object": "chat.completion",
|
||||
"model": "codestral-latest", "created": 0,
|
||||
"choices": []map[string]any{{
|
||||
"index": 0, "message": map[string]any{"role": "assistant", "content": "code"},
|
||||
"finish_reason": "length",
|
||||
}},
|
||||
"usage": map[string]any{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
temp := 0.2
|
||||
maxTok := 50
|
||||
minTok := 10
|
||||
_, err := client.FIMComplete(context.Background(), &fim.CompletionRequest{
|
||||
Model: "codestral-latest",
|
||||
Prompt: "fn main() {",
|
||||
Temperature: &temp,
|
||||
MaxTokens: &maxTok,
|
||||
MinTokens: &minTok,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFIMCompleteStream_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var body map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
if body["stream"] != true {
|
||||
t.Errorf("expected stream=true")
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
flusher, _ := w.(http.Flusher)
|
||||
for _, content := range []string{"\n ", "result = a + b", "\n "} {
|
||||
chunk := chat.CompletionChunk{
|
||||
ID: "fc", Model: "codestral-latest",
|
||||
Choices: []chat.CompletionStreamChoice{{
|
||||
Index: 0,
|
||||
Delta: chat.DeltaMessage{Content: chat.TextContent(content)},
|
||||
}},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
}
|
||||
fmt.Fprint(w, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
stream, err := client.FIMCompleteStream(context.Background(), &fim.CompletionRequest{
|
||||
Model: "codestral-latest",
|
||||
Prompt: "def add(a, b):",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
var count int
|
||||
for stream.Next() {
|
||||
count++
|
||||
}
|
||||
if stream.Err() != nil {
|
||||
t.Fatal(stream.Err())
|
||||
}
|
||||
if count != 3 {
|
||||
t.Errorf("got %d chunks, want 3", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFIMComplete_APIError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
json.NewEncoder(w).Encode(map[string]any{"message": "model not found"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
_, err := client.FIMComplete(context.Background(), &fim.CompletionRequest{
|
||||
Model: "bad-model",
|
||||
Prompt: "code",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !IsNotFound(err) {
|
||||
t.Errorf("expected not found, got: %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user