Skip to content

Commit

Permalink
Merge pull request #5 from ZachEddy/jwt-handling
Browse files Browse the repository at this point in the history
Update JWT Handling
  • Loading branch information
ZachEddy authored Apr 17, 2018
2 parents 3704645 + 431aadf commit 2fde752
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 132 deletions.
67 changes: 67 additions & 0 deletions mw/auth/jwt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package auth

import (
"context"
"errors"
"fmt"

jwt "github.com/dgrijalva/jwt-go"
"github.com/grpc-ecosystem/go-grpc-middleware/auth"
)

const (
// TODO: Field is tentatively called "AccountID" but will probably need to be
// changed. We don't know what the JWT will look like, so we're giving it our
// best guess for the time being.
MULTI_TENANCY_FIELD = "AccountID"
)

var (
errMissingField = errors.New("unable to get field from token")
errMissingToken = errors.New("unable to get token from context")
errInvalidAssertion = errors.New("unable to assert value as jwt.MapClaims")
)

func GetJWTField(ctx context.Context, field string, keyfunc jwt.Keyfunc) (string, error) {
token, err := getToken(ctx, keyfunc)
if err != nil {
return "", errMissingToken
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return "", errInvalidAssertion
}
jwtField, ok := claims[field]
if !ok {
return "", errMissingField
}
return fmt.Sprint(jwtField), nil
}

func GetAccountID(ctx context.Context, keyfunc jwt.Keyfunc) (string, error) {
return GetJWTField(ctx, MULTI_TENANCY_FIELD, keyfunc)
}

// getToken parses the token into a jwt.Token type from the grpc metadata.
// WARNING: if keyfunc is nil, the token will get parsed but not verified
// because it has been checked previously in the stack. More information
// here: https://godoc.org/github.com/dgrijalva/jwt-go#Parser.ParseUnverified
func getToken(ctx context.Context, keyfunc jwt.Keyfunc) (jwt.Token, error) {
tokenStr, err := grpc_auth.AuthFromMD(ctx, "token")
if err != nil {
return jwt.Token{}, ErrUnauthorized
}
parser := jwt.Parser{}
if keyfunc != nil {
token, err := parser.Parse(tokenStr, keyfunc)
if err != nil {
return jwt.Token{}, ErrUnauthorized
}
return *token, nil
}
token, _, err := parser.ParseUnverified(tokenStr, jwt.MapClaims{})
if err != nil {
return jwt.Token{}, ErrUnauthorized
}
return *token, nil
}
122 changes: 122 additions & 0 deletions mw/auth/jwt_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package auth

import (
"context"
"fmt"
"testing"

jwt "github.com/dgrijalva/jwt-go"
"google.golang.org/grpc/metadata"
)

const (
TEST_SECRET = "some-secret-123"
)

func TestGetJWTField(t *testing.T) {
var jwtFieldTests = []struct {
claims jwt.MapClaims
field string
expected string
err error
}{
{
claims: jwt.MapClaims{
"some-field": "id-abc-123",
},
field: "some-field",
expected: "id-abc-123",
err: nil,
},
{
claims: jwt.MapClaims{
"some-field": "id-abc-123",
},
field: "some-other-field",
expected: "",
err: errMissingField,
},
{
claims: jwt.MapClaims{},
field: "some-field",
expected: "",
err: errMissingToken,
},
}
for _, test := range jwtFieldTests {
ctx := context.Background()
if len(test.claims) != 0 {
token := makeToken(test.claims, t)
c, err := contextWithToken(token)
if err != nil {
t.Fatalf("Error when building request context: %v", err)
}
ctx = c
}
actual, err := GetJWTField(ctx, test.field, nil)
if err != test.err {
t.Errorf("Invalid error value: %v - expected %v", err, test.err)
}
if actual != test.expected {
t.Errorf("Invalid JWT field: %v - expected %v", actual, test.expected)
}
}
}

func TestGetAccountID(t *testing.T) {
var accountIDTests = []struct {
claims jwt.MapClaims
expected string
err error
}{
{
claims: jwt.MapClaims{
"AccountID": "id-abc-123",
},
expected: "id-abc-123",
err: nil,
},
{
claims: jwt.MapClaims{},
expected: "",
err: errMissingField,
},
}
for _, test := range accountIDTests {
token := makeToken(test.claims, t)
ctx, err := contextWithToken(token)
if err != nil {
t.Fatalf("Error when building request context: %v", err)
}
actual, err := GetAccountID(ctx, nil)
if err != test.err {
t.Errorf("Invalid error value: %v - expected %v", err, test.err)
}
if actual != test.expected {
t.Errorf("Invalid AccountID: %v - expected %v", actual, test.expected)
}
}
}

