diff --git a/crates/owlen-core/src/provider/manager.rs b/crates/owlen-core/src/provider/manager.rs new file mode 100644 index 0000000..7b197e0 --- /dev/null +++ b/crates/owlen-core/src/provider/manager.rs @@ -0,0 +1,227 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use futures::stream::{FuturesUnordered, StreamExt}; +use log::{debug, warn}; +use tokio::sync::RwLock; + +use crate::config::Config; +use crate::{Error, Result}; + +use super::{GenerateRequest, GenerateStream, ModelInfo, ModelProvider, ProviderStatus}; + +/// Model information annotated with the originating provider metadata. +#[derive(Debug, Clone)] +pub struct AnnotatedModelInfo { + pub provider_id: String, + pub provider_status: ProviderStatus, + pub model: ModelInfo, +} + +/// Coordinates multiple [`ModelProvider`] implementations and tracks their +/// health state. +pub struct ProviderManager { + providers: RwLock>>, + status_cache: RwLock>, +} + +impl ProviderManager { + /// Construct a new manager using the supplied configuration. Providers + /// defined in the configuration start with a `RequiresSetup` status so + /// that frontends can surface incomplete configuration to users. + pub fn new(config: &Config) -> Self { + let mut status_cache = HashMap::new(); + for provider_id in config.providers.keys() { + status_cache.insert(provider_id.clone(), ProviderStatus::RequiresSetup); + } + + Self { + providers: RwLock::new(HashMap::new()), + status_cache: RwLock::new(status_cache), + } + } + + /// Register a provider instance with the manager. + pub async fn register_provider(&self, provider: Arc) { + let provider_id = provider.metadata().id.clone(); + debug!("registering provider {}", provider_id); + + self.providers + .write() + .await + .insert(provider_id.clone(), provider); + self.status_cache + .write() + .await + .insert(provider_id, ProviderStatus::Unavailable); + } + + /// Return a stream by routing the request to the designated provider. + pub async fn generate( + &self, + provider_id: &str, + request: GenerateRequest, + ) -> Result { + let provider = { + let guard = self.providers.read().await; + guard.get(provider_id).cloned() + } + .ok_or_else(|| Error::Config(format!("provider '{provider_id}' not registered")))?; + + match provider.generate_stream(request).await { + Ok(stream) => { + self.status_cache + .write() + .await + .insert(provider_id.to_string(), ProviderStatus::Available); + Ok(stream) + } + Err(err) => { + self.status_cache + .write() + .await + .insert(provider_id.to_string(), ProviderStatus::Unavailable); + Err(err) + } + } + } + + /// List models across all providers, updating provider status along the way. + pub async fn list_all_models(&self) -> Result> { + let providers: Vec<(String, Arc)> = { + let guard = self.providers.read().await; + guard + .iter() + .map(|(id, provider)| (id.clone(), Arc::clone(provider))) + .collect() + }; + + let mut tasks = FuturesUnordered::new(); + + for (provider_id, provider) in providers { + tasks.push(async move { + let log_id = provider_id.clone(); + let mut status = ProviderStatus::Unavailable; + let mut models = Vec::new(); + + match provider.health_check().await { + Ok(health) => { + status = health; + if matches!(status, ProviderStatus::Available) { + match provider.list_models().await { + Ok(list) => { + models = list; + } + Err(err) => { + status = ProviderStatus::Unavailable; + warn!("listing models failed for provider {}: {}", log_id, err); + } + } + } + } + Err(err) => { + warn!("health check failed for provider {}: {}", log_id, err); + } + } + + (provider_id, status, models) + }); + } + + let mut annotated = Vec::new(); + let mut status_updates = HashMap::new(); + + while let Some((provider_id, status, models)) = tasks.next().await { + status_updates.insert(provider_id.clone(), status); + for model in models { + annotated.push(AnnotatedModelInfo { + provider_id: provider_id.clone(), + provider_status: status, + model, + }); + } + } + + { + let mut guard = self.status_cache.write().await; + for (provider_id, status) in status_updates { + guard.insert(provider_id, status); + } + } + + Ok(annotated) + } + + /// Refresh the health of all registered providers in parallel, returning + /// the latest status snapshot. + pub async fn refresh_health(&self) -> HashMap { + let providers: Vec<(String, Arc)> = { + let guard = self.providers.read().await; + guard + .iter() + .map(|(id, provider)| (id.clone(), Arc::clone(provider))) + .collect() + }; + + let mut tasks = FuturesUnordered::new(); + for (provider_id, provider) in providers { + tasks.push(async move { + let status = match provider.health_check().await { + Ok(status) => status, + Err(err) => { + warn!("health check failed for provider {}: {}", provider_id, err); + ProviderStatus::Unavailable + } + }; + (provider_id, status) + }); + } + + let mut updates = HashMap::new(); + while let Some((provider_id, status)) = tasks.next().await { + updates.insert(provider_id, status); + } + + { + let mut guard = self.status_cache.write().await; + for (provider_id, status) in &updates { + guard.insert(provider_id.clone(), *status); + } + } + + updates + } + + /// Return the provider instance for an identifier. + pub async fn get_provider(&self, provider_id: &str) -> Option> { + let guard = self.providers.read().await; + guard.get(provider_id).cloned() + } + + /// List the registered provider identifiers. + pub async fn provider_ids(&self) -> Vec { + let guard = self.providers.read().await; + guard.keys().cloned().collect() + } + + /// Retrieve the last known status for a provider. + pub async fn provider_status(&self, provider_id: &str) -> Option { + let guard = self.status_cache.read().await; + guard.get(provider_id).copied() + } + + /// Snapshot the currently cached statuses. + pub async fn provider_statuses(&self) -> HashMap { + let guard = self.status_cache.read().await; + guard.clone() + } +} + +impl Default for ProviderManager { + fn default() -> Self { + Self { + providers: RwLock::new(HashMap::new()), + status_cache: RwLock::new(HashMap::new()), + } + } +} diff --git a/crates/owlen-core/src/provider/mod.rs b/crates/owlen-core/src/provider/mod.rs index 7424a5b..5055ec9 100644 --- a/crates/owlen-core/src/provider/mod.rs +++ b/crates/owlen-core/src/provider/mod.rs @@ -2,8 +2,10 @@ //! //! This module defines the async [`ModelProvider`] trait that all model //! backends implement, together with a small suite of shared data structures -//! used for model discovery and streaming generation. +//! used for model discovery and streaming generation. The [`ProviderManager`] +//! orchestrates multiple providers and coordinates their health state. +mod manager; mod types; use std::pin::Pin; @@ -11,7 +13,7 @@ use std::pin::Pin; use async_trait::async_trait; use futures::Stream; -pub use self::types::*; +pub use self::{manager::*, types::*}; use crate::Result;