diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index 70078ee0ff737..3357b9aa292e4 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -195,8 +195,8 @@ func (a *ServerWithRoles) actionWithExtendedContext(namespace, kind, verb string // actionForKindSession is a special checker that grants access to session // recordings. It can allow access to a specific recording based on the // `where` section of the user's access rule for kind `session`. -func (a *ServerWithRoles) actionForKindSession(namespace string, sid session.ID) (types.SessionKind, error) { - sessionEnd, err := a.findSessionEndEvent(namespace, sid) +func (a *ServerWithRoles) actionForKindSession(ctx context.Context, namespace string, sid session.ID) (types.SessionKind, error) { + sessionEnd, err := a.findSessionEndEvent(ctx, sid) extendContext := func(ctx *services.Context) error { ctx.Session = sessionEnd @@ -4097,8 +4097,8 @@ func (s *streamWithRoles) RecordEvent(ctx context.Context, pe apievents.Prepared return s.stream.RecordEvent(ctx, pe) } -func (a *ServerWithRoles) findSessionEndEvent(namespace string, sid session.ID) (apievents.AuditEvent, error) { - sessionEvents, _, err := a.alog.SearchSessionEvents(context.TODO(), events.SearchSessionEventsRequest{ +func (a *ServerWithRoles) findSessionEndEvent(ctx context.Context, sid session.ID) (apievents.AuditEvent, error) { + sessionEvents, _, err := a.alog.SearchSessionEvents(ctx, events.SearchSessionEventsRequest{ From: time.Time{}, To: a.authServer.clock.Now().UTC(), Limit: defaults.EventsIterationLimit, @@ -5948,7 +5948,7 @@ func (a *ServerWithRoles) StreamSessionEvents(ctx context.Context, sessionID ses var sessionType types.SessionKind if !isTeleportServer { var err error - sessionType, err = a.actionForKindSession(apidefaults.Namespace, sessionID) + sessionType, err = a.actionForKindSession(ctx, apidefaults.Namespace, sessionID) if err != nil { c, e := make(chan apievents.AuditEvent), make(chan error, 1) e <- trace.Wrap(err) diff --git a/lib/client/api.go b/lib/client/api.go index c74cbd81d3006..57a4a7eed500c 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -2313,6 +2313,7 @@ func playSession(ctx context.Context, sessionID string, speed float64, streamer SessionID: *sid, Streamer: streamer, SkipIdleTime: skipIdleTime, + Context: ctx, }) if err != nil { return trace.Wrap(err) diff --git a/lib/events/auditlog.go b/lib/events/auditlog.go index 5227935e6f236..51180746cbe7f 100644 --- a/lib/events/auditlog.go +++ b/lib/events/auditlog.go @@ -532,7 +532,7 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID } start := time.Now() - if err := l.UploadHandler.Download(l.ctx, sessionID, rawSession); err != nil { + if err := l.UploadHandler.Download(ctx, sessionID, rawSession); err != nil { _ = rawSession.Close() if errors.Is(err, fs.ErrNotExist) { err = trace.NotFound("a recording for session %v was not found", sessionID) diff --git a/lib/events/s3sessions/s3handler.go b/lib/events/s3sessions/s3handler.go index 7fc80f9063c5a..8b8487bd26a3f 100644 --- a/lib/events/s3sessions/s3handler.go +++ b/lib/events/s3sessions/s3handler.go @@ -38,6 +38,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/s3" awstypes "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/gravitational/trace" + "go.opentelemetry.io/contrib/instrumentation/github.com/aws/aws-sdk-go-v2/otelaws" "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" @@ -232,6 +233,8 @@ func NewHandler(ctx context.Context, cfg Config) (*Handler, error) { return nil, trace.Wrap(err) } + otelaws.AppendMiddlewares(&awsConfig.APIOptions) + // Create S3 client with custom options client := s3.NewFromConfig(awsConfig, s3Opts...) diff --git a/lib/player/player.go b/lib/player/player.go index c548a8929200e..50c0ccf6a3fa4 100644 --- a/lib/player/player.go +++ b/lib/player/player.go @@ -119,6 +119,7 @@ type Config struct { SessionID session.ID Streamer Streamer SkipIdleTime bool + Context context.Context } func New(cfg *Config) (*Player, error) { @@ -140,6 +141,11 @@ func New(cfg *Config) (*Player, error) { slog.With(teleport.ComponentKey, "player"), ) + ctx := context.Background() + if cfg.Context != nil { + ctx = cfg.Context + } + p := &Player{ clock: clk, log: log, @@ -158,7 +164,7 @@ func New(cfg *Config) (*Player, error) { // start in a paused state p.playPause <- make(chan struct{}) - go p.stream() + go p.stream(ctx) return p, nil } @@ -186,8 +192,8 @@ func (p *Player) SetSpeed(s float64) error { return nil } -func (p *Player) stream() { - ctx, cancel := context.WithCancel(context.Background()) +func (p *Player) stream(baseContext context.Context) { + ctx, cancel := context.WithCancel(baseContext) defer cancel() eventsC, errC := p.streamer.StreamSessionEvents(metadata.WithSessionRecordingFormatContext(ctx, teleport.PTY), p.sessionID, 0) @@ -232,7 +238,7 @@ func (p *Player) stream() { // we rewind (by restarting the stream and seeking forward // to the rewind point) p.advanceTo.Store(int64(adv) * -1) - go p.stream() + go p.stream(baseContext) return default: if adv != normalPlayback { @@ -247,7 +253,7 @@ func (p *Player) stream() { switch err := p.applyDelay(lastDelay, currentDelay); { case errors.Is(err, errSeekWhilePaused): p.log.DebugContext(ctx, "Seeked during pause, will restart stream") - go p.stream() + go p.stream(baseContext) return case err != nil: close(p.emit) diff --git a/lib/web/desktop_playback.go b/lib/web/desktop_playback.go index f4755f18fe789..1467c14a28165 100644 --- a/lib/web/desktop_playback.go +++ b/lib/web/desktop_playback.go @@ -55,6 +55,7 @@ func (h *Handler) desktopPlaybackHandle( Log: h.logger, SessionID: session.ID(sID), Streamer: clt, + Context: r.Context(), }) if err != nil { h.log.Errorf("couldn't create player for session %v: %v", sID, err) diff --git a/lib/web/tty_playback.go b/lib/web/tty_playback.go index 76c603c1e49ec..f601f4237666c 100644 --- a/lib/web/tty_playback.go +++ b/lib/web/tty_playback.go @@ -88,6 +88,7 @@ func (h *Handler) ttyPlaybackHandle( Log: h.logger, SessionID: session.ID(sID), Streamer: clt, + Context: r.Context(), }) if err != nil { h.log.Warn("player error", err)