feat(ollama): add cloud provider with API key handling and auth‑aware health check

Introduce `OllamaCloudProvider` that resolves the API key from configuration or the `OLLAMA_CLOUD_API_KEY` environment variable, constructs provider metadata (including timeout as numeric), and maps auth errors to `ProviderStatus::RequiresSetup`. Export the new provider in the `ollama` module. Add shared HTTP error mapping utilities (`map_http_error`, `truncated_body`) and update local provider metadata to store timeout as a number.
This commit is contained in:
2025-10-15 21:07:41 +02:00
parent cdc425ae93
commit b49f58bc16
4 changed files with 152 additions and 14 deletions

View File

@@ -0,0 +1,108 @@
use std::{env, time::Duration};
use async_trait::async_trait;
use owlen_core::{
Error as CoreError, Result as CoreResult,
config::OLLAMA_CLOUD_BASE_URL,
provider::{
GenerateRequest, GenerateStream, ModelInfo, ModelProvider, ProviderMetadata,
ProviderStatus, ProviderType,
},
};
use serde_json::{Number, Value};
use super::OllamaClient;
const API_KEY_ENV: &str = "OLLAMA_CLOUD_API_KEY";
/// ModelProvider implementation for the hosted Ollama Cloud service.
pub struct OllamaCloudProvider {
client: OllamaClient,
}
impl OllamaCloudProvider {
/// Construct a new cloud provider. An API key must be supplied either
/// directly or via the `OLLAMA_CLOUD_API_KEY` environment variable.
pub fn new(
base_url: Option<String>,
api_key: Option<String>,
request_timeout: Option<Duration>,
) -> CoreResult<Self> {
let (api_key, key_source) = resolve_api_key(api_key)?;
let base_url = base_url.unwrap_or_else(|| OLLAMA_CLOUD_BASE_URL.to_string());
let mut metadata =
ProviderMetadata::new("ollama_cloud", "Ollama (Cloud)", ProviderType::Cloud, true);
metadata
.metadata
.insert("base_url".into(), Value::String(base_url.clone()));
metadata.metadata.insert(
"api_key_source".into(),
Value::String(key_source.to_string()),
);
metadata
.metadata
.insert("api_key_env".into(), Value::String(API_KEY_ENV.to_string()));
if let Some(timeout) = request_timeout {
let timeout_ms = timeout.as_millis().min(u128::from(u64::MAX)) as u64;
metadata.metadata.insert(
"request_timeout_ms".into(),
Value::Number(Number::from(timeout_ms)),
);
}
let client = OllamaClient::new(&base_url, Some(api_key), metadata, request_timeout)?;
Ok(Self { client })
}
}
#[async_trait]
impl ModelProvider for OllamaCloudProvider {
fn metadata(&self) -> &ProviderMetadata {
self.client.metadata()
}
async fn health_check(&self) -> CoreResult<ProviderStatus> {
match self.client.health_check().await {
Ok(status) => Ok(status),
Err(CoreError::Auth(_)) => Ok(ProviderStatus::RequiresSetup),
Err(err) => Err(err),
}
}
async fn list_models(&self) -> CoreResult<Vec<ModelInfo>> {
self.client.list_models().await
}
async fn generate_stream(&self, request: GenerateRequest) -> CoreResult<GenerateStream> {
self.client.generate_stream(request).await
}
}
fn resolve_api_key(api_key: Option<String>) -> CoreResult<(String, &'static str)> {
let key_from_config = api_key
.as_ref()
.map(|value| value.trim())
.filter(|value| !value.is_empty())
.map(str::to_string);
if let Some(key) = key_from_config {
return Ok((key, "config"));
}
let key_from_env = env::var(API_KEY_ENV)
.ok()
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty());
if let Some(key) = key_from_env {
return Ok((key, "env"));
}
Err(CoreError::Config(
"Ollama Cloud API key not configured. Set OLLAMA_CLOUD_API_KEY or configure an API key."
.into(),
))
}

