Files
owlen/crates/owlen-core/src/mcp/remote_client.rs

594 lines
24 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use super::protocol::methods;
use super::protocol::{
PROTOCOL_VERSION, RequestId, RpcErrorResponse, RpcNotification, RpcRequest, RpcResponse,
};
use super::{McpClient, McpToolCall, McpToolDescriptor, McpToolResponse};
use crate::consent::{ConsentManager, ConsentScope};
use crate::tools::{Tool, WebScrapeTool, WebSearchTool};
use crate::types::ModelInfo;
use crate::types::{ChatResponse, Message, Role};
use crate::{
ChatStream, Error, LlmProvider, Result, facade::llm_client::LlmClient, mode::Mode,
send_via_stream,
};
use anyhow::anyhow;
use futures::{StreamExt, future::BoxFuture, stream};
use reqwest::Client as HttpClient;
use serde_json::json;
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, Command};
use tokio::sync::Mutex;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async};
use tungstenite::protocol::Message as WsMessage;
/// Client that talks to the external `owlen-mcp-server` over STDIO, HTTP, or WebSocket.
pub struct RemoteMcpClient {
// Child process handling the server (kept alive for the duration of the client).
#[allow(dead_code)]
// For stdio transport, we keep the child process handles.
child: Option<Arc<Mutex<Child>>>,
stdin: Option<Arc<Mutex<tokio::process::ChildStdin>>>, // async write
stdout: Option<Arc<Mutex<BufReader<tokio::process::ChildStdout>>>>,
// For HTTP transport we keep a reusable client and base URL.
http_client: Option<HttpClient>,
http_endpoint: Option<String>,
// For WebSocket transport we keep a WebSocket stream.
ws_stream: Option<Arc<Mutex<WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>>>>,
#[allow(dead_code)] // Useful for debugging/logging
ws_endpoint: Option<String>,
// Incrementing request identifier.
next_id: AtomicU64,
// Optional HTTP header (name, value) injected into every request.
http_header: Option<(String, String)>,
}
/// Runtime secrets provided when constructing an MCP client.
#[derive(Debug, Default, Clone)]
pub struct McpRuntimeSecrets {
pub env_overrides: HashMap<String, String>,
pub http_header: Option<(String, String)>,
}
impl RemoteMcpClient {
/// Spawn the MCP server binary and prepare communication channels.
/// Spawn an MCP server based on a configuration entry.
/// The `transport` field must be "stdio" (the only supported mode).
/// Spawn an external MCP server based on a configuration entry.
/// The server must communicate over STDIO (the only supported transport).
pub fn new_with_config(config: &crate::config::McpServerConfig) -> Result<Self> {
Self::new_with_runtime(config, None)
}
pub fn new_with_runtime(
config: &crate::config::McpServerConfig,
runtime: Option<McpRuntimeSecrets>,
) -> Result<Self> {
let mut runtime = runtime.unwrap_or_default();
let transport = config.transport.to_lowercase();
match transport.as_str() {
"stdio" => {
// Build the command using the provided binary and arguments.
let mut cmd = Command::new(config.command.clone());
if !config.args.is_empty() {
cmd.args(config.args.clone());
}
cmd.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::inherit());
// Apply environment variables defined in the configuration.
for (k, v) in config.env.iter() {
cmd.env(k, v);
}
for (k, v) in runtime.env_overrides.drain() {
cmd.env(k, v);
}
let mut child = cmd.spawn().map_err(|e| {
Error::Io(std::io::Error::new(
e.kind(),
format!("Failed to spawn MCP server '{}': {}", config.name, e),
))
})?;
let stdin = child.stdin.take().ok_or_else(|| {
Error::Io(std::io::Error::other(
"Failed to capture stdin of MCP server",
))
})?;
let stdout = child.stdout.take().ok_or_else(|| {
Error::Io(std::io::Error::other(
"Failed to capture stdout of MCP server",
))
})?;
Ok(Self {
child: Some(Arc::new(Mutex::new(child))),
stdin: Some(Arc::new(Mutex::new(stdin))),
stdout: Some(Arc::new(Mutex::new(BufReader::new(stdout)))),
http_client: None,
http_endpoint: None,
ws_stream: None,
ws_endpoint: None,
next_id: AtomicU64::new(1),
http_header: None,
})
}
"http" => {
// For HTTP we treat `command` as the base URL.
let client = HttpClient::builder()
.timeout(Duration::from_secs(30))
.build()
.map_err(|e| Error::Network(e.to_string()))?;
Ok(Self {
child: None,
stdin: None,
stdout: None,
http_client: Some(client),
http_endpoint: Some(config.command.clone()),
ws_stream: None,
ws_endpoint: None,
next_id: AtomicU64::new(1),
http_header: runtime.http_header.take(),
})
}
"websocket" => {
// For WebSocket, the `command` field contains the WebSocket URL.
// We need to use a blocking task to establish the connection.
let ws_url = config.command.clone();
let (ws_stream, _response) = tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(async {
connect_async(&ws_url).await.map_err(|e| {
Error::Network(format!("WebSocket connection failed: {}", e))
})
})
})?;
Ok(Self {
child: None,
stdin: None,
stdout: None,
http_client: None,
http_endpoint: None,
ws_stream: Some(Arc::new(Mutex::new(ws_stream))),
ws_endpoint: Some(ws_url),
next_id: AtomicU64::new(1),
http_header: runtime.http_header.take(),
})
}
other => Err(Error::NotImplemented(format!(
"Transport '{}' not supported",
other
))),
}
}
/// Legacy constructor kept for compatibility; attempts to locate a binary.
pub fn new() -> Result<Self> {
// Fall back to searching for a binary as before, then delegate to new_with_config.
let workspace_root = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("../..")
.canonicalize()
.map_err(Error::Io)?;
// Prefer the LLM server binary as it provides both LLM and resource tools.
// The generic file-server is kept as a fallback for testing.
let candidates = [
"target/debug/owlen-mcp-llm-server",
"target/release/owlen-mcp-llm-server",
"target/debug/owlen-mcp-server",
];
let binary_path = candidates
.iter()
.map(|rel| workspace_root.join(rel))
.find(|p| p.exists())
.ok_or_else(|| {
Error::NotImplemented(format!(
"owlen-mcp server binary not found; checked {}, {}, and {}",
candidates[0], candidates[1], candidates[2]
))
})?;
let config = crate::config::McpServerConfig {
name: "default".to_string(),
command: binary_path.to_string_lossy().into_owned(),
args: Vec::new(),
transport: "stdio".to_string(),
env: std::collections::HashMap::new(),
oauth: None,
};
Self::new_with_config(&config)
}
async fn send_rpc(&self, method: &str, params: serde_json::Value) -> Result<serde_json::Value> {
let id = RequestId::Number(self.next_id.fetch_add(1, Ordering::Relaxed));
let request = RpcRequest::new(id.clone(), method, Some(params));
let req_str = serde_json::to_string(&request)? + "\n";
// For stdio transport we forward the request to the child process.
if let Some(stdin_arc) = &self.stdin {
let mut stdin = stdin_arc.lock().await;
stdin.write_all(req_str.as_bytes()).await?;
stdin.flush().await?;
}
// Read a single line response
// Handle based on selected transport.
if let Some(client) = &self.http_client {
// HTTP: POST JSON body to endpoint.
let endpoint = self
.http_endpoint
.as_ref()
.ok_or_else(|| Error::Network("Missing HTTP endpoint".into()))?;
let mut builder = client.post(endpoint);
if let Some((ref header_name, ref header_value)) = self.http_header {
builder = builder.header(header_name, header_value);
}
let resp = builder
.json(&request)
.send()
.await
.map_err(|e| Error::Network(e.to_string()))?;
let text = resp
.text()
.await
.map_err(|e| Error::Network(e.to_string()))?;
// Try to parse as success then error.
if let Ok(r) = serde_json::from_str::<RpcResponse>(&text)
&& r.id == id
{
return Ok(r.result);
}
let err_resp: RpcErrorResponse =
serde_json::from_str(&text).map_err(Error::Serialization)?;
return Err(Error::Network(format!(
"MCP server error {}: {}",
err_resp.error.code, err_resp.error.message
)));
}
// WebSocket path.
if let Some(ws_arc) = &self.ws_stream {
use futures::SinkExt;
let mut ws = ws_arc.lock().await;
// Send request as text message
let req_json = serde_json::to_string(&request)?;
ws.send(WsMessage::Text(req_json))
.await
.map_err(|e| Error::Network(format!("WebSocket send failed: {}", e)))?;
// Read response
let response_msg = ws
.next()
.await
.ok_or_else(|| Error::Network("WebSocket stream closed".into()))?
.map_err(|e| Error::Network(format!("WebSocket receive failed: {}", e)))?;
let response_text = match response_msg {
WsMessage::Text(text) => text,
WsMessage::Binary(data) => String::from_utf8(data).map_err(|e| {
Error::Network(format!("Invalid UTF-8 in binary message: {}", e))
})?,
WsMessage::Close(_) => {
return Err(Error::Network(
"WebSocket connection closed by server".into(),
));
}
_ => return Err(Error::Network("Unexpected WebSocket message type".into())),
};
// Try to parse as success then error.
if let Ok(r) = serde_json::from_str::<RpcResponse>(&response_text)
&& r.id == id
{
return Ok(r.result);
}
let err_resp: RpcErrorResponse =
serde_json::from_str(&response_text).map_err(Error::Serialization)?;
return Err(Error::Network(format!(
"MCP server error {}: {}",
err_resp.error.code, err_resp.error.message
)));
}
// STDIO path (default).
// Loop to skip notifications and find the response with matching ID.
loop {
let mut line = String::new();
{
let mut stdout = self
.stdout
.as_ref()
.ok_or_else(|| Error::Network("STDIO stdout not available".into()))?
.lock()
.await;
stdout.read_line(&mut line).await?;
}
// Try to parse as notification first (has no id field)
if let Ok(_notif) = serde_json::from_str::<RpcNotification>(&line) {
// Skip notifications and continue reading
continue;
}
// Try to parse successful response
if let Ok(resp) = serde_json::from_str::<RpcResponse>(&line) {
if resp.id == id {
return Ok(resp.result);
}
// If ID doesn't match, continue (though this shouldn't happen)
continue;
}
// Fallback to error response
if let Ok(err_resp) = serde_json::from_str::<RpcErrorResponse>(&line) {
return Err(Error::Network(format!(
"MCP server error {}: {}",
err_resp.error.code, err_resp.error.message
)));
}
// If we can't parse as any known type, return error
return Err(Error::Network(format!(
"Unable to parse server response: {}",
line.trim()
)));
}
}
}
impl RemoteMcpClient {
/// Convenience wrapper delegating to the `McpClient` trait methods.
pub async fn list_tools(&self) -> Result<Vec<McpToolDescriptor>> {
<Self as McpClient>::list_tools(self).await
}
pub async fn call_tool(&self, call: McpToolCall) -> Result<McpToolResponse> {
<Self as McpClient>::call_tool(self, call).await
}
}
#[async_trait::async_trait]
impl McpClient for RemoteMcpClient {
async fn list_tools(&self) -> Result<Vec<McpToolDescriptor>> {
// Query the remote MCP server for its tool descriptors using the standard
// `tools/list` RPC method. The server returns a JSON array of
// `McpToolDescriptor` objects.
let result = self.send_rpc(methods::TOOLS_LIST, json!(null)).await?;
let descriptors: Vec<McpToolDescriptor> = serde_json::from_value(result)?;
Ok(descriptors)
}
async fn call_tool(&self, call: McpToolCall) -> Result<McpToolResponse> {
// Local handling for simple resource tools to avoid needing the MCP server
// to implement them.
if call.name.starts_with("resources/get") {
let path = call
.arguments
.get("path")
.and_then(|v| v.as_str())
.unwrap_or("");
let content = std::fs::read_to_string(path).map_err(Error::Io)?;
return Ok(McpToolResponse {
name: call.name,
success: true,
output: serde_json::json!(content),
metadata: std::collections::HashMap::new(),
duration_ms: 0,
});
}
if call.name.starts_with("resources/list") {
let path = call
.arguments
.get("path")
.and_then(|v| v.as_str())
.unwrap_or(".");
let mut names = Vec::new();
for entry in std::fs::read_dir(path).map_err(Error::Io)?.flatten() {
if let Some(name) = entry.file_name().to_str() {
names.push(name.to_string());
}
}
return Ok(McpToolResponse {
name: call.name,
success: true,
output: serde_json::json!(names),
metadata: std::collections::HashMap::new(),
duration_ms: 0,
});
}
// Handle write and delete resources locally as well.
if call.name.starts_with("resources/write") {
let path = call
.arguments
.get("path")
.and_then(|v| v.as_str())
.ok_or_else(|| Error::InvalidInput("path missing".into()))?;
// Simple pathtraversal protection: reject any path containing ".." or absolute paths.
if path.contains("..") || Path::new(path).is_absolute() {
return Err(Error::InvalidInput("path traversal".into()));
}
let content = call
.arguments
.get("content")
.and_then(|v| v.as_str())
.ok_or_else(|| Error::InvalidInput("content missing".into()))?;
std::fs::write(path, content).map_err(Error::Io)?;
return Ok(McpToolResponse {
name: call.name,
success: true,
output: serde_json::json!(null),
metadata: std::collections::HashMap::new(),
duration_ms: 0,
});
}
if call.name.starts_with("resources/delete") {
let path = call
.arguments
.get("path")
.and_then(|v| v.as_str())
.ok_or_else(|| Error::InvalidInput("path missing".into()))?;
if path.contains("..") || Path::new(path).is_absolute() {
return Err(Error::InvalidInput("path traversal".into()));
}
std::fs::remove_file(path).map_err(Error::Io)?;
return Ok(McpToolResponse {
name: call.name,
success: true,
output: serde_json::json!(null),
metadata: std::collections::HashMap::new(),
duration_ms: 0,
});
}
// Local handling for web tools to avoid needing an external MCP server.
if call.name == "web_search" {
// Autogrant consent for the web_search tool (permanent for this process).
let consent_manager = std::sync::Arc::new(std::sync::Mutex::new(ConsentManager::new()));
{
let mut cm = consent_manager
.lock()
.map_err(|_| Error::Provider(anyhow!("Consent manager mutex poisoned")))?;
cm.grant_consent_with_scope(
"web_search",
Vec::new(),
Vec::new(),
ConsentScope::Permanent,
);
}
let tool = WebSearchTool::new(consent_manager.clone(), None, None);
let result = tool
.execute(call.arguments.clone())
.await
.map_err(|e| Error::Provider(e.into()))?;
return Ok(McpToolResponse {
name: call.name,
success: true,
output: result.output,
metadata: std::collections::HashMap::new(),
duration_ms: result.duration.as_millis() as u128,
});
}
if call.name == "web_scrape" {
let tool = WebScrapeTool::new();
let result = tool
.execute(call.arguments.clone())
.await
.map_err(|e| Error::Provider(e.into()))?;
return Ok(McpToolResponse {
name: call.name,
success: true,
output: result.output,
metadata: std::collections::HashMap::new(),
duration_ms: result.duration.as_millis() as u128,
});
}
// MCP server expects a generic "tools/call" method with a payload containing the
// specific tool name and its arguments. Wrap the incoming call accordingly.
let payload = serde_json::to_value(&call)?;
let result = self.send_rpc(methods::TOOLS_CALL, payload).await?;
// The server returns an McpToolResponse; deserialize it.
let response: McpToolResponse = serde_json::from_value(result)?;
Ok(response)
}
async fn set_mode(&self, _mode: Mode) -> Result<()> {
// Remote servers manage their own mode settings; treat as best-effort no-op.
Ok(())
}
}
// ---------------------------------------------------------------------------
// Provider implementation forwards chat requests to the generate_text tool.
// ---------------------------------------------------------------------------
impl LlmProvider for RemoteMcpClient {
type Stream = stream::Iter<std::vec::IntoIter<Result<ChatResponse>>>;
type ListModelsFuture<'a> = BoxFuture<'a, Result<Vec<ModelInfo>>>;
type SendPromptFuture<'a> = BoxFuture<'a, Result<ChatResponse>>;
type StreamPromptFuture<'a> = BoxFuture<'a, Result<Self::Stream>>;
type HealthCheckFuture<'a> = BoxFuture<'a, Result<()>>;
fn name(&self) -> &str {
"mcp-llm-server"
}
fn list_models(&self) -> Self::ListModelsFuture<'_> {
Box::pin(async move {
let result = self.send_rpc(methods::MODELS_LIST, json!(null)).await?;
let models: Vec<ModelInfo> = serde_json::from_value(result)?;
Ok(models)
})
}
fn send_prompt(&self, request: crate::types::ChatRequest) -> Self::SendPromptFuture<'_> {
Box::pin(send_via_stream(self, request))
}
fn stream_prompt(&self, request: crate::types::ChatRequest) -> Self::StreamPromptFuture<'_> {
Box::pin(async move {
let args = serde_json::json!({
"messages": request.messages,
"temperature": request.parameters.temperature,
"max_tokens": request.parameters.max_tokens,
"model": request.model,
"stream": request.parameters.stream,
});
let call = McpToolCall {
name: "generate_text".to_string(),
arguments: args,
};
let resp = self.call_tool(call).await?;
let content = resp.output.as_str().unwrap_or("").to_string();
let message = Message::new(Role::Assistant, content);
let chat_resp = ChatResponse {
message,
usage: None,
is_streaming: false,
is_final: true,
};
Ok(stream::iter(vec![Ok(chat_resp)]))
})
}
fn health_check(&self) -> Self::HealthCheckFuture<'_> {
Box::pin(async move {
let params = serde_json::json!({
"protocol_version": PROTOCOL_VERSION,
"client_info": {
"name": "owlen",
"version": env!("CARGO_PKG_VERSION"),
},
"capabilities": {}
});
self.send_rpc(methods::INITIALIZE, params).await.map(|_| ())
})
}
}
#[async_trait::async_trait]
impl LlmClient for RemoteMcpClient {
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
<Self as LlmProvider>::list_models(self).await
}
async fn send_chat(&self, request: crate::types::ChatRequest) -> Result<ChatResponse> {
<Self as LlmProvider>::send_prompt(self, request).await
}
async fn stream_chat(&self, request: crate::types::ChatRequest) -> Result<ChatStream> {
let stream = <Self as LlmProvider>::stream_prompt(self, request).await?;
Ok(Box::pin(stream))
}
async fn list_tools(&self) -> Result<Vec<McpToolDescriptor>> {
<Self as McpClient>::list_tools(self).await
}
async fn call_tool(&self, call: McpToolCall) -> Result<McpToolResponse> {
<Self as McpClient>::call_tool(self, call).await
}
}