Files
owlen/crates/owlen-core/src/router.rs

154 lines
4.7 KiB
Rust

//! Router for managing multiple providers and routing requests
use crate::{provider::*, types::*, Result};
use std::sync::Arc;
/// A router that can distribute requests across multiple providers
pub struct Router {
registry: ProviderRegistry,
routing_rules: Vec<RoutingRule>,
default_provider: Option<String>,
}
/// A rule for routing requests to specific providers
#[derive(Debug, Clone)]
pub struct RoutingRule {
/// Pattern to match against model names
pub model_pattern: String,
/// Provider to route to
pub provider: String,
/// Priority (higher numbers are checked first)
pub priority: u32,
}
impl Router {
/// Create a new router
pub fn new() -> Self {
Self {
registry: ProviderRegistry::new(),
routing_rules: Vec::new(),
default_provider: None,
}
}
/// Register a provider with the router
pub fn register_provider<P: Provider + 'static>(&mut self, provider: P) {
self.registry.register(provider);
}
/// Set the default provider
pub fn set_default_provider(&mut self, provider_name: String) {
self.default_provider = Some(provider_name);
}
/// Add a routing rule
pub fn add_routing_rule(&mut self, rule: RoutingRule) {
self.routing_rules.push(rule);
// Sort by priority (descending)
self.routing_rules
.sort_by(|a, b| b.priority.cmp(&a.priority));
}
/// Route a request to the appropriate provider
pub async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
let provider = self.find_provider_for_model(&request.model)?;
provider.chat(request).await
}
/// Route a streaming request to the appropriate provider
pub async fn chat_stream(&self, request: ChatRequest) -> Result<ChatStream> {
let provider = self.find_provider_for_model(&request.model)?;
provider.chat_stream(request).await
}
/// List all available models from all providers
pub async fn list_models(&self) -> Result<Vec<ModelInfo>> {
self.registry.list_all_models().await
}
/// Find the appropriate provider for a given model
fn find_provider_for_model(&self, model: &str) -> Result<Arc<dyn Provider>> {
// Check routing rules first
for rule in &self.routing_rules {
if self.matches_pattern(&rule.model_pattern, model) {
if let Some(provider) = self.registry.get(&rule.provider) {
return Ok(provider);
}
}
}
// Fall back to default provider
if let Some(default) = &self.default_provider {
if let Some(provider) = self.registry.get(default) {
return Ok(provider);
}
}
// If no default, try to find any provider that has this model
// This is a fallback for cases where routing isn't configured
for provider_name in self.registry.list_providers() {
if let Some(provider) = self.registry.get(&provider_name) {
return Ok(provider);
}
}
Err(crate::Error::Provider(anyhow::anyhow!(
"No provider found for model: {}",
model
)))
}
/// Check if a model name matches a pattern
fn matches_pattern(&self, pattern: &str, model: &str) -> bool {
// Simple pattern matching for now
// Could be extended to support more complex patterns
if pattern == "*" {
return true;
}
if let Some(prefix) = pattern.strip_suffix('*') {
return model.starts_with(prefix);
}
if let Some(suffix) = pattern.strip_prefix('*') {
return model.ends_with(suffix);
}
pattern == model
}
/// Get routing configuration
pub fn get_routing_rules(&self) -> &[RoutingRule] {
&self.routing_rules
}
/// Get the default provider name
pub fn get_default_provider(&self) -> Option<&str> {
self.default_provider.as_deref()
}
}
impl Default for Router {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pattern_matching() {
let router = Router::new();
assert!(router.matches_pattern("*", "any-model"));
assert!(router.matches_pattern("gpt*", "gpt-4"));
assert!(router.matches_pattern("gpt*", "gpt-3.5-turbo"));
assert!(!router.matches_pattern("gpt*", "claude-3"));
assert!(router.matches_pattern("*:latest", "llama2:latest"));
assert!(router.matches_pattern("exact-match", "exact-match"));
assert!(!router.matches_pattern("exact-match", "different-model"));
}
}