- Add .github/workflows/roam.yml for automated codebase indexing - Add .roam/fitness.yaml with architectural fitness rules - Reformat show.rs and vector.rs (rustfmt line-wrapping only, no logic) - Sync beads issue tracker Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
207 lines
5.8 KiB
Rust
207 lines
5.8 KiB
Rust
use std::collections::HashMap;
|
|
|
|
use rusqlite::Connection;
|
|
use rusqlite::OptionalExtension;
|
|
|
|
use crate::core::error::Result;
|
|
use crate::embedding::chunk_ids::decode_rowid;
|
|
|
|
#[derive(Debug)]
|
|
pub struct VectorResult {
|
|
pub document_id: i64,
|
|
pub distance: f64,
|
|
}
|
|
|
|
fn max_chunks_per_document(conn: &Connection) -> Result<i64> {
|
|
let stored: Option<i64> = conn
|
|
.query_row(
|
|
"SELECT MAX(chunk_count) FROM embedding_metadata
|
|
WHERE chunk_index = 0 AND chunk_count IS NOT NULL",
|
|
[],
|
|
|row| row.get(0),
|
|
)
|
|
.optional()?
|
|
.flatten();
|
|
|
|
if let Some(max) = stored {
|
|
return Ok(max);
|
|
}
|
|
|
|
Ok(conn
|
|
.query_row(
|
|
"SELECT COALESCE(MAX(cnt), 1) FROM (
|
|
SELECT COUNT(*) as cnt FROM embedding_metadata
|
|
WHERE last_error IS NULL GROUP BY document_id
|
|
)",
|
|
[],
|
|
|row| row.get(0),
|
|
)
|
|
.optional()?
|
|
.unwrap_or(1))
|
|
}
|
|
|
|
/// sqlite-vec hard limit for KNN `k` parameter.
|
|
const SQLITE_VEC_KNN_MAX: usize = 4_096;
|
|
|
|
/// Compute the KNN k value from the requested limit and the max chunks per
|
|
/// document. The result is guaranteed to never exceed [`SQLITE_VEC_KNN_MAX`].
|
|
fn compute_knn_k(limit: usize, max_chunks_per_doc: i64) -> usize {
|
|
let max_chunks = max_chunks_per_doc.unsigned_abs().max(1) as usize;
|
|
let multiplier = (max_chunks * 3 / 2 + 1).clamp(8, 200);
|
|
(limit * multiplier).min(SQLITE_VEC_KNN_MAX)
|
|
}
|
|
|
|
pub fn search_vector(
|
|
conn: &Connection,
|
|
query_embedding: &[f32],
|
|
limit: usize,
|
|
) -> Result<Vec<VectorResult>> {
|
|
if query_embedding.is_empty() || limit == 0 {
|
|
return Ok(Vec::new());
|
|
}
|
|
|
|
let embedding_bytes: Vec<u8> = query_embedding
|
|
.iter()
|
|
.flat_map(|f| f.to_le_bytes())
|
|
.collect();
|
|
|
|
let max_chunks = max_chunks_per_document(conn)?.max(1);
|
|
let k = compute_knn_k(limit, max_chunks);
|
|
|
|
let mut stmt = conn.prepare(
|
|
"SELECT rowid, distance
|
|
FROM embeddings
|
|
WHERE embedding MATCH ?1
|
|
AND k = ?2
|
|
ORDER BY distance",
|
|
)?;
|
|
|
|
let rows: Vec<(i64, f64)> = stmt
|
|
.query_map(rusqlite::params![embedding_bytes, k as i64], |row| {
|
|
Ok((row.get(0)?, row.get(1)?))
|
|
})?
|
|
.collect::<std::result::Result<Vec<_>, _>>()?;
|
|
|
|
let mut best: HashMap<i64, f64> = HashMap::new();
|
|
for (rowid, distance) in rows {
|
|
let (document_id, _chunk_index) = decode_rowid(rowid);
|
|
best.entry(document_id)
|
|
.and_modify(|d| {
|
|
if distance < *d {
|
|
*d = distance;
|
|
}
|
|
})
|
|
.or_insert(distance);
|
|
}
|
|
|
|
let mut results: Vec<VectorResult> = best
|
|
.into_iter()
|
|
.map(|(document_id, distance)| VectorResult {
|
|
document_id,
|
|
distance,
|
|
})
|
|
.collect();
|
|
results.sort_by(|a, b| a.distance.total_cmp(&b.distance));
|
|
results.truncate(limit);
|
|
|
|
Ok(results)
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_empty_returns_empty() {
|
|
let result = search_vector_dedup(vec![], 10);
|
|
assert!(result.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn test_dedup_keeps_best_distance() {
|
|
let rows = vec![(1000_i64, 0.5_f64), (1001, 0.3), (2000, 0.4)];
|
|
let results = search_vector_dedup(rows, 10);
|
|
assert_eq!(results.len(), 2);
|
|
assert_eq!(results[0].document_id, 1);
|
|
assert!((results[0].distance - 0.3).abs() < f64::EPSILON);
|
|
assert_eq!(results[1].document_id, 2);
|
|
}
|
|
|
|
#[test]
|
|
fn test_dedup_respects_limit() {
|
|
let rows = vec![(1000_i64, 0.1_f64), (2000, 0.2), (3000, 0.3)];
|
|
let results = search_vector_dedup(rows, 2);
|
|
assert_eq!(results.len(), 2);
|
|
}
|
|
|
|
#[test]
|
|
fn test_knn_k_never_exceeds_sqlite_vec_limit() {
|
|
for limit in [1, 10, 50, 100, 500, 1000, 1500, 2000, 5000] {
|
|
for max_chunks in [1, 2, 5, 10, 50, 100, 200, 500, 1000] {
|
|
let k = compute_knn_k(limit, max_chunks);
|
|
assert!(
|
|
k <= SQLITE_VEC_KNN_MAX,
|
|
"k={k} exceeded limit for limit={limit}, max_chunks={max_chunks}"
|
|
);
|
|
}
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_knn_k_reproduces_original_bug_scenario() {
|
|
let k = compute_knn_k(1500, 1);
|
|
assert!(
|
|
k <= SQLITE_VEC_KNN_MAX,
|
|
"k={k} exceeded 4096 at RECALL_CAP with 1 chunk"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_knn_k_small_limit_uses_minimum_multiplier() {
|
|
let k = compute_knn_k(10, 1);
|
|
assert_eq!(k, 80);
|
|
}
|
|
|
|
#[test]
|
|
fn test_knn_k_high_chunks_caps_multiplier() {
|
|
let k = compute_knn_k(10, 200);
|
|
assert_eq!(k, 2000);
|
|
}
|
|
|
|
#[test]
|
|
fn test_knn_k_zero_max_chunks_treated_as_one() {
|
|
let k = compute_knn_k(10, 0);
|
|
assert_eq!(k, 80);
|
|
}
|
|
|
|
#[test]
|
|
fn test_knn_k_negative_max_chunks_uses_absolute() {
|
|
let k = compute_knn_k(10, -5);
|
|
assert_eq!(k, compute_knn_k(10, 5));
|
|
}
|
|
|
|
fn search_vector_dedup(rows: Vec<(i64, f64)>, limit: usize) -> Vec<VectorResult> {
|
|
let mut best: HashMap<i64, f64> = HashMap::new();
|
|
for (rowid, distance) in rows {
|
|
let (document_id, _) = decode_rowid(rowid);
|
|
best.entry(document_id)
|
|
.and_modify(|d| {
|
|
if distance < *d {
|
|
*d = distance;
|
|
}
|
|
})
|
|
.or_insert(distance);
|
|
}
|
|
let mut results: Vec<VectorResult> = best
|
|
.into_iter()
|
|
.map(|(document_id, distance)| VectorResult {
|
|
document_id,
|
|
distance,
|
|
})
|
|
.collect();
|
|
results.sort_by(|a, b| a.distance.total_cmp(&b.distance));
|
|
results.truncate(limit);
|
|
results
|
|
}
|
|
}
|