feat(session): implement streaming state with text delta and tool‑call diff handling

- Introduce `StreamingMessageState` to track full text, last tool calls, and completion.
- Add `StreamDiff`, `TextDelta`, and `TextDeltaKind` for describing incremental changes.
- SessionController now maintains a `stream_states` map keyed by response IDs.
- `apply_stream_chunk` uses the new state to emit append/replace text deltas and tool‑call updates, handling final chunks and cleanup.
- `Conversation` gains `set_stream_content` to replace streaming content and manage metadata.
- Ensure stream state is cleared on cancel, conversation reset, and controller clear.
This commit is contained in:
2025-10-18 07:15:12 +02:00
parent 4820a6706f
commit c7b7fe98ec
2 changed files with 247 additions and 8 deletions

View File

@@ -190,6 +190,46 @@ impl ConversationManager {
Ok(()) Ok(())
} }
/// Replace the current streaming content for a message.
pub fn set_stream_content(
&mut self,
message_id: Uuid,
content: impl Into<String>,
is_final: bool,
) -> Result<()> {
let index = self
.message_index
.get(&message_id)
.copied()
.ok_or_else(|| crate::Error::Unknown(format!("Unknown message id: {message_id}")))?;
let conversation = self.active_mut();
if let Some(message) = conversation.messages.get_mut(index) {
message.content = content.into();
message.metadata.remove(PLACEHOLDER_FLAG);
message.timestamp = std::time::SystemTime::now();
let millis = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64;
message.metadata.insert(
LAST_CHUNK_TS.to_string(),
Value::Number(Number::from(millis)),
);
if is_final {
message
.metadata
.insert(STREAMING_FLAG.to_string(), Value::Bool(false));
self.streaming.remove(&message_id);
} else if let Some(info) = self.streaming.get_mut(&message_id) {
info.last_update = Instant::now();
}
}
Ok(())
}
/// Set placeholder text for a streaming message /// Set placeholder text for a streaming message
pub fn set_stream_placeholder( pub fn set_stream_placeholder(
&mut self, &mut self,
@@ -254,7 +294,11 @@ impl ConversationManager {
.ok_or_else(|| crate::Error::Unknown(format!("Unknown message id: {message_id}")))?; .ok_or_else(|| crate::Error::Unknown(format!("Unknown message id: {message_id}")))?;
if let Some(message) = self.active_mut().messages.get_mut(index) { if let Some(message) = self.active_mut().messages.get_mut(index) {
message.tool_calls = Some(tool_calls); if tool_calls.is_empty() {
message.tool_calls = None;
} else {
message.tool_calls = Some(tool_calls);
}
} }
Ok(()) Ok(())

View File

@@ -66,6 +66,101 @@ struct PendingToolRequest {
tool_calls: Vec<ToolCall>, tool_calls: Vec<ToolCall>,
} }
#[derive(Debug, Default)]
struct StreamingMessageState {
full_text: String,
last_tool_calls: Option<Vec<ToolCall>>,
finished: bool,
}
#[derive(Debug)]
struct StreamDiff {
text: Option<TextDelta>,
tool_calls: Option<Vec<ToolCall>>,
}
#[derive(Debug)]
struct TextDelta {
content: String,
mode: TextDeltaKind,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum TextDeltaKind {
Append,
Replace,
}
impl StreamingMessageState {
fn new() -> Self {
Self::default()
}
fn ingest(&mut self, chunk: &ChatResponse) -> StreamDiff {
if self.finished {
return StreamDiff {
text: None,
tool_calls: None,
};
}
let mut text_delta = None;
let incoming = chunk.message.content.clone();
if incoming != self.full_text {
if incoming.starts_with(&self.full_text) {
let delta = incoming[self.full_text.len()..].to_string();
if !delta.is_empty() {
text_delta = Some(TextDelta {
content: delta,
mode: TextDeltaKind::Append,
});
}
} else {
text_delta = Some(TextDelta {
content: incoming.clone(),
mode: TextDeltaKind::Replace,
});
}
self.full_text = incoming;
}
let mut tool_delta = None;
if let Some(tool_calls) = chunk.message.tool_calls.clone() {
if tool_calls.is_empty() {
let previously_had_calls = self
.last_tool_calls
.as_ref()
.map(|prev| !prev.is_empty())
.unwrap_or(false);
if previously_had_calls {
tool_delta = Some(Vec::new());
}
self.last_tool_calls = None;
} else {
let is_new = self
.last_tool_calls
.as_ref()
.map(|prev| prev != &tool_calls)
.unwrap_or(true);
if is_new {
tool_delta = Some(tool_calls.clone());
}
self.last_tool_calls = Some(tool_calls);
}
}
StreamDiff {
text: text_delta,
tool_calls: tool_delta,
}
}
fn mark_finished(&mut self) {
self.finished = true;
}
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ToolConsentResolution { pub struct ToolConsentResolution {
pub request_id: Uuid, pub request_id: Uuid,
@@ -144,6 +239,7 @@ pub struct SessionController {
missing_oauth_servers: Vec<String>, missing_oauth_servers: Vec<String>,
event_tx: Option<UnboundedSender<ControllerEvent>>, event_tx: Option<UnboundedSender<ControllerEvent>>,
pending_tool_requests: HashMap<Uuid, PendingToolRequest>, pending_tool_requests: HashMap<Uuid, PendingToolRequest>,
stream_states: HashMap<Uuid, StreamingMessageState>,
} }
async fn build_tools( async fn build_tools(
@@ -471,6 +567,7 @@ impl SessionController {
missing_oauth_servers, missing_oauth_servers,
event_tx, event_tx,
pending_tool_requests: HashMap::new(), pending_tool_requests: HashMap::new(),
stream_states: HashMap::new(),
}) })
} }
@@ -1229,6 +1326,8 @@ impl SessionController {
match self.provider.stream_prompt(request).await { match self.provider.stream_prompt(request).await {
Ok(stream) => { Ok(stream) => {
let response_id = self.conversation.start_streaming_response(); let response_id = self.conversation.start_streaming_response();
self.stream_states
.insert(response_id, StreamingMessageState::new());
Ok(SessionOutcome::Streaming { Ok(SessionOutcome::Streaming {
response_id, response_id,
stream, stream,
@@ -1248,14 +1347,43 @@ impl SessionController {
} }
pub fn apply_stream_chunk(&mut self, message_id: Uuid, chunk: &ChatResponse) -> Result<()> { pub fn apply_stream_chunk(&mut self, message_id: Uuid, chunk: &ChatResponse) -> Result<()> {
if chunk.message.has_tool_calls() { let state = self.stream_states.entry(message_id).or_default();
self.conversation.set_tool_calls_on_message(
message_id, let diff = state.ingest(chunk);
chunk.message.tool_calls.clone().unwrap_or_default(),
)?; if let Some(text_delta) = diff.text {
match text_delta.mode {
TextDeltaKind::Append => {
self.conversation.append_stream_chunk(
message_id,
&text_delta.content,
chunk.is_final,
)?;
}
TextDeltaKind::Replace => {
self.conversation.set_stream_content(
message_id,
text_delta.content,
chunk.is_final,
)?;
}
}
} else if chunk.is_final {
self.conversation
.append_stream_chunk(message_id, "", true)?;
} }
self.conversation
.append_stream_chunk(message_id, &chunk.message.content, chunk.is_final) if let Some(tool_calls) = diff.tool_calls {
self.conversation
.set_tool_calls_on_message(message_id, tool_calls)?;
}
if chunk.is_final {
state.mark_finished();
self.stream_states.remove(&message_id);
}
Ok(())
} }
pub fn check_streaming_tool_calls(&mut self, message_id: Uuid) -> Option<Vec<ToolCall>> { pub fn check_streaming_tool_calls(&mut self, message_id: Uuid) -> Option<Vec<ToolCall>> {
@@ -1339,6 +1467,7 @@ impl SessionController {
} }
pub fn cancel_stream(&mut self, message_id: Uuid, notice: &str) -> Result<()> { pub fn cancel_stream(&mut self, message_id: Uuid, notice: &str) -> Result<()> {
self.stream_states.remove(&message_id);
self.conversation self.conversation
.cancel_stream(message_id, notice.to_string()) .cancel_stream(message_id, notice.to_string())
} }
@@ -1376,10 +1505,12 @@ impl SessionController {
pub fn start_new_conversation(&mut self, model: Option<String>, name: Option<String>) { pub fn start_new_conversation(&mut self, model: Option<String>, name: Option<String>) {
self.conversation.start_new(model, name); self.conversation.start_new(model, name);
self.stream_states.clear();
} }
pub fn clear(&mut self) { pub fn clear(&mut self) {
self.conversation.clear(); self.conversation.clear();
self.stream_states.clear();
} }
pub async fn generate_conversation_description(&self) -> Result<String> { pub async fn generate_conversation_description(&self) -> Result<String> {
@@ -1403,6 +1534,29 @@ mod tests {
use std::sync::Arc; use std::sync::Arc;
use tempfile::tempdir; use tempfile::tempdir;
fn make_response(
text: &str,
tool_calls: Option<Vec<ToolCall>>,
is_final: bool,
) -> ChatResponse {
let mut message = Message::assistant(text.to_string());
message.tool_calls = tool_calls;
ChatResponse {
message,
usage: None,
is_streaming: true,
is_final,
}
}
fn make_tool_call(id: &str, name: &str) -> ToolCall {
ToolCall {
id: id.to_string(),
name: name.to_string(),
arguments: serde_json::json!({}),
}
}
const SERVER_NAME: &str = "oauth-test"; const SERVER_NAME: &str = "oauth-test";
fn build_oauth_config(server: &MockServer) -> McpOAuthConfig { fn build_oauth_config(server: &MockServer) -> McpOAuthConfig {
@@ -1441,6 +1595,47 @@ mod tests {
config config
} }
#[test]
fn streaming_state_tracks_text_deltas() {
let mut state = StreamingMessageState::new();
let diff = state.ingest(&make_response("Hello", None, false));
let first = diff.text.expect("text diff");
assert_eq!(first.content, "Hello");
assert_eq!(first.mode, TextDeltaKind::Append);
let diff = state.ingest(&make_response("Hello world", None, false));
let second = diff.text.expect("second diff");
assert_eq!(second.content, " world");
assert_eq!(second.mode, TextDeltaKind::Append);
let diff = state.ingest(&make_response("Hi", None, false));
let third = diff.text.expect("third diff");
assert_eq!(third.content, "Hi");
assert_eq!(third.mode, TextDeltaKind::Replace);
}
#[test]
fn streaming_state_detects_tool_call_changes() {
let mut state = StreamingMessageState::new();
let tool = make_tool_call("call-1", "web.search");
let diff = state.ingest(&make_response("", Some(vec![tool.clone()]), false));
let calls = diff.tool_calls.expect("initial tool call");
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].name, "web.search");
let diff = state.ingest(&make_response("", Some(vec![tool.clone()]), false));
assert!(
diff.tool_calls.is_none(),
"duplicate tool call should not emit"
);
let diff = state.ingest(&make_response("", Some(vec![]), false));
let cleared = diff.tool_calls.expect("clearing tool calls");
assert!(cleared.is_empty());
}
async fn build_session(server: &MockServer) -> (SessionController, tempfile::TempDir) { async fn build_session(server: &MockServer) -> (SessionController, tempfile::TempDir) {
unsafe { unsafe {
std::env::set_var("OWLEN_MASTER_PASSWORD", "test-password"); std::env::set_var("OWLEN_MASTER_PASSWORD", "test-password");