From c7b7fe98ecc04230f0c759976a6a998ffa646518 Mon Sep 17 00:00:00 2001 From: vikingowl Date: Sat, 18 Oct 2025 07:15:12 +0200 Subject: [PATCH] =?UTF-8?q?feat(session):=20implement=20streaming=20state?= =?UTF-8?q?=20with=20text=20delta=20and=20tool=E2=80=91call=20diff=20handl?= =?UTF-8?q?ing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Introduce `StreamingMessageState` to track full text, last tool calls, and completion. - Add `StreamDiff`, `TextDelta`, and `TextDeltaKind` for describing incremental changes. - SessionController now maintains a `stream_states` map keyed by response IDs. - `apply_stream_chunk` uses the new state to emit append/replace text deltas and tool‑call updates, handling final chunks and cleanup. - `Conversation` gains `set_stream_content` to replace streaming content and manage metadata. - Ensure stream state is cleared on cancel, conversation reset, and controller clear. --- crates/owlen-core/src/conversation.rs | 46 +++++- crates/owlen-core/src/session.rs | 209 +++++++++++++++++++++++++- 2 files changed, 247 insertions(+), 8 deletions(-) diff --git a/crates/owlen-core/src/conversation.rs b/crates/owlen-core/src/conversation.rs index 4c1bf40..3768678 100644 --- a/crates/owlen-core/src/conversation.rs +++ b/crates/owlen-core/src/conversation.rs @@ -190,6 +190,46 @@ impl ConversationManager { Ok(()) } + /// Replace the current streaming content for a message. + pub fn set_stream_content( + &mut self, + message_id: Uuid, + content: impl Into, + is_final: bool, + ) -> Result<()> { + let index = self + .message_index + .get(&message_id) + .copied() + .ok_or_else(|| crate::Error::Unknown(format!("Unknown message id: {message_id}")))?; + + let conversation = self.active_mut(); + if let Some(message) = conversation.messages.get_mut(index) { + message.content = content.into(); + message.metadata.remove(PLACEHOLDER_FLAG); + message.timestamp = std::time::SystemTime::now(); + let millis = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_millis() as u64; + message.metadata.insert( + LAST_CHUNK_TS.to_string(), + Value::Number(Number::from(millis)), + ); + + if is_final { + message + .metadata + .insert(STREAMING_FLAG.to_string(), Value::Bool(false)); + self.streaming.remove(&message_id); + } else if let Some(info) = self.streaming.get_mut(&message_id) { + info.last_update = Instant::now(); + } + } + + Ok(()) + } + /// Set placeholder text for a streaming message pub fn set_stream_placeholder( &mut self, @@ -254,7 +294,11 @@ impl ConversationManager { .ok_or_else(|| crate::Error::Unknown(format!("Unknown message id: {message_id}")))?; if let Some(message) = self.active_mut().messages.get_mut(index) { - message.tool_calls = Some(tool_calls); + if tool_calls.is_empty() { + message.tool_calls = None; + } else { + message.tool_calls = Some(tool_calls); + } } Ok(()) diff --git a/crates/owlen-core/src/session.rs b/crates/owlen-core/src/session.rs index 72fa624..6d5d830 100644 --- a/crates/owlen-core/src/session.rs +++ b/crates/owlen-core/src/session.rs @@ -66,6 +66,101 @@ struct PendingToolRequest { tool_calls: Vec, } +#[derive(Debug, Default)] +struct StreamingMessageState { + full_text: String, + last_tool_calls: Option>, + finished: bool, +} + +#[derive(Debug)] +struct StreamDiff { + text: Option, + tool_calls: Option>, +} + +#[derive(Debug)] +struct TextDelta { + content: String, + mode: TextDeltaKind, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum TextDeltaKind { + Append, + Replace, +} + +impl StreamingMessageState { + fn new() -> Self { + Self::default() + } + + fn ingest(&mut self, chunk: &ChatResponse) -> StreamDiff { + if self.finished { + return StreamDiff { + text: None, + tool_calls: None, + }; + } + + let mut text_delta = None; + let incoming = chunk.message.content.clone(); + + if incoming != self.full_text { + if incoming.starts_with(&self.full_text) { + let delta = incoming[self.full_text.len()..].to_string(); + if !delta.is_empty() { + text_delta = Some(TextDelta { + content: delta, + mode: TextDeltaKind::Append, + }); + } + } else { + text_delta = Some(TextDelta { + content: incoming.clone(), + mode: TextDeltaKind::Replace, + }); + } + self.full_text = incoming; + } + + let mut tool_delta = None; + if let Some(tool_calls) = chunk.message.tool_calls.clone() { + if tool_calls.is_empty() { + let previously_had_calls = self + .last_tool_calls + .as_ref() + .map(|prev| !prev.is_empty()) + .unwrap_or(false); + if previously_had_calls { + tool_delta = Some(Vec::new()); + } + self.last_tool_calls = None; + } else { + let is_new = self + .last_tool_calls + .as_ref() + .map(|prev| prev != &tool_calls) + .unwrap_or(true); + if is_new { + tool_delta = Some(tool_calls.clone()); + } + self.last_tool_calls = Some(tool_calls); + } + } + + StreamDiff { + text: text_delta, + tool_calls: tool_delta, + } + } + + fn mark_finished(&mut self) { + self.finished = true; + } +} + #[derive(Debug, Clone)] pub struct ToolConsentResolution { pub request_id: Uuid, @@ -144,6 +239,7 @@ pub struct SessionController { missing_oauth_servers: Vec, event_tx: Option>, pending_tool_requests: HashMap, + stream_states: HashMap, } async fn build_tools( @@ -471,6 +567,7 @@ impl SessionController { missing_oauth_servers, event_tx, pending_tool_requests: HashMap::new(), + stream_states: HashMap::new(), }) } @@ -1229,6 +1326,8 @@ impl SessionController { match self.provider.stream_prompt(request).await { Ok(stream) => { let response_id = self.conversation.start_streaming_response(); + self.stream_states + .insert(response_id, StreamingMessageState::new()); Ok(SessionOutcome::Streaming { response_id, stream, @@ -1248,14 +1347,43 @@ impl SessionController { } pub fn apply_stream_chunk(&mut self, message_id: Uuid, chunk: &ChatResponse) -> Result<()> { - if chunk.message.has_tool_calls() { - self.conversation.set_tool_calls_on_message( - message_id, - chunk.message.tool_calls.clone().unwrap_or_default(), - )?; + let state = self.stream_states.entry(message_id).or_default(); + + let diff = state.ingest(chunk); + + if let Some(text_delta) = diff.text { + match text_delta.mode { + TextDeltaKind::Append => { + self.conversation.append_stream_chunk( + message_id, + &text_delta.content, + chunk.is_final, + )?; + } + TextDeltaKind::Replace => { + self.conversation.set_stream_content( + message_id, + text_delta.content, + chunk.is_final, + )?; + } + } + } else if chunk.is_final { + self.conversation + .append_stream_chunk(message_id, "", true)?; } - self.conversation - .append_stream_chunk(message_id, &chunk.message.content, chunk.is_final) + + if let Some(tool_calls) = diff.tool_calls { + self.conversation + .set_tool_calls_on_message(message_id, tool_calls)?; + } + + if chunk.is_final { + state.mark_finished(); + self.stream_states.remove(&message_id); + } + + Ok(()) } pub fn check_streaming_tool_calls(&mut self, message_id: Uuid) -> Option> { @@ -1339,6 +1467,7 @@ impl SessionController { } pub fn cancel_stream(&mut self, message_id: Uuid, notice: &str) -> Result<()> { + self.stream_states.remove(&message_id); self.conversation .cancel_stream(message_id, notice.to_string()) } @@ -1376,10 +1505,12 @@ impl SessionController { pub fn start_new_conversation(&mut self, model: Option, name: Option) { self.conversation.start_new(model, name); + self.stream_states.clear(); } pub fn clear(&mut self) { self.conversation.clear(); + self.stream_states.clear(); } pub async fn generate_conversation_description(&self) -> Result { @@ -1403,6 +1534,29 @@ mod tests { use std::sync::Arc; use tempfile::tempdir; + fn make_response( + text: &str, + tool_calls: Option>, + is_final: bool, + ) -> ChatResponse { + let mut message = Message::assistant(text.to_string()); + message.tool_calls = tool_calls; + ChatResponse { + message, + usage: None, + is_streaming: true, + is_final, + } + } + + fn make_tool_call(id: &str, name: &str) -> ToolCall { + ToolCall { + id: id.to_string(), + name: name.to_string(), + arguments: serde_json::json!({}), + } + } + const SERVER_NAME: &str = "oauth-test"; fn build_oauth_config(server: &MockServer) -> McpOAuthConfig { @@ -1441,6 +1595,47 @@ mod tests { config } + #[test] + fn streaming_state_tracks_text_deltas() { + let mut state = StreamingMessageState::new(); + + let diff = state.ingest(&make_response("Hello", None, false)); + let first = diff.text.expect("text diff"); + assert_eq!(first.content, "Hello"); + assert_eq!(first.mode, TextDeltaKind::Append); + + let diff = state.ingest(&make_response("Hello world", None, false)); + let second = diff.text.expect("second diff"); + assert_eq!(second.content, " world"); + assert_eq!(second.mode, TextDeltaKind::Append); + + let diff = state.ingest(&make_response("Hi", None, false)); + let third = diff.text.expect("third diff"); + assert_eq!(third.content, "Hi"); + assert_eq!(third.mode, TextDeltaKind::Replace); + } + + #[test] + fn streaming_state_detects_tool_call_changes() { + let mut state = StreamingMessageState::new(); + let tool = make_tool_call("call-1", "web.search"); + + let diff = state.ingest(&make_response("", Some(vec![tool.clone()]), false)); + let calls = diff.tool_calls.expect("initial tool call"); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "web.search"); + + let diff = state.ingest(&make_response("", Some(vec![tool.clone()]), false)); + assert!( + diff.tool_calls.is_none(), + "duplicate tool call should not emit" + ); + + let diff = state.ingest(&make_response("", Some(vec![]), false)); + let cleared = diff.tool_calls.expect("clearing tool calls"); + assert!(cleared.is_empty()); + } + async fn build_session(server: &MockServer) -> (SessionController, tempfile::TempDir) { unsafe { std::env::set_var("OWLEN_MASTER_PASSWORD", "test-password");