diff --git a/jose.py b/jose.py index c93089a..3257681 100644 --- a/jose.py +++ b/jose.py @@ -108,9 +108,9 @@ def deserialize_compact(jwt): return token_type(*parts) - def encrypt(claims, jwk, adata='', add_header=None, alg='RSA-OAEP', - enc='A128CBC-HS256', rng=get_random_bytes, compression=None): + enc='A128CBC-HS256', rng=get_random_bytes, compression=None, + dir_key=""): """ Encrypts the given claims and produces a :class:`~jose.JWE` :param claims: A `dict` representing the claims for this @@ -128,6 +128,8 @@ def encrypt(claims, jwk, adata='', add_header=None, alg='RSA-OAEP', as output. :param compression: The compression algorithm to use. Currently supports `'DEF'`. + :param dir_key: A symmetric key to be used for Direct Ciphertext Encryption. + Defined in RFC 7518, Section 4.1 :rtype: :class:`~jose.JWE` :raises: :class:`~jose.Error` if there is an error producing the JWE """ @@ -157,19 +159,35 @@ def encrypt(claims, jwk, adata='', add_header=None, alg='RSA-OAEP', raise Error( 'Unsupported compression algorithm: {}'.format(compression)) plaintext = compress(plaintext) + + alg = header['alg'] + if(alg == 'dir'): + # body encryption/hash + ((cipher, _), key_size), ((hash_fn, _), hash_mod) = JWA[enc] + iv = rng(AES.block_size) + # for Direct encryption, pre-shared symmetric key is used + + #CHECK SECOND VALUE + ciphertext = cipher(plaintext, dir_key, iv) + hash = hash_fn(_jwe_hash_str(ciphertext, iv, adata), + dir_key, hash_mod) + + # cek encryption + encryption_key_ciphertext = '' + + else: + # body encryption/hash + ((cipher, _), key_size), ((hash_fn, _), hash_mod) = JWA[enc] + iv = rng(AES.block_size) + encryption_key = rng(hash_mod.digest_size) - # body encryption/hash - ((cipher, _), key_size), ((hash_fn, _), hash_mod) = JWA[enc] - iv = rng(AES.block_size) - encryption_key = rng(hash_mod.digest_size) - - ciphertext = cipher(plaintext, encryption_key[-hash_mod.digest_size/2:], iv) - hash = hash_fn(_jwe_hash_str(ciphertext, iv, adata), - encryption_key[:-hash_mod.digest_size/2], hash_mod) + ciphertext = cipher(plaintext, encryption_key[-hash_mod.digest_size/2:], iv) + hash = hash_fn(_jwe_hash_str(ciphertext, iv, adata), + encryption_key[:-hash_mod.digest_size/2], hash_mod) - # cek encryption - (cipher, _), _ = JWA[alg] - encryption_key_ciphertext = cipher(encryption_key, jwk) + # cek encryption + (cipher, _), _ = JWA[alg] + encryption_key_ciphertext = cipher(encryption_key, jwk) return JWE(*map(b64encode_url, (json_encode(header), @@ -178,8 +196,7 @@ def encrypt(claims, jwk, adata='', add_header=None, alg='RSA-OAEP', ciphertext, auth_tag(hash)))) - -def decrypt(jwe, jwk, adata='', validate_claims=True, expiry_seconds=None): +def decrypt(jwe, jwk, adata='', validate_claims=True, expiry_seconds=None, dir_key=""): """ Decrypts a deserialized :class:`~jose.JWE` :param jwe: An instance of :class:`~jose.JWE` @@ -193,6 +210,8 @@ def decrypt(jwe, jwk, adata='', validate_claims=True, expiry_seconds=None): :param expiry_seconds: An `int` containing the JWT expiry in seconds, used when evaluating the `iat` claim. Defaults to `None`, which disables `iat` claim validation. + :param dir_key: A symmetric key to be used for Direct Ciphertext Encryption. + Defined in RFC 7518, Section 4.1 :rtype: :class:`~jose.JWT` :raises: :class:`~jose.Expired` if the JWT has expired :raises: :class:`~jose.NotYetValid` if the JWT is not yet valid @@ -202,41 +221,73 @@ def decrypt(jwe, jwk, adata='', validate_claims=True, expiry_seconds=None): b64decode_url, jwe) header = json_decode(header) - # decrypt cek - (_, decipher), _ = JWA[header['alg']] - encryption_key = decipher(encryption_key_ciphertext, jwk) - - # decrypt body - ((_, decipher), _), ((hash_fn, _), mod) = JWA[header['enc']] + alg = header['alg'] + if(alg == 'dir'):#Use a shared symmetric key as the CEK + # decrypt body + ((_, decipher), _), ((hash_fn, _), mod) = JWA[header['enc']] + + version = header.get(_TEMP_VER_KEY) + if version: + plaintext = decipher(ciphertext, dir_key, iv) + hash = hash_fn(_jwe_hash_str(ciphertext, iv, adata, version), + dir_key, mod=mod) + else: + plaintext = decipher(ciphertext, dir_key, iv) + hash = hash_fn(_jwe_hash_str(ciphertext, iv, adata, version), + dir_key, mod=mod) + if not const_compare(auth_tag(hash), tag): + raise Error('Mismatched authentication tags') + if 'zip' in header: + try: + (_, decompress) = COMPRESSION[header['zip']] + except KeyError: + raise Error('Unsupported compression algorithm: {}'.format( + header['zip'])) + plaintext = decompress(plaintext) + + claims = json_decode(plaintext) + try: + del claims[_TEMP_VER_KEY] + except KeyError: + # expected when decrypting legacy tokens + pass - version = header.get(_TEMP_VER_KEY) - if version: - plaintext = decipher(ciphertext, encryption_key[-mod.digest_size/2:], iv) - hash = hash_fn(_jwe_hash_str(ciphertext, iv, adata, version), - encryption_key[:-mod.digest_size/2], mod=mod) else: - plaintext = decipher(ciphertext, encryption_key[:-mod.digest_size], iv) - hash = hash_fn(_jwe_hash_str(ciphertext, iv, adata, version), - encryption_key[-mod.digest_size:], mod=mod) - - if not const_compare(auth_tag(hash), tag): - raise Error('Mismatched authentication tags') - - if 'zip' in header: + # decrypt cek + (_, decipher), _ = JWA[header['alg']] + encryption_key = decipher(encryption_key_ciphertext, jwk) + + # decrypt body + ((_, decipher), _), ((hash_fn, _), mod) = JWA[header['enc']] + + version = header.get(_TEMP_VER_KEY) + if version: + plaintext = decipher(ciphertext, encryption_key[-mod.digest_size/2:], iv) + hash = hash_fn(_jwe_hash_str(ciphertext, iv, adata, version), + encryption_key[:-mod.digest_size/2], mod=mod) + else: + plaintext = decipher(ciphertext, encryption_key[:-mod.digest_size], iv) + hash = hash_fn(_jwe_hash_str(ciphertext, iv, adata, version), + encryption_key[-mod.digest_size:], mod=mod) + + if not const_compare(auth_tag(hash), tag): + raise Error('Mismatched authentication tags') + + if 'zip' in header: + try: + (_, decompress) = COMPRESSION[header['zip']] + except KeyError: + raise Error('Unsupported compression algorithm: {}'.format( + header['zip'])) + + plaintext = decompress(plaintext) + + claims = json_decode(plaintext) try: - (_, decompress) = COMPRESSION[header['zip']] + del claims[_TEMP_VER_KEY] except KeyError: - raise Error('Unsupported compression algorithm: {}'.format( - header['zip'])) - - plaintext = decompress(plaintext) - - claims = json_decode(plaintext) - try: - del claims[_TEMP_VER_KEY] - except KeyError: - # expected when decrypting legacy tokens - pass + # expected when decrypting legacy tokens + pass _validate(claims, validate_claims, expiry_seconds) diff --git a/tests.py b/tests.py index 04a4f72..f9029e1 100644 --- a/tests.py +++ b/tests.py @@ -159,7 +159,34 @@ def test_jwe_add_header(self): jwt = jose.decrypt(jose.deserialize_compact(et), rsa_priv_key) self.assertEqual(jwt.header['foo'], add_header['foo']) + def test_jwe_direct_encryption(self): + symmetric_key = "tisasymmetrickey" + jwe = jose.encrypt(claims, "", alg = "dir",enc="A128CBC-HS256", + dir_key = symmetric_key) + + # make sure the body can't be loaded as json (should be encrypted) + try: + json.loads(jose.b64decode_url(jwe.ciphertext)) + self.fail() + except ValueError: + pass + token = jose.serialize_compact(jwe) + + jwt = jose.decrypt(jose.deserialize_compact(token),"", + dir_key = symmetric_key) + self.assertNotIn(jose._TEMP_VER_KEY, claims) + + self.assertEqual(jwt.claims, claims) + + # invalid key + badkey = "1234123412341234" + try: + jose.decrypt(jose.deserialize_compact(token), '', dir_key=badkey) + self.fail() + except jose.Error as e: + self.assertEqual(e.message, 'Mismatched authentication tags') + def test_jwe_adata(self): adata = '42' for (alg, jwk), enc in product(self.algs, self.encs):