use std::{ any::Any, collections::HashMap, sync::{Arc, Mutex}, }; use anyhow::anyhow; use async_trait::async_trait; use futures::StreamExt; use owlen_core::tools::{WEB_SEARCH_TOOL_NAME, tool_name_matches}; use owlen_core::{ Config, Error, Mode, Provider, config::McpMode, consent::ConsentScope, mcp::{ McpClient, McpToolCall, McpToolDescriptor, McpToolResponse, failover::{FailoverMcpClient, ServerEntry}, }, session::{ControllerEvent, SessionController, SessionOutcome}, storage::StorageManager, types::{ChatParameters, ChatRequest, ChatResponse, Message, ModelInfo, Role, ToolCall}, ui::NoOpUiController, }; use tempfile::tempdir; use tokio::sync::mpsc; struct StreamingToolProvider; #[async_trait] impl Provider for StreamingToolProvider { fn name(&self) -> &str { "mock-streaming-provider" } async fn list_models(&self) -> owlen_core::Result> { Ok(vec![ModelInfo { id: "mock-model".into(), name: "Mock Model".into(), description: Some("A mock model that emits tool calls".into()), provider: self.name().into(), context_window: Some(4096), capabilities: vec!["chat".into(), "tools".into()], supports_tools: true, }]) } async fn send_prompt(&self, _request: ChatRequest) -> owlen_core::Result { let mut message = Message::assistant("tool-call".to_string()); message.tool_calls = Some(vec![ToolCall { id: "call-1".to_string(), name: "resources_write".to_string(), arguments: serde_json::json!({"path": "README.md", "content": "hello"}), }]); Ok(ChatResponse { message, usage: None, is_streaming: false, is_final: true, }) } async fn stream_prompt( &self, _request: ChatRequest, ) -> owlen_core::Result { let mut first_chunk = Message::assistant( "Thought: need to update README.\nAction: resources/write".to_string(), ); first_chunk.tool_calls = Some(vec![ToolCall { id: "call-1".to_string(), name: "resources_write".to_string(), arguments: serde_json::json!({"path": "README.md", "content": "hello"}), }]); let chunk = ChatResponse { message: first_chunk, usage: None, is_streaming: true, is_final: false, }; Ok(Box::pin(futures::stream::iter(vec![Ok(chunk)]))) } async fn health_check(&self) -> owlen_core::Result<()> { Ok(()) } fn as_any(&self) -> &(dyn Any + Send + Sync) { self } } struct NoToolSupportProvider { captured: Arc>>, } impl NoToolSupportProvider { fn new() -> Self { Self { captured: Arc::new(Mutex::new(None)), } } fn take_captured(&self) -> Option { self.captured.lock().expect("capture mutex").take() } } #[async_trait] impl Provider for NoToolSupportProvider { fn name(&self) -> &str { "mock-tool-less-provider" } async fn list_models(&self) -> owlen_core::Result> { Ok(vec![ModelInfo { id: "tool-less-model".into(), name: "Toolless Model".into(), description: Some("A model without tool support.".into()), provider: self.name().into(), context_window: Some(4096), capabilities: vec!["chat".into()], supports_tools: false, }]) } async fn send_prompt(&self, request: ChatRequest) -> owlen_core::Result { { let mut guard = self.captured.lock().expect("capture mutex"); *guard = Some(request); } Ok(ChatResponse { message: Message::assistant("ack".to_string()), usage: None, is_streaming: false, is_final: true, }) } async fn stream_prompt( &self, _request: ChatRequest, ) -> owlen_core::Result { Err(Error::Provider(anyhow!( "streaming disabled for mock provider" ))) } async fn health_check(&self) -> owlen_core::Result<()> { Ok(()) } fn as_any(&self) -> &(dyn Any + Send + Sync) { self } } fn tool_descriptor() -> McpToolDescriptor { McpToolDescriptor { name: WEB_SEARCH_TOOL_NAME.to_string(), description: "search".to_string(), input_schema: serde_json::json!({"type": "object"}), requires_network: true, requires_filesystem: vec![], } } struct TimeoutClient; #[async_trait] impl McpClient for TimeoutClient { async fn list_tools(&self) -> owlen_core::Result> { Ok(vec![tool_descriptor()]) } async fn call_tool(&self, _call: McpToolCall) -> owlen_core::Result { Err(Error::Network( "timeout while contacting remote web search endpoint".into(), )) } } #[derive(Clone)] struct CachedResponseClient { response: Arc, } impl CachedResponseClient { fn new() -> Self { let mut metadata = HashMap::new(); metadata.insert("source".to_string(), "cache".to_string()); metadata.insert("cached".to_string(), "true".to_string()); let response = McpToolResponse { name: WEB_SEARCH_TOOL_NAME.to_string(), success: true, output: serde_json::json!({ "query": "rust", "results": [ {"title": "Rust Programming Language", "url": "https://www.rust-lang.org"} ], "note": "cached result" }), metadata, duration_ms: 0, }; Self { response: Arc::new(response), } } } #[async_trait] impl McpClient for CachedResponseClient { async fn list_tools(&self) -> owlen_core::Result> { Ok(vec![tool_descriptor()]) } async fn call_tool(&self, _call: McpToolCall) -> owlen_core::Result { Ok((*self.response).clone()) } } #[tokio::test(flavor = "multi_thread")] async fn streaming_file_write_consent_denied_returns_resolution() { let temp_dir = tempdir().expect("temp dir"); let storage = StorageManager::with_database_path(temp_dir.path().join("owlen-tests.db")) .await .expect("storage"); let mut config = Config::default(); config.general.enable_streaming = true; config.privacy.encrypt_local_data = false; config.privacy.require_consent_per_session = true; config.general.default_model = Some("mock-model".into()); config.mcp.mode = McpMode::LocalOnly; config .refresh_mcp_servers(None) .expect("refresh MCP servers"); let provider: Arc = Arc::new(StreamingToolProvider); let ui = Arc::new(NoOpUiController); let (event_tx, mut event_rx) = mpsc::unbounded_channel::(); let mut session = SessionController::new( provider, config, Arc::new(storage), ui, true, Some(event_tx), ) .await .expect("session controller"); session .set_operating_mode(Mode::Code) .await .expect("code mode"); let outcome = session .send_message( "Please write to README".to_string(), ChatParameters { stream: true, ..Default::default() }, ) .await .expect("send message"); let (response_id, mut stream) = if let SessionOutcome::Streaming { response_id, stream, } = outcome { (response_id, stream) } else { panic!("expected streaming outcome"); }; session .mark_stream_placeholder(response_id, "▌") .expect("placeholder"); let chunk = stream .next() .await .expect("stream chunk") .expect("chunk result"); session .apply_stream_chunk(response_id, &chunk) .expect("apply chunk"); let tool_calls = session .check_streaming_tool_calls(response_id) .expect("tool calls"); assert_eq!(tool_calls.len(), 1); assert_eq!(tool_calls[0].name, "resources_write"); let request_id = loop { match event_rx.recv().await.expect("controller event") { ControllerEvent::ToolRequested { request_id, tool_name, data_types, endpoints, .. } => { assert_eq!(tool_name, "resources_write"); assert!(data_types.iter().any(|t| t.contains("file"))); assert!(endpoints.iter().any(|e| e.contains("filesystem"))); break request_id; } ControllerEvent::CompressionCompleted { .. } => continue, } }; let resolution = session .resolve_tool_consent(request_id, ConsentScope::Denied) .expect("resolution"); assert_eq!(resolution.scope, ConsentScope::Denied); assert_eq!(resolution.tool_name, "resources_write"); assert_eq!(resolution.tool_calls.len(), tool_calls.len()); let err = session .resolve_tool_consent(request_id, ConsentScope::Denied) .expect_err("second resolution should fail"); matches!(err, Error::InvalidInput(_)); let conversation = session.conversation().clone(); let assistant = conversation .messages .iter() .find(|message| message.role == Role::Assistant) .expect("assistant message present"); assert!( assistant .tool_calls .as_ref() .and_then(|calls| calls.first()) .is_some_and(|call| call.name == "resources_write"), "stream chunk should capture the tool call on the assistant message" ); } #[tokio::test(flavor = "multi_thread")] async fn disables_tools_when_model_lacks_support() { let raw_provider = Arc::new(NoToolSupportProvider::new()); let provider: Arc = raw_provider.clone(); let temp_dir = tempdir().expect("temp dir"); let storage = StorageManager::with_database_path(temp_dir.path().join("owlen-tests.db")) .await .expect("storage"); let mut config = Config::default(); config.general.default_provider = "mock-tool-less-provider".into(); config.general.default_model = Some("tool-less-model".into()); config.general.enable_streaming = false; config.privacy.encrypt_local_data = false; config.privacy.require_consent_per_session = false; config.mcp.mode = McpMode::LocalOnly; config .refresh_mcp_servers(None) .expect("refresh MCP servers"); let ui = Arc::new(NoOpUiController); let mut session = SessionController::new(provider, config, Arc::new(storage), ui, false, None) .await .expect("session controller"); let outcome = session .send_message( "Please respond without tools.".to_string(), ChatParameters::default(), ) .await .expect("send_message should succeed"); if let SessionOutcome::Complete(response) = outcome { assert_eq!(response.message.content, "ack"); } else { panic!("expected complete outcome when sending prompt"); } let captured = raw_provider .take_captured() .expect("provider should capture chat request"); assert!( captured.tools.is_none(), "tools should be disabled when the selected model does not support tools" ); } #[tokio::test] async fn web_tool_timeout_fails_over_to_cached_result() { let primary: Arc = Arc::new(TimeoutClient); let cached = CachedResponseClient::new(); let backup: Arc = Arc::new(cached.clone()); let client = FailoverMcpClient::with_servers(vec![ ServerEntry::new("primary".into(), primary, 1), ServerEntry::new("cache".into(), backup, 2), ]); let call = McpToolCall { name: WEB_SEARCH_TOOL_NAME.to_string(), arguments: serde_json::json!({ "query": "rust", "max_results": 3 }), }; let response = client.call_tool(call.clone()).await.expect("fallback"); assert!(tool_name_matches(&response.name, WEB_SEARCH_TOOL_NAME)); assert_eq!( response.metadata.get("source").map(String::as_str), Some("cache") ); assert_eq!( response.output.get("note").and_then(|value| value.as_str()), Some("cached result") ); let statuses = client.get_server_status().await; assert!(statuses.iter().any(|(name, health)| name == "primary" && !matches!(health, owlen_core::mcp::failover::ServerHealth::Healthy))); assert!(statuses.iter().any(|(name, health)| name == "cache" && matches!(health, owlen_core::mcp::failover::ServerHealth::Healthy))); }