Files
owlen/crates/owlen-providers/src/ollama/shared.rs
vikingowl e0b14a42f2 fix(provider/ollama): keep stream whitespace intact
Acceptance Criteria:\n- streaming chunks retain leading whitespace and indentation\n- end-of-stream metadata is still propagated\n- malformed frames emit defensive logging without crashing

Test Notes:\n- cargo test -p owlen-providers
2025-10-23 19:40:53 +02:00

500 lines
15 KiB
Rust

use std::collections::HashMap;
use std::time::Duration;
use futures::StreamExt;
use log::warn;
use owlen_core::provider::{
GenerateChunk, GenerateRequest, GenerateStream, ModelInfo, ProviderMetadata, ProviderStatus,
};
use owlen_core::{Error as CoreError, Result as CoreResult};
use reqwest::{Client, Method, StatusCode, Url};
use serde::Deserialize;
use serde_json::{Map as JsonMap, Value};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
const DEFAULT_TIMEOUT_SECS: u64 = 60;
/// Shared Ollama HTTP client used by both local and cloud providers.
#[derive(Clone)]
pub struct OllamaClient {
http: Client,
base_url: Url,
api_key: Option<String>,
provider_metadata: ProviderMetadata,
}
impl OllamaClient {
/// Create a new client with the given base URL and optional API key.
pub fn new(
base_url: impl AsRef<str>,
api_key: Option<String>,
provider_metadata: ProviderMetadata,
request_timeout: Option<Duration>,
) -> CoreResult<Self> {
let base_url = Url::parse(base_url.as_ref())
.map_err(|err| CoreError::Config(format!("invalid base url: {}", err)))?;
let timeout = request_timeout.unwrap_or_else(|| Duration::from_secs(DEFAULT_TIMEOUT_SECS));
let http = Client::builder()
.timeout(timeout)
.build()
.map_err(map_reqwest_error)?;
Ok(Self {
http,
base_url,
api_key,
provider_metadata,
})
}
/// Provider metadata associated with this client.
pub fn metadata(&self) -> &ProviderMetadata {
&self.provider_metadata
}
/// Perform a basic health check to determine provider availability.
pub async fn health_check(&self) -> CoreResult<ProviderStatus> {
let url = self.endpoint("api/tags")?;
let response = self
.request(Method::GET, url)
.send()
.await
.map_err(map_reqwest_error)?;
match response.status() {
status if status.is_success() => Ok(ProviderStatus::Available),
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => Ok(ProviderStatus::RequiresSetup),
_ => Ok(ProviderStatus::Unavailable),
}
}
/// Fetch the available models from the Ollama API.
pub async fn list_models(&self) -> CoreResult<Vec<ModelInfo>> {
let url = self.endpoint("api/tags")?;
let response = self
.request(Method::GET, url)
.send()
.await
.map_err(map_reqwest_error)?;
let status = response.status();
let bytes = response.bytes().await.map_err(map_reqwest_error)?;
if !status.is_success() {
return Err(map_http_error("tags", status, &bytes));
}
let payload: TagsResponse =
serde_json::from_slice(&bytes).map_err(CoreError::Serialization)?;
let models = payload
.models
.into_iter()
.map(|model| self.parse_model_info(model))
.collect();
Ok(models)
}
/// Request a streaming generation session from Ollama.
pub async fn generate_stream(&self, request: GenerateRequest) -> CoreResult<GenerateStream> {
let url = self.endpoint("api/generate")?;
let body = self.build_generate_body(request);
let response = self
.request(Method::POST, url)
.json(&body)
.send()
.await
.map_err(map_reqwest_error)?;
let status = response.status();
if !status.is_success() {
let bytes = response.bytes().await.map_err(map_reqwest_error)?;
return Err(map_http_error("generate", status, &bytes));
}
let stream = response.bytes_stream();
let (tx, rx) = mpsc::channel::<CoreResult<GenerateChunk>>(32);
tokio::spawn(async move {
let mut stream = stream;
let mut buffer: Vec<u8> = Vec::new();
while let Some(chunk) = stream.next().await {
match chunk {
Ok(bytes) => {
buffer.extend_from_slice(&bytes);
while let Some(pos) = buffer.iter().position(|byte| *byte == b'\n') {
let line_bytes: Vec<u8> = buffer.drain(..=pos).collect();
if let Some(line) = prepare_stream_line(&line_bytes) {
match parse_stream_line(&line) {
Ok(item) => {
if tx.send(Ok(item)).await.is_err() {
return;
}
}
Err(err) => {
let _ = tx.send(Err(err)).await;
return;
}
}
}
}
}
Err(err) => {
let _ = tx.send(Err(map_reqwest_error(err))).await;
return;
}
}
}
if !buffer.is_empty() {
let line_bytes = std::mem::take(&mut buffer);
if let Some(line) = prepare_stream_line(&line_bytes) {
match parse_stream_line(&line) {
Ok(item) => {
let _ = tx.send(Ok(item)).await;
}
Err(err) => {
let _ = tx.send(Err(err)).await;
}
}
}
}
});
let stream = ReceiverStream::new(rx);
Ok(Box::pin(stream))
}
fn request(&self, method: Method, url: Url) -> reqwest::RequestBuilder {
let mut builder = self.http.request(method, url);
if let Some(api_key) = &self.api_key {
builder = builder.bearer_auth(api_key);
}
builder
}
fn endpoint(&self, path: &str) -> CoreResult<Url> {
self.base_url
.join(path)
.map_err(|err| CoreError::Config(format!("invalid endpoint '{}': {}", path, err)))
}
fn build_generate_body(&self, request: GenerateRequest) -> Value {
let GenerateRequest {
model,
prompt,
context,
parameters,
metadata,
} = request;
let mut body = JsonMap::new();
body.insert("model".into(), Value::String(model));
body.insert("stream".into(), Value::Bool(true));
if let Some(prompt) = prompt {
body.insert("prompt".into(), Value::String(prompt));
}
if !context.is_empty() {
let items = context.into_iter().map(Value::String).collect();
body.insert("context".into(), Value::Array(items));
}
if !parameters.is_empty() {
body.insert("options".into(), Value::Object(to_json_map(parameters)));
}
if !metadata.is_empty() {
body.insert("metadata".into(), Value::Object(to_json_map(metadata)));
}
Value::Object(body)
}
fn parse_model_info(&self, model: OllamaModel) -> ModelInfo {
let mut metadata = HashMap::new();
if let Some(digest) = model.digest {
metadata.insert("digest".to_string(), Value::String(digest));
}
if let Some(modified) = model.modified_at {
metadata.insert("modified_at".to_string(), Value::String(modified));
}
if let Some(details) = model.details {
let mut details_map = JsonMap::new();
if let Some(format) = details.format {
details_map.insert("format".into(), Value::String(format));
}
if let Some(family) = details.family {
details_map.insert("family".into(), Value::String(family));
}
if let Some(parameter_size) = details.parameter_size {
details_map.insert("parameter_size".into(), Value::String(parameter_size));
}
if let Some(quantisation) = details.quantization_level {
details_map.insert("quantization_level".into(), Value::String(quantisation));
}
if !details_map.is_empty() {
metadata.insert("details".to_string(), Value::Object(details_map));
}
}
ModelInfo {
name: model.name,
size_bytes: model.size,
capabilities: Vec::new(),
description: None,
provider: self.provider_metadata.clone(),
metadata,
}
}
}
#[derive(Debug, Deserialize)]
struct TagsResponse {
#[serde(default)]
models: Vec<OllamaModel>,
}
#[derive(Debug, Deserialize)]
struct OllamaModel {
name: String,
#[serde(default)]
size: Option<u64>,
#[serde(default)]
digest: Option<String>,
#[serde(default)]
modified_at: Option<String>,
#[serde(default)]
details: Option<OllamaModelDetails>,
}
#[derive(Debug, Deserialize)]
struct OllamaModelDetails {
#[serde(default)]
format: Option<String>,
#[serde(default)]
family: Option<String>,
#[serde(default)]
parameter_size: Option<String>,
#[serde(default)]
quantization_level: Option<String>,
}
fn to_json_map(source: HashMap<String, Value>) -> JsonMap<String, Value> {
source.into_iter().collect()
}
fn to_metadata_map(value: &Value) -> HashMap<String, Value> {
let mut metadata = HashMap::new();
if let Value::Object(obj) = value {
for (key, item) in obj {
if key == "response" || key == "done" {
continue;
}
metadata.insert(key.clone(), item.clone());
}
}
metadata
}
fn prepare_stream_line(bytes: &[u8]) -> Option<String> {
if bytes.is_empty() {
return None;
}
let mut line = String::from_utf8_lossy(bytes).into_owned();
while line.ends_with('\n') || line.ends_with('\r') {
line.pop();
}
if line.trim().is_empty() {
return None;
}
Some(line)
}
fn log_stream_decode_error(line: &str, err: &serde_json::Error) {
const MAX_PREVIEW_CHARS: usize = 256;
let total_chars = line.chars().count();
let truncated = total_chars > MAX_PREVIEW_CHARS;
let mut preview: String = line.chars().take(MAX_PREVIEW_CHARS).collect();
if truncated {
preview.push_str("...");
}
let preview = preview
.replace('\n', "\\n")
.replace('\r', "\\r")
.replace('\t', "\\t");
warn!(
"Failed to parse Ollama stream chunk ({} chars): {}. Preview: \"{}\"",
total_chars, err, preview
);
}
fn parse_stream_line(line: &str) -> CoreResult<GenerateChunk> {
let value: Value = serde_json::from_str(line).map_err(|err| {
log_stream_decode_error(line, &err);
CoreError::Serialization(err)
})?;
if let Some(error) = value.get("error").and_then(Value::as_str) {
return Err(CoreError::Provider(anyhow::anyhow!(
"ollama generation error: {}",
error
)));
}
let mut chunk = GenerateChunk {
text: value
.get("response")
.and_then(Value::as_str)
.map(str::to_string),
is_final: value.get("done").and_then(Value::as_bool).unwrap_or(false),
metadata: to_metadata_map(&value),
};
if chunk.is_final {
if let Some(Value::Object(done_obj)) = value.get("done") {
for (key, item) in done_obj {
chunk.metadata.insert(key.clone(), item.clone());
}
}
if chunk.text.is_none() && chunk.metadata.is_empty() {
chunk
.metadata
.insert("status".into(), Value::String("done".into()));
}
}
Ok(chunk)
}
fn map_http_error(endpoint: &str, status: StatusCode, body: &[u8]) -> CoreError {
match status {
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => CoreError::Auth(format!(
"Ollama {} request unauthorized (status {})",
endpoint, status
)),
StatusCode::TOO_MANY_REQUESTS => CoreError::Provider(anyhow::anyhow!(
"Ollama {} request rate limited (status {})",
endpoint,
status
)),
_ => {
let snippet = truncated_body(body);
CoreError::Provider(anyhow::anyhow!(
"Ollama {} request failed: HTTP {} - {}",
endpoint,
status,
snippet
))
}
}
}
fn truncated_body(body: &[u8]) -> String {
const MAX_CHARS: usize = 512;
let text = String::from_utf8_lossy(body);
let mut value = String::new();
for (idx, ch) in text.chars().enumerate() {
if idx >= MAX_CHARS {
value.push('…');
return value;
}
value.push(ch);
}
value
}
fn map_reqwest_error(err: reqwest::Error) -> CoreError {
if err.is_timeout() {
CoreError::Timeout(err.to_string())
} else if err.is_connect() || err.is_request() {
CoreError::Network(err.to_string())
} else {
CoreError::Provider(err.into())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn prepare_stream_line_preserves_leading_whitespace() {
let mut bytes = br#"{"response":" fn main() {}\n","done":false}"#.to_vec();
bytes.extend_from_slice(b"\r\n");
let line = prepare_stream_line(&bytes).expect("line should be parsed");
assert!(line.starts_with(r#"{"response""#));
assert!(line.ends_with(r#""done":false}"#));
let chunk = parse_stream_line(&line).expect("chunk should parse");
assert_eq!(
chunk.text.as_deref(),
Some(" fn main() {}\n"),
"leading indentation must be preserved"
);
assert!(!chunk.is_final);
}
#[test]
fn parse_stream_line_handles_samples_fixture() {
let data = include_str!("../../../../samples.json");
let values: Vec<Value> =
serde_json::from_str(data).expect("samples fixture should be valid json");
let mut chunks = Vec::new();
for value in values {
let line = serde_json::to_string(&value).expect("serialize chunk");
let chunk = parse_stream_line(&line).expect("parse chunk");
chunks.push(chunk);
}
assert!(
!chunks.is_empty(),
"fixture must produce at least one chunk"
);
assert_eq!(
chunks[0].text.as_deref(),
Some("first"),
"first chunk should match fixture payload"
);
let final_chunk = chunks.last().expect("final chunk must exist");
assert!(
final_chunk.is_final,
"last chunk should be marked final per fixture"
);
assert!(
final_chunk.text.as_deref().unwrap_or_default().is_empty(),
"final chunk should not include stray text"
);
assert!(
final_chunk.metadata.contains_key("final_data"),
"final chunk should surface metadata from fixture"
);
}
}