Skip to content

Commit

Permalink
Add client tools auto update
Browse files Browse the repository at this point in the history
  • Loading branch information
vapopov committed Oct 10, 2024
1 parent 40e6c49 commit ddcbaf9
Show file tree
Hide file tree
Showing 13 changed files with 1,332 additions and 5 deletions.
231 changes: 231 additions & 0 deletions integration/autoupdate/client_update_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
/*
* Teleport
* Copyright (C) 2024 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package autoupdate_test

import (
"bytes"
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"regexp"
"runtime"
"strings"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/lib/autoupdate"
)

var (
// pattern is template for response on version command for client tools {tsh, tctl}.
pattern = regexp.MustCompile(`(?m)Teleport v(.*) git`)
)

// TestUpdate verifies the basic update logic. We first download a lower version, then request
// an update to a newer version, expecting it to re-execute with the updated version.
func TestUpdate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

// Fetch compiled test binary with updater logic and install to $TELEPORT_HOME.
updater := autoupdate.NewClientUpdater(
clientTools(),
toolsDir,
testVersions[0],
autoupdate.WithBaseURL(fmt.Sprintf("http://%s", baseURL)),
)
err := updater.Update(ctx, testVersions[0])
require.NoError(t, err)

// Verify that the installed version is equal to requested one.
cmd := exec.CommandContext(ctx, filepath.Join(toolsDir, "tsh"), "version")
out, err := cmd.Output()
require.NoError(t, err)

matches := pattern.FindStringSubmatch(string(out))
require.Len(t, matches, 2)
require.Equal(t, testVersions[0], matches[1])

// Execute version command again with setting the new version which must
// trigger re-execution of the same command after downloading requested version.
cmd = exec.CommandContext(ctx, filepath.Join(toolsDir, "tsh"), "version")
cmd.Env = append(
os.Environ(),
fmt.Sprintf("%s=%s", teleportToolsVersion, testVersions[1]),
)
out, err = cmd.Output()
require.NoError(t, err)

matches = pattern.FindStringSubmatch(string(out))
require.Len(t, matches, 2)
require.Equal(t, testVersions[1], matches[1])
}

// TestParallelUpdate launches multiple updater commands in parallel while defining a new version.
// The first process should acquire a lock and block execution for the other processes. After the
// first update is complete, other processes should acquire the lock one by one and re-execute
// the command with the updated version without any new downloads.
func TestParallelUpdate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

// Initial fetch the updater binary un-archive and replace.
updater := autoupdate.NewClientUpdater(
clientTools(),
toolsDir,
testVersions[0],
autoupdate.WithBaseURL(fmt.Sprintf("http://%s", baseURL)),
)
err := updater.Update(ctx, testVersions[0])
require.NoError(t, err)

// By setting the limit request next test http serving file going blocked until unlock is sent.
lock := make(chan struct{})
limitedWriter.SetLimitRequest(limitRequest{
limit: 1024,
lock: lock,
})

var outputs [3]bytes.Buffer
errChan := make(chan error, cap(outputs))
for i := 0; i < cap(outputs); i++ {
cmd := exec.Command(filepath.Join(toolsDir, "tsh"), "version")
cmd.Stdout = &outputs[i]
cmd.Stderr = &outputs[i]
cmd.Env = append(
os.Environ(),
fmt.Sprintf("%s=%s", teleportToolsVersion, testVersions[1]),
)
err = cmd.Start()
require.NoError(t, err, "failed to start updater")

go func(cmd *exec.Cmd) {
errChan <- cmd.Wait()
}(cmd)
}

select {
case err := <-errChan:
require.Fail(t, "we shouldn't receive any error", err)
case <-time.After(5 * time.Second):
require.Fail(t, "failed to wait till the download is started")
case <-lock:
// Wait for a short period to allow other processes to launch and attempt to acquire the lock.
time.Sleep(100 * time.Millisecond)
lock <- struct{}{}
}

// Wait till process finished with exit code 0, but we still should get progress
// bar in output content.
for i := 0; i < cap(outputs); i++ {
select {
case <-time.After(5 * time.Second):
require.Fail(t, "failed to wait till the process is finished")
case err := <-errChan:
require.NoError(t, err)
}
}

var progressCount int
for i := 0; i < cap(outputs); i++ {
matches := pattern.FindStringSubmatch(outputs[i].String())
require.Len(t, matches, 2)
assert.Equal(t, testVersions[1], matches[1])
if strings.Contains(outputs[i].String(), "Update progress:") {
progressCount++
}
}
assert.Equal(t, 1, progressCount, "we should have only one progress bar downloading new version")
}

// TestUpdateInterruptSignal verifies the interrupt signal send to the process must stop downloading.
func TestUpdateInterruptSignal(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

// Initial fetch the updater binary un-archive and replace.
updater := autoupdate.NewClientUpdater(
clientTools(),
toolsDir,
testVersions[0],
autoupdate.WithBaseURL(fmt.Sprintf("http://%s", baseURL)),
)
err := updater.Update(ctx, testVersions[0])
require.NoError(t, err)

var output bytes.Buffer
cmd := exec.Command(filepath.Join(toolsDir, "tsh"), "version")
cmd.Stdout = &output
cmd.Stderr = &output
cmd.Env = append(
os.Environ(),
fmt.Sprintf("%s=%s", teleportToolsVersion, testVersions[1]),
)
err = cmd.Start()
require.NoError(t, err, "failed to start updater")
pid := cmd.Process.Pid

errChan := make(chan error)
go func() {
errChan <- cmd.Wait()
}()

// By setting the limit request next test http serving file going blocked until unlock is sent.
lock := make(chan struct{})
limitedWriter.SetLimitRequest(limitRequest{
limit: 1024,
lock: lock,
})

select {
case err := <-errChan:
require.Fail(t, "we shouldn't receive any error", err)
case <-time.After(5 * time.Second):
require.Fail(t, "failed to wait till the download is started")
case <-lock:
time.Sleep(100 * time.Millisecond)
require.NoError(t, sendInterrupt(pid))
lock <- struct{}{}
}

// Wait till process finished with exit code 0, but we still should get progress
// bar in output content.
select {
case <-time.After(5 * time.Second):
require.Fail(t, "failed to wait till the process interrupted")
case err := <-errChan:
require.NoError(t, err)
}
assert.Contains(t, output.String(), "Update progress:")
}

func clientTools() []string {
switch runtime.GOOS {
case constants.WindowsOS:
return []string{"tsh.exe", "tctl.exe"}
default:
return []string{"tsh", "tctl"}
}
}
89 changes: 89 additions & 0 deletions integration/autoupdate/helper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Teleport
* Copyright (C) 2024 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package autoupdate_test

import (
"net/http"
"sync"
)

type limitRequest struct {
limit int64
lock chan struct{}
}

// limitedResponseWriter wraps http.ResponseWriter and enforces a write limit
// then block the response until signal is received.
type limitedResponseWriter struct {
requests chan limitRequest
}

// newLimitedResponseWriter creates a new limitedResponseWriter with the lock.
func newLimitedResponseWriter() *limitedResponseWriter {
lw := &limitedResponseWriter{
requests: make(chan limitRequest, 10),
}
return lw
}

// Wrap wraps response writer if limit was previously requested, if not, return original one.
func (lw *limitedResponseWriter) Wrap(w http.ResponseWriter) http.ResponseWriter {
select {
case request := <-lw.requests:
return &wrapper{
ResponseWriter: w,
request: request,
}
default:
return w
}
}

// SetLimitRequest sends limit request to the pool to wrap next response writer with defined limits.
func (lw *limitedResponseWriter) SetLimitRequest(limit limitRequest) {
lw.requests <- limit
}

// wrapper wraps the http response writer to control writing operation by blocking it.
type wrapper struct {
http.ResponseWriter

written int64
request limitRequest
released bool

mutex sync.Mutex
}

// Write writes data to the underlying ResponseWriter but respects the byte limit.
func (lw *wrapper) Write(p []byte) (int, error) {
lw.mutex.Lock()
defer lw.mutex.Unlock()

if lw.written >= lw.request.limit && !lw.released {
// Send signal that lock is acquired and wait till it was released by response.
lw.request.lock <- struct{}{}
<-lw.request.lock
lw.released = true
}

n, err := lw.ResponseWriter.Write(p)
lw.written += int64(n)
return n, err
}
37 changes: 37 additions & 0 deletions integration/autoupdate/helper_unix_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
//go:build !windows

/*
* Teleport
* Copyright (C) 2024 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package autoupdate_test

import (
"errors"
"syscall"

"github.com/gravitational/trace"
)

// sendInterrupt sends a SIGINT to the process.
func sendInterrupt(pid int) error {
err := syscall.Kill(pid, syscall.SIGINT)
if errors.Is(err, syscall.ESRCH) {
return trace.BadParameter("can't find the process: %v", pid)
}
return trace.Wrap(err)
}
Loading

0 comments on commit ddcbaf9

Please sign in to comment.