diff --git a/flagx.go b/flagx.go index 725ad89..b40c711 100644 --- a/flagx.go +++ b/flagx.go @@ -3,7 +3,10 @@ package flagx import ( "encoding" "flag" + "fmt" "io" + "reflect" + "strings" "time" ) @@ -11,14 +14,18 @@ import ( // Flag names must be unique within a FlagSet. // An attempt to define a flag whose name is already in use will cause a panic. type FlagSet struct { - fs *flag.FlagSet + fs *flag.FlagSet + aliases map[string]string // a mapping from a flag's name to its alias, empty value means no alias is defined. } // NewFlagSet returns new FlagSet. func NewFlagSet(name string, output io.Writer) *FlagSet { fs := flag.NewFlagSet(name, flag.ContinueOnError) fs.SetOutput(output) - return &FlagSet{fs: fs} + return &FlagSet{ + fs: fs, + aliases: make(map[string]string), + } } // AsStdlib returns *flag.FlagSet with all flags. @@ -49,14 +56,12 @@ func (f *FlagSet) IsParsed() bool { return f.fs.Parsed() } // are defined and before flags are accessed by the program. // The return value will be ErrHelp if -help or -h were set but not defined. func (f *FlagSet) Parse(arguments []string) error { return f.fs.Parse(arguments) } +func (f *FlagSet) VisitAll(fn func(*flag.Flag)) { f.fs.VisitAll(fn) } +func (f *FlagSet) Visit(fn func(*flag.Flag)) { f.fs.Visit(fn) } +func (f *FlagSet) Lookup(name string) *flag.Flag { return f.fs.Lookup(name) } +func (f *FlagSet) Set(name, value string) error { return f.fs.Set(name, value) } -func (f *FlagSet) PrintDefaults() { f.fs.PrintDefaults() } -func (f *FlagSet) VisitAll(fn func(*flag.Flag)) { f.fs.VisitAll(fn) } -func (f *FlagSet) Visit(fn func(*flag.Flag)) { f.fs.Visit(fn) } -func (f *FlagSet) Lookup(name string) *flag.Flag { return f.fs.Lookup(name) } -func (f *FlagSet) Set(name, value string) error { return f.fs.Set(name, value) } - -// defines a flag with the specified name and usage string. The type and +// Var defines a flag with the specified name and usage string. The type and // value of the flag are represented by the first argument, of type Value, which // typically holds a user-defined implementation of Value. For instance, the // caller could create a flag that turns a comma-separated string into a slice @@ -64,6 +69,7 @@ func (f *FlagSet) Set(name, value string) error { return f.fs.Set(name, value) // decompose the comma-separated string into the slice. // Empty string for alias means no alias will be created. func (f *FlagSet) Var(value flag.Value, name, alias, usage string) { + f.aliases[name] = alias f.fs.Var(value, name, usage) if alias != "" { f.fs.Var(value, alias, usage) @@ -75,9 +81,10 @@ func (f *FlagSet) Var(value flag.Value, name, alias, usage string) { // If fn returns a non-nil error, it will be treated as a flag value parsing error. // Empty string for alias means no alias will be created. func (f *FlagSet) Func(name, alias, usage string, fn func(string) error) { + f.aliases[name] = alias f.fs.Func(name, usage, fn) if alias != "" { - f.fs.Func(name, alias, fn) + f.fs.Func(alias, usage, fn) } } @@ -89,6 +96,7 @@ func (f *FlagSet) Func(name, alias, usage string, fn func(string) error) { // Empty string for alias means no alias will be created. func (f *FlagSet) Text(p encoding.TextUnmarshaler, name, alias string, value encoding.TextMarshaler, usage string) { // TODO(cristaloleg): for Go 1.19 this can be f.fs.TextVar(...) + f.aliases[name] = alias f.fs.Var(newTextValue(value, p), name, usage) if alias != "" { f.fs.Var(newTextValue(value, p), alias, usage) @@ -99,6 +107,7 @@ func (f *FlagSet) Text(p encoding.TextUnmarshaler, name, alias string, value enc // The argument p points to a bool variable in which to store the value of the flag. // Empty string for alias means no alias will be created. func (f *FlagSet) Bool(p *bool, name, alias string, value bool, usage string) { + f.aliases[name] = alias f.fs.BoolVar(p, name, value, usage) if alias != "" { f.fs.BoolVar(p, alias, value, usage) @@ -109,6 +118,7 @@ func (f *FlagSet) Bool(p *bool, name, alias string, value bool, usage string) { // The argument p points to an int variable in which to store the value of the flag. // Empty string for alias means no alias will be created. func (f *FlagSet) Int(p *int, name, alias string, value int, usage string) { + f.aliases[name] = alias f.fs.IntVar(p, name, value, usage) if alias != "" { f.fs.IntVar(p, alias, value, usage) @@ -119,6 +129,7 @@ func (f *FlagSet) Int(p *int, name, alias string, value int, usage string) { // The argument p points to an int64 variable in which to store the value of the flag. // Empty string for alias means no alias will be created. func (f *FlagSet) Int64(p *int64, name, alias string, value int64, usage string) { + f.aliases[name] = alias f.fs.Int64Var(p, name, value, usage) if alias != "" { f.fs.Int64Var(p, alias, value, usage) @@ -129,6 +140,7 @@ func (f *FlagSet) Int64(p *int64, name, alias string, value int64, usage string) // The argument p points to a uint variable in which to store the value of the flag. // Empty string for alias means no alias will be created. func (f *FlagSet) Uint(p *uint, name, alias string, value uint, usage string) { + f.aliases[name] = alias f.fs.UintVar(p, name, value, usage) if alias != "" { f.fs.UintVar(p, alias, value, usage) @@ -139,6 +151,7 @@ func (f *FlagSet) Uint(p *uint, name, alias string, value uint, usage string) { // The argument p points to a uint64 variable in which to store the value of the flag. // Empty string for alias means no alias will be created. func (f *FlagSet) Uint64(p *uint64, name, alias string, value uint64, usage string) { + f.aliases[name] = alias f.fs.Uint64Var(p, name, value, usage) if alias != "" { f.fs.Uint64Var(p, alias, value, usage) @@ -149,6 +162,7 @@ func (f *FlagSet) Uint64(p *uint64, name, alias string, value uint64, usage stri // The argument p points to a string variable in which to store the value of the flag. // Empty string for alias means no alias will be created. func (f *FlagSet) String(p *string, name, alias, value, usage string) { + f.aliases[name] = alias f.fs.StringVar(p, name, value, usage) if alias != "" { f.fs.StringVar(p, alias, value, usage) @@ -159,6 +173,7 @@ func (f *FlagSet) String(p *string, name, alias, value, usage string) { // The argument p points to a float64 variable in which to store the value of the flag. // Empty string for alias means no alias will be created. func (f *FlagSet) Float64(p *float64, name, alias string, value float64, usage string) { + f.aliases[name] = alias f.fs.Float64Var(p, name, value, usage) if alias != "" { f.fs.Float64Var(p, alias, value, usage) @@ -170,6 +185,7 @@ func (f *FlagSet) Float64(p *float64, name, alias string, value float64, usage s // The flag accepts a value acceptable to time.ParseDuration. // Empty string for alias means no alias will be created. func (f *FlagSet) Duration(p *time.Duration, name, alias string, value time.Duration, usage string) { + f.aliases[name] = alias f.fs.DurationVar(p, name, value, usage) if alias != "" { f.fs.DurationVar(p, alias, value, usage) @@ -180,6 +196,7 @@ func (f *FlagSet) Duration(p *time.Duration, name, alias string, value time.Dura // The argument p points to a bool variable in which to store the value of the flag. // Empty string for alias means no alias will be created. func (f *FlagSet) BoolSlice(p *[]bool, name, alias string, value []bool, usage string) { + f.aliases[name] = alias var sb SliceOfBool *p = []bool(sb) f.fs.Var(&sb, name, usage) @@ -192,6 +209,7 @@ func (f *FlagSet) BoolSlice(p *[]bool, name, alias string, value []bool, usage s // The argument p points to an int variable in which to store the value of the flag. // Empty string for alias means no alias will be created. func (f *FlagSet) IntSlice(p *[]int, name, alias string, value []int, usage string) { + f.aliases[name] = alias var si SliceOfInt = value *p = []int(si) f.fs.Var(&si, name, usage) @@ -204,6 +222,7 @@ func (f *FlagSet) IntSlice(p *[]int, name, alias string, value []int, usage stri // The argument p points to an int64 variable in which to store the value of the flag. // Empty string for alias means no alias will be created. func (f *FlagSet) Int64Slice(p *[]int64, name, alias string, value []int64, usage string) { + f.aliases[name] = alias var si SliceOfInt64 = value *p = []int64(si) f.fs.Var(&si, name, usage) @@ -216,6 +235,7 @@ func (f *FlagSet) Int64Slice(p *[]int64, name, alias string, value []int64, usag // The argument p points to a uint variable in which to store the value of the flag. // Empty string for alias means no alias will be created. func (f *FlagSet) UintSlice(p *[]uint, name, alias string, value []uint, usage string) { + f.aliases[name] = alias var su SliceOfUint = value *p = []uint(su) f.fs.Var(&su, name, usage) @@ -228,6 +248,7 @@ func (f *FlagSet) UintSlice(p *[]uint, name, alias string, value []uint, usage s // The argument p points to a uint64 variable in which to store the value of the flag. // Empty string for alias means no alias will be created. func (f *FlagSet) Uint64Slice(p *[]uint64, name, alias string, value []uint64, usage string) { + f.aliases[name] = alias var su SliceOfUint64 = value *p = []uint64(su) f.fs.Var(&su, name, usage) @@ -240,6 +261,7 @@ func (f *FlagSet) Uint64Slice(p *[]uint64, name, alias string, value []uint64, u // The argument p points to a string variable in which to store the value of the flag. // Empty string for alias means no alias will be created. func (f *FlagSet) StringSlice(p *[]string, name, alias string, value []string, usage string) { + f.aliases[name] = alias var ss SliceOfString = value *p = []string(ss) f.fs.Var(&ss, name, usage) @@ -252,6 +274,7 @@ func (f *FlagSet) StringSlice(p *[]string, name, alias string, value []string, u // The argument p points to a float64 variable in which to store the value of the flag. // Empty string for alias means no alias will be created. func (f *FlagSet) Float64Slice(p *[]float64, name, alias string, value []float64, usage string) { + f.aliases[name] = alias var sf SliceOfFloat64 = value *p = []float64(sf) f.fs.Var(&sf, name, usage) @@ -265,6 +288,7 @@ func (f *FlagSet) Float64Slice(p *[]float64, name, alias string, value []float64 // The flag accepts a value acceptable to time.ParseDuration. // Empty string for alias means no alias will be created. func (f *FlagSet) DurationSlice(p *[]time.Duration, name, alias string, value []time.Duration, usage string) { + f.aliases[name] = alias var sd SliceOfDuration = value *p = []time.Duration(sd) f.fs.Var(&sd, name, usage) @@ -277,6 +301,7 @@ func (f *FlagSet) DurationSlice(p *[]time.Duration, name, alias string, value [] // The argument p points to an int variable in which to store the value of the flag. // Empty string for alias means no alias will be created. func (f *FlagSet) IntSet(p *map[int]struct{}, name, alias string, value map[int]struct{}, usage string) { + f.aliases[name] = alias var si SetOfInt = value *p = map[int]struct{}(si) f.fs.Var(&si, name, usage) @@ -289,6 +314,7 @@ func (f *FlagSet) IntSet(p *map[int]struct{}, name, alias string, value map[int] // The argument p points to an int64 variable in which to store the value of the flag. // Empty string for alias means no alias will be created. func (f *FlagSet) Int64Set(p *map[int64]struct{}, name, alias string, value map[int64]struct{}, usage string) { + f.aliases[name] = alias var si SetOfInt64 = value *p = map[int64]struct{}(si) f.fs.Var(&si, name, usage) @@ -301,6 +327,7 @@ func (f *FlagSet) Int64Set(p *map[int64]struct{}, name, alias string, value map[ // The argument p points to a uint variable in which to store the value of the flag. // Empty string for alias means no alias will be created. func (f *FlagSet) UintSet(p *map[uint]struct{}, name, alias string, value map[uint]struct{}, usage string) { + f.aliases[name] = alias var su SetOfUint = value *p = map[uint]struct{}(su) f.fs.Var(&su, name, usage) @@ -313,6 +340,7 @@ func (f *FlagSet) UintSet(p *map[uint]struct{}, name, alias string, value map[ui // The argument p points to a uint64 variable in which to store the value of the flag. // Empty string for alias means no alias will be created. func (f *FlagSet) Uint64Set(p *map[uint64]struct{}, name, alias string, value map[uint64]struct{}, usage string) { + f.aliases[name] = alias var su SetOfUint64 = value *p = map[uint64]struct{}(su) f.fs.Var(&su, name, usage) @@ -325,6 +353,7 @@ func (f *FlagSet) Uint64Set(p *map[uint64]struct{}, name, alias string, value ma // The argument p points to a string variable in which to store the value of the flag. // Empty string for alias means no alias will be created. func (f *FlagSet) StringSet(p *map[string]struct{}, name, alias string, value map[string]struct{}, usage string) { + f.aliases[name] = alias var ss SetOfString = value *p = map[string]struct{}(ss) f.fs.Var(&ss, name, usage) @@ -337,6 +366,7 @@ func (f *FlagSet) StringSet(p *map[string]struct{}, name, alias string, value ma // The argument p points to a float64 variable in which to store the value of the flag. // Empty string for alias means no alias will be created. func (f *FlagSet) Float64Set(p *map[float64]struct{}, name, alias string, value map[float64]struct{}, usage string) { + f.aliases[name] = alias var sf SetOfFloat64 = value *p = map[float64]struct{}(sf) f.fs.Var(&sf, name, usage) @@ -350,6 +380,7 @@ func (f *FlagSet) Float64Set(p *map[float64]struct{}, name, alias string, value // The flag accepts a value acceptable to time.ParseDuration. // Empty string for alias means no alias will be created. func (f *FlagSet) DurationSet(p *map[time.Duration]struct{}, name, alias string, value map[time.Duration]struct{}, usage string) { + f.aliases[name] = alias var sd SetOfDuration = value *p = map[time.Duration]struct{}(sd) f.fs.Var(&sd, name, usage) @@ -357,3 +388,88 @@ func (f *FlagSet) DurationSet(p *map[time.Duration]struct{}, name, alias string, f.fs.Var(&sd, alias, usage) } } + +// PrintDefaults prints, to standard error unless configured otherwise, the +// default values of all defined command-line flags in the set. +func (f *FlagSet) PrintDefaults() { + // NOTE(junk1tm): copy-pasted from flag.PrintDefaults with a few modifications to support aliases. + + var isZeroValueErrs []error + f.VisitAll(func(fl *flag.Flag) { + if _, ok := f.aliases[fl.Name]; !ok { + // The flag is an alias, do not print it separately. + return + } + var b strings.Builder + fmt.Fprintf(&b, " -%s", fl.Name) // Two spaces before -; see next two comments. + if alias := f.aliases[fl.Name]; alias != "" { + fmt.Fprintf(&b, " (-%s)", alias) + } + name, usage := flag.UnquoteUsage(fl) + if len(name) > 0 { + b.WriteString(" ") + b.WriteString(name) + } + // Boolean flags of one ASCII letter are so common we + // treat them specially, putting their usage on the same line. + if b.Len() <= 4 { // space, space, '-', 'x'. + b.WriteString("\t") + } else { + // Four spaces before the tab triggers good alignment + // for both 4- and 8-space tab stops. + b.WriteString("\n \t") + } + b.WriteString(strings.ReplaceAll(usage, "\n", "\n \t")) + + // Print the default value only if it differs to the zero value + // for this flag type. + if isZero, err := isZeroValue(fl, fl.DefValue); err != nil { + isZeroValueErrs = append(isZeroValueErrs, err) + } else if !isZero { + // HACK(junk1tm): flag.stringValue is unexported, so we have to compare the type's name. + if fmt.Sprintf("%T", fl.Value) == "*flag.stringValue" { + // put quotes on the value + fmt.Fprintf(&b, " (default %q)", fl.DefValue) + } else { + fmt.Fprintf(&b, " (default %v)", fl.DefValue) + } + } + fmt.Fprint(f.fs.Output(), b.String(), "\n") + }) + // If calling String on any zero flag.Values triggered a panic, print + // the messages after the full set of defaults so that the programmer + // knows to fix the panic. + if errs := isZeroValueErrs; len(errs) > 0 { + fmt.Fprintln(f.fs.Output()) + for _, err := range errs { + fmt.Fprintln(f.fs.Output(), err) + } + } +} + +func isZeroValue(fl *flag.Flag, value string) (ok bool, err error) { + // NOTE(junk1tm): copy-pasted from flag.isZeroValue as a part of flag.PrintDefaults. + + // Build a zero value of the flag's Value type, and see if the + // result of calling its String method equals the value passed in. + // This works unless the Value type is itself an interface type. + typ := reflect.TypeOf(fl.Value) + var z reflect.Value + if typ.Kind() == reflect.Pointer { + z = reflect.New(typ.Elem()) + } else { + z = reflect.Zero(typ) + } + // Catch panics calling the String method, which shouldn't prevent the + // usage message from being printed, but that we should report to the + // user so that they know to fix their code. + defer func() { + if e := recover(); e != nil { + if typ.Kind() == reflect.Pointer { + typ = typ.Elem() + } + err = fmt.Errorf("panic calling String method on zero %v for flag %s: %v", typ, fl.Name, e) + } + }() + return value == z.Interface().(flag.Value).String(), nil +} diff --git a/flagx_test.go b/flagx_test.go index 17d70b6..8b71d2c 100644 --- a/flagx_test.go +++ b/flagx_test.go @@ -1,6 +1,7 @@ package flagx import ( + "bytes" "flag" "os" "reflect" @@ -34,6 +35,17 @@ func TestFlagSet(t *testing.T) { mustEqual(t, offsets, wantOffsets) } +func TestFlagSet_PrintDefaults(t *testing.T) { + const usage = ` -timeout (-t) duration + just a timeout (default 10s) +` + var buf bytes.Buffer + fset := NewFlagSet("testing", &buf) + fset.Duration(new(time.Duration), "timeout", "t", 10*time.Second, "just a timeout") + fset.PrintDefaults() + mustEqual(t, buf.String(), usage) +} + func failIfErr(t testing.TB, err error) { t.Helper() if err != nil {