feat(ollama): add cloud provider with API key handling and auth‑aware health check
Introduce `OllamaCloudProvider` that resolves the API key from configuration or the `OLLAMA_CLOUD_API_KEY` environment variable, constructs provider metadata (including timeout as numeric), and maps auth errors to `ProviderStatus::RequiresSetup`. Export the new provider in the `ollama` module. Add shared HTTP error mapping utilities (`map_http_error`, `truncated_body`) and update local provider metadata to store timeout as a number.
This commit is contained in:
108
crates/owlen-providers/src/ollama/cloud.rs
Normal file
108
crates/owlen-providers/src/ollama/cloud.rs
Normal file
@@ -0,0 +1,108 @@
|
||||
use std::{env, time::Duration};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use owlen_core::{
|
||||
Error as CoreError, Result as CoreResult,
|
||||
config::OLLAMA_CLOUD_BASE_URL,
|
||||
provider::{
|
||||
GenerateRequest, GenerateStream, ModelInfo, ModelProvider, ProviderMetadata,
|
||||
ProviderStatus, ProviderType,
|
||||
},
|
||||
};
|
||||
use serde_json::{Number, Value};
|
||||
|
||||
use super::OllamaClient;
|
||||
|
||||
const API_KEY_ENV: &str = "OLLAMA_CLOUD_API_KEY";
|
||||
|
||||
/// ModelProvider implementation for the hosted Ollama Cloud service.
|
||||
pub struct OllamaCloudProvider {
|
||||
client: OllamaClient,
|
||||
}
|
||||
|
||||
impl OllamaCloudProvider {
|
||||
/// Construct a new cloud provider. An API key must be supplied either
|
||||
/// directly or via the `OLLAMA_CLOUD_API_KEY` environment variable.
|
||||
pub fn new(
|
||||
base_url: Option<String>,
|
||||
api_key: Option<String>,
|
||||
request_timeout: Option<Duration>,
|
||||
) -> CoreResult<Self> {
|
||||
let (api_key, key_source) = resolve_api_key(api_key)?;
|
||||
let base_url = base_url.unwrap_or_else(|| OLLAMA_CLOUD_BASE_URL.to_string());
|
||||
|
||||
let mut metadata =
|
||||
ProviderMetadata::new("ollama_cloud", "Ollama (Cloud)", ProviderType::Cloud, true);
|
||||
metadata
|
||||
.metadata
|
||||
.insert("base_url".into(), Value::String(base_url.clone()));
|
||||
metadata.metadata.insert(
|
||||
"api_key_source".into(),
|
||||
Value::String(key_source.to_string()),
|
||||
);
|
||||
metadata
|
||||
.metadata
|
||||
.insert("api_key_env".into(), Value::String(API_KEY_ENV.to_string()));
|
||||
|
||||
if let Some(timeout) = request_timeout {
|
||||
let timeout_ms = timeout.as_millis().min(u128::from(u64::MAX)) as u64;
|
||||
metadata.metadata.insert(
|
||||
"request_timeout_ms".into(),
|
||||
Value::Number(Number::from(timeout_ms)),
|
||||
);
|
||||
}
|
||||
|
||||
let client = OllamaClient::new(&base_url, Some(api_key), metadata, request_timeout)?;
|
||||
|
||||
Ok(Self { client })
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ModelProvider for OllamaCloudProvider {
|
||||
fn metadata(&self) -> &ProviderMetadata {
|
||||
self.client.metadata()
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> CoreResult<ProviderStatus> {
|
||||
match self.client.health_check().await {
|
||||
Ok(status) => Ok(status),
|
||||
Err(CoreError::Auth(_)) => Ok(ProviderStatus::RequiresSetup),
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> CoreResult<Vec<ModelInfo>> {
|
||||
self.client.list_models().await
|
||||
}
|
||||
|
||||
async fn generate_stream(&self, request: GenerateRequest) -> CoreResult<GenerateStream> {
|
||||
self.client.generate_stream(request).await
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_api_key(api_key: Option<String>) -> CoreResult<(String, &'static str)> {
|
||||
let key_from_config = api_key
|
||||
.as_ref()
|
||||
.map(|value| value.trim())
|
||||
.filter(|value| !value.is_empty())
|
||||
.map(str::to_string);
|
||||
|
||||
if let Some(key) = key_from_config {
|
||||
return Ok((key, "config"));
|
||||
}
|
||||
|
||||
let key_from_env = env::var(API_KEY_ENV)
|
||||
.ok()
|
||||
.map(|value| value.trim().to_string())
|
||||
.filter(|value| !value.is_empty());
|
||||
|
||||
if let Some(key) = key_from_env {
|
||||
return Ok((key, "env"));
|
||||
}
|
||||
|
||||
Err(CoreError::Config(
|
||||
"Ollama Cloud API key not configured. Set OLLAMA_CLOUD_API_KEY or configure an API key."
|
||||
.into(),
|
||||
))
|
||||
}
|
||||
@@ -6,7 +6,7 @@ use owlen_core::provider::{
|
||||
ProviderType,
|
||||
};
|
||||
use owlen_core::{Error as CoreError, Result as CoreResult};
|
||||
use serde_json::Value;
|
||||
use serde_json::{Number, Value};
|
||||
use tokio::time::timeout;
|
||||
|
||||
use super::OllamaClient;
|
||||
@@ -37,9 +37,10 @@ impl OllamaLocalProvider {
|
||||
.metadata
|
||||
.insert("base_url".into(), Value::String(base_url.clone()));
|
||||
if let Some(timeout) = request_timeout {
|
||||
let timeout_ms = timeout.as_millis().min(u128::from(u64::MAX)) as u64;
|
||||
metadata.metadata.insert(
|
||||
"request_timeout_ms".into(),
|
||||
Value::String(timeout.as_millis().to_string()),
|
||||
Value::Number(Number::from(timeout_ms)),
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
pub mod cloud;
|
||||
pub mod local;
|
||||
pub mod shared;
|
||||
|
||||
pub use cloud::OllamaCloudProvider;
|
||||
pub use local::OllamaLocalProvider;
|
||||
pub use shared::OllamaClient;
|
||||
|
||||
@@ -84,12 +84,7 @@ impl OllamaClient {
|
||||
let bytes = response.bytes().await.map_err(map_reqwest_error)?;
|
||||
|
||||
if !status.is_success() {
|
||||
let body = String::from_utf8_lossy(&bytes);
|
||||
return Err(CoreError::Provider(anyhow::anyhow!(
|
||||
"Ollama tags request failed: HTTP {} - {}",
|
||||
status,
|
||||
body
|
||||
)));
|
||||
return Err(map_http_error("tags", status, &bytes));
|
||||
}
|
||||
|
||||
let payload: TagsResponse =
|
||||
@@ -121,12 +116,7 @@ impl OllamaClient {
|
||||
|
||||
if !status.is_success() {
|
||||
let bytes = response.bytes().await.map_err(map_reqwest_error)?;
|
||||
let body = String::from_utf8_lossy(&bytes);
|
||||
return Err(CoreError::Provider(anyhow::anyhow!(
|
||||
"Ollama generate request failed: HTTP {} - {}",
|
||||
status,
|
||||
body
|
||||
)));
|
||||
return Err(map_http_error("generate", status, &bytes));
|
||||
}
|
||||
|
||||
let stream = response.bytes_stream();
|
||||
@@ -351,6 +341,43 @@ fn parse_stream_line(line: &str) -> CoreResult<GenerateChunk> {
|
||||
Ok(chunk)
|
||||
}
|
||||
|
||||
fn map_http_error(endpoint: &str, status: StatusCode, body: &[u8]) -> CoreError {
|
||||
match status {
|
||||
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => CoreError::Auth(format!(
|
||||
"Ollama {} request unauthorized (status {})",
|
||||
endpoint, status
|
||||
)),
|
||||
StatusCode::TOO_MANY_REQUESTS => CoreError::Provider(anyhow::anyhow!(
|
||||
"Ollama {} request rate limited (status {})",
|
||||
endpoint,
|
||||
status
|
||||
)),
|
||||
_ => {
|
||||
let snippet = truncated_body(body);
|
||||
CoreError::Provider(anyhow::anyhow!(
|
||||
"Ollama {} request failed: HTTP {} - {}",
|
||||
endpoint,
|
||||
status,
|
||||
snippet
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn truncated_body(body: &[u8]) -> String {
|
||||
const MAX_CHARS: usize = 512;
|
||||
let text = String::from_utf8_lossy(body);
|
||||
let mut value = String::new();
|
||||
for (idx, ch) in text.chars().enumerate() {
|
||||
if idx >= MAX_CHARS {
|
||||
value.push('…');
|
||||
return value;
|
||||
}
|
||||
value.push(ch);
|
||||
}
|
||||
value
|
||||
}
|
||||
|
||||
fn map_reqwest_error(err: reqwest::Error) -> CoreError {
|
||||
if err.is_timeout() {
|
||||
CoreError::Timeout(err.to_string())
|
||||
|
||||
Reference in New Issue
Block a user