Skip to content

Commit

Permalink
make access request notifications expire alongside the request (#42527)
Browse files Browse the repository at this point in the history
  • Loading branch information
rudream authored Jun 6, 2024
1 parent b6b7fec commit 76187e4
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 6 deletions.
7 changes: 5 additions & 2 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ import (
"golang.org/x/exp/maps"
"golang.org/x/time/rate"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/client"
Expand Down Expand Up @@ -4971,7 +4972,8 @@ func (a *Server) CreateAccessRequestV2(ctx context.Context, req types.AccessRequ
Spec: &notificationsv1.NotificationSpec{},
SubKind: types.NotificationAccessRequestPendingSubKind,
Metadata: &headerv1.Metadata{
Labels: map[string]string{types.NotificationTitleLabel: notificationText, "request-id": req.GetName()},
Labels: map[string]string{types.NotificationTitleLabel: notificationText, "request-id": req.GetName()},
Expires: timestamppb.New(req.Expiry()),
},
},
},
Expand Down Expand Up @@ -5244,7 +5246,8 @@ func generateAccessRequestReviewedNotification(req types.AccessRequest, params t
"request-id": params.RequestID,
"roles": strings.Join(req.GetRoles(), ","),
"assumable-time": assumableTime,
}},
},
Expires: timestamppb.New(req.Expiry())},
}
}

Expand Down
1 change: 1 addition & 0 deletions lib/auth/grpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -5380,6 +5380,7 @@ func NewGRPCServer(cfg GRPCServerConfig) (*GRPCServer, error) {
notificationsServer, err := notifications.NewService(notifications.ServiceConfig{
Authorizer: cfg.Authorizer,
Backend: cfg.AuthServer.Services,
Clock: cfg.AuthServer.GetClock(),
UserNotificationCache: cfg.AuthServer.UserNotificationCache,
GlobalNotificationCache: cfg.AuthServer.GlobalNotificationCache,
})
Expand Down
67 changes: 63 additions & 4 deletions lib/auth/notification_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,49 @@ func TestNotifications(t *testing.T) {
},
}),
},
{
userNotification: &notificationsv1.Notification{
SubKind: "test-subkind",
Spec: &notificationsv1.NotificationSpec{
Username: managerUsername,
},
Metadata: &headerv1.Metadata{
Labels: map[string]string{
types.NotificationTitleLabel: "manager-7-expires",
},
// Expires in 15 minutes.
Expires: timestamppb.New(fakeClock.Now().Add(15 * time.Minute)),
},
},
},
{
globalNotification: &notificationsv1.GlobalNotification{
Spec: &notificationsv1.GlobalNotificationSpec{
Matcher: &notificationsv1.GlobalNotificationSpec_ByPermissions{
ByPermissions: &notificationsv1.ByPermissions{
RoleConditions: []*types.RoleConditions{
{
ReviewRequests: &types.AccessReviewConditions{
Roles: []string{"intern"},
},
},
},
},
},
Notification: &notificationsv1.Notification{
SubKind: "test-subkind",
Spec: &notificationsv1.NotificationSpec{},
Metadata: &headerv1.Metadata{
Labels: map[string]string{
types.NotificationTitleLabel: "manager-8-expires",
},
// Expires in 10 minutes.
Expires: timestamppb.New(fakeClock.Now().Add(10 * time.Minute)),
},
},
},
},
},
}

