Skip to content

Commit

Permalink
lots of cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
jairad26 committed Oct 25, 2024
1 parent 43083c2 commit a269097
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 86 deletions.
105 changes: 105 additions & 0 deletions runtime/middleware/authKeys.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright 2024 Hypermode Inc.
* Licensed under the terms of the Apache License, Version 2.0
* See the LICENSE file that accompanied this code for further details.
*
* SPDX-FileCopyrightText: 2024 Hypermode Inc. <hello@hypermode.com>
* SPDX-License-Identifier: Apache-2.0
*/

package middleware

import (
"context"
"os"
"strconv"
"sync"
"time"

"github.com/hypermodeinc/modus/runtime/logger"
)

var (
globalAuthKeys *AuthKeys
)

type AuthKeys struct {
pemPublicKeys map[string]any
jwksPublicKeys map[string]any
mu sync.RWMutex
quit chan struct{}
done chan struct{}
}

func newAuthKeys() *AuthKeys {
return &AuthKeys{
pemPublicKeys: make(map[string]any),
jwksPublicKeys: make(map[string]any),
quit: make(chan struct{}),
done: make(chan struct{}),
}
}

func (ak *AuthKeys) setPemPublicKeys(keys map[string]any) {
ak.mu.Lock()
defer ak.mu.Unlock()
ak.pemPublicKeys = keys
}

func (ak *AuthKeys) setJwksPublicKeys(keys map[string]any) {
ak.mu.Lock()
defer ak.mu.Unlock()
ak.jwksPublicKeys = keys
}

func (ak *AuthKeys) getPemPublicKeys() map[string]any {
ak.mu.RLock()
defer ak.mu.RUnlock()
return ak.pemPublicKeys
}

func (ak *AuthKeys) getJwksPublicKeys() map[string]any {
ak.mu.RLock()
defer ak.mu.RUnlock()
return ak.jwksPublicKeys
}

func getJwksRefreshMinutes(ctx context.Context) int {
refreshTimeStr := os.Getenv("MODUS_JWKS_REFRESH_MINUTES")
if refreshTimeStr == "" {
return 1440
}
refreshTime, err := strconv.Atoi(refreshTimeStr)
if err != nil {
logger.Warn(ctx).Err(err).Msg("Invalid MODUS_JWKS_REFRESH_MINUTES value. Using default value of 1440 minutes.")
return 1440
}
return refreshTime
}

func (ak *AuthKeys) worker(ctx context.Context) {
defer close(ak.done)
timer := time.NewTimer(time.Duration(getJwksRefreshMinutes(ctx)) * time.Minute)

defer timer.Stop()
for {
select {
case <-timer.C:
// refresh JWKS keys
keysStr := os.Getenv("MODUS_JWKS_ENDPOINTS")
if keysStr != "" {
timer.Reset(time.Duration(getJwksRefreshMinutes(ctx)) * time.Minute)
} else {
keys, err := jwksEndpointsJsonToKeys(ctx, keysStr)
if err != nil {
logger.Error(ctx).Err(err).Msg("Auth JWKS public keys deserializing error")
} else {
ak.setJwksPublicKeys(keys)
}
}
case <-ak.quit:
return
}
}

}
175 changes: 89 additions & 86 deletions runtime/middleware/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"crypto/x509"
"encoding/json"
"encoding/pem"
"errors"
"net/http"
"os"
"strings"
Expand All @@ -30,107 +31,95 @@ type jwtClaimsKey string

const jwtClaims jwtClaimsKey = "claims"

var authPublicKeys map[string]any
func publicPemKeysJsonToKeys(publicPemKeysJson string) (map[string]any, error) {
var publicKeyStrings map[string]string
if err := json.Unmarshal([]byte(publicPemKeysJson), &publicKeyStrings); err != nil {
return nil, err
}
keys := make(map[string]any)
for key, value := range publicKeyStrings {
block, _ := pem.Decode([]byte(value))
if block == nil {
return nil, errors.New("Invalid PEM block for key: " + key)
}

pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
return nil, err
}
keys[key] = pubKey
}
return keys, nil
}

func jwksEndpointsJsonToKeys(ctx context.Context, jwksEndpointsJson string) (map[string]any, error) {
var jwksEndpoints map[string]string
if err := json.Unmarshal([]byte(jwksEndpointsJson), &jwksEndpoints); err != nil {
return nil, err
}
keys := make(map[string]any)
for key, value := range jwksEndpoints {
jwks, err := jwk.Fetch(ctx, value)
if err != nil {
return nil, err
}

jwkKey, exists := jwks.Get(0)
if !exists {
return nil, errors.New("No keys found in JWKS for key: " + key)
}

var rawKey any
err = jwkKey.Raw(&rawKey)
if err != nil {
return nil, err
}

// Marshal the raw key into DER-encoded PKIX format
derBytes, err := x509.MarshalPKIXPublicKey(rawKey)
if err != nil {
return nil, err
}

pubKey, err := x509.ParsePKIXPublicKey(derBytes)
if err != nil {
return nil, err
}
keys[key] = pubKey
}
return keys, nil
}

