feat(session): implement streaming state with text delta and tool‑call diff handling
- Introduce `StreamingMessageState` to track full text, last tool calls, and completion. - Add `StreamDiff`, `TextDelta`, and `TextDeltaKind` for describing incremental changes. - SessionController now maintains a `stream_states` map keyed by response IDs. - `apply_stream_chunk` uses the new state to emit append/replace text deltas and tool‑call updates, handling final chunks and cleanup. - `Conversation` gains `set_stream_content` to replace streaming content and manage metadata. - Ensure stream state is cleared on cancel, conversation reset, and controller clear.
This commit is contained in:
@@ -190,6 +190,46 @@ impl ConversationManager {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Replace the current streaming content for a message.
|
||||
pub fn set_stream_content(
|
||||
&mut self,
|
||||
message_id: Uuid,
|
||||
content: impl Into<String>,
|
||||
is_final: bool,
|
||||
) -> Result<()> {
|
||||
let index = self
|
||||
.message_index
|
||||
.get(&message_id)
|
||||
.copied()
|
||||
.ok_or_else(|| crate::Error::Unknown(format!("Unknown message id: {message_id}")))?;
|
||||
|
||||
let conversation = self.active_mut();
|
||||
if let Some(message) = conversation.messages.get_mut(index) {
|
||||
message.content = content.into();
|
||||
message.metadata.remove(PLACEHOLDER_FLAG);
|
||||
message.timestamp = std::time::SystemTime::now();
|
||||
let millis = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_millis() as u64;
|
||||
message.metadata.insert(
|
||||
LAST_CHUNK_TS.to_string(),
|
||||
Value::Number(Number::from(millis)),
|
||||
);
|
||||
|
||||
if is_final {
|
||||
message
|
||||
.metadata
|
||||
.insert(STREAMING_FLAG.to_string(), Value::Bool(false));
|
||||
self.streaming.remove(&message_id);
|
||||
} else if let Some(info) = self.streaming.get_mut(&message_id) {
|
||||
info.last_update = Instant::now();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set placeholder text for a streaming message
|
||||
pub fn set_stream_placeholder(
|
||||
&mut self,
|
||||
@@ -254,8 +294,12 @@ impl ConversationManager {
|
||||
.ok_or_else(|| crate::Error::Unknown(format!("Unknown message id: {message_id}")))?;
|
||||
|
||||
if let Some(message) = self.active_mut().messages.get_mut(index) {
|
||||
if tool_calls.is_empty() {
|
||||
message.tool_calls = None;
|
||||
} else {
|
||||
message.tool_calls = Some(tool_calls);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -66,6 +66,101 @@ struct PendingToolRequest {
|
||||
tool_calls: Vec<ToolCall>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct StreamingMessageState {
|
||||
full_text: String,
|
||||
last_tool_calls: Option<Vec<ToolCall>>,
|
||||
finished: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct StreamDiff {
|
||||
text: Option<TextDelta>,
|
||||
tool_calls: Option<Vec<ToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TextDelta {
|
||||
content: String,
|
||||
mode: TextDeltaKind,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum TextDeltaKind {
|
||||
Append,
|
||||
Replace,
|
||||
}
|
||||
|
||||
impl StreamingMessageState {
|
||||
fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
fn ingest(&mut self, chunk: &ChatResponse) -> StreamDiff {
|
||||
if self.finished {
|
||||
return StreamDiff {
|
||||
text: None,
|
||||
tool_calls: None,
|
||||
};
|
||||
}
|
||||
|
||||
let mut text_delta = None;
|
||||
let incoming = chunk.message.content.clone();
|
||||
|
||||
if incoming != self.full_text {
|
||||
if incoming.starts_with(&self.full_text) {
|
||||
let delta = incoming[self.full_text.len()..].to_string();
|
||||
if !delta.is_empty() {
|
||||
text_delta = Some(TextDelta {
|
||||
content: delta,
|
||||
mode: TextDeltaKind::Append,
|
||||
});
|
||||
}
|
||||
} else {
|
||||
text_delta = Some(TextDelta {
|
||||
content: incoming.clone(),
|
||||
mode: TextDeltaKind::Replace,
|
||||
});
|
||||
}
|
||||
self.full_text = incoming;
|
||||
}
|
||||
|
||||
let mut tool_delta = None;
|
||||
if let Some(tool_calls) = chunk.message.tool_calls.clone() {
|
||||
if tool_calls.is_empty() {
|
||||
let previously_had_calls = self
|
||||
.last_tool_calls
|
||||
.as_ref()
|
||||
.map(|prev| !prev.is_empty())
|
||||
.unwrap_or(false);
|
||||
if previously_had_calls {
|
||||
tool_delta = Some(Vec::new());
|
||||
}
|
||||
self.last_tool_calls = None;
|
||||
} else {
|
||||
let is_new = self
|
||||
.last_tool_calls
|
||||
.as_ref()
|
||||
.map(|prev| prev != &tool_calls)
|
||||
.unwrap_or(true);
|
||||
if is_new {
|
||||
tool_delta = Some(tool_calls.clone());
|
||||
}
|
||||
self.last_tool_calls = Some(tool_calls);
|
||||
}
|
||||
}
|
||||
|
||||
StreamDiff {
|
||||
text: text_delta,
|
||||
tool_calls: tool_delta,
|
||||
}
|
||||
}
|
||||
|
||||
fn mark_finished(&mut self) {
|
||||
self.finished = true;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ToolConsentResolution {
|
||||
pub request_id: Uuid,
|
||||
@@ -144,6 +239,7 @@ pub struct SessionController {
|
||||
missing_oauth_servers: Vec<String>,
|
||||
event_tx: Option<UnboundedSender<ControllerEvent>>,
|
||||
pending_tool_requests: HashMap<Uuid, PendingToolRequest>,
|
||||
stream_states: HashMap<Uuid, StreamingMessageState>,
|
||||
}
|
||||
|
||||
async fn build_tools(
|
||||
@@ -471,6 +567,7 @@ impl SessionController {
|
||||
missing_oauth_servers,
|
||||
event_tx,
|
||||
pending_tool_requests: HashMap::new(),
|
||||
stream_states: HashMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1229,6 +1326,8 @@ impl SessionController {
|
||||
match self.provider.stream_prompt(request).await {
|
||||
Ok(stream) => {
|
||||
let response_id = self.conversation.start_streaming_response();
|
||||
self.stream_states
|
||||
.insert(response_id, StreamingMessageState::new());
|
||||
Ok(SessionOutcome::Streaming {
|
||||
response_id,
|
||||
stream,
|
||||
@@ -1248,14 +1347,43 @@ impl SessionController {
|
||||
}
|
||||
|
||||
pub fn apply_stream_chunk(&mut self, message_id: Uuid, chunk: &ChatResponse) -> Result<()> {
|
||||
if chunk.message.has_tool_calls() {
|
||||
self.conversation.set_tool_calls_on_message(
|
||||
let state = self.stream_states.entry(message_id).or_default();
|
||||
|
||||
let diff = state.ingest(chunk);
|
||||
|
||||
if let Some(text_delta) = diff.text {
|
||||
match text_delta.mode {
|
||||
TextDeltaKind::Append => {
|
||||
self.conversation.append_stream_chunk(
|
||||
message_id,
|
||||
chunk.message.tool_calls.clone().unwrap_or_default(),
|
||||
&text_delta.content,
|
||||
chunk.is_final,
|
||||
)?;
|
||||
}
|
||||
TextDeltaKind::Replace => {
|
||||
self.conversation.set_stream_content(
|
||||
message_id,
|
||||
text_delta.content,
|
||||
chunk.is_final,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
} else if chunk.is_final {
|
||||
self.conversation
|
||||
.append_stream_chunk(message_id, &chunk.message.content, chunk.is_final)
|
||||
.append_stream_chunk(message_id, "", true)?;
|
||||
}
|
||||
|
||||
if let Some(tool_calls) = diff.tool_calls {
|
||||
self.conversation
|
||||
.set_tool_calls_on_message(message_id, tool_calls)?;
|
||||
}
|
||||
|
||||
if chunk.is_final {
|
||||
state.mark_finished();
|
||||
self.stream_states.remove(&message_id);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn check_streaming_tool_calls(&mut self, message_id: Uuid) -> Option<Vec<ToolCall>> {
|
||||
@@ -1339,6 +1467,7 @@ impl SessionController {
|
||||
}
|
||||
|
||||
pub fn cancel_stream(&mut self, message_id: Uuid, notice: &str) -> Result<()> {
|
||||
self.stream_states.remove(&message_id);
|
||||
self.conversation
|
||||
.cancel_stream(message_id, notice.to_string())
|
||||
}
|
||||
@@ -1376,10 +1505,12 @@ impl SessionController {
|
||||
|
||||
pub fn start_new_conversation(&mut self, model: Option<String>, name: Option<String>) {
|
||||
self.conversation.start_new(model, name);
|
||||
self.stream_states.clear();
|
||||
}
|
||||
|
||||
pub fn clear(&mut self) {
|
||||
self.conversation.clear();
|
||||
self.stream_states.clear();
|
||||
}
|
||||
|
||||
pub async fn generate_conversation_description(&self) -> Result<String> {
|
||||
@@ -1403,6 +1534,29 @@ mod tests {
|
||||
use std::sync::Arc;
|
||||
use tempfile::tempdir;
|
||||
|
||||
fn make_response(
|
||||
text: &str,
|
||||
tool_calls: Option<Vec<ToolCall>>,
|
||||
is_final: bool,
|
||||
) -> ChatResponse {
|
||||
let mut message = Message::assistant(text.to_string());
|
||||
message.tool_calls = tool_calls;
|
||||
ChatResponse {
|
||||
message,
|
||||
usage: None,
|
||||
is_streaming: true,
|
||||
is_final,
|
||||
}
|
||||
}
|
||||
|
||||
fn make_tool_call(id: &str, name: &str) -> ToolCall {
|
||||
ToolCall {
|
||||
id: id.to_string(),
|
||||
name: name.to_string(),
|
||||
arguments: serde_json::json!({}),
|
||||
}
|
||||
}
|
||||
|
||||
const SERVER_NAME: &str = "oauth-test";
|
||||
|
||||
fn build_oauth_config(server: &MockServer) -> McpOAuthConfig {
|
||||
@@ -1441,6 +1595,47 @@ mod tests {
|
||||
config
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn streaming_state_tracks_text_deltas() {
|
||||
let mut state = StreamingMessageState::new();
|
||||
|
||||
let diff = state.ingest(&make_response("Hello", None, false));
|
||||
let first = diff.text.expect("text diff");
|
||||
assert_eq!(first.content, "Hello");
|
||||
assert_eq!(first.mode, TextDeltaKind::Append);
|
||||
|
||||
let diff = state.ingest(&make_response("Hello world", None, false));
|
||||
let second = diff.text.expect("second diff");
|
||||
assert_eq!(second.content, " world");
|
||||
assert_eq!(second.mode, TextDeltaKind::Append);
|
||||
|
||||
let diff = state.ingest(&make_response("Hi", None, false));
|
||||
let third = diff.text.expect("third diff");
|
||||
assert_eq!(third.content, "Hi");
|
||||
assert_eq!(third.mode, TextDeltaKind::Replace);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn streaming_state_detects_tool_call_changes() {
|
||||
let mut state = StreamingMessageState::new();
|
||||
let tool = make_tool_call("call-1", "web.search");
|
||||
|
||||
let diff = state.ingest(&make_response("", Some(vec![tool.clone()]), false));
|
||||
let calls = diff.tool_calls.expect("initial tool call");
|
||||
assert_eq!(calls.len(), 1);
|
||||
assert_eq!(calls[0].name, "web.search");
|
||||
|
||||
let diff = state.ingest(&make_response("", Some(vec![tool.clone()]), false));
|
||||
assert!(
|
||||
diff.tool_calls.is_none(),
|
||||
"duplicate tool call should not emit"
|
||||
);
|
||||
|
||||
let diff = state.ingest(&make_response("", Some(vec![]), false));
|
||||
let cleared = diff.tool_calls.expect("clearing tool calls");
|
||||
assert!(cleared.is_empty());
|
||||
}
|
||||
|
||||
async fn build_session(server: &MockServer) -> (SessionController, tempfile::TempDir) {
|
||||
unsafe {
|
||||
std::env::set_var("OWLEN_MASTER_PASSWORD", "test-password");
|
||||
|
||||
Reference in New Issue
Block a user