Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 107 additions & 8 deletions flag.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,29 @@ const (
PanicOnError
)

// UnknownFlagsHandling decides how to handle unknown flags
type UnknownFlagsHandling int

const (
// ErrorOnUnknownFlag will return an error if an unknown flag is found
ErrorOnUnknownFlag UnknownFlagsHandling = iota
// IgnoreUnknownFlag will ignore unknown flags and continue parsing rest of the flags
IgnoreUnknownFlag
// PassUnknownFlagToArgs will treat unknown flags as non-flag arguments.
// Combined shorthand flags mixed with known ones and unknown ones results
// combined flags only with unknown ones.
// E.g. -fghi results -gh if only `f` and `i` are known.
PassUnknownFlagToArgs
)

// ParseErrorsAllowlist defines the parsing errors that can be ignored
type ParseErrorsAllowlist struct {
// UnknownFlags will ignore unknown flags errors and continue parsing rest of the flags
// Deprecated: Use UnknownFlagsHandling instead
UnknownFlags bool

// UnknownFlagsHandling decides how to handle unknown flags. Defaults to UnknownFlagsHandlingErrorOnUnknown.
UnknownFlagsHandling UnknownFlagsHandling
}

// ParseErrorsWhitelist defines the parsing errors that can be ignored.
Expand Down Expand Up @@ -337,6 +356,35 @@ func (f *FlagSet) HasAvailableFlags() bool {
return false
}

// getUnknownFlagsHandling returns the UnknownFlagsHandling value,
// considering deprecated ParseErrorsWhitelist and UnknownFlags field
// After removing ParseErrorsWhitelist, this function can be simplified
// and moved to ParseErrorsAllowlist.getUnknownFlagsHandling()
func (f *FlagSet) getUnknownFlagsHandling() UnknownFlagsHandling {
// first check ParseErrorsAllowlist:
// if UnknownFlagsHandling is set, use it
if f.ParseErrorsAllowlist.UnknownFlagsHandling != ErrorOnUnknownFlag {
return f.ParseErrorsAllowlist.UnknownFlagsHandling
}

if f.ParseErrorsAllowlist.UnknownFlags {
return IgnoreUnknownFlag
}

// then, check deprecated ParseErrorsWhitelist:
// if UnknownFlagsHandling is set, use it
if f.ParseErrorsWhitelist.UnknownFlagsHandling != ErrorOnUnknownFlag {
return f.ParseErrorsAllowlist.UnknownFlagsHandling
}

if f.ParseErrorsWhitelist.UnknownFlags {
return IgnoreUnknownFlag
}

// Otherwise, return the default value
return ErrorOnUnknownFlag
}

// VisitAll visits the command-line flags in lexicographical order or
// in primordial order if f.SortFlags is false, calling fn for each.
// It visits all flags, even those not set.
Expand Down Expand Up @@ -988,6 +1036,17 @@ func stripUnknownFlagValue(args []string) []string {
return nil
}

// errUnknownFlag is used for internal unknown flag handling.
type unknownFlagError struct {
// UnknownFlags is flags that are unknown and unprocessed.
// It depends on the context whether this has a prefix like '-' or '--'.
UnknownFlags string
}

func (e *unknownFlagError) Error() string {
return fmt.Sprintf("unknown flag: %v", e.UnknownFlags)
}

