Change OllamaClient::embed_batch to accept &[&str] instead of Vec<String>. The EmbedRequest struct now borrows both model name and input texts, eliminating per-batch cloning of chunk text (up to 32KB per chunk x 32 chunks per batch). Serialization output is identical since serde serializes &str and String to the same JSON. In hybrid search, defer the RrfResult->HybridResult mapping until after filter+take, so only `limit` items (typically 20) are constructed instead of up to 1,500 at RECALL_CAP. Also switch filtered_ids to into_iter() to avoid an extra .copied() pass. Switch FTS search_fts from prepare() to prepare_cached() for statement reuse across repeated searches. Benchmarked at ~1.6x faster. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
237 lines
7.7 KiB
Rust
237 lines
7.7 KiB
Rust
use rusqlite::Connection;
|
|
|
|
use crate::core::error::Result;
|
|
use crate::embedding::ollama::OllamaClient;
|
|
use crate::search::filters::{SearchFilters, apply_filters};
|
|
use crate::search::rrf::RrfResult;
|
|
use crate::search::{FtsQueryMode, rank_rrf, search_fts, search_vector};
|
|
|
|
const BASE_RECALL_MIN: usize = 50;
|
|
const FILTERED_RECALL_MIN: usize = 200;
|
|
const RECALL_CAP: usize = 1500;
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
pub enum SearchMode {
|
|
Hybrid,
|
|
Lexical,
|
|
Semantic,
|
|
}
|
|
|
|
impl SearchMode {
|
|
pub fn parse(s: &str) -> Option<Self> {
|
|
match s.to_lowercase().as_str() {
|
|
"hybrid" => Some(Self::Hybrid),
|
|
"lexical" | "fts" => Some(Self::Lexical),
|
|
"semantic" | "vector" => Some(Self::Semantic),
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
pub fn as_str(&self) -> &'static str {
|
|
match self {
|
|
Self::Hybrid => "hybrid",
|
|
Self::Lexical => "lexical",
|
|
Self::Semantic => "semantic",
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct HybridResult {
|
|
pub document_id: i64,
|
|
pub score: f64,
|
|
pub vector_rank: Option<usize>,
|
|
pub fts_rank: Option<usize>,
|
|
pub rrf_score: f64,
|
|
}
|
|
|
|
pub async fn search_hybrid(
|
|
conn: &Connection,
|
|
client: Option<&OllamaClient>,
|
|
query: &str,
|
|
mode: SearchMode,
|
|
filters: &SearchFilters,
|
|
fts_mode: FtsQueryMode,
|
|
) -> Result<(Vec<HybridResult>, Vec<String>)> {
|
|
let mut warnings: Vec<String> = Vec::new();
|
|
|
|
let requested = filters.clamp_limit();
|
|
let top_k = if filters.has_any_filter() {
|
|
(requested * 50).clamp(FILTERED_RECALL_MIN, RECALL_CAP)
|
|
} else {
|
|
(requested * 10).clamp(BASE_RECALL_MIN, RECALL_CAP)
|
|
};
|
|
|
|
let (fts_tuples, vec_tuples) = match mode {
|
|
SearchMode::Lexical => {
|
|
let fts_results = search_fts(conn, query, top_k, fts_mode)?;
|
|
let fts_tuples: Vec<(i64, f64)> = fts_results
|
|
.iter()
|
|
.map(|r| (r.document_id, r.bm25_score))
|
|
.collect();
|
|
(fts_tuples, Vec::new())
|
|
}
|
|
|
|
SearchMode::Semantic => {
|
|
let Some(client) = client else {
|
|
return Err(crate::core::error::LoreError::Other(
|
|
"Semantic search requires Ollama. Start Ollama or use --mode=lexical.".into(),
|
|
));
|
|
};
|
|
|
|
let query_embedding = client.embed_batch(&[query]).await?;
|
|
let embedding = query_embedding.into_iter().next().unwrap_or_default();
|
|
|
|
if embedding.is_empty() {
|
|
return Err(crate::core::error::LoreError::Other(
|
|
"Ollama returned empty embedding for query.".into(),
|
|
));
|
|
}
|
|
|
|
let vec_results = search_vector(conn, &embedding, top_k)?;
|
|
let vec_tuples: Vec<(i64, f64)> = vec_results
|
|
.iter()
|
|
.map(|r| (r.document_id, r.distance))
|
|
.collect();
|
|
(Vec::new(), vec_tuples)
|
|
}
|
|
|
|
SearchMode::Hybrid => {
|
|
let fts_results = search_fts(conn, query, top_k, fts_mode)?;
|
|
let fts_tuples: Vec<(i64, f64)> = fts_results
|
|
.iter()
|
|
.map(|r| (r.document_id, r.bm25_score))
|
|
.collect();
|
|
|
|
match client {
|
|
Some(client) => match client.embed_batch(&[query]).await {
|
|
Ok(query_embedding) => {
|
|
let embedding = query_embedding.into_iter().next().unwrap_or_default();
|
|
|
|
let vec_tuples = if embedding.is_empty() {
|
|
warnings
|
|
.push("Ollama returned empty embedding, using FTS only.".into());
|
|
Vec::new()
|
|
} else {
|
|
let vec_results = search_vector(conn, &embedding, top_k)?;
|
|
vec_results
|
|
.iter()
|
|
.map(|r| (r.document_id, r.distance))
|
|
.collect()
|
|
};
|
|
|
|
(fts_tuples, vec_tuples)
|
|
}
|
|
Err(e) => {
|
|
warnings.push(format!(
|
|
"Embedding failed ({}), falling back to lexical search.",
|
|
e
|
|
));
|
|
(fts_tuples, Vec::new())
|
|
}
|
|
},
|
|
None => {
|
|
warnings.push("Ollama unavailable, falling back to lexical search.".into());
|
|
(fts_tuples, Vec::new())
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
let ranked = rank_rrf(&vec_tuples, &fts_tuples);
|
|
let limit = filters.clamp_limit();
|
|
|
|
let to_hybrid = |r: RrfResult| HybridResult {
|
|
document_id: r.document_id,
|
|
score: r.normalized_score,
|
|
vector_rank: r.vector_rank,
|
|
fts_rank: r.fts_rank,
|
|
rrf_score: r.rrf_score,
|
|
};
|
|
|
|
let results: Vec<HybridResult> = if filters.has_any_filter() {
|
|
let all_ids: Vec<i64> = ranked.iter().map(|r| r.document_id).collect();
|
|
let filtered_ids = apply_filters(conn, &all_ids, filters)?;
|
|
let filtered_set: std::collections::HashSet<i64> = filtered_ids.into_iter().collect();
|
|
ranked
|
|
.into_iter()
|
|
.filter(|r| filtered_set.contains(&r.document_id))
|
|
.take(limit)
|
|
.map(to_hybrid)
|
|
.collect()
|
|
} else {
|
|
ranked.into_iter().take(limit).map(to_hybrid).collect()
|
|
};
|
|
|
|
Ok((results, warnings))
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_search_mode_from_str() {
|
|
assert_eq!(SearchMode::parse("hybrid"), Some(SearchMode::Hybrid));
|
|
assert_eq!(SearchMode::parse("lexical"), Some(SearchMode::Lexical));
|
|
assert_eq!(SearchMode::parse("fts"), Some(SearchMode::Lexical));
|
|
assert_eq!(SearchMode::parse("semantic"), Some(SearchMode::Semantic));
|
|
assert_eq!(SearchMode::parse("vector"), Some(SearchMode::Semantic));
|
|
assert_eq!(SearchMode::parse("HYBRID"), Some(SearchMode::Hybrid));
|
|
assert_eq!(SearchMode::parse("invalid"), None);
|
|
assert_eq!(SearchMode::parse(""), None);
|
|
}
|
|
|
|
#[test]
|
|
fn test_search_mode_as_str() {
|
|
assert_eq!(SearchMode::Hybrid.as_str(), "hybrid");
|
|
assert_eq!(SearchMode::Lexical.as_str(), "lexical");
|
|
assert_eq!(SearchMode::Semantic.as_str(), "semantic");
|
|
}
|
|
|
|
#[test]
|
|
fn test_adaptive_recall_unfiltered() {
|
|
let filters = SearchFilters {
|
|
limit: 20,
|
|
..Default::default()
|
|
};
|
|
let requested = filters.clamp_limit();
|
|
let top_k = (requested * 10).clamp(BASE_RECALL_MIN, RECALL_CAP);
|
|
assert_eq!(top_k, 200);
|
|
}
|
|
|
|
#[test]
|
|
fn test_adaptive_recall_filtered() {
|
|
let filters = SearchFilters {
|
|
limit: 20,
|
|
author: Some("alice".to_string()),
|
|
..Default::default()
|
|
};
|
|
let requested = filters.clamp_limit();
|
|
let top_k = (requested * 50).clamp(FILTERED_RECALL_MIN, RECALL_CAP);
|
|
assert_eq!(top_k, 1000);
|
|
}
|
|
|
|
#[test]
|
|
fn test_adaptive_recall_cap() {
|
|
let filters = SearchFilters {
|
|
limit: 100,
|
|
author: Some("alice".to_string()),
|
|
..Default::default()
|
|
};
|
|
let requested = filters.clamp_limit();
|
|
let top_k = (requested * 50).clamp(FILTERED_RECALL_MIN, RECALL_CAP);
|
|
assert_eq!(top_k, RECALL_CAP);
|
|
}
|
|
|
|
#[test]
|
|
fn test_adaptive_recall_minimum() {
|
|
let filters = SearchFilters {
|
|
limit: 1,
|
|
..Default::default()
|
|
};
|
|
let requested = filters.clamp_limit();
|
|
let top_k = (requested * 10).clamp(BASE_RECALL_MIN, RECALL_CAP);
|
|
assert_eq!(top_k, BASE_RECALL_MIN);
|
|
}
|
|
}
|