Add CLI input tolerance with fuzzy flag/command matching

Agent-friendly argument normalization that auto-corrects common
CLI syntax mistakes before cobra parses them:
- Single-dash long flags: -zip -> --zip
- Bare key=value: zip=33101 -> --zip=33101
- Typos via Levenshtein distance (max 2): --ziip -> --zip
- Command typos: categoriess -> categories
- Flag aliases: --zipcode, --dept, --search -> canonical names

Corrections emit a "note:" line to stderr showing what was rewritten.
Positional arguments for completion/help subcommands are preserved
(e.g., "completion zsh" is not rewritten). Integration tests verify
end-to-end behavior including tolerance notes, double-dash boundaries,
and help output for rewritten args.
This commit is contained in:
2026-02-22 21:12:19 -05:00
parent cca04bc11c
commit e299e16844
3 changed files with 445 additions and 0 deletions

319
cmd/cli_tolerance.go Normal file
View File

@@ -0,0 +1,319 @@
package cmd
import (
"fmt"
"strings"
)
type flagSpec struct {
name string
requiresValue bool
}
var knownFlags = map[string]flagSpec{
"store": {name: "store", requiresValue: true},
"zip": {name: "zip", requiresValue: true},
"json": {name: "json", requiresValue: false},
"category": {name: "category", requiresValue: true},
"department": {name: "department", requiresValue: true},
"bogo": {name: "bogo", requiresValue: false},
"query": {name: "query", requiresValue: true},
"limit": {name: "limit", requiresValue: true},
"help": {name: "help", requiresValue: false},
}
var knownCommands = []string{
"categories",
"stores",
"completion",
"help",
}
var flagAliases = map[string]string{
"zipcode": "zip",
"postal-code": "zip",
"store-number": "store",
"storeno": "store",
"dept": "department",
"search": "query",
"max": "limit",
}
func normalizeCLIArgs(args []string) ([]string, []string) {
out := make([]string, 0, len(args))
notes := make([]string, 0, 2)
commandChosen := false
activeCommand := ""
nestedCommandAllowed := false
nestedCommandChosen := false
allowBareFlagRewrite := true
expectingValue := false
afterDoubleDash := false
for i, tok := range args {
if afterDoubleDash {
out = append(out, tok)
continue
}
if expectingValue {
out = append(out, tok)
expectingValue = false
continue
}
if tok == "--" {
out = append(out, tok)
afterDoubleDash = true
continue
}
canBeCommand := !commandChosen || (nestedCommandAllowed && !nestedCommandChosen)
normalized, note, isFlag, needsValue, isCommand := normalizeToken(tok, canBeCommand, allowBareFlagRewrite)
if note != "" {
notes = append(notes, note)
}
out = append(out, normalized)
if isCommand {
if !commandChosen {
commandChosen = true
activeCommand = normalized
allowBareFlagRewrite = bareFlagRewriteAllowed(activeCommand)
nestedCommandAllowed = allowsNestedCommandArg(activeCommand)
continue
}
if nestedCommandAllowed && !nestedCommandChosen {
nestedCommandChosen = true
}
}
if isFlag && needsValue && !strings.Contains(normalized, "=") && i < len(args)-1 {
expectingValue = true
}
}
return out, notes
}
func normalizeToken(tok string, canBeCommand bool, allowBareFlagRewrite bool) (normalized, note string, isFlag, needsValue, isCommand bool) {
if tok == "--" {
return tok, "", false, false, false
}
if strings.HasPrefix(tok, "--") {
flagName, rest := splitFlag(strings.TrimPrefix(tok, "--"))
canonical, ok := resolveFlagName(flagName)
if ok {
newTok := "--" + canonical + rest
if newTok != tok {
return newTok, fmt.Sprintf("interpreted `%s` as `%s`; use `%s` next time.", tok, newTok, newTok), true, knownFlags[canonical].requiresValue, false
}
return newTok, "", true, knownFlags[canonical].requiresValue, false
}
return tok, "", true, false, false
}
if strings.HasPrefix(tok, "-") && len(tok) > 2 {
flagName, rest := splitFlag(strings.TrimPrefix(tok, "-"))
canonical, ok := resolveFlagName(flagName)
if ok {
newTok := "--" + canonical + rest
return newTok, fmt.Sprintf("interpreted `%s` as `%s`; use `%s` next time.", tok, newTok, newTok), true, knownFlags[canonical].requiresValue, false
}
return tok, "", true, false, false
}
if strings.Contains(tok, "=") && !strings.HasPrefix(tok, "-") {
flagName, rest := splitFlag(tok)
canonical, ok := resolveFlagName(flagName)
if ok {
newTok := "--" + canonical + rest
return newTok, fmt.Sprintf("interpreted `%s` as `%s`; use `%s` next time.", tok, newTok, newTok), true, knownFlags[canonical].requiresValue, false
}
}
if canBeCommand && !strings.HasPrefix(tok, "-") {
if corrected, ok := resolveCommand(tok); ok {
if corrected != tok {
return corrected, fmt.Sprintf("interpreted command `%s` as `%s`; use `%s` next time.", tok, corrected, corrected), false, false, true
}
return tok, "", false, false, true
}
}
if allowBareFlagRewrite && !strings.HasPrefix(tok, "-") {
canonical, ok := resolveFlagName(tok)
if ok {
newTok := "--" + canonical
return newTok, fmt.Sprintf("interpreted `%s` as `%s`; use `%s` next time.", tok, newTok, newTok), true, knownFlags[canonical].requiresValue, false
}
}
return tok, "", false, false, false
}
func bareFlagRewriteAllowed(command string) bool {
// Some commands (for example `stores` and `categories`) are flag-only, so
// rewriting bare tokens like `zip` -> `--zip` is helpful there.
switch command {
case "stores", "categories":
return true
default:
return false
}
}
func allowsNestedCommandArg(command string) bool {
// These commands accept another command token as a positional argument.
switch command {
case "help", "completion":
return true
default:
return false
}
}
func resolveFlagName(raw string) (string, bool) {
name := strings.ToLower(strings.TrimSpace(raw))
name = strings.ReplaceAll(name, "_", "-")
if canonical, ok := flagAliases[name]; ok {
return canonical, true
}
if _, ok := knownFlags[name]; ok {
return name, true
}
if suggestion, ok := closestMatch(name, mapKeys(knownFlags), 2); ok {
return suggestion, true
}
return "", false
}
func resolveCommand(raw string) (string, bool) {
name := strings.ToLower(strings.TrimSpace(raw))
for _, cmd := range knownCommands {
if name == cmd {
return cmd, true
}
}
if suggestion, ok := closestMatch(name, knownCommands, 2); ok {
return suggestion, true
}
return "", false
}
func explainCLIError(err error) string {
return formatCLIErrorText(classifyCLIError(err))
}
func splitFlag(value string) (string, string) {
parts := strings.SplitN(value, "=", 2)
if len(parts) == 2 {
return parts[0], "=" + parts[1]
}
return value, ""
}
func extractUnknownValue(msg, marker string) string {
idx := strings.Index(msg, marker)
if idx == -1 {
return ""
}
remaining := strings.TrimSpace(msg[idx+len(marker):])
remaining = strings.TrimPrefix(remaining, ":")
remaining = strings.TrimSpace(remaining)
if strings.HasPrefix(remaining, "\"") {
remaining = strings.TrimPrefix(remaining, "\"")
end := strings.Index(remaining, "\"")
if end >= 0 {
return remaining[:end]
}
}
if strings.HasPrefix(remaining, "`") {
remaining = strings.TrimPrefix(remaining, "`")
end := strings.Index(remaining, "`")
if end >= 0 {
return remaining[:end]
}
}
if fields := strings.Fields(remaining); len(fields) > 0 {
return strings.Trim(fields[0], "\"`")
}
return ""
}
func mapKeys[K comparable, V any](m map[K]V) []K {
keys := make([]K, 0, len(m))
for key := range m {
keys = append(keys, key)
}
return keys
}
func closestMatch(target string, candidates []string, maxDistance int) (string, bool) {
best := ""
bestDist := maxDistance + 1
for _, candidate := range candidates {
d := levenshtein(target, candidate)
if d < bestDist {
bestDist = d
best = candidate
}
}
if bestDist <= maxDistance {
return best, true
}
return "", false
}
func levenshtein(a, b string) int {
if a == b {
return 0
}
if len(a) == 0 {
return len(b)
}
if len(b) == 0 {
return len(a)
}
prev := make([]int, len(b)+1)
curr := make([]int, len(b)+1)
for j := range prev {
prev[j] = j
}
for i := 1; i <= len(a); i++ {
curr[0] = i
for j := 1; j <= len(b); j++ {
cost := 0
if a[i-1] != b[j-1] {
cost = 1
}
del := prev[j] + 1
ins := curr[j-1] + 1
sub := prev[j-1] + cost
curr[j] = minInt(del, ins, sub)
}
prev, curr = curr, prev
}
return prev[len(b)]
}
func minInt(vals ...int) int {
best := vals[0]
for _, v := range vals[1:] {
if v < best {
best = v
}
}
return best
}

