diff --git a/command.go b/command.go index d37cbad..1b03ebf 100644 --- a/command.go +++ b/command.go @@ -158,17 +158,24 @@ func (c *Command) Execute() error { return errors.New("Execute called on a nil Command") } - // Regardless of where we call execute, run it only from the root command + // Regardless of where we call execute, run it only from the root command, this is to ensure + // that when we use the arguments to go and find the subcommand to run (if needed), then we + // at the root of the command tree. if c.parent != nil { return c.root().Execute() } - if err := c.Flags().Parse(c.args); err != nil { + // Use the raw arguments and the command tree to determine which subcommand (if any) + // we should be invoking. If it turns out we want to invoke the root command, then + // cmd here will be c. + cmd, args := findRequestedCommand(c, c.args) + + if err := cmd.Flags().Parse(args); err != nil { return fmt.Errorf("failed to parse command flags: %w", err) } // Check if we should be responding to -h/--help - helpCalled, err := c.Flags().GetBool("help") + helpCalled, err := cmd.Flags().GetBool("help") if err != nil { // We shouldn't ever get here because we define a default for help return fmt.Errorf("could not parse help flag: %w", err) @@ -177,14 +184,14 @@ func (c *Command) Execute() error { // If -h/--help was called, call the defined helpFunc and exit so that // the run function is never called. if helpCalled { - if err = defaultHelp(c); err != nil { + if err = defaultHelp(cmd); err != nil { return fmt.Errorf("help function returned an error: %w", err) } return nil } // Check if we should be responding to -v/--version - versionCalled, err := c.Flags().GetBool("version") + versionCalled, err := cmd.Flags().GetBool("version") if err != nil { // Again, shouldn't ever get here return fmt.Errorf("could not parse version flag: %w", err) @@ -193,35 +200,28 @@ func (c *Command) Execute() error { // If -v/--version was called, call the defined versionFunc and exit so that // the run function is never called if versionCalled { - if c.versionFunc == nil { + if cmd.versionFunc == nil { return errors.New("versionFunc was nil") } - if err := c.versionFunc(c); err != nil { + if err := cmd.versionFunc(c); err != nil { return fmt.Errorf("version function returned an error: %w", err) } return nil } - // Not all commands are runnable, e.g. if this command is the root of a subcommand - // it will define subcommands but no run function itself. We must decide here what to do when - // this command is executed. - - // A command cannot have - // no subcommands and no run function, it must define one or the other - if c.run == nil && len(c.subcommands) == 0 { + // A command cannot have no subcommands and no run function, it must define one or the other + if cmd.run == nil && len(cmd.subcommands) == 0 { return fmt.Errorf( "command %s has no subcommands and no run function, a command must either be runnable or have subcommands", - c.name, + cmd.name, ) } - // If the command is runnable, go and execute it's run function - if c.run != nil { - return c.run(c, c.Flags().Args()) + // If the command is runnable, go and execute its run function + if cmd.run != nil { + return cmd.run(cmd, cmd.Flags().Args()) } - // TODO: If the command defines subcommands, we need to parse the args and determine which subcommand to go and run - return nil } @@ -242,17 +242,17 @@ func (c *Command) Flags() *pflag.FlagSet { // Stdout returns the configured Stdout for the Command. func (c *Command) Stdout() io.Writer { - return c.stdout + return c.root().stdout } // Stderr returns the configured Stderr for the Command. func (c *Command) Stderr() io.Writer { - return c.stderr + return c.root().stderr } // Stdin returns the configured Stdin for the Command. func (c *Command) Stdin() io.Reader { - return c.stdin + return c.root().stdin } // root returns the root of the command tree. @@ -263,6 +263,142 @@ func (c *Command) root() *Command { return c } +// hasFlag returns whether the command has a flag of the given name defined. +func (c *Command) hasFlag(name string) bool { + flag := c.Flags().Lookup(name) + if flag == nil { + return false + } + return flag.NoOptDefVal != "" +} + +// hasShortFlag returns whether the command has a shorthand flag of the given name defined. +func (c *Command) hasShortFlag(name string) bool { + if len(name) == 0 { + return false + } + + flag := c.Flags().ShorthandLookup(name[:1]) + if flag == nil { + return false + } + return flag.NoOptDefVal != "" +} + +// findRequestedCommand uses the raw arguments and the command tree to determine what +// (if any) subcommand is being requested and return that command along with the arguments +// that were meant for it. +func findRequestedCommand(cmd *Command, args []string) (*Command, []string) { + // Any arguments without flags could be names of subcommands + argsWithoutFlags := stripFlags(cmd, args) + if len(argsWithoutFlags) == 0 { + // If there are no non-flag arguments, we must already be either at the root command + // or the correct subcommand + return cmd, args + } + + // The next non-flag argument will be the first immediate subcommand + // e.g. in 'go mod tidy', argsWithoutFlags[0] will be 'mod' + nextSubCommand := argsWithoutFlags[0] + + // Lookup this immediate subcommand by name and if we find it, recursively call + // this function so we eventually end up at the end of the command tree with + // the right arguments + next := findSubCommand(cmd, nextSubCommand) + if next != nil { + return findRequestedCommand(next, argsMinusFirstX(cmd, args, nextSubCommand)) + } + + // Found it + return cmd, args +} + +// argsMinusFirstX removes only the first x from args. Otherwise, commands that look like +// openshift admin policy add-role-to-user admin my-user, lose the admin argument (arg[4]). +// Special care needs to be taken not to remove a flag value. +func argsMinusFirstX(cmd *Command, args []string, x string) []string { + if len(args) == 0 { + return args + } + +Loop: + for pos := 0; pos < len(args); pos++ { + s := args[pos] + switch { + case s == "--": + // -- means we have reached the end of the parseable args. Break out of the loop now. + break Loop + case strings.HasPrefix(s, "--") && !strings.Contains(s, "=") && !cmd.hasFlag(s[2:]): + fallthrough + case strings.HasPrefix(s, "-") && !strings.Contains(s, "=") && len(s) == 2 && !cmd.hasShortFlag(s[1:]): + // This is a flag without a default value, and an equal sign is not used. Increment pos in order to skip + // over the next arg, because that is the value of this flag. + pos++ + continue + case !strings.HasPrefix(s, "-"): + // This is not a flag or a flag value. Check to see if it matches what we're looking for, and if so, + // return the args, excluding the one at this position. + if s == x { + ret := make([]string, 0, len(args)-1) + ret = append(ret, args[:pos]...) + ret = append(ret, args[pos+1:]...) + return ret + } + } + } + return args +} + +// findSubCommand searches the immediate subcommands of cmd by name, looking for next. +// +// If next is not found, it will return nil. +func findSubCommand(cmd *Command, next string) *Command { + for _, subcommand := range cmd.subcommands { + if subcommand.name == next { + return subcommand + } + } + return nil +} + +// stripFlags takes a slice of raw command line arguments (including possible flags) and removes +// any arguments that are flags or values passed in to flags e.g. --flag value. +func stripFlags(cmd *Command, args []string) []string { + if len(args) == 0 { + return args + } + + argsWithoutFlags := []string{} + +Loop: + for len(args) > 0 { + arg := args[0] + args = args[1:] + switch { + case arg == "--": + // "--" terminates the flags + break Loop + case strings.HasPrefix(arg, "--") && !strings.Contains(arg, "=") && !cmd.hasFlag(arg[2:]): + // If '--flag arg' then + // delete arg from args. + fallthrough // (do the same as below) + case strings.HasPrefix(arg, "-") && !strings.Contains(arg, "=") && len(arg) == 2 && !cmd.hasShortFlag(arg[1:]): + // If '-f arg' then + // delete 'arg' from args or break the loop if len(args) <= 1. + if len(args) <= 1 { + break Loop + } else { + args = args[1:] + continue + } + case arg != "" && !strings.HasPrefix(arg, "-"): + argsWithoutFlags = append(argsWithoutFlags, arg) + } + } + + return argsWithoutFlags +} + // defaultHelp is the default for a command's helpFunc. func defaultHelp(cmd *Command) error { if cmd == nil { diff --git a/command_test.go b/command_test.go index 7d50b37..fcc3649 100644 --- a/command_test.go +++ b/command_test.go @@ -108,6 +108,115 @@ func TestExecute(t *testing.T) { } } +func TestSubCommandExecute(t *testing.T) { + sub1 := cli.New( + "sub1", + cli.Run(func(cmd *cli.Command, args []string) error { + force, err := cmd.Flags().GetBool("force") + if err != nil { + return err + } + something, err := cmd.Flags().GetString("something") + if err != nil { + return err + } + if something == "" { + something = "" + } + fmt.Fprintf( + cmd.Stdout(), + "Hello from sub1, my args were: %v, force was %v, something was %s", + args, + force, + something, + ) + return nil + }), + ) + sub1.Flags().BoolP("force", "f", false, "Force for sub1") + sub1.Flags().StringP("something", "s", "", "Something for sub1") + + sub2 := cli.New( + "sub2", + cli.Run(func(cmd *cli.Command, args []string) error { + deleteFlag, err := cmd.Flags().GetBool("delete") + if err != nil { + return err + } + number, err := cmd.Flags().GetInt("number") + if err != nil { + return err + } + fmt.Fprintf( + cmd.Stdout(), + "Hello from sub2, my args were: %v, delete was %v, number was %d", + args, + deleteFlag, + number, + ) + return nil + }), + ) + sub2.Flags().BoolP("delete", "d", false, "Delete for sub2") + sub2.Flags().IntP("number", "n", -1, "Number for sub2") + + root := cli.New( + "root", + cli.SubCommands(sub1, sub2), + ) + + tests := []struct { + name string // Test case name + stdout string // Expected stdout + stderr string // Expected stderr + args []string // Args passed to root command + wantErr bool // Whether or not we wanted an error + }{ + { + name: "invoke sub1 no flags", + stdout: "Hello from sub1, my args were: [my subcommand args], force was false, something was ", + stderr: "", + args: []string{"sub1", "my", "subcommand", "args"}, + wantErr: false, + }, + { + name: "invoke sub2 no flags", + stdout: "Hello from sub2, my args were: [my different args], delete was false, number was -1", + stderr: "", + args: []string{"sub2", "my", "different", "args"}, + wantErr: false, + }, + { + name: "invoke sub1 with flags", + stdout: "Hello from sub1, my args were: [my subcommand args], force was true, something was here", + stderr: "", + args: []string{"sub1", "my", "subcommand", "args", "--force", "--something", "here"}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set the args on the root command + cli.Args(tt.args)(root) + + // Test output streams + stderr := &bytes.Buffer{} + stdout := &bytes.Buffer{} + + cli.Stderr(stderr)(root) + cli.Stdout(stdout)(root) + + // Execute the command, we should see the sub commands get executed based on what args we provide + err := root.Execute() + test.Ok(t, err) + + test.Equal(t, stdout.String(), tt.stdout) + test.Equal(t, stderr.String(), tt.stderr) + }) + } +} + func TestHelp(t *testing.T) { tests := []struct { cmd *cli.Command // The command under test