feat: v1.1.0 — sync with upstream Python SDK v2.1.3

Add Connectors, Audio Speech/Voices, Audio Realtime types,
and Observability (beta). 41 new service methods, 116 total.

Breaking: ListModels and UploadFile signatures changed
(pass nil for previous behavior).
This commit is contained in:
2026-03-24 09:07:03 +01:00
parent b1f0fc4907
commit aa5c53c407
40 changed files with 2906 additions and 18 deletions

View File

@@ -1,3 +1,37 @@
## v1.1.0 — 2026-03-24
Upstream sync with Python SDK v2.1.3. Adds Connectors, Audio Speech/Voices, and Observability (beta).
### Breaking Changes
- **`ListModels`** signature changed from `(ctx)` to `(ctx, *model.ListParams)`.
Pass `nil` for previous behavior. The new `ListParams` supports `Provider` and
`Model` query filters.
- **`UploadFile`** signature changed from `(ctx, filename, reader, purpose)` to
`(ctx, filename, reader, *file.UploadParams)`. The new `UploadParams` struct
holds `Purpose`, `Expiry`, and `Visibility` fields.
### Added
- **`ReasoningEffort`** field on `chat.CompletionRequest` and
`agents.CompletionRequest` — controls reasoning effort (`"none"`, `"high"`).
- **Connectors API** (new `connector/` package) — `CreateConnector`,
`ListConnectors`, `GetConnector`, `UpdateConnector`, `DeleteConnector`,
`GetConnectorAuthURL`, `ListConnectorTools`, `CallConnectorTool`.
- **Audio Speech (TTS)** — `Speech`, `SpeechStream` with `SpeechStream` typed
wrapper, `SpeechOutputFormat` enum (pcm/wav/mp3/flac/opus).
- **Audio Voices** — `ListVoices`, `CreateVoice`, `GetVoice`, `UpdateVoice`,
`DeleteVoice`, `GetVoiceSampleAudio`.
- **Audio Realtime types** — `AudioEncoding`, `AudioFormat`, `RealtimeSession`,
and WebSocket message types in `audio/realtime.go`. No WebSocket client yet
(would require adding a dependency).
- **Observability API** (new `observability/` package, beta) — campaigns,
chat completion events, judges, datasets, records, and import tasks.
33 service methods total.
- **`file.Visibility`** enum — `shared_global`, `shared_org`,
`shared_workspace`, `private`.
- **`model.ListParams`** — filter models by `Provider` and `Model`.
## v1.0.0 — 2026-03-17
Stable release. Tracks upstream Python SDK v2.0.4.

89
CLAUDE.md Normal file
View File

