//! Router for managing multiple providers and routing requests use crate::{provider::*, types::*, Result}; use std::sync::Arc; /// A router that can distribute requests across multiple providers pub struct Router { registry: ProviderRegistry, routing_rules: Vec, default_provider: Option, } /// A rule for routing requests to specific providers #[derive(Debug, Clone)] pub struct RoutingRule { /// Pattern to match against model names pub model_pattern: String, /// Provider to route to pub provider: String, /// Priority (higher numbers are checked first) pub priority: u32, } impl Router { /// Create a new router pub fn new() -> Self { Self { registry: ProviderRegistry::new(), routing_rules: Vec::new(), default_provider: None, } } /// Register a provider with the router pub fn register_provider(&mut self, provider: P) { self.registry.register(provider); } /// Set the default provider pub fn set_default_provider(&mut self, provider_name: String) { self.default_provider = Some(provider_name); } /// Add a routing rule pub fn add_routing_rule(&mut self, rule: RoutingRule) { self.routing_rules.push(rule); // Sort by priority (descending) self.routing_rules .sort_by(|a, b| b.priority.cmp(&a.priority)); } /// Route a request to the appropriate provider pub async fn chat(&self, request: ChatRequest) -> Result { let provider = self.find_provider_for_model(&request.model)?; provider.chat(request).await } /// Route a streaming request to the appropriate provider pub async fn chat_stream(&self, request: ChatRequest) -> Result { let provider = self.find_provider_for_model(&request.model)?; provider.chat_stream(request).await } /// List all available models from all providers pub async fn list_models(&self) -> Result> { self.registry.list_all_models().await } /// Find the appropriate provider for a given model fn find_provider_for_model(&self, model: &str) -> Result> { // Check routing rules first for rule in &self.routing_rules { if self.matches_pattern(&rule.model_pattern, model) { if let Some(provider) = self.registry.get(&rule.provider) { return Ok(provider); } } } // Fall back to default provider if let Some(default) = &self.default_provider { if let Some(provider) = self.registry.get(default) { return Ok(provider); } } // If no default, try to find any provider that has this model // This is a fallback for cases where routing isn't configured for provider_name in self.registry.list_providers() { if let Some(provider) = self.registry.get(&provider_name) { return Ok(provider); } } Err(crate::Error::Provider(anyhow::anyhow!( "No provider found for model: {}", model ))) } /// Check if a model name matches a pattern fn matches_pattern(&self, pattern: &str, model: &str) -> bool { // Simple pattern matching for now // Could be extended to support more complex patterns if pattern == "*" { return true; } if let Some(prefix) = pattern.strip_suffix('*') { return model.starts_with(prefix); } if let Some(suffix) = pattern.strip_prefix('*') { return model.ends_with(suffix); } pattern == model } /// Get routing configuration pub fn get_routing_rules(&self) -> &[RoutingRule] { &self.routing_rules } /// Get the default provider name pub fn get_default_provider(&self) -> Option<&str> { self.default_provider.as_deref() } } impl Default for Router { fn default() -> Self { Self::new() } } #[cfg(test)] mod tests { use super::*; #[test] fn test_pattern_matching() { let router = Router::new(); assert!(router.matches_pattern("*", "any-model")); assert!(router.matches_pattern("gpt*", "gpt-4")); assert!(router.matches_pattern("gpt*", "gpt-3.5-turbo")); assert!(!router.matches_pattern("gpt*", "claude-3")); assert!(router.matches_pattern("*:latest", "llama2:latest")); assert!(router.matches_pattern("exact-match", "exact-match")); assert!(!router.matches_pattern("exact-match", "different-model")); } }