- RoutingEngine with RouteDecision (model_pattern → provider) - config/routing.toml: rules, fallback chain, cost gating - Per-provider Usage tracking in /v1/usage response - 12 gateway tests green
94 lines
2.4 KiB
Rust
94 lines
2.4 KiB
Rust
use serde::{Deserialize, Serialize};
|
|
use std::collections::HashMap;
|
|
|
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
|
pub struct RoutingRule {
|
|
pub model_pattern: String,
|
|
pub provider: String,
|
|
pub max_tokens: Option<u32>,
|
|
pub temperature: Option<f64>,
|
|
}
|
|
|
|
#[derive(Clone, Default)]
|
|
pub struct RoutingEngine {
|
|
rules: Vec<RoutingRule>,
|
|
fallback_chain: Vec<String>,
|
|
}
|
|
|
|
impl RoutingEngine {
|
|
pub fn new() -> Self {
|
|
Self::default()
|
|
}
|
|
|
|
pub fn with_rules(mut self, rules: Vec<RoutingRule>) -> Self {
|
|
self.rules = rules;
|
|
self
|
|
}
|
|
|
|
pub fn with_fallback(mut self, chain: Vec<String>) -> Self {
|
|
self.fallback_chain = chain;
|
|
self
|
|
}
|
|
|
|
pub fn route(&self, model: &str) -> RouteDecision {
|
|
let lower = model.to_lowercase();
|
|
|
|
for rule in &self.rules {
|
|
if glob_match(&rule.model_pattern.to_lowercase(), &lower) {
|
|
return RouteDecision {
|
|
provider: rule.provider.clone(),
|
|
model: model.to_string(),
|
|
max_tokens: rule.max_tokens,
|
|
temperature: rule.temperature,
|
|
};
|
|
}
|
|
}
|
|
|
|
if let Some(first) = self.fallback_chain.first() {
|
|
RouteDecision {
|
|
provider: first.clone(),
|
|
model: model.to_string(),
|
|
max_tokens: None,
|
|
temperature: None,
|
|
}
|
|
} else {
|
|
RouteDecision {
|
|
provider: "ollama".to_string(),
|
|
model: model.to_string(),
|
|
max_tokens: None,
|
|
temperature: None,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub struct RouteDecision {
|
|
pub provider: String,
|
|
pub model: String,
|
|
pub max_tokens: Option<u32>,
|
|
pub temperature: Option<f64>,
|
|
}
|
|
|
|
fn glob_match(pattern: &str, name: &str) -> bool {
|
|
if pattern.contains('*') {
|
|
let parts: Vec<&str> = pattern.split('*').collect();
|
|
if parts.len() == 2 {
|
|
return name.starts_with(parts[0]) && name.ends_with(parts[1]);
|
|
} else if parts.len() == 1 {
|
|
return name.starts_with(parts[0]) || name.ends_with(parts[1]);
|
|
}
|
|
}
|
|
pattern == name
|
|
}
|
|
|
|
impl Default for RoutingRule {
|
|
fn default() -> Self {
|
|
Self {
|
|
model_pattern: "*".to_string(),
|
|
provider: "ollama".to_string(),
|
|
max_tokens: None,
|
|
temperature: None,
|
|
}
|
|
}
|
|
} |