diff --git a/src/core/project.rs b/src/core/project.rs index c6f2722..055e7a3 100644 --- a/src/core/project.rs +++ b/src/core/project.rs @@ -21,13 +21,14 @@ pub fn resolve_project(conn: &Connection, project_str: &str) -> Result { return Ok(id); } + let escaped = escape_like(project_str); let mut suffix_stmt = conn.prepare( "SELECT id, path_with_namespace FROM projects - WHERE path_with_namespace LIKE '%/' || ?1 - OR path_with_namespace = ?1", + WHERE path_with_namespace LIKE '%/' || ?1 ESCAPE '\\' + OR path_with_namespace = ?2", )?; let suffix_matches: Vec<(i64, String)> = suffix_stmt - .query_map(rusqlite::params![project_str], |row| { + .query_map(rusqlite::params![escaped, project_str], |row| { Ok((row.get(0)?, row.get(1)?)) })? .collect::, _>>()?; @@ -52,10 +53,10 @@ pub fn resolve_project(conn: &Connection, project_str: &str) -> Result { let mut substr_stmt = conn.prepare( "SELECT id, path_with_namespace FROM projects - WHERE LOWER(path_with_namespace) LIKE '%' || LOWER(?1) || '%'", + WHERE LOWER(path_with_namespace) LIKE '%' || LOWER(?1) || '%' ESCAPE '\\'", )?; let substr_matches: Vec<(i64, String)> = substr_stmt - .query_map(rusqlite::params![project_str], |row| { + .query_map(rusqlite::params![escaped], |row| { Ok((row.get(0)?, row.get(1)?)) })? .collect::, _>>()?; @@ -103,6 +104,15 @@ pub fn resolve_project(conn: &Connection, project_str: &str) -> Result { ))) } +/// Escape LIKE metacharacters so `%` and `_` in user input are treated as +/// literals. All queries using this must include `ESCAPE '\'`. +fn escape_like(input: &str) -> String { + input + .replace('\\', "\\\\") + .replace('%', "\\%") + .replace('_', "\\_") +} + #[cfg(test)] mod tests { use super::*; @@ -241,4 +251,24 @@ mod tests { let msg = err.to_string(); assert!(msg.contains("No projects have been synced")); } + + #[test] + fn test_underscore_not_wildcard() { + let conn = setup_db(); + insert_project(&conn, 1, "backend/my_project"); + insert_project(&conn, 2, "backend/my-project"); + // `_` in user input must not match `-` (LIKE wildcard behavior) + let id = resolve_project(&conn, "my_project").unwrap(); + assert_eq!(id, 1); + } + + #[test] + fn test_percent_not_wildcard() { + let conn = setup_db(); + insert_project(&conn, 1, "backend/a%b"); + insert_project(&conn, 2, "backend/axyzb"); + // `%` in user input must not match arbitrary strings + let id = resolve_project(&conn, "a%b").unwrap(); + assert_eq!(id, 1); + } }