feat: add workflows CRUD and registration service methods
This commit is contained in:
168
workflows.go
Normal file
168
workflows.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/workflow"
|
||||
)
|
||||
|
||||
// ListWorkflows lists workflows.
|
||||
func (c *Client) ListWorkflows(ctx context.Context, params *workflow.WorkflowListParams) (*workflow.WorkflowListResponse, error) {
|
||||
path := "/v1/workflows"
|
||||
if params != nil {
|
||||
q := url.Values{}
|
||||
if params.ActiveOnly != nil {
|
||||
q.Set("active_only", strconv.FormatBool(*params.ActiveOnly))
|
||||
}
|
||||
if params.IncludeShared != nil {
|
||||
q.Set("include_shared", strconv.FormatBool(*params.IncludeShared))
|
||||
}
|
||||
if params.AvailableInChatAssistant != nil {
|
||||
q.Set("available_in_chat_assistant", strconv.FormatBool(*params.AvailableInChatAssistant))
|
||||
}
|
||||
if params.Archived != nil {
|
||||
q.Set("archived", strconv.FormatBool(*params.Archived))
|
||||
}
|
||||
if params.Cursor != nil {
|
||||
q.Set("cursor", *params.Cursor)
|
||||
}
|
||||
if params.Limit != nil {
|
||||
q.Set("limit", strconv.Itoa(*params.Limit))
|
||||
}
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
path += "?" + encoded
|
||||
}
|
||||
}
|
||||
var resp workflow.WorkflowListResponse
|
||||
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// GetWorkflow retrieves a workflow by identifier.
|
||||
func (c *Client) GetWorkflow(ctx context.Context, workflowIdentifier string) (*workflow.Workflow, error) {
|
||||
var resp workflow.Workflow
|
||||
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/workflows/%s", workflowIdentifier), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// UpdateWorkflow updates a workflow.
|
||||
func (c *Client) UpdateWorkflow(ctx context.Context, workflowIdentifier string, req *workflow.WorkflowUpdateRequest) (*workflow.Workflow, error) {
|
||||
var resp workflow.Workflow
|
||||
if err := c.doJSON(ctx, "PUT", fmt.Sprintf("/v1/workflows/%s", workflowIdentifier), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ArchiveWorkflow archives a workflow.
|
||||
func (c *Client) ArchiveWorkflow(ctx context.Context, workflowIdentifier string) (*workflow.WorkflowArchiveResponse, error) {
|
||||
var resp workflow.WorkflowArchiveResponse
|
||||
if err := c.doJSON(ctx, "PUT", fmt.Sprintf("/v1/workflows/%s/archive", workflowIdentifier), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// UnarchiveWorkflow unarchives a workflow.
|
||||
func (c *Client) UnarchiveWorkflow(ctx context.Context, workflowIdentifier string) (*workflow.WorkflowArchiveResponse, error) {
|
||||
var resp workflow.WorkflowArchiveResponse
|
||||
if err := c.doJSON(ctx, "PUT", fmt.Sprintf("/v1/workflows/%s/unarchive", workflowIdentifier), nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ExecuteWorkflow executes a workflow.
|
||||
func (c *Client) ExecuteWorkflow(ctx context.Context, workflowIdentifier string, req *workflow.ExecutionRequest) (*workflow.ExecutionResponse, error) {
|
||||
var resp workflow.ExecutionResponse
|
||||
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/workflows/%s/execute", workflowIdentifier), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ListWorkflowRegistrations lists workflow registrations.
|
||||
func (c *Client) ListWorkflowRegistrations(ctx context.Context, params *workflow.RegistrationListParams) (*workflow.RegistrationListResponse, error) {
|
||||
path := "/v1/workflows/registrations"
|
||||
if params != nil {
|
||||
q := url.Values{}
|
||||
if params.WorkflowID != nil {
|
||||
q.Set("workflow_id", *params.WorkflowID)
|
||||
}
|
||||
if params.TaskQueue != nil {
|
||||
q.Set("task_queue", *params.TaskQueue)
|
||||
}
|
||||
if params.ActiveOnly != nil {
|
||||
q.Set("active_only", strconv.FormatBool(*params.ActiveOnly))
|
||||
}
|
||||
if params.IncludeShared != nil {
|
||||
q.Set("include_shared", strconv.FormatBool(*params.IncludeShared))
|
||||
}
|
||||
if params.WorkflowSearch != nil {
|
||||
q.Set("workflow_search", *params.WorkflowSearch)
|
||||
}
|
||||
if params.Archived != nil {
|
||||
q.Set("archived", strconv.FormatBool(*params.Archived))
|
||||
}
|
||||
if params.WithWorkflow != nil {
|
||||
q.Set("with_workflow", strconv.FormatBool(*params.WithWorkflow))
|
||||
}
|
||||
if params.AvailableInChatAssistant != nil {
|
||||
q.Set("available_in_chat_assistant", strconv.FormatBool(*params.AvailableInChatAssistant))
|
||||
}
|
||||
if params.Limit != nil {
|
||||
q.Set("limit", strconv.Itoa(*params.Limit))
|
||||
}
|
||||
if params.Cursor != nil {
|
||||
q.Set("cursor", *params.Cursor)
|
||||
}
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
path += "?" + encoded
|
||||
}
|
||||
}
|
||||
var resp workflow.RegistrationListResponse
|
||||
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// GetWorkflowRegistration retrieves a workflow registration by ID.
|
||||
func (c *Client) GetWorkflowRegistration(ctx context.Context, registrationID string, params *workflow.RegistrationGetParams) (*workflow.Registration, error) {
|
||||
path := fmt.Sprintf("/v1/workflows/registrations/%s", registrationID)
|
||||
if params != nil {
|
||||
q := url.Values{}
|
||||
if params.WithWorkflow != nil {
|
||||
q.Set("with_workflow", strconv.FormatBool(*params.WithWorkflow))
|
||||
}
|
||||
if params.IncludeShared != nil {
|
||||
q.Set("include_shared", strconv.FormatBool(*params.IncludeShared))
|
||||
}
|
||||
if encoded := q.Encode(); encoded != "" {
|
||||
path += "?" + encoded
|
||||
}
|
||||
}
|
||||
var resp workflow.Registration
|
||||
if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ExecuteWorkflowRegistration executes a workflow via its registration.
|
||||
//
|
||||
// Deprecated: Use ExecuteWorkflow instead. This method will be removed in a future release.
|
||||
func (c *Client) ExecuteWorkflowRegistration(ctx context.Context, registrationID string, req *workflow.ExecutionRequest) (*workflow.ExecutionResponse, error) {
|
||||
var resp workflow.ExecutionResponse
|
||||
if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/workflows/registrations/%s/execute", registrationID), req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
190
workflows_test.go
Normal file
190
workflows_test.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/workflow"
|
||||
)
|
||||
|
||||
func TestListWorkflows_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "GET" {
|
||||
t.Errorf("got method %s", r.Method)
|
||||
}
|
||||
if r.URL.Path != "/v1/workflows" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
if r.URL.Query().Get("limit") != "10" {
|
||||
t.Errorf("got limit %q", r.URL.Query().Get("limit"))
|
||||
}
|
||||
if r.URL.Query().Get("active_only") != "true" {
|
||||
t.Errorf("got active_only %q", r.URL.Query().Get("active_only"))
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"workflows": []map[string]any{
|
||||
{"id": "wf-1", "name": "my-flow", "owner_id": "u1", "workspace_id": "ws1", "created_at": "2026-01-01", "updated_at": "2026-01-01"},
|
||||
},
|
||||
"next_cursor": "cur-abc",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
active := true
|
||||
limit := 10
|
||||
resp, err := client.ListWorkflows(context.Background(), &workflow.WorkflowListParams{
|
||||
ActiveOnly: &active,
|
||||
Limit: &limit,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(resp.Workflows) != 1 {
|
||||
t.Fatalf("got %d workflows", len(resp.Workflows))
|
||||
}
|
||||
if resp.Workflows[0].ID != "wf-1" {
|
||||
t.Errorf("got id %q", resp.Workflows[0].ID)
|
||||
}
|
||||
if resp.NextCursor == nil || *resp.NextCursor != "cur-abc" {
|
||||
t.Errorf("got cursor %v", resp.NextCursor)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetWorkflow_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/workflows/wf-1" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "wf-1", "name": "my-flow", "owner_id": "u1", "workspace_id": "ws1",
|
||||
"created_at": "2026-01-01", "updated_at": "2026-01-01",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
wf, err := client.GetWorkflow(context.Background(), "wf-1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if wf.Name != "my-flow" {
|
||||
t.Errorf("got name %q", wf.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateWorkflow_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "PUT" {
|
||||
t.Errorf("got method %s", r.Method)
|
||||
}
|
||||
var body map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
if body["display_name"] != "New Name" {
|
||||
t.Errorf("got display_name %v", body["display_name"])
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"id": "wf-1", "name": "my-flow", "display_name": "New Name",
|
||||
"owner_id": "u1", "workspace_id": "ws1",
|
||||
"created_at": "2026-01-01", "updated_at": "2026-01-02",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
name := "New Name"
|
||||
wf, err := client.UpdateWorkflow(context.Background(), "wf-1", &workflow.WorkflowUpdateRequest{
|
||||
DisplayName: &name,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if wf.DisplayName == nil || *wf.DisplayName != "New Name" {
|
||||
t.Errorf("got display_name %v", wf.DisplayName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestArchiveWorkflow_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "PUT" {
|
||||
t.Errorf("got method %s", r.Method)
|
||||
}
|
||||
if r.URL.Path != "/v1/workflows/wf-1/archive" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{"id": "wf-1", "archived": true})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.ArchiveWorkflow(context.Background(), "wf-1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !resp.Archived {
|
||||
t.Error("expected archived=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteWorkflow_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
t.Errorf("got method %s", r.Method)
|
||||
}
|
||||
if r.URL.Path != "/v1/workflows/wf-1/execute" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
var body map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
input, _ := body["input"].(map[string]any)
|
||||
if input["prompt"] != "hello" {
|
||||
t.Errorf("got input %v", body["input"])
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"workflow_name": "my-flow", "execution_id": "exec-1",
|
||||
"root_execution_id": "exec-1", "status": "RUNNING",
|
||||
"start_time": "2026-01-01T00:00:00Z",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.ExecuteWorkflow(context.Background(), "wf-1", &workflow.ExecutionRequest{
|
||||
Input: map[string]any{"prompt": "hello"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp.ExecutionID != "exec-1" {
|
||||
t.Errorf("got execution_id %q", resp.ExecutionID)
|
||||
}
|
||||
if resp.Status != workflow.ExecutionRunning {
|
||||
t.Errorf("got status %q", resp.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListWorkflowRegistrations_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/workflows/registrations" {
|
||||
t.Errorf("got path %s", r.URL.Path)
|
||||
}
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"registrations": []map[string]any{
|
||||
{"id": "reg-1", "workflow_id": "wf-1", "task_queue": "default", "created_at": "2026-01-01", "updated_at": "2026-01-01"},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
resp, err := client.ListWorkflowRegistrations(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(resp.Registrations) != 1 {
|
||||
t.Fatalf("got %d registrations", len(resp.Registrations))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user