Introduce `McpCommand` enum and handlers in `owlen-cli` to manage MCP server registrations, including adding, listing, and removing servers across configuration scopes. Add scoped configuration support (`ScopedMcpServer`, `McpConfigScope`) and OAuth token handling in core config, alongside runtime refresh of MCP servers. Implement toast notifications in the TUI (`render_toasts`, `Toast`, `ToastLevel`) and integrate async handling for session events. Update config loading, validation, and schema versioning to accommodate new MCP scopes and resources. Add `httpmock` as a dev dependency for testing.
1470 lines
51 KiB
Rust
1470 lines
51 KiB
Rust
use crate::config::{Config, McpResourceConfig, McpServerConfig};
|
|
use crate::consent::ConsentManager;
|
|
use crate::conversation::ConversationManager;
|
|
use crate::credentials::CredentialManager;
|
|
use crate::encryption::{self, VaultHandle};
|
|
use crate::formatting::MessageFormatter;
|
|
use crate::input::InputBuffer;
|
|
use crate::mcp::McpToolCall;
|
|
use crate::mcp::client::McpClient;
|
|
use crate::mcp::factory::McpClientFactory;
|
|
use crate::mcp::permission::PermissionLayer;
|
|
use crate::mcp::remote_client::{McpRuntimeSecrets, RemoteMcpClient};
|
|
use crate::mode::Mode;
|
|
use crate::model::{DetailedModelInfo, ModelManager};
|
|
use crate::oauth::{DeviceAuthorization, DevicePollState, OAuthClient};
|
|
use crate::providers::OllamaProvider;
|
|
use crate::storage::{SessionMeta, StorageManager};
|
|
use crate::types::{
|
|
ChatParameters, ChatRequest, ChatResponse, Conversation, Message, ModelInfo, ToolCall,
|
|
};
|
|
use crate::ui::{RoleLabelDisplay, UiController};
|
|
use crate::validation::{SchemaValidator, get_builtin_schemas};
|
|
use crate::{ChatStream, Provider};
|
|
use crate::{
|
|
CodeExecTool, ResourcesDeleteTool, ResourcesGetTool, ResourcesListTool, ResourcesWriteTool,
|
|
ToolRegistry, WebScrapeTool, WebSearchDetailedTool, WebSearchTool,
|
|
};
|
|
use crate::{Error, Result};
|
|
use chrono::Utc;
|
|
use log::warn;
|
|
use serde_json::{Value, json};
|
|
use std::collections::HashMap;
|
|
use std::env;
|
|
use std::path::PathBuf;
|
|
use std::sync::{Arc, Mutex};
|
|
use tokio::sync::Mutex as TokioMutex;
|
|
use uuid::Uuid;
|
|
|
|
pub enum SessionOutcome {
|
|
Complete(ChatResponse),
|
|
Streaming {
|
|
response_id: Uuid,
|
|
stream: ChatStream,
|
|
},
|
|
}
|
|
|
|
fn extract_resource_content(value: &Value) -> Option<String> {
|
|
match value {
|
|
Value::Null => Some(String::new()),
|
|
Value::Bool(flag) => Some(flag.to_string()),
|
|
Value::Number(num) => Some(num.to_string()),
|
|
Value::String(text) => Some(text.clone()),
|
|
Value::Array(items) => {
|
|
let mut segments = Vec::new();
|
|
for item in items {
|
|
if let Some(segment) = extract_resource_content(item)
|
|
&& !segment.is_empty()
|
|
{
|
|
segments.push(segment);
|
|
}
|
|
}
|
|
if segments.is_empty() {
|
|
None
|
|
} else {
|
|
Some(segments.join("\n"))
|
|
}
|
|
}
|
|
Value::Object(map) => {
|
|
const PREFERRED_FIELDS: [&str; 6] =
|
|
["content", "contents", "text", "value", "body", "data"];
|
|
for key in PREFERRED_FIELDS.iter() {
|
|
if let Some(inner) = map.get(*key)
|
|
&& let Some(text) = extract_resource_content(inner)
|
|
&& !text.is_empty()
|
|
{
|
|
return Some(text);
|
|
}
|
|
}
|
|
|
|
if let Some(inner) = map.get("chunks")
|
|
&& let Some(text) = extract_resource_content(inner)
|
|
&& !text.is_empty()
|
|
{
|
|
return Some(text);
|
|
}
|
|
|
|
None
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct SessionController {
|
|
provider: Arc<dyn Provider>,
|
|
conversation: ConversationManager,
|
|
model_manager: ModelManager,
|
|
input_buffer: InputBuffer,
|
|
formatter: MessageFormatter,
|
|
config: Arc<TokioMutex<Config>>,
|
|
consent_manager: Arc<Mutex<ConsentManager>>,
|
|
tool_registry: Arc<ToolRegistry>,
|
|
schema_validator: Arc<SchemaValidator>,
|
|
mcp_client: Arc<dyn McpClient>,
|
|
named_mcp_clients: HashMap<String, Arc<dyn McpClient>>,
|
|
storage: Arc<StorageManager>,
|
|
vault: Option<Arc<Mutex<VaultHandle>>>,
|
|
master_key: Option<Arc<Vec<u8>>>,
|
|
credential_manager: Option<Arc<CredentialManager>>,
|
|
ui: Arc<dyn UiController>,
|
|
enable_code_tools: bool,
|
|
current_mode: Mode,
|
|
missing_oauth_servers: Vec<String>,
|
|
}
|
|
|
|
async fn build_tools(
|
|
config: Arc<TokioMutex<Config>>,
|
|
ui: Arc<dyn UiController>,
|
|
enable_code_tools: bool,
|
|
consent_manager: Arc<Mutex<ConsentManager>>,
|
|
credential_manager: Option<Arc<CredentialManager>>,
|
|
vault: Option<Arc<Mutex<VaultHandle>>>,
|
|
) -> Result<(Arc<ToolRegistry>, Arc<SchemaValidator>)> {
|
|
let mut registry = ToolRegistry::new(config.clone(), ui);
|
|
let mut validator = SchemaValidator::new();
|
|
// Acquire config asynchronously to avoid blocking the async runtime.
|
|
let config_guard = config.lock().await;
|
|
|
|
for (name, schema) in get_builtin_schemas() {
|
|
if let Err(err) = validator.register_schema(&name, schema) {
|
|
warn!("Failed to register built-in schema {name}: {err}");
|
|
}
|
|
}
|
|
|
|
if config_guard
|
|
.security
|
|
.allowed_tools
|
|
.iter()
|
|
.any(|tool| tool == "web_search")
|
|
&& config_guard.tools.web_search.enabled
|
|
&& config_guard.privacy.enable_remote_search
|
|
{
|
|
let tool = WebSearchTool::new(
|
|
consent_manager.clone(),
|
|
credential_manager.clone(),
|
|
vault.clone(),
|
|
);
|
|
registry.register(tool);
|
|
}
|
|
|
|
// Register web_scrape tool if allowed.
|
|
if config_guard
|
|
.security
|
|
.allowed_tools
|
|
.iter()
|
|
.any(|tool| tool == "web_scrape")
|
|
&& config_guard.tools.web_search.enabled // reuse web_search toggle for simplicity
|
|
&& config_guard.privacy.enable_remote_search
|
|
{
|
|
let tool = WebScrapeTool::new();
|
|
registry.register(tool);
|
|
}
|
|
|
|
if config_guard
|
|
.security
|
|
.allowed_tools
|
|
.iter()
|
|
.any(|tool| tool == "web_search")
|
|
&& config_guard.tools.web_search.enabled
|
|
&& config_guard.privacy.enable_remote_search
|
|
{
|
|
let tool = WebSearchDetailedTool::new(
|
|
consent_manager.clone(),
|
|
credential_manager.clone(),
|
|
vault.clone(),
|
|
);
|
|
registry.register(tool);
|
|
}
|
|
|
|
if enable_code_tools
|
|
&& config_guard
|
|
.security
|
|
.allowed_tools
|
|
.iter()
|
|
.any(|tool| tool == "code_exec")
|
|
&& config_guard.tools.code_exec.enabled
|
|
{
|
|
let tool = CodeExecTool::new(config_guard.tools.code_exec.allowed_languages.clone());
|
|
registry.register(tool);
|
|
}
|
|
|
|
registry.register(ResourcesListTool);
|
|
registry.register(ResourcesGetTool);
|
|
|
|
if config_guard
|
|
.security
|
|
.allowed_tools
|
|
.iter()
|
|
.any(|t| t == "file_write")
|
|
{
|
|
registry.register(ResourcesWriteTool);
|
|
}
|
|
if config_guard
|
|
.security
|
|
.allowed_tools
|
|
.iter()
|
|
.any(|t| t == "file_delete")
|
|
{
|
|
registry.register(ResourcesDeleteTool);
|
|
}
|
|
|
|
for tool in registry.all() {
|
|
if let Err(err) = validator.register_schema(tool.name(), tool.schema()) {
|
|
warn!("Failed to register schema for {}: {err}", tool.name());
|
|
}
|
|
}
|
|
|
|
Ok((Arc::new(registry), Arc::new(validator)))
|
|
}
|
|
|
|
impl SessionController {
|
|
async fn create_mcp_clients(
|
|
config: Arc<TokioMutex<Config>>,
|
|
tool_registry: Arc<ToolRegistry>,
|
|
schema_validator: Arc<SchemaValidator>,
|
|
credential_manager: Option<Arc<CredentialManager>>,
|
|
initial_mode: Mode,
|
|
) -> Result<(
|
|
Arc<dyn McpClient>,
|
|
HashMap<String, Arc<dyn McpClient>>,
|
|
Vec<String>,
|
|
)> {
|
|
let guard = config.lock().await;
|
|
let config_arc = Arc::new(guard.clone());
|
|
let factory = McpClientFactory::new(config_arc.clone(), tool_registry, schema_validator);
|
|
|
|
let mut missing_oauth_servers = Vec::new();
|
|
let primary_runtime = if let Some(primary_cfg) = guard.effective_mcp_servers().first() {
|
|
let (runtime, missing) =
|
|
Self::runtime_secrets_for_server(credential_manager.clone(), primary_cfg).await?;
|
|
if missing {
|
|
missing_oauth_servers.push(primary_cfg.name.clone());
|
|
}
|
|
runtime
|
|
} else {
|
|
None
|
|
};
|
|
|
|
let base_client = factory.create_with_secrets(primary_runtime)?;
|
|
let primary: Arc<dyn McpClient> =
|
|
Arc::new(PermissionLayer::new(base_client, config_arc.clone()));
|
|
primary.set_mode(initial_mode).await?;
|
|
|
|
let mut clients: HashMap<String, Arc<dyn McpClient>> = HashMap::new();
|
|
if let Some(primary_cfg) = guard.effective_mcp_servers().first() {
|
|
clients.insert(primary_cfg.name.clone(), Arc::clone(&primary));
|
|
}
|
|
|
|
for server_cfg in guard.effective_mcp_servers().iter().skip(1) {
|
|
let (runtime, missing) =
|
|
Self::runtime_secrets_for_server(credential_manager.clone(), server_cfg).await?;
|
|
if missing {
|
|
missing_oauth_servers.push(server_cfg.name.clone());
|
|
}
|
|
|
|
match RemoteMcpClient::new_with_runtime(server_cfg, runtime) {
|
|
Ok(remote) => {
|
|
let client: Arc<dyn McpClient> =
|
|
Arc::new(PermissionLayer::new(Box::new(remote), config_arc.clone()));
|
|
if let Err(err) = client.set_mode(initial_mode).await {
|
|
warn!(
|
|
"Failed to initialize MCP server '{}' in mode {:?}: {}",
|
|
server_cfg.name, initial_mode, err
|
|
);
|
|
}
|
|
clients.insert(server_cfg.name.clone(), Arc::clone(&client));
|
|
}
|
|
Err(err) => warn!(
|
|
"Failed to initialize MCP server '{}': {}",
|
|
server_cfg.name, err
|
|
),
|
|
}
|
|
}
|
|
|
|
drop(guard);
|
|
|
|
Ok((primary, clients, missing_oauth_servers))
|
|
}
|
|
|
|
async fn runtime_secrets_for_server(
|
|
credential_manager: Option<Arc<CredentialManager>>,
|
|
server: &McpServerConfig,
|
|
) -> Result<(Option<McpRuntimeSecrets>, bool)> {
|
|
if let Some(oauth) = &server.oauth {
|
|
if let Some(manager) = credential_manager {
|
|
match manager.load_oauth_token(&server.name).await? {
|
|
Some(token) => {
|
|
if token.access_token.trim().is_empty() || token.is_expired(Utc::now()) {
|
|
return Ok((None, true));
|
|
}
|
|
let mut secrets = McpRuntimeSecrets::default();
|
|
if let Some(env_name) = oauth.token_env.as_deref() {
|
|
secrets
|
|
.env_overrides
|
|
.insert(env_name.to_string(), token.access_token.clone());
|
|
}
|
|
if matches!(
|
|
server.transport.to_ascii_lowercase().as_str(),
|
|
"http" | "websocket"
|
|
) {
|
|
let header_value =
|
|
format!("{}{}", oauth.header_prefix(), token.access_token);
|
|
secrets.http_header =
|
|
Some((oauth.header_name().to_string(), header_value));
|
|
}
|
|
Ok((Some(secrets), false))
|
|
}
|
|
None => Ok((None, true)),
|
|
}
|
|
} else {
|
|
Ok((None, true))
|
|
}
|
|
} else {
|
|
Ok((None, false))
|
|
}
|
|
}
|
|
|
|
pub async fn new(
|
|
provider: Arc<dyn Provider>,
|
|
config: Config,
|
|
storage: Arc<StorageManager>,
|
|
ui: Arc<dyn UiController>,
|
|
enable_code_tools: bool,
|
|
) -> Result<Self> {
|
|
let config_arc = Arc::new(TokioMutex::new(config));
|
|
// Acquire the config asynchronously to avoid blocking the runtime.
|
|
let config_guard = config_arc.lock().await;
|
|
|
|
let model = config_guard
|
|
.general
|
|
.default_model
|
|
.clone()
|
|
.unwrap_or_else(|| "ollama/default".to_string());
|
|
|
|
let mut vault_handle: Option<Arc<Mutex<VaultHandle>>> = None;
|
|
let mut master_key: Option<Arc<Vec<u8>>> = None;
|
|
let mut credential_manager: Option<Arc<CredentialManager>> = None;
|
|
|
|
if config_guard.privacy.encrypt_local_data {
|
|
let base_dir = storage
|
|
.database_path()
|
|
.parent()
|
|
.map(|p| p.to_path_buf())
|
|
.or_else(dirs::data_local_dir)
|
|
.unwrap_or_else(|| PathBuf::from("."));
|
|
let secure_path = base_dir.join("encrypted_data.json");
|
|
let handle = match env::var("OWLEN_MASTER_PASSWORD") {
|
|
Ok(password) if !password.is_empty() => {
|
|
encryption::unlock_with_password(secure_path, &password)?
|
|
}
|
|
_ => encryption::unlock_interactive(secure_path)?,
|
|
};
|
|
let master = Arc::new(handle.data.master_key.clone());
|
|
master_key = Some(master.clone());
|
|
vault_handle = Some(Arc::new(Mutex::new(handle)));
|
|
credential_manager = Some(Arc::new(CredentialManager::new(storage.clone(), master)));
|
|
}
|
|
|
|
let consent_manager = if let Some(ref vault) = vault_handle {
|
|
Arc::new(Mutex::new(ConsentManager::from_vault(vault)))
|
|
} else {
|
|
Arc::new(Mutex::new(ConsentManager::new()))
|
|
};
|
|
|
|
let conversation = ConversationManager::with_history_capacity(
|
|
model,
|
|
config_guard.storage.max_saved_sessions,
|
|
);
|
|
let formatter = MessageFormatter::new(
|
|
config_guard.ui.wrap_column as usize,
|
|
config_guard.ui.role_label_mode,
|
|
)
|
|
.with_preserve_empty(config_guard.ui.word_wrap);
|
|
let input_buffer = InputBuffer::new(
|
|
config_guard.input.history_size,
|
|
config_guard.input.multiline,
|
|
config_guard.input.tab_width,
|
|
);
|
|
let model_manager = ModelManager::new(config_guard.general.model_cache_ttl());
|
|
|
|
drop(config_guard); // Release the lock before calling build_tools
|
|
|
|
let initial_mode = if enable_code_tools {
|
|
Mode::Code
|
|
} else {
|
|
Mode::Chat
|
|
};
|
|
|
|
let (tool_registry, schema_validator) = build_tools(
|
|
config_arc.clone(),
|
|
ui.clone(),
|
|
enable_code_tools,
|
|
consent_manager.clone(),
|
|
credential_manager.clone(),
|
|
vault_handle.clone(),
|
|
)
|
|
.await?;
|
|
|
|
let (mcp_client, named_mcp_clients, missing_oauth_servers) = Self::create_mcp_clients(
|
|
config_arc.clone(),
|
|
tool_registry.clone(),
|
|
schema_validator.clone(),
|
|
credential_manager.clone(),
|
|
initial_mode,
|
|
)
|
|
.await?;
|
|
|
|
Ok(Self {
|
|
provider,
|
|
conversation,
|
|
model_manager,
|
|
input_buffer,
|
|
formatter,
|
|
config: config_arc,
|
|
consent_manager,
|
|
tool_registry,
|
|
schema_validator,
|
|
mcp_client,
|
|
named_mcp_clients,
|
|
storage,
|
|
vault: vault_handle,
|
|
master_key,
|
|
credential_manager,
|
|
ui,
|
|
enable_code_tools,
|
|
current_mode: initial_mode,
|
|
missing_oauth_servers,
|
|
})
|
|
}
|
|
|
|
pub fn conversation(&self) -> &Conversation {
|
|
self.conversation.active()
|
|
}
|
|
|
|
pub fn conversation_mut(&mut self) -> &mut ConversationManager {
|
|
&mut self.conversation
|
|
}
|
|
|
|
pub fn input_buffer(&self) -> &InputBuffer {
|
|
&self.input_buffer
|
|
}
|
|
|
|
pub fn input_buffer_mut(&mut self) -> &mut InputBuffer {
|
|
&mut self.input_buffer
|
|
}
|
|
|
|
pub fn formatter(&self) -> &MessageFormatter {
|
|
&self.formatter
|
|
}
|
|
|
|
pub async fn set_formatter_wrap_width(&mut self, width: usize) {
|
|
self.formatter.set_wrap_width(width);
|
|
}
|
|
|
|
pub fn set_role_label_mode(&mut self, mode: RoleLabelDisplay) {
|
|
self.formatter.set_role_label_mode(mode);
|
|
}
|
|
|
|
/// Return the configured resource references aggregated across scopes.
|
|
pub async fn configured_resources(&self) -> Vec<McpResourceConfig> {
|
|
let guard = self.config.lock().await;
|
|
guard.effective_mcp_resources().to_vec()
|
|
}
|
|
|
|
/// Resolve a resource reference of the form `server:uri` (optionally prefixed with `@`).
|
|
pub async fn resolve_resource_reference(&self, reference: &str) -> Result<Option<String>> {
|
|
let (server, uri) = match Self::split_resource_reference(reference) {
|
|
Some(parts) => parts,
|
|
None => return Ok(None),
|
|
};
|
|
|
|
let resource_defined = {
|
|
let guard = self.config.lock().await;
|
|
guard.find_resource(&server, &uri).is_some()
|
|
};
|
|
|
|
if !resource_defined {
|
|
return Ok(None);
|
|
}
|
|
|
|
let client = self
|
|
.named_mcp_clients
|
|
.get(&server)
|
|
.cloned()
|
|
.ok_or_else(|| {
|
|
Error::Config(format!(
|
|
"MCP server '{}' referenced by resource '{}' is not available",
|
|
server, uri
|
|
))
|
|
})?;
|
|
|
|
let call = McpToolCall {
|
|
name: "resources/get".to_string(),
|
|
arguments: json!({ "uri": uri, "path": uri }),
|
|
};
|
|
let response = client.call_tool(call).await?;
|
|
if let Some(text) = extract_resource_content(&response.output) {
|
|
return Ok(Some(text));
|
|
}
|
|
|
|
let formatted = serde_json::to_string_pretty(&response.output)
|
|
.unwrap_or_else(|_| response.output.to_string());
|
|
Ok(Some(formatted))
|
|
}
|
|
|
|
fn split_resource_reference(reference: &str) -> Option<(String, String)> {
|
|
let trimmed = reference.trim();
|
|
let without_prefix = trimmed.strip_prefix('@').unwrap_or(trimmed);
|
|
let (server, uri) = without_prefix.split_once(':')?;
|
|
if server.is_empty() || uri.is_empty() {
|
|
return None;
|
|
}
|
|
Some((server.to_string(), uri.to_string()))
|
|
}
|
|
|
|
// Asynchronous access to the configuration (used internally).
|
|
pub async fn config_async(&self) -> tokio::sync::MutexGuard<'_, Config> {
|
|
self.config.lock().await
|
|
}
|
|
|
|
// Synchronous, blocking access to the configuration. This is kept for the TUI
|
|
// which expects `controller.config()` to return a reference without awaiting.
|
|
// Provide a blocking configuration lock that is safe to call from async
|
|
// contexts by using `tokio::task::block_in_place`. This allows the current
|
|
// thread to be blocked without violating Tokio's runtime constraints.
|
|
pub fn config(&self) -> tokio::sync::MutexGuard<'_, Config> {
|
|
tokio::task::block_in_place(|| self.config.blocking_lock())
|
|
}
|
|
|
|
// Synchronous mutable access, mirroring `config()` but allowing mutation.
|
|
pub fn config_mut(&self) -> tokio::sync::MutexGuard<'_, Config> {
|
|
tokio::task::block_in_place(|| self.config.blocking_lock())
|
|
}
|
|
|
|
pub fn config_cloned(&self) -> Arc<TokioMutex<Config>> {
|
|
self.config.clone()
|
|
}
|
|
|
|
pub async fn reload_mcp_clients(&mut self) -> Result<()> {
|
|
let (primary, named, missing) = Self::create_mcp_clients(
|
|
self.config.clone(),
|
|
self.tool_registry.clone(),
|
|
self.schema_validator.clone(),
|
|
self.credential_manager.clone(),
|
|
self.current_mode,
|
|
)
|
|
.await?;
|
|
self.mcp_client = primary;
|
|
self.named_mcp_clients = named;
|
|
self.missing_oauth_servers = missing;
|
|
Ok(())
|
|
}
|
|
|
|
pub fn grant_consent(&self, tool_name: &str, data_types: Vec<String>, endpoints: Vec<String>) {
|
|
let mut consent = self
|
|
.consent_manager
|
|
.lock()
|
|
.expect("Consent manager mutex poisoned");
|
|
consent.grant_consent(tool_name, data_types, endpoints);
|
|
|
|
if let Some(vault) = &self.vault
|
|
&& let Err(e) = consent.persist_to_vault(vault)
|
|
{
|
|
eprintln!("Warning: Failed to persist consent to vault: {}", e);
|
|
}
|
|
}
|
|
|
|
pub fn grant_consent_with_scope(
|
|
&self,
|
|
tool_name: &str,
|
|
data_types: Vec<String>,
|
|
endpoints: Vec<String>,
|
|
scope: crate::consent::ConsentScope,
|
|
) {
|
|
let mut consent = self
|
|
.consent_manager
|
|
.lock()
|
|
.expect("Consent manager mutex poisoned");
|
|
let is_permanent = matches!(scope, crate::consent::ConsentScope::Permanent);
|
|
consent.grant_consent_with_scope(tool_name, data_types, endpoints, scope);
|
|
|
|
// Only persist to vault for permanent consent
|
|
if is_permanent
|
|
&& let Some(vault) = &self.vault
|
|
&& let Err(e) = consent.persist_to_vault(vault)
|
|
{
|
|
eprintln!("Warning: Failed to persist consent to vault: {}", e);
|
|
}
|
|
}
|
|
|
|
pub fn check_tools_consent_needed(
|
|
&self,
|
|
tool_calls: &[ToolCall],
|
|
) -> Vec<(String, Vec<String>, Vec<String>)> {
|
|
let consent = self
|
|
.consent_manager
|
|
.lock()
|
|
.expect("Consent manager mutex poisoned");
|
|
let mut needs_consent = Vec::new();
|
|
let mut seen_tools = std::collections::HashSet::new();
|
|
|
|
for tool_call in tool_calls {
|
|
if seen_tools.contains(&tool_call.name) {
|
|
continue;
|
|
}
|
|
seen_tools.insert(tool_call.name.clone());
|
|
|
|
let (data_types, endpoints) = match tool_call.name.as_str() {
|
|
"web_search" | "web_search_detailed" => (
|
|
vec!["search query".to_string()],
|
|
vec!["duckduckgo.com".to_string()],
|
|
),
|
|
"code_exec" => (
|
|
vec!["code to execute".to_string()],
|
|
vec!["local sandbox".to_string()],
|
|
),
|
|
"resources/write" | "file_write" => (
|
|
vec!["file paths".to_string(), "file content".to_string()],
|
|
vec!["local filesystem".to_string()],
|
|
),
|
|
"resources/delete" | "file_delete" => (
|
|
vec!["file paths".to_string()],
|
|
vec!["local filesystem".to_string()],
|
|
),
|
|
_ => (vec![], vec![]),
|
|
};
|
|
|
|
if let Some((tool_name, dt, ep)) =
|
|
consent.check_if_consent_needed(&tool_call.name, data_types, endpoints)
|
|
{
|
|
needs_consent.push((tool_name, dt, ep));
|
|
}
|
|
}
|
|
|
|
needs_consent
|
|
}
|
|
|
|
pub async fn save_active_session(
|
|
&self,
|
|
name: Option<String>,
|
|
description: Option<String>,
|
|
) -> Result<Uuid> {
|
|
self.conversation
|
|
.save_active_with_description(&self.storage, name, description)
|
|
.await
|
|
}
|
|
|
|
pub async fn save_active_session_simple(&self, name: Option<String>) -> Result<Uuid> {
|
|
self.conversation.save_active(&self.storage, name).await
|
|
}
|
|
|
|
pub async fn load_saved_session(&mut self, id: Uuid) -> Result<()> {
|
|
self.conversation.load_saved(&self.storage, id).await
|
|
}
|
|
|
|
pub async fn list_saved_sessions(&self) -> Result<Vec<SessionMeta>> {
|
|
ConversationManager::list_saved_sessions(&self.storage).await
|
|
}
|
|
|
|
pub async fn delete_session(&self, id: Uuid) -> Result<()> {
|
|
self.storage.delete_session(id).await
|
|
}
|
|
|
|
pub async fn clear_secure_data(&self) -> Result<()> {
|
|
// ... (implementation remains the same)
|
|
Ok(())
|
|
}
|
|
|
|
pub fn persist_consent(&self) -> Result<()> {
|
|
// ... (implementation remains the same)
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn set_tool_enabled(&mut self, tool: &str, enabled: bool) -> Result<()> {
|
|
{
|
|
let mut config = self.config.lock().await;
|
|
match tool {
|
|
"web_search" => {
|
|
config.tools.web_search.enabled = enabled;
|
|
config.privacy.enable_remote_search = enabled;
|
|
}
|
|
"code_exec" => config.tools.code_exec.enabled = enabled,
|
|
other => return Err(Error::InvalidInput(format!("Unknown tool: {other}"))),
|
|
}
|
|
}
|
|
self.rebuild_tools().await
|
|
}
|
|
|
|
pub fn consent_manager(&self) -> Arc<Mutex<ConsentManager>> {
|
|
self.consent_manager.clone()
|
|
}
|
|
|
|
pub fn tool_registry(&self) -> Arc<ToolRegistry> {
|
|
self.tool_registry.clone()
|
|
}
|
|
|
|
pub fn schema_validator(&self) -> Arc<SchemaValidator> {
|
|
self.schema_validator.clone()
|
|
}
|
|
|
|
pub fn credential_manager(&self) -> Option<Arc<CredentialManager>> {
|
|
self.credential_manager.clone()
|
|
}
|
|
|
|
pub fn pending_oauth_servers(&self) -> Vec<String> {
|
|
self.missing_oauth_servers.clone()
|
|
}
|
|
|
|
pub async fn start_oauth_device_flow(&self, server: &str) -> Result<DeviceAuthorization> {
|
|
let oauth_config = {
|
|
let config = self.config.lock().await;
|
|
let server_cfg = config
|
|
.effective_mcp_servers()
|
|
.iter()
|
|
.find(|entry| entry.name == server)
|
|
.ok_or_else(|| {
|
|
Error::Config(format!("No MCP server named '{server}' is configured"))
|
|
})?;
|
|
server_cfg.oauth.clone().ok_or_else(|| {
|
|
Error::Config(format!(
|
|
"MCP server '{server}' does not define an OAuth configuration"
|
|
))
|
|
})?
|
|
};
|
|
|
|
let client = OAuthClient::new(oauth_config)?;
|
|
client.start_device_authorization().await
|
|
}
|
|
|
|
pub async fn poll_oauth_device_flow(
|
|
&mut self,
|
|
server: &str,
|
|
authorization: &DeviceAuthorization,
|
|
) -> Result<DevicePollState> {
|
|
let oauth_config = {
|
|
let config = self.config.lock().await;
|
|
let server_cfg = config
|
|
.effective_mcp_servers()
|
|
.iter()
|
|
.find(|entry| entry.name == server)
|
|
.ok_or_else(|| {
|
|
Error::Config(format!("No MCP server named '{server}' is configured"))
|
|
})?;
|
|
server_cfg.oauth.clone().ok_or_else(|| {
|
|
Error::Config(format!(
|
|
"MCP server '{server}' does not define an OAuth configuration"
|
|
))
|
|
})?
|
|
};
|
|
|
|
let client = OAuthClient::new(oauth_config)?;
|
|
match client.poll_device_token(authorization).await? {
|
|
DevicePollState::Pending { retry_in } => Ok(DevicePollState::Pending { retry_in }),
|
|
DevicePollState::Complete(token) => {
|
|
let manager = self.credential_manager.as_ref().cloned().ok_or_else(|| {
|
|
Error::Config(
|
|
"OAuth token storage requires encrypted local data; set \
|
|
privacy.encrypt_local_data = true in the configuration."
|
|
.to_string(),
|
|
)
|
|
})?;
|
|
|
|
manager.store_oauth_token(server, &token).await?;
|
|
self.missing_oauth_servers.retain(|entry| entry != server);
|
|
|
|
Ok(DevicePollState::Complete(token))
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn list_mcp_tools(&self) -> Vec<(String, crate::mcp::McpToolDescriptor)> {
|
|
let mut entries = Vec::new();
|
|
for (server, client) in self.named_mcp_clients.iter() {
|
|
let server_name = server.clone();
|
|
let client = Arc::clone(client);
|
|
match client.list_tools().await {
|
|
Ok(tools) => {
|
|
for descriptor in tools {
|
|
entries.push((server_name.clone(), descriptor));
|
|
}
|
|
}
|
|
Err(err) => {
|
|
warn!(
|
|
"Failed to list tools for MCP server '{}': {}",
|
|
server_name, err
|
|
);
|
|
}
|
|
}
|
|
}
|
|
entries
|
|
}
|
|
|
|
pub async fn call_mcp_tool(
|
|
&self,
|
|
server: &str,
|
|
tool: &str,
|
|
arguments: Value,
|
|
) -> Result<crate::mcp::McpToolResponse> {
|
|
let client = self.named_mcp_clients.get(server).cloned().ok_or_else(|| {
|
|
Error::Config(format!("No MCP server named '{}' is registered", server))
|
|
})?;
|
|
client
|
|
.call_tool(McpToolCall {
|
|
name: tool.to_string(),
|
|
arguments,
|
|
})
|
|
.await
|
|
}
|
|
|
|
pub fn mcp_server(&self) -> crate::mcp::McpServer {
|
|
crate::mcp::McpServer::new(self.tool_registry(), self.schema_validator())
|
|
}
|
|
|
|
pub fn storage(&self) -> Arc<StorageManager> {
|
|
self.storage.clone()
|
|
}
|
|
|
|
pub fn master_key(&self) -> Option<Arc<Vec<u8>>> {
|
|
self.master_key.as_ref().map(Arc::clone)
|
|
}
|
|
|
|
pub fn vault(&self) -> Option<Arc<Mutex<VaultHandle>>> {
|
|
self.vault.as_ref().map(Arc::clone)
|
|
}
|
|
|
|
pub async fn read_file(&self, path: &str) -> Result<String> {
|
|
let call = McpToolCall {
|
|
name: "resources/get".to_string(),
|
|
arguments: serde_json::json!({ "path": path }),
|
|
};
|
|
match self.mcp_client.call_tool(call).await {
|
|
Ok(response) => {
|
|
if let Some(text) = extract_resource_content(&response.output) {
|
|
return Ok(text);
|
|
}
|
|
|
|
let formatted = serde_json::to_string_pretty(&response.output)
|
|
.unwrap_or_else(|_| response.output.to_string());
|
|
Ok(formatted)
|
|
}
|
|
Err(err) => {
|
|
log::warn!("MCP file read failed ({}); falling back to local read", err);
|
|
let content = std::fs::read_to_string(path)?;
|
|
Ok(content)
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn read_file_with_tools(&self, path: &str) -> Result<String> {
|
|
if !self.enable_code_tools {
|
|
return Err(Error::InvalidInput(
|
|
"Code tools are disabled in chat mode. Run `:mode code` to switch.".to_string(),
|
|
));
|
|
}
|
|
|
|
let call = McpToolCall {
|
|
name: "resources/get".to_string(),
|
|
arguments: serde_json::json!({ "path": path }),
|
|
};
|
|
|
|
let response = self.mcp_client.call_tool(call).await?;
|
|
if let Some(text) = extract_resource_content(&response.output) {
|
|
Ok(text)
|
|
} else {
|
|
let formatted = serde_json::to_string_pretty(&response.output)
|
|
.unwrap_or_else(|_| response.output.to_string());
|
|
Ok(formatted)
|
|
}
|
|
}
|
|
|
|
pub fn code_tools_enabled(&self) -> bool {
|
|
self.enable_code_tools
|
|
}
|
|
|
|
pub async fn set_code_tools_enabled(&mut self, enabled: bool) -> Result<()> {
|
|
if self.enable_code_tools == enabled {
|
|
return Ok(());
|
|
}
|
|
|
|
self.enable_code_tools = enabled;
|
|
self.rebuild_tools().await
|
|
}
|
|
|
|
pub async fn set_operating_mode(&mut self, mode: Mode) -> Result<()> {
|
|
self.current_mode = mode;
|
|
let enable_code_tools = matches!(mode, Mode::Code);
|
|
self.set_code_tools_enabled(enable_code_tools).await?;
|
|
self.mcp_client.set_mode(mode).await
|
|
}
|
|
|
|
pub async fn list_dir(&self, path: &str) -> Result<Vec<String>> {
|
|
let call = McpToolCall {
|
|
name: "resources/list".to_string(),
|
|
arguments: serde_json::json!({ "path": path }),
|
|
};
|
|
match self.mcp_client.call_tool(call).await {
|
|
Ok(response) => {
|
|
let content: Vec<String> = serde_json::from_value(response.output)?;
|
|
Ok(content)
|
|
}
|
|
Err(err) => {
|
|
log::warn!(
|
|
"MCP directory list failed ({}); falling back to local list",
|
|
err
|
|
);
|
|
let mut entries = Vec::new();
|
|
for entry in std::fs::read_dir(path)? {
|
|
let entry = entry?;
|
|
entries.push(entry.file_name().to_string_lossy().to_string());
|
|
}
|
|
Ok(entries)
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn write_file(&self, path: &str, content: &str) -> Result<()> {
|
|
let call = McpToolCall {
|
|
name: "resources/write".to_string(),
|
|
arguments: serde_json::json!({ "path": path, "content": content }),
|
|
};
|
|
match self.mcp_client.call_tool(call).await {
|
|
Ok(_) => Ok(()),
|
|
Err(err) => {
|
|
log::warn!(
|
|
"MCP file write failed ({}); falling back to local write",
|
|
err
|
|
);
|
|
// Ensure parent directory exists
|
|
if let Some(parent) = std::path::Path::new(path).parent() {
|
|
std::fs::create_dir_all(parent)?;
|
|
}
|
|
std::fs::write(path, content)?;
|
|
Ok(())
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn delete_file(&self, path: &str) -> Result<()> {
|
|
let call = McpToolCall {
|
|
name: "resources/delete".to_string(),
|
|
arguments: serde_json::json!({ "path": path }),
|
|
};
|
|
match self.mcp_client.call_tool(call).await {
|
|
Ok(_) => Ok(()),
|
|
Err(err) => {
|
|
log::warn!(
|
|
"MCP file delete failed ({}); falling back to local delete",
|
|
err
|
|
);
|
|
std::fs::remove_file(path)?;
|
|
Ok(())
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn rebuild_tools(&mut self) -> Result<()> {
|
|
let (registry, validator) = build_tools(
|
|
self.config.clone(),
|
|
self.ui.clone(),
|
|
self.enable_code_tools,
|
|
self.consent_manager.clone(),
|
|
self.credential_manager.clone(),
|
|
self.vault.clone(),
|
|
)
|
|
.await?;
|
|
self.tool_registry = registry;
|
|
self.schema_validator = validator;
|
|
|
|
// Recreate MCP client with permission layer
|
|
let config = self.config.lock().await;
|
|
let factory = McpClientFactory::new(
|
|
Arc::new(config.clone()),
|
|
self.tool_registry.clone(),
|
|
self.schema_validator.clone(),
|
|
);
|
|
let base_client = factory.create()?;
|
|
let permission_client = PermissionLayer::new(base_client, Arc::new(config.clone()));
|
|
let client = Arc::new(permission_client);
|
|
client.set_mode(self.current_mode).await?;
|
|
self.mcp_client = client;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub fn selected_model(&self) -> &str {
|
|
&self.conversation.active().model
|
|
}
|
|
|
|
pub async fn set_model(&mut self, model: String) {
|
|
self.conversation.set_model(model.clone());
|
|
let mut config = self.config.lock().await;
|
|
config.general.default_model = Some(model);
|
|
}
|
|
|
|
pub async fn models(&self, force_refresh: bool) -> Result<Vec<ModelInfo>> {
|
|
self.model_manager
|
|
.get_or_refresh(force_refresh, || async {
|
|
self.provider.list_models().await
|
|
})
|
|
.await
|
|
}
|
|
|
|
fn as_ollama(&self) -> Option<&OllamaProvider> {
|
|
self.provider
|
|
.as_ref()
|
|
.as_any()
|
|
.downcast_ref::<OllamaProvider>()
|
|
}
|
|
|
|
pub async fn model_details(
|
|
&self,
|
|
model_name: &str,
|
|
force_refresh: bool,
|
|
) -> Result<DetailedModelInfo> {
|
|
if let Some(ollama) = self.as_ollama() {
|
|
if force_refresh {
|
|
ollama.refresh_model_info(model_name).await
|
|
} else {
|
|
ollama.get_model_info(model_name).await
|
|
}
|
|
} else {
|
|
Err(Error::NotImplemented(format!(
|
|
"Provider '{}' does not expose model inspection",
|
|
self.provider.name()
|
|
)))
|
|
}
|
|
}
|
|
|
|
pub async fn all_model_details(&self, force_refresh: bool) -> Result<Vec<DetailedModelInfo>> {
|
|
if let Some(ollama) = self.as_ollama() {
|
|
if force_refresh {
|
|
ollama.clear_model_info_cache().await;
|
|
}
|
|
ollama.get_all_models_info().await
|
|
} else {
|
|
Err(Error::NotImplemented(format!(
|
|
"Provider '{}' does not expose model inspection",
|
|
self.provider.name()
|
|
)))
|
|
}
|
|
}
|
|
|
|
pub async fn cached_model_details(&self) -> Vec<DetailedModelInfo> {
|
|
if let Some(ollama) = self.as_ollama() {
|
|
ollama.cached_model_info().await
|
|
} else {
|
|
Vec::new()
|
|
}
|
|
}
|
|
|
|
pub async fn invalidate_model_details(&self, model_name: &str) {
|
|
if let Some(ollama) = self.as_ollama() {
|
|
ollama.invalidate_model_info(model_name).await;
|
|
}
|
|
}
|
|
|
|
pub async fn clear_model_details_cache(&self) {
|
|
if let Some(ollama) = self.as_ollama() {
|
|
ollama.clear_model_info_cache().await;
|
|
}
|
|
}
|
|
|
|
pub async fn ensure_default_model(&mut self, models: &[ModelInfo]) {
|
|
let mut config = self.config.lock().await;
|
|
if let Some(default) = config.general.default_model.clone() {
|
|
if models.iter().any(|m| m.id == default || m.name == default) {
|
|
self.conversation.set_model(default.clone());
|
|
config.general.default_model = Some(default);
|
|
}
|
|
} else if let Some(model) = models.first() {
|
|
self.conversation.set_model(model.id.clone());
|
|
config.general.default_model = Some(model.id.clone());
|
|
}
|
|
}
|
|
|
|
pub async fn switch_provider(&mut self, provider: Arc<dyn Provider>) -> Result<()> {
|
|
self.provider = provider;
|
|
self.model_manager.invalidate().await;
|
|
Ok(())
|
|
}
|
|
|
|
/// Expose the underlying LLM provider.
|
|
pub fn provider(&self) -> Arc<dyn Provider> {
|
|
self.provider.clone()
|
|
}
|
|
|
|
pub async fn send_message(
|
|
&mut self,
|
|
content: String,
|
|
mut parameters: ChatParameters,
|
|
) -> Result<SessionOutcome> {
|
|
let streaming = { self.config.lock().await.general.enable_streaming || parameters.stream };
|
|
parameters.stream = streaming;
|
|
self.conversation.push_user_message(content);
|
|
self.send_request_with_current_conversation(parameters)
|
|
.await
|
|
}
|
|
|
|
pub async fn send_request_with_current_conversation(
|
|
&mut self,
|
|
mut parameters: ChatParameters,
|
|
) -> Result<SessionOutcome> {
|
|
let streaming = { self.config.lock().await.general.enable_streaming || parameters.stream };
|
|
parameters.stream = streaming;
|
|
|
|
let tools = if !self.tool_registry.all().is_empty() {
|
|
Some(
|
|
self.tool_registry
|
|
.all()
|
|
.into_iter()
|
|
.map(|tool| crate::mcp::McpToolDescriptor {
|
|
name: tool.name().to_string(),
|
|
description: tool.description().to_string(),
|
|
input_schema: tool.schema(),
|
|
requires_network: tool.requires_network(),
|
|
requires_filesystem: tool.requires_filesystem(),
|
|
})
|
|
.collect(),
|
|
)
|
|
} else {
|
|
None
|
|
};
|
|
|
|
let mut request = ChatRequest {
|
|
model: self.conversation.active().model.clone(),
|
|
messages: self.conversation.active().messages.clone(),
|
|
parameters: parameters.clone(),
|
|
tools: tools.clone(),
|
|
};
|
|
|
|
if !streaming {
|
|
const MAX_TOOL_ITERATIONS: usize = 5;
|
|
for _iteration in 0..MAX_TOOL_ITERATIONS {
|
|
match self.provider.send_prompt(request.clone()).await {
|
|
Ok(response) => {
|
|
if response.message.has_tool_calls() {
|
|
self.conversation.push_message(response.message.clone());
|
|
if let Some(tool_calls) = &response.message.tool_calls {
|
|
for tool_call in tool_calls {
|
|
let mcp_tool_call = McpToolCall {
|
|
name: tool_call.name.clone(),
|
|
arguments: tool_call.arguments.clone(),
|
|
};
|
|
let tool_result =
|
|
self.mcp_client.call_tool(mcp_tool_call).await;
|
|
let tool_response_content = match tool_result {
|
|
Ok(result) => serde_json::to_string_pretty(&result.output)
|
|
.unwrap_or_else(|_| {
|
|
"Tool execution succeeded".to_string()
|
|
}),
|
|
Err(e) => format!("Tool execution failed: {}", e),
|
|
};
|
|
let tool_msg =
|
|
Message::tool(tool_call.id.clone(), tool_response_content);
|
|
self.conversation.push_message(tool_msg);
|
|
}
|
|
}
|
|
request.messages = self.conversation.active().messages.clone();
|
|
continue;
|
|
} else {
|
|
self.conversation.push_message(response.message.clone());
|
|
return Ok(SessionOutcome::Complete(response));
|
|
}
|
|
}
|
|
Err(err) => {
|
|
self.conversation
|
|
.push_assistant_message(format!("Error: {}", err));
|
|
return Err(err);
|
|
}
|
|
}
|
|
}
|
|
self.conversation
|
|
.push_assistant_message("Maximum tool execution iterations reached".to_string());
|
|
return Err(crate::Error::Provider(anyhow::anyhow!(
|
|
"Maximum tool execution iterations reached"
|
|
)));
|
|
}
|
|
|
|
match self.provider.stream_prompt(request).await {
|
|
Ok(stream) => {
|
|
let response_id = self.conversation.start_streaming_response();
|
|
Ok(SessionOutcome::Streaming {
|
|
response_id,
|
|
stream,
|
|
})
|
|
}
|
|
Err(err) => {
|
|
self.conversation
|
|
.push_assistant_message(format!("Error starting stream: {}", err));
|
|
Err(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn mark_stream_placeholder(&mut self, message_id: Uuid, text: &str) -> Result<()> {
|
|
self.conversation
|
|
.set_stream_placeholder(message_id, text.to_string())
|
|
}
|
|
|
|
pub fn apply_stream_chunk(&mut self, message_id: Uuid, chunk: &ChatResponse) -> Result<()> {
|
|
if chunk.message.has_tool_calls() {
|
|
self.conversation.set_tool_calls_on_message(
|
|
message_id,
|
|
chunk.message.tool_calls.clone().unwrap_or_default(),
|
|
)?;
|
|
}
|
|
self.conversation
|
|
.append_stream_chunk(message_id, &chunk.message.content, chunk.is_final)
|
|
}
|
|
|
|
pub fn check_streaming_tool_calls(&self, message_id: Uuid) -> Option<Vec<ToolCall>> {
|
|
self.conversation
|
|
.active()
|
|
.messages
|
|
.iter()
|
|
.find(|m| m.id == message_id)
|
|
.and_then(|m| m.tool_calls.clone())
|
|
.filter(|calls| !calls.is_empty())
|
|
}
|
|
|
|
pub fn cancel_stream(&mut self, message_id: Uuid, notice: &str) -> Result<()> {
|
|
self.conversation
|
|
.cancel_stream(message_id, notice.to_string())
|
|
}
|
|
|
|
pub async fn execute_streaming_tools(
|
|
&mut self,
|
|
_message_id: Uuid,
|
|
tool_calls: Vec<ToolCall>,
|
|
) -> Result<SessionOutcome> {
|
|
for tool_call in &tool_calls {
|
|
let mcp_tool_call = McpToolCall {
|
|
name: tool_call.name.clone(),
|
|
arguments: tool_call.arguments.clone(),
|
|
};
|
|
let tool_result = self.mcp_client.call_tool(mcp_tool_call).await;
|
|
let tool_response_content = match tool_result {
|
|
Ok(result) => serde_json::to_string_pretty(&result.output)
|
|
.unwrap_or_else(|_| "Tool execution succeeded".to_string()),
|
|
Err(e) => format!("Tool execution failed: {}", e),
|
|
};
|
|
let tool_msg = Message::tool(tool_call.id.clone(), tool_response_content);
|
|
self.conversation.push_message(tool_msg);
|
|
}
|
|
let parameters = ChatParameters {
|
|
stream: self.config.lock().await.general.enable_streaming,
|
|
..Default::default()
|
|
};
|
|
self.send_request_with_current_conversation(parameters)
|
|
.await
|
|
}
|
|
|
|
pub fn history(&self) -> Vec<Conversation> {
|
|
self.conversation.history().cloned().collect()
|
|
}
|
|
|
|
pub fn start_new_conversation(&mut self, model: Option<String>, name: Option<String>) {
|
|
self.conversation.start_new(model, name);
|
|
}
|
|
|
|
pub fn clear(&mut self) {
|
|
self.conversation.clear();
|
|
}
|
|
|
|
pub async fn generate_conversation_description(&self) -> Result<String> {
|
|
// ... (implementation remains the same)
|
|
Ok("Empty conversation".to_string())
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::Provider;
|
|
use crate::config::{Config, McpMode, McpOAuthConfig, McpServerConfig};
|
|
use crate::llm::test_utils::MockProvider;
|
|
use crate::storage::StorageManager;
|
|
use crate::ui::NoOpUiController;
|
|
use chrono::Utc;
|
|
use httpmock::prelude::*;
|
|
use serde_json::json;
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use tempfile::tempdir;
|
|
|
|
const SERVER_NAME: &str = "oauth-test";
|
|
|
|
fn build_oauth_config(server: &MockServer) -> McpOAuthConfig {
|
|
McpOAuthConfig {
|
|
client_id: "owlen-client".to_string(),
|
|
client_secret: None,
|
|
authorize_url: server.url("/authorize"),
|
|
token_url: server.url("/token"),
|
|
device_authorization_url: Some(server.url("/device")),
|
|
redirect_url: None,
|
|
scopes: vec!["repo".to_string()],
|
|
token_env: Some("OAUTH_TOKEN".to_string()),
|
|
header: Some("Authorization".to_string()),
|
|
header_prefix: Some("Bearer ".to_string()),
|
|
}
|
|
}
|
|
|
|
fn build_config(server: &MockServer) -> Config {
|
|
let mut config = Config::default();
|
|
config.mcp.mode = McpMode::LocalOnly;
|
|
let oauth = build_oauth_config(server);
|
|
|
|
let mut env = HashMap::new();
|
|
env.insert("OWLEN_ENV".to_string(), "test".to_string());
|
|
|
|
config.mcp_servers = vec![McpServerConfig {
|
|
name: SERVER_NAME.to_string(),
|
|
command: server.url("/mcp"),
|
|
args: Vec::new(),
|
|
transport: "http".to_string(),
|
|
env,
|
|
oauth: Some(oauth),
|
|
}];
|
|
|
|
config.refresh_mcp_servers(None).unwrap();
|
|
config
|
|
}
|
|
|
|
async fn build_session(server: &MockServer) -> (SessionController, tempfile::TempDir) {
|
|
unsafe {
|
|
std::env::set_var("OWLEN_MASTER_PASSWORD", "test-password");
|
|
}
|
|
|
|
let temp_dir = tempdir().expect("tempdir");
|
|
let storage_path = temp_dir.path().join("owlen.db");
|
|
let storage = Arc::new(
|
|
StorageManager::with_database_path(storage_path)
|
|
.await
|
|
.expect("storage"),
|
|
);
|
|
|
|
let config = build_config(server);
|
|
let provider: Arc<dyn Provider> = Arc::new(MockProvider::default()) as Arc<dyn Provider>;
|
|
let ui = Arc::new(NoOpUiController);
|
|
|
|
let session = SessionController::new(provider, config, storage, ui, false)
|
|
.await
|
|
.expect("session");
|
|
|
|
(session, temp_dir)
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn start_oauth_device_flow_returns_details() {
|
|
let server = MockServer::start_async().await;
|
|
let device = server
|
|
.mock_async(|when, then| {
|
|
when.method(POST).path("/device");
|
|
then.status(200)
|
|
.header("content-type", "application/json")
|
|
.json_body(json!({
|
|
"device_code": "device-abc",
|
|
"user_code": "ABCD-1234",
|
|
"verification_uri": "https://example.test/activate",
|
|
"verification_uri_complete": "https://example.test/activate?user_code=ABCD-1234",
|
|
"expires_in": 600,
|
|
"interval": 5,
|
|
"message": "Enter the code to continue."
|
|
}));
|
|
})
|
|
.await;
|
|
|
|
let (session, _dir) = build_session(&server).await;
|
|
let authorization = session
|
|
.start_oauth_device_flow(SERVER_NAME)
|
|
.await
|
|
.expect("device flow");
|
|
|
|
assert_eq!(authorization.user_code, "ABCD-1234");
|
|
assert_eq!(
|
|
authorization.verification_uri_complete.as_deref(),
|
|
Some("https://example.test/activate?user_code=ABCD-1234")
|
|
);
|
|
assert!(authorization.expires_at > Utc::now());
|
|
device.assert_async().await;
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn poll_oauth_device_flow_stores_token_and_updates_state() {
|
|
let server = MockServer::start_async().await;
|
|
|
|
let device = server
|
|
.mock_async(|when, then| {
|
|
when.method(POST).path("/device");
|
|
then.status(200)
|
|
.header("content-type", "application/json")
|
|
.json_body(json!({
|
|
"device_code": "device-xyz",
|
|
"user_code": "WXYZ-9999",
|
|
"verification_uri": "https://example.test/activate",
|
|
"verification_uri_complete": "https://example.test/activate?user_code=WXYZ-9999",
|
|
"expires_in": 600,
|
|
"interval": 5
|
|
}));
|
|
})
|
|
.await;
|
|
|
|
let token = server
|
|
.mock_async(|when, then| {
|
|
when.method(POST)
|
|
.path("/token")
|
|
.body_contains("device_code=device-xyz");
|
|
then.status(200)
|
|
.header("content-type", "application/json")
|
|
.json_body(json!({
|
|
"access_token": "new-access-token",
|
|
"refresh_token": "refresh-token",
|
|
"expires_in": 3600,
|
|
"token_type": "Bearer"
|
|
}));
|
|
})
|
|
.await;
|
|
|
|
let (mut session, _dir) = build_session(&server).await;
|
|
assert_eq!(session.pending_oauth_servers(), vec![SERVER_NAME]);
|
|
|
|
let authorization = session
|
|
.start_oauth_device_flow(SERVER_NAME)
|
|
.await
|
|
.expect("device flow");
|
|
|
|
match session
|
|
.poll_oauth_device_flow(SERVER_NAME, &authorization)
|
|
.await
|
|
.expect("token poll")
|
|
{
|
|
DevicePollState::Complete(token_info) => {
|
|
assert_eq!(token_info.access_token, "new-access-token");
|
|
assert_eq!(token_info.refresh_token.as_deref(), Some("refresh-token"));
|
|
}
|
|
other => panic!("expected token completion, got {other:?}"),
|
|
}
|
|
|
|
assert!(
|
|
session
|
|
.pending_oauth_servers()
|
|
.iter()
|
|
.all(|entry| entry != SERVER_NAME),
|
|
"server should be removed from pending list"
|
|
);
|
|
|
|
let stored = session
|
|
.credential_manager()
|
|
.expect("credential manager")
|
|
.load_oauth_token(SERVER_NAME)
|
|
.await
|
|
.expect("load token")
|
|
.expect("token present");
|
|
|
|
assert_eq!(stored.access_token, "new-access-token");
|
|
assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token"));
|
|
|
|
device.assert_async().await;
|
|
token.assert_async().await;
|
|
}
|
|
}
|