Files
gitlore/src/search/vector.rs
Taylor Eernisse 67e2498689 chore: add roam CI workflow, fitness config, formatting, beads sync
- Add .github/workflows/roam.yml for automated codebase indexing
- Add .roam/fitness.yaml with architectural fitness rules
- Reformat show.rs and vector.rs (rustfmt line-wrapping only, no logic)
- Sync beads issue tracker

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 12:38:12 -05:00

207 lines
5.8 KiB
Rust

use std::collections::HashMap;
use rusqlite::Connection;
use rusqlite::OptionalExtension;
use crate::core::error::Result;
use crate::embedding::chunk_ids::decode_rowid;
#[derive(Debug)]
pub struct VectorResult {
pub document_id: i64,
pub distance: f64,
}
fn max_chunks_per_document(conn: &Connection) -> Result<i64> {
let stored: Option<i64> = conn
.query_row(
"SELECT MAX(chunk_count) FROM embedding_metadata
WHERE chunk_index = 0 AND chunk_count IS NOT NULL",
[],
|row| row.get(0),
)
.optional()?
.flatten();
if let Some(max) = stored {
return Ok(max);
}
Ok(conn
.query_row(
"SELECT COALESCE(MAX(cnt), 1) FROM (
SELECT COUNT(*) as cnt FROM embedding_metadata
WHERE last_error IS NULL GROUP BY document_id
)",
[],
|row| row.get(0),
)
.optional()?
.unwrap_or(1))
}
/// sqlite-vec hard limit for KNN `k` parameter.
const SQLITE_VEC_KNN_MAX: usize = 4_096;
/// Compute the KNN k value from the requested limit and the max chunks per
/// document. The result is guaranteed to never exceed [`SQLITE_VEC_KNN_MAX`].
fn compute_knn_k(limit: usize, max_chunks_per_doc: i64) -> usize {
let max_chunks = max_chunks_per_doc.unsigned_abs().max(1) as usize;
let multiplier = (max_chunks * 3 / 2 + 1).clamp(8, 200);
(limit * multiplier).min(SQLITE_VEC_KNN_MAX)
}
pub fn search_vector(
conn: &Connection,
query_embedding: &[f32],
limit: usize,
) -> Result<Vec<VectorResult>> {
if query_embedding.is_empty() || limit == 0 {
return Ok(Vec::new());
}
let embedding_bytes: Vec<u8> = query_embedding
.iter()
.flat_map(|f| f.to_le_bytes())
.collect();
let max_chunks = max_chunks_per_document(conn)?.max(1);
let k = compute_knn_k(limit, max_chunks);
let mut stmt = conn.prepare(
"SELECT rowid, distance
FROM embeddings
WHERE embedding MATCH ?1
AND k = ?2
ORDER BY distance",
)?;
let rows: Vec<(i64, f64)> = stmt
.query_map(rusqlite::params![embedding_bytes, k as i64], |row| {
Ok((row.get(0)?, row.get(1)?))
})?
.collect::<std::result::Result<Vec<_>, _>>()?;
let mut best: HashMap<i64, f64> = HashMap::new();
for (rowid, distance) in rows {
let (document_id, _chunk_index) = decode_rowid(rowid);
best.entry(document_id)
.and_modify(|d| {
if distance < *d {
*d = distance;
}
})
.or_insert(distance);
}
let mut results: Vec<VectorResult> = best
.into_iter()
.map(|(document_id, distance)| VectorResult {
document_id,
distance,
})
.collect();
results.sort_by(|a, b| a.distance.total_cmp(&b.distance));
results.truncate(limit);
Ok(results)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_returns_empty() {
let result = search_vector_dedup(vec![], 10);
assert!(result.is_empty());
}
#[test]
fn test_dedup_keeps_best_distance() {
let rows = vec![(1000_i64, 0.5_f64), (1001, 0.3), (2000, 0.4)];
let results = search_vector_dedup(rows, 10);
assert_eq!(results.len(), 2);
assert_eq!(results[0].document_id, 1);
assert!((results[0].distance - 0.3).abs() < f64::EPSILON);
assert_eq!(results[1].document_id, 2);
}
#[test]
fn test_dedup_respects_limit() {
let rows = vec![(1000_i64, 0.1_f64), (2000, 0.2), (3000, 0.3)];
let results = search_vector_dedup(rows, 2);
assert_eq!(results.len(), 2);
}
#[test]
fn test_knn_k_never_exceeds_sqlite_vec_limit() {
for limit in [1, 10, 50, 100, 500, 1000, 1500, 2000, 5000] {
for max_chunks in [1, 2, 5, 10, 50, 100, 200, 500, 1000] {
let k = compute_knn_k(limit, max_chunks);
assert!(
k <= SQLITE_VEC_KNN_MAX,
"k={k} exceeded limit for limit={limit}, max_chunks={max_chunks}"
);
}
}
}
#[test]
fn test_knn_k_reproduces_original_bug_scenario() {
let k = compute_knn_k(1500, 1);
assert!(
k <= SQLITE_VEC_KNN_MAX,
"k={k} exceeded 4096 at RECALL_CAP with 1 chunk"
);
}
#[test]
fn test_knn_k_small_limit_uses_minimum_multiplier() {
let k = compute_knn_k(10, 1);
assert_eq!(k, 80);
}
#[test]
fn test_knn_k_high_chunks_caps_multiplier() {
let k = compute_knn_k(10, 200);
assert_eq!(k, 2000);
}
#[test]
fn test_knn_k_zero_max_chunks_treated_as_one() {
let k = compute_knn_k(10, 0);
assert_eq!(k, 80);
}
#[test]
fn test_knn_k_negative_max_chunks_uses_absolute() {
let k = compute_knn_k(10, -5);
assert_eq!(k, compute_knn_k(10, 5));
}
fn search_vector_dedup(rows: Vec<(i64, f64)>, limit: usize) -> Vec<VectorResult> {
let mut best: HashMap<i64, f64> = HashMap::new();
for (rowid, distance) in rows {
let (document_id, _) = decode_rowid(rowid);
best.entry(document_id)
.and_modify(|d| {
if distance < *d {
*d = distance;
}
})
.or_insert(distance);
}
let mut results: Vec<VectorResult> = best
.into_iter()
.map(|(document_id, distance)| VectorResult {
document_id,
distance,
})
.collect();
results.sort_by(|a, b| a.distance.total_cmp(&b.distance));
results.truncate(limit);
results
}
}