Skip to content

Commit

Permalink
refactor: merge user and device code storage
Browse files Browse the repository at this point in the history
  • Loading branch information
nsklikas committed Nov 13, 2024
1 parent f13fb21 commit b136580
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 219 deletions.
40 changes: 12 additions & 28 deletions handler/rfc8628/auth_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,7 @@ type DeviceAuthHandler struct {
func (d *DeviceAuthHandler) HandleDeviceEndpointRequest(ctx context.Context, dar fosite.DeviceRequester, resp fosite.DeviceResponder) error {
var err error

var deviceCode string
deviceCode, err = d.handleDeviceCode(ctx, dar)
if err != nil {
return err
}

var userCode string
userCode, err = d.handleUserCode(ctx, dar)
deviceCode, userCode, err := d.handleDeviceAuthSession(ctx, dar)
if err != nil {
return err
}
Expand All @@ -52,41 +45,32 @@ func (d *DeviceAuthHandler) HandleDeviceEndpointRequest(ctx context.Context, dar
return nil
}

func (d *DeviceAuthHandler) handleDeviceCode(ctx context.Context, dar fosite.DeviceRequester) (string, error) {
code, signature, err := d.Strategy.GenerateDeviceCode(ctx)
if err != nil {
return "", errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
}
func (d *DeviceAuthHandler) handleDeviceAuthSession(ctx context.Context, dar fosite.DeviceRequester) (string, string, error) {
var userCode, userCodeSignature string

dar.GetSession().SetExpiresAt(fosite.DeviceCode, time.Now().UTC().Add(d.Config.GetDeviceAndUserCodeLifespan(ctx)))
if err = d.Storage.CreateDeviceCodeSession(ctx, signature, dar.Sanitize(nil)); err != nil {
return "", errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
deviceCode, deviceCodeSignature, err := d.Strategy.GenerateDeviceCode(ctx)
if err != nil {
return "", "", errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
}

return code, nil
}

func (d *DeviceAuthHandler) handleUserCode(ctx context.Context, dar fosite.DeviceRequester) (string, error) {
var err error
var userCode, signature string
dar.GetSession().SetExpiresAt(fosite.UserCode, time.Now().UTC().Add(d.Config.GetDeviceAndUserCodeLifespan(ctx)).Round(time.Second))
// Note: the retries are added here because we need to ensure uniqueness of user codes.
// The chances of duplicates should however be diminishing, because they are the same
// chance an attacker will be able to hit a valid code with few guesses. However, as
// used codes will probably still be around for some time before they get cleaned,
// the chances of hitting a duplicate here can be higher.
// Three retries should be plenty, as otherwise the entropy is definitely off.
for i := 0; i < MaxAttempts; i++ {
userCode, signature, err = d.Strategy.GenerateUserCode(ctx)
userCode, userCodeSignature, err = d.Strategy.GenerateUserCode(ctx)
if err != nil {
return "", err
return "", "", err
}

dar.GetSession().SetExpiresAt(fosite.UserCode, time.Now().UTC().Add(d.Config.GetDeviceAndUserCodeLifespan(ctx)).Round(time.Second))
if err = d.Storage.CreateUserCodeSession(ctx, signature, dar.Sanitize(nil)); err == nil {
return userCode, nil
if err = d.Storage.CreateDeviceAuthSession(ctx, deviceCodeSignature, userCodeSignature, dar.Sanitize(nil)); err == nil {
return deviceCode, userCode, nil
}
}

errMsg := fmt.Sprintf("Exceeded user-code generation max attempts %v: %s", MaxAttempts, err.Error())
return "", errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(errMsg))
return "", "", errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(errMsg))
}
25 changes: 6 additions & 19 deletions handler/rfc8628/auth_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,20 +85,15 @@ func Test_HandleDeviceEndpointRequestWithRetry(t *testing.T) {
EXPECT().
GenerateDeviceCode(ctx).
Return("deviceCode", "signature", nil)
mockRFC8628CoreStorage.
EXPECT().
CreateDeviceCodeSession(ctx, "signature", gomock.Any()).
Return(nil)
mockRFC8628CodeStrategy.
EXPECT().
GenerateUserCode(ctx).
Return("userCode", "signature", nil).
Return("userCode", "signature2", nil).
Times(1)
mockRFC8628CoreStorage.
EXPECT().
CreateUserCodeSession(ctx, "signature", gomock.Any()).
Return(nil).
Times(1)
CreateDeviceAuthSession(ctx, "signature", "signature2", gomock.Any()).
Return(nil)
},
check: func(t *testing.T, resp *fosite.DeviceResponse) {
assert.Equal(t, "userCode", resp.GetUserCode())
Expand All @@ -111,26 +106,22 @@ func Test_HandleDeviceEndpointRequestWithRetry(t *testing.T) {
EXPECT().
GenerateDeviceCode(ctx).
Return("deviceCode", "signature", nil)
mockRFC8628CoreStorage.
EXPECT().
CreateDeviceCodeSession(ctx, "signature", gomock.Any()).
Return(nil)
gomock.InOrder(
mockRFC8628CodeStrategy.
EXPECT().
GenerateUserCode(ctx).
Return("duplicatedUserCode", "duplicatedSignature", nil),
mockRFC8628CoreStorage.
EXPECT().
CreateUserCodeSession(ctx, "duplicatedSignature", gomock.Any()).
CreateDeviceAuthSession(ctx, "signature", "duplicatedSignature", gomock.Any()).
Return(errors.New("unique constraint violation")),
mockRFC8628CodeStrategy.
EXPECT().
GenerateUserCode(ctx).
Return("uniqueUserCode", "uniqueSignature", nil),
mockRFC8628CoreStorage.
EXPECT().
CreateUserCodeSession(ctx, "uniqueSignature", gomock.Any()).
CreateDeviceAuthSession(ctx, "signature", "uniqueSignature", gomock.Any()).
Return(nil),
)
},
Expand All @@ -145,18 +136,14 @@ func Test_HandleDeviceEndpointRequestWithRetry(t *testing.T) {
EXPECT().
GenerateDeviceCode(ctx).
Return("deviceCode", "signature", nil)
mockRFC8628CoreStorage.
EXPECT().
CreateDeviceCodeSession(ctx, "signature", gomock.Any()).
Return(nil)
mockRFC8628CodeStrategy.
EXPECT().
GenerateUserCode(ctx).
Return("duplicatedUserCode", "duplicatedSignature", nil).
Times(rfc8628.MaxAttempts)
mockRFC8628CoreStorage.
EXPECT().
CreateUserCodeSession(ctx, "duplicatedSignature", gomock.Any()).
CreateDeviceAuthSession(ctx, "signature", "duplicatedSignature", gomock.Any()).
Return(errors.New("unique constraint violation")).
Times(rfc8628.MaxAttempts)
},
Expand Down
31 changes: 6 additions & 25 deletions handler/rfc8628/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,15 @@ import (

// RFC8628CoreStorage is the storage needed for the DeviceAuthHandler
type RFC8628CoreStorage interface {
DeviceCodeStorage
UserCodeStorage
DeviceAuthStorage
oauth2.AccessTokenStorage
oauth2.RefreshTokenStorage
}

// DeviceCodeStorage handles the device_code storage
type DeviceCodeStorage interface {
// CreateDeviceCodeSession stores the device request for a given device code.
CreateDeviceCodeSession(ctx context.Context, signature string, request fosite.Requester) (err error)
// DeviceAuthStorage handles the device auth session storage
type DeviceAuthStorage interface {
// CreateDeviceAuthSession stores the device auth request session.
CreateDeviceAuthSession(ctx context.Context, deviceCodeSignature, userCodeSignature string, request fosite.Requester) (err error)

// GetDeviceCodeSession hydrates the session based on the given device code and returns the device request.
// If the device code has been invalidated with `InvalidateDeviceCodeSession`, this
Expand All @@ -30,26 +29,8 @@ type DeviceCodeStorage interface {
// Make sure to also return the fosite.Requester value when returning the fosite.ErrInvalidatedDeviceCode error!
GetDeviceCodeSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error)

// InvalidateDeviceCodeSession is called when a device code is being used. The state of the user
// InvalidateDeviceCodeSession is called when a device code is being used. The state of the device
// code should be set to invalid and consecutive requests to GetDeviceCodeSession should return the
// ErrInvalidatedDeviceCode error.
InvalidateDeviceCodeSession(ctx context.Context, signature string) (err error)
}

// UserCodeStorage handles the user_code storage
type UserCodeStorage interface {
// CreateUserCodeSession stores the device request for a given user code.
CreateUserCodeSession(ctx context.Context, signature string, request fosite.Requester) (err error)

// GetUserCodeSession hydrates the session based on the given user code and returns the device request.
// If the user code has been invalidated with `InvalidateUserCodeSession`, this
// method should return the ErrInvalidatedUserCode error.
//
// Make sure to also return the fosite.Requester value when returning the fosite.ErrInvalidatedUserCode error!
GetUserCodeSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error)

// InvalidateUserCodeSession is called when a user code is being used. The state of the user
// code should be set to invalid and consecutive requests to GetUserCodeSession should return the
// ErrInvalidatedUserCode error.
InvalidateUserCodeSession(ctx context.Context, signature string) (err error)
}
32 changes: 24 additions & 8 deletions handler/rfc8628/token_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,11 @@ func TestDeviceUserCode_HandleTokenEndpointRequest(t *testing.T) {
setup: func(t *testing.T, areq *fosite.AccessRequest, authreq *fosite.DeviceRequest) {
code, signature, err := strategy.GenerateDeviceCode(context.TODO())
require.NoError(t, err)
_, userCodeSignature, err := strategy.GenerateUserCode(context.TODO())
require.NoError(t, err)
areq.Form.Add("device_code", code)

require.NoError(t, store.CreateDeviceCodeSession(context.TODO(), signature, authreq))
require.NoError(t, store.CreateDeviceAuthSession(context.TODO(), signature, userCodeSignature, authreq))
},
expectErr: fosite.ErrAuthorizationPending,
},
Expand Down Expand Up @@ -192,9 +194,11 @@ func TestDeviceUserCode_HandleTokenEndpointRequest(t *testing.T) {
setup: func(t *testing.T, areq *fosite.AccessRequest, authreq *fosite.DeviceRequest) {
code, signature, err := strategy.GenerateDeviceCode(context.TODO())
require.NoError(t, err)
_, userCodeSignature, err := strategy.GenerateUserCode(context.TODO())
require.NoError(t, err)
areq.Form.Add("device_code", code)

require.NoError(t, store.CreateDeviceCodeSession(context.TODO(), signature, authreq))
require.NoError(t, store.CreateDeviceAuthSession(context.TODO(), signature, userCodeSignature, authreq))
},
expectErr: fosite.ErrDeviceExpiredToken,
},
Expand Down Expand Up @@ -227,9 +231,11 @@ func TestDeviceUserCode_HandleTokenEndpointRequest(t *testing.T) {
setup: func(t *testing.T, areq *fosite.AccessRequest, authreq *fosite.DeviceRequest) {
token, signature, err := strategy.GenerateDeviceCode(context.TODO())
require.NoError(t, err)
_, userCodeSignature, err := strategy.GenerateUserCode(context.TODO())
require.NoError(t, err)
areq.Form = url.Values{"device_code": {token}}

require.NoError(t, store.CreateDeviceCodeSession(context.TODO(), signature, authreq))
require.NoError(t, store.CreateDeviceAuthSession(context.TODO(), signature, userCodeSignature, authreq))
},
expectErr: fosite.ErrInvalidGrant,
},
Expand Down Expand Up @@ -263,9 +269,11 @@ func TestDeviceUserCode_HandleTokenEndpointRequest(t *testing.T) {
setup: func(t *testing.T, areq *fosite.AccessRequest, authreq *fosite.DeviceRequest) {
token, signature, err := strategy.GenerateDeviceCode(context.TODO())
require.NoError(t, err)
_, userCodeSignature, err := strategy.GenerateUserCode(context.TODO())
require.NoError(t, err)

areq.Form = url.Values{"device_code": {token}}
require.NoError(t, store.CreateDeviceCodeSession(context.TODO(), signature, authreq))
require.NoError(t, store.CreateDeviceAuthSession(context.TODO(), signature, userCodeSignature, authreq))
},
},
}
Expand Down Expand Up @@ -342,9 +350,11 @@ func TestDeviceUserCode_HandleTokenEndpointRequest_RateLimiting(t *testing.T) {

token, signature, err := strategy.GenerateDeviceCode(context.TODO())
require.NoError(t, err)
_, userCodeSignature, err := strategy.GenerateUserCode(context.TODO())
require.NoError(t, err)

areq.Form = url.Values{"device_code": {token}}
require.NoError(t, store.CreateDeviceCodeSession(context.TODO(), signature, authreq))
require.NoError(t, store.CreateDeviceAuthSession(context.TODO(), signature, userCodeSignature, authreq))
err = h.HandleTokenEndpointRequest(context.Background(), areq)
require.NoError(t, err, "%+v", err)
err = h.HandleTokenEndpointRequest(context.Background(), areq)
Expand Down Expand Up @@ -441,9 +451,11 @@ func TestDeviceUserCode_PopulateTokenEndpointResponse(t *testing.T) {
setup: func(t *testing.T, areq *fosite.AccessRequest, authreq *fosite.DeviceRequest, _ *fosite.Config) {
code, signature, err := strategy.GenerateDeviceCode(context.TODO())
require.NoError(t, err)
_, userCodeSignature, err := strategy.GenerateUserCode(context.TODO())
require.NoError(t, err)
areq.Form.Add("device_code", code)

require.NoError(t, store.CreateDeviceCodeSession(context.TODO(), signature, authreq))
require.NoError(t, store.CreateDeviceAuthSession(context.TODO(), signature, userCodeSignature, authreq))
},
check: func(t *testing.T, aresp *fosite.AccessResponse) {
assert.NotEmpty(t, aresp.AccessToken)
Expand Down Expand Up @@ -483,9 +495,11 @@ func TestDeviceUserCode_PopulateTokenEndpointResponse(t *testing.T) {
config.RefreshTokenScopes = []string{}
code, signature, err := strategy.GenerateDeviceCode(context.TODO())
require.NoError(t, err)
_, userCodeSignature, err := strategy.GenerateUserCode(context.TODO())
require.NoError(t, err)
areq.Form.Add("device_code", code)

require.NoError(t, store.CreateDeviceCodeSession(context.TODO(), signature, authreq))
require.NoError(t, store.CreateDeviceAuthSession(context.TODO(), signature, userCodeSignature, authreq))
},
check: func(t *testing.T, aresp *fosite.AccessResponse) {
assert.NotEmpty(t, aresp.AccessToken)
Expand Down Expand Up @@ -524,9 +538,11 @@ func TestDeviceUserCode_PopulateTokenEndpointResponse(t *testing.T) {
setup: func(t *testing.T, areq *fosite.AccessRequest, authreq *fosite.DeviceRequest, config *fosite.Config) {
code, signature, err := strategy.GenerateDeviceCode(context.TODO())
require.NoError(t, err)
_, userCodeSignature, err := strategy.GenerateUserCode(context.TODO())
require.NoError(t, err)
areq.Form.Add("device_code", code)

require.NoError(t, store.CreateDeviceCodeSession(context.TODO(), signature, authreq))
require.NoError(t, store.CreateDeviceAuthSession(context.TODO(), signature, userCodeSignature, authreq))
},
check: func(t *testing.T, aresp *fosite.AccessResponse) {
assert.NotEmpty(t, aresp.AccessToken)
Expand Down
5 changes: 2 additions & 3 deletions integration/helper_setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,8 @@ var fositeStore = &storage.MemoryStore{
AccessTokenRequestIDs: map[string]string{},
RefreshTokenRequestIDs: map[string]string{},
PARSessions: map[string]fosite.AuthorizeRequester{},
DeviceCodes: map[string]fosite.Requester{},
UserCodes: map[string]fosite.Requester{},
DeviceCodesRequestIDs: map[string]string{},
DeviceAuths: map[string]fosite.Requester{},
DeviceCodesRequestIDs: map[string]storage.DeviceAuthPair{},
UserCodesRequestIDs: map[string]string{},
}

Expand Down
Loading

0 comments on commit b136580

Please sign in to comment.