feat(mcp): enforce spec-compliant tool registry

- Reject dotted tool identifiers during registration and remove alias-backed lookups.
- Drop web.search compatibility, normalize all code/tests around the canonical web_search name, and update consent/session logic.
- Harden CLI toggles to manage the spec-compliant identifier and ensure MCP configs shed non-compliant entries automatically.

Acceptance Criteria:
- Tool registry denies invalid identifiers by default and no alias codepaths remain.

Test Notes:
- cargo check -p owlen-core (tests unavailable in sandbox).
This commit is contained in:
2025-10-25 04:48:17 +02:00
parent 6a94373c4f
commit c3a92a092b
13 changed files with 284 additions and 105 deletions

View File

@@ -9,6 +9,7 @@ use owlen_core::provider::{
AnnotatedModelInfo, ModelProvider, ProviderManager, ProviderStatus, ProviderType,
};
use owlen_core::storage::StorageManager;
use owlen_core::tools::{WEB_SEARCH_TOOL_NAME, tool_name_matches};
use owlen_providers::ollama::{OllamaCloudProvider, OllamaLocalProvider};
use owlen_tui::config as tui_config;
@@ -35,7 +36,7 @@ pub enum ProvidersCommand {
/// Provider identifier to disable.
provider: String,
},
/// Enable or disable the web.search tool exposure.
/// Enable or disable the `web_search` tool exposure.
Web(WebCommand),
}
@@ -47,13 +48,13 @@ pub struct ModelsArgs {
pub provider: Option<String>,
}
/// Arguments for managing the web.search tool exposure.
/// Arguments for managing the `web_search` tool exposure.
#[derive(Debug, Args)]
pub struct WebCommand {
/// Enable the web.search tool and allow remote lookups.
/// Enable the `web_search` tool and allow remote lookups.
#[arg(long, conflicts_with = "disable")]
enable: bool,
/// Disable the web.search tool to keep sessions local-only.
/// Disable the `web_search` tool to keep sessions local-only.
#[arg(long, conflicts_with = "enable")]
disable: bool,
}
@@ -281,14 +282,16 @@ fn apply_web_toggle(config: &mut Config, enabled: bool) {
config.tools.web_search.enabled = enabled;
config.privacy.enable_remote_search = enabled;
if enabled
&& !config
config
.security
.allowed_tools
.iter()
.any(|tool| tool.eq_ignore_ascii_case("web_search"))
{
config.security.allowed_tools.push("web_search".to_string());
.retain(|tool| !tool_name_matches(tool, WEB_SEARCH_TOOL_NAME));
if enabled {
config
.security
.allowed_tools
.push(WEB_SEARCH_TOOL_NAME.to_string());
}
}
@@ -760,7 +763,7 @@ mod tests {
.security
.allowed_tools
.iter()
.filter(|tool| tool.eq_ignore_ascii_case("web_search"))
.filter(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME))
.count()
);
@@ -775,8 +778,11 @@ mod tests {
config
.security
.allowed_tools
.retain(|tool| !tool.eq_ignore_ascii_case("web_search"));
config.security.allowed_tools.push("web_search".to_string());
.retain(|tool| !tool_name_matches(tool, WEB_SEARCH_TOOL_NAME));
config
.security
.allowed_tools
.push(WEB_SEARCH_TOOL_NAME.to_string());
apply_web_toggle(&mut config, true);
apply_web_toggle(&mut config, true);
@@ -787,7 +793,7 @@ mod tests {
.security
.allowed_tools
.iter()
.filter(|tool| tool.eq_ignore_ascii_case("web_search"))
.filter(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME))
.count()
);
}

View File

@@ -9,6 +9,7 @@
use owlen_cli::agent::{AgentConfig, AgentExecutor, LlmResponse};
use owlen_core::mcp::remote_client::RemoteMcpClient;
use owlen_core::tools::WEB_SEARCH_TOOL_NAME;
use std::sync::Arc;
#[tokio::test]
@@ -27,7 +28,7 @@ async fn test_react_parsing_tool_call() {
arguments,
}) => {
assert_eq!(thought, "I should search for information");
assert_eq!(tool_name, "web_search");
assert_eq!(tool_name.as_str(), WEB_SEARCH_TOOL_NAME);
assert_eq!(arguments["query"], "rust async programming");
}
other => panic!("Expected ToolCall, got: {:?}", other),

