diff --git a/chat_complete.go b/chat_complete.go index 63aac5c..2f93134 100644 --- a/chat_complete.go +++ b/chat_complete.go @@ -14,3 +14,14 @@ func (c *Client) ChatComplete(ctx context.Context, req *chat.CompletionRequest) } return &resp, nil } + +// ChatCompleteStream sends a chat completion request and returns a stream +// of completion chunks. The caller must call Close() on the returned stream. +func (c *Client) ChatCompleteStream(ctx context.Context, req *chat.CompletionRequest) (*Stream[chat.CompletionChunk], error) { + req.SetStream(true) + resp, err := c.doStream(ctx, "POST", "/v1/chat/completions", req) + if err != nil { + return nil, err + } + return newStream[chat.CompletionChunk](resp.Body), nil +} diff --git a/chat_stream_test.go b/chat_stream_test.go new file mode 100644 index 0000000..1e3b5de --- /dev/null +++ b/chat_stream_test.go @@ -0,0 +1,229 @@ +package mistral + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "somegit.dev/vikingowl/mistral-go-sdk/chat" +) + +func TestChatCompleteStream_Success(t *testing.T) { + chunks := []chat.CompletionChunk{ + { + ID: "chunk-1", + Model: "mistral-small-latest", + Choices: []chat.CompletionStreamChoice{{ + Index: 0, + Delta: chat.DeltaMessage{Role: "assistant"}, + }}, + }, + { + ID: "chunk-2", + Model: "mistral-small-latest", + Choices: []chat.CompletionStreamChoice{{ + Index: 0, + Delta: chat.DeltaMessage{Content: chat.TextContent("Hello")}, + }}, + }, + { + ID: "chunk-3", + Model: "mistral-small-latest", + Choices: []chat.CompletionStreamChoice{{ + Index: 0, + Delta: chat.DeltaMessage{Content: chat.TextContent(" world!")}, + }}, + }, + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body map[string]any + json.NewDecoder(r.Body).Decode(&body) + if body["stream"] != true { + t.Errorf("expected stream=true, got %v", body["stream"]) + } + + w.Header().Set("Content-Type", "text/event-stream") + flusher, _ := w.(http.Flusher) + for _, chunk := range chunks { + data, _ := json.Marshal(chunk) + fmt.Fprintf(w, "data: %s\n\n", data) + flusher.Flush() + } + fmt.Fprint(w, "data: [DONE]\n\n") + flusher.Flush() + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + stream, err := client.ChatCompleteStream(context.Background(), &chat.CompletionRequest{ + Model: "mistral-small-latest", + Messages: []chat.Message{ + &chat.UserMessage{Content: chat.TextContent("Hi")}, + }, + }) + if err != nil { + t.Fatal(err) + } + defer stream.Close() + + var received []chat.CompletionChunk + for stream.Next() { + received = append(received, stream.Current()) + } + if stream.Err() != nil { + t.Fatal(stream.Err()) + } + if len(received) != 3 { + t.Fatalf("got %d chunks, want 3", len(received)) + } + if received[0].Choices[0].Delta.Role != "assistant" { + t.Errorf("expected first chunk role=assistant") + } + if received[1].Choices[0].Delta.Content.String() != "Hello" { + t.Errorf("got %q", received[1].Choices[0].Delta.Content.String()) + } + if received[2].Choices[0].Delta.Content.String() != " world!" { + t.Errorf("got %q", received[2].Choices[0].Delta.Content.String()) + } +} + +func TestChatCompleteStream_CollectContent(t *testing.T) { + words := []string{"The", " quick", " brown", " fox"} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + flusher, _ := w.(http.Flusher) + for _, word := range words { + chunk := chat.CompletionChunk{ + ID: "c", + Model: "m", + Choices: []chat.CompletionStreamChoice{{ + Index: 0, + Delta: chat.DeltaMessage{Content: chat.TextContent(word)}, + }}, + } + data, _ := json.Marshal(chunk) + fmt.Fprintf(w, "data: %s\n\n", data) + flusher.Flush() + } + stop := "stop" + final := chat.CompletionChunk{ + ID: "c", + Model: "m", + Choices: []chat.CompletionStreamChoice{{ + Index: 0, + Delta: chat.DeltaMessage{}, + FinishReason: &stop, + }}, + Usage: &chat.UsageInfo{PromptTokens: 5, CompletionTokens: 4, TotalTokens: 9}, + } + data, _ := json.Marshal(final) + fmt.Fprintf(w, "data: %s\n\n", data) + fmt.Fprint(w, "data: [DONE]\n\n") + flusher.Flush() + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + stream, err := client.ChatCompleteStream(context.Background(), &chat.CompletionRequest{ + Model: "m", + Messages: []chat.Message{&chat.UserMessage{Content: chat.TextContent("Hi")}}, + }) + if err != nil { + t.Fatal(err) + } + defer stream.Close() + + var sb strings.Builder + var lastChunk chat.CompletionChunk + for stream.Next() { + lastChunk = stream.Current() + if len(lastChunk.Choices) > 0 { + sb.WriteString(lastChunk.Choices[0].Delta.Content.String()) + } + } + if stream.Err() != nil { + t.Fatal(stream.Err()) + } + if sb.String() != "The quick brown fox" { + t.Errorf("got %q", sb.String()) + } + if lastChunk.Usage == nil { + t.Fatal("expected usage in final chunk") + } + if lastChunk.Usage.TotalTokens != 9 { + t.Errorf("got total_tokens=%d", lastChunk.Usage.TotalTokens) + } +} + +func TestChatCompleteStream_WithToolCalls(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + flusher, _ := w.(http.Flusher) + + chunk := chat.CompletionChunk{ + ID: "c", + Model: "m", + Choices: []chat.CompletionStreamChoice{{ + Index: 0, + Delta: chat.DeltaMessage{ + ToolCalls: []chat.ToolCall{{ + ID: "call_1", + Type: "function", + Function: chat.FunctionCall{Name: "get_weather", Arguments: `{"city":"Paris"}`}, + }}, + }, + }}, + } + data, _ := json.Marshal(chunk) + fmt.Fprintf(w, "data: %s\n\n", data) + fmt.Fprint(w, "data: [DONE]\n\n") + flusher.Flush() + })) + defer server.Close() + + client := NewClient("key", WithBaseURL(server.URL)) + stream, err := client.ChatCompleteStream(context.Background(), &chat.CompletionRequest{ + Model: "m", + Messages: []chat.Message{&chat.UserMessage{Content: chat.TextContent("Weather?")}}, + }) + if err != nil { + t.Fatal(err) + } + defer stream.Close() + + if !stream.Next() { + t.Fatalf("expected chunk, err: %v", stream.Err()) + } + tc := stream.Current().Choices[0].Delta.ToolCalls + if len(tc) != 1 || tc[0].Function.Name != "get_weather" { + t.Errorf("got tool calls %+v", tc) + } +} + +func TestChatCompleteStream_APIError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + json.NewEncoder(w).Encode(map[string]any{ + "message": "invalid key", + "type": "auth_error", + }) + })) + defer server.Close() + + client := NewClient("bad", WithBaseURL(server.URL)) + _, err := client.ChatCompleteStream(context.Background(), &chat.CompletionRequest{ + Model: "m", + Messages: []chat.Message{&chat.UserMessage{Content: chat.TextContent("Hi")}}, + }) + if err == nil { + t.Fatal("expected error") + } + if !IsAuth(err) { + t.Errorf("expected auth error, got: %v", err) + } +} diff --git a/request.go b/request.go index 2f73aa4..71acf0f 100644 --- a/request.go +++ b/request.go @@ -51,6 +51,26 @@ func (c *Client) doJSON(ctx context.Context, method, path string, reqBody, respB return nil } +func (c *Client) doStream(ctx context.Context, method, path string, reqBody any) (*http.Response, error) { + var body io.Reader + if reqBody != nil { + data, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("mistral: marshal request: %w", err) + } + body = bytes.NewReader(data) + } + resp, err := c.do(ctx, method, path, body) + if err != nil { + return nil, err + } + if resp.StatusCode >= 400 { + defer resp.Body.Close() + return nil, parseAPIError(resp) + } + return resp, nil +} + func parseAPIError(resp *http.Response) error { body, err := io.ReadAll(resp.Body) if err != nil { diff --git a/sse.go b/sse.go new file mode 100644 index 0000000..8f48553 --- /dev/null +++ b/sse.go @@ -0,0 +1,75 @@ +package mistral + +import ( + "bufio" + "bytes" + "io" +) + +// sseEvent represents a single Server-Sent Event. +type sseEvent struct { + Event string + Data []byte +} + +// isDone returns true if this event signals end-of-stream. +func (e *sseEvent) isDone() bool { + return string(bytes.TrimSpace(e.Data)) == "[DONE]" +} + +// sseReader reads Server-Sent Events from an io.Reader. +type sseReader struct { + scanner *bufio.Scanner +} + +func newSSEReader(r io.Reader) *sseReader { + return &sseReader{scanner: bufio.NewScanner(r)} +} + +// next reads the next SSE event. Returns nil, nil at EOF. +func (r *sseReader) next() (*sseEvent, error) { + var event sseEvent + var hasData bool + + for r.scanner.Scan() { + line := r.scanner.Bytes() + + // Blank line = end of event + if len(line) == 0 { + if hasData { + return &event, nil + } + continue + } + + // Skip comments + if line[0] == ':' { + continue + } + + field, value, _ := bytes.Cut(line, []byte(":")) + // Strip single leading space from value per SSE spec + value = bytes.TrimPrefix(value, []byte(" ")) + + switch string(field) { + case "event": + event.Event = string(value) + case "data": + if hasData { + event.Data = append(event.Data, '\n') + } + event.Data = append(event.Data, value...) + hasData = true + } + } + + if err := r.scanner.Err(); err != nil { + return nil, err + } + + // Final event without trailing blank line + if hasData { + return &event, nil + } + return nil, nil +} diff --git a/sse_test.go b/sse_test.go new file mode 100644 index 0000000..4168c76 --- /dev/null +++ b/sse_test.go @@ -0,0 +1,182 @@ +package mistral + +import ( + "strings" + "testing" +) + +func TestSSEReader_SingleEvent(t *testing.T) { + input := "data: {\"id\":\"1\"}\n\n" + r := newSSEReader(strings.NewReader(input)) + ev, err := r.next() + if err != nil { + t.Fatal(err) + } + if ev == nil { + t.Fatal("expected event") + } + if string(ev.Data) != `{"id":"1"}` { + t.Errorf("got data %q", ev.Data) + } +} + +func TestSSEReader_MultipleEvents(t *testing.T) { + input := "data: first\n\ndata: second\n\n" + r := newSSEReader(strings.NewReader(input)) + + ev1, err := r.next() + if err != nil { + t.Fatal(err) + } + if string(ev1.Data) != "first" { + t.Errorf("got %q, want %q", ev1.Data, "first") + } + + ev2, err := r.next() + if err != nil { + t.Fatal(err) + } + if string(ev2.Data) != "second" { + t.Errorf("got %q, want %q", ev2.Data, "second") + } + + ev3, err := r.next() + if err != nil { + t.Fatal(err) + } + if ev3 != nil { + t.Errorf("expected nil at EOF, got %+v", ev3) + } +} + +func TestSSEReader_MultiLineData(t *testing.T) { + input := "data: line1\ndata: line2\ndata: line3\n\n" + r := newSSEReader(strings.NewReader(input)) + ev, err := r.next() + if err != nil { + t.Fatal(err) + } + want := "line1\nline2\nline3" + if string(ev.Data) != want { + t.Errorf("got %q, want %q", ev.Data, want) + } +} + +func TestSSEReader_EventField(t *testing.T) { + input := "event: completion\ndata: {\"id\":\"1\"}\n\n" + r := newSSEReader(strings.NewReader(input)) + ev, err := r.next() + if err != nil { + t.Fatal(err) + } + if ev.Event != "completion" { + t.Errorf("got event %q, want %q", ev.Event, "completion") + } + if string(ev.Data) != `{"id":"1"}` { + t.Errorf("got data %q", ev.Data) + } +} + +func TestSSEReader_SkipsComments(t *testing.T) { + input := ": this is a comment\ndata: hello\n\n" + r := newSSEReader(strings.NewReader(input)) + ev, err := r.next() + if err != nil { + t.Fatal(err) + } + if string(ev.Data) != "hello" { + t.Errorf("got %q, want %q", ev.Data, "hello") + } +} + +func TestSSEReader_Done(t *testing.T) { + input := "data: {\"id\":\"1\"}\n\ndata: [DONE]\n\n" + r := newSSEReader(strings.NewReader(input)) + + ev1, err := r.next() + if err != nil { + t.Fatal(err) + } + if ev1.isDone() { + t.Error("first event should not be done") + } + + ev2, err := r.next() + if err != nil { + t.Fatal(err) + } + if !ev2.isDone() { + t.Error("second event should be done") + } +} + +func TestSSEReader_DoneWithWhitespace(t *testing.T) { + ev := &sseEvent{Data: []byte(" [DONE] ")} + if !ev.isDone() { + t.Error("should detect [DONE] with whitespace") + } +} + +func TestSSEReader_EmptyStream(t *testing.T) { + r := newSSEReader(strings.NewReader("")) + ev, err := r.next() + if err != nil { + t.Fatal(err) + } + if ev != nil { + t.Errorf("expected nil for empty stream, got %+v", ev) + } +} + +func TestSSEReader_OnlyComments(t *testing.T) { + input := ": comment1\n: comment2\n\n" + r := newSSEReader(strings.NewReader(input)) + ev, err := r.next() + if err != nil { + t.Fatal(err) + } + if ev != nil { + t.Errorf("expected nil, got %+v", ev) + } +} + +func TestSSEReader_NoTrailingNewline(t *testing.T) { + input := "data: hello" + r := newSSEReader(strings.NewReader(input)) + ev, err := r.next() + if err != nil { + t.Fatal(err) + } + if ev == nil { + t.Fatal("expected event for data without trailing blank line") + } + if string(ev.Data) != "hello" { + t.Errorf("got %q, want %q", ev.Data, "hello") + } +} + +func TestSSEReader_DataNoSpace(t *testing.T) { + input := "data:{\"compact\":true}\n\n" + r := newSSEReader(strings.NewReader(input)) + ev, err := r.next() + if err != nil { + t.Fatal(err) + } + if string(ev.Data) != `{"compact":true}` { + t.Errorf("got %q", ev.Data) + } +} + +func TestSSEReader_MultipleBlankLines(t *testing.T) { + input := "data: first\n\n\n\ndata: second\n\n" + r := newSSEReader(strings.NewReader(input)) + + ev1, _ := r.next() + if string(ev1.Data) != "first" { + t.Errorf("got %q", ev1.Data) + } + ev2, _ := r.next() + if string(ev2.Data) != "second" { + t.Errorf("got %q", ev2.Data) + } +} diff --git a/stream.go b/stream.go new file mode 100644 index 0000000..97ab649 --- /dev/null +++ b/stream.go @@ -0,0 +1,72 @@ +package mistral + +import ( + "encoding/json" + "fmt" + "io" +) + +// Stream is a generic iterator for streaming API responses. +// Use Next() to advance, Current() to read the value, Err() for errors, +// and Close() when done. +type Stream[T any] struct { + reader *sseReader + closer io.Closer + current T + err error + done bool +} + +func newStream[T any](body io.ReadCloser) *Stream[T] { + return &Stream[T]{ + reader: newSSEReader(body), + closer: body, + } +} + +// Next advances to the next event. Returns false when the stream +// is exhausted or an error occurs. +func (s *Stream[T]) Next() bool { + if s.done || s.err != nil { + return false + } + for { + event, err := s.reader.next() + if err != nil { + s.err = fmt.Errorf("mistral: read stream: %w", err) + return false + } + if event == nil { + s.done = true + return false + } + if event.isDone() { + s.done = true + return false + } + + var v T + if err := json.Unmarshal(event.Data, &v); err != nil { + s.err = fmt.Errorf("mistral: decode stream event: %w", err) + return false + } + s.current = v + return true + } +} + +// Current returns the most recently read value. +// Only valid after Next() returns true. +func (s *Stream[T]) Current() T { + return s.current +} + +// Err returns any error encountered during streaming. +func (s *Stream[T]) Err() error { + return s.err +} + +// Close releases the underlying HTTP response body. +func (s *Stream[T]) Close() error { + return s.closer.Close() +} diff --git a/stream_test.go b/stream_test.go new file mode 100644 index 0000000..21fc5fc --- /dev/null +++ b/stream_test.go @@ -0,0 +1,141 @@ +package mistral + +import ( + "io" + "strings" + "testing" +) + +type testChunk struct { + ID string `json:"id"` + Content string `json:"content"` +} + +func newTestStream(sse string) *Stream[testChunk] { + body := io.NopCloser(strings.NewReader(sse)) + return newStream[testChunk](body) +} + +func TestStream_SingleChunk(t *testing.T) { + input := "data: {\"id\":\"1\",\"content\":\"hello\"}\n\ndata: [DONE]\n\n" + s := newTestStream(input) + defer s.Close() + + if !s.Next() { + t.Fatalf("expected Next() to return true, err: %v", s.Err()) + } + chunk := s.Current() + if chunk.ID != "1" || chunk.Content != "hello" { + t.Errorf("got %+v", chunk) + } + if s.Next() { + t.Error("expected Next() to return false after [DONE]") + } + if s.Err() != nil { + t.Errorf("unexpected error: %v", s.Err()) + } +} + +func TestStream_MultipleChunks(t *testing.T) { + input := "data: {\"id\":\"1\",\"content\":\"a\"}\n\ndata: {\"id\":\"2\",\"content\":\"b\"}\n\ndata: {\"id\":\"3\",\"content\":\"c\"}\n\ndata: [DONE]\n\n" + s := newTestStream(input) + defer s.Close() + + var chunks []testChunk + for s.Next() { + chunks = append(chunks, s.Current()) + } + if s.Err() != nil { + t.Fatal(s.Err()) + } + if len(chunks) != 3 { + t.Fatalf("got %d chunks, want 3", len(chunks)) + } + if chunks[0].Content != "a" || chunks[1].Content != "b" || chunks[2].Content != "c" { + t.Errorf("got %+v", chunks) + } +} + +func TestStream_EmptyStream(t *testing.T) { + s := newTestStream("data: [DONE]\n\n") + defer s.Close() + + if s.Next() { + t.Error("expected no chunks before [DONE]") + } + if s.Err() != nil { + t.Errorf("unexpected error: %v", s.Err()) + } +} + +func TestStream_InvalidJSON(t *testing.T) { + input := "data: not-json\n\n" + s := newTestStream(input) + defer s.Close() + + if s.Next() { + t.Error("expected Next() to return false for invalid JSON") + } + if s.Err() == nil { + t.Error("expected error for invalid JSON") + } +} + +func TestStream_NextAfterDone(t *testing.T) { + input := "data: {\"id\":\"1\",\"content\":\"x\"}\n\ndata: [DONE]\n\n" + s := newTestStream(input) + defer s.Close() + + s.Next() // consume first chunk + s.Next() // hits [DONE] + + // Calling Next() again should still return false + if s.Next() { + t.Error("expected false after stream is done") + } +} + +func TestStream_NextAfterError(t *testing.T) { + input := "data: bad\n\n" + s := newTestStream(input) + defer s.Close() + + s.Next() // triggers error + + // Calling Next() again should still return false + if s.Next() { + t.Error("expected false after error") + } +} + +func TestStream_WithComments(t *testing.T) { + input := ": keep-alive\ndata: {\"id\":\"1\",\"content\":\"ok\"}\n\n: ping\ndata: [DONE]\n\n" + s := newTestStream(input) + defer s.Close() + + if !s.Next() { + t.Fatalf("expected chunk, err: %v", s.Err()) + } + if s.Current().Content != "ok" { + t.Errorf("got %q", s.Current().Content) + } + if s.Next() { + t.Error("expected done after [DONE]") + } +} + +func TestStream_EOFWithoutDone(t *testing.T) { + input := "data: {\"id\":\"1\",\"content\":\"x\"}\n\n" + s := newTestStream(input) + defer s.Close() + + if !s.Next() { + t.Fatalf("expected chunk, err: %v", s.Err()) + } + if s.Next() { + t.Error("expected false at EOF") + } + if s.Err() != nil { + t.Errorf("expected no error at clean EOF, got: %v", s.Err()) + } +}