func Init(ctx context.Context) {
publicKeysJson := os.Getenv("MODUS_PEMS")
globalAuthKeys = newAuthKeys()
go globalAuthKeys.worker(ctx)
publicPemKeysJson := os.Getenv("MODUS_PEMS")
jwksEndpointsJson := os.Getenv("MODUS_JWKS_ENDPOINTS")
if publicKeysJson == "" && jwksEndpointsJson == "" {
if publicPemKeysJson == "" && jwksEndpointsJson == "" {
return
}

authPublicKeys = make(map[string]any)

if publicKeysJson != "" {
var publicKeyStrings map[string]string
err := json.Unmarshal([]byte(publicKeysJson), &publicKeyStrings)
if publicPemKeysJson != "" {
keys, err := publicPemKeysJsonToKeys(publicPemKeysJson)
if err != nil {
if config.IsDevEnvironment() {
logger.Fatal(ctx).Err(err).Msg("Auth public keys deserializing error")
logger.Fatal(ctx).Err(err).Msg("Auth PEM public keys deserializing error")
}
logger.Error(ctx).Err(err).Msg("Auth public keys deserializing error")
logger.Error(ctx).Err(err).Msg("Auth PEM public keys deserializing error")
return
}
for key, value := range publicKeyStrings {
block, _ := pem.Decode([]byte(value))
if block == nil {
if config.IsDevEnvironment() {
logger.Fatal(ctx).Msg("Invalid PEM block for key: " + key)
}
logger.Error(ctx).Msg("Invalid PEM block for key: " + key)
return
}

pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
if config.IsDevEnvironment() {
logger.Fatal(ctx).Err(err).Msg("JWT public key parsing error for key: " + key)
}
logger.Error(ctx).Err(err).Msg("JWT public key parsing error for key: " + key)
return
}
authPublicKeys[key] = pubKey
}
globalAuthKeys.setPemPublicKeys(keys)
}
if jwksEndpointsJson != "" {
var jwksEndpoints map[string]string
err := json.Unmarshal([]byte(jwksEndpointsJson), &jwksEndpoints)
keys, err := jwksEndpointsJsonToKeys(ctx, jwksEndpointsJson)
if err != nil {
if config.IsDevEnvironment() {
logger.Fatal(ctx).Err(err).Msg("JWKS endpoints deserializing error")
logger.Fatal(ctx).Err(err).Msg("Auth JWKS public keys deserializing error")
}
logger.Error(ctx).Err(err).Msg("JWKS endpoints deserializing error")
logger.Error(ctx).Err(err).Msg("Auth JWKS public keys deserializing error")
return
}
for key, value := range jwksEndpoints {
jwks, err := jwk.Fetch(ctx, value)
if err != nil {
if config.IsDevEnvironment() {
logger.Fatal(ctx).Err(err).Msg("JWKS fetching error for key: " + key)
}
logger.Error(ctx).Err(err).Msg("JWKS fetching error for key: " + key)
return
}

jwkKey, exists := jwks.Get(0)
if !exists {
if config.IsDevEnvironment() {
logger.Fatal(ctx).Msg("No keys found in JWKS for key: " + key)
}
logger.Error(ctx).Msg("No keys found in JWKS for key: " + key)
return
}

var rawKey interface{}
err = jwkKey.Raw(&rawKey)
if err != nil {
if config.IsDevEnvironment() {
logger.Fatal(ctx).Err(err).Msg("Failed to get raw key for key: " + key)
}
logger.Error(ctx).Err(err).Msg("Failed to get raw key for key: " + key)
return
}

// Marshal the raw key into DER-encoded PKIX format
derBytes, err := x509.MarshalPKIXPublicKey(rawKey)
if err != nil {
if config.IsDevEnvironment() {
logger.Fatal(ctx).Err(err).Msg("Failed to marshal raw key for key: " + key)
}
logger.Error(ctx).Err(err).Msg("Failed to marshal raw key for key: " + key)
return
}

pubKey, err := x509.ParsePKIXPublicKey(derBytes)
if err != nil {
if config.IsDevEnvironment() {
logger.Fatal(ctx).Err(err).Msg("JWT public key fetching error for key: " + key)
}
logger.Error(ctx).Err(err).Msg("JWT public key fetching error for key: " + key)
return
}
authPublicKeys[key] = pubKey
}
globalAuthKeys.setJwksPublicKeys(keys)
}
}

Expand All @@ -149,7 +138,7 @@ func HandleJWT(next http.Handler) http.Handler {
}
}

if len(authPublicKeys) == 0 {
if len(globalAuthKeys.getPemPublicKeys()) == 0 && len(globalAuthKeys.getJwksPublicKeys()) == 0 {
if config.IsDevEnvironment() {
if tokenStr == "" {
next.ServeHTTP(w, r)
Expand Down Expand Up @@ -183,9 +172,9 @@ func HandleJWT(next http.Handler) http.Handler {
var err error
var found bool

for _, publicKey := range authPublicKeys {
for _, pemPublicKey := range globalAuthKeys.getPemPublicKeys() {
token, err = jwtParser.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
return publicKey, nil
return pemPublicKey, nil
})
if err == nil {
if utils.DebugModeEnabled() {
Expand All @@ -195,6 +184,20 @@ func HandleJWT(next http.Handler) http.Handler {
break
}
}
if !found {
for _, jwksPublicKey := range globalAuthKeys.getJwksPublicKeys() {
token, err = jwtParser.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
return jwksPublicKey, nil
})
if err == nil {
if utils.DebugModeEnabled() {
logger.Debug(ctx).Msg("JWT token parsed successfully")
}
found = true
break
}
}
}
if !found {
logger.Error(ctx).Err(err).Msg("JWT parse error")
http.Error(w, "Access Denied", http.StatusUnauthorized)
Expand Down

0 comments on commit a269097

Please sign in to comment.