diff --git a/Cargo.toml b/Cargo.toml index aa33d81..7c0d03c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ members = [ "crates/tools/bash", "crates/tools/fs", "crates/tools/slash", + "crates/tools/web", "crates/integration/mcp-client", ] resolver = "2" diff --git a/crates/tools/web/Cargo.toml b/crates/tools/web/Cargo.toml new file mode 100644 index 0000000..92a5973 --- /dev/null +++ b/crates/tools/web/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "tools-web" +version = "0.1.0" +edition.workspace = true +license.workspace = true +rust-version.workspace = true + +[dependencies] +reqwest = { version = "0.12", features = ["json"] } +tokio = { version = "1.39", features = ["macros"] } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +color-eyre = "0.6" +url = "2.5" +async-trait = "0.1" + +[dev-dependencies] +tokio = { version = "1.39", features = ["macros", "rt-multi-thread"] } +wiremock = "0.6" diff --git a/crates/tools/web/src/lib.rs b/crates/tools/web/src/lib.rs new file mode 100644 index 0000000..04a2a20 --- /dev/null +++ b/crates/tools/web/src/lib.rs @@ -0,0 +1,225 @@ +use color_eyre::eyre::{Result, eyre}; +use reqwest::redirect::Policy; +use serde::{Deserialize, Serialize}; +use std::collections::HashSet; +use url::Url; + +/// WebFetch response +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FetchResponse { + pub url: String, + pub status: u16, + pub content: String, + pub content_type: Option, +} + +/// WebFetch client with domain filtering +pub struct WebFetchClient { + allowed_domains: HashSet, + blocked_domains: HashSet, + client: reqwest::Client, +} + +impl WebFetchClient { + /// Create a new WebFetch client + pub fn new() -> Self { + let client = reqwest::Client::builder() + .redirect(Policy::none()) // Don't follow redirects automatically + .build() + .unwrap(); + + Self { + allowed_domains: HashSet::new(), + blocked_domains: HashSet::new(), + client, + } + } + + /// Add an allowed domain + pub fn allow_domain(&mut self, domain: &str) { + self.allowed_domains.insert(domain.to_lowercase()); + } + + /// Add a blocked domain + pub fn block_domain(&mut self, domain: &str) { + self.blocked_domains.insert(domain.to_lowercase()); + } + + /// Check if a domain is allowed + fn is_domain_allowed(&self, domain: &str) -> bool { + let domain_lower = domain.to_lowercase(); + + // If explicitly blocked, deny + if self.blocked_domains.contains(&domain_lower) { + return false; + } + + // If allowlist is empty, allow all (except blocked) + if self.allowed_domains.is_empty() { + return true; + } + + // Otherwise, must be in allowlist + self.allowed_domains.contains(&domain_lower) + } + + /// Fetch a URL + pub async fn fetch(&self, url: &str) -> Result { + let parsed_url = Url::parse(url)?; + let domain = parsed_url + .host_str() + .ok_or_else(|| eyre!("No host in URL"))?; + + // Check domain permission + if !self.is_domain_allowed(domain) { + return Err(eyre!("Domain not allowed: {}", domain)); + } + + // Make the request + let response = self.client.get(url).send().await?; + + let status = response.status().as_u16(); + + // Handle redirects manually + if status >= 300 && status < 400 { + if let Some(location) = response.headers().get("location") { + let location_str = location.to_str()?; + + // Parse the redirect URL (may be relative) + let redirect_url = if location_str.starts_with("http") { + Url::parse(location_str)? + } else { + parsed_url.join(location_str)? + }; + + let redirect_domain = redirect_url + .host_str() + .ok_or_else(|| eyre!("No host in redirect URL"))?; + + // Check if redirect domain is allowed + if !self.is_domain_allowed(redirect_domain) { + return Err(eyre!( + "Redirect to unapproved domain: {} -> {}", + domain, + redirect_domain + )); + } + + return Err(eyre!( + "Redirect detected: {} -> {}. Use the redirect URL directly.", + url, + redirect_url + )); + } + } + + let content_type = response + .headers() + .get("content-type") + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); + + let content = response.text().await?; + + Ok(FetchResponse { + url: url.to_string(), + status, + content, + content_type, + }) + } +} + +impl Default for WebFetchClient { + fn default() -> Self { + Self::new() + } +} + +/// Search provider trait +#[async_trait::async_trait] +pub trait SearchProvider: Send + Sync { + fn name(&self) -> &str; + async fn search(&self, query: &str) -> Result>; +} + +/// Search result +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SearchResult { + pub title: String, + pub url: String, + pub snippet: String, +} + +/// Stub search provider for testing +pub struct StubSearchProvider { + results: Vec, +} + +impl StubSearchProvider { + pub fn new(results: Vec) -> Self { + Self { results } + } +} + +#[async_trait::async_trait] +impl SearchProvider for StubSearchProvider { + fn name(&self) -> &str { + "stub" + } + + async fn search(&self, _query: &str) -> Result> { + Ok(self.results.clone()) + } +} + +/// WebSearch client with pluggable providers +pub struct WebSearchClient { + provider: Box, +} + +impl WebSearchClient { + pub fn new(provider: Box) -> Self { + Self { provider } + } + + pub fn provider_name(&self) -> &str { + self.provider.name() + } + + pub async fn search(&self, query: &str) -> Result> { + self.provider.search(query).await + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn domain_filtering_allowlist() { + let mut client = WebFetchClient::new(); + client.allow_domain("example.com"); + + assert!(client.is_domain_allowed("example.com")); + assert!(!client.is_domain_allowed("evil.com")); + } + + #[test] + fn domain_filtering_blocklist() { + let mut client = WebFetchClient::new(); + client.block_domain("evil.com"); + + assert!(client.is_domain_allowed("example.com")); // Empty allowlist = allow all + assert!(!client.is_domain_allowed("evil.com")); + } + + #[test] + fn domain_filtering_case_insensitive() { + let mut client = WebFetchClient::new(); + client.allow_domain("Example.COM"); + + assert!(client.is_domain_allowed("example.com")); + assert!(client.is_domain_allowed("EXAMPLE.COM")); + } +} diff --git a/crates/tools/web/tests/web_tools.rs b/crates/tools/web/tests/web_tools.rs new file mode 100644 index 0000000..0c11288 --- /dev/null +++ b/crates/tools/web/tests/web_tools.rs @@ -0,0 +1,161 @@ +use tools_web::{WebFetchClient, WebSearchClient, StubSearchProvider, SearchResult}; +use wiremock::{MockServer, Mock, ResponseTemplate}; +use wiremock::matchers::{method, path}; + +#[tokio::test] +async fn webfetch_domain_whitelist_only() { + let mock_server = MockServer::start().await; + + Mock::given(method("GET")) + .and(path("/test")) + .respond_with(ResponseTemplate::new(200).set_body_string("Hello from allowed domain")) + .mount(&mock_server) + .await; + + let mut client = WebFetchClient::new(); + client.allow_domain("localhost"); + client.allow_domain("127.0.0.1"); // Domain without port + + // Fetch from allowed domain should work + let url = format!("{}/test", mock_server.uri()); + let response = client.fetch(&url).await.unwrap(); + assert_eq!(response.status, 200); + assert!(response.content.contains("Hello from allowed domain")); + + // Create a client with different allowlist + let mut strict_client = WebFetchClient::new(); + strict_client.allow_domain("example.com"); + + // Fetch from non-allowed domain should fail + let result = strict_client.fetch(&url).await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Domain not allowed")); +} + +#[tokio::test] +async fn webfetch_redirect_to_unapproved_domain() { + let mock_server = MockServer::start().await; + + // Mock a redirect to a different domain + Mock::given(method("GET")) + .and(path("/redirect")) + .respond_with( + ResponseTemplate::new(302) + .insert_header("location", "https://evil.com/malware") + ) + .mount(&mock_server) + .await; + + let mut client = WebFetchClient::new(); + client.allow_domain("localhost"); + client.allow_domain("127.0.0.1"); // Domain without port + // evil.com is NOT in the allowlist + + let url = format!("{}/redirect", mock_server.uri()); + let result = client.fetch(&url).await; + + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!(err_msg.contains("Redirect to unapproved domain") || err_msg.contains("evil.com")); +} + +#[tokio::test] +async fn webfetch_redirect_to_approved_domain() { + let mock_server = MockServer::start().await; + + let redirect_url = format!("{}/target", mock_server.uri()); + + // Mock a redirect to an approved domain + Mock::given(method("GET")) + .and(path("/redirect")) + .respond_with( + ResponseTemplate::new(302) + .insert_header("location", &redirect_url) + ) + .mount(&mock_server) + .await; + + let mut client = WebFetchClient::new(); + client.allow_domain("localhost"); + client.allow_domain("127.0.0.1"); // Domain without port + + let url = format!("{}/redirect", mock_server.uri()); + let result = client.fetch(&url).await; + + // Should fail but with a message about using the redirect URL + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!(err_msg.contains("Redirect detected") || err_msg.contains("Use the redirect URL")); +} + +#[tokio::test] +async fn webfetch_blocklist_overrides_allowlist() { + let mock_server = MockServer::start().await; + + Mock::given(method("GET")) + .and(path("/test")) + .respond_with(ResponseTemplate::new(200).set_body_string("Hello")) + .mount(&mock_server) + .await; + + let domain = "127.0.0.1"; + let mut client = WebFetchClient::new(); + client.allow_domain(domain); + client.block_domain(domain); // Block overrides allow + + let url = format!("{}/test", mock_server.uri()); + let result = client.fetch(&url).await; + + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Domain not allowed")); +} + +#[tokio::test] +async fn websearch_pluggable_provider() { + let stub_results = vec![ + SearchResult { + title: "Test Result 1".to_string(), + url: "https://example.com/1".to_string(), + snippet: "This is a test result".to_string(), + }, + SearchResult { + title: "Test Result 2".to_string(), + url: "https://example.com/2".to_string(), + snippet: "Another test result".to_string(), + }, + ]; + + let provider = StubSearchProvider::new(stub_results.clone()); + let client = WebSearchClient::new(Box::new(provider)); + + assert_eq!(client.provider_name(), "stub"); + + let results = client.search("test query").await.unwrap(); + assert_eq!(results.len(), 2); + assert_eq!(results[0].title, "Test Result 1"); + assert_eq!(results[1].url, "https://example.com/2"); +} + +#[tokio::test] +async fn webfetch_successful_request() { + let mock_server = MockServer::start().await; + + Mock::given(method("GET")) + .and(path("/api/data")) + .respond_with( + ResponseTemplate::new(200) + .set_body_string(r#"{"status":"ok"}"#) + .insert_header("content-type", "application/json") + ) + .mount(&mock_server) + .await; + + let client = WebFetchClient::new(); // Empty allowlist = allow all + + let url = format!("{}/api/data", mock_server.uri()); + let response = client.fetch(&url).await.unwrap(); + + assert_eq!(response.status, 200); + assert!(response.content.contains("status")); + assert!(response.content_type.is_some()); // Just verify content-type is present +}