Skip to content

Commit

Permalink
refactor: optimize error messages caused by improper use
Browse files Browse the repository at this point in the history
  • Loading branch information
windvalley committed Jan 14, 2022
1 parent 3fcd73e commit 2f0b06c
Show file tree
Hide file tree
Showing 10 changed files with 76 additions and 36 deletions.
2 changes: 2 additions & 0 deletions internal/cmd/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ Execute commands on target hosts.`,
task.SetCommand(shellCommand)

task.Start()

util.CobraCheckErrWithHelp(cmd, task.CheckErr())
},
}

Expand Down
2 changes: 2 additions & 0 deletions internal/cmd/fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ Copy files/dirs from target hosts to local.`,
task.SetFetchOptions(localDstDir, tmpDir)

task.Start()

util.CobraCheckErrWithHelp(cmd, task.CheckErr())
},
}

Expand Down
13 changes: 13 additions & 0 deletions internal/cmd/push.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ Copy local files/dirs to target hosts.`,
task.SetPushOptions(fileDstPath, allowOverwrite)

task.Start()

util.CobraCheckErrWithHelp(cmd, task.CheckErr())
},
}

Expand All @@ -134,4 +136,15 @@ func init() {
false,
"allow overwrite files/dirs if they already exist on target hosts",
)

pushCmd.SetHelpFunc(func(command *cobra.Command, strings []string) {
util.CobraMarkHiddenGlobalFlags(
command,
"run.sudo",
"run.as-user",
"run.lang",
)

command.Parent().HelpFunc()(command, strings)
})
}
2 changes: 2 additions & 0 deletions internal/cmd/script.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ Execute a local shell script on target hosts.`,
task.SetScriptOptions(destPath, remove, force)

task.Start()

util.CobraCheckErrWithHelp(cmd, task.CheckErr())
},
}

Expand Down
6 changes: 3 additions & 3 deletions internal/cmd/vault/decrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ Decrypt content encrypted by vault.`,
$ gossh vault decrypt GOSSH-AES256:a5c1b3c0cdad4669f84 -V /path/vault-password-file`,
Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 {
util.CheckErr("requires one arg to represent the vault encrypted content")
util.CobraCheckErrWithHelp(cmd, "requires one arg to represent the vault encrypted content")
}

if len(args) > 1 {
util.CheckErr("to many args, only need one")
util.CobraCheckErrWithHelp(cmd, "to many args, only need one")
}

