use std::collections::HashMap; const RRF_K: f64 = 60.0; pub struct RrfResult { pub document_id: i64, pub rrf_score: f64, pub normalized_score: f64, pub vector_rank: Option, pub fts_rank: Option, } pub fn rank_rrf(vector_results: &[(i64, f64)], fts_results: &[(i64, f64)]) -> Vec { if vector_results.is_empty() && fts_results.is_empty() { return Vec::new(); } let mut scores: HashMap, Option)> = HashMap::new(); for (i, &(doc_id, _)) in vector_results.iter().enumerate() { let rank = i + 1; let entry = scores.entry(doc_id).or_insert((0.0, None, None)); if entry.1.is_none() { entry.0 += 1.0 / (RRF_K + rank as f64); entry.1 = Some(rank); } } for (i, &(doc_id, _)) in fts_results.iter().enumerate() { let rank = i + 1; let entry = scores.entry(doc_id).or_insert((0.0, None, None)); if entry.2.is_none() { entry.0 += 1.0 / (RRF_K + rank as f64); entry.2 = Some(rank); } } let mut results: Vec = scores .into_iter() .map(|(doc_id, (rrf_score, vector_rank, fts_rank))| RrfResult { document_id: doc_id, rrf_score, normalized_score: 0.0, vector_rank, fts_rank, }) .collect(); results.sort_by(|a, b| b.rrf_score.total_cmp(&a.rrf_score)); if let Some(max_score) = results.first().map(|r| r.rrf_score).filter(|&s| s > 0.0) { for result in &mut results { result.normalized_score = result.rrf_score / max_score; } } results } #[cfg(test)] mod tests { use super::*; #[test] fn test_dual_list_ranks_higher() { let vector = vec![(1, 0.1), (2, 0.2)]; let fts = vec![(1, 5.0), (3, 3.0)]; let results = rank_rrf(&vector, &fts); assert_eq!(results[0].document_id, 1); let doc1 = &results[0]; let doc2_score = results .iter() .find(|r| r.document_id == 2) .unwrap() .rrf_score; let doc3_score = results .iter() .find(|r| r.document_id == 3) .unwrap() .rrf_score; assert!(doc1.rrf_score > doc2_score); assert!(doc1.rrf_score > doc3_score); } #[test] fn test_single_list_included() { let vector = vec![(1, 0.1)]; let fts = vec![(2, 5.0)]; let results = rank_rrf(&vector, &fts); assert_eq!(results.len(), 2); let doc_ids: Vec = results.iter().map(|r| r.document_id).collect(); assert!(doc_ids.contains(&1)); assert!(doc_ids.contains(&2)); } #[test] fn test_normalization() { let vector = vec![(1, 0.1), (2, 0.2)]; let fts = vec![(1, 5.0), (3, 3.0)]; let results = rank_rrf(&vector, &fts); assert!((results[0].normalized_score - 1.0).abs() < f64::EPSILON); for r in &results { assert!(r.normalized_score >= 0.0); assert!(r.normalized_score <= 1.0); } } #[test] fn test_empty_inputs() { let results = rank_rrf(&[], &[]); assert!(results.is_empty()); } #[test] fn test_ranks_are_1_indexed() { let vector = vec![(10, 0.1), (20, 0.2)]; let fts = vec![(10, 5.0), (30, 3.0)]; let results = rank_rrf(&vector, &fts); let doc10 = results.iter().find(|r| r.document_id == 10).unwrap(); assert_eq!(doc10.vector_rank, Some(1)); assert_eq!(doc10.fts_rank, Some(1)); let doc20 = results.iter().find(|r| r.document_id == 20).unwrap(); assert_eq!(doc20.vector_rank, Some(2)); assert_eq!(doc20.fts_rank, None); let doc30 = results.iter().find(|r| r.document_id == 30).unwrap(); assert_eq!(doc30.vector_rank, None); assert_eq!(doc30.fts_rank, Some(2)); } #[test] fn test_raw_and_normalized_scores() { let vector = vec![(1, 0.1)]; let fts = vec![(1, 5.0)]; let results = rank_rrf(&vector, &fts); assert_eq!(results.len(), 1); let r = &results[0]; let expected = 2.0 / 61.0; assert!((r.rrf_score - expected).abs() < 1e-10); assert!((r.normalized_score - 1.0).abs() < f64::EPSILON); } #[test] fn test_one_empty_list() { let vector = vec![(1, 0.1), (2, 0.2)]; let results = rank_rrf(&vector, &[]); assert_eq!(results.len(), 2); assert!((results[0].normalized_score - 1.0).abs() < f64::EPSILON); } }