- Export `LLMProvider` from `owlen-core` and replace public `Provider` re-exports. - Convert `OllamaProvider` to implement the new `LLMProvider` trait with associated future types. - Adjust imports and trait bounds in `remote_client.rs` to use the updated types. - Add comprehensive provider interface tests (`provider_interface.rs`) verifying router routing and provider registry model listing with `MockProvider`. - Align dependency versions across workspace crates by switching to workspace-managed versions. - Extend CI (`.woodpecker.yml`) with a dedicated test step and generate coverage reports. - Update architecture documentation to reflect the new provider abstraction.
370 lines
11 KiB
Rust
370 lines
11 KiB
Rust
//! Provider traits and registries.
|
|
|
|
use crate::{types::*, Error, Result};
|
|
use anyhow::anyhow;
|
|
use futures::{Stream, StreamExt};
|
|
use std::future::Future;
|
|
use std::pin::Pin;
|
|
use std::sync::Arc;
|
|
|
|
/// A stream of chat responses
|
|
pub type ChatStream = Pin<Box<dyn Stream<Item = Result<ChatResponse>> + Send>>;
|
|
|
|
/// Trait for LLM providers (Ollama, OpenAI, Anthropic, etc.) with zero-cost static dispatch.
|
|
pub trait LLMProvider: Send + Sync + 'static {
|
|
type Stream: Stream<Item = Result<ChatResponse>> + Send + 'static;
|
|
|
|
type ListModelsFuture<'a>: Future<Output = Result<Vec<ModelInfo>>> + Send
|
|
where
|
|
Self: 'a;
|
|
|
|
type ChatFuture<'a>: Future<Output = Result<ChatResponse>> + Send
|
|
where
|
|
Self: 'a;
|
|
|
|
type ChatStreamFuture<'a>: Future<Output = Result<Self::Stream>> + Send
|
|
where
|
|
Self: 'a;
|
|
|
|
type HealthCheckFuture<'a>: Future<Output = Result<()>> + Send
|
|
where
|
|
Self: 'a;
|
|
|
|
fn name(&self) -> &str;
|
|
|
|
fn list_models(&self) -> Self::ListModelsFuture<'_>;
|
|
fn chat(&self, request: ChatRequest) -> Self::ChatFuture<'_>;
|
|
fn chat_stream(&self, request: ChatRequest) -> Self::ChatStreamFuture<'_>;
|
|
fn health_check(&self) -> Self::HealthCheckFuture<'_>;
|
|
|
|
fn config_schema(&self) -> serde_json::Value {
|
|
serde_json::json!({})
|
|
}
|
|
}
|
|
|
|
/// Helper that implements [`LLMProvider::chat`] in terms of [`LLMProvider::chat_stream`].
|
|
pub async fn chat_via_stream<'a, P>(provider: &'a P, request: ChatRequest) -> Result<ChatResponse>
|
|
where
|
|
P: LLMProvider + 'a,
|
|
{
|
|
let stream = provider.chat_stream(request).await?;
|
|
let mut boxed: ChatStream = Box::pin(stream);
|
|
match boxed.next().await {
|
|
Some(Ok(response)) => Ok(response),
|
|
Some(Err(err)) => Err(err),
|
|
None => Err(Error::Provider(anyhow!(
|
|
"Empty chat stream from provider {}",
|
|
provider.name()
|
|
))),
|
|
}
|
|
}
|
|
|
|
/// Object-safe wrapper trait for runtime-configurable provider usage.
|
|
#[async_trait::async_trait]
|
|
pub trait Provider: Send + Sync {
|
|
/// Get the name of this provider.
|
|
fn name(&self) -> &str;
|
|
|
|
/// List available models from this provider.
|
|
async fn list_models(&self) -> Result<Vec<ModelInfo>>;
|
|
|
|
/// Send a chat completion request.
|
|
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse>;
|
|
|
|
/// Send a streaming chat completion request.
|
|
async fn chat_stream(&self, request: ChatRequest) -> Result<ChatStream>;
|
|
|
|
/// Check if the provider is available/healthy.
|
|
async fn health_check(&self) -> Result<()>;
|
|
|
|
/// Get provider-specific configuration schema.
|
|
fn config_schema(&self) -> serde_json::Value {
|
|
serde_json::json!({})
|
|
}
|
|
}
|
|
|
|
#[async_trait::async_trait]
|
|
impl<T> Provider for T
|
|
where
|
|
T: LLMProvider,
|
|
{
|
|
fn name(&self) -> &str {
|
|
LLMProvider::name(self)
|
|
}
|
|
|
|
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
|
|
LLMProvider::list_models(self).await
|
|
}
|
|
|
|
async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
|
|
LLMProvider::chat(self, request).await
|
|
}
|
|
|
|
async fn chat_stream(&self, request: ChatRequest) -> Result<ChatStream> {
|
|
let stream = LLMProvider::chat_stream(self, request).await?;
|
|
Ok(Box::pin(stream))
|
|
}
|
|
|
|
async fn health_check(&self) -> Result<()> {
|
|
LLMProvider::health_check(self).await
|
|
}
|
|
|
|
fn config_schema(&self) -> serde_json::Value {
|
|
LLMProvider::config_schema(self)
|
|
}
|
|
}
|
|
|
|
/// Configuration for a provider
|
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
|
pub struct ProviderConfig {
|
|
/// Provider type identifier
|
|
pub provider_type: String,
|
|
/// Base URL for API calls
|
|
pub base_url: Option<String>,
|
|
/// API key or token
|
|
pub api_key: Option<String>,
|
|
/// Additional provider-specific configuration
|
|
#[serde(flatten)]
|
|
pub extra: std::collections::HashMap<String, serde_json::Value>,
|
|
}
|
|
|
|
/// A registry of providers
|
|
pub struct ProviderRegistry {
|
|
providers: std::collections::HashMap<String, Arc<dyn Provider>>,
|
|
}
|
|
|
|
impl ProviderRegistry {
|
|
/// Create a new provider registry
|
|
pub fn new() -> Self {
|
|
Self {
|
|
providers: std::collections::HashMap::new(),
|
|
}
|
|
}
|
|
|
|
/// Register a provider using static dispatch.
|
|
pub fn register<P: LLMProvider + 'static>(&mut self, provider: P) {
|
|
self.register_arc(Arc::new(provider));
|
|
}
|
|
|
|
/// Register an already wrapped provider
|
|
pub fn register_arc(&mut self, provider: Arc<dyn Provider>) {
|
|
let name = provider.name().to_string();
|
|
self.providers.insert(name, provider);
|
|
}
|
|
|
|
/// Get a provider by name
|
|
pub fn get(&self, name: &str) -> Option<Arc<dyn Provider>> {
|
|
self.providers.get(name).cloned()
|
|
}
|
|
|
|
/// List all registered provider names
|
|
pub fn list_providers(&self) -> Vec<String> {
|
|
self.providers.keys().cloned().collect()
|
|
}
|
|
|
|
/// Get all models from all providers
|
|
pub async fn list_all_models(&self) -> Result<Vec<ModelInfo>> {
|
|
let mut all_models = Vec::new();
|
|
|
|
for provider in self.providers.values() {
|
|
match provider.list_models().await {
|
|
Ok(mut models) => all_models.append(&mut models),
|
|
Err(_) => {
|
|
// Continue with other providers
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(all_models)
|
|
}
|
|
}
|
|
|
|
impl Default for ProviderRegistry {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
pub mod test_utils {
|
|
use super::*;
|
|
use crate::types::{ChatRequest, ChatResponse, Message, ModelInfo, Role};
|
|
use futures::stream;
|
|
use std::future::{ready, Ready};
|
|
|
|
/// Mock provider for testing
|
|
#[derive(Default)]
|
|
pub struct MockProvider;
|
|
|
|
impl LLMProvider for MockProvider {
|
|
type Stream = stream::Iter<std::vec::IntoIter<Result<ChatResponse>>>;
|
|
type ListModelsFuture<'a> = Ready<Result<Vec<ModelInfo>>>;
|
|
type ChatFuture<'a> = Ready<Result<ChatResponse>>;
|
|
type ChatStreamFuture<'a> = Ready<Result<Self::Stream>>;
|
|
type HealthCheckFuture<'a> = Ready<Result<()>>;
|
|
|
|
fn name(&self) -> &str {
|
|
"mock"
|
|
}
|
|
|
|
fn list_models(&self) -> Self::ListModelsFuture<'_> {
|
|
ready(Ok(vec![ModelInfo {
|
|
id: "mock-model".to_string(),
|
|
provider: "mock".to_string(),
|
|
name: "mock-model".to_string(),
|
|
description: None,
|
|
context_window: None,
|
|
capabilities: vec![],
|
|
supports_tools: false,
|
|
}]))
|
|
}
|
|
|
|
fn chat(&self, request: ChatRequest) -> Self::ChatFuture<'_> {
|
|
ready(Ok(self.build_response(&request)))
|
|
}
|
|
|
|
fn chat_stream(&self, request: ChatRequest) -> Self::ChatStreamFuture<'_> {
|
|
let response = self.build_response(&request);
|
|
ready(Ok(stream::iter(vec![Ok(response)])))
|
|
}
|
|
|
|
fn health_check(&self) -> Self::HealthCheckFuture<'_> {
|
|
ready(Ok(()))
|
|
}
|
|
}
|
|
|
|
impl MockProvider {
|
|
fn build_response(&self, request: &ChatRequest) -> ChatResponse {
|
|
let content = format!(
|
|
"Mock response to: {}",
|
|
request
|
|
.messages
|
|
.last()
|
|
.map(|m| m.content.clone())
|
|
.unwrap_or_default()
|
|
);
|
|
|
|
ChatResponse {
|
|
message: Message::new(Role::Assistant, content),
|
|
usage: None,
|
|
is_streaming: false,
|
|
is_final: true,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::test_utils::MockProvider;
|
|
use super::*;
|
|
use crate::types::{ChatParameters, ChatRequest, ChatResponse, Message, ModelInfo, Role};
|
|
use futures::stream;
|
|
use std::future::{ready, Ready};
|
|
use std::sync::Arc;
|
|
|
|
struct StreamingProvider;
|
|
|
|
impl LLMProvider for StreamingProvider {
|
|
type Stream = stream::Iter<std::vec::IntoIter<Result<ChatResponse>>>;
|
|
type ListModelsFuture<'a> = Ready<Result<Vec<ModelInfo>>>;
|
|
type ChatFuture<'a> = Ready<Result<ChatResponse>>;
|
|
type ChatStreamFuture<'a> = Ready<Result<Self::Stream>>;
|
|
type HealthCheckFuture<'a> = Ready<Result<()>>;
|
|
|
|
fn name(&self) -> &str {
|
|
"streaming"
|
|
}
|
|
|
|
fn list_models(&self) -> Self::ListModelsFuture<'_> {
|
|
ready(Ok(vec![ModelInfo {
|
|
id: "stream-model".to_string(),
|
|
provider: "streaming".to_string(),
|
|
name: "stream-model".to_string(),
|
|
description: None,
|
|
context_window: None,
|
|
capabilities: vec!["chat".to_string()],
|
|
supports_tools: false,
|
|
}]))
|
|
}
|
|
|
|
fn chat(&self, request: ChatRequest) -> Self::ChatFuture<'_> {
|
|
ready(Ok(self.response(&request)))
|
|
}
|
|
|
|
fn chat_stream(&self, request: ChatRequest) -> Self::ChatStreamFuture<'_> {
|
|
let response = self.response(&request);
|
|
ready(Ok(stream::iter(vec![Ok(response)])))
|
|
}
|
|
|
|
fn health_check(&self) -> Self::HealthCheckFuture<'_> {
|
|
ready(Ok(()))
|
|
}
|
|
}
|
|
|
|
impl StreamingProvider {
|
|
fn response(&self, request: &ChatRequest) -> ChatResponse {
|
|
let reply = format!(
|
|
"echo:{}",
|
|
request
|
|
.messages
|
|
.last()
|
|
.map(|m| m.content.clone())
|
|
.unwrap_or_default()
|
|
);
|
|
ChatResponse {
|
|
message: Message::new(Role::Assistant, reply),
|
|
usage: None,
|
|
is_streaming: true,
|
|
is_final: true,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn default_chat_reads_from_stream() {
|
|
let provider = StreamingProvider;
|
|
let request = ChatRequest {
|
|
model: "stream-model".to_string(),
|
|
messages: vec![Message::new(Role::User, "ping".to_string())],
|
|
parameters: ChatParameters::default(),
|
|
tools: None,
|
|
};
|
|
|
|
let response = LLMProvider::chat(&provider, request)
|
|
.await
|
|
.expect("chat succeeded");
|
|
assert_eq!(response.message.content, "echo:ping");
|
|
assert!(response.is_final);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn registry_registers_static_provider() {
|
|
let mut registry = ProviderRegistry::new();
|
|
registry.register(StreamingProvider);
|
|
|
|
let provider = registry.get("streaming").expect("provider registered");
|
|
let models = provider.list_models().await.expect("models listed");
|
|
assert_eq!(models[0].id, "stream-model");
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn registry_accepts_dynamic_provider() {
|
|
let mut registry = ProviderRegistry::new();
|
|
let provider: Arc<dyn Provider> = Arc::new(MockProvider::default());
|
|
registry.register_arc(provider.clone());
|
|
|
|
let fetched = registry.get("mock").expect("mock provider present");
|
|
let request = ChatRequest {
|
|
model: "mock-model".to_string(),
|
|
messages: vec![Message::new(Role::User, "hi".to_string())],
|
|
parameters: ChatParameters::default(),
|
|
tools: None,
|
|
};
|
|
let response = Provider::chat(fetched.as_ref(), request)
|
|
.await
|
|
.expect("chat succeeded");
|
|
assert_eq!(response.message.content, "Mock response to: hi");
|
|
}
|
|
}
|