// creates a context with a jwt
func contextWithToken(token string) (context.Context, error) {
md := metadata.Pairs(
"authorization", fmt.Sprintf("token %s", token),
)
return metadata.NewIncomingContext(context.Background(), md), nil
}

// generates a token string based on the given jwt claims
func makeToken(claims jwt.Claims, t *testing.T) string {
method := jwt.SigningMethodHS256
token := jwt.NewWithClaims(method, claims)
signingString, err := token.SigningString()
if err != nil {
t.Fatalf("Error when building token: %v", err)
}
signature, err := method.Sign(signingString, []byte(TEST_SECRET))
if err != nil {
t.Fatalf("Error when building token: %v", err)
}
return fmt.Sprintf("%s.%s", signingString, signature)
}
25 changes: 0 additions & 25 deletions mw/auth/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"strings"

jwt "github.com/dgrijalva/jwt-go"
"github.com/grpc-ecosystem/go-grpc-middleware/auth"
pdp "github.com/infobloxopen/themis/pdp-service"
"google.golang.org/grpc"
)
Expand Down Expand Up @@ -40,30 +39,6 @@ func WithJWT(keyfunc jwt.Keyfunc) option {
}
}

// getToken parses the token into a jwt.Token type from the grpc metadata.
// WARNING: if keyfunc is nil, the token will get parsed but not verified
// because it has been checked previously in the stack. More information
// here: https://godoc.org/github.com/dgrijalva/jwt-go#Parser.ParseUnverified
func getToken(ctx context.Context, keyfunc jwt.Keyfunc) (jwt.Token, error) {
tokenStr, err := grpc_auth.AuthFromMD(ctx, "token")
if err != nil {
return jwt.Token{}, ErrUnauthorized
}
parser := jwt.Parser{}
if keyfunc != nil {
token, err := parser.Parse(tokenStr, keyfunc)
if err != nil {
return jwt.Token{}, ErrUnauthorized
}
return *token, nil
}
token, _, err := parser.ParseUnverified(tokenStr, jwt.MapClaims{})
if err != nil {
return jwt.Token{}, ErrUnauthorized
}
return *token, nil
}

// WithCallback allows developers to pass their own attributer to the
// authorization service. It gives them the flexibility to add customization to
// the auth process without needing to write a Builder from scratch.
Expand Down
28 changes: 0 additions & 28 deletions mw/auth/options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package auth

import (
"context"
"fmt"
"testing"

jwt "github.com/dgrijalva/jwt-go"
Expand All @@ -11,10 +10,6 @@ import (
"google.golang.org/grpc/metadata"
)

const (
TEST_SECRET = "some-secret-123"
)

func TestWithJWT(t *testing.T) {
var jwtTests = []struct {
token string
Expand Down Expand Up @@ -196,29 +191,6 @@ func TestStripPackageName(t *testing.T) {
}
}

// creates a context with a jwt
func contextWithToken(token string) (context.Context, error) {
md := metadata.Pairs(
"authorization", fmt.Sprintf("token %s", token),
)
return metadata.NewIncomingContext(context.Background(), md), nil
}

// generates a token string based on the given jwt claims
func makeToken(claims jwt.Claims, t *testing.T) string {
method := jwt.SigningMethodHS256
token := jwt.NewWithClaims(method, claims)
signingString, err := token.SigningString()
if err != nil {
t.Fatalf("Error when building token: %v", err)
}
signature, err := method.Sign(signingString, []byte(TEST_SECRET))
if err != nil {
t.Fatalf("Error when building token: %v", err)
}
return fmt.Sprintf("%s.%s", signingString, signature)
}

// checks if first and second attribute lists contain identical elements
func hasMatchingAttributes(first, second []*pdp.Attribute) bool {
if len(first) != len(second) {
Expand Down
40 changes: 0 additions & 40 deletions mw/auth/tenantid.go

This file was deleted.

39 changes: 0 additions & 39 deletions mw/auth/tenantid_test.go

This file was deleted.

0 comments on commit 2fde752

Please sign in to comment.