Files
owlen/crates/owlen-core/tests/ollama_wiremock.rs

417 lines
13 KiB
Rust

use std::{sync::Arc, time::Duration};
use owlen_core::tools::WEB_SEARCH_TOOL_NAME;
use owlen_core::types::{ChatParameters, ChatResponse, Role};
use owlen_core::{
Config, Provider,
providers::OllamaProvider,
session::{SessionController, SessionOutcome},
storage::StorageManager,
ui::NoOpUiController,
};
use serde_json::{Value, json};
use tempfile::{TempDir, tempdir};
use wiremock::{
Match, Mock, MockServer, Request, ResponseTemplate,
matchers::{header, method, path},
};
#[derive(Clone, Copy)]
struct BodySubstringMatcher {
needle: &'static str,
should_contain: bool,
}
impl BodySubstringMatcher {
const fn contains(needle: &'static str) -> Self {
Self {
needle,
should_contain: true,
}
}
const fn not_contains(needle: &'static str) -> Self {
Self {
needle,
should_contain: false,
}
}
}
impl Match for BodySubstringMatcher {
fn matches(&self, request: &Request) -> bool {
let body_str = std::str::from_utf8(&request.body).unwrap_or_default();
body_str.contains(self.needle) == self.should_contain
}
}
fn load_fixture(name: &str) -> Value {
match name {
"ollama_tags" => serde_json::from_str(include_str!("fixtures/ollama_tags.json"))
.expect("valid tags fixture"),
"ollama_local_completion" => {
serde_json::from_str(include_str!("fixtures/ollama_local_completion.json"))
.expect("valid local completion fixture")
}
"ollama_cloud_tool_call" => {
serde_json::from_str(include_str!("fixtures/ollama_cloud_tool_call.json"))
.expect("valid cloud tool call fixture")
}
"ollama_cloud_final" => {
serde_json::from_str(include_str!("fixtures/ollama_cloud_final.json"))
.expect("valid cloud final fixture")
}
other => panic!("unknown fixture '{other}'"),
}
}
async fn create_session(
provider: Arc<dyn Provider>,
config: Config,
) -> (SessionController, TempDir) {
let temp_dir = tempdir().expect("temp dir");
let storage_path = temp_dir.path().join("owlen-tests.db");
let storage = Arc::new(
StorageManager::with_database_path(storage_path)
.await
.expect("storage manager"),
);
let ui = Arc::new(NoOpUiController);
let session = SessionController::new(provider, config, storage, ui, false, None)
.await
.expect("session controller");
(session, temp_dir)
}
async fn send_prompt(session: &mut SessionController, message: &str) -> ChatResponse {
match session
.send_message(message.to_string(), ChatParameters::default())
.await
{
Ok(SessionOutcome::Complete(response)) => response,
Ok(SessionOutcome::Streaming { .. }) => {
panic!("expected complete outcome, got streaming response")
}
Err(err) => panic!("send_message failed: {err:?}"),
}
}
fn configure_local(base_url: &str) -> Config {
let mut config = Config::default();
config.general.default_provider = "ollama_local".into();
config.general.default_model = Some("local-mini".into());
config.general.enable_streaming = false;
config.privacy.encrypt_local_data = false;
config.privacy.require_consent_per_session = false;
if let Some(local) = config.providers.get_mut("ollama_local") {
local.enabled = true;
local.base_url = Some(base_url.to_string());
}
config
}
fn configure_cloud(base_url: &str, search_endpoint: Option<&str>) -> Config {
let mut config = Config::default();
config.general.default_provider = "ollama_cloud".into();
config.general.default_model = Some("llama3:8b-cloud".into());
config.general.enable_streaming = false;
config.privacy.enable_remote_search = true;
config.privacy.encrypt_local_data = false;
config.privacy.require_consent_per_session = false;
config.tools.web_search.enabled = true;
unsafe {
std::env::set_var("OWLEN_ALLOW_INSECURE_CLOUD", "1");
}
if let Some(cloud) = config.providers.get_mut("ollama_cloud") {
cloud.enabled = true;
cloud.base_url = Some(base_url.to_string());
cloud.api_key = Some("test-key".into());
if let Some(endpoint) = search_endpoint {
cloud
.extra
.insert("web_search_endpoint".into(), Value::String(endpoint.into()));
}
cloud.extra.insert(
owlen_core::config::OLLAMA_CLOUD_ENDPOINT_KEY.into(),
Value::String(base_url.to_string()),
);
}
config
}
async fn run_local_completion(base_suffix: &str) {
let server = MockServer::start().await;
let raw_base = format!("{}{}", server.uri(), base_suffix);
let tags = load_fixture("ollama_tags");
let completion = load_fixture("ollama_local_completion");
Mock::given(method("GET"))
.and(path("/api/tags"))
.respond_with(ResponseTemplate::new(200).set_body_json(tags))
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.respond_with(ResponseTemplate::new(200).set_body_json(completion))
.expect(1)
.mount(&server)
.await;
let config = configure_local(&raw_base);
let provider: Arc<dyn Provider> =
Arc::new(OllamaProvider::new(raw_base.clone()).expect("local provider"));
let (mut session, _tmp) = create_session(provider, config).await;
let response = send_prompt(&mut session, "Summarise the local status.").await;
assert_eq!(response.message.content, "Local response complete.");
let snapshot = session
.current_usage_snapshot()
.await
.expect("usage snapshot");
assert_eq!(snapshot.provider, "ollama_local");
assert_eq!(snapshot.hourly.total_tokens, 36);
assert_eq!(snapshot.weekly.total_tokens, 36);
}
async fn run_cloud_tool_flow(
base_suffix: &str,
search_endpoint: Option<&str>,
expected_search_path: &str,
) {
let server = MockServer::start().await;
let raw_base = format!("{}{}", server.uri(), base_suffix);
let tags = load_fixture("ollama_tags");
let tool_call = load_fixture("ollama_cloud_tool_call");
let final_chunk = load_fixture("ollama_cloud_final");
Mock::given(method("GET"))
.and(path("/api/tags"))
.and(header("authorization", "Bearer test-key"))
.respond_with(ResponseTemplate::new(200).set_body_json(tags))
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.and(header("authorization", "Bearer test-key"))
.and(BodySubstringMatcher::not_contains("\"role\":\"tool\""))
.respond_with(ResponseTemplate::new(200).set_body_json(tool_call))
.expect(1)
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.and(header("authorization", "Bearer test-key"))
.and(BodySubstringMatcher::contains("\"role\":\"tool\""))
.respond_with(ResponseTemplate::new(200).set_body_json(final_chunk))
.expect(1)
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path(expected_search_path))
.and(header("authorization", "Bearer test-key"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"results": [
{
"title": "Rust 1.85 Released",
"url": "https://blog.rust-lang.org/2025/10/23/Rust-1.85.html",
"snippet": "Rust 1.85 lands with incremental compilation improvements."
}
]
})))
.expect(1)
.mount(&server)
.await;
let config = configure_cloud(&raw_base, search_endpoint);
let cloud_cfg = config
.providers
.get("ollama_cloud")
.expect("cloud provider config")
.clone();
assert_eq!(cloud_cfg.api_key.as_deref(), Some("test-key"));
assert_eq!(cloud_cfg.base_url.as_deref(), Some(raw_base.as_str()));
assert_eq!(cloud_cfg.api_key.as_deref(), Some("test-key"));
assert_eq!(cloud_cfg.base_url.as_deref(), Some(raw_base.as_str()));
let provider: Arc<dyn Provider> = Arc::new(
OllamaProvider::from_config("ollama_cloud", &cloud_cfg, Some(&config.general))
.expect("cloud provider"),
);
let (mut session, _tmp) = create_session(provider, config).await;
let search_url = format!(
"{}/{}",
raw_base.trim_end_matches('/'),
expected_search_path.trim_start_matches('/')
);
session.grant_consent(
WEB_SEARCH_TOOL_NAME,
vec!["network".into()],
vec![search_url],
);
let response = send_prompt(&mut session, "What is new in Rust today?").await;
assert_eq!(
response.message.content,
"Rust 1.85 shipped today. Summarising the highlights now."
);
let convo = session.conversation();
let tool_messages: Vec<_> = convo
.messages
.iter()
.filter(|msg| msg.role == Role::Tool)
.collect();
assert_eq!(tool_messages.len(), 1);
assert!(
tool_messages[0].content.contains("Rust 1.85 Released"),
"tool response should include search result"
);
let snapshot = session
.current_usage_snapshot()
.await
.expect("usage snapshot");
assert_eq!(snapshot.provider, "ollama_cloud");
assert_eq!(snapshot.hourly.total_tokens, 112);
assert_eq!(snapshot.weekly.total_tokens, 112);
}
async fn run_cloud_error(
base_suffix: &str,
status: u16,
error_body: Value,
prompt: &str,
) -> String {
let server = MockServer::start().await;
let raw_base = format!("{}{}", server.uri(), base_suffix);
let search_endpoint = if base_suffix.is_empty() {
"/v1/web/search"
} else {
"/web/search"
};
let tags = load_fixture("ollama_tags");
Mock::given(method("GET"))
.and(path("/api/tags"))
.and(header("authorization", "Bearer test-key"))
.respond_with(ResponseTemplate::new(200).set_body_json(tags))
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/api/chat"))
.and(header("authorization", "Bearer test-key"))
.respond_with(
ResponseTemplate::new(status)
.set_body_json(error_body)
.set_delay(Duration::from_millis(5)),
)
.expect(1)
.mount(&server)
.await;
let config = configure_cloud(&raw_base, Some(search_endpoint));
let cloud_cfg = config
.providers
.get("ollama_cloud")
.expect("cloud provider config")
.clone();
let provider: Arc<dyn Provider> = Arc::new(
OllamaProvider::from_config("ollama_cloud", &cloud_cfg, Some(&config.general))
.expect("cloud provider"),
);
let (mut session, _tmp) = create_session(provider, config).await;
let err_text = match session
.send_message(prompt.to_string(), ChatParameters::default())
.await
{
Ok(_) => panic!("expected error status {status} but request succeeded"),
Err(err) => err.to_string(),
};
let snapshot = session
.current_usage_snapshot()
.await
.expect("usage snapshot");
assert_eq!(snapshot.hourly.total_tokens, 0);
assert_eq!(snapshot.weekly.total_tokens, 0);
err_text
}
#[tokio::test(flavor = "multi_thread")]
async fn local_provider_happy_path_records_usage() {
run_local_completion("").await;
}
#[tokio::test(flavor = "multi_thread")]
async fn local_provider_accepts_v1_base_url() {
run_local_completion("/v1").await;
}
#[tokio::test(flavor = "multi_thread")]
async fn cloud_tool_call_flows_through_web_search() {
run_cloud_tool_flow("", Some("/v1/web/search"), "/v1/web/search").await;
}
#[tokio::test(flavor = "multi_thread")]
async fn cloud_tool_call_accepts_v1_base_url() {
run_cloud_tool_flow("/v1", Some("/web/search"), "/v1/web/search").await;
}
#[tokio::test(flavor = "multi_thread")]
async fn cloud_tool_call_uses_default_search_endpoint() {
run_cloud_tool_flow("", None, "/api/web_search").await;
}
#[tokio::test(flavor = "multi_thread")]
async fn cloud_unauthorized_degrades_without_usage() {
for suffix in ["", "/v1"] {
let err_text = run_cloud_error(
suffix,
401,
json!({ "error": "unauthorized" }),
"Switch to cloud",
)
.await;
assert!(
err_text.contains("unauthorized") || err_text.contains("API key"),
"error should surface unauthorized detail for base '{suffix}', got: {err_text}"
);
}
}
#[tokio::test(flavor = "multi_thread")]
async fn cloud_rate_limit_returns_error_without_usage() {
for suffix in ["", "/v1"] {
let err_text = run_cloud_error(
suffix,
429,
json!({ "error": "too many requests" }),
"Hit rate limit",
)
.await;
assert!(
err_text.contains("rate limited") || err_text.contains("429"),
"error should mention rate limiting for base '{suffix}', got: {err_text}"
);
}
}