feat(providers/ollama): add variant support, retryable tag fetching with CLI fallback, and configurable provider name for robust model listing and health checks

This commit is contained in:
2025-10-18 05:59:50 +02:00
parent 4ce4ac0b0e
commit 3308b483f7
4 changed files with 682 additions and 88 deletions

View File

@@ -127,7 +127,8 @@ fn provider_from_config() -> Result<Arc<dyn Provider>, RpcError> {
match provider_cfg.provider_type.as_str() { match provider_cfg.provider_type.as_str() {
"ollama" | "ollama_cloud" => { "ollama" | "ollama_cloud" => {
let provider = OllamaProvider::from_config(&provider_cfg, Some(&config.general)) let provider =
OllamaProvider::from_config(&provider_key, &provider_cfg, Some(&config.general))
.map_err(|e| { .map_err(|e| {
RpcError::internal_error(format!( RpcError::internal_error(format!(
"Failed to init Ollama provider from config: {e}" "Failed to init Ollama provider from config: {e}"

View File

@@ -185,7 +185,8 @@ fn build_local_provider(cfg: &Config) -> Result<Arc<dyn Provider>> {
match provider_cfg.provider_type.as_str() { match provider_cfg.provider_type.as_str() {
"ollama" | "ollama_cloud" => { "ollama" | "ollama_cloud" => {
let provider = OllamaProvider::from_config(provider_cfg, Some(&cfg.general))?; let provider =
OllamaProvider::from_config(&provider_name, provider_cfg, Some(&cfg.general))?;
Ok(Arc::new(provider) as Arc<dyn Provider>) Ok(Arc::new(provider) as Arc<dyn Provider>)
} }
other => Err(anyhow!(format!( other => Err(anyhow!(format!(

View File

@@ -161,7 +161,7 @@ async fn status(provider: String) -> Result<()> {
Value::String("cloud".to_string()), Value::String("cloud".to_string()),
); );
let ollama = OllamaProvider::from_config(&runtime_cfg, Some(&config.general)) let ollama = OllamaProvider::from_config(&provider, &runtime_cfg, Some(&config.general))
.with_context(|| "Failed to construct Ollama provider. Run `owlen cloud setup` first.")?; .with_context(|| "Failed to construct Ollama provider. Run `owlen cloud setup` first.")?;
match ollama.health_check().await { match ollama.health_check().await {
@@ -212,7 +212,7 @@ async fn models(provider: String) -> Result<()> {
Value::String("cloud".to_string()), Value::String("cloud".to_string()),
); );
let ollama = OllamaProvider::from_config(&runtime_cfg, Some(&config.general)) let ollama = OllamaProvider::from_config(&provider, &runtime_cfg, Some(&config.general))
.with_context(|| "Failed to construct Ollama provider. Run `owlen cloud setup` first.")?; .with_context(|| "Failed to construct Ollama provider. Run `owlen cloud setup` first.")?;
match ollama.list_models().await { match ollama.list_models().await {

View File

@@ -4,6 +4,7 @@ use std::{
env, env,
net::{SocketAddr, TcpStream}, net::{SocketAddr, TcpStream},
pin::Pin, pin::Pin,
process::Command,
sync::Arc, sync::Arc,
time::{Duration, Instant, SystemTime}, time::{Duration, Instant, SystemTime},
}; };
@@ -23,11 +24,12 @@ use ollama_rs::{
models::{LocalModel, ModelInfo as OllamaModelInfo, ModelOptions}, models::{LocalModel, ModelInfo as OllamaModelInfo, ModelOptions},
}; };
use reqwest::{Client, StatusCode, Url}; use reqwest::{Client, StatusCode, Url};
use serde::Deserialize;
use serde_json::{Map as JsonMap, Value, json}; use serde_json::{Map as JsonMap, Value, json};
use tokio::{sync::RwLock, time::timeout}; use tokio::{sync::RwLock, time::sleep};
#[cfg(test)] #[cfg(test)]
use std::sync::{Mutex, OnceLock}; use std::sync::{Mutex, MutexGuard, OnceLock};
#[cfg(test)] #[cfg(test)]
use tokio_test::block_on; use tokio_test::block_on;
use uuid::Uuid; use uuid::Uuid;
@@ -48,6 +50,9 @@ const DEFAULT_MODEL_CACHE_TTL_SECS: u64 = 60;
pub(crate) const CLOUD_BASE_URL: &str = OLLAMA_CLOUD_BASE_URL; pub(crate) const CLOUD_BASE_URL: &str = OLLAMA_CLOUD_BASE_URL;
const LOCAL_PROBE_TIMEOUT_MS: u64 = 200; const LOCAL_PROBE_TIMEOUT_MS: u64 = 200;
const LOCAL_PROBE_TARGETS: &[&str] = &["127.0.0.1:11434", "[::1]:11434"]; const LOCAL_PROBE_TARGETS: &[&str] = &["127.0.0.1:11434", "[::1]:11434"];
const LOCAL_TAGS_TIMEOUT_STEPS_MS: [u64; 3] = [400, 800, 1_600];
const LOCAL_TAGS_RETRY_DELAYS_MS: [u64; 2] = [150, 300];
const HEALTHCHECK_TIMEOUT_MS: u64 = 1_000;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
enum OllamaMode { enum OllamaMode {
@@ -122,8 +127,53 @@ impl ScopeSnapshot {
} }
} }
#[derive(Clone)]
struct ScopeHandle {
client: Ollama,
http_client: Client,
base_url: String,
}
impl ScopeHandle {
fn new(client: Ollama, http_client: Client, base_url: impl Into<String>) -> Self {
Self {
client,
http_client,
base_url: base_url.into(),
}
}
fn api_url(&self, endpoint: &str) -> String {
build_api_endpoint(&self.base_url, endpoint)
}
}
#[derive(Debug, Deserialize)]
struct TagsResponse {
#[serde(default)]
models: Vec<LocalModel>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ProviderVariant {
Local,
Cloud,
}
impl ProviderVariant {
fn supports_local(self) -> bool {
matches!(self, ProviderVariant::Local)
}
fn supports_cloud(self) -> bool {
matches!(self, ProviderVariant::Cloud)
}
}
#[derive(Debug)] #[derive(Debug)]
struct OllamaOptions { struct OllamaOptions {
provider_name: String,
variant: ProviderVariant,
mode: OllamaMode, mode: OllamaMode,
base_url: String, base_url: String,
request_timeout: Duration, request_timeout: Duration,
@@ -133,8 +183,15 @@ struct OllamaOptions {
} }
impl OllamaOptions { impl OllamaOptions {
fn new(mode: OllamaMode, base_url: impl Into<String>) -> Self { fn new(
provider_name: impl Into<String>,
variant: ProviderVariant,
mode: OllamaMode,
base_url: impl Into<String>,
) -> Self {
Self { Self {
provider_name: provider_name.into(),
variant,
mode, mode,
base_url: base_url.into(), base_url: base_url.into(),
request_timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS), request_timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
@@ -153,6 +210,8 @@ impl OllamaOptions {
/// Ollama provider implementation backed by `ollama-rs`. /// Ollama provider implementation backed by `ollama-rs`.
#[derive(Debug)] #[derive(Debug)]
pub struct OllamaProvider { pub struct OllamaProvider {
provider_name: String,
variant: ProviderVariant,
mode: OllamaMode, mode: OllamaMode,
client: Ollama, client: Ollama,
http_client: Client, http_client: Client,
@@ -198,6 +257,16 @@ fn is_explicit_cloud_base(base_url: Option<&str>) -> bool {
#[cfg(test)] #[cfg(test)]
static PROBE_OVERRIDE: OnceLock<Mutex<Option<bool>>> = OnceLock::new(); static PROBE_OVERRIDE: OnceLock<Mutex<Option<bool>>> = OnceLock::new();
#[cfg(test)]
static TAGS_OVERRIDE: OnceLock<Mutex<Vec<std::result::Result<Vec<LocalModel>, Error>>>> =
OnceLock::new();
#[cfg(test)]
static TAGS_OVERRIDE_GATE: OnceLock<Mutex<()>> = OnceLock::new();
#[cfg(test)]
static PROBE_OVERRIDE_GATE: OnceLock<Mutex<()>> = OnceLock::new();
#[cfg(test)] #[cfg(test)]
fn set_probe_override(value: Option<bool>) { fn set_probe_override(value: Option<bool>) {
let guard = PROBE_OVERRIDE.get_or_init(|| Mutex::new(None)); let guard = PROBE_OVERRIDE.get_or_init(|| Mutex::new(None));
@@ -213,6 +282,51 @@ fn probe_override_value() -> Option<bool> {
.to_owned() .to_owned()
} }
#[cfg(test)]
fn set_tags_override(
sequence: Vec<std::result::Result<Vec<LocalModel>, Error>>,
) -> TagsOverrideGuard {
let gate = TAGS_OVERRIDE_GATE
.get_or_init(|| Mutex::new(()))
.lock()
.expect("tags override gate mutex poisoned");
let store = TAGS_OVERRIDE.get_or_init(|| Mutex::new(Vec::new()));
{
let mut guard = store.lock().expect("tags override mutex poisoned");
guard.clear();
for item in sequence.into_iter().rev() {
guard.push(item);
}
}
TagsOverrideGuard { gate: Some(gate) }
}
#[cfg(test)]
fn pop_tags_override() -> Option<std::result::Result<Vec<LocalModel>, Error>> {
TAGS_OVERRIDE
.get_or_init(|| Mutex::new(Vec::new()))
.lock()
.expect("tags override mutex poisoned")
.pop()
}
#[cfg(test)]
struct TagsOverrideGuard {
gate: Option<MutexGuard<'static, ()>>,
}
#[cfg(test)]
impl Drop for TagsOverrideGuard {
fn drop(&mut self) {
if let Some(store) = TAGS_OVERRIDE.get() {
let mut guard = store.lock().expect("tags override mutex poisoned");
guard.clear();
}
self.gate.take();
}
}
fn probe_default_local_daemon(timeout: Duration) -> bool { fn probe_default_local_daemon(timeout: Duration) -> bool {
#[cfg(test)] #[cfg(test)]
{ {
@@ -237,14 +351,46 @@ impl OllamaProvider {
let input = base_url.into(); let input = base_url.into();
let normalized = let normalized =
normalize_base_url(Some(&input), OllamaMode::Local).map_err(Error::Config)?; normalize_base_url(Some(&input), OllamaMode::Local).map_err(Error::Config)?;
Self::with_options(OllamaOptions::new(OllamaMode::Local, normalized)) Self::with_options(OllamaOptions::new(
"ollama_local",
ProviderVariant::Local,
OllamaMode::Local,
normalized,
))
} }
/// Construct a provider from configuration settings. /// Construct a provider from configuration settings.
pub fn from_config(config: &ProviderConfig, general: Option<&GeneralSettings>) -> Result<Self> { pub fn from_config(
provider_id: &str,
config: &ProviderConfig,
general: Option<&GeneralSettings>,
) -> Result<Self> {
let provider_type = config.provider_type.trim().to_ascii_lowercase();
let register_name = {
let candidate = provider_id.trim();
if candidate.is_empty() {
if provider_type.is_empty() {
"ollama".to_string()
} else {
provider_type.clone()
}
} else {
candidate.replace('-', "_")
}
};
let variant = if register_name == "ollama_cloud" || provider_type == "ollama_cloud" {
ProviderVariant::Cloud
} else {
ProviderVariant::Local
};
let mut api_key = resolve_api_key(config.api_key.clone()) let mut api_key = resolve_api_key(config.api_key.clone())
.or_else(|| resolve_api_key_env_hint(config.api_key_env.as_deref()))
.or_else(|| env_var_non_empty("OLLAMA_API_KEY")) .or_else(|| env_var_non_empty("OLLAMA_API_KEY"))
.or_else(|| env_var_non_empty("OLLAMA_CLOUD_API_KEY")); .or_else(|| env_var_non_empty("OLLAMA_CLOUD_API_KEY"));
let api_key_present = api_key.is_some();
let configured_mode = configured_mode_from_extra(config); let configured_mode = configured_mode_from_extra(config);
let configured_mode_label = config let configured_mode_label = config
.extra .extra
@@ -254,7 +400,7 @@ impl OllamaProvider {
let base_url = config.base_url.as_deref(); let base_url = config.base_url.as_deref();
let base_is_local = is_explicit_local_base(base_url); let base_is_local = is_explicit_local_base(base_url);
let base_is_cloud = is_explicit_cloud_base(base_url); let base_is_cloud = is_explicit_cloud_base(base_url);
let base_is_other = base_url.is_some() && !base_is_local && !base_is_cloud; let _base_is_other = base_url.is_some() && !base_is_local && !base_is_cloud;
let mut local_probe_result = None; let mut local_probe_result = None;
let cloud_endpoint = config let cloud_endpoint = config
@@ -265,28 +411,25 @@ impl OllamaProvider {
.transpose() .transpose()
.map_err(Error::Config)?; .map_err(Error::Config)?;
let mode = match configured_mode { if matches!(variant, ProviderVariant::Local) && configured_mode.is_none() {
Some(mode) => mode, let probe = probe_default_local_daemon(Duration::from_millis(LOCAL_PROBE_TIMEOUT_MS));
None => {
if base_is_local || base_is_other {
OllamaMode::Local
} else if base_is_cloud && api_key.is_some() {
OllamaMode::Cloud
} else {
let probe =
probe_default_local_daemon(Duration::from_millis(LOCAL_PROBE_TIMEOUT_MS));
local_probe_result = Some(probe); local_probe_result = Some(probe);
if probe {
OllamaMode::Local
} else if api_key.is_some() {
OllamaMode::Cloud
} else {
OllamaMode::Local
}
}
} }
let mode = match variant {
ProviderVariant::Local => OllamaMode::Local,
ProviderVariant::Cloud => OllamaMode::Cloud,
}; };
if matches!(variant, ProviderVariant::Cloud) {
if !api_key_present {
return Err(Error::Config(
"Ollama Cloud API key not configured. Set providers.ollama_cloud.api_key or OLLAMA_CLOUD_API_KEY."
.into(),
));
}
}
let base_candidate = match mode { let base_candidate = match mode {
OllamaMode::Local => base_url, OllamaMode::Local => base_url,
OllamaMode::Cloud => { OllamaMode::Cloud => {
@@ -301,7 +444,12 @@ impl OllamaProvider {
let normalized_base_url = let normalized_base_url =
normalize_base_url(base_candidate, mode).map_err(Error::Config)?; normalize_base_url(base_candidate, mode).map_err(Error::Config)?;
let mut options = OllamaOptions::new(mode, normalized_base_url.clone()); let mut options = OllamaOptions::new(
register_name.clone(),
variant,
mode,
normalized_base_url.clone(),
);
options.cloud_endpoint = cloud_endpoint.clone(); options.cloud_endpoint = cloud_endpoint.clone();
if let Some(timeout) = config if let Some(timeout) = config
@@ -327,7 +475,8 @@ impl OllamaProvider {
} }
debug!( debug!(
"Resolved Ollama provider: mode={:?}, base_url={}, configured_mode={}, api_key_present={}, local_probe={}", "Resolved Ollama provider '{}': mode={:?}, base_url={}, configured_mode={}, api_key_present={}, local_probe={}",
register_name,
mode, mode,
normalized_base_url, normalized_base_url,
configured_mode_label, configured_mode_label,
@@ -348,6 +497,8 @@ impl OllamaProvider {
fn with_options(options: OllamaOptions) -> Result<Self> { fn with_options(options: OllamaOptions) -> Result<Self> {
let OllamaOptions { let OllamaOptions {
provider_name,
variant,
mode, mode,
base_url, base_url,
request_timeout, request_timeout,
@@ -368,6 +519,8 @@ impl OllamaProvider {
}; };
Ok(Self { Ok(Self {
provider_name: provider_name.trim().to_ascii_lowercase(),
variant,
mode, mode,
client: ollama_client, client: ollama_client,
http_client, http_client,
@@ -397,19 +550,47 @@ impl OllamaProvider {
} }
} }
fn build_local_client(&self) -> Result<Option<Ollama>> { fn supports_local_scope(&self) -> bool {
self.variant.supports_local()
}
fn supports_cloud_scope(&self) -> bool {
self.variant.supports_cloud()
}
fn build_local_client(&self) -> Result<Option<ScopeHandle>> {
if !self.supports_local_scope() {
return Ok(None);
}
if matches!(self.mode, OllamaMode::Local) { if matches!(self.mode, OllamaMode::Local) {
return Ok(Some(self.client.clone())); return Ok(Some(ScopeHandle::new(
self.client.clone(),
self.http_client.clone(),
self.base_url.clone(),
)));
} }
let (client, _) = let (client, http_client) =
build_client_for_base(Self::local_base_url(), self.request_timeout, None)?; build_client_for_base(Self::local_base_url(), self.request_timeout, None)?;
Ok(Some(client)) Ok(Some(ScopeHandle::new(
client,
http_client,
Self::local_base_url(),
)))
}
fn build_cloud_client(&self) -> Result<Option<ScopeHandle>> {
if !self.supports_cloud_scope() {
return Ok(None);
} }
fn build_cloud_client(&self) -> Result<Option<Ollama>> {
if matches!(self.mode, OllamaMode::Cloud) { if matches!(self.mode, OllamaMode::Cloud) {
return Ok(Some(self.client.clone())); return Ok(Some(ScopeHandle::new(
self.client.clone(),
self.http_client.clone(),
self.base_url.clone(),
)));
} }
let api_key = match self.api_key.as_deref() { let api_key = match self.api_key.as_deref() {
@@ -419,8 +600,9 @@ impl OllamaProvider {
let endpoint = self.cloud_endpoint.as_deref().unwrap_or(CLOUD_BASE_URL); let endpoint = self.cloud_endpoint.as_deref().unwrap_or(CLOUD_BASE_URL);
let (client, _) = build_client_for_base(endpoint, self.request_timeout, Some(api_key))?; let (client, http_client) =
Ok(Some(client)) build_client_for_base(endpoint, self.request_timeout, Some(api_key))?;
Ok(Some(ScopeHandle::new(client, http_client, endpoint)))
} }
async fn cached_scope_models(&self, scope: OllamaMode) -> Option<Vec<ModelInfo>> { async fn cached_scope_models(&self, scope: OllamaMode) -> Option<Vec<ModelInfo>> {
@@ -663,9 +845,9 @@ impl OllamaProvider {
let mut seen: HashSet<String> = HashSet::new(); let mut seen: HashSet<String> = HashSet::new();
let mut errors: Vec<Error> = Vec::new(); let mut errors: Vec<Error> = Vec::new();
if let Some(local_client) = self.build_local_client()? { if let Some(local_handle) = self.build_local_client()? {
match self match self
.fetch_models_for_scope(OllamaMode::Local, local_client.clone()) .fetch_models_for_scope(OllamaMode::Local, local_handle)
.await .await
{ {
Ok(models) => { Ok(models) => {
@@ -680,9 +862,9 @@ impl OllamaProvider {
} }
} }
if let Some(cloud_client) = self.build_cloud_client()? { if let Some(cloud_handle) = self.build_cloud_client()? {
match self match self
.fetch_models_for_scope(OllamaMode::Cloud, cloud_client.clone()) .fetch_models_for_scope(OllamaMode::Cloud, cloud_handle)
.await .await
{ {
Ok(models) => { Ok(models) => {
@@ -711,40 +893,31 @@ impl OllamaProvider {
async fn fetch_models_for_scope( async fn fetch_models_for_scope(
&self, &self,
scope: OllamaMode, scope: OllamaMode,
client: Ollama, handle: ScopeHandle,
) -> Result<Vec<ModelInfo>> { ) -> Result<Vec<ModelInfo>> {
let list_result = if matches!(scope, OllamaMode::Local) { let tags_result = self.fetch_scope_tags_with_retry(scope, &handle).await;
match timeout(
Duration::from_millis(LOCAL_PROBE_TIMEOUT_MS),
client.list_local_models(),
)
.await
{
Ok(result) => result.map_err(|err| self.map_ollama_error("list models", err, None)),
Err(_) => Err(Error::Timeout(
"Timed out while contacting the local Ollama daemon".to_string(),
)),
}
} else {
client
.list_local_models()
.await
.map_err(|err| self.map_ollama_error("list models", err, None))
};
let models = match list_result { let models = match tags_result {
Ok(models) => models, Ok(models) => models,
Err(err) => { Err(err) => {
let message = err.to_string(); let original_detail = err.to_string();
self.mark_scope_failure(scope, message).await; let (error, banner) = self.decorate_scope_error(scope, &handle.base_url, err);
if banner != original_detail {
debug!(
"Model list for {:?} at {} failed: {}",
scope, handle.base_url, original_detail
);
}
self.mark_scope_failure(scope, banner.clone()).await;
if let Some(cached) = self.cached_scope_models(scope).await { if let Some(cached) = self.cached_scope_models(scope).await {
return Ok(cached); return Ok(cached);
} }
return Err(err); return Err(error);
} }
}; };
let cache = self.model_details_cache.clone(); let cache = self.model_details_cache.clone();
let client = handle.client.clone();
let fetched = join_all(models.into_iter().map(|local| { let fetched = join_all(models.into_iter().map(|local| {
let client = client.clone(); let client = client.clone();
let cache = cache.clone(); let cache = cache.clone();
@@ -780,6 +953,186 @@ impl OllamaProvider {
Ok(converted) Ok(converted)
} }
async fn fetch_scope_tags_with_retry(
&self,
scope: OllamaMode,
handle: &ScopeHandle,
) -> Result<Vec<LocalModel>> {
let attempts = if matches!(scope, OllamaMode::Local) {
LOCAL_TAGS_TIMEOUT_STEPS_MS.len()
} else {
1
};
let mut last_error: Option<Error> = None;
for attempt in 0..attempts {
match self.fetch_scope_tags_once(scope, handle, attempt).await {
Ok(models) => return Ok(models),
Err(err) => {
let should_retry = matches!(scope, OllamaMode::Local)
&& attempt + 1 < attempts
&& matches!(err, Error::Timeout(_) | Error::Network(_));
if should_retry {
debug!(
"Retrying Ollama model list for {:?} (attempt {}): {}",
scope,
attempt + 1,
err
);
last_error = Some(err);
sleep(self.tags_retry_delay(attempt)).await;
continue;
}
return Err(err);
}
}
}
Err(last_error
.unwrap_or_else(|| Error::Unknown("Ollama model list retries exhausted".to_string())))
}
async fn fetch_scope_tags_once(
&self,
scope: OllamaMode,
handle: &ScopeHandle,
attempt: usize,
) -> Result<Vec<LocalModel>> {
#[cfg(test)]
if let Some(result) = pop_tags_override() {
return result;
}
if matches!(scope, OllamaMode::Local) {
match self.list_local_models_via_cli() {
Ok(models) => return Ok(models),
Err(err) => {
debug!("`ollama ls` failed ({}); falling back to HTTP listing", err);
}
}
}
let url = handle.api_url("tags");
let response = handle
.http_client
.get(&url)
.timeout(self.tags_request_timeout(scope, attempt))
.send()
.await
.map_err(|err| map_reqwest_error("list models", err))?;
if !response.status().is_success() {
let status = response.status();
let detail = response.text().await.unwrap_or_else(|err| err.to_string());
return Err(self.map_http_failure("list models", status, detail, None));
}
let bytes = response
.bytes()
.await
.map_err(|err| map_reqwest_error("list models", err))?;
let parsed: TagsResponse = serde_json::from_slice(&bytes)?;
Ok(parsed.models)
}
fn tags_request_timeout(&self, scope: OllamaMode, attempt: usize) -> Duration {
if matches!(scope, OllamaMode::Local) {
let idx = attempt.min(LOCAL_TAGS_TIMEOUT_STEPS_MS.len() - 1);
Duration::from_millis(LOCAL_TAGS_TIMEOUT_STEPS_MS[idx])
} else {
self.request_timeout
}
}
fn tags_retry_delay(&self, attempt: usize) -> Duration {
let idx = attempt.min(LOCAL_TAGS_RETRY_DELAYS_MS.len() - 1);
Duration::from_millis(LOCAL_TAGS_RETRY_DELAYS_MS[idx])
}
fn list_local_models_via_cli(&self) -> Result<Vec<LocalModel>> {
let output = Command::new("ollama")
.arg("ls")
.output()
.map_err(|err| {
Error::Provider(anyhow!(
"Failed to execute `ollama ls`: {err}. Ensure the Ollama CLI is installed and accessible in PATH."
))
})?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(Error::Provider(anyhow!(
"`ollama ls` exited with status {}: {}",
output.status,
stderr.trim()
)));
}
let stdout = String::from_utf8(output.stdout).map_err(|err| {
Error::Provider(anyhow!("`ollama ls` returned non-UTF8 output: {err}"))
})?;
let mut models = Vec::new();
for line in stdout.lines() {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
let lowercase = trimmed.to_ascii_lowercase();
if lowercase.starts_with("name") {
continue;
}
let mut parts = trimmed.split_whitespace();
let Some(name) = parts.next() else {
continue;
};
let metadata_start = trimmed[name.len()..].trim();
models.push(LocalModel {
name: name.to_string(),
modified_at: metadata_start.to_string(),
size: 0,
});
}
Ok(models)
}
fn decorate_scope_error(
&self,
scope: OllamaMode,
base_url: &str,
err: Error,
) -> (Error, String) {
if matches!(scope, OllamaMode::Local) {
match err {
Error::Timeout(_) => {
let message = format_local_unreachable_message(base_url);
(Error::Timeout(message.clone()), message)
}
Error::Network(original) => {
if is_connectivity_error(&original) {
let message = format_local_unreachable_message(base_url);
(Error::Network(message.clone()), message)
} else {
let message = original.clone();
(Error::Network(original), message)
}
}
other => {
let message = other.to_string();
(other, message)
}
}
} else {
let message = err.to_string();
(err, message)
}
}
fn convert_detailed_model_info( fn convert_detailed_model_info(
mode: OllamaMode, mode: OllamaMode,
model_name: &str, model_name: &str,
@@ -893,7 +1246,7 @@ impl OllamaProvider {
id: name.clone(), id: name.clone(),
name, name,
description: Some(description), description: Some(description),
provider: "ollama".to_string(), provider: self.provider_name.clone(),
context_window: None, context_window: None,
capabilities, capabilities,
supports_tools: false, supports_tools: false,
@@ -948,7 +1301,7 @@ impl OllamaProvider {
StatusCode::NOT_FOUND => { StatusCode::NOT_FOUND => {
if let Some(model) = model { if let Some(model) = model {
Error::InvalidInput(format!( Error::InvalidInput(format!(
"Model '{model}' was not found at {}. Verify the name or pull it with `ollama pull`.", "Model '{model}' was not found at {}. Verify the name or pull it with `ollama pull {model}`.",
self.base_url self.base_url
)) ))
} else { } else {
@@ -992,7 +1345,7 @@ impl LlmProvider for OllamaProvider {
Self: 'a; Self: 'a;
fn name(&self) -> &str { fn name(&self) -> &str {
"ollama" &self.provider_name
} }
fn list_models(&self) -> Self::ListModelsFuture<'_> { fn list_models(&self) -> Self::ListModelsFuture<'_> {
@@ -1056,10 +1409,11 @@ impl LlmProvider for OllamaProvider {
fn health_check(&self) -> Self::HealthCheckFuture<'_> { fn health_check(&self) -> Self::HealthCheckFuture<'_> {
Box::pin(async move { Box::pin(async move {
let url = self.api_url("version"); let url = self.api_url("tags?limit=1");
let response = self let response = self
.http_client .http_client
.get(&url) .get(&url)
.timeout(Duration::from_millis(HEALTHCHECK_TIMEOUT_MS))
.send() .send()
.await .await
.map_err(|err| map_reqwest_error("health check", err))?; .map_err(|err| map_reqwest_error("health check", err))?;
@@ -1364,6 +1718,46 @@ fn value_to_u64(value: &Value) -> Option<u64> {
} }
} }
fn format_local_unreachable_message(base_url: &str) -> String {
let display = display_host_port(base_url);
format!(
"Ollama not reachable on {display}. Start the Ollama daemon (`ollama serve`) and try again."
)
}
fn display_host_port(base_url: &str) -> String {
Url::parse(base_url)
.ok()
.and_then(|url| {
url.host_str().map(|host| {
if let Some(port) = url.port() {
format!("{host}:{port}")
} else {
host.to_string()
}
})
})
.unwrap_or_else(|| base_url.to_string())
}
fn is_connectivity_error(message: &str) -> bool {
let lower = message.to_ascii_lowercase();
const CONNECTIVITY_MARKERS: &[&str] = &[
"connection refused",
"failed to connect",
"connect timeout",
"timed out while contacting",
"dns error",
"failed to lookup address",
"no route to host",
"host unreachable",
];
CONNECTIVITY_MARKERS
.iter()
.any(|marker| lower.contains(marker))
}
fn env_var_non_empty(name: &str) -> Option<String> { fn env_var_non_empty(name: &str) -> Option<String> {
env::var(name) env::var(name)
.ok() .ok()
@@ -1371,6 +1765,13 @@ fn env_var_non_empty(name: &str) -> Option<String> {
.filter(|value| !value.is_empty()) .filter(|value| !value.is_empty())
} }
fn resolve_api_key_env_hint(env_var: Option<&str>) -> Option<String> {
env_var
.map(str::trim)
.filter(|value| !value.is_empty())
.and_then(env_var_non_empty)
}
fn resolve_api_key(configured: Option<String>) -> Option<String> { fn resolve_api_key(configured: Option<String>) -> Option<String> {
let raw = configured?.trim().to_string(); let raw = configured?.trim().to_string();
if raw.is_empty() { if raw.is_empty() {
@@ -1545,7 +1946,8 @@ mod tests {
Value::String("local".to_string()), Value::String("local".to_string()),
); );
let provider = OllamaProvider::from_config(&config, None).expect("provider constructed"); let provider = OllamaProvider::from_config("ollama_local", &config, None)
.expect("provider constructed");
assert_eq!(provider.mode, OllamaMode::Local); assert_eq!(provider.mode, OllamaMode::Local);
assert_eq!(provider.base_url, "http://localhost:11434"); assert_eq!(provider.base_url, "http://localhost:11434");
@@ -1563,7 +1965,8 @@ mod tests {
}; };
// simulate missing explicit mode; defaults to auto // simulate missing explicit mode; defaults to auto
let provider = OllamaProvider::from_config(&config, None).expect("provider constructed"); let provider = OllamaProvider::from_config("ollama_local", &config, None)
.expect("provider constructed");
assert_eq!(provider.mode, OllamaMode::Local); assert_eq!(provider.mode, OllamaMode::Local);
assert_eq!(provider.base_url, "http://localhost:11434"); assert_eq!(provider.base_url, "http://localhost:11434");
@@ -1584,12 +1987,191 @@ mod tests {
Value::String("auto".to_string()), Value::String("auto".to_string()),
); );
let provider = OllamaProvider::from_config(&config, None).expect("provider constructed"); let provider = OllamaProvider::from_config("ollama_cloud", &config, None)
.expect("provider constructed");
assert_eq!(provider.mode, OllamaMode::Cloud); assert_eq!(provider.mode, OllamaMode::Cloud);
assert_eq!(provider.base_url, CLOUD_BASE_URL); assert_eq!(provider.base_url, CLOUD_BASE_URL);
} }
#[test]
fn cloud_provider_requires_api_key() {
let config = ProviderConfig {
enabled: true,
provider_type: "ollama_cloud".to_string(),
base_url: None,
api_key: None,
api_key_env: None,
extra: HashMap::new(),
};
let err = OllamaProvider::from_config("ollama_cloud", &config, None)
.expect_err("expected config error");
match err {
Error::Config(message) => {
assert!(message.contains("API key"));
}
other => panic!("unexpected error variant: {other:?}"),
}
}
#[test]
fn cloud_provider_uses_explicit_api_key() {
let config = ProviderConfig {
enabled: true,
provider_type: "ollama_cloud".to_string(),
base_url: None,
api_key: Some("secret-cloud-key".to_string()),
api_key_env: None,
extra: HashMap::new(),
};
let provider = OllamaProvider::from_config("ollama_cloud", &config, None)
.expect("provider constructed");
assert_eq!(provider.name(), "ollama_cloud");
assert_eq!(provider.mode, OllamaMode::Cloud);
assert_eq!(provider.base_url, CLOUD_BASE_URL);
}
#[test]
fn cloud_provider_reads_api_key_from_env_hint() {
let config = ProviderConfig {
enabled: true,
provider_type: "ollama_cloud".to_string(),
base_url: None,
api_key: None,
api_key_env: Some("OLLAMA_TEST_CLOUD_KEY".to_string()),
extra: HashMap::new(),
};
unsafe {
std::env::set_var("OLLAMA_TEST_CLOUD_KEY", "env-secret");
}
assert!(std::env::var("OLLAMA_TEST_CLOUD_KEY").is_ok());
assert!(resolve_api_key_env_hint(config.api_key_env.as_deref()).is_some());
assert_eq!(config.api_key_env.as_deref(), Some("OLLAMA_TEST_CLOUD_KEY"));
let provider = OllamaProvider::from_config("ollama_cloud", &config, None)
.expect("provider constructed");
assert_eq!(provider.name(), "ollama_cloud");
assert_eq!(provider.mode, OllamaMode::Cloud);
unsafe {
std::env::remove_var("OLLAMA_TEST_CLOUD_KEY");
}
}
#[test]
fn fetch_scope_tags_with_retry_success_uses_override() {
let provider = OllamaProvider::new("http://localhost:11434").expect("provider constructed");
let handle = ScopeHandle::new(
provider.client.clone(),
provider.http_client.clone(),
provider.base_url.clone(),
);
let _guard = set_tags_override(vec![Ok(vec![LocalModel {
name: "llama3".into(),
modified_at: "2024-01-01T00:00:00Z".into(),
size: 42,
}])]);
let models = block_on(provider.fetch_scope_tags_with_retry(OllamaMode::Local, &handle))
.expect("models returned");
assert_eq!(models.len(), 1);
assert_eq!(models[0].name, "llama3");
}
#[test]
fn fetch_scope_tags_with_retry_retries_on_timeout_then_succeeds() {
let provider = OllamaProvider::new("http://localhost:11434").expect("provider constructed");
let handle = ScopeHandle::new(
provider.client.clone(),
provider.http_client.clone(),
provider.base_url.clone(),
);
let _guard = set_tags_override(vec![
Err(Error::Timeout("first attempt".into())),
Ok(vec![LocalModel {
name: "llama3".into(),
modified_at: "2024-01-01T00:00:00Z".into(),
size: 42,
}]),
]);
let models = block_on(provider.fetch_scope_tags_with_retry(OllamaMode::Local, &handle))
.expect("models returned after retry");
assert_eq!(models.len(), 1);
assert_eq!(models[0].name, "llama3");
}
#[test]
fn decorate_scope_error_returns_friendly_message_for_connectivity() {
let provider = OllamaProvider::new("http://localhost:11434").expect("provider constructed");
let (error, message) = provider.decorate_scope_error(
OllamaMode::Local,
"http://localhost:11434",
Error::Network("failed to connect to host".into()),
);
assert!(matches!(
error,
Error::Network(ref text) if text.contains("Ollama not reachable")
));
assert!(message.contains("Ollama not reachable"));
assert!(message.contains("localhost:11434"));
}
#[test]
fn decorate_scope_error_preserves_http_failure_message() {
let provider = OllamaProvider::new("http://localhost:11434").expect("provider constructed");
let original = "Ollama list models failed (500): boom".to_string();
let (error, message) = provider.decorate_scope_error(
OllamaMode::Local,
"http://localhost:11434",
Error::Network(original.clone()),
);
assert!(matches!(error, Error::Network(ref text) if text.contains("500")));
assert_eq!(message, original);
}
#[test]
fn decorate_scope_error_translates_timeout() {
let provider = OllamaProvider::new("http://localhost:11434").expect("provider constructed");
let (error, message) = provider.decorate_scope_error(
OllamaMode::Local,
"http://localhost:11434",
Error::Timeout("deadline exceeded".into()),
);
assert!(matches!(
error,
Error::Timeout(ref text) if text.contains("Ollama not reachable")
));
assert!(message.contains("Ollama not reachable"));
}
#[test]
fn map_http_failure_model_not_found_suggests_pull_hint() {
let provider = OllamaProvider::new("http://localhost:11434").expect("provider constructed");
let err = provider.map_http_failure(
"chat",
StatusCode::NOT_FOUND,
"missing model".to_string(),
Some("llama3"),
);
let message = match err {
Error::InvalidInput(message) => message,
other => panic!("unexpected error variant: {other:?}"),
};
assert!(message.contains("ollama pull llama3"));
}
#[test] #[test]
fn build_model_options_merges_parameters() { fn build_model_options_merges_parameters() {
let mut parameters = ChatParameters::default(); let mut parameters = ChatParameters::default();
@@ -1630,13 +2212,19 @@ mod tests {
} }
#[cfg(test)] #[cfg(test)]
struct ProbeOverrideGuard; struct ProbeOverrideGuard {
gate: Option<MutexGuard<'static, ()>>,
}
#[cfg(test)] #[cfg(test)]
impl ProbeOverrideGuard { impl ProbeOverrideGuard {
fn set(value: Option<bool>) -> Self { fn set(value: Option<bool>) -> Self {
let gate = PROBE_OVERRIDE_GATE
.get_or_init(|| Mutex::new(()))
.lock()
.expect("probe override gate mutex poisoned");
set_probe_override(value); set_probe_override(value);
ProbeOverrideGuard ProbeOverrideGuard { gate: Some(gate) }
} }
} }
@@ -1644,6 +2232,7 @@ impl ProbeOverrideGuard {
impl Drop for ProbeOverrideGuard { impl Drop for ProbeOverrideGuard {
fn drop(&mut self) { fn drop(&mut self) {
set_probe_override(None); set_probe_override(None);
self.gate.take();
} }
} }
@@ -1666,7 +2255,8 @@ fn auto_mode_with_api_key_and_successful_probe_prefers_local() {
assert!(probe_default_local_daemon(Duration::from_millis(1))); assert!(probe_default_local_daemon(Duration::from_millis(1)));
let provider = OllamaProvider::from_config(&config, None).expect("provider constructed"); let provider =
OllamaProvider::from_config("ollama_local", &config, None).expect("provider constructed");
assert_eq!(provider.mode, OllamaMode::Local); assert_eq!(provider.mode, OllamaMode::Local);
assert_eq!(provider.base_url, "http://localhost:11434"); assert_eq!(provider.base_url, "http://localhost:11434");
@@ -1689,7 +2279,8 @@ fn auto_mode_with_api_key_and_failed_probe_prefers_cloud() {
Value::String("auto".to_string()), Value::String("auto".to_string()),
); );
let provider = OllamaProvider::from_config(&config, None).expect("provider constructed"); let provider =
OllamaProvider::from_config("ollama_cloud", &config, None).expect("provider constructed");
assert_eq!(provider.mode, OllamaMode::Cloud); assert_eq!(provider.mode, OllamaMode::Cloud);
assert_eq!(provider.base_url, CLOUD_BASE_URL); assert_eq!(provider.base_url, CLOUD_BASE_URL);
@@ -1706,7 +2297,8 @@ fn annotate_scope_status_adds_capabilities_for_unavailable_scopes() {
extra: HashMap::new(), extra: HashMap::new(),
}; };
let provider = OllamaProvider::from_config(&config, None).expect("provider constructed"); let provider =
OllamaProvider::from_config("ollama_local", &config, None).expect("provider constructed");
let mut models = vec![ModelInfo { let mut models = vec![ModelInfo {
id: "llama3".to_string(), id: "llama3".to_string(),