Skip to content

Commit

Permalink
Add NewKeySet method to JWTAuth
Browse files Browse the repository at this point in the history
This commit adds support for KeySets through a new method `NewKeySet` to the `JWTAuth` struct.

It includes tests and comments that seek to explain how it works inline.

There's also an example in the _example directory that shows how to use and rotate a KeySet.
  • Loading branch information
alexlovelltroy committed Aug 19, 2024
1 parent b5d850b commit 7552877
Show file tree
Hide file tree
Showing 3 changed files with 284 additions and 11 deletions.
49 changes: 48 additions & 1 deletion _example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,18 @@ import (
"github.com/lestrrat-go/jwx/v2/jwt"
)

type dynamicTokenAuth struct {
keySet []byte
}

func (d *dynamicTokenAuth) JWTAuth() (*jwtauth.JWTAuth, error) {
keySet, err := jwtauth.NewKeySet(d.keySet)
if err != nil {
return nil, err
}
return keySet, nil
}

var tokenAuth *jwtauth.JWTAuth

func init() {
Expand All @@ -76,7 +88,8 @@ func init() {
// For debugging/example purposes, we generate and print
// a sample jwt token with claims `user_id:123` here:
_, tokenString, _ := tokenAuth.Encode(map[string]interface{}{"user_id": 123})
fmt.Printf("DEBUG: a sample jwt is %s\n\n", tokenString)
fmt.Printf("DEBUG: a sample jwt for /admin is %s\n\n", tokenString)
fmt.Printf("DEBUG: a sample jwt for /rotate is %s\n\n", sampleJWTRotate)
}

func main() {
Expand Down Expand Up @@ -105,6 +118,23 @@ func router() http.Handler {
})
})

r.Group(func(r chi.Router) {
dynamicTokenAuth := dynamicTokenAuth{keySet: keySet}
// Seek, verify and validate JWT tokens based on keys returned by the callback function
r.Use(jwtauth.VerifierDynamic(dynamicTokenAuth.JWTAuth))

// Handle valid / invalid tokens. In this example, we use
// the provided authenticator middleware, but you can write your
// own very easily, look at the Authenticator method in jwtauth.go
// and tweak it, its not scary.
r.Use(jwtauth.Authenticator)

r.Get("/rotate", func(w http.ResponseWriter, r *http.Request) {
_, claims, _ := jwtauth.FromContext(r.Context())
w.Write([]byte(fmt.Sprintf("protected area. hi %v", claims["user_id"])))
})
})

