feat(provider): add ProviderManager to coordinate providers and cache health status
- Introduce `ProviderManager` for registering providers, routing generate calls, listing models, and refreshing health in parallel. - Maintain a status cache to expose the last known health of each provider. - Update `provider` module to re‑export the new manager alongside existing types.
This commit is contained in:
227
crates/owlen-core/src/provider/manager.rs
Normal file
227
crates/owlen-core/src/provider/manager.rs
Normal file
@@ -0,0 +1,227 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::stream::{FuturesUnordered, StreamExt};
|
||||
use log::{debug, warn};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::{Error, Result};
|
||||
|
||||
use super::{GenerateRequest, GenerateStream, ModelInfo, ModelProvider, ProviderStatus};
|
||||
|
||||
/// Model information annotated with the originating provider metadata.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AnnotatedModelInfo {
|
||||
pub provider_id: String,
|
||||
pub provider_status: ProviderStatus,
|
||||
pub model: ModelInfo,
|
||||
}
|
||||
|
||||
/// Coordinates multiple [`ModelProvider`] implementations and tracks their
|
||||
/// health state.
|
||||
pub struct ProviderManager {
|
||||
providers: RwLock<HashMap<String, Arc<dyn ModelProvider>>>,
|
||||
status_cache: RwLock<HashMap<String, ProviderStatus>>,
|
||||
}
|
||||
|
||||
impl ProviderManager {
|
||||
/// Construct a new manager using the supplied configuration. Providers
|
||||
/// defined in the configuration start with a `RequiresSetup` status so
|
||||
/// that frontends can surface incomplete configuration to users.
|
||||
pub fn new(config: &Config) -> Self {
|
||||
let mut status_cache = HashMap::new();
|
||||
for provider_id in config.providers.keys() {
|
||||
status_cache.insert(provider_id.clone(), ProviderStatus::RequiresSetup);
|
||||
}
|
||||
|
||||
Self {
|
||||
providers: RwLock::new(HashMap::new()),
|
||||
status_cache: RwLock::new(status_cache),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a provider instance with the manager.
|
||||
pub async fn register_provider(&self, provider: Arc<dyn ModelProvider>) {
|
||||
let provider_id = provider.metadata().id.clone();
|
||||
debug!("registering provider {}", provider_id);
|
||||
|
||||
self.providers
|
||||
.write()
|
||||
.await
|
||||
.insert(provider_id.clone(), provider);
|
||||
self.status_cache
|
||||
.write()
|
||||
.await
|
||||
.insert(provider_id, ProviderStatus::Unavailable);
|
||||
}
|
||||
|
||||
/// Return a stream by routing the request to the designated provider.
|
||||
pub async fn generate(
|
||||
&self,
|
||||
provider_id: &str,
|
||||
request: GenerateRequest,
|
||||
) -> Result<GenerateStream> {
|
||||
let provider = {
|
||||
let guard = self.providers.read().await;
|
||||
guard.get(provider_id).cloned()
|
||||
}
|
||||
.ok_or_else(|| Error::Config(format!("provider '{provider_id}' not registered")))?;
|
||||
|
||||
match provider.generate_stream(request).await {
|
||||
Ok(stream) => {
|
||||
self.status_cache
|
||||
.write()
|
||||
.await
|
||||
.insert(provider_id.to_string(), ProviderStatus::Available);
|
||||
Ok(stream)
|
||||
}
|
||||
Err(err) => {
|
||||
self.status_cache
|
||||
.write()
|
||||
.await
|
||||
.insert(provider_id.to_string(), ProviderStatus::Unavailable);
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// List models across all providers, updating provider status along the way.
|
||||
pub async fn list_all_models(&self) -> Result<Vec<AnnotatedModelInfo>> {
|
||||
let providers: Vec<(String, Arc<dyn ModelProvider>)> = {
|
||||
let guard = self.providers.read().await;
|
||||
guard
|
||||
.iter()
|
||||
.map(|(id, provider)| (id.clone(), Arc::clone(provider)))
|
||||
.collect()
|
||||
};
|
||||
|
||||
let mut tasks = FuturesUnordered::new();
|
||||
|
||||
for (provider_id, provider) in providers {
|
||||
tasks.push(async move {
|
||||
let log_id = provider_id.clone();
|
||||
let mut status = ProviderStatus::Unavailable;
|
||||
let mut models = Vec::new();
|
||||
|
||||
match provider.health_check().await {
|
||||
Ok(health) => {
|
||||
status = health;
|
||||
if matches!(status, ProviderStatus::Available) {
|
||||
match provider.list_models().await {
|
||||
Ok(list) => {
|
||||
models = list;
|
||||
}
|
||||
Err(err) => {
|
||||
status = ProviderStatus::Unavailable;
|
||||
warn!("listing models failed for provider {}: {}", log_id, err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("health check failed for provider {}: {}", log_id, err);
|
||||
}
|
||||
}
|
||||
|
||||
(provider_id, status, models)
|
||||
});
|
||||
}
|
||||
|
||||
let mut annotated = Vec::new();
|
||||
let mut status_updates = HashMap::new();
|
||||
|
||||
while let Some((provider_id, status, models)) = tasks.next().await {
|
||||
status_updates.insert(provider_id.clone(), status);
|
||||
for model in models {
|
||||
annotated.push(AnnotatedModelInfo {
|
||||
provider_id: provider_id.clone(),
|
||||
provider_status: status,
|
||||
model,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
let mut guard = self.status_cache.write().await;
|
||||
for (provider_id, status) in status_updates {
|
||||
guard.insert(provider_id, status);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(annotated)
|
||||
}
|
||||
|
||||
/// Refresh the health of all registered providers in parallel, returning
|
||||
/// the latest status snapshot.
|
||||
pub async fn refresh_health(&self) -> HashMap<String, ProviderStatus> {
|
||||
let providers: Vec<(String, Arc<dyn ModelProvider>)> = {
|
||||
let guard = self.providers.read().await;
|
||||
guard
|
||||
.iter()
|
||||
.map(|(id, provider)| (id.clone(), Arc::clone(provider)))
|
||||
.collect()
|
||||
};
|
||||
|
||||
let mut tasks = FuturesUnordered::new();
|
||||
for (provider_id, provider) in providers {
|
||||
tasks.push(async move {
|
||||
let status = match provider.health_check().await {
|
||||
Ok(status) => status,
|
||||
Err(err) => {
|
||||
warn!("health check failed for provider {}: {}", provider_id, err);
|
||||
ProviderStatus::Unavailable
|
||||
}
|
||||
};
|
||||
(provider_id, status)
|
||||
});
|
||||
}
|
||||
|
||||
let mut updates = HashMap::new();
|
||||
while let Some((provider_id, status)) = tasks.next().await {
|
||||
updates.insert(provider_id, status);
|
||||
}
|
||||
|
||||
{
|
||||
let mut guard = self.status_cache.write().await;
|
||||
for (provider_id, status) in &updates {
|
||||
guard.insert(provider_id.clone(), *status);
|
||||
}
|
||||
}
|
||||
|
||||
updates
|
||||
}
|
||||
|
||||
/// Return the provider instance for an identifier.
|
||||
pub async fn get_provider(&self, provider_id: &str) -> Option<Arc<dyn ModelProvider>> {
|
||||
let guard = self.providers.read().await;
|
||||
guard.get(provider_id).cloned()
|
||||
}
|
||||
|
||||
/// List the registered provider identifiers.
|
||||
pub async fn provider_ids(&self) -> Vec<String> {
|
||||
let guard = self.providers.read().await;
|
||||
guard.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Retrieve the last known status for a provider.
|
||||
pub async fn provider_status(&self, provider_id: &str) -> Option<ProviderStatus> {
|
||||
let guard = self.status_cache.read().await;
|
||||
guard.get(provider_id).copied()
|
||||
}
|
||||
|
||||
/// Snapshot the currently cached statuses.
|
||||
pub async fn provider_statuses(&self) -> HashMap<String, ProviderStatus> {
|
||||
let guard = self.status_cache.read().await;
|
||||
guard.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ProviderManager {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
providers: RwLock::new(HashMap::new()),
|
||||
status_cache: RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,8 +2,10 @@
|
||||
//!
|
||||
//! This module defines the async [`ModelProvider`] trait that all model
|
||||
//! backends implement, together with a small suite of shared data structures
|
||||
//! used for model discovery and streaming generation.
|
||||
//! used for model discovery and streaming generation. The [`ProviderManager`]
|
||||
//! orchestrates multiple providers and coordinates their health state.
|
||||
|
||||
mod manager;
|
||||
mod types;
|
||||
|
||||
use std::pin::Pin;
|
||||
@@ -11,7 +13,7 @@ use std::pin::Pin;
|
||||
use async_trait::async_trait;
|
||||
use futures::Stream;
|
||||
|
||||
pub use self::types::*;
|
||||
pub use self::{manager::*, types::*};
|
||||
|
||||
use crate::Result;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user