Skip to content

Commit

Permalink
fix: handle serialization errors that can be thrown by call to 'Commi…
Browse files Browse the repository at this point in the history
…t' (#403)
  • Loading branch information
vancity-amir authored Mar 26, 2020
1 parent b1ba04e commit 35a1558
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 9 deletions.
20 changes: 11 additions & 9 deletions handler/oauth2/flow_refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,22 +137,22 @@ func (c *RefreshTokenGrantHandler) PopulateTokenEndpointResponse(ctx context.Con

ts, err := c.TokenRevocationStorage.GetRefreshTokenSession(ctx, signature, nil)
if err != nil {
return handleRefreshTokenEndpointResponseStorageError(ctx, c.TokenRevocationStorage, err)
return handleRefreshTokenEndpointResponseStorageError(ctx, true, c.TokenRevocationStorage, err)
} else if err := c.TokenRevocationStorage.RevokeAccessToken(ctx, ts.GetID()); err != nil {
return handleRefreshTokenEndpointResponseStorageError(ctx, c.TokenRevocationStorage, err)
return handleRefreshTokenEndpointResponseStorageError(ctx, true, c.TokenRevocationStorage, err)
} else if err := c.TokenRevocationStorage.RevokeRefreshToken(ctx, ts.GetID()); err != nil {
return handleRefreshTokenEndpointResponseStorageError(ctx, c.TokenRevocationStorage, err)
return handleRefreshTokenEndpointResponseStorageError(ctx, true, c.TokenRevocationStorage, err)
}

storeReq := requester.Sanitize([]string{})
storeReq.SetID(ts.GetID())

if err := c.TokenRevocationStorage.CreateAccessTokenSession(ctx, accessSignature, storeReq); err != nil {
return handleRefreshTokenEndpointResponseStorageError(ctx, c.TokenRevocationStorage, err)
return handleRefreshTokenEndpointResponseStorageError(ctx, true, c.TokenRevocationStorage, err)
}

if err := c.TokenRevocationStorage.CreateRefreshTokenSession(ctx, refreshSignature, storeReq); err != nil {
return handleRefreshTokenEndpointResponseStorageError(ctx, c.TokenRevocationStorage, err)
return handleRefreshTokenEndpointResponseStorageError(ctx, true, c.TokenRevocationStorage, err)
}

responder.SetAccessToken(accessToken)
Expand All @@ -162,16 +162,18 @@ func (c *RefreshTokenGrantHandler) PopulateTokenEndpointResponse(ctx context.Con
responder.SetExtra("refresh_token", refreshToken)

if err := storage.MaybeCommitTx(ctx, c.TokenRevocationStorage); err != nil {
return errors.WithStack(fosite.ErrServerError.WithDebug(err.Error()))
return handleRefreshTokenEndpointResponseStorageError(ctx, false, c.TokenRevocationStorage, err)
}

return nil
}

func handleRefreshTokenEndpointResponseStorageError(ctx context.Context, store TokenRevocationStorage, storageErr error) (err error) {
func handleRefreshTokenEndpointResponseStorageError(ctx context.Context, rollback bool, store TokenRevocationStorage, storageErr error) (err error) {
defer func() {
if rbErr := storage.MaybeRollbackTx(ctx, store); rbErr != nil {
err = errors.WithStack(fosite.ErrServerError.WithDebug(rbErr.Error()))
if rollback {
if rbErr := storage.MaybeRollbackTx(ctx, store); rbErr != nil {
err = errors.WithStack(fosite.ErrServerError.WithDebug(rbErr.Error()))
}
}
}()

Expand Down
43 changes: 43 additions & 0 deletions handler/oauth2/flow_refresh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,49 @@ func TestRefreshFlowTransactional_PopulateTokenEndpointResponse(t *testing.T) {
},
expectError: fosite.ErrServerError,
},
{
description: "should result in a `fosite.ErrInvalidRequest` if transaction fails to commit due to a " +
"`fosite.ErrSerializationFailure` error",
setup: func() {
request.GrantTypes = fosite.Arguments{"refresh_token"}
mockTransactional.
EXPECT().
BeginTX(propagatedContext).
Return(propagatedContext, nil).
Times(1)
mockRevocationStore.
EXPECT().
GetRefreshTokenSession(propagatedContext, gomock.Any(), nil).
Return(request, nil).
Times(1)
mockRevocationStore.
EXPECT().
RevokeAccessToken(propagatedContext, gomock.Any()).
Return(nil).
Times(1)
mockRevocationStore.
EXPECT().
RevokeRefreshToken(propagatedContext, gomock.Any()).
Return(nil).
Times(1)
mockRevocationStore.
EXPECT().
CreateAccessTokenSession(propagatedContext, gomock.Any(), gomock.Any()).
Return(nil).
Times(1)
mockRevocationStore.
EXPECT().
CreateRefreshTokenSession(propagatedContext, gomock.Any(), gomock.Any()).
Return(nil).
Times(1)
mockTransactional.
EXPECT().
Commit(propagatedContext).
Return(fosite.ErrSerializationFailure).
Times(1)
},
expectError: fosite.ErrInvalidRequest,
},
} {
t.Run(fmt.Sprintf("scenario=%s", testCase.description), func(t *testing.T) {
ctrl := gomock.NewController(t)
Expand Down

0 comments on commit 35a1558

Please sign in to comment.