feat: Phase 2 streaming — SSE parser, Stream[T], ChatCompleteStream

Add streaming infrastructure:
- SSE line parser handling multi-line data, comments, [DONE] sentinel
- Generic Stream[T] pull-based iterator (no goroutines, no channel leaks)
- doStream() HTTP helper for streaming endpoints
- ChatCompleteStream() method
- 28 new tests: SSE edge cases, iterator behavior, httptest integration
This commit is contained in:
2026-03-05 19:33:07 +01:00
parent 9c85f64140
commit 9b453ca62a
7 changed files with 730 additions and 0 deletions

View File

@@ -14,3 +14,14 @@ func (c *Client) ChatComplete(ctx context.Context, req *chat.CompletionRequest)
}
return &resp, nil
}
// ChatCompleteStream sends a chat completion request and returns a stream
// of completion chunks. The caller must call Close() on the returned stream.
func (c *Client) ChatCompleteStream(ctx context.Context, req *chat.CompletionRequest) (*Stream[chat.CompletionChunk], error) {
req.SetStream(true)
resp, err := c.doStream(ctx, "POST", "/v1/chat/completions", req)
if err != nil {
return nil, err
}
return newStream[chat.CompletionChunk](resp.Body), nil
}

229
chat_stream_test.go Normal file
View File

