diff --git a/lib/autoupdate/client_update.go b/lib/autoupdate/client_update.go index 534ad4aeadf40..a82972f571ee3 100644 --- a/lib/autoupdate/client_update.go +++ b/lib/autoupdate/client_update.go @@ -239,7 +239,7 @@ func (u *ClientUpdater) Update(ctx context.Context, toolsVersion string) error { if err != nil { return trace.Wrap(err) } - archivePath, err := u.downloadArchive(signalCtx, u.toolsDir, archiveURL, hash) + archivePath, archiveHash, err := u.downloadArchive(signalCtx, u.toolsDir, archiveURL) if err != nil { return trace.Wrap(err) } @@ -248,6 +248,9 @@ func (u *ClientUpdater) Update(ctx context.Context, toolsVersion string) error { slog.WarnContext(ctx, "failed to remove archive", "error", err) } }() + if archiveHash != hash { + return trace.BadParameter("hash of archive does not match downloaded archive") + } pkgName := fmt.Sprint(uuid.New().String(), updatePackageSuffix) extractDir := filepath.Join(u.toolsDir, pkgName) @@ -291,7 +294,7 @@ func (u *ClientUpdater) Exec() (int, error) { return cmd.ProcessState.ExitCode(), nil } - if err := syscall.Exec(path, os.Args, env); err != nil { + if err := syscall.Exec(path, append([]string{path}, os.Args[1:]...), env); err != nil { return 0, trace.Wrap(err) } @@ -324,23 +327,23 @@ func (u *ClientUpdater) downloadHash(ctx context.Context, url string) (string, e return raw, nil } -func (u *ClientUpdater) downloadArchive(ctx context.Context, downloadDir string, url string, hash string) (string, error) { +func (u *ClientUpdater) downloadArchive(ctx context.Context, downloadDir string, url string) (string, string, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { - return "", trace.Wrap(err) + return "", "", trace.Wrap(err) } resp, err := u.client.Do(req) if err != nil { - return "", trace.Wrap(err) + return "", "", trace.Wrap(err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return "", trace.BadParameter("bad status when downloading archive: %v", resp.StatusCode) + return "", "", trace.BadParameter("bad status when downloading archive: %v", resp.StatusCode) } if resp.ContentLength != -1 { if err := checkFreeSpace(u.toolsDir, uint64(resp.ContentLength)); err != nil { - return "", trace.Wrap(err) + return "", "", trace.Wrap(err) } } @@ -348,7 +351,7 @@ func (u *ClientUpdater) downloadArchive(ctx context.Context, downloadDir string, // occurred. f, err := os.CreateTemp(downloadDir, "tmp-") if err != nil { - return "", trace.Wrap(err) + return "", "", trace.Wrap(err) } h := sha256.New() @@ -356,15 +359,12 @@ func (u *ClientUpdater) downloadArchive(ctx context.Context, downloadDir string, body := io.TeeReader(io.TeeReader(resp.Body, h), pw) // It is a little inefficient to download the file to disk and then re-load - // it into memory to unarchive later, but this is safer as it allows {tsh, - // tctl} to validate the hash before trying to operate on the archive. - _, err = io.Copy(f, body) + // it into memory to unarchive later, but this is safer as it allows client + // tools to validate the hash before trying to operate on the archive. + _, err = io.CopyN(f, body, resp.ContentLength) if err != nil { - return "", trace.Wrap(err) - } - if fmt.Sprintf("%x", h.Sum(nil)) != hash { - return "", trace.BadParameter("hash of archive does not match downloaded archive") + return "", "", trace.Wrap(err) } - return f.Name(), nil + return f.Name(), fmt.Sprintf("%x", h.Sum(nil)), nil }