notificationIdMap := map[string]string{}
Expand Down Expand Up @@ -348,7 +391,7 @@ func TestNotifications(t *testing.T) {
require.NoError(t, err)
defer managerClient.Close()

managerExpectedNotifications := []string{"auditor-8,manager-6", "manager-5", "manager-4", "manager-3", "auditor-5,manager-2", "manager-1"}
managerExpectedNotifications := []string{"manager-8-expires", "manager-7-expires", "auditor-8,manager-6", "manager-5", "manager-4", "manager-3", "auditor-5,manager-2", "manager-1"}

resp, err = managerClient.ListNotifications(ctx, &notificationsv1.ListNotificationsRequest{
PageSize: 10,
Expand All @@ -359,10 +402,10 @@ func TestNotifications(t *testing.T) {
// Verify that we've reached the end of both lists.
require.Equal(t, "", resp.NextPageToken)

// Mark "auditor-8,manager-6" as clicked.
// Mark "manager-8-expires" as clicked.
_, err = managerClient.UpsertUserNotificationState(ctx, managerUsername, &notificationsv1.UserNotificationState{
Spec: &notificationsv1.UserNotificationStateSpec{
NotificationId: notificationIdMap["auditor-8,manager-6"],
NotificationId: notificationIdMap["manager-8-expires"],
},
Status: &notificationsv1.UserNotificationStateStatus{
NotificationState: notificationsv1.NotificationState_NOTIFICATION_STATE_CLICKED,
Expand All @@ -376,10 +419,26 @@ func TestNotifications(t *testing.T) {
})
require.NoError(t, err)

clickedNotification := resp.Notifications[0] // "auditor-8,manager-6" is the first item in the list
clickedNotification := resp.Notifications[0] // "manager-8-expires" is the first item in the list
clickedLabelValue := clickedNotification.GetMetadata().GetLabels()[types.NotificationClickedLabel]
require.Equal(t, "true", clickedLabelValue)

// Advance 11 minutes.
fakeClock.Advance(11 * time.Minute)

// Verify that notification "manager-8-expires" is now no longer returned.
resp, err = managerClient.ListNotifications(ctx, &notificationsv1.ListNotificationsRequest{})
require.NoError(t, err)
require.Equal(t, managerExpectedNotifications[1:], notificationsToTitlesList(t, resp.Notifications))

// Advance 16 minutes.
fakeClock.Advance(16 * time.Minute)

// Verify that notification "manager-7-expires" is now no longer returned either.
resp, err = managerClient.ListNotifications(ctx, &notificationsv1.ListNotificationsRequest{})
require.NoError(t, err)
require.Equal(t, managerExpectedNotifications[2:], notificationsToTitlesList(t, resp.Notifications))

// Verify that manager can't upsert a notification state for auditor
_, err = managerClient.UpsertUserNotificationState(ctx, auditorUsername, &notificationsv1.UserNotificationState{
Spec: &notificationsv1.UserNotificationStateSpec{
Expand Down
23 changes: 23 additions & 0 deletions lib/auth/notifications/notificationsv1/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"strings"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"

"github.com/gravitational/teleport/api/client"
apidefaults "github.com/gravitational/teleport/api/defaults"
Expand All @@ -47,6 +48,8 @@ type ServiceConfig struct {
// GlobalNotificationCache is a custom cache for user-specific notifications,
// this is to allow fetching notifications by date in descending order.
GlobalNotificationCache *services.GlobalNotificationCache

Clock clockwork.Clock
}

// Backend contains the getters required for notification states and user last seen notifications,
Expand All @@ -70,6 +73,7 @@ type Service struct {
backend Backend
userNotificationCache *services.UserNotificationCache
globalNotificationCache *services.GlobalNotificationCache
clock clockwork.Clock
}

// NewService returns a new notifications gRPC service.
Expand All @@ -83,13 +87,16 @@ func NewService(cfg ServiceConfig) (*Service, error) {
return nil, trace.BadParameter("user notification cache is required")
case cfg.GlobalNotificationCache == nil:
return nil, trace.BadParameter("global notification cache is required")
case cfg.Clock == nil:
cfg.Clock = clockwork.NewRealClock()
}

return &Service{
authorizer: cfg.Authorizer,
backend: cfg.Backend,
userNotificationCache: cfg.UserNotificationCache,
globalNotificationCache: cfg.GlobalNotificationCache,
clock: cfg.Clock,
}, nil
}

Expand Down Expand Up @@ -120,6 +127,12 @@ func (s *Service) ListNotifications(ctx context.Context, req *notificationsv1.Li
startKey = nextKey
}

currentTime := s.clock.Now()
var hasNotificationExpired = func(n *notificationsv1.Notification) bool {
notificationExpiryTime := n.GetMetadata().GetExpires().AsTime()
return currentTime.After(notificationExpiryTime)
}

var userNotifMatchFn = func(n *notificationsv1.Notification) bool {
// Return true if the user hasn't dismissed this notification
return notificationStatesMap[n.GetMetadata().GetName()] != notificationsv1.NotificationState_NOTIFICATION_STATE_DISMISSED
Expand All @@ -134,6 +147,11 @@ func (s *Service) ListNotifications(ctx context.Context, req *notificationsv1.Li
userNotifsStream = stream.FilterMap(
s.userNotificationCache.StreamUserNotifications(ctx, username, userKey),
func(n *notificationsv1.Notification) (*notificationsv1.Notification, bool) {
// If the notification is expired, return false right away.
if hasNotificationExpired(n) {
return nil, false
}

if !userNotifMatchFn(n) {
return nil, false
}
Expand All @@ -145,6 +163,11 @@ func (s *Service) ListNotifications(ctx context.Context, req *notificationsv1.Li
globalNotifsStream = stream.FilterMap(
s.globalNotificationCache.StreamGlobalNotifications(ctx, globalKey),
func(gn *notificationsv1.GlobalNotification) (*notificationsv1.GlobalNotification, bool) {
// If the notification is expired, return false right away.
if hasNotificationExpired(gn.GetSpec().GetNotification()) {
return nil, false
}

if !s.matchGlobalNotification(ctx, authCtx, gn, notificationStatesMap) {
return nil, false
}
Expand Down

0 comments on commit 76187e4

Please sign in to comment.