diff --git a/command_test.go b/command_test.go index 88338738a9..bfa86b86d7 100644 --- a/command_test.go +++ b/command_test.go @@ -3068,27 +3068,6 @@ func TestPersistentFlag(t *testing.T) { } func TestFlagDuplicates(t *testing.T) { - cmd := &Command{ - Flags: []Flag{ - &StringFlag{ - Name: "sflag", - OnlyOnce: true, - }, - &IntSliceFlag{ - Name: "isflag", - }, - &FloatSliceFlag{ - Name: "fsflag", - OnlyOnce: true, - }, - &IntFlag{ - Name: "iflag", - }, - }, - Action: func(ctx *Context) error { - return nil - }, - } tests := []struct { name string @@ -3117,6 +3096,28 @@ func TestFlagDuplicates(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { + cmd := &Command{ + Flags: []Flag{ + &StringFlag{ + Name: "sflag", + OnlyOnce: true, + }, + &IntSliceFlag{ + Name: "isflag", + }, + &FloatSliceFlag{ + Name: "fsflag", + OnlyOnce: true, + }, + &IntFlag{ + Name: "iflag", + }, + }, + Action: func(ctx *Context) error { + return nil + }, + } + err := cmd.Run(buildTestContext(t), test.args) if test.errExpected && err == nil { t.Error("expected error") diff --git a/errors.go b/errors.go index 1bb53ff287..1178558756 100644 --- a/errors.go +++ b/errors.go @@ -199,3 +199,12 @@ func handleMultiError(multiErr MultiError) int { } return code } + +type typeError[T any] struct { + other any +} + +func (te *typeError[T]) Error() string { + var t T + return fmt.Sprintf("Expected type %T got instead %T", t, te.other) +} diff --git a/flag_impl.go b/flag_impl.go index b1663d4c13..257549ba3f 100644 --- a/flag_impl.go +++ b/flag_impl.go @@ -13,47 +13,35 @@ type Value interface { flag.Getter } -// simple wrapper to intercept Value operations -// to check for duplicates -type valueWrapper struct { - value Value - count int - onlyOnce bool +type boolFlag interface { + IsBoolFlag() bool } -func (v *valueWrapper) String() string { - if v.value == nil { - return "" - } - return v.value.String() +type fnValue struct { + fn func(string) error + isBool bool + v Value } -func (v *valueWrapper) Set(s string) error { - if v.count == 1 && v.onlyOnce { - return fmt.Errorf("cant duplicate this flag") +func (f *fnValue) Get() any { return f.v.Get() } +func (f *fnValue) Set(s string) error { return f.fn(s) } +func (f *fnValue) String() string { + if f.v == nil { + return "" } - v.count++ - return v.value.Set(s) -} - -func (v *valueWrapper) Get() any { - return v.value.Get() -} - -func (v *valueWrapper) IsBoolFlag() bool { - _, ok := v.value.(*boolValue) - return ok + return f.v.String() } -func (v *valueWrapper) Serialize() string { - if s, ok := v.value.(Serializer); ok { +func (f *fnValue) Serialize() string { + if s, ok := f.v.(Serializer); ok { return s.Serialize() } - return v.value.String() + return f.v.String() } -func (v *valueWrapper) Count() int { - if s, ok := v.value.(Countable); ok { +func (f *fnValue) IsBoolFlag() bool { return f.isBool } +func (f *fnValue) Count() int { + if s, ok := f.v.(Countable); ok { return s.Count() } return 0 @@ -105,7 +93,10 @@ type FlagBase[T any, C any, VC ValueCreator[T, C]] struct { OnlyOnce bool // whether this flag can be duplicated on the command line + Validator func(T) error // custom function to validate this flag value + // unexported fields for internal use + count int // number of times the flag has been set hasBeenSet bool // whether the flag has been set from env or file applied bool // whether the flag has been applied to a flag set already creator VC // value creator for this flag type @@ -160,15 +151,48 @@ func (f *FlagBase[T, C, V]) Apply(set *flag.FlagSet) error { } else { f.value = f.creator.Create(newVal, f.Destination, f.Config) } + + // Validate the given default or values set from external sources as well + if f.Validator != nil { + if v, ok := f.value.Get().(T); !ok { + return &typeError[T]{ + other: f.value.Get(), + } + } else if err := f.Validator(v); err != nil { + return err + } + } } - vw := &valueWrapper{ - value: f.value, - onlyOnce: f.OnlyOnce, + isBool := false + if b, ok := f.value.(boolFlag); ok && b.IsBoolFlag() { + isBool = true } for _, name := range f.Names() { - set.Var(vw, name, f.Usage) + set.Var(&fnValue{ + fn: func(val string) error { + if f.count == 1 && f.OnlyOnce { + return fmt.Errorf("cant duplicate this flag") + } + f.count++ + if err := f.value.Set(val); err != nil { + return err + } + if f.Validator != nil { + if v, ok := f.value.Get().(T); !ok { + return &typeError[T]{ + other: f.value.Get(), + } + } else if err := f.Validator(v); err != nil { + return err + } + } + return nil + }, + isBool: isBool, + v: f.value, + }, name, f.Usage) } f.applied = true diff --git a/flag_int_slice.go b/flag_int_slice.go index bea6b8724c..6146e937bc 100644 --- a/flag_int_slice.go +++ b/flag_int_slice.go @@ -1,7 +1,5 @@ package cli -import "flag" - type IntSlice = SliceBase[int64, IntegerConfig, intValue] type IntSliceFlag = FlagBase[[]int64, IntegerConfig, IntSlice] @@ -10,18 +8,9 @@ var NewIntSlice = NewSliceBase[int64, IntegerConfig, intValue] // IntSlice looks up the value of a local IntSliceFlag, returns // nil if not found func (cCtx *Context) IntSlice(name string) []int64 { - if fs := cCtx.lookupFlagSet(name); fs != nil { - return lookupIntSlice(name, fs) + if v, ok := cCtx.Value(name).([]int64); ok { + return v } - return nil -} -func lookupIntSlice(name string, set *flag.FlagSet) []int64 { - f := set.Lookup(name) - if f != nil { - if slice, ok := f.Value.(flag.Getter).Get().([]int64); ok { - return slice - } - } return nil } diff --git a/flag_string_map.go b/flag_string_map.go index f75ed37201..58f07f7965 100644 --- a/flag_string_map.go +++ b/flag_string_map.go @@ -1,7 +1,5 @@ package cli -import "flag" - type StringMap = MapBase[string, StringConfig, stringValue] type StringMapFlag = FlagBase[map[string]string, StringConfig, StringMap] @@ -10,18 +8,8 @@ var NewStringMap = NewMapBase[string, StringConfig, stringValue] // StringMap looks up the value of a local StringMapFlag, returns // nil if not found func (cCtx *Context) StringMap(name string) map[string]string { - if fs := cCtx.lookupFlagSet(name); fs != nil { - return lookupStringMap(name, fs) - } - return nil -} - -func lookupStringMap(name string, set *flag.FlagSet) map[string]string { - f := set.Lookup(name) - if f != nil { - if mapping, ok := f.Value.(flag.Getter).Get().(map[string]string); ok { - return mapping - } + if v, ok := cCtx.Value(name).(map[string]string); ok { + return v } return nil } diff --git a/flag_string_slice.go b/flag_string_slice.go index 19e159320b..89370fa795 100644 --- a/flag_string_slice.go +++ b/flag_string_slice.go @@ -1,9 +1,5 @@ package cli -import ( - "flag" -) - type StringSlice = SliceBase[string, StringConfig, stringValue] type StringSliceFlag = FlagBase[[]string, StringConfig, StringSlice] @@ -12,18 +8,8 @@ var NewStringSlice = NewSliceBase[string, StringConfig, stringValue] // StringSlice looks up the value of a local StringSliceFlag, returns // nil if not found func (cCtx *Context) StringSlice(name string) []string { - if fs := cCtx.lookupFlagSet(name); fs != nil { - return lookupStringSlice(name, fs) - } - return nil -} - -func lookupStringSlice(name string, set *flag.FlagSet) []string { - f := set.Lookup(name) - if f != nil { - if slice, ok := f.Value.(flag.Getter).Get().([]string); ok { - return slice - } + if v, ok := cCtx.Value(name).([]string); ok { + return v } return nil } diff --git a/flag_uint_slice.go b/flag_uint_slice.go index d36ef62572..b25c46f1c6 100644 --- a/flag_uint_slice.go +++ b/flag_uint_slice.go @@ -1,9 +1,5 @@ package cli -import ( - "flag" -) - type UintSlice = SliceBase[uint64, IntegerConfig, uintValue] type UintSliceFlag = FlagBase[[]uint64, IntegerConfig, UintSlice] @@ -12,18 +8,8 @@ var NewUintSlice = NewSliceBase[uint64, IntegerConfig, uintValue] // UintSlice looks up the value of a local UintSliceFlag, returns // nil if not found func (cCtx *Context) UintSlice(name string) []uint64 { - if fs := cCtx.lookupFlagSet(name); fs != nil { - return lookupUintSlice(name, fs) - } - return nil -} - -func lookupUintSlice(name string, set *flag.FlagSet) []uint64 { - f := set.Lookup(name) - if f != nil { - if slice, ok := f.Value.(flag.Getter).Get().([]uint64); ok { - return slice - } + if v, ok := cCtx.Value(name).([]uint64); ok { + return v } return nil } diff --git a/flag_validation_test.go b/flag_validation_test.go new file mode 100644 index 0000000000..1a84345660 --- /dev/null +++ b/flag_validation_test.go @@ -0,0 +1,114 @@ +package cli + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFlagDefaultValidation(t *testing.T) { + + cmd := &Command{ + Name: "foo", + Flags: []Flag{ + &IntFlag{ + Name: "if", + Value: 2, // this value should fail validation + Validator: func(i int64) error { + if (i >= 3 && i <= 10) || (i >= 20 && i <= 24) { + return nil + } + return fmt.Errorf("Value %d not in range [3,10] or [20,24]", i) + }, + }, + }, + } + + r := require.New(t) + + // Default value of flag is 2 which should fail validation + err := cmd.Run(buildTestContext(t), []string{"foo", "--if", "5"}) + r.Error(err) +} + +func TestFlagValidation(t *testing.T) { + + r := require.New(t) + + testCases := []struct { + name string + arg string + errExpected bool + }{ + { + name: "first range less than min", + arg: "2", + errExpected: true, + }, + { + name: "first range min", + arg: "3", + }, + { + name: "first range mid", + arg: "7", + }, + { + name: "first range max", + arg: "10", + }, + { + name: "first range greater than max", + arg: "15", + errExpected: true, + }, + { + name: "second range less than min", + arg: "19", + errExpected: true, + }, + { + name: "second range min", + arg: "20", + }, + { + name: "second range mid", + arg: "21", + }, + { + name: "second range max", + arg: "24", + }, + { + name: "second range greater than max", + arg: "27", + errExpected: true, + }, + } + + for _, testCase := range testCases { + cmd := &Command{ + Name: "foo", + Flags: []Flag{ + &IntFlag{ + Name: "it", + Value: 5, // note that this value should pass validation + Validator: func(i int64) error { + if (i >= 3 && i <= 10) || (i >= 20 && i <= 24) { + return nil + } + return fmt.Errorf("Value %d not in range [3,10]U[20,24]", i) + }, + }, + }, + } + + err := cmd.Run(buildTestContext(t), []string{"foo", "--it", testCase.arg}) + if !testCase.errExpected { + r.NoError(err) + } else { + r.Error(err) + } + } +} diff --git a/godoc-current.txt b/godoc-current.txt index 018e5ee2bb..82aa6f4c8c 100644 --- a/godoc-current.txt +++ b/godoc-current.txt @@ -644,6 +644,8 @@ type FlagBase[T any, C any, VC ValueCreator[T, C]] struct { OnlyOnce bool // whether this flag can be duplicated on the command line + Validator func(T) error // custom function to validate this flag value + // Has unexported fields. } FlagBase[T,C,VC] is a generic flag base which can be used as a boilerplate diff --git a/testdata/godoc-v3.x.txt b/testdata/godoc-v3.x.txt index 018e5ee2bb..82aa6f4c8c 100644 --- a/testdata/godoc-v3.x.txt +++ b/testdata/godoc-v3.x.txt @@ -644,6 +644,8 @@ type FlagBase[T any, C any, VC ValueCreator[T, C]] struct { OnlyOnce bool // whether this flag can be duplicated on the command line + Validator func(T) error // custom function to validate this flag value + // Has unexported fields. } FlagBase[T,C,VC] is a generic flag base which can be used as a boilerplate