feat(ollama): add explicit Ollama mode config, cloud endpoint storage, and scope‑availability caching with status annotations.
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
//! Ollama provider built on top of the `ollama-rs` crate.
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
collections::{HashMap, HashSet},
|
||||
env,
|
||||
net::{SocketAddr, TcpStream},
|
||||
pin::Pin,
|
||||
time::{Duration, SystemTime},
|
||||
sync::Arc,
|
||||
time::{Duration, Instant, SystemTime},
|
||||
};
|
||||
|
||||
use anyhow::anyhow;
|
||||
@@ -22,11 +24,17 @@ use ollama_rs::{
|
||||
};
|
||||
use reqwest::{Client, StatusCode, Url};
|
||||
use serde_json::{Map as JsonMap, Value, json};
|
||||
use tokio::{sync::RwLock, time::timeout};
|
||||
|
||||
#[cfg(test)]
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
#[cfg(test)]
|
||||
use tokio_test::block_on;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
Error, Result,
|
||||
config::GeneralSettings,
|
||||
config::{GeneralSettings, OLLAMA_CLOUD_BASE_URL, OLLAMA_CLOUD_ENDPOINT_KEY, OLLAMA_MODE_KEY},
|
||||
llm::{LlmProvider, ProviderConfig},
|
||||
mcp::McpToolDescriptor,
|
||||
model::{DetailedModelInfo, ModelDetailsCache, ModelManager},
|
||||
@@ -37,9 +45,11 @@ use crate::{
|
||||
|
||||
const DEFAULT_TIMEOUT_SECS: u64 = 120;
|
||||
const DEFAULT_MODEL_CACHE_TTL_SECS: u64 = 60;
|
||||
const CLOUD_BASE_URL: &str = "https://ollama.com";
|
||||
pub(crate) const CLOUD_BASE_URL: &str = OLLAMA_CLOUD_BASE_URL;
|
||||
const LOCAL_PROBE_TIMEOUT_MS: u64 = 200;
|
||||
const LOCAL_PROBE_TARGETS: &[&str] = &["127.0.0.1:11434", "[::1]:11434"];
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
enum OllamaMode {
|
||||
Local,
|
||||
Cloud,
|
||||
@@ -54,6 +64,44 @@ impl OllamaMode {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum ScopeAvailability {
|
||||
Unknown,
|
||||
Available,
|
||||
Unavailable,
|
||||
}
|
||||
|
||||
impl ScopeAvailability {
|
||||
fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
ScopeAvailability::Unknown => "unknown",
|
||||
ScopeAvailability::Available => "available",
|
||||
ScopeAvailability::Unavailable => "unavailable",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct ScopeSnapshot {
|
||||
models: Vec<ModelInfo>,
|
||||
fetched_at: Option<Instant>,
|
||||
availability: ScopeAvailability,
|
||||
last_error: Option<String>,
|
||||
last_checked: Option<Instant>,
|
||||
}
|
||||
|
||||
impl Default for ScopeSnapshot {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
models: Vec::new(),
|
||||
fetched_at: None,
|
||||
availability: ScopeAvailability::Unknown,
|
||||
last_error: None,
|
||||
last_checked: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct OllamaOptions {
|
||||
mode: OllamaMode,
|
||||
@@ -61,6 +109,7 @@ struct OllamaOptions {
|
||||
request_timeout: Duration,
|
||||
model_cache_ttl: Duration,
|
||||
api_key: Option<String>,
|
||||
cloud_endpoint: Option<String>,
|
||||
}
|
||||
|
||||
impl OllamaOptions {
|
||||
@@ -71,6 +120,7 @@ impl OllamaOptions {
|
||||
request_timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
|
||||
model_cache_ttl: Duration::from_secs(DEFAULT_MODEL_CACHE_TTL_SECS),
|
||||
api_key: None,
|
||||
cloud_endpoint: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,8 +137,78 @@ pub struct OllamaProvider {
|
||||
client: Ollama,
|
||||
http_client: Client,
|
||||
base_url: String,
|
||||
request_timeout: Duration,
|
||||
api_key: Option<String>,
|
||||
cloud_endpoint: Option<String>,
|
||||
model_manager: ModelManager,
|
||||
model_details_cache: ModelDetailsCache,
|
||||
model_cache_ttl: Duration,
|
||||
scope_cache: Arc<RwLock<HashMap<OllamaMode, ScopeSnapshot>>>,
|
||||
}
|
||||
|
||||
fn configured_mode_from_extra(config: &ProviderConfig) -> Option<OllamaMode> {
|
||||
config
|
||||
.extra
|
||||
.get(OLLAMA_MODE_KEY)
|
||||
.and_then(|value| value.as_str())
|
||||
.and_then(|value| match value.trim().to_ascii_lowercase().as_str() {
|
||||
"local" => Some(OllamaMode::Local),
|
||||
"cloud" => Some(OllamaMode::Cloud),
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
|
||||
fn is_explicit_local_base(base_url: Option<&str>) -> bool {
|
||||
base_url
|
||||
.and_then(|raw| Url::parse(raw).ok())
|
||||
.and_then(|parsed| parsed.host_str().map(|host| host.to_ascii_lowercase()))
|
||||
.map(|host| host == "localhost" || host == "127.0.0.1" || host == "::1")
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn is_explicit_cloud_base(base_url: Option<&str>) -> bool {
|
||||
base_url
|
||||
.map(|raw| {
|
||||
let trimmed = raw.trim_end_matches('/');
|
||||
trimmed == CLOUD_BASE_URL || trimmed.starts_with("https://ollama.com/")
|
||||
})
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
static PROBE_OVERRIDE: OnceLock<Mutex<Option<bool>>> = OnceLock::new();
|
||||
|
||||
#[cfg(test)]
|
||||
fn set_probe_override(value: Option<bool>) {
|
||||
let guard = PROBE_OVERRIDE.get_or_init(|| Mutex::new(None));
|
||||
*guard.lock().expect("probe override mutex poisoned") = value;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn probe_override_value() -> Option<bool> {
|
||||
PROBE_OVERRIDE
|
||||
.get_or_init(|| Mutex::new(None))
|
||||
.lock()
|
||||
.expect("probe override mutex poisoned")
|
||||
.to_owned()
|
||||
}
|
||||
|
||||
fn probe_default_local_daemon(timeout: Duration) -> bool {
|
||||
#[cfg(test)]
|
||||
{
|
||||
if let Some(value) = probe_override_value() {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
for target in LOCAL_PROBE_TARGETS {
|
||||
if let Ok(address) = target.parse::<SocketAddr>() {
|
||||
if TcpStream::connect_timeout(&address, timeout).is_ok() {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
impl OllamaProvider {
|
||||
@@ -105,23 +225,64 @@ impl OllamaProvider {
|
||||
let mut api_key = resolve_api_key(config.api_key.clone())
|
||||
.or_else(|| env_var_non_empty("OLLAMA_API_KEY"))
|
||||
.or_else(|| env_var_non_empty("OLLAMA_CLOUD_API_KEY"));
|
||||
let configured_mode = configured_mode_from_extra(config);
|
||||
let configured_mode_label = config
|
||||
.extra
|
||||
.get(OLLAMA_MODE_KEY)
|
||||
.and_then(|value| value.as_str())
|
||||
.unwrap_or("auto");
|
||||
let base_url = config.base_url.as_deref();
|
||||
let base_is_local = is_explicit_local_base(base_url);
|
||||
let base_is_cloud = is_explicit_cloud_base(base_url);
|
||||
let base_is_other = base_url.is_some() && !base_is_local && !base_is_cloud;
|
||||
|
||||
let mode = if api_key.is_some() {
|
||||
OllamaMode::Cloud
|
||||
} else {
|
||||
OllamaMode::Local
|
||||
let mut local_probe_result = None;
|
||||
let cloud_endpoint = config
|
||||
.extra
|
||||
.get(OLLAMA_CLOUD_ENDPOINT_KEY)
|
||||
.and_then(Value::as_str)
|
||||
.map(normalize_cloud_endpoint)
|
||||
.transpose()
|
||||
.map_err(Error::Config)?;
|
||||
|
||||
let mode = match configured_mode {
|
||||
Some(mode) => mode,
|
||||
None => {
|
||||
if base_is_local || base_is_other {
|
||||
OllamaMode::Local
|
||||
} else if base_is_cloud && api_key.is_some() {
|
||||
OllamaMode::Cloud
|
||||
} else {
|
||||
let probe =
|
||||
probe_default_local_daemon(Duration::from_millis(LOCAL_PROBE_TIMEOUT_MS));
|
||||
local_probe_result = Some(probe);
|
||||
if probe {
|
||||
OllamaMode::Local
|
||||
} else if api_key.is_some() {
|
||||
OllamaMode::Cloud
|
||||
} else {
|
||||
OllamaMode::Local
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let base_candidate = if mode == OllamaMode::Cloud {
|
||||
Some(CLOUD_BASE_URL)
|
||||
} else {
|
||||
config.base_url.as_deref()
|
||||
let base_candidate = match mode {
|
||||
OllamaMode::Local => base_url,
|
||||
OllamaMode::Cloud => {
|
||||
if base_is_cloud {
|
||||
base_url
|
||||
} else {
|
||||
Some(CLOUD_BASE_URL)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let normalized_base_url =
|
||||
normalize_base_url(base_candidate, mode).map_err(Error::Config)?;
|
||||
|
||||
let mut options = OllamaOptions::new(mode, normalized_base_url);
|
||||
let mut options = OllamaOptions::new(mode, normalized_base_url.clone());
|
||||
options.cloud_endpoint = cloud_endpoint.clone();
|
||||
|
||||
if let Some(timeout) = config
|
||||
.extra
|
||||
@@ -145,6 +306,23 @@ impl OllamaProvider {
|
||||
options = options.with_general(general);
|
||||
}
|
||||
|
||||
debug!(
|
||||
"Resolved Ollama provider: mode={:?}, base_url={}, configured_mode={}, api_key_present={}, local_probe={}",
|
||||
mode,
|
||||
normalized_base_url,
|
||||
configured_mode_label,
|
||||
if options.api_key.is_some() {
|
||||
"yes"
|
||||
} else {
|
||||
"no"
|
||||
},
|
||||
match local_probe_result {
|
||||
Some(true) => "success",
|
||||
Some(false) => "failed",
|
||||
None => "skipped",
|
||||
}
|
||||
);
|
||||
|
||||
Self::with_options(options)
|
||||
}
|
||||
|
||||
@@ -155,44 +333,32 @@ impl OllamaProvider {
|
||||
request_timeout,
|
||||
model_cache_ttl,
|
||||
api_key,
|
||||
cloud_endpoint,
|
||||
} = options;
|
||||
|
||||
let url = Url::parse(&base_url)
|
||||
.map_err(|err| Error::Config(format!("Invalid Ollama base URL '{base_url}': {err}")))?;
|
||||
let api_key_ref = api_key.as_deref();
|
||||
let (ollama_client, http_client) =
|
||||
build_client_for_base(&base_url, request_timeout, api_key_ref)?;
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
if let Some(ref key) = api_key {
|
||||
let value = HeaderValue::from_str(&format!("Bearer {key}")).map_err(|_| {
|
||||
Error::Config("OLLAMA API key contains invalid characters".to_string())
|
||||
})?;
|
||||
headers.insert(AUTHORIZATION, value);
|
||||
}
|
||||
|
||||
let mut client_builder = Client::builder().timeout(request_timeout);
|
||||
if !headers.is_empty() {
|
||||
client_builder = client_builder.default_headers(headers.clone());
|
||||
}
|
||||
|
||||
let http_client = client_builder
|
||||
.build()
|
||||
.map_err(|err| Error::Config(format!("Failed to build HTTP client: {err}")))?;
|
||||
|
||||
let port = url.port_or_known_default().ok_or_else(|| {
|
||||
Error::Config(format!("Unable to determine port for Ollama URL '{}'", url))
|
||||
})?;
|
||||
|
||||
let mut ollama_client = Ollama::new_with_client(url.clone(), port, http_client.clone());
|
||||
if !headers.is_empty() {
|
||||
ollama_client.set_headers(Some(headers.clone()));
|
||||
}
|
||||
let scope_cache = {
|
||||
let mut initial = HashMap::new();
|
||||
initial.insert(OllamaMode::Local, ScopeSnapshot::default());
|
||||
initial.insert(OllamaMode::Cloud, ScopeSnapshot::default());
|
||||
Arc::new(RwLock::new(initial))
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
mode,
|
||||
client: ollama_client,
|
||||
http_client,
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
request_timeout,
|
||||
api_key,
|
||||
cloud_endpoint,
|
||||
model_manager: ModelManager::new(model_cache_ttl),
|
||||
model_details_cache: ModelDetailsCache::new(model_cache_ttl),
|
||||
model_cache_ttl,
|
||||
scope_cache,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -200,6 +366,121 @@ impl OllamaProvider {
|
||||
build_api_endpoint(&self.base_url, endpoint)
|
||||
}
|
||||
|
||||
fn local_base_url() -> &'static str {
|
||||
OllamaMode::Local.default_base_url()
|
||||
}
|
||||
|
||||
fn scope_key(scope: OllamaMode) -> &'static str {
|
||||
match scope {
|
||||
OllamaMode::Local => "local",
|
||||
OllamaMode::Cloud => "cloud",
|
||||
}
|
||||
}
|
||||
|
||||
fn build_local_client(&self) -> Result<Option<Ollama>> {
|
||||
if matches!(self.mode, OllamaMode::Local) {
|
||||
return Ok(Some(self.client.clone()));
|
||||
}
|
||||
|
||||
let (client, _) =
|
||||
build_client_for_base(Self::local_base_url(), self.request_timeout, None)?;
|
||||
Ok(Some(client))
|
||||
}
|
||||
|
||||
fn build_cloud_client(&self) -> Result<Option<Ollama>> {
|
||||
if matches!(self.mode, OllamaMode::Cloud) {
|
||||
return Ok(Some(self.client.clone()));
|
||||
}
|
||||
|
||||
let api_key = match self.api_key.as_deref() {
|
||||
Some(key) if !key.trim().is_empty() => key,
|
||||
_ => return Ok(None),
|
||||
};
|
||||
|
||||
let endpoint = self.cloud_endpoint.as_deref().unwrap_or(CLOUD_BASE_URL);
|
||||
|
||||
let (client, _) = build_client_for_base(endpoint, self.request_timeout, Some(api_key))?;
|
||||
Ok(Some(client))
|
||||
}
|
||||
|
||||
async fn cached_scope_models(&self, scope: OllamaMode) -> Option<Vec<ModelInfo>> {
|
||||
let cache = self.scope_cache.read().await;
|
||||
cache.get(&scope).and_then(|entry| {
|
||||
if entry.availability == ScopeAvailability::Unknown {
|
||||
return None;
|
||||
}
|
||||
|
||||
entry.fetched_at.and_then(|ts| {
|
||||
if ts.elapsed() < self.model_cache_ttl {
|
||||
Some(entry.models.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
async fn update_scope_success(&self, scope: OllamaMode, models: &[ModelInfo]) {
|
||||
let mut cache = self.scope_cache.write().await;
|
||||
let entry = cache.entry(scope).or_default();
|
||||
entry.models = models.to_vec();
|
||||
entry.fetched_at = Some(Instant::now());
|
||||
entry.last_checked = Some(Instant::now());
|
||||
entry.availability = ScopeAvailability::Available;
|
||||
entry.last_error = None;
|
||||
}
|
||||
|
||||
async fn mark_scope_failure(&self, scope: OllamaMode, message: String) {
|
||||
let mut cache = self.scope_cache.write().await;
|
||||
let entry = cache.entry(scope).or_default();
|
||||
entry.availability = ScopeAvailability::Unavailable;
|
||||
entry.last_error = Some(message);
|
||||
entry.last_checked = Some(Instant::now());
|
||||
}
|
||||
|
||||
async fn annotate_scope_status(&self, models: &mut [ModelInfo]) {
|
||||
if models.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let cache = self.scope_cache.read().await;
|
||||
for (scope, snapshot) in cache.iter() {
|
||||
if snapshot.availability == ScopeAvailability::Unknown {
|
||||
continue;
|
||||
}
|
||||
let scope_key = Self::scope_key(*scope);
|
||||
let capability = format!(
|
||||
"scope-status:{}:{}",
|
||||
scope_key,
|
||||
snapshot.availability.as_str()
|
||||
);
|
||||
|
||||
for model in models.iter_mut() {
|
||||
if !model.capabilities.iter().any(|cap| cap == &capability) {
|
||||
model.capabilities.push(capability.clone());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(raw_reason) = snapshot.last_error.as_ref() {
|
||||
let cleaned = raw_reason.replace('\n', " ").trim().to_string();
|
||||
if !cleaned.is_empty() {
|
||||
let truncated: String = cleaned.chars().take(160).collect();
|
||||
let message_capability =
|
||||
format!("scope-status-message:{}:{}", scope_key, truncated);
|
||||
for model in models.iter_mut() {
|
||||
if !model
|
||||
.capabilities
|
||||
.iter()
|
||||
.any(|cap| cap == &message_capability)
|
||||
{
|
||||
model.capabilities.push(message_capability.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempt to resolve detailed model information for the given model, using the local cache when possible.
|
||||
pub async fn get_model_info(&self, model_name: &str) -> Result<DetailedModelInfo> {
|
||||
if let Some(info) = self.model_details_cache.get(model_name).await {
|
||||
@@ -312,15 +593,92 @@ impl OllamaProvider {
|
||||
}
|
||||
|
||||
async fn fetch_models(&self) -> Result<Vec<ModelInfo>> {
|
||||
let models = self
|
||||
.client
|
||||
.list_local_models()
|
||||
.await
|
||||
.map_err(|err| self.map_ollama_error("list models", err, None))?;
|
||||
let mut combined = Vec::new();
|
||||
let mut seen: HashSet<String> = HashSet::new();
|
||||
let mut errors: Vec<Error> = Vec::new();
|
||||
|
||||
if let Some(local_client) = self.build_local_client()? {
|
||||
match self
|
||||
.fetch_models_for_scope(OllamaMode::Local, local_client.clone())
|
||||
.await
|
||||
{
|
||||
Ok(models) => {
|
||||
for model in models {
|
||||
let key = format!("local::{}", model.id);
|
||||
if seen.insert(key) {
|
||||
combined.push(model);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => errors.push(err),
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(cloud_client) = self.build_cloud_client()? {
|
||||
match self
|
||||
.fetch_models_for_scope(OllamaMode::Cloud, cloud_client.clone())
|
||||
.await
|
||||
{
|
||||
Ok(models) => {
|
||||
for model in models {
|
||||
let key = format!("cloud::{}", model.id);
|
||||
if seen.insert(key) {
|
||||
combined.push(model);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => errors.push(err),
|
||||
}
|
||||
}
|
||||
|
||||
if combined.is_empty() {
|
||||
if let Some(err) = errors.pop() {
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
|
||||
self.annotate_scope_status(&mut combined).await;
|
||||
combined.sort_by(|a, b| a.name.to_lowercase().cmp(&b.name.to_lowercase()));
|
||||
Ok(combined)
|
||||
}
|
||||
|
||||
async fn fetch_models_for_scope(
|
||||
&self,
|
||||
scope: OllamaMode,
|
||||
client: Ollama,
|
||||
) -> Result<Vec<ModelInfo>> {
|
||||
let list_result = if matches!(scope, OllamaMode::Local) {
|
||||
match timeout(
|
||||
Duration::from_millis(LOCAL_PROBE_TIMEOUT_MS),
|
||||
client.list_local_models(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(result) => result.map_err(|err| self.map_ollama_error("list models", err, None)),
|
||||
Err(_) => Err(Error::Timeout(
|
||||
"Timed out while contacting the local Ollama daemon".to_string(),
|
||||
)),
|
||||
}
|
||||
} else {
|
||||
client
|
||||
.list_local_models()
|
||||
.await
|
||||
.map_err(|err| self.map_ollama_error("list models", err, None))
|
||||
};
|
||||
|
||||
let models = match list_result {
|
||||
Ok(models) => models,
|
||||
Err(err) => {
|
||||
let message = err.to_string();
|
||||
self.mark_scope_failure(scope, message).await;
|
||||
if let Some(cached) = self.cached_scope_models(scope).await {
|
||||
return Ok(cached);
|
||||
}
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
let client = self.client.clone();
|
||||
let cache = self.model_details_cache.clone();
|
||||
let mode = self.mode;
|
||||
let fetched = join_all(models.into_iter().map(|local| {
|
||||
let client = client.clone();
|
||||
let cache = cache.clone();
|
||||
@@ -329,7 +687,7 @@ impl OllamaProvider {
|
||||
let detail = match client.show_model_info(name.clone()).await {
|
||||
Ok(info) => {
|
||||
let detailed = OllamaProvider::convert_detailed_model_info(
|
||||
mode,
|
||||
scope,
|
||||
&name,
|
||||
Some(&local),
|
||||
&info,
|
||||
@@ -347,10 +705,13 @@ impl OllamaProvider {
|
||||
}))
|
||||
.await;
|
||||
|
||||
Ok(fetched
|
||||
let converted: Vec<ModelInfo> = fetched
|
||||
.into_iter()
|
||||
.map(|(local, detail)| self.convert_model(local, detail))
|
||||
.collect())
|
||||
.map(|(local, detail)| self.convert_model(scope, local, detail))
|
||||
.collect();
|
||||
|
||||
self.update_scope_success(scope, &converted).await;
|
||||
Ok(converted)
|
||||
}
|
||||
|
||||
fn convert_detailed_model_info(
|
||||
@@ -430,8 +791,13 @@ impl OllamaProvider {
|
||||
info.with_normalised_strings()
|
||||
}
|
||||
|
||||
fn convert_model(&self, model: LocalModel, detail: Option<OllamaModelInfo>) -> ModelInfo {
|
||||
let scope = match self.mode {
|
||||
fn convert_model(
|
||||
&self,
|
||||
scope: OllamaMode,
|
||||
model: LocalModel,
|
||||
detail: Option<OllamaModelInfo>,
|
||||
) -> ModelInfo {
|
||||
let scope_tag = match scope {
|
||||
OllamaMode::Local => "local",
|
||||
OllamaMode::Cloud => "cloud",
|
||||
};
|
||||
@@ -453,7 +819,9 @@ impl OllamaProvider {
|
||||
push_capability(&mut capabilities, &heuristic);
|
||||
}
|
||||
|
||||
let description = build_model_description(scope, detail.as_ref());
|
||||
push_capability(&mut capabilities, &format!("scope:{scope_tag}"));
|
||||
|
||||
let description = build_model_description(scope_tag, detail.as_ref());
|
||||
|
||||
ModelInfo {
|
||||
id: name.clone(),
|
||||
@@ -1004,6 +1372,10 @@ fn normalize_base_url(
|
||||
Ok(url.to_string().trim_end_matches('/').to_string())
|
||||
}
|
||||
|
||||
fn normalize_cloud_endpoint(input: &str) -> std::result::Result<String, String> {
|
||||
normalize_base_url(Some(input), OllamaMode::Cloud)
|
||||
}
|
||||
|
||||
fn build_api_endpoint(base_url: &str, endpoint: &str) -> String {
|
||||
let trimmed_base = base_url.trim_end_matches('/');
|
||||
let trimmed_endpoint = endpoint.trim_start_matches('/');
|
||||
@@ -1015,9 +1387,48 @@ fn build_api_endpoint(base_url: &str, endpoint: &str) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
fn build_client_for_base(
|
||||
base_url: &str,
|
||||
timeout: Duration,
|
||||
api_key: Option<&str>,
|
||||
) -> Result<(Ollama, Client)> {
|
||||
let url = Url::parse(base_url)
|
||||
.map_err(|err| Error::Config(format!("Invalid Ollama base URL '{base_url}': {err}")))?;
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
if let Some(key) = api_key {
|
||||
let value = HeaderValue::from_str(&format!("Bearer {key}"))
|
||||
.map_err(|_| Error::Config("OLLAMA API key contains invalid characters".to_string()))?;
|
||||
headers.insert(AUTHORIZATION, value);
|
||||
}
|
||||
|
||||
let mut client_builder = Client::builder().timeout(timeout);
|
||||
if !headers.is_empty() {
|
||||
client_builder = client_builder.default_headers(headers.clone());
|
||||
}
|
||||
|
||||
let http_client = client_builder.build().map_err(|err| {
|
||||
Error::Config(format!(
|
||||
"Failed to build HTTP client for '{base_url}': {err}"
|
||||
))
|
||||
})?;
|
||||
|
||||
let port = url.port_or_known_default().ok_or_else(|| {
|
||||
Error::Config(format!("Unable to determine port for Ollama URL '{}'", url))
|
||||
})?;
|
||||
|
||||
let mut ollama_client = Ollama::new_with_client(url.clone(), port, http_client.clone());
|
||||
if !headers.is_empty() {
|
||||
ollama_client.set_headers(Some(headers));
|
||||
}
|
||||
|
||||
Ok((ollama_client, http_client))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[test]
|
||||
fn resolve_api_key_prefers_literal_value() {
|
||||
@@ -1053,6 +1464,60 @@ mod tests {
|
||||
assert!(err.contains("https"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn explicit_local_mode_overrides_api_key() {
|
||||
let mut config = ProviderConfig {
|
||||
provider_type: "ollama".to_string(),
|
||||
base_url: Some("http://localhost:11434".to_string()),
|
||||
api_key: Some("secret-key".to_string()),
|
||||
extra: HashMap::new(),
|
||||
};
|
||||
config.extra.insert(
|
||||
OLLAMA_MODE_KEY.to_string(),
|
||||
Value::String("local".to_string()),
|
||||
);
|
||||
|
||||
let provider = OllamaProvider::from_config(&config, None).expect("provider constructed");
|
||||
|
||||
assert_eq!(provider.mode, OllamaMode::Local);
|
||||
assert_eq!(provider.base_url, "http://localhost:11434");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auto_mode_prefers_explicit_local_base() {
|
||||
let config = ProviderConfig {
|
||||
provider_type: "ollama".to_string(),
|
||||
base_url: Some("http://localhost:11434".to_string()),
|
||||
api_key: Some("secret-key".to_string()),
|
||||
extra: HashMap::new(),
|
||||
};
|
||||
// simulate missing explicit mode; defaults to auto
|
||||
|
||||
let provider = OllamaProvider::from_config(&config, None).expect("provider constructed");
|
||||
|
||||
assert_eq!(provider.mode, OllamaMode::Local);
|
||||
assert_eq!(provider.base_url, "http://localhost:11434");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auto_mode_with_api_key_and_no_local_probe_switches_to_cloud() {
|
||||
let mut config = ProviderConfig {
|
||||
provider_type: "ollama".to_string(),
|
||||
base_url: None,
|
||||
api_key: Some("secret-key".to_string()),
|
||||
extra: HashMap::new(),
|
||||
};
|
||||
config.extra.insert(
|
||||
OLLAMA_MODE_KEY.to_string(),
|
||||
Value::String("auto".to_string()),
|
||||
);
|
||||
|
||||
let provider = OllamaProvider::from_config(&config, None).expect("provider constructed");
|
||||
|
||||
assert_eq!(provider.mode, OllamaMode::Cloud);
|
||||
assert_eq!(provider.base_url, CLOUD_BASE_URL);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_model_options_merges_parameters() {
|
||||
let mut parameters = ChatParameters::default();
|
||||
@@ -1091,3 +1556,110 @@ mod tests {
|
||||
assert!(caps.iter().any(|cap| cap == "vision"));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
struct ProbeOverrideGuard;
|
||||
|
||||
#[cfg(test)]
|
||||
impl ProbeOverrideGuard {
|
||||
fn set(value: Option<bool>) -> Self {
|
||||
set_probe_override(value);
|
||||
ProbeOverrideGuard
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Drop for ProbeOverrideGuard {
|
||||
fn drop(&mut self) {
|
||||
set_probe_override(None);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auto_mode_with_api_key_and_successful_probe_prefers_local() {
|
||||
let _guard = ProbeOverrideGuard::set(Some(true));
|
||||
|
||||
let mut config = ProviderConfig {
|
||||
provider_type: "ollama".to_string(),
|
||||
base_url: None,
|
||||
api_key: Some("secret-key".to_string()),
|
||||
extra: HashMap::new(),
|
||||
};
|
||||
config.extra.insert(
|
||||
OLLAMA_MODE_KEY.to_string(),
|
||||
Value::String("auto".to_string()),
|
||||
);
|
||||
|
||||
assert!(probe_default_local_daemon(Duration::from_millis(1)));
|
||||
|
||||
let provider = OllamaProvider::from_config(&config, None).expect("provider constructed");
|
||||
|
||||
assert_eq!(provider.mode, OllamaMode::Local);
|
||||
assert_eq!(provider.base_url, "http://localhost:11434");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auto_mode_with_api_key_and_failed_probe_prefers_cloud() {
|
||||
let _guard = ProbeOverrideGuard::set(Some(false));
|
||||
|
||||
let mut config = ProviderConfig {
|
||||
provider_type: "ollama".to_string(),
|
||||
base_url: None,
|
||||
api_key: Some("secret-key".to_string()),
|
||||
extra: HashMap::new(),
|
||||
};
|
||||
config.extra.insert(
|
||||
OLLAMA_MODE_KEY.to_string(),
|
||||
Value::String("auto".to_string()),
|
||||
);
|
||||
|
||||
let provider = OllamaProvider::from_config(&config, None).expect("provider constructed");
|
||||
|
||||
assert_eq!(provider.mode, OllamaMode::Cloud);
|
||||
assert_eq!(provider.base_url, CLOUD_BASE_URL);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn annotate_scope_status_adds_capabilities_for_unavailable_scopes() {
|
||||
let config = ProviderConfig {
|
||||
provider_type: "ollama".to_string(),
|
||||
base_url: Some("http://localhost:11434".to_string()),
|
||||
api_key: None,
|
||||
extra: HashMap::new(),
|
||||
};
|
||||
|
||||
let provider = OllamaProvider::from_config(&config, None).expect("provider constructed");
|
||||
|
||||
let mut models = vec![ModelInfo {
|
||||
id: "llama3".to_string(),
|
||||
name: "Llama 3".to_string(),
|
||||
description: None,
|
||||
provider: "ollama".to_string(),
|
||||
context_window: None,
|
||||
capabilities: vec!["scope:local".to_string()],
|
||||
supports_tools: false,
|
||||
}];
|
||||
|
||||
block_on(async {
|
||||
{
|
||||
let mut cache = provider.scope_cache.write().await;
|
||||
let entry = cache.entry(OllamaMode::Cloud).or_default();
|
||||
entry.availability = ScopeAvailability::Unavailable;
|
||||
entry.last_error = Some("Cloud endpoint unreachable".to_string());
|
||||
}
|
||||
|
||||
provider.annotate_scope_status(&mut models).await;
|
||||
});
|
||||
|
||||
let capabilities = &models[0].capabilities;
|
||||
assert!(
|
||||
capabilities
|
||||
.iter()
|
||||
.any(|cap| cap == "scope-status:cloud:unavailable")
|
||||
);
|
||||
assert!(
|
||||
capabilities
|
||||
.iter()
|
||||
.any(|cap| cap.starts_with("scope-status-message:cloud:"))
|
||||
);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user