Skip to content

Commit

Permalink
refactor: introduce state-object
Browse files Browse the repository at this point in the history
refactor: Mocket - add typing and get rid of cyclic import
  • Loading branch information
betaboon committed Nov 17, 2024
1 parent d6b3bb1 commit 6291c31
Show file tree
Hide file tree
Showing 9 changed files with 49 additions and 57 deletions.
3 changes: 1 addition & 2 deletions mocket/entry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections.abc

from mocket.compat import encode_to_bytes
from mocket.mocket import Mocket


class MocketEntry:
Expand Down Expand Up @@ -41,8 +42,6 @@ def can_handle(data):
return True

def collect(self, data):
from mocket import Mocket

req = self.request_cls(data)
Mocket.collect(req)

Expand Down
4 changes: 2 additions & 2 deletions mocket/io.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import io
import os

from mocket.mocket import Mocket


class MocketSocketCore(io.BytesIO):
def __init__(self, address) -> None:
self._address = address
super().__init__()

def write(self, content):
from mocket import Mocket

super().write(content)

_, w_fd = Mocket.get_pair(self._address)
Expand Down
64 changes: 33 additions & 31 deletions mocket/mocket.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,68 @@
from __future__ import annotations

import collections
import itertools
import os
from typing import Optional, Tuple
from typing import TYPE_CHECKING, ClassVar

import mocket.inject

# NOTE this is here for backwards-compat to keep old import-paths working
from mocket.socket import MocketSocket as MocketSocket
# from mocket.socket import MocketSocket as MocketSocket

if TYPE_CHECKING:
from mocket.entry import MocketEntry
from mocket.types import Address


class Mocket:
_socket_pairs = {}
_address = (None, None)
_entries = collections.defaultdict(list)
_requests = []
_namespace = str(id(_entries))
_truesocket_recording_dir = None
_socket_pairs: ClassVar[dict[Address, tuple[int, int]]] = {}
_address: ClassVar[Address] = (None, None)
_entries: ClassVar[dict[Address, list[MocketEntry]]] = collections.defaultdict(list)
_requests: ClassVar[list] = []
_namespace: ClassVar[str] = str(id(_entries))
_truesocket_recording_dir: ClassVar[str | None] = None

enable = mocket.inject.enable
disable = mocket.inject.disable

@classmethod
def get_pair(cls, address: tuple) -> Tuple[Optional[int], Optional[int]]:
def get_pair(cls, address: Address) -> tuple[int, int] | tuple[None, None]:
"""
Given the id() of the caller, return a pair of file descriptors
as a tuple of two integers: (<read_fd>, <write_fd>)
"""
return cls._socket_pairs.get(address, (None, None))

@classmethod
def set_pair(cls, address: tuple, pair: Tuple[int, int]) -> None:
def set_pair(cls, address: Address, pair: tuple[int, int]) -> None:
"""
Store a pair of file descriptors under the key `id_`
as a tuple of two integers: (<read_fd>, <write_fd>)
"""
cls._socket_pairs[address] = pair

@classmethod
def register(cls, *entries):
def register(cls, *entries: MocketEntry) -> None:
for entry in entries:
cls._entries[entry.location].append(entry)

@classmethod
def get_entry(cls, host, port, data):
host = host or Mocket._address[0]
port = port or Mocket._address[1]
def get_entry(cls, host: str, port: int, data) -> MocketEntry | None:
host = host or cls._address[0]
port = port or cls._address[1]
entries = cls._entries.get((host, port), [])
for entry in entries:
if entry.can_handle(data):
return entry
return None

@classmethod
def collect(cls, data):
cls.request_list().append(data)
def collect(cls, data) -> None:
cls._requests.append(data)

@classmethod
def reset(cls):
def reset(cls) -> None:
for r_fd, w_fd in cls._socket_pairs.values():
os.close(r_fd)
os.close(w_fd)
Expand All @@ -63,39 +73,31 @@ def reset(cls):
@classmethod
def last_request(cls):
if cls.has_requests():
return cls.request_list()[-1]
return cls._requests[-1]

@classmethod
def request_list(cls):
return cls._requests

@classmethod
def remove_last_request(cls):
def remove_last_request(cls) -> None:
if cls.has_requests():
del cls._requests[-1]

@classmethod
def has_requests(cls):
def has_requests(cls) -> bool:
return bool(cls.request_list())

@classmethod
def get_namespace(cls):
def get_namespace(cls) -> str:
return cls._namespace

@staticmethod
def enable(namespace=None, truesocket_recording_dir=None):
mocket.inject.enable(namespace, truesocket_recording_dir)

@staticmethod
def disable():
mocket.inject.disable()

@classmethod
def get_truesocket_recording_dir(cls):
def get_truesocket_recording_dir(cls) -> str | None:
return cls._truesocket_recording_dir

