diff --git a/mocket/entry.py b/mocket/entry.py index 8fa28bc7..9dbbf442 100644 --- a/mocket/entry.py +++ b/mocket/entry.py @@ -1,6 +1,7 @@ import collections.abc from mocket.compat import encode_to_bytes +from mocket.mocket import Mocket class MocketEntry: @@ -41,8 +42,6 @@ def can_handle(data): return True def collect(self, data): - from mocket import Mocket - req = self.request_cls(data) Mocket.collect(req) diff --git a/mocket/io.py b/mocket/io.py index 45bb8272..648b16dd 100644 --- a/mocket/io.py +++ b/mocket/io.py @@ -1,6 +1,8 @@ import io import os +from mocket.mocket import Mocket + class MocketSocketCore(io.BytesIO): def __init__(self, address) -> None: @@ -8,8 +10,6 @@ def __init__(self, address) -> None: super().__init__() def write(self, content): - from mocket import Mocket - super().write(content) _, w_fd = Mocket.get_pair(self._address) diff --git a/mocket/mocket.py b/mocket/mocket.py index b7b463af..3476902d 100644 --- a/mocket/mocket.py +++ b/mocket/mocket.py @@ -1,24 +1,33 @@ +from __future__ import annotations + import collections import itertools import os -from typing import Optional, Tuple +from typing import TYPE_CHECKING, ClassVar import mocket.inject # NOTE this is here for backwards-compat to keep old import-paths working -from mocket.socket import MocketSocket as MocketSocket +# from mocket.socket import MocketSocket as MocketSocket + +if TYPE_CHECKING: + from mocket.entry import MocketEntry + from mocket.types import Address class Mocket: - _socket_pairs = {} - _address = (None, None) - _entries = collections.defaultdict(list) - _requests = [] - _namespace = str(id(_entries)) - _truesocket_recording_dir = None + _socket_pairs: ClassVar[dict[Address, tuple[int, int]]] = {} + _address: ClassVar[Address] = (None, None) + _entries: ClassVar[dict[Address, list[MocketEntry]]] = collections.defaultdict(list) + _requests: ClassVar[list] = [] + _namespace: ClassVar[str] = str(id(_entries)) + _truesocket_recording_dir: ClassVar[str | None] = None + + enable = mocket.inject.enable + disable = mocket.inject.disable @classmethod - def get_pair(cls, address: tuple) -> Tuple[Optional[int], Optional[int]]: + def get_pair(cls, address: Address) -> tuple[int, int] | tuple[None, None]: """ Given the id() of the caller, return a pair of file descriptors as a tuple of two integers: (, ) @@ -26,7 +35,7 @@ def get_pair(cls, address: tuple) -> Tuple[Optional[int], Optional[int]]: return cls._socket_pairs.get(address, (None, None)) @classmethod - def set_pair(cls, address: tuple, pair: Tuple[int, int]) -> None: + def set_pair(cls, address: Address, pair: tuple[int, int]) -> None: """ Store a pair of file descriptors under the key `id_` as a tuple of two integers: (, ) @@ -34,25 +43,26 @@ def set_pair(cls, address: tuple, pair: Tuple[int, int]) -> None: cls._socket_pairs[address] = pair @classmethod - def register(cls, *entries): + def register(cls, *entries: MocketEntry) -> None: for entry in entries: cls._entries[entry.location].append(entry) @classmethod - def get_entry(cls, host, port, data): - host = host or Mocket._address[0] - port = port or Mocket._address[1] + def get_entry(cls, host: str, port: int, data) -> MocketEntry | None: + host = host or cls._address[0] + port = port or cls._address[1] entries = cls._entries.get((host, port), []) for entry in entries: if entry.can_handle(data): return entry + return None @classmethod - def collect(cls, data): - cls.request_list().append(data) + def collect(cls, data) -> None: + cls._requests.append(data) @classmethod - def reset(cls): + def reset(cls) -> None: for r_fd, w_fd in cls._socket_pairs.values(): os.close(r_fd) os.close(w_fd) @@ -63,39 +73,31 @@ def reset(cls): @classmethod def last_request(cls): if cls.has_requests(): - return cls.request_list()[-1] + return cls._requests[-1] @classmethod def request_list(cls): return cls._requests @classmethod - def remove_last_request(cls): + def remove_last_request(cls) -> None: if cls.has_requests(): del cls._requests[-1] @classmethod - def has_requests(cls): + def has_requests(cls) -> bool: return bool(cls.request_list()) @classmethod - def get_namespace(cls): + def get_namespace(cls) -> str: return cls._namespace - @staticmethod - def enable(namespace=None, truesocket_recording_dir=None): - mocket.inject.enable(namespace, truesocket_recording_dir) - - @staticmethod - def disable(): - mocket.inject.disable() - @classmethod - def get_truesocket_recording_dir(cls): + def get_truesocket_recording_dir(cls) -> str | None: return cls._truesocket_recording_dir @classmethod - def assert_fail_if_entries_not_served(cls): + def assert_fail_if_entries_not_served(cls) -> None: """Mocket checks that all entries have been served at least once.""" if not all(entry._served for entry in itertools.chain(*cls._entries.values())): raise AssertionError("Some Mocket entries have not been served") diff --git a/mocket/mocketizer.py b/mocket/mocketizer.py index 5a988c77..2bf2b9cd 100644 --- a/mocket/mocketizer.py +++ b/mocket/mocketizer.py @@ -1,3 +1,4 @@ +from mocket.mocket import Mocket from mocket.mode import MocketMode from mocket.utils import get_mocketize @@ -23,8 +24,6 @@ def __init__( ) def enter(self): - from mocket import Mocket - Mocket.enable( namespace=self.namespace, truesocket_recording_dir=self.truesocket_recording_dir, @@ -39,7 +38,6 @@ def __enter__(self): def exit(self): if self.instance: self.check_and_call("mocketize_teardown") - from mocket import Mocket Mocket.disable() diff --git a/mocket/mode.py b/mocket/mode.py index 3c0638e5..e1da7955 100644 --- a/mocket/mode.py +++ b/mocket/mode.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, ClassVar from mocket.exceptions import StrictMocketException +from mocket.mocket import Mocket if TYPE_CHECKING: # pragma: no cover from typing import NoReturn @@ -31,8 +32,6 @@ def is_allowed(self, location: str | tuple[str, int]) -> bool: @staticmethod def raise_not_allowed() -> NoReturn: - from mocket.mocket import Mocket - current_entries = [ (location, "\n ".join(map(str, entries))) for location, entries in Mocket._entries.items() diff --git a/mocket/plugins/httpretty/__init__.py b/mocket/plugins/httpretty/__init__.py index d5e41e30..fac61840 100644 --- a/mocket/plugins/httpretty/__init__.py +++ b/mocket/plugins/httpretty/__init__.py @@ -1,6 +1,7 @@ -from mocket import Mocket, mocketize +from mocket import mocketize from mocket.async_mocket import async_mocketize from mocket.compat import ENCODING +from mocket.mocket import Mocket from mocket.mockhttp import Entry as MocketHttpEntry from mocket.mockhttp import Request as MocketHttpRequest from mocket.mockhttp import Response as MocketHttpResponse diff --git a/mocket/socket.py b/mocket/socket.py index 06a61ba1..e4be00b6 100644 --- a/mocket/socket.py +++ b/mocket/socket.py @@ -16,6 +16,7 @@ true_urllib3_ssl_wrap_socket, ) from mocket.io import MocketSocketCore +from mocket.mocket import Mocket from mocket.mode import MocketMode from mocket.utils import hexdump, hexload @@ -93,8 +94,6 @@ def io(self): return self._io def fileno(self): - from mocket.mocket import Mocket - address = (self._host, self._port) r_fd, _ = Mocket.get_pair(address) if not r_fd: @@ -136,8 +135,6 @@ def getsockname(self): return socket.gethostbyname(self._address[0]), self._address[1] def getpeercert(self, *args, **kwargs): - from mocket.mocket import Mocket - if not (self._host and self._port): self._address = self._host, self._port = Mocket._address @@ -164,8 +161,6 @@ def write(self, data): return self.send(encode_to_bytes(data)) def connect(self, address): - from mocket.mocket import Mocket - self._address = self._host, self._port = address Mocket._address = address @@ -175,8 +170,6 @@ def makefile(self, mode="r", bufsize=-1): return self.io def get_entry(self, data): - from mocket.mocket import Mocket - return Mocket.get_entry(self._host, self._port, data) def sendall(self, data, entry=None, *args, **kwargs): @@ -213,8 +206,6 @@ def recv_into(self, buffer, buffersize=None, flags=None): return len(data) def recv(self, buffersize, flags=None): - from mocket.mocket import Mocket - r_fd, _ = Mocket.get_pair((self._host, self._port)) if r_fd: return os.read(r_fd, buffersize) @@ -228,8 +219,6 @@ def recv(self, buffersize, flags=None): raise exc def true_sendall(self, data, *args, **kwargs): - from mocket.mocket import Mocket - if not MocketMode().is_allowed((self._host, self._port)): MocketMode.raise_not_allowed() @@ -244,7 +233,8 @@ def true_sendall(self, data, *args, **kwargs): if Mocket.get_truesocket_recording_dir(): path = os.path.join( - Mocket.get_truesocket_recording_dir(), Mocket.get_namespace() + ".json" + Mocket.get_truesocket_recording_dir(), + Mocket.get_namespace() + ".json", ) # check if there's already a recorded session dumped to a JSON file try: @@ -317,8 +307,6 @@ def true_sendall(self, data, *args, **kwargs): return encoded_response def send(self, data, *args, **kwargs): # pragma: no cover - from mocket.mocket import Mocket - entry = self.get_entry(data) if not entry or (entry and self._entry != entry): kwargs["entry"] = entry diff --git a/mocket/types.py b/mocket/types.py new file mode 100644 index 00000000..61b7a4d5 --- /dev/null +++ b/mocket/types.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from typing import Tuple + +Address = Tuple[str, int] diff --git a/tests/test_socket.py b/tests/test_socket.py index 8a6e65ad..112a9089 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -2,7 +2,7 @@ import pytest -from mocket.mocket import MocketSocket +from mocket.socket import MocketSocket @pytest.mark.parametrize("blocking", (False, True))