feat(session): implement streaming state with text delta and tool‑call diff handling
- Introduce `StreamingMessageState` to track full text, last tool calls, and completion. - Add `StreamDiff`, `TextDelta`, and `TextDeltaKind` for describing incremental changes. - SessionController now maintains a `stream_states` map keyed by response IDs. - `apply_stream_chunk` uses the new state to emit append/replace text deltas and tool‑call updates, handling final chunks and cleanup. - `Conversation` gains `set_stream_content` to replace streaming content and manage metadata. - Ensure stream state is cleared on cancel, conversation reset, and controller clear.
This commit is contained in:
@@ -190,6 +190,46 @@ impl ConversationManager {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Replace the current streaming content for a message.
|
||||||
|
pub fn set_stream_content(
|
||||||
|
&mut self,
|
||||||
|
message_id: Uuid,
|
||||||
|
content: impl Into<String>,
|
||||||
|
is_final: bool,
|
||||||
|
) -> Result<()> {
|
||||||
|
let index = self
|
||||||
|
.message_index
|
||||||
|
.get(&message_id)
|
||||||
|
.copied()
|
||||||
|
.ok_or_else(|| crate::Error::Unknown(format!("Unknown message id: {message_id}")))?;
|
||||||
|
|
||||||
|
let conversation = self.active_mut();
|
||||||
|
if let Some(message) = conversation.messages.get_mut(index) {
|
||||||
|
message.content = content.into();
|
||||||
|
message.metadata.remove(PLACEHOLDER_FLAG);
|
||||||
|
message.timestamp = std::time::SystemTime::now();
|
||||||
|
let millis = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_millis() as u64;
|
||||||
|
message.metadata.insert(
|
||||||
|
LAST_CHUNK_TS.to_string(),
|
||||||
|
Value::Number(Number::from(millis)),
|
||||||
|
);
|
||||||
|
|
||||||
|
if is_final {
|
||||||
|
message
|
||||||
|
.metadata
|
||||||
|
.insert(STREAMING_FLAG.to_string(), Value::Bool(false));
|
||||||
|
self.streaming.remove(&message_id);
|
||||||
|
} else if let Some(info) = self.streaming.get_mut(&message_id) {
|
||||||
|
info.last_update = Instant::now();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Set placeholder text for a streaming message
|
/// Set placeholder text for a streaming message
|
||||||
pub fn set_stream_placeholder(
|
pub fn set_stream_placeholder(
|
||||||
&mut self,
|
&mut self,
|
||||||
@@ -254,7 +294,11 @@ impl ConversationManager {
|
|||||||
.ok_or_else(|| crate::Error::Unknown(format!("Unknown message id: {message_id}")))?;
|
.ok_or_else(|| crate::Error::Unknown(format!("Unknown message id: {message_id}")))?;
|
||||||
|
|
||||||
if let Some(message) = self.active_mut().messages.get_mut(index) {
|
if let Some(message) = self.active_mut().messages.get_mut(index) {
|
||||||
message.tool_calls = Some(tool_calls);
|
if tool_calls.is_empty() {
|
||||||
|
message.tool_calls = None;
|
||||||
|
} else {
|
||||||
|
message.tool_calls = Some(tool_calls);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|||||||
@@ -66,6 +66,101 @@ struct PendingToolRequest {
|
|||||||
tool_calls: Vec<ToolCall>,
|
tool_calls: Vec<ToolCall>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
struct StreamingMessageState {
|
||||||
|
full_text: String,
|
||||||
|
last_tool_calls: Option<Vec<ToolCall>>,
|
||||||
|
finished: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct StreamDiff {
|
||||||
|
text: Option<TextDelta>,
|
||||||
|
tool_calls: Option<Vec<ToolCall>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct TextDelta {
|
||||||
|
content: String,
|
||||||
|
mode: TextDeltaKind,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
enum TextDeltaKind {
|
||||||
|
Append,
|
||||||
|
Replace,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StreamingMessageState {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ingest(&mut self, chunk: &ChatResponse) -> StreamDiff {
|
||||||
|
if self.finished {
|
||||||
|
return StreamDiff {
|
||||||
|
text: None,
|
||||||
|
tool_calls: None,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut text_delta = None;
|
||||||
|
let incoming = chunk.message.content.clone();
|
||||||
|
|
||||||
|
if incoming != self.full_text {
|
||||||
|
if incoming.starts_with(&self.full_text) {
|
||||||
|
let delta = incoming[self.full_text.len()..].to_string();
|
||||||
|
if !delta.is_empty() {
|
||||||
|
text_delta = Some(TextDelta {
|
||||||
|
content: delta,
|
||||||
|
mode: TextDeltaKind::Append,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
text_delta = Some(TextDelta {
|
||||||
|
content: incoming.clone(),
|
||||||
|
mode: TextDeltaKind::Replace,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
self.full_text = incoming;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut tool_delta = None;
|
||||||
|
if let Some(tool_calls) = chunk.message.tool_calls.clone() {
|
||||||
|
if tool_calls.is_empty() {
|
||||||
|
let previously_had_calls = self
|
||||||
|
.last_tool_calls
|
||||||
|
.as_ref()
|
||||||
|
.map(|prev| !prev.is_empty())
|
||||||
|
.unwrap_or(false);
|
||||||
|
if previously_had_calls {
|
||||||
|
tool_delta = Some(Vec::new());
|
||||||
|
}
|
||||||
|
self.last_tool_calls = None;
|
||||||
|
} else {
|
||||||
|
let is_new = self
|
||||||
|
.last_tool_calls
|
||||||
|
.as_ref()
|
||||||
|
.map(|prev| prev != &tool_calls)
|
||||||
|
.unwrap_or(true);
|
||||||
|
if is_new {
|
||||||
|
tool_delta = Some(tool_calls.clone());
|
||||||
|
}
|
||||||
|
self.last_tool_calls = Some(tool_calls);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
StreamDiff {
|
||||||
|
text: text_delta,
|
||||||
|
tool_calls: tool_delta,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mark_finished(&mut self) {
|
||||||
|
self.finished = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct ToolConsentResolution {
|
pub struct ToolConsentResolution {
|
||||||
pub request_id: Uuid,
|
pub request_id: Uuid,
|
||||||
@@ -144,6 +239,7 @@ pub struct SessionController {
|
|||||||
missing_oauth_servers: Vec<String>,
|
missing_oauth_servers: Vec<String>,
|
||||||
event_tx: Option<UnboundedSender<ControllerEvent>>,
|
event_tx: Option<UnboundedSender<ControllerEvent>>,
|
||||||
pending_tool_requests: HashMap<Uuid, PendingToolRequest>,
|
pending_tool_requests: HashMap<Uuid, PendingToolRequest>,
|
||||||
|
stream_states: HashMap<Uuid, StreamingMessageState>,
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn build_tools(
|
async fn build_tools(
|
||||||
@@ -471,6 +567,7 @@ impl SessionController {
|
|||||||
missing_oauth_servers,
|
missing_oauth_servers,
|
||||||
event_tx,
|
event_tx,
|
||||||
pending_tool_requests: HashMap::new(),
|
pending_tool_requests: HashMap::new(),
|
||||||
|
stream_states: HashMap::new(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1229,6 +1326,8 @@ impl SessionController {
|
|||||||
match self.provider.stream_prompt(request).await {
|
match self.provider.stream_prompt(request).await {
|
||||||
Ok(stream) => {
|
Ok(stream) => {
|
||||||
let response_id = self.conversation.start_streaming_response();
|
let response_id = self.conversation.start_streaming_response();
|
||||||
|
self.stream_states
|
||||||
|
.insert(response_id, StreamingMessageState::new());
|
||||||
Ok(SessionOutcome::Streaming {
|
Ok(SessionOutcome::Streaming {
|
||||||
response_id,
|
response_id,
|
||||||
stream,
|
stream,
|
||||||
@@ -1248,14 +1347,43 @@ impl SessionController {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn apply_stream_chunk(&mut self, message_id: Uuid, chunk: &ChatResponse) -> Result<()> {
|
pub fn apply_stream_chunk(&mut self, message_id: Uuid, chunk: &ChatResponse) -> Result<()> {
|
||||||
if chunk.message.has_tool_calls() {
|
let state = self.stream_states.entry(message_id).or_default();
|
||||||
self.conversation.set_tool_calls_on_message(
|
|
||||||
message_id,
|
let diff = state.ingest(chunk);
|
||||||
chunk.message.tool_calls.clone().unwrap_or_default(),
|
|
||||||
)?;
|
if let Some(text_delta) = diff.text {
|
||||||
|
match text_delta.mode {
|
||||||
|
TextDeltaKind::Append => {
|
||||||
|
self.conversation.append_stream_chunk(
|
||||||
|
message_id,
|
||||||
|
&text_delta.content,
|
||||||
|
chunk.is_final,
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
TextDeltaKind::Replace => {
|
||||||
|
self.conversation.set_stream_content(
|
||||||
|
message_id,
|
||||||
|
text_delta.content,
|
||||||
|
chunk.is_final,
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if chunk.is_final {
|
||||||
|
self.conversation
|
||||||
|
.append_stream_chunk(message_id, "", true)?;
|
||||||
}
|
}
|
||||||
self.conversation
|
|
||||||
.append_stream_chunk(message_id, &chunk.message.content, chunk.is_final)
|
if let Some(tool_calls) = diff.tool_calls {
|
||||||
|
self.conversation
|
||||||
|
.set_tool_calls_on_message(message_id, tool_calls)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if chunk.is_final {
|
||||||
|
state.mark_finished();
|
||||||
|
self.stream_states.remove(&message_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn check_streaming_tool_calls(&mut self, message_id: Uuid) -> Option<Vec<ToolCall>> {
|
pub fn check_streaming_tool_calls(&mut self, message_id: Uuid) -> Option<Vec<ToolCall>> {
|
||||||
@@ -1339,6 +1467,7 @@ impl SessionController {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn cancel_stream(&mut self, message_id: Uuid, notice: &str) -> Result<()> {
|
pub fn cancel_stream(&mut self, message_id: Uuid, notice: &str) -> Result<()> {
|
||||||
|
self.stream_states.remove(&message_id);
|
||||||
self.conversation
|
self.conversation
|
||||||
.cancel_stream(message_id, notice.to_string())
|
.cancel_stream(message_id, notice.to_string())
|
||||||
}
|
}
|
||||||
@@ -1376,10 +1505,12 @@ impl SessionController {
|
|||||||
|
|
||||||
pub fn start_new_conversation(&mut self, model: Option<String>, name: Option<String>) {
|
pub fn start_new_conversation(&mut self, model: Option<String>, name: Option<String>) {
|
||||||
self.conversation.start_new(model, name);
|
self.conversation.start_new(model, name);
|
||||||
|
self.stream_states.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn clear(&mut self) {
|
pub fn clear(&mut self) {
|
||||||
self.conversation.clear();
|
self.conversation.clear();
|
||||||
|
self.stream_states.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn generate_conversation_description(&self) -> Result<String> {
|
pub async fn generate_conversation_description(&self) -> Result<String> {
|
||||||
@@ -1403,6 +1534,29 @@ mod tests {
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tempfile::tempdir;
|
use tempfile::tempdir;
|
||||||
|
|
||||||
|
fn make_response(
|
||||||
|
text: &str,
|
||||||
|
tool_calls: Option<Vec<ToolCall>>,
|
||||||
|
is_final: bool,
|
||||||
|
) -> ChatResponse {
|
||||||
|
let mut message = Message::assistant(text.to_string());
|
||||||
|
message.tool_calls = tool_calls;
|
||||||
|
ChatResponse {
|
||||||
|
message,
|
||||||
|
usage: None,
|
||||||
|
is_streaming: true,
|
||||||
|
is_final,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_tool_call(id: &str, name: &str) -> ToolCall {
|
||||||
|
ToolCall {
|
||||||
|
id: id.to_string(),
|
||||||
|
name: name.to_string(),
|
||||||
|
arguments: serde_json::json!({}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const SERVER_NAME: &str = "oauth-test";
|
const SERVER_NAME: &str = "oauth-test";
|
||||||
|
|
||||||
fn build_oauth_config(server: &MockServer) -> McpOAuthConfig {
|
fn build_oauth_config(server: &MockServer) -> McpOAuthConfig {
|
||||||
@@ -1441,6 +1595,47 @@ mod tests {
|
|||||||
config
|
config
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn streaming_state_tracks_text_deltas() {
|
||||||
|
let mut state = StreamingMessageState::new();
|
||||||
|
|
||||||
|
let diff = state.ingest(&make_response("Hello", None, false));
|
||||||
|
let first = diff.text.expect("text diff");
|
||||||
|
assert_eq!(first.content, "Hello");
|
||||||
|
assert_eq!(first.mode, TextDeltaKind::Append);
|
||||||
|
|
||||||
|
let diff = state.ingest(&make_response("Hello world", None, false));
|
||||||
|
let second = diff.text.expect("second diff");
|
||||||
|
assert_eq!(second.content, " world");
|
||||||
|
assert_eq!(second.mode, TextDeltaKind::Append);
|
||||||
|
|
||||||
|
let diff = state.ingest(&make_response("Hi", None, false));
|
||||||
|
let third = diff.text.expect("third diff");
|
||||||
|
assert_eq!(third.content, "Hi");
|
||||||
|
assert_eq!(third.mode, TextDeltaKind::Replace);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn streaming_state_detects_tool_call_changes() {
|
||||||
|
let mut state = StreamingMessageState::new();
|
||||||
|
let tool = make_tool_call("call-1", "web.search");
|
||||||
|
|
||||||
|
let diff = state.ingest(&make_response("", Some(vec![tool.clone()]), false));
|
||||||
|
let calls = diff.tool_calls.expect("initial tool call");
|
||||||
|
assert_eq!(calls.len(), 1);
|
||||||
|
assert_eq!(calls[0].name, "web.search");
|
||||||
|
|
||||||
|
let diff = state.ingest(&make_response("", Some(vec![tool.clone()]), false));
|
||||||
|
assert!(
|
||||||
|
diff.tool_calls.is_none(),
|
||||||
|
"duplicate tool call should not emit"
|
||||||
|
);
|
||||||
|
|
||||||
|
let diff = state.ingest(&make_response("", Some(vec![]), false));
|
||||||
|
let cleared = diff.tool_calls.expect("clearing tool calls");
|
||||||
|
assert!(cleared.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
async fn build_session(server: &MockServer) -> (SessionController, tempfile::TempDir) {
|
async fn build_session(server: &MockServer) -> (SessionController, tempfile::TempDir) {
|
||||||
unsafe {
|
unsafe {
|
||||||
std::env::set_var("OWLEN_MASTER_PASSWORD", "test-password");
|
std::env::set_var("OWLEN_MASTER_PASSWORD", "test-password");
|
||||||
|
|||||||
Reference in New Issue
Block a user