Files
owlen/crates/tools/web/src/lib.rs

326 lines
9.0 KiB
Rust

use color_eyre::eyre::{Result, eyre};
use reqwest::redirect::Policy;
use scraper::{Html, Selector};
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use url::Url;
/// WebFetch response
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FetchResponse {
pub url: String,
pub status: u16,
pub content: String,
pub content_type: Option<String>,
}
/// WebFetch client with domain filtering
pub struct WebFetchClient {
allowed_domains: HashSet<String>,
blocked_domains: HashSet<String>,
client: reqwest::Client,
}
impl WebFetchClient {
/// Create a new WebFetch client
pub fn new() -> Self {
let client = reqwest::Client::builder()
.redirect(Policy::none()) // Don't follow redirects automatically
.build()
.unwrap();
Self {
allowed_domains: HashSet::new(),
blocked_domains: HashSet::new(),
client,
}
}
/// Add an allowed domain
pub fn allow_domain(&mut self, domain: &str) {
self.allowed_domains.insert(domain.to_lowercase());
}
/// Add a blocked domain
pub fn block_domain(&mut self, domain: &str) {
self.blocked_domains.insert(domain.to_lowercase());
}
/// Check if a domain is allowed
fn is_domain_allowed(&self, domain: &str) -> bool {
let domain_lower = domain.to_lowercase();
// If explicitly blocked, deny
if self.blocked_domains.contains(&domain_lower) {
return false;
}
// If allowlist is empty, allow all (except blocked)
if self.allowed_domains.is_empty() {
return true;
}
// Otherwise, must be in allowlist
self.allowed_domains.contains(&domain_lower)
}
/// Fetch a URL
pub async fn fetch(&self, url: &str) -> Result<FetchResponse> {
let parsed_url = Url::parse(url)?;
let domain = parsed_url
.host_str()
.ok_or_else(|| eyre!("No host in URL"))?;
// Check domain permission
if !self.is_domain_allowed(domain) {
return Err(eyre!("Domain not allowed: {}", domain));
}
// Make the request
let response = self.client.get(url).send().await?;
let status = response.status().as_u16();
// Handle redirects manually
if (300..400).contains(&status)
&& let Some(location) = response.headers().get("location")
{
let location_str = location.to_str()?;
// Parse the redirect URL (may be relative)
let redirect_url = if location_str.starts_with("http") {
Url::parse(location_str)?
} else {
parsed_url.join(location_str)?
};
let redirect_domain = redirect_url
.host_str()
.ok_or_else(|| eyre!("No host in redirect URL"))?;
// Check if redirect domain is allowed
if !self.is_domain_allowed(redirect_domain) {
return Err(eyre!(
"Redirect to unapproved domain: {} -> {}",
domain,
redirect_domain
));
}
return Err(eyre!(
"Redirect detected: {} -> {}. Use the redirect URL directly.",
url,
redirect_url
));
}
let content_type = response
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let content = response.text().await?;
Ok(FetchResponse {
url: url.to_string(),
status,
content,
content_type,
})
}
}
impl Default for WebFetchClient {
fn default() -> Self {
Self::new()
}
}
/// Search provider trait
#[async_trait::async_trait]
pub trait SearchProvider: Send + Sync {
fn name(&self) -> &str;
async fn search(&self, query: &str) -> Result<Vec<SearchResult>>;
}
/// Search result
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
pub title: String,
pub url: String,
pub snippet: String,
}
/// Stub search provider for testing
pub struct StubSearchProvider {
results: Vec<SearchResult>,
}
impl StubSearchProvider {
pub fn new(results: Vec<SearchResult>) -> Self {
Self { results }
}
}
#[async_trait::async_trait]
impl SearchProvider for StubSearchProvider {
fn name(&self) -> &str {
"stub"
}
async fn search(&self, _query: &str) -> Result<Vec<SearchResult>> {
Ok(self.results.clone())
}
}
/// DuckDuckGo HTML search provider
pub struct DuckDuckGoSearchProvider {
client: reqwest::Client,
max_results: usize,
}
impl DuckDuckGoSearchProvider {
/// Create a new DuckDuckGo search provider with default max results (10)
pub fn new() -> Self {
Self::with_max_results(10)
}
/// Create a new DuckDuckGo search provider with custom max results
pub fn with_max_results(max_results: usize) -> Self {
let client = reqwest::Client::builder()
.user_agent("Mozilla/5.0 (compatible; Owlen/1.0)")
.build()
.unwrap();
Self { client, max_results }
}
/// Parse DuckDuckGo HTML results
fn parse_results(html: &str, max_results: usize) -> Result<Vec<SearchResult>> {
let document = Html::parse_document(html);
// DuckDuckGo HTML selectors
let result_selector = Selector::parse(".result").map_err(|e| eyre!("Invalid selector: {:?}", e))?;
let title_selector = Selector::parse(".result__title a").map_err(|e| eyre!("Invalid selector: {:?}", e))?;
let snippet_selector = Selector::parse(".result__snippet").map_err(|e| eyre!("Invalid selector: {:?}", e))?;
let mut results = Vec::new();
for result in document.select(&result_selector).take(max_results) {
let title = result
.select(&title_selector)
.next()
.map(|e| e.text().collect::<String>().trim().to_string())
.unwrap_or_default();
let url = result
.select(&title_selector)
.next()
.and_then(|e| e.value().attr("href"))
.unwrap_or_default()
.to_string();
let snippet = result
.select(&snippet_selector)
.next()
.map(|e| e.text().collect::<String>().trim().to_string())
.unwrap_or_default();
if !title.is_empty() && !url.is_empty() {
results.push(SearchResult { title, url, snippet });
}
}
Ok(results)
}
}
impl Default for DuckDuckGoSearchProvider {
fn default() -> Self {
Self::new()
}
}
#[async_trait::async_trait]
impl SearchProvider for DuckDuckGoSearchProvider {
fn name(&self) -> &str {
"duckduckgo"
}
async fn search(&self, query: &str) -> Result<Vec<SearchResult>> {
let encoded_query = urlencoding::encode(query);
let url = format!("https://html.duckduckgo.com/html/?q={}", encoded_query);
let response = self.client.get(&url).send().await?;
let html = response.text().await?;
Self::parse_results(&html, self.max_results)
}
}
/// WebSearch client with pluggable providers
pub struct WebSearchClient {
provider: Box<dyn SearchProvider>,
}
impl WebSearchClient {
pub fn new(provider: Box<dyn SearchProvider>) -> Self {
Self { provider }
}
pub fn provider_name(&self) -> &str {
self.provider.name()
}
pub async fn search(&self, query: &str) -> Result<Vec<SearchResult>> {
self.provider.search(query).await
}
}
/// Format search results for LLM consumption (markdown format)
pub fn format_search_results(results: &[SearchResult]) -> String {
if results.is_empty() {
return "No results found.".to_string();
}
results
.iter()
.enumerate()
.map(|(i, r)| format!("{}. [{}]({})\n {}", i + 1, r.title, r.url, r.snippet))
.collect::<Vec<_>>()
.join("\n\n")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn domain_filtering_allowlist() {
let mut client = WebFetchClient::new();
client.allow_domain("example.com");
assert!(client.is_domain_allowed("example.com"));
assert!(!client.is_domain_allowed("evil.com"));
}
#[test]
fn domain_filtering_blocklist() {
let mut client = WebFetchClient::new();
client.block_domain("evil.com");
assert!(client.is_domain_allowed("example.com")); // Empty allowlist = allow all
assert!(!client.is_domain_allowed("evil.com"));
}
#[test]
fn domain_filtering_case_insensitive() {
let mut client = WebFetchClient::new();
client.allow_domain("Example.COM");
assert!(client.is_domain_allowed("example.com"));
assert!(client.is_domain_allowed("EXAMPLE.COM"));
}
}