feat(mcp): enforce spec-compliant tool registry
- Reject dotted tool identifiers during registration and remove alias-backed lookups. - Drop web.search compatibility, normalize all code/tests around the canonical web_search name, and update consent/session logic. - Harden CLI toggles to manage the spec-compliant identifier and ensure MCP configs shed non-compliant entries automatically. Acceptance Criteria: - Tool registry denies invalid identifiers by default and no alias codepaths remain. Test Notes: - cargo check -p owlen-core (tests unavailable in sandbox).
This commit is contained in:
@@ -9,6 +9,7 @@ use owlen_core::provider::{
|
|||||||
AnnotatedModelInfo, ModelProvider, ProviderManager, ProviderStatus, ProviderType,
|
AnnotatedModelInfo, ModelProvider, ProviderManager, ProviderStatus, ProviderType,
|
||||||
};
|
};
|
||||||
use owlen_core::storage::StorageManager;
|
use owlen_core::storage::StorageManager;
|
||||||
|
use owlen_core::tools::{WEB_SEARCH_TOOL_NAME, tool_name_matches};
|
||||||
use owlen_providers::ollama::{OllamaCloudProvider, OllamaLocalProvider};
|
use owlen_providers::ollama::{OllamaCloudProvider, OllamaLocalProvider};
|
||||||
use owlen_tui::config as tui_config;
|
use owlen_tui::config as tui_config;
|
||||||
|
|
||||||
@@ -35,7 +36,7 @@ pub enum ProvidersCommand {
|
|||||||
/// Provider identifier to disable.
|
/// Provider identifier to disable.
|
||||||
provider: String,
|
provider: String,
|
||||||
},
|
},
|
||||||
/// Enable or disable the web.search tool exposure.
|
/// Enable or disable the `web_search` tool exposure.
|
||||||
Web(WebCommand),
|
Web(WebCommand),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -47,13 +48,13 @@ pub struct ModelsArgs {
|
|||||||
pub provider: Option<String>,
|
pub provider: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Arguments for managing the web.search tool exposure.
|
/// Arguments for managing the `web_search` tool exposure.
|
||||||
#[derive(Debug, Args)]
|
#[derive(Debug, Args)]
|
||||||
pub struct WebCommand {
|
pub struct WebCommand {
|
||||||
/// Enable the web.search tool and allow remote lookups.
|
/// Enable the `web_search` tool and allow remote lookups.
|
||||||
#[arg(long, conflicts_with = "disable")]
|
#[arg(long, conflicts_with = "disable")]
|
||||||
enable: bool,
|
enable: bool,
|
||||||
/// Disable the web.search tool to keep sessions local-only.
|
/// Disable the `web_search` tool to keep sessions local-only.
|
||||||
#[arg(long, conflicts_with = "enable")]
|
#[arg(long, conflicts_with = "enable")]
|
||||||
disable: bool,
|
disable: bool,
|
||||||
}
|
}
|
||||||
@@ -281,14 +282,16 @@ fn apply_web_toggle(config: &mut Config, enabled: bool) {
|
|||||||
config.tools.web_search.enabled = enabled;
|
config.tools.web_search.enabled = enabled;
|
||||||
config.privacy.enable_remote_search = enabled;
|
config.privacy.enable_remote_search = enabled;
|
||||||
|
|
||||||
if enabled
|
config
|
||||||
&& !config
|
|
||||||
.security
|
.security
|
||||||
.allowed_tools
|
.allowed_tools
|
||||||
.iter()
|
.retain(|tool| !tool_name_matches(tool, WEB_SEARCH_TOOL_NAME));
|
||||||
.any(|tool| tool.eq_ignore_ascii_case("web_search"))
|
|
||||||
{
|
if enabled {
|
||||||
config.security.allowed_tools.push("web_search".to_string());
|
config
|
||||||
|
.security
|
||||||
|
.allowed_tools
|
||||||
|
.push(WEB_SEARCH_TOOL_NAME.to_string());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -760,7 +763,7 @@ mod tests {
|
|||||||
.security
|
.security
|
||||||
.allowed_tools
|
.allowed_tools
|
||||||
.iter()
|
.iter()
|
||||||
.filter(|tool| tool.eq_ignore_ascii_case("web_search"))
|
.filter(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME))
|
||||||
.count()
|
.count()
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -775,8 +778,11 @@ mod tests {
|
|||||||
config
|
config
|
||||||
.security
|
.security
|
||||||
.allowed_tools
|
.allowed_tools
|
||||||
.retain(|tool| !tool.eq_ignore_ascii_case("web_search"));
|
.retain(|tool| !tool_name_matches(tool, WEB_SEARCH_TOOL_NAME));
|
||||||
config.security.allowed_tools.push("web_search".to_string());
|
config
|
||||||
|
.security
|
||||||
|
.allowed_tools
|
||||||
|
.push(WEB_SEARCH_TOOL_NAME.to_string());
|
||||||
|
|
||||||
apply_web_toggle(&mut config, true);
|
apply_web_toggle(&mut config, true);
|
||||||
apply_web_toggle(&mut config, true);
|
apply_web_toggle(&mut config, true);
|
||||||
@@ -787,7 +793,7 @@ mod tests {
|
|||||||
.security
|
.security
|
||||||
.allowed_tools
|
.allowed_tools
|
||||||
.iter()
|
.iter()
|
||||||
.filter(|tool| tool.eq_ignore_ascii_case("web_search"))
|
.filter(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME))
|
||||||
.count()
|
.count()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@
|
|||||||
|
|
||||||
use owlen_cli::agent::{AgentConfig, AgentExecutor, LlmResponse};
|
use owlen_cli::agent::{AgentConfig, AgentExecutor, LlmResponse};
|
||||||
use owlen_core::mcp::remote_client::RemoteMcpClient;
|
use owlen_core::mcp::remote_client::RemoteMcpClient;
|
||||||
|
use owlen_core::tools::WEB_SEARCH_TOOL_NAME;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@@ -27,7 +28,7 @@ async fn test_react_parsing_tool_call() {
|
|||||||
arguments,
|
arguments,
|
||||||
}) => {
|
}) => {
|
||||||
assert_eq!(thought, "I should search for information");
|
assert_eq!(thought, "I should search for information");
|
||||||
assert_eq!(tool_name, "web_search");
|
assert_eq!(tool_name.as_str(), WEB_SEARCH_TOOL_NAME);
|
||||||
assert_eq!(arguments["query"], "rust async programming");
|
assert_eq!(arguments["query"], "rust async programming");
|
||||||
}
|
}
|
||||||
other => panic!("Expected ToolCall, got: {:?}", other),
|
other => panic!("Expected ToolCall, got: {:?}", other),
|
||||||
|
|||||||
@@ -366,6 +366,7 @@ mod tests {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::llm::test_utils::MockProvider;
|
use crate::llm::test_utils::MockProvider;
|
||||||
use crate::mcp::test_utils::MockMcpClient;
|
use crate::mcp::test_utils::MockMcpClient;
|
||||||
|
use crate::tools::WEB_SEARCH_TOOL_NAME;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_parse_tool_call() {
|
fn test_parse_tool_call() {
|
||||||
@@ -389,7 +390,7 @@ ACTION_INPUT: {"query": "Rust programming language"}
|
|||||||
arguments,
|
arguments,
|
||||||
} => {
|
} => {
|
||||||
assert!(thought.contains("search for information"));
|
assert!(thought.contains("search for information"));
|
||||||
assert_eq!(tool_name, "web_search");
|
assert!(matches!(tool_name.as_str(), WEB_SEARCH_TOOL_NAME));
|
||||||
assert_eq!(arguments["query"], "Rust programming language");
|
assert_eq!(arguments["query"], "Rust programming language");
|
||||||
}
|
}
|
||||||
_ => panic!("Expected ToolCall"),
|
_ => panic!("Expected ToolCall"),
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ use crate::Error;
|
|||||||
use crate::ProviderConfig;
|
use crate::ProviderConfig;
|
||||||
use crate::Result;
|
use crate::Result;
|
||||||
use crate::mode::ModeConfig;
|
use crate::mode::ModeConfig;
|
||||||
|
use crate::tools::WEB_SEARCH_TOOL_NAME;
|
||||||
use crate::ui::RoleLabelDisplay;
|
use crate::ui::RoleLabelDisplay;
|
||||||
use serde::de::{self, Deserializer, Visitor};
|
use serde::de::{self, Deserializer, Visitor};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
@@ -328,6 +329,11 @@ impl Config {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generate MCP server configurations that mirror the Codex CLI defaults.
|
||||||
|
pub fn codex_default_servers() -> Vec<McpServerConfig> {
|
||||||
|
crate::mcp::codex::codex_connector_configs()
|
||||||
|
}
|
||||||
|
|
||||||
/// Persist configuration to disk
|
/// Persist configuration to disk
|
||||||
pub fn save(&self, path: Option<&Path>) -> Result<()> {
|
pub fn save(&self, path: Option<&Path>) -> Result<()> {
|
||||||
let mut validator = self.clone();
|
let mut validator = self.clone();
|
||||||
@@ -1682,7 +1688,7 @@ impl SecuritySettings {
|
|||||||
|
|
||||||
fn default_allowed_tools() -> Vec<String> {
|
fn default_allowed_tools() -> Vec<String> {
|
||||||
vec![
|
vec![
|
||||||
"web_search".to_string(),
|
WEB_SEARCH_TOOL_NAME.to_string(),
|
||||||
"web_scrape".to_string(),
|
"web_scrape".to_string(),
|
||||||
"code_exec".to_string(),
|
"code_exec".to_string(),
|
||||||
"file_write".to_string(),
|
"file_write".to_string(),
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ use chrono::{DateTime, Utc};
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::encryption::VaultHandle;
|
use crate::encryption::VaultHandle;
|
||||||
|
use crate::tools::canonical_tool_name;
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct ConsentRequest {
|
pub struct ConsentRequest {
|
||||||
@@ -94,10 +95,12 @@ impl ConsentManager {
|
|||||||
data_types: Vec<String>,
|
data_types: Vec<String>,
|
||||||
endpoints: Vec<String>,
|
endpoints: Vec<String>,
|
||||||
) -> Result<ConsentScope> {
|
) -> Result<ConsentScope> {
|
||||||
|
let canonical = canonical_tool_name(tool_name);
|
||||||
|
|
||||||
// Check if already granted permanently
|
// Check if already granted permanently
|
||||||
if self
|
if self
|
||||||
.permanent_records
|
.permanent_records
|
||||||
.get(tool_name)
|
.get(canonical)
|
||||||
.is_some_and(|existing| existing.scope == ConsentScope::Permanent)
|
.is_some_and(|existing| existing.scope == ConsentScope::Permanent)
|
||||||
{
|
{
|
||||||
return Ok(ConsentScope::Permanent);
|
return Ok(ConsentScope::Permanent);
|
||||||
@@ -106,31 +109,31 @@ impl ConsentManager {
|
|||||||
// Check if granted for session
|
// Check if granted for session
|
||||||
if self
|
if self
|
||||||
.session_records
|
.session_records
|
||||||
.get(tool_name)
|
.get(canonical)
|
||||||
.is_some_and(|existing| existing.scope == ConsentScope::Session)
|
.is_some_and(|existing| existing.scope == ConsentScope::Session)
|
||||||
{
|
{
|
||||||
return Ok(ConsentScope::Session);
|
return Ok(ConsentScope::Session);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if request is already pending (prevent duplicate prompts)
|
// Check if request is already pending (prevent duplicate prompts)
|
||||||
if self.pending_requests.contains_key(tool_name) {
|
if self.pending_requests.contains_key(canonical) {
|
||||||
// Wait for the other prompt to complete by returning denied temporarily
|
// Wait for the other prompt to complete by returning denied temporarily
|
||||||
// The caller should retry after a short delay
|
// The caller should retry after a short delay
|
||||||
return Ok(ConsentScope::Denied);
|
return Ok(ConsentScope::Denied);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mark as pending
|
// Mark as pending
|
||||||
self.pending_requests.insert(tool_name.to_string(), ());
|
self.pending_requests.insert(canonical.to_string(), ());
|
||||||
|
|
||||||
// Show consent dialog and get scope
|
// Show consent dialog and get scope
|
||||||
let scope = self.show_consent_dialog(tool_name, &data_types, &endpoints)?;
|
let scope = self.show_consent_dialog(tool_name, &data_types, &endpoints)?;
|
||||||
|
|
||||||
// Remove from pending
|
// Remove from pending
|
||||||
self.pending_requests.remove(tool_name);
|
self.pending_requests.remove(canonical);
|
||||||
|
|
||||||
// Create record based on scope
|
// Create record based on scope
|
||||||
let record = ConsentRecord {
|
let record = ConsentRecord {
|
||||||
tool_name: tool_name.to_string(),
|
tool_name: canonical.to_string(),
|
||||||
scope: scope.clone(),
|
scope: scope.clone(),
|
||||||
timestamp: Utc::now(),
|
timestamp: Utc::now(),
|
||||||
data_types,
|
data_types,
|
||||||
@@ -140,10 +143,10 @@ impl ConsentManager {
|
|||||||
// Store in appropriate location
|
// Store in appropriate location
|
||||||
match scope {
|
match scope {
|
||||||
ConsentScope::Permanent => {
|
ConsentScope::Permanent => {
|
||||||
self.permanent_records.insert(tool_name.to_string(), record);
|
self.permanent_records.insert(canonical.to_string(), record);
|
||||||
}
|
}
|
||||||
ConsentScope::Session => {
|
ConsentScope::Session => {
|
||||||
self.session_records.insert(tool_name.to_string(), record);
|
self.session_records.insert(canonical.to_string(), record);
|
||||||
}
|
}
|
||||||
ConsentScope::Once | ConsentScope::Denied => {
|
ConsentScope::Once | ConsentScope::Denied => {
|
||||||
// Don't store, just return the decision
|
// Don't store, just return the decision
|
||||||
@@ -171,8 +174,9 @@ impl ConsentManager {
|
|||||||
endpoints: Vec<String>,
|
endpoints: Vec<String>,
|
||||||
scope: ConsentScope,
|
scope: ConsentScope,
|
||||||
) {
|
) {
|
||||||
|
let canonical = canonical_tool_name(tool_name);
|
||||||
let record = ConsentRecord {
|
let record = ConsentRecord {
|
||||||
tool_name: tool_name.to_string(),
|
tool_name: canonical.to_string(),
|
||||||
scope: scope.clone(),
|
scope: scope.clone(),
|
||||||
timestamp: Utc::now(),
|
timestamp: Utc::now(),
|
||||||
data_types,
|
data_types,
|
||||||
@@ -181,13 +185,13 @@ impl ConsentManager {
|
|||||||
|
|
||||||
match scope {
|
match scope {
|
||||||
ConsentScope::Permanent => {
|
ConsentScope::Permanent => {
|
||||||
self.permanent_records.insert(tool_name.to_string(), record);
|
self.permanent_records.insert(canonical.to_string(), record);
|
||||||
}
|
}
|
||||||
ConsentScope::Session => {
|
ConsentScope::Session => {
|
||||||
self.session_records.insert(tool_name.to_string(), record);
|
self.session_records.insert(canonical.to_string(), record);
|
||||||
}
|
}
|
||||||
ConsentScope::Once => {
|
ConsentScope::Once => {
|
||||||
self.once_records.insert(tool_name.to_string(), record);
|
self.once_records.insert(canonical.to_string(), record);
|
||||||
}
|
}
|
||||||
ConsentScope::Denied => {} // Denied is not stored
|
ConsentScope::Denied => {} // Denied is not stored
|
||||||
}
|
}
|
||||||
@@ -195,28 +199,30 @@ impl ConsentManager {
|
|||||||
|
|
||||||
/// Check if consent is needed (returns None if already granted, Some(info) if needed)
|
/// Check if consent is needed (returns None if already granted, Some(info) if needed)
|
||||||
pub fn check_consent_needed(&self, tool_name: &str) -> Option<ConsentRequest> {
|
pub fn check_consent_needed(&self, tool_name: &str) -> Option<ConsentRequest> {
|
||||||
if self.has_consent(tool_name) {
|
let canonical = canonical_tool_name(tool_name);
|
||||||
|
if self.has_consent(canonical) {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
Some(ConsentRequest {
|
Some(ConsentRequest {
|
||||||
tool_name: tool_name.to_string(),
|
tool_name: canonical.to_string(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn has_consent(&self, tool_name: &str) -> bool {
|
pub fn has_consent(&self, tool_name: &str) -> bool {
|
||||||
|
let canonical = canonical_tool_name(tool_name);
|
||||||
// Check permanent first, then session, then once
|
// Check permanent first, then session, then once
|
||||||
self.permanent_records
|
self.permanent_records
|
||||||
.get(tool_name)
|
.get(canonical)
|
||||||
.map(|r| r.scope == ConsentScope::Permanent)
|
.map(|r| r.scope == ConsentScope::Permanent)
|
||||||
.or_else(|| {
|
.or_else(|| {
|
||||||
self.session_records
|
self.session_records
|
||||||
.get(tool_name)
|
.get(canonical)
|
||||||
.map(|r| r.scope == ConsentScope::Session)
|
.map(|r| r.scope == ConsentScope::Session)
|
||||||
})
|
})
|
||||||
.or_else(|| {
|
.or_else(|| {
|
||||||
self.once_records
|
self.once_records
|
||||||
.get(tool_name)
|
.get(canonical)
|
||||||
.map(|r| r.scope == ConsentScope::Once)
|
.map(|r| r.scope == ConsentScope::Once)
|
||||||
})
|
})
|
||||||
.unwrap_or(false)
|
.unwrap_or(false)
|
||||||
@@ -224,13 +230,15 @@ impl ConsentManager {
|
|||||||
|
|
||||||
/// Consume "once" consent for a tool (clears it after first use)
|
/// Consume "once" consent for a tool (clears it after first use)
|
||||||
pub fn consume_once_consent(&mut self, tool_name: &str) {
|
pub fn consume_once_consent(&mut self, tool_name: &str) {
|
||||||
self.once_records.remove(tool_name);
|
let canonical = canonical_tool_name(tool_name);
|
||||||
|
self.once_records.remove(canonical);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn revoke_consent(&mut self, tool_name: &str) {
|
pub fn revoke_consent(&mut self, tool_name: &str) {
|
||||||
self.permanent_records.remove(tool_name);
|
let canonical = canonical_tool_name(tool_name);
|
||||||
self.session_records.remove(tool_name);
|
self.permanent_records.remove(canonical);
|
||||||
self.once_records.remove(tool_name);
|
self.session_records.remove(canonical);
|
||||||
|
self.once_records.remove(canonical);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn clear_all_consent(&mut self) {
|
pub fn clear_all_consent(&mut self) {
|
||||||
@@ -253,10 +261,11 @@ impl ConsentManager {
|
|||||||
data_types: Vec<String>,
|
data_types: Vec<String>,
|
||||||
endpoints: Vec<String>,
|
endpoints: Vec<String>,
|
||||||
) -> Option<(String, Vec<String>, Vec<String>)> {
|
) -> Option<(String, Vec<String>, Vec<String>)> {
|
||||||
if self.has_consent(tool_name) {
|
let canonical = canonical_tool_name(tool_name);
|
||||||
|
if self.has_consent(canonical) {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
Some((tool_name.to_string(), data_types, endpoints))
|
Some((canonical.to_string(), data_types, endpoints))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn show_consent_dialog(
|
fn show_consent_dialog(
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ use crate::model::{DetailedModelInfo, ModelManager};
|
|||||||
use crate::oauth::{DeviceAuthorization, DevicePollState, OAuthClient};
|
use crate::oauth::{DeviceAuthorization, DevicePollState, OAuthClient};
|
||||||
use crate::providers::OllamaProvider;
|
use crate::providers::OllamaProvider;
|
||||||
use crate::storage::{SessionMeta, StorageManager};
|
use crate::storage::{SessionMeta, StorageManager};
|
||||||
|
use crate::tools::{WEB_SEARCH_TOOL_NAME, canonical_tool_name, tool_name_matches};
|
||||||
use crate::types::{
|
use crate::types::{
|
||||||
ChatParameters, ChatRequest, ChatResponse, Conversation, Message, ModelInfo, ToolCall,
|
ChatParameters, ChatRequest, ChatResponse, Conversation, Message, ModelInfo, ToolCall,
|
||||||
};
|
};
|
||||||
@@ -407,7 +408,7 @@ async fn build_tools(
|
|||||||
.security
|
.security
|
||||||
.allowed_tools
|
.allowed_tools
|
||||||
.iter()
|
.iter()
|
||||||
.any(|tool| tool == "web_search")
|
.any(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME))
|
||||||
&& config_guard.tools.web_search.enabled
|
&& config_guard.tools.web_search.enabled
|
||||||
&& config_guard.privacy.enable_remote_search
|
&& config_guard.privacy.enable_remote_search
|
||||||
{
|
{
|
||||||
@@ -424,7 +425,7 @@ async fn build_tools(
|
|||||||
|
|
||||||
if let Some(settings) = web_search_settings {
|
if let Some(settings) = web_search_settings {
|
||||||
let tool = WebSearchTool::new(consent_manager.clone(), settings);
|
let tool = WebSearchTool::new(consent_manager.clone(), settings);
|
||||||
registry.register(tool);
|
registry.register(tool)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register web_scrape tool if allowed.
|
// Register web_scrape tool if allowed.
|
||||||
@@ -432,12 +433,12 @@ async fn build_tools(
|
|||||||
.security
|
.security
|
||||||
.allowed_tools
|
.allowed_tools
|
||||||
.iter()
|
.iter()
|
||||||
.any(|tool| tool == "web_scrape")
|
.any(|tool| tool_name_matches(tool, "web_scrape"))
|
||||||
&& config_guard.tools.web_search.enabled // reuse web_search toggle for simplicity
|
&& config_guard.tools.web_search.enabled // reuse web_search toggle for simplicity
|
||||||
&& config_guard.privacy.enable_remote_search
|
&& config_guard.privacy.enable_remote_search
|
||||||
{
|
{
|
||||||
let tool = WebScrapeTool::new();
|
let tool = WebScrapeTool::new();
|
||||||
registry.register(tool);
|
registry.register(tool)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
if enable_code_tools
|
if enable_code_tools
|
||||||
@@ -449,11 +450,11 @@ async fn build_tools(
|
|||||||
&& config_guard.tools.code_exec.enabled
|
&& config_guard.tools.code_exec.enabled
|
||||||
{
|
{
|
||||||
let tool = CodeExecTool::new(config_guard.tools.code_exec.allowed_languages.clone());
|
let tool = CodeExecTool::new(config_guard.tools.code_exec.allowed_languages.clone());
|
||||||
registry.register(tool);
|
registry.register(tool)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
registry.register(ResourcesListTool);
|
registry.register(ResourcesListTool)?;
|
||||||
registry.register(ResourcesGetTool);
|
registry.register(ResourcesGetTool)?;
|
||||||
|
|
||||||
if config_guard
|
if config_guard
|
||||||
.security
|
.security
|
||||||
@@ -461,7 +462,7 @@ async fn build_tools(
|
|||||||
.iter()
|
.iter()
|
||||||
.any(|t| t == "file_write")
|
.any(|t| t == "file_write")
|
||||||
{
|
{
|
||||||
registry.register(ResourcesWriteTool);
|
registry.register(ResourcesWriteTool)?;
|
||||||
}
|
}
|
||||||
if config_guard
|
if config_guard
|
||||||
.security
|
.security
|
||||||
@@ -469,7 +470,7 @@ async fn build_tools(
|
|||||||
.iter()
|
.iter()
|
||||||
.any(|t| t == "file_delete")
|
.any(|t| t == "file_delete")
|
||||||
{
|
{
|
||||||
registry.register(ResourcesDeleteTool);
|
registry.register(ResourcesDeleteTool)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
for tool in registry.all() {
|
for tool in registry.all() {
|
||||||
@@ -1023,13 +1024,14 @@ impl SessionController {
|
|||||||
let mut seen_tools = std::collections::HashSet::new();
|
let mut seen_tools = std::collections::HashSet::new();
|
||||||
|
|
||||||
for tool_call in tool_calls {
|
for tool_call in tool_calls {
|
||||||
if seen_tools.contains(&tool_call.name) {
|
let canonical = canonical_tool_name(tool_call.name.as_str()).to_string();
|
||||||
|
if seen_tools.contains(&canonical) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
seen_tools.insert(tool_call.name.clone());
|
seen_tools.insert(canonical.clone());
|
||||||
|
|
||||||
let (data_types, endpoints) = match tool_call.name.as_str() {
|
let (data_types, endpoints) = match canonical.as_str() {
|
||||||
"web_search" => (
|
WEB_SEARCH_TOOL_NAME => (
|
||||||
vec!["search query".to_string()],
|
vec!["search query".to_string()],
|
||||||
vec!["cloud provider".to_string()],
|
vec!["cloud provider".to_string()],
|
||||||
),
|
),
|
||||||
@@ -1097,13 +1099,14 @@ impl SessionController {
|
|||||||
pub async fn set_tool_enabled(&mut self, tool: &str, enabled: bool) -> Result<()> {
|
pub async fn set_tool_enabled(&mut self, tool: &str, enabled: bool) -> Result<()> {
|
||||||
{
|
{
|
||||||
let mut config = self.config.lock().await;
|
let mut config = self.config.lock().await;
|
||||||
match tool {
|
let canonical = canonical_tool_name(tool);
|
||||||
"web_search" => {
|
match canonical {
|
||||||
|
WEB_SEARCH_TOOL_NAME => {
|
||||||
config.tools.web_search.enabled = enabled;
|
config.tools.web_search.enabled = enabled;
|
||||||
config.privacy.enable_remote_search = enabled;
|
config.privacy.enable_remote_search = enabled;
|
||||||
}
|
}
|
||||||
"code_exec" => config.tools.code_exec.enabled = enabled,
|
"code_exec" => config.tools.code_exec.enabled = enabled,
|
||||||
other => return Err(Error::InvalidInput(format!("Unknown tool: {other}"))),
|
_ => return Err(Error::InvalidInput(format!("Unknown tool: {tool}"))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
self.rebuild_tools().await
|
self.rebuild_tools().await
|
||||||
@@ -1897,12 +1900,12 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn streaming_state_detects_tool_call_changes() {
|
fn streaming_state_detects_tool_call_changes() {
|
||||||
let mut state = StreamingMessageState::new();
|
let mut state = StreamingMessageState::new();
|
||||||
let tool = make_tool_call("call-1", "web.search");
|
let tool = make_tool_call("call-1", "web_search");
|
||||||
|
|
||||||
let diff = state.ingest(&make_response("", Some(vec![tool.clone()]), false));
|
let diff = state.ingest(&make_response("", Some(vec![tool.clone()]), false));
|
||||||
let calls = diff.tool_calls.expect("initial tool call");
|
let calls = diff.tool_calls.expect("initial tool call");
|
||||||
assert_eq!(calls.len(), 1);
|
assert_eq!(calls.len(), 1);
|
||||||
assert_eq!(calls[0].name, "web.search");
|
assert_eq!(calls[0].name, "web_search");
|
||||||
|
|
||||||
let diff = state.ingest(&make_response("", Some(vec![tool.clone()]), false));
|
let diff = state.ingest(&make_response("", Some(vec![tool.clone()]), false));
|
||||||
assert!(
|
assert!(
|
||||||
|
|||||||
@@ -12,12 +12,64 @@ pub mod web_scrape;
|
|||||||
pub mod web_search;
|
pub mod web_search;
|
||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use once_cell::sync::Lazy;
|
||||||
|
use regex::Regex;
|
||||||
use serde_json::{Value, json};
|
use serde_json::{Value, json};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use crate::Result;
|
use crate::Result;
|
||||||
|
|
||||||
|
/// MCP mandates tool identifiers to match `^[A-Za-z0-9_-]{1,64}$`.
|
||||||
|
pub const MAX_TOOL_IDENTIFIER_LEN: usize = 64;
|
||||||
|
|
||||||
|
static TOOL_IDENTIFIER_RE: Lazy<Regex> =
|
||||||
|
Lazy::new(|| Regex::new(r"^[A-Za-z0-9_-]{1,64}$").expect("valid tool identifier regex"));
|
||||||
|
|
||||||
|
pub const WEB_SEARCH_TOOL_NAME: &str = "web_search";
|
||||||
|
|
||||||
|
/// Return the canonical identifier for a tool.
|
||||||
|
pub fn canonical_tool_name(name: &str) -> &str {
|
||||||
|
name
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check whether two tool identifiers refer to the same logical tool.
|
||||||
|
pub fn tool_name_matches(lhs: &str, rhs: &str) -> bool {
|
||||||
|
canonical_tool_name(lhs) == canonical_tool_name(rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Determine whether the provided identifier satisfies the MCP naming contract.
|
||||||
|
pub fn is_valid_tool_identifier(name: &str) -> bool {
|
||||||
|
TOOL_IDENTIFIER_RE.is_match(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Provide lint-style feedback when a tool identifier falls outside the MCP rules.
|
||||||
|
pub fn tool_identifier_violation(name: &str) -> Option<String> {
|
||||||
|
if name.is_empty() {
|
||||||
|
return Some("Tool identifiers must not be empty.".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
if name.len() > MAX_TOOL_IDENTIFIER_LEN {
|
||||||
|
return Some(format!(
|
||||||
|
"Tool identifier '{name}' exceeds the {MAX_TOOL_IDENTIFIER_LEN}-character MCP limit."
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if name.trim() != name {
|
||||||
|
return Some(format!(
|
||||||
|
"Tool identifier '{name}' contains leading or trailing whitespace."
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if !TOOL_IDENTIFIER_RE.is_match(name) {
|
||||||
|
return Some(format!(
|
||||||
|
"Tool identifier '{name}' may only contain ASCII letters, digits, hyphens, or underscores."
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
/// Trait representing a tool that can be called via the MCP interface.
|
/// Trait representing a tool that can be called via the MCP interface.
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait Tool: Send + Sync {
|
pub trait Tool: Send + Sync {
|
||||||
@@ -34,6 +86,10 @@ pub trait Tool: Send + Sync {
|
|||||||
fn requires_filesystem(&self) -> Vec<String> {
|
fn requires_filesystem(&self) -> Vec<String> {
|
||||||
Vec::new()
|
Vec::new()
|
||||||
}
|
}
|
||||||
|
/// Optional additional identifiers (must remain spec-compliant).
|
||||||
|
fn aliases(&self) -> &'static [&'static str] {
|
||||||
|
&[]
|
||||||
|
}
|
||||||
async fn execute(&self, args: Value) -> Result<ToolResult>;
|
async fn execute(&self, args: Value) -> Result<ToolResult>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::Result;
|
use crate::{Error, Result};
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
use super::{Tool, ToolResult};
|
use super::{
|
||||||
|
Tool, ToolResult, WEB_SEARCH_TOOL_NAME, canonical_tool_name, tool_identifier_violation,
|
||||||
|
};
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
use crate::mode::Mode;
|
use crate::mode::Mode;
|
||||||
use crate::ui::UiController;
|
use crate::ui::UiController;
|
||||||
@@ -25,13 +27,32 @@ impl ToolRegistry {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn register<T>(&mut self, tool: T)
|
pub fn register<T>(&mut self, tool: T) -> Result<()>
|
||||||
where
|
where
|
||||||
T: Tool + 'static,
|
T: Tool + 'static,
|
||||||
{
|
{
|
||||||
let tool: Arc<dyn Tool> = Arc::new(tool);
|
let tool: Arc<dyn Tool> = Arc::new(tool);
|
||||||
let name = tool.name().to_string();
|
let name = tool.name();
|
||||||
self.tools.insert(name, tool);
|
|
||||||
|
if let Some(reason) = tool_identifier_violation(name) {
|
||||||
|
log::error!("Tool '{}' failed validation: {}", name, reason);
|
||||||
|
return Err(Error::InvalidInput(format!(
|
||||||
|
"Tool '{name}' is not a valid MCP identifier: {reason}"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
if self
|
||||||
|
.tools
|
||||||
|
.insert(name.to_string(), Arc::clone(&tool))
|
||||||
|
.is_some()
|
||||||
|
{
|
||||||
|
log::warn!(
|
||||||
|
"Tool '{}' was already registered; overwriting previous entry.",
|
||||||
|
name
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
|
pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
|
||||||
@@ -43,20 +64,25 @@ impl ToolRegistry {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn execute(&self, name: &str, args: Value, mode: Mode) -> Result<ToolResult> {
|
pub async fn execute(&self, name: &str, args: Value, mode: Mode) -> Result<ToolResult> {
|
||||||
|
let canonical = canonical_tool_name(name);
|
||||||
let tool = self
|
let tool = self
|
||||||
.get(name)
|
.get(canonical)
|
||||||
.with_context(|| format!("Tool not registered: {}", name))?;
|
.with_context(|| format!("Tool not registered: {}", name))?;
|
||||||
|
|
||||||
let mut config = self.config.lock().await;
|
let mut config = self.config.lock().await;
|
||||||
|
|
||||||
// Check mode-based tool availability first
|
// Check mode-based tool availability first
|
||||||
if !config.modes.is_tool_allowed(mode, name) {
|
if !(config.modes.is_tool_allowed(mode, canonical)
|
||||||
|
|| config.modes.is_tool_allowed(mode, name))
|
||||||
|
{
|
||||||
let alternate_mode = match mode {
|
let alternate_mode = match mode {
|
||||||
Mode::Chat => Mode::Code,
|
Mode::Chat => Mode::Code,
|
||||||
Mode::Code => Mode::Chat,
|
Mode::Code => Mode::Chat,
|
||||||
};
|
};
|
||||||
|
|
||||||
if config.modes.is_tool_allowed(alternate_mode, name) {
|
if config.modes.is_tool_allowed(alternate_mode, canonical)
|
||||||
|
|| config.modes.is_tool_allowed(alternate_mode, name)
|
||||||
|
{
|
||||||
return Ok(ToolResult::error(&format!(
|
return Ok(ToolResult::error(&format!(
|
||||||
"Tool '{}' is not available in {} mode. Switch to {} mode to use this tool (use :mode {} command).",
|
"Tool '{}' is not available in {} mode. Switch to {} mode to use this tool (use :mode {} command).",
|
||||||
name, mode, alternate_mode, alternate_mode
|
name, mode, alternate_mode, alternate_mode
|
||||||
@@ -69,8 +95,8 @@ impl ToolRegistry {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let is_enabled = match name {
|
let is_enabled = match canonical {
|
||||||
"web_search" => config.tools.web_search.enabled,
|
WEB_SEARCH_TOOL_NAME => config.tools.web_search.enabled,
|
||||||
"code_exec" => config.tools.code_exec.enabled,
|
"code_exec" => config.tools.code_exec.enabled,
|
||||||
_ => true, // All other tools are considered enabled by default
|
_ => true, // All other tools are considered enabled by default
|
||||||
};
|
};
|
||||||
@@ -82,8 +108,8 @@ impl ToolRegistry {
|
|||||||
);
|
);
|
||||||
if self.ui.confirm(&prompt).await {
|
if self.ui.confirm(&prompt).await {
|
||||||
// Enable the tool in the in-memory config for the current session
|
// Enable the tool in the in-memory config for the current session
|
||||||
match name {
|
match canonical {
|
||||||
"web_search" => config.tools.web_search.enabled = true,
|
WEB_SEARCH_TOOL_NAME => config.tools.web_search.enabled = true,
|
||||||
"code_exec" => config.tools.code_exec.enabled = true,
|
"code_exec" => config.tools.code_exec.enabled = true,
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
@@ -112,3 +138,69 @@ impl ToolRegistry {
|
|||||||
self.tools.keys().cloned().collect()
|
self.tools.keys().cloned().collect()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::config::Config;
|
||||||
|
use crate::tools::{Tool, ToolResult, WEB_SEARCH_TOOL_NAME};
|
||||||
|
use crate::ui::NoOpUiController;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde_json::{Value, json};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
struct DummyTool {
|
||||||
|
name: &'static str,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for DummyTool {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
self.name
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &'static str {
|
||||||
|
"dummy tool"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn schema(&self) -> Value {
|
||||||
|
json!({ "type": "object" })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn aliases(&self) -> &'static [&'static str] {
|
||||||
|
self.aliases
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, _args: Value) -> Result<ToolResult> {
|
||||||
|
Ok(ToolResult::success(json!({ "echo": true })))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn registry() -> ToolRegistry {
|
||||||
|
let config = Arc::new(tokio::sync::Mutex::new(Config::default()));
|
||||||
|
let ui = Arc::new(NoOpUiController);
|
||||||
|
ToolRegistry::new(config, ui)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn rejects_invalid_tool_identifier() {
|
||||||
|
let mut registry = registry();
|
||||||
|
let tool = DummyTool {
|
||||||
|
name: "invalid.tool",
|
||||||
|
};
|
||||||
|
|
||||||
|
let err = registry.register(tool).unwrap_err();
|
||||||
|
assert!(matches!(err, Error::InvalidInput(_)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn registers_spec_compliant_tool() {
|
||||||
|
let mut registry = registry();
|
||||||
|
let tool = DummyTool {
|
||||||
|
name: WEB_SEARCH_TOOL_NAME,
|
||||||
|
};
|
||||||
|
|
||||||
|
registry.register(tool).unwrap();
|
||||||
|
assert!(registry.get(WEB_SEARCH_TOOL_NAME).is_some());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ use serde_json::{Value, json};
|
|||||||
|
|
||||||
use super::{Tool, ToolResult};
|
use super::{Tool, ToolResult};
|
||||||
use crate::consent::ConsentManager;
|
use crate::consent::ConsentManager;
|
||||||
|
use crate::tools::WEB_SEARCH_TOOL_NAME;
|
||||||
|
|
||||||
/// Configuration applied to the web search tool at registration time.
|
/// Configuration applied to the web search tool at registration time.
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
@@ -44,7 +45,7 @@ impl WebSearchTool {
|
|||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl Tool for WebSearchTool {
|
impl Tool for WebSearchTool {
|
||||||
fn name(&self) -> &'static str {
|
fn name(&self) -> &'static str {
|
||||||
"web_search"
|
WEB_SEARCH_TOOL_NAME
|
||||||
}
|
}
|
||||||
|
|
||||||
fn description(&self) -> &'static str {
|
fn description(&self) -> &'static str {
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ use anyhow::{Context, Result};
|
|||||||
use jsonschema::{JSONSchema, ValidationError};
|
use jsonschema::{JSONSchema, ValidationError};
|
||||||
use serde_json::{Value, json};
|
use serde_json::{Value, json};
|
||||||
|
|
||||||
|
use crate::tools::WEB_SEARCH_TOOL_NAME;
|
||||||
|
|
||||||
pub struct SchemaValidator {
|
pub struct SchemaValidator {
|
||||||
schemas: HashMap<String, JSONSchema>,
|
schemas: HashMap<String, JSONSchema>,
|
||||||
}
|
}
|
||||||
@@ -56,9 +58,7 @@ fn format_validation_error(error: ValidationError) -> String {
|
|||||||
pub fn get_builtin_schemas() -> HashMap<String, Value> {
|
pub fn get_builtin_schemas() -> HashMap<String, Value> {
|
||||||
let mut schemas = HashMap::new();
|
let mut schemas = HashMap::new();
|
||||||
|
|
||||||
schemas.insert(
|
let web_search_schema = json!({
|
||||||
"web_search".to_string(),
|
|
||||||
json!({
|
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"query": {
|
"query": {
|
||||||
@@ -75,8 +75,9 @@ pub fn get_builtin_schemas() -> HashMap<String, Value> {
|
|||||||
},
|
},
|
||||||
"required": ["query"],
|
"required": ["query"],
|
||||||
"additionalProperties": false
|
"additionalProperties": false
|
||||||
}),
|
});
|
||||||
);
|
|
||||||
|
schemas.insert(WEB_SEARCH_TOOL_NAME.to_string(), web_search_schema.clone());
|
||||||
|
|
||||||
schemas.insert(
|
schemas.insert(
|
||||||
"code_exec".to_string(),
|
"code_exec".to_string(),
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ use std::{any::Any, collections::HashMap, sync::Arc};
|
|||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
|
use owlen_core::tools::{WEB_SEARCH_TOOL_NAME, tool_name_matches};
|
||||||
use owlen_core::{
|
use owlen_core::{
|
||||||
Config, Error, Mode, Provider,
|
Config, Error, Mode, Provider,
|
||||||
config::McpMode,
|
config::McpMode,
|
||||||
@@ -88,7 +89,7 @@ impl Provider for StreamingToolProvider {
|
|||||||
|
|
||||||
fn tool_descriptor() -> McpToolDescriptor {
|
fn tool_descriptor() -> McpToolDescriptor {
|
||||||
McpToolDescriptor {
|
McpToolDescriptor {
|
||||||
name: "web_search".to_string(),
|
name: WEB_SEARCH_TOOL_NAME.to_string(),
|
||||||
description: "search".to_string(),
|
description: "search".to_string(),
|
||||||
input_schema: serde_json::json!({"type": "object"}),
|
input_schema: serde_json::json!({"type": "object"}),
|
||||||
requires_network: true,
|
requires_network: true,
|
||||||
@@ -123,7 +124,7 @@ impl CachedResponseClient {
|
|||||||
metadata.insert("cached".to_string(), "true".to_string());
|
metadata.insert("cached".to_string(), "true".to_string());
|
||||||
|
|
||||||
let response = McpToolResponse {
|
let response = McpToolResponse {
|
||||||
name: "web_search".to_string(),
|
name: WEB_SEARCH_TOOL_NAME.to_string(),
|
||||||
success: true,
|
success: true,
|
||||||
output: serde_json::json!({
|
output: serde_json::json!({
|
||||||
"query": "rust",
|
"query": "rust",
|
||||||
@@ -286,13 +287,13 @@ async fn web_tool_timeout_fails_over_to_cached_result() {
|
|||||||
]);
|
]);
|
||||||
|
|
||||||
let call = McpToolCall {
|
let call = McpToolCall {
|
||||||
name: "web_search".to_string(),
|
name: WEB_SEARCH_TOOL_NAME.to_string(),
|
||||||
arguments: serde_json::json!({ "query": "rust", "max_results": 3 }),
|
arguments: serde_json::json!({ "query": "rust", "max_results": 3 }),
|
||||||
};
|
};
|
||||||
|
|
||||||
let response = client.call_tool(call.clone()).await.expect("fallback");
|
let response = client.call_tool(call.clone()).await.expect("fallback");
|
||||||
|
|
||||||
assert_eq!(response.name, "web_search");
|
assert!(tool_name_matches(&response.name, WEB_SEARCH_TOOL_NAME));
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
response.metadata.get("source").map(String::as_str),
|
response.metadata.get("source").map(String::as_str),
|
||||||
Some("cache")
|
Some("cache")
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
use owlen_core::consent::{ConsentManager, ConsentScope};
|
use owlen_core::consent::{ConsentManager, ConsentScope};
|
||||||
|
use owlen_core::tools::WEB_SEARCH_TOOL_NAME;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_consent_scopes() {
|
fn test_consent_scopes() {
|
||||||
@@ -43,23 +44,23 @@ fn test_pending_requests_prevents_duplicates() {
|
|||||||
// In real usage, multiple threads would call request_consent simultaneously
|
// In real usage, multiple threads would call request_consent simultaneously
|
||||||
|
|
||||||
// First, verify a tool has no consent
|
// First, verify a tool has no consent
|
||||||
assert!(!manager.has_consent("web_search"));
|
assert!(!manager.has_consent(WEB_SEARCH_TOOL_NAME));
|
||||||
|
|
||||||
// The pending_requests map is private, but we can test the behavior
|
// The pending_requests map is private, but we can test the behavior
|
||||||
// by checking that consent checks work correctly
|
// by checking that consent checks work correctly
|
||||||
assert!(manager.check_consent_needed("web_search").is_some());
|
assert!(manager.check_consent_needed(WEB_SEARCH_TOOL_NAME).is_some());
|
||||||
|
|
||||||
// Grant session consent
|
// Grant session consent
|
||||||
manager.grant_consent_with_scope(
|
manager.grant_consent_with_scope(
|
||||||
"web_search",
|
WEB_SEARCH_TOOL_NAME,
|
||||||
vec!["search queries".to_string()],
|
vec!["search queries".to_string()],
|
||||||
vec!["https://api.search.com".to_string()],
|
vec!["https://api.search.com".to_string()],
|
||||||
ConsentScope::Session,
|
ConsentScope::Session,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Now it should have consent
|
// Now it should have consent
|
||||||
assert!(manager.has_consent("web_search"));
|
assert!(manager.has_consent(WEB_SEARCH_TOOL_NAME));
|
||||||
assert!(manager.check_consent_needed("web_search").is_none());
|
assert!(manager.check_consent_needed(WEB_SEARCH_TOOL_NAME).is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ use std::{any::Any, collections::HashMap, sync::Arc};
|
|||||||
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use futures::stream;
|
use futures::stream;
|
||||||
|
use owlen_core::tools::{WEB_SEARCH_TOOL_NAME, tool_name_matches};
|
||||||
use owlen_core::{
|
use owlen_core::{
|
||||||
ChatStream, Provider, Result,
|
ChatStream, Provider, Result,
|
||||||
config::Config,
|
config::Config,
|
||||||
@@ -96,12 +97,12 @@ async fn toggling_web_search_updates_config_and_registry() {
|
|||||||
.tool_registry()
|
.tool_registry()
|
||||||
.tools()
|
.tools()
|
||||||
.iter()
|
.iter()
|
||||||
.any(|tool| tool == "web_search"),
|
.any(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME)),
|
||||||
"web_search should be disabled by default"
|
"web_search should be disabled by default"
|
||||||
);
|
);
|
||||||
|
|
||||||
session
|
session
|
||||||
.set_tool_enabled("web_search", true)
|
.set_tool_enabled(WEB_SEARCH_TOOL_NAME, true)
|
||||||
.await
|
.await
|
||||||
.expect("enable web_search");
|
.expect("enable web_search");
|
||||||
|
|
||||||
@@ -115,12 +116,12 @@ async fn toggling_web_search_updates_config_and_registry() {
|
|||||||
.tool_registry()
|
.tool_registry()
|
||||||
.tools()
|
.tools()
|
||||||
.iter()
|
.iter()
|
||||||
.any(|tool| tool == "web_search"),
|
.any(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME)),
|
||||||
"web_search should be registered when enabled"
|
"web_search should be registered when enabled"
|
||||||
);
|
);
|
||||||
|
|
||||||
session
|
session
|
||||||
.set_tool_enabled("web_search", false)
|
.set_tool_enabled(WEB_SEARCH_TOOL_NAME, false)
|
||||||
.await
|
.await
|
||||||
.expect("disable web_search");
|
.expect("disable web_search");
|
||||||
|
|
||||||
@@ -134,7 +135,7 @@ async fn toggling_web_search_updates_config_and_registry() {
|
|||||||
.tool_registry()
|
.tool_registry()
|
||||||
.tools()
|
.tools()
|
||||||
.iter()
|
.iter()
|
||||||
.any(|tool| tool == "web_search"),
|
.any(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME)),
|
||||||
"web_search should be removed when disabled"
|
"web_search should be removed when disabled"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user