Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate EC support to Rust #9024

Merged
merged 1 commit into from
Jun 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 17 additions & 244 deletions src/cryptography/hazmat/backends/openssl/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,12 @@
import contextlib
import itertools
import typing
from contextlib import contextmanager

from cryptography import utils, x509
from cryptography.exceptions import UnsupportedAlgorithm, _Reasons
from cryptography.hazmat.backends.openssl import aead
from cryptography.hazmat.backends.openssl.ciphers import _CipherContext
from cryptography.hazmat.backends.openssl.cmac import _CMACContext
from cryptography.hazmat.backends.openssl.ec import (
_EllipticCurvePrivateKey,
_EllipticCurvePublicKey,
)
from cryptography.hazmat.backends.openssl.rsa import (
_RSAPrivateKey,
_RSAPublicKey,
Expand Down Expand Up @@ -542,10 +537,9 @@ def _evp_pkey_to_private_key(
int(self._ffi.cast("uintptr_t", evp_pkey))
)
elif key_type == self._lib.EVP_PKEY_EC:
ec_cdata = self._lib.EVP_PKEY_get1_EC_KEY(evp_pkey)
self.openssl_assert(ec_cdata != self._ffi.NULL)
ec_cdata = self._ffi.gc(ec_cdata, self._lib.EC_KEY_free)
return _EllipticCurvePrivateKey(self, ec_cdata, evp_pkey)
return rust_openssl.ec.private_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey))
)
elif key_type in self._dh_types:
return rust_openssl.dh.private_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey))
Expand Down Expand Up @@ -603,12 +597,9 @@ def _evp_pkey_to_public_key(self, evp_pkey) -> PublicKeyTypes:
int(self._ffi.cast("uintptr_t", evp_pkey))
)
elif key_type == self._lib.EVP_PKEY_EC:
ec_cdata = self._lib.EVP_PKEY_get1_EC_KEY(evp_pkey)
if ec_cdata == self._ffi.NULL:
errors = self._consume_errors()
raise ValueError("Unable to load EC key", errors)
ec_cdata = self._ffi.gc(ec_cdata, self._lib.EC_KEY_free)
return _EllipticCurvePublicKey(self, ec_cdata, evp_pkey)
return rust_openssl.ec.public_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey))
)
elif key_type in self._dh_types:
return rust_openssl.dh.public_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey))
Expand Down Expand Up @@ -944,20 +935,7 @@ def elliptic_curve_supported(self, curve: ec.EllipticCurve) -> bool:
):
return False

try:
curve_nid = self._elliptic_curve_to_nid(curve)
except UnsupportedAlgorithm:
curve_nid = self._lib.NID_undef

group = self._lib.EC_GROUP_new_by_curve_name(curve_nid)

if group == self._ffi.NULL:
self._consume_errors()
return False
else:
self.openssl_assert(curve_nid != self._lib.NID_undef)
self._lib.EC_GROUP_free(group)
return True
return rust_openssl.ec.curve_supported(curve)

