//! Ollama provider built on top of the `ollama-rs` crate. use std::{ collections::HashMap, env, pin::Pin, time::{Duration, SystemTime}, }; use anyhow::anyhow; use futures::{future::join_all, future::BoxFuture, Stream, StreamExt}; use log::{debug, warn}; use ollama_rs::{ error::OllamaError, generation::chat::{ request::ChatMessageRequest as OllamaChatRequest, ChatMessage as OllamaMessage, ChatMessageResponse as OllamaChatResponse, MessageRole as OllamaRole, }, generation::tools::{ToolCall as OllamaToolCall, ToolCallFunction as OllamaToolCallFunction}, headers::{HeaderMap, HeaderValue, AUTHORIZATION}, models::{LocalModel, ModelInfo as OllamaModelInfo, ModelOptions}, Ollama, }; use reqwest::{Client, StatusCode, Url}; use serde_json::{json, Map as JsonMap, Value}; use uuid::Uuid; use crate::{ config::GeneralSettings, mcp::McpToolDescriptor, model::ModelManager, provider::{LLMProvider, ProviderConfig}, types::{ ChatParameters, ChatRequest, ChatResponse, Message, ModelInfo, Role, TokenUsage, ToolCall, }, Error, Result, }; const DEFAULT_TIMEOUT_SECS: u64 = 120; const DEFAULT_MODEL_CACHE_TTL_SECS: u64 = 60; const CLOUD_BASE_URL: &str = "https://ollama.com"; #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum OllamaMode { Local, Cloud, } impl OllamaMode { fn default_base_url(self) -> &'static str { match self { Self::Local => "http://localhost:11434", Self::Cloud => CLOUD_BASE_URL, } } } #[derive(Debug)] struct OllamaOptions { mode: OllamaMode, base_url: String, request_timeout: Duration, model_cache_ttl: Duration, api_key: Option, } impl OllamaOptions { fn new(mode: OllamaMode, base_url: impl Into) -> Self { Self { mode, base_url: base_url.into(), request_timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS), model_cache_ttl: Duration::from_secs(DEFAULT_MODEL_CACHE_TTL_SECS), api_key: None, } } fn with_general(mut self, general: &GeneralSettings) -> Self { self.model_cache_ttl = general.model_cache_ttl(); self } } /// Ollama provider implementation backed by `ollama-rs`. #[derive(Debug)] pub struct OllamaProvider { mode: OllamaMode, client: Ollama, http_client: Client, base_url: String, model_manager: ModelManager, } impl OllamaProvider { /// Create a provider targeting an explicit base URL (local usage). pub fn new(base_url: impl Into) -> Result { let input = base_url.into(); let normalized = normalize_base_url(Some(&input), OllamaMode::Local).map_err(Error::Config)?; Self::with_options(OllamaOptions::new(OllamaMode::Local, normalized)) } /// Construct a provider from configuration settings. pub fn from_config(config: &ProviderConfig, general: Option<&GeneralSettings>) -> Result { let mut api_key = resolve_api_key(config.api_key.clone()) .or_else(|| env_var_non_empty("OLLAMA_API_KEY")) .or_else(|| env_var_non_empty("OLLAMA_CLOUD_API_KEY")); let mode = if api_key.is_some() { OllamaMode::Cloud } else { OllamaMode::Local }; let base_candidate = if mode == OllamaMode::Cloud { Some(CLOUD_BASE_URL) } else { config.base_url.as_deref() }; let normalized_base_url = normalize_base_url(base_candidate, mode).map_err(Error::Config)?; let mut options = OllamaOptions::new(mode, normalized_base_url); if let Some(timeout) = config .extra .get("timeout_secs") .and_then(|value| value.as_u64()) { options.request_timeout = Duration::from_secs(timeout.max(5)); } if let Some(cache_ttl) = config .extra .get("model_cache_ttl_secs") .and_then(|value| value.as_u64()) { options.model_cache_ttl = Duration::from_secs(cache_ttl.max(5)); } options.api_key = api_key.take(); if let Some(general) = general { options = options.with_general(general); } Self::with_options(options) } fn with_options(options: OllamaOptions) -> Result { let OllamaOptions { mode, base_url, request_timeout, model_cache_ttl, api_key, } = options; let url = Url::parse(&base_url) .map_err(|err| Error::Config(format!("Invalid Ollama base URL '{base_url}': {err}")))?; let mut headers = HeaderMap::new(); if let Some(ref key) = api_key { let value = HeaderValue::from_str(&format!("Bearer {key}")).map_err(|_| { Error::Config("OLLAMA API key contains invalid characters".to_string()) })?; headers.insert(AUTHORIZATION, value); } let mut client_builder = Client::builder().timeout(request_timeout); if !headers.is_empty() { client_builder = client_builder.default_headers(headers.clone()); } let http_client = client_builder .build() .map_err(|err| Error::Config(format!("Failed to build HTTP client: {err}")))?; let port = url.port_or_known_default().ok_or_else(|| { Error::Config(format!("Unable to determine port for Ollama URL '{}'", url)) })?; let mut ollama_client = Ollama::new_with_client(url.clone(), port, http_client.clone()); if !headers.is_empty() { ollama_client.set_headers(Some(headers.clone())); } Ok(Self { mode, client: ollama_client, http_client, base_url: base_url.trim_end_matches('/').to_string(), model_manager: ModelManager::new(model_cache_ttl), }) } fn api_url(&self, endpoint: &str) -> String { build_api_endpoint(&self.base_url, endpoint) } fn prepare_chat_request( &self, model: String, messages: Vec, parameters: ChatParameters, tools: Option>, ) -> Result<(String, OllamaChatRequest)> { if self.mode == OllamaMode::Cloud && !model.contains("-cloud") { warn!( "Model '{}' does not use the '-cloud' suffix. Cloud-only models may fail to load.", model ); } if let Some(descriptors) = &tools { if !descriptors.is_empty() { debug!( "Ignoring {} MCP tool descriptors for Ollama request (tool calling unsupported)", descriptors.len() ); } } let converted_messages = messages.into_iter().map(convert_message).collect(); let mut request = OllamaChatRequest::new(model.clone(), converted_messages); if let Some(options) = build_model_options(¶meters)? { request.options = Some(options); } Ok((model, request)) } async fn fetch_models(&self) -> Result> { let models = self .client .list_local_models() .await .map_err(|err| self.map_ollama_error("list models", err, None))?; let client = self.client.clone(); let fetched = join_all(models.into_iter().map(|local| { let client = client.clone(); async move { let name = local.name.clone(); let detail = match client.show_model_info(name.clone()).await { Ok(info) => Some(info), Err(err) => { debug!("Failed to fetch Ollama model info for '{name}': {err}"); None } }; (local, detail) } })) .await; Ok(fetched .into_iter() .map(|(local, detail)| self.convert_model(local, detail)) .collect()) } fn convert_model(&self, model: LocalModel, detail: Option) -> ModelInfo { let scope = match self.mode { OllamaMode::Local => "local", OllamaMode::Cloud => "cloud", }; let name = model.name; let mut capabilities: Vec = detail .as_ref() .map(|info| { info.capabilities .iter() .map(|cap| cap.to_ascii_lowercase()) .collect() }) .unwrap_or_default(); push_capability(&mut capabilities, "chat"); for heuristic in heuristic_capabilities(&name) { push_capability(&mut capabilities, &heuristic); } let description = build_model_description(scope, detail.as_ref()); ModelInfo { id: name.clone(), name, description: Some(description), provider: "ollama".to_string(), context_window: None, capabilities, supports_tools: false, } } fn convert_ollama_response(response: OllamaChatResponse, streaming: bool) -> ChatResponse { let usage = response.final_data.as_ref().map(|data| { let prompt = clamp_to_u32(data.prompt_eval_count); let completion = clamp_to_u32(data.eval_count); TokenUsage { prompt_tokens: prompt, completion_tokens: completion, total_tokens: prompt.saturating_add(completion), } }); ChatResponse { message: convert_ollama_message(response.message), usage, is_streaming: streaming, is_final: if streaming { response.done } else { true }, } } fn map_ollama_error(&self, action: &str, err: OllamaError, model: Option<&str>) -> Error { match err { OllamaError::ReqwestError(request_err) => { if let Some(status) = request_err.status() { self.map_http_failure(action, status, request_err.to_string(), model) } else if request_err.is_timeout() { Error::Timeout(format!("Ollama {action} timed out: {request_err}")) } else { Error::Network(format!("Ollama {action} request failed: {request_err}")) } } OllamaError::InternalError(internal) => Error::Provider(anyhow!(internal.message)), OllamaError::Other(message) => Error::Provider(anyhow!(message)), OllamaError::JsonError(err) => Error::Serialization(err), OllamaError::ToolCallError(err) => Error::Provider(anyhow!(err)), } } fn map_http_failure( &self, action: &str, status: StatusCode, detail: String, model: Option<&str>, ) -> Error { match status { StatusCode::NOT_FOUND => { if let Some(model) = model { Error::InvalidInput(format!( "Model '{model}' was not found at {}. Verify the name or pull it with `ollama pull`.", self.base_url )) } else { Error::InvalidInput(format!( "{action} returned 404 from {}: {detail}", self.base_url )) } } StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => Error::Auth(format!( "Ollama rejected the request ({status}): {detail}. Check your API key and account permissions." )), StatusCode::BAD_REQUEST => Error::InvalidInput(format!( "{action} rejected by Ollama ({status}): {detail}" )), StatusCode::SERVICE_UNAVAILABLE | StatusCode::GATEWAY_TIMEOUT => Error::Timeout( format!( "Ollama {action} timed out ({status}). The model may still be loading." ), ), _ => Error::Network(format!( "Ollama {action} failed ({status}): {detail}" )), } } } impl LLMProvider for OllamaProvider { type Stream = Pin> + Send>>; type ListModelsFuture<'a> = BoxFuture<'a, Result>> where Self: 'a; type ChatFuture<'a> = BoxFuture<'a, Result> where Self: 'a; type ChatStreamFuture<'a> = BoxFuture<'a, Result> where Self: 'a; type HealthCheckFuture<'a> = BoxFuture<'a, Result<()>> where Self: 'a; fn name(&self) -> &str { "ollama" } fn list_models(&self) -> Self::ListModelsFuture<'_> { Box::pin(async move { self.model_manager .get_or_refresh(false, || async { self.fetch_models().await }) .await }) } fn chat(&self, request: ChatRequest) -> Self::ChatFuture<'_> { Box::pin(async move { let ChatRequest { model, messages, parameters, tools, } = request; let (model_id, ollama_request) = self.prepare_chat_request(model, messages, parameters, tools)?; let response = self .client .send_chat_messages(ollama_request) .await .map_err(|err| self.map_ollama_error("chat", err, Some(&model_id)))?; Ok(Self::convert_ollama_response(response, false)) }) } fn chat_stream(&self, request: ChatRequest) -> Self::ChatStreamFuture<'_> { Box::pin(async move { let ChatRequest { model, messages, parameters, tools, } = request; let (model_id, ollama_request) = self.prepare_chat_request(model, messages, parameters, tools)?; let stream = self .client .send_chat_messages_stream(ollama_request) .await .map_err(|err| self.map_ollama_error("chat_stream", err, Some(&model_id)))?; let mapped = stream.map(|item| match item { Ok(chunk) => Ok(Self::convert_ollama_response(chunk, true)), Err(_) => Err(Error::Provider(anyhow!( "Ollama returned a malformed streaming chunk" ))), }); Ok(Box::pin(mapped) as Self::Stream) }) } fn health_check(&self) -> Self::HealthCheckFuture<'_> { Box::pin(async move { let url = self.api_url("version"); let response = self .http_client .get(&url) .send() .await .map_err(|err| map_reqwest_error("health check", err))?; if response.status().is_success() { return Ok(()); } let status = response.status(); let detail = response.text().await.unwrap_or_else(|err| err.to_string()); Err(self.map_http_failure("health check", status, detail, None)) }) } fn config_schema(&self) -> serde_json::Value { serde_json::json!({ "type": "object", "properties": { "base_url": { "type": "string", "description": "Base URL for the Ollama API (ignored when api_key is provided)", "default": self.mode.default_base_url() }, "timeout_secs": { "type": "integer", "description": "HTTP request timeout in seconds", "minimum": 5, "default": DEFAULT_TIMEOUT_SECS }, "model_cache_ttl_secs": { "type": "integer", "description": "Seconds to cache model listings", "minimum": 5, "default": DEFAULT_MODEL_CACHE_TTL_SECS } } }) } } fn build_model_options(parameters: &ChatParameters) -> Result> { let mut options = JsonMap::new(); for (key, value) in ¶meters.extra { options.insert(key.clone(), value.clone()); } if let Some(temperature) = parameters.temperature { options.insert("temperature".to_string(), json!(temperature)); } if let Some(max_tokens) = parameters.max_tokens { let capped = i32::try_from(max_tokens).unwrap_or(i32::MAX); options.insert("num_predict".to_string(), json!(capped)); } if options.is_empty() { return Ok(None); } serde_json::from_value(Value::Object(options)) .map(Some) .map_err(|err| Error::Config(format!("Invalid Ollama options: {err}"))) } fn convert_message(message: Message) -> OllamaMessage { let Message { role, content, metadata, tool_calls, .. } = message; let role = match role { Role::User => OllamaRole::User, Role::Assistant => OllamaRole::Assistant, Role::System => OllamaRole::System, Role::Tool => OllamaRole::Tool, }; let tool_calls = tool_calls .unwrap_or_default() .into_iter() .map(|tool_call| OllamaToolCall { function: OllamaToolCallFunction { name: tool_call.name, arguments: tool_call.arguments, }, }) .collect(); let thinking = metadata .get("thinking") .and_then(|value| value.as_str().map(|s| s.to_owned())); OllamaMessage { role, content, tool_calls, images: None, thinking, } } fn convert_ollama_message(message: OllamaMessage) -> Message { let role = match message.role { OllamaRole::Assistant => Role::Assistant, OllamaRole::System => Role::System, OllamaRole::Tool => Role::Tool, OllamaRole::User => Role::User, }; let tool_calls = if message.tool_calls.is_empty() { None } else { Some( message .tool_calls .into_iter() .enumerate() .map(|(idx, tool_call)| ToolCall { id: format!("tool-call-{idx}"), name: tool_call.function.name, arguments: tool_call.function.arguments, }) .collect::>(), ) }; let mut metadata = HashMap::new(); if let Some(thinking) = message.thinking { metadata.insert("thinking".to_string(), Value::String(thinking)); } Message { id: Uuid::new_v4(), role, content: message.content, metadata, timestamp: SystemTime::now(), tool_calls, } } fn clamp_to_u32(value: u64) -> u32 { u32::try_from(value).unwrap_or(u32::MAX) } fn push_capability(capabilities: &mut Vec, capability: &str) { let candidate = capability.to_ascii_lowercase(); if !capabilities .iter() .any(|existing| existing.eq_ignore_ascii_case(&candidate)) { capabilities.push(candidate); } } fn heuristic_capabilities(name: &str) -> Vec { let lowercase = name.to_ascii_lowercase(); let mut detected = Vec::new(); if lowercase.contains("vision") || lowercase.contains("multimodal") || lowercase.contains("image") { detected.push("vision".to_string()); } if lowercase.contains("think") || lowercase.contains("reason") || lowercase.contains("deepseek-r1") || lowercase.contains("r1") { detected.push("thinking".to_string()); } if lowercase.contains("audio") || lowercase.contains("speech") || lowercase.contains("voice") { detected.push("audio".to_string()); } detected } fn build_model_description(scope: &str, detail: Option<&OllamaModelInfo>) -> String { if let Some(info) = detail { let mut parts = Vec::new(); if let Some(family) = info .model_info .get("family") .and_then(|value| value.as_str()) { parts.push(family.to_string()); } if let Some(parameter_size) = info .model_info .get("parameter_size") .and_then(|value| value.as_str()) { parts.push(parameter_size.to_string()); } if let Some(variant) = info .model_info .get("variant") .and_then(|value| value.as_str()) { parts.push(variant.to_string()); } if !parts.is_empty() { return format!("Ollama ({scope}) – {}", parts.join(" · ")); } } format!("Ollama ({scope}) model") } fn env_var_non_empty(name: &str) -> Option { env::var(name) .ok() .map(|value| value.trim().to_string()) .filter(|value| !value.is_empty()) } fn resolve_api_key(configured: Option) -> Option { let raw = configured?.trim().to_string(); if raw.is_empty() { return None; } if let Some(variable) = raw .strip_prefix("${") .and_then(|value| value.strip_suffix('}')) .or_else(|| raw.strip_prefix('$')) { let var_name = variable.trim(); if var_name.is_empty() { return None; } return env_var_non_empty(var_name); } Some(raw) } fn map_reqwest_error(action: &str, err: reqwest::Error) -> Error { if err.is_timeout() { Error::Timeout(format!("Ollama {action} request timed out: {err}")) } else { Error::Network(format!("Ollama {action} request failed: {err}")) } } fn normalize_base_url( input: Option<&str>, mode_hint: OllamaMode, ) -> std::result::Result { let mut candidate = input .map(str::trim) .filter(|value| !value.is_empty()) .map(|value| value.to_string()) .unwrap_or_else(|| mode_hint.default_base_url().to_string()); if !candidate.starts_with("http://") && !candidate.starts_with("https://") { candidate = format!("https://{candidate}"); } let mut url = Url::parse(&candidate).map_err(|err| format!("Invalid Ollama URL '{candidate}': {err}"))?; if url.cannot_be_a_base() { return Err(format!("URL '{candidate}' cannot be used as a base URL")); } if mode_hint == OllamaMode::Cloud && url.scheme() != "https" { return Err("Ollama Cloud requires https:// base URLs".to_string()); } let path = url.path().trim_end_matches('/'); if path == "/api" { url.set_path("/"); } else if !path.is_empty() && path != "/" { return Err("Ollama base URLs must not include additional path segments".to_string()); } url.set_query(None); url.set_fragment(None); Ok(url.to_string().trim_end_matches('/').to_string()) } fn build_api_endpoint(base_url: &str, endpoint: &str) -> String { let trimmed_base = base_url.trim_end_matches('/'); let trimmed_endpoint = endpoint.trim_start_matches('/'); if trimmed_base.ends_with("/api") { format!("{trimmed_base}/{trimmed_endpoint}") } else { format!("{trimmed_base}/api/{trimmed_endpoint}") } } #[cfg(test)] mod tests { use super::*; #[test] fn resolve_api_key_prefers_literal_value() { assert_eq!( resolve_api_key(Some("direct-key".into())), Some("direct-key".into()) ); } #[test] fn resolve_api_key_expands_env_var() { std::env::set_var("OLLAMA_TEST_KEY", "secret"); assert_eq!( resolve_api_key(Some("${OLLAMA_TEST_KEY}".into())), Some("secret".into()) ); std::env::remove_var("OLLAMA_TEST_KEY"); } #[test] fn normalize_base_url_removes_api_path() { let url = normalize_base_url(Some("https://ollama.com/api"), OllamaMode::Cloud).unwrap(); assert_eq!(url, "https://ollama.com"); } #[test] fn normalize_base_url_rejects_cloud_without_https() { let err = normalize_base_url(Some("http://ollama.com"), OllamaMode::Cloud).unwrap_err(); assert!(err.contains("https")); } #[test] fn build_model_options_merges_parameters() { let mut parameters = ChatParameters::default(); parameters.temperature = Some(0.3); parameters.max_tokens = Some(128); parameters .extra .insert("num_ctx".into(), Value::from(4096_u64)); let options = build_model_options(¶meters) .expect("options built") .expect("options present"); let serialized = serde_json::to_value(&options).expect("serialize options"); let temperature = serialized["temperature"] .as_f64() .expect("temperature present"); assert!((temperature - 0.3).abs() < 1e-6); assert_eq!(serialized["num_predict"], json!(128)); assert_eq!(serialized["num_ctx"], json!(4096)); } #[test] fn heuristic_capabilities_detects_thinking_models() { let caps = heuristic_capabilities("deepseek-r1"); assert!(caps.iter().any(|cap| cap == "thinking")); } #[test] fn push_capability_avoids_duplicates() { let mut caps = vec!["chat".to_string()]; push_capability(&mut caps, "Chat"); push_capability(&mut caps, "Vision"); push_capability(&mut caps, "vision"); assert_eq!(caps.len(), 2); assert!(caps.iter().any(|cap| cap == "vision")); } }