154 lines
4.7 KiB
Rust
154 lines
4.7 KiB
Rust
//! Router for managing multiple providers and routing requests
|
|
|
|
use crate::{provider::*, types::*, Result};
|
|
use std::sync::Arc;
|
|
|
|
/// A router that can distribute requests across multiple providers
|
|
pub struct Router {
|
|
registry: ProviderRegistry,
|
|
routing_rules: Vec<RoutingRule>,
|
|
default_provider: Option<String>,
|
|
}
|
|
|
|
/// A rule for routing requests to specific providers
|
|
#[derive(Debug, Clone)]
|
|
pub struct RoutingRule {
|
|
/// Pattern to match against model names
|
|
pub model_pattern: String,
|
|
/// Provider to route to
|
|
pub provider: String,
|
|
/// Priority (higher numbers are checked first)
|
|
pub priority: u32,
|
|
}
|
|
|
|
impl Router {
|
|
/// Create a new router
|
|
pub fn new() -> Self {
|
|
Self {
|
|
registry: ProviderRegistry::new(),
|
|
routing_rules: Vec::new(),
|
|
default_provider: None,
|
|
}
|
|
}
|
|
|
|
/// Register a provider with the router
|
|
pub fn register_provider<P: Provider + 'static>(&mut self, provider: P) {
|
|
self.registry.register(provider);
|
|
}
|
|
|
|
/// Set the default provider
|
|
pub fn set_default_provider(&mut self, provider_name: String) {
|
|
self.default_provider = Some(provider_name);
|
|
}
|
|
|
|
/// Add a routing rule
|
|
pub fn add_routing_rule(&mut self, rule: RoutingRule) {
|
|
self.routing_rules.push(rule);
|
|
// Sort by priority (descending)
|
|
self.routing_rules
|
|
.sort_by(|a, b| b.priority.cmp(&a.priority));
|
|
}
|
|
|
|
/// Route a request to the appropriate provider
|
|
pub async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
|
|
let provider = self.find_provider_for_model(&request.model)?;
|
|
provider.chat(request).await
|
|
}
|
|
|
|
/// Route a streaming request to the appropriate provider
|
|
pub async fn chat_stream(&self, request: ChatRequest) -> Result<ChatStream> {
|
|
let provider = self.find_provider_for_model(&request.model)?;
|
|
provider.chat_stream(request).await
|
|
}
|
|
|
|
/// List all available models from all providers
|
|
pub async fn list_models(&self) -> Result<Vec<ModelInfo>> {
|
|
self.registry.list_all_models().await
|
|
}
|
|
|
|
/// Find the appropriate provider for a given model
|
|
fn find_provider_for_model(&self, model: &str) -> Result<Arc<dyn Provider>> {
|
|
// Check routing rules first
|
|
for rule in &self.routing_rules {
|
|
if self.matches_pattern(&rule.model_pattern, model) {
|
|
if let Some(provider) = self.registry.get(&rule.provider) {
|
|
return Ok(provider);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Fall back to default provider
|
|
if let Some(default) = &self.default_provider {
|
|
if let Some(provider) = self.registry.get(default) {
|
|
return Ok(provider);
|
|
}
|
|
}
|
|
|
|
// If no default, try to find any provider that has this model
|
|
// This is a fallback for cases where routing isn't configured
|
|
for provider_name in self.registry.list_providers() {
|
|
if let Some(provider) = self.registry.get(&provider_name) {
|
|
return Ok(provider);
|
|
}
|
|
}
|
|
|
|
Err(crate::Error::Provider(anyhow::anyhow!(
|
|
"No provider found for model: {}",
|
|
model
|
|
)))
|
|
}
|
|
|
|
/// Check if a model name matches a pattern
|
|
fn matches_pattern(&self, pattern: &str, model: &str) -> bool {
|
|
// Simple pattern matching for now
|
|
// Could be extended to support more complex patterns
|
|
if pattern == "*" {
|
|
return true;
|
|
}
|
|
|
|
if let Some(prefix) = pattern.strip_suffix('*') {
|
|
return model.starts_with(prefix);
|
|
}
|
|
|
|
if let Some(suffix) = pattern.strip_prefix('*') {
|
|
return model.ends_with(suffix);
|
|
}
|
|
|
|
pattern == model
|
|
}
|
|
|
|
/// Get routing configuration
|
|
pub fn get_routing_rules(&self) -> &[RoutingRule] {
|
|
&self.routing_rules
|
|
}
|
|
|
|
/// Get the default provider name
|
|
pub fn get_default_provider(&self) -> Option<&str> {
|
|
self.default_provider.as_deref()
|
|
}
|
|
}
|
|
|
|
impl Default for Router {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_pattern_matching() {
|
|
let router = Router::new();
|
|
|
|
assert!(router.matches_pattern("*", "any-model"));
|
|
assert!(router.matches_pattern("gpt*", "gpt-4"));
|
|
assert!(router.matches_pattern("gpt*", "gpt-3.5-turbo"));
|
|
assert!(!router.matches_pattern("gpt*", "claude-3"));
|
|
assert!(router.matches_pattern("*:latest", "llama2:latest"));
|
|
assert!(router.matches_pattern("exact-match", "exact-match"));
|
|
assert!(!router.matches_pattern("exact-match", "different-model"));
|
|
}
|
|
}
|