@classmethod
def assert_fail_if_entries_not_served(cls):
def assert_fail_if_entries_not_served(cls) -> None:
"""Mocket checks that all entries have been served at least once."""
if not all(entry._served for entry in itertools.chain(*cls._entries.values())):
raise AssertionError("Some Mocket entries have not been served")
4 changes: 1 addition & 3 deletions mocket/mocketizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from mocket.mocket import Mocket
from mocket.mode import MocketMode
from mocket.utils import get_mocketize

Expand All @@ -23,8 +24,6 @@ def __init__(
)

def enter(self):
from mocket import Mocket

Mocket.enable(
namespace=self.namespace,
truesocket_recording_dir=self.truesocket_recording_dir,
Expand All @@ -39,7 +38,6 @@ def __enter__(self):
def exit(self):
if self.instance:
self.check_and_call("mocketize_teardown")
from mocket import Mocket

Mocket.disable()

Expand Down
3 changes: 1 addition & 2 deletions mocket/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING, Any, ClassVar

from mocket.exceptions import StrictMocketException
from mocket.mocket import Mocket

if TYPE_CHECKING: # pragma: no cover
from typing import NoReturn
Expand Down Expand Up @@ -31,8 +32,6 @@ def is_allowed(self, location: str | tuple[str, int]) -> bool:

@staticmethod
def raise_not_allowed() -> NoReturn:
from mocket.mocket import Mocket

current_entries = [
(location, "\n ".join(map(str, entries)))
for location, entries in Mocket._entries.items()
Expand Down
3 changes: 2 additions & 1 deletion mocket/plugins/httpretty/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from mocket import Mocket, mocketize
from mocket import mocketize
from mocket.async_mocket import async_mocketize
from mocket.compat import ENCODING
from mocket.mocket import Mocket
from mocket.mockhttp import Entry as MocketHttpEntry
from mocket.mockhttp import Request as MocketHttpRequest
from mocket.mockhttp import Response as MocketHttpResponse
Expand Down
18 changes: 3 additions & 15 deletions mocket/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
true_urllib3_ssl_wrap_socket,
)
from mocket.io import MocketSocketCore
from mocket.mocket import Mocket
from mocket.mode import MocketMode
from mocket.utils import hexdump, hexload

Expand Down Expand Up @@ -93,8 +94,6 @@ def io(self):
return self._io

def fileno(self):
from mocket.mocket import Mocket

address = (self._host, self._port)
r_fd, _ = Mocket.get_pair(address)
if not r_fd:
Expand Down Expand Up @@ -136,8 +135,6 @@ def getsockname(self):
return socket.gethostbyname(self._address[0]), self._address[1]

def getpeercert(self, *args, **kwargs):
from mocket.mocket import Mocket

if not (self._host and self._port):
self._address = self._host, self._port = Mocket._address

Expand All @@ -164,8 +161,6 @@ def write(self, data):
return self.send(encode_to_bytes(data))

def connect(self, address):
from mocket.mocket import Mocket

self._address = self._host, self._port = address
Mocket._address = address

Expand All @@ -175,8 +170,6 @@ def makefile(self, mode="r", bufsize=-1):
return self.io

def get_entry(self, data):
from mocket.mocket import Mocket

return Mocket.get_entry(self._host, self._port, data)

def sendall(self, data, entry=None, *args, **kwargs):
Expand Down Expand Up @@ -213,8 +206,6 @@ def recv_into(self, buffer, buffersize=None, flags=None):
return len(data)

def recv(self, buffersize, flags=None):
from mocket.mocket import Mocket

r_fd, _ = Mocket.get_pair((self._host, self._port))
if r_fd:
return os.read(r_fd, buffersize)
Expand All @@ -228,8 +219,6 @@ def recv(self, buffersize, flags=None):
raise exc

def true_sendall(self, data, *args, **kwargs):
from mocket.mocket import Mocket

if not MocketMode().is_allowed((self._host, self._port)):
MocketMode.raise_not_allowed()

Expand All @@ -244,7 +233,8 @@ def true_sendall(self, data, *args, **kwargs):

if Mocket.get_truesocket_recording_dir():
path = os.path.join(
Mocket.get_truesocket_recording_dir(), Mocket.get_namespace() + ".json"
Mocket.get_truesocket_recording_dir(),
Mocket.get_namespace() + ".json",
)
# check if there's already a recorded session dumped to a JSON file
try:
Expand Down Expand Up @@ -317,8 +307,6 @@ def true_sendall(self, data, *args, **kwargs):
return encoded_response

def send(self, data, *args, **kwargs): # pragma: no cover
from mocket.mocket import Mocket

entry = self.get_entry(data)
if not entry or (entry and self._entry != entry):
kwargs["entry"] = entry
Expand Down
5 changes: 5 additions & 0 deletions mocket/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import annotations

from typing import Tuple

Address = Tuple[str, int]
2 changes: 1 addition & 1 deletion tests/test_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from mocket.mocket import MocketSocket
from mocket.socket import MocketSocket


@pytest.mark.parametrize("blocking", (False, True))
Expand Down

0 comments on commit 6291c31

Please sign in to comment.