diff --git a/workflows.go b/workflows.go new file mode 100644 index 0000000..5d08d1d --- /dev/null +++ b/workflows.go @@ -0,0 +1,168 @@ +package mistral + +import ( + "context" + "fmt" + "net/url" + "strconv" + + "somegit.dev/vikingowl/mistral-go-sdk/workflow" +) + +// ListWorkflows lists workflows. +func (c *Client) ListWorkflows(ctx context.Context, params *workflow.WorkflowListParams) (*workflow.WorkflowListResponse, error) { + path := "/v1/workflows" + if params != nil { + q := url.Values{} + if params.ActiveOnly != nil { + q.Set("active_only", strconv.FormatBool(*params.ActiveOnly)) + } + if params.IncludeShared != nil { + q.Set("include_shared", strconv.FormatBool(*params.IncludeShared)) + } + if params.AvailableInChatAssistant != nil { + q.Set("available_in_chat_assistant", strconv.FormatBool(*params.AvailableInChatAssistant)) + } + if params.Archived != nil { + q.Set("archived", strconv.FormatBool(*params.Archived)) + } + if params.Cursor != nil { + q.Set("cursor", *params.Cursor) + } + if params.Limit != nil { + q.Set("limit", strconv.Itoa(*params.Limit)) + } + if encoded := q.Encode(); encoded != "" { + path += "?" + encoded + } + } + var resp workflow.WorkflowListResponse + if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// GetWorkflow retrieves a workflow by identifier. +func (c *Client) GetWorkflow(ctx context.Context, workflowIdentifier string) (*workflow.Workflow, error) { + var resp workflow.Workflow + if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/workflows/%s", workflowIdentifier), nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// UpdateWorkflow updates a workflow. +func (c *Client) UpdateWorkflow(ctx context.Context, workflowIdentifier string, req *workflow.WorkflowUpdateRequest) (*workflow.Workflow, error) { + var resp workflow.Workflow + if err := c.doJSON(ctx, "PUT", fmt.Sprintf("/v1/workflows/%s", workflowIdentifier), req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// ArchiveWorkflow archives a workflow. +func (c *Client) ArchiveWorkflow(ctx context.Context, workflowIdentifier string) (*workflow.WorkflowArchiveResponse, error) { + var resp workflow.WorkflowArchiveResponse + if err := c.doJSON(ctx, "PUT", fmt.Sprintf("/v1/workflows/%s/archive", workflowIdentifier), nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// UnarchiveWorkflow unarchives a workflow. +func (c *Client) UnarchiveWorkflow(ctx context.Context, workflowIdentifier string) (*workflow.WorkflowArchiveResponse, error) { + var resp workflow.WorkflowArchiveResponse + if err := c.doJSON(ctx, "PUT", fmt.Sprintf("/v1/workflows/%s/unarchive", workflowIdentifier), nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// ExecuteWorkflow executes a workflow. +func (c *Client) ExecuteWorkflow(ctx context.Context, workflowIdentifier string, req *workflow.ExecutionRequest) (*workflow.ExecutionResponse, error) { + var resp workflow.ExecutionResponse + if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/workflows/%s/execute", workflowIdentifier), req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// ListWorkflowRegistrations lists workflow registrations. +func (c *Client) ListWorkflowRegistrations(ctx context.Context, params *workflow.RegistrationListParams) (*workflow.RegistrationListResponse, error) { + path := "/v1/workflows/registrations" + if params != nil { + q := url.Values{} + if params.WorkflowID != nil { + q.Set("workflow_id", *params.WorkflowID) + } + if params.TaskQueue != nil { + q.Set("task_queue", *params.TaskQueue) + } + if params.ActiveOnly != nil { + q.Set("active_only", strconv.FormatBool(*params.ActiveOnly)) + } + if params.IncludeShared != nil { + q.Set("include_shared", strconv.FormatBool(*params.IncludeShared)) + } + if params.WorkflowSearch != nil { + q.Set("workflow_search", *params.WorkflowSearch) + } + if params.Archived != nil { + q.Set("archived", strconv.FormatBool(*params.Archived)) + } + if params.WithWorkflow != nil { + q.Set("with_workflow", strconv.FormatBool(*params.WithWorkflow)) + } + if params.AvailableInChatAssistant != nil { + q.Set("available_in_chat_assistant", strconv.FormatBool(*params.AvailableInChatAssistant)) + } + if params.Limit != nil { + q.Set("limit", strconv.Itoa(*params.Limit)) + } + if params.Cursor != nil { + q.Set("cursor", *params.Cursor) + } + if encoded := q.Encode(); encoded != "" { + path += "?" + encoded + } + } + var resp workflow.RegistrationListResponse + if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// GetWorkflowRegistration retrieves a workflow registration by ID. +func (c *Client) GetWorkflowRegistration(ctx context.Context, registrationID string, params *workflow.RegistrationGetParams) (*workflow.Registration, error) { + path := fmt.Sprintf("/v1/workflows/registrations/%s", registrationID) + if params != nil { + q := url.Values{} + if params.WithWorkflow != nil { + q.Set("with_workflow", strconv.FormatBool(*params.WithWorkflow)) + } + if params.IncludeShared != nil { + q.Set("include_shared", strconv.FormatBool(*params.IncludeShared)) + } + if encoded := q.Encode(); encoded != "" { + path += "?" + encoded + } + } + var resp workflow.Registration + if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// ExecuteWorkflowRegistration executes a workflow via its registration. +// +// Deprecated: Use ExecuteWorkflow instead. This method will be removed in a future release. +func (c *Client) ExecuteWorkflowRegistration(ctx context.Context, registrationID string, req *workflow.ExecutionRequest) (*workflow.ExecutionResponse, error) { + var resp workflow.ExecutionResponse + if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/workflows/registrations/%s/execute", registrationID), req, &resp); err != nil { + return nil, err + } + return &resp, nil +} diff --git a/workflows_test.go b/workflows_test.go new file mode 100644 index 0000000..b52aa44 --- /dev/null +++ b/workflows_test.go @@ -0,0 +1,190 @@ +package mistral + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "somegit.dev/vikingowl/mistral-go-sdk/workflow" +) + +func TestListWorkflows_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" { + t.Errorf("got method %s", r.Method) + } + if r.URL.Path != "/v1/workflows" { + t.Errorf("got path %s", r.URL.Path) + } + if r.URL.Query().Get("limit") != "10" { + t.Errorf("got limit %q", r.URL.Query().Get("limit")) + } + if r.URL.Query().Get("active_only") != "true" { + t.Errorf("got active_only %q", r.URL.Query().Get("active_only")) + } + json.NewEncoder(w).Encode(map[string]any{ + "workflows": []map[string]any{ + {"id": "wf-1", "name": "my-flow", "owner_id": "u1", "workspace_id": "ws1", "created_at": "2026-01-01", "updated_at": "2026-01-01"}, + }, + "next_cursor": "cur-abc", + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + active := true + limit := 10 + resp, err := client.ListWorkflows(context.Background(), &workflow.WorkflowListParams{ + ActiveOnly: &active, + Limit: &limit, + }) + if err != nil { + t.Fatal(err) + } + if len(resp.Workflows) != 1 { + t.Fatalf("got %d workflows", len(resp.Workflows)) + } + if resp.Workflows[0].ID != "wf-1" { + t.Errorf("got id %q", resp.Workflows[0].ID) + } + if resp.NextCursor == nil || *resp.NextCursor != "cur-abc" { + t.Errorf("got cursor %v", resp.NextCursor) + } +} + +func TestGetWorkflow_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/workflows/wf-1" { + t.Errorf("got path %s", r.URL.Path) + } + json.NewEncoder(w).Encode(map[string]any{ + "id": "wf-1", "name": "my-flow", "owner_id": "u1", "workspace_id": "ws1", + "created_at": "2026-01-01", "updated_at": "2026-01-01", + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + wf, err := client.GetWorkflow(context.Background(), "wf-1") + if err != nil { + t.Fatal(err) + } + if wf.Name != "my-flow" { + t.Errorf("got name %q", wf.Name) + } +} + +func TestUpdateWorkflow_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "PUT" { + t.Errorf("got method %s", r.Method) + } + var body map[string]any + json.NewDecoder(r.Body).Decode(&body) + if body["display_name"] != "New Name" { + t.Errorf("got display_name %v", body["display_name"]) + } + json.NewEncoder(w).Encode(map[string]any{ + "id": "wf-1", "name": "my-flow", "display_name": "New Name", + "owner_id": "u1", "workspace_id": "ws1", + "created_at": "2026-01-01", "updated_at": "2026-01-02", + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + name := "New Name" + wf, err := client.UpdateWorkflow(context.Background(), "wf-1", &workflow.WorkflowUpdateRequest{ + DisplayName: &name, + }) + if err != nil { + t.Fatal(err) + } + if wf.DisplayName == nil || *wf.DisplayName != "New Name" { + t.Errorf("got display_name %v", wf.DisplayName) + } +} + +func TestArchiveWorkflow_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "PUT" { + t.Errorf("got method %s", r.Method) + } + if r.URL.Path != "/v1/workflows/wf-1/archive" { + t.Errorf("got path %s", r.URL.Path) + } + json.NewEncoder(w).Encode(map[string]any{"id": "wf-1", "archived": true}) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.ArchiveWorkflow(context.Background(), "wf-1") + if err != nil { + t.Fatal(err) + } + if !resp.Archived { + t.Error("expected archived=true") + } +} + +func TestExecuteWorkflow_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("got method %s", r.Method) + } + if r.URL.Path != "/v1/workflows/wf-1/execute" { + t.Errorf("got path %s", r.URL.Path) + } + var body map[string]any + json.NewDecoder(r.Body).Decode(&body) + input, _ := body["input"].(map[string]any) + if input["prompt"] != "hello" { + t.Errorf("got input %v", body["input"]) + } + json.NewEncoder(w).Encode(map[string]any{ + "workflow_name": "my-flow", "execution_id": "exec-1", + "root_execution_id": "exec-1", "status": "RUNNING", + "start_time": "2026-01-01T00:00:00Z", + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.ExecuteWorkflow(context.Background(), "wf-1", &workflow.ExecutionRequest{ + Input: map[string]any{"prompt": "hello"}, + }) + if err != nil { + t.Fatal(err) + } + if resp.ExecutionID != "exec-1" { + t.Errorf("got execution_id %q", resp.ExecutionID) + } + if resp.Status != workflow.ExecutionRunning { + t.Errorf("got status %q", resp.Status) + } +} + +func TestListWorkflowRegistrations_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/workflows/registrations" { + t.Errorf("got path %s", r.URL.Path) + } + json.NewEncoder(w).Encode(map[string]any{ + "registrations": []map[string]any{ + {"id": "reg-1", "workflow_id": "wf-1", "task_queue": "default", "created_at": "2026-01-01", "updated_at": "2026-01-01"}, + }, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.ListWorkflowRegistrations(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + if len(resp.Registrations) != 1 { + t.Fatalf("got %d registrations", len(resp.Registrations)) + } +}