From 2b7eaa765f7fd216878c6b12692f7b354ad2ddaa Mon Sep 17 00:00:00 2001 From: Walter BONETTI Date: Wed, 5 Jun 2024 17:00:23 -0400 Subject: [PATCH] Add pyOpenSSL context support --- pyproject.toml | 3 + src/paho/mqtt/client.py | 149 +++++++++++++++++++++++++++++++--------- 2 files changed, 119 insertions(+), 33 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e990812d..f2c3a944 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,9 @@ dependencies = [] proxy = [ "PySocks", ] +openssl = [ + "pyOpenSSL" +] [project.urls] Homepage = "http://eclipse.org/paho" diff --git a/src/paho/mqtt/client.py b/src/paho/mqtt/client.py index 4ccc8696..7456b44f 100644 --- a/src/paho/mqtt/client.py +++ b/src/paho/mqtt/client.py @@ -45,6 +45,54 @@ from .reasoncodes import ReasonCode, ReasonCodes from .subscribeoptions import SubscribeOptions +try: + from OpenSSL import SSL + from OpenSSL.crypto import X509 + + def _subject_alt_name_string(cert: X509) -> list: + """Extracts the subject alternative name (SAN) entries from the certificate.""" + san = [] + for i in range(cert.get_extension_count()): + ext = cert.get_extension(i) + if ext.get_short_name() == b'subjectAltName': + san_entries = ext.__str__().split(', ') + for entry in san_entries: + key, value = entry.split(':', 1) + print(f"key {key}: value {value}") + san.append((key.strip(), value.strip())) + return san + + def _openssl_match_hostname(cert: X509, hostname: str): + """Verify that *cert* matches the *hostname* according to RFC 2818 and RFC 6125 rules. + CertificateError is raised on failure. On success, the function returns nothing. + """ + if not cert: + raise ValueError("Empty or no certificate. match_hostname needs a certificate.") + + dnsnames = [] + # Extract subject alternative name (SAN) entries + san = _subject_alt_name_string(cert) + for key, value in san: + if key == 'DNS': + if ssl._dnsname_match(value, hostname): + return + dnsnames.append(value) + + if not dnsnames: + # TODO: check if no dns entry to use subject + raise ValueError("pyOpenssl match_hostname: using subject is not supported.") + + if len(dnsnames) > 1: + raise ssl.CertificateError(f"Hostname {hostname} doesn't match any of {', '.join(map(repr, dnsnames))}") + elif len(dnsnames) == 1: + raise ssl.CertificateError(f"Hostname {hostname} doesn't match {dnsnames[0]}") + else: + raise ssl.CertificateError("No appropriate commonName or subjectAltName fields were found") + + HAS_OPENSSL = True +except ImportError: + HAS_OPENSSL = False + try: from typing import Literal except ImportError: @@ -851,7 +899,7 @@ def __init__( self._thread: threading.Thread | None = None self._thread_terminate = False self._ssl = False - self._ssl_context: ssl.SSLContext | None = None + self._ssl_context: ssl.SSLContext | SSL.Context | None = None # Only used when SSL context does not have check_hostname attribute self._tls_insecure = False self._logger: logging.Logger | None = None @@ -1181,26 +1229,37 @@ def ws_set_options( def tls_set_context( self, - context: ssl.SSLContext | None = None, + context: ssl.SSLContext | SSL.Context | None = None, ) -> None: """Configure network encryption and authentication context. Enables SSL/TLS support. - :param context: an ssl.SSLContext object. By default this is given by - ``ssl.create_default_context()``, if available. + :param context: an ssl.SSLContext or OpenSSL.SSL.Context object. By default, this is given by + ``ssl.create_default_context()`` if available. - Must be called before `connect()`, `connect_async()` or `connect_srv()`.""" + Must be called before `connect()`, `connect_async()` or `connect_srv()`. + """ if self._ssl_context is not None: raise ValueError('SSL/TLS has already been configured.') if context is None: - context = ssl.create_default_context() + if HAS_OPENSSL: + raise ValueError("OpenSSL custom context is not provided.") + else: + context = ssl.create_default_context() self._ssl = True self._ssl_context = context - # Ensure _tls_insecure is consistent with check_hostname attribute - if hasattr(context, 'check_hostname'): + # Ensure _tls_insecure is consistent with check_hostname attribute for ssl.SSLContext + if isinstance(context, ssl.SSLContext) and hasattr(context, 'check_hostname'): self._tls_insecure = not context.check_hostname + elif HAS_OPENSSL and isinstance(context, SSL.Context): + # PyOpenSSL Context does not have check_hostname attribute + # Set _tls_insecure based on custom logic if necessary + self._tls_insecure = False # Assuming default to False for PyOpenSSL + else: + # If OpenSSL is not available and context is an SSL.Context, raise an error + raise ValueError("OpenSSL is not available, cannot use SSL.Context.") def tls_set( self, @@ -4638,43 +4697,67 @@ def _create_socket_connection(self) -> _socket.socket: return socks.create_connection(addr, timeout=self._connect_timeout, source_address=source, **proxy) else: return socket.create_connection(addr, timeout=self._connect_timeout, source_address=source) - - def _ssl_wrap_socket(self, tcp_sock: _socket.socket) -> ssl.SSLSocket: + + def _ssl_wrap_socket(self, tcp_sock: _socket) -> _socket.socket: if self._ssl_context is None: raise ValueError( "Impossible condition. _ssl_context should never be None if _ssl is True" ) - + verify_host = not self._tls_insecure try: - # Try with server_hostname, even it's not supported in certain scenarios - ssl_sock = self._ssl_context.wrap_socket( - tcp_sock, - server_hostname=self._host, - do_handshake_on_connect=False, - ) + if isinstance(self._ssl_context, ssl.SSLContext): + # Use the built-in ssl.SSLContext + ssl_sock = self._ssl_context.wrap_socket( + tcp_sock, + server_hostname=self._host, + do_handshake_on_connect=False, + ) + elif HAS_OPENSSL and isinstance(self._ssl_context, SSL.Context): + # Use PyOpenSSL's SSL.Context + conn = SSL.Connection(self._ssl_context, tcp_sock) + conn.set_connect_state() + if self._host: + conn.set_tlsext_host_name(self._host.encode('utf-8')) + ssl_sock = conn + else: + raise ValueError("Unsupported SSL context type") except ssl.CertificateError: - # CertificateError is derived from ValueError raise except ValueError: - # Python version requires SNI in order to handle server_hostname, but SNI is not available - ssl_sock = self._ssl_context.wrap_socket( - tcp_sock, - do_handshake_on_connect=False, - ) - else: - # If SSL context has already checked hostname, then don't need to do it again - if getattr(self._ssl_context, 'check_hostname', False): # type: ignore - verify_host = False + if isinstance(self._ssl_context, ssl.SSLContext): + ssl_sock = self._ssl_context.wrap_socket( + tcp_sock, + do_handshake_on_connect=False, + ) + else: + raise ssl_sock.settimeout(self._keepalive) - ssl_sock.do_handshake() - if verify_host: - # TODO: this type error is a true error: - # error: Module has no attribute "match_hostname" [attr-defined] - # Python 3.12 no longer have this method. - ssl.match_hostname(ssl_sock.getpeercert(), self._host) # type: ignore + # Function to handle retries for non-blocking SSL handshake + def do_handshake_with_retries(ssl_sock, retries=35, delay=0.1): + for attempt in range(retries): + try: + ssl_sock.do_handshake() + return + except SSL.WantReadError: + if attempt == retries - 1: + raise RuntimeError("Handshake failed after maximum retries") + time.sleep(delay) + + if HAS_OPENSSL and isinstance(ssl_sock, SSL.Connection): + do_handshake_with_retries(ssl_sock) + if verify_host: + if getattr(self._ssl_context, 'check_hostname', False): + verify_host = False + _openssl_match_hostname(ssl_sock.get_peer_certificate(), self._host) + else: + ssl_sock.do_handshake() + if verify_host: + if getattr(self._ssl_context, 'check_hostname', False): + verify_host = False + ssl.match_hostname(ssl_sock.getpeercert(), self._host) return ssl_sock