use serde::{Deserialize, Serialize}; use std::collections::HashMap; #[derive(Clone, Debug, Deserialize, Serialize)] pub struct RoutingRule { pub model_pattern: String, pub provider: String, pub max_tokens: Option, pub temperature: Option, } #[derive(Clone, Default)] pub struct RoutingEngine { rules: Vec, fallback_chain: Vec, } impl RoutingEngine { pub fn new() -> Self { Self::default() } pub fn with_rules(mut self, rules: Vec) -> Self { self.rules = rules; self } pub fn with_fallback(mut self, chain: Vec) -> Self { self.fallback_chain = chain; self } pub fn route(&self, model: &str) -> RouteDecision { let lower = model.to_lowercase(); for rule in &self.rules { if glob_match(&rule.model_pattern.to_lowercase(), &lower) { return RouteDecision { provider: rule.provider.clone(), model: model.to_string(), max_tokens: rule.max_tokens, temperature: rule.temperature, }; } } if let Some(first) = self.fallback_chain.first() { RouteDecision { provider: first.clone(), model: model.to_string(), max_tokens: None, temperature: None, } } else { RouteDecision { provider: "ollama".to_string(), model: model.to_string(), max_tokens: None, temperature: None, } } } } #[derive(Clone, Debug)] pub struct RouteDecision { pub provider: String, pub model: String, pub max_tokens: Option, pub temperature: Option, } fn glob_match(pattern: &str, name: &str) -> bool { if pattern.contains('*') { let parts: Vec<&str> = pattern.split('*').collect(); if parts.len() == 2 { return name.starts_with(parts[0]) && name.ends_with(parts[1]); } else if parts.len() == 1 { return name.starts_with(parts[0]) || name.ends_with(parts[1]); } } pattern == name } impl Default for RoutingRule { fn default() -> Self { Self { model_pattern: "*".to_string(), provider: "ollama".to_string(), max_tokens: None, temperature: None, } } }