73
cmd/cli_tolerance_test.go Normal file
View File

@@ -0,0 +1,73 @@
package cmd
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
)
func TestNormalizeCLIArgs_RewritesCommonFlagSyntax(t *testing.T) {
args, notes := normalizeCLIArgs([]string{"-zip", "33101", "json"})
assert.Equal(t, []string{"--zip", "33101", "--json"}, args)
assert.NotEmpty(t, notes)
}
func TestNormalizeCLIArgs_RewritesTypoFlag(t *testing.T) {
args, notes := normalizeCLIArgs([]string{"--ziip", "33101"})
assert.Equal(t, []string{"--zip", "33101"}, args)
assert.NotEmpty(t, notes)
}
func TestNormalizeCLIArgs_RewritesCommandTypo(t *testing.T) {
args, notes := normalizeCLIArgs([]string{"categoriess", "--zip", "33101"})
assert.Equal(t, []string{"categories", "--zip", "33101"}, args)
assert.NotEmpty(t, notes)
}
func TestNormalizeCLIArgs_DoesNotRewriteCompletionPositionalArgs(t *testing.T) {
args, notes := normalizeCLIArgs([]string{"completion", "zsh"})
assert.Equal(t, []string{"completion", "zsh"}, args)
assert.Empty(t, notes)
}
func TestNormalizeCLIArgs_DoesNotRewriteHelpCommandArgAsFlag(t *testing.T) {
args, notes := normalizeCLIArgs([]string{"help", "stores"})
assert.Equal(t, []string{"help", "stores"}, args)
assert.Empty(t, notes)
}
func TestNormalizeCLIArgs_RespectsDoubleDashBoundary(t *testing.T) {
args, notes := normalizeCLIArgs([]string{"stores", "--", "zip", "33101"})
assert.Equal(t, []string{"stores", "--", "zip", "33101"}, args)
assert.Empty(t, notes)
}
func TestNormalizeCLIArgs_LeavesKnownShorthandUntouched(t *testing.T) {
args, notes := normalizeCLIArgs([]string{"-z", "33101", "-n", "5"})
assert.Equal(t, []string{"-z", "33101", "-n", "5"}, args)
assert.Empty(t, notes)
}
func TestExplainCLIError_UnknownFlagIncludesSuggestionAndExamples(t *testing.T) {
msg := explainCLIError(errors.New("unknown flag: --ziip"))
assert.Contains(t, msg, "Try `--zip`.")
assert.Contains(t, msg, "pubcli --zip 33101")
assert.Contains(t, msg, "pubcli --store 1425 --bogo")
}
func TestExplainCLIError_UnknownCommandIncludesSuggestionAndExamples(t *testing.T) {
msg := explainCLIError(errors.New("unknown command \"stors\" for \"pubcli\""))
assert.Contains(t, msg, "Did you mean `stores`?")
assert.Contains(t, msg, "pubcli stores --zip 33101")
assert.Contains(t, msg, "pubcli categories --zip 33101")
}

