Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed Symmetric Key Encryption/Decryption #12

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 97 additions & 46 deletions jose.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,9 @@ def deserialize_compact(jwt):

return token_type(*parts)


def encrypt(claims, jwk, adata='', add_header=None, alg='RSA-OAEP',
enc='A128CBC-HS256', rng=get_random_bytes, compression=None):
enc='A128CBC-HS256', rng=get_random_bytes, compression=None,
dir_key=""):
""" Encrypts the given claims and produces a :class:`~jose.JWE`

:param claims: A `dict` representing the claims for this
Expand All @@ -128,6 +128,8 @@ def encrypt(claims, jwk, adata='', add_header=None, alg='RSA-OAEP',
as output.
:param compression: The compression algorithm to use. Currently supports
`'DEF'`.
:param dir_key: A symmetric key to be used for Direct Ciphertext Encryption.
Defined in RFC 7518, Section 4.1
:rtype: :class:`~jose.JWE`
:raises: :class:`~jose.Error` if there is an error producing the JWE
"""
Expand Down Expand Up @@ -157,19 +159,35 @@ def encrypt(claims, jwk, adata='', add_header=None, alg='RSA-OAEP',
raise Error(
'Unsupported compression algorithm: {}'.format(compression))
plaintext = compress(plaintext)

alg = header['alg']
if(alg == 'dir'):
# body encryption/hash
((cipher, _), key_size), ((hash_fn, _), hash_mod) = JWA[enc]
iv = rng(AES.block_size)
# for Direct encryption, pre-shared symmetric key is used

#CHECK SECOND VALUE
ciphertext = cipher(plaintext, dir_key, iv)
hash = hash_fn(_jwe_hash_str(ciphertext, iv, adata),
dir_key, hash_mod)

# cek encryption
encryption_key_ciphertext = ''

else:
# body encryption/hash
((cipher, _), key_size), ((hash_fn, _), hash_mod) = JWA[enc]
iv = rng(AES.block_size)
encryption_key = rng(hash_mod.digest_size)

# body encryption/hash
((cipher, _), key_size), ((hash_fn, _), hash_mod) = JWA[enc]
iv = rng(AES.block_size)
encryption_key = rng(hash_mod.digest_size)

ciphertext = cipher(plaintext, encryption_key[-hash_mod.digest_size/2:], iv)
hash = hash_fn(_jwe_hash_str(ciphertext, iv, adata),
encryption_key[:-hash_mod.digest_size/2], hash_mod)
ciphertext = cipher(plaintext, encryption_key[-hash_mod.digest_size/2:], iv)
hash = hash_fn(_jwe_hash_str(ciphertext, iv, adata),
encryption_key[:-hash_mod.digest_size/2], hash_mod)

# cek encryption
(cipher, _), _ = JWA[alg]
encryption_key_ciphertext = cipher(encryption_key, jwk)
# cek encryption
(cipher, _), _ = JWA[alg]
encryption_key_ciphertext = cipher(encryption_key, jwk)

return JWE(*map(b64encode_url,
(json_encode(header),
Expand All @@ -178,8 +196,7 @@ def encrypt(claims, jwk, adata='', add_header=None, alg='RSA-OAEP',
ciphertext,
auth_tag(hash))))


def decrypt(jwe, jwk, adata='', validate_claims=True, expiry_seconds=None):
def decrypt(jwe, jwk, adata='', validate_claims=True, expiry_seconds=None, dir_key=""):
""" Decrypts a deserialized :class:`~jose.JWE`

:param jwe: An instance of :class:`~jose.JWE`
Expand All @@ -193,6 +210,8 @@ def decrypt(jwe, jwk, adata='', validate_claims=True, expiry_seconds=None):
:param expiry_seconds: An `int` containing the JWT expiry in seconds, used
when evaluating the `iat` claim. Defaults to `None`,
which disables `iat` claim validation.
:param dir_key: A symmetric key to be used for Direct Ciphertext Encryption.
Defined in RFC 7518, Section 4.1
:rtype: :class:`~jose.JWT`
:raises: :class:`~jose.Expired` if the JWT has expired
:raises: :class:`~jose.NotYetValid` if the JWT is not yet valid
Expand All @@ -202,41 +221,73 @@ def decrypt(jwe, jwk, adata='', validate_claims=True, expiry_seconds=None):
b64decode_url, jwe)
header = json_decode(header)

# decrypt cek
(_, decipher), _ = JWA[header['alg']]
encryption_key = decipher(encryption_key_ciphertext, jwk)

# decrypt body
((_, decipher), _), ((hash_fn, _), mod) = JWA[header['enc']]
alg = header['alg']
if(alg == 'dir'):#Use a shared symmetric key as the CEK
# decrypt body
((_, decipher), _), ((hash_fn, _), mod) = JWA[header['enc']]

version = header.get(_TEMP_VER_KEY)
if version:
plaintext = decipher(ciphertext, dir_key, iv)
hash = hash_fn(_jwe_hash_str(ciphertext, iv, adata, version),
dir_key, mod=mod)
else:
plaintext = decipher(ciphertext, dir_key, iv)
hash = hash_fn(_jwe_hash_str(ciphertext, iv, adata, version),
dir_key, mod=mod)
if not const_compare(auth_tag(hash), tag):
raise Error('Mismatched authentication tags')
if 'zip' in header:
try:
(_, decompress) = COMPRESSION[header['zip']]
except KeyError:
raise Error('Unsupported compression algorithm: {}'.format(
header['zip']))
plaintext = decompress(plaintext)

claims = json_decode(plaintext)
try:
del claims[_TEMP_VER_KEY]
except KeyError:
# expected when decrypting legacy tokens
pass

version = header.get(_TEMP_VER_KEY)
if version:
plaintext = decipher(ciphertext, encryption_key[-mod.digest_size/2:], iv)
hash = hash_fn(_jwe_hash_str(ciphertext, iv, adata, version),
encryption_key[:-mod.digest_size/2], mod=mod)
else:
plaintext = decipher(ciphertext, encryption_key[:-mod.digest_size], iv)
hash = hash_fn(_jwe_hash_str(ciphertext, iv, adata, version),
encryption_key[-mod.digest_size:], mod=mod)

if not const_compare(auth_tag(hash), tag):
raise Error('Mismatched authentication tags')

if 'zip' in header:
# decrypt cek
(_, decipher), _ = JWA[header['alg']]
encryption_key = decipher(encryption_key_ciphertext, jwk)

# decrypt body
((_, decipher), _), ((hash_fn, _), mod) = JWA[header['enc']]

version = header.get(_TEMP_VER_KEY)
if version:
plaintext = decipher(ciphertext, encryption_key[-mod.digest_size/2:], iv)
hash = hash_fn(_jwe_hash_str(ciphertext, iv, adata, version),
encryption_key[:-mod.digest_size/2], mod=mod)
else:
plaintext = decipher(ciphertext, encryption_key[:-mod.digest_size], iv)
hash = hash_fn(_jwe_hash_str(ciphertext, iv, adata, version),
encryption_key[-mod.digest_size:], mod=mod)

if not const_compare(auth_tag(hash), tag):
raise Error('Mismatched authentication tags')

if 'zip' in header:
try:
(_, decompress) = COMPRESSION[header['zip']]
except KeyError:
raise Error('Unsupported compression algorithm: {}'.format(
header['zip']))

plaintext = decompress(plaintext)

claims = json_decode(plaintext)
try:
(_, decompress) = COMPRESSION[header['zip']]
del claims[_TEMP_VER_KEY]
except KeyError:
raise Error('Unsupported compression algorithm: {}'.format(
header['zip']))

plaintext = decompress(plaintext)

claims = json_decode(plaintext)
try:
del claims[_TEMP_VER_KEY]
except KeyError:
# expected when decrypting legacy tokens
pass
# expected when decrypting legacy tokens
pass

_validate(claims, validate_claims, expiry_seconds)

Expand Down
27 changes: 27 additions & 0 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,34 @@ def test_jwe_add_header(self):
jwt = jose.decrypt(jose.deserialize_compact(et), rsa_priv_key)

self.assertEqual(jwt.header['foo'], add_header['foo'])
def test_jwe_direct_encryption(self):
symmetric_key = "tisasymmetrickey"

jwe = jose.encrypt(claims, "", alg = "dir",enc="A128CBC-HS256",
dir_key = symmetric_key)

# make sure the body can't be loaded as json (should be encrypted)
try:
json.loads(jose.b64decode_url(jwe.ciphertext))
self.fail()
except ValueError:
pass
token = jose.serialize_compact(jwe)

jwt = jose.decrypt(jose.deserialize_compact(token),"",
dir_key = symmetric_key)
self.assertNotIn(jose._TEMP_VER_KEY, claims)

self.assertEqual(jwt.claims, claims)

# invalid key
badkey = "1234123412341234"
try:
jose.decrypt(jose.deserialize_compact(token), '', dir_key=badkey)
self.fail()
except jose.Error as e:
self.assertEqual(e.message, 'Mismatched authentication tags')

def test_jwe_adata(self):
adata = '42'
for (alg, jwk), enc in product(self.algs, self.encs):
Expand Down