From a3fa2c3986129cc44d3dc7c722431ec3a7d96a1d Mon Sep 17 00:00:00 2001 From: Vadym Popov Date: Tue, 15 Oct 2024 17:35:00 -0400 Subject: [PATCH] Replace fork for posix platform for re-exec Move integration tests to client tools specific dir Use context cancellation with SIGTERM, SIGINT Remove cancelable tee reader with context replacement Renaming --- .../{ => tools}/client_update_test.go | 2 +- .../autoupdate/{ => tools}/helper_test.go | 2 +- .../{ => tools}/helper_unix_test.go | 2 +- .../{ => tools}/helper_windows_test.go | 2 +- .../autoupdate/{ => tools}/main_test.go | 2 +- .../autoupdate/{ => tools}/updater/main.go | 2 +- lib/autoupdate/client_update.go | 69 ++++++++++++------- lib/autoupdate/progress.go | 38 ---------- lib/autoupdate/utils.go | 2 +- 9 files changed, 50 insertions(+), 71 deletions(-) rename integration/autoupdate/{ => tools}/client_update_test.go (99%) rename integration/autoupdate/{ => tools}/helper_test.go (99%) rename integration/autoupdate/{ => tools}/helper_unix_test.go (97%) rename integration/autoupdate/{ => tools}/helper_windows_test.go (98%) rename integration/autoupdate/{ => tools}/main_test.go (99%) rename integration/autoupdate/{ => tools}/updater/main.go (97%) diff --git a/integration/autoupdate/client_update_test.go b/integration/autoupdate/tools/client_update_test.go similarity index 99% rename from integration/autoupdate/client_update_test.go rename to integration/autoupdate/tools/client_update_test.go index 0fb76bef5c3d2..d30516705252e 100644 --- a/integration/autoupdate/client_update_test.go +++ b/integration/autoupdate/tools/client_update_test.go @@ -16,7 +16,7 @@ * along with this program. If not, see . */ -package autoupdate_test +package tools_test import ( "bytes" diff --git a/integration/autoupdate/helper_test.go b/integration/autoupdate/tools/helper_test.go similarity index 99% rename from integration/autoupdate/helper_test.go rename to integration/autoupdate/tools/helper_test.go index ada3539186a1b..a3c37a9e94b55 100644 --- a/integration/autoupdate/helper_test.go +++ b/integration/autoupdate/tools/helper_test.go @@ -16,7 +16,7 @@ * along with this program. If not, see . */ -package autoupdate_test +package tools_test import ( "net/http" diff --git a/integration/autoupdate/helper_unix_test.go b/integration/autoupdate/tools/helper_unix_test.go similarity index 97% rename from integration/autoupdate/helper_unix_test.go rename to integration/autoupdate/tools/helper_unix_test.go index 632defbc5f29b..61ba0766b90d4 100644 --- a/integration/autoupdate/helper_unix_test.go +++ b/integration/autoupdate/tools/helper_unix_test.go @@ -18,7 +18,7 @@ * along with this program. If not, see . */ -package autoupdate_test +package tools_test import ( "errors" diff --git a/integration/autoupdate/helper_windows_test.go b/integration/autoupdate/tools/helper_windows_test.go similarity index 98% rename from integration/autoupdate/helper_windows_test.go rename to integration/autoupdate/tools/helper_windows_test.go index 89d109b0ad26d..b2ede9ade8c19 100644 --- a/integration/autoupdate/helper_windows_test.go +++ b/integration/autoupdate/tools/helper_windows_test.go @@ -18,7 +18,7 @@ * along with this program. If not, see . */ -package autoupdate_test +package tools_test import ( "syscall" diff --git a/integration/autoupdate/main_test.go b/integration/autoupdate/tools/main_test.go similarity index 99% rename from integration/autoupdate/main_test.go rename to integration/autoupdate/tools/main_test.go index 4df033950c450..06f36894217d8 100644 --- a/integration/autoupdate/main_test.go +++ b/integration/autoupdate/tools/main_test.go @@ -16,7 +16,7 @@ * along with this program. If not, see . */ -package autoupdate_test +package tools_test import ( "context" diff --git a/integration/autoupdate/updater/main.go b/integration/autoupdate/tools/updater/main.go similarity index 97% rename from integration/autoupdate/updater/main.go rename to integration/autoupdate/tools/updater/main.go index 63c82adf1d13c..ba6663f8a8165 100644 --- a/integration/autoupdate/updater/main.go +++ b/integration/autoupdate/tools/updater/main.go @@ -52,7 +52,7 @@ func main() { // Download and update the version of client tools required by the cluster. // This is required if the user passed in the TELEPORT_TOOLS_VERSION explicitly. err := updater.UpdateWithLock(ctx, toolsVersion) - if errors.Is(err, autoupdate.ErrCanceled) { + if errors.Is(err, context.Canceled) { os.Exit(0) return } diff --git a/lib/autoupdate/client_update.go b/lib/autoupdate/client_update.go index ac089b073e120..534ad4aeadf40 100644 --- a/lib/autoupdate/client_update.go +++ b/lib/autoupdate/client_update.go @@ -26,9 +26,11 @@ import ( "encoding/hex" "fmt" "io" + "log/slog" "net/http" "os" "os/exec" + "os/signal" "path/filepath" "regexp" "runtime" @@ -47,13 +49,13 @@ import ( const ( // teleportToolsVersionEnv is environment name for requesting specific version for update. teleportToolsVersionEnv = "TELEPORT_TOOLS_VERSION" - // baseUrl is CDN URL for downloading official Teleport packages. - baseUrl = "https://cdn.teleport.dev" + // baseURL is CDN URL for downloading official Teleport packages. + baseURL = "https://cdn.teleport.dev" // checksumHexLen is length of the hash sum. checksumHexLen = 64 // reservedFreeDisk is the predefined amount of free disk space (in bytes) required // to remain available after downloading archives. - reservedFreeDisk = 10 * 1024 * 1024 + reservedFreeDisk = 10 * 1024 * 1024 // 10 Mb // lockFileName is file used for locking update process in parallel. lockFileName = ".lock" // updatePackageSuffix is directory suffix used for package extraction in tools directory. @@ -65,18 +67,18 @@ var ( pattern = regexp.MustCompile(`(?m)Teleport v(.*) git`) ) -// Option applies an option value for the ClientUpdater. -type Option func(u *ClientUpdater) +// ClientOption applies an option value for the ClientUpdater. +type ClientOption func(u *ClientUpdater) // WithBaseURL defines custom base url for the updater. -func WithBaseURL(baseUrl string) Option { +func WithBaseURL(baseUrl string) ClientOption { return func(u *ClientUpdater) { u.baseUrl = baseUrl } } // WithClient defines custom http client for the ClientUpdater. -func WithClient(client *http.Client) Option { +func WithClient(client *http.Client) ClientOption { return func(u *ClientUpdater) { u.client = client } @@ -93,12 +95,12 @@ type ClientUpdater struct { } // NewClientUpdater initiate updater for the client tools auto updates. -func NewClientUpdater(tools []string, toolsDir string, localVersion string, options ...Option) *ClientUpdater { +func NewClientUpdater(tools []string, toolsDir string, localVersion string, options ...ClientOption) *ClientUpdater { updater := &ClientUpdater{ tools: tools, toolsDir: toolsDir, localVersion: localVersion, - baseUrl: baseUrl, + baseUrl: baseURL, client: http.DefaultClient, } for _, option := range options { @@ -123,7 +125,7 @@ func (u *ClientUpdater) CheckLocal() (string, bool) { // If a version of client tools has already been downloaded to // tools directory, return that. - toolsVersion, err := version(u.toolsDir) + toolsVersion, err := checkClientToolVersion(u.toolsDir) if err != nil { return "", false } @@ -165,7 +167,7 @@ func (u *ClientUpdater) CheckRemote(ctx context.Context, proxyAddr string) (stri // If a version of client tools has already been downloaded to // tools directory, return that. - toolsVersion, err := version(u.toolsDir) + toolsVersion, err := checkClientToolVersion(u.toolsDir) if err != nil { return "", false, trace.Wrap(err) } @@ -190,7 +192,7 @@ func (u *ClientUpdater) UpdateWithLock(ctx context.Context, toolsVersion string) if err := os.MkdirAll(u.toolsDir, 0o755); err != nil { return trace.Wrap(err) } - // Lock concurrent {tsh, tctl} execution util requested version is updated. + // Lock concurrent client tools execution util requested version is updated. unlock, err := utils.FSWriteLock(filepath.Join(u.toolsDir, lockFileName)) if err != nil { return trace.Wrap(err) @@ -202,7 +204,7 @@ func (u *ClientUpdater) UpdateWithLock(ctx context.Context, toolsVersion string) // If the version of the running binary or the version downloaded to // tools directory is the same as the requested version of client tools, // nothing to be done, exit early. - teleportVersion, err := version(u.toolsDir) + teleportVersion, err := checkClientToolVersion(u.toolsDir) if err != nil && !trace.IsNotFound(err) { return trace.Wrap(err) @@ -211,7 +213,7 @@ func (u *ClientUpdater) UpdateWithLock(ctx context.Context, toolsVersion string) return nil } - // Download and update {tsh, tctl} in tools directory. + // Download and update client tools in tools directory. if err := u.Update(ctx, toolsVersion); err != nil { return trace.Wrap(err) } @@ -228,17 +230,24 @@ func (u *ClientUpdater) Update(ctx context.Context, toolsVersion string) error { return trace.Wrap(err) } + signalCtx, cancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM) + defer cancel() + // Download the archive and validate against the hash. Download to a // temporary path within tools directory. - hash, err := u.downloadHash(ctx, hashURL) + hash, err := u.downloadHash(signalCtx, hashURL) if err != nil { return trace.Wrap(err) } - archivePath, err := u.downloadArchive(ctx, u.toolsDir, archiveURL, hash) + archivePath, err := u.downloadArchive(signalCtx, u.toolsDir, archiveURL, hash) if err != nil { return trace.Wrap(err) } - defer os.Remove(archivePath) + defer func() { + if err := os.Remove(archivePath); err != nil { + slog.WarnContext(ctx, "failed to remove archive", "error", err) + } + }() pkgName := fmt.Sprint(uuid.New().String(), updatePackageSuffix) extractDir := filepath.Join(u.toolsDir, pkgName) @@ -266,19 +275,27 @@ func (u *ClientUpdater) Exec() (int, error) { if err != nil { return 0, trace.Wrap(err) } - - cmd := exec.Command(path, os.Args[1:]...) // To prevent re-execution loop we have to disable update logic for re-execution. - cmd.Env = append(os.Environ(), teleportToolsVersionEnv+"=off") - cmd.Stdin = os.Stdin - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr + env := append(os.Environ(), teleportToolsVersionEnv+"=off") + + if runtime.GOOS == constants.WindowsOS { + cmd := exec.Command(path, os.Args[1:]...) + cmd.Env = env + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return 0, trace.Wrap(err) + } + + return cmd.ProcessState.ExitCode(), nil + } - if err := cmd.Run(); err != nil { + if err := syscall.Exec(path, os.Args, env); err != nil { return 0, trace.Wrap(err) } - return cmd.ProcessState.ExitCode(), nil + return 0, nil } func (u *ClientUpdater) downloadHash(ctx context.Context, url string) (string, error) { @@ -336,7 +353,7 @@ func (u *ClientUpdater) downloadArchive(ctx context.Context, downloadDir string, h := sha256.New() pw := &progressWriter{n: 0, limit: resp.ContentLength} - body := cancelableTeeReader(io.TeeReader(resp.Body, h), pw, syscall.SIGINT, syscall.SIGTERM) + body := io.TeeReader(io.TeeReader(resp.Body, h), pw) // It is a little inefficient to download the file to disk and then re-load // it into memory to unarchive later, but this is safer as it allows {tsh, diff --git a/lib/autoupdate/progress.go b/lib/autoupdate/progress.go index d455c2bfee487..60d94cb65e886 100644 --- a/lib/autoupdate/progress.go +++ b/lib/autoupdate/progress.go @@ -20,47 +20,9 @@ package autoupdate import ( "fmt" - "io" - "os" - "os/signal" "strings" ) -var ( - // ErrCanceled represent the cancellation error for Ctrl-Break/Ctrl-C, depends on the platform. - ErrCanceled = fmt.Errorf("canceled") -) - -// cancelableTeeReader is a copy of TeeReader with ability to react on signal notifier -// to cancel reading process. -func cancelableTeeReader(r io.Reader, w io.Writer, signals ...os.Signal) io.Reader { - sigs := make(chan os.Signal, 1) - signal.Notify(sigs, signals...) - - return &teeReader{r, w, sigs} -} - -type teeReader struct { - r io.Reader - w io.Writer - sigs chan os.Signal -} - -func (t *teeReader) Read(p []byte) (n int, err error) { - select { - case <-t.sigs: - return 0, ErrCanceled - default: - n, err = t.r.Read(p) - if n > 0 { - if n, err := t.w.Write(p[:n]); err != nil { - return n, err - } - } - } - return -} - type progressWriter struct { n int64 limit int64 diff --git a/lib/autoupdate/utils.go b/lib/autoupdate/utils.go index d326888f14f42..04733018e7490 100644 --- a/lib/autoupdate/utils.go +++ b/lib/autoupdate/utils.go @@ -63,7 +63,7 @@ func ToolsDir() (string, error) { return filepath.Join(filepath.Clean(home), ".tsh", "bin"), nil } -func version(toolsDir string) (string, error) { +func checkClientToolVersion(toolsDir string) (string, error) { // Find the path to the current executable. path, err := toolName(toolsDir) if err != nil {