// Public routes
r.Group(func(r chi.Router) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -114,3 +144,20 @@ func router() http.Handler {

return r
}

var (
keySet = []byte(`{
"keys": [
{
"kty": "RSA",
"alg": "RS256",
"kid": "kid",
"use": "sig",
"n": "rgzO_v14UXJ33MvccKI8aIw3YpknVJbRB-m1z1X4j3gaTmmzmb7_naEd1TOKhF6Z1BGupvAKhCs8uHtp5e1PCrp52kzrjv7nqQfDpdppPZmKpwf-OD_lVgLLuCljB71mX9w7T5vI_WiVknuNhm48y0TJQNslpDZum4E2e0BLKUDRKKlo25foGoDuQN535_Xso861U8KsA80jX37BJplQ6IHewV_bbe04NYTVqaFcmLaZCAzh2f8L1h4xt76Y0xF_u8FXt2-rgcWlz17CtZzxC8ZXNI_92pX8CY5LY2eQf_B_n5Rhd5TQvEIdoI1GNBrcKUI9pMeEC4pErcOGgKGH7w",
"e": "AQAB"
}
]
}`)

sampleJWTRotate = `eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6ImtpZCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.APC4bUOmfbcXjBnZnmyiGBpXqlboTB4Qbh_sqJrgSU5AEQlwzjvDJ79eBlty8h6kfq3i5ffy87s-g82ZoRsHqMjwCIvTOVnoEyDgVu68s9lE32uaA0cc2-hbA13DIBsyIUGjehh9c3h93BrUoUr7n0CHgoKgx2OEw1Bq8vm4EqvmFGF-mr_0qi32uudPy3I15SyP1NJfU0ogQEFUdDHww3c8omDmrTPiGlWZAl9AiBMroDu0nq3UOtC4d5Se-361NEGiZ9J_kHcVWGdoMwsi5KEB0Uf3wAfXK3wcXeRu1pTXYKOV3X3g_2ss6mh65bNMsSx-MZUnQv5v6qZMOxMBUA`
)
29 changes: 29 additions & 0 deletions jwtauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package jwtauth

import (
"context"
"encoding/json"
"errors"
"net/http"
"strings"
"time"

"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jwt"
)

Expand All @@ -17,6 +19,7 @@ type JWTAuth struct {
verifyKey interface{} // public-key, only used by RSA and ECDSA algorithms
verifier jwt.ParseOption
validateOptions []jwt.ValidateOption
keySet jwk.Set
}

var (
Expand Down Expand Up @@ -50,6 +53,24 @@ func New(alg string, signKey interface{}, verifyKey interface{}, validateOptions
return ja
}

// NewKeySet initializes a new JWTAuth instance with the provided key set.
// It takes a keySet parameter, which is a byte slice containing the key set in JSON format.
// The function returns a pointer to JWTAuth and an error.
// If the key set cannot be unmarshaled from the byte slice, an error is returned.
// Otherwise, the JWTAuth instance is created with the unmarshaled key set and a verifier is set using the key set.
func NewKeySet(keySet []byte) (*JWTAuth, error) {
ks := jwk.NewSet()
err := json.Unmarshal(keySet, &ks)
if err != nil {
return nil, err
}

ja := &JWTAuth{keySet: ks}
ja.verifier = jwt.WithKeySet(ks)

return ja, nil
}

// Verifier http middleware handler will verify a JWT string from a http request.
//
// Verifier will search for a JWT token in a http request, in the order:
Expand Down Expand Up @@ -119,13 +140,21 @@ func VerifyToken(ja *JWTAuth, tokenString string) (jwt.Token, error) {
return token, nil
}

// Encode generates a JWT token string with the provided claims.
// It returns the encoded token as a string, along with the token object and any error encountered.
func (ja *JWTAuth) Encode(claims map[string]interface{}) (t jwt.Token, tokenString string, err error) {
t = jwt.New()
for k, v := range claims {
if err := t.Set(k, v); err != nil {
return nil, "", err
}
}
// ja.sign() isn't going to work if ja.signKey is nil
if ja.signKey == nil {
// This generally means that you've called Encode on a KeySet
// which can't be supported.
return nil, "", errors.New("no signing key provided")
}
payload, err := ja.sign(t)
if err != nil {
return nil, "", err
Expand Down
217 changes: 207 additions & 10 deletions jwtauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (
"testing"
"time"

"github.com/lestrrat-go/jwx/v2/jws"

"github.com/go-chi/chi/v5"
"github.com/go-chi/jwtauth/v5"
"github.com/lestrrat-go/jwx/v2/jwa"
Expand Down Expand Up @@ -41,6 +43,27 @@ MFwwDQYJKoZIhvcNAQEBBQADSwAwSAJBALxo3PCjFw4QjgOX06QCJIJBnXXNiEYw
DLxxa5/7QyH6y77nCRQyJ3x3UwF9rUD0RCsp4sNdX5kOQ9PUyHyOtCUCAwEAAQ==
-----END PUBLIC KEY-----
`

KeySet = `{
"keys": [
{
"kty": "RSA",
"n": "vGjc8KMXDhCOA5fTpAIkgkGddc2IRjAMvHFrn_tDIfrLvucJFDInfHdTAX2tQPREKyniw11fmQ5D09TIfI60JQ",
"e": "AQAB",
"alg": "RS256",
"kid": "1",
"use": "sig"
},
{
"kty": "RSA",
"n": "foo",
"e": "AQAB",
"alg": "RS256",
"kid": "2",
"use": "sig"
}
]
}`
)

func init() {
Expand All @@ -51,6 +74,59 @@ func init() {
// Tests
//

func TestNewKeySet(t *testing.T) {
_, err := jwtauth.NewKeySet([]byte("not a valid key set"))
if err == nil {
t.Fatal("The error should not be nil")
}

_, err = jwtauth.NewKeySet([]byte(KeySet))
if err != nil {
t.Fatalf(err.Error())
}
}

func TestKeySetRSA(t *testing.T) {
privateKeyBlock, _ := pem.Decode([]byte(PrivateKeyRS256String))

privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes)

if err != nil {
t.Fatalf(err.Error())
}

KeySetAuth, _ := jwtauth.NewKeySet([]byte(KeySet))
claims := map[string]interface{}{
"key": "val",
"key2": "val2",
"key3": "val3",
}

signed := newJwtRSAToken(jwa.RS256, privateKey, "1", claims)

token, err := KeySetAuth.Decode(signed)

if err != nil {
t.Fatalf("Failed to decode token string %s\n", err.Error())
}

tokenClaims, err := token.AsMap(context.Background())
if err != nil {
t.Fatal(err.Error())
}

if !reflect.DeepEqual(claims, tokenClaims) {
t.Fatalf("The decoded claims don't match the original ones\n")
}

_, _, err = KeySetAuth.Encode(claims)
if err.Error() != "no signing key provided" {
t.Fatalf("Expect error to equal %s. Found: %s.", "no signing key provided", err.Error())
}
fmt.Println(token.PrivateClaims())

}

func TestSimple(t *testing.T) {
r := chi.NewRouter()

Expand Down Expand Up @@ -279,20 +355,118 @@ func TestMore(t *testing.T) {
}
}

func TestEncodeClaims(t *testing.T) {
func TestKeySet(t *testing.T) {
privateKeyBlock, _ := pem.Decode([]byte(PrivateKeyRS256String))
privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes)
if err != nil {
t.Fatalf(err.Error())
}

r := chi.NewRouter()

keySet, err := jwtauth.NewKeySet([]byte(KeySet))
if err != nil {
t.Fatalf(err.Error())
}

// Protected routes
r.Group(func(r chi.Router) {
r.Use(jwtauth.Verifier(keySet))

authenticator := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token, _, err := jwtauth.FromContext(r.Context())

if err != nil {
http.Error(w, jwtauth.ErrorReason(err).Error(), http.StatusUnauthorized)
return
}

if err := jwt.Validate(token); err != nil {
http.Error(w, jwtauth.ErrorReason(err).Error(), http.StatusUnauthorized)
return
}

// Token is authenticated, pass it through
next.ServeHTTP(w, r)
})
}
r.Use(authenticator)

r.Get("/admin", func(w http.ResponseWriter, r *http.Request) {
_, claims, err := jwtauth.FromContext(r.Context())

if err != nil {
w.Write([]byte(fmt.Sprintf("error! %v", err)))
return
}

w.Write([]byte(fmt.Sprintf("protected, user:%v", claims["user_id"])))
})
})

// Public routes
r.Group(func(r chi.Router) {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("welcome"))
})
})

ts := httptest.NewServer(r)
defer ts.Close()

h := http.Header{}
h.Set("Authorization", "BEARER "+newJwtRSAToken(jwa.RS256, privateKey, "1", map[string]interface{}{"user_id": 31337, "exp": jwtauth.ExpireIn(5 * time.Minute)}))
if status, resp := testRequest(t, ts, "GET", "/admin", h, nil); status != 200 || resp != "protected, user:31337" {
t.Fatalf(resp)
}
}

func TestEncodeInvalidClaim(t *testing.T) {
ja := jwtauth.New("HS256", []byte("secretpass"), nil)
claims := map[string]interface{}{
"key1": "val1",
"key2": 2,
"key3": time.Now(),
"key4": []string{"1", "2"},
"key1": "val1",
"key2": 2,
"key3": time.Now(),
"key4": []string{"1", "2"},
jwt.JwtIDKey: 1, // This is invalid becasue it should be a string
}
claims[jwt.JwtIDKey] = 1
if _, _, err := TokenAuthHS256.Encode(claims); err == nil {
_, _, err := ja.Encode(claims)
if err == nil {

t.Fatal("encoding invalid claims succeeded")
}
claims[jwt.JwtIDKey] = "123"
if _, _, err := TokenAuthHS256.Encode(claims); err != nil {
t.Fatalf("unexpected error encoding valid claims: %v", err)
}
func TestEncode(t *testing.T) {
ja := jwtauth.New("HS256", []byte("secretpass"), nil)

claims := map[string]interface{}{
"sub": "1234567890",
"name": "John Doe",
"iat": 1516239022,
}

token, tokenString, err := ja.Encode(claims)
if err != nil {
t.Fatalf("Failed to encode claims: %s", err.Error())
}

if token == nil {
t.Fatal("Token should not be nil")
}

if tokenString == "" {
t.Fatal("Token string should not be empty")
}

// Verify the token string
verifiedToken, err := ja.Decode(tokenString)
if err != nil {
t.Fatalf("Failed to decode token string: %s", err.Error())
}

if !reflect.DeepEqual(token, verifiedToken) {
t.Fatal("Decoded token does not match the original token")
}
}

Expand Down Expand Up @@ -357,6 +531,29 @@ func newJwt512Token(secret []byte, claims ...map[string]interface{}) string {
return string(tokenPayload)
}

func newJwtRSAToken(alg jwa.SignatureAlgorithm, secret interface{}, kid string, claims ...map[string]interface{}) string {
token := jwt.New()
if len(claims) > 0 {
for k, v := range claims[0] {
token.Set(k, v)
}
}

headers := jws.NewHeaders()
if kid != "" {
err := headers.Set("kid", kid)
if err != nil {
log.Fatal(err)
}
}

tokenPayload, err := jwt.Sign(token, jwt.WithKey(alg, secret, jws.WithProtectedHeaders(headers)))
if err != nil {
log.Fatal(err)
}
return string(tokenPayload)
}

func newAuthHeader(claims ...map[string]interface{}) http.Header {
h := http.Header{}
h.Set("Authorization", "BEARER "+newJwtToken(TokenSecret, claims...))
Expand Down

0 comments on commit 7552877

Please sign in to comment.