From 0f85b04c0e9eb15a0604f48899a3f70474523f15 Mon Sep 17 00:00:00 2001 From: shilinlee <836160610@qq.com> Date: Tue, 5 Dec 2023 14:41:44 +0800 Subject: [PATCH] feat: support start cluster cmd (#27) Signed-off-by: shilinlee <836160610@qq.com> --- cmd/cluster/display.go | 17 ++ cmd/cluster/root.go | 1 + cmd/cluster/start.go | 4 +- cmd/cluster/start2.go | 60 ++++ embed/templates/systemd/system.service.tpl | 41 +++ go.mod | 5 +- go.sum | 2 + pkg/cluster/manager/basic.go | 2 +- pkg/cluster/manager/builder.go | 4 +- pkg/cluster/manager/check.go | 2 +- pkg/cluster/manager/install2.go | 8 +- pkg/cluster/manager/manager.go | 97 +++--- pkg/cluster/manager/start.go | 62 +++- pkg/cluster/module/systemd.go | 104 +++++++ pkg/cluster/module/wait_for.go | 94 ++++++ pkg/cluster/operation/action.go | 330 +++++++++++++++++++++ pkg/cluster/operation/operation.go | 40 +++ pkg/cluster/spec/instance.go | 92 +++++- pkg/cluster/spec/parse_topology.go | 9 +- pkg/cluster/spec/profile.go | 2 +- pkg/cluster/spec/spec.go | 60 ++-- pkg/cluster/spec/spec_manager.go | 31 +- pkg/cluster/spec/ts_meta.go | 8 +- pkg/cluster/spec/ts_sql.go | 6 +- pkg/cluster/spec/ts_store.go | 10 +- pkg/cluster/spec/validate.go | 21 ++ pkg/cluster/task/builder.go | 55 ++-- pkg/cluster/task/func.go | 45 +++ pkg/cluster/task/ssh.go | 6 +- pkg/cluster/task/ssh_keyset.go | 46 +++ pkg/cluster/template/systemd/system.go | 114 +++++++ pkg/meta/resource_ctrl.go | 24 ++ pkg/set/any_set.go | 90 ++++++ pkg/set/string_set.go | 81 +++++ pkg/utils/retry.go | 112 +++++++ 35 files changed, 1556 insertions(+), 129 deletions(-) create mode 100644 cmd/cluster/display.go create mode 100644 cmd/cluster/start2.go create mode 100644 embed/templates/systemd/system.service.tpl create mode 100644 pkg/cluster/module/systemd.go create mode 100644 pkg/cluster/module/wait_for.go create mode 100644 pkg/cluster/operation/action.go create mode 100644 pkg/cluster/spec/validate.go create mode 100644 pkg/cluster/task/func.go create mode 100644 pkg/cluster/task/ssh_keyset.go create mode 100644 pkg/cluster/template/systemd/system.go create mode 100644 pkg/meta/resource_ctrl.go create mode 100644 pkg/set/any_set.go create mode 100644 pkg/set/string_set.go create mode 100644 pkg/utils/retry.go diff --git a/cmd/cluster/display.go b/cmd/cluster/display.go new file mode 100644 index 0000000..a5d27a2 --- /dev/null +++ b/cmd/cluster/display.go @@ -0,0 +1,17 @@ +package cluster + +import ( + "github.com/openGemini/gemix/pkg/cluster/manager" + "github.com/spf13/cobra" +) + +func shellCompGetClusterName(cm *manager.Manager, toComplete string) ([]string, cobra.ShellCompDirective) { + var result []string + //clusters, _ := cm.GetClusterList() + //for _, c := range clusters { + // if strings.HasPrefix(c.Name, toComplete) { + // result = append(result, c.Name) + // } + //} + return result, cobra.ShellCompDirectiveNoFileComp +} diff --git a/cmd/cluster/root.go b/cmd/cluster/root.go index fd6897a..ac9b622 100644 --- a/cmd/cluster/root.go +++ b/cmd/cluster/root.go @@ -58,6 +58,7 @@ func init() { installCmd(), installCmd2(), startCmd, + startCmd2(), stopCmd, uninstallCmd, statusCmd, diff --git a/cmd/cluster/start.go b/cmd/cluster/start.go index 8b3468d..02fc7d1 100644 --- a/cmd/cluster/start.go +++ b/cmd/cluster/start.go @@ -22,13 +22,11 @@ import ( "github.com/spf13/cobra" ) -var startOpts utils.StartOptions - // startCmd represents the start command var startCmd = &cobra.Command{ Use: "start ", Short: "Start an openGemini cluster", - Long: `Start an openGemini cluster based on configuration files and version numbers.`, + Long: `Start an openGemini cluster`, Run: func(cmd *cobra.Command, args []string) { var ops utils.ClusterOptions var err error diff --git a/cmd/cluster/start2.go b/cmd/cluster/start2.go new file mode 100644 index 0000000..320d05e --- /dev/null +++ b/cmd/cluster/start2.go @@ -0,0 +1,60 @@ +// Copyright 2023 Huawei Cloud Computing Technologies Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cluster + +import ( + "github.com/openGemini/gemix/utils" + "github.com/spf13/cobra" +) + +var startOpts utils.StartOptions + +func startCmd2() *cobra.Command { + var ( + initPasswd bool + ) + + var cmd = &cobra.Command{ + Use: "start2 ", + Short: "Start an openGemini cluster", + Long: `Start an openGemini cluster`, + RunE: func(cmd *cobra.Command, args []string) error { + if len(args) != 1 { + return cmd.Help() + } + + clusterName := args[0] + err := cm.StartCluster(clusterName, gOpt) + if err != nil { + return err + } + + // TODO: init password + return nil + }, + ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + switch len(args) { + case 0: + return shellCompGetClusterName(cm, toComplete) + default: + return nil, cobra.ShellCompDirectiveNoFileComp + } + }, + } + cmd.Flags().BoolVar(&initPasswd, "init", false, "Initialize a secure root password for the database") + cmd.Flags().StringSliceVarP(&gOpt.Roles, "role", "R", nil, "Only start specified roles") + cmd.Flags().StringSliceVarP(&gOpt.Nodes, "node", "N", nil, "Only start specified nodes") + return cmd +} diff --git a/embed/templates/systemd/system.service.tpl b/embed/templates/systemd/system.service.tpl new file mode 100644 index 0000000..2fcea57 --- /dev/null +++ b/embed/templates/systemd/system.service.tpl @@ -0,0 +1,41 @@ +[Unit] +Description={{.ServiceName}} service +After=syslog.target network.target remote-fs.target nss-lookup.target + +[Service] +{{- if .MemoryLimit}} +MemoryLimit={{.MemoryLimit}} +{{- end}} +{{- if .CPUQuota}} +CPUQuota={{.CPUQuota}} +{{- end}} +{{- if .IOReadBandwidthMax}} +IOReadBandwidthMax={{.IOReadBandwidthMax}} +{{- end}} +{{- if .IOWriteBandwidthMax}} +IOWriteBandwidthMax={{.IOWriteBandwidthMax}} +{{- end}} +{{- if .LimitCORE}} +LimitCORE={{.LimitCORE}} +{{- end}} +LimitNOFILE=1000000 +LimitSTACK=10485760 + +{{- if .GrantCapNetRaw}} +AmbientCapabilities=CAP_NET_RAW +{{- end}} +User={{.User}} +ExecStart=/bin/bash -c '{{.DeployDir}}/scripts/run_{{.ServiceName}}.sh' + +{{- if .Restart}} +Restart={{.Restart}} +{{else}} +Restart=always +{{end}} +RestartSec=15s +{{- if .DisableSendSigkill}} +SendSIGKILL=no +{{- end}} + +[Install] +WantedBy=multi-user.target diff --git a/go.mod b/go.mod index ede9847..58257d9 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/charmbracelet/lipgloss v0.9.1 github.com/creasty/defaults v1.7.0 github.com/fatih/color v1.16.0 + github.com/google/uuid v1.3.0 github.com/joomcode/errorx v1.1.1 github.com/olekukonko/tablewriter v0.0.5 github.com/pkg/errors v0.9.1 @@ -21,7 +22,9 @@ require ( go.uber.org/zap v1.26.0 golang.org/x/crypto v0.15.0 golang.org/x/mod v0.14.0 + golang.org/x/sync v0.1.0 golang.org/x/term v0.14.0 + golang.org/x/text v0.14.0 gopkg.in/yaml.v2 v2.4.0 ) @@ -48,8 +51,6 @@ require ( github.com/spf13/pflag v1.0.5 // indirect go.uber.org/goleak v1.2.1 // indirect go.uber.org/multierr v1.10.0 // indirect - golang.org/x/sync v0.1.0 // indirect golang.org/x/sys v0.14.0 // indirect - golang.org/x/text v0.14.0 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect ) diff --git a/go.sum b/go.sum index 899023c..51cca41 100644 --- a/go.sum +++ b/go.sum @@ -32,6 +32,8 @@ github.com/dchest/bcrypt_pbkdf v0.0.0-20150205184540-83f37f9c154a h1:saTgr5tMLFn github.com/dchest/bcrypt_pbkdf v0.0.0-20150205184540-83f37f9c154a/go.mod h1:Bw9BbhOJVNR+t0jCqx2GC6zv0TGBsShs56Y3gfSCvl0= github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM= github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/joomcode/errorx v1.1.1 h1:/LFG/qSk1gUTuZjs+qlyOJEpcVjD9DXgBNFhdZkQrjY= diff --git a/pkg/cluster/manager/basic.go b/pkg/cluster/manager/basic.go index b21aeaf..48efb50 100644 --- a/pkg/cluster/manager/basic.go +++ b/pkg/cluster/manager/basic.go @@ -23,7 +23,7 @@ type hostInfo struct { } // getAllUniqueHosts gets all the instance -func getAllUniqueHosts(topo *spec.Specification) map[string]hostInfo { +func getAllUniqueHosts(topo spec.Topology) map[string]hostInfo { // monitor uniqueHosts := make(map[string]hostInfo) // host -> ssh-port, os, arch topo.IterInstance(func(inst spec.Instance) { diff --git a/pkg/cluster/manager/builder.go b/pkg/cluster/manager/builder.go index df2ed36..47cde01 100644 --- a/pkg/cluster/manager/builder.go +++ b/pkg/cluster/manager/builder.go @@ -25,7 +25,7 @@ import ( ) // buildDownloadCompTasks build download component tasks -func buildDownloadCompTasks(clusterVersion string, topo *spec.Specification) []*task.StepDisplay { +func buildDownloadCompTasks(clusterVersion string, topo spec.Topology) []*task.StepDisplay { var tasks []*task.StepDisplay uniqueTasks := make(map[string]struct{}) @@ -48,7 +48,7 @@ func buildDownloadCompTasks(clusterVersion string, topo *spec.Specification) []* func buildInitConfigTasks( m *Manager, clustername string, - topo *spec.Specification, + topo spec.Topology, base *spec.BaseMeta, gOpt operator.Options, ) []*task.StepDisplay { diff --git a/pkg/cluster/manager/check.go b/pkg/cluster/manager/check.go index bc446ea..e85791c 100644 --- a/pkg/cluster/manager/check.go +++ b/pkg/cluster/manager/check.go @@ -17,7 +17,7 @@ package manager import "github.com/openGemini/gemix/pkg/cluster/spec" // checkConflict checks cluster conflict -func checkConflict(m *Manager, clusterName string, topo *spec.Specification) error { +func checkConflict(m *Manager, clusterName string, topo spec.Topology) error { //clusterList, err := m.specManager.GetAllClusters() //if err != nil { // return err diff --git a/pkg/cluster/manager/install2.go b/pkg/cluster/manager/install2.go index ea050a7..ea7be23 100644 --- a/pkg/cluster/manager/install2.go +++ b/pkg/cluster/manager/install2.go @@ -287,10 +287,10 @@ func (m *Manager) Install( } // FIXME: remove me if you finish - //err = m.specManager.SaveMeta(clusterName, metadata) - //if err != nil { - // return err - //} + err = m.specManager.SaveMeta(clusterName, metadata) + if err != nil { + return err + } hint := color.New(color.FgBlue).Sprintf("%s start %s", "gemix cluster", clusterName) fmt.Printf("Cluster `%s` deployed successfully, you can start it with command: `%s`\n", clusterName, hint) diff --git a/pkg/cluster/manager/manager.go b/pkg/cluster/manager/manager.go index 8faf68b..29f0371 100644 --- a/pkg/cluster/manager/manager.go +++ b/pkg/cluster/manager/manager.go @@ -19,7 +19,9 @@ import ( "strings" "github.com/fatih/color" + operator "github.com/openGemini/gemix/pkg/cluster/operation" "github.com/openGemini/gemix/pkg/cluster/spec" + "github.com/openGemini/gemix/pkg/cluster/task" "github.com/openGemini/gemix/pkg/gui" "github.com/pkg/errors" "go.uber.org/zap" @@ -42,26 +44,26 @@ func NewManager(sysName string, specManager *spec.SpecManager, logger *zap.Logge } } -//func (m *Manager) meta(name string) (metadata spec.Metadata, err error) { -// exist, err := m.specManager.Exist(name) -// if err != nil { -// return nil, err -// } -// -// if !exist { -// return nil, perrs.Errorf("%s cluster `%s` not exists", m.sysName, name) -// } -// -// metadata = m.specManager.NewMetadata() -// err = m.specManager.Metadata(name, metadata) -// if err != nil { -// return metadata, err -// } -// -// return metadata, nil -//} +func (m *Manager) meta(name string) (metadata spec.Metadata, err error) { + exist, err := m.specManager.Exist(name) + if err != nil { + return nil, err + } + + if !exist { + return nil, errors.Errorf("%s cluster `%s` not exists", m.sysName, name) + } -func (m *Manager) confirmTopology(clusterName, version string, topo *spec.Specification) error { + metadata = m.specManager.NewMetadata() // TODO: 没有可用信息 + err = m.specManager.Metadata(name, metadata) + if err != nil { + return metadata, err + } + + return metadata, nil +} + +func (m *Manager) confirmTopology(clusterName, version string, topo spec.Topology) error { fmt.Println("Please confirm your topology:") cyan := color.New(color.FgCyan, color.Bold) @@ -95,44 +97,35 @@ func (m *Manager) confirmTopology(clusterName, version string, topo *spec.Specif return gui.PromptForConfirmOrAbortError("Do you want to continue? [y/N]: ") } -//func (m *Manager) sshTaskBuilder(name string, topo spec.Topology, user string, gOpt operator.Options) (*task.Builder, error) { -// var p *tui.SSHConnectionProps = &tui.SSHConnectionProps{} -// if gOpt.SSHType != executor.SSHTypeNone && len(gOpt.SSHProxyHost) != 0 { -// var err error -// if p, err = tui.ReadIdentityFileOrPassword(gOpt.SSHProxyIdentity, gOpt.SSHProxyUsePassword); err != nil { -// return nil, err -// } -// } -// -// return task.NewBuilder(m.logger). -// SSHKeySet( -// m.specManager.Path(name, "ssh", "id_rsa"), -// m.specManager.Path(name, "ssh", "id_rsa.pub"), -// ). -// ClusterSSH( -// topo, -// user, -// gOpt.SSHTimeout, -// gOpt.OptTimeout, -// gOpt.SSHProxyHost, -// gOpt.SSHProxyPort, -// gOpt.SSHProxyUser, -// p.Password, -// p.IdentityFile, -// p.IdentityFilePassphrase, -// gOpt.SSHProxyTimeout, -// gOpt.SSHType, -// topo.BaseTopo().GlobalOptions.SSHType, -// ), nil -//} +func (m *Manager) sshTaskBuilder(name string, topo spec.Topology, user string, gOpt operator.Options) (*task.Builder, error) { + //var p *gui.SSHConnectionProps = &gui.SSHConnectionProps{} + //if gOpt.SSHType != executor.SSHTypeNone && len(gOpt.SSHProxyHost) != 0 { + // var err error + // if p, err = gui.ReadIdentityFileOrPassword(gOpt.SSHProxyIdentity, gOpt.SSHProxyUsePassword); err != nil { + // return nil, err + // } + //} + + return task.NewBuilder(m.logger). + SSHKeySet( + m.specManager.Path(name, "ssh", "id_rsa"), + m.specManager.Path(name, "ssh", "id_rsa.pub"), + ). + ClusterSSH( + topo, + user, + gOpt.SSHTimeout, + gOpt.OptTimeout, + ), nil +} // fillHost full host cpu-arch and kernel-name -func (m *Manager) fillHost(s *gui.SSHConnectionProps, topo *spec.Specification, user string) error { +func (m *Manager) fillHost(s *gui.SSHConnectionProps, topo spec.Topology, user string) error { hostArchOrOS := map[string]string{} topo.IterInstance(func(instance spec.Instance) { insOS := instance.OS() if insOS == "" { - insOS = topo.GlobalOptions.OS + insOS = topo.BaseTopo().GlobalOptions.OS } hostArchOrOS[instance.GetHost()] = insOS }) @@ -144,7 +137,7 @@ func (m *Manager) fillHost(s *gui.SSHConnectionProps, topo *spec.Specification, topo.IterInstance(func(instance spec.Instance) { insArch := instance.Arch() if insArch == "" { - insArch = topo.GlobalOptions.Arch + insArch = topo.BaseTopo().GlobalOptions.Arch } hostArchOrOS[instance.GetHost()] = insArch }) diff --git a/pkg/cluster/manager/start.go b/pkg/cluster/manager/start.go index ce5f400..02569dc 100644 --- a/pkg/cluster/manager/start.go +++ b/pkg/cluster/manager/start.go @@ -15,7 +15,7 @@ package manager import ( - "errors" + "context" "fmt" "path/filepath" "strconv" @@ -23,12 +23,72 @@ import ( "sync" "time" + "github.com/joomcode/errorx" "github.com/openGemini/gemix/pkg/cluster/config" + "github.com/openGemini/gemix/pkg/cluster/ctxt" "github.com/openGemini/gemix/pkg/cluster/operation" + "github.com/openGemini/gemix/pkg/cluster/spec" + "github.com/openGemini/gemix/pkg/cluster/task" "github.com/openGemini/gemix/utils" + "github.com/pkg/errors" + "go.uber.org/zap" "golang.org/x/crypto/ssh" ) +// StartCluster start the cluster with specified name. +func (m *Manager) StartCluster(name string, gOpt operation.Options, fn ...func(b *task.Builder, metadata spec.Metadata)) error { + m.logger.Info("Starting cluster ...", zap.String("cluster name", name)) + + // check locked + //if err := m.specManager.ScaleOutLockedErr(name); err != nil { + // return err + //} + + metadata, err := m.meta(name) + if err != nil { + return err + } + + topo := metadata.GetTopology() + base := metadata.GetBaseMeta() + + //tlsCfg, err := topo.TLSConfig(m.specManager.Path(name, spec.TLSCertKeyDir)) + //if err != nil { + // return err + //} + + b, err := m.sshTaskBuilder(name, topo, base.User, gOpt) + if err != nil { + return err + } + + b.Func("StartCluster", func(ctx context.Context) error { + return operation.Start(ctx, topo, gOpt, nil) + }) + + for _, f := range fn { + f(b, metadata) + } + + t := b.Build() + + ctx := ctxt.New( + context.Background(), + gOpt.Concurrency, + m.logger, + ) + if err := t.Execute(ctx); err != nil { + if errorx.Cast(err) != nil { + // FIXME: Map possible task errors and give suggestions. + return err + } + return errors.WithStack(err) + } + + m.logger.Info("Started cluster successfully", zap.String("cluster name", name)) + return nil +} + type Starter interface { PrepareForStart() error Start() error diff --git a/pkg/cluster/module/systemd.go b/pkg/cluster/module/systemd.go new file mode 100644 index 0000000..fcfec0c --- /dev/null +++ b/pkg/cluster/module/systemd.go @@ -0,0 +1,104 @@ +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package module + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/openGemini/gemix/pkg/cluster/ctxt" +) + +// scope can be either "system", "user" or "global" +const ( + SystemdScopeSystem = "system" + SystemdScopeUser = "user" + SystemdScopeGlobal = "global" +) + +// SystemdModuleConfig is the configurations used to initialize a SystemdModule +type SystemdModuleConfig struct { + Unit string // the name of systemd unit(s) + Action string // the action to perform with the unit + ReloadDaemon bool // run daemon-reload before other actions + CheckActive bool // run is-active before action + Scope string // user, system or global + Force bool // add the `--force` arg to systemctl command + Signal string // specify the signal to send to process + Timeout time.Duration // timeout to execute the command +} + +// SystemdModule is the module used to control systemd units +type SystemdModule struct { + cmd string // the built command + sudo bool // does the command need to be run as root + timeout time.Duration // timeout to execute the command +} + +// NewSystemdModule builds and returns a SystemdModule object base on +// given config. +func NewSystemdModule(config SystemdModuleConfig) *SystemdModule { + systemctl := "systemctl" + sudo := true + + if config.Force { + systemctl = fmt.Sprintf("%s --force", systemctl) + } + + if config.Signal != "" { + systemctl = fmt.Sprintf("%s --signal %s", systemctl, config.Signal) + } + + switch config.Scope { + case SystemdScopeUser: + sudo = false // `--user` scope does not need root privilege + fallthrough + case SystemdScopeGlobal: + systemctl = fmt.Sprintf("%s --%s", systemctl, config.Scope) + } + + cmd := fmt.Sprintf("%s %s %s", + systemctl, strings.ToLower(config.Action), config.Unit) + + if config.CheckActive { + cmd = fmt.Sprintf("if [[ $(%s is-active %s) == \"active\" ]]; then %s; fi", + systemctl, config.Unit, cmd) + } + + if config.ReloadDaemon { + cmd = fmt.Sprintf("%s daemon-reload && %s", + systemctl, cmd) + } + mod := &SystemdModule{ + cmd: cmd, + sudo: sudo, + timeout: config.Timeout, + } + + // the default TimeoutStopSec of systemd is 90s, after which it sends a SIGKILL + // to remaining processes, set the default value slightly larger than it + if config.Timeout == 0 { + mod.timeout = time.Second * 100 + } + + return mod +} + +// Execute passes the command to executor and returns its results, the executor +// should be already initialized. +func (mod *SystemdModule) Execute(ctx context.Context, exec ctxt.Executor) ([]byte, []byte, error) { + return exec.Execute(ctx, mod.cmd, mod.sudo, mod.timeout) +} diff --git a/pkg/cluster/module/wait_for.go b/pkg/cluster/module/wait_for.go new file mode 100644 index 0000000..761df7b --- /dev/null +++ b/pkg/cluster/module/wait_for.go @@ -0,0 +1,94 @@ +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package module + +import ( + "bytes" + "context" + "fmt" + "time" + + "github.com/openGemini/gemix/pkg/cluster/ctxt" + "github.com/openGemini/gemix/pkg/utils" + "github.com/pkg/errors" + "go.uber.org/zap" +) + +// WaitForConfig is the configurations of WaitFor module. +type WaitForConfig struct { + Port int // Port number to poll. + Sleep time.Duration // Duration to sleep between checks, default 1 second. + // Choices: + // started + // stopped + // When checking a port started will ensure the port is open, stopped will check that it is closed + State string + Timeout time.Duration // Maximum duration to wait for. +} + +// WaitFor is the module used to wait for some condition. +type WaitFor struct { + c WaitForConfig +} + +// NewWaitFor create a WaitFor instance. +func NewWaitFor(c WaitForConfig) *WaitFor { + if c.Sleep == 0 { + c.Sleep = time.Second + } + if c.Timeout == 0 { + c.Timeout = time.Second * 60 + } + if c.State == "" { + c.State = "started" + } + + w := &WaitFor{ + c: c, + } + + return w +} + +// Execute the module return nil if successfully wait for the event. +func (w *WaitFor) Execute(ctx context.Context, e ctxt.Executor) (err error) { + pattern := []byte(fmt.Sprintf(":%d ", w.c.Port)) + + retryOpt := utils.RetryOption{ + Delay: w.c.Sleep, + Timeout: w.c.Timeout, + } + if err := utils.Retry(func() error { + // only listing TCP ports + stdout, _, err := e.Execute(ctx, "ss -ltn", false) + if err == nil { + switch w.c.State { + case "started": + if bytes.Contains(stdout, pattern) { + return nil + } + case "stopped": + if !bytes.Contains(stdout, pattern) { + return nil + } + } + return errors.New("still waiting for port state to be satisfied") + } + return err + }, retryOpt); err != nil { + zap.L().Debug("retry error", zap.Error(err)) + return errors.Errorf("timed out waiting for port %d to be %s after %s", w.c.Port, w.c.State, w.c.Timeout) + } + return nil +} diff --git a/pkg/cluster/operation/action.go b/pkg/cluster/operation/action.go new file mode 100644 index 0000000..9da2407 --- /dev/null +++ b/pkg/cluster/operation/action.go @@ -0,0 +1,330 @@ +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package operation + +import ( + "bytes" + "context" + "crypto/tls" + "fmt" + "time" + + "github.com/openGemini/gemix/pkg/cluster/ctxt" + "github.com/openGemini/gemix/pkg/cluster/module" + "github.com/openGemini/gemix/pkg/cluster/spec" + logprinter "github.com/openGemini/gemix/pkg/logger/printer" + "github.com/openGemini/gemix/pkg/set" + "github.com/pkg/errors" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" + "golang.org/x/text/cases" + "golang.org/x/text/language" +) + +var ( + actionPrevMsgs = map[string]string{ + "start": "Starting", + "stop": "Stopping", + "enable": "Enabling", + "disable": "Disabling", + } + actionPostMsgs = map[string]string{} +) + +func init() { + for action := range actionPrevMsgs { + actionPostMsgs[action] = cases.Title(language.English).String(action) + } +} + +// Start the cluster. +func Start( + ctx context.Context, + cluster spec.Topology, + options Options, + tlsCfg *tls.Config, +) error { + uniqueHosts := set.NewStringSet() + roleFilter := set.NewStringSet(options.Roles...) + nodeFilter := set.NewStringSet(options.Nodes...) + components := cluster.ComponentsByStartOrder() + components = FilterComponent(components, roleFilter) + + for _, comp := range components { + insts := FilterInstance(comp.Instances(), nodeFilter) + err := StartComponent(ctx, insts, options, tlsCfg) + if err != nil { + return errors.WithMessagef(err, "failed to start %s", comp.Name()) + } + + errg, _ := errgroup.WithContext(ctx) + for _, inst := range insts { + uniqueHosts.Insert(inst.GetManageHost()) + } + if err = errg.Wait(); err != nil { + return err + } + } + + //hosts := make([]string, 0, len(uniqueHosts)) + //for host := range uniqueHosts { + // hosts = append(hosts, host) + //} + //return StartMonitored(ctx, hosts, noAgentHosts, monitoredOptions, options.OptTimeout) + return nil +} + +func enableInstance(ctx context.Context, ins spec.Instance, timeout uint64, isEnable bool) error { + e := ctxt.GetInner(ctx).Get(ins.GetManageHost()) + logger := ctx.Value(logprinter.ContextKeyLogger).(*zap.Logger) + + action := "disable" + if isEnable { + action = "enable" + } + logger.Info(fmt.Sprintf("\t%s instance %s", actionPrevMsgs[action], ins.ID())) + + // Enable/Disable by systemd. + if err := systemctl(ctx, e, ins.ServiceName(), action, timeout); err != nil { + return toFailedActionError(err, action, ins.GetManageHost(), ins.ServiceName(), ins.LogDir()) + } + + logger.Info(fmt.Sprintf("\t%s instance %s success", actionPostMsgs[action], ins.ID())) + + return nil +} + +func startInstance(ctx context.Context, ins spec.Instance, timeout uint64, tlsCfg *tls.Config) error { + e := ctxt.GetInner(ctx).Get(ins.GetManageHost()) + logger := ctx.Value(logprinter.ContextKeyLogger).(*zap.Logger) + logger.Info(fmt.Sprintf("\tStarting instance %s", ins.ID())) + + if err := systemctl(ctx, e, ins.ServiceName(), "start", timeout); err != nil { + return toFailedActionError(err, "start", ins.GetManageHost(), ins.ServiceName(), ins.LogDir()) + } + + // Check ready. + if err := ins.Ready(ctx, e, timeout, tlsCfg); err != nil { + return toFailedActionError(err, "start", ins.GetManageHost(), ins.ServiceName(), ins.LogDir()) + } + + logger.Info(fmt.Sprintf("\tStart instance %s success", ins.ID())) + + return nil +} + +func systemctl(ctx context.Context, executor ctxt.Executor, service string, action string, timeout uint64) error { + logger := ctx.Value(logprinter.ContextKeyLogger).(*zap.Logger) + c := module.SystemdModuleConfig{ + Unit: service, + ReloadDaemon: true, + Action: action, + Timeout: time.Second * time.Duration(timeout), + } + systemd := module.NewSystemdModule(c) + stdout, stderr, err := systemd.Execute(ctx, executor) + + if len(stdout) > 0 { + fmt.Println(string(stdout)) + } + if len(stderr) > 0 && !bytes.Contains(stderr, []byte("Created symlink ")) && !bytes.Contains(stderr, []byte("Removed symlink ")) { + logger.Error(string(stderr)) + } + if len(stderr) > 0 && action == "stop" { + // ignore "unit not loaded" error, as this means the unit is not + // exist, and that's exactly what we want + // NOTE: there will be a potential bug if the unit name is set + // wrong and the real unit still remains started. + if bytes.Contains(stderr, []byte(" not loaded.")) { + logger.Warn(string(stderr)) + return nil // reset the error to avoid exiting + } + logger.Error(string(stderr)) + } + return err +} + +// EnableComponent enable/disable the instances +func EnableComponent(ctx context.Context, instances []spec.Instance, noAgentHosts set.StringSet, options Options, isEnable bool) error { + if len(instances) == 0 { + return nil + } + + logger := ctx.Value(logprinter.ContextKeyLogger).(*zap.Logger) + name := instances[0].ComponentName() + if isEnable { + logger.Info(fmt.Sprintf("Enabling component %s", name)) + } else { + logger.Info(fmt.Sprintf("Disabling component %s", name)) + } + + errg, _ := errgroup.WithContext(ctx) + + for _, ins := range instances { + ins := ins + + errg.Go(func() error { + err := enableInstance(ctx, ins, options.OptTimeout, isEnable) + if err != nil { + return err + } + return nil + }) + } + + return errg.Wait() +} + +// StartComponent start the instances. +func StartComponent(ctx context.Context, instances []spec.Instance, options Options, tlsCfg *tls.Config) error { + if len(instances) == 0 { + return nil + } + + logger := ctx.Value(logprinter.ContextKeyLogger).(*zap.Logger) + name := instances[0].ComponentName() + logger.Info(fmt.Sprintf("Starting component %s", name)) + + errg, _ := errgroup.WithContext(ctx) + for _, ins := range instances { + ins := ins + + errg.Go(func() error { + if err := ins.PrepareStart(ctx, tlsCfg); err != nil { + return err + } + return startInstance(ctx, ins, options.OptTimeout, tlsCfg) + }) + } + + return errg.Wait() +} + +//func stopInstance(ctx context.Context, ins spec.Instance, timeout uint64) error { +// e := ctxt.GetInner(ctx).Get(ins.GetManageHost()) +// logger := ctx.Value(logprinter.ContextKeyLogger).(*zap.Logger) +// logger.Infof("\tStopping instance %s", ins.GetManageHost()) +// +// if err := systemctl(ctx, e, ins.ServiceName(), "stop", timeout); err != nil { +// return toFailedActionError(err, "stop", ins.GetManageHost(), ins.ServiceName(), ins.LogDir()) +// } +// +// logger.Infof("\tStop %s %s success", ins.ComponentName(), ins.ID()) +// +// return nil +//} + +// // StopComponent stop the instances. +// func StopComponent(ctx context.Context, +// +// topo spec.Topology, +// instances []spec.Instance, +// noAgentHosts set.StringSet, +// options Options, +// forceStop bool, +// evictLeader bool, +// tlsCfg *tls.Config, +// +// ) error { +// if len(instances) == 0 { +// return nil +// } +// +// logger := ctx.Value(logprinter.ContextKeyLogger).(*zap.Logger) +// name := instances[0].ComponentName() +// logger.Infof("Stopping component %s", name) +// +// errg, _ := errgroup.WithContext(ctx) +// +// for _, ins := range instances { +// ins := ins +// switch name { +// case spec.ComponentNodeExporter, +// spec.ComponentBlackboxExporter: +// if noAgentHosts.Exist(ins.GetManageHost()) { +// logger.Debugf("Ignored stopping %s for %s:%d", name, ins.GetManageHost(), ins.GetPort()) +// continue +// } +// case spec.ComponentCDC: +// nctx := checkpoint.NewContext(ctx) +// if !forceStop { +// // when scale-in cdc node, each node should be stopped one by one. +// cdc, ok := ins.(spec.RollingUpdateInstance) +// if !ok { +// panic("cdc should support rolling upgrade, but not") +// } +// err := cdc.PreRestart(nctx, topo, int(options.APITimeout), tlsCfg) +// if err != nil { +// // this should never hit, since all errors swallowed to trigger hard stop. +// return err +// } +// } +// if err := stopInstance(nctx, ins, options.OptTimeout); err != nil { +// return err +// } +// // continue here, to skip the logic below. +// continue +// } +// +// // the checkpoint part of context can't be shared between goroutines +// // since it's used to trace the stack, so we must create a new layer +// // of checkpoint context every time put it into a new goroutine. +// nctx := checkpoint.NewContext(ctx) +// errg.Go(func() error { +// if evictLeader { +// rIns, ok := ins.(spec.RollingUpdateInstance) +// if ok { +// err := rIns.PreRestart(nctx, topo, int(options.APITimeout), tlsCfg) +// if err != nil { +// return err +// } +// } +// } +// err := stopInstance(nctx, ins, options.OptTimeout) +// if err != nil { +// return err +// } +// return nil +// }) +// } +// +// return errg.Wait() +// } + +// toFailedActionError formats the errror msg for failed action +func toFailedActionError(err error, action string, host, service, logDir string) error { + return errors.WithMessagef(err, + "failed to %s: %s %s, please check the instance's log(%s) for more detail.", + action, host, service, logDir, + ) +} + +//lint:ignore U1000 keep this +func executeSSHCommand(ctx context.Context, action, host, command string) error { + if command == "" { + return nil + } + e, found := ctxt.GetInner(ctx).GetExecutor(host) + if !found { + return fmt.Errorf("no executor") + } + logger := ctx.Value(logprinter.ContextKeyLogger).(*zap.Logger) + logger.Info(fmt.Sprintf("\t%s on %s", action, host)) + stdout, stderr, err := e.Execute(ctx, command, false) + if err != nil { + return errors.WithMessagef(err, "stderr: %s", string(stderr)) + } + logger.Info(fmt.Sprintf("\t%s", stdout)) + return nil +} diff --git a/pkg/cluster/operation/operation.go b/pkg/cluster/operation/operation.go index b411620..53ba84e 100644 --- a/pkg/cluster/operation/operation.go +++ b/pkg/cluster/operation/operation.go @@ -1,5 +1,10 @@ package operation +import ( + "github.com/openGemini/gemix/pkg/cluster/spec" + "github.com/openGemini/gemix/pkg/set" +) + // Options represents the operation options type Options struct { Roles []string @@ -30,3 +35,38 @@ type Options struct { DisplayMode string // the output format // Operation Operation } + +// FilterComponent filter components by set +func FilterComponent(comps []spec.Component, components set.StringSet) (res []spec.Component) { + if len(components) == 0 { + res = comps + return + } + + for _, c := range comps { + role := c.Name() + if !components.Exist(role) { + continue + } + + res = append(res, c) + } + return +} + +// FilterInstance filter instances by set +func FilterInstance(instances []spec.Instance, nodes set.StringSet) (res []spec.Instance) { + if len(nodes) == 0 { + res = instances + return + } + + for _, c := range instances { + if !nodes.Exist(c.ID()) { + continue + } + res = append(res, c) + } + + return +} diff --git a/pkg/cluster/spec/instance.go b/pkg/cluster/spec/instance.go index cef942b..506fabe 100644 --- a/pkg/cluster/spec/instance.go +++ b/pkg/cluster/spec/instance.go @@ -24,8 +24,12 @@ import ( "strings" "time" + "github.com/google/uuid" "github.com/openGemini/gemix/pkg/cluster/ctxt" + "github.com/openGemini/gemix/pkg/cluster/module" + system "github.com/openGemini/gemix/pkg/cluster/template/systemd" "github.com/openGemini/gemix/pkg/meta" + "github.com/pkg/errors" ) // Components names @@ -52,7 +56,7 @@ type Component interface { type Instance interface { InstanceSpec ID() string - //Ready(context.Context, ctxt.Executor, uint64, *tls.Config) error + Ready(context.Context, ctxt.Executor, uint64, *tls.Config) error InitConfig(ctx context.Context, e ctxt.Executor, clusterName string, clusterVersion string, deployUser string, paths meta.DirPaths) error //ScaleConfig(ctx context.Context, e ctxt.Executor, topo Topology, clusterName string, clusterVersion string, deployUser string, paths meta.DirPaths) error PrepareStart(ctx context.Context, tlsCfg *tls.Config) error @@ -75,6 +79,28 @@ type Instance interface { Arch() string } +// PortStarted wait until a port is being listened +func PortStarted(ctx context.Context, e ctxt.Executor, port int, timeout uint64) error { + c := module.WaitForConfig{ + Port: port, + State: "started", + Timeout: time.Second * time.Duration(timeout), + } + w := module.NewWaitFor(c) + return w.Execute(ctx, e) +} + +// PortStopped wait until a port is being released +func PortStopped(ctx context.Context, e ctxt.Executor, port int, timeout uint64) error { + c := module.WaitForConfig{ + Port: port, + State: "stopped", + Timeout: time.Second * time.Duration(timeout), + } + w := module.NewWaitFor(c) + return w.Execute(ctx, e) +} + // BaseInstance implements some method of Instance interface.. type BaseInstance struct { InstanceSpec @@ -94,14 +120,38 @@ type BaseInstance struct { } // Ready implements Instance interface -//func (i *BaseInstance) Ready(ctx context.Context, e ctxt.Executor, timeout uint64, _ *tls.Config) error { -// return PortStarted(ctx, e, i.Port, timeout) -//} +func (i *BaseInstance) Ready(ctx context.Context, e ctxt.Executor, timeout uint64, _ *tls.Config) error { + return PortStarted(ctx, e, i.Port, timeout) +} // InitConfig init the service configuration. -//func (i *BaseInstance) InitConfig(ctx context.Context, e ctxt.Executor, opt GlobalOptions, user string, paths meta.DirPaths) (err error) { -// return nil -//} +func (i *BaseInstance) InitConfig(ctx context.Context, e ctxt.Executor, opt GlobalOptions, user string, paths meta.DirPaths) (err error) { + comp := i.ComponentName() + host := i.GetHost() + port := i.GetPort() + sysCfg := filepath.Join(paths.Cache, fmt.Sprintf("%s-%s-%d.service", comp, host, port)) + + resource := MergeResourceControl(opt.ResourceControl, i.ResourceControl()) + systemCfg := system.NewConfig(comp, user, paths.Deploy). + WithMemoryLimit(resource.MemoryLimit). + WithCPUQuota(resource.CPUQuota). + WithLimitCORE(resource.LimitCORE). + WithIOReadBandwidthMax(resource.IOReadBandwidthMax). + WithIOWriteBandwidthMax(resource.IOWriteBandwidthMax) + + if err = systemCfg.ConfigToFile(sysCfg); err != nil { + return errors.WithStack(err) + } + tgt := filepath.Join("/tmp", comp+"_"+uuid.New().String()+".service") + if err := e.Transfer(ctx, sysCfg, tgt, false, 0, false); err != nil { + return errors.WithMessagef(err, "transfer from %s to %s failed", sysCfg, tgt) + } + cmd := fmt.Sprintf("mv %s /etc/systemd/system/%s-%d.service", tgt, comp, port) + if _, _, err := e.Execute(ctx, cmd, true); err != nil { + return errors.WithMessagef(err, "execute: %s", cmd) + } + return nil +} // MergeServerConfig merges the server configuration and overwrite the global configuration func (i *BaseInstance) MergeServerConfig(ctx context.Context, e ctxt.Executor, globalConf, instanceConf map[string]any, paths meta.DirPaths) error { @@ -268,6 +318,34 @@ func (i *BaseInstance) PrepareStart(ctx context.Context, tlsCfg *tls.Config) err return nil } +// MergeResourceControl merge the rhs into lhs and overwrite rhs if lhs has value for same field +func MergeResourceControl(lhs, rhs meta.ResourceControl) meta.ResourceControl { + if rhs.MemoryLimit != "" { + lhs.MemoryLimit = rhs.MemoryLimit + } + if rhs.CPUQuota != "" { + lhs.CPUQuota = rhs.CPUQuota + } + if rhs.IOReadBandwidthMax != "" { + lhs.IOReadBandwidthMax = rhs.IOReadBandwidthMax + } + if rhs.IOWriteBandwidthMax != "" { + lhs.IOWriteBandwidthMax = rhs.IOWriteBandwidthMax + } + if rhs.LimitCORE != "" { + lhs.LimitCORE = rhs.LimitCORE + } + return lhs +} + +// ResourceControl return cgroups config of instance +func (i *BaseInstance) ResourceControl() meta.ResourceControl { + if v := reflect.Indirect(reflect.ValueOf(i.InstanceSpec)).FieldByName("ResourceControl"); v.IsValid() { + return v.Interface().(meta.ResourceControl) + } + return meta.ResourceControl{} +} + // GetPort implements Instance interface func (i *BaseInstance) GetPort() int { return i.Port diff --git a/pkg/cluster/spec/parse_topology.go b/pkg/cluster/spec/parse_topology.go index 9c6f75d..deb80d6 100644 --- a/pkg/cluster/spec/parse_topology.go +++ b/pkg/cluster/spec/parse_topology.go @@ -76,7 +76,7 @@ func ReadFromYaml(file string) (*Specification, error) { // ParseTopologyYaml read yaml content from `file` and unmarshal it to `out` // ignoreGlobal ignore global variables in file, only ignoreGlobal with a index of 0 is effective -func ParseTopologyYaml(file string, out *Specification, ignoreGlobal ...bool) error { +func ParseTopologyYaml(file string, out Topology, ignoreGlobal ...bool) error { zap.L().Debug("Parse topology file", zap.String("file", file)) yamlFile, err := ReadYamlFile(file) @@ -126,11 +126,12 @@ func Abs(user, path string) string { } // ExpandRelativeDir fill DeployDir, DataDir and LogDir to absolute path -func ExpandRelativeDir(topo *Specification) { +func ExpandRelativeDir(topo Topology) { expandRelativePath(deployUser(topo), topo) } -func expandRelativePath(user string, topo *Specification) { +func expandRelativePath(user string, topology Topology) { + topo := topology.(*Specification) topo.GlobalOptions.DeployDir = Abs(user, topo.GlobalOptions.DeployDir) topo.GlobalOptions.LogDir = Abs(user, topo.GlobalOptions.LogDir) @@ -155,7 +156,7 @@ func expandRelativePath(user string, topo *Specification) { } } -func deployUser(topo *Specification) string { +func deployUser(topo Topology) string { base := topo.BaseTopo() if base.GlobalOptions == nil || base.GlobalOptions.User == "" { return defaultDeployUser diff --git a/pkg/cluster/spec/profile.go b/pkg/cluster/spec/profile.go index 5cc87b1..faeadd6 100644 --- a/pkg/cluster/spec/profile.go +++ b/pkg/cluster/spec/profile.go @@ -65,7 +65,7 @@ func Initialize(base string) error { profileDir = filepath.Join(homeDir, localdata.ProfileDirName, localdata.StorageParentDir, base) } - clusterBaseDir := filepath.Join(profileDir, OpenGeminiClusterDir) // TODO: how to use? + clusterBaseDir := filepath.Join(profileDir, OpenGeminiClusterDir) openGeminiSpec = NewSpec(clusterBaseDir, func() *ClusterMeta { return &ClusterMeta{ Topology: new(Specification), diff --git a/pkg/cluster/spec/spec.go b/pkg/cluster/spec/spec.go index 81bea8f..51ab602 100644 --- a/pkg/cluster/spec/spec.go +++ b/pkg/cluster/spec/spec.go @@ -22,6 +22,7 @@ import ( "sync" "github.com/creasty/defaults" + "github.com/openGemini/gemix/pkg/meta" "github.com/openGemini/gemix/utils" "github.com/pkg/errors" ) @@ -46,16 +47,16 @@ type ( // GlobalOptions represents the global options for all groups in topology // specification in topology.yaml GlobalOptions struct { - User string `yaml:"user,omitempty" default:"gemini"` - Group string `yaml:"group,omitempty"` - SSHPort int `yaml:"ssh_port,omitempty" default:"22" validate:"ssh_port:editable"` - TLSEnabled bool `yaml:"enable_tls,omitempty"` - DeployDir string `yaml:"deploy_dir,omitempty" default:"deploy"` - DataDir string `yaml:"data_dir,omitempty" default:"data"` - LogDir string `yaml:"log_dir,omitempty" default:"data"` - //ResourceControl meta.ResourceControl `yaml:"resource_control,omitempty" validate:"resource_control:editable"` - OS string `yaml:"os,omitempty" default:"linux"` - Arch string `yaml:"arch,omitempty" default:"amd64"` + User string `yaml:"user,omitempty" default:"gemini"` + Group string `yaml:"group,omitempty"` + SSHPort int `yaml:"ssh_port,omitempty" default:"22" validate:"ssh_port:editable"` + TLSEnabled bool `yaml:"enable_tls,omitempty"` + DeployDir string `yaml:"deploy_dir,omitempty" default:"deploy"` + DataDir string `yaml:"data_dir,omitempty" default:"data"` + LogDir string `yaml:"log_dir,omitempty" default:"data"` + ResourceControl meta.ResourceControl `yaml:"resource_control,omitempty" validate:"resource_control:editable"` + OS string `yaml:"os,omitempty" default:"linux"` + Arch string `yaml:"arch,omitempty" default:"amd64"` //Custom any `yaml:"custom,omitempty" validate:"custom:ignore"` } @@ -81,6 +82,31 @@ type ( } ) +// Topology represents specification of the cluster. +type Topology interface { + BaseTopo() *BaseTopo + // Validate validates the topology specification and produce error if the + // specification invalid (e.g: port conflicts or directory conflicts) + Validate() error + + ComponentsByStartOrder() []Component + ComponentsByStopOrder() []Component + //ComponentsByUpdateOrder(curVer string) []Component + IterInstance(fn func(instance Instance), concurrency ...int) + //CountDir(host string, dir string) int // count how many time a path is used by instances in cluster + //TLSConfig(dir string) (*tls.Config, error) + //Merge(that Topology) Topology // TODO: for update + FillHostArchOrOS(hostArchmap map[string]string, fullType FullHostType) error + //GetGrafanaConfig() map[string]string +} + +type BaseTopo struct { + GlobalOptions *GlobalOptions + MasterList []string + + //Grafanas []*GrafanaSpec +} + // BaseMeta is the base info of metadata. type BaseMeta struct { User string @@ -88,6 +114,13 @@ type BaseMeta struct { Version string } +// Metadata of a cluster. +type Metadata interface { + GetTopology() Topology + SetTopology(topo Topology) + GetBaseMeta() *BaseMeta +} + // UnmarshalYAML implements the yaml.Unmarshaler interface, // it sets the default values when unmarshaling the topology file func (s *Specification) UnmarshalYAML(unmarshal func(any) error) error { @@ -282,13 +315,6 @@ func setCustomDefaults(globalOptions *GlobalOptions, field reflect.Value) error return nil } -type BaseTopo struct { - GlobalOptions *GlobalOptions - MasterList []string - - //Grafanas []*GrafanaSpec -} - // GetTSMetaListWithManageHost returns a list of PD API hosts of the current cluster func (s *Specification) GetTSMetaListWithManageHost() []string { var tsMetaList []string diff --git a/pkg/cluster/spec/spec_manager.go b/pkg/cluster/spec/spec_manager.go index 1e5166c..547808e 100644 --- a/pkg/cluster/spec/spec_manager.go +++ b/pkg/cluster/spec/spec_manager.go @@ -15,8 +15,10 @@ package spec import ( + "fmt" "os" "path/filepath" + "reflect" "github.com/joomcode/errorx" "github.com/openGemini/gemix/pkg/gui" @@ -100,6 +102,23 @@ func (s *SpecManager) SaveMeta(clusterName string, meta *ClusterMeta) error { return nil } +// Metadata tries to read the metadata of a cluster from file +func (s *SpecManager) Metadata(clusterName string, meta any) error { + fname := s.Path(clusterName, metaFileName) + + yamlFile, err := os.ReadFile(fname) + if err != nil { + return errors.WithStack(err) + } + + err = yaml.Unmarshal(yamlFile, meta) + if err != nil { + return errors.WithStack(err) + } + + return nil +} + // Exist checks if the cluster exist by checking the meta file. func (s *SpecManager) Exist(clusterName string) (exist bool, err error) { fname := s.Path(clusterName, metaFileName) @@ -143,10 +162,20 @@ type ClusterMeta struct { } // GetTopology implement Metadata interface. -func (m *ClusterMeta) GetTopology() *Specification { +func (m *ClusterMeta) GetTopology() Topology { return m.Topology } +// SetTopology implement Metadata interface. +func (m *ClusterMeta) SetTopology(topo Topology) { + tidbTopo, ok := topo.(*Specification) + if !ok { + panic(fmt.Sprintln("wrong type: ", reflect.TypeOf(topo))) + } + + m.Topology = tidbTopo +} + // GetBaseMeta implements Metadata interface. func (m *ClusterMeta) GetBaseMeta() *BaseMeta { return &BaseMeta{ diff --git a/pkg/cluster/spec/ts_meta.go b/pkg/cluster/spec/ts_meta.go index 58ad12d..821f693 100644 --- a/pkg/cluster/spec/ts_meta.go +++ b/pkg/cluster/spec/ts_meta.go @@ -144,9 +144,9 @@ type TSMetaInstance struct { func (i *TSMetaInstance) InitConfig(ctx context.Context, e ctxt.Executor, clusterName string, clusterVersion string, deployUser string, paths meta.DirPaths) error { topo := i.topo - //if err := i.BaseInstance.InitConfig(ctx, e, topo.GlobalOptions, deployUser, paths); err != nil { - // return err - //} + if err := i.BaseInstance.InitConfig(ctx, e, topo.GlobalOptions, deployUser, paths); err != nil { + return err + } spec := i.InstanceSpec.(*TSMetaSpec) cfg := &scripts.TSMetaScript{ @@ -191,7 +191,7 @@ func (i *TSMetaInstance) SetDefaultConfig(instanceConf map[string]any) map[strin var metaPeerAddrs []string var tsMetaSpec *TSMetaSpec for _, metaSpec := range i.topo.TSMetaServers { - if i.Host == metaSpec.Host { + if i.Host == metaSpec.Host && i.Port == metaSpec.ClientPort { tsMetaSpec = metaSpec } metaPeerAddrs = append(metaPeerAddrs, utils.JoinHostPort(metaSpec.Host, metaSpec.PeerPort)) diff --git a/pkg/cluster/spec/ts_sql.go b/pkg/cluster/spec/ts_sql.go index 5c35557..cbd9b35 100644 --- a/pkg/cluster/spec/ts_sql.go +++ b/pkg/cluster/spec/ts_sql.go @@ -138,9 +138,9 @@ type TSSqlInstance struct { func (i *TSSqlInstance) InitConfig(ctx context.Context, e ctxt.Executor, clusterName string, clusterVersion string, deployUser string, paths meta.DirPaths) error { topo := i.topo - //if err := i.BaseInstance.InitConfig(ctx, e, topo.GlobalOptions, deployUser, paths); err != nil { - // return err - //} + if err := i.BaseInstance.InitConfig(ctx, e, topo.GlobalOptions, deployUser, paths); err != nil { + return err + } spec := i.InstanceSpec.(*TSSqlSpec) cfg := &scripts.TSSqlScript{ diff --git a/pkg/cluster/spec/ts_store.go b/pkg/cluster/spec/ts_store.go index b7aca5c..5882a2e 100644 --- a/pkg/cluster/spec/ts_store.go +++ b/pkg/cluster/spec/ts_store.go @@ -109,7 +109,7 @@ func (c *TSStoreComponent) Instances() []Instance { Host: s.Host, ManageHost: s.ManageHost, ListenHost: s.ListenHost, - Port: s.SelectPort, + Port: s.SelectPort, // do not change me SSHP: s.SSHPort, Source: s.GetSource(), @@ -143,9 +143,9 @@ type TSStoreInstance struct { func (i *TSStoreInstance) InitConfig(ctx context.Context, e ctxt.Executor, clusterName string, clusterVersion string, deployUser string, paths meta.DirPaths) error { topo := i.topo - //if err := i.BaseInstance.InitConfig(ctx, e, topo.GlobalOptions, deployUser, paths); err != nil { - // return err - //} + if err := i.BaseInstance.InitConfig(ctx, e, topo.GlobalOptions, deployUser, paths); err != nil { + return err + } //enableTLS := topo.GlobalOptions.TLSEnabled spec := i.InstanceSpec.(*TSStoreSpec) @@ -199,7 +199,7 @@ func (i *TSStoreInstance) SetDefaultConfig(instanceConf map[string]any) map[stri var tsStoreSpec *TSStoreSpec for _, storeSpec := range i.topo.TSStoreServers { - if i.Host == storeSpec.Host { + if i.Host == storeSpec.Host && i.Port == storeSpec.SelectPort { tsStoreSpec = storeSpec } } diff --git a/pkg/cluster/spec/validate.go b/pkg/cluster/spec/validate.go new file mode 100644 index 0000000..ee5bec1 --- /dev/null +++ b/pkg/cluster/spec/validate.go @@ -0,0 +1,21 @@ +// Copyright 2023 Huawei Cloud Computing Technologies Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spec + +// Validate validates the topology specification and produce error if the +// specification invalid (e.g: port conflicts or directory conflicts) +func (s *Specification) Validate() error { + return nil +} diff --git a/pkg/cluster/task/builder.go b/pkg/cluster/task/builder.go index a879135..119b59d 100644 --- a/pkg/cluster/task/builder.go +++ b/pkg/cluster/task/builder.go @@ -14,6 +14,8 @@ package task import ( + "context" + "github.com/openGemini/gemix/pkg/cluster/spec" "github.com/openGemini/gemix/pkg/meta" "go.uber.org/zap" @@ -70,6 +72,34 @@ func (b *Builder) UserSSH(host string, port int, deployUser string, sshTimeout, return b } +// Func append a func task. +func (b *Builder) Func(name string, fn func(ctx context.Context) error) *Builder { + b.tasks = append(b.tasks, &Func{ + name: name, + fn: fn, + }) + return b +} + +// ClusterSSH init all UserSSH need for the cluster. +func (b *Builder) ClusterSSH( + topo spec.Topology, + deployUser string, sshTimeout, exeTimeout uint64) *Builder { + var tasks []Task + topo.IterInstance(func(inst spec.Instance) { + tasks = append(tasks, &UserSSH{ + host: inst.GetManageHost(), + port: inst.GetSSHPort(), + deployUser: deployUser, + timeout: sshTimeout, + exeTimeout: exeTimeout, + }) + }) + + b.tasks = append(b.tasks, &Parallel{inner: tasks}) + return b +} + // Download appends a Downloader task to the current task collection func (b *Builder) Download(component, os, arch string, version string) *Builder { b.tasks = append(b.tasks, NewDownloader(component, os, arch, version)) @@ -141,24 +171,13 @@ func (b *Builder) SSHKeyGen(keypath string) *Builder { } // SSHKeySet appends a SSHKeySet task to the current task collection -//func (b *Builder) SSHKeySet(privKeyPath, pubKeyPath string) *Builder { -// b.tasks = append(b.tasks, &SSHKeySet{ -// privateKeyPath: privKeyPath, -// publicKeyPath: pubKeyPath, -// }) -// return b -//} - -// EnvInit appends a EnvInit task to the current task collection -//func (b *Builder) EnvInit(host, deployUser string, userGroup string, skipCreateUser bool) *Builder { -// b.tasks = append(b.tasks, &EnvInit{ -// host: host, -// deployUser: deployUser, -// userGroup: userGroup, -// skipCreateUser: skipCreateUser, -// }) -// return b -//} +func (b *Builder) SSHKeySet(privKeyPath, pubKeyPath string) *Builder { + b.tasks = append(b.tasks, &SSHKeySet{ + privateKeyPath: privKeyPath, + publicKeyPath: pubKeyPath, + }) + return b +} // Mkdir appends a Mkdir task to the current task collection func (b *Builder) Mkdir(user, host string, dirs ...string) *Builder { diff --git a/pkg/cluster/task/func.go b/pkg/cluster/task/func.go new file mode 100644 index 0000000..1880448 --- /dev/null +++ b/pkg/cluster/task/func.go @@ -0,0 +1,45 @@ +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package task + +import "context" + +// Func wrap a closure. +type Func struct { + name string + fn func(ctx context.Context) error +} + +// NewFunc create a Func task +func NewFunc(name string, fn func(ctx context.Context) error) *Func { + return &Func{ + name: name, + fn: fn, + } +} + +// Execute implements the Task interface +func (m *Func) Execute(ctx context.Context) error { + return m.fn(ctx) +} + +// Rollback implements the Task interface +func (m *Func) Rollback(_ context.Context) error { + return ErrUnsupportedRollback +} + +// String implements the fmt.Stringer interface +func (m *Func) String() string { + return m.name +} diff --git a/pkg/cluster/task/ssh.go b/pkg/cluster/task/ssh.go index c6dba36..b050df1 100644 --- a/pkg/cluster/task/ssh.go +++ b/pkg/cluster/task/ssh.go @@ -99,9 +99,9 @@ type UserSSH struct { // Execute implements the Task interface func (s *UserSSH) Execute(ctx context.Context) error { sc := executor.SSHConfig{ - Host: s.host, - Port: s.port, - //KeyFile: ctxt.GetInner(ctx).PrivateKeyPath, + Host: s.host, + Port: s.port, + KeyFile: ctxt.GetInner(ctx).PrivateKeyPath, Password: s.proxyPassword, User: s.deployUser, Timeout: time.Second * time.Duration(s.timeout), diff --git a/pkg/cluster/task/ssh_keyset.go b/pkg/cluster/task/ssh_keyset.go new file mode 100644 index 0000000..7907cdc --- /dev/null +++ b/pkg/cluster/task/ssh_keyset.go @@ -0,0 +1,46 @@ +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package task + +import ( + "context" + "fmt" + + "github.com/openGemini/gemix/pkg/cluster/ctxt" +) + +// SSHKeySet is used to set the Context private/public key path +type SSHKeySet struct { + privateKeyPath string + publicKeyPath string +} + +// Execute implements the Task interface +func (s *SSHKeySet) Execute(ctx context.Context) error { + ctxt.GetInner(ctx).PublicKeyPath = s.publicKeyPath + ctxt.GetInner(ctx).PrivateKeyPath = s.privateKeyPath + return nil +} + +// Rollback implements the Task interface +func (s *SSHKeySet) Rollback(ctx context.Context) error { + ctxt.GetInner(ctx).PublicKeyPath = "" + ctxt.GetInner(ctx).PrivateKeyPath = "" + return nil +} + +// String implements the fmt.Stringer interface +func (s *SSHKeySet) String() string { + return fmt.Sprintf("SSHKeySet: privateKey=%s, publicKey=%s", s.privateKeyPath, s.publicKeyPath) +} diff --git a/pkg/cluster/template/systemd/system.go b/pkg/cluster/template/systemd/system.go new file mode 100644 index 0000000..2865a8c --- /dev/null +++ b/pkg/cluster/template/systemd/system.go @@ -0,0 +1,114 @@ +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package system + +import ( + "bytes" + "os" + "path" + "strings" + "text/template" + + "github.com/openGemini/gemix/embed" +) + +// Config represent the data to generate systemd config +type Config struct { + ServiceName string + User string + MemoryLimit string + CPUQuota string + IOReadBandwidthMax string + IOWriteBandwidthMax string + LimitCORE string + DeployDir string + DisableSendSigkill bool + GrantCapNetRaw bool + // Takes one of no, on-success, on-failure, on-abnormal, on-watchdog, on-abort, or always. + // The Template set as always if this is not setted. + Restart string +} + +// NewConfig returns a Config with given arguments +func NewConfig(service, user, deployDir string) *Config { + return &Config{ + ServiceName: strings.ReplaceAll(service, "-", "_"), + User: user, + DeployDir: deployDir, + } +} + +// WithMemoryLimit set the MemoryLimit field of Config +func (c *Config) WithMemoryLimit(mem string) *Config { + c.MemoryLimit = mem + return c +} + +// WithCPUQuota set the CPUQuota field of Config +func (c *Config) WithCPUQuota(cpu string) *Config { + c.CPUQuota = cpu + return c +} + +// WithIOReadBandwidthMax set the IOReadBandwidthMax field of Config +func (c *Config) WithIOReadBandwidthMax(io string) *Config { + c.IOReadBandwidthMax = io + return c +} + +// WithIOWriteBandwidthMax set the IOWriteBandwidthMax field of Config +func (c *Config) WithIOWriteBandwidthMax(io string) *Config { + c.IOWriteBandwidthMax = io + return c +} + +// WithLimitCORE set the LimitCORE field of Config +func (c *Config) WithLimitCORE(core string) *Config { + c.LimitCORE = core + return c +} + +// ConfigToFile write config content to specific path +func (c *Config) ConfigToFile(file string) error { + config, err := c.Config() + if err != nil { + return err + } + return os.WriteFile(file, config, 0750) +} + +// Config generate the config file data. +func (c *Config) Config() ([]byte, error) { + fp := path.Join("templates", "systemd", "system.service.tpl") + tpl, err := embed.ReadTemplate(fp) + if err != nil { + return nil, err + } + return c.ConfigWithTemplate(string(tpl)) +} + +// ConfigWithTemplate generate the system config content by tpl +func (c *Config) ConfigWithTemplate(tpl string) ([]byte, error) { + tmpl, err := template.New("system").Parse(tpl) + if err != nil { + return nil, err + } + + content := bytes.NewBufferString("") + if err := tmpl.Execute(content, c); err != nil { + return nil, err + } + + return content.Bytes(), nil +} diff --git a/pkg/meta/resource_ctrl.go b/pkg/meta/resource_ctrl.go new file mode 100644 index 0000000..b20bec9 --- /dev/null +++ b/pkg/meta/resource_ctrl.go @@ -0,0 +1,24 @@ +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package meta + +// ResourceControl is used to control the system resource +// See: https://www.freedesktop.org/software/systemd/man/systemd.resource-control.html +type ResourceControl struct { + MemoryLimit string `yaml:"memory_limit,omitempty" validate:"memory_limit:editable"` + CPUQuota string `yaml:"cpu_quota,omitempty" validate:"cpu_quota:editable"` + IOReadBandwidthMax string `yaml:"io_read_bandwidth_max,omitempty" validate:"io_read_bandwidth_max:editable"` + IOWriteBandwidthMax string `yaml:"io_write_bandwidth_max,omitempty" validate:"io_write_bandwidth_max:editable"` + LimitCORE string `yaml:"limit_core,omitempty" validate:"limit_core:editable"` +} diff --git a/pkg/set/any_set.go b/pkg/set/any_set.go new file mode 100644 index 0000000..a428f60 --- /dev/null +++ b/pkg/set/any_set.go @@ -0,0 +1,90 @@ +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package set + +// AnySet is a set stores any +type AnySet struct { + eq func(a any, b any) bool + slice []any +} + +// NewAnySet builds a AnySet +func NewAnySet(eq func(a any, b any) bool, aa ...any) *AnySet { + slice := []any{} +out: + for _, a := range aa { + for _, b := range slice { + if eq(a, b) { + continue out + } + } + slice = append(slice, a) + } + return &AnySet{eq, slice} +} + +// Exist checks whether `val` exists in `s`. +func (s *AnySet) Exist(val any) bool { + for _, a := range s.slice { + if s.eq(a, val) { + return true + } + } + return false +} + +// Insert inserts `val` into `s`. +func (s *AnySet) Insert(val any) { + if !s.Exist(val) { + s.slice = append(s.slice, val) + } +} + +// Intersection returns the intersection of two sets +func (s *AnySet) Intersection(rhs *AnySet) *AnySet { + newSet := NewAnySet(s.eq) + for elt := range rhs.slice { + if s.Exist(elt) { + newSet.Insert(elt) + } + } + return newSet +} + +// Remove removes `val` from `s` +func (s *AnySet) Remove(val any) { + for i, a := range s.slice { + if s.eq(a, val) { + s.slice = append(s.slice[:i], s.slice[i+1:]...) + return + } + } +} + +// Difference returns the difference of two sets +func (s *AnySet) Difference(rhs *AnySet) *AnySet { + newSet := NewAnySet(s.eq) + diffSet := NewAnySet(s.eq, rhs.slice...) + for elt := range s.slice { + if !diffSet.Exist(elt) { + newSet.Insert(elt) + } + } + return newSet +} + +// Slice converts the set to a slice +func (s *AnySet) Slice() []any { + return append([]any{}, s.slice...) +} diff --git a/pkg/set/string_set.go b/pkg/set/string_set.go new file mode 100644 index 0000000..a8784e5 --- /dev/null +++ b/pkg/set/string_set.go @@ -0,0 +1,81 @@ +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package set + +// StringSet is a string set. +type StringSet map[string]struct{} + +// NewStringSet builds a string set. +func NewStringSet(ss ...string) StringSet { + set := make(StringSet) + for _, s := range ss { + set.Insert(s) + } + return set +} + +// Exist checks whether `val` exists in `s`. +func (s StringSet) Exist(val string) bool { + _, ok := s[val] + return ok +} + +// Insert inserts `val` into `s`. +func (s StringSet) Insert(val string) { + s[val] = struct{}{} +} + +// Join add all elements of `add` to `s`. +func (s StringSet) Join(add StringSet) StringSet { + for elt := range add { + s.Insert(elt) + } + return s +} + +// Intersection returns the intersection of two sets +func (s StringSet) Intersection(rhs StringSet) StringSet { + newSet := NewStringSet() + for elt := range s { + if rhs.Exist(elt) { + newSet.Insert(elt) + } + } + return newSet +} + +// Remove removes `val` from `s` +func (s StringSet) Remove(val string) { + delete(s, val) +} + +// Difference returns the difference of two sets +func (s StringSet) Difference(rhs StringSet) StringSet { + newSet := NewStringSet() + for elt := range s { + if !rhs.Exist(elt) { + newSet.Insert(elt) + } + } + return newSet +} + +// Slice converts the set to a slice +func (s StringSet) Slice() []string { + res := make([]string, 0) + for val := range s { + res = append(res, val) + } + return res +} diff --git a/pkg/utils/retry.go b/pkg/utils/retry.go new file mode 100644 index 0000000..3a1f409 --- /dev/null +++ b/pkg/utils/retry.go @@ -0,0 +1,112 @@ +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package utils + +import ( + "fmt" + "strings" + "time" +) + +// RetryUntil when the `when` func returns true +func RetryUntil(f func() error, when func(error) bool) error { + e := f() + if e == nil { + return nil + } + if when == nil { + return RetryUntil(f, nil) + } else if when(e) { + return RetryUntil(f, when) + } + return e +} + +// RetryOption is options for Retry() +type RetryOption struct { + Attempts int64 + Delay time.Duration + Timeout time.Duration +} + +// default values for RetryOption +var ( + defaultAttempts int64 = 20 + defaultDelay = time.Millisecond * 500 // 500ms + defaultTimeout = time.Second * 10 // 10s +) + +// Retry retries the func until it returns no error or reaches attempts limit or +// timed out, either one is earlier +func Retry(doFunc func() error, opts ...RetryOption) error { + var cfg RetryOption + if len(opts) > 0 { + cfg = opts[0] + } else { + cfg = RetryOption{ + Attempts: defaultAttempts, + Delay: defaultDelay, + Timeout: defaultTimeout, + } + } + + // timeout must be greater than 0 + if cfg.Timeout <= 0 { + return fmt.Errorf("timeout (%s) must be greater than 0", cfg.Timeout) + } + // set options automatically for invalid value + if cfg.Delay <= 0 { + cfg.Delay = defaultDelay + } + if cfg.Attempts <= 0 { + cfg.Attempts = cfg.Timeout.Milliseconds()/cfg.Delay.Milliseconds() + 1 + } + + timeoutChan := time.After(cfg.Timeout) + + // call the function + var attemptCount int64 + var err error + for attemptCount = 0; attemptCount < cfg.Attempts; attemptCount++ { + if err = doFunc(); err == nil { + return nil + } + + // check for timeout + select { + case <-timeoutChan: + return fmt.Errorf("operation timed out after %s", cfg.Timeout) + default: + time.Sleep(cfg.Delay) + } + } + + return fmt.Errorf("operation exceeds the max retry attempts of %d. error of last attempt: %s", cfg.Attempts, err) +} + +// IsTimeoutOrMaxRetry return true if it's timeout or reach max retry. +func IsTimeoutOrMaxRetry(err error) bool { + if err == nil { + return false + } + + s := err.Error() + + if strings.Contains(s, "operation timed out after") || + strings.Contains(s, "operation exceeds the max retry attempts of") { + return true + } + + return false +}