- Reject dotted tool identifiers during registration and remove alias-backed lookups. - Drop web.search compatibility, normalize all code/tests around the canonical web_search name, and update consent/session logic. - Harden CLI toggles to manage the spec-compliant identifier and ensure MCP configs shed non-compliant entries automatically. Acceptance Criteria: - Tool registry denies invalid identifiers by default and no alias codepaths remain. Test Notes: - cargo check -p owlen-core (tests unavailable in sandbox).
312 lines
9.4 KiB
Rust
312 lines
9.4 KiB
Rust
use std::{any::Any, collections::HashMap, sync::Arc};
|
|
|
|
use async_trait::async_trait;
|
|
use futures::StreamExt;
|
|
use owlen_core::tools::{WEB_SEARCH_TOOL_NAME, tool_name_matches};
|
|
use owlen_core::{
|
|
Config, Error, Mode, Provider,
|
|
config::McpMode,
|
|
consent::ConsentScope,
|
|
mcp::{
|
|
McpClient, McpToolCall, McpToolDescriptor, McpToolResponse,
|
|
failover::{FailoverMcpClient, ServerEntry},
|
|
},
|
|
session::{ControllerEvent, SessionController, SessionOutcome},
|
|
storage::StorageManager,
|
|
types::{ChatParameters, ChatRequest, ChatResponse, Message, ModelInfo, Role, ToolCall},
|
|
ui::NoOpUiController,
|
|
};
|
|
use tempfile::tempdir;
|
|
use tokio::sync::mpsc;
|
|
|
|
struct StreamingToolProvider;
|
|
|
|
#[async_trait]
|
|
impl Provider for StreamingToolProvider {
|
|
fn name(&self) -> &str {
|
|
"mock-streaming-provider"
|
|
}
|
|
|
|
async fn list_models(&self) -> owlen_core::Result<Vec<ModelInfo>> {
|
|
Ok(vec![ModelInfo {
|
|
id: "mock-model".into(),
|
|
name: "Mock Model".into(),
|
|
description: Some("A mock model that emits tool calls".into()),
|
|
provider: self.name().into(),
|
|
context_window: Some(4096),
|
|
capabilities: vec!["chat".into(), "tools".into()],
|
|
supports_tools: true,
|
|
}])
|
|
}
|
|
|
|
async fn send_prompt(&self, _request: ChatRequest) -> owlen_core::Result<ChatResponse> {
|
|
let mut message = Message::assistant("tool-call".to_string());
|
|
message.tool_calls = Some(vec![ToolCall {
|
|
id: "call-1".to_string(),
|
|
name: "resources/write".to_string(),
|
|
arguments: serde_json::json!({"path": "README.md", "content": "hello"}),
|
|
}]);
|
|
|
|
Ok(ChatResponse {
|
|
message,
|
|
usage: None,
|
|
is_streaming: false,
|
|
is_final: true,
|
|
})
|
|
}
|
|
|
|
async fn stream_prompt(
|
|
&self,
|
|
_request: ChatRequest,
|
|
) -> owlen_core::Result<owlen_core::ChatStream> {
|
|
let mut first_chunk = Message::assistant(
|
|
"Thought: need to update README.\nAction: resources/write".to_string(),
|
|
);
|
|
first_chunk.tool_calls = Some(vec![ToolCall {
|
|
id: "call-1".to_string(),
|
|
name: "resources/write".to_string(),
|
|
arguments: serde_json::json!({"path": "README.md", "content": "hello"}),
|
|
}]);
|
|
|
|
let chunk = ChatResponse {
|
|
message: first_chunk,
|
|
usage: None,
|
|
is_streaming: true,
|
|
is_final: false,
|
|
};
|
|
|
|
Ok(Box::pin(futures::stream::iter(vec![Ok(chunk)])))
|
|
}
|
|
|
|
async fn health_check(&self) -> owlen_core::Result<()> {
|
|
Ok(())
|
|
}
|
|
|
|
fn as_any(&self) -> &(dyn Any + Send + Sync) {
|
|
self
|
|
}
|
|
}
|
|
|
|
fn tool_descriptor() -> McpToolDescriptor {
|
|
McpToolDescriptor {
|
|
name: WEB_SEARCH_TOOL_NAME.to_string(),
|
|
description: "search".to_string(),
|
|
input_schema: serde_json::json!({"type": "object"}),
|
|
requires_network: true,
|
|
requires_filesystem: vec![],
|
|
}
|
|
}
|
|
|
|
struct TimeoutClient;
|
|
|
|
#[async_trait]
|
|
impl McpClient for TimeoutClient {
|
|
async fn list_tools(&self) -> owlen_core::Result<Vec<McpToolDescriptor>> {
|
|
Ok(vec![tool_descriptor()])
|
|
}
|
|
|
|
async fn call_tool(&self, _call: McpToolCall) -> owlen_core::Result<McpToolResponse> {
|
|
Err(Error::Network(
|
|
"timeout while contacting remote web search endpoint".into(),
|
|
))
|
|
}
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
struct CachedResponseClient {
|
|
response: Arc<McpToolResponse>,
|
|
}
|
|
|
|
impl CachedResponseClient {
|
|
fn new() -> Self {
|
|
let mut metadata = HashMap::new();
|
|
metadata.insert("source".to_string(), "cache".to_string());
|
|
metadata.insert("cached".to_string(), "true".to_string());
|
|
|
|
let response = McpToolResponse {
|
|
name: WEB_SEARCH_TOOL_NAME.to_string(),
|
|
success: true,
|
|
output: serde_json::json!({
|
|
"query": "rust",
|
|
"results": [
|
|
{"title": "Rust Programming Language", "url": "https://www.rust-lang.org"}
|
|
],
|
|
"note": "cached result"
|
|
}),
|
|
metadata,
|
|
duration_ms: 0,
|
|
};
|
|
|
|
Self {
|
|
response: Arc::new(response),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl McpClient for CachedResponseClient {
|
|
async fn list_tools(&self) -> owlen_core::Result<Vec<McpToolDescriptor>> {
|
|
Ok(vec![tool_descriptor()])
|
|
}
|
|
|
|
async fn call_tool(&self, _call: McpToolCall) -> owlen_core::Result<McpToolResponse> {
|
|
Ok((*self.response).clone())
|
|
}
|
|
}
|
|
|
|
#[tokio::test(flavor = "multi_thread")]
|
|
async fn streaming_file_write_consent_denied_returns_resolution() {
|
|
let temp_dir = tempdir().expect("temp dir");
|
|
let storage = StorageManager::with_database_path(temp_dir.path().join("owlen-tests.db"))
|
|
.await
|
|
.expect("storage");
|
|
|
|
let mut config = Config::default();
|
|
config.general.enable_streaming = true;
|
|
config.privacy.encrypt_local_data = false;
|
|
config.privacy.require_consent_per_session = true;
|
|
config.general.default_model = Some("mock-model".into());
|
|
config.mcp.mode = McpMode::LocalOnly;
|
|
config
|
|
.refresh_mcp_servers(None)
|
|
.expect("refresh MCP servers");
|
|
|
|
let provider: Arc<dyn Provider> = Arc::new(StreamingToolProvider);
|
|
let ui = Arc::new(NoOpUiController);
|
|
let (event_tx, mut event_rx) = mpsc::unbounded_channel::<ControllerEvent>();
|
|
|
|
let mut session = SessionController::new(
|
|
provider,
|
|
config,
|
|
Arc::new(storage),
|
|
ui,
|
|
true,
|
|
Some(event_tx),
|
|
)
|
|
.await
|
|
.expect("session controller");
|
|
|
|
session
|
|
.set_operating_mode(Mode::Code)
|
|
.await
|
|
.expect("code mode");
|
|
|
|
let outcome = session
|
|
.send_message(
|
|
"Please write to README".to_string(),
|
|
ChatParameters {
|
|
stream: true,
|
|
..Default::default()
|
|
},
|
|
)
|
|
.await
|
|
.expect("send message");
|
|
|
|
let (response_id, mut stream) = if let SessionOutcome::Streaming {
|
|
response_id,
|
|
stream,
|
|
} = outcome
|
|
{
|
|
(response_id, stream)
|
|
} else {
|
|
panic!("expected streaming outcome");
|
|
};
|
|
|
|
session
|
|
.mark_stream_placeholder(response_id, "▌")
|
|
.expect("placeholder");
|
|
|
|
let chunk = stream
|
|
.next()
|
|
.await
|
|
.expect("stream chunk")
|
|
.expect("chunk result");
|
|
session
|
|
.apply_stream_chunk(response_id, &chunk)
|
|
.expect("apply chunk");
|
|
|
|
let tool_calls = session
|
|
.check_streaming_tool_calls(response_id)
|
|
.expect("tool calls");
|
|
assert_eq!(tool_calls.len(), 1);
|
|
assert_eq!(tool_calls[0].name, "resources/write");
|
|
|
|
let event = event_rx.recv().await.expect("controller event");
|
|
let request_id = match event {
|
|
ControllerEvent::ToolRequested {
|
|
request_id,
|
|
tool_name,
|
|
data_types,
|
|
endpoints,
|
|
..
|
|
} => {
|
|
assert_eq!(tool_name, "resources/write");
|
|
assert!(data_types.iter().any(|t| t.contains("file")));
|
|
assert!(endpoints.iter().any(|e| e.contains("filesystem")));
|
|
request_id
|
|
}
|
|
};
|
|
|
|
let resolution = session
|
|
.resolve_tool_consent(request_id, ConsentScope::Denied)
|
|
.expect("resolution");
|
|
assert_eq!(resolution.scope, ConsentScope::Denied);
|
|
assert_eq!(resolution.tool_name, "resources/write");
|
|
assert_eq!(resolution.tool_calls.len(), tool_calls.len());
|
|
|
|
let err = session
|
|
.resolve_tool_consent(request_id, ConsentScope::Denied)
|
|
.expect_err("second resolution should fail");
|
|
matches!(err, Error::InvalidInput(_));
|
|
|
|
let conversation = session.conversation().clone();
|
|
let assistant = conversation
|
|
.messages
|
|
.iter()
|
|
.find(|message| message.role == Role::Assistant)
|
|
.expect("assistant message present");
|
|
assert!(
|
|
assistant
|
|
.tool_calls
|
|
.as_ref()
|
|
.and_then(|calls| calls.first())
|
|
.is_some_and(|call| call.name == "resources/write"),
|
|
"stream chunk should capture the tool call on the assistant message"
|
|
);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn web_tool_timeout_fails_over_to_cached_result() {
|
|
let primary: Arc<dyn McpClient> = Arc::new(TimeoutClient);
|
|
let cached = CachedResponseClient::new();
|
|
let backup: Arc<dyn McpClient> = Arc::new(cached.clone());
|
|
|
|
let client = FailoverMcpClient::with_servers(vec![
|
|
ServerEntry::new("primary".into(), primary, 1),
|
|
ServerEntry::new("cache".into(), backup, 2),
|
|
]);
|
|
|
|
let call = McpToolCall {
|
|
name: WEB_SEARCH_TOOL_NAME.to_string(),
|
|
arguments: serde_json::json!({ "query": "rust", "max_results": 3 }),
|
|
};
|
|
|
|
let response = client.call_tool(call.clone()).await.expect("fallback");
|
|
|
|
assert!(tool_name_matches(&response.name, WEB_SEARCH_TOOL_NAME));
|
|
assert_eq!(
|
|
response.metadata.get("source").map(String::as_str),
|
|
Some("cache")
|
|
);
|
|
assert_eq!(
|
|
response.output.get("note").and_then(|value| value.as_str()),
|
|
Some("cached result")
|
|
);
|
|
|
|
let statuses = client.get_server_status().await;
|
|
assert!(statuses.iter().any(|(name, health)| name == "primary"
|
|
&& !matches!(health, owlen_core::mcp::failover::ServerHealth::Healthy)));
|
|
assert!(statuses.iter().any(|(name, health)| name == "cache"
|
|
&& matches!(health, owlen_core::mcp::failover::ServerHealth::Healthy)));
|
|
}
|