326 lines
9.0 KiB
Rust
326 lines
9.0 KiB
Rust
use color_eyre::eyre::{Result, eyre};
|
|
use reqwest::redirect::Policy;
|
|
use scraper::{Html, Selector};
|
|
use serde::{Deserialize, Serialize};
|
|
use std::collections::HashSet;
|
|
use url::Url;
|
|
|
|
/// WebFetch response
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct FetchResponse {
|
|
pub url: String,
|
|
pub status: u16,
|
|
pub content: String,
|
|
pub content_type: Option<String>,
|
|
}
|
|
|
|
/// WebFetch client with domain filtering
|
|
pub struct WebFetchClient {
|
|
allowed_domains: HashSet<String>,
|
|
blocked_domains: HashSet<String>,
|
|
client: reqwest::Client,
|
|
}
|
|
|
|
impl WebFetchClient {
|
|
/// Create a new WebFetch client
|
|
pub fn new() -> Self {
|
|
let client = reqwest::Client::builder()
|
|
.redirect(Policy::none()) // Don't follow redirects automatically
|
|
.build()
|
|
.unwrap();
|
|
|
|
Self {
|
|
allowed_domains: HashSet::new(),
|
|
blocked_domains: HashSet::new(),
|
|
client,
|
|
}
|
|
}
|
|
|
|
/// Add an allowed domain
|
|
pub fn allow_domain(&mut self, domain: &str) {
|
|
self.allowed_domains.insert(domain.to_lowercase());
|
|
}
|
|
|
|
/// Add a blocked domain
|
|
pub fn block_domain(&mut self, domain: &str) {
|
|
self.blocked_domains.insert(domain.to_lowercase());
|
|
}
|
|
|
|
/// Check if a domain is allowed
|
|
fn is_domain_allowed(&self, domain: &str) -> bool {
|
|
let domain_lower = domain.to_lowercase();
|
|
|
|
// If explicitly blocked, deny
|
|
if self.blocked_domains.contains(&domain_lower) {
|
|
return false;
|
|
}
|
|
|
|
// If allowlist is empty, allow all (except blocked)
|
|
if self.allowed_domains.is_empty() {
|
|
return true;
|
|
}
|
|
|
|
// Otherwise, must be in allowlist
|
|
self.allowed_domains.contains(&domain_lower)
|
|
}
|
|
|
|
/// Fetch a URL
|
|
pub async fn fetch(&self, url: &str) -> Result<FetchResponse> {
|
|
let parsed_url = Url::parse(url)?;
|
|
let domain = parsed_url
|
|
.host_str()
|
|
.ok_or_else(|| eyre!("No host in URL"))?;
|
|
|
|
// Check domain permission
|
|
if !self.is_domain_allowed(domain) {
|
|
return Err(eyre!("Domain not allowed: {}", domain));
|
|
}
|
|
|
|
// Make the request
|
|
let response = self.client.get(url).send().await?;
|
|
|
|
let status = response.status().as_u16();
|
|
|
|
// Handle redirects manually
|
|
if (300..400).contains(&status)
|
|
&& let Some(location) = response.headers().get("location")
|
|
{
|
|
let location_str = location.to_str()?;
|
|
|
|
// Parse the redirect URL (may be relative)
|
|
let redirect_url = if location_str.starts_with("http") {
|
|
Url::parse(location_str)?
|
|
} else {
|
|
parsed_url.join(location_str)?
|
|
};
|
|
|
|
let redirect_domain = redirect_url
|
|
.host_str()
|
|
.ok_or_else(|| eyre!("No host in redirect URL"))?;
|
|
|
|
// Check if redirect domain is allowed
|
|
if !self.is_domain_allowed(redirect_domain) {
|
|
return Err(eyre!(
|
|
"Redirect to unapproved domain: {} -> {}",
|
|
domain,
|
|
redirect_domain
|
|
));
|
|
}
|
|
|
|
return Err(eyre!(
|
|
"Redirect detected: {} -> {}. Use the redirect URL directly.",
|
|
url,
|
|
redirect_url
|
|
));
|
|
}
|
|
|
|
let content_type = response
|
|
.headers()
|
|
.get("content-type")
|
|
.and_then(|v| v.to_str().ok())
|
|
.map(|s| s.to_string());
|
|
|
|
let content = response.text().await?;
|
|
|
|
Ok(FetchResponse {
|
|
url: url.to_string(),
|
|
status,
|
|
content,
|
|
content_type,
|
|
})
|
|
}
|
|
}
|
|
|
|
impl Default for WebFetchClient {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
/// Search provider trait
|
|
#[async_trait::async_trait]
|
|
pub trait SearchProvider: Send + Sync {
|
|
fn name(&self) -> &str;
|
|
async fn search(&self, query: &str) -> Result<Vec<SearchResult>>;
|
|
}
|
|
|
|
/// Search result
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct SearchResult {
|
|
pub title: String,
|
|
pub url: String,
|
|
pub snippet: String,
|
|
}
|
|
|
|
/// Stub search provider for testing
|
|
pub struct StubSearchProvider {
|
|
results: Vec<SearchResult>,
|
|
}
|
|
|
|
impl StubSearchProvider {
|
|
pub fn new(results: Vec<SearchResult>) -> Self {
|
|
Self { results }
|
|
}
|
|
}
|
|
|
|
#[async_trait::async_trait]
|
|
impl SearchProvider for StubSearchProvider {
|
|
fn name(&self) -> &str {
|
|
"stub"
|
|
}
|
|
|
|
async fn search(&self, _query: &str) -> Result<Vec<SearchResult>> {
|
|
Ok(self.results.clone())
|
|
}
|
|
}
|
|
|
|
/// DuckDuckGo HTML search provider
|
|
pub struct DuckDuckGoSearchProvider {
|
|
client: reqwest::Client,
|
|
max_results: usize,
|
|
}
|
|
|
|
impl DuckDuckGoSearchProvider {
|
|
/// Create a new DuckDuckGo search provider with default max results (10)
|
|
pub fn new() -> Self {
|
|
Self::with_max_results(10)
|
|
}
|
|
|
|
/// Create a new DuckDuckGo search provider with custom max results
|
|
pub fn with_max_results(max_results: usize) -> Self {
|
|
let client = reqwest::Client::builder()
|
|
.user_agent("Mozilla/5.0 (compatible; Owlen/1.0)")
|
|
.build()
|
|
.unwrap();
|
|
|
|
Self { client, max_results }
|
|
}
|
|
|
|
/// Parse DuckDuckGo HTML results
|
|
fn parse_results(html: &str, max_results: usize) -> Result<Vec<SearchResult>> {
|
|
let document = Html::parse_document(html);
|
|
|
|
// DuckDuckGo HTML selectors
|
|
let result_selector = Selector::parse(".result").map_err(|e| eyre!("Invalid selector: {:?}", e))?;
|
|
let title_selector = Selector::parse(".result__title a").map_err(|e| eyre!("Invalid selector: {:?}", e))?;
|
|
let snippet_selector = Selector::parse(".result__snippet").map_err(|e| eyre!("Invalid selector: {:?}", e))?;
|
|
|
|
let mut results = Vec::new();
|
|
|
|
for result in document.select(&result_selector).take(max_results) {
|
|
let title = result
|
|
.select(&title_selector)
|
|
.next()
|
|
.map(|e| e.text().collect::<String>().trim().to_string())
|
|
.unwrap_or_default();
|
|
|
|
let url = result
|
|
.select(&title_selector)
|
|
.next()
|
|
.and_then(|e| e.value().attr("href"))
|
|
.unwrap_or_default()
|
|
.to_string();
|
|
|
|
let snippet = result
|
|
.select(&snippet_selector)
|
|
.next()
|
|
.map(|e| e.text().collect::<String>().trim().to_string())
|
|
.unwrap_or_default();
|
|
|
|
if !title.is_empty() && !url.is_empty() {
|
|
results.push(SearchResult { title, url, snippet });
|
|
}
|
|
}
|
|
|
|
Ok(results)
|
|
}
|
|
}
|
|
|
|
impl Default for DuckDuckGoSearchProvider {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
#[async_trait::async_trait]
|
|
impl SearchProvider for DuckDuckGoSearchProvider {
|
|
fn name(&self) -> &str {
|
|
"duckduckgo"
|
|
}
|
|
|
|
async fn search(&self, query: &str) -> Result<Vec<SearchResult>> {
|
|
let encoded_query = urlencoding::encode(query);
|
|
let url = format!("https://html.duckduckgo.com/html/?q={}", encoded_query);
|
|
|
|
let response = self.client.get(&url).send().await?;
|
|
let html = response.text().await?;
|
|
|
|
Self::parse_results(&html, self.max_results)
|
|
}
|
|
}
|
|
|
|
/// WebSearch client with pluggable providers
|
|
pub struct WebSearchClient {
|
|
provider: Box<dyn SearchProvider>,
|
|
}
|
|
|
|
impl WebSearchClient {
|
|
pub fn new(provider: Box<dyn SearchProvider>) -> Self {
|
|
Self { provider }
|
|
}
|
|
|
|
pub fn provider_name(&self) -> &str {
|
|
self.provider.name()
|
|
}
|
|
|
|
pub async fn search(&self, query: &str) -> Result<Vec<SearchResult>> {
|
|
self.provider.search(query).await
|
|
}
|
|
}
|
|
|
|
/// Format search results for LLM consumption (markdown format)
|
|
pub fn format_search_results(results: &[SearchResult]) -> String {
|
|
if results.is_empty() {
|
|
return "No results found.".to_string();
|
|
}
|
|
|
|
results
|
|
.iter()
|
|
.enumerate()
|
|
.map(|(i, r)| format!("{}. [{}]({})\n {}", i + 1, r.title, r.url, r.snippet))
|
|
.collect::<Vec<_>>()
|
|
.join("\n\n")
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn domain_filtering_allowlist() {
|
|
let mut client = WebFetchClient::new();
|
|
client.allow_domain("example.com");
|
|
|
|
assert!(client.is_domain_allowed("example.com"));
|
|
assert!(!client.is_domain_allowed("evil.com"));
|
|
}
|
|
|
|
#[test]
|
|
fn domain_filtering_blocklist() {
|
|
let mut client = WebFetchClient::new();
|
|
client.block_domain("evil.com");
|
|
|
|
assert!(client.is_domain_allowed("example.com")); // Empty allowlist = allow all
|
|
assert!(!client.is_domain_allowed("evil.com"));
|
|
}
|
|
|
|
#[test]
|
|
fn domain_filtering_case_insensitive() {
|
|
let mut client = WebFetchClient::new();
|
|
client.allow_domain("Example.COM");
|
|
|
|
assert!(client.is_domain_allowed("example.com"));
|
|
assert!(client.is_domain_allowed("EXAMPLE.COM"));
|
|
}
|
|
}
|