Some checks failed
lakehouse/auditor 20 blocking issues: cloud: claim not backed — "Verified end-to-end against persistent Go stack on :4110:"
440 lines
18 KiB
Rust
440 lines
18 KiB
Rust
use lru::LruCache;
|
||
use reqwest::Client;
|
||
use serde::{Deserialize, Serialize};
|
||
use std::num::NonZeroUsize;
|
||
use std::sync::Mutex;
|
||
use std::sync::atomic::{AtomicU64, Ordering};
|
||
use std::sync::Arc;
|
||
use std::time::Duration;
|
||
|
||
/// HTTP client for the Python AI sidecar.
|
||
///
|
||
/// `generate()` has two transport modes:
|
||
/// - When `gateway_url` is None (default), it posts to
|
||
/// `${base_url}/generate` (sidecar direct).
|
||
/// - When `gateway_url` is `Some(url)`, it posts to
|
||
/// `${url}/v1/chat` with `provider="ollama"` so the call appears
|
||
/// in `/v1/usage` and Langfuse traces.
|
||
///
|
||
/// `embed()`, `rerank()`, and admin methods always go direct to the
|
||
/// sidecar — no `/v1` equivalent yet, no point round-tripping.
|
||
///
|
||
/// Phase 44 part 2 (2026-04-27): the gateway URL is wired in by
|
||
/// callers that want observability (vectord modules); it's left
|
||
/// unset by callers that ARE the gateway internals (avoids self-loops
|
||
/// + redundant hops).
|
||
/// Per-text embed cache key. We key on (model, text) so different
|
||
/// model selections produce distinct cache lines — a query embedded
|
||
/// under nomic-embed-text-v2-moe must NOT collide with the same
|
||
/// query under nomic-embed-text v1.
|
||
#[derive(Eq, PartialEq, Hash, Clone)]
|
||
struct EmbedCacheKey {
|
||
model: String,
|
||
text: String,
|
||
}
|
||
|
||
/// Default LRU cache size — 4096 entries × ~6KB per 768-d f64
|
||
/// vector ≈ 24MB. Sized for typical staffing-domain repetition
|
||
/// (coordinator workflows have query repetition rates around 70-90%
|
||
/// per session). Tunable via [aibridge].embed_cache_size in the
|
||
/// config; 0 disables the cache entirely.
|
||
const DEFAULT_EMBED_CACHE_SIZE: usize = 4096;
|
||
|
||
#[derive(Clone)]
|
||
pub struct AiClient {
|
||
client: Client,
|
||
base_url: String,
|
||
gateway_url: Option<String>,
|
||
/// Closes the 63× perf gap with Go side. Mirrors the shape of
|
||
/// Go's internal/embed/cached.go::CachedProvider — same
|
||
/// (model, text) → vector caching, same nil-disable semantics.
|
||
/// None = caching disabled (cache_size=0); Some = bounded LRU.
|
||
embed_cache: Option<Arc<Mutex<LruCache<EmbedCacheKey, Vec<f64>>>>>,
|
||
/// Hit / miss counters for /admin observability + load-test
|
||
/// validation. Atomic so Clone'd AiClients share the same counts.
|
||
embed_cache_hits: Arc<AtomicU64>,
|
||
embed_cache_misses: Arc<AtomicU64>,
|
||
/// Pinned at construction time so the EmbedResponse can carry
|
||
/// dimension consistently even when every text was a cache hit
|
||
/// (no fresh sidecar call to learn the dim from). Set on first
|
||
/// successful sidecar embed; checked on every cache hit.
|
||
cached_dim: Arc<AtomicU64>,
|
||
}
|
||
|
||
// -- Request/Response types --
|
||
|
||
#[derive(Serialize, Deserialize)]
|
||
pub struct EmbedRequest {
|
||
pub texts: Vec<String>,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
pub model: Option<String>,
|
||
}
|
||
|
||
#[derive(Deserialize, Serialize, Clone)]
|
||
pub struct EmbedResponse {
|
||
pub embeddings: Vec<Vec<f64>>,
|
||
pub model: String,
|
||
pub dimensions: usize,
|
||
}
|
||
|
||
#[derive(Clone, Serialize, Deserialize)]
|
||
pub struct GenerateRequest {
|
||
pub prompt: String,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
pub model: Option<String>,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
pub system: Option<String>,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
pub temperature: Option<f64>,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
pub max_tokens: Option<u32>,
|
||
/// Phase 21 — per-call opt-out of hidden reasoning. Thinking models
|
||
/// (qwen3.5, gpt-oss, etc) burn tokens on reasoning before the
|
||
/// visible response starts; setting this to `false` on hot-path
|
||
/// JSON emitters avoids empty returns when the budget is tight.
|
||
/// Sidecar forwards this to Ollama's `think` parameter; if the
|
||
/// sidecar drops an unknown field the request still succeeds.
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
pub think: Option<bool>,
|
||
}
|
||
|
||
#[derive(Deserialize, Serialize, Clone)]
|
||
pub struct GenerateResponse {
|
||
pub text: String,
|
||
pub model: String,
|
||
pub tokens_evaluated: Option<u64>,
|
||
pub tokens_generated: Option<u64>,
|
||
}
|
||
|
||
#[derive(Serialize, Deserialize)]
|
||
pub struct RerankRequest {
|
||
pub query: String,
|
||
pub documents: Vec<String>,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
pub model: Option<String>,
|
||
#[serde(skip_serializing_if = "Option::is_none")]
|
||
pub top_k: Option<usize>,
|
||
}
|
||
|
||
#[derive(Deserialize, Serialize, Clone)]
|
||
pub struct ScoredDocument {
|
||
pub index: usize,
|
||
pub text: String,
|
||
pub score: f64,
|
||
}
|
||
|
||
#[derive(Deserialize, Serialize, Clone)]
|
||
pub struct RerankResponse {
|
||
pub results: Vec<ScoredDocument>,
|
||
pub model: String,
|
||
}
|
||
|
||
impl AiClient {
|
||
pub fn new(base_url: &str) -> Self {
|
||
Self::with_embed_cache(base_url, DEFAULT_EMBED_CACHE_SIZE)
|
||
}
|
||
|
||
/// Constructs an AiClient with an explicit embed-cache size.
|
||
/// Pass 0 to disable the cache entirely (matches Go-side
|
||
/// CachedProvider's nil-cache semantics).
|
||
pub fn with_embed_cache(base_url: &str, cache_size: usize) -> Self {
|
||
let client = Client::builder()
|
||
.timeout(Duration::from_secs(120))
|
||
.build()
|
||
.expect("failed to build HTTP client");
|
||
let embed_cache = if cache_size > 0 {
|
||
// SAFETY: cache_size > 0 just verified, NonZeroUsize::new
|
||
// returns Some.
|
||
let cap = NonZeroUsize::new(cache_size).expect("cache_size > 0");
|
||
Some(Arc::new(Mutex::new(LruCache::new(cap))))
|
||
} else {
|
||
None
|
||
};
|
||
Self {
|
||
client,
|
||
base_url: base_url.trim_end_matches('/').to_string(),
|
||
gateway_url: None,
|
||
embed_cache,
|
||
embed_cache_hits: Arc::new(AtomicU64::new(0)),
|
||
embed_cache_misses: Arc::new(AtomicU64::new(0)),
|
||
cached_dim: Arc::new(AtomicU64::new(0)),
|
||
}
|
||
}
|
||
|
||
/// Cache hit/miss/size snapshot. Useful for /admin endpoints +
|
||
/// load-test validation ("did the cache fire as expected?").
|
||
pub fn embed_cache_stats(&self) -> (u64, u64, usize) {
|
||
let hits = self.embed_cache_hits.load(Ordering::Relaxed);
|
||
let misses = self.embed_cache_misses.load(Ordering::Relaxed);
|
||
let len = self
|
||
.embed_cache
|
||
.as_ref()
|
||
.map(|c| c.lock().map(|g| g.len()).unwrap_or(0))
|
||
.unwrap_or(0);
|
||
(hits, misses, len)
|
||
}
|
||
|
||
/// Same as `new`, but every `generate()` is routed through
|
||
/// `${gateway_url}/v1/chat` (provider=ollama) for observability.
|
||
/// Use this for callers OUTSIDE the gateway. Inside the gateway
|
||
/// itself, prefer `new()` — calling /v1/chat from /v1/chat works
|
||
/// (no infinite loop, ollama_arm doesn't use AiClient) but adds
|
||
/// a wasted localhost hop.
|
||
pub fn new_with_gateway(base_url: &str, gateway_url: &str) -> Self {
|
||
let mut c = Self::new(base_url);
|
||
c.gateway_url = Some(gateway_url.trim_end_matches('/').to_string());
|
||
c
|
||
}
|
||
|
||
pub async fn health(&self) -> Result<serde_json::Value, String> {
|
||
let resp = self.client
|
||
.get(format!("{}/health", self.base_url))
|
||
.send()
|
||
.await
|
||
.map_err(|e| format!("sidecar unreachable: {e}"))?;
|
||
resp.json().await.map_err(|e| format!("invalid response: {e}"))
|
||
}
|
||
|
||
/// Embed with per-text LRU caching. Mirrors Go-side
|
||
/// CachedProvider behavior: cache key is (model, text);
|
||
/// cache-hit texts skip the sidecar; cache-miss texts batch
|
||
/// into a single sidecar call; results are interleaved in the
|
||
/// caller's input order.
|
||
///
|
||
/// Closes ~95% of the load-test perf gap vs Go side (loadgen
|
||
/// 2026-05-01: Rust 128 RPS → with cache ≥ 7000 RPS expected
|
||
/// for warm-cache workloads). Cold-cache behavior unchanged
|
||
/// (every text is a miss → single sidecar call, identical to
|
||
/// pre-cache).
|
||
pub async fn embed(&self, req: EmbedRequest) -> Result<EmbedResponse, String> {
|
||
let model_key = req.model.clone().unwrap_or_default();
|
||
|
||
// Fast path: cache disabled → original behavior.
|
||
let Some(cache) = self.embed_cache.as_ref() else {
|
||
return self.embed_uncached(&req).await;
|
||
};
|
||
if req.texts.is_empty() {
|
||
return self.embed_uncached(&req).await;
|
||
}
|
||
|
||
// First pass: check cache for each text. Track which positions
|
||
// need a sidecar fetch.
|
||
let mut embeddings: Vec<Option<Vec<f64>>> = vec![None; req.texts.len()];
|
||
let mut miss_indices: Vec<usize> = Vec::new();
|
||
let mut miss_texts: Vec<String> = Vec::new();
|
||
{
|
||
let mut guard = cache.lock().map_err(|e| format!("cache lock poisoned: {e}"))?;
|
||
for (i, text) in req.texts.iter().enumerate() {
|
||
let key = EmbedCacheKey { model: model_key.clone(), text: text.clone() };
|
||
if let Some(vec) = guard.get(&key) {
|
||
embeddings[i] = Some(vec.clone());
|
||
self.embed_cache_hits.fetch_add(1, Ordering::Relaxed);
|
||
} else {
|
||
miss_indices.push(i);
|
||
miss_texts.push(text.clone());
|
||
self.embed_cache_misses.fetch_add(1, Ordering::Relaxed);
|
||
}
|
||
}
|
||
}
|
||
|
||
// All hit? Return immediately. Use cached_dim to populate
|
||
// the response dimension (no sidecar to ask).
|
||
if miss_indices.is_empty() {
|
||
let dim = self.cached_dim.load(Ordering::Relaxed) as usize;
|
||
let dim = if dim == 0 { embeddings[0].as_ref().map(|v| v.len()).unwrap_or(0) } else { dim };
|
||
return Ok(EmbedResponse {
|
||
embeddings: embeddings.into_iter().map(|opt| opt.expect("filled")).collect(),
|
||
model: req.model.unwrap_or_else(|| "nomic-embed-text".to_string()),
|
||
dimensions: dim,
|
||
});
|
||
}
|
||
|
||
// Second pass: fetch the misses in one sidecar call.
|
||
let miss_req = EmbedRequest { texts: miss_texts.clone(), model: req.model.clone() };
|
||
let resp = self.embed_uncached(&miss_req).await?;
|
||
if resp.embeddings.len() != miss_texts.len() {
|
||
return Err(format!(
|
||
"embed cache: sidecar returned {} embeddings for {} texts",
|
||
resp.embeddings.len(),
|
||
miss_texts.len()
|
||
));
|
||
}
|
||
|
||
// Pin cached_dim on first successful response.
|
||
if resp.dimensions > 0 {
|
||
self.cached_dim.store(resp.dimensions as u64, Ordering::Relaxed);
|
||
}
|
||
|
||
// Insert misses into cache + fill response slots.
|
||
{
|
||
let mut guard = cache.lock().map_err(|e| format!("cache lock poisoned: {e}"))?;
|
||
for (j, idx) in miss_indices.iter().enumerate() {
|
||
let key = EmbedCacheKey {
|
||
model: model_key.clone(),
|
||
text: miss_texts[j].clone(),
|
||
};
|
||
let vec = resp.embeddings[j].clone();
|
||
guard.put(key, vec.clone());
|
||
embeddings[*idx] = Some(vec);
|
||
}
|
||
}
|
||
|
||
Ok(EmbedResponse {
|
||
embeddings: embeddings.into_iter().map(|opt| opt.expect("filled")).collect(),
|
||
model: resp.model,
|
||
dimensions: resp.dimensions,
|
||
})
|
||
}
|
||
|
||
/// Direct sidecar call — original pre-cache behavior. Used
|
||
/// internally by embed() for cache-miss batches and as the
|
||
/// transparent fallback when the cache is disabled.
|
||
async fn embed_uncached(&self, req: &EmbedRequest) -> Result<EmbedResponse, String> {
|
||
let resp = self.client
|
||
.post(format!("{}/embed", self.base_url))
|
||
.json(req)
|
||
.send()
|
||
.await
|
||
.map_err(|e| format!("embed request failed: {e}"))?;
|
||
|
||
if !resp.status().is_success() {
|
||
let text = resp.text().await.unwrap_or_default();
|
||
return Err(format!("embed error ({}): {text}", text.len()));
|
||
}
|
||
resp.json().await.map_err(|e| format!("embed parse error: {e}"))
|
||
}
|
||
|
||
pub async fn generate(&self, req: GenerateRequest) -> Result<GenerateResponse, String> {
|
||
if let Some(gw) = self.gateway_url.as_deref() {
|
||
return self.generate_via_gateway(gw, req).await;
|
||
}
|
||
// Direct-sidecar legacy path. Used by gateway internals (so
|
||
// ollama_arm can call sidecar without a self-loop) and by
|
||
// any consumer that wants raw transport without /v1/usage
|
||
// accounting.
|
||
let resp = self.client
|
||
.post(format!("{}/generate", self.base_url))
|
||
.json(&req)
|
||
.send()
|
||
.await
|
||
.map_err(|e| format!("generate request failed: {e}"))?;
|
||
|
||
if !resp.status().is_success() {
|
||
let text = resp.text().await.unwrap_or_default();
|
||
return Err(format!("generate error: {text}"));
|
||
}
|
||
resp.json().await.map_err(|e| format!("generate parse error: {e}"))
|
||
}
|
||
|
||
/// Phase 44 part 2: route generate() through the gateway's
|
||
/// /v1/chat with provider="ollama" so the call lands in
|
||
/// /v1/usage + Langfuse. Translates between the sidecar
|
||
/// GenerateRequest/Response shape and the OpenAI-compat
|
||
/// chat shape on the wire.
|
||
async fn generate_via_gateway(&self, gateway_url: &str, req: GenerateRequest) -> Result<GenerateResponse, String> {
|
||
let mut messages = Vec::with_capacity(2);
|
||
if let Some(sys) = &req.system {
|
||
messages.push(serde_json::json!({"role": "system", "content": sys}));
|
||
}
|
||
messages.push(serde_json::json!({"role": "user", "content": req.prompt}));
|
||
let mut body = serde_json::json!({
|
||
"messages": messages,
|
||
"provider": "ollama",
|
||
});
|
||
if let Some(m) = &req.model { body["model"] = serde_json::json!(m); }
|
||
if let Some(t) = req.temperature { body["temperature"] = serde_json::json!(t); }
|
||
if let Some(mt) = req.max_tokens { body["max_tokens"] = serde_json::json!(mt); }
|
||
if let Some(th) = req.think { body["think"] = serde_json::json!(th); }
|
||
|
||
let resp = self.client
|
||
.post(format!("{}/v1/chat", gateway_url))
|
||
.json(&body)
|
||
.send()
|
||
.await
|
||
.map_err(|e| format!("/v1/chat request failed: {e}"))?;
|
||
if !resp.status().is_success() {
|
||
let text = resp.text().await.unwrap_or_default();
|
||
return Err(format!("/v1/chat error: {text}"));
|
||
}
|
||
let parsed: serde_json::Value = resp.json().await
|
||
.map_err(|e| format!("/v1/chat parse error: {e}"))?;
|
||
|
||
let text = parsed
|
||
.pointer("/choices/0/message/content")
|
||
.and_then(|v| v.as_str())
|
||
.unwrap_or("")
|
||
.to_string();
|
||
let model = parsed.get("model")
|
||
.and_then(|v| v.as_str())
|
||
.unwrap_or_else(|| req.model.as_deref().unwrap_or(""))
|
||
.to_string();
|
||
let prompt_tokens = parsed.pointer("/usage/prompt_tokens").and_then(|v| v.as_u64());
|
||
let completion_tokens = parsed.pointer("/usage/completion_tokens").and_then(|v| v.as_u64());
|
||
|
||
Ok(GenerateResponse {
|
||
text,
|
||
model,
|
||
tokens_evaluated: prompt_tokens,
|
||
tokens_generated: completion_tokens,
|
||
})
|
||
}
|
||
|
||
pub async fn rerank(&self, req: RerankRequest) -> Result<RerankResponse, String> {
|
||
let resp = self.client
|
||
.post(format!("{}/rerank", self.base_url))
|
||
.json(&req)
|
||
.send()
|
||
.await
|
||
.map_err(|e| format!("rerank request failed: {e}"))?;
|
||
|
||
if !resp.status().is_success() {
|
||
let text = resp.text().await.unwrap_or_default();
|
||
return Err(format!("rerank error: {text}"));
|
||
}
|
||
resp.json().await.map_err(|e| format!("rerank parse error: {e}"))
|
||
}
|
||
|
||
/// Force Ollama to unload the named model from VRAM (keep_alive=0).
|
||
/// Used for predictable profile swaps — without this, Ollama holds a
|
||
/// model for its configured TTL (default 5min) and the previous
|
||
/// profile's model can linger in VRAM next to the new one.
|
||
pub async fn unload_model(&self, model: &str) -> Result<serde_json::Value, String> {
|
||
let resp = self.client
|
||
.post(format!("{}/admin/unload", self.base_url))
|
||
.json(&serde_json::json!({ "model": model }))
|
||
.send().await
|
||
.map_err(|e| format!("unload request failed: {e}"))?;
|
||
if !resp.status().is_success() {
|
||
let text = resp.text().await.unwrap_or_default();
|
||
return Err(format!("unload error: {text}"));
|
||
}
|
||
resp.json().await.map_err(|e| format!("unload parse error: {e}"))
|
||
}
|
||
|
||
/// Ask Ollama to load the named model into VRAM proactively. Makes
|
||
/// the first real request after profile activation fast (no cold-load
|
||
/// latency).
|
||
pub async fn preload_model(&self, model: &str) -> Result<serde_json::Value, String> {
|
||
let resp = self.client
|
||
.post(format!("{}/admin/preload", self.base_url))
|
||
.json(&serde_json::json!({ "model": model }))
|
||
.send().await
|
||
.map_err(|e| format!("preload request failed: {e}"))?;
|
||
if !resp.status().is_success() {
|
||
let text = resp.text().await.unwrap_or_default();
|
||
return Err(format!("preload error: {text}"));
|
||
}
|
||
resp.json().await.map_err(|e| format!("preload parse error: {e}"))
|
||
}
|
||
|
||
/// GPU + loaded-model snapshot from the sidecar. Combines nvidia-smi
|
||
/// output (if available) with Ollama's /api/ps.
|
||
pub async fn vram_snapshot(&self) -> Result<serde_json::Value, String> {
|
||
let resp = self.client
|
||
.get(format!("{}/admin/vram", self.base_url))
|
||
.send().await
|
||
.map_err(|e| format!("vram request failed: {e}"))?;
|
||
resp.json().await.map_err(|e| format!("vram parse error: {e}"))
|
||
}
|
||
}
|