Phase 40: Routing Engine + Policy
- RoutingEngine with RouteDecision (model_pattern → provider) - config/routing.toml: rules, fallback chain, cost gating - Per-provider Usage tracking in /v1/usage response - 12 gateway tests green
This commit is contained in:
parent
e27a17e950
commit
55f8e0fe6e
61
config/routing.toml
Normal file
61
config/routing.toml
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
# Phase 40: Routing Engine Configuration
|
||||||
|
#
|
||||||
|
# Human-editable rules for model → provider routing.
|
||||||
|
# Matching order: first match wins.
|
||||||
|
|
||||||
|
[[rule]]
|
||||||
|
model_pattern = "gpt-4*"
|
||||||
|
provider = "openrouter"
|
||||||
|
max_tokens = 4096
|
||||||
|
temperature = 0.7
|
||||||
|
|
||||||
|
[[rule]]
|
||||||
|
model_pattern = "claude*"
|
||||||
|
provider = "claude"
|
||||||
|
max_tokens = 4096
|
||||||
|
temperature = 0.7
|
||||||
|
|
||||||
|
[[rule]]
|
||||||
|
model_pattern = "gemini*"
|
||||||
|
provider = "gemini"
|
||||||
|
max_tokens = 8192
|
||||||
|
temperature = 0.9
|
||||||
|
|
||||||
|
[[rule]]
|
||||||
|
model_pattern = "qwen3.5*"
|
||||||
|
provider = "ollama"
|
||||||
|
max_tokens = 4096
|
||||||
|
temperature = 0.3
|
||||||
|
|
||||||
|
[[rule]]
|
||||||
|
model_pattern = "qwen3*"
|
||||||
|
provider = "ollama"
|
||||||
|
max_tokens = 2048
|
||||||
|
temperature = 0.3
|
||||||
|
|
||||||
|
[[rule]]
|
||||||
|
model_pattern = "gpt-oss*"
|
||||||
|
provider = "ollama"
|
||||||
|
temperature = 0.1
|
||||||
|
|
||||||
|
[[rule]]
|
||||||
|
model_pattern = "*"
|
||||||
|
provider = "ollama"
|
||||||
|
temperature = 0.5
|
||||||
|
|
||||||
|
# Fallback chain: if primary fails, try these in order
|
||||||
|
fallback = ["ollama", "openrouter"]
|
||||||
|
|
||||||
|
# Cost gating (tokens = cents per 1M)
|
||||||
|
[cost]
|
||||||
|
ollama = 0
|
||||||
|
openrouter = 15
|
||||||
|
claude = 15
|
||||||
|
gemini = 0
|
||||||
|
|
||||||
|
# Daily budget per provider (cents)
|
||||||
|
[daily_budget]
|
||||||
|
ollama = 0
|
||||||
|
openrouter = 1000
|
||||||
|
claude = 500
|
||||||
|
gemini = 0
|
||||||
@ -3,5 +3,6 @@ pub mod context;
|
|||||||
pub mod continuation;
|
pub mod continuation;
|
||||||
pub mod provider;
|
pub mod provider;
|
||||||
pub mod providers;
|
pub mod providers;
|
||||||
|
pub mod routing;
|
||||||
pub mod service;
|
pub mod service;
|
||||||
pub mod tree_split;
|
pub mod tree_split;
|
||||||
|
|||||||
94
crates/aibridge/src/routing.rs
Normal file
94
crates/aibridge/src/routing.rs
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct RoutingRule {
|
||||||
|
pub model_pattern: String,
|
||||||
|
pub provider: String,
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
pub temperature: Option<f64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Default)]
|
||||||
|
pub struct RoutingEngine {
|
||||||
|
rules: Vec<RoutingRule>,
|
||||||
|
fallback_chain: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RoutingEngine {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_rules(mut self, rules: Vec<RoutingRule>) -> Self {
|
||||||
|
self.rules = rules;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_fallback(mut self, chain: Vec<String>) -> Self {
|
||||||
|
self.fallback_chain = chain;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn route(&self, model: &str) -> RouteDecision {
|
||||||
|
let lower = model.to_lowercase();
|
||||||
|
|
||||||
|
for rule in &self.rules {
|
||||||
|
if glob_match(&rule.model_pattern.to_lowercase(), &lower) {
|
||||||
|
return RouteDecision {
|
||||||
|
provider: rule.provider.clone(),
|
||||||
|
model: model.to_string(),
|
||||||
|
max_tokens: rule.max_tokens,
|
||||||
|
temperature: rule.temperature,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(first) = self.fallback_chain.first() {
|
||||||
|
RouteDecision {
|
||||||
|
provider: first.clone(),
|
||||||
|
model: model.to_string(),
|
||||||
|
max_tokens: None,
|
||||||
|
temperature: None,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
RouteDecision {
|
||||||
|
provider: "ollama".to_string(),
|
||||||
|
model: model.to_string(),
|
||||||
|
max_tokens: None,
|
||||||
|
temperature: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct RouteDecision {
|
||||||
|
pub provider: String,
|
||||||
|
pub model: String,
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
pub temperature: Option<f64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn glob_match(pattern: &str, name: &str) -> bool {
|
||||||
|
if pattern.contains('*') {
|
||||||
|
let parts: Vec<&str> = pattern.split('*').collect();
|
||||||
|
if parts.len() == 2 {
|
||||||
|
return name.starts_with(parts[0]) && name.ends_with(parts[1]);
|
||||||
|
} else if parts.len() == 1 {
|
||||||
|
return name.starts_with(parts[0]) || name.ends_with(parts[1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pattern == name
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for RoutingRule {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
model_pattern: "*".to_string(),
|
||||||
|
provider: "ollama".to_string(),
|
||||||
|
max_tokens: None,
|
||||||
|
temperature: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -24,7 +24,7 @@ use axum::{
|
|||||||
Json,
|
Json,
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::sync::Arc;
|
use std::{collections::HashMap, sync::Arc};
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
@ -46,6 +46,16 @@ pub struct Usage {
|
|||||||
pub prompt_tokens: u64,
|
pub prompt_tokens: u64,
|
||||||
pub completion_tokens: u64,
|
pub completion_tokens: u64,
|
||||||
pub total_tokens: u64,
|
pub total_tokens: u64,
|
||||||
|
#[serde(default)]
|
||||||
|
pub by_provider: std::collections::HashMap<String, ProviderUsage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default, Clone, Serialize)]
|
||||||
|
pub struct ProviderUsage {
|
||||||
|
pub requests: u64,
|
||||||
|
pub prompt_tokens: u64,
|
||||||
|
pub completion_tokens: u64,
|
||||||
|
pub total_tokens: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn router(state: V1State) -> Router {
|
pub fn router(state: V1State) -> Router {
|
||||||
@ -135,18 +145,22 @@ async fn chat(
|
|||||||
let start_time = chrono::Utc::now();
|
let start_time = chrono::Utc::now();
|
||||||
let start_instant = std::time::Instant::now();
|
let start_instant = std::time::Instant::now();
|
||||||
|
|
||||||
let resp = match provider.as_str() {
|
let (resp, used_provider) = match provider.as_str() {
|
||||||
"ollama" | "local" | "" => ollama::chat(&state.ai_client, &req)
|
"ollama" | "local" | "" => {
|
||||||
.await
|
let r = ollama::chat(&state.ai_client, &req)
|
||||||
.map_err(|e| (StatusCode::BAD_GATEWAY, format!("ollama local: {e}")))?,
|
.await
|
||||||
|
.map_err(|e| (StatusCode::BAD_GATEWAY, format!("ollama local: {e}")))?;
|
||||||
|
(r, "ollama".to_string())
|
||||||
|
}
|
||||||
"ollama_cloud" | "cloud" => {
|
"ollama_cloud" | "cloud" => {
|
||||||
let key = state.ollama_cloud_key.as_deref().ok_or((
|
let key = state.ollama_cloud_key.as_deref().ok_or((
|
||||||
StatusCode::SERVICE_UNAVAILABLE,
|
StatusCode::SERVICE_UNAVAILABLE,
|
||||||
"OLLAMA_CLOUD_KEY not configured — set env or populate /root/llm_team_config.json".to_string(),
|
"OLLAMA_CLOUD_KEY not configured".to_string(),
|
||||||
))?;
|
))?;
|
||||||
ollama_cloud::chat(key, &req)
|
let r = ollama_cloud::chat(key, &req)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| (StatusCode::BAD_GATEWAY, format!("ollama cloud: {e}")))?
|
.map_err(|e| (StatusCode::BAD_GATEWAY, format!("ollama cloud: {e}")))?;
|
||||||
|
(r, "ollama_cloud".to_string())
|
||||||
}
|
}
|
||||||
other => {
|
other => {
|
||||||
return Err((
|
return Err((
|
||||||
@ -168,7 +182,7 @@ async fn chat(
|
|||||||
.map(|c| c.message.content.clone())
|
.map(|c| c.message.content.clone())
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
lf.emit_chat(langfuse_trace::ChatTrace {
|
lf.emit_chat(langfuse_trace::ChatTrace {
|
||||||
provider: provider.clone(),
|
provider: used_provider.clone(),
|
||||||
model: resp.model.clone(),
|
model: resp.model.clone(),
|
||||||
input: req.messages.clone(),
|
input: req.messages.clone(),
|
||||||
output,
|
output,
|
||||||
@ -183,12 +197,19 @@ async fn chat(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Phase 40: per-provider usage tracking
|
||||||
{
|
{
|
||||||
let mut u = state.usage.write().await;
|
let mut u = state.usage.write().await;
|
||||||
u.requests += 1;
|
u.requests += 1;
|
||||||
u.prompt_tokens += resp.usage.prompt_tokens as u64;
|
u.prompt_tokens += resp.usage.prompt_tokens as u64;
|
||||||
u.completion_tokens += resp.usage.completion_tokens as u64;
|
u.completion_tokens += resp.usage.completion_tokens as u64;
|
||||||
u.total_tokens += resp.usage.total_tokens as u64;
|
u.total_tokens += resp.usage.total_tokens as u64;
|
||||||
|
|
||||||
|
let provider_usage = u.by_provider.entry(used_provider).or_default();
|
||||||
|
provider_usage.requests += 1;
|
||||||
|
provider_usage.prompt_tokens += resp.usage.prompt_tokens as u64;
|
||||||
|
provider_usage.completion_tokens += resp.usage.completion_tokens as u64;
|
||||||
|
provider_usage.total_tokens += resp.usage.total_tokens as u64;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Json(resp))
|
Ok(Json(resp))
|
||||||
|
|||||||
@ -339,6 +339,11 @@
|
|||||||
- `OllamaAdapter` — wraps existing AiClient
|
- `OllamaAdapter` — wraps existing AiClient
|
||||||
- `OpenRouterAdapter` — HTTP client to openrouter.ai
|
- `OpenRouterAdapter` — HTTP client to openrouter.ai
|
||||||
- `provider_key()` routing by model prefix (openrouter/* → OpenRouter)
|
- `provider_key()` routing by model prefix (openrouter/* → OpenRouter)
|
||||||
|
- [x] **Phase 40: Routing & Policy Engine** (2026-04-23)
|
||||||
|
- `RoutingEngine` with `RouteDecision` in aibridge::routing
|
||||||
|
- `config/routing.toml` — rules by model_pattern, fallback chain, cost gating
|
||||||
|
- Per-provider usage tracking: Usage.by_provider
|
||||||
|
- 12 gateway tests green, curl gates pass
|
||||||
- [ ] Fine-tuned domain models (Phase 25+)
|
- [ ] Fine-tuned domain models (Phase 25+)
|
||||||
- [ ] Multi-node query distribution (only if ceilings bite)
|
- [ ] Multi-node query distribution (only if ceilings bite)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user