feat: Phase 2 streaming — SSE parser, Stream[T], ChatCompleteStream
Add streaming infrastructure: - SSE line parser handling multi-line data, comments, [DONE] sentinel - Generic Stream[T] pull-based iterator (no goroutines, no channel leaks) - doStream() HTTP helper for streaming endpoints - ChatCompleteStream() method - 28 new tests: SSE edge cases, iterator behavior, httptest integration
This commit is contained in:
@@ -14,3 +14,14 @@ func (c *Client) ChatComplete(ctx context.Context, req *chat.CompletionRequest)
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// ChatCompleteStream sends a chat completion request and returns a stream
|
||||
// of completion chunks. The caller must call Close() on the returned stream.
|
||||
func (c *Client) ChatCompleteStream(ctx context.Context, req *chat.CompletionRequest) (*Stream[chat.CompletionChunk], error) {
|
||||
req.SetStream(true)
|
||||
resp, err := c.doStream(ctx, "POST", "/v1/chat/completions", req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newStream[chat.CompletionChunk](resp.Body), nil
|
||||
}
|
||||
|
||||
229
chat_stream_test.go
Normal file
229
chat_stream_test.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"somegit.dev/vikingowl/mistral-go-sdk/chat"
|
||||
)
|
||||
|
||||
func TestChatCompleteStream_Success(t *testing.T) {
|
||||
chunks := []chat.CompletionChunk{
|
||||
{
|
||||
ID: "chunk-1",
|
||||
Model: "mistral-small-latest",
|
||||
Choices: []chat.CompletionStreamChoice{{
|
||||
Index: 0,
|
||||
Delta: chat.DeltaMessage{Role: "assistant"},
|
||||
}},
|
||||
},
|
||||
{
|
||||
ID: "chunk-2",
|
||||
Model: "mistral-small-latest",
|
||||
Choices: []chat.CompletionStreamChoice{{
|
||||
Index: 0,
|
||||
Delta: chat.DeltaMessage{Content: chat.TextContent("Hello")},
|
||||
}},
|
||||
},
|
||||
{
|
||||
ID: "chunk-3",
|
||||
Model: "mistral-small-latest",
|
||||
Choices: []chat.CompletionStreamChoice{{
|
||||
Index: 0,
|
||||
Delta: chat.DeltaMessage{Content: chat.TextContent(" world!")},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var body map[string]any
|
||||
json.NewDecoder(r.Body).Decode(&body)
|
||||
if body["stream"] != true {
|
||||
t.Errorf("expected stream=true, got %v", body["stream"])
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
flusher, _ := w.(http.Flusher)
|
||||
for _, chunk := range chunks {
|
||||
data, _ := json.Marshal(chunk)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
}
|
||||
fmt.Fprint(w, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
stream, err := client.ChatCompleteStream(context.Background(), &chat.CompletionRequest{
|
||||
Model: "mistral-small-latest",
|
||||
Messages: []chat.Message{
|
||||
&chat.UserMessage{Content: chat.TextContent("Hi")},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
var received []chat.CompletionChunk
|
||||
for stream.Next() {
|
||||
received = append(received, stream.Current())
|
||||
}
|
||||
if stream.Err() != nil {
|
||||
t.Fatal(stream.Err())
|
||||
}
|
||||
if len(received) != 3 {
|
||||
t.Fatalf("got %d chunks, want 3", len(received))
|
||||
}
|
||||
if received[0].Choices[0].Delta.Role != "assistant" {
|
||||
t.Errorf("expected first chunk role=assistant")
|
||||
}
|
||||
if received[1].Choices[0].Delta.Content.String() != "Hello" {
|
||||
t.Errorf("got %q", received[1].Choices[0].Delta.Content.String())
|
||||
}
|
||||
if received[2].Choices[0].Delta.Content.String() != " world!" {
|
||||
t.Errorf("got %q", received[2].Choices[0].Delta.Content.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatCompleteStream_CollectContent(t *testing.T) {
|
||||
words := []string{"The", " quick", " brown", " fox"}
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
flusher, _ := w.(http.Flusher)
|
||||
for _, word := range words {
|
||||
chunk := chat.CompletionChunk{
|
||||
ID: "c",
|
||||
Model: "m",
|
||||
Choices: []chat.CompletionStreamChoice{{
|
||||
Index: 0,
|
||||
Delta: chat.DeltaMessage{Content: chat.TextContent(word)},
|
||||
}},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
flusher.Flush()
|
||||
}
|
||||
stop := "stop"
|
||||
final := chat.CompletionChunk{
|
||||
ID: "c",
|
||||
Model: "m",
|
||||
Choices: []chat.CompletionStreamChoice{{
|
||||
Index: 0,
|
||||
Delta: chat.DeltaMessage{},
|
||||
FinishReason: &stop,
|
||||
}},
|
||||
Usage: &chat.UsageInfo{PromptTokens: 5, CompletionTokens: 4, TotalTokens: 9},
|
||||
}
|
||||
data, _ := json.Marshal(final)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
fmt.Fprint(w, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
stream, err := client.ChatCompleteStream(context.Background(), &chat.CompletionRequest{
|
||||
Model: "m",
|
||||
Messages: []chat.Message{&chat.UserMessage{Content: chat.TextContent("Hi")}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
var sb strings.Builder
|
||||
var lastChunk chat.CompletionChunk
|
||||
for stream.Next() {
|
||||
lastChunk = stream.Current()
|
||||
if len(lastChunk.Choices) > 0 {
|
||||
sb.WriteString(lastChunk.Choices[0].Delta.Content.String())
|
||||
}
|
||||
}
|
||||
if stream.Err() != nil {
|
||||
t.Fatal(stream.Err())
|
||||
}
|
||||
if sb.String() != "The quick brown fox" {
|
||||
t.Errorf("got %q", sb.String())
|
||||
}
|
||||
if lastChunk.Usage == nil {
|
||||
t.Fatal("expected usage in final chunk")
|
||||
}
|
||||
if lastChunk.Usage.TotalTokens != 9 {
|
||||
t.Errorf("got total_tokens=%d", lastChunk.Usage.TotalTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatCompleteStream_WithToolCalls(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
flusher, _ := w.(http.Flusher)
|
||||
|
||||
chunk := chat.CompletionChunk{
|
||||
ID: "c",
|
||||
Model: "m",
|
||||
Choices: []chat.CompletionStreamChoice{{
|
||||
Index: 0,
|
||||
Delta: chat.DeltaMessage{
|
||||
ToolCalls: []chat.ToolCall{{
|
||||
ID: "call_1",
|
||||
Type: "function",
|
||||
Function: chat.FunctionCall{Name: "get_weather", Arguments: `{"city":"Paris"}`},
|
||||
}},
|
||||
},
|
||||
}},
|
||||
}
|
||||
data, _ := json.Marshal(chunk)
|
||||
fmt.Fprintf(w, "data: %s\n\n", data)
|
||||
fmt.Fprint(w, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("key", WithBaseURL(server.URL))
|
||||
stream, err := client.ChatCompleteStream(context.Background(), &chat.CompletionRequest{
|
||||
Model: "m",
|
||||
Messages: []chat.Message{&chat.UserMessage{Content: chat.TextContent("Weather?")}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
if !stream.Next() {
|
||||
t.Fatalf("expected chunk, err: %v", stream.Err())
|
||||
}
|
||||
tc := stream.Current().Choices[0].Delta.ToolCalls
|
||||
if len(tc) != 1 || tc[0].Function.Name != "get_weather" {
|
||||
t.Errorf("got tool calls %+v", tc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatCompleteStream_APIError(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"message": "invalid key",
|
||||
"type": "auth_error",
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewClient("bad", WithBaseURL(server.URL))
|
||||
_, err := client.ChatCompleteStream(context.Background(), &chat.CompletionRequest{
|
||||
Model: "m",
|
||||
Messages: []chat.Message{&chat.UserMessage{Content: chat.TextContent("Hi")}},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if !IsAuth(err) {
|
||||
t.Errorf("expected auth error, got: %v", err)
|
||||
}
|
||||
}
|
||||
20
request.go
20
request.go
@@ -51,6 +51,26 @@ func (c *Client) doJSON(ctx context.Context, method, path string, reqBody, respB
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) doStream(ctx context.Context, method, path string, reqBody any) (*http.Response, error) {
|
||||
var body io.Reader
|
||||
if reqBody != nil {
|
||||
data, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mistral: marshal request: %w", err)
|
||||
}
|
||||
body = bytes.NewReader(data)
|
||||
}
|
||||
resp, err := c.do(ctx, method, path, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
defer resp.Body.Close()
|
||||
return nil, parseAPIError(resp)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func parseAPIError(resp *http.Response) error {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
|
||||
75
sse.go
Normal file
75
sse.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
)
|
||||
|
||||
// sseEvent represents a single Server-Sent Event.
|
||||
type sseEvent struct {
|
||||
Event string
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// isDone returns true if this event signals end-of-stream.
|
||||
func (e *sseEvent) isDone() bool {
|
||||
return string(bytes.TrimSpace(e.Data)) == "[DONE]"
|
||||
}
|
||||
|
||||
// sseReader reads Server-Sent Events from an io.Reader.
|
||||
type sseReader struct {
|
||||
scanner *bufio.Scanner
|
||||
}
|
||||
|
||||
func newSSEReader(r io.Reader) *sseReader {
|
||||
return &sseReader{scanner: bufio.NewScanner(r)}
|
||||
}
|
||||
|
||||
// next reads the next SSE event. Returns nil, nil at EOF.
|
||||
func (r *sseReader) next() (*sseEvent, error) {
|
||||
var event sseEvent
|
||||
var hasData bool
|
||||
|
||||
for r.scanner.Scan() {
|
||||
line := r.scanner.Bytes()
|
||||
|
||||
// Blank line = end of event
|
||||
if len(line) == 0 {
|
||||
if hasData {
|
||||
return &event, nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip comments
|
||||
if line[0] == ':' {
|
||||
continue
|
||||
}
|
||||
|
||||
field, value, _ := bytes.Cut(line, []byte(":"))
|
||||
// Strip single leading space from value per SSE spec
|
||||
value = bytes.TrimPrefix(value, []byte(" "))
|
||||
|
||||
switch string(field) {
|
||||
case "event":
|
||||
event.Event = string(value)
|
||||
case "data":
|
||||
if hasData {
|
||||
event.Data = append(event.Data, '\n')
|
||||
}
|
||||
event.Data = append(event.Data, value...)
|
||||
hasData = true
|
||||
}
|
||||
}
|
||||
|
||||
if err := r.scanner.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Final event without trailing blank line
|
||||
if hasData {
|
||||
return &event, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
182
sse_test.go
Normal file
182
sse_test.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSSEReader_SingleEvent(t *testing.T) {
|
||||
input := "data: {\"id\":\"1\"}\n\n"
|
||||
r := newSSEReader(strings.NewReader(input))
|
||||
ev, err := r.next()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if ev == nil {
|
||||
t.Fatal("expected event")
|
||||
}
|
||||
if string(ev.Data) != `{"id":"1"}` {
|
||||
t.Errorf("got data %q", ev.Data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEReader_MultipleEvents(t *testing.T) {
|
||||
input := "data: first\n\ndata: second\n\n"
|
||||
r := newSSEReader(strings.NewReader(input))
|
||||
|
||||
ev1, err := r.next()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(ev1.Data) != "first" {
|
||||
t.Errorf("got %q, want %q", ev1.Data, "first")
|
||||
}
|
||||
|
||||
ev2, err := r.next()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(ev2.Data) != "second" {
|
||||
t.Errorf("got %q, want %q", ev2.Data, "second")
|
||||
}
|
||||
|
||||
ev3, err := r.next()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if ev3 != nil {
|
||||
t.Errorf("expected nil at EOF, got %+v", ev3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEReader_MultiLineData(t *testing.T) {
|
||||
input := "data: line1\ndata: line2\ndata: line3\n\n"
|
||||
r := newSSEReader(strings.NewReader(input))
|
||||
ev, err := r.next()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
want := "line1\nline2\nline3"
|
||||
if string(ev.Data) != want {
|
||||
t.Errorf("got %q, want %q", ev.Data, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEReader_EventField(t *testing.T) {
|
||||
input := "event: completion\ndata: {\"id\":\"1\"}\n\n"
|
||||
r := newSSEReader(strings.NewReader(input))
|
||||
ev, err := r.next()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if ev.Event != "completion" {
|
||||
t.Errorf("got event %q, want %q", ev.Event, "completion")
|
||||
}
|
||||
if string(ev.Data) != `{"id":"1"}` {
|
||||
t.Errorf("got data %q", ev.Data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEReader_SkipsComments(t *testing.T) {
|
||||
input := ": this is a comment\ndata: hello\n\n"
|
||||
r := newSSEReader(strings.NewReader(input))
|
||||
ev, err := r.next()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(ev.Data) != "hello" {
|
||||
t.Errorf("got %q, want %q", ev.Data, "hello")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEReader_Done(t *testing.T) {
|
||||
input := "data: {\"id\":\"1\"}\n\ndata: [DONE]\n\n"
|
||||
r := newSSEReader(strings.NewReader(input))
|
||||
|
||||
ev1, err := r.next()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if ev1.isDone() {
|
||||
t.Error("first event should not be done")
|
||||
}
|
||||
|
||||
ev2, err := r.next()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !ev2.isDone() {
|
||||
t.Error("second event should be done")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEReader_DoneWithWhitespace(t *testing.T) {
|
||||
ev := &sseEvent{Data: []byte(" [DONE] ")}
|
||||
if !ev.isDone() {
|
||||
t.Error("should detect [DONE] with whitespace")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEReader_EmptyStream(t *testing.T) {
|
||||
r := newSSEReader(strings.NewReader(""))
|
||||
ev, err := r.next()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if ev != nil {
|
||||
t.Errorf("expected nil for empty stream, got %+v", ev)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEReader_OnlyComments(t *testing.T) {
|
||||
input := ": comment1\n: comment2\n\n"
|
||||
r := newSSEReader(strings.NewReader(input))
|
||||
ev, err := r.next()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if ev != nil {
|
||||
t.Errorf("expected nil, got %+v", ev)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEReader_NoTrailingNewline(t *testing.T) {
|
||||
input := "data: hello"
|
||||
r := newSSEReader(strings.NewReader(input))
|
||||
ev, err := r.next()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if ev == nil {
|
||||
t.Fatal("expected event for data without trailing blank line")
|
||||
}
|
||||
if string(ev.Data) != "hello" {
|
||||
t.Errorf("got %q, want %q", ev.Data, "hello")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEReader_DataNoSpace(t *testing.T) {
|
||||
input := "data:{\"compact\":true}\n\n"
|
||||
r := newSSEReader(strings.NewReader(input))
|
||||
ev, err := r.next()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(ev.Data) != `{"compact":true}` {
|
||||
t.Errorf("got %q", ev.Data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSEReader_MultipleBlankLines(t *testing.T) {
|
||||
input := "data: first\n\n\n\ndata: second\n\n"
|
||||
r := newSSEReader(strings.NewReader(input))
|
||||
|
||||
ev1, _ := r.next()
|
||||
if string(ev1.Data) != "first" {
|
||||
t.Errorf("got %q", ev1.Data)
|
||||
}
|
||||
ev2, _ := r.next()
|
||||
if string(ev2.Data) != "second" {
|
||||
t.Errorf("got %q", ev2.Data)
|
||||
}
|
||||
}
|
||||
72
stream.go
Normal file
72
stream.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// Stream is a generic iterator for streaming API responses.
|
||||
// Use Next() to advance, Current() to read the value, Err() for errors,
|
||||
// and Close() when done.
|
||||
type Stream[T any] struct {
|
||||
reader *sseReader
|
||||
closer io.Closer
|
||||
current T
|
||||
err error
|
||||
done bool
|
||||
}
|
||||
|
||||
func newStream[T any](body io.ReadCloser) *Stream[T] {
|
||||
return &Stream[T]{
|
||||
reader: newSSEReader(body),
|
||||
closer: body,
|
||||
}
|
||||
}
|
||||
|
||||
// Next advances to the next event. Returns false when the stream
|
||||
// is exhausted or an error occurs.
|
||||
func (s *Stream[T]) Next() bool {
|
||||
if s.done || s.err != nil {
|
||||
return false
|
||||
}
|
||||
for {
|
||||
event, err := s.reader.next()
|
||||
if err != nil {
|
||||
s.err = fmt.Errorf("mistral: read stream: %w", err)
|
||||
return false
|
||||
}
|
||||
if event == nil {
|
||||
s.done = true
|
||||
return false
|
||||
}
|
||||
if event.isDone() {
|
||||
s.done = true
|
||||
return false
|
||||
}
|
||||
|
||||
var v T
|
||||
if err := json.Unmarshal(event.Data, &v); err != nil {
|
||||
s.err = fmt.Errorf("mistral: decode stream event: %w", err)
|
||||
return false
|
||||
}
|
||||
s.current = v
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Current returns the most recently read value.
|
||||
// Only valid after Next() returns true.
|
||||
func (s *Stream[T]) Current() T {
|
||||
return s.current
|
||||
}
|
||||
|
||||
// Err returns any error encountered during streaming.
|
||||
func (s *Stream[T]) Err() error {
|
||||
return s.err
|
||||
}
|
||||
|
||||
// Close releases the underlying HTTP response body.
|
||||
func (s *Stream[T]) Close() error {
|
||||
return s.closer.Close()
|
||||
}
|
||||
141
stream_test.go
Normal file
141
stream_test.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package mistral
|
||||
|
||||
import (
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type testChunk struct {
|
||||
ID string `json:"id"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
func newTestStream(sse string) *Stream[testChunk] {
|
||||
body := io.NopCloser(strings.NewReader(sse))
|
||||
return newStream[testChunk](body)
|
||||
}
|
||||
|
||||
func TestStream_SingleChunk(t *testing.T) {
|
||||
input := "data: {\"id\":\"1\",\"content\":\"hello\"}\n\ndata: [DONE]\n\n"
|
||||
s := newTestStream(input)
|
||||
defer s.Close()
|
||||
|
||||
if !s.Next() {
|
||||
t.Fatalf("expected Next() to return true, err: %v", s.Err())
|
||||
}
|
||||
chunk := s.Current()
|
||||
if chunk.ID != "1" || chunk.Content != "hello" {
|
||||
t.Errorf("got %+v", chunk)
|
||||
}
|
||||
if s.Next() {
|
||||
t.Error("expected Next() to return false after [DONE]")
|
||||
}
|
||||
if s.Err() != nil {
|
||||
t.Errorf("unexpected error: %v", s.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func TestStream_MultipleChunks(t *testing.T) {
|
||||
input := "data: {\"id\":\"1\",\"content\":\"a\"}\n\ndata: {\"id\":\"2\",\"content\":\"b\"}\n\ndata: {\"id\":\"3\",\"content\":\"c\"}\n\ndata: [DONE]\n\n"
|
||||
s := newTestStream(input)
|
||||
defer s.Close()
|
||||
|
||||
var chunks []testChunk
|
||||
for s.Next() {
|
||||
chunks = append(chunks, s.Current())
|
||||
}
|
||||
if s.Err() != nil {
|
||||
t.Fatal(s.Err())
|
||||
}
|
||||
if len(chunks) != 3 {
|
||||
t.Fatalf("got %d chunks, want 3", len(chunks))
|
||||
}
|
||||
if chunks[0].Content != "a" || chunks[1].Content != "b" || chunks[2].Content != "c" {
|
||||
t.Errorf("got %+v", chunks)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStream_EmptyStream(t *testing.T) {
|
||||
s := newTestStream("data: [DONE]\n\n")
|
||||
defer s.Close()
|
||||
|
||||
if s.Next() {
|
||||
t.Error("expected no chunks before [DONE]")
|
||||
}
|
||||
if s.Err() != nil {
|
||||
t.Errorf("unexpected error: %v", s.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func TestStream_InvalidJSON(t *testing.T) {
|
||||
input := "data: not-json\n\n"
|
||||
s := newTestStream(input)
|
||||
defer s.Close()
|
||||
|
||||
if s.Next() {
|
||||
t.Error("expected Next() to return false for invalid JSON")
|
||||
}
|
||||
if s.Err() == nil {
|
||||
t.Error("expected error for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStream_NextAfterDone(t *testing.T) {
|
||||
input := "data: {\"id\":\"1\",\"content\":\"x\"}\n\ndata: [DONE]\n\n"
|
||||
s := newTestStream(input)
|
||||
defer s.Close()
|
||||
|
||||
s.Next() // consume first chunk
|
||||
s.Next() // hits [DONE]
|
||||
|
||||
// Calling Next() again should still return false
|
||||
if s.Next() {
|
||||
t.Error("expected false after stream is done")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStream_NextAfterError(t *testing.T) {
|
||||
input := "data: bad\n\n"
|
||||
s := newTestStream(input)
|
||||
defer s.Close()
|
||||
|
||||
s.Next() // triggers error
|
||||
|
||||
// Calling Next() again should still return false
|
||||
if s.Next() {
|
||||
t.Error("expected false after error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStream_WithComments(t *testing.T) {
|
||||
input := ": keep-alive\ndata: {\"id\":\"1\",\"content\":\"ok\"}\n\n: ping\ndata: [DONE]\n\n"
|
||||
s := newTestStream(input)
|
||||
defer s.Close()
|
||||
|
||||
if !s.Next() {
|
||||
t.Fatalf("expected chunk, err: %v", s.Err())
|
||||
}
|
||||
if s.Current().Content != "ok" {
|
||||
t.Errorf("got %q", s.Current().Content)
|
||||
}
|
||||
if s.Next() {
|
||||
t.Error("expected done after [DONE]")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStream_EOFWithoutDone(t *testing.T) {
|
||||
input := "data: {\"id\":\"1\",\"content\":\"x\"}\n\n"
|
||||
s := newTestStream(input)
|
||||
defer s.Close()
|
||||
|
||||
if !s.Next() {
|
||||
t.Fatalf("expected chunk, err: %v", s.Err())
|
||||
}
|
||||
if s.Next() {
|
||||
t.Error("expected false at EOF")
|
||||
}
|
||||
if s.Err() != nil {
|
||||
t.Errorf("expected no error at clean EOF, got: %v", s.Err())
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user