-
Notifications
You must be signed in to change notification settings - Fork 115
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #5 from ZachEddy/jwt-handling
Update JWT Handling
- Loading branch information
Showing
6 changed files
with
189 additions
and
132 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
package auth | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
|
||
jwt "github.com/dgrijalva/jwt-go" | ||
"github.com/grpc-ecosystem/go-grpc-middleware/auth" | ||
) | ||
|
||
const ( | ||
// TODO: Field is tentatively called "AccountID" but will probably need to be | ||
// changed. We don't know what the JWT will look like, so we're giving it our | ||
// best guess for the time being. | ||
MULTI_TENANCY_FIELD = "AccountID" | ||
) | ||
|
||
var ( | ||
errMissingField = errors.New("unable to get field from token") | ||
errMissingToken = errors.New("unable to get token from context") | ||
errInvalidAssertion = errors.New("unable to assert value as jwt.MapClaims") | ||
) | ||
|
||
func GetJWTField(ctx context.Context, field string, keyfunc jwt.Keyfunc) (string, error) { | ||
token, err := getToken(ctx, keyfunc) | ||
if err != nil { | ||
return "", errMissingToken | ||
} | ||
claims, ok := token.Claims.(jwt.MapClaims) | ||
if !ok { | ||
return "", errInvalidAssertion | ||
} | ||
jwtField, ok := claims[field] | ||
if !ok { | ||
return "", errMissingField | ||
} | ||
return fmt.Sprint(jwtField), nil | ||
} | ||
|
||
func GetAccountID(ctx context.Context, keyfunc jwt.Keyfunc) (string, error) { | ||
return GetJWTField(ctx, MULTI_TENANCY_FIELD, keyfunc) | ||
} | ||
|
||
// getToken parses the token into a jwt.Token type from the grpc metadata. | ||
// WARNING: if keyfunc is nil, the token will get parsed but not verified | ||
// because it has been checked previously in the stack. More information | ||
// here: https://godoc.org/github.com/dgrijalva/jwt-go#Parser.ParseUnverified | ||
func getToken(ctx context.Context, keyfunc jwt.Keyfunc) (jwt.Token, error) { | ||
tokenStr, err := grpc_auth.AuthFromMD(ctx, "token") | ||
if err != nil { | ||
return jwt.Token{}, ErrUnauthorized | ||
} | ||
parser := jwt.Parser{} | ||
if keyfunc != nil { | ||
token, err := parser.Parse(tokenStr, keyfunc) | ||
if err != nil { | ||
return jwt.Token{}, ErrUnauthorized | ||
} | ||
return *token, nil | ||
} | ||
token, _, err := parser.ParseUnverified(tokenStr, jwt.MapClaims{}) | ||
if err != nil { | ||
return jwt.Token{}, ErrUnauthorized | ||
} | ||
return *token, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
package auth | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"testing" | ||
|
||
jwt "github.com/dgrijalva/jwt-go" | ||
"google.golang.org/grpc/metadata" | ||
) | ||
|
||
const ( | ||
TEST_SECRET = "some-secret-123" | ||
) | ||
|
||
func TestGetJWTField(t *testing.T) { | ||
var jwtFieldTests = []struct { | ||
claims jwt.MapClaims | ||
field string | ||
expected string | ||
err error | ||
}{ | ||
{ | ||
claims: jwt.MapClaims{ | ||
"some-field": "id-abc-123", | ||
}, | ||
field: "some-field", | ||
expected: "id-abc-123", | ||
err: nil, | ||
}, | ||
{ | ||
claims: jwt.MapClaims{ | ||
"some-field": "id-abc-123", | ||
}, | ||
field: "some-other-field", | ||
expected: "", | ||
err: errMissingField, | ||
}, | ||
{ | ||
claims: jwt.MapClaims{}, | ||
field: "some-field", | ||
expected: "", | ||
err: errMissingToken, | ||
}, | ||
} | ||
for _, test := range jwtFieldTests { | ||
ctx := context.Background() | ||
if len(test.claims) != 0 { | ||
token := makeToken(test.claims, t) | ||
c, err := contextWithToken(token) | ||
if err != nil { | ||
t.Fatalf("Error when building request context: %v", err) | ||
} | ||
ctx = c | ||
} | ||
actual, err := GetJWTField(ctx, test.field, nil) | ||
if err != test.err { | ||
t.Errorf("Invalid error value: %v - expected %v", err, test.err) | ||
} | ||
if actual != test.expected { | ||
t.Errorf("Invalid JWT field: %v - expected %v", actual, test.expected) | ||
} | ||
} | ||
} | ||
|
||
func TestGetAccountID(t *testing.T) { | ||
var accountIDTests = []struct { | ||
claims jwt.MapClaims | ||
expected string | ||
err error | ||
}{ | ||
{ | ||
claims: jwt.MapClaims{ | ||
"AccountID": "id-abc-123", | ||
}, | ||
expected: "id-abc-123", | ||
err: nil, | ||
}, | ||
{ | ||
claims: jwt.MapClaims{}, | ||
expected: "", | ||
err: errMissingField, | ||
}, | ||
} | ||
for _, test := range accountIDTests { | ||
token := makeToken(test.claims, t) | ||
ctx, err := contextWithToken(token) | ||
if err != nil { | ||
t.Fatalf("Error when building request context: %v", err) | ||
} | ||
actual, err := GetAccountID(ctx, nil) | ||
if err != test.err { | ||
t.Errorf("Invalid error value: %v - expected %v", err, test.err) | ||
} | ||
if actual != test.expected { | ||
t.Errorf("Invalid AccountID: %v - expected %v", actual, test.expected) | ||
} | ||
} | ||
} | ||
|
||
// creates a context with a jwt | ||
func contextWithToken(token string) (context.Context, error) { | ||
md := metadata.Pairs( | ||
"authorization", fmt.Sprintf("token %s", token), | ||
) | ||
return metadata.NewIncomingContext(context.Background(), md), nil | ||
} | ||
|
||
// generates a token string based on the given jwt claims | ||
func makeToken(claims jwt.Claims, t *testing.T) string { | ||
method := jwt.SigningMethodHS256 | ||
token := jwt.NewWithClaims(method, claims) | ||
signingString, err := token.SigningString() | ||
if err != nil { | ||
t.Fatalf("Error when building token: %v", err) | ||
} | ||
signature, err := method.Sign(signingString, []byte(TEST_SECRET)) | ||
if err != nil { | ||
t.Fatalf("Error when building token: %v", err) | ||
} | ||
return fmt.Sprintf("%s.%s", signingString, signature) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.