feat(mcp): enforce spec-compliant tool registry
- Reject dotted tool identifiers during registration and remove alias-backed lookups. - Drop web.search compatibility, normalize all code/tests around the canonical web_search name, and update consent/session logic. - Harden CLI toggles to manage the spec-compliant identifier and ensure MCP configs shed non-compliant entries automatically. Acceptance Criteria: - Tool registry denies invalid identifiers by default and no alias codepaths remain. Test Notes: - cargo check -p owlen-core (tests unavailable in sandbox).
This commit is contained in:
@@ -9,6 +9,7 @@ use owlen_core::provider::{
|
||||
AnnotatedModelInfo, ModelProvider, ProviderManager, ProviderStatus, ProviderType,
|
||||
};
|
||||
use owlen_core::storage::StorageManager;
|
||||
use owlen_core::tools::{WEB_SEARCH_TOOL_NAME, tool_name_matches};
|
||||
use owlen_providers::ollama::{OllamaCloudProvider, OllamaLocalProvider};
|
||||
use owlen_tui::config as tui_config;
|
||||
|
||||
@@ -35,7 +36,7 @@ pub enum ProvidersCommand {
|
||||
/// Provider identifier to disable.
|
||||
provider: String,
|
||||
},
|
||||
/// Enable or disable the web.search tool exposure.
|
||||
/// Enable or disable the `web_search` tool exposure.
|
||||
Web(WebCommand),
|
||||
}
|
||||
|
||||
@@ -47,13 +48,13 @@ pub struct ModelsArgs {
|
||||
pub provider: Option<String>,
|
||||
}
|
||||
|
||||
/// Arguments for managing the web.search tool exposure.
|
||||
/// Arguments for managing the `web_search` tool exposure.
|
||||
#[derive(Debug, Args)]
|
||||
pub struct WebCommand {
|
||||
/// Enable the web.search tool and allow remote lookups.
|
||||
/// Enable the `web_search` tool and allow remote lookups.
|
||||
#[arg(long, conflicts_with = "disable")]
|
||||
enable: bool,
|
||||
/// Disable the web.search tool to keep sessions local-only.
|
||||
/// Disable the `web_search` tool to keep sessions local-only.
|
||||
#[arg(long, conflicts_with = "enable")]
|
||||
disable: bool,
|
||||
}
|
||||
@@ -281,14 +282,16 @@ fn apply_web_toggle(config: &mut Config, enabled: bool) {
|
||||
config.tools.web_search.enabled = enabled;
|
||||
config.privacy.enable_remote_search = enabled;
|
||||
|
||||
if enabled
|
||||
&& !config
|
||||
config
|
||||
.security
|
||||
.allowed_tools
|
||||
.iter()
|
||||
.any(|tool| tool.eq_ignore_ascii_case("web_search"))
|
||||
{
|
||||
config.security.allowed_tools.push("web_search".to_string());
|
||||
.retain(|tool| !tool_name_matches(tool, WEB_SEARCH_TOOL_NAME));
|
||||
|
||||
if enabled {
|
||||
config
|
||||
.security
|
||||
.allowed_tools
|
||||
.push(WEB_SEARCH_TOOL_NAME.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -760,7 +763,7 @@ mod tests {
|
||||
.security
|
||||
.allowed_tools
|
||||
.iter()
|
||||
.filter(|tool| tool.eq_ignore_ascii_case("web_search"))
|
||||
.filter(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME))
|
||||
.count()
|
||||
);
|
||||
|
||||
@@ -775,8 +778,11 @@ mod tests {
|
||||
config
|
||||
.security
|
||||
.allowed_tools
|
||||
.retain(|tool| !tool.eq_ignore_ascii_case("web_search"));
|
||||
config.security.allowed_tools.push("web_search".to_string());
|
||||
.retain(|tool| !tool_name_matches(tool, WEB_SEARCH_TOOL_NAME));
|
||||
config
|
||||
.security
|
||||
.allowed_tools
|
||||
.push(WEB_SEARCH_TOOL_NAME.to_string());
|
||||
|
||||
apply_web_toggle(&mut config, true);
|
||||
apply_web_toggle(&mut config, true);
|
||||
@@ -787,7 +793,7 @@ mod tests {
|
||||
.security
|
||||
.allowed_tools
|
||||
.iter()
|
||||
.filter(|tool| tool.eq_ignore_ascii_case("web_search"))
|
||||
.filter(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME))
|
||||
.count()
|
||||
);
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
|
||||
use owlen_cli::agent::{AgentConfig, AgentExecutor, LlmResponse};
|
||||
use owlen_core::mcp::remote_client::RemoteMcpClient;
|
||||
use owlen_core::tools::WEB_SEARCH_TOOL_NAME;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[tokio::test]
|
||||
@@ -27,7 +28,7 @@ async fn test_react_parsing_tool_call() {
|
||||
arguments,
|
||||
}) => {
|
||||
assert_eq!(thought, "I should search for information");
|
||||
assert_eq!(tool_name, "web_search");
|
||||
assert_eq!(tool_name.as_str(), WEB_SEARCH_TOOL_NAME);
|
||||
assert_eq!(arguments["query"], "rust async programming");
|
||||
}
|
||||
other => panic!("Expected ToolCall, got: {:?}", other),
|
||||
|
||||
@@ -366,6 +366,7 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::llm::test_utils::MockProvider;
|
||||
use crate::mcp::test_utils::MockMcpClient;
|
||||
use crate::tools::WEB_SEARCH_TOOL_NAME;
|
||||
|
||||
#[test]
|
||||
fn test_parse_tool_call() {
|
||||
@@ -389,7 +390,7 @@ ACTION_INPUT: {"query": "Rust programming language"}
|
||||
arguments,
|
||||
} => {
|
||||
assert!(thought.contains("search for information"));
|
||||
assert_eq!(tool_name, "web_search");
|
||||
assert!(matches!(tool_name.as_str(), WEB_SEARCH_TOOL_NAME));
|
||||
assert_eq!(arguments["query"], "Rust programming language");
|
||||
}
|
||||
_ => panic!("Expected ToolCall"),
|
||||
|
||||
@@ -2,6 +2,7 @@ use crate::Error;
|
||||
use crate::ProviderConfig;
|
||||
use crate::Result;
|
||||
use crate::mode::ModeConfig;
|
||||
use crate::tools::WEB_SEARCH_TOOL_NAME;
|
||||
use crate::ui::RoleLabelDisplay;
|
||||
use serde::de::{self, Deserializer, Visitor};
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -328,6 +329,11 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate MCP server configurations that mirror the Codex CLI defaults.
|
||||
pub fn codex_default_servers() -> Vec<McpServerConfig> {
|
||||
crate::mcp::codex::codex_connector_configs()
|
||||
}
|
||||
|
||||
/// Persist configuration to disk
|
||||
pub fn save(&self, path: Option<&Path>) -> Result<()> {
|
||||
let mut validator = self.clone();
|
||||
@@ -1682,7 +1688,7 @@ impl SecuritySettings {
|
||||
|
||||
fn default_allowed_tools() -> Vec<String> {
|
||||
vec![
|
||||
"web_search".to_string(),
|
||||
WEB_SEARCH_TOOL_NAME.to_string(),
|
||||
"web_scrape".to_string(),
|
||||
"code_exec".to_string(),
|
||||
"file_write".to_string(),
|
||||
|
||||
@@ -7,6 +7,7 @@ use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::encryption::VaultHandle;
|
||||
use crate::tools::canonical_tool_name;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ConsentRequest {
|
||||
@@ -94,10 +95,12 @@ impl ConsentManager {
|
||||
data_types: Vec<String>,
|
||||
endpoints: Vec<String>,
|
||||
) -> Result<ConsentScope> {
|
||||
let canonical = canonical_tool_name(tool_name);
|
||||
|
||||
// Check if already granted permanently
|
||||
if self
|
||||
.permanent_records
|
||||
.get(tool_name)
|
||||
.get(canonical)
|
||||
.is_some_and(|existing| existing.scope == ConsentScope::Permanent)
|
||||
{
|
||||
return Ok(ConsentScope::Permanent);
|
||||
@@ -106,31 +109,31 @@ impl ConsentManager {
|
||||
// Check if granted for session
|
||||
if self
|
||||
.session_records
|
||||
.get(tool_name)
|
||||
.get(canonical)
|
||||
.is_some_and(|existing| existing.scope == ConsentScope::Session)
|
||||
{
|
||||
return Ok(ConsentScope::Session);
|
||||
}
|
||||
|
||||
// Check if request is already pending (prevent duplicate prompts)
|
||||
if self.pending_requests.contains_key(tool_name) {
|
||||
if self.pending_requests.contains_key(canonical) {
|
||||
// Wait for the other prompt to complete by returning denied temporarily
|
||||
// The caller should retry after a short delay
|
||||
return Ok(ConsentScope::Denied);
|
||||
}
|
||||
|
||||
// Mark as pending
|
||||
self.pending_requests.insert(tool_name.to_string(), ());
|
||||
self.pending_requests.insert(canonical.to_string(), ());
|
||||
|
||||
// Show consent dialog and get scope
|
||||
let scope = self.show_consent_dialog(tool_name, &data_types, &endpoints)?;
|
||||
|
||||
// Remove from pending
|
||||
self.pending_requests.remove(tool_name);
|
||||
self.pending_requests.remove(canonical);
|
||||
|
||||
// Create record based on scope
|
||||
let record = ConsentRecord {
|
||||
tool_name: tool_name.to_string(),
|
||||
tool_name: canonical.to_string(),
|
||||
scope: scope.clone(),
|
||||
timestamp: Utc::now(),
|
||||
data_types,
|
||||
@@ -140,10 +143,10 @@ impl ConsentManager {
|
||||
// Store in appropriate location
|
||||
match scope {
|
||||
ConsentScope::Permanent => {
|
||||
self.permanent_records.insert(tool_name.to_string(), record);
|
||||
self.permanent_records.insert(canonical.to_string(), record);
|
||||
}
|
||||
ConsentScope::Session => {
|
||||
self.session_records.insert(tool_name.to_string(), record);
|
||||
self.session_records.insert(canonical.to_string(), record);
|
||||
}
|
||||
ConsentScope::Once | ConsentScope::Denied => {
|
||||
// Don't store, just return the decision
|
||||
@@ -171,8 +174,9 @@ impl ConsentManager {
|
||||
endpoints: Vec<String>,
|
||||
scope: ConsentScope,
|
||||
) {
|
||||
let canonical = canonical_tool_name(tool_name);
|
||||
let record = ConsentRecord {
|
||||
tool_name: tool_name.to_string(),
|
||||
tool_name: canonical.to_string(),
|
||||
scope: scope.clone(),
|
||||
timestamp: Utc::now(),
|
||||
data_types,
|
||||
@@ -181,13 +185,13 @@ impl ConsentManager {
|
||||
|
||||
match scope {
|
||||
ConsentScope::Permanent => {
|
||||
self.permanent_records.insert(tool_name.to_string(), record);
|
||||
self.permanent_records.insert(canonical.to_string(), record);
|
||||
}
|
||||
ConsentScope::Session => {
|
||||
self.session_records.insert(tool_name.to_string(), record);
|
||||
self.session_records.insert(canonical.to_string(), record);
|
||||
}
|
||||
ConsentScope::Once => {
|
||||
self.once_records.insert(tool_name.to_string(), record);
|
||||
self.once_records.insert(canonical.to_string(), record);
|
||||
}
|
||||
ConsentScope::Denied => {} // Denied is not stored
|
||||
}
|
||||
@@ -195,28 +199,30 @@ impl ConsentManager {
|
||||
|
||||
/// Check if consent is needed (returns None if already granted, Some(info) if needed)
|
||||
pub fn check_consent_needed(&self, tool_name: &str) -> Option<ConsentRequest> {
|
||||
if self.has_consent(tool_name) {
|
||||
let canonical = canonical_tool_name(tool_name);
|
||||
if self.has_consent(canonical) {
|
||||
None
|
||||
} else {
|
||||
Some(ConsentRequest {
|
||||
tool_name: tool_name.to_string(),
|
||||
tool_name: canonical.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn has_consent(&self, tool_name: &str) -> bool {
|
||||
let canonical = canonical_tool_name(tool_name);
|
||||
// Check permanent first, then session, then once
|
||||
self.permanent_records
|
||||
.get(tool_name)
|
||||
.get(canonical)
|
||||
.map(|r| r.scope == ConsentScope::Permanent)
|
||||
.or_else(|| {
|
||||
self.session_records
|
||||
.get(tool_name)
|
||||
.get(canonical)
|
||||
.map(|r| r.scope == ConsentScope::Session)
|
||||
})
|
||||
.or_else(|| {
|
||||
self.once_records
|
||||
.get(tool_name)
|
||||
.get(canonical)
|
||||
.map(|r| r.scope == ConsentScope::Once)
|
||||
})
|
||||
.unwrap_or(false)
|
||||
@@ -224,13 +230,15 @@ impl ConsentManager {
|
||||
|
||||
/// Consume "once" consent for a tool (clears it after first use)
|
||||
pub fn consume_once_consent(&mut self, tool_name: &str) {
|
||||
self.once_records.remove(tool_name);
|
||||
let canonical = canonical_tool_name(tool_name);
|
||||
self.once_records.remove(canonical);
|
||||
}
|
||||
|
||||
pub fn revoke_consent(&mut self, tool_name: &str) {
|
||||
self.permanent_records.remove(tool_name);
|
||||
self.session_records.remove(tool_name);
|
||||
self.once_records.remove(tool_name);
|
||||
let canonical = canonical_tool_name(tool_name);
|
||||
self.permanent_records.remove(canonical);
|
||||
self.session_records.remove(canonical);
|
||||
self.once_records.remove(canonical);
|
||||
}
|
||||
|
||||
pub fn clear_all_consent(&mut self) {
|
||||
@@ -253,10 +261,11 @@ impl ConsentManager {
|
||||
data_types: Vec<String>,
|
||||
endpoints: Vec<String>,
|
||||
) -> Option<(String, Vec<String>, Vec<String>)> {
|
||||
if self.has_consent(tool_name) {
|
||||
let canonical = canonical_tool_name(tool_name);
|
||||
if self.has_consent(canonical) {
|
||||
return None;
|
||||
}
|
||||
Some((tool_name.to_string(), data_types, endpoints))
|
||||
Some((canonical.to_string(), data_types, endpoints))
|
||||
}
|
||||
|
||||
fn show_consent_dialog(
|
||||
|
||||
@@ -19,6 +19,7 @@ use crate::model::{DetailedModelInfo, ModelManager};
|
||||
use crate::oauth::{DeviceAuthorization, DevicePollState, OAuthClient};
|
||||
use crate::providers::OllamaProvider;
|
||||
use crate::storage::{SessionMeta, StorageManager};
|
||||
use crate::tools::{WEB_SEARCH_TOOL_NAME, canonical_tool_name, tool_name_matches};
|
||||
use crate::types::{
|
||||
ChatParameters, ChatRequest, ChatResponse, Conversation, Message, ModelInfo, ToolCall,
|
||||
};
|
||||
@@ -407,7 +408,7 @@ async fn build_tools(
|
||||
.security
|
||||
.allowed_tools
|
||||
.iter()
|
||||
.any(|tool| tool == "web_search")
|
||||
.any(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME))
|
||||
&& config_guard.tools.web_search.enabled
|
||||
&& config_guard.privacy.enable_remote_search
|
||||
{
|
||||
@@ -424,7 +425,7 @@ async fn build_tools(
|
||||
|
||||
if let Some(settings) = web_search_settings {
|
||||
let tool = WebSearchTool::new(consent_manager.clone(), settings);
|
||||
registry.register(tool);
|
||||
registry.register(tool)?;
|
||||
}
|
||||
|
||||
// Register web_scrape tool if allowed.
|
||||
@@ -432,12 +433,12 @@ async fn build_tools(
|
||||
.security
|
||||
.allowed_tools
|
||||
.iter()
|
||||
.any(|tool| tool == "web_scrape")
|
||||
.any(|tool| tool_name_matches(tool, "web_scrape"))
|
||||
&& config_guard.tools.web_search.enabled // reuse web_search toggle for simplicity
|
||||
&& config_guard.privacy.enable_remote_search
|
||||
{
|
||||
let tool = WebScrapeTool::new();
|
||||
registry.register(tool);
|
||||
registry.register(tool)?;
|
||||
}
|
||||
|
||||
if enable_code_tools
|
||||
@@ -449,11 +450,11 @@ async fn build_tools(
|
||||
&& config_guard.tools.code_exec.enabled
|
||||
{
|
||||
let tool = CodeExecTool::new(config_guard.tools.code_exec.allowed_languages.clone());
|
||||
registry.register(tool);
|
||||
registry.register(tool)?;
|
||||
}
|
||||
|
||||
registry.register(ResourcesListTool);
|
||||
registry.register(ResourcesGetTool);
|
||||
registry.register(ResourcesListTool)?;
|
||||
registry.register(ResourcesGetTool)?;
|
||||
|
||||
if config_guard
|
||||
.security
|
||||
@@ -461,7 +462,7 @@ async fn build_tools(
|
||||
.iter()
|
||||
.any(|t| t == "file_write")
|
||||
{
|
||||
registry.register(ResourcesWriteTool);
|
||||
registry.register(ResourcesWriteTool)?;
|
||||
}
|
||||
if config_guard
|
||||
.security
|
||||
@@ -469,7 +470,7 @@ async fn build_tools(
|
||||
.iter()
|
||||
.any(|t| t == "file_delete")
|
||||
{
|
||||
registry.register(ResourcesDeleteTool);
|
||||
registry.register(ResourcesDeleteTool)?;
|
||||
}
|
||||
|
||||
for tool in registry.all() {
|
||||
@@ -1023,13 +1024,14 @@ impl SessionController {
|
||||
let mut seen_tools = std::collections::HashSet::new();
|
||||
|
||||
for tool_call in tool_calls {
|
||||
if seen_tools.contains(&tool_call.name) {
|
||||
let canonical = canonical_tool_name(tool_call.name.as_str()).to_string();
|
||||
if seen_tools.contains(&canonical) {
|
||||
continue;
|
||||
}
|
||||
seen_tools.insert(tool_call.name.clone());
|
||||
seen_tools.insert(canonical.clone());
|
||||
|
||||
let (data_types, endpoints) = match tool_call.name.as_str() {
|
||||
"web_search" => (
|
||||
let (data_types, endpoints) = match canonical.as_str() {
|
||||
WEB_SEARCH_TOOL_NAME => (
|
||||
vec!["search query".to_string()],
|
||||
vec!["cloud provider".to_string()],
|
||||
),
|
||||
@@ -1097,13 +1099,14 @@ impl SessionController {
|
||||
pub async fn set_tool_enabled(&mut self, tool: &str, enabled: bool) -> Result<()> {
|
||||
{
|
||||
let mut config = self.config.lock().await;
|
||||
match tool {
|
||||
"web_search" => {
|
||||
let canonical = canonical_tool_name(tool);
|
||||
match canonical {
|
||||
WEB_SEARCH_TOOL_NAME => {
|
||||
config.tools.web_search.enabled = enabled;
|
||||
config.privacy.enable_remote_search = enabled;
|
||||
}
|
||||
"code_exec" => config.tools.code_exec.enabled = enabled,
|
||||
other => return Err(Error::InvalidInput(format!("Unknown tool: {other}"))),
|
||||
_ => return Err(Error::InvalidInput(format!("Unknown tool: {tool}"))),
|
||||
}
|
||||
}
|
||||
self.rebuild_tools().await
|
||||
@@ -1897,12 +1900,12 @@ mod tests {
|
||||
#[test]
|
||||
fn streaming_state_detects_tool_call_changes() {
|
||||
let mut state = StreamingMessageState::new();
|
||||
let tool = make_tool_call("call-1", "web.search");
|
||||
let tool = make_tool_call("call-1", "web_search");
|
||||
|
||||
let diff = state.ingest(&make_response("", Some(vec![tool.clone()]), false));
|
||||
let calls = diff.tool_calls.expect("initial tool call");
|
||||
assert_eq!(calls.len(), 1);
|
||||
assert_eq!(calls[0].name, "web.search");
|
||||
assert_eq!(calls[0].name, "web_search");
|
||||
|
||||
let diff = state.ingest(&make_response("", Some(vec![tool.clone()]), false));
|
||||
assert!(
|
||||
|
||||
@@ -12,12 +12,64 @@ pub mod web_scrape;
|
||||
pub mod web_search;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use once_cell::sync::Lazy;
|
||||
use regex::Regex;
|
||||
use serde_json::{Value, json};
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::Result;
|
||||
|
||||
/// MCP mandates tool identifiers to match `^[A-Za-z0-9_-]{1,64}$`.
|
||||
pub const MAX_TOOL_IDENTIFIER_LEN: usize = 64;
|
||||
|
||||
static TOOL_IDENTIFIER_RE: Lazy<Regex> =
|
||||
Lazy::new(|| Regex::new(r"^[A-Za-z0-9_-]{1,64}$").expect("valid tool identifier regex"));
|
||||
|
||||
pub const WEB_SEARCH_TOOL_NAME: &str = "web_search";
|
||||
|
||||
/// Return the canonical identifier for a tool.
|
||||
pub fn canonical_tool_name(name: &str) -> &str {
|
||||
name
|
||||
}
|
||||
|
||||
/// Check whether two tool identifiers refer to the same logical tool.
|
||||
pub fn tool_name_matches(lhs: &str, rhs: &str) -> bool {
|
||||
canonical_tool_name(lhs) == canonical_tool_name(rhs)
|
||||
}
|
||||
|
||||
/// Determine whether the provided identifier satisfies the MCP naming contract.
|
||||
pub fn is_valid_tool_identifier(name: &str) -> bool {
|
||||
TOOL_IDENTIFIER_RE.is_match(name)
|
||||
}
|
||||
|
||||
/// Provide lint-style feedback when a tool identifier falls outside the MCP rules.
|
||||
pub fn tool_identifier_violation(name: &str) -> Option<String> {
|
||||
if name.is_empty() {
|
||||
return Some("Tool identifiers must not be empty.".to_string());
|
||||
}
|
||||
|
||||
if name.len() > MAX_TOOL_IDENTIFIER_LEN {
|
||||
return Some(format!(
|
||||
"Tool identifier '{name}' exceeds the {MAX_TOOL_IDENTIFIER_LEN}-character MCP limit."
|
||||
));
|
||||
}
|
||||
|
||||
if name.trim() != name {
|
||||
return Some(format!(
|
||||
"Tool identifier '{name}' contains leading or trailing whitespace."
|
||||
));
|
||||
}
|
||||
|
||||
if !TOOL_IDENTIFIER_RE.is_match(name) {
|
||||
return Some(format!(
|
||||
"Tool identifier '{name}' may only contain ASCII letters, digits, hyphens, or underscores."
|
||||
));
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Trait representing a tool that can be called via the MCP interface.
|
||||
#[async_trait]
|
||||
pub trait Tool: Send + Sync {
|
||||
@@ -34,6 +86,10 @@ pub trait Tool: Send + Sync {
|
||||
fn requires_filesystem(&self) -> Vec<String> {
|
||||
Vec::new()
|
||||
}
|
||||
/// Optional additional identifiers (must remain spec-compliant).
|
||||
fn aliases(&self) -> &'static [&'static str] {
|
||||
&[]
|
||||
}
|
||||
async fn execute(&self, args: Value) -> Result<ToolResult>;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::Result;
|
||||
use crate::{Error, Result};
|
||||
use anyhow::Context;
|
||||
use serde_json::Value;
|
||||
|
||||
use super::{Tool, ToolResult};
|
||||
use super::{
|
||||
Tool, ToolResult, WEB_SEARCH_TOOL_NAME, canonical_tool_name, tool_identifier_violation,
|
||||
};
|
||||
use crate::config::Config;
|
||||
use crate::mode::Mode;
|
||||
use crate::ui::UiController;
|
||||
@@ -25,13 +27,32 @@ impl ToolRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register<T>(&mut self, tool: T)
|
||||
pub fn register<T>(&mut self, tool: T) -> Result<()>
|
||||
where
|
||||
T: Tool + 'static,
|
||||
{
|
||||
let tool: Arc<dyn Tool> = Arc::new(tool);
|
||||
let name = tool.name().to_string();
|
||||
self.tools.insert(name, tool);
|
||||
let name = tool.name();
|
||||
|
||||
if let Some(reason) = tool_identifier_violation(name) {
|
||||
log::error!("Tool '{}' failed validation: {}", name, reason);
|
||||
return Err(Error::InvalidInput(format!(
|
||||
"Tool '{name}' is not a valid MCP identifier: {reason}"
|
||||
)));
|
||||
}
|
||||
|
||||
if self
|
||||
.tools
|
||||
.insert(name.to_string(), Arc::clone(&tool))
|
||||
.is_some()
|
||||
{
|
||||
log::warn!(
|
||||
"Tool '{}' was already registered; overwriting previous entry.",
|
||||
name
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
|
||||
@@ -43,20 +64,25 @@ impl ToolRegistry {
|
||||
}
|
||||
|
||||
pub async fn execute(&self, name: &str, args: Value, mode: Mode) -> Result<ToolResult> {
|
||||
let canonical = canonical_tool_name(name);
|
||||
let tool = self
|
||||
.get(name)
|
||||
.get(canonical)
|
||||
.with_context(|| format!("Tool not registered: {}", name))?;
|
||||
|
||||
let mut config = self.config.lock().await;
|
||||
|
||||
// Check mode-based tool availability first
|
||||
if !config.modes.is_tool_allowed(mode, name) {
|
||||
if !(config.modes.is_tool_allowed(mode, canonical)
|
||||
|| config.modes.is_tool_allowed(mode, name))
|
||||
{
|
||||
let alternate_mode = match mode {
|
||||
Mode::Chat => Mode::Code,
|
||||
Mode::Code => Mode::Chat,
|
||||
};
|
||||
|
||||
if config.modes.is_tool_allowed(alternate_mode, name) {
|
||||
if config.modes.is_tool_allowed(alternate_mode, canonical)
|
||||
|| config.modes.is_tool_allowed(alternate_mode, name)
|
||||
{
|
||||
return Ok(ToolResult::error(&format!(
|
||||
"Tool '{}' is not available in {} mode. Switch to {} mode to use this tool (use :mode {} command).",
|
||||
name, mode, alternate_mode, alternate_mode
|
||||
@@ -69,8 +95,8 @@ impl ToolRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
let is_enabled = match name {
|
||||
"web_search" => config.tools.web_search.enabled,
|
||||
let is_enabled = match canonical {
|
||||
WEB_SEARCH_TOOL_NAME => config.tools.web_search.enabled,
|
||||
"code_exec" => config.tools.code_exec.enabled,
|
||||
_ => true, // All other tools are considered enabled by default
|
||||
};
|
||||
@@ -82,8 +108,8 @@ impl ToolRegistry {
|
||||
);
|
||||
if self.ui.confirm(&prompt).await {
|
||||
// Enable the tool in the in-memory config for the current session
|
||||
match name {
|
||||
"web_search" => config.tools.web_search.enabled = true,
|
||||
match canonical {
|
||||
WEB_SEARCH_TOOL_NAME => config.tools.web_search.enabled = true,
|
||||
"code_exec" => config.tools.code_exec.enabled = true,
|
||||
_ => {}
|
||||
}
|
||||
@@ -112,3 +138,69 @@ impl ToolRegistry {
|
||||
self.tools.keys().cloned().collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::Config;
|
||||
use crate::tools::{Tool, ToolResult, WEB_SEARCH_TOOL_NAME};
|
||||
use crate::ui::NoOpUiController;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::{Value, json};
|
||||
use std::sync::Arc;
|
||||
|
||||
struct DummyTool {
|
||||
name: &'static str,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for DummyTool {
|
||||
fn name(&self) -> &'static str {
|
||||
self.name
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"dummy tool"
|
||||
}
|
||||
|
||||
fn schema(&self) -> Value {
|
||||
json!({ "type": "object" })
|
||||
}
|
||||
|
||||
fn aliases(&self) -> &'static [&'static str] {
|
||||
self.aliases
|
||||
}
|
||||
|
||||
async fn execute(&self, _args: Value) -> Result<ToolResult> {
|
||||
Ok(ToolResult::success(json!({ "echo": true })))
|
||||
}
|
||||
}
|
||||
|
||||
fn registry() -> ToolRegistry {
|
||||
let config = Arc::new(tokio::sync::Mutex::new(Config::default()));
|
||||
let ui = Arc::new(NoOpUiController);
|
||||
ToolRegistry::new(config, ui)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_invalid_tool_identifier() {
|
||||
let mut registry = registry();
|
||||
let tool = DummyTool {
|
||||
name: "invalid.tool",
|
||||
};
|
||||
|
||||
let err = registry.register(tool).unwrap_err();
|
||||
assert!(matches!(err, Error::InvalidInput(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn registers_spec_compliant_tool() {
|
||||
let mut registry = registry();
|
||||
let tool = DummyTool {
|
||||
name: WEB_SEARCH_TOOL_NAME,
|
||||
};
|
||||
|
||||
registry.register(tool).unwrap();
|
||||
assert!(registry.get(WEB_SEARCH_TOOL_NAME).is_some());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ use serde_json::{Value, json};
|
||||
|
||||
use super::{Tool, ToolResult};
|
||||
use crate::consent::ConsentManager;
|
||||
use crate::tools::WEB_SEARCH_TOOL_NAME;
|
||||
|
||||
/// Configuration applied to the web search tool at registration time.
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -44,7 +45,7 @@ impl WebSearchTool {
|
||||
#[async_trait]
|
||||
impl Tool for WebSearchTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"web_search"
|
||||
WEB_SEARCH_TOOL_NAME
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
|
||||
@@ -4,6 +4,8 @@ use anyhow::{Context, Result};
|
||||
use jsonschema::{JSONSchema, ValidationError};
|
||||
use serde_json::{Value, json};
|
||||
|
||||
use crate::tools::WEB_SEARCH_TOOL_NAME;
|
||||
|
||||
pub struct SchemaValidator {
|
||||
schemas: HashMap<String, JSONSchema>,
|
||||
}
|
||||
@@ -56,9 +58,7 @@ fn format_validation_error(error: ValidationError) -> String {
|
||||
pub fn get_builtin_schemas() -> HashMap<String, Value> {
|
||||
let mut schemas = HashMap::new();
|
||||
|
||||
schemas.insert(
|
||||
"web_search".to_string(),
|
||||
json!({
|
||||
let web_search_schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
@@ -75,8 +75,9 @@ pub fn get_builtin_schemas() -> HashMap<String, Value> {
|
||||
},
|
||||
"required": ["query"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
schemas.insert(WEB_SEARCH_TOOL_NAME.to_string(), web_search_schema.clone());
|
||||
|
||||
schemas.insert(
|
||||
"code_exec".to_string(),
|
||||
|
||||
@@ -2,6 +2,7 @@ use std::{any::Any, collections::HashMap, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::StreamExt;
|
||||
use owlen_core::tools::{WEB_SEARCH_TOOL_NAME, tool_name_matches};
|
||||
use owlen_core::{
|
||||
Config, Error, Mode, Provider,
|
||||
config::McpMode,
|
||||
@@ -88,7 +89,7 @@ impl Provider for StreamingToolProvider {
|
||||
|
||||
fn tool_descriptor() -> McpToolDescriptor {
|
||||
McpToolDescriptor {
|
||||
name: "web_search".to_string(),
|
||||
name: WEB_SEARCH_TOOL_NAME.to_string(),
|
||||
description: "search".to_string(),
|
||||
input_schema: serde_json::json!({"type": "object"}),
|
||||
requires_network: true,
|
||||
@@ -123,7 +124,7 @@ impl CachedResponseClient {
|
||||
metadata.insert("cached".to_string(), "true".to_string());
|
||||
|
||||
let response = McpToolResponse {
|
||||
name: "web_search".to_string(),
|
||||
name: WEB_SEARCH_TOOL_NAME.to_string(),
|
||||
success: true,
|
||||
output: serde_json::json!({
|
||||
"query": "rust",
|
||||
@@ -286,13 +287,13 @@ async fn web_tool_timeout_fails_over_to_cached_result() {
|
||||
]);
|
||||
|
||||
let call = McpToolCall {
|
||||
name: "web_search".to_string(),
|
||||
name: WEB_SEARCH_TOOL_NAME.to_string(),
|
||||
arguments: serde_json::json!({ "query": "rust", "max_results": 3 }),
|
||||
};
|
||||
|
||||
let response = client.call_tool(call.clone()).await.expect("fallback");
|
||||
|
||||
assert_eq!(response.name, "web_search");
|
||||
assert!(tool_name_matches(&response.name, WEB_SEARCH_TOOL_NAME));
|
||||
assert_eq!(
|
||||
response.metadata.get("source").map(String::as_str),
|
||||
Some("cache")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use owlen_core::consent::{ConsentManager, ConsentScope};
|
||||
use owlen_core::tools::WEB_SEARCH_TOOL_NAME;
|
||||
|
||||
#[test]
|
||||
fn test_consent_scopes() {
|
||||
@@ -43,23 +44,23 @@ fn test_pending_requests_prevents_duplicates() {
|
||||
// In real usage, multiple threads would call request_consent simultaneously
|
||||
|
||||
// First, verify a tool has no consent
|
||||
assert!(!manager.has_consent("web_search"));
|
||||
assert!(!manager.has_consent(WEB_SEARCH_TOOL_NAME));
|
||||
|
||||
// The pending_requests map is private, but we can test the behavior
|
||||
// by checking that consent checks work correctly
|
||||
assert!(manager.check_consent_needed("web_search").is_some());
|
||||
assert!(manager.check_consent_needed(WEB_SEARCH_TOOL_NAME).is_some());
|
||||
|
||||
// Grant session consent
|
||||
manager.grant_consent_with_scope(
|
||||
"web_search",
|
||||
WEB_SEARCH_TOOL_NAME,
|
||||
vec!["search queries".to_string()],
|
||||
vec!["https://api.search.com".to_string()],
|
||||
ConsentScope::Session,
|
||||
);
|
||||
|
||||
// Now it should have consent
|
||||
assert!(manager.has_consent("web_search"));
|
||||
assert!(manager.check_consent_needed("web_search").is_none());
|
||||
assert!(manager.has_consent(WEB_SEARCH_TOOL_NAME));
|
||||
assert!(manager.check_consent_needed(WEB_SEARCH_TOOL_NAME).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -2,6 +2,7 @@ use std::{any::Any, collections::HashMap, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::stream;
|
||||
use owlen_core::tools::{WEB_SEARCH_TOOL_NAME, tool_name_matches};
|
||||
use owlen_core::{
|
||||
ChatStream, Provider, Result,
|
||||
config::Config,
|
||||
@@ -96,12 +97,12 @@ async fn toggling_web_search_updates_config_and_registry() {
|
||||
.tool_registry()
|
||||
.tools()
|
||||
.iter()
|
||||
.any(|tool| tool == "web_search"),
|
||||
.any(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME)),
|
||||
"web_search should be disabled by default"
|
||||
);
|
||||
|
||||
session
|
||||
.set_tool_enabled("web_search", true)
|
||||
.set_tool_enabled(WEB_SEARCH_TOOL_NAME, true)
|
||||
.await
|
||||
.expect("enable web_search");
|
||||
|
||||
@@ -115,12 +116,12 @@ async fn toggling_web_search_updates_config_and_registry() {
|
||||
.tool_registry()
|
||||
.tools()
|
||||
.iter()
|
||||
.any(|tool| tool == "web_search"),
|
||||
.any(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME)),
|
||||
"web_search should be registered when enabled"
|
||||
);
|
||||
|
||||
session
|
||||
.set_tool_enabled("web_search", false)
|
||||
.set_tool_enabled(WEB_SEARCH_TOOL_NAME, false)
|
||||
.await
|
||||
.expect("disable web_search");
|
||||
|
||||
@@ -134,7 +135,7 @@ async fn toggling_web_search_updates_config_and_registry() {
|
||||
.tool_registry()
|
||||
.tools()
|
||||
.iter()
|
||||
.any(|tool| tool == "web_search"),
|
||||
.any(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME)),
|
||||
"web_search should be removed when disabled"
|
||||
);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user