Acceptance-Criteria:\n- spec-compliant names are shared via WEB_SEARCH_TOOL_NAME and ModeConfig checks canonical identifiers.\n- workspace depends on once_cell so regex helpers build without local target hacks. Test-Notes:\n- cargo test
397 lines
12 KiB
Rust
397 lines
12 KiB
Rust
use std::sync::Arc;
|
|
|
|
use owlen_core::tools::WEB_SEARCH_TOOL_NAME;
|
|
use owlen_core::types::{ChatParameters, Role};
|
|
use owlen_core::{
|
|
Config, Provider,
|
|
providers::OllamaProvider,
|
|
session::{SessionController, SessionOutcome},
|
|
storage::StorageManager,
|
|
ui::NoOpUiController,
|
|
};
|
|
use serde_json::{Value, json};
|
|
use tempfile::{TempDir, tempdir};
|
|
use wiremock::{
|
|
Match, Mock, MockServer, Request, ResponseTemplate,
|
|
matchers::{header, method, path},
|
|
};
|
|
|
|
#[derive(Clone, Copy)]
|
|
struct BodySubstringMatcher {
|
|
needle: &'static str,
|
|
should_contain: bool,
|
|
}
|
|
|
|
impl BodySubstringMatcher {
|
|
const fn contains(needle: &'static str) -> Self {
|
|
Self {
|
|
needle,
|
|
should_contain: true,
|
|
}
|
|
}
|
|
|
|
const fn not_contains(needle: &'static str) -> Self {
|
|
Self {
|
|
needle,
|
|
should_contain: false,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Match for BodySubstringMatcher {
|
|
fn matches(&self, request: &Request) -> bool {
|
|
let body_str = std::str::from_utf8(&request.body).unwrap_or_default();
|
|
body_str.contains(self.needle) == self.should_contain
|
|
}
|
|
}
|
|
|
|
fn load_fixture(name: &str) -> Value {
|
|
match name {
|
|
"ollama_tags" => serde_json::from_str(include_str!("fixtures/ollama_tags.json"))
|
|
.expect("valid tags fixture"),
|
|
"ollama_local_completion" => {
|
|
serde_json::from_str(include_str!("fixtures/ollama_local_completion.json"))
|
|
.expect("valid local completion fixture")
|
|
}
|
|
"ollama_cloud_tool_call" => {
|
|
serde_json::from_str(include_str!("fixtures/ollama_cloud_tool_call.json"))
|
|
.expect("valid cloud tool call fixture")
|
|
}
|
|
"ollama_cloud_final" => {
|
|
serde_json::from_str(include_str!("fixtures/ollama_cloud_final.json"))
|
|
.expect("valid cloud final fixture")
|
|
}
|
|
other => panic!("unknown fixture '{other}'"),
|
|
}
|
|
}
|
|
|
|
async fn create_session(
|
|
provider: Arc<dyn Provider>,
|
|
config: Config,
|
|
) -> (SessionController, TempDir) {
|
|
let temp_dir = tempdir().expect("temp dir");
|
|
let storage_path = temp_dir.path().join("owlen-tests.db");
|
|
let storage = Arc::new(
|
|
StorageManager::with_database_path(storage_path)
|
|
.await
|
|
.expect("storage manager"),
|
|
);
|
|
let ui = Arc::new(NoOpUiController);
|
|
|
|
let session = SessionController::new(provider, config, storage, ui, false, None)
|
|
.await
|
|
.expect("session controller");
|
|
(session, temp_dir)
|
|
}
|
|
|
|
fn configure_local(base_url: &str) -> Config {
|
|
let mut config = Config::default();
|
|
config.general.default_provider = "ollama_local".into();
|
|
config.general.default_model = Some("local-mini".into());
|
|
config.general.enable_streaming = false;
|
|
config.privacy.encrypt_local_data = false;
|
|
config.privacy.require_consent_per_session = false;
|
|
|
|
if let Some(local) = config.providers.get_mut("ollama_local") {
|
|
local.enabled = true;
|
|
local.base_url = Some(base_url.to_string());
|
|
}
|
|
|
|
config
|
|
}
|
|
|
|
fn configure_cloud(base_url: &str) -> Config {
|
|
let mut config = Config::default();
|
|
config.general.default_provider = "ollama_cloud".into();
|
|
config.general.default_model = Some("llama3:8b-cloud".into());
|
|
config.general.enable_streaming = false;
|
|
config.privacy.enable_remote_search = true;
|
|
config.privacy.encrypt_local_data = false;
|
|
config.privacy.require_consent_per_session = false;
|
|
config.tools.web_search.enabled = true;
|
|
unsafe {
|
|
std::env::set_var("OWLEN_ALLOW_INSECURE_CLOUD", "1");
|
|
}
|
|
|
|
if let Some(cloud) = config.providers.get_mut("ollama_cloud") {
|
|
cloud.enabled = true;
|
|
cloud.base_url = Some(base_url.to_string());
|
|
cloud.api_key = Some("test-key".into());
|
|
cloud.extra.insert(
|
|
"web_search_endpoint".into(),
|
|
Value::String("/v1/web/search".into()),
|
|
);
|
|
cloud.extra.insert(
|
|
owlen_core::config::OLLAMA_CLOUD_ENDPOINT_KEY.into(),
|
|
Value::String(base_url.to_string()),
|
|
);
|
|
}
|
|
|
|
config
|
|
}
|
|
|
|
#[tokio::test(flavor = "multi_thread")]
|
|
async fn local_provider_happy_path_records_usage() {
|
|
let server = MockServer::start().await;
|
|
|
|
let tags = load_fixture("ollama_tags");
|
|
let completion = load_fixture("ollama_local_completion");
|
|
|
|
let base_url = server.uri();
|
|
|
|
let tags_template = ResponseTemplate::new(200).set_body_json(tags);
|
|
Mock::given(method("GET"))
|
|
.and(path("/api/tags"))
|
|
.respond_with(tags_template.clone())
|
|
.mount(&server)
|
|
.await;
|
|
|
|
Mock::given(method("POST"))
|
|
.and(path("/api/chat"))
|
|
.respond_with(ResponseTemplate::new(200).set_body_json(completion))
|
|
.mount(&server)
|
|
.await;
|
|
|
|
let config = configure_local(&base_url);
|
|
let provider: Arc<dyn Provider> =
|
|
Arc::new(OllamaProvider::new(base_url.clone()).expect("local provider"));
|
|
|
|
let (mut session, _tmp) = create_session(provider, config).await;
|
|
|
|
let outcome = session
|
|
.send_message(
|
|
"Summarise the local status.".to_string(),
|
|
ChatParameters::default(),
|
|
)
|
|
.await
|
|
.expect("local completion");
|
|
|
|
let response = match outcome {
|
|
SessionOutcome::Complete(response) => response,
|
|
_ => panic!("expected complete outcome"),
|
|
};
|
|
|
|
assert_eq!(response.message.content, "Local response complete.");
|
|
|
|
let snapshot = session
|
|
.current_usage_snapshot()
|
|
.await
|
|
.expect("usage snapshot");
|
|
assert_eq!(snapshot.provider, "ollama_local");
|
|
assert_eq!(snapshot.hourly.total_tokens, 36);
|
|
assert_eq!(snapshot.weekly.total_tokens, 36);
|
|
}
|
|
|
|
#[tokio::test(flavor = "multi_thread")]
|
|
async fn cloud_tool_call_flows_through_web_search() {
|
|
let server = MockServer::start().await;
|
|
|
|
let tags = load_fixture("ollama_tags");
|
|
let tool_call = load_fixture("ollama_cloud_tool_call");
|
|
let final_chunk = load_fixture("ollama_cloud_final");
|
|
|
|
let base_url = server.uri();
|
|
|
|
let tags_template = ResponseTemplate::new(200).set_body_json(tags);
|
|
Mock::given(method("GET"))
|
|
.and(path("/api/tags"))
|
|
.respond_with(tags_template.clone())
|
|
.mount(&server)
|
|
.await;
|
|
|
|
Mock::given(method("POST"))
|
|
.and(path("/api/chat"))
|
|
.and(BodySubstringMatcher::not_contains("\"role\":\"tool\""))
|
|
.respond_with(ResponseTemplate::new(200).set_body_json(tool_call))
|
|
.mount(&server)
|
|
.await;
|
|
|
|
Mock::given(method("POST"))
|
|
.and(path("/api/chat"))
|
|
.and(BodySubstringMatcher::contains("\"role\":\"tool\""))
|
|
.respond_with(ResponseTemplate::new(200).set_body_json(final_chunk))
|
|
.mount(&server)
|
|
.await;
|
|
|
|
Mock::given(method("POST"))
|
|
.and(path("/v1/web/search"))
|
|
.and(header("authorization", "Bearer test-key"))
|
|
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
|
|
"results": [
|
|
{
|
|
"title": "Rust 1.85 Released",
|
|
"url": "https://blog.rust-lang.org/2025/10/23/Rust-1.85.html",
|
|
"snippet": "Rust 1.85 lands with incremental compilation improvements."
|
|
}
|
|
]
|
|
})))
|
|
.expect(1)
|
|
.mount(&server)
|
|
.await;
|
|
|
|
let config = configure_cloud(&base_url);
|
|
let cloud_cfg = config
|
|
.providers
|
|
.get("ollama_cloud")
|
|
.expect("cloud provider config")
|
|
.clone();
|
|
assert_eq!(cloud_cfg.api_key.as_deref(), Some("test-key"));
|
|
assert_eq!(cloud_cfg.base_url.as_deref(), Some(base_url.as_str()));
|
|
let provider: Arc<dyn Provider> = Arc::new(
|
|
OllamaProvider::from_config("ollama_cloud", &cloud_cfg, Some(&config.general))
|
|
.expect("cloud provider"),
|
|
);
|
|
|
|
let (mut session, _tmp) = create_session(provider, config).await;
|
|
session.grant_consent(
|
|
WEB_SEARCH_TOOL_NAME,
|
|
vec!["network".into()],
|
|
vec![format!("{}/v1/web/search", base_url)],
|
|
);
|
|
|
|
let outcome = session
|
|
.send_message(
|
|
"What is new in Rust today?".to_string(),
|
|
ChatParameters::default(),
|
|
)
|
|
.await;
|
|
|
|
let response = match outcome {
|
|
Ok(SessionOutcome::Complete(response)) => response,
|
|
Ok(_) => panic!("expected complete outcome"),
|
|
Err(err) => panic!("cloud completion: {err:?}"),
|
|
};
|
|
assert_eq!(
|
|
response.message.content,
|
|
"Rust 1.85 shipped today. Summarising the highlights now."
|
|
);
|
|
|
|
let convo = session.conversation();
|
|
let tool_messages: Vec<_> = convo
|
|
.messages
|
|
.iter()
|
|
.filter(|msg| msg.role == Role::Tool)
|
|
.collect();
|
|
assert_eq!(tool_messages.len(), 1);
|
|
assert!(
|
|
tool_messages[0].content.contains("Rust 1.85 Released"),
|
|
"tool response should include search result"
|
|
);
|
|
|
|
let snapshot = session
|
|
.current_usage_snapshot()
|
|
.await
|
|
.expect("usage snapshot");
|
|
assert_eq!(snapshot.provider, "ollama_cloud");
|
|
assert_eq!(snapshot.hourly.total_tokens, 112); // 64 prompt + 48 completion
|
|
assert_eq!(snapshot.weekly.total_tokens, 112);
|
|
}
|
|
|
|
#[tokio::test(flavor = "multi_thread")]
|
|
async fn cloud_unauthorized_degrades_without_usage() {
|
|
let server = MockServer::start().await;
|
|
|
|
let tags = load_fixture("ollama_tags");
|
|
let tags_template = ResponseTemplate::new(200).set_body_json(tags);
|
|
let base_url = server.uri();
|
|
Mock::given(method("GET"))
|
|
.and(path("/api/tags"))
|
|
.respond_with(tags_template.clone())
|
|
.mount(&server)
|
|
.await;
|
|
|
|
Mock::given(method("POST"))
|
|
.and(path("/api/chat"))
|
|
.respond_with(ResponseTemplate::new(401).set_body_json(json!({
|
|
"error": "unauthorized"
|
|
})))
|
|
.mount(&server)
|
|
.await;
|
|
|
|
let config = configure_cloud(&base_url);
|
|
let cloud_cfg = config
|
|
.providers
|
|
.get("ollama_cloud")
|
|
.expect("cloud provider config")
|
|
.clone();
|
|
let provider: Arc<dyn Provider> = Arc::new(
|
|
OllamaProvider::from_config("ollama_cloud", &cloud_cfg, Some(&config.general))
|
|
.expect("cloud provider"),
|
|
);
|
|
|
|
let (mut session, _tmp) = create_session(provider, config).await;
|
|
|
|
let err_text = match session
|
|
.send_message("Switch to cloud".to_string(), ChatParameters::default())
|
|
.await
|
|
{
|
|
Ok(_) => panic!("expected unauthorized error, but request succeeded"),
|
|
Err(err) => err.to_string(),
|
|
};
|
|
assert!(
|
|
err_text.contains("unauthorized") || err_text.contains("API key"),
|
|
"error should surface unauthorized detail, got: {err_text}"
|
|
);
|
|
|
|
let snapshot = session
|
|
.current_usage_snapshot()
|
|
.await
|
|
.expect("usage snapshot");
|
|
assert_eq!(snapshot.hourly.total_tokens, 0);
|
|
assert_eq!(snapshot.weekly.total_tokens, 0);
|
|
}
|
|
|
|
#[tokio::test(flavor = "multi_thread")]
|
|
async fn cloud_rate_limit_returns_error_without_usage() {
|
|
let server = MockServer::start().await;
|
|
|
|
let tags = load_fixture("ollama_tags");
|
|
let tags_template = ResponseTemplate::new(200).set_body_json(tags);
|
|
let base_url = server.uri();
|
|
Mock::given(method("GET"))
|
|
.and(path("/api/tags"))
|
|
.respond_with(tags_template.clone())
|
|
.mount(&server)
|
|
.await;
|
|
|
|
Mock::given(method("POST"))
|
|
.and(path("/api/chat"))
|
|
.respond_with(ResponseTemplate::new(429).set_body_json(json!({
|
|
"error": "too many requests"
|
|
})))
|
|
.mount(&server)
|
|
.await;
|
|
|
|
let config = configure_cloud(&base_url);
|
|
let cloud_cfg = config
|
|
.providers
|
|
.get("ollama_cloud")
|
|
.expect("cloud provider config")
|
|
.clone();
|
|
let provider: Arc<dyn Provider> = Arc::new(
|
|
OllamaProvider::from_config("ollama_cloud", &cloud_cfg, Some(&config.general))
|
|
.expect("cloud provider"),
|
|
);
|
|
|
|
let (mut session, _tmp) = create_session(provider, config).await;
|
|
|
|
let err_text = match session
|
|
.send_message("Hit rate limit".to_string(), ChatParameters::default())
|
|
.await
|
|
{
|
|
Ok(_) => panic!("expected rate-limit error, but request succeeded"),
|
|
Err(err) => err.to_string(),
|
|
};
|
|
assert!(
|
|
err_text.contains("rate limited") || err_text.contains("429"),
|
|
"error should mention rate limiting, got: {err_text}"
|
|
);
|
|
|
|
let snapshot = session
|
|
.current_usage_snapshot()
|
|
.await
|
|
.expect("usage snapshot");
|
|
assert_eq!(snapshot.hourly.total_tokens, 0);
|
|
assert_eq!(snapshot.weekly.total_tokens, 0);
|
|
}
|