- Switch to `ollama-rs` crate for chat, model listing, and streaming. - Remove custom request building, authentication handling, and debug logging. - Drop unsupported tool conversion; now ignore tool descriptors with a warning. - Refactor model fetching to use local model info and optional cloud details. - Consolidate error mapping via `map_ollama_error`. - Update health check to use the new HTTP client. - Delete obsolete `provider_interface.rs` test as the provider interface has changed.
842 lines
26 KiB
Rust
842 lines
26 KiB
Rust
//! Ollama provider built on top of the `ollama-rs` crate.
|
||
use std::{
|
||
collections::HashMap,
|
||
env,
|
||
pin::Pin,
|
||
time::{Duration, SystemTime},
|
||
};
|
||
|
||
use anyhow::anyhow;
|
||
use futures::{future::join_all, future::BoxFuture, Stream, StreamExt};
|
||
use log::{debug, warn};
|
||
use ollama_rs::{
|
||
error::OllamaError,
|
||
generation::chat::{
|
||
request::ChatMessageRequest as OllamaChatRequest, ChatMessage as OllamaMessage,
|
||
ChatMessageResponse as OllamaChatResponse, MessageRole as OllamaRole,
|
||
},
|
||
generation::tools::{ToolCall as OllamaToolCall, ToolCallFunction as OllamaToolCallFunction},
|
||
headers::{HeaderMap, HeaderValue, AUTHORIZATION},
|
||
models::{LocalModel, ModelInfo as OllamaModelInfo, ModelOptions},
|
||
Ollama,
|
||
};
|
||
use reqwest::{Client, StatusCode, Url};
|
||
use serde_json::{json, Map as JsonMap, Value};
|
||
use uuid::Uuid;
|
||
|
||
use crate::{
|
||
config::GeneralSettings,
|
||
mcp::McpToolDescriptor,
|
||
model::ModelManager,
|
||
provider::{LLMProvider, ProviderConfig},
|
||
types::{
|
||
ChatParameters, ChatRequest, ChatResponse, Message, ModelInfo, Role, TokenUsage, ToolCall,
|
||
},
|
||
Error, Result,
|
||
};
|
||
|
||
const DEFAULT_TIMEOUT_SECS: u64 = 120;
|
||
const DEFAULT_MODEL_CACHE_TTL_SECS: u64 = 60;
|
||
const CLOUD_BASE_URL: &str = "https://ollama.com";
|
||
|
||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||
enum OllamaMode {
|
||
Local,
|
||
Cloud,
|
||
}
|
||
|
||
impl OllamaMode {
|
||
fn default_base_url(self) -> &'static str {
|
||
match self {
|
||
Self::Local => "http://localhost:11434",
|
||
Self::Cloud => CLOUD_BASE_URL,
|
||
}
|
||
}
|
||
}
|
||
|
||
#[derive(Debug)]
|
||
struct OllamaOptions {
|
||
mode: OllamaMode,
|
||
base_url: String,
|
||
request_timeout: Duration,
|
||
model_cache_ttl: Duration,
|
||
api_key: Option<String>,
|
||
}
|
||
|
||
impl OllamaOptions {
|
||
fn new(mode: OllamaMode, base_url: impl Into<String>) -> Self {
|
||
Self {
|
||
mode,
|
||
base_url: base_url.into(),
|
||
request_timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
|
||
model_cache_ttl: Duration::from_secs(DEFAULT_MODEL_CACHE_TTL_SECS),
|
||
api_key: None,
|
||
}
|
||
}
|
||
|
||
fn with_general(mut self, general: &GeneralSettings) -> Self {
|
||
self.model_cache_ttl = general.model_cache_ttl();
|
||
self
|
||
}
|
||
}
|
||
|
||
/// Ollama provider implementation backed by `ollama-rs`.
|
||
#[derive(Debug)]
|
||
pub struct OllamaProvider {
|
||
mode: OllamaMode,
|
||
client: Ollama,
|
||
http_client: Client,
|
||
base_url: String,
|
||
model_manager: ModelManager,
|
||
}
|
||
|
||
impl OllamaProvider {
|
||
/// Create a provider targeting an explicit base URL (local usage).
|
||
pub fn new(base_url: impl Into<String>) -> Result<Self> {
|
||
let input = base_url.into();
|
||
let normalized =
|
||
normalize_base_url(Some(&input), OllamaMode::Local).map_err(Error::Config)?;
|
||
Self::with_options(OllamaOptions::new(OllamaMode::Local, normalized))
|
||
}
|
||
|
||
/// Construct a provider from configuration settings.
|
||
pub fn from_config(config: &ProviderConfig, general: Option<&GeneralSettings>) -> Result<Self> {
|
||
let mut api_key = resolve_api_key(config.api_key.clone())
|
||
.or_else(|| env_var_non_empty("OLLAMA_API_KEY"))
|
||
.or_else(|| env_var_non_empty("OLLAMA_CLOUD_API_KEY"));
|
||
|
||
let mode = if api_key.is_some() {
|
||
OllamaMode::Cloud
|
||
} else {
|
||
OllamaMode::Local
|
||
};
|
||
|
||
let base_candidate = if mode == OllamaMode::Cloud {
|
||
Some(CLOUD_BASE_URL)
|
||
} else {
|
||
config.base_url.as_deref()
|
||
};
|
||
|
||
let normalized_base_url =
|
||
normalize_base_url(base_candidate, mode).map_err(Error::Config)?;
|
||
|
||
let mut options = OllamaOptions::new(mode, normalized_base_url);
|
||
|
||
if let Some(timeout) = config
|
||
.extra
|
||
.get("timeout_secs")
|
||
.and_then(|value| value.as_u64())
|
||
{
|
||
options.request_timeout = Duration::from_secs(timeout.max(5));
|
||
}
|
||
|
||
if let Some(cache_ttl) = config
|
||
.extra
|
||
.get("model_cache_ttl_secs")
|
||
.and_then(|value| value.as_u64())
|
||
{
|
||
options.model_cache_ttl = Duration::from_secs(cache_ttl.max(5));
|
||
}
|
||
|
||
options.api_key = api_key.take();
|
||
|
||
if let Some(general) = general {
|
||
options = options.with_general(general);
|
||
}
|
||
|
||
Self::with_options(options)
|
||
}
|
||
|
||
fn with_options(options: OllamaOptions) -> Result<Self> {
|
||
let OllamaOptions {
|
||
mode,
|
||
base_url,
|
||
request_timeout,
|
||
model_cache_ttl,
|
||
api_key,
|
||
} = options;
|
||
|
||
let url = Url::parse(&base_url)
|
||
.map_err(|err| Error::Config(format!("Invalid Ollama base URL '{base_url}': {err}")))?;
|
||
|
||
let mut headers = HeaderMap::new();
|
||
if let Some(ref key) = api_key {
|
||
let value = HeaderValue::from_str(&format!("Bearer {key}")).map_err(|_| {
|
||
Error::Config("OLLAMA API key contains invalid characters".to_string())
|
||
})?;
|
||
headers.insert(AUTHORIZATION, value);
|
||
}
|
||
|
||
let mut client_builder = Client::builder().timeout(request_timeout);
|
||
if !headers.is_empty() {
|
||
client_builder = client_builder.default_headers(headers.clone());
|
||
}
|
||
|
||
let http_client = client_builder
|
||
.build()
|
||
.map_err(|err| Error::Config(format!("Failed to build HTTP client: {err}")))?;
|
||
|
||
let port = url.port_or_known_default().ok_or_else(|| {
|
||
Error::Config(format!("Unable to determine port for Ollama URL '{}'", url))
|
||
})?;
|
||
|
||
let mut ollama_client = Ollama::new_with_client(url.clone(), port, http_client.clone());
|
||
if !headers.is_empty() {
|
||
ollama_client.set_headers(Some(headers.clone()));
|
||
}
|
||
|
||
Ok(Self {
|
||
mode,
|
||
client: ollama_client,
|
||
http_client,
|
||
base_url: base_url.trim_end_matches('/').to_string(),
|
||
model_manager: ModelManager::new(model_cache_ttl),
|
||
})
|
||
}
|
||
|
||
fn api_url(&self, endpoint: &str) -> String {
|
||
build_api_endpoint(&self.base_url, endpoint)
|
||
}
|
||
|
||
fn prepare_chat_request(
|
||
&self,
|
||
model: String,
|
||
messages: Vec<Message>,
|
||
parameters: ChatParameters,
|
||
tools: Option<Vec<McpToolDescriptor>>,
|
||
) -> Result<(String, OllamaChatRequest)> {
|
||
if self.mode == OllamaMode::Cloud && !model.contains("-cloud") {
|
||
warn!(
|
||
"Model '{}' does not use the '-cloud' suffix. Cloud-only models may fail to load.",
|
||
model
|
||
);
|
||
}
|
||
|
||
if let Some(descriptors) = &tools {
|
||
if !descriptors.is_empty() {
|
||
debug!(
|
||
"Ignoring {} MCP tool descriptors for Ollama request (tool calling unsupported)",
|
||
descriptors.len()
|
||
);
|
||
}
|
||
}
|
||
|
||
let converted_messages = messages.into_iter().map(convert_message).collect();
|
||
let mut request = OllamaChatRequest::new(model.clone(), converted_messages);
|
||
|
||
if let Some(options) = build_model_options(¶meters)? {
|
||
request.options = Some(options);
|
||
}
|
||
|
||
Ok((model, request))
|
||
}
|
||
|
||
async fn fetch_models(&self) -> Result<Vec<ModelInfo>> {
|
||
let models = self
|
||
.client
|
||
.list_local_models()
|
||
.await
|
||
.map_err(|err| self.map_ollama_error("list models", err, None))?;
|
||
|
||
let client = self.client.clone();
|
||
let fetched = join_all(models.into_iter().map(|local| {
|
||
let client = client.clone();
|
||
async move {
|
||
let name = local.name.clone();
|
||
let detail = match client.show_model_info(name.clone()).await {
|
||
Ok(info) => Some(info),
|
||
Err(err) => {
|
||
debug!("Failed to fetch Ollama model info for '{name}': {err}");
|
||
None
|
||
}
|
||
};
|
||
(local, detail)
|
||
}
|
||
}))
|
||
.await;
|
||
|
||
Ok(fetched
|
||
.into_iter()
|
||
.map(|(local, detail)| self.convert_model(local, detail))
|
||
.collect())
|
||
}
|
||
|
||
fn convert_model(&self, model: LocalModel, detail: Option<OllamaModelInfo>) -> ModelInfo {
|
||
let scope = match self.mode {
|
||
OllamaMode::Local => "local",
|
||
OllamaMode::Cloud => "cloud",
|
||
};
|
||
|
||
let name = model.name;
|
||
let mut capabilities: Vec<String> = detail
|
||
.as_ref()
|
||
.map(|info| {
|
||
info.capabilities
|
||
.iter()
|
||
.map(|cap| cap.to_ascii_lowercase())
|
||
.collect()
|
||
})
|
||
.unwrap_or_default();
|
||
|
||
push_capability(&mut capabilities, "chat");
|
||
|
||
for heuristic in heuristic_capabilities(&name) {
|
||
push_capability(&mut capabilities, &heuristic);
|
||
}
|
||
|
||
let description = build_model_description(scope, detail.as_ref());
|
||
|
||
ModelInfo {
|
||
id: name.clone(),
|
||
name,
|
||
description: Some(description),
|
||
provider: "ollama".to_string(),
|
||
context_window: None,
|
||
capabilities,
|
||
supports_tools: false,
|
||
}
|
||
}
|
||
|
||
fn convert_ollama_response(response: OllamaChatResponse, streaming: bool) -> ChatResponse {
|
||
let usage = response.final_data.as_ref().map(|data| {
|
||
let prompt = clamp_to_u32(data.prompt_eval_count);
|
||
let completion = clamp_to_u32(data.eval_count);
|
||
TokenUsage {
|
||
prompt_tokens: prompt,
|
||
completion_tokens: completion,
|
||
total_tokens: prompt.saturating_add(completion),
|
||
}
|
||
});
|
||
|
||
ChatResponse {
|
||
message: convert_ollama_message(response.message),
|
||
usage,
|
||
is_streaming: streaming,
|
||
is_final: if streaming { response.done } else { true },
|
||
}
|
||
}
|
||
|
||
fn map_ollama_error(&self, action: &str, err: OllamaError, model: Option<&str>) -> Error {
|
||
match err {
|
||
OllamaError::ReqwestError(request_err) => {
|
||
if let Some(status) = request_err.status() {
|
||
self.map_http_failure(action, status, request_err.to_string(), model)
|
||
} else if request_err.is_timeout() {
|
||
Error::Timeout(format!("Ollama {action} timed out: {request_err}"))
|
||
} else {
|
||
Error::Network(format!("Ollama {action} request failed: {request_err}"))
|
||
}
|
||
}
|
||
OllamaError::InternalError(internal) => Error::Provider(anyhow!(internal.message)),
|
||
OllamaError::Other(message) => Error::Provider(anyhow!(message)),
|
||
OllamaError::JsonError(err) => Error::Serialization(err),
|
||
OllamaError::ToolCallError(err) => Error::Provider(anyhow!(err)),
|
||
}
|
||
}
|
||
|
||
fn map_http_failure(
|
||
&self,
|
||
action: &str,
|
||
status: StatusCode,
|
||
detail: String,
|
||
model: Option<&str>,
|
||
) -> Error {
|
||
match status {
|
||
StatusCode::NOT_FOUND => {
|
||
if let Some(model) = model {
|
||
Error::InvalidInput(format!(
|
||
"Model '{model}' was not found at {}. Verify the name or pull it with `ollama pull`.",
|
||
self.base_url
|
||
))
|
||
} else {
|
||
Error::InvalidInput(format!(
|
||
"{action} returned 404 from {}: {detail}",
|
||
self.base_url
|
||
))
|
||
}
|
||
}
|
||
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => Error::Auth(format!(
|
||
"Ollama rejected the request ({status}): {detail}. Check your API key and account permissions."
|
||
)),
|
||
StatusCode::BAD_REQUEST => Error::InvalidInput(format!(
|
||
"{action} rejected by Ollama ({status}): {detail}"
|
||
)),
|
||
StatusCode::SERVICE_UNAVAILABLE | StatusCode::GATEWAY_TIMEOUT => Error::Timeout(
|
||
format!(
|
||
"Ollama {action} timed out ({status}). The model may still be loading."
|
||
),
|
||
),
|
||
_ => Error::Network(format!(
|
||
"Ollama {action} failed ({status}): {detail}"
|
||
)),
|
||
}
|
||
}
|
||
}
|
||
|
||
impl LLMProvider for OllamaProvider {
|
||
type Stream = Pin<Box<dyn Stream<Item = Result<ChatResponse>> + Send>>;
|
||
type ListModelsFuture<'a>
|
||
= BoxFuture<'a, Result<Vec<ModelInfo>>>
|
||
where
|
||
Self: 'a;
|
||
type ChatFuture<'a>
|
||
= BoxFuture<'a, Result<ChatResponse>>
|
||
where
|
||
Self: 'a;
|
||
type ChatStreamFuture<'a>
|
||
= BoxFuture<'a, Result<Self::Stream>>
|
||
where
|
||
Self: 'a;
|
||
type HealthCheckFuture<'a>
|
||
= BoxFuture<'a, Result<()>>
|
||
where
|
||
Self: 'a;
|
||
|
||
fn name(&self) -> &str {
|
||
"ollama"
|
||
}
|
||
|
||
fn list_models(&self) -> Self::ListModelsFuture<'_> {
|
||
Box::pin(async move {
|
||
self.model_manager
|
||
.get_or_refresh(false, || async { self.fetch_models().await })
|
||
.await
|
||
})
|
||
}
|
||
|
||
fn chat(&self, request: ChatRequest) -> Self::ChatFuture<'_> {
|
||
Box::pin(async move {
|
||
let ChatRequest {
|
||
model,
|
||
messages,
|
||
parameters,
|
||
tools,
|
||
} = request;
|
||
|
||
let (model_id, ollama_request) =
|
||
self.prepare_chat_request(model, messages, parameters, tools)?;
|
||
|
||
let response = self
|
||
.client
|
||
.send_chat_messages(ollama_request)
|
||
.await
|
||
.map_err(|err| self.map_ollama_error("chat", err, Some(&model_id)))?;
|
||
|
||
Ok(Self::convert_ollama_response(response, false))
|
||
})
|
||
}
|
||
|
||
fn chat_stream(&self, request: ChatRequest) -> Self::ChatStreamFuture<'_> {
|
||
Box::pin(async move {
|
||
let ChatRequest {
|
||
model,
|
||
messages,
|
||
parameters,
|
||
tools,
|
||
} = request;
|
||
|
||
let (model_id, ollama_request) =
|
||
self.prepare_chat_request(model, messages, parameters, tools)?;
|
||
|
||
let stream = self
|
||
.client
|
||
.send_chat_messages_stream(ollama_request)
|
||
.await
|
||
.map_err(|err| self.map_ollama_error("chat_stream", err, Some(&model_id)))?;
|
||
|
||
let mapped = stream.map(|item| match item {
|
||
Ok(chunk) => Ok(Self::convert_ollama_response(chunk, true)),
|
||
Err(_) => Err(Error::Provider(anyhow!(
|
||
"Ollama returned a malformed streaming chunk"
|
||
))),
|
||
});
|
||
|
||
Ok(Box::pin(mapped) as Self::Stream)
|
||
})
|
||
}
|
||
|
||
fn health_check(&self) -> Self::HealthCheckFuture<'_> {
|
||
Box::pin(async move {
|
||
let url = self.api_url("version");
|
||
let response = self
|
||
.http_client
|
||
.get(&url)
|
||
.send()
|
||
.await
|
||
.map_err(|err| map_reqwest_error("health check", err))?;
|
||
|
||
if response.status().is_success() {
|
||
return Ok(());
|
||
}
|
||
|
||
let status = response.status();
|
||
let detail = response.text().await.unwrap_or_else(|err| err.to_string());
|
||
Err(self.map_http_failure("health check", status, detail, None))
|
||
})
|
||
}
|
||
|
||
fn config_schema(&self) -> serde_json::Value {
|
||
serde_json::json!({
|
||
"type": "object",
|
||
"properties": {
|
||
"base_url": {
|
||
"type": "string",
|
||
"description": "Base URL for the Ollama API (ignored when api_key is provided)",
|
||
"default": self.mode.default_base_url()
|
||
},
|
||
"timeout_secs": {
|
||
"type": "integer",
|
||
"description": "HTTP request timeout in seconds",
|
||
"minimum": 5,
|
||
"default": DEFAULT_TIMEOUT_SECS
|
||
},
|
||
"model_cache_ttl_secs": {
|
||
"type": "integer",
|
||
"description": "Seconds to cache model listings",
|
||
"minimum": 5,
|
||
"default": DEFAULT_MODEL_CACHE_TTL_SECS
|
||
}
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
fn build_model_options(parameters: &ChatParameters) -> Result<Option<ModelOptions>> {
|
||
let mut options = JsonMap::new();
|
||
|
||
for (key, value) in ¶meters.extra {
|
||
options.insert(key.clone(), value.clone());
|
||
}
|
||
|
||
if let Some(temperature) = parameters.temperature {
|
||
options.insert("temperature".to_string(), json!(temperature));
|
||
}
|
||
|
||
if let Some(max_tokens) = parameters.max_tokens {
|
||
let capped = i32::try_from(max_tokens).unwrap_or(i32::MAX);
|
||
options.insert("num_predict".to_string(), json!(capped));
|
||
}
|
||
|
||
if options.is_empty() {
|
||
return Ok(None);
|
||
}
|
||
|
||
serde_json::from_value(Value::Object(options))
|
||
.map(Some)
|
||
.map_err(|err| Error::Config(format!("Invalid Ollama options: {err}")))
|
||
}
|
||
|
||
fn convert_message(message: Message) -> OllamaMessage {
|
||
let Message {
|
||
role,
|
||
content,
|
||
metadata,
|
||
tool_calls,
|
||
..
|
||
} = message;
|
||
|
||
let role = match role {
|
||
Role::User => OllamaRole::User,
|
||
Role::Assistant => OllamaRole::Assistant,
|
||
Role::System => OllamaRole::System,
|
||
Role::Tool => OllamaRole::Tool,
|
||
};
|
||
|
||
let tool_calls = tool_calls
|
||
.unwrap_or_default()
|
||
.into_iter()
|
||
.map(|tool_call| OllamaToolCall {
|
||
function: OllamaToolCallFunction {
|
||
name: tool_call.name,
|
||
arguments: tool_call.arguments,
|
||
},
|
||
})
|
||
.collect();
|
||
|
||
let thinking = metadata
|
||
.get("thinking")
|
||
.and_then(|value| value.as_str().map(|s| s.to_owned()));
|
||
|
||
OllamaMessage {
|
||
role,
|
||
content,
|
||
tool_calls,
|
||
images: None,
|
||
thinking,
|
||
}
|
||
}
|
||
|
||
fn convert_ollama_message(message: OllamaMessage) -> Message {
|
||
let role = match message.role {
|
||
OllamaRole::Assistant => Role::Assistant,
|
||
OllamaRole::System => Role::System,
|
||
OllamaRole::Tool => Role::Tool,
|
||
OllamaRole::User => Role::User,
|
||
};
|
||
|
||
let tool_calls = if message.tool_calls.is_empty() {
|
||
None
|
||
} else {
|
||
Some(
|
||
message
|
||
.tool_calls
|
||
.into_iter()
|
||
.enumerate()
|
||
.map(|(idx, tool_call)| ToolCall {
|
||
id: format!("tool-call-{idx}"),
|
||
name: tool_call.function.name,
|
||
arguments: tool_call.function.arguments,
|
||
})
|
||
.collect::<Vec<_>>(),
|
||
)
|
||
};
|
||
|
||
let mut metadata = HashMap::new();
|
||
if let Some(thinking) = message.thinking {
|
||
metadata.insert("thinking".to_string(), Value::String(thinking));
|
||
}
|
||
|
||
Message {
|
||
id: Uuid::new_v4(),
|
||
role,
|
||
content: message.content,
|
||
metadata,
|
||
timestamp: SystemTime::now(),
|
||
tool_calls,
|
||
}
|
||
}
|
||
|
||
fn clamp_to_u32(value: u64) -> u32 {
|
||
u32::try_from(value).unwrap_or(u32::MAX)
|
||
}
|
||
|
||
fn push_capability(capabilities: &mut Vec<String>, capability: &str) {
|
||
let candidate = capability.to_ascii_lowercase();
|
||
if !capabilities
|
||
.iter()
|
||
.any(|existing| existing.eq_ignore_ascii_case(&candidate))
|
||
{
|
||
capabilities.push(candidate);
|
||
}
|
||
}
|
||
|
||
fn heuristic_capabilities(name: &str) -> Vec<String> {
|
||
let lowercase = name.to_ascii_lowercase();
|
||
let mut detected = Vec::new();
|
||
|
||
if lowercase.contains("vision")
|
||
|| lowercase.contains("multimodal")
|
||
|| lowercase.contains("image")
|
||
{
|
||
detected.push("vision".to_string());
|
||
}
|
||
|
||
if lowercase.contains("think")
|
||
|| lowercase.contains("reason")
|
||
|| lowercase.contains("deepseek-r1")
|
||
|| lowercase.contains("r1")
|
||
{
|
||
detected.push("thinking".to_string());
|
||
}
|
||
|
||
if lowercase.contains("audio") || lowercase.contains("speech") || lowercase.contains("voice") {
|
||
detected.push("audio".to_string());
|
||
}
|
||
|
||
detected
|
||
}
|
||
|
||
fn build_model_description(scope: &str, detail: Option<&OllamaModelInfo>) -> String {
|
||
if let Some(info) = detail {
|
||
let mut parts = Vec::new();
|
||
|
||
if let Some(family) = info
|
||
.model_info
|
||
.get("family")
|
||
.and_then(|value| value.as_str())
|
||
{
|
||
parts.push(family.to_string());
|
||
}
|
||
|
||
if let Some(parameter_size) = info
|
||
.model_info
|
||
.get("parameter_size")
|
||
.and_then(|value| value.as_str())
|
||
{
|
||
parts.push(parameter_size.to_string());
|
||
}
|
||
|
||
if let Some(variant) = info
|
||
.model_info
|
||
.get("variant")
|
||
.and_then(|value| value.as_str())
|
||
{
|
||
parts.push(variant.to_string());
|
||
}
|
||
|
||
if !parts.is_empty() {
|
||
return format!("Ollama ({scope}) – {}", parts.join(" · "));
|
||
}
|
||
}
|
||
|
||
format!("Ollama ({scope}) model")
|
||
}
|
||
|
||
fn env_var_non_empty(name: &str) -> Option<String> {
|
||
env::var(name)
|
||
.ok()
|
||
.map(|value| value.trim().to_string())
|
||
.filter(|value| !value.is_empty())
|
||
}
|
||
|
||
fn resolve_api_key(configured: Option<String>) -> Option<String> {
|
||
let raw = configured?.trim().to_string();
|
||
if raw.is_empty() {
|
||
return None;
|
||
}
|
||
|
||
if let Some(variable) = raw
|
||
.strip_prefix("${")
|
||
.and_then(|value| value.strip_suffix('}'))
|
||
.or_else(|| raw.strip_prefix('$'))
|
||
{
|
||
let var_name = variable.trim();
|
||
if var_name.is_empty() {
|
||
return None;
|
||
}
|
||
return env_var_non_empty(var_name);
|
||
}
|
||
|
||
Some(raw)
|
||
}
|
||
|
||
fn map_reqwest_error(action: &str, err: reqwest::Error) -> Error {
|
||
if err.is_timeout() {
|
||
Error::Timeout(format!("Ollama {action} request timed out: {err}"))
|
||
} else {
|
||
Error::Network(format!("Ollama {action} request failed: {err}"))
|
||
}
|
||
}
|
||
|
||
fn normalize_base_url(
|
||
input: Option<&str>,
|
||
mode_hint: OllamaMode,
|
||
) -> std::result::Result<String, String> {
|
||
let mut candidate = input
|
||
.map(str::trim)
|
||
.filter(|value| !value.is_empty())
|
||
.map(|value| value.to_string())
|
||
.unwrap_or_else(|| mode_hint.default_base_url().to_string());
|
||
|
||
if !candidate.starts_with("http://") && !candidate.starts_with("https://") {
|
||
candidate = format!("https://{candidate}");
|
||
}
|
||
|
||
let mut url =
|
||
Url::parse(&candidate).map_err(|err| format!("Invalid Ollama URL '{candidate}': {err}"))?;
|
||
|
||
if url.cannot_be_a_base() {
|
||
return Err(format!("URL '{candidate}' cannot be used as a base URL"));
|
||
}
|
||
|
||
if mode_hint == OllamaMode::Cloud && url.scheme() != "https" {
|
||
return Err("Ollama Cloud requires https:// base URLs".to_string());
|
||
}
|
||
|
||
let path = url.path().trim_end_matches('/');
|
||
if path == "/api" {
|
||
url.set_path("/");
|
||
} else if !path.is_empty() && path != "/" {
|
||
return Err("Ollama base URLs must not include additional path segments".to_string());
|
||
}
|
||
|
||
url.set_query(None);
|
||
url.set_fragment(None);
|
||
|
||
Ok(url.to_string().trim_end_matches('/').to_string())
|
||
}
|
||
|
||
fn build_api_endpoint(base_url: &str, endpoint: &str) -> String {
|
||
let trimmed_base = base_url.trim_end_matches('/');
|
||
let trimmed_endpoint = endpoint.trim_start_matches('/');
|
||
|
||
if trimmed_base.ends_with("/api") {
|
||
format!("{trimmed_base}/{trimmed_endpoint}")
|
||
} else {
|
||
format!("{trimmed_base}/api/{trimmed_endpoint}")
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
|
||
#[test]
|
||
fn resolve_api_key_prefers_literal_value() {
|
||
assert_eq!(
|
||
resolve_api_key(Some("direct-key".into())),
|
||
Some("direct-key".into())
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn resolve_api_key_expands_env_var() {
|
||
std::env::set_var("OLLAMA_TEST_KEY", "secret");
|
||
assert_eq!(
|
||
resolve_api_key(Some("${OLLAMA_TEST_KEY}".into())),
|
||
Some("secret".into())
|
||
);
|
||
std::env::remove_var("OLLAMA_TEST_KEY");
|
||
}
|
||
|
||
#[test]
|
||
fn normalize_base_url_removes_api_path() {
|
||
let url = normalize_base_url(Some("https://ollama.com/api"), OllamaMode::Cloud).unwrap();
|
||
assert_eq!(url, "https://ollama.com");
|
||
}
|
||
|
||
#[test]
|
||
fn normalize_base_url_rejects_cloud_without_https() {
|
||
let err = normalize_base_url(Some("http://ollama.com"), OllamaMode::Cloud).unwrap_err();
|
||
assert!(err.contains("https"));
|
||
}
|
||
|
||
#[test]
|
||
fn build_model_options_merges_parameters() {
|
||
let mut parameters = ChatParameters::default();
|
||
parameters.temperature = Some(0.3);
|
||
parameters.max_tokens = Some(128);
|
||
parameters
|
||
.extra
|
||
.insert("num_ctx".into(), Value::from(4096_u64));
|
||
|
||
let options = build_model_options(¶meters)
|
||
.expect("options built")
|
||
.expect("options present");
|
||
let serialized = serde_json::to_value(&options).expect("serialize options");
|
||
let temperature = serialized["temperature"]
|
||
.as_f64()
|
||
.expect("temperature present");
|
||
assert!((temperature - 0.3).abs() < 1e-6);
|
||
assert_eq!(serialized["num_predict"], json!(128));
|
||
assert_eq!(serialized["num_ctx"], json!(4096));
|
||
}
|
||
|
||
#[test]
|
||
fn heuristic_capabilities_detects_thinking_models() {
|
||
let caps = heuristic_capabilities("deepseek-r1");
|
||
assert!(caps.iter().any(|cap| cap == "thinking"));
|
||
}
|
||
|
||
#[test]
|
||
fn push_capability_avoids_duplicates() {
|
||
let mut caps = vec!["chat".to_string()];
|
||
push_capability(&mut caps, "Chat");
|
||
push_capability(&mut caps, "Vision");
|
||
push_capability(&mut caps, "vision");
|
||
|
||
assert_eq!(caps.len(), 2);
|
||
assert!(caps.iter().any(|cap| cap == "vision"));
|
||
}
|
||
}
|