From dfc68d18b974af17b56542d6b0955832ee26fc34 Mon Sep 17 00:00:00 2001 From: saltiyazan Date: Thu, 10 Oct 2024 19:50:52 +0400 Subject: [PATCH] chore: Use tls lib V4.0 (#515) --- .../v4/tls_certificates.py | 167 +++++------------- lib/charms/vault_k8s/v0/vault_tls.py | 6 +- src/charm.py | 10 +- tests/unit/certificates.py | 8 +- .../lib/charms/vault_k8s/v0/test_vault_tls.py | 4 +- 5 files changed, 56 insertions(+), 139 deletions(-) diff --git a/lib/charms/tls_certificates_interface/v4/tls_certificates.py b/lib/charms/tls_certificates_interface/v4/tls_certificates.py index c6c10fe3..0de221e2 100644 --- a/lib/charms/tls_certificates_interface/v4/tls_certificates.py +++ b/lib/charms/tls_certificates_interface/v4/tls_certificates.py @@ -1,12 +1,17 @@ # Copyright 2024 Canonical Ltd. # See LICENSE file for licensing details. -"""Charm library for managing TLS certificates (V4) - BETA. +"""Charm library for managing TLS certificates (V4). -> Warning: This is a beta version of the tls-certificates interface library. -> Use at your own risk. +This library contains the Requires and Provides classes for handling the tls-certificates +interface. -Learn how-to use the TLS Certificates interface library by reading the documentation: +Pre-requisites: + - Juju >= 3.0 + - cryptography >= 43.0.0 + - pydantic + +Learn more on how-to use the TLS Certificates interface library by reading the documentation: - https://charmhub.io/tls-certificates-interface/ """ # noqa: D214, D405, D411, D416 @@ -47,7 +52,7 @@ # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 9 +LIBPATCH = 0 PYDEPS = ["cryptography", "pydantic"] @@ -138,7 +143,6 @@ class _Certificate(BaseModel): certificate_signing_request: str certificate: str chain: Optional[List[str]] = None - recommended_expiry_notification_time: Optional[int] = None revoked: Optional[bool] = None def to_provider_certificate(self, relation_id: int) -> "ProviderCertificate": @@ -153,7 +157,6 @@ def to_provider_certificate(self, relation_id: int) -> "ProviderCertificate": chain=[Certificate.from_string(certificate) for certificate in self.chain] if self.chain else [], - recommended_expiry_notification_time=self.recommended_expiry_notification_time, revoked=self.revoked, ) @@ -215,6 +218,8 @@ class Certificate: raw: str common_name: str + expiry_time: datetime + validity_start_time: datetime is_ca: bool = False sans_dns: Optional[FrozenSet[str]] = frozenset() sans_ip: Optional[FrozenSet[str]] = frozenset() @@ -225,8 +230,6 @@ class Certificate: country_name: Optional[str] = None state_or_province_name: Optional[str] = None locality_name: Optional[str] = None - expiry_time: Optional[datetime] = None - validity_start_time: Optional[datetime] = None def __str__(self) -> str: """Return the certificate as a string.""" @@ -424,8 +427,8 @@ def get_sha256_hex(self) -> str: @dataclass(frozen=True) -class CertificateRequest: - """This class represents a certificate request. +class CertificateRequestAttributes: + """A representation of the certificate request attributes. This class should be used inside the requirer charm to specify the requested attributes for the certificate. @@ -477,7 +480,7 @@ def generate_csr( @classmethod def from_csr(cls, csr: CertificateSigningRequest, is_ca: bool): - """Create a CertificateRequest object from a CSR.""" + """Create a CertificateRequestAttributes object from a CSR.""" return cls( common_name=csr.common_name, sans_dns=csr.sans_dns, @@ -502,7 +505,6 @@ class ProviderCertificate: certificate_signing_request: CertificateSigningRequest ca: Certificate chain: List[Certificate] - recommended_expiry_notification_time: Optional[int] = None revoked: Optional[bool] = None def to_json(self) -> str: @@ -523,8 +525,8 @@ def to_json(self) -> str: @dataclass(frozen=True) -class RequirerCSR: - """This class represents a certificate signing request requested by the TLS requirer.""" +class RequirerCertificateRequest: + """This class represents a certificate signing request requested by a specific TLS requirer.""" relation_id: int certificate_signing_request: CertificateSigningRequest @@ -572,60 +574,6 @@ def chain_as_pem(self) -> str: return "\n\n".join([str(cert) for cert in self.chain]) -def _get_closest_future_time( - expiry_notification_time: datetime, expiry_time: datetime -) -> datetime: - """Return expiry_notification_time if not in the past, otherwise return expiry_time. - - Args: - expiry_notification_time (datetime): Notification time of impending expiration - expiry_time (datetime): Expiration time - - Returns: - datetime: expiry_notification_time if not in the past, expiry_time otherwise - """ - return ( - expiry_notification_time - if datetime.now(timezone.utc) < expiry_notification_time - else expiry_time - ) - - -def calculate_expiry_notification_time( - not_valid_before: datetime, - not_valid_after: datetime, - provider_recommended_notification_time: Optional[int] = None, -) -> datetime: - """Calculate a reasonable time to notify the user about the certificate expiry. - - It takes into account the time recommended by the provider. - Time recommended by the provider is preferred, - then dynamically calculated time. - - Args: - not_valid_before: Time when the certificate is valid from. - not_valid_after: Time when the certificate is valid until. - provider_recommended_notification_time: - Time in hours prior to expiry to notify the user. - Recommended by the provider. - - Returns: - datetime: Time to notify the user about the certificate expiry. - """ - if provider_recommended_notification_time is not None: - provider_recommended_notification_time = abs(provider_recommended_notification_time) - provider_recommendation_time_delta = not_valid_after - timedelta( - hours=provider_recommended_notification_time - ) - if not_valid_before < provider_recommendation_time_delta: - return provider_recommendation_time_delta - # Divide the time between not_valid_after and not_valid_before by 3 - # For example, if there are 3 days between not_valid_after and not_valid_before, - # the notification time will be 1 day before not_valid_after. - calculated_time = (not_valid_after - not_valid_before) / 3 - return not_valid_after - calculated_time - - def generate_private_key( key_size: int = 2048, public_exponent: int = 65537, @@ -996,7 +944,7 @@ def __init__( self, charm: CharmBase, relationship_name: str, - certificate_requests: List[CertificateRequest], + certificate_requests: List[CertificateRequestAttributes], mode: Mode = Mode.UNIT, refresh_events: List[BoundEvent] = [], ): @@ -1005,7 +953,8 @@ def __init__( Args: charm (CharmBase): The charm instance to relate to. relationship_name (str): The name of the relation that provides the certificates. - certificate_requests (List[CertificateRequest]): A list of certificate requests. + certificate_requests (List[CertificateRequestAttributes]): + A list with the attributes of the certificate requests. mode (Mode): Whether to use unit or app certificates mode. Default is Mode.UNIT. refresh_events (List[BoundEvent]): A list of events to trigger a refresh of the certificates. @@ -1165,14 +1114,14 @@ def _csr_matches_certificate_request( self, certificate_signing_request: CertificateSigningRequest, is_ca: bool ) -> bool: for certificate_request in self.certificate_requests: - if certificate_request == CertificateRequest.from_csr( + if certificate_request == CertificateRequestAttributes.from_csr( certificate_signing_request, is_ca, ): return True return False - def _certificate_requested(self, certificate_request: CertificateRequest) -> bool: + def _certificate_requested(self, certificate_request: CertificateRequestAttributes) -> bool: if not self.private_key: return False csr = self._certificate_requested_for_attributes(certificate_request) @@ -1184,17 +1133,17 @@ def _certificate_requested(self, certificate_request: CertificateRequest) -> boo def _certificate_requested_for_attributes( self, - certificate_request: CertificateRequest, - ) -> Optional[RequirerCSR]: + certificate_request: CertificateRequestAttributes, + ) -> Optional[RequirerCertificateRequest]: for requirer_csr in self.get_csrs_from_requirer_relation_data(): - if certificate_request == CertificateRequest.from_csr( + if certificate_request == CertificateRequestAttributes.from_csr( requirer_csr.certificate_signing_request, requirer_csr.is_ca, ): return requirer_csr return None - def get_csrs_from_requirer_relation_data(self) -> List[RequirerCSR]: + def get_csrs_from_requirer_relation_data(self) -> List[RequirerCertificateRequest]: """Return list of requirer's CSRs from relation data.""" if self.mode == Mode.APP and not self.model.unit.is_leader(): logger.debug("Not a leader unit - Skipping") @@ -1212,7 +1161,7 @@ def get_csrs_from_requirer_relation_data(self) -> List[RequirerCSR]: requirer_csrs = [] for csr in requirer_relation_data.certificate_signing_requests: requirer_csrs.append( - RequirerCSR( + RequirerCertificateRequest( relation_id=relation.id, certificate_signing_request=CertificateSigningRequest.from_string( csr.certificate_signing_request @@ -1288,11 +1237,11 @@ def _send_certificate_requests(self): self._request_certificate(csr=csr, is_ca=certificate_request.is_ca) def get_assigned_certificate( - self, certificate_request: CertificateRequest + self, certificate_request: CertificateRequestAttributes ) -> Tuple[ProviderCertificate | None, PrivateKey | None]: """Get the certificate that was assigned to the given certificate request.""" for requirer_csr in self.get_csrs_from_requirer_relation_data(): - if certificate_request == CertificateRequest.from_csr( + if certificate_request == CertificateRequestAttributes.from_csr( requirer_csr.certificate_signing_request, requirer_csr.is_ca, ): @@ -1308,7 +1257,7 @@ def get_assigned_certificates(self) -> Tuple[List[ProviderCertificate], PrivateK return assigned_certificates, self.private_key def _find_certificate_in_relation_data( - self, csr: RequirerCSR + self, csr: RequirerCertificateRequest ) -> Optional[ProviderCertificate]: """Return the certificate that match the given CSR.""" for provider_certificate in self.get_provider_certificates(): @@ -1359,7 +1308,7 @@ def _find_available_certificates(self): } ) secret.set_info( - expire=self._get_next_secret_expiry_time(provider_certificate), + expire=provider_certificate.certificate.expiry_time, ) except SecretNotFoundError: logger.debug("Creating new secret with label %s", secret_label) @@ -1369,7 +1318,7 @@ def _find_available_certificates(self): "csr": str(provider_certificate.certificate_signing_request), }, label=secret_label, - expire=self._get_next_secret_expiry_time(provider_certificate), + expire=provider_certificate.certificate.expiry_time, ) self.on.certificate_available.emit( certificate_signing_request=provider_certificate.certificate_signing_request, @@ -1410,41 +1359,6 @@ def _cleanup_certificate_requests(self): "Removed CSR from relation data because it did not match the private key" # noqa: E501 ) - def _get_next_secret_expiry_time( - self, provider_certificate: ProviderCertificate - ) -> Optional[datetime]: - """Return the expiry time or expiry notification time. - - Extracts the expiry time from the provided certificate, calculates the - expiry notification time and return the closest of the two, that is in - the future. - - Args: - provider_certificate: ProviderCertificate object - - Returns: - Optional[datetime]: None if the certificate expiry time cannot be read, - next expiry time otherwise. - """ - if not provider_certificate.certificate.expiry_time: - logger.warning("Certificate has no expiry time") - return None - if not provider_certificate.certificate.validity_start_time: - logger.warning("Certificate has no validity start time") - return None - expiry_notification_time = calculate_expiry_notification_time( - not_valid_before=provider_certificate.certificate.validity_start_time, - not_valid_after=provider_certificate.certificate.expiry_time, - provider_recommended_notification_time=provider_certificate.recommended_expiry_notification_time, - ) - if not expiry_notification_time: - logger.warning("Could not calculate expiry notification time") - return None - return _get_closest_future_time( - expiry_notification_time, - provider_certificate.certificate.expiry_time, - ) - def _tls_relation_created(self) -> bool: relation = self.model.get_relation(self.relationship_name) if not relation: @@ -1520,10 +1434,12 @@ def _get_tls_relations(self, relation_id: Optional[int] = None) -> List[Relation else self.model.relations.get(self.relationship_name, []) ) - def get_certificate_requests(self, relation_id: Optional[int] = None) -> List[RequirerCSR]: + def get_certificate_requests( + self, relation_id: Optional[int] = None + ) -> List[RequirerCertificateRequest]: """Load certificate requests from the relation data.""" relations = self._get_tls_relations(relation_id) - requirer_csrs: List[RequirerCSR] = [] + requirer_csrs: List[RequirerCertificateRequest] = [] for relation in relations: for unit in relation.units: requirer_csrs.extend(self._load_requirer_databag(relation, unit)) @@ -1532,14 +1448,14 @@ def get_certificate_requests(self, relation_id: Optional[int] = None) -> List[Re def _load_requirer_databag( self, relation: Relation, unit_or_app: Union[Application, Unit] - ) -> List[RequirerCSR]: + ) -> List[RequirerCertificateRequest]: try: requirer_relation_data = _RequirerData.load(relation.data[unit_or_app]) except DataValidationError: logger.debug("Invalid requirer relation data for %s", unit_or_app.name) return [] return [ - RequirerCSR( + RequirerCertificateRequest( relation_id=relation.id, certificate_signing_request=CertificateSigningRequest.from_string( csr.certificate_signing_request @@ -1559,7 +1475,6 @@ def _add_provider_certificate( certificate_signing_request=str(provider_certificate.certificate_signing_request), ca=str(provider_certificate.ca), chain=[str(certificate) for certificate in provider_certificate.chain], - recommended_expiry_notification_time=provider_certificate.recommended_expiry_notification_time, ) provider_certificates = self._load_provider_certificates(relation) if new_certificate in provider_certificates: @@ -1692,17 +1607,17 @@ def get_unsolicited_certificates( def get_outstanding_certificate_requests( self, relation_id: Optional[int] = None - ) -> List[RequirerCSR]: + ) -> List[RequirerCertificateRequest]: """Return CSR's for which no certificate has been issued. Args: relation_id (int): Relation id Returns: - list: List of RequirerCSR objects. + list: List of RequirerCertificateRequest objects. """ requirer_csrs = self.get_certificate_requests(relation_id=relation_id) - outstanding_csrs: List[RequirerCSR] = [] + outstanding_csrs: List[RequirerCertificateRequest] = [] for relation_csr in requirer_csrs: if not self._certificate_issued_for_csr( csr=relation_csr.certificate_signing_request, diff --git a/lib/charms/vault_k8s/v0/vault_tls.py b/lib/charms/vault_k8s/v0/vault_tls.py index e3810f19..a1d84bc6 100644 --- a/lib/charms/vault_k8s/v0/vault_tls.py +++ b/lib/charms/vault_k8s/v0/vault_tls.py @@ -15,7 +15,7 @@ ) from charms.tls_certificates_interface.v4.tls_certificates import ( Certificate, - CertificateRequest, + CertificateRequestAttributes, PrivateKey, TLSCertificatesRequiresV4, generate_ca, @@ -200,11 +200,11 @@ def _configure_ca_cert_relation(self, event: EventBase): """Send the CA certificate to the relation.""" self.send_ca_cert() - def _get_certificate_requests(self) -> List[CertificateRequest]: + def _get_certificate_requests(self) -> List[CertificateRequestAttributes]: if not self.common_name: return [] return [ - CertificateRequest( + CertificateRequestAttributes( common_name=self.common_name, sans_dns=self.sans_dns, sans_ip=self.sans_ip ) ] diff --git a/src/charm.py b/src/charm.py index ec8ffac3..28fc09f7 100755 --- a/src/charm.py +++ b/src/charm.py @@ -22,11 +22,11 @@ from charms.prometheus_k8s.v0.prometheus_scrape import MetricsEndpointProvider from charms.tls_certificates_interface.v4.tls_certificates import ( Certificate, - CertificateRequest, + CertificateRequestAttributes, Mode, PrivateKey, ProviderCertificate, - RequirerCSR, + RequirerCertificateRequest, TLSCertificatesProvidesV4, TLSCertificatesRequiresV4, ) @@ -608,11 +608,11 @@ def _get_pki_intermediate_ca( return None, None return provider_certificate, private_key - def _get_certificate_request(self) -> CertificateRequest | None: + def _get_certificate_request(self) -> CertificateRequestAttributes | None: common_name = self._get_config_common_name() if not common_name: return None - return CertificateRequest( + return CertificateRequestAttributes( common_name=common_name, is_ca=True, ) @@ -679,7 +679,7 @@ def _generate_kv_for_requirer( self._set_kv_relation_data(relation, mount, ca_certificate, egress_subnets) self._remove_stale_nonce(relation=relation, nonce=nonce) - def _generate_pki_certificate_for_requirer(self, requirer_csr: RequirerCSR): + def _generate_pki_certificate_for_requirer(self, requirer_csr: RequirerCertificateRequest): """Generate a PKI certificate for a TLS requirer.""" if not self.unit.is_leader(): logger.debug("Only leader unit can handle a vault-pki request") diff --git a/tests/unit/certificates.py b/tests/unit/certificates.py index 881a4b44..3dfd1517 100644 --- a/tests/unit/certificates.py +++ b/tests/unit/certificates.py @@ -9,7 +9,7 @@ CertificateSigningRequest, PrivateKey, ProviderCertificate, - RequirerCSR, + RequirerCertificateRequest, generate_ca, generate_certificate, generate_csr, @@ -51,13 +51,15 @@ def generate_example_provider_certificate( return provider_certificate, private_key -def generate_example_requirer_csr(common_name: str, relation_id: int) -> RequirerCSR: +def generate_example_requirer_csr( + common_name: str, relation_id: int +) -> RequirerCertificateRequest: private_key = generate_private_key() csr = generate_csr( private_key=private_key, common_name=common_name, ) - return RequirerCSR( + return RequirerCertificateRequest( relation_id=relation_id, certificate_signing_request=csr, is_ca=False, diff --git a/tests/unit/lib/charms/vault_k8s/v0/test_vault_tls.py b/tests/unit/lib/charms/vault_k8s/v0/test_vault_tls.py index d04ab8be..4a69bea4 100644 --- a/tests/unit/lib/charms/vault_k8s/v0/test_vault_tls.py +++ b/tests/unit/lib/charms/vault_k8s/v0/test_vault_tls.py @@ -10,7 +10,7 @@ import pytest import scenario from charms.tls_certificates_interface.v4.tls_certificates import ( - CertificateRequest, + CertificateRequestAttributes, ProviderCertificate, generate_ca, generate_certificate, @@ -191,7 +191,7 @@ def test_given_certificate_access_relation_when_relation_changed_then_new_reques self.ctx.run(certificates_relation.changed_event, state_in) self.mock_get_assigned_certificate.assert_called_once_with( - certificate_request=CertificateRequest( + certificate_request=CertificateRequestAttributes( common_name=ingress_address, sans_dns=frozenset({self.fqdn}), sans_ip=frozenset({ingress_address}),