Skip to content

Commit

Permalink
Add pyOpenSSL context support
Browse files Browse the repository at this point in the history
  • Loading branch information
IniterWorker committed Jun 5, 2024
1 parent d45de37 commit 2b7eaa7
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 33 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ dependencies = []
proxy = [
"PySocks",
]
openssl = [
"pyOpenSSL"
]

[project.urls]
Homepage = "http://eclipse.org/paho"
Expand Down
149 changes: 116 additions & 33 deletions src/paho/mqtt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,54 @@
from .reasoncodes import ReasonCode, ReasonCodes
from .subscribeoptions import SubscribeOptions

try:
from OpenSSL import SSL
from OpenSSL.crypto import X509

def _subject_alt_name_string(cert: X509) -> list:
"""Extracts the subject alternative name (SAN) entries from the certificate."""
san = []
for i in range(cert.get_extension_count()):
ext = cert.get_extension(i)
if ext.get_short_name() == b'subjectAltName':
san_entries = ext.__str__().split(', ')
for entry in san_entries:
key, value = entry.split(':', 1)
print(f"key {key}: value {value}")
san.append((key.strip(), value.strip()))
return san

def _openssl_match_hostname(cert: X509, hostname: str):
"""Verify that *cert* matches the *hostname* according to RFC 2818 and RFC 6125 rules.
CertificateError is raised on failure. On success, the function returns nothing.
"""
if not cert:
raise ValueError("Empty or no certificate. match_hostname needs a certificate.")

dnsnames = []
# Extract subject alternative name (SAN) entries
san = _subject_alt_name_string(cert)
for key, value in san:
if key == 'DNS':
if ssl._dnsname_match(value, hostname):
return
dnsnames.append(value)

if not dnsnames:
# TODO: check if no dns entry to use subject
raise ValueError("pyOpenssl match_hostname: using subject is not supported.")

if len(dnsnames) > 1:
raise ssl.CertificateError(f"Hostname {hostname} doesn't match any of {', '.join(map(repr, dnsnames))}")
elif len(dnsnames) == 1:
raise ssl.CertificateError(f"Hostname {hostname} doesn't match {dnsnames[0]}")
else:
raise ssl.CertificateError("No appropriate commonName or subjectAltName fields were found")

HAS_OPENSSL = True
except ImportError:
HAS_OPENSSL = False

