pub mod details; pub use details::{DetailedModelInfo, ModelInfoRetrievalError}; use crate::Result; use crate::types::ModelInfo; use std::collections::HashMap; use std::future::Future; use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::sync::RwLock; #[derive(Default, Debug)] struct ModelCache { models: Vec, last_refresh: Option, } /// Caches model listings for improved selection performance #[derive(Clone, Debug)] pub struct ModelManager { cache: Arc>, ttl: Duration, } impl ModelManager { /// Create a new manager with the desired cache TTL pub fn new(ttl: Duration) -> Self { Self { cache: Arc::new(RwLock::new(ModelCache::default())), ttl, } } /// Get cached models, refreshing via the provided fetcher when stale. Returns the up-to-date model list. pub async fn get_or_refresh( &self, force_refresh: bool, fetcher: F, ) -> Result> where F: FnOnce() -> Fut, Fut: Future>>, { if let (false, Some(models)) = (force_refresh, self.cached_if_fresh().await) { return Ok(models); } let models = fetcher().await?; let mut cache = self.cache.write().await; cache.models = models.clone(); cache.last_refresh = Some(Instant::now()); Ok(models) } /// Return cached models without refreshing pub async fn cached(&self) -> Vec { self.cache.read().await.models.clone() } /// Drop cached models, forcing next call to refresh pub async fn invalidate(&self) { let mut cache = self.cache.write().await; cache.models.clear(); cache.last_refresh = None; } /// Select a model by id or name from the cache pub async fn select(&self, identifier: &str) -> Option { let cache = self.cache.read().await; cache .models .iter() .find(|m| m.id == identifier || m.name == identifier) .cloned() } async fn cached_if_fresh(&self) -> Option> { let cache = self.cache.read().await; let fresh = matches!(cache.last_refresh, Some(ts) if ts.elapsed() < self.ttl); if fresh && !cache.models.is_empty() { Some(cache.models.clone()) } else { None } } } #[derive(Default, Debug)] struct ModelDetailsCacheInner { by_key: HashMap, name_to_key: HashMap, fetched_at: HashMap, } /// Cache for rich model details, indexed by digest when available. #[derive(Clone, Debug)] pub struct ModelDetailsCache { inner: Arc>, ttl: Duration, } impl ModelDetailsCache { /// Create a new details cache with the provided TTL. pub fn new(ttl: Duration) -> Self { Self { inner: Arc::new(RwLock::new(ModelDetailsCacheInner::default())), ttl, } } /// Try to read cached details for the provided model name. pub async fn get(&self, name: &str) -> Option { let mut inner = self.inner.write().await; let key = inner.name_to_key.get(name).cloned()?; let stale = inner .fetched_at .get(&key) .is_some_and(|ts| ts.elapsed() >= self.ttl); if stale { inner.by_key.remove(&key); inner.name_to_key.remove(name); inner.fetched_at.remove(&key); return None; } inner.by_key.get(&key).cloned() } /// Cache the provided details, overwriting existing entries. pub async fn insert(&self, info: DetailedModelInfo) { let key = info.digest.clone().unwrap_or_else(|| info.name.clone()); let mut inner = self.inner.write().await; // Remove prior mappings for this model name (possibly different digest). if let Some(previous_key) = inner.name_to_key.get(&info.name).cloned() && previous_key != key { inner.by_key.remove(&previous_key); inner.fetched_at.remove(&previous_key); } inner.fetched_at.insert(key.clone(), Instant::now()); inner.name_to_key.insert(info.name.clone(), key.clone()); inner.by_key.insert(key, info); } /// Remove a specific model from the cache. pub async fn invalidate(&self, name: &str) { let mut inner = self.inner.write().await; if let Some(key) = inner.name_to_key.remove(name) { inner.by_key.remove(&key); inner.fetched_at.remove(&key); } } /// Clear the entire cache. pub async fn invalidate_all(&self) { let mut inner = self.inner.write().await; inner.by_key.clear(); inner.name_to_key.clear(); inner.fetched_at.clear(); } /// Return all cached values regardless of freshness. pub async fn cached(&self) -> Vec { let inner = self.inner.read().await; inner.by_key.values().cloned().collect() } } #[cfg(test)] mod tests { use super::*; use std::time::Duration; use tokio::time::sleep; fn sample_details(name: &str) -> DetailedModelInfo { DetailedModelInfo { name: name.to_string(), ..Default::default() } } #[tokio::test] async fn model_details_cache_returns_cached_entry() { let cache = ModelDetailsCache::new(Duration::from_millis(50)); let info = sample_details("llama"); cache.insert(info.clone()).await; let cached = cache.get("llama").await; assert!(cached.is_some()); assert_eq!(cached.unwrap().name, "llama"); } #[tokio::test] async fn model_details_cache_expires_based_on_ttl() { let cache = ModelDetailsCache::new(Duration::from_millis(10)); cache.insert(sample_details("phi")).await; sleep(Duration::from_millis(30)).await; assert!(cache.get("phi").await.is_none()); } #[tokio::test] async fn model_details_cache_invalidate_removes_entry() { let cache = ModelDetailsCache::new(Duration::from_secs(1)); cache.insert(sample_details("mistral")).await; cache.invalidate("mistral").await; assert!(cache.get("mistral").await.is_none()); } }