Skip to content

Commit

Permalink
Propogate context for better span correlation of recording viewing (#…
Browse files Browse the repository at this point in the history
…49142)

Traces generated from playing back session recordings were
incomplete due to the correct context not being used at various
levels of the events and player code. This attempts to rectify that
by setting the correct context.Context along the way.
  • Loading branch information
rosstimothy authored Nov 21, 2024
1 parent bcbfa81 commit beec948
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 11 deletions.
10 changes: 5 additions & 5 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,8 @@ func (a *ServerWithRoles) actionWithExtendedContext(namespace, kind, verb string
// actionForKindSession is a special checker that grants access to session
// recordings. It can allow access to a specific recording based on the
// `where` section of the user's access rule for kind `session`.
func (a *ServerWithRoles) actionForKindSession(namespace string, sid session.ID) (types.SessionKind, error) {
sessionEnd, err := a.findSessionEndEvent(namespace, sid)
func (a *ServerWithRoles) actionForKindSession(ctx context.Context, namespace string, sid session.ID) (types.SessionKind, error) {
sessionEnd, err := a.findSessionEndEvent(ctx, sid)

extendContext := func(ctx *services.Context) error {
ctx.Session = sessionEnd
Expand Down Expand Up @@ -4097,8 +4097,8 @@ func (s *streamWithRoles) RecordEvent(ctx context.Context, pe apievents.Prepared
return s.stream.RecordEvent(ctx, pe)
}

func (a *ServerWithRoles) findSessionEndEvent(namespace string, sid session.ID) (apievents.AuditEvent, error) {
sessionEvents, _, err := a.alog.SearchSessionEvents(context.TODO(), events.SearchSessionEventsRequest{
func (a *ServerWithRoles) findSessionEndEvent(ctx context.Context, sid session.ID) (apievents.AuditEvent, error) {
sessionEvents, _, err := a.alog.SearchSessionEvents(ctx, events.SearchSessionEventsRequest{
From: time.Time{},
To: a.authServer.clock.Now().UTC(),
Limit: defaults.EventsIterationLimit,
Expand Down Expand Up @@ -5948,7 +5948,7 @@ func (a *ServerWithRoles) StreamSessionEvents(ctx context.Context, sessionID ses
var sessionType types.SessionKind
if !isTeleportServer {
var err error
sessionType, err = a.actionForKindSession(apidefaults.Namespace, sessionID)
sessionType, err = a.actionForKindSession(ctx, apidefaults.Namespace, sessionID)
if err != nil {
c, e := make(chan apievents.AuditEvent), make(chan error, 1)
e <- trace.Wrap(err)
Expand Down
1 change: 1 addition & 0 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2313,6 +2313,7 @@ func playSession(ctx context.Context, sessionID string, speed float64, streamer
SessionID: *sid,
Streamer: streamer,
SkipIdleTime: skipIdleTime,
Context: ctx,
})
if err != nil {
return trace.Wrap(err)
Expand Down
2 changes: 1 addition & 1 deletion lib/events/auditlog.go
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID
}

start := time.Now()
if err := l.UploadHandler.Download(l.ctx, sessionID, rawSession); err != nil {
if err := l.UploadHandler.Download(ctx, sessionID, rawSession); err != nil {
_ = rawSession.Close()
if errors.Is(err, fs.ErrNotExist) {
err = trace.NotFound("a recording for session %v was not found", sessionID)
Expand Down
3 changes: 3 additions & 0 deletions lib/events/s3sessions/s3handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"github.com/aws/aws-sdk-go-v2/service/s3"
awstypes "github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/gravitational/trace"
"go.opentelemetry.io/contrib/instrumentation/github.com/aws/aws-sdk-go-v2/otelaws"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/types"
Expand Down Expand Up @@ -232,6 +233,8 @@ func NewHandler(ctx context.Context, cfg Config) (*Handler, error) {
return nil, trace.Wrap(err)
}

otelaws.AppendMiddlewares(&awsConfig.APIOptions)

// Create S3 client with custom options
client := s3.NewFromConfig(awsConfig, s3Opts...)

Expand Down
16 changes: 11 additions & 5 deletions lib/player/player.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ type Config struct {
SessionID session.ID
Streamer Streamer
SkipIdleTime bool
Context context.Context
}

func New(cfg *Config) (*Player, error) {
Expand All @@ -140,6 +141,11 @@ func New(cfg *Config) (*Player, error) {
slog.With(teleport.ComponentKey, "player"),
)

ctx := context.Background()
if cfg.Context != nil {
ctx = cfg.Context
}

p := &Player{
clock: clk,
log: log,
Expand All @@ -158,7 +164,7 @@ func New(cfg *Config) (*Player, error) {
// start in a paused state
p.playPause <- make(chan struct{})

go p.stream()
go p.stream(ctx)

return p, nil
}
Expand Down Expand Up @@ -186,8 +192,8 @@ func (p *Player) SetSpeed(s float64) error {
return nil
}

func (p *Player) stream() {
ctx, cancel := context.WithCancel(context.Background())
func (p *Player) stream(baseContext context.Context) {
ctx, cancel := context.WithCancel(baseContext)
defer cancel()

eventsC, errC := p.streamer.StreamSessionEvents(metadata.WithSessionRecordingFormatContext(ctx, teleport.PTY), p.sessionID, 0)
Expand Down Expand Up @@ -232,7 +238,7 @@ func (p *Player) stream() {
// we rewind (by restarting the stream and seeking forward
// to the rewind point)
p.advanceTo.Store(int64(adv) * -1)
go p.stream()
go p.stream(baseContext)
return
default:
if adv != normalPlayback {
Expand All @@ -247,7 +253,7 @@ func (p *Player) stream() {
switch err := p.applyDelay(lastDelay, currentDelay); {
case errors.Is(err, errSeekWhilePaused):
p.log.DebugContext(ctx, "Seeked during pause, will restart stream")
go p.stream()
go p.stream(baseContext)
return
case err != nil:
close(p.emit)
Expand Down
1 change: 1 addition & 0 deletions lib/web/desktop_playback.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ func (h *Handler) desktopPlaybackHandle(
Log: h.logger,
SessionID: session.ID(sID),
Streamer: clt,
Context: r.Context(),
})
if err != nil {
h.log.Errorf("couldn't create player for session %v: %v", sID, err)
Expand Down
1 change: 1 addition & 0 deletions lib/web/tty_playback.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ func (h *Handler) ttyPlaybackHandle(
Log: h.logger,
SessionID: session.ID(sID),
Streamer: clt,
Context: r.Context(),
})
if err != nil {
h.log.Warn("player error", err)
Expand Down

0 comments on commit beec948

Please sign in to comment.