@@ -0,0 +1,89 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Project
Idiomatic Go SDK for the Mistral AI API. Module path: `somegit.dev/vikingowl/mistral-go-sdk`. Requires Go 1.26+. Zero external dependencies (stdlib only). Tracks the upstream [Mistral Python SDK](https://github.com/mistralai/client-python) as reference for API surface and type definitions.
## Repository layout
- **Working directory**: `mistral-go-sdk/` — the Go SDK source. All development happens here.
- **`../client-python/`**: Clone of the upstream Mistral Python SDK. Read-only reference — pull/update it when checking for upstream API changes, but never modify it.
## Commands
```bash
# Run all unit tests
go test ./...
# Run a single test
go test -run TestChatComplete_Success
# Run integration tests (requires MISTRAL_API_KEY env var)
go test -tags=integration ./...
# Vet and build
go vet ./...
go build ./...
```
No Makefile, linter config, or code generation tooling — standard `go test` / `go vet` / `go build`.
## Architecture
### Two-layer design: types in sub-packages, methods on `*Client`
Sub-packages (`chat/`, `agents/`, `conversation/`, `embedding/`, `model/`, `file/`, `finetune/`, `batch/`, `ocr/`, `audio/`, `library/`, `moderation/`, `classification/`, `fim/`) are **types-only** — they define request/response structs and enums but contain no HTTP logic. All service methods live on `*Client` in the root package, prefix-namespaced by domain (e.g. `ChatComplete`, `AgentsComplete`, `CreateFineTuningJob`, `UploadFile`).
### HTTP internals (request.go)
All HTTP flows route through a small set of unexported helpers on `*Client`:
- `do()` — raw HTTP with auth headers + retry
- `doJSON()` — JSON marshal request → `do()` → unmarshal response
- `doStream()` — JSON request → raw `*http.Response` for SSE
- `doMultipart()` / `doMultipartStream()` — multipart file upload variants
- `doRetry()` — retry loop with exponential backoff + jitter + `Retry-After` parsing
### Streaming
Generic `Stream[T]` type wraps SSE (`sseReader`) with `Next()`/`Current()`/`Err()`/`Close()` iterator pattern. Typed wrappers `EventStream` (conversations) and `AudioStream` (transcription) unmarshal `json.RawMessage` into domain-specific event types.
### Sealed interfaces for discriminated unions
Polymorphic API types use **sealed interfaces** with unexported marker methods:
- `chat.Message` (marker: `isMessage()`) — `SystemMessage`, `UserMessage`, `AssistantMessage`, `ToolMessage`
- `chat.ContentChunk` (marker: `contentChunk()`) — `TextChunk`, `ImageURLChunk`, `DocumentURLChunk`, `FileChunk`, `ReferenceChunk`, `ThinkChunk`, `AudioChunk`, `ToolReferenceChunk`, `ToolFileChunk`
- `agents.AgentTool` (marker: `agentToolType()`) — `FunctionTool`, `WebSearchTool`, `CodeInterpreterTool`, `ConnectorTool`, etc.
- `conversation.Event` — conversation streaming events
Each has an `Unknown*` variant so the SDK doesn't break on new API types. Each has a `Unmarshal*` dispatch function that probes a `type`/`role` discriminator field.
### Custom JSON patterns
Several types require non-trivial marshal/unmarshal:
- **Type alias trick** — `type alias T` inside `MarshalJSON` to avoid infinite recursion when injecting a `type`/`role` discriminator field.
- **`json:"-"` + custom MarshalJSON** — `CompletionRequest.Messages` (and `stream`) are excluded from default marshaling and injected via custom `MarshalJSON`.
- **Union types** — `Content` handles `string | null | []ContentChunk`; `ToolChoice` handles `string | object`; `ImageURL` handles `string | object`; `FunctionCall.Arguments` handles `string | object`; `ReferenceID` handles `int | string` with type preservation.
- **Probe struct pattern** — `Unmarshal*` functions decode only the discriminator field first, then dispatch to the concrete type.
### Shared types in `chat/`
`GuardrailConfig`, `ModerationLLMV1Config`, `ModerationLLMV2Config` live in `chat/` because it's the base types package imported by both `agents/` and `conversation/`. This avoids import cycles.
### Error handling
`APIError` in `error.go` with sentinel checkers: `IsNotFound()`, `IsRateLimit()`, `IsAuth()`. All use `errors.As` for unwrapping.
## Testing patterns
- Unit tests use `httptest.NewServer` with inline handlers to mock the Mistral API. Client is pointed at the test server via `WithBaseURL(server.URL)`.
- Integration tests are behind `//go:build integration` build tag and require `MISTRAL_API_KEY`.
- Tests use stdlib `testing` only — no third-party test frameworks.
## Adding a new API endpoint
1. Define request/response types in the appropriate sub-package (or create a new one with a `doc.go`).
2. Add a method on `*Client` in the root package. Use `doJSON` for standard request/response, `doStream` for SSE, `doMultipart` for file uploads.
3. Add unit tests with `httptest.NewServer`.
4. If the endpoint supports streaming, return `*Stream[T]` and call `EnableStream()` on the request before sending.

View File

@@ -11,7 +11,7 @@ The most complete Go client for the [Mistral AI API](https://docs.mistral.ai/).
**Zero dependencies.** The entire SDK — including tests — uses only the Go standard library. No `go.sum`, no transitive dependency tree to audit, no version conflicts, no supply chain risk.
**Full API coverage.** 75 methods across every Mistral endpoint — including Conversations, Agents CRUD, Libraries, OCR, Audio, Fine-tuning, and Batch Jobs. No other Go SDK covers Conversations or Agents.
**Full API coverage.** 116 methods across every Mistral endpoint — including Connectors, Audio Speech/Voices, Conversations, Agents CRUD, Libraries, OCR, Observability, Fine-tuning, and Batch Jobs. No other Go SDK covers Conversations, Connectors, or Observability.
**Typed streaming.** A generic pull-based `Stream[T]` iterator — no channels, no goroutines, no leaks. Just `Next()` / `Current()` / `Err()` / `Close()`.
@@ -132,7 +132,7 @@ for stream.Next() {
## API Coverage
75 public methods on `Client`, grouped by domain:
116 public methods on `Client`, grouped by domain:
| Domain | Methods |
|--------|---------|
@@ -140,6 +140,7 @@ for stream.Next() {
| **FIM** | `FIMComplete`, `FIMCompleteStream` |
| **Agents (completions)** | `AgentsComplete`, `AgentsCompleteStream` |
| **Agents (CRUD)** | `CreateAgent`, `ListAgents`, `GetAgent`, `UpdateAgent`, `DeleteAgent`, `UpdateAgentVersion`, `ListAgentVersions`, `GetAgentVersion`, `SetAgentAlias`, `ListAgentAliases`, `DeleteAgentAlias` |
| **Connectors** | `CreateConnector`, `ListConnectors`, `GetConnector`, `UpdateConnector`, `DeleteConnector`, `GetConnectorAuthURL`, `ListConnectorTools`, `CallConnectorTool` |
| **Conversations** | `StartConversation`, `StartConversationStream`, `AppendConversation`, `AppendConversationStream`, `RestartConversation`, `RestartConversationStream`, `GetConversation`, `ListConversations`, `DeleteConversation`, `GetConversationHistory`, `GetConversationMessages` |
| **Models** | `ListModels`, `GetModel`, `DeleteModel` |
| **Files** | `UploadFile`, `ListFiles`, `GetFile`, `DeleteFile`, `GetFileContent`, `GetFileURL` |
@@ -147,10 +148,16 @@ for stream.Next() {
| **Fine-tuning** | `CreateFineTuningJob`, `ListFineTuningJobs`, `GetFineTuningJob`, `CancelFineTuningJob`, `StartFineTuningJob`, `UpdateFineTunedModel`, `ArchiveFineTunedModel`, `UnarchiveFineTunedModel` |
| **Batch** | `CreateBatchJob`, `ListBatchJobs`, `GetBatchJob`, `CancelBatchJob` |
| **OCR** | `OCR` |
| **Audio** | `Transcribe`, `TranscribeStream` |
| **Audio (transcription)** | `Transcribe`, `TranscribeStream` |
| **Audio (speech)** | `Speech`, `SpeechStream` |
| **Audio (voices)** | `ListVoices`, `CreateVoice`, `GetVoice`, `UpdateVoice`, `DeleteVoice`, `GetVoiceSampleAudio` |
| **Libraries** | `CreateLibrary`, `ListLibraries`, `GetLibrary`, `UpdateLibrary`, `DeleteLibrary`, `UploadDocument`, `ListDocuments`, `GetDocument`, `UpdateDocument`, `DeleteDocument`, `GetDocumentTextContent`, `GetDocumentStatus`, `GetDocumentSignedURL`, `GetDocumentExtractedTextSignedURL`, `ReprocessDocument`, `ListLibrarySharing`, `ShareLibrary`, `UnshareLibrary` |
| **Moderation** | `Moderate`, `ModerateChat` |
| **Classification** | `Classify`, `ClassifyChat` |
| **Observability (campaigns)** | `CreateCampaign`, `ListCampaigns`, `GetCampaign`, `DeleteCampaign`, `GetCampaignStatus`, `ListCampaignEvents` |
| **Observability (events)** | `SearchChatCompletionEvents`, `SearchChatCompletionEventIDs`, `GetChatCompletionEvent`, `GetSimilarChatCompletionEvents`, `JudgeChatCompletionEvent` |
| **Observability (judges)** | `CreateJudge`, `ListJudges`, `GetJudge`, `UpdateJudge`, `DeleteJudge`, `JudgeConversation` |
| **Observability (datasets)** | `CreateDataset`, `ListDatasets`, `GetDataset`, `UpdateDataset`, `DeleteDataset`, `ExportDatasetToJSONL`, `ListDatasetRecords`, `CreateDatasetRecord`, `GetDatasetRecord`, `UpdateDatasetRecordPayload`, `UpdateDatasetRecordProperties`, `DeleteDatasetRecord`, `BulkDeleteDatasetRecords`, `JudgeDatasetRecord`, `ImportDatasetFromCampaign`, `ImportDatasetFromExplorer`, `ImportDatasetFromFile`, `ImportDatasetFromPlayground`, `ImportDatasetFromDataset`, `ListDatasetTasks`, `GetDatasetTask` |
## Comparison
@@ -163,11 +170,13 @@ There is no official Go SDK from Mistral AI (only Python and TypeScript). The ma
| Embeddings | Yes | Yes | Yes | Yes |
| Tool calling | Yes | No | No | No |
| Agents (completions + CRUD) | Yes | No | No | No |
| Connectors (MCP) | Yes | No | No | No |
| Conversations API | Yes | No | No | No |
| Libraries / Documents | Yes | No | No | No |
| Fine-tuning / Batch | Yes | No | No | No |
| OCR | Yes | No | No | Yes |
| Audio transcription | Yes | No | No | No |
| Audio (transcription + TTS + voices) | Yes | No | No | No |
| Observability (beta) | Yes | No | No | No |
| Moderation / Classification | Yes | No | No | No |
| Vision (multimodal) | Yes | No | No | Yes |
| Zero dependencies | Yes | test-only (testify) | test-only (testify) | test-only (testify) |
@@ -221,6 +230,7 @@ as its upstream reference for API surface and type definitions.
| SDK Version | Upstream Python SDK |
|-------------|---------------------|
| v1.1.0 | v2.1.3 |
| v1.0.0 | v2.0.4 |
## License

View File

@@ -26,6 +26,7 @@ type CompletionRequest struct {
Prediction *chat.Prediction `json:"prediction,omitempty"`
PromptMode *chat.PromptMode `json:"prompt_mode,omitempty"`
Guardrails []chat.GuardrailConfig `json:"guardrails,omitempty"`
ReasoningEffort *chat.ReasoningEffort `json:"reasoning_effort,omitempty"`
stream bool
}

View File

@@ -114,6 +114,39 @@ func TestAgentsComplete_WithTools(t *testing.T) {
}
}
func TestAgentsComplete_ReasoningEffort(t *testing.T) {
effort := chat.ReasoningEffortHigh
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
if body["reasoning_effort"] != "high" {
t.Errorf("expected reasoning_effort=high, got %v", body["reasoning_effort"])
}
json.NewEncoder(w).Encode(map[string]any{
"id": "a-re", "object": "chat.completion",
"model": "m", "created": 0,
"choices": []map[string]any{{
"index": 0, "message": map[string]any{"role": "assistant", "content": "ok"},
"finish_reason": "stop",
}},
"usage": map[string]any{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
_, err := client.AgentsComplete(context.Background(), &agents.CompletionRequest{
AgentID: "agent-1",
Messages: []chat.Message{&chat.UserMessage{Content: chat.TextContent("Hi")}},
ReasoningEffort: &effort,
})
if err != nil {
t.Fatal(err)
}
}
func TestAgentsCompleteStream_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body map[string]any

View File

@@ -1,5 +1,25 @@
// Package audio provides types for the Mistral audio transcription API.
// Package audio provides types for the Mistral audio APIs.
//
// # Transcription
//
// [TranscriptionRequest] and [TranscriptionResponse] handle speech-to-text.
// Streaming transcription returns typed [StreamEvent] values via a sealed
// interface dispatched by the "type" field.
//
// # Speech (TTS)
//
// [SpeechRequest] and [SpeechResponse] handle text-to-speech.
// Streaming speech returns typed [SpeechStreamEvent] values
// ([SpeechAudioDelta] and [SpeechDone]).
//
// # Voices
//
// [VoiceResponse], [VoiceCreateRequest], and [VoiceUpdateRequest] manage
// custom voices for speech synthesis.
//
// # Realtime
//
// Realtime transcription types ([AudioEncoding], [AudioFormat],
// [RealtimeSession], and WebSocket message types) are defined here.
// The WebSocket client is not yet implemented.
package audio

75
audio/realtime.go Normal file
View File

@@ -0,0 +1,75 @@
package audio
// AudioEncoding is the encoding format for realtime audio streams.
type AudioEncoding string
const (
EncodingPCMS16LE AudioEncoding = "pcm_s16le"
EncodingPCMS32LE AudioEncoding = "pcm_s32le"
EncodingPCMF16LE AudioEncoding = "pcm_f16le"
EncodingPCMF32LE AudioEncoding = "pcm_f32le"
EncodingPCMMulaw AudioEncoding = "pcm_mulaw"
EncodingPCMAlaw AudioEncoding = "pcm_alaw"
)
// AudioFormat describes the encoding and sample rate for realtime audio.
type AudioFormat struct {
Encoding AudioEncoding `json:"encoding"`
SampleRate int `json:"sample_rate"`
}
// RealtimeSession describes a realtime transcription session.
type RealtimeSession struct {
RequestID string `json:"request_id"`
Model string `json:"model"`
AudioFormat AudioFormat `json:"audio_format"`
TargetStreamingDelayMs *int `json:"target_streaming_delay_ms,omitempty"`
}
// RealtimeSessionUpdate is sent to update session parameters.
// Parameters can only be changed before audio transmission starts.
type RealtimeSessionUpdate struct {
AudioFormat *AudioFormat `json:"audio_format,omitempty"`
TargetStreamingDelayMs *int `json:"target_streaming_delay_ms,omitempty"`
}
// InputAudioAppend sends a chunk of audio data.
// Audio is base64-encoded (max 262144 bytes decoded).
type InputAudioAppend struct {
Type string `json:"type"` // "input_audio.append"
Audio string `json:"audio"`
}
// InputAudioFlush flushes the audio buffer.
type InputAudioFlush struct {
Type string `json:"type"` // "input_audio.flush"
}
// InputAudioEnd signals the end of audio input.
type InputAudioEnd struct {
Type string `json:"type"` // "input_audio.end"
}
// RealtimeSessionCreated is received when a session is created.
type RealtimeSessionCreated struct {
Type string `json:"type"` // "session.created"
Session RealtimeSession `json:"session"`
}
// RealtimeSessionUpdated is received when a session is updated.
type RealtimeSessionUpdated struct {
Type string `json:"type"` // "session.updated"
Session RealtimeSession `json:"session"`
}
// RealtimeErrorDetail describes a realtime error.
type RealtimeErrorDetail struct {
Message string `json:"message"`
Code int `json:"code"`
}
// RealtimeError is received on error.
type RealtimeError struct {
Type string `json:"type"` // "error"
Error RealtimeErrorDetail `json:"error"`
}

88
audio/speech.go Normal file
View File

@@ -0,0 +1,88 @@
package audio
import (
"encoding/json"
"fmt"
)
// SpeechOutputFormat is the output audio format for speech synthesis.
type SpeechOutputFormat string
const (
SpeechFormatPCM SpeechOutputFormat = "pcm"
SpeechFormatWAV SpeechOutputFormat = "wav"
SpeechFormatMP3 SpeechOutputFormat = "mp3"
SpeechFormatFLAC SpeechOutputFormat = "flac"
SpeechFormatOpus SpeechOutputFormat = "opus"
)
// SpeechRequest represents a text-to-speech request.
type SpeechRequest struct {
Input string `json:"input"`
Model string `json:"model"`
Metadata map[string]any `json:"metadata,omitempty"`
VoiceID *string `json:"voice_id,omitempty"`
RefAudio *string `json:"ref_audio,omitempty"`
ResponseFormat *SpeechOutputFormat `json:"response_format,omitempty"`
stream bool
}
// EnableStream is used internally to enable streaming.
func (r *SpeechRequest) EnableStream() { r.stream = true }
func (r *SpeechRequest) MarshalJSON() ([]byte, error) {
type Alias SpeechRequest
return json.Marshal(&struct {
Stream bool `json:"stream"`
*Alias
}{
Stream: r.stream,
Alias: (*Alias)(r),
})
}
// SpeechResponse is the response from a non-streaming speech request.
type SpeechResponse struct {
AudioData string `json:"audio_data"`
}
// SpeechStreamEvent is a sealed interface for speech streaming events.
type SpeechStreamEvent interface {
speechStreamEvent()
}
// SpeechAudioDelta contains a chunk of audio data during streaming.
type SpeechAudioDelta struct {
Type string `json:"type"`
AudioData string `json:"audio_data"`
}
func (*SpeechAudioDelta) speechStreamEvent() {}
// SpeechDone is emitted when speech synthesis is complete.
type SpeechDone struct {
Type string `json:"type"`
Usage UsageInfo `json:"usage"`
}
func (*SpeechDone) speechStreamEvent() {}
// UnmarshalSpeechStreamEvent dispatches a raw JSON event to the correct type.
func UnmarshalSpeechStreamEvent(data []byte) (SpeechStreamEvent, error) {
var probe struct {
Type string `json:"type"`
}
if err := json.Unmarshal(data, &probe); err != nil {
return nil, err
}
switch probe.Type {
case "speech.audio.delta":
var e SpeechAudioDelta
return &e, json.Unmarshal(data, &e)
case "speech.audio.done":
var e SpeechDone
return &e, json.Unmarshal(data, &e)
default:
return nil, fmt.Errorf("unknown speech stream event type: %q", probe.Type)
}
}

48
audio/voice.go Normal file
View File

@@ -0,0 +1,48 @@
package audio
// VoiceResponse represents a voice entity.
type VoiceResponse struct {
Name string `json:"name"`
ID string `json:"id"`
CreatedAt string `json:"created_at"`
UserID *string `json:"user_id,omitempty"`
Slug *string `json:"slug,omitempty"`
Languages []string `json:"languages,omitempty"`
Gender *string `json:"gender,omitempty"`
Age *int `json:"age,omitempty"`
Tags []string `json:"tags,omitempty"`
Color *string `json:"color,omitempty"`
RetentionNotice *int `json:"retention_notice,omitempty"`
}
// VoiceCreateRequest creates a custom voice.
type VoiceCreateRequest struct {
Name string `json:"name"`
SampleAudio string `json:"sample_audio"`
Slug *string `json:"slug,omitempty"`
Languages []string `json:"languages,omitempty"`
Gender *string `json:"gender,omitempty"`
Age *int `json:"age,omitempty"`
Tags []string `json:"tags,omitempty"`
Color *string `json:"color,omitempty"`
RetentionNotice *int `json:"retention_notice,omitempty"`
SampleFilename *string `json:"sample_filename,omitempty"`
}
// VoiceUpdateRequest updates a voice.
type VoiceUpdateRequest struct {
Name *string `json:"name,omitempty"`
Languages []string `json:"languages,omitempty"`
Gender *string `json:"gender,omitempty"`
Age *int `json:"age,omitempty"`
Tags []string `json:"tags,omitempty"`
}
// VoiceListResponse is the response from listing voices.
type VoiceListResponse struct {
Items []VoiceResponse `json:"items"`
Total int `json:"total"`
Page int `json:"page"`
PageSize int `json:"page_size"`
TotalPages int `json:"total_pages"`
}

View File

@@ -3,7 +3,9 @@ package mistral
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"somegit.dev/vikingowl/mistral-go-sdk/audio"
)
@@ -95,3 +97,125 @@ func (s *AudioStream) Err() error { return s.err }
// Close releases the underlying connection.
func (s *AudioStream) Close() error { return s.stream.Close() }
// Speech sends a text-to-speech request and returns the full response.
func (c *Client) Speech(ctx context.Context, req *audio.SpeechRequest) (*audio.SpeechResponse, error) {
var resp audio.SpeechResponse
if err := c.doJSON(ctx, "POST", "/v1/audio/speech", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// SpeechStream sends a text-to-speech request and returns a stream of audio events.
func (c *Client) SpeechStream(ctx context.Context, req *audio.SpeechRequest) (*SpeechStream, error) {
req.EnableStream()
resp, err := c.doStream(ctx, "POST", "/v1/audio/speech", req)
if err != nil {
return nil, err
}
return newSpeechStream(resp.Body), nil
}
// SpeechStream wraps the generic Stream for speech streaming events.
type SpeechStream struct {
stream *Stream[json.RawMessage]
event audio.SpeechStreamEvent
err error
}
func newSpeechStream(body readCloser) *SpeechStream {
return &SpeechStream{
stream: newStream[json.RawMessage](body),
}
}
// Next advances to the next event. Returns false when done or on error.
func (s *SpeechStream) Next() bool {
if s.err != nil {
return false
}
if !s.stream.Next() {
s.err = s.stream.Err()
return false
}
event, err := audio.UnmarshalSpeechStreamEvent(s.stream.Current())
if err != nil {
s.err = err
return false
}
s.event = event
return true
}
// Current returns the most recently read event.
func (s *SpeechStream) Current() audio.SpeechStreamEvent { return s.event }
// Err returns any error encountered during streaming.
func (s *SpeechStream) Err() error { return s.err }
// Close releases the underlying connection.
func (s *SpeechStream) Close() error { return s.stream.Close() }
// ListVoices returns available voices.
func (c *Client) ListVoices(ctx context.Context) (*audio.VoiceListResponse, error) {
var resp audio.VoiceListResponse
if err := c.doJSON(ctx, "GET", "/v1/audio/voices", nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// CreateVoice creates a custom voice.
func (c *Client) CreateVoice(ctx context.Context, req *audio.VoiceCreateRequest) (*audio.VoiceResponse, error) {
var resp audio.VoiceResponse
if err := c.doJSON(ctx, "POST", "/v1/audio/voices", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetVoice retrieves a voice by ID.
func (c *Client) GetVoice(ctx context.Context, voiceID string) (*audio.VoiceResponse, error) {
var resp audio.VoiceResponse
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/audio/voices/%s", voiceID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// UpdateVoice updates a voice.
func (c *Client) UpdateVoice(ctx context.Context, voiceID string, req *audio.VoiceUpdateRequest) (*audio.VoiceResponse, error) {
var resp audio.VoiceResponse
if err := c.doJSON(ctx, "PATCH", fmt.Sprintf("/v1/audio/voices/%s", voiceID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// DeleteVoice deletes a voice.
func (c *Client) DeleteVoice(ctx context.Context, voiceID string) error {
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/audio/voices/%s", voiceID), nil)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return parseAPIError(resp)
}
return nil
}
// GetVoiceSampleAudio retrieves the sample audio for a voice.
// Returns the raw HTTP response; the caller must close the body.
func (c *Client) GetVoiceSampleAudio(ctx context.Context, voiceID string) (*http.Response, error) {
resp, err := c.do(ctx, "GET", fmt.Sprintf("/v1/audio/voices/%s/sample", voiceID), nil)
if err != nil {
return nil, err
}
if resp.StatusCode >= 400 {
defer resp.Body.Close()
return nil, parseAPIError(resp)
}
return resp, nil
}

217
audio_speech_test.go Normal file
View File

@@ -0,0 +1,217 @@
package mistral
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/audio"
)
func TestSpeech_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("expected POST, got %s", r.Method)
}
if r.URL.Path != "/v1/audio/speech" {
t.Errorf("got path %s", r.URL.Path)
}
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
if body["input"] != "Hello world" {
t.Errorf("got input %v", body["input"])
}
if body["stream"] != false {
t.Errorf("expected stream=false")
}
json.NewEncoder(w).Encode(map[string]any{
"audio_data": "base64audiodata==",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.Speech(context.Background(), &audio.SpeechRequest{
Input: "Hello world",
Model: "mistral-speech",
})
if err != nil {
t.Fatal(err)
}
if resp.AudioData != "base64audiodata==" {
t.Errorf("got audio_data %q", resp.AudioData)
}
}
func TestSpeechStream_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
if body["stream"] != true {
t.Errorf("expected stream=true")
}
w.Header().Set("Content-Type", "text/event-stream")
flusher, _ := w.(http.Flusher)
delta, _ := json.Marshal(map[string]any{
"type": "speech.audio.delta", "audio_data": "chunk1==",
})
fmt.Fprintf(w, "data: %s\n\n", delta)
flusher.Flush()
done, _ := json.Marshal(map[string]any{
"type": "speech.audio.done",
"usage": map[string]any{
"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15,
},
})
fmt.Fprintf(w, "data: %s\n\n", done)
flusher.Flush()
fmt.Fprint(w, "data: [DONE]\n\n")
flusher.Flush()
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
stream, err := client.SpeechStream(context.Background(), &audio.SpeechRequest{
Input: "Hi",
Model: "mistral-speech",
})
if err != nil {
t.Fatal(err)
}
defer stream.Close()
var events int
for stream.Next() {
events++
}
if err := stream.Err(); err != nil {
t.Fatal(err)
}
if events != 2 {
t.Errorf("got %d events, want 2", events)
}
}
func TestListVoices_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/audio/voices" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"items": []map[string]any{
{"id": "v1", "name": "Default", "created_at": "2025-01-01"},
},
"total": 1, "page": 1, "page_size": 10, "total_pages": 1,
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.ListVoices(context.Background())
if err != nil {
t.Fatal(err)
}
if len(resp.Items) != 1 {
t.Fatalf("got %d voices", len(resp.Items))
}
if resp.Items[0].ID != "v1" {
t.Errorf("got id %q", resp.Items[0].ID)
}
}
func TestCreateVoice_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("expected POST, got %s", r.Method)
}
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
if body["name"] != "MyVoice" {
t.Errorf("got name %v", body["name"])
}
json.NewEncoder(w).Encode(map[string]any{
"id": "v2", "name": "MyVoice", "created_at": "2025-01-01",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.CreateVoice(context.Background(), &audio.VoiceCreateRequest{
Name: "MyVoice",
SampleAudio: "base64audio==",
})
if err != nil {
t.Fatal(err)
}
if resp.ID != "v2" {
t.Errorf("got id %q", resp.ID)
}
}
func TestGetVoice_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/audio/voices/v1" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"id": "v1", "name": "Default", "created_at": "2025-01-01",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.GetVoice(context.Background(), "v1")
if err != nil {
t.Fatal(err)
}
if resp.Name != "Default" {
t.Errorf("got name %q", resp.Name)
}
}
func TestUpdateVoice_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "PATCH" {
t.Errorf("expected PATCH, got %s", r.Method)
}
json.NewEncoder(w).Encode(map[string]any{
"id": "v1", "name": "Renamed", "created_at": "2025-01-01",
})
}))
defer server.Close()
name := "Renamed"
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.UpdateVoice(context.Background(), "v1", &audio.VoiceUpdateRequest{
Name: &name,
})
if err != nil {
t.Fatal(err)
}
if resp.Name != "Renamed" {
t.Errorf("got name %q", resp.Name)
}
}
func TestDeleteVoice_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "DELETE" {
t.Errorf("expected DELETE, got %s", r.Method)
}
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
err := client.DeleteVoice(context.Background(), "v1")
if err != nil {
t.Fatal(err)
}
}

View File

@@ -9,6 +9,14 @@ const (
PromptModeReasoning PromptMode = "reasoning"
)
// ReasoningEffort controls the amount of reasoning effort the model uses.
type ReasoningEffort string
const (
ReasoningEffortNone ReasoningEffort = "none"
ReasoningEffortHigh ReasoningEffort = "high"
)
// Prediction provides expected completion content for optimization.
type Prediction struct {
Type string `json:"type"`
@@ -36,6 +44,7 @@ type CompletionRequest struct {
Prediction *Prediction `json:"prediction,omitempty"`
PromptMode *PromptMode `json:"prompt_mode,omitempty"`
Guardrails []GuardrailConfig `json:"guardrails,omitempty"`
ReasoningEffort *ReasoningEffort `json:"reasoning_effort,omitempty"`
stream bool
}

View File

@@ -350,6 +350,41 @@ func TestChatComplete_RequestBody(t *testing.T) {
}
}
func TestChatComplete_ReasoningEffort(t *testing.T) {
effort := chat.ReasoningEffortHigh
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
bodyBytes, _ := io.ReadAll(r.Body)
var body map[string]any
json.Unmarshal(bodyBytes, &body)
if body["reasoning_effort"] != "high" {
t.Errorf("expected reasoning_effort=high, got %v", body["reasoning_effort"])
}
json.NewEncoder(w).Encode(map[string]any{
"id": "chat-re", "object": "chat.completion",
"model": "m", "created": 0,
"choices": []map[string]any{{
"index": 0, "message": map[string]any{"role": "assistant", "content": "ok"},
"finish_reason": "stop",
}},
"usage": map[string]any{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
_, err := client.ChatComplete(context.Background(), &chat.CompletionRequest{
Model: "m",
Messages: []chat.Message{&chat.UserMessage{Content: chat.TextContent("Hi")}},
ReasoningEffort: &effort,
})
if err != nil {
t.Fatal(err)
}
}
func TestChatComplete_ContextCanceled(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Never responds — context should cancel first

100
connector/connector.go Normal file
View File

@@ -0,0 +1,100 @@
package connector
import "encoding/json"
// Visibility controls who can see a connector or tool.
type Visibility string
const (
VisibilitySharedGlobal Visibility = "shared_global"
VisibilitySharedOrg Visibility = "shared_org"
VisibilitySharedWorkspace Visibility = "shared_workspace"
VisibilityPrivate Visibility = "private"
)
// AuthData holds OAuth2 client credentials for a connector.
type AuthData struct {
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
}
// Connector represents a registered MCP connector.
type Connector struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
CreatedAt string `json:"created_at"`
ModifiedAt string `json:"modified_at"`
Server *string `json:"server,omitempty"`
AuthType *string `json:"auth_type,omitempty"`
}
// CreateRequest creates a new connector.
type CreateRequest struct {
Name string `json:"name"`
Description string `json:"description"`
Server string `json:"server"`
IconURL *string `json:"icon_url,omitempty"`
Visibility *Visibility `json:"visibility,omitempty"`
Headers map[string]string `json:"headers,omitempty"`
AuthData *AuthData `json:"auth_data,omitempty"`
SystemPrompt *string `json:"system_prompt,omitempty"`
}
// UpdateRequest updates an existing connector.
type UpdateRequest struct {
Name *string `json:"name,omitempty"`
Description *string `json:"description,omitempty"`
IconURL *string `json:"icon_url,omitempty"`
SystemPrompt *string `json:"system_prompt,omitempty"`
Server *string `json:"server,omitempty"`
Headers map[string]string `json:"headers,omitempty"`
AuthData *AuthData `json:"auth_data,omitempty"`
ConnectionConfig map[string]any `json:"connection_config,omitempty"`
ConnectionSecrets map[string]any `json:"connection_secrets,omitempty"`
}
// AuthURLResponse is the response from getting a connector's OAuth URL.
type AuthURLResponse struct {
AuthURL string `json:"auth_url"`
TTL int `json:"ttl"`
}
// CallToolRequest is the request body for calling a connector tool.
type CallToolRequest struct {
Arguments map[string]any `json:"arguments,omitempty"`
}
// CallToolResponse is the response from calling a connector tool.
// Content is left as raw JSON because the upstream API returns a union
// of 5 content types (text, image, audio, resource link, embedded resource).
type CallToolResponse struct {
Content json.RawMessage `json:"content"`
Metadata map[string]any `json:"metadata,omitempty"`
}
// Tool represents a tool exposed by a connector.
type Tool struct {
ID string `json:"id"`
Name string `json:"name"`
Description *string `json:"description,omitempty"`
Visibility Visibility `json:"visibility,omitempty"`
CreatedAt string `json:"created_at,omitempty"`
ModifiedAt string `json:"modified_at,omitempty"`
SystemPrompt *string `json:"system_prompt,omitempty"`
JsonSchema map[string]any `json:"jsonschema,omitempty"`
Active *bool `json:"active,omitempty"`
}
// ListParams holds query parameters for listing connectors.
type ListParams struct {
Page *int
PageSize *int
}
// ListToolsParams holds query parameters for listing connector tools.
type ListToolsParams struct {
Page *int
PageSize *int
Refresh *bool
}

6
connector/doc.go Normal file
View File

@@ -0,0 +1,6 @@
// Package connector provides types for the Mistral connectors API.
//
// Connectors represent MCP (Model Context Protocol) server integrations.
// Use [CreateRequest] to register a new connector, then use tools
// discovered via the list-tools endpoint in chat or agent completions.
package connector

119
connectors.go Normal file
View File

@@ -0,0 +1,119 @@
package mistral
import (
"context"
"fmt"
"net/url"
"strconv"
"somegit.dev/vikingowl/mistral-go-sdk/connector"
)
// CreateConnector registers a new MCP connector.
func (c *Client) CreateConnector(ctx context.Context, req *connector.CreateRequest) (*connector.Connector, error) {
var resp connector.Connector
if err := c.doJSON(ctx, "POST", "/v1/connectors", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ListConnectors returns all connectors.
func (c *Client) ListConnectors(ctx context.Context, params *connector.ListParams) ([]connector.Connector, error) {
path := "/v1/connectors"
if params != nil {
q := url.Values{}
if params.Page != nil {
q.Set("page", strconv.Itoa(*params.Page))
}
if params.PageSize != nil {
q.Set("page_size", strconv.Itoa(*params.PageSize))
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp []connector.Connector
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return resp, nil
}
// GetConnector retrieves a connector by ID or name.
func (c *Client) GetConnector(ctx context.Context, idOrName string) (*connector.Connector, error) {
var resp connector.Connector
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/connectors/%s", idOrName), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// UpdateConnector updates an existing connector.
func (c *Client) UpdateConnector(ctx context.Context, idOrName string, req *connector.UpdateRequest) (*connector.Connector, error) {
var resp connector.Connector
if err := c.doJSON(ctx, "PATCH", fmt.Sprintf("/v1/connectors/%s", idOrName), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// DeleteConnector deletes a connector.
func (c *Client) DeleteConnector(ctx context.Context, idOrName string) error {
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/connectors/%s", idOrName), nil)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return parseAPIError(resp)
}
return nil
}
// GetConnectorAuthURL returns the OAuth2 authorization URL for a connector.
func (c *Client) GetConnectorAuthURL(ctx context.Context, idOrName string, appReturnURL *string) (*connector.AuthURLResponse, error) {
path := fmt.Sprintf("/v1/connectors/%s/auth_url", idOrName)
if appReturnURL != nil {
path += "?app_return_url=" + url.QueryEscape(*appReturnURL)
}
var resp connector.AuthURLResponse
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ListConnectorTools lists tools exposed by a connector.
func (c *Client) ListConnectorTools(ctx context.Context, idOrName string, params *connector.ListToolsParams) ([]connector.Tool, error) {
path := fmt.Sprintf("/v1/connectors/%s/tools", idOrName)
if params != nil {
q := url.Values{}
if params.Page != nil {
q.Set("page", strconv.Itoa(*params.Page))
}
if params.PageSize != nil {
q.Set("page_size", strconv.Itoa(*params.PageSize))
}
if params.Refresh != nil {
q.Set("refresh", strconv.FormatBool(*params.Refresh))
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp []connector.Tool
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return resp, nil
}
// CallConnectorTool invokes a tool on a connector.
func (c *Client) CallConnectorTool(ctx context.Context, idOrName, toolName string, req *connector.CallToolRequest) (*connector.CallToolResponse, error) {
var resp connector.CallToolResponse
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/connectors/%s/tools/%s/call", idOrName, toolName), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}

217
connectors_test.go Normal file
View File

@@ -0,0 +1,217 @@
package mistral
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/connector"
)
func TestCreateConnector_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("expected POST, got %s", r.Method)
}
if r.URL.Path != "/v1/connectors" {
t.Errorf("got path %s", r.URL.Path)
}
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
if body["name"] != "my_connector" {
t.Errorf("got name %v", body["name"])
}
if body["server"] != "https://mcp.example.com" {
t.Errorf("got server %v", body["server"])
}
json.NewEncoder(w).Encode(map[string]any{
"id": "conn-1", "name": "my_connector",
"description": "test", "created_at": "2025-01-01",
"modified_at": "2025-01-01", "server": "https://mcp.example.com",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.CreateConnector(context.Background(), &connector.CreateRequest{
Name: "my_connector",
Description: "test",
Server: "https://mcp.example.com",
})
if err != nil {
t.Fatal(err)
}
if resp.ID != "conn-1" {
t.Errorf("got id %q", resp.ID)
}
}
func TestListConnectors_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
t.Errorf("expected GET, got %s", r.Method)
}
json.NewEncoder(w).Encode([]map[string]any{
{"id": "c1", "name": "conn1", "description": "d1", "created_at": "t", "modified_at": "t"},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
list, err := client.ListConnectors(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
if len(list) != 1 {
t.Fatalf("got %d connectors", len(list))
}
if list[0].ID != "c1" {
t.Errorf("got id %q", list[0].ID)
}
}
func TestGetConnector_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/connectors/my_conn" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"id": "c1", "name": "my_conn", "description": "d",
"created_at": "t", "modified_at": "t",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
c, err := client.GetConnector(context.Background(), "my_conn")
if err != nil {
t.Fatal(err)
}
if c.Name != "my_conn" {
t.Errorf("got name %q", c.Name)
}
}
func TestUpdateConnector_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "PATCH" {
t.Errorf("expected PATCH, got %s", r.Method)
}
json.NewEncoder(w).Encode(map[string]any{
"id": "c1", "name": "updated", "description": "new desc",
"created_at": "t", "modified_at": "t",
})
}))
defer server.Close()
name := "updated"
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.UpdateConnector(context.Background(), "c1", &connector.UpdateRequest{
Name: &name,
})
if err != nil {
t.Fatal(err)
}
if resp.Name != "updated" {
t.Errorf("got name %q", resp.Name)
}
}
func TestDeleteConnector_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "DELETE" {
t.Errorf("expected DELETE, got %s", r.Method)
}
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
err := client.DeleteConnector(context.Background(), "c1")
if err != nil {
t.Fatal(err)
}
}
func TestGetConnectorAuthURL_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/connectors/c1/auth_url" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"auth_url": "https://oauth.example.com/authorize",
"ttl": 3600,
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.GetConnectorAuthURL(context.Background(), "c1", nil)
if err != nil {
t.Fatal(err)
}
if resp.AuthURL != "https://oauth.example.com/authorize" {
t.Errorf("got auth_url %q", resp.AuthURL)
}
if resp.TTL != 3600 {
t.Errorf("got ttl %d", resp.TTL)
}
}
func TestListConnectorTools_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/connectors/c1/tools" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode([]map[string]any{
{"id": "t1", "name": "search", "description": "search the web"},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
tools, err := client.ListConnectorTools(context.Background(), "c1", nil)
if err != nil {
t.Fatal(err)
}
if len(tools) != 1 {
t.Fatalf("got %d tools", len(tools))
}
if tools[0].Name != "search" {
t.Errorf("got name %q", tools[0].Name)
}
}
func TestCallConnectorTool_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Errorf("expected POST, got %s", r.Method)
}
if r.URL.Path != "/v1/connectors/c1/tools/search/call" {
t.Errorf("got path %s", r.URL.Path)
}
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
args := body["arguments"].(map[string]any)
if args["query"] != "hello" {
t.Errorf("got query %v", args["query"])
}
json.NewEncoder(w).Encode(map[string]any{
"content": []map[string]any{{"type": "text", "text": "result"}},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.CallConnectorTool(context.Background(), "c1", "search", &connector.CallToolRequest{
Arguments: map[string]any{"query": "hello"},
})
if err != nil {
t.Fatal(err)
}
if resp.Content == nil {
t.Error("expected non-nil content")
}
}

6
doc.go
View File

@@ -39,9 +39,9 @@
// # Sub-packages
//
// Types are organized into sub-packages by domain: [chat], [agents],
// [conversation], [embedding], [model], [file], [finetune], [batch],
// [ocr], [audio], [library], [moderation], [classification], and [fim].
// All service methods live directly on [Client].
// [connector], [conversation], [embedding], [model], [file], [finetune],
// [batch], [ocr], [audio], [library], [moderation], [classification],
// [fim], and [observability]. All service methods live directly on [Client].
//
// # Reference
//

View File

@@ -64,6 +64,23 @@ type SignedURL struct {
URL string `json:"url"`
}
// Visibility controls who can see a file.
type Visibility string
const (
VisibilitySharedGlobal Visibility = "shared_global"
VisibilitySharedOrg Visibility = "shared_org"
VisibilitySharedWorkspace Visibility = "shared_workspace"
VisibilityPrivate Visibility = "private"
)
// UploadParams holds parameters for uploading a file.
type UploadParams struct {
Purpose Purpose
Expiry *int
Visibility *Visibility
}
// ListParams holds optional parameters for listing files.
type ListParams struct {
Page *int

View File

@@ -12,10 +12,18 @@ import (
)
// UploadFile uploads a file for use with fine-tuning, batch, or OCR.
func (c *Client) UploadFile(ctx context.Context, filename string, r io.Reader, purpose file.Purpose) (*file.File, error) {
func (c *Client) UploadFile(ctx context.Context, filename string, r io.Reader, params *file.UploadParams) (*file.File, error) {
fields := map[string]string{}
if purpose != "" {
fields["purpose"] = string(purpose)
if params != nil {
if params.Purpose != "" {
fields["purpose"] = string(params.Purpose)
}
if params.Expiry != nil {
fields["expiry"] = strconv.Itoa(*params.Expiry)
}
if params.Visibility != nil {
fields["visibility"] = string(*params.Visibility)
}
}
var resp file.File
if err := c.doMultipart(ctx, "/v1/files", filename, r, fields, &resp); err != nil {

View File

@@ -58,7 +58,7 @@ func TestUploadFile_Success(t *testing.T) {
context.Background(),
"train.jsonl",
strings.NewReader(`{"text":"hello"}`),
file.PurposeFineTune,
&file.UploadParams{Purpose: file.PurposeFineTune},
)
if err != nil {
t.Fatal(err)
@@ -74,6 +74,43 @@ func TestUploadFile_Success(t *testing.T) {
}
}
func TestUploadFile_WithExpiryAndVisibility(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseMultipartForm(10 << 20); err != nil {
t.Fatal(err)
}
if r.FormValue("purpose") != "fine-tune" {
t.Errorf("got purpose %q", r.FormValue("purpose"))
}
if r.FormValue("expiry") != "48" {
t.Errorf("expected expiry=48, got %q", r.FormValue("expiry"))
}
if r.FormValue("visibility") != "private" {
t.Errorf("expected visibility=private, got %q", r.FormValue("visibility"))
}
json.NewEncoder(w).Encode(map[string]any{
"id": "file-ev", "object": "file", "bytes": 10,
"created_at": 1, "filename": "data.jsonl",
"purpose": "fine-tune", "sample_type": "instruct",
"source": "upload",
})
}))
defer server.Close()
expiry := 48
vis := file.VisibilityPrivate
client := NewClient("key", WithBaseURL(server.URL))
_, err := client.UploadFile(context.Background(), "data.jsonl", strings.NewReader("{}"), &file.UploadParams{
Purpose: file.PurposeFineTune,
Expiry: &expiry,
Visibility: &vis,
})
if err != nil {
t.Fatal(err)
}
}
func TestListFiles_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
@@ -260,7 +297,7 @@ func TestUploadFile_Error(t *testing.T) {
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
_, err := client.UploadFile(context.Background(), "bad.txt", strings.NewReader(""), file.PurposeFineTune)
_, err := client.UploadFile(context.Background(), "bad.txt", strings.NewReader(""), &file.UploadParams{Purpose: file.PurposeFineTune})
if err == nil {
t.Fatal("expected error")
}

View File

@@ -24,7 +24,7 @@ func integrationClient(t *testing.T) *Client {
func TestIntegration_ListModels(t *testing.T) {
client := integrationClient(t)
resp, err := client.ListModels(context.Background())
resp, err := client.ListModels(context.Background(), nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -6,7 +6,7 @@ import (
)
// Version is the SDK version string.
const Version = "1.0.0"
const Version = "1.1.0"
const (
defaultBaseURL = "https://api.mistral.ai"

View File

@@ -52,3 +52,9 @@ type DeleteModelOut struct {
Object string `json:"object"`
Deleted bool `json:"deleted"`
}
// ListParams holds optional parameters for listing models.
type ListParams struct {
Provider *string
Model *string
}

View File

@@ -2,14 +2,28 @@ package mistral
import (
"context"
"net/url"
"somegit.dev/vikingowl/mistral-go-sdk/model"
)
// ListModels returns a list of available models.
func (c *Client) ListModels(ctx context.Context) (*model.ModelList, error) {
func (c *Client) ListModels(ctx context.Context, params *model.ListParams) (*model.ModelList, error) {
path := "/v1/models"
if params != nil {
q := url.Values{}
if params.Provider != nil {
q.Set("provider", *params.Provider)
}
if params.Model != nil {
q.Set("model", *params.Model)
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp model.ModelList
if err := c.doJSON(ctx, "GET", "/v1/models", nil, &resp); err != nil {
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil

View File

@@ -6,6 +6,8 @@ import (
"net/http"
"net/http/httptest"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/model"
)
func TestListModels_Success(t *testing.T) {
@@ -45,7 +47,7 @@ func TestListModels_Success(t *testing.T) {
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
list, err := client.ListModels(context.Background())
list, err := client.ListModels(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
@@ -82,6 +84,33 @@ func TestListModels_Success(t *testing.T) {
}
}
func TestListModels_WithParams(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Query().Get("provider") != "mistralai" {
t.Errorf("expected provider=mistralai, got %q", r.URL.Query().Get("provider"))
}
if r.URL.Query().Get("model") != "mistral-small" {
t.Errorf("expected model=mistral-small, got %q", r.URL.Query().Get("model"))
}
json.NewEncoder(w).Encode(map[string]any{
"object": "list",
"data": []map[string]any{},
})
}))
defer server.Close()
provider := "mistralai"
modelName := "mistral-small"
client := NewClient("key", WithBaseURL(server.URL))
_, err := client.ListModels(context.Background(), &model.ListParams{
Provider: &provider,
Model: &modelName,
})
if err != nil {
t.Fatal(err)
}
}
func TestGetModel_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/models/mistral-small-latest" {

46
observability/campaign.go Normal file
View File

@@ -0,0 +1,46 @@
package observability
// Campaign represents an observability campaign.
type Campaign struct {
ID string `json:"id"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
DeletedAt *string `json:"deleted_at,omitempty"`
Name string `json:"name"`
OwnerID string `json:"owner_id"`
WorkspaceID string `json:"workspace_id"`
Description string `json:"description"`
MaxNbEvents int `json:"max_nb_events"`
SearchParams FilterPayload `json:"search_params"`
Judge Judge `json:"judge"`
}
// CreateCampaignRequest creates a new campaign.
type CreateCampaignRequest struct {
SearchParams FilterPayload `json:"search_params"`
JudgeID string `json:"judge_id"`
Name string `json:"name"`
Description string `json:"description"`
MaxNbEvents int `json:"max_nb_events"`
}
// CampaignStatusResponse is the response for campaign status.
type CampaignStatusResponse struct {
Status TaskStatus `json:"status"`
}
// ListCampaignsResponse is the response from listing campaigns.
type ListCampaignsResponse struct {
Count int `json:"count"`
Results []Campaign `json:"results,omitempty"`
Next *string `json:"next,omitempty"`
Previous *string `json:"previous,omitempty"`
}
// ListCampaignEventsResponse is the response from listing campaign events.
type ListCampaignEventsResponse struct {
Count int `json:"count"`
Results []ChatCompletionEventPreview `json:"results,omitempty"`
Next *string `json:"next,omitempty"`
Previous *string `json:"previous,omitempty"`
}

156
observability/dataset.go Normal file
View File

@@ -0,0 +1,156 @@
package observability
import "encoding/json"
// ConversationSource indicates how a dataset record was created.
type ConversationSource string
const (
SourceExplorer ConversationSource = "EXPLORER"
SourceUploadedFile ConversationSource = "UPLOADED_FILE"
SourceDirectInput ConversationSource = "DIRECT_INPUT"
SourcePlayground ConversationSource = "PLAYGROUND"
)
// Dataset represents a dataset entity.
type Dataset struct {
ID string `json:"id"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
DeletedAt *string `json:"deleted_at,omitempty"`
Name string `json:"name"`
Description string `json:"description"`
OwnerID string `json:"owner_id"`
WorkspaceID string `json:"workspace_id"`
}
// CreateDatasetRequest creates a new dataset.
type CreateDatasetRequest struct {
Name string `json:"name"`
Description string `json:"description"`
}
// UpdateDatasetRequest updates a dataset.
type UpdateDatasetRequest struct {
Name *string `json:"name,omitempty"`
Description *string `json:"description,omitempty"`
}
// DatasetRecord is a single record in a dataset.
type DatasetRecord struct {
ID string `json:"id"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
DeletedAt *string `json:"deleted_at,omitempty"`
DatasetID string `json:"dataset_id"`
Payload ConversationPayload `json:"payload"`
Properties map[string]any `json:"properties,omitempty"`
Source ConversationSource `json:"source"`
}
// ConversationPayload holds the messages for a dataset record.
type ConversationPayload struct {
Messages []map[string]any `json:"messages"`
}
// CreateRecordRequest creates a new dataset record.
type CreateRecordRequest struct {
Payload ConversationPayload `json:"payload"`
Properties map[string]any `json:"properties"`
}
// UpdateRecordPayloadRequest updates a record's payload.
type UpdateRecordPayloadRequest struct {
Payload ConversationPayload `json:"payload"`
}
// UpdateRecordPropertiesRequest updates a record's properties.
type UpdateRecordPropertiesRequest struct {
Properties map[string]any `json:"properties"`
}
// BulkDeleteRecordsRequest deletes multiple records.
type BulkDeleteRecordsRequest struct {
DatasetRecordIDs []string `json:"dataset_record_ids"`
}
// JudgeRecordRequest judges a dataset record.
type JudgeRecordRequest struct {
JudgeDefinition CreateJudgeRequest `json:"judge_definition"`
}
// DatasetImportTask tracks an async import operation.
type DatasetImportTask struct {
ID string `json:"id"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
DeletedAt *string `json:"deleted_at,omitempty"`
CreatorID string `json:"creator_id"`
DatasetID string `json:"dataset_id"`
WorkspaceID string `json:"workspace_id"`
Status TaskStatus `json:"status"`
Progress *int `json:"progress,omitempty"`
Message *string `json:"message,omitempty"`
}
// ExportDatasetResponse is the response from exporting a dataset.
type ExportDatasetResponse struct {
FileURL string `json:"file_url"`
}
// Import request types.
// ImportFromCampaignRequest imports records from a campaign.
type ImportFromCampaignRequest struct {
CampaignID string `json:"campaign_id"`
}
// ImportFromExplorerRequest imports records from explorer events.
type ImportFromExplorerRequest struct {
CompletionEventIDs []string `json:"completion_event_ids"`
}
// ImportFromFileRequest imports records from a file.
type ImportFromFileRequest struct {
FileID string `json:"file_id"`
}
// ImportFromPlaygroundRequest imports records from playground conversations.
type ImportFromPlaygroundRequest struct {
ConversationIDs []string `json:"conversation_ids"`
}
// ImportFromDatasetRequest imports records from another dataset.
type ImportFromDatasetRequest struct {
DatasetRecordIDs []string `json:"dataset_record_ids"`
}
// List response types.
// ListDatasetsResponse is the response from listing datasets.
type ListDatasetsResponse struct {
Count int `json:"count"`
Results []Dataset `json:"results,omitempty"`
Next *string `json:"next,omitempty"`
Previous *string `json:"previous,omitempty"`
}
// ListRecordsResponse is the response from listing dataset records.
type ListRecordsResponse struct {
Count int `json:"count"`
Results []DatasetRecord `json:"results,omitempty"`
Next *string `json:"next,omitempty"`
Previous *string `json:"previous,omitempty"`
}
// ListTasksResponse is the response from listing import tasks.
type ListTasksResponse struct {
Count int `json:"count"`
Results []DatasetImportTask `json:"results,omitempty"`
Next *string `json:"next,omitempty"`
Previous *string `json:"previous,omitempty"`
}
// JudgeResultResponse is the raw response from judging operations.
// The shape depends on the judge type (classification or regression).
type JudgeResultResponse json.RawMessage

5
observability/doc.go Normal file
View File

@@ -0,0 +1,5 @@
// Package observability provides types for the Mistral observability API (beta).
//
// This includes campaigns, chat completion events, judges, and datasets
// for monitoring and evaluating model behavior.
package observability

70
observability/event.go Normal file
View File

@@ -0,0 +1,70 @@
package observability
// ChatCompletionEvent is a full chat completion event.
type ChatCompletionEvent struct {
EventID string `json:"event_id"`
CorrelationID string `json:"correlation_id"`
CreatedAt string `json:"created_at"`
ExtraFields map[string]any `json:"extra_fields,omitempty"`
NbInputTokens int `json:"nb_input_tokens"`
NbOutputTokens int `json:"nb_output_tokens"`
EnabledTools []map[string]any `json:"enabled_tools,omitempty"`
RequestMessages []map[string]any `json:"request_messages,omitempty"`
ResponseMessages []map[string]any `json:"response_messages,omitempty"`
NbMessages int `json:"nb_messages"`
ChatTranscriptionEvents []ChatTranscriptionEvent `json:"chat_transcription_events,omitempty"`
}
// ChatCompletionEventPreview is a summary of a chat completion event.
type ChatCompletionEventPreview struct {
EventID string `json:"event_id"`
CorrelationID string `json:"correlation_id"`
CreatedAt string `json:"created_at"`
ExtraFields map[string]any `json:"extra_fields,omitempty"`
NbInputTokens int `json:"nb_input_tokens"`
NbOutputTokens int `json:"nb_output_tokens"`
}
// ChatTranscriptionEvent is an audio transcription within a chat event.
type ChatTranscriptionEvent struct {
AudioURL string `json:"audio_url"`
Model string `json:"model"`
ResponseMessage map[string]any `json:"response_message"`
}
// SearchEventsRequest is the request body for searching chat completion events.
type SearchEventsRequest struct {
SearchParams FilterPayload `json:"search_params"`
ExtraFields []string `json:"extra_fields,omitempty"`
}
// SearchEventsResponse is the response from searching events.
type SearchEventsResponse struct {
Results []ChatCompletionEventPreview `json:"results,omitempty"`
Next *string `json:"next,omitempty"`
Cursor *string `json:"cursor,omitempty"`
}
// SearchEventIDsRequest is the request body for searching event IDs.
type SearchEventIDsRequest struct {
SearchParams FilterPayload `json:"search_params"`
ExtraFields []string `json:"extra_fields,omitempty"`
}
// SearchEventIDsResponse is the response from searching event IDs.
type SearchEventIDsResponse struct {
CompletionEventIDs []string `json:"completion_event_ids"`
}
// JudgeEventRequest is the request body for judging a chat completion event.
type JudgeEventRequest struct {
JudgeDefinition CreateJudgeRequest `json:"judge_definition"`
}
// SimilarEventsResponse is the response from fetching similar events.
type SimilarEventsResponse struct {
Count int `json:"count"`
Results []ChatCompletionEventPreview `json:"results,omitempty"`
Next *string `json:"next,omitempty"`
Previous *string `json:"previous,omitempty"`
}

75
observability/filter.go Normal file
View File

@@ -0,0 +1,75 @@
package observability
import "encoding/json"
// Op is a filter comparison operator.
type Op string
const (
OpLt Op = "lt"
OpLte Op = "lte"
OpGt Op = "gt"
OpGte Op = "gte"
OpEq Op = "eq"
OpNeq Op = "neq"
OpIsNull Op = "isnull"
OpStartsWith Op = "startswith"
OpIStartsWith Op = "istartswith"
OpEndsWith Op = "endswith"
OpIEndsWith Op = "iendswith"
OpContains Op = "contains"
OpIContains Op = "icontains"
OpMatches Op = "matches"
OpNotContains Op = "notcontains"
OpINotContains Op = "inotcontains"
OpIncludes Op = "includes"
OpExcludes Op = "excludes"
OpLenEq Op = "len_eq"
)
// FilterCondition is a single filter comparison.
type FilterCondition struct {
Field string `json:"field"`
Op Op `json:"op"`
Value any `json:"value"`
}
// FilterGroup combines filters with AND/OR logic.
// The JSON keys are uppercase "AND" / "OR".
type FilterGroup struct {
AND []json.RawMessage `json:"AND,omitempty"`
OR []json.RawMessage `json:"OR,omitempty"`
}
// FilterPayload wraps the top-level filter for search operations.
// Filters can be a FilterGroup or a FilterCondition.
type FilterPayload struct {
Filters json.RawMessage `json:"filters,omitempty"`
}
// TaskStatus is the status of an async task.
type TaskStatus string
const (
TaskStatusRunning TaskStatus = "RUNNING"
TaskStatusCompleted TaskStatus = "COMPLETED"
TaskStatusFailed TaskStatus = "FAILED"
TaskStatusCanceled TaskStatus = "CANCELED"
TaskStatusTerminated TaskStatus = "TERMINATED"
TaskStatusContinuedAsNew TaskStatus = "CONTINUED_AS_NEW"
TaskStatusTimedOut TaskStatus = "TIMED_OUT"
TaskStatusUnknown TaskStatus = "UNKNOWN"
)
// PaginationParams holds common pagination query parameters.
type PaginationParams struct {
Page *int
PageSize *int
}
// SearchParams holds common search query parameters.
type SearchParams struct {
Page *int
PageSize *int
Q *string
}

114
observability/judge.go Normal file
View File

@@ -0,0 +1,114 @@
package observability
import (
"encoding/json"
"fmt"
)
// JudgeOutputType identifies the kind of judge output.
type JudgeOutputType string
const (
JudgeOutputClassification JudgeOutputType = "CLASSIFICATION"
JudgeOutputRegression JudgeOutputType = "REGRESSION"
)
// JudgeOutput is a sealed interface for judge output configurations.
type JudgeOutput interface {
judgeOutputType() JudgeOutputType
}
// ClassificationOutput configures a classification judge.
type ClassificationOutput struct {
Type JudgeOutputType `json:"type"`
Options []ClassificationOption `json:"options"`
}
func (*ClassificationOutput) judgeOutputType() JudgeOutputType { return JudgeOutputClassification }
// ClassificationOption is a single option for a classification judge.
type ClassificationOption struct {
Value string `json:"value"`
Description string `json:"description"`
}
// RegressionOutput configures a regression judge.
type RegressionOutput struct {
Type JudgeOutputType `json:"type"`
MinDescription string `json:"min_description"`
MaxDescription string `json:"max_description"`
Min *float64 `json:"min,omitempty"`
Max *float64 `json:"max,omitempty"`
}
func (*RegressionOutput) judgeOutputType() JudgeOutputType { return JudgeOutputRegression }
// UnmarshalJudgeOutput dispatches to the concrete JudgeOutput type.
func UnmarshalJudgeOutput(data []byte) (JudgeOutput, error) {
var probe struct {
Type string `json:"type"`
}
if err := json.Unmarshal(data, &probe); err != nil {
return nil, fmt.Errorf("unmarshal judge output: %w", err)
}
switch JudgeOutputType(probe.Type) {
case JudgeOutputClassification:
var o ClassificationOutput
return &o, json.Unmarshal(data, &o)
case JudgeOutputRegression:
var o RegressionOutput
return &o, json.Unmarshal(data, &o)
default:
return nil, fmt.Errorf("unknown judge output type: %q", probe.Type)
}
}
// Judge represents a judge entity.
type Judge struct {
ID string `json:"id"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
DeletedAt *string `json:"deleted_at,omitempty"`
OwnerID string `json:"owner_id"`
WorkspaceID string `json:"workspace_id"`
Name string `json:"name"`
Description string `json:"description"`
ModelName string `json:"model_name"`
Output json.RawMessage `json:"output"`
Instructions string `json:"instructions"`
Tools []string `json:"tools,omitempty"`
}
// CreateJudgeRequest creates a new judge.
type CreateJudgeRequest struct {
Name string `json:"name"`
Description string `json:"description"`
ModelName string `json:"model_name"`
Output json.RawMessage `json:"output"`
Instructions string `json:"instructions"`
Tools []string `json:"tools"`
}
// UpdateJudgeRequest updates a judge.
type UpdateJudgeRequest struct {
Name string `json:"name"`
Description string `json:"description"`
ModelName string `json:"model_name"`
Output json.RawMessage `json:"output"`
Instructions string `json:"instructions"`
Tools []string `json:"tools"`
}
// JudgeConversationRequest is the request for live-judging a conversation.
type JudgeConversationRequest struct {
Messages []map[string]any `json:"messages"`
Properties map[string]any `json:"properties,omitempty"`
}
// ListJudgesResponse is the response from listing judges.
type ListJudgesResponse struct {
Count int `json:"count"`
Results []Judge `json:"results,omitempty"`
Next *string `json:"next,omitempty"`
Previous *string `json:"previous,omitempty"`
}

View File

@@ -0,0 +1,97 @@
package mistral
import (
"context"
"fmt"
"net/url"
"strconv"
"somegit.dev/vikingowl/mistral-go-sdk/observability"
)
// CreateCampaign creates a new observability campaign.
func (c *Client) CreateCampaign(ctx context.Context, req *observability.CreateCampaignRequest) (*observability.Campaign, error) {
var resp observability.Campaign
if err := c.doJSON(ctx, "POST", "/v1/observability/campaigns", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ListCampaigns lists observability campaigns.
func (c *Client) ListCampaigns(ctx context.Context, params *observability.SearchParams) (*observability.ListCampaignsResponse, error) {
path := "/v1/observability/campaigns"
if params != nil {
q := url.Values{}
if params.PageSize != nil {
q.Set("page_size", strconv.Itoa(*params.PageSize))
}
if params.Page != nil {
q.Set("page", strconv.Itoa(*params.Page))
}
if params.Q != nil {
q.Set("q", *params.Q)
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp observability.ListCampaignsResponse
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetCampaign retrieves a campaign by ID.
func (c *Client) GetCampaign(ctx context.Context, campaignID string) (*observability.Campaign, error) {
var resp observability.Campaign
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/campaigns/%s", campaignID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// DeleteCampaign deletes a campaign.
func (c *Client) DeleteCampaign(ctx context.Context, campaignID string) error {
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/observability/campaigns/%s", campaignID), nil)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return parseAPIError(resp)
}
return nil
}
// GetCampaignStatus retrieves the status of a campaign.
func (c *Client) GetCampaignStatus(ctx context.Context, campaignID string) (*observability.CampaignStatusResponse, error) {
var resp observability.CampaignStatusResponse
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/campaigns/%s/status", campaignID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ListCampaignEvents lists events selected by a campaign.
func (c *Client) ListCampaignEvents(ctx context.Context, campaignID string, params *observability.PaginationParams) (*observability.ListCampaignEventsResponse, error) {
path := fmt.Sprintf("/v1/observability/campaigns/%s/selected-events", campaignID)
if params != nil {
q := url.Values{}
if params.PageSize != nil {
q.Set("page_size", strconv.Itoa(*params.PageSize))
}
if params.Page != nil {
q.Set("page", strconv.Itoa(*params.Page))
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp observability.ListCampaignEventsResponse
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}

View File

@@ -0,0 +1,148 @@
package mistral
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/observability"
)
func TestCreateCampaign_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" || r.URL.Path != "/v1/observability/campaigns" {
t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
}
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
if body["name"] != "test-campaign" {
t.Errorf("got name %v", body["name"])
}
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]any{
"id": "camp-1", "name": "test-campaign", "description": "d",
"created_at": "t", "updated_at": "t", "owner_id": "o",
"workspace_id": "w", "max_nb_events": 100,
"search_params": map[string]any{}, "judge": map[string]any{"id": "j1"},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.CreateCampaign(context.Background(), &observability.CreateCampaignRequest{
Name: "test-campaign",
Description: "d",
JudgeID: "j1",
MaxNbEvents: 100,
})
if err != nil {
t.Fatal(err)
}
if resp.ID != "camp-1" {
t.Errorf("got id %q", resp.ID)
}
}
func TestListCampaigns_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/campaigns" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"count": 1,
"results": []map[string]any{{"id": "c1", "name": "c", "description": "d", "created_at": "t", "updated_at": "t", "owner_id": "o", "workspace_id": "w", "max_nb_events": 10, "search_params": map[string]any{}, "judge": map[string]any{"id": "j"}}},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.ListCampaigns(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
if resp.Count != 1 {
t.Errorf("got count %d", resp.Count)
}
}
func TestGetCampaign_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/campaigns/camp-1" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"id": "camp-1", "name": "c", "description": "d",
"created_at": "t", "updated_at": "t", "owner_id": "o",
"workspace_id": "w", "max_nb_events": 10,
"search_params": map[string]any{}, "judge": map[string]any{"id": "j"},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.GetCampaign(context.Background(), "camp-1")
if err != nil {
t.Fatal(err)
}
if resp.ID != "camp-1" {
t.Errorf("got id %q", resp.ID)
}
}
func TestDeleteCampaign_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "DELETE" {
t.Errorf("expected DELETE")
}
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
if err := client.DeleteCampaign(context.Background(), "camp-1"); err != nil {
t.Fatal(err)
}
}
func TestGetCampaignStatus_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/campaigns/camp-1/status" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{"status": "COMPLETED"})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.GetCampaignStatus(context.Background(), "camp-1")
if err != nil {
t.Fatal(err)
}
if resp.Status != observability.TaskStatusCompleted {
t.Errorf("got status %q", resp.Status)
}
}
func TestListCampaignEvents_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/campaigns/camp-1/selected-events" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"count": 0,
"results": []any{},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.ListCampaignEvents(context.Background(), "camp-1", nil)
if err != nil {
t.Fatal(err)
}
if resp.Count != 0 {
t.Errorf("got count %d", resp.Count)
}
}

252
observability_datasets.go Normal file
View File

@@ -0,0 +1,252 @@
package mistral
import (
"context"
"encoding/json"
"fmt"
"net/url"
"strconv"
"somegit.dev/vikingowl/mistral-go-sdk/observability"
)
// CreateDataset creates a new observability dataset.
func (c *Client) CreateDataset(ctx context.Context, req *observability.CreateDatasetRequest) (*observability.Dataset, error) {
var resp observability.Dataset
if err := c.doJSON(ctx, "POST", "/v1/observability/datasets", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ListDatasets lists observability datasets.
func (c *Client) ListDatasets(ctx context.Context, params *observability.SearchParams) (*observability.ListDatasetsResponse, error) {
path := "/v1/observability/datasets"
if params != nil {
q := url.Values{}
if params.PageSize != nil {
q.Set("page_size", strconv.Itoa(*params.PageSize))
}
if params.Page != nil {
q.Set("page", strconv.Itoa(*params.Page))
}
if params.Q != nil {
q.Set("q", *params.Q)
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp observability.ListDatasetsResponse
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetDataset retrieves a dataset by ID.
func (c *Client) GetDataset(ctx context.Context, datasetID string) (*observability.Dataset, error) {
var resp observability.Dataset
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/datasets/%s", datasetID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// UpdateDataset updates a dataset.
func (c *Client) UpdateDataset(ctx context.Context, datasetID string, req *observability.UpdateDatasetRequest) (*observability.Dataset, error) {
var resp observability.Dataset
if err := c.doJSON(ctx, "PATCH", fmt.Sprintf("/v1/observability/datasets/%s", datasetID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// DeleteDataset deletes a dataset.
func (c *Client) DeleteDataset(ctx context.Context, datasetID string) error {
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/observability/datasets/%s", datasetID), nil)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return parseAPIError(resp)
}
return nil
}
// ExportDatasetToJSONL exports a dataset to JSONL format.
func (c *Client) ExportDatasetToJSONL(ctx context.Context, datasetID string) (*observability.ExportDatasetResponse, error) {
var resp observability.ExportDatasetResponse
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/datasets/%s/exports/to-jsonl", datasetID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// Dataset records
// ListDatasetRecords lists records in a dataset.
func (c *Client) ListDatasetRecords(ctx context.Context, datasetID string, params *observability.PaginationParams) (*observability.ListRecordsResponse, error) {
path := fmt.Sprintf("/v1/observability/datasets/%s/records", datasetID)
if params != nil {
q := url.Values{}
if params.PageSize != nil {
q.Set("page_size", strconv.Itoa(*params.PageSize))
}
if params.Page != nil {
q.Set("page", strconv.Itoa(*params.Page))
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp observability.ListRecordsResponse
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// CreateDatasetRecord creates a record in a dataset.
func (c *Client) CreateDatasetRecord(ctx context.Context, datasetID string, req *observability.CreateRecordRequest) (*observability.DatasetRecord, error) {
var resp observability.DatasetRecord
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/records", datasetID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetDatasetRecord retrieves a dataset record by ID.
func (c *Client) GetDatasetRecord(ctx context.Context, recordID string) (*observability.DatasetRecord, error) {
var resp observability.DatasetRecord
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/dataset-records/%s", recordID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// UpdateDatasetRecordPayload updates a record's payload.
func (c *Client) UpdateDatasetRecordPayload(ctx context.Context, recordID string, req *observability.UpdateRecordPayloadRequest) (*observability.DatasetRecord, error) {
var resp observability.DatasetRecord
if err := c.doJSON(ctx, "PUT", fmt.Sprintf("/v1/observability/dataset-records/%s/payload", recordID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// UpdateDatasetRecordProperties updates a record's properties.
func (c *Client) UpdateDatasetRecordProperties(ctx context.Context, recordID string, req *observability.UpdateRecordPropertiesRequest) (*observability.DatasetRecord, error) {
var resp observability.DatasetRecord
if err := c.doJSON(ctx, "PUT", fmt.Sprintf("/v1/observability/dataset-records/%s/properties", recordID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// DeleteDatasetRecord deletes a dataset record.
func (c *Client) DeleteDatasetRecord(ctx context.Context, recordID string) error {
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/observability/dataset-records/%s", recordID), nil)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return parseAPIError(resp)
}
return nil
}
// BulkDeleteDatasetRecords deletes multiple dataset records.
func (c *Client) BulkDeleteDatasetRecords(ctx context.Context, req *observability.BulkDeleteRecordsRequest) error {
return c.doJSON(ctx, "POST", "/v1/observability/dataset-records/bulk-delete", req, nil)
}
// JudgeDatasetRecord judges a dataset record.
func (c *Client) JudgeDatasetRecord(ctx context.Context, recordID string, req *observability.JudgeRecordRequest) (json.RawMessage, error) {
var resp json.RawMessage
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/dataset-records/%s/live-judging", recordID), req, &resp); err != nil {
return nil, err
}
return resp, nil
}
// Import operations
// ImportDatasetFromCampaign imports records from a campaign.
func (c *Client) ImportDatasetFromCampaign(ctx context.Context, datasetID string, req *observability.ImportFromCampaignRequest) (*observability.DatasetImportTask, error) {
var resp observability.DatasetImportTask
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/imports/from-campaign", datasetID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ImportDatasetFromExplorer imports records from explorer events.
func (c *Client) ImportDatasetFromExplorer(ctx context.Context, datasetID string, req *observability.ImportFromExplorerRequest) (*observability.DatasetImportTask, error) {
var resp observability.DatasetImportTask
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/imports/from-explorer", datasetID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ImportDatasetFromFile imports records from a file.
func (c *Client) ImportDatasetFromFile(ctx context.Context, datasetID string, req *observability.ImportFromFileRequest) (*observability.DatasetImportTask, error) {
var resp observability.DatasetImportTask
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/imports/from-file", datasetID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ImportDatasetFromPlayground imports records from playground conversations.
func (c *Client) ImportDatasetFromPlayground(ctx context.Context, datasetID string, req *observability.ImportFromPlaygroundRequest) (*observability.DatasetImportTask, error) {
var resp observability.DatasetImportTask
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/imports/from-playground", datasetID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ImportDatasetFromDataset imports records from another dataset.
func (c *Client) ImportDatasetFromDataset(ctx context.Context, datasetID string, req *observability.ImportFromDatasetRequest) (*observability.DatasetImportTask, error) {
var resp observability.DatasetImportTask
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/imports/from-dataset", datasetID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// Tasks
// ListDatasetTasks lists import tasks for a dataset.
func (c *Client) ListDatasetTasks(ctx context.Context, datasetID string, params *observability.PaginationParams) (*observability.ListTasksResponse, error) {
path := fmt.Sprintf("/v1/observability/datasets/%s/tasks", datasetID)
if params != nil {
q := url.Values{}
if params.PageSize != nil {
q.Set("page_size", strconv.Itoa(*params.PageSize))
}
if params.Page != nil {
q.Set("page", strconv.Itoa(*params.Page))
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp observability.ListTasksResponse
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetDatasetTask retrieves an import task by ID.
func (c *Client) GetDatasetTask(ctx context.Context, datasetID, taskID string) (*observability.DatasetImportTask, error) {
var resp observability.DatasetImportTask
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/datasets/%s/tasks/%s", datasetID, taskID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}

View File

@@ -0,0 +1,211 @@
package mistral
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/observability"
)
func datasetJSON() map[string]any {
return map[string]any{
"id": "ds-1", "created_at": "t", "updated_at": "t",
"name": "test-ds", "description": "d",
"owner_id": "o", "workspace_id": "w",
}
}
func TestCreateDataset_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" || r.URL.Path != "/v1/observability/datasets" {
t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
}
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(datasetJSON())
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.CreateDataset(context.Background(), &observability.CreateDatasetRequest{
Name: "test-ds",
Description: "d",
})
if err != nil {
t.Fatal(err)
}
if resp.ID != "ds-1" {
t.Errorf("got id %q", resp.ID)
}
}
func TestListDatasets_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]any{
"count": 1,
"results": []any{datasetJSON()},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.ListDatasets(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
if resp.Count != 1 {
t.Errorf("got count %d", resp.Count)
}
}
func TestGetDataset_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/datasets/ds-1" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(datasetJSON())
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.GetDataset(context.Background(), "ds-1")
if err != nil {
t.Fatal(err)
}
if resp.Name != "test-ds" {
t.Errorf("got name %q", resp.Name)
}
}
func TestDeleteDataset_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
if err := client.DeleteDataset(context.Background(), "ds-1"); err != nil {
t.Fatal(err)
}
}
func TestCreateDatasetRecord_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" || r.URL.Path != "/v1/observability/datasets/ds-1/records" {
t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
}
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]any{
"id": "rec-1", "created_at": "t", "updated_at": "t",
"dataset_id": "ds-1", "source": "DIRECT_INPUT",
"payload": map[string]any{"messages": []any{}},
"properties": map[string]any{},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.CreateDatasetRecord(context.Background(), "ds-1", &observability.CreateRecordRequest{
Payload: observability.ConversationPayload{Messages: []map[string]any{}},
Properties: map[string]any{},
})
if err != nil {
t.Fatal(err)
}
if resp.ID != "rec-1" {
t.Errorf("got id %q", resp.ID)
}
}
func TestListDatasetRecords_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/datasets/ds-1/records" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"count": 0, "results": []any{},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.ListDatasetRecords(context.Background(), "ds-1", nil)
if err != nil {
t.Fatal(err)
}
if resp.Count != 0 {
t.Errorf("got count %d", resp.Count)
}
}
func TestImportDatasetFromCampaign_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/datasets/ds-1/imports/from-campaign" {
t.Errorf("got path %s", r.URL.Path)
}
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]any{
"id": "task-1", "created_at": "t", "updated_at": "t",
"creator_id": "u", "dataset_id": "ds-1", "workspace_id": "w",
"status": "RUNNING",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.ImportDatasetFromCampaign(context.Background(), "ds-1", &observability.ImportFromCampaignRequest{
CampaignID: "camp-1",
})
if err != nil {
t.Fatal(err)
}
if resp.Status != observability.TaskStatusRunning {
t.Errorf("got status %q", resp.Status)
}
}
func TestExportDatasetToJSONL_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/datasets/ds-1/exports/to-jsonl" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"file_url": "https://storage.example.com/export.jsonl",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.ExportDatasetToJSONL(context.Background(), "ds-1")
if err != nil {
t.Fatal(err)
}
if resp.FileURL != "https://storage.example.com/export.jsonl" {
t.Errorf("got file_url %q", resp.FileURL)
}
}
func TestGetDatasetTask_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/datasets/ds-1/tasks/task-1" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"id": "task-1", "created_at": "t", "updated_at": "t",
"creator_id": "u", "dataset_id": "ds-1", "workspace_id": "w",
"status": "COMPLETED",
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.GetDatasetTask(context.Background(), "ds-1", "task-1")
if err != nil {
t.Fatal(err)
}
if resp.Status != observability.TaskStatusCompleted {
t.Errorf("got status %q", resp.Status)
}
}

69
observability_events.go Normal file
View File

@@ -0,0 +1,69 @@
package mistral
import (
"context"
"encoding/json"
"fmt"
"net/url"
"strconv"
"somegit.dev/vikingowl/mistral-go-sdk/observability"
)
// SearchChatCompletionEvents searches for chat completion events.
func (c *Client) SearchChatCompletionEvents(ctx context.Context, req *observability.SearchEventsRequest) (*observability.SearchEventsResponse, error) {
var resp observability.SearchEventsResponse
if err := c.doJSON(ctx, "POST", "/v1/observability/chat-completion-events/search", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// SearchChatCompletionEventIDs searches for chat completion event IDs.
func (c *Client) SearchChatCompletionEventIDs(ctx context.Context, req *observability.SearchEventIDsRequest) (*observability.SearchEventIDsResponse, error) {
var resp observability.SearchEventIDsResponse
if err := c.doJSON(ctx, "POST", "/v1/observability/chat-completion-events/search-ids", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetChatCompletionEvent retrieves a chat completion event by ID.
func (c *Client) GetChatCompletionEvent(ctx context.Context, eventID string) (*observability.ChatCompletionEvent, error) {
var resp observability.ChatCompletionEvent
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/chat-completion-events/%s", eventID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetSimilarChatCompletionEvents retrieves events similar to a given event.
func (c *Client) GetSimilarChatCompletionEvents(ctx context.Context, eventID string, params *observability.PaginationParams) (*observability.SimilarEventsResponse, error) {
path := fmt.Sprintf("/v1/observability/chat-completion-events/%s/similar-events", eventID)
if params != nil {
q := url.Values{}
if params.PageSize != nil {
q.Set("page_size", strconv.Itoa(*params.PageSize))
}
if params.Page != nil {
q.Set("page", strconv.Itoa(*params.Page))
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp observability.SimilarEventsResponse
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// JudgeChatCompletionEvent judges a chat completion event.
func (c *Client) JudgeChatCompletionEvent(ctx context.Context, eventID string, req *observability.JudgeEventRequest) (json.RawMessage, error) {
var resp json.RawMessage
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/chat-completion-events/%s/live-judging", eventID), req, &resp); err != nil {
return nil, err
}
return resp, nil
}

View File

@@ -0,0 +1,101 @@
package mistral
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/observability"
)
func TestSearchChatCompletionEvents_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" || r.URL.Path != "/v1/observability/chat-completion-events/search" {
t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"results": []map[string]any{
{"event_id": "ev-1", "correlation_id": "c1", "created_at": "t", "nb_input_tokens": 10, "nb_output_tokens": 5},
},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.SearchChatCompletionEvents(context.Background(), &observability.SearchEventsRequest{})
if err != nil {
t.Fatal(err)
}
if len(resp.Results) != 1 {
t.Fatalf("got %d results", len(resp.Results))
}
if resp.Results[0].EventID != "ev-1" {
t.Errorf("got event_id %q", resp.Results[0].EventID)
}
}
func TestSearchChatCompletionEventIDs_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/chat-completion-events/search-ids" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"completion_event_ids": []string{"ev-1", "ev-2"},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.SearchChatCompletionEventIDs(context.Background(), &observability.SearchEventIDsRequest{})
if err != nil {
t.Fatal(err)
}
if len(resp.CompletionEventIDs) != 2 {
t.Errorf("got %d ids", len(resp.CompletionEventIDs))
}
}
func TestGetChatCompletionEvent_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/chat-completion-events/ev-1" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"event_id": "ev-1", "correlation_id": "c1", "created_at": "t",
"nb_input_tokens": 10, "nb_output_tokens": 5, "nb_messages": 2,
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.GetChatCompletionEvent(context.Background(), "ev-1")
if err != nil {
t.Fatal(err)
}
if resp.EventID != "ev-1" {
t.Errorf("got event_id %q", resp.EventID)
}
}
func TestGetSimilarChatCompletionEvents_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/chat-completion-events/ev-1/similar-events" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(map[string]any{
"count": 0, "results": []any{},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.GetSimilarChatCompletionEvents(context.Background(), "ev-1", nil)
if err != nil {
t.Fatal(err)
}
if resp.Count != 0 {
t.Errorf("got count %d", resp.Count)
}
}

85
observability_judges.go Normal file
View File

@@ -0,0 +1,85 @@
package mistral
import (
"context"
"encoding/json"
"fmt"
"net/url"
"strconv"
"somegit.dev/vikingowl/mistral-go-sdk/observability"
)
// CreateJudge creates a new observability judge.
func (c *Client) CreateJudge(ctx context.Context, req *observability.CreateJudgeRequest) (*observability.Judge, error) {
var resp observability.Judge
if err := c.doJSON(ctx, "POST", "/v1/observability/judges", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ListJudges lists observability judges.
func (c *Client) ListJudges(ctx context.Context, params *observability.SearchParams) (*observability.ListJudgesResponse, error) {
path := "/v1/observability/judges"
if params != nil {
q := url.Values{}
if params.PageSize != nil {
q.Set("page_size", strconv.Itoa(*params.PageSize))
}
if params.Page != nil {
q.Set("page", strconv.Itoa(*params.Page))
}
if params.Q != nil {
q.Set("q", *params.Q)
}
if encoded := q.Encode(); encoded != "" {
path += "?" + encoded
}
}
var resp observability.ListJudgesResponse
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetJudge retrieves a judge by ID.
func (c *Client) GetJudge(ctx context.Context, judgeID string) (*observability.Judge, error) {
var resp observability.Judge
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/judges/%s", judgeID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// UpdateJudge updates a judge.
func (c *Client) UpdateJudge(ctx context.Context, judgeID string, req *observability.UpdateJudgeRequest) (*observability.Judge, error) {
var resp observability.Judge
if err := c.doJSON(ctx, "PUT", fmt.Sprintf("/v1/observability/judges/%s", judgeID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// DeleteJudge deletes a judge.
func (c *Client) DeleteJudge(ctx context.Context, judgeID string) error {
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/observability/judges/%s", judgeID), nil)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return parseAPIError(resp)
}
return nil
}
// JudgeConversation performs live judging on a conversation.
func (c *Client) JudgeConversation(ctx context.Context, judgeID string, req *observability.JudgeConversationRequest) (json.RawMessage, error) {
var resp json.RawMessage
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/judges/%s/live-judging", judgeID), req, &resp); err != nil {
return nil, err
}
return resp, nil
}

View File

@@ -0,0 +1,123 @@
package mistral
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/observability"
)
func judgeJSON() map[string]any {
return map[string]any{
"id": "j1", "created_at": "t", "updated_at": "t",
"owner_id": "o", "workspace_id": "w", "name": "quality",
"description": "d", "model_name": "m", "instructions": "i",
"output": map[string]any{"type": "CLASSIFICATION", "options": []any{}},
}
}
func TestCreateJudge_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" || r.URL.Path != "/v1/observability/judges" {
t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
}
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(judgeJSON())
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.CreateJudge(context.Background(), &observability.CreateJudgeRequest{
Name: "quality",
Description: "d",
ModelName: "m",
Instructions: "i",
Tools: []string{},
Output: json.RawMessage(`{"type":"CLASSIFICATION","options":[]}`),
})
if err != nil {
t.Fatal(err)
}
if resp.ID != "j1" {
t.Errorf("got id %q", resp.ID)
}
}
func TestListJudges_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
json.NewEncoder(w).Encode(map[string]any{
"count": 1,
"results": []any{judgeJSON()},
})
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.ListJudges(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
if resp.Count != 1 {
t.Errorf("got count %d", resp.Count)
}
}
func TestGetJudge_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/observability/judges/j1" {
t.Errorf("got path %s", r.URL.Path)
}
json.NewEncoder(w).Encode(judgeJSON())
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
resp, err := client.GetJudge(context.Background(), "j1")
if err != nil {
t.Fatal(err)
}
if resp.Name != "quality" {
t.Errorf("got name %q", resp.Name)
}
}
func TestUpdateJudge_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "PUT" {
t.Errorf("expected PUT, got %s", r.Method)
}
json.NewEncoder(w).Encode(judgeJSON())
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
_, err := client.UpdateJudge(context.Background(), "j1", &observability.UpdateJudgeRequest{
Name: "quality",
Description: "d",
ModelName: "m",
Instructions: "i",
Tools: []string{},
Output: json.RawMessage(`{"type":"CLASSIFICATION","options":[]}`),
})
if err != nil {
t.Fatal(err)
}
}
func TestDeleteJudge_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "DELETE" {
t.Errorf("expected DELETE")
}
w.WriteHeader(http.StatusNoContent)
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
if err := client.DeleteJudge(context.Background(), "j1"); err != nil {
t.Fatal(err)
}
}