Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
jairad26 committed Oct 10, 2024
1 parent b757b2c commit 897cfc1
Showing 1 changed file with 21 additions and 12 deletions.
33 changes: 21 additions & 12 deletions runtime/middleware/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,38 +29,47 @@ type jwtClaimsKey string

const jwtClaims jwtClaimsKey = "jwt_claims"

var jwtPublicKeys map[string]any
var authPublicKeys map[string]any

func Init(ctx context.Context) {
privKeysJson := os.Getenv("MODUS_RSA_PEMS")
if privKeysJson == "" {
publicKeysJson := os.Getenv("MODUS_PEMS")
if publicKeysJson == "" {
return
}
var publicKeyStrings map[string]string
err := json.Unmarshal([]byte(privKeysJson), &publicKeyStrings)
err := json.Unmarshal([]byte(publicKeysJson), &publicKeyStrings)
if err != nil {
logger.Error(ctx).Err(err).Msg("JWT private keys unmarshalling error")
if config.IsDevEnvironment() {
logger.Fatal(ctx).Err(err).Msg("JWT public keys deserializing error")
}
logger.Error(ctx).Err(err).Msg("JWT private keys deserializing error")
return
}
jwtPublicKeys = make(map[string]any)
authPublicKeys = make(map[string]any)
for key, value := range publicKeyStrings {
block, _ := pem.Decode([]byte(value))
if block == nil {
logger.Error(ctx).Msg("Invalid PEM block")
if config.IsDevEnvironment() {
logger.Fatal(ctx).Msg("Invalid PEM block for key: " + key)
}
logger.Error(ctx).Msg("Invalid PEM block for key: " + key)
return
}

pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
logger.Error(ctx).Err(err).Msg("JWT public key parsing error")
if config.IsDevEnvironment() {
logger.Fatal(ctx).Err(err).Msg("JWT public key parsing error for key: " + key)
}
logger.Error(ctx).Err(err).Msg("JWT public key parsing error for key: " + key)
return
}
jwtPublicKeys[key] = pubKey
authPublicKeys[key] = pubKey
}
}

func HandleJWT(next http.Handler) http.Handler {
var jwtParser = new(jwt.Parser)
var jwtParser = jwt.NewParser()
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var ctx context.Context = r.Context()
tokenStr := r.Header.Get("Authorization")
Expand All @@ -74,7 +83,7 @@ func HandleJWT(next http.Handler) http.Handler {
}
}

if len(jwtPublicKeys) == 0 {
if len(authPublicKeys) == 0 {
if !config.IsDevEnvironment() || tokenStr == "" {
next.ServeHTTP(w, r)
return
Expand All @@ -96,7 +105,7 @@ func HandleJWT(next http.Handler) http.Handler {
var token *jwt.Token
var err error

for _, publicKey := range jwtPublicKeys {
for _, publicKey := range authPublicKeys {
token, err = jwtParser.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
return publicKey, nil
})
Expand Down

0 comments on commit 897cfc1

Please sign in to comment.