feat(cli): add MCP management subcommand with add/list/remove commands

Introduce `McpCommand` enum and handlers in `owlen-cli` to manage MCP server registrations, including adding, listing, and removing servers across configuration scopes. Add scoped configuration support (`ScopedMcpServer`, `McpConfigScope`) and OAuth token handling in core config, alongside runtime refresh of MCP servers. Implement toast notifications in the TUI (`render_toasts`, `Toast`, `ToastLevel`) and integrate async handling for session events. Update config loading, validation, and schema versioning to accommodate new MCP scopes and resources. Add `httpmock` as a dev dependency for testing.
This commit is contained in:
2025-10-13 17:54:14 +02:00
parent 0da8a3f193
commit 690f5c7056
23 changed files with 3388 additions and 74 deletions

View File

@@ -92,6 +92,7 @@ OWLEN uses a modal, vim-inspired interface. Press `F1` (available from any mode)
- **Editing Mode**: Enter with `i` or `a`. Send messages with `Enter`. - **Editing Mode**: Enter with `i` or `a`. Send messages with `Enter`.
- **Command Mode**: Enter with `:`. Access commands like `:quit`, `:save`, `:theme`. - **Command Mode**: Enter with `:`. Access commands like `:quit`, `:save`, `:theme`.
- **Tutorial Command**: Type `:tutorial` any time for a quick summary of the most important keybindings. - **Tutorial Command**: Type `:tutorial` any time for a quick summary of the most important keybindings.
- **MCP Slash Commands**: Owlen auto-registers zero-argument MCP tools as slash commands—type `/mcp__github__list_prs` (for example) to pull remote context directly into the chat log.
## Documentation ## Documentation

View File

