diff --git a/crates/owlen-core/tests/agent_tool_flow.rs b/crates/owlen-core/tests/agent_tool_flow.rs new file mode 100644 index 0000000..29b937b --- /dev/null +++ b/crates/owlen-core/tests/agent_tool_flow.rs @@ -0,0 +1,310 @@ +use std::{any::Any, collections::HashMap, sync::Arc}; + +use async_trait::async_trait; +use futures::StreamExt; +use owlen_core::{ + Config, Error, Mode, Provider, + config::McpMode, + consent::ConsentScope, + mcp::{ + McpClient, McpToolCall, McpToolDescriptor, McpToolResponse, + failover::{FailoverMcpClient, ServerEntry}, + }, + session::{ControllerEvent, SessionController, SessionOutcome}, + storage::StorageManager, + types::{ChatParameters, ChatRequest, ChatResponse, Message, ModelInfo, Role, ToolCall}, + ui::NoOpUiController, +}; +use tempfile::tempdir; +use tokio::sync::mpsc; + +struct StreamingToolProvider; + +#[async_trait] +impl Provider for StreamingToolProvider { + fn name(&self) -> &str { + "mock-streaming-provider" + } + + async fn list_models(&self) -> owlen_core::Result> { + Ok(vec![ModelInfo { + id: "mock-model".into(), + name: "Mock Model".into(), + description: Some("A mock model that emits tool calls".into()), + provider: self.name().into(), + context_window: Some(4096), + capabilities: vec!["chat".into(), "tools".into()], + supports_tools: true, + }]) + } + + async fn send_prompt(&self, _request: ChatRequest) -> owlen_core::Result { + let mut message = Message::assistant("tool-call".to_string()); + message.tool_calls = Some(vec![ToolCall { + id: "call-1".to_string(), + name: "resources/write".to_string(), + arguments: serde_json::json!({"path": "README.md", "content": "hello"}), + }]); + + Ok(ChatResponse { + message, + usage: None, + is_streaming: false, + is_final: true, + }) + } + + async fn stream_prompt( + &self, + _request: ChatRequest, + ) -> owlen_core::Result { + let mut first_chunk = Message::assistant( + "Thought: need to update README.\nAction: resources/write".to_string(), + ); + first_chunk.tool_calls = Some(vec![ToolCall { + id: "call-1".to_string(), + name: "resources/write".to_string(), + arguments: serde_json::json!({"path": "README.md", "content": "hello"}), + }]); + + let chunk = ChatResponse { + message: first_chunk, + usage: None, + is_streaming: true, + is_final: false, + }; + + Ok(Box::pin(futures::stream::iter(vec![Ok(chunk)]))) + } + + async fn health_check(&self) -> owlen_core::Result<()> { + Ok(()) + } + + fn as_any(&self) -> &(dyn Any + Send + Sync) { + self + } +} + +fn tool_descriptor() -> McpToolDescriptor { + McpToolDescriptor { + name: "web_search".to_string(), + description: "search".to_string(), + input_schema: serde_json::json!({"type": "object"}), + requires_network: true, + requires_filesystem: vec![], + } +} + +struct TimeoutClient; + +#[async_trait] +impl McpClient for TimeoutClient { + async fn list_tools(&self) -> owlen_core::Result> { + Ok(vec![tool_descriptor()]) + } + + async fn call_tool(&self, _call: McpToolCall) -> owlen_core::Result { + Err(Error::Network( + "timeout while contacting remote web search endpoint".into(), + )) + } +} + +#[derive(Clone)] +struct CachedResponseClient { + response: Arc, +} + +impl CachedResponseClient { + fn new() -> Self { + let mut metadata = HashMap::new(); + metadata.insert("source".to_string(), "cache".to_string()); + metadata.insert("cached".to_string(), "true".to_string()); + + let response = McpToolResponse { + name: "web_search".to_string(), + success: true, + output: serde_json::json!({ + "query": "rust", + "results": [ + {"title": "Rust Programming Language", "url": "https://www.rust-lang.org"} + ], + "note": "cached result" + }), + metadata, + duration_ms: 0, + }; + + Self { + response: Arc::new(response), + } + } +} + +#[async_trait] +impl McpClient for CachedResponseClient { + async fn list_tools(&self) -> owlen_core::Result> { + Ok(vec![tool_descriptor()]) + } + + async fn call_tool(&self, _call: McpToolCall) -> owlen_core::Result { + Ok((*self.response).clone()) + } +} + +#[tokio::test(flavor = "multi_thread")] +async fn streaming_file_write_consent_denied_returns_resolution() { + let temp_dir = tempdir().expect("temp dir"); + let storage = StorageManager::with_database_path(temp_dir.path().join("owlen-tests.db")) + .await + .expect("storage"); + + let mut config = Config::default(); + config.general.enable_streaming = true; + config.privacy.encrypt_local_data = false; + config.privacy.require_consent_per_session = true; + config.general.default_model = Some("mock-model".into()); + config.mcp.mode = McpMode::LocalOnly; + config + .refresh_mcp_servers(None) + .expect("refresh MCP servers"); + + let provider: Arc = Arc::new(StreamingToolProvider); + let ui = Arc::new(NoOpUiController); + let (event_tx, mut event_rx) = mpsc::unbounded_channel::(); + + let mut session = SessionController::new( + provider, + config, + Arc::new(storage), + ui, + true, + Some(event_tx), + ) + .await + .expect("session controller"); + + session + .set_operating_mode(Mode::Code) + .await + .expect("code mode"); + + let outcome = session + .send_message( + "Please write to README".to_string(), + ChatParameters { + stream: true, + ..Default::default() + }, + ) + .await + .expect("send message"); + + let (response_id, mut stream) = if let SessionOutcome::Streaming { + response_id, + stream, + } = outcome + { + (response_id, stream) + } else { + panic!("expected streaming outcome"); + }; + + session + .mark_stream_placeholder(response_id, "▌") + .expect("placeholder"); + + let chunk = stream + .next() + .await + .expect("stream chunk") + .expect("chunk result"); + session + .apply_stream_chunk(response_id, &chunk) + .expect("apply chunk"); + + let tool_calls = session + .check_streaming_tool_calls(response_id) + .expect("tool calls"); + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].name, "resources/write"); + + let event = event_rx.recv().await.expect("controller event"); + let request_id = match event { + ControllerEvent::ToolRequested { + request_id, + tool_name, + data_types, + endpoints, + .. + } => { + assert_eq!(tool_name, "resources/write"); + assert!(data_types.iter().any(|t| t.contains("file"))); + assert!(endpoints.iter().any(|e| e.contains("filesystem"))); + request_id + } + }; + + let resolution = session + .resolve_tool_consent(request_id, ConsentScope::Denied) + .expect("resolution"); + assert_eq!(resolution.scope, ConsentScope::Denied); + assert_eq!(resolution.tool_name, "resources/write"); + assert_eq!(resolution.tool_calls.len(), tool_calls.len()); + + let err = session + .resolve_tool_consent(request_id, ConsentScope::Denied) + .expect_err("second resolution should fail"); + matches!(err, Error::InvalidInput(_)); + + let conversation = session.conversation().clone(); + let assistant = conversation + .messages + .iter() + .find(|message| message.role == Role::Assistant) + .expect("assistant message present"); + assert!( + assistant + .tool_calls + .as_ref() + .and_then(|calls| calls.first()) + .is_some_and(|call| call.name == "resources/write"), + "stream chunk should capture the tool call on the assistant message" + ); +} + +#[tokio::test] +async fn web_tool_timeout_fails_over_to_cached_result() { + let primary: Arc = Arc::new(TimeoutClient); + let cached = CachedResponseClient::new(); + let backup: Arc = Arc::new(cached.clone()); + + let client = FailoverMcpClient::with_servers(vec![ + ServerEntry::new("primary".into(), primary, 1), + ServerEntry::new("cache".into(), backup, 2), + ]); + + let call = McpToolCall { + name: "web_search".to_string(), + arguments: serde_json::json!({ "query": "rust", "max_results": 3 }), + }; + + let response = client.call_tool(call.clone()).await.expect("fallback"); + + assert_eq!(response.name, "web_search"); + assert_eq!( + response.metadata.get("source").map(String::as_str), + Some("cache") + ); + assert_eq!( + response.output.get("note").and_then(|value| value.as_str()), + Some("cached result") + ); + + let statuses = client.get_server_status().await; + assert!(statuses.iter().any(|(name, health)| name == "primary" + && !matches!(health, owlen_core::mcp::failover::ServerHealth::Healthy))); + assert!(statuses.iter().any(|(name, health)| name == "cache" + && matches!(health, owlen_core::mcp::failover::ServerHealth::Healthy))); +} diff --git a/crates/owlen-tui/tests/agent_flow_ui.rs b/crates/owlen-tui/tests/agent_flow_ui.rs new file mode 100644 index 0000000..ccc263c --- /dev/null +++ b/crates/owlen-tui/tests/agent_flow_ui.rs @@ -0,0 +1,164 @@ +use std::{any::Any, sync::Arc}; + +use async_trait::async_trait; +use crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; +use futures_util::stream; +use owlen_core::{ + Config, Mode, Provider, + config::McpMode, + session::SessionController, + storage::StorageManager, + types::{ChatResponse, Message, Role, ToolCall}, + ui::{NoOpUiController, UiController}, +}; +use owlen_tui::ChatApp; +use owlen_tui::app::UiRuntime; +use owlen_tui::events::Event; +use tempfile::tempdir; +use tokio::sync::mpsc; + +struct StubProvider; + +#[async_trait] +impl Provider for StubProvider { + fn name(&self) -> &str { + "stub-provider" + } + + async fn list_models(&self) -> owlen_core::Result> { + Ok(vec![owlen_core::types::ModelInfo { + id: "stub-model".into(), + name: "Stub Model".into(), + description: Some("Stub model for testing".into()), + provider: self.name().into(), + context_window: Some(4096), + capabilities: vec!["chat".into()], + supports_tools: true, + }]) + } + + async fn send_prompt( + &self, + _request: owlen_core::types::ChatRequest, + ) -> owlen_core::Result { + Ok(ChatResponse { + message: Message::assistant("stub response".to_string()), + usage: None, + is_streaming: false, + is_final: true, + }) + } + + async fn stream_prompt( + &self, + _request: owlen_core::types::ChatRequest, + ) -> owlen_core::Result { + Ok(Box::pin(stream::empty())) + } + + async fn health_check(&self) -> owlen_core::Result<()> { + Ok(()) + } + + fn as_any(&self) -> &(dyn Any + Send + Sync) { + self + } +} + +#[tokio::test(flavor = "multi_thread")] +async fn denied_consent_appends_apology_message() { + let temp_dir = tempdir().expect("temp dir"); + let storage = Arc::new( + StorageManager::with_database_path(temp_dir.path().join("owlen-tui-tests.db")) + .await + .expect("storage"), + ); + + let mut config = Config::default(); + config.privacy.encrypt_local_data = false; + config.general.default_model = Some("stub-model".into()); + config.mcp.mode = McpMode::LocalOnly; + config + .refresh_mcp_servers(None) + .expect("refresh MCP servers"); + + let provider: Arc = Arc::new(StubProvider); + let ui: Arc = Arc::new(NoOpUiController); + let (event_tx, controller_event_rx) = mpsc::unbounded_channel(); + + // Pre-populate a pending consent request before handing the controller to the TUI. + let mut session = SessionController::new( + Arc::clone(&provider), + config, + Arc::clone(&storage), + Arc::clone(&ui), + true, + Some(event_tx.clone()), + ) + .await + .expect("session controller"); + + session + .set_operating_mode(Mode::Code) + .await + .expect("code mode"); + + let tool_call = ToolCall { + id: "call-1".to_string(), + name: "resources/delete".to_string(), + arguments: serde_json::json!({"path": "/tmp/example.txt"}), + }; + + let message_id = session + .conversation_mut() + .push_assistant_message("Preparing to modify files."); + session + .conversation_mut() + .set_tool_calls_on_message(message_id, vec![tool_call]) + .expect("tool calls"); + + let advertised_calls = session + .check_streaming_tool_calls(message_id) + .expect("queued consent"); + assert_eq!(advertised_calls.len(), 1); + + let (mut app, mut session_rx) = ChatApp::new(session, controller_event_rx) + .await + .expect("chat app"); + // Session events are not used in this test. + session_rx.close(); + + // Process the controller event emitted by check_streaming_tool_calls. + UiRuntime::poll_controller_events(&mut app).expect("poll controller events"); + assert!(app.has_pending_consent()); + + let consent_state = app + .consent_dialog() + .expect("consent dialog should be visible") + .clone(); + assert_eq!(consent_state.tool_name, "resources/delete"); + + // Simulate the user pressing "4" to deny consent. + let deny_key = KeyEvent::new(KeyCode::Char('4'), KeyModifiers::NONE); + UiRuntime::handle_ui_event(&mut app, Event::Key(deny_key)) + .await + .expect("handle deny key"); + + assert!(!app.has_pending_consent()); + assert!( + app.status_message() + .to_lowercase() + .contains("consent denied") + ); + + let conversation = app.conversation(); + let last_message = conversation.messages.last().expect("last message"); + assert_eq!(last_message.role, Role::Assistant); + assert!( + last_message + .content + .to_lowercase() + .contains("consent was denied"), + "assistant should acknowledge the denied consent" + ); +}