diff --git a/CHANGELOG.md b/CHANGELOG.md index 9abb7ed..fc52d95 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,14 +1,61 @@ -## Unreleased +## v1.4.0 — 2026-04-28 + +Spec/SDK alignment pass after upstream OpenAPI moved to v1.0.0 and +Python SDK shipped v2.3.0..v2.4.3. RAG ingestion-pipeline beta surface +(Python v2.4.3) intentionally deferred until the dust settles upstream. + +### Added + +- **`Client.GetWorkflowWorkerInfo`** — restores the + `GET /v1/workflows/workers/whoami` endpoint that was removed in v1.3.0. + The endpoint is still in the spec and is needed by callers running + custom workers that connect their own scheduler. + (`workflow.WorkerInfo` type.) +- **Observability fields API** — three GETs missing since the + observability surface was first added: + - `Client.GetChatCompletionFields` (`/v1/observability/chat-completion-fields`) + - `Client.GetChatCompletionFieldOptions` (`…/{field}/options?operator=…`) + - `Client.GetChatCompletionFieldOptionsCounts` (`…/{field}/options-counts`) + - new types: `observability.BaseFieldDefinition`, `FieldGroup`, + `ChatCompletionFields`, `ChatCompletionFieldOptions`, + `FieldOptionCountsRequest`, `FieldOptionCounts`, `FieldOptionCountItem`, + plus `FieldType` and `FieldOperator` typed enums. +- **Workflow payload encoding constants** — `workflow.EncodedPayloadOption` + with `EncodedPayloadOffloaded`, `EncodedPayloadEncrypted`, + `EncodedPayloadEncryptedPartial`. Wire-compatible refinement of the + pre-existing `[]string` field on `NetworkEncodedInput`. + (Mirrors Python SDK v2.4.0.) +- **Workflow ↔ connector integration** (Python SDK v2.4.2): + - `workflow.ConnectorSlot`, `ConnectorBindings`, `ConnectorExtensions`, + `WorkflowExtensions` types. + - `workflow.BuildConnectorExtensions(slots …)` helper that produces the + nested map expected at `ExecutionRequest.Extensions["mistralai"]`. + - `workflow.ConnectorAuthTaskState` + `ConnectorAuthStatus` constants + for parsing payloads emitted by the `connector-auth` custom task event. + - New `Extensions map[string]any` field on `workflow.ExecutionRequest`. +- **HITL (human-in-the-loop) confirmation constants** — typed values + alongside the pre-existing `conversation.ToolCallConfirmation` and + `tool_confirmations` field: + - `conversation.Confirmation` with `ConfirmationAllow` / `ConfirmationDeny` + for the reply side. + - `ConfirmationStatusPending` / `ConfirmationStatusAllowed` / + `ConfirmationStatusDenied` for `FunctionCallEvent.ConfirmationStatus` + and `FunctionCallEntry.ConfirmationStatus` (already present as + untyped strings). ### Changed +- `workflow.NetworkEncodedInput.EncodingOptions` is now + `[]EncodedPayloadOption` (string-typed alias). JSON wire format + unchanged; existing call sites that passed `[]string{"offloaded"}` + need to switch to `[]workflow.EncodedPayloadOption{workflow.EncodedPayloadOffloaded}` + or the typed constants directly. - Tracking upstream Mistral OpenAPI spec **v1.0.0** (was v0.1.104). - No SDK surface change: the only spec delta in this window was the - removal of OCR confidence-score fields - (`OCRPageObject.confidence_scores`, `OCRRequest.confidence_scores_granularity`, + Only spec delta in this window was the removal of OCR confidence-score + fields (`OCRPageObject.confidence_scores`, + `OCRRequest.confidence_scores_granularity`, `OCRTableObject.word_confidence_scores`, plus the `OCRConfidenceScore` - and `OCRPageConfidenceScores` schemas), none of which were exposed by - this SDK. + and `OCRPageConfidenceScores` schemas), none of which this SDK exposed. ### Fixed (CI) diff --git a/README.md b/README.md index 5561c52..f4a9147 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ The most complete Go client for the [Mistral AI API](https://docs.mistral.ai/). **Zero dependencies.** The entire SDK — including tests — uses only the Go standard library. No `go.sum`, no transitive dependency tree to audit, no version conflicts, no supply chain risk. -**Full API coverage.** 165 methods across every Mistral endpoint — including Workflows, Connectors, Audio Speech/Voices, Conversations, Agents CRUD, Libraries, OCR, Observability, Fine-tuning, and Batch Jobs. No other Go SDK covers Workflows, Conversations, Connectors, or Observability. +**Full API coverage.** 169 methods across every Mistral endpoint — including Workflows (with worker introspection and connector bindings), Connectors, Audio Speech/Voices, Conversations (including human-in-the-loop tool confirmations), Agents CRUD, Libraries, OCR, Observability (events, judges, datasets, campaigns, fields), Fine-tuning, and Batch Jobs. No other Go SDK covers Workflows, Conversations, Connectors, or Observability. **Typed streaming.** A generic pull-based `Stream[T]` iterator — no channels, no goroutines, no leaks. Just `Next()` / `Current()` / `Err()` / `Close()`. @@ -19,7 +19,7 @@ The most complete Go client for the [Mistral AI API](https://docs.mistral.ai/). **Hand-written, not generated.** Idiomatic Go with sealed interfaces, discriminated unions, and functional options — not a Speakeasy/OpenAPI auto-gen dump with `any` everywhere. -**Test-driven.** 284 tests with race detection clean. Every endpoint tested against mock servers; integration tests against the real API. +**Test-driven.** 297 tests with race detection clean. Every endpoint tested against mock servers; integration tests against the real API. ## Install @@ -132,7 +132,7 @@ for stream.Next() { ## API Coverage -165 public methods on `Client`, grouped by domain: +169 public methods on `Client`, grouped by domain: | Domain | Methods | |--------|---------| @@ -156,6 +156,7 @@ for stream.Next() { | **Classification** | `Classify`, `ClassifyChat` | | **Observability (campaigns)** | `CreateCampaign`, `ListCampaigns`, `GetCampaign`, `DeleteCampaign`, `GetCampaignStatus`, `ListCampaignEvents` | | **Observability (events)** | `SearchChatCompletionEvents`, `SearchChatCompletionEventIDs`, `GetChatCompletionEvent`, `GetSimilarChatCompletionEvents`, `JudgeChatCompletionEvent` | +| **Observability (fields)** | `GetChatCompletionFields`, `GetChatCompletionFieldOptions`, `GetChatCompletionFieldOptionsCounts` | | **Observability (judges)** | `CreateJudge`, `ListJudges`, `GetJudge`, `UpdateJudge`, `DeleteJudge`, `JudgeConversation` | | **Observability (datasets)** | `CreateDataset`, `ListDatasets`, `GetDataset`, `UpdateDataset`, `DeleteDataset`, `ExportDatasetToJSONL`, `ListDatasetRecords`, `CreateDatasetRecord`, `GetDatasetRecord`, `UpdateDatasetRecordPayload`, `UpdateDatasetRecordProperties`, `DeleteDatasetRecord`, `BulkDeleteDatasetRecords`, `JudgeDatasetRecord`, `ImportDatasetFromCampaign`, `ImportDatasetFromExplorer`, `ImportDatasetFromFile`, `ImportDatasetFromPlayground`, `ImportDatasetFromDataset`, `ListDatasetTasks`, `GetDatasetTask` | | **Workflows (CRUD)** | `ListWorkflows`, `GetWorkflow`, `UpdateWorkflow`, `ArchiveWorkflow`, `UnarchiveWorkflow`, `ExecuteWorkflow`, `ExecuteWorkflowAndWait` | @@ -167,6 +168,7 @@ for stream.Next() { | **Workflows (metrics)** | `GetWorkflowMetrics` | | **Workflows (runs)** | `ListWorkflowRuns`, `GetWorkflowRun`, `GetWorkflowRunHistory` | | **Workflows (schedules)** | `ListWorkflowSchedules`, `ScheduleWorkflow`, `UnscheduleWorkflow` | +| **Workflows (workers)** | `GetWorkflowWorkerInfo` | ## Comparison @@ -239,13 +241,14 @@ This SDK tracks the [official Mistral OpenAPI spec](https://github.com/mistralai The [Mistral Python SDK](https://github.com/mistralai/client-python) is used as a secondary reference for implementation patterns. -| SDK Version | Upstream Python SDK | -|-------------|---------------------| -| v1.3.0 | v2.3.0 | -| v1.2.1 | v2.2.0 | -| v1.2.0 | v2.2.0 | -| v1.1.0 | v2.1.3 | -| v1.0.0 | v2.0.4 | +| SDK Version | Upstream Python SDK | Upstream OpenAPI | +|-------------|---------------------|------------------| +| v1.4.0 | v2.4.3 (excl. RAG ingestion-pipeline beta) | v1.0.0 | +| v1.3.0 | v2.3.0 | v0.1.104 | +| v1.2.1 | v2.2.0 | — | +| v1.2.0 | v2.2.0 | — | +| v1.1.0 | v2.1.3 | — | +| v1.0.0 | v2.0.4 | — | ## License diff --git a/conversation/conversation.go b/conversation/conversation.go index 985fad9..f934c32 100644 --- a/conversation/conversation.go +++ b/conversation/conversation.go @@ -44,10 +44,30 @@ type CompletionArgs struct { ToolChoice *chat.ToolChoiceMode `json:"tool_choice,omitempty"` } +// Confirmation is a client decision on a pending tool call. +type Confirmation string + +const ( + ConfirmationAllow Confirmation = "allow" + ConfirmationDeny Confirmation = "deny" +) + +// ConfirmationStatus values appear on FunctionCallEvent.ConfirmationStatus +// and FunctionCallEntry.ConfirmationStatus, reporting where in the +// human-in-the-loop flow a tool call currently sits. +const ( + ConfirmationStatusPending = "pending" + ConfirmationStatusAllowed = "allowed" + ConfirmationStatusDenied = "denied" +) + // ToolCallConfirmation confirms or denies a pending tool call. +// +// Send a slice of these on AppendRequest.ToolConfirmations after receiving +// a function call event whose ConfirmationStatus is "pending". type ToolCallConfirmation struct { ToolCallID string `json:"tool_call_id"` - Confirmation string `json:"confirmation"` // "allow" or "deny" + Confirmation string `json:"confirmation"` // use ConfirmationAllow / ConfirmationDeny } // Inputs represents conversation inputs (text string or entry array). diff --git a/observability/field.go b/observability/field.go new file mode 100644 index 0000000..785503a --- /dev/null +++ b/observability/field.go @@ -0,0 +1,84 @@ +package observability + +// FieldType identifies the data type of a chat-completion-event field. +type FieldType string + +const ( + FieldTypeEnum FieldType = "ENUM" + FieldTypeText FieldType = "TEXT" + FieldTypeInt FieldType = "INT" + FieldTypeFloat FieldType = "FLOAT" + FieldTypeBool FieldType = "BOOL" + FieldTypeTimestamp FieldType = "TIMESTAMP" + FieldTypeArray FieldType = "ARRAY" +) + +// FieldOperator is a filter operator supported on observability fields. +type FieldOperator string + +const ( + FieldOperatorLT FieldOperator = "lt" + FieldOperatorLTE FieldOperator = "lte" + FieldOperatorGT FieldOperator = "gt" + FieldOperatorGTE FieldOperator = "gte" + FieldOperatorStartsWith FieldOperator = "startswith" + FieldOperatorIStartsWith FieldOperator = "istartswith" + FieldOperatorEndsWith FieldOperator = "endswith" + FieldOperatorIEndsWith FieldOperator = "iendswith" + FieldOperatorContains FieldOperator = "contains" + FieldOperatorIContains FieldOperator = "icontains" + FieldOperatorMatches FieldOperator = "matches" + FieldOperatorNotContains FieldOperator = "notcontains" + FieldOperatorINotContain FieldOperator = "inotcontains" + FieldOperatorEq FieldOperator = "eq" + FieldOperatorNeq FieldOperator = "neq" + FieldOperatorIsNull FieldOperator = "isnull" + FieldOperatorIncludes FieldOperator = "includes" + FieldOperatorExcludes FieldOperator = "excludes" + FieldOperatorLenEq FieldOperator = "len_eq" +) + +// BaseFieldDefinition describes a searchable chat-completion-event field. +type BaseFieldDefinition struct { + Name string `json:"name"` + Label string `json:"label"` + Type FieldType `json:"type"` + Group *string `json:"group,omitempty"` + SupportedOperators []FieldOperator `json:"supported_operators"` +} + +// FieldGroup groups related field definitions for UI display. +type FieldGroup struct { + Name string `json:"name"` + Label string `json:"label"` +} + +// ChatCompletionFields is the response of GET /v1/observability/chat-completion-fields. +type ChatCompletionFields struct { + FieldDefinitions []BaseFieldDefinition `json:"field_definitions"` + FieldGroups []FieldGroup `json:"field_groups"` +} + +// ChatCompletionFieldOptions is the response of +// GET /v1/observability/chat-completion-fields/{field_name}/options. +// +// Each option may be a string, bool, or null — preserved as raw any. +type ChatCompletionFieldOptions struct { + Options []any `json:"options"` +} + +// FieldOptionCountsRequest is the body of POST options-counts. +type FieldOptionCountsRequest struct { + FilterParams *FilterPayload `json:"filter_params,omitempty"` +} + +// FieldOptionCountItem pairs a field value with how many events have it. +type FieldOptionCountItem struct { + Value string `json:"value"` + Count int `json:"count"` +} + +// FieldOptionCounts is the response of POST options-counts. +type FieldOptionCounts struct { + Counts []FieldOptionCountItem `json:"counts"` +} diff --git a/observability_fields.go b/observability_fields.go new file mode 100644 index 0000000..47cf041 --- /dev/null +++ b/observability_fields.go @@ -0,0 +1,46 @@ +package mistral + +import ( + "context" + "fmt" + "net/url" + + "github.com/VikingOwl91/mistral-go-sdk/observability" +) + +// GetChatCompletionFields returns the searchable field definitions and groups +// for chat-completion observability events. +func (c *Client) GetChatCompletionFields(ctx context.Context) (*observability.ChatCompletionFields, error) { + var resp observability.ChatCompletionFields + if err := c.doJSON(ctx, "GET", "/v1/observability/chat-completion-fields", nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// GetChatCompletionFieldOptions returns the distinct values seen for the given +// field, filtered by the requested operator. +func (c *Client) GetChatCompletionFieldOptions(ctx context.Context, fieldName string, operator observability.FieldOperator) (*observability.ChatCompletionFieldOptions, error) { + q := url.Values{} + q.Set("operator", string(operator)) + path := fmt.Sprintf("/v1/observability/chat-completion-fields/%s/options?%s", url.PathEscape(fieldName), q.Encode()) + var resp observability.ChatCompletionFieldOptions + if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// GetChatCompletionFieldOptionsCounts returns per-value event counts for the +// given field, optionally filtered by the supplied filter payload. +func (c *Client) GetChatCompletionFieldOptionsCounts(ctx context.Context, fieldName string, req *observability.FieldOptionCountsRequest) (*observability.FieldOptionCounts, error) { + if req == nil { + req = &observability.FieldOptionCountsRequest{} + } + path := fmt.Sprintf("/v1/observability/chat-completion-fields/%s/options-counts", url.PathEscape(fieldName)) + var resp observability.FieldOptionCounts + if err := c.doJSON(ctx, "POST", path, req, &resp); err != nil { + return nil, err + } + return &resp, nil +} diff --git a/observability_fields_test.go b/observability_fields_test.go new file mode 100644 index 0000000..4225920 --- /dev/null +++ b/observability_fields_test.go @@ -0,0 +1,100 @@ +package mistral + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/VikingOwl91/mistral-go-sdk/observability" +) + +func TestGetChatCompletionFields_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" || r.URL.Path != "/v1/observability/chat-completion-fields" { + t.Errorf("unexpected %s %s", r.Method, r.URL.Path) + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "field_definitions": []map[string]any{ + { + "name": "model", + "label": "Model", + "type": "ENUM", + "supported_operators": []string{"eq", "neq", "includes"}, + }, + }, + "field_groups": []map[string]any{ + {"name": "request", "label": "Request"}, + }, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.GetChatCompletionFields(context.Background()) + if err != nil { + t.Fatal(err) + } + if len(resp.FieldDefinitions) != 1 { + t.Fatalf("got %d field definitions", len(resp.FieldDefinitions)) + } + def := resp.FieldDefinitions[0] + if def.Name != "model" || def.Type != observability.FieldTypeEnum { + t.Errorf("unexpected field def: %+v", def) + } + if len(def.SupportedOperators) != 3 { + t.Errorf("got %d operators", len(def.SupportedOperators)) + } + if len(resp.FieldGroups) != 1 || resp.FieldGroups[0].Name != "request" { + t.Errorf("unexpected groups: %+v", resp.FieldGroups) + } +} + +func TestGetChatCompletionFieldOptions_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/observability/chat-completion-fields/model/options" { + t.Errorf("got path %s", r.URL.Path) + } + if got := r.URL.Query().Get("operator"); got != "eq" { + t.Errorf("got operator=%q want eq", got) + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "options": []any{"mistral-small-latest", "mistral-large-latest", nil, true}, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.GetChatCompletionFieldOptions(context.Background(), "model", observability.FieldOperatorEq) + if err != nil { + t.Fatal(err) + } + if len(resp.Options) != 4 { + t.Fatalf("got %d options", len(resp.Options)) + } +} + +func TestGetChatCompletionFieldOptionsCounts_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" || r.URL.Path != "/v1/observability/chat-completion-fields/model/options-counts" { + t.Errorf("unexpected %s %s", r.Method, r.URL.Path) + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "counts": []map[string]any{ + {"value": "mistral-small-latest", "count": 42}, + {"value": "mistral-large-latest", "count": 17}, + }, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.GetChatCompletionFieldOptionsCounts(context.Background(), "model", &observability.FieldOptionCountsRequest{}) + if err != nil { + t.Fatal(err) + } + if len(resp.Counts) != 2 || resp.Counts[0].Count != 42 { + t.Errorf("unexpected counts: %+v", resp.Counts) + } +} diff --git a/workflow/connectors.go b/workflow/connectors.go new file mode 100644 index 0000000..a9f831b --- /dev/null +++ b/workflow/connectors.go @@ -0,0 +1,61 @@ +package workflow + +// ConnectorSlot declares a connector dependency for a workflow execution. +// +// Pass a slice of slots to BuildConnectorExtensions to produce the +// nested map expected on ExecutionRequest.Extensions. +type ConnectorSlot struct { + ConnectorName string `json:"connector_name"` + CredentialsName *string `json:"credentials_name,omitempty"` +} + +// ConnectorBindings is the bindings list inside ConnectorExtensions. +type ConnectorBindings struct { + Bindings []ConnectorSlot `json:"bindings"` +} + +// ConnectorExtensions is the value of the "mistralai" key in workflow extensions. +type ConnectorExtensions struct { + Connectors ConnectorBindings `json:"connectors"` +} + +// WorkflowExtensions is the top-level shape of the extensions field +// expected by the workflow execute endpoint when binding connectors. +type WorkflowExtensions struct { + Mistralai ConnectorExtensions `json:"mistralai"` +} + +// BuildConnectorExtensions returns the value to set on +// ExecutionRequest.Extensions for the given connector slots. +// +// The result is a map[string]any so callers can merge in additional +// extension keys without colliding with the connector wire shape. +func BuildConnectorExtensions(slots ...ConnectorSlot) map[string]any { + return map[string]any{ + "mistralai": ConnectorExtensions{ + Connectors: ConnectorBindings{Bindings: slots}, + }, + } +} + +// ConnectorAuthStatus is the state of an OAuth flow emitted by a +// connector-auth custom task event. +type ConnectorAuthStatus string + +const ( + ConnectorAuthWaitingForAuth ConnectorAuthStatus = "waiting_for_auth" + ConnectorAuthConnected ConnectorAuthStatus = "connected" + ConnectorAuthAccessDenied ConnectorAuthStatus = "access_denied" + ConnectorAuthTimedOut ConnectorAuthStatus = "timed_out" + ConnectorAuthError ConnectorAuthStatus = "error" +) + +// ConnectorAuthTaskState is the payload of a custom task event of type +// "connector-auth", emitted while a workflow waits for OAuth completion. +type ConnectorAuthTaskState struct { + ConnectorName string `json:"connector_name"` + ConnectorID string `json:"connector_id"` + Status ConnectorAuthStatus `json:"status"` + AuthURL *string `json:"auth_url,omitempty"` + Message *string `json:"message,omitempty"` +} diff --git a/workflow/execution.go b/workflow/execution.go index ae20e76..8cee4d1 100644 --- a/workflow/execution.go +++ b/workflow/execution.go @@ -25,6 +25,9 @@ type ExecutionRequest struct { TimeoutSeconds *float64 `json:"timeout_seconds,omitempty"` CustomTracingAttributes map[string]string `json:"custom_tracing_attributes,omitempty"` DeploymentName *string `json:"deployment_name,omitempty"` + // Extensions carries plugin-specific data such as connector bindings. + // Use BuildConnectorExtensions to construct the standard connector shape. + Extensions map[string]any `json:"extensions,omitempty"` } // ExecutionResponse is the response from a workflow execution. @@ -40,11 +43,20 @@ type ExecutionResponse struct { TotalDurationMs *int `json:"total_duration_ms,omitempty"` } +// EncodedPayloadOption identifies how a workflow payload was encoded. +type EncodedPayloadOption string + +const ( + EncodedPayloadOffloaded EncodedPayloadOption = "offloaded" + EncodedPayloadEncrypted EncodedPayloadOption = "encrypted" + EncodedPayloadEncryptedPartial EncodedPayloadOption = "encrypted-partial" +) + // NetworkEncodedInput holds a base64-encoded payload for workflow input. type NetworkEncodedInput struct { - B64Payload string `json:"b64payload"` - EncodingOptions []string `json:"encoding_options,omitempty"` - Empty bool `json:"empty,omitempty"` + B64Payload string `json:"b64payload"` + EncodingOptions []EncodedPayloadOption `json:"encoding_options,omitempty"` + Empty bool `json:"empty,omitempty"` } // SignalInvocationBody is the request body for signaling a workflow execution. diff --git a/workflow/worker.go b/workflow/worker.go new file mode 100644 index 0000000..67a82e5 --- /dev/null +++ b/workflow/worker.go @@ -0,0 +1,12 @@ +package workflow + +// WorkerInfo describes the worker scheduler the SDK is connected to. +// +// Returned by GET /v1/workflows/workers/whoami. Useful when running custom +// workers that need to know which scheduler / namespace to connect to. +// For managed deployments, prefer Registration.DeploymentID. +type WorkerInfo struct { + SchedulerURL string `json:"scheduler_url"` + Namespace string `json:"namespace"` + TLS bool `json:"tls"` +} diff --git a/workflows_extensions_test.go b/workflows_extensions_test.go new file mode 100644 index 0000000..7180078 --- /dev/null +++ b/workflows_extensions_test.go @@ -0,0 +1,106 @@ +package mistral + +import ( + "encoding/json" + "testing" + + "github.com/VikingOwl91/mistral-go-sdk/conversation" + "github.com/VikingOwl91/mistral-go-sdk/workflow" +) + +func TestNetworkEncodedInput_EncodingOptions(t *testing.T) { + in := workflow.NetworkEncodedInput{ + B64Payload: "eyJrIjoidiJ9", + EncodingOptions: []workflow.EncodedPayloadOption{workflow.EncodedPayloadOffloaded, workflow.EncodedPayloadEncrypted}, + } + b, err := json.Marshal(in) + if err != nil { + t.Fatal(err) + } + var got map[string]any + if err := json.Unmarshal(b, &got); err != nil { + t.Fatal(err) + } + opts, ok := got["encoding_options"].([]any) + if !ok || len(opts) != 2 { + t.Fatalf("unexpected encoding_options: %v", got["encoding_options"]) + } + if opts[0] != "offloaded" || opts[1] != "encrypted" { + t.Errorf("got %v, want [offloaded encrypted]", opts) + } +} + +func TestBuildConnectorExtensions_WireShape(t *testing.T) { + creds := "work-account" + ext := workflow.BuildConnectorExtensions( + workflow.ConnectorSlot{ConnectorName: "gmail"}, + workflow.ConnectorSlot{ConnectorName: "notion", CredentialsName: &creds}, + ) + b, err := json.Marshal(ext) + if err != nil { + t.Fatal(err) + } + want := `{"mistralai":{"connectors":{"bindings":[{"connector_name":"gmail"},{"connector_name":"notion","credentials_name":"work-account"}]}}}` + if string(b) != want { + t.Errorf("\nwant %s\ngot %s", want, string(b)) + } +} + +func TestExecutionRequest_Extensions(t *testing.T) { + req := workflow.ExecutionRequest{ + Extensions: workflow.BuildConnectorExtensions(workflow.ConnectorSlot{ConnectorName: "gmail"}), + } + b, err := json.Marshal(req) + if err != nil { + t.Fatal(err) + } + var got map[string]any + if err := json.Unmarshal(b, &got); err != nil { + t.Fatal(err) + } + if _, ok := got["extensions"]; !ok { + t.Fatalf("expected extensions key in marshalled request: %s", string(b)) + } +} + +func TestConnectorAuthTaskState_Roundtrip(t *testing.T) { + authURL := "https://oauth.example.com/authorize" + in := workflow.ConnectorAuthTaskState{ + ConnectorName: "gmail", + ConnectorID: "conn-1", + Status: workflow.ConnectorAuthWaitingForAuth, + AuthURL: &authURL, + } + b, err := json.Marshal(in) + if err != nil { + t.Fatal(err) + } + var out workflow.ConnectorAuthTaskState + if err := json.Unmarshal(b, &out); err != nil { + t.Fatal(err) + } + if out.Status != workflow.ConnectorAuthWaitingForAuth { + t.Errorf("got status %q", out.Status) + } + if out.AuthURL == nil || *out.AuthURL != authURL { + t.Errorf("auth_url roundtrip failed") + } +} + +func TestConfirmationConstants_WireValues(t *testing.T) { + // Reply constants. + c := conversation.ToolCallConfirmation{ + ToolCallID: "call_1", + Confirmation: string(conversation.ConfirmationAllow), + } + b, _ := json.Marshal(c) + if string(b) != `{"tool_call_id":"call_1","confirmation":"allow"}` { + t.Errorf("got %s", string(b)) + } + // Inbound status constants. + if conversation.ConfirmationStatusPending != "pending" || + conversation.ConfirmationStatusAllowed != "allowed" || + conversation.ConfirmationStatusDenied != "denied" { + t.Errorf("unexpected confirmation status constants") + } +} diff --git a/workflows_workers.go b/workflows_workers.go new file mode 100644 index 0000000..ecae6ad --- /dev/null +++ b/workflows_workers.go @@ -0,0 +1,20 @@ +package mistral + +import ( + "context" + + "github.com/VikingOwl91/mistral-go-sdk/workflow" +) + +// GetWorkflowWorkerInfo returns the scheduler URL, namespace, and TLS setting +// the API expects custom workers to connect with. +// +// Most callers using managed deployments do not need this — see +// Registration.DeploymentID. It is exposed for users running custom workers. +func (c *Client) GetWorkflowWorkerInfo(ctx context.Context) (*workflow.WorkerInfo, error) { + var resp workflow.WorkerInfo + if err := c.doJSON(ctx, "GET", "/v1/workflows/workers/whoami", nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} diff --git a/workflows_workers_test.go b/workflows_workers_test.go new file mode 100644 index 0000000..a24e3f9 --- /dev/null +++ b/workflows_workers_test.go @@ -0,0 +1,52 @@ +package mistral + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestGetWorkflowWorkerInfo_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" || r.URL.Path != "/v1/workflows/workers/whoami" { + t.Errorf("unexpected %s %s", r.Method, r.URL.Path) + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "scheduler_url": "scheduler.example.com:7233", + "namespace": "tenant-2", + "tls": true, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + info, err := client.GetWorkflowWorkerInfo(context.Background()) + if err != nil { + t.Fatal(err) + } + if info.SchedulerURL != "scheduler.example.com:7233" || info.Namespace != "tenant-2" || !info.TLS { + t.Errorf("unexpected info: %+v", info) + } +} + +func TestGetWorkflowWorkerInfo_TLSDefault(t *testing.T) { + // Server omits the tls field; the SDK should default it to false. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{ + "scheduler_url": "s", + "namespace": "n", + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + info, err := client.GetWorkflowWorkerInfo(context.Background()) + if err != nil { + t.Fatal(err) + } + if info.TLS { + t.Errorf("expected default tls=false, got true") + } +}