@@ -0,0 +1,229 @@
package mistral
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"somegit.dev/vikingowl/mistral-go-sdk/chat"
)
func TestChatCompleteStream_Success(t *testing.T) {
chunks := []chat.CompletionChunk{
{
ID: "chunk-1",
Model: "mistral-small-latest",
Choices: []chat.CompletionStreamChoice{{
Index: 0,
Delta: chat.DeltaMessage{Role: "assistant"},
}},
},
{
ID: "chunk-2",
Model: "mistral-small-latest",
Choices: []chat.CompletionStreamChoice{{
Index: 0,
Delta: chat.DeltaMessage{Content: chat.TextContent("Hello")},
}},
},
{
ID: "chunk-3",
Model: "mistral-small-latest",
Choices: []chat.CompletionStreamChoice{{
Index: 0,
Delta: chat.DeltaMessage{Content: chat.TextContent(" world!")},
}},
},
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var body map[string]any
json.NewDecoder(r.Body).Decode(&body)
if body["stream"] != true {
t.Errorf("expected stream=true, got %v", body["stream"])
}
w.Header().Set("Content-Type", "text/event-stream")
flusher, _ := w.(http.Flusher)
for _, chunk := range chunks {
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()
}
fmt.Fprint(w, "data: [DONE]\n\n")
flusher.Flush()
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
stream, err := client.ChatCompleteStream(context.Background(), &chat.CompletionRequest{
Model: "mistral-small-latest",
Messages: []chat.Message{
&chat.UserMessage{Content: chat.TextContent("Hi")},
},
})
if err != nil {
t.Fatal(err)
}
defer stream.Close()
var received []chat.CompletionChunk
for stream.Next() {
received = append(received, stream.Current())
}
if stream.Err() != nil {
t.Fatal(stream.Err())
}
if len(received) != 3 {
t.Fatalf("got %d chunks, want 3", len(received))
}
if received[0].Choices[0].Delta.Role != "assistant" {
t.Errorf("expected first chunk role=assistant")
}
if received[1].Choices[0].Delta.Content.String() != "Hello" {
t.Errorf("got %q", received[1].Choices[0].Delta.Content.String())
}
if received[2].Choices[0].Delta.Content.String() != " world!" {
t.Errorf("got %q", received[2].Choices[0].Delta.Content.String())
}
}
func TestChatCompleteStream_CollectContent(t *testing.T) {
words := []string{"The", " quick", " brown", " fox"}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
flusher, _ := w.(http.Flusher)
for _, word := range words {
chunk := chat.CompletionChunk{
ID: "c",
Model: "m",
Choices: []chat.CompletionStreamChoice{{
Index: 0,
Delta: chat.DeltaMessage{Content: chat.TextContent(word)},
}},
}
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()
}
stop := "stop"
final := chat.CompletionChunk{
ID: "c",
Model: "m",
Choices: []chat.CompletionStreamChoice{{
Index: 0,
Delta: chat.DeltaMessage{},
FinishReason: &stop,
}},
Usage: &chat.UsageInfo{PromptTokens: 5, CompletionTokens: 4, TotalTokens: 9},
}
data, _ := json.Marshal(final)
fmt.Fprintf(w, "data: %s\n\n", data)
fmt.Fprint(w, "data: [DONE]\n\n")
flusher.Flush()
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
stream, err := client.ChatCompleteStream(context.Background(), &chat.CompletionRequest{
Model: "m",
Messages: []chat.Message{&chat.UserMessage{Content: chat.TextContent("Hi")}},
})
if err != nil {
t.Fatal(err)
}
defer stream.Close()
var sb strings.Builder
var lastChunk chat.CompletionChunk
for stream.Next() {
lastChunk = stream.Current()
if len(lastChunk.Choices) > 0 {
sb.WriteString(lastChunk.Choices[0].Delta.Content.String())
}
}
if stream.Err() != nil {
t.Fatal(stream.Err())
}
if sb.String() != "The quick brown fox" {
t.Errorf("got %q", sb.String())
}
if lastChunk.Usage == nil {
t.Fatal("expected usage in final chunk")
}
if lastChunk.Usage.TotalTokens != 9 {
t.Errorf("got total_tokens=%d", lastChunk.Usage.TotalTokens)
}
}
func TestChatCompleteStream_WithToolCalls(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
flusher, _ := w.(http.Flusher)
chunk := chat.CompletionChunk{
ID: "c",
Model: "m",
Choices: []chat.CompletionStreamChoice{{
Index: 0,
Delta: chat.DeltaMessage{
ToolCalls: []chat.ToolCall{{
ID: "call_1",
Type: "function",
Function: chat.FunctionCall{Name: "get_weather", Arguments: `{"city":"Paris"}`},
}},
},
}},
}
data, _ := json.Marshal(chunk)
fmt.Fprintf(w, "data: %s\n\n", data)
fmt.Fprint(w, "data: [DONE]\n\n")
flusher.Flush()
}))
defer server.Close()
client := NewClient("key", WithBaseURL(server.URL))
stream, err := client.ChatCompleteStream(context.Background(), &chat.CompletionRequest{
Model: "m",
Messages: []chat.Message{&chat.UserMessage{Content: chat.TextContent("Weather?")}},
})
if err != nil {
t.Fatal(err)
}
defer stream.Close()
if !stream.Next() {
t.Fatalf("expected chunk, err: %v", stream.Err())
}
tc := stream.Current().Choices[0].Delta.ToolCalls
if len(tc) != 1 || tc[0].Function.Name != "get_weather" {
t.Errorf("got tool calls %+v", tc)
}
}
func TestChatCompleteStream_APIError(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(map[string]any{
"message": "invalid key",
"type": "auth_error",
})
}))
defer server.Close()
client := NewClient("bad", WithBaseURL(server.URL))
_, err := client.ChatCompleteStream(context.Background(), &chat.CompletionRequest{
Model: "m",
Messages: []chat.Message{&chat.UserMessage{Content: chat.TextContent("Hi")}},
})
if err == nil {
t.Fatal("expected error")
}
if !IsAuth(err) {
t.Errorf("expected auth error, got: %v", err)
}
}

View File

