feat(provider): add ProviderManager to coordinate providers and cache health status
- Introduce `ProviderManager` for registering providers, routing generate calls, listing models, and refreshing health in parallel. - Maintain a status cache to expose the last known health of each provider. - Update `provider` module to re‑export the new manager alongside existing types.
This commit is contained in:
227
crates/owlen-core/src/provider/manager.rs
Normal file
227
crates/owlen-core/src/provider/manager.rs
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use futures::stream::{FuturesUnordered, StreamExt};
|
||||||
|
use log::{debug, warn};
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
|
use crate::config::Config;
|
||||||
|
use crate::{Error, Result};
|
||||||
|
|
||||||
|
use super::{GenerateRequest, GenerateStream, ModelInfo, ModelProvider, ProviderStatus};
|
||||||
|
|
||||||
|
/// Model information annotated with the originating provider metadata.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct AnnotatedModelInfo {
|
||||||
|
pub provider_id: String,
|
||||||
|
pub provider_status: ProviderStatus,
|
||||||
|
pub model: ModelInfo,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Coordinates multiple [`ModelProvider`] implementations and tracks their
|
||||||
|
/// health state.
|
||||||
|
pub struct ProviderManager {
|
||||||
|
providers: RwLock<HashMap<String, Arc<dyn ModelProvider>>>,
|
||||||
|
status_cache: RwLock<HashMap<String, ProviderStatus>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ProviderManager {
|
||||||
|
/// Construct a new manager using the supplied configuration. Providers
|
||||||
|
/// defined in the configuration start with a `RequiresSetup` status so
|
||||||
|
/// that frontends can surface incomplete configuration to users.
|
||||||
|
pub fn new(config: &Config) -> Self {
|
||||||
|
let mut status_cache = HashMap::new();
|
||||||
|
for provider_id in config.providers.keys() {
|
||||||
|
status_cache.insert(provider_id.clone(), ProviderStatus::RequiresSetup);
|
||||||
|
}
|
||||||
|
|
||||||
|
Self {
|
||||||
|
providers: RwLock::new(HashMap::new()),
|
||||||
|
status_cache: RwLock::new(status_cache),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Register a provider instance with the manager.
|
||||||
|
pub async fn register_provider(&self, provider: Arc<dyn ModelProvider>) {
|
||||||
|
let provider_id = provider.metadata().id.clone();
|
||||||
|
debug!("registering provider {}", provider_id);
|
||||||
|
|
||||||
|
self.providers
|
||||||
|
.write()
|
||||||
|
.await
|
||||||
|
.insert(provider_id.clone(), provider);
|
||||||
|
self.status_cache
|
||||||
|
.write()
|
||||||
|
.await
|
||||||
|
.insert(provider_id, ProviderStatus::Unavailable);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return a stream by routing the request to the designated provider.
|
||||||
|
pub async fn generate(
|
||||||
|
&self,
|
||||||
|
provider_id: &str,
|
||||||
|
request: GenerateRequest,
|
||||||
|
) -> Result<GenerateStream> {
|
||||||
|
let provider = {
|
||||||
|
let guard = self.providers.read().await;
|
||||||
|
guard.get(provider_id).cloned()
|
||||||
|
}
|
||||||
|
.ok_or_else(|| Error::Config(format!("provider '{provider_id}' not registered")))?;
|
||||||
|
|
||||||
|
match provider.generate_stream(request).await {
|
||||||
|
Ok(stream) => {
|
||||||
|
self.status_cache
|
||||||
|
.write()
|
||||||
|
.await
|
||||||
|
.insert(provider_id.to_string(), ProviderStatus::Available);
|
||||||
|
Ok(stream)
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
self.status_cache
|
||||||
|
.write()
|
||||||
|
.await
|
||||||
|
.insert(provider_id.to_string(), ProviderStatus::Unavailable);
|
||||||
|
Err(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// List models across all providers, updating provider status along the way.
|
||||||
|
pub async fn list_all_models(&self) -> Result<Vec<AnnotatedModelInfo>> {
|
||||||
|
let providers: Vec<(String, Arc<dyn ModelProvider>)> = {
|
||||||
|
let guard = self.providers.read().await;
|
||||||
|
guard
|
||||||
|
.iter()
|
||||||
|
.map(|(id, provider)| (id.clone(), Arc::clone(provider)))
|
||||||
|
.collect()
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut tasks = FuturesUnordered::new();
|
||||||
|
|
||||||
|
for (provider_id, provider) in providers {
|
||||||
|
tasks.push(async move {
|
||||||
|
let log_id = provider_id.clone();
|
||||||
|
let mut status = ProviderStatus::Unavailable;
|
||||||
|
let mut models = Vec::new();
|
||||||
|
|
||||||
|
match provider.health_check().await {
|
||||||
|
Ok(health) => {
|
||||||
|
status = health;
|
||||||
|
if matches!(status, ProviderStatus::Available) {
|
||||||
|
match provider.list_models().await {
|
||||||
|
Ok(list) => {
|
||||||
|
models = list;
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
status = ProviderStatus::Unavailable;
|
||||||
|
warn!("listing models failed for provider {}: {}", log_id, err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
warn!("health check failed for provider {}: {}", log_id, err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(provider_id, status, models)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut annotated = Vec::new();
|
||||||
|
let mut status_updates = HashMap::new();
|
||||||
|
|
||||||
|
while let Some((provider_id, status, models)) = tasks.next().await {
|
||||||
|
status_updates.insert(provider_id.clone(), status);
|
||||||
|
for model in models {
|
||||||
|
annotated.push(AnnotatedModelInfo {
|
||||||
|
provider_id: provider_id.clone(),
|
||||||
|
provider_status: status,
|
||||||
|
model,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut guard = self.status_cache.write().await;
|
||||||
|
for (provider_id, status) in status_updates {
|
||||||
|
guard.insert(provider_id, status);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(annotated)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Refresh the health of all registered providers in parallel, returning
|
||||||
|
/// the latest status snapshot.
|
||||||
|
pub async fn refresh_health(&self) -> HashMap<String, ProviderStatus> {
|
||||||
|
let providers: Vec<(String, Arc<dyn ModelProvider>)> = {
|
||||||
|
let guard = self.providers.read().await;
|
||||||
|
guard
|
||||||
|
.iter()
|
||||||
|
.map(|(id, provider)| (id.clone(), Arc::clone(provider)))
|
||||||
|
.collect()
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut tasks = FuturesUnordered::new();
|
||||||
|
for (provider_id, provider) in providers {
|
||||||
|
tasks.push(async move {
|
||||||
|
let status = match provider.health_check().await {
|
||||||
|
Ok(status) => status,
|
||||||
|
Err(err) => {
|
||||||
|
warn!("health check failed for provider {}: {}", provider_id, err);
|
||||||
|
ProviderStatus::Unavailable
|
||||||
|
}
|
||||||
|
};
|
||||||
|
(provider_id, status)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut updates = HashMap::new();
|
||||||
|
while let Some((provider_id, status)) = tasks.next().await {
|
||||||
|
updates.insert(provider_id, status);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut guard = self.status_cache.write().await;
|
||||||
|
for (provider_id, status) in &updates {
|
||||||
|
guard.insert(provider_id.clone(), *status);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
updates
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the provider instance for an identifier.
|
||||||
|
pub async fn get_provider(&self, provider_id: &str) -> Option<Arc<dyn ModelProvider>> {
|
||||||
|
let guard = self.providers.read().await;
|
||||||
|
guard.get(provider_id).cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// List the registered provider identifiers.
|
||||||
|
pub async fn provider_ids(&self) -> Vec<String> {
|
||||||
|
let guard = self.providers.read().await;
|
||||||
|
guard.keys().cloned().collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Retrieve the last known status for a provider.
|
||||||
|
pub async fn provider_status(&self, provider_id: &str) -> Option<ProviderStatus> {
|
||||||
|
let guard = self.status_cache.read().await;
|
||||||
|
guard.get(provider_id).copied()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Snapshot the currently cached statuses.
|
||||||
|
pub async fn provider_statuses(&self) -> HashMap<String, ProviderStatus> {
|
||||||
|
let guard = self.status_cache.read().await;
|
||||||
|
guard.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ProviderManager {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
providers: RwLock::new(HashMap::new()),
|
||||||
|
status_cache: RwLock::new(HashMap::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,8 +2,10 @@
|
|||||||
//!
|
//!
|
||||||
//! This module defines the async [`ModelProvider`] trait that all model
|
//! This module defines the async [`ModelProvider`] trait that all model
|
||||||
//! backends implement, together with a small suite of shared data structures
|
//! backends implement, together with a small suite of shared data structures
|
||||||
//! used for model discovery and streaming generation.
|
//! used for model discovery and streaming generation. The [`ProviderManager`]
|
||||||
|
//! orchestrates multiple providers and coordinates their health state.
|
||||||
|
|
||||||
|
mod manager;
|
||||||
mod types;
|
mod types;
|
||||||
|
|
||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
@@ -11,7 +13,7 @@ use std::pin::Pin;
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use futures::Stream;
|
use futures::Stream;
|
||||||
|
|
||||||
pub use self::types::*;
|
pub use self::{manager::*, types::*};
|
||||||
|
|
||||||
use crate::Result;
|
use crate::Result;
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user