diff --git a/Cargo.toml b/Cargo.toml index 0a88754..006298f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ members = [ "crates/owlen-core", "crates/owlen-tui", "crates/owlen-cli", + "crates/owlen-providers", "crates/owlen-mcp-server", "crates/owlen-mcp-llm-server", "crates/owlen-mcp-client", diff --git a/crates/owlen-providers/Cargo.toml b/crates/owlen-providers/Cargo.toml new file mode 100644 index 0000000..ded8348 --- /dev/null +++ b/crates/owlen-providers/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "owlen-providers" +version.workspace = true +edition.workspace = true +authors.workspace = true +license.workspace = true +repository.workspace = true +homepage.workspace = true +description = "Provider implementations for OWLEN" + +[dependencies] +owlen-core = { path = "../owlen-core" } +anyhow = { workspace = true } +async-trait = { workspace = true } +futures = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +tokio = { workspace = true } +tokio-stream = { workspace = true } +reqwest = { package = "reqwest", version = "0.11", features = ["json", "stream"] } diff --git a/crates/owlen-providers/src/lib.rs b/crates/owlen-providers/src/lib.rs new file mode 100644 index 0000000..59dd5ac --- /dev/null +++ b/crates/owlen-providers/src/lib.rs @@ -0,0 +1,3 @@ +//! Provider implementations for OWLEN. + +pub mod ollama; diff --git a/crates/owlen-providers/src/ollama/mod.rs b/crates/owlen-providers/src/ollama/mod.rs new file mode 100644 index 0000000..663f6f4 --- /dev/null +++ b/crates/owlen-providers/src/ollama/mod.rs @@ -0,0 +1,3 @@ +pub mod shared; + +pub use shared::OllamaClient; diff --git a/crates/owlen-providers/src/ollama/shared.rs b/crates/owlen-providers/src/ollama/shared.rs new file mode 100644 index 0000000..caec128 --- /dev/null +++ b/crates/owlen-providers/src/ollama/shared.rs @@ -0,0 +1,360 @@ +use std::collections::HashMap; +use std::time::Duration; + +use futures::StreamExt; +use owlen_core::provider::{ + GenerateChunk, GenerateRequest, GenerateStream, ModelInfo, ProviderMetadata, ProviderStatus, +}; +use owlen_core::{Error as CoreError, Result as CoreResult}; +use reqwest::{Client, Method, StatusCode, Url}; +use serde::Deserialize; +use serde_json::{Map as JsonMap, Value}; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; + +const DEFAULT_TIMEOUT_SECS: u64 = 60; + +/// Shared Ollama HTTP client used by both local and cloud providers. +#[derive(Clone)] +pub struct OllamaClient { + http: Client, + base_url: Url, + api_key: Option, + provider_metadata: ProviderMetadata, +} + +impl OllamaClient { + /// Create a new client with the given base URL and optional API key. + pub fn new( + base_url: impl AsRef, + api_key: Option, + provider_metadata: ProviderMetadata, + ) -> CoreResult { + let base_url = Url::parse(base_url.as_ref()) + .map_err(|err| CoreError::Config(format!("invalid base url: {}", err)))?; + + let http = Client::builder() + .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS)) + .build() + .map_err(map_reqwest_error)?; + + Ok(Self { + http, + base_url, + api_key, + provider_metadata, + }) + } + + /// Provider metadata associated with this client. + pub fn metadata(&self) -> &ProviderMetadata { + &self.provider_metadata + } + + /// Perform a basic health check to determine provider availability. + pub async fn health_check(&self) -> CoreResult { + let url = self.endpoint("api/tags")?; + + let response = self + .request(Method::GET, url) + .send() + .await + .map_err(map_reqwest_error)?; + + match response.status() { + status if status.is_success() => Ok(ProviderStatus::Available), + StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => Ok(ProviderStatus::RequiresSetup), + _ => Ok(ProviderStatus::Unavailable), + } + } + + /// Fetch the available models from the Ollama API. + pub async fn list_models(&self) -> CoreResult> { + let url = self.endpoint("api/tags")?; + + let response = self + .request(Method::GET, url) + .send() + .await + .map_err(map_reqwest_error)?; + + let status = response.status(); + let bytes = response.bytes().await.map_err(map_reqwest_error)?; + + if !status.is_success() { + let body = String::from_utf8_lossy(&bytes); + return Err(CoreError::Provider(anyhow::anyhow!( + "Ollama tags request failed: HTTP {} - {}", + status, + body + ))); + } + + let payload: TagsResponse = + serde_json::from_slice(&bytes).map_err(CoreError::Serialization)?; + + let models = payload + .models + .into_iter() + .map(|model| self.parse_model_info(model)) + .collect(); + + Ok(models) + } + + /// Request a streaming generation session from Ollama. + pub async fn generate_stream(&self, request: GenerateRequest) -> CoreResult { + let url = self.endpoint("api/generate")?; + + let body = self.build_generate_body(request); + + let response = self + .request(Method::POST, url) + .json(&body) + .send() + .await + .map_err(map_reqwest_error)?; + + let status = response.status(); + + if !status.is_success() { + let bytes = response.bytes().await.map_err(map_reqwest_error)?; + let body = String::from_utf8_lossy(&bytes); + return Err(CoreError::Provider(anyhow::anyhow!( + "Ollama generate request failed: HTTP {} - {}", + status, + body + ))); + } + + let stream = response.bytes_stream(); + let (tx, rx) = mpsc::channel::>(32); + + tokio::spawn(async move { + let mut stream = stream; + let mut buffer: Vec = Vec::new(); + + while let Some(chunk) = stream.next().await { + match chunk { + Ok(bytes) => { + buffer.extend_from_slice(&bytes); + while let Some(pos) = buffer.iter().position(|byte| *byte == b'\n') { + let line_bytes: Vec = buffer.drain(..=pos).collect(); + let line = String::from_utf8_lossy(&line_bytes).trim().to_string(); + if line.is_empty() { + continue; + } + + match parse_stream_line(&line) { + Ok(item) => { + if tx.send(Ok(item)).await.is_err() { + return; + } + } + Err(err) => { + let _ = tx.send(Err(err)).await; + return; + } + } + } + } + Err(err) => { + let _ = tx.send(Err(map_reqwest_error(err))).await; + return; + } + } + } + + if !buffer.is_empty() { + let line = String::from_utf8_lossy(&buffer).trim().to_string(); + if !line.is_empty() { + match parse_stream_line(&line) { + Ok(item) => { + let _ = tx.send(Ok(item)).await; + } + Err(err) => { + let _ = tx.send(Err(err)).await; + } + } + } + } + }); + + let stream = ReceiverStream::new(rx); + Ok(Box::pin(stream)) + } + + fn request(&self, method: Method, url: Url) -> reqwest::RequestBuilder { + let mut builder = self.http.request(method, url); + if let Some(api_key) = &self.api_key { + builder = builder.bearer_auth(api_key); + } + builder + } + + fn endpoint(&self, path: &str) -> CoreResult { + self.base_url + .join(path) + .map_err(|err| CoreError::Config(format!("invalid endpoint '{}': {}", path, err))) + } + + fn build_generate_body(&self, request: GenerateRequest) -> Value { + let GenerateRequest { + model, + prompt, + context, + parameters, + metadata, + } = request; + + let mut body = JsonMap::new(); + body.insert("model".into(), Value::String(model)); + body.insert("stream".into(), Value::Bool(true)); + + if let Some(prompt) = prompt { + body.insert("prompt".into(), Value::String(prompt)); + } + + if !context.is_empty() { + let items = context.into_iter().map(Value::String).collect(); + body.insert("context".into(), Value::Array(items)); + } + + if !parameters.is_empty() { + body.insert("options".into(), Value::Object(to_json_map(parameters))); + } + + if !metadata.is_empty() { + body.insert("metadata".into(), Value::Object(to_json_map(metadata))); + } + + Value::Object(body) + } + + fn parse_model_info(&self, model: OllamaModel) -> ModelInfo { + let mut metadata = HashMap::new(); + + if let Some(digest) = model.digest { + metadata.insert("digest".to_string(), Value::String(digest)); + } + if let Some(modified) = model.modified_at { + metadata.insert("modified_at".to_string(), Value::String(modified)); + } + if let Some(details) = model.details { + let mut details_map = JsonMap::new(); + if let Some(format) = details.format { + details_map.insert("format".into(), Value::String(format)); + } + if let Some(family) = details.family { + details_map.insert("family".into(), Value::String(family)); + } + if let Some(parameter_size) = details.parameter_size { + details_map.insert("parameter_size".into(), Value::String(parameter_size)); + } + if let Some(quantisation) = details.quantization_level { + details_map.insert("quantization_level".into(), Value::String(quantisation)); + } + + if !details_map.is_empty() { + metadata.insert("details".to_string(), Value::Object(details_map)); + } + } + + ModelInfo { + name: model.name, + size_bytes: model.size, + capabilities: Vec::new(), + description: None, + provider: self.provider_metadata.clone(), + metadata, + } + } +} + +#[derive(Debug, Deserialize)] +struct TagsResponse { + #[serde(default)] + models: Vec, +} + +#[derive(Debug, Deserialize)] +struct OllamaModel { + name: String, + #[serde(default)] + size: Option, + #[serde(default)] + digest: Option, + #[serde(default)] + modified_at: Option, + #[serde(default)] + details: Option, +} + +#[derive(Debug, Deserialize)] +struct OllamaModelDetails { + #[serde(default)] + format: Option, + #[serde(default)] + family: Option, + #[serde(default)] + parameter_size: Option, + #[serde(default)] + quantization_level: Option, +} + +fn to_json_map(source: HashMap) -> JsonMap { + source.into_iter().collect() +} + +fn to_metadata_map(value: &Value) -> HashMap { + let mut metadata = HashMap::new(); + + if let Value::Object(obj) = value { + for (key, item) in obj { + if key == "response" || key == "done" { + continue; + } + metadata.insert(key.clone(), item.clone()); + } + } + + metadata +} + +fn parse_stream_line(line: &str) -> CoreResult { + let value: Value = serde_json::from_str(line).map_err(CoreError::Serialization)?; + + if let Some(error) = value.get("error").and_then(Value::as_str) { + return Err(CoreError::Provider(anyhow::anyhow!( + "ollama generation error: {}", + error + ))); + } + + let mut chunk = GenerateChunk { + text: value + .get("response") + .and_then(Value::as_str) + .map(str::to_string), + is_final: value.get("done").and_then(Value::as_bool).unwrap_or(false), + metadata: to_metadata_map(&value), + }; + + if chunk.is_final && chunk.text.is_none() && chunk.metadata.is_empty() { + chunk + .metadata + .insert("status".into(), Value::String("done".into())); + } + + Ok(chunk) +} + +fn map_reqwest_error(err: reqwest::Error) -> CoreError { + if err.is_timeout() { + CoreError::Timeout(err.to_string()) + } else if err.is_connect() || err.is_request() { + CoreError::Network(err.to_string()) + } else { + CoreError::Provider(err.into()) + } +}