Add App core struct with event-handling and initialization logic for TUI.
This commit is contained in:
155
crates/owlen-core/src/router.rs
Normal file
155
crates/owlen-core/src/router.rs
Normal file
@@ -0,0 +1,155 @@
|
||||
//! Router for managing multiple providers and routing requests
|
||||
|
||||
use crate::{provider::*, types::*, Result};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// A router that can distribute requests across multiple providers
|
||||
pub struct Router {
|
||||
registry: ProviderRegistry,
|
||||
routing_rules: Vec<RoutingRule>,
|
||||
default_provider: Option<String>,
|
||||
}
|
||||
|
||||
/// A rule for routing requests to specific providers
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RoutingRule {
|
||||
/// Pattern to match against model names
|
||||
pub model_pattern: String,
|
||||
/// Provider to route to
|
||||
pub provider: String,
|
||||
/// Priority (higher numbers are checked first)
|
||||
pub priority: u32,
|
||||
}
|
||||
|
||||
impl Router {
|
||||
/// Create a new router
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
registry: ProviderRegistry::new(),
|
||||
routing_rules: Vec::new(),
|
||||
default_provider: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a provider with the router
|
||||
pub fn register_provider<P: Provider + 'static>(&mut self, provider: P) {
|
||||
self.registry.register(provider);
|
||||
}
|
||||
|
||||
/// Set the default provider
|
||||
pub fn set_default_provider(&mut self, provider_name: String) {
|
||||
self.default_provider = Some(provider_name);
|
||||
}
|
||||
|
||||
/// Add a routing rule
|
||||
pub fn add_routing_rule(&mut self, rule: RoutingRule) {
|
||||
self.routing_rules.push(rule);
|
||||
// Sort by priority (descending)
|
||||
self.routing_rules
|
||||
.sort_by(|a, b| b.priority.cmp(&a.priority));
|
||||
}
|
||||
|
||||
/// Route a request to the appropriate provider
|
||||
pub async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
|
||||
let provider = self.find_provider_for_model(&request.model)?;
|
||||
provider.chat(request).await
|
||||
}
|
||||
|
||||
/// Route a streaming request to the appropriate provider
|
||||
pub async fn chat_stream(&self, request: ChatRequest) -> Result<ChatStream> {
|
||||
let provider = self.find_provider_for_model(&request.model)?;
|
||||
provider.chat_stream(request).await
|
||||
}
|
||||
|
||||
/// List all available models from all providers
|
||||
pub async fn list_models(&self) -> Result<Vec<ModelInfo>> {
|
||||
self.registry.list_all_models().await
|
||||
}
|
||||
|
||||
/// Find the appropriate provider for a given model
|
||||
fn find_provider_for_model(&self, model: &str) -> Result<Arc<dyn Provider>> {
|
||||
// Check routing rules first
|
||||
for rule in &self.routing_rules {
|
||||
if self.matches_pattern(&rule.model_pattern, model) {
|
||||
if let Some(provider) = self.registry.get(&rule.provider) {
|
||||
return Ok(provider);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to default provider
|
||||
if let Some(default) = &self.default_provider {
|
||||
if let Some(provider) = self.registry.get(default) {
|
||||
return Ok(provider);
|
||||
}
|
||||
}
|
||||
|
||||
// If no default, try to find any provider that has this model
|
||||
// This is a fallback for cases where routing isn't configured
|
||||
for provider_name in self.registry.list_providers() {
|
||||
if let Some(provider) = self.registry.get(&provider_name) {
|
||||
return Ok(provider);
|
||||
}
|
||||
}
|
||||
|
||||
Err(crate::Error::Provider(anyhow::anyhow!(
|
||||
"No provider found for model: {}",
|
||||
model
|
||||
)))
|
||||
}
|
||||
|
||||
/// Check if a model name matches a pattern
|
||||
fn matches_pattern(&self, pattern: &str, model: &str) -> bool {
|
||||
// Simple pattern matching for now
|
||||
// Could be extended to support more complex patterns
|
||||
if pattern == "*" {
|
||||
return true;
|
||||
}
|
||||
|
||||
if pattern.ends_with('*') {
|
||||
let prefix = &pattern[..pattern.len() - 1];
|
||||
return model.starts_with(prefix);
|
||||
}
|
||||
|
||||
if pattern.starts_with('*') {
|
||||
let suffix = &pattern[1..];
|
||||
return model.ends_with(suffix);
|
||||
}
|
||||
|
||||
pattern == model
|
||||
}
|
||||
|
||||
/// Get routing configuration
|
||||
pub fn get_routing_rules(&self) -> &[RoutingRule] {
|
||||
&self.routing_rules
|
||||
}
|
||||
|
||||
/// Get the default provider name
|
||||
pub fn get_default_provider(&self) -> Option<&str> {
|
||||
self.default_provider.as_deref()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Router {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_pattern_matching() {
|
||||
let router = Router::new();
|
||||
|
||||
assert!(router.matches_pattern("*", "any-model"));
|
||||
assert!(router.matches_pattern("gpt*", "gpt-4"));
|
||||
assert!(router.matches_pattern("gpt*", "gpt-3.5-turbo"));
|
||||
assert!(!router.matches_pattern("gpt*", "claude-3"));
|
||||
assert!(router.matches_pattern("*:latest", "llama2:latest"));
|
||||
assert!(router.matches_pattern("exact-match", "exact-match"));
|
||||
assert!(!router.matches_pattern("exact-match", "different-model"));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user