Files
owlen/crates/owlen-core/src/providers/ollama.rs
vikingowl fab63d224b refactor(ollama): replace handcrafted HTTP logic with ollama‑rs client and simplify request handling
- Switch to `ollama-rs` crate for chat, model listing, and streaming.
- Remove custom request building, authentication handling, and debug logging.
- Drop unsupported tool conversion; now ignore tool descriptors with a warning.
- Refactor model fetching to use local model info and optional cloud details.
- Consolidate error mapping via `map_ollama_error`.
- Update health check to use the new HTTP client.
- Delete obsolete `provider_interface.rs` test as the provider interface has changed.
2025-10-12 07:09:58 +02:00

842 lines
26 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
//! Ollama provider built on top of the `ollama-rs` crate.
use std::{
collections::HashMap,
env,
pin::Pin,
time::{Duration, SystemTime},
};
use anyhow::anyhow;
use futures::{future::join_all, future::BoxFuture, Stream, StreamExt};
use log::{debug, warn};
use ollama_rs::{
error::OllamaError,
generation::chat::{
request::ChatMessageRequest as OllamaChatRequest, ChatMessage as OllamaMessage,
ChatMessageResponse as OllamaChatResponse, MessageRole as OllamaRole,
},
generation::tools::{ToolCall as OllamaToolCall, ToolCallFunction as OllamaToolCallFunction},
headers::{HeaderMap, HeaderValue, AUTHORIZATION},
models::{LocalModel, ModelInfo as OllamaModelInfo, ModelOptions},
Ollama,
};
use reqwest::{Client, StatusCode, Url};
use serde_json::{json, Map as JsonMap, Value};
use uuid::Uuid;
use crate::{
config::GeneralSettings,
mcp::McpToolDescriptor,
model::ModelManager,
provider::{LLMProvider, ProviderConfig},
types::{
ChatParameters, ChatRequest, ChatResponse, Message, ModelInfo, Role, TokenUsage, ToolCall,
},
Error, Result,
};
const DEFAULT_TIMEOUT_SECS: u64 = 120;
const DEFAULT_MODEL_CACHE_TTL_SECS: u64 = 60;
const CLOUD_BASE_URL: &str = "https://ollama.com";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum OllamaMode {
Local,
Cloud,
}
impl OllamaMode {
fn default_base_url(self) -> &'static str {
match self {
Self::Local => "http://localhost:11434",
Self::Cloud => CLOUD_BASE_URL,
}
}
}
#[derive(Debug)]
struct OllamaOptions {
mode: OllamaMode,
base_url: String,
request_timeout: Duration,
model_cache_ttl: Duration,
api_key: Option<String>,
}
impl OllamaOptions {
fn new(mode: OllamaMode, base_url: impl Into<String>) -> Self {
Self {
mode,
base_url: base_url.into(),
request_timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
model_cache_ttl: Duration::from_secs(DEFAULT_MODEL_CACHE_TTL_SECS),
api_key: None,
}
}
fn with_general(mut self, general: &GeneralSettings) -> Self {
self.model_cache_ttl = general.model_cache_ttl();
self
}
}
/// Ollama provider implementation backed by `ollama-rs`.
#[derive(Debug)]
pub struct OllamaProvider {
mode: OllamaMode,
client: Ollama,
http_client: Client,
base_url: String,
model_manager: ModelManager,
}
impl OllamaProvider {
/// Create a provider targeting an explicit base URL (local usage).
pub fn new(base_url: impl Into<String>) -> Result<Self> {
let input = base_url.into();
let normalized =
normalize_base_url(Some(&input), OllamaMode::Local).map_err(Error::Config)?;
Self::with_options(OllamaOptions::new(OllamaMode::Local, normalized))
}
/// Construct a provider from configuration settings.
pub fn from_config(config: &ProviderConfig, general: Option<&GeneralSettings>) -> Result<Self> {
let mut api_key = resolve_api_key(config.api_key.clone())
.or_else(|| env_var_non_empty("OLLAMA_API_KEY"))
.or_else(|| env_var_non_empty("OLLAMA_CLOUD_API_KEY"));
let mode = if api_key.is_some() {
OllamaMode::Cloud
} else {
OllamaMode::Local
};
let base_candidate = if mode == OllamaMode::Cloud {
Some(CLOUD_BASE_URL)
} else {
config.base_url.as_deref()
};
let normalized_base_url =
normalize_base_url(base_candidate, mode).map_err(Error::Config)?;
let mut options = OllamaOptions::new(mode, normalized_base_url);
if let Some(timeout) = config
.extra
.get("timeout_secs")
.and_then(|value| value.as_u64())
{
options.request_timeout = Duration::from_secs(timeout.max(5));
}
if let Some(cache_ttl) = config
.extra
.get("model_cache_ttl_secs")
.and_then(|value| value.as_u64())
{
options.model_cache_ttl = Duration::from_secs(cache_ttl.max(5));
}
options.api_key = api_key.take();
if let Some(general) = general {
options = options.with_general(general);
}
Self::with_options(options)
}
fn with_options(options: OllamaOptions) -> Result<Self> {
let OllamaOptions {
mode,
base_url,
request_timeout,
model_cache_ttl,
api_key,
} = options;
let url = Url::parse(&base_url)
.map_err(|err| Error::Config(format!("Invalid Ollama base URL '{base_url}': {err}")))?;
let mut headers = HeaderMap::new();
if let Some(ref key) = api_key {
let value = HeaderValue::from_str(&format!("Bearer {key}")).map_err(|_| {
Error::Config("OLLAMA API key contains invalid characters".to_string())
})?;
headers.insert(AUTHORIZATION, value);
}
let mut client_builder = Client::builder().timeout(request_timeout);
if !headers.is_empty() {
client_builder = client_builder.default_headers(headers.clone());
}
let http_client = client_builder
.build()
.map_err(|err| Error::Config(format!("Failed to build HTTP client: {err}")))?;
let port = url.port_or_known_default().ok_or_else(|| {
Error::Config(format!("Unable to determine port for Ollama URL '{}'", url))
})?;
let mut ollama_client = Ollama::new_with_client(url.clone(), port, http_client.clone());
if !headers.is_empty() {
ollama_client.set_headers(Some(headers.clone()));
}
Ok(Self {
mode,
client: ollama_client,
http_client,
base_url: base_url.trim_end_matches('/').to_string(),
model_manager: ModelManager::new(model_cache_ttl),
})
}
fn api_url(&self, endpoint: &str) -> String {
build_api_endpoint(&self.base_url, endpoint)
}
fn prepare_chat_request(
&self,
model: String,
messages: Vec<Message>,
parameters: ChatParameters,
tools: Option<Vec<McpToolDescriptor>>,
) -> Result<(String, OllamaChatRequest)> {
if self.mode == OllamaMode::Cloud && !model.contains("-cloud") {
warn!(
"Model '{}' does not use the '-cloud' suffix. Cloud-only models may fail to load.",
model
);
}
if let Some(descriptors) = &tools {
if !descriptors.is_empty() {
debug!(
"Ignoring {} MCP tool descriptors for Ollama request (tool calling unsupported)",
descriptors.len()
);
}
}
let converted_messages = messages.into_iter().map(convert_message).collect();
let mut request = OllamaChatRequest::new(model.clone(), converted_messages);
if let Some(options) = build_model_options(&parameters)? {
request.options = Some(options);
}
Ok((model, request))
}
async fn fetch_models(&self) -> Result<Vec<ModelInfo>> {
let models = self
.client
.list_local_models()
.await
.map_err(|err| self.map_ollama_error("list models", err, None))?;
let client = self.client.clone();
let fetched = join_all(models.into_iter().map(|local| {
let client = client.clone();
async move {
let name = local.name.clone();
let detail = match client.show_model_info(name.clone()).await {
Ok(info) => Some(info),
Err(err) => {
debug!("Failed to fetch Ollama model info for '{name}': {err}");
None
}
};
(local, detail)
}
}))
.await;
Ok(fetched
.into_iter()
.map(|(local, detail)| self.convert_model(local, detail))
.collect())
}
fn convert_model(&self, model: LocalModel, detail: Option<OllamaModelInfo>) -> ModelInfo {
let scope = match self.mode {
OllamaMode::Local => "local",
OllamaMode::Cloud => "cloud",
};
let name = model.name;
let mut capabilities: Vec<String> = detail
.as_ref()
.map(|info| {
info.capabilities
.iter()
.map(|cap| cap.to_ascii_lowercase())
.collect()
})
.unwrap_or_default();
push_capability(&mut capabilities, "chat");
for heuristic in heuristic_capabilities(&name) {
push_capability(&mut capabilities, &heuristic);
}
let description = build_model_description(scope, detail.as_ref());
ModelInfo {
id: name.clone(),
name,
description: Some(description),
provider: "ollama".to_string(),
context_window: None,
capabilities,
supports_tools: false,
}
}
fn convert_ollama_response(response: OllamaChatResponse, streaming: bool) -> ChatResponse {
let usage = response.final_data.as_ref().map(|data| {
let prompt = clamp_to_u32(data.prompt_eval_count);
let completion = clamp_to_u32(data.eval_count);
TokenUsage {
prompt_tokens: prompt,
completion_tokens: completion,
total_tokens: prompt.saturating_add(completion),
}
});
ChatResponse {
message: convert_ollama_message(response.message),
usage,
is_streaming: streaming,
is_final: if streaming { response.done } else { true },
}
}
fn map_ollama_error(&self, action: &str, err: OllamaError, model: Option<&str>) -> Error {
match err {
OllamaError::ReqwestError(request_err) => {
if let Some(status) = request_err.status() {
self.map_http_failure(action, status, request_err.to_string(), model)
} else if request_err.is_timeout() {
Error::Timeout(format!("Ollama {action} timed out: {request_err}"))
} else {
Error::Network(format!("Ollama {action} request failed: {request_err}"))
}
}
OllamaError::InternalError(internal) => Error::Provider(anyhow!(internal.message)),
OllamaError::Other(message) => Error::Provider(anyhow!(message)),
OllamaError::JsonError(err) => Error::Serialization(err),
OllamaError::ToolCallError(err) => Error::Provider(anyhow!(err)),
}
}
fn map_http_failure(
&self,
action: &str,
status: StatusCode,
detail: String,
model: Option<&str>,
) -> Error {
match status {
StatusCode::NOT_FOUND => {
if let Some(model) = model {
Error::InvalidInput(format!(
"Model '{model}' was not found at {}. Verify the name or pull it with `ollama pull`.",
self.base_url
))
} else {
Error::InvalidInput(format!(
"{action} returned 404 from {}: {detail}",
self.base_url
))
}
}
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => Error::Auth(format!(
"Ollama rejected the request ({status}): {detail}. Check your API key and account permissions."
)),
StatusCode::BAD_REQUEST => Error::InvalidInput(format!(
"{action} rejected by Ollama ({status}): {detail}"
)),
StatusCode::SERVICE_UNAVAILABLE | StatusCode::GATEWAY_TIMEOUT => Error::Timeout(
format!(
"Ollama {action} timed out ({status}). The model may still be loading."
),
),
_ => Error::Network(format!(
"Ollama {action} failed ({status}): {detail}"
)),
}
}
}
impl LLMProvider for OllamaProvider {
type Stream = Pin<Box<dyn Stream<Item = Result<ChatResponse>> + Send>>;
type ListModelsFuture<'a>
= BoxFuture<'a, Result<Vec<ModelInfo>>>
where
Self: 'a;
type ChatFuture<'a>
= BoxFuture<'a, Result<ChatResponse>>
where
Self: 'a;
type ChatStreamFuture<'a>
= BoxFuture<'a, Result<Self::Stream>>
where
Self: 'a;
type HealthCheckFuture<'a>
= BoxFuture<'a, Result<()>>
where
Self: 'a;
fn name(&self) -> &str {
"ollama"
}
fn list_models(&self) -> Self::ListModelsFuture<'_> {
Box::pin(async move {
self.model_manager
.get_or_refresh(false, || async { self.fetch_models().await })
.await
})
}
fn chat(&self, request: ChatRequest) -> Self::ChatFuture<'_> {
Box::pin(async move {
let ChatRequest {
model,
messages,
parameters,
tools,
} = request;
let (model_id, ollama_request) =
self.prepare_chat_request(model, messages, parameters, tools)?;
let response = self
.client
.send_chat_messages(ollama_request)
.await
.map_err(|err| self.map_ollama_error("chat", err, Some(&model_id)))?;
Ok(Self::convert_ollama_response(response, false))
})
}
fn chat_stream(&self, request: ChatRequest) -> Self::ChatStreamFuture<'_> {
Box::pin(async move {
let ChatRequest {
model,
messages,
parameters,
tools,
} = request;
let (model_id, ollama_request) =
self.prepare_chat_request(model, messages, parameters, tools)?;
let stream = self
.client
.send_chat_messages_stream(ollama_request)
.await
.map_err(|err| self.map_ollama_error("chat_stream", err, Some(&model_id)))?;
let mapped = stream.map(|item| match item {
Ok(chunk) => Ok(Self::convert_ollama_response(chunk, true)),
Err(_) => Err(Error::Provider(anyhow!(
"Ollama returned a malformed streaming chunk"
))),
});
Ok(Box::pin(mapped) as Self::Stream)
})
}
fn health_check(&self) -> Self::HealthCheckFuture<'_> {
Box::pin(async move {
let url = self.api_url("version");
let response = self
.http_client
.get(&url)
.send()
.await
.map_err(|err| map_reqwest_error("health check", err))?;
if response.status().is_success() {
return Ok(());
}
let status = response.status();
let detail = response.text().await.unwrap_or_else(|err| err.to_string());
Err(self.map_http_failure("health check", status, detail, None))
})
}
fn config_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"base_url": {
"type": "string",
"description": "Base URL for the Ollama API (ignored when api_key is provided)",
"default": self.mode.default_base_url()
},
"timeout_secs": {
"type": "integer",
"description": "HTTP request timeout in seconds",
"minimum": 5,
"default": DEFAULT_TIMEOUT_SECS
},
"model_cache_ttl_secs": {
"type": "integer",
"description": "Seconds to cache model listings",
"minimum": 5,
"default": DEFAULT_MODEL_CACHE_TTL_SECS
}
}
})
}
}
fn build_model_options(parameters: &ChatParameters) -> Result<Option<ModelOptions>> {
let mut options = JsonMap::new();
for (key, value) in &parameters.extra {
options.insert(key.clone(), value.clone());
}
if let Some(temperature) = parameters.temperature {
options.insert("temperature".to_string(), json!(temperature));
}
if let Some(max_tokens) = parameters.max_tokens {
let capped = i32::try_from(max_tokens).unwrap_or(i32::MAX);
options.insert("num_predict".to_string(), json!(capped));
}
if options.is_empty() {
return Ok(None);
}
serde_json::from_value(Value::Object(options))
.map(Some)
.map_err(|err| Error::Config(format!("Invalid Ollama options: {err}")))
}
fn convert_message(message: Message) -> OllamaMessage {
let Message {
role,
content,
metadata,
tool_calls,
..
} = message;
let role = match role {
Role::User => OllamaRole::User,
Role::Assistant => OllamaRole::Assistant,
Role::System => OllamaRole::System,
Role::Tool => OllamaRole::Tool,
};
let tool_calls = tool_calls
.unwrap_or_default()
.into_iter()
.map(|tool_call| OllamaToolCall {
function: OllamaToolCallFunction {
name: tool_call.name,
arguments: tool_call.arguments,
},
})
.collect();
let thinking = metadata
.get("thinking")
.and_then(|value| value.as_str().map(|s| s.to_owned()));
OllamaMessage {
role,
content,
tool_calls,
images: None,
thinking,
}
}
fn convert_ollama_message(message: OllamaMessage) -> Message {
let role = match message.role {
OllamaRole::Assistant => Role::Assistant,
OllamaRole::System => Role::System,
OllamaRole::Tool => Role::Tool,
OllamaRole::User => Role::User,
};
let tool_calls = if message.tool_calls.is_empty() {
None
} else {
Some(
message
.tool_calls
.into_iter()
.enumerate()
.map(|(idx, tool_call)| ToolCall {
id: format!("tool-call-{idx}"),
name: tool_call.function.name,
arguments: tool_call.function.arguments,
})
.collect::<Vec<_>>(),
)
};
let mut metadata = HashMap::new();
if let Some(thinking) = message.thinking {
metadata.insert("thinking".to_string(), Value::String(thinking));
}
Message {
id: Uuid::new_v4(),
role,
content: message.content,
metadata,
timestamp: SystemTime::now(),
tool_calls,
}
}
fn clamp_to_u32(value: u64) -> u32 {
u32::try_from(value).unwrap_or(u32::MAX)
}
fn push_capability(capabilities: &mut Vec<String>, capability: &str) {
let candidate = capability.to_ascii_lowercase();
if !capabilities
.iter()
.any(|existing| existing.eq_ignore_ascii_case(&candidate))
{
capabilities.push(candidate);
}
}
fn heuristic_capabilities(name: &str) -> Vec<String> {
let lowercase = name.to_ascii_lowercase();
let mut detected = Vec::new();
if lowercase.contains("vision")
|| lowercase.contains("multimodal")
|| lowercase.contains("image")
{
detected.push("vision".to_string());
}
if lowercase.contains("think")
|| lowercase.contains("reason")
|| lowercase.contains("deepseek-r1")
|| lowercase.contains("r1")
{
detected.push("thinking".to_string());
}
if lowercase.contains("audio") || lowercase.contains("speech") || lowercase.contains("voice") {
detected.push("audio".to_string());
}
detected
}
fn build_model_description(scope: &str, detail: Option<&OllamaModelInfo>) -> String {
if let Some(info) = detail {
let mut parts = Vec::new();
if let Some(family) = info
.model_info
.get("family")
.and_then(|value| value.as_str())
{
parts.push(family.to_string());
}
if let Some(parameter_size) = info
.model_info
.get("parameter_size")
.and_then(|value| value.as_str())
{
parts.push(parameter_size.to_string());
}
if let Some(variant) = info
.model_info
.get("variant")
.and_then(|value| value.as_str())
{
parts.push(variant.to_string());
}
if !parts.is_empty() {
return format!("Ollama ({scope}) {}", parts.join(" · "));
}
}
format!("Ollama ({scope}) model")
}
fn env_var_non_empty(name: &str) -> Option<String> {
env::var(name)
.ok()
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty())
}
fn resolve_api_key(configured: Option<String>) -> Option<String> {
let raw = configured?.trim().to_string();
if raw.is_empty() {
return None;
}
if let Some(variable) = raw
.strip_prefix("${")
.and_then(|value| value.strip_suffix('}'))
.or_else(|| raw.strip_prefix('$'))
{
let var_name = variable.trim();
if var_name.is_empty() {
return None;
}
return env_var_non_empty(var_name);
}
Some(raw)
}
fn map_reqwest_error(action: &str, err: reqwest::Error) -> Error {
if err.is_timeout() {
Error::Timeout(format!("Ollama {action} request timed out: {err}"))
} else {
Error::Network(format!("Ollama {action} request failed: {err}"))
}
}
fn normalize_base_url(
input: Option<&str>,
mode_hint: OllamaMode,
) -> std::result::Result<String, String> {
let mut candidate = input
.map(str::trim)
.filter(|value| !value.is_empty())
.map(|value| value.to_string())
.unwrap_or_else(|| mode_hint.default_base_url().to_string());
if !candidate.starts_with("http://") && !candidate.starts_with("https://") {
candidate = format!("https://{candidate}");
}
let mut url =
Url::parse(&candidate).map_err(|err| format!("Invalid Ollama URL '{candidate}': {err}"))?;
if url.cannot_be_a_base() {
return Err(format!("URL '{candidate}' cannot be used as a base URL"));
}
if mode_hint == OllamaMode::Cloud && url.scheme() != "https" {
return Err("Ollama Cloud requires https:// base URLs".to_string());
}
let path = url.path().trim_end_matches('/');
if path == "/api" {
url.set_path("/");
} else if !path.is_empty() && path != "/" {
return Err("Ollama base URLs must not include additional path segments".to_string());
}
url.set_query(None);
url.set_fragment(None);
Ok(url.to_string().trim_end_matches('/').to_string())
}
fn build_api_endpoint(base_url: &str, endpoint: &str) -> String {
let trimmed_base = base_url.trim_end_matches('/');
let trimmed_endpoint = endpoint.trim_start_matches('/');
if trimmed_base.ends_with("/api") {
format!("{trimmed_base}/{trimmed_endpoint}")
} else {
format!("{trimmed_base}/api/{trimmed_endpoint}")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn resolve_api_key_prefers_literal_value() {
assert_eq!(
resolve_api_key(Some("direct-key".into())),
Some("direct-key".into())
);
}
#[test]
fn resolve_api_key_expands_env_var() {
std::env::set_var("OLLAMA_TEST_KEY", "secret");
assert_eq!(
resolve_api_key(Some("${OLLAMA_TEST_KEY}".into())),
Some("secret".into())
);
std::env::remove_var("OLLAMA_TEST_KEY");
}
#[test]
fn normalize_base_url_removes_api_path() {
let url = normalize_base_url(Some("https://ollama.com/api"), OllamaMode::Cloud).unwrap();
assert_eq!(url, "https://ollama.com");
}
#[test]
fn normalize_base_url_rejects_cloud_without_https() {
let err = normalize_base_url(Some("http://ollama.com"), OllamaMode::Cloud).unwrap_err();
assert!(err.contains("https"));
}
#[test]
fn build_model_options_merges_parameters() {
let mut parameters = ChatParameters::default();
parameters.temperature = Some(0.3);
parameters.max_tokens = Some(128);
parameters
.extra
.insert("num_ctx".into(), Value::from(4096_u64));
let options = build_model_options(&parameters)
.expect("options built")
.expect("options present");
let serialized = serde_json::to_value(&options).expect("serialize options");
let temperature = serialized["temperature"]
.as_f64()
.expect("temperature present");
assert!((temperature - 0.3).abs() < 1e-6);
assert_eq!(serialized["num_predict"], json!(128));
assert_eq!(serialized["num_ctx"], json!(4096));
}
#[test]
fn heuristic_capabilities_detects_thinking_models() {
let caps = heuristic_capabilities("deepseek-r1");
assert!(caps.iter().any(|cap| cap == "thinking"));
}
#[test]
fn push_capability_avoids_duplicates() {
let mut caps = vec!["chat".to_string()];
push_capability(&mut caps, "Chat");
push_capability(&mut caps, "Vision");
push_capability(&mut caps, "vision");
assert_eq!(caps.len(), 2);
assert!(caps.iter().any(|cap| cap == "vision"));
}
}