use crate::config::{Config, McpResourceConfig, McpServerConfig}; use crate::consent::{ConsentManager, ConsentScope}; use crate::conversation::ConversationManager; use crate::credentials::CredentialManager; use crate::encryption::{self, VaultHandle}; use crate::formatting::MessageFormatter; use crate::input::InputBuffer; use crate::mcp::McpToolCall; use crate::mcp::client::McpClient; use crate::mcp::factory::McpClientFactory; use crate::mcp::permission::PermissionLayer; use crate::mcp::remote_client::{McpRuntimeSecrets, RemoteMcpClient}; use crate::mode::Mode; use crate::model::{DetailedModelInfo, ModelManager}; use crate::oauth::{DeviceAuthorization, DevicePollState, OAuthClient}; use crate::providers::OllamaProvider; use crate::storage::{SessionMeta, StorageManager}; use crate::types::{ ChatParameters, ChatRequest, ChatResponse, Conversation, Message, ModelInfo, ToolCall, }; use crate::ui::{RoleLabelDisplay, UiController}; use crate::validation::{SchemaValidator, get_builtin_schemas}; use crate::{ChatStream, Provider}; use crate::{ CodeExecTool, ResourcesDeleteTool, ResourcesGetTool, ResourcesListTool, ResourcesWriteTool, ToolRegistry, WebScrapeTool, WebSearchDetailedTool, WebSearchTool, }; use crate::{Error, Result}; use chrono::Utc; use log::warn; use serde_json::{Value, json}; use std::collections::HashMap; use std::env; use std::path::PathBuf; use std::sync::{Arc, Mutex}; use tokio::sync::Mutex as TokioMutex; use tokio::sync::mpsc::UnboundedSender; use uuid::Uuid; pub enum SessionOutcome { Complete(ChatResponse), Streaming { response_id: Uuid, stream: ChatStream, }, } #[derive(Debug, Clone)] pub enum ControllerEvent { ToolRequested { request_id: Uuid, message_id: Uuid, tool_name: String, data_types: Vec, endpoints: Vec, tool_calls: Vec, }, } #[derive(Clone, Debug)] struct PendingToolRequest { message_id: Uuid, tool_name: String, data_types: Vec, endpoints: Vec, tool_calls: Vec, } #[derive(Debug, Clone)] pub struct ToolConsentResolution { pub request_id: Uuid, pub message_id: Uuid, pub tool_name: String, pub scope: ConsentScope, pub tool_calls: Vec, } fn extract_resource_content(value: &Value) -> Option { match value { Value::Null => Some(String::new()), Value::Bool(flag) => Some(flag.to_string()), Value::Number(num) => Some(num.to_string()), Value::String(text) => Some(text.clone()), Value::Array(items) => { let mut segments = Vec::new(); for item in items { if let Some(segment) = extract_resource_content(item).filter(|segment| !segment.is_empty()) { segments.push(segment); } } if segments.is_empty() { None } else { Some(segments.join("\n")) } } Value::Object(map) => { const PREFERRED_FIELDS: [&str; 6] = ["content", "contents", "text", "value", "body", "data"]; for key in PREFERRED_FIELDS.iter() { if let Some(text) = map .get(*key) .and_then(extract_resource_content) .filter(|text| !text.is_empty()) { return Some(text); } } if let Some(text) = map .get("chunks") .and_then(extract_resource_content) .filter(|text| !text.is_empty()) { return Some(text); } None } } } pub struct SessionController { provider: Arc, conversation: ConversationManager, model_manager: ModelManager, input_buffer: InputBuffer, formatter: MessageFormatter, config: Arc>, consent_manager: Arc>, tool_registry: Arc, schema_validator: Arc, mcp_client: Arc, named_mcp_clients: HashMap>, storage: Arc, vault: Option>>, master_key: Option>>, credential_manager: Option>, ui: Arc, enable_code_tools: bool, current_mode: Mode, missing_oauth_servers: Vec, event_tx: Option>, pending_tool_requests: HashMap, } async fn build_tools( config: Arc>, ui: Arc, enable_code_tools: bool, consent_manager: Arc>, credential_manager: Option>, vault: Option>>, ) -> Result<(Arc, Arc)> { let mut registry = ToolRegistry::new(config.clone(), ui); let mut validator = SchemaValidator::new(); // Acquire config asynchronously to avoid blocking the async runtime. let config_guard = config.lock().await; for (name, schema) in get_builtin_schemas() { if let Err(err) = validator.register_schema(&name, schema) { warn!("Failed to register built-in schema {name}: {err}"); } } if config_guard .security .allowed_tools .iter() .any(|tool| tool == "web_search") && config_guard.tools.web_search.enabled && config_guard.privacy.enable_remote_search { let tool = WebSearchTool::new( consent_manager.clone(), credential_manager.clone(), vault.clone(), ); registry.register(tool); } // Register web_scrape tool if allowed. if config_guard .security .allowed_tools .iter() .any(|tool| tool == "web_scrape") && config_guard.tools.web_search.enabled // reuse web_search toggle for simplicity && config_guard.privacy.enable_remote_search { let tool = WebScrapeTool::new(); registry.register(tool); } if config_guard .security .allowed_tools .iter() .any(|tool| tool == "web_search") && config_guard.tools.web_search.enabled && config_guard.privacy.enable_remote_search { let tool = WebSearchDetailedTool::new( consent_manager.clone(), credential_manager.clone(), vault.clone(), ); registry.register(tool); } if enable_code_tools && config_guard .security .allowed_tools .iter() .any(|tool| tool == "code_exec") && config_guard.tools.code_exec.enabled { let tool = CodeExecTool::new(config_guard.tools.code_exec.allowed_languages.clone()); registry.register(tool); } registry.register(ResourcesListTool); registry.register(ResourcesGetTool); if config_guard .security .allowed_tools .iter() .any(|t| t == "file_write") { registry.register(ResourcesWriteTool); } if config_guard .security .allowed_tools .iter() .any(|t| t == "file_delete") { registry.register(ResourcesDeleteTool); } for tool in registry.all() { if let Err(err) = validator.register_schema(tool.name(), tool.schema()) { warn!("Failed to register schema for {}: {err}", tool.name()); } } Ok((Arc::new(registry), Arc::new(validator))) } impl SessionController { async fn create_mcp_clients( config: Arc>, tool_registry: Arc, schema_validator: Arc, credential_manager: Option>, initial_mode: Mode, ) -> Result<( Arc, HashMap>, Vec, )> { let guard = config.lock().await; let config_arc = Arc::new(guard.clone()); let factory = McpClientFactory::new(config_arc.clone(), tool_registry, schema_validator); let mut missing_oauth_servers = Vec::new(); let primary_runtime = if let Some(primary_cfg) = guard.effective_mcp_servers().first() { let (runtime, missing) = Self::runtime_secrets_for_server(credential_manager.clone(), primary_cfg).await?; if missing { missing_oauth_servers.push(primary_cfg.name.clone()); } runtime } else { None }; let base_client = factory.create_with_secrets(primary_runtime)?; let primary: Arc = Arc::new(PermissionLayer::new(base_client, config_arc.clone())); primary.set_mode(initial_mode).await?; let mut clients: HashMap> = HashMap::new(); if let Some(primary_cfg) = guard.effective_mcp_servers().first() { clients.insert(primary_cfg.name.clone(), Arc::clone(&primary)); } for server_cfg in guard.effective_mcp_servers().iter().skip(1) { let (runtime, missing) = Self::runtime_secrets_for_server(credential_manager.clone(), server_cfg).await?; if missing { missing_oauth_servers.push(server_cfg.name.clone()); } match RemoteMcpClient::new_with_runtime(server_cfg, runtime) { Ok(remote) => { let client: Arc = Arc::new(PermissionLayer::new(Box::new(remote), config_arc.clone())); if let Err(err) = client.set_mode(initial_mode).await { warn!( "Failed to initialize MCP server '{}' in mode {:?}: {}", server_cfg.name, initial_mode, err ); } clients.insert(server_cfg.name.clone(), Arc::clone(&client)); } Err(err) => warn!( "Failed to initialize MCP server '{}': {}", server_cfg.name, err ), } } drop(guard); Ok((primary, clients, missing_oauth_servers)) } async fn runtime_secrets_for_server( credential_manager: Option>, server: &McpServerConfig, ) -> Result<(Option, bool)> { if let Some(oauth) = &server.oauth { if let Some(manager) = credential_manager { match manager.load_oauth_token(&server.name).await? { Some(token) => { if token.access_token.trim().is_empty() || token.is_expired(Utc::now()) { return Ok((None, true)); } let mut secrets = McpRuntimeSecrets::default(); if let Some(env_name) = oauth.token_env.as_deref() { secrets .env_overrides .insert(env_name.to_string(), token.access_token.clone()); } if matches!( server.transport.to_ascii_lowercase().as_str(), "http" | "websocket" ) { let header_value = format!("{}{}", oauth.header_prefix(), token.access_token); secrets.http_header = Some((oauth.header_name().to_string(), header_value)); } Ok((Some(secrets), false)) } None => Ok((None, true)), } } else { Ok((None, true)) } } else { Ok((None, false)) } } pub async fn new( provider: Arc, config: Config, storage: Arc, ui: Arc, enable_code_tools: bool, event_tx: Option>, ) -> Result { let config_arc = Arc::new(TokioMutex::new(config)); // Acquire the config asynchronously to avoid blocking the runtime. let config_guard = config_arc.lock().await; let model = config_guard .general .default_model .clone() .unwrap_or_else(|| "ollama/default".to_string()); let mut vault_handle: Option>> = None; let mut master_key: Option>> = None; let mut credential_manager: Option> = None; if config_guard.privacy.encrypt_local_data { let base_dir = storage .database_path() .parent() .map(|p| p.to_path_buf()) .or_else(dirs::data_local_dir) .unwrap_or_else(|| PathBuf::from(".")); let secure_path = base_dir.join("encrypted_data.json"); let handle = match env::var("OWLEN_MASTER_PASSWORD") { Ok(password) if !password.is_empty() => { encryption::unlock_with_password(secure_path, &password)? } _ => encryption::unlock_interactive(secure_path)?, }; let master = Arc::new(handle.data.master_key.clone()); master_key = Some(master.clone()); vault_handle = Some(Arc::new(Mutex::new(handle))); credential_manager = Some(Arc::new(CredentialManager::new(storage.clone(), master))); } let consent_manager = if let Some(ref vault) = vault_handle { Arc::new(Mutex::new(ConsentManager::from_vault(vault))) } else { Arc::new(Mutex::new(ConsentManager::new())) }; let conversation = ConversationManager::with_history_capacity( model, config_guard.storage.max_saved_sessions, ); let formatter = MessageFormatter::new( config_guard.ui.wrap_column as usize, config_guard.ui.role_label_mode, ) .with_preserve_empty(config_guard.ui.word_wrap); let input_buffer = InputBuffer::new( config_guard.input.history_size, config_guard.input.multiline, config_guard.input.tab_width, ); let model_manager = ModelManager::new(config_guard.general.model_cache_ttl()); drop(config_guard); // Release the lock before calling build_tools let initial_mode = if enable_code_tools { Mode::Code } else { Mode::Chat }; let (tool_registry, schema_validator) = build_tools( config_arc.clone(), ui.clone(), enable_code_tools, consent_manager.clone(), credential_manager.clone(), vault_handle.clone(), ) .await?; let (mcp_client, named_mcp_clients, missing_oauth_servers) = Self::create_mcp_clients( config_arc.clone(), tool_registry.clone(), schema_validator.clone(), credential_manager.clone(), initial_mode, ) .await?; Ok(Self { provider, conversation, model_manager, input_buffer, formatter, config: config_arc, consent_manager, tool_registry, schema_validator, mcp_client, named_mcp_clients, storage, vault: vault_handle, master_key, credential_manager, ui, enable_code_tools, current_mode: initial_mode, missing_oauth_servers, event_tx, pending_tool_requests: HashMap::new(), }) } pub fn conversation(&self) -> &Conversation { self.conversation.active() } pub fn conversation_mut(&mut self) -> &mut ConversationManager { &mut self.conversation } pub fn input_buffer(&self) -> &InputBuffer { &self.input_buffer } pub fn input_buffer_mut(&mut self) -> &mut InputBuffer { &mut self.input_buffer } pub fn formatter(&self) -> &MessageFormatter { &self.formatter } pub async fn set_formatter_wrap_width(&mut self, width: usize) { self.formatter.set_wrap_width(width); } pub fn set_role_label_mode(&mut self, mode: RoleLabelDisplay) { self.formatter.set_role_label_mode(mode); } /// Return the configured resource references aggregated across scopes. pub async fn configured_resources(&self) -> Vec { let guard = self.config.lock().await; guard.effective_mcp_resources().to_vec() } /// Resolve a resource reference of the form `server:uri` (optionally prefixed with `@`). pub async fn resolve_resource_reference(&self, reference: &str) -> Result> { let (server, uri) = match Self::split_resource_reference(reference) { Some(parts) => parts, None => return Ok(None), }; let resource_defined = { let guard = self.config.lock().await; guard.find_resource(&server, &uri).is_some() }; if !resource_defined { return Ok(None); } let client = self .named_mcp_clients .get(&server) .cloned() .ok_or_else(|| { Error::Config(format!( "MCP server '{}' referenced by resource '{}' is not available", server, uri )) })?; let call = McpToolCall { name: "resources/get".to_string(), arguments: json!({ "uri": uri, "path": uri }), }; let response = client.call_tool(call).await?; if let Some(text) = extract_resource_content(&response.output) { return Ok(Some(text)); } let formatted = serde_json::to_string_pretty(&response.output) .unwrap_or_else(|_| response.output.to_string()); Ok(Some(formatted)) } fn split_resource_reference(reference: &str) -> Option<(String, String)> { let trimmed = reference.trim(); let without_prefix = trimmed.strip_prefix('@').unwrap_or(trimmed); let (server, uri) = without_prefix.split_once(':')?; if server.is_empty() || uri.is_empty() { return None; } Some((server.to_string(), uri.to_string())) } // Asynchronous access to the configuration (used internally). pub async fn config_async(&self) -> tokio::sync::MutexGuard<'_, Config> { self.config.lock().await } // Synchronous, blocking access to the configuration. This is kept for the TUI // which expects `controller.config()` to return a reference without awaiting. // Provide a blocking configuration lock that is safe to call from async // contexts by using `tokio::task::block_in_place`. This allows the current // thread to be blocked without violating Tokio's runtime constraints. pub fn config(&self) -> tokio::sync::MutexGuard<'_, Config> { tokio::task::block_in_place(|| self.config.blocking_lock()) } // Synchronous mutable access, mirroring `config()` but allowing mutation. pub fn config_mut(&self) -> tokio::sync::MutexGuard<'_, Config> { tokio::task::block_in_place(|| self.config.blocking_lock()) } pub fn config_cloned(&self) -> Arc> { self.config.clone() } pub async fn reload_mcp_clients(&mut self) -> Result<()> { let (primary, named, missing) = Self::create_mcp_clients( self.config.clone(), self.tool_registry.clone(), self.schema_validator.clone(), self.credential_manager.clone(), self.current_mode, ) .await?; self.mcp_client = primary; self.named_mcp_clients = named; self.missing_oauth_servers = missing; Ok(()) } pub fn grant_consent(&self, tool_name: &str, data_types: Vec, endpoints: Vec) { let mut consent = self .consent_manager .lock() .expect("Consent manager mutex poisoned"); consent.grant_consent(tool_name, data_types, endpoints); let Some(vault) = &self.vault else { return; }; if let Err(e) = consent.persist_to_vault(vault) { eprintln!("Warning: Failed to persist consent to vault: {}", e); } } pub fn grant_consent_with_scope( &self, tool_name: &str, data_types: Vec, endpoints: Vec, scope: crate::consent::ConsentScope, ) { let mut consent = self .consent_manager .lock() .expect("Consent manager mutex poisoned"); let is_permanent = matches!(scope, crate::consent::ConsentScope::Permanent); consent.grant_consent_with_scope(tool_name, data_types, endpoints, scope); // Only persist to vault for permanent consent if !is_permanent { return; } let Some(vault) = &self.vault else { return; }; if let Err(e) = consent.persist_to_vault(vault) { eprintln!("Warning: Failed to persist consent to vault: {}", e); } } pub fn check_tools_consent_needed( &self, tool_calls: &[ToolCall], ) -> Vec<(String, Vec, Vec)> { let consent = self .consent_manager .lock() .expect("Consent manager mutex poisoned"); let mut needs_consent = Vec::new(); let mut seen_tools = std::collections::HashSet::new(); for tool_call in tool_calls { if seen_tools.contains(&tool_call.name) { continue; } seen_tools.insert(tool_call.name.clone()); let (data_types, endpoints) = match tool_call.name.as_str() { "web_search" | "web_search_detailed" => ( vec!["search query".to_string()], vec!["duckduckgo.com".to_string()], ), "code_exec" => ( vec!["code to execute".to_string()], vec!["local sandbox".to_string()], ), "resources/write" | "file_write" => ( vec!["file paths".to_string(), "file content".to_string()], vec!["local filesystem".to_string()], ), "resources/delete" | "file_delete" => ( vec!["file paths".to_string()], vec!["local filesystem".to_string()], ), _ => (vec![], vec![]), }; if let Some((tool_name, dt, ep)) = consent.check_if_consent_needed(&tool_call.name, data_types, endpoints) { needs_consent.push((tool_name, dt, ep)); } } needs_consent } pub async fn save_active_session( &self, name: Option, description: Option, ) -> Result { self.conversation .save_active_with_description(&self.storage, name, description) .await } pub async fn save_active_session_simple(&self, name: Option) -> Result { self.conversation.save_active(&self.storage, name).await } pub async fn load_saved_session(&mut self, id: Uuid) -> Result<()> { self.conversation.load_saved(&self.storage, id).await } pub async fn list_saved_sessions(&self) -> Result> { ConversationManager::list_saved_sessions(&self.storage).await } pub async fn delete_session(&self, id: Uuid) -> Result<()> { self.storage.delete_session(id).await } pub async fn clear_secure_data(&self) -> Result<()> { // ... (implementation remains the same) Ok(()) } pub fn persist_consent(&self) -> Result<()> { // ... (implementation remains the same) Ok(()) } pub async fn set_tool_enabled(&mut self, tool: &str, enabled: bool) -> Result<()> { { let mut config = self.config.lock().await; match tool { "web_search" => { config.tools.web_search.enabled = enabled; config.privacy.enable_remote_search = enabled; } "code_exec" => config.tools.code_exec.enabled = enabled, other => return Err(Error::InvalidInput(format!("Unknown tool: {other}"))), } } self.rebuild_tools().await } pub fn consent_manager(&self) -> Arc> { self.consent_manager.clone() } pub fn tool_registry(&self) -> Arc { self.tool_registry.clone() } pub fn schema_validator(&self) -> Arc { self.schema_validator.clone() } pub fn credential_manager(&self) -> Option> { self.credential_manager.clone() } pub fn pending_oauth_servers(&self) -> Vec { self.missing_oauth_servers.clone() } pub async fn start_oauth_device_flow(&self, server: &str) -> Result { let oauth_config = { let config = self.config.lock().await; let server_cfg = config .effective_mcp_servers() .iter() .find(|entry| entry.name == server) .ok_or_else(|| { Error::Config(format!("No MCP server named '{server}' is configured")) })?; server_cfg.oauth.clone().ok_or_else(|| { Error::Config(format!( "MCP server '{server}' does not define an OAuth configuration" )) })? }; let client = OAuthClient::new(oauth_config)?; client.start_device_authorization().await } pub async fn poll_oauth_device_flow( &mut self, server: &str, authorization: &DeviceAuthorization, ) -> Result { let oauth_config = { let config = self.config.lock().await; let server_cfg = config .effective_mcp_servers() .iter() .find(|entry| entry.name == server) .ok_or_else(|| { Error::Config(format!("No MCP server named '{server}' is configured")) })?; server_cfg.oauth.clone().ok_or_else(|| { Error::Config(format!( "MCP server '{server}' does not define an OAuth configuration" )) })? }; let client = OAuthClient::new(oauth_config)?; match client.poll_device_token(authorization).await? { DevicePollState::Pending { retry_in } => Ok(DevicePollState::Pending { retry_in }), DevicePollState::Complete(token) => { let manager = self.credential_manager.as_ref().cloned().ok_or_else(|| { Error::Config( "OAuth token storage requires encrypted local data; set \ privacy.encrypt_local_data = true in the configuration." .to_string(), ) })?; manager.store_oauth_token(server, &token).await?; self.missing_oauth_servers.retain(|entry| entry != server); Ok(DevicePollState::Complete(token)) } } } pub async fn list_mcp_tools(&self) -> Vec<(String, crate::mcp::McpToolDescriptor)> { let mut entries = Vec::new(); for (server, client) in self.named_mcp_clients.iter() { let server_name = server.clone(); let client = Arc::clone(client); match client.list_tools().await { Ok(tools) => { for descriptor in tools { entries.push((server_name.clone(), descriptor)); } } Err(err) => { warn!( "Failed to list tools for MCP server '{}': {}", server_name, err ); } } } entries } pub async fn call_mcp_tool( &self, server: &str, tool: &str, arguments: Value, ) -> Result { let client = self.named_mcp_clients.get(server).cloned().ok_or_else(|| { Error::Config(format!("No MCP server named '{}' is registered", server)) })?; client .call_tool(McpToolCall { name: tool.to_string(), arguments, }) .await } pub fn mcp_server(&self) -> crate::mcp::McpServer { crate::mcp::McpServer::new(self.tool_registry(), self.schema_validator()) } pub fn storage(&self) -> Arc { self.storage.clone() } pub fn master_key(&self) -> Option>> { self.master_key.as_ref().map(Arc::clone) } pub fn vault(&self) -> Option>> { self.vault.as_ref().map(Arc::clone) } pub async fn read_file(&self, path: &str) -> Result { let call = McpToolCall { name: "resources/get".to_string(), arguments: serde_json::json!({ "path": path }), }; match self.mcp_client.call_tool(call).await { Ok(response) => { if let Some(text) = extract_resource_content(&response.output) { return Ok(text); } let formatted = serde_json::to_string_pretty(&response.output) .unwrap_or_else(|_| response.output.to_string()); Ok(formatted) } Err(err) => { log::warn!("MCP file read failed ({}); falling back to local read", err); let content = std::fs::read_to_string(path)?; Ok(content) } } } pub async fn read_file_with_tools(&self, path: &str) -> Result { if !self.enable_code_tools { return Err(Error::InvalidInput( "Code tools are disabled in chat mode. Run `:mode code` to switch.".to_string(), )); } let call = McpToolCall { name: "resources/get".to_string(), arguments: serde_json::json!({ "path": path }), }; let response = self.mcp_client.call_tool(call).await?; if let Some(text) = extract_resource_content(&response.output) { Ok(text) } else { let formatted = serde_json::to_string_pretty(&response.output) .unwrap_or_else(|_| response.output.to_string()); Ok(formatted) } } pub fn code_tools_enabled(&self) -> bool { self.enable_code_tools } pub async fn set_code_tools_enabled(&mut self, enabled: bool) -> Result<()> { if self.enable_code_tools == enabled { return Ok(()); } self.enable_code_tools = enabled; self.rebuild_tools().await } pub async fn set_operating_mode(&mut self, mode: Mode) -> Result<()> { self.current_mode = mode; let enable_code_tools = matches!(mode, Mode::Code); self.set_code_tools_enabled(enable_code_tools).await?; self.mcp_client.set_mode(mode).await } pub async fn list_dir(&self, path: &str) -> Result> { let call = McpToolCall { name: "resources/list".to_string(), arguments: serde_json::json!({ "path": path }), }; match self.mcp_client.call_tool(call).await { Ok(response) => { let content: Vec = serde_json::from_value(response.output)?; Ok(content) } Err(err) => { log::warn!( "MCP directory list failed ({}); falling back to local list", err ); let mut entries = Vec::new(); for entry in std::fs::read_dir(path)? { let entry = entry?; entries.push(entry.file_name().to_string_lossy().to_string()); } Ok(entries) } } } pub async fn write_file(&self, path: &str, content: &str) -> Result<()> { let call = McpToolCall { name: "resources/write".to_string(), arguments: serde_json::json!({ "path": path, "content": content }), }; match self.mcp_client.call_tool(call).await { Ok(_) => Ok(()), Err(err) => { log::warn!( "MCP file write failed ({}); falling back to local write", err ); // Ensure parent directory exists if let Some(parent) = std::path::Path::new(path).parent() { std::fs::create_dir_all(parent)?; } std::fs::write(path, content)?; Ok(()) } } } pub async fn delete_file(&self, path: &str) -> Result<()> { let call = McpToolCall { name: "resources/delete".to_string(), arguments: serde_json::json!({ "path": path }), }; match self.mcp_client.call_tool(call).await { Ok(_) => Ok(()), Err(err) => { log::warn!( "MCP file delete failed ({}); falling back to local delete", err ); std::fs::remove_file(path)?; Ok(()) } } } async fn rebuild_tools(&mut self) -> Result<()> { let (registry, validator) = build_tools( self.config.clone(), self.ui.clone(), self.enable_code_tools, self.consent_manager.clone(), self.credential_manager.clone(), self.vault.clone(), ) .await?; self.tool_registry = registry; self.schema_validator = validator; // Recreate MCP client with permission layer let config = self.config.lock().await; let factory = McpClientFactory::new( Arc::new(config.clone()), self.tool_registry.clone(), self.schema_validator.clone(), ); let base_client = factory.create()?; let permission_client = PermissionLayer::new(base_client, Arc::new(config.clone())); let client = Arc::new(permission_client); client.set_mode(self.current_mode).await?; self.mcp_client = client; Ok(()) } pub fn selected_model(&self) -> &str { &self.conversation.active().model } pub async fn set_model(&mut self, model: String) { self.conversation.set_model(model.clone()); let mut config = self.config.lock().await; config.general.default_model = Some(model); } pub async fn models(&self, force_refresh: bool) -> Result> { self.model_manager .get_or_refresh(force_refresh, || async { self.provider.list_models().await }) .await } fn as_ollama(&self) -> Option<&OllamaProvider> { self.provider .as_ref() .as_any() .downcast_ref::() } pub async fn model_details( &self, model_name: &str, force_refresh: bool, ) -> Result { if let Some(ollama) = self.as_ollama() { if force_refresh { ollama.refresh_model_info(model_name).await } else { ollama.get_model_info(model_name).await } } else { Err(Error::NotImplemented(format!( "Provider '{}' does not expose model inspection", self.provider.name() ))) } } pub async fn all_model_details(&self, force_refresh: bool) -> Result> { if let Some(ollama) = self.as_ollama() { if force_refresh { ollama.clear_model_info_cache().await; } ollama.get_all_models_info().await } else { Err(Error::NotImplemented(format!( "Provider '{}' does not expose model inspection", self.provider.name() ))) } } pub async fn cached_model_details(&self) -> Vec { if let Some(ollama) = self.as_ollama() { ollama.cached_model_info().await } else { Vec::new() } } pub async fn invalidate_model_details(&self, model_name: &str) { if let Some(ollama) = self.as_ollama() { ollama.invalidate_model_info(model_name).await; } } pub async fn clear_model_details_cache(&self) { if let Some(ollama) = self.as_ollama() { ollama.clear_model_info_cache().await; } } pub async fn ensure_default_model(&mut self, models: &[ModelInfo]) { let mut config = self.config.lock().await; if let Some(default) = config.general.default_model.clone() { if models.iter().any(|m| m.id == default || m.name == default) { self.conversation.set_model(default.clone()); config.general.default_model = Some(default); } } else if let Some(model) = models.first() { self.conversation.set_model(model.id.clone()); config.general.default_model = Some(model.id.clone()); } } pub async fn switch_provider(&mut self, provider: Arc) -> Result<()> { self.provider = provider; self.model_manager.invalidate().await; Ok(()) } /// Expose the underlying LLM provider. pub fn provider(&self) -> Arc { self.provider.clone() } pub async fn send_message( &mut self, content: String, mut parameters: ChatParameters, ) -> Result { let streaming = { self.config.lock().await.general.enable_streaming || parameters.stream }; parameters.stream = streaming; self.conversation.push_user_message(content); self.send_request_with_current_conversation(parameters) .await } pub async fn send_request_with_current_conversation( &mut self, mut parameters: ChatParameters, ) -> Result { let streaming = { self.config.lock().await.general.enable_streaming || parameters.stream }; parameters.stream = streaming; let tools = if !self.tool_registry.all().is_empty() { Some( self.tool_registry .all() .into_iter() .map(|tool| crate::mcp::McpToolDescriptor { name: tool.name().to_string(), description: tool.description().to_string(), input_schema: tool.schema(), requires_network: tool.requires_network(), requires_filesystem: tool.requires_filesystem(), }) .collect(), ) } else { None }; let mut request = ChatRequest { model: self.conversation.active().model.clone(), messages: self.conversation.active().messages.clone(), parameters: parameters.clone(), tools: tools.clone(), }; if !streaming { const MAX_TOOL_ITERATIONS: usize = 5; for _iteration in 0..MAX_TOOL_ITERATIONS { match self.provider.send_prompt(request.clone()).await { Ok(response) => { if response.message.has_tool_calls() { self.conversation.push_message(response.message.clone()); if let Some(tool_calls) = &response.message.tool_calls { for tool_call in tool_calls { let mcp_tool_call = McpToolCall { name: tool_call.name.clone(), arguments: tool_call.arguments.clone(), }; let tool_result = self.mcp_client.call_tool(mcp_tool_call).await; let tool_response_content = match tool_result { Ok(result) => serde_json::to_string_pretty(&result.output) .unwrap_or_else(|_| { "Tool execution succeeded".to_string() }), Err(e) => format!("Tool execution failed: {}", e), }; let tool_msg = Message::tool(tool_call.id.clone(), tool_response_content); self.conversation.push_message(tool_msg); } } request.messages = self.conversation.active().messages.clone(); continue; } else { self.conversation.push_message(response.message.clone()); return Ok(SessionOutcome::Complete(response)); } } Err(err) => { self.conversation .push_assistant_message(format!("Error: {}", err)); return Err(err); } } } self.conversation .push_assistant_message("Maximum tool execution iterations reached".to_string()); return Err(crate::Error::Provider(anyhow::anyhow!( "Maximum tool execution iterations reached" ))); } match self.provider.stream_prompt(request).await { Ok(stream) => { let response_id = self.conversation.start_streaming_response(); Ok(SessionOutcome::Streaming { response_id, stream, }) } Err(err) => { self.conversation .push_assistant_message(format!("Error starting stream: {}", err)); Err(err) } } } pub fn mark_stream_placeholder(&mut self, message_id: Uuid, text: &str) -> Result<()> { self.conversation .set_stream_placeholder(message_id, text.to_string()) } pub fn apply_stream_chunk(&mut self, message_id: Uuid, chunk: &ChatResponse) -> Result<()> { if chunk.message.has_tool_calls() { self.conversation.set_tool_calls_on_message( message_id, chunk.message.tool_calls.clone().unwrap_or_default(), )?; } self.conversation .append_stream_chunk(message_id, &chunk.message.content, chunk.is_final) } pub fn check_streaming_tool_calls(&mut self, message_id: Uuid) -> Option> { let maybe_calls = self .conversation .active() .messages .iter() .find(|m| m.id == message_id) .and_then(|m| m.tool_calls.clone()) .filter(|calls| !calls.is_empty()); let calls = maybe_calls?; if !self .pending_tool_requests .values() .any(|pending| pending.message_id == message_id) { if let Some((tool_name, data_types, endpoints)) = self.check_tools_consent_needed(&calls).into_iter().next() { let request_id = Uuid::new_v4(); let pending = PendingToolRequest { message_id, tool_name: tool_name.clone(), data_types: data_types.clone(), endpoints: endpoints.clone(), tool_calls: calls.clone(), }; self.pending_tool_requests.insert(request_id, pending); if let Some(tx) = &self.event_tx { let _ = tx.send(ControllerEvent::ToolRequested { request_id, message_id, tool_name, data_types, endpoints, tool_calls: calls.clone(), }); } } } Some(calls) } pub fn resolve_tool_consent( &mut self, request_id: Uuid, scope: ConsentScope, ) -> Result { let pending = self .pending_tool_requests .remove(&request_id) .ok_or_else(|| { Error::InvalidInput(format!("Unknown tool consent request: {}", request_id)) })?; let PendingToolRequest { message_id, tool_name, data_types, endpoints, tool_calls, .. } = pending; if !matches!(scope, ConsentScope::Denied) { self.grant_consent_with_scope(&tool_name, data_types, endpoints, scope.clone()); } Ok(ToolConsentResolution { request_id, message_id, tool_name, scope, tool_calls, }) } pub fn cancel_stream(&mut self, message_id: Uuid, notice: &str) -> Result<()> { self.conversation .cancel_stream(message_id, notice.to_string()) } pub async fn execute_streaming_tools( &mut self, _message_id: Uuid, tool_calls: Vec, ) -> Result { for tool_call in &tool_calls { let mcp_tool_call = McpToolCall { name: tool_call.name.clone(), arguments: tool_call.arguments.clone(), }; let tool_result = self.mcp_client.call_tool(mcp_tool_call).await; let tool_response_content = match tool_result { Ok(result) => serde_json::to_string_pretty(&result.output) .unwrap_or_else(|_| "Tool execution succeeded".to_string()), Err(e) => format!("Tool execution failed: {}", e), }; let tool_msg = Message::tool(tool_call.id.clone(), tool_response_content); self.conversation.push_message(tool_msg); } let parameters = ChatParameters { stream: self.config.lock().await.general.enable_streaming, ..Default::default() }; self.send_request_with_current_conversation(parameters) .await } pub fn history(&self) -> Vec { self.conversation.history().cloned().collect() } pub fn start_new_conversation(&mut self, model: Option, name: Option) { self.conversation.start_new(model, name); } pub fn clear(&mut self) { self.conversation.clear(); } pub async fn generate_conversation_description(&self) -> Result { // ... (implementation remains the same) Ok("Empty conversation".to_string()) } } #[cfg(test)] mod tests { use super::*; use crate::Provider; use crate::config::{Config, McpMode, McpOAuthConfig, McpServerConfig}; use crate::llm::test_utils::MockProvider; use crate::storage::StorageManager; use crate::ui::NoOpUiController; use chrono::Utc; use httpmock::prelude::*; use serde_json::json; use std::collections::HashMap; use std::sync::Arc; use tempfile::tempdir; const SERVER_NAME: &str = "oauth-test"; fn build_oauth_config(server: &MockServer) -> McpOAuthConfig { McpOAuthConfig { client_id: "owlen-client".to_string(), client_secret: None, authorize_url: server.url("/authorize"), token_url: server.url("/token"), device_authorization_url: Some(server.url("/device")), redirect_url: None, scopes: vec!["repo".to_string()], token_env: Some("OAUTH_TOKEN".to_string()), header: Some("Authorization".to_string()), header_prefix: Some("Bearer ".to_string()), } } fn build_config(server: &MockServer) -> Config { let mut config = Config::default(); config.mcp.mode = McpMode::LocalOnly; let oauth = build_oauth_config(server); let mut env = HashMap::new(); env.insert("OWLEN_ENV".to_string(), "test".to_string()); config.mcp_servers = vec![McpServerConfig { name: SERVER_NAME.to_string(), command: server.url("/mcp"), args: Vec::new(), transport: "http".to_string(), env, oauth: Some(oauth), }]; config.refresh_mcp_servers(None).unwrap(); config } async fn build_session(server: &MockServer) -> (SessionController, tempfile::TempDir) { unsafe { std::env::set_var("OWLEN_MASTER_PASSWORD", "test-password"); } let temp_dir = tempdir().expect("tempdir"); let storage_path = temp_dir.path().join("owlen.db"); let storage = Arc::new( StorageManager::with_database_path(storage_path) .await .expect("storage"), ); let config = build_config(server); let provider: Arc = Arc::new(MockProvider::default()) as Arc; let ui = Arc::new(NoOpUiController); let session = SessionController::new(provider, config, storage, ui, false, None) .await .expect("session"); (session, temp_dir) } #[tokio::test] async fn start_oauth_device_flow_returns_details() { let server = MockServer::start_async().await; let device = server .mock_async(|when, then| { when.method(POST).path("/device"); then.status(200) .header("content-type", "application/json") .json_body(json!({ "device_code": "device-abc", "user_code": "ABCD-1234", "verification_uri": "https://example.test/activate", "verification_uri_complete": "https://example.test/activate?user_code=ABCD-1234", "expires_in": 600, "interval": 5, "message": "Enter the code to continue." })); }) .await; let (session, _dir) = build_session(&server).await; let authorization = session .start_oauth_device_flow(SERVER_NAME) .await .expect("device flow"); assert_eq!(authorization.user_code, "ABCD-1234"); assert_eq!( authorization.verification_uri_complete.as_deref(), Some("https://example.test/activate?user_code=ABCD-1234") ); assert!(authorization.expires_at > Utc::now()); device.assert_async().await; } #[tokio::test] async fn poll_oauth_device_flow_stores_token_and_updates_state() { let server = MockServer::start_async().await; let device = server .mock_async(|when, then| { when.method(POST).path("/device"); then.status(200) .header("content-type", "application/json") .json_body(json!({ "device_code": "device-xyz", "user_code": "WXYZ-9999", "verification_uri": "https://example.test/activate", "verification_uri_complete": "https://example.test/activate?user_code=WXYZ-9999", "expires_in": 600, "interval": 5 })); }) .await; let token = server .mock_async(|when, then| { when.method(POST) .path("/token") .body_contains("device_code=device-xyz"); then.status(200) .header("content-type", "application/json") .json_body(json!({ "access_token": "new-access-token", "refresh_token": "refresh-token", "expires_in": 3600, "token_type": "Bearer" })); }) .await; let (mut session, _dir) = build_session(&server).await; assert_eq!(session.pending_oauth_servers(), vec![SERVER_NAME]); let authorization = session .start_oauth_device_flow(SERVER_NAME) .await .expect("device flow"); match session .poll_oauth_device_flow(SERVER_NAME, &authorization) .await .expect("token poll") { DevicePollState::Complete(token_info) => { assert_eq!(token_info.access_token, "new-access-token"); assert_eq!(token_info.refresh_token.as_deref(), Some("refresh-token")); } other => panic!("expected token completion, got {other:?}"), } assert!( session .pending_oauth_servers() .iter() .all(|entry| entry != SERVER_NAME), "server should be removed from pending list" ); let stored = session .credential_manager() .expect("credential manager") .load_oauth_token(SERVER_NAME) .await .expect("load token") .expect("token present"); assert_eq!(stored.access_token, "new-access-token"); assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token")); device.assert_async().await; token.assert_async().await; } }