feat(provider): add ProviderManager to coordinate providers and cache health status

- Introduce `ProviderManager` for registering providers, routing generate calls, listing models, and refreshing health in parallel.
- Maintain a status cache to expose the last known health of each provider.
- Update `provider` module to re‑export the new manager alongside existing types.
This commit is contained in:
2025-10-15 20:37:36 +02:00
parent 641c95131f
commit 9d85420bf6
2 changed files with 231 additions and 2 deletions

View File

@@ -0,0 +1,227 @@
use std::collections::HashMap;
use std::sync::Arc;
use futures::stream::{FuturesUnordered, StreamExt};
use log::{debug, warn};
use tokio::sync::RwLock;
use crate::config::Config;
use crate::{Error, Result};
use super::{GenerateRequest, GenerateStream, ModelInfo, ModelProvider, ProviderStatus};
/// Model information annotated with the originating provider metadata.
#[derive(Debug, Clone)]
pub struct AnnotatedModelInfo {
pub provider_id: String,
pub provider_status: ProviderStatus,
pub model: ModelInfo,
}
/// Coordinates multiple [`ModelProvider`] implementations and tracks their
/// health state.
pub struct ProviderManager {
providers: RwLock<HashMap<String, Arc<dyn ModelProvider>>>,
status_cache: RwLock<HashMap<String, ProviderStatus>>,
}
impl ProviderManager {
/// Construct a new manager using the supplied configuration. Providers
/// defined in the configuration start with a `RequiresSetup` status so
/// that frontends can surface incomplete configuration to users.
pub fn new(config: &Config) -> Self {
let mut status_cache = HashMap::new();
for provider_id in config.providers.keys() {
status_cache.insert(provider_id.clone(), ProviderStatus::RequiresSetup);
}
Self {
providers: RwLock::new(HashMap::new()),
status_cache: RwLock::new(status_cache),
}
}
/// Register a provider instance with the manager.
pub async fn register_provider(&self, provider: Arc<dyn ModelProvider>) {
let provider_id = provider.metadata().id.clone();
debug!("registering provider {}", provider_id);
self.providers
.write()
.await
.insert(provider_id.clone(), provider);
self.status_cache
.write()
.await
.insert(provider_id, ProviderStatus::Unavailable);
}
/// Return a stream by routing the request to the designated provider.
pub async fn generate(
&self,
provider_id: &str,
request: GenerateRequest,
) -> Result<GenerateStream> {
let provider = {
let guard = self.providers.read().await;
guard.get(provider_id).cloned()
}
.ok_or_else(|| Error::Config(format!("provider '{provider_id}' not registered")))?;
match provider.generate_stream(request).await {
Ok(stream) => {
self.status_cache
.write()
.await
.insert(provider_id.to_string(), ProviderStatus::Available);
Ok(stream)
}
Err(err) => {
self.status_cache
.write()
.await
.insert(provider_id.to_string(), ProviderStatus::Unavailable);
Err(err)
}
}
}
/// List models across all providers, updating provider status along the way.
pub async fn list_all_models(&self) -> Result<Vec<AnnotatedModelInfo>> {
let providers: Vec<(String, Arc<dyn ModelProvider>)> = {
let guard = self.providers.read().await;
guard
.iter()
.map(|(id, provider)| (id.clone(), Arc::clone(provider)))
.collect()
};
let mut tasks = FuturesUnordered::new();
for (provider_id, provider) in providers {
tasks.push(async move {
let log_id = provider_id.clone();
let mut status = ProviderStatus::Unavailable;
let mut models = Vec::new();
match provider.health_check().await {
Ok(health) => {
status = health;
if matches!(status, ProviderStatus::Available) {
match provider.list_models().await {
Ok(list) => {
models = list;
}
Err(err) => {
status = ProviderStatus::Unavailable;
warn!("listing models failed for provider {}: {}", log_id, err);
}
}
}
}
Err(err) => {
warn!("health check failed for provider {}: {}", log_id, err);
}
}
(provider_id, status, models)
});
}
let mut annotated = Vec::new();
let mut status_updates = HashMap::new();
while let Some((provider_id, status, models)) = tasks.next().await {
status_updates.insert(provider_id.clone(), status);
for model in models {
annotated.push(AnnotatedModelInfo {
provider_id: provider_id.clone(),
provider_status: status,
model,
});
}
}
{
let mut guard = self.status_cache.write().await;
for (provider_id, status) in status_updates {
guard.insert(provider_id, status);
}
}
Ok(annotated)
}
/// Refresh the health of all registered providers in parallel, returning
/// the latest status snapshot.
pub async fn refresh_health(&self) -> HashMap<String, ProviderStatus> {
let providers: Vec<(String, Arc<dyn ModelProvider>)> = {
let guard = self.providers.read().await;
guard
.iter()
.map(|(id, provider)| (id.clone(), Arc::clone(provider)))
.collect()
};
let mut tasks = FuturesUnordered::new();
for (provider_id, provider) in providers {
tasks.push(async move {
let status = match provider.health_check().await {
Ok(status) => status,
Err(err) => {
warn!("health check failed for provider {}: {}", provider_id, err);
ProviderStatus::Unavailable
}
};
(provider_id, status)
});
}
let mut updates = HashMap::new();
while let Some((provider_id, status)) = tasks.next().await {
updates.insert(provider_id, status);
}
{
let mut guard = self.status_cache.write().await;
for (provider_id, status) in &updates {
guard.insert(provider_id.clone(), *status);
}
}
updates
}
/// Return the provider instance for an identifier.
pub async fn get_provider(&self, provider_id: &str) -> Option<Arc<dyn ModelProvider>> {
let guard = self.providers.read().await;
guard.get(provider_id).cloned()
}
/// List the registered provider identifiers.
pub async fn provider_ids(&self) -> Vec<String> {
let guard = self.providers.read().await;
guard.keys().cloned().collect()
}
/// Retrieve the last known status for a provider.
pub async fn provider_status(&self, provider_id: &str) -> Option<ProviderStatus> {
let guard = self.status_cache.read().await;
guard.get(provider_id).copied()
}
/// Snapshot the currently cached statuses.
pub async fn provider_statuses(&self) -> HashMap<String, ProviderStatus> {
let guard = self.status_cache.read().await;
guard.clone()
}
}
impl Default for ProviderManager {
fn default() -> Self {
Self {
providers: RwLock::new(HashMap::new()),
status_cache: RwLock::new(HashMap::new()),
}
}
}

View File

@@ -2,8 +2,10 @@
//! //!
//! This module defines the async [`ModelProvider`] trait that all model //! This module defines the async [`ModelProvider`] trait that all model
//! backends implement, together with a small suite of shared data structures //! backends implement, together with a small suite of shared data structures
//! used for model discovery and streaming generation. //! used for model discovery and streaming generation. The [`ProviderManager`]
//! orchestrates multiple providers and coordinates their health state.
mod manager;
mod types; mod types;
use std::pin::Pin; use std::pin::Pin;
@@ -11,7 +13,7 @@ use std::pin::Pin;
use async_trait::async_trait; use async_trait::async_trait;
use futures::Stream; use futures::Stream;
pub use self::types::*; pub use self::{manager::*, types::*};
use crate::Result; use crate::Result;