View File

@@ -0,0 +1,53 @@
package cmd
import (
"bytes"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestRunCLI_CompletionZsh(t *testing.T) {
var stdout bytes.Buffer
var stderr bytes.Buffer
code := runCLI([]string{"completion", "zsh"}, &stdout, &stderr)
assert.Equal(t, 0, code)
assert.Contains(t, stdout.String(), "#compdef pubcli")
assert.Empty(t, stderr.String())
}
func TestRunCLI_HelpStores(t *testing.T) {
var stdout bytes.Buffer
var stderr bytes.Buffer
code := runCLI([]string{"help", "stores"}, &stdout, &stderr)
assert.Equal(t, 0, code)
assert.Contains(t, stdout.String(), "pubcli stores [flags]")
assert.Empty(t, stderr.String())
}
func TestRunCLI_TolerantRewriteWithoutNetworkCall(t *testing.T) {
var stdout bytes.Buffer
var stderr bytes.Buffer
code := runCLI([]string{"stores", "-zip", "33101", "--help"}, &stdout, &stderr)
assert.Equal(t, 0, code)
assert.Contains(t, stdout.String(), "pubcli stores [flags]")
assert.Contains(t, stderr.String(), "interpreted `-zip` as `--zip`")
}
func TestRunCLI_DoubleDashBoundary(t *testing.T) {
var stdout bytes.Buffer
var stderr bytes.Buffer
code := runCLI([]string{"stores", "--", "zip", "33101", "--help"}, &stdout, &stderr)
assert.Equal(t, 0, code)
assert.Contains(t, stdout.String(), "pubcli stores [flags]")
assert.False(t, strings.Contains(stderr.String(), "interpreted `zip` as `--zip`"))
}