Skip to content

Commit

Permalink
refactor: make injection code more readable and make backwards-compat…
Browse files Browse the repository at this point in the history
… more explicit
  • Loading branch information
betaboon committed Nov 23, 2024
1 parent 0da2722 commit 199e903
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 127 deletions.
219 changes: 124 additions & 95 deletions mocket/inject.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
from __future__ import annotations

import contextlib
import os
import socket
import ssl
from types import ModuleType
from typing import Any

import urllib3
from packaging.version import Version

try: # pragma: no cover
from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3
from mocket.utils import package_version, python_version

pyopenssl_override = True
except ImportError:
pyopenssl_override = False

def _replace(module: ModuleType, name: str, new_value: Any) -> None:
module.__dict__[name] = new_value


def _inject_stdlib_socket() -> None:
import socket

def enable(
namespace: str | None = None,
truesocket_recording_dir: str | None = None,
) -> None:
from mocket.mocket import Mocket
from mocket.socket import (
MocketSocket,
mock_create_connection,
Expand All @@ -27,99 +25,130 @@ def enable(
mock_gethostname,
mock_inet_pton,
mock_socketpair,
mock_urllib3_match_hostname,
)

_replace(socket, "socket", MocketSocket)
_replace(socket, "SocketType", MocketSocket)
_replace(socket, "create_connection", mock_create_connection)
_replace(socket, "getaddrinfo", mock_getaddrinfo)
_replace(socket, "gethostbyname", mock_gethostbyname)
_replace(socket, "gethostname", mock_gethostname)
_replace(socket, "inet_pton", mock_inet_pton)
_replace(socket, "socketpair", mock_socketpair)


def _restore_stdlib_socket() -> None:
import socket

from mocket.socket import (
true_socket_create_connection,
true_socket_getaddrinfo,
true_socket_gethostbyname,
true_socket_gethostname,
true_socket_inet_pton,
true_socket_socket,
true_socket_socket_type,
true_socket_socketpair,
)

_replace(socket, "SocketType", true_socket_socket_type)
_replace(socket, "create_connection", true_socket_create_connection)
_replace(socket, "getaddrinfo", true_socket_getaddrinfo)
_replace(socket, "gethostbyname", true_socket_gethostbyname)
_replace(socket, "gethostname", true_socket_gethostname)
_replace(socket, "inet_pton", true_socket_inet_pton)
_replace(socket, "socket", true_socket_socket)
_replace(socket, "socketpair", true_socket_socketpair)


def _inject_stdlib_ssl() -> None:
import ssl

from mocket.ssl.context import MocketSSLContext

_replace(ssl, "SSLContext", MocketSSLContext)

if python_version() < Version("3.12.0"):
_replace(ssl, "wrap_socket", MocketSSLContext.wrap_socket)


def _restore_stdlib_ssl() -> None:
import ssl

from mocket.ssl.context import true_ssl_ssl_context

_replace(ssl, "SSLContext", true_ssl_ssl_context)

if python_version() < Version("3.12.0"):
from mocket.ssl.context import true_ssl_wrap_socket

_replace(ssl, "wrap_socket", true_ssl_wrap_socket)


def _inject_urllib3() -> None:
import urllib3

from mocket.socket import mock_urllib3_match_hostname
from mocket.ssl.context import MocketSSLContext

_replace(urllib3.util.ssl_, "ssl_wrap_socket", MocketSSLContext.wrap_socket)
_replace(urllib3.util, "ssl_wrap_socket", MocketSSLContext.wrap_socket)
_replace(urllib3.connection, "ssl_wrap_socket", MocketSSLContext.wrap_socket)
_replace(urllib3.connection, "match_hostname", mock_urllib3_match_hostname)

if package_version("urllib3") < Version("2.0.0"):
_replace(urllib3.util.ssl_, "wrap_socket", MocketSSLContext.wrap_socket)

with contextlib.suppress(ImportError):
from urllib3.contrib.pyopenssl import extract_from_urllib3

extract_from_urllib3()


def _restore_urllib3() -> None:
import urllib3

from mocket.socket import true_urllib3_match_hostname
from mocket.ssl.context import true_urllib3_ssl_wrap_socket

_replace(urllib3.connection, "match_hostname", true_urllib3_match_hostname)
_replace(urllib3.util.ssl_, "ssl_wrap_socket", true_urllib3_ssl_wrap_socket)
_replace(urllib3.util, "ssl_wrap_socket", true_urllib3_ssl_wrap_socket)
_replace(urllib3.connection, "ssl_wrap_socket", true_urllib3_ssl_wrap_socket)

if package_version("urllib3") < Version("2.0.0"):
from mocket.ssl.context import true_urllib3_wrap_socket

_replace(urllib3.util.ssl_, "wrap_socket", true_urllib3_wrap_socket)

with contextlib.suppress(ImportError):
from urllib3.contrib.pyopenssl import inject_into_urllib3

inject_into_urllib3()


def enable(
namespace: str | None = None,
truesocket_recording_dir: str | None = None,
) -> None:
_inject_stdlib_socket()
_inject_stdlib_ssl()
_inject_urllib3()

from mocket.mocket import Mocket

Mocket._namespace = namespace
Mocket._truesocket_recording_dir = truesocket_recording_dir

if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir):
# JSON dumps will be saved here
raise AssertionError