try:
from typing import Literal
except ImportError:
Expand Down Expand Up @@ -851,7 +899,7 @@ def __init__(
self._thread: threading.Thread | None = None
self._thread_terminate = False
self._ssl = False
self._ssl_context: ssl.SSLContext | None = None
self._ssl_context: ssl.SSLContext | SSL.Context | None = None
# Only used when SSL context does not have check_hostname attribute
self._tls_insecure = False
self._logger: logging.Logger | None = None
Expand Down Expand Up @@ -1181,26 +1229,37 @@ def ws_set_options(

def tls_set_context(
self,
context: ssl.SSLContext | None = None,
context: ssl.SSLContext | SSL.Context | None = None,
) -> None:
"""Configure network encryption and authentication context. Enables SSL/TLS support.
:param context: an ssl.SSLContext object. By default this is given by
``ssl.create_default_context()``, if available.
:param context: an ssl.SSLContext or OpenSSL.SSL.Context object. By default, this is given by
``ssl.create_default_context()`` if available.
Must be called before `connect()`, `connect_async()` or `connect_srv()`."""
Must be called before `connect()`, `connect_async()` or `connect_srv()`.
"""
if self._ssl_context is not None:
raise ValueError('SSL/TLS has already been configured.')

if context is None:
context = ssl.create_default_context()
if HAS_OPENSSL:
raise ValueError("OpenSSL custom context is not provided.")
else:
context = ssl.create_default_context()

self._ssl = True
self._ssl_context = context

# Ensure _tls_insecure is consistent with check_hostname attribute
if hasattr(context, 'check_hostname'):
# Ensure _tls_insecure is consistent with check_hostname attribute for ssl.SSLContext
if isinstance(context, ssl.SSLContext) and hasattr(context, 'check_hostname'):
self._tls_insecure = not context.check_hostname
elif HAS_OPENSSL and isinstance(context, SSL.Context):
# PyOpenSSL Context does not have check_hostname attribute
# Set _tls_insecure based on custom logic if necessary
self._tls_insecure = False # Assuming default to False for PyOpenSSL
else:
# If OpenSSL is not available and context is an SSL.Context, raise an error
raise ValueError("OpenSSL is not available, cannot use SSL.Context.")

def tls_set(
self,
Expand Down Expand Up @@ -4638,43 +4697,67 @@ def _create_socket_connection(self) -> _socket.socket:
return socks.create_connection(addr, timeout=self._connect_timeout, source_address=source, **proxy)
else:
return socket.create_connection(addr, timeout=self._connect_timeout, source_address=source)

def _ssl_wrap_socket(self, tcp_sock: _socket.socket) -> ssl.SSLSocket:
def _ssl_wrap_socket(self, tcp_sock: _socket) -> _socket.socket:
if self._ssl_context is None:
raise ValueError(
"Impossible condition. _ssl_context should never be None if _ssl is True"
)

verify_host = not self._tls_insecure
try:
# Try with server_hostname, even it's not supported in certain scenarios
ssl_sock = self._ssl_context.wrap_socket(
tcp_sock,
server_hostname=self._host,
do_handshake_on_connect=False,
)
if isinstance(self._ssl_context, ssl.SSLContext):
# Use the built-in ssl.SSLContext
ssl_sock = self._ssl_context.wrap_socket(
tcp_sock,
server_hostname=self._host,
do_handshake_on_connect=False,
)
elif HAS_OPENSSL and isinstance(self._ssl_context, SSL.Context):
# Use PyOpenSSL's SSL.Context
conn = SSL.Connection(self._ssl_context, tcp_sock)
conn.set_connect_state()
if self._host:
conn.set_tlsext_host_name(self._host.encode('utf-8'))
ssl_sock = conn
else:
raise ValueError("Unsupported SSL context type")
except ssl.CertificateError:
# CertificateError is derived from ValueError
raise
except ValueError:
# Python version requires SNI in order to handle server_hostname, but SNI is not available
ssl_sock = self._ssl_context.wrap_socket(
tcp_sock,
do_handshake_on_connect=False,
)
else:
# If SSL context has already checked hostname, then don't need to do it again
if getattr(self._ssl_context, 'check_hostname', False): # type: ignore
verify_host = False
if isinstance(self._ssl_context, ssl.SSLContext):
ssl_sock = self._ssl_context.wrap_socket(
tcp_sock,
do_handshake_on_connect=False,
)
else:
raise

ssl_sock.settimeout(self._keepalive)
ssl_sock.do_handshake()

if verify_host:
# TODO: this type error is a true error:
# error: Module has no attribute "match_hostname" [attr-defined]
# Python 3.12 no longer have this method.
ssl.match_hostname(ssl_sock.getpeercert(), self._host) # type: ignore
# Function to handle retries for non-blocking SSL handshake
def do_handshake_with_retries(ssl_sock, retries=35, delay=0.1):
for attempt in range(retries):
try:
ssl_sock.do_handshake()
return
except SSL.WantReadError:
if attempt == retries - 1:
raise RuntimeError("Handshake failed after maximum retries")
time.sleep(delay)

if HAS_OPENSSL and isinstance(ssl_sock, SSL.Connection):
do_handshake_with_retries(ssl_sock)
if verify_host:
if getattr(self._ssl_context, 'check_hostname', False):
verify_host = False
_openssl_match_hostname(ssl_sock.get_peer_certificate(), self._host)
else:
ssl_sock.do_handshake()
if verify_host:
if getattr(self._ssl_context, 'check_hostname', False):
verify_host = False
ssl.match_hostname(ssl_sock.getpeercert(), self._host)

return ssl_sock

Expand Down

0 comments on commit 2b7eaa7

Please sign in to comment.