use lru::LruCache; use reqwest::Client; use serde::{Deserialize, Serialize}; use std::num::NonZeroUsize; use std::sync::Mutex; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use std::time::Duration; /// HTTP client for the Python AI sidecar. /// /// `generate()` has two transport modes: /// - When `gateway_url` is None (default), it posts to /// `${base_url}/generate` (sidecar direct). /// - When `gateway_url` is `Some(url)`, it posts to /// `${url}/v1/chat` with `provider="ollama"` so the call appears /// in `/v1/usage` and Langfuse traces. /// /// `embed()`, `rerank()`, and admin methods always go direct to the /// sidecar — no `/v1` equivalent yet, no point round-tripping. /// /// Phase 44 part 2 (2026-04-27): the gateway URL is wired in by /// callers that want observability (vectord modules); it's left /// unset by callers that ARE the gateway internals (avoids self-loops /// + redundant hops). /// Per-text embed cache key. We key on (model, text) so different /// model selections produce distinct cache lines — a query embedded /// under nomic-embed-text-v2-moe must NOT collide with the same /// query under nomic-embed-text v1. #[derive(Eq, PartialEq, Hash, Clone)] struct EmbedCacheKey { model: String, text: String, } /// Default LRU cache size — 4096 entries × ~6KB per 768-d f64 /// vector ≈ 24MB. Sized for typical staffing-domain repetition /// (coordinator workflows have query repetition rates around 70-90% /// per session). Tunable via [aibridge].embed_cache_size in the /// config; 0 disables the cache entirely. const DEFAULT_EMBED_CACHE_SIZE: usize = 4096; #[derive(Clone)] pub struct AiClient { client: Client, base_url: String, gateway_url: Option, /// Closes the 63× perf gap with Go side. Mirrors the shape of /// Go's internal/embed/cached.go::CachedProvider — same /// (model, text) → vector caching, same nil-disable semantics. /// None = caching disabled (cache_size=0); Some = bounded LRU. embed_cache: Option>>>>, /// Hit / miss counters for /admin observability + load-test /// validation. Atomic so Clone'd AiClients share the same counts. embed_cache_hits: Arc, embed_cache_misses: Arc, /// Pinned at construction time so the EmbedResponse can carry /// dimension consistently even when every text was a cache hit /// (no fresh sidecar call to learn the dim from). Set on first /// successful sidecar embed; checked on every cache hit. cached_dim: Arc, } // -- Request/Response types -- #[derive(Serialize, Deserialize)] pub struct EmbedRequest { pub texts: Vec, #[serde(skip_serializing_if = "Option::is_none")] pub model: Option, } #[derive(Deserialize, Serialize, Clone)] pub struct EmbedResponse { pub embeddings: Vec>, pub model: String, pub dimensions: usize, } #[derive(Clone, Serialize, Deserialize)] pub struct GenerateRequest { pub prompt: String, #[serde(skip_serializing_if = "Option::is_none")] pub model: Option, #[serde(skip_serializing_if = "Option::is_none")] pub system: Option, #[serde(skip_serializing_if = "Option::is_none")] pub temperature: Option, #[serde(skip_serializing_if = "Option::is_none")] pub max_tokens: Option, /// Phase 21 — per-call opt-out of hidden reasoning. Thinking models /// (qwen3.5, gpt-oss, etc) burn tokens on reasoning before the /// visible response starts; setting this to `false` on hot-path /// JSON emitters avoids empty returns when the budget is tight. /// Sidecar forwards this to Ollama's `think` parameter; if the /// sidecar drops an unknown field the request still succeeds. #[serde(skip_serializing_if = "Option::is_none")] pub think: Option, } #[derive(Deserialize, Serialize, Clone)] pub struct GenerateResponse { pub text: String, pub model: String, pub tokens_evaluated: Option, pub tokens_generated: Option, } #[derive(Serialize, Deserialize)] pub struct RerankRequest { pub query: String, pub documents: Vec, #[serde(skip_serializing_if = "Option::is_none")] pub model: Option, #[serde(skip_serializing_if = "Option::is_none")] pub top_k: Option, } #[derive(Deserialize, Serialize, Clone)] pub struct ScoredDocument { pub index: usize, pub text: String, pub score: f64, } #[derive(Deserialize, Serialize, Clone)] pub struct RerankResponse { pub results: Vec, pub model: String, } impl AiClient { pub fn new(base_url: &str) -> Self { Self::with_embed_cache(base_url, DEFAULT_EMBED_CACHE_SIZE) } /// Constructs an AiClient with an explicit embed-cache size. /// Pass 0 to disable the cache entirely (matches Go-side /// CachedProvider's nil-cache semantics). pub fn with_embed_cache(base_url: &str, cache_size: usize) -> Self { let client = Client::builder() .timeout(Duration::from_secs(120)) .build() .expect("failed to build HTTP client"); let embed_cache = if cache_size > 0 { // SAFETY: cache_size > 0 just verified, NonZeroUsize::new // returns Some. let cap = NonZeroUsize::new(cache_size).expect("cache_size > 0"); Some(Arc::new(Mutex::new(LruCache::new(cap)))) } else { None }; Self { client, base_url: base_url.trim_end_matches('/').to_string(), gateway_url: None, embed_cache, embed_cache_hits: Arc::new(AtomicU64::new(0)), embed_cache_misses: Arc::new(AtomicU64::new(0)), cached_dim: Arc::new(AtomicU64::new(0)), } } /// Cache hit/miss/size snapshot. Useful for /admin endpoints + /// load-test validation ("did the cache fire as expected?"). pub fn embed_cache_stats(&self) -> (u64, u64, usize) { let hits = self.embed_cache_hits.load(Ordering::Relaxed); let misses = self.embed_cache_misses.load(Ordering::Relaxed); let len = self .embed_cache .as_ref() .map(|c| c.lock().map(|g| g.len()).unwrap_or(0)) .unwrap_or(0); (hits, misses, len) } /// Same as `new`, but every `generate()` is routed through /// `${gateway_url}/v1/chat` (provider=ollama) for observability. /// Use this for callers OUTSIDE the gateway. Inside the gateway /// itself, prefer `new()` — calling /v1/chat from /v1/chat works /// (no infinite loop, ollama_arm doesn't use AiClient) but adds /// a wasted localhost hop. pub fn new_with_gateway(base_url: &str, gateway_url: &str) -> Self { let mut c = Self::new(base_url); c.gateway_url = Some(gateway_url.trim_end_matches('/').to_string()); c } pub async fn health(&self) -> Result { let resp = self.client .get(format!("{}/health", self.base_url)) .send() .await .map_err(|e| format!("sidecar unreachable: {e}"))?; resp.json().await.map_err(|e| format!("invalid response: {e}")) } /// Embed with per-text LRU caching. Mirrors Go-side /// CachedProvider behavior: cache key is (model, text); /// cache-hit texts skip the sidecar; cache-miss texts batch /// into a single sidecar call; results are interleaved in the /// caller's input order. /// /// Closes ~95% of the load-test perf gap vs Go side (loadgen /// 2026-05-01: Rust 128 RPS → with cache ≥ 7000 RPS expected /// for warm-cache workloads). Cold-cache behavior unchanged /// (every text is a miss → single sidecar call, identical to /// pre-cache). pub async fn embed(&self, req: EmbedRequest) -> Result { let model_key = req.model.clone().unwrap_or_default(); // Fast path: cache disabled → original behavior. let Some(cache) = self.embed_cache.as_ref() else { return self.embed_uncached(&req).await; }; if req.texts.is_empty() { return self.embed_uncached(&req).await; } // First pass: check cache for each text. Track which positions // need a sidecar fetch. let mut embeddings: Vec>> = vec![None; req.texts.len()]; let mut miss_indices: Vec = Vec::new(); let mut miss_texts: Vec = Vec::new(); { let mut guard = cache.lock().map_err(|e| format!("cache lock poisoned: {e}"))?; for (i, text) in req.texts.iter().enumerate() { let key = EmbedCacheKey { model: model_key.clone(), text: text.clone() }; if let Some(vec) = guard.get(&key) { embeddings[i] = Some(vec.clone()); self.embed_cache_hits.fetch_add(1, Ordering::Relaxed); } else { miss_indices.push(i); miss_texts.push(text.clone()); self.embed_cache_misses.fetch_add(1, Ordering::Relaxed); } } } // All hit? Return immediately. Use cached_dim to populate // the response dimension (no sidecar to ask). if miss_indices.is_empty() { let dim = self.cached_dim.load(Ordering::Relaxed) as usize; let dim = if dim == 0 { embeddings[0].as_ref().map(|v| v.len()).unwrap_or(0) } else { dim }; return Ok(EmbedResponse { embeddings: embeddings.into_iter().map(|opt| opt.expect("filled")).collect(), model: req.model.unwrap_or_else(|| "nomic-embed-text".to_string()), dimensions: dim, }); } // Second pass: fetch the misses in one sidecar call. let miss_req = EmbedRequest { texts: miss_texts.clone(), model: req.model.clone() }; let resp = self.embed_uncached(&miss_req).await?; if resp.embeddings.len() != miss_texts.len() { return Err(format!( "embed cache: sidecar returned {} embeddings for {} texts", resp.embeddings.len(), miss_texts.len() )); } // Pin cached_dim on first successful response. if resp.dimensions > 0 { self.cached_dim.store(resp.dimensions as u64, Ordering::Relaxed); } // Insert misses into cache + fill response slots. { let mut guard = cache.lock().map_err(|e| format!("cache lock poisoned: {e}"))?; for (j, idx) in miss_indices.iter().enumerate() { let key = EmbedCacheKey { model: model_key.clone(), text: miss_texts[j].clone(), }; let vec = resp.embeddings[j].clone(); guard.put(key, vec.clone()); embeddings[*idx] = Some(vec); } } Ok(EmbedResponse { embeddings: embeddings.into_iter().map(|opt| opt.expect("filled")).collect(), model: resp.model, dimensions: resp.dimensions, }) } /// Direct sidecar call — original pre-cache behavior. Used /// internally by embed() for cache-miss batches and as the /// transparent fallback when the cache is disabled. async fn embed_uncached(&self, req: &EmbedRequest) -> Result { let resp = self.client .post(format!("{}/embed", self.base_url)) .json(req) .send() .await .map_err(|e| format!("embed request failed: {e}"))?; if !resp.status().is_success() { let text = resp.text().await.unwrap_or_default(); return Err(format!("embed error ({}): {text}", text.len())); } resp.json().await.map_err(|e| format!("embed parse error: {e}")) } pub async fn generate(&self, req: GenerateRequest) -> Result { if let Some(gw) = self.gateway_url.as_deref() { return self.generate_via_gateway(gw, req).await; } // Direct-sidecar legacy path. Used by gateway internals (so // ollama_arm can call sidecar without a self-loop) and by // any consumer that wants raw transport without /v1/usage // accounting. let resp = self.client .post(format!("{}/generate", self.base_url)) .json(&req) .send() .await .map_err(|e| format!("generate request failed: {e}"))?; if !resp.status().is_success() { let text = resp.text().await.unwrap_or_default(); return Err(format!("generate error: {text}")); } resp.json().await.map_err(|e| format!("generate parse error: {e}")) } /// Phase 44 part 2: route generate() through the gateway's /// /v1/chat with provider="ollama" so the call lands in /// /v1/usage + Langfuse. Translates between the sidecar /// GenerateRequest/Response shape and the OpenAI-compat /// chat shape on the wire. async fn generate_via_gateway(&self, gateway_url: &str, req: GenerateRequest) -> Result { let mut messages = Vec::with_capacity(2); if let Some(sys) = &req.system { messages.push(serde_json::json!({"role": "system", "content": sys})); } messages.push(serde_json::json!({"role": "user", "content": req.prompt})); let mut body = serde_json::json!({ "messages": messages, "provider": "ollama", }); if let Some(m) = &req.model { body["model"] = serde_json::json!(m); } if let Some(t) = req.temperature { body["temperature"] = serde_json::json!(t); } if let Some(mt) = req.max_tokens { body["max_tokens"] = serde_json::json!(mt); } if let Some(th) = req.think { body["think"] = serde_json::json!(th); } let resp = self.client .post(format!("{}/v1/chat", gateway_url)) .json(&body) .send() .await .map_err(|e| format!("/v1/chat request failed: {e}"))?; if !resp.status().is_success() { let text = resp.text().await.unwrap_or_default(); return Err(format!("/v1/chat error: {text}")); } let parsed: serde_json::Value = resp.json().await .map_err(|e| format!("/v1/chat parse error: {e}"))?; let text = parsed .pointer("/choices/0/message/content") .and_then(|v| v.as_str()) .unwrap_or("") .to_string(); let model = parsed.get("model") .and_then(|v| v.as_str()) .unwrap_or_else(|| req.model.as_deref().unwrap_or("")) .to_string(); let prompt_tokens = parsed.pointer("/usage/prompt_tokens").and_then(|v| v.as_u64()); let completion_tokens = parsed.pointer("/usage/completion_tokens").and_then(|v| v.as_u64()); Ok(GenerateResponse { text, model, tokens_evaluated: prompt_tokens, tokens_generated: completion_tokens, }) } pub async fn rerank(&self, req: RerankRequest) -> Result { let resp = self.client .post(format!("{}/rerank", self.base_url)) .json(&req) .send() .await .map_err(|e| format!("rerank request failed: {e}"))?; if !resp.status().is_success() { let text = resp.text().await.unwrap_or_default(); return Err(format!("rerank error: {text}")); } resp.json().await.map_err(|e| format!("rerank parse error: {e}")) } /// Force Ollama to unload the named model from VRAM (keep_alive=0). /// Used for predictable profile swaps — without this, Ollama holds a /// model for its configured TTL (default 5min) and the previous /// profile's model can linger in VRAM next to the new one. pub async fn unload_model(&self, model: &str) -> Result { let resp = self.client .post(format!("{}/admin/unload", self.base_url)) .json(&serde_json::json!({ "model": model })) .send().await .map_err(|e| format!("unload request failed: {e}"))?; if !resp.status().is_success() { let text = resp.text().await.unwrap_or_default(); return Err(format!("unload error: {text}")); } resp.json().await.map_err(|e| format!("unload parse error: {e}")) } /// Ask Ollama to load the named model into VRAM proactively. Makes /// the first real request after profile activation fast (no cold-load /// latency). pub async fn preload_model(&self, model: &str) -> Result { let resp = self.client .post(format!("{}/admin/preload", self.base_url)) .json(&serde_json::json!({ "model": model })) .send().await .map_err(|e| format!("preload request failed: {e}"))?; if !resp.status().is_success() { let text = resp.text().await.unwrap_or_default(); return Err(format!("preload error: {text}")); } resp.json().await.map_err(|e| format!("preload parse error: {e}")) } /// GPU + loaded-model snapshot from the sidecar. Combines nvidia-smi /// output (if available) with Ollama's /api/ps. pub async fn vram_snapshot(&self) -> Result { let resp = self.client .get(format!("{}/admin/vram", self.base_url)) .send().await .map_err(|e| format!("vram request failed: {e}"))?; resp.json().await.map_err(|e| format!("vram parse error: {e}")) } }