use std::{sync::Arc, time::Duration}; use owlen_core::tools::WEB_SEARCH_TOOL_NAME; use owlen_core::types::{ChatParameters, ChatResponse, Role}; use owlen_core::{ Config, Provider, providers::OllamaProvider, session::{SessionController, SessionOutcome}, storage::StorageManager, ui::NoOpUiController, }; use serde_json::{Value, json}; use tempfile::{TempDir, tempdir}; use wiremock::{ Match, Mock, MockServer, Request, ResponseTemplate, matchers::{header, method, path}, }; #[derive(Clone, Copy)] struct BodySubstringMatcher { needle: &'static str, should_contain: bool, } impl BodySubstringMatcher { const fn contains(needle: &'static str) -> Self { Self { needle, should_contain: true, } } const fn not_contains(needle: &'static str) -> Self { Self { needle, should_contain: false, } } } impl Match for BodySubstringMatcher { fn matches(&self, request: &Request) -> bool { let body_str = std::str::from_utf8(&request.body).unwrap_or_default(); body_str.contains(self.needle) == self.should_contain } } fn load_fixture(name: &str) -> Value { match name { "ollama_tags" => serde_json::from_str(include_str!("fixtures/ollama_tags.json")) .expect("valid tags fixture"), "ollama_local_completion" => { serde_json::from_str(include_str!("fixtures/ollama_local_completion.json")) .expect("valid local completion fixture") } "ollama_cloud_tool_call" => { serde_json::from_str(include_str!("fixtures/ollama_cloud_tool_call.json")) .expect("valid cloud tool call fixture") } "ollama_cloud_final" => { serde_json::from_str(include_str!("fixtures/ollama_cloud_final.json")) .expect("valid cloud final fixture") } other => panic!("unknown fixture '{other}'"), } } async fn create_session( provider: Arc, config: Config, ) -> (SessionController, TempDir) { let temp_dir = tempdir().expect("temp dir"); let storage_path = temp_dir.path().join("owlen-tests.db"); let storage = Arc::new( StorageManager::with_database_path(storage_path) .await .expect("storage manager"), ); let ui = Arc::new(NoOpUiController); let session = SessionController::new(provider, config, storage, ui, false, None) .await .expect("session controller"); (session, temp_dir) } async fn send_prompt(session: &mut SessionController, message: &str) -> ChatResponse { match session .send_message(message.to_string(), ChatParameters::default()) .await { Ok(SessionOutcome::Complete(response)) => response, Ok(SessionOutcome::Streaming { .. }) => { panic!("expected complete outcome, got streaming response") } Err(err) => panic!("send_message failed: {err:?}"), } } fn configure_local(base_url: &str) -> Config { let mut config = Config::default(); config.general.default_provider = "ollama_local".into(); config.general.default_model = Some("local-mini".into()); config.general.enable_streaming = false; config.privacy.encrypt_local_data = false; config.privacy.require_consent_per_session = false; if let Some(local) = config.providers.get_mut("ollama_local") { local.enabled = true; local.base_url = Some(base_url.to_string()); } config } fn configure_cloud(base_url: &str, search_endpoint: Option<&str>) -> Config { let mut config = Config::default(); config.general.default_provider = "ollama_cloud".into(); config.general.default_model = Some("llama3:8b-cloud".into()); config.general.enable_streaming = false; config.privacy.enable_remote_search = true; config.privacy.encrypt_local_data = false; config.privacy.require_consent_per_session = false; config.tools.web_search.enabled = true; unsafe { std::env::set_var("OWLEN_ALLOW_INSECURE_CLOUD", "1"); } if let Some(cloud) = config.providers.get_mut("ollama_cloud") { cloud.enabled = true; cloud.base_url = Some(base_url.to_string()); cloud.api_key = Some("test-key".into()); if let Some(endpoint) = search_endpoint { cloud .extra .insert("web_search_endpoint".into(), Value::String(endpoint.into())); } cloud.extra.insert( owlen_core::config::OLLAMA_CLOUD_ENDPOINT_KEY.into(), Value::String(base_url.to_string()), ); } config } async fn run_local_completion(base_suffix: &str) { let server = MockServer::start().await; let raw_base = format!("{}{}", server.uri(), base_suffix); let tags = load_fixture("ollama_tags"); let completion = load_fixture("ollama_local_completion"); Mock::given(method("GET")) .and(path("/api/tags")) .respond_with(ResponseTemplate::new(200).set_body_json(tags)) .mount(&server) .await; Mock::given(method("POST")) .and(path("/api/chat")) .respond_with(ResponseTemplate::new(200).set_body_json(completion)) .expect(1) .mount(&server) .await; let config = configure_local(&raw_base); let provider: Arc = Arc::new(OllamaProvider::new(raw_base.clone()).expect("local provider")); let (mut session, _tmp) = create_session(provider, config).await; let response = send_prompt(&mut session, "Summarise the local status.").await; assert_eq!(response.message.content, "Local response complete."); let snapshot = session .current_usage_snapshot() .await .expect("usage snapshot"); assert_eq!(snapshot.provider, "ollama_local"); assert_eq!(snapshot.hourly.total_tokens, 36); assert_eq!(snapshot.weekly.total_tokens, 36); } async fn run_cloud_tool_flow( base_suffix: &str, search_endpoint: Option<&str>, expected_search_path: &str, ) { let server = MockServer::start().await; let raw_base = format!("{}{}", server.uri(), base_suffix); let tags = load_fixture("ollama_tags"); let tool_call = load_fixture("ollama_cloud_tool_call"); let final_chunk = load_fixture("ollama_cloud_final"); Mock::given(method("GET")) .and(path("/api/tags")) .and(header("authorization", "Bearer test-key")) .respond_with(ResponseTemplate::new(200).set_body_json(tags)) .mount(&server) .await; Mock::given(method("POST")) .and(path("/api/chat")) .and(header("authorization", "Bearer test-key")) .and(BodySubstringMatcher::not_contains("\"role\":\"tool\"")) .respond_with(ResponseTemplate::new(200).set_body_json(tool_call)) .expect(1) .mount(&server) .await; Mock::given(method("POST")) .and(path("/api/chat")) .and(header("authorization", "Bearer test-key")) .and(BodySubstringMatcher::contains("\"role\":\"tool\"")) .respond_with(ResponseTemplate::new(200).set_body_json(final_chunk)) .expect(1) .mount(&server) .await; Mock::given(method("POST")) .and(path(expected_search_path)) .and(header("authorization", "Bearer test-key")) .respond_with(ResponseTemplate::new(200).set_body_json(json!({ "results": [ { "title": "Rust 1.85 Released", "url": "https://blog.rust-lang.org/2025/10/23/Rust-1.85.html", "snippet": "Rust 1.85 lands with incremental compilation improvements." } ] }))) .expect(1) .mount(&server) .await; let config = configure_cloud(&raw_base, search_endpoint); let cloud_cfg = config .providers .get("ollama_cloud") .expect("cloud provider config") .clone(); assert_eq!(cloud_cfg.api_key.as_deref(), Some("test-key")); assert_eq!(cloud_cfg.base_url.as_deref(), Some(raw_base.as_str())); assert_eq!(cloud_cfg.api_key.as_deref(), Some("test-key")); assert_eq!(cloud_cfg.base_url.as_deref(), Some(raw_base.as_str())); let provider: Arc = Arc::new( OllamaProvider::from_config("ollama_cloud", &cloud_cfg, Some(&config.general)) .expect("cloud provider"), ); let (mut session, _tmp) = create_session(provider, config).await; let search_url = format!( "{}/{}", raw_base.trim_end_matches('/'), expected_search_path.trim_start_matches('/') ); session.grant_consent( WEB_SEARCH_TOOL_NAME, vec!["network".into()], vec![search_url], ); let response = send_prompt(&mut session, "What is new in Rust today?").await; assert_eq!( response.message.content, "Rust 1.85 shipped today. Summarising the highlights now." ); let convo = session.conversation(); let tool_messages: Vec<_> = convo .messages .iter() .filter(|msg| msg.role == Role::Tool) .collect(); assert_eq!(tool_messages.len(), 1); assert!( tool_messages[0].content.contains("Rust 1.85 Released"), "tool response should include search result" ); let snapshot = session .current_usage_snapshot() .await .expect("usage snapshot"); assert_eq!(snapshot.provider, "ollama_cloud"); assert_eq!(snapshot.hourly.total_tokens, 112); assert_eq!(snapshot.weekly.total_tokens, 112); } async fn run_cloud_error( base_suffix: &str, status: u16, error_body: Value, prompt: &str, ) -> String { let server = MockServer::start().await; let raw_base = format!("{}{}", server.uri(), base_suffix); let search_endpoint = if base_suffix.is_empty() { "/v1/web/search" } else { "/web/search" }; let tags = load_fixture("ollama_tags"); Mock::given(method("GET")) .and(path("/api/tags")) .and(header("authorization", "Bearer test-key")) .respond_with(ResponseTemplate::new(200).set_body_json(tags)) .mount(&server) .await; Mock::given(method("POST")) .and(path("/api/chat")) .and(header("authorization", "Bearer test-key")) .respond_with( ResponseTemplate::new(status) .set_body_json(error_body) .set_delay(Duration::from_millis(5)), ) .expect(1) .mount(&server) .await; let config = configure_cloud(&raw_base, Some(search_endpoint)); let cloud_cfg = config .providers .get("ollama_cloud") .expect("cloud provider config") .clone(); let provider: Arc = Arc::new( OllamaProvider::from_config("ollama_cloud", &cloud_cfg, Some(&config.general)) .expect("cloud provider"), ); let (mut session, _tmp) = create_session(provider, config).await; let err_text = match session .send_message(prompt.to_string(), ChatParameters::default()) .await { Ok(_) => panic!("expected error status {status} but request succeeded"), Err(err) => err.to_string(), }; let snapshot = session .current_usage_snapshot() .await .expect("usage snapshot"); assert_eq!(snapshot.hourly.total_tokens, 0); assert_eq!(snapshot.weekly.total_tokens, 0); err_text } #[tokio::test(flavor = "multi_thread")] async fn local_provider_happy_path_records_usage() { run_local_completion("").await; } #[tokio::test(flavor = "multi_thread")] async fn local_provider_accepts_v1_base_url() { run_local_completion("/v1").await; } #[tokio::test(flavor = "multi_thread")] async fn cloud_tool_call_flows_through_web_search() { run_cloud_tool_flow("", Some("/v1/web/search"), "/v1/web/search").await; } #[tokio::test(flavor = "multi_thread")] async fn cloud_tool_call_accepts_v1_base_url() { run_cloud_tool_flow("/v1", Some("/web/search"), "/web/search").await; } #[tokio::test(flavor = "multi_thread")] async fn cloud_tool_call_uses_default_search_endpoint() { run_cloud_tool_flow("", None, "/api/web_search").await; } #[tokio::test(flavor = "multi_thread")] async fn cloud_unauthorized_degrades_without_usage() { for suffix in ["", "/v1"] { let err_text = run_cloud_error( suffix, 401, json!({ "error": "unauthorized" }), "Switch to cloud", ) .await; assert!( err_text.contains("unauthorized") || err_text.contains("API key"), "error should surface unauthorized detail for base '{suffix}', got: {err_text}" ); } } #[tokio::test(flavor = "multi_thread")] async fn cloud_rate_limit_returns_error_without_usage() { for suffix in ["", "/v1"] { let err_text = run_cloud_error( suffix, 429, json!({ "error": "too many requests" }), "Hit rate limit", ) .await; assert!( err_text.contains("rate limited") || err_text.contains("429"), "error should mention rate limiting for base '{suffix}', got: {err_text}" ); } }