use std::collections::HashMap; use rusqlite::Connection; use rusqlite::OptionalExtension; use crate::core::error::Result; use crate::embedding::chunk_ids::decode_rowid; #[derive(Debug)] pub struct VectorResult { pub document_id: i64, pub distance: f64, } fn max_chunks_per_document(conn: &Connection) -> Result { let stored: Option = conn .query_row( "SELECT MAX(chunk_count) FROM embedding_metadata WHERE chunk_index = 0 AND chunk_count IS NOT NULL", [], |row| row.get(0), ) .optional()? .flatten(); if let Some(max) = stored { return Ok(max); } Ok(conn .query_row( "SELECT COALESCE(MAX(cnt), 1) FROM ( SELECT COUNT(*) as cnt FROM embedding_metadata WHERE last_error IS NULL GROUP BY document_id )", [], |row| row.get(0), ) .optional()? .unwrap_or(1)) } /// sqlite-vec hard limit for KNN `k` parameter. const SQLITE_VEC_KNN_MAX: usize = 4_096; /// Compute the KNN k value from the requested limit and the max chunks per /// document. The result is guaranteed to never exceed [`SQLITE_VEC_KNN_MAX`]. fn compute_knn_k(limit: usize, max_chunks_per_doc: i64) -> usize { let max_chunks = max_chunks_per_doc.unsigned_abs().max(1) as usize; let multiplier = (max_chunks * 3 / 2 + 1).clamp(8, 200); (limit * multiplier).min(SQLITE_VEC_KNN_MAX) } pub fn search_vector( conn: &Connection, query_embedding: &[f32], limit: usize, ) -> Result> { if query_embedding.is_empty() || limit == 0 { return Ok(Vec::new()); } let embedding_bytes: Vec = query_embedding .iter() .flat_map(|f| f.to_le_bytes()) .collect(); let max_chunks = max_chunks_per_document(conn)?.max(1); let k = compute_knn_k(limit, max_chunks); let mut stmt = conn.prepare( "SELECT rowid, distance FROM embeddings WHERE embedding MATCH ?1 AND k = ?2 ORDER BY distance", )?; let rows: Vec<(i64, f64)> = stmt .query_map(rusqlite::params![embedding_bytes, k as i64], |row| { Ok((row.get(0)?, row.get(1)?)) })? .collect::, _>>()?; let mut best: HashMap = HashMap::new(); for (rowid, distance) in rows { let (document_id, _chunk_index) = decode_rowid(rowid); best.entry(document_id) .and_modify(|d| { if distance < *d { *d = distance; } }) .or_insert(distance); } let mut results: Vec = best .into_iter() .map(|(document_id, distance)| VectorResult { document_id, distance, }) .collect(); results.sort_by(|a, b| a.distance.total_cmp(&b.distance)); results.truncate(limit); Ok(results) } #[cfg(test)] mod tests { use super::*; #[test] fn test_empty_returns_empty() { let result = search_vector_dedup(vec![], 10); assert!(result.is_empty()); } #[test] fn test_dedup_keeps_best_distance() { let rows = vec![(1000_i64, 0.5_f64), (1001, 0.3), (2000, 0.4)]; let results = search_vector_dedup(rows, 10); assert_eq!(results.len(), 2); assert_eq!(results[0].document_id, 1); assert!((results[0].distance - 0.3).abs() < f64::EPSILON); assert_eq!(results[1].document_id, 2); } #[test] fn test_dedup_respects_limit() { let rows = vec![(1000_i64, 0.1_f64), (2000, 0.2), (3000, 0.3)]; let results = search_vector_dedup(rows, 2); assert_eq!(results.len(), 2); } #[test] fn test_knn_k_never_exceeds_sqlite_vec_limit() { // Brute-force: no combination of valid inputs should exceed 4096. for limit in [1, 10, 50, 100, 500, 1000, 1500, 2000, 5000] { for max_chunks in [1, 2, 5, 10, 50, 100, 200, 500, 1000] { let k = compute_knn_k(limit, max_chunks); assert!( k <= SQLITE_VEC_KNN_MAX, "k={k} exceeded limit for limit={limit}, max_chunks={max_chunks}" ); } } } #[test] fn test_knn_k_reproduces_original_bug_scenario() { // The original bug: limit=1500 (RECALL_CAP) with multiplier >= 8 // produced k=10_000 which exceeded sqlite-vec's 4096 cap. let k = compute_knn_k(1500, 1); assert!(k <= SQLITE_VEC_KNN_MAX, "k={k} exceeded 4096 at RECALL_CAP with 1 chunk"); } #[test] fn test_knn_k_small_limit_uses_minimum_multiplier() { // With 1 chunk, multiplier clamps to minimum of 8. let k = compute_knn_k(10, 1); assert_eq!(k, 80); } #[test] fn test_knn_k_high_chunks_caps_multiplier() { // With 200 chunks, multiplier = (200*3/2 + 1) = 301 → clamped to 200. // limit=10 * 200 = 2000 < 4096. let k = compute_knn_k(10, 200); assert_eq!(k, 2000); } #[test] fn test_knn_k_zero_max_chunks_treated_as_one() { // max_chunks_per_doc=0 → unsigned_abs().max(1) = 1 let k = compute_knn_k(10, 0); assert_eq!(k, 80); // multiplier clamp(8, 200) → 8 } #[test] fn test_knn_k_negative_max_chunks_uses_absolute() { // Defensive: negative values use unsigned_abs let k = compute_knn_k(10, -5); assert_eq!(k, compute_knn_k(10, 5)); } fn search_vector_dedup(rows: Vec<(i64, f64)>, limit: usize) -> Vec { let mut best: HashMap = HashMap::new(); for (rowid, distance) in rows { let (document_id, _) = decode_rowid(rowid); best.entry(document_id) .and_modify(|d| { if distance < *d { *d = distance; } }) .or_insert(distance); } let mut results: Vec = best .into_iter() .map(|(document_id, distance)| VectorResult { document_id, distance, }) .collect(); results.sort_by(|a, b| a.distance.total_cmp(&b.distance)); results.truncate(limit); results } }