feat(providers/ollama): add variant support, retryable tag fetching with CLI fallback, and configurable provider name for robust model listing and health checks
This commit is contained in:
@@ -127,7 +127,8 @@ fn provider_from_config() -> Result<Arc<dyn Provider>, RpcError> {
|
|||||||
|
|
||||||
match provider_cfg.provider_type.as_str() {
|
match provider_cfg.provider_type.as_str() {
|
||||||
"ollama" | "ollama_cloud" => {
|
"ollama" | "ollama_cloud" => {
|
||||||
let provider = OllamaProvider::from_config(&provider_cfg, Some(&config.general))
|
let provider =
|
||||||
|
OllamaProvider::from_config(&provider_key, &provider_cfg, Some(&config.general))
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
RpcError::internal_error(format!(
|
RpcError::internal_error(format!(
|
||||||
"Failed to init Ollama provider from config: {e}"
|
"Failed to init Ollama provider from config: {e}"
|
||||||
|
|||||||
@@ -185,7 +185,8 @@ fn build_local_provider(cfg: &Config) -> Result<Arc<dyn Provider>> {
|
|||||||
|
|
||||||
match provider_cfg.provider_type.as_str() {
|
match provider_cfg.provider_type.as_str() {
|
||||||
"ollama" | "ollama_cloud" => {
|
"ollama" | "ollama_cloud" => {
|
||||||
let provider = OllamaProvider::from_config(provider_cfg, Some(&cfg.general))?;
|
let provider =
|
||||||
|
OllamaProvider::from_config(&provider_name, provider_cfg, Some(&cfg.general))?;
|
||||||
Ok(Arc::new(provider) as Arc<dyn Provider>)
|
Ok(Arc::new(provider) as Arc<dyn Provider>)
|
||||||
}
|
}
|
||||||
other => Err(anyhow!(format!(
|
other => Err(anyhow!(format!(
|
||||||
|
|||||||
@@ -161,7 +161,7 @@ async fn status(provider: String) -> Result<()> {
|
|||||||
Value::String("cloud".to_string()),
|
Value::String("cloud".to_string()),
|
||||||
);
|
);
|
||||||
|
|
||||||
let ollama = OllamaProvider::from_config(&runtime_cfg, Some(&config.general))
|
let ollama = OllamaProvider::from_config(&provider, &runtime_cfg, Some(&config.general))
|
||||||
.with_context(|| "Failed to construct Ollama provider. Run `owlen cloud setup` first.")?;
|
.with_context(|| "Failed to construct Ollama provider. Run `owlen cloud setup` first.")?;
|
||||||
|
|
||||||
match ollama.health_check().await {
|
match ollama.health_check().await {
|
||||||
@@ -212,7 +212,7 @@ async fn models(provider: String) -> Result<()> {
|
|||||||
Value::String("cloud".to_string()),
|
Value::String("cloud".to_string()),
|
||||||
);
|
);
|
||||||
|
|
||||||
let ollama = OllamaProvider::from_config(&runtime_cfg, Some(&config.general))
|
let ollama = OllamaProvider::from_config(&provider, &runtime_cfg, Some(&config.general))
|
||||||
.with_context(|| "Failed to construct Ollama provider. Run `owlen cloud setup` first.")?;
|
.with_context(|| "Failed to construct Ollama provider. Run `owlen cloud setup` first.")?;
|
||||||
|
|
||||||
match ollama.list_models().await {
|
match ollama.list_models().await {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ use std::{
|
|||||||
env,
|
env,
|
||||||
net::{SocketAddr, TcpStream},
|
net::{SocketAddr, TcpStream},
|
||||||
pin::Pin,
|
pin::Pin,
|
||||||
|
process::Command,
|
||||||
sync::Arc,
|
sync::Arc,
|
||||||
time::{Duration, Instant, SystemTime},
|
time::{Duration, Instant, SystemTime},
|
||||||
};
|
};
|
||||||
@@ -23,11 +24,12 @@ use ollama_rs::{
|
|||||||
models::{LocalModel, ModelInfo as OllamaModelInfo, ModelOptions},
|
models::{LocalModel, ModelInfo as OllamaModelInfo, ModelOptions},
|
||||||
};
|
};
|
||||||
use reqwest::{Client, StatusCode, Url};
|
use reqwest::{Client, StatusCode, Url};
|
||||||
|
use serde::Deserialize;
|
||||||
use serde_json::{Map as JsonMap, Value, json};
|
use serde_json::{Map as JsonMap, Value, json};
|
||||||
use tokio::{sync::RwLock, time::timeout};
|
use tokio::{sync::RwLock, time::sleep};
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
use std::sync::{Mutex, OnceLock};
|
use std::sync::{Mutex, MutexGuard, OnceLock};
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
use tokio_test::block_on;
|
use tokio_test::block_on;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
@@ -48,6 +50,9 @@ const DEFAULT_MODEL_CACHE_TTL_SECS: u64 = 60;
|
|||||||
pub(crate) const CLOUD_BASE_URL: &str = OLLAMA_CLOUD_BASE_URL;
|
pub(crate) const CLOUD_BASE_URL: &str = OLLAMA_CLOUD_BASE_URL;
|
||||||
const LOCAL_PROBE_TIMEOUT_MS: u64 = 200;
|
const LOCAL_PROBE_TIMEOUT_MS: u64 = 200;
|
||||||
const LOCAL_PROBE_TARGETS: &[&str] = &["127.0.0.1:11434", "[::1]:11434"];
|
const LOCAL_PROBE_TARGETS: &[&str] = &["127.0.0.1:11434", "[::1]:11434"];
|
||||||
|
const LOCAL_TAGS_TIMEOUT_STEPS_MS: [u64; 3] = [400, 800, 1_600];
|
||||||
|
const LOCAL_TAGS_RETRY_DELAYS_MS: [u64; 2] = [150, 300];
|
||||||
|
const HEALTHCHECK_TIMEOUT_MS: u64 = 1_000;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
enum OllamaMode {
|
enum OllamaMode {
|
||||||
@@ -122,8 +127,53 @@ impl ScopeSnapshot {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
struct ScopeHandle {
|
||||||
|
client: Ollama,
|
||||||
|
http_client: Client,
|
||||||
|
base_url: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ScopeHandle {
|
||||||
|
fn new(client: Ollama, http_client: Client, base_url: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
client,
|
||||||
|
http_client,
|
||||||
|
base_url: base_url.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn api_url(&self, endpoint: &str) -> String {
|
||||||
|
build_api_endpoint(&self.base_url, endpoint)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct TagsResponse {
|
||||||
|
#[serde(default)]
|
||||||
|
models: Vec<LocalModel>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
enum ProviderVariant {
|
||||||
|
Local,
|
||||||
|
Cloud,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ProviderVariant {
|
||||||
|
fn supports_local(self) -> bool {
|
||||||
|
matches!(self, ProviderVariant::Local)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn supports_cloud(self) -> bool {
|
||||||
|
matches!(self, ProviderVariant::Cloud)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct OllamaOptions {
|
struct OllamaOptions {
|
||||||
|
provider_name: String,
|
||||||
|
variant: ProviderVariant,
|
||||||
mode: OllamaMode,
|
mode: OllamaMode,
|
||||||
base_url: String,
|
base_url: String,
|
||||||
request_timeout: Duration,
|
request_timeout: Duration,
|
||||||
@@ -133,8 +183,15 @@ struct OllamaOptions {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl OllamaOptions {
|
impl OllamaOptions {
|
||||||
fn new(mode: OllamaMode, base_url: impl Into<String>) -> Self {
|
fn new(
|
||||||
|
provider_name: impl Into<String>,
|
||||||
|
variant: ProviderVariant,
|
||||||
|
mode: OllamaMode,
|
||||||
|
base_url: impl Into<String>,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
provider_name: provider_name.into(),
|
||||||
|
variant,
|
||||||
mode,
|
mode,
|
||||||
base_url: base_url.into(),
|
base_url: base_url.into(),
|
||||||
request_timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
|
request_timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
|
||||||
@@ -153,6 +210,8 @@ impl OllamaOptions {
|
|||||||
/// Ollama provider implementation backed by `ollama-rs`.
|
/// Ollama provider implementation backed by `ollama-rs`.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct OllamaProvider {
|
pub struct OllamaProvider {
|
||||||
|
provider_name: String,
|
||||||
|
variant: ProviderVariant,
|
||||||
mode: OllamaMode,
|
mode: OllamaMode,
|
||||||
client: Ollama,
|
client: Ollama,
|
||||||
http_client: Client,
|
http_client: Client,
|
||||||
@@ -198,6 +257,16 @@ fn is_explicit_cloud_base(base_url: Option<&str>) -> bool {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
static PROBE_OVERRIDE: OnceLock<Mutex<Option<bool>>> = OnceLock::new();
|
static PROBE_OVERRIDE: OnceLock<Mutex<Option<bool>>> = OnceLock::new();
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
static TAGS_OVERRIDE: OnceLock<Mutex<Vec<std::result::Result<Vec<LocalModel>, Error>>>> =
|
||||||
|
OnceLock::new();
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
static TAGS_OVERRIDE_GATE: OnceLock<Mutex<()>> = OnceLock::new();
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
static PROBE_OVERRIDE_GATE: OnceLock<Mutex<()>> = OnceLock::new();
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
fn set_probe_override(value: Option<bool>) {
|
fn set_probe_override(value: Option<bool>) {
|
||||||
let guard = PROBE_OVERRIDE.get_or_init(|| Mutex::new(None));
|
let guard = PROBE_OVERRIDE.get_or_init(|| Mutex::new(None));
|
||||||
@@ -213,6 +282,51 @@ fn probe_override_value() -> Option<bool> {
|
|||||||
.to_owned()
|
.to_owned()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
fn set_tags_override(
|
||||||
|
sequence: Vec<std::result::Result<Vec<LocalModel>, Error>>,
|
||||||
|
) -> TagsOverrideGuard {
|
||||||
|
let gate = TAGS_OVERRIDE_GATE
|
||||||
|
.get_or_init(|| Mutex::new(()))
|
||||||
|
.lock()
|
||||||
|
.expect("tags override gate mutex poisoned");
|
||||||
|
|
||||||
|
let store = TAGS_OVERRIDE.get_or_init(|| Mutex::new(Vec::new()));
|
||||||
|
{
|
||||||
|
let mut guard = store.lock().expect("tags override mutex poisoned");
|
||||||
|
guard.clear();
|
||||||
|
for item in sequence.into_iter().rev() {
|
||||||
|
guard.push(item);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TagsOverrideGuard { gate: Some(gate) }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
fn pop_tags_override() -> Option<std::result::Result<Vec<LocalModel>, Error>> {
|
||||||
|
TAGS_OVERRIDE
|
||||||
|
.get_or_init(|| Mutex::new(Vec::new()))
|
||||||
|
.lock()
|
||||||
|
.expect("tags override mutex poisoned")
|
||||||
|
.pop()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
struct TagsOverrideGuard {
|
||||||
|
gate: Option<MutexGuard<'static, ()>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
impl Drop for TagsOverrideGuard {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if let Some(store) = TAGS_OVERRIDE.get() {
|
||||||
|
let mut guard = store.lock().expect("tags override mutex poisoned");
|
||||||
|
guard.clear();
|
||||||
|
}
|
||||||
|
self.gate.take();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn probe_default_local_daemon(timeout: Duration) -> bool {
|
fn probe_default_local_daemon(timeout: Duration) -> bool {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
{
|
{
|
||||||
@@ -237,14 +351,46 @@ impl OllamaProvider {
|
|||||||
let input = base_url.into();
|
let input = base_url.into();
|
||||||
let normalized =
|
let normalized =
|
||||||
normalize_base_url(Some(&input), OllamaMode::Local).map_err(Error::Config)?;
|
normalize_base_url(Some(&input), OllamaMode::Local).map_err(Error::Config)?;
|
||||||
Self::with_options(OllamaOptions::new(OllamaMode::Local, normalized))
|
Self::with_options(OllamaOptions::new(
|
||||||
|
"ollama_local",
|
||||||
|
ProviderVariant::Local,
|
||||||
|
OllamaMode::Local,
|
||||||
|
normalized,
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Construct a provider from configuration settings.
|
/// Construct a provider from configuration settings.
|
||||||
pub fn from_config(config: &ProviderConfig, general: Option<&GeneralSettings>) -> Result<Self> {
|
pub fn from_config(
|
||||||
|
provider_id: &str,
|
||||||
|
config: &ProviderConfig,
|
||||||
|
general: Option<&GeneralSettings>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let provider_type = config.provider_type.trim().to_ascii_lowercase();
|
||||||
|
let register_name = {
|
||||||
|
let candidate = provider_id.trim();
|
||||||
|
if candidate.is_empty() {
|
||||||
|
if provider_type.is_empty() {
|
||||||
|
"ollama".to_string()
|
||||||
|
} else {
|
||||||
|
provider_type.clone()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
candidate.replace('-', "_")
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let variant = if register_name == "ollama_cloud" || provider_type == "ollama_cloud" {
|
||||||
|
ProviderVariant::Cloud
|
||||||
|
} else {
|
||||||
|
ProviderVariant::Local
|
||||||
|
};
|
||||||
|
|
||||||
let mut api_key = resolve_api_key(config.api_key.clone())
|
let mut api_key = resolve_api_key(config.api_key.clone())
|
||||||
|
.or_else(|| resolve_api_key_env_hint(config.api_key_env.as_deref()))
|
||||||
.or_else(|| env_var_non_empty("OLLAMA_API_KEY"))
|
.or_else(|| env_var_non_empty("OLLAMA_API_KEY"))
|
||||||
.or_else(|| env_var_non_empty("OLLAMA_CLOUD_API_KEY"));
|
.or_else(|| env_var_non_empty("OLLAMA_CLOUD_API_KEY"));
|
||||||
|
let api_key_present = api_key.is_some();
|
||||||
|
|
||||||
let configured_mode = configured_mode_from_extra(config);
|
let configured_mode = configured_mode_from_extra(config);
|
||||||
let configured_mode_label = config
|
let configured_mode_label = config
|
||||||
.extra
|
.extra
|
||||||
@@ -254,7 +400,7 @@ impl OllamaProvider {
|
|||||||
let base_url = config.base_url.as_deref();
|
let base_url = config.base_url.as_deref();
|
||||||
let base_is_local = is_explicit_local_base(base_url);
|
let base_is_local = is_explicit_local_base(base_url);
|
||||||
let base_is_cloud = is_explicit_cloud_base(base_url);
|
let base_is_cloud = is_explicit_cloud_base(base_url);
|
||||||
let base_is_other = base_url.is_some() && !base_is_local && !base_is_cloud;
|
let _base_is_other = base_url.is_some() && !base_is_local && !base_is_cloud;
|
||||||
|
|
||||||
let mut local_probe_result = None;
|
let mut local_probe_result = None;
|
||||||
let cloud_endpoint = config
|
let cloud_endpoint = config
|
||||||
@@ -265,28 +411,25 @@ impl OllamaProvider {
|
|||||||
.transpose()
|
.transpose()
|
||||||
.map_err(Error::Config)?;
|
.map_err(Error::Config)?;
|
||||||
|
|
||||||
let mode = match configured_mode {
|
if matches!(variant, ProviderVariant::Local) && configured_mode.is_none() {
|
||||||
Some(mode) => mode,
|
let probe = probe_default_local_daemon(Duration::from_millis(LOCAL_PROBE_TIMEOUT_MS));
|
||||||
None => {
|
|
||||||
if base_is_local || base_is_other {
|
|
||||||
OllamaMode::Local
|
|
||||||
} else if base_is_cloud && api_key.is_some() {
|
|
||||||
OllamaMode::Cloud
|
|
||||||
} else {
|
|
||||||
let probe =
|
|
||||||
probe_default_local_daemon(Duration::from_millis(LOCAL_PROBE_TIMEOUT_MS));
|
|
||||||
local_probe_result = Some(probe);
|
local_probe_result = Some(probe);
|
||||||
if probe {
|
|
||||||
OllamaMode::Local
|
|
||||||
} else if api_key.is_some() {
|
|
||||||
OllamaMode::Cloud
|
|
||||||
} else {
|
|
||||||
OllamaMode::Local
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let mode = match variant {
|
||||||
|
ProviderVariant::Local => OllamaMode::Local,
|
||||||
|
ProviderVariant::Cloud => OllamaMode::Cloud,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if matches!(variant, ProviderVariant::Cloud) {
|
||||||
|
if !api_key_present {
|
||||||
|
return Err(Error::Config(
|
||||||
|
"Ollama Cloud API key not configured. Set providers.ollama_cloud.api_key or OLLAMA_CLOUD_API_KEY."
|
||||||
|
.into(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let base_candidate = match mode {
|
let base_candidate = match mode {
|
||||||
OllamaMode::Local => base_url,
|
OllamaMode::Local => base_url,
|
||||||
OllamaMode::Cloud => {
|
OllamaMode::Cloud => {
|
||||||
@@ -301,7 +444,12 @@ impl OllamaProvider {
|
|||||||
let normalized_base_url =
|
let normalized_base_url =
|
||||||
normalize_base_url(base_candidate, mode).map_err(Error::Config)?;
|
normalize_base_url(base_candidate, mode).map_err(Error::Config)?;
|
||||||
|
|
||||||
let mut options = OllamaOptions::new(mode, normalized_base_url.clone());
|
let mut options = OllamaOptions::new(
|
||||||
|
register_name.clone(),
|
||||||
|
variant,
|
||||||
|
mode,
|
||||||
|
normalized_base_url.clone(),
|
||||||
|
);
|
||||||
options.cloud_endpoint = cloud_endpoint.clone();
|
options.cloud_endpoint = cloud_endpoint.clone();
|
||||||
|
|
||||||
if let Some(timeout) = config
|
if let Some(timeout) = config
|
||||||
@@ -327,7 +475,8 @@ impl OllamaProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
debug!(
|
debug!(
|
||||||
"Resolved Ollama provider: mode={:?}, base_url={}, configured_mode={}, api_key_present={}, local_probe={}",
|
"Resolved Ollama provider '{}': mode={:?}, base_url={}, configured_mode={}, api_key_present={}, local_probe={}",
|
||||||
|
register_name,
|
||||||
mode,
|
mode,
|
||||||
normalized_base_url,
|
normalized_base_url,
|
||||||
configured_mode_label,
|
configured_mode_label,
|
||||||
@@ -348,6 +497,8 @@ impl OllamaProvider {
|
|||||||
|
|
||||||
fn with_options(options: OllamaOptions) -> Result<Self> {
|
fn with_options(options: OllamaOptions) -> Result<Self> {
|
||||||
let OllamaOptions {
|
let OllamaOptions {
|
||||||
|
provider_name,
|
||||||
|
variant,
|
||||||
mode,
|
mode,
|
||||||
base_url,
|
base_url,
|
||||||
request_timeout,
|
request_timeout,
|
||||||
@@ -368,6 +519,8 @@ impl OllamaProvider {
|
|||||||
};
|
};
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
|
provider_name: provider_name.trim().to_ascii_lowercase(),
|
||||||
|
variant,
|
||||||
mode,
|
mode,
|
||||||
client: ollama_client,
|
client: ollama_client,
|
||||||
http_client,
|
http_client,
|
||||||
@@ -397,19 +550,47 @@ impl OllamaProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_local_client(&self) -> Result<Option<Ollama>> {
|
fn supports_local_scope(&self) -> bool {
|
||||||
|
self.variant.supports_local()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn supports_cloud_scope(&self) -> bool {
|
||||||
|
self.variant.supports_cloud()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_local_client(&self) -> Result<Option<ScopeHandle>> {
|
||||||
|
if !self.supports_local_scope() {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
if matches!(self.mode, OllamaMode::Local) {
|
if matches!(self.mode, OllamaMode::Local) {
|
||||||
return Ok(Some(self.client.clone()));
|
return Ok(Some(ScopeHandle::new(
|
||||||
|
self.client.clone(),
|
||||||
|
self.http_client.clone(),
|
||||||
|
self.base_url.clone(),
|
||||||
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
let (client, _) =
|
let (client, http_client) =
|
||||||
build_client_for_base(Self::local_base_url(), self.request_timeout, None)?;
|
build_client_for_base(Self::local_base_url(), self.request_timeout, None)?;
|
||||||
Ok(Some(client))
|
Ok(Some(ScopeHandle::new(
|
||||||
|
client,
|
||||||
|
http_client,
|
||||||
|
Self::local_base_url(),
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_cloud_client(&self) -> Result<Option<ScopeHandle>> {
|
||||||
|
if !self.supports_cloud_scope() {
|
||||||
|
return Ok(None);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_cloud_client(&self) -> Result<Option<Ollama>> {
|
|
||||||
if matches!(self.mode, OllamaMode::Cloud) {
|
if matches!(self.mode, OllamaMode::Cloud) {
|
||||||
return Ok(Some(self.client.clone()));
|
return Ok(Some(ScopeHandle::new(
|
||||||
|
self.client.clone(),
|
||||||
|
self.http_client.clone(),
|
||||||
|
self.base_url.clone(),
|
||||||
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
let api_key = match self.api_key.as_deref() {
|
let api_key = match self.api_key.as_deref() {
|
||||||
@@ -419,8 +600,9 @@ impl OllamaProvider {
|
|||||||
|
|
||||||
let endpoint = self.cloud_endpoint.as_deref().unwrap_or(CLOUD_BASE_URL);
|
let endpoint = self.cloud_endpoint.as_deref().unwrap_or(CLOUD_BASE_URL);
|
||||||
|
|
||||||
let (client, _) = build_client_for_base(endpoint, self.request_timeout, Some(api_key))?;
|
let (client, http_client) =
|
||||||
Ok(Some(client))
|
build_client_for_base(endpoint, self.request_timeout, Some(api_key))?;
|
||||||
|
Ok(Some(ScopeHandle::new(client, http_client, endpoint)))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn cached_scope_models(&self, scope: OllamaMode) -> Option<Vec<ModelInfo>> {
|
async fn cached_scope_models(&self, scope: OllamaMode) -> Option<Vec<ModelInfo>> {
|
||||||
@@ -663,9 +845,9 @@ impl OllamaProvider {
|
|||||||
let mut seen: HashSet<String> = HashSet::new();
|
let mut seen: HashSet<String> = HashSet::new();
|
||||||
let mut errors: Vec<Error> = Vec::new();
|
let mut errors: Vec<Error> = Vec::new();
|
||||||
|
|
||||||
if let Some(local_client) = self.build_local_client()? {
|
if let Some(local_handle) = self.build_local_client()? {
|
||||||
match self
|
match self
|
||||||
.fetch_models_for_scope(OllamaMode::Local, local_client.clone())
|
.fetch_models_for_scope(OllamaMode::Local, local_handle)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(models) => {
|
Ok(models) => {
|
||||||
@@ -680,9 +862,9 @@ impl OllamaProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(cloud_client) = self.build_cloud_client()? {
|
if let Some(cloud_handle) = self.build_cloud_client()? {
|
||||||
match self
|
match self
|
||||||
.fetch_models_for_scope(OllamaMode::Cloud, cloud_client.clone())
|
.fetch_models_for_scope(OllamaMode::Cloud, cloud_handle)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(models) => {
|
Ok(models) => {
|
||||||
@@ -711,40 +893,31 @@ impl OllamaProvider {
|
|||||||
async fn fetch_models_for_scope(
|
async fn fetch_models_for_scope(
|
||||||
&self,
|
&self,
|
||||||
scope: OllamaMode,
|
scope: OllamaMode,
|
||||||
client: Ollama,
|
handle: ScopeHandle,
|
||||||
) -> Result<Vec<ModelInfo>> {
|
) -> Result<Vec<ModelInfo>> {
|
||||||
let list_result = if matches!(scope, OllamaMode::Local) {
|
let tags_result = self.fetch_scope_tags_with_retry(scope, &handle).await;
|
||||||
match timeout(
|
|
||||||
Duration::from_millis(LOCAL_PROBE_TIMEOUT_MS),
|
|
||||||
client.list_local_models(),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(result) => result.map_err(|err| self.map_ollama_error("list models", err, None)),
|
|
||||||
Err(_) => Err(Error::Timeout(
|
|
||||||
"Timed out while contacting the local Ollama daemon".to_string(),
|
|
||||||
)),
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
client
|
|
||||||
.list_local_models()
|
|
||||||
.await
|
|
||||||
.map_err(|err| self.map_ollama_error("list models", err, None))
|
|
||||||
};
|
|
||||||
|
|
||||||
let models = match list_result {
|
let models = match tags_result {
|
||||||
Ok(models) => models,
|
Ok(models) => models,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
let message = err.to_string();
|
let original_detail = err.to_string();
|
||||||
self.mark_scope_failure(scope, message).await;
|
let (error, banner) = self.decorate_scope_error(scope, &handle.base_url, err);
|
||||||
|
if banner != original_detail {
|
||||||
|
debug!(
|
||||||
|
"Model list for {:?} at {} failed: {}",
|
||||||
|
scope, handle.base_url, original_detail
|
||||||
|
);
|
||||||
|
}
|
||||||
|
self.mark_scope_failure(scope, banner.clone()).await;
|
||||||
if let Some(cached) = self.cached_scope_models(scope).await {
|
if let Some(cached) = self.cached_scope_models(scope).await {
|
||||||
return Ok(cached);
|
return Ok(cached);
|
||||||
}
|
}
|
||||||
return Err(err);
|
return Err(error);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let cache = self.model_details_cache.clone();
|
let cache = self.model_details_cache.clone();
|
||||||
|
let client = handle.client.clone();
|
||||||
let fetched = join_all(models.into_iter().map(|local| {
|
let fetched = join_all(models.into_iter().map(|local| {
|
||||||
let client = client.clone();
|
let client = client.clone();
|
||||||
let cache = cache.clone();
|
let cache = cache.clone();
|
||||||
@@ -780,6 +953,186 @@ impl OllamaProvider {
|
|||||||
Ok(converted)
|
Ok(converted)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn fetch_scope_tags_with_retry(
|
||||||
|
&self,
|
||||||
|
scope: OllamaMode,
|
||||||
|
handle: &ScopeHandle,
|
||||||
|
) -> Result<Vec<LocalModel>> {
|
||||||
|
let attempts = if matches!(scope, OllamaMode::Local) {
|
||||||
|
LOCAL_TAGS_TIMEOUT_STEPS_MS.len()
|
||||||
|
} else {
|
||||||
|
1
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut last_error: Option<Error> = None;
|
||||||
|
|
||||||
|
for attempt in 0..attempts {
|
||||||
|
match self.fetch_scope_tags_once(scope, handle, attempt).await {
|
||||||
|
Ok(models) => return Ok(models),
|
||||||
|
Err(err) => {
|
||||||
|
let should_retry = matches!(scope, OllamaMode::Local)
|
||||||
|
&& attempt + 1 < attempts
|
||||||
|
&& matches!(err, Error::Timeout(_) | Error::Network(_));
|
||||||
|
|
||||||
|
if should_retry {
|
||||||
|
debug!(
|
||||||
|
"Retrying Ollama model list for {:?} (attempt {}): {}",
|
||||||
|
scope,
|
||||||
|
attempt + 1,
|
||||||
|
err
|
||||||
|
);
|
||||||
|
last_error = Some(err);
|
||||||
|
sleep(self.tags_retry_delay(attempt)).await;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
return Err(err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(last_error
|
||||||
|
.unwrap_or_else(|| Error::Unknown("Ollama model list retries exhausted".to_string())))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn fetch_scope_tags_once(
|
||||||
|
&self,
|
||||||
|
scope: OllamaMode,
|
||||||
|
handle: &ScopeHandle,
|
||||||
|
attempt: usize,
|
||||||
|
) -> Result<Vec<LocalModel>> {
|
||||||
|
#[cfg(test)]
|
||||||
|
if let Some(result) = pop_tags_override() {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
if matches!(scope, OllamaMode::Local) {
|
||||||
|
match self.list_local_models_via_cli() {
|
||||||
|
Ok(models) => return Ok(models),
|
||||||
|
Err(err) => {
|
||||||
|
debug!("`ollama ls` failed ({}); falling back to HTTP listing", err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let url = handle.api_url("tags");
|
||||||
|
let response = handle
|
||||||
|
.http_client
|
||||||
|
.get(&url)
|
||||||
|
.timeout(self.tags_request_timeout(scope, attempt))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|err| map_reqwest_error("list models", err))?;
|
||||||
|
|
||||||
|
if !response.status().is_success() {
|
||||||
|
let status = response.status();
|
||||||
|
let detail = response.text().await.unwrap_or_else(|err| err.to_string());
|
||||||
|
return Err(self.map_http_failure("list models", status, detail, None));
|
||||||
|
}
|
||||||
|
|
||||||
|
let bytes = response
|
||||||
|
.bytes()
|
||||||
|
.await
|
||||||
|
.map_err(|err| map_reqwest_error("list models", err))?;
|
||||||
|
let parsed: TagsResponse = serde_json::from_slice(&bytes)?;
|
||||||
|
Ok(parsed.models)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn tags_request_timeout(&self, scope: OllamaMode, attempt: usize) -> Duration {
|
||||||
|
if matches!(scope, OllamaMode::Local) {
|
||||||
|
let idx = attempt.min(LOCAL_TAGS_TIMEOUT_STEPS_MS.len() - 1);
|
||||||
|
Duration::from_millis(LOCAL_TAGS_TIMEOUT_STEPS_MS[idx])
|
||||||
|
} else {
|
||||||
|
self.request_timeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn tags_retry_delay(&self, attempt: usize) -> Duration {
|
||||||
|
let idx = attempt.min(LOCAL_TAGS_RETRY_DELAYS_MS.len() - 1);
|
||||||
|
Duration::from_millis(LOCAL_TAGS_RETRY_DELAYS_MS[idx])
|
||||||
|
}
|
||||||
|
|
||||||
|
fn list_local_models_via_cli(&self) -> Result<Vec<LocalModel>> {
|
||||||
|
let output = Command::new("ollama")
|
||||||
|
.arg("ls")
|
||||||
|
.output()
|
||||||
|
.map_err(|err| {
|
||||||
|
Error::Provider(anyhow!(
|
||||||
|
"Failed to execute `ollama ls`: {err}. Ensure the Ollama CLI is installed and accessible in PATH."
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if !output.status.success() {
|
||||||
|
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||||
|
return Err(Error::Provider(anyhow!(
|
||||||
|
"`ollama ls` exited with status {}: {}",
|
||||||
|
output.status,
|
||||||
|
stderr.trim()
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let stdout = String::from_utf8(output.stdout).map_err(|err| {
|
||||||
|
Error::Provider(anyhow!("`ollama ls` returned non-UTF8 output: {err}"))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let mut models = Vec::new();
|
||||||
|
for line in stdout.lines() {
|
||||||
|
let trimmed = line.trim();
|
||||||
|
if trimmed.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let lowercase = trimmed.to_ascii_lowercase();
|
||||||
|
if lowercase.starts_with("name") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut parts = trimmed.split_whitespace();
|
||||||
|
let Some(name) = parts.next() else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
let metadata_start = trimmed[name.len()..].trim();
|
||||||
|
|
||||||
|
models.push(LocalModel {
|
||||||
|
name: name.to_string(),
|
||||||
|
modified_at: metadata_start.to_string(),
|
||||||
|
size: 0,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(models)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn decorate_scope_error(
|
||||||
|
&self,
|
||||||
|
scope: OllamaMode,
|
||||||
|
base_url: &str,
|
||||||
|
err: Error,
|
||||||
|
) -> (Error, String) {
|
||||||
|
if matches!(scope, OllamaMode::Local) {
|
||||||
|
match err {
|
||||||
|
Error::Timeout(_) => {
|
||||||
|
let message = format_local_unreachable_message(base_url);
|
||||||
|
(Error::Timeout(message.clone()), message)
|
||||||
|
}
|
||||||
|
Error::Network(original) => {
|
||||||
|
if is_connectivity_error(&original) {
|
||||||
|
let message = format_local_unreachable_message(base_url);
|
||||||
|
(Error::Network(message.clone()), message)
|
||||||
|
} else {
|
||||||
|
let message = original.clone();
|
||||||
|
(Error::Network(original), message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
other => {
|
||||||
|
let message = other.to_string();
|
||||||
|
(other, message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let message = err.to_string();
|
||||||
|
(err, message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn convert_detailed_model_info(
|
fn convert_detailed_model_info(
|
||||||
mode: OllamaMode,
|
mode: OllamaMode,
|
||||||
model_name: &str,
|
model_name: &str,
|
||||||
@@ -893,7 +1246,7 @@ impl OllamaProvider {
|
|||||||
id: name.clone(),
|
id: name.clone(),
|
||||||
name,
|
name,
|
||||||
description: Some(description),
|
description: Some(description),
|
||||||
provider: "ollama".to_string(),
|
provider: self.provider_name.clone(),
|
||||||
context_window: None,
|
context_window: None,
|
||||||
capabilities,
|
capabilities,
|
||||||
supports_tools: false,
|
supports_tools: false,
|
||||||
@@ -948,7 +1301,7 @@ impl OllamaProvider {
|
|||||||
StatusCode::NOT_FOUND => {
|
StatusCode::NOT_FOUND => {
|
||||||
if let Some(model) = model {
|
if let Some(model) = model {
|
||||||
Error::InvalidInput(format!(
|
Error::InvalidInput(format!(
|
||||||
"Model '{model}' was not found at {}. Verify the name or pull it with `ollama pull`.",
|
"Model '{model}' was not found at {}. Verify the name or pull it with `ollama pull {model}`.",
|
||||||
self.base_url
|
self.base_url
|
||||||
))
|
))
|
||||||
} else {
|
} else {
|
||||||
@@ -992,7 +1345,7 @@ impl LlmProvider for OllamaProvider {
|
|||||||
Self: 'a;
|
Self: 'a;
|
||||||
|
|
||||||
fn name(&self) -> &str {
|
fn name(&self) -> &str {
|
||||||
"ollama"
|
&self.provider_name
|
||||||
}
|
}
|
||||||
|
|
||||||
fn list_models(&self) -> Self::ListModelsFuture<'_> {
|
fn list_models(&self) -> Self::ListModelsFuture<'_> {
|
||||||
@@ -1056,10 +1409,11 @@ impl LlmProvider for OllamaProvider {
|
|||||||
|
|
||||||
fn health_check(&self) -> Self::HealthCheckFuture<'_> {
|
fn health_check(&self) -> Self::HealthCheckFuture<'_> {
|
||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
let url = self.api_url("version");
|
let url = self.api_url("tags?limit=1");
|
||||||
let response = self
|
let response = self
|
||||||
.http_client
|
.http_client
|
||||||
.get(&url)
|
.get(&url)
|
||||||
|
.timeout(Duration::from_millis(HEALTHCHECK_TIMEOUT_MS))
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
.map_err(|err| map_reqwest_error("health check", err))?;
|
.map_err(|err| map_reqwest_error("health check", err))?;
|
||||||
@@ -1364,6 +1718,46 @@ fn value_to_u64(value: &Value) -> Option<u64> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn format_local_unreachable_message(base_url: &str) -> String {
|
||||||
|
let display = display_host_port(base_url);
|
||||||
|
format!(
|
||||||
|
"Ollama not reachable on {display}. Start the Ollama daemon (`ollama serve`) and try again."
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn display_host_port(base_url: &str) -> String {
|
||||||
|
Url::parse(base_url)
|
||||||
|
.ok()
|
||||||
|
.and_then(|url| {
|
||||||
|
url.host_str().map(|host| {
|
||||||
|
if let Some(port) = url.port() {
|
||||||
|
format!("{host}:{port}")
|
||||||
|
} else {
|
||||||
|
host.to_string()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.unwrap_or_else(|| base_url.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_connectivity_error(message: &str) -> bool {
|
||||||
|
let lower = message.to_ascii_lowercase();
|
||||||
|
const CONNECTIVITY_MARKERS: &[&str] = &[
|
||||||
|
"connection refused",
|
||||||
|
"failed to connect",
|
||||||
|
"connect timeout",
|
||||||
|
"timed out while contacting",
|
||||||
|
"dns error",
|
||||||
|
"failed to lookup address",
|
||||||
|
"no route to host",
|
||||||
|
"host unreachable",
|
||||||
|
];
|
||||||
|
|
||||||
|
CONNECTIVITY_MARKERS
|
||||||
|
.iter()
|
||||||
|
.any(|marker| lower.contains(marker))
|
||||||
|
}
|
||||||
|
|
||||||
fn env_var_non_empty(name: &str) -> Option<String> {
|
fn env_var_non_empty(name: &str) -> Option<String> {
|
||||||
env::var(name)
|
env::var(name)
|
||||||
.ok()
|
.ok()
|
||||||
@@ -1371,6 +1765,13 @@ fn env_var_non_empty(name: &str) -> Option<String> {
|
|||||||
.filter(|value| !value.is_empty())
|
.filter(|value| !value.is_empty())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn resolve_api_key_env_hint(env_var: Option<&str>) -> Option<String> {
|
||||||
|
env_var
|
||||||
|
.map(str::trim)
|
||||||
|
.filter(|value| !value.is_empty())
|
||||||
|
.and_then(env_var_non_empty)
|
||||||
|
}
|
||||||
|
|
||||||
fn resolve_api_key(configured: Option<String>) -> Option<String> {
|
fn resolve_api_key(configured: Option<String>) -> Option<String> {
|
||||||
let raw = configured?.trim().to_string();
|
let raw = configured?.trim().to_string();
|
||||||
if raw.is_empty() {
|
if raw.is_empty() {
|
||||||
@@ -1545,7 +1946,8 @@ mod tests {
|
|||||||
Value::String("local".to_string()),
|
Value::String("local".to_string()),
|
||||||
);
|
);
|
||||||
|
|
||||||
let provider = OllamaProvider::from_config(&config, None).expect("provider constructed");
|
let provider = OllamaProvider::from_config("ollama_local", &config, None)
|
||||||
|
.expect("provider constructed");
|
||||||
|
|
||||||
assert_eq!(provider.mode, OllamaMode::Local);
|
assert_eq!(provider.mode, OllamaMode::Local);
|
||||||
assert_eq!(provider.base_url, "http://localhost:11434");
|
assert_eq!(provider.base_url, "http://localhost:11434");
|
||||||
@@ -1563,7 +1965,8 @@ mod tests {
|
|||||||
};
|
};
|
||||||
// simulate missing explicit mode; defaults to auto
|
// simulate missing explicit mode; defaults to auto
|
||||||
|
|
||||||
let provider = OllamaProvider::from_config(&config, None).expect("provider constructed");
|
let provider = OllamaProvider::from_config("ollama_local", &config, None)
|
||||||
|
.expect("provider constructed");
|
||||||
|
|
||||||
assert_eq!(provider.mode, OllamaMode::Local);
|
assert_eq!(provider.mode, OllamaMode::Local);
|
||||||
assert_eq!(provider.base_url, "http://localhost:11434");
|
assert_eq!(provider.base_url, "http://localhost:11434");
|
||||||
@@ -1584,12 +1987,191 @@ mod tests {
|
|||||||
Value::String("auto".to_string()),
|
Value::String("auto".to_string()),
|
||||||
);
|
);
|
||||||
|
|
||||||
let provider = OllamaProvider::from_config(&config, None).expect("provider constructed");
|
let provider = OllamaProvider::from_config("ollama_cloud", &config, None)
|
||||||
|
.expect("provider constructed");
|
||||||
|
|
||||||
assert_eq!(provider.mode, OllamaMode::Cloud);
|
assert_eq!(provider.mode, OllamaMode::Cloud);
|
||||||
assert_eq!(provider.base_url, CLOUD_BASE_URL);
|
assert_eq!(provider.base_url, CLOUD_BASE_URL);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cloud_provider_requires_api_key() {
|
||||||
|
let config = ProviderConfig {
|
||||||
|
enabled: true,
|
||||||
|
provider_type: "ollama_cloud".to_string(),
|
||||||
|
base_url: None,
|
||||||
|
api_key: None,
|
||||||
|
api_key_env: None,
|
||||||
|
extra: HashMap::new(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let err = OllamaProvider::from_config("ollama_cloud", &config, None)
|
||||||
|
.expect_err("expected config error");
|
||||||
|
match err {
|
||||||
|
Error::Config(message) => {
|
||||||
|
assert!(message.contains("API key"));
|
||||||
|
}
|
||||||
|
other => panic!("unexpected error variant: {other:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cloud_provider_uses_explicit_api_key() {
|
||||||
|
let config = ProviderConfig {
|
||||||
|
enabled: true,
|
||||||
|
provider_type: "ollama_cloud".to_string(),
|
||||||
|
base_url: None,
|
||||||
|
api_key: Some("secret-cloud-key".to_string()),
|
||||||
|
api_key_env: None,
|
||||||
|
extra: HashMap::new(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let provider = OllamaProvider::from_config("ollama_cloud", &config, None)
|
||||||
|
.expect("provider constructed");
|
||||||
|
assert_eq!(provider.name(), "ollama_cloud");
|
||||||
|
assert_eq!(provider.mode, OllamaMode::Cloud);
|
||||||
|
assert_eq!(provider.base_url, CLOUD_BASE_URL);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cloud_provider_reads_api_key_from_env_hint() {
|
||||||
|
let config = ProviderConfig {
|
||||||
|
enabled: true,
|
||||||
|
provider_type: "ollama_cloud".to_string(),
|
||||||
|
base_url: None,
|
||||||
|
api_key: None,
|
||||||
|
api_key_env: Some("OLLAMA_TEST_CLOUD_KEY".to_string()),
|
||||||
|
extra: HashMap::new(),
|
||||||
|
};
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
std::env::set_var("OLLAMA_TEST_CLOUD_KEY", "env-secret");
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(std::env::var("OLLAMA_TEST_CLOUD_KEY").is_ok());
|
||||||
|
assert!(resolve_api_key_env_hint(config.api_key_env.as_deref()).is_some());
|
||||||
|
assert_eq!(config.api_key_env.as_deref(), Some("OLLAMA_TEST_CLOUD_KEY"));
|
||||||
|
|
||||||
|
let provider = OllamaProvider::from_config("ollama_cloud", &config, None)
|
||||||
|
.expect("provider constructed");
|
||||||
|
assert_eq!(provider.name(), "ollama_cloud");
|
||||||
|
assert_eq!(provider.mode, OllamaMode::Cloud);
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
std::env::remove_var("OLLAMA_TEST_CLOUD_KEY");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn fetch_scope_tags_with_retry_success_uses_override() {
|
||||||
|
let provider = OllamaProvider::new("http://localhost:11434").expect("provider constructed");
|
||||||
|
let handle = ScopeHandle::new(
|
||||||
|
provider.client.clone(),
|
||||||
|
provider.http_client.clone(),
|
||||||
|
provider.base_url.clone(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let _guard = set_tags_override(vec![Ok(vec![LocalModel {
|
||||||
|
name: "llama3".into(),
|
||||||
|
modified_at: "2024-01-01T00:00:00Z".into(),
|
||||||
|
size: 42,
|
||||||
|
}])]);
|
||||||
|
|
||||||
|
let models = block_on(provider.fetch_scope_tags_with_retry(OllamaMode::Local, &handle))
|
||||||
|
.expect("models returned");
|
||||||
|
assert_eq!(models.len(), 1);
|
||||||
|
assert_eq!(models[0].name, "llama3");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn fetch_scope_tags_with_retry_retries_on_timeout_then_succeeds() {
|
||||||
|
let provider = OllamaProvider::new("http://localhost:11434").expect("provider constructed");
|
||||||
|
let handle = ScopeHandle::new(
|
||||||
|
provider.client.clone(),
|
||||||
|
provider.http_client.clone(),
|
||||||
|
provider.base_url.clone(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let _guard = set_tags_override(vec![
|
||||||
|
Err(Error::Timeout("first attempt".into())),
|
||||||
|
Ok(vec![LocalModel {
|
||||||
|
name: "llama3".into(),
|
||||||
|
modified_at: "2024-01-01T00:00:00Z".into(),
|
||||||
|
size: 42,
|
||||||
|
}]),
|
||||||
|
]);
|
||||||
|
|
||||||
|
let models = block_on(provider.fetch_scope_tags_with_retry(OllamaMode::Local, &handle))
|
||||||
|
.expect("models returned after retry");
|
||||||
|
assert_eq!(models.len(), 1);
|
||||||
|
assert_eq!(models[0].name, "llama3");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn decorate_scope_error_returns_friendly_message_for_connectivity() {
|
||||||
|
let provider = OllamaProvider::new("http://localhost:11434").expect("provider constructed");
|
||||||
|
let (error, message) = provider.decorate_scope_error(
|
||||||
|
OllamaMode::Local,
|
||||||
|
"http://localhost:11434",
|
||||||
|
Error::Network("failed to connect to host".into()),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(matches!(
|
||||||
|
error,
|
||||||
|
Error::Network(ref text) if text.contains("Ollama not reachable")
|
||||||
|
));
|
||||||
|
assert!(message.contains("Ollama not reachable"));
|
||||||
|
assert!(message.contains("localhost:11434"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn decorate_scope_error_preserves_http_failure_message() {
|
||||||
|
let provider = OllamaProvider::new("http://localhost:11434").expect("provider constructed");
|
||||||
|
let original = "Ollama list models failed (500): boom".to_string();
|
||||||
|
let (error, message) = provider.decorate_scope_error(
|
||||||
|
OllamaMode::Local,
|
||||||
|
"http://localhost:11434",
|
||||||
|
Error::Network(original.clone()),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(matches!(error, Error::Network(ref text) if text.contains("500")));
|
||||||
|
assert_eq!(message, original);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn decorate_scope_error_translates_timeout() {
|
||||||
|
let provider = OllamaProvider::new("http://localhost:11434").expect("provider constructed");
|
||||||
|
let (error, message) = provider.decorate_scope_error(
|
||||||
|
OllamaMode::Local,
|
||||||
|
"http://localhost:11434",
|
||||||
|
Error::Timeout("deadline exceeded".into()),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert!(matches!(
|
||||||
|
error,
|
||||||
|
Error::Timeout(ref text) if text.contains("Ollama not reachable")
|
||||||
|
));
|
||||||
|
assert!(message.contains("Ollama not reachable"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn map_http_failure_model_not_found_suggests_pull_hint() {
|
||||||
|
let provider = OllamaProvider::new("http://localhost:11434").expect("provider constructed");
|
||||||
|
let err = provider.map_http_failure(
|
||||||
|
"chat",
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
"missing model".to_string(),
|
||||||
|
Some("llama3"),
|
||||||
|
);
|
||||||
|
|
||||||
|
let message = match err {
|
||||||
|
Error::InvalidInput(message) => message,
|
||||||
|
other => panic!("unexpected error variant: {other:?}"),
|
||||||
|
};
|
||||||
|
|
||||||
|
assert!(message.contains("ollama pull llama3"));
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn build_model_options_merges_parameters() {
|
fn build_model_options_merges_parameters() {
|
||||||
let mut parameters = ChatParameters::default();
|
let mut parameters = ChatParameters::default();
|
||||||
@@ -1630,13 +2212,19 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
struct ProbeOverrideGuard;
|
struct ProbeOverrideGuard {
|
||||||
|
gate: Option<MutexGuard<'static, ()>>,
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
impl ProbeOverrideGuard {
|
impl ProbeOverrideGuard {
|
||||||
fn set(value: Option<bool>) -> Self {
|
fn set(value: Option<bool>) -> Self {
|
||||||
|
let gate = PROBE_OVERRIDE_GATE
|
||||||
|
.get_or_init(|| Mutex::new(()))
|
||||||
|
.lock()
|
||||||
|
.expect("probe override gate mutex poisoned");
|
||||||
set_probe_override(value);
|
set_probe_override(value);
|
||||||
ProbeOverrideGuard
|
ProbeOverrideGuard { gate: Some(gate) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1644,6 +2232,7 @@ impl ProbeOverrideGuard {
|
|||||||
impl Drop for ProbeOverrideGuard {
|
impl Drop for ProbeOverrideGuard {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
set_probe_override(None);
|
set_probe_override(None);
|
||||||
|
self.gate.take();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1666,7 +2255,8 @@ fn auto_mode_with_api_key_and_successful_probe_prefers_local() {
|
|||||||
|
|
||||||
assert!(probe_default_local_daemon(Duration::from_millis(1)));
|
assert!(probe_default_local_daemon(Duration::from_millis(1)));
|
||||||
|
|
||||||
let provider = OllamaProvider::from_config(&config, None).expect("provider constructed");
|
let provider =
|
||||||
|
OllamaProvider::from_config("ollama_local", &config, None).expect("provider constructed");
|
||||||
|
|
||||||
assert_eq!(provider.mode, OllamaMode::Local);
|
assert_eq!(provider.mode, OllamaMode::Local);
|
||||||
assert_eq!(provider.base_url, "http://localhost:11434");
|
assert_eq!(provider.base_url, "http://localhost:11434");
|
||||||
@@ -1689,7 +2279,8 @@ fn auto_mode_with_api_key_and_failed_probe_prefers_cloud() {
|
|||||||
Value::String("auto".to_string()),
|
Value::String("auto".to_string()),
|
||||||
);
|
);
|
||||||
|
|
||||||
let provider = OllamaProvider::from_config(&config, None).expect("provider constructed");
|
let provider =
|
||||||
|
OllamaProvider::from_config("ollama_cloud", &config, None).expect("provider constructed");
|
||||||
|
|
||||||
assert_eq!(provider.mode, OllamaMode::Cloud);
|
assert_eq!(provider.mode, OllamaMode::Cloud);
|
||||||
assert_eq!(provider.base_url, CLOUD_BASE_URL);
|
assert_eq!(provider.base_url, CLOUD_BASE_URL);
|
||||||
@@ -1706,7 +2297,8 @@ fn annotate_scope_status_adds_capabilities_for_unavailable_scopes() {
|
|||||||
extra: HashMap::new(),
|
extra: HashMap::new(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let provider = OllamaProvider::from_config(&config, None).expect("provider constructed");
|
let provider =
|
||||||
|
OllamaProvider::from_config("ollama_local", &config, None).expect("provider constructed");
|
||||||
|
|
||||||
let mut models = vec![ModelInfo {
|
let mut models = vec![ModelInfo {
|
||||||
id: "llama3".to_string(),
|
id: "llama3".to_string(),
|
||||||
|
|||||||
Reference in New Issue
Block a user