Files
owlen/crates/owlen-core/src/providers/ollama.rs
vikingowl 498e6e61b6 feat(tui): add markdown rendering support and toggle command
- Introduce new `owlen-markdown` crate that converts Markdown strings to `ratatui::Text` with headings, lists, bold/italic, and inline code.
- Add `render_markdown` config option (default true) and expose it via `app.render_markdown_enabled()`.
- Implement `:markdown [on|off]` command to toggle markdown rendering.
- Update help overlay to document the new markdown toggle.
- Adjust UI rendering to conditionally apply markdown styling based on the markdown flag and code mode.
- Wire the new crate into `owlen-tui` Cargo.toml.
2025-10-14 01:35:13 +02:00

1094 lines
34 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//! Ollama provider built on top of the `ollama-rs` crate.
use std::{
collections::HashMap,
env,
pin::Pin,
time::{Duration, SystemTime},
};
use anyhow::anyhow;
use futures::{Stream, StreamExt, future::BoxFuture, future::join_all};
use log::{debug, warn};
use ollama_rs::{
Ollama,
error::OllamaError,
generation::chat::{
ChatMessage as OllamaMessage, ChatMessageResponse as OllamaChatResponse,
MessageRole as OllamaRole, request::ChatMessageRequest as OllamaChatRequest,
},
generation::tools::{ToolCall as OllamaToolCall, ToolCallFunction as OllamaToolCallFunction},
headers::{AUTHORIZATION, HeaderMap, HeaderValue},
models::{LocalModel, ModelInfo as OllamaModelInfo, ModelOptions},
};
use reqwest::{Client, StatusCode, Url};
use serde_json::{Map as JsonMap, Value, json};
use uuid::Uuid;
use crate::{
Error, Result,
config::GeneralSettings,
llm::{LlmProvider, ProviderConfig},
mcp::McpToolDescriptor,
model::{DetailedModelInfo, ModelDetailsCache, ModelManager},
types::{
ChatParameters, ChatRequest, ChatResponse, Message, ModelInfo, Role, TokenUsage, ToolCall,
},
};
const DEFAULT_TIMEOUT_SECS: u64 = 120;
const DEFAULT_MODEL_CACHE_TTL_SECS: u64 = 60;
const CLOUD_BASE_URL: &str = "https://ollama.com";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum OllamaMode {
Local,
Cloud,
}
impl OllamaMode {
fn default_base_url(self) -> &'static str {
match self {
Self::Local => "http://localhost:11434",
Self::Cloud => CLOUD_BASE_URL,
}
}
}
#[derive(Debug)]
struct OllamaOptions {
mode: OllamaMode,
base_url: String,
request_timeout: Duration,
model_cache_ttl: Duration,
api_key: Option<String>,
}
impl OllamaOptions {
fn new(mode: OllamaMode, base_url: impl Into<String>) -> Self {
Self {
mode,
base_url: base_url.into(),
request_timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
model_cache_ttl: Duration::from_secs(DEFAULT_MODEL_CACHE_TTL_SECS),
api_key: None,
}
}
fn with_general(mut self, general: &GeneralSettings) -> Self {
self.model_cache_ttl = general.model_cache_ttl();
self
}
}
/// Ollama provider implementation backed by `ollama-rs`.
#[derive(Debug)]
pub struct OllamaProvider {
mode: OllamaMode,
client: Ollama,
http_client: Client,
base_url: String,
model_manager: ModelManager,
model_details_cache: ModelDetailsCache,
}
impl OllamaProvider {
/// Create a provider targeting an explicit base URL (local usage).
pub fn new(base_url: impl Into<String>) -> Result<Self> {
let input = base_url.into();
let normalized =
normalize_base_url(Some(&input), OllamaMode::Local).map_err(Error::Config)?;
Self::with_options(OllamaOptions::new(OllamaMode::Local, normalized))
}
/// Construct a provider from configuration settings.
pub fn from_config(config: &ProviderConfig, general: Option<&GeneralSettings>) -> Result<Self> {
let mut api_key = resolve_api_key(config.api_key.clone())
.or_else(|| env_var_non_empty("OLLAMA_API_KEY"))
.or_else(|| env_var_non_empty("OLLAMA_CLOUD_API_KEY"));
let mode = if api_key.is_some() {
OllamaMode::Cloud
} else {
OllamaMode::Local
};
let base_candidate = if mode == OllamaMode::Cloud {
Some(CLOUD_BASE_URL)
} else {
config.base_url.as_deref()
};
let normalized_base_url =
normalize_base_url(base_candidate, mode).map_err(Error::Config)?;
let mut options = OllamaOptions::new(mode, normalized_base_url);
if let Some(timeout) = config
.extra
.get("timeout_secs")
.and_then(|value| value.as_u64())
{
options.request_timeout = Duration::from_secs(timeout.max(5));
}
if let Some(cache_ttl) = config
.extra
.get("model_cache_ttl_secs")
.and_then(|value| value.as_u64())
{
options.model_cache_ttl = Duration::from_secs(cache_ttl.max(5));
}
options.api_key = api_key.take();
if let Some(general) = general {
options = options.with_general(general);
}
Self::with_options(options)
}
fn with_options(options: OllamaOptions) -> Result<Self> {
let OllamaOptions {
mode,
base_url,
request_timeout,
model_cache_ttl,
api_key,
} = options;
let url = Url::parse(&base_url)
.map_err(|err| Error::Config(format!("Invalid Ollama base URL '{base_url}': {err}")))?;
let mut headers = HeaderMap::new();
if let Some(ref key) = api_key {
let value = HeaderValue::from_str(&format!("Bearer {key}")).map_err(|_| {
Error::Config("OLLAMA API key contains invalid characters".to_string())
})?;
headers.insert(AUTHORIZATION, value);
}
let mut client_builder = Client::builder().timeout(request_timeout);
if !headers.is_empty() {
client_builder = client_builder.default_headers(headers.clone());
}
let http_client = client_builder
.build()
.map_err(|err| Error::Config(format!("Failed to build HTTP client: {err}")))?;
let port = url.port_or_known_default().ok_or_else(|| {
Error::Config(format!("Unable to determine port for Ollama URL '{}'", url))
})?;
let mut ollama_client = Ollama::new_with_client(url.clone(), port, http_client.clone());
if !headers.is_empty() {
ollama_client.set_headers(Some(headers.clone()));
}
Ok(Self {
mode,
client: ollama_client,
http_client,
base_url: base_url.trim_end_matches('/').to_string(),
model_manager: ModelManager::new(model_cache_ttl),
model_details_cache: ModelDetailsCache::new(model_cache_ttl),
})
}
fn api_url(&self, endpoint: &str) -> String {
build_api_endpoint(&self.base_url, endpoint)
}
/// Attempt to resolve detailed model information for the given model, using the local cache when possible.
pub async fn get_model_info(&self, model_name: &str) -> Result<DetailedModelInfo> {
if let Some(info) = self.model_details_cache.get(model_name).await {
return Ok(info);
}
self.fetch_and_cache_model_info(model_name, None).await
}
/// Force-refresh model information for the specified model.
pub async fn refresh_model_info(&self, model_name: &str) -> Result<DetailedModelInfo> {
self.model_details_cache.invalidate(model_name).await;
self.fetch_and_cache_model_info(model_name, None).await
}
/// Retrieve detailed information for all locally available models.
pub async fn get_all_models_info(&self) -> Result<Vec<DetailedModelInfo>> {
let models = self
.client
.list_local_models()
.await
.map_err(|err| self.map_ollama_error("list models", err, None))?;
let mut details = Vec::with_capacity(models.len());
for local in &models {
match self
.fetch_and_cache_model_info(&local.name, Some(local))
.await
{
Ok(info) => details.push(info),
Err(err) => warn!("Failed to gather model info for '{}': {}", local.name, err),
}
}
Ok(details)
}
/// Return any cached model information without touching the Ollama daemon.
pub async fn cached_model_info(&self) -> Vec<DetailedModelInfo> {
self.model_details_cache.cached().await
}
/// Remove a single model's cached information.
pub async fn invalidate_model_info(&self, model_name: &str) {
self.model_details_cache.invalidate(model_name).await;
}
/// Clear the entire model information cache.
pub async fn clear_model_info_cache(&self) {
self.model_details_cache.invalidate_all().await;
}
async fn fetch_and_cache_model_info(
&self,
model_name: &str,
local: Option<&LocalModel>,
) -> Result<DetailedModelInfo> {
let detail = self
.client
.show_model_info(model_name.to_string())
.await
.map_err(|err| self.map_ollama_error("show_model_info", err, Some(model_name)))?;
let local_owned = if let Some(local) = local {
Some(local.clone())
} else {
let models = self
.client
.list_local_models()
.await
.map_err(|err| self.map_ollama_error("list models", err, None))?;
models.into_iter().find(|m| m.name == model_name)
};
let detailed =
Self::convert_detailed_model_info(self.mode, model_name, local_owned.as_ref(), &detail);
self.model_details_cache.insert(detailed.clone()).await;
Ok(detailed)
}
fn prepare_chat_request(
&self,
model: String,
messages: Vec<Message>,
parameters: ChatParameters,
tools: Option<Vec<McpToolDescriptor>>,
) -> Result<(String, OllamaChatRequest)> {
if self.mode == OllamaMode::Cloud && !model.contains("-cloud") {
warn!(
"Model '{}' does not use the '-cloud' suffix. Cloud-only models may fail to load.",
model
);
}
if let Some(descriptors) = &tools
&& !descriptors.is_empty()
{
debug!(
"Ignoring {} MCP tool descriptors for Ollama request (tool calling unsupported)",
descriptors.len()
);
}
let converted_messages = messages.into_iter().map(convert_message).collect();
let mut request = OllamaChatRequest::new(model.clone(), converted_messages);
if let Some(options) = build_model_options(&parameters)? {
request.options = Some(options);
}
Ok((model, request))
}
async fn fetch_models(&self) -> Result<Vec<ModelInfo>> {
let models = self
.client
.list_local_models()
.await
.map_err(|err| self.map_ollama_error("list models", err, None))?;
let client = self.client.clone();
let cache = self.model_details_cache.clone();
let mode = self.mode;
let fetched = join_all(models.into_iter().map(|local| {
let client = client.clone();
let cache = cache.clone();
async move {
let name = local.name.clone();
let detail = match client.show_model_info(name.clone()).await {
Ok(info) => {
let detailed = OllamaProvider::convert_detailed_model_info(
mode,
&name,
Some(&local),
&info,
);
cache.insert(detailed).await;
Some(info)
}
Err(err) => {
debug!("Failed to fetch Ollama model info for '{name}': {err}");
None
}
};
(local, detail)
}
}))
.await;
Ok(fetched
.into_iter()
.map(|(local, detail)| self.convert_model(local, detail))
.collect())
}
fn convert_detailed_model_info(
mode: OllamaMode,
model_name: &str,
local: Option<&LocalModel>,
detail: &OllamaModelInfo,
) -> DetailedModelInfo {
let map = &detail.model_info;
let architecture =
pick_first_string(map, &["architecture", "model_format", "model_type", "arch"]);
let parameters = non_empty(detail.parameters.clone())
.or_else(|| pick_first_string(map, &["parameters"]));
let parameter_size = pick_first_string(map, &["parameter_size"]);
let context_length = pick_first_u64(map, &["context_length", "num_ctx", "max_context"]);
let embedding_length = pick_first_u64(map, &["embedding_length"]);
let quantization =
pick_first_string(map, &["quantization_level", "quantization", "quantize"]);
let family = pick_first_string(map, &["family", "model_family"]);
let mut families = pick_string_list(map, &["families", "model_families"]);
if families.is_empty() {
families.extend(family.clone());
}
let system = pick_first_string(map, &["system"]);
let mut modified_at = local
.and_then(|entry| non_empty(entry.modified_at.clone()))
.or_else(|| pick_first_string(map, &["modified_at", "created_at"]));
if modified_at.is_none() && mode == OllamaMode::Cloud {
modified_at = pick_first_string(map, &["updated_at"]);
}
let size = local
.and_then(|entry| {
if entry.size > 0 {
Some(entry.size)
} else {
None
}
})
.or_else(|| pick_first_u64(map, &["size", "model_size", "download_size"]));
let digest = pick_first_string(map, &["digest", "sha256", "checksum"]);
let mut info = DetailedModelInfo {
name: model_name.to_string(),
architecture,
parameters,
context_length,
embedding_length,
quantization,
family,
families,
parameter_size,
template: non_empty(detail.template.clone()),
system,
license: non_empty(detail.license.clone()),
modelfile: non_empty(detail.modelfile.clone()),
modified_at,
size,
digest,
};
if info.parameter_size.is_none() {
info.parameter_size = info.parameters.clone();
}
info.with_normalised_strings()
}
fn convert_model(&self, model: LocalModel, detail: Option<OllamaModelInfo>) -> ModelInfo {
let scope = match self.mode {
OllamaMode::Local => "local",
OllamaMode::Cloud => "cloud",
};
let name = model.name;
let mut capabilities: Vec<String> = detail
.as_ref()
.map(|info| {
info.capabilities
.iter()
.map(|cap| cap.to_ascii_lowercase())
.collect()
})
.unwrap_or_default();
push_capability(&mut capabilities, "chat");
for heuristic in heuristic_capabilities(&name) {
push_capability(&mut capabilities, &heuristic);
}
let description = build_model_description(scope, detail.as_ref());
ModelInfo {
id: name.clone(),
name,
description: Some(description),
provider: "ollama".to_string(),
context_window: None,
capabilities,
supports_tools: false,
}
}
fn convert_ollama_response(response: OllamaChatResponse, streaming: bool) -> ChatResponse {
let usage = response.final_data.as_ref().map(|data| {
let prompt = clamp_to_u32(data.prompt_eval_count);
let completion = clamp_to_u32(data.eval_count);
TokenUsage {
prompt_tokens: prompt,
completion_tokens: completion,
total_tokens: prompt.saturating_add(completion),
}
});
ChatResponse {
message: convert_ollama_message(response.message),
usage,
is_streaming: streaming,
is_final: if streaming { response.done } else { true },
}
}
fn map_ollama_error(&self, action: &str, err: OllamaError, model: Option<&str>) -> Error {
match err {
OllamaError::ReqwestError(request_err) => {
if let Some(status) = request_err.status() {
self.map_http_failure(action, status, request_err.to_string(), model)
} else if request_err.is_timeout() {
Error::Timeout(format!("Ollama {action} timed out: {request_err}"))
} else {
Error::Network(format!("Ollama {action} request failed: {request_err}"))
}
}
OllamaError::InternalError(internal) => Error::Provider(anyhow!(internal.message)),
OllamaError::Other(message) => Error::Provider(anyhow!(message)),
OllamaError::JsonError(err) => Error::Serialization(err),
OllamaError::ToolCallError(err) => Error::Provider(anyhow!(err)),
}
}
fn map_http_failure(
&self,
action: &str,
status: StatusCode,
detail: String,
model: Option<&str>,
) -> Error {
match status {
StatusCode::NOT_FOUND => {
if let Some(model) = model {
Error::InvalidInput(format!(
"Model '{model}' was not found at {}. Verify the name or pull it with `ollama pull`.",
self.base_url
))
} else {
Error::InvalidInput(format!(
"{action} returned 404 from {}: {detail}",
self.base_url
))
}
}
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => Error::Auth(format!(
"Ollama rejected the request ({status}): {detail}. Check your API key and account permissions."
)),
StatusCode::BAD_REQUEST => {
Error::InvalidInput(format!("{action} rejected by Ollama ({status}): {detail}"))
}
StatusCode::SERVICE_UNAVAILABLE | StatusCode::GATEWAY_TIMEOUT => Error::Timeout(
format!("Ollama {action} timed out ({status}). The model may still be loading."),
),
_ => Error::Network(format!("Ollama {action} failed ({status}): {detail}")),
}
}
}
impl LlmProvider for OllamaProvider {
type Stream = Pin<Box<dyn Stream<Item = Result<ChatResponse>> + Send>>;
type ListModelsFuture<'a>
= BoxFuture<'a, Result<Vec<ModelInfo>>>
where
Self: 'a;
type SendPromptFuture<'a>
= BoxFuture<'a, Result<ChatResponse>>
where
Self: 'a;
type StreamPromptFuture<'a>
= BoxFuture<'a, Result<Self::Stream>>
where
Self: 'a;
type HealthCheckFuture<'a>
= BoxFuture<'a, Result<()>>
where
Self: 'a;
fn name(&self) -> &str {
"ollama"
}
fn list_models(&self) -> Self::ListModelsFuture<'_> {
Box::pin(async move {
self.model_manager
.get_or_refresh(false, || async { self.fetch_models().await })
.await
})
}
fn send_prompt(&self, request: ChatRequest) -> Self::SendPromptFuture<'_> {
Box::pin(async move {
let ChatRequest {
model,
messages,
parameters,
tools,
} = request;
let (model_id, ollama_request) =
self.prepare_chat_request(model, messages, parameters, tools)?;
let response = self
.client
.send_chat_messages(ollama_request)
.await
.map_err(|err| self.map_ollama_error("chat", err, Some(&model_id)))?;
Ok(Self::convert_ollama_response(response, false))
})
}
fn stream_prompt(&self, request: ChatRequest) -> Self::StreamPromptFuture<'_> {
Box::pin(async move {
let ChatRequest {
model,
messages,
parameters,
tools,
} = request;
let (model_id, ollama_request) =
self.prepare_chat_request(model, messages, parameters, tools)?;
let stream = self
.client
.send_chat_messages_stream(ollama_request)
.await
.map_err(|err| self.map_ollama_error("chat_stream", err, Some(&model_id)))?;
let mapped = stream.map(|item| match item {
Ok(chunk) => Ok(Self::convert_ollama_response(chunk, true)),
Err(_) => Err(Error::Provider(anyhow!(
"Ollama returned a malformed streaming chunk"
))),
});
Ok(Box::pin(mapped) as Self::Stream)
})
}
fn health_check(&self) -> Self::HealthCheckFuture<'_> {
Box::pin(async move {
let url = self.api_url("version");
let response = self
.http_client
.get(&url)
.send()
.await
.map_err(|err| map_reqwest_error("health check", err))?;
if response.status().is_success() {
return Ok(());
}
let status = response.status();
let detail = response.text().await.unwrap_or_else(|err| err.to_string());
Err(self.map_http_failure("health check", status, detail, None))
})
}
fn config_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"base_url": {
"type": "string",
"description": "Base URL for the Ollama API (ignored when api_key is provided)",
"default": self.mode.default_base_url()
},
"timeout_secs": {
"type": "integer",
"description": "HTTP request timeout in seconds",
"minimum": 5,
"default": DEFAULT_TIMEOUT_SECS
},
"model_cache_ttl_secs": {
"type": "integer",
"description": "Seconds to cache model listings",
"minimum": 5,
"default": DEFAULT_MODEL_CACHE_TTL_SECS
}
}
})
}
}
fn build_model_options(parameters: &ChatParameters) -> Result<Option<ModelOptions>> {
let mut options = JsonMap::new();
for (key, value) in &parameters.extra {
options.insert(key.clone(), value.clone());
}
if let Some(temperature) = parameters.temperature {
options.insert("temperature".to_string(), json!(temperature));
}
if let Some(max_tokens) = parameters.max_tokens {
let capped = i32::try_from(max_tokens).unwrap_or(i32::MAX);
options.insert("num_predict".to_string(), json!(capped));
}
if options.is_empty() {
return Ok(None);
}
serde_json::from_value(Value::Object(options))
.map(Some)
.map_err(|err| Error::Config(format!("Invalid Ollama options: {err}")))
}
fn convert_message(message: Message) -> OllamaMessage {
let Message {
role,
content,
metadata,
tool_calls,
..
} = message;
let role = match role {
Role::User => OllamaRole::User,
Role::Assistant => OllamaRole::Assistant,
Role::System => OllamaRole::System,
Role::Tool => OllamaRole::Tool,
};
let tool_calls = tool_calls
.unwrap_or_default()
.into_iter()
.map(|tool_call| OllamaToolCall {
function: OllamaToolCallFunction {
name: tool_call.name,
arguments: tool_call.arguments,
},
})
.collect();
let thinking = metadata
.get("thinking")
.and_then(|value| value.as_str().map(|s| s.to_owned()));
OllamaMessage {
role,
content,
tool_calls,
images: None,
thinking,
}
}
fn convert_ollama_message(message: OllamaMessage) -> Message {
let role = match message.role {
OllamaRole::Assistant => Role::Assistant,
OllamaRole::System => Role::System,
OllamaRole::Tool => Role::Tool,
OllamaRole::User => Role::User,
};
let tool_calls = if message.tool_calls.is_empty() {
None
} else {
Some(
message
.tool_calls
.into_iter()
.enumerate()
.map(|(idx, tool_call)| ToolCall {
id: format!("tool-call-{idx}"),
name: tool_call.function.name,
arguments: tool_call.function.arguments,
})
.collect::<Vec<_>>(),
)
};
let mut metadata = HashMap::new();
if let Some(thinking) = message.thinking {
metadata.insert("thinking".to_string(), Value::String(thinking));
}
Message {
id: Uuid::new_v4(),
role,
content: message.content,
metadata,
timestamp: SystemTime::now(),
tool_calls,
}
}
fn clamp_to_u32(value: u64) -> u32 {
u32::try_from(value).unwrap_or(u32::MAX)
}
fn push_capability(capabilities: &mut Vec<String>, capability: &str) {
let candidate = capability.to_ascii_lowercase();
if !capabilities
.iter()
.any(|existing| existing.eq_ignore_ascii_case(&candidate))
{
capabilities.push(candidate);
}
}
fn heuristic_capabilities(name: &str) -> Vec<String> {
let lowercase = name.to_ascii_lowercase();
let mut detected = Vec::new();
if lowercase.contains("vision")
|| lowercase.contains("multimodal")
|| lowercase.contains("image")
{
detected.push("vision".to_string());
}
if lowercase.contains("think")
|| lowercase.contains("reason")
|| lowercase.contains("deepseek-r1")
|| lowercase.contains("r1")
{
detected.push("thinking".to_string());
}
if lowercase.contains("audio") || lowercase.contains("speech") || lowercase.contains("voice") {
detected.push("audio".to_string());
}
detected
}
fn build_model_description(scope: &str, detail: Option<&OllamaModelInfo>) -> String {
if let Some(info) = detail {
let mut parts = Vec::new();
if let Some(family) = info
.model_info
.get("family")
.and_then(|value| value.as_str())
{
parts.push(family.to_string());
}
if let Some(parameter_size) = info
.model_info
.get("parameter_size")
.and_then(|value| value.as_str())
{
parts.push(parameter_size.to_string());
}
if let Some(variant) = info
.model_info
.get("variant")
.and_then(|value| value.as_str())
{
parts.push(variant.to_string());
}
if !parts.is_empty() {
return format!("Ollama ({scope}) {}", parts.join(" · "));
}
}
format!("Ollama ({scope}) model")
}
fn non_empty(value: String) -> Option<String> {
let trimmed = value.trim();
if trimmed.is_empty() {
None
} else {
Some(value)
}
}
fn pick_first_string(map: &JsonMap<String, Value>, keys: &[&str]) -> Option<String> {
keys.iter()
.filter_map(|key| map.get(*key))
.find_map(value_to_string)
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
}
fn pick_first_u64(map: &JsonMap<String, Value>, keys: &[&str]) -> Option<u64> {
keys.iter()
.filter_map(|key| map.get(*key))
.find_map(value_to_u64)
}
fn pick_string_list(map: &JsonMap<String, Value>, keys: &[&str]) -> Vec<String> {
for key in keys {
if let Some(value) = map.get(*key) {
match value {
Value::Array(items) => {
let collected: Vec<String> = items
.iter()
.filter_map(value_to_string)
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();
if !collected.is_empty() {
return collected;
}
}
Value::String(text) => {
let collected: Vec<String> = text
.split(',')
.map(|part| part.trim())
.filter(|part| !part.is_empty())
.map(|part| part.to_string())
.collect();
if !collected.is_empty() {
return collected;
}
}
_ => {}
}
}
}
Vec::new()
}
fn value_to_string(value: &Value) -> Option<String> {
match value {
Value::String(text) => Some(text.clone()),
Value::Number(num) => Some(num.to_string()),
Value::Bool(flag) => Some(flag.to_string()),
_ => None,
}
}
fn value_to_u64(value: &Value) -> Option<u64> {
match value {
Value::Number(num) => {
if let Some(v) = num.as_u64() {
Some(v)
} else if let Some(v) = num.as_i64() {
v.try_into().ok()
} else if let Some(v) = num.as_f64() {
if v >= 0.0 { Some(v as u64) } else { None }
} else {
None
}
}
Value::String(text) => text.trim().parse::<u64>().ok(),
_ => None,
}
}
fn env_var_non_empty(name: &str) -> Option<String> {
env::var(name)
.ok()
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty())
}
fn resolve_api_key(configured: Option<String>) -> Option<String> {
let raw = configured?.trim().to_string();
if raw.is_empty() {
return None;
}
if let Some(variable) = raw
.strip_prefix("${")
.and_then(|value| value.strip_suffix('}'))
.or_else(|| raw.strip_prefix('$'))
{
let var_name = variable.trim();
if var_name.is_empty() {
return None;
}
return env_var_non_empty(var_name);
}
Some(raw)
}
fn map_reqwest_error(action: &str, err: reqwest::Error) -> Error {
if err.is_timeout() {
Error::Timeout(format!("Ollama {action} request timed out: {err}"))
} else {
Error::Network(format!("Ollama {action} request failed: {err}"))
}
}
fn normalize_base_url(
input: Option<&str>,
mode_hint: OllamaMode,
) -> std::result::Result<String, String> {
let mut candidate = input
.map(str::trim)
.filter(|value| !value.is_empty())
.map(|value| value.to_string())
.unwrap_or_else(|| mode_hint.default_base_url().to_string());
if !candidate.starts_with("http://") && !candidate.starts_with("https://") {
candidate = format!("https://{candidate}");
}
let mut url =
Url::parse(&candidate).map_err(|err| format!("Invalid Ollama URL '{candidate}': {err}"))?;
if url.cannot_be_a_base() {
return Err(format!("URL '{candidate}' cannot be used as a base URL"));
}
if mode_hint == OllamaMode::Cloud && url.scheme() != "https" {
return Err("Ollama Cloud requires https:// base URLs".to_string());
}
let path = url.path().trim_end_matches('/');
if path == "/api" {
url.set_path("/");
} else if !path.is_empty() && path != "/" {
return Err("Ollama base URLs must not include additional path segments".to_string());
}
url.set_query(None);
url.set_fragment(None);
Ok(url.to_string().trim_end_matches('/').to_string())
}
fn build_api_endpoint(base_url: &str, endpoint: &str) -> String {
let trimmed_base = base_url.trim_end_matches('/');
let trimmed_endpoint = endpoint.trim_start_matches('/');
if trimmed_base.ends_with("/api") {
format!("{trimmed_base}/{trimmed_endpoint}")
} else {
format!("{trimmed_base}/api/{trimmed_endpoint}")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn resolve_api_key_prefers_literal_value() {
assert_eq!(
resolve_api_key(Some("direct-key".into())),
Some("direct-key".into())
);
}
#[test]
fn resolve_api_key_expands_env_var() {
unsafe {
std::env::set_var("OLLAMA_TEST_KEY", "secret");
}
assert_eq!(
resolve_api_key(Some("${OLLAMA_TEST_KEY}".into())),
Some("secret".into())
);
unsafe {
std::env::remove_var("OLLAMA_TEST_KEY");
}
}
#[test]
fn normalize_base_url_removes_api_path() {
let url = normalize_base_url(Some("https://ollama.com/api"), OllamaMode::Cloud).unwrap();
assert_eq!(url, "https://ollama.com");
}
#[test]
fn normalize_base_url_rejects_cloud_without_https() {
let err = normalize_base_url(Some("http://ollama.com"), OllamaMode::Cloud).unwrap_err();
assert!(err.contains("https"));
}
#[test]
fn build_model_options_merges_parameters() {
let mut parameters = ChatParameters::default();
parameters.temperature = Some(0.3);
parameters.max_tokens = Some(128);
parameters
.extra
.insert("num_ctx".into(), Value::from(4096_u64));
let options = build_model_options(&parameters)
.expect("options built")
.expect("options present");
let serialized = serde_json::to_value(&options).expect("serialize options");
let temperature = serialized["temperature"]
.as_f64()
.expect("temperature present");
assert!((temperature - 0.3).abs() < 1e-6);
assert_eq!(serialized["num_predict"], json!(128));
assert_eq!(serialized["num_ctx"], json!(4096));
}
#[test]
fn heuristic_capabilities_detects_thinking_models() {
let caps = heuristic_capabilities("deepseek-r1");
assert!(caps.iter().any(|cap| cap == "thinking"));
}
#[test]
fn push_capability_avoids_duplicates() {
let mut caps = vec!["chat".to_string()];
push_capability(&mut caps, "Chat");
push_capability(&mut caps, "Vision");
push_capability(&mut caps, "vision");
assert_eq!(caps.len(), 2);
assert!(caps.iter().any(|cap| cap == "vision"));
}
}