diff --git a/pkg/client/integration_test.go b/pkg/client/integration_test.go index d8b3f255..073efef7 100644 --- a/pkg/client/integration_test.go +++ b/pkg/client/integration_test.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "io" - "io/ioutil" "math/rand" "net/http" "net/http/cookiejar" @@ -56,11 +55,11 @@ func TestRelyingPartySession(t *testing.T) { clientID := t.Name() + "-" + strconv.FormatInt(seed.Int63(), 25) t.Log("------- run authorization code flow ------") - provider, _, refreshToken, idToken := RunAuthorizationCodeFlow(t, opServer, clientID, "secret") + provider, tokens := RunAuthorizationCodeFlow(t, opServer, clientID, "secret") t.Log("------- refresh tokens ------") - newTokens, err := rp.RefreshAccessToken(CTX, provider, refreshToken, "", "") + newTokens, err := rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, tokens.RefreshToken, "", "") require.NoError(t, err, "refresh token") assert.NotNil(t, newTokens, "access token") t.Logf("new access token %s", newTokens.AccessToken) @@ -68,11 +67,13 @@ func TestRelyingPartySession(t *testing.T) { t.Logf("new token type %s", newTokens.TokenType) t.Logf("new expiry %s", newTokens.Expiry.Format(time.RFC3339)) require.NotEmpty(t, newTokens.AccessToken, "new accessToken") - assert.NotEmpty(t, newTokens.Extra("id_token"), "new idToken") + assert.NotEmpty(t, newTokens.IDToken, "new idToken") + assert.NotNil(t, newTokens.IDTokenClaims) + assert.Equal(t, newTokens.IDTokenClaims.Subject, tokens.IDTokenClaims.Subject) t.Log("------ end session (logout) ------") - newLoc, err := rp.EndSession(CTX, provider, idToken, "", "") + newLoc, err := rp.EndSession(CTX, provider, tokens.IDToken, "", "") require.NoError(t, err, "logout") if newLoc != nil { t.Logf("redirect to %s", newLoc) @@ -81,12 +82,12 @@ func TestRelyingPartySession(t *testing.T) { } t.Log("------ attempt refresh again (should fail) ------") - t.Log("trying original refresh token", refreshToken) - _, err = rp.RefreshAccessToken(CTX, provider, refreshToken, "", "") + t.Log("trying original refresh token", tokens.RefreshToken) + _, err = rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, tokens.RefreshToken, "", "") assert.Errorf(t, err, "refresh with original") if newTokens.RefreshToken != "" { t.Log("trying replacement refresh token", newTokens.RefreshToken) - _, err = rp.RefreshAccessToken(CTX, provider, newTokens.RefreshToken, "", "") + _, err = rp.RefreshTokens[*oidc.IDTokenClaims](CTX, provider, newTokens.RefreshToken, "", "") assert.Errorf(t, err, "refresh with replacement") } } @@ -106,7 +107,7 @@ func TestResourceServerTokenExchange(t *testing.T) { clientSecret := "secret" t.Log("------- run authorization code flow ------") - provider, _, refreshToken, idToken := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret) + provider, tokens := RunAuthorizationCodeFlow(t, opServer, clientID, clientSecret) resourceServer, err := rs.NewResourceServerClientCredentials(CTX, opServer.URL, clientID, clientSecret) require.NoError(t, err, "new resource server") @@ -116,7 +117,7 @@ func TestResourceServerTokenExchange(t *testing.T) { tokenExchangeResponse, err := tokenexchange.ExchangeToken( CTX, resourceServer, - refreshToken, + tokens.RefreshToken, oidc.RefreshTokenType, "", "", @@ -134,7 +135,7 @@ func TestResourceServerTokenExchange(t *testing.T) { t.Log("------ end session (logout) ------") - newLoc, err := rp.EndSession(CTX, provider, idToken, "", "") + newLoc, err := rp.EndSession(CTX, provider, tokens.IDToken, "", "") require.NoError(t, err, "logout") if newLoc != nil { t.Logf("redirect to %s", newLoc) @@ -147,7 +148,7 @@ func TestResourceServerTokenExchange(t *testing.T) { tokenExchangeResponse, err = tokenexchange.ExchangeToken( CTX, resourceServer, - refreshToken, + tokens.RefreshToken, oidc.RefreshTokenType, "", "", @@ -161,7 +162,7 @@ func TestResourceServerTokenExchange(t *testing.T) { require.Nil(t, tokenExchangeResponse, "token exchange response") } -func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string) (provider rp.RelyingParty, accessToken, refreshToken, idToken string) { +func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, clientSecret string) (provider rp.RelyingParty, tokens *oidc.Tokens[*oidc.IDTokenClaims]) { targetURL := "http://local-site" localURL, err := url.Parse(targetURL + "/login?requestID=1234") require.NoError(t, err, "local url") @@ -258,7 +259,8 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, } var email string - redirect := func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty, info *oidc.UserInfo) { + redirect := func(w http.ResponseWriter, r *http.Request, newTokens *oidc.Tokens[*oidc.IDTokenClaims], state string, rp rp.RelyingParty, info *oidc.UserInfo) { + tokens = newTokens require.NotNil(t, tokens, "tokens") require.NotNil(t, info, "info") t.Log("access token", tokens.AccessToken) @@ -266,9 +268,6 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, t.Log("id token", tokens.IDToken) t.Log("email", info.Email) - accessToken = tokens.AccessToken - refreshToken = tokens.RefreshToken - idToken = tokens.IDToken email = info.Email http.Redirect(w, r, targetURL, 302) } @@ -290,12 +289,12 @@ func RunAuthorizationCodeFlow(t *testing.T, opServer *httptest.Server, clientID, require.NoError(t, err, "get fully-authorizied redirect location") require.Equal(t, targetURL, authorizedURL.String(), "fully-authorizied redirect location") - require.NotEmpty(t, idToken, "id token") - assert.NotEmpty(t, refreshToken, "refresh token") - assert.NotEmpty(t, accessToken, "access token") + require.NotEmpty(t, tokens.IDToken, "id token") + assert.NotEmpty(t, tokens.RefreshToken, "refresh token") + assert.NotEmpty(t, tokens.AccessToken, "access token") assert.NotEmpty(t, email, "email") - return provider, accessToken, refreshToken, idToken + return provider, tokens } type deferredHandler struct { @@ -343,7 +342,7 @@ func getForm(t *testing.T, desc string, httpClient *http.Client, uri *url.URL) [ func fillForm(t *testing.T, desc string, httpClient *http.Client, body []byte, uri *url.URL, opts ...gosubmit.Option) *url.URL { // TODO: switch to io.NopCloser when go1.15 support is dropped - req := gosubmit.ParseWithURL(ioutil.NopCloser(bytes.NewReader(body)), uri.String()).FirstForm().Testing(t).NewTestRequest( + req := gosubmit.ParseWithURL(io.NopCloser(bytes.NewReader(body)), uri.String()).FirstForm().Testing(t).NewTestRequest( append([]gosubmit.Option{gosubmit.AutoFill()}, opts...)..., ) if req.URL.Scheme == "" { diff --git a/pkg/client/rp/relying_party.go b/pkg/client/rp/relying_party.go index 7d73a5a2..5597c9d9 100644 --- a/pkg/client/rp/relying_party.go +++ b/pkg/client/rp/relying_party.go @@ -356,6 +356,25 @@ func GenerateAndStoreCodeChallenge(w http.ResponseWriter, rp RelyingParty) (stri return oidc.NewSHACodeChallenge(codeVerifier), nil } +// ErrMissingIDToken is returned when an id_token was expected, +// but not received in the token response. +var ErrMissingIDToken = errors.New("id_token missing") + +func verifyTokenResponse[C oidc.IDClaims](ctx context.Context, token *oauth2.Token, rp RelyingParty) (*oidc.Tokens[C], error) { + if rp.IsOAuth2Only() { + return &oidc.Tokens[C]{Token: token}, nil + } + idTokenString, ok := token.Extra(idTokenKey).(string) + if !ok { + return &oidc.Tokens[C]{Token: token}, ErrMissingIDToken + } + idToken, err := VerifyTokens[C](ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier()) + if err != nil { + return nil, err + } + return &oidc.Tokens[C]{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil +} + // CodeExchange handles the oauth2 code exchange, extracting and validating the id_token // returning it parsed together with the oauth2 tokens (access, refresh) func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingParty, opts ...CodeExchangeOpt) (tokens *oidc.Tokens[C], err error) { @@ -369,22 +388,7 @@ func CodeExchange[C oidc.IDClaims](ctx context.Context, code string, rp RelyingP if err != nil { return nil, err } - - if rp.IsOAuth2Only() { - return &oidc.Tokens[C]{Token: token}, nil - } - - idTokenString, ok := token.Extra(idTokenKey).(string) - if !ok { - return nil, errors.New("id_token missing") - } - - idToken, err := VerifyTokens[C](ctx, token.AccessToken, idTokenString, rp.IDTokenVerifier()) - if err != nil { - return nil, err - } - - return &oidc.Tokens[C]{Token: token, IDTokenClaims: idToken, IDToken: idTokenString}, nil + return verifyTokenResponse[C](ctx, token, rp) } type CodeExchangeCallback[C oidc.IDClaims] func(w http.ResponseWriter, r *http.Request, tokens *oidc.Tokens[C], state string, rp RelyingParty) @@ -609,11 +613,14 @@ type RefreshTokenRequest struct { GrantType oidc.GrantType `schema:"grant_type"` } -// RefreshAccessToken performs a token refresh. If it doesn't error, it will always +// RefreshTokens performs a token refresh. If it doesn't error, it will always // provide a new AccessToken. It may provide a new RefreshToken, and if it does, then -// the old one should be considered invalid. It may also provide a new IDToken. The -// new IDToken can be retrieved with token.Extra("id_token"). -func RefreshAccessToken(ctx context.Context, rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oauth2.Token, error) { +// the old one should be considered invalid. +// +// In case the RP is not OAuth2 only and an IDToken was part of the response, +// the IDToken and AccessToken will be verfied +// and the IDToken and IDTokenClaims fields will be populated in the returned object. +func RefreshTokens[C oidc.IDClaims](ctx context.Context, rp RelyingParty, refreshToken, clientAssertion, clientAssertionType string) (*oidc.Tokens[C], error) { request := RefreshTokenRequest{ RefreshToken: refreshToken, Scopes: rp.OAuthConfig().Scopes, @@ -623,7 +630,17 @@ func RefreshAccessToken(ctx context.Context, rp RelyingParty, refreshToken, clie ClientAssertionType: clientAssertionType, GrantType: oidc.GrantTypeRefreshToken, } - return client.CallTokenEndpoint(ctx, request, tokenEndpointCaller{RelyingParty: rp}) + newToken, err := client.CallTokenEndpoint(ctx, request, tokenEndpointCaller{RelyingParty: rp}) + if err != nil { + return nil, err + } + tokens, err := verifyTokenResponse[C](ctx, newToken, rp) + if err == nil || errors.Is(err, ErrMissingIDToken) { + // https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse + // ...except that it might not contain an id_token. + return tokens, nil + } + return nil, err } func EndSession(ctx context.Context, rp RelyingParty, idToken, optionalRedirectURI, optionalState string) (*url.URL, error) { diff --git a/pkg/client/rp/relying_party_test.go b/pkg/client/rp/relying_party_test.go new file mode 100644 index 00000000..4c5a1b31 --- /dev/null +++ b/pkg/client/rp/relying_party_test.go @@ -0,0 +1,107 @@ +package rp + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + tu "github.com/zitadel/oidc/v3/internal/testutil" + "github.com/zitadel/oidc/v3/pkg/oidc" + "golang.org/x/oauth2" +) + +func Test_verifyTokenResponse(t *testing.T) { + verifier := &IDTokenVerifier{ + Issuer: tu.ValidIssuer, + MaxAgeIAT: 2 * time.Minute, + ClientID: tu.ValidClientID, + Offset: time.Second, + SupportedSignAlgs: []string{string(tu.SignatureAlgorithm)}, + KeySet: tu.KeySet{}, + MaxAge: 2 * time.Minute, + ACR: tu.ACRVerify, + Nonce: func(context.Context) string { return tu.ValidNonce }, + } + tests := []struct { + name string + oauth2Only bool + tokens func() (token *oauth2.Token, want *oidc.Tokens[*oidc.IDTokenClaims]) + wantErr error + }{ + { + name: "succes, oauth2 only", + oauth2Only: true, + tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) { + accesToken, _ := tu.ValidAccessToken() + token := &oauth2.Token{ + AccessToken: accesToken, + } + return token, &oidc.Tokens[*oidc.IDTokenClaims]{ + Token: token, + } + }, + }, + { + name: "id_token missing error", + oauth2Only: false, + tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) { + accesToken, _ := tu.ValidAccessToken() + token := &oauth2.Token{ + AccessToken: accesToken, + } + return token, &oidc.Tokens[*oidc.IDTokenClaims]{ + Token: token, + } + }, + wantErr: ErrMissingIDToken, + }, + { + name: "verify tokens error", + oauth2Only: false, + tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) { + accesToken, _ := tu.ValidAccessToken() + token := &oauth2.Token{ + AccessToken: accesToken, + } + token = token.WithExtra(map[string]any{ + "id_token": "foobar", + }) + return token, nil + }, + wantErr: oidc.ErrParse, + }, + { + name: "success, with id_token", + oauth2Only: false, + tokens: func() (*oauth2.Token, *oidc.Tokens[*oidc.IDTokenClaims]) { + accesToken, _ := tu.ValidAccessToken() + token := &oauth2.Token{ + AccessToken: accesToken, + } + idToken, claims := tu.ValidIDToken() + token = token.WithExtra(map[string]any{ + "id_token": idToken, + }) + return token, &oidc.Tokens[*oidc.IDTokenClaims]{ + Token: token, + IDTokenClaims: claims, + IDToken: idToken, + } + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rp := &relyingParty{ + oauth2Only: tt.oauth2Only, + idTokenVerifier: verifier, + } + token, want := tt.tokens() + got, err := verifyTokenResponse[*oidc.IDTokenClaims](context.Background(), token, rp) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, want, got) + }) + } +}