//! Provider traits and registries. use crate::{types::*, Error, Result}; use anyhow::anyhow; use futures::{Stream, StreamExt}; use std::any::Any; use std::future::Future; use std::pin::Pin; use std::sync::Arc; /// A stream of chat responses pub type ChatStream = Pin> + Send>>; /// Trait for LLM providers (Ollama, OpenAI, Anthropic, etc.) with zero-cost static dispatch. pub trait LLMProvider: Send + Sync + 'static + Any + Sized { type Stream: Stream> + Send + 'static; type ListModelsFuture<'a>: Future>> + Send where Self: 'a; type ChatFuture<'a>: Future> + Send where Self: 'a; type ChatStreamFuture<'a>: Future> + Send where Self: 'a; type HealthCheckFuture<'a>: Future> + Send where Self: 'a; fn name(&self) -> &str; fn list_models(&self) -> Self::ListModelsFuture<'_>; fn chat(&self, request: ChatRequest) -> Self::ChatFuture<'_>; fn chat_stream(&self, request: ChatRequest) -> Self::ChatStreamFuture<'_>; fn health_check(&self) -> Self::HealthCheckFuture<'_>; fn config_schema(&self) -> serde_json::Value { serde_json::json!({}) } fn as_any(&self) -> &(dyn Any + Send + Sync) { self } } /// Helper that implements [`LLMProvider::chat`] in terms of [`LLMProvider::chat_stream`]. pub async fn chat_via_stream<'a, P>(provider: &'a P, request: ChatRequest) -> Result where P: LLMProvider + 'a, { let stream = provider.chat_stream(request).await?; let mut boxed: ChatStream = Box::pin(stream); match boxed.next().await { Some(Ok(response)) => Ok(response), Some(Err(err)) => Err(err), None => Err(Error::Provider(anyhow!( "Empty chat stream from provider {}", provider.name() ))), } } /// Object-safe wrapper trait for runtime-configurable provider usage. #[async_trait::async_trait] pub trait Provider: Send + Sync { /// Get the name of this provider. fn name(&self) -> &str; /// List available models from this provider. async fn list_models(&self) -> Result>; /// Send a chat completion request. async fn chat(&self, request: ChatRequest) -> Result; /// Send a streaming chat completion request. async fn chat_stream(&self, request: ChatRequest) -> Result; /// Check if the provider is available/healthy. async fn health_check(&self) -> Result<()>; /// Get provider-specific configuration schema. fn config_schema(&self) -> serde_json::Value { serde_json::json!({}) } fn as_any(&self) -> &(dyn Any + Send + Sync); } #[async_trait::async_trait] impl Provider for T where T: LLMProvider, { fn name(&self) -> &str { LLMProvider::name(self) } async fn list_models(&self) -> Result> { LLMProvider::list_models(self).await } async fn chat(&self, request: ChatRequest) -> Result { LLMProvider::chat(self, request).await } async fn chat_stream(&self, request: ChatRequest) -> Result { let stream = LLMProvider::chat_stream(self, request).await?; Ok(Box::pin(stream)) } async fn health_check(&self) -> Result<()> { LLMProvider::health_check(self).await } fn config_schema(&self) -> serde_json::Value { LLMProvider::config_schema(self) } fn as_any(&self) -> &(dyn Any + Send + Sync) { LLMProvider::as_any(self) } } /// Configuration for a provider #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct ProviderConfig { /// Provider type identifier pub provider_type: String, /// Base URL for API calls pub base_url: Option, /// API key or token pub api_key: Option, /// Additional provider-specific configuration #[serde(flatten)] pub extra: std::collections::HashMap, } /// A registry of providers pub struct ProviderRegistry { providers: std::collections::HashMap>, } impl ProviderRegistry { /// Create a new provider registry pub fn new() -> Self { Self { providers: std::collections::HashMap::new(), } } /// Register a provider using static dispatch. pub fn register(&mut self, provider: P) { self.register_arc(Arc::new(provider)); } /// Register an already wrapped provider pub fn register_arc(&mut self, provider: Arc) { let name = provider.name().to_string(); self.providers.insert(name, provider); } /// Get a provider by name pub fn get(&self, name: &str) -> Option> { self.providers.get(name).cloned() } /// List all registered provider names pub fn list_providers(&self) -> Vec { self.providers.keys().cloned().collect() } /// Get all models from all providers pub async fn list_all_models(&self) -> Result> { let mut all_models = Vec::new(); for provider in self.providers.values() { match provider.list_models().await { Ok(mut models) => all_models.append(&mut models), Err(_) => { // Continue with other providers } } } Ok(all_models) } } impl Default for ProviderRegistry { fn default() -> Self { Self::new() } } #[cfg(test)] pub mod test_utils { use super::*; use crate::types::{ChatRequest, ChatResponse, Message, ModelInfo, Role}; use futures::stream; use std::future::{ready, Ready}; /// Mock provider for testing #[derive(Default)] pub struct MockProvider; impl LLMProvider for MockProvider { type Stream = stream::Iter>>; type ListModelsFuture<'a> = Ready>>; type ChatFuture<'a> = Ready>; type ChatStreamFuture<'a> = Ready>; type HealthCheckFuture<'a> = Ready>; fn name(&self) -> &str { "mock" } fn list_models(&self) -> Self::ListModelsFuture<'_> { ready(Ok(vec![ModelInfo { id: "mock-model".to_string(), provider: "mock".to_string(), name: "mock-model".to_string(), description: None, context_window: None, capabilities: vec![], supports_tools: false, }])) } fn chat(&self, request: ChatRequest) -> Self::ChatFuture<'_> { ready(Ok(self.build_response(&request))) } fn chat_stream(&self, request: ChatRequest) -> Self::ChatStreamFuture<'_> { let response = self.build_response(&request); ready(Ok(stream::iter(vec![Ok(response)]))) } fn health_check(&self) -> Self::HealthCheckFuture<'_> { ready(Ok(())) } } impl MockProvider { fn build_response(&self, request: &ChatRequest) -> ChatResponse { let content = format!( "Mock response to: {}", request .messages .last() .map(|m| m.content.clone()) .unwrap_or_default() ); ChatResponse { message: Message::new(Role::Assistant, content), usage: None, is_streaming: false, is_final: true, } } } } #[cfg(test)] mod tests { use super::test_utils::MockProvider; use super::*; use crate::types::{ChatParameters, ChatRequest, ChatResponse, Message, ModelInfo, Role}; use futures::stream; use std::future::{ready, Ready}; use std::sync::Arc; struct StreamingProvider; impl LLMProvider for StreamingProvider { type Stream = stream::Iter>>; type ListModelsFuture<'a> = Ready>>; type ChatFuture<'a> = Ready>; type ChatStreamFuture<'a> = Ready>; type HealthCheckFuture<'a> = Ready>; fn name(&self) -> &str { "streaming" } fn list_models(&self) -> Self::ListModelsFuture<'_> { ready(Ok(vec![ModelInfo { id: "stream-model".to_string(), provider: "streaming".to_string(), name: "stream-model".to_string(), description: None, context_window: None, capabilities: vec!["chat".to_string()], supports_tools: false, }])) } fn chat(&self, request: ChatRequest) -> Self::ChatFuture<'_> { ready(Ok(self.response(&request))) } fn chat_stream(&self, request: ChatRequest) -> Self::ChatStreamFuture<'_> { let response = self.response(&request); ready(Ok(stream::iter(vec![Ok(response)]))) } fn health_check(&self) -> Self::HealthCheckFuture<'_> { ready(Ok(())) } } impl StreamingProvider { fn response(&self, request: &ChatRequest) -> ChatResponse { let reply = format!( "echo:{}", request .messages .last() .map(|m| m.content.clone()) .unwrap_or_default() ); ChatResponse { message: Message::new(Role::Assistant, reply), usage: None, is_streaming: true, is_final: true, } } } #[tokio::test] async fn default_chat_reads_from_stream() { let provider = StreamingProvider; let request = ChatRequest { model: "stream-model".to_string(), messages: vec![Message::new(Role::User, "ping".to_string())], parameters: ChatParameters::default(), tools: None, }; let response = LLMProvider::chat(&provider, request) .await .expect("chat succeeded"); assert_eq!(response.message.content, "echo:ping"); assert!(response.is_final); } #[tokio::test] async fn registry_registers_static_provider() { let mut registry = ProviderRegistry::new(); registry.register(StreamingProvider); let provider = registry.get("streaming").expect("provider registered"); let models = provider.list_models().await.expect("models listed"); assert_eq!(models[0].id, "stream-model"); } #[tokio::test] async fn registry_accepts_dynamic_provider() { let mut registry = ProviderRegistry::new(); let provider: Arc = Arc::new(MockProvider::default()); registry.register_arc(provider.clone()); let fetched = registry.get("mock").expect("mock provider present"); let request = ChatRequest { model: "mock-model".to_string(), messages: vec![Message::new(Role::User, "hi".to_string())], parameters: ChatParameters::default(), tools: None, }; let response = Provider::chat(fetched.as_ref(), request) .await .expect("chat succeeded"); assert_eq!(response.message.content, "Mock response to: hi"); } }