socket.socket = socket.__dict__["socket"] = MocketSocket
socket._socketobject = socket.__dict__["_socketobject"] = MocketSocket
socket.SocketType = socket.__dict__["SocketType"] = MocketSocket
socket.create_connection = socket.__dict__["create_connection"] = (
mock_create_connection
)
socket.gethostname = socket.__dict__["gethostname"] = mock_gethostname
socket.gethostbyname = socket.__dict__["gethostbyname"] = mock_gethostbyname
socket.getaddrinfo = socket.__dict__["getaddrinfo"] = mock_getaddrinfo
socket.socketpair = socket.__dict__["socketpair"] = mock_socketpair
ssl.wrap_socket = ssl.__dict__["wrap_socket"] = MocketSSLContext.wrap_socket
ssl.SSLContext = ssl.__dict__["SSLContext"] = MocketSSLContext
socket.inet_pton = socket.__dict__["inet_pton"] = mock_inet_pton
urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
MocketSSLContext.wrap_socket
)
urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
"ssl_wrap_socket"
] = MocketSSLContext.wrap_socket
urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = (
MocketSSLContext.wrap_socket
)
urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
"ssl_wrap_socket"
] = MocketSSLContext.wrap_socket
urllib3.connection.match_hostname = urllib3.connection.__dict__[
"match_hostname"
] = mock_urllib3_match_hostname
if pyopenssl_override: # pragma: no cover
# Take out the pyopenssl version - use the default implementation
extract_from_urllib3()


def disable() -> None:
_restore_stdlib_socket()
_restore_stdlib_ssl()
_restore_urllib3()

from mocket.mocket import Mocket
from mocket.socket import (
true_create_connection,
true_getaddrinfo,
true_gethostbyname,
true_gethostname,
true_inet_pton,
true_socket,
true_socketpair,
true_urllib3_match_hostname,
)
from mocket.ssl.context import (
true_ssl_context,
true_ssl_wrap_socket,
true_urllib3_ssl_wrap_socket,
true_urllib3_wrap_socket,
)

socket.socket = socket.__dict__["socket"] = true_socket
socket._socketobject = socket.__dict__["_socketobject"] = true_socket
socket.SocketType = socket.__dict__["SocketType"] = true_socket
socket.create_connection = socket.__dict__["create_connection"] = (
true_create_connection
)
socket.gethostname = socket.__dict__["gethostname"] = true_gethostname
socket.gethostbyname = socket.__dict__["gethostbyname"] = true_gethostbyname
socket.getaddrinfo = socket.__dict__["getaddrinfo"] = true_getaddrinfo
socket.socketpair = socket.__dict__["socketpair"] = true_socketpair
if true_ssl_wrap_socket:
ssl.wrap_socket = ssl.__dict__["wrap_socket"] = true_ssl_wrap_socket
ssl.SSLContext = ssl.__dict__["SSLContext"] = true_ssl_context
socket.inet_pton = socket.__dict__["inet_pton"] = true_inet_pton
urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
true_urllib3_wrap_socket
)
urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
"ssl_wrap_socket"
] = true_urllib3_ssl_wrap_socket
urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = (
true_urllib3_ssl_wrap_socket
)
urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
"ssl_wrap_socket"
] = true_urllib3_ssl_wrap_socket
urllib3.connection.match_hostname = urllib3.connection.__dict__[
"match_hostname"
] = true_urllib3_match_hostname
Mocket.reset()
if pyopenssl_override: # pragma: no cover
# Put the pyopenssl version back in place
inject_into_urllib3()
23 changes: 12 additions & 11 deletions mocket/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from types import TracebackType
from typing import Any, Type

import urllib3.connection
import urllib3
from typing_extensions import Self

from mocket.compat import decode_from_bytes, encode_to_bytes
Expand All @@ -27,13 +27,14 @@
)
from mocket.utils import hexdump, hexload

