Skip to content

Commit

Permalink
Merge pull request #1785 from dearchap/validation
Browse files Browse the repository at this point in the history
Feature: Add support for validation functions
  • Loading branch information
dearchap authored Jun 27, 2023
2 parents 4a9488f + 51cb2ef commit fc3b515
Show file tree
Hide file tree
Showing 10 changed files with 215 additions and 114 deletions.
43 changes: 22 additions & 21 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3068,27 +3068,6 @@ func TestPersistentFlag(t *testing.T) {
}

func TestFlagDuplicates(t *testing.T) {
cmd := &Command{
Flags: []Flag{
&StringFlag{
Name: "sflag",
OnlyOnce: true,
},
&IntSliceFlag{
Name: "isflag",
},
&FloatSliceFlag{
Name: "fsflag",
OnlyOnce: true,
},
&IntFlag{
Name: "iflag",
},
},
Action: func(ctx *Context) error {
return nil
},
}

tests := []struct {
name string
Expand Down Expand Up @@ -3117,6 +3096,28 @@ func TestFlagDuplicates(t *testing.T) {

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
cmd := &Command{
Flags: []Flag{
&StringFlag{
Name: "sflag",
OnlyOnce: true,
},
&IntSliceFlag{
Name: "isflag",
},
&FloatSliceFlag{
Name: "fsflag",
OnlyOnce: true,
},
&IntFlag{
Name: "iflag",
},
},
Action: func(ctx *Context) error {
return nil
},
}

err := cmd.Run(buildTestContext(t), test.args)
if test.errExpected && err == nil {
t.Error("expected error")
Expand Down
9 changes: 9 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,12 @@ func handleMultiError(multiErr MultiError) int {
}
return code
}

type typeError[T any] struct {
other any
}

func (te *typeError[T]) Error() string {
var t T
return fmt.Sprintf("Expected type %T got instead %T", t, te.other)
}
92 changes: 58 additions & 34 deletions flag_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,47 +13,35 @@ type Value interface {
flag.Getter
}

// simple wrapper to intercept Value operations
// to check for duplicates
type valueWrapper struct {
value Value
count int
onlyOnce bool
type boolFlag interface {
IsBoolFlag() bool
}

func (v *valueWrapper) String() string {
if v.value == nil {
return ""
}
return v.value.String()
type fnValue struct {
fn func(string) error
isBool bool
v Value
}

func (v *valueWrapper) Set(s string) error {
if v.count == 1 && v.onlyOnce {
return fmt.Errorf("cant duplicate this flag")
func (f *fnValue) Get() any { return f.v.Get() }
func (f *fnValue) Set(s string) error { return f.fn(s) }
func (f *fnValue) String() string {
if f.v == nil {
return ""
}
v.count++
return v.value.Set(s)
}

func (v *valueWrapper) Get() any {
return v.value.Get()
}

func (v *valueWrapper) IsBoolFlag() bool {
_, ok := v.value.(*boolValue)
return ok
return f.v.String()
}

func (v *valueWrapper) Serialize() string {
if s, ok := v.value.(Serializer); ok {
func (f *fnValue) Serialize() string {
if s, ok := f.v.(Serializer); ok {
return s.Serialize()
}
return v.value.String()
return f.v.String()
}

func (v *valueWrapper) Count() int {
if s, ok := v.value.(Countable); ok {
func (f *fnValue) IsBoolFlag() bool { return f.isBool }
func (f *fnValue) Count() int {
if s, ok := f.v.(Countable); ok {
return s.Count()
}
return 0
Expand Down Expand Up @@ -105,7 +93,10 @@ type FlagBase[T any, C any, VC ValueCreator[T, C]] struct {

OnlyOnce bool // whether this flag can be duplicated on the command line

Validator func(T) error // custom function to validate this flag value

// unexported fields for internal use
count int // number of times the flag has been set
hasBeenSet bool // whether the flag has been set from env or file
applied bool // whether the flag has been applied to a flag set already
creator VC // value creator for this flag type
Expand Down Expand Up @@ -160,15 +151,48 @@ func (f *FlagBase[T, C, V]) Apply(set *flag.FlagSet) error {
} else {
f.value = f.creator.Create(newVal, f.Destination, f.Config)
}

// Validate the given default or values set from external sources as well
if f.Validator != nil {
if v, ok := f.value.Get().(T); !ok {
return &typeError[T]{
other: f.value.Get(),
}
} else if err := f.Validator(v); err != nil {
return err
}
}
}

vw := &valueWrapper{
value: f.value,
onlyOnce: f.OnlyOnce,
isBool := false
if b, ok := f.value.(boolFlag); ok && b.IsBoolFlag() {
isBool = true
}

for _, name := range f.Names() {
set.Var(vw, name, f.Usage)
set.Var(&fnValue{
fn: func(val string) error {
if f.count == 1 && f.OnlyOnce {
return fmt.Errorf("cant duplicate this flag")
}
f.count++
if err := f.value.Set(val); err != nil {
return err
}
if f.Validator != nil {
if v, ok := f.value.Get().(T); !ok {
return &typeError[T]{
other: f.value.Get(),
}
} else if err := f.Validator(v); err != nil {
return err
}
}
return nil
},
isBool: isBool,
v: f.value,
}, name, f.Usage)
}

f.applied = true
Expand Down
15 changes: 2 additions & 13 deletions flag_int_slice.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package cli

import "flag"

type IntSlice = SliceBase[int64, IntegerConfig, intValue]
type IntSliceFlag = FlagBase[[]int64, IntegerConfig, IntSlice]

Expand All @@ -10,18 +8,9 @@ var NewIntSlice = NewSliceBase[int64, IntegerConfig, intValue]
// IntSlice looks up the value of a local IntSliceFlag, returns
// nil if not found
func (cCtx *Context) IntSlice(name string) []int64 {
if fs := cCtx.lookupFlagSet(name); fs != nil {
return lookupIntSlice(name, fs)
if v, ok := cCtx.Value(name).([]int64); ok {
return v
}
return nil
}

func lookupIntSlice(name string, set *flag.FlagSet) []int64 {
f := set.Lookup(name)
if f != nil {
if slice, ok := f.Value.(flag.Getter).Get().([]int64); ok {
return slice
}
}
return nil
}
16 changes: 2 additions & 14 deletions flag_string_map.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package cli

import "flag"

type StringMap = MapBase[string, StringConfig, stringValue]
type StringMapFlag = FlagBase[map[string]string, StringConfig, StringMap]

Expand All @@ -10,18 +8,8 @@ var NewStringMap = NewMapBase[string, StringConfig, stringValue]
// StringMap looks up the value of a local StringMapFlag, returns
// nil if not found
func (cCtx *Context) StringMap(name string) map[string]string {
if fs := cCtx.lookupFlagSet(name); fs != nil {
return lookupStringMap(name, fs)
}
return nil
}

func lookupStringMap(name string, set *flag.FlagSet) map[string]string {
f := set.Lookup(name)
if f != nil {
if mapping, ok := f.Value.(flag.Getter).Get().(map[string]string); ok {
return mapping
}
if v, ok := cCtx.Value(name).(map[string]string); ok {
return v
}
return nil
}
18 changes: 2 additions & 16 deletions flag_string_slice.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
package cli

import (
"flag"
)

type StringSlice = SliceBase[string, StringConfig, stringValue]
type StringSliceFlag = FlagBase[[]string, StringConfig, StringSlice]

Expand All @@ -12,18 +8,8 @@ var NewStringSlice = NewSliceBase[string, StringConfig, stringValue]
// StringSlice looks up the value of a local StringSliceFlag, returns
// nil if not found
func (cCtx *Context) StringSlice(name string) []string {
if fs := cCtx.lookupFlagSet(name); fs != nil {
return lookupStringSlice(name, fs)
}
return nil
}

func lookupStringSlice(name string, set *flag.FlagSet) []string {
f := set.Lookup(name)
if f != nil {
if slice, ok := f.Value.(flag.Getter).Get().([]string); ok {
return slice
}
if v, ok := cCtx.Value(name).([]string); ok {
return v
}
return nil
}
18 changes: 2 additions & 16 deletions flag_uint_slice.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
package cli

import (
"flag"
)

type UintSlice = SliceBase[uint64, IntegerConfig, uintValue]
type UintSliceFlag = FlagBase[[]uint64, IntegerConfig, UintSlice]

Expand All @@ -12,18 +8,8 @@ var NewUintSlice = NewSliceBase[uint64, IntegerConfig, uintValue]
// UintSlice looks up the value of a local UintSliceFlag, returns
// nil if not found
func (cCtx *Context) UintSlice(name string) []uint64 {
if fs := cCtx.lookupFlagSet(name); fs != nil {
return lookupUintSlice(name, fs)
}
return nil
}

func lookupUintSlice(name string, set *flag.FlagSet) []uint64 {
f := set.Lookup(name)
if f != nil {
if slice, ok := f.Value.(flag.Getter).Get().([]uint64); ok {
return slice
}
if v, ok := cCtx.Value(name).([]uint64); ok {
return v
}
return nil
}
Loading

0 comments on commit fc3b515

Please sign in to comment.