feat: v1.1.0 — sync with upstream Python SDK v2.1.3
Add Connectors, Audio Speech/Voices, Audio Realtime types, and Observability (beta). 41 new service methods, 116 total. Breaking: ListModels and UploadFile signatures changed (pass nil for previous behavior).
This commit is contained in:
34
CHANGELOG.md
34
CHANGELOG.md
@@ -1,3 +1,37 @@
|
||||
## v1.1.0 — 2026-03-24
|
||||
|
||||
Upstream sync with Python SDK v2.1.3. Adds Connectors, Audio Speech/Voices, and Observability (beta).
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
- **`ListModels`** signature changed from `(ctx)` to `(ctx, *model.ListParams)`.
|
||||
Pass `nil` for previous behavior. The new `ListParams` supports `Provider` and
|
||||
`Model` query filters.
|
||||
- **`UploadFile`** signature changed from `(ctx, filename, reader, purpose)` to
|
||||
`(ctx, filename, reader, *file.UploadParams)`. The new `UploadParams` struct
|
||||
holds `Purpose`, `Expiry`, and `Visibility` fields.
|
||||
|
||||
### Added
|
||||
|
||||
- **`ReasoningEffort`** field on `chat.CompletionRequest` and
|
||||
`agents.CompletionRequest` — controls reasoning effort (`"none"`, `"high"`).
|
||||
- **Connectors API** (new `connector/` package) — `CreateConnector`,
|
||||
`ListConnectors`, `GetConnector`, `UpdateConnector`, `DeleteConnector`,
|
||||
`GetConnectorAuthURL`, `ListConnectorTools`, `CallConnectorTool`.
|
||||
- **Audio Speech (TTS)** — `Speech`, `SpeechStream` with `SpeechStream` typed
|
||||
wrapper, `SpeechOutputFormat` enum (pcm/wav/mp3/flac/opus).
|
||||
- **Audio Voices** — `ListVoices`, `CreateVoice`, `GetVoice`, `UpdateVoice`,
|
||||
`DeleteVoice`, `GetVoiceSampleAudio`.
|
||||
- **Audio Realtime types** — `AudioEncoding`, `AudioFormat`, `RealtimeSession`,
|
||||
and WebSocket message types in `audio/realtime.go`. No WebSocket client yet
|
||||
(would require adding a dependency).
|
||||
- **Observability API** (new `observability/` package, beta) — campaigns,
|
||||
chat completion events, judges, datasets, records, and import tasks.
|
||||
33 service methods total.
|
||||
- **`file.Visibility`** enum — `shared_global`, `shared_org`,
|
||||
`shared_workspace`, `private`.
|
||||
- **`model.ListParams`** — filter models by `Provider` and `Model`.
|
||||
|
||||
## v1.0.0 — 2026-03-17
|
||||
|
||||
Stable release. Tracks upstream Python SDK v2.0.4.
|
||||
|
||||
89
CLAUDE.md
Normal file
89
CLAUDE.md
Normal file
@@ -0,0 +1,89 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Project
|
||||
|
||||
Idiomatic Go SDK for the Mistral AI API. Module path: `somegit.dev/vikingowl/mistral-go-sdk`. Requires Go 1.26+. Zero external dependencies (stdlib only). Tracks the upstream [Mistral Python SDK](https://github.com/mistralai/client-python) as reference for API surface and type definitions.
|
||||
|
||||
## Repository layout
|
||||
|
||||
- **Working directory**: `mistral-go-sdk/` — the Go SDK source. All development happens here.
|
||||
- **`../client-python/`**: Clone of the upstream Mistral Python SDK. Read-only reference — pull/update it when checking for upstream API changes, but never modify it.
|
||||
|
||||
## Commands
|
||||
|
||||
```bash
|
||||
# Run all unit tests
|
||||
go test ./...
|
||||
|
||||
# Run a single test
|
||||
go test -run TestChatComplete_Success
|
||||
|
||||
# Run integration tests (requires MISTRAL_API_KEY env var)
|
||||
go test -tags=integration ./...
|
||||
|
||||
# Vet and build
|
||||
go vet ./...
|
||||
go build ./...
|
||||
```
|
||||
|
||||
No Makefile, linter config, or code generation tooling — standard `go test` / `go vet` / `go build`.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Two-layer design: types in sub-packages, methods on `*Client`
|
||||
|
||||
Sub-packages (`chat/`, `agents/`, `conversation/`, `embedding/`, `model/`, `file/`, `finetune/`, `batch/`, `ocr/`, `audio/`, `library/`, `moderation/`, `classification/`, `fim/`) are **types-only** — they define request/response structs and enums but contain no HTTP logic. All service methods live on `*Client` in the root package, prefix-namespaced by domain (e.g. `ChatComplete`, `AgentsComplete`, `CreateFineTuningJob`, `UploadFile`).
|
||||
|
||||
### HTTP internals (request.go)
|
||||
|
||||
All HTTP flows route through a small set of unexported helpers on `*Client`:
|
||||
- `do()` — raw HTTP with auth headers + retry
|
||||
- `doJSON()` — JSON marshal request → `do()` → unmarshal response
|
||||
- `doStream()` — JSON request → raw `*http.Response` for SSE
|
||||
- `doMultipart()` / `doMultipartStream()` — multipart file upload variants
|
||||
- `doRetry()` — retry loop with exponential backoff + jitter + `Retry-After` parsing
|
||||
|
||||
### Streaming
|
||||
|
||||
Generic `Stream[T]` type wraps SSE (`sseReader`) with `Next()`/`Current()`/`Err()`/`Close()` iterator pattern. Typed wrappers `EventStream` (conversations) and `AudioStream` (transcription) unmarshal `json.RawMessage` into domain-specific event types.
|
||||
|
||||
### Sealed interfaces for discriminated unions
|
||||
|
||||
Polymorphic API types use **sealed interfaces** with unexported marker methods:
|
||||
- `chat.Message` (marker: `isMessage()`) — `SystemMessage`, `UserMessage`, `AssistantMessage`, `ToolMessage`
|
||||
- `chat.ContentChunk` (marker: `contentChunk()`) — `TextChunk`, `ImageURLChunk`, `DocumentURLChunk`, `FileChunk`, `ReferenceChunk`, `ThinkChunk`, `AudioChunk`, `ToolReferenceChunk`, `ToolFileChunk`
|
||||
- `agents.AgentTool` (marker: `agentToolType()`) — `FunctionTool`, `WebSearchTool`, `CodeInterpreterTool`, `ConnectorTool`, etc.
|
||||
- `conversation.Event` — conversation streaming events
|
||||
|
||||
Each has an `Unknown*` variant so the SDK doesn't break on new API types. Each has a `Unmarshal*` dispatch function that probes a `type`/`role` discriminator field.
|
||||
|
||||
### Custom JSON patterns
|
||||
|
||||
Several types require non-trivial marshal/unmarshal:
|
||||
- **Type alias trick** — `type alias T` inside `MarshalJSON` to avoid infinite recursion when injecting a `type`/`role` discriminator field.
|
||||
- **`json:"-"` + custom MarshalJSON** — `CompletionRequest.Messages` (and `stream`) are excluded from default marshaling and injected via custom `MarshalJSON`.
|
||||
- **Union types** — `Content` handles `string | null | []ContentChunk`; `ToolChoice` handles `string | object`; `ImageURL` handles `string | object`; `FunctionCall.Arguments` handles `string | object`; `ReferenceID` handles `int | string` with type preservation.
|
||||
- **Probe struct pattern** — `Unmarshal*` functions decode only the discriminator field first, then dispatch to the concrete type.
|
||||
|
||||
### Shared types in `chat/`
|
||||
|
||||
`GuardrailConfig`, `ModerationLLMV1Config`, `ModerationLLMV2Config` live in `chat/` because it's the base types package imported by both `agents/` and `conversation/`. This avoids import cycles.
|
||||
|
||||
### Error handling
|
||||
|
||||
`APIError` in `error.go` with sentinel checkers: `IsNotFound()`, `IsRateLimit()`, `IsAuth()`. All use `errors.As` for unwrapping.
|
||||
|
||||
## Testing patterns
|
||||
|
||||
- Unit tests use `httptest.NewServer` with inline handlers to mock the Mistral API. Client is pointed at the test server via `WithBaseURL(server.URL)`.
|
||||
- Integration tests are behind `//go:build integration` build tag and require `MISTRAL_API_KEY`.
|
||||
- Tests use stdlib `testing` only — no third-party test frameworks.
|
||||
|
||||
## Adding a new API endpoint
|
||||
|
||||
1. Define request/response types in the appropriate sub-package (or create a new one with a `doc.go`).
|
||||
2. Add a method on `*Client` in the root package. Use `doJSON` for standard request/response, `doStream` for SSE, `doMultipart` for file uploads.
|
||||
3. Add unit tests with `httptest.NewServer`.
|
||||
4. If the endpoint supports streaming, return `*Stream[T]` and call `EnableStream()` on the request before sending.
|
||||
18
README.md
18
README.md
@@ -11,7 +11,7 @@ The most complete Go client for the [Mistral AI API](https://docs.mistral.ai/).
|
||||
|
||||
**Zero dependencies.** The entire SDK — including tests — uses only the Go standard library. No `go.sum`, no transitive dependency tree to audit, no version conflicts, no supply chain risk.
|
||||
|
||||
**Full API coverage.** 75 methods across every Mistral endpoint — including Conversations, Agents CRUD, Libraries, OCR, Audio, Fine-tuning, and Batch Jobs. No other Go SDK covers Conversations or Agents.
|
||||
**Full API coverage.** 116 methods across every Mistral endpoint — including Connectors, Audio Speech/Voices, Conversations, Agents CRUD, Libraries, OCR, Observability, Fine-tuning, and Batch Jobs. No other Go SDK covers Conversations, Connectors, or Observability.
|
||||
|
||||
**Typed streaming.** A generic pull-based `Stream[T]` iterator — no channels, no goroutines, no leaks. Just `Next()` / `Current()` / `Err()` / `Close()`.
|
||||
|
||||
@@ -132,7 +132,7 @@ for stream.Next() {
|
||||
|
||||
## API Coverage
|
||||
|
||||
75 public methods on `Client`, grouped by domain:
|
||||
116 public methods on `Client`, grouped by domain:
|
||||
|
||||
| Domain | Methods |
|
||||
|--------|---------|
|
||||
@@ -140,6 +140,7 @@ for stream.Next() {
|
||||
| **FIM** | `FIMComplete`, `FIMCompleteStream` |
|
||||
| **Agents (completions)** | `AgentsComplete`, `AgentsCompleteStream` |
|
||||
| **Agents (CRUD)** | `CreateAgent`, `ListAgents`, `GetAgent`, `UpdateAgent`, `DeleteAgent`, `UpdateAgentVersion`, `ListAgentVersions`, `GetAgentVersion`, `SetAgentAlias`, `ListAgentAliases`, `DeleteAgentAlias` |
|
||||
| **Connectors** | `CreateConnector`, `ListConnectors`, `GetConnector`, `UpdateConnector`, `DeleteConnector`, `GetConnectorAuthURL`, `ListConnectorTools`, `CallConnectorTool` |
|
||||
| **Conversations** | `StartConversation`, `StartConversationStream`, `AppendConversation`, `AppendConversationStream`, `RestartConversation`, `RestartConversationStream`, `GetConversation`, `ListConversations`, `DeleteConversation`, `GetConversationHistory`, `GetConversationMessages` |
|
||||
| **Models** | `ListModels`, `GetModel`, `DeleteModel` |
|
||||
| **Files** | `UploadFile`, `ListFiles`, `GetFile`, `DeleteFile`, `GetFileContent`, `GetFileURL` |
|
||||
@@ -147,10 +148,16 @@ for stream.Next() {
|
||||
| **Fine-tuning** | `CreateFineTuningJob`, `ListFineTuningJobs`, `GetFineTuningJob`, `CancelFineTuningJob`, `StartFineTuningJob`, `UpdateFineTunedModel`, `ArchiveFineTunedModel`, `UnarchiveFineTunedModel` |
|
||||
| **Batch** | `CreateBatchJob`, `ListBatchJobs`, `GetBatchJob`, `CancelBatchJob` |
|
||||
| **OCR** | `OCR` |
|
||||
| **Audio** | `Transcribe`, `TranscribeStream` |
|
||||
| **Audio (transcription)** | `Transcribe`, `TranscribeStream` |
|
||||
| **Audio (speech)** | `Speech`, `SpeechStream` |
|
||||
| **Audio (voices)** | `ListVoices`, `CreateVoice`, `GetVoice`, `UpdateVoice`, `DeleteVoice`, `GetVoiceSampleAudio` |
|
||||
| **Libraries** | `CreateLibrary`, `ListLibraries`, `GetLibrary`, `UpdateLibrary`, `DeleteLibrary`, `UploadDocument`, `ListDocuments`, `GetDocument`, `UpdateDocument`, `DeleteDocument`, `GetDocumentTextContent`, `GetDocumentStatus`, `GetDocumentSignedURL`, `GetDocumentExtractedTextSignedURL`, `ReprocessDocument`, `ListLibrarySharing`, `ShareLibrary`, `UnshareLibrary` |
|
||||
| **Moderation** | `Moderate`, `ModerateChat` |
|
||||
| **Classification** | `Classify`, `ClassifyChat` |
|
||||
| **Observability (campaigns)** | `CreateCampaign`, `ListCampaigns`, `GetCampaign`, `DeleteCampaign`, `GetCampaignStatus`, `ListCampaignEvents` |
|
||||
| **Observability (events)** | `SearchChatCompletionEvents`, `SearchChatCompletionEventIDs`, `GetChatCompletionEvent`, `GetSimilarChatCompletionEvents`, `JudgeChatCompletionEvent` |
|
||||
| **Observability (judges)** | `CreateJudge`, `ListJudges`, `GetJudge`, `UpdateJudge`, `DeleteJudge`, `JudgeConversation` |
|
||||
| **Observability (datasets)** | `CreateDataset`, `ListDatasets`, `GetDataset`, `UpdateDataset`, `DeleteDataset`, `ExportDatasetToJSONL`, `ListDatasetRecords`, `CreateDatasetRecord`, `GetDatasetRecord`, `UpdateDatasetRecordPayload`, `UpdateDatasetRecordProperties`, `DeleteDatasetRecord`, `BulkDeleteDatasetRecords`, `JudgeDatasetRecord`, `ImportDatasetFromCampaign`, `ImportDatasetFromExplorer`, `ImportDatasetFromFile`, `ImportDatasetFromPlayground`, `ImportDatasetFromDataset`, `ListDatasetTasks`, `GetDatasetTask` |
|
||||
|
||||
## Comparison
|
||||
|
||||
@@ -163,11 +170,13 @@ There is no official Go SDK from Mistral AI (only Python and TypeScript). The ma
|
||||
| Embeddings | Yes | Yes | Yes | Yes |
|
||||
| Tool calling | Yes | No | No | No |
|
||||
| Agents (completions + CRUD) | Yes | No | No | No |
|
||||
| Connectors (MCP) | Yes | No | No | No |
|
||||
| Conversations API | Yes | No | No | No |
|
||||
| Libraries / Documents | Yes | No | No | No |
|
||||
| Fine-tuning / Batch | Yes | No | No | No |
|
||||
| OCR | Yes | No | No | Yes |
|
||||
| Audio transcription | Yes | No | No | No |
|
||||
| Audio (transcription + TTS + voices) | Yes | No | No | No |
|
||||
| Observability (beta) | Yes | No | No | No |
|
||||
| Moderation / Classification | Yes | No | No | No |
|
||||
| Vision (multimodal) | Yes | No | No | Yes |
|
||||
| Zero dependencies | Yes | test-only (testify) | test-only (testify) | test-only (testify) |
|
||||
@@ -221,6 +230,7 @@ as its upstream reference for API surface and type definitions.
|
||||
|
||||
| SDK Version | Upstream Python SDK |
|
||||
|-------------|---------------------|
|
||||
| v1.1.0 | v2.1.3 |
|
||||
| v1.0.0 | v2.0.4 |
|
||||
|
||||
## License
|
||||
|
||||
@@ -26,6 +26,7 @@ type CompletionRequest struct {
|
||||
Prediction *chat.Prediction `json:"prediction,omitempty"`
|
||||
PromptMode *chat.PromptMode `json:"prompt_mode,omitempty"`
|
||||
Guardrails []chat.GuardrailConfig `json:"guardrails,omitempty"`
|
||||
ReasoningEffort *chat.ReasoningEffort `json:"reasoning_effort,omitempty"`
|
||||
stream bool
|
||||
}
|
||||
|
||||
|
||||
@@ -114,6 +114,39 @@ func TestAgentsComplete_WithTools(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentsComplete_ReasoningEffort(t *testing.T) {
|
||||
effort := chat.ReasoningEffortHigh
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var body map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
if body["reasoning_effort"] != "high" {
|
||||
t.Errorf("expected reasoning_effort=high, got %v", body["reasoning_effort"])
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "a-re", "object": "chat.completion",
|
||||
"model": "m", "created": 0,
|
||||
"choices": []map[string]any{{
|
||||
"index": 0, "message": map[string]any{"role": "assistant", "content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
}},
|
||||
"usage": map[string]any{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
_, err := client.AgentsComplete(context.Background(), &agents.CompletionRequest{
|
||||
AgentID: "agent-1",
|
||||
Messages: []chat.Message{&chat.UserMessage{Content: chat.TextContent("Hi")}},
|
||||
ReasoningEffort: &effort,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAgentsCompleteStream_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var body map[string]any
|
||||
|
||||
22
audio/doc.go
22
audio/doc.go
@@ -1,5 +1,25 @@
|
||||
// Package audio provides types for the Mistral audio transcription API.
|
||||
// Package audio provides types for the Mistral audio APIs.
|
||||
//
|
||||
// # Transcription
|
||||
//
|
||||
// [TranscriptionRequest] and [TranscriptionResponse] handle speech-to-text.
|
||||
// Streaming transcription returns typed [StreamEvent] values via a sealed
|
||||
// interface dispatched by the "type" field.
|
||||
//
|
||||
// # Speech (TTS)
|
||||
//
|
||||
// [SpeechRequest] and [SpeechResponse] handle text-to-speech.
|
||||
// Streaming speech returns typed [SpeechStreamEvent] values
|
||||
// ([SpeechAudioDelta] and [SpeechDone]).
|
||||
//
|
||||
// # Voices
|
||||
//
|
||||
// [VoiceResponse], [VoiceCreateRequest], and [VoiceUpdateRequest] manage
|
||||
// custom voices for speech synthesis.
|
||||
//
|
||||
// # Realtime
|
||||
//
|
||||
// Realtime transcription types ([AudioEncoding], [AudioFormat],
|
||||
// [RealtimeSession], and WebSocket message types) are defined here.
|
||||
// The WebSocket client is not yet implemented.
|
||||
package audio
|
||||
|
||||
75
audio/realtime.go
Normal file
75
audio/realtime.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package audio
|
||||
|
||||
// AudioEncoding is the encoding format for realtime audio streams.
|
||||
type AudioEncoding string
|
||||
|
||||
const (
|
||||
EncodingPCMS16LE AudioEncoding = "pcm_s16le"
|
||||
EncodingPCMS32LE AudioEncoding = "pcm_s32le"
|
||||
EncodingPCMF16LE AudioEncoding = "pcm_f16le"
|
||||
EncodingPCMF32LE AudioEncoding = "pcm_f32le"
|
||||
EncodingPCMMulaw AudioEncoding = "pcm_mulaw"
|
||||
EncodingPCMAlaw AudioEncoding = "pcm_alaw"
|
||||
)
|
||||
|
||||
// AudioFormat describes the encoding and sample rate for realtime audio.
|
||||
type AudioFormat struct {
|
||||
Encoding AudioEncoding `json:"encoding"`
|
||||
SampleRate int `json:"sample_rate"`
|
||||
}
|
||||
|
||||
// RealtimeSession describes a realtime transcription session.
|
||||
type RealtimeSession struct {
|
||||
RequestID string `json:"request_id"`
|
||||
Model string `json:"model"`
|
||||
AudioFormat AudioFormat `json:"audio_format"`
|
||||
TargetStreamingDelayMs *int `json:"target_streaming_delay_ms,omitempty"`
|
||||
}
|
||||
|
||||
// RealtimeSessionUpdate is sent to update session parameters.
|
||||
// Parameters can only be changed before audio transmission starts.
|
||||
type RealtimeSessionUpdate struct {
|
||||
AudioFormat *AudioFormat `json:"audio_format,omitempty"`
|
||||
TargetStreamingDelayMs *int `json:"target_streaming_delay_ms,omitempty"`
|
||||
}
|
||||
|
||||
// InputAudioAppend sends a chunk of audio data.
|
||||
// Audio is base64-encoded (max 262144 bytes decoded).
|
||||
type InputAudioAppend struct {
|
||||
Type string `json:"type"` // "input_audio.append"
|
||||
Audio string `json:"audio"`
|
||||
}
|
||||
|
||||
// InputAudioFlush flushes the audio buffer.
|
||||
type InputAudioFlush struct {
|
||||
Type string `json:"type"` // "input_audio.flush"
|
||||
}
|
||||
|
||||
// InputAudioEnd signals the end of audio input.
|
||||
type InputAudioEnd struct {
|
||||
Type string `json:"type"` // "input_audio.end"
|
||||
}
|
||||
|
||||
// RealtimeSessionCreated is received when a session is created.
|
||||
type RealtimeSessionCreated struct {
|
||||
Type string `json:"type"` // "session.created"
|
||||
Session RealtimeSession `json:"session"`
|
||||
}
|
||||
|
||||
// RealtimeSessionUpdated is received when a session is updated.
|
||||
type RealtimeSessionUpdated struct {
|
||||
Type string `json:"type"` // "session.updated"
|
||||
Session RealtimeSession `json:"session"`
|
||||
}
|
||||
|
||||
// RealtimeErrorDetail describes a realtime error.
|
||||
type RealtimeErrorDetail struct {
|
||||
Message string `json:"message"`
|
||||
Code int `json:"code"`
|
||||
}
|
||||
|
||||
// RealtimeError is received on error.
|
||||
type RealtimeError struct {
|
||||
Type string `json:"type"` // "error"
|
||||
Error RealtimeErrorDetail `json:"error"`
|
||||
}
|
||||
88
audio/speech.go
Normal file
88
audio/speech.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package audio
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// SpeechOutputFormat is the output audio format for speech synthesis.
|
||||
type SpeechOutputFormat string
|
||||
|
||||
const (
|
||||
SpeechFormatPCM SpeechOutputFormat = "pcm"
|
||||
SpeechFormatWAV SpeechOutputFormat = "wav"
|
||||
SpeechFormatMP3 SpeechOutputFormat = "mp3"
|
||||
SpeechFormatFLAC SpeechOutputFormat = "flac"
|
||||
SpeechFormatOpus SpeechOutputFormat = "opus"
|
||||
)
|
||||
|
||||
// SpeechRequest represents a text-to-speech request.
|
||||
type SpeechRequest struct {
|
||||
Input string `json:"input"`
|
||||
Model string `json:"model"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
VoiceID *string `json:"voice_id,omitempty"`
|
||||
RefAudio *string `json:"ref_audio,omitempty"`
|
||||
ResponseFormat *SpeechOutputFormat `json:"response_format,omitempty"`
|
||||
stream bool
|
||||
}
|
||||
|
||||
// EnableStream is used internally to enable streaming.
|
||||
func (r *SpeechRequest) EnableStream() { r.stream = true }
|
||||
|
||||
func (r *SpeechRequest) MarshalJSON() ([]byte, error) {
|
||||
type Alias SpeechRequest
|
||||
return json.Marshal(&struct {
|
||||
Stream bool `json:"stream"`
|
||||
*Alias
|
||||
}{
|
||||
Stream: r.stream,
|
||||
Alias: (*Alias)(r),
|
||||
})
|
||||
}
|
||||
|
||||
// SpeechResponse is the response from a non-streaming speech request.
|
||||
type SpeechResponse struct {
|
||||
AudioData string `json:"audio_data"`
|
||||
}
|
||||
|
||||
// SpeechStreamEvent is a sealed interface for speech streaming events.
|
||||
type SpeechStreamEvent interface {
|
||||
speechStreamEvent()
|
||||
}
|
||||
|
||||
// SpeechAudioDelta contains a chunk of audio data during streaming.
|
||||
type SpeechAudioDelta struct {
|
||||
Type string `json:"type"`
|
||||
AudioData string `json:"audio_data"`
|
||||
}
|
||||
|
||||
func (*SpeechAudioDelta) speechStreamEvent() {}
|
||||
|
||||
// SpeechDone is emitted when speech synthesis is complete.
|
||||
type SpeechDone struct {
|
||||
Type string `json:"type"`
|
||||
Usage UsageInfo `json:"usage"`
|
||||
}
|
||||
|
||||
func (*SpeechDone) speechStreamEvent() {}
|
||||
|
||||
// UnmarshalSpeechStreamEvent dispatches a raw JSON event to the correct type.
|
||||
func UnmarshalSpeechStreamEvent(data []byte) (SpeechStreamEvent, error) {
|
||||
var probe struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &probe); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch probe.Type {
|
||||
case "speech.audio.delta":
|
||||
var e SpeechAudioDelta
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
case "speech.audio.done":
|
||||
var e SpeechDone
|
||||
return &e, json.Unmarshal(data, &e)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown speech stream event type: %q", probe.Type)
|
||||
}
|
||||
}
|
||||
48
audio/voice.go
Normal file
48
audio/voice.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package audio
|
||||
|
||||
// VoiceResponse represents a voice entity.
|
||||
type VoiceResponse struct {
|
||||
Name string `json:"name"`
|
||||
ID string `json:"id"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UserID *string `json:"user_id,omitempty"`
|
||||
Slug *string `json:"slug,omitempty"`
|
||||
Languages []string `json:"languages,omitempty"`
|
||||
Gender *string `json:"gender,omitempty"`
|
||||
Age *int `json:"age,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
Color *string `json:"color,omitempty"`
|
||||
RetentionNotice *int `json:"retention_notice,omitempty"`
|
||||
}
|
||||
|
||||
// VoiceCreateRequest creates a custom voice.
|
||||
type VoiceCreateRequest struct {
|
||||
Name string `json:"name"`
|
||||
SampleAudio string `json:"sample_audio"`
|
||||
Slug *string `json:"slug,omitempty"`
|
||||
Languages []string `json:"languages,omitempty"`
|
||||
Gender *string `json:"gender,omitempty"`
|
||||
Age *int `json:"age,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
Color *string `json:"color,omitempty"`
|
||||
RetentionNotice *int `json:"retention_notice,omitempty"`
|
||||
SampleFilename *string `json:"sample_filename,omitempty"`
|
||||
}
|
||||
|
||||
// VoiceUpdateRequest updates a voice.
|
||||
type VoiceUpdateRequest struct {
|
||||
Name *string `json:"name,omitempty"`
|
||||
Languages []string `json:"languages,omitempty"`
|
||||
Gender *string `json:"gender,omitempty"`
|
||||
Age *int `json:"age,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
}
|
||||
|
||||
// VoiceListResponse is the response from listing voices.
|
||||
type VoiceListResponse struct {
|
||||
Items []VoiceResponse `json:"items"`
|
||||
Total int `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
}
|
||||
124
audio_api.go
124
audio_api.go
@@ -3,7 +3,9 @@ package mistral
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/audio"
|
||||
)
|
||||
@@ -95,3 +97,125 @@ func (s *AudioStream) Err() error { return s.err }
|
||||
|
||||
// Close releases the underlying connection.
|
||||
func (s *AudioStream) Close() error { return s.stream.Close() }
|
||||
|
||||
// Speech sends a text-to-speech request and returns the full response.
|
||||
func (c *Client) Speech(ctx context.Context, req *audio.SpeechRequest) (*audio.SpeechResponse, error) {
|
||||
var resp audio.SpeechResponse
|
||||
if err := c.doJSON(ctx, "POST", "/v1/audio/speech", req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// SpeechStream sends a text-to-speech request and returns a stream of audio events.
|
||||
func (c *Client) SpeechStream(ctx context.Context, req *audio.SpeechRequest) (*SpeechStream, error) {
|
||||
req.EnableStream()
|
||||
resp, err := c.doStream(ctx, "POST", "/v1/audio/speech", req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newSpeechStream(resp.Body), nil
|
||||
}
|
||||
|
||||
// SpeechStream wraps the generic Stream for speech streaming events.
|
||||
type SpeechStream struct {
|
||||
stream *Stream[json.RawMessage]
|
||||
event audio.SpeechStreamEvent
|
||||
err error
|
||||
}
|
||||
|
||||
func newSpeechStream(body readCloser) *SpeechStream {
|
||||
return &SpeechStream{
|
||||
stream: newStream[json.RawMessage](body),
|
||||
}
|
||||
}
|
||||
|
||||
// Next advances to the next event. Returns false when done or on error.
|
||||
func (s *SpeechStream) Next() bool {
|
||||
if s.err != nil {
|
||||
return false
|
||||
}
|
||||
if !s.stream.Next() {
|
||||
s.err = s.stream.Err()
|
||||
return false
|
||||
}
|
||||
event, err := audio.UnmarshalSpeechStreamEvent(s.stream.Current())
|
||||
if err != nil {
|
||||
s.err = err
|
||||
return false
|
||||
}
|
||||
s.event = event
|
||||
return true
|
||||
}
|
||||
|
||||
// Current returns the most recently read event.
|
||||
func (s *SpeechStream) Current() audio.SpeechStreamEvent { return s.event }
|
||||
|
||||
// Err returns any error encountered during streaming.
|
||||
func (s *SpeechStream) Err() error { return s.err }
|
||||
|
||||
// Close releases the underlying connection.
|
||||
func (s *SpeechStream) Close() error { return s.stream.Close() }
|
||||
|
||||
// ListVoices returns available voices.
|
||||
func (c *Client) ListVoices(ctx context.Context) (*audio.VoiceListResponse, error) {
|
||||
var resp audio.VoiceListResponse
|
||||
if err := c.doJSON(ctx, "GET", "/v1/audio/voices", nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// CreateVoice creates a custom voice.
|
||||
func (c *Client) CreateVoice(ctx context.Context, req *audio.VoiceCreateRequest) (*audio.VoiceResponse, error) {
|
||||
var resp audio.VoiceResponse
|
||||
if err := c.doJSON(ctx, "POST", "/v1/audio/voices", req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// GetVoice retrieves a voice by ID.
|
||||
func (c *Client) GetVoice(ctx context.Context, voiceID string) (*audio.VoiceResponse, error) {
|
||||
var resp audio.VoiceResponse
|
||||
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/audio/voices/%s", voiceID), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// UpdateVoice updates a voice.
|
||||
func (c *Client) UpdateVoice(ctx context.Context, voiceID string, req *audio.VoiceUpdateRequest) (*audio.VoiceResponse, error) {
|
||||
var resp audio.VoiceResponse
|
||||
if err := c.doJSON(ctx, "PATCH", fmt.Sprintf("/v1/audio/voices/%s", voiceID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// DeleteVoice deletes a voice.
|
||||
func (c *Client) DeleteVoice(ctx context.Context, voiceID string) error {
|
||||
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/audio/voices/%s", voiceID), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 400 {
|
||||
return parseAPIError(resp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetVoiceSampleAudio retrieves the sample audio for a voice.
|
||||
// Returns the raw HTTP response; the caller must close the body.
|
||||
func (c *Client) GetVoiceSampleAudio(ctx context.Context, voiceID string) (*http.Response, error) {
|
||||
resp, err := c.do(ctx, "GET", fmt.Sprintf("/v1/audio/voices/%s/sample", voiceID), nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
defer resp.Body.Close()
|
||||
return nil, parseAPIError(resp)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
217
audio_speech_test.go
Normal file
217
audio_speech_test.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/audio"
|
||||
)
|
||||
|
||||
func TestSpeech_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
t.Errorf("expected POST, got %s", r.Method)
|
||||
}
|
||||
if r.URL.Path != "/v1/audio/speech" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
var body map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
if body["input"] != "Hello world" {
|
||||
t.Errorf("got input %v", body["input"])
|
||||
}
|
||||
if body["stream"] != false {
|
||||
t.Errorf("expected stream=false")
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"audio_data": "base64audiodata==",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.Speech(context.Background(), &audio.SpeechRequest{
|
||||
Input: "Hello world",
|
||||
Model: "mistral-speech",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.AudioData != "base64audiodata==" {
|
||||
t.Errorf("got audio_data %q", resp.AudioData)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpeechStream_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var body map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
if body["stream"] != true {
|
||||
t.Errorf("expected stream=true")
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
flusher, _ := w.(http.Flusher)
|
||||
|
||||
delta, _ := json.Marshal(map[string]any{
|
||||
"type": "speech.audio.delta", "audio_data": "chunk1==",
|
||||
})
|
||||
fmt.Fprintf(w, "data: %s\n\n", delta)
|
||||
flusher.Flush()
|
||||
|
||||
done, _ := json.Marshal(map[string]any{
|
||||
"type": "speech.audio.done",
|
||||
"usage": map[string]any{
|
||||
"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15,
|
||||
},
|
||||
})
|
||||
fmt.Fprintf(w, "data: %s\n\n", done)
|
||||
flusher.Flush()
|
||||
|
||||
fmt.Fprint(w, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
stream, err := client.SpeechStream(context.Background(), &audio.SpeechRequest{
|
||||
Input: "Hi",
|
||||
Model: "mistral-speech",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
var events int
|
||||
for stream.Next() {
|
||||
events++
|
||||
}
|
||||
if err := stream.Err(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if events != 2 {
|
||||
t.Errorf("got %d events, want 2", events)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListVoices_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/audio/voices" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"items": []map[string]any{
|
||||
{"id": "v1", "name": "Default", "created_at": "2025-01-01"},
|
||||
},
|
||||
"total": 1, "page": 1, "page_size": 10, "total_pages": 1,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.ListVoices(context.Background())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(resp.Items) != 1 {
|
||||
t.Fatalf("got %d voices", len(resp.Items))
|
||||
}
|
||||
if resp.Items[0].ID != "v1" {
|
||||
t.Errorf("got id %q", resp.Items[0].ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateVoice_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
t.Errorf("expected POST, got %s", r.Method)
|
||||
}
|
||||
var body map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
if body["name"] != "MyVoice" {
|
||||
t.Errorf("got name %v", body["name"])
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "v2", "name": "MyVoice", "created_at": "2025-01-01",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.CreateVoice(context.Background(), &audio.VoiceCreateRequest{
|
||||
Name: "MyVoice",
|
||||
SampleAudio: "base64audio==",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.ID != "v2" {
|
||||
t.Errorf("got id %q", resp.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetVoice_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/audio/voices/v1" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "v1", "name": "Default", "created_at": "2025-01-01",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.GetVoice(context.Background(), "v1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Name != "Default" {
|
||||
t.Errorf("got name %q", resp.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateVoice_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "PATCH" {
|
||||
t.Errorf("expected PATCH, got %s", r.Method)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "v1", "name": "Renamed", "created_at": "2025-01-01",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
name := "Renamed"
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.UpdateVoice(context.Background(), "v1", &audio.VoiceUpdateRequest{
|
||||
Name: &name,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Name != "Renamed" {
|
||||
t.Errorf("got name %q", resp.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteVoice_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "DELETE" {
|
||||
t.Errorf("expected DELETE, got %s", r.Method)
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
err := client.DeleteVoice(context.Background(), "v1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
@@ -9,6 +9,14 @@ const (
|
||||
PromptModeReasoning PromptMode = "reasoning"
|
||||
)
|
||||
|
||||
// ReasoningEffort controls the amount of reasoning effort the model uses.
|
||||
type ReasoningEffort string
|
||||
|
||||
const (
|
||||
ReasoningEffortNone ReasoningEffort = "none"
|
||||
ReasoningEffortHigh ReasoningEffort = "high"
|
||||
)
|
||||
|
||||
// Prediction provides expected completion content for optimization.
|
||||
type Prediction struct {
|
||||
Type string `json:"type"`
|
||||
@@ -36,6 +44,7 @@ type CompletionRequest struct {
|
||||
Prediction *Prediction `json:"prediction,omitempty"`
|
||||
PromptMode *PromptMode `json:"prompt_mode,omitempty"`
|
||||
Guardrails []GuardrailConfig `json:"guardrails,omitempty"`
|
||||
ReasoningEffort *ReasoningEffort `json:"reasoning_effort,omitempty"`
|
||||
stream bool
|
||||
}
|
||||
|
||||
|
||||
@@ -350,6 +350,41 @@ func TestChatComplete_RequestBody(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatComplete_ReasoningEffort(t *testing.T) {
|
||||
effort := chat.ReasoningEffortHigh
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
bodyBytes, _ := io.ReadAll(r.Body)
|
||||
var body map[string]any
|
||||
json.Unmarshal(bodyBytes, &body)
|
||||
|
||||
if body["reasoning_effort"] != "high" {
|
||||
t.Errorf("expected reasoning_effort=high, got %v", body["reasoning_effort"])
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "chat-re", "object": "chat.completion",
|
||||
"model": "m", "created": 0,
|
||||
"choices": []map[string]any{{
|
||||
"index": 0, "message": map[string]any{"role": "assistant", "content": "ok"},
|
||||
"finish_reason": "stop",
|
||||
}},
|
||||
"usage": map[string]any{"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
_, err := client.ChatComplete(context.Background(), &chat.CompletionRequest{
|
||||
Model: "m",
|
||||
Messages: []chat.Message{&chat.UserMessage{Content: chat.TextContent("Hi")}},
|
||||
ReasoningEffort: &effort,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatComplete_ContextCanceled(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Never responds — context should cancel first
|
||||
|
||||
100
connector/connector.go
Normal file
100
connector/connector.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package connector
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// Visibility controls who can see a connector or tool.
|
||||
type Visibility string
|
||||
|
||||
const (
|
||||
VisibilitySharedGlobal Visibility = "shared_global"
|
||||
VisibilitySharedOrg Visibility = "shared_org"
|
||||
VisibilitySharedWorkspace Visibility = "shared_workspace"
|
||||
VisibilityPrivate Visibility = "private"
|
||||
)
|
||||
|
||||
// AuthData holds OAuth2 client credentials for a connector.
|
||||
type AuthData struct {
|
||||
ClientID string `json:"client_id"`
|
||||
ClientSecret string `json:"client_secret"`
|
||||
}
|
||||
|
||||
// Connector represents a registered MCP connector.
|
||||
type Connector struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
ModifiedAt string `json:"modified_at"`
|
||||
Server *string `json:"server,omitempty"`
|
||||
AuthType *string `json:"auth_type,omitempty"`
|
||||
}
|
||||
|
||||
// CreateRequest creates a new connector.
|
||||
type CreateRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Server string `json:"server"`
|
||||
IconURL *string `json:"icon_url,omitempty"`
|
||||
Visibility *Visibility `json:"visibility,omitempty"`
|
||||
Headers map[string]string `json:"headers,omitempty"`
|
||||
AuthData *AuthData `json:"auth_data,omitempty"`
|
||||
SystemPrompt *string `json:"system_prompt,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateRequest updates an existing connector.
|
||||
type UpdateRequest struct {
|
||||
Name *string `json:"name,omitempty"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
IconURL *string `json:"icon_url,omitempty"`
|
||||
SystemPrompt *string `json:"system_prompt,omitempty"`
|
||||
Server *string `json:"server,omitempty"`
|
||||
Headers map[string]string `json:"headers,omitempty"`
|
||||
AuthData *AuthData `json:"auth_data,omitempty"`
|
||||
ConnectionConfig map[string]any `json:"connection_config,omitempty"`
|
||||
ConnectionSecrets map[string]any `json:"connection_secrets,omitempty"`
|
||||
}
|
||||
|
||||
// AuthURLResponse is the response from getting a connector's OAuth URL.
|
||||
type AuthURLResponse struct {
|
||||
AuthURL string `json:"auth_url"`
|
||||
TTL int `json:"ttl"`
|
||||
}
|
||||
|
||||
// CallToolRequest is the request body for calling a connector tool.
|
||||
type CallToolRequest struct {
|
||||
Arguments map[string]any `json:"arguments,omitempty"`
|
||||
}
|
||||
|
||||
// CallToolResponse is the response from calling a connector tool.
|
||||
// Content is left as raw JSON because the upstream API returns a union
|
||||
// of 5 content types (text, image, audio, resource link, embedded resource).
|
||||
type CallToolResponse struct {
|
||||
Content json.RawMessage `json:"content"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// Tool represents a tool exposed by a connector.
|
||||
type Tool struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
Visibility Visibility `json:"visibility,omitempty"`
|
||||
CreatedAt string `json:"created_at,omitempty"`
|
||||
ModifiedAt string `json:"modified_at,omitempty"`
|
||||
SystemPrompt *string `json:"system_prompt,omitempty"`
|
||||
JsonSchema map[string]any `json:"jsonschema,omitempty"`
|
||||
Active *bool `json:"active,omitempty"`
|
||||
}
|
||||
|
||||
// ListParams holds query parameters for listing connectors.
|
||||
type ListParams struct {
|
||||
Page *int
|
||||
PageSize *int
|
||||
}
|
||||
|
||||
// ListToolsParams holds query parameters for listing connector tools.
|
||||
type ListToolsParams struct {
|
||||
Page *int
|
||||
PageSize *int
|
||||
Refresh *bool
|
||||
}
|
||||
6
connector/doc.go
Normal file
6
connector/doc.go
Normal file
@@ -0,0 +1,6 @@
|
||||
// Package connector provides types for the Mistral connectors API.
|
||||
//
|
||||
// Connectors represent MCP (Model Context Protocol) server integrations.
|
||||
// Use [CreateRequest] to register a new connector, then use tools
|
||||
// discovered via the list-tools endpoint in chat or agent completions.
|
||||
package connector
|
||||
119
connectors.go
Normal file
119
connectors.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/connector"
|
||||
)
|
||||
|
||||
// CreateConnector registers a new MCP connector.
|
||||
func (c *Client) CreateConnector(ctx context.Context, req *connector.CreateRequest) (*connector.Connector, error) {
|
||||
var resp connector.Connector
|
||||
if err := c.doJSON(ctx, "POST", "/v1/connectors", req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ListConnectors returns all connectors.
|
||||
func (c *Client) ListConnectors(ctx context.Context, params *connector.ListParams) ([]connector.Connector, error) {
|
||||
path := "/v1/connectors"
|
||||
if params != nil {
|
||||
q := url.Values{}
|
||||
if params.Page != nil {
|
||||
q.Set("page", strconv.Itoa(*params.Page))
|
||||
}
|
||||
if params.PageSize != nil {
|
||||
q.Set("page_size", strconv.Itoa(*params.PageSize))
|
||||
}
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
path += "?" + encoded
|
||||
}
|
||||
}
|
||||
var resp []connector.Connector
|
||||
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// GetConnector retrieves a connector by ID or name.
|
||||
func (c *Client) GetConnector(ctx context.Context, idOrName string) (*connector.Connector, error) {
|
||||
var resp connector.Connector
|
||||
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/connectors/%s", idOrName), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// UpdateConnector updates an existing connector.
|
||||
func (c *Client) UpdateConnector(ctx context.Context, idOrName string, req *connector.UpdateRequest) (*connector.Connector, error) {
|
||||
var resp connector.Connector
|
||||
if err := c.doJSON(ctx, "PATCH", fmt.Sprintf("/v1/connectors/%s", idOrName), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// DeleteConnector deletes a connector.
|
||||
func (c *Client) DeleteConnector(ctx context.Context, idOrName string) error {
|
||||
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/connectors/%s", idOrName), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 400 {
|
||||
return parseAPIError(resp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetConnectorAuthURL returns the OAuth2 authorization URL for a connector.
|
||||
func (c *Client) GetConnectorAuthURL(ctx context.Context, idOrName string, appReturnURL *string) (*connector.AuthURLResponse, error) {
|
||||
path := fmt.Sprintf("/v1/connectors/%s/auth_url", idOrName)
|
||||
if appReturnURL != nil {
|
||||
path += "?app_return_url=" + url.QueryEscape(*appReturnURL)
|
||||
}
|
||||
var resp connector.AuthURLResponse
|
||||
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ListConnectorTools lists tools exposed by a connector.
|
||||
func (c *Client) ListConnectorTools(ctx context.Context, idOrName string, params *connector.ListToolsParams) ([]connector.Tool, error) {
|
||||
path := fmt.Sprintf("/v1/connectors/%s/tools", idOrName)
|
||||
if params != nil {
|
||||
q := url.Values{}
|
||||
if params.Page != nil {
|
||||
q.Set("page", strconv.Itoa(*params.Page))
|
||||
}
|
||||
if params.PageSize != nil {
|
||||
q.Set("page_size", strconv.Itoa(*params.PageSize))
|
||||
}
|
||||
if params.Refresh != nil {
|
||||
q.Set("refresh", strconv.FormatBool(*params.Refresh))
|
||||
}
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
path += "?" + encoded
|
||||
}
|
||||
}
|
||||
var resp []connector.Tool
|
||||
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// CallConnectorTool invokes a tool on a connector.
|
||||
func (c *Client) CallConnectorTool(ctx context.Context, idOrName, toolName string, req *connector.CallToolRequest) (*connector.CallToolResponse, error) {
|
||||
var resp connector.CallToolResponse
|
||||
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/connectors/%s/tools/%s/call", idOrName, toolName), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
217
connectors_test.go
Normal file
217
connectors_test.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/connector"
|
||||
)
|
||||
|
||||
func TestCreateConnector_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
t.Errorf("expected POST, got %s", r.Method)
|
||||
}
|
||||
if r.URL.Path != "/v1/connectors" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
var body map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
if body["name"] != "my_connector" {
|
||||
t.Errorf("got name %v", body["name"])
|
||||
}
|
||||
if body["server"] != "https://mcp.example.com" {
|
||||
t.Errorf("got server %v", body["server"])
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "conn-1", "name": "my_connector",
|
||||
"description": "test", "created_at": "2025-01-01",
|
||||
"modified_at": "2025-01-01", "server": "https://mcp.example.com",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.CreateConnector(context.Background(), &connector.CreateRequest{
|
||||
Name: "my_connector",
|
||||
Description: "test",
|
||||
Server: "https://mcp.example.com",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.ID != "conn-1" {
|
||||
t.Errorf("got id %q", resp.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListConnectors_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "GET" {
|
||||
t.Errorf("expected GET, got %s", r.Method)
|
||||
}
|
||||
json.NewEncoder(w).Encode([]map[string]any{
|
||||
{"id": "c1", "name": "conn1", "description": "d1", "created_at": "t", "modified_at": "t"},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
list, err := client.ListConnectors(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(list) != 1 {
|
||||
t.Fatalf("got %d connectors", len(list))
|
||||
}
|
||||
if list[0].ID != "c1" {
|
||||
t.Errorf("got id %q", list[0].ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetConnector_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/connectors/my_conn" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "c1", "name": "my_conn", "description": "d",
|
||||
"created_at": "t", "modified_at": "t",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
c, err := client.GetConnector(context.Background(), "my_conn")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if c.Name != "my_conn" {
|
||||
t.Errorf("got name %q", c.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateConnector_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "PATCH" {
|
||||
t.Errorf("expected PATCH, got %s", r.Method)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "c1", "name": "updated", "description": "new desc",
|
||||
"created_at": "t", "modified_at": "t",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
name := "updated"
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.UpdateConnector(context.Background(), "c1", &connector.UpdateRequest{
|
||||
Name: &name,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Name != "updated" {
|
||||
t.Errorf("got name %q", resp.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteConnector_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "DELETE" {
|
||||
t.Errorf("expected DELETE, got %s", r.Method)
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
err := client.DeleteConnector(context.Background(), "c1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetConnectorAuthURL_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/connectors/c1/auth_url" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"auth_url": "https://oauth.example.com/authorize",
|
||||
"ttl": 3600,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.GetConnectorAuthURL(context.Background(), "c1", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.AuthURL != "https://oauth.example.com/authorize" {
|
||||
t.Errorf("got auth_url %q", resp.AuthURL)
|
||||
}
|
||||
if resp.TTL != 3600 {
|
||||
t.Errorf("got ttl %d", resp.TTL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListConnectorTools_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/connectors/c1/tools" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode([]map[string]any{
|
||||
{"id": "t1", "name": "search", "description": "search the web"},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
tools, err := client.ListConnectorTools(context.Background(), "c1", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(tools) != 1 {
|
||||
t.Fatalf("got %d tools", len(tools))
|
||||
}
|
||||
if tools[0].Name != "search" {
|
||||
t.Errorf("got name %q", tools[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallConnectorTool_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
t.Errorf("expected POST, got %s", r.Method)
|
||||
}
|
||||
if r.URL.Path != "/v1/connectors/c1/tools/search/call" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
var body map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
args := body["arguments"].(map[string]any)
|
||||
if args["query"] != "hello" {
|
||||
t.Errorf("got query %v", args["query"])
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"content": []map[string]any{{"type": "text", "text": "result"}},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.CallConnectorTool(context.Background(), "c1", "search", &connector.CallToolRequest{
|
||||
Arguments: map[string]any{"query": "hello"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Content == nil {
|
||||
t.Error("expected non-nil content")
|
||||
}
|
||||
}
|
||||
6
doc.go
6
doc.go
@@ -39,9 +39,9 @@
|
||||
// # Sub-packages
|
||||
//
|
||||
// Types are organized into sub-packages by domain: [chat], [agents],
|
||||
// [conversation], [embedding], [model], [file], [finetune], [batch],
|
||||
// [ocr], [audio], [library], [moderation], [classification], and [fim].
|
||||
// All service methods live directly on [Client].
|
||||
// [connector], [conversation], [embedding], [model], [file], [finetune],
|
||||
// [batch], [ocr], [audio], [library], [moderation], [classification],
|
||||
// [fim], and [observability]. All service methods live directly on [Client].
|
||||
//
|
||||
// # Reference
|
||||
//
|
||||
|
||||
17
file/file.go
17
file/file.go
@@ -64,6 +64,23 @@ type SignedURL struct {
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// Visibility controls who can see a file.
|
||||
type Visibility string
|
||||
|
||||
const (
|
||||
VisibilitySharedGlobal Visibility = "shared_global"
|
||||
VisibilitySharedOrg Visibility = "shared_org"
|
||||
VisibilitySharedWorkspace Visibility = "shared_workspace"
|
||||
VisibilityPrivate Visibility = "private"
|
||||
)
|
||||
|
||||
// UploadParams holds parameters for uploading a file.
|
||||
type UploadParams struct {
|
||||
Purpose Purpose
|
||||
Expiry *int
|
||||
Visibility *Visibility
|
||||
}
|
||||
|
||||
// ListParams holds optional parameters for listing files.
|
||||
type ListParams struct {
|
||||
Page *int
|
||||
|
||||
14
files.go
14
files.go
@@ -12,10 +12,18 @@ import (
|
||||
)
|
||||
|
||||
// UploadFile uploads a file for use with fine-tuning, batch, or OCR.
|
||||
func (c *Client) UploadFile(ctx context.Context, filename string, r io.Reader, purpose file.Purpose) (*file.File, error) {
|
||||
func (c *Client) UploadFile(ctx context.Context, filename string, r io.Reader, params *file.UploadParams) (*file.File, error) {
|
||||
fields := map[string]string{}
|
||||
if purpose != "" {
|
||||
fields["purpose"] = string(purpose)
|
||||
if params != nil {
|
||||
if params.Purpose != "" {
|
||||
fields["purpose"] = string(params.Purpose)
|
||||
}
|
||||
if params.Expiry != nil {
|
||||
fields["expiry"] = strconv.Itoa(*params.Expiry)
|
||||
}
|
||||
if params.Visibility != nil {
|
||||
fields["visibility"] = string(*params.Visibility)
|
||||
}
|
||||
}
|
||||
var resp file.File
|
||||
if err := c.doMultipart(ctx, "/v1/files", filename, r, fields, &resp); err != nil {
|
||||
|
||||
@@ -58,7 +58,7 @@ func TestUploadFile_Success(t *testing.T) {
|
||||
context.Background(),
|
||||
"train.jsonl",
|
||||
strings.NewReader(`{"text":"hello"}`),
|
||||
file.PurposeFineTune,
|
||||
&file.UploadParams{Purpose: file.PurposeFineTune},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -74,6 +74,43 @@ func TestUploadFile_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUploadFile_WithExpiryAndVisibility(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseMultipartForm(10 << 20); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if r.FormValue("purpose") != "fine-tune" {
|
||||
t.Errorf("got purpose %q", r.FormValue("purpose"))
|
||||
}
|
||||
if r.FormValue("expiry") != "48" {
|
||||
t.Errorf("expected expiry=48, got %q", r.FormValue("expiry"))
|
||||
}
|
||||
if r.FormValue("visibility") != "private" {
|
||||
t.Errorf("expected visibility=private, got %q", r.FormValue("visibility"))
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "file-ev", "object": "file", "bytes": 10,
|
||||
"created_at": 1, "filename": "data.jsonl",
|
||||
"purpose": "fine-tune", "sample_type": "instruct",
|
||||
"source": "upload",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
expiry := 48
|
||||
vis := file.VisibilityPrivate
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
_, err := client.UploadFile(context.Background(), "data.jsonl", strings.NewReader("{}"), &file.UploadParams{
|
||||
Purpose: file.PurposeFineTune,
|
||||
Expiry: &expiry,
|
||||
Visibility: &vis,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListFiles_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "GET" {
|
||||
@@ -260,7 +297,7 @@ func TestUploadFile_Error(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
_, err := client.UploadFile(context.Background(), "bad.txt", strings.NewReader(""), file.PurposeFineTune)
|
||||
_, err := client.UploadFile(context.Background(), "bad.txt", strings.NewReader(""), &file.UploadParams{Purpose: file.PurposeFineTune})
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ func integrationClient(t *testing.T) *Client {
|
||||
func TestIntegration_ListModels(t *testing.T) {
|
||||
client := integrationClient(t)
|
||||
|
||||
resp, err := client.ListModels(context.Background())
|
||||
resp, err := client.ListModels(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
)
|
||||
|
||||
// Version is the SDK version string.
|
||||
const Version = "1.0.0"
|
||||
const Version = "1.1.0"
|
||||
|
||||
const (
|
||||
defaultBaseURL = "https://api.mistral.ai"
|
||||
|
||||
@@ -52,3 +52,9 @@ type DeleteModelOut struct {
|
||||
Object string `json:"object"`
|
||||
Deleted bool `json:"deleted"`
|
||||
}
|
||||
|
||||
// ListParams holds optional parameters for listing models.
|
||||
type ListParams struct {
|
||||
Provider *string
|
||||
Model *string
|
||||
}
|
||||
|
||||
18
models.go
18
models.go
@@ -2,14 +2,28 @@ package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/model"
|
||||
)
|
||||
|
||||
// ListModels returns a list of available models.
|
||||
func (c *Client) ListModels(ctx context.Context) (*model.ModelList, error) {
|
||||
func (c *Client) ListModels(ctx context.Context, params *model.ListParams) (*model.ModelList, error) {
|
||||
path := "/v1/models"
|
||||
if params != nil {
|
||||
q := url.Values{}
|
||||
if params.Provider != nil {
|
||||
q.Set("provider", *params.Provider)
|
||||
}
|
||||
if params.Model != nil {
|
||||
q.Set("model", *params.Model)
|
||||
}
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
path += "?" + encoded
|
||||
}
|
||||
}
|
||||
var resp model.ModelList
|
||||
if err := c.doJSON(ctx, "GET", "/v1/models", nil, &resp); err != nil {
|
||||
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/model"
|
||||
)
|
||||
|
||||
func TestListModels_Success(t *testing.T) {
|
||||
@@ -45,7 +47,7 @@ func TestListModels_Success(t *testing.T) {
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
list, err := client.ListModels(context.Background())
|
||||
list, err := client.ListModels(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -82,6 +84,33 @@ func TestListModels_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestListModels_WithParams(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Query().Get("provider") != "mistralai" {
|
||||
t.Errorf("expected provider=mistralai, got %q", r.URL.Query().Get("provider"))
|
||||
}
|
||||
if r.URL.Query().Get("model") != "mistral-small" {
|
||||
t.Errorf("expected model=mistral-small, got %q", r.URL.Query().Get("model"))
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"object": "list",
|
||||
"data": []map[string]any{},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := "mistralai"
|
||||
modelName := "mistral-small"
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
_, err := client.ListModels(context.Background(), &model.ListParams{
|
||||
Provider: &provider,
|
||||
Model: &modelName,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModel_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/models/mistral-small-latest" {
|
||||
|
||||
46
observability/campaign.go
Normal file
46
observability/campaign.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package observability
|
||||
|
||||
// Campaign represents an observability campaign.
|
||||
type Campaign struct {
|
||||
ID string `json:"id"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
DeletedAt *string `json:"deleted_at,omitempty"`
|
||||
Name string `json:"name"`
|
||||
OwnerID string `json:"owner_id"`
|
||||
WorkspaceID string `json:"workspace_id"`
|
||||
Description string `json:"description"`
|
||||
MaxNbEvents int `json:"max_nb_events"`
|
||||
SearchParams FilterPayload `json:"search_params"`
|
||||
Judge Judge `json:"judge"`
|
||||
}
|
||||
|
||||
// CreateCampaignRequest creates a new campaign.
|
||||
type CreateCampaignRequest struct {
|
||||
SearchParams FilterPayload `json:"search_params"`
|
||||
JudgeID string `json:"judge_id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
MaxNbEvents int `json:"max_nb_events"`
|
||||
}
|
||||
|
||||
// CampaignStatusResponse is the response for campaign status.
|
||||
type CampaignStatusResponse struct {
|
||||
Status TaskStatus `json:"status"`
|
||||
}
|
||||
|
||||
// ListCampaignsResponse is the response from listing campaigns.
|
||||
type ListCampaignsResponse struct {
|
||||
Count int `json:"count"`
|
||||
Results []Campaign `json:"results,omitempty"`
|
||||
Next *string `json:"next,omitempty"`
|
||||
Previous *string `json:"previous,omitempty"`
|
||||
}
|
||||
|
||||
// ListCampaignEventsResponse is the response from listing campaign events.
|
||||
type ListCampaignEventsResponse struct {
|
||||
Count int `json:"count"`
|
||||
Results []ChatCompletionEventPreview `json:"results,omitempty"`
|
||||
Next *string `json:"next,omitempty"`
|
||||
Previous *string `json:"previous,omitempty"`
|
||||
}
|
||||
156
observability/dataset.go
Normal file
156
observability/dataset.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package observability
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// ConversationSource indicates how a dataset record was created.
|
||||
type ConversationSource string
|
||||
|
||||
const (
|
||||
SourceExplorer ConversationSource = "EXPLORER"
|
||||
SourceUploadedFile ConversationSource = "UPLOADED_FILE"
|
||||
SourceDirectInput ConversationSource = "DIRECT_INPUT"
|
||||
SourcePlayground ConversationSource = "PLAYGROUND"
|
||||
)
|
||||
|
||||
// Dataset represents a dataset entity.
|
||||
type Dataset struct {
|
||||
ID string `json:"id"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
DeletedAt *string `json:"deleted_at,omitempty"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
OwnerID string `json:"owner_id"`
|
||||
WorkspaceID string `json:"workspace_id"`
|
||||
}
|
||||
|
||||
// CreateDatasetRequest creates a new dataset.
|
||||
type CreateDatasetRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
// UpdateDatasetRequest updates a dataset.
|
||||
type UpdateDatasetRequest struct {
|
||||
Name *string `json:"name,omitempty"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
}
|
||||
|
||||
// DatasetRecord is a single record in a dataset.
|
||||
type DatasetRecord struct {
|
||||
ID string `json:"id"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
DeletedAt *string `json:"deleted_at,omitempty"`
|
||||
DatasetID string `json:"dataset_id"`
|
||||
Payload ConversationPayload `json:"payload"`
|
||||
Properties map[string]any `json:"properties,omitempty"`
|
||||
Source ConversationSource `json:"source"`
|
||||
}
|
||||
|
||||
// ConversationPayload holds the messages for a dataset record.
|
||||
type ConversationPayload struct {
|
||||
Messages []map[string]any `json:"messages"`
|
||||
}
|
||||
|
||||
// CreateRecordRequest creates a new dataset record.
|
||||
type CreateRecordRequest struct {
|
||||
Payload ConversationPayload `json:"payload"`
|
||||
Properties map[string]any `json:"properties"`
|
||||
}
|
||||
|
||||
// UpdateRecordPayloadRequest updates a record's payload.
|
||||
type UpdateRecordPayloadRequest struct {
|
||||
Payload ConversationPayload `json:"payload"`
|
||||
}
|
||||
|
||||
// UpdateRecordPropertiesRequest updates a record's properties.
|
||||
type UpdateRecordPropertiesRequest struct {
|
||||
Properties map[string]any `json:"properties"`
|
||||
}
|
||||
|
||||
// BulkDeleteRecordsRequest deletes multiple records.
|
||||
type BulkDeleteRecordsRequest struct {
|
||||
DatasetRecordIDs []string `json:"dataset_record_ids"`
|
||||
}
|
||||
|
||||
// JudgeRecordRequest judges a dataset record.
|
||||
type JudgeRecordRequest struct {
|
||||
JudgeDefinition CreateJudgeRequest `json:"judge_definition"`
|
||||
}
|
||||
|
||||
// DatasetImportTask tracks an async import operation.
|
||||
type DatasetImportTask struct {
|
||||
ID string `json:"id"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
DeletedAt *string `json:"deleted_at,omitempty"`
|
||||
CreatorID string `json:"creator_id"`
|
||||
DatasetID string `json:"dataset_id"`
|
||||
WorkspaceID string `json:"workspace_id"`
|
||||
Status TaskStatus `json:"status"`
|
||||
Progress *int `json:"progress,omitempty"`
|
||||
Message *string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// ExportDatasetResponse is the response from exporting a dataset.
|
||||
type ExportDatasetResponse struct {
|
||||
FileURL string `json:"file_url"`
|
||||
}
|
||||
|
||||
// Import request types.
|
||||
|
||||
// ImportFromCampaignRequest imports records from a campaign.
|
||||
type ImportFromCampaignRequest struct {
|
||||
CampaignID string `json:"campaign_id"`
|
||||
}
|
||||
|
||||
// ImportFromExplorerRequest imports records from explorer events.
|
||||
type ImportFromExplorerRequest struct {
|
||||
CompletionEventIDs []string `json:"completion_event_ids"`
|
||||
}
|
||||
|
||||
// ImportFromFileRequest imports records from a file.
|
||||
type ImportFromFileRequest struct {
|
||||
FileID string `json:"file_id"`
|
||||
}
|
||||
|
||||
// ImportFromPlaygroundRequest imports records from playground conversations.
|
||||
type ImportFromPlaygroundRequest struct {
|
||||
ConversationIDs []string `json:"conversation_ids"`
|
||||
}
|
||||
|
||||
// ImportFromDatasetRequest imports records from another dataset.
|
||||
type ImportFromDatasetRequest struct {
|
||||
DatasetRecordIDs []string `json:"dataset_record_ids"`
|
||||
}
|
||||
|
||||
// List response types.
|
||||
|
||||
// ListDatasetsResponse is the response from listing datasets.
|
||||
type ListDatasetsResponse struct {
|
||||
Count int `json:"count"`
|
||||
Results []Dataset `json:"results,omitempty"`
|
||||
Next *string `json:"next,omitempty"`
|
||||
Previous *string `json:"previous,omitempty"`
|
||||
}
|
||||
|
||||
// ListRecordsResponse is the response from listing dataset records.
|
||||
type ListRecordsResponse struct {
|
||||
Count int `json:"count"`
|
||||
Results []DatasetRecord `json:"results,omitempty"`
|
||||
Next *string `json:"next,omitempty"`
|
||||
Previous *string `json:"previous,omitempty"`
|
||||
}
|
||||
|
||||
// ListTasksResponse is the response from listing import tasks.
|
||||
type ListTasksResponse struct {
|
||||
Count int `json:"count"`
|
||||
Results []DatasetImportTask `json:"results,omitempty"`
|
||||
Next *string `json:"next,omitempty"`
|
||||
Previous *string `json:"previous,omitempty"`
|
||||
}
|
||||
|
||||
// JudgeResultResponse is the raw response from judging operations.
|
||||
// The shape depends on the judge type (classification or regression).
|
||||
type JudgeResultResponse json.RawMessage
|
||||
5
observability/doc.go
Normal file
5
observability/doc.go
Normal file
@@ -0,0 +1,5 @@
|
||||
// Package observability provides types for the Mistral observability API (beta).
|
||||
//
|
||||
// This includes campaigns, chat completion events, judges, and datasets
|
||||
// for monitoring and evaluating model behavior.
|
||||
package observability
|
||||
70
observability/event.go
Normal file
70
observability/event.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package observability
|
||||
|
||||
// ChatCompletionEvent is a full chat completion event.
|
||||
type ChatCompletionEvent struct {
|
||||
EventID string `json:"event_id"`
|
||||
CorrelationID string `json:"correlation_id"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
ExtraFields map[string]any `json:"extra_fields,omitempty"`
|
||||
NbInputTokens int `json:"nb_input_tokens"`
|
||||
NbOutputTokens int `json:"nb_output_tokens"`
|
||||
EnabledTools []map[string]any `json:"enabled_tools,omitempty"`
|
||||
RequestMessages []map[string]any `json:"request_messages,omitempty"`
|
||||
ResponseMessages []map[string]any `json:"response_messages,omitempty"`
|
||||
NbMessages int `json:"nb_messages"`
|
||||
ChatTranscriptionEvents []ChatTranscriptionEvent `json:"chat_transcription_events,omitempty"`
|
||||
}
|
||||
|
||||
// ChatCompletionEventPreview is a summary of a chat completion event.
|
||||
type ChatCompletionEventPreview struct {
|
||||
EventID string `json:"event_id"`
|
||||
CorrelationID string `json:"correlation_id"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
ExtraFields map[string]any `json:"extra_fields,omitempty"`
|
||||
NbInputTokens int `json:"nb_input_tokens"`
|
||||
NbOutputTokens int `json:"nb_output_tokens"`
|
||||
}
|
||||
|
||||
// ChatTranscriptionEvent is an audio transcription within a chat event.
|
||||
type ChatTranscriptionEvent struct {
|
||||
AudioURL string `json:"audio_url"`
|
||||
Model string `json:"model"`
|
||||
ResponseMessage map[string]any `json:"response_message"`
|
||||
}
|
||||
|
||||
// SearchEventsRequest is the request body for searching chat completion events.
|
||||
type SearchEventsRequest struct {
|
||||
SearchParams FilterPayload `json:"search_params"`
|
||||
ExtraFields []string `json:"extra_fields,omitempty"`
|
||||
}
|
||||
|
||||
// SearchEventsResponse is the response from searching events.
|
||||
type SearchEventsResponse struct {
|
||||
Results []ChatCompletionEventPreview `json:"results,omitempty"`
|
||||
Next *string `json:"next,omitempty"`
|
||||
Cursor *string `json:"cursor,omitempty"`
|
||||
}
|
||||
|
||||
// SearchEventIDsRequest is the request body for searching event IDs.
|
||||
type SearchEventIDsRequest struct {
|
||||
SearchParams FilterPayload `json:"search_params"`
|
||||
ExtraFields []string `json:"extra_fields,omitempty"`
|
||||
}
|
||||
|
||||
// SearchEventIDsResponse is the response from searching event IDs.
|
||||
type SearchEventIDsResponse struct {
|
||||
CompletionEventIDs []string `json:"completion_event_ids"`
|
||||
}
|
||||
|
||||
// JudgeEventRequest is the request body for judging a chat completion event.
|
||||
type JudgeEventRequest struct {
|
||||
JudgeDefinition CreateJudgeRequest `json:"judge_definition"`
|
||||
}
|
||||
|
||||
// SimilarEventsResponse is the response from fetching similar events.
|
||||
type SimilarEventsResponse struct {
|
||||
Count int `json:"count"`
|
||||
Results []ChatCompletionEventPreview `json:"results,omitempty"`
|
||||
Next *string `json:"next,omitempty"`
|
||||
Previous *string `json:"previous,omitempty"`
|
||||
}
|
||||
75
observability/filter.go
Normal file
75
observability/filter.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package observability
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// Op is a filter comparison operator.
|
||||
type Op string
|
||||
|
||||
const (
|
||||
OpLt Op = "lt"
|
||||
OpLte Op = "lte"
|
||||
OpGt Op = "gt"
|
||||
OpGte Op = "gte"
|
||||
OpEq Op = "eq"
|
||||
OpNeq Op = "neq"
|
||||
OpIsNull Op = "isnull"
|
||||
OpStartsWith Op = "startswith"
|
||||
OpIStartsWith Op = "istartswith"
|
||||
OpEndsWith Op = "endswith"
|
||||
OpIEndsWith Op = "iendswith"
|
||||
OpContains Op = "contains"
|
||||
OpIContains Op = "icontains"
|
||||
OpMatches Op = "matches"
|
||||
OpNotContains Op = "notcontains"
|
||||
OpINotContains Op = "inotcontains"
|
||||
OpIncludes Op = "includes"
|
||||
OpExcludes Op = "excludes"
|
||||
OpLenEq Op = "len_eq"
|
||||
)
|
||||
|
||||
// FilterCondition is a single filter comparison.
|
||||
type FilterCondition struct {
|
||||
Field string `json:"field"`
|
||||
Op Op `json:"op"`
|
||||
Value any `json:"value"`
|
||||
}
|
||||
|
||||
// FilterGroup combines filters with AND/OR logic.
|
||||
// The JSON keys are uppercase "AND" / "OR".
|
||||
type FilterGroup struct {
|
||||
AND []json.RawMessage `json:"AND,omitempty"`
|
||||
OR []json.RawMessage `json:"OR,omitempty"`
|
||||
}
|
||||
|
||||
// FilterPayload wraps the top-level filter for search operations.
|
||||
// Filters can be a FilterGroup or a FilterCondition.
|
||||
type FilterPayload struct {
|
||||
Filters json.RawMessage `json:"filters,omitempty"`
|
||||
}
|
||||
|
||||
// TaskStatus is the status of an async task.
|
||||
type TaskStatus string
|
||||
|
||||
const (
|
||||
TaskStatusRunning TaskStatus = "RUNNING"
|
||||
TaskStatusCompleted TaskStatus = "COMPLETED"
|
||||
TaskStatusFailed TaskStatus = "FAILED"
|
||||
TaskStatusCanceled TaskStatus = "CANCELED"
|
||||
TaskStatusTerminated TaskStatus = "TERMINATED"
|
||||
TaskStatusContinuedAsNew TaskStatus = "CONTINUED_AS_NEW"
|
||||
TaskStatusTimedOut TaskStatus = "TIMED_OUT"
|
||||
TaskStatusUnknown TaskStatus = "UNKNOWN"
|
||||
)
|
||||
|
||||
// PaginationParams holds common pagination query parameters.
|
||||
type PaginationParams struct {
|
||||
Page *int
|
||||
PageSize *int
|
||||
}
|
||||
|
||||
// SearchParams holds common search query parameters.
|
||||
type SearchParams struct {
|
||||
Page *int
|
||||
PageSize *int
|
||||
Q *string
|
||||
}
|
||||
114
observability/judge.go
Normal file
114
observability/judge.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package observability
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// JudgeOutputType identifies the kind of judge output.
|
||||
type JudgeOutputType string
|
||||
|
||||
const (
|
||||
JudgeOutputClassification JudgeOutputType = "CLASSIFICATION"
|
||||
JudgeOutputRegression JudgeOutputType = "REGRESSION"
|
||||
)
|
||||
|
||||
// JudgeOutput is a sealed interface for judge output configurations.
|
||||
type JudgeOutput interface {
|
||||
judgeOutputType() JudgeOutputType
|
||||
}
|
||||
|
||||
// ClassificationOutput configures a classification judge.
|
||||
type ClassificationOutput struct {
|
||||
Type JudgeOutputType `json:"type"`
|
||||
Options []ClassificationOption `json:"options"`
|
||||
}
|
||||
|
||||
func (*ClassificationOutput) judgeOutputType() JudgeOutputType { return JudgeOutputClassification }
|
||||
|
||||
// ClassificationOption is a single option for a classification judge.
|
||||
type ClassificationOption struct {
|
||||
Value string `json:"value"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
// RegressionOutput configures a regression judge.
|
||||
type RegressionOutput struct {
|
||||
Type JudgeOutputType `json:"type"`
|
||||
MinDescription string `json:"min_description"`
|
||||
MaxDescription string `json:"max_description"`
|
||||
Min *float64 `json:"min,omitempty"`
|
||||
Max *float64 `json:"max,omitempty"`
|
||||
}
|
||||
|
||||
func (*RegressionOutput) judgeOutputType() JudgeOutputType { return JudgeOutputRegression }
|
||||
|
||||
// UnmarshalJudgeOutput dispatches to the concrete JudgeOutput type.
|
||||
func UnmarshalJudgeOutput(data []byte) (JudgeOutput, error) {
|
||||
var probe struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
if err := json.Unmarshal(data, &probe); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal judge output: %w", err)
|
||||
}
|
||||
switch JudgeOutputType(probe.Type) {
|
||||
case JudgeOutputClassification:
|
||||
var o ClassificationOutput
|
||||
return &o, json.Unmarshal(data, &o)
|
||||
case JudgeOutputRegression:
|
||||
var o RegressionOutput
|
||||
return &o, json.Unmarshal(data, &o)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown judge output type: %q", probe.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// Judge represents a judge entity.
|
||||
type Judge struct {
|
||||
ID string `json:"id"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
DeletedAt *string `json:"deleted_at,omitempty"`
|
||||
OwnerID string `json:"owner_id"`
|
||||
WorkspaceID string `json:"workspace_id"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
ModelName string `json:"model_name"`
|
||||
Output json.RawMessage `json:"output"`
|
||||
Instructions string `json:"instructions"`
|
||||
Tools []string `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
// CreateJudgeRequest creates a new judge.
|
||||
type CreateJudgeRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
ModelName string `json:"model_name"`
|
||||
Output json.RawMessage `json:"output"`
|
||||
Instructions string `json:"instructions"`
|
||||
Tools []string `json:"tools"`
|
||||
}
|
||||
|
||||
// UpdateJudgeRequest updates a judge.
|
||||
type UpdateJudgeRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
ModelName string `json:"model_name"`
|
||||
Output json.RawMessage `json:"output"`
|
||||
Instructions string `json:"instructions"`
|
||||
Tools []string `json:"tools"`
|
||||
}
|
||||
|
||||
// JudgeConversationRequest is the request for live-judging a conversation.
|
||||
type JudgeConversationRequest struct {
|
||||
Messages []map[string]any `json:"messages"`
|
||||
Properties map[string]any `json:"properties,omitempty"`
|
||||
}
|
||||
|
||||
// ListJudgesResponse is the response from listing judges.
|
||||
type ListJudgesResponse struct {
|
||||
Count int `json:"count"`
|
||||
Results []Judge `json:"results,omitempty"`
|
||||
Next *string `json:"next,omitempty"`
|
||||
Previous *string `json:"previous,omitempty"`
|
||||
}
|
||||
97
observability_campaigns.go
Normal file
97
observability_campaigns.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/observability"
|
||||
)
|
||||
|
||||
// CreateCampaign creates a new observability campaign.
|
||||
func (c *Client) CreateCampaign(ctx context.Context, req *observability.CreateCampaignRequest) (*observability.Campaign, error) {
|
||||
var resp observability.Campaign
|
||||
if err := c.doJSON(ctx, "POST", "/v1/observability/campaigns", req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ListCampaigns lists observability campaigns.
|
||||
func (c *Client) ListCampaigns(ctx context.Context, params *observability.SearchParams) (*observability.ListCampaignsResponse, error) {
|
||||
path := "/v1/observability/campaigns"
|
||||
if params != nil {
|
||||
q := url.Values{}
|
||||
if params.PageSize != nil {
|
||||
q.Set("page_size", strconv.Itoa(*params.PageSize))
|
||||
}
|
||||
if params.Page != nil {
|
||||
q.Set("page", strconv.Itoa(*params.Page))
|
||||
}
|
||||
if params.Q != nil {
|
||||
q.Set("q", *params.Q)
|
||||
}
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
path += "?" + encoded
|
||||
}
|
||||
}
|
||||
var resp observability.ListCampaignsResponse
|
||||
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// GetCampaign retrieves a campaign by ID.
|
||||
func (c *Client) GetCampaign(ctx context.Context, campaignID string) (*observability.Campaign, error) {
|
||||
var resp observability.Campaign
|
||||
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/campaigns/%s", campaignID), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// DeleteCampaign deletes a campaign.
|
||||
func (c *Client) DeleteCampaign(ctx context.Context, campaignID string) error {
|
||||
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/observability/campaigns/%s", campaignID), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 400 {
|
||||
return parseAPIError(resp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCampaignStatus retrieves the status of a campaign.
|
||||
func (c *Client) GetCampaignStatus(ctx context.Context, campaignID string) (*observability.CampaignStatusResponse, error) {
|
||||
var resp observability.CampaignStatusResponse
|
||||
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/campaigns/%s/status", campaignID), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ListCampaignEvents lists events selected by a campaign.
|
||||
func (c *Client) ListCampaignEvents(ctx context.Context, campaignID string, params *observability.PaginationParams) (*observability.ListCampaignEventsResponse, error) {
|
||||
path := fmt.Sprintf("/v1/observability/campaigns/%s/selected-events", campaignID)
|
||||
if params != nil {
|
||||
q := url.Values{}
|
||||
if params.PageSize != nil {
|
||||
q.Set("page_size", strconv.Itoa(*params.PageSize))
|
||||
}
|
||||
if params.Page != nil {
|
||||
q.Set("page", strconv.Itoa(*params.Page))
|
||||
}
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
path += "?" + encoded
|
||||
}
|
||||
}
|
||||
var resp observability.ListCampaignEventsResponse
|
||||
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
148
observability_campaigns_test.go
Normal file
148
observability_campaigns_test.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/observability"
|
||||
)
|
||||
|
||||
func TestCreateCampaign_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" || r.URL.Path != "/v1/observability/campaigns" {
|
||||
t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
|
||||
}
|
||||
var body map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
if body["name"] != "test-campaign" {
|
||||
t.Errorf("got name %v", body["name"])
|
||||
}
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "camp-1", "name": "test-campaign", "description": "d",
|
||||
"created_at": "t", "updated_at": "t", "owner_id": "o",
|
||||
"workspace_id": "w", "max_nb_events": 100,
|
||||
"search_params": map[string]any{}, "judge": map[string]any{"id": "j1"},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.CreateCampaign(context.Background(), &observability.CreateCampaignRequest{
|
||||
Name: "test-campaign",
|
||||
Description: "d",
|
||||
JudgeID: "j1",
|
||||
MaxNbEvents: 100,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.ID != "camp-1" {
|
||||
t.Errorf("got id %q", resp.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListCampaigns_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/observability/campaigns" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"count": 1,
|
||||
"results": []map[string]any{{"id": "c1", "name": "c", "description": "d", "created_at": "t", "updated_at": "t", "owner_id": "o", "workspace_id": "w", "max_nb_events": 10, "search_params": map[string]any{}, "judge": map[string]any{"id": "j"}}},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.ListCampaigns(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Count != 1 {
|
||||
t.Errorf("got count %d", resp.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCampaign_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/observability/campaigns/camp-1" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "camp-1", "name": "c", "description": "d",
|
||||
"created_at": "t", "updated_at": "t", "owner_id": "o",
|
||||
"workspace_id": "w", "max_nb_events": 10,
|
||||
"search_params": map[string]any{}, "judge": map[string]any{"id": "j"},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.GetCampaign(context.Background(), "camp-1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.ID != "camp-1" {
|
||||
t.Errorf("got id %q", resp.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteCampaign_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "DELETE" {
|
||||
t.Errorf("expected DELETE")
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
if err := client.DeleteCampaign(context.Background(), "camp-1"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCampaignStatus_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/observability/campaigns/camp-1/status" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{"status": "COMPLETED"})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.GetCampaignStatus(context.Background(), "camp-1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Status != observability.TaskStatusCompleted {
|
||||
t.Errorf("got status %q", resp.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListCampaignEvents_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/observability/campaigns/camp-1/selected-events" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"count": 0,
|
||||
"results": []any{},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.ListCampaignEvents(context.Background(), "camp-1", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Count != 0 {
|
||||
t.Errorf("got count %d", resp.Count)
|
||||
}
|
||||
}
|
||||
252
observability_datasets.go
Normal file
252
observability_datasets.go
Normal file
@@ -0,0 +1,252 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/observability"
|
||||
)
|
||||
|
||||
// CreateDataset creates a new observability dataset.
|
||||
func (c *Client) CreateDataset(ctx context.Context, req *observability.CreateDatasetRequest) (*observability.Dataset, error) {
|
||||
var resp observability.Dataset
|
||||
if err := c.doJSON(ctx, "POST", "/v1/observability/datasets", req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ListDatasets lists observability datasets.
|
||||
func (c *Client) ListDatasets(ctx context.Context, params *observability.SearchParams) (*observability.ListDatasetsResponse, error) {
|
||||
path := "/v1/observability/datasets"
|
||||
if params != nil {
|
||||
q := url.Values{}
|
||||
if params.PageSize != nil {
|
||||
q.Set("page_size", strconv.Itoa(*params.PageSize))
|
||||
}
|
||||
if params.Page != nil {
|
||||
q.Set("page", strconv.Itoa(*params.Page))
|
||||
}
|
||||
if params.Q != nil {
|
||||
q.Set("q", *params.Q)
|
||||
}
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
path += "?" + encoded
|
||||
}
|
||||
}
|
||||
var resp observability.ListDatasetsResponse
|
||||
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// GetDataset retrieves a dataset by ID.
|
||||
func (c *Client) GetDataset(ctx context.Context, datasetID string) (*observability.Dataset, error) {
|
||||
var resp observability.Dataset
|
||||
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/datasets/%s", datasetID), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// UpdateDataset updates a dataset.
|
||||
func (c *Client) UpdateDataset(ctx context.Context, datasetID string, req *observability.UpdateDatasetRequest) (*observability.Dataset, error) {
|
||||
var resp observability.Dataset
|
||||
if err := c.doJSON(ctx, "PATCH", fmt.Sprintf("/v1/observability/datasets/%s", datasetID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// DeleteDataset deletes a dataset.
|
||||
func (c *Client) DeleteDataset(ctx context.Context, datasetID string) error {
|
||||
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/observability/datasets/%s", datasetID), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 400 {
|
||||
return parseAPIError(resp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExportDatasetToJSONL exports a dataset to JSONL format.
|
||||
func (c *Client) ExportDatasetToJSONL(ctx context.Context, datasetID string) (*observability.ExportDatasetResponse, error) {
|
||||
var resp observability.ExportDatasetResponse
|
||||
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/datasets/%s/exports/to-jsonl", datasetID), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// Dataset records
|
||||
|
||||
// ListDatasetRecords lists records in a dataset.
|
||||
func (c *Client) ListDatasetRecords(ctx context.Context, datasetID string, params *observability.PaginationParams) (*observability.ListRecordsResponse, error) {
|
||||
path := fmt.Sprintf("/v1/observability/datasets/%s/records", datasetID)
|
||||
if params != nil {
|
||||
q := url.Values{}
|
||||
if params.PageSize != nil {
|
||||
q.Set("page_size", strconv.Itoa(*params.PageSize))
|
||||
}
|
||||
if params.Page != nil {
|
||||
q.Set("page", strconv.Itoa(*params.Page))
|
||||
}
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
path += "?" + encoded
|
||||
}
|
||||
}
|
||||
var resp observability.ListRecordsResponse
|
||||
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// CreateDatasetRecord creates a record in a dataset.
|
||||
func (c *Client) CreateDatasetRecord(ctx context.Context, datasetID string, req *observability.CreateRecordRequest) (*observability.DatasetRecord, error) {
|
||||
var resp observability.DatasetRecord
|
||||
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/records", datasetID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// GetDatasetRecord retrieves a dataset record by ID.
|
||||
func (c *Client) GetDatasetRecord(ctx context.Context, recordID string) (*observability.DatasetRecord, error) {
|
||||
var resp observability.DatasetRecord
|
||||
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/dataset-records/%s", recordID), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// UpdateDatasetRecordPayload updates a record's payload.
|
||||
func (c *Client) UpdateDatasetRecordPayload(ctx context.Context, recordID string, req *observability.UpdateRecordPayloadRequest) (*observability.DatasetRecord, error) {
|
||||
var resp observability.DatasetRecord
|
||||
if err := c.doJSON(ctx, "PUT", fmt.Sprintf("/v1/observability/dataset-records/%s/payload", recordID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// UpdateDatasetRecordProperties updates a record's properties.
|
||||
func (c *Client) UpdateDatasetRecordProperties(ctx context.Context, recordID string, req *observability.UpdateRecordPropertiesRequest) (*observability.DatasetRecord, error) {
|
||||
var resp observability.DatasetRecord
|
||||
if err := c.doJSON(ctx, "PUT", fmt.Sprintf("/v1/observability/dataset-records/%s/properties", recordID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// DeleteDatasetRecord deletes a dataset record.
|
||||
func (c *Client) DeleteDatasetRecord(ctx context.Context, recordID string) error {
|
||||
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/observability/dataset-records/%s", recordID), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 400 {
|
||||
return parseAPIError(resp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BulkDeleteDatasetRecords deletes multiple dataset records.
|
||||
func (c *Client) BulkDeleteDatasetRecords(ctx context.Context, req *observability.BulkDeleteRecordsRequest) error {
|
||||
return c.doJSON(ctx, "POST", "/v1/observability/dataset-records/bulk-delete", req, nil)
|
||||
}
|
||||
|
||||
// JudgeDatasetRecord judges a dataset record.
|
||||
func (c *Client) JudgeDatasetRecord(ctx context.Context, recordID string, req *observability.JudgeRecordRequest) (json.RawMessage, error) {
|
||||
var resp json.RawMessage
|
||||
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/dataset-records/%s/live-judging", recordID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Import operations
|
||||
|
||||
// ImportDatasetFromCampaign imports records from a campaign.
|
||||
func (c *Client) ImportDatasetFromCampaign(ctx context.Context, datasetID string, req *observability.ImportFromCampaignRequest) (*observability.DatasetImportTask, error) {
|
||||
var resp observability.DatasetImportTask
|
||||
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/imports/from-campaign", datasetID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ImportDatasetFromExplorer imports records from explorer events.
|
||||
func (c *Client) ImportDatasetFromExplorer(ctx context.Context, datasetID string, req *observability.ImportFromExplorerRequest) (*observability.DatasetImportTask, error) {
|
||||
var resp observability.DatasetImportTask
|
||||
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/imports/from-explorer", datasetID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ImportDatasetFromFile imports records from a file.
|
||||
func (c *Client) ImportDatasetFromFile(ctx context.Context, datasetID string, req *observability.ImportFromFileRequest) (*observability.DatasetImportTask, error) {
|
||||
var resp observability.DatasetImportTask
|
||||
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/imports/from-file", datasetID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ImportDatasetFromPlayground imports records from playground conversations.
|
||||
func (c *Client) ImportDatasetFromPlayground(ctx context.Context, datasetID string, req *observability.ImportFromPlaygroundRequest) (*observability.DatasetImportTask, error) {
|
||||
var resp observability.DatasetImportTask
|
||||
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/imports/from-playground", datasetID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ImportDatasetFromDataset imports records from another dataset.
|
||||
func (c *Client) ImportDatasetFromDataset(ctx context.Context, datasetID string, req *observability.ImportFromDatasetRequest) (*observability.DatasetImportTask, error) {
|
||||
var resp observability.DatasetImportTask
|
||||
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/datasets/%s/imports/from-dataset", datasetID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// Tasks
|
||||
|
||||
// ListDatasetTasks lists import tasks for a dataset.
|
||||
func (c *Client) ListDatasetTasks(ctx context.Context, datasetID string, params *observability.PaginationParams) (*observability.ListTasksResponse, error) {
|
||||
path := fmt.Sprintf("/v1/observability/datasets/%s/tasks", datasetID)
|
||||
if params != nil {
|
||||
q := url.Values{}
|
||||
if params.PageSize != nil {
|
||||
q.Set("page_size", strconv.Itoa(*params.PageSize))
|
||||
}
|
||||
if params.Page != nil {
|
||||
q.Set("page", strconv.Itoa(*params.Page))
|
||||
}
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
path += "?" + encoded
|
||||
}
|
||||
}
|
||||
var resp observability.ListTasksResponse
|
||||
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// GetDatasetTask retrieves an import task by ID.
|
||||
func (c *Client) GetDatasetTask(ctx context.Context, datasetID, taskID string) (*observability.DatasetImportTask, error) {
|
||||
var resp observability.DatasetImportTask
|
||||
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/datasets/%s/tasks/%s", datasetID, taskID), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
211
observability_datasets_test.go
Normal file
211
observability_datasets_test.go
Normal file
@@ -0,0 +1,211 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/observability"
|
||||
)
|
||||
|
||||
func datasetJSON() map[string]any {
|
||||
return map[string]any{
|
||||
"id": "ds-1", "created_at": "t", "updated_at": "t",
|
||||
"name": "test-ds", "description": "d",
|
||||
"owner_id": "o", "workspace_id": "w",
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateDataset_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" || r.URL.Path != "/v1/observability/datasets" {
|
||||
t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
|
||||
}
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(datasetJSON())
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.CreateDataset(context.Background(), &observability.CreateDatasetRequest{
|
||||
Name: "test-ds",
|
||||
Description: "d",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.ID != "ds-1" {
|
||||
t.Errorf("got id %q", resp.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDatasets_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"count": 1,
|
||||
"results": []any{datasetJSON()},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.ListDatasets(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Count != 1 {
|
||||
t.Errorf("got count %d", resp.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDataset_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/observability/datasets/ds-1" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(datasetJSON())
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.GetDataset(context.Background(), "ds-1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Name != "test-ds" {
|
||||
t.Errorf("got name %q", resp.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteDataset_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
if err := client.DeleteDataset(context.Background(), "ds-1"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateDatasetRecord_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" || r.URL.Path != "/v1/observability/datasets/ds-1/records" {
|
||||
t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
|
||||
}
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "rec-1", "created_at": "t", "updated_at": "t",
|
||||
"dataset_id": "ds-1", "source": "DIRECT_INPUT",
|
||||
"payload": map[string]any{"messages": []any{}},
|
||||
"properties": map[string]any{},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.CreateDatasetRecord(context.Background(), "ds-1", &observability.CreateRecordRequest{
|
||||
Payload: observability.ConversationPayload{Messages: []map[string]any{}},
|
||||
Properties: map[string]any{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.ID != "rec-1" {
|
||||
t.Errorf("got id %q", resp.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListDatasetRecords_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/observability/datasets/ds-1/records" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"count": 0, "results": []any{},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.ListDatasetRecords(context.Background(), "ds-1", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Count != 0 {
|
||||
t.Errorf("got count %d", resp.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportDatasetFromCampaign_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/observability/datasets/ds-1/imports/from-campaign" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "task-1", "created_at": "t", "updated_at": "t",
|
||||
"creator_id": "u", "dataset_id": "ds-1", "workspace_id": "w",
|
||||
"status": "RUNNING",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.ImportDatasetFromCampaign(context.Background(), "ds-1", &observability.ImportFromCampaignRequest{
|
||||
CampaignID: "camp-1",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Status != observability.TaskStatusRunning {
|
||||
t.Errorf("got status %q", resp.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExportDatasetToJSONL_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/observability/datasets/ds-1/exports/to-jsonl" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"file_url": "https://storage.example.com/export.jsonl",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.ExportDatasetToJSONL(context.Background(), "ds-1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.FileURL != "https://storage.example.com/export.jsonl" {
|
||||
t.Errorf("got file_url %q", resp.FileURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDatasetTask_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/observability/datasets/ds-1/tasks/task-1" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "task-1", "created_at": "t", "updated_at": "t",
|
||||
"creator_id": "u", "dataset_id": "ds-1", "workspace_id": "w",
|
||||
"status": "COMPLETED",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.GetDatasetTask(context.Background(), "ds-1", "task-1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Status != observability.TaskStatusCompleted {
|
||||
t.Errorf("got status %q", resp.Status)
|
||||
}
|
||||
}
|
||||
69
observability_events.go
Normal file
69
observability_events.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/observability"
|
||||
)
|
||||
|
||||
// SearchChatCompletionEvents searches for chat completion events.
|
||||
func (c *Client) SearchChatCompletionEvents(ctx context.Context, req *observability.SearchEventsRequest) (*observability.SearchEventsResponse, error) {
|
||||
var resp observability.SearchEventsResponse
|
||||
if err := c.doJSON(ctx, "POST", "/v1/observability/chat-completion-events/search", req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// SearchChatCompletionEventIDs searches for chat completion event IDs.
|
||||
func (c *Client) SearchChatCompletionEventIDs(ctx context.Context, req *observability.SearchEventIDsRequest) (*observability.SearchEventIDsResponse, error) {
|
||||
var resp observability.SearchEventIDsResponse
|
||||
if err := c.doJSON(ctx, "POST", "/v1/observability/chat-completion-events/search-ids", req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// GetChatCompletionEvent retrieves a chat completion event by ID.
|
||||
func (c *Client) GetChatCompletionEvent(ctx context.Context, eventID string) (*observability.ChatCompletionEvent, error) {
|
||||
var resp observability.ChatCompletionEvent
|
||||
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/chat-completion-events/%s", eventID), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// GetSimilarChatCompletionEvents retrieves events similar to a given event.
|
||||
func (c *Client) GetSimilarChatCompletionEvents(ctx context.Context, eventID string, params *observability.PaginationParams) (*observability.SimilarEventsResponse, error) {
|
||||
path := fmt.Sprintf("/v1/observability/chat-completion-events/%s/similar-events", eventID)
|
||||
if params != nil {
|
||||
q := url.Values{}
|
||||
if params.PageSize != nil {
|
||||
q.Set("page_size", strconv.Itoa(*params.PageSize))
|
||||
}
|
||||
if params.Page != nil {
|
||||
q.Set("page", strconv.Itoa(*params.Page))
|
||||
}
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
path += "?" + encoded
|
||||
}
|
||||
}
|
||||
var resp observability.SimilarEventsResponse
|
||||
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// JudgeChatCompletionEvent judges a chat completion event.
|
||||
func (c *Client) JudgeChatCompletionEvent(ctx context.Context, eventID string, req *observability.JudgeEventRequest) (json.RawMessage, error) {
|
||||
var resp json.RawMessage
|
||||
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/chat-completion-events/%s/live-judging", eventID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
101
observability_events_test.go
Normal file
101
observability_events_test.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/observability"
|
||||
)
|
||||
|
||||
func TestSearchChatCompletionEvents_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" || r.URL.Path != "/v1/observability/chat-completion-events/search" {
|
||||
t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"results": []map[string]any{
|
||||
{"event_id": "ev-1", "correlation_id": "c1", "created_at": "t", "nb_input_tokens": 10, "nb_output_tokens": 5},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.SearchChatCompletionEvents(context.Background(), &observability.SearchEventsRequest{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(resp.Results) != 1 {
|
||||
t.Fatalf("got %d results", len(resp.Results))
|
||||
}
|
||||
if resp.Results[0].EventID != "ev-1" {
|
||||
t.Errorf("got event_id %q", resp.Results[0].EventID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchChatCompletionEventIDs_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/observability/chat-completion-events/search-ids" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"completion_event_ids": []string{"ev-1", "ev-2"},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.SearchChatCompletionEventIDs(context.Background(), &observability.SearchEventIDsRequest{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(resp.CompletionEventIDs) != 2 {
|
||||
t.Errorf("got %d ids", len(resp.CompletionEventIDs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetChatCompletionEvent_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/observability/chat-completion-events/ev-1" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"event_id": "ev-1", "correlation_id": "c1", "created_at": "t",
|
||||
"nb_input_tokens": 10, "nb_output_tokens": 5, "nb_messages": 2,
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.GetChatCompletionEvent(context.Background(), "ev-1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.EventID != "ev-1" {
|
||||
t.Errorf("got event_id %q", resp.EventID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSimilarChatCompletionEvents_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/observability/chat-completion-events/ev-1/similar-events" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"count": 0, "results": []any{},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.GetSimilarChatCompletionEvents(context.Background(), "ev-1", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Count != 0 {
|
||||
t.Errorf("got count %d", resp.Count)
|
||||
}
|
||||
}
|
||||
85
observability_judges.go
Normal file
85
observability_judges.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/observability"
|
||||
)
|
||||
|
||||
// CreateJudge creates a new observability judge.
|
||||
func (c *Client) CreateJudge(ctx context.Context, req *observability.CreateJudgeRequest) (*observability.Judge, error) {
|
||||
var resp observability.Judge
|
||||
if err := c.doJSON(ctx, "POST", "/v1/observability/judges", req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ListJudges lists observability judges.
|
||||
func (c *Client) ListJudges(ctx context.Context, params *observability.SearchParams) (*observability.ListJudgesResponse, error) {
|
||||
path := "/v1/observability/judges"
|
||||
if params != nil {
|
||||
q := url.Values{}
|
||||
if params.PageSize != nil {
|
||||
q.Set("page_size", strconv.Itoa(*params.PageSize))
|
||||
}
|
||||
if params.Page != nil {
|
||||
q.Set("page", strconv.Itoa(*params.Page))
|
||||
}
|
||||
if params.Q != nil {
|
||||
q.Set("q", *params.Q)
|
||||
}
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
path += "?" + encoded
|
||||
}
|
||||
}
|
||||
var resp observability.ListJudgesResponse
|
||||
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// GetJudge retrieves a judge by ID.
|
||||
func (c *Client) GetJudge(ctx context.Context, judgeID string) (*observability.Judge, error) {
|
||||
var resp observability.Judge
|
||||
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/observability/judges/%s", judgeID), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// UpdateJudge updates a judge.
|
||||
func (c *Client) UpdateJudge(ctx context.Context, judgeID string, req *observability.UpdateJudgeRequest) (*observability.Judge, error) {
|
||||
var resp observability.Judge
|
||||
if err := c.doJSON(ctx, "PUT", fmt.Sprintf("/v1/observability/judges/%s", judgeID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// DeleteJudge deletes a judge.
|
||||
func (c *Client) DeleteJudge(ctx context.Context, judgeID string) error {
|
||||
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/observability/judges/%s", judgeID), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode >= 400 {
|
||||
return parseAPIError(resp)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// JudgeConversation performs live judging on a conversation.
|
||||
func (c *Client) JudgeConversation(ctx context.Context, judgeID string, req *observability.JudgeConversationRequest) (json.RawMessage, error) {
|
||||
var resp json.RawMessage
|
||||
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/observability/judges/%s/live-judging", judgeID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
123
observability_judges_test.go
Normal file
123
observability_judges_test.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/observability"
|
||||
)
|
||||
|
||||
func judgeJSON() map[string]any {
|
||||
return map[string]any{
|
||||
"id": "j1", "created_at": "t", "updated_at": "t",
|
||||
"owner_id": "o", "workspace_id": "w", "name": "quality",
|
||||
"description": "d", "model_name": "m", "instructions": "i",
|
||||
"output": map[string]any{"type": "CLASSIFICATION", "options": []any{}},
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateJudge_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" || r.URL.Path != "/v1/observability/judges" {
|
||||
t.Errorf("unexpected %s %s", r.Method, r.URL.Path)
|
||||
}
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(judgeJSON())
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.CreateJudge(context.Background(), &observability.CreateJudgeRequest{
|
||||
Name: "quality",
|
||||
Description: "d",
|
||||
ModelName: "m",
|
||||
Instructions: "i",
|
||||
Tools: []string{},
|
||||
Output: json.RawMessage(`{"type":"CLASSIFICATION","options":[]}`),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.ID != "j1" {
|
||||
t.Errorf("got id %q", resp.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListJudges_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"count": 1,
|
||||
"results": []any{judgeJSON()},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.ListJudges(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Count != 1 {
|
||||
t.Errorf("got count %d", resp.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetJudge_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/observability/judges/j1" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(judgeJSON())
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.GetJudge(context.Background(), "j1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.Name != "quality" {
|
||||
t.Errorf("got name %q", resp.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateJudge_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "PUT" {
|
||||
t.Errorf("expected PUT, got %s", r.Method)
|
||||
}
|
||||
json.NewEncoder(w).Encode(judgeJSON())
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
_, err := client.UpdateJudge(context.Background(), "j1", &observability.UpdateJudgeRequest{
|
||||
Name: "quality",
|
||||
Description: "d",
|
||||
ModelName: "m",
|
||||
Instructions: "i",
|
||||
Tools: []string{},
|
||||
Output: json.RawMessage(`{"type":"CLASSIFICATION","options":[]}`),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteJudge_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "DELETE" {
|
||||
t.Errorf("expected DELETE")
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
if err := client.DeleteJudge(context.Background(), "j1"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user