2844 lines
90 KiB
Rust
2844 lines
90 KiB
Rust
//! Ollama provider built on top of the `ollama-rs` crate.
|
||
use std::{
|
||
collections::{HashMap, HashSet},
|
||
convert::TryFrom,
|
||
env, fs,
|
||
net::{SocketAddr, TcpStream},
|
||
pin::Pin,
|
||
process::Command,
|
||
sync::{Arc, OnceLock},
|
||
time::{Duration, Instant, SystemTime},
|
||
};
|
||
|
||
use anyhow::anyhow;
|
||
use base64::{Engine, engine::general_purpose::STANDARD as BASE64_STANDARD};
|
||
use futures::{Stream, StreamExt, future::BoxFuture, future::join_all};
|
||
use log::{debug, warn};
|
||
use ollama_rs::{
|
||
Ollama,
|
||
error::OllamaError,
|
||
generation::tools::{
|
||
ToolCall as OllamaToolCall, ToolCallFunction as OllamaToolCallFunction,
|
||
ToolInfo as OllamaToolInfo,
|
||
},
|
||
generation::{
|
||
chat::{
|
||
ChatMessage as OllamaMessage, ChatMessageResponse as OllamaChatResponse,
|
||
MessageRole as OllamaRole, request::ChatMessageRequest as OllamaChatRequest,
|
||
},
|
||
images::Image,
|
||
},
|
||
headers::{AUTHORIZATION, HeaderMap, HeaderValue},
|
||
models::{LocalModel, ModelInfo as OllamaModelInfo, ModelOptions},
|
||
};
|
||
use reqwest::{Client, StatusCode, Url};
|
||
use serde::Deserialize;
|
||
use serde_json::{Map as JsonMap, Value, json};
|
||
use tokio::{sync::RwLock, time::sleep};
|
||
|
||
#[cfg(test)]
|
||
use std::sync::{Mutex, MutexGuard};
|
||
#[cfg(test)]
|
||
use tokio_test::block_on;
|
||
use uuid::Uuid;
|
||
|
||
use crate::{
|
||
Error, Result,
|
||
config::{
|
||
DEFAULT_OLLAMA_CLOUD_HOURLY_QUOTA, DEFAULT_OLLAMA_CLOUD_WEEKLY_QUOTA, GeneralSettings,
|
||
LEGACY_OLLAMA_CLOUD_API_KEY_ENV, LEGACY_OWLEN_OLLAMA_CLOUD_API_KEY_ENV, OLLAMA_API_KEY_ENV,
|
||
OLLAMA_CLOUD_BASE_URL, OLLAMA_CLOUD_ENDPOINT_KEY, OLLAMA_MODE_KEY,
|
||
},
|
||
llm::{LlmProvider, ProviderConfig},
|
||
mcp::McpToolDescriptor,
|
||
model::{DetailedModelInfo, ModelDetailsCache, ModelManager},
|
||
provider::{ProviderError, ProviderErrorKind},
|
||
types::{
|
||
ChatParameters, ChatRequest, ChatResponse, Message, MessageAttachment, ModelInfo, Role,
|
||
TokenUsage, ToolCall,
|
||
},
|
||
};
|
||
|
||
const DEFAULT_TIMEOUT_SECS: u64 = 120;
|
||
const DEFAULT_MODEL_CACHE_TTL_SECS: u64 = 60;
|
||
pub(crate) const CLOUD_BASE_URL: &str = OLLAMA_CLOUD_BASE_URL;
|
||
const LOCAL_PROBE_TIMEOUT_MS: u64 = 200;
|
||
const LOCAL_PROBE_TARGETS: &[&str] = &["127.0.0.1:11434", "[::1]:11434"];
|
||
const LOCAL_TAGS_TIMEOUT_STEPS_MS: [u64; 3] = [400, 800, 1_600];
|
||
const LOCAL_TAGS_RETRY_DELAYS_MS: [u64; 2] = [150, 300];
|
||
const HEALTHCHECK_TIMEOUT_MS: u64 = 1_000;
|
||
|
||
static LEGACY_CLOUD_ENV_WARNING: OnceLock<()> = OnceLock::new();
|
||
|
||
fn warn_legacy_cloud_env(var_name: &str) {
|
||
if LEGACY_CLOUD_ENV_WARNING.set(()).is_ok() {
|
||
warn!(
|
||
"Using legacy Ollama Cloud API key environment variable `{var_name}`. \
|
||
Prefer configuring OLLAMA_API_KEY; legacy names remain supported but may be removed."
|
||
);
|
||
}
|
||
}
|
||
|
||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||
enum OllamaMode {
|
||
Local,
|
||
Cloud,
|
||
}
|
||
|
||
impl OllamaMode {
|
||
fn default_base_url(self) -> &'static str {
|
||
match self {
|
||
Self::Local => "http://localhost:11434",
|
||
Self::Cloud => CLOUD_BASE_URL,
|
||
}
|
||
}
|
||
}
|
||
|
||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||
enum ScopeAvailability {
|
||
Unknown,
|
||
Available,
|
||
Unavailable,
|
||
}
|
||
|
||
impl ScopeAvailability {
|
||
fn as_str(self) -> &'static str {
|
||
match self {
|
||
ScopeAvailability::Unknown => "unknown",
|
||
ScopeAvailability::Available => "available",
|
||
ScopeAvailability::Unavailable => "unavailable",
|
||
}
|
||
}
|
||
}
|
||
|
||
#[derive(Debug, Clone)]
|
||
struct ScopeSnapshot {
|
||
models: Vec<ModelInfo>,
|
||
fetched_at: Option<Instant>,
|
||
availability: ScopeAvailability,
|
||
last_error: Option<String>,
|
||
last_checked: Option<Instant>,
|
||
last_success_at: Option<Instant>,
|
||
}
|
||
|
||
impl Default for ScopeSnapshot {
|
||
fn default() -> Self {
|
||
Self {
|
||
models: Vec::new(),
|
||
fetched_at: None,
|
||
availability: ScopeAvailability::Unknown,
|
||
last_error: None,
|
||
last_checked: None,
|
||
last_success_at: None,
|
||
}
|
||
}
|
||
}
|
||
|
||
impl ScopeSnapshot {
|
||
fn is_stale(&self, ttl: Duration) -> bool {
|
||
match self.fetched_at {
|
||
Some(ts) => ts.elapsed() >= ttl,
|
||
None => !self.models.is_empty(),
|
||
}
|
||
}
|
||
|
||
fn last_checked_age_secs(&self) -> Option<u64> {
|
||
self.last_checked.map(|instant| instant.elapsed().as_secs())
|
||
}
|
||
|
||
fn last_success_age_secs(&self) -> Option<u64> {
|
||
self.last_success_at
|
||
.map(|instant| instant.elapsed().as_secs())
|
||
}
|
||
}
|
||
|
||
#[derive(Clone)]
|
||
struct ScopeHandle {
|
||
client: Ollama,
|
||
http_client: Client,
|
||
base_url: String,
|
||
}
|
||
|
||
impl ScopeHandle {
|
||
fn new(client: Ollama, http_client: Client, base_url: impl Into<String>) -> Self {
|
||
Self {
|
||
client,
|
||
http_client,
|
||
base_url: base_url.into(),
|
||
}
|
||
}
|
||
|
||
fn api_url(&self, endpoint: &str) -> String {
|
||
build_api_endpoint(&self.base_url, endpoint)
|
||
}
|
||
}
|
||
|
||
#[derive(Debug, Deserialize)]
|
||
struct TagsResponse {
|
||
#[serde(default)]
|
||
models: Vec<LocalModel>,
|
||
}
|
||
|
||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||
enum ProviderVariant {
|
||
Local,
|
||
Cloud,
|
||
}
|
||
|
||
impl ProviderVariant {
|
||
fn supports_local(self) -> bool {
|
||
matches!(self, ProviderVariant::Local)
|
||
}
|
||
|
||
fn supports_cloud(self) -> bool {
|
||
matches!(self, ProviderVariant::Cloud)
|
||
}
|
||
}
|
||
|
||
#[derive(Debug)]
|
||
struct OllamaOptions {
|
||
provider_name: String,
|
||
variant: ProviderVariant,
|
||
mode: OllamaMode,
|
||
base_url: String,
|
||
request_timeout: Duration,
|
||
model_cache_ttl: Duration,
|
||
api_key: Option<String>,
|
||
cloud_endpoint: Option<String>,
|
||
}
|
||
|
||
impl OllamaOptions {
|
||
fn new(
|
||
provider_name: impl Into<String>,
|
||
variant: ProviderVariant,
|
||
mode: OllamaMode,
|
||
base_url: impl Into<String>,
|
||
) -> Self {
|
||
Self {
|
||
provider_name: provider_name.into(),
|
||
variant,
|
||
mode,
|
||
base_url: base_url.into(),
|
||
request_timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
|
||
model_cache_ttl: Duration::from_secs(DEFAULT_MODEL_CACHE_TTL_SECS),
|
||
api_key: None,
|
||
cloud_endpoint: None,
|
||
}
|
||
}
|
||
|
||
fn with_general(mut self, general: &GeneralSettings) -> Self {
|
||
self.model_cache_ttl = general.model_cache_ttl();
|
||
self
|
||
}
|
||
}
|
||
|
||
/// Ollama provider implementation backed by `ollama-rs`.
|
||
#[derive(Debug)]
|
||
pub struct OllamaProvider {
|
||
provider_name: String,
|
||
variant: ProviderVariant,
|
||
mode: OllamaMode,
|
||
client: Ollama,
|
||
http_client: Client,
|
||
base_url: String,
|
||
request_timeout: Duration,
|
||
api_key: Option<String>,
|
||
cloud_endpoint: Option<String>,
|
||
model_manager: ModelManager,
|
||
model_details_cache: ModelDetailsCache,
|
||
model_cache_ttl: Duration,
|
||
scope_cache: Arc<RwLock<HashMap<OllamaMode, ScopeSnapshot>>>,
|
||
}
|
||
|
||
fn configured_mode_from_extra(config: &ProviderConfig) -> Option<OllamaMode> {
|
||
config
|
||
.extra
|
||
.get(OLLAMA_MODE_KEY)
|
||
.and_then(|value| value.as_str())
|
||
.and_then(|value| match value.trim().to_ascii_lowercase().as_str() {
|
||
"local" => Some(OllamaMode::Local),
|
||
"cloud" => Some(OllamaMode::Cloud),
|
||
_ => None,
|
||
})
|
||
}
|
||
|
||
fn is_explicit_local_base(base_url: Option<&str>) -> bool {
|
||
base_url
|
||
.and_then(|raw| Url::parse(raw).ok())
|
||
.and_then(|parsed| parsed.host_str().map(|host| host.to_ascii_lowercase()))
|
||
.map(|host| host == "localhost" || host == "127.0.0.1" || host == "::1")
|
||
.unwrap_or(false)
|
||
}
|
||
|
||
fn is_explicit_cloud_base(base_url: Option<&str>) -> bool {
|
||
base_url
|
||
.map(|raw| {
|
||
let trimmed = raw.trim_end_matches('/');
|
||
trimmed == CLOUD_BASE_URL || trimmed.starts_with("https://ollama.com/")
|
||
})
|
||
.unwrap_or(false)
|
||
}
|
||
|
||
#[cfg(test)]
|
||
static PROBE_OVERRIDE: OnceLock<Mutex<Option<bool>>> = OnceLock::new();
|
||
|
||
#[cfg(test)]
|
||
static TAGS_OVERRIDE: OnceLock<Mutex<Vec<std::result::Result<Vec<LocalModel>, Error>>>> =
|
||
OnceLock::new();
|
||
|
||
#[cfg(test)]
|
||
static TAGS_OVERRIDE_GATE: OnceLock<Mutex<()>> = OnceLock::new();
|
||
|
||
#[cfg(test)]
|
||
static PROBE_OVERRIDE_GATE: OnceLock<Mutex<()>> = OnceLock::new();
|
||
|
||
#[cfg(test)]
|
||
fn set_probe_override(value: Option<bool>) {
|
||
let guard = PROBE_OVERRIDE.get_or_init(|| Mutex::new(None));
|
||
*guard.lock().expect("probe override mutex poisoned") = value;
|
||
}
|
||
|
||
#[cfg(test)]
|
||
fn probe_override_value() -> Option<bool> {
|
||
PROBE_OVERRIDE
|
||
.get_or_init(|| Mutex::new(None))
|
||
.lock()
|
||
.expect("probe override mutex poisoned")
|
||
.to_owned()
|
||
}
|
||
|
||
#[cfg(test)]
|
||
fn set_tags_override(
|
||
sequence: Vec<std::result::Result<Vec<LocalModel>, Error>>,
|
||
) -> TagsOverrideGuard {
|
||
let gate = TAGS_OVERRIDE_GATE
|
||
.get_or_init(|| Mutex::new(()))
|
||
.lock()
|
||
.expect("tags override gate mutex poisoned");
|
||
|
||
let store = TAGS_OVERRIDE.get_or_init(|| Mutex::new(Vec::new()));
|
||
{
|
||
let mut guard = store.lock().expect("tags override mutex poisoned");
|
||
guard.clear();
|
||
for item in sequence.into_iter().rev() {
|
||
guard.push(item);
|
||
}
|
||
}
|
||
TagsOverrideGuard { gate: Some(gate) }
|
||
}
|
||
|
||
#[cfg(test)]
|
||
fn pop_tags_override() -> Option<std::result::Result<Vec<LocalModel>, Error>> {
|
||
TAGS_OVERRIDE
|
||
.get_or_init(|| Mutex::new(Vec::new()))
|
||
.lock()
|
||
.expect("tags override mutex poisoned")
|
||
.pop()
|
||
}
|
||
|
||
#[cfg(test)]
|
||
struct TagsOverrideGuard {
|
||
gate: Option<MutexGuard<'static, ()>>,
|
||
}
|
||
|
||
#[cfg(test)]
|
||
impl Drop for TagsOverrideGuard {
|
||
fn drop(&mut self) {
|
||
if let Some(store) = TAGS_OVERRIDE.get() {
|
||
let mut guard = store.lock().expect("tags override mutex poisoned");
|
||
guard.clear();
|
||
}
|
||
self.gate.take();
|
||
}
|
||
}
|
||
|
||
fn probe_default_local_daemon(timeout: Duration) -> bool {
|
||
#[cfg(test)]
|
||
{
|
||
if let Some(value) = probe_override_value() {
|
||
return value;
|
||
}
|
||
}
|
||
|
||
for target in LOCAL_PROBE_TARGETS {
|
||
if let Ok(address) = target.parse::<SocketAddr>() {
|
||
if TcpStream::connect_timeout(&address, timeout).is_ok() {
|
||
return true;
|
||
}
|
||
}
|
||
}
|
||
false
|
||
}
|
||
|
||
impl OllamaProvider {
|
||
/// Create a provider targeting an explicit base URL (local usage).
|
||
pub fn new(base_url: impl Into<String>) -> Result<Self> {
|
||
let input = base_url.into();
|
||
let normalized =
|
||
normalize_base_url(Some(&input), OllamaMode::Local).map_err(Error::Config)?;
|
||
Self::with_options(OllamaOptions::new(
|
||
"ollama_local",
|
||
ProviderVariant::Local,
|
||
OllamaMode::Local,
|
||
normalized,
|
||
))
|
||
}
|
||
|
||
/// Construct a provider from configuration settings.
|
||
pub fn from_config(
|
||
provider_id: &str,
|
||
config: &ProviderConfig,
|
||
general: Option<&GeneralSettings>,
|
||
) -> Result<Self> {
|
||
let provider_type = config.provider_type.trim().to_ascii_lowercase();
|
||
let register_name = {
|
||
let candidate = provider_id.trim();
|
||
if candidate.is_empty() {
|
||
if provider_type.is_empty() {
|
||
"ollama".to_string()
|
||
} else {
|
||
provider_type.clone()
|
||
}
|
||
} else {
|
||
candidate.replace('-', "_")
|
||
}
|
||
};
|
||
|
||
let variant = if register_name == "ollama_cloud" || provider_type == "ollama_cloud" {
|
||
ProviderVariant::Cloud
|
||
} else {
|
||
ProviderVariant::Local
|
||
};
|
||
|
||
let mut api_key = resolve_api_key(config.api_key.clone())
|
||
.or_else(|| resolve_api_key_env_hint(config.api_key_env.as_deref()))
|
||
.or_else(|| env_var_non_empty(OLLAMA_API_KEY_ENV))
|
||
.or_else(|| {
|
||
warn_legacy_cloud_env(LEGACY_OLLAMA_CLOUD_API_KEY_ENV);
|
||
env_var_non_empty(LEGACY_OLLAMA_CLOUD_API_KEY_ENV)
|
||
})
|
||
.or_else(|| {
|
||
warn_legacy_cloud_env(LEGACY_OWLEN_OLLAMA_CLOUD_API_KEY_ENV);
|
||
env_var_non_empty(LEGACY_OWLEN_OLLAMA_CLOUD_API_KEY_ENV)
|
||
});
|
||
let api_key_present = api_key.is_some();
|
||
|
||
let configured_mode = configured_mode_from_extra(config);
|
||
let configured_mode_label = config
|
||
.extra
|
||
.get(OLLAMA_MODE_KEY)
|
||
.and_then(|value| value.as_str())
|
||
.unwrap_or("auto");
|
||
let base_url = config.base_url.as_deref();
|
||
let base_is_local = is_explicit_local_base(base_url);
|
||
let base_is_cloud = is_explicit_cloud_base(base_url);
|
||
let _base_is_other = base_url.is_some() && !base_is_local && !base_is_cloud;
|
||
|
||
let mut local_probe_result = None;
|
||
let cloud_endpoint = config
|
||
.extra
|
||
.get(OLLAMA_CLOUD_ENDPOINT_KEY)
|
||
.and_then(Value::as_str)
|
||
.map(normalize_cloud_endpoint)
|
||
.transpose()
|
||
.map_err(Error::Config)?;
|
||
|
||
if matches!(variant, ProviderVariant::Local) && configured_mode.is_none() {
|
||
let probe = probe_default_local_daemon(Duration::from_millis(LOCAL_PROBE_TIMEOUT_MS));
|
||
local_probe_result = Some(probe);
|
||
}
|
||
|
||
let mode = match variant {
|
||
ProviderVariant::Local => OllamaMode::Local,
|
||
ProviderVariant::Cloud => OllamaMode::Cloud,
|
||
};
|
||
|
||
if matches!(variant, ProviderVariant::Cloud) {
|
||
if !api_key_present {
|
||
return Err(Error::Config(
|
||
"Ollama Cloud API key not configured. Set providers.ollama_cloud.api_key or export OLLAMA_API_KEY (legacy: OLLAMA_CLOUD_API_KEY / OWLEN_OLLAMA_CLOUD_API_KEY)."
|
||
.into(),
|
||
));
|
||
}
|
||
}
|
||
|
||
let base_candidate = match mode {
|
||
OllamaMode::Local => base_url,
|
||
OllamaMode::Cloud => base_url.or(Some(CLOUD_BASE_URL)),
|
||
};
|
||
|
||
let normalized_base_url =
|
||
normalize_base_url(base_candidate, mode).map_err(Error::Config)?;
|
||
|
||
let mut options = OllamaOptions::new(
|
||
register_name.clone(),
|
||
variant,
|
||
mode,
|
||
normalized_base_url.clone(),
|
||
);
|
||
options.cloud_endpoint = cloud_endpoint.clone();
|
||
|
||
if let Some(timeout) = config
|
||
.extra
|
||
.get("timeout_secs")
|
||
.and_then(|value| value.as_u64())
|
||
{
|
||
options.request_timeout = Duration::from_secs(timeout.max(5));
|
||
}
|
||
|
||
if let Some(cache_ttl) = config
|
||
.extra
|
||
.get("model_cache_ttl_secs")
|
||
.and_then(|value| value.as_u64())
|
||
{
|
||
options.model_cache_ttl = Duration::from_secs(cache_ttl.max(5));
|
||
}
|
||
|
||
options.api_key = api_key.take();
|
||
|
||
if let Some(general) = general {
|
||
options = options.with_general(general);
|
||
}
|
||
|
||
debug!(
|
||
"Resolved Ollama provider '{}': mode={:?}, base_url={}, configured_mode={}, api_key_present={}, local_probe={}",
|
||
register_name,
|
||
mode,
|
||
normalized_base_url,
|
||
configured_mode_label,
|
||
if options.api_key.is_some() {
|
||
"yes"
|
||
} else {
|
||
"no"
|
||
},
|
||
match local_probe_result {
|
||
Some(true) => "success",
|
||
Some(false) => "failed",
|
||
None => "skipped",
|
||
}
|
||
);
|
||
|
||
Self::with_options(options)
|
||
}
|
||
|
||
fn with_options(options: OllamaOptions) -> Result<Self> {
|
||
let OllamaOptions {
|
||
provider_name,
|
||
variant,
|
||
mode,
|
||
base_url,
|
||
request_timeout,
|
||
model_cache_ttl,
|
||
api_key,
|
||
cloud_endpoint,
|
||
} = options;
|
||
|
||
let api_key_ref = api_key.as_deref();
|
||
let (ollama_client, http_client) =
|
||
build_client_for_base(&base_url, request_timeout, api_key_ref)?;
|
||
|
||
let scope_cache = {
|
||
let mut initial = HashMap::new();
|
||
initial.insert(OllamaMode::Local, ScopeSnapshot::default());
|
||
initial.insert(OllamaMode::Cloud, ScopeSnapshot::default());
|
||
Arc::new(RwLock::new(initial))
|
||
};
|
||
|
||
Ok(Self {
|
||
provider_name: provider_name.trim().to_ascii_lowercase(),
|
||
variant,
|
||
mode,
|
||
client: ollama_client,
|
||
http_client,
|
||
base_url: base_url.trim_end_matches('/').to_string(),
|
||
request_timeout,
|
||
api_key,
|
||
cloud_endpoint,
|
||
model_manager: ModelManager::new(model_cache_ttl),
|
||
model_details_cache: ModelDetailsCache::new(model_cache_ttl),
|
||
model_cache_ttl,
|
||
scope_cache,
|
||
})
|
||
}
|
||
|
||
fn api_url(&self, endpoint: &str) -> String {
|
||
build_api_endpoint(&self.base_url, endpoint)
|
||
}
|
||
|
||
fn local_base_url() -> &'static str {
|
||
OllamaMode::Local.default_base_url()
|
||
}
|
||
|
||
fn scope_key(scope: OllamaMode) -> &'static str {
|
||
match scope {
|
||
OllamaMode::Local => "local",
|
||
OllamaMode::Cloud => "cloud",
|
||
}
|
||
}
|
||
|
||
fn supports_local_scope(&self) -> bool {
|
||
self.variant.supports_local()
|
||
}
|
||
|
||
fn supports_cloud_scope(&self) -> bool {
|
||
self.variant.supports_cloud()
|
||
}
|
||
|
||
fn build_local_client(&self) -> Result<Option<ScopeHandle>> {
|
||
if !self.supports_local_scope() {
|
||
return Ok(None);
|
||
}
|
||
|
||
if matches!(self.mode, OllamaMode::Local) {
|
||
return Ok(Some(ScopeHandle::new(
|
||
self.client.clone(),
|
||
self.http_client.clone(),
|
||
self.base_url.clone(),
|
||
)));
|
||
}
|
||
|
||
let (client, http_client) =
|
||
build_client_for_base(Self::local_base_url(), self.request_timeout, None)?;
|
||
Ok(Some(ScopeHandle::new(
|
||
client,
|
||
http_client,
|
||
Self::local_base_url(),
|
||
)))
|
||
}
|
||
|
||
fn build_cloud_client(&self) -> Result<Option<ScopeHandle>> {
|
||
if !self.supports_cloud_scope() {
|
||
return Ok(None);
|
||
}
|
||
|
||
if matches!(self.mode, OllamaMode::Cloud) {
|
||
return Ok(Some(ScopeHandle::new(
|
||
self.client.clone(),
|
||
self.http_client.clone(),
|
||
self.base_url.clone(),
|
||
)));
|
||
}
|
||
|
||
let api_key = match self.api_key.as_deref() {
|
||
Some(key) if !key.trim().is_empty() => key,
|
||
_ => return Ok(None),
|
||
};
|
||
|
||
let endpoint = self.cloud_endpoint.as_deref().unwrap_or(CLOUD_BASE_URL);
|
||
|
||
let (client, http_client) =
|
||
build_client_for_base(endpoint, self.request_timeout, Some(api_key))?;
|
||
Ok(Some(ScopeHandle::new(client, http_client, endpoint)))
|
||
}
|
||
|
||
async fn cached_scope_models(&self, scope: OllamaMode) -> Option<Vec<ModelInfo>> {
|
||
let cache = self.scope_cache.read().await;
|
||
cache.get(&scope).and_then(|entry| {
|
||
if entry.availability == ScopeAvailability::Unknown {
|
||
return None;
|
||
}
|
||
|
||
if entry.models.is_empty() {
|
||
return None;
|
||
}
|
||
|
||
if let Some(ts) = entry.fetched_at {
|
||
if ts.elapsed() < self.model_cache_ttl {
|
||
return Some(entry.models.clone());
|
||
}
|
||
}
|
||
|
||
// Fallback to last good models even if stale; UI will mark as degraded
|
||
Some(entry.models.clone())
|
||
})
|
||
}
|
||
|
||
async fn update_scope_success(&self, scope: OllamaMode, models: &[ModelInfo]) {
|
||
let mut cache = self.scope_cache.write().await;
|
||
let entry = cache.entry(scope).or_default();
|
||
let now = Instant::now();
|
||
entry.models = models.to_vec();
|
||
entry.fetched_at = Some(now);
|
||
entry.last_checked = Some(now);
|
||
entry.last_success_at = Some(now);
|
||
entry.availability = ScopeAvailability::Available;
|
||
entry.last_error = None;
|
||
}
|
||
|
||
async fn mark_scope_failure(&self, scope: OllamaMode, message: String) {
|
||
let mut cache = self.scope_cache.write().await;
|
||
let entry = cache.entry(scope).or_default();
|
||
entry.availability = ScopeAvailability::Unavailable;
|
||
entry.last_error = Some(message);
|
||
entry.last_checked = Some(Instant::now());
|
||
}
|
||
|
||
async fn annotate_scope_status(&self, models: &mut [ModelInfo]) {
|
||
if models.is_empty() {
|
||
return;
|
||
}
|
||
|
||
let cache = self.scope_cache.read().await;
|
||
for (scope, snapshot) in cache.iter() {
|
||
if snapshot.availability == ScopeAvailability::Unknown {
|
||
continue;
|
||
}
|
||
let scope_key = Self::scope_key(*scope);
|
||
let capability = format!(
|
||
"scope-status:{}:{}",
|
||
scope_key,
|
||
snapshot.availability.as_str()
|
||
);
|
||
|
||
for model in models.iter_mut() {
|
||
if !model.capabilities.iter().any(|cap| cap == &capability) {
|
||
model.capabilities.push(capability.clone());
|
||
}
|
||
}
|
||
|
||
let stale = snapshot.is_stale(self.model_cache_ttl);
|
||
let stale_capability = format!(
|
||
"scope-status-stale:{}:{}",
|
||
scope_key,
|
||
if stale { "1" } else { "0" }
|
||
);
|
||
for model in models.iter_mut() {
|
||
if !model
|
||
.capabilities
|
||
.iter()
|
||
.any(|cap| cap == &stale_capability)
|
||
{
|
||
model.capabilities.push(stale_capability.clone());
|
||
}
|
||
}
|
||
|
||
if let Some(age) = snapshot.last_checked_age_secs() {
|
||
let age_capability = format!("scope-status-age:{}:{}", scope_key, age);
|
||
for model in models.iter_mut() {
|
||
if !model.capabilities.iter().any(|cap| cap == &age_capability) {
|
||
model.capabilities.push(age_capability.clone());
|
||
}
|
||
}
|
||
}
|
||
|
||
if let Some(success_age) = snapshot.last_success_age_secs() {
|
||
let success_capability =
|
||
format!("scope-status-success-age:{}:{}", scope_key, success_age);
|
||
for model in models.iter_mut() {
|
||
if !model
|
||
.capabilities
|
||
.iter()
|
||
.any(|cap| cap == &success_capability)
|
||
{
|
||
model.capabilities.push(success_capability.clone());
|
||
}
|
||
}
|
||
}
|
||
|
||
if let Some(raw_reason) = snapshot.last_error.as_ref() {
|
||
let cleaned = raw_reason.replace('\n', " ").trim().to_string();
|
||
if !cleaned.is_empty() {
|
||
let truncated: String = cleaned.chars().take(160).collect();
|
||
let message_capability =
|
||
format!("scope-status-message:{}:{}", scope_key, truncated);
|
||
for model in models.iter_mut() {
|
||
if !model
|
||
.capabilities
|
||
.iter()
|
||
.any(|cap| cap == &message_capability)
|
||
{
|
||
model.capabilities.push(message_capability.clone());
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
/// Attempt to resolve detailed model information for the given model, using the local cache when possible.
|
||
pub async fn get_model_info(&self, model_name: &str) -> Result<DetailedModelInfo> {
|
||
if let Some(info) = self.model_details_cache.get(model_name).await {
|
||
return Ok(info);
|
||
}
|
||
self.fetch_and_cache_model_info(model_name, None).await
|
||
}
|
||
|
||
/// Force-refresh model information for the specified model.
|
||
pub async fn refresh_model_info(&self, model_name: &str) -> Result<DetailedModelInfo> {
|
||
self.model_details_cache.invalidate(model_name).await;
|
||
self.fetch_and_cache_model_info(model_name, None).await
|
||
}
|
||
|
||
/// Retrieve detailed information for all locally available models.
|
||
pub async fn get_all_models_info(&self) -> Result<Vec<DetailedModelInfo>> {
|
||
let models = self
|
||
.client
|
||
.list_local_models()
|
||
.await
|
||
.map_err(|err| self.map_ollama_error("list models", err, None))?;
|
||
|
||
let mut details = Vec::with_capacity(models.len());
|
||
for local in &models {
|
||
match self
|
||
.fetch_and_cache_model_info(&local.name, Some(local))
|
||
.await
|
||
{
|
||
Ok(info) => details.push(info),
|
||
Err(err) => warn!("Failed to gather model info for '{}': {}", local.name, err),
|
||
}
|
||
}
|
||
Ok(details)
|
||
}
|
||
|
||
/// Return any cached model information without touching the Ollama daemon.
|
||
pub async fn cached_model_info(&self) -> Vec<DetailedModelInfo> {
|
||
self.model_details_cache.cached().await
|
||
}
|
||
|
||
/// Remove a single model's cached information.
|
||
pub async fn invalidate_model_info(&self, model_name: &str) {
|
||
self.model_details_cache.invalidate(model_name).await;
|
||
}
|
||
|
||
/// Clear the entire model information cache.
|
||
pub async fn clear_model_info_cache(&self) {
|
||
self.model_details_cache.invalidate_all().await;
|
||
}
|
||
|
||
async fn fetch_and_cache_model_info(
|
||
&self,
|
||
model_name: &str,
|
||
local: Option<&LocalModel>,
|
||
) -> Result<DetailedModelInfo> {
|
||
let detail = self
|
||
.client
|
||
.show_model_info(model_name.to_string())
|
||
.await
|
||
.map_err(|err| self.map_ollama_error("show_model_info", err, Some(model_name)))?;
|
||
|
||
let local_owned = if let Some(local) = local {
|
||
Some(local.clone())
|
||
} else {
|
||
let models = self
|
||
.client
|
||
.list_local_models()
|
||
.await
|
||
.map_err(|err| self.map_ollama_error("list models", err, None))?;
|
||
models.into_iter().find(|m| m.name == model_name)
|
||
};
|
||
|
||
let detailed =
|
||
Self::convert_detailed_model_info(self.mode, model_name, local_owned.as_ref(), &detail);
|
||
self.model_details_cache.insert(detailed.clone()).await;
|
||
Ok(detailed)
|
||
}
|
||
|
||
fn prepare_chat_request(
|
||
&self,
|
||
model: String,
|
||
messages: Vec<Message>,
|
||
parameters: ChatParameters,
|
||
tools: Option<Vec<McpToolDescriptor>>,
|
||
) -> Result<(String, OllamaChatRequest)> {
|
||
if self.mode == OllamaMode::Cloud && !model.contains("-cloud") {
|
||
warn!(
|
||
"Model '{}' does not use the '-cloud' suffix. Cloud-only models may fail to load.",
|
||
model
|
||
);
|
||
}
|
||
|
||
let converted_messages = messages.into_iter().map(convert_message).collect();
|
||
let mut request = OllamaChatRequest::new(model.clone(), converted_messages);
|
||
|
||
if let Some(options) = build_model_options(¶meters)? {
|
||
request.options = Some(options);
|
||
}
|
||
|
||
if let Some(tool_descriptors) = tools.as_ref() {
|
||
let tool_infos = convert_tool_descriptors(tool_descriptors)?;
|
||
if !tool_infos.is_empty() {
|
||
request.tools = tool_infos;
|
||
}
|
||
}
|
||
|
||
Ok((model, request))
|
||
}
|
||
|
||
async fn fetch_models(&self) -> Result<Vec<ModelInfo>> {
|
||
let mut combined = Vec::new();
|
||
let mut seen: HashSet<String> = HashSet::new();
|
||
let mut errors: Vec<Error> = Vec::new();
|
||
|
||
if let Some(local_handle) = self.build_local_client()? {
|
||
match self
|
||
.fetch_models_for_scope(OllamaMode::Local, local_handle)
|
||
.await
|
||
{
|
||
Ok(models) => {
|
||
for model in models {
|
||
let key = format!("local::{}", model.id);
|
||
if seen.insert(key) {
|
||
combined.push(model);
|
||
}
|
||
}
|
||
}
|
||
Err(err) => errors.push(err),
|
||
}
|
||
}
|
||
|
||
if let Some(cloud_handle) = self.build_cloud_client()? {
|
||
match self
|
||
.fetch_models_for_scope(OllamaMode::Cloud, cloud_handle)
|
||
.await
|
||
{
|
||
Ok(models) => {
|
||
for model in models {
|
||
let key = format!("cloud::{}", model.id);
|
||
if seen.insert(key) {
|
||
combined.push(model);
|
||
}
|
||
}
|
||
}
|
||
Err(err) => errors.push(err),
|
||
}
|
||
}
|
||
|
||
if combined.is_empty() {
|
||
if let Some(err) = errors.pop() {
|
||
return Err(err);
|
||
}
|
||
}
|
||
|
||
self.annotate_scope_status(&mut combined).await;
|
||
combined.sort_by(|a, b| a.name.to_lowercase().cmp(&b.name.to_lowercase()));
|
||
Ok(combined)
|
||
}
|
||
|
||
async fn fetch_models_for_scope(
|
||
&self,
|
||
scope: OllamaMode,
|
||
handle: ScopeHandle,
|
||
) -> Result<Vec<ModelInfo>> {
|
||
let tags_result = self.fetch_scope_tags_with_retry(scope, &handle).await;
|
||
|
||
let models = match tags_result {
|
||
Ok(models) => models,
|
||
Err(err) => {
|
||
let original_detail = err.to_string();
|
||
let (error, banner) = self.decorate_scope_error(scope, &handle.base_url, err);
|
||
if banner != original_detail {
|
||
debug!(
|
||
"Model list for {:?} at {} failed: {}",
|
||
scope, handle.base_url, original_detail
|
||
);
|
||
}
|
||
self.mark_scope_failure(scope, banner.clone()).await;
|
||
if let Some(cached) = self.cached_scope_models(scope).await {
|
||
return Ok(cached);
|
||
}
|
||
return Err(error);
|
||
}
|
||
};
|
||
|
||
let cache = self.model_details_cache.clone();
|
||
let client = handle.client.clone();
|
||
let fetched = join_all(models.into_iter().map(|local| {
|
||
let client = client.clone();
|
||
let cache = cache.clone();
|
||
async move {
|
||
let name = local.name.clone();
|
||
let detail = match client.show_model_info(name.clone()).await {
|
||
Ok(info) => {
|
||
let detailed = OllamaProvider::convert_detailed_model_info(
|
||
scope,
|
||
&name,
|
||
Some(&local),
|
||
&info,
|
||
);
|
||
cache.insert(detailed).await;
|
||
Some(info)
|
||
}
|
||
Err(err) => {
|
||
debug!("Failed to fetch Ollama model info for '{name}': {err}");
|
||
None
|
||
}
|
||
};
|
||
(local, detail)
|
||
}
|
||
}))
|
||
.await;
|
||
|
||
let converted: Vec<ModelInfo> = fetched
|
||
.into_iter()
|
||
.map(|(local, detail)| self.convert_model(scope, local, detail))
|
||
.collect();
|
||
|
||
self.update_scope_success(scope, &converted).await;
|
||
Ok(converted)
|
||
}
|
||
|
||
async fn fetch_scope_tags_with_retry(
|
||
&self,
|
||
scope: OllamaMode,
|
||
handle: &ScopeHandle,
|
||
) -> Result<Vec<LocalModel>> {
|
||
let attempts = if matches!(scope, OllamaMode::Local) {
|
||
LOCAL_TAGS_TIMEOUT_STEPS_MS.len()
|
||
} else {
|
||
1
|
||
};
|
||
|
||
let mut last_error: Option<Error> = None;
|
||
|
||
for attempt in 0..attempts {
|
||
match self.fetch_scope_tags_once(scope, handle, attempt).await {
|
||
Ok(models) => return Ok(models),
|
||
Err(err) => {
|
||
let should_retry = matches!(scope, OllamaMode::Local)
|
||
&& attempt + 1 < attempts
|
||
&& matches!(err, Error::Timeout(_) | Error::Network(_));
|
||
|
||
if should_retry {
|
||
debug!(
|
||
"Retrying Ollama model list for {:?} (attempt {}): {}",
|
||
scope,
|
||
attempt + 1,
|
||
err
|
||
);
|
||
last_error = Some(err);
|
||
sleep(self.tags_retry_delay(attempt)).await;
|
||
continue;
|
||
}
|
||
return Err(err);
|
||
}
|
||
}
|
||
}
|
||
|
||
Err(last_error
|
||
.unwrap_or_else(|| Error::Unknown("Ollama model list retries exhausted".to_string())))
|
||
}
|
||
|
||
async fn fetch_scope_tags_once(
|
||
&self,
|
||
scope: OllamaMode,
|
||
handle: &ScopeHandle,
|
||
attempt: usize,
|
||
) -> Result<Vec<LocalModel>> {
|
||
#[cfg(test)]
|
||
if let Some(result) = pop_tags_override() {
|
||
return result;
|
||
}
|
||
|
||
if matches!(scope, OllamaMode::Local) {
|
||
match self.list_local_models_via_cli() {
|
||
Ok(models) => return Ok(models),
|
||
Err(err) => {
|
||
debug!("`ollama ls` failed ({}); falling back to HTTP listing", err);
|
||
}
|
||
}
|
||
}
|
||
|
||
let url = handle.api_url("tags");
|
||
let response = handle
|
||
.http_client
|
||
.get(&url)
|
||
.timeout(self.tags_request_timeout(scope, attempt))
|
||
.send()
|
||
.await
|
||
.map_err(|err| map_reqwest_error("list models", err))?;
|
||
|
||
if !response.status().is_success() {
|
||
let status = response.status();
|
||
let detail = response.text().await.unwrap_or_else(|err| err.to_string());
|
||
return Err(self.map_http_failure("list models", status, detail, None));
|
||
}
|
||
|
||
let bytes = response
|
||
.bytes()
|
||
.await
|
||
.map_err(|err| map_reqwest_error("list models", err))?;
|
||
let parsed: TagsResponse = serde_json::from_slice(&bytes)?;
|
||
Ok(parsed.models)
|
||
}
|
||
|
||
fn tags_request_timeout(&self, scope: OllamaMode, attempt: usize) -> Duration {
|
||
if matches!(scope, OllamaMode::Local) {
|
||
let idx = attempt.min(LOCAL_TAGS_TIMEOUT_STEPS_MS.len() - 1);
|
||
Duration::from_millis(LOCAL_TAGS_TIMEOUT_STEPS_MS[idx])
|
||
} else {
|
||
self.request_timeout
|
||
}
|
||
}
|
||
|
||
fn tags_retry_delay(&self, attempt: usize) -> Duration {
|
||
let idx = attempt.min(LOCAL_TAGS_RETRY_DELAYS_MS.len() - 1);
|
||
Duration::from_millis(LOCAL_TAGS_RETRY_DELAYS_MS[idx])
|
||
}
|
||
|
||
fn list_local_models_via_cli(&self) -> Result<Vec<LocalModel>> {
|
||
let output = Command::new("ollama")
|
||
.arg("ls")
|
||
.output()
|
||
.map_err(|err| {
|
||
Error::Provider(anyhow!(
|
||
"Failed to execute `ollama ls`: {err}. Ensure the Ollama CLI is installed and accessible in PATH."
|
||
))
|
||
})?;
|
||
|
||
if !output.status.success() {
|
||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||
return Err(Error::Provider(anyhow!(
|
||
"`ollama ls` exited with status {}: {}",
|
||
output.status,
|
||
stderr.trim()
|
||
)));
|
||
}
|
||
|
||
let stdout = String::from_utf8(output.stdout).map_err(|err| {
|
||
Error::Provider(anyhow!("`ollama ls` returned non-UTF8 output: {err}"))
|
||
})?;
|
||
|
||
let mut models = Vec::new();
|
||
for line in stdout.lines() {
|
||
let trimmed = line.trim();
|
||
if trimmed.is_empty() {
|
||
continue;
|
||
}
|
||
let lowercase = trimmed.to_ascii_lowercase();
|
||
if lowercase.starts_with("name") {
|
||
continue;
|
||
}
|
||
|
||
let mut parts = trimmed.split_whitespace();
|
||
let Some(name) = parts.next() else {
|
||
continue;
|
||
};
|
||
let metadata_start = trimmed[name.len()..].trim();
|
||
|
||
models.push(LocalModel {
|
||
name: name.to_string(),
|
||
modified_at: metadata_start.to_string(),
|
||
size: 0,
|
||
});
|
||
}
|
||
|
||
Ok(models)
|
||
}
|
||
|
||
fn decorate_scope_error(
|
||
&self,
|
||
scope: OllamaMode,
|
||
base_url: &str,
|
||
err: Error,
|
||
) -> (Error, String) {
|
||
if matches!(scope, OllamaMode::Local) {
|
||
match err {
|
||
Error::Timeout(_) => {
|
||
let message = format_local_unreachable_message(base_url);
|
||
(Error::Timeout(message.clone()), message)
|
||
}
|
||
Error::Network(original) => {
|
||
if is_connectivity_error(&original) {
|
||
let message = format_local_unreachable_message(base_url);
|
||
(Error::Network(message.clone()), message)
|
||
} else {
|
||
let message = original.clone();
|
||
(Error::Network(original), message)
|
||
}
|
||
}
|
||
other => {
|
||
let message = other.to_string();
|
||
(other, message)
|
||
}
|
||
}
|
||
} else {
|
||
let message = err.to_string();
|
||
(err, message)
|
||
}
|
||
}
|
||
|
||
fn convert_detailed_model_info(
|
||
mode: OllamaMode,
|
||
model_name: &str,
|
||
local: Option<&LocalModel>,
|
||
detail: &OllamaModelInfo,
|
||
) -> DetailedModelInfo {
|
||
let map = &detail.model_info;
|
||
|
||
let architecture =
|
||
pick_first_string(map, &["architecture", "model_format", "model_type", "arch"]);
|
||
|
||
let parameters = non_empty(detail.parameters.clone())
|
||
.or_else(|| pick_first_string(map, &["parameters"]));
|
||
|
||
let parameter_size = pick_first_string(map, &["parameter_size"]);
|
||
|
||
let context_length = pick_first_u64(map, &["context_length", "num_ctx", "max_context"]);
|
||
let embedding_length = pick_first_u64(map, &["embedding_length"]);
|
||
|
||
let quantization =
|
||
pick_first_string(map, &["quantization_level", "quantization", "quantize"]);
|
||
|
||
let family = pick_first_string(map, &["family", "model_family"]);
|
||
let mut families = pick_string_list(map, &["families", "model_families"]);
|
||
|
||
if families.is_empty() {
|
||
families.extend(family.clone());
|
||
}
|
||
|
||
let system = pick_first_string(map, &["system"]);
|
||
|
||
let mut modified_at = local
|
||
.and_then(|entry| non_empty(entry.modified_at.clone()))
|
||
.or_else(|| pick_first_string(map, &["modified_at", "created_at"]));
|
||
|
||
if modified_at.is_none() && mode == OllamaMode::Cloud {
|
||
modified_at = pick_first_string(map, &["updated_at"]);
|
||
}
|
||
|
||
let size = local
|
||
.and_then(|entry| {
|
||
if entry.size > 0 {
|
||
Some(entry.size)
|
||
} else {
|
||
None
|
||
}
|
||
})
|
||
.or_else(|| pick_first_u64(map, &["size", "model_size", "download_size"]));
|
||
|
||
let digest = pick_first_string(map, &["digest", "sha256", "checksum"]);
|
||
|
||
let mut info = DetailedModelInfo {
|
||
name: model_name.to_string(),
|
||
architecture,
|
||
parameters,
|
||
context_length,
|
||
embedding_length,
|
||
quantization,
|
||
family,
|
||
families,
|
||
parameter_size,
|
||
template: non_empty(detail.template.clone()),
|
||
system,
|
||
license: non_empty(detail.license.clone()),
|
||
modelfile: non_empty(detail.modelfile.clone()),
|
||
modified_at,
|
||
size,
|
||
digest,
|
||
};
|
||
|
||
if info.parameter_size.is_none() {
|
||
info.parameter_size = info.parameters.clone();
|
||
}
|
||
|
||
info.with_normalised_strings()
|
||
}
|
||
|
||
fn convert_model(
|
||
&self,
|
||
scope: OllamaMode,
|
||
model: LocalModel,
|
||
detail: Option<OllamaModelInfo>,
|
||
) -> ModelInfo {
|
||
let scope_tag = match scope {
|
||
OllamaMode::Local => "local",
|
||
OllamaMode::Cloud => "cloud",
|
||
};
|
||
|
||
let name = model.name;
|
||
let mut capabilities: Vec<String> = detail
|
||
.as_ref()
|
||
.map(|info| {
|
||
info.capabilities
|
||
.iter()
|
||
.map(|cap| cap.to_ascii_lowercase())
|
||
.collect()
|
||
})
|
||
.unwrap_or_default();
|
||
|
||
push_capability(&mut capabilities, "chat");
|
||
|
||
for heuristic in heuristic_capabilities(&name) {
|
||
push_capability(&mut capabilities, &heuristic);
|
||
}
|
||
|
||
push_capability(&mut capabilities, &format!("scope:{scope_tag}"));
|
||
|
||
let description = build_model_description(scope_tag, detail.as_ref());
|
||
|
||
let context_window = detail.as_ref().and_then(|info| {
|
||
pick_first_u64(
|
||
&info.model_info,
|
||
&["context_length", "num_ctx", "max_context"],
|
||
)
|
||
.and_then(|raw| u32::try_from(raw).ok())
|
||
});
|
||
|
||
let supports_tools = model_supports_tools(&name, &capabilities, detail.as_ref());
|
||
|
||
ModelInfo {
|
||
id: name.clone(),
|
||
name,
|
||
description: Some(description),
|
||
provider: self.provider_name.clone(),
|
||
context_window,
|
||
capabilities,
|
||
supports_tools,
|
||
}
|
||
}
|
||
|
||
fn convert_ollama_response(response: OllamaChatResponse, streaming: bool) -> ChatResponse {
|
||
let OllamaChatResponse {
|
||
model,
|
||
created_at,
|
||
message,
|
||
done,
|
||
final_data,
|
||
} = response;
|
||
|
||
let usage = final_data.as_ref().map(|data| {
|
||
let prompt = clamp_to_u32(data.prompt_eval_count);
|
||
let completion = clamp_to_u32(data.eval_count);
|
||
TokenUsage {
|
||
prompt_tokens: prompt,
|
||
completion_tokens: completion,
|
||
total_tokens: prompt.saturating_add(completion),
|
||
}
|
||
});
|
||
|
||
let mut message = convert_ollama_message(message);
|
||
|
||
let mut provider_meta = JsonMap::new();
|
||
provider_meta.insert("model".into(), Value::String(model));
|
||
provider_meta.insert("created_at".into(), Value::String(created_at));
|
||
|
||
if let Some(ref final_block) = final_data {
|
||
if let Ok(value) = serde_json::to_value(final_block) {
|
||
provider_meta.insert("final_data".into(), value);
|
||
}
|
||
}
|
||
|
||
message
|
||
.metadata
|
||
.insert("ollama".into(), Value::Object(provider_meta));
|
||
|
||
ChatResponse {
|
||
message,
|
||
usage,
|
||
is_streaming: streaming,
|
||
is_final: if streaming { done } else { true },
|
||
}
|
||
}
|
||
|
||
fn provider_failure(
|
||
&self,
|
||
kind: ProviderErrorKind,
|
||
message: impl Into<String>,
|
||
detail: Option<String>,
|
||
) -> Error {
|
||
let error = ProviderError::new(kind, message).with_provider(self.provider_name.clone());
|
||
let error = if let Some(detail) = detail {
|
||
error.with_detail(detail)
|
||
} else {
|
||
error
|
||
};
|
||
Error::ProviderFailure(error)
|
||
}
|
||
|
||
fn map_ollama_error(&self, action: &str, err: OllamaError, model: Option<&str>) -> Error {
|
||
match err {
|
||
OllamaError::ReqwestError(request_err) => {
|
||
if let Some(status) = request_err.status() {
|
||
self.map_http_failure(action, status, request_err.to_string(), model)
|
||
} else if request_err.is_timeout() {
|
||
self.provider_failure(
|
||
ProviderErrorKind::Timeout,
|
||
format!("Ollama {action} timed out"),
|
||
Some(request_err.to_string()),
|
||
)
|
||
} else if request_err.is_connect() || request_err.is_request() {
|
||
self.provider_failure(
|
||
ProviderErrorKind::Network,
|
||
format!("Ollama {action} request failed"),
|
||
Some(request_err.to_string()),
|
||
)
|
||
} else {
|
||
Error::Provider(anyhow!(request_err))
|
||
}
|
||
}
|
||
OllamaError::InternalError(internal) => self.provider_failure(
|
||
ProviderErrorKind::Protocol,
|
||
format!("Ollama {action} internal error"),
|
||
Some(internal.message),
|
||
),
|
||
OllamaError::Other(message) => {
|
||
let parsed_error = serde_json::from_str::<Value>(&message)
|
||
.ok()
|
||
.and_then(|value| {
|
||
value
|
||
.get("error")
|
||
.and_then(Value::as_str)
|
||
.map(|err| err.trim().to_string())
|
||
})
|
||
.map(|err| err.to_ascii_lowercase());
|
||
|
||
if let Some(err) = parsed_error.as_deref() {
|
||
if err.contains("too many") || err.contains("rate limit") {
|
||
return self.provider_failure(
|
||
ProviderErrorKind::RateLimited,
|
||
format!("Ollama {action} request rate limited"),
|
||
Some(message),
|
||
);
|
||
}
|
||
|
||
if err.contains("unauthorized") || err.contains("invalid api key") {
|
||
return self.provider_failure(
|
||
ProviderErrorKind::Unauthorized,
|
||
format!("Ollama {action} rejected the request (unauthorized). Check your API key and account permissions."),
|
||
Some(message),
|
||
);
|
||
}
|
||
}
|
||
|
||
self.provider_failure(
|
||
ProviderErrorKind::Unknown,
|
||
format!("Ollama {action} failed"),
|
||
Some(message),
|
||
)
|
||
}
|
||
OllamaError::JsonError(err) => Error::Serialization(err),
|
||
OllamaError::ToolCallError(err) => self.provider_failure(
|
||
ProviderErrorKind::Protocol,
|
||
format!("Ollama {action} tool call failed"),
|
||
Some(err.to_string()),
|
||
),
|
||
}
|
||
}
|
||
|
||
fn map_http_failure(
|
||
&self,
|
||
action: &str,
|
||
status: StatusCode,
|
||
detail: String,
|
||
model: Option<&str>,
|
||
) -> Error {
|
||
match status {
|
||
StatusCode::NOT_FOUND => {
|
||
if let Some(model) = model {
|
||
Error::InvalidInput(format!(
|
||
"Model '{model}' was not found at {}. Verify the name or pull it with `ollama pull {model}`.",
|
||
self.base_url
|
||
))
|
||
} else {
|
||
Error::InvalidInput(format!(
|
||
"{action} returned 404 from {}: {detail}",
|
||
self.base_url
|
||
))
|
||
}
|
||
}
|
||
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => self.provider_failure(
|
||
ProviderErrorKind::Unauthorized,
|
||
format!(
|
||
"Ollama rejected the request ({status}). Check your API key and account permissions."
|
||
),
|
||
Some(detail),
|
||
),
|
||
StatusCode::TOO_MANY_REQUESTS => self.provider_failure(
|
||
ProviderErrorKind::RateLimited,
|
||
format!("Ollama {action} request rate limited"),
|
||
Some(detail),
|
||
),
|
||
StatusCode::BAD_REQUEST => {
|
||
Error::InvalidInput(format!("{action} rejected by Ollama ({status}): {detail}"))
|
||
}
|
||
StatusCode::SERVICE_UNAVAILABLE | StatusCode::GATEWAY_TIMEOUT => self.provider_failure(
|
||
ProviderErrorKind::Timeout,
|
||
format!("Ollama {action} timed out ({status}). The model may still be loading."),
|
||
Some(detail),
|
||
),
|
||
status if status.is_server_error() => self.provider_failure(
|
||
ProviderErrorKind::Unavailable,
|
||
format!("Ollama {action} request failed ({status}). Try again later."),
|
||
Some(detail),
|
||
),
|
||
status if status.is_client_error() => self.provider_failure(
|
||
ProviderErrorKind::InvalidRequest,
|
||
format!("Ollama {action} rejected the request ({status})."),
|
||
Some(detail),
|
||
),
|
||
_ => self.provider_failure(
|
||
ProviderErrorKind::Unknown,
|
||
format!("Ollama {action} failed ({status})."),
|
||
Some(detail),
|
||
),
|
||
}
|
||
}
|
||
}
|
||
|
||
impl LlmProvider for OllamaProvider {
|
||
type Stream = Pin<Box<dyn Stream<Item = Result<ChatResponse>> + Send>>;
|
||
type ListModelsFuture<'a>
|
||
= BoxFuture<'a, Result<Vec<ModelInfo>>>
|
||
where
|
||
Self: 'a;
|
||
type SendPromptFuture<'a>
|
||
= BoxFuture<'a, Result<ChatResponse>>
|
||
where
|
||
Self: 'a;
|
||
type StreamPromptFuture<'a>
|
||
= BoxFuture<'a, Result<Self::Stream>>
|
||
where
|
||
Self: 'a;
|
||
type HealthCheckFuture<'a>
|
||
= BoxFuture<'a, Result<()>>
|
||
where
|
||
Self: 'a;
|
||
|
||
fn name(&self) -> &str {
|
||
&self.provider_name
|
||
}
|
||
|
||
fn list_models(&self) -> Self::ListModelsFuture<'_> {
|
||
Box::pin(async move {
|
||
self.model_manager
|
||
.get_or_refresh(false, || async { self.fetch_models().await })
|
||
.await
|
||
})
|
||
}
|
||
|
||
fn send_prompt(&self, request: ChatRequest) -> Self::SendPromptFuture<'_> {
|
||
Box::pin(async move {
|
||
let ChatRequest {
|
||
model,
|
||
messages,
|
||
parameters,
|
||
tools,
|
||
} = request;
|
||
|
||
let (model_id, ollama_request) =
|
||
self.prepare_chat_request(model, messages, parameters, tools)?;
|
||
|
||
let response = self
|
||
.client
|
||
.send_chat_messages(ollama_request)
|
||
.await
|
||
.map_err(|err| self.map_ollama_error("chat", err, Some(&model_id)))?;
|
||
|
||
Ok(Self::convert_ollama_response(response, false))
|
||
})
|
||
}
|
||
|
||
fn stream_prompt(&self, request: ChatRequest) -> Self::StreamPromptFuture<'_> {
|
||
Box::pin(async move {
|
||
let ChatRequest {
|
||
model,
|
||
messages,
|
||
parameters,
|
||
tools,
|
||
} = request;
|
||
|
||
let (model_id, ollama_request) =
|
||
self.prepare_chat_request(model, messages, parameters, tools)?;
|
||
|
||
let stream = self
|
||
.client
|
||
.send_chat_messages_stream(ollama_request)
|
||
.await
|
||
.map_err(|err| self.map_ollama_error("chat_stream", err, Some(&model_id)))?;
|
||
|
||
let mapped = stream.map(|item| match item {
|
||
Ok(chunk) => Ok(Self::convert_ollama_response(chunk, true)),
|
||
Err(_) => Err(Error::Provider(anyhow!(
|
||
"Ollama returned a malformed streaming chunk"
|
||
))),
|
||
});
|
||
|
||
Ok(Box::pin(mapped) as Self::Stream)
|
||
})
|
||
}
|
||
|
||
fn health_check(&self) -> Self::HealthCheckFuture<'_> {
|
||
Box::pin(async move {
|
||
let url = self.api_url("tags?limit=1");
|
||
let response = self
|
||
.http_client
|
||
.get(&url)
|
||
.timeout(Duration::from_millis(HEALTHCHECK_TIMEOUT_MS))
|
||
.send()
|
||
.await
|
||
.map_err(|err| map_reqwest_error("health check", err))?;
|
||
|
||
if response.status().is_success() {
|
||
return Ok(());
|
||
}
|
||
|
||
let status = response.status();
|
||
let detail = response.text().await.unwrap_or_else(|err| err.to_string());
|
||
Err(self.map_http_failure("health check", status, detail, None))
|
||
})
|
||
}
|
||
|
||
fn config_schema(&self) -> serde_json::Value {
|
||
serde_json::json!({
|
||
"type": "object",
|
||
"properties": {
|
||
"base_url": {
|
||
"type": "string",
|
||
"description": "Base URL for the Ollama API (ignored when api_key is provided)",
|
||
"default": self.mode.default_base_url()
|
||
},
|
||
"timeout_secs": {
|
||
"type": "integer",
|
||
"description": "HTTP request timeout in seconds",
|
||
"minimum": 5,
|
||
"default": DEFAULT_TIMEOUT_SECS
|
||
},
|
||
"model_cache_ttl_secs": {
|
||
"type": "integer",
|
||
"description": "Seconds to cache model listings",
|
||
"minimum": 5,
|
||
"default": DEFAULT_MODEL_CACHE_TTL_SECS
|
||
},
|
||
"hourly_quota_tokens": {
|
||
"type": "integer",
|
||
"description": "Soft hourly token quota used for UI alerts",
|
||
"minimum": 0,
|
||
"default": DEFAULT_OLLAMA_CLOUD_HOURLY_QUOTA
|
||
},
|
||
"weekly_quota_tokens": {
|
||
"type": "integer",
|
||
"description": "Soft weekly token quota used for UI alerts",
|
||
"minimum": 0,
|
||
"default": DEFAULT_OLLAMA_CLOUD_WEEKLY_QUOTA
|
||
}
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
fn build_model_options(parameters: &ChatParameters) -> Result<Option<ModelOptions>> {
|
||
let mut options = JsonMap::new();
|
||
|
||
for (key, value) in ¶meters.extra {
|
||
options.insert(key.clone(), value.clone());
|
||
}
|
||
|
||
if let Some(temperature) = parameters.temperature {
|
||
options.insert("temperature".to_string(), json!(temperature));
|
||
}
|
||
|
||
if let Some(max_tokens) = parameters.max_tokens {
|
||
let capped = i32::try_from(max_tokens).unwrap_or(i32::MAX);
|
||
options.insert("num_predict".to_string(), json!(capped));
|
||
}
|
||
|
||
if options.is_empty() {
|
||
return Ok(None);
|
||
}
|
||
|
||
serde_json::from_value(Value::Object(options))
|
||
.map(Some)
|
||
.map_err(|err| Error::Config(format!("Invalid Ollama options: {err}")))
|
||
}
|
||
|
||
fn convert_tool_descriptors(descriptors: &[McpToolDescriptor]) -> Result<Vec<OllamaToolInfo>> {
|
||
descriptors
|
||
.iter()
|
||
.map(|descriptor| {
|
||
let payload = json!({
|
||
"type": "Function",
|
||
"function": {
|
||
"name": descriptor.name,
|
||
"description": descriptor.description,
|
||
"parameters": descriptor.input_schema
|
||
}
|
||
});
|
||
|
||
serde_json::from_value(payload).map_err(|err| {
|
||
Error::Config(format!(
|
||
"Invalid tool schema for '{}': {err}",
|
||
descriptor.name
|
||
))
|
||
})
|
||
})
|
||
.collect()
|
||
}
|
||
|
||
fn convert_message(message: Message) -> OllamaMessage {
|
||
let Message {
|
||
role,
|
||
content,
|
||
metadata,
|
||
tool_calls,
|
||
attachments,
|
||
..
|
||
} = message;
|
||
|
||
let role = match role {
|
||
Role::User => OllamaRole::User,
|
||
Role::Assistant => OllamaRole::Assistant,
|
||
Role::System => OllamaRole::System,
|
||
Role::Tool => OllamaRole::Tool,
|
||
};
|
||
|
||
let tool_calls = tool_calls
|
||
.unwrap_or_default()
|
||
.into_iter()
|
||
.map(|tool_call| OllamaToolCall {
|
||
function: OllamaToolCallFunction {
|
||
name: tool_call.name,
|
||
arguments: tool_call.arguments,
|
||
},
|
||
})
|
||
.collect();
|
||
|
||
let thinking = metadata
|
||
.get("thinking")
|
||
.and_then(|value| value.as_str().map(|s| s.to_owned()));
|
||
|
||
let images: Vec<Image> = attachments
|
||
.into_iter()
|
||
.filter_map(|attachment| {
|
||
if !attachment.is_image() {
|
||
return None;
|
||
}
|
||
if let Some(data) = attachment.data_base64 {
|
||
return Some(Image::from_base64(data));
|
||
}
|
||
if let Some(path) = attachment.source_path {
|
||
match fs::read(&path) {
|
||
Ok(bytes) => {
|
||
let encoded = BASE64_STANDARD.encode(bytes);
|
||
return Some(Image::from_base64(encoded));
|
||
}
|
||
Err(err) => {
|
||
warn!(
|
||
"Failed to read attachment '{}' for image conversion: {}",
|
||
path.display(),
|
||
err
|
||
);
|
||
}
|
||
}
|
||
}
|
||
None
|
||
})
|
||
.collect();
|
||
|
||
OllamaMessage {
|
||
role,
|
||
content,
|
||
tool_calls,
|
||
images: if images.is_empty() {
|
||
None
|
||
} else {
|
||
Some(images)
|
||
},
|
||
thinking,
|
||
}
|
||
}
|
||
|
||
fn convert_ollama_message(message: OllamaMessage) -> Message {
|
||
let role = match message.role {
|
||
OllamaRole::Assistant => Role::Assistant,
|
||
OllamaRole::System => Role::System,
|
||
OllamaRole::Tool => Role::Tool,
|
||
OllamaRole::User => Role::User,
|
||
};
|
||
|
||
let tool_calls = if message.tool_calls.is_empty() {
|
||
None
|
||
} else {
|
||
Some(
|
||
message
|
||
.tool_calls
|
||
.into_iter()
|
||
.enumerate()
|
||
.map(|(idx, tool_call)| ToolCall {
|
||
id: format!("tool-call-{idx}"),
|
||
name: tool_call.function.name,
|
||
arguments: tool_call.function.arguments,
|
||
})
|
||
.collect::<Vec<_>>(),
|
||
)
|
||
};
|
||
|
||
let mut metadata = HashMap::new();
|
||
if let Some(thinking) = message.thinking {
|
||
metadata.insert("thinking".to_string(), Value::String(thinking));
|
||
}
|
||
|
||
let attachments = message
|
||
.images
|
||
.unwrap_or_default()
|
||
.into_iter()
|
||
.enumerate()
|
||
.filter_map(|(idx, image)| {
|
||
let data = image.to_base64();
|
||
if data.is_empty() {
|
||
return None;
|
||
}
|
||
let size_bytes = (data.len() as u64).saturating_mul(3).saturating_div(4);
|
||
let name = format!("image-{}.png", idx + 1);
|
||
Some(
|
||
MessageAttachment::from_base64(
|
||
name,
|
||
"image/png",
|
||
data.to_string(),
|
||
Some(size_bytes),
|
||
)
|
||
.with_description(format!("Generated image {}", idx + 1)),
|
||
)
|
||
})
|
||
.collect();
|
||
|
||
Message {
|
||
id: Uuid::new_v4(),
|
||
role,
|
||
content: message.content,
|
||
metadata,
|
||
timestamp: SystemTime::now(),
|
||
tool_calls,
|
||
attachments,
|
||
}
|
||
}
|
||
|
||
fn clamp_to_u32(value: u64) -> u32 {
|
||
u32::try_from(value).unwrap_or(u32::MAX)
|
||
}
|
||
|
||
fn push_capability(capabilities: &mut Vec<String>, capability: &str) {
|
||
let candidate = capability.to_ascii_lowercase();
|
||
if !capabilities
|
||
.iter()
|
||
.any(|existing| existing.eq_ignore_ascii_case(&candidate))
|
||
{
|
||
capabilities.push(candidate);
|
||
}
|
||
}
|
||
|
||
fn heuristic_capabilities(name: &str) -> Vec<String> {
|
||
let lowercase = name.to_ascii_lowercase();
|
||
let mut detected = Vec::new();
|
||
|
||
if lowercase.contains("vision")
|
||
|| lowercase.contains("multimodal")
|
||
|| lowercase.contains("image")
|
||
{
|
||
detected.push("vision".to_string());
|
||
}
|
||
|
||
if lowercase.contains("think")
|
||
|| lowercase.contains("reason")
|
||
|| lowercase.contains("deepseek-r1")
|
||
|| lowercase.contains("r1")
|
||
{
|
||
detected.push("thinking".to_string());
|
||
}
|
||
|
||
if lowercase.contains("audio") || lowercase.contains("speech") || lowercase.contains("voice") {
|
||
detected.push("audio".to_string());
|
||
}
|
||
|
||
detected
|
||
}
|
||
|
||
fn capability_implies_tools(label: &str) -> bool {
|
||
let normalized = label.to_ascii_lowercase();
|
||
normalized.contains("tool")
|
||
|| normalized.contains("function_call")
|
||
|| normalized.contains("function-call")
|
||
|| normalized.contains("tool_call")
|
||
}
|
||
|
||
fn model_supports_tools(
|
||
name: &str,
|
||
capabilities: &[String],
|
||
detail: Option<&OllamaModelInfo>,
|
||
) -> bool {
|
||
if let Some(info) = detail {
|
||
if info
|
||
.capabilities
|
||
.iter()
|
||
.any(|capability| capability_implies_tools(capability))
|
||
{
|
||
return true;
|
||
}
|
||
}
|
||
|
||
if capabilities
|
||
.iter()
|
||
.any(|capability| capability_implies_tools(capability))
|
||
{
|
||
return true;
|
||
}
|
||
|
||
let lowered = name.to_ascii_lowercase();
|
||
["functioncall", "function-call", "function_call", "tool"]
|
||
.iter()
|
||
.any(|needle| lowered.contains(needle))
|
||
}
|
||
|
||
fn build_model_description(scope: &str, detail: Option<&OllamaModelInfo>) -> String {
|
||
if let Some(info) = detail {
|
||
let mut parts = Vec::new();
|
||
|
||
if let Some(family) = info
|
||
.model_info
|
||
.get("family")
|
||
.and_then(|value| value.as_str())
|
||
{
|
||
parts.push(family.to_string());
|
||
}
|
||
|
||
if let Some(parameter_size) = info
|
||
.model_info
|
||
.get("parameter_size")
|
||
.and_then(|value| value.as_str())
|
||
{
|
||
parts.push(parameter_size.to_string());
|
||
}
|
||
|
||
if let Some(variant) = info
|
||
.model_info
|
||
.get("variant")
|
||
.and_then(|value| value.as_str())
|
||
{
|
||
parts.push(variant.to_string());
|
||
}
|
||
|
||
if !parts.is_empty() {
|
||
return format!("Ollama ({scope}) – {}", parts.join(" · "));
|
||
}
|
||
}
|
||
|
||
format!("Ollama ({scope}) model")
|
||
}
|
||
|
||
fn non_empty(value: String) -> Option<String> {
|
||
let trimmed = value.trim();
|
||
if trimmed.is_empty() {
|
||
None
|
||
} else {
|
||
Some(value)
|
||
}
|
||
}
|
||
|
||
fn pick_first_string(map: &JsonMap<String, Value>, keys: &[&str]) -> Option<String> {
|
||
keys.iter()
|
||
.filter_map(|key| map.get(*key))
|
||
.find_map(value_to_string)
|
||
.map(|s| s.trim().to_string())
|
||
.filter(|s| !s.is_empty())
|
||
}
|
||
|
||
fn pick_first_u64(map: &JsonMap<String, Value>, keys: &[&str]) -> Option<u64> {
|
||
keys.iter()
|
||
.filter_map(|key| map.get(*key))
|
||
.find_map(value_to_u64)
|
||
}
|
||
|
||
fn pick_string_list(map: &JsonMap<String, Value>, keys: &[&str]) -> Vec<String> {
|
||
for key in keys {
|
||
if let Some(value) = map.get(*key) {
|
||
match value {
|
||
Value::Array(items) => {
|
||
let collected: Vec<String> = items
|
||
.iter()
|
||
.filter_map(value_to_string)
|
||
.map(|s| s.trim().to_string())
|
||
.filter(|s| !s.is_empty())
|
||
.collect();
|
||
if !collected.is_empty() {
|
||
return collected;
|
||
}
|
||
}
|
||
Value::String(text) => {
|
||
let collected: Vec<String> = text
|
||
.split(',')
|
||
.map(|part| part.trim())
|
||
.filter(|part| !part.is_empty())
|
||
.map(|part| part.to_string())
|
||
.collect();
|
||
if !collected.is_empty() {
|
||
return collected;
|
||
}
|
||
}
|
||
_ => {}
|
||
}
|
||
}
|
||
}
|
||
Vec::new()
|
||
}
|
||
|
||
fn value_to_string(value: &Value) -> Option<String> {
|
||
match value {
|
||
Value::String(text) => Some(text.clone()),
|
||
Value::Number(num) => Some(num.to_string()),
|
||
Value::Bool(flag) => Some(flag.to_string()),
|
||
_ => None,
|
||
}
|
||
}
|
||
|
||
fn value_to_u64(value: &Value) -> Option<u64> {
|
||
match value {
|
||
Value::Number(num) => {
|
||
if let Some(v) = num.as_u64() {
|
||
Some(v)
|
||
} else if let Some(v) = num.as_i64() {
|
||
v.try_into().ok()
|
||
} else if let Some(v) = num.as_f64() {
|
||
if v >= 0.0 { Some(v as u64) } else { None }
|
||
} else {
|
||
None
|
||
}
|
||
}
|
||
Value::String(text) => text.trim().parse::<u64>().ok(),
|
||
_ => None,
|
||
}
|
||
}
|
||
|
||
fn format_local_unreachable_message(base_url: &str) -> String {
|
||
let display = display_host_port(base_url);
|
||
format!(
|
||
"Ollama not reachable on {display}. Start the Ollama daemon (`ollama serve`) and try again."
|
||
)
|
||
}
|
||
|
||
fn display_host_port(base_url: &str) -> String {
|
||
Url::parse(base_url)
|
||
.ok()
|
||
.and_then(|url| {
|
||
url.host_str().map(|host| {
|
||
if let Some(port) = url.port() {
|
||
format!("{host}:{port}")
|
||
} else {
|
||
host.to_string()
|
||
}
|
||
})
|
||
})
|
||
.unwrap_or_else(|| base_url.to_string())
|
||
}
|
||
|
||
fn is_connectivity_error(message: &str) -> bool {
|
||
let lower = message.to_ascii_lowercase();
|
||
const CONNECTIVITY_MARKERS: &[&str] = &[
|
||
"connection refused",
|
||
"failed to connect",
|
||
"connect timeout",
|
||
"timed out while contacting",
|
||
"dns error",
|
||
"failed to lookup address",
|
||
"no route to host",
|
||
"host unreachable",
|
||
];
|
||
|
||
CONNECTIVITY_MARKERS
|
||
.iter()
|
||
.any(|marker| lower.contains(marker))
|
||
}
|
||
|
||
fn env_var_non_empty(name: &str) -> Option<String> {
|
||
env::var(name)
|
||
.ok()
|
||
.map(|value| value.trim().to_string())
|
||
.filter(|value| !value.is_empty())
|
||
}
|
||
|
||
fn resolve_api_key_env_hint(env_var: Option<&str>) -> Option<String> {
|
||
let var = env_var?.trim();
|
||
if var.is_empty() {
|
||
return None;
|
||
}
|
||
|
||
if var.eq_ignore_ascii_case(LEGACY_OLLAMA_CLOUD_API_KEY_ENV)
|
||
|| var.eq_ignore_ascii_case(LEGACY_OWLEN_OLLAMA_CLOUD_API_KEY_ENV)
|
||
{
|
||
warn_legacy_cloud_env(var);
|
||
}
|
||
|
||
env_var_non_empty(var)
|
||
}
|
||
|
||
fn resolve_api_key(configured: Option<String>) -> Option<String> {
|
||
let raw = configured?.trim().to_string();
|
||
if raw.is_empty() {
|
||
return None;
|
||
}
|
||
|
||
if let Some(variable) = raw
|
||
.strip_prefix("${")
|
||
.and_then(|value| value.strip_suffix('}'))
|
||
.or_else(|| raw.strip_prefix('$'))
|
||
{
|
||
let var_name = variable.trim();
|
||
if var_name.is_empty() {
|
||
return None;
|
||
}
|
||
return env_var_non_empty(var_name);
|
||
}
|
||
|
||
Some(raw)
|
||
}
|
||
|
||
fn map_reqwest_error(action: &str, err: reqwest::Error) -> Error {
|
||
if err.is_timeout() {
|
||
Error::Timeout(format!("Ollama {action} request timed out: {err}"))
|
||
} else {
|
||
Error::Network(format!("Ollama {action} request failed: {err}"))
|
||
}
|
||
}
|
||
|
||
pub(crate) fn normalize_cloud_base_url(input: Option<&str>) -> std::result::Result<String, String> {
|
||
normalize_base_url(input, OllamaMode::Cloud)
|
||
}
|
||
|
||
fn normalize_base_url(
|
||
input: Option<&str>,
|
||
mode_hint: OllamaMode,
|
||
) -> std::result::Result<String, String> {
|
||
let mut candidate = input
|
||
.map(str::trim)
|
||
.filter(|value| !value.is_empty())
|
||
.map(|value| value.to_string())
|
||
.unwrap_or_else(|| mode_hint.default_base_url().to_string());
|
||
|
||
if !candidate.starts_with("http://") && !candidate.starts_with("https://") {
|
||
candidate = format!("https://{candidate}");
|
||
}
|
||
|
||
let mut url =
|
||
Url::parse(&candidate).map_err(|err| format!("Invalid Ollama URL '{candidate}': {err}"))?;
|
||
|
||
if url.cannot_be_a_base() {
|
||
return Err(format!("URL '{candidate}' cannot be used as a base URL"));
|
||
}
|
||
|
||
if mode_hint == OllamaMode::Cloud && url.scheme() != "https" {
|
||
if std::env::var("OWLEN_ALLOW_INSECURE_CLOUD").is_err() {
|
||
return Err("Ollama Cloud requires https:// base URLs".to_string());
|
||
}
|
||
}
|
||
|
||
let path = url.path().trim_end_matches('/');
|
||
match path {
|
||
"" | "/" => {}
|
||
"/api" | "/v1" => {
|
||
url.set_path("/");
|
||
}
|
||
_ => {
|
||
return Err("Ollama base URLs must not include additional path segments".to_string());
|
||
}
|
||
}
|
||
|
||
if mode_hint == OllamaMode::Cloud {
|
||
if let Some(host) = url.host_str() {
|
||
if host.eq_ignore_ascii_case("api.ollama.com") {
|
||
url.set_host(Some("ollama.com"))
|
||
.map_err(|err| format!("Failed to normalise Ollama Cloud host: {err}"))?;
|
||
}
|
||
}
|
||
}
|
||
|
||
url.set_query(None);
|
||
url.set_fragment(None);
|
||
|
||
Ok(url.to_string().trim_end_matches('/').to_string())
|
||
}
|
||
|
||
fn normalize_cloud_endpoint(input: &str) -> std::result::Result<String, String> {
|
||
normalize_base_url(Some(input), OllamaMode::Cloud)
|
||
}
|
||
|
||
fn build_api_endpoint(base_url: &str, endpoint: &str) -> String {
|
||
let trimmed_base = base_url.trim_end_matches('/');
|
||
let trimmed_endpoint = endpoint.trim_start_matches('/');
|
||
|
||
if trimmed_base.ends_with("/api") {
|
||
format!("{trimmed_base}/{trimmed_endpoint}")
|
||
} else {
|
||
format!("{trimmed_base}/api/{trimmed_endpoint}")
|
||
}
|
||
}
|
||
|
||
fn build_client_for_base(
|
||
base_url: &str,
|
||
timeout: Duration,
|
||
api_key: Option<&str>,
|
||
) -> Result<(Ollama, Client)> {
|
||
let url = Url::parse(base_url)
|
||
.map_err(|err| Error::Config(format!("Invalid Ollama base URL '{base_url}': {err}")))?;
|
||
|
||
let mut headers = HeaderMap::new();
|
||
if let Some(key) = api_key {
|
||
let value = HeaderValue::from_str(&format!("Bearer {key}"))
|
||
.map_err(|_| Error::Config("OLLAMA API key contains invalid characters".to_string()))?;
|
||
headers.insert(AUTHORIZATION, value);
|
||
}
|
||
|
||
let mut client_builder = Client::builder().timeout(timeout);
|
||
if !headers.is_empty() {
|
||
client_builder = client_builder.default_headers(headers.clone());
|
||
}
|
||
|
||
let http_client = client_builder.build().map_err(|err| {
|
||
Error::Config(format!(
|
||
"Failed to build HTTP client for '{base_url}': {err}"
|
||
))
|
||
})?;
|
||
|
||
let port = url.port_or_known_default().ok_or_else(|| {
|
||
Error::Config(format!("Unable to determine port for Ollama URL '{}'", url))
|
||
})?;
|
||
|
||
let mut ollama_client = Ollama::new_with_client(url.clone(), port, http_client.clone());
|
||
if !headers.is_empty() {
|
||
ollama_client.set_headers(Some(headers));
|
||
}
|
||
|
||
Ok((ollama_client, http_client))
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use crate::mcp::McpToolDescriptor;
|
||
use ollama_rs::generation::chat::ChatMessageFinalResponseData;
|
||
use serde_json::{Map as JsonMap, Value, json};
|
||
use std::collections::HashMap;
|
||
|
||
#[test]
|
||
fn resolve_api_key_prefers_literal_value() {
|
||
assert_eq!(
|
||
resolve_api_key(Some("direct-key".into())),
|
||
Some("direct-key".into())
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn resolve_api_key_expands_env_var() {
|
||
unsafe {
|
||
std::env::set_var("OLLAMA_TEST_KEY", "secret");
|
||
}
|
||
assert_eq!(
|
||
resolve_api_key(Some("${OLLAMA_TEST_KEY}".into())),
|
||
Some("secret".into())
|
||
);
|
||
unsafe {
|
||
std::env::remove_var("OLLAMA_TEST_KEY");
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn normalize_base_url_removes_api_path() {
|
||
let url = normalize_base_url(Some("https://ollama.com/api"), OllamaMode::Cloud).unwrap();
|
||
assert_eq!(url, "https://ollama.com");
|
||
}
|
||
|
||
#[test]
|
||
fn normalize_base_url_accepts_v1_path_for_local() {
|
||
let url = normalize_base_url(Some("http://localhost:11434/v1"), OllamaMode::Local).unwrap();
|
||
assert_eq!(url, "http://localhost:11434");
|
||
}
|
||
|
||
#[test]
|
||
fn normalize_base_url_accepts_v1_path_for_cloud() {
|
||
let url = normalize_base_url(Some("https://api.ollama.com/v1"), OllamaMode::Cloud).unwrap();
|
||
assert_eq!(url, "https://ollama.com");
|
||
}
|
||
|
||
#[test]
|
||
fn normalize_base_url_canonicalises_api_hostname() {
|
||
let url = normalize_base_url(Some("https://api.ollama.com"), OllamaMode::Cloud).unwrap();
|
||
assert_eq!(url, "https://ollama.com");
|
||
}
|
||
|
||
#[test]
|
||
fn normalize_base_url_rejects_cloud_without_https() {
|
||
let err = normalize_base_url(Some("http://ollama.com"), OllamaMode::Cloud).unwrap_err();
|
||
assert!(err.contains("https"));
|
||
}
|
||
|
||
#[test]
|
||
fn explicit_local_mode_overrides_api_key() {
|
||
let mut config = ProviderConfig {
|
||
enabled: true,
|
||
provider_type: "ollama".to_string(),
|
||
base_url: Some("http://localhost:11434".to_string()),
|
||
api_key: Some("secret-key".to_string()),
|
||
api_key_env: None,
|
||
extra: HashMap::new(),
|
||
};
|
||
config.extra.insert(
|
||
OLLAMA_MODE_KEY.to_string(),
|
||
Value::String("local".to_string()),
|
||
);
|
||
|
||
let provider = OllamaProvider::from_config("ollama_local", &config, None)
|
||
.expect("provider constructed");
|
||
|
||
assert_eq!(provider.mode, OllamaMode::Local);
|
||
assert_eq!(provider.base_url, "http://localhost:11434");
|
||
}
|
||
|
||
#[test]
|
||
fn auto_mode_prefers_explicit_local_base() {
|
||
let config = ProviderConfig {
|
||
enabled: true,
|
||
provider_type: "ollama".to_string(),
|
||
base_url: Some("http://localhost:11434".to_string()),
|
||
api_key: Some("secret-key".to_string()),
|
||
api_key_env: None,
|
||
extra: HashMap::new(),
|
||
};
|
||
// simulate missing explicit mode; defaults to auto
|
||
|
||
let provider = OllamaProvider::from_config("ollama_local", &config, None)
|
||
.expect("provider constructed");
|
||
|
||
assert_eq!(provider.mode, OllamaMode::Local);
|
||
assert_eq!(provider.base_url, "http://localhost:11434");
|
||
}
|
||
|
||
#[test]
|
||
fn auto_mode_with_api_key_and_no_local_probe_switches_to_cloud() {
|
||
let mut config = ProviderConfig {
|
||
enabled: true,
|
||
provider_type: "ollama".to_string(),
|
||
base_url: None,
|
||
api_key: Some("secret-key".to_string()),
|
||
api_key_env: None,
|
||
extra: HashMap::new(),
|
||
};
|
||
config.extra.insert(
|
||
OLLAMA_MODE_KEY.to_string(),
|
||
Value::String("auto".to_string()),
|
||
);
|
||
|
||
let provider = OllamaProvider::from_config("ollama_cloud", &config, None)
|
||
.expect("provider constructed");
|
||
|
||
assert_eq!(provider.mode, OllamaMode::Cloud);
|
||
assert_eq!(provider.base_url, CLOUD_BASE_URL);
|
||
}
|
||
|
||
#[test]
|
||
fn cloud_provider_requires_api_key() {
|
||
let _primary = EnvVarGuard::clear(OLLAMA_API_KEY_ENV);
|
||
let _legacy_primary = EnvVarGuard::clear(LEGACY_OLLAMA_CLOUD_API_KEY_ENV);
|
||
let _legacy_secondary = EnvVarGuard::clear(LEGACY_OWLEN_OLLAMA_CLOUD_API_KEY_ENV);
|
||
|
||
let config = ProviderConfig {
|
||
enabled: true,
|
||
provider_type: "ollama_cloud".to_string(),
|
||
base_url: None,
|
||
api_key: None,
|
||
api_key_env: None,
|
||
extra: HashMap::new(),
|
||
};
|
||
|
||
let err = OllamaProvider::from_config("ollama_cloud", &config, None)
|
||
.expect_err("expected config error");
|
||
match err {
|
||
Error::Config(message) => {
|
||
assert!(message.contains("API key"));
|
||
}
|
||
other => panic!("unexpected error variant: {other:?}"),
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn cloud_provider_uses_explicit_api_key() {
|
||
let config = ProviderConfig {
|
||
enabled: true,
|
||
provider_type: "ollama_cloud".to_string(),
|
||
base_url: None,
|
||
api_key: Some("secret-cloud-key".to_string()),
|
||
api_key_env: None,
|
||
extra: HashMap::new(),
|
||
};
|
||
|
||
let provider = OllamaProvider::from_config("ollama_cloud", &config, None)
|
||
.expect("provider constructed");
|
||
assert_eq!(provider.name(), "ollama_cloud");
|
||
assert_eq!(provider.mode, OllamaMode::Cloud);
|
||
assert_eq!(provider.base_url, CLOUD_BASE_URL);
|
||
}
|
||
|
||
#[test]
|
||
fn cloud_provider_reads_api_key_from_env_hint() {
|
||
let config = ProviderConfig {
|
||
enabled: true,
|
||
provider_type: "ollama_cloud".to_string(),
|
||
base_url: None,
|
||
api_key: None,
|
||
api_key_env: Some("OLLAMA_TEST_CLOUD_KEY".to_string()),
|
||
extra: HashMap::new(),
|
||
};
|
||
|
||
unsafe {
|
||
std::env::set_var("OLLAMA_TEST_CLOUD_KEY", "env-secret");
|
||
}
|
||
|
||
assert!(std::env::var("OLLAMA_TEST_CLOUD_KEY").is_ok());
|
||
assert!(resolve_api_key_env_hint(config.api_key_env.as_deref()).is_some());
|
||
assert_eq!(config.api_key_env.as_deref(), Some("OLLAMA_TEST_CLOUD_KEY"));
|
||
|
||
let provider = OllamaProvider::from_config("ollama_cloud", &config, None)
|
||
.expect("provider constructed");
|
||
assert_eq!(provider.name(), "ollama_cloud");
|
||
assert_eq!(provider.mode, OllamaMode::Cloud);
|
||
|
||
unsafe {
|
||
std::env::remove_var("OLLAMA_TEST_CLOUD_KEY");
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn fetch_scope_tags_with_retry_success_uses_override() {
|
||
let provider = OllamaProvider::new("http://localhost:11434").expect("provider constructed");
|
||
let handle = ScopeHandle::new(
|
||
provider.client.clone(),
|
||
provider.http_client.clone(),
|
||
provider.base_url.clone(),
|
||
);
|
||
|
||
let _guard = set_tags_override(vec![Ok(vec![LocalModel {
|
||
name: "llama3".into(),
|
||
modified_at: "2024-01-01T00:00:00Z".into(),
|
||
size: 42,
|
||
}])]);
|
||
|
||
let models = block_on(provider.fetch_scope_tags_with_retry(OllamaMode::Local, &handle))
|
||
.expect("models returned");
|
||
assert_eq!(models.len(), 1);
|
||
assert_eq!(models[0].name, "llama3");
|
||
}
|
||
|
||
#[test]
|
||
fn convert_model_propagates_context_window_from_details() {
|
||
let provider = OllamaProvider::new("http://localhost:11434").expect("provider constructed");
|
||
let local = LocalModel {
|
||
name: "gemma3n:e4b".into(),
|
||
modified_at: "2024-01-01T00:00:00Z".into(),
|
||
size: 0,
|
||
};
|
||
|
||
let mut meta = JsonMap::new();
|
||
meta.insert(
|
||
"context_length".into(),
|
||
Value::Number(serde_json::Number::from(32_768)),
|
||
);
|
||
|
||
let detail = OllamaModelInfo {
|
||
license: String::new(),
|
||
modelfile: String::new(),
|
||
parameters: String::new(),
|
||
template: String::new(),
|
||
model_info: meta,
|
||
capabilities: vec![],
|
||
};
|
||
|
||
let info = provider.convert_model(OllamaMode::Local, local, Some(detail));
|
||
assert_eq!(info.context_window, Some(32_768));
|
||
}
|
||
|
||
#[test]
|
||
fn fetch_scope_tags_with_retry_retries_on_timeout_then_succeeds() {
|
||
let provider = OllamaProvider::new("http://localhost:11434").expect("provider constructed");
|
||
let handle = ScopeHandle::new(
|
||
provider.client.clone(),
|
||
provider.http_client.clone(),
|
||
provider.base_url.clone(),
|
||
);
|
||
|
||
let _guard = set_tags_override(vec![
|
||
Err(Error::Timeout("first attempt".into())),
|
||
Ok(vec![LocalModel {
|
||
name: "llama3".into(),
|
||
modified_at: "2024-01-01T00:00:00Z".into(),
|
||
size: 42,
|
||
}]),
|
||
]);
|
||
|
||
let models = block_on(provider.fetch_scope_tags_with_retry(OllamaMode::Local, &handle))
|
||
.expect("models returned after retry");
|
||
assert_eq!(models.len(), 1);
|
||
assert_eq!(models[0].name, "llama3");
|
||
}
|
||
|
||
#[test]
|
||
fn decorate_scope_error_returns_friendly_message_for_connectivity() {
|
||
let provider = OllamaProvider::new("http://localhost:11434").expect("provider constructed");
|
||
let (error, message) = provider.decorate_scope_error(
|
||
OllamaMode::Local,
|
||
"http://localhost:11434",
|
||
Error::Network("failed to connect to host".into()),
|
||
);
|
||
|
||
assert!(matches!(
|
||
error,
|
||
Error::Network(ref text) if text.contains("Ollama not reachable")
|
||
));
|
||
assert!(message.contains("Ollama not reachable"));
|
||
assert!(message.contains("localhost:11434"));
|
||
}
|
||
|
||
#[test]
|
||
fn decorate_scope_error_preserves_http_failure_message() {
|
||
let provider = OllamaProvider::new("http://localhost:11434").expect("provider constructed");
|
||
let original = "Ollama list models failed (500): boom".to_string();
|
||
let (error, message) = provider.decorate_scope_error(
|
||
OllamaMode::Local,
|
||
"http://localhost:11434",
|
||
Error::Network(original.clone()),
|
||
);
|
||
|
||
assert!(matches!(error, Error::Network(ref text) if text.contains("500")));
|
||
assert_eq!(message, original);
|
||
}
|
||
|
||
#[test]
|
||
fn decorate_scope_error_translates_timeout() {
|
||
let provider = OllamaProvider::new("http://localhost:11434").expect("provider constructed");
|
||
let (error, message) = provider.decorate_scope_error(
|
||
OllamaMode::Local,
|
||
"http://localhost:11434",
|
||
Error::Timeout("deadline exceeded".into()),
|
||
);
|
||
|
||
assert!(matches!(
|
||
error,
|
||
Error::Timeout(ref text) if text.contains("Ollama not reachable")
|
||
));
|
||
assert!(message.contains("Ollama not reachable"));
|
||
}
|
||
|
||
#[test]
|
||
fn map_http_failure_model_not_found_suggests_pull_hint() {
|
||
let provider = OllamaProvider::new("http://localhost:11434").expect("provider constructed");
|
||
let err = provider.map_http_failure(
|
||
"chat",
|
||
StatusCode::NOT_FOUND,
|
||
"missing model".to_string(),
|
||
Some("llama3"),
|
||
);
|
||
|
||
let message = match err {
|
||
Error::InvalidInput(message) => message,
|
||
other => panic!("unexpected error variant: {other:?}"),
|
||
};
|
||
|
||
assert!(message.contains("ollama pull llama3"));
|
||
}
|
||
|
||
#[test]
|
||
fn build_model_options_merges_parameters() {
|
||
let mut parameters = ChatParameters::default();
|
||
parameters.temperature = Some(0.3);
|
||
parameters.max_tokens = Some(128);
|
||
parameters
|
||
.extra
|
||
.insert("num_ctx".into(), Value::from(4096_u64));
|
||
|
||
let options = build_model_options(¶meters)
|
||
.expect("options built")
|
||
.expect("options present");
|
||
let serialized = serde_json::to_value(&options).expect("serialize options");
|
||
let temperature = serialized["temperature"]
|
||
.as_f64()
|
||
.expect("temperature present");
|
||
assert!((temperature - 0.3).abs() < 1e-6);
|
||
assert_eq!(serialized["num_predict"], json!(128));
|
||
assert_eq!(serialized["num_ctx"], json!(4096));
|
||
}
|
||
|
||
#[test]
|
||
fn prepare_chat_request_serializes_tool_descriptors() {
|
||
let provider = OllamaProvider::new("http://localhost:11434").expect("provider constructed");
|
||
|
||
let descriptor = McpToolDescriptor {
|
||
name: crate::tools::WEB_SEARCH_TOOL_NAME.to_string(),
|
||
description: "Perform a web search".to_string(),
|
||
input_schema: json!({
|
||
"type": "object",
|
||
"properties": {
|
||
"query": {"type": "string"}
|
||
},
|
||
"required": ["query"]
|
||
}),
|
||
requires_network: true,
|
||
requires_filesystem: Vec::new(),
|
||
};
|
||
|
||
let (_model_id, request) = provider
|
||
.prepare_chat_request(
|
||
"llama3".to_string(),
|
||
vec![Message::user("Hello".to_string())],
|
||
ChatParameters::default(),
|
||
Some(vec![descriptor.clone()]),
|
||
)
|
||
.expect("request built");
|
||
|
||
assert_eq!(request.tools.len(), 1);
|
||
let tool = &request.tools[0];
|
||
assert_eq!(tool.function.name, descriptor.name);
|
||
assert_eq!(tool.function.description, descriptor.description);
|
||
|
||
let serialized = serde_json::to_value(&tool.function.parameters).expect("serialize schema");
|
||
assert_eq!(serialized, descriptor.input_schema);
|
||
}
|
||
|
||
#[test]
|
||
fn convert_model_marks_tool_capability() {
|
||
let provider = OllamaProvider::new("http://localhost:11434").expect("provider constructed");
|
||
|
||
let local = LocalModel {
|
||
name: "llama3-tool".to_string(),
|
||
modified_at: "2025-10-23T00:00:00Z".to_string(),
|
||
size: 0,
|
||
};
|
||
|
||
let detail = OllamaModelInfo {
|
||
license: String::new(),
|
||
modelfile: String::new(),
|
||
parameters: String::new(),
|
||
template: String::new(),
|
||
model_info: JsonMap::new(),
|
||
capabilities: vec!["function_call".to_string()],
|
||
};
|
||
|
||
let info = provider.convert_model(OllamaMode::Local, local, Some(detail));
|
||
assert!(info.supports_tools);
|
||
}
|
||
|
||
#[test]
|
||
fn convert_response_attaches_provider_metadata() {
|
||
let final_data = ChatMessageFinalResponseData {
|
||
total_duration: 10,
|
||
load_duration: 2,
|
||
prompt_eval_count: 42,
|
||
prompt_eval_duration: 4,
|
||
eval_count: 21,
|
||
eval_duration: 6,
|
||
};
|
||
|
||
let response = OllamaChatResponse {
|
||
model: "llama3".to_string(),
|
||
created_at: "2025-10-23T18:00:00Z".to_string(),
|
||
message: OllamaMessage {
|
||
role: OllamaRole::Assistant,
|
||
content: "Tool output incoming".to_string(),
|
||
tool_calls: Vec::new(),
|
||
images: None,
|
||
thinking: None,
|
||
},
|
||
done: true,
|
||
final_data: Some(final_data),
|
||
};
|
||
|
||
let chunk = OllamaProvider::convert_ollama_response(response, false);
|
||
|
||
let metadata = chunk
|
||
.message
|
||
.metadata
|
||
.get("ollama")
|
||
.and_then(Value::as_object)
|
||
.expect("ollama metadata present");
|
||
assert_eq!(
|
||
metadata.get("model").and_then(Value::as_str),
|
||
Some("llama3")
|
||
);
|
||
assert!(metadata.contains_key("final_data"));
|
||
assert_eq!(
|
||
metadata.get("created_at").and_then(Value::as_str).unwrap(),
|
||
"2025-10-23T18:00:00Z"
|
||
);
|
||
|
||
let usage = chunk.usage.expect("usage populated");
|
||
assert_eq!(usage.prompt_tokens, 42);
|
||
assert_eq!(usage.completion_tokens, 21);
|
||
}
|
||
|
||
#[test]
|
||
fn heuristic_capabilities_detects_thinking_models() {
|
||
let caps = heuristic_capabilities("deepseek-r1");
|
||
assert!(caps.iter().any(|cap| cap == "thinking"));
|
||
}
|
||
|
||
#[test]
|
||
fn push_capability_avoids_duplicates() {
|
||
let mut caps = vec!["chat".to_string()];
|
||
push_capability(&mut caps, "Chat");
|
||
push_capability(&mut caps, "Vision");
|
||
push_capability(&mut caps, "vision");
|
||
|
||
assert_eq!(caps.len(), 2);
|
||
assert!(caps.iter().any(|cap| cap == "vision"));
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
struct ProbeOverrideGuard {
|
||
gate: Option<MutexGuard<'static, ()>>,
|
||
}
|
||
|
||
#[cfg(test)]
|
||
impl ProbeOverrideGuard {
|
||
fn set(value: Option<bool>) -> Self {
|
||
let gate = PROBE_OVERRIDE_GATE
|
||
.get_or_init(|| Mutex::new(()))
|
||
.lock()
|
||
.expect("probe override gate mutex poisoned");
|
||
set_probe_override(value);
|
||
ProbeOverrideGuard { gate: Some(gate) }
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
impl Drop for ProbeOverrideGuard {
|
||
fn drop(&mut self) {
|
||
set_probe_override(None);
|
||
self.gate.take();
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
struct EnvVarGuard {
|
||
key: &'static str,
|
||
original: Option<String>,
|
||
}
|
||
|
||
#[cfg(test)]
|
||
impl EnvVarGuard {
|
||
fn clear(key: &'static str) -> Self {
|
||
let original = std::env::var(key).ok();
|
||
unsafe {
|
||
std::env::remove_var(key);
|
||
}
|
||
Self { key, original }
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
impl Drop for EnvVarGuard {
|
||
fn drop(&mut self) {
|
||
match &self.original {
|
||
Some(value) => unsafe {
|
||
std::env::set_var(self.key, value);
|
||
},
|
||
None => unsafe {
|
||
std::env::remove_var(self.key);
|
||
},
|
||
}
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn auto_mode_with_api_key_and_successful_probe_prefers_local() {
|
||
let _guard = ProbeOverrideGuard::set(Some(true));
|
||
|
||
let mut config = ProviderConfig {
|
||
enabled: true,
|
||
provider_type: "ollama".to_string(),
|
||
base_url: None,
|
||
api_key: Some("secret-key".to_string()),
|
||
api_key_env: None,
|
||
extra: HashMap::new(),
|
||
};
|
||
config.extra.insert(
|
||
OLLAMA_MODE_KEY.to_string(),
|
||
Value::String("auto".to_string()),
|
||
);
|
||
|
||
assert!(probe_default_local_daemon(Duration::from_millis(1)));
|
||
|
||
let provider =
|
||
OllamaProvider::from_config("ollama_local", &config, None).expect("provider constructed");
|
||
|
||
assert_eq!(provider.mode, OllamaMode::Local);
|
||
assert_eq!(provider.base_url, "http://localhost:11434");
|
||
}
|
||
|
||
#[test]
|
||
fn auto_mode_with_api_key_and_failed_probe_prefers_cloud() {
|
||
let _guard = ProbeOverrideGuard::set(Some(false));
|
||
|
||
let mut config = ProviderConfig {
|
||
enabled: true,
|
||
provider_type: "ollama".to_string(),
|
||
base_url: None,
|
||
api_key: Some("secret-key".to_string()),
|
||
api_key_env: None,
|
||
extra: HashMap::new(),
|
||
};
|
||
config.extra.insert(
|
||
OLLAMA_MODE_KEY.to_string(),
|
||
Value::String("auto".to_string()),
|
||
);
|
||
|
||
let provider =
|
||
OllamaProvider::from_config("ollama_cloud", &config, None).expect("provider constructed");
|
||
|
||
assert_eq!(provider.mode, OllamaMode::Cloud);
|
||
assert_eq!(provider.base_url, CLOUD_BASE_URL);
|
||
}
|
||
|
||
#[test]
|
||
fn annotate_scope_status_adds_capabilities_for_unavailable_scopes() {
|
||
let config = ProviderConfig {
|
||
enabled: true,
|
||
provider_type: "ollama".to_string(),
|
||
base_url: Some("http://localhost:11434".to_string()),
|
||
api_key: None,
|
||
api_key_env: None,
|
||
extra: HashMap::new(),
|
||
};
|
||
|
||
let provider =
|
||
OllamaProvider::from_config("ollama_local", &config, None).expect("provider constructed");
|
||
|
||
let mut models = vec![ModelInfo {
|
||
id: "llama3".to_string(),
|
||
name: "Llama 3".to_string(),
|
||
description: None,
|
||
provider: "ollama".to_string(),
|
||
context_window: None,
|
||
capabilities: vec!["scope:local".to_string()],
|
||
supports_tools: false,
|
||
}];
|
||
|
||
block_on(async {
|
||
{
|
||
let mut cache = provider.scope_cache.write().await;
|
||
let entry = cache.entry(OllamaMode::Cloud).or_default();
|
||
entry.availability = ScopeAvailability::Unavailable;
|
||
entry.last_error = Some("Cloud endpoint unreachable".to_string());
|
||
entry.last_checked = Some(Instant::now());
|
||
}
|
||
|
||
provider.annotate_scope_status(&mut models).await;
|
||
});
|
||
|
||
let capabilities = &models[0].capabilities;
|
||
assert!(
|
||
capabilities
|
||
.iter()
|
||
.any(|cap| cap == "scope-status:cloud:unavailable")
|
||
);
|
||
assert!(
|
||
capabilities
|
||
.iter()
|
||
.any(|cap| cap.starts_with("scope-status-message:cloud:"))
|
||
);
|
||
assert!(
|
||
capabilities
|
||
.iter()
|
||
.any(|cap| cap.starts_with("scope-status-age:cloud:"))
|
||
);
|
||
assert!(
|
||
capabilities
|
||
.iter()
|
||
.any(|cap| cap == "scope-status-stale:cloud:0")
|
||
);
|
||
}
|