feat(cli): add MCP management subcommand with add/list/remove commands
Introduce `McpCommand` enum and handlers in `owlen-cli` to manage MCP server registrations, including adding, listing, and removing servers across configuration scopes. Add scoped configuration support (`ScopedMcpServer`, `McpConfigScope`) and OAuth token handling in core config, alongside runtime refresh of MCP servers. Implement toast notifications in the TUI (`render_toasts`, `Toast`, `ToastLevel`) and integrate async handling for session events. Update config loading, validation, and schema versioning to accommodate new MCP scopes and resources. Add `httpmock` as a dev dependency for testing.
This commit is contained in:
@@ -92,6 +92,7 @@ OWLEN uses a modal, vim-inspired interface. Press `F1` (available from any mode)
|
||||
- **Editing Mode**: Enter with `i` or `a`. Send messages with `Enter`.
|
||||
- **Command Mode**: Enter with `:`. Access commands like `:quit`, `:save`, `:theme`.
|
||||
- **Tutorial Command**: Type `:tutorial` any time for a quick summary of the most important keybindings.
|
||||
- **MCP Slash Commands**: Owlen auto-registers zero-argument MCP tools as slash commands—type `/mcp__github__list_prs` (for example) to pull remote context directly into the chat log.
|
||||
|
||||
## Documentation
|
||||
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
//! OWLEN CLI - Chat TUI client
|
||||
|
||||
mod cloud;
|
||||
mod mcp;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use async_trait::async_trait;
|
||||
use clap::{Parser, Subcommand};
|
||||
use cloud::{CloudCommand, load_runtime_credentials, set_env_var};
|
||||
use mcp::{McpCommand, run_mcp_command};
|
||||
use owlen_core::config as core_config;
|
||||
use owlen_core::{
|
||||
ChatStream, Error, Provider,
|
||||
@@ -54,6 +56,9 @@ enum OwlenCommand {
|
||||
/// Manage Ollama Cloud credentials
|
||||
#[command(subcommand)]
|
||||
Cloud(CloudCommand),
|
||||
/// Manage MCP server registrations
|
||||
#[command(subcommand)]
|
||||
Mcp(McpCommand),
|
||||
/// Show manual steps for updating Owlen to the latest revision
|
||||
Upgrade,
|
||||
}
|
||||
@@ -69,7 +74,7 @@ enum ConfigCommand {
|
||||
fn build_provider(cfg: &Config) -> anyhow::Result<Arc<dyn Provider>> {
|
||||
match cfg.mcp.mode {
|
||||
McpMode::RemotePreferred => {
|
||||
let remote_result = if let Some(mcp_server) = cfg.mcp_servers.first() {
|
||||
let remote_result = if let Some(mcp_server) = cfg.effective_mcp_servers().first() {
|
||||
RemoteMcpClient::new_with_config(mcp_server)
|
||||
} else {
|
||||
RemoteMcpClient::new()
|
||||
@@ -91,7 +96,7 @@ fn build_provider(cfg: &Config) -> anyhow::Result<Arc<dyn Provider>> {
|
||||
}
|
||||
}
|
||||
McpMode::RemoteOnly => {
|
||||
let mcp_server = cfg.mcp_servers.first().ok_or_else(|| {
|
||||
let mcp_server = cfg.effective_mcp_servers().first().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"[[mcp_servers]] must be configured when [mcp].mode = \"remote_only\""
|
||||
)
|
||||
@@ -130,6 +135,7 @@ async fn run_command(command: OwlenCommand) -> Result<()> {
|
||||
match command {
|
||||
OwlenCommand::Config(config_cmd) => run_config_command(config_cmd),
|
||||
OwlenCommand::Cloud(cloud_cmd) => cloud::run_cloud_command(cloud_cmd).await,
|
||||
OwlenCommand::Mcp(mcp_cmd) => run_mcp_command(mcp_cmd),
|
||||
OwlenCommand::Upgrade => {
|
||||
println!(
|
||||
"To update Owlen from source:\n git pull\n cargo install --path crates/owlen-cli --force"
|
||||
@@ -157,6 +163,7 @@ fn run_config_doctor() -> Result<()> {
|
||||
let config_path = core_config::default_config_path();
|
||||
let existed = config_path.exists();
|
||||
let mut config = config::try_load_config().unwrap_or_default();
|
||||
let _ = config.refresh_mcp_servers(None);
|
||||
let mut changes = Vec::new();
|
||||
|
||||
if !existed {
|
||||
@@ -205,7 +212,7 @@ fn run_config_doctor() -> Result<()> {
|
||||
config.mcp.warn_on_legacy = true;
|
||||
changes.push("converted [mcp].mode = 'legacy' to 'local_only'".to_string());
|
||||
}
|
||||
McpMode::RemoteOnly if config.mcp_servers.is_empty() => {
|
||||
McpMode::RemoteOnly if config.effective_mcp_servers().is_empty() => {
|
||||
config.mcp.mode = McpMode::RemotePreferred;
|
||||
config.mcp.allow_fallback = true;
|
||||
changes.push(
|
||||
@@ -213,7 +220,9 @@ fn run_config_doctor() -> Result<()> {
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
McpMode::RemotePreferred if !config.mcp.allow_fallback && config.mcp_servers.is_empty() => {
|
||||
McpMode::RemotePreferred
|
||||
if !config.mcp.allow_fallback && config.effective_mcp_servers().is_empty() =>
|
||||
{
|
||||
config.mcp.allow_fallback = true;
|
||||
changes.push(
|
||||
"enabled [mcp].allow_fallback because no remote servers are configured".to_string(),
|
||||
@@ -369,6 +378,7 @@ async fn main() -> Result<()> {
|
||||
let color_support = detect_terminal_color_support();
|
||||
// Load configuration (or fall back to defaults) for the session controller.
|
||||
let mut cfg = config::try_load_config().unwrap_or_default();
|
||||
let _ = cfg.refresh_mcp_servers(None);
|
||||
if let Some(previous_theme) = apply_terminal_theme(&mut cfg, &color_support) {
|
||||
let term_label = match &color_support {
|
||||
TerminalColorSupport::Limited { term } => Cow::from(term.as_str()),
|
||||
@@ -398,7 +408,7 @@ async fn main() -> Result<()> {
|
||||
Ok(_) => provider,
|
||||
Err(err) => {
|
||||
let hint = if matches!(cfg.mcp.mode, McpMode::RemotePreferred | McpMode::RemoteOnly)
|
||||
&& !cfg.mcp_servers.is_empty()
|
||||
&& !cfg.effective_mcp_servers().is_empty()
|
||||
{
|
||||
"Ensure the configured MCP server is running and reachable."
|
||||
} else {
|
||||
@@ -523,7 +533,7 @@ async fn run_app(
|
||||
}
|
||||
}
|
||||
Some(session_event) = session_rx.recv() => {
|
||||
app.handle_session_event(session_event)?;
|
||||
app.handle_session_event(session_event).await?;
|
||||
}
|
||||
_ = tokio::time::sleep(sleep_duration) => {}
|
||||
}
|
||||
|
||||
257
crates/owlen-cli/src/mcp.rs
Normal file
257
crates/owlen-cli/src/mcp.rs
Normal file
@@ -0,0 +1,257 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use clap::{Args, Subcommand, ValueEnum};
|
||||
use owlen_core::config::{self as core_config, Config, McpConfigScope, McpServerConfig};
|
||||
use owlen_tui::config as tui_config;
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
pub enum McpCommand {
|
||||
/// Add or update an MCP server in the selected scope
|
||||
Add(AddArgs),
|
||||
/// List MCP servers across scopes
|
||||
List(ListArgs),
|
||||
/// Remove an MCP server from a scope
|
||||
Remove(RemoveArgs),
|
||||
}
|
||||
|
||||
pub fn run_mcp_command(command: McpCommand) -> Result<()> {
|
||||
match command {
|
||||
McpCommand::Add(args) => handle_add(args),
|
||||
McpCommand::List(args) => handle_list(args),
|
||||
McpCommand::Remove(args) => handle_remove(args),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, ValueEnum, Default)]
|
||||
pub enum ScopeArg {
|
||||
User,
|
||||
#[default]
|
||||
Project,
|
||||
Local,
|
||||
}
|
||||
|
||||
impl From<ScopeArg> for McpConfigScope {
|
||||
fn from(value: ScopeArg) -> Self {
|
||||
match value {
|
||||
ScopeArg::User => McpConfigScope::User,
|
||||
ScopeArg::Project => McpConfigScope::Project,
|
||||
ScopeArg::Local => McpConfigScope::Local,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
pub struct AddArgs {
|
||||
/// Logical name used to reference the server
|
||||
pub name: String,
|
||||
/// Command or endpoint invoked for the server
|
||||
pub command: String,
|
||||
/// Transport mechanism (stdio, http, websocket)
|
||||
#[arg(long, default_value = "stdio")]
|
||||
pub transport: String,
|
||||
/// Configuration scope to write the server into
|
||||
#[arg(long, value_enum, default_value_t = ScopeArg::Project)]
|
||||
pub scope: ScopeArg,
|
||||
/// Environment variables (KEY=VALUE) passed to the server process
|
||||
#[arg(long = "env")]
|
||||
pub env: Vec<String>,
|
||||
/// Additional arguments appended when launching the server
|
||||
#[arg(trailing_var_arg = true, value_name = "ARG")]
|
||||
pub args: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args, Default)]
|
||||
pub struct ListArgs {
|
||||
/// Restrict output to a specific configuration scope
|
||||
#[arg(long, value_enum)]
|
||||
pub scope: Option<ScopeArg>,
|
||||
/// Display only the effective servers (after precedence resolution)
|
||||
#[arg(long)]
|
||||
pub effective_only: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
pub struct RemoveArgs {
|
||||
/// Name of the server to remove
|
||||
pub name: String,
|
||||
/// Optional explicit scope to remove from
|
||||
#[arg(long, value_enum)]
|
||||
pub scope: Option<ScopeArg>,
|
||||
}
|
||||
|
||||
fn handle_add(args: AddArgs) -> Result<()> {
|
||||
let mut config = load_config()?;
|
||||
let scope: McpConfigScope = args.scope.into();
|
||||
let mut env_map = HashMap::new();
|
||||
for pair in &args.env {
|
||||
let (key, value) = pair
|
||||
.split_once('=')
|
||||
.ok_or_else(|| anyhow!("Environment pairs must use KEY=VALUE syntax: '{}'", pair))?;
|
||||
if key.trim().is_empty() {
|
||||
return Err(anyhow!("Environment variable name cannot be empty"));
|
||||
}
|
||||
env_map.insert(key.trim().to_string(), value.to_string());
|
||||
}
|
||||
|
||||
let server = McpServerConfig {
|
||||
name: args.name.clone(),
|
||||
command: args.command.clone(),
|
||||
args: args.args.clone(),
|
||||
transport: args.transport.to_lowercase(),
|
||||
env: env_map,
|
||||
oauth: None,
|
||||
};
|
||||
|
||||
config.add_mcp_server(scope, server.clone(), None)?;
|
||||
if matches!(scope, McpConfigScope::User) {
|
||||
tui_config::save_config(&config)?;
|
||||
}
|
||||
|
||||
if let Some(path) = core_config::mcp_scope_path(scope, None) {
|
||||
println!(
|
||||
"Registered MCP server '{}' in {} scope ({})",
|
||||
server.name,
|
||||
scope,
|
||||
path.display()
|
||||
);
|
||||
} else {
|
||||
println!(
|
||||
"Registered MCP server '{}' in {} scope.",
|
||||
server.name, scope
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_list(args: ListArgs) -> Result<()> {
|
||||
let mut config = load_config()?;
|
||||
config.refresh_mcp_servers(None)?;
|
||||
|
||||
let scoped = config.scoped_mcp_servers();
|
||||
if scoped.is_empty() {
|
||||
println!("No MCP servers configured.");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let filter_scope = args.scope.map(|scope| scope.into());
|
||||
let effective = config.effective_mcp_servers();
|
||||
let mut active = HashSet::new();
|
||||
for server in effective {
|
||||
active.insert((
|
||||
server.name.clone(),
|
||||
server.command.clone(),
|
||||
server.transport.to_lowercase(),
|
||||
));
|
||||
}
|
||||
|
||||
println!(
|
||||
"{:<2} {:<8} {:<20} {:<10} Command",
|
||||
"", "Scope", "Name", "Transport"
|
||||
);
|
||||
for entry in scoped {
|
||||
if let Some(target_scope) = filter_scope
|
||||
&& entry.scope != target_scope
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let payload = format_command_line(&entry.config.command, &entry.config.args);
|
||||
let key = (
|
||||
entry.config.name.clone(),
|
||||
entry.config.command.clone(),
|
||||
entry.config.transport.to_lowercase(),
|
||||
);
|
||||
let marker = if active.contains(&key) { "*" } else { " " };
|
||||
|
||||
if args.effective_only && marker != "*" {
|
||||
continue;
|
||||
}
|
||||
|
||||
println!(
|
||||
"{} {:<8} {:<20} {:<10} {}",
|
||||
marker, entry.scope, entry.config.name, entry.config.transport, payload
|
||||
);
|
||||
}
|
||||
|
||||
let scoped_resources = config.scoped_mcp_resources();
|
||||
if !scoped_resources.is_empty() {
|
||||
println!();
|
||||
println!("{:<2} {:<8} {:<30} Title", "", "Scope", "Resource");
|
||||
let effective_keys: HashSet<(String, String)> = config
|
||||
.effective_mcp_resources()
|
||||
.iter()
|
||||
.map(|res| (res.server.clone(), res.uri.clone()))
|
||||
.collect();
|
||||
|
||||
for entry in scoped_resources {
|
||||
if let Some(target_scope) = filter_scope
|
||||
&& entry.scope != target_scope
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let key = (entry.config.server.clone(), entry.config.uri.clone());
|
||||
let marker = if effective_keys.contains(&key) {
|
||||
"*"
|
||||
} else {
|
||||
" "
|
||||
};
|
||||
if args.effective_only && marker != "*" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let reference = format!("@{}:{}", entry.config.server, entry.config.uri);
|
||||
let title = entry.config.title.as_deref().unwrap_or("—");
|
||||
|
||||
println!("{} {:<8} {:<30} {}", marker, entry.scope, reference, title);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_remove(args: RemoveArgs) -> Result<()> {
|
||||
let mut config = load_config()?;
|
||||
let scope_hint = args.scope.map(|scope| scope.into());
|
||||
let result = config.remove_mcp_server(scope_hint, &args.name, None)?;
|
||||
|
||||
match result {
|
||||
Some(scope) => {
|
||||
if matches!(scope, McpConfigScope::User) {
|
||||
tui_config::save_config(&config)?;
|
||||
}
|
||||
|
||||
if let Some(path) = core_config::mcp_scope_path(scope, None) {
|
||||
println!(
|
||||
"Removed MCP server '{}' from {} scope ({})",
|
||||
args.name,
|
||||
scope,
|
||||
path.display()
|
||||
);
|
||||
} else {
|
||||
println!("Removed MCP server '{}' from {} scope.", args.name, scope);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
println!("No MCP server named '{}' was found.", args.name);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_config() -> Result<Config> {
|
||||
let mut config = tui_config::try_load_config().unwrap_or_default();
|
||||
config.refresh_mcp_servers(None)?;
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
fn format_command_line(command: &str, args: &[String]) -> String {
|
||||
if args.is_empty() {
|
||||
command.to_string()
|
||||
} else {
|
||||
format!("{} {}", command, args.join(" "))
|
||||
}
|
||||
}
|
||||
@@ -50,3 +50,4 @@ ollama-rs = { version = "0.3", features = ["stream", "headers"] }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-test = { workspace = true }
|
||||
httpmock = "0.7"
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
use crate::Error;
|
||||
use crate::ProviderConfig;
|
||||
use crate::Result;
|
||||
use crate::mode::ModeConfig;
|
||||
use crate::ui::RoleLabelDisplay;
|
||||
use serde::de::{self, Deserializer, Visitor};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::fmt;
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::str::FromStr;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Default location for the OWLEN configuration file
|
||||
@@ -54,6 +56,21 @@ pub struct Config {
|
||||
/// External MCP server definitions
|
||||
#[serde(default)]
|
||||
pub mcp_servers: Vec<McpServerConfig>,
|
||||
/// User-scoped resource definitions
|
||||
#[serde(default)]
|
||||
pub mcp_resources: Vec<McpResourceConfig>,
|
||||
/// Resolved MCP servers across scopes (runtime only).
|
||||
#[serde(skip)]
|
||||
pub scoped_mcp_servers: Vec<ScopedMcpServer>,
|
||||
/// Effective MCP servers after applying precedence rules (runtime only).
|
||||
#[serde(skip)]
|
||||
pub effective_mcp_servers: Vec<McpServerConfig>,
|
||||
/// Resolved MCP resources across scopes (runtime only).
|
||||
#[serde(skip)]
|
||||
pub scoped_mcp_resources: Vec<ScopedMcpResource>,
|
||||
/// Effective MCP resources after precedence (runtime only).
|
||||
#[serde(skip)]
|
||||
pub effective_mcp_resources: Vec<McpResourceConfig>,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
@@ -74,6 +91,11 @@ impl Default for Config {
|
||||
tools: ToolSettings::default(),
|
||||
modes: ModeConfig::default(),
|
||||
mcp_servers: Vec::new(),
|
||||
mcp_resources: Vec::new(),
|
||||
scoped_mcp_servers: Vec::new(),
|
||||
effective_mcp_servers: Vec::new(),
|
||||
scoped_mcp_resources: Vec::new(),
|
||||
effective_mcp_resources: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -94,6 +116,9 @@ pub struct McpServerConfig {
|
||||
/// Optional environment variable map for the process.
|
||||
#[serde(default)]
|
||||
pub env: std::collections::HashMap<String, String>,
|
||||
/// Optional OAuth configuration for remote servers.
|
||||
#[serde(default)]
|
||||
pub oauth: Option<McpOAuthConfig>,
|
||||
}
|
||||
|
||||
impl McpServerConfig {
|
||||
@@ -102,6 +127,126 @@ impl McpServerConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// OAuth configuration for MCP servers that require delegated authentication.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct McpOAuthConfig {
|
||||
/// Public client identifier registered with the authorization server.
|
||||
pub client_id: String,
|
||||
/// Optional client secret for confidential clients.
|
||||
#[serde(default)]
|
||||
pub client_secret: Option<String>,
|
||||
/// OAuth authorization endpoint (used for web-based flows).
|
||||
pub authorize_url: String,
|
||||
/// OAuth token endpoint.
|
||||
pub token_url: String,
|
||||
/// Optional device authorization endpoint for device-code flows.
|
||||
#[serde(default)]
|
||||
pub device_authorization_url: Option<String>,
|
||||
/// Optional redirect URL (PKCE / authorization-code flows).
|
||||
#[serde(default)]
|
||||
pub redirect_url: Option<String>,
|
||||
/// Requested OAuth scopes.
|
||||
#[serde(default)]
|
||||
pub scopes: Vec<String>,
|
||||
/// Environment variable name populated with the bearer access token when spawning stdio servers.
|
||||
#[serde(default)]
|
||||
pub token_env: Option<String>,
|
||||
/// Optional HTTP header name for bearer authentication (defaults to "Authorization").
|
||||
#[serde(default)]
|
||||
pub header: Option<String>,
|
||||
/// Optional prefix prepended to the access token (defaults to "Bearer ").
|
||||
#[serde(default)]
|
||||
pub header_prefix: Option<String>,
|
||||
}
|
||||
|
||||
impl McpOAuthConfig {
|
||||
pub fn header_name(&self) -> &str {
|
||||
self.header.as_deref().unwrap_or("Authorization")
|
||||
}
|
||||
|
||||
pub fn header_prefix(&self) -> &str {
|
||||
self.header_prefix.as_deref().unwrap_or("Bearer ")
|
||||
}
|
||||
}
|
||||
|
||||
/// Scope for MCP server configuration entries.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum McpConfigScope {
|
||||
/// User-level configuration stored under the user's config directory.
|
||||
User,
|
||||
/// Project configuration stored in the repository (e.g. `.mcp.json`).
|
||||
Project,
|
||||
/// Local overrides stored alongside the project but excluded from version control.
|
||||
Local,
|
||||
}
|
||||
|
||||
impl McpConfigScope {
|
||||
fn precedence_iter() -> impl Iterator<Item = Self> {
|
||||
[
|
||||
McpConfigScope::Local,
|
||||
McpConfigScope::Project,
|
||||
McpConfigScope::User,
|
||||
]
|
||||
.into_iter()
|
||||
}
|
||||
|
||||
fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
McpConfigScope::User => "user",
|
||||
McpConfigScope::Project => "project",
|
||||
McpConfigScope::Local => "local",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for McpConfigScope {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(self.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for McpConfigScope {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||
match s.to_ascii_lowercase().as_str() {
|
||||
"user" => Ok(McpConfigScope::User),
|
||||
"project" => Ok(McpConfigScope::Project),
|
||||
"local" => Ok(McpConfigScope::Local),
|
||||
other => Err(format!("Unknown MCP scope '{other}'")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A resolved MCP server entry annotated with its configuration scope.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ScopedMcpServer {
|
||||
pub scope: McpConfigScope,
|
||||
pub config: McpServerConfig,
|
||||
}
|
||||
|
||||
/// Configuration for a predefined MCP resource reference.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct McpResourceConfig {
|
||||
/// Named MCP server that owns this resource.
|
||||
pub server: String,
|
||||
/// URI or path identifying the resource within the server.
|
||||
pub uri: String,
|
||||
/// Optional short title displayed in UI.
|
||||
#[serde(default)]
|
||||
pub title: Option<String>,
|
||||
/// Optional detailed description shown in tooltips.
|
||||
#[serde(default)]
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
/// Resource entry annotated with its originating scope.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ScopedMcpResource {
|
||||
pub scope: McpConfigScope,
|
||||
pub config: McpResourceConfig,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
fn default_schema_version() -> String {
|
||||
CONFIG_SCHEMA_VERSION.to_string()
|
||||
@@ -138,18 +283,22 @@ impl Config {
|
||||
config.mcp.apply_backward_compat();
|
||||
config.apply_schema_migrations(&previous_version);
|
||||
config.expand_provider_env_vars()?;
|
||||
config.refresh_mcp_servers(None)?;
|
||||
config.validate()?;
|
||||
Ok(config)
|
||||
} else {
|
||||
let mut config = Config::default();
|
||||
config.expand_provider_env_vars()?;
|
||||
config.refresh_mcp_servers(None)?;
|
||||
Ok(config)
|
||||
}
|
||||
}
|
||||
|
||||
/// Persist configuration to disk
|
||||
pub fn save(&self, path: Option<&Path>) -> Result<()> {
|
||||
self.validate()?;
|
||||
let mut validator = self.clone();
|
||||
validator.refresh_mcp_servers(None)?;
|
||||
validator.validate()?;
|
||||
|
||||
let path = match path {
|
||||
Some(path) => path.to_path_buf(),
|
||||
@@ -214,6 +363,192 @@ impl Config {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Refresh the resolved MCP server list by loading scope-specific definitions.
|
||||
pub fn refresh_mcp_servers(&mut self, project_hint: Option<&Path>) -> Result<()> {
|
||||
let mut scoped_servers = Vec::new();
|
||||
let mut scoped_resources = Vec::new();
|
||||
|
||||
let mut user_servers = self.mcp_servers.clone();
|
||||
expand_mcp_servers(&mut user_servers, "config.mcp_servers")?;
|
||||
for server in user_servers {
|
||||
scoped_servers.push(ScopedMcpServer {
|
||||
scope: McpConfigScope::User,
|
||||
config: server,
|
||||
});
|
||||
}
|
||||
|
||||
let mut user_resources = self.mcp_resources.clone();
|
||||
expand_mcp_resources(&mut user_resources, "config.mcp_resources")?;
|
||||
for resource in user_resources {
|
||||
scoped_resources.push(ScopedMcpResource {
|
||||
scope: McpConfigScope::User,
|
||||
config: resource,
|
||||
});
|
||||
}
|
||||
|
||||
for scope in [McpConfigScope::Project, McpConfigScope::Local] {
|
||||
if let Some(path) = mcp_scope_path(scope, project_hint) {
|
||||
let mut file = read_scope_config(&path)?;
|
||||
let server_context = format!("mcp.{scope}.servers");
|
||||
expand_mcp_servers(&mut file.servers, &server_context)?;
|
||||
for server in file.servers {
|
||||
scoped_servers.push(ScopedMcpServer {
|
||||
scope,
|
||||
config: server,
|
||||
});
|
||||
}
|
||||
|
||||
let resource_context = format!("mcp.{scope}.resources");
|
||||
expand_mcp_resources(&mut file.resources, &resource_context)?;
|
||||
for resource in file.resources {
|
||||
scoped_resources.push(ScopedMcpResource {
|
||||
scope,
|
||||
config: resource,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut effective_servers = Vec::new();
|
||||
let mut seen_servers = HashSet::new();
|
||||
for scope in McpConfigScope::precedence_iter() {
|
||||
for entry in scoped_servers.iter().filter(|entry| entry.scope == scope) {
|
||||
if seen_servers.insert(entry.config.name.clone()) {
|
||||
effective_servers.push(entry.config.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut effective_resources = Vec::new();
|
||||
let mut seen_resources: HashSet<(String, String)> = HashSet::new();
|
||||
for scope in McpConfigScope::precedence_iter() {
|
||||
for entry in scoped_resources.iter().filter(|entry| entry.scope == scope) {
|
||||
let key = (entry.config.server.clone(), entry.config.uri.clone());
|
||||
if seen_resources.insert(key) {
|
||||
effective_resources.push(entry.config.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.scoped_mcp_servers = scoped_servers;
|
||||
self.effective_mcp_servers = effective_servers;
|
||||
self.scoped_mcp_resources = scoped_resources;
|
||||
self.effective_mcp_resources = effective_resources;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Return the merged MCP servers using scope precedence (local > project > user).
|
||||
pub fn effective_mcp_servers(&self) -> &[McpServerConfig] {
|
||||
&self.effective_mcp_servers
|
||||
}
|
||||
|
||||
/// Return MCP servers annotated with their originating scope.
|
||||
pub fn scoped_mcp_servers(&self) -> &[ScopedMcpServer] {
|
||||
&self.scoped_mcp_servers
|
||||
}
|
||||
|
||||
/// Return merged MCP resources using scope precedence (local > project > user).
|
||||
pub fn effective_mcp_resources(&self) -> &[McpResourceConfig] {
|
||||
&self.effective_mcp_resources
|
||||
}
|
||||
|
||||
/// Return scoped MCP resources with their origin scope metadata.
|
||||
pub fn scoped_mcp_resources(&self) -> &[ScopedMcpResource] {
|
||||
&self.scoped_mcp_resources
|
||||
}
|
||||
|
||||
/// Locate a configured resource by server and URI.
|
||||
pub fn find_resource(&self, server: &str, uri: &str) -> Option<&McpResourceConfig> {
|
||||
self.effective_mcp_resources
|
||||
.iter()
|
||||
.find(|resource| resource.server == server && resource.uri == uri)
|
||||
}
|
||||
|
||||
/// Add or replace an MCP server definition within the specified scope.
|
||||
pub fn add_mcp_server(
|
||||
&mut self,
|
||||
scope: McpConfigScope,
|
||||
server: McpServerConfig,
|
||||
project_hint: Option<&Path>,
|
||||
) -> Result<()> {
|
||||
match scope {
|
||||
McpConfigScope::User => {
|
||||
self.mcp_servers
|
||||
.retain(|existing| existing.name != server.name);
|
||||
self.mcp_servers.push(server);
|
||||
}
|
||||
other => {
|
||||
let path = mcp_scope_path(other, project_hint).ok_or_else(|| {
|
||||
Error::Config(format!(
|
||||
"Unable to resolve project root for MCP scope '{}'",
|
||||
other
|
||||
))
|
||||
})?;
|
||||
let mut file = read_scope_config(&path)?;
|
||||
file.servers.retain(|existing| existing.name != server.name);
|
||||
file.servers.push(server);
|
||||
write_scope_config(&path, &file)?;
|
||||
}
|
||||
}
|
||||
|
||||
self.refresh_mcp_servers(project_hint)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove an MCP server from the given scope, or infer the scope if omitted.
|
||||
pub fn remove_mcp_server(
|
||||
&mut self,
|
||||
scope: Option<McpConfigScope>,
|
||||
name: &str,
|
||||
project_hint: Option<&Path>,
|
||||
) -> Result<Option<McpConfigScope>> {
|
||||
let target_scope = if let Some(scope) = scope {
|
||||
scope
|
||||
} else {
|
||||
self.refresh_mcp_servers(project_hint)?;
|
||||
match self
|
||||
.scoped_mcp_servers
|
||||
.iter()
|
||||
.find(|entry| entry.config.name == name)
|
||||
{
|
||||
Some(entry) => entry.scope,
|
||||
None => return Ok(None),
|
||||
}
|
||||
};
|
||||
|
||||
let removed = match target_scope {
|
||||
McpConfigScope::User => {
|
||||
let before = self.mcp_servers.len();
|
||||
self.mcp_servers.retain(|entry| entry.name != name);
|
||||
before != self.mcp_servers.len()
|
||||
}
|
||||
other => {
|
||||
let path = mcp_scope_path(other, project_hint).ok_or_else(|| {
|
||||
Error::Config(format!(
|
||||
"Unable to resolve project root for MCP scope '{}'",
|
||||
other
|
||||
))
|
||||
})?;
|
||||
let mut file = read_scope_config(&path)?;
|
||||
let before = file.servers.len();
|
||||
file.servers.retain(|entry| entry.name != name);
|
||||
if before == file.servers.len() {
|
||||
false
|
||||
} else {
|
||||
write_scope_config(&path, &file)?;
|
||||
true
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if removed {
|
||||
self.refresh_mcp_servers(project_hint)?;
|
||||
Ok(Some(target_scope))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate configuration invariants and surface actionable error messages.
|
||||
pub fn validate(&self) -> Result<()> {
|
||||
self.validate_default_provider()?;
|
||||
@@ -284,9 +619,15 @@ impl Config {
|
||||
}
|
||||
|
||||
fn validate_mcp_settings(&self) -> Result<()> {
|
||||
let has_effective_servers = if self.effective_mcp_servers.is_empty() {
|
||||
!self.mcp_servers.is_empty()
|
||||
} else {
|
||||
!self.effective_mcp_servers.is_empty()
|
||||
};
|
||||
|
||||
match self.mcp.mode {
|
||||
McpMode::RemoteOnly => {
|
||||
if self.mcp_servers.is_empty() {
|
||||
if !has_effective_servers {
|
||||
return Err(crate::Error::Config(
|
||||
"[mcp].mode = 'remote_only' requires at least one [[mcp_servers]] entry"
|
||||
.to_string(),
|
||||
@@ -294,7 +635,7 @@ impl Config {
|
||||
}
|
||||
}
|
||||
McpMode::RemotePreferred => {
|
||||
if !self.mcp.allow_fallback && self.mcp_servers.is_empty() {
|
||||
if !self.mcp.allow_fallback && !has_effective_servers {
|
||||
return Err(crate::Error::Config(
|
||||
"[mcp].allow_fallback = false requires at least one [[mcp_servers]] entry"
|
||||
.to_string(),
|
||||
@@ -313,26 +654,13 @@ impl Config {
|
||||
}
|
||||
|
||||
fn validate_mcp_servers(&self) -> Result<()> {
|
||||
for server in &self.mcp_servers {
|
||||
if server.name.trim().is_empty() {
|
||||
return Err(crate::Error::Config(
|
||||
"Each [[mcp_servers]] entry must include a non-empty name".to_string(),
|
||||
));
|
||||
if self.scoped_mcp_servers.is_empty() {
|
||||
for server in &self.mcp_servers {
|
||||
validate_mcp_server_entry(server, McpConfigScope::User)?;
|
||||
}
|
||||
|
||||
if server.command.trim().is_empty() {
|
||||
return Err(crate::Error::Config(format!(
|
||||
"MCP server '{}' must define a command or endpoint",
|
||||
server.name
|
||||
)));
|
||||
}
|
||||
|
||||
let transport = server.transport.to_lowercase();
|
||||
if !matches!(transport.as_str(), "stdio" | "http" | "websocket") {
|
||||
return Err(crate::Error::Config(format!(
|
||||
"Unknown MCP transport '{}' for server '{}'",
|
||||
server.transport, server.name
|
||||
)));
|
||||
} else {
|
||||
for entry in &self.scoped_mcp_servers {
|
||||
validate_mcp_server_entry(&entry.config, entry.scope)?;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -349,6 +677,58 @@ fn default_ollama_provider_config() -> ProviderConfig {
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_mcp_server_entry(server: &McpServerConfig, scope: McpConfigScope) -> Result<()> {
|
||||
if server.name.trim().is_empty() {
|
||||
return Err(Error::Config(format!(
|
||||
"Each MCP server entry must include a non-empty name (scope: {scope})"
|
||||
)));
|
||||
}
|
||||
|
||||
if server.command.trim().is_empty() {
|
||||
return Err(Error::Config(format!(
|
||||
"MCP server '{}' must define a command or endpoint (scope: {scope})",
|
||||
server.name
|
||||
)));
|
||||
}
|
||||
|
||||
let transport = server.transport.to_lowercase();
|
||||
if !matches!(transport.as_str(), "stdio" | "http" | "websocket") {
|
||||
return Err(Error::Config(format!(
|
||||
"Unknown MCP transport '{}' for server '{}' (scope: {scope})",
|
||||
server.transport, server.name
|
||||
)));
|
||||
}
|
||||
|
||||
if let Some(oauth) = &server.oauth {
|
||||
if oauth.client_id.trim().is_empty() {
|
||||
return Err(Error::Config(format!(
|
||||
"MCP server '{}' defines OAuth without a client_id",
|
||||
server.name
|
||||
)));
|
||||
}
|
||||
if oauth.authorize_url.trim().is_empty() {
|
||||
return Err(Error::Config(format!(
|
||||
"MCP server '{}' defines OAuth without an authorize_url",
|
||||
server.name
|
||||
)));
|
||||
}
|
||||
if oauth.token_url.trim().is_empty() {
|
||||
return Err(Error::Config(format!(
|
||||
"MCP server '{}' defines OAuth without a token_url",
|
||||
server.name
|
||||
)));
|
||||
}
|
||||
if oauth.device_authorization_url.is_none() && oauth.redirect_url.is_none() {
|
||||
return Err(Error::Config(format!(
|
||||
"MCP server '{}' must define either device_authorization_url or redirect_url for OAuth flows",
|
||||
server.name
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn expand_provider_entry(provider_name: &str, provider: &mut ProviderConfig) -> Result<()> {
|
||||
if let Some(ref mut base_url) = provider.base_url {
|
||||
let expanded = expand_env_string(
|
||||
@@ -379,6 +759,136 @@ fn expand_provider_entry(provider_name: &str, provider: &mut ProviderConfig) ->
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn expand_mcp_servers(servers: &mut [McpServerConfig], field_path: &str) -> Result<()> {
|
||||
for (idx, server) in servers.iter_mut().enumerate() {
|
||||
expand_mcp_server_entry(server, field_path, idx)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn expand_mcp_server_entry(
|
||||
server: &mut McpServerConfig,
|
||||
field_path: &str,
|
||||
index: usize,
|
||||
) -> Result<()> {
|
||||
server.command = expand_env_string(
|
||||
server.command.as_str(),
|
||||
&format!("{field_path}[{index}].command"),
|
||||
)?;
|
||||
|
||||
for (arg_idx, arg) in server.args.iter_mut().enumerate() {
|
||||
*arg = expand_env_string(
|
||||
arg.as_str(),
|
||||
&format!("{field_path}[{index}].args[{arg_idx}]"),
|
||||
)?;
|
||||
}
|
||||
|
||||
for (env_key, env_value) in server.env.iter_mut() {
|
||||
*env_value = expand_env_string(
|
||||
env_value.as_str(),
|
||||
&format!("{field_path}[{index}].env.{env_key}"),
|
||||
)?;
|
||||
}
|
||||
|
||||
if let Some(oauth) = server.oauth.as_mut() {
|
||||
oauth.client_id = expand_env_string(
|
||||
oauth.client_id.as_str(),
|
||||
&format!("{field_path}[{index}].oauth.client_id"),
|
||||
)?;
|
||||
oauth.authorize_url = expand_env_string(
|
||||
oauth.authorize_url.as_str(),
|
||||
&format!("{field_path}[{index}].oauth.authorize_url"),
|
||||
)?;
|
||||
oauth.token_url = expand_env_string(
|
||||
oauth.token_url.as_str(),
|
||||
&format!("{field_path}[{index}].oauth.token_url"),
|
||||
)?;
|
||||
|
||||
if let Some(secret) = oauth.client_secret.as_mut() {
|
||||
*secret = expand_env_string(
|
||||
secret.as_str(),
|
||||
&format!("{field_path}[{index}].oauth.client_secret"),
|
||||
)?;
|
||||
}
|
||||
|
||||
if let Some(device_url) = oauth.device_authorization_url.as_mut() {
|
||||
*device_url = expand_env_string(
|
||||
device_url.as_str(),
|
||||
&format!("{field_path}[{index}].oauth.device_authorization_url"),
|
||||
)?;
|
||||
}
|
||||
|
||||
if let Some(redirect) = oauth.redirect_url.as_mut() {
|
||||
*redirect = expand_env_string(
|
||||
redirect.as_str(),
|
||||
&format!("{field_path}[{index}].oauth.redirect_url"),
|
||||
)?;
|
||||
}
|
||||
|
||||
if let Some(token_env) = oauth.token_env.as_mut() {
|
||||
*token_env = expand_env_string(
|
||||
token_env.as_str(),
|
||||
&format!("{field_path}[{index}].oauth.token_env"),
|
||||
)?;
|
||||
}
|
||||
|
||||
if let Some(header) = oauth.header.as_mut() {
|
||||
*header = expand_env_string(
|
||||
header.as_str(),
|
||||
&format!("{field_path}[{index}].oauth.header"),
|
||||
)?;
|
||||
}
|
||||
|
||||
if let Some(prefix) = oauth.header_prefix.as_mut() {
|
||||
*prefix = expand_env_string(
|
||||
prefix.as_str(),
|
||||
&format!("{field_path}[{index}].oauth.header_prefix"),
|
||||
)?;
|
||||
}
|
||||
|
||||
for (scope_idx, scope) in oauth.scopes.iter_mut().enumerate() {
|
||||
*scope = expand_env_string(
|
||||
scope.as_str(),
|
||||
&format!("{field_path}[{index}].oauth.scopes[{scope_idx}]"),
|
||||
)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn expand_mcp_resources(resources: &mut [McpResourceConfig], field_path: &str) -> Result<()> {
|
||||
for (idx, resource) in resources.iter_mut().enumerate() {
|
||||
expand_mcp_resource_entry(resource, field_path, idx)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn expand_mcp_resource_entry(
|
||||
resource: &mut McpResourceConfig,
|
||||
field_path: &str,
|
||||
index: usize,
|
||||
) -> Result<()> {
|
||||
resource.server = expand_env_string(
|
||||
resource.server.as_str(),
|
||||
&format!("{field_path}[{index}].server"),
|
||||
)?;
|
||||
resource.uri = expand_env_string(resource.uri.as_str(), &format!("{field_path}[{index}].uri"))?;
|
||||
|
||||
if let Some(title) = resource.title.as_mut() {
|
||||
*title = expand_env_string(title.as_str(), &format!("{field_path}[{index}].title"))?;
|
||||
}
|
||||
|
||||
if let Some(description) = resource.description.as_mut() {
|
||||
*description = expand_env_string(
|
||||
description.as_str(),
|
||||
&format!("{field_path}[{index}].description"),
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn expand_env_string(input: &str, field_path: &str) -> Result<String> {
|
||||
if !input.contains('$') {
|
||||
return Ok(input.to_string());
|
||||
@@ -408,6 +918,106 @@ pub fn default_config_path() -> PathBuf {
|
||||
PathBuf::from(shellexpand::tilde(DEFAULT_CONFIG_PATH).as_ref())
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Clone)]
|
||||
struct McpConfigFile {
|
||||
#[serde(default)]
|
||||
servers: Vec<McpServerConfig>,
|
||||
#[serde(default)]
|
||||
resources: Vec<McpResourceConfig>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum McpConfigEnvelope {
|
||||
Array(Vec<McpServerConfig>),
|
||||
Object(McpConfigFile),
|
||||
}
|
||||
|
||||
fn read_scope_config(path: &Path) -> Result<McpConfigFile> {
|
||||
if !path.exists() {
|
||||
return Ok(McpConfigFile::default());
|
||||
}
|
||||
|
||||
let contents = fs::read_to_string(path).map_err(Error::Io)?;
|
||||
if contents.trim().is_empty() {
|
||||
return Ok(McpConfigFile::default());
|
||||
}
|
||||
|
||||
let doc: McpConfigEnvelope = serde_json::from_str(&contents).map_err(|err| {
|
||||
Error::Config(format!(
|
||||
"Failed to parse MCP configuration at {}: {err}",
|
||||
path.display()
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(match doc {
|
||||
McpConfigEnvelope::Array(servers) => McpConfigFile {
|
||||
servers,
|
||||
resources: Vec::new(),
|
||||
},
|
||||
McpConfigEnvelope::Object(doc) => doc,
|
||||
})
|
||||
}
|
||||
|
||||
fn write_scope_config(path: &Path, file: &McpConfigFile) -> Result<()> {
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent).map_err(Error::Io)?;
|
||||
}
|
||||
|
||||
let serialized = serde_json::to_string_pretty(file).map_err(|err| {
|
||||
Error::Config(format!(
|
||||
"Failed to serialize MCP configuration for {}: {err}",
|
||||
path.display()
|
||||
))
|
||||
})?;
|
||||
|
||||
fs::write(path, serialized).map_err(Error::Io)
|
||||
}
|
||||
|
||||
/// Resolve the configuration file path for a given scope.
|
||||
pub fn mcp_scope_path(scope: McpConfigScope, project_hint: Option<&Path>) -> Option<PathBuf> {
|
||||
match scope {
|
||||
McpConfigScope::User => dirs::config_dir()
|
||||
.or_else(|| Some(PathBuf::from(shellexpand::tilde("~/.config").as_ref())))
|
||||
.map(|dir| dir.join("owlen").join("mcp.json")),
|
||||
McpConfigScope::Project | McpConfigScope::Local => {
|
||||
let root = project_hint
|
||||
.map(PathBuf::from)
|
||||
.or_else(|| discover_project_root(None))?;
|
||||
|
||||
if matches!(scope, McpConfigScope::Project) {
|
||||
Some(root.join(".mcp.json"))
|
||||
} else {
|
||||
Some(root.join(".owlen").join("mcp.local.json"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn discover_project_root(start: Option<&Path>) -> Option<PathBuf> {
|
||||
let mut current = start
|
||||
.map(PathBuf::from)
|
||||
.or_else(|| std::env::current_dir().ok())?;
|
||||
|
||||
loop {
|
||||
if current.join(".mcp.json").exists()
|
||||
|| current.join(".owlen").exists()
|
||||
|| current.join(".git").exists()
|
||||
|| current.join("Cargo.toml").exists()
|
||||
{
|
||||
return Some(current);
|
||||
}
|
||||
|
||||
if !current.pop() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
start
|
||||
.map(PathBuf::from)
|
||||
.or_else(|| std::env::current_dir().ok())
|
||||
}
|
||||
|
||||
/// General behaviour settings shared across clients
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GeneralSettings {
|
||||
@@ -1173,6 +1783,7 @@ mod tests {
|
||||
transport: "udp".into(),
|
||||
args: Vec::new(),
|
||||
env: std::collections::HashMap::new(),
|
||||
oauth: None,
|
||||
}];
|
||||
let result = config.validate();
|
||||
assert!(
|
||||
@@ -1186,4 +1797,113 @@ mod tests {
|
||||
config.mcp.mode = McpMode::LocalOnly;
|
||||
assert!(config.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn refresh_mcp_servers_merges_scopes_with_precedence() {
|
||||
let temp = tempfile::tempdir().expect("tempdir");
|
||||
let project_root = temp.path();
|
||||
std::fs::write(
|
||||
project_root.join(".mcp.json"),
|
||||
r#"{
|
||||
"servers": [
|
||||
{ "name": "shared", "command": "project-cmd", "transport": "stdio" },
|
||||
{ "name": "project-only", "command": "proj", "transport": "stdio" }
|
||||
],
|
||||
"resources": [
|
||||
{ "server": "github", "uri": "issue://123", "title": "Project Issue" },
|
||||
{ "server": "docs", "uri": "page://start", "title": "Project Doc" }
|
||||
]
|
||||
}"#,
|
||||
)
|
||||
.expect("write project scope");
|
||||
|
||||
let local_dir = project_root.join(".owlen");
|
||||
std::fs::create_dir_all(&local_dir).expect("local dir");
|
||||
std::fs::write(
|
||||
local_dir.join("mcp.local.json"),
|
||||
r#"{
|
||||
"servers": [
|
||||
{ "name": "shared", "command": "local-cmd", "transport": "stdio" }
|
||||
],
|
||||
"resources": [
|
||||
{ "server": "github", "uri": "issue://123", "title": "Local Override" }
|
||||
]
|
||||
}"#,
|
||||
)
|
||||
.expect("write local scope");
|
||||
|
||||
let mut config = Config::default();
|
||||
config.mcp_servers.push(McpServerConfig {
|
||||
name: "shared".into(),
|
||||
command: "user-cmd".into(),
|
||||
args: Vec::new(),
|
||||
transport: "stdio".into(),
|
||||
env: std::collections::HashMap::new(),
|
||||
oauth: None,
|
||||
});
|
||||
config.mcp_resources.push(McpResourceConfig {
|
||||
server: "github".into(),
|
||||
uri: "issue://123".into(),
|
||||
title: Some("User Issue".into()),
|
||||
description: None,
|
||||
});
|
||||
|
||||
config
|
||||
.refresh_mcp_servers(Some(project_root))
|
||||
.expect("refresh scopes");
|
||||
|
||||
// We should have four scoped entries (user + two project + local) and precedence should select local
|
||||
assert_eq!(config.scoped_mcp_servers().len(), 4);
|
||||
let effective = config.effective_mcp_servers();
|
||||
assert_eq!(effective.len(), 2); // shared + project-only
|
||||
assert_eq!(effective[0].command, "local-cmd");
|
||||
assert_eq!(effective[0].name, "shared");
|
||||
|
||||
assert_eq!(config.scoped_mcp_resources().len(), 4);
|
||||
let effective_resources = config.effective_mcp_resources();
|
||||
assert_eq!(effective_resources.len(), 2);
|
||||
assert_eq!(
|
||||
effective_resources
|
||||
.iter()
|
||||
.find(|res| res.server == "github")
|
||||
.and_then(|res| res.title.as_deref()),
|
||||
Some("Local Override")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_mcp_server_reports_scope() {
|
||||
let temp = tempfile::tempdir().expect("tempdir");
|
||||
let project_root = temp.path();
|
||||
std::fs::write(
|
||||
project_root.join(".mcp.json"),
|
||||
r#"{ "servers": [{ "name": "project", "command": "proj", "transport": "stdio" }] }"#,
|
||||
)
|
||||
.expect("write project scope");
|
||||
|
||||
let mut config = Config::default();
|
||||
config.mcp_servers.push(McpServerConfig {
|
||||
name: "user".into(),
|
||||
command: "user".into(),
|
||||
args: Vec::new(),
|
||||
transport: "stdio".into(),
|
||||
env: std::collections::HashMap::new(),
|
||||
oauth: None,
|
||||
});
|
||||
config
|
||||
.refresh_mcp_servers(Some(project_root))
|
||||
.expect("refresh scopes");
|
||||
|
||||
// Remove without specifying scope should pick highest precedence (project)
|
||||
let removed_scope = config
|
||||
.remove_mcp_server(None, "project", Some(project_root))
|
||||
.expect("remove call");
|
||||
assert_eq!(removed_scope, Some(McpConfigScope::Project));
|
||||
|
||||
// Remove the remaining user scope explicitly
|
||||
let removed_scope = config
|
||||
.remove_mcp_server(Some(McpConfigScope::User), "user", Some(project_root))
|
||||
.expect("remove user");
|
||||
assert_eq!(removed_scope, Some(McpConfigScope::User));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::sync::Arc;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{Error, Result, storage::StorageManager};
|
||||
use crate::{Error, Result, oauth::OAuthToken, storage::StorageManager};
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct ApiCredentials {
|
||||
@@ -31,6 +31,10 @@ impl CredentialManager {
|
||||
format!("{}_{}", self.namespace, tool_name)
|
||||
}
|
||||
|
||||
fn oauth_storage_key(&self, resource: &str) -> String {
|
||||
self.namespaced_key(&format!("oauth_{resource}"))
|
||||
}
|
||||
|
||||
pub async fn store_credentials(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
@@ -68,4 +72,37 @@ impl CredentialManager {
|
||||
let key = self.namespaced_key(tool_name);
|
||||
self.storage.delete_secure_item(&key).await
|
||||
}
|
||||
|
||||
pub async fn store_oauth_token(&self, resource: &str, token: &OAuthToken) -> Result<()> {
|
||||
let key = self.oauth_storage_key(resource);
|
||||
let payload = serde_json::to_vec(token).map_err(|err| {
|
||||
Error::Storage(format!(
|
||||
"Failed to serialize OAuth token for secure storage: {err}"
|
||||
))
|
||||
})?;
|
||||
self.storage
|
||||
.store_secure_item(&key, &payload, &self.master_key)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn load_oauth_token(&self, resource: &str) -> Result<Option<OAuthToken>> {
|
||||
let key = self.oauth_storage_key(resource);
|
||||
let raw = self
|
||||
.storage
|
||||
.load_secure_item(&key, &self.master_key)
|
||||
.await?;
|
||||
if let Some(bytes) = raw {
|
||||
let token = serde_json::from_slice(&bytes).map_err(|err| {
|
||||
Error::Storage(format!("Failed to deserialize stored OAuth token: {err}"))
|
||||
})?;
|
||||
Ok(Some(token))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delete_oauth_token(&self, resource: &str) -> Result<()> {
|
||||
let key = self.oauth_storage_key(resource);
|
||||
self.storage.delete_secure_item(&key).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ pub mod llm;
|
||||
pub mod mcp;
|
||||
pub mod mode;
|
||||
pub mod model;
|
||||
pub mod oauth;
|
||||
pub mod providers;
|
||||
pub mod router;
|
||||
pub mod sandbox;
|
||||
@@ -36,6 +37,7 @@ pub use credentials::*;
|
||||
pub use encryption::*;
|
||||
pub use formatting::*;
|
||||
pub use input::*;
|
||||
pub use oauth::*;
|
||||
// Export MCP types but exclude test_utils to avoid ambiguity
|
||||
pub use llm::{
|
||||
ChatStream, LlmProvider, Provider, ProviderConfig, ProviderRegistry, send_via_stream,
|
||||
|
||||
@@ -3,7 +3,10 @@
|
||||
/// Provides a unified interface for creating MCP clients based on configuration.
|
||||
/// Supports switching between local (in-process) and remote (STDIO) execution modes.
|
||||
use super::client::McpClient;
|
||||
use super::{LocalMcpClient, remote_client::RemoteMcpClient};
|
||||
use super::{
|
||||
LocalMcpClient,
|
||||
remote_client::{McpRuntimeSecrets, RemoteMcpClient},
|
||||
};
|
||||
use crate::config::{Config, McpMode};
|
||||
use crate::tools::registry::ToolRegistry;
|
||||
use crate::validation::SchemaValidator;
|
||||
@@ -33,6 +36,14 @@ impl McpClientFactory {
|
||||
|
||||
/// Create an MCP client based on the current configuration.
|
||||
pub fn create(&self) -> Result<Box<dyn McpClient>> {
|
||||
self.create_with_secrets(None)
|
||||
}
|
||||
|
||||
/// Create an MCP client using optional runtime secrets (OAuth tokens, env overrides).
|
||||
pub fn create_with_secrets(
|
||||
&self,
|
||||
runtime: Option<McpRuntimeSecrets>,
|
||||
) -> Result<Box<dyn McpClient>> {
|
||||
match self.config.mcp.mode {
|
||||
McpMode::Disabled => Err(Error::Config(
|
||||
"MCP mode is set to 'disabled'; tooling cannot function in this configuration."
|
||||
@@ -48,14 +59,14 @@ impl McpClientFactory {
|
||||
)))
|
||||
}
|
||||
McpMode::RemoteOnly => {
|
||||
let server_cfg = self.config.mcp_servers.first().ok_or_else(|| {
|
||||
let server_cfg = self.config.effective_mcp_servers().first().ok_or_else(|| {
|
||||
Error::Config(
|
||||
"MCP mode 'remote_only' requires at least one entry in [[mcp_servers]]"
|
||||
.to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
RemoteMcpClient::new_with_config(server_cfg)
|
||||
RemoteMcpClient::new_with_runtime(server_cfg, runtime)
|
||||
.map(|client| Box::new(client) as Box<dyn McpClient>)
|
||||
.map_err(|e| {
|
||||
Error::Config(format!(
|
||||
@@ -65,8 +76,8 @@ impl McpClientFactory {
|
||||
})
|
||||
}
|
||||
McpMode::RemotePreferred => {
|
||||
if let Some(server_cfg) = self.config.mcp_servers.first() {
|
||||
match RemoteMcpClient::new_with_config(server_cfg) {
|
||||
if let Some(server_cfg) = self.config.effective_mcp_servers().first() {
|
||||
match RemoteMcpClient::new_with_runtime(server_cfg, runtime.clone()) {
|
||||
Ok(client) => {
|
||||
info!(
|
||||
"Connected to remote MCP server '{}' via {} transport.",
|
||||
@@ -125,7 +136,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_factory_creates_local_client_when_no_servers_configured() {
|
||||
let config = Config::default();
|
||||
let mut config = Config::default();
|
||||
config.refresh_mcp_servers(None).unwrap();
|
||||
|
||||
let factory = build_factory(config);
|
||||
|
||||
@@ -139,6 +151,7 @@ mod tests {
|
||||
let mut config = Config::default();
|
||||
config.mcp.mode = McpMode::RemoteOnly;
|
||||
config.mcp_servers.clear();
|
||||
config.refresh_mcp_servers(None).unwrap();
|
||||
|
||||
let factory = build_factory(config);
|
||||
let result = factory.create();
|
||||
@@ -156,7 +169,9 @@ mod tests {
|
||||
args: Vec::new(),
|
||||
transport: "stdio".to_string(),
|
||||
env: std::collections::HashMap::new(),
|
||||
oauth: None,
|
||||
}];
|
||||
config.refresh_mcp_servers(None).unwrap();
|
||||
|
||||
let factory = build_factory(config);
|
||||
let result = factory.create();
|
||||
|
||||
@@ -305,6 +305,7 @@ mod tests {
|
||||
args: vec![],
|
||||
transport: "http".to_string(),
|
||||
env: std::collections::HashMap::new(),
|
||||
oauth: None,
|
||||
};
|
||||
|
||||
if let Ok(client) = RemoteMcpClient::new_with_config(&config) {
|
||||
|
||||
@@ -12,6 +12,7 @@ use anyhow::anyhow;
|
||||
use futures::{StreamExt, future::BoxFuture, stream};
|
||||
use reqwest::Client as HttpClient;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
@@ -39,6 +40,15 @@ pub struct RemoteMcpClient {
|
||||
ws_endpoint: Option<String>,
|
||||
// Incrementing request identifier.
|
||||
next_id: AtomicU64,
|
||||
// Optional HTTP header (name, value) injected into every request.
|
||||
http_header: Option<(String, String)>,
|
||||
}
|
||||
|
||||
/// Runtime secrets provided when constructing an MCP client.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct McpRuntimeSecrets {
|
||||
pub env_overrides: HashMap<String, String>,
|
||||
pub http_header: Option<(String, String)>,
|
||||
}
|
||||
|
||||
impl RemoteMcpClient {
|
||||
@@ -48,6 +58,14 @@ impl RemoteMcpClient {
|
||||
/// Spawn an external MCP server based on a configuration entry.
|
||||
/// The server must communicate over STDIO (the only supported transport).
|
||||
pub fn new_with_config(config: &crate::config::McpServerConfig) -> Result<Self> {
|
||||
Self::new_with_runtime(config, None)
|
||||
}
|
||||
|
||||
pub fn new_with_runtime(
|
||||
config: &crate::config::McpServerConfig,
|
||||
runtime: Option<McpRuntimeSecrets>,
|
||||
) -> Result<Self> {
|
||||
let mut runtime = runtime.unwrap_or_default();
|
||||
let transport = config.transport.to_lowercase();
|
||||
match transport.as_str() {
|
||||
"stdio" => {
|
||||
@@ -64,6 +82,9 @@ impl RemoteMcpClient {
|
||||
for (k, v) in config.env.iter() {
|
||||
cmd.env(k, v);
|
||||
}
|
||||
for (k, v) in runtime.env_overrides.drain() {
|
||||
cmd.env(k, v);
|
||||
}
|
||||
|
||||
let mut child = cmd.spawn().map_err(|e| {
|
||||
Error::Io(std::io::Error::new(
|
||||
@@ -92,6 +113,7 @@ impl RemoteMcpClient {
|
||||
ws_stream: None,
|
||||
ws_endpoint: None,
|
||||
next_id: AtomicU64::new(1),
|
||||
http_header: None,
|
||||
})
|
||||
}
|
||||
"http" => {
|
||||
@@ -109,6 +131,7 @@ impl RemoteMcpClient {
|
||||
ws_stream: None,
|
||||
ws_endpoint: None,
|
||||
next_id: AtomicU64::new(1),
|
||||
http_header: runtime.http_header.take(),
|
||||
})
|
||||
}
|
||||
"websocket" => {
|
||||
@@ -132,6 +155,7 @@ impl RemoteMcpClient {
|
||||
ws_stream: Some(Arc::new(Mutex::new(ws_stream))),
|
||||
ws_endpoint: Some(ws_url),
|
||||
next_id: AtomicU64::new(1),
|
||||
http_header: runtime.http_header.take(),
|
||||
})
|
||||
}
|
||||
other => Err(Error::NotImplemented(format!(
|
||||
@@ -171,6 +195,7 @@ impl RemoteMcpClient {
|
||||
args: Vec::new(),
|
||||
transport: "stdio".to_string(),
|
||||
env: std::collections::HashMap::new(),
|
||||
oauth: None,
|
||||
};
|
||||
Self::new_with_config(&config)
|
||||
}
|
||||
@@ -193,8 +218,11 @@ impl RemoteMcpClient {
|
||||
.http_endpoint
|
||||
.as_ref()
|
||||
.ok_or_else(|| Error::Network("Missing HTTP endpoint".into()))?;
|
||||
let resp = client
|
||||
.post(endpoint)
|
||||
let mut builder = client.post(endpoint);
|
||||
if let Some((ref header_name, ref header_value)) = self.http_header {
|
||||
builder = builder.header(header_name, header_value);
|
||||
}
|
||||
let resp = builder
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
|
||||
507
crates/owlen-core/src/oauth.rs
Normal file
507
crates/owlen-core/src/oauth.rs
Normal file
@@ -0,0 +1,507 @@
|
||||
use std::time::Duration as StdDuration;
|
||||
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{Error, Result, config::McpOAuthConfig};
|
||||
|
||||
/// Persisted OAuth token set for MCP servers and providers.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
|
||||
pub struct OAuthToken {
|
||||
/// Bearer access token returned by the authorization server.
|
||||
pub access_token: String,
|
||||
/// Optional refresh token if the provider issues one.
|
||||
#[serde(default)]
|
||||
pub refresh_token: Option<String>,
|
||||
/// Absolute UTC expiration timestamp for the access token.
|
||||
#[serde(default)]
|
||||
pub expires_at: Option<DateTime<Utc>>,
|
||||
/// Optional space-delimited scope string supplied by the provider.
|
||||
#[serde(default)]
|
||||
pub scope: Option<String>,
|
||||
/// Token type reported by the provider (typically `Bearer`).
|
||||
#[serde(default)]
|
||||
pub token_type: Option<String>,
|
||||
}
|
||||
|
||||
impl OAuthToken {
|
||||
/// Returns `true` if the access token has expired at the provided instant.
|
||||
pub fn is_expired(&self, now: DateTime<Utc>) -> bool {
|
||||
matches!(self.expires_at, Some(expiry) if now >= expiry)
|
||||
}
|
||||
|
||||
/// Returns `true` if the token will expire within the supplied duration window.
|
||||
pub fn will_expire_within(&self, window: Duration, now: DateTime<Utc>) -> bool {
|
||||
matches!(self.expires_at, Some(expiry) if expiry - now <= window)
|
||||
}
|
||||
}
|
||||
|
||||
/// Active device-authorization session details returned by the authorization server.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DeviceAuthorization {
|
||||
pub device_code: String,
|
||||
pub user_code: String,
|
||||
pub verification_uri: String,
|
||||
pub verification_uri_complete: Option<String>,
|
||||
pub expires_at: DateTime<Utc>,
|
||||
pub interval: StdDuration,
|
||||
pub message: Option<String>,
|
||||
}
|
||||
|
||||
impl DeviceAuthorization {
|
||||
pub fn is_expired(&self, now: DateTime<Utc>) -> bool {
|
||||
now >= self.expires_at
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of polling the token endpoint during a device-authorization flow.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum DevicePollState {
|
||||
Pending { retry_in: StdDuration },
|
||||
Complete(OAuthToken),
|
||||
}
|
||||
|
||||
pub struct OAuthClient {
|
||||
http: Client,
|
||||
config: McpOAuthConfig,
|
||||
}
|
||||
|
||||
impl OAuthClient {
|
||||
pub fn new(config: McpOAuthConfig) -> Result<Self> {
|
||||
let http = Client::builder()
|
||||
.user_agent("OwlenOAuth/1.0")
|
||||
.build()
|
||||
.map_err(|err| Error::Network(format!("Failed to construct HTTP client: {err}")))?;
|
||||
Ok(Self { http, config })
|
||||
}
|
||||
|
||||
fn scope_value(&self) -> Option<String> {
|
||||
if self.config.scopes.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(self.config.scopes.join(" "))
|
||||
}
|
||||
}
|
||||
|
||||
fn token_request_base(&self) -> Vec<(String, String)> {
|
||||
let mut params = vec![("client_id".to_string(), self.config.client_id.clone())];
|
||||
if let Some(secret) = &self.config.client_secret {
|
||||
params.push(("client_secret".to_string(), secret.clone()));
|
||||
}
|
||||
params
|
||||
}
|
||||
|
||||
pub async fn start_device_authorization(&self) -> Result<DeviceAuthorization> {
|
||||
let device_url = self
|
||||
.config
|
||||
.device_authorization_url
|
||||
.as_ref()
|
||||
.ok_or_else(|| {
|
||||
Error::Config("Device authorization endpoint is not configured.".to_string())
|
||||
})?;
|
||||
|
||||
let mut params = self.token_request_base();
|
||||
if let Some(scope) = self.scope_value() {
|
||||
params.push(("scope".to_string(), scope));
|
||||
}
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(device_url)
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| map_http_error("start device authorization", err))?;
|
||||
|
||||
let status = response.status();
|
||||
let payload = response
|
||||
.json::<DeviceAuthorizationResponse>()
|
||||
.await
|
||||
.map_err(|err| {
|
||||
Error::Auth(format!(
|
||||
"Failed to parse device authorization response (status {status}): {err}"
|
||||
))
|
||||
})?;
|
||||
|
||||
let expires_at =
|
||||
Utc::now() + Duration::seconds(payload.expires_in.min(i64::MAX as u64) as i64);
|
||||
let interval = StdDuration::from_secs(payload.interval.unwrap_or(5).max(1));
|
||||
|
||||
Ok(DeviceAuthorization {
|
||||
device_code: payload.device_code,
|
||||
user_code: payload.user_code,
|
||||
verification_uri: payload.verification_uri,
|
||||
verification_uri_complete: payload.verification_uri_complete,
|
||||
expires_at,
|
||||
interval,
|
||||
message: payload.message,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn poll_device_token(&self, auth: &DeviceAuthorization) -> Result<DevicePollState> {
|
||||
let mut params = self.token_request_base();
|
||||
params.push(("grant_type".to_string(), DEVICE_CODE_GRANT.to_string()));
|
||||
params.push(("device_code".to_string(), auth.device_code.clone()));
|
||||
if let Some(scope) = self.scope_value() {
|
||||
params.push(("scope".to_string(), scope));
|
||||
}
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&self.config.token_url)
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| map_http_error("poll device token", err))?;
|
||||
|
||||
let status = response.status();
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.map_err(|err| map_http_error("read token response", err))?;
|
||||
|
||||
if status.is_success() {
|
||||
let payload: TokenResponse = serde_json::from_str(&text).map_err(|err| {
|
||||
Error::Auth(format!(
|
||||
"Failed to parse OAuth token response: {err}; body: {text}"
|
||||
))
|
||||
})?;
|
||||
return Ok(DevicePollState::Complete(oauth_token_from_response(
|
||||
payload,
|
||||
)));
|
||||
}
|
||||
|
||||
let error = serde_json::from_str::<OAuthErrorResponse>(&text).unwrap_or_else(|_| {
|
||||
OAuthErrorResponse {
|
||||
error: "unknown_error".to_string(),
|
||||
error_description: Some(text.clone()),
|
||||
}
|
||||
});
|
||||
|
||||
match error.error.as_str() {
|
||||
"authorization_pending" => Ok(DevicePollState::Pending {
|
||||
retry_in: auth.interval,
|
||||
}),
|
||||
"slow_down" => Ok(DevicePollState::Pending {
|
||||
retry_in: auth.interval.saturating_add(StdDuration::from_secs(5)),
|
||||
}),
|
||||
"access_denied" => {
|
||||
Err(Error::Auth(error.error_description.unwrap_or_else(|| {
|
||||
"User declined authorization".to_string()
|
||||
})))
|
||||
}
|
||||
"expired_token" | "expired_device_code" => {
|
||||
Err(Error::Auth(error.error_description.unwrap_or_else(|| {
|
||||
"Device authorization expired".to_string()
|
||||
})))
|
||||
}
|
||||
other => Err(Error::Auth(
|
||||
error
|
||||
.error_description
|
||||
.unwrap_or_else(|| format!("OAuth error: {other}")),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn refresh_token(&self, refresh_token: &str) -> Result<OAuthToken> {
|
||||
let mut params = self.token_request_base();
|
||||
params.push(("grant_type".to_string(), "refresh_token".to_string()));
|
||||
params.push(("refresh_token".to_string(), refresh_token.to_string()));
|
||||
if let Some(scope) = self.scope_value() {
|
||||
params.push(("scope".to_string(), scope));
|
||||
}
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&self.config.token_url)
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| map_http_error("refresh OAuth token", err))?;
|
||||
|
||||
let status = response.status();
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.map_err(|err| map_http_error("read refresh response", err))?;
|
||||
|
||||
if status.is_success() {
|
||||
let payload: TokenResponse = serde_json::from_str(&text).map_err(|err| {
|
||||
Error::Auth(format!(
|
||||
"Failed to parse OAuth refresh response: {err}; body: {text}"
|
||||
))
|
||||
})?;
|
||||
Ok(oauth_token_from_response(payload))
|
||||
} else {
|
||||
let error = serde_json::from_str::<OAuthErrorResponse>(&text).unwrap_or_else(|_| {
|
||||
OAuthErrorResponse {
|
||||
error: "unknown_error".to_string(),
|
||||
error_description: Some(text.clone()),
|
||||
}
|
||||
});
|
||||
Err(Error::Auth(error.error_description.unwrap_or_else(|| {
|
||||
format!("OAuth token refresh failed: {}", error.error)
|
||||
})))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const DEVICE_CODE_GRANT: &str = "urn:ietf:params:oauth:grant-type:device_code";
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DeviceAuthorizationResponse {
|
||||
device_code: String,
|
||||
user_code: String,
|
||||
verification_uri: String,
|
||||
#[serde(default)]
|
||||
verification_uri_complete: Option<String>,
|
||||
expires_in: u64,
|
||||
#[serde(default)]
|
||||
interval: Option<u64>,
|
||||
#[serde(default)]
|
||||
message: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TokenResponse {
|
||||
access_token: String,
|
||||
#[serde(default)]
|
||||
refresh_token: Option<String>,
|
||||
#[serde(default)]
|
||||
expires_in: Option<u64>,
|
||||
#[serde(default)]
|
||||
scope: Option<String>,
|
||||
#[serde(default)]
|
||||
token_type: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OAuthErrorResponse {
|
||||
error: String,
|
||||
#[serde(default)]
|
||||
error_description: Option<String>,
|
||||
}
|
||||
|
||||
fn oauth_token_from_response(payload: TokenResponse) -> OAuthToken {
|
||||
let expires_at = payload
|
||||
.expires_in
|
||||
.map(|seconds| seconds.min(i64::MAX as u64) as i64)
|
||||
.map(|seconds| Utc::now() + Duration::seconds(seconds));
|
||||
|
||||
OAuthToken {
|
||||
access_token: payload.access_token,
|
||||
refresh_token: payload.refresh_token,
|
||||
expires_at,
|
||||
scope: payload.scope,
|
||||
token_type: payload.token_type,
|
||||
}
|
||||
}
|
||||
|
||||
fn map_http_error(action: &str, err: reqwest::Error) -> Error {
|
||||
if err.is_timeout() {
|
||||
Error::Timeout(format!("OAuth {action} request timed out: {err}"))
|
||||
} else if err.is_connect() {
|
||||
Error::Network(format!("OAuth {action} connection error: {err}"))
|
||||
} else {
|
||||
Error::Network(format!("OAuth {action} request failed: {err}"))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use httpmock::prelude::*;
|
||||
use serde_json::json;
|
||||
|
||||
fn config_for(server: &MockServer) -> McpOAuthConfig {
|
||||
McpOAuthConfig {
|
||||
client_id: "test-client".to_string(),
|
||||
client_secret: None,
|
||||
authorize_url: server.url("/authorize"),
|
||||
token_url: server.url("/token"),
|
||||
device_authorization_url: Some(server.url("/device")),
|
||||
redirect_url: None,
|
||||
scopes: vec!["repo".to_string(), "user".to_string()],
|
||||
token_env: None,
|
||||
header: None,
|
||||
header_prefix: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn sample_device_authorization() -> DeviceAuthorization {
|
||||
DeviceAuthorization {
|
||||
device_code: "device-123".to_string(),
|
||||
user_code: "ABCD-EFGH".to_string(),
|
||||
verification_uri: "https://example.test/activate".to_string(),
|
||||
verification_uri_complete: Some(
|
||||
"https://example.test/activate?user_code=ABCD-EFGH".to_string(),
|
||||
),
|
||||
expires_at: Utc::now() + Duration::minutes(10),
|
||||
interval: StdDuration::from_secs(5),
|
||||
message: Some("Open the verification URL and enter the code.".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn start_device_authorization_returns_payload() {
|
||||
let server = MockServer::start_async().await;
|
||||
let device_mock = server
|
||||
.mock_async(|when, then| {
|
||||
when.method(POST).path("/device");
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"device_code": "device-123",
|
||||
"user_code": "ABCD-EFGH",
|
||||
"verification_uri": "https://example.test/activate",
|
||||
"verification_uri_complete": "https://example.test/activate?user_code=ABCD-EFGH",
|
||||
"expires_in": 600,
|
||||
"interval": 7,
|
||||
"message": "Open the verification URL and enter the code."
|
||||
}));
|
||||
})
|
||||
.await;
|
||||
|
||||
let client = OAuthClient::new(config_for(&server)).expect("client");
|
||||
let auth = client
|
||||
.start_device_authorization()
|
||||
.await
|
||||
.expect("device authorization payload");
|
||||
|
||||
assert_eq!(auth.user_code, "ABCD-EFGH");
|
||||
assert_eq!(auth.interval, StdDuration::from_secs(7));
|
||||
assert!(auth.expires_at > Utc::now());
|
||||
device_mock.assert_async().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn poll_device_token_reports_pending() {
|
||||
let server = MockServer::start_async().await;
|
||||
let pending = server
|
||||
.mock_async(|when, then| {
|
||||
when.method(POST)
|
||||
.path("/token")
|
||||
.body_contains(
|
||||
"grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code",
|
||||
)
|
||||
.body_contains("device_code=device-123");
|
||||
then.status(400)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"error": "authorization_pending"
|
||||
}));
|
||||
})
|
||||
.await;
|
||||
|
||||
let config = config_for(&server);
|
||||
let client = OAuthClient::new(config).expect("client");
|
||||
let auth = sample_device_authorization();
|
||||
|
||||
let result = client.poll_device_token(&auth).await.expect("poll result");
|
||||
match result {
|
||||
DevicePollState::Pending { retry_in } => {
|
||||
assert_eq!(retry_in, StdDuration::from_secs(5));
|
||||
}
|
||||
other => panic!("expected pending state, got {other:?}"),
|
||||
}
|
||||
|
||||
pending.assert_async().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn poll_device_token_applies_slow_down_backoff() {
|
||||
let server = MockServer::start_async().await;
|
||||
let slow = server
|
||||
.mock_async(|when, then| {
|
||||
when.method(POST).path("/token");
|
||||
then.status(400)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"error": "slow_down"
|
||||
}));
|
||||
})
|
||||
.await;
|
||||
|
||||
let config = config_for(&server);
|
||||
let client = OAuthClient::new(config).expect("client");
|
||||
let auth = sample_device_authorization();
|
||||
|
||||
let result = client.poll_device_token(&auth).await.expect("poll result");
|
||||
match result {
|
||||
DevicePollState::Pending { retry_in } => {
|
||||
assert_eq!(retry_in, StdDuration::from_secs(10));
|
||||
}
|
||||
other => panic!("expected pending state, got {other:?}"),
|
||||
}
|
||||
|
||||
slow.assert_async().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn poll_device_token_returns_token_when_authorized() {
|
||||
let server = MockServer::start_async().await;
|
||||
let token = server
|
||||
.mock_async(|when, then| {
|
||||
when.method(POST).path("/token");
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"access_token": "token-abc",
|
||||
"refresh_token": "refresh-xyz",
|
||||
"expires_in": 3600,
|
||||
"token_type": "Bearer",
|
||||
"scope": "repo user"
|
||||
}));
|
||||
})
|
||||
.await;
|
||||
|
||||
let config = config_for(&server);
|
||||
let client = OAuthClient::new(config).expect("client");
|
||||
let auth = sample_device_authorization();
|
||||
|
||||
let result = client.poll_device_token(&auth).await.expect("poll result");
|
||||
let token_info = match result {
|
||||
DevicePollState::Complete(token) => token,
|
||||
other => panic!("expected completion, got {other:?}"),
|
||||
};
|
||||
|
||||
assert_eq!(token_info.access_token, "token-abc");
|
||||
assert_eq!(token_info.refresh_token.as_deref(), Some("refresh-xyz"));
|
||||
assert!(token_info.expires_at.is_some());
|
||||
token.assert_async().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn refresh_token_roundtrip() {
|
||||
let server = MockServer::start_async().await;
|
||||
let refresh = server
|
||||
.mock_async(|when, then| {
|
||||
when.method(POST)
|
||||
.path("/token")
|
||||
.body_contains("grant_type=refresh_token")
|
||||
.body_contains("refresh_token=old-refresh");
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"access_token": "token-new",
|
||||
"refresh_token": "refresh-new",
|
||||
"expires_in": 1200,
|
||||
"token_type": "Bearer"
|
||||
}));
|
||||
})
|
||||
.await;
|
||||
|
||||
let config = config_for(&server);
|
||||
let client = OAuthClient::new(config).expect("client");
|
||||
let token = client
|
||||
.refresh_token("old-refresh")
|
||||
.await
|
||||
.expect("refresh response");
|
||||
|
||||
assert_eq!(token.access_token, "token-new");
|
||||
assert_eq!(token.refresh_token.as_deref(), Some("refresh-new"));
|
||||
assert!(token.expires_at.is_some());
|
||||
refresh.assert_async().await;
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::config::Config;
|
||||
use crate::config::{Config, McpResourceConfig, McpServerConfig};
|
||||
use crate::consent::ConsentManager;
|
||||
use crate::conversation::ConversationManager;
|
||||
use crate::credentials::CredentialManager;
|
||||
@@ -9,8 +9,10 @@ use crate::mcp::McpToolCall;
|
||||
use crate::mcp::client::McpClient;
|
||||
use crate::mcp::factory::McpClientFactory;
|
||||
use crate::mcp::permission::PermissionLayer;
|
||||
use crate::mcp::remote_client::{McpRuntimeSecrets, RemoteMcpClient};
|
||||
use crate::mode::Mode;
|
||||
use crate::model::{DetailedModelInfo, ModelManager};
|
||||
use crate::oauth::{DeviceAuthorization, DevicePollState, OAuthClient};
|
||||
use crate::providers::OllamaProvider;
|
||||
use crate::storage::{SessionMeta, StorageManager};
|
||||
use crate::types::{
|
||||
@@ -24,8 +26,10 @@ use crate::{
|
||||
ToolRegistry, WebScrapeTool, WebSearchDetailedTool, WebSearchTool,
|
||||
};
|
||||
use crate::{Error, Result};
|
||||
use chrono::Utc;
|
||||
use log::warn;
|
||||
use serde_json::Value;
|
||||
use serde_json::{Value, json};
|
||||
use std::collections::HashMap;
|
||||
use std::env;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, Mutex};
|
||||
@@ -96,6 +100,7 @@ pub struct SessionController {
|
||||
tool_registry: Arc<ToolRegistry>,
|
||||
schema_validator: Arc<SchemaValidator>,
|
||||
mcp_client: Arc<dyn McpClient>,
|
||||
named_mcp_clients: HashMap<String, Arc<dyn McpClient>>,
|
||||
storage: Arc<StorageManager>,
|
||||
vault: Option<Arc<Mutex<VaultHandle>>>,
|
||||
master_key: Option<Arc<Vec<u8>>>,
|
||||
@@ -103,6 +108,7 @@ pub struct SessionController {
|
||||
ui: Arc<dyn UiController>,
|
||||
enable_code_tools: bool,
|
||||
current_mode: Mode,
|
||||
missing_oauth_servers: Vec<String>,
|
||||
}
|
||||
|
||||
async fn build_tools(
|
||||
@@ -211,6 +217,112 @@ async fn build_tools(
|
||||
}
|
||||
|
||||
impl SessionController {
|
||||
async fn create_mcp_clients(
|
||||
config: Arc<TokioMutex<Config>>,
|
||||
tool_registry: Arc<ToolRegistry>,
|
||||
schema_validator: Arc<SchemaValidator>,
|
||||
credential_manager: Option<Arc<CredentialManager>>,
|
||||
initial_mode: Mode,
|
||||
) -> Result<(
|
||||
Arc<dyn McpClient>,
|
||||
HashMap<String, Arc<dyn McpClient>>,
|
||||
Vec<String>,
|
||||
)> {
|
||||
let guard = config.lock().await;
|
||||
let config_arc = Arc::new(guard.clone());
|
||||
let factory = McpClientFactory::new(config_arc.clone(), tool_registry, schema_validator);
|
||||
|
||||
let mut missing_oauth_servers = Vec::new();
|
||||
let primary_runtime = if let Some(primary_cfg) = guard.effective_mcp_servers().first() {
|
||||
let (runtime, missing) =
|
||||
Self::runtime_secrets_for_server(credential_manager.clone(), primary_cfg).await?;
|
||||
if missing {
|
||||
missing_oauth_servers.push(primary_cfg.name.clone());
|
||||
}
|
||||
runtime
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let base_client = factory.create_with_secrets(primary_runtime)?;
|
||||
let primary: Arc<dyn McpClient> =
|
||||
Arc::new(PermissionLayer::new(base_client, config_arc.clone()));
|
||||
primary.set_mode(initial_mode).await?;
|
||||
|
||||
let mut clients: HashMap<String, Arc<dyn McpClient>> = HashMap::new();
|
||||
if let Some(primary_cfg) = guard.effective_mcp_servers().first() {
|
||||
clients.insert(primary_cfg.name.clone(), Arc::clone(&primary));
|
||||
}
|
||||
|
||||
for server_cfg in guard.effective_mcp_servers().iter().skip(1) {
|
||||
let (runtime, missing) =
|
||||
Self::runtime_secrets_for_server(credential_manager.clone(), server_cfg).await?;
|
||||
if missing {
|
||||
missing_oauth_servers.push(server_cfg.name.clone());
|
||||
}
|
||||
|
||||
match RemoteMcpClient::new_with_runtime(server_cfg, runtime) {
|
||||
Ok(remote) => {
|
||||
let client: Arc<dyn McpClient> =
|
||||
Arc::new(PermissionLayer::new(Box::new(remote), config_arc.clone()));
|
||||
if let Err(err) = client.set_mode(initial_mode).await {
|
||||
warn!(
|
||||
"Failed to initialize MCP server '{}' in mode {:?}: {}",
|
||||
server_cfg.name, initial_mode, err
|
||||
);
|
||||
}
|
||||
clients.insert(server_cfg.name.clone(), Arc::clone(&client));
|
||||
}
|
||||
Err(err) => warn!(
|
||||
"Failed to initialize MCP server '{}': {}",
|
||||
server_cfg.name, err
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
drop(guard);
|
||||
|
||||
Ok((primary, clients, missing_oauth_servers))
|
||||
}
|
||||
|
||||
async fn runtime_secrets_for_server(
|
||||
credential_manager: Option<Arc<CredentialManager>>,
|
||||
server: &McpServerConfig,
|
||||
) -> Result<(Option<McpRuntimeSecrets>, bool)> {
|
||||
if let Some(oauth) = &server.oauth {
|
||||
if let Some(manager) = credential_manager {
|
||||
match manager.load_oauth_token(&server.name).await? {
|
||||
Some(token) => {
|
||||
if token.access_token.trim().is_empty() || token.is_expired(Utc::now()) {
|
||||
return Ok((None, true));
|
||||
}
|
||||
let mut secrets = McpRuntimeSecrets::default();
|
||||
if let Some(env_name) = oauth.token_env.as_deref() {
|
||||
secrets
|
||||
.env_overrides
|
||||
.insert(env_name.to_string(), token.access_token.clone());
|
||||
}
|
||||
if matches!(
|
||||
server.transport.to_ascii_lowercase().as_str(),
|
||||
"http" | "websocket"
|
||||
) {
|
||||
let header_value =
|
||||
format!("{}{}", oauth.header_prefix(), token.access_token);
|
||||
secrets.http_header =
|
||||
Some((oauth.header_name().to_string(), header_value));
|
||||
}
|
||||
Ok((Some(secrets), false))
|
||||
}
|
||||
None => Ok((None, true)),
|
||||
}
|
||||
} else {
|
||||
Ok((None, true))
|
||||
}
|
||||
} else {
|
||||
Ok((None, false))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn new(
|
||||
provider: Arc<dyn Provider>,
|
||||
config: Config,
|
||||
@@ -292,19 +404,14 @@ impl SessionController {
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Create MCP client with permission layer
|
||||
let mcp_client: Arc<dyn McpClient> = {
|
||||
let guard = config_arc.lock().await;
|
||||
let factory = McpClientFactory::new(
|
||||
Arc::new(guard.clone()),
|
||||
tool_registry.clone(),
|
||||
schema_validator.clone(),
|
||||
);
|
||||
let base_client = factory.create()?;
|
||||
let client = Arc::new(PermissionLayer::new(base_client, Arc::new(guard.clone())));
|
||||
client.set_mode(initial_mode).await?;
|
||||
client
|
||||
};
|
||||
let (mcp_client, named_mcp_clients, missing_oauth_servers) = Self::create_mcp_clients(
|
||||
config_arc.clone(),
|
||||
tool_registry.clone(),
|
||||
schema_validator.clone(),
|
||||
credential_manager.clone(),
|
||||
initial_mode,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(Self {
|
||||
provider,
|
||||
@@ -317,6 +424,7 @@ impl SessionController {
|
||||
tool_registry,
|
||||
schema_validator,
|
||||
mcp_client,
|
||||
named_mcp_clients,
|
||||
storage,
|
||||
vault: vault_handle,
|
||||
master_key,
|
||||
@@ -324,6 +432,7 @@ impl SessionController {
|
||||
ui,
|
||||
enable_code_tools,
|
||||
current_mode: initial_mode,
|
||||
missing_oauth_servers,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -355,6 +464,63 @@ impl SessionController {
|
||||
self.formatter.set_role_label_mode(mode);
|
||||
}
|
||||
|
||||
/// Return the configured resource references aggregated across scopes.
|
||||
pub async fn configured_resources(&self) -> Vec<McpResourceConfig> {
|
||||
let guard = self.config.lock().await;
|
||||
guard.effective_mcp_resources().to_vec()
|
||||
}
|
||||
|
||||
/// Resolve a resource reference of the form `server:uri` (optionally prefixed with `@`).
|
||||
pub async fn resolve_resource_reference(&self, reference: &str) -> Result<Option<String>> {
|
||||
let (server, uri) = match Self::split_resource_reference(reference) {
|
||||
Some(parts) => parts,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
let resource_defined = {
|
||||
let guard = self.config.lock().await;
|
||||
guard.find_resource(&server, &uri).is_some()
|
||||
};
|
||||
|
||||
if !resource_defined {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let client = self
|
||||
.named_mcp_clients
|
||||
.get(&server)
|
||||
.cloned()
|
||||
.ok_or_else(|| {
|
||||
Error::Config(format!(
|
||||
"MCP server '{}' referenced by resource '{}' is not available",
|
||||
server, uri
|
||||
))
|
||||
})?;
|
||||
|
||||
let call = McpToolCall {
|
||||
name: "resources/get".to_string(),
|
||||
arguments: json!({ "uri": uri, "path": uri }),
|
||||
};
|
||||
let response = client.call_tool(call).await?;
|
||||
if let Some(text) = extract_resource_content(&response.output) {
|
||||
return Ok(Some(text));
|
||||
}
|
||||
|
||||
let formatted = serde_json::to_string_pretty(&response.output)
|
||||
.unwrap_or_else(|_| response.output.to_string());
|
||||
Ok(Some(formatted))
|
||||
}
|
||||
|
||||
fn split_resource_reference(reference: &str) -> Option<(String, String)> {
|
||||
let trimmed = reference.trim();
|
||||
let without_prefix = trimmed.strip_prefix('@').unwrap_or(trimmed);
|
||||
let (server, uri) = without_prefix.split_once(':')?;
|
||||
if server.is_empty() || uri.is_empty() {
|
||||
return None;
|
||||
}
|
||||
Some((server.to_string(), uri.to_string()))
|
||||
}
|
||||
|
||||
// Asynchronous access to the configuration (used internally).
|
||||
pub async fn config_async(&self) -> tokio::sync::MutexGuard<'_, Config> {
|
||||
self.config.lock().await
|
||||
@@ -378,6 +544,21 @@ impl SessionController {
|
||||
self.config.clone()
|
||||
}
|
||||
|
||||
pub async fn reload_mcp_clients(&mut self) -> Result<()> {
|
||||
let (primary, named, missing) = Self::create_mcp_clients(
|
||||
self.config.clone(),
|
||||
self.tool_registry.clone(),
|
||||
self.schema_validator.clone(),
|
||||
self.credential_manager.clone(),
|
||||
self.current_mode,
|
||||
)
|
||||
.await?;
|
||||
self.mcp_client = primary;
|
||||
self.named_mcp_clients = named;
|
||||
self.missing_oauth_servers = missing;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn grant_consent(&self, tool_name: &str, data_types: Vec<String>, endpoints: Vec<String>) {
|
||||
let mut consent = self
|
||||
.consent_manager
|
||||
@@ -525,6 +706,115 @@ impl SessionController {
|
||||
self.schema_validator.clone()
|
||||
}
|
||||
|
||||
pub fn credential_manager(&self) -> Option<Arc<CredentialManager>> {
|
||||
self.credential_manager.clone()
|
||||
}
|
||||
|
||||
pub fn pending_oauth_servers(&self) -> Vec<String> {
|
||||
self.missing_oauth_servers.clone()
|
||||
}
|
||||
|
||||
pub async fn start_oauth_device_flow(&self, server: &str) -> Result<DeviceAuthorization> {
|
||||
let oauth_config = {
|
||||
let config = self.config.lock().await;
|
||||
let server_cfg = config
|
||||
.effective_mcp_servers()
|
||||
.iter()
|
||||
.find(|entry| entry.name == server)
|
||||
.ok_or_else(|| {
|
||||
Error::Config(format!("No MCP server named '{server}' is configured"))
|
||||
})?;
|
||||
server_cfg.oauth.clone().ok_or_else(|| {
|
||||
Error::Config(format!(
|
||||
"MCP server '{server}' does not define an OAuth configuration"
|
||||
))
|
||||
})?
|
||||
};
|
||||
|
||||
let client = OAuthClient::new(oauth_config)?;
|
||||
client.start_device_authorization().await
|
||||
}
|
||||
|
||||
pub async fn poll_oauth_device_flow(
|
||||
&mut self,
|
||||
server: &str,
|
||||
authorization: &DeviceAuthorization,
|
||||
) -> Result<DevicePollState> {
|
||||
let oauth_config = {
|
||||
let config = self.config.lock().await;
|
||||
let server_cfg = config
|
||||
.effective_mcp_servers()
|
||||
.iter()
|
||||
.find(|entry| entry.name == server)
|
||||
.ok_or_else(|| {
|
||||
Error::Config(format!("No MCP server named '{server}' is configured"))
|
||||
})?;
|
||||
server_cfg.oauth.clone().ok_or_else(|| {
|
||||
Error::Config(format!(
|
||||
"MCP server '{server}' does not define an OAuth configuration"
|
||||
))
|
||||
})?
|
||||
};
|
||||
|
||||
let client = OAuthClient::new(oauth_config)?;
|
||||
match client.poll_device_token(authorization).await? {
|
||||
DevicePollState::Pending { retry_in } => Ok(DevicePollState::Pending { retry_in }),
|
||||
DevicePollState::Complete(token) => {
|
||||
let manager = self.credential_manager.as_ref().cloned().ok_or_else(|| {
|
||||
Error::Config(
|
||||
"OAuth token storage requires encrypted local data; set \
|
||||
privacy.encrypt_local_data = true in the configuration."
|
||||
.to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
manager.store_oauth_token(server, &token).await?;
|
||||
self.missing_oauth_servers.retain(|entry| entry != server);
|
||||
|
||||
Ok(DevicePollState::Complete(token))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn list_mcp_tools(&self) -> Vec<(String, crate::mcp::McpToolDescriptor)> {
|
||||
let mut entries = Vec::new();
|
||||
for (server, client) in self.named_mcp_clients.iter() {
|
||||
let server_name = server.clone();
|
||||
let client = Arc::clone(client);
|
||||
match client.list_tools().await {
|
||||
Ok(tools) => {
|
||||
for descriptor in tools {
|
||||
entries.push((server_name.clone(), descriptor));
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"Failed to list tools for MCP server '{}': {}",
|
||||
server_name, err
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
entries
|
||||
}
|
||||
|
||||
pub async fn call_mcp_tool(
|
||||
&self,
|
||||
server: &str,
|
||||
tool: &str,
|
||||
arguments: Value,
|
||||
) -> Result<crate::mcp::McpToolResponse> {
|
||||
let client = self.named_mcp_clients.get(server).cloned().ok_or_else(|| {
|
||||
Error::Config(format!("No MCP server named '{}' is registered", server))
|
||||
})?;
|
||||
client
|
||||
.call_tool(McpToolCall {
|
||||
name: tool.to_string(),
|
||||
arguments,
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub fn mcp_server(&self) -> crate::mcp::McpServer {
|
||||
crate::mcp::McpServer::new(self.tool_registry(), self.schema_validator())
|
||||
}
|
||||
@@ -985,3 +1275,195 @@ impl SessionController {
|
||||
Ok("Empty conversation".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::Provider;
|
||||
use crate::config::{Config, McpMode, McpOAuthConfig, McpServerConfig};
|
||||
use crate::llm::test_utils::MockProvider;
|
||||
use crate::storage::StorageManager;
|
||||
use crate::ui::NoOpUiController;
|
||||
use chrono::Utc;
|
||||
use httpmock::prelude::*;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tempfile::tempdir;
|
||||
|
||||
const SERVER_NAME: &str = "oauth-test";
|
||||
|
||||
fn build_oauth_config(server: &MockServer) -> McpOAuthConfig {
|
||||
McpOAuthConfig {
|
||||
client_id: "owlen-client".to_string(),
|
||||
client_secret: None,
|
||||
authorize_url: server.url("/authorize"),
|
||||
token_url: server.url("/token"),
|
||||
device_authorization_url: Some(server.url("/device")),
|
||||
redirect_url: None,
|
||||
scopes: vec!["repo".to_string()],
|
||||
token_env: Some("OAUTH_TOKEN".to_string()),
|
||||
header: Some("Authorization".to_string()),
|
||||
header_prefix: Some("Bearer ".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_config(server: &MockServer) -> Config {
|
||||
let mut config = Config::default();
|
||||
config.mcp.mode = McpMode::LocalOnly;
|
||||
let oauth = build_oauth_config(server);
|
||||
|
||||
let mut env = HashMap::new();
|
||||
env.insert("OWLEN_ENV".to_string(), "test".to_string());
|
||||
|
||||
config.mcp_servers = vec![McpServerConfig {
|
||||
name: SERVER_NAME.to_string(),
|
||||
command: server.url("/mcp"),
|
||||
args: Vec::new(),
|
||||
transport: "http".to_string(),
|
||||
env,
|
||||
oauth: Some(oauth),
|
||||
}];
|
||||
|
||||
config.refresh_mcp_servers(None).unwrap();
|
||||
config
|
||||
}
|
||||
|
||||
async fn build_session(server: &MockServer) -> (SessionController, tempfile::TempDir) {
|
||||
unsafe {
|
||||
std::env::set_var("OWLEN_MASTER_PASSWORD", "test-password");
|
||||
}
|
||||
|
||||
let temp_dir = tempdir().expect("tempdir");
|
||||
let storage_path = temp_dir.path().join("owlen.db");
|
||||
let storage = Arc::new(
|
||||
StorageManager::with_database_path(storage_path)
|
||||
.await
|
||||
.expect("storage"),
|
||||
);
|
||||
|
||||
let config = build_config(server);
|
||||
let provider: Arc<dyn Provider> = Arc::new(MockProvider::default()) as Arc<dyn Provider>;
|
||||
let ui = Arc::new(NoOpUiController);
|
||||
|
||||
let session = SessionController::new(provider, config, storage, ui, false)
|
||||
.await
|
||||
.expect("session");
|
||||
|
||||
(session, temp_dir)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn start_oauth_device_flow_returns_details() {
|
||||
let server = MockServer::start_async().await;
|
||||
let device = server
|
||||
.mock_async(|when, then| {
|
||||
when.method(POST).path("/device");
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"device_code": "device-abc",
|
||||
"user_code": "ABCD-1234",
|
||||
"verification_uri": "https://example.test/activate",
|
||||
"verification_uri_complete": "https://example.test/activate?user_code=ABCD-1234",
|
||||
"expires_in": 600,
|
||||
"interval": 5,
|
||||
"message": "Enter the code to continue."
|
||||
}));
|
||||
})
|
||||
.await;
|
||||
|
||||
let (session, _dir) = build_session(&server).await;
|
||||
let authorization = session
|
||||
.start_oauth_device_flow(SERVER_NAME)
|
||||
.await
|
||||
.expect("device flow");
|
||||
|
||||
assert_eq!(authorization.user_code, "ABCD-1234");
|
||||
assert_eq!(
|
||||
authorization.verification_uri_complete.as_deref(),
|
||||
Some("https://example.test/activate?user_code=ABCD-1234")
|
||||
);
|
||||
assert!(authorization.expires_at > Utc::now());
|
||||
device.assert_async().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn poll_oauth_device_flow_stores_token_and_updates_state() {
|
||||
let server = MockServer::start_async().await;
|
||||
|
||||
let device = server
|
||||
.mock_async(|when, then| {
|
||||
when.method(POST).path("/device");
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"device_code": "device-xyz",
|
||||
"user_code": "WXYZ-9999",
|
||||
"verification_uri": "https://example.test/activate",
|
||||
"verification_uri_complete": "https://example.test/activate?user_code=WXYZ-9999",
|
||||
"expires_in": 600,
|
||||
"interval": 5
|
||||
}));
|
||||
})
|
||||
.await;
|
||||
|
||||
let token = server
|
||||
.mock_async(|when, then| {
|
||||
when.method(POST)
|
||||
.path("/token")
|
||||
.body_contains("device_code=device-xyz");
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"access_token": "new-access-token",
|
||||
"refresh_token": "refresh-token",
|
||||
"expires_in": 3600,
|
||||
"token_type": "Bearer"
|
||||
}));
|
||||
})
|
||||
.await;
|
||||
|
||||
let (mut session, _dir) = build_session(&server).await;
|
||||
assert_eq!(session.pending_oauth_servers(), vec![SERVER_NAME]);
|
||||
|
||||
let authorization = session
|
||||
.start_oauth_device_flow(SERVER_NAME)
|
||||
.await
|
||||
.expect("device flow");
|
||||
|
||||
match session
|
||||
.poll_oauth_device_flow(SERVER_NAME, &authorization)
|
||||
.await
|
||||
.expect("token poll")
|
||||
{
|
||||
DevicePollState::Complete(token_info) => {
|
||||
assert_eq!(token_info.access_token, "new-access-token");
|
||||
assert_eq!(token_info.refresh_token.as_deref(), Some("refresh-token"));
|
||||
}
|
||||
other => panic!("expected token completion, got {other:?}"),
|
||||
}
|
||||
|
||||
assert!(
|
||||
session
|
||||
.pending_oauth_servers()
|
||||
.iter()
|
||||
.all(|entry| entry != SERVER_NAME),
|
||||
"server should be removed from pending list"
|
||||
);
|
||||
|
||||
let stored = session
|
||||
.credential_manager()
|
||||
.expect("credential manager")
|
||||
.load_oauth_token(SERVER_NAME)
|
||||
.await
|
||||
.expect("load token")
|
||||
.expect("token present");
|
||||
|
||||
assert_eq!(stored.access_token, "new-access-token");
|
||||
assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token"));
|
||||
|
||||
device.assert_async().await;
|
||||
token.assert_async().await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,6 +44,7 @@ async fn test_render_prompt_via_external_server() -> Result<()> {
|
||||
args: Vec::new(),
|
||||
transport: "stdio".into(),
|
||||
env: std::collections::HashMap::new(),
|
||||
oauth: None,
|
||||
};
|
||||
|
||||
let client = match RemoteMcpClient::new_with_config(&config) {
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
//! crates can depend only on `owlen-mcp-client` without pulling in the entire
|
||||
//! core crate internals.
|
||||
|
||||
pub use owlen_core::config::{McpConfigScope, ScopedMcpServer};
|
||||
pub use owlen_core::mcp::remote_client::RemoteMcpClient;
|
||||
pub use owlen_core::mcp::{McpClient, McpToolCall, McpToolDescriptor, McpToolResponse};
|
||||
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use chrono::{DateTime, Local};
|
||||
use chrono::{DateTime, Local, Utc};
|
||||
use crossterm::terminal::{disable_raw_mode, enable_raw_mode};
|
||||
use owlen_core::mcp::remote_client::RemoteMcpClient;
|
||||
use owlen_core::mcp::{McpToolDescriptor, McpToolResponse};
|
||||
use owlen_core::{
|
||||
Provider, ProviderConfig,
|
||||
config::McpResourceConfig,
|
||||
model::DetailedModelInfo,
|
||||
oauth::{DeviceAuthorization, DevicePollState},
|
||||
session::{SessionController, SessionOutcome},
|
||||
storage::SessionMeta,
|
||||
theme::Theme,
|
||||
@@ -19,7 +22,7 @@ use tokio::{
|
||||
sync::mpsc,
|
||||
task::{self, JoinHandle},
|
||||
};
|
||||
use tui_textarea::{Input, TextArea};
|
||||
use tui_textarea::{CursorMove, Input, TextArea};
|
||||
use unicode_width::UnicodeWidthStr;
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -27,12 +30,14 @@ use crate::commands;
|
||||
use crate::config;
|
||||
use crate::events::Event;
|
||||
use crate::model_info_panel::ModelInfoPanel;
|
||||
use crate::slash::{self, McpSlashCommand, SlashCommand};
|
||||
use crate::state::{
|
||||
CodeWorkspace, CommandPalette, FileFilterMode, FileIconResolver, FileNode, FileTreeState,
|
||||
ModelPaletteEntry, PaletteSuggestion, PaneDirection, PaneRestoreRequest, RepoSearchMessage,
|
||||
RepoSearchState, SplitAxis, SymbolSearchMessage, SymbolSearchState, WorkspaceSnapshot,
|
||||
spawn_repo_search_task, spawn_symbol_search_task,
|
||||
};
|
||||
use crate::toast::{Toast, ToastLevel, ToastManager};
|
||||
use crate::ui::format_tool_output;
|
||||
// Agent executor moved to separate binary `owlen-agent`. The TUI no longer directly
|
||||
// imports `AgentExecutor` to avoid a circular dependency on `owlen-cli`.
|
||||
@@ -48,6 +53,7 @@ use std::sync::Arc;
|
||||
use std::time::{Duration, Instant, SystemTime};
|
||||
|
||||
use dirs::{config_dir, data_local_dir};
|
||||
use serde_json::{Value, json};
|
||||
|
||||
const ONBOARDING_STATUS_LINE: &str =
|
||||
"Welcome to Owlen! Press F1 for help or type :tutorial for keybinding tips.";
|
||||
@@ -61,6 +67,13 @@ const RESIZE_DOUBLE_TAP_WINDOW: Duration = Duration::from_millis(450);
|
||||
const RESIZE_STEP: f32 = 0.05;
|
||||
const RESIZE_SNAP_VALUES: [f32; 3] = [0.5, 0.75, 0.25];
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum SlashOutcome {
|
||||
NotCommand,
|
||||
Consumed,
|
||||
Error,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub(crate) struct ModelSelectorItem {
|
||||
kind: ModelSelectorItemKind,
|
||||
@@ -158,6 +171,11 @@ pub enum SessionEvent {
|
||||
AgentCompleted { answer: String },
|
||||
/// Agent execution failed
|
||||
AgentFailed { error: String },
|
||||
/// Poll the OAuth device authorization flow for the given server
|
||||
OAuthPoll {
|
||||
server: String,
|
||||
authorization: DeviceAuthorization,
|
||||
},
|
||||
}
|
||||
|
||||
pub const HELP_TAB_COUNT: usize = 7;
|
||||
@@ -205,6 +223,9 @@ pub struct ChatApp {
|
||||
clipboard: String, // Vim-style clipboard for yank/paste
|
||||
pending_file_action: Option<FileActionPrompt>, // Active file action prompt
|
||||
command_palette: CommandPalette, // Command mode state (buffer + suggestions)
|
||||
resource_catalog: Vec<McpResourceConfig>, // Configured MCP resources for autocompletion
|
||||
pending_resource_refs: Vec<String>, // Resource references to resolve before send
|
||||
oauth_flows: HashMap<String, DeviceAuthorization>, // Active OAuth device flows by server
|
||||
repo_search: RepoSearchState, // Repository search overlay state
|
||||
repo_search_task: Option<JoinHandle<()>>,
|
||||
repo_search_rx: Option<mpsc::UnboundedReceiver<RepoSearchMessage>>,
|
||||
@@ -235,6 +256,7 @@ pub struct ChatApp {
|
||||
selected_theme_index: usize, // Index of selected theme in browser
|
||||
pending_consent: Option<ConsentDialogState>, // Pending consent request
|
||||
system_status: String, // System/status messages (tool execution, status, etc)
|
||||
toasts: ToastManager,
|
||||
/// Simple execution budget: maximum number of tool calls allowed per session.
|
||||
_execution_budget: usize,
|
||||
/// Agent mode enabled
|
||||
@@ -438,6 +460,9 @@ impl ChatApp {
|
||||
clipboard: String::new(),
|
||||
pending_file_action: None,
|
||||
command_palette: CommandPalette::new(),
|
||||
resource_catalog: Vec::new(),
|
||||
pending_resource_refs: Vec::new(),
|
||||
oauth_flows: HashMap::new(),
|
||||
repo_search: RepoSearchState::new(),
|
||||
repo_search_task: None,
|
||||
repo_search_rx: None,
|
||||
@@ -472,6 +497,7 @@ impl ChatApp {
|
||||
} else {
|
||||
String::new()
|
||||
},
|
||||
toasts: ToastManager::new(),
|
||||
_execution_budget: 50,
|
||||
agent_mode: false,
|
||||
agent_running: false,
|
||||
@@ -490,6 +516,8 @@ impl ChatApp {
|
||||
));
|
||||
|
||||
app.update_command_palette_catalog();
|
||||
app.refresh_resource_catalog().await?;
|
||||
app.refresh_mcp_slash_commands().await?;
|
||||
|
||||
if let Err(err) = app.restore_workspace_layout().await {
|
||||
eprintln!("Warning: failed to restore workspace layout: {err}");
|
||||
@@ -1371,6 +1399,18 @@ impl ChatApp {
|
||||
&self.theme
|
||||
}
|
||||
|
||||
pub fn toasts(&self) -> impl Iterator<Item = &Toast> {
|
||||
self.toasts.iter()
|
||||
}
|
||||
|
||||
pub fn push_toast(&mut self, level: ToastLevel, message: impl Into<String>) {
|
||||
self.toasts.push(message, level);
|
||||
}
|
||||
|
||||
fn prune_toasts(&mut self) {
|
||||
self.toasts.retain_active();
|
||||
}
|
||||
|
||||
pub fn input_max_rows(&self) -> u16 {
|
||||
let config = self.controller.config();
|
||||
config.ui.input_max_rows.max(1)
|
||||
@@ -1443,6 +1483,304 @@ impl ChatApp {
|
||||
.update_dynamic_sources(models, providers);
|
||||
}
|
||||
|
||||
async fn refresh_resource_catalog(&mut self) -> Result<()> {
|
||||
let mut resources = self.controller.configured_resources().await;
|
||||
resources.sort_by(|a, b| a.server.cmp(&b.server).then(a.uri.cmp(&b.uri)));
|
||||
self.resource_catalog = resources;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn refresh_mcp_slash_commands(&mut self) -> Result<()> {
|
||||
let mut commands = Vec::new();
|
||||
for (server, descriptor) in self.controller.list_mcp_tools().await {
|
||||
if !Self::tool_supports_slash(&descriptor) {
|
||||
continue;
|
||||
}
|
||||
let description = if descriptor.description.trim().is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(descriptor.description.clone())
|
||||
};
|
||||
commands.push(McpSlashCommand::new(
|
||||
server,
|
||||
descriptor.name.clone(),
|
||||
description,
|
||||
));
|
||||
}
|
||||
slash::set_mcp_commands(commands);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn tool_supports_slash(descriptor: &McpToolDescriptor) -> bool {
|
||||
if descriptor.name.trim().is_empty() {
|
||||
return false;
|
||||
}
|
||||
Self::tool_allows_empty_arguments(&descriptor.input_schema)
|
||||
}
|
||||
|
||||
fn tool_allows_empty_arguments(schema: &Value) -> bool {
|
||||
match schema {
|
||||
Value::Object(map) => {
|
||||
if let Some(Value::Array(required)) = map.get("required") {
|
||||
!required
|
||||
.iter()
|
||||
.any(|entry| entry.as_str().is_some_and(|s| !s.is_empty()))
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn format_mcp_slash_message(server: &str, tool: &str, response: &McpToolResponse) -> String {
|
||||
let status = if response.success { "✓" } else { "✗" };
|
||||
let payload = if response.success {
|
||||
Self::extract_mcp_primary_text(&response.output)
|
||||
} else {
|
||||
Self::extract_mcp_error(&response.output)
|
||||
.or_else(|| Self::extract_mcp_primary_text(&response.output))
|
||||
}
|
||||
.unwrap_or_else(|| Self::pretty_print_value(&response.output));
|
||||
|
||||
if payload.trim().is_empty() {
|
||||
return format!("MCP {server}::{tool} {status}");
|
||||
}
|
||||
|
||||
if payload.contains('\n') {
|
||||
format!("MCP {server}::{tool} {status}\n```json\n{payload}\n```")
|
||||
} else {
|
||||
format!("MCP {server}::{tool} {status}\n{payload}")
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_mcp_primary_text(value: &Value) -> Option<String> {
|
||||
if let Some(text) = value.as_str().filter(|text| !text.trim().is_empty()) {
|
||||
return Some(text.to_string());
|
||||
}
|
||||
|
||||
if let Value::Object(map) = value {
|
||||
const CANDIDATES: [&str; 6] =
|
||||
["rendered", "text", "content", "value", "message", "body"];
|
||||
for key in CANDIDATES {
|
||||
if let Some(Value::String(text)) = map.get(key)
|
||||
&& !text.trim().is_empty()
|
||||
{
|
||||
return Some(text.clone());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(Value::Array(items)) = map.get("lines") {
|
||||
let mut collected = Vec::new();
|
||||
for item in items {
|
||||
if let Some(segment) = item.as_str()
|
||||
&& !segment.trim().is_empty()
|
||||
{
|
||||
collected.push(segment.trim());
|
||||
}
|
||||
}
|
||||
if !collected.is_empty() {
|
||||
return Some(collected.join("\n"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn extract_mcp_error(value: &Value) -> Option<String> {
|
||||
if let Value::Object(map) = value
|
||||
&& let Some(Value::String(message)) = map.get("error")
|
||||
&& !message.trim().is_empty()
|
||||
{
|
||||
return Some(message.clone());
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn pretty_print_value(value: &Value) -> String {
|
||||
serde_json::to_string_pretty(value).unwrap_or_else(|_| value.to_string())
|
||||
}
|
||||
|
||||
async fn resolve_pending_resource_references(&mut self) -> Result<()> {
|
||||
if self.pending_resource_refs.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut resolved = 0usize;
|
||||
let references: Vec<String> = self.pending_resource_refs.drain(..).collect();
|
||||
for reference in references {
|
||||
match self.controller.resolve_resource_reference(&reference).await {
|
||||
Ok(Some(content)) => {
|
||||
let message = format!("Resource @{}:\n{}", reference, content);
|
||||
self.controller
|
||||
.conversation_mut()
|
||||
.push_system_message(message);
|
||||
resolved += 1;
|
||||
}
|
||||
Ok(None) => {
|
||||
self.push_toast(
|
||||
ToastLevel::Warning,
|
||||
format!(
|
||||
"Resource @{} is not defined in the current project.",
|
||||
reference
|
||||
),
|
||||
);
|
||||
}
|
||||
Err(err) => {
|
||||
self.push_toast(
|
||||
ToastLevel::Error,
|
||||
format!("Failed to load resource @{}: {}", reference, err),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if resolved > 0 {
|
||||
self.status = format!("Inserted {resolved} resource snippet(s).");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn complete_resource_reference(&mut self) -> bool {
|
||||
if self.resource_catalog.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
let (row, col) = self.textarea.cursor();
|
||||
let lines = self.textarea.lines().to_vec();
|
||||
if row >= lines.len() {
|
||||
return false;
|
||||
}
|
||||
|
||||
let line = &lines[row];
|
||||
let chars: Vec<char> = line.chars().collect();
|
||||
if col > chars.len() {
|
||||
return false;
|
||||
}
|
||||
|
||||
let mut start = col;
|
||||
while start > 0 {
|
||||
let ch = chars[start - 1];
|
||||
if ch == '@' {
|
||||
start -= 1;
|
||||
break;
|
||||
}
|
||||
if ch.is_whitespace() {
|
||||
return false;
|
||||
}
|
||||
start -= 1;
|
||||
}
|
||||
|
||||
if start >= col || chars.get(start) != Some(&'@') {
|
||||
return false;
|
||||
}
|
||||
|
||||
if chars[start + 1..col].iter().any(|ch| ch.is_whitespace()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let mut end = col;
|
||||
while end < chars.len() {
|
||||
let ch = chars[end];
|
||||
if ch.is_whitespace() {
|
||||
break;
|
||||
}
|
||||
end += 1;
|
||||
}
|
||||
|
||||
let typed_prefix: String = chars[start + 1..col].iter().collect();
|
||||
let trailing_segment: String = chars[col..end].iter().collect();
|
||||
let lower_prefix = typed_prefix.to_ascii_lowercase();
|
||||
let lower_full = format!("{}{}", typed_prefix, trailing_segment).to_ascii_lowercase();
|
||||
|
||||
let mut matches: Vec<&McpResourceConfig> = self
|
||||
.resource_catalog
|
||||
.iter()
|
||||
.filter(|resource| {
|
||||
let reference = format!("{}:{}", resource.server, resource.uri);
|
||||
let lower_reference = reference.to_ascii_lowercase();
|
||||
lower_reference.starts_with(&lower_full)
|
||||
|| lower_reference.starts_with(&lower_prefix)
|
||||
|| resource
|
||||
.title
|
||||
.as_ref()
|
||||
.map(|title| title.to_ascii_lowercase().starts_with(&lower_prefix))
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.collect();
|
||||
|
||||
if matches.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
matches.sort_by(|a, b| a.server.cmp(&b.server).then(a.uri.cmp(&b.uri)));
|
||||
let (selected_server, selected_uri, selected_title) = {
|
||||
let selected = matches[0];
|
||||
(
|
||||
selected.server.clone(),
|
||||
selected.uri.clone(),
|
||||
selected.title.clone(),
|
||||
)
|
||||
};
|
||||
let replacement = format!("@{}:{}", selected_server, selected_uri);
|
||||
|
||||
let mut new_line = String::new();
|
||||
new_line.extend(chars[..start].iter());
|
||||
new_line.push_str(&replacement);
|
||||
new_line.extend(chars[end..].iter());
|
||||
|
||||
let mut new_lines = lines;
|
||||
new_lines[row] = new_line;
|
||||
self.textarea = TextArea::new(new_lines);
|
||||
configure_textarea_defaults(&mut self.textarea);
|
||||
|
||||
let new_col = start + replacement.len();
|
||||
self.textarea
|
||||
.move_cursor(CursorMove::Jump(row as u16, new_col as u16));
|
||||
|
||||
self.sync_textarea_to_buffer();
|
||||
|
||||
if let Some(title) = selected_title.as_deref() {
|
||||
self.status = format!("Inserted resource {} ({title}).", replacement);
|
||||
} else {
|
||||
self.status = format!("Inserted resource {}.", replacement);
|
||||
}
|
||||
self.error = None;
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
fn extract_resource_references(text: &str) -> Vec<String> {
|
||||
let mut references = Vec::new();
|
||||
let mut current = String::new();
|
||||
let mut in_reference = false;
|
||||
|
||||
for ch in text.chars() {
|
||||
if in_reference {
|
||||
if ch.is_whitespace() || matches!(ch, ',' | ';' | ')' | '(' | '.' | '!' | '?') {
|
||||
if current.contains(':') {
|
||||
references.push(current.clone());
|
||||
}
|
||||
current.clear();
|
||||
in_reference = false;
|
||||
} else {
|
||||
current.push(ch);
|
||||
}
|
||||
} else if ch == '@' {
|
||||
in_reference = true;
|
||||
current.clear();
|
||||
}
|
||||
}
|
||||
|
||||
if in_reference && current.contains(':') {
|
||||
references.push(current);
|
||||
}
|
||||
|
||||
references
|
||||
}
|
||||
|
||||
fn display_name_for_model(model: &ModelInfo) -> String {
|
||||
if model.name.trim().is_empty() {
|
||||
model.id.clone()
|
||||
@@ -2110,6 +2448,204 @@ impl ChatApp {
|
||||
configure_textarea_defaults(&mut self.textarea);
|
||||
}
|
||||
|
||||
async fn process_slash_submission(&mut self) -> Result<SlashOutcome> {
|
||||
let raw = self.controller.input_buffer().text().to_string();
|
||||
if raw.trim().is_empty() {
|
||||
return Ok(SlashOutcome::NotCommand);
|
||||
}
|
||||
|
||||
match slash::parse(&raw) {
|
||||
Ok(None) => Ok(SlashOutcome::NotCommand),
|
||||
Ok(Some(command)) => match self.execute_slash_command(command).await {
|
||||
Ok(()) => {
|
||||
self.input_buffer_mut().push_history_entry(raw.clone());
|
||||
self.controller.input_buffer_mut().clear();
|
||||
Ok(SlashOutcome::Consumed)
|
||||
}
|
||||
Err(err) => {
|
||||
self.error = Some(err.to_string());
|
||||
self.status = "Slash command failed".to_string();
|
||||
self.controller.input_buffer_mut().set_text(raw);
|
||||
Ok(SlashOutcome::Error)
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
self.error = Some(err.to_string());
|
||||
self.status = "Slash command error".to_string();
|
||||
Ok(SlashOutcome::Error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn execute_slash_command(&mut self, command: SlashCommand) -> Result<()> {
|
||||
match command {
|
||||
SlashCommand::Summarize { count } => {
|
||||
let prompt = if let Some(count) = count {
|
||||
format!(
|
||||
"Summarize the last {count} messages in this conversation. Highlight key decisions, open questions, and follow-up tasks."
|
||||
)
|
||||
} else {
|
||||
"Summarize the conversation so far, calling out major decisions, blockers, and immediate next steps.".to_string()
|
||||
};
|
||||
self.status = "Summarizing conversation...".to_string();
|
||||
self.dispatch_user_prompt(prompt);
|
||||
}
|
||||
SlashCommand::Explain { snippet } => {
|
||||
let prompt = format!(
|
||||
"Explain the following code snippet. Cover what it does and call out any potential issues or improvements:\n```\n{}\n```",
|
||||
snippet
|
||||
);
|
||||
self.status = "Explaining snippet...".to_string();
|
||||
self.dispatch_user_prompt(prompt);
|
||||
}
|
||||
SlashCommand::Refactor { path } => {
|
||||
let trimmed = path.trim();
|
||||
if trimmed.is_empty() {
|
||||
anyhow::bail!("usage: /refactor <relative/path/to/file>");
|
||||
}
|
||||
let source = self.controller.read_file(trimmed).await?;
|
||||
let prompt = format!(
|
||||
"Refactor the file `{}`. Provide specific improvements for readability, safety, and maintainability. Include updated code where relevant.\n\n```text\n{}\n```",
|
||||
trimmed, source
|
||||
);
|
||||
self.status = format!("Refactor review for {trimmed}...");
|
||||
self.dispatch_user_prompt(prompt);
|
||||
}
|
||||
SlashCommand::TestPlan => {
|
||||
let prompt = "Generate a comprehensive test plan for this repository. Outline critical test suites, coverage gaps, and prioritized steps to reach confident automation.".to_string();
|
||||
self.status = "Generating test plan...".to_string();
|
||||
self.dispatch_user_prompt(prompt);
|
||||
}
|
||||
SlashCommand::Compact => {
|
||||
let prompt = "Compress our conversation history to its essentials. Summarize previous exchanges, preserve critical context, and indicate what state can be safely forgotten.".to_string();
|
||||
self.status = "Compacting conversation...".to_string();
|
||||
self.dispatch_user_prompt(prompt);
|
||||
}
|
||||
SlashCommand::McpTool { server, tool } => {
|
||||
self.status = format!("Running MCP tool {server}::{tool}...");
|
||||
let response = self
|
||||
.controller
|
||||
.call_mcp_tool(&server, &tool, json!({}))
|
||||
.await
|
||||
.map_err(|err| {
|
||||
anyhow!("Failed to invoke MCP tool {}::{}: {}", server, tool, err)
|
||||
})?;
|
||||
|
||||
let content = Self::format_mcp_slash_message(&server, &tool, &response);
|
||||
self.controller
|
||||
.conversation_mut()
|
||||
.push_system_message(content);
|
||||
self.auto_scroll.stick_to_bottom = true;
|
||||
self.new_message_alert = true;
|
||||
|
||||
if response.success {
|
||||
self.status = format!("MCP {server}::{tool} result added to chat.");
|
||||
self.push_toast(ToastLevel::Info, format!("MCP {server}::{tool} completed."));
|
||||
} else {
|
||||
self.status = format!("MCP {server}::{tool} reported an error (see chat).");
|
||||
self.push_toast(
|
||||
ToastLevel::Warning,
|
||||
format!("MCP {server}::{tool} reported an error."),
|
||||
);
|
||||
}
|
||||
self.error = None;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn schedule_oauth_poll(
|
||||
&self,
|
||||
server: String,
|
||||
authorization: DeviceAuthorization,
|
||||
delay: Duration,
|
||||
) {
|
||||
let sender = self.session_tx.clone();
|
||||
tokio::spawn(async move {
|
||||
tokio::time::sleep(delay).await;
|
||||
let _ = sender.send(SessionEvent::OAuthPoll {
|
||||
server,
|
||||
authorization,
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
async fn start_oauth_login(&mut self, server: &str) -> Result<()> {
|
||||
if self.oauth_flows.contains_key(server) {
|
||||
self.error = Some(format!("OAuth flow for '{server}' is already in progress."));
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let authorization = match self.controller.start_oauth_device_flow(server).await {
|
||||
Ok(auth) => auth,
|
||||
Err(err) => {
|
||||
self.error = Some(format!("Failed to start OAuth for '{server}': {err}"));
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
self.oauth_flows
|
||||
.insert(server.to_string(), authorization.clone());
|
||||
|
||||
let link = authorization
|
||||
.verification_uri_complete
|
||||
.clone()
|
||||
.unwrap_or_else(|| authorization.verification_uri.clone());
|
||||
let status = format!(
|
||||
"Authorize '{server}' via {} (code {}).",
|
||||
link, authorization.user_code
|
||||
);
|
||||
self.status = status;
|
||||
self.error = None;
|
||||
|
||||
let mut message = format!(
|
||||
"OAuth authorization required for `{server}`.\nVisit:\n{}\nEnter code: `{}`",
|
||||
link, authorization.user_code
|
||||
);
|
||||
if let Some(hint) = &authorization.message
|
||||
&& !hint.trim().is_empty()
|
||||
{
|
||||
message.push_str("\n\n");
|
||||
message.push_str(hint);
|
||||
}
|
||||
if authorization.expires_at > Utc::now() {
|
||||
message.push_str(&format!(
|
||||
"\n\nThis code expires at {}.",
|
||||
authorization
|
||||
.expires_at
|
||||
.to_rfc3339_opts(chrono::SecondsFormat::Secs, true)
|
||||
));
|
||||
}
|
||||
|
||||
self.controller
|
||||
.conversation_mut()
|
||||
.push_system_message(message);
|
||||
self.auto_scroll.stick_to_bottom = true;
|
||||
self.notify_new_activity();
|
||||
|
||||
self.push_toast(
|
||||
ToastLevel::Warning,
|
||||
format!("Authorize {server}: code {}", authorization.user_code),
|
||||
);
|
||||
|
||||
let delay = authorization.interval;
|
||||
self.schedule_oauth_poll(server.to_string(), authorization.clone(), delay);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn dispatch_user_prompt(&mut self, prompt: String) {
|
||||
if prompt.trim().is_empty() {
|
||||
self.error = Some("Slash command generated an empty request".to_string());
|
||||
return;
|
||||
}
|
||||
|
||||
self.controller.conversation_mut().push_user_message(prompt);
|
||||
self.auto_scroll.stick_to_bottom = true;
|
||||
self.pending_llm_request = true;
|
||||
self.set_system_status(String::new());
|
||||
self.error = None;
|
||||
}
|
||||
|
||||
fn set_code_view_content(
|
||||
&mut self,
|
||||
display_path: impl Into<String>,
|
||||
@@ -2216,14 +2752,14 @@ impl ChatApp {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn restore_workspace_layout(&mut self) -> Result<()> {
|
||||
async fn restore_workspace_layout(&mut self) -> Result<bool> {
|
||||
let path = match self.workspace_layout_path() {
|
||||
Ok(path) => path,
|
||||
Err(_) => return Ok(()),
|
||||
Err(_) => return Ok(false),
|
||||
};
|
||||
|
||||
if !path.exists() {
|
||||
return Ok(());
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let contents = fs::read_to_string(&path)
|
||||
@@ -2247,7 +2783,7 @@ impl ChatApp {
|
||||
self.status = "Workspace layout restored".to_string();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
Ok(restored_any)
|
||||
}
|
||||
|
||||
fn direction_label(direction: PaneDirection) -> &'static str {
|
||||
@@ -3289,6 +3825,7 @@ impl ChatApp {
|
||||
Event::Tick => {
|
||||
self.poll_repo_search();
|
||||
self.poll_symbol_search();
|
||||
self.prune_toasts();
|
||||
// Future: update streaming timers
|
||||
}
|
||||
Event::Resize(width, height) => {
|
||||
@@ -4172,13 +4709,24 @@ impl ChatApp {
|
||||
self.textarea.insert_newline();
|
||||
}
|
||||
(KeyCode::Enter, KeyModifiers::NONE) => {
|
||||
// Send message and return to normal mode
|
||||
self.sync_textarea_to_buffer();
|
||||
self.send_user_message_and_request_response();
|
||||
// Clear the textarea by setting it to empty
|
||||
self.textarea = TextArea::default();
|
||||
configure_textarea_defaults(&mut self.textarea);
|
||||
self.set_input_mode(InputMode::Normal);
|
||||
match self.process_slash_submission().await? {
|
||||
SlashOutcome::NotCommand => {
|
||||
self.send_user_message_and_request_response();
|
||||
self.textarea = TextArea::default();
|
||||
configure_textarea_defaults(&mut self.textarea);
|
||||
self.set_input_mode(InputMode::Normal);
|
||||
}
|
||||
SlashOutcome::Consumed => {
|
||||
self.textarea = TextArea::default();
|
||||
configure_textarea_defaults(&mut self.textarea);
|
||||
self.set_input_mode(InputMode::Normal);
|
||||
}
|
||||
SlashOutcome::Error => {
|
||||
// Restore textarea content so the user can correct the command
|
||||
self.sync_buffer_to_textarea();
|
||||
}
|
||||
}
|
||||
}
|
||||
(KeyCode::Enter, _) => {
|
||||
// Any Enter with modifiers keeps editing and inserts a newline via tui-textarea
|
||||
@@ -4208,6 +4756,11 @@ impl ChatApp {
|
||||
self.textarea
|
||||
.move_cursor(tui_textarea::CursorMove::WordBack);
|
||||
}
|
||||
(KeyCode::Tab, m) if m.is_empty() => {
|
||||
if !self.complete_resource_reference() {
|
||||
self.textarea.input(Input::from(key));
|
||||
}
|
||||
}
|
||||
(KeyCode::Char('r'), m) if m.contains(KeyModifiers::CONTROL) => {
|
||||
// Redo - history next
|
||||
self.input_buffer_mut().history_next();
|
||||
@@ -4538,6 +5091,31 @@ impl ChatApp {
|
||||
}
|
||||
}
|
||||
}
|
||||
"oauth" => {
|
||||
if args.is_empty() {
|
||||
let pending = self.controller.pending_oauth_servers();
|
||||
if pending.is_empty() {
|
||||
self.status =
|
||||
"No OAuth-enabled MCP servers require authorization."
|
||||
.to_string();
|
||||
} else {
|
||||
self.status = format!(
|
||||
"Pending OAuth servers: {}",
|
||||
pending.join(", ")
|
||||
);
|
||||
}
|
||||
self.error = None;
|
||||
} else if args.len() == 1 {
|
||||
self.start_oauth_login(args[0]).await?;
|
||||
} else if args.len() == 2
|
||||
&& args[0].eq_ignore_ascii_case("login")
|
||||
{
|
||||
self.start_oauth_login(args[1]).await?;
|
||||
} else {
|
||||
self.error =
|
||||
Some("Usage: :oauth [login] <server>".to_string());
|
||||
}
|
||||
}
|
||||
"load" | "o" => {
|
||||
// Load saved sessions and enter browser mode
|
||||
match self.controller.list_saved_sessions().await {
|
||||
@@ -5015,29 +5593,58 @@ impl ChatApp {
|
||||
if self.code_workspace.tabs().is_empty() {
|
||||
self.status =
|
||||
"No open panes to save".to_string();
|
||||
self.error = None;
|
||||
self.push_toast(
|
||||
ToastLevel::Warning,
|
||||
"Open a pane before saving layout.",
|
||||
);
|
||||
} else {
|
||||
self.persist_workspace_layout();
|
||||
self.status =
|
||||
"Workspace layout saved".to_string();
|
||||
self.error = None;
|
||||
self.push_toast(
|
||||
ToastLevel::Success,
|
||||
"Workspace layout saved.",
|
||||
);
|
||||
}
|
||||
}
|
||||
"load" => match self.restore_workspace_layout().await {
|
||||
Ok(()) => {
|
||||
Ok(true) => {
|
||||
self.status =
|
||||
"Workspace layout restored".to_string();
|
||||
self.error = None;
|
||||
self.push_toast(
|
||||
ToastLevel::Success,
|
||||
"Workspace layout restored.",
|
||||
);
|
||||
}
|
||||
Ok(false) => {
|
||||
self.status =
|
||||
"No saved layout to restore".to_string();
|
||||
self.error = None;
|
||||
self.push_toast(
|
||||
ToastLevel::Info,
|
||||
"No saved layout was found.",
|
||||
);
|
||||
}
|
||||
Err(err) => {
|
||||
self.error = Some(err.to_string());
|
||||
let message = format!(
|
||||
"Failed to restore workspace layout: {}",
|
||||
err
|
||||
);
|
||||
self.error = Some(message.clone());
|
||||
self.status =
|
||||
"Failed to restore workspace layout"
|
||||
.to_string();
|
||||
self.push_toast(ToastLevel::Error, message);
|
||||
}
|
||||
},
|
||||
other => {
|
||||
self.status =
|
||||
format!("Unknown layout command: {other}");
|
||||
self.error = Some(format!(
|
||||
"Unknown layout command: {other}"
|
||||
"Unknown layout subcommand: {other}"
|
||||
));
|
||||
}
|
||||
}
|
||||
@@ -5068,6 +5675,27 @@ impl ChatApp {
|
||||
self.error = None;
|
||||
self.sync_ui_preferences_from_config();
|
||||
self.update_command_palette_catalog();
|
||||
if let Err(err) = self.refresh_resource_catalog().await
|
||||
{
|
||||
self.push_toast(
|
||||
ToastLevel::Error,
|
||||
format!(
|
||||
"Failed to refresh MCP resources: {}",
|
||||
err
|
||||
),
|
||||
);
|
||||
}
|
||||
if let Err(err) =
|
||||
self.refresh_mcp_slash_commands().await
|
||||
{
|
||||
self.push_toast(
|
||||
ToastLevel::Error,
|
||||
format!(
|
||||
"Failed to refresh MCP slash commands: {}",
|
||||
err
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
self.error =
|
||||
@@ -5666,7 +6294,7 @@ impl ChatApp {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn handle_session_event(&mut self, event: SessionEvent) -> Result<()> {
|
||||
pub async fn handle_session_event(&mut self, event: SessionEvent) -> Result<()> {
|
||||
match event {
|
||||
SessionEvent::StreamChunk {
|
||||
message_id,
|
||||
@@ -5760,6 +6388,52 @@ impl ChatApp {
|
||||
self.agent_actions = None;
|
||||
self.stop_loading_animation();
|
||||
}
|
||||
SessionEvent::OAuthPoll {
|
||||
server,
|
||||
authorization,
|
||||
} => {
|
||||
match self
|
||||
.controller
|
||||
.poll_oauth_device_flow(&server, &authorization)
|
||||
.await
|
||||
{
|
||||
Ok(DevicePollState::Pending { retry_in }) => {
|
||||
self.oauth_flows
|
||||
.insert(server.clone(), authorization.clone());
|
||||
let server_name = server.clone();
|
||||
self.schedule_oauth_poll(server, authorization, retry_in);
|
||||
self.status = format!("Waiting for OAuth approval for {server_name}...");
|
||||
}
|
||||
Ok(DevicePollState::Complete(_token)) => {
|
||||
self.oauth_flows.remove(&server);
|
||||
self.push_toast(
|
||||
ToastLevel::Success,
|
||||
format!("OAuth authorization complete for {server}."),
|
||||
);
|
||||
self.status = format!("OAuth authorization complete for {server}.");
|
||||
if let Err(err) = self.refresh_resource_catalog().await {
|
||||
self.push_toast(
|
||||
ToastLevel::Error,
|
||||
format!("Failed to refresh MCP resources: {err}"),
|
||||
);
|
||||
}
|
||||
if let Err(err) = self.refresh_mcp_slash_commands().await {
|
||||
self.push_toast(
|
||||
ToastLevel::Error,
|
||||
format!("Failed to refresh MCP slash commands: {err}"),
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
self.oauth_flows.remove(&server);
|
||||
self.error = Some(format!("OAuth flow for '{server}' failed: {err}"));
|
||||
self.push_toast(
|
||||
ToastLevel::Error,
|
||||
format!("OAuth failure for {server}: {err}"),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -5825,6 +6499,7 @@ impl ChatApp {
|
||||
args: Vec::new(),
|
||||
transport: "stdio".to_string(),
|
||||
env: env_vars.clone(),
|
||||
oauth: None,
|
||||
};
|
||||
RemoteMcpClient::new_with_config(&config)
|
||||
} else {
|
||||
@@ -6176,6 +6851,7 @@ impl ChatApp {
|
||||
args: Vec::new(),
|
||||
transport: "stdio".to_string(),
|
||||
env: env_vars,
|
||||
oauth: None,
|
||||
};
|
||||
Arc::new(RemoteMcpClient::new_with_config(&config)?)
|
||||
} else {
|
||||
@@ -6423,6 +7099,10 @@ impl ChatApp {
|
||||
|
||||
// Step 1: Add user message to conversation immediately (synchronous)
|
||||
let message = self.controller.input_buffer_mut().commit_to_history();
|
||||
let mut references = Self::extract_resource_references(&message);
|
||||
references.sort();
|
||||
references.dedup();
|
||||
self.pending_resource_refs = references;
|
||||
self.controller
|
||||
.conversation_mut()
|
||||
.push_user_message(message.clone());
|
||||
@@ -6539,6 +7219,8 @@ impl ChatApp {
|
||||
|
||||
self.pending_llm_request = false;
|
||||
|
||||
self.resolve_pending_resource_references().await?;
|
||||
|
||||
// Check if agent mode is enabled
|
||||
if self.agent_mode {
|
||||
return self.process_agent_request().await;
|
||||
|
||||
@@ -28,8 +28,8 @@ impl CodeApp {
|
||||
self.inner.handle_event(event).await
|
||||
}
|
||||
|
||||
pub fn handle_session_event(&mut self, event: SessionEvent) -> Result<()> {
|
||||
self.inner.handle_session_event(event)
|
||||
pub async fn handle_session_event(&mut self, event: SessionEvent) -> Result<()> {
|
||||
self.inner.handle_session_event(event).await
|
||||
}
|
||||
|
||||
pub fn mode(&self) -> InputMode {
|
||||
|
||||
@@ -235,7 +235,7 @@ pub fn match_score(candidate: &str, query: &str) -> Option<(usize, usize)> {
|
||||
if candidate_normalized == query_normalized {
|
||||
Some((0, candidate.len()))
|
||||
} else if candidate_normalized.starts_with(&query_normalized) {
|
||||
Some((1, candidate.len()))
|
||||
Some((1, 0))
|
||||
} else if let Some(pos) = candidate_normalized.find(&query_normalized) {
|
||||
Some((2, pos))
|
||||
} else if is_subsequence(&candidate_normalized, &query_normalized) {
|
||||
|
||||
@@ -18,7 +18,9 @@ pub mod commands;
|
||||
pub mod config;
|
||||
pub mod events;
|
||||
pub mod model_info_panel;
|
||||
pub mod slash;
|
||||
pub mod state;
|
||||
pub mod toast;
|
||||
pub mod tui_controller;
|
||||
pub mod ui;
|
||||
|
||||
|
||||
238
crates/owlen-tui/src/slash.rs
Normal file
238
crates/owlen-tui/src/slash.rs
Normal file
@@ -0,0 +1,238 @@
|
||||
//! Slash command parsing for chat input.
|
||||
//!
|
||||
//! Provides lightweight handling for inline commands such as `/summarize`
|
||||
//! and `/testplan`. The parser returns owned data so callers can prepare
|
||||
//! requests immediately without additional lifetime juggling.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::str::FromStr;
|
||||
use std::sync::{OnceLock, RwLock};
|
||||
|
||||
/// Supported slash commands.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum SlashCommand {
|
||||
Summarize { count: Option<usize> },
|
||||
Explain { snippet: String },
|
||||
Refactor { path: String },
|
||||
TestPlan,
|
||||
Compact,
|
||||
McpTool { server: String, tool: String },
|
||||
}
|
||||
|
||||
/// Errors emitted when parsing invalid slash input.
|
||||
#[derive(Debug)]
|
||||
pub enum SlashError {
|
||||
UnknownCommand(String),
|
||||
Message(String),
|
||||
}
|
||||
|
||||
impl fmt::Display for SlashError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
SlashError::UnknownCommand(cmd) => write!(f, "unknown slash command: {cmd}"),
|
||||
SlashError::Message(msg) => f.write_str(msg),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for SlashError {}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct McpSlashCommand {
|
||||
pub server: String,
|
||||
pub tool: String,
|
||||
pub keyword: String,
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
impl McpSlashCommand {
|
||||
pub fn new(
|
||||
server: impl Into<String>,
|
||||
tool: impl Into<String>,
|
||||
description: Option<String>,
|
||||
) -> Self {
|
||||
let server = server.into();
|
||||
let tool = tool.into();
|
||||
let keyword = format!(
|
||||
"mcp__{}__{}",
|
||||
canonicalize_component(&server),
|
||||
canonicalize_component(&tool)
|
||||
);
|
||||
Self {
|
||||
server,
|
||||
tool,
|
||||
keyword,
|
||||
description,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static MCP_COMMANDS: OnceLock<RwLock<HashMap<String, McpSlashCommand>>> = OnceLock::new();
|
||||
|
||||
fn dynamic_registry() -> &'static RwLock<HashMap<String, McpSlashCommand>> {
|
||||
MCP_COMMANDS.get_or_init(|| RwLock::new(HashMap::new()))
|
||||
}
|
||||
|
||||
pub fn set_mcp_commands(commands: impl IntoIterator<Item = McpSlashCommand>) {
|
||||
let registry = dynamic_registry();
|
||||
let mut guard = registry.write().expect("MCP command registry poisoned");
|
||||
guard.clear();
|
||||
for command in commands {
|
||||
guard.insert(command.keyword.clone(), command);
|
||||
}
|
||||
}
|
||||
|
||||
fn find_mcp_command(keyword: &str) -> Option<McpSlashCommand> {
|
||||
let registry = dynamic_registry();
|
||||
let guard = registry.read().expect("MCP command registry poisoned");
|
||||
guard.get(keyword).cloned()
|
||||
}
|
||||
|
||||
fn canonicalize_component(input: &str) -> String {
|
||||
let mut out = String::new();
|
||||
let mut last_was_underscore = false;
|
||||
for ch in input.chars() {
|
||||
let mapped = if ch.is_ascii_alphanumeric() {
|
||||
ch.to_ascii_lowercase()
|
||||
} else {
|
||||
'_'
|
||||
};
|
||||
if mapped == '_' {
|
||||
if !last_was_underscore {
|
||||
out.push('_');
|
||||
last_was_underscore = true;
|
||||
}
|
||||
} else {
|
||||
out.push(mapped);
|
||||
last_was_underscore = false;
|
||||
}
|
||||
}
|
||||
if out.is_empty() { "_".to_string() } else { out }
|
||||
}
|
||||
|
||||
/// Attempt to parse a slash command from the provided input.
|
||||
pub fn parse(input: &str) -> Result<Option<SlashCommand>, SlashError> {
|
||||
let trimmed = input.trim();
|
||||
if !trimmed.starts_with('/') {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let body = trimmed.trim_start_matches('/');
|
||||
if body.is_empty() {
|
||||
return Err(SlashError::Message("missing command name after '/'".into()));
|
||||
}
|
||||
|
||||
let mut parts = body.split_whitespace();
|
||||
let command = parts.next().unwrap();
|
||||
let remainder = parts.collect::<Vec<_>>();
|
||||
|
||||
if let Some(dynamic) = find_mcp_command(command) {
|
||||
if !remainder.is_empty() {
|
||||
return Err(SlashError::Message(format!(
|
||||
"/{} does not accept arguments",
|
||||
dynamic.keyword
|
||||
)));
|
||||
}
|
||||
return Ok(Some(SlashCommand::McpTool {
|
||||
server: dynamic.server,
|
||||
tool: dynamic.tool,
|
||||
}));
|
||||
}
|
||||
|
||||
let cmd = match command {
|
||||
"summarize" => {
|
||||
let count = remainder
|
||||
.first()
|
||||
.and_then(|value| usize::from_str(value).ok());
|
||||
SlashCommand::Summarize { count }
|
||||
}
|
||||
"explain" => {
|
||||
if remainder.is_empty() {
|
||||
return Err(SlashError::Message(
|
||||
"usage: /explain <code snippet or description>".into(),
|
||||
));
|
||||
}
|
||||
SlashCommand::Explain {
|
||||
snippet: remainder.join(" "),
|
||||
}
|
||||
}
|
||||
"refactor" => {
|
||||
if remainder.is_empty() {
|
||||
return Err(SlashError::Message(
|
||||
"usage: /refactor <relative/path/to/file>".into(),
|
||||
));
|
||||
}
|
||||
SlashCommand::Refactor {
|
||||
path: remainder.join(" "),
|
||||
}
|
||||
}
|
||||
"testplan" => SlashCommand::TestPlan,
|
||||
"compact" => SlashCommand::Compact,
|
||||
other => return Err(SlashError::UnknownCommand(other.to_string())),
|
||||
};
|
||||
|
||||
Ok(Some(cmd))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn ignores_non_command_input() {
|
||||
let result = parse("hello world").unwrap();
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_summarize_with_count() {
|
||||
let command = parse("/summarize 10").unwrap().expect("expected command");
|
||||
match command {
|
||||
SlashCommand::Summarize { count } => assert_eq!(count, Some(10)),
|
||||
other => panic!("unexpected command: {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn returns_error_for_unknown_command() {
|
||||
let err = parse("/unknown").unwrap_err();
|
||||
assert_eq!(err.to_string(), "unknown slash command: unknown");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_registered_mcp_command() {
|
||||
set_mcp_commands(Vec::new());
|
||||
set_mcp_commands(vec![McpSlashCommand::new("github", "list_prs", None)]);
|
||||
|
||||
let command = parse("/mcp__github__list_prs")
|
||||
.unwrap()
|
||||
.expect("expected command");
|
||||
match command {
|
||||
SlashCommand::McpTool { server, tool } => {
|
||||
assert_eq!(server, "github");
|
||||
assert_eq!(tool, "list_prs");
|
||||
}
|
||||
other => panic!("unexpected command variant: {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_mcp_command_with_arguments() {
|
||||
set_mcp_commands(Vec::new());
|
||||
set_mcp_commands(vec![McpSlashCommand::new("github", "list_prs", None)]);
|
||||
|
||||
let err = parse("/mcp__github__list_prs extra").unwrap_err();
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"/mcp__github__list_prs does not accept arguments"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn canonicalizes_mcp_command_components() {
|
||||
set_mcp_commands(Vec::new());
|
||||
let entry = McpSlashCommand::new("GitHub", "list/prs", None);
|
||||
assert_eq!(entry.keyword, "mcp__github__list_prs");
|
||||
}
|
||||
}
|
||||
96
crates/owlen-tui/src/theme_util.rs
Normal file
96
crates/owlen-tui/src/theme_util.rs
Normal file
@@ -0,0 +1,96 @@
|
||||
macro_rules! adjust_fields {
|
||||
($theme:expr, $func:expr, $($field:ident),+ $(,)?) => {
|
||||
$(
|
||||
$theme.$field = $func($theme.$field);
|
||||
)+
|
||||
};
|
||||
}
|
||||
|
||||
use owlen_core::theme::Theme;
|
||||
use ratatui::style::Color;
|
||||
|
||||
/// Return a clone of `base` with contrast adjustments applied.
|
||||
/// Positive `steps` increase contrast, negative values decrease it.
|
||||
pub fn with_contrast(base: &Theme, steps: i8) -> Theme {
|
||||
if steps == 0 {
|
||||
return base.clone();
|
||||
}
|
||||
|
||||
let factor = (1.0 + (steps as f32) * 0.18).clamp(0.3, 2.0);
|
||||
let adjust = |color: Color| adjust_color(color, factor);
|
||||
|
||||
let mut theme = base.clone();
|
||||
adjust_fields!(
|
||||
theme,
|
||||
adjust,
|
||||
text,
|
||||
background,
|
||||
focused_panel_border,
|
||||
unfocused_panel_border,
|
||||
focus_beacon_fg,
|
||||
focus_beacon_bg,
|
||||
unfocused_beacon_fg,
|
||||
pane_header_active,
|
||||
pane_header_inactive,
|
||||
pane_hint_text,
|
||||
user_message_role,
|
||||
assistant_message_role,
|
||||
tool_output,
|
||||
thinking_panel_title,
|
||||
command_bar_background,
|
||||
status_background,
|
||||
mode_normal,
|
||||
mode_editing,
|
||||
mode_model_selection,
|
||||
mode_provider_selection,
|
||||
mode_help,
|
||||
mode_visual,
|
||||
mode_command,
|
||||
selection_bg,
|
||||
selection_fg,
|
||||
cursor,
|
||||
code_block_background,
|
||||
code_block_border,
|
||||
code_block_text,
|
||||
code_block_keyword,
|
||||
code_block_string,
|
||||
code_block_comment,
|
||||
placeholder,
|
||||
error,
|
||||
info,
|
||||
agent_thought,
|
||||
agent_action,
|
||||
agent_action_input,
|
||||
agent_observation,
|
||||
agent_final_answer,
|
||||
agent_badge_running_fg,
|
||||
agent_badge_running_bg,
|
||||
agent_badge_idle_fg,
|
||||
agent_badge_idle_bg,
|
||||
operating_chat_fg,
|
||||
operating_chat_bg,
|
||||
operating_code_fg,
|
||||
operating_code_bg
|
||||
);
|
||||
|
||||
theme
|
||||
}
|
||||
|
||||
fn adjust_color(color: Color, factor: f32) -> Color {
|
||||
match color {
|
||||
Color::Rgb(r, g, b) => {
|
||||
let adjust_component = |component: u8| -> u8 {
|
||||
let normalized = component as f32 / 255.0;
|
||||
let contrasted = ((normalized - 0.5) * factor + 0.5).clamp(0.0, 1.0);
|
||||
(contrasted * 255.0).round().clamp(0.0, 255.0) as u8
|
||||
};
|
||||
|
||||
Color::Rgb(
|
||||
adjust_component(r),
|
||||
adjust_component(g),
|
||||
adjust_component(b),
|
||||
)
|
||||
}
|
||||
_ => color,
|
||||
}
|
||||
}
|
||||
114
crates/owlen-tui/src/toast.rs
Normal file
114
crates/owlen-tui/src/toast.rs
Normal file
@@ -0,0 +1,114 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
/// Severity level for toast notifications.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ToastLevel {
|
||||
Info,
|
||||
Success,
|
||||
Warning,
|
||||
Error,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Toast {
|
||||
pub message: String,
|
||||
pub level: ToastLevel,
|
||||
created: Instant,
|
||||
duration: Duration,
|
||||
}
|
||||
|
||||
impl Toast {
|
||||
fn new(message: String, level: ToastLevel, lifetime: Duration) -> Self {
|
||||
Self {
|
||||
message,
|
||||
level,
|
||||
created: Instant::now(),
|
||||
duration: lifetime,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_expired(&self, now: Instant) -> bool {
|
||||
now.duration_since(self.created) >= self.duration
|
||||
}
|
||||
}
|
||||
|
||||
/// Fixed-size toast queue with automatic expiration.
|
||||
#[derive(Debug)]
|
||||
pub struct ToastManager {
|
||||
items: VecDeque<Toast>,
|
||||
max_active: usize,
|
||||
lifetime: Duration,
|
||||
}
|
||||
|
||||
impl Default for ToastManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ToastManager {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
items: VecDeque::new(),
|
||||
max_active: 3,
|
||||
lifetime: Duration::from_secs(3),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_lifetime(mut self, duration: Duration) -> Self {
|
||||
self.lifetime = duration;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn push(&mut self, message: impl Into<String>, level: ToastLevel) {
|
||||
let toast = Toast::new(message.into(), level, self.lifetime);
|
||||
self.items.push_front(toast);
|
||||
while self.items.len() > self.max_active {
|
||||
self.items.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn retain_active(&mut self) {
|
||||
let now = Instant::now();
|
||||
self.items.retain(|toast| !toast.is_expired(now));
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> impl Iterator<Item = &Toast> {
|
||||
self.items.iter()
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.items.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::thread::sleep;
|
||||
|
||||
#[test]
|
||||
fn manager_limits_active_toasts() {
|
||||
let mut manager = ToastManager::new();
|
||||
manager.push("first", ToastLevel::Info);
|
||||
manager.push("second", ToastLevel::Warning);
|
||||
manager.push("third", ToastLevel::Success);
|
||||
manager.push("fourth", ToastLevel::Error);
|
||||
|
||||
let collected: Vec<_> = manager.iter().map(|toast| toast.message.clone()).collect();
|
||||
assert_eq!(collected.len(), 3);
|
||||
assert_eq!(collected[0], "fourth");
|
||||
assert_eq!(collected[2], "second");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_expires_toasts_after_lifetime() {
|
||||
let mut manager = ToastManager::new().with_lifetime(Duration::from_millis(1));
|
||||
manager.push("short lived", ToastLevel::Info);
|
||||
assert!(!manager.is_empty());
|
||||
sleep(Duration::from_millis(5));
|
||||
manager.retain_active();
|
||||
assert!(manager.is_empty());
|
||||
}
|
||||
}
|
||||
@@ -16,10 +16,12 @@ use crate::state::{
|
||||
CodePane, EditorTab, FileFilterMode, FileNode, LayoutNode, PaletteGroup, PaneId,
|
||||
RepoSearchRowKind, SplitAxis, VisibleFileEntry,
|
||||
};
|
||||
use crate::toast::{Toast, ToastLevel};
|
||||
use owlen_core::model::DetailedModelInfo;
|
||||
use owlen_core::theme::Theme;
|
||||
use owlen_core::types::{ModelInfo, Role};
|
||||
use owlen_core::ui::{FocusedPanel, InputMode, RoleLabelDisplay};
|
||||
use textwrap::wrap;
|
||||
|
||||
const PRIVACY_TAB_INDEX: usize = HELP_TAB_COUNT - 1;
|
||||
|
||||
@@ -331,6 +333,113 @@ pub fn render_chat(frame: &mut Frame<'_>, app: &mut ChatApp) {
|
||||
if let Some(area) = code_area {
|
||||
render_code_workspace(frame, area, app);
|
||||
}
|
||||
|
||||
render_toasts(frame, app, full_area);
|
||||
}
|
||||
|
||||
fn toast_palette(level: ToastLevel, theme: &Theme) -> (&'static str, Style, Style) {
|
||||
let (label, color) = match level {
|
||||
ToastLevel::Info => ("INFO", theme.info),
|
||||
ToastLevel::Success => ("OK", theme.agent_badge_idle_bg),
|
||||
ToastLevel::Warning => ("WARN", theme.agent_action),
|
||||
ToastLevel::Error => ("ERROR", theme.error),
|
||||
};
|
||||
|
||||
let badge_style = Style::default()
|
||||
.fg(theme.background)
|
||||
.bg(color)
|
||||
.add_modifier(Modifier::BOLD);
|
||||
let border_style = Style::default().fg(color);
|
||||
(label, badge_style, border_style)
|
||||
}
|
||||
|
||||
fn render_toasts(frame: &mut Frame<'_>, app: &ChatApp, full_area: Rect) {
|
||||
let toasts: Vec<&Toast> = app.toasts().collect();
|
||||
if toasts.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let theme = app.theme();
|
||||
let available_width = usize::from(full_area.width.saturating_sub(2));
|
||||
if available_width == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
let max_text_width = toasts
|
||||
.iter()
|
||||
.map(|toast| UnicodeWidthStr::width(toast.message.as_str()))
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
|
||||
let mut width = max_text_width.saturating_add(6); // padding + badge
|
||||
width = width.clamp(14, available_width);
|
||||
width = width.min(48);
|
||||
if width == 0 {
|
||||
return;
|
||||
}
|
||||
let width = width as u16;
|
||||
|
||||
let offset_x = full_area
|
||||
.x
|
||||
.saturating_add(full_area.width.saturating_sub(width + 1));
|
||||
let mut offset_y = full_area.y.saturating_add(1);
|
||||
let frame_bottom = full_area.y.saturating_add(full_area.height);
|
||||
|
||||
for toast in toasts {
|
||||
let (label, badge_style, border_style) = toast_palette(toast.level, theme);
|
||||
let badge_text = format!(" {} ", label);
|
||||
let indent_width = UnicodeWidthStr::width(badge_text.as_str()) + 1;
|
||||
let indent = " ".repeat(indent_width);
|
||||
|
||||
let content_width = width.saturating_sub(4).max(1) as usize;
|
||||
let wrapped_lines = wrap(toast.message.as_str(), content_width);
|
||||
let lines: Vec<String> = if wrapped_lines.is_empty() {
|
||||
vec![String::new()]
|
||||
} else {
|
||||
wrapped_lines
|
||||
.into_iter()
|
||||
.map(|cow| cow.into_owned())
|
||||
.collect()
|
||||
};
|
||||
|
||||
let text_style = Style::default().fg(theme.text);
|
||||
let mut paragraph_lines = Vec::with_capacity(lines.len());
|
||||
if let Some((first, rest)) = lines.split_first() {
|
||||
paragraph_lines.push(Line::from(vec![
|
||||
Span::styled(badge_text.clone(), badge_style),
|
||||
Span::raw(" "),
|
||||
Span::styled(first.clone(), text_style),
|
||||
]));
|
||||
for line in rest {
|
||||
paragraph_lines.push(Line::from(vec![
|
||||
Span::raw(indent.clone()),
|
||||
Span::styled(line.clone(), text_style),
|
||||
]));
|
||||
}
|
||||
}
|
||||
|
||||
let height = (paragraph_lines.len() as u16).saturating_add(2);
|
||||
if offset_y.saturating_add(height) > frame_bottom {
|
||||
break;
|
||||
}
|
||||
|
||||
let area = Rect::new(offset_x, offset_y, width, height);
|
||||
frame.render_widget(Clear, area);
|
||||
let block = Block::default()
|
||||
.borders(Borders::ALL)
|
||||
.border_style(border_style)
|
||||
.style(Style::default().bg(theme.background));
|
||||
let paragraph = Paragraph::new(paragraph_lines)
|
||||
.block(block)
|
||||
.alignment(Alignment::Left)
|
||||
.wrap(Wrap { trim: false });
|
||||
frame.render_widget(paragraph, area);
|
||||
|
||||
offset_y = offset_y.saturating_add(height + 1);
|
||||
if offset_y >= frame_bottom {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
||||
@@ -9,11 +9,21 @@ fn palette_tracks_buffer_and_suggestions() {
|
||||
|
||||
palette.set_buffer("mo");
|
||||
assert_eq!(palette.buffer(), "mo");
|
||||
assert!(palette.suggestions().iter().all(|s| s.starts_with("mo")));
|
||||
assert!(
|
||||
palette
|
||||
.suggestions()
|
||||
.iter()
|
||||
.all(|s| s.value.starts_with("mo"))
|
||||
);
|
||||
|
||||
palette.push_char('d');
|
||||
assert_eq!(palette.buffer(), "mod");
|
||||
assert!(palette.suggestions().iter().all(|s| s.starts_with("mod")));
|
||||
assert!(
|
||||
palette
|
||||
.suggestions()
|
||||
.iter()
|
||||
.all(|s| s.value.starts_with("mod"))
|
||||
);
|
||||
|
||||
palette.pop_char();
|
||||
assert_eq!(palette.buffer(), "mo");
|
||||
|
||||
Reference in New Issue
Block a user