use std::{any::Any, collections::HashMap, sync::Arc}; use async_trait::async_trait; use futures::StreamExt; use owlen_core::{ Config, Error, Mode, Provider, config::McpMode, consent::ConsentScope, mcp::{ McpClient, McpToolCall, McpToolDescriptor, McpToolResponse, failover::{FailoverMcpClient, ServerEntry}, }, session::{ControllerEvent, SessionController, SessionOutcome}, storage::StorageManager, types::{ChatParameters, ChatRequest, ChatResponse, Message, ModelInfo, Role, ToolCall}, ui::NoOpUiController, }; use tempfile::tempdir; use tokio::sync::mpsc; struct StreamingToolProvider; #[async_trait] impl Provider for StreamingToolProvider { fn name(&self) -> &str { "mock-streaming-provider" } async fn list_models(&self) -> owlen_core::Result> { Ok(vec![ModelInfo { id: "mock-model".into(), name: "Mock Model".into(), description: Some("A mock model that emits tool calls".into()), provider: self.name().into(), context_window: Some(4096), capabilities: vec!["chat".into(), "tools".into()], supports_tools: true, }]) } async fn send_prompt(&self, _request: ChatRequest) -> owlen_core::Result { let mut message = Message::assistant("tool-call".to_string()); message.tool_calls = Some(vec![ToolCall { id: "call-1".to_string(), name: "resources/write".to_string(), arguments: serde_json::json!({"path": "README.md", "content": "hello"}), }]); Ok(ChatResponse { message, usage: None, is_streaming: false, is_final: true, }) } async fn stream_prompt( &self, _request: ChatRequest, ) -> owlen_core::Result { let mut first_chunk = Message::assistant( "Thought: need to update README.\nAction: resources/write".to_string(), ); first_chunk.tool_calls = Some(vec![ToolCall { id: "call-1".to_string(), name: "resources/write".to_string(), arguments: serde_json::json!({"path": "README.md", "content": "hello"}), }]); let chunk = ChatResponse { message: first_chunk, usage: None, is_streaming: true, is_final: false, }; Ok(Box::pin(futures::stream::iter(vec![Ok(chunk)]))) } async fn health_check(&self) -> owlen_core::Result<()> { Ok(()) } fn as_any(&self) -> &(dyn Any + Send + Sync) { self } } fn tool_descriptor() -> McpToolDescriptor { McpToolDescriptor { name: "web_search".to_string(), description: "search".to_string(), input_schema: serde_json::json!({"type": "object"}), requires_network: true, requires_filesystem: vec![], } } struct TimeoutClient; #[async_trait] impl McpClient for TimeoutClient { async fn list_tools(&self) -> owlen_core::Result> { Ok(vec![tool_descriptor()]) } async fn call_tool(&self, _call: McpToolCall) -> owlen_core::Result { Err(Error::Network( "timeout while contacting remote web search endpoint".into(), )) } } #[derive(Clone)] struct CachedResponseClient { response: Arc, } impl CachedResponseClient { fn new() -> Self { let mut metadata = HashMap::new(); metadata.insert("source".to_string(), "cache".to_string()); metadata.insert("cached".to_string(), "true".to_string()); let response = McpToolResponse { name: "web_search".to_string(), success: true, output: serde_json::json!({ "query": "rust", "results": [ {"title": "Rust Programming Language", "url": "https://www.rust-lang.org"} ], "note": "cached result" }), metadata, duration_ms: 0, }; Self { response: Arc::new(response), } } } #[async_trait] impl McpClient for CachedResponseClient { async fn list_tools(&self) -> owlen_core::Result> { Ok(vec![tool_descriptor()]) } async fn call_tool(&self, _call: McpToolCall) -> owlen_core::Result { Ok((*self.response).clone()) } } #[tokio::test(flavor = "multi_thread")] async fn streaming_file_write_consent_denied_returns_resolution() { let temp_dir = tempdir().expect("temp dir"); let storage = StorageManager::with_database_path(temp_dir.path().join("owlen-tests.db")) .await .expect("storage"); let mut config = Config::default(); config.general.enable_streaming = true; config.privacy.encrypt_local_data = false; config.privacy.require_consent_per_session = true; config.general.default_model = Some("mock-model".into()); config.mcp.mode = McpMode::LocalOnly; config .refresh_mcp_servers(None) .expect("refresh MCP servers"); let provider: Arc = Arc::new(StreamingToolProvider); let ui = Arc::new(NoOpUiController); let (event_tx, mut event_rx) = mpsc::unbounded_channel::(); let mut session = SessionController::new( provider, config, Arc::new(storage), ui, true, Some(event_tx), ) .await .expect("session controller"); session .set_operating_mode(Mode::Code) .await .expect("code mode"); let outcome = session .send_message( "Please write to README".to_string(), ChatParameters { stream: true, ..Default::default() }, ) .await .expect("send message"); let (response_id, mut stream) = if let SessionOutcome::Streaming { response_id, stream, } = outcome { (response_id, stream) } else { panic!("expected streaming outcome"); }; session .mark_stream_placeholder(response_id, "▌") .expect("placeholder"); let chunk = stream .next() .await .expect("stream chunk") .expect("chunk result"); session .apply_stream_chunk(response_id, &chunk) .expect("apply chunk"); let tool_calls = session .check_streaming_tool_calls(response_id) .expect("tool calls"); assert_eq!(tool_calls.len(), 1); assert_eq!(tool_calls[0].name, "resources/write"); let event = event_rx.recv().await.expect("controller event"); let request_id = match event { ControllerEvent::ToolRequested { request_id, tool_name, data_types, endpoints, .. } => { assert_eq!(tool_name, "resources/write"); assert!(data_types.iter().any(|t| t.contains("file"))); assert!(endpoints.iter().any(|e| e.contains("filesystem"))); request_id } }; let resolution = session .resolve_tool_consent(request_id, ConsentScope::Denied) .expect("resolution"); assert_eq!(resolution.scope, ConsentScope::Denied); assert_eq!(resolution.tool_name, "resources/write"); assert_eq!(resolution.tool_calls.len(), tool_calls.len()); let err = session .resolve_tool_consent(request_id, ConsentScope::Denied) .expect_err("second resolution should fail"); matches!(err, Error::InvalidInput(_)); let conversation = session.conversation().clone(); let assistant = conversation .messages .iter() .find(|message| message.role == Role::Assistant) .expect("assistant message present"); assert!( assistant .tool_calls .as_ref() .and_then(|calls| calls.first()) .is_some_and(|call| call.name == "resources/write"), "stream chunk should capture the tool call on the assistant message" ); } #[tokio::test] async fn web_tool_timeout_fails_over_to_cached_result() { let primary: Arc = Arc::new(TimeoutClient); let cached = CachedResponseClient::new(); let backup: Arc = Arc::new(cached.clone()); let client = FailoverMcpClient::with_servers(vec![ ServerEntry::new("primary".into(), primary, 1), ServerEntry::new("cache".into(), backup, 2), ]); let call = McpToolCall { name: "web_search".to_string(), arguments: serde_json::json!({ "query": "rust", "max_results": 3 }), }; let response = client.call_tool(call.clone()).await.expect("fallback"); assert_eq!(response.name, "web_search"); assert_eq!( response.metadata.get("source").map(String::as_str), Some("cache") ); assert_eq!( response.output.get("note").and_then(|value| value.as_str()), Some("cached result") ); let statuses = client.get_server_status().await; assert!(statuses.iter().any(|(name, health)| name == "primary" && !matches!(health, owlen_core::mcp::failover::ServerHealth::Healthy))); assert!(statuses.iter().any(|(name, health)| name == "cache" && matches!(health, owlen_core::mcp::failover::ServerHealth::Healthy))); }