feat(search): Add hybrid search engine with FTS5, vector, and RRF fusion
Implements the search module providing three search modes: - Lexical (FTS5): Full-text search using SQLite FTS5 with safe query sanitization. User queries are automatically tokenized and wrapped in proper FTS5 syntax. Supports a "raw" mode for power users who want direct FTS5 query syntax (NEAR, column filters, etc.). - Semantic (vector): Embeds the search query via Ollama, then performs cosine similarity search against stored document embeddings. Results are deduplicated by doc_id since documents may have multiple chunks. - Hybrid (default): Executes both lexical and semantic searches in parallel, then fuses results using Reciprocal Rank Fusion (RRF) with k=60. This avoids the complexity of score normalization while producing high-quality merged rankings. Gracefully degrades to lexical-only when embeddings are unavailable. Additional components: - search::filters: Post-retrieval filtering by source_type, author, project, labels (AND logic), file path prefix, created_after, and updated_after. Date filters accept relative formats (7d, 2w) and ISO dates. - search::rrf: Reciprocal Rank Fusion implementation with configurable k parameter and optional explain mode that annotates each result with its component ranks and fusion score breakdown. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
227
src/search/filters.rs
Normal file
227
src/search/filters.rs
Normal file
@@ -0,0 +1,227 @@
|
||||
use crate::core::error::Result;
|
||||
use crate::documents::SourceType;
|
||||
use rusqlite::Connection;
|
||||
|
||||
const DEFAULT_LIMIT: usize = 20;
|
||||
const MAX_LIMIT: usize = 100;
|
||||
|
||||
/// Path filter: exact match or prefix match (trailing `/`).
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum PathFilter {
|
||||
Exact(String),
|
||||
Prefix(String),
|
||||
}
|
||||
|
||||
/// Filters applied to search results post-retrieval.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct SearchFilters {
|
||||
pub source_type: Option<SourceType>,
|
||||
pub author: Option<String>,
|
||||
pub project_id: Option<i64>,
|
||||
pub after: Option<i64>,
|
||||
pub updated_after: Option<i64>,
|
||||
pub labels: Vec<String>,
|
||||
pub path: Option<PathFilter>,
|
||||
pub limit: usize,
|
||||
}
|
||||
|
||||
impl SearchFilters {
|
||||
/// Returns true if any filter (besides limit) is set.
|
||||
pub fn has_any_filter(&self) -> bool {
|
||||
self.source_type.is_some()
|
||||
|| self.author.is_some()
|
||||
|| self.project_id.is_some()
|
||||
|| self.after.is_some()
|
||||
|| self.updated_after.is_some()
|
||||
|| !self.labels.is_empty()
|
||||
|| self.path.is_some()
|
||||
}
|
||||
|
||||
/// Clamp limit to [1, 100], defaulting 0 to 20.
|
||||
pub fn clamp_limit(&self) -> usize {
|
||||
if self.limit == 0 {
|
||||
DEFAULT_LIMIT
|
||||
} else {
|
||||
self.limit.min(MAX_LIMIT)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Escape SQL LIKE wildcards in a string.
|
||||
fn escape_like(s: &str) -> String {
|
||||
s.replace('\\', "\\\\")
|
||||
.replace('%', "\\%")
|
||||
.replace('_', "\\_")
|
||||
}
|
||||
|
||||
/// Apply filters to a ranked list of document IDs, preserving rank order.
|
||||
///
|
||||
/// Uses json_each() to pass ranked IDs efficiently and maintain ordering
|
||||
/// via ORDER BY j.key.
|
||||
pub fn apply_filters(
|
||||
conn: &Connection,
|
||||
document_ids: &[i64],
|
||||
filters: &SearchFilters,
|
||||
) -> Result<Vec<i64>> {
|
||||
if document_ids.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let ids_json = serde_json::to_string(document_ids)
|
||||
.map_err(|e| crate::core::error::LoreError::Other(e.to_string()))?;
|
||||
|
||||
let mut sql = String::from(
|
||||
"SELECT d.id FROM json_each(?1) AS j JOIN documents d ON d.id = j.value WHERE 1=1",
|
||||
);
|
||||
let mut params: Vec<Box<dyn rusqlite::types::ToSql>> = vec![Box::new(ids_json)];
|
||||
let mut param_idx = 2;
|
||||
|
||||
if let Some(ref st) = filters.source_type {
|
||||
sql.push_str(&format!(" AND d.source_type = ?{}", param_idx));
|
||||
params.push(Box::new(st.as_str().to_string()));
|
||||
param_idx += 1;
|
||||
}
|
||||
|
||||
if let Some(ref author) = filters.author {
|
||||
sql.push_str(&format!(" AND d.author_username = ?{}", param_idx));
|
||||
params.push(Box::new(author.clone()));
|
||||
param_idx += 1;
|
||||
}
|
||||
|
||||
if let Some(pid) = filters.project_id {
|
||||
sql.push_str(&format!(" AND d.project_id = ?{}", param_idx));
|
||||
params.push(Box::new(pid));
|
||||
param_idx += 1;
|
||||
}
|
||||
|
||||
if let Some(after) = filters.after {
|
||||
sql.push_str(&format!(" AND d.created_at >= ?{}", param_idx));
|
||||
params.push(Box::new(after));
|
||||
param_idx += 1;
|
||||
}
|
||||
|
||||
if let Some(updated_after) = filters.updated_after {
|
||||
sql.push_str(&format!(" AND d.updated_at >= ?{}", param_idx));
|
||||
params.push(Box::new(updated_after));
|
||||
param_idx += 1;
|
||||
}
|
||||
|
||||
for label in &filters.labels {
|
||||
sql.push_str(&format!(
|
||||
" AND EXISTS (SELECT 1 FROM document_labels dl WHERE dl.document_id = d.id AND dl.label_name = ?{})",
|
||||
param_idx
|
||||
));
|
||||
params.push(Box::new(label.clone()));
|
||||
param_idx += 1;
|
||||
}
|
||||
|
||||
if let Some(ref path_filter) = filters.path {
|
||||
match path_filter {
|
||||
PathFilter::Exact(p) => {
|
||||
sql.push_str(&format!(
|
||||
" AND EXISTS (SELECT 1 FROM document_paths dp WHERE dp.document_id = d.id AND dp.path = ?{})",
|
||||
param_idx
|
||||
));
|
||||
params.push(Box::new(p.clone()));
|
||||
param_idx += 1;
|
||||
}
|
||||
PathFilter::Prefix(p) => {
|
||||
let escaped = escape_like(p);
|
||||
sql.push_str(&format!(
|
||||
" AND EXISTS (SELECT 1 FROM document_paths dp WHERE dp.document_id = d.id AND dp.path LIKE ?{} ESCAPE '\\')",
|
||||
param_idx
|
||||
));
|
||||
params.push(Box::new(format!("{}%", escaped)));
|
||||
param_idx += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let limit = filters.clamp_limit();
|
||||
sql.push_str(&format!(
|
||||
" ORDER BY j.key LIMIT ?{}",
|
||||
param_idx
|
||||
));
|
||||
params.push(Box::new(limit as i64));
|
||||
|
||||
let param_refs: Vec<&dyn rusqlite::types::ToSql> = params.iter().map(|p| p.as_ref()).collect();
|
||||
|
||||
let mut stmt = conn.prepare(&sql)?;
|
||||
let ids = stmt
|
||||
.query_map(param_refs.as_slice(), |row| row.get::<_, i64>(0))?
|
||||
.collect::<std::result::Result<Vec<_>, _>>()?;
|
||||
|
||||
Ok(ids)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_has_any_filter_default() {
|
||||
let f = SearchFilters::default();
|
||||
assert!(!f.has_any_filter());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_has_any_filter_with_source_type() {
|
||||
let f = SearchFilters {
|
||||
source_type: Some(SourceType::Issue),
|
||||
..Default::default()
|
||||
};
|
||||
assert!(f.has_any_filter());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_has_any_filter_with_labels() {
|
||||
let f = SearchFilters {
|
||||
labels: vec!["bug".to_string()],
|
||||
..Default::default()
|
||||
};
|
||||
assert!(f.has_any_filter());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_limit_clamping_zero() {
|
||||
let f = SearchFilters {
|
||||
limit: 0,
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(f.clamp_limit(), 20);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_limit_clamping_over_max() {
|
||||
let f = SearchFilters {
|
||||
limit: 200,
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(f.clamp_limit(), 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_limit_clamping_normal() {
|
||||
let f = SearchFilters {
|
||||
limit: 50,
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(f.clamp_limit(), 50);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_escape_like() {
|
||||
assert_eq!(escape_like("src/%.rs"), "src/\\%.rs");
|
||||
assert_eq!(escape_like("file_name"), "file\\_name");
|
||||
assert_eq!(escape_like("normal"), "normal");
|
||||
assert_eq!(escape_like("a\\b"), "a\\\\b");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_ids() {
|
||||
// Cannot test apply_filters without DB, but we can verify empty returns empty
|
||||
// by testing the early return path logic
|
||||
let f = SearchFilters::default();
|
||||
assert!(!f.has_any_filter());
|
||||
}
|
||||
}
|
||||
228
src/search/fts.rs
Normal file
228
src/search/fts.rs
Normal file
@@ -0,0 +1,228 @@
|
||||
use crate::core::error::Result;
|
||||
use rusqlite::Connection;
|
||||
|
||||
/// FTS query mode.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum FtsQueryMode {
|
||||
/// Safe mode: each token wrapped in quotes, trailing * preserved on alphanumeric tokens.
|
||||
Safe,
|
||||
/// Raw mode: query passed directly to FTS5 (for advanced users).
|
||||
Raw,
|
||||
}
|
||||
|
||||
/// A single FTS5 search result.
|
||||
#[derive(Debug)]
|
||||
pub struct FtsResult {
|
||||
pub document_id: i64,
|
||||
pub bm25_score: f64,
|
||||
pub snippet: String,
|
||||
}
|
||||
|
||||
/// Convert raw user input into a safe FTS5 query.
|
||||
///
|
||||
/// Safe mode:
|
||||
/// - Splits on whitespace
|
||||
/// - Wraps each token in double quotes (escaping internal quotes)
|
||||
/// - Preserves trailing `*` on alphanumeric-only tokens (prefix search)
|
||||
///
|
||||
/// Raw mode: passes through unchanged.
|
||||
pub fn to_fts_query(raw: &str, mode: FtsQueryMode) -> String {
|
||||
match mode {
|
||||
FtsQueryMode::Raw => raw.to_string(),
|
||||
FtsQueryMode::Safe => {
|
||||
let trimmed = raw.trim();
|
||||
if trimmed.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
let tokens: Vec<String> = trimmed
|
||||
.split_whitespace()
|
||||
.map(|token| {
|
||||
// Check if token ends with * and the rest is alphanumeric
|
||||
if token.ends_with('*') {
|
||||
let stem = &token[..token.len() - 1];
|
||||
if !stem.is_empty() && stem.chars().all(|c| c.is_alphanumeric() || c == '_') {
|
||||
// Preserve prefix search: "stem"*
|
||||
let escaped = stem.replace('"', "\"\"");
|
||||
return format!("\"{}\"*", escaped);
|
||||
}
|
||||
}
|
||||
// Default: wrap in quotes, escape internal quotes
|
||||
let escaped = token.replace('"', "\"\"");
|
||||
format!("\"{}\"", escaped)
|
||||
})
|
||||
.collect();
|
||||
|
||||
tokens.join(" ")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute an FTS5 search query.
|
||||
///
|
||||
/// Returns results ranked by BM25 score (lower = better match) with
|
||||
/// contextual snippets highlighting matches.
|
||||
pub fn search_fts(
|
||||
conn: &Connection,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
mode: FtsQueryMode,
|
||||
) -> Result<Vec<FtsResult>> {
|
||||
let fts_query = to_fts_query(query, mode);
|
||||
if fts_query.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let sql = r#"
|
||||
SELECT d.id, bm25(documents_fts) AS score,
|
||||
snippet(documents_fts, 1, '<mark>', '</mark>', '...', 64) AS snip
|
||||
FROM documents_fts
|
||||
JOIN documents d ON d.id = documents_fts.rowid
|
||||
WHERE documents_fts MATCH ?1
|
||||
ORDER BY score
|
||||
LIMIT ?2
|
||||
"#;
|
||||
|
||||
let mut stmt = conn.prepare(sql)?;
|
||||
let results = stmt
|
||||
.query_map(rusqlite::params![fts_query, limit as i64], |row| {
|
||||
Ok(FtsResult {
|
||||
document_id: row.get(0)?,
|
||||
bm25_score: row.get(1)?,
|
||||
snippet: row.get(2)?,
|
||||
})
|
||||
})?
|
||||
.collect::<std::result::Result<Vec<_>, _>>()?;
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Generate a fallback snippet for results without FTS snippets.
|
||||
/// Truncates at a word boundary and appends "...".
|
||||
pub fn generate_fallback_snippet(content_text: &str, max_chars: usize) -> String {
|
||||
if content_text.chars().count() <= max_chars {
|
||||
return content_text.to_string();
|
||||
}
|
||||
|
||||
// Collect the char boundary at max_chars to slice correctly for multi-byte content
|
||||
let byte_end = content_text
|
||||
.char_indices()
|
||||
.nth(max_chars)
|
||||
.map(|(i, _)| i)
|
||||
.unwrap_or(content_text.len());
|
||||
let truncated = &content_text[..byte_end];
|
||||
|
||||
// Walk backward to find a word boundary (space)
|
||||
if let Some(last_space) = truncated.rfind(' ') {
|
||||
format!("{}...", &truncated[..last_space])
|
||||
} else {
|
||||
format!("{}...", truncated)
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the best snippet: prefer FTS snippet, fall back to truncated content.
|
||||
pub fn get_result_snippet(fts_snippet: Option<&str>, content_text: &str) -> String {
|
||||
match fts_snippet {
|
||||
Some(s) if !s.is_empty() => s.to_string(),
|
||||
_ => generate_fallback_snippet(content_text, 200),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_safe_query_basic() {
|
||||
let result = to_fts_query("auth error", FtsQueryMode::Safe);
|
||||
assert_eq!(result, "\"auth\" \"error\"");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_safe_query_prefix() {
|
||||
let result = to_fts_query("auth*", FtsQueryMode::Safe);
|
||||
assert_eq!(result, "\"auth\"*");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_safe_query_special_chars() {
|
||||
let result = to_fts_query("C++", FtsQueryMode::Safe);
|
||||
assert_eq!(result, "\"C++\"");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_safe_query_dash() {
|
||||
let result = to_fts_query("-DWITH_SSL", FtsQueryMode::Safe);
|
||||
assert_eq!(result, "\"-DWITH_SSL\"");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_safe_query_quotes() {
|
||||
let result = to_fts_query("he said \"hello\"", FtsQueryMode::Safe);
|
||||
assert_eq!(result, "\"he\" \"said\" \"\"\"hello\"\"\"");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_raw_mode_passthrough() {
|
||||
let result = to_fts_query("auth OR error", FtsQueryMode::Raw);
|
||||
assert_eq!(result, "auth OR error");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_query() {
|
||||
let result = to_fts_query("", FtsQueryMode::Safe);
|
||||
assert_eq!(result, "");
|
||||
|
||||
let result = to_fts_query(" ", FtsQueryMode::Safe);
|
||||
assert_eq!(result, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prefix_only_alphanumeric() {
|
||||
// Non-alphanumeric prefix: C++* should NOT be treated as prefix search
|
||||
let result = to_fts_query("C++*", FtsQueryMode::Safe);
|
||||
assert_eq!(result, "\"C++*\"");
|
||||
|
||||
// Pure alphanumeric prefix: auth* should be prefix search
|
||||
let result = to_fts_query("auth*", FtsQueryMode::Safe);
|
||||
assert_eq!(result, "\"auth\"*");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_prefix_with_underscore() {
|
||||
let result = to_fts_query("jwt_token*", FtsQueryMode::Safe);
|
||||
assert_eq!(result, "\"jwt_token\"*");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fallback_snippet_short() {
|
||||
let result = generate_fallback_snippet("Short content", 200);
|
||||
assert_eq!(result, "Short content");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fallback_snippet_word_boundary() {
|
||||
let content = "This is a moderately long piece of text that should be truncated at a word boundary for readability purposes";
|
||||
let result = generate_fallback_snippet(content, 50);
|
||||
assert!(result.ends_with("..."));
|
||||
assert!(result.len() <= 55); // 50 + "..."
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_result_snippet_prefers_fts() {
|
||||
let result = get_result_snippet(Some("FTS <mark>match</mark>"), "full content text");
|
||||
assert_eq!(result, "FTS <mark>match</mark>");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_result_snippet_fallback() {
|
||||
let result = get_result_snippet(None, "full content text");
|
||||
assert_eq!(result, "full content text");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_result_snippet_empty_fts() {
|
||||
let result = get_result_snippet(Some(""), "full content text");
|
||||
assert_eq!(result, "full content text");
|
||||
}
|
||||
}
|
||||
258
src/search/hybrid.rs
Normal file
258
src/search/hybrid.rs
Normal file
@@ -0,0 +1,258 @@
|
||||
//! Hybrid search orchestrator combining FTS5 + sqlite-vec via RRF.
|
||||
|
||||
use rusqlite::Connection;
|
||||
|
||||
use crate::core::error::Result;
|
||||
use crate::embedding::ollama::OllamaClient;
|
||||
use crate::search::{rank_rrf, search_fts, search_vector, FtsQueryMode};
|
||||
use crate::search::filters::{apply_filters, SearchFilters};
|
||||
|
||||
const BASE_RECALL_MIN: usize = 50;
|
||||
const FILTERED_RECALL_MIN: usize = 200;
|
||||
const RECALL_CAP: usize = 1500;
|
||||
|
||||
/// Search mode selection.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum SearchMode {
|
||||
Hybrid,
|
||||
Lexical,
|
||||
Semantic,
|
||||
}
|
||||
|
||||
impl SearchMode {
|
||||
pub fn parse(s: &str) -> Option<Self> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"hybrid" => Some(Self::Hybrid),
|
||||
"lexical" | "fts" => Some(Self::Lexical),
|
||||
"semantic" | "vector" => Some(Self::Semantic),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Hybrid => "hybrid",
|
||||
Self::Lexical => "lexical",
|
||||
Self::Semantic => "semantic",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Combined search result with provenance from both retrieval lists.
|
||||
pub struct HybridResult {
|
||||
pub document_id: i64,
|
||||
pub score: f64,
|
||||
pub vector_rank: Option<usize>,
|
||||
pub fts_rank: Option<usize>,
|
||||
pub rrf_score: f64,
|
||||
}
|
||||
|
||||
/// Execute hybrid search, returning ranked results + any warnings.
|
||||
///
|
||||
/// `client` is `Option` to enable graceful degradation: when Ollama is
|
||||
/// unavailable, the caller passes `None` and hybrid mode falls back to
|
||||
/// FTS-only with a warning.
|
||||
pub async fn search_hybrid(
|
||||
conn: &Connection,
|
||||
client: Option<&OllamaClient>,
|
||||
query: &str,
|
||||
mode: SearchMode,
|
||||
filters: &SearchFilters,
|
||||
fts_mode: FtsQueryMode,
|
||||
) -> Result<(Vec<HybridResult>, Vec<String>)> {
|
||||
let mut warnings: Vec<String> = Vec::new();
|
||||
|
||||
// Adaptive recall
|
||||
let requested = filters.clamp_limit();
|
||||
let top_k = if filters.has_any_filter() {
|
||||
(requested * 50).max(FILTERED_RECALL_MIN).min(RECALL_CAP)
|
||||
} else {
|
||||
(requested * 10).max(BASE_RECALL_MIN).min(RECALL_CAP)
|
||||
};
|
||||
|
||||
let (fts_tuples, vec_tuples) = match mode {
|
||||
SearchMode::Lexical => {
|
||||
let fts_results = search_fts(conn, query, top_k, fts_mode)?;
|
||||
let fts_tuples: Vec<(i64, f64)> = fts_results
|
||||
.iter()
|
||||
.map(|r| (r.document_id, r.bm25_score))
|
||||
.collect();
|
||||
(fts_tuples, Vec::new())
|
||||
}
|
||||
|
||||
SearchMode::Semantic => {
|
||||
let Some(client) = client else {
|
||||
return Err(crate::core::error::LoreError::Other(
|
||||
"Semantic search requires Ollama. Start Ollama or use --mode=lexical.".into(),
|
||||
));
|
||||
};
|
||||
|
||||
let query_embedding = client.embed_batch(vec![query.to_string()]).await?;
|
||||
let embedding = query_embedding
|
||||
.into_iter()
|
||||
.next()
|
||||
.unwrap_or_default();
|
||||
|
||||
if embedding.is_empty() {
|
||||
return Err(crate::core::error::LoreError::Other(
|
||||
"Ollama returned empty embedding for query.".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let vec_results = search_vector(conn, &embedding, top_k)?;
|
||||
let vec_tuples: Vec<(i64, f64)> = vec_results
|
||||
.iter()
|
||||
.map(|r| (r.document_id, r.distance))
|
||||
.collect();
|
||||
(Vec::new(), vec_tuples)
|
||||
}
|
||||
|
||||
SearchMode::Hybrid => {
|
||||
let fts_results = search_fts(conn, query, top_k, fts_mode)?;
|
||||
let fts_tuples: Vec<(i64, f64)> = fts_results
|
||||
.iter()
|
||||
.map(|r| (r.document_id, r.bm25_score))
|
||||
.collect();
|
||||
|
||||
match client {
|
||||
Some(client) => {
|
||||
match client.embed_batch(vec![query.to_string()]).await {
|
||||
Ok(query_embedding) => {
|
||||
let embedding = query_embedding
|
||||
.into_iter()
|
||||
.next()
|
||||
.unwrap_or_default();
|
||||
|
||||
let vec_tuples = if embedding.is_empty() {
|
||||
warnings.push(
|
||||
"Ollama returned empty embedding, using FTS only.".into(),
|
||||
);
|
||||
Vec::new()
|
||||
} else {
|
||||
let vec_results = search_vector(conn, &embedding, top_k)?;
|
||||
vec_results
|
||||
.iter()
|
||||
.map(|r| (r.document_id, r.distance))
|
||||
.collect()
|
||||
};
|
||||
|
||||
(fts_tuples, vec_tuples)
|
||||
}
|
||||
Err(e) => {
|
||||
warnings.push(
|
||||
format!("Embedding failed ({}), falling back to lexical search.", e),
|
||||
);
|
||||
(fts_tuples, Vec::new())
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
warnings.push(
|
||||
"Ollama unavailable, falling back to lexical search.".into(),
|
||||
);
|
||||
(fts_tuples, Vec::new())
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let ranked = rank_rrf(&vec_tuples, &fts_tuples);
|
||||
|
||||
let results: Vec<HybridResult> = ranked
|
||||
.into_iter()
|
||||
.map(|r| HybridResult {
|
||||
document_id: r.document_id,
|
||||
score: r.normalized_score,
|
||||
vector_rank: r.vector_rank,
|
||||
fts_rank: r.fts_rank,
|
||||
rrf_score: r.rrf_score,
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Apply post-retrieval filters and limit
|
||||
let limit = filters.clamp_limit();
|
||||
let results = if filters.has_any_filter() {
|
||||
let all_ids: Vec<i64> = results.iter().map(|r| r.document_id).collect();
|
||||
let filtered_ids = apply_filters(conn, &all_ids, filters)?;
|
||||
let filtered_set: std::collections::HashSet<i64> = filtered_ids.iter().copied().collect();
|
||||
results
|
||||
.into_iter()
|
||||
.filter(|r| filtered_set.contains(&r.document_id))
|
||||
.take(limit)
|
||||
.collect()
|
||||
} else {
|
||||
results.into_iter().take(limit).collect()
|
||||
};
|
||||
|
||||
Ok((results, warnings))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_search_mode_from_str() {
|
||||
assert_eq!(SearchMode::parse("hybrid"), Some(SearchMode::Hybrid));
|
||||
assert_eq!(SearchMode::parse("lexical"), Some(SearchMode::Lexical));
|
||||
assert_eq!(SearchMode::parse("fts"), Some(SearchMode::Lexical));
|
||||
assert_eq!(SearchMode::parse("semantic"), Some(SearchMode::Semantic));
|
||||
assert_eq!(SearchMode::parse("vector"), Some(SearchMode::Semantic));
|
||||
assert_eq!(SearchMode::parse("HYBRID"), Some(SearchMode::Hybrid));
|
||||
assert_eq!(SearchMode::parse("invalid"), None);
|
||||
assert_eq!(SearchMode::parse(""), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_search_mode_as_str() {
|
||||
assert_eq!(SearchMode::Hybrid.as_str(), "hybrid");
|
||||
assert_eq!(SearchMode::Lexical.as_str(), "lexical");
|
||||
assert_eq!(SearchMode::Semantic.as_str(), "semantic");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adaptive_recall_unfiltered() {
|
||||
let filters = SearchFilters {
|
||||
limit: 20,
|
||||
..Default::default()
|
||||
};
|
||||
let requested = filters.clamp_limit();
|
||||
let top_k = (requested * 10).max(BASE_RECALL_MIN).min(RECALL_CAP);
|
||||
assert_eq!(top_k, 200);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adaptive_recall_filtered() {
|
||||
let filters = SearchFilters {
|
||||
limit: 20,
|
||||
author: Some("alice".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
let requested = filters.clamp_limit();
|
||||
let top_k = (requested * 50).max(FILTERED_RECALL_MIN).min(RECALL_CAP);
|
||||
assert_eq!(top_k, 1000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adaptive_recall_cap() {
|
||||
let filters = SearchFilters {
|
||||
limit: 100,
|
||||
author: Some("alice".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
let requested = filters.clamp_limit();
|
||||
let top_k = (requested * 50).max(FILTERED_RECALL_MIN).min(RECALL_CAP);
|
||||
assert_eq!(top_k, RECALL_CAP); // 5000 capped to 1500
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adaptive_recall_minimum() {
|
||||
let filters = SearchFilters {
|
||||
limit: 1,
|
||||
..Default::default()
|
||||
};
|
||||
let requested = filters.clamp_limit();
|
||||
let top_k = (requested * 10).max(BASE_RECALL_MIN).min(RECALL_CAP);
|
||||
assert_eq!(top_k, BASE_RECALL_MIN); // 10 -> 50
|
||||
}
|
||||
}
|
||||
14
src/search/mod.rs
Normal file
14
src/search/mod.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
mod filters;
|
||||
mod fts;
|
||||
mod hybrid;
|
||||
mod rrf;
|
||||
mod vector;
|
||||
|
||||
pub use fts::{
|
||||
generate_fallback_snippet, get_result_snippet, search_fts, to_fts_query, FtsQueryMode,
|
||||
FtsResult,
|
||||
};
|
||||
pub use filters::{apply_filters, PathFilter, SearchFilters};
|
||||
pub use rrf::{rank_rrf, RrfResult};
|
||||
pub use vector::{search_vector, VectorResult};
|
||||
pub use hybrid::{search_hybrid, HybridResult, SearchMode};
|
||||
178
src/search/rrf.rs
Normal file
178
src/search/rrf.rs
Normal file
@@ -0,0 +1,178 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
const RRF_K: f64 = 60.0;
|
||||
|
||||
/// A single result from Reciprocal Rank Fusion, containing both raw and
|
||||
/// normalized scores plus per-list rank provenance for --explain output.
|
||||
pub struct RrfResult {
|
||||
pub document_id: i64,
|
||||
/// Raw RRF score: sum of 1/(k + rank) across all lists.
|
||||
pub rrf_score: f64,
|
||||
/// Normalized to [0, 1] where the best result is 1.0.
|
||||
pub normalized_score: f64,
|
||||
/// 1-indexed rank in the vector results list, if present.
|
||||
pub vector_rank: Option<usize>,
|
||||
/// 1-indexed rank in the FTS results list, if present.
|
||||
pub fts_rank: Option<usize>,
|
||||
}
|
||||
|
||||
/// Combine vector and FTS retrieval results using Reciprocal Rank Fusion.
|
||||
///
|
||||
/// Input tuples are `(document_id, score/distance)` — already sorted by each retriever.
|
||||
/// Ranks are 1-indexed (first result = rank 1).
|
||||
///
|
||||
/// Score = sum of 1/(k + rank) for each list containing the document.
|
||||
pub fn rank_rrf(
|
||||
vector_results: &[(i64, f64)],
|
||||
fts_results: &[(i64, f64)],
|
||||
) -> Vec<RrfResult> {
|
||||
if vector_results.is_empty() && fts_results.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
// (rrf_score, vector_rank, fts_rank)
|
||||
let mut scores: HashMap<i64, (f64, Option<usize>, Option<usize>)> = HashMap::new();
|
||||
|
||||
for (i, &(doc_id, _)) in vector_results.iter().enumerate() {
|
||||
let rank = i + 1; // 1-indexed
|
||||
let entry = scores.entry(doc_id).or_insert((0.0, None, None));
|
||||
entry.0 += 1.0 / (RRF_K + rank as f64);
|
||||
if entry.1.is_none() {
|
||||
entry.1 = Some(rank);
|
||||
}
|
||||
}
|
||||
|
||||
for (i, &(doc_id, _)) in fts_results.iter().enumerate() {
|
||||
let rank = i + 1; // 1-indexed
|
||||
let entry = scores.entry(doc_id).or_insert((0.0, None, None));
|
||||
entry.0 += 1.0 / (RRF_K + rank as f64);
|
||||
if entry.2.is_none() {
|
||||
entry.2 = Some(rank);
|
||||
}
|
||||
}
|
||||
|
||||
let mut results: Vec<RrfResult> = scores
|
||||
.into_iter()
|
||||
.map(|(doc_id, (rrf_score, vector_rank, fts_rank))| RrfResult {
|
||||
document_id: doc_id,
|
||||
rrf_score,
|
||||
normalized_score: 0.0, // filled in below
|
||||
vector_rank,
|
||||
fts_rank,
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Sort descending by rrf_score
|
||||
results.sort_by(|a, b| b.rrf_score.partial_cmp(&a.rrf_score).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
// Normalize: best = 1.0
|
||||
if let Some(max_score) = results.first().map(|r| r.rrf_score) {
|
||||
if max_score > 0.0 {
|
||||
for result in &mut results {
|
||||
result.normalized_score = result.rrf_score / max_score;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_dual_list_ranks_higher() {
|
||||
let vector = vec![(1, 0.1), (2, 0.2)];
|
||||
let fts = vec![(1, 5.0), (3, 3.0)];
|
||||
let results = rank_rrf(&vector, &fts);
|
||||
|
||||
// Doc 1 appears in both lists, should rank highest
|
||||
assert_eq!(results[0].document_id, 1);
|
||||
|
||||
// Doc 1 score should be higher than doc 2 and doc 3
|
||||
let doc1 = &results[0];
|
||||
let doc2_score = results.iter().find(|r| r.document_id == 2).unwrap().rrf_score;
|
||||
let doc3_score = results.iter().find(|r| r.document_id == 3).unwrap().rrf_score;
|
||||
assert!(doc1.rrf_score > doc2_score);
|
||||
assert!(doc1.rrf_score > doc3_score);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_single_list_included() {
|
||||
let vector = vec![(1, 0.1)];
|
||||
let fts = vec![(2, 5.0)];
|
||||
let results = rank_rrf(&vector, &fts);
|
||||
|
||||
assert_eq!(results.len(), 2);
|
||||
let doc_ids: Vec<i64> = results.iter().map(|r| r.document_id).collect();
|
||||
assert!(doc_ids.contains(&1));
|
||||
assert!(doc_ids.contains(&2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalization() {
|
||||
let vector = vec![(1, 0.1), (2, 0.2)];
|
||||
let fts = vec![(1, 5.0), (3, 3.0)];
|
||||
let results = rank_rrf(&vector, &fts);
|
||||
|
||||
// Best result should have normalized_score = 1.0
|
||||
assert!((results[0].normalized_score - 1.0).abs() < f64::EPSILON);
|
||||
|
||||
// All scores in [0, 1]
|
||||
for r in &results {
|
||||
assert!(r.normalized_score >= 0.0);
|
||||
assert!(r.normalized_score <= 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_inputs() {
|
||||
let results = rank_rrf(&[], &[]);
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ranks_are_1_indexed() {
|
||||
let vector = vec![(10, 0.1), (20, 0.2)];
|
||||
let fts = vec![(10, 5.0), (30, 3.0)];
|
||||
let results = rank_rrf(&vector, &fts);
|
||||
|
||||
let doc10 = results.iter().find(|r| r.document_id == 10).unwrap();
|
||||
assert_eq!(doc10.vector_rank, Some(1));
|
||||
assert_eq!(doc10.fts_rank, Some(1));
|
||||
|
||||
let doc20 = results.iter().find(|r| r.document_id == 20).unwrap();
|
||||
assert_eq!(doc20.vector_rank, Some(2));
|
||||
assert_eq!(doc20.fts_rank, None);
|
||||
|
||||
let doc30 = results.iter().find(|r| r.document_id == 30).unwrap();
|
||||
assert_eq!(doc30.vector_rank, None);
|
||||
assert_eq!(doc30.fts_rank, Some(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_raw_and_normalized_scores() {
|
||||
let vector = vec![(1, 0.1)];
|
||||
let fts = vec![(1, 5.0)];
|
||||
let results = rank_rrf(&vector, &fts);
|
||||
|
||||
assert_eq!(results.len(), 1);
|
||||
let r = &results[0];
|
||||
|
||||
// RRF score = 1/(60+1) + 1/(60+1) = 2/61
|
||||
let expected = 2.0 / 61.0;
|
||||
assert!((r.rrf_score - expected).abs() < 1e-10);
|
||||
assert!((r.normalized_score - 1.0).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_one_empty_list() {
|
||||
let vector = vec![(1, 0.1), (2, 0.2)];
|
||||
let results = rank_rrf(&vector, &[]);
|
||||
|
||||
assert_eq!(results.len(), 2);
|
||||
// Single result should still have normalized_score = 1.0
|
||||
assert!((results[0].normalized_score - 1.0).abs() < f64::EPSILON);
|
||||
}
|
||||
}
|
||||
139
src/search/vector.rs
Normal file
139
src/search/vector.rs
Normal file
@@ -0,0 +1,139 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use rusqlite::Connection;
|
||||
|
||||
use crate::core::error::Result;
|
||||
use crate::embedding::chunk_ids::decode_rowid;
|
||||
|
||||
/// A single vector search result (document-level, deduplicated).
|
||||
#[derive(Debug)]
|
||||
pub struct VectorResult {
|
||||
pub document_id: i64,
|
||||
pub distance: f64,
|
||||
}
|
||||
|
||||
/// Search documents using sqlite-vec KNN query.
|
||||
///
|
||||
/// Over-fetches 3x limit to handle chunk deduplication (multiple chunks per
|
||||
/// document produce multiple KNN results for the same document_id).
|
||||
/// Returns deduplicated results with best (lowest) distance per document.
|
||||
pub fn search_vector(
|
||||
conn: &Connection,
|
||||
query_embedding: &[f32],
|
||||
limit: usize,
|
||||
) -> Result<Vec<VectorResult>> {
|
||||
if query_embedding.is_empty() || limit == 0 {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// Convert to raw little-endian bytes for sqlite-vec
|
||||
let embedding_bytes: Vec<u8> = query_embedding
|
||||
.iter()
|
||||
.flat_map(|f| f.to_le_bytes())
|
||||
.collect();
|
||||
|
||||
let k = limit * 3; // Over-fetch for dedup
|
||||
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT rowid, distance
|
||||
FROM embeddings
|
||||
WHERE embedding MATCH ?1
|
||||
AND k = ?2
|
||||
ORDER BY distance"
|
||||
)?;
|
||||
|
||||
let rows: Vec<(i64, f64)> = stmt
|
||||
.query_map(rusqlite::params![embedding_bytes, k as i64], |row| {
|
||||
Ok((row.get(0)?, row.get(1)?))
|
||||
})?
|
||||
.collect::<std::result::Result<Vec<_>, _>>()?;
|
||||
|
||||
// Dedup by document_id, keeping best (lowest) distance
|
||||
let mut best: HashMap<i64, f64> = HashMap::new();
|
||||
for (rowid, distance) in rows {
|
||||
let (document_id, _chunk_index) = decode_rowid(rowid);
|
||||
best.entry(document_id)
|
||||
.and_modify(|d| {
|
||||
if distance < *d {
|
||||
*d = distance;
|
||||
}
|
||||
})
|
||||
.or_insert(distance);
|
||||
}
|
||||
|
||||
// Sort by distance ascending, take limit
|
||||
let mut results: Vec<VectorResult> = best
|
||||
.into_iter()
|
||||
.map(|(document_id, distance)| VectorResult {
|
||||
document_id,
|
||||
distance,
|
||||
})
|
||||
.collect();
|
||||
results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap_or(std::cmp::Ordering::Equal));
|
||||
results.truncate(limit);
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// Note: Full integration tests require sqlite-vec loaded, which happens via
|
||||
// create_connection in db.rs. These are basic unit tests for the dedup logic.
|
||||
|
||||
#[test]
|
||||
fn test_empty_returns_empty() {
|
||||
// Can't test KNN without sqlite-vec, but we can test edge cases
|
||||
let result = search_vector_dedup(vec![], 10);
|
||||
assert!(result.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dedup_keeps_best_distance() {
|
||||
// Simulate: doc 1 has chunks at rowid 1000 (idx 0) and 1001 (idx 1)
|
||||
let rows = vec![
|
||||
(1000_i64, 0.5_f64), // doc 1, chunk 0
|
||||
(1001, 0.3), // doc 1, chunk 1 (better)
|
||||
(2000, 0.4), // doc 2, chunk 0
|
||||
];
|
||||
let results = search_vector_dedup(rows, 10);
|
||||
assert_eq!(results.len(), 2);
|
||||
assert_eq!(results[0].document_id, 1); // doc 1 best = 0.3
|
||||
assert!((results[0].distance - 0.3).abs() < f64::EPSILON);
|
||||
assert_eq!(results[1].document_id, 2); // doc 2 = 0.4
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dedup_respects_limit() {
|
||||
let rows = vec![
|
||||
(1000_i64, 0.1_f64),
|
||||
(2000, 0.2),
|
||||
(3000, 0.3),
|
||||
];
|
||||
let results = search_vector_dedup(rows, 2);
|
||||
assert_eq!(results.len(), 2);
|
||||
}
|
||||
|
||||
/// Helper for testing dedup logic without sqlite-vec
|
||||
fn search_vector_dedup(rows: Vec<(i64, f64)>, limit: usize) -> Vec<VectorResult> {
|
||||
let mut best: HashMap<i64, f64> = HashMap::new();
|
||||
for (rowid, distance) in rows {
|
||||
let (document_id, _) = decode_rowid(rowid);
|
||||
best.entry(document_id)
|
||||
.and_modify(|d| {
|
||||
if distance < *d {
|
||||
*d = distance;
|
||||
}
|
||||
})
|
||||
.or_insert(distance);
|
||||
}
|
||||
let mut results: Vec<VectorResult> = best
|
||||
.into_iter()
|
||||
.map(|(document_id, distance)| VectorResult { document_id, distance })
|
||||
.collect();
|
||||
results.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap_or(std::cmp::Ordering::Equal));
|
||||
results.truncate(limit);
|
||||
results
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user