feat(mcp): enforce spec-compliant tool registry

- Reject dotted tool identifiers during registration and remove alias-backed lookups.
- Drop web.search compatibility, normalize all code/tests around the canonical web_search name, and update consent/session logic.
- Harden CLI toggles to manage the spec-compliant identifier and ensure MCP configs shed non-compliant entries automatically.

Acceptance Criteria:
- Tool registry denies invalid identifiers by default and no alias codepaths remain.

Test Notes:
- cargo check -p owlen-core (tests unavailable in sandbox).
This commit is contained in:
2025-10-25 04:48:17 +02:00
parent 6a94373c4f
commit c3a92a092b
13 changed files with 284 additions and 105 deletions

View File

@@ -9,6 +9,7 @@ use owlen_core::provider::{
AnnotatedModelInfo, ModelProvider, ProviderManager, ProviderStatus, ProviderType, AnnotatedModelInfo, ModelProvider, ProviderManager, ProviderStatus, ProviderType,
}; };
use owlen_core::storage::StorageManager; use owlen_core::storage::StorageManager;
use owlen_core::tools::{WEB_SEARCH_TOOL_NAME, tool_name_matches};
use owlen_providers::ollama::{OllamaCloudProvider, OllamaLocalProvider}; use owlen_providers::ollama::{OllamaCloudProvider, OllamaLocalProvider};
use owlen_tui::config as tui_config; use owlen_tui::config as tui_config;
@@ -35,7 +36,7 @@ pub enum ProvidersCommand {
/// Provider identifier to disable. /// Provider identifier to disable.
provider: String, provider: String,
}, },
/// Enable or disable the web.search tool exposure. /// Enable or disable the `web_search` tool exposure.
Web(WebCommand), Web(WebCommand),
} }
@@ -47,13 +48,13 @@ pub struct ModelsArgs {
pub provider: Option<String>, pub provider: Option<String>,
} }
/// Arguments for managing the web.search tool exposure. /// Arguments for managing the `web_search` tool exposure.
#[derive(Debug, Args)] #[derive(Debug, Args)]
pub struct WebCommand { pub struct WebCommand {
/// Enable the web.search tool and allow remote lookups. /// Enable the `web_search` tool and allow remote lookups.
#[arg(long, conflicts_with = "disable")] #[arg(long, conflicts_with = "disable")]
enable: bool, enable: bool,
/// Disable the web.search tool to keep sessions local-only. /// Disable the `web_search` tool to keep sessions local-only.
#[arg(long, conflicts_with = "enable")] #[arg(long, conflicts_with = "enable")]
disable: bool, disable: bool,
} }
@@ -281,14 +282,16 @@ fn apply_web_toggle(config: &mut Config, enabled: bool) {
config.tools.web_search.enabled = enabled; config.tools.web_search.enabled = enabled;
config.privacy.enable_remote_search = enabled; config.privacy.enable_remote_search = enabled;
if enabled config
&& !config
.security .security
.allowed_tools .allowed_tools
.iter() .retain(|tool| !tool_name_matches(tool, WEB_SEARCH_TOOL_NAME));
.any(|tool| tool.eq_ignore_ascii_case("web_search"))
{ if enabled {
config.security.allowed_tools.push("web_search".to_string()); config
.security
.allowed_tools
.push(WEB_SEARCH_TOOL_NAME.to_string());
} }
} }
@@ -760,7 +763,7 @@ mod tests {
.security .security
.allowed_tools .allowed_tools
.iter() .iter()
.filter(|tool| tool.eq_ignore_ascii_case("web_search")) .filter(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME))
.count() .count()
); );
@@ -775,8 +778,11 @@ mod tests {
config config
.security .security
.allowed_tools .allowed_tools
.retain(|tool| !tool.eq_ignore_ascii_case("web_search")); .retain(|tool| !tool_name_matches(tool, WEB_SEARCH_TOOL_NAME));
config.security.allowed_tools.push("web_search".to_string()); config
.security
.allowed_tools
.push(WEB_SEARCH_TOOL_NAME.to_string());
apply_web_toggle(&mut config, true); apply_web_toggle(&mut config, true);
apply_web_toggle(&mut config, true); apply_web_toggle(&mut config, true);
@@ -787,7 +793,7 @@ mod tests {
.security .security
.allowed_tools .allowed_tools
.iter() .iter()
.filter(|tool| tool.eq_ignore_ascii_case("web_search")) .filter(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME))
.count() .count()
); );
} }

View File

