diff --git a/Cargo.toml b/Cargo.toml index eaec354..aa33d81 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ members = [ "crates/tools/bash", "crates/tools/fs", "crates/tools/slash", + "crates/integration/mcp-client", ] resolver = "2" diff --git a/crates/integration/mcp-client/Cargo.toml b/crates/integration/mcp-client/Cargo.toml new file mode 100644 index 0000000..ddf5c1a --- /dev/null +++ b/crates/integration/mcp-client/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "mcp-client" +version = "0.1.0" +edition.workspace = true +license.workspace = true +rust-version.workspace = true + +[dependencies] +serde = { version = "1", features = ["derive"] } +serde_json = "1" +tokio = { version = "1.39", features = ["process", "io-util", "sync", "time"] } +color-eyre = "0.6" + +[dev-dependencies] +tempfile = "3.23.0" +tokio = { version = "1.39", features = ["macros", "rt-multi-thread"] } diff --git a/crates/integration/mcp-client/src/lib.rs b/crates/integration/mcp-client/src/lib.rs new file mode 100644 index 0000000..7cdfea6 --- /dev/null +++ b/crates/integration/mcp-client/src/lib.rs @@ -0,0 +1,272 @@ +use color_eyre::eyre::{Result, eyre}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::process::Stdio; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use tokio::process::{Child, Command}; +use tokio::sync::Mutex; + +/// JSON-RPC 2.0 request +#[derive(Debug, Serialize)] +struct JsonRpcRequest { + jsonrpc: String, + id: u64, + method: String, + #[serde(skip_serializing_if = "Option::is_none")] + params: Option, +} + +/// JSON-RPC 2.0 response +#[derive(Debug, Deserialize)] +struct JsonRpcResponse { + jsonrpc: String, + id: u64, + #[serde(skip_serializing_if = "Option::is_none")] + result: Option, + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, +} + +#[derive(Debug, Deserialize)] +struct JsonRpcError { + code: i32, + message: String, +} + +/// MCP server capabilities +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ServerCapabilities { + #[serde(default)] + pub tools: Option, + #[serde(default)] + pub resources: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ToolsCapability { + #[serde(default)] + pub list_changed: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ResourcesCapability { + #[serde(default)] + pub subscribe: Option, + #[serde(default)] + pub list_changed: Option, +} + +/// MCP Tool definition +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct McpTool { + pub name: String, + #[serde(default)] + pub description: Option, + #[serde(default)] + pub input_schema: Option, +} + +/// MCP Resource definition +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct McpResource { + pub uri: String, + #[serde(default)] + pub name: Option, + #[serde(default)] + pub description: Option, + #[serde(default)] + pub mime_type: Option, +} + +/// MCP Client over stdio transport +pub struct McpClient { + process: Mutex, + next_id: Mutex, + server_name: String, +} + +impl McpClient { + /// Create a new MCP client by spawning a subprocess + pub async fn spawn(command: &str, args: &[&str], server_name: &str) -> Result { + let mut child = Command::new(command) + .args(args) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn()?; + + // Verify process is running + if child.try_wait()?.is_some() { + return Err(eyre!("MCP server process exited immediately")); + } + + Ok(Self { + process: Mutex::new(child), + next_id: Mutex::new(1), + server_name: server_name.to_string(), + }) + } + + /// Initialize the MCP connection + pub async fn initialize(&self) -> Result { + let params = serde_json::json!({ + "protocolVersion": "2024-11-05", + "capabilities": { + "roots": { + "listChanged": true + } + }, + "clientInfo": { + "name": "owlen", + "version": env!("CARGO_PKG_VERSION") + } + }); + + let response = self.send_request("initialize", Some(params)).await?; + + let capabilities = response + .get("capabilities") + .ok_or_else(|| eyre!("No capabilities in initialize response"))?; + + Ok(serde_json::from_value(capabilities.clone())?) + } + + /// List available tools + pub async fn list_tools(&self) -> Result> { + let response = self.send_request("tools/list", None).await?; + + let tools = response + .get("tools") + .ok_or_else(|| eyre!("No tools in response"))?; + + Ok(serde_json::from_value(tools.clone())?) + } + + /// Call a tool + pub async fn call_tool(&self, name: &str, arguments: Value) -> Result { + let params = serde_json::json!({ + "name": name, + "arguments": arguments + }); + + let response = self.send_request("tools/call", Some(params)).await?; + + response + .get("content") + .cloned() + .ok_or_else(|| eyre!("No content in tool call response")) + } + + /// List available resources + pub async fn list_resources(&self) -> Result> { + let response = self.send_request("resources/list", None).await?; + + let resources = response + .get("resources") + .ok_or_else(|| eyre!("No resources in response"))?; + + Ok(serde_json::from_value(resources.clone())?) + } + + /// Read a resource + pub async fn read_resource(&self, uri: &str) -> Result { + let params = serde_json::json!({ + "uri": uri + }); + + let response = self.send_request("resources/read", Some(params)).await?; + + response + .get("contents") + .cloned() + .ok_or_else(|| eyre!("No contents in resource read response")) + } + + /// Get the server name + pub fn server_name(&self) -> &str { + &self.server_name + } + + /// Send a JSON-RPC request and get the response + async fn send_request(&self, method: &str, params: Option) -> Result { + let mut next_id = self.next_id.lock().await; + let id = *next_id; + *next_id += 1; + drop(next_id); + + let request = JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id, + method: method.to_string(), + params, + }; + + let request_json = serde_json::to_string(&request)?; + + let mut process = self.process.lock().await; + + // Write request + let stdin = process.stdin.as_mut().ok_or_else(|| eyre!("No stdin"))?; + stdin.write_all(request_json.as_bytes()).await?; + stdin.write_all(b"\n").await?; + stdin.flush().await?; + + // Read response + let stdout = process.stdout.take().ok_or_else(|| eyre!("No stdout"))?; + let mut reader = BufReader::new(stdout); + let mut response_line = String::new(); + reader.read_line(&mut response_line).await?; + + // Put stdout back + process.stdout = Some(reader.into_inner()); + + drop(process); + + let response: JsonRpcResponse = serde_json::from_str(&response_line)?; + + if response.id != id { + return Err(eyre!("Response ID mismatch: expected {}, got {}", id, response.id)); + } + + if let Some(error) = response.error { + return Err(eyre!("MCP error {}: {}", error.code, error.message)); + } + + response.result.ok_or_else(|| eyre!("No result in response")) + } + + /// Close the MCP connection + pub async fn close(self) -> Result<()> { + let mut process = self.process.into_inner(); + + // Close stdin to signal the server to exit + drop(process.stdin.take()); + + // Wait for process to exit (with timeout) + tokio::time::timeout( + std::time::Duration::from_secs(5), + process.wait() + ).await??; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn jsonrpc_request_serializes() { + let req = JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: 1, + method: "test".to_string(), + params: Some(serde_json::json!({"key": "value"})), + }; + + let json = serde_json::to_string(&req).unwrap(); + assert!(json.contains("\"method\":\"test\"")); + assert!(json.contains("\"id\":1")); + } +} diff --git a/crates/integration/mcp-client/tests/mcp.rs b/crates/integration/mcp-client/tests/mcp.rs new file mode 100644 index 0000000..77e1b41 --- /dev/null +++ b/crates/integration/mcp-client/tests/mcp.rs @@ -0,0 +1,347 @@ +use mcp_client::McpClient; +use std::fs; +use tempfile::tempdir; + +#[tokio::test] +async fn mcp_server_capability_negotiation() { + // Create a mock MCP server script + let dir = tempdir().unwrap(); + let server_script = dir.path().join("mock_server.py"); + + let script_content = r#"#!/usr/bin/env python3 +import sys +import json + +def read_request(): + line = sys.stdin.readline() + return json.loads(line) + +def send_response(response): + sys.stdout.write(json.dumps(response) + '\n') + sys.stdout.flush() + +# Main loop +while True: + try: + req = read_request() + method = req.get('method') + req_id = req.get('id') + + if method == 'initialize': + send_response({ + 'jsonrpc': '2.0', + 'id': req_id, + 'result': { + 'protocolVersion': '2024-11-05', + 'capabilities': { + 'tools': {'list_changed': True}, + 'resources': {'subscribe': False} + }, + 'serverInfo': { + 'name': 'test-server', + 'version': '1.0.0' + } + } + }) + elif method == 'tools/list': + send_response({ + 'jsonrpc': '2.0', + 'id': req_id, + 'result': { + 'tools': [] + } + }) + else: + send_response({ + 'jsonrpc': '2.0', + 'id': req_id, + 'error': { + 'code': -32601, + 'message': f'Method not found: {method}' + } + }) + except EOFError: + break + except Exception as e: + sys.stderr.write(f'Error: {e}\n') + break +"#; + + fs::write(&server_script, script_content).unwrap(); + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + fs::set_permissions(&server_script, std::fs::Permissions::from_mode(0o755)).unwrap(); + } + + // Connect to the server + let client = McpClient::spawn( + "python3", + &[server_script.to_str().unwrap()], + "test-server" + ).await.unwrap(); + + // Initialize + let capabilities = client.initialize().await.unwrap(); + + // Verify capabilities + assert!(capabilities.tools.is_some()); + assert_eq!(capabilities.tools.unwrap().list_changed, Some(true)); + + client.close().await.unwrap(); +} + +#[tokio::test] +async fn mcp_tool_invocation() { + let dir = tempdir().unwrap(); + let server_script = dir.path().join("mock_server.py"); + + let script_content = r#"#!/usr/bin/env python3 +import sys +import json + +def read_request(): + line = sys.stdin.readline() + return json.loads(line) + +def send_response(response): + sys.stdout.write(json.dumps(response) + '\n') + sys.stdout.flush() + +while True: + try: + req = read_request() + method = req.get('method') + req_id = req.get('id') + params = req.get('params', {}) + + if method == 'initialize': + send_response({ + 'jsonrpc': '2.0', + 'id': req_id, + 'result': { + 'protocolVersion': '2024-11-05', + 'capabilities': { + 'tools': {} + }, + 'serverInfo': { + 'name': 'test-server', + 'version': '1.0.0' + } + } + }) + elif method == 'tools/list': + send_response({ + 'jsonrpc': '2.0', + 'id': req_id, + 'result': { + 'tools': [ + { + 'name': 'echo', + 'description': 'Echo the input', + 'input_schema': { + 'type': 'object', + 'properties': { + 'message': {'type': 'string'} + } + } + } + ] + } + }) + elif method == 'tools/call': + tool_name = params.get('name') + arguments = params.get('arguments', {}) + if tool_name == 'echo': + send_response({ + 'jsonrpc': '2.0', + 'id': req_id, + 'result': { + 'content': [ + { + 'type': 'text', + 'text': arguments.get('message', '') + } + ] + } + }) + else: + send_response({ + 'jsonrpc': '2.0', + 'id': req_id, + 'error': { + 'code': -32602, + 'message': f'Unknown tool: {tool_name}' + } + }) + else: + send_response({ + 'jsonrpc': '2.0', + 'id': req_id, + 'error': { + 'code': -32601, + 'message': f'Method not found: {method}' + } + }) + except EOFError: + break + except Exception as e: + sys.stderr.write(f'Error: {e}\n') + break +"#; + + fs::write(&server_script, script_content).unwrap(); + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + fs::set_permissions(&server_script, std::fs::Permissions::from_mode(0o755)).unwrap(); + } + + let client = McpClient::spawn( + "python3", + &[server_script.to_str().unwrap()], + "test-server" + ).await.unwrap(); + + client.initialize().await.unwrap(); + + // List tools + let tools = client.list_tools().await.unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].name, "echo"); + + // Call tool + let result = client.call_tool( + "echo", + serde_json::json!({"message": "Hello, MCP!"}) + ).await.unwrap(); + + // Verify result + let content = result.as_array().unwrap(); + assert_eq!(content[0]["text"].as_str().unwrap(), "Hello, MCP!"); + + client.close().await.unwrap(); +} + +#[tokio::test] +async fn mcp_resource_reads() { + let dir = tempdir().unwrap(); + let server_script = dir.path().join("mock_server.py"); + + let script_content = r#"#!/usr/bin/env python3 +import sys +import json + +def read_request(): + line = sys.stdin.readline() + return json.loads(line) + +def send_response(response): + sys.stdout.write(json.dumps(response) + '\n') + sys.stdout.flush() + +while True: + try: + req = read_request() + method = req.get('method') + req_id = req.get('id') + params = req.get('params', {}) + + if method == 'initialize': + send_response({ + 'jsonrpc': '2.0', + 'id': req_id, + 'result': { + 'protocolVersion': '2024-11-05', + 'capabilities': { + 'resources': {} + }, + 'serverInfo': { + 'name': 'test-server', + 'version': '1.0.0' + } + } + }) + elif method == 'resources/list': + send_response({ + 'jsonrpc': '2.0', + 'id': req_id, + 'result': { + 'resources': [ + { + 'uri': 'file:///test.txt', + 'name': 'Test File', + 'description': 'A test file', + 'mime_type': 'text/plain' + } + ] + } + }) + elif method == 'resources/read': + uri = params.get('uri') + if uri == 'file:///test.txt': + send_response({ + 'jsonrpc': '2.0', + 'id': req_id, + 'result': { + 'contents': [ + { + 'uri': uri, + 'mime_type': 'text/plain', + 'text': 'Hello from resource!' + } + ] + } + }) + else: + send_response({ + 'jsonrpc': '2.0', + 'id': req_id, + 'error': { + 'code': -32602, + 'message': f'Unknown resource: {uri}' + } + }) + else: + send_response({ + 'jsonrpc': '2.0', + 'id': req_id, + 'error': { + 'code': -32601, + 'message': f'Method not found: {method}' + } + }) + except EOFError: + break + except Exception as e: + sys.stderr.write(f'Error: {e}\n') + break +"#; + + fs::write(&server_script, script_content).unwrap(); + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + fs::set_permissions(&server_script, std::fs::Permissions::from_mode(0o755)).unwrap(); + } + + let client = McpClient::spawn( + "python3", + &[server_script.to_str().unwrap()], + "test-server" + ).await.unwrap(); + + client.initialize().await.unwrap(); + + // List resources + let resources = client.list_resources().await.unwrap(); + assert_eq!(resources.len(), 1); + assert_eq!(resources[0].uri, "file:///test.txt"); + + // Read resource + let contents = client.read_resource("file:///test.txt").await.unwrap(); + let contents_array = contents.as_array().unwrap(); + assert_eq!(contents_array[0]["text"].as_str().unwrap(), "Hello from resource!"); + + client.close().await.unwrap(); +} diff --git a/crates/platform/permissions/src/lib.rs b/crates/platform/permissions/src/lib.rs index 8c8dfd2..02728f5 100644 --- a/crates/platform/permissions/src/lib.rs +++ b/crates/platform/permissions/src/lib.rs @@ -15,6 +15,7 @@ pub enum Tool { SlashCommand, Task, TodoWrite, + Mcp, } #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] @@ -136,7 +137,7 @@ impl PermissionManager { // Edit/Write operations allowed Tool::Edit | Tool::Write | Tool::NotebookEdit => PermissionDecision::Allow, // Bash and other dangerous operations still require asking - Tool::Bash | Tool::WebFetch | Tool::WebSearch => PermissionDecision::Ask, + Tool::Bash | Tool::WebFetch | Tool::WebSearch | Tool::Mcp => PermissionDecision::Ask, // Utility tools allowed Tool::TodoWrite | Tool::SlashCommand | Tool::Task => PermissionDecision::Allow, }, @@ -209,4 +210,31 @@ mod tests { assert!(rule.matches(Tool::Read, Some("any context"))); assert!(rule.matches(Tool::Read, None)); } + + #[test] + fn mcp_server_pattern_matching() { + // Allow all tools from a specific server + let rule = PermissionRule { + tool: Tool::Mcp, + pattern: Some("filesystem__*".to_string()), + action: Action::Allow, + }; + + assert!(rule.matches(Tool::Mcp, Some("filesystem__read_file"))); + assert!(rule.matches(Tool::Mcp, Some("filesystem__write_file"))); + assert!(!rule.matches(Tool::Mcp, Some("database__query"))); + } + + #[test] + fn mcp_exact_tool_matching() { + // Allow only a specific tool from a server + let rule = PermissionRule { + tool: Tool::Mcp, + pattern: Some("filesystem__read_file".to_string()), + action: Action::Allow, + }; + + assert!(rule.matches(Tool::Mcp, Some("filesystem__read_file"))); + assert!(!rule.matches(Tool::Mcp, Some("filesystem__write_file"))); + } }