Skip to content

Commit

Permalink
Support multiple Issuers
Browse files Browse the repository at this point in the history
  • Loading branch information
Richard87 committed Oct 9, 2024
1 parent f261139 commit 924d91c
Show file tree
Hide file tree
Showing 8 changed files with 442 additions and 190 deletions.
141 changes: 99 additions & 42 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,68 +3,125 @@ package main
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"slices"
"strings"
"time"

"github.com/coreos/go-oidc/v3/oidc"
"github.com/rs/zerolog"
"github.com/auth0/go-jwt-middleware/v2/jwks"
"github.com/rs/zerolog/log"
"gopkg.in/go-jose/go-jose.v2/jwt"
)

var (
errInvalidAuthorizationHeader = errors.New("invalid Authorization header")
)

type Verifier interface {
Verify(ctx context.Context, rawIDToken string) (*oidc.IDToken, error)
type KeyFunc func(ctx context.Context) (interface{}, error)
type controller struct {
providers map[string]KeyFunc
audience string
subjects []string
}

// AuthHandler returns a Handler to authenticate requests
func AuthHandler(subjects []string, verifier Verifier) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Trace().Func(func(e *zerolog.Event) {
headers := r.Header.Clone()
headers.Del("Authorization")
if r.Header.Get("Authorization") != "" {
headers.Set("Authorization", "!REMOVED!")
}
e.Interface("headers", headers)
}).Msg("Request details")
t := time.Now()

auth := r.Header.Get("Authorization")
jwt, err := parseAuthHeader(auth)
// NewAuthHandler returns a Handler to authenticate requests
func NewAuthHandler(audience string, subjects, issuers []string) (RouteMapper, error) {
providers := make(map[string]KeyFunc, len(issuers))
for _, issuer := range issuers {
issuerUrl, err := url.Parse(issuer)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte("Forbidden"))
log.Info().Err(err).Dur("elappsed_ms", time.Since(t)).Int("status", http.StatusUnauthorized).Msg("Unauthorized")
return
return nil, err
}

token, err := verifier.Verify(r.Context(), jwt)
provider := jwks.NewCachingProvider(issuerUrl, 5*time.Hour)
providers[issuer] = provider.KeyFunc
}

if err != nil {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte("Forbidden"))
log.Info().Err(err).Dur("elappsed_ms", time.Since(t)).Int("status", http.StatusUnauthorized).Msg("Unauthorized")
return
}
c := &controller{
providers: providers,
audience: audience,
subjects: subjects,
}
return func(mux *http.ServeMux) {
mux.Handle("/auth", c)
}, nil
}

subject := token.Subject
found := slices.Contains(subjects, subject)
if !found {
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte("Forbidden"))
log.Info().Err(err).Dur("elappsed_ms", time.Since(t)).Int("status", http.StatusForbidden).Str("sub", subject).Msg("Forbidden")
return
}
func (c *controller) ServeHTTP(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
authHeader, err := parseAuthHeader(auth)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte("Unauthorized"))
log.Info().Err(err).Msg("Unauthorized: Invalid auth header")
return
}

claims, err := c.getClaims(r.Context(), authHeader)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte("Unauthorized"))
log.Warn().Err(err).Msg("Forbidden: Invalid token")
return
}

subject := claims.Subject

found := slices.Contains(c.subjects, subject)
if !found {
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte("Forbidden"))
log.Warn().Str("sub", subject).Msg("Forbidden")
return
}

w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("OK"))
log.Info().Str("sub", subject).Msg("Authorized")
}

func (c *controller) getClaims(ctx context.Context, authHeader string) (*jwt.Claims, error) {
var unsafeClaims jwt.Claims
token, err := jwt.ParseSigned(authHeader)
if err != nil {
return nil, fmt.Errorf("failed to parse JWT token: %w", err)
}
err = token.UnsafeClaimsWithoutVerification(&unsafeClaims)
if err != nil {
return nil, fmt.Errorf("failed to extract JWT unsafeClaims: %w", err)
}
var keyId string
if len(token.Headers) == 1 {
keyId = token.Headers[0].KeyID
}
if keyId == "" {
return nil, fmt.Errorf("failed to find keyId in headers")
}

issuer := unsafeClaims.Issuer
keyFunc, ok := c.providers[issuer]
if !ok {
return nil, fmt.Errorf("unknown issuer: %s", issuer)
}
key, err := keyFunc(ctx)
if err != nil {
return nil, fmt.Errorf("error getting the keys from the key func: %w", err)
}

var verifiedClaims jwt.Claims
err = token.Claims(key, &verifiedClaims)
if err != nil {
return nil, fmt.Errorf("failed to verify token unsafeClaims: %w", err)
}

expected := jwt.Expected{Audience: []string{c.audience}}
if err = verifiedClaims.Validate(expected); err != nil {
return nil, fmt.Errorf("failed to verify token unsafeClaims: %w", err)
}

w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("OK"))
log.Info().Dur("elappsed_ms", time.Since(t)).Int("status", http.StatusOK).Str("sub", subject).Msg("Authorized")
})
return &verifiedClaims, nil
}

func parseAuthHeader(authorization string) (string, error) {
Expand Down
Loading

0 comments on commit 924d91c

Please sign in to comment.