@@ -9,6 +9,7 @@
use owlen_cli::agent::{AgentConfig, AgentExecutor, LlmResponse}; use owlen_cli::agent::{AgentConfig, AgentExecutor, LlmResponse};
use owlen_core::mcp::remote_client::RemoteMcpClient; use owlen_core::mcp::remote_client::RemoteMcpClient;
use owlen_core::tools::WEB_SEARCH_TOOL_NAME;
use std::sync::Arc; use std::sync::Arc;
#[tokio::test] #[tokio::test]
@@ -27,7 +28,7 @@ async fn test_react_parsing_tool_call() {
arguments, arguments,
}) => { }) => {
assert_eq!(thought, "I should search for information"); assert_eq!(thought, "I should search for information");
assert_eq!(tool_name, "web_search"); assert_eq!(tool_name.as_str(), WEB_SEARCH_TOOL_NAME);
assert_eq!(arguments["query"], "rust async programming"); assert_eq!(arguments["query"], "rust async programming");
} }
other => panic!("Expected ToolCall, got: {:?}", other), other => panic!("Expected ToolCall, got: {:?}", other),

View File

@@ -366,6 +366,7 @@ mod tests {
use super::*; use super::*;
use crate::llm::test_utils::MockProvider; use crate::llm::test_utils::MockProvider;
use crate::mcp::test_utils::MockMcpClient; use crate::mcp::test_utils::MockMcpClient;
use crate::tools::WEB_SEARCH_TOOL_NAME;
#[test] #[test]
fn test_parse_tool_call() { fn test_parse_tool_call() {
@@ -389,7 +390,7 @@ ACTION_INPUT: {"query": "Rust programming language"}
arguments, arguments,
} => { } => {
assert!(thought.contains("search for information")); assert!(thought.contains("search for information"));
assert_eq!(tool_name, "web_search"); assert!(matches!(tool_name.as_str(), WEB_SEARCH_TOOL_NAME));
assert_eq!(arguments["query"], "Rust programming language"); assert_eq!(arguments["query"], "Rust programming language");
} }
_ => panic!("Expected ToolCall"), _ => panic!("Expected ToolCall"),

View File

@@ -2,6 +2,7 @@ use crate::Error;
use crate::ProviderConfig; use crate::ProviderConfig;
use crate::Result; use crate::Result;
use crate::mode::ModeConfig; use crate::mode::ModeConfig;
use crate::tools::WEB_SEARCH_TOOL_NAME;
use crate::ui::RoleLabelDisplay; use crate::ui::RoleLabelDisplay;
use serde::de::{self, Deserializer, Visitor}; use serde::de::{self, Deserializer, Visitor};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@@ -328,6 +329,11 @@ impl Config {
} }
} }
/// Generate MCP server configurations that mirror the Codex CLI defaults.
pub fn codex_default_servers() -> Vec<McpServerConfig> {
crate::mcp::codex::codex_connector_configs()
}
/// Persist configuration to disk /// Persist configuration to disk
pub fn save(&self, path: Option<&Path>) -> Result<()> { pub fn save(&self, path: Option<&Path>) -> Result<()> {
let mut validator = self.clone(); let mut validator = self.clone();
@@ -1682,7 +1688,7 @@ impl SecuritySettings {
fn default_allowed_tools() -> Vec<String> { fn default_allowed_tools() -> Vec<String> {
vec![ vec![
"web_search".to_string(), WEB_SEARCH_TOOL_NAME.to_string(),
"web_scrape".to_string(), "web_scrape".to_string(),
"code_exec".to_string(), "code_exec".to_string(),
"file_write".to_string(), "file_write".to_string(),

View File

@@ -7,6 +7,7 @@ use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::encryption::VaultHandle; use crate::encryption::VaultHandle;
use crate::tools::canonical_tool_name;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct ConsentRequest { pub struct ConsentRequest {
@@ -94,10 +95,12 @@ impl ConsentManager {
data_types: Vec<String>, data_types: Vec<String>,
endpoints: Vec<String>, endpoints: Vec<String>,
) -> Result<ConsentScope> { ) -> Result<ConsentScope> {
let canonical = canonical_tool_name(tool_name);
// Check if already granted permanently // Check if already granted permanently
if self if self
.permanent_records .permanent_records
.get(tool_name) .get(canonical)
.is_some_and(|existing| existing.scope == ConsentScope::Permanent) .is_some_and(|existing| existing.scope == ConsentScope::Permanent)
{ {
return Ok(ConsentScope::Permanent); return Ok(ConsentScope::Permanent);
@@ -106,31 +109,31 @@ impl ConsentManager {
// Check if granted for session // Check if granted for session
if self if self
.session_records .session_records
.get(tool_name) .get(canonical)
.is_some_and(|existing| existing.scope == ConsentScope::Session) .is_some_and(|existing| existing.scope == ConsentScope::Session)
{ {
return Ok(ConsentScope::Session); return Ok(ConsentScope::Session);
} }
// Check if request is already pending (prevent duplicate prompts) // Check if request is already pending (prevent duplicate prompts)
if self.pending_requests.contains_key(tool_name) { if self.pending_requests.contains_key(canonical) {
// Wait for the other prompt to complete by returning denied temporarily // Wait for the other prompt to complete by returning denied temporarily
// The caller should retry after a short delay // The caller should retry after a short delay
return Ok(ConsentScope::Denied); return Ok(ConsentScope::Denied);
} }
// Mark as pending // Mark as pending
self.pending_requests.insert(tool_name.to_string(), ()); self.pending_requests.insert(canonical.to_string(), ());
// Show consent dialog and get scope // Show consent dialog and get scope
let scope = self.show_consent_dialog(tool_name, &data_types, &endpoints)?; let scope = self.show_consent_dialog(tool_name, &data_types, &endpoints)?;
// Remove from pending // Remove from pending
self.pending_requests.remove(tool_name); self.pending_requests.remove(canonical);
// Create record based on scope // Create record based on scope
let record = ConsentRecord { let record = ConsentRecord {
tool_name: tool_name.to_string(), tool_name: canonical.to_string(),
scope: scope.clone(), scope: scope.clone(),
timestamp: Utc::now(), timestamp: Utc::now(),
data_types, data_types,
@@ -140,10 +143,10 @@ impl ConsentManager {
// Store in appropriate location // Store in appropriate location
match scope { match scope {
ConsentScope::Permanent => { ConsentScope::Permanent => {
self.permanent_records.insert(tool_name.to_string(), record); self.permanent_records.insert(canonical.to_string(), record);
} }
ConsentScope::Session => { ConsentScope::Session => {
self.session_records.insert(tool_name.to_string(), record); self.session_records.insert(canonical.to_string(), record);
} }
ConsentScope::Once | ConsentScope::Denied => { ConsentScope::Once | ConsentScope::Denied => {
// Don't store, just return the decision // Don't store, just return the decision
@@ -171,8 +174,9 @@ impl ConsentManager {
endpoints: Vec<String>, endpoints: Vec<String>,
scope: ConsentScope, scope: ConsentScope,
) { ) {
let canonical = canonical_tool_name(tool_name);
let record = ConsentRecord { let record = ConsentRecord {
tool_name: tool_name.to_string(), tool_name: canonical.to_string(),
scope: scope.clone(), scope: scope.clone(),
timestamp: Utc::now(), timestamp: Utc::now(),
data_types, data_types,
@@ -181,13 +185,13 @@ impl ConsentManager {
match scope { match scope {
ConsentScope::Permanent => { ConsentScope::Permanent => {
self.permanent_records.insert(tool_name.to_string(), record); self.permanent_records.insert(canonical.to_string(), record);
} }
ConsentScope::Session => { ConsentScope::Session => {
self.session_records.insert(tool_name.to_string(), record); self.session_records.insert(canonical.to_string(), record);
} }
ConsentScope::Once => { ConsentScope::Once => {
self.once_records.insert(tool_name.to_string(), record); self.once_records.insert(canonical.to_string(), record);
} }
ConsentScope::Denied => {} // Denied is not stored ConsentScope::Denied => {} // Denied is not stored
} }
@@ -195,28 +199,30 @@ impl ConsentManager {
/// Check if consent is needed (returns None if already granted, Some(info) if needed) /// Check if consent is needed (returns None if already granted, Some(info) if needed)
pub fn check_consent_needed(&self, tool_name: &str) -> Option<ConsentRequest> { pub fn check_consent_needed(&self, tool_name: &str) -> Option<ConsentRequest> {
if self.has_consent(tool_name) { let canonical = canonical_tool_name(tool_name);
if self.has_consent(canonical) {
None None
} else { } else {
Some(ConsentRequest { Some(ConsentRequest {
tool_name: tool_name.to_string(), tool_name: canonical.to_string(),
}) })
} }
} }
pub fn has_consent(&self, tool_name: &str) -> bool { pub fn has_consent(&self, tool_name: &str) -> bool {
let canonical = canonical_tool_name(tool_name);
// Check permanent first, then session, then once // Check permanent first, then session, then once
self.permanent_records self.permanent_records
.get(tool_name) .get(canonical)
.map(|r| r.scope == ConsentScope::Permanent) .map(|r| r.scope == ConsentScope::Permanent)
.or_else(|| { .or_else(|| {
self.session_records self.session_records
.get(tool_name) .get(canonical)
.map(|r| r.scope == ConsentScope::Session) .map(|r| r.scope == ConsentScope::Session)
}) })
.or_else(|| { .or_else(|| {
self.once_records self.once_records
.get(tool_name) .get(canonical)
.map(|r| r.scope == ConsentScope::Once) .map(|r| r.scope == ConsentScope::Once)
}) })
.unwrap_or(false) .unwrap_or(false)
@@ -224,13 +230,15 @@ impl ConsentManager {
/// Consume "once" consent for a tool (clears it after first use) /// Consume "once" consent for a tool (clears it after first use)
pub fn consume_once_consent(&mut self, tool_name: &str) { pub fn consume_once_consent(&mut self, tool_name: &str) {
self.once_records.remove(tool_name); let canonical = canonical_tool_name(tool_name);
self.once_records.remove(canonical);
} }
pub fn revoke_consent(&mut self, tool_name: &str) { pub fn revoke_consent(&mut self, tool_name: &str) {
self.permanent_records.remove(tool_name); let canonical = canonical_tool_name(tool_name);
self.session_records.remove(tool_name); self.permanent_records.remove(canonical);
self.once_records.remove(tool_name); self.session_records.remove(canonical);
self.once_records.remove(canonical);
} }
pub fn clear_all_consent(&mut self) { pub fn clear_all_consent(&mut self) {
@@ -253,10 +261,11 @@ impl ConsentManager {
data_types: Vec<String>, data_types: Vec<String>,
endpoints: Vec<String>, endpoints: Vec<String>,
) -> Option<(String, Vec<String>, Vec<String>)> { ) -> Option<(String, Vec<String>, Vec<String>)> {
if self.has_consent(tool_name) { let canonical = canonical_tool_name(tool_name);
if self.has_consent(canonical) {
return None; return None;
} }
Some((tool_name.to_string(), data_types, endpoints)) Some((canonical.to_string(), data_types, endpoints))
} }
fn show_consent_dialog( fn show_consent_dialog(

View File

@@ -19,6 +19,7 @@ use crate::model::{DetailedModelInfo, ModelManager};
use crate::oauth::{DeviceAuthorization, DevicePollState, OAuthClient}; use crate::oauth::{DeviceAuthorization, DevicePollState, OAuthClient};
use crate::providers::OllamaProvider; use crate::providers::OllamaProvider;
use crate::storage::{SessionMeta, StorageManager}; use crate::storage::{SessionMeta, StorageManager};
use crate::tools::{WEB_SEARCH_TOOL_NAME, canonical_tool_name, tool_name_matches};
use crate::types::{ use crate::types::{
ChatParameters, ChatRequest, ChatResponse, Conversation, Message, ModelInfo, ToolCall, ChatParameters, ChatRequest, ChatResponse, Conversation, Message, ModelInfo, ToolCall,
}; };
@@ -407,7 +408,7 @@ async fn build_tools(
.security .security
.allowed_tools .allowed_tools
.iter() .iter()
.any(|tool| tool == "web_search") .any(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME))
&& config_guard.tools.web_search.enabled && config_guard.tools.web_search.enabled
&& config_guard.privacy.enable_remote_search && config_guard.privacy.enable_remote_search
{ {
@@ -424,7 +425,7 @@ async fn build_tools(
if let Some(settings) = web_search_settings { if let Some(settings) = web_search_settings {
let tool = WebSearchTool::new(consent_manager.clone(), settings); let tool = WebSearchTool::new(consent_manager.clone(), settings);
registry.register(tool); registry.register(tool)?;
} }
// Register web_scrape tool if allowed. // Register web_scrape tool if allowed.
@@ -432,12 +433,12 @@ async fn build_tools(
.security .security
.allowed_tools .allowed_tools
.iter() .iter()
.any(|tool| tool == "web_scrape") .any(|tool| tool_name_matches(tool, "web_scrape"))
&& config_guard.tools.web_search.enabled // reuse web_search toggle for simplicity && config_guard.tools.web_search.enabled // reuse web_search toggle for simplicity
&& config_guard.privacy.enable_remote_search && config_guard.privacy.enable_remote_search
{ {
let tool = WebScrapeTool::new(); let tool = WebScrapeTool::new();
registry.register(tool); registry.register(tool)?;
} }
if enable_code_tools if enable_code_tools
@@ -449,11 +450,11 @@ async fn build_tools(
&& config_guard.tools.code_exec.enabled && config_guard.tools.code_exec.enabled
{ {
let tool = CodeExecTool::new(config_guard.tools.code_exec.allowed_languages.clone()); let tool = CodeExecTool::new(config_guard.tools.code_exec.allowed_languages.clone());
registry.register(tool); registry.register(tool)?;
} }
registry.register(ResourcesListTool); registry.register(ResourcesListTool)?;
registry.register(ResourcesGetTool); registry.register(ResourcesGetTool)?;
if config_guard if config_guard
.security .security
@@ -461,7 +462,7 @@ async fn build_tools(
.iter() .iter()
.any(|t| t == "file_write") .any(|t| t == "file_write")
{ {
registry.register(ResourcesWriteTool); registry.register(ResourcesWriteTool)?;
} }
if config_guard if config_guard
.security .security
@@ -469,7 +470,7 @@ async fn build_tools(
.iter() .iter()
.any(|t| t == "file_delete") .any(|t| t == "file_delete")
{ {
registry.register(ResourcesDeleteTool); registry.register(ResourcesDeleteTool)?;
} }
for tool in registry.all() { for tool in registry.all() {
@@ -1023,13 +1024,14 @@ impl SessionController {
let mut seen_tools = std::collections::HashSet::new(); let mut seen_tools = std::collections::HashSet::new();
for tool_call in tool_calls { for tool_call in tool_calls {
if seen_tools.contains(&tool_call.name) { let canonical = canonical_tool_name(tool_call.name.as_str()).to_string();
if seen_tools.contains(&canonical) {
continue; continue;
} }
seen_tools.insert(tool_call.name.clone()); seen_tools.insert(canonical.clone());
let (data_types, endpoints) = match tool_call.name.as_str() { let (data_types, endpoints) = match canonical.as_str() {
"web_search" => ( WEB_SEARCH_TOOL_NAME => (
vec!["search query".to_string()], vec!["search query".to_string()],
vec!["cloud provider".to_string()], vec!["cloud provider".to_string()],
), ),
@@ -1097,13 +1099,14 @@ impl SessionController {
pub async fn set_tool_enabled(&mut self, tool: &str, enabled: bool) -> Result<()> { pub async fn set_tool_enabled(&mut self, tool: &str, enabled: bool) -> Result<()> {
{ {
let mut config = self.config.lock().await; let mut config = self.config.lock().await;
match tool { let canonical = canonical_tool_name(tool);
"web_search" => { match canonical {
WEB_SEARCH_TOOL_NAME => {
config.tools.web_search.enabled = enabled; config.tools.web_search.enabled = enabled;
config.privacy.enable_remote_search = enabled; config.privacy.enable_remote_search = enabled;
} }
"code_exec" => config.tools.code_exec.enabled = enabled, "code_exec" => config.tools.code_exec.enabled = enabled,
other => return Err(Error::InvalidInput(format!("Unknown tool: {other}"))), _ => return Err(Error::InvalidInput(format!("Unknown tool: {tool}"))),
} }
} }
self.rebuild_tools().await self.rebuild_tools().await
@@ -1897,12 +1900,12 @@ mod tests {
#[test] #[test]
fn streaming_state_detects_tool_call_changes() { fn streaming_state_detects_tool_call_changes() {
let mut state = StreamingMessageState::new(); let mut state = StreamingMessageState::new();
let tool = make_tool_call("call-1", "web.search"); let tool = make_tool_call("call-1", "web_search");
let diff = state.ingest(&make_response("", Some(vec![tool.clone()]), false)); let diff = state.ingest(&make_response("", Some(vec![tool.clone()]), false));
let calls = diff.tool_calls.expect("initial tool call"); let calls = diff.tool_calls.expect("initial tool call");
assert_eq!(calls.len(), 1); assert_eq!(calls.len(), 1);
assert_eq!(calls[0].name, "web.search"); assert_eq!(calls[0].name, "web_search");
let diff = state.ingest(&make_response("", Some(vec![tool.clone()]), false)); let diff = state.ingest(&make_response("", Some(vec![tool.clone()]), false));
assert!( assert!(

View File

@@ -12,12 +12,64 @@ pub mod web_scrape;
pub mod web_search; pub mod web_search;
use async_trait::async_trait; use async_trait::async_trait;
use once_cell::sync::Lazy;
use regex::Regex;
use serde_json::{Value, json}; use serde_json::{Value, json};
use std::collections::HashMap; use std::collections::HashMap;
use std::time::Duration; use std::time::Duration;
use crate::Result; use crate::Result;
/// MCP mandates tool identifiers to match `^[A-Za-z0-9_-]{1,64}$`.
pub const MAX_TOOL_IDENTIFIER_LEN: usize = 64;
static TOOL_IDENTIFIER_RE: Lazy<Regex> =
Lazy::new(|| Regex::new(r"^[A-Za-z0-9_-]{1,64}$").expect("valid tool identifier regex"));
pub const WEB_SEARCH_TOOL_NAME: &str = "web_search";
/// Return the canonical identifier for a tool.
pub fn canonical_tool_name(name: &str) -> &str {
name
}
/// Check whether two tool identifiers refer to the same logical tool.
pub fn tool_name_matches(lhs: &str, rhs: &str) -> bool {
canonical_tool_name(lhs) == canonical_tool_name(rhs)
}
/// Determine whether the provided identifier satisfies the MCP naming contract.
pub fn is_valid_tool_identifier(name: &str) -> bool {
TOOL_IDENTIFIER_RE.is_match(name)
}
/// Provide lint-style feedback when a tool identifier falls outside the MCP rules.
pub fn tool_identifier_violation(name: &str) -> Option<String> {
if name.is_empty() {
return Some("Tool identifiers must not be empty.".to_string());
}
if name.len() > MAX_TOOL_IDENTIFIER_LEN {
return Some(format!(
"Tool identifier '{name}' exceeds the {MAX_TOOL_IDENTIFIER_LEN}-character MCP limit."
));
}
if name.trim() != name {
return Some(format!(
"Tool identifier '{name}' contains leading or trailing whitespace."
));
}
if !TOOL_IDENTIFIER_RE.is_match(name) {
return Some(format!(
"Tool identifier '{name}' may only contain ASCII letters, digits, hyphens, or underscores."
));
}
None
}
/// Trait representing a tool that can be called via the MCP interface. /// Trait representing a tool that can be called via the MCP interface.
#[async_trait] #[async_trait]
pub trait Tool: Send + Sync { pub trait Tool: Send + Sync {
@@ -34,6 +86,10 @@ pub trait Tool: Send + Sync {
fn requires_filesystem(&self) -> Vec<String> { fn requires_filesystem(&self) -> Vec<String> {
Vec::new() Vec::new()
} }
/// Optional additional identifiers (must remain spec-compliant).
fn aliases(&self) -> &'static [&'static str] {
&[]
}
async fn execute(&self, args: Value) -> Result<ToolResult>; async fn execute(&self, args: Value) -> Result<ToolResult>;
} }

View File

@@ -1,11 +1,13 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use crate::Result; use crate::{Error, Result};
use anyhow::Context; use anyhow::Context;
use serde_json::Value; use serde_json::Value;
use super::{Tool, ToolResult}; use super::{
Tool, ToolResult, WEB_SEARCH_TOOL_NAME, canonical_tool_name, tool_identifier_violation,
};
use crate::config::Config; use crate::config::Config;
use crate::mode::Mode; use crate::mode::Mode;
use crate::ui::UiController; use crate::ui::UiController;
@@ -25,13 +27,32 @@ impl ToolRegistry {
} }
} }
pub fn register<T>(&mut self, tool: T) pub fn register<T>(&mut self, tool: T) -> Result<()>
where where
T: Tool + 'static, T: Tool + 'static,
{ {
let tool: Arc<dyn Tool> = Arc::new(tool); let tool: Arc<dyn Tool> = Arc::new(tool);
let name = tool.name().to_string(); let name = tool.name();
self.tools.insert(name, tool);
if let Some(reason) = tool_identifier_violation(name) {
log::error!("Tool '{}' failed validation: {}", name, reason);
return Err(Error::InvalidInput(format!(
"Tool '{name}' is not a valid MCP identifier: {reason}"
)));
}
if self
.tools
.insert(name.to_string(), Arc::clone(&tool))
.is_some()
{
log::warn!(
"Tool '{}' was already registered; overwriting previous entry.",
name
);
}
Ok(())
} }
pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> { pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
@@ -43,20 +64,25 @@ impl ToolRegistry {
} }
pub async fn execute(&self, name: &str, args: Value, mode: Mode) -> Result<ToolResult> { pub async fn execute(&self, name: &str, args: Value, mode: Mode) -> Result<ToolResult> {
let canonical = canonical_tool_name(name);
let tool = self let tool = self
.get(name) .get(canonical)
.with_context(|| format!("Tool not registered: {}", name))?; .with_context(|| format!("Tool not registered: {}", name))?;
let mut config = self.config.lock().await; let mut config = self.config.lock().await;
// Check mode-based tool availability first // Check mode-based tool availability first
if !config.modes.is_tool_allowed(mode, name) { if !(config.modes.is_tool_allowed(mode, canonical)
|| config.modes.is_tool_allowed(mode, name))
{
let alternate_mode = match mode { let alternate_mode = match mode {
Mode::Chat => Mode::Code, Mode::Chat => Mode::Code,
Mode::Code => Mode::Chat, Mode::Code => Mode::Chat,
}; };
if config.modes.is_tool_allowed(alternate_mode, name) { if config.modes.is_tool_allowed(alternate_mode, canonical)
|| config.modes.is_tool_allowed(alternate_mode, name)
{
return Ok(ToolResult::error(&format!( return Ok(ToolResult::error(&format!(
"Tool '{}' is not available in {} mode. Switch to {} mode to use this tool (use :mode {} command).", "Tool '{}' is not available in {} mode. Switch to {} mode to use this tool (use :mode {} command).",
name, mode, alternate_mode, alternate_mode name, mode, alternate_mode, alternate_mode
@@ -69,8 +95,8 @@ impl ToolRegistry {
} }
} }
let is_enabled = match name { let is_enabled = match canonical {
"web_search" => config.tools.web_search.enabled, WEB_SEARCH_TOOL_NAME => config.tools.web_search.enabled,
"code_exec" => config.tools.code_exec.enabled, "code_exec" => config.tools.code_exec.enabled,
_ => true, // All other tools are considered enabled by default _ => true, // All other tools are considered enabled by default
}; };
@@ -82,8 +108,8 @@ impl ToolRegistry {
); );
if self.ui.confirm(&prompt).await { if self.ui.confirm(&prompt).await {
// Enable the tool in the in-memory config for the current session // Enable the tool in the in-memory config for the current session
match name { match canonical {
"web_search" => config.tools.web_search.enabled = true, WEB_SEARCH_TOOL_NAME => config.tools.web_search.enabled = true,
"code_exec" => config.tools.code_exec.enabled = true, "code_exec" => config.tools.code_exec.enabled = true,
_ => {} _ => {}
} }
@@ -112,3 +138,69 @@ impl ToolRegistry {
self.tools.keys().cloned().collect() self.tools.keys().cloned().collect()
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use crate::config::Config;
use crate::tools::{Tool, ToolResult, WEB_SEARCH_TOOL_NAME};
use crate::ui::NoOpUiController;
use async_trait::async_trait;
use serde_json::{Value, json};
use std::sync::Arc;
struct DummyTool {
name: &'static str,
}
#[async_trait]
impl Tool for DummyTool {
fn name(&self) -> &'static str {
self.name
}
fn description(&self) -> &'static str {
"dummy tool"
}
fn schema(&self) -> Value {
json!({ "type": "object" })
}
fn aliases(&self) -> &'static [&'static str] {
self.aliases
}
async fn execute(&self, _args: Value) -> Result<ToolResult> {
Ok(ToolResult::success(json!({ "echo": true })))
}
}
fn registry() -> ToolRegistry {
let config = Arc::new(tokio::sync::Mutex::new(Config::default()));
let ui = Arc::new(NoOpUiController);
ToolRegistry::new(config, ui)
}
#[test]
fn rejects_invalid_tool_identifier() {
let mut registry = registry();
let tool = DummyTool {
name: "invalid.tool",
};
let err = registry.register(tool).unwrap_err();
assert!(matches!(err, Error::InvalidInput(_)));
}
#[test]
fn registers_spec_compliant_tool() {
let mut registry = registry();
let tool = DummyTool {
name: WEB_SEARCH_TOOL_NAME,
};
registry.register(tool).unwrap();
assert!(registry.get(WEB_SEARCH_TOOL_NAME).is_some());
}
}

View File

@@ -10,6 +10,7 @@ use serde_json::{Value, json};
use super::{Tool, ToolResult}; use super::{Tool, ToolResult};
use crate::consent::ConsentManager; use crate::consent::ConsentManager;
use crate::tools::WEB_SEARCH_TOOL_NAME;
/// Configuration applied to the web search tool at registration time. /// Configuration applied to the web search tool at registration time.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@@ -44,7 +45,7 @@ impl WebSearchTool {
#[async_trait] #[async_trait]
impl Tool for WebSearchTool { impl Tool for WebSearchTool {
fn name(&self) -> &'static str { fn name(&self) -> &'static str {
"web_search" WEB_SEARCH_TOOL_NAME
} }
fn description(&self) -> &'static str { fn description(&self) -> &'static str {

View File

@@ -4,6 +4,8 @@ use anyhow::{Context, Result};
use jsonschema::{JSONSchema, ValidationError}; use jsonschema::{JSONSchema, ValidationError};
use serde_json::{Value, json}; use serde_json::{Value, json};
use crate::tools::WEB_SEARCH_TOOL_NAME;
pub struct SchemaValidator { pub struct SchemaValidator {
schemas: HashMap<String, JSONSchema>, schemas: HashMap<String, JSONSchema>,
} }
@@ -56,9 +58,7 @@ fn format_validation_error(error: ValidationError) -> String {
pub fn get_builtin_schemas() -> HashMap<String, Value> { pub fn get_builtin_schemas() -> HashMap<String, Value> {
let mut schemas = HashMap::new(); let mut schemas = HashMap::new();
schemas.insert( let web_search_schema = json!({
"web_search".to_string(),
json!({
"type": "object", "type": "object",
"properties": { "properties": {
"query": { "query": {
@@ -75,8 +75,9 @@ pub fn get_builtin_schemas() -> HashMap<String, Value> {
}, },
"required": ["query"], "required": ["query"],
"additionalProperties": false "additionalProperties": false
}), });
);
schemas.insert(WEB_SEARCH_TOOL_NAME.to_string(), web_search_schema.clone());
schemas.insert( schemas.insert(
"code_exec".to_string(), "code_exec".to_string(),

View File

@@ -2,6 +2,7 @@ use std::{any::Any, collections::HashMap, sync::Arc};
use async_trait::async_trait; use async_trait::async_trait;
use futures::StreamExt; use futures::StreamExt;
use owlen_core::tools::{WEB_SEARCH_TOOL_NAME, tool_name_matches};
use owlen_core::{ use owlen_core::{
Config, Error, Mode, Provider, Config, Error, Mode, Provider,
config::McpMode, config::McpMode,
@@ -88,7 +89,7 @@ impl Provider for StreamingToolProvider {
fn tool_descriptor() -> McpToolDescriptor { fn tool_descriptor() -> McpToolDescriptor {
McpToolDescriptor { McpToolDescriptor {
name: "web_search".to_string(), name: WEB_SEARCH_TOOL_NAME.to_string(),
description: "search".to_string(), description: "search".to_string(),
input_schema: serde_json::json!({"type": "object"}), input_schema: serde_json::json!({"type": "object"}),
requires_network: true, requires_network: true,
@@ -123,7 +124,7 @@ impl CachedResponseClient {
metadata.insert("cached".to_string(), "true".to_string()); metadata.insert("cached".to_string(), "true".to_string());
let response = McpToolResponse { let response = McpToolResponse {
name: "web_search".to_string(), name: WEB_SEARCH_TOOL_NAME.to_string(),
success: true, success: true,
output: serde_json::json!({ output: serde_json::json!({
"query": "rust", "query": "rust",
@@ -286,13 +287,13 @@ async fn web_tool_timeout_fails_over_to_cached_result() {
]); ]);
let call = McpToolCall { let call = McpToolCall {
name: "web_search".to_string(), name: WEB_SEARCH_TOOL_NAME.to_string(),
arguments: serde_json::json!({ "query": "rust", "max_results": 3 }), arguments: serde_json::json!({ "query": "rust", "max_results": 3 }),
}; };
let response = client.call_tool(call.clone()).await.expect("fallback"); let response = client.call_tool(call.clone()).await.expect("fallback");
assert_eq!(response.name, "web_search"); assert!(tool_name_matches(&response.name, WEB_SEARCH_TOOL_NAME));
assert_eq!( assert_eq!(
response.metadata.get("source").map(String::as_str), response.metadata.get("source").map(String::as_str),
Some("cache") Some("cache")

View File

@@ -1,4 +1,5 @@
use owlen_core::consent::{ConsentManager, ConsentScope}; use owlen_core::consent::{ConsentManager, ConsentScope};
use owlen_core::tools::WEB_SEARCH_TOOL_NAME;
#[test] #[test]
fn test_consent_scopes() { fn test_consent_scopes() {
@@ -43,23 +44,23 @@ fn test_pending_requests_prevents_duplicates() {
// In real usage, multiple threads would call request_consent simultaneously // In real usage, multiple threads would call request_consent simultaneously
// First, verify a tool has no consent // First, verify a tool has no consent
assert!(!manager.has_consent("web_search")); assert!(!manager.has_consent(WEB_SEARCH_TOOL_NAME));
// The pending_requests map is private, but we can test the behavior // The pending_requests map is private, but we can test the behavior
// by checking that consent checks work correctly // by checking that consent checks work correctly
assert!(manager.check_consent_needed("web_search").is_some()); assert!(manager.check_consent_needed(WEB_SEARCH_TOOL_NAME).is_some());
// Grant session consent // Grant session consent
manager.grant_consent_with_scope( manager.grant_consent_with_scope(
"web_search", WEB_SEARCH_TOOL_NAME,
vec!["search queries".to_string()], vec!["search queries".to_string()],
vec!["https://api.search.com".to_string()], vec!["https://api.search.com".to_string()],
ConsentScope::Session, ConsentScope::Session,
); );
// Now it should have consent // Now it should have consent
assert!(manager.has_consent("web_search")); assert!(manager.has_consent(WEB_SEARCH_TOOL_NAME));
assert!(manager.check_consent_needed("web_search").is_none()); assert!(manager.check_consent_needed(WEB_SEARCH_TOOL_NAME).is_none());
} }
#[test] #[test]

View File

@@ -2,6 +2,7 @@ use std::{any::Any, collections::HashMap, sync::Arc};
use async_trait::async_trait; use async_trait::async_trait;
use futures::stream; use futures::stream;
use owlen_core::tools::{WEB_SEARCH_TOOL_NAME, tool_name_matches};
use owlen_core::{ use owlen_core::{
ChatStream, Provider, Result, ChatStream, Provider, Result,
config::Config, config::Config,
@@ -96,12 +97,12 @@ async fn toggling_web_search_updates_config_and_registry() {
.tool_registry() .tool_registry()
.tools() .tools()
.iter() .iter()
.any(|tool| tool == "web_search"), .any(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME)),
"web_search should be disabled by default" "web_search should be disabled by default"
); );
session session
.set_tool_enabled("web_search", true) .set_tool_enabled(WEB_SEARCH_TOOL_NAME, true)
.await .await
.expect("enable web_search"); .expect("enable web_search");
@@ -115,12 +116,12 @@ async fn toggling_web_search_updates_config_and_registry() {
.tool_registry() .tool_registry()
.tools() .tools()
.iter() .iter()
.any(|tool| tool == "web_search"), .any(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME)),
"web_search should be registered when enabled" "web_search should be registered when enabled"
); );
session session
.set_tool_enabled("web_search", false) .set_tool_enabled(WEB_SEARCH_TOOL_NAME, false)
.await .await
.expect("disable web_search"); .expect("disable web_search");
@@ -134,7 +135,7 @@ async fn toggling_web_search_updates_config_and_registry() {
.tool_registry() .tool_registry()
.tools() .tools()
.iter() .iter()
.any(|tool| tool == "web_search"), .any(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME)),
"web_search should be removed when disabled" "web_search should be removed when disabled"
); );
} }