/// Dual-pipeline supervisor for embedding ingestion. /// Splits work into ranges, runs parallel pipelines, handles failures /// with round-robin retry and checkpointing. use std::collections::HashMap; use std::sync::Arc; use tokio::sync::RwLock; use aibridge::client::{AiClient, EmbedRequest}; use crate::chunker::TextChunk; use crate::jobs::JobTracker; use crate::store; use object_store::ObjectStore; /// A range of chunks to embed. #[derive(Debug, Clone)] struct ChunkRange { start: usize, end: usize, attempts: u32, } /// Checkpoint: tracks which ranges are done, persisted for crash recovery. #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct Checkpoint { pub job_id: String, pub index_name: String, pub total_chunks: usize, pub completed_ranges: Vec<(usize, usize)>, // (start, end) pairs pub failed_ranges: Vec<(usize, usize, String)>, // (start, end, error) pub embedded_count: usize, } /// Pipeline status for monitoring. #[derive(Debug, Clone, serde::Serialize)] pub struct PipelineStatus { pub pipeline_id: String, pub current_range: Option<(usize, usize)>, pub chunks_done: usize, pub rate: f32, pub alive: bool, } /// Supervisor configuration. pub struct SupervisorConfig { pub num_pipelines: usize, // default 2 pub batch_size: usize, // embeddings per API call (default 32) pub range_size: usize, // chunks per range (default 2000) pub max_retries: u32, // per range (default 3) pub checkpoint_interval: usize, // checkpoint every N chunks (default 1000) } impl Default for SupervisorConfig { fn default() -> Self { Self { num_pipelines: 4, // i9 + A4000 can handle parallel embedding batch_size: 64, // larger batches for GPU throughput range_size: 2500, // bigger ranges, fewer coordination overhead max_retries: 3, checkpoint_interval: 1000, } } } /// Run supervised dual-pipeline embedding. pub async fn run_supervised( job_id: &str, index_name: &str, chunks: Vec, ai_client: &AiClient, obj_store: &Arc, tracker: &JobTracker, config: SupervisorConfig, ) -> Result { let total = chunks.len(); let chunks = Arc::new(chunks); tracing::info!("supervisor: starting {} pipelines for {} chunks (range_size={})", config.num_pipelines, total, config.range_size); // Split into ranges let mut ranges: Vec = Vec::new(); let mut start = 0; while start < total { let end = (start + config.range_size).min(total); ranges.push(ChunkRange { start, end, attempts: 0 }); start = end; } tracing::info!("supervisor: {} ranges created", ranges.len()); // Shared state let work_queue: Arc>> = Arc::new(RwLock::new(ranges)); let all_vectors: Arc>>>> = Arc::new(RwLock::new(HashMap::new())); let dead_letter: Arc>> = Arc::new(RwLock::new(Vec::new())); let global_count: Arc = Arc::new(std::sync::atomic::AtomicUsize::new(0)); let start_time = std::time::Instant::now(); // Checkpoint state let checkpoint: Arc> = Arc::new(RwLock::new(Checkpoint { job_id: job_id.to_string(), index_name: index_name.to_string(), total_chunks: total, completed_ranges: vec![], failed_ranges: vec![], embedded_count: 0, })); // Try to load existing checkpoint (crash recovery) if let Ok(data) = storaged::ops::get(obj_store, &format!("checkpoints/{job_id}.json")).await { if let Ok(saved) = serde_json::from_slice::(&data) { tracing::info!("supervisor: recovered checkpoint — {}/{} already done", saved.embedded_count, saved.total_chunks); // Remove already-completed ranges from work queue let mut queue = work_queue.write().await; queue.retain(|r| !saved.completed_ranges.iter().any(|(s, e)| *s == r.start && *e == r.end)); global_count.store(saved.embedded_count, std::sync::atomic::Ordering::Relaxed); *checkpoint.write().await = saved; } } // Spawn pipeline workers let mut handles = Vec::new(); for pipeline_id in 0..config.num_pipelines { let queue = work_queue.clone(); let vectors = all_vectors.clone(); let dead = dead_letter.clone(); let count = global_count.clone(); let ckpt = checkpoint.clone(); let client = ai_client.clone(); let store = obj_store.clone(); let chunks = chunks.clone(); let tracker = tracker.clone(); let jid = job_id.to_string(); let batch_size = config.batch_size; let max_retries = config.max_retries; let ckpt_interval = config.checkpoint_interval; let handle = tokio::spawn(async move { let pid = format!("pipeline-{pipeline_id}"); tracing::info!("{pid}: started"); loop { // Grab next range from queue let range = { let mut q = queue.write().await; q.pop() }; let range = match range { Some(r) => r, None => { tracing::info!("{pid}: no more work, shutting down"); break; } }; tracing::debug!("{pid}: processing range [{}, {})", range.start, range.end); // Embed the range match embed_range(&chunks[range.start..range.end], &client, batch_size).await { Ok(range_vectors) => { let n = range_vectors.len(); vectors.write().await.insert(range.start, range_vectors); let prev = count.fetch_add(n, std::sync::atomic::Ordering::Relaxed); let done = prev + n; // Update job tracker let elapsed = start_time.elapsed().as_secs_f32(); let rate = if elapsed > 0.0 { done as f32 / elapsed } else { 0.0 }; tracker.update_embed_progress(&jid, done, rate).await; // Update checkpoint let mut ckpt = ckpt.write().await; ckpt.completed_ranges.push((range.start, range.end)); ckpt.embedded_count = done; // Persist checkpoint periodically if done % ckpt_interval < n { let json = serde_json::to_vec(&*ckpt).unwrap_or_default(); let _ = storaged::ops::put(&store, &format!("checkpoints/{jid}.json"), json.into()).await; } tracing::info!("{pid}: range [{},{}) done — {done}/{} total", range.start, range.end, chunks.len()); } Err(e) => { let attempt = range.attempts + 1; tracing::warn!("{pid}: range [{},{}) failed (attempt {attempt}/{max_retries}): {e}", range.start, range.end); if attempt < max_retries { // Push back to queue for retry (round-robin to other pipeline) queue.write().await.push(ChunkRange { start: range.start, end: range.end, attempts: attempt, }); } else { // Dead letter tracing::error!("{pid}: range [{},{}) dead-lettered after {max_retries} attempts", range.start, range.end); dead.write().await.push((range.start, range.end, e)); let mut ckpt = ckpt.write().await; ckpt.failed_ranges.push((range.start, range.end, format!("max retries exceeded"))); } } } } }); handles.push(handle); } // Wait for all pipelines for handle in handles { let _ = handle.await; } // Check for dead letters let dead = dead_letter.read().await; if !dead.is_empty() { tracing::warn!("supervisor: {} ranges failed permanently", dead.len()); } // Assemble vectors in order let vectors_map = all_vectors.read().await; let mut sorted_starts: Vec = vectors_map.keys().cloned().collect(); sorted_starts.sort(); let mut final_vectors: Vec> = Vec::with_capacity(total); for start in sorted_starts { final_vectors.extend(vectors_map[&start].clone()); } let embedded_count = final_vectors.len(); tracing::info!("supervisor: {embedded_count}/{total} chunks embedded, storing index"); // Store — only the successfully embedded chunks // We need to match chunks to vectors let successful_chunks: Vec = { let ckpt = checkpoint.read().await; let mut result = Vec::new(); for (s, e) in &ckpt.completed_ranges { result.extend(chunks[*s..*e].iter().cloned()); } result }; if successful_chunks.is_empty() { return Err("no chunks were successfully embedded".into()); } let key = store::store_embeddings(obj_store, index_name, &successful_chunks, &final_vectors).await?; // Clean up checkpoint let _ = storaged::ops::delete(obj_store, &format!("checkpoints/{job_id}.json")).await; let elapsed = start_time.elapsed().as_secs_f32(); let rate = embedded_count as f32 / elapsed; tracing::info!("supervisor: completed — {embedded_count} vectors in {elapsed:.0}s ({rate:.0}/sec)"); Ok(key) } /// Embed a range of chunks. Returns vectors or error. async fn embed_range( chunks: &[TextChunk], ai_client: &AiClient, batch_size: usize, ) -> Result>, String> { let mut vectors = Vec::with_capacity(chunks.len()); for batch in chunks.chunks(batch_size) { let texts: Vec = batch.iter().map(|c| c.text.clone()).collect(); let resp = ai_client.embed(EmbedRequest { texts, model: None, }).await.map_err(|e| format!("embed error: {e}"))?; vectors.extend(resp.embeddings); } Ok(vectors) }