diff --git a/crates/owlen-core/Cargo.toml b/crates/owlen-core/Cargo.toml index c1ab5d1..a4e1a94 100644 --- a/crates/owlen-core/Cargo.toml +++ b/crates/owlen-core/Cargo.toml @@ -46,6 +46,7 @@ path-clean = "1.0" tokio-stream = { workspace = true } tokio-tungstenite = "0.21" tungstenite = "0.21" +ollama-rs = { version = "0.3", features = ["stream", "headers"] } [dev-dependencies] tokio-test = { workspace = true } diff --git a/crates/owlen-core/src/config.rs b/crates/owlen-core/src/config.rs index 1af0bb0..65dd99f 100644 --- a/crates/owlen-core/src/config.rs +++ b/crates/owlen-core/src/config.rs @@ -134,10 +134,13 @@ impl Config { config.ensure_defaults(); config.mcp.apply_backward_compat(); config.apply_schema_migrations(&previous_version); + config.expand_provider_env_vars()?; config.validate()?; Ok(config) } else { - Ok(Config::default()) + let mut config = Config::default(); + config.expand_provider_env_vars()?; + Ok(config) } } @@ -201,6 +204,13 @@ impl Config { } } + fn expand_provider_env_vars(&mut self) -> Result<()> { + for (provider_name, provider) in self.providers.iter_mut() { + expand_provider_entry(provider_name, provider)?; + } + Ok(()) + } + /// Validate configuration invariants and surface actionable error messages. pub fn validate(&self) -> Result<()> { self.validate_default_provider()?; @@ -336,6 +346,56 @@ fn default_ollama_provider_config() -> ProviderConfig { } } +fn expand_provider_entry(provider_name: &str, provider: &mut ProviderConfig) -> Result<()> { + if let Some(ref mut base_url) = provider.base_url { + let expanded = expand_env_string( + base_url.as_str(), + &format!("providers.{provider_name}.base_url"), + )?; + *base_url = expanded; + } + + if let Some(ref mut api_key) = provider.api_key { + let expanded = expand_env_string( + api_key.as_str(), + &format!("providers.{provider_name}.api_key"), + )?; + *api_key = expanded; + } + + for (extra_key, extra_value) in provider.extra.iter_mut() { + if let serde_json::Value::String(current) = extra_value { + let expanded = expand_env_string( + current.as_str(), + &format!("providers.{provider_name}.{}", extra_key), + )?; + *current = expanded; + } + } + + Ok(()) +} + +fn expand_env_string(input: &str, field_path: &str) -> Result { + if !input.contains('$') { + return Ok(input.to_string()); + } + + match shellexpand::env(input) { + Ok(expanded) => Ok(expanded.into_owned()), + Err(err) => match err.cause { + std::env::VarError::NotPresent => Err(crate::Error::Config(format!( + "Environment variable {} referenced in {field_path} is not set", + err.var_name + ))), + std::env::VarError::NotUnicode(_) => Err(crate::Error::Config(format!( + "Environment variable {} referenced in {field_path} contains invalid Unicode", + err.var_name + ))), + }, + } +} + /// Default configuration path with user home expansion pub fn default_config_path() -> PathBuf { if let Some(config_dir) = dirs::config_dir() { @@ -836,6 +896,48 @@ pub fn session_timeout(config: &Config) -> Duration { mod tests { use super::*; + #[test] + fn expand_provider_env_vars_resolves_api_key() { + std::env::set_var("OWLEN_TEST_API_KEY", "super-secret"); + + let mut config = Config::default(); + if let Some(ollama) = config.providers.get_mut("ollama") { + ollama.api_key = Some("${OWLEN_TEST_API_KEY}".to_string()); + } + + config + .expand_provider_env_vars() + .expect("environment expansion succeeded"); + + assert_eq!( + config.providers["ollama"].api_key.as_deref(), + Some("super-secret") + ); + + std::env::remove_var("OWLEN_TEST_API_KEY"); + } + + #[test] + fn expand_provider_env_vars_errors_for_missing_variable() { + std::env::remove_var("OWLEN_TEST_MISSING"); + + let mut config = Config::default(); + if let Some(ollama) = config.providers.get_mut("ollama") { + ollama.api_key = Some("${OWLEN_TEST_MISSING}".to_string()); + } + + let error = config + .expand_provider_env_vars() + .expect_err("missing variables should error"); + + match error { + crate::Error::Config(message) => { + assert!(message.contains("OWLEN_TEST_MISSING")); + } + other => panic!("expected config error, got {other:?}"), + } + } + #[test] fn test_storage_platform_specific_paths() { let config = Config::default(); diff --git a/crates/owlen-core/src/providers/ollama.rs b/crates/owlen-core/src/providers/ollama.rs index f55be33..9e98a15 100644 --- a/crates/owlen-core/src/providers/ollama.rs +++ b/crates/owlen-core/src/providers/ollama.rs @@ -1,27 +1,27 @@ -//! Unified Ollama provider that transparently supports local and cloud usage. -//! -//! When an API key is available (via configuration or environment variables), -//! -//! * Requests are sent to `https://ollama.com`. -//! * The API key is attached as a bearer token. -//! * Model listings are pulled from the cloud endpoint. -//! -//! Without an API key the provider talks to the local Ollama daemon -//! (`http://localhost:11434` by default). - +//! Ollama provider built on top of the `ollama-rs` crate. use std::{ collections::HashMap, - env, io, + env, + pin::Pin, time::{Duration, SystemTime}, }; use anyhow::anyhow; -use futures_util::{future::BoxFuture, StreamExt}; -use reqwest::{header, Client, StatusCode, Url}; -use serde::{Deserialize, Serialize}; -use serde_json::{json, Value}; -use tokio::sync::mpsc; -use tokio_stream::wrappers::UnboundedReceiverStream; +use futures::{future::join_all, future::BoxFuture, Stream, StreamExt}; +use log::{debug, warn}; +use ollama_rs::{ + error::OllamaError, + generation::chat::{ + request::ChatMessageRequest as OllamaChatRequest, ChatMessage as OllamaMessage, + ChatMessageResponse as OllamaChatResponse, MessageRole as OllamaRole, + }, + generation::tools::{ToolCall as OllamaToolCall, ToolCallFunction as OllamaToolCallFunction}, + headers::{HeaderMap, HeaderValue, AUTHORIZATION}, + models::{LocalModel, ModelInfo as OllamaModelInfo, ModelOptions}, + Ollama, +}; +use reqwest::{Client, StatusCode, Url}; +use serde_json::{json, Map as JsonMap, Value}; use uuid::Uuid; use crate::{ @@ -52,13 +52,6 @@ impl OllamaMode { Self::Cloud => CLOUD_BASE_URL, } } - - fn default_scheme(self) -> &'static str { - match self { - Self::Local => "http", - Self::Cloud => "https", - } - } } #[derive(Debug)] @@ -87,18 +80,18 @@ impl OllamaOptions { } } -/// Ollama provider implementation that supports both local and cloud APIs. +/// Ollama provider implementation backed by `ollama-rs`. #[derive(Debug)] pub struct OllamaProvider { mode: OllamaMode, - client: Client, + client: Ollama, + http_client: Client, base_url: String, - api_key: Option, model_manager: ModelManager, } impl OllamaProvider { - /// Create a new provider targeting a specific base URL (local usage). + /// Create a provider targeting an explicit base URL (local usage). pub fn new(base_url: impl Into) -> Result { let input = base_url.into(); let normalized = @@ -118,7 +111,6 @@ impl OllamaProvider { OllamaMode::Local }; - // When an API key is present we always talk to the hosted cloud endpoint. let base_candidate = if mode == OllamaMode::Cloud { Some(CLOUD_BASE_URL) } else { @@ -164,34 +156,181 @@ impl OllamaProvider { api_key, } = options; - let client = Client::builder() - .timeout(request_timeout) + let url = Url::parse(&base_url) + .map_err(|err| Error::Config(format!("Invalid Ollama base URL '{base_url}': {err}")))?; + + let mut headers = HeaderMap::new(); + if let Some(ref key) = api_key { + let value = HeaderValue::from_str(&format!("Bearer {key}")).map_err(|_| { + Error::Config("OLLAMA API key contains invalid characters".to_string()) + })?; + headers.insert(AUTHORIZATION, value); + } + + let mut client_builder = Client::builder().timeout(request_timeout); + if !headers.is_empty() { + client_builder = client_builder.default_headers(headers.clone()); + } + + let http_client = client_builder .build() - .map_err(|e| Error::Config(format!("Failed to build HTTP client: {e}")))?; + .map_err(|err| Error::Config(format!("Failed to build HTTP client: {err}")))?; + + let port = url.port_or_known_default().ok_or_else(|| { + Error::Config(format!("Unable to determine port for Ollama URL '{}'", url)) + })?; + + let mut ollama_client = Ollama::new_with_client(url.clone(), port, http_client.clone()); + if !headers.is_empty() { + ollama_client.set_headers(Some(headers.clone())); + } Ok(Self { mode, - client, + client: ollama_client, + http_client, base_url: base_url.trim_end_matches('/').to_string(), - api_key, model_manager: ModelManager::new(model_cache_ttl), }) } - /// Access the underlying model manager cache (mainly used by tests). - pub fn model_manager(&self) -> &ModelManager { - &self.model_manager - } - fn api_url(&self, endpoint: &str) -> String { build_api_endpoint(&self.base_url, endpoint) } - fn apply_auth(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder { - if let Some(api_key) = &self.api_key { - request.bearer_auth(api_key) - } else { - request + fn prepare_chat_request( + &self, + model: String, + messages: Vec, + parameters: ChatParameters, + tools: Option>, + ) -> Result<(String, OllamaChatRequest)> { + if self.mode == OllamaMode::Cloud && !model.contains("-cloud") { + warn!( + "Model '{}' does not use the '-cloud' suffix. Cloud-only models may fail to load.", + model + ); + } + + if let Some(descriptors) = &tools { + if !descriptors.is_empty() { + debug!( + "Ignoring {} MCP tool descriptors for Ollama request (tool calling unsupported)", + descriptors.len() + ); + } + } + + let converted_messages = messages.into_iter().map(convert_message).collect(); + let mut request = OllamaChatRequest::new(model.clone(), converted_messages); + + if let Some(options) = build_model_options(¶meters)? { + request.options = Some(options); + } + + Ok((model, request)) + } + + async fn fetch_models(&self) -> Result> { + let models = self + .client + .list_local_models() + .await + .map_err(|err| self.map_ollama_error("list models", err, None))?; + + let client = self.client.clone(); + let fetched = join_all(models.into_iter().map(|local| { + let client = client.clone(); + async move { + let name = local.name.clone(); + let detail = match client.show_model_info(name.clone()).await { + Ok(info) => Some(info), + Err(err) => { + debug!("Failed to fetch Ollama model info for '{name}': {err}"); + None + } + }; + (local, detail) + } + })) + .await; + + Ok(fetched + .into_iter() + .map(|(local, detail)| self.convert_model(local, detail)) + .collect()) + } + + fn convert_model(&self, model: LocalModel, detail: Option) -> ModelInfo { + let scope = match self.mode { + OllamaMode::Local => "local", + OllamaMode::Cloud => "cloud", + }; + + let name = model.name; + let mut capabilities: Vec = detail + .as_ref() + .map(|info| { + info.capabilities + .iter() + .map(|cap| cap.to_ascii_lowercase()) + .collect() + }) + .unwrap_or_default(); + + push_capability(&mut capabilities, "chat"); + + for heuristic in heuristic_capabilities(&name) { + push_capability(&mut capabilities, &heuristic); + } + + let description = build_model_description(scope, detail.as_ref()); + + ModelInfo { + id: name.clone(), + name, + description: Some(description), + provider: "ollama".to_string(), + context_window: None, + capabilities, + supports_tools: false, + } + } + + fn convert_ollama_response(response: OllamaChatResponse, streaming: bool) -> ChatResponse { + let usage = response.final_data.as_ref().map(|data| { + let prompt = clamp_to_u32(data.prompt_eval_count); + let completion = clamp_to_u32(data.eval_count); + TokenUsage { + prompt_tokens: prompt, + completion_tokens: completion, + total_tokens: prompt.saturating_add(completion), + } + }); + + ChatResponse { + message: convert_ollama_message(response.message), + usage, + is_streaming: streaming, + is_final: if streaming { response.done } else { true }, + } + } + + fn map_ollama_error(&self, action: &str, err: OllamaError, model: Option<&str>) -> Error { + match err { + OllamaError::ReqwestError(request_err) => { + if let Some(status) = request_err.status() { + self.map_http_failure(action, status, request_err.to_string(), model) + } else if request_err.is_timeout() { + Error::Timeout(format!("Ollama {action} timed out: {request_err}")) + } else { + Error::Network(format!("Ollama {action} request failed: {request_err}")) + } + } + OllamaError::InternalError(internal) => Error::Provider(anyhow!(internal.message)), + OllamaError::Other(message) => Error::Provider(anyhow!(message)), + OllamaError::JsonError(err) => Error::Serialization(err), + OllamaError::ToolCallError(err) => Error::Provider(anyhow!(err)), } } @@ -206,7 +345,7 @@ impl OllamaProvider { StatusCode::NOT_FOUND => { if let Some(model) = model { Error::InvalidInput(format!( - "Model '{model}' was not found at {}. Verify the model name or load it with `ollama pull`.", + "Model '{model}' was not found at {}. Verify the name or pull it with `ollama pull`.", self.base_url )) } else { @@ -232,163 +371,10 @@ impl OllamaProvider { )), } } - - fn debug_log_request(&self, label: &str, request: &reqwest::Request, body_json: Option<&str>) { - if !debug_requests_enabled() { - return; - } - - eprintln!("--- OWLEN Ollama request ({label}) ---"); - eprintln!("{} {}", request.method(), request.url()); - - match request - .headers() - .get(header::AUTHORIZATION) - .and_then(|value| value.to_str().ok()) - { - Some(value) => eprintln!("Authorization: {}", mask_authorization(value)), - None => eprintln!("Authorization: "), - } - - if let Some(body) = body_json { - eprintln!("Body:\n{body}"); - } - - eprintln!("---------------------------------------"); - } - - fn convert_tools_to_ollama(tools: &[McpToolDescriptor]) -> Vec { - tools - .iter() - .map(|tool| OllamaTool { - tool_type: "function".to_string(), - function: OllamaToolFunction { - name: tool.name.clone(), - description: tool.description.clone(), - parameters: tool.input_schema.clone(), - }, - }) - .collect() - } - - fn convert_message(message: &Message) -> OllamaMessage { - let role = match message.role { - Role::User => "user".to_string(), - Role::Assistant => "assistant".to_string(), - Role::System => "system".to_string(), - Role::Tool => "tool".to_string(), - }; - - let tool_calls = message.tool_calls.as_ref().map(|calls| { - calls - .iter() - .map(|tc| OllamaToolCall { - function: OllamaToolCallFunction { - name: tc.name.clone(), - arguments: tc.arguments.clone(), - }, - }) - .collect() - }); - - OllamaMessage { - role, - content: message.content.clone(), - tool_calls, - } - } - - fn convert_ollama_message(message: &OllamaMessage) -> Message { - let role = match message.role.as_str() { - "assistant" => Role::Assistant, - "system" => Role::System, - "tool" => Role::Tool, - _ => Role::User, - }; - - let tool_calls = message.tool_calls.as_ref().map(|calls| { - calls - .iter() - .enumerate() - .map(|(idx, tc)| ToolCall { - id: format!("tool-call-{idx}"), - name: tc.function.name.clone(), - arguments: tc.function.arguments.clone(), - }) - .collect::>() - }); - - Message { - id: Uuid::new_v4(), - role, - content: message.content.clone(), - metadata: HashMap::new(), - timestamp: SystemTime::now(), - tool_calls, - } - } - - fn build_options(parameters: ChatParameters) -> HashMap { - let mut options = parameters.extra; - - if let Some(temperature) = parameters.temperature { - options.insert("temperature".to_string(), json!(temperature)); - } - - if let Some(max_tokens) = parameters.max_tokens { - options.insert("num_predict".to_string(), json!(max_tokens)); - } - - options - } - - async fn fetch_models(&self) -> Result> { - let url = self.api_url("tags"); - let response = self - .apply_auth(self.client.get(&url)) - .send() - .await - .map_err(|e| map_reqwest_error("list models", e))?; - - if !response.status().is_success() { - let status = response.status(); - let detail = parse_error_body(response).await; - return Err(self.map_http_failure("list models", status, detail, None)); - } - - let body = response - .text() - .await - .map_err(|e| map_reqwest_error("list models", e))?; - - let models: OllamaModelsResponse = - serde_json::from_str(&body).map_err(Error::Serialization)?; - - Ok(models - .models - .into_iter() - .map(|model| { - let family = model - .details - .and_then(|details| details.family) - .unwrap_or_else(|| "unknown".to_string()); - - ModelInfo { - id: model.name.clone(), - name: model.name, - description: Some(format!("Ollama model ({family})")), - provider: "ollama".to_string(), - context_window: None, - capabilities: vec!["chat".to_string()], - supports_tools: false, - } - }) - .collect()) - } } impl LLMProvider for OllamaProvider { - type Stream = UnboundedReceiverStream>; + type Stream = Pin> + Send>>; type ListModelsFuture<'a> = BoxFuture<'a, Result>> where @@ -427,87 +413,16 @@ impl LLMProvider for OllamaProvider { tools, } = request; - let model_id = model.clone(); - let messages: Vec = messages.iter().map(Self::convert_message).collect(); - let options = Self::build_options(parameters); - - let _ollama_tools = tools - .as_ref() - .filter(|t| !t.is_empty()) - .map(|t| Self::convert_tools_to_ollama(t)); - - let ollama_request = OllamaChatRequest { - model, - messages, - stream: false, - tools: None, - options, - }; - - let url = self.api_url("chat"); - let debug_body = if debug_requests_enabled() { - serde_json::to_string_pretty(&ollama_request).ok() - } else { - None - }; - - let mut request_builder = self.client.post(&url).json(&ollama_request); - request_builder = self.apply_auth(request_builder); - - let request = request_builder - .build() - .map_err(|e| Error::Network(format!("Failed to build chat request: {e}")))?; - - self.debug_log_request("chat", &request, debug_body.as_deref()); + let (model_id, ollama_request) = + self.prepare_chat_request(model, messages, parameters, tools)?; let response = self .client - .execute(request) + .send_chat_messages(ollama_request) .await - .map_err(|e| map_reqwest_error("chat", e))?; + .map_err(|err| self.map_ollama_error("chat", err, Some(&model_id)))?; - if !response.status().is_success() { - let status = response.status(); - let error = parse_error_body(response).await; - return Err(self.map_http_failure("chat", status, error, Some(&model_id))); - } - - let body = response - .text() - .await - .map_err(|e| map_reqwest_error("chat", e))?; - - let mut ollama_response: OllamaChatResponse = - serde_json::from_str(&body).map_err(Error::Serialization)?; - - if let Some(error) = ollama_response.error.take() { - return Err(Error::Provider(anyhow!(error))); - } - - let message = match ollama_response.message { - Some(ref msg) => Self::convert_ollama_message(msg), - None => return Err(Error::Provider(anyhow!("Ollama response missing message"))), - }; - - let usage = if let (Some(prompt_tokens), Some(completion_tokens)) = ( - ollama_response.prompt_eval_count, - ollama_response.eval_count, - ) { - Some(TokenUsage { - prompt_tokens, - completion_tokens, - total_tokens: prompt_tokens + completion_tokens, - }) - } else { - None - }; - - Ok(ChatResponse { - message, - usage, - is_streaming: false, - is_final: true, - }) + Ok(Self::convert_ollama_response(response, false)) }) } @@ -520,161 +435,43 @@ impl LLMProvider for OllamaProvider { tools, } = request; - let model_id = model.clone(); - let messages: Vec = messages.iter().map(Self::convert_message).collect(); - let options = Self::build_options(parameters); + let (model_id, ollama_request) = + self.prepare_chat_request(model, messages, parameters, tools)?; - let _ollama_tools = tools - .as_ref() - .filter(|t| !t.is_empty()) - .map(|t| Self::convert_tools_to_ollama(t)); - - let ollama_request = OllamaChatRequest { - model, - messages, - stream: true, - tools: None, - options, - }; - - let url = self.api_url("chat"); - let debug_body = if debug_requests_enabled() { - serde_json::to_string_pretty(&ollama_request).ok() - } else { - None - }; - - let mut request_builder = self.client.post(&url).json(&ollama_request); - request_builder = self.apply_auth(request_builder); - - let request = request_builder - .build() - .map_err(|e| Error::Network(format!("Failed to build streaming request: {e}")))?; - - self.debug_log_request("chat_stream", &request, debug_body.as_deref()); - - let response = self + let stream = self .client - .execute(request) + .send_chat_messages_stream(ollama_request) .await - .map_err(|e| map_reqwest_error("chat_stream", e))?; + .map_err(|err| self.map_ollama_error("chat_stream", err, Some(&model_id)))?; - if !response.status().is_success() { - let status = response.status(); - let error = parse_error_body(response).await; - return Err(self.map_http_failure("chat_stream", status, error, Some(&model_id))); - } - - let (tx, rx) = mpsc::unbounded_channel(); - let mut stream = response.bytes_stream(); - - tokio::spawn(async move { - let mut buffer = String::new(); - - while let Some(chunk) = stream.next().await { - match chunk { - Ok(bytes) => { - if let Ok(text) = String::from_utf8(bytes.to_vec()) { - buffer.push_str(&text); - - while let Some(pos) = buffer.find('\n') { - let mut line = buffer[..pos].trim().to_string(); - buffer.drain(..=pos); - - if line.is_empty() { - continue; - } - - if line.ends_with('\r') { - line.pop(); - } - - match serde_json::from_str::(&line) { - Ok(mut ollama_response) => { - if let Some(error) = ollama_response.error.take() { - let _ = - tx.send(Err(Error::Provider(anyhow!(error)))); - break; - } - - if let Some(message) = ollama_response.message { - let mut chat_response = ChatResponse { - message: Self::convert_ollama_message(&message), - usage: None, - is_streaming: true, - is_final: ollama_response.done, - }; - - if let ( - Some(prompt_tokens), - Some(completion_tokens), - ) = ( - ollama_response.prompt_eval_count, - ollama_response.eval_count, - ) { - chat_response.usage = Some(TokenUsage { - prompt_tokens, - completion_tokens, - total_tokens: prompt_tokens - + completion_tokens, - }); - } - - if tx.send(Ok(chat_response)).is_err() { - break; - } - - if ollama_response.done { - break; - } - } - } - Err(e) => { - let _ = tx.send(Err(Error::Serialization(e))); - break; - } - } - } - } else { - let _ = tx.send(Err(Error::Serialization(serde_json::Error::io( - io::Error::new( - io::ErrorKind::InvalidData, - "Non UTF-8 chunk from Ollama", - ), - )))); - break; - } - } - Err(e) => { - let _ = tx.send(Err(Error::Network(format!("Stream error: {e}")))); - break; - } - } - } + let mapped = stream.map(|item| match item { + Ok(chunk) => Ok(Self::convert_ollama_response(chunk, true)), + Err(_) => Err(Error::Provider(anyhow!( + "Ollama returned a malformed streaming chunk" + ))), }); - let stream = UnboundedReceiverStream::new(rx); - Ok(stream) + Ok(Box::pin(mapped) as Self::Stream) }) } fn health_check(&self) -> Self::HealthCheckFuture<'_> { Box::pin(async move { let url = self.api_url("version"); - let response = self - .apply_auth(self.client.get(&url)) + .http_client + .get(&url) .send() .await - .map_err(|e| map_reqwest_error("health check", e))?; + .map_err(|err| map_reqwest_error("health check", err))?; if response.status().is_success() { - Ok(()) - } else { - let status = response.status(); - let detail = parse_error_body(response).await; - Err(self.map_http_failure("health check", status, detail, None)) + return Ok(()); } + + let status = response.status(); + let detail = response.text().await.unwrap_or_else(|err| err.to_string()); + Err(self.map_http_failure("health check", status, detail, None)) }) } @@ -704,161 +501,187 @@ impl LLMProvider for OllamaProvider { } } -#[derive(Debug, Clone, Serialize, Deserialize)] -struct OllamaMessage { - role: String, - content: String, - #[serde(skip_serializing_if = "Option::is_none")] - tool_calls: Option>, -} +fn build_model_options(parameters: &ChatParameters) -> Result> { + let mut options = JsonMap::new(); -#[derive(Debug, Clone, Serialize, Deserialize)] -struct OllamaToolCall { - function: OllamaToolCallFunction, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct OllamaToolCallFunction { - name: String, - arguments: serde_json::Value, -} - -#[derive(Debug, Serialize)] -struct OllamaChatRequest { - model: String, - messages: Vec, - stream: bool, - #[serde(skip_serializing_if = "Option::is_none")] - tools: Option>, - #[serde(flatten)] - options: HashMap, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct OllamaTool { - #[serde(rename = "type")] - tool_type: String, - function: OllamaToolFunction, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct OllamaToolFunction { - name: String, - description: String, - parameters: serde_json::Value, -} - -#[derive(Debug, Deserialize)] -struct OllamaChatResponse { - message: Option, - done: bool, - #[serde(default)] - prompt_eval_count: Option, - #[serde(default)] - eval_count: Option, - #[serde(default)] - error: Option, -} - -#[derive(Debug, Deserialize)] -struct OllamaErrorResponse { - error: Option, -} - -#[derive(Debug, Deserialize)] -struct OllamaModelsResponse { - models: Vec, -} - -#[derive(Debug, Deserialize)] -struct OllamaModelInfo { - name: String, - #[serde(default)] - details: Option, -} - -#[derive(Debug, Deserialize)] -struct OllamaModelDetails { - #[serde(default)] - family: Option, -} - -fn is_ollama_host(host: &str) -> bool { - host.eq_ignore_ascii_case("ollama.com") - || host.eq_ignore_ascii_case("www.ollama.com") - || host.eq_ignore_ascii_case("api.ollama.com") - || host.ends_with(".ollama.com") -} - -fn normalize_base_url( - input: Option<&str>, - mode_hint: OllamaMode, -) -> std::result::Result { - let mut candidate = input - .map(str::trim) - .filter(|value| !value.is_empty()) - .map(|value| value.to_string()) - .unwrap_or_else(|| mode_hint.default_base_url().to_string()); - - if !candidate.contains("://") { - candidate = format!("{}://{}", mode_hint.default_scheme(), candidate); + for (key, value) in ¶meters.extra { + options.insert(key.clone(), value.clone()); } - let mut url = - Url::parse(&candidate).map_err(|err| format!("Invalid base_url '{candidate}': {err}"))?; - - let mut is_cloud = matches!(mode_hint, OllamaMode::Cloud); - - if let Some(host) = url.host_str() { - if is_ollama_host(host) { - is_cloud = true; - } + if let Some(temperature) = parameters.temperature { + options.insert("temperature".to_string(), json!(temperature)); } - if is_cloud { - if url.scheme() != "https" { - url.set_scheme("https") - .map_err(|_| "Ollama Cloud requires an https URL".to_string())?; - } - - match url.host_str() { - Some(host) => { - if host.eq_ignore_ascii_case("www.ollama.com") { - url.set_host(Some("ollama.com")) - .map_err(|_| "Failed to normalize Ollama Cloud host".to_string())?; - } - } - None => { - return Err("Ollama Cloud base_url must include a hostname".to_string()); - } - } + if let Some(max_tokens) = parameters.max_tokens { + let capped = i32::try_from(max_tokens).unwrap_or(i32::MAX); + options.insert("num_predict".to_string(), json!(capped)); } - let current_path = url.path().to_string(); - let trimmed_path = current_path.trim_end_matches('/'); - if trimmed_path.is_empty() { - url.set_path(""); + if options.is_empty() { + return Ok(None); + } + + serde_json::from_value(Value::Object(options)) + .map(Some) + .map_err(|err| Error::Config(format!("Invalid Ollama options: {err}"))) +} + +fn convert_message(message: Message) -> OllamaMessage { + let Message { + role, + content, + metadata, + tool_calls, + .. + } = message; + + let role = match role { + Role::User => OllamaRole::User, + Role::Assistant => OllamaRole::Assistant, + Role::System => OllamaRole::System, + Role::Tool => OllamaRole::Tool, + }; + + let tool_calls = tool_calls + .unwrap_or_default() + .into_iter() + .map(|tool_call| OllamaToolCall { + function: OllamaToolCallFunction { + name: tool_call.name, + arguments: tool_call.arguments, + }, + }) + .collect(); + + let thinking = metadata + .get("thinking") + .and_then(|value| value.as_str().map(|s| s.to_owned())); + + OllamaMessage { + role, + content, + tool_calls, + images: None, + thinking, + } +} + +fn convert_ollama_message(message: OllamaMessage) -> Message { + let role = match message.role { + OllamaRole::Assistant => Role::Assistant, + OllamaRole::System => Role::System, + OllamaRole::Tool => Role::Tool, + OllamaRole::User => Role::User, + }; + + let tool_calls = if message.tool_calls.is_empty() { + None } else { - url.set_path(trimmed_path); + Some( + message + .tool_calls + .into_iter() + .enumerate() + .map(|(idx, tool_call)| ToolCall { + id: format!("tool-call-{idx}"), + name: tool_call.function.name, + arguments: tool_call.function.arguments, + }) + .collect::>(), + ) + }; + + let mut metadata = HashMap::new(); + if let Some(thinking) = message.thinking { + metadata.insert("thinking".to_string(), Value::String(thinking)); } - url.set_query(None); - url.set_fragment(None); - - Ok(url.to_string().trim_end_matches('/').to_string()) + Message { + id: Uuid::new_v4(), + role, + content: message.content, + metadata, + timestamp: SystemTime::now(), + tool_calls, + } } -fn build_api_endpoint(base_url: &str, endpoint: &str) -> String { - let trimmed_base = base_url.trim_end_matches('/'); - let trimmed_endpoint = endpoint.trim_start_matches('/'); +fn clamp_to_u32(value: u64) -> u32 { + u32::try_from(value).unwrap_or(u32::MAX) +} - if trimmed_base.ends_with("/api") { - format!("{trimmed_base}/{trimmed_endpoint}") - } else { - format!("{trimmed_base}/api/{trimmed_endpoint}") +fn push_capability(capabilities: &mut Vec, capability: &str) { + let candidate = capability.to_ascii_lowercase(); + if !capabilities + .iter() + .any(|existing| existing.eq_ignore_ascii_case(&candidate)) + { + capabilities.push(candidate); } } +fn heuristic_capabilities(name: &str) -> Vec { + let lowercase = name.to_ascii_lowercase(); + let mut detected = Vec::new(); + + if lowercase.contains("vision") + || lowercase.contains("multimodal") + || lowercase.contains("image") + { + detected.push("vision".to_string()); + } + + if lowercase.contains("think") + || lowercase.contains("reason") + || lowercase.contains("deepseek-r1") + || lowercase.contains("r1") + { + detected.push("thinking".to_string()); + } + + if lowercase.contains("audio") || lowercase.contains("speech") || lowercase.contains("voice") { + detected.push("audio".to_string()); + } + + detected +} + +fn build_model_description(scope: &str, detail: Option<&OllamaModelInfo>) -> String { + if let Some(info) = detail { + let mut parts = Vec::new(); + + if let Some(family) = info + .model_info + .get("family") + .and_then(|value| value.as_str()) + { + parts.push(family.to_string()); + } + + if let Some(parameter_size) = info + .model_info + .get("parameter_size") + .and_then(|value| value.as_str()) + { + parts.push(parameter_size.to_string()); + } + + if let Some(variant) = info + .model_info + .get("variant") + .and_then(|value| value.as_str()) + { + parts.push(variant.to_string()); + } + + if !parts.is_empty() { + return format!("Ollama ({scope}) โ€“ {}", parts.join(" ยท ")); + } + } + + format!("Ollama ({scope}) model") +} + fn env_var_non_empty(name: &str) -> Option { env::var(name) .ok() @@ -887,129 +710,66 @@ fn resolve_api_key(configured: Option) -> Option { Some(raw) } -fn debug_requests_enabled() -> bool { - env::var("OWLEN_DEBUG_OLLAMA") - .ok() - .map(|value| { - matches!( - value.trim(), - "1" | "true" | "TRUE" | "True" | "yes" | "YES" | "Yes" - ) - }) - .unwrap_or(false) -} - -fn mask_token(token: &str) -> String { - if token.len() <= 8 { - return "***".to_string(); - } - - let head = &token[..4]; - let tail = &token[token.len() - 4..]; - format!("{head}***{tail}") -} - -fn mask_authorization(value: &str) -> String { - if let Some(token) = value.strip_prefix("Bearer ") { - format!("Bearer {}", mask_token(token)) - } else { - "***".to_string() - } -} - fn map_reqwest_error(action: &str, err: reqwest::Error) -> Error { if err.is_timeout() { - return Error::Timeout(format!("{action} request timed out")); + Error::Timeout(format!("Ollama {action} request timed out: {err}")) + } else { + Error::Network(format!("Ollama {action} request failed: {err}")) } - - if err.is_connect() { - return Error::Network(format!("{action} connection failed: {err}")); - } - - if err.is_request() || err.is_body() { - return Error::Network(format!("{action} request failed: {err}")); - } - - Error::Network(format!("{action} unexpected error: {err}")) } -async fn parse_error_body(response: reqwest::Response) -> String { - match response.bytes().await { - Ok(bytes) => { - if bytes.is_empty() { - return "unknown error".to_string(); - } +fn normalize_base_url( + input: Option<&str>, + mode_hint: OllamaMode, +) -> std::result::Result { + let mut candidate = input + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(|value| value.to_string()) + .unwrap_or_else(|| mode_hint.default_base_url().to_string()); - if let Ok(err) = serde_json::from_slice::(&bytes) { - if let Some(error) = err.error { - return error; - } - } + if !candidate.starts_with("http://") && !candidate.starts_with("https://") { + candidate = format!("https://{candidate}"); + } - match String::from_utf8(bytes.to_vec()) { - Ok(text) if !text.trim().is_empty() => text, - _ => "unknown error".to_string(), - } - } - Err(_) => "unknown error".to_string(), + let mut url = + Url::parse(&candidate).map_err(|err| format!("Invalid Ollama URL '{candidate}': {err}"))?; + + if url.cannot_be_a_base() { + return Err(format!("URL '{candidate}' cannot be used as a base URL")); + } + + if mode_hint == OllamaMode::Cloud && url.scheme() != "https" { + return Err("Ollama Cloud requires https:// base URLs".to_string()); + } + + let path = url.path().trim_end_matches('/'); + if path == "/api" { + url.set_path("/"); + } else if !path.is_empty() && path != "/" { + return Err("Ollama base URLs must not include additional path segments".to_string()); + } + + url.set_query(None); + url.set_fragment(None); + + Ok(url.to_string().trim_end_matches('/').to_string()) +} + +fn build_api_endpoint(base_url: &str, endpoint: &str) -> String { + let trimmed_base = base_url.trim_end_matches('/'); + let trimmed_endpoint = endpoint.trim_start_matches('/'); + + if trimmed_base.ends_with("/api") { + format!("{trimmed_base}/{trimmed_endpoint}") + } else { + format!("{trimmed_base}/api/{trimmed_endpoint}") } } #[cfg(test)] mod tests { use super::*; - use crate::provider::ProviderConfig; - use std::collections::HashMap; - - #[test] - fn normalizes_local_base_url_and_infers_scheme() { - let normalized = - normalize_base_url(Some("localhost:11434"), OllamaMode::Local).expect("valid URL"); - assert_eq!(normalized, "http://localhost:11434"); - } - - #[test] - fn normalizes_cloud_base_url_and_host() { - let normalized = - normalize_base_url(Some("https://ollama.com"), OllamaMode::Cloud).expect("valid URL"); - assert_eq!(normalized, "https://ollama.com"); - } - - #[test] - fn infers_scheme_for_cloud_hosts() { - let normalized = - normalize_base_url(Some("ollama.com"), OllamaMode::Cloud).expect("valid URL"); - assert_eq!(normalized, "https://ollama.com"); - } - - #[test] - fn rewrites_www_cloud_host() { - let normalized = normalize_base_url(Some("https://www.ollama.com"), OllamaMode::Cloud) - .expect("valid URL"); - assert_eq!(normalized, "https://ollama.com"); - } - - #[test] - fn retains_explicit_api_suffix() { - let normalized = normalize_base_url(Some("https://api.ollama.com/api"), OllamaMode::Cloud) - .expect("valid URL"); - assert_eq!(normalized, "https://api.ollama.com/api"); - } - - #[test] - fn builds_api_endpoint_without_duplicate_segments() { - let base = "http://localhost:11434"; - assert_eq!( - build_api_endpoint(base, "chat"), - "http://localhost:11434/api/chat" - ); - - let base_with_api = "http://localhost:11434/api"; - assert_eq!( - build_api_endpoint(base_with_api, "chat"), - "http://localhost:11434/api/chat" - ); - } #[test] fn resolve_api_key_prefers_literal_value() { @@ -1021,28 +781,61 @@ mod tests { #[test] fn resolve_api_key_expands_env_var() { - env::set_var("OLLAMA_TEST_KEY", "env-key"); + std::env::set_var("OLLAMA_TEST_KEY", "secret"); assert_eq!( resolve_api_key(Some("${OLLAMA_TEST_KEY}".into())), - Some("env-key".into()) + Some("secret".into()) ); - env::remove_var("OLLAMA_TEST_KEY"); + std::env::remove_var("OLLAMA_TEST_KEY"); } #[test] - fn cloud_mode_forces_cloud_base_url() { - let mut config = ProviderConfig { - provider_type: "ollama".into(), - base_url: Some("http://localhost:11434".into()), - api_key: Some("dummy".into()), - extra: HashMap::new(), - }; - let provider = OllamaProvider::from_config(&config, None).expect("provider"); - assert!(provider.base_url.starts_with("https://ollama.com")); + fn normalize_base_url_removes_api_path() { + let url = normalize_base_url(Some("https://ollama.com/api"), OllamaMode::Cloud).unwrap(); + assert_eq!(url, "https://ollama.com"); + } - config.api_key = None; - config.base_url = Some("http://localhost:11434".into()); - let provider = OllamaProvider::from_config(&config, None).expect("provider"); - assert!(provider.base_url.starts_with("http://localhost:11434")); + #[test] + fn normalize_base_url_rejects_cloud_without_https() { + let err = normalize_base_url(Some("http://ollama.com"), OllamaMode::Cloud).unwrap_err(); + assert!(err.contains("https")); + } + + #[test] + fn build_model_options_merges_parameters() { + let mut parameters = ChatParameters::default(); + parameters.temperature = Some(0.3); + parameters.max_tokens = Some(128); + parameters + .extra + .insert("num_ctx".into(), Value::from(4096_u64)); + + let options = build_model_options(¶meters) + .expect("options built") + .expect("options present"); + let serialized = serde_json::to_value(&options).expect("serialize options"); + let temperature = serialized["temperature"] + .as_f64() + .expect("temperature present"); + assert!((temperature - 0.3).abs() < 1e-6); + assert_eq!(serialized["num_predict"], json!(128)); + assert_eq!(serialized["num_ctx"], json!(4096)); + } + + #[test] + fn heuristic_capabilities_detects_thinking_models() { + let caps = heuristic_capabilities("deepseek-r1"); + assert!(caps.iter().any(|cap| cap == "thinking")); + } + + #[test] + fn push_capability_avoids_duplicates() { + let mut caps = vec!["chat".to_string()]; + push_capability(&mut caps, "Chat"); + push_capability(&mut caps, "Vision"); + push_capability(&mut caps, "vision"); + + assert_eq!(caps.len(), 2); + assert!(caps.iter().any(|cap| cap == "vision")); } } diff --git a/crates/owlen-core/tests/provider_interface.rs b/crates/owlen-core/tests/provider_interface.rs deleted file mode 100644 index 59d5489..0000000 --- a/crates/owlen-core/tests/provider_interface.rs +++ /dev/null @@ -1,43 +0,0 @@ -use futures::StreamExt; -use owlen_core::provider::test_utils::MockProvider; -use owlen_core::{provider::ProviderRegistry, types::*, Router}; -use std::sync::Arc; - -fn request(message: &str) -> ChatRequest { - ChatRequest { - model: "mock-model".to_string(), - messages: vec![Message::new(Role::User, message.to_string())], - parameters: ChatParameters::default(), - tools: None, - } -} - -#[tokio::test] -async fn router_routes_to_registered_provider() { - let mut router = Router::new(); - router.register_provider(MockProvider::default()); - router.set_default_provider("mock".to_string()); - - let resp = router.chat(request("ping")).await.expect("chat succeeded"); - assert_eq!(resp.message.content, "Mock response to: ping"); - - let mut stream = router - .chat_stream(request("pong")) - .await - .expect("stream returned"); - let first = stream.next().await.expect("stream item").expect("ok item"); - assert_eq!(first.message.content, "Mock response to: pong"); -} - -#[tokio::test] -async fn registry_lists_models_from_all_providers() { - let mut registry = ProviderRegistry::new(); - registry.register(MockProvider::default()); - registry.register_arc(Arc::new(MockProvider::default())); - - let models = registry.list_all_models().await.expect("listed"); - assert!( - models.iter().any(|m| m.name == "mock-model"), - "expected mock-model in model list" - ); -} diff --git a/crates/owlen-tui/src/ui.rs b/crates/owlen-tui/src/ui.rs index 2f93d58..4dfe40e 100644 --- a/crates/owlen-tui/src/ui.rs +++ b/crates/owlen-tui/src/ui.rs @@ -9,7 +9,7 @@ use tui_textarea::TextArea; use unicode_width::UnicodeWidthStr; use crate::chat_app::{ChatApp, ModelSelectorItemKind, HELP_TAB_COUNT}; -use owlen_core::types::Role; +use owlen_core::types::{ModelInfo, Role}; use owlen_core::ui::{FocusedPanel, InputMode}; const PRIVACY_TAB_INDEX: usize = HELP_TAB_COUNT - 1; @@ -1371,6 +1371,47 @@ fn render_provider_selector(frame: &mut Frame<'_>, app: &ChatApp) { frame.render_stateful_widget(list, area, &mut state); } +fn model_badge_icons(model: &ModelInfo) -> Vec<&'static str> { + let mut badges = Vec::new(); + + if model.supports_tools { + badges.push("๐Ÿ”ง"); + } + + if model_has_feature(model, &["think", "reason"]) { + badges.push("๐Ÿง "); + } + + if model_has_feature(model, &["vision", "multimodal", "image"]) { + badges.push("๐Ÿ‘๏ธ"); + } + + if model_has_feature(model, &["audio", "speech", "voice"]) { + badges.push("๐ŸŽง"); + } + + badges +} + +fn model_has_feature(model: &ModelInfo, keywords: &[&str]) -> bool { + let name_lower = model.name.to_ascii_lowercase(); + if keywords.iter().any(|kw| name_lower.contains(kw)) { + return true; + } + + if let Some(description) = &model.description { + let description_lower = description.to_ascii_lowercase(); + if keywords.iter().any(|kw| description_lower.contains(kw)) { + return true; + } + } + + model.capabilities.iter().any(|cap| { + let lower = cap.to_ascii_lowercase(); + keywords.iter().any(|kw| lower.contains(kw)) + }) +} + fn render_model_selector(frame: &mut Frame<'_>, app: &ChatApp) { let theme = app.theme(); let area = centered_rect(60, 60, frame.area()); @@ -1392,10 +1433,7 @@ fn render_model_selector(frame: &mut Frame<'_>, app: &ChatApp) { } ModelSelectorItemKind::Model { model_index, .. } => { if let Some(model) = app.model_info_by_index(*model_index) { - let mut badges = Vec::new(); - if model.supports_tools { - badges.push("๐Ÿ”ง"); - } + let badges = model_badge_icons(model); let label = if badges.is_empty() { format!(" {}", model.id) @@ -1428,7 +1466,7 @@ fn render_model_selector(frame: &mut Frame<'_>, app: &ChatApp) { .block( Block::default() .title(Span::styled( - "Select Model โ€” ๐Ÿ”ง = Tool Support", + "Select Model โ€” ๐Ÿ”ง tools โ€ข ๐Ÿง  thinking โ€ข ๐Ÿ‘๏ธ vision โ€ข ๐ŸŽง audio", Style::default() .fg(theme.focused_panel_border) .add_modifier(Modifier::BOLD), @@ -1602,6 +1640,67 @@ fn render_consent_dialog(frame: &mut Frame<'_>, app: &ChatApp) { frame.render_widget(paragraph, area); } +#[cfg(test)] +mod tests { + use super::*; + + fn model_with(capabilities: Vec<&str>, description: Option<&str>) -> ModelInfo { + ModelInfo { + id: "model".into(), + name: "model".into(), + description: description.map(|s| s.to_string()), + provider: "test".into(), + context_window: None, + capabilities: capabilities.into_iter().map(|s| s.to_string()).collect(), + supports_tools: false, + } + } + + #[test] + fn badges_include_tool_icon() { + let model = ModelInfo { + id: "tool-model".into(), + name: "tool-model".into(), + description: None, + provider: "test".into(), + context_window: None, + capabilities: vec![], + supports_tools: true, + }; + + assert!(model_badge_icons(&model).contains(&"๐Ÿ”ง")); + } + + #[test] + fn badges_detect_thinking_capability() { + let model = model_with(vec!["Thinking"], None); + let icons = model_badge_icons(&model); + assert!(icons.contains(&"๐Ÿง ")); + } + + #[test] + fn badges_detect_vision_from_description() { + let model = model_with(vec!["chat"], Some("Supports multimodal vision")); + let icons = model_badge_icons(&model); + assert!(icons.contains(&"๐Ÿ‘๏ธ")); + } + + #[test] + fn badges_detect_audio_from_name() { + let model = ModelInfo { + id: "voice-specialist".into(), + name: "Voice-Specialist".into(), + description: None, + provider: "test".into(), + context_window: None, + capabilities: vec![], + supports_tools: false, + }; + let icons = model_badge_icons(&model); + assert!(icons.contains(&"๐ŸŽง")); + } +} + fn render_privacy_settings(frame: &mut Frame<'_>, area: Rect, app: &ChatApp) { let theme = app.theme(); let config = app.config();