diff --git a/lib/autoupdate/tools/progress.go b/lib/autoupdate/tools/progress.go index 95395003730ec..5e9f7f4416a04 100644 --- a/lib/autoupdate/tools/progress.go +++ b/lib/autoupdate/tools/progress.go @@ -20,24 +20,54 @@ package tools import ( "fmt" + "io" "strings" + + "github.com/gravitational/trace" ) type progressWriter struct { - n int64 - limit int64 + n int64 + limit int64 + size int + progress int } -func (w *progressWriter) Write(p []byte) (int, error) { - w.n = w.n + int64(len(p)) +// newProgressWriter creates progress writer instance and prints empty +// progress bar right after initialisation. +func newProgressWriter(size int) (*progressWriter, func()) { + pw := &progressWriter{size: size} + pw.Print(0) + return pw, func() { + fmt.Print("\n") + } +} - n := int((w.n*100)/w.limit) / 10 - bricks := strings.Repeat("▒", n) + strings.Repeat(" ", 10-n) +// Print prints the update progress bar with `n` bricks. +func (w *progressWriter) Print(n int) { + bricks := strings.Repeat("▒", n) + strings.Repeat(" ", w.size-n) fmt.Print("\rUpdate progress: [" + bricks + "] (Ctrl-C to cancel update)") +} - if w.n == w.limit { - fmt.Print("\n") +func (w *progressWriter) Write(p []byte) (int, error) { + if w.limit == 0 || w.size == 0 { + return 0, io.EOF + } + + w.n += int64(len(p)) + bricks := int((w.n*100)/w.limit) / w.size + if w.progress != bricks { + w.Print(bricks) + w.progress = bricks } return len(p), nil } + +// CopyLimit sets the limit of writing bytes to the progress writer and initiate copying process. +func (w *progressWriter) CopyLimit(dst io.Writer, src io.Reader, limit int64) (written int64, err error) { + w.limit = limit + n, err := io.CopyN(dst, io.TeeReader(src, w), limit) + + return n, trace.Wrap(err) +} diff --git a/lib/autoupdate/tools/updater.go b/lib/autoupdate/tools/updater.go index e75d532d8883e..7b491b9ade792 100644 --- a/lib/autoupdate/tools/updater.go +++ b/lib/autoupdate/tools/updater.go @@ -270,14 +270,6 @@ func (u *Updater) Update(ctx context.Context, toolsVersion string) error { // update downloads the archive and validate against the hash. Download to a // temporary path within tools directory. func (u *Updater) update(ctx context.Context, pkg packageURL, pkgName string) error { - hash, err := u.downloadHash(ctx, pkg.Hash) - if pkg.Optional && trace.IsNotFound(err) { - return nil - } - if err != nil { - return trace.Wrap(err) - } - f, err := os.CreateTemp(u.toolsDir, "tmp-") if err != nil { return trace.Wrap(err) @@ -296,6 +288,15 @@ func (u *Updater) update(ctx context.Context, pkg packageURL, pkgName string) er if err != nil { return trace.Wrap(err) } + + hash, err := u.downloadHash(ctx, pkg.Hash) + if pkg.Optional && trace.IsNotFound(err) { + return nil + } + if err != nil { + return trace.Wrap(err) + } + if !bytes.Equal(archiveHash, hash) { return trace.BadParameter("hash of archive does not match downloaded archive") } @@ -378,6 +379,11 @@ func (u *Updater) downloadHash(ctx context.Context, url string) ([]byte, error) // downloadArchive downloads the archive package by `url` and writes content to the writer interface, // return calculated sha256 hash sum of the content. func (u *Updater) downloadArchive(ctx context.Context, url string, f io.Writer) ([]byte, error) { + // Display a progress bar before initiating the update request to inform the user that + // an update is in progress, allowing them the option to cancel before actual response + // which might be delayed with slow internet connection or complete isolation to CDN. + pw, finish := newProgressWriter(10) + defer finish() req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, trace.Wrap(err) @@ -401,14 +407,10 @@ func (u *Updater) downloadArchive(ctx context.Context, url string, f io.Writer) } h := sha256.New() - pw := &progressWriter{n: 0, limit: resp.ContentLength} - body := io.TeeReader(io.TeeReader(resp.Body, h), pw) - // It is a little inefficient to download the file to disk and then re-load // it into memory to unarchive later, but this is safer as it allows client // tools to validate the hash before trying to operate on the archive. - _, err = io.CopyN(f, body, resp.ContentLength) - if err != nil { + if _, err := pw.CopyLimit(f, io.TeeReader(resp.Body, h), resp.ContentLength); err != nil { return nil, trace.Wrap(err) } diff --git a/lib/client/api.go b/lib/client/api.go index 88bbc3b06ee5f..7733b9d4ba657 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -718,16 +718,17 @@ func RetryWithRelogin(ctx context.Context, tc *TeleportClient, fn func() error, if reExec { // Download the version of client tools required by the cluster. err := updater.UpdateWithLock(ctx, toolsVersion) - if err != nil { - return trace.Wrap(err) + if err != nil && !errors.Is(err, context.Canceled) { + utils.FatalError(err) } - // Re-execute client tools with the correct version of client tools. code, err := updater.Exec() - if err != nil { - return trace.Wrap(err) + if err != nil && !errors.Is(err, os.ErrNotExist) { + log.Debugf("Failed to re-exec client tool: %v.", err) + os.Exit(code) + } else if err == nil { + os.Exit(code) } - os.Exit(code) } if opt.afterLoginHook != nil { diff --git a/lib/utils/packaging/unarchive.go b/lib/utils/packaging/unarchive.go index dee4d8cdb915c..4ace902646f5e 100644 --- a/lib/utils/packaging/unarchive.go +++ b/lib/utils/packaging/unarchive.go @@ -132,7 +132,9 @@ func replaceZip(toolsDir string, archivePath string, extractDir string, execName if err := os.Remove(appPath); err != nil && !os.IsNotExist(err) { return trace.Wrap(err) } - if err := os.Symlink(dest, appPath); err != nil { + // For the Windows build we have to use hard links to be able + // to use client tools without administrative access. + if err := os.Link(dest, appPath); err != nil { return trace.Wrap(err) } return trace.Wrap(destFile.Close()) diff --git a/tool/tctl/common/tctl.go b/tool/tctl/common/tctl.go index e5e8354a20713..0794f0ecca3e5 100644 --- a/tool/tctl/common/tctl.go +++ b/tool/tctl/common/tctl.go @@ -25,8 +25,10 @@ import ( "io/fs" "log/slog" "os" + "os/signal" "path/filepath" "runtime" + "syscall" "time" "github.com/alecthomas/kingpin/v2" @@ -125,15 +127,21 @@ func Run(ctx context.Context, commands []CLICommand) { // is required if the user passed in the TELEPORT_TOOLS_VERSION // explicitly. err := updater.UpdateWithLock(ctx, toolsVersion) - if err != nil { + if errors.Is(err, context.Canceled) { + var cancel context.CancelFunc + ctx, cancel = signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) + defer cancel() + } else if err != nil { utils.FatalError(err) } // Re-execute client tools with the correct version of client tools. code, err := updater.Exec() - if err != nil { - utils.FatalError(err) + if err != nil && !errors.Is(err, os.ErrNotExist) { + log.Debugf("Failed to re-exec client tool: %v.", err) + os.Exit(code) + } else if err == nil { + os.Exit(code) } - os.Exit(code) } err = TryRun(commands, os.Args[1:]) diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go index fc732697ead6d..715076638768d 100644 --- a/tool/tsh/common/tsh.go +++ b/tool/tsh/common/tsh.go @@ -712,16 +712,20 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { // Download the version of client tools required by the cluster. This // is required if the user passed in the TELEPORT_TOOLS_VERSION // explicitly. - if err := updater.UpdateWithLock(ctx, toolsVersion); err != nil { + err := updater.UpdateWithLock(ctx, toolsVersion) + if errors.Is(err, context.Canceled) { + var cancel context.CancelFunc + ctx, cancel = signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT) + defer cancel() + } else if err != nil { return trace.Wrap(err) } - // Re-execute client tools with the correct version of client tools. code, err := updater.Exec() - if err != nil { + if err != nil && !errors.Is(err, os.ErrNotExist) { log.Debugf("Failed to re-exec client tool: %v.", err) os.Exit(code) - } else { + } else if err == nil { os.Exit(code) } } @@ -5577,16 +5581,17 @@ func updateAndRun(ctx context.Context, proxy string, insecure bool) error { if reExec { // Download the version of client tools required by the cluster. err := updater.UpdateWithLock(ctx, toolsVersion) - if err != nil { - return trace.Wrap(err) + if err != nil && !errors.Is(err, context.Canceled) { + utils.FatalError(err) } - // Re-execute client tools with the correct version of client tools. code, err := updater.Exec() - if err != nil { - return trace.Wrap(err) + if err != nil && !errors.Is(err, os.ErrNotExist) { + log.Debugf("Failed to re-exec client tool: %v.", err) + os.Exit(code) + } else if err == nil { + os.Exit(code) } - os.Exit(code) } return nil