use reqwest::Client; use serde::{Deserialize, Serialize}; use std::time::Duration; /// HTTP client for the Python AI sidecar. #[derive(Clone)] pub struct AiClient { client: Client, base_url: String, } // -- Request/Response types -- #[derive(Serialize, Deserialize)] pub struct EmbedRequest { pub texts: Vec, #[serde(skip_serializing_if = "Option::is_none")] pub model: Option, } #[derive(Deserialize, Serialize, Clone)] pub struct EmbedResponse { pub embeddings: Vec>, pub model: String, pub dimensions: usize, } #[derive(Serialize, Deserialize)] pub struct GenerateRequest { pub prompt: String, #[serde(skip_serializing_if = "Option::is_none")] pub model: Option, #[serde(skip_serializing_if = "Option::is_none")] pub system: Option, #[serde(skip_serializing_if = "Option::is_none")] pub temperature: Option, #[serde(skip_serializing_if = "Option::is_none")] pub max_tokens: Option, } #[derive(Deserialize, Serialize, Clone)] pub struct GenerateResponse { pub text: String, pub model: String, pub tokens_evaluated: Option, pub tokens_generated: Option, } #[derive(Serialize, Deserialize)] pub struct RerankRequest { pub query: String, pub documents: Vec, #[serde(skip_serializing_if = "Option::is_none")] pub model: Option, #[serde(skip_serializing_if = "Option::is_none")] pub top_k: Option, } #[derive(Deserialize, Serialize, Clone)] pub struct ScoredDocument { pub index: usize, pub text: String, pub score: f64, } #[derive(Deserialize, Serialize, Clone)] pub struct RerankResponse { pub results: Vec, pub model: String, } impl AiClient { pub fn new(base_url: &str) -> Self { let client = Client::builder() .timeout(Duration::from_secs(120)) .build() .expect("failed to build HTTP client"); Self { client, base_url: base_url.trim_end_matches('/').to_string(), } } pub async fn health(&self) -> Result { let resp = self.client .get(format!("{}/health", self.base_url)) .send() .await .map_err(|e| format!("sidecar unreachable: {e}"))?; resp.json().await.map_err(|e| format!("invalid response: {e}")) } pub async fn embed(&self, req: EmbedRequest) -> Result { let resp = self.client .post(format!("{}/embed", self.base_url)) .json(&req) .send() .await .map_err(|e| format!("embed request failed: {e}"))?; if !resp.status().is_success() { let text = resp.text().await.unwrap_or_default(); return Err(format!("embed error ({}): {text}", text.len())); } resp.json().await.map_err(|e| format!("embed parse error: {e}")) } pub async fn generate(&self, req: GenerateRequest) -> Result { let resp = self.client .post(format!("{}/generate", self.base_url)) .json(&req) .send() .await .map_err(|e| format!("generate request failed: {e}"))?; if !resp.status().is_success() { let text = resp.text().await.unwrap_or_default(); return Err(format!("generate error: {text}")); } resp.json().await.map_err(|e| format!("generate parse error: {e}")) } pub async fn rerank(&self, req: RerankRequest) -> Result { let resp = self.client .post(format!("{}/rerank", self.base_url)) .json(&req) .send() .await .map_err(|e| format!("rerank request failed: {e}"))?; if !resp.status().is_success() { let text = resp.text().await.unwrap_or_default(); return Err(format!("rerank error: {text}")); } resp.json().await.map_err(|e| format!("rerank parse error: {e}")) } }