use crate::core::error::Result; use rusqlite::Connection; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum FtsQueryMode { Safe, Raw, } #[derive(Debug)] pub struct FtsResult { pub document_id: i64, pub bm25_score: f64, pub snippet: String, } pub fn to_fts_query(raw: &str, mode: FtsQueryMode) -> String { match mode { FtsQueryMode::Raw => raw.to_string(), FtsQueryMode::Safe => { let trimmed = raw.trim(); if trimmed.is_empty() { return String::new(); } let mut result = String::with_capacity(trimmed.len() + 20); for (i, token) in trimmed.split_whitespace().enumerate() { if i > 0 { result.push(' '); } if let Some(stem) = token.strip_suffix('*') && !stem.is_empty() && stem.chars().all(|c| c.is_alphanumeric() || c == '_') { result.push('"'); result.push_str(&stem.replace('"', "\"\"")); result.push_str("\"*"); } else { result.push('"'); result.push_str(&token.replace('"', "\"\"")); result.push('"'); } } result } } } pub fn search_fts( conn: &Connection, query: &str, limit: usize, mode: FtsQueryMode, ) -> Result> { let fts_query = to_fts_query(query, mode); if fts_query.is_empty() { return Ok(Vec::new()); } let sql = r#" SELECT d.id, bm25(documents_fts) AS score, snippet(documents_fts, 1, '', '', '...', 64) AS snip FROM documents_fts JOIN documents d ON d.id = documents_fts.rowid WHERE documents_fts MATCH ?1 ORDER BY score LIMIT ?2 "#; let mut stmt = conn.prepare(sql)?; let results = stmt .query_map(rusqlite::params![fts_query, limit as i64], |row| { Ok(FtsResult { document_id: row.get(0)?, bm25_score: row.get(1)?, snippet: row.get(2)?, }) })? .collect::, _>>()?; Ok(results) } pub fn generate_fallback_snippet(content_text: &str, max_chars: usize) -> String { if content_text.chars().count() <= max_chars { return content_text.to_string(); } let byte_end = content_text .char_indices() .nth(max_chars) .map(|(i, _)| i) .unwrap_or(content_text.len()); let truncated = &content_text[..byte_end]; if let Some(last_space) = truncated.rfind(' ') { format!("{}...", &truncated[..last_space]) } else { format!("{}...", truncated) } } pub fn get_result_snippet(fts_snippet: Option<&str>, content_text: &str) -> String { match fts_snippet { Some(s) if !s.is_empty() => s.to_string(), _ => generate_fallback_snippet(content_text, 200), } } #[cfg(test)] mod tests { use super::*; #[test] fn test_safe_query_basic() { let result = to_fts_query("auth error", FtsQueryMode::Safe); assert_eq!(result, "\"auth\" \"error\""); } #[test] fn test_safe_query_prefix() { let result = to_fts_query("auth*", FtsQueryMode::Safe); assert_eq!(result, "\"auth\"*"); } #[test] fn test_safe_query_special_chars() { let result = to_fts_query("C++", FtsQueryMode::Safe); assert_eq!(result, "\"C++\""); } #[test] fn test_safe_query_dash() { let result = to_fts_query("-DWITH_SSL", FtsQueryMode::Safe); assert_eq!(result, "\"-DWITH_SSL\""); } #[test] fn test_safe_query_quotes() { let result = to_fts_query("he said \"hello\"", FtsQueryMode::Safe); assert_eq!(result, "\"he\" \"said\" \"\"\"hello\"\"\""); } #[test] fn test_raw_mode_passthrough() { let result = to_fts_query("auth OR error", FtsQueryMode::Raw); assert_eq!(result, "auth OR error"); } #[test] fn test_empty_query() { let result = to_fts_query("", FtsQueryMode::Safe); assert_eq!(result, ""); let result = to_fts_query(" ", FtsQueryMode::Safe); assert_eq!(result, ""); } #[test] fn test_prefix_only_alphanumeric() { let result = to_fts_query("C++*", FtsQueryMode::Safe); assert_eq!(result, "\"C++*\""); let result = to_fts_query("auth*", FtsQueryMode::Safe); assert_eq!(result, "\"auth\"*"); } #[test] fn test_prefix_with_underscore() { let result = to_fts_query("jwt_token*", FtsQueryMode::Safe); assert_eq!(result, "\"jwt_token\"*"); } #[test] fn test_fallback_snippet_short() { let result = generate_fallback_snippet("Short content", 200); assert_eq!(result, "Short content"); } #[test] fn test_fallback_snippet_word_boundary() { let content = "This is a moderately long piece of text that should be truncated at a word boundary for readability purposes"; let result = generate_fallback_snippet(content, 50); assert!(result.ends_with("...")); assert!(result.len() <= 55); } #[test] fn test_get_result_snippet_prefers_fts() { let result = get_result_snippet(Some("FTS match"), "full content text"); assert_eq!(result, "FTS match"); } #[test] fn test_get_result_snippet_fallback() { let result = get_result_snippet(None, "full content text"); assert_eq!(result, "full content text"); } #[test] fn test_get_result_snippet_empty_fts() { let result = get_result_snippet(Some(""), "full content text"); assert_eq!(result, "full content text"); } }