Skip to content

Commit

Permalink
Update local storage services to prefer ConditionalUpdate (#49182)
Browse files Browse the repository at this point in the history
Migrates existing uses of backend.CompareAndSwap to use
backend.ConditionalUpdate instead. As long as the revision is
being correctly and accurately provided, the conditional update
is a better version of atomic resource update than CAS.
  • Loading branch information
rosstimothy authored Nov 19, 2024
1 parent ce71412 commit 628536b
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 28 deletions.
3 changes: 1 addition & 2 deletions lib/services/local/connection_diagnostic.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ func (s *ConnectionDiagnosticService) UpdateConnectionDiagnostic(ctx context.Con
}

// AppendDiagnosticTrace adds a Trace into the ConnectionDiagnostics.
// It does a CompareAndSwap to ensure atomicity.
func (s *ConnectionDiagnosticService) AppendDiagnosticTrace(ctx context.Context, name string, t *types.ConnectionDiagnosticTrace) (types.ConnectionDiagnostic, error) {
existing, err := s.Get(ctx, backend.NewKey(connectionDiagnosticPrefix, name))
if err != nil {
Expand Down Expand Up @@ -115,7 +114,7 @@ func (s *ConnectionDiagnosticService) AppendDiagnosticTrace(ctx context.Context,
Revision: existing.Revision,
}

_, err = s.CompareAndSwap(ctx, *existing, newItem)
_, err = s.ConditionalUpdate(ctx, newItem)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
10 changes: 4 additions & 6 deletions lib/services/local/dynamic_access.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ func (s *DynamicAccessService) SetAccessRequestState(ctx context.Context, params
if err != nil {
return nil, trace.Wrap(err)
}
if _, err := s.CompareAndSwap(ctx, *item, newItem); err != nil {
if _, err := s.ConditionalUpdate(ctx, newItem); err != nil {
if trace.IsCompareFailed(err) {
select {
case <-retry.After():
Expand Down Expand Up @@ -195,7 +195,7 @@ func (s *DynamicAccessService) ApplyAccessReview(ctx context.Context, params typ
if err != nil {
return nil, trace.Wrap(err)
}
if _, err := s.CompareAndSwap(ctx, *item, newItem); err != nil {
if _, err := s.ConditionalUpdate(ctx, newItem); err != nil {
if trace.IsCompareFailed(err) {
select {
case <-retry.After():
Expand Down Expand Up @@ -411,10 +411,8 @@ func (s *DynamicAccessService) CreateAccessRequestAllowedPromotions(ctx context.
if err != nil {
return trace.Wrap(err)
}
// Currently, this logic is used only internally (no API exposed), and
// there is only one place that calls it. If this ever changes, we will
// need to do a CompareAndSwap here.
if _, err := s.Put(ctx, item); err != nil {

if _, err := s.Create(ctx, item); err != nil {
return trace.Wrap(err)
}
return nil
Expand Down
2 changes: 1 addition & 1 deletion lib/services/local/plugin_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ func (p *PluginDataService) updatePluginData(ctx context.Context, params types.P
return trace.Wrap(err)
}
} else {
if _, err := p.CompareAndSwap(ctx, *item, newItem); err != nil {
if _, err := p.ConditionalUpdate(ctx, newItem); err != nil {
if trace.IsCompareFailed(err) {
select {
case <-retry.After():
Expand Down
2 changes: 1 addition & 1 deletion lib/services/local/plugins.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ func (s *PluginsService) updateAndSwap(ctx context.Context, name string, modify
return trace.Wrap(err)
}

_, err = s.backend.CompareAndSwap(ctx, *item, backend.Item{
_, err = s.backend.ConditionalUpdate(ctx, backend.Item{
Key: backend.NewKey(pluginsPrefix, plugin.GetName()),
Value: value,
Expires: plugin.Expiry(),
Expand Down
12 changes: 6 additions & 6 deletions lib/services/local/presence.go
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ func (s *PresenceService) acquireSemaphore(ctx context.Context, key backend.Key,
if err != nil {
return nil, trace.Wrap(err)
}
sem, err := services.UnmarshalSemaphore(item.Value)
sem, err := services.UnmarshalSemaphore(item.Value, services.WithRevision(item.Revision))
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -711,7 +711,7 @@ func (s *PresenceService) acquireSemaphore(ctx context.Context, key backend.Key,
Revision: rev,
}

if _, err := s.CompareAndSwap(ctx, *item, newItem); err != nil {
if _, err := s.ConditionalUpdate(ctx, newItem); err != nil {
return nil, trace.Wrap(err)
}
return lease, nil
Expand All @@ -737,7 +737,7 @@ func (s *PresenceService) KeepAliveSemaphoreLease(ctx context.Context, lease typ
return trace.Wrap(err)
}

sem, err := services.UnmarshalSemaphore(item.Value)
sem, err := services.UnmarshalSemaphore(item.Value, services.WithRevision(item.Revision))
if err != nil {
return trace.Wrap(err)
}
Expand All @@ -761,7 +761,7 @@ func (s *PresenceService) KeepAliveSemaphoreLease(ctx context.Context, lease typ
Revision: rev,
}

_, err = s.CompareAndSwap(ctx, *item, newItem)
_, err = s.ConditionalUpdate(ctx, newItem)
if err != nil {
if trace.IsCompareFailed(err) {
return trace.CompareFailed("semaphore %v/%v has been concurrently updated, try again", sem.GetSubKind(), sem.GetName())
Expand Down Expand Up @@ -801,7 +801,7 @@ func (s *PresenceService) CancelSemaphoreLease(ctx context.Context, lease types.
return trace.Wrap(err)
}

sem, err := services.UnmarshalSemaphore(item.Value)
sem, err := services.UnmarshalSemaphore(item.Value, services.WithRevision(item.Revision))
if err != nil {
return trace.Wrap(err)
}
Expand All @@ -823,7 +823,7 @@ func (s *PresenceService) CancelSemaphoreLease(ctx context.Context, lease types.
Revision: rev,
}

_, err = s.CompareAndSwap(ctx, *item, newItem)
_, err = s.ConditionalUpdate(ctx, newItem)
switch {
case err == nil:
return nil
Expand Down
22 changes: 11 additions & 11 deletions lib/services/local/sessiontracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ import (
)

const (
sessionPrefix = "session_tracker"
retryDelay = time.Second
terminatedTTL = 3 * time.Minute
casRetryLimit = 7
casErrorMessage = "CompareAndSwap reached retry limit"
sessionPrefix = "session_tracker"
retryDelay = time.Second
terminatedTTL = 3 * time.Minute
updateRetryLimit = 7
updateRetryLimitMessage = "Update retry limit reached"
)

type sessionTracker struct {
Expand All @@ -63,7 +63,7 @@ func (s *sessionTracker) loadSession(ctx context.Context, sessionID string) (typ

// UpdatePresence updates the presence status of a user in a session.
func (s *sessionTracker) UpdatePresence(ctx context.Context, sessionID, user string) error {
for i := 0; i < casRetryLimit; i++ {
for i := 0; i < updateRetryLimit; i++ {
sessionItem, err := s.bk.Get(ctx, backend.NewKey(sessionPrefix, sessionID))
if err != nil {
return trace.Wrap(err)
Expand All @@ -89,7 +89,7 @@ func (s *sessionTracker) UpdatePresence(ctx context.Context, sessionID, user str
Expires: session.Expiry(),
Revision: sessionItem.Revision,
}
_, err = s.bk.CompareAndSwap(ctx, *sessionItem, item)
_, err = s.bk.ConditionalUpdate(ctx, item)
if trace.IsCompareFailed(err) {
select {
case <-ctx.Done():
Expand All @@ -102,7 +102,7 @@ func (s *sessionTracker) UpdatePresence(ctx context.Context, sessionID, user str
return trace.Wrap(err)
}

return trace.CompareFailed(casErrorMessage)
return trace.CompareFailed(updateRetryLimitMessage)
}

// GetSessionTracker returns the current state of a session tracker for an active session.
Expand Down Expand Up @@ -202,7 +202,7 @@ func (s *sessionTracker) CreateSessionTracker(ctx context.Context, tracker types

// UpdateSessionTracker updates a tracker resource for an active session.
func (s *sessionTracker) UpdateSessionTracker(ctx context.Context, req *proto.UpdateSessionTrackerRequest) error {
for i := 0; i < casRetryLimit; i++ {
for i := 0; i < updateRetryLimit; i++ {
sessionItem, err := s.bk.Get(ctx, backend.NewKey(sessionPrefix, req.SessionID))
if err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -268,7 +268,7 @@ func (s *sessionTracker) UpdateSessionTracker(ctx context.Context, req *proto.Up
Expires: expiry,
Revision: sessionItem.Revision,
}
_, err = s.bk.CompareAndSwap(ctx, *sessionItem, item)
_, err = s.bk.ConditionalUpdate(ctx, item)
if trace.IsCompareFailed(err) {
select {
case <-ctx.Done():
Expand All @@ -281,7 +281,7 @@ func (s *sessionTracker) UpdateSessionTracker(ctx context.Context, req *proto.Up
return trace.Wrap(err)
}

return trace.CompareFailed(casErrorMessage)
return trace.CompareFailed(updateRetryLimitMessage)
}

// RemoveSessionTracker removes a tracker resource for an active session.
Expand Down
3 changes: 2 additions & 1 deletion lib/services/local/unstable.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ func (s UnstableService) AssertSystemRole(ctx context.Context, req proto.SystemR
Expires: time.Now().Add(assertionTTL).UTC(),
}
if item != nil {
_, err = s.CompareAndSwap(ctx, *item, newItem)
newItem.Revision = item.Revision
_, err = s.ConditionalUpdate(ctx, newItem)
if trace.IsCompareFailed(err) {
// nodes are expected to perform assertions sequentially
return trace.CompareFailed("system role assertion set was concurrently modified (this is bug)")
Expand Down

0 comments on commit 628536b

Please sign in to comment.