- JobTracker extended with JobType::ProfileActivation + Embed - activate_profile returns job_id immediately, work spawns in background - /v1/chat, /v1/usage, /v1/sessions endpoints (OpenAI-compatible) - Langfuse trace integration (Phase 40 early deliverable) - 12 gateway unit tests green, curl gates pass
287 lines
10 KiB
Rust
287 lines
10 KiB
Rust
/// Dual-pipeline supervisor for embedding ingestion.
|
|
/// Splits work into ranges, runs parallel pipelines, handles failures
|
|
/// with round-robin retry and checkpointing.
|
|
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use tokio::sync::RwLock;
|
|
|
|
use aibridge::client::{AiClient, EmbedRequest};
|
|
use crate::chunker::TextChunk;
|
|
use crate::jobs::JobTracker;
|
|
use crate::store;
|
|
use object_store::ObjectStore;
|
|
|
|
/// A range of chunks to embed.
|
|
#[derive(Debug, Clone)]
|
|
struct ChunkRange {
|
|
start: usize,
|
|
end: usize,
|
|
attempts: u32,
|
|
}
|
|
|
|
/// Checkpoint: tracks which ranges are done, persisted for crash recovery.
|
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
|
pub struct Checkpoint {
|
|
pub job_id: String,
|
|
pub index_name: String,
|
|
pub total_chunks: usize,
|
|
pub completed_ranges: Vec<(usize, usize)>, // (start, end) pairs
|
|
pub failed_ranges: Vec<(usize, usize, String)>, // (start, end, error)
|
|
pub embedded_count: usize,
|
|
}
|
|
|
|
/// Pipeline status for monitoring.
|
|
#[derive(Debug, Clone, serde::Serialize)]
|
|
pub struct PipelineStatus {
|
|
pub pipeline_id: String,
|
|
pub current_range: Option<(usize, usize)>,
|
|
pub chunks_done: usize,
|
|
pub rate: f32,
|
|
pub alive: bool,
|
|
}
|
|
|
|
/// Supervisor configuration.
|
|
pub struct SupervisorConfig {
|
|
pub num_pipelines: usize, // default 2
|
|
pub batch_size: usize, // embeddings per API call (default 32)
|
|
pub range_size: usize, // chunks per range (default 2000)
|
|
pub max_retries: u32, // per range (default 3)
|
|
pub checkpoint_interval: usize, // checkpoint every N chunks (default 1000)
|
|
}
|
|
|
|
impl Default for SupervisorConfig {
|
|
fn default() -> Self {
|
|
Self {
|
|
num_pipelines: 4, // i9 + A4000 can handle parallel embedding
|
|
batch_size: 64, // larger batches for GPU throughput
|
|
range_size: 2500, // bigger ranges, fewer coordination overhead
|
|
max_retries: 3,
|
|
checkpoint_interval: 1000,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Run supervised dual-pipeline embedding.
|
|
pub async fn run_supervised(
|
|
job_id: &str,
|
|
index_name: &str,
|
|
chunks: Vec<TextChunk>,
|
|
ai_client: &AiClient,
|
|
obj_store: &Arc<dyn ObjectStore>,
|
|
tracker: &JobTracker,
|
|
config: SupervisorConfig,
|
|
) -> Result<String, String> {
|
|
let total = chunks.len();
|
|
let chunks = Arc::new(chunks);
|
|
|
|
tracing::info!("supervisor: starting {} pipelines for {} chunks (range_size={})",
|
|
config.num_pipelines, total, config.range_size);
|
|
|
|
// Split into ranges
|
|
let mut ranges: Vec<ChunkRange> = Vec::new();
|
|
let mut start = 0;
|
|
while start < total {
|
|
let end = (start + config.range_size).min(total);
|
|
ranges.push(ChunkRange { start, end, attempts: 0 });
|
|
start = end;
|
|
}
|
|
|
|
tracing::info!("supervisor: {} ranges created", ranges.len());
|
|
|
|
// Shared state
|
|
let work_queue: Arc<RwLock<Vec<ChunkRange>>> = Arc::new(RwLock::new(ranges));
|
|
let all_vectors: Arc<RwLock<HashMap<usize, Vec<Vec<f64>>>>> = Arc::new(RwLock::new(HashMap::new()));
|
|
let dead_letter: Arc<RwLock<Vec<(usize, usize, String)>>> = Arc::new(RwLock::new(Vec::new()));
|
|
let global_count: Arc<std::sync::atomic::AtomicUsize> = Arc::new(std::sync::atomic::AtomicUsize::new(0));
|
|
let start_time = std::time::Instant::now();
|
|
|
|
// Checkpoint state
|
|
let checkpoint: Arc<RwLock<Checkpoint>> = Arc::new(RwLock::new(Checkpoint {
|
|
job_id: job_id.to_string(),
|
|
index_name: index_name.to_string(),
|
|
total_chunks: total,
|
|
completed_ranges: vec![],
|
|
failed_ranges: vec![],
|
|
embedded_count: 0,
|
|
}));
|
|
|
|
// Try to load existing checkpoint (crash recovery)
|
|
if let Ok(data) = storaged::ops::get(obj_store, &format!("checkpoints/{job_id}.json")).await {
|
|
if let Ok(saved) = serde_json::from_slice::<Checkpoint>(&data) {
|
|
tracing::info!("supervisor: recovered checkpoint — {}/{} already done",
|
|
saved.embedded_count, saved.total_chunks);
|
|
|
|
// Remove already-completed ranges from work queue
|
|
let mut queue = work_queue.write().await;
|
|
queue.retain(|r| !saved.completed_ranges.iter().any(|(s, e)| *s == r.start && *e == r.end));
|
|
global_count.store(saved.embedded_count, std::sync::atomic::Ordering::Relaxed);
|
|
|
|
*checkpoint.write().await = saved;
|
|
}
|
|
}
|
|
|
|
// Spawn pipeline workers
|
|
let mut handles = Vec::new();
|
|
for pipeline_id in 0..config.num_pipelines {
|
|
let queue = work_queue.clone();
|
|
let vectors = all_vectors.clone();
|
|
let dead = dead_letter.clone();
|
|
let count = global_count.clone();
|
|
let ckpt = checkpoint.clone();
|
|
let client = ai_client.clone();
|
|
let store = obj_store.clone();
|
|
let chunks = chunks.clone();
|
|
let tracker = tracker.clone();
|
|
let jid = job_id.to_string();
|
|
let batch_size = config.batch_size;
|
|
let max_retries = config.max_retries;
|
|
let ckpt_interval = config.checkpoint_interval;
|
|
|
|
let handle = tokio::spawn(async move {
|
|
let pid = format!("pipeline-{pipeline_id}");
|
|
tracing::info!("{pid}: started");
|
|
|
|
loop {
|
|
// Grab next range from queue
|
|
let range = {
|
|
let mut q = queue.write().await;
|
|
q.pop()
|
|
};
|
|
|
|
let range = match range {
|
|
Some(r) => r,
|
|
None => {
|
|
tracing::info!("{pid}: no more work, shutting down");
|
|
break;
|
|
}
|
|
};
|
|
|
|
tracing::debug!("{pid}: processing range [{}, {})", range.start, range.end);
|
|
|
|
// Embed the range
|
|
match embed_range(&chunks[range.start..range.end], &client, batch_size).await {
|
|
Ok(range_vectors) => {
|
|
let n = range_vectors.len();
|
|
vectors.write().await.insert(range.start, range_vectors);
|
|
let prev = count.fetch_add(n, std::sync::atomic::Ordering::Relaxed);
|
|
let done = prev + n;
|
|
|
|
// Update job tracker
|
|
let elapsed = start_time.elapsed().as_secs_f32();
|
|
let rate = if elapsed > 0.0 { done as f32 / elapsed } else { 0.0 };
|
|
tracker.update_embed_progress(&jid, done, rate).await;
|
|
|
|
// Update checkpoint
|
|
let mut ckpt = ckpt.write().await;
|
|
ckpt.completed_ranges.push((range.start, range.end));
|
|
ckpt.embedded_count = done;
|
|
|
|
// Persist checkpoint periodically
|
|
if done % ckpt_interval < n {
|
|
let json = serde_json::to_vec(&*ckpt).unwrap_or_default();
|
|
let _ = storaged::ops::put(&store, &format!("checkpoints/{jid}.json"), json.into()).await;
|
|
}
|
|
|
|
tracing::info!("{pid}: range [{},{}) done — {done}/{} total",
|
|
range.start, range.end, chunks.len());
|
|
}
|
|
Err(e) => {
|
|
let attempt = range.attempts + 1;
|
|
tracing::warn!("{pid}: range [{},{}) failed (attempt {attempt}/{max_retries}): {e}",
|
|
range.start, range.end);
|
|
|
|
if attempt < max_retries {
|
|
// Push back to queue for retry (round-robin to other pipeline)
|
|
queue.write().await.push(ChunkRange {
|
|
start: range.start,
|
|
end: range.end,
|
|
attempts: attempt,
|
|
});
|
|
} else {
|
|
// Dead letter
|
|
tracing::error!("{pid}: range [{},{}) dead-lettered after {max_retries} attempts",
|
|
range.start, range.end);
|
|
dead.write().await.push((range.start, range.end, e));
|
|
|
|
let mut ckpt = ckpt.write().await;
|
|
ckpt.failed_ranges.push((range.start, range.end, format!("max retries exceeded")));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
});
|
|
handles.push(handle);
|
|
}
|
|
|
|
// Wait for all pipelines
|
|
for handle in handles {
|
|
let _ = handle.await;
|
|
}
|
|
|
|
// Check for dead letters
|
|
let dead = dead_letter.read().await;
|
|
if !dead.is_empty() {
|
|
tracing::warn!("supervisor: {} ranges failed permanently", dead.len());
|
|
}
|
|
|
|
// Assemble vectors in order
|
|
let vectors_map = all_vectors.read().await;
|
|
let mut sorted_starts: Vec<usize> = vectors_map.keys().cloned().collect();
|
|
sorted_starts.sort();
|
|
|
|
let mut final_vectors: Vec<Vec<f64>> = Vec::with_capacity(total);
|
|
for start in sorted_starts {
|
|
final_vectors.extend(vectors_map[&start].clone());
|
|
}
|
|
|
|
let embedded_count = final_vectors.len();
|
|
tracing::info!("supervisor: {embedded_count}/{total} chunks embedded, storing index");
|
|
|
|
// Store — only the successfully embedded chunks
|
|
// We need to match chunks to vectors
|
|
let successful_chunks: Vec<TextChunk> = {
|
|
let ckpt = checkpoint.read().await;
|
|
let mut result = Vec::new();
|
|
for (s, e) in &ckpt.completed_ranges {
|
|
result.extend(chunks[*s..*e].iter().cloned());
|
|
}
|
|
result
|
|
};
|
|
|
|
if successful_chunks.is_empty() {
|
|
return Err("no chunks were successfully embedded".into());
|
|
}
|
|
|
|
let key = store::store_embeddings(obj_store, index_name, &successful_chunks, &final_vectors).await?;
|
|
|
|
// Clean up checkpoint
|
|
let _ = storaged::ops::delete(obj_store, &format!("checkpoints/{job_id}.json")).await;
|
|
|
|
let elapsed = start_time.elapsed().as_secs_f32();
|
|
let rate = embedded_count as f32 / elapsed;
|
|
tracing::info!("supervisor: completed — {embedded_count} vectors in {elapsed:.0}s ({rate:.0}/sec)");
|
|
|
|
Ok(key)
|
|
}
|
|
|
|
/// Embed a range of chunks. Returns vectors or error.
|
|
async fn embed_range(
|
|
chunks: &[TextChunk],
|
|
ai_client: &AiClient,
|
|
batch_size: usize,
|
|
) -> Result<Vec<Vec<f64>>, String> {
|
|
let mut vectors = Vec::with_capacity(chunks.len());
|
|
|
|
for batch in chunks.chunks(batch_size) {
|
|
let texts: Vec<String> = batch.iter().map(|c| c.text.clone()).collect();
|
|
let resp = ai_client.embed(EmbedRequest {
|
|
texts,
|
|
model: None,
|
|
}).await.map_err(|e| format!("embed error: {e}"))?;
|
|
vectors.extend(resp.embeddings);
|
|
}
|
|
|
|
Ok(vectors)
|
|
}
|