Skip to content

Commit

Permalink
refactor: migrate mocket.compat.entry to use mocket.core.entry
Browse files Browse the repository at this point in the history
  • Loading branch information
betaboon committed Dec 2, 2024
1 parent aaf4399 commit 0a94080
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 82 deletions.
3 changes: 2 additions & 1 deletion mocket/compat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from mocket.compat.entry import MocketEntry
from mocket.compat.entry import MocketEntry, Response
from mocket.core.ssl.context import MocketSSLContext as FakeSSLContext

__all__ = [
"FakeSSLContext",
"MocketEntry",
"Response",
]
88 changes: 33 additions & 55 deletions mocket/compat/entry.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,36 @@
import collections.abc

from mocket.core.compat import encode_to_bytes
from mocket.core.mocket import Mocket


class MocketEntry:
class Response(bytes):
@property
def data(self):
return self

response_index = 0
request_cls = bytes
response_cls = Response
responses = None
_served = None

def __init__(self, location, responses):
self._served = False
self.location = location

if not isinstance(responses, collections.abc.Iterable):
from __future__ import annotations

from mocket.bytes import MocketBytesEntry, MocketBytesResponse
from mocket.core.types import Address


class Response(MocketBytesResponse):
def __init__(self, data: bytes | str | bool) -> None:
if isinstance(data, str):
data = data.encode()
elif isinstance(data, bool):
data = bytes(data)
self._data = data


class MocketEntry(MocketBytesEntry):
def __init__(
self,
location: Address,
responses: list[MocketBytesResponse | Exception | bytes | str | bool]
| MocketBytesResponse
| Exception
| bytes
| str
| bool,
) -> None:
if not isinstance(responses, list):
responses = [responses]

if not responses:
self.responses = [self.response_cls(encode_to_bytes(""))]
else:
self.responses = []
for r in responses:
if not isinstance(r, BaseException) and not getattr(r, "data", False):
if isinstance(r, str):
r = encode_to_bytes(r)
r = self.response_cls(r)
self.responses.append(r)

def __repr__(self):
return f"{self.__class__.__name__}(location={self.location})"

@staticmethod
def can_handle(data):
return True

def collect(self, data):
req = self.request_cls(data)
Mocket.collect(req)

def get_response(self):
response = self.responses[self.response_index]
if self.response_index < len(self.responses) - 1:
self.response_index += 1

self._served = True

if isinstance(response, BaseException):
raise response
_responses = []
for response in responses:
if not isinstance(response, (MocketBytesResponse, Exception)):
response = MocketBytesResponse(response)
_responses.append(response)

return response.data
super().__init__(address=location, responses=_responses)
38 changes: 14 additions & 24 deletions mocket/core/mocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,19 @@
import mocket.core.inject
from mocket.core.recording import MocketRecordStorage

# NOTE this is here for backwards-compat to keep old import-paths working
# from mocket.socket import MocketSocket as MocketSocket

if TYPE_CHECKING:
from mocket.compat.entry import MocketEntry
from mocket.core.entry import MocketBaseEntry
from mocket.core.entry import MocketBaseEntry, MocketBaseRequest
from mocket.core.types import Address


class Mocket:
_socket_pairs: ClassVar[dict[Address, tuple[int, int]]] = {}
_address: ClassVar[Address] = (None, None)
_entries: ClassVar[dict[Address, list[MocketEntry | MocketBaseEntry]]] = (
collections.defaultdict(list)
_entries: ClassVar[dict[Address, list[MocketBaseEntry]]] = collections.defaultdict(
list
)
_requests: ClassVar[list] = []
_requests: ClassVar[list[MocketBaseRequest]] = []
_last_entry: ClassVar[MocketBaseEntry | None] = None
_record_storage: ClassVar[MocketRecordStorage | None] = None

@classmethod
Expand Down Expand Up @@ -73,18 +70,12 @@ def set_pair(cls, address: Address, pair: tuple[int, int]) -> None:
cls._socket_pairs[address] = pair

@classmethod
def register(cls, *entries: MocketEntry | MocketBaseEntry) -> None:
def register(cls, *entries: MocketBaseEntry) -> None:
for entry in entries:
address = entry.location if hasattr(entry, "location") else entry.address
cls._entries[address].append(entry)
cls._entries[entry.address].append(entry)

@classmethod
def get_entry(
cls,
host: str,
port: int,
data: bytes,
) -> MocketEntry | MocketBaseEntry | None:
def get_entry(cls, host: str, port: int, data) -> MocketBaseEntry | None:
host = host or cls._address[0]
port = port or cls._address[1]
entries = cls._entries.get((host, port), [])
Expand All @@ -108,12 +99,13 @@ def reset(cls) -> None:
cls._record_storage = None

@classmethod
def last_request(cls):
def last_request(cls) -> MocketBaseRequest | None:
if cls.has_requests():
return cls._requests[-1]
return None

@classmethod
def request_list(cls):
def request_list(cls) -> list[MocketBaseRequest]:
return cls._requests

@classmethod
Expand All @@ -140,9 +132,7 @@ def get_truesocket_recording_dir(cls) -> str | None:
@classmethod
def assert_fail_if_entries_not_served(cls) -> None:
"""Mocket checks that all entries have been served at least once."""

def served(entry: MocketEntry | MocketBaseEntry) -> bool | None:
return entry._served if hasattr(entry, "_served") else entry.served_response

if not all(served(entry) for entry in itertools.chain(*cls._entries.values())):
if not all(
entry.served_response for entry in itertools.chain(*cls._entries.values())
):
raise AssertionError("Some Mocket entries have not been served")
4 changes: 2 additions & 2 deletions mocket/core/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from typing_extensions import Self

from mocket.compat.entry import MocketEntry
from mocket.core.entry import MocketBaseEntry
from mocket.core.io import MocketSocketIO
from mocket.core.mocket import Mocket
from mocket.core.mode import MocketMode
Expand Down Expand Up @@ -167,7 +167,7 @@ def connect(self, address: Address) -> None:
def makefile(self, mode: str = "r", bufsize: int = -1) -> MocketSocketIO:
return self.io

def get_entry(self, data: bytes) -> MocketEntry | None:
def get_entry(self, data: bytes) -> MocketBaseEntry | None:
return Mocket.get_entry(self._host, self._port, data)

def sendall(self, data, entry=None, *args, **kwargs):
Expand Down

0 comments on commit 0a94080

Please sign in to comment.