diff --git a/crates/mcp/llm-server/src/main.rs b/crates/mcp/llm-server/src/main.rs index 7672a8e..fd0db69 100644 --- a/crates/mcp/llm-server/src/main.rs +++ b/crates/mcp/llm-server/src/main.rs @@ -127,12 +127,13 @@ fn provider_from_config() -> Result, RpcError> { match provider_cfg.provider_type.as_str() { "ollama" | "ollama_cloud" => { - let provider = OllamaProvider::from_config(&provider_cfg, Some(&config.general)) - .map_err(|e| { - RpcError::internal_error(format!( - "Failed to init Ollama provider from config: {e}" - )) - })?; + let provider = + OllamaProvider::from_config(&provider_key, &provider_cfg, Some(&config.general)) + .map_err(|e| { + RpcError::internal_error(format!( + "Failed to init Ollama provider from config: {e}" + )) + })?; Ok(Arc::new(provider) as Arc) } other => Err(RpcError::internal_error(format!( diff --git a/crates/owlen-cli/src/bootstrap.rs b/crates/owlen-cli/src/bootstrap.rs index 7282bcf..feb6f15 100644 --- a/crates/owlen-cli/src/bootstrap.rs +++ b/crates/owlen-cli/src/bootstrap.rs @@ -185,7 +185,8 @@ fn build_local_provider(cfg: &Config) -> Result> { match provider_cfg.provider_type.as_str() { "ollama" | "ollama_cloud" => { - let provider = OllamaProvider::from_config(provider_cfg, Some(&cfg.general))?; + let provider = + OllamaProvider::from_config(&provider_name, provider_cfg, Some(&cfg.general))?; Ok(Arc::new(provider) as Arc) } other => Err(anyhow!(format!( diff --git a/crates/owlen-cli/src/commands/cloud.rs b/crates/owlen-cli/src/commands/cloud.rs index 7b32605..3d4dd8a 100644 --- a/crates/owlen-cli/src/commands/cloud.rs +++ b/crates/owlen-cli/src/commands/cloud.rs @@ -161,7 +161,7 @@ async fn status(provider: String) -> Result<()> { Value::String("cloud".to_string()), ); - let ollama = OllamaProvider::from_config(&runtime_cfg, Some(&config.general)) + let ollama = OllamaProvider::from_config(&provider, &runtime_cfg, Some(&config.general)) .with_context(|| "Failed to construct Ollama provider. Run `owlen cloud setup` first.")?; match ollama.health_check().await { @@ -212,7 +212,7 @@ async fn models(provider: String) -> Result<()> { Value::String("cloud".to_string()), ); - let ollama = OllamaProvider::from_config(&runtime_cfg, Some(&config.general)) + let ollama = OllamaProvider::from_config(&provider, &runtime_cfg, Some(&config.general)) .with_context(|| "Failed to construct Ollama provider. Run `owlen cloud setup` first.")?; match ollama.list_models().await { diff --git a/crates/owlen-core/src/providers/ollama.rs b/crates/owlen-core/src/providers/ollama.rs index 4733992..e4cef1e 100644 --- a/crates/owlen-core/src/providers/ollama.rs +++ b/crates/owlen-core/src/providers/ollama.rs @@ -4,6 +4,7 @@ use std::{ env, net::{SocketAddr, TcpStream}, pin::Pin, + process::Command, sync::Arc, time::{Duration, Instant, SystemTime}, }; @@ -23,11 +24,12 @@ use ollama_rs::{ models::{LocalModel, ModelInfo as OllamaModelInfo, ModelOptions}, }; use reqwest::{Client, StatusCode, Url}; +use serde::Deserialize; use serde_json::{Map as JsonMap, Value, json}; -use tokio::{sync::RwLock, time::timeout}; +use tokio::{sync::RwLock, time::sleep}; #[cfg(test)] -use std::sync::{Mutex, OnceLock}; +use std::sync::{Mutex, MutexGuard, OnceLock}; #[cfg(test)] use tokio_test::block_on; use uuid::Uuid; @@ -48,6 +50,9 @@ const DEFAULT_MODEL_CACHE_TTL_SECS: u64 = 60; pub(crate) const CLOUD_BASE_URL: &str = OLLAMA_CLOUD_BASE_URL; const LOCAL_PROBE_TIMEOUT_MS: u64 = 200; const LOCAL_PROBE_TARGETS: &[&str] = &["127.0.0.1:11434", "[::1]:11434"]; +const LOCAL_TAGS_TIMEOUT_STEPS_MS: [u64; 3] = [400, 800, 1_600]; +const LOCAL_TAGS_RETRY_DELAYS_MS: [u64; 2] = [150, 300]; +const HEALTHCHECK_TIMEOUT_MS: u64 = 1_000; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] enum OllamaMode { @@ -122,8 +127,53 @@ impl ScopeSnapshot { } } +#[derive(Clone)] +struct ScopeHandle { + client: Ollama, + http_client: Client, + base_url: String, +} + +impl ScopeHandle { + fn new(client: Ollama, http_client: Client, base_url: impl Into) -> Self { + Self { + client, + http_client, + base_url: base_url.into(), + } + } + + fn api_url(&self, endpoint: &str) -> String { + build_api_endpoint(&self.base_url, endpoint) + } +} + +#[derive(Debug, Deserialize)] +struct TagsResponse { + #[serde(default)] + models: Vec, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ProviderVariant { + Local, + Cloud, +} + +impl ProviderVariant { + fn supports_local(self) -> bool { + matches!(self, ProviderVariant::Local) + } + + fn supports_cloud(self) -> bool { + matches!(self, ProviderVariant::Cloud) + } +} + #[derive(Debug)] struct OllamaOptions { + provider_name: String, + variant: ProviderVariant, mode: OllamaMode, base_url: String, request_timeout: Duration, @@ -133,8 +183,15 @@ struct OllamaOptions { } impl OllamaOptions { - fn new(mode: OllamaMode, base_url: impl Into) -> Self { + fn new( + provider_name: impl Into, + variant: ProviderVariant, + mode: OllamaMode, + base_url: impl Into, + ) -> Self { Self { + provider_name: provider_name.into(), + variant, mode, base_url: base_url.into(), request_timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS), @@ -153,6 +210,8 @@ impl OllamaOptions { /// Ollama provider implementation backed by `ollama-rs`. #[derive(Debug)] pub struct OllamaProvider { + provider_name: String, + variant: ProviderVariant, mode: OllamaMode, client: Ollama, http_client: Client, @@ -198,6 +257,16 @@ fn is_explicit_cloud_base(base_url: Option<&str>) -> bool { #[cfg(test)] static PROBE_OVERRIDE: OnceLock>> = OnceLock::new(); +#[cfg(test)] +static TAGS_OVERRIDE: OnceLock, Error>>>> = + OnceLock::new(); + +#[cfg(test)] +static TAGS_OVERRIDE_GATE: OnceLock> = OnceLock::new(); + +#[cfg(test)] +static PROBE_OVERRIDE_GATE: OnceLock> = OnceLock::new(); + #[cfg(test)] fn set_probe_override(value: Option) { let guard = PROBE_OVERRIDE.get_or_init(|| Mutex::new(None)); @@ -213,6 +282,51 @@ fn probe_override_value() -> Option { .to_owned() } +#[cfg(test)] +fn set_tags_override( + sequence: Vec, Error>>, +) -> TagsOverrideGuard { + let gate = TAGS_OVERRIDE_GATE + .get_or_init(|| Mutex::new(())) + .lock() + .expect("tags override gate mutex poisoned"); + + let store = TAGS_OVERRIDE.get_or_init(|| Mutex::new(Vec::new())); + { + let mut guard = store.lock().expect("tags override mutex poisoned"); + guard.clear(); + for item in sequence.into_iter().rev() { + guard.push(item); + } + } + TagsOverrideGuard { gate: Some(gate) } +} + +#[cfg(test)] +fn pop_tags_override() -> Option, Error>> { + TAGS_OVERRIDE + .get_or_init(|| Mutex::new(Vec::new())) + .lock() + .expect("tags override mutex poisoned") + .pop() +} + +#[cfg(test)] +struct TagsOverrideGuard { + gate: Option>, +} + +#[cfg(test)] +impl Drop for TagsOverrideGuard { + fn drop(&mut self) { + if let Some(store) = TAGS_OVERRIDE.get() { + let mut guard = store.lock().expect("tags override mutex poisoned"); + guard.clear(); + } + self.gate.take(); + } +} + fn probe_default_local_daemon(timeout: Duration) -> bool { #[cfg(test)] { @@ -237,14 +351,46 @@ impl OllamaProvider { let input = base_url.into(); let normalized = normalize_base_url(Some(&input), OllamaMode::Local).map_err(Error::Config)?; - Self::with_options(OllamaOptions::new(OllamaMode::Local, normalized)) + Self::with_options(OllamaOptions::new( + "ollama_local", + ProviderVariant::Local, + OllamaMode::Local, + normalized, + )) } /// Construct a provider from configuration settings. - pub fn from_config(config: &ProviderConfig, general: Option<&GeneralSettings>) -> Result { + pub fn from_config( + provider_id: &str, + config: &ProviderConfig, + general: Option<&GeneralSettings>, + ) -> Result { + let provider_type = config.provider_type.trim().to_ascii_lowercase(); + let register_name = { + let candidate = provider_id.trim(); + if candidate.is_empty() { + if provider_type.is_empty() { + "ollama".to_string() + } else { + provider_type.clone() + } + } else { + candidate.replace('-', "_") + } + }; + + let variant = if register_name == "ollama_cloud" || provider_type == "ollama_cloud" { + ProviderVariant::Cloud + } else { + ProviderVariant::Local + }; + let mut api_key = resolve_api_key(config.api_key.clone()) + .or_else(|| resolve_api_key_env_hint(config.api_key_env.as_deref())) .or_else(|| env_var_non_empty("OLLAMA_API_KEY")) .or_else(|| env_var_non_empty("OLLAMA_CLOUD_API_KEY")); + let api_key_present = api_key.is_some(); + let configured_mode = configured_mode_from_extra(config); let configured_mode_label = config .extra @@ -254,7 +400,7 @@ impl OllamaProvider { let base_url = config.base_url.as_deref(); let base_is_local = is_explicit_local_base(base_url); let base_is_cloud = is_explicit_cloud_base(base_url); - let base_is_other = base_url.is_some() && !base_is_local && !base_is_cloud; + let _base_is_other = base_url.is_some() && !base_is_local && !base_is_cloud; let mut local_probe_result = None; let cloud_endpoint = config @@ -265,28 +411,25 @@ impl OllamaProvider { .transpose() .map_err(Error::Config)?; - let mode = match configured_mode { - Some(mode) => mode, - None => { - if base_is_local || base_is_other { - OllamaMode::Local - } else if base_is_cloud && api_key.is_some() { - OllamaMode::Cloud - } else { - let probe = - probe_default_local_daemon(Duration::from_millis(LOCAL_PROBE_TIMEOUT_MS)); - local_probe_result = Some(probe); - if probe { - OllamaMode::Local - } else if api_key.is_some() { - OllamaMode::Cloud - } else { - OllamaMode::Local - } - } - } + if matches!(variant, ProviderVariant::Local) && configured_mode.is_none() { + let probe = probe_default_local_daemon(Duration::from_millis(LOCAL_PROBE_TIMEOUT_MS)); + local_probe_result = Some(probe); + } + + let mode = match variant { + ProviderVariant::Local => OllamaMode::Local, + ProviderVariant::Cloud => OllamaMode::Cloud, }; + if matches!(variant, ProviderVariant::Cloud) { + if !api_key_present { + return Err(Error::Config( + "Ollama Cloud API key not configured. Set providers.ollama_cloud.api_key or OLLAMA_CLOUD_API_KEY." + .into(), + )); + } + } + let base_candidate = match mode { OllamaMode::Local => base_url, OllamaMode::Cloud => { @@ -301,7 +444,12 @@ impl OllamaProvider { let normalized_base_url = normalize_base_url(base_candidate, mode).map_err(Error::Config)?; - let mut options = OllamaOptions::new(mode, normalized_base_url.clone()); + let mut options = OllamaOptions::new( + register_name.clone(), + variant, + mode, + normalized_base_url.clone(), + ); options.cloud_endpoint = cloud_endpoint.clone(); if let Some(timeout) = config @@ -327,7 +475,8 @@ impl OllamaProvider { } debug!( - "Resolved Ollama provider: mode={:?}, base_url={}, configured_mode={}, api_key_present={}, local_probe={}", + "Resolved Ollama provider '{}': mode={:?}, base_url={}, configured_mode={}, api_key_present={}, local_probe={}", + register_name, mode, normalized_base_url, configured_mode_label, @@ -348,6 +497,8 @@ impl OllamaProvider { fn with_options(options: OllamaOptions) -> Result { let OllamaOptions { + provider_name, + variant, mode, base_url, request_timeout, @@ -368,6 +519,8 @@ impl OllamaProvider { }; Ok(Self { + provider_name: provider_name.trim().to_ascii_lowercase(), + variant, mode, client: ollama_client, http_client, @@ -397,19 +550,47 @@ impl OllamaProvider { } } - fn build_local_client(&self) -> Result> { - if matches!(self.mode, OllamaMode::Local) { - return Ok(Some(self.client.clone())); - } - - let (client, _) = - build_client_for_base(Self::local_base_url(), self.request_timeout, None)?; - Ok(Some(client)) + fn supports_local_scope(&self) -> bool { + self.variant.supports_local() } - fn build_cloud_client(&self) -> Result> { + fn supports_cloud_scope(&self) -> bool { + self.variant.supports_cloud() + } + + fn build_local_client(&self) -> Result> { + if !self.supports_local_scope() { + return Ok(None); + } + + if matches!(self.mode, OllamaMode::Local) { + return Ok(Some(ScopeHandle::new( + self.client.clone(), + self.http_client.clone(), + self.base_url.clone(), + ))); + } + + let (client, http_client) = + build_client_for_base(Self::local_base_url(), self.request_timeout, None)?; + Ok(Some(ScopeHandle::new( + client, + http_client, + Self::local_base_url(), + ))) + } + + fn build_cloud_client(&self) -> Result> { + if !self.supports_cloud_scope() { + return Ok(None); + } + if matches!(self.mode, OllamaMode::Cloud) { - return Ok(Some(self.client.clone())); + return Ok(Some(ScopeHandle::new( + self.client.clone(), + self.http_client.clone(), + self.base_url.clone(), + ))); } let api_key = match self.api_key.as_deref() { @@ -419,8 +600,9 @@ impl OllamaProvider { let endpoint = self.cloud_endpoint.as_deref().unwrap_or(CLOUD_BASE_URL); - let (client, _) = build_client_for_base(endpoint, self.request_timeout, Some(api_key))?; - Ok(Some(client)) + let (client, http_client) = + build_client_for_base(endpoint, self.request_timeout, Some(api_key))?; + Ok(Some(ScopeHandle::new(client, http_client, endpoint))) } async fn cached_scope_models(&self, scope: OllamaMode) -> Option> { @@ -663,9 +845,9 @@ impl OllamaProvider { let mut seen: HashSet = HashSet::new(); let mut errors: Vec = Vec::new(); - if let Some(local_client) = self.build_local_client()? { + if let Some(local_handle) = self.build_local_client()? { match self - .fetch_models_for_scope(OllamaMode::Local, local_client.clone()) + .fetch_models_for_scope(OllamaMode::Local, local_handle) .await { Ok(models) => { @@ -680,9 +862,9 @@ impl OllamaProvider { } } - if let Some(cloud_client) = self.build_cloud_client()? { + if let Some(cloud_handle) = self.build_cloud_client()? { match self - .fetch_models_for_scope(OllamaMode::Cloud, cloud_client.clone()) + .fetch_models_for_scope(OllamaMode::Cloud, cloud_handle) .await { Ok(models) => { @@ -711,40 +893,31 @@ impl OllamaProvider { async fn fetch_models_for_scope( &self, scope: OllamaMode, - client: Ollama, + handle: ScopeHandle, ) -> Result> { - let list_result = if matches!(scope, OllamaMode::Local) { - match timeout( - Duration::from_millis(LOCAL_PROBE_TIMEOUT_MS), - client.list_local_models(), - ) - .await - { - Ok(result) => result.map_err(|err| self.map_ollama_error("list models", err, None)), - Err(_) => Err(Error::Timeout( - "Timed out while contacting the local Ollama daemon".to_string(), - )), - } - } else { - client - .list_local_models() - .await - .map_err(|err| self.map_ollama_error("list models", err, None)) - }; + let tags_result = self.fetch_scope_tags_with_retry(scope, &handle).await; - let models = match list_result { + let models = match tags_result { Ok(models) => models, Err(err) => { - let message = err.to_string(); - self.mark_scope_failure(scope, message).await; + let original_detail = err.to_string(); + let (error, banner) = self.decorate_scope_error(scope, &handle.base_url, err); + if banner != original_detail { + debug!( + "Model list for {:?} at {} failed: {}", + scope, handle.base_url, original_detail + ); + } + self.mark_scope_failure(scope, banner.clone()).await; if let Some(cached) = self.cached_scope_models(scope).await { return Ok(cached); } - return Err(err); + return Err(error); } }; let cache = self.model_details_cache.clone(); + let client = handle.client.clone(); let fetched = join_all(models.into_iter().map(|local| { let client = client.clone(); let cache = cache.clone(); @@ -780,6 +953,186 @@ impl OllamaProvider { Ok(converted) } + async fn fetch_scope_tags_with_retry( + &self, + scope: OllamaMode, + handle: &ScopeHandle, + ) -> Result> { + let attempts = if matches!(scope, OllamaMode::Local) { + LOCAL_TAGS_TIMEOUT_STEPS_MS.len() + } else { + 1 + }; + + let mut last_error: Option = None; + + for attempt in 0..attempts { + match self.fetch_scope_tags_once(scope, handle, attempt).await { + Ok(models) => return Ok(models), + Err(err) => { + let should_retry = matches!(scope, OllamaMode::Local) + && attempt + 1 < attempts + && matches!(err, Error::Timeout(_) | Error::Network(_)); + + if should_retry { + debug!( + "Retrying Ollama model list for {:?} (attempt {}): {}", + scope, + attempt + 1, + err + ); + last_error = Some(err); + sleep(self.tags_retry_delay(attempt)).await; + continue; + } + return Err(err); + } + } + } + + Err(last_error + .unwrap_or_else(|| Error::Unknown("Ollama model list retries exhausted".to_string()))) + } + + async fn fetch_scope_tags_once( + &self, + scope: OllamaMode, + handle: &ScopeHandle, + attempt: usize, + ) -> Result> { + #[cfg(test)] + if let Some(result) = pop_tags_override() { + return result; + } + + if matches!(scope, OllamaMode::Local) { + match self.list_local_models_via_cli() { + Ok(models) => return Ok(models), + Err(err) => { + debug!("`ollama ls` failed ({}); falling back to HTTP listing", err); + } + } + } + + let url = handle.api_url("tags"); + let response = handle + .http_client + .get(&url) + .timeout(self.tags_request_timeout(scope, attempt)) + .send() + .await + .map_err(|err| map_reqwest_error("list models", err))?; + + if !response.status().is_success() { + let status = response.status(); + let detail = response.text().await.unwrap_or_else(|err| err.to_string()); + return Err(self.map_http_failure("list models", status, detail, None)); + } + + let bytes = response + .bytes() + .await + .map_err(|err| map_reqwest_error("list models", err))?; + let parsed: TagsResponse = serde_json::from_slice(&bytes)?; + Ok(parsed.models) + } + + fn tags_request_timeout(&self, scope: OllamaMode, attempt: usize) -> Duration { + if matches!(scope, OllamaMode::Local) { + let idx = attempt.min(LOCAL_TAGS_TIMEOUT_STEPS_MS.len() - 1); + Duration::from_millis(LOCAL_TAGS_TIMEOUT_STEPS_MS[idx]) + } else { + self.request_timeout + } + } + + fn tags_retry_delay(&self, attempt: usize) -> Duration { + let idx = attempt.min(LOCAL_TAGS_RETRY_DELAYS_MS.len() - 1); + Duration::from_millis(LOCAL_TAGS_RETRY_DELAYS_MS[idx]) + } + + fn list_local_models_via_cli(&self) -> Result> { + let output = Command::new("ollama") + .arg("ls") + .output() + .map_err(|err| { + Error::Provider(anyhow!( + "Failed to execute `ollama ls`: {err}. Ensure the Ollama CLI is installed and accessible in PATH." + )) + })?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + return Err(Error::Provider(anyhow!( + "`ollama ls` exited with status {}: {}", + output.status, + stderr.trim() + ))); + } + + let stdout = String::from_utf8(output.stdout).map_err(|err| { + Error::Provider(anyhow!("`ollama ls` returned non-UTF8 output: {err}")) + })?; + + let mut models = Vec::new(); + for line in stdout.lines() { + let trimmed = line.trim(); + if trimmed.is_empty() { + continue; + } + let lowercase = trimmed.to_ascii_lowercase(); + if lowercase.starts_with("name") { + continue; + } + + let mut parts = trimmed.split_whitespace(); + let Some(name) = parts.next() else { + continue; + }; + let metadata_start = trimmed[name.len()..].trim(); + + models.push(LocalModel { + name: name.to_string(), + modified_at: metadata_start.to_string(), + size: 0, + }); + } + + Ok(models) + } + + fn decorate_scope_error( + &self, + scope: OllamaMode, + base_url: &str, + err: Error, + ) -> (Error, String) { + if matches!(scope, OllamaMode::Local) { + match err { + Error::Timeout(_) => { + let message = format_local_unreachable_message(base_url); + (Error::Timeout(message.clone()), message) + } + Error::Network(original) => { + if is_connectivity_error(&original) { + let message = format_local_unreachable_message(base_url); + (Error::Network(message.clone()), message) + } else { + let message = original.clone(); + (Error::Network(original), message) + } + } + other => { + let message = other.to_string(); + (other, message) + } + } + } else { + let message = err.to_string(); + (err, message) + } + } + fn convert_detailed_model_info( mode: OllamaMode, model_name: &str, @@ -893,7 +1246,7 @@ impl OllamaProvider { id: name.clone(), name, description: Some(description), - provider: "ollama".to_string(), + provider: self.provider_name.clone(), context_window: None, capabilities, supports_tools: false, @@ -948,7 +1301,7 @@ impl OllamaProvider { StatusCode::NOT_FOUND => { if let Some(model) = model { Error::InvalidInput(format!( - "Model '{model}' was not found at {}. Verify the name or pull it with `ollama pull`.", + "Model '{model}' was not found at {}. Verify the name or pull it with `ollama pull {model}`.", self.base_url )) } else { @@ -992,7 +1345,7 @@ impl LlmProvider for OllamaProvider { Self: 'a; fn name(&self) -> &str { - "ollama" + &self.provider_name } fn list_models(&self) -> Self::ListModelsFuture<'_> { @@ -1056,10 +1409,11 @@ impl LlmProvider for OllamaProvider { fn health_check(&self) -> Self::HealthCheckFuture<'_> { Box::pin(async move { - let url = self.api_url("version"); + let url = self.api_url("tags?limit=1"); let response = self .http_client .get(&url) + .timeout(Duration::from_millis(HEALTHCHECK_TIMEOUT_MS)) .send() .await .map_err(|err| map_reqwest_error("health check", err))?; @@ -1364,6 +1718,46 @@ fn value_to_u64(value: &Value) -> Option { } } +fn format_local_unreachable_message(base_url: &str) -> String { + let display = display_host_port(base_url); + format!( + "Ollama not reachable on {display}. Start the Ollama daemon (`ollama serve`) and try again." + ) +} + +fn display_host_port(base_url: &str) -> String { + Url::parse(base_url) + .ok() + .and_then(|url| { + url.host_str().map(|host| { + if let Some(port) = url.port() { + format!("{host}:{port}") + } else { + host.to_string() + } + }) + }) + .unwrap_or_else(|| base_url.to_string()) +} + +fn is_connectivity_error(message: &str) -> bool { + let lower = message.to_ascii_lowercase(); + const CONNECTIVITY_MARKERS: &[&str] = &[ + "connection refused", + "failed to connect", + "connect timeout", + "timed out while contacting", + "dns error", + "failed to lookup address", + "no route to host", + "host unreachable", + ]; + + CONNECTIVITY_MARKERS + .iter() + .any(|marker| lower.contains(marker)) +} + fn env_var_non_empty(name: &str) -> Option { env::var(name) .ok() @@ -1371,6 +1765,13 @@ fn env_var_non_empty(name: &str) -> Option { .filter(|value| !value.is_empty()) } +fn resolve_api_key_env_hint(env_var: Option<&str>) -> Option { + env_var + .map(str::trim) + .filter(|value| !value.is_empty()) + .and_then(env_var_non_empty) +} + fn resolve_api_key(configured: Option) -> Option { let raw = configured?.trim().to_string(); if raw.is_empty() { @@ -1545,7 +1946,8 @@ mod tests { Value::String("local".to_string()), ); - let provider = OllamaProvider::from_config(&config, None).expect("provider constructed"); + let provider = OllamaProvider::from_config("ollama_local", &config, None) + .expect("provider constructed"); assert_eq!(provider.mode, OllamaMode::Local); assert_eq!(provider.base_url, "http://localhost:11434"); @@ -1563,7 +1965,8 @@ mod tests { }; // simulate missing explicit mode; defaults to auto - let provider = OllamaProvider::from_config(&config, None).expect("provider constructed"); + let provider = OllamaProvider::from_config("ollama_local", &config, None) + .expect("provider constructed"); assert_eq!(provider.mode, OllamaMode::Local); assert_eq!(provider.base_url, "http://localhost:11434"); @@ -1584,12 +1987,191 @@ mod tests { Value::String("auto".to_string()), ); - let provider = OllamaProvider::from_config(&config, None).expect("provider constructed"); + let provider = OllamaProvider::from_config("ollama_cloud", &config, None) + .expect("provider constructed"); assert_eq!(provider.mode, OllamaMode::Cloud); assert_eq!(provider.base_url, CLOUD_BASE_URL); } + #[test] + fn cloud_provider_requires_api_key() { + let config = ProviderConfig { + enabled: true, + provider_type: "ollama_cloud".to_string(), + base_url: None, + api_key: None, + api_key_env: None, + extra: HashMap::new(), + }; + + let err = OllamaProvider::from_config("ollama_cloud", &config, None) + .expect_err("expected config error"); + match err { + Error::Config(message) => { + assert!(message.contains("API key")); + } + other => panic!("unexpected error variant: {other:?}"), + } + } + + #[test] + fn cloud_provider_uses_explicit_api_key() { + let config = ProviderConfig { + enabled: true, + provider_type: "ollama_cloud".to_string(), + base_url: None, + api_key: Some("secret-cloud-key".to_string()), + api_key_env: None, + extra: HashMap::new(), + }; + + let provider = OllamaProvider::from_config("ollama_cloud", &config, None) + .expect("provider constructed"); + assert_eq!(provider.name(), "ollama_cloud"); + assert_eq!(provider.mode, OllamaMode::Cloud); + assert_eq!(provider.base_url, CLOUD_BASE_URL); + } + + #[test] + fn cloud_provider_reads_api_key_from_env_hint() { + let config = ProviderConfig { + enabled: true, + provider_type: "ollama_cloud".to_string(), + base_url: None, + api_key: None, + api_key_env: Some("OLLAMA_TEST_CLOUD_KEY".to_string()), + extra: HashMap::new(), + }; + + unsafe { + std::env::set_var("OLLAMA_TEST_CLOUD_KEY", "env-secret"); + } + + assert!(std::env::var("OLLAMA_TEST_CLOUD_KEY").is_ok()); + assert!(resolve_api_key_env_hint(config.api_key_env.as_deref()).is_some()); + assert_eq!(config.api_key_env.as_deref(), Some("OLLAMA_TEST_CLOUD_KEY")); + + let provider = OllamaProvider::from_config("ollama_cloud", &config, None) + .expect("provider constructed"); + assert_eq!(provider.name(), "ollama_cloud"); + assert_eq!(provider.mode, OllamaMode::Cloud); + + unsafe { + std::env::remove_var("OLLAMA_TEST_CLOUD_KEY"); + } + } + + #[test] + fn fetch_scope_tags_with_retry_success_uses_override() { + let provider = OllamaProvider::new("http://localhost:11434").expect("provider constructed"); + let handle = ScopeHandle::new( + provider.client.clone(), + provider.http_client.clone(), + provider.base_url.clone(), + ); + + let _guard = set_tags_override(vec![Ok(vec![LocalModel { + name: "llama3".into(), + modified_at: "2024-01-01T00:00:00Z".into(), + size: 42, + }])]); + + let models = block_on(provider.fetch_scope_tags_with_retry(OllamaMode::Local, &handle)) + .expect("models returned"); + assert_eq!(models.len(), 1); + assert_eq!(models[0].name, "llama3"); + } + + #[test] + fn fetch_scope_tags_with_retry_retries_on_timeout_then_succeeds() { + let provider = OllamaProvider::new("http://localhost:11434").expect("provider constructed"); + let handle = ScopeHandle::new( + provider.client.clone(), + provider.http_client.clone(), + provider.base_url.clone(), + ); + + let _guard = set_tags_override(vec![ + Err(Error::Timeout("first attempt".into())), + Ok(vec![LocalModel { + name: "llama3".into(), + modified_at: "2024-01-01T00:00:00Z".into(), + size: 42, + }]), + ]); + + let models = block_on(provider.fetch_scope_tags_with_retry(OllamaMode::Local, &handle)) + .expect("models returned after retry"); + assert_eq!(models.len(), 1); + assert_eq!(models[0].name, "llama3"); + } + + #[test] + fn decorate_scope_error_returns_friendly_message_for_connectivity() { + let provider = OllamaProvider::new("http://localhost:11434").expect("provider constructed"); + let (error, message) = provider.decorate_scope_error( + OllamaMode::Local, + "http://localhost:11434", + Error::Network("failed to connect to host".into()), + ); + + assert!(matches!( + error, + Error::Network(ref text) if text.contains("Ollama not reachable") + )); + assert!(message.contains("Ollama not reachable")); + assert!(message.contains("localhost:11434")); + } + + #[test] + fn decorate_scope_error_preserves_http_failure_message() { + let provider = OllamaProvider::new("http://localhost:11434").expect("provider constructed"); + let original = "Ollama list models failed (500): boom".to_string(); + let (error, message) = provider.decorate_scope_error( + OllamaMode::Local, + "http://localhost:11434", + Error::Network(original.clone()), + ); + + assert!(matches!(error, Error::Network(ref text) if text.contains("500"))); + assert_eq!(message, original); + } + + #[test] + fn decorate_scope_error_translates_timeout() { + let provider = OllamaProvider::new("http://localhost:11434").expect("provider constructed"); + let (error, message) = provider.decorate_scope_error( + OllamaMode::Local, + "http://localhost:11434", + Error::Timeout("deadline exceeded".into()), + ); + + assert!(matches!( + error, + Error::Timeout(ref text) if text.contains("Ollama not reachable") + )); + assert!(message.contains("Ollama not reachable")); + } + + #[test] + fn map_http_failure_model_not_found_suggests_pull_hint() { + let provider = OllamaProvider::new("http://localhost:11434").expect("provider constructed"); + let err = provider.map_http_failure( + "chat", + StatusCode::NOT_FOUND, + "missing model".to_string(), + Some("llama3"), + ); + + let message = match err { + Error::InvalidInput(message) => message, + other => panic!("unexpected error variant: {other:?}"), + }; + + assert!(message.contains("ollama pull llama3")); + } + #[test] fn build_model_options_merges_parameters() { let mut parameters = ChatParameters::default(); @@ -1630,13 +2212,19 @@ mod tests { } #[cfg(test)] -struct ProbeOverrideGuard; +struct ProbeOverrideGuard { + gate: Option>, +} #[cfg(test)] impl ProbeOverrideGuard { fn set(value: Option) -> Self { + let gate = PROBE_OVERRIDE_GATE + .get_or_init(|| Mutex::new(())) + .lock() + .expect("probe override gate mutex poisoned"); set_probe_override(value); - ProbeOverrideGuard + ProbeOverrideGuard { gate: Some(gate) } } } @@ -1644,6 +2232,7 @@ impl ProbeOverrideGuard { impl Drop for ProbeOverrideGuard { fn drop(&mut self) { set_probe_override(None); + self.gate.take(); } } @@ -1666,7 +2255,8 @@ fn auto_mode_with_api_key_and_successful_probe_prefers_local() { assert!(probe_default_local_daemon(Duration::from_millis(1))); - let provider = OllamaProvider::from_config(&config, None).expect("provider constructed"); + let provider = + OllamaProvider::from_config("ollama_local", &config, None).expect("provider constructed"); assert_eq!(provider.mode, OllamaMode::Local); assert_eq!(provider.base_url, "http://localhost:11434"); @@ -1689,7 +2279,8 @@ fn auto_mode_with_api_key_and_failed_probe_prefers_cloud() { Value::String("auto".to_string()), ); - let provider = OllamaProvider::from_config(&config, None).expect("provider constructed"); + let provider = + OllamaProvider::from_config("ollama_cloud", &config, None).expect("provider constructed"); assert_eq!(provider.mode, OllamaMode::Cloud); assert_eq!(provider.base_url, CLOUD_BASE_URL); @@ -1706,7 +2297,8 @@ fn annotate_scope_status_adds_capabilities_for_unavailable_scopes() { extra: HashMap::new(), }; - let provider = OllamaProvider::from_config(&config, None).expect("provider constructed"); + let provider = + OllamaProvider::from_config("ollama_local", &config, None).expect("provider constructed"); let mut models = vec![ModelInfo { id: "llama3".to_string(),