/// Brute-force vector search with cosine similarity. /// Works well up to ~100K vectors. HNSW index would go here for larger scale. use crate::store::StoredEmbedding; /// A search result with score. #[derive(Debug, Clone, serde::Serialize)] pub struct SearchResult { pub source: String, pub doc_id: String, pub chunk_idx: u32, pub chunk_text: String, pub score: f32, } /// Search embeddings by cosine similarity. Returns top_k results. pub fn search( query_vector: &[f32], embeddings: &[StoredEmbedding], top_k: usize, ) -> Vec { let query_norm = norm(query_vector); if query_norm == 0.0 { return vec![]; } let mut scored: Vec = embeddings.iter().map(|emb| { let score = cosine_similarity(query_vector, &emb.vector, query_norm); SearchResult { source: emb.source.clone(), doc_id: emb.doc_id.clone(), chunk_idx: emb.chunk_idx, chunk_text: emb.chunk_text.clone(), score, } }).collect(); // Sort descending by score scored.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal)); scored.truncate(top_k); scored } fn cosine_similarity(a: &[f32], b: &[f32], a_norm: f32) -> f32 { if a.len() != b.len() { return 0.0; } let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(); let b_norm = norm(b); if b_norm == 0.0 { return 0.0; } dot / (a_norm * b_norm) } fn norm(v: &[f32]) -> f32 { v.iter().map(|x| x * x).sum::().sqrt() } #[cfg(test)] mod tests { use super::*; #[test] fn identical_vectors_score_1() { let v = vec![1.0, 2.0, 3.0]; let emb = StoredEmbedding { source: "test".into(), doc_id: "1".into(), chunk_idx: 0, chunk_text: "hello".into(), vector: v.clone(), }; let results = search(&v, &[emb], 1); assert!((results[0].score - 1.0).abs() < 0.001); } #[test] fn orthogonal_vectors_score_0() { let q = vec![1.0, 0.0]; let emb = StoredEmbedding { source: "test".into(), doc_id: "1".into(), chunk_idx: 0, chunk_text: "hello".into(), vector: vec![0.0, 1.0], }; let results = search(&q, &[emb], 1); assert!(results[0].score.abs() < 0.001); } #[test] fn returns_top_k() { let q = vec![1.0, 0.0, 0.0]; let embs: Vec = (0..10).map(|i| StoredEmbedding { source: "test".into(), doc_id: format!("{i}"), chunk_idx: 0, chunk_text: format!("doc {i}"), vector: vec![1.0 - i as f32 * 0.1, i as f32 * 0.1, 0.0], }).collect(); let results = search(&q, &embs, 3); assert_eq!(results.len(), 3); assert!(results[0].score >= results[1].score); assert!(results[1].score >= results[2].score); } }