use crate::core::error::Result; use rusqlite::Connection; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum FtsQueryMode { Safe, Raw, } #[derive(Debug)] pub struct FtsResult { pub document_id: i64, pub bm25_score: f64, pub snippet: String, } /// Validate an FTS5 query string for safety. /// Rejects known-dangerous patterns: unbalanced quotes, excessive wildcards, /// and empty queries. Returns the sanitized query or falls back to Safe mode. fn sanitize_raw_fts(raw: &str) -> Option { let trimmed = raw.trim(); if trimmed.is_empty() { return None; } // Reject unbalanced double quotes (FTS5 syntax error) let quote_count = trimmed.chars().filter(|&c| c == '"').count(); if quote_count % 2 != 0 { return None; } // Reject leading wildcard-only queries (expensive full-table scan) if trimmed == "*" || trimmed.starts_with("* ") { return None; } Some(trimmed.to_string()) } pub fn to_fts_query(raw: &str, mode: FtsQueryMode) -> String { match mode { FtsQueryMode::Raw => { // Validate raw FTS5 input; fall back to Safe mode if invalid match sanitize_raw_fts(raw) { Some(sanitized) => sanitized, None => to_fts_query(raw, FtsQueryMode::Safe), } } FtsQueryMode::Safe => { let trimmed = raw.trim(); if trimmed.is_empty() { return String::new(); } // FTS5 boolean operators are case-sensitive uppercase keywords. // Pass them through unquoted so users can write "switch AND health". // Note: NEAR is a function NEAR(term1 term2, N), not an infix operator. // Users who need NEAR syntax should use FtsQueryMode::Raw. const FTS5_OPERATORS: &[&str] = &["AND", "OR", "NOT"]; let mut result = String::with_capacity(trimmed.len() + 20); for (i, token) in trimmed.split_whitespace().enumerate() { if i > 0 { result.push(' '); } if FTS5_OPERATORS.contains(&token) { result.push_str(token); } else if let Some(stem) = token.strip_suffix('*') && !stem.is_empty() && stem.chars().all(|c| c.is_alphanumeric() || c == '_') { result.push('"'); result.push_str(&stem.replace('"', "\"\"")); result.push_str("\"*"); } else { result.push('"'); result.push_str(&token.replace('"', "\"\"")); result.push('"'); } } result } } } pub fn search_fts( conn: &Connection, query: &str, limit: usize, mode: FtsQueryMode, ) -> Result> { let fts_query = to_fts_query(query, mode); if fts_query.is_empty() { return Ok(Vec::new()); } let sql = r#" SELECT d.id, bm25(documents_fts) AS score, snippet(documents_fts, 1, '', '', '...', 64) AS snip FROM documents_fts JOIN documents d ON d.id = documents_fts.rowid WHERE documents_fts MATCH ?1 ORDER BY score LIMIT ?2 "#; let mut stmt = conn.prepare_cached(sql)?; let results = stmt .query_map(rusqlite::params![fts_query, limit as i64], |row| { Ok(FtsResult { document_id: row.get(0)?, bm25_score: row.get(1)?, snippet: row.get(2)?, }) })? .collect::, _>>()?; Ok(results) } pub fn generate_fallback_snippet(content_text: &str, max_chars: usize) -> String { // Use char_indices to find the boundary at max_chars in a single pass, // short-circuiting early for large strings instead of counting all chars. let byte_end = match content_text.char_indices().nth(max_chars) { Some((i, _)) => i, None => return content_text.to_string(), // content fits within max_chars }; let truncated = &content_text[..byte_end]; if let Some(last_space) = truncated.rfind(' ') { format!("{}...", &truncated[..last_space]) } else { format!("{}...", truncated) } } pub fn get_result_snippet(fts_snippet: Option<&str>, content_text: &str) -> String { match fts_snippet { Some(s) if !s.is_empty() => s.to_string(), _ => generate_fallback_snippet(content_text, 200), } } #[cfg(test)] mod tests { use super::*; #[test] fn test_safe_query_basic() { let result = to_fts_query("auth error", FtsQueryMode::Safe); assert_eq!(result, "\"auth\" \"error\""); } #[test] fn test_safe_query_prefix() { let result = to_fts_query("auth*", FtsQueryMode::Safe); assert_eq!(result, "\"auth\"*"); } #[test] fn test_safe_query_special_chars() { let result = to_fts_query("C++", FtsQueryMode::Safe); assert_eq!(result, "\"C++\""); } #[test] fn test_safe_query_dash() { let result = to_fts_query("-DWITH_SSL", FtsQueryMode::Safe); assert_eq!(result, "\"-DWITH_SSL\""); } #[test] fn test_safe_query_quotes() { let result = to_fts_query("he said \"hello\"", FtsQueryMode::Safe); assert_eq!(result, "\"he\" \"said\" \"\"\"hello\"\"\""); } #[test] fn test_raw_mode_passthrough() { let result = to_fts_query("auth OR error", FtsQueryMode::Raw); assert_eq!(result, "auth OR error"); } #[test] fn test_empty_query() { let result = to_fts_query("", FtsQueryMode::Safe); assert_eq!(result, ""); let result = to_fts_query(" ", FtsQueryMode::Safe); assert_eq!(result, ""); } #[test] fn test_prefix_only_alphanumeric() { let result = to_fts_query("C++*", FtsQueryMode::Safe); assert_eq!(result, "\"C++*\""); let result = to_fts_query("auth*", FtsQueryMode::Safe); assert_eq!(result, "\"auth\"*"); } #[test] fn test_prefix_with_underscore() { let result = to_fts_query("jwt_token*", FtsQueryMode::Safe); assert_eq!(result, "\"jwt_token\"*"); } #[test] fn test_fallback_snippet_short() { let result = generate_fallback_snippet("Short content", 200); assert_eq!(result, "Short content"); } #[test] fn test_fallback_snippet_word_boundary() { let content = "This is a moderately long piece of text that should be truncated at a word boundary for readability purposes"; let result = generate_fallback_snippet(content, 50); assert!(result.ends_with("...")); assert!(result.len() <= 55); } #[test] fn test_get_result_snippet_prefers_fts() { let result = get_result_snippet(Some("FTS match"), "full content text"); assert_eq!(result, "FTS match"); } #[test] fn test_get_result_snippet_fallback() { let result = get_result_snippet(None, "full content text"); assert_eq!(result, "full content text"); } #[test] fn test_get_result_snippet_empty_fts() { let result = get_result_snippet(Some(""), "full content text"); assert_eq!(result, "full content text"); } #[test] fn test_raw_mode_valid_fts5_passes_through() { let result = to_fts_query("auth OR error", FtsQueryMode::Raw); assert_eq!(result, "auth OR error"); let result = to_fts_query("\"exact phrase\"", FtsQueryMode::Raw); assert_eq!(result, "\"exact phrase\""); } #[test] fn test_raw_mode_unbalanced_quotes_falls_back_to_safe() { let result = to_fts_query("auth \"error", FtsQueryMode::Raw); // Falls back to Safe mode: each token quoted assert_eq!(result, "\"auth\" \"\"\"error\""); } #[test] fn test_raw_mode_leading_wildcard_falls_back_to_safe() { let result = to_fts_query("* OR auth", FtsQueryMode::Raw); // Falls back to Safe mode; OR is an FTS5 operator so it passes through unquoted assert_eq!(result, "\"*\" OR \"auth\""); let result = to_fts_query("*", FtsQueryMode::Raw); assert_eq!(result, "\"*\""); } #[test] fn test_raw_mode_empty_falls_back_to_safe() { let result = to_fts_query("", FtsQueryMode::Raw); assert_eq!(result, ""); let result = to_fts_query(" ", FtsQueryMode::Raw); assert_eq!(result, ""); } }