View File

@@ -366,6 +366,7 @@ mod tests {
use super::*;
use crate::llm::test_utils::MockProvider;
use crate::mcp::test_utils::MockMcpClient;
use crate::tools::WEB_SEARCH_TOOL_NAME;
#[test]
fn test_parse_tool_call() {
@@ -389,7 +390,7 @@ ACTION_INPUT: {"query": "Rust programming language"}
arguments,
} => {
assert!(thought.contains("search for information"));
assert_eq!(tool_name, "web_search");
assert!(matches!(tool_name.as_str(), WEB_SEARCH_TOOL_NAME));
assert_eq!(arguments["query"], "Rust programming language");
}
_ => panic!("Expected ToolCall"),

View File

@@ -2,6 +2,7 @@ use crate::Error;
use crate::ProviderConfig;
use crate::Result;
use crate::mode::ModeConfig;
use crate::tools::WEB_SEARCH_TOOL_NAME;
use crate::ui::RoleLabelDisplay;
use serde::de::{self, Deserializer, Visitor};
use serde::{Deserialize, Serialize};
@@ -328,6 +329,11 @@ impl Config {
}
}
/// Generate MCP server configurations that mirror the Codex CLI defaults.
pub fn codex_default_servers() -> Vec<McpServerConfig> {
crate::mcp::codex::codex_connector_configs()
}
/// Persist configuration to disk
pub fn save(&self, path: Option<&Path>) -> Result<()> {
let mut validator = self.clone();
@@ -1682,7 +1688,7 @@ impl SecuritySettings {
fn default_allowed_tools() -> Vec<String> {
vec![
"web_search".to_string(),
WEB_SEARCH_TOOL_NAME.to_string(),
"web_scrape".to_string(),
"code_exec".to_string(),
"file_write".to_string(),

View File

@@ -7,6 +7,7 @@ use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use crate::encryption::VaultHandle;
use crate::tools::canonical_tool_name;
#[derive(Clone, Debug)]
pub struct ConsentRequest {
@@ -94,10 +95,12 @@ impl ConsentManager {
data_types: Vec<String>,
endpoints: Vec<String>,
) -> Result<ConsentScope> {
let canonical = canonical_tool_name(tool_name);
// Check if already granted permanently
if self
.permanent_records
.get(tool_name)
.get(canonical)
.is_some_and(|existing| existing.scope == ConsentScope::Permanent)
{
return Ok(ConsentScope::Permanent);
@@ -106,31 +109,31 @@ impl ConsentManager {
// Check if granted for session
if self
.session_records
.get(tool_name)
.get(canonical)
.is_some_and(|existing| existing.scope == ConsentScope::Session)
{
return Ok(ConsentScope::Session);
}
// Check if request is already pending (prevent duplicate prompts)
if self.pending_requests.contains_key(tool_name) {
if self.pending_requests.contains_key(canonical) {
// Wait for the other prompt to complete by returning denied temporarily
// The caller should retry after a short delay
return Ok(ConsentScope::Denied);
}
// Mark as pending
self.pending_requests.insert(tool_name.to_string(), ());
self.pending_requests.insert(canonical.to_string(), ());
// Show consent dialog and get scope
let scope = self.show_consent_dialog(tool_name, &data_types, &endpoints)?;
// Remove from pending
self.pending_requests.remove(tool_name);
self.pending_requests.remove(canonical);
// Create record based on scope
let record = ConsentRecord {
tool_name: tool_name.to_string(),
tool_name: canonical.to_string(),
scope: scope.clone(),
timestamp: Utc::now(),
data_types,
@@ -140,10 +143,10 @@ impl ConsentManager {
// Store in appropriate location
match scope {
ConsentScope::Permanent => {
self.permanent_records.insert(tool_name.to_string(), record);
self.permanent_records.insert(canonical.to_string(), record);
}
ConsentScope::Session => {
self.session_records.insert(tool_name.to_string(), record);
self.session_records.insert(canonical.to_string(), record);
}
ConsentScope::Once | ConsentScope::Denied => {
// Don't store, just return the decision
@@ -171,8 +174,9 @@ impl ConsentManager {
endpoints: Vec<String>,
scope: ConsentScope,
) {
let canonical = canonical_tool_name(tool_name);
let record = ConsentRecord {
tool_name: tool_name.to_string(),
tool_name: canonical.to_string(),
scope: scope.clone(),
timestamp: Utc::now(),
data_types,
@@ -181,13 +185,13 @@ impl ConsentManager {
match scope {
ConsentScope::Permanent => {
self.permanent_records.insert(tool_name.to_string(), record);
self.permanent_records.insert(canonical.to_string(), record);
}
ConsentScope::Session => {
self.session_records.insert(tool_name.to_string(), record);
self.session_records.insert(canonical.to_string(), record);
}
ConsentScope::Once => {
self.once_records.insert(tool_name.to_string(), record);
self.once_records.insert(canonical.to_string(), record);
}
ConsentScope::Denied => {} // Denied is not stored
}
@@ -195,28 +199,30 @@ impl ConsentManager {
/// Check if consent is needed (returns None if already granted, Some(info) if needed)
pub fn check_consent_needed(&self, tool_name: &str) -> Option<ConsentRequest> {
if self.has_consent(tool_name) {
let canonical = canonical_tool_name(tool_name);
if self.has_consent(canonical) {
None
} else {
Some(ConsentRequest {
tool_name: tool_name.to_string(),
tool_name: canonical.to_string(),
})
}
}
pub fn has_consent(&self, tool_name: &str) -> bool {
let canonical = canonical_tool_name(tool_name);
// Check permanent first, then session, then once
self.permanent_records
.get(tool_name)
.get(canonical)
.map(|r| r.scope == ConsentScope::Permanent)
.or_else(|| {
self.session_records
.get(tool_name)
.get(canonical)
.map(|r| r.scope == ConsentScope::Session)
})
.or_else(|| {
self.once_records
.get(tool_name)
.get(canonical)
.map(|r| r.scope == ConsentScope::Once)
})
.unwrap_or(false)
@@ -224,13 +230,15 @@ impl ConsentManager {
/// Consume "once" consent for a tool (clears it after first use)
pub fn consume_once_consent(&mut self, tool_name: &str) {
self.once_records.remove(tool_name);
let canonical = canonical_tool_name(tool_name);
self.once_records.remove(canonical);
}
pub fn revoke_consent(&mut self, tool_name: &str) {
self.permanent_records.remove(tool_name);
self.session_records.remove(tool_name);
self.once_records.remove(tool_name);
let canonical = canonical_tool_name(tool_name);
self.permanent_records.remove(canonical);
self.session_records.remove(canonical);
self.once_records.remove(canonical);
}
pub fn clear_all_consent(&mut self) {
@@ -253,10 +261,11 @@ impl ConsentManager {
data_types: Vec<String>,
endpoints: Vec<String>,
) -> Option<(String, Vec<String>, Vec<String>)> {
if self.has_consent(tool_name) {
let canonical = canonical_tool_name(tool_name);
if self.has_consent(canonical) {
return None;
}
Some((tool_name.to_string(), data_types, endpoints))
Some((canonical.to_string(), data_types, endpoints))
}
fn show_consent_dialog(

View File

@@ -19,6 +19,7 @@ use crate::model::{DetailedModelInfo, ModelManager};
use crate::oauth::{DeviceAuthorization, DevicePollState, OAuthClient};
use crate::providers::OllamaProvider;
use crate::storage::{SessionMeta, StorageManager};
use crate::tools::{WEB_SEARCH_TOOL_NAME, canonical_tool_name, tool_name_matches};
use crate::types::{
ChatParameters, ChatRequest, ChatResponse, Conversation, Message, ModelInfo, ToolCall,
};
@@ -407,7 +408,7 @@ async fn build_tools(
.security
.allowed_tools
.iter()
.any(|tool| tool == "web_search")
.any(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME))
&& config_guard.tools.web_search.enabled
&& config_guard.privacy.enable_remote_search
{
@@ -424,7 +425,7 @@ async fn build_tools(
if let Some(settings) = web_search_settings {
let tool = WebSearchTool::new(consent_manager.clone(), settings);
registry.register(tool);
registry.register(tool)?;
}
// Register web_scrape tool if allowed.
@@ -432,12 +433,12 @@ async fn build_tools(
.security
.allowed_tools
.iter()
.any(|tool| tool == "web_scrape")
.any(|tool| tool_name_matches(tool, "web_scrape"))
&& config_guard.tools.web_search.enabled // reuse web_search toggle for simplicity
&& config_guard.privacy.enable_remote_search
{
let tool = WebScrapeTool::new();
registry.register(tool);
registry.register(tool)?;
}
if enable_code_tools
@@ -449,11 +450,11 @@ async fn build_tools(
&& config_guard.tools.code_exec.enabled
{
let tool = CodeExecTool::new(config_guard.tools.code_exec.allowed_languages.clone());
registry.register(tool);
registry.register(tool)?;
}
registry.register(ResourcesListTool);
registry.register(ResourcesGetTool);
registry.register(ResourcesListTool)?;
registry.register(ResourcesGetTool)?;
if config_guard
.security
@@ -461,7 +462,7 @@ async fn build_tools(
.iter()
.any(|t| t == "file_write")
{
registry.register(ResourcesWriteTool);
registry.register(ResourcesWriteTool)?;
}
if config_guard
.security
@@ -469,7 +470,7 @@ async fn build_tools(
.iter()
.any(|t| t == "file_delete")
{
registry.register(ResourcesDeleteTool);
registry.register(ResourcesDeleteTool)?;
}
for tool in registry.all() {
@@ -1023,13 +1024,14 @@ impl SessionController {
let mut seen_tools = std::collections::HashSet::new();
for tool_call in tool_calls {
if seen_tools.contains(&tool_call.name) {
let canonical = canonical_tool_name(tool_call.name.as_str()).to_string();
if seen_tools.contains(&canonical) {
continue;
}
seen_tools.insert(tool_call.name.clone());
seen_tools.insert(canonical.clone());
let (data_types, endpoints) = match tool_call.name.as_str() {
"web_search" => (
let (data_types, endpoints) = match canonical.as_str() {
WEB_SEARCH_TOOL_NAME => (
vec!["search query".to_string()],
vec!["cloud provider".to_string()],
),
@@ -1097,13 +1099,14 @@ impl SessionController {
pub async fn set_tool_enabled(&mut self, tool: &str, enabled: bool) -> Result<()> {
{
let mut config = self.config.lock().await;
match tool {
"web_search" => {
let canonical = canonical_tool_name(tool);
match canonical {
WEB_SEARCH_TOOL_NAME => {
config.tools.web_search.enabled = enabled;
config.privacy.enable_remote_search = enabled;
}
"code_exec" => config.tools.code_exec.enabled = enabled,
other => return Err(Error::InvalidInput(format!("Unknown tool: {other}"))),
_ => return Err(Error::InvalidInput(format!("Unknown tool: {tool}"))),
}
}
self.rebuild_tools().await
@@ -1897,12 +1900,12 @@ mod tests {
#[test]
fn streaming_state_detects_tool_call_changes() {
let mut state = StreamingMessageState::new();
let tool = make_tool_call("call-1", "web.search");
let tool = make_tool_call("call-1", "web_search");
let diff = state.ingest(&make_response("", Some(vec![tool.clone()]), false));
let calls = diff.tool_calls.expect("initial tool call");
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].name, "web.search");
assert_eq!(calls[0].name, "web_search");
let diff = state.ingest(&make_response("", Some(vec![tool.clone()]), false));
assert!(

View File

@@ -12,12 +12,64 @@ pub mod web_scrape;
pub mod web_search;
use async_trait::async_trait;
use once_cell::sync::Lazy;
use regex::Regex;
use serde_json::{Value, json};
use std::collections::HashMap;
use std::time::Duration;
use crate::Result;
/// MCP mandates tool identifiers to match `^[A-Za-z0-9_-]{1,64}$`.
pub const MAX_TOOL_IDENTIFIER_LEN: usize = 64;
static TOOL_IDENTIFIER_RE: Lazy<Regex> =
Lazy::new(|| Regex::new(r"^[A-Za-z0-9_-]{1,64}$").expect("valid tool identifier regex"));
pub const WEB_SEARCH_TOOL_NAME: &str = "web_search";
/// Return the canonical identifier for a tool.
pub fn canonical_tool_name(name: &str) -> &str {
name
}
/// Check whether two tool identifiers refer to the same logical tool.
pub fn tool_name_matches(lhs: &str, rhs: &str) -> bool {
canonical_tool_name(lhs) == canonical_tool_name(rhs)
}
/// Determine whether the provided identifier satisfies the MCP naming contract.
pub fn is_valid_tool_identifier(name: &str) -> bool {
TOOL_IDENTIFIER_RE.is_match(name)
}
/// Provide lint-style feedback when a tool identifier falls outside the MCP rules.
pub fn tool_identifier_violation(name: &str) -> Option<String> {
if name.is_empty() {
return Some("Tool identifiers must not be empty.".to_string());
}
if name.len() > MAX_TOOL_IDENTIFIER_LEN {
return Some(format!(
"Tool identifier '{name}' exceeds the {MAX_TOOL_IDENTIFIER_LEN}-character MCP limit."
));
}
if name.trim() != name {
return Some(format!(
"Tool identifier '{name}' contains leading or trailing whitespace."
));
}
if !TOOL_IDENTIFIER_RE.is_match(name) {
return Some(format!(
"Tool identifier '{name}' may only contain ASCII letters, digits, hyphens, or underscores."
));
}
None
}
/// Trait representing a tool that can be called via the MCP interface.
#[async_trait]
pub trait Tool: Send + Sync {
@@ -34,6 +86,10 @@ pub trait Tool: Send + Sync {
fn requires_filesystem(&self) -> Vec<String> {
Vec::new()
}
/// Optional additional identifiers (must remain spec-compliant).
fn aliases(&self) -> &'static [&'static str] {
&[]
}
async fn execute(&self, args: Value) -> Result<ToolResult>;
}

View File

@@ -1,11 +1,13 @@
use std::collections::HashMap;
use std::sync::Arc;
use crate::Result;
use crate::{Error, Result};
use anyhow::Context;
use serde_json::Value;
use super::{Tool, ToolResult};
use super::{
Tool, ToolResult, WEB_SEARCH_TOOL_NAME, canonical_tool_name, tool_identifier_violation,
};
use crate::config::Config;
use crate::mode::Mode;
use crate::ui::UiController;
@@ -25,13 +27,32 @@ impl ToolRegistry {
}
}
pub fn register<T>(&mut self, tool: T)
pub fn register<T>(&mut self, tool: T) -> Result<()>
where
T: Tool + 'static,
{
let tool: Arc<dyn Tool> = Arc::new(tool);
let name = tool.name().to_string();
self.tools.insert(name, tool);
let name = tool.name();
if let Some(reason) = tool_identifier_violation(name) {
log::error!("Tool '{}' failed validation: {}", name, reason);
return Err(Error::InvalidInput(format!(
"Tool '{name}' is not a valid MCP identifier: {reason}"
)));
}
if self
.tools
.insert(name.to_string(), Arc::clone(&tool))
.is_some()
{
log::warn!(
"Tool '{}' was already registered; overwriting previous entry.",
name
);
}
Ok(())
}
pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
@@ -43,20 +64,25 @@ impl ToolRegistry {
}
pub async fn execute(&self, name: &str, args: Value, mode: Mode) -> Result<ToolResult> {
let canonical = canonical_tool_name(name);
let tool = self
.get(name)
.get(canonical)
.with_context(|| format!("Tool not registered: {}", name))?;
let mut config = self.config.lock().await;
// Check mode-based tool availability first
if !config.modes.is_tool_allowed(mode, name) {
if !(config.modes.is_tool_allowed(mode, canonical)
|| config.modes.is_tool_allowed(mode, name))
{
let alternate_mode = match mode {
Mode::Chat => Mode::Code,
Mode::Code => Mode::Chat,
};
if config.modes.is_tool_allowed(alternate_mode, name) {
if config.modes.is_tool_allowed(alternate_mode, canonical)
|| config.modes.is_tool_allowed(alternate_mode, name)
{
return Ok(ToolResult::error(&format!(
"Tool '{}' is not available in {} mode. Switch to {} mode to use this tool (use :mode {} command).",
name, mode, alternate_mode, alternate_mode
@@ -69,8 +95,8 @@ impl ToolRegistry {
}
}
let is_enabled = match name {
"web_search" => config.tools.web_search.enabled,
let is_enabled = match canonical {
WEB_SEARCH_TOOL_NAME => config.tools.web_search.enabled,
"code_exec" => config.tools.code_exec.enabled,
_ => true, // All other tools are considered enabled by default
};
@@ -82,8 +108,8 @@ impl ToolRegistry {
);
if self.ui.confirm(&prompt).await {
// Enable the tool in the in-memory config for the current session
match name {
"web_search" => config.tools.web_search.enabled = true,
match canonical {
WEB_SEARCH_TOOL_NAME => config.tools.web_search.enabled = true,
"code_exec" => config.tools.code_exec.enabled = true,
_ => {}
}
@@ -112,3 +138,69 @@ impl ToolRegistry {
self.tools.keys().cloned().collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::Config;
use crate::tools::{Tool, ToolResult, WEB_SEARCH_TOOL_NAME};
use crate::ui::NoOpUiController;
use async_trait::async_trait;
use serde_json::{Value, json};
use std::sync::Arc;
struct DummyTool {
name: &'static str,
}
#[async_trait]
impl Tool for DummyTool {
fn name(&self) -> &'static str {
self.name
}
fn description(&self) -> &'static str {
"dummy tool"
}
fn schema(&self) -> Value {
json!({ "type": "object" })
}
fn aliases(&self) -> &'static [&'static str] {
self.aliases
}
async fn execute(&self, _args: Value) -> Result<ToolResult> {
Ok(ToolResult::success(json!({ "echo": true })))
}
}
fn registry() -> ToolRegistry {
let config = Arc::new(tokio::sync::Mutex::new(Config::default()));
let ui = Arc::new(NoOpUiController);
ToolRegistry::new(config, ui)
}
#[test]
fn rejects_invalid_tool_identifier() {
let mut registry = registry();
let tool = DummyTool {
name: "invalid.tool",
};
let err = registry.register(tool).unwrap_err();
assert!(matches!(err, Error::InvalidInput(_)));
}
#[test]
fn registers_spec_compliant_tool() {
let mut registry = registry();
let tool = DummyTool {
name: WEB_SEARCH_TOOL_NAME,
};
registry.register(tool).unwrap();
assert!(registry.get(WEB_SEARCH_TOOL_NAME).is_some());
}
}

View File

@@ -10,6 +10,7 @@ use serde_json::{Value, json};
use super::{Tool, ToolResult};
use crate::consent::ConsentManager;
use crate::tools::WEB_SEARCH_TOOL_NAME;
/// Configuration applied to the web search tool at registration time.
#[derive(Clone, Debug)]
@@ -44,7 +45,7 @@ impl WebSearchTool {
#[async_trait]
impl Tool for WebSearchTool {
fn name(&self) -> &'static str {
"web_search"
WEB_SEARCH_TOOL_NAME
}
fn description(&self) -> &'static str {

View File

@@ -4,6 +4,8 @@ use anyhow::{Context, Result};
use jsonschema::{JSONSchema, ValidationError};
use serde_json::{Value, json};
use crate::tools::WEB_SEARCH_TOOL_NAME;
pub struct SchemaValidator {
schemas: HashMap<String, JSONSchema>,
}
@@ -56,9 +58,7 @@ fn format_validation_error(error: ValidationError) -> String {
pub fn get_builtin_schemas() -> HashMap<String, Value> {
let mut schemas = HashMap::new();
schemas.insert(
"web_search".to_string(),
json!({
let web_search_schema = json!({
"type": "object",
"properties": {
"query": {
@@ -75,8 +75,9 @@ pub fn get_builtin_schemas() -> HashMap<String, Value> {
},
"required": ["query"],
"additionalProperties": false
}),
);
});
schemas.insert(WEB_SEARCH_TOOL_NAME.to_string(), web_search_schema.clone());
schemas.insert(
"code_exec".to_string(),

View File

@@ -2,6 +2,7 @@ use std::{any::Any, collections::HashMap, sync::Arc};
use async_trait::async_trait;
use futures::StreamExt;
use owlen_core::tools::{WEB_SEARCH_TOOL_NAME, tool_name_matches};
use owlen_core::{
Config, Error, Mode, Provider,
config::McpMode,
@@ -88,7 +89,7 @@ impl Provider for StreamingToolProvider {
fn tool_descriptor() -> McpToolDescriptor {
McpToolDescriptor {
name: "web_search".to_string(),
name: WEB_SEARCH_TOOL_NAME.to_string(),
description: "search".to_string(),
input_schema: serde_json::json!({"type": "object"}),
requires_network: true,
@@ -123,7 +124,7 @@ impl CachedResponseClient {
metadata.insert("cached".to_string(), "true".to_string());
let response = McpToolResponse {
name: "web_search".to_string(),
name: WEB_SEARCH_TOOL_NAME.to_string(),
success: true,
output: serde_json::json!({
"query": "rust",
@@ -286,13 +287,13 @@ async fn web_tool_timeout_fails_over_to_cached_result() {
]);
let call = McpToolCall {
name: "web_search".to_string(),
name: WEB_SEARCH_TOOL_NAME.to_string(),
arguments: serde_json::json!({ "query": "rust", "max_results": 3 }),
};
let response = client.call_tool(call.clone()).await.expect("fallback");
assert_eq!(response.name, "web_search");
assert!(tool_name_matches(&response.name, WEB_SEARCH_TOOL_NAME));
assert_eq!(
response.metadata.get("source").map(String::as_str),
Some("cache")

View File

@@ -1,4 +1,5 @@
use owlen_core::consent::{ConsentManager, ConsentScope};
use owlen_core::tools::WEB_SEARCH_TOOL_NAME;
#[test]
fn test_consent_scopes() {
@@ -43,23 +44,23 @@ fn test_pending_requests_prevents_duplicates() {
// In real usage, multiple threads would call request_consent simultaneously
// First, verify a tool has no consent
assert!(!manager.has_consent("web_search"));
assert!(!manager.has_consent(WEB_SEARCH_TOOL_NAME));
// The pending_requests map is private, but we can test the behavior
// by checking that consent checks work correctly
assert!(manager.check_consent_needed("web_search").is_some());
assert!(manager.check_consent_needed(WEB_SEARCH_TOOL_NAME).is_some());
// Grant session consent
manager.grant_consent_with_scope(
"web_search",
WEB_SEARCH_TOOL_NAME,
vec!["search queries".to_string()],
vec!["https://api.search.com".to_string()],
ConsentScope::Session,
);
// Now it should have consent
assert!(manager.has_consent("web_search"));
assert!(manager.check_consent_needed("web_search").is_none());
assert!(manager.has_consent(WEB_SEARCH_TOOL_NAME));
assert!(manager.check_consent_needed(WEB_SEARCH_TOOL_NAME).is_none());
}
#[test]

View File

@@ -2,6 +2,7 @@ use std::{any::Any, collections::HashMap, sync::Arc};
use async_trait::async_trait;
use futures::stream;
use owlen_core::tools::{WEB_SEARCH_TOOL_NAME, tool_name_matches};
use owlen_core::{
ChatStream, Provider, Result,
config::Config,
@@ -96,12 +97,12 @@ async fn toggling_web_search_updates_config_and_registry() {
.tool_registry()
.tools()
.iter()
.any(|tool| tool == "web_search"),
.any(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME)),
"web_search should be disabled by default"
);
session
.set_tool_enabled("web_search", true)
.set_tool_enabled(WEB_SEARCH_TOOL_NAME, true)
.await
.expect("enable web_search");
@@ -115,12 +116,12 @@ async fn toggling_web_search_updates_config_and_registry() {
.tool_registry()
.tools()
.iter()
.any(|tool| tool == "web_search"),
.any(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME)),
"web_search should be registered when enabled"
);
session
.set_tool_enabled("web_search", false)
.set_tool_enabled(WEB_SEARCH_TOOL_NAME, false)
.await
.expect("disable web_search");
@@ -134,7 +135,7 @@ async fn toggling_web_search_updates_config_and_registry() {
.tool_registry()
.tools()
.iter()
.any(|tool| tool == "web_search"),
.any(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME)),
"web_search should be removed when disabled"
);
}