feat: add session interface with channel-based local implementation
Session interface decouples UI from engine via channels: - Send(input) starts agentic turn in background goroutine - Events() returns channel for streaming events - TurnResult() returns completed Turn after drain - Cancel() propagates context cancellation - Status() reports state, provider, model, token usage, turn count Local implementation: engine runs on dedicated goroutine per turn, events pushed to buffered channel (64), context cancellation propagated. 5 tests.
This commit is contained in:
126
internal/session/local.go
Normal file
126
internal/session/local.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/engine"
|
||||
"somegit.dev/Owlibou/gnoma/internal/stream"
|
||||
)
|
||||
|
||||
// Local implements Session using goroutines and channels within the same process.
|
||||
type Local struct {
|
||||
mu sync.Mutex
|
||||
|
||||
eng *engine.Engine
|
||||
state SessionState
|
||||
events chan stream.Event
|
||||
|
||||
// Current turn context
|
||||
cancel context.CancelFunc
|
||||
turn *engine.Turn
|
||||
err error
|
||||
|
||||
// Stats
|
||||
provider string
|
||||
model string
|
||||
turnCount int
|
||||
}
|
||||
|
||||
// NewLocal creates a channel-based in-process session.
|
||||
func NewLocal(eng *engine.Engine, providerName, model string) *Local {
|
||||
return &Local{
|
||||
eng: eng,
|
||||
state: StateIdle,
|
||||
provider: providerName,
|
||||
model: model,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Local) Send(input string) error {
|
||||
s.mu.Lock()
|
||||
if s.state != StateIdle {
|
||||
s.mu.Unlock()
|
||||
return fmt.Errorf("session not idle (state: %s)", s.state)
|
||||
}
|
||||
|
||||
s.state = StateStreaming
|
||||
s.events = make(chan stream.Event, 64)
|
||||
s.turn = nil
|
||||
s.err = nil
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s.cancel = cancel
|
||||
s.turnCount++
|
||||
s.mu.Unlock()
|
||||
|
||||
// Run engine in background goroutine
|
||||
go func() {
|
||||
cb := func(evt stream.Event) {
|
||||
select {
|
||||
case s.events <- evt:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
turn, err := s.eng.Submit(ctx, input, cb)
|
||||
|
||||
s.mu.Lock()
|
||||
s.turn = turn
|
||||
s.err = err
|
||||
if err != nil && ctx.Err() != nil {
|
||||
s.state = StateCancelled
|
||||
} else if err != nil {
|
||||
s.state = StateError
|
||||
} else {
|
||||
s.state = StateIdle
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
close(s.events)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Local) Events() <-chan stream.Event {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.events
|
||||
}
|
||||
|
||||
func (s *Local) TurnResult() (*engine.Turn, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.turn, s.err
|
||||
}
|
||||
|
||||
func (s *Local) Cancel() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.cancel != nil {
|
||||
s.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Local) Close() error {
|
||||
s.Cancel()
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.state = StateClosed
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Local) Status() Status {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
return Status{
|
||||
State: s.state,
|
||||
Provider: s.provider,
|
||||
Model: s.model,
|
||||
TokensUsed: s.eng.Usage().TotalTokens(),
|
||||
TurnCount: s.turnCount,
|
||||
}
|
||||
}
|
||||
64
internal/session/session.go
Normal file
64
internal/session/session.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"somegit.dev/Owlibou/gnoma/internal/engine"
|
||||
"somegit.dev/Owlibou/gnoma/internal/stream"
|
||||
)
|
||||
|
||||
// SessionState tracks the current state of a session.
|
||||
type SessionState int
|
||||
|
||||
const (
|
||||
StateIdle SessionState = iota
|
||||
StateStreaming
|
||||
StateToolExec
|
||||
StateCancelled
|
||||
StateError
|
||||
StateClosed
|
||||
)
|
||||
|
||||
func (s SessionState) String() string {
|
||||
switch s {
|
||||
case StateIdle:
|
||||
return "idle"
|
||||
case StateStreaming:
|
||||
return "streaming"
|
||||
case StateToolExec:
|
||||
return "tool_exec"
|
||||
case StateCancelled:
|
||||
return "cancelled"
|
||||
case StateError:
|
||||
return "error"
|
||||
case StateClosed:
|
||||
return "closed"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// Status holds observable session state.
|
||||
type Status struct {
|
||||
State SessionState
|
||||
Provider string
|
||||
Model string
|
||||
TokensUsed int64
|
||||
TurnCount int
|
||||
}
|
||||
|
||||
// Session is the boundary between UI and engine.
|
||||
// All communication is via channels. No shared mutable state.
|
||||
type Session interface {
|
||||
// Send submits user input and begins an agentic turn.
|
||||
Send(input string) error
|
||||
// Events returns the channel that receives streaming events.
|
||||
// A new channel is created per Send(). Closed when the turn completes.
|
||||
Events() <-chan stream.Event
|
||||
// TurnResult returns the completed Turn after Events() is drained.
|
||||
TurnResult() (*engine.Turn, error)
|
||||
// Cancel aborts the current turn.
|
||||
Cancel()
|
||||
// Close shuts down the session.
|
||||
Close() error
|
||||
// Status returns current session state.
|
||||
Status() Status
|
||||
}
|
||||
250
internal/session/session_test.go
Normal file
250
internal/session/session_test.go
Normal file
@@ -0,0 +1,250 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"somegit.dev/Owlibou/gnoma/internal/engine"
|
||||
"somegit.dev/Owlibou/gnoma/internal/message"
|
||||
"somegit.dev/Owlibou/gnoma/internal/provider"
|
||||
"somegit.dev/Owlibou/gnoma/internal/stream"
|
||||
"somegit.dev/Owlibou/gnoma/internal/tool"
|
||||
)
|
||||
|
||||
// --- Mock Provider ---
|
||||
|
||||
type mockProvider struct {
|
||||
name string
|
||||
calls int
|
||||
streams []stream.Stream
|
||||
}
|
||||
|
||||
func (m *mockProvider) Name() string { return m.name }
|
||||
func (m *mockProvider) DefaultModel() string { return "mock-model" }
|
||||
func (m *mockProvider) Models(_ context.Context) ([]provider.ModelInfo, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockProvider) Stream(_ context.Context, _ provider.Request) (stream.Stream, error) {
|
||||
if m.calls >= len(m.streams) {
|
||||
return nil, fmt.Errorf("no more streams")
|
||||
}
|
||||
s := m.streams[m.calls]
|
||||
m.calls++
|
||||
return s, nil
|
||||
}
|
||||
|
||||
type eventStream struct {
|
||||
events []stream.Event
|
||||
idx int
|
||||
}
|
||||
|
||||
func newEventStream(stopReason message.StopReason, events ...stream.Event) *eventStream {
|
||||
events = append(events, stream.Event{Type: stream.EventTextDelta, StopReason: stopReason})
|
||||
return &eventStream{events: events}
|
||||
}
|
||||
|
||||
func (s *eventStream) Next() bool { s.idx++; return s.idx <= len(s.events) }
|
||||
func (s *eventStream) Current() stream.Event { return s.events[s.idx-1] }
|
||||
func (s *eventStream) Err() error { return nil }
|
||||
func (s *eventStream) Close() error { return nil }
|
||||
|
||||
// --- Tests ---
|
||||
|
||||
func TestLocal_SendAndReceive(t *testing.T) {
|
||||
mp := &mockProvider{
|
||||
name: "test",
|
||||
streams: []stream.Stream{
|
||||
newEventStream(message.StopEndTurn,
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "Hello "},
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "world!"},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()})
|
||||
sess := NewLocal(eng, "test", "mock-model")
|
||||
|
||||
// Initial state
|
||||
status := sess.Status()
|
||||
if status.State != StateIdle {
|
||||
t.Errorf("initial state = %s, want idle", status.State)
|
||||
}
|
||||
|
||||
// Send
|
||||
if err := sess.Send("hello"); err != nil {
|
||||
t.Fatalf("Send: %v", err)
|
||||
}
|
||||
|
||||
// Collect events
|
||||
var texts []string
|
||||
for evt := range sess.Events() {
|
||||
if evt.Type == stream.EventTextDelta && evt.Text != "" {
|
||||
texts = append(texts, evt.Text)
|
||||
}
|
||||
}
|
||||
|
||||
if len(texts) == 0 {
|
||||
t.Error("should receive text events")
|
||||
}
|
||||
|
||||
// Turn result
|
||||
turn, err := sess.TurnResult()
|
||||
if err != nil {
|
||||
t.Fatalf("TurnResult: %v", err)
|
||||
}
|
||||
if turn == nil {
|
||||
t.Fatal("turn should not be nil")
|
||||
}
|
||||
|
||||
// Back to idle
|
||||
status = sess.Status()
|
||||
if status.State != StateIdle {
|
||||
t.Errorf("state after turn = %s, want idle", status.State)
|
||||
}
|
||||
if status.TurnCount != 1 {
|
||||
t.Errorf("TurnCount = %d, want 1", status.TurnCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocal_SendWhileBusy(t *testing.T) {
|
||||
mp := &mockProvider{
|
||||
name: "test",
|
||||
streams: []stream.Stream{
|
||||
newEventStream(message.StopEndTurn,
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "slow..."},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()})
|
||||
sess := NewLocal(eng, "test", "model")
|
||||
|
||||
sess.Send("first")
|
||||
|
||||
// Try to send while still processing
|
||||
err := sess.Send("second")
|
||||
if err == nil {
|
||||
t.Error("should error when sending while busy")
|
||||
}
|
||||
|
||||
// Drain events to let first turn complete
|
||||
for range sess.Events() {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocal_Cancel(t *testing.T) {
|
||||
// Create a slow stream with many events
|
||||
events := make([]stream.Event, 100)
|
||||
for i := range events {
|
||||
events[i] = stream.Event{Type: stream.EventTextDelta, Text: "x"}
|
||||
}
|
||||
mp := &mockProvider{
|
||||
name: "test",
|
||||
streams: []stream.Stream{&slowStream{events: events}},
|
||||
}
|
||||
|
||||
eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()})
|
||||
sess := NewLocal(eng, "test", "model")
|
||||
|
||||
sess.Send("slow task")
|
||||
|
||||
// Read a few events then cancel
|
||||
evts := sess.Events()
|
||||
<-evts // wait for first event
|
||||
sess.Cancel()
|
||||
|
||||
// Drain remaining
|
||||
for range evts {
|
||||
}
|
||||
|
||||
// Should be cancelled or error (context.Canceled wraps to error)
|
||||
status := sess.Status()
|
||||
if status.State != StateCancelled && status.State != StateError && status.State != StateIdle {
|
||||
t.Errorf("state after cancel = %s, want cancelled/error/idle", status.State)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocal_Close(t *testing.T) {
|
||||
mp := &mockProvider{name: "test"}
|
||||
eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()})
|
||||
sess := NewLocal(eng, "test", "model")
|
||||
|
||||
if err := sess.Close(); err != nil {
|
||||
t.Fatalf("Close: %v", err)
|
||||
}
|
||||
|
||||
status := sess.Status()
|
||||
if status.State != StateClosed {
|
||||
t.Errorf("state after close = %s, want closed", status.State)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocal_StatusTracking(t *testing.T) {
|
||||
mp := &mockProvider{
|
||||
name: "test",
|
||||
streams: []stream.Stream{
|
||||
newEventStream(message.StopEndTurn,
|
||||
stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 100, OutputTokens: 50}},
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "ok"},
|
||||
),
|
||||
newEventStream(message.StopEndTurn,
|
||||
stream.Event{Type: stream.EventUsage, Usage: &message.Usage{InputTokens: 200, OutputTokens: 80}},
|
||||
stream.Event{Type: stream.EventTextDelta, Text: "ok"},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
eng, _ := engine.New(engine.Config{Provider: mp, Tools: tool.NewRegistry()})
|
||||
sess := NewLocal(eng, "test", "mock-model")
|
||||
|
||||
// Turn 1
|
||||
sess.Send("one")
|
||||
for range sess.Events() {
|
||||
}
|
||||
|
||||
// Turn 2
|
||||
sess.Send("two")
|
||||
for range sess.Events() {
|
||||
}
|
||||
|
||||
status := sess.Status()
|
||||
if status.TurnCount != 2 {
|
||||
t.Errorf("TurnCount = %d, want 2", status.TurnCount)
|
||||
}
|
||||
if status.TokensUsed != 430 { // 100+50+200+80
|
||||
t.Errorf("TokensUsed = %d, want 430", status.TokensUsed)
|
||||
}
|
||||
if status.Provider != "test" {
|
||||
t.Errorf("Provider = %q", status.Provider)
|
||||
}
|
||||
if status.Model != "mock-model" {
|
||||
t.Errorf("Model = %q", status.Model)
|
||||
}
|
||||
}
|
||||
|
||||
// slowStream produces events slowly then stops.
|
||||
type slowStream struct {
|
||||
events []stream.Event
|
||||
idx int
|
||||
}
|
||||
|
||||
func (s *slowStream) Next() bool {
|
||||
if s.idx >= len(s.events) {
|
||||
return false
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
s.idx++
|
||||
return true
|
||||
}
|
||||
func (s *slowStream) Current() stream.Event { return s.events[s.idx-1] }
|
||||
func (s *slowStream) Err() error { return nil }
|
||||
func (s *slowStream) Close() error { return nil }
|
||||
|
||||
// Ensure Local implements Session interface
|
||||
var _ Session = (*Local)(nil)
|
||||
|
||||
// Suppress unused import
|
||||
var _ = json.Marshal
|
||||
Reference in New Issue
Block a user