Skip to content

Commit

Permalink
Merge pull request #3 from junk1tm/better-print-defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
cristaloleg authored Oct 30, 2022
2 parents 499df2b + 5d104da commit 9da1a13
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 10 deletions.
136 changes: 126 additions & 10 deletions flagx.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,29 @@ package flagx
import (
"encoding"
"flag"
"fmt"
"io"
"reflect"
"strings"
"time"
)

// FlagSet represents a set of flags.
// Flag names must be unique within a FlagSet.
// An attempt to define a flag whose name is already in use will cause a panic.
type FlagSet struct {
fs *flag.FlagSet
fs *flag.FlagSet
aliases map[string]string // a mapping from a flag's name to its alias, empty value means no alias is defined.
}

// NewFlagSet returns new FlagSet.
func NewFlagSet(name string, output io.Writer) *FlagSet {
fs := flag.NewFlagSet(name, flag.ContinueOnError)
fs.SetOutput(output)
return &FlagSet{fs: fs}
return &FlagSet{
fs: fs,
aliases: make(map[string]string),
}
}

// AsStdlib returns *flag.FlagSet with all flags.
Expand Down Expand Up @@ -49,21 +56,20 @@ func (f *FlagSet) IsParsed() bool { return f.fs.Parsed() }
// are defined and before flags are accessed by the program.
// The return value will be ErrHelp if -help or -h were set but not defined.
func (f *FlagSet) Parse(arguments []string) error { return f.fs.Parse(arguments) }
func (f *FlagSet) VisitAll(fn func(*flag.Flag)) { f.fs.VisitAll(fn) }
func (f *FlagSet) Visit(fn func(*flag.Flag)) { f.fs.Visit(fn) }
func (f *FlagSet) Lookup(name string) *flag.Flag { return f.fs.Lookup(name) }
func (f *FlagSet) Set(name, value string) error { return f.fs.Set(name, value) }

func (f *FlagSet) PrintDefaults() { f.fs.PrintDefaults() }
func (f *FlagSet) VisitAll(fn func(*flag.Flag)) { f.fs.VisitAll(fn) }
func (f *FlagSet) Visit(fn func(*flag.Flag)) { f.fs.Visit(fn) }
func (f *FlagSet) Lookup(name string) *flag.Flag { return f.fs.Lookup(name) }
func (f *FlagSet) Set(name, value string) error { return f.fs.Set(name, value) }