@@ -1,11 +1,13 @@
//! OWLEN CLI - Chat TUI client //! OWLEN CLI - Chat TUI client
mod cloud; mod cloud;
mod mcp;
use anyhow::{Result, anyhow}; use anyhow::{Result, anyhow};
use async_trait::async_trait; use async_trait::async_trait;
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};
use cloud::{CloudCommand, load_runtime_credentials, set_env_var}; use cloud::{CloudCommand, load_runtime_credentials, set_env_var};
use mcp::{McpCommand, run_mcp_command};
use owlen_core::config as core_config; use owlen_core::config as core_config;
use owlen_core::{ use owlen_core::{
ChatStream, Error, Provider, ChatStream, Error, Provider,
@@ -54,6 +56,9 @@ enum OwlenCommand {
/// Manage Ollama Cloud credentials /// Manage Ollama Cloud credentials
#[command(subcommand)] #[command(subcommand)]
Cloud(CloudCommand), Cloud(CloudCommand),
/// Manage MCP server registrations
#[command(subcommand)]
Mcp(McpCommand),
/// Show manual steps for updating Owlen to the latest revision /// Show manual steps for updating Owlen to the latest revision
Upgrade, Upgrade,
} }
@@ -69,7 +74,7 @@ enum ConfigCommand {
fn build_provider(cfg: &Config) -> anyhow::Result<Arc<dyn Provider>> { fn build_provider(cfg: &Config) -> anyhow::Result<Arc<dyn Provider>> {
match cfg.mcp.mode { match cfg.mcp.mode {
McpMode::RemotePreferred => { McpMode::RemotePreferred => {
let remote_result = if let Some(mcp_server) = cfg.mcp_servers.first() { let remote_result = if let Some(mcp_server) = cfg.effective_mcp_servers().first() {
RemoteMcpClient::new_with_config(mcp_server) RemoteMcpClient::new_with_config(mcp_server)
} else { } else {
RemoteMcpClient::new() RemoteMcpClient::new()
@@ -91,7 +96,7 @@ fn build_provider(cfg: &Config) -> anyhow::Result<Arc<dyn Provider>> {
} }
} }
McpMode::RemoteOnly => { McpMode::RemoteOnly => {
let mcp_server = cfg.mcp_servers.first().ok_or_else(|| { let mcp_server = cfg.effective_mcp_servers().first().ok_or_else(|| {
anyhow::anyhow!( anyhow::anyhow!(
"[[mcp_servers]] must be configured when [mcp].mode = \"remote_only\"" "[[mcp_servers]] must be configured when [mcp].mode = \"remote_only\""
) )
@@ -130,6 +135,7 @@ async fn run_command(command: OwlenCommand) -> Result<()> {
match command { match command {
OwlenCommand::Config(config_cmd) => run_config_command(config_cmd), OwlenCommand::Config(config_cmd) => run_config_command(config_cmd),
OwlenCommand::Cloud(cloud_cmd) => cloud::run_cloud_command(cloud_cmd).await, OwlenCommand::Cloud(cloud_cmd) => cloud::run_cloud_command(cloud_cmd).await,
OwlenCommand::Mcp(mcp_cmd) => run_mcp_command(mcp_cmd),
OwlenCommand::Upgrade => { OwlenCommand::Upgrade => {
println!( println!(
"To update Owlen from source:\n git pull\n cargo install --path crates/owlen-cli --force" "To update Owlen from source:\n git pull\n cargo install --path crates/owlen-cli --force"
@@ -157,6 +163,7 @@ fn run_config_doctor() -> Result<()> {
let config_path = core_config::default_config_path(); let config_path = core_config::default_config_path();
let existed = config_path.exists(); let existed = config_path.exists();
let mut config = config::try_load_config().unwrap_or_default(); let mut config = config::try_load_config().unwrap_or_default();
let _ = config.refresh_mcp_servers(None);
let mut changes = Vec::new(); let mut changes = Vec::new();
if !existed { if !existed {
@@ -205,7 +212,7 @@ fn run_config_doctor() -> Result<()> {
config.mcp.warn_on_legacy = true; config.mcp.warn_on_legacy = true;
changes.push("converted [mcp].mode = 'legacy' to 'local_only'".to_string()); changes.push("converted [mcp].mode = 'legacy' to 'local_only'".to_string());
} }
McpMode::RemoteOnly if config.mcp_servers.is_empty() => { McpMode::RemoteOnly if config.effective_mcp_servers().is_empty() => {
config.mcp.mode = McpMode::RemotePreferred; config.mcp.mode = McpMode::RemotePreferred;
config.mcp.allow_fallback = true; config.mcp.allow_fallback = true;
changes.push( changes.push(
@@ -213,7 +220,9 @@ fn run_config_doctor() -> Result<()> {
.to_string(), .to_string(),
); );
} }
McpMode::RemotePreferred if !config.mcp.allow_fallback && config.mcp_servers.is_empty() => { McpMode::RemotePreferred
if !config.mcp.allow_fallback && config.effective_mcp_servers().is_empty() =>
{
config.mcp.allow_fallback = true; config.mcp.allow_fallback = true;
changes.push( changes.push(
"enabled [mcp].allow_fallback because no remote servers are configured".to_string(), "enabled [mcp].allow_fallback because no remote servers are configured".to_string(),
@@ -369,6 +378,7 @@ async fn main() -> Result<()> {
let color_support = detect_terminal_color_support(); let color_support = detect_terminal_color_support();
// Load configuration (or fall back to defaults) for the session controller. // Load configuration (or fall back to defaults) for the session controller.
let mut cfg = config::try_load_config().unwrap_or_default(); let mut cfg = config::try_load_config().unwrap_or_default();
let _ = cfg.refresh_mcp_servers(None);
if let Some(previous_theme) = apply_terminal_theme(&mut cfg, &color_support) { if let Some(previous_theme) = apply_terminal_theme(&mut cfg, &color_support) {
let term_label = match &color_support { let term_label = match &color_support {
TerminalColorSupport::Limited { term } => Cow::from(term.as_str()), TerminalColorSupport::Limited { term } => Cow::from(term.as_str()),
@@ -398,7 +408,7 @@ async fn main() -> Result<()> {
Ok(_) => provider, Ok(_) => provider,
Err(err) => { Err(err) => {
let hint = if matches!(cfg.mcp.mode, McpMode::RemotePreferred | McpMode::RemoteOnly) let hint = if matches!(cfg.mcp.mode, McpMode::RemotePreferred | McpMode::RemoteOnly)
&& !cfg.mcp_servers.is_empty() && !cfg.effective_mcp_servers().is_empty()
{ {
"Ensure the configured MCP server is running and reachable." "Ensure the configured MCP server is running and reachable."
} else { } else {
@@ -523,7 +533,7 @@ async fn run_app(
} }
} }
Some(session_event) = session_rx.recv() => { Some(session_event) = session_rx.recv() => {
app.handle_session_event(session_event)?; app.handle_session_event(session_event).await?;
} }
_ = tokio::time::sleep(sleep_duration) => {} _ = tokio::time::sleep(sleep_duration) => {}
} }

257
crates/owlen-cli/src/mcp.rs Normal file
View File

@@ -0,0 +1,257 @@
use std::collections::{HashMap, HashSet};
use anyhow::{Result, anyhow};
use clap::{Args, Subcommand, ValueEnum};
use owlen_core::config::{self as core_config, Config, McpConfigScope, McpServerConfig};
use owlen_tui::config as tui_config;
#[derive(Debug, Subcommand)]
pub enum McpCommand {
/// Add or update an MCP server in the selected scope
Add(AddArgs),
/// List MCP servers across scopes
List(ListArgs),
/// Remove an MCP server from a scope
Remove(RemoveArgs),
}
pub fn run_mcp_command(command: McpCommand) -> Result<()> {
match command {
McpCommand::Add(args) => handle_add(args),
McpCommand::List(args) => handle_list(args),
McpCommand::Remove(args) => handle_remove(args),
}
}
#[derive(Debug, Clone, Copy, ValueEnum, Default)]
pub enum ScopeArg {
User,
#[default]
Project,
Local,
}
impl From<ScopeArg> for McpConfigScope {
fn from(value: ScopeArg) -> Self {
match value {
ScopeArg::User => McpConfigScope::User,
ScopeArg::Project => McpConfigScope::Project,
ScopeArg::Local => McpConfigScope::Local,
}
}
}
#[derive(Debug, Args)]
pub struct AddArgs {
/// Logical name used to reference the server
pub name: String,
/// Command or endpoint invoked for the server
pub command: String,
/// Transport mechanism (stdio, http, websocket)
#[arg(long, default_value = "stdio")]
pub transport: String,
/// Configuration scope to write the server into
#[arg(long, value_enum, default_value_t = ScopeArg::Project)]
pub scope: ScopeArg,
/// Environment variables (KEY=VALUE) passed to the server process
#[arg(long = "env")]
pub env: Vec<String>,
/// Additional arguments appended when launching the server
#[arg(trailing_var_arg = true, value_name = "ARG")]
pub args: Vec<String>,
}
#[derive(Debug, Args, Default)]
pub struct ListArgs {
/// Restrict output to a specific configuration scope
#[arg(long, value_enum)]
pub scope: Option<ScopeArg>,
/// Display only the effective servers (after precedence resolution)
#[arg(long)]
pub effective_only: bool,
}
#[derive(Debug, Args)]
pub struct RemoveArgs {
/// Name of the server to remove
pub name: String,
/// Optional explicit scope to remove from
#[arg(long, value_enum)]
pub scope: Option<ScopeArg>,
}
fn handle_add(args: AddArgs) -> Result<()> {
let mut config = load_config()?;
let scope: McpConfigScope = args.scope.into();
let mut env_map = HashMap::new();
for pair in &args.env {
let (key, value) = pair
.split_once('=')
.ok_or_else(|| anyhow!("Environment pairs must use KEY=VALUE syntax: '{}'", pair))?;
if key.trim().is_empty() {
return Err(anyhow!("Environment variable name cannot be empty"));
}
env_map.insert(key.trim().to_string(), value.to_string());
}
let server = McpServerConfig {
name: args.name.clone(),
command: args.command.clone(),
args: args.args.clone(),
transport: args.transport.to_lowercase(),
env: env_map,
oauth: None,
};
config.add_mcp_server(scope, server.clone(), None)?;
if matches!(scope, McpConfigScope::User) {
tui_config::save_config(&config)?;
}
if let Some(path) = core_config::mcp_scope_path(scope, None) {
println!(
"Registered MCP server '{}' in {} scope ({})",
server.name,
scope,
path.display()
);
} else {
println!(
"Registered MCP server '{}' in {} scope.",
server.name, scope
);
}
Ok(())
}
fn handle_list(args: ListArgs) -> Result<()> {
let mut config = load_config()?;
config.refresh_mcp_servers(None)?;
let scoped = config.scoped_mcp_servers();
if scoped.is_empty() {
println!("No MCP servers configured.");
return Ok(());
}
let filter_scope = args.scope.map(|scope| scope.into());
let effective = config.effective_mcp_servers();
let mut active = HashSet::new();
for server in effective {
active.insert((
server.name.clone(),
server.command.clone(),
server.transport.to_lowercase(),
));
}
println!(
"{:<2} {:<8} {:<20} {:<10} Command",
"", "Scope", "Name", "Transport"
);
for entry in scoped {
if let Some(target_scope) = filter_scope
&& entry.scope != target_scope
{
continue;
}
let payload = format_command_line(&entry.config.command, &entry.config.args);
let key = (
entry.config.name.clone(),
entry.config.command.clone(),
entry.config.transport.to_lowercase(),
);
let marker = if active.contains(&key) { "*" } else { " " };
if args.effective_only && marker != "*" {
continue;
}
println!(
"{} {:<8} {:<20} {:<10} {}",
marker, entry.scope, entry.config.name, entry.config.transport, payload
);
}
let scoped_resources = config.scoped_mcp_resources();
if !scoped_resources.is_empty() {
println!();
println!("{:<2} {:<8} {:<30} Title", "", "Scope", "Resource");
let effective_keys: HashSet<(String, String)> = config
.effective_mcp_resources()
.iter()
.map(|res| (res.server.clone(), res.uri.clone()))
.collect();
for entry in scoped_resources {
if let Some(target_scope) = filter_scope
&& entry.scope != target_scope
{
continue;
}
let key = (entry.config.server.clone(), entry.config.uri.clone());
let marker = if effective_keys.contains(&key) {
"*"
} else {
" "
};
if args.effective_only && marker != "*" {
continue;
}
let reference = format!("@{}:{}", entry.config.server, entry.config.uri);
let title = entry.config.title.as_deref().unwrap_or("");
println!("{} {:<8} {:<30} {}", marker, entry.scope, reference, title);
}
}
Ok(())
}
fn handle_remove(args: RemoveArgs) -> Result<()> {
let mut config = load_config()?;
let scope_hint = args.scope.map(|scope| scope.into());
let result = config.remove_mcp_server(scope_hint, &args.name, None)?;
match result {
Some(scope) => {
if matches!(scope, McpConfigScope::User) {
tui_config::save_config(&config)?;
}
if let Some(path) = core_config::mcp_scope_path(scope, None) {
println!(
"Removed MCP server '{}' from {} scope ({})",
args.name,
scope,
path.display()
);
} else {
println!("Removed MCP server '{}' from {} scope.", args.name, scope);
}
}
None => {
println!("No MCP server named '{}' was found.", args.name);
}
}
Ok(())
}
fn load_config() -> Result<Config> {
let mut config = tui_config::try_load_config().unwrap_or_default();
config.refresh_mcp_servers(None)?;
Ok(config)
}
fn format_command_line(command: &str, args: &[String]) -> String {
if args.is_empty() {
command.to_string()
} else {
format!("{} {}", command, args.join(" "))
}
}

View File

@@ -50,3 +50,4 @@ ollama-rs = { version = "0.3", features = ["stream", "headers"] }
[dev-dependencies] [dev-dependencies]
tokio-test = { workspace = true } tokio-test = { workspace = true }
httpmock = "0.7"

View File

@@ -1,13 +1,15 @@
use crate::Error;
use crate::ProviderConfig; use crate::ProviderConfig;
use crate::Result; use crate::Result;
use crate::mode::ModeConfig; use crate::mode::ModeConfig;
use crate::ui::RoleLabelDisplay; use crate::ui::RoleLabelDisplay;
use serde::de::{self, Deserializer, Visitor}; use serde::de::{self, Deserializer, Visitor};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::{HashMap, HashSet};
use std::fmt; use std::fmt;
use std::fs; use std::fs;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::time::Duration; use std::time::Duration;
/// Default location for the OWLEN configuration file /// Default location for the OWLEN configuration file
@@ -54,6 +56,21 @@ pub struct Config {
/// External MCP server definitions /// External MCP server definitions
#[serde(default)] #[serde(default)]
pub mcp_servers: Vec<McpServerConfig>, pub mcp_servers: Vec<McpServerConfig>,
/// User-scoped resource definitions
#[serde(default)]
pub mcp_resources: Vec<McpResourceConfig>,
/// Resolved MCP servers across scopes (runtime only).
#[serde(skip)]
pub scoped_mcp_servers: Vec<ScopedMcpServer>,
/// Effective MCP servers after applying precedence rules (runtime only).
#[serde(skip)]
pub effective_mcp_servers: Vec<McpServerConfig>,
/// Resolved MCP resources across scopes (runtime only).
#[serde(skip)]
pub scoped_mcp_resources: Vec<ScopedMcpResource>,
/// Effective MCP resources after precedence (runtime only).
#[serde(skip)]
pub effective_mcp_resources: Vec<McpResourceConfig>,
} }
impl Default for Config { impl Default for Config {
@@ -74,6 +91,11 @@ impl Default for Config {
tools: ToolSettings::default(), tools: ToolSettings::default(),
modes: ModeConfig::default(), modes: ModeConfig::default(),
mcp_servers: Vec::new(), mcp_servers: Vec::new(),
mcp_resources: Vec::new(),
scoped_mcp_servers: Vec::new(),
effective_mcp_servers: Vec::new(),
scoped_mcp_resources: Vec::new(),
effective_mcp_resources: Vec::new(),
} }
} }
} }
@@ -94,6 +116,9 @@ pub struct McpServerConfig {
/// Optional environment variable map for the process. /// Optional environment variable map for the process.
#[serde(default)] #[serde(default)]
pub env: std::collections::HashMap<String, String>, pub env: std::collections::HashMap<String, String>,
/// Optional OAuth configuration for remote servers.
#[serde(default)]
pub oauth: Option<McpOAuthConfig>,
} }
impl McpServerConfig { impl McpServerConfig {
@@ -102,6 +127,126 @@ impl McpServerConfig {
} }
} }
/// OAuth configuration for MCP servers that require delegated authentication.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct McpOAuthConfig {
/// Public client identifier registered with the authorization server.
pub client_id: String,
/// Optional client secret for confidential clients.
#[serde(default)]
pub client_secret: Option<String>,
/// OAuth authorization endpoint (used for web-based flows).
pub authorize_url: String,
/// OAuth token endpoint.
pub token_url: String,
/// Optional device authorization endpoint for device-code flows.
#[serde(default)]
pub device_authorization_url: Option<String>,
/// Optional redirect URL (PKCE / authorization-code flows).
#[serde(default)]
pub redirect_url: Option<String>,
/// Requested OAuth scopes.
#[serde(default)]
pub scopes: Vec<String>,
/// Environment variable name populated with the bearer access token when spawning stdio servers.
#[serde(default)]
pub token_env: Option<String>,
/// Optional HTTP header name for bearer authentication (defaults to "Authorization").
#[serde(default)]
pub header: Option<String>,
/// Optional prefix prepended to the access token (defaults to "Bearer ").
#[serde(default)]
pub header_prefix: Option<String>,
}
impl McpOAuthConfig {
pub fn header_name(&self) -> &str {
self.header.as_deref().unwrap_or("Authorization")
}
pub fn header_prefix(&self) -> &str {
self.header_prefix.as_deref().unwrap_or("Bearer ")
}
}
/// Scope for MCP server configuration entries.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum McpConfigScope {
/// User-level configuration stored under the user's config directory.
User,
/// Project configuration stored in the repository (e.g. `.mcp.json`).
Project,
/// Local overrides stored alongside the project but excluded from version control.
Local,
}
impl McpConfigScope {
fn precedence_iter() -> impl Iterator<Item = Self> {
[
McpConfigScope::Local,
McpConfigScope::Project,
McpConfigScope::User,
]
.into_iter()
}
fn as_str(self) -> &'static str {
match self {
McpConfigScope::User => "user",
McpConfigScope::Project => "project",
McpConfigScope::Local => "local",
}
}
}
impl fmt::Display for McpConfigScope {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
impl FromStr for McpConfigScope {
type Err = String;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.to_ascii_lowercase().as_str() {
"user" => Ok(McpConfigScope::User),
"project" => Ok(McpConfigScope::Project),
"local" => Ok(McpConfigScope::Local),
other => Err(format!("Unknown MCP scope '{other}'")),
}
}
}
/// A resolved MCP server entry annotated with its configuration scope.
#[derive(Debug, Clone)]
pub struct ScopedMcpServer {
pub scope: McpConfigScope,
pub config: McpServerConfig,
}
/// Configuration for a predefined MCP resource reference.
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct McpResourceConfig {
/// Named MCP server that owns this resource.
pub server: String,
/// URI or path identifying the resource within the server.
pub uri: String,
/// Optional short title displayed in UI.
#[serde(default)]
pub title: Option<String>,
/// Optional detailed description shown in tooltips.
#[serde(default)]
pub description: Option<String>,
}
/// Resource entry annotated with its originating scope.
#[derive(Debug, Clone)]
pub struct ScopedMcpResource {
pub scope: McpConfigScope,
pub config: McpResourceConfig,
}
impl Config { impl Config {
fn default_schema_version() -> String { fn default_schema_version() -> String {
CONFIG_SCHEMA_VERSION.to_string() CONFIG_SCHEMA_VERSION.to_string()
@@ -138,18 +283,22 @@ impl Config {
config.mcp.apply_backward_compat(); config.mcp.apply_backward_compat();
config.apply_schema_migrations(&previous_version); config.apply_schema_migrations(&previous_version);
config.expand_provider_env_vars()?; config.expand_provider_env_vars()?;
config.refresh_mcp_servers(None)?;
config.validate()?; config.validate()?;
Ok(config) Ok(config)
} else { } else {
let mut config = Config::default(); let mut config = Config::default();
config.expand_provider_env_vars()?; config.expand_provider_env_vars()?;
config.refresh_mcp_servers(None)?;
Ok(config) Ok(config)
} }
} }
/// Persist configuration to disk /// Persist configuration to disk
pub fn save(&self, path: Option<&Path>) -> Result<()> { pub fn save(&self, path: Option<&Path>) -> Result<()> {
self.validate()?; let mut validator = self.clone();
validator.refresh_mcp_servers(None)?;
validator.validate()?;
let path = match path { let path = match path {
Some(path) => path.to_path_buf(), Some(path) => path.to_path_buf(),
@@ -214,6 +363,192 @@ impl Config {
Ok(()) Ok(())
} }
/// Refresh the resolved MCP server list by loading scope-specific definitions.
pub fn refresh_mcp_servers(&mut self, project_hint: Option<&Path>) -> Result<()> {
let mut scoped_servers = Vec::new();
let mut scoped_resources = Vec::new();
let mut user_servers = self.mcp_servers.clone();
expand_mcp_servers(&mut user_servers, "config.mcp_servers")?;
for server in user_servers {
scoped_servers.push(ScopedMcpServer {
scope: McpConfigScope::User,
config: server,
});
}
let mut user_resources = self.mcp_resources.clone();
expand_mcp_resources(&mut user_resources, "config.mcp_resources")?;
for resource in user_resources {
scoped_resources.push(ScopedMcpResource {
scope: McpConfigScope::User,
config: resource,
});
}
for scope in [McpConfigScope::Project, McpConfigScope::Local] {
if let Some(path) = mcp_scope_path(scope, project_hint) {
let mut file = read_scope_config(&path)?;
let server_context = format!("mcp.{scope}.servers");
expand_mcp_servers(&mut file.servers, &server_context)?;
for server in file.servers {
scoped_servers.push(ScopedMcpServer {
scope,
config: server,
});
}
let resource_context = format!("mcp.{scope}.resources");
expand_mcp_resources(&mut file.resources, &resource_context)?;
for resource in file.resources {
scoped_resources.push(ScopedMcpResource {
scope,
config: resource,
});
}
}
}
let mut effective_servers = Vec::new();
let mut seen_servers = HashSet::new();
for scope in McpConfigScope::precedence_iter() {
for entry in scoped_servers.iter().filter(|entry| entry.scope == scope) {
if seen_servers.insert(entry.config.name.clone()) {
effective_servers.push(entry.config.clone());
}
}
}
let mut effective_resources = Vec::new();
let mut seen_resources: HashSet<(String, String)> = HashSet::new();
for scope in McpConfigScope::precedence_iter() {
for entry in scoped_resources.iter().filter(|entry| entry.scope == scope) {
let key = (entry.config.server.clone(), entry.config.uri.clone());
if seen_resources.insert(key) {
effective_resources.push(entry.config.clone());
}
}
}
self.scoped_mcp_servers = scoped_servers;
self.effective_mcp_servers = effective_servers;
self.scoped_mcp_resources = scoped_resources;
self.effective_mcp_resources = effective_resources;
Ok(())
}
/// Return the merged MCP servers using scope precedence (local > project > user).
pub fn effective_mcp_servers(&self) -> &[McpServerConfig] {
&self.effective_mcp_servers
}
/// Return MCP servers annotated with their originating scope.
pub fn scoped_mcp_servers(&self) -> &[ScopedMcpServer] {
&self.scoped_mcp_servers
}
/// Return merged MCP resources using scope precedence (local > project > user).
pub fn effective_mcp_resources(&self) -> &[McpResourceConfig] {
&self.effective_mcp_resources
}
/// Return scoped MCP resources with their origin scope metadata.
pub fn scoped_mcp_resources(&self) -> &[ScopedMcpResource] {
&self.scoped_mcp_resources
}
/// Locate a configured resource by server and URI.
pub fn find_resource(&self, server: &str, uri: &str) -> Option<&McpResourceConfig> {
self.effective_mcp_resources
.iter()
.find(|resource| resource.server == server && resource.uri == uri)
}
/// Add or replace an MCP server definition within the specified scope.
pub fn add_mcp_server(
&mut self,
scope: McpConfigScope,
server: McpServerConfig,
project_hint: Option<&Path>,
) -> Result<()> {
match scope {
McpConfigScope::User => {
self.mcp_servers
.retain(|existing| existing.name != server.name);
self.mcp_servers.push(server);
}
other => {
let path = mcp_scope_path(other, project_hint).ok_or_else(|| {
Error::Config(format!(
"Unable to resolve project root for MCP scope '{}'",
other
))
})?;
let mut file = read_scope_config(&path)?;
file.servers.retain(|existing| existing.name != server.name);
file.servers.push(server);
write_scope_config(&path, &file)?;
}
}
self.refresh_mcp_servers(project_hint)?;
Ok(())
}
/// Remove an MCP server from the given scope, or infer the scope if omitted.
pub fn remove_mcp_server(
&mut self,
scope: Option<McpConfigScope>,
name: &str,
project_hint: Option<&Path>,
) -> Result<Option<McpConfigScope>> {
let target_scope = if let Some(scope) = scope {
scope
} else {
self.refresh_mcp_servers(project_hint)?;
match self
.scoped_mcp_servers
.iter()
.find(|entry| entry.config.name == name)
{
Some(entry) => entry.scope,
None => return Ok(None),
}
};
let removed = match target_scope {
McpConfigScope::User => {
let before = self.mcp_servers.len();
self.mcp_servers.retain(|entry| entry.name != name);
before != self.mcp_servers.len()
}
other => {
let path = mcp_scope_path(other, project_hint).ok_or_else(|| {
Error::Config(format!(
"Unable to resolve project root for MCP scope '{}'",
other
))
})?;
let mut file = read_scope_config(&path)?;
let before = file.servers.len();
file.servers.retain(|entry| entry.name != name);
if before == file.servers.len() {
false
} else {
write_scope_config(&path, &file)?;
true
}
}
};
if removed {
self.refresh_mcp_servers(project_hint)?;
Ok(Some(target_scope))
} else {
Ok(None)
}
}
/// Validate configuration invariants and surface actionable error messages. /// Validate configuration invariants and surface actionable error messages.
pub fn validate(&self) -> Result<()> { pub fn validate(&self) -> Result<()> {
self.validate_default_provider()?; self.validate_default_provider()?;
@@ -284,9 +619,15 @@ impl Config {
} }
fn validate_mcp_settings(&self) -> Result<()> { fn validate_mcp_settings(&self) -> Result<()> {
let has_effective_servers = if self.effective_mcp_servers.is_empty() {
!self.mcp_servers.is_empty()
} else {
!self.effective_mcp_servers.is_empty()
};
match self.mcp.mode { match self.mcp.mode {
McpMode::RemoteOnly => { McpMode::RemoteOnly => {
if self.mcp_servers.is_empty() { if !has_effective_servers {
return Err(crate::Error::Config( return Err(crate::Error::Config(
"[mcp].mode = 'remote_only' requires at least one [[mcp_servers]] entry" "[mcp].mode = 'remote_only' requires at least one [[mcp_servers]] entry"
.to_string(), .to_string(),
@@ -294,7 +635,7 @@ impl Config {
} }
} }
McpMode::RemotePreferred => { McpMode::RemotePreferred => {
if !self.mcp.allow_fallback && self.mcp_servers.is_empty() { if !self.mcp.allow_fallback && !has_effective_servers {
return Err(crate::Error::Config( return Err(crate::Error::Config(
"[mcp].allow_fallback = false requires at least one [[mcp_servers]] entry" "[mcp].allow_fallback = false requires at least one [[mcp_servers]] entry"
.to_string(), .to_string(),
@@ -313,26 +654,13 @@ impl Config {
} }
fn validate_mcp_servers(&self) -> Result<()> { fn validate_mcp_servers(&self) -> Result<()> {
for server in &self.mcp_servers { if self.scoped_mcp_servers.is_empty() {
if server.name.trim().is_empty() { for server in &self.mcp_servers {
return Err(crate::Error::Config( validate_mcp_server_entry(server, McpConfigScope::User)?;
"Each [[mcp_servers]] entry must include a non-empty name".to_string(),
));
} }
} else {
if server.command.trim().is_empty() { for entry in &self.scoped_mcp_servers {
return Err(crate::Error::Config(format!( validate_mcp_server_entry(&entry.config, entry.scope)?;
"MCP server '{}' must define a command or endpoint",
server.name
)));
}
let transport = server.transport.to_lowercase();
if !matches!(transport.as_str(), "stdio" | "http" | "websocket") {
return Err(crate::Error::Config(format!(
"Unknown MCP transport '{}' for server '{}'",
server.transport, server.name
)));
} }
} }
@@ -349,6 +677,58 @@ fn default_ollama_provider_config() -> ProviderConfig {
} }
} }
fn validate_mcp_server_entry(server: &McpServerConfig, scope: McpConfigScope) -> Result<()> {
if server.name.trim().is_empty() {
return Err(Error::Config(format!(
"Each MCP server entry must include a non-empty name (scope: {scope})"
)));
}
if server.command.trim().is_empty() {
return Err(Error::Config(format!(
"MCP server '{}' must define a command or endpoint (scope: {scope})",
server.name
)));
}
let transport = server.transport.to_lowercase();
if !matches!(transport.as_str(), "stdio" | "http" | "websocket") {
return Err(Error::Config(format!(
"Unknown MCP transport '{}' for server '{}' (scope: {scope})",
server.transport, server.name
)));
}
if let Some(oauth) = &server.oauth {
if oauth.client_id.trim().is_empty() {
return Err(Error::Config(format!(
"MCP server '{}' defines OAuth without a client_id",
server.name
)));
}
if oauth.authorize_url.trim().is_empty() {
return Err(Error::Config(format!(
"MCP server '{}' defines OAuth without an authorize_url",
server.name
)));
}
if oauth.token_url.trim().is_empty() {
return Err(Error::Config(format!(
"MCP server '{}' defines OAuth without a token_url",
server.name
)));
}
if oauth.device_authorization_url.is_none() && oauth.redirect_url.is_none() {
return Err(Error::Config(format!(
"MCP server '{}' must define either device_authorization_url or redirect_url for OAuth flows",
server.name
)));
}
}
Ok(())
}
fn expand_provider_entry(provider_name: &str, provider: &mut ProviderConfig) -> Result<()> { fn expand_provider_entry(provider_name: &str, provider: &mut ProviderConfig) -> Result<()> {
if let Some(ref mut base_url) = provider.base_url { if let Some(ref mut base_url) = provider.base_url {
let expanded = expand_env_string( let expanded = expand_env_string(
@@ -379,6 +759,136 @@ fn expand_provider_entry(provider_name: &str, provider: &mut ProviderConfig) ->
Ok(()) Ok(())
} }
fn expand_mcp_servers(servers: &mut [McpServerConfig], field_path: &str) -> Result<()> {
for (idx, server) in servers.iter_mut().enumerate() {
expand_mcp_server_entry(server, field_path, idx)?;
}
Ok(())
}
fn expand_mcp_server_entry(
server: &mut McpServerConfig,
field_path: &str,
index: usize,
) -> Result<()> {
server.command = expand_env_string(
server.command.as_str(),
&format!("{field_path}[{index}].command"),
)?;
for (arg_idx, arg) in server.args.iter_mut().enumerate() {
*arg = expand_env_string(
arg.as_str(),
&format!("{field_path}[{index}].args[{arg_idx}]"),
)?;
}
for (env_key, env_value) in server.env.iter_mut() {
*env_value = expand_env_string(
env_value.as_str(),
&format!("{field_path}[{index}].env.{env_key}"),
)?;
}
if let Some(oauth) = server.oauth.as_mut() {
oauth.client_id = expand_env_string(
oauth.client_id.as_str(),
&format!("{field_path}[{index}].oauth.client_id"),
)?;
oauth.authorize_url = expand_env_string(
oauth.authorize_url.as_str(),
&format!("{field_path}[{index}].oauth.authorize_url"),
)?;
oauth.token_url = expand_env_string(
oauth.token_url.as_str(),
&format!("{field_path}[{index}].oauth.token_url"),
)?;
if let Some(secret) = oauth.client_secret.as_mut() {
*secret = expand_env_string(
secret.as_str(),
&format!("{field_path}[{index}].oauth.client_secret"),
)?;
}
if let Some(device_url) = oauth.device_authorization_url.as_mut() {
*device_url = expand_env_string(
device_url.as_str(),
&format!("{field_path}[{index}].oauth.device_authorization_url"),
)?;
}
if let Some(redirect) = oauth.redirect_url.as_mut() {
*redirect = expand_env_string(
redirect.as_str(),
&format!("{field_path}[{index}].oauth.redirect_url"),
)?;
}
if let Some(token_env) = oauth.token_env.as_mut() {
*token_env = expand_env_string(
token_env.as_str(),
&format!("{field_path}[{index}].oauth.token_env"),
)?;
}
if let Some(header) = oauth.header.as_mut() {
*header = expand_env_string(
header.as_str(),
&format!("{field_path}[{index}].oauth.header"),
)?;
}
if let Some(prefix) = oauth.header_prefix.as_mut() {
*prefix = expand_env_string(
prefix.as_str(),
&format!("{field_path}[{index}].oauth.header_prefix"),
)?;
}
for (scope_idx, scope) in oauth.scopes.iter_mut().enumerate() {
*scope = expand_env_string(
scope.as_str(),
&format!("{field_path}[{index}].oauth.scopes[{scope_idx}]"),
)?;
}
}
Ok(())
}
fn expand_mcp_resources(resources: &mut [McpResourceConfig], field_path: &str) -> Result<()> {
for (idx, resource) in resources.iter_mut().enumerate() {
expand_mcp_resource_entry(resource, field_path, idx)?;
}
Ok(())
}
fn expand_mcp_resource_entry(
resource: &mut McpResourceConfig,
field_path: &str,
index: usize,
) -> Result<()> {
resource.server = expand_env_string(
resource.server.as_str(),
&format!("{field_path}[{index}].server"),
)?;
resource.uri = expand_env_string(resource.uri.as_str(), &format!("{field_path}[{index}].uri"))?;
if let Some(title) = resource.title.as_mut() {
*title = expand_env_string(title.as_str(), &format!("{field_path}[{index}].title"))?;
}
if let Some(description) = resource.description.as_mut() {
*description = expand_env_string(
description.as_str(),
&format!("{field_path}[{index}].description"),
)?;
}
Ok(())
}
fn expand_env_string(input: &str, field_path: &str) -> Result<String> { fn expand_env_string(input: &str, field_path: &str) -> Result<String> {
if !input.contains('$') { if !input.contains('$') {
return Ok(input.to_string()); return Ok(input.to_string());
@@ -408,6 +918,106 @@ pub fn default_config_path() -> PathBuf {
PathBuf::from(shellexpand::tilde(DEFAULT_CONFIG_PATH).as_ref()) PathBuf::from(shellexpand::tilde(DEFAULT_CONFIG_PATH).as_ref())
} }
#[derive(Serialize, Deserialize, Default, Clone)]
struct McpConfigFile {
#[serde(default)]
servers: Vec<McpServerConfig>,
#[serde(default)]
resources: Vec<McpResourceConfig>,
}
#[derive(Serialize, Deserialize)]
#[serde(untagged)]
enum McpConfigEnvelope {
Array(Vec<McpServerConfig>),
Object(McpConfigFile),
}
fn read_scope_config(path: &Path) -> Result<McpConfigFile> {
if !path.exists() {
return Ok(McpConfigFile::default());
}
let contents = fs::read_to_string(path).map_err(Error::Io)?;
if contents.trim().is_empty() {
return Ok(McpConfigFile::default());
}
let doc: McpConfigEnvelope = serde_json::from_str(&contents).map_err(|err| {
Error::Config(format!(
"Failed to parse MCP configuration at {}: {err}",
path.display()
))
})?;
Ok(match doc {
McpConfigEnvelope::Array(servers) => McpConfigFile {
servers,
resources: Vec::new(),
},
McpConfigEnvelope::Object(doc) => doc,
})
}
fn write_scope_config(path: &Path, file: &McpConfigFile) -> Result<()> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).map_err(Error::Io)?;
}
let serialized = serde_json::to_string_pretty(file).map_err(|err| {
Error::Config(format!(
"Failed to serialize MCP configuration for {}: {err}",
path.display()
))
})?;
fs::write(path, serialized).map_err(Error::Io)
}
/// Resolve the configuration file path for a given scope.
pub fn mcp_scope_path(scope: McpConfigScope, project_hint: Option<&Path>) -> Option<PathBuf> {
match scope {
McpConfigScope::User => dirs::config_dir()
.or_else(|| Some(PathBuf::from(shellexpand::tilde("~/.config").as_ref())))
.map(|dir| dir.join("owlen").join("mcp.json")),
McpConfigScope::Project | McpConfigScope::Local => {
let root = project_hint
.map(PathBuf::from)
.or_else(|| discover_project_root(None))?;
if matches!(scope, McpConfigScope::Project) {
Some(root.join(".mcp.json"))
} else {
Some(root.join(".owlen").join("mcp.local.json"))
}
}
}
}
fn discover_project_root(start: Option<&Path>) -> Option<PathBuf> {
let mut current = start
.map(PathBuf::from)
.or_else(|| std::env::current_dir().ok())?;
loop {
if current.join(".mcp.json").exists()
|| current.join(".owlen").exists()
|| current.join(".git").exists()
|| current.join("Cargo.toml").exists()
{
return Some(current);
}
if !current.pop() {
break;
}
}
start
.map(PathBuf::from)
.or_else(|| std::env::current_dir().ok())
}
/// General behaviour settings shared across clients /// General behaviour settings shared across clients
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeneralSettings { pub struct GeneralSettings {
@@ -1173,6 +1783,7 @@ mod tests {
transport: "udp".into(), transport: "udp".into(),
args: Vec::new(), args: Vec::new(),
env: std::collections::HashMap::new(), env: std::collections::HashMap::new(),
oauth: None,
}]; }];
let result = config.validate(); let result = config.validate();
assert!( assert!(
@@ -1186,4 +1797,113 @@ mod tests {
config.mcp.mode = McpMode::LocalOnly; config.mcp.mode = McpMode::LocalOnly;
assert!(config.validate().is_ok()); assert!(config.validate().is_ok());
} }
#[test]
fn refresh_mcp_servers_merges_scopes_with_precedence() {
let temp = tempfile::tempdir().expect("tempdir");
let project_root = temp.path();
std::fs::write(
project_root.join(".mcp.json"),
r#"{
"servers": [
{ "name": "shared", "command": "project-cmd", "transport": "stdio" },
{ "name": "project-only", "command": "proj", "transport": "stdio" }
],
"resources": [
{ "server": "github", "uri": "issue://123", "title": "Project Issue" },
{ "server": "docs", "uri": "page://start", "title": "Project Doc" }
]
}"#,
)
.expect("write project scope");
let local_dir = project_root.join(".owlen");
std::fs::create_dir_all(&local_dir).expect("local dir");
std::fs::write(
local_dir.join("mcp.local.json"),
r#"{
"servers": [
{ "name": "shared", "command": "local-cmd", "transport": "stdio" }
],
"resources": [
{ "server": "github", "uri": "issue://123", "title": "Local Override" }
]
}"#,
)
.expect("write local scope");
let mut config = Config::default();
config.mcp_servers.push(McpServerConfig {
name: "shared".into(),
command: "user-cmd".into(),
args: Vec::new(),
transport: "stdio".into(),
env: std::collections::HashMap::new(),
oauth: None,
});
config.mcp_resources.push(McpResourceConfig {
server: "github".into(),
uri: "issue://123".into(),
title: Some("User Issue".into()),
description: None,
});
config
.refresh_mcp_servers(Some(project_root))
.expect("refresh scopes");
// We should have four scoped entries (user + two project + local) and precedence should select local
assert_eq!(config.scoped_mcp_servers().len(), 4);
let effective = config.effective_mcp_servers();
assert_eq!(effective.len(), 2); // shared + project-only
assert_eq!(effective[0].command, "local-cmd");
assert_eq!(effective[0].name, "shared");
assert_eq!(config.scoped_mcp_resources().len(), 4);
let effective_resources = config.effective_mcp_resources();
assert_eq!(effective_resources.len(), 2);
assert_eq!(
effective_resources
.iter()
.find(|res| res.server == "github")
.and_then(|res| res.title.as_deref()),
Some("Local Override")
);
}
#[test]
fn remove_mcp_server_reports_scope() {
let temp = tempfile::tempdir().expect("tempdir");
let project_root = temp.path();
std::fs::write(
project_root.join(".mcp.json"),
r#"{ "servers": [{ "name": "project", "command": "proj", "transport": "stdio" }] }"#,
)
.expect("write project scope");
let mut config = Config::default();
config.mcp_servers.push(McpServerConfig {
name: "user".into(),
command: "user".into(),
args: Vec::new(),
transport: "stdio".into(),
env: std::collections::HashMap::new(),
oauth: None,
});
config
.refresh_mcp_servers(Some(project_root))
.expect("refresh scopes");
// Remove without specifying scope should pick highest precedence (project)
let removed_scope = config
.remove_mcp_server(None, "project", Some(project_root))
.expect("remove call");
assert_eq!(removed_scope, Some(McpConfigScope::Project));
// Remove the remaining user scope explicitly
let removed_scope = config
.remove_mcp_server(Some(McpConfigScope::User), "user", Some(project_root))
.expect("remove user");
assert_eq!(removed_scope, Some(McpConfigScope::User));
}
} }

View File

@@ -2,7 +2,7 @@ use std::sync::Arc;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::{Error, Result, storage::StorageManager}; use crate::{Error, Result, oauth::OAuthToken, storage::StorageManager};
#[derive(Serialize, Deserialize, Debug)] #[derive(Serialize, Deserialize, Debug)]
pub struct ApiCredentials { pub struct ApiCredentials {
@@ -31,6 +31,10 @@ impl CredentialManager {
format!("{}_{}", self.namespace, tool_name) format!("{}_{}", self.namespace, tool_name)
} }
fn oauth_storage_key(&self, resource: &str) -> String {
self.namespaced_key(&format!("oauth_{resource}"))
}
pub async fn store_credentials( pub async fn store_credentials(
&self, &self,
tool_name: &str, tool_name: &str,
@@ -68,4 +72,37 @@ impl CredentialManager {
let key = self.namespaced_key(tool_name); let key = self.namespaced_key(tool_name);
self.storage.delete_secure_item(&key).await self.storage.delete_secure_item(&key).await
} }
pub async fn store_oauth_token(&self, resource: &str, token: &OAuthToken) -> Result<()> {
let key = self.oauth_storage_key(resource);
let payload = serde_json::to_vec(token).map_err(|err| {
Error::Storage(format!(
"Failed to serialize OAuth token for secure storage: {err}"
))
})?;
self.storage
.store_secure_item(&key, &payload, &self.master_key)
.await
}
pub async fn load_oauth_token(&self, resource: &str) -> Result<Option<OAuthToken>> {
let key = self.oauth_storage_key(resource);
let raw = self
.storage
.load_secure_item(&key, &self.master_key)
.await?;
if let Some(bytes) = raw {
let token = serde_json::from_slice(&bytes).map_err(|err| {
Error::Storage(format!("Failed to deserialize stored OAuth token: {err}"))
})?;
Ok(Some(token))
} else {
Ok(None)
}
}
pub async fn delete_oauth_token(&self, resource: &str) -> Result<()> {
let key = self.oauth_storage_key(resource);
self.storage.delete_secure_item(&key).await
}
} }

View File

@@ -15,6 +15,7 @@ pub mod llm;
pub mod mcp; pub mod mcp;
pub mod mode; pub mod mode;
pub mod model; pub mod model;
pub mod oauth;
pub mod providers; pub mod providers;
pub mod router; pub mod router;
pub mod sandbox; pub mod sandbox;
@@ -36,6 +37,7 @@ pub use credentials::*;
pub use encryption::*; pub use encryption::*;
pub use formatting::*; pub use formatting::*;
pub use input::*; pub use input::*;
pub use oauth::*;
// Export MCP types but exclude test_utils to avoid ambiguity // Export MCP types but exclude test_utils to avoid ambiguity
pub use llm::{ pub use llm::{
ChatStream, LlmProvider, Provider, ProviderConfig, ProviderRegistry, send_via_stream, ChatStream, LlmProvider, Provider, ProviderConfig, ProviderRegistry, send_via_stream,

View File

@@ -3,7 +3,10 @@
/// Provides a unified interface for creating MCP clients based on configuration. /// Provides a unified interface for creating MCP clients based on configuration.
/// Supports switching between local (in-process) and remote (STDIO) execution modes. /// Supports switching between local (in-process) and remote (STDIO) execution modes.
use super::client::McpClient; use super::client::McpClient;
use super::{LocalMcpClient, remote_client::RemoteMcpClient}; use super::{
LocalMcpClient,
remote_client::{McpRuntimeSecrets, RemoteMcpClient},
};
use crate::config::{Config, McpMode}; use crate::config::{Config, McpMode};
use crate::tools::registry::ToolRegistry; use crate::tools::registry::ToolRegistry;
use crate::validation::SchemaValidator; use crate::validation::SchemaValidator;
@@ -33,6 +36,14 @@ impl McpClientFactory {
/// Create an MCP client based on the current configuration. /// Create an MCP client based on the current configuration.
pub fn create(&self) -> Result<Box<dyn McpClient>> { pub fn create(&self) -> Result<Box<dyn McpClient>> {
self.create_with_secrets(None)
}
/// Create an MCP client using optional runtime secrets (OAuth tokens, env overrides).
pub fn create_with_secrets(
&self,
runtime: Option<McpRuntimeSecrets>,
) -> Result<Box<dyn McpClient>> {
match self.config.mcp.mode { match self.config.mcp.mode {
McpMode::Disabled => Err(Error::Config( McpMode::Disabled => Err(Error::Config(
"MCP mode is set to 'disabled'; tooling cannot function in this configuration." "MCP mode is set to 'disabled'; tooling cannot function in this configuration."
@@ -48,14 +59,14 @@ impl McpClientFactory {
))) )))
} }
McpMode::RemoteOnly => { McpMode::RemoteOnly => {
let server_cfg = self.config.mcp_servers.first().ok_or_else(|| { let server_cfg = self.config.effective_mcp_servers().first().ok_or_else(|| {
Error::Config( Error::Config(
"MCP mode 'remote_only' requires at least one entry in [[mcp_servers]]" "MCP mode 'remote_only' requires at least one entry in [[mcp_servers]]"
.to_string(), .to_string(),
) )
})?; })?;
RemoteMcpClient::new_with_config(server_cfg) RemoteMcpClient::new_with_runtime(server_cfg, runtime)
.map(|client| Box::new(client) as Box<dyn McpClient>) .map(|client| Box::new(client) as Box<dyn McpClient>)
.map_err(|e| { .map_err(|e| {
Error::Config(format!( Error::Config(format!(
@@ -65,8 +76,8 @@ impl McpClientFactory {
}) })
} }
McpMode::RemotePreferred => { McpMode::RemotePreferred => {
if let Some(server_cfg) = self.config.mcp_servers.first() { if let Some(server_cfg) = self.config.effective_mcp_servers().first() {
match RemoteMcpClient::new_with_config(server_cfg) { match RemoteMcpClient::new_with_runtime(server_cfg, runtime.clone()) {
Ok(client) => { Ok(client) => {
info!( info!(
"Connected to remote MCP server '{}' via {} transport.", "Connected to remote MCP server '{}' via {} transport.",
@@ -125,7 +136,8 @@ mod tests {
#[test] #[test]
fn test_factory_creates_local_client_when_no_servers_configured() { fn test_factory_creates_local_client_when_no_servers_configured() {
let config = Config::default(); let mut config = Config::default();
config.refresh_mcp_servers(None).unwrap();
let factory = build_factory(config); let factory = build_factory(config);
@@ -139,6 +151,7 @@ mod tests {
let mut config = Config::default(); let mut config = Config::default();
config.mcp.mode = McpMode::RemoteOnly; config.mcp.mode = McpMode::RemoteOnly;
config.mcp_servers.clear(); config.mcp_servers.clear();
config.refresh_mcp_servers(None).unwrap();
let factory = build_factory(config); let factory = build_factory(config);
let result = factory.create(); let result = factory.create();
@@ -156,7 +169,9 @@ mod tests {
args: Vec::new(), args: Vec::new(),
transport: "stdio".to_string(), transport: "stdio".to_string(),
env: std::collections::HashMap::new(), env: std::collections::HashMap::new(),
oauth: None,
}]; }];
config.refresh_mcp_servers(None).unwrap();
let factory = build_factory(config); let factory = build_factory(config);
let result = factory.create(); let result = factory.create();

View File

@@ -305,6 +305,7 @@ mod tests {
args: vec![], args: vec![],
transport: "http".to_string(), transport: "http".to_string(),
env: std::collections::HashMap::new(), env: std::collections::HashMap::new(),
oauth: None,
}; };
if let Ok(client) = RemoteMcpClient::new_with_config(&config) { if let Ok(client) = RemoteMcpClient::new_with_config(&config) {

View File

@@ -12,6 +12,7 @@ use anyhow::anyhow;
use futures::{StreamExt, future::BoxFuture, stream}; use futures::{StreamExt, future::BoxFuture, stream};
use reqwest::Client as HttpClient; use reqwest::Client as HttpClient;
use serde_json::json; use serde_json::json;
use std::collections::HashMap;
use std::path::Path; use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
@@ -39,6 +40,15 @@ pub struct RemoteMcpClient {
ws_endpoint: Option<String>, ws_endpoint: Option<String>,
// Incrementing request identifier. // Incrementing request identifier.
next_id: AtomicU64, next_id: AtomicU64,
// Optional HTTP header (name, value) injected into every request.
http_header: Option<(String, String)>,
}
/// Runtime secrets provided when constructing an MCP client.
#[derive(Debug, Default, Clone)]
pub struct McpRuntimeSecrets {
pub env_overrides: HashMap<String, String>,
pub http_header: Option<(String, String)>,
} }
impl RemoteMcpClient { impl RemoteMcpClient {
@@ -48,6 +58,14 @@ impl RemoteMcpClient {
/// Spawn an external MCP server based on a configuration entry. /// Spawn an external MCP server based on a configuration entry.
/// The server must communicate over STDIO (the only supported transport). /// The server must communicate over STDIO (the only supported transport).
pub fn new_with_config(config: &crate::config::McpServerConfig) -> Result<Self> { pub fn new_with_config(config: &crate::config::McpServerConfig) -> Result<Self> {
Self::new_with_runtime(config, None)
}
pub fn new_with_runtime(
config: &crate::config::McpServerConfig,
runtime: Option<McpRuntimeSecrets>,
) -> Result<Self> {
let mut runtime = runtime.unwrap_or_default();
let transport = config.transport.to_lowercase(); let transport = config.transport.to_lowercase();
match transport.as_str() { match transport.as_str() {
"stdio" => { "stdio" => {
@@ -64,6 +82,9 @@ impl RemoteMcpClient {
for (k, v) in config.env.iter() { for (k, v) in config.env.iter() {
cmd.env(k, v); cmd.env(k, v);
} }
for (k, v) in runtime.env_overrides.drain() {
cmd.env(k, v);
}
let mut child = cmd.spawn().map_err(|e| { let mut child = cmd.spawn().map_err(|e| {
Error::Io(std::io::Error::new( Error::Io(std::io::Error::new(
@@ -92,6 +113,7 @@ impl RemoteMcpClient {
ws_stream: None, ws_stream: None,
ws_endpoint: None, ws_endpoint: None,
next_id: AtomicU64::new(1), next_id: AtomicU64::new(1),
http_header: None,
}) })
} }
"http" => { "http" => {
@@ -109,6 +131,7 @@ impl RemoteMcpClient {
ws_stream: None, ws_stream: None,
ws_endpoint: None, ws_endpoint: None,
next_id: AtomicU64::new(1), next_id: AtomicU64::new(1),
http_header: runtime.http_header.take(),
}) })
} }
"websocket" => { "websocket" => {
@@ -132,6 +155,7 @@ impl RemoteMcpClient {
ws_stream: Some(Arc::new(Mutex::new(ws_stream))), ws_stream: Some(Arc::new(Mutex::new(ws_stream))),
ws_endpoint: Some(ws_url), ws_endpoint: Some(ws_url),
next_id: AtomicU64::new(1), next_id: AtomicU64::new(1),
http_header: runtime.http_header.take(),
}) })
} }
other => Err(Error::NotImplemented(format!( other => Err(Error::NotImplemented(format!(
@@ -171,6 +195,7 @@ impl RemoteMcpClient {
args: Vec::new(), args: Vec::new(),
transport: "stdio".to_string(), transport: "stdio".to_string(),
env: std::collections::HashMap::new(), env: std::collections::HashMap::new(),
oauth: None,
}; };
Self::new_with_config(&config) Self::new_with_config(&config)
} }
@@ -193,8 +218,11 @@ impl RemoteMcpClient {
.http_endpoint .http_endpoint
.as_ref() .as_ref()
.ok_or_else(|| Error::Network("Missing HTTP endpoint".into()))?; .ok_or_else(|| Error::Network("Missing HTTP endpoint".into()))?;
let resp = client let mut builder = client.post(endpoint);
.post(endpoint) if let Some((ref header_name, ref header_value)) = self.http_header {
builder = builder.header(header_name, header_value);
}
let resp = builder
.json(&request) .json(&request)
.send() .send()
.await .await

View File

@@ -0,0 +1,507 @@
use std::time::Duration as StdDuration;
use chrono::{DateTime, Duration, Utc};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use crate::{Error, Result, config::McpOAuthConfig};
/// Persisted OAuth token set for MCP servers and providers.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
pub struct OAuthToken {
/// Bearer access token returned by the authorization server.
pub access_token: String,
/// Optional refresh token if the provider issues one.
#[serde(default)]
pub refresh_token: Option<String>,
/// Absolute UTC expiration timestamp for the access token.
#[serde(default)]
pub expires_at: Option<DateTime<Utc>>,
/// Optional space-delimited scope string supplied by the provider.
#[serde(default)]
pub scope: Option<String>,
/// Token type reported by the provider (typically `Bearer`).
#[serde(default)]
pub token_type: Option<String>,
}
impl OAuthToken {
/// Returns `true` if the access token has expired at the provided instant.
pub fn is_expired(&self, now: DateTime<Utc>) -> bool {
matches!(self.expires_at, Some(expiry) if now >= expiry)
}
/// Returns `true` if the token will expire within the supplied duration window.
pub fn will_expire_within(&self, window: Duration, now: DateTime<Utc>) -> bool {
matches!(self.expires_at, Some(expiry) if expiry - now <= window)
}
}
/// Active device-authorization session details returned by the authorization server.
#[derive(Debug, Clone)]
pub struct DeviceAuthorization {
pub device_code: String,
pub user_code: String,
pub verification_uri: String,
pub verification_uri_complete: Option<String>,
pub expires_at: DateTime<Utc>,
pub interval: StdDuration,
pub message: Option<String>,
}
impl DeviceAuthorization {
pub fn is_expired(&self, now: DateTime<Utc>) -> bool {
now >= self.expires_at
}
}
/// Result of polling the token endpoint during a device-authorization flow.
#[derive(Debug, Clone)]
pub enum DevicePollState {
Pending { retry_in: StdDuration },
Complete(OAuthToken),
}
pub struct OAuthClient {
http: Client,
config: McpOAuthConfig,
}
impl OAuthClient {
pub fn new(config: McpOAuthConfig) -> Result<Self> {
let http = Client::builder()
.user_agent("OwlenOAuth/1.0")
.build()
.map_err(|err| Error::Network(format!("Failed to construct HTTP client: {err}")))?;
Ok(Self { http, config })
}
fn scope_value(&self) -> Option<String> {
if self.config.scopes.is_empty() {
None
} else {
Some(self.config.scopes.join(" "))
}
}
fn token_request_base(&self) -> Vec<(String, String)> {
let mut params = vec![("client_id".to_string(), self.config.client_id.clone())];
if let Some(secret) = &self.config.client_secret {
params.push(("client_secret".to_string(), secret.clone()));
}
params
}
pub async fn start_device_authorization(&self) -> Result<DeviceAuthorization> {
let device_url = self
.config
.device_authorization_url
.as_ref()
.ok_or_else(|| {
Error::Config("Device authorization endpoint is not configured.".to_string())
})?;
let mut params = self.token_request_base();
if let Some(scope) = self.scope_value() {
params.push(("scope".to_string(), scope));
}
let response = self
.http
.post(device_url)
.form(&params)
.send()
.await
.map_err(|err| map_http_error("start device authorization", err))?;
let status = response.status();
let payload = response
.json::<DeviceAuthorizationResponse>()
.await
.map_err(|err| {
Error::Auth(format!(
"Failed to parse device authorization response (status {status}): {err}"
))
})?;
let expires_at =
Utc::now() + Duration::seconds(payload.expires_in.min(i64::MAX as u64) as i64);
let interval = StdDuration::from_secs(payload.interval.unwrap_or(5).max(1));
Ok(DeviceAuthorization {
device_code: payload.device_code,
user_code: payload.user_code,
verification_uri: payload.verification_uri,
verification_uri_complete: payload.verification_uri_complete,
expires_at,
interval,
message: payload.message,
})
}
pub async fn poll_device_token(&self, auth: &DeviceAuthorization) -> Result<DevicePollState> {
let mut params = self.token_request_base();
params.push(("grant_type".to_string(), DEVICE_CODE_GRANT.to_string()));
params.push(("device_code".to_string(), auth.device_code.clone()));
if let Some(scope) = self.scope_value() {
params.push(("scope".to_string(), scope));
}
let response = self
.http
.post(&self.config.token_url)
.form(&params)
.send()
.await
.map_err(|err| map_http_error("poll device token", err))?;
let status = response.status();
let text = response
.text()
.await
.map_err(|err| map_http_error("read token response", err))?;
if status.is_success() {
let payload: TokenResponse = serde_json::from_str(&text).map_err(|err| {
Error::Auth(format!(
"Failed to parse OAuth token response: {err}; body: {text}"
))
})?;
return Ok(DevicePollState::Complete(oauth_token_from_response(
payload,
)));
}
let error = serde_json::from_str::<OAuthErrorResponse>(&text).unwrap_or_else(|_| {
OAuthErrorResponse {
error: "unknown_error".to_string(),
error_description: Some(text.clone()),
}
});
match error.error.as_str() {
"authorization_pending" => Ok(DevicePollState::Pending {
retry_in: auth.interval,
}),
"slow_down" => Ok(DevicePollState::Pending {
retry_in: auth.interval.saturating_add(StdDuration::from_secs(5)),
}),
"access_denied" => {
Err(Error::Auth(error.error_description.unwrap_or_else(|| {
"User declined authorization".to_string()
})))
}
"expired_token" | "expired_device_code" => {
Err(Error::Auth(error.error_description.unwrap_or_else(|| {
"Device authorization expired".to_string()
})))
}
other => Err(Error::Auth(
error
.error_description
.unwrap_or_else(|| format!("OAuth error: {other}")),
)),
}
}
pub async fn refresh_token(&self, refresh_token: &str) -> Result<OAuthToken> {
let mut params = self.token_request_base();
params.push(("grant_type".to_string(), "refresh_token".to_string()));
params.push(("refresh_token".to_string(), refresh_token.to_string()));
if let Some(scope) = self.scope_value() {
params.push(("scope".to_string(), scope));
}
let response = self
.http
.post(&self.config.token_url)
.form(&params)
.send()
.await
.map_err(|err| map_http_error("refresh OAuth token", err))?;
let status = response.status();
let text = response
.text()
.await
.map_err(|err| map_http_error("read refresh response", err))?;
if status.is_success() {
let payload: TokenResponse = serde_json::from_str(&text).map_err(|err| {
Error::Auth(format!(
"Failed to parse OAuth refresh response: {err}; body: {text}"
))
})?;
Ok(oauth_token_from_response(payload))
} else {
let error = serde_json::from_str::<OAuthErrorResponse>(&text).unwrap_or_else(|_| {
OAuthErrorResponse {
error: "unknown_error".to_string(),
error_description: Some(text.clone()),
}
});
Err(Error::Auth(error.error_description.unwrap_or_else(|| {
format!("OAuth token refresh failed: {}", error.error)
})))
}
}
}
const DEVICE_CODE_GRANT: &str = "urn:ietf:params:oauth:grant-type:device_code";
#[derive(Debug, Deserialize)]
struct DeviceAuthorizationResponse {
device_code: String,
user_code: String,
verification_uri: String,
#[serde(default)]
verification_uri_complete: Option<String>,
expires_in: u64,
#[serde(default)]
interval: Option<u64>,
#[serde(default)]
message: Option<String>,
}
#[derive(Debug, Deserialize)]
struct TokenResponse {
access_token: String,
#[serde(default)]
refresh_token: Option<String>,
#[serde(default)]
expires_in: Option<u64>,
#[serde(default)]
scope: Option<String>,
#[serde(default)]
token_type: Option<String>,
}
#[derive(Debug, Deserialize)]
struct OAuthErrorResponse {
error: String,
#[serde(default)]
error_description: Option<String>,
}
fn oauth_token_from_response(payload: TokenResponse) -> OAuthToken {
let expires_at = payload
.expires_in
.map(|seconds| seconds.min(i64::MAX as u64) as i64)
.map(|seconds| Utc::now() + Duration::seconds(seconds));
OAuthToken {
access_token: payload.access_token,
refresh_token: payload.refresh_token,
expires_at,
scope: payload.scope,
token_type: payload.token_type,
}
}
fn map_http_error(action: &str, err: reqwest::Error) -> Error {
if err.is_timeout() {
Error::Timeout(format!("OAuth {action} request timed out: {err}"))
} else if err.is_connect() {
Error::Network(format!("OAuth {action} connection error: {err}"))
} else {
Error::Network(format!("OAuth {action} request failed: {err}"))
}
}
#[cfg(test)]
mod tests {
use super::*;
use httpmock::prelude::*;
use serde_json::json;
fn config_for(server: &MockServer) -> McpOAuthConfig {
McpOAuthConfig {
client_id: "test-client".to_string(),
client_secret: None,
authorize_url: server.url("/authorize"),
token_url: server.url("/token"),
device_authorization_url: Some(server.url("/device")),
redirect_url: None,
scopes: vec!["repo".to_string(), "user".to_string()],
token_env: None,
header: None,
header_prefix: None,
}
}
fn sample_device_authorization() -> DeviceAuthorization {
DeviceAuthorization {
device_code: "device-123".to_string(),
user_code: "ABCD-EFGH".to_string(),
verification_uri: "https://example.test/activate".to_string(),
verification_uri_complete: Some(
"https://example.test/activate?user_code=ABCD-EFGH".to_string(),
),
expires_at: Utc::now() + Duration::minutes(10),
interval: StdDuration::from_secs(5),
message: Some("Open the verification URL and enter the code.".to_string()),
}
}
#[tokio::test]
async fn start_device_authorization_returns_payload() {
let server = MockServer::start_async().await;
let device_mock = server
.mock_async(|when, then| {
when.method(POST).path("/device");
then.status(200)
.header("content-type", "application/json")
.json_body(json!({
"device_code": "device-123",
"user_code": "ABCD-EFGH",
"verification_uri": "https://example.test/activate",
"verification_uri_complete": "https://example.test/activate?user_code=ABCD-EFGH",
"expires_in": 600,
"interval": 7,
"message": "Open the verification URL and enter the code."
}));
})
.await;
let client = OAuthClient::new(config_for(&server)).expect("client");
let auth = client
.start_device_authorization()
.await
.expect("device authorization payload");
assert_eq!(auth.user_code, "ABCD-EFGH");
assert_eq!(auth.interval, StdDuration::from_secs(7));
assert!(auth.expires_at > Utc::now());
device_mock.assert_async().await;
}
#[tokio::test]
async fn poll_device_token_reports_pending() {
let server = MockServer::start_async().await;
let pending = server
.mock_async(|when, then| {
when.method(POST)
.path("/token")
.body_contains(
"grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code",
)
.body_contains("device_code=device-123");
then.status(400)
.header("content-type", "application/json")
.json_body(json!({
"error": "authorization_pending"
}));
})
.await;
let config = config_for(&server);
let client = OAuthClient::new(config).expect("client");
let auth = sample_device_authorization();
let result = client.poll_device_token(&auth).await.expect("poll result");
match result {
DevicePollState::Pending { retry_in } => {
assert_eq!(retry_in, StdDuration::from_secs(5));
}
other => panic!("expected pending state, got {other:?}"),
}
pending.assert_async().await;
}
#[tokio::test]
async fn poll_device_token_applies_slow_down_backoff() {
let server = MockServer::start_async().await;
let slow = server
.mock_async(|when, then| {
when.method(POST).path("/token");
then.status(400)
.header("content-type", "application/json")
.json_body(json!({
"error": "slow_down"
}));
})
.await;
let config = config_for(&server);
let client = OAuthClient::new(config).expect("client");
let auth = sample_device_authorization();
let result = client.poll_device_token(&auth).await.expect("poll result");
match result {
DevicePollState::Pending { retry_in } => {
assert_eq!(retry_in, StdDuration::from_secs(10));
}
other => panic!("expected pending state, got {other:?}"),
}
slow.assert_async().await;
}
#[tokio::test]
async fn poll_device_token_returns_token_when_authorized() {
let server = MockServer::start_async().await;
let token = server
.mock_async(|when, then| {
when.method(POST).path("/token");
then.status(200)
.header("content-type", "application/json")
.json_body(json!({
"access_token": "token-abc",
"refresh_token": "refresh-xyz",
"expires_in": 3600,
"token_type": "Bearer",
"scope": "repo user"
}));
})
.await;
let config = config_for(&server);
let client = OAuthClient::new(config).expect("client");
let auth = sample_device_authorization();
let result = client.poll_device_token(&auth).await.expect("poll result");
let token_info = match result {
DevicePollState::Complete(token) => token,
other => panic!("expected completion, got {other:?}"),
};
assert_eq!(token_info.access_token, "token-abc");
assert_eq!(token_info.refresh_token.as_deref(), Some("refresh-xyz"));
assert!(token_info.expires_at.is_some());
token.assert_async().await;
}
#[tokio::test]
async fn refresh_token_roundtrip() {
let server = MockServer::start_async().await;
let refresh = server
.mock_async(|when, then| {
when.method(POST)
.path("/token")
.body_contains("grant_type=refresh_token")
.body_contains("refresh_token=old-refresh");
then.status(200)
.header("content-type", "application/json")
.json_body(json!({
"access_token": "token-new",
"refresh_token": "refresh-new",
"expires_in": 1200,
"token_type": "Bearer"
}));
})
.await;
let config = config_for(&server);
let client = OAuthClient::new(config).expect("client");
let token = client
.refresh_token("old-refresh")
.await
.expect("refresh response");
assert_eq!(token.access_token, "token-new");
assert_eq!(token.refresh_token.as_deref(), Some("refresh-new"));
assert!(token.expires_at.is_some());
refresh.assert_async().await;
}
}

View File

@@ -1,4 +1,4 @@
use crate::config::Config; use crate::config::{Config, McpResourceConfig, McpServerConfig};
use crate::consent::ConsentManager; use crate::consent::ConsentManager;
use crate::conversation::ConversationManager; use crate::conversation::ConversationManager;
use crate::credentials::CredentialManager; use crate::credentials::CredentialManager;
@@ -9,8 +9,10 @@ use crate::mcp::McpToolCall;
use crate::mcp::client::McpClient; use crate::mcp::client::McpClient;
use crate::mcp::factory::McpClientFactory; use crate::mcp::factory::McpClientFactory;
use crate::mcp::permission::PermissionLayer; use crate::mcp::permission::PermissionLayer;
use crate::mcp::remote_client::{McpRuntimeSecrets, RemoteMcpClient};
use crate::mode::Mode; use crate::mode::Mode;
use crate::model::{DetailedModelInfo, ModelManager}; use crate::model::{DetailedModelInfo, ModelManager};
use crate::oauth::{DeviceAuthorization, DevicePollState, OAuthClient};
use crate::providers::OllamaProvider; use crate::providers::OllamaProvider;
use crate::storage::{SessionMeta, StorageManager}; use crate::storage::{SessionMeta, StorageManager};
use crate::types::{ use crate::types::{
@@ -24,8 +26,10 @@ use crate::{
ToolRegistry, WebScrapeTool, WebSearchDetailedTool, WebSearchTool, ToolRegistry, WebScrapeTool, WebSearchDetailedTool, WebSearchTool,
}; };
use crate::{Error, Result}; use crate::{Error, Result};
use chrono::Utc;
use log::warn; use log::warn;
use serde_json::Value; use serde_json::{Value, json};
use std::collections::HashMap;
use std::env; use std::env;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
@@ -96,6 +100,7 @@ pub struct SessionController {
tool_registry: Arc<ToolRegistry>, tool_registry: Arc<ToolRegistry>,
schema_validator: Arc<SchemaValidator>, schema_validator: Arc<SchemaValidator>,
mcp_client: Arc<dyn McpClient>, mcp_client: Arc<dyn McpClient>,
named_mcp_clients: HashMap<String, Arc<dyn McpClient>>,
storage: Arc<StorageManager>, storage: Arc<StorageManager>,
vault: Option<Arc<Mutex<VaultHandle>>>, vault: Option<Arc<Mutex<VaultHandle>>>,
master_key: Option<Arc<Vec<u8>>>, master_key: Option<Arc<Vec<u8>>>,
@@ -103,6 +108,7 @@ pub struct SessionController {
ui: Arc<dyn UiController>, ui: Arc<dyn UiController>,
enable_code_tools: bool, enable_code_tools: bool,
current_mode: Mode, current_mode: Mode,
missing_oauth_servers: Vec<String>,
} }
async fn build_tools( async fn build_tools(
@@ -211,6 +217,112 @@ async fn build_tools(
} }
impl SessionController { impl SessionController {
async fn create_mcp_clients(
config: Arc<TokioMutex<Config>>,
tool_registry: Arc<ToolRegistry>,
schema_validator: Arc<SchemaValidator>,
credential_manager: Option<Arc<CredentialManager>>,
initial_mode: Mode,
) -> Result<(
Arc<dyn McpClient>,
HashMap<String, Arc<dyn McpClient>>,
Vec<String>,
)> {
let guard = config.lock().await;
let config_arc = Arc::new(guard.clone());
let factory = McpClientFactory::new(config_arc.clone(), tool_registry, schema_validator);
let mut missing_oauth_servers = Vec::new();
let primary_runtime = if let Some(primary_cfg) = guard.effective_mcp_servers().first() {
let (runtime, missing) =
Self::runtime_secrets_for_server(credential_manager.clone(), primary_cfg).await?;
if missing {
missing_oauth_servers.push(primary_cfg.name.clone());
}
runtime
} else {
None
};
let base_client = factory.create_with_secrets(primary_runtime)?;
let primary: Arc<dyn McpClient> =
Arc::new(PermissionLayer::new(base_client, config_arc.clone()));
primary.set_mode(initial_mode).await?;
let mut clients: HashMap<String, Arc<dyn McpClient>> = HashMap::new();
if let Some(primary_cfg) = guard.effective_mcp_servers().first() {
clients.insert(primary_cfg.name.clone(), Arc::clone(&primary));
}
for server_cfg in guard.effective_mcp_servers().iter().skip(1) {
let (runtime, missing) =
Self::runtime_secrets_for_server(credential_manager.clone(), server_cfg).await?;
if missing {
missing_oauth_servers.push(server_cfg.name.clone());
}
match RemoteMcpClient::new_with_runtime(server_cfg, runtime) {
Ok(remote) => {
let client: Arc<dyn McpClient> =
Arc::new(PermissionLayer::new(Box::new(remote), config_arc.clone()));
if let Err(err) = client.set_mode(initial_mode).await {
warn!(
"Failed to initialize MCP server '{}' in mode {:?}: {}",
server_cfg.name, initial_mode, err
);
}
clients.insert(server_cfg.name.clone(), Arc::clone(&client));
}
Err(err) => warn!(
"Failed to initialize MCP server '{}': {}",
server_cfg.name, err
),
}
}
drop(guard);
Ok((primary, clients, missing_oauth_servers))
}
async fn runtime_secrets_for_server(
credential_manager: Option<Arc<CredentialManager>>,
server: &McpServerConfig,
) -> Result<(Option<McpRuntimeSecrets>, bool)> {
if let Some(oauth) = &server.oauth {
if let Some(manager) = credential_manager {
match manager.load_oauth_token(&server.name).await? {
Some(token) => {
if token.access_token.trim().is_empty() || token.is_expired(Utc::now()) {
return Ok((None, true));
}
let mut secrets = McpRuntimeSecrets::default();
if let Some(env_name) = oauth.token_env.as_deref() {
secrets
.env_overrides
.insert(env_name.to_string(), token.access_token.clone());
}
if matches!(
server.transport.to_ascii_lowercase().as_str(),
"http" | "websocket"
) {
let header_value =
format!("{}{}", oauth.header_prefix(), token.access_token);
secrets.http_header =
Some((oauth.header_name().to_string(), header_value));
}
Ok((Some(secrets), false))
}
None => Ok((None, true)),
}
} else {
Ok((None, true))
}
} else {
Ok((None, false))
}
}
pub async fn new( pub async fn new(
provider: Arc<dyn Provider>, provider: Arc<dyn Provider>,
config: Config, config: Config,
@@ -292,19 +404,14 @@ impl SessionController {
) )
.await?; .await?;
// Create MCP client with permission layer let (mcp_client, named_mcp_clients, missing_oauth_servers) = Self::create_mcp_clients(
let mcp_client: Arc<dyn McpClient> = { config_arc.clone(),
let guard = config_arc.lock().await; tool_registry.clone(),
let factory = McpClientFactory::new( schema_validator.clone(),
Arc::new(guard.clone()), credential_manager.clone(),
tool_registry.clone(), initial_mode,
schema_validator.clone(), )
); .await?;
let base_client = factory.create()?;
let client = Arc::new(PermissionLayer::new(base_client, Arc::new(guard.clone())));
client.set_mode(initial_mode).await?;
client
};
Ok(Self { Ok(Self {
provider, provider,
@@ -317,6 +424,7 @@ impl SessionController {
tool_registry, tool_registry,
schema_validator, schema_validator,
mcp_client, mcp_client,
named_mcp_clients,
storage, storage,
vault: vault_handle, vault: vault_handle,
master_key, master_key,
@@ -324,6 +432,7 @@ impl SessionController {
ui, ui,
enable_code_tools, enable_code_tools,
current_mode: initial_mode, current_mode: initial_mode,
missing_oauth_servers,
}) })
} }
@@ -355,6 +464,63 @@ impl SessionController {
self.formatter.set_role_label_mode(mode); self.formatter.set_role_label_mode(mode);
} }
/// Return the configured resource references aggregated across scopes.
pub async fn configured_resources(&self) -> Vec<McpResourceConfig> {
let guard = self.config.lock().await;
guard.effective_mcp_resources().to_vec()
}
/// Resolve a resource reference of the form `server:uri` (optionally prefixed with `@`).
pub async fn resolve_resource_reference(&self, reference: &str) -> Result<Option<String>> {
let (server, uri) = match Self::split_resource_reference(reference) {
Some(parts) => parts,
None => return Ok(None),
};
let resource_defined = {
let guard = self.config.lock().await;
guard.find_resource(&server, &uri).is_some()
};
if !resource_defined {
return Ok(None);
}
let client = self
.named_mcp_clients
.get(&server)
.cloned()
.ok_or_else(|| {
Error::Config(format!(
"MCP server '{}' referenced by resource '{}' is not available",
server, uri
))
})?;
let call = McpToolCall {
name: "resources/get".to_string(),
arguments: json!({ "uri": uri, "path": uri }),
};
let response = client.call_tool(call).await?;
if let Some(text) = extract_resource_content(&response.output) {
return Ok(Some(text));
}
let formatted = serde_json::to_string_pretty(&response.output)
.unwrap_or_else(|_| response.output.to_string());
Ok(Some(formatted))
}
fn split_resource_reference(reference: &str) -> Option<(String, String)> {
let trimmed = reference.trim();
let without_prefix = trimmed.strip_prefix('@').unwrap_or(trimmed);
let (server, uri) = without_prefix.split_once(':')?;
if server.is_empty() || uri.is_empty() {
return None;
}
Some((server.to_string(), uri.to_string()))
}
// Asynchronous access to the configuration (used internally). // Asynchronous access to the configuration (used internally).
pub async fn config_async(&self) -> tokio::sync::MutexGuard<'_, Config> { pub async fn config_async(&self) -> tokio::sync::MutexGuard<'_, Config> {
self.config.lock().await self.config.lock().await
@@ -378,6 +544,21 @@ impl SessionController {
self.config.clone() self.config.clone()
} }
pub async fn reload_mcp_clients(&mut self) -> Result<()> {
let (primary, named, missing) = Self::create_mcp_clients(
self.config.clone(),
self.tool_registry.clone(),
self.schema_validator.clone(),
self.credential_manager.clone(),
self.current_mode,
)
.await?;
self.mcp_client = primary;
self.named_mcp_clients = named;
self.missing_oauth_servers = missing;
Ok(())
}
pub fn grant_consent(&self, tool_name: &str, data_types: Vec<String>, endpoints: Vec<String>) { pub fn grant_consent(&self, tool_name: &str, data_types: Vec<String>, endpoints: Vec<String>) {
let mut consent = self let mut consent = self
.consent_manager .consent_manager
@@ -525,6 +706,115 @@ impl SessionController {
self.schema_validator.clone() self.schema_validator.clone()
} }
pub fn credential_manager(&self) -> Option<Arc<CredentialManager>> {
self.credential_manager.clone()
}
pub fn pending_oauth_servers(&self) -> Vec<String> {
self.missing_oauth_servers.clone()
}
pub async fn start_oauth_device_flow(&self, server: &str) -> Result<DeviceAuthorization> {
let oauth_config = {
let config = self.config.lock().await;
let server_cfg = config
.effective_mcp_servers()
.iter()
.find(|entry| entry.name == server)
.ok_or_else(|| {
Error::Config(format!("No MCP server named '{server}' is configured"))
})?;
server_cfg.oauth.clone().ok_or_else(|| {
Error::Config(format!(
"MCP server '{server}' does not define an OAuth configuration"
))
})?
};
let client = OAuthClient::new(oauth_config)?;
client.start_device_authorization().await
}
pub async fn poll_oauth_device_flow(
&mut self,
server: &str,
authorization: &DeviceAuthorization,
) -> Result<DevicePollState> {
let oauth_config = {
let config = self.config.lock().await;
let server_cfg = config
.effective_mcp_servers()
.iter()
.find(|entry| entry.name == server)
.ok_or_else(|| {
Error::Config(format!("No MCP server named '{server}' is configured"))
})?;
server_cfg.oauth.clone().ok_or_else(|| {
Error::Config(format!(
"MCP server '{server}' does not define an OAuth configuration"
))
})?
};
let client = OAuthClient::new(oauth_config)?;
match client.poll_device_token(authorization).await? {
DevicePollState::Pending { retry_in } => Ok(DevicePollState::Pending { retry_in }),
DevicePollState::Complete(token) => {
let manager = self.credential_manager.as_ref().cloned().ok_or_else(|| {
Error::Config(
"OAuth token storage requires encrypted local data; set \
privacy.encrypt_local_data = true in the configuration."
.to_string(),
)
})?;
manager.store_oauth_token(server, &token).await?;
self.missing_oauth_servers.retain(|entry| entry != server);
Ok(DevicePollState::Complete(token))
}
}
}
pub async fn list_mcp_tools(&self) -> Vec<(String, crate::mcp::McpToolDescriptor)> {
let mut entries = Vec::new();
for (server, client) in self.named_mcp_clients.iter() {
let server_name = server.clone();
let client = Arc::clone(client);
match client.list_tools().await {
Ok(tools) => {
for descriptor in tools {
entries.push((server_name.clone(), descriptor));
}
}
Err(err) => {
warn!(
"Failed to list tools for MCP server '{}': {}",
server_name, err
);
}
}
}
entries
}
pub async fn call_mcp_tool(
&self,
server: &str,
tool: &str,
arguments: Value,
) -> Result<crate::mcp::McpToolResponse> {
let client = self.named_mcp_clients.get(server).cloned().ok_or_else(|| {
Error::Config(format!("No MCP server named '{}' is registered", server))
})?;
client
.call_tool(McpToolCall {
name: tool.to_string(),
arguments,
})
.await
}
pub fn mcp_server(&self) -> crate::mcp::McpServer { pub fn mcp_server(&self) -> crate::mcp::McpServer {
crate::mcp::McpServer::new(self.tool_registry(), self.schema_validator()) crate::mcp::McpServer::new(self.tool_registry(), self.schema_validator())
} }
@@ -985,3 +1275,195 @@ impl SessionController {
Ok("Empty conversation".to_string()) Ok("Empty conversation".to_string())
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use crate::Provider;
use crate::config::{Config, McpMode, McpOAuthConfig, McpServerConfig};
use crate::llm::test_utils::MockProvider;
use crate::storage::StorageManager;
use crate::ui::NoOpUiController;
use chrono::Utc;
use httpmock::prelude::*;
use serde_json::json;
use std::collections::HashMap;
use std::sync::Arc;
use tempfile::tempdir;
const SERVER_NAME: &str = "oauth-test";
fn build_oauth_config(server: &MockServer) -> McpOAuthConfig {
McpOAuthConfig {
client_id: "owlen-client".to_string(),
client_secret: None,
authorize_url: server.url("/authorize"),
token_url: server.url("/token"),
device_authorization_url: Some(server.url("/device")),
redirect_url: None,
scopes: vec!["repo".to_string()],
token_env: Some("OAUTH_TOKEN".to_string()),
header: Some("Authorization".to_string()),
header_prefix: Some("Bearer ".to_string()),
}
}
fn build_config(server: &MockServer) -> Config {
let mut config = Config::default();
config.mcp.mode = McpMode::LocalOnly;
let oauth = build_oauth_config(server);
let mut env = HashMap::new();
env.insert("OWLEN_ENV".to_string(), "test".to_string());
config.mcp_servers = vec![McpServerConfig {
name: SERVER_NAME.to_string(),
command: server.url("/mcp"),
args: Vec::new(),
transport: "http".to_string(),
env,
oauth: Some(oauth),
}];
config.refresh_mcp_servers(None).unwrap();
config
}
async fn build_session(server: &MockServer) -> (SessionController, tempfile::TempDir) {
unsafe {
std::env::set_var("OWLEN_MASTER_PASSWORD", "test-password");
}
let temp_dir = tempdir().expect("tempdir");
let storage_path = temp_dir.path().join("owlen.db");
let storage = Arc::new(
StorageManager::with_database_path(storage_path)
.await
.expect("storage"),
);
let config = build_config(server);
let provider: Arc<dyn Provider> = Arc::new(MockProvider::default()) as Arc<dyn Provider>;
let ui = Arc::new(NoOpUiController);
let session = SessionController::new(provider, config, storage, ui, false)
.await
.expect("session");
(session, temp_dir)
}
#[tokio::test]
async fn start_oauth_device_flow_returns_details() {
let server = MockServer::start_async().await;
let device = server
.mock_async(|when, then| {
when.method(POST).path("/device");
then.status(200)
.header("content-type", "application/json")
.json_body(json!({
"device_code": "device-abc",
"user_code": "ABCD-1234",
"verification_uri": "https://example.test/activate",
"verification_uri_complete": "https://example.test/activate?user_code=ABCD-1234",
"expires_in": 600,
"interval": 5,
"message": "Enter the code to continue."
}));
})
.await;
let (session, _dir) = build_session(&server).await;
let authorization = session
.start_oauth_device_flow(SERVER_NAME)
.await
.expect("device flow");
assert_eq!(authorization.user_code, "ABCD-1234");
assert_eq!(
authorization.verification_uri_complete.as_deref(),
Some("https://example.test/activate?user_code=ABCD-1234")
);
assert!(authorization.expires_at > Utc::now());
device.assert_async().await;
}
#[tokio::test]
async fn poll_oauth_device_flow_stores_token_and_updates_state() {
let server = MockServer::start_async().await;
let device = server
.mock_async(|when, then| {
when.method(POST).path("/device");
then.status(200)
.header("content-type", "application/json")
.json_body(json!({
"device_code": "device-xyz",
"user_code": "WXYZ-9999",
"verification_uri": "https://example.test/activate",
"verification_uri_complete": "https://example.test/activate?user_code=WXYZ-9999",
"expires_in": 600,
"interval": 5
}));
})
.await;
let token = server
.mock_async(|when, then| {
when.method(POST)
.path("/token")
.body_contains("device_code=device-xyz");
then.status(200)
.header("content-type", "application/json")
.json_body(json!({
"access_token": "new-access-token",
"refresh_token": "refresh-token",
"expires_in": 3600,
"token_type": "Bearer"
}));
})
.await;
let (mut session, _dir) = build_session(&server).await;
assert_eq!(session.pending_oauth_servers(), vec![SERVER_NAME]);
let authorization = session
.start_oauth_device_flow(SERVER_NAME)
.await
.expect("device flow");
match session
.poll_oauth_device_flow(SERVER_NAME, &authorization)
.await
.expect("token poll")
{
DevicePollState::Complete(token_info) => {
assert_eq!(token_info.access_token, "new-access-token");
assert_eq!(token_info.refresh_token.as_deref(), Some("refresh-token"));
}
other => panic!("expected token completion, got {other:?}"),
}
assert!(
session
.pending_oauth_servers()
.iter()
.all(|entry| entry != SERVER_NAME),
"server should be removed from pending list"
);
let stored = session
.credential_manager()
.expect("credential manager")
.load_oauth_token(SERVER_NAME)
.await
.expect("load token")
.expect("token present");
assert_eq!(stored.access_token, "new-access-token");
assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token"));
device.assert_async().await;
token.assert_async().await;
}
}

View File

@@ -44,6 +44,7 @@ async fn test_render_prompt_via_external_server() -> Result<()> {
args: Vec::new(), args: Vec::new(),
transport: "stdio".into(), transport: "stdio".into(),
env: std::collections::HashMap::new(), env: std::collections::HashMap::new(),
oauth: None,
}; };
let client = match RemoteMcpClient::new_with_config(&config) { let client = match RemoteMcpClient::new_with_config(&config) {

View File

@@ -5,6 +5,7 @@
//! crates can depend only on `owlen-mcp-client` without pulling in the entire //! crates can depend only on `owlen-mcp-client` without pulling in the entire
//! core crate internals. //! core crate internals.
pub use owlen_core::config::{McpConfigScope, ScopedMcpServer};
pub use owlen_core::mcp::remote_client::RemoteMcpClient; pub use owlen_core::mcp::remote_client::RemoteMcpClient;
pub use owlen_core::mcp::{McpClient, McpToolCall, McpToolDescriptor, McpToolResponse}; pub use owlen_core::mcp::{McpClient, McpToolCall, McpToolDescriptor, McpToolResponse};

View File

@@ -1,10 +1,13 @@
use anyhow::{Context, Result, anyhow}; use anyhow::{Context, Result, anyhow};
use chrono::{DateTime, Local}; use chrono::{DateTime, Local, Utc};
use crossterm::terminal::{disable_raw_mode, enable_raw_mode}; use crossterm::terminal::{disable_raw_mode, enable_raw_mode};
use owlen_core::mcp::remote_client::RemoteMcpClient; use owlen_core::mcp::remote_client::RemoteMcpClient;
use owlen_core::mcp::{McpToolDescriptor, McpToolResponse};
use owlen_core::{ use owlen_core::{
Provider, ProviderConfig, Provider, ProviderConfig,
config::McpResourceConfig,
model::DetailedModelInfo, model::DetailedModelInfo,
oauth::{DeviceAuthorization, DevicePollState},
session::{SessionController, SessionOutcome}, session::{SessionController, SessionOutcome},
storage::SessionMeta, storage::SessionMeta,
theme::Theme, theme::Theme,
@@ -19,7 +22,7 @@ use tokio::{
sync::mpsc, sync::mpsc,
task::{self, JoinHandle}, task::{self, JoinHandle},
}; };
use tui_textarea::{Input, TextArea}; use tui_textarea::{CursorMove, Input, TextArea};
use unicode_width::UnicodeWidthStr; use unicode_width::UnicodeWidthStr;
use uuid::Uuid; use uuid::Uuid;
@@ -27,12 +30,14 @@ use crate::commands;
use crate::config; use crate::config;
use crate::events::Event; use crate::events::Event;
use crate::model_info_panel::ModelInfoPanel; use crate::model_info_panel::ModelInfoPanel;
use crate::slash::{self, McpSlashCommand, SlashCommand};
use crate::state::{ use crate::state::{
CodeWorkspace, CommandPalette, FileFilterMode, FileIconResolver, FileNode, FileTreeState, CodeWorkspace, CommandPalette, FileFilterMode, FileIconResolver, FileNode, FileTreeState,
ModelPaletteEntry, PaletteSuggestion, PaneDirection, PaneRestoreRequest, RepoSearchMessage, ModelPaletteEntry, PaletteSuggestion, PaneDirection, PaneRestoreRequest, RepoSearchMessage,
RepoSearchState, SplitAxis, SymbolSearchMessage, SymbolSearchState, WorkspaceSnapshot, RepoSearchState, SplitAxis, SymbolSearchMessage, SymbolSearchState, WorkspaceSnapshot,
spawn_repo_search_task, spawn_symbol_search_task, spawn_repo_search_task, spawn_symbol_search_task,
}; };
use crate::toast::{Toast, ToastLevel, ToastManager};
use crate::ui::format_tool_output; use crate::ui::format_tool_output;
// Agent executor moved to separate binary `owlen-agent`. The TUI no longer directly // Agent executor moved to separate binary `owlen-agent`. The TUI no longer directly
// imports `AgentExecutor` to avoid a circular dependency on `owlen-cli`. // imports `AgentExecutor` to avoid a circular dependency on `owlen-cli`.
@@ -48,6 +53,7 @@ use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime}; use std::time::{Duration, Instant, SystemTime};
use dirs::{config_dir, data_local_dir}; use dirs::{config_dir, data_local_dir};
use serde_json::{Value, json};
const ONBOARDING_STATUS_LINE: &str = const ONBOARDING_STATUS_LINE: &str =
"Welcome to Owlen! Press F1 for help or type :tutorial for keybinding tips."; "Welcome to Owlen! Press F1 for help or type :tutorial for keybinding tips.";
@@ -61,6 +67,13 @@ const RESIZE_DOUBLE_TAP_WINDOW: Duration = Duration::from_millis(450);
const RESIZE_STEP: f32 = 0.05; const RESIZE_STEP: f32 = 0.05;
const RESIZE_SNAP_VALUES: [f32; 3] = [0.5, 0.75, 0.25]; const RESIZE_SNAP_VALUES: [f32; 3] = [0.5, 0.75, 0.25];
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum SlashOutcome {
NotCommand,
Consumed,
Error,
}
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub(crate) struct ModelSelectorItem { pub(crate) struct ModelSelectorItem {
kind: ModelSelectorItemKind, kind: ModelSelectorItemKind,
@@ -158,6 +171,11 @@ pub enum SessionEvent {
AgentCompleted { answer: String }, AgentCompleted { answer: String },
/// Agent execution failed /// Agent execution failed
AgentFailed { error: String }, AgentFailed { error: String },
/// Poll the OAuth device authorization flow for the given server
OAuthPoll {
server: String,
authorization: DeviceAuthorization,
},
} }
pub const HELP_TAB_COUNT: usize = 7; pub const HELP_TAB_COUNT: usize = 7;
@@ -205,6 +223,9 @@ pub struct ChatApp {
clipboard: String, // Vim-style clipboard for yank/paste clipboard: String, // Vim-style clipboard for yank/paste
pending_file_action: Option<FileActionPrompt>, // Active file action prompt pending_file_action: Option<FileActionPrompt>, // Active file action prompt
command_palette: CommandPalette, // Command mode state (buffer + suggestions) command_palette: CommandPalette, // Command mode state (buffer + suggestions)
resource_catalog: Vec<McpResourceConfig>, // Configured MCP resources for autocompletion
pending_resource_refs: Vec<String>, // Resource references to resolve before send
oauth_flows: HashMap<String, DeviceAuthorization>, // Active OAuth device flows by server
repo_search: RepoSearchState, // Repository search overlay state repo_search: RepoSearchState, // Repository search overlay state
repo_search_task: Option<JoinHandle<()>>, repo_search_task: Option<JoinHandle<()>>,
repo_search_rx: Option<mpsc::UnboundedReceiver<RepoSearchMessage>>, repo_search_rx: Option<mpsc::UnboundedReceiver<RepoSearchMessage>>,
@@ -235,6 +256,7 @@ pub struct ChatApp {
selected_theme_index: usize, // Index of selected theme in browser selected_theme_index: usize, // Index of selected theme in browser
pending_consent: Option<ConsentDialogState>, // Pending consent request pending_consent: Option<ConsentDialogState>, // Pending consent request
system_status: String, // System/status messages (tool execution, status, etc) system_status: String, // System/status messages (tool execution, status, etc)
toasts: ToastManager,
/// Simple execution budget: maximum number of tool calls allowed per session. /// Simple execution budget: maximum number of tool calls allowed per session.
_execution_budget: usize, _execution_budget: usize,
/// Agent mode enabled /// Agent mode enabled
@@ -438,6 +460,9 @@ impl ChatApp {
clipboard: String::new(), clipboard: String::new(),
pending_file_action: None, pending_file_action: None,
command_palette: CommandPalette::new(), command_palette: CommandPalette::new(),
resource_catalog: Vec::new(),
pending_resource_refs: Vec::new(),
oauth_flows: HashMap::new(),
repo_search: RepoSearchState::new(), repo_search: RepoSearchState::new(),
repo_search_task: None, repo_search_task: None,
repo_search_rx: None, repo_search_rx: None,
@@ -472,6 +497,7 @@ impl ChatApp {
} else { } else {
String::new() String::new()
}, },
toasts: ToastManager::new(),
_execution_budget: 50, _execution_budget: 50,
agent_mode: false, agent_mode: false,
agent_running: false, agent_running: false,
@@ -490,6 +516,8 @@ impl ChatApp {
)); ));
app.update_command_palette_catalog(); app.update_command_palette_catalog();
app.refresh_resource_catalog().await?;
app.refresh_mcp_slash_commands().await?;
if let Err(err) = app.restore_workspace_layout().await { if let Err(err) = app.restore_workspace_layout().await {
eprintln!("Warning: failed to restore workspace layout: {err}"); eprintln!("Warning: failed to restore workspace layout: {err}");
@@ -1371,6 +1399,18 @@ impl ChatApp {
&self.theme &self.theme
} }
pub fn toasts(&self) -> impl Iterator<Item = &Toast> {
self.toasts.iter()
}
pub fn push_toast(&mut self, level: ToastLevel, message: impl Into<String>) {
self.toasts.push(message, level);
}
fn prune_toasts(&mut self) {
self.toasts.retain_active();
}
pub fn input_max_rows(&self) -> u16 { pub fn input_max_rows(&self) -> u16 {
let config = self.controller.config(); let config = self.controller.config();
config.ui.input_max_rows.max(1) config.ui.input_max_rows.max(1)
@@ -1443,6 +1483,304 @@ impl ChatApp {
.update_dynamic_sources(models, providers); .update_dynamic_sources(models, providers);
} }
async fn refresh_resource_catalog(&mut self) -> Result<()> {
let mut resources = self.controller.configured_resources().await;
resources.sort_by(|a, b| a.server.cmp(&b.server).then(a.uri.cmp(&b.uri)));
self.resource_catalog = resources;
Ok(())
}
async fn refresh_mcp_slash_commands(&mut self) -> Result<()> {
let mut commands = Vec::new();
for (server, descriptor) in self.controller.list_mcp_tools().await {
if !Self::tool_supports_slash(&descriptor) {
continue;
}
let description = if descriptor.description.trim().is_empty() {
None
} else {
Some(descriptor.description.clone())
};
commands.push(McpSlashCommand::new(
server,
descriptor.name.clone(),
description,
));
}
slash::set_mcp_commands(commands);
Ok(())
}
fn tool_supports_slash(descriptor: &McpToolDescriptor) -> bool {
if descriptor.name.trim().is_empty() {
return false;
}
Self::tool_allows_empty_arguments(&descriptor.input_schema)
}
fn tool_allows_empty_arguments(schema: &Value) -> bool {
match schema {
Value::Object(map) => {
if let Some(Value::Array(required)) = map.get("required") {
!required
.iter()
.any(|entry| entry.as_str().is_some_and(|s| !s.is_empty()))
} else {
true
}
}
_ => true,
}
}
fn format_mcp_slash_message(server: &str, tool: &str, response: &McpToolResponse) -> String {
let status = if response.success { "" } else { "" };
let payload = if response.success {
Self::extract_mcp_primary_text(&response.output)
} else {
Self::extract_mcp_error(&response.output)
.or_else(|| Self::extract_mcp_primary_text(&response.output))
}
.unwrap_or_else(|| Self::pretty_print_value(&response.output));
if payload.trim().is_empty() {
return format!("MCP {server}::{tool} {status}");
}
if payload.contains('\n') {
format!("MCP {server}::{tool} {status}\n```json\n{payload}\n```")
} else {
format!("MCP {server}::{tool} {status}\n{payload}")
}
}
fn extract_mcp_primary_text(value: &Value) -> Option<String> {
if let Some(text) = value.as_str().filter(|text| !text.trim().is_empty()) {
return Some(text.to_string());
}
if let Value::Object(map) = value {
const CANDIDATES: [&str; 6] =
["rendered", "text", "content", "value", "message", "body"];
for key in CANDIDATES {
if let Some(Value::String(text)) = map.get(key)
&& !text.trim().is_empty()
{
return Some(text.clone());
}
}
if let Some(Value::Array(items)) = map.get("lines") {
let mut collected = Vec::new();
for item in items {
if let Some(segment) = item.as_str()
&& !segment.trim().is_empty()
{
collected.push(segment.trim());
}
}
if !collected.is_empty() {
return Some(collected.join("\n"));
}
}
}
None
}
fn extract_mcp_error(value: &Value) -> Option<String> {
if let Value::Object(map) = value
&& let Some(Value::String(message)) = map.get("error")
&& !message.trim().is_empty()
{
return Some(message.clone());
}
None
}
fn pretty_print_value(value: &Value) -> String {
serde_json::to_string_pretty(value).unwrap_or_else(|_| value.to_string())
}
async fn resolve_pending_resource_references(&mut self) -> Result<()> {
if self.pending_resource_refs.is_empty() {
return Ok(());
}
let mut resolved = 0usize;
let references: Vec<String> = self.pending_resource_refs.drain(..).collect();
for reference in references {
match self.controller.resolve_resource_reference(&reference).await {
Ok(Some(content)) => {
let message = format!("Resource @{}:\n{}", reference, content);
self.controller
.conversation_mut()
.push_system_message(message);
resolved += 1;
}
Ok(None) => {
self.push_toast(
ToastLevel::Warning,
format!(
"Resource @{} is not defined in the current project.",
reference
),
);
}
Err(err) => {
self.push_toast(
ToastLevel::Error,
format!("Failed to load resource @{}: {}", reference, err),
);
}
}
}
if resolved > 0 {
self.status = format!("Inserted {resolved} resource snippet(s).");
}
Ok(())
}
fn complete_resource_reference(&mut self) -> bool {
if self.resource_catalog.is_empty() {
return false;
}
let (row, col) = self.textarea.cursor();
let lines = self.textarea.lines().to_vec();
if row >= lines.len() {
return false;
}
let line = &lines[row];
let chars: Vec<char> = line.chars().collect();
if col > chars.len() {
return false;
}
let mut start = col;
while start > 0 {
let ch = chars[start - 1];
if ch == '@' {
start -= 1;
break;
}
if ch.is_whitespace() {
return false;
}
start -= 1;
}
if start >= col || chars.get(start) != Some(&'@') {
return false;
}
if chars[start + 1..col].iter().any(|ch| ch.is_whitespace()) {
return false;
}
let mut end = col;
while end < chars.len() {
let ch = chars[end];
if ch.is_whitespace() {
break;
}
end += 1;
}
let typed_prefix: String = chars[start + 1..col].iter().collect();
let trailing_segment: String = chars[col..end].iter().collect();
let lower_prefix = typed_prefix.to_ascii_lowercase();
let lower_full = format!("{}{}", typed_prefix, trailing_segment).to_ascii_lowercase();
let mut matches: Vec<&McpResourceConfig> = self
.resource_catalog
.iter()
.filter(|resource| {
let reference = format!("{}:{}", resource.server, resource.uri);
let lower_reference = reference.to_ascii_lowercase();
lower_reference.starts_with(&lower_full)
|| lower_reference.starts_with(&lower_prefix)
|| resource
.title
.as_ref()
.map(|title| title.to_ascii_lowercase().starts_with(&lower_prefix))
.unwrap_or(false)
})
.collect();
if matches.is_empty() {
return false;
}
matches.sort_by(|a, b| a.server.cmp(&b.server).then(a.uri.cmp(&b.uri)));
let (selected_server, selected_uri, selected_title) = {
let selected = matches[0];
(
selected.server.clone(),
selected.uri.clone(),
selected.title.clone(),
)
};
let replacement = format!("@{}:{}", selected_server, selected_uri);
let mut new_line = String::new();
new_line.extend(chars[..start].iter());
new_line.push_str(&replacement);
new_line.extend(chars[end..].iter());
let mut new_lines = lines;
new_lines[row] = new_line;
self.textarea = TextArea::new(new_lines);
configure_textarea_defaults(&mut self.textarea);
let new_col = start + replacement.len();
self.textarea
.move_cursor(CursorMove::Jump(row as u16, new_col as u16));
self.sync_textarea_to_buffer();
if let Some(title) = selected_title.as_deref() {
self.status = format!("Inserted resource {} ({title}).", replacement);
} else {
self.status = format!("Inserted resource {}.", replacement);
}
self.error = None;
true
}
fn extract_resource_references(text: &str) -> Vec<String> {
let mut references = Vec::new();
let mut current = String::new();
let mut in_reference = false;
for ch in text.chars() {
if in_reference {
if ch.is_whitespace() || matches!(ch, ',' | ';' | ')' | '(' | '.' | '!' | '?') {
if current.contains(':') {
references.push(current.clone());
}
current.clear();
in_reference = false;
} else {
current.push(ch);
}
} else if ch == '@' {
in_reference = true;
current.clear();
}
}
if in_reference && current.contains(':') {
references.push(current);
}
references
}
fn display_name_for_model(model: &ModelInfo) -> String { fn display_name_for_model(model: &ModelInfo) -> String {
if model.name.trim().is_empty() { if model.name.trim().is_empty() {
model.id.clone() model.id.clone()
@@ -2110,6 +2448,204 @@ impl ChatApp {
configure_textarea_defaults(&mut self.textarea); configure_textarea_defaults(&mut self.textarea);
} }
async fn process_slash_submission(&mut self) -> Result<SlashOutcome> {
let raw = self.controller.input_buffer().text().to_string();
if raw.trim().is_empty() {
return Ok(SlashOutcome::NotCommand);
}
match slash::parse(&raw) {
Ok(None) => Ok(SlashOutcome::NotCommand),
Ok(Some(command)) => match self.execute_slash_command(command).await {
Ok(()) => {
self.input_buffer_mut().push_history_entry(raw.clone());
self.controller.input_buffer_mut().clear();
Ok(SlashOutcome::Consumed)
}
Err(err) => {
self.error = Some(err.to_string());
self.status = "Slash command failed".to_string();
self.controller.input_buffer_mut().set_text(raw);
Ok(SlashOutcome::Error)
}
},
Err(err) => {
self.error = Some(err.to_string());
self.status = "Slash command error".to_string();
Ok(SlashOutcome::Error)
}
}
}
async fn execute_slash_command(&mut self, command: SlashCommand) -> Result<()> {
match command {
SlashCommand::Summarize { count } => {
let prompt = if let Some(count) = count {
format!(
"Summarize the last {count} messages in this conversation. Highlight key decisions, open questions, and follow-up tasks."
)
} else {
"Summarize the conversation so far, calling out major decisions, blockers, and immediate next steps.".to_string()
};
self.status = "Summarizing conversation...".to_string();
self.dispatch_user_prompt(prompt);
}
SlashCommand::Explain { snippet } => {
let prompt = format!(
"Explain the following code snippet. Cover what it does and call out any potential issues or improvements:\n```\n{}\n```",
snippet
);
self.status = "Explaining snippet...".to_string();
self.dispatch_user_prompt(prompt);
}
SlashCommand::Refactor { path } => {
let trimmed = path.trim();
if trimmed.is_empty() {
anyhow::bail!("usage: /refactor <relative/path/to/file>");
}
let source = self.controller.read_file(trimmed).await?;
let prompt = format!(
"Refactor the file `{}`. Provide specific improvements for readability, safety, and maintainability. Include updated code where relevant.\n\n```text\n{}\n```",
trimmed, source
);
self.status = format!("Refactor review for {trimmed}...");
self.dispatch_user_prompt(prompt);
}
SlashCommand::TestPlan => {
let prompt = "Generate a comprehensive test plan for this repository. Outline critical test suites, coverage gaps, and prioritized steps to reach confident automation.".to_string();
self.status = "Generating test plan...".to_string();
self.dispatch_user_prompt(prompt);
}
SlashCommand::Compact => {
let prompt = "Compress our conversation history to its essentials. Summarize previous exchanges, preserve critical context, and indicate what state can be safely forgotten.".to_string();
self.status = "Compacting conversation...".to_string();
self.dispatch_user_prompt(prompt);
}
SlashCommand::McpTool { server, tool } => {
self.status = format!("Running MCP tool {server}::{tool}...");
let response = self
.controller
.call_mcp_tool(&server, &tool, json!({}))
.await
.map_err(|err| {
anyhow!("Failed to invoke MCP tool {}::{}: {}", server, tool, err)
})?;
let content = Self::format_mcp_slash_message(&server, &tool, &response);
self.controller
.conversation_mut()
.push_system_message(content);
self.auto_scroll.stick_to_bottom = true;
self.new_message_alert = true;
if response.success {
self.status = format!("MCP {server}::{tool} result added to chat.");
self.push_toast(ToastLevel::Info, format!("MCP {server}::{tool} completed."));
} else {
self.status = format!("MCP {server}::{tool} reported an error (see chat).");
self.push_toast(
ToastLevel::Warning,
format!("MCP {server}::{tool} reported an error."),
);
}
self.error = None;
}
}
Ok(())
}
fn schedule_oauth_poll(
&self,
server: String,
authorization: DeviceAuthorization,
delay: Duration,
) {
let sender = self.session_tx.clone();
tokio::spawn(async move {
tokio::time::sleep(delay).await;
let _ = sender.send(SessionEvent::OAuthPoll {
server,
authorization,
});
});
}
async fn start_oauth_login(&mut self, server: &str) -> Result<()> {
if self.oauth_flows.contains_key(server) {
self.error = Some(format!("OAuth flow for '{server}' is already in progress."));
return Ok(());
}
let authorization = match self.controller.start_oauth_device_flow(server).await {
Ok(auth) => auth,
Err(err) => {
self.error = Some(format!("Failed to start OAuth for '{server}': {err}"));
return Ok(());
}
};
self.oauth_flows
.insert(server.to_string(), authorization.clone());
let link = authorization
.verification_uri_complete
.clone()
.unwrap_or_else(|| authorization.verification_uri.clone());
let status = format!(
"Authorize '{server}' via {} (code {}).",
link, authorization.user_code
);
self.status = status;
self.error = None;
let mut message = format!(
"OAuth authorization required for `{server}`.\nVisit:\n{}\nEnter code: `{}`",
link, authorization.user_code
);
if let Some(hint) = &authorization.message
&& !hint.trim().is_empty()
{
message.push_str("\n\n");
message.push_str(hint);
}
if authorization.expires_at > Utc::now() {
message.push_str(&format!(
"\n\nThis code expires at {}.",
authorization
.expires_at
.to_rfc3339_opts(chrono::SecondsFormat::Secs, true)
));
}
self.controller
.conversation_mut()
.push_system_message(message);
self.auto_scroll.stick_to_bottom = true;
self.notify_new_activity();
self.push_toast(
ToastLevel::Warning,
format!("Authorize {server}: code {}", authorization.user_code),
);
let delay = authorization.interval;
self.schedule_oauth_poll(server.to_string(), authorization.clone(), delay);
Ok(())
}
fn dispatch_user_prompt(&mut self, prompt: String) {
if prompt.trim().is_empty() {
self.error = Some("Slash command generated an empty request".to_string());
return;
}
self.controller.conversation_mut().push_user_message(prompt);
self.auto_scroll.stick_to_bottom = true;
self.pending_llm_request = true;
self.set_system_status(String::new());
self.error = None;
}
fn set_code_view_content( fn set_code_view_content(
&mut self, &mut self,
display_path: impl Into<String>, display_path: impl Into<String>,
@@ -2216,14 +2752,14 @@ impl ChatApp {
Ok(()) Ok(())
} }
async fn restore_workspace_layout(&mut self) -> Result<()> { async fn restore_workspace_layout(&mut self) -> Result<bool> {
let path = match self.workspace_layout_path() { let path = match self.workspace_layout_path() {
Ok(path) => path, Ok(path) => path,
Err(_) => return Ok(()), Err(_) => return Ok(false),
}; };
if !path.exists() { if !path.exists() {
return Ok(()); return Ok(false);
} }
let contents = fs::read_to_string(&path) let contents = fs::read_to_string(&path)
@@ -2247,7 +2783,7 @@ impl ChatApp {
self.status = "Workspace layout restored".to_string(); self.status = "Workspace layout restored".to_string();
} }
Ok(()) Ok(restored_any)
} }
fn direction_label(direction: PaneDirection) -> &'static str { fn direction_label(direction: PaneDirection) -> &'static str {
@@ -3289,6 +3825,7 @@ impl ChatApp {
Event::Tick => { Event::Tick => {
self.poll_repo_search(); self.poll_repo_search();
self.poll_symbol_search(); self.poll_symbol_search();
self.prune_toasts();
// Future: update streaming timers // Future: update streaming timers
} }
Event::Resize(width, height) => { Event::Resize(width, height) => {
@@ -4172,13 +4709,24 @@ impl ChatApp {
self.textarea.insert_newline(); self.textarea.insert_newline();
} }
(KeyCode::Enter, KeyModifiers::NONE) => { (KeyCode::Enter, KeyModifiers::NONE) => {
// Send message and return to normal mode
self.sync_textarea_to_buffer(); self.sync_textarea_to_buffer();
self.send_user_message_and_request_response(); match self.process_slash_submission().await? {
// Clear the textarea by setting it to empty SlashOutcome::NotCommand => {
self.textarea = TextArea::default(); self.send_user_message_and_request_response();
configure_textarea_defaults(&mut self.textarea); self.textarea = TextArea::default();
self.set_input_mode(InputMode::Normal); configure_textarea_defaults(&mut self.textarea);
self.set_input_mode(InputMode::Normal);
}
SlashOutcome::Consumed => {
self.textarea = TextArea::default();
configure_textarea_defaults(&mut self.textarea);
self.set_input_mode(InputMode::Normal);
}
SlashOutcome::Error => {
// Restore textarea content so the user can correct the command
self.sync_buffer_to_textarea();
}
}
} }
(KeyCode::Enter, _) => { (KeyCode::Enter, _) => {
// Any Enter with modifiers keeps editing and inserts a newline via tui-textarea // Any Enter with modifiers keeps editing and inserts a newline via tui-textarea
@@ -4208,6 +4756,11 @@ impl ChatApp {
self.textarea self.textarea
.move_cursor(tui_textarea::CursorMove::WordBack); .move_cursor(tui_textarea::CursorMove::WordBack);
} }
(KeyCode::Tab, m) if m.is_empty() => {
if !self.complete_resource_reference() {
self.textarea.input(Input::from(key));
}
}
(KeyCode::Char('r'), m) if m.contains(KeyModifiers::CONTROL) => { (KeyCode::Char('r'), m) if m.contains(KeyModifiers::CONTROL) => {
// Redo - history next // Redo - history next
self.input_buffer_mut().history_next(); self.input_buffer_mut().history_next();
@@ -4538,6 +5091,31 @@ impl ChatApp {
} }
} }
} }
"oauth" => {
if args.is_empty() {
let pending = self.controller.pending_oauth_servers();
if pending.is_empty() {
self.status =
"No OAuth-enabled MCP servers require authorization."
.to_string();
} else {
self.status = format!(
"Pending OAuth servers: {}",
pending.join(", ")
);
}
self.error = None;
} else if args.len() == 1 {
self.start_oauth_login(args[0]).await?;
} else if args.len() == 2
&& args[0].eq_ignore_ascii_case("login")
{
self.start_oauth_login(args[1]).await?;
} else {
self.error =
Some("Usage: :oauth [login] <server>".to_string());
}
}
"load" | "o" => { "load" | "o" => {
// Load saved sessions and enter browser mode // Load saved sessions and enter browser mode
match self.controller.list_saved_sessions().await { match self.controller.list_saved_sessions().await {
@@ -5015,29 +5593,58 @@ impl ChatApp {
if self.code_workspace.tabs().is_empty() { if self.code_workspace.tabs().is_empty() {
self.status = self.status =
"No open panes to save".to_string(); "No open panes to save".to_string();
self.error = None;
self.push_toast(
ToastLevel::Warning,
"Open a pane before saving layout.",
);
} else { } else {
self.persist_workspace_layout(); self.persist_workspace_layout();
self.status = self.status =
"Workspace layout saved".to_string(); "Workspace layout saved".to_string();
self.error = None; self.error = None;
self.push_toast(
ToastLevel::Success,
"Workspace layout saved.",
);
} }
} }
"load" => match self.restore_workspace_layout().await { "load" => match self.restore_workspace_layout().await {
Ok(()) => { Ok(true) => {
self.status = self.status =
"Workspace layout restored".to_string(); "Workspace layout restored".to_string();
self.error = None; self.error = None;
self.push_toast(
ToastLevel::Success,
"Workspace layout restored.",
);
}
Ok(false) => {
self.status =
"No saved layout to restore".to_string();
self.error = None;
self.push_toast(
ToastLevel::Info,
"No saved layout was found.",
);
} }
Err(err) => { Err(err) => {
self.error = Some(err.to_string()); let message = format!(
"Failed to restore workspace layout: {}",
err
);
self.error = Some(message.clone());
self.status = self.status =
"Failed to restore workspace layout" "Failed to restore workspace layout"
.to_string(); .to_string();
self.push_toast(ToastLevel::Error, message);
} }
}, },
other => { other => {
self.status =
format!("Unknown layout command: {other}");
self.error = Some(format!( self.error = Some(format!(
"Unknown layout command: {other}" "Unknown layout subcommand: {other}"
)); ));
} }
} }
@@ -5068,6 +5675,27 @@ impl ChatApp {
self.error = None; self.error = None;
self.sync_ui_preferences_from_config(); self.sync_ui_preferences_from_config();
self.update_command_palette_catalog(); self.update_command_palette_catalog();
if let Err(err) = self.refresh_resource_catalog().await
{
self.push_toast(
ToastLevel::Error,
format!(
"Failed to refresh MCP resources: {}",
err
),
);
}
if let Err(err) =
self.refresh_mcp_slash_commands().await
{
self.push_toast(
ToastLevel::Error,
format!(
"Failed to refresh MCP slash commands: {}",
err
),
);
}
} }
Err(e) => { Err(e) => {
self.error = self.error =
@@ -5666,7 +6294,7 @@ impl ChatApp {
} }
} }
pub fn handle_session_event(&mut self, event: SessionEvent) -> Result<()> { pub async fn handle_session_event(&mut self, event: SessionEvent) -> Result<()> {
match event { match event {
SessionEvent::StreamChunk { SessionEvent::StreamChunk {
message_id, message_id,
@@ -5760,6 +6388,52 @@ impl ChatApp {
self.agent_actions = None; self.agent_actions = None;
self.stop_loading_animation(); self.stop_loading_animation();
} }
SessionEvent::OAuthPoll {
server,
authorization,
} => {
match self
.controller
.poll_oauth_device_flow(&server, &authorization)
.await
{
Ok(DevicePollState::Pending { retry_in }) => {
self.oauth_flows
.insert(server.clone(), authorization.clone());
let server_name = server.clone();
self.schedule_oauth_poll(server, authorization, retry_in);
self.status = format!("Waiting for OAuth approval for {server_name}...");
}
Ok(DevicePollState::Complete(_token)) => {
self.oauth_flows.remove(&server);
self.push_toast(
ToastLevel::Success,
format!("OAuth authorization complete for {server}."),
);
self.status = format!("OAuth authorization complete for {server}.");
if let Err(err) = self.refresh_resource_catalog().await {
self.push_toast(
ToastLevel::Error,
format!("Failed to refresh MCP resources: {err}"),
);
}
if let Err(err) = self.refresh_mcp_slash_commands().await {
self.push_toast(
ToastLevel::Error,
format!("Failed to refresh MCP slash commands: {err}"),
);
}
}
Err(err) => {
self.oauth_flows.remove(&server);
self.error = Some(format!("OAuth flow for '{server}' failed: {err}"));
self.push_toast(
ToastLevel::Error,
format!("OAuth failure for {server}: {err}"),
);
}
}
}
} }
Ok(()) Ok(())
} }
@@ -5825,6 +6499,7 @@ impl ChatApp {
args: Vec::new(), args: Vec::new(),
transport: "stdio".to_string(), transport: "stdio".to_string(),
env: env_vars.clone(), env: env_vars.clone(),
oauth: None,
}; };
RemoteMcpClient::new_with_config(&config) RemoteMcpClient::new_with_config(&config)
} else { } else {
@@ -6176,6 +6851,7 @@ impl ChatApp {
args: Vec::new(), args: Vec::new(),
transport: "stdio".to_string(), transport: "stdio".to_string(),
env: env_vars, env: env_vars,
oauth: None,
}; };
Arc::new(RemoteMcpClient::new_with_config(&config)?) Arc::new(RemoteMcpClient::new_with_config(&config)?)
} else { } else {
@@ -6423,6 +7099,10 @@ impl ChatApp {
// Step 1: Add user message to conversation immediately (synchronous) // Step 1: Add user message to conversation immediately (synchronous)
let message = self.controller.input_buffer_mut().commit_to_history(); let message = self.controller.input_buffer_mut().commit_to_history();
let mut references = Self::extract_resource_references(&message);
references.sort();
references.dedup();
self.pending_resource_refs = references;
self.controller self.controller
.conversation_mut() .conversation_mut()
.push_user_message(message.clone()); .push_user_message(message.clone());
@@ -6539,6 +7219,8 @@ impl ChatApp {
self.pending_llm_request = false; self.pending_llm_request = false;
self.resolve_pending_resource_references().await?;
// Check if agent mode is enabled // Check if agent mode is enabled
if self.agent_mode { if self.agent_mode {
return self.process_agent_request().await; return self.process_agent_request().await;

View File

@@ -28,8 +28,8 @@ impl CodeApp {
self.inner.handle_event(event).await self.inner.handle_event(event).await
} }
pub fn handle_session_event(&mut self, event: SessionEvent) -> Result<()> { pub async fn handle_session_event(&mut self, event: SessionEvent) -> Result<()> {
self.inner.handle_session_event(event) self.inner.handle_session_event(event).await
} }
pub fn mode(&self) -> InputMode { pub fn mode(&self) -> InputMode {

View File

@@ -235,7 +235,7 @@ pub fn match_score(candidate: &str, query: &str) -> Option<(usize, usize)> {
if candidate_normalized == query_normalized { if candidate_normalized == query_normalized {
Some((0, candidate.len())) Some((0, candidate.len()))
} else if candidate_normalized.starts_with(&query_normalized) { } else if candidate_normalized.starts_with(&query_normalized) {
Some((1, candidate.len())) Some((1, 0))
} else if let Some(pos) = candidate_normalized.find(&query_normalized) { } else if let Some(pos) = candidate_normalized.find(&query_normalized) {
Some((2, pos)) Some((2, pos))
} else if is_subsequence(&candidate_normalized, &query_normalized) { } else if is_subsequence(&candidate_normalized, &query_normalized) {

View File

@@ -18,7 +18,9 @@ pub mod commands;
pub mod config; pub mod config;
pub mod events; pub mod events;
pub mod model_info_panel; pub mod model_info_panel;
pub mod slash;
pub mod state; pub mod state;
pub mod toast;
pub mod tui_controller; pub mod tui_controller;
pub mod ui; pub mod ui;

View File

@@ -0,0 +1,238 @@
//! Slash command parsing for chat input.
//!
//! Provides lightweight handling for inline commands such as `/summarize`
//! and `/testplan`. The parser returns owned data so callers can prepare
//! requests immediately without additional lifetime juggling.
use std::collections::HashMap;
use std::fmt;
use std::str::FromStr;
use std::sync::{OnceLock, RwLock};
/// Supported slash commands.
#[derive(Debug, Clone)]
pub enum SlashCommand {
Summarize { count: Option<usize> },
Explain { snippet: String },
Refactor { path: String },
TestPlan,
Compact,
McpTool { server: String, tool: String },
}
/// Errors emitted when parsing invalid slash input.
#[derive(Debug)]
pub enum SlashError {
UnknownCommand(String),
Message(String),
}
impl fmt::Display for SlashError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SlashError::UnknownCommand(cmd) => write!(f, "unknown slash command: {cmd}"),
SlashError::Message(msg) => f.write_str(msg),
}
}
}
impl std::error::Error for SlashError {}
#[derive(Debug, Clone)]
pub struct McpSlashCommand {
pub server: String,
pub tool: String,
pub keyword: String,
pub description: Option<String>,
}
impl McpSlashCommand {
pub fn new(
server: impl Into<String>,
tool: impl Into<String>,
description: Option<String>,
) -> Self {
let server = server.into();
let tool = tool.into();
let keyword = format!(
"mcp__{}__{}",
canonicalize_component(&server),
canonicalize_component(&tool)
);
Self {
server,
tool,
keyword,
description,
}
}
}
static MCP_COMMANDS: OnceLock<RwLock<HashMap<String, McpSlashCommand>>> = OnceLock::new();
fn dynamic_registry() -> &'static RwLock<HashMap<String, McpSlashCommand>> {
MCP_COMMANDS.get_or_init(|| RwLock::new(HashMap::new()))
}
pub fn set_mcp_commands(commands: impl IntoIterator<Item = McpSlashCommand>) {
let registry = dynamic_registry();
let mut guard = registry.write().expect("MCP command registry poisoned");
guard.clear();
for command in commands {
guard.insert(command.keyword.clone(), command);
}
}
fn find_mcp_command(keyword: &str) -> Option<McpSlashCommand> {
let registry = dynamic_registry();
let guard = registry.read().expect("MCP command registry poisoned");
guard.get(keyword).cloned()
}
fn canonicalize_component(input: &str) -> String {
let mut out = String::new();
let mut last_was_underscore = false;
for ch in input.chars() {
let mapped = if ch.is_ascii_alphanumeric() {
ch.to_ascii_lowercase()
} else {
'_'
};
if mapped == '_' {
if !last_was_underscore {
out.push('_');
last_was_underscore = true;
}
} else {
out.push(mapped);
last_was_underscore = false;
}
}
if out.is_empty() { "_".to_string() } else { out }
}
/// Attempt to parse a slash command from the provided input.
pub fn parse(input: &str) -> Result<Option<SlashCommand>, SlashError> {
let trimmed = input.trim();
if !trimmed.starts_with('/') {
return Ok(None);
}
let body = trimmed.trim_start_matches('/');
if body.is_empty() {
return Err(SlashError::Message("missing command name after '/'".into()));
}
let mut parts = body.split_whitespace();
let command = parts.next().unwrap();
let remainder = parts.collect::<Vec<_>>();
if let Some(dynamic) = find_mcp_command(command) {
if !remainder.is_empty() {
return Err(SlashError::Message(format!(
"/{} does not accept arguments",
dynamic.keyword
)));
}
return Ok(Some(SlashCommand::McpTool {
server: dynamic.server,
tool: dynamic.tool,
}));
}
let cmd = match command {
"summarize" => {
let count = remainder
.first()
.and_then(|value| usize::from_str(value).ok());
SlashCommand::Summarize { count }
}
"explain" => {
if remainder.is_empty() {
return Err(SlashError::Message(
"usage: /explain <code snippet or description>".into(),
));
}
SlashCommand::Explain {
snippet: remainder.join(" "),
}
}
"refactor" => {
if remainder.is_empty() {
return Err(SlashError::Message(
"usage: /refactor <relative/path/to/file>".into(),
));
}
SlashCommand::Refactor {
path: remainder.join(" "),
}
}
"testplan" => SlashCommand::TestPlan,
"compact" => SlashCommand::Compact,
other => return Err(SlashError::UnknownCommand(other.to_string())),
};
Ok(Some(cmd))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ignores_non_command_input() {
let result = parse("hello world").unwrap();
assert!(result.is_none());
}
#[test]
fn parses_summarize_with_count() {
let command = parse("/summarize 10").unwrap().expect("expected command");
match command {
SlashCommand::Summarize { count } => assert_eq!(count, Some(10)),
other => panic!("unexpected command: {:?}", other),
}
}
#[test]
fn returns_error_for_unknown_command() {
let err = parse("/unknown").unwrap_err();
assert_eq!(err.to_string(), "unknown slash command: unknown");
}
#[test]
fn parses_registered_mcp_command() {
set_mcp_commands(Vec::new());
set_mcp_commands(vec![McpSlashCommand::new("github", "list_prs", None)]);
let command = parse("/mcp__github__list_prs")
.unwrap()
.expect("expected command");
match command {
SlashCommand::McpTool { server, tool } => {
assert_eq!(server, "github");
assert_eq!(tool, "list_prs");
}
other => panic!("unexpected command variant: {:?}", other),
}
}
#[test]
fn rejects_mcp_command_with_arguments() {
set_mcp_commands(Vec::new());
set_mcp_commands(vec![McpSlashCommand::new("github", "list_prs", None)]);
let err = parse("/mcp__github__list_prs extra").unwrap_err();
assert_eq!(
err.to_string(),
"/mcp__github__list_prs does not accept arguments"
);
}
#[test]
fn canonicalizes_mcp_command_components() {
set_mcp_commands(Vec::new());
let entry = McpSlashCommand::new("GitHub", "list/prs", None);
assert_eq!(entry.keyword, "mcp__github__list_prs");
}
}

View File

@@ -0,0 +1,96 @@
macro_rules! adjust_fields {
($theme:expr, $func:expr, $($field:ident),+ $(,)?) => {
$(
$theme.$field = $func($theme.$field);
)+
};
}
use owlen_core::theme::Theme;
use ratatui::style::Color;
/// Return a clone of `base` with contrast adjustments applied.
/// Positive `steps` increase contrast, negative values decrease it.
pub fn with_contrast(base: &Theme, steps: i8) -> Theme {
if steps == 0 {
return base.clone();
}
let factor = (1.0 + (steps as f32) * 0.18).clamp(0.3, 2.0);
let adjust = |color: Color| adjust_color(color, factor);
let mut theme = base.clone();
adjust_fields!(
theme,
adjust,
text,
background,
focused_panel_border,
unfocused_panel_border,
focus_beacon_fg,
focus_beacon_bg,
unfocused_beacon_fg,
pane_header_active,
pane_header_inactive,
pane_hint_text,
user_message_role,
assistant_message_role,
tool_output,
thinking_panel_title,
command_bar_background,
status_background,
mode_normal,
mode_editing,
mode_model_selection,
mode_provider_selection,
mode_help,
mode_visual,
mode_command,
selection_bg,
selection_fg,
cursor,
code_block_background,
code_block_border,
code_block_text,
code_block_keyword,
code_block_string,
code_block_comment,
placeholder,
error,
info,
agent_thought,
agent_action,
agent_action_input,
agent_observation,
agent_final_answer,
agent_badge_running_fg,
agent_badge_running_bg,
agent_badge_idle_fg,
agent_badge_idle_bg,
operating_chat_fg,
operating_chat_bg,
operating_code_fg,
operating_code_bg
);
theme
}
fn adjust_color(color: Color, factor: f32) -> Color {
match color {
Color::Rgb(r, g, b) => {
let adjust_component = |component: u8| -> u8 {
let normalized = component as f32 / 255.0;
let contrasted = ((normalized - 0.5) * factor + 0.5).clamp(0.0, 1.0);
(contrasted * 255.0).round().clamp(0.0, 255.0) as u8
};
Color::Rgb(
adjust_component(r),
adjust_component(g),
adjust_component(b),
)
}
_ => color,
}
}

View File

@@ -0,0 +1,114 @@
use std::collections::VecDeque;
use std::time::{Duration, Instant};
/// Severity level for toast notifications.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ToastLevel {
Info,
Success,
Warning,
Error,
}
#[derive(Debug, Clone)]
pub struct Toast {
pub message: String,
pub level: ToastLevel,
created: Instant,
duration: Duration,
}
impl Toast {
fn new(message: String, level: ToastLevel, lifetime: Duration) -> Self {
Self {
message,
level,
created: Instant::now(),
duration: lifetime,
}
}
fn is_expired(&self, now: Instant) -> bool {
now.duration_since(self.created) >= self.duration
}
}
/// Fixed-size toast queue with automatic expiration.
#[derive(Debug)]
pub struct ToastManager {
items: VecDeque<Toast>,
max_active: usize,
lifetime: Duration,
}
impl Default for ToastManager {
fn default() -> Self {
Self::new()
}
}
impl ToastManager {
pub fn new() -> Self {
Self {
items: VecDeque::new(),
max_active: 3,
lifetime: Duration::from_secs(3),
}
}
pub fn with_lifetime(mut self, duration: Duration) -> Self {
self.lifetime = duration;
self
}
pub fn push(&mut self, message: impl Into<String>, level: ToastLevel) {
let toast = Toast::new(message.into(), level, self.lifetime);
self.items.push_front(toast);
while self.items.len() > self.max_active {
self.items.pop_back();
}
}
pub fn retain_active(&mut self) {
let now = Instant::now();
self.items.retain(|toast| !toast.is_expired(now));
}
pub fn iter(&self) -> impl Iterator<Item = &Toast> {
self.items.iter()
}
pub fn is_empty(&self) -> bool {
self.items.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread::sleep;
#[test]
fn manager_limits_active_toasts() {
let mut manager = ToastManager::new();
manager.push("first", ToastLevel::Info);
manager.push("second", ToastLevel::Warning);
manager.push("third", ToastLevel::Success);
manager.push("fourth", ToastLevel::Error);
let collected: Vec<_> = manager.iter().map(|toast| toast.message.clone()).collect();
assert_eq!(collected.len(), 3);
assert_eq!(collected[0], "fourth");
assert_eq!(collected[2], "second");
}
#[test]
fn manager_expires_toasts_after_lifetime() {
let mut manager = ToastManager::new().with_lifetime(Duration::from_millis(1));
manager.push("short lived", ToastLevel::Info);
assert!(!manager.is_empty());
sleep(Duration::from_millis(5));
manager.retain_active();
assert!(manager.is_empty());
}
}

View File

@@ -16,10 +16,12 @@ use crate::state::{
CodePane, EditorTab, FileFilterMode, FileNode, LayoutNode, PaletteGroup, PaneId, CodePane, EditorTab, FileFilterMode, FileNode, LayoutNode, PaletteGroup, PaneId,
RepoSearchRowKind, SplitAxis, VisibleFileEntry, RepoSearchRowKind, SplitAxis, VisibleFileEntry,
}; };
use crate::toast::{Toast, ToastLevel};
use owlen_core::model::DetailedModelInfo; use owlen_core::model::DetailedModelInfo;
use owlen_core::theme::Theme; use owlen_core::theme::Theme;
use owlen_core::types::{ModelInfo, Role}; use owlen_core::types::{ModelInfo, Role};
use owlen_core::ui::{FocusedPanel, InputMode, RoleLabelDisplay}; use owlen_core::ui::{FocusedPanel, InputMode, RoleLabelDisplay};
use textwrap::wrap;
const PRIVACY_TAB_INDEX: usize = HELP_TAB_COUNT - 1; const PRIVACY_TAB_INDEX: usize = HELP_TAB_COUNT - 1;
@@ -331,6 +333,113 @@ pub fn render_chat(frame: &mut Frame<'_>, app: &mut ChatApp) {
if let Some(area) = code_area { if let Some(area) = code_area {
render_code_workspace(frame, area, app); render_code_workspace(frame, area, app);
} }
render_toasts(frame, app, full_area);
}
fn toast_palette(level: ToastLevel, theme: &Theme) -> (&'static str, Style, Style) {
let (label, color) = match level {
ToastLevel::Info => ("INFO", theme.info),
ToastLevel::Success => ("OK", theme.agent_badge_idle_bg),
ToastLevel::Warning => ("WARN", theme.agent_action),
ToastLevel::Error => ("ERROR", theme.error),
};
let badge_style = Style::default()
.fg(theme.background)
.bg(color)
.add_modifier(Modifier::BOLD);
let border_style = Style::default().fg(color);
(label, badge_style, border_style)
}
fn render_toasts(frame: &mut Frame<'_>, app: &ChatApp, full_area: Rect) {
let toasts: Vec<&Toast> = app.toasts().collect();
if toasts.is_empty() {
return;
}
let theme = app.theme();
let available_width = usize::from(full_area.width.saturating_sub(2));
if available_width == 0 {
return;
}
let max_text_width = toasts
.iter()
.map(|toast| UnicodeWidthStr::width(toast.message.as_str()))
.max()
.unwrap_or(0);
let mut width = max_text_width.saturating_add(6); // padding + badge
width = width.clamp(14, available_width);
width = width.min(48);
if width == 0 {
return;
}
let width = width as u16;
let offset_x = full_area
.x
.saturating_add(full_area.width.saturating_sub(width + 1));
let mut offset_y = full_area.y.saturating_add(1);
let frame_bottom = full_area.y.saturating_add(full_area.height);
for toast in toasts {
let (label, badge_style, border_style) = toast_palette(toast.level, theme);
let badge_text = format!(" {} ", label);
let indent_width = UnicodeWidthStr::width(badge_text.as_str()) + 1;
let indent = " ".repeat(indent_width);
let content_width = width.saturating_sub(4).max(1) as usize;
let wrapped_lines = wrap(toast.message.as_str(), content_width);
let lines: Vec<String> = if wrapped_lines.is_empty() {
vec![String::new()]
} else {
wrapped_lines
.into_iter()
.map(|cow| cow.into_owned())
.collect()
};
let text_style = Style::default().fg(theme.text);
let mut paragraph_lines = Vec::with_capacity(lines.len());
if let Some((first, rest)) = lines.split_first() {
paragraph_lines.push(Line::from(vec![
Span::styled(badge_text.clone(), badge_style),
Span::raw(" "),
Span::styled(first.clone(), text_style),
]));
for line in rest {
paragraph_lines.push(Line::from(vec![
Span::raw(indent.clone()),
Span::styled(line.clone(), text_style),
]));
}
}
let height = (paragraph_lines.len() as u16).saturating_add(2);
if offset_y.saturating_add(height) > frame_bottom {
break;
}
let area = Rect::new(offset_x, offset_y, width, height);
frame.render_widget(Clear, area);
let block = Block::default()
.borders(Borders::ALL)
.border_style(border_style)
.style(Style::default().bg(theme.background));
let paragraph = Paragraph::new(paragraph_lines)
.block(block)
.alignment(Alignment::Left)
.wrap(Wrap { trim: false });
frame.render_widget(paragraph, area);
offset_y = offset_y.saturating_add(height + 1);
if offset_y >= frame_bottom {
break;
}
}
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]

View File

@@ -9,11 +9,21 @@ fn palette_tracks_buffer_and_suggestions() {
palette.set_buffer("mo"); palette.set_buffer("mo");
assert_eq!(palette.buffer(), "mo"); assert_eq!(palette.buffer(), "mo");
assert!(palette.suggestions().iter().all(|s| s.starts_with("mo"))); assert!(
palette
.suggestions()
.iter()
.all(|s| s.value.starts_with("mo"))
);
palette.push_char('d'); palette.push_char('d');
assert_eq!(palette.buffer(), "mod"); assert_eq!(palette.buffer(), "mod");
assert!(palette.suggestions().iter().all(|s| s.starts_with("mod"))); assert!(
palette
.suggestions()
.iter()
.all(|s| s.value.starts_with("mod"))
);
palette.pop_char(); palette.pop_char();
assert_eq!(palette.buffer(), "mo"); assert_eq!(palette.buffer(), "mo");