From b3aa3a2a7a46cd60f7ec12411ddf5baa854ad17f Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Fri, 23 Jun 2023 18:17:07 -0400 Subject: [PATCH] Port RSA to rust --- .../hazmat/backends/openssl/backend.py | 270 +------- .../hazmat/backends/openssl/rsa.py | 574 ---------------- .../bindings/_rust/openssl/__init__.pyi | 2 + .../hazmat/bindings/_rust/openssl/rsa.pyi | 23 + .../hazmat/primitives/asymmetric/rsa.py | 3 + src/rust/Cargo.lock | 9 +- src/rust/Cargo.toml | 4 + src/rust/src/backend/ec.rs | 4 - src/rust/src/backend/mod.rs | 2 + src/rust/src/backend/rsa.rs | 614 ++++++++++++++++++ src/rust/src/backend/utils.rs | 42 +- tests/hazmat/primitives/test_rsa.py | 38 ++ 12 files changed, 745 insertions(+), 840 deletions(-) delete mode 100644 src/cryptography/hazmat/backends/openssl/rsa.py create mode 100644 src/cryptography/hazmat/bindings/_rust/openssl/rsa.pyi create mode 100644 src/rust/src/backend/rsa.rs diff --git a/src/cryptography/hazmat/backends/openssl/backend.py b/src/cryptography/hazmat/backends/openssl/backend.py index b4294224035ac..da70f793f128c 100644 --- a/src/cryptography/hazmat/backends/openssl/backend.py +++ b/src/cryptography/hazmat/backends/openssl/backend.py @@ -14,10 +14,6 @@ from cryptography.hazmat.backends.openssl import aead from cryptography.hazmat.backends.openssl.ciphers import _CipherContext from cryptography.hazmat.backends.openssl.cmac import _CMACContext -from cryptography.hazmat.backends.openssl.rsa import ( - _RSAPrivateKey, - _RSAPublicKey, -) from cryptography.hazmat.bindings._rust import openssl as rust_openssl from cryptography.hazmat.bindings.openssl import binding from cryptography.hazmat.primitives import hashes, serialization @@ -63,7 +59,6 @@ XTS, Mode, ) -from cryptography.hazmat.primitives.serialization import ssh from cryptography.hazmat.primitives.serialization.pkcs12 import ( PBES, PKCS12Certificate, @@ -360,24 +355,7 @@ def generate_rsa_private_key( self, public_exponent: int, key_size: int ) -> rsa.RSAPrivateKey: rsa._verify_rsa_parameters(public_exponent, key_size) - - rsa_cdata = self._lib.RSA_new() - self.openssl_assert(rsa_cdata != self._ffi.NULL) - rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free) - - bn = self._int_to_bn(public_exponent) - bn = self._ffi.gc(bn, self._lib.BN_free) - - res = self._lib.RSA_generate_key_ex( - rsa_cdata, key_size, bn, self._ffi.NULL - ) - self.openssl_assert(res == 1) - evp_pkey = self._rsa_cdata_to_evp_pkey(rsa_cdata) - - # We can skip RSA key validation here since we just generated the key - return _RSAPrivateKey( - self, rsa_cdata, evp_pkey, unsafe_skip_rsa_key_validation=True - ) + return rust_openssl.rsa.generate_private_key(public_exponent, key_size) def generate_rsa_parameters_supported( self, public_exponent: int, key_size: int @@ -403,46 +381,15 @@ def load_rsa_private_numbers( numbers.public_numbers.e, numbers.public_numbers.n, ) - rsa_cdata = self._lib.RSA_new() - self.openssl_assert(rsa_cdata != self._ffi.NULL) - rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free) - p = self._int_to_bn(numbers.p) - q = self._int_to_bn(numbers.q) - d = self._int_to_bn(numbers.d) - dmp1 = self._int_to_bn(numbers.dmp1) - dmq1 = self._int_to_bn(numbers.dmq1) - iqmp = self._int_to_bn(numbers.iqmp) - e = self._int_to_bn(numbers.public_numbers.e) - n = self._int_to_bn(numbers.public_numbers.n) - res = self._lib.RSA_set0_factors(rsa_cdata, p, q) - self.openssl_assert(res == 1) - res = self._lib.RSA_set0_key(rsa_cdata, n, e, d) - self.openssl_assert(res == 1) - res = self._lib.RSA_set0_crt_params(rsa_cdata, dmp1, dmq1, iqmp) - self.openssl_assert(res == 1) - evp_pkey = self._rsa_cdata_to_evp_pkey(rsa_cdata) - - return _RSAPrivateKey( - self, - rsa_cdata, - evp_pkey, - unsafe_skip_rsa_key_validation=unsafe_skip_rsa_key_validation, + return rust_openssl.rsa.from_private_numbers( + numbers, unsafe_skip_rsa_key_validation ) def load_rsa_public_numbers( self, numbers: rsa.RSAPublicNumbers ) -> rsa.RSAPublicKey: rsa._check_public_key_components(numbers.e, numbers.n) - rsa_cdata = self._lib.RSA_new() - self.openssl_assert(rsa_cdata != self._ffi.NULL) - rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free) - e = self._int_to_bn(numbers.e) - n = self._int_to_bn(numbers.n) - res = self._lib.RSA_set0_key(rsa_cdata, n, e, self._ffi.NULL) - self.openssl_assert(res == 1) - evp_pkey = self._rsa_cdata_to_evp_pkey(rsa_cdata) - - return _RSAPublicKey(self, rsa_cdata, evp_pkey) + return rust_openssl.rsa.from_public_numbers(numbers) def _create_evp_pkey_gc(self): evp_pkey = self._lib.EVP_PKEY_new() @@ -502,13 +449,8 @@ def _evp_pkey_to_private_key( key_type = self._lib.EVP_PKEY_id(evp_pkey) if key_type == self._lib.EVP_PKEY_RSA: - rsa_cdata = self._lib.EVP_PKEY_get1_RSA(evp_pkey) - self.openssl_assert(rsa_cdata != self._ffi.NULL) - rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free) - return _RSAPrivateKey( - self, - rsa_cdata, - evp_pkey, + return rust_openssl.rsa.private_key_from_ptr( + int(self._ffi.cast("uintptr_t", evp_pkey)), unsafe_skip_rsa_key_validation=unsafe_skip_rsa_key_validation, ) elif ( @@ -575,10 +517,9 @@ def _evp_pkey_to_public_key(self, evp_pkey) -> PublicKeyTypes: key_type = self._lib.EVP_PKEY_id(evp_pkey) if key_type == self._lib.EVP_PKEY_RSA: - rsa_cdata = self._lib.EVP_PKEY_get1_RSA(evp_pkey) - self.openssl_assert(rsa_cdata != self._ffi.NULL) - rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free) - return _RSAPublicKey(self, rsa_cdata, evp_pkey) + return rust_openssl.rsa.public_key_from_ptr( + int(self._ffi.cast("uintptr_t", evp_pkey)) + ) elif ( key_type == self._lib.EVP_PKEY_RSA_PSS and not self._lib.CRYPTOGRAPHY_IS_LIBRESSL @@ -735,7 +676,9 @@ def load_pem_public_key(self, data: bytes) -> PublicKeyTypes: if rsa_cdata != self._ffi.NULL: rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free) evp_pkey = self._rsa_cdata_to_evp_pkey(rsa_cdata) - return _RSAPublicKey(self, rsa_cdata, evp_pkey) + return rust_openssl.rsa.public_key_from_ptr( + int(self._ffi.cast("uintptr_t", evp_pkey)) + ) else: self._handle_key_loading_error() @@ -798,7 +741,9 @@ def load_der_public_key(self, data: bytes) -> PublicKeyTypes: if rsa_cdata != self._ffi.NULL: rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free) evp_pkey = self._rsa_cdata_to_evp_pkey(rsa_cdata) - return _RSAPublicKey(self, rsa_cdata, evp_pkey) + return rust_openssl.rsa.public_key_from_ptr( + int(self._ffi.cast("uintptr_t", evp_pkey)) + ) else: self._handle_key_loading_error() @@ -986,191 +931,6 @@ def elliptic_curve_exchange_algorithm_supported( algorithm, ec.ECDH ) - def _private_key_bytes( - self, - encoding: serialization.Encoding, - format: serialization.PrivateFormat, - encryption_algorithm: serialization.KeySerializationEncryption, - key, - evp_pkey, - cdata, - ) -> bytes: - # validate argument types - if not isinstance(encoding, serialization.Encoding): - raise TypeError("encoding must be an item from the Encoding enum") - if not isinstance(format, serialization.PrivateFormat): - raise TypeError( - "format must be an item from the PrivateFormat enum" - ) - if not isinstance( - encryption_algorithm, serialization.KeySerializationEncryption - ): - raise TypeError( - "Encryption algorithm must be a KeySerializationEncryption " - "instance" - ) - - # validate password - if isinstance(encryption_algorithm, serialization.NoEncryption): - password = b"" - elif isinstance( - encryption_algorithm, serialization.BestAvailableEncryption - ): - password = encryption_algorithm.password - if len(password) > 1023: - raise ValueError( - "Passwords longer than 1023 bytes are not supported by " - "this backend" - ) - elif ( - isinstance( - encryption_algorithm, serialization._KeySerializationEncryption - ) - and encryption_algorithm._format - is format - is serialization.PrivateFormat.OpenSSH - ): - password = encryption_algorithm.password - else: - raise ValueError("Unsupported encryption type") - - # PKCS8 + PEM/DER - if format is serialization.PrivateFormat.PKCS8: - if encoding is serialization.Encoding.PEM: - write_bio = self._lib.PEM_write_bio_PKCS8PrivateKey - elif encoding is serialization.Encoding.DER: - write_bio = self._lib.i2d_PKCS8PrivateKey_bio - else: - raise ValueError("Unsupported encoding for PKCS8") - return self._private_key_bytes_via_bio( - write_bio, evp_pkey, password - ) - - # TraditionalOpenSSL + PEM/DER - if format is serialization.PrivateFormat.TraditionalOpenSSL: - if self._fips_enabled and not isinstance( - encryption_algorithm, serialization.NoEncryption - ): - raise ValueError( - "Encrypted traditional OpenSSL format is not " - "supported in FIPS mode." - ) - key_type = self._lib.EVP_PKEY_id(evp_pkey) - - if encoding is serialization.Encoding.PEM: - assert key_type == self._lib.EVP_PKEY_RSA - write_bio = self._lib.PEM_write_bio_RSAPrivateKey - return self._private_key_bytes_via_bio( - write_bio, cdata, password - ) - - if encoding is serialization.Encoding.DER: - if password: - raise ValueError( - "Encryption is not supported for DER encoded " - "traditional OpenSSL keys" - ) - assert key_type == self._lib.EVP_PKEY_RSA - write_bio = self._lib.i2d_RSAPrivateKey_bio - return self._bio_func_output(write_bio, cdata) - - raise ValueError("Unsupported encoding for TraditionalOpenSSL") - - # OpenSSH + PEM - if format is serialization.PrivateFormat.OpenSSH: - if encoding is serialization.Encoding.PEM: - return ssh._serialize_ssh_private_key( - key, password, encryption_algorithm - ) - - raise ValueError( - "OpenSSH private key format can only be used" - " with PEM encoding" - ) - - # Anything that key-specific code was supposed to handle earlier, - # like Raw. - raise ValueError("format is invalid with this key") - - def _private_key_bytes_via_bio( - self, write_bio, evp_pkey, password - ) -> bytes: - if not password: - evp_cipher = self._ffi.NULL - else: - # This is a curated value that we will update over time. - evp_cipher = self._lib.EVP_get_cipherbyname(b"aes-256-cbc") - - return self._bio_func_output( - write_bio, - evp_pkey, - evp_cipher, - password, - len(password), - self._ffi.NULL, - self._ffi.NULL, - ) - - def _bio_func_output(self, write_bio, *args) -> bytes: - bio = self._create_mem_bio_gc() - res = write_bio(bio, *args) - self.openssl_assert(res == 1) - return self._read_mem_bio(bio) - - def _public_key_bytes( - self, - encoding: serialization.Encoding, - format: serialization.PublicFormat, - key, - evp_pkey, - cdata, - ) -> bytes: - if not isinstance(encoding, serialization.Encoding): - raise TypeError("encoding must be an item from the Encoding enum") - if not isinstance(format, serialization.PublicFormat): - raise TypeError( - "format must be an item from the PublicFormat enum" - ) - - # SubjectPublicKeyInfo + PEM/DER - if format is serialization.PublicFormat.SubjectPublicKeyInfo: - if encoding is serialization.Encoding.PEM: - write_bio = self._lib.PEM_write_bio_PUBKEY - elif encoding is serialization.Encoding.DER: - write_bio = self._lib.i2d_PUBKEY_bio - else: - raise ValueError( - "SubjectPublicKeyInfo works only with PEM or DER encoding" - ) - return self._bio_func_output(write_bio, evp_pkey) - - # PKCS1 + PEM/DER - if format is serialization.PublicFormat.PKCS1: - # Only RSA is supported here. - key_type = self._lib.EVP_PKEY_id(evp_pkey) - self.openssl_assert(key_type == self._lib.EVP_PKEY_RSA) - - if encoding is serialization.Encoding.PEM: - write_bio = self._lib.PEM_write_bio_RSAPublicKey - elif encoding is serialization.Encoding.DER: - write_bio = self._lib.i2d_RSAPublicKey_bio - else: - raise ValueError("PKCS1 works only with PEM or DER encoding") - return self._bio_func_output(write_bio, cdata) - - # OpenSSH + OpenSSH - if format is serialization.PublicFormat.OpenSSH: - if encoding is serialization.Encoding.OpenSSH: - return ssh.serialize_ssh_public_key(key) - - raise ValueError( - "OpenSSH format must be used with OpenSSH encoding" - ) - - # Anything that key-specific code was supposed to handle earlier, - # like Raw, CompressedPoint, UncompressedPoint - raise ValueError("format is invalid with this key") - def dh_supported(self) -> bool: return not self._lib.CRYPTOGRAPHY_IS_BORINGSSL diff --git a/src/cryptography/hazmat/backends/openssl/rsa.py b/src/cryptography/hazmat/backends/openssl/rsa.py deleted file mode 100644 index b9c96a78faa15..0000000000000 --- a/src/cryptography/hazmat/backends/openssl/rsa.py +++ /dev/null @@ -1,574 +0,0 @@ -# This file is dual licensed under the terms of the Apache License, Version -# 2.0, and the BSD License. See the LICENSE file in the root of this repository -# for complete details. - -from __future__ import annotations - -import typing - -from cryptography.exceptions import ( - InvalidSignature, - UnsupportedAlgorithm, - _Reasons, -) -from cryptography.hazmat.backends.openssl.utils import ( - _calculate_digest_and_algorithm, -) -from cryptography.hazmat.primitives import hashes, serialization -from cryptography.hazmat.primitives.asymmetric import utils as asym_utils -from cryptography.hazmat.primitives.asymmetric.padding import ( - MGF1, - OAEP, - PSS, - AsymmetricPadding, - PKCS1v15, - _Auto, - _DigestLength, - _MaxLength, - calculate_max_pss_salt_length, -) -from cryptography.hazmat.primitives.asymmetric.rsa import ( - RSAPrivateKey, - RSAPrivateNumbers, - RSAPublicKey, - RSAPublicNumbers, -) - -if typing.TYPE_CHECKING: - from cryptography.hazmat.backends.openssl.backend import Backend - - -def _get_rsa_pss_salt_length( - backend: Backend, - pss: PSS, - key: typing.Union[RSAPrivateKey, RSAPublicKey], - hash_algorithm: hashes.HashAlgorithm, -) -> int: - salt = pss._salt_length - - if isinstance(salt, _MaxLength): - return calculate_max_pss_salt_length(key, hash_algorithm) - elif isinstance(salt, _DigestLength): - return hash_algorithm.digest_size - elif isinstance(salt, _Auto): - if isinstance(key, RSAPrivateKey): - raise ValueError( - "PSS salt length can only be set to AUTO when verifying" - ) - return backend._lib.RSA_PSS_SALTLEN_AUTO - else: - return salt - - -def _enc_dec_rsa( - backend: Backend, - key: typing.Union[_RSAPrivateKey, _RSAPublicKey], - data: bytes, - padding: AsymmetricPadding, -) -> bytes: - if not isinstance(padding, AsymmetricPadding): - raise TypeError("Padding must be an instance of AsymmetricPadding.") - - if isinstance(padding, PKCS1v15): - padding_enum = backend._lib.RSA_PKCS1_PADDING - elif isinstance(padding, OAEP): - padding_enum = backend._lib.RSA_PKCS1_OAEP_PADDING - - if not isinstance(padding._mgf, MGF1): - raise UnsupportedAlgorithm( - "Only MGF1 is supported by this backend.", - _Reasons.UNSUPPORTED_MGF, - ) - - if not backend.rsa_padding_supported(padding): - raise UnsupportedAlgorithm( - "This combination of padding and hash algorithm is not " - "supported by this backend.", - _Reasons.UNSUPPORTED_PADDING, - ) - - else: - raise UnsupportedAlgorithm( - f"{padding.name} is not supported by this backend.", - _Reasons.UNSUPPORTED_PADDING, - ) - - return _enc_dec_rsa_pkey_ctx(backend, key, data, padding_enum, padding) - - -def _enc_dec_rsa_pkey_ctx( - backend: Backend, - key: typing.Union[_RSAPrivateKey, _RSAPublicKey], - data: bytes, - padding_enum: int, - padding: AsymmetricPadding, -) -> bytes: - init: typing.Callable[[typing.Any], int] - crypt: typing.Callable[[typing.Any, typing.Any, int, bytes, int], int] - if isinstance(key, _RSAPublicKey): - init = backend._lib.EVP_PKEY_encrypt_init - crypt = backend._lib.EVP_PKEY_encrypt - else: - init = backend._lib.EVP_PKEY_decrypt_init - crypt = backend._lib.EVP_PKEY_decrypt - - pkey_ctx = backend._lib.EVP_PKEY_CTX_new(key._evp_pkey, backend._ffi.NULL) - backend.openssl_assert(pkey_ctx != backend._ffi.NULL) - pkey_ctx = backend._ffi.gc(pkey_ctx, backend._lib.EVP_PKEY_CTX_free) - res = init(pkey_ctx) - backend.openssl_assert(res == 1) - res = backend._lib.EVP_PKEY_CTX_set_rsa_padding(pkey_ctx, padding_enum) - backend.openssl_assert(res > 0) - buf_size = backend._lib.EVP_PKEY_size(key._evp_pkey) - backend.openssl_assert(buf_size > 0) - if isinstance(padding, OAEP): - mgf1_md = backend._evp_md_non_null_from_algorithm( - padding._mgf._algorithm - ) - res = backend._lib.EVP_PKEY_CTX_set_rsa_mgf1_md(pkey_ctx, mgf1_md) - backend.openssl_assert(res > 0) - oaep_md = backend._evp_md_non_null_from_algorithm(padding._algorithm) - res = backend._lib.EVP_PKEY_CTX_set_rsa_oaep_md(pkey_ctx, oaep_md) - backend.openssl_assert(res > 0) - - if ( - isinstance(padding, OAEP) - and padding._label is not None - and len(padding._label) > 0 - ): - # set0_rsa_oaep_label takes ownership of the char * so we need to - # copy it into some new memory - labelptr = backend._lib.OPENSSL_malloc(len(padding._label)) - backend.openssl_assert(labelptr != backend._ffi.NULL) - backend._ffi.memmove(labelptr, padding._label, len(padding._label)) - res = backend._lib.EVP_PKEY_CTX_set0_rsa_oaep_label( - pkey_ctx, labelptr, len(padding._label) - ) - backend.openssl_assert(res == 1) - - outlen = backend._ffi.new("size_t *", buf_size) - buf = backend._ffi.new("unsigned char[]", buf_size) - # Everything from this line onwards is written with the goal of being as - # constant-time as is practical given the constraints of Python and our - # API. See Bleichenbacher's '98 attack on RSA, and its many many variants. - # As such, you should not attempt to change this (particularly to "clean it - # up") without understanding why it was written this way (see - # Chesterton's Fence), and without measuring to verify you have not - # introduced observable time differences. - res = crypt(pkey_ctx, buf, outlen, data, len(data)) - resbuf = backend._ffi.buffer(buf)[: outlen[0]] - backend._lib.ERR_clear_error() - if res <= 0: - raise ValueError("Encryption/decryption failed.") - return resbuf - - -def _rsa_sig_determine_padding( - backend: Backend, - key: typing.Union[_RSAPrivateKey, _RSAPublicKey], - padding: AsymmetricPadding, - algorithm: typing.Optional[hashes.HashAlgorithm], -) -> int: - if not isinstance(padding, AsymmetricPadding): - raise TypeError("Expected provider of AsymmetricPadding.") - - pkey_size = backend._lib.EVP_PKEY_size(key._evp_pkey) - backend.openssl_assert(pkey_size > 0) - - if isinstance(padding, PKCS1v15): - # Hash algorithm is ignored for PKCS1v15-padding, may be None. - padding_enum = backend._lib.RSA_PKCS1_PADDING - elif isinstance(padding, PSS): - if not isinstance(padding._mgf, MGF1): - raise UnsupportedAlgorithm( - "Only MGF1 is supported by this backend.", - _Reasons.UNSUPPORTED_MGF, - ) - - # PSS padding requires a hash algorithm - if not isinstance(algorithm, hashes.HashAlgorithm): - raise TypeError("Expected instance of hashes.HashAlgorithm.") - - # Size of key in bytes - 2 is the maximum - # PSS signature length (salt length is checked later) - if pkey_size - algorithm.digest_size - 2 < 0: - raise ValueError( - "Digest too large for key size. Use a larger " - "key or different digest." - ) - - padding_enum = backend._lib.RSA_PKCS1_PSS_PADDING - else: - raise UnsupportedAlgorithm( - f"{padding.name} is not supported by this backend.", - _Reasons.UNSUPPORTED_PADDING, - ) - - return padding_enum - - -# Hash algorithm can be absent (None) to initialize the context without setting -# any message digest algorithm. This is currently only valid for the PKCS1v15 -# padding type, where it means that the signature data is encoded/decoded -# as provided, without being wrapped in a DigestInfo structure. -def _rsa_sig_setup( - backend: Backend, - padding: AsymmetricPadding, - algorithm: typing.Optional[hashes.HashAlgorithm], - key: typing.Union[_RSAPublicKey, _RSAPrivateKey], - init_func: typing.Callable[[typing.Any], int], -): - padding_enum = _rsa_sig_determine_padding(backend, key, padding, algorithm) - pkey_ctx = backend._lib.EVP_PKEY_CTX_new(key._evp_pkey, backend._ffi.NULL) - backend.openssl_assert(pkey_ctx != backend._ffi.NULL) - pkey_ctx = backend._ffi.gc(pkey_ctx, backend._lib.EVP_PKEY_CTX_free) - res = init_func(pkey_ctx) - if res != 1: - errors = backend._consume_errors() - raise ValueError("Unable to sign/verify with this key", errors) - - if algorithm is not None: - evp_md = backend._evp_md_non_null_from_algorithm(algorithm) - res = backend._lib.EVP_PKEY_CTX_set_signature_md(pkey_ctx, evp_md) - if res <= 0: - backend._consume_errors() - raise UnsupportedAlgorithm( - "{} is not supported by this backend for RSA signing.".format( - algorithm.name - ), - _Reasons.UNSUPPORTED_HASH, - ) - res = backend._lib.EVP_PKEY_CTX_set_rsa_padding(pkey_ctx, padding_enum) - if res <= 0: - backend._consume_errors() - raise UnsupportedAlgorithm( - "{} is not supported for the RSA signature operation.".format( - padding.name - ), - _Reasons.UNSUPPORTED_PADDING, - ) - if isinstance(padding, PSS): - assert isinstance(algorithm, hashes.HashAlgorithm) - res = backend._lib.EVP_PKEY_CTX_set_rsa_pss_saltlen( - pkey_ctx, - _get_rsa_pss_salt_length(backend, padding, key, algorithm), - ) - backend.openssl_assert(res > 0) - - mgf1_md = backend._evp_md_non_null_from_algorithm( - padding._mgf._algorithm - ) - res = backend._lib.EVP_PKEY_CTX_set_rsa_mgf1_md(pkey_ctx, mgf1_md) - backend.openssl_assert(res > 0) - - return pkey_ctx - - -def _rsa_sig_sign( - backend: Backend, - padding: AsymmetricPadding, - algorithm: hashes.HashAlgorithm, - private_key: _RSAPrivateKey, - data: bytes, -) -> bytes: - pkey_ctx = _rsa_sig_setup( - backend, - padding, - algorithm, - private_key, - backend._lib.EVP_PKEY_sign_init, - ) - buflen = backend._ffi.new("size_t *") - res = backend._lib.EVP_PKEY_sign( - pkey_ctx, backend._ffi.NULL, buflen, data, len(data) - ) - backend.openssl_assert(res == 1) - buf = backend._ffi.new("unsigned char[]", buflen[0]) - res = backend._lib.EVP_PKEY_sign(pkey_ctx, buf, buflen, data, len(data)) - if res != 1: - errors = backend._consume_errors() - raise ValueError( - "Digest or salt length too long for key size. Use a larger key " - "or shorter salt length if you are specifying a PSS salt", - errors, - ) - - return backend._ffi.buffer(buf)[:] - - -def _rsa_sig_verify( - backend: Backend, - padding: AsymmetricPadding, - algorithm: hashes.HashAlgorithm, - public_key: _RSAPublicKey, - signature: bytes, - data: bytes, -) -> None: - pkey_ctx = _rsa_sig_setup( - backend, - padding, - algorithm, - public_key, - backend._lib.EVP_PKEY_verify_init, - ) - res = backend._lib.EVP_PKEY_verify( - pkey_ctx, signature, len(signature), data, len(data) - ) - # The previous call can return negative numbers in the event of an - # error. This is not a signature failure but we need to fail if it - # occurs. - backend.openssl_assert(res >= 0) - if res == 0: - backend._consume_errors() - raise InvalidSignature - - -def _rsa_sig_recover( - backend: Backend, - padding: AsymmetricPadding, - algorithm: typing.Optional[hashes.HashAlgorithm], - public_key: _RSAPublicKey, - signature: bytes, -) -> bytes: - pkey_ctx = _rsa_sig_setup( - backend, - padding, - algorithm, - public_key, - backend._lib.EVP_PKEY_verify_recover_init, - ) - - # Attempt to keep the rest of the code in this function as constant/time - # as possible. See the comment in _enc_dec_rsa_pkey_ctx. Note that the - # buflen parameter is used even though its value may be undefined in the - # error case. Due to the tolerant nature of Python slicing this does not - # trigger any exceptions. - maxlen = backend._lib.EVP_PKEY_size(public_key._evp_pkey) - backend.openssl_assert(maxlen > 0) - buf = backend._ffi.new("unsigned char[]", maxlen) - buflen = backend._ffi.new("size_t *", maxlen) - res = backend._lib.EVP_PKEY_verify_recover( - pkey_ctx, buf, buflen, signature, len(signature) - ) - resbuf = backend._ffi.buffer(buf)[: buflen[0]] - backend._lib.ERR_clear_error() - # Assume that all parameter errors are handled during the setup phase and - # any error here is due to invalid signature. - if res != 1: - raise InvalidSignature - return resbuf - - -class _RSAPrivateKey(RSAPrivateKey): - _evp_pkey: object - _rsa_cdata: object - _key_size: int - - def __init__( - self, - backend: Backend, - rsa_cdata, - evp_pkey, - *, - unsafe_skip_rsa_key_validation: bool, - ): - res: int - # RSA_check_key is slower in OpenSSL 3.0.0 due to improved - # primality checking. In normal use this is unlikely to be a problem - # since users don't load new keys constantly, but for TESTING we've - # added an init arg that allows skipping the checks. You should not - # use this in production code unless you understand the consequences. - if not unsafe_skip_rsa_key_validation: - res = backend._lib.RSA_check_key(rsa_cdata) - if res != 1: - errors = backend._consume_errors() - raise ValueError("Invalid private key", errors) - # 2 is prime and passes an RSA key check, so we also check - # if p and q are odd just to be safe. - p = backend._ffi.new("BIGNUM **") - q = backend._ffi.new("BIGNUM **") - backend._lib.RSA_get0_factors(rsa_cdata, p, q) - backend.openssl_assert(p[0] != backend._ffi.NULL) - backend.openssl_assert(q[0] != backend._ffi.NULL) - p_odd = backend._lib.BN_is_odd(p[0]) - q_odd = backend._lib.BN_is_odd(q[0]) - if p_odd != 1 or q_odd != 1: - errors = backend._consume_errors() - raise ValueError("Invalid private key", errors) - - self._backend = backend - self._rsa_cdata = rsa_cdata - self._evp_pkey = evp_pkey - - n = self._backend._ffi.new("BIGNUM **") - self._backend._lib.RSA_get0_key( - self._rsa_cdata, - n, - self._backend._ffi.NULL, - self._backend._ffi.NULL, - ) - self._backend.openssl_assert(n[0] != self._backend._ffi.NULL) - self._key_size = self._backend._lib.BN_num_bits(n[0]) - - @property - def key_size(self) -> int: - return self._key_size - - def decrypt(self, ciphertext: bytes, padding: AsymmetricPadding) -> bytes: - key_size_bytes = (self.key_size + 7) // 8 - if key_size_bytes != len(ciphertext): - raise ValueError("Ciphertext length must be equal to key size.") - - return _enc_dec_rsa(self._backend, self, ciphertext, padding) - - def public_key(self) -> RSAPublicKey: - ctx = self._backend._lib.RSAPublicKey_dup(self._rsa_cdata) - self._backend.openssl_assert(ctx != self._backend._ffi.NULL) - ctx = self._backend._ffi.gc(ctx, self._backend._lib.RSA_free) - evp_pkey = self._backend._rsa_cdata_to_evp_pkey(ctx) - return _RSAPublicKey(self._backend, ctx, evp_pkey) - - def private_numbers(self) -> RSAPrivateNumbers: - n = self._backend._ffi.new("BIGNUM **") - e = self._backend._ffi.new("BIGNUM **") - d = self._backend._ffi.new("BIGNUM **") - p = self._backend._ffi.new("BIGNUM **") - q = self._backend._ffi.new("BIGNUM **") - dmp1 = self._backend._ffi.new("BIGNUM **") - dmq1 = self._backend._ffi.new("BIGNUM **") - iqmp = self._backend._ffi.new("BIGNUM **") - self._backend._lib.RSA_get0_key(self._rsa_cdata, n, e, d) - self._backend.openssl_assert(n[0] != self._backend._ffi.NULL) - self._backend.openssl_assert(e[0] != self._backend._ffi.NULL) - self._backend.openssl_assert(d[0] != self._backend._ffi.NULL) - self._backend._lib.RSA_get0_factors(self._rsa_cdata, p, q) - self._backend.openssl_assert(p[0] != self._backend._ffi.NULL) - self._backend.openssl_assert(q[0] != self._backend._ffi.NULL) - self._backend._lib.RSA_get0_crt_params( - self._rsa_cdata, dmp1, dmq1, iqmp - ) - self._backend.openssl_assert(dmp1[0] != self._backend._ffi.NULL) - self._backend.openssl_assert(dmq1[0] != self._backend._ffi.NULL) - self._backend.openssl_assert(iqmp[0] != self._backend._ffi.NULL) - return RSAPrivateNumbers( - p=self._backend._bn_to_int(p[0]), - q=self._backend._bn_to_int(q[0]), - d=self._backend._bn_to_int(d[0]), - dmp1=self._backend._bn_to_int(dmp1[0]), - dmq1=self._backend._bn_to_int(dmq1[0]), - iqmp=self._backend._bn_to_int(iqmp[0]), - public_numbers=RSAPublicNumbers( - e=self._backend._bn_to_int(e[0]), - n=self._backend._bn_to_int(n[0]), - ), - ) - - def private_bytes( - self, - encoding: serialization.Encoding, - format: serialization.PrivateFormat, - encryption_algorithm: serialization.KeySerializationEncryption, - ) -> bytes: - return self._backend._private_key_bytes( - encoding, - format, - encryption_algorithm, - self, - self._evp_pkey, - self._rsa_cdata, - ) - - def sign( - self, - data: bytes, - padding: AsymmetricPadding, - algorithm: typing.Union[asym_utils.Prehashed, hashes.HashAlgorithm], - ) -> bytes: - data, algorithm = _calculate_digest_and_algorithm(data, algorithm) - return _rsa_sig_sign(self._backend, padding, algorithm, self, data) - - -class _RSAPublicKey(RSAPublicKey): - _evp_pkey: object - _rsa_cdata: object - _key_size: int - - def __init__(self, backend: Backend, rsa_cdata, evp_pkey): - self._backend = backend - self._rsa_cdata = rsa_cdata - self._evp_pkey = evp_pkey - - n = self._backend._ffi.new("BIGNUM **") - self._backend._lib.RSA_get0_key( - self._rsa_cdata, - n, - self._backend._ffi.NULL, - self._backend._ffi.NULL, - ) - self._backend.openssl_assert(n[0] != self._backend._ffi.NULL) - self._key_size = self._backend._lib.BN_num_bits(n[0]) - - @property - def key_size(self) -> int: - return self._key_size - - def __eq__(self, other: object) -> bool: - if not isinstance(other, _RSAPublicKey): - return NotImplemented - - return ( - self._backend._lib.EVP_PKEY_cmp(self._evp_pkey, other._evp_pkey) - == 1 - ) - - def encrypt(self, plaintext: bytes, padding: AsymmetricPadding) -> bytes: - return _enc_dec_rsa(self._backend, self, plaintext, padding) - - def public_numbers(self) -> RSAPublicNumbers: - n = self._backend._ffi.new("BIGNUM **") - e = self._backend._ffi.new("BIGNUM **") - self._backend._lib.RSA_get0_key( - self._rsa_cdata, n, e, self._backend._ffi.NULL - ) - self._backend.openssl_assert(n[0] != self._backend._ffi.NULL) - self._backend.openssl_assert(e[0] != self._backend._ffi.NULL) - return RSAPublicNumbers( - e=self._backend._bn_to_int(e[0]), - n=self._backend._bn_to_int(n[0]), - ) - - def public_bytes( - self, - encoding: serialization.Encoding, - format: serialization.PublicFormat, - ) -> bytes: - return self._backend._public_key_bytes( - encoding, format, self, self._evp_pkey, self._rsa_cdata - ) - - def verify( - self, - signature: bytes, - data: bytes, - padding: AsymmetricPadding, - algorithm: typing.Union[asym_utils.Prehashed, hashes.HashAlgorithm], - ) -> None: - data, algorithm = _calculate_digest_and_algorithm(data, algorithm) - _rsa_sig_verify( - self._backend, padding, algorithm, self, signature, data - ) - - def recover_data_from_signature( - self, - signature: bytes, - padding: AsymmetricPadding, - algorithm: typing.Optional[hashes.HashAlgorithm], - ) -> bytes: - if isinstance(algorithm, asym_utils.Prehashed): - raise TypeError( - "Prehashed is only supported in the sign and verify methods. " - "It cannot be used with recover_data_from_signature." - ) - return _rsa_sig_recover( - self._backend, padding, algorithm, self, signature - ) diff --git a/src/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi b/src/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi index d0e6ccaed238c..16cfa04e420a6 100644 --- a/src/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi +++ b/src/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi @@ -14,6 +14,7 @@ from cryptography.hazmat.bindings._rust.openssl import ( hmac, kdf, poly1305, + rsa, x448, x25519, ) @@ -29,6 +30,7 @@ __all__ = [ "kdf", "ed448", "ed25519", + "rsa", "poly1305", "x448", "x25519", diff --git a/src/cryptography/hazmat/bindings/_rust/openssl/rsa.pyi b/src/cryptography/hazmat/bindings/_rust/openssl/rsa.pyi new file mode 100644 index 0000000000000..d42134f72c74c --- /dev/null +++ b/src/cryptography/hazmat/bindings/_rust/openssl/rsa.pyi @@ -0,0 +1,23 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from cryptography.hazmat.primitives.asymmetric import rsa + +class RSAPrivateKey: ... +class RSAPublicKey: ... + +def generate_private_key( + public_exponent: int, + key_size: int, +) -> rsa.RSAPrivateKey: ... +def private_key_from_ptr( + ptr: int, + unsafe_skip_rsa_key_validation: bool, +) -> rsa.RSAPrivateKey: ... +def public_key_from_ptr(ptr: int) -> rsa.RSAPublicKey: ... +def from_private_numbers( + numbers: rsa.RSAPrivateNumbers, + unsafe_skip_rsa_key_validation: bool, +) -> rsa.RSAPrivateKey: ... +def from_public_numbers(numbers: rsa.RSAPublicNumbers) -> rsa.RSAPublicKey: ... diff --git a/src/cryptography/hazmat/primitives/asymmetric/rsa.py b/src/cryptography/hazmat/primitives/asymmetric/rsa.py index b740f01f7c4cb..bda15f2d1abd7 100644 --- a/src/cryptography/hazmat/primitives/asymmetric/rsa.py +++ b/src/cryptography/hazmat/primitives/asymmetric/rsa.py @@ -8,6 +8,7 @@ import typing from math import gcd +from cryptography.hazmat.bindings._rust import openssl as rust_openssl from cryptography.hazmat.primitives import _serialization, hashes from cryptography.hazmat.primitives._asymmetric import AsymmetricPadding from cryptography.hazmat.primitives.asymmetric import utils as asym_utils @@ -63,6 +64,7 @@ def private_bytes( RSAPrivateKeyWithSerialization = RSAPrivateKey +RSAPrivateKey.register(rust_openssl.rsa.RSAPrivateKey) class RSAPublicKey(metaclass=abc.ABCMeta): @@ -126,6 +128,7 @@ def __eq__(self, other: object) -> bool: RSAPublicKeyWithSerialization = RSAPublicKey +RSAPublicKey.register(rust_openssl.rsa.RSAPublicKey) def generate_private_key( diff --git a/src/rust/Cargo.lock b/src/rust/Cargo.lock index b7b574c726e23..c9f2f45ae046a 100644 --- a/src/rust/Cargo.lock +++ b/src/rust/Cargo.lock @@ -151,8 +151,7 @@ checksum = "9670a07f94779e00908f3e686eab508878ebb390ba6e604d3a284c00e8d0487b" [[package]] name = "openssl" version = "0.10.55" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "345df152bc43501c5eb9e4654ff05f794effb78d4efe3d53abc158baddc0703d" +source = "git+https://github.com/sfackler/rust-openssl#994e5ff8c63557ab2aa85c85cc6956b0b0216ca7" dependencies = [ "bitflags", "cfg-if", @@ -166,8 +165,7 @@ dependencies = [ [[package]] name = "openssl-macros" version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +source = "git+https://github.com/sfackler/rust-openssl#994e5ff8c63557ab2aa85c85cc6956b0b0216ca7" dependencies = [ "proc-macro2", "quote", @@ -177,8 +175,7 @@ dependencies = [ [[package]] name = "openssl-sys" version = "0.9.90" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "374533b0e45f3a7ced10fcaeccca020e66656bc03dac384f852e4e5a7a8104a6" +source = "git+https://github.com/sfackler/rust-openssl#994e5ff8c63557ab2aa85c85cc6956b0b0216ca7" dependencies = [ "cc", "libc", diff --git a/src/rust/Cargo.toml b/src/rust/Cargo.toml index 7fc45add24b60..ca8835da807e8 100644 --- a/src/rust/Cargo.toml +++ b/src/rust/Cargo.toml @@ -36,3 +36,7 @@ overflow-checks = true [workspace] members = ["cryptography-cffi", "cryptography-openssl", "cryptography-x509"] + +[patch.crates-io] +openssl = { git = "https://github.com/sfackler/rust-openssl" } +openssl-sys = { git = "https://github.com/sfackler/rust-openssl" } diff --git a/src/rust/src/backend/ec.rs b/src/rust/src/backend/ec.rs index 59351b721a498..8d7bb351a7971 100644 --- a/src/rust/src/backend/ec.rs +++ b/src/rust/src/backend/ec.rs @@ -503,10 +503,6 @@ impl ECPublicKey { let mut verifier = openssl::pkey_ctx::PkeyCtx::new(&self.pkey)?; verifier.verify_init()?; let valid = verifier.verify(data, signature).unwrap_or(false); - // TODO: Empty the error stack. BoringSSL leaves one in the event of - // signature validation failure. Upstream to rust-openssl? - #[cfg(CRYPTOGRAPHY_IS_BORINGSSL)] - openssl::error::ErrorStack::get(); if !valid { return Err(CryptographyError::from( exceptions::InvalidSignature::new_err(()), diff --git a/src/rust/src/backend/mod.rs b/src/rust/src/backend/mod.rs index b032aaac44047..e9942e60ff476 100644 --- a/src/rust/src/backend/mod.rs +++ b/src/rust/src/backend/mod.rs @@ -13,6 +13,7 @@ pub(crate) mod hashes; pub(crate) mod hmac; pub(crate) mod kdf; pub(crate) mod poly1305; +pub(crate) mod rsa; pub(crate) mod utils; #[cfg(any(not(CRYPTOGRAPHY_IS_LIBRESSL), CRYPTOGRAPHY_LIBRESSL_370_OR_GREATER))] pub(crate) mod x25519; @@ -39,6 +40,7 @@ pub(crate) fn add_to_module(module: &pyo3::prelude::PyModule) -> pyo3::PyResult< module.add_submodule(hashes::create_module(module.py())?)?; module.add_submodule(hmac::create_module(module.py())?)?; module.add_submodule(kdf::create_module(module.py())?)?; + module.add_submodule(rsa::create_module(module.py())?)?; Ok(()) } diff --git a/src/rust/src/backend/rsa.rs b/src/rust/src/backend/rsa.rs new file mode 100644 index 0000000000000..9ce8316f8634f --- /dev/null +++ b/src/rust/src/backend/rsa.rs @@ -0,0 +1,614 @@ +// This file is dual licensed under the terms of the Apache License, Version +// 2.0, and the BSD License. See the LICENSE file in the root of this repository +// for complete details. + +use crate::backend::{hashes, utils}; +use crate::error::{CryptographyError, CryptographyResult}; +use crate::exceptions; +use foreign_types_shared::ForeignTypeRef; + +#[pyo3::prelude::pyclass( + module = "cryptography.hazmat.bindings._rust.openssl.rsa", + name = "RSAPrivateKey" +)] +struct RsaPrivateKey { + pkey: openssl::pkey::PKey, +} + +#[pyo3::prelude::pyclass( + module = "cryptography.hazmat.bindings._rust.openssl.rsa", + name = "RSAPublicKey" +)] +struct RsaPublicKey { + pkey: openssl::pkey::PKey, +} + +fn check_rsa_private_key( + rsa: &openssl::rsa::Rsa, +) -> CryptographyResult<()> { + if !rsa.check_key().unwrap_or(false) || rsa.p().unwrap().is_even() || rsa.q().unwrap().is_even() + { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("Invalid private key"), + )); + } + Ok(()) +} + +#[pyo3::prelude::pyfunction] +fn private_key_from_ptr( + ptr: usize, + unsafe_skip_rsa_key_validation: bool, +) -> CryptographyResult { + let pkey = unsafe { openssl::pkey::PKeyRef::from_ptr(ptr as *mut _) }; + if !unsafe_skip_rsa_key_validation { + check_rsa_private_key(&pkey.rsa().unwrap())?; + } + Ok(RsaPrivateKey { + pkey: pkey.to_owned(), + }) +} + +#[pyo3::prelude::pyfunction] +fn public_key_from_ptr(ptr: usize) -> RsaPublicKey { + let pkey = unsafe { openssl::pkey::PKeyRef::from_ptr(ptr as *mut _) }; + RsaPublicKey { + pkey: pkey.to_owned(), + } +} + +#[pyo3::prelude::pyfunction] +fn generate_private_key(public_exponent: u32, key_size: u32) -> CryptographyResult { + let e = openssl::bn::BigNum::from_u32(public_exponent)?; + let rsa = openssl::rsa::Rsa::generate_with_e(key_size, &e)?; + let pkey = openssl::pkey::PKey::from_rsa(rsa)?; + Ok(RsaPrivateKey { pkey }) +} + +#[pyo3::prelude::pyfunction] +fn from_private_numbers( + py: pyo3::Python<'_>, + numbers: &pyo3::PyAny, + unsafe_skip_rsa_key_validation: bool, +) -> CryptographyResult { + let public_numbers = numbers.getattr(pyo3::intern!(py, "public_numbers"))?; + + let rsa = openssl::rsa::Rsa::from_private_components( + utils::py_int_to_bn(py, public_numbers.getattr(pyo3::intern!(py, "n"))?)?, + utils::py_int_to_bn(py, public_numbers.getattr(pyo3::intern!(py, "e"))?)?, + utils::py_int_to_bn(py, numbers.getattr(pyo3::intern!(py, "d"))?)?, + utils::py_int_to_bn(py, numbers.getattr(pyo3::intern!(py, "p"))?)?, + utils::py_int_to_bn(py, numbers.getattr(pyo3::intern!(py, "q"))?)?, + utils::py_int_to_bn(py, numbers.getattr(pyo3::intern!(py, "dmp1"))?)?, + utils::py_int_to_bn(py, numbers.getattr(pyo3::intern!(py, "dmq1"))?)?, + utils::py_int_to_bn(py, numbers.getattr(pyo3::intern!(py, "iqmp"))?)?, + ) + .unwrap(); + if !unsafe_skip_rsa_key_validation { + check_rsa_private_key(&rsa)?; + } + let pkey = openssl::pkey::PKey::from_rsa(rsa)?; + Ok(RsaPrivateKey { pkey }) +} + +#[pyo3::prelude::pyfunction] +fn from_public_numbers( + py: pyo3::Python<'_>, + numbers: &pyo3::PyAny, +) -> CryptographyResult { + let rsa = openssl::rsa::Rsa::from_public_components( + utils::py_int_to_bn(py, numbers.getattr(pyo3::intern!(py, "n"))?)?, + utils::py_int_to_bn(py, numbers.getattr(pyo3::intern!(py, "e"))?)?, + ) + .unwrap(); + let pkey = openssl::pkey::PKey::from_rsa(rsa)?; + Ok(RsaPublicKey { pkey }) +} + +fn oaep_hash_supported(md: &openssl::hash::MessageDigest) -> bool { + (!cryptography_openssl::fips::is_enabled() && md == &openssl::hash::MessageDigest::sha1()) + || md == &openssl::hash::MessageDigest::sha224() + || md == &openssl::hash::MessageDigest::sha256() + || md == &openssl::hash::MessageDigest::sha384() + || md == &openssl::hash::MessageDigest::sha512() +} + +fn setup_encryption_ctx( + py: pyo3::Python<'_>, + ctx: &mut openssl::pkey_ctx::PkeyCtx, + padding: &pyo3::PyAny, +) -> CryptographyResult<()> { + let padding_mod = py.import(pyo3::intern!( + py, + "cryptography.hazmat.primitives.asymmetric.padding" + ))?; + let asymmetric_padding_class = padding_mod + .getattr(pyo3::intern!(py, "AsymmetricPadding"))? + .extract()?; + let pkcs1_class = padding_mod + .getattr(pyo3::intern!(py, "PKCS1v15"))? + .extract()?; + let oaep_class = padding_mod.getattr(pyo3::intern!(py, "OAEP"))?.extract()?; + let mgf1_class = padding_mod.getattr(pyo3::intern!(py, "MGF1"))?.extract()?; + + if !padding.is_instance(asymmetric_padding_class)? { + return Err(CryptographyError::from( + pyo3::exceptions::PyTypeError::new_err( + "Padding must be an instance of AsymmetricPadding.", + ), + )); + } + + let padding_enum = if padding.is_instance(pkcs1_class)? { + openssl::rsa::Padding::PKCS1 + } else if padding.is_instance(oaep_class)? { + if !padding + .getattr(pyo3::intern!(py, "_mgf"))? + .is_instance(mgf1_class)? + { + return Err(CryptographyError::from( + exceptions::UnsupportedAlgorithm::new_err(( + "Only MGF1 is supported.", + exceptions::Reasons::UNSUPPORTED_MGF, + )), + )); + } + + openssl::rsa::Padding::PKCS1_OAEP + } else { + return Err(CryptographyError::from( + exceptions::UnsupportedAlgorithm::new_err(( + format!( + "{} is not supported by this backend.", + padding.getattr(pyo3::intern!(py, "name"))? + ), + exceptions::Reasons::UNSUPPORTED_PADDING, + )), + )); + }; + + ctx.set_rsa_padding(padding_enum)?; + + if padding_enum == openssl::rsa::Padding::PKCS1_OAEP { + let mgf1_md = hashes::message_digest_from_algorithm( + py, + padding + .getattr(pyo3::intern!(py, "_mgf"))? + .getattr(pyo3::intern!(py, "_algorithm"))?, + )?; + let oaep_md = hashes::message_digest_from_algorithm( + py, + padding.getattr(pyo3::intern!(py, "_algorithm"))?, + )?; + + if !oaep_hash_supported(&mgf1_md) || !oaep_hash_supported(&oaep_md) { + return Err(CryptographyError::from( + exceptions::UnsupportedAlgorithm::new_err(( + "This combination of padding and hash algorithm is not supported", + exceptions::Reasons::UNSUPPORTED_PADDING, + )), + )); + } + + ctx.set_rsa_mgf1_md(openssl::md::Md::from_nid(mgf1_md.type_()).unwrap())?; + ctx.set_rsa_oaep_md(openssl::md::Md::from_nid(oaep_md.type_()).unwrap())?; + + if let Some(label) = padding + .getattr(pyo3::intern!(py, "_label"))? + .extract::>()? + { + if !label.is_empty() { + ctx.set_rsa_oaep_label(label)?; + } + } + } + + Ok(()) +} + +fn setup_signature_ctx( + py: pyo3::Python<'_>, + ctx: &mut openssl::pkey_ctx::PkeyCtx, + padding: &pyo3::PyAny, + algorithm: &pyo3::PyAny, + key_size: usize, + is_signing: bool, +) -> CryptographyResult<()> { + let padding_mod = py.import(pyo3::intern!( + py, + "cryptography.hazmat.primitives.asymmetric.padding" + ))?; + let asymmetric_padding_class = padding_mod.getattr(pyo3::intern!(py, "AsymmetricPadding"))?; + let pkcs1_class = padding_mod.getattr(pyo3::intern!(py, "PKCS1v15"))?; + let pss_class = padding_mod.getattr(pyo3::intern!(py, "PSS"))?.extract()?; + let max_length_class = padding_mod.getattr(pyo3::intern!(py, "_MaxLength"))?; + let digest_length_class = padding_mod.getattr(pyo3::intern!(py, "_DigestLength"))?; + let auto_class = padding_mod.getattr(pyo3::intern!(py, "_Auto"))?; + let mgf1_class = padding_mod.getattr(pyo3::intern!(py, "MGF1"))?; + let hash_algorithm_class = py + .import(pyo3::intern!(py, "cryptography.hazmat.primitives.hashes"))? + .getattr(pyo3::intern!(py, "HashAlgorithm"))?; + + if !padding.is_instance(asymmetric_padding_class)? { + return Err(CryptographyError::from( + pyo3::exceptions::PyTypeError::new_err( + "Padding must be an instance of AsymmetricPadding.", + ), + )); + } + + let padding_enum = if padding.is_instance(pkcs1_class)? { + openssl::rsa::Padding::PKCS1 + } else if padding.is_instance(pss_class)? { + if !padding + .getattr(pyo3::intern!(py, "_mgf"))? + .is_instance(mgf1_class)? + { + return Err(CryptographyError::from( + exceptions::UnsupportedAlgorithm::new_err(( + "Only MGF1 is supported.", + exceptions::Reasons::UNSUPPORTED_MGF, + )), + )); + } + + // PSS padding requires a hash algorithm + if !algorithm.is_instance(hash_algorithm_class)? { + return Err(CryptographyError::from( + pyo3::exceptions::PyTypeError::new_err( + "Expected instance of hashes.HashAlgorithm.", + ), + )); + } + + if algorithm + .getattr(pyo3::intern!(py, "digest_size"))? + .extract::()? + + 2 + > key_size + { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err( + "Digest too large for key size. Use a larger key or different digest.", + ), + )); + } + + openssl::rsa::Padding::PKCS1_PSS + } else { + return Err(CryptographyError::from( + exceptions::UnsupportedAlgorithm::new_err(( + format!( + "{} is not supported by this backend.", + padding.getattr(pyo3::intern!(py, "name"))? + ), + exceptions::Reasons::UNSUPPORTED_PADDING, + )), + )); + }; + + if !algorithm.is_none() { + let md = hashes::message_digest_from_algorithm(py, algorithm)?; + ctx.set_signature_md(openssl::md::Md::from_nid(md.type_()).unwrap()) + .or_else(|_| { + Err(CryptographyError::from( + exceptions::UnsupportedAlgorithm::new_err(( + format!( + "{} is not supported by this backend for RSA signing.", + algorithm.getattr(pyo3::intern!(py, "name"))? + ), + exceptions::Reasons::UNSUPPORTED_HASH, + )), + )) + })?; + } + ctx.set_rsa_padding(padding_enum).or_else(|_| { + Err(exceptions::UnsupportedAlgorithm::new_err(( + format!( + "{} is not supported for the RSA signature operation", + padding.getattr(pyo3::intern!(py, "name"))? + ), + exceptions::Reasons::UNSUPPORTED_PADDING, + ))) + })?; + + if padding_enum == openssl::rsa::Padding::PKCS1_PSS { + let salt = padding.getattr(pyo3::intern!(py, "_salt_length"))?; + if salt.is_instance(max_length_class)? { + ctx.set_rsa_pss_saltlen(openssl::sign::RsaPssSaltlen::MAXIMUM_LENGTH)?; + } else if salt.is_instance(digest_length_class)? { + ctx.set_rsa_pss_saltlen(openssl::sign::RsaPssSaltlen::DIGEST_LENGTH)?; + } else if salt.is_instance(auto_class)? { + if is_signing { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err( + "PSS salt length can only be set to Auto when verifying", + ), + )); + } + } else { + ctx.set_rsa_pss_saltlen(openssl::sign::RsaPssSaltlen::custom(salt.extract::()?))?; + }; + + let mgf1_md = hashes::message_digest_from_algorithm( + py, + padding + .getattr(pyo3::intern!(py, "_mgf"))? + .getattr(pyo3::intern!(py, "_algorithm"))?, + )?; + ctx.set_rsa_mgf1_md(openssl::md::Md::from_nid(mgf1_md.type_()).unwrap())?; + } + + Ok(()) +} + +#[pyo3::prelude::pymethods] +impl RsaPrivateKey { + fn sign<'p>( + &self, + py: pyo3::Python<'p>, + data: &[u8], + padding: &pyo3::PyAny, + algorithm: &pyo3::PyAny, + ) -> CryptographyResult<&'p pyo3::PyAny> { + let (data, algorithm): (&[u8], &pyo3::PyAny) = py + .import(pyo3::intern!( + py, + "cryptography.hazmat.backends.openssl.utils" + ))? + .call_method1( + pyo3::intern!(py, "_calculate_digest_and_algorithm"), + (data, algorithm), + )? + .extract()?; + + let mut ctx = openssl::pkey_ctx::PkeyCtx::new(&self.pkey)?; + ctx.sign_init().map_err(|_| { + pyo3::exceptions::PyValueError::new_err("Unable to sign/verify with this key") + })?; + setup_signature_ctx(py, &mut ctx, padding, algorithm, self.pkey.size(), true)?; + + let length = ctx.sign(data, None)?; + Ok(pyo3::types::PyBytes::new_with(py, length, |b| { + ctx.sign(data, Some(b)).map_err(|_| { + pyo3::exceptions::PyValueError::new_err( + "Digest or salt length too long for key size. Use a larger key or shorter salt length if you are specifying a PSS salt", + ) + })?; + Ok(()) + })?) + } + + fn decrypt<'p>( + &self, + py: pyo3::Python<'p>, + ciphertext: &[u8], + padding: &pyo3::PyAny, + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + let key_size_bytes = + usize::try_from((self.pkey.rsa().unwrap().n().num_bits() + 7) / 8).unwrap(); + if key_size_bytes != ciphertext.len() { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err( + "Ciphertext length must be equal to key size.", + ), + )); + } + + let mut ctx = openssl::pkey_ctx::PkeyCtx::new(&self.pkey)?; + ctx.decrypt_init()?; + + setup_encryption_ctx(py, &mut ctx, padding)?; + + // TODO: avoid copy + let mut plaintext = vec![]; + // XXX: reproduce the constant time handling from Python + ctx.decrypt_to_vec(ciphertext, &mut plaintext) + .map_err(|_| pyo3::exceptions::PyValueError::new_err("Decryption failed"))?; + Ok(pyo3::types::PyBytes::new(py, &plaintext)) + } + + #[getter] + fn key_size(&self) -> i32 { + self.pkey.rsa().unwrap().n().num_bits() + } + + fn public_key(&self) -> CryptographyResult { + let priv_rsa = self.pkey.rsa().unwrap(); + let rsa = openssl::rsa::Rsa::from_public_components( + priv_rsa.n().to_owned()?, + priv_rsa.e().to_owned()?, + ) + .unwrap(); + let pkey = openssl::pkey::PKey::from_rsa(rsa)?; + Ok(RsaPublicKey { pkey }) + } + + fn private_numbers<'p>(&self, py: pyo3::Python<'p>) -> CryptographyResult<&'p pyo3::PyAny> { + let rsa = self.pkey.rsa().unwrap(); + + let py_p = utils::bn_to_py_int(py, rsa.p().unwrap())?; + let py_q = utils::bn_to_py_int(py, rsa.q().unwrap())?; + let py_d = utils::bn_to_py_int(py, rsa.d())?; + let py_dmp1 = utils::bn_to_py_int(py, rsa.dmp1().unwrap())?; + let py_dmq1 = utils::bn_to_py_int(py, rsa.dmq1().unwrap())?; + let py_iqmp = utils::bn_to_py_int(py, rsa.iqmp().unwrap())?; + let py_e = utils::bn_to_py_int(py, rsa.e())?; + let py_n = utils::bn_to_py_int(py, rsa.n())?; + + let rsa_mod = py.import(pyo3::intern!( + py, + "cryptography.hazmat.primitives.asymmetric.rsa" + ))?; + + let public_numbers = + rsa_mod.call_method1(pyo3::intern!(py, "RSAPublicNumbers"), (py_e, py_n))?; + Ok(rsa_mod.call_method1( + pyo3::intern!(py, "RSAPrivateNumbers"), + (py_p, py_q, py_d, py_dmp1, py_dmq1, py_iqmp, public_numbers), + )?) + } + + fn private_bytes<'p>( + slf: &pyo3::PyCell, + py: pyo3::Python<'p>, + encoding: &pyo3::PyAny, + format: &pyo3::PyAny, + encryption_algorithm: &pyo3::PyAny, + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + utils::pkey_private_bytes( + py, + slf, + &slf.borrow().pkey, + encoding, + format, + encryption_algorithm, + true, + false, + ) + } +} + +#[pyo3::prelude::pymethods] +impl RsaPublicKey { + fn verify( + &self, + py: pyo3::Python<'_>, + signature: &[u8], + data: &[u8], + padding: &pyo3::PyAny, + algorithm: &pyo3::PyAny, + ) -> CryptographyResult<()> { + let (data, algorithm): (&[u8], &pyo3::PyAny) = py + .import(pyo3::intern!( + py, + "cryptography.hazmat.backends.openssl.utils" + ))? + .call_method1( + pyo3::intern!(py, "_calculate_digest_and_algorithm"), + (data, algorithm), + )? + .extract()?; + + let mut ctx = openssl::pkey_ctx::PkeyCtx::new(&self.pkey)?; + ctx.verify_init()?; + setup_signature_ctx(py, &mut ctx, padding, algorithm, self.pkey.size(), false)?; + + let valid = ctx.verify(data, signature).unwrap_or(false); + if !valid { + return Err(CryptographyError::from( + exceptions::InvalidSignature::new_err(()), + )); + } + + Ok(()) + } + + fn encrypt<'p>( + &self, + py: pyo3::Python<'p>, + plaintext: &[u8], + padding: &pyo3::PyAny, + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + let mut ctx = openssl::pkey_ctx::PkeyCtx::new(&self.pkey)?; + ctx.encrypt_init()?; + + setup_encryption_ctx(py, &mut ctx, padding)?; + + // TODO: avoid copy + let mut ciphertext = vec![]; + ctx.encrypt_to_vec(plaintext, &mut ciphertext) + .map_err(|_| pyo3::exceptions::PyValueError::new_err("Encryption failed"))?; + Ok(pyo3::types::PyBytes::new(py, &ciphertext)) + } + + fn recover_data_from_signature<'p>( + &self, + py: pyo3::Python<'p>, + signature: &[u8], + padding: &pyo3::PyAny, + algorithm: &pyo3::PyAny, + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + let prehashed_class = py + .import(pyo3::intern!( + py, + "cryptography.hazmat.primitives.asymmetric.utils" + ))? + .getattr(pyo3::intern!(py, "Prehashed"))?; + + if algorithm.is_instance(prehashed_class)? { + return Err(CryptographyError::from( + pyo3::exceptions::PyTypeError::new_err( + "Prehashed is only supported in the sign and verify methods. It cannot be used with recover_data_from_signature.", + ), + )); + } + + // TODO: constant time stuff + let mut ctx = openssl::pkey_ctx::PkeyCtx::new(&self.pkey)?; + ctx.verify_recover_init()?; + setup_signature_ctx(py, &mut ctx, padding, algorithm, self.pkey.size(), false)?; + + let length = ctx.verify_recover(signature, None)?; + let mut buf = vec![0u8; length]; + let length = ctx + .verify_recover(signature, Some(&mut buf)) + .map_err(|_| exceptions::InvalidSignature::new_err(()))?; + + Ok(pyo3::types::PyBytes::new(py, &buf[..length])) + } + + #[getter] + fn key_size(&self) -> i32 { + self.pkey.rsa().unwrap().n().num_bits() + } + + fn public_numbers<'p>(&self, py: pyo3::Python<'p>) -> CryptographyResult<&'p pyo3::PyAny> { + let rsa = self.pkey.rsa().unwrap(); + + let py_e = utils::bn_to_py_int(py, rsa.e())?; + let py_n = utils::bn_to_py_int(py, rsa.n())?; + + let rsa_mod = py.import(pyo3::intern!( + py, + "cryptography.hazmat.primitives.asymmetric.rsa" + ))?; + + Ok(rsa_mod.call_method1(pyo3::intern!(py, "RSAPublicNumbers"), (py_e, py_n))?) + } + + fn public_bytes<'p>( + slf: &pyo3::PyCell, + py: pyo3::Python<'p>, + encoding: &pyo3::PyAny, + format: &pyo3::PyAny, + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + utils::pkey_public_bytes(py, slf, &slf.borrow().pkey, encoding, format, true, false) + } + + fn __richcmp__( + &self, + other: pyo3::PyRef<'_, RsaPublicKey>, + op: pyo3::basic::CompareOp, + ) -> pyo3::PyResult { + match op { + pyo3::basic::CompareOp::Eq => Ok(self.pkey.public_eq(&other.pkey)), + pyo3::basic::CompareOp::Ne => Ok(!self.pkey.public_eq(&other.pkey)), + _ => Err(pyo3::exceptions::PyTypeError::new_err("Cannot be ordered")), + } + } +} + +pub(crate) fn create_module(py: pyo3::Python<'_>) -> pyo3::PyResult<&pyo3::prelude::PyModule> { + let m = pyo3::prelude::PyModule::new(py, "rsa")?; + m.add_function(pyo3::wrap_pyfunction!(private_key_from_ptr, m)?)?; + m.add_function(pyo3::wrap_pyfunction!(public_key_from_ptr, m)?)?; + m.add_function(pyo3::wrap_pyfunction!(generate_private_key, m)?)?; + m.add_function(pyo3::wrap_pyfunction!(from_private_numbers, m)?)?; + m.add_function(pyo3::wrap_pyfunction!(from_public_numbers, m)?)?; + + m.add_class::()?; + m.add_class::()?; + + Ok(m) +} diff --git a/src/rust/src/backend/utils.rs b/src/rust/src/backend/utils.rs index 086f88ab93608..a2679cddedcf8 100644 --- a/src/rust/src/backend/utils.rs +++ b/src/rust/src/backend/utils.rs @@ -163,7 +163,30 @@ pub(crate) fn pkey_private_bytes<'p>( } if format.is(private_format_class.getattr(pyo3::intern!(py, "TraditionalOpenSSL"))?) { - if let Ok(dsa) = pkey.dsa() { + if let Ok(rsa) = pkey.rsa() { + if encoding.is(encoding_class.getattr(pyo3::intern!(py, "PEM"))?) { + let pem_bytes = if password.is_empty() { + rsa.private_key_to_pem()? + } else { + rsa.private_key_to_pem_passphrase( + openssl::symm::Cipher::aes_256_cbc(), + password, + )? + }; + return Ok(pyo3::types::PyBytes::new(py, &pem_bytes)); + } else if encoding.is(encoding_class.getattr(pyo3::intern!(py, "DER"))?) { + if !password.is_empty() { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err( + "Encryption is not supported for DER encoded traditional OpenSSL keys", + ), + )); + } + + let der_bytes = rsa.private_key_to_der()?; + return Ok(pyo3::types::PyBytes::new(py, &der_bytes)); + } + } else if let Ok(dsa) = pkey.dsa() { if encoding.is(encoding_class.getattr(pyo3::intern!(py, "PEM"))?) { let pem_bytes = if password.is_empty() { dsa.private_key_to_pem()? @@ -332,6 +355,23 @@ pub(crate) fn pkey_public_bytes<'p>( } } + if let Ok(rsa) = pkey.rsa() { + if format.is(public_format_class.getattr(pyo3::intern!(py, "PKCS1"))?) { + if encoding.is(encoding_class.getattr(pyo3::intern!(py, "PEM"))?) { + let pem_bytes = rsa.public_key_to_pem_pkcs1()?; + return Ok(pyo3::types::PyBytes::new(py, &pem_bytes)); + } else if encoding.is(encoding_class.getattr(pyo3::intern!(py, "DER"))?) { + let der_bytes = rsa.public_key_to_der_pkcs1()?; + return Ok(pyo3::types::PyBytes::new(py, &der_bytes)); + } + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err( + "PKCS1 works only with PEM or DER encoding", + ), + )); + } + } + // OpenSSH + OpenSSH if openssh_allowed && format.is(public_format_class.getattr(pyo3::intern!(py, "OpenSSH"))?) { if encoding.is(encoding_class.getattr(pyo3::intern!(py, "OpenSSH"))?) { diff --git a/tests/hazmat/primitives/test_rsa.py b/tests/hazmat/primitives/test_rsa.py index 3cb3b17efb228..e8c3ff3cdb681 100644 --- a/tests/hazmat/primitives/test_rsa.py +++ b/tests/hazmat/primitives/test_rsa.py @@ -845,6 +845,21 @@ def test_unsupported_hash(self, rsa_key_512: rsa.RSAPrivateKey, backend): with raises_unsupported_algorithm(_Reasons.UNSUPPORTED_HASH): private_key.sign(message, pss, hashes.BLAKE2s(32)) + @pytest.mark.supported( + only_if=lambda backend: backend.rsa_padding_supported( + padding.PSS(mgf=padding.MGF1(hashes.SHA256()), salt_length=0) + ), + skip_message="Does not support PSS.", + ) + def test_unsupported_hash_pss_mgf1(self, rsa_key_2048: rsa.RSAPrivateKey): + private_key = rsa_key_2048 + message = b"my message" + pss = padding.PSS( + mgf=padding.MGF1(DummyHashAlgorithm()), salt_length=0 + ) + with raises_unsupported_algorithm(_Reasons.UNSUPPORTED_HASH): + private_key.sign(message, pss, hashes.SHA256()) + @pytest.mark.supported( only_if=lambda backend: backend.rsa_padding_supported( padding.PSS(mgf=padding.MGF1(hashes.SHA1()), salt_length=0) @@ -1937,6 +1952,27 @@ def test_invalid_oaep_decryption_data_to_large_for_modulus(self, backend): ), ) + def test_unsupported_oaep_hash(self, rsa_key_2048: rsa.RSAPrivateKey): + private_key = rsa_key_2048 + with raises_unsupported_algorithm(_Reasons.UNSUPPORTED_HASH): + private_key.decrypt( + b"0" * 256, + padding.OAEP( + mgf=padding.MGF1(DummyHashAlgorithm()), + algorithm=hashes.SHA256(), + label=None, + ), + ) + with raises_unsupported_algorithm(_Reasons.UNSUPPORTED_HASH): + private_key.decrypt( + b"0" * 256, + padding.OAEP( + mgf=padding.MGF1(hashes.SHA256()), + algorithm=DummyHashAlgorithm(), + label=None, + ), + ) + def test_unsupported_oaep_mgf( self, rsa_key_2048: rsa.RSAPrivateKey, backend ): @@ -2734,3 +2770,5 @@ def test_public_key_equality(self, rsa_key_2048: rsa.RSAPrivateKey): assert key1 == key2 assert key1 != key3 assert key1 != object() + with pytest.raises(TypeError): + key1 < key2 # type: ignore[operator]