def elliptic_curve_signature_algorithm_supported(
self,
Expand All @@ -979,158 +957,27 @@ def generate_elliptic_curve_private_key(
"""
Generate a new private key on the named curve.
"""

if self.elliptic_curve_supported(curve):
ec_cdata = self._ec_key_new_by_curve(curve)

res = self._lib.EC_KEY_generate_key(ec_cdata)
self.openssl_assert(res == 1)

evp_pkey = self._ec_cdata_to_evp_pkey(ec_cdata)

return _EllipticCurvePrivateKey(self, ec_cdata, evp_pkey)
else:
raise UnsupportedAlgorithm(
f"Backend object does not support {curve.name}.",
_Reasons.UNSUPPORTED_ELLIPTIC_CURVE,
)
return rust_openssl.ec.generate_private_key(curve)

def load_elliptic_curve_private_numbers(
self, numbers: ec.EllipticCurvePrivateNumbers
) -> ec.EllipticCurvePrivateKey:
public = numbers.public_numbers

ec_cdata = self._ec_key_new_by_curve(public.curve)

private_value = self._ffi.gc(
self._int_to_bn(numbers.private_value), self._lib.BN_clear_free
)
res = self._lib.EC_KEY_set_private_key(ec_cdata, private_value)
if res != 1:
self._consume_errors()
raise ValueError("Invalid EC key.")

with self._tmp_bn_ctx() as bn_ctx:
self._ec_key_set_public_key_affine_coordinates(
ec_cdata, public.x, public.y, bn_ctx
)
# derive the expected public point and compare it to the one we
# just set based on the values we were given. If they don't match
# this isn't a valid key pair.
group = self._lib.EC_KEY_get0_group(ec_cdata)
self.openssl_assert(group != self._ffi.NULL)
set_point = backend._lib.EC_KEY_get0_public_key(ec_cdata)
self.openssl_assert(set_point != self._ffi.NULL)
computed_point = self._lib.EC_POINT_new(group)
self.openssl_assert(computed_point != self._ffi.NULL)
computed_point = self._ffi.gc(
computed_point, self._lib.EC_POINT_free
)
res = self._lib.EC_POINT_mul(
group,
computed_point,
private_value,
self._ffi.NULL,
self._ffi.NULL,
bn_ctx,
)
self.openssl_assert(res == 1)
if (
self._lib.EC_POINT_cmp(
group, set_point, computed_point, bn_ctx
)
!= 0
):
raise ValueError("Invalid EC key.")

evp_pkey = self._ec_cdata_to_evp_pkey(ec_cdata)

return _EllipticCurvePrivateKey(self, ec_cdata, evp_pkey)
return rust_openssl.ec.from_private_numbers(numbers)

def load_elliptic_curve_public_numbers(
self, numbers: ec.EllipticCurvePublicNumbers
) -> ec.EllipticCurvePublicKey:
ec_cdata = self._ec_key_new_by_curve(numbers.curve)
with self._tmp_bn_ctx() as bn_ctx:
self._ec_key_set_public_key_affine_coordinates(
ec_cdata, numbers.x, numbers.y, bn_ctx
)
evp_pkey = self._ec_cdata_to_evp_pkey(ec_cdata)

return _EllipticCurvePublicKey(self, ec_cdata, evp_pkey)
return rust_openssl.ec.from_public_numbers(numbers)

def load_elliptic_curve_public_bytes(
self, curve: ec.EllipticCurve, point_bytes: bytes
) -> ec.EllipticCurvePublicKey:
ec_cdata = self._ec_key_new_by_curve(curve)
group = self._lib.EC_KEY_get0_group(ec_cdata)
self.openssl_assert(group != self._ffi.NULL)
point = self._lib.EC_POINT_new(group)
self.openssl_assert(point != self._ffi.NULL)
point = self._ffi.gc(point, self._lib.EC_POINT_free)
with self._tmp_bn_ctx() as bn_ctx:
res = self._lib.EC_POINT_oct2point(
group, point, point_bytes, len(point_bytes), bn_ctx
)
if res != 1:
self._consume_errors()
raise ValueError("Invalid public bytes for the given curve")

res = self._lib.EC_KEY_set_public_key(ec_cdata, point)
self.openssl_assert(res == 1)
evp_pkey = self._ec_cdata_to_evp_pkey(ec_cdata)
return _EllipticCurvePublicKey(self, ec_cdata, evp_pkey)
return rust_openssl.ec.from_public_bytes(curve, point_bytes)

def derive_elliptic_curve_private_key(
self, private_value: int, curve: ec.EllipticCurve
) -> ec.EllipticCurvePrivateKey:
ec_cdata = self._ec_key_new_by_curve(curve)

group = self._lib.EC_KEY_get0_group(ec_cdata)
self.openssl_assert(group != self._ffi.NULL)

point = self._lib.EC_POINT_new(group)
self.openssl_assert(point != self._ffi.NULL)
point = self._ffi.gc(point, self._lib.EC_POINT_free)

value = self._int_to_bn(private_value)
value = self._ffi.gc(value, self._lib.BN_clear_free)

with self._tmp_bn_ctx() as bn_ctx:
res = self._lib.EC_POINT_mul(
group, point, value, self._ffi.NULL, self._ffi.NULL, bn_ctx
)
self.openssl_assert(res == 1)

bn_x = self._lib.BN_CTX_get(bn_ctx)
bn_y = self._lib.BN_CTX_get(bn_ctx)

res = self._lib.EC_POINT_get_affine_coordinates(
group, point, bn_x, bn_y, bn_ctx
)
if res != 1:
self._consume_errors()
raise ValueError("Unable to derive key from private_value")

res = self._lib.EC_KEY_set_public_key(ec_cdata, point)
self.openssl_assert(res == 1)
private = self._int_to_bn(private_value)
private = self._ffi.gc(private, self._lib.BN_clear_free)
res = self._lib.EC_KEY_set_private_key(ec_cdata, private)
self.openssl_assert(res == 1)

evp_pkey = self._ec_cdata_to_evp_pkey(ec_cdata)

return _EllipticCurvePrivateKey(self, ec_cdata, evp_pkey)

def _ec_key_new_by_curve(self, curve: ec.EllipticCurve):
curve_nid = self._elliptic_curve_to_nid(curve)
return self._ec_key_new_by_curve_nid(curve_nid)

def _ec_key_new_by_curve_nid(self, curve_nid: int):
ec_cdata = self._lib.EC_KEY_new_by_curve_name(curve_nid)
self.openssl_assert(ec_cdata != self._ffi.NULL)
return self._ffi.gc(ec_cdata, self._lib.EC_KEY_free)
return rust_openssl.ec.derive_private_key(private_value, curve)

def elliptic_curve_exchange_algorithm_supported(
self, algorithm: ec.ECDH, curve: ec.EllipticCurve
Expand All @@ -1139,73 +986,6 @@ def elliptic_curve_exchange_algorithm_supported(
algorithm, ec.ECDH
)

def _ec_cdata_to_evp_pkey(self, ec_cdata):
evp_pkey = self._create_evp_pkey_gc()
res = self._lib.EVP_PKEY_set1_EC_KEY(evp_pkey, ec_cdata)
self.openssl_assert(res == 1)
return evp_pkey

def _elliptic_curve_to_nid(self, curve: ec.EllipticCurve) -> int:
"""
Get the NID for a curve name.
"""

curve_aliases = {"secp192r1": "prime192v1", "secp256r1": "prime256v1"}

curve_name = curve_aliases.get(curve.name, curve.name)

curve_nid = self._lib.OBJ_sn2nid(curve_name.encode())
if curve_nid == self._lib.NID_undef:
raise UnsupportedAlgorithm(
f"{curve.name} is not a supported elliptic curve",
_Reasons.UNSUPPORTED_ELLIPTIC_CURVE,
)
return curve_nid

@contextmanager
def _tmp_bn_ctx(self):
bn_ctx = self._lib.BN_CTX_new()
self.openssl_assert(bn_ctx != self._ffi.NULL)
bn_ctx = self._ffi.gc(bn_ctx, self._lib.BN_CTX_free)
self._lib.BN_CTX_start(bn_ctx)
try:
yield bn_ctx
finally:
self._lib.BN_CTX_end(bn_ctx)

def _ec_key_set_public_key_affine_coordinates(
self,
ec_cdata,
x: int,
y: int,
bn_ctx,
) -> None:
"""
Sets the public key point in the EC_KEY context to the affine x and y
values.
"""

if x < 0 or y < 0:
raise ValueError(
"Invalid EC key. Both x and y must be non-negative."
)

x = self._ffi.gc(self._int_to_bn(x), self._lib.BN_free)
y = self._ffi.gc(self._int_to_bn(y), self._lib.BN_free)
group = self._lib.EC_KEY_get0_group(ec_cdata)
self.openssl_assert(group != self._ffi.NULL)
point = self._lib.EC_POINT_new(group)
self.openssl_assert(point != self._ffi.NULL)
point = self._ffi.gc(point, self._lib.EC_POINT_free)
res = self._lib.EC_POINT_set_affine_coordinates(
group, point, x, y, bn_ctx
)
if res != 1:
self._consume_errors()
raise ValueError("Invalid EC key.")
res = self._lib.EC_KEY_set_public_key(ec_cdata, point)
self.openssl_assert(res == 1)

def _private_key_bytes(
self,
encoding: serialization.Encoding,
Expand Down Expand Up @@ -1278,11 +1058,8 @@ def _private_key_bytes(
key_type = self._lib.EVP_PKEY_id(evp_pkey)

if encoding is serialization.Encoding.PEM:
if key_type == self._lib.EVP_PKEY_RSA:
write_bio = self._lib.PEM_write_bio_RSAPrivateKey
else:
assert key_type == self._lib.EVP_PKEY_EC
write_bio = self._lib.PEM_write_bio_ECPrivateKey
assert key_type == self._lib.EVP_PKEY_RSA
write_bio = self._lib.PEM_write_bio_RSAPrivateKey
return self._private_key_bytes_via_bio(
write_bio, cdata, password
)
Expand All @@ -1293,11 +1070,8 @@ def _private_key_bytes(
"Encryption is not supported for DER encoded "
"traditional OpenSSL keys"
)
if key_type == self._lib.EVP_PKEY_RSA:
write_bio = self._lib.i2d_RSAPrivateKey_bio
else:
assert key_type == self._lib.EVP_PKEY_EC
write_bio = self._lib.i2d_ECPrivateKey_bio
assert key_type == self._lib.EVP_PKEY_RSA
write_bio = self._lib.i2d_RSAPrivateKey_bio
return self._bio_func_output(write_bio, cdata)

raise ValueError("Unsupported encoding for TraditionalOpenSSL")
Expand Down Expand Up @@ -1374,8 +1148,7 @@ def _public_key_bytes(
if format is serialization.PublicFormat.PKCS1:
# Only RSA is supported here.
key_type = self._lib.EVP_PKEY_id(evp_pkey)
if key_type != self._lib.EVP_PKEY_RSA:
raise ValueError("PKCS1 format is supported only for RSA keys")
self.openssl_assert(key_type == self._lib.EVP_PKEY_RSA)

if encoding is serialization.Encoding.PEM:
write_bio = self._lib.PEM_write_bio_RSAPublicKey
Expand Down
Loading
Loading