diff --git a/src/search/filters.rs b/src/search/filters.rs new file mode 100644 index 0000000..409de60 --- /dev/null +++ b/src/search/filters.rs @@ -0,0 +1,227 @@ +use crate::core::error::Result; +use crate::documents::SourceType; +use rusqlite::Connection; + +const DEFAULT_LIMIT: usize = 20; +const MAX_LIMIT: usize = 100; + +/// Path filter: exact match or prefix match (trailing `/`). +#[derive(Debug, Clone)] +pub enum PathFilter { + Exact(String), + Prefix(String), +} + +/// Filters applied to search results post-retrieval. +#[derive(Debug, Clone, Default)] +pub struct SearchFilters { + pub source_type: Option, + pub author: Option, + pub project_id: Option, + pub after: Option, + pub updated_after: Option, + pub labels: Vec, + pub path: Option, + pub limit: usize, +} + +impl SearchFilters { + /// Returns true if any filter (besides limit) is set. + pub fn has_any_filter(&self) -> bool { + self.source_type.is_some() + || self.author.is_some() + || self.project_id.is_some() + || self.after.is_some() + || self.updated_after.is_some() + || !self.labels.is_empty() + || self.path.is_some() + } + + /// Clamp limit to [1, 100], defaulting 0 to 20. + pub fn clamp_limit(&self) -> usize { + if self.limit == 0 { + DEFAULT_LIMIT + } else { + self.limit.min(MAX_LIMIT) + } + } +} + +/// Escape SQL LIKE wildcards in a string. +fn escape_like(s: &str) -> String { + s.replace('\\', "\\\\") + .replace('%', "\\%") + .replace('_', "\\_") +} + +/// Apply filters to a ranked list of document IDs, preserving rank order. +/// +/// Uses json_each() to pass ranked IDs efficiently and maintain ordering +/// via ORDER BY j.key. +pub fn apply_filters( + conn: &Connection, + document_ids: &[i64], + filters: &SearchFilters, +) -> Result> { + if document_ids.is_empty() { + return Ok(Vec::new()); + } + + let ids_json = serde_json::to_string(document_ids) + .map_err(|e| crate::core::error::LoreError::Other(e.to_string()))?; + + let mut sql = String::from( + "SELECT d.id FROM json_each(?1) AS j JOIN documents d ON d.id = j.value WHERE 1=1", + ); + let mut params: Vec> = vec![Box::new(ids_json)]; + let mut param_idx = 2; + + if let Some(ref st) = filters.source_type { + sql.push_str(&format!(" AND d.source_type = ?{}", param_idx)); + params.push(Box::new(st.as_str().to_string())); + param_idx += 1; + } + + if let Some(ref author) = filters.author { + sql.push_str(&format!(" AND d.author_username = ?{}", param_idx)); + params.push(Box::new(author.clone())); + param_idx += 1; + } + + if let Some(pid) = filters.project_id { + sql.push_str(&format!(" AND d.project_id = ?{}", param_idx)); + params.push(Box::new(pid)); + param_idx += 1; + } + + if let Some(after) = filters.after { + sql.push_str(&format!(" AND d.created_at >= ?{}", param_idx)); + params.push(Box::new(after)); + param_idx += 1; + } + + if let Some(updated_after) = filters.updated_after { + sql.push_str(&format!(" AND d.updated_at >= ?{}", param_idx)); + params.push(Box::new(updated_after)); + param_idx += 1; + } + + for label in &filters.labels { + sql.push_str(&format!( + " AND EXISTS (SELECT 1 FROM document_labels dl WHERE dl.document_id = d.id AND dl.label_name = ?{})", + param_idx + )); + params.push(Box::new(label.clone())); + param_idx += 1; + } + + if let Some(ref path_filter) = filters.path { + match path_filter { + PathFilter::Exact(p) => { + sql.push_str(&format!( + " AND EXISTS (SELECT 1 FROM document_paths dp WHERE dp.document_id = d.id AND dp.path = ?{})", + param_idx + )); + params.push(Box::new(p.clone())); + param_idx += 1; + } + PathFilter::Prefix(p) => { + let escaped = escape_like(p); + sql.push_str(&format!( + " AND EXISTS (SELECT 1 FROM document_paths dp WHERE dp.document_id = d.id AND dp.path LIKE ?{} ESCAPE '\\')", + param_idx + )); + params.push(Box::new(format!("{}%", escaped))); + param_idx += 1; + } + } + } + + let limit = filters.clamp_limit(); + sql.push_str(&format!( + " ORDER BY j.key LIMIT ?{}", + param_idx + )); + params.push(Box::new(limit as i64)); + + let param_refs: Vec<&dyn rusqlite::types::ToSql> = params.iter().map(|p| p.as_ref()).collect(); + + let mut stmt = conn.prepare(&sql)?; + let ids = stmt + .query_map(param_refs.as_slice(), |row| row.get::<_, i64>(0))? + .collect::, _>>()?; + + Ok(ids) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_has_any_filter_default() { + let f = SearchFilters::default(); + assert!(!f.has_any_filter()); + } + + #[test] + fn test_has_any_filter_with_source_type() { + let f = SearchFilters { + source_type: Some(SourceType::Issue), + ..Default::default() + }; + assert!(f.has_any_filter()); + } + + #[test] + fn test_has_any_filter_with_labels() { + let f = SearchFilters { + labels: vec!["bug".to_string()], + ..Default::default() + }; + assert!(f.has_any_filter()); + } + + #[test] + fn test_limit_clamping_zero() { + let f = SearchFilters { + limit: 0, + ..Default::default() + }; + assert_eq!(f.clamp_limit(), 20); + } + + #[test] + fn test_limit_clamping_over_max() { + let f = SearchFilters { + limit: 200, + ..Default::default() + }; + assert_eq!(f.clamp_limit(), 100); + } + + #[test] + fn test_limit_clamping_normal() { + let f = SearchFilters { + limit: 50, + ..Default::default() + }; + assert_eq!(f.clamp_limit(), 50); + } + + #[test] + fn test_escape_like() { + assert_eq!(escape_like("src/%.rs"), "src/\\%.rs"); + assert_eq!(escape_like("file_name"), "file\\_name"); + assert_eq!(escape_like("normal"), "normal"); + assert_eq!(escape_like("a\\b"), "a\\\\b"); + } + + #[test] + fn test_empty_ids() { + // Cannot test apply_filters without DB, but we can verify empty returns empty + // by testing the early return path logic + let f = SearchFilters::default(); + assert!(!f.has_any_filter()); + } +} diff --git a/src/search/fts.rs b/src/search/fts.rs new file mode 100644 index 0000000..8b2a26e --- /dev/null +++ b/src/search/fts.rs @@ -0,0 +1,228 @@ +use crate::core::error::Result; +use rusqlite::Connection; + +/// FTS query mode. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum FtsQueryMode { + /// Safe mode: each token wrapped in quotes, trailing * preserved on alphanumeric tokens. + Safe, + /// Raw mode: query passed directly to FTS5 (for advanced users). + Raw, +} + +/// A single FTS5 search result. +#[derive(Debug)] +pub struct FtsResult { + pub document_id: i64, + pub bm25_score: f64, + pub snippet: String, +} + +/// Convert raw user input into a safe FTS5 query. +/// +/// Safe mode: +/// - Splits on whitespace +/// - Wraps each token in double quotes (escaping internal quotes) +/// - Preserves trailing `*` on alphanumeric-only tokens (prefix search) +/// +/// Raw mode: passes through unchanged. +pub fn to_fts_query(raw: &str, mode: FtsQueryMode) -> String { + match mode { + FtsQueryMode::Raw => raw.to_string(), + FtsQueryMode::Safe => { + let trimmed = raw.trim(); + if trimmed.is_empty() { + return String::new(); + } + + let tokens: Vec = trimmed + .split_whitespace() + .map(|token| { + // Check if token ends with * and the rest is alphanumeric + if token.ends_with('*') { + let stem = &token[..token.len() - 1]; + if !stem.is_empty() && stem.chars().all(|c| c.is_alphanumeric() || c == '_') { + // Preserve prefix search: "stem"* + let escaped = stem.replace('"', "\"\""); + return format!("\"{}\"*", escaped); + } + } + // Default: wrap in quotes, escape internal quotes + let escaped = token.replace('"', "\"\""); + format!("\"{}\"", escaped) + }) + .collect(); + + tokens.join(" ") + } + } +} + +/// Execute an FTS5 search query. +/// +/// Returns results ranked by BM25 score (lower = better match) with +/// contextual snippets highlighting matches. +pub fn search_fts( + conn: &Connection, + query: &str, + limit: usize, + mode: FtsQueryMode, +) -> Result> { + let fts_query = to_fts_query(query, mode); + if fts_query.is_empty() { + return Ok(Vec::new()); + } + + let sql = r#" + SELECT d.id, bm25(documents_fts) AS score, + snippet(documents_fts, 1, '', '', '...', 64) AS snip + FROM documents_fts + JOIN documents d ON d.id = documents_fts.rowid + WHERE documents_fts MATCH ?1 + ORDER BY score + LIMIT ?2 + "#; + + let mut stmt = conn.prepare(sql)?; + let results = stmt + .query_map(rusqlite::params![fts_query, limit as i64], |row| { + Ok(FtsResult { + document_id: row.get(0)?, + bm25_score: row.get(1)?, + snippet: row.get(2)?, + }) + })? + .collect::, _>>()?; + + Ok(results) +} + +/// Generate a fallback snippet for results without FTS snippets. +/// Truncates at a word boundary and appends "...". +pub fn generate_fallback_snippet(content_text: &str, max_chars: usize) -> String { + if content_text.chars().count() <= max_chars { + return content_text.to_string(); + } + + // Collect the char boundary at max_chars to slice correctly for multi-byte content + let byte_end = content_text + .char_indices() + .nth(max_chars) + .map(|(i, _)| i) + .unwrap_or(content_text.len()); + let truncated = &content_text[..byte_end]; + + // Walk backward to find a word boundary (space) + if let Some(last_space) = truncated.rfind(' ') { + format!("{}...", &truncated[..last_space]) + } else { + format!("{}...", truncated) + } +} + +/// Get the best snippet: prefer FTS snippet, fall back to truncated content. +pub fn get_result_snippet(fts_snippet: Option<&str>, content_text: &str) -> String { + match fts_snippet { + Some(s) if !s.is_empty() => s.to_string(), + _ => generate_fallback_snippet(content_text, 200), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_safe_query_basic() { + let result = to_fts_query("auth error", FtsQueryMode::Safe); + assert_eq!(result, "\"auth\" \"error\""); + } + + #[test] + fn test_safe_query_prefix() { + let result = to_fts_query("auth*", FtsQueryMode::Safe); + assert_eq!(result, "\"auth\"*"); + } + + #[test] + fn test_safe_query_special_chars() { + let result = to_fts_query("C++", FtsQueryMode::Safe); + assert_eq!(result, "\"C++\""); + } + + #[test] + fn test_safe_query_dash() { + let result = to_fts_query("-DWITH_SSL", FtsQueryMode::Safe); + assert_eq!(result, "\"-DWITH_SSL\""); + } + + #[test] + fn test_safe_query_quotes() { + let result = to_fts_query("he said \"hello\"", FtsQueryMode::Safe); + assert_eq!(result, "\"he\" \"said\" \"\"\"hello\"\"\""); + } + + #[test] + fn test_raw_mode_passthrough() { + let result = to_fts_query("auth OR error", FtsQueryMode::Raw); + assert_eq!(result, "auth OR error"); + } + + #[test] + fn test_empty_query() { + let result = to_fts_query("", FtsQueryMode::Safe); + assert_eq!(result, ""); + + let result = to_fts_query(" ", FtsQueryMode::Safe); + assert_eq!(result, ""); + } + + #[test] + fn test_prefix_only_alphanumeric() { + // Non-alphanumeric prefix: C++* should NOT be treated as prefix search + let result = to_fts_query("C++*", FtsQueryMode::Safe); + assert_eq!(result, "\"C++*\""); + + // Pure alphanumeric prefix: auth* should be prefix search + let result = to_fts_query("auth*", FtsQueryMode::Safe); + assert_eq!(result, "\"auth\"*"); + } + + #[test] + fn test_prefix_with_underscore() { + let result = to_fts_query("jwt_token*", FtsQueryMode::Safe); + assert_eq!(result, "\"jwt_token\"*"); + } + + #[test] + fn test_fallback_snippet_short() { + let result = generate_fallback_snippet("Short content", 200); + assert_eq!(result, "Short content"); + } + + #[test] + fn test_fallback_snippet_word_boundary() { + let content = "This is a moderately long piece of text that should be truncated at a word boundary for readability purposes"; + let result = generate_fallback_snippet(content, 50); + assert!(result.ends_with("...")); + assert!(result.len() <= 55); // 50 + "..." + } + + #[test] + fn test_get_result_snippet_prefers_fts() { + let result = get_result_snippet(Some("FTS match"), "full content text"); + assert_eq!(result, "FTS match"); + } + + #[test] + fn test_get_result_snippet_fallback() { + let result = get_result_snippet(None, "full content text"); + assert_eq!(result, "full content text"); + } + + #[test] + fn test_get_result_snippet_empty_fts() { + let result = get_result_snippet(Some(""), "full content text"); + assert_eq!(result, "full content text"); + } +} diff --git a/src/search/hybrid.rs b/src/search/hybrid.rs new file mode 100644 index 0000000..f1ea182 --- /dev/null +++ b/src/search/hybrid.rs @@ -0,0 +1,258 @@ +//! Hybrid search orchestrator combining FTS5 + sqlite-vec via RRF. + +use rusqlite::Connection; + +use crate::core::error::Result; +use crate::embedding::ollama::OllamaClient; +use crate::search::{rank_rrf, search_fts, search_vector, FtsQueryMode}; +use crate::search::filters::{apply_filters, SearchFilters}; + +const BASE_RECALL_MIN: usize = 50; +const FILTERED_RECALL_MIN: usize = 200; +const RECALL_CAP: usize = 1500; + +/// Search mode selection. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SearchMode { + Hybrid, + Lexical, + Semantic, +} + +impl SearchMode { + pub fn parse(s: &str) -> Option { + match s.to_lowercase().as_str() { + "hybrid" => Some(Self::Hybrid), + "lexical" | "fts" => Some(Self::Lexical), + "semantic" | "vector" => Some(Self::Semantic), + _ => None, + } + } + + pub fn as_str(&self) -> &'static str { + match self { + Self::Hybrid => "hybrid", + Self::Lexical => "lexical", + Self::Semantic => "semantic", + } + } +} + +/// Combined search result with provenance from both retrieval lists. +pub struct HybridResult { + pub document_id: i64, + pub score: f64, + pub vector_rank: Option, + pub fts_rank: Option, + pub rrf_score: f64, +} + +/// Execute hybrid search, returning ranked results + any warnings. +/// +/// `client` is `Option` to enable graceful degradation: when Ollama is +/// unavailable, the caller passes `None` and hybrid mode falls back to +/// FTS-only with a warning. +pub async fn search_hybrid( + conn: &Connection, + client: Option<&OllamaClient>, + query: &str, + mode: SearchMode, + filters: &SearchFilters, + fts_mode: FtsQueryMode, +) -> Result<(Vec, Vec)> { + let mut warnings: Vec = Vec::new(); + + // Adaptive recall + let requested = filters.clamp_limit(); + let top_k = if filters.has_any_filter() { + (requested * 50).max(FILTERED_RECALL_MIN).min(RECALL_CAP) + } else { + (requested * 10).max(BASE_RECALL_MIN).min(RECALL_CAP) + }; + + let (fts_tuples, vec_tuples) = match mode { + SearchMode::Lexical => { + let fts_results = search_fts(conn, query, top_k, fts_mode)?; + let fts_tuples: Vec<(i64, f64)> = fts_results + .iter() + .map(|r| (r.document_id, r.bm25_score)) + .collect(); + (fts_tuples, Vec::new()) + } + + SearchMode::Semantic => { + let Some(client) = client else { + return Err(crate::core::error::LoreError::Other( + "Semantic search requires Ollama. Start Ollama or use --mode=lexical.".into(), + )); + }; + + let query_embedding = client.embed_batch(vec![query.to_string()]).await?; + let embedding = query_embedding + .into_iter() + .next() + .unwrap_or_default(); + + if embedding.is_empty() { + return Err(crate::core::error::LoreError::Other( + "Ollama returned empty embedding for query.".into(), + )); + } + + let vec_results = search_vector(conn, &embedding, top_k)?; + let vec_tuples: Vec<(i64, f64)> = vec_results + .iter() + .map(|r| (r.document_id, r.distance)) + .collect(); + (Vec::new(), vec_tuples) + } + + SearchMode::Hybrid => { + let fts_results = search_fts(conn, query, top_k, fts_mode)?; + let fts_tuples: Vec<(i64, f64)> = fts_results + .iter() + .map(|r| (r.document_id, r.bm25_score)) + .collect(); + + match client { + Some(client) => { + match client.embed_batch(vec![query.to_string()]).await { + Ok(query_embedding) => { + let embedding = query_embedding + .into_iter() + .next() + .unwrap_or_default(); + + let vec_tuples = if embedding.is_empty() { + warnings.push( + "Ollama returned empty embedding, using FTS only.".into(), + ); + Vec::new() + } else { + let vec_results = search_vector(conn, &embedding, top_k)?; + vec_results + .iter() + .map(|r| (r.document_id, r.distance)) + .collect() + }; + + (fts_tuples, vec_tuples) + } + Err(e) => { + warnings.push( + format!("Embedding failed ({}), falling back to lexical search.", e), + ); + (fts_tuples, Vec::new()) + } + } + } + None => { + warnings.push( + "Ollama unavailable, falling back to lexical search.".into(), + ); + (fts_tuples, Vec::new()) + } + } + } + }; + + let ranked = rank_rrf(&vec_tuples, &fts_tuples); + + let results: Vec = ranked + .into_iter() + .map(|r| HybridResult { + document_id: r.document_id, + score: r.normalized_score, + vector_rank: r.vector_rank, + fts_rank: r.fts_rank, + rrf_score: r.rrf_score, + }) + .collect(); + + // Apply post-retrieval filters and limit + let limit = filters.clamp_limit(); + let results = if filters.has_any_filter() { + let all_ids: Vec = results.iter().map(|r| r.document_id).collect(); + let filtered_ids = apply_filters(conn, &all_ids, filters)?; + let filtered_set: std::collections::HashSet = filtered_ids.iter().copied().collect(); + results + .into_iter() + .filter(|r| filtered_set.contains(&r.document_id)) + .take(limit) + .collect() + } else { + results.into_iter().take(limit).collect() + }; + + Ok((results, warnings)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_search_mode_from_str() { + assert_eq!(SearchMode::parse("hybrid"), Some(SearchMode::Hybrid)); + assert_eq!(SearchMode::parse("lexical"), Some(SearchMode::Lexical)); + assert_eq!(SearchMode::parse("fts"), Some(SearchMode::Lexical)); + assert_eq!(SearchMode::parse("semantic"), Some(SearchMode::Semantic)); + assert_eq!(SearchMode::parse("vector"), Some(SearchMode::Semantic)); + assert_eq!(SearchMode::parse("HYBRID"), Some(SearchMode::Hybrid)); + assert_eq!(SearchMode::parse("invalid"), None); + assert_eq!(SearchMode::parse(""), None); + } + + #[test] + fn test_search_mode_as_str() { + assert_eq!(SearchMode::Hybrid.as_str(), "hybrid"); + assert_eq!(SearchMode::Lexical.as_str(), "lexical"); + assert_eq!(SearchMode::Semantic.as_str(), "semantic"); + } + + #[test] + fn test_adaptive_recall_unfiltered() { + let filters = SearchFilters { + limit: 20, + ..Default::default() + }; + let requested = filters.clamp_limit(); + let top_k = (requested * 10).max(BASE_RECALL_MIN).min(RECALL_CAP); + assert_eq!(top_k, 200); + } + + #[test] + fn test_adaptive_recall_filtered() { + let filters = SearchFilters { + limit: 20, + author: Some("alice".to_string()), + ..Default::default() + }; + let requested = filters.clamp_limit(); + let top_k = (requested * 50).max(FILTERED_RECALL_MIN).min(RECALL_CAP); + assert_eq!(top_k, 1000); + } + + #[test] + fn test_adaptive_recall_cap() { + let filters = SearchFilters { + limit: 100, + author: Some("alice".to_string()), + ..Default::default() + }; + let requested = filters.clamp_limit(); + let top_k = (requested * 50).max(FILTERED_RECALL_MIN).min(RECALL_CAP); + assert_eq!(top_k, RECALL_CAP); // 5000 capped to 1500 + } + + #[test] + fn test_adaptive_recall_minimum() { + let filters = SearchFilters { + limit: 1, + ..Default::default() + }; + let requested = filters.clamp_limit(); + let top_k = (requested * 10).max(BASE_RECALL_MIN).min(RECALL_CAP); + assert_eq!(top_k, BASE_RECALL_MIN); // 10 -> 50 + } +} diff --git a/src/search/mod.rs b/src/search/mod.rs new file mode 100644 index 0000000..6d0032e --- /dev/null +++ b/src/search/mod.rs @@ -0,0 +1,14 @@ +mod filters; +mod fts; +mod hybrid; +mod rrf; +mod vector; + +pub use fts::{ + generate_fallback_snippet, get_result_snippet, search_fts, to_fts_query, FtsQueryMode, + FtsResult, +}; +pub use filters::{apply_filters, PathFilter, SearchFilters}; +pub use rrf::{rank_rrf, RrfResult}; +pub use vector::{search_vector, VectorResult}; +pub use hybrid::{search_hybrid, HybridResult, SearchMode}; diff --git a/src/search/rrf.rs b/src/search/rrf.rs new file mode 100644 index 0000000..034e71c --- /dev/null +++ b/src/search/rrf.rs @@ -0,0 +1,178 @@ +use std::collections::HashMap; + +const RRF_K: f64 = 60.0; + +/// A single result from Reciprocal Rank Fusion, containing both raw and +/// normalized scores plus per-list rank provenance for --explain output. +pub struct RrfResult { + pub document_id: i64, + /// Raw RRF score: sum of 1/(k + rank) across all lists. + pub rrf_score: f64, + /// Normalized to [0, 1] where the best result is 1.0. + pub normalized_score: f64, + /// 1-indexed rank in the vector results list, if present. + pub vector_rank: Option, + /// 1-indexed rank in the FTS results list, if present. + pub fts_rank: Option, +} + +/// Combine vector and FTS retrieval results using Reciprocal Rank Fusion. +/// +/// Input tuples are `(document_id, score/distance)` — already sorted by each retriever. +/// Ranks are 1-indexed (first result = rank 1). +/// +/// Score = sum of 1/(k + rank) for each list containing the document. +pub fn rank_rrf( + vector_results: &[(i64, f64)], + fts_results: &[(i64, f64)], +) -> Vec { + if vector_results.is_empty() && fts_results.is_empty() { + return Vec::new(); + } + + // (rrf_score, vector_rank, fts_rank) + let mut scores: HashMap, Option)> = HashMap::new(); + + for (i, &(doc_id, _)) in vector_results.iter().enumerate() { + let rank = i + 1; // 1-indexed + let entry = scores.entry(doc_id).or_insert((0.0, None, None)); + entry.0 += 1.0 / (RRF_K + rank as f64); + if entry.1.is_none() { + entry.1 = Some(rank); + } + } + + for (i, &(doc_id, _)) in fts_results.iter().enumerate() { + let rank = i + 1; // 1-indexed + let entry = scores.entry(doc_id).or_insert((0.0, None, None)); + entry.0 += 1.0 / (RRF_K + rank as f64); + if entry.2.is_none() { + entry.2 = Some(rank); + } + } + + let mut results: Vec = scores + .into_iter() + .map(|(doc_id, (rrf_score, vector_rank, fts_rank))| RrfResult { + document_id: doc_id, + rrf_score, + normalized_score: 0.0, // filled in below + vector_rank, + fts_rank, + }) + .collect(); + + // Sort descending by rrf_score + results.sort_by(|a, b| b.rrf_score.partial_cmp(&a.rrf_score).unwrap_or(std::cmp::Ordering::Equal)); + + // Normalize: best = 1.0 + if let Some(max_score) = results.first().map(|r| r.rrf_score) { + if max_score > 0.0 { + for result in &mut results { + result.normalized_score = result.rrf_score / max_score; + } + } + } + + results +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_dual_list_ranks_higher() { + let vector = vec![(1, 0.1), (2, 0.2)]; + let fts = vec![(1, 5.0), (3, 3.0)]; + let results = rank_rrf(&vector, &fts); + + // Doc 1 appears in both lists, should rank highest + assert_eq!(results[0].document_id, 1); + + // Doc 1 score should be higher than doc 2 and doc 3 + let doc1 = &results[0]; + let doc2_score = results.iter().find(|r| r.document_id == 2).unwrap().rrf_score; + let doc3_score = results.iter().find(|r| r.document_id == 3).unwrap().rrf_score; + assert!(doc1.rrf_score > doc2_score); + assert!(doc1.rrf_score > doc3_score); + } + + #[test] + fn test_single_list_included() { + let vector = vec![(1, 0.1)]; + let fts = vec![(2, 5.0)]; + let results = rank_rrf(&vector, &fts); + + assert_eq!(results.len(), 2); + let doc_ids: Vec = results.iter().map(|r| r.document_id).collect(); + assert!(doc_ids.contains(&1)); + assert!(doc_ids.contains(&2)); + } + + #[test] + fn test_normalization() { + let vector = vec![(1, 0.1), (2, 0.2)]; + let fts = vec![(1, 5.0), (3, 3.0)]; + let results = rank_rrf(&vector, &fts); + + // Best result should have normalized_score = 1.0 + assert!((results[0].normalized_score - 1.0).abs() < f64::EPSILON); + + // All scores in [0, 1] + for r in &results { + assert!(r.normalized_score >= 0.0); + assert!(r.normalized_score <= 1.0); + } + } + + #[test] + fn test_empty_inputs() { + let results = rank_rrf(&[], &[]); + assert!(results.is_empty()); + } + + #[test] + fn test_ranks_are_1_indexed() { + let vector = vec![(10, 0.1), (20, 0.2)]; + let fts = vec![(10, 5.0), (30, 3.0)]; + let results = rank_rrf(&vector, &fts); + + let doc10 = results.iter().find(|r| r.document_id == 10).unwrap(); + assert_eq!(doc10.vector_rank, Some(1)); + assert_eq!(doc10.fts_rank, Some(1)); + + let doc20 = results.iter().find(|r| r.document_id == 20).unwrap(); + assert_eq!(doc20.vector_rank, Some(2)); + assert_eq!(doc20.fts_rank, None); + + let doc30 = results.iter().find(|r| r.document_id == 30).unwrap(); + assert_eq!(doc30.vector_rank, None); + assert_eq!(doc30.fts_rank, Some(2)); + } + + #[test] + fn test_raw_and_normalized_scores() { + let vector = vec![(1, 0.1)]; + let fts = vec![(1, 5.0)]; + let results = rank_rrf(&vector, &fts); + + assert_eq!(results.len(), 1); + let r = &results[0]; + + // RRF score = 1/(60+1) + 1/(60+1) = 2/61 + let expected = 2.0 / 61.0; + assert!((r.rrf_score - expected).abs() < 1e-10); + assert!((r.normalized_score - 1.0).abs() < f64::EPSILON); + } + + #[test] + fn test_one_empty_list() { + let vector = vec![(1, 0.1), (2, 0.2)]; + let results = rank_rrf(&vector, &[]); + + assert_eq!(results.len(), 2); + // Single result should still have normalized_score = 1.0 + assert!((results[0].normalized_score - 1.0).abs() < f64::EPSILON); + } +} diff --git a/src/search/vector.rs b/src/search/vector.rs new file mode 100644 index 0000000..5939323 --- /dev/null +++ b/src/search/vector.rs @@ -0,0 +1,139 @@ +use std::collections::HashMap; + +use rusqlite::Connection; + +use crate::core::error::Result; +use crate::embedding::chunk_ids::decode_rowid; + +/// A single vector search result (document-level, deduplicated). +#[derive(Debug)] +pub struct VectorResult { + pub document_id: i64, + pub distance: f64, +} + +/// Search documents using sqlite-vec KNN query. +/// +/// Over-fetches 3x limit to handle chunk deduplication (multiple chunks per +/// document produce multiple KNN results for the same document_id). +/// Returns deduplicated results with best (lowest) distance per document. +pub fn search_vector( + conn: &Connection, + query_embedding: &[f32], + limit: usize, +) -> Result> { + if query_embedding.is_empty() || limit == 0 { + return Ok(Vec::new()); + } + + // Convert to raw little-endian bytes for sqlite-vec + let embedding_bytes: Vec = query_embedding + .iter() + .flat_map(|f| f.to_le_bytes()) + .collect(); + + let k = limit * 3; // Over-fetch for dedup + + let mut stmt = conn.prepare( + "SELECT rowid, distance + FROM embeddings + WHERE embedding MATCH ?1 + AND k = ?2 + ORDER BY distance" + )?; + + let rows: Vec<(i64, f64)> = stmt + .query_map(rusqlite::params![embedding_bytes, k as i64], |row| { + Ok((row.get(0)?, row.get(1)?)) + })? + .collect::, _>>()?; + + // Dedup by document_id, keeping best (lowest) distance + let mut best: HashMap = HashMap::new(); + for (rowid, distance) in rows { + let (document_id, _chunk_index) = decode_rowid(rowid); + best.entry(document_id) + .and_modify(|d| { + if distance < *d { + *d = distance; + } + }) + .or_insert(distance); + } + + // Sort by distance ascending, take limit + let mut results: Vec = best + .into_iter() + .map(|(document_id, distance)| VectorResult { + document_id, + distance, + }) + .collect(); + results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap_or(std::cmp::Ordering::Equal)); + results.truncate(limit); + + Ok(results) +} + +#[cfg(test)] +mod tests { + use super::*; + + // Note: Full integration tests require sqlite-vec loaded, which happens via + // create_connection in db.rs. These are basic unit tests for the dedup logic. + + #[test] + fn test_empty_returns_empty() { + // Can't test KNN without sqlite-vec, but we can test edge cases + let result = search_vector_dedup(vec![], 10); + assert!(result.is_empty()); + } + + #[test] + fn test_dedup_keeps_best_distance() { + // Simulate: doc 1 has chunks at rowid 1000 (idx 0) and 1001 (idx 1) + let rows = vec![ + (1000_i64, 0.5_f64), // doc 1, chunk 0 + (1001, 0.3), // doc 1, chunk 1 (better) + (2000, 0.4), // doc 2, chunk 0 + ]; + let results = search_vector_dedup(rows, 10); + assert_eq!(results.len(), 2); + assert_eq!(results[0].document_id, 1); // doc 1 best = 0.3 + assert!((results[0].distance - 0.3).abs() < f64::EPSILON); + assert_eq!(results[1].document_id, 2); // doc 2 = 0.4 + } + + #[test] + fn test_dedup_respects_limit() { + let rows = vec![ + (1000_i64, 0.1_f64), + (2000, 0.2), + (3000, 0.3), + ]; + let results = search_vector_dedup(rows, 2); + assert_eq!(results.len(), 2); + } + + /// Helper for testing dedup logic without sqlite-vec + fn search_vector_dedup(rows: Vec<(i64, f64)>, limit: usize) -> Vec { + let mut best: HashMap = HashMap::new(); + for (rowid, distance) in rows { + let (document_id, _) = decode_rowid(rowid); + best.entry(document_id) + .and_modify(|d| { + if distance < *d { + *d = distance; + } + }) + .or_insert(distance); + } + let mut results: Vec = best + .into_iter() + .map(|(document_id, distance)| VectorResult { document_id, distance }) + .collect(); + results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap_or(std::cmp::Ordering::Equal)); + results.truncate(limit); + results + } +}