531 lines
17 KiB
Rust
531 lines
17 KiB
Rust
//! Ollama provider for OWLEN LLM client
|
|
|
|
use futures_util::StreamExt;
|
|
use owlen_core::{
|
|
config::GeneralSettings,
|
|
model::ModelManager,
|
|
provider::{ChatStream, Provider, ProviderConfig},
|
|
types::{ChatParameters, ChatRequest, ChatResponse, Message, ModelInfo, Role, TokenUsage},
|
|
Result,
|
|
};
|
|
use reqwest::Client;
|
|
use serde::{Deserialize, Serialize};
|
|
use serde_json::{json, Value};
|
|
use std::collections::HashMap;
|
|
use std::io;
|
|
use std::time::Duration;
|
|
use tokio::sync::mpsc;
|
|
use tokio_stream::wrappers::UnboundedReceiverStream;
|
|
|
|
const DEFAULT_TIMEOUT_SECS: u64 = 120;
|
|
const DEFAULT_MODEL_CACHE_TTL_SECS: u64 = 60;
|
|
|
|
/// Ollama provider implementation with enhanced configuration and caching
|
|
pub struct OllamaProvider {
|
|
client: Client,
|
|
base_url: String,
|
|
model_manager: ModelManager,
|
|
}
|
|
|
|
/// Options for configuring the Ollama provider
|
|
pub struct OllamaOptions {
|
|
pub base_url: String,
|
|
pub request_timeout: Duration,
|
|
pub model_cache_ttl: Duration,
|
|
}
|
|
|
|
impl OllamaOptions {
|
|
pub fn new(base_url: impl Into<String>) -> Self {
|
|
Self {
|
|
base_url: base_url.into(),
|
|
request_timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
|
|
model_cache_ttl: Duration::from_secs(DEFAULT_MODEL_CACHE_TTL_SECS),
|
|
}
|
|
}
|
|
|
|
pub fn with_general(mut self, general: &GeneralSettings) -> Self {
|
|
self.model_cache_ttl = general.model_cache_ttl();
|
|
self
|
|
}
|
|
}
|
|
|
|
/// Ollama-specific message format
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
struct OllamaMessage {
|
|
role: String,
|
|
content: String,
|
|
}
|
|
|
|
/// Ollama chat request format
|
|
#[derive(Debug, Serialize)]
|
|
struct OllamaChatRequest {
|
|
model: String,
|
|
messages: Vec<OllamaMessage>,
|
|
stream: bool,
|
|
#[serde(flatten)]
|
|
options: HashMap<String, Value>,
|
|
}
|
|
|
|
/// Ollama chat response format
|
|
#[derive(Debug, Deserialize)]
|
|
struct OllamaChatResponse {
|
|
message: Option<OllamaMessage>,
|
|
done: bool,
|
|
#[serde(default)]
|
|
prompt_eval_count: Option<u32>,
|
|
#[serde(default)]
|
|
eval_count: Option<u32>,
|
|
#[serde(default)]
|
|
error: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct OllamaErrorResponse {
|
|
error: Option<String>,
|
|
}
|
|
|
|
/// Ollama models list response
|
|
#[derive(Debug, Deserialize)]
|
|
struct OllamaModelsResponse {
|
|
models: Vec<OllamaModelInfo>,
|
|
}
|
|
|
|
/// Ollama model information
|
|
#[derive(Debug, Deserialize)]
|
|
struct OllamaModelInfo {
|
|
name: String,
|
|
#[serde(default)]
|
|
details: Option<OllamaModelDetails>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct OllamaModelDetails {
|
|
#[serde(default)]
|
|
family: Option<String>,
|
|
}
|
|
|
|
impl OllamaProvider {
|
|
/// Create a new Ollama provider with sensible defaults
|
|
pub fn new(base_url: impl Into<String>) -> Result<Self> {
|
|
Self::with_options(OllamaOptions::new(base_url))
|
|
}
|
|
|
|
/// Create a provider from configuration settings
|
|
pub fn from_config(config: &ProviderConfig, general: Option<&GeneralSettings>) -> Result<Self> {
|
|
let mut options = OllamaOptions::new(
|
|
config
|
|
.base_url
|
|
.clone()
|
|
.unwrap_or_else(|| "http://localhost:11434".to_string()),
|
|
);
|
|
|
|
if let Some(timeout) = config
|
|
.extra
|
|
.get("timeout_secs")
|
|
.and_then(|value| value.as_u64())
|
|
{
|
|
options.request_timeout = Duration::from_secs(timeout.max(5));
|
|
}
|
|
|
|
if let Some(cache_ttl) = config
|
|
.extra
|
|
.get("model_cache_ttl_secs")
|
|
.and_then(|value| value.as_u64())
|
|
{
|
|
options.model_cache_ttl = Duration::from_secs(cache_ttl.max(5));
|
|
}
|
|
|
|
if let Some(general) = general {
|
|
options = options.with_general(general);
|
|
}
|
|
|
|
Self::with_options(options)
|
|
}
|
|
|
|
/// Create a provider from explicit options
|
|
pub fn with_options(options: OllamaOptions) -> Result<Self> {
|
|
let client = Client::builder()
|
|
.timeout(options.request_timeout)
|
|
.build()
|
|
.map_err(|e| owlen_core::Error::Config(format!("Failed to build HTTP client: {e}")))?;
|
|
|
|
Ok(Self {
|
|
client,
|
|
base_url: options.base_url.trim_end_matches('/').to_string(),
|
|
model_manager: ModelManager::new(options.model_cache_ttl),
|
|
})
|
|
}
|
|
|
|
/// Accessor for the underlying model manager
|
|
pub fn model_manager(&self) -> &ModelManager {
|
|
&self.model_manager
|
|
}
|
|
|
|
fn convert_message(message: &Message) -> OllamaMessage {
|
|
OllamaMessage {
|
|
role: match message.role {
|
|
Role::User => "user".to_string(),
|
|
Role::Assistant => "assistant".to_string(),
|
|
Role::System => "system".to_string(),
|
|
},
|
|
content: message.content.clone(),
|
|
}
|
|
}
|
|
|
|
fn convert_ollama_message(message: &OllamaMessage) -> Message {
|
|
let role = match message.role.as_str() {
|
|
"user" => Role::User,
|
|
"assistant" => Role::Assistant,
|
|
"system" => Role::System,
|
|
_ => Role::Assistant,
|
|
};
|
|
|
|
Message::new(role, message.content.clone())
|
|
}
|
|
|
|
fn build_options(parameters: ChatParameters) -> HashMap<String, Value> {
|
|
let mut options = parameters.extra;
|
|
|
|
if let Some(temperature) = parameters.temperature {
|
|
options
|
|
.entry("temperature".to_string())
|
|
.or_insert(json!(temperature as f64));
|
|
}
|
|
|
|
if let Some(max_tokens) = parameters.max_tokens {
|
|
options
|
|
.entry("num_predict".to_string())
|
|
.or_insert(json!(max_tokens));
|
|
}
|
|
|
|
options
|
|
}
|
|
|
|
async fn fetch_models(&self) -> Result<Vec<ModelInfo>> {
|
|
let url = format!("{}/api/tags", self.base_url);
|
|
|
|
let response = self
|
|
.client
|
|
.get(&url)
|
|
.send()
|
|
.await
|
|
.map_err(|e| owlen_core::Error::Network(format!("Failed to fetch models: {e}")))?;
|
|
|
|
if !response.status().is_success() {
|
|
let code = response.status();
|
|
let error = parse_error_body(response).await;
|
|
return Err(owlen_core::Error::Network(format!(
|
|
"Ollama model listing failed ({code}): {error}"
|
|
)));
|
|
}
|
|
|
|
let body = response.text().await.map_err(|e| {
|
|
owlen_core::Error::Network(format!("Failed to read models response: {e}"))
|
|
})?;
|
|
|
|
let ollama_response: OllamaModelsResponse =
|
|
serde_json::from_str(&body).map_err(owlen_core::Error::Serialization)?;
|
|
|
|
let models = ollama_response
|
|
.models
|
|
.into_iter()
|
|
.map(|model| ModelInfo {
|
|
id: model.name.clone(),
|
|
name: model.name.clone(),
|
|
description: model
|
|
.details
|
|
.as_ref()
|
|
.and_then(|d| d.family.as_ref().map(|f| format!("Ollama {f} model"))),
|
|
provider: "ollama".to_string(),
|
|
context_window: None,
|
|
capabilities: vec!["chat".to_string()],
|
|
})
|
|
.collect();
|
|
|
|
Ok(models)
|
|
}
|
|
}
|
|
|
|
#[async_trait::async_trait]
|
|
impl Provider for OllamaProvider {
|
|
fn name(&self) -> &str {
|
|
"ollama"
|
|
}
|
|
|
|
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
|
|
self.model_manager
|
|
.get_or_refresh(false, || async { self.fetch_models().await })
|
|
.await
|
|
}
|
|
|
|
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
|
|
let ChatRequest {
|
|
model,
|
|
messages,
|
|
parameters,
|
|
} = request;
|
|
|
|
let messages: Vec<OllamaMessage> = messages.iter().map(Self::convert_message).collect();
|
|
|
|
let options = Self::build_options(parameters);
|
|
|
|
let ollama_request = OllamaChatRequest {
|
|
model,
|
|
messages,
|
|
stream: false,
|
|
options,
|
|
};
|
|
|
|
let url = format!("{}/api/chat", self.base_url);
|
|
let response = self
|
|
.client
|
|
.post(&url)
|
|
.json(&ollama_request)
|
|
.send()
|
|
.await
|
|
.map_err(|e| owlen_core::Error::Network(format!("Chat request failed: {e}")))?;
|
|
|
|
if !response.status().is_success() {
|
|
let code = response.status();
|
|
let error = parse_error_body(response).await;
|
|
return Err(owlen_core::Error::Network(format!(
|
|
"Ollama chat failed ({code}): {error}"
|
|
)));
|
|
}
|
|
|
|
let body = response.text().await.map_err(|e| {
|
|
owlen_core::Error::Network(format!("Failed to read chat response: {e}"))
|
|
})?;
|
|
|
|
let mut ollama_response: OllamaChatResponse =
|
|
serde_json::from_str(&body).map_err(owlen_core::Error::Serialization)?;
|
|
|
|
if let Some(error) = ollama_response.error.take() {
|
|
return Err(owlen_core::Error::Provider(anyhow::anyhow!(error)));
|
|
}
|
|
|
|
let message = match ollama_response.message {
|
|
Some(ref msg) => Self::convert_ollama_message(msg),
|
|
None => {
|
|
return Err(owlen_core::Error::Provider(anyhow::anyhow!(
|
|
"Ollama response missing message"
|
|
)))
|
|
}
|
|
};
|
|
|
|
let usage = if let (Some(prompt_tokens), Some(completion_tokens)) = (
|
|
ollama_response.prompt_eval_count,
|
|
ollama_response.eval_count,
|
|
) {
|
|
Some(TokenUsage {
|
|
prompt_tokens,
|
|
completion_tokens,
|
|
total_tokens: prompt_tokens + completion_tokens,
|
|
})
|
|
} else {
|
|
None
|
|
};
|
|
|
|
Ok(ChatResponse {
|
|
message,
|
|
usage,
|
|
is_streaming: false,
|
|
is_final: true,
|
|
})
|
|
}
|
|
|
|
async fn chat_stream(&self, request: ChatRequest) -> Result<ChatStream> {
|
|
let ChatRequest {
|
|
model,
|
|
messages,
|
|
parameters,
|
|
} = request;
|
|
|
|
let messages: Vec<OllamaMessage> = messages.iter().map(Self::convert_message).collect();
|
|
|
|
let options = Self::build_options(parameters);
|
|
|
|
let ollama_request = OllamaChatRequest {
|
|
model,
|
|
messages,
|
|
stream: true,
|
|
options,
|
|
};
|
|
|
|
let url = format!("{}/api/chat", self.base_url);
|
|
|
|
let response = self
|
|
.client
|
|
.post(&url)
|
|
.json(&ollama_request)
|
|
.send()
|
|
.await
|
|
.map_err(|e| owlen_core::Error::Network(format!("Streaming request failed: {e}")))?;
|
|
|
|
if !response.status().is_success() {
|
|
let code = response.status();
|
|
let error = parse_error_body(response).await;
|
|
return Err(owlen_core::Error::Network(format!(
|
|
"Ollama streaming chat failed ({code}): {error}"
|
|
)));
|
|
}
|
|
|
|
let (tx, rx) = mpsc::unbounded_channel();
|
|
let mut stream = response.bytes_stream();
|
|
|
|
tokio::spawn(async move {
|
|
let mut buffer = String::new();
|
|
|
|
while let Some(chunk) = stream.next().await {
|
|
match chunk {
|
|
Ok(bytes) => {
|
|
if let Ok(text) = String::from_utf8(bytes.to_vec()) {
|
|
buffer.push_str(&text);
|
|
|
|
while let Some(pos) = buffer.find('\n') {
|
|
let mut line = buffer[..pos].trim().to_string();
|
|
buffer.drain(..=pos);
|
|
|
|
if line.is_empty() {
|
|
continue;
|
|
}
|
|
|
|
if line.ends_with('\r') {
|
|
line.pop();
|
|
}
|
|
|
|
match serde_json::from_str::<OllamaChatResponse>(&line) {
|
|
Ok(mut ollama_response) => {
|
|
if let Some(error) = ollama_response.error.take() {
|
|
let _ = tx.send(Err(owlen_core::Error::Provider(
|
|
anyhow::anyhow!(error),
|
|
)));
|
|
break;
|
|
}
|
|
|
|
if let Some(message) = ollama_response.message {
|
|
let mut chat_response = ChatResponse {
|
|
message: Self::convert_ollama_message(&message),
|
|
usage: None,
|
|
is_streaming: true,
|
|
is_final: ollama_response.done,
|
|
};
|
|
|
|
if let (Some(prompt_tokens), Some(completion_tokens)) = (
|
|
ollama_response.prompt_eval_count,
|
|
ollama_response.eval_count,
|
|
) {
|
|
chat_response.usage = Some(TokenUsage {
|
|
prompt_tokens,
|
|
completion_tokens,
|
|
total_tokens: prompt_tokens + completion_tokens,
|
|
});
|
|
}
|
|
|
|
if tx.send(Ok(chat_response)).is_err() {
|
|
break;
|
|
}
|
|
|
|
if ollama_response.done {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
Err(e) => {
|
|
let _ = tx.send(Err(owlen_core::Error::Serialization(e)));
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
let _ = tx.send(Err(owlen_core::Error::Serialization(
|
|
serde_json::Error::io(io::Error::new(
|
|
io::ErrorKind::InvalidData,
|
|
"Non UTF-8 chunk from Ollama",
|
|
)),
|
|
)));
|
|
break;
|
|
}
|
|
}
|
|
Err(e) => {
|
|
let _ = tx.send(Err(owlen_core::Error::Network(format!(
|
|
"Stream error: {e}"
|
|
))));
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
let stream = UnboundedReceiverStream::new(rx);
|
|
Ok(Box::pin(stream))
|
|
}
|
|
|
|
async fn health_check(&self) -> Result<()> {
|
|
let url = format!("{}/api/version", self.base_url);
|
|
|
|
let response = self
|
|
.client
|
|
.get(&url)
|
|
.send()
|
|
.await
|
|
.map_err(|e| owlen_core::Error::Network(format!("Health check failed: {e}")))?;
|
|
|
|
if response.status().is_success() {
|
|
Ok(())
|
|
} else {
|
|
Err(owlen_core::Error::Network(format!(
|
|
"Ollama health check failed: HTTP {}",
|
|
response.status()
|
|
)))
|
|
}
|
|
}
|
|
|
|
fn config_schema(&self) -> serde_json::Value {
|
|
serde_json::json!({
|
|
"type": "object",
|
|
"properties": {
|
|
"base_url": {
|
|
"type": "string",
|
|
"description": "Base URL for Ollama API",
|
|
"default": "http://localhost:11434"
|
|
},
|
|
"timeout_secs": {
|
|
"type": "integer",
|
|
"description": "HTTP request timeout in seconds",
|
|
"minimum": 5,
|
|
"default": DEFAULT_TIMEOUT_SECS
|
|
},
|
|
"model_cache_ttl_secs": {
|
|
"type": "integer",
|
|
"description": "Seconds to cache model listings",
|
|
"minimum": 5,
|
|
"default": DEFAULT_MODEL_CACHE_TTL_SECS
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
async fn parse_error_body(response: reqwest::Response) -> String {
|
|
match response.bytes().await {
|
|
Ok(bytes) => {
|
|
if bytes.is_empty() {
|
|
return "unknown error".to_string();
|
|
}
|
|
|
|
if let Ok(err) = serde_json::from_slice::<OllamaErrorResponse>(&bytes) {
|
|
if let Some(error) = err.error {
|
|
return error;
|
|
}
|
|
}
|
|
|
|
match String::from_utf8(bytes.to_vec()) {
|
|
Ok(text) if !text.trim().is_empty() => text,
|
|
_ => "unknown error".to_string(),
|
|
}
|
|
}
|
|
Err(_) => "unknown error".to_string(),
|
|
}
|
|
}
|