use color_eyre::eyre::{Result, eyre}; use reqwest::redirect::Policy; use scraper::{Html, Selector}; use serde::{Deserialize, Serialize}; use std::collections::HashSet; use url::Url; /// WebFetch response #[derive(Debug, Clone, Serialize, Deserialize)] pub struct FetchResponse { pub url: String, pub status: u16, pub content: String, pub content_type: Option, } /// WebFetch client with domain filtering pub struct WebFetchClient { allowed_domains: HashSet, blocked_domains: HashSet, client: reqwest::Client, } impl WebFetchClient { /// Create a new WebFetch client pub fn new() -> Self { let client = reqwest::Client::builder() .redirect(Policy::none()) // Don't follow redirects automatically .build() .unwrap(); Self { allowed_domains: HashSet::new(), blocked_domains: HashSet::new(), client, } } /// Add an allowed domain pub fn allow_domain(&mut self, domain: &str) { self.allowed_domains.insert(domain.to_lowercase()); } /// Add a blocked domain pub fn block_domain(&mut self, domain: &str) { self.blocked_domains.insert(domain.to_lowercase()); } /// Check if a domain is allowed fn is_domain_allowed(&self, domain: &str) -> bool { let domain_lower = domain.to_lowercase(); // If explicitly blocked, deny if self.blocked_domains.contains(&domain_lower) { return false; } // If allowlist is empty, allow all (except blocked) if self.allowed_domains.is_empty() { return true; } // Otherwise, must be in allowlist self.allowed_domains.contains(&domain_lower) } /// Fetch a URL pub async fn fetch(&self, url: &str) -> Result { let parsed_url = Url::parse(url)?; let domain = parsed_url .host_str() .ok_or_else(|| eyre!("No host in URL"))?; // Check domain permission if !self.is_domain_allowed(domain) { return Err(eyre!("Domain not allowed: {}", domain)); } // Make the request let response = self.client.get(url).send().await?; let status = response.status().as_u16(); // Handle redirects manually if (300..400).contains(&status) && let Some(location) = response.headers().get("location") { let location_str = location.to_str()?; // Parse the redirect URL (may be relative) let redirect_url = if location_str.starts_with("http") { Url::parse(location_str)? } else { parsed_url.join(location_str)? }; let redirect_domain = redirect_url .host_str() .ok_or_else(|| eyre!("No host in redirect URL"))?; // Check if redirect domain is allowed if !self.is_domain_allowed(redirect_domain) { return Err(eyre!( "Redirect to unapproved domain: {} -> {}", domain, redirect_domain )); } return Err(eyre!( "Redirect detected: {} -> {}. Use the redirect URL directly.", url, redirect_url )); } let content_type = response .headers() .get("content-type") .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()); let content = response.text().await?; Ok(FetchResponse { url: url.to_string(), status, content, content_type, }) } } impl Default for WebFetchClient { fn default() -> Self { Self::new() } } /// Search provider trait #[async_trait::async_trait] pub trait SearchProvider: Send + Sync { fn name(&self) -> &str; async fn search(&self, query: &str) -> Result>; } /// Search result #[derive(Debug, Clone, Serialize, Deserialize)] pub struct SearchResult { pub title: String, pub url: String, pub snippet: String, } /// Stub search provider for testing pub struct StubSearchProvider { results: Vec, } impl StubSearchProvider { pub fn new(results: Vec) -> Self { Self { results } } } #[async_trait::async_trait] impl SearchProvider for StubSearchProvider { fn name(&self) -> &str { "stub" } async fn search(&self, _query: &str) -> Result> { Ok(self.results.clone()) } } /// DuckDuckGo HTML search provider pub struct DuckDuckGoSearchProvider { client: reqwest::Client, max_results: usize, } impl DuckDuckGoSearchProvider { /// Create a new DuckDuckGo search provider with default max results (10) pub fn new() -> Self { Self::with_max_results(10) } /// Create a new DuckDuckGo search provider with custom max results pub fn with_max_results(max_results: usize) -> Self { let client = reqwest::Client::builder() .user_agent("Mozilla/5.0 (compatible; Owlen/1.0)") .build() .unwrap(); Self { client, max_results } } /// Parse DuckDuckGo HTML results fn parse_results(html: &str, max_results: usize) -> Result> { let document = Html::parse_document(html); // DuckDuckGo HTML selectors let result_selector = Selector::parse(".result").map_err(|e| eyre!("Invalid selector: {:?}", e))?; let title_selector = Selector::parse(".result__title a").map_err(|e| eyre!("Invalid selector: {:?}", e))?; let snippet_selector = Selector::parse(".result__snippet").map_err(|e| eyre!("Invalid selector: {:?}", e))?; let mut results = Vec::new(); for result in document.select(&result_selector).take(max_results) { let title = result .select(&title_selector) .next() .map(|e| e.text().collect::().trim().to_string()) .unwrap_or_default(); let url = result .select(&title_selector) .next() .and_then(|e| e.value().attr("href")) .unwrap_or_default() .to_string(); let snippet = result .select(&snippet_selector) .next() .map(|e| e.text().collect::().trim().to_string()) .unwrap_or_default(); if !title.is_empty() && !url.is_empty() { results.push(SearchResult { title, url, snippet }); } } Ok(results) } } impl Default for DuckDuckGoSearchProvider { fn default() -> Self { Self::new() } } #[async_trait::async_trait] impl SearchProvider for DuckDuckGoSearchProvider { fn name(&self) -> &str { "duckduckgo" } async fn search(&self, query: &str) -> Result> { let encoded_query = urlencoding::encode(query); let url = format!("https://html.duckduckgo.com/html/?q={}", encoded_query); let response = self.client.get(&url).send().await?; let html = response.text().await?; Self::parse_results(&html, self.max_results) } } /// WebSearch client with pluggable providers pub struct WebSearchClient { provider: Box, } impl WebSearchClient { pub fn new(provider: Box) -> Self { Self { provider } } pub fn provider_name(&self) -> &str { self.provider.name() } pub async fn search(&self, query: &str) -> Result> { self.provider.search(query).await } } /// Format search results for LLM consumption (markdown format) pub fn format_search_results(results: &[SearchResult]) -> String { if results.is_empty() { return "No results found.".to_string(); } results .iter() .enumerate() .map(|(i, r)| format!("{}. [{}]({})\n {}", i + 1, r.title, r.url, r.snippet)) .collect::>() .join("\n\n") } #[cfg(test)] mod tests { use super::*; #[test] fn domain_filtering_allowlist() { let mut client = WebFetchClient::new(); client.allow_domain("example.com"); assert!(client.is_domain_allowed("example.com")); assert!(!client.is_domain_allowed("evil.com")); } #[test] fn domain_filtering_blocklist() { let mut client = WebFetchClient::new(); client.block_domain("evil.com"); assert!(client.is_domain_allowed("example.com")); // Empty allowlist = allow all assert!(!client.is_domain_allowed("evil.com")); } #[test] fn domain_filtering_case_insensitive() { let mut client = WebFetchClient::new(); client.allow_domain("Example.COM"); assert!(client.is_domain_allowed("example.com")); assert!(client.is_domain_allowed("EXAMPLE.COM")); } }