Skip to content

Commit

Permalink
Merge pull request #1473 from alindeman/add-user-endpoint
Browse files Browse the repository at this point in the history
Add UserInfo endpoint
  • Loading branch information
srenatus authored Jul 2, 2019
2 parents d6fad19 + 5b66bf0 commit 8b4dbb9
Show file tree
Hide file tree
Showing 16 changed files with 547 additions and 373 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ require (
github.com/boltdb/bolt v1.3.1 // indirect
github.com/cockroachdb/cmux v0.0.0-20170110192607-30d10be49292 // indirect
github.com/coreos/etcd v3.2.9+incompatible
github.com/coreos/go-oidc v0.0.0-20170307191026-be73733bb8cc
github.com/coreos/go-oidc v2.0.0+incompatible
github.com/coreos/go-semver v0.2.0 // indirect
github.com/coreos/go-systemd v0.0.0-20181031085051-9002847aa142 // indirect
github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f // indirect
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ github.com/cockroachdb/cmux v0.0.0-20170110192607-30d10be49292 h1:dzj1/xcivGjNPw
github.com/cockroachdb/cmux v0.0.0-20170110192607-30d10be49292/go.mod h1:qRiX68mZX1lGBkTWyp3CLcenw9I94W2dLeRvMzcn9N4=
github.com/coreos/etcd v3.2.9+incompatible h1:3TbjfK5+aSRLTU/KgBC1xlgA2dn2ddYQngRqX6HFwlQ=
github.com/coreos/etcd v3.2.9+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE=
github.com/coreos/go-oidc v0.0.0-20170307191026-be73733bb8cc h1:9yuvA19Q5WFkLwJcMDoYm8m89ilzqZ5zEHqdvU+Zbds=
github.com/coreos/go-oidc v0.0.0-20170307191026-be73733bb8cc/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc=
github.com/coreos/go-oidc v2.0.0+incompatible h1:+RStIopZ8wooMx+Vs5Bt8zMXxV1ABl5LbakNExNmZIg=
github.com/coreos/go-oidc v2.0.0+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc=
github.com/coreos/go-semver v0.2.0 h1:3Jm3tLmsgAYcjC+4Up7hJrFBPr+n7rAqYeSw/SZazuY=
github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
github.com/coreos/go-systemd v0.0.0-20181031085051-9002847aa142 h1:3jFq2xL4ZajGK4aZY8jz+DAF0FHjI51BXjjSwCzS1Dk=
Expand Down
61 changes: 55 additions & 6 deletions server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"sync"
"time"

oidc "github.com/coreos/go-oidc"
"github.com/gorilla/mux"
jose "gopkg.in/square/go-jose.v2"

Expand Down Expand Up @@ -151,6 +152,7 @@ type discovery struct {
Auth string `json:"authorization_endpoint"`
Token string `json:"token_endpoint"`
Keys string `json:"jwks_uri"`
UserInfo string `json:"userinfo_endpoint"`
ResponseTypes []string `json:"response_types_supported"`
Subjects []string `json:"subject_types_supported"`
IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"`
Expand All @@ -165,6 +167,7 @@ func (s *Server) discoveryHandler() (http.HandlerFunc, error) {
Auth: s.absURL("/auth"),
Token: s.absURL("/token"),
Keys: s.absURL("/keys"),
UserInfo: s.absURL("/userinfo"),
Subjects: []string{"public"},
IDTokenAlgs: []string{string(jose.RS256)},
Scopes: []string{"openid", "email", "groups", "profile", "offline_access"},
Expand Down Expand Up @@ -559,7 +562,8 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
idToken string
idTokenExpiry time.Time

accessToken = storage.NewID()
// Access token
accessToken string
)

for _, responseType := range authReq.ResponseTypes {
Expand Down Expand Up @@ -595,6 +599,14 @@ func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authRe
case responseTypeIDToken:
implicitOrHybrid = true
var err error

accessToken, err = s.newAccessToken(authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, authReq.ConnectorID)
if err != nil {
s.logger.Errorf("failed to create new access token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}

idToken, idTokenExpiry, err = s.newIDToken(authReq.ClientID, authReq.Claims, authReq.Scopes, authReq.Nonce, accessToken, authReq.ConnectorID)
if err != nil {
s.logger.Errorf("failed to create ID token: %v", err)
Expand Down Expand Up @@ -716,7 +728,13 @@ func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client s
return
}

accessToken := storage.NewID()
accessToken, err := s.newAccessToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, authCode.ConnectorID)
if err != nil {
s.logger.Errorf("failed to create new access token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}

idToken, expiry, err := s.newIDToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce, accessToken, authCode.ConnectorID)
if err != nil {
s.logger.Errorf("failed to create ID token: %v", err)
Expand Down Expand Up @@ -965,7 +983,13 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
Groups: ident.Groups,
}

accessToken := storage.NewID()
accessToken, err := s.newAccessToken(client.ID, claims, scopes, refresh.Nonce, refresh.ConnectorID)
if err != nil {
s.logger.Errorf("failed to create new access token: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
return
}

idToken, expiry, err := s.newIDToken(client.ID, claims, scopes, refresh.Nonce, accessToken, refresh.ConnectorID)
if err != nil {
s.logger.Errorf("failed to create ID token: %v", err)
Expand Down Expand Up @@ -1026,10 +1050,35 @@ func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, clie
s.writeAccessToken(w, idToken, accessToken, rawNewToken, expiry)
}

func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
const prefix = "Bearer "

auth := r.Header.Get("authorization")
if len(auth) < len(prefix) || !strings.EqualFold(prefix, auth[:len(prefix)]) {
w.Header().Set("WWW-Authenticate", "Bearer")
s.tokenErrHelper(w, errAccessDenied, "Invalid bearer token.", http.StatusUnauthorized)
return
}
rawIDToken := auth[len(prefix):]

verifier := oidc.NewVerifier(s.issuerURL.String(), &storageKeySet{s.storage}, &oidc.Config{SkipClientIDCheck: true})
idToken, err := verifier.Verify(r.Context(), rawIDToken)
if err != nil {
s.tokenErrHelper(w, errAccessDenied, err.Error(), http.StatusForbidden)
return
}

var claims json.RawMessage
if err := idToken.Claims(&claims); err != nil {
s.tokenErrHelper(w, errServerError, err.Error(), http.StatusInternalServerError)
return
}

w.Header().Set("Content-Type", "application/json")
w.Write(claims)
}

func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, refreshToken string, expiry time.Time) {
// TODO(ericchiang): figure out an access token story and support the user info
// endpoint. For now use a random value so no one depends on the access_token
// holding a specific structure.
resp := struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
Expand Down
44 changes: 44 additions & 0 deletions server/oauth2.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
Expand Down Expand Up @@ -265,6 +266,11 @@ type federatedIDClaims struct {
UserID string `json:"user_id,omitempty"`
}

func (s *Server) newAccessToken(clientID string, claims storage.Claims, scopes []string, nonce, connID string) (accessToken string, err error) {
idToken, _, err := s.newIDToken(clientID, claims, scopes, nonce, storage.NewID(), connID)
return idToken, err
}

func (s *Server) newIDToken(clientID string, claims storage.Claims, scopes []string, nonce, accessToken, connID string) (idToken string, expiry time.Time, err error) {
keys, err := s.storage.GetKeys()
if err != nil {
Expand Down Expand Up @@ -561,3 +567,41 @@ func validateRedirectURI(client storage.Client, redirectURI string) bool {
host, _, err := net.SplitHostPort(u.Host)
return err == nil && host == "localhost"
}

// storageKeySet implements the oidc.KeySet interface backed by Dex storage
type storageKeySet struct {
storage.Storage
}

func (s *storageKeySet) VerifySignature(_ context.Context, jwt string) (payload []byte, err error) {
jws, err := jose.ParseSigned(jwt)
if err != nil {
return nil, err
}

keyID := ""
for _, sig := range jws.Signatures {
keyID = sig.Header.KeyID
break
}

skeys, err := s.Storage.GetKeys()
if err != nil {
return nil, err
}

keys := []*jose.JSONWebKey{skeys.SigningKeyPub}
for _, vk := range skeys.VerificationKeys {
keys = append(keys, vk.PublicKey)
}

for _, key := range keys {
if keyID == "" || key.KeyID == keyID {
if payload, err := jws.Verify(key); err == nil {
return payload, nil
}
}
}

return nil, errors.New("failed to verify id token signature")
}
87 changes: 87 additions & 0 deletions server/oauth2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package server

import (
"context"
"crypto/rand"
"crypto/rsa"
"net/http"
"net/http/httptest"
"net/url"
Expand All @@ -11,6 +13,7 @@ import (
jose "gopkg.in/square/go-jose.v2"

"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/memory"
)

func TestParseAuthorizationRequest(t *testing.T) {
Expand Down Expand Up @@ -259,3 +262,87 @@ func TestValidRedirectURI(t *testing.T) {
}
}
}

func TestStorageKeySet(t *testing.T) {
s := memory.New(logger)
if err := s.UpdateKeys(func(keys storage.Keys) (storage.Keys, error) {
keys.SigningKey = &jose.JSONWebKey{
Key: testKey,
KeyID: "testkey",
Algorithm: "RS256",
Use: "sig",
}
keys.SigningKeyPub = &jose.JSONWebKey{
Key: testKey.Public(),
KeyID: "testkey",
Algorithm: "RS256",
Use: "sig",
}
return keys, nil
}); err != nil {
t.Fatal(err)
}

tests := []struct {
name string
tokenGenerator func() (jwt string, err error)
wantErr bool
}{
{
name: "valid token",
tokenGenerator: func() (string, error) {
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: testKey}, nil)
if err != nil {
return "", err
}

jws, err := signer.Sign([]byte("payload"))
if err != nil {
return "", err
}

return jws.CompactSerialize()
},
wantErr: false,
},
{
name: "token signed by different key",
tokenGenerator: func() (string, error) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return "", err
}

signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: key}, nil)
if err != nil {
return "", err
}

jws, err := signer.Sign([]byte("payload"))
if err != nil {
return "", err
}

return jws.CompactSerialize()
},
wantErr: true,
},
}

for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
jwt, err := tc.tokenGenerator()
if err != nil {
t.Fatal(err)
}

keySet := &storageKeySet{s}

_, err = keySet.VerifySignature(context.Background(), jwt)
if (err != nil && !tc.wantErr) || (err == nil && tc.wantErr) {
t.Fatalf("wantErr = %v, but got err = %v", tc.wantErr, err)
}
})
}
}
1 change: 1 addition & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
// TODO(ericchiang): rate limit certain paths based on IP.
handleWithCORS("/token", s.handleToken)
handleWithCORS("/keys", s.handlePublicKeys)
handleWithCORS("/userinfo", s.handleUserInfo)
handleFunc("/auth", s.handleAuthorization)
handleFunc("/auth/{connector}", s.handleConnectorLogin)
r.HandleFunc(path.Join(issuerURL.Path, "/callback"), func(w http.ResponseWriter, r *http.Request) {
Expand Down
Loading

0 comments on commit 8b4dbb9

Please sign in to comment.