diff --git a/README.md b/README.md index dcf4212..3a7d536 100644 --- a/README.md +++ b/README.md @@ -92,6 +92,7 @@ OWLEN uses a modal, vim-inspired interface. Press `F1` (available from any mode) - **Editing Mode**: Enter with `i` or `a`. Send messages with `Enter`. - **Command Mode**: Enter with `:`. Access commands like `:quit`, `:save`, `:theme`. - **Tutorial Command**: Type `:tutorial` any time for a quick summary of the most important keybindings. +- **MCP Slash Commands**: Owlen auto-registers zero-argument MCP tools as slash commands—type `/mcp__github__list_prs` (for example) to pull remote context directly into the chat log. ## Documentation diff --git a/crates/owlen-cli/src/main.rs b/crates/owlen-cli/src/main.rs index d1c4977..48f7577 100644 --- a/crates/owlen-cli/src/main.rs +++ b/crates/owlen-cli/src/main.rs @@ -1,11 +1,13 @@ //! OWLEN CLI - Chat TUI client mod cloud; +mod mcp; use anyhow::{Result, anyhow}; use async_trait::async_trait; use clap::{Parser, Subcommand}; use cloud::{CloudCommand, load_runtime_credentials, set_env_var}; +use mcp::{McpCommand, run_mcp_command}; use owlen_core::config as core_config; use owlen_core::{ ChatStream, Error, Provider, @@ -54,6 +56,9 @@ enum OwlenCommand { /// Manage Ollama Cloud credentials #[command(subcommand)] Cloud(CloudCommand), + /// Manage MCP server registrations + #[command(subcommand)] + Mcp(McpCommand), /// Show manual steps for updating Owlen to the latest revision Upgrade, } @@ -69,7 +74,7 @@ enum ConfigCommand { fn build_provider(cfg: &Config) -> anyhow::Result> { match cfg.mcp.mode { McpMode::RemotePreferred => { - let remote_result = if let Some(mcp_server) = cfg.mcp_servers.first() { + let remote_result = if let Some(mcp_server) = cfg.effective_mcp_servers().first() { RemoteMcpClient::new_with_config(mcp_server) } else { RemoteMcpClient::new() @@ -91,7 +96,7 @@ fn build_provider(cfg: &Config) -> anyhow::Result> { } } McpMode::RemoteOnly => { - let mcp_server = cfg.mcp_servers.first().ok_or_else(|| { + let mcp_server = cfg.effective_mcp_servers().first().ok_or_else(|| { anyhow::anyhow!( "[[mcp_servers]] must be configured when [mcp].mode = \"remote_only\"" ) @@ -130,6 +135,7 @@ async fn run_command(command: OwlenCommand) -> Result<()> { match command { OwlenCommand::Config(config_cmd) => run_config_command(config_cmd), OwlenCommand::Cloud(cloud_cmd) => cloud::run_cloud_command(cloud_cmd).await, + OwlenCommand::Mcp(mcp_cmd) => run_mcp_command(mcp_cmd), OwlenCommand::Upgrade => { println!( "To update Owlen from source:\n git pull\n cargo install --path crates/owlen-cli --force" @@ -157,6 +163,7 @@ fn run_config_doctor() -> Result<()> { let config_path = core_config::default_config_path(); let existed = config_path.exists(); let mut config = config::try_load_config().unwrap_or_default(); + let _ = config.refresh_mcp_servers(None); let mut changes = Vec::new(); if !existed { @@ -205,7 +212,7 @@ fn run_config_doctor() -> Result<()> { config.mcp.warn_on_legacy = true; changes.push("converted [mcp].mode = 'legacy' to 'local_only'".to_string()); } - McpMode::RemoteOnly if config.mcp_servers.is_empty() => { + McpMode::RemoteOnly if config.effective_mcp_servers().is_empty() => { config.mcp.mode = McpMode::RemotePreferred; config.mcp.allow_fallback = true; changes.push( @@ -213,7 +220,9 @@ fn run_config_doctor() -> Result<()> { .to_string(), ); } - McpMode::RemotePreferred if !config.mcp.allow_fallback && config.mcp_servers.is_empty() => { + McpMode::RemotePreferred + if !config.mcp.allow_fallback && config.effective_mcp_servers().is_empty() => + { config.mcp.allow_fallback = true; changes.push( "enabled [mcp].allow_fallback because no remote servers are configured".to_string(), @@ -369,6 +378,7 @@ async fn main() -> Result<()> { let color_support = detect_terminal_color_support(); // Load configuration (or fall back to defaults) for the session controller. let mut cfg = config::try_load_config().unwrap_or_default(); + let _ = cfg.refresh_mcp_servers(None); if let Some(previous_theme) = apply_terminal_theme(&mut cfg, &color_support) { let term_label = match &color_support { TerminalColorSupport::Limited { term } => Cow::from(term.as_str()), @@ -398,7 +408,7 @@ async fn main() -> Result<()> { Ok(_) => provider, Err(err) => { let hint = if matches!(cfg.mcp.mode, McpMode::RemotePreferred | McpMode::RemoteOnly) - && !cfg.mcp_servers.is_empty() + && !cfg.effective_mcp_servers().is_empty() { "Ensure the configured MCP server is running and reachable." } else { @@ -523,7 +533,7 @@ async fn run_app( } } Some(session_event) = session_rx.recv() => { - app.handle_session_event(session_event)?; + app.handle_session_event(session_event).await?; } _ = tokio::time::sleep(sleep_duration) => {} } diff --git a/crates/owlen-cli/src/mcp.rs b/crates/owlen-cli/src/mcp.rs new file mode 100644 index 0000000..34410eb --- /dev/null +++ b/crates/owlen-cli/src/mcp.rs @@ -0,0 +1,257 @@ +use std::collections::{HashMap, HashSet}; + +use anyhow::{Result, anyhow}; +use clap::{Args, Subcommand, ValueEnum}; +use owlen_core::config::{self as core_config, Config, McpConfigScope, McpServerConfig}; +use owlen_tui::config as tui_config; + +#[derive(Debug, Subcommand)] +pub enum McpCommand { + /// Add or update an MCP server in the selected scope + Add(AddArgs), + /// List MCP servers across scopes + List(ListArgs), + /// Remove an MCP server from a scope + Remove(RemoveArgs), +} + +pub fn run_mcp_command(command: McpCommand) -> Result<()> { + match command { + McpCommand::Add(args) => handle_add(args), + McpCommand::List(args) => handle_list(args), + McpCommand::Remove(args) => handle_remove(args), + } +} + +#[derive(Debug, Clone, Copy, ValueEnum, Default)] +pub enum ScopeArg { + User, + #[default] + Project, + Local, +} + +impl From for McpConfigScope { + fn from(value: ScopeArg) -> Self { + match value { + ScopeArg::User => McpConfigScope::User, + ScopeArg::Project => McpConfigScope::Project, + ScopeArg::Local => McpConfigScope::Local, + } + } +} + +#[derive(Debug, Args)] +pub struct AddArgs { + /// Logical name used to reference the server + pub name: String, + /// Command or endpoint invoked for the server + pub command: String, + /// Transport mechanism (stdio, http, websocket) + #[arg(long, default_value = "stdio")] + pub transport: String, + /// Configuration scope to write the server into + #[arg(long, value_enum, default_value_t = ScopeArg::Project)] + pub scope: ScopeArg, + /// Environment variables (KEY=VALUE) passed to the server process + #[arg(long = "env")] + pub env: Vec, + /// Additional arguments appended when launching the server + #[arg(trailing_var_arg = true, value_name = "ARG")] + pub args: Vec, +} + +#[derive(Debug, Args, Default)] +pub struct ListArgs { + /// Restrict output to a specific configuration scope + #[arg(long, value_enum)] + pub scope: Option, + /// Display only the effective servers (after precedence resolution) + #[arg(long)] + pub effective_only: bool, +} + +#[derive(Debug, Args)] +pub struct RemoveArgs { + /// Name of the server to remove + pub name: String, + /// Optional explicit scope to remove from + #[arg(long, value_enum)] + pub scope: Option, +} + +fn handle_add(args: AddArgs) -> Result<()> { + let mut config = load_config()?; + let scope: McpConfigScope = args.scope.into(); + let mut env_map = HashMap::new(); + for pair in &args.env { + let (key, value) = pair + .split_once('=') + .ok_or_else(|| anyhow!("Environment pairs must use KEY=VALUE syntax: '{}'", pair))?; + if key.trim().is_empty() { + return Err(anyhow!("Environment variable name cannot be empty")); + } + env_map.insert(key.trim().to_string(), value.to_string()); + } + + let server = McpServerConfig { + name: args.name.clone(), + command: args.command.clone(), + args: args.args.clone(), + transport: args.transport.to_lowercase(), + env: env_map, + oauth: None, + }; + + config.add_mcp_server(scope, server.clone(), None)?; + if matches!(scope, McpConfigScope::User) { + tui_config::save_config(&config)?; + } + + if let Some(path) = core_config::mcp_scope_path(scope, None) { + println!( + "Registered MCP server '{}' in {} scope ({})", + server.name, + scope, + path.display() + ); + } else { + println!( + "Registered MCP server '{}' in {} scope.", + server.name, scope + ); + } + + Ok(()) +} + +fn handle_list(args: ListArgs) -> Result<()> { + let mut config = load_config()?; + config.refresh_mcp_servers(None)?; + + let scoped = config.scoped_mcp_servers(); + if scoped.is_empty() { + println!("No MCP servers configured."); + return Ok(()); + } + + let filter_scope = args.scope.map(|scope| scope.into()); + let effective = config.effective_mcp_servers(); + let mut active = HashSet::new(); + for server in effective { + active.insert(( + server.name.clone(), + server.command.clone(), + server.transport.to_lowercase(), + )); + } + + println!( + "{:<2} {:<8} {:<20} {:<10} Command", + "", "Scope", "Name", "Transport" + ); + for entry in scoped { + if let Some(target_scope) = filter_scope + && entry.scope != target_scope + { + continue; + } + + let payload = format_command_line(&entry.config.command, &entry.config.args); + let key = ( + entry.config.name.clone(), + entry.config.command.clone(), + entry.config.transport.to_lowercase(), + ); + let marker = if active.contains(&key) { "*" } else { " " }; + + if args.effective_only && marker != "*" { + continue; + } + + println!( + "{} {:<8} {:<20} {:<10} {}", + marker, entry.scope, entry.config.name, entry.config.transport, payload + ); + } + + let scoped_resources = config.scoped_mcp_resources(); + if !scoped_resources.is_empty() { + println!(); + println!("{:<2} {:<8} {:<30} Title", "", "Scope", "Resource"); + let effective_keys: HashSet<(String, String)> = config + .effective_mcp_resources() + .iter() + .map(|res| (res.server.clone(), res.uri.clone())) + .collect(); + + for entry in scoped_resources { + if let Some(target_scope) = filter_scope + && entry.scope != target_scope + { + continue; + } + + let key = (entry.config.server.clone(), entry.config.uri.clone()); + let marker = if effective_keys.contains(&key) { + "*" + } else { + " " + }; + if args.effective_only && marker != "*" { + continue; + } + + let reference = format!("@{}:{}", entry.config.server, entry.config.uri); + let title = entry.config.title.as_deref().unwrap_or("—"); + + println!("{} {:<8} {:<30} {}", marker, entry.scope, reference, title); + } + } + + Ok(()) +} + +fn handle_remove(args: RemoveArgs) -> Result<()> { + let mut config = load_config()?; + let scope_hint = args.scope.map(|scope| scope.into()); + let result = config.remove_mcp_server(scope_hint, &args.name, None)?; + + match result { + Some(scope) => { + if matches!(scope, McpConfigScope::User) { + tui_config::save_config(&config)?; + } + + if let Some(path) = core_config::mcp_scope_path(scope, None) { + println!( + "Removed MCP server '{}' from {} scope ({})", + args.name, + scope, + path.display() + ); + } else { + println!("Removed MCP server '{}' from {} scope.", args.name, scope); + } + } + None => { + println!("No MCP server named '{}' was found.", args.name); + } + } + + Ok(()) +} + +fn load_config() -> Result { + let mut config = tui_config::try_load_config().unwrap_or_default(); + config.refresh_mcp_servers(None)?; + Ok(config) +} + +fn format_command_line(command: &str, args: &[String]) -> String { + if args.is_empty() { + command.to_string() + } else { + format!("{} {}", command, args.join(" ")) + } +} diff --git a/crates/owlen-core/Cargo.toml b/crates/owlen-core/Cargo.toml index a4e1a94..90e4e20 100644 --- a/crates/owlen-core/Cargo.toml +++ b/crates/owlen-core/Cargo.toml @@ -50,3 +50,4 @@ ollama-rs = { version = "0.3", features = ["stream", "headers"] } [dev-dependencies] tokio-test = { workspace = true } +httpmock = "0.7" diff --git a/crates/owlen-core/src/config.rs b/crates/owlen-core/src/config.rs index daed5dd..c354a7d 100644 --- a/crates/owlen-core/src/config.rs +++ b/crates/owlen-core/src/config.rs @@ -1,13 +1,15 @@ +use crate::Error; use crate::ProviderConfig; use crate::Result; use crate::mode::ModeConfig; use crate::ui::RoleLabelDisplay; use serde::de::{self, Deserializer, Visitor}; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::fmt; use std::fs; use std::path::{Path, PathBuf}; +use std::str::FromStr; use std::time::Duration; /// Default location for the OWLEN configuration file @@ -54,6 +56,21 @@ pub struct Config { /// External MCP server definitions #[serde(default)] pub mcp_servers: Vec, + /// User-scoped resource definitions + #[serde(default)] + pub mcp_resources: Vec, + /// Resolved MCP servers across scopes (runtime only). + #[serde(skip)] + pub scoped_mcp_servers: Vec, + /// Effective MCP servers after applying precedence rules (runtime only). + #[serde(skip)] + pub effective_mcp_servers: Vec, + /// Resolved MCP resources across scopes (runtime only). + #[serde(skip)] + pub scoped_mcp_resources: Vec, + /// Effective MCP resources after precedence (runtime only). + #[serde(skip)] + pub effective_mcp_resources: Vec, } impl Default for Config { @@ -74,6 +91,11 @@ impl Default for Config { tools: ToolSettings::default(), modes: ModeConfig::default(), mcp_servers: Vec::new(), + mcp_resources: Vec::new(), + scoped_mcp_servers: Vec::new(), + effective_mcp_servers: Vec::new(), + scoped_mcp_resources: Vec::new(), + effective_mcp_resources: Vec::new(), } } } @@ -94,6 +116,9 @@ pub struct McpServerConfig { /// Optional environment variable map for the process. #[serde(default)] pub env: std::collections::HashMap, + /// Optional OAuth configuration for remote servers. + #[serde(default)] + pub oauth: Option, } impl McpServerConfig { @@ -102,6 +127,126 @@ impl McpServerConfig { } } +/// OAuth configuration for MCP servers that require delegated authentication. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct McpOAuthConfig { + /// Public client identifier registered with the authorization server. + pub client_id: String, + /// Optional client secret for confidential clients. + #[serde(default)] + pub client_secret: Option, + /// OAuth authorization endpoint (used for web-based flows). + pub authorize_url: String, + /// OAuth token endpoint. + pub token_url: String, + /// Optional device authorization endpoint for device-code flows. + #[serde(default)] + pub device_authorization_url: Option, + /// Optional redirect URL (PKCE / authorization-code flows). + #[serde(default)] + pub redirect_url: Option, + /// Requested OAuth scopes. + #[serde(default)] + pub scopes: Vec, + /// Environment variable name populated with the bearer access token when spawning stdio servers. + #[serde(default)] + pub token_env: Option, + /// Optional HTTP header name for bearer authentication (defaults to "Authorization"). + #[serde(default)] + pub header: Option, + /// Optional prefix prepended to the access token (defaults to "Bearer "). + #[serde(default)] + pub header_prefix: Option, +} + +impl McpOAuthConfig { + pub fn header_name(&self) -> &str { + self.header.as_deref().unwrap_or("Authorization") + } + + pub fn header_prefix(&self) -> &str { + self.header_prefix.as_deref().unwrap_or("Bearer ") + } +} + +/// Scope for MCP server configuration entries. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum McpConfigScope { + /// User-level configuration stored under the user's config directory. + User, + /// Project configuration stored in the repository (e.g. `.mcp.json`). + Project, + /// Local overrides stored alongside the project but excluded from version control. + Local, +} + +impl McpConfigScope { + fn precedence_iter() -> impl Iterator { + [ + McpConfigScope::Local, + McpConfigScope::Project, + McpConfigScope::User, + ] + .into_iter() + } + + fn as_str(self) -> &'static str { + match self { + McpConfigScope::User => "user", + McpConfigScope::Project => "project", + McpConfigScope::Local => "local", + } + } +} + +impl fmt::Display for McpConfigScope { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +impl FromStr for McpConfigScope { + type Err = String; + + fn from_str(s: &str) -> std::result::Result { + match s.to_ascii_lowercase().as_str() { + "user" => Ok(McpConfigScope::User), + "project" => Ok(McpConfigScope::Project), + "local" => Ok(McpConfigScope::Local), + other => Err(format!("Unknown MCP scope '{other}'")), + } + } +} + +/// A resolved MCP server entry annotated with its configuration scope. +#[derive(Debug, Clone)] +pub struct ScopedMcpServer { + pub scope: McpConfigScope, + pub config: McpServerConfig, +} + +/// Configuration for a predefined MCP resource reference. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct McpResourceConfig { + /// Named MCP server that owns this resource. + pub server: String, + /// URI or path identifying the resource within the server. + pub uri: String, + /// Optional short title displayed in UI. + #[serde(default)] + pub title: Option, + /// Optional detailed description shown in tooltips. + #[serde(default)] + pub description: Option, +} + +/// Resource entry annotated with its originating scope. +#[derive(Debug, Clone)] +pub struct ScopedMcpResource { + pub scope: McpConfigScope, + pub config: McpResourceConfig, +} + impl Config { fn default_schema_version() -> String { CONFIG_SCHEMA_VERSION.to_string() @@ -138,18 +283,22 @@ impl Config { config.mcp.apply_backward_compat(); config.apply_schema_migrations(&previous_version); config.expand_provider_env_vars()?; + config.refresh_mcp_servers(None)?; config.validate()?; Ok(config) } else { let mut config = Config::default(); config.expand_provider_env_vars()?; + config.refresh_mcp_servers(None)?; Ok(config) } } /// Persist configuration to disk pub fn save(&self, path: Option<&Path>) -> Result<()> { - self.validate()?; + let mut validator = self.clone(); + validator.refresh_mcp_servers(None)?; + validator.validate()?; let path = match path { Some(path) => path.to_path_buf(), @@ -214,6 +363,192 @@ impl Config { Ok(()) } + /// Refresh the resolved MCP server list by loading scope-specific definitions. + pub fn refresh_mcp_servers(&mut self, project_hint: Option<&Path>) -> Result<()> { + let mut scoped_servers = Vec::new(); + let mut scoped_resources = Vec::new(); + + let mut user_servers = self.mcp_servers.clone(); + expand_mcp_servers(&mut user_servers, "config.mcp_servers")?; + for server in user_servers { + scoped_servers.push(ScopedMcpServer { + scope: McpConfigScope::User, + config: server, + }); + } + + let mut user_resources = self.mcp_resources.clone(); + expand_mcp_resources(&mut user_resources, "config.mcp_resources")?; + for resource in user_resources { + scoped_resources.push(ScopedMcpResource { + scope: McpConfigScope::User, + config: resource, + }); + } + + for scope in [McpConfigScope::Project, McpConfigScope::Local] { + if let Some(path) = mcp_scope_path(scope, project_hint) { + let mut file = read_scope_config(&path)?; + let server_context = format!("mcp.{scope}.servers"); + expand_mcp_servers(&mut file.servers, &server_context)?; + for server in file.servers { + scoped_servers.push(ScopedMcpServer { + scope, + config: server, + }); + } + + let resource_context = format!("mcp.{scope}.resources"); + expand_mcp_resources(&mut file.resources, &resource_context)?; + for resource in file.resources { + scoped_resources.push(ScopedMcpResource { + scope, + config: resource, + }); + } + } + } + + let mut effective_servers = Vec::new(); + let mut seen_servers = HashSet::new(); + for scope in McpConfigScope::precedence_iter() { + for entry in scoped_servers.iter().filter(|entry| entry.scope == scope) { + if seen_servers.insert(entry.config.name.clone()) { + effective_servers.push(entry.config.clone()); + } + } + } + + let mut effective_resources = Vec::new(); + let mut seen_resources: HashSet<(String, String)> = HashSet::new(); + for scope in McpConfigScope::precedence_iter() { + for entry in scoped_resources.iter().filter(|entry| entry.scope == scope) { + let key = (entry.config.server.clone(), entry.config.uri.clone()); + if seen_resources.insert(key) { + effective_resources.push(entry.config.clone()); + } + } + } + + self.scoped_mcp_servers = scoped_servers; + self.effective_mcp_servers = effective_servers; + self.scoped_mcp_resources = scoped_resources; + self.effective_mcp_resources = effective_resources; + Ok(()) + } + + /// Return the merged MCP servers using scope precedence (local > project > user). + pub fn effective_mcp_servers(&self) -> &[McpServerConfig] { + &self.effective_mcp_servers + } + + /// Return MCP servers annotated with their originating scope. + pub fn scoped_mcp_servers(&self) -> &[ScopedMcpServer] { + &self.scoped_mcp_servers + } + + /// Return merged MCP resources using scope precedence (local > project > user). + pub fn effective_mcp_resources(&self) -> &[McpResourceConfig] { + &self.effective_mcp_resources + } + + /// Return scoped MCP resources with their origin scope metadata. + pub fn scoped_mcp_resources(&self) -> &[ScopedMcpResource] { + &self.scoped_mcp_resources + } + + /// Locate a configured resource by server and URI. + pub fn find_resource(&self, server: &str, uri: &str) -> Option<&McpResourceConfig> { + self.effective_mcp_resources + .iter() + .find(|resource| resource.server == server && resource.uri == uri) + } + + /// Add or replace an MCP server definition within the specified scope. + pub fn add_mcp_server( + &mut self, + scope: McpConfigScope, + server: McpServerConfig, + project_hint: Option<&Path>, + ) -> Result<()> { + match scope { + McpConfigScope::User => { + self.mcp_servers + .retain(|existing| existing.name != server.name); + self.mcp_servers.push(server); + } + other => { + let path = mcp_scope_path(other, project_hint).ok_or_else(|| { + Error::Config(format!( + "Unable to resolve project root for MCP scope '{}'", + other + )) + })?; + let mut file = read_scope_config(&path)?; + file.servers.retain(|existing| existing.name != server.name); + file.servers.push(server); + write_scope_config(&path, &file)?; + } + } + + self.refresh_mcp_servers(project_hint)?; + Ok(()) + } + + /// Remove an MCP server from the given scope, or infer the scope if omitted. + pub fn remove_mcp_server( + &mut self, + scope: Option, + name: &str, + project_hint: Option<&Path>, + ) -> Result> { + let target_scope = if let Some(scope) = scope { + scope + } else { + self.refresh_mcp_servers(project_hint)?; + match self + .scoped_mcp_servers + .iter() + .find(|entry| entry.config.name == name) + { + Some(entry) => entry.scope, + None => return Ok(None), + } + }; + + let removed = match target_scope { + McpConfigScope::User => { + let before = self.mcp_servers.len(); + self.mcp_servers.retain(|entry| entry.name != name); + before != self.mcp_servers.len() + } + other => { + let path = mcp_scope_path(other, project_hint).ok_or_else(|| { + Error::Config(format!( + "Unable to resolve project root for MCP scope '{}'", + other + )) + })?; + let mut file = read_scope_config(&path)?; + let before = file.servers.len(); + file.servers.retain(|entry| entry.name != name); + if before == file.servers.len() { + false + } else { + write_scope_config(&path, &file)?; + true + } + } + }; + + if removed { + self.refresh_mcp_servers(project_hint)?; + Ok(Some(target_scope)) + } else { + Ok(None) + } + } + /// Validate configuration invariants and surface actionable error messages. pub fn validate(&self) -> Result<()> { self.validate_default_provider()?; @@ -284,9 +619,15 @@ impl Config { } fn validate_mcp_settings(&self) -> Result<()> { + let has_effective_servers = if self.effective_mcp_servers.is_empty() { + !self.mcp_servers.is_empty() + } else { + !self.effective_mcp_servers.is_empty() + }; + match self.mcp.mode { McpMode::RemoteOnly => { - if self.mcp_servers.is_empty() { + if !has_effective_servers { return Err(crate::Error::Config( "[mcp].mode = 'remote_only' requires at least one [[mcp_servers]] entry" .to_string(), @@ -294,7 +635,7 @@ impl Config { } } McpMode::RemotePreferred => { - if !self.mcp.allow_fallback && self.mcp_servers.is_empty() { + if !self.mcp.allow_fallback && !has_effective_servers { return Err(crate::Error::Config( "[mcp].allow_fallback = false requires at least one [[mcp_servers]] entry" .to_string(), @@ -313,26 +654,13 @@ impl Config { } fn validate_mcp_servers(&self) -> Result<()> { - for server in &self.mcp_servers { - if server.name.trim().is_empty() { - return Err(crate::Error::Config( - "Each [[mcp_servers]] entry must include a non-empty name".to_string(), - )); + if self.scoped_mcp_servers.is_empty() { + for server in &self.mcp_servers { + validate_mcp_server_entry(server, McpConfigScope::User)?; } - - if server.command.trim().is_empty() { - return Err(crate::Error::Config(format!( - "MCP server '{}' must define a command or endpoint", - server.name - ))); - } - - let transport = server.transport.to_lowercase(); - if !matches!(transport.as_str(), "stdio" | "http" | "websocket") { - return Err(crate::Error::Config(format!( - "Unknown MCP transport '{}' for server '{}'", - server.transport, server.name - ))); + } else { + for entry in &self.scoped_mcp_servers { + validate_mcp_server_entry(&entry.config, entry.scope)?; } } @@ -349,6 +677,58 @@ fn default_ollama_provider_config() -> ProviderConfig { } } +fn validate_mcp_server_entry(server: &McpServerConfig, scope: McpConfigScope) -> Result<()> { + if server.name.trim().is_empty() { + return Err(Error::Config(format!( + "Each MCP server entry must include a non-empty name (scope: {scope})" + ))); + } + + if server.command.trim().is_empty() { + return Err(Error::Config(format!( + "MCP server '{}' must define a command or endpoint (scope: {scope})", + server.name + ))); + } + + let transport = server.transport.to_lowercase(); + if !matches!(transport.as_str(), "stdio" | "http" | "websocket") { + return Err(Error::Config(format!( + "Unknown MCP transport '{}' for server '{}' (scope: {scope})", + server.transport, server.name + ))); + } + + if let Some(oauth) = &server.oauth { + if oauth.client_id.trim().is_empty() { + return Err(Error::Config(format!( + "MCP server '{}' defines OAuth without a client_id", + server.name + ))); + } + if oauth.authorize_url.trim().is_empty() { + return Err(Error::Config(format!( + "MCP server '{}' defines OAuth without an authorize_url", + server.name + ))); + } + if oauth.token_url.trim().is_empty() { + return Err(Error::Config(format!( + "MCP server '{}' defines OAuth without a token_url", + server.name + ))); + } + if oauth.device_authorization_url.is_none() && oauth.redirect_url.is_none() { + return Err(Error::Config(format!( + "MCP server '{}' must define either device_authorization_url or redirect_url for OAuth flows", + server.name + ))); + } + } + + Ok(()) +} + fn expand_provider_entry(provider_name: &str, provider: &mut ProviderConfig) -> Result<()> { if let Some(ref mut base_url) = provider.base_url { let expanded = expand_env_string( @@ -379,6 +759,136 @@ fn expand_provider_entry(provider_name: &str, provider: &mut ProviderConfig) -> Ok(()) } +fn expand_mcp_servers(servers: &mut [McpServerConfig], field_path: &str) -> Result<()> { + for (idx, server) in servers.iter_mut().enumerate() { + expand_mcp_server_entry(server, field_path, idx)?; + } + Ok(()) +} + +fn expand_mcp_server_entry( + server: &mut McpServerConfig, + field_path: &str, + index: usize, +) -> Result<()> { + server.command = expand_env_string( + server.command.as_str(), + &format!("{field_path}[{index}].command"), + )?; + + for (arg_idx, arg) in server.args.iter_mut().enumerate() { + *arg = expand_env_string( + arg.as_str(), + &format!("{field_path}[{index}].args[{arg_idx}]"), + )?; + } + + for (env_key, env_value) in server.env.iter_mut() { + *env_value = expand_env_string( + env_value.as_str(), + &format!("{field_path}[{index}].env.{env_key}"), + )?; + } + + if let Some(oauth) = server.oauth.as_mut() { + oauth.client_id = expand_env_string( + oauth.client_id.as_str(), + &format!("{field_path}[{index}].oauth.client_id"), + )?; + oauth.authorize_url = expand_env_string( + oauth.authorize_url.as_str(), + &format!("{field_path}[{index}].oauth.authorize_url"), + )?; + oauth.token_url = expand_env_string( + oauth.token_url.as_str(), + &format!("{field_path}[{index}].oauth.token_url"), + )?; + + if let Some(secret) = oauth.client_secret.as_mut() { + *secret = expand_env_string( + secret.as_str(), + &format!("{field_path}[{index}].oauth.client_secret"), + )?; + } + + if let Some(device_url) = oauth.device_authorization_url.as_mut() { + *device_url = expand_env_string( + device_url.as_str(), + &format!("{field_path}[{index}].oauth.device_authorization_url"), + )?; + } + + if let Some(redirect) = oauth.redirect_url.as_mut() { + *redirect = expand_env_string( + redirect.as_str(), + &format!("{field_path}[{index}].oauth.redirect_url"), + )?; + } + + if let Some(token_env) = oauth.token_env.as_mut() { + *token_env = expand_env_string( + token_env.as_str(), + &format!("{field_path}[{index}].oauth.token_env"), + )?; + } + + if let Some(header) = oauth.header.as_mut() { + *header = expand_env_string( + header.as_str(), + &format!("{field_path}[{index}].oauth.header"), + )?; + } + + if let Some(prefix) = oauth.header_prefix.as_mut() { + *prefix = expand_env_string( + prefix.as_str(), + &format!("{field_path}[{index}].oauth.header_prefix"), + )?; + } + + for (scope_idx, scope) in oauth.scopes.iter_mut().enumerate() { + *scope = expand_env_string( + scope.as_str(), + &format!("{field_path}[{index}].oauth.scopes[{scope_idx}]"), + )?; + } + } + + Ok(()) +} + +fn expand_mcp_resources(resources: &mut [McpResourceConfig], field_path: &str) -> Result<()> { + for (idx, resource) in resources.iter_mut().enumerate() { + expand_mcp_resource_entry(resource, field_path, idx)?; + } + Ok(()) +} + +fn expand_mcp_resource_entry( + resource: &mut McpResourceConfig, + field_path: &str, + index: usize, +) -> Result<()> { + resource.server = expand_env_string( + resource.server.as_str(), + &format!("{field_path}[{index}].server"), + )?; + resource.uri = expand_env_string(resource.uri.as_str(), &format!("{field_path}[{index}].uri"))?; + + if let Some(title) = resource.title.as_mut() { + *title = expand_env_string(title.as_str(), &format!("{field_path}[{index}].title"))?; + } + + if let Some(description) = resource.description.as_mut() { + *description = expand_env_string( + description.as_str(), + &format!("{field_path}[{index}].description"), + )?; + } + + Ok(()) +} + fn expand_env_string(input: &str, field_path: &str) -> Result { if !input.contains('$') { return Ok(input.to_string()); @@ -408,6 +918,106 @@ pub fn default_config_path() -> PathBuf { PathBuf::from(shellexpand::tilde(DEFAULT_CONFIG_PATH).as_ref()) } +#[derive(Serialize, Deserialize, Default, Clone)] +struct McpConfigFile { + #[serde(default)] + servers: Vec, + #[serde(default)] + resources: Vec, +} + +#[derive(Serialize, Deserialize)] +#[serde(untagged)] +enum McpConfigEnvelope { + Array(Vec), + Object(McpConfigFile), +} + +fn read_scope_config(path: &Path) -> Result { + if !path.exists() { + return Ok(McpConfigFile::default()); + } + + let contents = fs::read_to_string(path).map_err(Error::Io)?; + if contents.trim().is_empty() { + return Ok(McpConfigFile::default()); + } + + let doc: McpConfigEnvelope = serde_json::from_str(&contents).map_err(|err| { + Error::Config(format!( + "Failed to parse MCP configuration at {}: {err}", + path.display() + )) + })?; + + Ok(match doc { + McpConfigEnvelope::Array(servers) => McpConfigFile { + servers, + resources: Vec::new(), + }, + McpConfigEnvelope::Object(doc) => doc, + }) +} + +fn write_scope_config(path: &Path, file: &McpConfigFile) -> Result<()> { + if let Some(parent) = path.parent() { + fs::create_dir_all(parent).map_err(Error::Io)?; + } + + let serialized = serde_json::to_string_pretty(file).map_err(|err| { + Error::Config(format!( + "Failed to serialize MCP configuration for {}: {err}", + path.display() + )) + })?; + + fs::write(path, serialized).map_err(Error::Io) +} + +/// Resolve the configuration file path for a given scope. +pub fn mcp_scope_path(scope: McpConfigScope, project_hint: Option<&Path>) -> Option { + match scope { + McpConfigScope::User => dirs::config_dir() + .or_else(|| Some(PathBuf::from(shellexpand::tilde("~/.config").as_ref()))) + .map(|dir| dir.join("owlen").join("mcp.json")), + McpConfigScope::Project | McpConfigScope::Local => { + let root = project_hint + .map(PathBuf::from) + .or_else(|| discover_project_root(None))?; + + if matches!(scope, McpConfigScope::Project) { + Some(root.join(".mcp.json")) + } else { + Some(root.join(".owlen").join("mcp.local.json")) + } + } + } +} + +fn discover_project_root(start: Option<&Path>) -> Option { + let mut current = start + .map(PathBuf::from) + .or_else(|| std::env::current_dir().ok())?; + + loop { + if current.join(".mcp.json").exists() + || current.join(".owlen").exists() + || current.join(".git").exists() + || current.join("Cargo.toml").exists() + { + return Some(current); + } + + if !current.pop() { + break; + } + } + + start + .map(PathBuf::from) + .or_else(|| std::env::current_dir().ok()) +} + /// General behaviour settings shared across clients #[derive(Debug, Clone, Serialize, Deserialize)] pub struct GeneralSettings { @@ -1173,6 +1783,7 @@ mod tests { transport: "udp".into(), args: Vec::new(), env: std::collections::HashMap::new(), + oauth: None, }]; let result = config.validate(); assert!( @@ -1186,4 +1797,113 @@ mod tests { config.mcp.mode = McpMode::LocalOnly; assert!(config.validate().is_ok()); } + + #[test] + fn refresh_mcp_servers_merges_scopes_with_precedence() { + let temp = tempfile::tempdir().expect("tempdir"); + let project_root = temp.path(); + std::fs::write( + project_root.join(".mcp.json"), + r#"{ + "servers": [ + { "name": "shared", "command": "project-cmd", "transport": "stdio" }, + { "name": "project-only", "command": "proj", "transport": "stdio" } + ], + "resources": [ + { "server": "github", "uri": "issue://123", "title": "Project Issue" }, + { "server": "docs", "uri": "page://start", "title": "Project Doc" } + ] + }"#, + ) + .expect("write project scope"); + + let local_dir = project_root.join(".owlen"); + std::fs::create_dir_all(&local_dir).expect("local dir"); + std::fs::write( + local_dir.join("mcp.local.json"), + r#"{ + "servers": [ + { "name": "shared", "command": "local-cmd", "transport": "stdio" } + ], + "resources": [ + { "server": "github", "uri": "issue://123", "title": "Local Override" } + ] + }"#, + ) + .expect("write local scope"); + + let mut config = Config::default(); + config.mcp_servers.push(McpServerConfig { + name: "shared".into(), + command: "user-cmd".into(), + args: Vec::new(), + transport: "stdio".into(), + env: std::collections::HashMap::new(), + oauth: None, + }); + config.mcp_resources.push(McpResourceConfig { + server: "github".into(), + uri: "issue://123".into(), + title: Some("User Issue".into()), + description: None, + }); + + config + .refresh_mcp_servers(Some(project_root)) + .expect("refresh scopes"); + + // We should have four scoped entries (user + two project + local) and precedence should select local + assert_eq!(config.scoped_mcp_servers().len(), 4); + let effective = config.effective_mcp_servers(); + assert_eq!(effective.len(), 2); // shared + project-only + assert_eq!(effective[0].command, "local-cmd"); + assert_eq!(effective[0].name, "shared"); + + assert_eq!(config.scoped_mcp_resources().len(), 4); + let effective_resources = config.effective_mcp_resources(); + assert_eq!(effective_resources.len(), 2); + assert_eq!( + effective_resources + .iter() + .find(|res| res.server == "github") + .and_then(|res| res.title.as_deref()), + Some("Local Override") + ); + } + + #[test] + fn remove_mcp_server_reports_scope() { + let temp = tempfile::tempdir().expect("tempdir"); + let project_root = temp.path(); + std::fs::write( + project_root.join(".mcp.json"), + r#"{ "servers": [{ "name": "project", "command": "proj", "transport": "stdio" }] }"#, + ) + .expect("write project scope"); + + let mut config = Config::default(); + config.mcp_servers.push(McpServerConfig { + name: "user".into(), + command: "user".into(), + args: Vec::new(), + transport: "stdio".into(), + env: std::collections::HashMap::new(), + oauth: None, + }); + config + .refresh_mcp_servers(Some(project_root)) + .expect("refresh scopes"); + + // Remove without specifying scope should pick highest precedence (project) + let removed_scope = config + .remove_mcp_server(None, "project", Some(project_root)) + .expect("remove call"); + assert_eq!(removed_scope, Some(McpConfigScope::Project)); + + // Remove the remaining user scope explicitly + let removed_scope = config + .remove_mcp_server(Some(McpConfigScope::User), "user", Some(project_root)) + .expect("remove user"); + assert_eq!(removed_scope, Some(McpConfigScope::User)); + } } diff --git a/crates/owlen-core/src/credentials.rs b/crates/owlen-core/src/credentials.rs index 785b43e..8fa5fb4 100644 --- a/crates/owlen-core/src/credentials.rs +++ b/crates/owlen-core/src/credentials.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use serde::{Deserialize, Serialize}; -use crate::{Error, Result, storage::StorageManager}; +use crate::{Error, Result, oauth::OAuthToken, storage::StorageManager}; #[derive(Serialize, Deserialize, Debug)] pub struct ApiCredentials { @@ -31,6 +31,10 @@ impl CredentialManager { format!("{}_{}", self.namespace, tool_name) } + fn oauth_storage_key(&self, resource: &str) -> String { + self.namespaced_key(&format!("oauth_{resource}")) + } + pub async fn store_credentials( &self, tool_name: &str, @@ -68,4 +72,37 @@ impl CredentialManager { let key = self.namespaced_key(tool_name); self.storage.delete_secure_item(&key).await } + + pub async fn store_oauth_token(&self, resource: &str, token: &OAuthToken) -> Result<()> { + let key = self.oauth_storage_key(resource); + let payload = serde_json::to_vec(token).map_err(|err| { + Error::Storage(format!( + "Failed to serialize OAuth token for secure storage: {err}" + )) + })?; + self.storage + .store_secure_item(&key, &payload, &self.master_key) + .await + } + + pub async fn load_oauth_token(&self, resource: &str) -> Result> { + let key = self.oauth_storage_key(resource); + let raw = self + .storage + .load_secure_item(&key, &self.master_key) + .await?; + if let Some(bytes) = raw { + let token = serde_json::from_slice(&bytes).map_err(|err| { + Error::Storage(format!("Failed to deserialize stored OAuth token: {err}")) + })?; + Ok(Some(token)) + } else { + Ok(None) + } + } + + pub async fn delete_oauth_token(&self, resource: &str) -> Result<()> { + let key = self.oauth_storage_key(resource); + self.storage.delete_secure_item(&key).await + } } diff --git a/crates/owlen-core/src/lib.rs b/crates/owlen-core/src/lib.rs index 8b1692b..9507ddd 100644 --- a/crates/owlen-core/src/lib.rs +++ b/crates/owlen-core/src/lib.rs @@ -15,6 +15,7 @@ pub mod llm; pub mod mcp; pub mod mode; pub mod model; +pub mod oauth; pub mod providers; pub mod router; pub mod sandbox; @@ -36,6 +37,7 @@ pub use credentials::*; pub use encryption::*; pub use formatting::*; pub use input::*; +pub use oauth::*; // Export MCP types but exclude test_utils to avoid ambiguity pub use llm::{ ChatStream, LlmProvider, Provider, ProviderConfig, ProviderRegistry, send_via_stream, diff --git a/crates/owlen-core/src/mcp/factory.rs b/crates/owlen-core/src/mcp/factory.rs index a17a99e..2a66f66 100644 --- a/crates/owlen-core/src/mcp/factory.rs +++ b/crates/owlen-core/src/mcp/factory.rs @@ -3,7 +3,10 @@ /// Provides a unified interface for creating MCP clients based on configuration. /// Supports switching between local (in-process) and remote (STDIO) execution modes. use super::client::McpClient; -use super::{LocalMcpClient, remote_client::RemoteMcpClient}; +use super::{ + LocalMcpClient, + remote_client::{McpRuntimeSecrets, RemoteMcpClient}, +}; use crate::config::{Config, McpMode}; use crate::tools::registry::ToolRegistry; use crate::validation::SchemaValidator; @@ -33,6 +36,14 @@ impl McpClientFactory { /// Create an MCP client based on the current configuration. pub fn create(&self) -> Result> { + self.create_with_secrets(None) + } + + /// Create an MCP client using optional runtime secrets (OAuth tokens, env overrides). + pub fn create_with_secrets( + &self, + runtime: Option, + ) -> Result> { match self.config.mcp.mode { McpMode::Disabled => Err(Error::Config( "MCP mode is set to 'disabled'; tooling cannot function in this configuration." @@ -48,14 +59,14 @@ impl McpClientFactory { ))) } McpMode::RemoteOnly => { - let server_cfg = self.config.mcp_servers.first().ok_or_else(|| { + let server_cfg = self.config.effective_mcp_servers().first().ok_or_else(|| { Error::Config( "MCP mode 'remote_only' requires at least one entry in [[mcp_servers]]" .to_string(), ) })?; - RemoteMcpClient::new_with_config(server_cfg) + RemoteMcpClient::new_with_runtime(server_cfg, runtime) .map(|client| Box::new(client) as Box) .map_err(|e| { Error::Config(format!( @@ -65,8 +76,8 @@ impl McpClientFactory { }) } McpMode::RemotePreferred => { - if let Some(server_cfg) = self.config.mcp_servers.first() { - match RemoteMcpClient::new_with_config(server_cfg) { + if let Some(server_cfg) = self.config.effective_mcp_servers().first() { + match RemoteMcpClient::new_with_runtime(server_cfg, runtime.clone()) { Ok(client) => { info!( "Connected to remote MCP server '{}' via {} transport.", @@ -125,7 +136,8 @@ mod tests { #[test] fn test_factory_creates_local_client_when_no_servers_configured() { - let config = Config::default(); + let mut config = Config::default(); + config.refresh_mcp_servers(None).unwrap(); let factory = build_factory(config); @@ -139,6 +151,7 @@ mod tests { let mut config = Config::default(); config.mcp.mode = McpMode::RemoteOnly; config.mcp_servers.clear(); + config.refresh_mcp_servers(None).unwrap(); let factory = build_factory(config); let result = factory.create(); @@ -156,7 +169,9 @@ mod tests { args: Vec::new(), transport: "stdio".to_string(), env: std::collections::HashMap::new(), + oauth: None, }]; + config.refresh_mcp_servers(None).unwrap(); let factory = build_factory(config); let result = factory.create(); diff --git a/crates/owlen-core/src/mcp/failover.rs b/crates/owlen-core/src/mcp/failover.rs index 3acc794..9b806a9 100644 --- a/crates/owlen-core/src/mcp/failover.rs +++ b/crates/owlen-core/src/mcp/failover.rs @@ -305,6 +305,7 @@ mod tests { args: vec![], transport: "http".to_string(), env: std::collections::HashMap::new(), + oauth: None, }; if let Ok(client) = RemoteMcpClient::new_with_config(&config) { diff --git a/crates/owlen-core/src/mcp/remote_client.rs b/crates/owlen-core/src/mcp/remote_client.rs index f55b6cb..7cdf5e9 100644 --- a/crates/owlen-core/src/mcp/remote_client.rs +++ b/crates/owlen-core/src/mcp/remote_client.rs @@ -12,6 +12,7 @@ use anyhow::anyhow; use futures::{StreamExt, future::BoxFuture, stream}; use reqwest::Client as HttpClient; use serde_json::json; +use std::collections::HashMap; use std::path::Path; use std::sync::Arc; use std::sync::atomic::{AtomicU64, Ordering}; @@ -39,6 +40,15 @@ pub struct RemoteMcpClient { ws_endpoint: Option, // Incrementing request identifier. next_id: AtomicU64, + // Optional HTTP header (name, value) injected into every request. + http_header: Option<(String, String)>, +} + +/// Runtime secrets provided when constructing an MCP client. +#[derive(Debug, Default, Clone)] +pub struct McpRuntimeSecrets { + pub env_overrides: HashMap, + pub http_header: Option<(String, String)>, } impl RemoteMcpClient { @@ -48,6 +58,14 @@ impl RemoteMcpClient { /// Spawn an external MCP server based on a configuration entry. /// The server must communicate over STDIO (the only supported transport). pub fn new_with_config(config: &crate::config::McpServerConfig) -> Result { + Self::new_with_runtime(config, None) + } + + pub fn new_with_runtime( + config: &crate::config::McpServerConfig, + runtime: Option, + ) -> Result { + let mut runtime = runtime.unwrap_or_default(); let transport = config.transport.to_lowercase(); match transport.as_str() { "stdio" => { @@ -64,6 +82,9 @@ impl RemoteMcpClient { for (k, v) in config.env.iter() { cmd.env(k, v); } + for (k, v) in runtime.env_overrides.drain() { + cmd.env(k, v); + } let mut child = cmd.spawn().map_err(|e| { Error::Io(std::io::Error::new( @@ -92,6 +113,7 @@ impl RemoteMcpClient { ws_stream: None, ws_endpoint: None, next_id: AtomicU64::new(1), + http_header: None, }) } "http" => { @@ -109,6 +131,7 @@ impl RemoteMcpClient { ws_stream: None, ws_endpoint: None, next_id: AtomicU64::new(1), + http_header: runtime.http_header.take(), }) } "websocket" => { @@ -132,6 +155,7 @@ impl RemoteMcpClient { ws_stream: Some(Arc::new(Mutex::new(ws_stream))), ws_endpoint: Some(ws_url), next_id: AtomicU64::new(1), + http_header: runtime.http_header.take(), }) } other => Err(Error::NotImplemented(format!( @@ -171,6 +195,7 @@ impl RemoteMcpClient { args: Vec::new(), transport: "stdio".to_string(), env: std::collections::HashMap::new(), + oauth: None, }; Self::new_with_config(&config) } @@ -193,8 +218,11 @@ impl RemoteMcpClient { .http_endpoint .as_ref() .ok_or_else(|| Error::Network("Missing HTTP endpoint".into()))?; - let resp = client - .post(endpoint) + let mut builder = client.post(endpoint); + if let Some((ref header_name, ref header_value)) = self.http_header { + builder = builder.header(header_name, header_value); + } + let resp = builder .json(&request) .send() .await diff --git a/crates/owlen-core/src/oauth.rs b/crates/owlen-core/src/oauth.rs new file mode 100644 index 0000000..56be6e3 --- /dev/null +++ b/crates/owlen-core/src/oauth.rs @@ -0,0 +1,507 @@ +use std::time::Duration as StdDuration; + +use chrono::{DateTime, Duration, Utc}; +use reqwest::Client; +use serde::{Deserialize, Serialize}; + +use crate::{Error, Result, config::McpOAuthConfig}; + +/// Persisted OAuth token set for MCP servers and providers. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)] +pub struct OAuthToken { + /// Bearer access token returned by the authorization server. + pub access_token: String, + /// Optional refresh token if the provider issues one. + #[serde(default)] + pub refresh_token: Option, + /// Absolute UTC expiration timestamp for the access token. + #[serde(default)] + pub expires_at: Option>, + /// Optional space-delimited scope string supplied by the provider. + #[serde(default)] + pub scope: Option, + /// Token type reported by the provider (typically `Bearer`). + #[serde(default)] + pub token_type: Option, +} + +impl OAuthToken { + /// Returns `true` if the access token has expired at the provided instant. + pub fn is_expired(&self, now: DateTime) -> bool { + matches!(self.expires_at, Some(expiry) if now >= expiry) + } + + /// Returns `true` if the token will expire within the supplied duration window. + pub fn will_expire_within(&self, window: Duration, now: DateTime) -> bool { + matches!(self.expires_at, Some(expiry) if expiry - now <= window) + } +} + +/// Active device-authorization session details returned by the authorization server. +#[derive(Debug, Clone)] +pub struct DeviceAuthorization { + pub device_code: String, + pub user_code: String, + pub verification_uri: String, + pub verification_uri_complete: Option, + pub expires_at: DateTime, + pub interval: StdDuration, + pub message: Option, +} + +impl DeviceAuthorization { + pub fn is_expired(&self, now: DateTime) -> bool { + now >= self.expires_at + } +} + +/// Result of polling the token endpoint during a device-authorization flow. +#[derive(Debug, Clone)] +pub enum DevicePollState { + Pending { retry_in: StdDuration }, + Complete(OAuthToken), +} + +pub struct OAuthClient { + http: Client, + config: McpOAuthConfig, +} + +impl OAuthClient { + pub fn new(config: McpOAuthConfig) -> Result { + let http = Client::builder() + .user_agent("OwlenOAuth/1.0") + .build() + .map_err(|err| Error::Network(format!("Failed to construct HTTP client: {err}")))?; + Ok(Self { http, config }) + } + + fn scope_value(&self) -> Option { + if self.config.scopes.is_empty() { + None + } else { + Some(self.config.scopes.join(" ")) + } + } + + fn token_request_base(&self) -> Vec<(String, String)> { + let mut params = vec![("client_id".to_string(), self.config.client_id.clone())]; + if let Some(secret) = &self.config.client_secret { + params.push(("client_secret".to_string(), secret.clone())); + } + params + } + + pub async fn start_device_authorization(&self) -> Result { + let device_url = self + .config + .device_authorization_url + .as_ref() + .ok_or_else(|| { + Error::Config("Device authorization endpoint is not configured.".to_string()) + })?; + + let mut params = self.token_request_base(); + if let Some(scope) = self.scope_value() { + params.push(("scope".to_string(), scope)); + } + + let response = self + .http + .post(device_url) + .form(¶ms) + .send() + .await + .map_err(|err| map_http_error("start device authorization", err))?; + + let status = response.status(); + let payload = response + .json::() + .await + .map_err(|err| { + Error::Auth(format!( + "Failed to parse device authorization response (status {status}): {err}" + )) + })?; + + let expires_at = + Utc::now() + Duration::seconds(payload.expires_in.min(i64::MAX as u64) as i64); + let interval = StdDuration::from_secs(payload.interval.unwrap_or(5).max(1)); + + Ok(DeviceAuthorization { + device_code: payload.device_code, + user_code: payload.user_code, + verification_uri: payload.verification_uri, + verification_uri_complete: payload.verification_uri_complete, + expires_at, + interval, + message: payload.message, + }) + } + + pub async fn poll_device_token(&self, auth: &DeviceAuthorization) -> Result { + let mut params = self.token_request_base(); + params.push(("grant_type".to_string(), DEVICE_CODE_GRANT.to_string())); + params.push(("device_code".to_string(), auth.device_code.clone())); + if let Some(scope) = self.scope_value() { + params.push(("scope".to_string(), scope)); + } + + let response = self + .http + .post(&self.config.token_url) + .form(¶ms) + .send() + .await + .map_err(|err| map_http_error("poll device token", err))?; + + let status = response.status(); + let text = response + .text() + .await + .map_err(|err| map_http_error("read token response", err))?; + + if status.is_success() { + let payload: TokenResponse = serde_json::from_str(&text).map_err(|err| { + Error::Auth(format!( + "Failed to parse OAuth token response: {err}; body: {text}" + )) + })?; + return Ok(DevicePollState::Complete(oauth_token_from_response( + payload, + ))); + } + + let error = serde_json::from_str::(&text).unwrap_or_else(|_| { + OAuthErrorResponse { + error: "unknown_error".to_string(), + error_description: Some(text.clone()), + } + }); + + match error.error.as_str() { + "authorization_pending" => Ok(DevicePollState::Pending { + retry_in: auth.interval, + }), + "slow_down" => Ok(DevicePollState::Pending { + retry_in: auth.interval.saturating_add(StdDuration::from_secs(5)), + }), + "access_denied" => { + Err(Error::Auth(error.error_description.unwrap_or_else(|| { + "User declined authorization".to_string() + }))) + } + "expired_token" | "expired_device_code" => { + Err(Error::Auth(error.error_description.unwrap_or_else(|| { + "Device authorization expired".to_string() + }))) + } + other => Err(Error::Auth( + error + .error_description + .unwrap_or_else(|| format!("OAuth error: {other}")), + )), + } + } + + pub async fn refresh_token(&self, refresh_token: &str) -> Result { + let mut params = self.token_request_base(); + params.push(("grant_type".to_string(), "refresh_token".to_string())); + params.push(("refresh_token".to_string(), refresh_token.to_string())); + if let Some(scope) = self.scope_value() { + params.push(("scope".to_string(), scope)); + } + + let response = self + .http + .post(&self.config.token_url) + .form(¶ms) + .send() + .await + .map_err(|err| map_http_error("refresh OAuth token", err))?; + + let status = response.status(); + let text = response + .text() + .await + .map_err(|err| map_http_error("read refresh response", err))?; + + if status.is_success() { + let payload: TokenResponse = serde_json::from_str(&text).map_err(|err| { + Error::Auth(format!( + "Failed to parse OAuth refresh response: {err}; body: {text}" + )) + })?; + Ok(oauth_token_from_response(payload)) + } else { + let error = serde_json::from_str::(&text).unwrap_or_else(|_| { + OAuthErrorResponse { + error: "unknown_error".to_string(), + error_description: Some(text.clone()), + } + }); + Err(Error::Auth(error.error_description.unwrap_or_else(|| { + format!("OAuth token refresh failed: {}", error.error) + }))) + } + } +} + +const DEVICE_CODE_GRANT: &str = "urn:ietf:params:oauth:grant-type:device_code"; + +#[derive(Debug, Deserialize)] +struct DeviceAuthorizationResponse { + device_code: String, + user_code: String, + verification_uri: String, + #[serde(default)] + verification_uri_complete: Option, + expires_in: u64, + #[serde(default)] + interval: Option, + #[serde(default)] + message: Option, +} + +#[derive(Debug, Deserialize)] +struct TokenResponse { + access_token: String, + #[serde(default)] + refresh_token: Option, + #[serde(default)] + expires_in: Option, + #[serde(default)] + scope: Option, + #[serde(default)] + token_type: Option, +} + +#[derive(Debug, Deserialize)] +struct OAuthErrorResponse { + error: String, + #[serde(default)] + error_description: Option, +} + +fn oauth_token_from_response(payload: TokenResponse) -> OAuthToken { + let expires_at = payload + .expires_in + .map(|seconds| seconds.min(i64::MAX as u64) as i64) + .map(|seconds| Utc::now() + Duration::seconds(seconds)); + + OAuthToken { + access_token: payload.access_token, + refresh_token: payload.refresh_token, + expires_at, + scope: payload.scope, + token_type: payload.token_type, + } +} + +fn map_http_error(action: &str, err: reqwest::Error) -> Error { + if err.is_timeout() { + Error::Timeout(format!("OAuth {action} request timed out: {err}")) + } else if err.is_connect() { + Error::Network(format!("OAuth {action} connection error: {err}")) + } else { + Error::Network(format!("OAuth {action} request failed: {err}")) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use httpmock::prelude::*; + use serde_json::json; + + fn config_for(server: &MockServer) -> McpOAuthConfig { + McpOAuthConfig { + client_id: "test-client".to_string(), + client_secret: None, + authorize_url: server.url("/authorize"), + token_url: server.url("/token"), + device_authorization_url: Some(server.url("/device")), + redirect_url: None, + scopes: vec!["repo".to_string(), "user".to_string()], + token_env: None, + header: None, + header_prefix: None, + } + } + + fn sample_device_authorization() -> DeviceAuthorization { + DeviceAuthorization { + device_code: "device-123".to_string(), + user_code: "ABCD-EFGH".to_string(), + verification_uri: "https://example.test/activate".to_string(), + verification_uri_complete: Some( + "https://example.test/activate?user_code=ABCD-EFGH".to_string(), + ), + expires_at: Utc::now() + Duration::minutes(10), + interval: StdDuration::from_secs(5), + message: Some("Open the verification URL and enter the code.".to_string()), + } + } + + #[tokio::test] + async fn start_device_authorization_returns_payload() { + let server = MockServer::start_async().await; + let device_mock = server + .mock_async(|when, then| { + when.method(POST).path("/device"); + then.status(200) + .header("content-type", "application/json") + .json_body(json!({ + "device_code": "device-123", + "user_code": "ABCD-EFGH", + "verification_uri": "https://example.test/activate", + "verification_uri_complete": "https://example.test/activate?user_code=ABCD-EFGH", + "expires_in": 600, + "interval": 7, + "message": "Open the verification URL and enter the code." + })); + }) + .await; + + let client = OAuthClient::new(config_for(&server)).expect("client"); + let auth = client + .start_device_authorization() + .await + .expect("device authorization payload"); + + assert_eq!(auth.user_code, "ABCD-EFGH"); + assert_eq!(auth.interval, StdDuration::from_secs(7)); + assert!(auth.expires_at > Utc::now()); + device_mock.assert_async().await; + } + + #[tokio::test] + async fn poll_device_token_reports_pending() { + let server = MockServer::start_async().await; + let pending = server + .mock_async(|when, then| { + when.method(POST) + .path("/token") + .body_contains( + "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code", + ) + .body_contains("device_code=device-123"); + then.status(400) + .header("content-type", "application/json") + .json_body(json!({ + "error": "authorization_pending" + })); + }) + .await; + + let config = config_for(&server); + let client = OAuthClient::new(config).expect("client"); + let auth = sample_device_authorization(); + + let result = client.poll_device_token(&auth).await.expect("poll result"); + match result { + DevicePollState::Pending { retry_in } => { + assert_eq!(retry_in, StdDuration::from_secs(5)); + } + other => panic!("expected pending state, got {other:?}"), + } + + pending.assert_async().await; + } + + #[tokio::test] + async fn poll_device_token_applies_slow_down_backoff() { + let server = MockServer::start_async().await; + let slow = server + .mock_async(|when, then| { + when.method(POST).path("/token"); + then.status(400) + .header("content-type", "application/json") + .json_body(json!({ + "error": "slow_down" + })); + }) + .await; + + let config = config_for(&server); + let client = OAuthClient::new(config).expect("client"); + let auth = sample_device_authorization(); + + let result = client.poll_device_token(&auth).await.expect("poll result"); + match result { + DevicePollState::Pending { retry_in } => { + assert_eq!(retry_in, StdDuration::from_secs(10)); + } + other => panic!("expected pending state, got {other:?}"), + } + + slow.assert_async().await; + } + + #[tokio::test] + async fn poll_device_token_returns_token_when_authorized() { + let server = MockServer::start_async().await; + let token = server + .mock_async(|when, then| { + when.method(POST).path("/token"); + then.status(200) + .header("content-type", "application/json") + .json_body(json!({ + "access_token": "token-abc", + "refresh_token": "refresh-xyz", + "expires_in": 3600, + "token_type": "Bearer", + "scope": "repo user" + })); + }) + .await; + + let config = config_for(&server); + let client = OAuthClient::new(config).expect("client"); + let auth = sample_device_authorization(); + + let result = client.poll_device_token(&auth).await.expect("poll result"); + let token_info = match result { + DevicePollState::Complete(token) => token, + other => panic!("expected completion, got {other:?}"), + }; + + assert_eq!(token_info.access_token, "token-abc"); + assert_eq!(token_info.refresh_token.as_deref(), Some("refresh-xyz")); + assert!(token_info.expires_at.is_some()); + token.assert_async().await; + } + + #[tokio::test] + async fn refresh_token_roundtrip() { + let server = MockServer::start_async().await; + let refresh = server + .mock_async(|when, then| { + when.method(POST) + .path("/token") + .body_contains("grant_type=refresh_token") + .body_contains("refresh_token=old-refresh"); + then.status(200) + .header("content-type", "application/json") + .json_body(json!({ + "access_token": "token-new", + "refresh_token": "refresh-new", + "expires_in": 1200, + "token_type": "Bearer" + })); + }) + .await; + + let config = config_for(&server); + let client = OAuthClient::new(config).expect("client"); + let token = client + .refresh_token("old-refresh") + .await + .expect("refresh response"); + + assert_eq!(token.access_token, "token-new"); + assert_eq!(token.refresh_token.as_deref(), Some("refresh-new")); + assert!(token.expires_at.is_some()); + refresh.assert_async().await; + } +} diff --git a/crates/owlen-core/src/session.rs b/crates/owlen-core/src/session.rs index 62ada6f..bbdf1a9 100644 --- a/crates/owlen-core/src/session.rs +++ b/crates/owlen-core/src/session.rs @@ -1,4 +1,4 @@ -use crate::config::Config; +use crate::config::{Config, McpResourceConfig, McpServerConfig}; use crate::consent::ConsentManager; use crate::conversation::ConversationManager; use crate::credentials::CredentialManager; @@ -9,8 +9,10 @@ use crate::mcp::McpToolCall; use crate::mcp::client::McpClient; use crate::mcp::factory::McpClientFactory; use crate::mcp::permission::PermissionLayer; +use crate::mcp::remote_client::{McpRuntimeSecrets, RemoteMcpClient}; use crate::mode::Mode; use crate::model::{DetailedModelInfo, ModelManager}; +use crate::oauth::{DeviceAuthorization, DevicePollState, OAuthClient}; use crate::providers::OllamaProvider; use crate::storage::{SessionMeta, StorageManager}; use crate::types::{ @@ -24,8 +26,10 @@ use crate::{ ToolRegistry, WebScrapeTool, WebSearchDetailedTool, WebSearchTool, }; use crate::{Error, Result}; +use chrono::Utc; use log::warn; -use serde_json::Value; +use serde_json::{Value, json}; +use std::collections::HashMap; use std::env; use std::path::PathBuf; use std::sync::{Arc, Mutex}; @@ -96,6 +100,7 @@ pub struct SessionController { tool_registry: Arc, schema_validator: Arc, mcp_client: Arc, + named_mcp_clients: HashMap>, storage: Arc, vault: Option>>, master_key: Option>>, @@ -103,6 +108,7 @@ pub struct SessionController { ui: Arc, enable_code_tools: bool, current_mode: Mode, + missing_oauth_servers: Vec, } async fn build_tools( @@ -211,6 +217,112 @@ async fn build_tools( } impl SessionController { + async fn create_mcp_clients( + config: Arc>, + tool_registry: Arc, + schema_validator: Arc, + credential_manager: Option>, + initial_mode: Mode, + ) -> Result<( + Arc, + HashMap>, + Vec, + )> { + let guard = config.lock().await; + let config_arc = Arc::new(guard.clone()); + let factory = McpClientFactory::new(config_arc.clone(), tool_registry, schema_validator); + + let mut missing_oauth_servers = Vec::new(); + let primary_runtime = if let Some(primary_cfg) = guard.effective_mcp_servers().first() { + let (runtime, missing) = + Self::runtime_secrets_for_server(credential_manager.clone(), primary_cfg).await?; + if missing { + missing_oauth_servers.push(primary_cfg.name.clone()); + } + runtime + } else { + None + }; + + let base_client = factory.create_with_secrets(primary_runtime)?; + let primary: Arc = + Arc::new(PermissionLayer::new(base_client, config_arc.clone())); + primary.set_mode(initial_mode).await?; + + let mut clients: HashMap> = HashMap::new(); + if let Some(primary_cfg) = guard.effective_mcp_servers().first() { + clients.insert(primary_cfg.name.clone(), Arc::clone(&primary)); + } + + for server_cfg in guard.effective_mcp_servers().iter().skip(1) { + let (runtime, missing) = + Self::runtime_secrets_for_server(credential_manager.clone(), server_cfg).await?; + if missing { + missing_oauth_servers.push(server_cfg.name.clone()); + } + + match RemoteMcpClient::new_with_runtime(server_cfg, runtime) { + Ok(remote) => { + let client: Arc = + Arc::new(PermissionLayer::new(Box::new(remote), config_arc.clone())); + if let Err(err) = client.set_mode(initial_mode).await { + warn!( + "Failed to initialize MCP server '{}' in mode {:?}: {}", + server_cfg.name, initial_mode, err + ); + } + clients.insert(server_cfg.name.clone(), Arc::clone(&client)); + } + Err(err) => warn!( + "Failed to initialize MCP server '{}': {}", + server_cfg.name, err + ), + } + } + + drop(guard); + + Ok((primary, clients, missing_oauth_servers)) + } + + async fn runtime_secrets_for_server( + credential_manager: Option>, + server: &McpServerConfig, + ) -> Result<(Option, bool)> { + if let Some(oauth) = &server.oauth { + if let Some(manager) = credential_manager { + match manager.load_oauth_token(&server.name).await? { + Some(token) => { + if token.access_token.trim().is_empty() || token.is_expired(Utc::now()) { + return Ok((None, true)); + } + let mut secrets = McpRuntimeSecrets::default(); + if let Some(env_name) = oauth.token_env.as_deref() { + secrets + .env_overrides + .insert(env_name.to_string(), token.access_token.clone()); + } + if matches!( + server.transport.to_ascii_lowercase().as_str(), + "http" | "websocket" + ) { + let header_value = + format!("{}{}", oauth.header_prefix(), token.access_token); + secrets.http_header = + Some((oauth.header_name().to_string(), header_value)); + } + Ok((Some(secrets), false)) + } + None => Ok((None, true)), + } + } else { + Ok((None, true)) + } + } else { + Ok((None, false)) + } + } + pub async fn new( provider: Arc, config: Config, @@ -292,19 +404,14 @@ impl SessionController { ) .await?; - // Create MCP client with permission layer - let mcp_client: Arc = { - let guard = config_arc.lock().await; - let factory = McpClientFactory::new( - Arc::new(guard.clone()), - tool_registry.clone(), - schema_validator.clone(), - ); - let base_client = factory.create()?; - let client = Arc::new(PermissionLayer::new(base_client, Arc::new(guard.clone()))); - client.set_mode(initial_mode).await?; - client - }; + let (mcp_client, named_mcp_clients, missing_oauth_servers) = Self::create_mcp_clients( + config_arc.clone(), + tool_registry.clone(), + schema_validator.clone(), + credential_manager.clone(), + initial_mode, + ) + .await?; Ok(Self { provider, @@ -317,6 +424,7 @@ impl SessionController { tool_registry, schema_validator, mcp_client, + named_mcp_clients, storage, vault: vault_handle, master_key, @@ -324,6 +432,7 @@ impl SessionController { ui, enable_code_tools, current_mode: initial_mode, + missing_oauth_servers, }) } @@ -355,6 +464,63 @@ impl SessionController { self.formatter.set_role_label_mode(mode); } + /// Return the configured resource references aggregated across scopes. + pub async fn configured_resources(&self) -> Vec { + let guard = self.config.lock().await; + guard.effective_mcp_resources().to_vec() + } + + /// Resolve a resource reference of the form `server:uri` (optionally prefixed with `@`). + pub async fn resolve_resource_reference(&self, reference: &str) -> Result> { + let (server, uri) = match Self::split_resource_reference(reference) { + Some(parts) => parts, + None => return Ok(None), + }; + + let resource_defined = { + let guard = self.config.lock().await; + guard.find_resource(&server, &uri).is_some() + }; + + if !resource_defined { + return Ok(None); + } + + let client = self + .named_mcp_clients + .get(&server) + .cloned() + .ok_or_else(|| { + Error::Config(format!( + "MCP server '{}' referenced by resource '{}' is not available", + server, uri + )) + })?; + + let call = McpToolCall { + name: "resources/get".to_string(), + arguments: json!({ "uri": uri, "path": uri }), + }; + let response = client.call_tool(call).await?; + if let Some(text) = extract_resource_content(&response.output) { + return Ok(Some(text)); + } + + let formatted = serde_json::to_string_pretty(&response.output) + .unwrap_or_else(|_| response.output.to_string()); + Ok(Some(formatted)) + } + + fn split_resource_reference(reference: &str) -> Option<(String, String)> { + let trimmed = reference.trim(); + let without_prefix = trimmed.strip_prefix('@').unwrap_or(trimmed); + let (server, uri) = without_prefix.split_once(':')?; + if server.is_empty() || uri.is_empty() { + return None; + } + Some((server.to_string(), uri.to_string())) + } + // Asynchronous access to the configuration (used internally). pub async fn config_async(&self) -> tokio::sync::MutexGuard<'_, Config> { self.config.lock().await @@ -378,6 +544,21 @@ impl SessionController { self.config.clone() } + pub async fn reload_mcp_clients(&mut self) -> Result<()> { + let (primary, named, missing) = Self::create_mcp_clients( + self.config.clone(), + self.tool_registry.clone(), + self.schema_validator.clone(), + self.credential_manager.clone(), + self.current_mode, + ) + .await?; + self.mcp_client = primary; + self.named_mcp_clients = named; + self.missing_oauth_servers = missing; + Ok(()) + } + pub fn grant_consent(&self, tool_name: &str, data_types: Vec, endpoints: Vec) { let mut consent = self .consent_manager @@ -525,6 +706,115 @@ impl SessionController { self.schema_validator.clone() } + pub fn credential_manager(&self) -> Option> { + self.credential_manager.clone() + } + + pub fn pending_oauth_servers(&self) -> Vec { + self.missing_oauth_servers.clone() + } + + pub async fn start_oauth_device_flow(&self, server: &str) -> Result { + let oauth_config = { + let config = self.config.lock().await; + let server_cfg = config + .effective_mcp_servers() + .iter() + .find(|entry| entry.name == server) + .ok_or_else(|| { + Error::Config(format!("No MCP server named '{server}' is configured")) + })?; + server_cfg.oauth.clone().ok_or_else(|| { + Error::Config(format!( + "MCP server '{server}' does not define an OAuth configuration" + )) + })? + }; + + let client = OAuthClient::new(oauth_config)?; + client.start_device_authorization().await + } + + pub async fn poll_oauth_device_flow( + &mut self, + server: &str, + authorization: &DeviceAuthorization, + ) -> Result { + let oauth_config = { + let config = self.config.lock().await; + let server_cfg = config + .effective_mcp_servers() + .iter() + .find(|entry| entry.name == server) + .ok_or_else(|| { + Error::Config(format!("No MCP server named '{server}' is configured")) + })?; + server_cfg.oauth.clone().ok_or_else(|| { + Error::Config(format!( + "MCP server '{server}' does not define an OAuth configuration" + )) + })? + }; + + let client = OAuthClient::new(oauth_config)?; + match client.poll_device_token(authorization).await? { + DevicePollState::Pending { retry_in } => Ok(DevicePollState::Pending { retry_in }), + DevicePollState::Complete(token) => { + let manager = self.credential_manager.as_ref().cloned().ok_or_else(|| { + Error::Config( + "OAuth token storage requires encrypted local data; set \ + privacy.encrypt_local_data = true in the configuration." + .to_string(), + ) + })?; + + manager.store_oauth_token(server, &token).await?; + self.missing_oauth_servers.retain(|entry| entry != server); + + Ok(DevicePollState::Complete(token)) + } + } + } + + pub async fn list_mcp_tools(&self) -> Vec<(String, crate::mcp::McpToolDescriptor)> { + let mut entries = Vec::new(); + for (server, client) in self.named_mcp_clients.iter() { + let server_name = server.clone(); + let client = Arc::clone(client); + match client.list_tools().await { + Ok(tools) => { + for descriptor in tools { + entries.push((server_name.clone(), descriptor)); + } + } + Err(err) => { + warn!( + "Failed to list tools for MCP server '{}': {}", + server_name, err + ); + } + } + } + entries + } + + pub async fn call_mcp_tool( + &self, + server: &str, + tool: &str, + arguments: Value, + ) -> Result { + let client = self.named_mcp_clients.get(server).cloned().ok_or_else(|| { + Error::Config(format!("No MCP server named '{}' is registered", server)) + })?; + client + .call_tool(McpToolCall { + name: tool.to_string(), + arguments, + }) + .await + } + pub fn mcp_server(&self) -> crate::mcp::McpServer { crate::mcp::McpServer::new(self.tool_registry(), self.schema_validator()) } @@ -985,3 +1275,195 @@ impl SessionController { Ok("Empty conversation".to_string()) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::Provider; + use crate::config::{Config, McpMode, McpOAuthConfig, McpServerConfig}; + use crate::llm::test_utils::MockProvider; + use crate::storage::StorageManager; + use crate::ui::NoOpUiController; + use chrono::Utc; + use httpmock::prelude::*; + use serde_json::json; + use std::collections::HashMap; + use std::sync::Arc; + use tempfile::tempdir; + + const SERVER_NAME: &str = "oauth-test"; + + fn build_oauth_config(server: &MockServer) -> McpOAuthConfig { + McpOAuthConfig { + client_id: "owlen-client".to_string(), + client_secret: None, + authorize_url: server.url("/authorize"), + token_url: server.url("/token"), + device_authorization_url: Some(server.url("/device")), + redirect_url: None, + scopes: vec!["repo".to_string()], + token_env: Some("OAUTH_TOKEN".to_string()), + header: Some("Authorization".to_string()), + header_prefix: Some("Bearer ".to_string()), + } + } + + fn build_config(server: &MockServer) -> Config { + let mut config = Config::default(); + config.mcp.mode = McpMode::LocalOnly; + let oauth = build_oauth_config(server); + + let mut env = HashMap::new(); + env.insert("OWLEN_ENV".to_string(), "test".to_string()); + + config.mcp_servers = vec![McpServerConfig { + name: SERVER_NAME.to_string(), + command: server.url("/mcp"), + args: Vec::new(), + transport: "http".to_string(), + env, + oauth: Some(oauth), + }]; + + config.refresh_mcp_servers(None).unwrap(); + config + } + + async fn build_session(server: &MockServer) -> (SessionController, tempfile::TempDir) { + unsafe { + std::env::set_var("OWLEN_MASTER_PASSWORD", "test-password"); + } + + let temp_dir = tempdir().expect("tempdir"); + let storage_path = temp_dir.path().join("owlen.db"); + let storage = Arc::new( + StorageManager::with_database_path(storage_path) + .await + .expect("storage"), + ); + + let config = build_config(server); + let provider: Arc = Arc::new(MockProvider::default()) as Arc; + let ui = Arc::new(NoOpUiController); + + let session = SessionController::new(provider, config, storage, ui, false) + .await + .expect("session"); + + (session, temp_dir) + } + + #[tokio::test] + async fn start_oauth_device_flow_returns_details() { + let server = MockServer::start_async().await; + let device = server + .mock_async(|when, then| { + when.method(POST).path("/device"); + then.status(200) + .header("content-type", "application/json") + .json_body(json!({ + "device_code": "device-abc", + "user_code": "ABCD-1234", + "verification_uri": "https://example.test/activate", + "verification_uri_complete": "https://example.test/activate?user_code=ABCD-1234", + "expires_in": 600, + "interval": 5, + "message": "Enter the code to continue." + })); + }) + .await; + + let (session, _dir) = build_session(&server).await; + let authorization = session + .start_oauth_device_flow(SERVER_NAME) + .await + .expect("device flow"); + + assert_eq!(authorization.user_code, "ABCD-1234"); + assert_eq!( + authorization.verification_uri_complete.as_deref(), + Some("https://example.test/activate?user_code=ABCD-1234") + ); + assert!(authorization.expires_at > Utc::now()); + device.assert_async().await; + } + + #[tokio::test] + async fn poll_oauth_device_flow_stores_token_and_updates_state() { + let server = MockServer::start_async().await; + + let device = server + .mock_async(|when, then| { + when.method(POST).path("/device"); + then.status(200) + .header("content-type", "application/json") + .json_body(json!({ + "device_code": "device-xyz", + "user_code": "WXYZ-9999", + "verification_uri": "https://example.test/activate", + "verification_uri_complete": "https://example.test/activate?user_code=WXYZ-9999", + "expires_in": 600, + "interval": 5 + })); + }) + .await; + + let token = server + .mock_async(|when, then| { + when.method(POST) + .path("/token") + .body_contains("device_code=device-xyz"); + then.status(200) + .header("content-type", "application/json") + .json_body(json!({ + "access_token": "new-access-token", + "refresh_token": "refresh-token", + "expires_in": 3600, + "token_type": "Bearer" + })); + }) + .await; + + let (mut session, _dir) = build_session(&server).await; + assert_eq!(session.pending_oauth_servers(), vec![SERVER_NAME]); + + let authorization = session + .start_oauth_device_flow(SERVER_NAME) + .await + .expect("device flow"); + + match session + .poll_oauth_device_flow(SERVER_NAME, &authorization) + .await + .expect("token poll") + { + DevicePollState::Complete(token_info) => { + assert_eq!(token_info.access_token, "new-access-token"); + assert_eq!(token_info.refresh_token.as_deref(), Some("refresh-token")); + } + other => panic!("expected token completion, got {other:?}"), + } + + assert!( + session + .pending_oauth_servers() + .iter() + .all(|entry| entry != SERVER_NAME), + "server should be removed from pending list" + ); + + let stored = session + .credential_manager() + .expect("credential manager") + .load_oauth_token(SERVER_NAME) + .await + .expect("load token") + .expect("token present"); + + assert_eq!(stored.access_token, "new-access-token"); + assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token")); + + device.assert_async().await; + token.assert_async().await; + } +} diff --git a/crates/owlen-core/tests/prompt_server.rs b/crates/owlen-core/tests/prompt_server.rs index 56797dc..dbc16b0 100644 --- a/crates/owlen-core/tests/prompt_server.rs +++ b/crates/owlen-core/tests/prompt_server.rs @@ -44,6 +44,7 @@ async fn test_render_prompt_via_external_server() -> Result<()> { args: Vec::new(), transport: "stdio".into(), env: std::collections::HashMap::new(), + oauth: None, }; let client = match RemoteMcpClient::new_with_config(&config) { diff --git a/crates/owlen-mcp-client/src/lib.rs b/crates/owlen-mcp-client/src/lib.rs index 07708ad..f5afb7c 100644 --- a/crates/owlen-mcp-client/src/lib.rs +++ b/crates/owlen-mcp-client/src/lib.rs @@ -5,6 +5,7 @@ //! crates can depend only on `owlen-mcp-client` without pulling in the entire //! core crate internals. +pub use owlen_core::config::{McpConfigScope, ScopedMcpServer}; pub use owlen_core::mcp::remote_client::RemoteMcpClient; pub use owlen_core::mcp::{McpClient, McpToolCall, McpToolDescriptor, McpToolResponse}; diff --git a/crates/owlen-tui/src/chat_app.rs b/crates/owlen-tui/src/chat_app.rs index 0677911..dca7f78 100644 --- a/crates/owlen-tui/src/chat_app.rs +++ b/crates/owlen-tui/src/chat_app.rs @@ -1,10 +1,13 @@ use anyhow::{Context, Result, anyhow}; -use chrono::{DateTime, Local}; +use chrono::{DateTime, Local, Utc}; use crossterm::terminal::{disable_raw_mode, enable_raw_mode}; use owlen_core::mcp::remote_client::RemoteMcpClient; +use owlen_core::mcp::{McpToolDescriptor, McpToolResponse}; use owlen_core::{ Provider, ProviderConfig, + config::McpResourceConfig, model::DetailedModelInfo, + oauth::{DeviceAuthorization, DevicePollState}, session::{SessionController, SessionOutcome}, storage::SessionMeta, theme::Theme, @@ -19,7 +22,7 @@ use tokio::{ sync::mpsc, task::{self, JoinHandle}, }; -use tui_textarea::{Input, TextArea}; +use tui_textarea::{CursorMove, Input, TextArea}; use unicode_width::UnicodeWidthStr; use uuid::Uuid; @@ -27,12 +30,14 @@ use crate::commands; use crate::config; use crate::events::Event; use crate::model_info_panel::ModelInfoPanel; +use crate::slash::{self, McpSlashCommand, SlashCommand}; use crate::state::{ CodeWorkspace, CommandPalette, FileFilterMode, FileIconResolver, FileNode, FileTreeState, ModelPaletteEntry, PaletteSuggestion, PaneDirection, PaneRestoreRequest, RepoSearchMessage, RepoSearchState, SplitAxis, SymbolSearchMessage, SymbolSearchState, WorkspaceSnapshot, spawn_repo_search_task, spawn_symbol_search_task, }; +use crate::toast::{Toast, ToastLevel, ToastManager}; use crate::ui::format_tool_output; // Agent executor moved to separate binary `owlen-agent`. The TUI no longer directly // imports `AgentExecutor` to avoid a circular dependency on `owlen-cli`. @@ -48,6 +53,7 @@ use std::sync::Arc; use std::time::{Duration, Instant, SystemTime}; use dirs::{config_dir, data_local_dir}; +use serde_json::{Value, json}; const ONBOARDING_STATUS_LINE: &str = "Welcome to Owlen! Press F1 for help or type :tutorial for keybinding tips."; @@ -61,6 +67,13 @@ const RESIZE_DOUBLE_TAP_WINDOW: Duration = Duration::from_millis(450); const RESIZE_STEP: f32 = 0.05; const RESIZE_SNAP_VALUES: [f32; 3] = [0.5, 0.75, 0.25]; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum SlashOutcome { + NotCommand, + Consumed, + Error, +} + #[derive(Clone, Debug)] pub(crate) struct ModelSelectorItem { kind: ModelSelectorItemKind, @@ -158,6 +171,11 @@ pub enum SessionEvent { AgentCompleted { answer: String }, /// Agent execution failed AgentFailed { error: String }, + /// Poll the OAuth device authorization flow for the given server + OAuthPoll { + server: String, + authorization: DeviceAuthorization, + }, } pub const HELP_TAB_COUNT: usize = 7; @@ -205,6 +223,9 @@ pub struct ChatApp { clipboard: String, // Vim-style clipboard for yank/paste pending_file_action: Option, // Active file action prompt command_palette: CommandPalette, // Command mode state (buffer + suggestions) + resource_catalog: Vec, // Configured MCP resources for autocompletion + pending_resource_refs: Vec, // Resource references to resolve before send + oauth_flows: HashMap, // Active OAuth device flows by server repo_search: RepoSearchState, // Repository search overlay state repo_search_task: Option>, repo_search_rx: Option>, @@ -235,6 +256,7 @@ pub struct ChatApp { selected_theme_index: usize, // Index of selected theme in browser pending_consent: Option, // Pending consent request system_status: String, // System/status messages (tool execution, status, etc) + toasts: ToastManager, /// Simple execution budget: maximum number of tool calls allowed per session. _execution_budget: usize, /// Agent mode enabled @@ -438,6 +460,9 @@ impl ChatApp { clipboard: String::new(), pending_file_action: None, command_palette: CommandPalette::new(), + resource_catalog: Vec::new(), + pending_resource_refs: Vec::new(), + oauth_flows: HashMap::new(), repo_search: RepoSearchState::new(), repo_search_task: None, repo_search_rx: None, @@ -472,6 +497,7 @@ impl ChatApp { } else { String::new() }, + toasts: ToastManager::new(), _execution_budget: 50, agent_mode: false, agent_running: false, @@ -490,6 +516,8 @@ impl ChatApp { )); app.update_command_palette_catalog(); + app.refresh_resource_catalog().await?; + app.refresh_mcp_slash_commands().await?; if let Err(err) = app.restore_workspace_layout().await { eprintln!("Warning: failed to restore workspace layout: {err}"); @@ -1371,6 +1399,18 @@ impl ChatApp { &self.theme } + pub fn toasts(&self) -> impl Iterator { + self.toasts.iter() + } + + pub fn push_toast(&mut self, level: ToastLevel, message: impl Into) { + self.toasts.push(message, level); + } + + fn prune_toasts(&mut self) { + self.toasts.retain_active(); + } + pub fn input_max_rows(&self) -> u16 { let config = self.controller.config(); config.ui.input_max_rows.max(1) @@ -1443,6 +1483,304 @@ impl ChatApp { .update_dynamic_sources(models, providers); } + async fn refresh_resource_catalog(&mut self) -> Result<()> { + let mut resources = self.controller.configured_resources().await; + resources.sort_by(|a, b| a.server.cmp(&b.server).then(a.uri.cmp(&b.uri))); + self.resource_catalog = resources; + Ok(()) + } + + async fn refresh_mcp_slash_commands(&mut self) -> Result<()> { + let mut commands = Vec::new(); + for (server, descriptor) in self.controller.list_mcp_tools().await { + if !Self::tool_supports_slash(&descriptor) { + continue; + } + let description = if descriptor.description.trim().is_empty() { + None + } else { + Some(descriptor.description.clone()) + }; + commands.push(McpSlashCommand::new( + server, + descriptor.name.clone(), + description, + )); + } + slash::set_mcp_commands(commands); + Ok(()) + } + + fn tool_supports_slash(descriptor: &McpToolDescriptor) -> bool { + if descriptor.name.trim().is_empty() { + return false; + } + Self::tool_allows_empty_arguments(&descriptor.input_schema) + } + + fn tool_allows_empty_arguments(schema: &Value) -> bool { + match schema { + Value::Object(map) => { + if let Some(Value::Array(required)) = map.get("required") { + !required + .iter() + .any(|entry| entry.as_str().is_some_and(|s| !s.is_empty())) + } else { + true + } + } + _ => true, + } + } + + fn format_mcp_slash_message(server: &str, tool: &str, response: &McpToolResponse) -> String { + let status = if response.success { "✓" } else { "✗" }; + let payload = if response.success { + Self::extract_mcp_primary_text(&response.output) + } else { + Self::extract_mcp_error(&response.output) + .or_else(|| Self::extract_mcp_primary_text(&response.output)) + } + .unwrap_or_else(|| Self::pretty_print_value(&response.output)); + + if payload.trim().is_empty() { + return format!("MCP {server}::{tool} {status}"); + } + + if payload.contains('\n') { + format!("MCP {server}::{tool} {status}\n```json\n{payload}\n```") + } else { + format!("MCP {server}::{tool} {status}\n{payload}") + } + } + + fn extract_mcp_primary_text(value: &Value) -> Option { + if let Some(text) = value.as_str().filter(|text| !text.trim().is_empty()) { + return Some(text.to_string()); + } + + if let Value::Object(map) = value { + const CANDIDATES: [&str; 6] = + ["rendered", "text", "content", "value", "message", "body"]; + for key in CANDIDATES { + if let Some(Value::String(text)) = map.get(key) + && !text.trim().is_empty() + { + return Some(text.clone()); + } + } + + if let Some(Value::Array(items)) = map.get("lines") { + let mut collected = Vec::new(); + for item in items { + if let Some(segment) = item.as_str() + && !segment.trim().is_empty() + { + collected.push(segment.trim()); + } + } + if !collected.is_empty() { + return Some(collected.join("\n")); + } + } + } + + None + } + + fn extract_mcp_error(value: &Value) -> Option { + if let Value::Object(map) = value + && let Some(Value::String(message)) = map.get("error") + && !message.trim().is_empty() + { + return Some(message.clone()); + } + None + } + + fn pretty_print_value(value: &Value) -> String { + serde_json::to_string_pretty(value).unwrap_or_else(|_| value.to_string()) + } + + async fn resolve_pending_resource_references(&mut self) -> Result<()> { + if self.pending_resource_refs.is_empty() { + return Ok(()); + } + + let mut resolved = 0usize; + let references: Vec = self.pending_resource_refs.drain(..).collect(); + for reference in references { + match self.controller.resolve_resource_reference(&reference).await { + Ok(Some(content)) => { + let message = format!("Resource @{}:\n{}", reference, content); + self.controller + .conversation_mut() + .push_system_message(message); + resolved += 1; + } + Ok(None) => { + self.push_toast( + ToastLevel::Warning, + format!( + "Resource @{} is not defined in the current project.", + reference + ), + ); + } + Err(err) => { + self.push_toast( + ToastLevel::Error, + format!("Failed to load resource @{}: {}", reference, err), + ); + } + } + } + + if resolved > 0 { + self.status = format!("Inserted {resolved} resource snippet(s)."); + } + + Ok(()) + } + + fn complete_resource_reference(&mut self) -> bool { + if self.resource_catalog.is_empty() { + return false; + } + + let (row, col) = self.textarea.cursor(); + let lines = self.textarea.lines().to_vec(); + if row >= lines.len() { + return false; + } + + let line = &lines[row]; + let chars: Vec = line.chars().collect(); + if col > chars.len() { + return false; + } + + let mut start = col; + while start > 0 { + let ch = chars[start - 1]; + if ch == '@' { + start -= 1; + break; + } + if ch.is_whitespace() { + return false; + } + start -= 1; + } + + if start >= col || chars.get(start) != Some(&'@') { + return false; + } + + if chars[start + 1..col].iter().any(|ch| ch.is_whitespace()) { + return false; + } + + let mut end = col; + while end < chars.len() { + let ch = chars[end]; + if ch.is_whitespace() { + break; + } + end += 1; + } + + let typed_prefix: String = chars[start + 1..col].iter().collect(); + let trailing_segment: String = chars[col..end].iter().collect(); + let lower_prefix = typed_prefix.to_ascii_lowercase(); + let lower_full = format!("{}{}", typed_prefix, trailing_segment).to_ascii_lowercase(); + + let mut matches: Vec<&McpResourceConfig> = self + .resource_catalog + .iter() + .filter(|resource| { + let reference = format!("{}:{}", resource.server, resource.uri); + let lower_reference = reference.to_ascii_lowercase(); + lower_reference.starts_with(&lower_full) + || lower_reference.starts_with(&lower_prefix) + || resource + .title + .as_ref() + .map(|title| title.to_ascii_lowercase().starts_with(&lower_prefix)) + .unwrap_or(false) + }) + .collect(); + + if matches.is_empty() { + return false; + } + + matches.sort_by(|a, b| a.server.cmp(&b.server).then(a.uri.cmp(&b.uri))); + let (selected_server, selected_uri, selected_title) = { + let selected = matches[0]; + ( + selected.server.clone(), + selected.uri.clone(), + selected.title.clone(), + ) + }; + let replacement = format!("@{}:{}", selected_server, selected_uri); + + let mut new_line = String::new(); + new_line.extend(chars[..start].iter()); + new_line.push_str(&replacement); + new_line.extend(chars[end..].iter()); + + let mut new_lines = lines; + new_lines[row] = new_line; + self.textarea = TextArea::new(new_lines); + configure_textarea_defaults(&mut self.textarea); + + let new_col = start + replacement.len(); + self.textarea + .move_cursor(CursorMove::Jump(row as u16, new_col as u16)); + + self.sync_textarea_to_buffer(); + + if let Some(title) = selected_title.as_deref() { + self.status = format!("Inserted resource {} ({title}).", replacement); + } else { + self.status = format!("Inserted resource {}.", replacement); + } + self.error = None; + + true + } + + fn extract_resource_references(text: &str) -> Vec { + let mut references = Vec::new(); + let mut current = String::new(); + let mut in_reference = false; + + for ch in text.chars() { + if in_reference { + if ch.is_whitespace() || matches!(ch, ',' | ';' | ')' | '(' | '.' | '!' | '?') { + if current.contains(':') { + references.push(current.clone()); + } + current.clear(); + in_reference = false; + } else { + current.push(ch); + } + } else if ch == '@' { + in_reference = true; + current.clear(); + } + } + + if in_reference && current.contains(':') { + references.push(current); + } + + references + } + fn display_name_for_model(model: &ModelInfo) -> String { if model.name.trim().is_empty() { model.id.clone() @@ -2110,6 +2448,204 @@ impl ChatApp { configure_textarea_defaults(&mut self.textarea); } + async fn process_slash_submission(&mut self) -> Result { + let raw = self.controller.input_buffer().text().to_string(); + if raw.trim().is_empty() { + return Ok(SlashOutcome::NotCommand); + } + + match slash::parse(&raw) { + Ok(None) => Ok(SlashOutcome::NotCommand), + Ok(Some(command)) => match self.execute_slash_command(command).await { + Ok(()) => { + self.input_buffer_mut().push_history_entry(raw.clone()); + self.controller.input_buffer_mut().clear(); + Ok(SlashOutcome::Consumed) + } + Err(err) => { + self.error = Some(err.to_string()); + self.status = "Slash command failed".to_string(); + self.controller.input_buffer_mut().set_text(raw); + Ok(SlashOutcome::Error) + } + }, + Err(err) => { + self.error = Some(err.to_string()); + self.status = "Slash command error".to_string(); + Ok(SlashOutcome::Error) + } + } + } + + async fn execute_slash_command(&mut self, command: SlashCommand) -> Result<()> { + match command { + SlashCommand::Summarize { count } => { + let prompt = if let Some(count) = count { + format!( + "Summarize the last {count} messages in this conversation. Highlight key decisions, open questions, and follow-up tasks." + ) + } else { + "Summarize the conversation so far, calling out major decisions, blockers, and immediate next steps.".to_string() + }; + self.status = "Summarizing conversation...".to_string(); + self.dispatch_user_prompt(prompt); + } + SlashCommand::Explain { snippet } => { + let prompt = format!( + "Explain the following code snippet. Cover what it does and call out any potential issues or improvements:\n```\n{}\n```", + snippet + ); + self.status = "Explaining snippet...".to_string(); + self.dispatch_user_prompt(prompt); + } + SlashCommand::Refactor { path } => { + let trimmed = path.trim(); + if trimmed.is_empty() { + anyhow::bail!("usage: /refactor "); + } + let source = self.controller.read_file(trimmed).await?; + let prompt = format!( + "Refactor the file `{}`. Provide specific improvements for readability, safety, and maintainability. Include updated code where relevant.\n\n```text\n{}\n```", + trimmed, source + ); + self.status = format!("Refactor review for {trimmed}..."); + self.dispatch_user_prompt(prompt); + } + SlashCommand::TestPlan => { + let prompt = "Generate a comprehensive test plan for this repository. Outline critical test suites, coverage gaps, and prioritized steps to reach confident automation.".to_string(); + self.status = "Generating test plan...".to_string(); + self.dispatch_user_prompt(prompt); + } + SlashCommand::Compact => { + let prompt = "Compress our conversation history to its essentials. Summarize previous exchanges, preserve critical context, and indicate what state can be safely forgotten.".to_string(); + self.status = "Compacting conversation...".to_string(); + self.dispatch_user_prompt(prompt); + } + SlashCommand::McpTool { server, tool } => { + self.status = format!("Running MCP tool {server}::{tool}..."); + let response = self + .controller + .call_mcp_tool(&server, &tool, json!({})) + .await + .map_err(|err| { + anyhow!("Failed to invoke MCP tool {}::{}: {}", server, tool, err) + })?; + + let content = Self::format_mcp_slash_message(&server, &tool, &response); + self.controller + .conversation_mut() + .push_system_message(content); + self.auto_scroll.stick_to_bottom = true; + self.new_message_alert = true; + + if response.success { + self.status = format!("MCP {server}::{tool} result added to chat."); + self.push_toast(ToastLevel::Info, format!("MCP {server}::{tool} completed.")); + } else { + self.status = format!("MCP {server}::{tool} reported an error (see chat)."); + self.push_toast( + ToastLevel::Warning, + format!("MCP {server}::{tool} reported an error."), + ); + } + self.error = None; + } + } + Ok(()) + } + + fn schedule_oauth_poll( + &self, + server: String, + authorization: DeviceAuthorization, + delay: Duration, + ) { + let sender = self.session_tx.clone(); + tokio::spawn(async move { + tokio::time::sleep(delay).await; + let _ = sender.send(SessionEvent::OAuthPoll { + server, + authorization, + }); + }); + } + + async fn start_oauth_login(&mut self, server: &str) -> Result<()> { + if self.oauth_flows.contains_key(server) { + self.error = Some(format!("OAuth flow for '{server}' is already in progress.")); + return Ok(()); + } + + let authorization = match self.controller.start_oauth_device_flow(server).await { + Ok(auth) => auth, + Err(err) => { + self.error = Some(format!("Failed to start OAuth for '{server}': {err}")); + return Ok(()); + } + }; + + self.oauth_flows + .insert(server.to_string(), authorization.clone()); + + let link = authorization + .verification_uri_complete + .clone() + .unwrap_or_else(|| authorization.verification_uri.clone()); + let status = format!( + "Authorize '{server}' via {} (code {}).", + link, authorization.user_code + ); + self.status = status; + self.error = None; + + let mut message = format!( + "OAuth authorization required for `{server}`.\nVisit:\n{}\nEnter code: `{}`", + link, authorization.user_code + ); + if let Some(hint) = &authorization.message + && !hint.trim().is_empty() + { + message.push_str("\n\n"); + message.push_str(hint); + } + if authorization.expires_at > Utc::now() { + message.push_str(&format!( + "\n\nThis code expires at {}.", + authorization + .expires_at + .to_rfc3339_opts(chrono::SecondsFormat::Secs, true) + )); + } + + self.controller + .conversation_mut() + .push_system_message(message); + self.auto_scroll.stick_to_bottom = true; + self.notify_new_activity(); + + self.push_toast( + ToastLevel::Warning, + format!("Authorize {server}: code {}", authorization.user_code), + ); + + let delay = authorization.interval; + self.schedule_oauth_poll(server.to_string(), authorization.clone(), delay); + Ok(()) + } + + fn dispatch_user_prompt(&mut self, prompt: String) { + if prompt.trim().is_empty() { + self.error = Some("Slash command generated an empty request".to_string()); + return; + } + + self.controller.conversation_mut().push_user_message(prompt); + self.auto_scroll.stick_to_bottom = true; + self.pending_llm_request = true; + self.set_system_status(String::new()); + self.error = None; + } + fn set_code_view_content( &mut self, display_path: impl Into, @@ -2216,14 +2752,14 @@ impl ChatApp { Ok(()) } - async fn restore_workspace_layout(&mut self) -> Result<()> { + async fn restore_workspace_layout(&mut self) -> Result { let path = match self.workspace_layout_path() { Ok(path) => path, - Err(_) => return Ok(()), + Err(_) => return Ok(false), }; if !path.exists() { - return Ok(()); + return Ok(false); } let contents = fs::read_to_string(&path) @@ -2247,7 +2783,7 @@ impl ChatApp { self.status = "Workspace layout restored".to_string(); } - Ok(()) + Ok(restored_any) } fn direction_label(direction: PaneDirection) -> &'static str { @@ -3289,6 +3825,7 @@ impl ChatApp { Event::Tick => { self.poll_repo_search(); self.poll_symbol_search(); + self.prune_toasts(); // Future: update streaming timers } Event::Resize(width, height) => { @@ -4172,13 +4709,24 @@ impl ChatApp { self.textarea.insert_newline(); } (KeyCode::Enter, KeyModifiers::NONE) => { - // Send message and return to normal mode self.sync_textarea_to_buffer(); - self.send_user_message_and_request_response(); - // Clear the textarea by setting it to empty - self.textarea = TextArea::default(); - configure_textarea_defaults(&mut self.textarea); - self.set_input_mode(InputMode::Normal); + match self.process_slash_submission().await? { + SlashOutcome::NotCommand => { + self.send_user_message_and_request_response(); + self.textarea = TextArea::default(); + configure_textarea_defaults(&mut self.textarea); + self.set_input_mode(InputMode::Normal); + } + SlashOutcome::Consumed => { + self.textarea = TextArea::default(); + configure_textarea_defaults(&mut self.textarea); + self.set_input_mode(InputMode::Normal); + } + SlashOutcome::Error => { + // Restore textarea content so the user can correct the command + self.sync_buffer_to_textarea(); + } + } } (KeyCode::Enter, _) => { // Any Enter with modifiers keeps editing and inserts a newline via tui-textarea @@ -4208,6 +4756,11 @@ impl ChatApp { self.textarea .move_cursor(tui_textarea::CursorMove::WordBack); } + (KeyCode::Tab, m) if m.is_empty() => { + if !self.complete_resource_reference() { + self.textarea.input(Input::from(key)); + } + } (KeyCode::Char('r'), m) if m.contains(KeyModifiers::CONTROL) => { // Redo - history next self.input_buffer_mut().history_next(); @@ -4538,6 +5091,31 @@ impl ChatApp { } } } + "oauth" => { + if args.is_empty() { + let pending = self.controller.pending_oauth_servers(); + if pending.is_empty() { + self.status = + "No OAuth-enabled MCP servers require authorization." + .to_string(); + } else { + self.status = format!( + "Pending OAuth servers: {}", + pending.join(", ") + ); + } + self.error = None; + } else if args.len() == 1 { + self.start_oauth_login(args[0]).await?; + } else if args.len() == 2 + && args[0].eq_ignore_ascii_case("login") + { + self.start_oauth_login(args[1]).await?; + } else { + self.error = + Some("Usage: :oauth [login] ".to_string()); + } + } "load" | "o" => { // Load saved sessions and enter browser mode match self.controller.list_saved_sessions().await { @@ -5015,29 +5593,58 @@ impl ChatApp { if self.code_workspace.tabs().is_empty() { self.status = "No open panes to save".to_string(); + self.error = None; + self.push_toast( + ToastLevel::Warning, + "Open a pane before saving layout.", + ); } else { self.persist_workspace_layout(); self.status = "Workspace layout saved".to_string(); self.error = None; + self.push_toast( + ToastLevel::Success, + "Workspace layout saved.", + ); } } "load" => match self.restore_workspace_layout().await { - Ok(()) => { + Ok(true) => { self.status = "Workspace layout restored".to_string(); self.error = None; + self.push_toast( + ToastLevel::Success, + "Workspace layout restored.", + ); + } + Ok(false) => { + self.status = + "No saved layout to restore".to_string(); + self.error = None; + self.push_toast( + ToastLevel::Info, + "No saved layout was found.", + ); } Err(err) => { - self.error = Some(err.to_string()); + let message = format!( + "Failed to restore workspace layout: {}", + err + ); + self.error = Some(message.clone()); self.status = "Failed to restore workspace layout" .to_string(); + self.push_toast(ToastLevel::Error, message); } }, other => { + self.status = + format!("Unknown layout command: {other}"); self.error = Some(format!( - "Unknown layout command: {other}" + "Unknown layout subcommand: {other}" )); } } @@ -5068,6 +5675,27 @@ impl ChatApp { self.error = None; self.sync_ui_preferences_from_config(); self.update_command_palette_catalog(); + if let Err(err) = self.refresh_resource_catalog().await + { + self.push_toast( + ToastLevel::Error, + format!( + "Failed to refresh MCP resources: {}", + err + ), + ); + } + if let Err(err) = + self.refresh_mcp_slash_commands().await + { + self.push_toast( + ToastLevel::Error, + format!( + "Failed to refresh MCP slash commands: {}", + err + ), + ); + } } Err(e) => { self.error = @@ -5666,7 +6294,7 @@ impl ChatApp { } } - pub fn handle_session_event(&mut self, event: SessionEvent) -> Result<()> { + pub async fn handle_session_event(&mut self, event: SessionEvent) -> Result<()> { match event { SessionEvent::StreamChunk { message_id, @@ -5760,6 +6388,52 @@ impl ChatApp { self.agent_actions = None; self.stop_loading_animation(); } + SessionEvent::OAuthPoll { + server, + authorization, + } => { + match self + .controller + .poll_oauth_device_flow(&server, &authorization) + .await + { + Ok(DevicePollState::Pending { retry_in }) => { + self.oauth_flows + .insert(server.clone(), authorization.clone()); + let server_name = server.clone(); + self.schedule_oauth_poll(server, authorization, retry_in); + self.status = format!("Waiting for OAuth approval for {server_name}..."); + } + Ok(DevicePollState::Complete(_token)) => { + self.oauth_flows.remove(&server); + self.push_toast( + ToastLevel::Success, + format!("OAuth authorization complete for {server}."), + ); + self.status = format!("OAuth authorization complete for {server}."); + if let Err(err) = self.refresh_resource_catalog().await { + self.push_toast( + ToastLevel::Error, + format!("Failed to refresh MCP resources: {err}"), + ); + } + if let Err(err) = self.refresh_mcp_slash_commands().await { + self.push_toast( + ToastLevel::Error, + format!("Failed to refresh MCP slash commands: {err}"), + ); + } + } + Err(err) => { + self.oauth_flows.remove(&server); + self.error = Some(format!("OAuth flow for '{server}' failed: {err}")); + self.push_toast( + ToastLevel::Error, + format!("OAuth failure for {server}: {err}"), + ); + } + } + } } Ok(()) } @@ -5825,6 +6499,7 @@ impl ChatApp { args: Vec::new(), transport: "stdio".to_string(), env: env_vars.clone(), + oauth: None, }; RemoteMcpClient::new_with_config(&config) } else { @@ -6176,6 +6851,7 @@ impl ChatApp { args: Vec::new(), transport: "stdio".to_string(), env: env_vars, + oauth: None, }; Arc::new(RemoteMcpClient::new_with_config(&config)?) } else { @@ -6423,6 +7099,10 @@ impl ChatApp { // Step 1: Add user message to conversation immediately (synchronous) let message = self.controller.input_buffer_mut().commit_to_history(); + let mut references = Self::extract_resource_references(&message); + references.sort(); + references.dedup(); + self.pending_resource_refs = references; self.controller .conversation_mut() .push_user_message(message.clone()); @@ -6539,6 +7219,8 @@ impl ChatApp { self.pending_llm_request = false; + self.resolve_pending_resource_references().await?; + // Check if agent mode is enabled if self.agent_mode { return self.process_agent_request().await; diff --git a/crates/owlen-tui/src/code_app.rs b/crates/owlen-tui/src/code_app.rs index 690dbb6..137026b 100644 --- a/crates/owlen-tui/src/code_app.rs +++ b/crates/owlen-tui/src/code_app.rs @@ -28,8 +28,8 @@ impl CodeApp { self.inner.handle_event(event).await } - pub fn handle_session_event(&mut self, event: SessionEvent) -> Result<()> { - self.inner.handle_session_event(event) + pub async fn handle_session_event(&mut self, event: SessionEvent) -> Result<()> { + self.inner.handle_session_event(event).await } pub fn mode(&self) -> InputMode { diff --git a/crates/owlen-tui/src/commands/mod.rs b/crates/owlen-tui/src/commands/mod.rs index e7e318b..b852d99 100644 --- a/crates/owlen-tui/src/commands/mod.rs +++ b/crates/owlen-tui/src/commands/mod.rs @@ -235,7 +235,7 @@ pub fn match_score(candidate: &str, query: &str) -> Option<(usize, usize)> { if candidate_normalized == query_normalized { Some((0, candidate.len())) } else if candidate_normalized.starts_with(&query_normalized) { - Some((1, candidate.len())) + Some((1, 0)) } else if let Some(pos) = candidate_normalized.find(&query_normalized) { Some((2, pos)) } else if is_subsequence(&candidate_normalized, &query_normalized) { diff --git a/crates/owlen-tui/src/lib.rs b/crates/owlen-tui/src/lib.rs index ad437b3..80be015 100644 --- a/crates/owlen-tui/src/lib.rs +++ b/crates/owlen-tui/src/lib.rs @@ -18,7 +18,9 @@ pub mod commands; pub mod config; pub mod events; pub mod model_info_panel; +pub mod slash; pub mod state; +pub mod toast; pub mod tui_controller; pub mod ui; diff --git a/crates/owlen-tui/src/slash.rs b/crates/owlen-tui/src/slash.rs new file mode 100644 index 0000000..3112344 --- /dev/null +++ b/crates/owlen-tui/src/slash.rs @@ -0,0 +1,238 @@ +//! Slash command parsing for chat input. +//! +//! Provides lightweight handling for inline commands such as `/summarize` +//! and `/testplan`. The parser returns owned data so callers can prepare +//! requests immediately without additional lifetime juggling. + +use std::collections::HashMap; +use std::fmt; +use std::str::FromStr; +use std::sync::{OnceLock, RwLock}; + +/// Supported slash commands. +#[derive(Debug, Clone)] +pub enum SlashCommand { + Summarize { count: Option }, + Explain { snippet: String }, + Refactor { path: String }, + TestPlan, + Compact, + McpTool { server: String, tool: String }, +} + +/// Errors emitted when parsing invalid slash input. +#[derive(Debug)] +pub enum SlashError { + UnknownCommand(String), + Message(String), +} + +impl fmt::Display for SlashError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + SlashError::UnknownCommand(cmd) => write!(f, "unknown slash command: {cmd}"), + SlashError::Message(msg) => f.write_str(msg), + } + } +} + +impl std::error::Error for SlashError {} + +#[derive(Debug, Clone)] +pub struct McpSlashCommand { + pub server: String, + pub tool: String, + pub keyword: String, + pub description: Option, +} + +impl McpSlashCommand { + pub fn new( + server: impl Into, + tool: impl Into, + description: Option, + ) -> Self { + let server = server.into(); + let tool = tool.into(); + let keyword = format!( + "mcp__{}__{}", + canonicalize_component(&server), + canonicalize_component(&tool) + ); + Self { + server, + tool, + keyword, + description, + } + } +} + +static MCP_COMMANDS: OnceLock>> = OnceLock::new(); + +fn dynamic_registry() -> &'static RwLock> { + MCP_COMMANDS.get_or_init(|| RwLock::new(HashMap::new())) +} + +pub fn set_mcp_commands(commands: impl IntoIterator) { + let registry = dynamic_registry(); + let mut guard = registry.write().expect("MCP command registry poisoned"); + guard.clear(); + for command in commands { + guard.insert(command.keyword.clone(), command); + } +} + +fn find_mcp_command(keyword: &str) -> Option { + let registry = dynamic_registry(); + let guard = registry.read().expect("MCP command registry poisoned"); + guard.get(keyword).cloned() +} + +fn canonicalize_component(input: &str) -> String { + let mut out = String::new(); + let mut last_was_underscore = false; + for ch in input.chars() { + let mapped = if ch.is_ascii_alphanumeric() { + ch.to_ascii_lowercase() + } else { + '_' + }; + if mapped == '_' { + if !last_was_underscore { + out.push('_'); + last_was_underscore = true; + } + } else { + out.push(mapped); + last_was_underscore = false; + } + } + if out.is_empty() { "_".to_string() } else { out } +} + +/// Attempt to parse a slash command from the provided input. +pub fn parse(input: &str) -> Result, SlashError> { + let trimmed = input.trim(); + if !trimmed.starts_with('/') { + return Ok(None); + } + + let body = trimmed.trim_start_matches('/'); + if body.is_empty() { + return Err(SlashError::Message("missing command name after '/'".into())); + } + + let mut parts = body.split_whitespace(); + let command = parts.next().unwrap(); + let remainder = parts.collect::>(); + + if let Some(dynamic) = find_mcp_command(command) { + if !remainder.is_empty() { + return Err(SlashError::Message(format!( + "/{} does not accept arguments", + dynamic.keyword + ))); + } + return Ok(Some(SlashCommand::McpTool { + server: dynamic.server, + tool: dynamic.tool, + })); + } + + let cmd = match command { + "summarize" => { + let count = remainder + .first() + .and_then(|value| usize::from_str(value).ok()); + SlashCommand::Summarize { count } + } + "explain" => { + if remainder.is_empty() { + return Err(SlashError::Message( + "usage: /explain ".into(), + )); + } + SlashCommand::Explain { + snippet: remainder.join(" "), + } + } + "refactor" => { + if remainder.is_empty() { + return Err(SlashError::Message( + "usage: /refactor ".into(), + )); + } + SlashCommand::Refactor { + path: remainder.join(" "), + } + } + "testplan" => SlashCommand::TestPlan, + "compact" => SlashCommand::Compact, + other => return Err(SlashError::UnknownCommand(other.to_string())), + }; + + Ok(Some(cmd)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn ignores_non_command_input() { + let result = parse("hello world").unwrap(); + assert!(result.is_none()); + } + + #[test] + fn parses_summarize_with_count() { + let command = parse("/summarize 10").unwrap().expect("expected command"); + match command { + SlashCommand::Summarize { count } => assert_eq!(count, Some(10)), + other => panic!("unexpected command: {:?}", other), + } + } + + #[test] + fn returns_error_for_unknown_command() { + let err = parse("/unknown").unwrap_err(); + assert_eq!(err.to_string(), "unknown slash command: unknown"); + } + + #[test] + fn parses_registered_mcp_command() { + set_mcp_commands(Vec::new()); + set_mcp_commands(vec![McpSlashCommand::new("github", "list_prs", None)]); + + let command = parse("/mcp__github__list_prs") + .unwrap() + .expect("expected command"); + match command { + SlashCommand::McpTool { server, tool } => { + assert_eq!(server, "github"); + assert_eq!(tool, "list_prs"); + } + other => panic!("unexpected command variant: {:?}", other), + } + } + + #[test] + fn rejects_mcp_command_with_arguments() { + set_mcp_commands(Vec::new()); + set_mcp_commands(vec![McpSlashCommand::new("github", "list_prs", None)]); + + let err = parse("/mcp__github__list_prs extra").unwrap_err(); + assert_eq!( + err.to_string(), + "/mcp__github__list_prs does not accept arguments" + ); + } + + #[test] + fn canonicalizes_mcp_command_components() { + set_mcp_commands(Vec::new()); + let entry = McpSlashCommand::new("GitHub", "list/prs", None); + assert_eq!(entry.keyword, "mcp__github__list_prs"); + } +} diff --git a/crates/owlen-tui/src/theme_util.rs b/crates/owlen-tui/src/theme_util.rs new file mode 100644 index 0000000..4f6d4e4 --- /dev/null +++ b/crates/owlen-tui/src/theme_util.rs @@ -0,0 +1,96 @@ +macro_rules! adjust_fields { + ($theme:expr, $func:expr, $($field:ident),+ $(,)?) => { + $( + $theme.$field = $func($theme.$field); + )+ + }; +} + +use owlen_core::theme::Theme; +use ratatui::style::Color; + +/// Return a clone of `base` with contrast adjustments applied. +/// Positive `steps` increase contrast, negative values decrease it. +pub fn with_contrast(base: &Theme, steps: i8) -> Theme { + if steps == 0 { + return base.clone(); + } + + let factor = (1.0 + (steps as f32) * 0.18).clamp(0.3, 2.0); + let adjust = |color: Color| adjust_color(color, factor); + + let mut theme = base.clone(); + adjust_fields!( + theme, + adjust, + text, + background, + focused_panel_border, + unfocused_panel_border, + focus_beacon_fg, + focus_beacon_bg, + unfocused_beacon_fg, + pane_header_active, + pane_header_inactive, + pane_hint_text, + user_message_role, + assistant_message_role, + tool_output, + thinking_panel_title, + command_bar_background, + status_background, + mode_normal, + mode_editing, + mode_model_selection, + mode_provider_selection, + mode_help, + mode_visual, + mode_command, + selection_bg, + selection_fg, + cursor, + code_block_background, + code_block_border, + code_block_text, + code_block_keyword, + code_block_string, + code_block_comment, + placeholder, + error, + info, + agent_thought, + agent_action, + agent_action_input, + agent_observation, + agent_final_answer, + agent_badge_running_fg, + agent_badge_running_bg, + agent_badge_idle_fg, + agent_badge_idle_bg, + operating_chat_fg, + operating_chat_bg, + operating_code_fg, + operating_code_bg + ); + + theme +} + +fn adjust_color(color: Color, factor: f32) -> Color { + match color { + Color::Rgb(r, g, b) => { + let adjust_component = |component: u8| -> u8 { + let normalized = component as f32 / 255.0; + let contrasted = ((normalized - 0.5) * factor + 0.5).clamp(0.0, 1.0); + (contrasted * 255.0).round().clamp(0.0, 255.0) as u8 + }; + + Color::Rgb( + adjust_component(r), + adjust_component(g), + adjust_component(b), + ) + } + _ => color, + } +} diff --git a/crates/owlen-tui/src/toast.rs b/crates/owlen-tui/src/toast.rs new file mode 100644 index 0000000..5c60a08 --- /dev/null +++ b/crates/owlen-tui/src/toast.rs @@ -0,0 +1,114 @@ +use std::collections::VecDeque; +use std::time::{Duration, Instant}; + +/// Severity level for toast notifications. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ToastLevel { + Info, + Success, + Warning, + Error, +} + +#[derive(Debug, Clone)] +pub struct Toast { + pub message: String, + pub level: ToastLevel, + created: Instant, + duration: Duration, +} + +impl Toast { + fn new(message: String, level: ToastLevel, lifetime: Duration) -> Self { + Self { + message, + level, + created: Instant::now(), + duration: lifetime, + } + } + + fn is_expired(&self, now: Instant) -> bool { + now.duration_since(self.created) >= self.duration + } +} + +/// Fixed-size toast queue with automatic expiration. +#[derive(Debug)] +pub struct ToastManager { + items: VecDeque, + max_active: usize, + lifetime: Duration, +} + +impl Default for ToastManager { + fn default() -> Self { + Self::new() + } +} + +impl ToastManager { + pub fn new() -> Self { + Self { + items: VecDeque::new(), + max_active: 3, + lifetime: Duration::from_secs(3), + } + } + + pub fn with_lifetime(mut self, duration: Duration) -> Self { + self.lifetime = duration; + self + } + + pub fn push(&mut self, message: impl Into, level: ToastLevel) { + let toast = Toast::new(message.into(), level, self.lifetime); + self.items.push_front(toast); + while self.items.len() > self.max_active { + self.items.pop_back(); + } + } + + pub fn retain_active(&mut self) { + let now = Instant::now(); + self.items.retain(|toast| !toast.is_expired(now)); + } + + pub fn iter(&self) -> impl Iterator { + self.items.iter() + } + + pub fn is_empty(&self) -> bool { + self.items.is_empty() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::thread::sleep; + + #[test] + fn manager_limits_active_toasts() { + let mut manager = ToastManager::new(); + manager.push("first", ToastLevel::Info); + manager.push("second", ToastLevel::Warning); + manager.push("third", ToastLevel::Success); + manager.push("fourth", ToastLevel::Error); + + let collected: Vec<_> = manager.iter().map(|toast| toast.message.clone()).collect(); + assert_eq!(collected.len(), 3); + assert_eq!(collected[0], "fourth"); + assert_eq!(collected[2], "second"); + } + + #[test] + fn manager_expires_toasts_after_lifetime() { + let mut manager = ToastManager::new().with_lifetime(Duration::from_millis(1)); + manager.push("short lived", ToastLevel::Info); + assert!(!manager.is_empty()); + sleep(Duration::from_millis(5)); + manager.retain_active(); + assert!(manager.is_empty()); + } +} diff --git a/crates/owlen-tui/src/ui.rs b/crates/owlen-tui/src/ui.rs index 49a0925..63874b4 100644 --- a/crates/owlen-tui/src/ui.rs +++ b/crates/owlen-tui/src/ui.rs @@ -16,10 +16,12 @@ use crate::state::{ CodePane, EditorTab, FileFilterMode, FileNode, LayoutNode, PaletteGroup, PaneId, RepoSearchRowKind, SplitAxis, VisibleFileEntry, }; +use crate::toast::{Toast, ToastLevel}; use owlen_core::model::DetailedModelInfo; use owlen_core::theme::Theme; use owlen_core::types::{ModelInfo, Role}; use owlen_core::ui::{FocusedPanel, InputMode, RoleLabelDisplay}; +use textwrap::wrap; const PRIVACY_TAB_INDEX: usize = HELP_TAB_COUNT - 1; @@ -331,6 +333,113 @@ pub fn render_chat(frame: &mut Frame<'_>, app: &mut ChatApp) { if let Some(area) = code_area { render_code_workspace(frame, area, app); } + + render_toasts(frame, app, full_area); +} + +fn toast_palette(level: ToastLevel, theme: &Theme) -> (&'static str, Style, Style) { + let (label, color) = match level { + ToastLevel::Info => ("INFO", theme.info), + ToastLevel::Success => ("OK", theme.agent_badge_idle_bg), + ToastLevel::Warning => ("WARN", theme.agent_action), + ToastLevel::Error => ("ERROR", theme.error), + }; + + let badge_style = Style::default() + .fg(theme.background) + .bg(color) + .add_modifier(Modifier::BOLD); + let border_style = Style::default().fg(color); + (label, badge_style, border_style) +} + +fn render_toasts(frame: &mut Frame<'_>, app: &ChatApp, full_area: Rect) { + let toasts: Vec<&Toast> = app.toasts().collect(); + if toasts.is_empty() { + return; + } + + let theme = app.theme(); + let available_width = usize::from(full_area.width.saturating_sub(2)); + if available_width == 0 { + return; + } + + let max_text_width = toasts + .iter() + .map(|toast| UnicodeWidthStr::width(toast.message.as_str())) + .max() + .unwrap_or(0); + + let mut width = max_text_width.saturating_add(6); // padding + badge + width = width.clamp(14, available_width); + width = width.min(48); + if width == 0 { + return; + } + let width = width as u16; + + let offset_x = full_area + .x + .saturating_add(full_area.width.saturating_sub(width + 1)); + let mut offset_y = full_area.y.saturating_add(1); + let frame_bottom = full_area.y.saturating_add(full_area.height); + + for toast in toasts { + let (label, badge_style, border_style) = toast_palette(toast.level, theme); + let badge_text = format!(" {} ", label); + let indent_width = UnicodeWidthStr::width(badge_text.as_str()) + 1; + let indent = " ".repeat(indent_width); + + let content_width = width.saturating_sub(4).max(1) as usize; + let wrapped_lines = wrap(toast.message.as_str(), content_width); + let lines: Vec = if wrapped_lines.is_empty() { + vec![String::new()] + } else { + wrapped_lines + .into_iter() + .map(|cow| cow.into_owned()) + .collect() + }; + + let text_style = Style::default().fg(theme.text); + let mut paragraph_lines = Vec::with_capacity(lines.len()); + if let Some((first, rest)) = lines.split_first() { + paragraph_lines.push(Line::from(vec![ + Span::styled(badge_text.clone(), badge_style), + Span::raw(" "), + Span::styled(first.clone(), text_style), + ])); + for line in rest { + paragraph_lines.push(Line::from(vec![ + Span::raw(indent.clone()), + Span::styled(line.clone(), text_style), + ])); + } + } + + let height = (paragraph_lines.len() as u16).saturating_add(2); + if offset_y.saturating_add(height) > frame_bottom { + break; + } + + let area = Rect::new(offset_x, offset_y, width, height); + frame.render_widget(Clear, area); + let block = Block::default() + .borders(Borders::ALL) + .border_style(border_style) + .style(Style::default().bg(theme.background)); + let paragraph = Paragraph::new(paragraph_lines) + .block(block) + .alignment(Alignment::Left) + .wrap(Wrap { trim: false }); + frame.render_widget(paragraph, area); + + offset_y = offset_y.saturating_add(height + 1); + if offset_y >= frame_bottom { + break; + } + } } #[derive(Debug, Clone)] diff --git a/crates/owlen-tui/tests/state_tests.rs b/crates/owlen-tui/tests/state_tests.rs index 3caeace..ba3dda6 100644 --- a/crates/owlen-tui/tests/state_tests.rs +++ b/crates/owlen-tui/tests/state_tests.rs @@ -9,11 +9,21 @@ fn palette_tracks_buffer_and_suggestions() { palette.set_buffer("mo"); assert_eq!(palette.buffer(), "mo"); - assert!(palette.suggestions().iter().all(|s| s.starts_with("mo"))); + assert!( + palette + .suggestions() + .iter() + .all(|s| s.value.starts_with("mo")) + ); palette.push_char('d'); assert_eq!(palette.buffer(), "mod"); - assert!(palette.suggestions().iter().all(|s| s.starts_with("mod"))); + assert!( + palette + .suggestions() + .iter() + .all(|s| s.value.starts_with("mod")) + ); palette.pop_char(); assert_eq!(palette.buffer(), "mo");