diff --git a/workflow/event.go b/workflow/event.go index 1594b07..c3793eb 100644 --- a/workflow/event.go +++ b/workflow/event.go @@ -325,7 +325,10 @@ func UnmarshalEvent(data []byte) (Event, error) { // StreamPayload is a single SSE payload from the workflow event stream. type StreamPayload struct { - Data json.RawMessage `json:"data"` + Stream string `json:"stream"` + Data json.RawMessage `json:"data"` + WorkflowContext StreamWorkflowContext `json:"workflow_context"` + BrokerSequence int64 `json:"broker_sequence"` } // StreamWorkflowContext holds context for a workflow event stream. diff --git a/workflows_executions.go b/workflows_executions.go new file mode 100644 index 0000000..f4eb980 --- /dev/null +++ b/workflows_executions.go @@ -0,0 +1,255 @@ +package mistral + +import ( + "context" + "encoding/json" + "fmt" + "net/url" + "strconv" + "time" + + "somegit.dev/vikingowl/mistral-go-sdk/workflow" +) + +// GetWorkflowExecution retrieves a workflow execution by ID. +func (c *Client) GetWorkflowExecution(ctx context.Context, executionID string) (*workflow.ExecutionResponse, error) { + var resp workflow.ExecutionResponse + if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/workflows/executions/%s", executionID), nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// GetWorkflowExecutionHistory retrieves the history of a workflow execution. +func (c *Client) GetWorkflowExecutionHistory(ctx context.Context, executionID string, decodePayloads *bool) (json.RawMessage, error) { + path := fmt.Sprintf("/v1/workflows/executions/%s/history", executionID) + if decodePayloads != nil { + path += "?decode_payloads=" + strconv.FormatBool(*decodePayloads) + } + var resp json.RawMessage + if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil { + return nil, err + } + return resp, nil +} + +// StreamWorkflowExecution streams events for a workflow execution via SSE. +func (c *Client) StreamWorkflowExecution(ctx context.Context, executionID string, params *workflow.StreamParams) (*WorkflowEventStream, error) { + path := fmt.Sprintf("/v1/workflows/executions/%s/stream", executionID) + if params != nil { + q := url.Values{} + if params.EventSource != nil { + q.Set("event_source", string(*params.EventSource)) + } + if params.LastEventID != nil { + q.Set("last_event_id", *params.LastEventID) + } + if encoded := q.Encode(); encoded != "" { + path += "?" + encoded + } + } + resp, err := c.doStream(ctx, "GET", path, nil) + if err != nil { + return nil, err + } + return newWorkflowEventStream(resp.Body), nil +} + +// SignalWorkflowExecution sends a signal to a workflow execution. +func (c *Client) SignalWorkflowExecution(ctx context.Context, executionID string, req *workflow.SignalInvocationBody) (*workflow.SignalResponse, error) { + var resp workflow.SignalResponse + if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/workflows/executions/%s/signals", executionID), req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// QueryWorkflowExecution queries a workflow execution. +func (c *Client) QueryWorkflowExecution(ctx context.Context, executionID string, req *workflow.QueryInvocationBody) (*workflow.QueryResponse, error) { + var resp workflow.QueryResponse + if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/workflows/executions/%s/queries", executionID), req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// UpdateWorkflowExecution sends an update to a workflow execution. +func (c *Client) UpdateWorkflowExecution(ctx context.Context, executionID string, req *workflow.UpdateInvocationBody) (*workflow.UpdateResponse, error) { + var resp workflow.UpdateResponse + if err := c.doJSON(ctx, "POST", fmt.Sprintf("/v1/workflows/executions/%s/updates", executionID), req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// TerminateWorkflowExecution terminates a workflow execution. +func (c *Client) TerminateWorkflowExecution(ctx context.Context, executionID string) error { + resp, err := c.do(ctx, "POST", fmt.Sprintf("/v1/workflows/executions/%s/terminate", executionID), nil) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode >= 400 { + return parseAPIError(resp) + } + return nil +} + +// CancelWorkflowExecution cancels a workflow execution. +func (c *Client) CancelWorkflowExecution(ctx context.Context, executionID string) error { + resp, err := c.do(ctx, "POST", fmt.Sprintf("/v1/workflows/executions/%s/cancel", executionID), nil) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode >= 400 { + return parseAPIError(resp) + } + return nil +} + +// ResetWorkflowExecution resets a workflow execution to a specific event. +func (c *Client) ResetWorkflowExecution(ctx context.Context, executionID string, req *workflow.ResetInvocationBody) error { + return c.doJSON(ctx, "POST", fmt.Sprintf("/v1/workflows/executions/%s/reset", executionID), req, nil) +} + +// BatchCancelWorkflowExecutions cancels multiple workflow executions. +func (c *Client) BatchCancelWorkflowExecutions(ctx context.Context, req *workflow.BatchExecutionBody) (*workflow.BatchExecutionResponse, error) { + var resp workflow.BatchExecutionResponse + if err := c.doJSON(ctx, "POST", "/v1/workflows/executions/cancel", req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// BatchTerminateWorkflowExecutions terminates multiple workflow executions. +func (c *Client) BatchTerminateWorkflowExecutions(ctx context.Context, req *workflow.BatchExecutionBody) (*workflow.BatchExecutionResponse, error) { + var resp workflow.BatchExecutionResponse + if err := c.doJSON(ctx, "POST", "/v1/workflows/executions/terminate", req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// GetWorkflowExecutionTraceOTel retrieves the OpenTelemetry trace for a workflow execution. +func (c *Client) GetWorkflowExecutionTraceOTel(ctx context.Context, executionID string) (*workflow.TraceOTelResponse, error) { + var resp workflow.TraceOTelResponse + if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/workflows/executions/%s/trace/otel", executionID), nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// GetWorkflowExecutionTraceSummary retrieves the trace summary for a workflow execution. +func (c *Client) GetWorkflowExecutionTraceSummary(ctx context.Context, executionID string) (*workflow.TraceSummaryResponse, error) { + var resp workflow.TraceSummaryResponse + if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/workflows/executions/%s/trace/summary", executionID), nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// GetWorkflowExecutionTraceEvents retrieves the trace events for a workflow execution. +func (c *Client) GetWorkflowExecutionTraceEvents(ctx context.Context, executionID string, params *workflow.TraceEventsParams) (*workflow.TraceEventsResponse, error) { + path := fmt.Sprintf("/v1/workflows/executions/%s/trace/events", executionID) + if params != nil { + q := url.Values{} + if params.MergeSameIDEvents != nil { + q.Set("merge_same_id_events", strconv.FormatBool(*params.MergeSameIDEvents)) + } + if params.IncludeInternalEvents != nil { + q.Set("include_internal_events", strconv.FormatBool(*params.IncludeInternalEvents)) + } + if encoded := q.Encode(); encoded != "" { + path += "?" + encoded + } + } + var resp workflow.TraceEventsResponse + if err := c.doJSON(ctx, "GET", path, nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// WorkflowEventStream wraps the generic Stream to provide typed workflow events +// with StreamPayload envelope metadata. +type WorkflowEventStream struct { + stream *Stream[json.RawMessage] + event workflow.Event + payload *workflow.StreamPayload + err error +} + +func newWorkflowEventStream(body readCloser) *WorkflowEventStream { + return &WorkflowEventStream{ + stream: newStream[json.RawMessage](body), + } +} + +// Next advances to the next event. Returns false when done or on error. +func (s *WorkflowEventStream) Next() bool { + if s.err != nil { + return false + } + if !s.stream.Next() { + s.err = s.stream.Err() + return false + } + var payload workflow.StreamPayload + if err := json.Unmarshal(s.stream.Current(), &payload); err != nil { + s.err = fmt.Errorf("mistral: decode workflow stream payload: %w", err) + return false + } + event, err := workflow.UnmarshalEvent(payload.Data) + if err != nil { + s.err = err + return false + } + s.event = event + s.payload = &payload + return true +} + +// Current returns the most recently read workflow event. +func (s *WorkflowEventStream) Current() workflow.Event { return s.event } + +// CurrentPayload returns the full StreamPayload envelope of the current event. +func (s *WorkflowEventStream) CurrentPayload() *workflow.StreamPayload { return s.payload } + +// Err returns any error encountered during streaming. +func (s *WorkflowEventStream) Err() error { return s.err } + +// Close releases the underlying connection. +func (s *WorkflowEventStream) Close() error { return s.stream.Close() } + +// ExecuteWorkflowAndWait executes a workflow and polls until completion. +func (c *Client) ExecuteWorkflowAndWait(ctx context.Context, workflowIdentifier string, req *workflow.ExecutionRequest) (*workflow.ExecutionResponse, error) { + execResp, err := c.ExecuteWorkflow(ctx, workflowIdentifier, req) + if err != nil { + return nil, err + } + for { + if isTerminal(execResp.Status) { + return execResp, nil + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(500 * time.Millisecond): + } + execResp, err = c.GetWorkflowExecution(ctx, execResp.ExecutionID) + if err != nil { + return nil, err + } + } +} + +func isTerminal(s workflow.ExecutionStatus) bool { + switch s { + case workflow.ExecutionCompleted, workflow.ExecutionFailed, + workflow.ExecutionCanceled, workflow.ExecutionTerminated, + workflow.ExecutionTimedOut: + return true + } + return false +} diff --git a/workflows_executions_test.go b/workflows_executions_test.go new file mode 100644 index 0000000..5bbdb6f --- /dev/null +++ b/workflows_executions_test.go @@ -0,0 +1,266 @@ +package mistral + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "somegit.dev/vikingowl/mistral-go-sdk/workflow" +) + +func TestGetWorkflowExecution_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/workflows/executions/exec-1" { + t.Errorf("got path %s", r.URL.Path) + } + json.NewEncoder(w).Encode(map[string]any{ + "workflow_name": "my-flow", "execution_id": "exec-1", + "root_execution_id": "exec-1", "status": "COMPLETED", + "start_time": "2026-01-01T00:00:00Z", + "end_time": "2026-01-01T00:01:00Z", + "result": map[string]any{"answer": 42}, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.GetWorkflowExecution(context.Background(), "exec-1") + if err != nil { + t.Fatal(err) + } + if resp.Status != workflow.ExecutionCompleted { + t.Errorf("got status %q", resp.Status) + } +} + +func TestSignalWorkflowExecution_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("got method %s", r.Method) + } + if r.URL.Path != "/v1/workflows/executions/exec-1/signals" { + t.Errorf("got path %s", r.URL.Path) + } + var body map[string]any + json.NewDecoder(r.Body).Decode(&body) + if body["name"] != "approval" { + t.Errorf("got name %v", body["name"]) + } + w.WriteHeader(http.StatusAccepted) + json.NewEncoder(w).Encode(map[string]any{"message": "Signal accepted"}) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.SignalWorkflowExecution(context.Background(), "exec-1", &workflow.SignalInvocationBody{ + Name: "approval", + Input: map[string]any{"approved": true}, + }) + if err != nil { + t.Fatal(err) + } + if resp.Message != "Signal accepted" { + t.Errorf("got message %q", resp.Message) + } +} + +func TestTerminateWorkflowExecution_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("got method %s", r.Method) + } + if r.URL.Path != "/v1/workflows/executions/exec-1/terminate" { + t.Errorf("got path %s", r.URL.Path) + } + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + err := client.TerminateWorkflowExecution(context.Background(), "exec-1") + if err != nil { + t.Fatal(err) + } +} + +func TestBatchCancelWorkflowExecutions_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("got method %s", r.Method) + } + if r.URL.Path != "/v1/workflows/executions/cancel" { + t.Errorf("got path %s", r.URL.Path) + } + json.NewEncoder(w).Encode(map[string]any{ + "results": map[string]any{ + "exec-1": map[string]any{"status": "success"}, + "exec-2": map[string]any{"status": "failure", "error": "not found"}, + }, + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.BatchCancelWorkflowExecutions(context.Background(), &workflow.BatchExecutionBody{ + ExecutionIDs: []string{"exec-1", "exec-2"}, + }) + if err != nil { + t.Fatal(err) + } + if resp.Results["exec-1"].Status != "success" { + t.Errorf("got exec-1 status %q", resp.Results["exec-1"].Status) + } + if resp.Results["exec-2"].Error == nil || *resp.Results["exec-2"].Error != "not found" { + t.Errorf("expected exec-2 error") + } +} + +func TestStreamWorkflowExecution_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" { + t.Errorf("got method %s", r.Method) + } + if r.URL.Path != "/v1/workflows/executions/exec-1/stream" { + t.Errorf("got path %s", r.URL.Path) + } + w.Header().Set("Content-Type", "text/event-stream") + flusher, _ := w.(http.Flusher) + + payloads := []map[string]any{ + { + "stream": "events", + "data": map[string]any{ + "event_id": "evt-1", "event_timestamp": 1711929600000000000, + "root_workflow_exec_id": "exec-1", "parent_workflow_exec_id": nil, + "workflow_exec_id": "exec-1", "workflow_run_id": "run-1", + "workflow_name": "my-flow", "event_type": "WORKFLOW_EXECUTION_STARTED", + "attributes": map[string]any{}, + }, + "workflow_context": map[string]any{ + "namespace": "default", "workflow_name": "my-flow", "workflow_exec_id": "exec-1", + }, + "broker_sequence": 1, + }, + { + "stream": "events", + "data": map[string]any{ + "event_id": "evt-2", "event_timestamp": 1711929601000000000, + "root_workflow_exec_id": "exec-1", "parent_workflow_exec_id": nil, + "workflow_exec_id": "exec-1", "workflow_run_id": "run-1", + "workflow_name": "my-flow", "event_type": "WORKFLOW_EXECUTION_COMPLETED", + "attributes": map[string]any{"result": map[string]any{"value": 42, "type": "json"}}, + }, + "workflow_context": map[string]any{ + "namespace": "default", "workflow_name": "my-flow", "workflow_exec_id": "exec-1", + }, + "broker_sequence": 2, + }, + } + for _, p := range payloads { + data, _ := json.Marshal(p) + fmt.Fprintf(w, "data: %s\n\n", data) + flusher.Flush() + } + fmt.Fprint(w, "data: [DONE]\n\n") + flusher.Flush() + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + stream, err := client.StreamWorkflowExecution(context.Background(), "exec-1", nil) + if err != nil { + t.Fatal(err) + } + defer stream.Close() + + var events []workflow.Event + var lastPayload *workflow.StreamPayload + for stream.Next() { + events = append(events, stream.Current()) + lastPayload = stream.CurrentPayload() + } + if stream.Err() != nil { + t.Fatal(stream.Err()) + } + if len(events) != 2 { + t.Fatalf("got %d events, want 2", len(events)) + } + if _, ok := events[0].(*workflow.WorkflowExecutionStartedEvent); !ok { + t.Errorf("expected *WorkflowExecutionStartedEvent, got %T", events[0]) + } + if _, ok := events[1].(*workflow.WorkflowExecutionCompletedEvent); !ok { + t.Errorf("expected *WorkflowExecutionCompletedEvent, got %T", events[1]) + } + if lastPayload.WorkflowContext.WorkflowName != "my-flow" { + t.Errorf("got workflow context name %q", lastPayload.WorkflowContext.WorkflowName) + } +} + +func TestGetWorkflowExecutionTraceOTel_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/workflows/executions/exec-1/trace/otel" { + t.Errorf("got path %s", r.URL.Path) + } + json.NewEncoder(w).Encode(map[string]any{ + "workflow_name": "my-flow", "execution_id": "exec-1", + "root_execution_id": "exec-1", "status": "COMPLETED", + "start_time": "2026-01-01T00:00:00Z", "data_source": "temporal", + }) + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.GetWorkflowExecutionTraceOTel(context.Background(), "exec-1") + if err != nil { + t.Fatal(err) + } + if resp.DataSource != "temporal" { + t.Errorf("got data_source %q", resp.DataSource) + } +} + +func TestExecuteWorkflowAndWait_Success(t *testing.T) { + calls := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == "POST" && r.URL.Path == "/v1/workflows/wf-1/execute": + json.NewEncoder(w).Encode(map[string]any{ + "workflow_name": "my-flow", "execution_id": "exec-1", + "root_execution_id": "exec-1", "status": "RUNNING", + "start_time": "2026-01-01T00:00:00Z", + }) + case r.Method == "GET" && r.URL.Path == "/v1/workflows/executions/exec-1": + calls++ + status := "RUNNING" + if calls >= 2 { + status = "COMPLETED" + } + resp := map[string]any{ + "workflow_name": "my-flow", "execution_id": "exec-1", + "root_execution_id": "exec-1", "status": status, + "start_time": "2026-01-01T00:00:00Z", + } + if status == "COMPLETED" { + resp["result"] = map[string]any{"answer": 42} + } + json.NewEncoder(w).Encode(resp) + default: + t.Errorf("unexpected %s %s", r.Method, r.URL.Path) + } + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + resp, err := client.ExecuteWorkflowAndWait(context.Background(), "wf-1", &workflow.ExecutionRequest{ + Input: map[string]any{"prompt": "hello"}, + }) + if err != nil { + t.Fatal(err) + } + if resp.Status != workflow.ExecutionCompleted { + t.Errorf("got status %q", resp.Status) + } +}