From 199e90344a90d5ea9210b801202e2ed577df5e22 Mon Sep 17 00:00:00 2001 From: betaboon Date: Sat, 23 Nov 2024 17:54:17 +0100 Subject: [PATCH] refactor: make injection code more readable and make backwards-compat more explicit --- mocket/inject.py | 219 ++++++++++++++++++++++++------------------ mocket/socket.py | 23 ++--- mocket/ssl/context.py | 24 ++--- mocket/utils.py | 17 +++- pyproject.toml | 1 + tests/test_mode.py | 2 +- 6 files changed, 159 insertions(+), 127 deletions(-) diff --git a/mocket/inject.py b/mocket/inject.py index 35e9da01..204c5981 100644 --- a/mocket/inject.py +++ b/mocket/inject.py @@ -1,24 +1,22 @@ from __future__ import annotations +import contextlib import os -import socket -import ssl +from types import ModuleType +from typing import Any -import urllib3 +from packaging.version import Version -try: # pragma: no cover - from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3 +from mocket.utils import package_version, python_version - pyopenssl_override = True -except ImportError: - pyopenssl_override = False +def _replace(module: ModuleType, name: str, new_value: Any) -> None: + module.__dict__[name] = new_value + + +def _inject_stdlib_socket() -> None: + import socket -def enable( - namespace: str | None = None, - truesocket_recording_dir: str | None = None, -) -> None: - from mocket.mocket import Mocket from mocket.socket import ( MocketSocket, mock_create_connection, @@ -27,99 +25,130 @@ def enable( mock_gethostname, mock_inet_pton, mock_socketpair, - mock_urllib3_match_hostname, ) + + _replace(socket, "socket", MocketSocket) + _replace(socket, "SocketType", MocketSocket) + _replace(socket, "create_connection", mock_create_connection) + _replace(socket, "getaddrinfo", mock_getaddrinfo) + _replace(socket, "gethostbyname", mock_gethostbyname) + _replace(socket, "gethostname", mock_gethostname) + _replace(socket, "inet_pton", mock_inet_pton) + _replace(socket, "socketpair", mock_socketpair) + + +def _restore_stdlib_socket() -> None: + import socket + + from mocket.socket import ( + true_socket_create_connection, + true_socket_getaddrinfo, + true_socket_gethostbyname, + true_socket_gethostname, + true_socket_inet_pton, + true_socket_socket, + true_socket_socket_type, + true_socket_socketpair, + ) + + _replace(socket, "SocketType", true_socket_socket_type) + _replace(socket, "create_connection", true_socket_create_connection) + _replace(socket, "getaddrinfo", true_socket_getaddrinfo) + _replace(socket, "gethostbyname", true_socket_gethostbyname) + _replace(socket, "gethostname", true_socket_gethostname) + _replace(socket, "inet_pton", true_socket_inet_pton) + _replace(socket, "socket", true_socket_socket) + _replace(socket, "socketpair", true_socket_socketpair) + + +def _inject_stdlib_ssl() -> None: + import ssl + from mocket.ssl.context import MocketSSLContext + _replace(ssl, "SSLContext", MocketSSLContext) + + if python_version() < Version("3.12.0"): + _replace(ssl, "wrap_socket", MocketSSLContext.wrap_socket) + + +def _restore_stdlib_ssl() -> None: + import ssl + + from mocket.ssl.context import true_ssl_ssl_context + + _replace(ssl, "SSLContext", true_ssl_ssl_context) + + if python_version() < Version("3.12.0"): + from mocket.ssl.context import true_ssl_wrap_socket + + _replace(ssl, "wrap_socket", true_ssl_wrap_socket) + + +def _inject_urllib3() -> None: + import urllib3 + + from mocket.socket import mock_urllib3_match_hostname + from mocket.ssl.context import MocketSSLContext + + _replace(urllib3.util.ssl_, "ssl_wrap_socket", MocketSSLContext.wrap_socket) + _replace(urllib3.util, "ssl_wrap_socket", MocketSSLContext.wrap_socket) + _replace(urllib3.connection, "ssl_wrap_socket", MocketSSLContext.wrap_socket) + _replace(urllib3.connection, "match_hostname", mock_urllib3_match_hostname) + + if package_version("urllib3") < Version("2.0.0"): + _replace(urllib3.util.ssl_, "wrap_socket", MocketSSLContext.wrap_socket) + + with contextlib.suppress(ImportError): + from urllib3.contrib.pyopenssl import extract_from_urllib3 + + extract_from_urllib3() + + +def _restore_urllib3() -> None: + import urllib3 + + from mocket.socket import true_urllib3_match_hostname + from mocket.ssl.context import true_urllib3_ssl_wrap_socket + + _replace(urllib3.connection, "match_hostname", true_urllib3_match_hostname) + _replace(urllib3.util.ssl_, "ssl_wrap_socket", true_urllib3_ssl_wrap_socket) + _replace(urllib3.util, "ssl_wrap_socket", true_urllib3_ssl_wrap_socket) + _replace(urllib3.connection, "ssl_wrap_socket", true_urllib3_ssl_wrap_socket) + + if package_version("urllib3") < Version("2.0.0"): + from mocket.ssl.context import true_urllib3_wrap_socket + + _replace(urllib3.util.ssl_, "wrap_socket", true_urllib3_wrap_socket) + + with contextlib.suppress(ImportError): + from urllib3.contrib.pyopenssl import inject_into_urllib3 + + inject_into_urllib3() + + +def enable( + namespace: str | None = None, + truesocket_recording_dir: str | None = None, +) -> None: + _inject_stdlib_socket() + _inject_stdlib_ssl() + _inject_urllib3() + + from mocket.mocket import Mocket + Mocket._namespace = namespace Mocket._truesocket_recording_dir = truesocket_recording_dir - if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir): # JSON dumps will be saved here raise AssertionError - socket.socket = socket.__dict__["socket"] = MocketSocket - socket._socketobject = socket.__dict__["_socketobject"] = MocketSocket - socket.SocketType = socket.__dict__["SocketType"] = MocketSocket - socket.create_connection = socket.__dict__["create_connection"] = ( - mock_create_connection - ) - socket.gethostname = socket.__dict__["gethostname"] = mock_gethostname - socket.gethostbyname = socket.__dict__["gethostbyname"] = mock_gethostbyname - socket.getaddrinfo = socket.__dict__["getaddrinfo"] = mock_getaddrinfo - socket.socketpair = socket.__dict__["socketpair"] = mock_socketpair - ssl.wrap_socket = ssl.__dict__["wrap_socket"] = MocketSSLContext.wrap_socket - ssl.SSLContext = ssl.__dict__["SSLContext"] = MocketSSLContext - socket.inet_pton = socket.__dict__["inet_pton"] = mock_inet_pton - urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = ( - MocketSSLContext.wrap_socket - ) - urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[ - "ssl_wrap_socket" - ] = MocketSSLContext.wrap_socket - urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = ( - MocketSSLContext.wrap_socket - ) - urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[ - "ssl_wrap_socket" - ] = MocketSSLContext.wrap_socket - urllib3.connection.match_hostname = urllib3.connection.__dict__[ - "match_hostname" - ] = mock_urllib3_match_hostname - if pyopenssl_override: # pragma: no cover - # Take out the pyopenssl version - use the default implementation - extract_from_urllib3() - def disable() -> None: + _restore_stdlib_socket() + _restore_stdlib_ssl() + _restore_urllib3() + from mocket.mocket import Mocket - from mocket.socket import ( - true_create_connection, - true_getaddrinfo, - true_gethostbyname, - true_gethostname, - true_inet_pton, - true_socket, - true_socketpair, - true_urllib3_match_hostname, - ) - from mocket.ssl.context import ( - true_ssl_context, - true_ssl_wrap_socket, - true_urllib3_ssl_wrap_socket, - true_urllib3_wrap_socket, - ) - socket.socket = socket.__dict__["socket"] = true_socket - socket._socketobject = socket.__dict__["_socketobject"] = true_socket - socket.SocketType = socket.__dict__["SocketType"] = true_socket - socket.create_connection = socket.__dict__["create_connection"] = ( - true_create_connection - ) - socket.gethostname = socket.__dict__["gethostname"] = true_gethostname - socket.gethostbyname = socket.__dict__["gethostbyname"] = true_gethostbyname - socket.getaddrinfo = socket.__dict__["getaddrinfo"] = true_getaddrinfo - socket.socketpair = socket.__dict__["socketpair"] = true_socketpair - if true_ssl_wrap_socket: - ssl.wrap_socket = ssl.__dict__["wrap_socket"] = true_ssl_wrap_socket - ssl.SSLContext = ssl.__dict__["SSLContext"] = true_ssl_context - socket.inet_pton = socket.__dict__["inet_pton"] = true_inet_pton - urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = ( - true_urllib3_wrap_socket - ) - urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[ - "ssl_wrap_socket" - ] = true_urllib3_ssl_wrap_socket - urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = ( - true_urllib3_ssl_wrap_socket - ) - urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[ - "ssl_wrap_socket" - ] = true_urllib3_ssl_wrap_socket - urllib3.connection.match_hostname = urllib3.connection.__dict__[ - "match_hostname" - ] = true_urllib3_match_hostname Mocket.reset() - if pyopenssl_override: # pragma: no cover - # Put the pyopenssl version back in place - inject_into_urllib3() diff --git a/mocket/socket.py b/mocket/socket.py index 03c6f7e5..b5af1bca 100644 --- a/mocket/socket.py +++ b/mocket/socket.py @@ -11,7 +11,7 @@ from types import TracebackType from typing import Any, Type -import urllib3.connection +import urllib3 from typing_extensions import Self from mocket.compat import decode_from_bytes, encode_to_bytes @@ -27,13 +27,14 @@ ) from mocket.utils import hexdump, hexload -true_create_connection = socket.create_connection -true_getaddrinfo = socket.getaddrinfo -true_gethostbyname = socket.gethostbyname -true_gethostname = socket.gethostname -true_inet_pton = socket.inet_pton -true_socket = socket.socket -true_socketpair = socket.socketpair +true_socket_socket = socket.socket +true_socket_socket_type = socket.SocketType +true_socket_create_connection = socket.create_connection +true_socket_gethostname = socket.gethostname +true_socket_gethostbyname = socket.gethostbyname +true_socket_getaddrinfo = socket.getaddrinfo +true_socket_socketpair = socket.socketpair +true_socket_inet_pton = socket.inet_pton true_urllib3_match_hostname = urllib3.connection.match_hostname @@ -106,7 +107,7 @@ def __init__( self._proto = proto self._kwargs = kwargs - self._true_socket = true_socket(family, type, proto) + self._true_socket = true_socket_socket(family, type, proto) self._buflen = 65536 self._timeout: float | None = None @@ -187,7 +188,7 @@ def getblocking(self) -> bool: return self.gettimeout() is None def getsockname(self) -> _RetAddress: - return true_gethostbyname(self._address[0]), self._address[1] + return true_socket_gethostbyname(self._address[0]), self._address[1] def connect(self, address: Address) -> None: self._address = self._host, self._port = address @@ -295,7 +296,7 @@ def true_sendall(self, data: ReadableBuffer, *args: Any, **kwargs: Any) -> int: # if not available, call the real sendall except KeyError: host, port = self._host, self._port - host = true_gethostbyname(host) + host = true_socket_gethostbyname(host) with contextlib.suppress(OSError, ValueError): # already connected diff --git a/mocket/ssl/context.py b/mocket/ssl/context.py index 438faa10..54ed1e46 100644 --- a/mocket/ssl/context.py +++ b/mocket/ssl/context.py @@ -1,30 +1,24 @@ from __future__ import annotations -import contextlib import ssl from typing import Any -import urllib3.util.ssl_ +import urllib3 +from packaging.version import Version from mocket.socket import MocketSocket from mocket.ssl.socket import MocketSSLSocket +from mocket.utils import package_version, python_version -true_ssl_context = ssl.SSLContext +true_ssl_ssl_context = ssl.SSLContext -true_ssl_wrap_socket = None -true_urllib3_ssl_wrap_socket = urllib3.util.ssl_.ssl_wrap_socket -true_urllib3_wrap_socket = None - -with contextlib.suppress(ImportError): - # from Py3.12 it's only under SSLContext - from ssl import wrap_socket as ssl_wrap_socket +if python_version() < Version("3.12.0"): + true_ssl_wrap_socket = ssl.wrap_socket - true_ssl_wrap_socket = ssl_wrap_socket - -with contextlib.suppress(ImportError): - from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket +true_urllib3_ssl_wrap_socket = urllib3.util.ssl_.ssl_wrap_socket - true_urllib3_wrap_socket = urllib3_wrap_socket +if package_version("urllib3") < Version("2.0.0"): + true_urllib3_wrap_socket = urllib3.util.ssl_.wrap_socket class _MocketSSLContext: diff --git a/mocket/utils.py b/mocket/utils.py index b9e2c259..7948af25 100644 --- a/mocket/utils.py +++ b/mocket/utils.py @@ -1,18 +1,25 @@ from __future__ import annotations import binascii +import importlib.metadata +import platform import ssl from typing import Callable +import packaging.version + from mocket.compat import decode_from_bytes, encode_to_bytes -# NOTE this is here for backwards-compat to keep old import-paths working -from mocket.io import MocketSocketIO as MocketSocketCore +SSL_PROTOCOL = ssl.PROTOCOL_TLSv1_2 -# NOTE this is here for backwards-compat to keep old import-paths working -from mocket.mode import MocketMode -SSL_PROTOCOL = ssl.PROTOCOL_TLSv1_2 +def python_version() -> packaging.version.Version: + return packaging.version.Version(platform.python_version()) + + +def package_version(package_name: str) -> packaging.version.Version: + version = importlib.metadata.version(package_name) + return packaging.version.parse(version) def hexdump(binary_string: bytes) -> str: diff --git a/pyproject.toml b/pyproject.toml index 77d1f5d4..73024c75 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "decorator>=4.0.0", "urllib3>=1.25.3", "h11", + "packaging>=24.2", ] dynamic = ["version"] diff --git a/tests/test_mode.py b/tests/test_mode.py index 2a764949..ea5905b0 100644 --- a/tests/test_mode.py +++ b/tests/test_mode.py @@ -4,7 +4,7 @@ from mocket import Mocketizer, mocketize from mocket.exceptions import StrictMocketException from mocket.mockhttp import Entry, Response -from mocket.utils import MocketMode +from mocket.mode import MocketMode @mocketize(strict_mode=True)