func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []string, err error) {
a = args
name := s[2:]
Expand All @@ -999,22 +1058,25 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin
split := strings.SplitN(name, "=", 2)
name = split[0]
flag, exists := f.formal[f.normalizeFlagName(name)]
unknownFlagsHandling := f.getUnknownFlagsHandling()

if !exists {
switch {
case name == "help":
f.usage()
return a, ErrHelp
case f.ParseErrorsWhitelist.UnknownFlags:
fallthrough
case f.ParseErrorsAllowlist.UnknownFlags:
case unknownFlagsHandling == IgnoreUnknownFlag:
// --unknown=unknownval arg ...
// we do not want to lose arg in this case
if len(split) >= 2 {
return a, nil
}

return stripUnknownFlagValue(a), nil
case unknownFlagsHandling == PassUnknownFlagToArgs:
return a, &unknownFlagError{
UnknownFlags: s,
}
default:
err = f.fail(&NotExistError{name: name, messageType: flagUnknownFlagMessage})
return
Expand Down Expand Up @@ -1060,14 +1122,14 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse

flag, exists := f.shorthands[c]
if !exists {
unknownFlagsHandling := f.getUnknownFlagsHandling()

switch {
case c == 'h':
f.usage()
err = ErrHelp
return
case f.ParseErrorsWhitelist.UnknownFlags:
fallthrough
case f.ParseErrorsAllowlist.UnknownFlags:
case unknownFlagsHandling == IgnoreUnknownFlag:
// '-f=arg arg ...'
// we do not want to lose arg in this case
if len(shorthands) > 2 && shorthands[1] == '=' {
Expand All @@ -1077,6 +1139,20 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse

outArgs = stripUnknownFlagValue(outArgs)
return
case unknownFlagsHandling == PassUnknownFlagToArgs:
// '-f=arg': pass all the argument
if len(shorthands) > 2 && shorthands[1] == '=' {
outShorts = ""
err = &unknownFlagError{
UnknownFlags: shorthands,
}
return
}
// '-fgh': pass only the first switch
err = &unknownFlagError{
UnknownFlags: shorthands[0:1],
}
return
default:
err = f.fail(&NotExistError{
name: string(c),
Expand Down Expand Up @@ -1127,14 +1203,31 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse
func (f *FlagSet) parseShortArg(s string, args []string, fn parseFunc) (a []string, err error) {
a = args
shorthands := s[1:]
var errUnknownFlagAll *unknownFlagError

// "shorthands" can be a series of shorthand letters of flags (e.g. "-vvv").
for len(shorthands) > 0 {
shorthands, a, err = f.parseSingleShortArg(shorthands, args, fn)
if err != nil {
return
if errUnknownFlag, ok := err.(*unknownFlagError); ok {
// this means f.ParseErrorsAllowlist.UnknownFlagsHandling is set to UnknownFlagsHandlingPassUnknownToArgs
if errUnknownFlagAll == nil {
errUnknownFlagAll = &unknownFlagError{
UnknownFlags: "-",
}
}

errUnknownFlagAll.UnknownFlags = errUnknownFlagAll.UnknownFlags +
errUnknownFlag.UnknownFlags
err = nil
} else {
return
}
}
}
if errUnknownFlagAll != nil {
err = errUnknownFlagAll
}

return
}
Expand Down Expand Up @@ -1164,7 +1257,13 @@ func (f *FlagSet) parseArgs(args []string, fn parseFunc) (err error) {
args, err = f.parseShortArg(s, args, fn)
}
if err != nil {
return
if errUnknownFlag, ok := err.(*unknownFlagError); ok {
// this means f.ParseErrorsAllowlist.UnknownFlagsHandling is set to UnknownFlagsHandlingPassUnknownToArgs
f.args = append(f.args, errUnknownFlag.UnknownFlags)
err = nil
} else {
return
}
}
}
return
Expand Down
109 changes: 109 additions & 0 deletions flag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,111 @@ func testParseWithUnknownFlags(f *FlagSet, t *testing.T, setUnknownFlags func(f
}
}

func testParseWithUnknownFlagsAndPassToArgs(f *FlagSet, t *testing.T) {
if f.Parsed() {
t.Fatal("f.Parse() = true before Parse")
}
f.ParseErrorsAllowlist.UnknownFlagsHandling = PassUnknownFlagToArgs
f.SetInterspersed(true)

f.BoolP("boola", "a", false, "bool value")
f.BoolP("boolb", "b", false, "bool2 value")
f.BoolP("boolc", "c", false, "bool3 value")
f.BoolP("boold", "d", false, "bool4 value")
f.BoolP("boole", "e", false, "bool4 value")
f.StringP("stringa", "s", "0", "string value")
f.StringP("stringz", "z", "0", "string value")
f.StringP("stringx", "x", "0", "string value")
f.StringP("stringy", "y", "0", "string value")
f.StringP("stringo", "o", "0", "string value")
f.Lookup("stringx").NoOptDefVal = "1"
args := []string{
"-ab",
// -f and -g is unknown
"-fcgs=xx",
"--stringz=something",
"--unknown1",
"unknown1Value",
"-d=true",
"-x",
"--unknown2=unknown2Value",
"-u=unknown3Value",
"-p",
"unknown4Value",
"-q", //another unknown with bool value
"-y",
"ee",
"--unknown7=unknown7value",
"--stringo=ovalue",
"--unknown8=unknown8value",
"--boole",
"--unknown6",
"",
"-uuuuu",
"",
"--unknown10",
"--unknown11",
"arg0",
"arg1",
}
want := []string{
"boola", "true",
"boolb", "true",
"boolc", "true",
"stringa", "xx",
"stringz", "something",
"boold", "true",
"stringx", "1",
"stringy", "ee",
"stringo", "ovalue",
"boole", "true",
}
wantArgs := []string{
"-fg",
"--unknown1",
"unknown1Value",
"--unknown2=unknown2Value",
"-u=unknown3Value",
"-p",
"unknown4Value",
"-q", //another unknown with bool value
"--unknown7=unknown7value",
"--unknown8=unknown8value",
"--unknown6",
"",
"-uuuuu",
"",
"--unknown10",
"--unknown11",
"arg0",
"arg1",
}
got := []string{}
store := func(flag *Flag, value string) error {
got = append(got, flag.Name)
if len(value) > 0 {
got = append(got, value)
}
return nil
}
if err := f.ParseAll(args, store); err != nil {
t.Errorf("expected no error, got %s", err)
}
if !f.Parsed() {
t.Errorf("f.Parse() = false after Parse")
}
if !reflect.DeepEqual(got, want) {
t.Errorf("f.ParseAll() fail to restore the args")
t.Errorf("Got: %v", got)
t.Errorf("Want: %v", want)
}
if !reflect.DeepEqual(f.Args(), wantArgs) {
t.Errorf("f.ParseAll() fail to restore the args")
t.Errorf("Got: %v", f.Args())
t.Errorf("Want: %v", wantArgs)
}
}

func TestShorthand(t *testing.T) {
f := NewFlagSet("shorthand", ContinueOnError)
if f.Parsed() {
Expand Down Expand Up @@ -676,6 +781,10 @@ func TestIgnoreUnknownFlagsBackwardsCompat(t *testing.T) {
testParseWithUnknownFlags(GetCommandLine(), t, func(f *FlagSet) { f.ParseErrorsWhitelist.UnknownFlags = true })
}

func TestIgnoreUnknownFlagsAndPassToArgs(t *testing.T) {
ResetForTesting(func() { t.Error("bad parse") })
testParseWithUnknownFlagsAndPassToArgs(GetCommandLine(), t)
}
func TestFlagSetParse(t *testing.T) {
testParse(NewFlagSet("test", ContinueOnError), t)
}
Expand Down