From 9cd59d21fd5ed204631cc9fae368ca01651248ef Mon Sep 17 00:00:00 2001 From: thibaud stevelinck Date: Tue, 23 Jun 2020 23:43:14 +0200 Subject: [PATCH 1/2] add unknown flags support --- flag.go | 106 ++++++++++++++++++++++++++++++++++----------------- flag_test.go | 63 ++++++++++++++++++++++++++++++ 2 files changed, 135 insertions(+), 34 deletions(-) diff --git a/flag.go b/flag.go index 7c058de3..f274f8c0 100644 --- a/flag.go +++ b/flag.go @@ -165,6 +165,7 @@ type FlagSet struct { normalizeNameFunc func(f *FlagSet, name string) NormalizedName addedGoFlagSets []*goflag.FlagSet + unknownFlags []*Flag } // A Flag represents the state of a flag. @@ -182,6 +183,12 @@ type Flag struct { Annotations map[string][]string // used by cobra.Command bash autocomple code } +// A UnknownFlag represents the state of a flag that is not expected. +type UnknownFlag struct { + Name string // name as it appears on command line + Value Value // value as set +} + // Value is the interface to the dynamic value stored in a flag. // (The default value is represented as a string.) type Value interface { @@ -275,6 +282,16 @@ func (f *FlagSet) SetOutput(output io.Writer) { f.output = output } +func (f *FlagSet) VisitUnknowns(fn func(*Flag)) { + if len(f.unknownFlags) == 0 { + return + } + + for _, flag := range f.unknownFlags { + fn(flag) + } +} + // VisitAll visits the flags in lexicographical order or // in primordial order if f.SortFlags is false, calling fn for each. // It visits all flags, even those not set. @@ -956,6 +973,18 @@ func stripUnknownFlagValue(args []string) []string { return nil } +func createUnknownFlag(name string, value string) *Flag { + flag := new(Flag) + flag.Name = name + flag.Value = newStringValue(value, &value) + return flag +} + +func (f *FlagSet) addUnknownFlag(name string, value string) { + flag := createUnknownFlag(name, value) + f.unknownFlags = append(f.unknownFlags, flag) +} + func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []string, err error) { a = args name := s[2:] @@ -969,19 +998,11 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin flag, exists := f.formal[f.normalizeFlagName(name)] if !exists { - switch { - case name == "help": + if name == "help" { f.usage() return a, ErrHelp - case f.ParseErrorsWhitelist.UnknownFlags: - // --unknown=unknownval arg ... - // we do not want to lose arg in this case - if len(split) >= 2 { - return a, nil - } - - return stripUnknownFlagValue(a), nil - default: + } + if !f.ParseErrorsWhitelist.UnknownFlags { err = f.failf("unknown flag: --%s", name) return } @@ -991,15 +1012,28 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin if len(split) == 2 { // '--flag=arg' value = split[1] - } else if flag.NoOptDefVal != "" { + } else if exists && flag.NoOptDefVal != "" { // '--flag' (arg was optional) value = flag.NoOptDefVal } else if len(a) > 0 { // '--flag arg' - value = a[0] - a = a[1:] - } else { - // '--flag' (arg was required) + if !exists && strings.HasPrefix(a[0], "-") { + value = "" + } else { + value = a[0] + a = a[1:] + } + } else if f.ParseErrorsWhitelist.UnknownFlags { + value = "" + } + + if !exists && f.ParseErrorsWhitelist.UnknownFlags { + fmt.Println("in not exist flag") + f.addUnknownFlag(name, value) + return + } + + if flag.NoOptDefVal == "" && value == "" { err = f.failf("flag needs an argument: %s", s) return } @@ -1023,22 +1057,12 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse flag, exists := f.shorthands[c] if !exists { - switch { - case c == 'h': + if c == 'h' { f.usage() err = ErrHelp return - case f.ParseErrorsWhitelist.UnknownFlags: - // '-f=arg arg ...' - // we do not want to lose arg in this case - if len(shorthands) > 2 && shorthands[1] == '=' { - outShorts = "" - return - } - - outArgs = stripUnknownFlagValue(outArgs) - return - default: + } + if !f.ParseErrorsWhitelist.UnknownFlags { err = f.failf("unknown shorthand flag: %q in -%s", c, shorthands) return } @@ -1049,8 +1073,8 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse // '-f=arg' value = shorthands[2:] outShorts = "" - } else if flag.NoOptDefVal != "" { - // '-f' (arg was optional) + } else if exists && flag.NoOptDefVal != "" { + // '--flag' (arg was optional) value = flag.NoOptDefVal } else if len(shorthands) > 1 { // '-farg' @@ -1058,9 +1082,23 @@ func (f *FlagSet) parseSingleShortArg(shorthands string, args []string, fn parse outShorts = "" } else if len(args) > 0 { // '-f arg' - value = args[0] - outArgs = args[1:] - } else { + if !exists && strings.HasPrefix(args[0], "-") { + value = "" + } else { + value = args[0] + outArgs = args[1:] + } + + } else if f.ParseErrorsWhitelist.UnknownFlags { + value = "" + } + + if !exists && f.ParseErrorsWhitelist.UnknownFlags { + f.addUnknownFlag(string(c), value) + return + } + + if flag.NoOptDefVal == "" && value == "" { // '-f' (arg was required) err = f.failf("flag needs an argument: %q in -%s", c, shorthands) return diff --git a/flag_test.go b/flag_test.go index 58a5d25a..13e9e242 100644 --- a/flag_test.go +++ b/flag_test.go @@ -480,6 +480,68 @@ func testParseWithUnknownFlags(f *FlagSet, t *testing.T) { } } +func testRetrieveUknowsWhenUnknownFlagsParsed(t *testing.T) { + f := NewFlagSet("unknwonFlags", ContinueOnError) + if f.Parsed() { + t.Error("f.Parse() = true before Parse") + } + boolaFlag := f.BoolP("boola", "a", false, "bool value") + stringaFlag := f.StringP("stringa", "s", "0", "string value") + + args := []string{ + "-a", + "--stringa", + "hello", + "--unknownFlag1", + "unknownValue1", + "--unknownFlag2", + "--unknownFlag3=unknownValue3", + "-e", + "unknownValue4", + "-f=unknownValue5", + "-g", + } + + f.ParseErrorsWhitelist.UnknownFlags = true + + want := map[string]string{ + "unknownFlag1": "unknownValue1", + "unknownFlag2": "", + "unknownFlag3": "unknownValue3", + "e": "unknownValue4", + "f": "unknownValue5", + "g": "", + } + + f.SetOutput(ioutil.Discard) + if err := f.Parse(args); err != nil { + t.Error("expected no error, got ", err) + } + if !f.Parsed() { + t.Error("f.Parse() = false after Parse") + } + if *boolaFlag != true { + t.Error("boola flag should be true, is ", *boolaFlag) + } + if *stringaFlag != "hello" { + t.Error("stringa flag should be `hello`, is ", *stringaFlag) + } + if len(f.unknownFlags) != len(want) { + t.Errorf("f.ParseAll() failed to parse unknown flags") + } + for _, flag := range f.unknownFlags { + wantedValue, ok := want[flag.Name] + if !ok { + t.Errorf("f.unknownFlags contains a flag \"%s\" and shouldn't", flag.Name) + break + } + if wantedValue != flag.Value.String() { + t.Errorf("value for the unknown flag \"%s\" should be \"%s\", got \"%s\"", flag.Name, wantedValue, flag.Value.String()) + } + + } +} + func TestShorthand(t *testing.T) { f := NewFlagSet("shorthand", ContinueOnError) if f.Parsed() { @@ -588,6 +650,7 @@ func TestParseAll(t *testing.T) { func TestIgnoreUnknownFlags(t *testing.T) { ResetForTesting(func() { t.Error("bad parse") }) + testRetrieveUknowsWhenUnknownFlagsParsed(t) testParseWithUnknownFlags(GetCommandLine(), t) } From 33dec6aac49458ea83cc46413ffcc7e0a189fef4 Mon Sep 17 00:00:00 2001 From: tibo Date: Fri, 18 Sep 2020 22:41:06 +0200 Subject: [PATCH 2/2] fixing test --- flag.go | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/flag.go b/flag.go index f274f8c0..529c84bb 100644 --- a/flag.go +++ b/flag.go @@ -282,6 +282,7 @@ func (f *FlagSet) SetOutput(output io.Writer) { f.output = output } +// VisitUnknowns visits all the flags that have not been registered. func (f *FlagSet) VisitUnknowns(fn func(*Flag)) { if len(f.unknownFlags) == 0 { return @@ -1028,16 +1029,10 @@ func (f *FlagSet) parseLongArg(s string, args []string, fn parseFunc) (a []strin } if !exists && f.ParseErrorsWhitelist.UnknownFlags { - fmt.Println("in not exist flag") f.addUnknownFlag(name, value) return } - if flag.NoOptDefVal == "" && value == "" { - err = f.failf("flag needs an argument: %s", s) - return - } - err = fn(flag, value) if err != nil { f.failf(err.Error())