diff --git a/CHANGELOG.md b/CHANGELOG.md index c9f55f0..05a018c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,37 @@ +## v1.1.0 — 2026-03-24 + +Upstream sync with Python SDK v2.1.3. Adds Connectors, Audio Speech/Voices, and Observability (beta). + +### Breaking Changes + +- **`ListModels`** signature changed from `(ctx)` to `(ctx, *model.ListParams)`. + Pass `nil` for previous behavior. The new `ListParams` supports `Provider` and + `Model` query filters. +- **`UploadFile`** signature changed from `(ctx, filename, reader, purpose)` to + `(ctx, filename, reader, *file.UploadParams)`. The new `UploadParams` struct + holds `Purpose`, `Expiry`, and `Visibility` fields. + +### Added + +- **`ReasoningEffort`** field on `chat.CompletionRequest` and + `agents.CompletionRequest` — controls reasoning effort (`"none"`, `"high"`). +- **Connectors API** (new `connector/` package) — `CreateConnector`, + `ListConnectors`, `GetConnector`, `UpdateConnector`, `DeleteConnector`, + `GetConnectorAuthURL`, `ListConnectorTools`, `CallConnectorTool`. +- **Audio Speech (TTS)** — `Speech`, `SpeechStream` with `SpeechStream` typed + wrapper, `SpeechOutputFormat` enum (pcm/wav/mp3/flac/opus). +- **Audio Voices** — `ListVoices`, `CreateVoice`, `GetVoice`, `UpdateVoice`, + `DeleteVoice`, `GetVoiceSampleAudio`. +- **Audio Realtime types** — `AudioEncoding`, `AudioFormat`, `RealtimeSession`, + and WebSocket message types in `audio/realtime.go`. No WebSocket client yet + (would require adding a dependency). +- **Observability API** (new `observability/` package, beta) — campaigns, + chat completion events, judges, datasets, records, and import tasks. + 33 service methods total. +- **`file.Visibility`** enum — `shared_global`, `shared_org`, + `shared_workspace`, `private`. +- **`model.ListParams`** — filter models by `Provider` and `Model`. + ## v1.0.0 — 2026-03-17 Stable release. Tracks upstream Python SDK v2.0.4. diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..a575e27 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,89 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project + +Idiomatic Go SDK for the Mistral AI API. Module path: `somegit.dev/vikingowl/mistral-go-sdk`. Requires Go 1.26+. Zero external dependencies (stdlib only). Tracks the upstream [Mistral Python SDK](https://github.com/mistralai/client-python) as reference for API surface and type definitions. + +## Repository layout + +- **Working directory**: `mistral-go-sdk/` — the Go SDK source. All development happens here. +- **`../client-python/`**: Clone of the upstream Mistral Python SDK. Read-only reference — pull/update it when checking for upstream API changes, but never modify it. + +## Commands + +```bash +# Run all unit tests +go test ./... + +# Run a single test +go test -run TestChatComplete_Success + +# Run integration tests (requires MISTRAL_API_KEY env var) +go test -tags=integration ./... + +# Vet and build +go vet ./... +go build ./... +``` + +No Makefile, linter config, or code generation tooling — standard `go test` / `go vet` / `go build`. + +## Architecture + +### Two-layer design: types in sub-packages, methods on `*Client` + +Sub-packages (`chat/`, `agents/`, `conversation/`, `embedding/`, `model/`, `file/`, `finetune/`, `batch/`, `ocr/`, `audio/`, `library/`, `moderation/`, `classification/`, `fim/`) are **types-only** — they define request/response structs and enums but contain no HTTP logic. All service methods live on `*Client` in the root package, prefix-namespaced by domain (e.g. `ChatComplete`, `AgentsComplete`, `CreateFineTuningJob`, `UploadFile`). + +### HTTP internals (request.go) + +All HTTP flows route through a small set of unexported helpers on `*Client`: +- `do()` — raw HTTP with auth headers + retry +- `doJSON()` — JSON marshal request → `do()` → unmarshal response +- `doStream()` — JSON request → raw `*http.Response` for SSE +- `doMultipart()` / `doMultipartStream()` — multipart file upload variants +- `doRetry()` — retry loop with exponential backoff + jitter + `Retry-After` parsing + +### Streaming + +Generic `Stream[T]` type wraps SSE (`sseReader`) with `Next()`/`Current()`/`Err()`/`Close()` iterator pattern. Typed wrappers `EventStream` (conversations) and `AudioStream` (transcription) unmarshal `json.RawMessage` into domain-specific event types. + +### Sealed interfaces for discriminated unions + +Polymorphic API types use **sealed interfaces** with unexported marker methods: +- `chat.Message` (marker: `isMessage()`) — `SystemMessage`, `UserMessage`, `AssistantMessage`, `ToolMessage` +- `chat.ContentChunk` (marker: `contentChunk()`) — `TextChunk`, `ImageURLChunk`, `DocumentURLChunk`, `FileChunk`, `ReferenceChunk`, `ThinkChunk`, `AudioChunk`, `ToolReferenceChunk`, `ToolFileChunk` +- `agents.AgentTool` (marker: `agentToolType()`) — `FunctionTool`, `WebSearchTool`, `CodeInterpreterTool`, `ConnectorTool`, etc. +- `conversation.Event` — conversation streaming events + +Each has an `Unknown*` variant so the SDK doesn't break on new API types. Each has a `Unmarshal*` dispatch function that probes a `type`/`role` discriminator field. + +### Custom JSON patterns + +Several types require non-trivial marshal/unmarshal: +- **Type alias trick** — `type alias T` inside `MarshalJSON` to avoid infinite recursion when injecting a `type`/`role` discriminator field. +- **`json:"-"` + custom MarshalJSON** — `CompletionRequest.Messages` (and `stream`) are excluded from default marshaling and injected via custom `MarshalJSON`. +- **Union types** — `Content` handles `string | null | []ContentChunk`; `ToolChoice` handles `string | object`; `ImageURL` handles `string | object`; `FunctionCall.Arguments` handles `string | object`; `ReferenceID` handles `int | string` with type preservation. +- **Probe struct pattern** — `Unmarshal*` functions decode only the discriminator field first, then dispatch to the concrete type. + +### Shared types in `chat/` + +`GuardrailConfig`, `ModerationLLMV1Config`, `ModerationLLMV2Config` live in `chat/` because it's the base types package imported by both `agents/` and `conversation/`. This avoids import cycles. + +### Error handling + +`APIError` in `error.go` with sentinel checkers: `IsNotFound()`, `IsRateLimit()`, `IsAuth()`. All use `errors.As` for unwrapping. + +## Testing patterns + +- Unit tests use `httptest.NewServer` with inline handlers to mock the Mistral API. Client is pointed at the test server via `WithBaseURL(server.URL)`. +- Integration tests are behind `//go:build integration` build tag and require `MISTRAL_API_KEY`. +- Tests use stdlib `testing` only — no third-party test frameworks. + +## Adding a new API endpoint + +1. Define request/response types in the appropriate sub-package (or create a new one with a `doc.go`). +2. Add a method on `*Client` in the root package. Use `doJSON` for standard request/response, `doStream` for SSE, `doMultipart` for file uploads. +3. Add unit tests with `httptest.NewServer`. +4. If the endpoint supports streaming, return `*Stream[T]` and call `EnableStream()` on the request before sending. diff --git a/README.md b/README.md index 6b6e7f6..b84d83d 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ The most complete Go client for the [Mistral AI API](https://docs.mistral.ai/). **Zero dependencies.** The entire SDK — including tests — uses only the Go standard library. No `go.sum`, no transitive dependency tree to audit, no version conflicts, no supply chain risk. -**Full API coverage.** 75 methods across every Mistral endpoint — including Conversations, Agents CRUD, Libraries, OCR, Audio, Fine-tuning, and Batch Jobs. No other Go SDK covers Conversations or Agents. +**Full API coverage.** 116 methods across every Mistral endpoint — including Connectors, Audio Speech/Voices, Conversations, Agents CRUD, Libraries, OCR, Observability, Fine-tuning, and Batch Jobs. No other Go SDK covers Conversations, Connectors, or Observability. **Typed streaming.** A generic pull-based `Stream[T]` iterator — no channels, no goroutines, no leaks. Just `Next()` / `Current()` / `Err()` / `Close()`. @@ -132,7 +132,7 @@ for stream.Next() { ## API Coverage -75 public methods on `Client`, grouped by domain: +116 public methods on `Client`, grouped by domain: | Domain | Methods | |--------|---------| @@ -140,6 +140,7 @@ for stream.Next() { | **FIM** | `FIMComplete`, `FIMCompleteStream` | | **Agents (completions)** | `AgentsComplete`, `AgentsCompleteStream` | | **Agents (CRUD)** | `CreateAgent`, `ListAgents`, `GetAgent`, `UpdateAgent`, `DeleteAgent`, `UpdateAgentVersion`, `ListAgentVersions`, `GetAgentVersion`, `SetAgentAlias`, `ListAgentAliases`, `DeleteAgentAlias` | +| **Connectors** | `CreateConnector`, `ListConnectors`, `GetConnector`, `UpdateConnector`, `DeleteConnector`, `GetConnectorAuthURL`, `ListConnectorTools`, `CallConnectorTool` | | **Conversations** | `StartConversation`, `StartConversationStream`, `AppendConversation`, `AppendConversationStream`, `RestartConversation`, `RestartConversationStream`, `GetConversation`, `ListConversations`, `DeleteConversation`, `GetConversationHistory`, `GetConversationMessages` | | **Models** | `ListModels`, `GetModel`, `DeleteModel` | | **Files** | `UploadFile`, `ListFiles`, `GetFile`, `DeleteFile`, `GetFileContent`, `GetFileURL` | @@ -147,10 +148,16 @@ for stream.Next() { | **Fine-tuning** | `CreateFineTuningJob`, `ListFineTuningJobs`, `GetFineTuningJob`, `CancelFineTuningJob`, `StartFineTuningJob`, `UpdateFineTunedModel`, `ArchiveFineTunedModel`, `UnarchiveFineTunedModel` | | **Batch** | `CreateBatchJob`, `ListBatchJobs`, `GetBatchJob`, `CancelBatchJob` | | **OCR** | `OCR` | -| **Audio** | `Transcribe`, `TranscribeStream` | +| **Audio (transcription)** | `Transcribe`, `TranscribeStream` | +| **Audio (speech)** | `Speech`, `SpeechStream` | +| **Audio (voices)** | `ListVoices`, `CreateVoice`, `GetVoice`, `UpdateVoice`, `DeleteVoice`, `GetVoiceSampleAudio` | | **Libraries** | `CreateLibrary`, `ListLibraries`, `GetLibrary`, `UpdateLibrary`, `DeleteLibrary`, `UploadDocument`, `ListDocuments`, `GetDocument`, `UpdateDocument`, `DeleteDocument`, `GetDocumentTextContent`, `GetDocumentStatus`, `GetDocumentSignedURL`, `GetDocumentExtractedTextSignedURL`, `ReprocessDocument`, `ListLibrarySharing`, `ShareLibrary`, `UnshareLibrary` | | **Moderation** | `Moderate`, `ModerateChat` | | **Classification** | `Classify`, `ClassifyChat` | +| **Observability (campaigns)** | `CreateCampaign`, `ListCampaigns`, `GetCampaign`, `DeleteCampaign`, `GetCampaignStatus`, `ListCampaignEvents` | +| **Observability (events)** | `SearchChatCompletionEvents`, `SearchChatCompletionEventIDs`, `GetChatCompletionEvent`, `GetSimilarChatCompletionEvents`, `JudgeChatCompletionEvent` | +| **Observability (judges)** | `CreateJudge`, `ListJudges`, `GetJudge`, `UpdateJudge`, `DeleteJudge`, `JudgeConversation` | +| **Observability (datasets)** | `CreateDataset`, `ListDatasets`, `GetDataset`, `UpdateDataset`, `DeleteDataset`, `ExportDatasetToJSONL`, `ListDatasetRecords`, `CreateDatasetRecord`, `GetDatasetRecord`, `UpdateDatasetRecordPayload`, `UpdateDatasetRecordProperties`, `DeleteDatasetRecord`, `BulkDeleteDatasetRecords`, `JudgeDatasetRecord`, `ImportDatasetFromCampaign`, `ImportDatasetFromExplorer`, `ImportDatasetFromFile`, `ImportDatasetFromPlayground`, `ImportDatasetFromDataset`, `ListDatasetTasks`, `GetDatasetTask` | ## Comparison @@ -163,11 +170,13 @@ There is no official Go SDK from Mistral AI (only Python and TypeScript). The ma | Embeddings | Yes | Yes | Yes | Yes | | Tool calling | Yes | No | No | No | | Agents (completions + CRUD) | Yes | No | No | No | +| Connectors (MCP) | Yes | No | No | No | | Conversations API | Yes | No | No | No | | Libraries / Documents | Yes | No | No | No | | Fine-tuning / Batch | Yes | No | No | No | | OCR | Yes | No | No | Yes | -| Audio transcription | Yes | No | No | No | +| Audio (transcription + TTS + voices) | Yes | No | No | No | +| Observability (beta) | Yes | No | No | No | | Moderation / Classification | Yes | No | No | No | | Vision (multimodal) | Yes | No | No | Yes | | Zero dependencies | Yes | test-only (testify) | test-only (testify) | test-only (testify) | @@ -221,6 +230,7 @@ as its upstream reference for API surface and type definitions. | SDK Version | Upstream Python SDK | |-------------|---------------------| +| v1.1.0 | v2.1.3 | | v1.0.0 | v2.0.4 | ## License diff --git a/agents/request.go b/agents/request.go index 735fd35..bba570d 100644 --- a/agents/request.go +++ b/agents/request.go @@ -26,6 +26,7 @@ type CompletionRequest struct { Prediction *chat.Prediction `json:"prediction,omitempty"` PromptMode *chat.PromptMode `json:"prompt_mode,omitempty"` Guardrails []chat.GuardrailConfig `json:"guardrails,omitempty"` + ReasoningEffort *chat.ReasoningEffort `json:"reasoning_effort,omitempty"` stream bool } diff --git a/agents_complete_test.go b/agents_complete_test.go index cbdaf95..2f347ea 100644 --- a/agents_complete_test.go +++ b/agents_complete_test.go @@ -114,6 +114,39 @@ func TestAgentsComplete_WithTools(t *testing.T) { } } +func TestAgentsComplete_ReasoningEffort(t *testing.T) { + effort := chat.ReasoningEffortHigh + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body map[string]any + json.NewDecoder(r.Body).Decode(&body) + if body["reasoning_effort"] != "high" { + t.Errorf("expected reasoning_effort=high, got %v", body["reasoning_effort"]) + } + + json.NewEncoder(w).Encode(map[string]any{ + "id": "a-re", "object": "chat.completion", + "model": "m", "created": 0, + "choices": []map[string]any{{ + "index": 0, "message": map[string]any{"role": "assistant", "content": "ok"}, + "finish_reason": "stop", + }}, + "usage": map[string]any{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + _, err := client.AgentsComplete(context.Background(), &agents.CompletionRequest{ + AgentID: "agent-1", + Messages: []chat.Message{&chat.UserMessage{Content: chat.TextContent("Hi")}}, + ReasoningEffort: &effort, + }) + if err != nil { + t.Fatal(err) + } +} + func TestAgentsCompleteStream_Success(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var body map[string]any diff --git a/audio/doc.go b/audio/doc.go index 01211b8..d4a310d 100644 --- a/audio/doc.go +++ b/audio/doc.go @@ -1,5 +1,25 @@ -// Package audio provides types for the Mistral audio transcription API. +// Package audio provides types for the Mistral audio APIs. // +// # Transcription +// +// [TranscriptionRequest] and [TranscriptionResponse] handle speech-to-text. // Streaming transcription returns typed [StreamEvent] values via a sealed // interface dispatched by the "type" field. +// +// # Speech (TTS) +// +// [SpeechRequest] and [SpeechResponse] handle text-to-speech. +// Streaming speech returns typed [SpeechStreamEvent] values +// ([SpeechAudioDelta] and [SpeechDone]). +// +// # Voices +// +// [VoiceResponse], [VoiceCreateRequest], and [VoiceUpdateRequest] manage +// custom voices for speech synthesis. +// +// # Realtime +// +// Realtime transcription types ([AudioEncoding], [AudioFormat], +// [RealtimeSession], and WebSocket message types) are defined here. +// The WebSocket client is not yet implemented. package audio diff --git a/audio/realtime.go b/audio/realtime.go new file mode 100644 index 0000000..a023fd8 --- /dev/null +++ b/audio/realtime.go @@ -0,0 +1,75 @@ +package audio + +// AudioEncoding is the encoding format for realtime audio streams. +type AudioEncoding string + +const ( + EncodingPCMS16LE AudioEncoding = "pcm_s16le" + EncodingPCMS32LE AudioEncoding = "pcm_s32le" + EncodingPCMF16LE AudioEncoding = "pcm_f16le" + EncodingPCMF32LE AudioEncoding = "pcm_f32le" + EncodingPCMMulaw AudioEncoding = "pcm_mulaw" + EncodingPCMAlaw AudioEncoding = "pcm_alaw" +) + +// AudioFormat describes the encoding and sample rate for realtime audio. +type AudioFormat struct { + Encoding AudioEncoding `json:"encoding"` + SampleRate int `json:"sample_rate"` +} + +// RealtimeSession describes a realtime transcription session. +type RealtimeSession struct { + RequestID string `json:"request_id"` + Model string `json:"model"` + AudioFormat AudioFormat `json:"audio_format"` + TargetStreamingDelayMs *int `json:"target_streaming_delay_ms,omitempty"` +} + +// RealtimeSessionUpdate is sent to update session parameters. +// Parameters can only be changed before audio transmission starts. +type RealtimeSessionUpdate struct { + AudioFormat *AudioFormat `json:"audio_format,omitempty"` + TargetStreamingDelayMs *int `json:"target_streaming_delay_ms,omitempty"` +} + +// InputAudioAppend sends a chunk of audio data. +// Audio is base64-encoded (max 262144 bytes decoded). +type InputAudioAppend struct { + Type string `json:"type"` // "input_audio.append" + Audio string `json:"audio"` +} + +// InputAudioFlush flushes the audio buffer. +type InputAudioFlush struct { + Type string `json:"type"` // "input_audio.flush" +} + +// InputAudioEnd signals the end of audio input. +type InputAudioEnd struct { + Type string `json:"type"` // "input_audio.end" +} + +// RealtimeSessionCreated is received when a session is created. +type RealtimeSessionCreated struct { + Type string `json:"type"` // "session.created" + Session RealtimeSession `json:"session"` +} + +// RealtimeSessionUpdated is received when a session is updated. +type RealtimeSessionUpdated struct { + Type string `json:"type"` // "session.updated" + Session RealtimeSession `json:"session"` +} + +// RealtimeErrorDetail describes a realtime error. +type RealtimeErrorDetail struct { + Message string `json:"message"` + Code int `json:"code"` +} + +// RealtimeError is received on error. +type RealtimeError struct { + Type string `json:"type"` // "error" + Error RealtimeErrorDetail `json:"error"` +} diff --git a/audio/speech.go b/audio/speech.go new file mode 100644 index 0000000..6ec99df --- /dev/null +++ b/audio/speech.go @@ -0,0 +1,88 @@ +package audio + +import ( + "encoding/json" + "fmt" +) + +// SpeechOutputFormat is the output audio format for speech synthesis. +type SpeechOutputFormat string + +const ( + SpeechFormatPCM SpeechOutputFormat = "pcm" + SpeechFormatWAV SpeechOutputFormat = "wav" + SpeechFormatMP3 SpeechOutputFormat = "mp3" + SpeechFormatFLAC SpeechOutputFormat = "flac" + SpeechFormatOpus SpeechOutputFormat = "opus" +) + +// SpeechRequest represents a text-to-speech request. +type SpeechRequest struct { + Input string `json:"input"` + Model string `json:"model"` + Metadata map[string]any `json:"metadata,omitempty"` + VoiceID *string `json:"voice_id,omitempty"` + RefAudio *string `json:"ref_audio,omitempty"` + ResponseFormat *SpeechOutputFormat `json:"response_format,omitempty"` + stream bool +} + +// EnableStream is used internally to enable streaming. +func (r *SpeechRequest) EnableStream() { r.stream = true } + +func (r *SpeechRequest) MarshalJSON() ([]byte, error) { + type Alias SpeechRequest + return json.Marshal(&struct { + Stream bool `json:"stream"` + *Alias + }{ + Stream: r.stream, + Alias: (*Alias)(r), + }) +} + +// SpeechResponse is the response from a non-streaming speech request. +type SpeechResponse struct { + AudioData string `json:"audio_data"` +} + +// SpeechStreamEvent is a sealed interface for speech streaming events. +type SpeechStreamEvent interface { + speechStreamEvent() +} + +// SpeechAudioDelta contains a chunk of audio data during streaming. +type SpeechAudioDelta struct { + Type string `json:"type"` + AudioData string `json:"audio_data"` +} + +func (*SpeechAudioDelta) speechStreamEvent() {} + +// SpeechDone is emitted when speech synthesis is complete. +type SpeechDone struct { + Type string `json:"type"` + Usage UsageInfo `json:"usage"` +} + +func (*SpeechDone) speechStreamEvent() {} + +// UnmarshalSpeechStreamEvent dispatches a raw JSON event to the correct type. +func UnmarshalSpeechStreamEvent(data []byte) (SpeechStreamEvent, error) { + var probe struct { + Type string `json:"type"` + } + if err := json.Unmarshal(data, &probe); err != nil { + return nil, err + } + switch probe.Type { + case "speech.audio.delta": + var e SpeechAudioDelta + return &e, json.Unmarshal(data, &e) + case "speech.audio.done": + var e SpeechDone + return &e, json.Unmarshal(data, &e) + default: + return nil, fmt.Errorf("unknown speech stream event type: %q", probe.Type) + } +} diff --git a/audio/voice.go b/audio/voice.go new file mode 100644 index 0000000..0001bd1 --- /dev/null +++ b/audio/voice.go @@ -0,0 +1,48 @@ +package audio + +// VoiceResponse represents a voice entity. +type VoiceResponse struct { + Name string `json:"name"` + ID string `json:"id"` + CreatedAt string `json:"created_at"` + UserID *string `json:"user_id,omitempty"` + Slug *string `json:"slug,omitempty"` + Languages []string `json:"languages,omitempty"` + Gender *string `json:"gender,omitempty"` + Age *int `json:"age,omitempty"` + Tags []string `json:"tags,omitempty"` + Color *string `json:"color,omitempty"` + RetentionNotice *int `json:"retention_notice,omitempty"` +} + +// VoiceCreateRequest creates a custom voice. +type VoiceCreateRequest struct { + Name string `json:"name"` + SampleAudio string `json:"sample_audio"` + Slug *string `json:"slug,omitempty"` + Languages []string `json:"languages,omitempty"` + Gender *string `json:"gender,omitempty"` + Age *int `json:"age,omitempty"` + Tags []string `json:"tags,omitempty"` + Color *string `json:"color,omitempty"` + RetentionNotice *int `json:"retention_notice,omitempty"` + SampleFilename *string `json:"sample_filename,omitempty"` +} + +// VoiceUpdateRequest updates a voice. +type VoiceUpdateRequest struct { + Name *string `json:"name,omitempty"` + Languages []string `json:"languages,omitempty"` + Gender *string `json:"gender,omitempty"` + Age *int `json:"age,omitempty"` + Tags []string `json:"tags,omitempty"` +} + +// VoiceListResponse is the response from listing voices. +type VoiceListResponse struct { + Items []VoiceResponse `json:"items"` + Total int `json:"total"` + Page int `json:"page"` + PageSize int `json:"page_size"` + TotalPages int `json:"total_pages"` +} diff --git a/audio_api.go b/audio_api.go index e601861..e1ef28b 100644 --- a/audio_api.go +++ b/audio_api.go @@ -3,7 +3,9 @@ package mistral import ( "context" "encoding/json" + "fmt" "io" + "net/http" "somegit.dev/vikingowl/mistral-go-sdk/audio" ) @@ -95,3 +97,125 @@ func (s *AudioStream) Err() error { return s.err } // Close releases the underlying connection. func (s *AudioStream) Close() error { return s.stream.Close() } + +// Speech sends a text-to-speech request and returns the full response. +func (c *Client) Speech(ctx context.Context, req *audio.SpeechRequest) (*audio.SpeechResponse, error) { + var resp audio.SpeechResponse + if err := c.doJSON(ctx, "POST", "/v1/audio/speech", req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// SpeechStream sends a text-to-speech request and returns a stream of audio events. +func (c *Client) SpeechStream(ctx context.Context, req *audio.SpeechRequest) (*SpeechStream, error) { + req.EnableStream() + resp, err := c.doStream(ctx, "POST", "/v1/audio/speech", req) + if err != nil { + return nil, err + } + return newSpeechStream(resp.Body), nil +} + +// SpeechStream wraps the generic Stream for speech streaming events. +type SpeechStream struct { + stream *Stream[json.RawMessage] + event audio.SpeechStreamEvent + err error +} + +func newSpeechStream(body readCloser) *SpeechStream { + return &SpeechStream{ + stream: newStream[json.RawMessage](body), + } +} + +// Next advances to the next event. Returns false when done or on error. +func (s *SpeechStream) Next() bool { + if s.err != nil { + return false + } + if !s.stream.Next() { + s.err = s.stream.Err() + return false + } + event, err := audio.UnmarshalSpeechStreamEvent(s.stream.Current()) + if err != nil { + s.err = err + return false + } + s.event = event + return true +} + +// Current returns the most recently read event. +func (s *SpeechStream) Current() audio.SpeechStreamEvent { return s.event } + +// Err returns any error encountered during streaming. +func (s *SpeechStream) Err() error { return s.err } + +// Close releases the underlying connection. +func (s *SpeechStream) Close() error { return s.stream.Close() } + +// ListVoices returns available voices. +func (c *Client) ListVoices(ctx context.Context) (*audio.VoiceListResponse, error) { + var resp audio.VoiceListResponse + if err := c.doJSON(ctx, "GET", "/v1/audio/voices", nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// CreateVoice creates a custom voice. +func (c *Client) CreateVoice(ctx context.Context, req *audio.VoiceCreateRequest) (*audio.VoiceResponse, error) { + var resp audio.VoiceResponse + if err := c.doJSON(ctx, "POST", "/v1/audio/voices", req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// GetVoice retrieves a voice by ID. +func (c *Client) GetVoice(ctx context.Context, voiceID string) (*audio.VoiceResponse, error) { + var resp audio.VoiceResponse + if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/audio/voices/%s", voiceID), nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// UpdateVoice updates a voice. +func (c *Client) UpdateVoice(ctx context.Context, voiceID string, req *audio.VoiceUpdateRequest) (*audio.VoiceResponse, error) { + var resp audio.VoiceResponse + if err := c.doJSON(ctx, "PATCH", fmt.Sprintf("/v1/audio/voices/%s", voiceID), req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// DeleteVoice deletes a voice. +func (c *Client) DeleteVoice(ctx context.Context, voiceID string) error { + resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/audio/voices/%s", voiceID), nil) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode >= 400 { + return parseAPIError(resp) + } + return nil +} + +// GetVoiceSampleAudio retrieves the sample audio for a voice. +// Returns the raw HTTP response; the caller must close the body. +func (c *Client) GetVoiceSampleAudio(ctx context.Context, voiceID string) (*http.Response, error) { + resp, err := c.do(ctx, "GET", fmt.Sprintf("/v1/audio/voices/%s/sample", voiceID), nil) + if err != nil { + return nil, err + } + if resp.StatusCode >= 400 { + defer resp.Body.Close() + return nil, parseAPIError(resp) + } + return resp, nil +} diff --git a/audio_speech_test.go b/audio_speech_test.go new file mode 100644 index 0000000..5d4ce89 --- /dev/null +++ b/audio_speech_test.go @@ -0,0 +1,217 @@ +package mistral + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "somegit.dev/vikingowl/mistral-go-sdk/audio" +) + +func TestSpeech_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("expected POST, got %s", r.Method) + } + if r.URL.Path != "/v1/audio/speech" { + t.Errorf("got path %s", r.URL.Path) + } + var body map[string]any + json.NewDecoder(r.Body).Decode(&body) + if body["input"] != "Hello world" { + t.Errorf("got input %v", body["input"]) + } + if body["stream"] != false { + t.Errorf("expected stream=false") + } + json.NewEncoder(w).Encode(map[string]any{ + "audio_data": "base64audiodata==", + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.Speech(context.Background(), &audio.SpeechRequest{ + Input: "Hello world", + Model: "mistral-speech", + }) + if err != nil { + t.Fatal(err) + } + if resp.AudioData != "base64audiodata==" { + t.Errorf("got audio_data %q", resp.AudioData) + } +} + +func TestSpeechStream_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body map[string]any + json.NewDecoder(r.Body).Decode(&body) + if body["stream"] != true { + t.Errorf("expected stream=true") + } + + w.Header().Set("Content-Type", "text/event-stream") + flusher, _ := w.(http.Flusher) + + delta, _ := json.Marshal(map[string]any{ + "type": "speech.audio.delta", "audio_data": "chunk1==", + }) + fmt.Fprintf(w, "data: %s\n\n", delta) + flusher.Flush() + + done, _ := json.Marshal(map[string]any{ + "type": "speech.audio.done", + "usage": map[string]any{ + "prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15, + }, + }) + fmt.Fprintf(w, "data: %s\n\n", done) + flusher.Flush() + + fmt.Fprint(w, "data: [DONE]\n\n") + flusher.Flush() + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + stream, err := client.SpeechStream(context.Background(), &audio.SpeechRequest{ + Input: "Hi", + Model: "mistral-speech", + }) + if err != nil { + t.Fatal(err) + } + defer stream.Close() + + var events int + for stream.Next() { + events++ + } + if err := stream.Err(); err != nil { + t.Fatal(err) + } + if events != 2 { + t.Errorf("got %d events, want 2", events) + } +} + +func TestListVoices_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/audio/voices" { + t.Errorf("got path %s", r.URL.Path) + } + json.NewEncoder(w).Encode(map[string]any{ + "items": []map[string]any{ + {"id": "v1", "name": "Default", "created_at": "2025-01-01"}, + }, + "total": 1, "page": 1, "page_size": 10, "total_pages": 1, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.ListVoices(context.Background()) + if err != nil { + t.Fatal(err) + } + if len(resp.Items) != 1 { + t.Fatalf("got %d voices", len(resp.Items)) + } + if resp.Items[0].ID != "v1" { + t.Errorf("got id %q", resp.Items[0].ID) + } +} + +func TestCreateVoice_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("expected POST, got %s", r.Method) + } + var body map[string]any + json.NewDecoder(r.Body).Decode(&body) + if body["name"] != "MyVoice" { + t.Errorf("got name %v", body["name"]) + } + json.NewEncoder(w).Encode(map[string]any{ + "id": "v2", "name": "MyVoice", "created_at": "2025-01-01", + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.CreateVoice(context.Background(), &audio.VoiceCreateRequest{ + Name: "MyVoice", + SampleAudio: "base64audio==", + }) + if err != nil { + t.Fatal(err) + } + if resp.ID != "v2" { + t.Errorf("got id %q", resp.ID) + } +} + +func TestGetVoice_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/audio/voices/v1" { + t.Errorf("got path %s", r.URL.Path) + } + json.NewEncoder(w).Encode(map[string]any{ + "id": "v1", "name": "Default", "created_at": "2025-01-01", + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.GetVoice(context.Background(), "v1") + if err != nil { + t.Fatal(err) + } + if resp.Name != "Default" { + t.Errorf("got name %q", resp.Name) + } +} + +func TestUpdateVoice_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "PATCH" { + t.Errorf("expected PATCH, got %s", r.Method) + } + json.NewEncoder(w).Encode(map[string]any{ + "id": "v1", "name": "Renamed", "created_at": "2025-01-01", + }) + })) + defer server.Close() + + name := "Renamed" + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.UpdateVoice(context.Background(), "v1", &audio.VoiceUpdateRequest{ + Name: &name, + }) + if err != nil { + t.Fatal(err) + } + if resp.Name != "Renamed" { + t.Errorf("got name %q", resp.Name) + } +} + +func TestDeleteVoice_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "DELETE" { + t.Errorf("expected DELETE, got %s", r.Method) + } + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + err := client.DeleteVoice(context.Background(), "v1") + if err != nil { + t.Fatal(err) + } +} diff --git a/chat/request.go b/chat/request.go index e090e5c..8708965 100644 --- a/chat/request.go +++ b/chat/request.go @@ -9,6 +9,14 @@ const ( PromptModeReasoning PromptMode = "reasoning" ) +// ReasoningEffort controls the amount of reasoning effort the model uses. +type ReasoningEffort string + +const ( + ReasoningEffortNone ReasoningEffort = "none" + ReasoningEffortHigh ReasoningEffort = "high" +) + // Prediction provides expected completion content for optimization. type Prediction struct { Type string `json:"type"` @@ -36,6 +44,7 @@ type CompletionRequest struct { Prediction *Prediction `json:"prediction,omitempty"` PromptMode *PromptMode `json:"prompt_mode,omitempty"` Guardrails []GuardrailConfig `json:"guardrails,omitempty"` + ReasoningEffort *ReasoningEffort `json:"reasoning_effort,omitempty"` stream bool } diff --git a/chat_complete_test.go b/chat_complete_test.go index c35d771..ac8994d 100644 --- a/chat_complete_test.go +++ b/chat_complete_test.go @@ -350,6 +350,41 @@ func TestChatComplete_RequestBody(t *testing.T) { } } +func TestChatComplete_ReasoningEffort(t *testing.T) { + effort := chat.ReasoningEffortHigh + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + bodyBytes, _ := io.ReadAll(r.Body) + var body map[string]any + json.Unmarshal(bodyBytes, &body) + + if body["reasoning_effort"] != "high" { + t.Errorf("expected reasoning_effort=high, got %v", body["reasoning_effort"]) + } + + json.NewEncoder(w).Encode(map[string]any{ + "id": "chat-re", "object": "chat.completion", + "model": "m", "created": 0, + "choices": []map[string]any{{ + "index": 0, "message": map[string]any{"role": "assistant", "content": "ok"}, + "finish_reason": "stop", + }}, + "usage": map[string]any{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + _, err := client.ChatComplete(context.Background(), &chat.CompletionRequest{ + Model: "m", + Messages: []chat.Message{&chat.UserMessage{Content: chat.TextContent("Hi")}}, + ReasoningEffort: &effort, + }) + if err != nil { + t.Fatal(err) + } +} + func TestChatComplete_ContextCanceled(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Never responds — context should cancel first diff --git a/connector/connector.go b/connector/connector.go new file mode 100644 index 0000000..15216be --- /dev/null +++ b/connector/connector.go @@ -0,0 +1,100 @@ +package connector + +import "encoding/json" + +// Visibility controls who can see a connector or tool. +type Visibility string + +const ( + VisibilitySharedGlobal Visibility = "shared_global" + VisibilitySharedOrg Visibility = "shared_org" + VisibilitySharedWorkspace Visibility = "shared_workspace" + VisibilityPrivate Visibility = "private" +) + +// AuthData holds OAuth2 client credentials for a connector. +type AuthData struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` +} + +// Connector represents a registered MCP connector. +type Connector struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + CreatedAt string `json:"created_at"` + ModifiedAt string `json:"modified_at"` + Server *string `json:"server,omitempty"` + AuthType *string `json:"auth_type,omitempty"` +} + +// CreateRequest creates a new connector. +type CreateRequest struct { + Name string `json:"name"` + Description string `json:"description"` + Server string `json:"server"` + IconURL *string `json:"icon_url,omitempty"` + Visibility *Visibility `json:"visibility,omitempty"` + Headers map[string]string `json:"headers,omitempty"` + AuthData *AuthData `json:"auth_data,omitempty"` + SystemPrompt *string `json:"system_prompt,omitempty"` +} + +// UpdateRequest updates an existing connector. +type UpdateRequest struct { + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` + IconURL *string `json:"icon_url,omitempty"` + SystemPrompt *string `json:"system_prompt,omitempty"` + Server *string `json:"server,omitempty"` + Headers map[string]string `json:"headers,omitempty"` + AuthData *AuthData `json:"auth_data,omitempty"` + ConnectionConfig map[string]any `json:"connection_config,omitempty"` + ConnectionSecrets map[string]any `json:"connection_secrets,omitempty"` +} + +// AuthURLResponse is the response from getting a connector's OAuth URL. +type AuthURLResponse struct { + AuthURL string `json:"auth_url"` + TTL int `json:"ttl"` +} + +// CallToolRequest is the request body for calling a connector tool. +type CallToolRequest struct { + Arguments map[string]any `json:"arguments,omitempty"` +} + +// CallToolResponse is the response from calling a connector tool. +// Content is left as raw JSON because the upstream API returns a union +// of 5 content types (text, image, audio, resource link, embedded resource). +type CallToolResponse struct { + Content json.RawMessage `json:"content"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// Tool represents a tool exposed by a connector. +type Tool struct { + ID string `json:"id"` + Name string `json:"name"` + Description *string `json:"description,omitempty"` + Visibility Visibility `json:"visibility,omitempty"` + CreatedAt string `json:"created_at,omitempty"` + ModifiedAt string `json:"modified_at,omitempty"` + SystemPrompt *string `json:"system_prompt,omitempty"` + JsonSchema map[string]any `json:"jsonschema,omitempty"` + Active *bool `json:"active,omitempty"` +} + +// ListParams holds query parameters for listing connectors. +type ListParams struct { + Page *int + PageSize *int +} + +// ListToolsParams holds query parameters for listing connector tools. +type ListToolsParams struct { + Page *int + PageSize *int + Refresh *bool +} diff --git a/connector/doc.go b/connector/doc.go new file mode 100644 index 0000000..55f2ccf --- /dev/null +++ b/connector/doc.go @@ -0,0 +1,6 @@ +// Package connector provides types for the Mistral connectors API. +// +// Connectors represent MCP (Model Context Protocol) server integrations. +// Use [CreateRequest] to register a new connector, then use tools +// discovered via the list-tools endpoint in chat or agent completions. +package connector diff --git a/connectors.go b/connectors.go new file mode 100644 index 0000000..9600562 --- /dev/null +++ b/connectors.go @@ -0,0 +1,119 @@ +package mistral + +import ( + "context" + "fmt" + "net/url" + "strconv" + + "somegit.dev/vikingowl/mistral-go-sdk/connector" +) + +// CreateConnector registers a new MCP connector. +func (c *Client) CreateConnector(ctx context.Context, req *connector.CreateRequest) (*connector.Connector, error) { + var resp connector.Connector + if err := c.doJSON(ctx, "POST", "/v1/connectors", req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// ListConnectors returns all connectors. +func (c *Client) ListConnectors(ctx context.Context, params *connector.ListParams) ([]connector.Connector, error) { + path := "/v1/connectors" + if params != nil { + q := url.Values{} + if params.Page != nil { + q.Set("page", strconv.Itoa(*params.Page)) + } + if params.PageSize != nil { + q.Set("page_size", strconv.Itoa(*params.PageSize)) + } + if encoded := q.Encode(); encoded != "" { + path += "?" + encoded + } + } + var resp []connector.Connector + if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil { + return nil, err + } + return resp, nil +} + +// GetConnector retrieves a connector by ID or name. +func (c *Client) GetConnector(ctx context.Context, idOrName string) (*connector.Connector, error) { + var resp connector.Connector + if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/connectors/%s", idOrName), nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// UpdateConnector updates an existing connector. +func (c *Client) UpdateConnector(ctx context.Context, idOrName string, req *connector.UpdateRequest) (*connector.Connector, error) { + var resp connector.Connector + if err := c.doJSON(ctx, "PATCH", fmt.Sprintf("/v1/connectors/%s", idOrName), req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// DeleteConnector deletes a connector. +func (c *Client) DeleteConnector(ctx context.Context, idOrName string) error { + resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/connectors/%s", idOrName), nil) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode >= 400 { + return parseAPIError(resp) + } + return nil +} + +// GetConnectorAuthURL returns the OAuth2 authorization URL for a connector. +func (c *Client) GetConnectorAuthURL(ctx context.Context, idOrName string, appReturnURL *string) (*connector.AuthURLResponse, error) { + path := fmt.Sprintf("/v1/connectors/%s/auth_url", idOrName) + if appReturnURL != nil { + path += "?app_return_url=" + url.QueryEscape(*appReturnURL) + } + var resp connector.AuthURLResponse + if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// ListConnectorTools lists tools exposed by a connector. +func (c *Client) ListConnectorTools(ctx context.Context, idOrName string, params *connector.ListToolsParams) ([]connector.Tool, error) { + path := fmt.Sprintf("/v1/connectors/%s/tools", idOrName) + if params != nil { + q := url.Values{} + if params.Page != nil { + q.Set("page", strconv.Itoa(*params.Page)) + } + if params.PageSize != nil { + q.Set("page_size", strconv.Itoa(*params.PageSize)) + } + if params.Refresh != nil { + q.Set("refresh", strconv.FormatBool(*params.Refresh)) + } + if encoded := q.Encode(); encoded != "" { + path += "?" + encoded + } + } + var resp []connector.Tool + if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil { + return nil, err + } + return resp, nil +} + +// CallConnectorTool invokes a tool on a connector. +func (c *Client) CallConnectorTool(ctx context.Context, idOrName, toolName string, req *connector.CallToolRequest) (*connector.CallToolResponse, error) { + var resp connector.CallToolResponse + if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/connectors/%s/tools/%s/call", idOrName, toolName), req, &resp); err != nil { + return nil, err + } + return &resp, nil +} diff --git a/connectors_test.go b/connectors_test.go new file mode 100644 index 0000000..7e30c58 --- /dev/null +++ b/connectors_test.go @@ -0,0 +1,217 @@ +package mistral + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "somegit.dev/vikingowl/mistral-go-sdk/connector" +) + +func TestCreateConnector_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("expected POST, got %s", r.Method) + } + if r.URL.Path != "/v1/connectors" { + t.Errorf("got path %s", r.URL.Path) + } + var body map[string]any + json.NewDecoder(r.Body).Decode(&body) + if body["name"] != "my_connector" { + t.Errorf("got name %v", body["name"]) + } + if body["server"] != "https://mcp.example.com" { + t.Errorf("got server %v", body["server"]) + } + json.NewEncoder(w).Encode(map[string]any{ + "id": "conn-1", "name": "my_connector", + "description": "test", "created_at": "2025-01-01", + "modified_at": "2025-01-01", "server": "https://mcp.example.com", + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.CreateConnector(context.Background(), &connector.CreateRequest{ + Name: "my_connector", + Description: "test", + Server: "https://mcp.example.com", + }) + if err != nil { + t.Fatal(err) + } + if resp.ID != "conn-1" { + t.Errorf("got id %q", resp.ID) + } +} + +func TestListConnectors_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" { + t.Errorf("expected GET, got %s", r.Method) + } + json.NewEncoder(w).Encode([]map[string]any{ + {"id": "c1", "name": "conn1", "description": "d1", "created_at": "t", "modified_at": "t"}, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + list, err := client.ListConnectors(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + if len(list) != 1 { + t.Fatalf("got %d connectors", len(list)) + } + if list[0].ID != "c1" { + t.Errorf("got id %q", list[0].ID) + } +} + +func TestGetConnector_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/connectors/my_conn" { + t.Errorf("got path %s", r.URL.Path) + } + json.NewEncoder(w).Encode(map[string]any{ + "id": "c1", "name": "my_conn", "description": "d", + "created_at": "t", "modified_at": "t", + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + c, err := client.GetConnector(context.Background(), "my_conn") + if err != nil { + t.Fatal(err) + } + if c.Name != "my_conn" { + t.Errorf("got name %q", c.Name) + } +} + +func TestUpdateConnector_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "PATCH" { + t.Errorf("expected PATCH, got %s", r.Method) + } + json.NewEncoder(w).Encode(map[string]any{ + "id": "c1", "name": "updated", "description": "new desc", + "created_at": "t", "modified_at": "t", + }) + })) + defer server.Close() + + name := "updated" + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.UpdateConnector(context.Background(), "c1", &connector.UpdateRequest{ + Name: &name, + }) + if err != nil { + t.Fatal(err) + } + if resp.Name != "updated" { + t.Errorf("got name %q", resp.Name) + } +} + +func TestDeleteConnector_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "DELETE" { + t.Errorf("expected DELETE, got %s", r.Method) + } + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + err := client.DeleteConnector(context.Background(), "c1") + if err != nil { + t.Fatal(err) + } +} + +func TestGetConnectorAuthURL_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/connectors/c1/auth_url" { + t.Errorf("got path %s", r.URL.Path) + } + json.NewEncoder(w).Encode(map[string]any{ + "auth_url": "https://oauth.example.com/authorize", + "ttl": 3600, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.GetConnectorAuthURL(context.Background(), "c1", nil) + if err != nil { + t.Fatal(err) + } + if resp.AuthURL != "https://oauth.example.com/authorize" { + t.Errorf("got auth_url %q", resp.AuthURL) + } + if resp.TTL != 3600 { + t.Errorf("got ttl %d", resp.TTL) + } +} + +func TestListConnectorTools_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/connectors/c1/tools" { + t.Errorf("got path %s", r.URL.Path) + } + json.NewEncoder(w).Encode([]map[string]any{ + {"id": "t1", "name": "search", "description": "search the web"}, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + tools, err := client.ListConnectorTools(context.Background(), "c1", nil) + if err != nil { + t.Fatal(err) + } + if len(tools) != 1 { + t.Fatalf("got %d tools", len(tools)) + } + if tools[0].Name != "search" { + t.Errorf("got name %q", tools[0].Name) + } +} + +func TestCallConnectorTool_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("expected POST, got %s", r.Method) + } + if r.URL.Path != "/v1/connectors/c1/tools/search/call" { + t.Errorf("got path %s", r.URL.Path) + } + var body map[string]any + json.NewDecoder(r.Body).Decode(&body) + args := body["arguments"].(map[string]any) + if args["query"] != "hello" { + t.Errorf("got query %v", args["query"]) + } + json.NewEncoder(w).Encode(map[string]any{ + "content": []map[string]any{{"type": "text", "text": "result"}}, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.CallConnectorTool(context.Background(), "c1", "search", &connector.CallToolRequest{ + Arguments: map[string]any{"query": "hello"}, + }) + if err != nil { + t.Fatal(err) + } + if resp.Content == nil { + t.Error("expected non-nil content") + } +} diff --git a/doc.go b/doc.go index 2668765..59b52de 100644 --- a/doc.go +++ b/doc.go @@ -39,9 +39,9 @@ // # Sub-packages // // Types are organized into sub-packages by domain: [chat], [agents], -// [conversation], [embedding], [model], [file], [finetune], [batch], -// [ocr], [audio], [library], [moderation], [classification], and [fim]. -// All service methods live directly on [Client]. +// [connector], [conversation], [embedding], [model], [file], [finetune], +// [batch], [ocr], [audio], [library], [moderation], [classification], +// [fim], and [observability]. All service methods live directly on [Client]. // // # Reference // diff --git a/file/file.go b/file/file.go index fa5497a..a6a6091 100644 --- a/file/file.go +++ b/file/file.go @@ -64,6 +64,23 @@ type SignedURL struct { URL string `json:"url"` } +// Visibility controls who can see a file. +type Visibility string + +const ( + VisibilitySharedGlobal Visibility = "shared_global" + VisibilitySharedOrg Visibility = "shared_org" + VisibilitySharedWorkspace Visibility = "shared_workspace" + VisibilityPrivate Visibility = "private" +) + +// UploadParams holds parameters for uploading a file. +type UploadParams struct { + Purpose Purpose + Expiry *int + Visibility *Visibility +} + // ListParams holds optional parameters for listing files. type ListParams struct { Page *int diff --git a/files.go b/files.go index 5f39f1b..4a0ec59 100644 --- a/files.go +++ b/files.go @@ -12,10 +12,18 @@ import ( ) // UploadFile uploads a file for use with fine-tuning, batch, or OCR. -func (c *Client) UploadFile(ctx context.Context, filename string, r io.Reader, purpose file.Purpose) (*file.File, error) { +func (c *Client) UploadFile(ctx context.Context, filename string, r io.Reader, params *file.UploadParams) (*file.File, error) { fields := map[string]string{} - if purpose != "" { - fields["purpose"] = string(purpose) + if params != nil { + if params.Purpose != "" { + fields["purpose"] = string(params.Purpose) + } + if params.Expiry != nil { + fields["expiry"] = strconv.Itoa(*params.Expiry) + } + if params.Visibility != nil { + fields["visibility"] = string(*params.Visibility) + } } var resp file.File if err := c.doMultipart(ctx, "/v1/files", filename, r, fields, &resp); err != nil { diff --git a/files_test.go b/files_test.go index d172311..b30c25f 100644 --- a/files_test.go +++ b/files_test.go @@ -58,7 +58,7 @@ func TestUploadFile_Success(t *testing.T) { context.Background(), "train.jsonl", strings.NewReader(`{"text":"hello"}`), - file.PurposeFineTune, + &file.UploadParams{Purpose: file.PurposeFineTune}, ) if err != nil { t.Fatal(err) @@ -74,6 +74,43 @@ func TestUploadFile_Success(t *testing.T) { } } +func TestUploadFile_WithExpiryAndVisibility(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseMultipartForm(10 << 20); err != nil { + t.Fatal(err) + } + if r.FormValue("purpose") != "fine-tune" { + t.Errorf("got purpose %q", r.FormValue("purpose")) + } + if r.FormValue("expiry") != "48" { + t.Errorf("expected expiry=48, got %q", r.FormValue("expiry")) + } + if r.FormValue("visibility") != "private" { + t.Errorf("expected visibility=private, got %q", r.FormValue("visibility")) + } + + json.NewEncoder(w).Encode(map[string]any{ + "id": "file-ev", "object": "file", "bytes": 10, + "created_at": 1, "filename": "data.jsonl", + "purpose": "fine-tune", "sample_type": "instruct", + "source": "upload", + }) + })) + defer server.Close() + + expiry := 48 + vis := file.VisibilityPrivate + client := NewClient("key", WithBaseURL(server.URL)) + _, err := client.UploadFile(context.Background(), "data.jsonl", strings.NewReader("{}"), &file.UploadParams{ + Purpose: file.PurposeFineTune, + Expiry: &expiry, + Visibility: &vis, + }) + if err != nil { + t.Fatal(err) + } +} + func TestListFiles_Success(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != "GET" { @@ -260,7 +297,7 @@ func TestUploadFile_Error(t *testing.T) { defer server.Close() client := NewClient("key", WithBaseURL(server.URL)) - _, err := client.UploadFile(context.Background(), "bad.txt", strings.NewReader(""), file.PurposeFineTune) + _, err := client.UploadFile(context.Background(), "bad.txt", strings.NewReader(""), &file.UploadParams{Purpose: file.PurposeFineTune}) if err == nil { t.Fatal("expected error") } diff --git a/integration_test.go b/integration_test.go index 995877e..ee90266 100644 --- a/integration_test.go +++ b/integration_test.go @@ -24,7 +24,7 @@ func integrationClient(t *testing.T) *Client { func TestIntegration_ListModels(t *testing.T) { client := integrationClient(t) - resp, err := client.ListModels(context.Background()) + resp, err := client.ListModels(context.Background(), nil) if err != nil { t.Fatal(err) } diff --git a/mistral.go b/mistral.go index d55bc5a..03d6234 100644 --- a/mistral.go +++ b/mistral.go @@ -6,7 +6,7 @@ import ( ) // Version is the SDK version string. -const Version = "1.0.0" +const Version = "1.1.0" const ( defaultBaseURL = "https://api.mistral.ai" diff --git a/model/model.go b/model/model.go index ae20b50..d5f02c6 100644 --- a/model/model.go +++ b/model/model.go @@ -52,3 +52,9 @@ type DeleteModelOut struct { Object string `json:"object"` Deleted bool `json:"deleted"` } + +// ListParams holds optional parameters for listing models. +type ListParams struct { + Provider *string + Model *string +} diff --git a/models.go b/models.go index 3de3265..c1868be 100644 --- a/models.go +++ b/models.go @@ -2,14 +2,28 @@ package mistral import ( "context" + "net/url" "somegit.dev/vikingowl/mistral-go-sdk/model" ) // ListModels returns a list of available models. -func (c *Client) ListModels(ctx context.Context) (*model.ModelList, error) { +func (c *Client) ListModels(ctx context.Context, params *model.ListParams) (*model.ModelList, error) { + path := "/v1/models" + if params != nil { + q := url.Values{} + if params.Provider != nil { + q.Set("provider", *params.Provider) + } + if params.Model != nil { + q.Set("model", *params.Model) + } + if encoded := q.Encode(); encoded != "" { + path += "?" + encoded + } + } var resp model.ModelList - if err := c.doJSON(ctx, "GET", "/v1/models", nil, &resp); err != nil { + if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil { return nil, err } return &resp, nil diff --git a/models_test.go b/models_test.go index 9198242..cdc389d 100644 --- a/models_test.go +++ b/models_test.go @@ -6,6 +6,8 @@ import ( "net/http" "net/http/httptest" "testing" + + "somegit.dev/vikingowl/mistral-go-sdk/model" ) func TestListModels_Success(t *testing.T) { @@ -45,7 +47,7 @@ func TestListModels_Success(t *testing.T) { defer server.Close() client := NewClient("key", WithBaseURL(server.URL)) - list, err := client.ListModels(context.Background()) + list, err := client.ListModels(context.Background(), nil) if err != nil { t.Fatal(err) } @@ -82,6 +84,33 @@ func TestListModels_Success(t *testing.T) { } } +func TestListModels_WithParams(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("provider") != "mistralai" { + t.Errorf("expected provider=mistralai, got %q", r.URL.Query().Get("provider")) + } + if r.URL.Query().Get("model") != "mistral-small" { + t.Errorf("expected model=mistral-small, got %q", r.URL.Query().Get("model")) + } + json.NewEncoder(w).Encode(map[string]any{ + "object": "list", + "data": []map[string]any{}, + }) + })) + defer server.Close() + + provider := "mistralai" + modelName := "mistral-small" + client := NewClient("key", WithBaseURL(server.URL)) + _, err := client.ListModels(context.Background(), &model.ListParams{ + Provider: &provider, + Model: &modelName, + }) + if err != nil { + t.Fatal(err) + } +} + func TestGetModel_Success(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v1/models/mistral-small-latest" { diff --git a/observability/campaign.go b/observability/campaign.go new file mode 100644 index 0000000..620f0f8 --- /dev/null +++ b/observability/campaign.go @@ -0,0 +1,46 @@ +package observability + +// Campaign represents an observability campaign. +type Campaign struct { + ID string `json:"id"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` + DeletedAt *string `json:"deleted_at,omitempty"` + Name string `json:"name"` + OwnerID string `json:"owner_id"` + WorkspaceID string `json:"workspace_id"` + Description string `json:"description"` + MaxNbEvents int `json:"max_nb_events"` + SearchParams FilterPayload `json:"search_params"` + Judge Judge `json:"judge"` +} + +// CreateCampaignRequest creates a new campaign. +type CreateCampaignRequest struct { + SearchParams FilterPayload `json:"search_params"` + JudgeID string `json:"judge_id"` + Name string `json:"name"` + Description string `json:"description"` + MaxNbEvents int `json:"max_nb_events"` +} + +// CampaignStatusResponse is the response for campaign status. +type CampaignStatusResponse struct { + Status TaskStatus `json:"status"` +} + +// ListCampaignsResponse is the response from listing campaigns. +type ListCampaignsResponse struct { + Count int `json:"count"` + Results []Campaign `json:"results,omitempty"` + Next *string `json:"next,omitempty"` + Previous *string `json:"previous,omitempty"` +} + +// ListCampaignEventsResponse is the response from listing campaign events. +type ListCampaignEventsResponse struct { + Count int `json:"count"` + Results []ChatCompletionEventPreview `json:"results,omitempty"` + Next *string `json:"next,omitempty"` + Previous *string `json:"previous,omitempty"` +} diff --git a/observability/dataset.go b/observability/dataset.go new file mode 100644 index 0000000..56001eb --- /dev/null +++ b/observability/dataset.go @@ -0,0 +1,156 @@ +package observability + +import "encoding/json" + +// ConversationSource indicates how a dataset record was created. +type ConversationSource string + +const ( + SourceExplorer ConversationSource = "EXPLORER" + SourceUploadedFile ConversationSource = "UPLOADED_FILE" + SourceDirectInput ConversationSource = "DIRECT_INPUT" + SourcePlayground ConversationSource = "PLAYGROUND" +) + +// Dataset represents a dataset entity. +type Dataset struct { + ID string `json:"id"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` + DeletedAt *string `json:"deleted_at,omitempty"` + Name string `json:"name"` + Description string `json:"description"` + OwnerID string `json:"owner_id"` + WorkspaceID string `json:"workspace_id"` +} + +// CreateDatasetRequest creates a new dataset. +type CreateDatasetRequest struct { + Name string `json:"name"` + Description string `json:"description"` +} + +// UpdateDatasetRequest updates a dataset. +type UpdateDatasetRequest struct { + Name *string `json:"name,omitempty"` + Description *string `json:"description,omitempty"` +} + +// DatasetRecord is a single record in a dataset. +type DatasetRecord struct { + ID string `json:"id"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` + DeletedAt *string `json:"deleted_at,omitempty"` + DatasetID string `json:"dataset_id"` + Payload ConversationPayload `json:"payload"` + Properties map[string]any `json:"properties,omitempty"` + Source ConversationSource `json:"source"` +} + +// ConversationPayload holds the messages for a dataset record. +type ConversationPayload struct { + Messages []map[string]any `json:"messages"` +} + +// CreateRecordRequest creates a new dataset record. +type CreateRecordRequest struct { + Payload ConversationPayload `json:"payload"` + Properties map[string]any `json:"properties"` +} + +// UpdateRecordPayloadRequest updates a record's payload. +type UpdateRecordPayloadRequest struct { + Payload ConversationPayload `json:"payload"` +} + +// UpdateRecordPropertiesRequest updates a record's properties. +type UpdateRecordPropertiesRequest struct { + Properties map[string]any `json:"properties"` +} + +// BulkDeleteRecordsRequest deletes multiple records. +type BulkDeleteRecordsRequest struct { + DatasetRecordIDs []string `json:"dataset_record_ids"` +} + +// JudgeRecordRequest judges a dataset record. +type JudgeRecordRequest struct { + JudgeDefinition CreateJudgeRequest `json:"judge_definition"` +} + +// DatasetImportTask tracks an async import operation. +type DatasetImportTask struct { + ID string `json:"id"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` + DeletedAt *string `json:"deleted_at,omitempty"` + CreatorID string `json:"creator_id"` + DatasetID string `json:"dataset_id"` + WorkspaceID string `json:"workspace_id"` + Status TaskStatus `json:"status"` + Progress *int `json:"progress,omitempty"` + Message *string `json:"message,omitempty"` +} + +// ExportDatasetResponse is the response from exporting a dataset. +type ExportDatasetResponse struct { + FileURL string `json:"file_url"` +} + +// Import request types. + +// ImportFromCampaignRequest imports records from a campaign. +type ImportFromCampaignRequest struct { + CampaignID string `json:"campaign_id"` +} + +// ImportFromExplorerRequest imports records from explorer events. +type ImportFromExplorerRequest struct { + CompletionEventIDs []string `json:"completion_event_ids"` +} + +// ImportFromFileRequest imports records from a file. +type ImportFromFileRequest struct { + FileID string `json:"file_id"` +} + +// ImportFromPlaygroundRequest imports records from playground conversations. +type ImportFromPlaygroundRequest struct { + ConversationIDs []string `json:"conversation_ids"` +} + +// ImportFromDatasetRequest imports records from another dataset. +type ImportFromDatasetRequest struct { + DatasetRecordIDs []string `json:"dataset_record_ids"` +} + +// List response types. + +// ListDatasetsResponse is the response from listing datasets. +type ListDatasetsResponse struct { + Count int `json:"count"` + Results []Dataset `json:"results,omitempty"` + Next *string `json:"next,omitempty"` + Previous *string `json:"previous,omitempty"` +} + +// ListRecordsResponse is the response from listing dataset records. +type ListRecordsResponse struct { + Count int `json:"count"` + Results []DatasetRecord `json:"results,omitempty"` + Next *string `json:"next,omitempty"` + Previous *string `json:"previous,omitempty"` +} + +// ListTasksResponse is the response from listing import tasks. +type ListTasksResponse struct { + Count int `json:"count"` + Results []DatasetImportTask `json:"results,omitempty"` + Next *string `json:"next,omitempty"` + Previous *string `json:"previous,omitempty"` +} + +// JudgeResultResponse is the raw response from judging operations. +// The shape depends on the judge type (classification or regression). +type JudgeResultResponse json.RawMessage diff --git a/observability/doc.go b/observability/doc.go new file mode 100644 index 0000000..ea8c99c --- /dev/null +++ b/observability/doc.go @@ -0,0 +1,5 @@ +// Package observability provides types for the Mistral observability API (beta). +// +// This includes campaigns, chat completion events, judges, and datasets +// for monitoring and evaluating model behavior. +package observability diff --git a/observability/event.go b/observability/event.go new file mode 100644 index 0000000..df8f9c4 --- /dev/null +++ b/observability/event.go @@ -0,0 +1,70 @@ +package observability + +// ChatCompletionEvent is a full chat completion event. +type ChatCompletionEvent struct { + EventID string `json:"event_id"` + CorrelationID string `json:"correlation_id"` + CreatedAt string `json:"created_at"` + ExtraFields map[string]any `json:"extra_fields,omitempty"` + NbInputTokens int `json:"nb_input_tokens"` + NbOutputTokens int `json:"nb_output_tokens"` + EnabledTools []map[string]any `json:"enabled_tools,omitempty"` + RequestMessages []map[string]any `json:"request_messages,omitempty"` + ResponseMessages []map[string]any `json:"response_messages,omitempty"` + NbMessages int `json:"nb_messages"` + ChatTranscriptionEvents []ChatTranscriptionEvent `json:"chat_transcription_events,omitempty"` +} + +// ChatCompletionEventPreview is a summary of a chat completion event. +type ChatCompletionEventPreview struct { + EventID string `json:"event_id"` + CorrelationID string `json:"correlation_id"` + CreatedAt string `json:"created_at"` + ExtraFields map[string]any `json:"extra_fields,omitempty"` + NbInputTokens int `json:"nb_input_tokens"` + NbOutputTokens int `json:"nb_output_tokens"` +} + +// ChatTranscriptionEvent is an audio transcription within a chat event. +type ChatTranscriptionEvent struct { + AudioURL string `json:"audio_url"` + Model string `json:"model"` + ResponseMessage map[string]any `json:"response_message"` +} + +// SearchEventsRequest is the request body for searching chat completion events. +type SearchEventsRequest struct { + SearchParams FilterPayload `json:"search_params"` + ExtraFields []string `json:"extra_fields,omitempty"` +} + +// SearchEventsResponse is the response from searching events. +type SearchEventsResponse struct { + Results []ChatCompletionEventPreview `json:"results,omitempty"` + Next *string `json:"next,omitempty"` + Cursor *string `json:"cursor,omitempty"` +} + +// SearchEventIDsRequest is the request body for searching event IDs. +type SearchEventIDsRequest struct { + SearchParams FilterPayload `json:"search_params"` + ExtraFields []string `json:"extra_fields,omitempty"` +} + +// SearchEventIDsResponse is the response from searching event IDs. +type SearchEventIDsResponse struct { + CompletionEventIDs []string `json:"completion_event_ids"` +} + +// JudgeEventRequest is the request body for judging a chat completion event. +type JudgeEventRequest struct { + JudgeDefinition CreateJudgeRequest `json:"judge_definition"` +} + +// SimilarEventsResponse is the response from fetching similar events. +type SimilarEventsResponse struct { + Count int `json:"count"` + Results []ChatCompletionEventPreview `json:"results,omitempty"` + Next *string `json:"next,omitempty"` + Previous *string `json:"previous,omitempty"` +} diff --git a/observability/filter.go b/observability/filter.go new file mode 100644 index 0000000..5123cc7 --- /dev/null +++ b/observability/filter.go @@ -0,0 +1,75 @@ +package observability + +import "encoding/json" + +// Op is a filter comparison operator. +type Op string + +const ( + OpLt Op = "lt" + OpLte Op = "lte" + OpGt Op = "gt" + OpGte Op = "gte" + OpEq Op = "eq" + OpNeq Op = "neq" + OpIsNull Op = "isnull" + OpStartsWith Op = "startswith" + OpIStartsWith Op = "istartswith" + OpEndsWith Op = "endswith" + OpIEndsWith Op = "iendswith" + OpContains Op = "contains" + OpIContains Op = "icontains" + OpMatches Op = "matches" + OpNotContains Op = "notcontains" + OpINotContains Op = "inotcontains" + OpIncludes Op = "includes" + OpExcludes Op = "excludes" + OpLenEq Op = "len_eq" +) + +// FilterCondition is a single filter comparison. +type FilterCondition struct { + Field string `json:"field"` + Op Op `json:"op"` + Value any `json:"value"` +} + +// FilterGroup combines filters with AND/OR logic. +// The JSON keys are uppercase "AND" / "OR". +type FilterGroup struct { + AND []json.RawMessage `json:"AND,omitempty"` + OR []json.RawMessage `json:"OR,omitempty"` +} + +// FilterPayload wraps the top-level filter for search operations. +// Filters can be a FilterGroup or a FilterCondition. +type FilterPayload struct { + Filters json.RawMessage `json:"filters,omitempty"` +} + +// TaskStatus is the status of an async task. +type TaskStatus string + +const ( + TaskStatusRunning TaskStatus = "RUNNING" + TaskStatusCompleted TaskStatus = "COMPLETED" + TaskStatusFailed TaskStatus = "FAILED" + TaskStatusCanceled TaskStatus = "CANCELED" + TaskStatusTerminated TaskStatus = "TERMINATED" + TaskStatusContinuedAsNew TaskStatus = "CONTINUED_AS_NEW" + TaskStatusTimedOut TaskStatus = "TIMED_OUT" + TaskStatusUnknown TaskStatus = "UNKNOWN" +) + +// PaginationParams holds common pagination query parameters. +type PaginationParams struct { + Page *int + PageSize *int +} + +// SearchParams holds common search query parameters. +type SearchParams struct { + Page *int + PageSize *int + Q *string +} diff --git a/observability/judge.go b/observability/judge.go new file mode 100644 index 0000000..474a232 --- /dev/null +++ b/observability/judge.go @@ -0,0 +1,114 @@ +package observability + +import ( + "encoding/json" + "fmt" +) + +// JudgeOutputType identifies the kind of judge output. +type JudgeOutputType string + +const ( + JudgeOutputClassification JudgeOutputType = "CLASSIFICATION" + JudgeOutputRegression JudgeOutputType = "REGRESSION" +) + +// JudgeOutput is a sealed interface for judge output configurations. +type JudgeOutput interface { + judgeOutputType() JudgeOutputType +} + +// ClassificationOutput configures a classification judge. +type ClassificationOutput struct { + Type JudgeOutputType `json:"type"` + Options []ClassificationOption `json:"options"` +} + +func (*ClassificationOutput) judgeOutputType() JudgeOutputType { return JudgeOutputClassification } + +// ClassificationOption is a single option for a classification judge. +type ClassificationOption struct { + Value string `json:"value"` + Description string `json:"description"` +} + +// RegressionOutput configures a regression judge. +type RegressionOutput struct { + Type JudgeOutputType `json:"type"` + MinDescription string `json:"min_description"` + MaxDescription string `json:"max_description"` + Min *float64 `json:"min,omitempty"` + Max *float64 `json:"max,omitempty"` +} + +func (*RegressionOutput) judgeOutputType() JudgeOutputType { return JudgeOutputRegression } + +// UnmarshalJudgeOutput dispatches to the concrete JudgeOutput type. +func UnmarshalJudgeOutput(data []byte) (JudgeOutput, error) { + var probe struct { + Type string `json:"type"` + } + if err := json.Unmarshal(data, &probe); err != nil { + return nil, fmt.Errorf("unmarshal judge output: %w", err) + } + switch JudgeOutputType(probe.Type) { + case JudgeOutputClassification: + var o ClassificationOutput + return &o, json.Unmarshal(data, &o) + case JudgeOutputRegression: + var o RegressionOutput + return &o, json.Unmarshal(data, &o) + default: + return nil, fmt.Errorf("unknown judge output type: %q", probe.Type) + } +} + +// Judge represents a judge entity. +type Judge struct { + ID string `json:"id"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` + DeletedAt *string `json:"deleted_at,omitempty"` + OwnerID string `json:"owner_id"` + WorkspaceID string `json:"workspace_id"` + Name string `json:"name"` + Description string `json:"description"` + ModelName string `json:"model_name"` + Output json.RawMessage `json:"output"` + Instructions string `json:"instructions"` + Tools []string `json:"tools,omitempty"` +} + +// CreateJudgeRequest creates a new judge. +type CreateJudgeRequest struct { + Name string `json:"name"` + Description string `json:"description"` + ModelName string `json:"model_name"` + Output json.RawMessage `json:"output"` + Instructions string `json:"instructions"` + Tools []string `json:"tools"` +} + +// UpdateJudgeRequest updates a judge. +type UpdateJudgeRequest struct { + Name string `json:"name"` + Description string `json:"description"` + ModelName string `json:"model_name"` + Output json.RawMessage `json:"output"` + Instructions string `json:"instructions"` + Tools []string `json:"tools"` +} + +// JudgeConversationRequest is the request for live-judging a conversation. +type JudgeConversationRequest struct { + Messages []map[string]any `json:"messages"` + Properties map[string]any `json:"properties,omitempty"` +} + +// ListJudgesResponse is the response from listing judges. +type ListJudgesResponse struct { + Count int `json:"count"` + Results []Judge `json:"results,omitempty"` + Next *string `json:"next,omitempty"` + Previous *string `json:"previous,omitempty"` +} diff --git a/observability_campaigns.go b/observability_campaigns.go new file mode 100644 index 0000000..28ab679 --- /dev/null +++ b/observability_campaigns.go @@ -0,0 +1,97 @@ +package mistral + +import ( + "context" + "fmt" + "net/url" + "strconv" + + "somegit.dev/vikingowl/mistral-go-sdk/observability" +) + +// CreateCampaign creates a new observability campaign. +func (c *Client) CreateCampaign(ctx context.Context, req *observability.CreateCampaignRequest) (*observability.Campaign, error) { + var resp observability.Campaign + if err := c.doJSON(ctx, "POST", "/v1/observability/campaigns", req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// ListCampaigns lists observability campaigns. +func (c *Client) ListCampaigns(ctx context.Context, params *observability.SearchParams) (*observability.ListCampaignsResponse, error) { + path := "/v1/observability/campaigns" + if params != nil { + q := url.Values{} + if params.PageSize != nil { + q.Set("page_size", strconv.Itoa(*params.PageSize)) + } + if params.Page != nil { + q.Set("page", strconv.Itoa(*params.Page)) + } + if params.Q != nil { + q.Set("q", *params.Q) + } + if encoded := q.Encode(); encoded != "" { + path += "?" + encoded + } + } + var resp observability.ListCampaignsResponse + if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// GetCampaign retrieves a campaign by ID. +func (c *Client) GetCampaign(ctx context.Context, campaignID string) (*observability.Campaign, error) { + var resp observability.Campaign + if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/campaigns/%s", campaignID), nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// DeleteCampaign deletes a campaign. +func (c *Client) DeleteCampaign(ctx context.Context, campaignID string) error { + resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/observability/campaigns/%s", campaignID), nil) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode >= 400 { + return parseAPIError(resp) + } + return nil +} + +// GetCampaignStatus retrieves the status of a campaign. +func (c *Client) GetCampaignStatus(ctx context.Context, campaignID string) (*observability.CampaignStatusResponse, error) { + var resp observability.CampaignStatusResponse + if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/campaigns/%s/status", campaignID), nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// ListCampaignEvents lists events selected by a campaign. +func (c *Client) ListCampaignEvents(ctx context.Context, campaignID string, params *observability.PaginationParams) (*observability.ListCampaignEventsResponse, error) { + path := fmt.Sprintf("/v1/observability/campaigns/%s/selected-events", campaignID) + if params != nil { + q := url.Values{} + if params.PageSize != nil { + q.Set("page_size", strconv.Itoa(*params.PageSize)) + } + if params.Page != nil { + q.Set("page", strconv.Itoa(*params.Page)) + } + if encoded := q.Encode(); encoded != "" { + path += "?" + encoded + } + } + var resp observability.ListCampaignEventsResponse + if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} diff --git a/observability_campaigns_test.go b/observability_campaigns_test.go new file mode 100644 index 0000000..1a9f12b --- /dev/null +++ b/observability_campaigns_test.go @@ -0,0 +1,148 @@ +package mistral + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "somegit.dev/vikingowl/mistral-go-sdk/observability" +) + +func TestCreateCampaign_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" || r.URL.Path != "/v1/observability/campaigns" { + t.Errorf("unexpected %s %s", r.Method, r.URL.Path) + } + var body map[string]any + json.NewDecoder(r.Body).Decode(&body) + if body["name"] != "test-campaign" { + t.Errorf("got name %v", body["name"]) + } + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(map[string]any{ + "id": "camp-1", "name": "test-campaign", "description": "d", + "created_at": "t", "updated_at": "t", "owner_id": "o", + "workspace_id": "w", "max_nb_events": 100, + "search_params": map[string]any{}, "judge": map[string]any{"id": "j1"}, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.CreateCampaign(context.Background(), &observability.CreateCampaignRequest{ + Name: "test-campaign", + Description: "d", + JudgeID: "j1", + MaxNbEvents: 100, + }) + if err != nil { + t.Fatal(err) + } + if resp.ID != "camp-1" { + t.Errorf("got id %q", resp.ID) + } +} + +func TestListCampaigns_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/observability/campaigns" { + t.Errorf("got path %s", r.URL.Path) + } + json.NewEncoder(w).Encode(map[string]any{ + "count": 1, + "results": []map[string]any{{"id": "c1", "name": "c", "description": "d", "created_at": "t", "updated_at": "t", "owner_id": "o", "workspace_id": "w", "max_nb_events": 10, "search_params": map[string]any{}, "judge": map[string]any{"id": "j"}}}, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.ListCampaigns(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + if resp.Count != 1 { + t.Errorf("got count %d", resp.Count) + } +} + +func TestGetCampaign_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/observability/campaigns/camp-1" { + t.Errorf("got path %s", r.URL.Path) + } + json.NewEncoder(w).Encode(map[string]any{ + "id": "camp-1", "name": "c", "description": "d", + "created_at": "t", "updated_at": "t", "owner_id": "o", + "workspace_id": "w", "max_nb_events": 10, + "search_params": map[string]any{}, "judge": map[string]any{"id": "j"}, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.GetCampaign(context.Background(), "camp-1") + if err != nil { + t.Fatal(err) + } + if resp.ID != "camp-1" { + t.Errorf("got id %q", resp.ID) + } +} + +func TestDeleteCampaign_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "DELETE" { + t.Errorf("expected DELETE") + } + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + if err := client.DeleteCampaign(context.Background(), "camp-1"); err != nil { + t.Fatal(err) + } +} + +func TestGetCampaignStatus_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/observability/campaigns/camp-1/status" { + t.Errorf("got path %s", r.URL.Path) + } + json.NewEncoder(w).Encode(map[string]any{"status": "COMPLETED"}) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.GetCampaignStatus(context.Background(), "camp-1") + if err != nil { + t.Fatal(err) + } + if resp.Status != observability.TaskStatusCompleted { + t.Errorf("got status %q", resp.Status) + } +} + +func TestListCampaignEvents_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/observability/campaigns/camp-1/selected-events" { + t.Errorf("got path %s", r.URL.Path) + } + json.NewEncoder(w).Encode(map[string]any{ + "count": 0, + "results": []any{}, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.ListCampaignEvents(context.Background(), "camp-1", nil) + if err != nil { + t.Fatal(err) + } + if resp.Count != 0 { + t.Errorf("got count %d", resp.Count) + } +} diff --git a/observability_datasets.go b/observability_datasets.go new file mode 100644 index 0000000..fb999da --- /dev/null +++ b/observability_datasets.go @@ -0,0 +1,252 @@ +package mistral + +import ( + "context" + "encoding/json" + "fmt" + "net/url" + "strconv" + + "somegit.dev/vikingowl/mistral-go-sdk/observability" +) + +// CreateDataset creates a new observability dataset. +func (c *Client) CreateDataset(ctx context.Context, req *observability.CreateDatasetRequest) (*observability.Dataset, error) { + var resp observability.Dataset + if err := c.doJSON(ctx, "POST", "/v1/observability/datasets", req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// ListDatasets lists observability datasets. +func (c *Client) ListDatasets(ctx context.Context, params *observability.SearchParams) (*observability.ListDatasetsResponse, error) { + path := "/v1/observability/datasets" + if params != nil { + q := url.Values{} + if params.PageSize != nil { + q.Set("page_size", strconv.Itoa(*params.PageSize)) + } + if params.Page != nil { + q.Set("page", strconv.Itoa(*params.Page)) + } + if params.Q != nil { + q.Set("q", *params.Q) + } + if encoded := q.Encode(); encoded != "" { + path += "?" + encoded + } + } + var resp observability.ListDatasetsResponse + if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// GetDataset retrieves a dataset by ID. +func (c *Client) GetDataset(ctx context.Context, datasetID string) (*observability.Dataset, error) { + var resp observability.Dataset + if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/datasets/%s", datasetID), nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// UpdateDataset updates a dataset. +func (c *Client) UpdateDataset(ctx context.Context, datasetID string, req *observability.UpdateDatasetRequest) (*observability.Dataset, error) { + var resp observability.Dataset + if err := c.doJSON(ctx, "PATCH", fmt.Sprintf("/v1/observability/datasets/%s", datasetID), req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// DeleteDataset deletes a dataset. +func (c *Client) DeleteDataset(ctx context.Context, datasetID string) error { + resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/observability/datasets/%s", datasetID), nil) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode >= 400 { + return parseAPIError(resp) + } + return nil +} + +// ExportDatasetToJSONL exports a dataset to JSONL format. +func (c *Client) ExportDatasetToJSONL(ctx context.Context, datasetID string) (*observability.ExportDatasetResponse, error) { + var resp observability.ExportDatasetResponse + if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/datasets/%s/exports/to-jsonl", datasetID), nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// Dataset records + +// ListDatasetRecords lists records in a dataset. +func (c *Client) ListDatasetRecords(ctx context.Context, datasetID string, params *observability.PaginationParams) (*observability.ListRecordsResponse, error) { + path := fmt.Sprintf("/v1/observability/datasets/%s/records", datasetID) + if params != nil { + q := url.Values{} + if params.PageSize != nil { + q.Set("page_size", strconv.Itoa(*params.PageSize)) + } + if params.Page != nil { + q.Set("page", strconv.Itoa(*params.Page)) + } + if encoded := q.Encode(); encoded != "" { + path += "?" + encoded + } + } + var resp observability.ListRecordsResponse + if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// CreateDatasetRecord creates a record in a dataset. +func (c *Client) CreateDatasetRecord(ctx context.Context, datasetID string, req *observability.CreateRecordRequest) (*observability.DatasetRecord, error) { + var resp observability.DatasetRecord + if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/records", datasetID), req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// GetDatasetRecord retrieves a dataset record by ID. +func (c *Client) GetDatasetRecord(ctx context.Context, recordID string) (*observability.DatasetRecord, error) { + var resp observability.DatasetRecord + if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/dataset-records/%s", recordID), nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// UpdateDatasetRecordPayload updates a record's payload. +func (c *Client) UpdateDatasetRecordPayload(ctx context.Context, recordID string, req *observability.UpdateRecordPayloadRequest) (*observability.DatasetRecord, error) { + var resp observability.DatasetRecord + if err := c.doJSON(ctx, "PUT", fmt.Sprintf("/v1/observability/dataset-records/%s/payload", recordID), req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// UpdateDatasetRecordProperties updates a record's properties. +func (c *Client) UpdateDatasetRecordProperties(ctx context.Context, recordID string, req *observability.UpdateRecordPropertiesRequest) (*observability.DatasetRecord, error) { + var resp observability.DatasetRecord + if err := c.doJSON(ctx, "PUT", fmt.Sprintf("/v1/observability/dataset-records/%s/properties", recordID), req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// DeleteDatasetRecord deletes a dataset record. +func (c *Client) DeleteDatasetRecord(ctx context.Context, recordID string) error { + resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/observability/dataset-records/%s", recordID), nil) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode >= 400 { + return parseAPIError(resp) + } + return nil +} + +// BulkDeleteDatasetRecords deletes multiple dataset records. +func (c *Client) BulkDeleteDatasetRecords(ctx context.Context, req *observability.BulkDeleteRecordsRequest) error { + return c.doJSON(ctx, "POST", "/v1/observability/dataset-records/bulk-delete", req, nil) +} + +// JudgeDatasetRecord judges a dataset record. +func (c *Client) JudgeDatasetRecord(ctx context.Context, recordID string, req *observability.JudgeRecordRequest) (json.RawMessage, error) { + var resp json.RawMessage + if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/dataset-records/%s/live-judging", recordID), req, &resp); err != nil { + return nil, err + } + return resp, nil +} + +// Import operations + +// ImportDatasetFromCampaign imports records from a campaign. +func (c *Client) ImportDatasetFromCampaign(ctx context.Context, datasetID string, req *observability.ImportFromCampaignRequest) (*observability.DatasetImportTask, error) { + var resp observability.DatasetImportTask + if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/imports/from-campaign", datasetID), req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// ImportDatasetFromExplorer imports records from explorer events. +func (c *Client) ImportDatasetFromExplorer(ctx context.Context, datasetID string, req *observability.ImportFromExplorerRequest) (*observability.DatasetImportTask, error) { + var resp observability.DatasetImportTask + if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/imports/from-explorer", datasetID), req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// ImportDatasetFromFile imports records from a file. +func (c *Client) ImportDatasetFromFile(ctx context.Context, datasetID string, req *observability.ImportFromFileRequest) (*observability.DatasetImportTask, error) { + var resp observability.DatasetImportTask + if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/imports/from-file", datasetID), req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// ImportDatasetFromPlayground imports records from playground conversations. +func (c *Client) ImportDatasetFromPlayground(ctx context.Context, datasetID string, req *observability.ImportFromPlaygroundRequest) (*observability.DatasetImportTask, error) { + var resp observability.DatasetImportTask + if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/imports/from-playground", datasetID), req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// ImportDatasetFromDataset imports records from another dataset. +func (c *Client) ImportDatasetFromDataset(ctx context.Context, datasetID string, req *observability.ImportFromDatasetRequest) (*observability.DatasetImportTask, error) { + var resp observability.DatasetImportTask + if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/imports/from-dataset", datasetID), req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// Tasks + +// ListDatasetTasks lists import tasks for a dataset. +func (c *Client) ListDatasetTasks(ctx context.Context, datasetID string, params *observability.PaginationParams) (*observability.ListTasksResponse, error) { + path := fmt.Sprintf("/v1/observability/datasets/%s/tasks", datasetID) + if params != nil { + q := url.Values{} + if params.PageSize != nil { + q.Set("page_size", strconv.Itoa(*params.PageSize)) + } + if params.Page != nil { + q.Set("page", strconv.Itoa(*params.Page)) + } + if encoded := q.Encode(); encoded != "" { + path += "?" + encoded + } + } + var resp observability.ListTasksResponse + if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// GetDatasetTask retrieves an import task by ID. +func (c *Client) GetDatasetTask(ctx context.Context, datasetID, taskID string) (*observability.DatasetImportTask, error) { + var resp observability.DatasetImportTask + if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/datasets/%s/tasks/%s", datasetID, taskID), nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} diff --git a/observability_datasets_test.go b/observability_datasets_test.go new file mode 100644 index 0000000..190901d --- /dev/null +++ b/observability_datasets_test.go @@ -0,0 +1,211 @@ +package mistral + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "somegit.dev/vikingowl/mistral-go-sdk/observability" +) + +func datasetJSON() map[string]any { + return map[string]any{ + "id": "ds-1", "created_at": "t", "updated_at": "t", + "name": "test-ds", "description": "d", + "owner_id": "o", "workspace_id": "w", + } +} + +func TestCreateDataset_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" || r.URL.Path != "/v1/observability/datasets" { + t.Errorf("unexpected %s %s", r.Method, r.URL.Path) + } + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(datasetJSON()) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.CreateDataset(context.Background(), &observability.CreateDatasetRequest{ + Name: "test-ds", + Description: "d", + }) + if err != nil { + t.Fatal(err) + } + if resp.ID != "ds-1" { + t.Errorf("got id %q", resp.ID) + } +} + +func TestListDatasets_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]any{ + "count": 1, + "results": []any{datasetJSON()}, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.ListDatasets(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + if resp.Count != 1 { + t.Errorf("got count %d", resp.Count) + } +} + +func TestGetDataset_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/observability/datasets/ds-1" { + t.Errorf("got path %s", r.URL.Path) + } + json.NewEncoder(w).Encode(datasetJSON()) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.GetDataset(context.Background(), "ds-1") + if err != nil { + t.Fatal(err) + } + if resp.Name != "test-ds" { + t.Errorf("got name %q", resp.Name) + } +} + +func TestDeleteDataset_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + if err := client.DeleteDataset(context.Background(), "ds-1"); err != nil { + t.Fatal(err) + } +} + +func TestCreateDatasetRecord_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" || r.URL.Path != "/v1/observability/datasets/ds-1/records" { + t.Errorf("unexpected %s %s", r.Method, r.URL.Path) + } + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(map[string]any{ + "id": "rec-1", "created_at": "t", "updated_at": "t", + "dataset_id": "ds-1", "source": "DIRECT_INPUT", + "payload": map[string]any{"messages": []any{}}, + "properties": map[string]any{}, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.CreateDatasetRecord(context.Background(), "ds-1", &observability.CreateRecordRequest{ + Payload: observability.ConversationPayload{Messages: []map[string]any{}}, + Properties: map[string]any{}, + }) + if err != nil { + t.Fatal(err) + } + if resp.ID != "rec-1" { + t.Errorf("got id %q", resp.ID) + } +} + +func TestListDatasetRecords_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/observability/datasets/ds-1/records" { + t.Errorf("got path %s", r.URL.Path) + } + json.NewEncoder(w).Encode(map[string]any{ + "count": 0, "results": []any{}, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.ListDatasetRecords(context.Background(), "ds-1", nil) + if err != nil { + t.Fatal(err) + } + if resp.Count != 0 { + t.Errorf("got count %d", resp.Count) + } +} + +func TestImportDatasetFromCampaign_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/observability/datasets/ds-1/imports/from-campaign" { + t.Errorf("got path %s", r.URL.Path) + } + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(map[string]any{ + "id": "task-1", "created_at": "t", "updated_at": "t", + "creator_id": "u", "dataset_id": "ds-1", "workspace_id": "w", + "status": "RUNNING", + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.ImportDatasetFromCampaign(context.Background(), "ds-1", &observability.ImportFromCampaignRequest{ + CampaignID: "camp-1", + }) + if err != nil { + t.Fatal(err) + } + if resp.Status != observability.TaskStatusRunning { + t.Errorf("got status %q", resp.Status) + } +} + +func TestExportDatasetToJSONL_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/observability/datasets/ds-1/exports/to-jsonl" { + t.Errorf("got path %s", r.URL.Path) + } + json.NewEncoder(w).Encode(map[string]any{ + "file_url": "https://storage.example.com/export.jsonl", + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.ExportDatasetToJSONL(context.Background(), "ds-1") + if err != nil { + t.Fatal(err) + } + if resp.FileURL != "https://storage.example.com/export.jsonl" { + t.Errorf("got file_url %q", resp.FileURL) + } +} + +func TestGetDatasetTask_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/observability/datasets/ds-1/tasks/task-1" { + t.Errorf("got path %s", r.URL.Path) + } + json.NewEncoder(w).Encode(map[string]any{ + "id": "task-1", "created_at": "t", "updated_at": "t", + "creator_id": "u", "dataset_id": "ds-1", "workspace_id": "w", + "status": "COMPLETED", + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.GetDatasetTask(context.Background(), "ds-1", "task-1") + if err != nil { + t.Fatal(err) + } + if resp.Status != observability.TaskStatusCompleted { + t.Errorf("got status %q", resp.Status) + } +} diff --git a/observability_events.go b/observability_events.go new file mode 100644 index 0000000..0afdb4a --- /dev/null +++ b/observability_events.go @@ -0,0 +1,69 @@ +package mistral + +import ( + "context" + "encoding/json" + "fmt" + "net/url" + "strconv" + + "somegit.dev/vikingowl/mistral-go-sdk/observability" +) + +// SearchChatCompletionEvents searches for chat completion events. +func (c *Client) SearchChatCompletionEvents(ctx context.Context, req *observability.SearchEventsRequest) (*observability.SearchEventsResponse, error) { + var resp observability.SearchEventsResponse + if err := c.doJSON(ctx, "POST", "/v1/observability/chat-completion-events/search", req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// SearchChatCompletionEventIDs searches for chat completion event IDs. +func (c *Client) SearchChatCompletionEventIDs(ctx context.Context, req *observability.SearchEventIDsRequest) (*observability.SearchEventIDsResponse, error) { + var resp observability.SearchEventIDsResponse + if err := c.doJSON(ctx, "POST", "/v1/observability/chat-completion-events/search-ids", req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// GetChatCompletionEvent retrieves a chat completion event by ID. +func (c *Client) GetChatCompletionEvent(ctx context.Context, eventID string) (*observability.ChatCompletionEvent, error) { + var resp observability.ChatCompletionEvent + if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/chat-completion-events/%s", eventID), nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// GetSimilarChatCompletionEvents retrieves events similar to a given event. +func (c *Client) GetSimilarChatCompletionEvents(ctx context.Context, eventID string, params *observability.PaginationParams) (*observability.SimilarEventsResponse, error) { + path := fmt.Sprintf("/v1/observability/chat-completion-events/%s/similar-events", eventID) + if params != nil { + q := url.Values{} + if params.PageSize != nil { + q.Set("page_size", strconv.Itoa(*params.PageSize)) + } + if params.Page != nil { + q.Set("page", strconv.Itoa(*params.Page)) + } + if encoded := q.Encode(); encoded != "" { + path += "?" + encoded + } + } + var resp observability.SimilarEventsResponse + if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// JudgeChatCompletionEvent judges a chat completion event. +func (c *Client) JudgeChatCompletionEvent(ctx context.Context, eventID string, req *observability.JudgeEventRequest) (json.RawMessage, error) { + var resp json.RawMessage + if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/chat-completion-events/%s/live-judging", eventID), req, &resp); err != nil { + return nil, err + } + return resp, nil +} diff --git a/observability_events_test.go b/observability_events_test.go new file mode 100644 index 0000000..9c17bbe --- /dev/null +++ b/observability_events_test.go @@ -0,0 +1,101 @@ +package mistral + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "somegit.dev/vikingowl/mistral-go-sdk/observability" +) + +func TestSearchChatCompletionEvents_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" || r.URL.Path != "/v1/observability/chat-completion-events/search" { + t.Errorf("unexpected %s %s", r.Method, r.URL.Path) + } + json.NewEncoder(w).Encode(map[string]any{ + "results": []map[string]any{ + {"event_id": "ev-1", "correlation_id": "c1", "created_at": "t", "nb_input_tokens": 10, "nb_output_tokens": 5}, + }, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.SearchChatCompletionEvents(context.Background(), &observability.SearchEventsRequest{}) + if err != nil { + t.Fatal(err) + } + if len(resp.Results) != 1 { + t.Fatalf("got %d results", len(resp.Results)) + } + if resp.Results[0].EventID != "ev-1" { + t.Errorf("got event_id %q", resp.Results[0].EventID) + } +} + +func TestSearchChatCompletionEventIDs_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/observability/chat-completion-events/search-ids" { + t.Errorf("got path %s", r.URL.Path) + } + json.NewEncoder(w).Encode(map[string]any{ + "completion_event_ids": []string{"ev-1", "ev-2"}, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.SearchChatCompletionEventIDs(context.Background(), &observability.SearchEventIDsRequest{}) + if err != nil { + t.Fatal(err) + } + if len(resp.CompletionEventIDs) != 2 { + t.Errorf("got %d ids", len(resp.CompletionEventIDs)) + } +} + +func TestGetChatCompletionEvent_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/observability/chat-completion-events/ev-1" { + t.Errorf("got path %s", r.URL.Path) + } + json.NewEncoder(w).Encode(map[string]any{ + "event_id": "ev-1", "correlation_id": "c1", "created_at": "t", + "nb_input_tokens": 10, "nb_output_tokens": 5, "nb_messages": 2, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.GetChatCompletionEvent(context.Background(), "ev-1") + if err != nil { + t.Fatal(err) + } + if resp.EventID != "ev-1" { + t.Errorf("got event_id %q", resp.EventID) + } +} + +func TestGetSimilarChatCompletionEvents_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/observability/chat-completion-events/ev-1/similar-events" { + t.Errorf("got path %s", r.URL.Path) + } + json.NewEncoder(w).Encode(map[string]any{ + "count": 0, "results": []any{}, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.GetSimilarChatCompletionEvents(context.Background(), "ev-1", nil) + if err != nil { + t.Fatal(err) + } + if resp.Count != 0 { + t.Errorf("got count %d", resp.Count) + } +} diff --git a/observability_judges.go b/observability_judges.go new file mode 100644 index 0000000..053baeb --- /dev/null +++ b/observability_judges.go @@ -0,0 +1,85 @@ +package mistral + +import ( + "context" + "encoding/json" + "fmt" + "net/url" + "strconv" + + "somegit.dev/vikingowl/mistral-go-sdk/observability" +) + +// CreateJudge creates a new observability judge. +func (c *Client) CreateJudge(ctx context.Context, req *observability.CreateJudgeRequest) (*observability.Judge, error) { + var resp observability.Judge + if err := c.doJSON(ctx, "POST", "/v1/observability/judges", req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// ListJudges lists observability judges. +func (c *Client) ListJudges(ctx context.Context, params *observability.SearchParams) (*observability.ListJudgesResponse, error) { + path := "/v1/observability/judges" + if params != nil { + q := url.Values{} + if params.PageSize != nil { + q.Set("page_size", strconv.Itoa(*params.PageSize)) + } + if params.Page != nil { + q.Set("page", strconv.Itoa(*params.Page)) + } + if params.Q != nil { + q.Set("q", *params.Q) + } + if encoded := q.Encode(); encoded != "" { + path += "?" + encoded + } + } + var resp observability.ListJudgesResponse + if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// GetJudge retrieves a judge by ID. +func (c *Client) GetJudge(ctx context.Context, judgeID string) (*observability.Judge, error) { + var resp observability.Judge + if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/judges/%s", judgeID), nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// UpdateJudge updates a judge. +func (c *Client) UpdateJudge(ctx context.Context, judgeID string, req *observability.UpdateJudgeRequest) (*observability.Judge, error) { + var resp observability.Judge + if err := c.doJSON(ctx, "PUT", fmt.Sprintf("/v1/observability/judges/%s", judgeID), req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// DeleteJudge deletes a judge. +func (c *Client) DeleteJudge(ctx context.Context, judgeID string) error { + resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/observability/judges/%s", judgeID), nil) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode >= 400 { + return parseAPIError(resp) + } + return nil +} + +// JudgeConversation performs live judging on a conversation. +func (c *Client) JudgeConversation(ctx context.Context, judgeID string, req *observability.JudgeConversationRequest) (json.RawMessage, error) { + var resp json.RawMessage + if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/judges/%s/live-judging", judgeID), req, &resp); err != nil { + return nil, err + } + return resp, nil +} diff --git a/observability_judges_test.go b/observability_judges_test.go new file mode 100644 index 0000000..44e1290 --- /dev/null +++ b/observability_judges_test.go @@ -0,0 +1,123 @@ +package mistral + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "somegit.dev/vikingowl/mistral-go-sdk/observability" +) + +func judgeJSON() map[string]any { + return map[string]any{ + "id": "j1", "created_at": "t", "updated_at": "t", + "owner_id": "o", "workspace_id": "w", "name": "quality", + "description": "d", "model_name": "m", "instructions": "i", + "output": map[string]any{"type": "CLASSIFICATION", "options": []any{}}, + } +} + +func TestCreateJudge_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" || r.URL.Path != "/v1/observability/judges" { + t.Errorf("unexpected %s %s", r.Method, r.URL.Path) + } + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(judgeJSON()) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.CreateJudge(context.Background(), &observability.CreateJudgeRequest{ + Name: "quality", + Description: "d", + ModelName: "m", + Instructions: "i", + Tools: []string{}, + Output: json.RawMessage(`{"type":"CLASSIFICATION","options":[]}`), + }) + if err != nil { + t.Fatal(err) + } + if resp.ID != "j1" { + t.Errorf("got id %q", resp.ID) + } +} + +func TestListJudges_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(map[string]any{ + "count": 1, + "results": []any{judgeJSON()}, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.ListJudges(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + if resp.Count != 1 { + t.Errorf("got count %d", resp.Count) + } +} + +func TestGetJudge_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/observability/judges/j1" { + t.Errorf("got path %s", r.URL.Path) + } + json.NewEncoder(w).Encode(judgeJSON()) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.GetJudge(context.Background(), "j1") + if err != nil { + t.Fatal(err) + } + if resp.Name != "quality" { + t.Errorf("got name %q", resp.Name) + } +} + +func TestUpdateJudge_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "PUT" { + t.Errorf("expected PUT, got %s", r.Method) + } + json.NewEncoder(w).Encode(judgeJSON()) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + _, err := client.UpdateJudge(context.Background(), "j1", &observability.UpdateJudgeRequest{ + Name: "quality", + Description: "d", + ModelName: "m", + Instructions: "i", + Tools: []string{}, + Output: json.RawMessage(`{"type":"CLASSIFICATION","options":[]}`), + }) + if err != nil { + t.Fatal(err) + } +} + +func TestDeleteJudge_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "DELETE" { + t.Errorf("expected DELETE") + } + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + if err := client.DeleteJudge(context.Background(), "j1"); err != nil { + t.Fatal(err) + } +}