true_create_connection = socket.create_connection
true_getaddrinfo = socket.getaddrinfo
true_gethostbyname = socket.gethostbyname
true_gethostname = socket.gethostname
true_inet_pton = socket.inet_pton
true_socket = socket.socket
true_socketpair = socket.socketpair
true_socket_socket = socket.socket
true_socket_socket_type = socket.SocketType
true_socket_create_connection = socket.create_connection
true_socket_gethostname = socket.gethostname
true_socket_gethostbyname = socket.gethostbyname
true_socket_getaddrinfo = socket.getaddrinfo
true_socket_socketpair = socket.socketpair
true_socket_inet_pton = socket.inet_pton
true_urllib3_match_hostname = urllib3.connection.match_hostname


Expand Down Expand Up @@ -106,7 +107,7 @@ def __init__(
self._proto = proto

self._kwargs = kwargs
self._true_socket = true_socket(family, type, proto)
self._true_socket = true_socket_socket(family, type, proto)

self._buflen = 65536
self._timeout: float | None = None
Expand Down Expand Up @@ -187,7 +188,7 @@ def getblocking(self) -> bool:
return self.gettimeout() is None

def getsockname(self) -> _RetAddress:
return true_gethostbyname(self._address[0]), self._address[1]
return true_socket_gethostbyname(self._address[0]), self._address[1]

def connect(self, address: Address) -> None:
self._address = self._host, self._port = address
Expand Down Expand Up @@ -295,7 +296,7 @@ def true_sendall(self, data: ReadableBuffer, *args: Any, **kwargs: Any) -> int:
# if not available, call the real sendall
except KeyError:
host, port = self._host, self._port
host = true_gethostbyname(host)
host = true_socket_gethostbyname(host)

with contextlib.suppress(OSError, ValueError):
# already connected
Expand Down
24 changes: 9 additions & 15 deletions mocket/ssl/context.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,24 @@
from __future__ import annotations

import contextlib
import ssl
from typing import Any

import urllib3.util.ssl_
import urllib3
from packaging.version import Version

from mocket.socket import MocketSocket
from mocket.ssl.socket import MocketSSLSocket
from mocket.utils import package_version, python_version

true_ssl_context = ssl.SSLContext
true_ssl_ssl_context = ssl.SSLContext

true_ssl_wrap_socket = None
true_urllib3_ssl_wrap_socket = urllib3.util.ssl_.ssl_wrap_socket
true_urllib3_wrap_socket = None

with contextlib.suppress(ImportError):
# from Py3.12 it's only under SSLContext
from ssl import wrap_socket as ssl_wrap_socket
if python_version() < Version("3.12.0"):
true_ssl_wrap_socket = ssl.wrap_socket

true_ssl_wrap_socket = ssl_wrap_socket

with contextlib.suppress(ImportError):
from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket
true_urllib3_ssl_wrap_socket = urllib3.util.ssl_.ssl_wrap_socket

true_urllib3_wrap_socket = urllib3_wrap_socket
if package_version("urllib3") < Version("2.0.0"):
true_urllib3_wrap_socket = urllib3.util.ssl_.wrap_socket


class _MocketSSLContext:
Expand Down
17 changes: 12 additions & 5 deletions mocket/utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
from __future__ import annotations

import binascii
import importlib.metadata
import platform
import ssl
from typing import Callable

import packaging.version

from mocket.compat import decode_from_bytes, encode_to_bytes

# NOTE this is here for backwards-compat to keep old import-paths working
from mocket.io import MocketSocketIO as MocketSocketCore
SSL_PROTOCOL = ssl.PROTOCOL_TLSv1_2

# NOTE this is here for backwards-compat to keep old import-paths working
from mocket.mode import MocketMode

SSL_PROTOCOL = ssl.PROTOCOL_TLSv1_2
def python_version() -> packaging.version.Version:
return packaging.version.Version(platform.python_version())


def package_version(package_name: str) -> packaging.version.Version:
version = importlib.metadata.version(package_name)
return packaging.version.parse(version)


def hexdump(binary_string: bytes) -> str:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies = [
"decorator>=4.0.0",
"urllib3>=1.25.3",
"h11",
"packaging>=24.2",
]
dynamic = ["version"]

Expand Down
2 changes: 1 addition & 1 deletion tests/test_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from mocket import Mocketizer, mocketize
from mocket.exceptions import StrictMocketException
from mocket.mockhttp import Entry, Response
from mocket.utils import MocketMode
from mocket.mode import MocketMode


@mocketize(strict_mode=True)
Expand Down

0 comments on commit 199e903

Please sign in to comment.