From e54f4188a68df5d82335c6dca25b03902e7d0c2f Mon Sep 17 00:00:00 2001 From: Udara Premadasa Date: Fri, 17 Jun 2022 10:43:19 +0530 Subject: [PATCH] add jwt module --- jwt/check.go | 321 +++++++++++++++++++++++++++++++++ jwt/jwt.go | 314 ++++++++++++++++++++++++++++++++ jwt/register.go | 464 ++++++++++++++++++++++++++++++++++++++++++++++++ jwt/sign.go | 325 +++++++++++++++++++++++++++++++++ jwt/web.go | 267 ++++++++++++++++++++++++++++ 5 files changed, 1691 insertions(+) create mode 100644 jwt/check.go create mode 100644 jwt/jwt.go create mode 100644 jwt/register.go create mode 100644 jwt/sign.go create mode 100644 jwt/web.go diff --git a/jwt/check.go b/jwt/check.go new file mode 100644 index 0000000..8d9710b --- /dev/null +++ b/jwt/check.go @@ -0,0 +1,321 @@ +package jwt + +import ( + "bytes" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/hmac" + "crypto/rsa" + "encoding/json" + "errors" + "fmt" + "hash" + "math/big" +) + +// ErrSigMiss means the signature check failed. +var ErrSigMiss = errors.New("jwt: signature mismatch") + +var errNoPayload = errors.New("jwt: one part only—payload absent") + +// “Producers MUST NOT use the empty list "[]" as the "crit" value.” +// — “JSON Web Signature (JWS)” RFC 7515, subsection 4.1.11 +var errCritEmpty = errors.New("jwt: empty array in crit header") + +// EvalCrit is invoked by the Check functions for each token with one or more +// JOSE extensions. The crit slice has the JSON field names (for header) which +// “MUST be understood and processed” according to RFC 7515, subsection 4.1.11. +// “If any of the listed extension Header Parameters are not understood and +// supported by the recipient, then the JWS is invalid.” +// The respective Check function returns any error from EvalCrit as is. +var EvalCrit = func(token []byte, crit []string, header json.RawMessage) error { + return fmt.Errorf("jwt: unsupported critical extension in JOSE header: %q", crit) +} + +// ParseWithoutCheck skips the signature validation. +func ParseWithoutCheck(token []byte) (*Claims, error) { + var c Claims + _, _, _, err := c.scan(token) + if err != nil { + return nil, err + } + + return &c, c.applyPayload() +} + +// ECDSACheck parses a JWT if, and only if, the signature checks out. +// The return is an AlgError when the algorithm is not in ECDSAAlgs. +// Use Valid to complete the verification. +func ECDSACheck(token []byte, key *ecdsa.PublicKey) (*Claims, error) { + var c Claims + bodyLen, sig, alg, err := c.scan(token) + if err != nil { + return nil, err + } + + hash, err := hashLookup(alg, ECDSAAlgs) + if err != nil { + return nil, err + } + digest := hash.New() + digest.Write(token[:bodyLen]) + + r := new(big.Int).SetBytes(sig[:len(sig)/2]) + s := new(big.Int).SetBytes(sig[len(sig)/2:]) + buf := sig[len(sig):] + if !ecdsa.Verify(key, digest.Sum(buf), r, s) { + return nil, ErrSigMiss + } + + return &c, c.applyPayload() +} + +// EdDSACheck parses a JWT if, and only if, the signature checks out. +// Use Valid to complete the verification. +func EdDSACheck(token []byte, key ed25519.PublicKey) (*Claims, error) { + var c Claims + bodyLen, sig, alg, err := c.scan(token) + if err != nil { + return nil, err + } + + if alg != EdDSA { + return nil, AlgError(alg) + } + + if !ed25519.Verify(key, token[:bodyLen], sig) { + return nil, ErrSigMiss + } + + return &c, c.applyPayload() +} + +// HMACCheck parses a JWT if, and only if, the signature checks out. +// The return is an AlgError when the algorithm is not in HMACAlgs. +// Use Valid to complete the verification. +func HMACCheck(token, secret []byte) (*Claims, error) { + if len(secret) == 0 { + return nil, errNoSecret + } + + var c Claims + bodyLen, sig, alg, err := c.scan(token) + if err != nil { + return nil, err + } + + hash, err := hashLookup(alg, HMACAlgs) + if err != nil { + return nil, err + } + digest := hmac.New(hash.New, secret) + digest.Write(token[:bodyLen]) + + buf := sig[len(sig):] + if !hmac.Equal(sig, digest.Sum(buf)) { + return nil, ErrSigMiss + } + + return &c, c.applyPayload() +} + +// Check parses a JWT if, and only if, the signature checks out. +// The return is an AlgError when the algorithm does not match. +// Use Valid to complete the verification. +func (h *HMAC) Check(token []byte) (*Claims, error) { + var c Claims + bodyLen, sig, alg, err := c.scan(token) + if err != nil { + return nil, err + } + if alg != h.alg { + return nil, AlgError(alg) + } + + digest := h.digests.Get().(hash.Hash) + defer h.digests.Put(digest) + digest.Reset() + digest.Write(token[:bodyLen]) + + buf := sig[len(sig):] + if !hmac.Equal(sig, digest.Sum(buf)) { + return nil, ErrSigMiss + } + + return &c, c.applyPayload() +} + +// RSACheck parses a JWT if, and only if, the signature checks out. +// The return is an AlgError when the algorithm is not in RSAAlgs. +// Use Valid to complete the verification. +func RSACheck(token []byte, key *rsa.PublicKey) (*Claims, error) { + var c Claims + bodyLen, sig, alg, err := c.scan(token) + if err != nil { + return nil, err + } + + hash, err := hashLookup(alg, RSAAlgs) + if err != nil { + return nil, err + } + digest := hash.New() + digest.Write(token[:bodyLen]) + + buf := sig[len(sig):] + if alg != "" && alg[0] == 'P' { + err = rsa.VerifyPSS(key, hash, digest.Sum(buf), sig, &pSSOptions) + } else { + err = rsa.VerifyPKCS1v15(key, hash, digest.Sum(buf), sig) + } + if err != nil { + return nil, ErrSigMiss + } + + return &c, c.applyPayload() +} + +// DecodeParts reads up to three base64 parts. The result goes in c.RawHeader, c.Raw and sig. +func (c *Claims) decodeParts(token []byte) (bodyLen int, sig []byte, err error) { + // fits all 3 parts decoded + buffer space for Hash.Sum. + buf := make([]byte, len(token)) + + // header + i := bytes.IndexByte(token, '.') + if i < 0 { + i = len(token) + } + n, err := encoding.Decode(buf, token[:i]) + if err != nil { + return 0, nil, fmt.Errorf("jwt: malformed JOSE header: %w", err) + } + c.RawHeader = json.RawMessage(buf[:n]) + buf = buf[n:] + + if i >= len(token) { + return len(token), nil, nil + } + i++ // pass first dot + + // payload + bodyLen = i + bytes.IndexByte(token[i:], '.') + if bodyLen < i { + bodyLen = len(token) + } + n, err = encoding.Decode(buf, token[i:bodyLen]) + if err != nil { + return 0, nil, fmt.Errorf("jwt: malformed payload: %w", err) + } + c.Raw = json.RawMessage(buf[:n]) + buf = buf[n:] + + if bodyLen >= len(token) { + return bodyLen, nil, nil + } + + // signature + remain := token[bodyLen+1:] + end := bytes.IndexByte(remain, '.') + if end >= 0 { + remain = remain[:end] + } + n, err = encoding.Decode(buf, remain) + if err != nil { + return 0, nil, fmt.Errorf("jwt: malformed signature: %w", err) + } + return bodyLen, buf[:n], nil +} + +func (c *Claims) scan(token []byte) (bodyLen int, sig []byte, alg string, err error) { + bodyLen, sig, err = c.decodeParts(token) + if err != nil { + return 0, nil, "", err + } + + var header struct { + Kid string `json:"kid"` + Alg string `json:"alg"` + Crit []string `json:"crit"` + } + if err := json.Unmarshal([]byte(c.RawHeader), &header); err != nil { + return 0, nil, "", fmt.Errorf("jwt: malformed JOSE header: %w", err) + } + + if len(c.Raw) == 0 { + return 0, nil, "", errNoPayload + } + + // apply JOSE + alg = header.Alg + c.KeyID = header.Kid + if header.Crit != nil { + if len(header.Crit) == 0 { + return 0, nil, "", errCritEmpty + } + if err := EvalCrit(token, header.Crit, c.RawHeader); err != nil { + return 0, nil, "", err + } + } + + return +} + +func (c *Claims) applyPayload() error { + err := json.Unmarshal([]byte(c.Raw), &c.Set) + if err != nil { + return fmt.Errorf("jwt: malformed payload: %w", err) + } + + // move from Set to Registered on type match + m := c.Set + if s, ok := m[issuer].(string); ok { + delete(m, issuer) + c.Issuer = s + } + if s, ok := m[subject].(string); ok { + delete(m, subject) + c.Subject = s + } + + // “In the general case, the "aud" value is an array of case-sensitive + // strings, each containing a StringOrURI value. In the special case + // when the JWT has one audience, the "aud" value MAY be a single + // case-sensitive string containing a StringOrURI value.” + switch a := m[audience].(type) { + case []interface{}: + allStrings := true + for _, o := range a { + if s, ok := o.(string); ok { + c.Audiences = append(c.Audiences, s) + } else { + allStrings = false + } + } + if allStrings { + delete(m, audience) + } + + case string: + delete(m, audience) + c.Audiences = []string{a} + } + + if f, ok := m[expires].(float64); ok { + delete(m, expires) + c.Expires = (*NumericTime)(&f) + } + if f, ok := m[notBefore].(float64); ok { + delete(m, notBefore) + c.NotBefore = (*NumericTime)(&f) + } + if f, ok := m[issued].(float64); ok { + delete(m, issued) + c.Issued = (*NumericTime)(&f) + } + if s, ok := m[id].(string); ok { + delete(m, id) + c.ID = s + } + + return nil +} diff --git a/jwt/jwt.go b/jwt/jwt.go new file mode 100644 index 0000000..802dfab --- /dev/null +++ b/jwt/jwt.go @@ -0,0 +1,314 @@ +// Package jwt implements “JSON Web Token (JWT)” RFC 7519. +// Signatures only; no unsecured nor encrypted tokens. +package jwt + +import ( + "crypto" + "crypto/hmac" + "crypto/rsa" + _ "crypto/sha256" // link into binary + _ "crypto/sha512" // link into binary + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "math" + "sync" + "time" +) + +// Algorithm Identification Tokens +const ( + EdDSA = "EdDSA" // EdDSA signature algorithms + ES256 = "ES256" // ECDSA using P-256 and SHA-256 + ES384 = "ES384" // ECDSA using P-384 and SHA-384 + ES512 = "ES512" // ECDSA using P-521 and SHA-512 + HS256 = "HS256" // HMAC using SHA-256 + HS384 = "HS384" // HMAC using SHA-384 + HS512 = "HS512" // HMAC using SHA-512 + PS256 = "PS256" // RSASSA-PSS using SHA-256 and MGF1 with SHA-256 + PS384 = "PS384" // RSASSA-PSS using SHA-384 and MGF1 with SHA-384 + PS512 = "PS512" // RSASSA-PSS using SHA-512 and MGF1 with SHA-512 + RS256 = "RS256" // RSASSA-PKCS1-v1_5 using SHA-256 + RS384 = "RS384" // RSASSA-PKCS1-v1_5 using SHA-384 + RS512 = "RS512" // RSASSA-PKCS1-v1_5 using SHA-512 +) + +// Algorithm support is configured with hash registrations. +// Any modifications should be made before first use to prevent +// data races in the Check and Sign functions, i.e., customise +// from either main or init. +var ( + ECDSAAlgs = map[string]crypto.Hash{ + ES256: crypto.SHA256, + ES384: crypto.SHA384, + ES512: crypto.SHA512, + } + HMACAlgs = map[string]crypto.Hash{ + HS256: crypto.SHA256, + HS384: crypto.SHA384, + HS512: crypto.SHA512, + } + RSAAlgs = map[string]crypto.Hash{ + PS256: crypto.SHA256, + PS384: crypto.SHA384, + PS512: crypto.SHA512, + RS256: crypto.SHA256, + RS384: crypto.SHA384, + RS512: crypto.SHA512, + } +) + +// See crypto.Hash.Available. +var errHashLink = errors.New("jwt: hash function not linked into binary") + +func hashLookup(alg string, algs map[string]crypto.Hash) (crypto.Hash, error) { + hash, ok := algs[alg] + if !ok { + return 0, AlgError(alg) + } + if !hash.Available() { + return 0, errHashLink + } + return hash, nil +} + +// AlgError signals that the specified algorithm is not in use. +type AlgError string + +// Error honors the error interface. +func (e AlgError) Error() string { + return fmt.Sprintf("jwt: algorithm %q not in use", string(e)) +} + +// ErrUnsecured signals a token without a signature, as described in RFC 7519, +// section 6. +const ErrUnsecured = AlgError("none") + +// ErrNoSecret protects against programming and configuration mistakes. +var errNoSecret = errors.New("jwt: empty secret rejected") + +// HMAC is a reusable instance, optimized for high usage scenarios. +// +// Multiple goroutines may invoke methods on an HMAC simultaneously. +type HMAC struct { + alg string + digests sync.Pool +} + +// NewHMAC returns a new reusable instance. +func NewHMAC(alg string, secret []byte) (*HMAC, error) { + if len(secret) == 0 { + return nil, errNoSecret + } + hash, err := hashLookup(alg, HMACAlgs) + if err != nil { + return nil, err + } + return &HMAC{alg, sync.Pool{New: func() interface{} { + return hmac.New(hash.New, secret) + }}}, nil +} + +var encoding = base64.RawURLEncoding + +// “The size of the salt value is the same size as the hash function output.” +// — “JSON Web Algorithms (JWA)” RFC 7518, subsection 3.5 +var pSSOptions = rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash} + +// Standard (IANA registered) claim names. +const ( + issuer = "iss" + subject = "sub" + audience = "aud" + expires = "exp" + notBefore = "nbf" + issued = "iat" + id = "jti" +) + +// Registered “JSON Web Token Claims” has a subset of the IANA registration. +// See for the full listing. +// +// Each field is optional—there are no required claims. The string values are +// case sensitive. +type Registered struct { + // Issuer identifies the principal that issued the JWT. + Issuer string `json:"iss,omitempty"` + + // Subject identifies the principal that is the subject of the JWT. + Subject string `json:"sub,omitempty"` + + // Audiences identifies the recipients that the JWT is intended for. + Audiences []string `json:"aud,omitempty"` + + // Expires identifies the expiration time on or after which the JWT + // must not be accepted for processing. + Expires *NumericTime `json:"exp,omitempty"` + + // NotBefore identifies the time before which the JWT must not be + // accepted for processing. + NotBefore *NumericTime `json:"nbf,omitempty"` + + // Issued identifies the time at which the JWT was issued. + Issued *NumericTime `json:"iat,omitempty"` + + // ID provides a unique identifier for the JWT. + ID string `json:"jti,omitempty"` +} + +// Valid returns whether the claims set may be accepted for processing at the +// given moment in time. If the time is zero, then Valid returns whether there +// are no time constraints ("nbf" & "exp"). +func (r *Registered) Valid(t time.Time) bool { + n := NewNumericTime(t) + if n == nil { + return r.Expires == nil && r.NotBefore == nil + } + + return (r.Expires == nil || *r.Expires > *n) && + (r.NotBefore == nil || *r.NotBefore <= *n) +} + +// AcceptAudience verifies the applicability of an audience identified as +// stringOrURI. Any stringOrURI is accepted on absence of the aud(ience) claim. +func (r *Registered) AcceptAudience(stringOrURI string) bool { + for _, s := range r.Audiences { + if stringOrURI == s { + return true + } + } + return len(r.Audiences) == 0 +} + +// Claims are the (signed) statements of a JWT. +type Claims struct { + // Registered field values take precedence over Set. + Registered + + // Set maps claims by name, for usecases beyond the Registered fields. + // The Sign methods copy each non-zero Registered value into Set when + // the map is not nil. The Check methods map claims in Set if the name + // doesn't match any of the Registered, or if the data type won't fit. + // Entries are treated conform the encoding/json package. + // + // bool, for JSON booleans + // float64, for JSON numbers + // string, for JSON strings + // []interface{}, for JSON arrays + // map[string]interface{}, for JSON objects + // nil for JSON null + // + Set map[string]interface{} + + // Raw encoding as is within the token. This field is read-only. + Raw json.RawMessage + // RawHeader encoding as is within the token. This field is read-only. + RawHeader json.RawMessage + + // “The "kid" (key ID) Header Parameter is a hint indicating which key + // was used to secure the JWS. This parameter allows originators to + // explicitly signal a change of key to recipients. The structure of the + // "kid" value is unspecified. Its value MUST be a case-sensitive + // string. Use of this Header Parameter is OPTIONAL.” + // — “JSON Web Signature (JWS)” RFC 7515, subsection 4.1.4 + KeyID string +} + +// String returns the claim when present and if the representation is a JSON string. +// Note that null is not a string. +func (c *Claims) String(name string) (value string, ok bool) { + // try Registered first + switch name { + case issuer: + value = c.Issuer + case subject: + value = c.Subject + case audience: + if len(c.Audiences) == 1 { + return c.Audiences[0], true + } + if c.Audiences != nil { + return "", false + } + case id: + value = c.ID + } + if value != "" { + return value, true + } + + // fallback + value, ok = c.Set[name].(string) + return +} + +// Number returns the claim when present and if the representation is a JSON number. +// Note that null is not a number. +func (c *Claims) Number(name string) (value float64, ok bool) { + // try Registered first + switch name { + case expires: + if c.Expires != nil { + return float64(*c.Expires), true + } + case notBefore: + if c.NotBefore != nil { + return float64(*c.NotBefore), true + } + case issued: + if c.Issued != nil { + return float64(*c.Issued), true + } + } + + // fallback + value, ok = c.Set[name].(float64) + return +} + +// NumericTime implements NumericDate: “A JSON numeric value representing +// the number of seconds from 1970-01-01T00:00:00Z UTC until the specified +// UTC date/time, ignoring leap seconds.” +type NumericTime float64 + +// BUG(pascaldekloe): Some broken JWT implementations fail to parse tokens with +// fractions in Registered.Expires, .NotBefore or .Issued. Round to seconds—like +// NewNumericDate(time.Now().Round(time.Second))—for compatibility. + +// NewNumericTime returns the the corresponding representation with nil for the +// zero value. Do t.Round(time.Second) for slightly smaller token production and +// compatibility. See the bugs section for details. +func NewNumericTime(t time.Time) *NumericTime { + if t.IsZero() { + return nil + } + if t.Nanosecond() == 0 { + // no rounding errors + n := NumericTime(t.Unix()) + return &n + } + n := NumericTime(float64(t.UnixNano()) / 1e9) + return &n +} + +// Time returns the Go mapping with the zero value for nil. +func (n *NumericTime) Time() time.Time { + if n == nil { + return time.Time{} + } + int, frac := math.Modf(float64(*n)) + if frac == 0 { + // no rounding errors + return time.Unix(int64(int), 0).UTC() + } + return time.Unix(0, int64(*n*NumericTime(time.Second))).UTC() +} + +// String returns the ISO representation or the empty string for nil. +func (n *NumericTime) String() string { + if n == nil { + return "" + } + return n.Time().Format(time.RFC3339Nano) +} diff --git a/jwt/register.go b/jwt/register.go new file mode 100644 index 0000000..194466c --- /dev/null +++ b/jwt/register.go @@ -0,0 +1,464 @@ +package jwt + +import ( + "bytes" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/hmac" + "crypto/rsa" + "crypto/x509" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "hash" + "math/big" +) + +// KeyRegister is a collection of recognized credentials. +type KeyRegister struct { + ECDSAs []*ecdsa.PublicKey // ECDSA credentials + EdDSAs []ed25519.PublicKey // EdDSA credentials + RSAs []*rsa.PublicKey // RSA credentials + HMACs []*HMAC // HMAC credentials + Secrets [][]byte // HMAC credentials + + // Optional key identification. See Claims.KeyID for details. + // Non-empty strings match the respective key or secret by index. + ECDSAIDs []string // ECDSAs key ID mapping + EdDSAIDs []string // EdDSA key ID mapping + RSAIDs []string // RSAs key ID mapping + HMACIDs []string // Secrets key ID mapping + SecretIDs []string // Secrets key ID mapping +} + +// Check parses a JWT if, and only if, the signature checks out. +// Use Claims.Valid to complete the verification. +func (keys *KeyRegister) Check(token []byte) (*Claims, error) { + var c Claims + lastDot, sig, alg, err := c.scan(token) + if err != nil { + return nil, err + } + body := token[:lastDot] + buf := sig[len(sig):] + + switch hashAlg, err := hashLookup(alg, HMACAlgs); err.(type) { + case nil: + hMACOptions := keys.HMACs + if c.KeyID != "" { + for i, kid := range keys.HMACIDs { + if kid == c.KeyID && i < len(hMACOptions) { + hMACOptions = hMACOptions[i : i+1] + break + } + } + } + for _, h := range hMACOptions { + if h.alg == alg { + digest := h.digests.Get().(hash.Hash) + digest.Reset() + digest.Write(body) + sum := digest.Sum(buf) + h.digests.Put(digest) + if hmac.Equal(sig, sum) { + return &c, c.applyPayload() + } + } + } + + keyOptions := keys.Secrets + if c.KeyID != "" { + for i, kid := range keys.SecretIDs { + if kid == c.KeyID && i < len(keyOptions) { + keyOptions = keyOptions[i : i+1] + break + } + } + } + + for _, secret := range keyOptions { + digest := hmac.New(hashAlg.New, secret) + digest.Write(body) + if hmac.Equal(sig, digest.Sum(buf)) { + return &c, c.applyPayload() + } + } + return nil, ErrSigMiss + + case AlgError: + break // next + default: + return nil, err + } + + if alg == EdDSA { + keyOptions := keys.EdDSAs + if c.KeyID != "" { + for i, kid := range keys.EdDSAIDs { + if kid == c.KeyID && i < len(keyOptions) { + keyOptions = keyOptions[i : i+1] + break + } + } + } + + for _, key := range keyOptions { + if ed25519.Verify(key, body, sig) { + return &c, c.applyPayload() + } + } + return nil, ErrSigMiss + } + + switch hash, err := hashLookup(alg, RSAAlgs); err.(type) { + case nil: + keyOptions := keys.RSAs + if c.KeyID != "" { + for i, kid := range keys.RSAIDs { + if kid == c.KeyID && i < len(keyOptions) { + keyOptions = keyOptions[i : i+1] + break + } + } + } + + digest := hash.New() + digest.Write(body) + digestSum := digest.Sum(buf) + for _, key := range keyOptions { + if alg != "" && alg[0] == 'P' { + err = rsa.VerifyPSS(key, hash, digestSum, sig, &pSSOptions) + } else { + err = rsa.VerifyPKCS1v15(key, hash, digestSum, sig) + } + if err == nil { + return &c, c.applyPayload() + } + } + return nil, ErrSigMiss + + case AlgError: + break // next + default: + return nil, err + } + + switch hash, err := hashLookup(alg, ECDSAAlgs); err { + case nil: + keyOptions := keys.ECDSAs + if c.KeyID != "" { + for i, kid := range keys.ECDSAIDs { + if kid == c.KeyID && i < len(keyOptions) { + keyOptions = keyOptions[i : i+1] + break + } + } + } + + r := new(big.Int).SetBytes(sig[:len(sig)/2]) + s := new(big.Int).SetBytes(sig[len(sig)/2:]) + digest := hash.New() + digest.Write(body) + digestSum := digest.Sum(buf) + for _, key := range keyOptions { + if ecdsa.Verify(key, digestSum, r, s) { + return &c, c.applyPayload() + } + } + return nil, ErrSigMiss + + default: + return nil, err + } +} + +var errUnencryptedPEM = errors.New("jwt: unencrypted PEM rejected due password expectation") + +// LoadPEM scans text for PEM-encoded keys. Each occurrence found is then added +// to the register. Extraction works with certificates, public keys and private +// keys. PEM encryption is enforced with a non-empty password to ensure security +// when ordered. +func (keys *KeyRegister) LoadPEM(text, password []byte) (keysAdded int, err error) { + for { + block, remainder := pem.Decode(text) + if block == nil { + return + } + text = remainder + + if x509.IsEncryptedPEMBlock(block) { + block.Bytes, err = x509.DecryptPEMBlock(block, password) + if err != nil { + return keysAdded, err + } + } else if len(password) != 0 { + return keysAdded, errUnencryptedPEM + } + + var key interface{} + var err error + + // See RFC 7468, section 4. + switch block.Type { + case "CERTIFICATE": + certs, err := x509.ParseCertificates(block.Bytes) + if err != nil { + return keysAdded, err + } + for _, c := range certs { + if err := keys.add(c.PublicKey, ""); err != nil { + return keysAdded, err + } + keysAdded++ + } + continue + + case "PUBLIC KEY": + key, err = x509.ParsePKIXPublicKey(block.Bytes) + + case "PRIVATE KEY": + key, err = x509.ParsePKCS8PrivateKey(block.Bytes) + + case "EC PRIVATE KEY": + key, err = x509.ParseECPrivateKey(block.Bytes) + + case "RSA PRIVATE KEY": + key, err = x509.ParsePKCS1PrivateKey(block.Bytes) + + default: + return keysAdded, fmt.Errorf("jwt: unknown PEM type %q", block.Type) + } + if err != nil { + return keysAdded, err + } + if err := keys.add(key, ""); err != nil { + return keysAdded, err + } + + keysAdded++ + } +} + +func (keys *KeyRegister) add(key interface{}, kid string) error { + var i int + var ids *[]string + + switch t := key.(type) { + case *ecdsa.PublicKey: + i = len(keys.ECDSAs) + keys.ECDSAs = append(keys.ECDSAs, t) + ids = &keys.ECDSAIDs + case *ecdsa.PrivateKey: + i = len(keys.ECDSAs) + keys.ECDSAs = append(keys.ECDSAs, &t.PublicKey) + ids = &keys.ECDSAIDs + case ed25519.PublicKey: + i = len(keys.EdDSAs) + keys.EdDSAs = append(keys.EdDSAs, t) + ids = &keys.EdDSAIDs + case ed25519.PrivateKey: + i = len(keys.EdDSAs) + keys.EdDSAs = append(keys.EdDSAs, t.Public().(ed25519.PublicKey)) + ids = &keys.EdDSAIDs + case *rsa.PublicKey: + i = len(keys.RSAs) + keys.RSAs = append(keys.RSAs, t) + ids = &keys.RSAIDs + case *rsa.PrivateKey: + i = len(keys.RSAs) + keys.RSAs = append(keys.RSAs, &t.PublicKey) + ids = &keys.RSAIDs + case []byte: + i = len(keys.Secrets) + keys.Secrets = append(keys.Secrets, t) + ids = &keys.SecretIDs + default: + return fmt.Errorf("jwt: unsupported key type %T", t) + } + + if kid != "" { + for len(*ids) <= i { + *ids = append(*ids, "") + } + (*ids)[i] = kid + } + + return nil +} + +// PEM exports the (public) keys as PEM-encoded PKIX. +// Elements from the Secret field, if any, are not included. +func (keys *KeyRegister) PEM() ([]byte, error) { + buf := new(bytes.Buffer) + for _, key := range keys.ECDSAs { + if err := encodePEM(buf, key); err != nil { + return nil, err + } + } + for _, key := range keys.EdDSAs { + // There is no error case for EdDSA at the moment. + // Still want check for future stability. + if err := encodePEM(buf, key); err != nil { + return nil, err + } + } + for _, key := range keys.RSAs { + if err := encodePEM(buf, key); err != nil { + return nil, err + } + } + return buf.Bytes(), nil +} + +func encodePEM(buf *bytes.Buffer, key interface{}) error { + der, err := x509.MarshalPKIXPublicKey(key) + if err != nil { + return err + } + return pem.Encode(buf, &pem.Block{ + Type: "PUBLIC KEY", + Bytes: der, + }) +} + +type jwk struct { + Keys []*jwk + + Kid string + Kty *string + Crv string + + K, X, Y, N, E *string +} + +// LoadJWK adds keys from the JSON data to the register, including the key ID, +// a.k.a "kid", when present. If the object has a "keys" attribute, then data is +// read as a JWKS (JSON Web Key Set). Otherwise, data is read as a single JWK. +func (keys *KeyRegister) LoadJWK(data []byte) (keysAdded int, err error) { + j := new(jwk) + if err := json.Unmarshal(data, j); err != nil { + return 0, err + } + + if j.Keys == nil { + if err := keys.addJWK(j); err != nil { + return 0, err + } + return 1, nil + } + + for i, k := range j.Keys { + if err := keys.addJWK(k); err != nil { + return i, err + } + } + return len(j.Keys), nil +} + +var ( + errJWKNoKty = errors.New("jwt: JWK missing \"kty\" field") + errJWKParam = errors.New("jwt: JWK missing key–parameter field") + + errJWKCurveSize = errors.New("jwt: JWK curve parameters don't match curve size") + errJWKCurveMiss = errors.New("jwt: JWK curve parameters are not on the curve") +) + +func (keys *KeyRegister) addJWK(j *jwk) error { + // See RFC 7518, subsection 6.1 + + if j.Kty == nil { + return errJWKNoKty + } + switch *j.Kty { + default: + return fmt.Errorf("jwt: JWK with unsupported key type %q", *j.Kty) + + case "EC": + var curve elliptic.Curve + switch j.Crv { + case "P-256": + curve = elliptic.P256() + case "P-384": + curve = elliptic.P384() + case "P-521": + curve = elliptic.P521() + default: + return fmt.Errorf("jwt: JWK with unsupported elliptic curve %q", j.Crv) + } + + x, err := intParam(j.X) + if err != nil { + return err + } + y, err := intParam(j.Y) + if err != nil { + return err + } + + size := (curve.Params().BitSize + 7) / 8 + xSize, ySize := (x.BitLen()+7)/8, (y.BitLen()+7)/8 + if xSize != size || ySize != size { + return errJWKCurveSize + } + + if !curve.IsOnCurve(x, y) { + return errJWKCurveMiss + } + + keys.add(&ecdsa.PublicKey{Curve: curve, X: x, Y: y}, j.Kid) + + case "RSA": + n, err := intParam(j.N) + if err != nil { + return err + } + e, err := intParam(j.E) + if err != nil { + return err + } + + keys.add(&rsa.PublicKey{N: n, E: int(e.Int64())}, j.Kid) + + case "oct": + bytes, err := dataParam(j.K) + if err != nil { + return err + } + keys.add(bytes, j.Kid) + + case "OKP": + switch j.Crv { + case "Ed25519": + bytes, err := dataParam(j.X) + if err != nil { + return err + } + keys.add(ed25519.PublicKey(bytes), j.Kid) + default: + return fmt.Errorf("jwt: JWK with unsupported elliptic curve %q", j.Crv) + } + } + + return nil +} + +func dataParam(p *string) ([]byte, error) { + if p == nil { + return nil, errJWKParam + } + bytes, err := encoding.DecodeString(*p) + if err != nil { + return nil, fmt.Errorf("jwt: JWK with malformed key–parameter field: %w", err) + } + return bytes, nil +} + +func intParam(p *string) (*big.Int, error) { + bytes, err := dataParam(p) + if err != nil { + return nil, err + } + return new(big.Int).SetBytes(bytes), nil +} diff --git a/jwt/sign.go b/jwt/sign.go new file mode 100644 index 0000000..ebd37b0 --- /dev/null +++ b/jwt/sign.go @@ -0,0 +1,325 @@ +package jwt + +import ( + "bytes" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/hmac" + "crypto/rand" + "crypto/rsa" + "encoding/json" + "errors" + "fmt" + "hash" + "strconv" +) + +// FormatWithoutSign updates the Raw fields and returns a new JWT, with only the +// first two parts. +// +// tokenWithoutSignature :≡ header-base64 '.' payload-base64 +// token :≡ tokenWithoutSignature '.' signature-base64 +// +// The JOSE header (content) can be extended with extraHeaders, in the form of +// JSON objects. Redundant and/or duplicate keys are applied as provided. +func (c *Claims) FormatWithoutSign(alg string, extraHeaders ...json.RawMessage) (tokenWithoutSignature []byte, err error) { + return c.newToken(alg, 0, extraHeaders) +} + +// ECDSASign updates the Raw fields and returns a new JWT. +// The return is an AlgError when alg is not in ECDSAAlgs. +// The caller must use the correct key for the respective algorithm (P-256 for +// ES256, P-384 for ES384 and P-521 for ES512) or risk malformed token production. +// +// The JOSE header (content) can be extended with extraHeaders, in the form of +// JSON objects. Redundant and/or duplicate keys are applied as provided. +func (c *Claims) ECDSASign(alg string, key *ecdsa.PrivateKey, extraHeaders ...json.RawMessage) (token []byte, err error) { + hash, err := hashLookup(alg, ECDSAAlgs) + if err != nil { + return nil, err + } + digest := hash.New() + + // signature contains pair (r, s) as per RFC 7518, subsection 3.4 + paramLen := (key.Curve.Params().BitSize + 7) / 8 + token, err = c.newToken(alg, encoding.EncodedLen(paramLen*2), extraHeaders) + if err != nil { + return nil, err + } + digest.Write(token) + + buf := token[len(token):] + r, s, err := ecdsa.Sign(rand.Reader, key, digest.Sum(buf)) + if err != nil { + return nil, err + } + + token = append(token, '.') + sig := token[len(token):cap(token)] + // serialize r and s, using sig as a buffer + i := len(sig) + for _, word := range s.Bits() { + for bitCount := strconv.IntSize; bitCount > 0; bitCount -= 8 { + i-- + sig[i] = byte(word) + word >>= 8 + } + } + // i might have exceeded paramLen due to the word size + i = len(sig) - paramLen + for _, word := range r.Bits() { + for bitCount := strconv.IntSize; bitCount > 0; bitCount -= 8 { + i-- + sig[i] = byte(word) + word >>= 8 + } + } + + // encoder won't overhaul source space + encoding.Encode(sig, sig[len(sig)-2*paramLen:]) + return token[:cap(token)], nil +} + +// EdDSASign updates the Raw fields and returns a new JWT. +// +// The JOSE header (content) can be extended with extraHeaders, in the form of +// JSON objects. Redundant and/or duplicate keys are applied as provided. +func (c *Claims) EdDSASign(key ed25519.PrivateKey, extraHeaders ...json.RawMessage) (token []byte, err error) { + token, err = c.newToken(EdDSA, encoding.EncodedLen(ed25519.SignatureSize), extraHeaders) + if err != nil { + return nil, err + } + + sig := ed25519.Sign(key, token) + + token = append(token, '.') + encoding.Encode(token[len(token):cap(token)], sig) + return token[:cap(token)], nil +} + +// HMACSign updates the Raw fields and returns a new JWT. +// The return is an AlgError when alg is not in HMACAlgs. +// +// The JOSE header (content) can be extended with extraHeaders, in the form of +// JSON objects. Redundant and/or duplicate keys are applied as provided. +func (c *Claims) HMACSign(alg string, secret []byte, extraHeaders ...json.RawMessage) (token []byte, err error) { + if len(secret) == 0 { + return nil, errNoSecret + } + + hash, err := hashLookup(alg, HMACAlgs) + if err != nil { + return nil, err + } + digest := hmac.New(hash.New, secret) + + token, err = c.newToken(alg, encoding.EncodedLen(digest.Size()), extraHeaders) + if err != nil { + return nil, err + } + digest.Write(token) + + token = append(token, '.') + i := cap(token) - digest.Size() + buf := token[i:i] + encoding.Encode(token[len(token):cap(token)], digest.Sum(buf)) + return token[:cap(token)], nil +} + +// Sign updates the Raw fields on c and returns a new JWT. +// +// The JOSE header (content) can be extended with extraHeaders, in the form of +// JSON objects. Redundant and/or duplicate keys are applied as provided. +func (h *HMAC) Sign(c *Claims, extraHeaders ...json.RawMessage) (token []byte, err error) { + digest := h.digests.Get().(hash.Hash) + defer h.digests.Put(digest) + digest.Reset() + + token, err = c.newToken(h.alg, encoding.EncodedLen(digest.Size()), extraHeaders) + if err != nil { + return nil, err + } + digest.Write(token) + + token = append(token, '.') + i := cap(token) - digest.Size() + buf := token[i:i] + encoding.Encode(token[len(token):cap(token)], digest.Sum(buf)) + return token[:cap(token)], nil +} + +// RSASign updates the Raw fields and returns a new JWT. +// The return is an AlgError when alg is not in RSAAlgs. +// +// The JOSE header (content) can be extended with extraHeaders, in the form of +// JSON objects. Redundant and/or duplicate keys are applied as provided. +func (c *Claims) RSASign(alg string, key *rsa.PrivateKey, extraHeaders ...json.RawMessage) (token []byte, err error) { + hash, err := hashLookup(alg, RSAAlgs) + if err != nil { + return nil, err + } + digest := hash.New() + + token, err = c.newToken(alg, encoding.EncodedLen(key.Size()), extraHeaders) + if err != nil { + return nil, err + } + digest.Write(token) + + var sig []byte + buf := token[len(token):] + if alg != "" && alg[0] == 'P' { + sig, err = rsa.SignPSS(rand.Reader, key, hash, digest.Sum(buf), &pSSOptions) + } else { + sig, err = rsa.SignPKCS1v15(rand.Reader, key, hash, digest.Sum(buf)) + } + if err != nil { + return nil, err + } + + token = append(token, '.') + encoding.Encode(token[len(token):cap(token)], sig) + return token[:cap(token)], nil +} + +var ( + headerES256 = []byte(`{"alg":"ES256"}`) + headerES384 = []byte(`{"alg":"ES384"}`) + headerES512 = []byte(`{"alg":"ES512"}`) + headerEdDSA = []byte(`{"alg":"EdDSA"}`) + headerHS256 = []byte(`{"alg":"HS256"}`) + headerHS384 = []byte(`{"alg":"HS384"}`) + headerHS512 = []byte(`{"alg":"HS512"}`) + headerPS256 = []byte(`{"alg":"PS256"}`) + headerPS384 = []byte(`{"alg":"PS384"}`) + headerPS512 = []byte(`{"alg":"PS512"}`) + headerRS256 = []byte(`{"alg":"RS256"}`) + headerRS384 = []byte(`{"alg":"RS384"}`) + headerRS512 = []byte(`{"alg":"RS512"}`) +) + +func (c *Claims) newToken(alg string, encSigLen int, extraHeaders []json.RawMessage) ([]byte, error) { + var payload interface{} + if c.Set == nil { + payload = &c.Registered + } else { + payload = c.Set + + // merge Registered + if c.Issuer != "" { + c.Set[issuer] = c.Issuer + } + if c.Subject != "" { + c.Set[subject] = c.Subject + } + if len(c.Audiences) != 0 { + array := make([]interface{}, len(c.Audiences)) + for i, s := range c.Audiences { + array[i] = s + } + c.Set[audience] = array + } + if c.Expires != nil { + c.Set[expires] = float64(*c.Expires) + } + if c.NotBefore != nil { + c.Set[notBefore] = float64(*c.NotBefore) + } + if c.Issued != nil { + c.Set[issued] = float64(*c.Issued) + } + if c.ID != "" { + c.Set[id] = c.ID + } + } + + // define Claims.Raw + if bytes, err := json.Marshal(payload); err != nil { + return nil, err + } else { + c.Raw = json.RawMessage(bytes) + } + + // try fixed JOSE header + if len(extraHeaders) == 0 && c.KeyID == "" { + var fixed string + switch alg { + case ES256: + fixed = "eyJhbGciOiJFUzI1NiJ9." + c.RawHeader = headerES256 + case ES384: + fixed = "eyJhbGciOiJFUzM4NCJ9." + c.RawHeader = headerES384 + case ES512: + fixed = "eyJhbGciOiJFUzUxMiJ9." + c.RawHeader = headerES512 + case EdDSA: + fixed = "eyJhbGciOiJFZERTQSJ9." + c.RawHeader = headerEdDSA + case HS256: + fixed = "eyJhbGciOiJIUzI1NiJ9." + c.RawHeader = headerHS256 + case HS384: + fixed = "eyJhbGciOiJIUzM4NCJ9." + c.RawHeader = headerHS384 + case HS512: + fixed = "eyJhbGciOiJIUzUxMiJ9." + c.RawHeader = headerHS512 + case PS256: + fixed = "eyJhbGciOiJQUzI1NiJ9." + c.RawHeader = headerPS256 + case PS384: + fixed = "eyJhbGciOiJQUzM4NCJ9." + c.RawHeader = headerPS384 + case PS512: + fixed = "eyJhbGciOiJQUzUxMiJ9." + c.RawHeader = headerPS512 + case RS256: + fixed = "eyJhbGciOiJSUzI1NiJ9." + c.RawHeader = headerRS256 + case RS384: + fixed = "eyJhbGciOiJSUzM4NCJ9." + c.RawHeader = headerRS384 + case RS512: + fixed = "eyJhbGciOiJSUzUxMiJ9." + c.RawHeader = headerRS512 + } + + if fixed != "" { + l := len(fixed) + encoding.EncodedLen(len(c.Raw)) + token := make([]byte, l, l+1+encSigLen) + copy(token, fixed) + encoding.Encode(token[len(fixed):], c.Raw) + return token, nil + } + } + + // compose JOSE header + var header bytes.Buffer + if c.KeyID == "" { + fmt.Fprintf(&header, `{"alg":%q}`, alg) + } else { + fmt.Fprintf(&header, `{"alg":%q,"kid":%q}`, alg, c.KeyID) + } + for _, raw := range extraHeaders { + if len(raw) == 0 || raw[0] != '{' { + return nil, errors.New("jwt: JOSE header addition is not a JSON object") + } + offset := header.Len() - 1 + header.Truncate(offset) + if err := json.Compact(&header, []byte(raw)); err != nil { + return nil, fmt.Errorf("jwt: malformed JOSE header addition: %w", err) + } + header.Bytes()[offset] = ',' + } + c.RawHeader = json.RawMessage(header.Bytes()) + + // compose token + headerLen := encoding.EncodedLen(header.Len()) + l := headerLen + 1 + encoding.EncodedLen(len(c.Raw)) + token := make([]byte, l, l+1+encSigLen) + encoding.Encode(token, header.Bytes()) + token[headerLen] = '.' + encoding.Encode(token[headerLen+1:], c.Raw) + return token, nil +} diff --git a/jwt/web.go b/jwt/web.go new file mode 100644 index 0000000..cd4d0c9 --- /dev/null +++ b/jwt/web.go @@ -0,0 +1,267 @@ +package jwt + +import ( + "context" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rsa" + "errors" + "net/http" + "strconv" + "strings" + "time" +) + +// MIMEType is the IANA registered media type. +const MIMEType = "application/jwt" + +// OAuthURN is the IANA registered OAuth URI. +const OAuthURN = "urn:ietf:params:oauth:token-type:jwt" + +// ErrNoHeader signals an HTTP request without authorization. +var ErrNoHeader = errors.New("jwt: no HTTP authorization header") + +var errNotBearer = errors.New("jwt: not HTTP Bearer scheme") + +// ECDSACheckHeader applies ECDSACheck on an HTTP request. +// Specifically it looks for a bearer token in the Authorization header. +func ECDSACheckHeader(r *http.Request, key *ecdsa.PublicKey) (*Claims, error) { + token, err := BearerToken(r.Header) + if err != nil { + return nil, err + } + return ECDSACheck([]byte(token), key) +} + +// EdDSACheckHeader applies EdDSACheck on an HTTP request. +// Specifically it looks for a bearer token in the Authorization header. +func EdDSACheckHeader(r *http.Request, key ed25519.PublicKey) (*Claims, error) { + token, err := BearerToken(r.Header) + if err != nil { + return nil, err + } + return EdDSACheck([]byte(token), key) +} + +// HMACCheckHeader applies HMACCheck on an HTTP request. +// Specifically it looks for a bearer token in the Authorization header. +func HMACCheckHeader(r *http.Request, secret []byte) (*Claims, error) { + token, err := BearerToken(r.Header) + if err != nil { + return nil, err + } + return HMACCheck([]byte(token), secret) +} + +// CheckHeader applies Check on an HTTP request. +// Specifically it looks for a bearer token in the Authorization header. +func (h *HMAC) CheckHeader(r *http.Request) (*Claims, error) { + token, err := BearerToken(r.Header) + if err != nil { + return nil, err + } + return h.Check([]byte(token)) +} + +// RSACheckHeader applies RSACheck on an HTTP request. +// Specifically it looks for a bearer token in the Authorization header. +func RSACheckHeader(r *http.Request, key *rsa.PublicKey) (*Claims, error) { + token, err := BearerToken(r.Header) + if err != nil { + return nil, err + } + return RSACheck([]byte(token), key) +} + +// CheckHeader applies KeyRegister.Check on an HTTP request. +// Specifically it looks for a bearer token in the Authorization header. +func (keys *KeyRegister) CheckHeader(r *http.Request) (*Claims, error) { + token, err := BearerToken(r.Header) + if err != nil { + return nil, err + } + return keys.Check([]byte(token)) +} + +// Bearer extracts the token from an HTTP header. +func BearerToken(h http.Header) (token string, err error) { + v := h.Values("Authorization") + if len(v) == 0 { + return "", ErrNoHeader + } + // “It MUST be possible to combine the multiple header fields into one + // "field-name: field-value" pair, without changing the semantics of the + // message, by appending each subsequent field-value to the first, each + // separated by a comma.” + // — “Hypertext Transfer Protocol” RFC 2616, subsection 4.2 + s := strings.Join(v, ", ") + + const prefix = "Bearer " + // The scheme is case-insensitive 🤦 as per RFC 2617, subsection 1.2. + if len(s) < len(prefix) || !strings.EqualFold(s[:len(prefix)], prefix) { + return "", errNotBearer + } + return s[len(prefix):], nil +} + +// ECDSASignHeader applies ECDSASign on an HTTP request. +// Specifically it sets a bearer token in the Authorization header. +func (c *Claims) ECDSASignHeader(r *http.Request, alg string, key *ecdsa.PrivateKey) error { + token, err := c.ECDSASign(alg, key) + if err != nil { + return err + } + r.Header.Set("Authorization", "Bearer "+string(token)) + return nil +} + +// EdDSASignHeader applies ECDSASign on an HTTP request. +// Specifically it sets a bearer token in the Authorization header. +func (c *Claims) EdDSASignHeader(r *http.Request, key ed25519.PrivateKey) error { + token, err := c.EdDSASign(key) + if err != nil { + return err + } + r.Header.Set("Authorization", "Bearer "+string(token)) + return nil +} + +// HMACSignHeader applies HMACSign on an HTTP request. +// Specifically it sets a bearer token in the Authorization header. +func (c *Claims) HMACSignHeader(r *http.Request, alg string, secret []byte) error { + token, err := c.HMACSign(alg, secret) + if err != nil { + return err + } + r.Header.Set("Authorization", "Bearer "+string(token)) + return nil +} + +// SignHeader applies Sign on an HTTP request. +// Specifically it sets a bearer token in the Authorization header. +func (h *HMAC) SignHeader(c *Claims, r *http.Request) error { + token, err := h.Sign(c) + if err != nil { + return err + } + r.Header.Set("Authorization", "Bearer "+string(token)) + return nil +} + +// RSASignHeader applies RSASign on an HTTP request. +// Specifically it sets a bearer token in the Authorization header. +func (c *Claims) RSASignHeader(r *http.Request, alg string, key *rsa.PrivateKey) error { + token, err := c.RSASign(alg, key) + if err != nil { + return err + } + r.Header.Set("Authorization", "Bearer "+string(token)) + return nil +} + +// Handler protects an http.Handler with security enforcements. +// Requests are only passed to Target if the JWT checks out. +type Handler struct { + // Target is the secured service. + Target http.Handler + + // Keys defines the trusted credentials. + Keys *KeyRegister + + // HeaderBinding maps JWT claim names to HTTP header names. + // All requests passed to Target have these headers set. In + // case of failure the request is rejected with status code + // 401 (Unauthorized) and a description. + HeaderBinding map[string]string + + // HeaderPrefix is an optional constraint for JWT claim binding. + // Any client headers that match the prefix are removed from the + // request. + HeaderPrefix string + + // ContextKey places the validated Claims in the context of + // each respective request passed to Target when set. See + // http.Request.Context and context.Context.Value. + ContextKey interface{} + + // When not nil, then Func is called after the JWT validation + // succeeds and before any header bindings. Target is skipped + // [request drop] when the return is false. + // This feature may be used to further customise requests or + // as a filter or as an extended http.HandlerFunc. + Func func(http.ResponseWriter, *http.Request, *Claims) (pass bool) + + // Error sends a custom response. Nil defaults to http.Error. + // The appropriate WWW-Authenticate value is already present. + Error func(w http.ResponseWriter, error string, statusCode int) +} + +func (h *Handler) error(w http.ResponseWriter, error string, statusCode int) { + if h.Error != nil { + h.Error(w, error, statusCode) + } else { + http.Error(w, error, statusCode) + } +} + +// ServeHTTP honors the http.Handler interface. +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // verify claims + claims, err := h.Keys.CheckHeader(r) + if err != nil { + if err == ErrNoHeader { + w.Header().Set("WWW-Authenticate", "Bearer") + } else { + w.Header().Set("WWW-Authenticate", `Bearer error="invalid_token", error_description=`+strconv.QuoteToASCII(err.Error())) + } + h.error(w, err.Error(), http.StatusUnauthorized) + return + } + + // verify time constraints + if !claims.Valid(time.Now()) { + w.Header().Set("WWW-Authenticate", `Bearer error="invalid_token", error_description="jwt: time constraints exceeded"`) + h.error(w, "jwt: time constraints exceeded", http.StatusUnauthorized) + return + } + + // filter request headers + headerPrefix := http.CanonicalHeaderKey(h.HeaderPrefix) + if headerPrefix != "" { + for name := range r.Header { + if strings.HasPrefix(name, headerPrefix) { + delete(r.Header, name) + } + } + } + + // apply the custom function when set + if h.Func != nil && !h.Func(w, r, claims) { + return + } + + // claim propagation + for claimName, headerName := range h.HeaderBinding { + headerName = http.CanonicalHeaderKey(headerName) + if !strings.HasPrefix(headerName, headerPrefix) { + h.error(w, "jwt: prefix mismatch in header binding", http.StatusInternalServerError) + return + } + + s, ok := claims.String(claimName) + if !ok { + msg := "jwt: want string for claim " + claimName + w.Header().Set("WWW-Authenticate", `Bearer error="invalid_token", error_description=`+strconv.QuoteToASCII(msg)) + h.error(w, msg, http.StatusUnauthorized) + return + } + r.Header[headerName] = []string{s} + } + + // place claims in request context + if h.ContextKey != nil { + r = r.WithContext(context.WithValue(r.Context(), h.ContextKey, claims)) + } + + h.Target.ServeHTTP(w, r) +}