use serde::Serialize; use strsim::jaro_winkler; // --------------------------------------------------------------------------- // Types // --------------------------------------------------------------------------- /// A single correction applied to one argument. #[derive(Debug, Clone, Serialize)] pub struct Correction { pub original: String, pub corrected: String, pub rule: CorrectionRule, pub confidence: f64, } /// Which rule triggered the correction. #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)] #[serde(rename_all = "snake_case")] pub enum CorrectionRule { SingleDashLongFlag, CaseNormalization, FuzzyFlag, } /// Result of the correction pass over raw args. #[derive(Debug, Clone)] pub struct CorrectionResult { pub args: Vec, pub corrections: Vec, } // --------------------------------------------------------------------------- // Flag registry // --------------------------------------------------------------------------- /// Global flags accepted by every command (from `Cli` struct). const GLOBAL_FLAGS: &[&str] = &[ "--config", "--robot", "--json", "--color", "--quiet", "--no-quiet", "--verbose", "--no-verbose", "--log-format", ]; /// Per-subcommand flags. Each entry is `(command_name, &[flags])`. /// Hidden `--no-*` variants are included so they can be fuzzy-matched too. const COMMAND_FLAGS: &[(&str, &[&str])] = &[ ( "issues", &[ "--limit", "--fields", "--state", "--project", "--author", "--assignee", "--label", "--milestone", "--since", "--due-before", "--has-due", "--no-has-due", "--sort", "--asc", "--no-asc", "--open", "--no-open", ], ), ( "mrs", &[ "--limit", "--fields", "--state", "--project", "--author", "--assignee", "--reviewer", "--label", "--since", "--draft", "--no-draft", "--target", "--source", "--sort", "--asc", "--no-asc", "--open", "--no-open", ], ), ( "ingest", &[ "--project", "--force", "--no-force", "--full", "--no-full", "--dry-run", "--no-dry-run", ], ), ( "sync", &[ "--full", "--no-full", "--force", "--no-force", "--no-embed", "--no-docs", "--no-events", "--no-file-changes", "--dry-run", "--no-dry-run", ], ), ( "search", &[ "--mode", "--type", "--author", "--project", "--label", "--path", "--since", "--updated-since", "--limit", "--explain", "--no-explain", "--fts-mode", ], ), ( "embed", &["--full", "--no-full", "--retry-failed", "--no-retry-failed"], ), ( "stats", &[ "--check", "--no-check", "--repair", "--dry-run", "--no-dry-run", ], ), ("count", &["--for"]), ( "timeline", &[ "--project", "--since", "--depth", "--expand-mentions", "--limit", "--max-seeds", "--max-entities", "--max-evidence", ], ), ( "who", &[ "--path", "--active", "--overlap", "--reviews", "--since", "--project", "--limit", "--detail", "--no-detail", ], ), ( "init", &[ "--force", "--non-interactive", "--gitlab-url", "--token-env-var", "--projects", ], ), ("generate-docs", &["--full", "--project"]), ("completions", &[]), ( "list", &[ "--limit", "--project", "--state", "--author", "--assignee", "--label", "--milestone", "--since", "--due-before", "--has-due-date", "--sort", "--order", "--open", "--draft", "--no-draft", "--reviewer", "--target-branch", "--source-branch", ], ), ("show", &["--project"]), ("reset", &["--yes"]), ]; /// Valid values for enum-like flags, used for post-clap error enhancement. pub const ENUM_VALUES: &[(&str, &[&str])] = &[ ("--state", &["opened", "closed", "merged", "locked", "all"]), ("--mode", &["lexical", "hybrid", "semantic"]), ("--sort", &["updated", "created", "iid"]), ("--type", &["issue", "mr", "discussion"]), ("--fts-mode", &["safe", "raw"]), ("--color", &["auto", "always", "never"]), ("--log-format", &["text", "json"]), ("--for", &["issue", "mr"]), ]; // --------------------------------------------------------------------------- // Correction thresholds // --------------------------------------------------------------------------- const FUZZY_FLAG_THRESHOLD: f64 = 0.8; // --------------------------------------------------------------------------- // Core logic // --------------------------------------------------------------------------- /// Detect which subcommand is being invoked by finding the first positional /// arg (not a flag, not a flag value). fn detect_subcommand(args: &[String]) -> Option<&str> { // Skip args[0] (binary name). Walk forward looking for the first // arg that isn't a flag and isn't the value to a flag that takes one. let mut skip_next = false; for arg in args.iter().skip(1) { if skip_next { skip_next = false; continue; } if arg.starts_with('-') { // Flags that take a value: we know global ones; for simplicity // skip the next arg for any `--flag=value` form (handled inline) // or known value-taking global flags. if arg.contains('=') { continue; } if matches!(arg.as_str(), "--config" | "-c" | "--color" | "--log-format") { skip_next = true; } continue; } // First non-flag positional = subcommand return Some(arg.as_str()); } None } /// Build the set of valid long flags for the detected subcommand. fn valid_flags_for(subcommand: Option<&str>) -> Vec<&'static str> { let mut flags: Vec<&str> = GLOBAL_FLAGS.to_vec(); if let Some(cmd) = subcommand { for (name, cmd_flags) in COMMAND_FLAGS { if *name == cmd { flags.extend_from_slice(cmd_flags); break; } } } else { // No subcommand detected — include all flags for maximum matching for (_, cmd_flags) in COMMAND_FLAGS { for flag in *cmd_flags { if !flags.contains(flag) { flags.push(flag); } } } } flags } /// Run the pre-clap correction pass on raw args. /// /// When `strict` is true (robot mode), only deterministic corrections are applied /// (single-dash long flags, case normalization). Fuzzy matching is disabled to /// prevent misleading agents with speculative corrections. /// /// Returns the (possibly modified) args and any corrections applied. pub fn correct_args(raw: Vec, strict: bool) -> CorrectionResult { let subcommand = detect_subcommand(&raw); let valid = valid_flags_for(subcommand); let mut corrected = Vec::with_capacity(raw.len()); let mut corrections = Vec::new(); let mut past_terminator = false; for arg in raw { // B1: Stop correcting after POSIX `--` option terminator if arg == "--" { past_terminator = true; corrected.push(arg); continue; } if past_terminator { corrected.push(arg); continue; } if let Some(fixed) = try_correct(&arg, &valid, strict) { let s = fixed.corrected.clone(); corrections.push(fixed); corrected.push(s); } else { corrected.push(arg); } } CorrectionResult { args: corrected, corrections, } } /// Clap built-in flags that should never be corrected. These are handled by clap /// directly and are not in our GLOBAL_FLAGS registry. const CLAP_BUILTINS: &[&str] = &["--help", "--version"]; /// Try to correct a single arg. Returns `None` if no correction needed. /// /// When `strict` is true, fuzzy matching is disabled — only deterministic /// corrections (single-dash fix, case normalization) are applied. fn try_correct(arg: &str, valid_flags: &[&str], strict: bool) -> Option { // Only attempt correction on flag-like args (starts with `-`) if !arg.starts_with('-') { return None; } // B2: Never correct clap built-in flags (--help, --version) let flag_part_for_builtin = if let Some(eq_pos) = arg.find('=') { &arg[..eq_pos] } else { arg }; if CLAP_BUILTINS .iter() .any(|b| b.eq_ignore_ascii_case(flag_part_for_builtin)) { return None; } // Skip short flags — they're unambiguous single chars (-p, -n, -v, -J) // Also skip stacked short flags (-vvv) if !arg.starts_with("--") { // Rule 1: Single-dash long flag — e.g. `-robot` (len > 2, not a valid short flag) // A short flag is `-` + single char, optionally stacked (-vvv). // If it's `-` + multiple chars and NOT all the same char, it's likely a single-dash long flag. let after_dash = &arg[1..]; // Check if it's a stacked short flag like -vvv (all same char) let all_same_char = after_dash.len() > 1 && after_dash .chars() .all(|c| c == after_dash.chars().next().unwrap_or('\0')); if all_same_char { return None; } // Single char = valid short flag, don't touch if after_dash.len() == 1 { return None; } // It looks like a single-dash long flag (e.g. `-robot`, `-state`) let candidate = format!("--{after_dash}"); // Check exact match first (case-sensitive) if valid_flags.contains(&candidate.as_str()) { return Some(Correction { original: arg.to_string(), corrected: candidate, rule: CorrectionRule::SingleDashLongFlag, confidence: 0.95, }); } // Check case-insensitive exact match let lower = candidate.to_lowercase(); if let Some(&flag) = valid_flags.iter().find(|f| f.to_lowercase() == lower) { return Some(Correction { original: arg.to_string(), corrected: flag.to_string(), rule: CorrectionRule::SingleDashLongFlag, confidence: 0.95, }); } // Try fuzzy on the single-dash candidate (skip in strict mode) if !strict && let Some((best_flag, score)) = best_fuzzy_match(&lower, valid_flags) && score >= FUZZY_FLAG_THRESHOLD { return Some(Correction { original: arg.to_string(), corrected: best_flag.to_string(), rule: CorrectionRule::SingleDashLongFlag, confidence: score * 0.95, // discount slightly for compound correction }); } return None; } // For `--flag` or `--flag=value` forms: only correct the flag name let (flag_part, value_suffix) = if let Some(eq_pos) = arg.find('=') { (&arg[..eq_pos], Some(&arg[eq_pos..])) } else { (arg, None) }; // Already valid? No correction needed. if valid_flags.contains(&flag_part) { return None; } // Rule 2: Case normalization — `--Robot` -> `--robot` let lower = flag_part.to_lowercase(); if lower != flag_part && let Some(&flag) = valid_flags.iter().find(|f| f.to_lowercase() == lower) { let corrected = match value_suffix { Some(suffix) => format!("{flag}{suffix}"), None => flag.to_string(), }; return Some(Correction { original: arg.to_string(), corrected, rule: CorrectionRule::CaseNormalization, confidence: 0.9, }); } // Rule 3: Fuzzy flag match — `--staate` -> `--state` (skip in strict mode) if !strict && let Some((best_flag, score)) = best_fuzzy_match(&lower, valid_flags) && score >= FUZZY_FLAG_THRESHOLD { let corrected = match value_suffix { Some(suffix) => format!("{best_flag}{suffix}"), None => best_flag.to_string(), }; return Some(Correction { original: arg.to_string(), corrected, rule: CorrectionRule::FuzzyFlag, confidence: score, }); } None } /// Find the best fuzzy match among valid flags for a given (lowercased) input. fn best_fuzzy_match<'a>(input: &str, valid_flags: &[&'a str]) -> Option<(&'a str, f64)> { valid_flags .iter() .map(|&flag| (flag, jaro_winkler(input, flag))) .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)) } // --------------------------------------------------------------------------- // Post-clap suggestion helpers // --------------------------------------------------------------------------- /// Given an unrecognized flag (from a clap error), suggest the most similar /// valid flag for the detected subcommand. pub fn suggest_similar_flag(invalid_flag: &str, raw_args: &[String]) -> Option { let subcommand = detect_subcommand(raw_args); let valid = valid_flags_for(subcommand); let lower = invalid_flag.to_lowercase(); let (best_flag, score) = best_fuzzy_match(&lower, &valid)?; if score >= 0.6 { Some(best_flag.to_string()) } else { None } } /// Given a flag name, return its valid enum values (if known). pub fn valid_values_for_flag(flag: &str) -> Option<&'static [&'static str]> { let lower = flag.to_lowercase(); ENUM_VALUES .iter() .find(|(f, _)| f.to_lowercase() == lower) .map(|(_, vals)| *vals) } /// Format a human/robot teaching note for a correction. pub fn format_teaching_note(correction: &Correction) -> String { match correction.rule { CorrectionRule::SingleDashLongFlag => { format!( "Use double-dash for long flags: {} (not {})", correction.corrected, correction.original ) } CorrectionRule::CaseNormalization => { format!( "Flags are lowercase: {} (not {})", correction.corrected, correction.original ) } CorrectionRule::FuzzyFlag => { format!( "Correct spelling: {} (not {})", correction.corrected, correction.original ) } } } // --------------------------------------------------------------------------- // Tests // --------------------------------------------------------------------------- #[cfg(test)] mod tests { use super::*; fn args(s: &str) -> Vec { s.split_whitespace().map(String::from).collect() } // ---- Single-dash long flag ---- #[test] fn single_dash_robot() { let result = correct_args(args("lore -robot issues -n 5"), false); assert_eq!(result.corrections.len(), 1); assert_eq!(result.corrections[0].original, "-robot"); assert_eq!(result.corrections[0].corrected, "--robot"); assert_eq!( result.corrections[0].rule, CorrectionRule::SingleDashLongFlag ); assert_eq!(result.args, args("lore --robot issues -n 5")); } #[test] fn single_dash_state() { let result = correct_args(args("lore --robot issues -state opened"), false); assert_eq!(result.corrections.len(), 1); assert_eq!(result.corrections[0].corrected, "--state"); } // ---- Case normalization ---- #[test] fn case_robot() { let result = correct_args(args("lore --Robot issues"), false); assert_eq!(result.corrections.len(), 1); assert_eq!(result.corrections[0].corrected, "--robot"); assert_eq!( result.corrections[0].rule, CorrectionRule::CaseNormalization ); } #[test] fn case_state_upper() { let result = correct_args(args("lore --robot issues --State opened"), false); assert_eq!(result.corrections.len(), 1); assert_eq!(result.corrections[0].corrected, "--state"); assert_eq!( result.corrections[0].rule, CorrectionRule::CaseNormalization ); } #[test] fn case_all_upper() { let result = correct_args(args("lore --ROBOT issues --STATE opened"), false); assert_eq!(result.corrections.len(), 2); assert_eq!(result.corrections[0].corrected, "--robot"); assert_eq!(result.corrections[1].corrected, "--state"); } // ---- Fuzzy flag match ---- #[test] fn fuzzy_staate() { let result = correct_args(args("lore --robot issues --staate opened"), false); assert_eq!(result.corrections.len(), 1); assert_eq!(result.corrections[0].corrected, "--state"); assert_eq!(result.corrections[0].rule, CorrectionRule::FuzzyFlag); } #[test] fn fuzzy_projct() { let result = correct_args(args("lore --robot issues --projct group/repo"), false); assert_eq!(result.corrections.len(), 1); assert_eq!(result.corrections[0].corrected, "--project"); assert_eq!(result.corrections[0].rule, CorrectionRule::FuzzyFlag); } // ---- No corrections ---- #[test] fn already_correct() { let original = args("lore --robot issues --state opened -n 10"); let result = correct_args(original.clone(), false); assert!(result.corrections.is_empty()); assert_eq!(result.args, original); } #[test] fn short_flags_untouched() { let original = args("lore -J issues -n 10 -s opened -p group/repo"); let result = correct_args(original.clone(), false); assert!(result.corrections.is_empty()); } #[test] fn stacked_short_flags_untouched() { let original = args("lore -vvv issues"); let result = correct_args(original.clone(), false); assert!(result.corrections.is_empty()); } #[test] fn positional_args_untouched() { let result = correct_args(args("lore --robot search authentication"), false); assert!(result.corrections.is_empty()); } #[test] fn wildly_wrong_flag_not_corrected() { // `--xyzzy` shouldn't match anything above 0.8 let result = correct_args(args("lore --robot issues --xyzzy foo"), false); assert!(result.corrections.is_empty()); } // ---- Flag with = value ---- #[test] fn flag_eq_value_case_correction() { let result = correct_args(args("lore --robot issues --State=opened"), false); assert_eq!(result.corrections.len(), 1); assert_eq!(result.corrections[0].corrected, "--state=opened"); } // ---- Multiple corrections in one invocation ---- #[test] fn multiple_corrections() { let result = correct_args( args("lore -robot issues --State opened --projct group/repo"), false, ); assert_eq!(result.corrections.len(), 3); assert_eq!(result.args[1], "--robot"); assert_eq!(result.args[3], "--state"); assert_eq!(result.args[5], "--project"); } // ---- B1: POSIX -- option terminator ---- #[test] fn option_terminator_stops_corrections() { let result = correct_args(args("lore issues -- --staate --projct"), false); // Nothing after `--` should be corrected assert!(result.corrections.is_empty()); assert_eq!(result.args[2], "--"); assert_eq!(result.args[3], "--staate"); assert_eq!(result.args[4], "--projct"); } #[test] fn correction_before_terminator_still_works() { let result = correct_args(args("lore --Robot issues -- --staate"), false); assert_eq!(result.corrections.len(), 1); assert_eq!(result.corrections[0].corrected, "--robot"); assert_eq!(result.args[4], "--staate"); // untouched after -- } // ---- B2: Clap built-in flags not corrected ---- #[test] fn version_flag_not_corrected() { let result = correct_args(args("lore --version"), false); assert!(result.corrections.is_empty()); assert_eq!(result.args[1], "--version"); } #[test] fn help_flag_not_corrected() { let result = correct_args(args("lore --help"), false); assert!(result.corrections.is_empty()); assert_eq!(result.args[1], "--help"); } // ---- I6: Strict mode (robot) disables fuzzy matching ---- #[test] fn strict_mode_disables_fuzzy() { // Fuzzy match works in non-strict let non_strict = correct_args(args("lore --robot issues --staate opened"), false); assert_eq!(non_strict.corrections.len(), 1); assert_eq!(non_strict.corrections[0].rule, CorrectionRule::FuzzyFlag); // Fuzzy match disabled in strict let strict = correct_args(args("lore --robot issues --staate opened"), true); assert!(strict.corrections.is_empty()); } #[test] fn strict_mode_still_fixes_case() { let result = correct_args(args("lore --Robot issues --State opened"), true); assert_eq!(result.corrections.len(), 2); assert_eq!(result.corrections[0].corrected, "--robot"); assert_eq!(result.corrections[1].corrected, "--state"); } #[test] fn strict_mode_still_fixes_single_dash() { let result = correct_args(args("lore -robot issues"), true); assert_eq!(result.corrections.len(), 1); assert_eq!(result.corrections[0].corrected, "--robot"); } // ---- Teaching notes ---- #[test] fn teaching_note_single_dash() { let c = Correction { original: "-robot".to_string(), corrected: "--robot".to_string(), rule: CorrectionRule::SingleDashLongFlag, confidence: 0.95, }; let note = format_teaching_note(&c); assert!(note.contains("double-dash")); assert!(note.contains("--robot")); } #[test] fn teaching_note_case() { let c = Correction { original: "--State".to_string(), corrected: "--state".to_string(), rule: CorrectionRule::CaseNormalization, confidence: 0.9, }; let note = format_teaching_note(&c); assert!(note.contains("lowercase")); } #[test] fn teaching_note_fuzzy() { let c = Correction { original: "--staate".to_string(), corrected: "--state".to_string(), rule: CorrectionRule::FuzzyFlag, confidence: 0.85, }; let note = format_teaching_note(&c); assert!(note.contains("spelling")); } // ---- Post-clap suggestion helpers ---- #[test] fn suggest_similar_flag_works() { let raw = args("lore --robot issues --xstat opened"); let suggestion = suggest_similar_flag("--xstat", &raw); // Should suggest --state (close enough with lower threshold 0.6) assert!(suggestion.is_some()); } #[test] fn valid_values_for_state() { let vals = valid_values_for_flag("--state"); assert!(vals.is_some()); let vals = vals.unwrap(); assert!(vals.contains(&"opened")); assert!(vals.contains(&"closed")); } #[test] fn valid_values_unknown_flag() { assert!(valid_values_for_flag("--xyzzy").is_none()); } // ---- Subcommand detection ---- #[test] fn detect_subcommand_basic() { assert_eq!( detect_subcommand(&args("lore issues -n 10")), Some("issues") ); } #[test] fn detect_subcommand_with_globals() { assert_eq!( detect_subcommand(&args("lore --robot --config /tmp/c.json mrs")), Some("mrs") ); } #[test] fn detect_subcommand_with_color() { assert_eq!( detect_subcommand(&args("lore --color never issues")), Some("issues") ); } #[test] fn detect_subcommand_none() { assert_eq!(detect_subcommand(&args("lore --robot")), None); } // ---- Registry drift test ---- // This test uses clap introspection to verify our static registry covers // all long flags defined in the Cli struct. #[test] fn registry_covers_global_flags() { use clap::CommandFactory; let cmd = crate::cli::Cli::command(); let clap_globals: Vec = cmd .get_arguments() .filter_map(|a| a.get_long().map(|l| format!("--{l}"))) .collect(); for flag in &clap_globals { // Skip help/version — clap adds these automatically if flag == "--help" || flag == "--version" { continue; } assert!( GLOBAL_FLAGS.contains(&flag.as_str()), "Clap global flag {flag} is missing from GLOBAL_FLAGS registry. \ Add it to GLOBAL_FLAGS in autocorrect.rs." ); } } #[test] fn registry_covers_command_flags() { use clap::CommandFactory; let cmd = crate::cli::Cli::command(); for sub in cmd.get_subcommands() { let sub_name = sub.get_name().to_string(); // Find our registry entry let registry_entry = COMMAND_FLAGS.iter().find(|(name, _)| *name == sub_name); // Not all subcommands need entries (e.g., version, auth, status // with no subcommand-specific flags) let clap_flags: Vec = sub .get_arguments() .filter_map(|a| a.get_long().map(|l| format!("--{l}"))) .filter(|f| !GLOBAL_FLAGS.contains(&f.as_str())) .filter(|f| f != "--help" && f != "--version") .collect(); if clap_flags.is_empty() { continue; } let registry_flags = registry_entry.map(|(_, flags)| *flags); let registry_flags = registry_flags.unwrap_or_else(|| { panic!( "Subcommand '{sub_name}' has clap flags {clap_flags:?} but no COMMAND_FLAGS \ registry entry. Add it to COMMAND_FLAGS in autocorrect.rs." ) }); for flag in &clap_flags { assert!( registry_flags.contains(&flag.as_str()), "Clap flag {flag} on subcommand '{sub_name}' is missing from \ COMMAND_FLAGS registry. Add it to the '{sub_name}' entry in autocorrect.rs." ); } } } }