use axum::{ Json, Router, extract::{Path, Query, State}, http::StatusCode, response::IntoResponse, routing::{get, post}, }; use object_store::ObjectStore; use serde::{Deserialize, Serialize}; use std::sync::Arc; use aibridge::client::{AiClient, EmbedRequest}; use crate::{chunker, hnsw, index_registry, jobs, rag, search, store, supervisor}; #[derive(Clone)] pub struct VectorState { pub store: Arc, pub ai_client: AiClient, pub job_tracker: jobs::JobTracker, pub index_registry: index_registry::IndexRegistry, pub hnsw_store: hnsw::HnswStore, } pub fn router(state: VectorState) -> Router { Router::new() .route("/health", get(health)) .route("/index", post(create_index)) .route("/indexes", get(list_indexes)) .route("/indexes/{name}", get(get_index_meta)) .route("/jobs", get(list_jobs)) .route("/jobs/{id}", get(get_job)) .route("/search", post(search_index)) .route("/rag", post(rag_query)) .route("/hnsw/build", post(build_hnsw)) .route("/hnsw/search", post(search_hnsw)) .route("/hnsw/list", get(list_hnsw)) .with_state(state) } async fn health() -> &'static str { "vectord ok" } // --- Background Index Creation --- #[derive(Deserialize)] struct CreateIndexRequest { index_name: String, source: String, documents: Vec, chunk_size: Option, overlap: Option, } #[derive(Deserialize)] struct DocInput { id: String, text: String, } #[derive(Serialize)] struct CreateIndexResponse { job_id: String, index_name: String, documents: usize, chunks: usize, message: String, } async fn create_index( State(state): State, Json(req): Json, ) -> impl IntoResponse { let chunk_size = req.chunk_size.unwrap_or(500); let overlap = req.overlap.unwrap_or(50); // Chunk synchronously (fast) let doc_ids: Vec = req.documents.iter().map(|d| d.id.clone()).collect(); let texts: Vec = req.documents.iter().map(|d| d.text.clone()).collect(); let chunks = chunker::chunk_column(&req.source, &doc_ids, &texts, chunk_size, overlap); if chunks.is_empty() { return Err((StatusCode::BAD_REQUEST, "no text to index".to_string())); } let n_docs = req.documents.len(); let n_chunks = chunks.len(); let index_name = req.index_name.clone(); // Create job and return immediately let job_id = state.job_tracker.create(&index_name, n_chunks).await; tracing::info!("job {job_id}: indexing '{}' — {} docs → {} chunks (background)", index_name, n_docs, n_chunks); // Spawn supervised dual-pipeline embedding let tracker = state.job_tracker.clone(); let ai_client = state.ai_client.clone(); let obj_store = state.store.clone(); let registry = state.index_registry.clone(); let jid = job_id.clone(); let source_name = req.source.clone(); let idx_name = req.index_name.clone(); tokio::spawn(async move { let start_time = std::time::Instant::now(); let config = supervisor::SupervisorConfig::default(); let result = supervisor::run_supervised( &jid, &idx_name, chunks, &ai_client, &obj_store, &tracker, config, ).await; match result { Ok(key) => { let elapsed = start_time.elapsed().as_secs_f32(); let rate = if elapsed > 0.0 { n_chunks as f32 / elapsed } else { 0.0 }; // Register index metadata with model version info let meta = index_registry::IndexMeta { index_name: idx_name.clone(), source: source_name, model_name: "nomic-embed-text".to_string(), // from sidecar config model_version: "latest".to_string(), dimensions: 768, chunk_count: n_chunks, doc_count: n_docs, chunk_size: chunk_size, overlap: overlap, storage_key: key.clone(), created_at: chrono::Utc::now(), build_time_secs: elapsed, chunks_per_sec: rate, }; let _ = registry.register(meta).await; tracker.complete(&jid, key).await; tracing::info!("job {jid}: completed — {n_chunks} chunks in {elapsed:.0}s ({rate:.0}/sec)"); } Err(e) => { tracker.fail(&jid, e.clone()).await; tracing::error!("job {jid}: failed — {e}"); } } }); Ok((StatusCode::ACCEPTED, Json(CreateIndexResponse { job_id, index_name: req.index_name, documents: n_docs, chunks: n_chunks, message: format!("embedding {} chunks in background — poll /vectors/jobs/{{id}} for progress", n_chunks), }))) } // --- Index Registry --- #[derive(Deserialize)] struct IndexListQuery { source: Option, model: Option, } async fn list_indexes( State(state): State, Query(q): Query, ) -> impl IntoResponse { let indexes = state.index_registry.list(q.source.as_deref(), q.model.as_deref()).await; Json(indexes) } async fn get_index_meta( State(state): State, Path(name): Path, ) -> impl IntoResponse { match state.index_registry.get(&name).await { Some(meta) => Ok(Json(meta)), None => Err((StatusCode::NOT_FOUND, format!("index not found: {name}"))), } } // --- unused legacy function below, kept for reference --- #[allow(dead_code)] /// Legacy single-pipeline embedding (replaced by supervisor). async fn _run_embedding_job_legacy( job_id: &str, index_name: &str, chunks: &[chunker::TextChunk], ai_client: &AiClient, store: &Arc, tracker: &jobs::JobTracker, ) -> Result { let batch_size = 32; let mut all_vectors: Vec> = Vec::new(); let start = std::time::Instant::now(); for (i, batch) in chunks.chunks(batch_size).enumerate() { let texts: Vec = batch.iter().map(|c| c.text.clone()).collect(); let embed_resp = ai_client.embed(EmbedRequest { texts, model: None, }).await.map_err(|e| format!("embed batch {} error: {e}", i))?; all_vectors.extend(embed_resp.embeddings); // Update progress let elapsed = start.elapsed().as_secs_f32(); let rate = if elapsed > 0.0 { all_vectors.len() as f32 / elapsed } else { 0.0 }; tracker.update_progress(job_id, all_vectors.len(), rate).await; // Log every 100 batches if (i + 1) % 100 == 0 { let pct = (all_vectors.len() as f32 / chunks.len() as f32) * 100.0; let eta = if rate > 0.0 { (chunks.len() - all_vectors.len()) as f32 / rate } else { 0.0 }; tracing::info!("job {job_id}: {}/{} chunks ({pct:.0}%), {rate:.0}/sec, ETA {eta:.0}s", all_vectors.len(), chunks.len()); } } // Store let key = store::store_embeddings(store, index_name, chunks, &all_vectors).await?; Ok(key) } // --- Job Status --- async fn list_jobs(State(state): State) -> impl IntoResponse { let jobs = state.job_tracker.list().await; Json(jobs) } async fn get_job( State(state): State, Path(id): Path, ) -> impl IntoResponse { match state.job_tracker.get(&id).await { Some(job) => Ok(Json(job)), None => Err((StatusCode::NOT_FOUND, format!("job not found: {id}"))), } } // --- Search --- #[derive(Deserialize)] struct SearchRequest { index_name: String, query: String, top_k: Option, } #[derive(Serialize)] struct SearchResponse { results: Vec, query: String, } async fn search_index( State(state): State, Json(req): Json, ) -> impl IntoResponse { let top_k = req.top_k.unwrap_or(5); let embed_resp = state.ai_client.embed(EmbedRequest { texts: vec![req.query.clone()], model: None, }).await.map_err(|e| (StatusCode::BAD_GATEWAY, format!("embed error: {e}")))?; if embed_resp.embeddings.is_empty() { return Err((StatusCode::BAD_GATEWAY, "no embedding returned".to_string())); } let query_vec: Vec = embed_resp.embeddings[0].iter().map(|&x| x as f32).collect(); let embeddings = store::load_embeddings(&state.store, &req.index_name) .await .map_err(|e| (StatusCode::NOT_FOUND, format!("index not found: {e}")))?; let results = search::search(&query_vec, &embeddings, top_k); Ok(Json(SearchResponse { results, query: req.query, })) } // --- RAG --- #[derive(Deserialize)] struct RagRequest { index_name: String, question: String, top_k: Option, } async fn rag_query( State(state): State, Json(req): Json, ) -> impl IntoResponse { let top_k = req.top_k.unwrap_or(5); match rag::query(&req.question, &req.index_name, top_k, &state.store, &state.ai_client).await { Ok(resp) => Ok(Json(resp)), Err(e) => Err((StatusCode::INTERNAL_SERVER_ERROR, e)), } } // --- HNSW Fast Search --- #[derive(Deserialize)] struct BuildHnswRequest { /// Name of the stored vector index to build HNSW from index_name: String, } /// Build an HNSW index from an existing stored vector index. /// Loads embeddings from Parquet, builds HNSW in memory. async fn build_hnsw( State(state): State, Json(req): Json, ) -> impl IntoResponse { tracing::info!("building HNSW for '{}'", req.index_name); // Load embeddings from Parquet let embeddings = store::load_embeddings(&state.store, &req.index_name) .await .map_err(|e| (StatusCode::NOT_FOUND, format!("index not found: {e}")))?; let n = embeddings.len(); tracing::info!("loaded {} embeddings, building HNSW...", n); // Build HNSW match state.hnsw_store.build_index(&req.index_name, embeddings).await { Ok(stats) => Ok(Json(stats)), Err(e) => Err((StatusCode::INTERNAL_SERVER_ERROR, e)), } } #[derive(Deserialize)] struct HnswSearchRequest { index_name: String, query: String, top_k: Option, } /// Search using HNSW — approximate nearest neighbors, much faster than brute-force. async fn search_hnsw( State(state): State, Json(req): Json, ) -> impl IntoResponse { let top_k = req.top_k.unwrap_or(5); // Embed query let embed_resp = state.ai_client.embed(EmbedRequest { texts: vec![req.query.clone()], model: None, }).await.map_err(|e| (StatusCode::BAD_GATEWAY, format!("embed error: {e}")))?; if embed_resp.embeddings.is_empty() { return Err((StatusCode::BAD_GATEWAY, "no embedding returned".to_string())); } let query_vec: Vec = embed_resp.embeddings[0].iter().map(|&x| x as f32).collect(); // Search HNSW match state.hnsw_store.search(&req.index_name, &query_vec, top_k).await { Ok(results) => Ok(Json(serde_json::json!({ "results": results, "query": req.query, "method": "hnsw", }))), Err(e) => Err((StatusCode::NOT_FOUND, e)), } } async fn list_hnsw(State(state): State) -> impl IntoResponse { Json(state.hnsw_store.list().await) }