From 89ce24b25edafe0b84155cbd6cfb236b798eda0a Mon Sep 17 00:00:00 2001 From: Gabriel Corado Date: Mon, 4 Nov 2024 15:51:56 -0300 Subject: [PATCH] Fix `tsh play` `--skip-idle-time` not working correctly (#47304) --- lib/client/api.go | 6 ++-- lib/player/player.go | 68 +++++++++++++++++++++------------------ lib/player/player_test.go | 31 +++++++++++++++++- 3 files changed, 69 insertions(+), 36 deletions(-) diff --git a/lib/client/api.go b/lib/client/api.go index ece92d5b21c24..4337bcd241632 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -2223,13 +2223,11 @@ func playSession(ctx context.Context, sessionID string, speed float64, streamer } playing = !playing case keyLeft, keyDown: - current := time.Duration(player.LastPlayed() * int64(time.Millisecond)) - player.SetPos(max(current-skipDuration, 0)) // rewind + player.SetPos(max(player.LastPlayed()-skipDuration, 0)) // rewind term.Clear() term.SetCursorPos(1, 1) case keyRight, keyUp: - current := time.Duration(player.LastPlayed() * int64(time.Millisecond)) - player.SetPos(current + skipDuration) // advance forward + player.SetPos(player.LastPlayed() + skipDuration) // advance forward } } }() diff --git a/lib/player/player.go b/lib/player/player.go index 40beed295b5c1..f7aaf7041c566 100644 --- a/lib/player/player.go +++ b/lib/player/player.go @@ -59,7 +59,7 @@ type Player struct { advanceTo atomic.Int64 emit chan events.AuditEvent - wake chan int64 + wake chan time.Duration done chan struct{} // playPause holds a channel to be closed when @@ -75,7 +75,12 @@ type Player struct { err error } -const normalPlayback = math.MinInt64 +const ( + normalPlayback = time.Duration(0) + // MaxIdleTime defines the max idle time when skipping idle + // periods on the recording. + MaxIdleTime = 500 * time.Millisecond +) // Streamer is the underlying streamer that provides // access to recorded session events. @@ -116,18 +121,19 @@ func New(cfg *Config) (*Player, error) { } p := &Player{ - clock: clk, - log: log, - sessionID: cfg.SessionID, - streamer: cfg.Streamer, - emit: make(chan events.AuditEvent, 1024), - playPause: make(chan chan struct{}, 1), - wake: make(chan int64), - done: make(chan struct{}), + clock: clk, + log: log, + sessionID: cfg.SessionID, + streamer: cfg.Streamer, + emit: make(chan events.AuditEvent, 1024), + playPause: make(chan chan struct{}, 1), + wake: make(chan time.Duration), + done: make(chan struct{}), + skipIdleTime: cfg.SkipIdleTime, } p.speed.Store(float64(defaultPlaybackSpeed)) - p.advanceTo.Store(normalPlayback) + p.advanceTo.Store(int64(normalPlayback)) // start in a paused state p.playPause <- make(chan struct{}) @@ -165,7 +171,7 @@ func (p *Player) stream() { defer cancel() eventsC, errC := p.streamer.StreamSessionEvents(ctx, p.sessionID, 0) - lastDelay := int64(0) + var lastDelay time.Duration for { select { case <-p.done: @@ -191,7 +197,7 @@ func (p *Player) stream() { currentDelay := getDelay(evt) if currentDelay > 0 && currentDelay >= lastDelay { - switch adv := p.advanceTo.Load(); { + switch adv := time.Duration(p.advanceTo.Load()); { case adv >= currentDelay: // no timing delay necessary, we are fast forwarding break @@ -199,12 +205,12 @@ func (p *Player) stream() { // any negative value other than normalPlayback means // we rewind (by restarting the stream and seeking forward // to the rewind point) - p.advanceTo.Store(adv * -1) + p.advanceTo.Store(int64(adv) * -1) go p.stream() return default: if adv != normalPlayback { - p.advanceTo.Store(normalPlayback) + p.advanceTo.Store(int64(normalPlayback)) // we're catching back up to real time, so the delay // is calculated not from the last event but from the @@ -232,7 +238,7 @@ func (p *Player) stream() { // // TODO: consider a select with a timeout to detect blocked readers? p.emit <- evt - p.lastPlayed.Store(currentDelay) + p.lastPlayed.Store(int64(currentDelay)) } } } @@ -284,14 +290,14 @@ func (p *Player) SetPos(d time.Duration) error { if d == 0 { d = 1 * time.Millisecond } - if d.Milliseconds() < p.lastPlayed.Load() { + if d < time.Duration(p.lastPlayed.Load()) { d = -1 * d } - p.advanceTo.Store(d.Milliseconds()) + p.advanceTo.Store(int64(d)) // try to wake up the player if it's waiting to emit an event select { - case p.wake <- d.Milliseconds(): + case p.wake <- d: default: } @@ -308,18 +314,18 @@ func (p *Player) SetPos(d time.Duration) error { // // A nil return value indicates that the delay has elapsed and that // the next even can be emitted. -func (p *Player) applyDelay(lastDelay, currentDelay int64) error { +func (p *Player) applyDelay(lastDelay, currentDelay time.Duration) error { loop: for { // TODO(zmb3): changing play speed during a long sleep // will not apply until after the sleep completes speed := p.speed.Load().(float64) - scaled := float64(currentDelay-lastDelay) / speed + scaled := time.Duration(float64(currentDelay-lastDelay) / speed) if p.skipIdleTime { - scaled = min(scaled, 500.0*float64(time.Millisecond)) + scaled = min(scaled, MaxIdleTime) } - timer := p.clock.NewTimer(time.Duration(scaled) * time.Millisecond) + timer := p.clock.NewTimer(scaled) defer timer.Stop() start := time.Now() @@ -333,7 +339,7 @@ loop: case newPos == interruptForPause: // the user paused playback while we were waiting to emit the next event: // 1) figure out much of the sleep we completed - dur := float64(time.Since(start).Milliseconds()) * speed + dur := time.Duration(float64(time.Since(start)) * speed) // 2) wait here until the user resumes playback if err := p.waitWhilePaused(); errors.Is(err, errSeekWhilePaused) { @@ -345,7 +351,7 @@ loop: // now that we're playing again, update our delay to account // for the portion that was already satisfied and apply the // remaining delay - lastDelay += int64(dur) + lastDelay += dur timer.Stop() continue loop case newPos > currentDelay: @@ -430,17 +436,17 @@ func (p *Player) waitWhilePaused() error { // LastPlayed returns the time of the last played event, // expressed as milliseconds since the start of the session. -func (p *Player) LastPlayed() int64 { - return p.lastPlayed.Load() +func (p *Player) LastPlayed() time.Duration { + return time.Duration(p.lastPlayed.Load()) } -func getDelay(e events.AuditEvent) int64 { +func getDelay(e events.AuditEvent) time.Duration { switch x := e.(type) { case *events.DesktopRecording: - return x.DelayMilliseconds + return time.Duration(x.DelayMilliseconds) * time.Millisecond case *events.SessionPrint: - return x.DelayMilliseconds + return time.Duration(x.DelayMilliseconds) * time.Millisecond default: - return int64(0) + return time.Duration(0) } } diff --git a/lib/player/player_test.go b/lib/player/player_test.go index a64418df941f5..f799b87919ad4 100644 --- a/lib/player/player_test.go +++ b/lib/player/player_test.go @@ -26,6 +26,7 @@ import ( "time" "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" apievents "github.com/gravitational/teleport/api/types/events" @@ -169,7 +170,7 @@ func TestClose(t *testing.T) { _, ok := <-p.C() require.False(t, ok, "player channel should have been closed") require.NoError(t, p.Err()) - require.Equal(t, int64(1000), p.LastPlayed()) + require.Equal(t, time.Second, p.LastPlayed()) } func TestSeekForward(t *testing.T) { @@ -260,6 +261,34 @@ func TestRewind(t *testing.T) { p.Close() } +func TestSkipIdlePeriods(t *testing.T) { + eventCount := 3 + delayMilliseconds := 60000 + clk := clockwork.NewFakeClock() + p, err := player.New(&player.Config{ + Clock: clk, + SessionID: "test-session", + SkipIdleTime: true, + Streamer: &simpleStreamer{count: int64(eventCount), delay: int64(delayMilliseconds)}, + }) + require.NoError(t, err) + require.NoError(t, p.Play()) + + for i := 0; i < eventCount; i++ { + // Consume events in an eventually loop to avoid firing the clock + // events before the timer is set. + require.EventuallyWithT(t, func(t *assert.CollectT) { + clk.Advance(player.MaxIdleTime) + select { + case evt := <-p.C(): + assert.Equal(t, int64(i), evt.GetIndex()) + default: + assert.Fail(t, "expected to receive event after short period, but got nothing") + } + }, 3*time.Second, 100*time.Millisecond) + } +} + // simpleStreamer streams a fake session that contains // count events, emitted at a particular interval type simpleStreamer struct {