188 lines
5.3 KiB
Rust
188 lines
5.3 KiB
Rust
use crate::Result;
|
|
use crate::mode::Mode;
|
|
use crate::tools::registry::ToolRegistry;
|
|
use crate::validation::SchemaValidator;
|
|
use async_trait::async_trait;
|
|
pub use client::McpClient;
|
|
use serde::{Deserialize, Serialize};
|
|
use serde_json::Value;
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use std::time::Duration;
|
|
|
|
pub mod client;
|
|
pub mod factory;
|
|
pub mod failover;
|
|
pub mod permission;
|
|
pub mod protocol;
|
|
pub mod remote_client;
|
|
|
|
/// Descriptor for a tool exposed over MCP
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct McpToolDescriptor {
|
|
pub name: String,
|
|
pub description: String,
|
|
pub input_schema: Value,
|
|
pub requires_network: bool,
|
|
pub requires_filesystem: Vec<String>,
|
|
}
|
|
|
|
/// Invocation payload for a tool call
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct McpToolCall {
|
|
pub name: String,
|
|
pub arguments: Value,
|
|
}
|
|
|
|
/// Result returned by a tool invocation
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct McpToolResponse {
|
|
pub name: String,
|
|
pub success: bool,
|
|
pub output: Value,
|
|
pub metadata: HashMap<String, String>,
|
|
pub duration_ms: u128,
|
|
}
|
|
|
|
/// Thin MCP server facade over the tool registry
|
|
pub struct McpServer {
|
|
registry: Arc<ToolRegistry>,
|
|
validator: Arc<SchemaValidator>,
|
|
mode: Arc<tokio::sync::RwLock<Mode>>,
|
|
}
|
|
|
|
impl McpServer {
|
|
pub fn new(registry: Arc<ToolRegistry>, validator: Arc<SchemaValidator>) -> Self {
|
|
Self {
|
|
registry,
|
|
validator,
|
|
mode: Arc::new(tokio::sync::RwLock::new(Mode::default())),
|
|
}
|
|
}
|
|
|
|
/// Set the current operating mode
|
|
pub async fn set_mode(&self, mode: Mode) {
|
|
*self.mode.write().await = mode;
|
|
}
|
|
|
|
/// Get the current operating mode
|
|
pub async fn get_mode(&self) -> Mode {
|
|
*self.mode.read().await
|
|
}
|
|
|
|
/// Enumerate the registered tools as MCP descriptors
|
|
pub async fn list_tools(&self) -> Vec<McpToolDescriptor> {
|
|
let mode = self.get_mode().await;
|
|
let available_tools = self.registry.available_tools(mode).await;
|
|
|
|
self.registry
|
|
.all()
|
|
.into_iter()
|
|
.filter(|tool| available_tools.contains(&tool.name().to_string()))
|
|
.map(|tool| McpToolDescriptor {
|
|
name: tool.name().to_string(),
|
|
description: tool.description().to_string(),
|
|
input_schema: tool.schema(),
|
|
requires_network: tool.requires_network(),
|
|
requires_filesystem: tool.requires_filesystem(),
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
/// Execute a tool call after validating inputs against the registered schema
|
|
pub async fn call_tool(&self, call: McpToolCall) -> Result<McpToolResponse> {
|
|
self.validator.validate(&call.name, &call.arguments)?;
|
|
let mode = self.get_mode().await;
|
|
let result = self
|
|
.registry
|
|
.execute(&call.name, call.arguments, mode)
|
|
.await?;
|
|
Ok(McpToolResponse {
|
|
name: call.name,
|
|
success: result.success,
|
|
output: result.output,
|
|
metadata: result.metadata,
|
|
duration_ms: duration_to_millis(result.duration),
|
|
})
|
|
}
|
|
}
|
|
|
|
fn duration_to_millis(duration: Duration) -> u128 {
|
|
duration.as_secs() as u128 * 1_000 + u128::from(duration.subsec_millis())
|
|
}
|
|
|
|
pub struct LocalMcpClient {
|
|
server: McpServer,
|
|
}
|
|
|
|
impl LocalMcpClient {
|
|
pub fn new(registry: Arc<ToolRegistry>, validator: Arc<SchemaValidator>) -> Self {
|
|
Self {
|
|
server: McpServer::new(registry, validator),
|
|
}
|
|
}
|
|
|
|
/// Set the current operating mode
|
|
pub async fn set_mode(&self, mode: Mode) {
|
|
self.server.set_mode(mode).await;
|
|
}
|
|
|
|
/// Get the current operating mode
|
|
pub async fn get_mode(&self) -> Mode {
|
|
self.server.get_mode().await
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl McpClient for LocalMcpClient {
|
|
async fn list_tools(&self) -> Result<Vec<McpToolDescriptor>> {
|
|
Ok(self.server.list_tools().await)
|
|
}
|
|
|
|
async fn call_tool(&self, call: McpToolCall) -> Result<McpToolResponse> {
|
|
self.server.call_tool(call).await
|
|
}
|
|
|
|
async fn set_mode(&self, mode: Mode) -> Result<()> {
|
|
self.server.set_mode(mode).await;
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
pub mod test_utils {
|
|
use super::*;
|
|
|
|
/// Mock MCP client for testing
|
|
#[derive(Default)]
|
|
pub struct MockMcpClient;
|
|
|
|
#[async_trait]
|
|
impl McpClient for MockMcpClient {
|
|
async fn list_tools(&self) -> Result<Vec<McpToolDescriptor>> {
|
|
Ok(vec![McpToolDescriptor {
|
|
name: "mock_tool".to_string(),
|
|
description: "A mock tool for testing".to_string(),
|
|
input_schema: serde_json::json!({
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {"type": "string"}
|
|
}
|
|
}),
|
|
requires_network: false,
|
|
requires_filesystem: vec![],
|
|
}])
|
|
}
|
|
|
|
async fn call_tool(&self, call: McpToolCall) -> Result<McpToolResponse> {
|
|
Ok(McpToolResponse {
|
|
name: call.name,
|
|
success: true,
|
|
output: serde_json::json!({"result": "mock result"}),
|
|
metadata: HashMap::new(),
|
|
duration_ms: 10,
|
|
})
|
|
}
|
|
}
|
|
}
|