From c6b13f7cc817d7d4c690a43ecaf8085591c90fae Mon Sep 17 00:00:00 2001 From: vikingowl Date: Fri, 3 Apr 2026 15:12:12 +0200 Subject: [PATCH] feat: add session interface with channel-based local implementation Session interface decouples UI from engine via channels: - Send(input) starts agentic turn in background goroutine - Events() returns channel for streaming events - TurnResult() returns completed Turn after drain - Cancel() propagates context cancellation - Status() reports state, provider, model, token usage, turn count Local implementation: engine runs on dedicated goroutine per turn, events pushed to buffered channel (64), context cancellation propagated. 5 tests. --- internal/session/local.go | 126 ++++++++++++++++ internal/session/session.go | 64 ++++++++ internal/session/session_test.go | 250 +++++++++++++++++++++++++++++++ 3 files changed, 440 insertions(+) create mode 100644 internal/session/local.go create mode 100644 internal/session/session.go create mode 100644 internal/session/session_test.go diff --git a/internal/session/local.go b/internal/session/local.go new file mode 100644 index 0000000..5ea0f0e --- /dev/null +++ b/internal/session/local.go @@ -0,0 +1,126 @@ +package session + +import ( + "context" + "fmt" + "sync" + + "somegit.dev/Owlibou/gnoma/internal/engine" + "somegit.dev/Owlibou/gnoma/internal/stream" +) + +// Local implements Session using goroutines and channels within the same process. +type Local struct { + mu sync.Mutex + + eng *engine.Engine + state SessionState + events chan stream.Event + + // Current turn context + cancel context.CancelFunc + turn *engine.Turn + err error + + // Stats + provider string + model string + turnCount int +} + +// NewLocal creates a channel-based in-process session. +func NewLocal(eng *engine.Engine, providerName, model string) *Local { + return &Local{ + eng: eng, + state: StateIdle, + provider: providerName, + model: model, + } +} + +func (s *Local) Send(input string) error { + s.mu.Lock() + if s.state != StateIdle { + s.mu.Unlock() + return fmt.Errorf("session not idle (state: %s)", s.state) + } + + s.state = StateStreaming + s.events = make(chan stream.Event, 64) + s.turn = nil + s.err = nil + + ctx, cancel := context.WithCancel(context.Background()) + s.cancel = cancel + s.turnCount++ + s.mu.Unlock() + + // Run engine in background goroutine + go func() { + cb := func(evt stream.Event) { + select { + case s.events <- evt: + case <-ctx.Done(): + } + } + + turn, err := s.eng.Submit(ctx, input, cb) + + s.mu.Lock() + s.turn = turn + s.err = err + if err != nil && ctx.Err() != nil { + s.state = StateCancelled + } else if err != nil { + s.state = StateError + } else { + s.state = StateIdle + } + s.mu.Unlock() + + close(s.events) + }() + + return nil +} + +func (s *Local) Events() <-chan stream.Event { + s.mu.Lock() + defer s.mu.Unlock() + return s.events +} + +func (s *Local) TurnResult() (*engine.Turn, error) { + s.mu.Lock() + defer s.mu.Unlock() + return s.turn, s.err +} + +func (s *Local) Cancel() { + s.mu.Lock() + defer s.mu.Unlock() + if s.cancel != nil { + s.cancel() + } +} + +func (s *Local) Close() error { + s.Cancel() + s.mu.Lock() + defer s.mu.Unlock() + s.state = StateClosed + return nil +} + +func (s *Local) Status() Status { + s.mu.Lock() + defer s.mu.Unlock() + + return Status{ + State: s.state, + Provider: s.provider, + Model: s.model, + TokensUsed: s.eng.Usage().TotalTokens(), + TurnCount: s.turnCount, + } +} diff --git a/internal/session/session.go b/internal/session/session.go new file mode 100644 index 0000000..6eed511 --- /dev/null +++ b/internal/session/session.go @@ -0,0 +1,64 @@ +package session + +import ( + "somegit.dev/Owlibou/gnoma/internal/engine" + "somegit.dev/Owlibou/gnoma/internal/stream" +) + +// SessionState tracks the current state of a session. +type SessionState int + +const ( + StateIdle SessionState = iota + StateStreaming + StateToolExec + StateCancelled + StateError + StateClosed +) + +func (s SessionState) String() string { + switch s { + case StateIdle: + return "idle" + case StateStreaming: + return "streaming" + case StateToolExec: + return "tool_exec" + case StateCancelled: + return "cancelled" + case StateError: + return "error" + case StateClosed: + return "closed" + default: + return "unknown" + } +} + +// Status holds observable session state. +type Status struct { + State SessionState + Provider string + Model string + TokensUsed int64 + TurnCount int +} + +// Session is the boundary between UI and engine. +// All communication is via channels. No shared mutable state. +type Session interface { + // Send submits user input and begins an agentic turn. + Send(input string) error + // Events returns the channel that receives streaming events. + // A new channel is created per Send(). Closed when the turn completes. + Events() <-chan stream.Event + // TurnResult returns the completed Turn after Events() is drained. + TurnResult() (*engine.Turn, error) + // Cancel aborts the current turn. + Cancel() + // Close shuts down the session. + Close() error + // Status returns current session state. + Status() Status +} diff --git a/internal/session/session_test.go b/internal/session/session_test.go new file mode 100644 index 0000000..da0a61c --- /dev/null +++ b/internal/session/session_test.go @@ -0,0 +1,250 @@ +package session + +import ( + "context" + "encoding/json" + "fmt" + "testing" + "time" + + "somegit.dev/Owlibou/gnoma/internal/engine" + "somegit.dev/Owlibou/gnoma/internal/message" + "somegit.dev/Owlibou/gnoma/internal/provider" + "somegit.dev/Owlibou/gnoma/internal/stream" + "somegit.dev/Owlibou/gnoma/internal/tool" +) + +// --- Mock Provider --- + +type mockProvider struct { + name string + calls int + streams []stream.Stream +} + +func (m *mockProvider) Name() string { return m.name } +func (m *mockProvider) DefaultModel() string { return "mock-model" } +func (m *mockProvider) Models(_ context.Context) ([]provider.ModelInfo, error) { + return nil, nil +} +func (m *mockProvider) Stream(_ context.Context, _ provider.Request) (stream.Stream, error) { + if m.calls >= len(m.streams) { + return nil, fmt.Errorf("no more streams") + } + s := m.streams[m.calls] + m.calls++ + return s, nil +} + +type eventStream struct { + events []stream.Event + idx int +} + +func newEventStream(stopReason message.StopReason, events ...stream.Event) *eventStream { + events = append(events, stream.Event{Type: stream.EventTextDelta, StopReason: stopReason}) + return &eventStream{events: events} +} + +func (s *eventStream) Next() bool { s.idx++; return s.idx <= len(s.events) } +func (s *eventStream) Current() stream.Event { return s.events[s.idx-1] } +func (s *eventStream) Err() error { return nil } +func (s *eventStream) Close() error { return nil } + +// --- Tests --- + +func TestLocal_SendAndReceive(t *testing.T) { + mp := &mockProvider{ + name: "test", + streams: []stream.Stream{ + newEventStream(message.StopEndTurn, + stream.Event{Type: stream.EventTextDelta, Text: "Hello "}, + stream.Event{Type: stream.EventTextDelta, Text: "world!"}, + ), + }, + } + + eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()}) + sess := NewLocal(eng, "test", "mock-model") + + // Initial state + status := sess.Status() + if status.State != StateIdle { + t.Errorf("initial state = %s, want idle", status.State) + } + + // Send + if err := sess.Send("hello"); err != nil { + t.Fatalf("Send: %v", err) + } + + // Collect events + var texts []string + for evt := range sess.Events() { + if evt.Type == stream.EventTextDelta && evt.Text != "" { + texts = append(texts, evt.Text) + } + } + + if len(texts) == 0 { + t.Error("should receive text events") + } + + // Turn result + turn, err := sess.TurnResult() + if err != nil { + t.Fatalf("TurnResult: %v", err) + } + if turn == nil { + t.Fatal("turn should not be nil") + } + + // Back to idle + status = sess.Status() + if status.State != StateIdle { + t.Errorf("state after turn = %s, want idle", status.State) + } + if status.TurnCount != 1 { + t.Errorf("TurnCount = %d, want 1", status.TurnCount) + } +} + +func TestLocal_SendWhileBusy(t *testing.T) { + mp := &mockProvider{ + name: "test", + streams: []stream.Stream{ + newEventStream(message.StopEndTurn, + stream.Event{Type: stream.EventTextDelta, Text: "slow..."}, + ), + }, + } + + eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()}) + sess := NewLocal(eng, "test", "model") + + sess.Send("first") + + // Try to send while still processing + err := sess.Send("second") + if err == nil { + t.Error("should error when sending while busy") + } + + // Drain events to let first turn complete + for range sess.Events() { + } +} + +func TestLocal_Cancel(t *testing.T) { + // Create a slow stream with many events + events := make([]stream.Event, 100) + for i := range events { + events[i] = stream.Event{Type: stream.EventTextDelta, Text: "x"} + } + mp := &mockProvider{ + name: "test", + streams: []stream.Stream{&slowStream{events: events}}, + } + + eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()}) + sess := NewLocal(eng, "test", "model") + + sess.Send("slow task") + + // Read a few events then cancel + evts := sess.Events() + <-evts // wait for first event + sess.Cancel() + + // Drain remaining + for range evts { + } + + // Should be cancelled or error (context.Canceled wraps to error) + status := sess.Status() + if status.State != StateCancelled && status.State != StateError && status.State != StateIdle { + t.Errorf("state after cancel = %s, want cancelled/error/idle", status.State) + } +} + +func TestLocal_Close(t *testing.T) { + mp := &mockProvider{name: "test"} + eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()}) + sess := NewLocal(eng, "test", "model") + + if err := sess.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + + status := sess.Status() + if status.State != StateClosed { + t.Errorf("state after close = %s, want closed", status.State) + } +} + +func TestLocal_StatusTracking(t *testing.T) { + mp := &mockProvider{ + name: "test", + streams: []stream.Stream{ + newEventStream(message.StopEndTurn, + stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 100, OutputTokens: 50}}, + stream.Event{Type: stream.EventTextDelta, Text: "ok"}, + ), + newEventStream(message.StopEndTurn, + stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 200, OutputTokens: 80}}, + stream.Event{Type: stream.EventTextDelta, Text: "ok"}, + ), + }, + } + + eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()}) + sess := NewLocal(eng, "test", "mock-model") + + // Turn 1 + sess.Send("one") + for range sess.Events() { + } + + // Turn 2 + sess.Send("two") + for range sess.Events() { + } + + status := sess.Status() + if status.TurnCount != 2 { + t.Errorf("TurnCount = %d, want 2", status.TurnCount) + } + if status.TokensUsed != 430 { // 100+50+200+80 + t.Errorf("TokensUsed = %d, want 430", status.TokensUsed) + } + if status.Provider != "test" { + t.Errorf("Provider = %q", status.Provider) + } + if status.Model != "mock-model" { + t.Errorf("Model = %q", status.Model) + } +} + +// slowStream produces events slowly then stops. +type slowStream struct { + events []stream.Event + idx int +} + +func (s *slowStream) Next() bool { + if s.idx >= len(s.events) { + return false + } + time.Sleep(50 * time.Millisecond) + s.idx++ + return true +} +func (s *slowStream) Current() stream.Event { return s.events[s.idx-1] } +func (s *slowStream) Err() error { return nil } +func (s *slowStream) Close() error { return nil } + +// Ensure Local implements Session interface +var _ Session = (*Local)(nil) + +// Suppress unused import +var _ = json.Marshal