diff --git a/crates/owlen-cli/src/commands/providers.rs b/crates/owlen-cli/src/commands/providers.rs index 6aef1be..67842e0 100644 --- a/crates/owlen-cli/src/commands/providers.rs +++ b/crates/owlen-cli/src/commands/providers.rs @@ -9,6 +9,7 @@ use owlen_core::provider::{ AnnotatedModelInfo, ModelProvider, ProviderManager, ProviderStatus, ProviderType, }; use owlen_core::storage::StorageManager; +use owlen_core::tools::{WEB_SEARCH_TOOL_NAME, tool_name_matches}; use owlen_providers::ollama::{OllamaCloudProvider, OllamaLocalProvider}; use owlen_tui::config as tui_config; @@ -35,7 +36,7 @@ pub enum ProvidersCommand { /// Provider identifier to disable. provider: String, }, - /// Enable or disable the web.search tool exposure. + /// Enable or disable the `web_search` tool exposure. Web(WebCommand), } @@ -47,13 +48,13 @@ pub struct ModelsArgs { pub provider: Option, } -/// Arguments for managing the web.search tool exposure. +/// Arguments for managing the `web_search` tool exposure. #[derive(Debug, Args)] pub struct WebCommand { - /// Enable the web.search tool and allow remote lookups. + /// Enable the `web_search` tool and allow remote lookups. #[arg(long, conflicts_with = "disable")] enable: bool, - /// Disable the web.search tool to keep sessions local-only. + /// Disable the `web_search` tool to keep sessions local-only. #[arg(long, conflicts_with = "enable")] disable: bool, } @@ -281,14 +282,16 @@ fn apply_web_toggle(config: &mut Config, enabled: bool) { config.tools.web_search.enabled = enabled; config.privacy.enable_remote_search = enabled; - if enabled - && !config + config + .security + .allowed_tools + .retain(|tool| !tool_name_matches(tool, WEB_SEARCH_TOOL_NAME)); + + if enabled { + config .security .allowed_tools - .iter() - .any(|tool| tool.eq_ignore_ascii_case("web_search")) - { - config.security.allowed_tools.push("web_search".to_string()); + .push(WEB_SEARCH_TOOL_NAME.to_string()); } } @@ -760,7 +763,7 @@ mod tests { .security .allowed_tools .iter() - .filter(|tool| tool.eq_ignore_ascii_case("web_search")) + .filter(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME)) .count() ); @@ -775,8 +778,11 @@ mod tests { config .security .allowed_tools - .retain(|tool| !tool.eq_ignore_ascii_case("web_search")); - config.security.allowed_tools.push("web_search".to_string()); + .retain(|tool| !tool_name_matches(tool, WEB_SEARCH_TOOL_NAME)); + config + .security + .allowed_tools + .push(WEB_SEARCH_TOOL_NAME.to_string()); apply_web_toggle(&mut config, true); apply_web_toggle(&mut config, true); @@ -787,7 +793,7 @@ mod tests { .security .allowed_tools .iter() - .filter(|tool| tool.eq_ignore_ascii_case("web_search")) + .filter(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME)) .count() ); } diff --git a/crates/owlen-cli/tests/agent_tests.rs b/crates/owlen-cli/tests/agent_tests.rs index 5e6a726..41c4cc8 100644 --- a/crates/owlen-cli/tests/agent_tests.rs +++ b/crates/owlen-cli/tests/agent_tests.rs @@ -9,6 +9,7 @@ use owlen_cli::agent::{AgentConfig, AgentExecutor, LlmResponse}; use owlen_core::mcp::remote_client::RemoteMcpClient; +use owlen_core::tools::WEB_SEARCH_TOOL_NAME; use std::sync::Arc; #[tokio::test] @@ -27,7 +28,7 @@ async fn test_react_parsing_tool_call() { arguments, }) => { assert_eq!(thought, "I should search for information"); - assert_eq!(tool_name, "web_search"); + assert_eq!(tool_name.as_str(), WEB_SEARCH_TOOL_NAME); assert_eq!(arguments["query"], "rust async programming"); } other => panic!("Expected ToolCall, got: {:?}", other), diff --git a/crates/owlen-core/src/agent.rs b/crates/owlen-core/src/agent.rs index 0e455d7..382facd 100644 --- a/crates/owlen-core/src/agent.rs +++ b/crates/owlen-core/src/agent.rs @@ -366,6 +366,7 @@ mod tests { use super::*; use crate::llm::test_utils::MockProvider; use crate::mcp::test_utils::MockMcpClient; + use crate::tools::WEB_SEARCH_TOOL_NAME; #[test] fn test_parse_tool_call() { @@ -389,7 +390,7 @@ ACTION_INPUT: {"query": "Rust programming language"} arguments, } => { assert!(thought.contains("search for information")); - assert_eq!(tool_name, "web_search"); + assert!(matches!(tool_name.as_str(), WEB_SEARCH_TOOL_NAME)); assert_eq!(arguments["query"], "Rust programming language"); } _ => panic!("Expected ToolCall"), diff --git a/crates/owlen-core/src/config.rs b/crates/owlen-core/src/config.rs index 3fcf246..e2b385a 100644 --- a/crates/owlen-core/src/config.rs +++ b/crates/owlen-core/src/config.rs @@ -2,6 +2,7 @@ use crate::Error; use crate::ProviderConfig; use crate::Result; use crate::mode::ModeConfig; +use crate::tools::WEB_SEARCH_TOOL_NAME; use crate::ui::RoleLabelDisplay; use serde::de::{self, Deserializer, Visitor}; use serde::{Deserialize, Serialize}; @@ -328,6 +329,11 @@ impl Config { } } + /// Generate MCP server configurations that mirror the Codex CLI defaults. + pub fn codex_default_servers() -> Vec { + crate::mcp::codex::codex_connector_configs() + } + /// Persist configuration to disk pub fn save(&self, path: Option<&Path>) -> Result<()> { let mut validator = self.clone(); @@ -1682,7 +1688,7 @@ impl SecuritySettings { fn default_allowed_tools() -> Vec { vec![ - "web_search".to_string(), + WEB_SEARCH_TOOL_NAME.to_string(), "web_scrape".to_string(), "code_exec".to_string(), "file_write".to_string(), diff --git a/crates/owlen-core/src/consent.rs b/crates/owlen-core/src/consent.rs index f851bf9..5347e5f 100644 --- a/crates/owlen-core/src/consent.rs +++ b/crates/owlen-core/src/consent.rs @@ -7,6 +7,7 @@ use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use crate::encryption::VaultHandle; +use crate::tools::canonical_tool_name; #[derive(Clone, Debug)] pub struct ConsentRequest { @@ -94,10 +95,12 @@ impl ConsentManager { data_types: Vec, endpoints: Vec, ) -> Result { + let canonical = canonical_tool_name(tool_name); + // Check if already granted permanently if self .permanent_records - .get(tool_name) + .get(canonical) .is_some_and(|existing| existing.scope == ConsentScope::Permanent) { return Ok(ConsentScope::Permanent); @@ -106,31 +109,31 @@ impl ConsentManager { // Check if granted for session if self .session_records - .get(tool_name) + .get(canonical) .is_some_and(|existing| existing.scope == ConsentScope::Session) { return Ok(ConsentScope::Session); } // Check if request is already pending (prevent duplicate prompts) - if self.pending_requests.contains_key(tool_name) { + if self.pending_requests.contains_key(canonical) { // Wait for the other prompt to complete by returning denied temporarily // The caller should retry after a short delay return Ok(ConsentScope::Denied); } // Mark as pending - self.pending_requests.insert(tool_name.to_string(), ()); + self.pending_requests.insert(canonical.to_string(), ()); // Show consent dialog and get scope let scope = self.show_consent_dialog(tool_name, &data_types, &endpoints)?; // Remove from pending - self.pending_requests.remove(tool_name); + self.pending_requests.remove(canonical); // Create record based on scope let record = ConsentRecord { - tool_name: tool_name.to_string(), + tool_name: canonical.to_string(), scope: scope.clone(), timestamp: Utc::now(), data_types, @@ -140,10 +143,10 @@ impl ConsentManager { // Store in appropriate location match scope { ConsentScope::Permanent => { - self.permanent_records.insert(tool_name.to_string(), record); + self.permanent_records.insert(canonical.to_string(), record); } ConsentScope::Session => { - self.session_records.insert(tool_name.to_string(), record); + self.session_records.insert(canonical.to_string(), record); } ConsentScope::Once | ConsentScope::Denied => { // Don't store, just return the decision @@ -171,8 +174,9 @@ impl ConsentManager { endpoints: Vec, scope: ConsentScope, ) { + let canonical = canonical_tool_name(tool_name); let record = ConsentRecord { - tool_name: tool_name.to_string(), + tool_name: canonical.to_string(), scope: scope.clone(), timestamp: Utc::now(), data_types, @@ -181,13 +185,13 @@ impl ConsentManager { match scope { ConsentScope::Permanent => { - self.permanent_records.insert(tool_name.to_string(), record); + self.permanent_records.insert(canonical.to_string(), record); } ConsentScope::Session => { - self.session_records.insert(tool_name.to_string(), record); + self.session_records.insert(canonical.to_string(), record); } ConsentScope::Once => { - self.once_records.insert(tool_name.to_string(), record); + self.once_records.insert(canonical.to_string(), record); } ConsentScope::Denied => {} // Denied is not stored } @@ -195,28 +199,30 @@ impl ConsentManager { /// Check if consent is needed (returns None if already granted, Some(info) if needed) pub fn check_consent_needed(&self, tool_name: &str) -> Option { - if self.has_consent(tool_name) { + let canonical = canonical_tool_name(tool_name); + if self.has_consent(canonical) { None } else { Some(ConsentRequest { - tool_name: tool_name.to_string(), + tool_name: canonical.to_string(), }) } } pub fn has_consent(&self, tool_name: &str) -> bool { + let canonical = canonical_tool_name(tool_name); // Check permanent first, then session, then once self.permanent_records - .get(tool_name) + .get(canonical) .map(|r| r.scope == ConsentScope::Permanent) .or_else(|| { self.session_records - .get(tool_name) + .get(canonical) .map(|r| r.scope == ConsentScope::Session) }) .or_else(|| { self.once_records - .get(tool_name) + .get(canonical) .map(|r| r.scope == ConsentScope::Once) }) .unwrap_or(false) @@ -224,13 +230,15 @@ impl ConsentManager { /// Consume "once" consent for a tool (clears it after first use) pub fn consume_once_consent(&mut self, tool_name: &str) { - self.once_records.remove(tool_name); + let canonical = canonical_tool_name(tool_name); + self.once_records.remove(canonical); } pub fn revoke_consent(&mut self, tool_name: &str) { - self.permanent_records.remove(tool_name); - self.session_records.remove(tool_name); - self.once_records.remove(tool_name); + let canonical = canonical_tool_name(tool_name); + self.permanent_records.remove(canonical); + self.session_records.remove(canonical); + self.once_records.remove(canonical); } pub fn clear_all_consent(&mut self) { @@ -253,10 +261,11 @@ impl ConsentManager { data_types: Vec, endpoints: Vec, ) -> Option<(String, Vec, Vec)> { - if self.has_consent(tool_name) { + let canonical = canonical_tool_name(tool_name); + if self.has_consent(canonical) { return None; } - Some((tool_name.to_string(), data_types, endpoints)) + Some((canonical.to_string(), data_types, endpoints)) } fn show_consent_dialog( diff --git a/crates/owlen-core/src/session.rs b/crates/owlen-core/src/session.rs index f86f1fa..c473b6f 100644 --- a/crates/owlen-core/src/session.rs +++ b/crates/owlen-core/src/session.rs @@ -19,6 +19,7 @@ use crate::model::{DetailedModelInfo, ModelManager}; use crate::oauth::{DeviceAuthorization, DevicePollState, OAuthClient}; use crate::providers::OllamaProvider; use crate::storage::{SessionMeta, StorageManager}; +use crate::tools::{WEB_SEARCH_TOOL_NAME, canonical_tool_name, tool_name_matches}; use crate::types::{ ChatParameters, ChatRequest, ChatResponse, Conversation, Message, ModelInfo, ToolCall, }; @@ -407,7 +408,7 @@ async fn build_tools( .security .allowed_tools .iter() - .any(|tool| tool == "web_search") + .any(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME)) && config_guard.tools.web_search.enabled && config_guard.privacy.enable_remote_search { @@ -424,7 +425,7 @@ async fn build_tools( if let Some(settings) = web_search_settings { let tool = WebSearchTool::new(consent_manager.clone(), settings); - registry.register(tool); + registry.register(tool)?; } // Register web_scrape tool if allowed. @@ -432,12 +433,12 @@ async fn build_tools( .security .allowed_tools .iter() - .any(|tool| tool == "web_scrape") + .any(|tool| tool_name_matches(tool, "web_scrape")) && config_guard.tools.web_search.enabled // reuse web_search toggle for simplicity && config_guard.privacy.enable_remote_search { let tool = WebScrapeTool::new(); - registry.register(tool); + registry.register(tool)?; } if enable_code_tools @@ -449,11 +450,11 @@ async fn build_tools( && config_guard.tools.code_exec.enabled { let tool = CodeExecTool::new(config_guard.tools.code_exec.allowed_languages.clone()); - registry.register(tool); + registry.register(tool)?; } - registry.register(ResourcesListTool); - registry.register(ResourcesGetTool); + registry.register(ResourcesListTool)?; + registry.register(ResourcesGetTool)?; if config_guard .security @@ -461,7 +462,7 @@ async fn build_tools( .iter() .any(|t| t == "file_write") { - registry.register(ResourcesWriteTool); + registry.register(ResourcesWriteTool)?; } if config_guard .security @@ -469,7 +470,7 @@ async fn build_tools( .iter() .any(|t| t == "file_delete") { - registry.register(ResourcesDeleteTool); + registry.register(ResourcesDeleteTool)?; } for tool in registry.all() { @@ -1023,13 +1024,14 @@ impl SessionController { let mut seen_tools = std::collections::HashSet::new(); for tool_call in tool_calls { - if seen_tools.contains(&tool_call.name) { + let canonical = canonical_tool_name(tool_call.name.as_str()).to_string(); + if seen_tools.contains(&canonical) { continue; } - seen_tools.insert(tool_call.name.clone()); + seen_tools.insert(canonical.clone()); - let (data_types, endpoints) = match tool_call.name.as_str() { - "web_search" => ( + let (data_types, endpoints) = match canonical.as_str() { + WEB_SEARCH_TOOL_NAME => ( vec!["search query".to_string()], vec!["cloud provider".to_string()], ), @@ -1097,13 +1099,14 @@ impl SessionController { pub async fn set_tool_enabled(&mut self, tool: &str, enabled: bool) -> Result<()> { { let mut config = self.config.lock().await; - match tool { - "web_search" => { + let canonical = canonical_tool_name(tool); + match canonical { + WEB_SEARCH_TOOL_NAME => { config.tools.web_search.enabled = enabled; config.privacy.enable_remote_search = enabled; } "code_exec" => config.tools.code_exec.enabled = enabled, - other => return Err(Error::InvalidInput(format!("Unknown tool: {other}"))), + _ => return Err(Error::InvalidInput(format!("Unknown tool: {tool}"))), } } self.rebuild_tools().await @@ -1897,12 +1900,12 @@ mod tests { #[test] fn streaming_state_detects_tool_call_changes() { let mut state = StreamingMessageState::new(); - let tool = make_tool_call("call-1", "web.search"); + let tool = make_tool_call("call-1", "web_search"); let diff = state.ingest(&make_response("", Some(vec![tool.clone()]), false)); let calls = diff.tool_calls.expect("initial tool call"); assert_eq!(calls.len(), 1); - assert_eq!(calls[0].name, "web.search"); + assert_eq!(calls[0].name, "web_search"); let diff = state.ingest(&make_response("", Some(vec![tool.clone()]), false)); assert!( diff --git a/crates/owlen-core/src/tools.rs b/crates/owlen-core/src/tools.rs index 1069c66..46b2fb4 100644 --- a/crates/owlen-core/src/tools.rs +++ b/crates/owlen-core/src/tools.rs @@ -12,12 +12,64 @@ pub mod web_scrape; pub mod web_search; use async_trait::async_trait; +use once_cell::sync::Lazy; +use regex::Regex; use serde_json::{Value, json}; use std::collections::HashMap; use std::time::Duration; use crate::Result; +/// MCP mandates tool identifiers to match `^[A-Za-z0-9_-]{1,64}$`. +pub const MAX_TOOL_IDENTIFIER_LEN: usize = 64; + +static TOOL_IDENTIFIER_RE: Lazy = + Lazy::new(|| Regex::new(r"^[A-Za-z0-9_-]{1,64}$").expect("valid tool identifier regex")); + +pub const WEB_SEARCH_TOOL_NAME: &str = "web_search"; + +/// Return the canonical identifier for a tool. +pub fn canonical_tool_name(name: &str) -> &str { + name +} + +/// Check whether two tool identifiers refer to the same logical tool. +pub fn tool_name_matches(lhs: &str, rhs: &str) -> bool { + canonical_tool_name(lhs) == canonical_tool_name(rhs) +} + +/// Determine whether the provided identifier satisfies the MCP naming contract. +pub fn is_valid_tool_identifier(name: &str) -> bool { + TOOL_IDENTIFIER_RE.is_match(name) +} + +/// Provide lint-style feedback when a tool identifier falls outside the MCP rules. +pub fn tool_identifier_violation(name: &str) -> Option { + if name.is_empty() { + return Some("Tool identifiers must not be empty.".to_string()); + } + + if name.len() > MAX_TOOL_IDENTIFIER_LEN { + return Some(format!( + "Tool identifier '{name}' exceeds the {MAX_TOOL_IDENTIFIER_LEN}-character MCP limit." + )); + } + + if name.trim() != name { + return Some(format!( + "Tool identifier '{name}' contains leading or trailing whitespace." + )); + } + + if !TOOL_IDENTIFIER_RE.is_match(name) { + return Some(format!( + "Tool identifier '{name}' may only contain ASCII letters, digits, hyphens, or underscores." + )); + } + + None +} + /// Trait representing a tool that can be called via the MCP interface. #[async_trait] pub trait Tool: Send + Sync { @@ -34,6 +86,10 @@ pub trait Tool: Send + Sync { fn requires_filesystem(&self) -> Vec { Vec::new() } + /// Optional additional identifiers (must remain spec-compliant). + fn aliases(&self) -> &'static [&'static str] { + &[] + } async fn execute(&self, args: Value) -> Result; } diff --git a/crates/owlen-core/src/tools/registry.rs b/crates/owlen-core/src/tools/registry.rs index ce05ce4..ff262b9 100644 --- a/crates/owlen-core/src/tools/registry.rs +++ b/crates/owlen-core/src/tools/registry.rs @@ -1,11 +1,13 @@ use std::collections::HashMap; use std::sync::Arc; -use crate::Result; +use crate::{Error, Result}; use anyhow::Context; use serde_json::Value; -use super::{Tool, ToolResult}; +use super::{ + Tool, ToolResult, WEB_SEARCH_TOOL_NAME, canonical_tool_name, tool_identifier_violation, +}; use crate::config::Config; use crate::mode::Mode; use crate::ui::UiController; @@ -25,13 +27,32 @@ impl ToolRegistry { } } - pub fn register(&mut self, tool: T) + pub fn register(&mut self, tool: T) -> Result<()> where T: Tool + 'static, { let tool: Arc = Arc::new(tool); - let name = tool.name().to_string(); - self.tools.insert(name, tool); + let name = tool.name(); + + if let Some(reason) = tool_identifier_violation(name) { + log::error!("Tool '{}' failed validation: {}", name, reason); + return Err(Error::InvalidInput(format!( + "Tool '{name}' is not a valid MCP identifier: {reason}" + ))); + } + + if self + .tools + .insert(name.to_string(), Arc::clone(&tool)) + .is_some() + { + log::warn!( + "Tool '{}' was already registered; overwriting previous entry.", + name + ); + } + + Ok(()) } pub fn get(&self, name: &str) -> Option> { @@ -43,20 +64,25 @@ impl ToolRegistry { } pub async fn execute(&self, name: &str, args: Value, mode: Mode) -> Result { + let canonical = canonical_tool_name(name); let tool = self - .get(name) + .get(canonical) .with_context(|| format!("Tool not registered: {}", name))?; let mut config = self.config.lock().await; // Check mode-based tool availability first - if !config.modes.is_tool_allowed(mode, name) { + if !(config.modes.is_tool_allowed(mode, canonical) + || config.modes.is_tool_allowed(mode, name)) + { let alternate_mode = match mode { Mode::Chat => Mode::Code, Mode::Code => Mode::Chat, }; - if config.modes.is_tool_allowed(alternate_mode, name) { + if config.modes.is_tool_allowed(alternate_mode, canonical) + || config.modes.is_tool_allowed(alternate_mode, name) + { return Ok(ToolResult::error(&format!( "Tool '{}' is not available in {} mode. Switch to {} mode to use this tool (use :mode {} command).", name, mode, alternate_mode, alternate_mode @@ -69,8 +95,8 @@ impl ToolRegistry { } } - let is_enabled = match name { - "web_search" => config.tools.web_search.enabled, + let is_enabled = match canonical { + WEB_SEARCH_TOOL_NAME => config.tools.web_search.enabled, "code_exec" => config.tools.code_exec.enabled, _ => true, // All other tools are considered enabled by default }; @@ -82,8 +108,8 @@ impl ToolRegistry { ); if self.ui.confirm(&prompt).await { // Enable the tool in the in-memory config for the current session - match name { - "web_search" => config.tools.web_search.enabled = true, + match canonical { + WEB_SEARCH_TOOL_NAME => config.tools.web_search.enabled = true, "code_exec" => config.tools.code_exec.enabled = true, _ => {} } @@ -112,3 +138,69 @@ impl ToolRegistry { self.tools.keys().cloned().collect() } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::Config; + use crate::tools::{Tool, ToolResult, WEB_SEARCH_TOOL_NAME}; + use crate::ui::NoOpUiController; + use async_trait::async_trait; + use serde_json::{Value, json}; + use std::sync::Arc; + + struct DummyTool { + name: &'static str, + } + + #[async_trait] + impl Tool for DummyTool { + fn name(&self) -> &'static str { + self.name + } + + fn description(&self) -> &'static str { + "dummy tool" + } + + fn schema(&self) -> Value { + json!({ "type": "object" }) + } + + fn aliases(&self) -> &'static [&'static str] { + self.aliases + } + + async fn execute(&self, _args: Value) -> Result { + Ok(ToolResult::success(json!({ "echo": true }))) + } + } + + fn registry() -> ToolRegistry { + let config = Arc::new(tokio::sync::Mutex::new(Config::default())); + let ui = Arc::new(NoOpUiController); + ToolRegistry::new(config, ui) + } + + #[test] + fn rejects_invalid_tool_identifier() { + let mut registry = registry(); + let tool = DummyTool { + name: "invalid.tool", + }; + + let err = registry.register(tool).unwrap_err(); + assert!(matches!(err, Error::InvalidInput(_))); + } + + #[test] + fn registers_spec_compliant_tool() { + let mut registry = registry(); + let tool = DummyTool { + name: WEB_SEARCH_TOOL_NAME, + }; + + registry.register(tool).unwrap(); + assert!(registry.get(WEB_SEARCH_TOOL_NAME).is_some()); + } +} diff --git a/crates/owlen-core/src/tools/web_search.rs b/crates/owlen-core/src/tools/web_search.rs index f9dc942..8c3d8e3 100644 --- a/crates/owlen-core/src/tools/web_search.rs +++ b/crates/owlen-core/src/tools/web_search.rs @@ -10,6 +10,7 @@ use serde_json::{Value, json}; use super::{Tool, ToolResult}; use crate::consent::ConsentManager; +use crate::tools::WEB_SEARCH_TOOL_NAME; /// Configuration applied to the web search tool at registration time. #[derive(Clone, Debug)] @@ -44,7 +45,7 @@ impl WebSearchTool { #[async_trait] impl Tool for WebSearchTool { fn name(&self) -> &'static str { - "web_search" + WEB_SEARCH_TOOL_NAME } fn description(&self) -> &'static str { diff --git a/crates/owlen-core/src/validation.rs b/crates/owlen-core/src/validation.rs index 3f445fa..24f8d5e 100644 --- a/crates/owlen-core/src/validation.rs +++ b/crates/owlen-core/src/validation.rs @@ -4,6 +4,8 @@ use anyhow::{Context, Result}; use jsonschema::{JSONSchema, ValidationError}; use serde_json::{Value, json}; +use crate::tools::WEB_SEARCH_TOOL_NAME; + pub struct SchemaValidator { schemas: HashMap, } @@ -56,27 +58,26 @@ fn format_validation_error(error: ValidationError) -> String { pub fn get_builtin_schemas() -> HashMap { let mut schemas = HashMap::new(); - schemas.insert( - "web_search".to_string(), - json!({ - "type": "object", - "properties": { - "query": { - "type": "string", - "minLength": 1, - "maxLength": 500 - }, - "max_results": { - "type": "integer", - "minimum": 1, - "maximum": 10, - "default": 5 - } + let web_search_schema = json!({ + "type": "object", + "properties": { + "query": { + "type": "string", + "minLength": 1, + "maxLength": 500 }, - "required": ["query"], - "additionalProperties": false - }), - ); + "max_results": { + "type": "integer", + "minimum": 1, + "maximum": 10, + "default": 5 + } + }, + "required": ["query"], + "additionalProperties": false + }); + + schemas.insert(WEB_SEARCH_TOOL_NAME.to_string(), web_search_schema.clone()); schemas.insert( "code_exec".to_string(), diff --git a/crates/owlen-core/tests/agent_tool_flow.rs b/crates/owlen-core/tests/agent_tool_flow.rs index 29b937b..916cbb1 100644 --- a/crates/owlen-core/tests/agent_tool_flow.rs +++ b/crates/owlen-core/tests/agent_tool_flow.rs @@ -2,6 +2,7 @@ use std::{any::Any, collections::HashMap, sync::Arc}; use async_trait::async_trait; use futures::StreamExt; +use owlen_core::tools::{WEB_SEARCH_TOOL_NAME, tool_name_matches}; use owlen_core::{ Config, Error, Mode, Provider, config::McpMode, @@ -88,7 +89,7 @@ impl Provider for StreamingToolProvider { fn tool_descriptor() -> McpToolDescriptor { McpToolDescriptor { - name: "web_search".to_string(), + name: WEB_SEARCH_TOOL_NAME.to_string(), description: "search".to_string(), input_schema: serde_json::json!({"type": "object"}), requires_network: true, @@ -123,7 +124,7 @@ impl CachedResponseClient { metadata.insert("cached".to_string(), "true".to_string()); let response = McpToolResponse { - name: "web_search".to_string(), + name: WEB_SEARCH_TOOL_NAME.to_string(), success: true, output: serde_json::json!({ "query": "rust", @@ -286,13 +287,13 @@ async fn web_tool_timeout_fails_over_to_cached_result() { ]); let call = McpToolCall { - name: "web_search".to_string(), + name: WEB_SEARCH_TOOL_NAME.to_string(), arguments: serde_json::json!({ "query": "rust", "max_results": 3 }), }; let response = client.call_tool(call.clone()).await.expect("fallback"); - assert_eq!(response.name, "web_search"); + assert!(tool_name_matches(&response.name, WEB_SEARCH_TOOL_NAME)); assert_eq!( response.metadata.get("source").map(String::as_str), Some("cache") diff --git a/crates/owlen-core/tests/consent_scope.rs b/crates/owlen-core/tests/consent_scope.rs index 6ee4b36..fa93a6a 100644 --- a/crates/owlen-core/tests/consent_scope.rs +++ b/crates/owlen-core/tests/consent_scope.rs @@ -1,4 +1,5 @@ use owlen_core::consent::{ConsentManager, ConsentScope}; +use owlen_core::tools::WEB_SEARCH_TOOL_NAME; #[test] fn test_consent_scopes() { @@ -43,23 +44,23 @@ fn test_pending_requests_prevents_duplicates() { // In real usage, multiple threads would call request_consent simultaneously // First, verify a tool has no consent - assert!(!manager.has_consent("web_search")); + assert!(!manager.has_consent(WEB_SEARCH_TOOL_NAME)); // The pending_requests map is private, but we can test the behavior // by checking that consent checks work correctly - assert!(manager.check_consent_needed("web_search").is_some()); + assert!(manager.check_consent_needed(WEB_SEARCH_TOOL_NAME).is_some()); // Grant session consent manager.grant_consent_with_scope( - "web_search", + WEB_SEARCH_TOOL_NAME, vec!["search queries".to_string()], vec!["https://api.search.com".to_string()], ConsentScope::Session, ); // Now it should have consent - assert!(manager.has_consent("web_search")); - assert!(manager.check_consent_needed("web_search").is_none()); + assert!(manager.has_consent(WEB_SEARCH_TOOL_NAME)); + assert!(manager.check_consent_needed(WEB_SEARCH_TOOL_NAME).is_none()); } #[test] diff --git a/crates/owlen-core/tests/web_search_toggle.rs b/crates/owlen-core/tests/web_search_toggle.rs index d8c465a..5a80355 100644 --- a/crates/owlen-core/tests/web_search_toggle.rs +++ b/crates/owlen-core/tests/web_search_toggle.rs @@ -2,6 +2,7 @@ use std::{any::Any, collections::HashMap, sync::Arc}; use async_trait::async_trait; use futures::stream; +use owlen_core::tools::{WEB_SEARCH_TOOL_NAME, tool_name_matches}; use owlen_core::{ ChatStream, Provider, Result, config::Config, @@ -96,12 +97,12 @@ async fn toggling_web_search_updates_config_and_registry() { .tool_registry() .tools() .iter() - .any(|tool| tool == "web_search"), + .any(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME)), "web_search should be disabled by default" ); session - .set_tool_enabled("web_search", true) + .set_tool_enabled(WEB_SEARCH_TOOL_NAME, true) .await .expect("enable web_search"); @@ -115,12 +116,12 @@ async fn toggling_web_search_updates_config_and_registry() { .tool_registry() .tools() .iter() - .any(|tool| tool == "web_search"), + .any(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME)), "web_search should be registered when enabled" ); session - .set_tool_enabled("web_search", false) + .set_tool_enabled(WEB_SEARCH_TOOL_NAME, false) .await .expect("disable web_search"); @@ -134,7 +135,7 @@ async fn toggling_web_search_updates_config_and_registry() { .tool_registry() .tools() .iter() - .any(|tool| tool == "web_search"), + .any(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME)), "web_search should be removed when disabled" ); }