if !aes.IsAES256CipherText(args[0]) {
Expand All @@ -66,6 +66,6 @@ Decrypt content encrypted by vault.`,
}
util.CheckErr(err)

fmt.Println(plainText)
fmt.Printf("\n%s\n", plainText)
},
}
7 changes: 4 additions & 3 deletions internal/cmd/vault/encrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,24 @@ Encrypt sensitive content.`,
$ gossh vault encrypt your-sensitive-content -V /path/vault-password-file`,
Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 {
util.CheckErr("requires one arg to represent the plaintxt to be encrypted")
util.CobraCheckErrWithHelp(cmd, "requires one arg to represent the plaintxt to be encrypted")
}

if len(args) > 1 {
util.CheckErr("to many args, only need one")
util.CobraCheckErrWithHelp(cmd, "to many args, only need one")
}

return nil
},
Run: func(cmd *cobra.Command, args []string) {
vaultPass := getVaultConfirmPassword()

encryptContent, err := aes.AES256Encode(args[0], vaultPass)
if err != nil {
err = fmt.Errorf("encrypt failed: %w", err)
}
util.CheckErr(err)

fmt.Println(encryptContent)
fmt.Printf("\n%s\n", encryptContent)
},
}
52 changes: 28 additions & 24 deletions internal/pkg/sshtask/sshtask.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ type Task struct {

taskOutput chan taskResult
detailOutput chan detailResult

err error
}

// NewTask ...
Expand Down Expand Up @@ -227,7 +229,8 @@ func (t *Task) BatchRun() {

allHosts, err := t.getAllHosts()
if err != nil {
util.CheckErr(err)
t.err = err
return
}

log.Debugf("got target hosts, count: %d", len(allHosts))
Expand All @@ -251,29 +254,31 @@ func (t *Task) BatchRun() {
switch t.taskType {
case CommandTask:
if t.command == "" {
util.CheckErr(errors.New("need flag '-e/--execute' or '-L/--hosts.list'"))
t.err = errors.New("need flag '-e/--execute' or '-L/--hosts.list'")
}
case ScriptTask:
if t.scriptFile == "" {
util.CheckErr(errors.New("need flag '-e/--execute' or '-L/--hosts.list'"))
t.err = errors.New("need flag '-e/--execute' or '-L/--hosts.list'")
}
case PushTask:
if t.pushFiles == nil || len(t.pushFiles.files) == 0 {
util.CheckErr(errors.New("need flag '-f/--files' or '-L/--hosts.list'"))
t.err = errors.New("need flag '-f/--files' or '-L/--hosts.list'")
}
case FetchTask:
if len(t.fetchFiles) == 0 {
util.CheckErr(errors.New("need flag '-f/--files' or '-L/--hosts.list'"))
}

if len(t.dstDir) == 0 {
util.CheckErr(errors.New("need flag '-d/--dest-path' or '-L/--hosts.list'"))
t.err = errors.New("need flag '-f/--files' or '-L/--hosts.list'")
} else if len(t.dstDir) == 0 {
t.err = errors.New("need flag '-d/--dest-path' or '-L/--hosts.list'")
} else {
if !util.DirExists(t.dstDir) {
err := os.MkdirAll(t.dstDir, os.ModePerm)
util.CheckErr(err)
}
}
}

if !util.DirExists(t.dstDir) {
err := os.MkdirAll(t.dstDir, os.ModePerm)
util.CheckErr(err)
}
if t.err != nil {
return
}

t.buildSSHClient()
Expand Down Expand Up @@ -348,6 +353,11 @@ func (t *Task) HandleOutput() {
}
}

// CheckErr ...
func (t *Task) CheckErr() error {
return t.err
}

func (t *Task) getAllHosts() ([]string, error) {
var hosts []string

Expand Down Expand Up @@ -392,17 +402,15 @@ func (t *Task) getAllHosts() ([]string, error) {
}

if len(hosts) == 0 {
return nil, fmt.Errorf("no target hosts provided")
return nil, fmt.Errorf("need target hosts, you can specify hosts file by flag '-H' or " +
"provide host/pattern as positional arguments")
}

return util.RemoveDuplStr(hosts), nil
}

func (t *Task) buildSSHClient() {
var (
sshClient *batchssh.Client
err error
)
var sshClient *batchssh.Client

password, err := t.getPassword()
if err != nil {
Expand All @@ -414,7 +422,7 @@ func (t *Task) buildSSHClient() {
if t.configFlags.Proxy.Server != "" {
proxyAuths := t.getProxySSHAuthMethods(&password)

sshClient, err = batchssh.NewClient(
sshClient = batchssh.NewClient(
t.configFlags.Auth.User,
password,
auths,
Expand All @@ -430,7 +438,7 @@ func (t *Task) buildSSHClient() {
),
)
} else {
sshClient, err = batchssh.NewClient(
sshClient = batchssh.NewClient(
t.configFlags.Auth.User,
password,
auths,
Expand All @@ -441,10 +449,6 @@ func (t *Task) buildSSHClient() {
)
}

if err != nil {
util.CheckErr(err)
}

t.sshClient = sshClient
}

Expand Down
8 changes: 2 additions & 6 deletions pkg/batchssh/batchssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,7 @@ type Proxy struct {
}

// NewClient session.
func NewClient(
user, password string,
auths []ssh.AuthMethod,
options ...func(*Client),
) (*Client, error) {
func NewClient(user, password string, auths []ssh.AuthMethod, options ...func(*Client)) *Client {
client := Client{
User: user,
Password: password,
Expand All @@ -103,7 +99,7 @@ func NewClient(
option(&client)
}

return &client, nil
return &client
}

// BatchRun command on remote servers.
Expand Down
13 changes: 13 additions & 0 deletions pkg/util/cobra.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,19 @@ import (
"github.com/spf13/pflag"
)

// CobraCheckErrWithHelp instead of cobra default behavior.
func CobraCheckErrWithHelp(cmd *cobra.Command, errMsg interface{}) {
if errMsg != nil {
PrintErr(errMsg)

_ = cmd.Help()

fmt.Println()

CheckErr(errMsg)
}
}

// CobraMarkHiddenGlobalFlags that from params.
func CobraMarkHiddenGlobalFlags(command *cobra.Command, flags ...string) {
for _, v := range flags {
Expand Down
7 changes: 7 additions & 0 deletions pkg/util/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,10 @@ func CheckErr(msg interface{}) {
os.Exit(1)
}
}

// PrintErr with red color if err not nil.
func PrintErr(msg interface{}) {
if msg != nil {
fmt.Fprintln(os.Stderr, color.RedString("Error:"), msg)
}
}

0 comments on commit 2f0b06c

Please sign in to comment.