Phase 40: Routing Engine + Policy
- RoutingEngine with RouteDecision (model_pattern → provider) - config/routing.toml: rules, fallback chain, cost gating - Per-provider Usage tracking in /v1/usage response - 12 gateway tests green
This commit is contained in:
parent
e27a17e950
commit
55f8e0fe6e
61
config/routing.toml
Normal file
61
config/routing.toml
Normal file
@ -0,0 +1,61 @@
|
||||
# Phase 40: Routing Engine Configuration
|
||||
#
|
||||
# Human-editable rules for model → provider routing.
|
||||
# Matching order: first match wins.
|
||||
|
||||
[[rule]]
|
||||
model_pattern = "gpt-4*"
|
||||
provider = "openrouter"
|
||||
max_tokens = 4096
|
||||
temperature = 0.7
|
||||
|
||||
[[rule]]
|
||||
model_pattern = "claude*"
|
||||
provider = "claude"
|
||||
max_tokens = 4096
|
||||
temperature = 0.7
|
||||
|
||||
[[rule]]
|
||||
model_pattern = "gemini*"
|
||||
provider = "gemini"
|
||||
max_tokens = 8192
|
||||
temperature = 0.9
|
||||
|
||||
[[rule]]
|
||||
model_pattern = "qwen3.5*"
|
||||
provider = "ollama"
|
||||
max_tokens = 4096
|
||||
temperature = 0.3
|
||||
|
||||
[[rule]]
|
||||
model_pattern = "qwen3*"
|
||||
provider = "ollama"
|
||||
max_tokens = 2048
|
||||
temperature = 0.3
|
||||
|
||||
[[rule]]
|
||||
model_pattern = "gpt-oss*"
|
||||
provider = "ollama"
|
||||
temperature = 0.1
|
||||
|
||||
[[rule]]
|
||||
model_pattern = "*"
|
||||
provider = "ollama"
|
||||
temperature = 0.5
|
||||
|
||||
# Fallback chain: if primary fails, try these in order
|
||||
fallback = ["ollama", "openrouter"]
|
||||
|
||||
# Cost gating (tokens = cents per 1M)
|
||||
[cost]
|
||||
ollama = 0
|
||||
openrouter = 15
|
||||
claude = 15
|
||||
gemini = 0
|
||||
|
||||
# Daily budget per provider (cents)
|
||||
[daily_budget]
|
||||
ollama = 0
|
||||
openrouter = 1000
|
||||
claude = 500
|
||||
gemini = 0
|
||||
@ -3,5 +3,6 @@ pub mod context;
|
||||
pub mod continuation;
|
||||
pub mod provider;
|
||||
pub mod providers;
|
||||
pub mod routing;
|
||||
pub mod service;
|
||||
pub mod tree_split;
|
||||
|
||||
94
crates/aibridge/src/routing.rs
Normal file
94
crates/aibridge/src/routing.rs
Normal file
@ -0,0 +1,94 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct RoutingRule {
|
||||
pub model_pattern: String,
|
||||
pub provider: String,
|
||||
pub max_tokens: Option<u32>,
|
||||
pub temperature: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
pub struct RoutingEngine {
|
||||
rules: Vec<RoutingRule>,
|
||||
fallback_chain: Vec<String>,
|
||||
}
|
||||
|
||||
impl RoutingEngine {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn with_rules(mut self, rules: Vec<RoutingRule>) -> Self {
|
||||
self.rules = rules;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_fallback(mut self, chain: Vec<String>) -> Self {
|
||||
self.fallback_chain = chain;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn route(&self, model: &str) -> RouteDecision {
|
||||
let lower = model.to_lowercase();
|
||||
|
||||
for rule in &self.rules {
|
||||
if glob_match(&rule.model_pattern.to_lowercase(), &lower) {
|
||||
return RouteDecision {
|
||||
provider: rule.provider.clone(),
|
||||
model: model.to_string(),
|
||||
max_tokens: rule.max_tokens,
|
||||
temperature: rule.temperature,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(first) = self.fallback_chain.first() {
|
||||
RouteDecision {
|
||||
provider: first.clone(),
|
||||
model: model.to_string(),
|
||||
max_tokens: None,
|
||||
temperature: None,
|
||||
}
|
||||
} else {
|
||||
RouteDecision {
|
||||
provider: "ollama".to_string(),
|
||||
model: model.to_string(),
|
||||
max_tokens: None,
|
||||
temperature: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RouteDecision {
|
||||
pub provider: String,
|
||||
pub model: String,
|
||||
pub max_tokens: Option<u32>,
|
||||
pub temperature: Option<f64>,
|
||||
}
|
||||
|
||||
fn glob_match(pattern: &str, name: &str) -> bool {
|
||||
if pattern.contains('*') {
|
||||
let parts: Vec<&str> = pattern.split('*').collect();
|
||||
if parts.len() == 2 {
|
||||
return name.starts_with(parts[0]) && name.ends_with(parts[1]);
|
||||
} else if parts.len() == 1 {
|
||||
return name.starts_with(parts[0]) || name.ends_with(parts[1]);
|
||||
}
|
||||
}
|
||||
pattern == name
|
||||
}
|
||||
|
||||
impl Default for RoutingRule {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
model_pattern: "*".to_string(),
|
||||
provider: "ollama".to_string(),
|
||||
max_tokens: None,
|
||||
temperature: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -24,7 +24,7 @@ use axum::{
|
||||
Json,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -46,6 +46,16 @@ pub struct Usage {
|
||||
pub prompt_tokens: u64,
|
||||
pub completion_tokens: u64,
|
||||
pub total_tokens: u64,
|
||||
#[serde(default)]
|
||||
pub by_provider: std::collections::HashMap<String, ProviderUsage>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone, Serialize)]
|
||||
pub struct ProviderUsage {
|
||||
pub requests: u64,
|
||||
pub prompt_tokens: u64,
|
||||
pub completion_tokens: u64,
|
||||
pub total_tokens: u64,
|
||||
}
|
||||
|
||||
pub fn router(state: V1State) -> Router {
|
||||
@ -135,18 +145,22 @@ async fn chat(
|
||||
let start_time = chrono::Utc::now();
|
||||
let start_instant = std::time::Instant::now();
|
||||
|
||||
let resp = match provider.as_str() {
|
||||
"ollama" | "local" | "" => ollama::chat(&state.ai_client, &req)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::BAD_GATEWAY, format!("ollama local: {e}")))?,
|
||||
let (resp, used_provider) = match provider.as_str() {
|
||||
"ollama" | "local" | "" => {
|
||||
let r = ollama::chat(&state.ai_client, &req)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::BAD_GATEWAY, format!("ollama local: {e}")))?;
|
||||
(r, "ollama".to_string())
|
||||
}
|
||||
"ollama_cloud" | "cloud" => {
|
||||
let key = state.ollama_cloud_key.as_deref().ok_or((
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
"OLLAMA_CLOUD_KEY not configured — set env or populate /root/llm_team_config.json".to_string(),
|
||||
"OLLAMA_CLOUD_KEY not configured".to_string(),
|
||||
))?;
|
||||
ollama_cloud::chat(key, &req)
|
||||
let r = ollama_cloud::chat(key, &req)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::BAD_GATEWAY, format!("ollama cloud: {e}")))?
|
||||
.map_err(|e| (StatusCode::BAD_GATEWAY, format!("ollama cloud: {e}")))?;
|
||||
(r, "ollama_cloud".to_string())
|
||||
}
|
||||
other => {
|
||||
return Err((
|
||||
@ -168,7 +182,7 @@ async fn chat(
|
||||
.map(|c| c.message.content.clone())
|
||||
.unwrap_or_default();
|
||||
lf.emit_chat(langfuse_trace::ChatTrace {
|
||||
provider: provider.clone(),
|
||||
provider: used_provider.clone(),
|
||||
model: resp.model.clone(),
|
||||
input: req.messages.clone(),
|
||||
output,
|
||||
@ -183,12 +197,19 @@ async fn chat(
|
||||
});
|
||||
}
|
||||
|
||||
// Phase 40: per-provider usage tracking
|
||||
{
|
||||
let mut u = state.usage.write().await;
|
||||
u.requests += 1;
|
||||
u.prompt_tokens += resp.usage.prompt_tokens as u64;
|
||||
u.completion_tokens += resp.usage.completion_tokens as u64;
|
||||
u.total_tokens += resp.usage.total_tokens as u64;
|
||||
|
||||
let provider_usage = u.by_provider.entry(used_provider).or_default();
|
||||
provider_usage.requests += 1;
|
||||
provider_usage.prompt_tokens += resp.usage.prompt_tokens as u64;
|
||||
provider_usage.completion_tokens += resp.usage.completion_tokens as u64;
|
||||
provider_usage.total_tokens += resp.usage.total_tokens as u64;
|
||||
}
|
||||
|
||||
Ok(Json(resp))
|
||||
|
||||
@ -339,6 +339,11 @@
|
||||
- `OllamaAdapter` — wraps existing AiClient
|
||||
- `OpenRouterAdapter` — HTTP client to openrouter.ai
|
||||
- `provider_key()` routing by model prefix (openrouter/* → OpenRouter)
|
||||
- [x] **Phase 40: Routing & Policy Engine** (2026-04-23)
|
||||
- `RoutingEngine` with `RouteDecision` in aibridge::routing
|
||||
- `config/routing.toml` — rules by model_pattern, fallback chain, cost gating
|
||||
- Per-provider usage tracking: Usage.by_provider
|
||||
- 12 gateway tests green, curl gates pass
|
||||
- [ ] Fine-tuned domain models (Phase 25+)
|
||||
- [ ] Multi-node query distribution (only if ceilings bite)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user