Skip to content

Commit

Permalink
Replace fork for posix platform for re-exec
Browse files Browse the repository at this point in the history
Move integration tests to client tools specific dir
Use context cancellation with SIGTERM, SIGINT
Remove cancelable tee reader with context replacement
Renaming
  • Loading branch information
vapopov committed Oct 15, 2024
1 parent ddcbaf9 commit a3fa2c3
Show file tree
Hide file tree
Showing 9 changed files with 50 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package autoupdate_test
package tools_test

import (
"bytes"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package autoupdate_test
package tools_test

import (
"net/http"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package autoupdate_test
package tools_test

import (
"errors"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package autoupdate_test
package tools_test

import (
"syscall"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package autoupdate_test
package tools_test

import (
"context"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func main() {
// Download and update the version of client tools required by the cluster.
// This is required if the user passed in the TELEPORT_TOOLS_VERSION explicitly.
err := updater.UpdateWithLock(ctx, toolsVersion)
if errors.Is(err, autoupdate.ErrCanceled) {
if errors.Is(err, context.Canceled) {
os.Exit(0)
return
}
Expand Down
69 changes: 43 additions & 26 deletions lib/autoupdate/client_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ import (
"encoding/hex"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"os/exec"
"os/signal"
"path/filepath"
"regexp"
"runtime"
Expand All @@ -47,13 +49,13 @@ import (
const (
// teleportToolsVersionEnv is environment name for requesting specific version for update.
teleportToolsVersionEnv = "TELEPORT_TOOLS_VERSION"
// baseUrl is CDN URL for downloading official Teleport packages.
baseUrl = "https://cdn.teleport.dev"
// baseURL is CDN URL for downloading official Teleport packages.
baseURL = "https://cdn.teleport.dev"
// checksumHexLen is length of the hash sum.
checksumHexLen = 64
// reservedFreeDisk is the predefined amount of free disk space (in bytes) required
// to remain available after downloading archives.
reservedFreeDisk = 10 * 1024 * 1024
reservedFreeDisk = 10 * 1024 * 1024 // 10 Mb
// lockFileName is file used for locking update process in parallel.
lockFileName = ".lock"
// updatePackageSuffix is directory suffix used for package extraction in tools directory.
Expand All @@ -65,18 +67,18 @@ var (
pattern = regexp.MustCompile(`(?m)Teleport v(.*) git`)
)

// Option applies an option value for the ClientUpdater.
type Option func(u *ClientUpdater)
// ClientOption applies an option value for the ClientUpdater.
type ClientOption func(u *ClientUpdater)

// WithBaseURL defines custom base url for the updater.
func WithBaseURL(baseUrl string) Option {
func WithBaseURL(baseUrl string) ClientOption {
return func(u *ClientUpdater) {
u.baseUrl = baseUrl
}
}

// WithClient defines custom http client for the ClientUpdater.
func WithClient(client *http.Client) Option {
func WithClient(client *http.Client) ClientOption {
return func(u *ClientUpdater) {
u.client = client
}
Expand All @@ -93,12 +95,12 @@ type ClientUpdater struct {
}

// NewClientUpdater initiate updater for the client tools auto updates.
func NewClientUpdater(tools []string, toolsDir string, localVersion string, options ...Option) *ClientUpdater {
func NewClientUpdater(tools []string, toolsDir string, localVersion string, options ...ClientOption) *ClientUpdater {
updater := &ClientUpdater{
tools: tools,
toolsDir: toolsDir,
localVersion: localVersion,
baseUrl: baseUrl,
baseUrl: baseURL,
client: http.DefaultClient,
}
for _, option := range options {
Expand All @@ -123,7 +125,7 @@ func (u *ClientUpdater) CheckLocal() (string, bool) {

// If a version of client tools has already been downloaded to
// tools directory, return that.
toolsVersion, err := version(u.toolsDir)
toolsVersion, err := checkClientToolVersion(u.toolsDir)
if err != nil {
return "", false
}
Expand Down Expand Up @@ -165,7 +167,7 @@ func (u *ClientUpdater) CheckRemote(ctx context.Context, proxyAddr string) (stri

// If a version of client tools has already been downloaded to
// tools directory, return that.
toolsVersion, err := version(u.toolsDir)
toolsVersion, err := checkClientToolVersion(u.toolsDir)
if err != nil {
return "", false, trace.Wrap(err)
}
Expand All @@ -190,7 +192,7 @@ func (u *ClientUpdater) UpdateWithLock(ctx context.Context, toolsVersion string)
if err := os.MkdirAll(u.toolsDir, 0o755); err != nil {
return trace.Wrap(err)
}
// Lock concurrent {tsh, tctl} execution util requested version is updated.
// Lock concurrent client tools execution util requested version is updated.
unlock, err := utils.FSWriteLock(filepath.Join(u.toolsDir, lockFileName))
if err != nil {
return trace.Wrap(err)
Expand All @@ -202,7 +204,7 @@ func (u *ClientUpdater) UpdateWithLock(ctx context.Context, toolsVersion string)
// If the version of the running binary or the version downloaded to
// tools directory is the same as the requested version of client tools,
// nothing to be done, exit early.
teleportVersion, err := version(u.toolsDir)
teleportVersion, err := checkClientToolVersion(u.toolsDir)
if err != nil && !trace.IsNotFound(err) {
return trace.Wrap(err)

Expand All @@ -211,7 +213,7 @@ func (u *ClientUpdater) UpdateWithLock(ctx context.Context, toolsVersion string)
return nil
}

// Download and update {tsh, tctl} in tools directory.
// Download and update client tools in tools directory.
if err := u.Update(ctx, toolsVersion); err != nil {
return trace.Wrap(err)
}
Expand All @@ -228,17 +230,24 @@ func (u *ClientUpdater) Update(ctx context.Context, toolsVersion string) error {
return trace.Wrap(err)
}

signalCtx, cancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM)
defer cancel()

// Download the archive and validate against the hash. Download to a
// temporary path within tools directory.
hash, err := u.downloadHash(ctx, hashURL)
hash, err := u.downloadHash(signalCtx, hashURL)
if err != nil {
return trace.Wrap(err)
}
archivePath, err := u.downloadArchive(ctx, u.toolsDir, archiveURL, hash)
archivePath, err := u.downloadArchive(signalCtx, u.toolsDir, archiveURL, hash)
if err != nil {
return trace.Wrap(err)
}
defer os.Remove(archivePath)
defer func() {
if err := os.Remove(archivePath); err != nil {
slog.WarnContext(ctx, "failed to remove archive", "error", err)
}
}()

pkgName := fmt.Sprint(uuid.New().String(), updatePackageSuffix)
extractDir := filepath.Join(u.toolsDir, pkgName)
Expand Down Expand Up @@ -266,19 +275,27 @@ func (u *ClientUpdater) Exec() (int, error) {
if err != nil {
return 0, trace.Wrap(err)
}

cmd := exec.Command(path, os.Args[1:]...)
// To prevent re-execution loop we have to disable update logic for re-execution.
cmd.Env = append(os.Environ(), teleportToolsVersionEnv+"=off")
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
env := append(os.Environ(), teleportToolsVersionEnv+"=off")

if runtime.GOOS == constants.WindowsOS {
cmd := exec.Command(path, os.Args[1:]...)
cmd.Env = env
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return 0, trace.Wrap(err)
}

return cmd.ProcessState.ExitCode(), nil
}

if err := cmd.Run(); err != nil {
if err := syscall.Exec(path, os.Args, env); err != nil {
return 0, trace.Wrap(err)
}

return cmd.ProcessState.ExitCode(), nil
return 0, nil
}

func (u *ClientUpdater) downloadHash(ctx context.Context, url string) (string, error) {
Expand Down Expand Up @@ -336,7 +353,7 @@ func (u *ClientUpdater) downloadArchive(ctx context.Context, downloadDir string,

h := sha256.New()
pw := &progressWriter{n: 0, limit: resp.ContentLength}
body := cancelableTeeReader(io.TeeReader(resp.Body, h), pw, syscall.SIGINT, syscall.SIGTERM)
body := io.TeeReader(io.TeeReader(resp.Body, h), pw)

// It is a little inefficient to download the file to disk and then re-load
// it into memory to unarchive later, but this is safer as it allows {tsh,
Expand Down
38 changes: 0 additions & 38 deletions lib/autoupdate/progress.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,47 +20,9 @@ package autoupdate

import (
"fmt"
"io"
"os"
"os/signal"
"strings"
)

var (
// ErrCanceled represent the cancellation error for Ctrl-Break/Ctrl-C, depends on the platform.
ErrCanceled = fmt.Errorf("canceled")
)

// cancelableTeeReader is a copy of TeeReader with ability to react on signal notifier
// to cancel reading process.
func cancelableTeeReader(r io.Reader, w io.Writer, signals ...os.Signal) io.Reader {
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, signals...)

return &teeReader{r, w, sigs}
}

type teeReader struct {
r io.Reader
w io.Writer
sigs chan os.Signal
}

func (t *teeReader) Read(p []byte) (n int, err error) {
select {
case <-t.sigs:
return 0, ErrCanceled
default:
n, err = t.r.Read(p)
if n > 0 {
if n, err := t.w.Write(p[:n]); err != nil {
return n, err
}
}
}
return
}

type progressWriter struct {
n int64
limit int64
Expand Down
2 changes: 1 addition & 1 deletion lib/autoupdate/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func ToolsDir() (string, error) {
return filepath.Join(filepath.Clean(home), ".tsh", "bin"), nil
}

func version(toolsDir string) (string, error) {
func checkClientToolVersion(toolsDir string) (string, error) {
// Find the path to the current executable.
path, err := toolName(toolsDir)
if err != nil {
Expand Down

0 comments on commit a3fa2c3

Please sign in to comment.