@@ -51,6 +51,26 @@ func (c *Client) doJSON(ctx context.Context, method, path string, reqBody, respB
return nil
}
func (c *Client) doStream(ctx context.Context, method, path string, reqBody any) (*http.Response, error) {
var body io.Reader
if reqBody != nil {
data, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("mistral: marshal request: %w", err)
}
body = bytes.NewReader(data)
}
resp, err := c.do(ctx, method, path, body)
if err != nil {
return nil, err
}
if resp.StatusCode >= 400 {
defer resp.Body.Close()
return nil, parseAPIError(resp)
}
return resp, nil
}
func parseAPIError(resp *http.Response) error {
body, err := io.ReadAll(resp.Body)
if err != nil {

75
sse.go Normal file
View File

@@ -0,0 +1,75 @@
package mistral
import (
"bufio"
"bytes"
"io"
)
// sseEvent represents a single Server-Sent Event.
type sseEvent struct {
Event string
Data []byte
}
// isDone returns true if this event signals end-of-stream.
func (e *sseEvent) isDone() bool {
return string(bytes.TrimSpace(e.Data)) == "[DONE]"
}
// sseReader reads Server-Sent Events from an io.Reader.
type sseReader struct {
scanner *bufio.Scanner
}
func newSSEReader(r io.Reader) *sseReader {
return &sseReader{scanner: bufio.NewScanner(r)}
}
// next reads the next SSE event. Returns nil, nil at EOF.
func (r *sseReader) next() (*sseEvent, error) {
var event sseEvent
var hasData bool
for r.scanner.Scan() {
line := r.scanner.Bytes()
// Blank line = end of event
if len(line) == 0 {
if hasData {
return &event, nil
}
continue
}
// Skip comments
if line[0] == ':' {
continue
}
field, value, _ := bytes.Cut(line, []byte(":"))
// Strip single leading space from value per SSE spec
value = bytes.TrimPrefix(value, []byte(" "))
switch string(field) {
case "event":
event.Event = string(value)
case "data":
if hasData {
event.Data = append(event.Data, '\n')
}
event.Data = append(event.Data, value...)
hasData = true
}
}
if err := r.scanner.Err(); err != nil {
return nil, err
}
// Final event without trailing blank line
if hasData {
return &event, nil
}
return nil, nil
}

182
sse_test.go Normal file
View File

@@ -0,0 +1,182 @@
package mistral
import (
"strings"
"testing"
)
func TestSSEReader_SingleEvent(t *testing.T) {
input := "data: {\"id\":\"1\"}\n\n"
r := newSSEReader(strings.NewReader(input))
ev, err := r.next()
if err != nil {
t.Fatal(err)
}
if ev == nil {
t.Fatal("expected event")
}
if string(ev.Data) != `{"id":"1"}` {
t.Errorf("got data %q", ev.Data)
}
}
func TestSSEReader_MultipleEvents(t *testing.T) {
input := "data: first\n\ndata: second\n\n"
r := newSSEReader(strings.NewReader(input))
ev1, err := r.next()
if err != nil {
t.Fatal(err)
}
if string(ev1.Data) != "first" {
t.Errorf("got %q, want %q", ev1.Data, "first")
}
ev2, err := r.next()
if err != nil {
t.Fatal(err)
}
if string(ev2.Data) != "second" {
t.Errorf("got %q, want %q", ev2.Data, "second")
}
ev3, err := r.next()
if err != nil {
t.Fatal(err)
}
if ev3 != nil {
t.Errorf("expected nil at EOF, got %+v", ev3)
}
}
func TestSSEReader_MultiLineData(t *testing.T) {
input := "data: line1\ndata: line2\ndata: line3\n\n"
r := newSSEReader(strings.NewReader(input))
ev, err := r.next()
if err != nil {
t.Fatal(err)
}
want := "line1\nline2\nline3"
if string(ev.Data) != want {
t.Errorf("got %q, want %q", ev.Data, want)
}
}
func TestSSEReader_EventField(t *testing.T) {
input := "event: completion\ndata: {\"id\":\"1\"}\n\n"
r := newSSEReader(strings.NewReader(input))
ev, err := r.next()
if err != nil {
t.Fatal(err)
}
if ev.Event != "completion" {
t.Errorf("got event %q, want %q", ev.Event, "completion")
}
if string(ev.Data) != `{"id":"1"}` {
t.Errorf("got data %q", ev.Data)
}
}
func TestSSEReader_SkipsComments(t *testing.T) {
input := ": this is a comment\ndata: hello\n\n"
r := newSSEReader(strings.NewReader(input))
ev, err := r.next()
if err != nil {
t.Fatal(err)
}
if string(ev.Data) != "hello" {
t.Errorf("got %q, want %q", ev.Data, "hello")
}
}
func TestSSEReader_Done(t *testing.T) {
input := "data: {\"id\":\"1\"}\n\ndata: [DONE]\n\n"
r := newSSEReader(strings.NewReader(input))
ev1, err := r.next()
if err != nil {
t.Fatal(err)
}
if ev1.isDone() {
t.Error("first event should not be done")
}
ev2, err := r.next()
if err != nil {
t.Fatal(err)
}
if !ev2.isDone() {
t.Error("second event should be done")
}
}
func TestSSEReader_DoneWithWhitespace(t *testing.T) {
ev := &sseEvent{Data: []byte(" [DONE] ")}
if !ev.isDone() {
t.Error("should detect [DONE] with whitespace")
}
}
func TestSSEReader_EmptyStream(t *testing.T) {
r := newSSEReader(strings.NewReader(""))
ev, err := r.next()
if err != nil {
t.Fatal(err)
}
if ev != nil {
t.Errorf("expected nil for empty stream, got %+v", ev)
}
}
func TestSSEReader_OnlyComments(t *testing.T) {
input := ": comment1\n: comment2\n\n"
r := newSSEReader(strings.NewReader(input))
ev, err := r.next()
if err != nil {
t.Fatal(err)
}
if ev != nil {
t.Errorf("expected nil, got %+v", ev)
}
}
func TestSSEReader_NoTrailingNewline(t *testing.T) {
input := "data: hello"
r := newSSEReader(strings.NewReader(input))
ev, err := r.next()
if err != nil {
t.Fatal(err)
}
if ev == nil {
t.Fatal("expected event for data without trailing blank line")
}
if string(ev.Data) != "hello" {
t.Errorf("got %q, want %q", ev.Data, "hello")
}
}
func TestSSEReader_DataNoSpace(t *testing.T) {
input := "data:{\"compact\":true}\n\n"
r := newSSEReader(strings.NewReader(input))
ev, err := r.next()
if err != nil {
t.Fatal(err)
}
if string(ev.Data) != `{"compact":true}` {
t.Errorf("got %q", ev.Data)
}
}
func TestSSEReader_MultipleBlankLines(t *testing.T) {
input := "data: first\n\n\n\ndata: second\n\n"
r := newSSEReader(strings.NewReader(input))
ev1, _ := r.next()
if string(ev1.Data) != "first" {
t.Errorf("got %q", ev1.Data)
}
ev2, _ := r.next()
if string(ev2.Data) != "second" {
t.Errorf("got %q", ev2.Data)
}
}

72
stream.go Normal file
View File

@@ -0,0 +1,72 @@
package mistral
import (
"encoding/json"
"fmt"
"io"
)
// Stream is a generic iterator for streaming API responses.
// Use Next() to advance, Current() to read the value, Err() for errors,
// and Close() when done.
type Stream[T any] struct {
reader *sseReader
closer io.Closer
current T
err error
done bool
}
func newStream[T any](body io.ReadCloser) *Stream[T] {
return &Stream[T]{
reader: newSSEReader(body),
closer: body,
}
}
// Next advances to the next event. Returns false when the stream
// is exhausted or an error occurs.
func (s *Stream[T]) Next() bool {
if s.done || s.err != nil {
return false
}
for {
event, err := s.reader.next()
if err != nil {
s.err = fmt.Errorf("mistral: read stream: %w", err)
return false
}
if event == nil {
s.done = true
return false
}
if event.isDone() {
s.done = true
return false
}
var v T
if err := json.Unmarshal(event.Data, &v); err != nil {
s.err = fmt.Errorf("mistral: decode stream event: %w", err)
return false
}
s.current = v
return true
}
}
// Current returns the most recently read value.
// Only valid after Next() returns true.
func (s *Stream[T]) Current() T {
return s.current
}
// Err returns any error encountered during streaming.
func (s *Stream[T]) Err() error {
return s.err
}
// Close releases the underlying HTTP response body.
func (s *Stream[T]) Close() error {
return s.closer.Close()
}

141
stream_test.go Normal file
View File

@@ -0,0 +1,141 @@
package mistral
import (
"io"
"strings"
"testing"
)
type testChunk struct {
ID string `json:"id"`
Content string `json:"content"`
}
func newTestStream(sse string) *Stream[testChunk] {
body := io.NopCloser(strings.NewReader(sse))
return newStream[testChunk](body)
}
func TestStream_SingleChunk(t *testing.T) {
input := "data: {\"id\":\"1\",\"content\":\"hello\"}\n\ndata: [DONE]\n\n"
s := newTestStream(input)
defer s.Close()
if !s.Next() {
t.Fatalf("expected Next() to return true, err: %v", s.Err())
}
chunk := s.Current()
if chunk.ID != "1" || chunk.Content != "hello" {
t.Errorf("got %+v", chunk)
}
if s.Next() {
t.Error("expected Next() to return false after [DONE]")
}
if s.Err() != nil {
t.Errorf("unexpected error: %v", s.Err())
}
}
func TestStream_MultipleChunks(t *testing.T) {
input := "data: {\"id\":\"1\",\"content\":\"a\"}\n\ndata: {\"id\":\"2\",\"content\":\"b\"}\n\ndata: {\"id\":\"3\",\"content\":\"c\"}\n\ndata: [DONE]\n\n"
s := newTestStream(input)
defer s.Close()
var chunks []testChunk
for s.Next() {
chunks = append(chunks, s.Current())
}
if s.Err() != nil {
t.Fatal(s.Err())
}
if len(chunks) != 3 {
t.Fatalf("got %d chunks, want 3", len(chunks))
}
if chunks[0].Content != "a" || chunks[1].Content != "b" || chunks[2].Content != "c" {
t.Errorf("got %+v", chunks)
}
}
func TestStream_EmptyStream(t *testing.T) {
s := newTestStream("data: [DONE]\n\n")
defer s.Close()
if s.Next() {
t.Error("expected no chunks before [DONE]")
}
if s.Err() != nil {
t.Errorf("unexpected error: %v", s.Err())
}
}
func TestStream_InvalidJSON(t *testing.T) {
input := "data: not-json\n\n"
s := newTestStream(input)
defer s.Close()
if s.Next() {
t.Error("expected Next() to return false for invalid JSON")
}
if s.Err() == nil {
t.Error("expected error for invalid JSON")
}
}
func TestStream_NextAfterDone(t *testing.T) {
input := "data: {\"id\":\"1\",\"content\":\"x\"}\n\ndata: [DONE]\n\n"
s := newTestStream(input)
defer s.Close()
s.Next() // consume first chunk
s.Next() // hits [DONE]
// Calling Next() again should still return false
if s.Next() {
t.Error("expected false after stream is done")
}
}
func TestStream_NextAfterError(t *testing.T) {
input := "data: bad\n\n"
s := newTestStream(input)
defer s.Close()
s.Next() // triggers error
// Calling Next() again should still return false
if s.Next() {
t.Error("expected false after error")
}
}
func TestStream_WithComments(t *testing.T) {
input := ": keep-alive\ndata: {\"id\":\"1\",\"content\":\"ok\"}\n\n: ping\ndata: [DONE]\n\n"
s := newTestStream(input)
defer s.Close()
if !s.Next() {
t.Fatalf("expected chunk, err: %v", s.Err())
}
if s.Current().Content != "ok" {
t.Errorf("got %q", s.Current().Content)
}
if s.Next() {
t.Error("expected done after [DONE]")
}
}
func TestStream_EOFWithoutDone(t *testing.T) {
input := "data: {\"id\":\"1\",\"content\":\"x\"}\n\n"
s := newTestStream(input)
defer s.Close()
if !s.Next() {
t.Fatalf("expected chunk, err: %v", s.Err())
}
if s.Next() {
t.Error("expected false at EOF")
}
if s.Err() != nil {
t.Errorf("expected no error at clean EOF, got: %v", s.Err())
}
}