Skip to content

Commit

Permalink
fix(op): do not redirect to unverified uri on error (#641)
Browse files Browse the repository at this point in the history
* fix(op): do not redirect to unverified uri on error

Backport of #640
Related to #627

* adjust tests
  • Loading branch information
muhlemmer authored Aug 21, 2024
1 parent 75759d9 commit e8769ce
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 25 deletions.
49 changes: 33 additions & 16 deletions pkg/op/auth_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,19 +75,27 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
if authReq.RequestParam != "" && authorizer.RequestObjectSupported() {
authReq, err = ParseRequestObject(ctx, authReq, authorizer.Storage(), IssuerFromContext(ctx))
if err != nil {
AuthRequestError(w, r, authReq, err, authorizer.Encoder())
AuthRequestError(w, r, nil, err, authorizer.Encoder())
return
}
}
if authReq.ClientID == "" {
AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing client_id"), authorizer.Encoder())
AuthRequestError(w, r, nil, fmt.Errorf("auth request is missing client_id"), authorizer.Encoder())
return
}
if authReq.RedirectURI == "" {
AuthRequestError(w, r, authReq, fmt.Errorf("auth request is missing redirect_uri"), authorizer.Encoder())
AuthRequestError(w, r, nil, fmt.Errorf("auth request is missing redirect_uri"), authorizer.Encoder())
return
}
validation := ValidateAuthRequest

var client Client
validation := func(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier IDTokenHintVerifier) (sub string, err error) {
client, err = authorizer.Storage().GetClientByClientID(ctx, authReq.ClientID)
if err != nil {
return "", oidc.ErrInvalidRequestRedirectURI().WithDescription("unable to retrieve client by id").WithParent(err)
}
return ValidateAuthRequestClient(ctx, authReq, client, verifier)
}
if validater, ok := authorizer.(AuthorizeValidator); ok {
validation = validater.ValidateAuthRequest
}
Expand All @@ -105,11 +113,6 @@ func Authorize(w http.ResponseWriter, r *http.Request, authorizer Authorizer) {
AuthRequestError(w, r, authReq, oidc.DefaultToServerError(err, "unable to save auth request"), authorizer.Encoder())
return
}
client, err := authorizer.Storage().GetClientByClientID(ctx, req.GetClientID())
if err != nil {
AuthRequestError(w, r, req, oidc.DefaultToServerError(err, "unable to retrieve client by id"), authorizer.Encoder())
return
}
RedirectToLogin(req.GetID(), client, w, r)
}

Expand Down Expand Up @@ -204,23 +207,37 @@ func CopyRequestObjectToAuthRequest(authReq *oidc.AuthRequest, requestObject *oi
authReq.RequestParam = ""
}

// ValidateAuthRequest validates the authorize parameters and returns the userID of the id_token_hint if passed
// ValidateAuthRequest validates the authorize parameters and returns the userID of the id_token_hint if passed.
//
// Deprecated: Use [ValidateAuthRequestClient] to prevent querying for the Client twice.
func ValidateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, storage Storage, verifier IDTokenHintVerifier) (sub string, err error) {
authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge)
ctx, span := tracer.Start(ctx, "ValidateAuthRequest")
defer span.End()

client, err := storage.GetClientByClientID(ctx, authReq.ClientID)
if err != nil {
return "", oidc.ErrInvalidRequestRedirectURI().WithDescription("unable to retrieve client by id").WithParent(err)
}
return ValidateAuthRequestClient(ctx, authReq, client, verifier)
}

// ValidateAuthRequestClient validates the Auth request against the passed client.
// If id_token_hint is part of the request, the subject of the token is returned.
func ValidateAuthRequestClient(ctx context.Context, authReq *oidc.AuthRequest, client Client, verifier IDTokenHintVerifier) (sub string, err error) {
ctx, span := tracer.Start(ctx, "ValidateAuthRequestClient")
defer span.End()

if err := ValidateAuthReqRedirectURI(client, authReq.RedirectURI, authReq.ResponseType); err != nil {
return "", err
}
client, err := storage.GetClientByClientID(ctx, authReq.ClientID)
authReq.MaxAge, err = ValidateAuthReqPrompt(authReq.Prompt, authReq.MaxAge)
if err != nil {
return "", oidc.DefaultToServerError(err, "unable to retrieve client by id")
return "", err
}
authReq.Scopes, err = ValidateAuthReqScopes(client, authReq.Scopes)
if err != nil {
return "", err
}
if err := ValidateAuthReqRedirectURI(client, authReq.RedirectURI, authReq.ResponseType); err != nil {
return "", err
}
if err := ValidateAuthReqResponseType(client, authReq.ResponseType); err != nil {
return "", err
}
Expand Down
13 changes: 4 additions & 9 deletions pkg/op/auth_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,27 +132,22 @@ func TestValidateAuthRequest(t *testing.T) {
}{
{
"scope missing fails",
args{&oidc.AuthRequest{}, mock.NewMockStorageExpectValidClientID(t), nil},
args{&oidc.AuthRequest{ClientID: "client_id", RedirectURI: "https://registered.com/callback"}, mock.NewMockStorageExpectValidClientID(t), nil},
oidc.ErrInvalidRequest(),
},
{
"scope openid missing fails",
args{&oidc.AuthRequest{Scopes: []string{"profile"}}, mock.NewMockStorageExpectValidClientID(t), nil},
args{&oidc.AuthRequest{ClientID: "client_id", RedirectURI: "https://registered.com/callback", Scopes: []string{"profile"}}, mock.NewMockStorageExpectValidClientID(t), nil},
oidc.ErrInvalidScope(),
},
{
"response_type missing fails",
args{&oidc.AuthRequest{Scopes: []string{"openid"}}, mock.NewMockStorageExpectValidClientID(t), nil},
oidc.ErrInvalidRequest(),
},
{
"client_id missing fails",
args{&oidc.AuthRequest{Scopes: []string{"openid"}, ResponseType: oidc.ResponseTypeCode}, mock.NewMockStorageExpectValidClientID(t), nil},
args{&oidc.AuthRequest{ClientID: "client_id", RedirectURI: "https://registered.com/callback", Scopes: []string{"openid"}}, mock.NewMockStorageExpectValidClientID(t), nil},
oidc.ErrInvalidRequest(),
},
{
"redirect_uri missing fails",
args{&oidc.AuthRequest{Scopes: []string{"openid"}, ResponseType: oidc.ResponseTypeCode, ClientID: "client_id"}, mock.NewMockStorageExpectValidClientID(t), nil},
args{&oidc.AuthRequest{ClientID: "client_id", Scopes: []string{"openid"}, ResponseType: oidc.ResponseTypeCode}, mock.NewMockStorageExpectValidClientID(t), nil},
oidc.ErrInvalidRequest(),
},
}
Expand Down

0 comments on commit e8769ce

Please sign in to comment.