Skip to content

Commit

Permalink
Resolve Codeql complaints (#215)
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarprudnikov authored Aug 28, 2024
1 parent 5ac6f3a commit 359a78e
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 50 deletions.
75 changes: 37 additions & 38 deletions pyscitt/pyscitt/crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from uuid import uuid4

warnings.filterwarnings("ignore", category=Warning)
Expand Down Expand Up @@ -69,12 +69,6 @@

RECOMMENDED_RSA_PUBLIC_EXPONENT = 65537

REGISTERED_EC_CURVES = {
"P-256": P256,
"P-384": P384,
"P-521": P521,
}

Pem = str

COSE_HEADER_PARAM_ISSUER = 391
Expand All @@ -90,12 +84,36 @@

RegistrationInfoValue = Union[str, bytes, int]
RegistrationInfo = Dict[str, RegistrationInfoValue]
CoseCurveTypes = Union[Type[P256], Type[P384], Type[P521]]
CoseCurveType = Tuple[str, CoseCurveTypes]


def ec_curve_from_name(name: str) -> EllipticCurve:
if name == "P-256":
return ec.SECP256R1()
elif name == "P-384":
return ec.SECP384R1()
elif name == "P-521":
return ec.SECP521R1()
else:
raise ValueError(f"Unsupported EC curve: {name}")


def cose_curve_from_ec(curve: EllipticCurve) -> CoseCurveType:
if isinstance(curve, ec.SECP256R1):
return ("P-256", P256)
elif isinstance(curve, ec.SECP384R1):
return ("P-384", P384)
elif isinstance(curve, ec.SECP521R1):
return ("P-521", P521)
else:
raise ValueError(f"Unsupported EC curve: {curve}")


def generate_rsa_keypair(key_size: int) -> Tuple[Pem, Pem]:
def generate_rsa_keypair() -> Tuple[Pem, Pem]:
priv = rsa.generate_private_key(
public_exponent=RECOMMENDED_RSA_PUBLIC_EXPONENT,
key_size=key_size,
key_size=2048,
)
pub = priv.public_key()
priv_pem = priv.private_bytes(
Expand All @@ -107,12 +125,8 @@ def generate_rsa_keypair(key_size: int) -> Tuple[Pem, Pem]:
return priv_pem, pub_pem


def generate_ec_keypair(curve: str) -> Tuple[Pem, Pem]:
if curve not in REGISTERED_EC_CURVES:
raise NotImplementedError(f"Unsupported curve: {curve}")
curve_obj = REGISTERED_EC_CURVES[curve].curve_obj
assert isinstance(curve_obj, EllipticCurve)
priv = ec.generate_private_key(curve=curve_obj)
def generate_ec_keypair(curve_name: str) -> Tuple[Pem, Pem]:
priv = ec.generate_private_key(curve=ec_curve_from_name(curve_name))
pub = priv.public_key()
priv_pem = priv.private_bytes(
Encoding.PEM, PrivateFormat.PKCS8, NoEncryption()
Expand Down Expand Up @@ -140,11 +154,10 @@ def generate_ed25519_keypair() -> Tuple[Pem, Pem]:
def generate_keypair(
kty: str,
*,
rsa_key_size: Optional[int] = None,
ec_curve: Optional[str] = None,
) -> Tuple[str, str]:
if kty == "rsa":
return generate_rsa_keypair(rsa_key_size or 2048)
return generate_rsa_keypair()
elif kty == "ec":
return generate_ec_keypair(ec_curve or "P-256")
elif kty == "ed25519":
Expand Down Expand Up @@ -492,13 +505,7 @@ def from_cryptography_eckey_obj(
priv_nums = None
pub_nums = ext_key.public_numbers()

# Create map of cryptography curves to cose curves. E.g. {ec.SECP256R1: P256, ...}
registered_crvs = {
type(crv.curve_obj): crv for crv in REGISTERED_EC_CURVES.values()
}
if type(pub_nums.curve) not in registered_crvs:
raise ValueError(f"Unsupported EC Curve: {type(pub_nums.curve)}")
curve = registered_crvs[type(pub_nums.curve)]
_, curve = cose_curve_from_ec(pub_nums.curve)

cose_key = {}
if pub_nums:
Expand Down Expand Up @@ -616,7 +623,7 @@ def get_last_embedded_receipt_from_cose(buf: bytes) -> Union[bytes, None]:


def load_private_key(key_path: Path) -> Pem:
with open(key_path) as f:
with open(key_path, encoding="utf-8") as f:
key_priv_pem = f.read()
if is_ssh_private_key(key_priv_pem):
key_priv_pem = ssh_private_key_to_pem(key_priv_pem)
Expand Down Expand Up @@ -654,16 +661,9 @@ def encode_pub_num_jwk(dec):
elif isinstance(pub_key, EllipticCurvePublicKey):
pub_numbers = pub_key.public_numbers()
curve = pub_numbers.curve
# Create map of curves to names. E.g. {ec.SECP256R1: "P-256", ...}
registered_crvs = {
type(crv.curve_obj): name for name, crv in REGISTERED_EC_CURVES.items()
}
if type(curve) not in registered_crvs:
raise ValueError(f"Unsupported EC Curve: {curve}")
crv_name = registered_crvs[type(curve)]
x = pub_numbers.x.to_bytes(REGISTERED_EC_CURVES[crv_name].size, "big")
y = pub_numbers.y.to_bytes(REGISTERED_EC_CURVES[crv_name].size, "big")

crv_name, crv = cose_curve_from_ec(curve)
x = pub_numbers.x.to_bytes(crv.size, "big")
y = pub_numbers.y.to_bytes(crv.size, "big")
jwk = {
"kty": "EC",
"crv": crv_name,
Expand Down Expand Up @@ -715,7 +715,7 @@ def sign_claimset(
claims: bytes,
content_type: str,
feed: Optional[str] = None,
registration_info: RegistrationInfo = {},
registration_info: Optional[RegistrationInfo] = None,
svn: Optional[int] = None,
cwt: bool = False,
) -> bytes:
Expand Down Expand Up @@ -804,8 +804,7 @@ def convert_jwk_to_pem(jwk: dict) -> Pem:
if jwk.get("kty") == "EC":
x = int.from_bytes(base64.urlsafe_b64decode(jwk["x"]), "big")
y = int.from_bytes(base64.urlsafe_b64decode(jwk["y"]), "big")
crv = REGISTERED_EC_CURVES[jwk["crv"]].curve_obj
assert isinstance(crv, EllipticCurve)
crv = ec_curve_from_name(jwk["crv"])
key = EllipticCurvePublicNumbers(x, y, crv).public_key()
else:
raise NotImplementedError("Unsupported JWK type")
Expand Down
16 changes: 9 additions & 7 deletions pyscitt/pyscitt/key_vault_sign_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,6 @@

from . import crypto

ALGORITHMS = {
256: ("ES256", "sha256"),
384: ("ES384", "sha384"),
521: ("ES512", "sha384"),
}


class KeyVaultSignClient(MemberAuthenticationMethod):
"""MemberIdentity implementation that uses Azure Key Vault."""
Expand Down Expand Up @@ -120,7 +114,15 @@ def http_sign(self, data: bytes):
pub_key = cert.public_key()
assert isinstance(pub_key, (EllipticCurvePublicKey))
key_size = pub_key.curve.key_size
signature_algorithm, hash_algorithm = ALGORITHMS[key_size]

if key_size == 256:
signature_algorithm, hash_algorithm = ("ES256", "sha256")
elif key_size == 384:
signature_algorithm, hash_algorithm = ("ES384", "sha384")
elif key_size == 521:
signature_algorithm, hash_algorithm = ("ES512", "sha512")
else:
raise ValueError(f"Unsupported EC size: {key_size}")

digest_to_sign = hashlib.new(hash_algorithm, data).digest()
sign_result = crypto_client.sign(
Expand Down
2 changes: 1 addition & 1 deletion test/infra/did_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def do_GET(handler_self):
self.port = self.httpd.server_address[1]
self.base_url = f"https://{self.host}:{self.port}"

tls_key_pem, _ = crypto.generate_rsa_keypair(2048)
tls_key_pem, _ = crypto.generate_rsa_keypair()
self.tls_cert_pem = crypto.generate_cert(tls_key_pem, cn=host)

context = _create_tls_context(self.tls_cert_pem, tls_key_pem)
Expand Down
2 changes: 1 addition & 1 deletion test/infra/jwt_issuer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
class JwtIssuer:
def __init__(self, name="example.com"):
self.name = name
self.key, _ = crypto.generate_rsa_keypair(2048)
self.key, _ = crypto.generate_rsa_keypair()
self.cert = crypto.generate_cert(self.key, cn=name)
self.key_id = crypto.get_cert_fingerprint(self.cert)

Expand Down
6 changes: 3 additions & 3 deletions test/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def test_local_development(run, service_url, tmp_path: Path):


def test_create_ssh_did_web(run, tmp_path: Path):
private_key, public_key = crypto.generate_rsa_keypair(2048)
private_key, public_key = crypto.generate_rsa_keypair()
ssh_private_key = crypto.private_key_pem_to_ssh(private_key)
ssh_public_key = crypto.pub_key_pem_to_ssh(public_key)

Expand Down Expand Up @@ -401,7 +401,7 @@ def test_create_ssh_did_web(run, tmp_path: Path):


def test_adhoc_signer(run, tmp_path: Path):
private_key, public_key = crypto.generate_rsa_keypair(2048)
private_key, public_key = crypto.generate_rsa_keypair()
(tmp_path / "key.pem").write_text(private_key)
(tmp_path / "key_pub.pem").write_text(public_key)
(tmp_path / "claims.json").write_text(json.dumps({"foo": "bar"}))
Expand Down Expand Up @@ -519,7 +519,7 @@ def test_prefix_tree(run, tmp_path: Path):


def test_registration_info(run, tmp_path: Path):
private_key, public_key = crypto.generate_rsa_keypair(2048)
private_key, public_key = crypto.generate_rsa_keypair()
(tmp_path / "key.pem").write_text(private_key)
(tmp_path / "claims.json").write_text(json.dumps({"foo": "bar"}))

Expand Down

0 comments on commit 359a78e

Please sign in to comment.