use std::sync::Arc; use owlen_core::tools::WEB_SEARCH_TOOL_NAME; use owlen_core::types::{ChatParameters, Role}; use owlen_core::{ Config, Provider, providers::OllamaProvider, session::{SessionController, SessionOutcome}, storage::StorageManager, ui::NoOpUiController, }; use serde_json::{Value, json}; use tempfile::{TempDir, tempdir}; use wiremock::{ Match, Mock, MockServer, Request, ResponseTemplate, matchers::{header, method, path}, }; #[derive(Clone, Copy)] struct BodySubstringMatcher { needle: &'static str, should_contain: bool, } impl BodySubstringMatcher { const fn contains(needle: &'static str) -> Self { Self { needle, should_contain: true, } } const fn not_contains(needle: &'static str) -> Self { Self { needle, should_contain: false, } } } impl Match for BodySubstringMatcher { fn matches(&self, request: &Request) -> bool { let body_str = std::str::from_utf8(&request.body).unwrap_or_default(); body_str.contains(self.needle) == self.should_contain } } fn load_fixture(name: &str) -> Value { match name { "ollama_tags" => serde_json::from_str(include_str!("fixtures/ollama_tags.json")) .expect("valid tags fixture"), "ollama_local_completion" => { serde_json::from_str(include_str!("fixtures/ollama_local_completion.json")) .expect("valid local completion fixture") } "ollama_cloud_tool_call" => { serde_json::from_str(include_str!("fixtures/ollama_cloud_tool_call.json")) .expect("valid cloud tool call fixture") } "ollama_cloud_final" => { serde_json::from_str(include_str!("fixtures/ollama_cloud_final.json")) .expect("valid cloud final fixture") } other => panic!("unknown fixture '{other}'"), } } async fn create_session( provider: Arc, config: Config, ) -> (SessionController, TempDir) { let temp_dir = tempdir().expect("temp dir"); let storage_path = temp_dir.path().join("owlen-tests.db"); let storage = Arc::new( StorageManager::with_database_path(storage_path) .await .expect("storage manager"), ); let ui = Arc::new(NoOpUiController); let session = SessionController::new(provider, config, storage, ui, false, None) .await .expect("session controller"); (session, temp_dir) } fn configure_local(base_url: &str) -> Config { let mut config = Config::default(); config.general.default_provider = "ollama_local".into(); config.general.default_model = Some("local-mini".into()); config.general.enable_streaming = false; config.privacy.encrypt_local_data = false; config.privacy.require_consent_per_session = false; if let Some(local) = config.providers.get_mut("ollama_local") { local.enabled = true; local.base_url = Some(base_url.to_string()); } config } fn configure_cloud(base_url: &str) -> Config { let mut config = Config::default(); config.general.default_provider = "ollama_cloud".into(); config.general.default_model = Some("llama3:8b-cloud".into()); config.general.enable_streaming = false; config.privacy.enable_remote_search = true; config.privacy.encrypt_local_data = false; config.privacy.require_consent_per_session = false; config.tools.web_search.enabled = true; unsafe { std::env::set_var("OWLEN_ALLOW_INSECURE_CLOUD", "1"); } if let Some(cloud) = config.providers.get_mut("ollama_cloud") { cloud.enabled = true; cloud.base_url = Some(base_url.to_string()); cloud.api_key = Some("test-key".into()); cloud.extra.insert( "web_search_endpoint".into(), Value::String("/v1/web/search".into()), ); cloud.extra.insert( owlen_core::config::OLLAMA_CLOUD_ENDPOINT_KEY.into(), Value::String(base_url.to_string()), ); } config } #[tokio::test(flavor = "multi_thread")] async fn local_provider_happy_path_records_usage() { let server = MockServer::start().await; let tags = load_fixture("ollama_tags"); let completion = load_fixture("ollama_local_completion"); let base_url = server.uri(); let tags_template = ResponseTemplate::new(200).set_body_json(tags); Mock::given(method("GET")) .and(path("/api/tags")) .respond_with(tags_template.clone()) .mount(&server) .await; Mock::given(method("POST")) .and(path("/api/chat")) .respond_with(ResponseTemplate::new(200).set_body_json(completion)) .mount(&server) .await; let config = configure_local(&base_url); let provider: Arc = Arc::new(OllamaProvider::new(base_url.clone()).expect("local provider")); let (mut session, _tmp) = create_session(provider, config).await; let outcome = session .send_message( "Summarise the local status.".to_string(), ChatParameters::default(), ) .await .expect("local completion"); let response = match outcome { SessionOutcome::Complete(response) => response, _ => panic!("expected complete outcome"), }; assert_eq!(response.message.content, "Local response complete."); let snapshot = session .current_usage_snapshot() .await .expect("usage snapshot"); assert_eq!(snapshot.provider, "ollama_local"); assert_eq!(snapshot.hourly.total_tokens, 36); assert_eq!(snapshot.weekly.total_tokens, 36); } #[tokio::test(flavor = "multi_thread")] async fn cloud_tool_call_flows_through_web_search() { let server = MockServer::start().await; let tags = load_fixture("ollama_tags"); let tool_call = load_fixture("ollama_cloud_tool_call"); let final_chunk = load_fixture("ollama_cloud_final"); let base_url = server.uri(); let tags_template = ResponseTemplate::new(200).set_body_json(tags); Mock::given(method("GET")) .and(path("/api/tags")) .respond_with(tags_template.clone()) .mount(&server) .await; Mock::given(method("POST")) .and(path("/api/chat")) .and(BodySubstringMatcher::not_contains("\"role\":\"tool\"")) .respond_with(ResponseTemplate::new(200).set_body_json(tool_call)) .mount(&server) .await; Mock::given(method("POST")) .and(path("/api/chat")) .and(BodySubstringMatcher::contains("\"role\":\"tool\"")) .respond_with(ResponseTemplate::new(200).set_body_json(final_chunk)) .mount(&server) .await; Mock::given(method("POST")) .and(path("/v1/web/search")) .and(header("authorization", "Bearer test-key")) .respond_with(ResponseTemplate::new(200).set_body_json(json!({ "results": [ { "title": "Rust 1.85 Released", "url": "https://blog.rust-lang.org/2025/10/23/Rust-1.85.html", "snippet": "Rust 1.85 lands with incremental compilation improvements." } ] }))) .expect(1) .mount(&server) .await; let config = configure_cloud(&base_url); let cloud_cfg = config .providers .get("ollama_cloud") .expect("cloud provider config") .clone(); assert_eq!(cloud_cfg.api_key.as_deref(), Some("test-key")); assert_eq!(cloud_cfg.base_url.as_deref(), Some(base_url.as_str())); let provider: Arc = Arc::new( OllamaProvider::from_config("ollama_cloud", &cloud_cfg, Some(&config.general)) .expect("cloud provider"), ); let (mut session, _tmp) = create_session(provider, config).await; session.grant_consent( WEB_SEARCH_TOOL_NAME, vec!["network".into()], vec![format!("{}/v1/web/search", base_url)], ); let outcome = session .send_message( "What is new in Rust today?".to_string(), ChatParameters::default(), ) .await; let response = match outcome { Ok(SessionOutcome::Complete(response)) => response, Ok(_) => panic!("expected complete outcome"), Err(err) => panic!("cloud completion: {err:?}"), }; assert_eq!( response.message.content, "Rust 1.85 shipped today. Summarising the highlights now." ); let convo = session.conversation(); let tool_messages: Vec<_> = convo .messages .iter() .filter(|msg| msg.role == Role::Tool) .collect(); assert_eq!(tool_messages.len(), 1); assert!( tool_messages[0].content.contains("Rust 1.85 Released"), "tool response should include search result" ); let snapshot = session .current_usage_snapshot() .await .expect("usage snapshot"); assert_eq!(snapshot.provider, "ollama_cloud"); assert_eq!(snapshot.hourly.total_tokens, 112); // 64 prompt + 48 completion assert_eq!(snapshot.weekly.total_tokens, 112); } #[tokio::test(flavor = "multi_thread")] async fn cloud_unauthorized_degrades_without_usage() { let server = MockServer::start().await; let tags = load_fixture("ollama_tags"); let tags_template = ResponseTemplate::new(200).set_body_json(tags); let base_url = server.uri(); Mock::given(method("GET")) .and(path("/api/tags")) .respond_with(tags_template.clone()) .mount(&server) .await; Mock::given(method("POST")) .and(path("/api/chat")) .respond_with(ResponseTemplate::new(401).set_body_json(json!({ "error": "unauthorized" }))) .mount(&server) .await; let config = configure_cloud(&base_url); let cloud_cfg = config .providers .get("ollama_cloud") .expect("cloud provider config") .clone(); let provider: Arc = Arc::new( OllamaProvider::from_config("ollama_cloud", &cloud_cfg, Some(&config.general)) .expect("cloud provider"), ); let (mut session, _tmp) = create_session(provider, config).await; let err_text = match session .send_message("Switch to cloud".to_string(), ChatParameters::default()) .await { Ok(_) => panic!("expected unauthorized error, but request succeeded"), Err(err) => err.to_string(), }; assert!( err_text.contains("unauthorized") || err_text.contains("API key"), "error should surface unauthorized detail, got: {err_text}" ); let snapshot = session .current_usage_snapshot() .await .expect("usage snapshot"); assert_eq!(snapshot.hourly.total_tokens, 0); assert_eq!(snapshot.weekly.total_tokens, 0); } #[tokio::test(flavor = "multi_thread")] async fn cloud_rate_limit_returns_error_without_usage() { let server = MockServer::start().await; let tags = load_fixture("ollama_tags"); let tags_template = ResponseTemplate::new(200).set_body_json(tags); let base_url = server.uri(); Mock::given(method("GET")) .and(path("/api/tags")) .respond_with(tags_template.clone()) .mount(&server) .await; Mock::given(method("POST")) .and(path("/api/chat")) .respond_with(ResponseTemplate::new(429).set_body_json(json!({ "error": "too many requests" }))) .mount(&server) .await; let config = configure_cloud(&base_url); let cloud_cfg = config .providers .get("ollama_cloud") .expect("cloud provider config") .clone(); let provider: Arc = Arc::new( OllamaProvider::from_config("ollama_cloud", &cloud_cfg, Some(&config.general)) .expect("cloud provider"), ); let (mut session, _tmp) = create_session(provider, config).await; let err_text = match session .send_message("Hit rate limit".to_string(), ChatParameters::default()) .await { Ok(_) => panic!("expected rate-limit error, but request succeeded"), Err(err) => err.to_string(), }; assert!( err_text.contains("rate limited") || err_text.contains("429"), "error should mention rate limiting, got: {err_text}" ); let snapshot = session .current_usage_snapshot() .await .expect("usage snapshot"); assert_eq!(snapshot.hourly.total_tokens, 0); assert_eq!(snapshot.weekly.total_tokens, 0); }