use agent_core::{create_event_channel, run_agent_loop_streaming, AgentEvent, ToolContext}; use async_trait::async_trait; use futures_util::stream; use llm_core::{ ChatMessage, ChatOptions, LlmError, StreamChunk, LlmProvider, Tool, ToolCallDelta, }; use permissions::{Mode, PermissionManager}; use std::pin::Pin; /// Mock LLM provider for testing streaming struct MockStreamingProvider { responses: Vec, } enum MockResponse { /// Text-only response (no tool calls) Text(Vec), // Chunks of text /// Tool call response ToolCall { text_chunks: Vec, tool_id: String, tool_name: String, tool_args: String, }, } #[async_trait] impl LlmProvider for MockStreamingProvider { fn name(&self) -> &str { "mock" } fn model(&self) -> &str { "mock-model" } async fn chat_stream( &self, messages: &[ChatMessage], _options: &ChatOptions, _tools: Option<&[Tool]>, ) -> Result> + Send>>, LlmError> { // Determine which response to use based on message count let response_idx = (messages.len() / 2).min(self.responses.len() - 1); let response = &self.responses[response_idx]; let chunks: Vec> = match response { MockResponse::Text(text_chunks) => text_chunks .iter() .map(|text| { Ok(StreamChunk { content: Some(text.clone()), tool_calls: None, done: false, usage: None, }) }) .collect(), MockResponse::ToolCall { text_chunks, tool_id, tool_name, tool_args, } => { let mut result = vec![]; // First emit text chunks for text in text_chunks { result.push(Ok(StreamChunk { content: Some(text.clone()), tool_calls: None, done: false, usage: None, })); } // Then emit tool call in chunks result.push(Ok(StreamChunk { content: None, tool_calls: Some(vec![ToolCallDelta { index: 0, id: Some(tool_id.clone()), function_name: Some(tool_name.clone()), arguments_delta: None, }]), done: false, usage: None, })); // Emit args in chunks for chunk in tool_args.chars().collect::>().chunks(5) { result.push(Ok(StreamChunk { content: None, tool_calls: Some(vec![ToolCallDelta { index: 0, id: None, function_name: None, arguments_delta: Some(chunk.iter().collect()), }]), done: false, usage: None, })); } result } }; Ok(Box::pin(stream::iter(chunks))) } } #[tokio::test] async fn test_streaming_text_only() { let provider = MockStreamingProvider { responses: vec![MockResponse::Text(vec![ "Hello".to_string(), " ".to_string(), "world".to_string(), "!".to_string(), ])], }; let perms = PermissionManager::new(Mode::Plan); let ctx = ToolContext::default(); let (tx, mut rx) = create_event_channel(); // Spawn the agent loop let handle = tokio::spawn(async move { run_agent_loop_streaming( &provider, "Say hello", &ChatOptions::default(), &perms, &ctx, tx, ) .await }); // Collect events let mut text_deltas = vec![]; let mut done_response = None; while let Some(event) = rx.recv().await { match event { AgentEvent::TextDelta(text) => { text_deltas.push(text); } AgentEvent::Done { final_response } => { done_response = Some(final_response); break; } AgentEvent::Error(e) => { panic!("Unexpected error: {}", e); } _ => {} } } // Wait for agent loop to complete let result = handle.await.unwrap(); assert!(result.is_ok()); // Verify events assert_eq!(text_deltas, vec!["Hello", " ", "world", "!"]); assert_eq!(done_response, Some("Hello world!".to_string())); assert_eq!(result.unwrap(), "Hello world!"); } #[tokio::test] async fn test_streaming_with_tool_call() { let provider = MockStreamingProvider { responses: vec![ MockResponse::ToolCall { text_chunks: vec!["Let me ".to_string(), "check...".to_string()], tool_id: "call_123".to_string(), tool_name: "glob".to_string(), tool_args: r#"{"pattern":"*.rs"}"#.to_string(), }, MockResponse::Text(vec!["Found ".to_string(), "the files!".to_string()]), ], }; let perms = PermissionManager::new(Mode::Plan); let ctx = ToolContext::default(); let (tx, mut rx) = create_event_channel(); // Spawn the agent loop let handle = tokio::spawn(async move { run_agent_loop_streaming( &provider, "Find Rust files", &ChatOptions::default(), &perms, &ctx, tx, ) .await }); // Collect events let mut text_deltas = vec![]; let mut tool_starts = vec![]; let mut tool_outputs = vec![]; let mut tool_ends = vec![]; while let Some(event) = rx.recv().await { match event { AgentEvent::TextDelta(text) => { text_deltas.push(text); } AgentEvent::ToolStart { tool_name, tool_id, } => { tool_starts.push((tool_name, tool_id)); } AgentEvent::ToolOutput { tool_id, content, is_error, } => { tool_outputs.push((tool_id, content, is_error)); } AgentEvent::ToolEnd { tool_id, success } => { tool_ends.push((tool_id, success)); } AgentEvent::Done { .. } => { break; } AgentEvent::Error(e) => { panic!("Unexpected error: {}", e); } } } // Wait for agent loop to complete let result = handle.await.unwrap(); assert!(result.is_ok()); // Verify we got text deltas from both responses assert!(text_deltas.contains(&"Let me ".to_string())); assert!(text_deltas.contains(&"check...".to_string())); assert!(text_deltas.contains(&"Found ".to_string())); assert!(text_deltas.contains(&"the files!".to_string())); // Verify tool events assert_eq!(tool_starts.len(), 1); assert_eq!(tool_starts[0].0, "glob"); assert_eq!(tool_starts[0].1, "call_123"); assert_eq!(tool_outputs.len(), 1); assert_eq!(tool_outputs[0].0, "call_123"); assert!(!tool_outputs[0].2); // not an error assert_eq!(tool_ends.len(), 1); assert_eq!(tool_ends[0].0, "call_123"); assert!(tool_ends[0].1); // success } #[tokio::test] async fn test_channel_creation() { let (tx, mut rx) = create_event_channel(); // Test that channel works tx.send(AgentEvent::TextDelta("test".to_string())) .await .unwrap(); let event = rx.recv().await.unwrap(); match event { AgentEvent::TextDelta(text) => assert_eq!(text, "test"), _ => panic!("Wrong event type"), } }