View File

@@ -6,7 +6,7 @@ use owlen_core::provider::{
ProviderType, ProviderType,
}; };
use owlen_core::{Error as CoreError, Result as CoreResult}; use owlen_core::{Error as CoreError, Result as CoreResult};
use serde_json::Value; use serde_json::{Number, Value};
use tokio::time::timeout; use tokio::time::timeout;
use super::OllamaClient; use super::OllamaClient;
@@ -37,9 +37,10 @@ impl OllamaLocalProvider {
.metadata .metadata
.insert("base_url".into(), Value::String(base_url.clone())); .insert("base_url".into(), Value::String(base_url.clone()));
if let Some(timeout) = request_timeout { if let Some(timeout) = request_timeout {
let timeout_ms = timeout.as_millis().min(u128::from(u64::MAX)) as u64;
metadata.metadata.insert( metadata.metadata.insert(
"request_timeout_ms".into(), "request_timeout_ms".into(),
Value::String(timeout.as_millis().to_string()), Value::Number(Number::from(timeout_ms)),
); );
} }

View File

@@ -1,5 +1,7 @@
pub mod cloud;
pub mod local; pub mod local;
pub mod shared; pub mod shared;
pub use cloud::OllamaCloudProvider;
pub use local::OllamaLocalProvider; pub use local::OllamaLocalProvider;
pub use shared::OllamaClient; pub use shared::OllamaClient;

View File

@@ -84,12 +84,7 @@ impl OllamaClient {
let bytes = response.bytes().await.map_err(map_reqwest_error)?; let bytes = response.bytes().await.map_err(map_reqwest_error)?;
if !status.is_success() { if !status.is_success() {
let body = String::from_utf8_lossy(&bytes); return Err(map_http_error("tags", status, &bytes));
return Err(CoreError::Provider(anyhow::anyhow!(
"Ollama tags request failed: HTTP {} - {}",
status,
body
)));
} }
let payload: TagsResponse = let payload: TagsResponse =
@@ -121,12 +116,7 @@ impl OllamaClient {
if !status.is_success() { if !status.is_success() {
let bytes = response.bytes().await.map_err(map_reqwest_error)?; let bytes = response.bytes().await.map_err(map_reqwest_error)?;
let body = String::from_utf8_lossy(&bytes); return Err(map_http_error("generate", status, &bytes));
return Err(CoreError::Provider(anyhow::anyhow!(
"Ollama generate request failed: HTTP {} - {}",
status,
body
)));
} }
let stream = response.bytes_stream(); let stream = response.bytes_stream();
@@ -351,6 +341,43 @@ fn parse_stream_line(line: &str) -> CoreResult<GenerateChunk> {
Ok(chunk) Ok(chunk)
} }
fn map_http_error(endpoint: &str, status: StatusCode, body: &[u8]) -> CoreError {
match status {
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => CoreError::Auth(format!(
"Ollama {} request unauthorized (status {})",
endpoint, status
)),
StatusCode::TOO_MANY_REQUESTS => CoreError::Provider(anyhow::anyhow!(
"Ollama {} request rate limited (status {})",
endpoint,
status
)),
_ => {
let snippet = truncated_body(body);
CoreError::Provider(anyhow::anyhow!(
"Ollama {} request failed: HTTP {} - {}",
endpoint,
status,
snippet
))
}
}
}
fn truncated_body(body: &[u8]) -> String {
const MAX_CHARS: usize = 512;
let text = String::from_utf8_lossy(body);
let mut value = String::new();
for (idx, ch) in text.chars().enumerate() {
if idx >= MAX_CHARS {
value.push('…');
return value;
}
value.push(ch);
}
value
}
fn map_reqwest_error(err: reqwest::Error) -> CoreError { fn map_reqwest_error(err: reqwest::Error) -> CoreError {
if err.is_timeout() { if err.is_timeout() {
CoreError::Timeout(err.to_string()) CoreError::Timeout(err.to_string())