// defines a flag with the specified name and usage string. The type and
// Var defines a flag with the specified name and usage string. The type and
// value of the flag are represented by the first argument, of type Value, which
// typically holds a user-defined implementation of Value. For instance, the
// caller could create a flag that turns a comma-separated string into a slice
// of strings by giving the slice the methods of Value; in particular, Set would
// decompose the comma-separated string into the slice.
// Empty string for alias means no alias will be created.
func (f *FlagSet) Var(value flag.Value, name, alias, usage string) {
f.aliases[name] = alias
f.fs.Var(value, name, usage)
if alias != "" {
f.fs.Var(value, alias, usage)
Expand All @@ -75,9 +81,10 @@ func (f *FlagSet) Var(value flag.Value, name, alias, usage string) {
// If fn returns a non-nil error, it will be treated as a flag value parsing error.
// Empty string for alias means no alias will be created.
func (f *FlagSet) Func(name, alias, usage string, fn func(string) error) {
f.aliases[name] = alias
f.fs.Func(name, usage, fn)
if alias != "" {
f.fs.Func(name, alias, fn)
f.fs.Func(alias, usage, fn)
}
}

Expand All @@ -89,6 +96,7 @@ func (f *FlagSet) Func(name, alias, usage string, fn func(string) error) {
// Empty string for alias means no alias will be created.
func (f *FlagSet) Text(p encoding.TextUnmarshaler, name, alias string, value encoding.TextMarshaler, usage string) {
// TODO(cristaloleg): for Go 1.19 this can be f.fs.TextVar(...)
f.aliases[name] = alias
f.fs.Var(newTextValue(value, p), name, usage)
if alias != "" {
f.fs.Var(newTextValue(value, p), alias, usage)
Expand All @@ -99,6 +107,7 @@ func (f *FlagSet) Text(p encoding.TextUnmarshaler, name, alias string, value enc
// The argument p points to a bool variable in which to store the value of the flag.
// Empty string for alias means no alias will be created.
func (f *FlagSet) Bool(p *bool, name, alias string, value bool, usage string) {
f.aliases[name] = alias
f.fs.BoolVar(p, name, value, usage)
if alias != "" {
f.fs.BoolVar(p, alias, value, usage)
Expand All @@ -109,6 +118,7 @@ func (f *FlagSet) Bool(p *bool, name, alias string, value bool, usage string) {
// The argument p points to an int variable in which to store the value of the flag.
// Empty string for alias means no alias will be created.
func (f *FlagSet) Int(p *int, name, alias string, value int, usage string) {
f.aliases[name] = alias
f.fs.IntVar(p, name, value, usage)
if alias != "" {
f.fs.IntVar(p, alias, value, usage)
Expand All @@ -119,6 +129,7 @@ func (f *FlagSet) Int(p *int, name, alias string, value int, usage string) {
// The argument p points to an int64 variable in which to store the value of the flag.
// Empty string for alias means no alias will be created.
func (f *FlagSet) Int64(p *int64, name, alias string, value int64, usage string) {
f.aliases[name] = alias
f.fs.Int64Var(p, name, value, usage)
if alias != "" {
f.fs.Int64Var(p, alias, value, usage)
Expand All @@ -129,6 +140,7 @@ func (f *FlagSet) Int64(p *int64, name, alias string, value int64, usage string)
// The argument p points to a uint variable in which to store the value of the flag.
// Empty string for alias means no alias will be created.
func (f *FlagSet) Uint(p *uint, name, alias string, value uint, usage string) {
f.aliases[name] = alias
f.fs.UintVar(p, name, value, usage)
if alias != "" {
f.fs.UintVar(p, alias, value, usage)
Expand All @@ -139,6 +151,7 @@ func (f *FlagSet) Uint(p *uint, name, alias string, value uint, usage string) {
// The argument p points to a uint64 variable in which to store the value of the flag.
// Empty string for alias means no alias will be created.
func (f *FlagSet) Uint64(p *uint64, name, alias string, value uint64, usage string) {
f.aliases[name] = alias
f.fs.Uint64Var(p, name, value, usage)
if alias != "" {
f.fs.Uint64Var(p, alias, value, usage)
Expand All @@ -149,6 +162,7 @@ func (f *FlagSet) Uint64(p *uint64, name, alias string, value uint64, usage stri
// The argument p points to a string variable in which to store the value of the flag.
// Empty string for alias means no alias will be created.
func (f *FlagSet) String(p *string, name, alias, value, usage string) {
f.aliases[name] = alias
f.fs.StringVar(p, name, value, usage)
if alias != "" {
f.fs.StringVar(p, alias, value, usage)
Expand All @@ -159,6 +173,7 @@ func (f *FlagSet) String(p *string, name, alias, value, usage string) {
// The argument p points to a float64 variable in which to store the value of the flag.
// Empty string for alias means no alias will be created.
func (f *FlagSet) Float64(p *float64, name, alias string, value float64, usage string) {
f.aliases[name] = alias
f.fs.Float64Var(p, name, value, usage)
if alias != "" {
f.fs.Float64Var(p, alias, value, usage)
Expand All @@ -170,6 +185,7 @@ func (f *FlagSet) Float64(p *float64, name, alias string, value float64, usage s
// The flag accepts a value acceptable to time.ParseDuration.
// Empty string for alias means no alias will be created.
func (f *FlagSet) Duration(p *time.Duration, name, alias string, value time.Duration, usage string) {
f.aliases[name] = alias
f.fs.DurationVar(p, name, value, usage)
if alias != "" {
f.fs.DurationVar(p, alias, value, usage)
Expand All @@ -180,6 +196,7 @@ func (f *FlagSet) Duration(p *time.Duration, name, alias string, value time.Dura
// The argument p points to a bool variable in which to store the value of the flag.
// Empty string for alias means no alias will be created.
func (f *FlagSet) BoolSlice(p *[]bool, name, alias string, value []bool, usage string) {
f.aliases[name] = alias
var sb SliceOfBool
*p = []bool(sb)
f.fs.Var(&sb, name, usage)
Expand All @@ -192,6 +209,7 @@ func (f *FlagSet) BoolSlice(p *[]bool, name, alias string, value []bool, usage s
// The argument p points to an int variable in which to store the value of the flag.
// Empty string for alias means no alias will be created.
func (f *FlagSet) IntSlice(p *[]int, name, alias string, value []int, usage string) {
f.aliases[name] = alias
var si SliceOfInt = value
*p = []int(si)
f.fs.Var(&si, name, usage)
Expand All @@ -204,6 +222,7 @@ func (f *FlagSet) IntSlice(p *[]int, name, alias string, value []int, usage stri
// The argument p points to an int64 variable in which to store the value of the flag.
// Empty string for alias means no alias will be created.
func (f *FlagSet) Int64Slice(p *[]int64, name, alias string, value []int64, usage string) {
f.aliases[name] = alias
var si SliceOfInt64 = value
*p = []int64(si)
f.fs.Var(&si, name, usage)
Expand All @@ -216,6 +235,7 @@ func (f *FlagSet) Int64Slice(p *[]int64, name, alias string, value []int64, usag
// The argument p points to a uint variable in which to store the value of the flag.
// Empty string for alias means no alias will be created.
func (f *FlagSet) UintSlice(p *[]uint, name, alias string, value []uint, usage string) {
f.aliases[name] = alias
var su SliceOfUint = value
*p = []uint(su)
f.fs.Var(&su, name, usage)
Expand All @@ -228,6 +248,7 @@ func (f *FlagSet) UintSlice(p *[]uint, name, alias string, value []uint, usage s
// The argument p points to a uint64 variable in which to store the value of the flag.
// Empty string for alias means no alias will be created.
func (f *FlagSet) Uint64Slice(p *[]uint64, name, alias string, value []uint64, usage string) {
f.aliases[name] = alias
var su SliceOfUint64 = value
*p = []uint64(su)
f.fs.Var(&su, name, usage)
Expand All @@ -240,6 +261,7 @@ func (f *FlagSet) Uint64Slice(p *[]uint64, name, alias string, value []uint64, u
// The argument p points to a string variable in which to store the value of the flag.
// Empty string for alias means no alias will be created.
func (f *FlagSet) StringSlice(p *[]string, name, alias string, value []string, usage string) {
f.aliases[name] = alias
var ss SliceOfString = value
*p = []string(ss)
f.fs.Var(&ss, name, usage)
Expand All @@ -252,6 +274,7 @@ func (f *FlagSet) StringSlice(p *[]string, name, alias string, value []string, u
// The argument p points to a float64 variable in which to store the value of the flag.
// Empty string for alias means no alias will be created.
func (f *FlagSet) Float64Slice(p *[]float64, name, alias string, value []float64, usage string) {
f.aliases[name] = alias
var sf SliceOfFloat64 = value
*p = []float64(sf)
f.fs.Var(&sf, name, usage)
Expand All @@ -265,6 +288,7 @@ func (f *FlagSet) Float64Slice(p *[]float64, name, alias string, value []float64
// The flag accepts a value acceptable to time.ParseDuration.
// Empty string for alias means no alias will be created.
func (f *FlagSet) DurationSlice(p *[]time.Duration, name, alias string, value []time.Duration, usage string) {
f.aliases[name] = alias
var sd SliceOfDuration = value
*p = []time.Duration(sd)
f.fs.Var(&sd, name, usage)
Expand All @@ -277,6 +301,7 @@ func (f *FlagSet) DurationSlice(p *[]time.Duration, name, alias string, value []
// The argument p points to an int variable in which to store the value of the flag.
// Empty string for alias means no alias will be created.
func (f *FlagSet) IntSet(p *map[int]struct{}, name, alias string, value map[int]struct{}, usage string) {
f.aliases[name] = alias
var si SetOfInt = value
*p = map[int]struct{}(si)
f.fs.Var(&si, name, usage)
Expand All @@ -289,6 +314,7 @@ func (f *FlagSet) IntSet(p *map[int]struct{}, name, alias string, value map[int]
// The argument p points to an int64 variable in which to store the value of the flag.
// Empty string for alias means no alias will be created.
func (f *FlagSet) Int64Set(p *map[int64]struct{}, name, alias string, value map[int64]struct{}, usage string) {
f.aliases[name] = alias
var si SetOfInt64 = value
*p = map[int64]struct{}(si)
f.fs.Var(&si, name, usage)
Expand All @@ -301,6 +327,7 @@ func (f *FlagSet) Int64Set(p *map[int64]struct{}, name, alias string, value map[
// The argument p points to a uint variable in which to store the value of the flag.
// Empty string for alias means no alias will be created.
func (f *FlagSet) UintSet(p *map[uint]struct{}, name, alias string, value map[uint]struct{}, usage string) {
f.aliases[name] = alias
var su SetOfUint = value
*p = map[uint]struct{}(su)
f.fs.Var(&su, name, usage)
Expand All @@ -313,6 +340,7 @@ func (f *FlagSet) UintSet(p *map[uint]struct{}, name, alias string, value map[ui
// The argument p points to a uint64 variable in which to store the value of the flag.
// Empty string for alias means no alias will be created.
func (f *FlagSet) Uint64Set(p *map[uint64]struct{}, name, alias string, value map[uint64]struct{}, usage string) {
f.aliases[name] = alias
var su SetOfUint64 = value
*p = map[uint64]struct{}(su)
f.fs.Var(&su, name, usage)
Expand All @@ -325,6 +353,7 @@ func (f *FlagSet) Uint64Set(p *map[uint64]struct{}, name, alias string, value ma
// The argument p points to a string variable in which to store the value of the flag.
// Empty string for alias means no alias will be created.
func (f *FlagSet) StringSet(p *map[string]struct{}, name, alias string, value map[string]struct{}, usage string) {
f.aliases[name] = alias
var ss SetOfString = value
*p = map[string]struct{}(ss)
f.fs.Var(&ss, name, usage)
Expand All @@ -337,6 +366,7 @@ func (f *FlagSet) StringSet(p *map[string]struct{}, name, alias string, value ma
// The argument p points to a float64 variable in which to store the value of the flag.
// Empty string for alias means no alias will be created.
func (f *FlagSet) Float64Set(p *map[float64]struct{}, name, alias string, value map[float64]struct{}, usage string) {
f.aliases[name] = alias
var sf SetOfFloat64 = value
*p = map[float64]struct{}(sf)
f.fs.Var(&sf, name, usage)
Expand All @@ -350,10 +380,96 @@ func (f *FlagSet) Float64Set(p *map[float64]struct{}, name, alias string, value
// The flag accepts a value acceptable to time.ParseDuration.
// Empty string for alias means no alias will be created.
func (f *FlagSet) DurationSet(p *map[time.Duration]struct{}, name, alias string, value map[time.Duration]struct{}, usage string) {
f.aliases[name] = alias
var sd SetOfDuration = value
*p = map[time.Duration]struct{}(sd)
f.fs.Var(&sd, name, usage)
if alias != "" {
f.fs.Var(&sd, alias, usage)
}
}

// PrintDefaults prints, to standard error unless configured otherwise, the
// default values of all defined command-line flags in the set.
func (f *FlagSet) PrintDefaults() {
// NOTE(junk1tm): copy-pasted from flag.PrintDefaults with a few modifications to support aliases.

var isZeroValueErrs []error
f.VisitAll(func(fl *flag.Flag) {
if _, ok := f.aliases[fl.Name]; !ok {
// The flag is an alias, do not print it separately.
return
}
var b strings.Builder
fmt.Fprintf(&b, " -%s", fl.Name) // Two spaces before -; see next two comments.
if alias := f.aliases[fl.Name]; alias != "" {
fmt.Fprintf(&b, " (-%s)", alias)
}
name, usage := flag.UnquoteUsage(fl)
if len(name) > 0 {
b.WriteString(" ")
b.WriteString(name)
}
// Boolean flags of one ASCII letter are so common we
// treat them specially, putting their usage on the same line.
if b.Len() <= 4 { // space, space, '-', 'x'.
b.WriteString("\t")
} else {
// Four spaces before the tab triggers good alignment
// for both 4- and 8-space tab stops.
b.WriteString("\n \t")
}
b.WriteString(strings.ReplaceAll(usage, "\n", "\n \t"))

// Print the default value only if it differs to the zero value
// for this flag type.
if isZero, err := isZeroValue(fl, fl.DefValue); err != nil {
isZeroValueErrs = append(isZeroValueErrs, err)
} else if !isZero {
// HACK(junk1tm): flag.stringValue is unexported, so we have to compare the type's name.
if fmt.Sprintf("%T", fl.Value) == "*flag.stringValue" {
// put quotes on the value
fmt.Fprintf(&b, " (default %q)", fl.DefValue)
} else {
fmt.Fprintf(&b, " (default %v)", fl.DefValue)
}
}
fmt.Fprint(f.fs.Output(), b.String(), "\n")
})
// If calling String on any zero flag.Values triggered a panic, print
// the messages after the full set of defaults so that the programmer
// knows to fix the panic.
if errs := isZeroValueErrs; len(errs) > 0 {
fmt.Fprintln(f.fs.Output())
for _, err := range errs {
fmt.Fprintln(f.fs.Output(), err)
}
}
}

func isZeroValue(fl *flag.Flag, value string) (ok bool, err error) {
// NOTE(junk1tm): copy-pasted from flag.isZeroValue as a part of flag.PrintDefaults.

// Build a zero value of the flag's Value type, and see if the
// result of calling its String method equals the value passed in.
// This works unless the Value type is itself an interface type.
typ := reflect.TypeOf(fl.Value)
var z reflect.Value
if typ.Kind() == reflect.Pointer {
z = reflect.New(typ.Elem())
} else {
z = reflect.Zero(typ)
}
// Catch panics calling the String method, which shouldn't prevent the
// usage message from being printed, but that we should report to the
// user so that they know to fix their code.
defer func() {
if e := recover(); e != nil {
if typ.Kind() == reflect.Pointer {
typ = typ.Elem()
}
err = fmt.Errorf("panic calling String method on zero %v for flag %s: %v", typ, fl.Name, e)
}
}()
return value == z.Interface().(flag.Value).String(), nil
}
12 changes: 12 additions & 0 deletions flagx_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package flagx

import (
"bytes"
"flag"
"os"
"reflect"
Expand Down Expand Up @@ -34,6 +35,17 @@ func TestFlagSet(t *testing.T) {
mustEqual(t, offsets, wantOffsets)
}

func TestFlagSet_PrintDefaults(t *testing.T) {
const usage = ` -timeout (-t) duration
just a timeout (default 10s)
`
var buf bytes.Buffer
fset := NewFlagSet("testing", &buf)
fset.Duration(new(time.Duration), "timeout", "t", 10*time.Second, "just a timeout")
fset.PrintDefaults()
mustEqual(t, buf.String(), usage)
}

func failIfErr(t testing.TB, err error) {
t.Helper()
if err != nil {
Expand Down

0 comments on commit 9da1a13

Please sign in to comment.