diff --git a/mocket/__init__.py b/mocket/__init__.py index f72917a0..d64cb11d 100644 --- a/mocket/__init__.py +++ b/mocket/__init__.py @@ -1,5 +1,8 @@ from mocket.async_mocket import async_mocketize -from mocket.mocket import FakeSSLContext, Mocket, MocketEntry, Mocketizer, mocketize +from mocket.entry import MocketEntry +from mocket.mocket import Mocket +from mocket.mocketizer import Mocketizer, mocketize +from mocket.ssl import FakeSSLContext __all__ = ( "async_mocketize", diff --git a/mocket/async_mocket.py b/mocket/async_mocket.py index c0f77253..709d225f 100644 --- a/mocket/async_mocket.py +++ b/mocket/async_mocket.py @@ -1,4 +1,4 @@ -from mocket.mocket import Mocketizer +from mocket.mocketizer import Mocketizer from mocket.utils import get_mocketize diff --git a/mocket/entry.py b/mocket/entry.py new file mode 100644 index 00000000..8fa28bc7 --- /dev/null +++ b/mocket/entry.py @@ -0,0 +1,59 @@ +import collections.abc + +from mocket.compat import encode_to_bytes + + +class MocketEntry: + class Response(bytes): + @property + def data(self): + return self + + response_index = 0 + request_cls = bytes + response_cls = Response + responses = None + _served = None + + def __init__(self, location, responses): + self._served = False + self.location = location + + if not isinstance(responses, collections.abc.Iterable): + responses = [responses] + + if not responses: + self.responses = [self.response_cls(encode_to_bytes(""))] + else: + self.responses = [] + for r in responses: + if not isinstance(r, BaseException) and not getattr(r, "data", False): + if isinstance(r, str): + r = encode_to_bytes(r) + r = self.response_cls(r) + self.responses.append(r) + + def __repr__(self): + return f"{self.__class__.__name__}(location={self.location})" + + @staticmethod + def can_handle(data): + return True + + def collect(self, data): + from mocket import Mocket + + req = self.request_cls(data) + Mocket.collect(req) + + def get_response(self): + response = self.responses[self.response_index] + if self.response_index < len(self.responses) - 1: + self.response_index += 1 + + self._served = True + + if isinstance(response, BaseException): + raise response + + return response.data diff --git a/mocket/io.py b/mocket/io.py new file mode 100644 index 00000000..45bb8272 --- /dev/null +++ b/mocket/io.py @@ -0,0 +1,17 @@ +import io +import os + + +class MocketSocketCore(io.BytesIO): + def __init__(self, address) -> None: + self._address = address + super().__init__() + + def write(self, content): + from mocket import Mocket + + super().write(content) + + _, w_fd = Mocket.get_pair(self._address) + if w_fd: + os.write(w_fd, content) diff --git a/mocket/mocket.py b/mocket/mocket.py index 81a42bfb..6bb0e566 100644 --- a/mocket/mocket.py +++ b/mocket/mocket.py @@ -1,16 +1,8 @@ import collections -import collections.abc as collections_abc -import contextlib -import errno -import hashlib import itertools -import json import os -import select import socket import ssl -from datetime import datetime, timedelta -from json.decoder import JSONDecodeError from typing import Optional, Tuple import urllib3 @@ -23,22 +15,8 @@ urllib3_wrap_socket = None -from mocket.compat import decode_from_bytes, encode_to_bytes -from mocket.utils import ( - MocketMode, - MocketSocketCore, - get_mocketize, - hexdump, - hexload, -) - -xxh32 = None -try: - from xxhash import xxh32 -except ImportError: # pragma: no cover - with contextlib.suppress(ImportError): - from xxhash_cffi import xxh32 -hasher = xxh32 or hashlib.md5 +from mocket.socket import MocketSocket, create_connection, socketpair +from mocket.ssl import FakeSSLContext try: # pragma: no cover from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3 @@ -64,364 +42,6 @@ true_urllib3_match_hostname = urllib3_match_hostname -class SuperFakeSSLContext: - """For Python 3.6 and newer.""" - - class FakeSetter(int): - def __set__(self, *args): - pass - - minimum_version = FakeSetter() - options = FakeSetter() - verify_mode = FakeSetter() - verify_flags = FakeSetter() - - -class FakeSSLContext(SuperFakeSSLContext): - DUMMY_METHODS = ( - "load_default_certs", - "load_verify_locations", - "set_alpn_protocols", - "set_ciphers", - "set_default_verify_paths", - ) - sock = None - post_handshake_auth = None - _check_hostname = False - - @property - def check_hostname(self): - return self._check_hostname - - @check_hostname.setter - def check_hostname(self, _): - self._check_hostname = False - - def __init__(self, *args, **kwargs): - self._set_dummy_methods() - - def _set_dummy_methods(self): - def dummy_method(*args, **kwargs): - pass - - for m in self.DUMMY_METHODS: - setattr(self, m, dummy_method) - - @staticmethod - def wrap_socket(sock, *args, **kwargs): - sock.kwargs = kwargs - sock._secure_socket = True - return sock - - @staticmethod - def wrap_bio(incoming, outcoming, *args, **kwargs): - ssl_obj = MocketSocket() - ssl_obj._host = kwargs["server_hostname"] - return ssl_obj - - -def create_connection(address, timeout=None, source_address=None): - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP) - if timeout: - s.settimeout(timeout) - s.connect(address) - return s - - -def socketpair(*args, **kwargs): - """Returns a real socketpair() used by asyncio loop for supporting calls made by fastapi and similar services.""" - import _socket - - return _socket.socketpair(*args, **kwargs) - - -def _hash_request(h, req): - return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest() - - -class MocketSocket: - timeout = None - _fd = None - family = None - type = None - proto = None - _host = None - _port = None - _address = None - cipher = lambda s: ("ADH", "AES256", "SHA") - compression = lambda s: ssl.OP_NO_COMPRESSION - _mode = None - _bufsize = None - _secure_socket = False - _did_handshake = False - _sent_non_empty_bytes = False - _io = None - - def __init__( - self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, **kwargs - ): - self.true_socket = true_socket(family, type, proto) - self._buflen = 65536 - self._entry = None - self.family = int(family) - self.type = int(type) - self.proto = int(proto) - self._truesocket_recording_dir = None - self.kwargs = kwargs - - def __str__(self): - return f"({self.__class__.__name__})(family={self.family} type={self.type} protocol={self.proto})" - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() - - @property - def io(self): - if self._io is None: - self._io = MocketSocketCore((self._host, self._port)) - return self._io - - def fileno(self): - address = (self._host, self._port) - r_fd, _ = Mocket.get_pair(address) - if not r_fd: - r_fd, w_fd = os.pipe() - Mocket.set_pair(address, (r_fd, w_fd)) - return r_fd - - def gettimeout(self): - return self.timeout - - def setsockopt(self, family, type, proto): - self.family = family - self.type = type - self.proto = proto - - if self.true_socket: - self.true_socket.setsockopt(family, type, proto) - - def settimeout(self, timeout): - self.timeout = timeout - - @staticmethod - def getsockopt(level, optname, buflen=None): - return socket.SOCK_STREAM - - def do_handshake(self): - self._did_handshake = True - - def getpeername(self): - return self._address - - def setblocking(self, block): - self.settimeout(None) if block else self.settimeout(0.0) - - def getblocking(self): - return self.gettimeout() is None - - def getsockname(self): - return socket.gethostbyname(self._address[0]), self._address[1] - - def getpeercert(self, *args, **kwargs): - if not (self._host and self._port): - self._address = self._host, self._port = Mocket._address - - now = datetime.now() - shift = now + timedelta(days=30 * 12) - return { - "notAfter": shift.strftime("%b %d %H:%M:%S GMT"), - "subjectAltName": ( - ("DNS", f"*.{self._host}"), - ("DNS", self._host), - ("DNS", "*"), - ), - "subject": ( - (("organizationName", f"*.{self._host}"),), - (("organizationalUnitName", "Domain Control Validated"),), - (("commonName", f"*.{self._host}"),), - ), - } - - def unwrap(self): - return self - - def write(self, data): - return self.send(encode_to_bytes(data)) - - def connect(self, address): - self._address = self._host, self._port = address - Mocket._address = address - - def makefile(self, mode="r", bufsize=-1): - self._mode = mode - self._bufsize = bufsize - return self.io - - def get_entry(self, data): - return Mocket.get_entry(self._host, self._port, data) - - def sendall(self, data, entry=None, *args, **kwargs): - if entry is None: - entry = self.get_entry(data) - - if entry: - consume_response = entry.collect(data) - response = entry.get_response() if consume_response is not False else None - else: - response = self.true_sendall(data, *args, **kwargs) - - if response is not None: - self.io.seek(0) - self.io.write(response) - self.io.truncate() - self.io.seek(0) - - def read(self, buffersize): - rv = self.io.read(buffersize) - if rv: - self._sent_non_empty_bytes = True - if self._did_handshake and not self._sent_non_empty_bytes: - raise ssl.SSLWantReadError("The operation did not complete (read)") - return rv - - def recv_into(self, buffer, buffersize=None, flags=None): - if hasattr(buffer, "write"): - return buffer.write(self.read(buffersize)) - # buffer is a memoryview - data = self.read(buffersize) - if data: - buffer[: len(data)] = data - return len(data) - - def recv(self, buffersize, flags=None): - r_fd, _ = Mocket.get_pair((self._host, self._port)) - if r_fd: - return os.read(r_fd, buffersize) - data = self.read(buffersize) - if data: - return data - # used by Redis mock - exc = BlockingIOError() - exc.errno = errno.EWOULDBLOCK - exc.args = (0,) - raise exc - - def true_sendall(self, data, *args, **kwargs): - if not MocketMode().is_allowed((self._host, self._port)): - MocketMode.raise_not_allowed() - - req = decode_from_bytes(data) - # make request unique again - req_signature = _hash_request(hasher, req) - # port should be always a string - port = str(self._port) - - # prepare responses dictionary - responses = {} - - if Mocket.get_truesocket_recording_dir(): - path = os.path.join( - Mocket.get_truesocket_recording_dir(), Mocket.get_namespace() + ".json" - ) - # check if there's already a recorded session dumped to a JSON file - try: - with open(path) as f: - responses = json.load(f) - # if not, create a new dictionary - except (FileNotFoundError, JSONDecodeError): - pass - - try: - try: - response_dict = responses[self._host][port][req_signature] - except KeyError: - if hasher is not hashlib.md5: - # Fallback for backwards compatibility - req_signature = _hash_request(hashlib.md5, req) - response_dict = responses[self._host][port][req_signature] - else: - raise - except KeyError: - # preventing next KeyError exceptions - responses.setdefault(self._host, {}) - responses[self._host].setdefault(port, {}) - responses[self._host][port].setdefault(req_signature, {}) - response_dict = responses[self._host][port][req_signature] - - # try to get the response from the dictionary - try: - encoded_response = hexload(response_dict["response"]) - # if not available, call the real sendall - except KeyError: - host, port = self._host, self._port - host = true_gethostbyname(host) - - if isinstance(self.true_socket, true_socket) and self._secure_socket: - self.true_socket = true_urllib3_ssl_wrap_socket( - self.true_socket, - **self.kwargs, - ) - - with contextlib.suppress(OSError, ValueError): - # already connected - self.true_socket.connect((host, port)) - self.true_socket.sendall(data, *args, **kwargs) - encoded_response = b"" - # https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L12 - while True: - more_to_read = select.select([self.true_socket], [], [], 0.1)[0] - if not more_to_read and encoded_response: - break - new_content = self.true_socket.recv(self._buflen) - if not new_content: - break - encoded_response += new_content - - # dump the resulting dictionary to a JSON file - if Mocket.get_truesocket_recording_dir(): - # update the dictionary with request and response lines - response_dict["request"] = req - response_dict["response"] = hexdump(encoded_response) - - with open(path, mode="w") as f: - f.write( - decode_from_bytes( - json.dumps(responses, indent=4, sort_keys=True) - ) - ) - - # response back to .sendall() which writes it to the Mocket socket and flush the BytesIO - return encoded_response - - def send(self, data, *args, **kwargs): # pragma: no cover - entry = self.get_entry(data) - if not entry or (entry and self._entry != entry): - kwargs["entry"] = entry - self.sendall(data, *args, **kwargs) - else: - req = Mocket.last_request() - if hasattr(req, "add_data"): - req.add_data(data) - self._entry = entry - return len(data) - - def close(self): - if self.true_socket and not self.true_socket._closed: - self.true_socket.close() - self._fd = None - - def __getattr__(self, name): - """Do nothing catchall function, for methods like shutdown()""" - - def do_nothing(*args, **kwargs): - pass - - return do_nothing - - class Mocket: _socket_pairs = {} _address = (None, None) @@ -589,148 +209,3 @@ def assert_fail_if_entries_not_served(cls): """Mocket checks that all entries have been served at least once.""" if not all(entry._served for entry in itertools.chain(*cls._entries.values())): raise AssertionError("Some Mocket entries have not been served") - - -class MocketEntry: - class Response(bytes): - @property - def data(self): - return self - - response_index = 0 - request_cls = bytes - response_cls = Response - responses = None - _served = None - - def __init__(self, location, responses): - self._served = False - self.location = location - - if not isinstance(responses, collections_abc.Iterable): - responses = [responses] - - if not responses: - self.responses = [self.response_cls(encode_to_bytes(""))] - else: - self.responses = [] - for r in responses: - if not isinstance(r, BaseException) and not getattr(r, "data", False): - if isinstance(r, str): - r = encode_to_bytes(r) - r = self.response_cls(r) - self.responses.append(r) - - def __repr__(self): - return f"{self.__class__.__name__}(location={self.location})" - - @staticmethod - def can_handle(data): - return True - - def collect(self, data): - req = self.request_cls(data) - Mocket.collect(req) - - def get_response(self): - response = self.responses[self.response_index] - if self.response_index < len(self.responses) - 1: - self.response_index += 1 - - self._served = True - - if isinstance(response, BaseException): - raise response - - return response.data - - -class Mocketizer: - def __init__( - self, - instance=None, - namespace=None, - truesocket_recording_dir=None, - strict_mode=False, - strict_mode_allowed=None, - ): - self.instance = instance - self.truesocket_recording_dir = truesocket_recording_dir - self.namespace = namespace or str(id(self)) - MocketMode().STRICT = strict_mode - if strict_mode: - MocketMode().STRICT_ALLOWED = strict_mode_allowed or [] - elif strict_mode_allowed: - raise ValueError( - "Allowed locations are only accepted when STRICT mode is active." - ) - - def enter(self): - Mocket.enable( - namespace=self.namespace, - truesocket_recording_dir=self.truesocket_recording_dir, - ) - if self.instance: - self.check_and_call("mocketize_setup") - - def __enter__(self): - self.enter() - return self - - def exit(self): - if self.instance: - self.check_and_call("mocketize_teardown") - Mocket.disable() - - def __exit__(self, type, value, tb): - self.exit() - - async def __aenter__(self, *args, **kwargs): - self.enter() - return self - - async def __aexit__(self, *args, **kwargs): - self.exit() - - def check_and_call(self, method_name): - method = getattr(self.instance, method_name, None) - if callable(method): - method() - - @staticmethod - def factory(test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args): - instance = args[0] if args else None - namespace = None - if truesocket_recording_dir: - namespace = ".".join( - ( - instance.__class__.__module__, - instance.__class__.__name__, - test.__name__, - ) - ) - - return Mocketizer( - instance, - namespace=namespace, - truesocket_recording_dir=truesocket_recording_dir, - strict_mode=strict_mode, - strict_mode_allowed=strict_mode_allowed, - ) - - -def wrapper( - test, - truesocket_recording_dir=None, - strict_mode=False, - strict_mode_allowed=None, - *args, - **kwargs, -): - with Mocketizer.factory( - test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args - ): - return test(*args, **kwargs) - - -mocketize = get_mocketize(wrapper_=wrapper) diff --git a/mocket/mocketizer.py b/mocket/mocketizer.py new file mode 100644 index 00000000..5a988c77 --- /dev/null +++ b/mocket/mocketizer.py @@ -0,0 +1,97 @@ +from mocket.mode import MocketMode +from mocket.utils import get_mocketize + + +class Mocketizer: + def __init__( + self, + instance=None, + namespace=None, + truesocket_recording_dir=None, + strict_mode=False, + strict_mode_allowed=None, + ): + self.instance = instance + self.truesocket_recording_dir = truesocket_recording_dir + self.namespace = namespace or str(id(self)) + MocketMode().STRICT = strict_mode + if strict_mode: + MocketMode().STRICT_ALLOWED = strict_mode_allowed or [] + elif strict_mode_allowed: + raise ValueError( + "Allowed locations are only accepted when STRICT mode is active." + ) + + def enter(self): + from mocket import Mocket + + Mocket.enable( + namespace=self.namespace, + truesocket_recording_dir=self.truesocket_recording_dir, + ) + if self.instance: + self.check_and_call("mocketize_setup") + + def __enter__(self): + self.enter() + return self + + def exit(self): + if self.instance: + self.check_and_call("mocketize_teardown") + from mocket import Mocket + + Mocket.disable() + + def __exit__(self, type, value, tb): + self.exit() + + async def __aenter__(self, *args, **kwargs): + self.enter() + return self + + async def __aexit__(self, *args, **kwargs): + self.exit() + + def check_and_call(self, method_name): + method = getattr(self.instance, method_name, None) + if callable(method): + method() + + @staticmethod + def factory(test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args): + instance = args[0] if args else None + namespace = None + if truesocket_recording_dir: + namespace = ".".join( + ( + instance.__class__.__module__, + instance.__class__.__name__, + test.__name__, + ) + ) + + return Mocketizer( + instance, + namespace=namespace, + truesocket_recording_dir=truesocket_recording_dir, + strict_mode=strict_mode, + strict_mode_allowed=strict_mode_allowed, + ) + + +def wrapper( + test, + truesocket_recording_dir=None, + strict_mode=False, + strict_mode_allowed=None, + *args, + **kwargs, +): + with Mocketizer.factory( + test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args + ): + return test(*args, **kwargs) + + +mocketize = get_mocketize(wrapper_=wrapper) diff --git a/mocket/mockhttp.py b/mocket/mockhttp.py index beb312c0..245a11af 100644 --- a/mocket/mockhttp.py +++ b/mocket/mockhttp.py @@ -8,7 +8,8 @@ from h11 import Request as H11Request from mocket.compat import ENCODING, decode_from_bytes, do_the_magic, encode_to_bytes -from mocket.mocket import Mocket, MocketEntry +from mocket.entry import MocketEntry +from mocket.mocket import Mocket STATUS = {k: v[0] for k, v in BaseHTTPRequestHandler.responses.items()} CRLF = "\r\n" diff --git a/mocket/mockredis.py b/mocket/mockredis.py index 4ed69e1f..fc386e2d 100644 --- a/mocket/mockredis.py +++ b/mocket/mockredis.py @@ -5,7 +5,8 @@ encode_to_bytes, shsplit, ) -from mocket.mocket import Mocket, MocketEntry +from mocket.entry import MocketEntry +from mocket.mocket import Mocket class Request: diff --git a/mocket/mode.py b/mocket/mode.py new file mode 100644 index 00000000..3c0638e5 --- /dev/null +++ b/mocket/mode.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, ClassVar + +from mocket.exceptions import StrictMocketException + +if TYPE_CHECKING: # pragma: no cover + from typing import NoReturn + + +class MocketMode: + __shared_state: ClassVar[dict[str, Any]] = {} + STRICT: ClassVar = None + STRICT_ALLOWED: ClassVar = None + + def __init__(self) -> None: + self.__dict__ = self.__shared_state + + def is_allowed(self, location: str | tuple[str, int]) -> bool: + """ + Checks if (`host`, `port`) or at least `host` + are allowed locations to perform real `socket` calls + """ + if not self.STRICT: + return True + + host_allowed = False + if isinstance(location, tuple): + host_allowed = location[0] in self.STRICT_ALLOWED + return host_allowed or location in self.STRICT_ALLOWED + + @staticmethod + def raise_not_allowed() -> NoReturn: + from mocket.mocket import Mocket + + current_entries = [ + (location, "\n ".join(map(str, entries))) + for location, entries in Mocket._entries.items() + ] + formatted_entries = "\n".join( + [f" {location}:\n {entries}" for location, entries in current_entries] + ) + raise StrictMocketException( + "Mocket tried to use the real `socket` module while STRICT mode was active.\n" + f"Registered entries:\n{formatted_entries}" + ) diff --git a/mocket/socket.py b/mocket/socket.py new file mode 100644 index 00000000..3a971af5 --- /dev/null +++ b/mocket/socket.py @@ -0,0 +1,346 @@ +import contextlib +import errno +import hashlib +import json +import os +import select +import socket +import ssl +from datetime import datetime, timedelta +from json.decoder import JSONDecodeError + +from mocket.compat import decode_from_bytes, encode_to_bytes +from mocket.io import MocketSocketCore +from mocket.mode import MocketMode +from mocket.utils import hexdump, hexload + +xxh32 = None +try: + from xxhash import xxh32 +except ImportError: # pragma: no cover + with contextlib.suppress(ImportError): + from xxhash_cffi import xxh32 +hasher = xxh32 or hashlib.md5 + + +def create_connection(address, timeout=None, source_address=None): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP) + if timeout: + s.settimeout(timeout) + s.connect(address) + return s + + +def socketpair(*args, **kwargs): + """Returns a real socketpair() used by asyncio loop for supporting calls made by fastapi and similar services.""" + import _socket + + return _socket.socketpair(*args, **kwargs) + + +def _hash_request(h, req): + return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest() + + +class MocketSocket: + timeout = None + _fd = None + family = None + type = None + proto = None + _host = None + _port = None + _address = None + cipher = lambda s: ("ADH", "AES256", "SHA") + compression = lambda s: ssl.OP_NO_COMPRESSION + _mode = None + _bufsize = None + _secure_socket = False + _did_handshake = False + _sent_non_empty_bytes = False + _io = None + + def __init__( + self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, **kwargs + ): + from mocket.mocket import true_socket + + self.true_socket = true_socket(family, type, proto) + self._buflen = 65536 + self._entry = None + self.family = int(family) + self.type = int(type) + self.proto = int(proto) + self._truesocket_recording_dir = None + self.kwargs = kwargs + + def __str__(self): + return f"({self.__class__.__name__})(family={self.family} type={self.type} protocol={self.proto})" + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + @property + def io(self): + if self._io is None: + self._io = MocketSocketCore((self._host, self._port)) + return self._io + + def fileno(self): + from mocket.mocket import Mocket + + address = (self._host, self._port) + r_fd, _ = Mocket.get_pair(address) + if not r_fd: + r_fd, w_fd = os.pipe() + Mocket.set_pair(address, (r_fd, w_fd)) + return r_fd + + def gettimeout(self): + return self.timeout + + def setsockopt(self, family, type, proto): + self.family = family + self.type = type + self.proto = proto + + if self.true_socket: + self.true_socket.setsockopt(family, type, proto) + + def settimeout(self, timeout): + self.timeout = timeout + + @staticmethod + def getsockopt(level, optname, buflen=None): + return socket.SOCK_STREAM + + def do_handshake(self): + self._did_handshake = True + + def getpeername(self): + return self._address + + def setblocking(self, block): + self.settimeout(None) if block else self.settimeout(0.0) + + def getblocking(self): + return self.gettimeout() is None + + def getsockname(self): + return socket.gethostbyname(self._address[0]), self._address[1] + + def getpeercert(self, *args, **kwargs): + from mocket.mocket import Mocket + + if not (self._host and self._port): + self._address = self._host, self._port = Mocket._address + + now = datetime.now() + shift = now + timedelta(days=30 * 12) + return { + "notAfter": shift.strftime("%b %d %H:%M:%S GMT"), + "subjectAltName": ( + ("DNS", f"*.{self._host}"), + ("DNS", self._host), + ("DNS", "*"), + ), + "subject": ( + (("organizationName", f"*.{self._host}"),), + (("organizationalUnitName", "Domain Control Validated"),), + (("commonName", f"*.{self._host}"),), + ), + } + + def unwrap(self): + return self + + def write(self, data): + return self.send(encode_to_bytes(data)) + + def connect(self, address): + from mocket.mocket import Mocket + + self._address = self._host, self._port = address + Mocket._address = address + + def makefile(self, mode="r", bufsize=-1): + self._mode = mode + self._bufsize = bufsize + return self.io + + def get_entry(self, data): + from mocket.mocket import Mocket + + return Mocket.get_entry(self._host, self._port, data) + + def sendall(self, data, entry=None, *args, **kwargs): + if entry is None: + entry = self.get_entry(data) + + if entry: + consume_response = entry.collect(data) + response = entry.get_response() if consume_response is not False else None + else: + response = self.true_sendall(data, *args, **kwargs) + + if response is not None: + self.io.seek(0) + self.io.write(response) + self.io.truncate() + self.io.seek(0) + + def read(self, buffersize): + rv = self.io.read(buffersize) + if rv: + self._sent_non_empty_bytes = True + if self._did_handshake and not self._sent_non_empty_bytes: + raise ssl.SSLWantReadError("The operation did not complete (read)") + return rv + + def recv_into(self, buffer, buffersize=None, flags=None): + if hasattr(buffer, "write"): + return buffer.write(self.read(buffersize)) + # buffer is a memoryview + data = self.read(buffersize) + if data: + buffer[: len(data)] = data + return len(data) + + def recv(self, buffersize, flags=None): + from mocket.mocket import Mocket + + r_fd, _ = Mocket.get_pair((self._host, self._port)) + if r_fd: + return os.read(r_fd, buffersize) + data = self.read(buffersize) + if data: + return data + # used by Redis mock + exc = BlockingIOError() + exc.errno = errno.EWOULDBLOCK + exc.args = (0,) + raise exc + + def true_sendall(self, data, *args, **kwargs): + from mocket.mocket import ( + Mocket, + true_gethostbyname, + true_socket, + true_urllib3_ssl_wrap_socket, + ) + + if not MocketMode().is_allowed((self._host, self._port)): + MocketMode.raise_not_allowed() + + req = decode_from_bytes(data) + # make request unique again + req_signature = _hash_request(hasher, req) + # port should be always a string + port = str(self._port) + + # prepare responses dictionary + responses = {} + + if Mocket.get_truesocket_recording_dir(): + path = os.path.join( + Mocket.get_truesocket_recording_dir(), Mocket.get_namespace() + ".json" + ) + # check if there's already a recorded session dumped to a JSON file + try: + with open(path) as f: + responses = json.load(f) + # if not, create a new dictionary + except (FileNotFoundError, JSONDecodeError): + pass + + try: + try: + response_dict = responses[self._host][port][req_signature] + except KeyError: + if hasher is not hashlib.md5: + # Fallback for backwards compatibility + req_signature = _hash_request(hashlib.md5, req) + response_dict = responses[self._host][port][req_signature] + else: + raise + except KeyError: + # preventing next KeyError exceptions + responses.setdefault(self._host, {}) + responses[self._host].setdefault(port, {}) + responses[self._host][port].setdefault(req_signature, {}) + response_dict = responses[self._host][port][req_signature] + + # try to get the response from the dictionary + try: + encoded_response = hexload(response_dict["response"]) + # if not available, call the real sendall + except KeyError: + host, port = self._host, self._port + host = true_gethostbyname(host) + + if isinstance(self.true_socket, true_socket) and self._secure_socket: + self.true_socket = true_urllib3_ssl_wrap_socket( + self.true_socket, + **self.kwargs, + ) + + with contextlib.suppress(OSError, ValueError): + # already connected + self.true_socket.connect((host, port)) + self.true_socket.sendall(data, *args, **kwargs) + encoded_response = b"" + # https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L12 + while True: + more_to_read = select.select([self.true_socket], [], [], 0.1)[0] + if not more_to_read and encoded_response: + break + new_content = self.true_socket.recv(self._buflen) + if not new_content: + break + encoded_response += new_content + + # dump the resulting dictionary to a JSON file + if Mocket.get_truesocket_recording_dir(): + # update the dictionary with request and response lines + response_dict["request"] = req + response_dict["response"] = hexdump(encoded_response) + + with open(path, mode="w") as f: + f.write( + decode_from_bytes( + json.dumps(responses, indent=4, sort_keys=True) + ) + ) + + # response back to .sendall() which writes it to the Mocket socket and flush the BytesIO + return encoded_response + + def send(self, data, *args, **kwargs): # pragma: no cover + from mocket.mocket import Mocket + + entry = self.get_entry(data) + if not entry or (entry and self._entry != entry): + kwargs["entry"] = entry + self.sendall(data, *args, **kwargs) + else: + req = Mocket.last_request() + if hasattr(req, "add_data"): + req.add_data(data) + self._entry = entry + return len(data) + + def close(self): + if self.true_socket and not self.true_socket._closed: + self.true_socket.close() + self._fd = None + + def __getattr__(self, name): + """Do nothing catchall function, for methods like shutdown()""" + + def do_nothing(*args, **kwargs): + pass + + return do_nothing diff --git a/mocket/ssl.py b/mocket/ssl.py new file mode 100644 index 00000000..e4ae44cf --- /dev/null +++ b/mocket/ssl.py @@ -0,0 +1,56 @@ +class SuperFakeSSLContext: + """For Python 3.6 and newer.""" + + class FakeSetter(int): + def __set__(self, *args): + pass + + minimum_version = FakeSetter() + options = FakeSetter() + verify_mode = FakeSetter() + verify_flags = FakeSetter() + + +class FakeSSLContext(SuperFakeSSLContext): + DUMMY_METHODS = ( + "load_default_certs", + "load_verify_locations", + "set_alpn_protocols", + "set_ciphers", + "set_default_verify_paths", + ) + sock = None + post_handshake_auth = None + _check_hostname = False + + @property + def check_hostname(self): + return self._check_hostname + + @check_hostname.setter + def check_hostname(self, _): + self._check_hostname = False + + def __init__(self, *args, **kwargs): + self._set_dummy_methods() + + def _set_dummy_methods(self): + def dummy_method(*args, **kwargs): + pass + + for m in self.DUMMY_METHODS: + setattr(self, m, dummy_method) + + @staticmethod + def wrap_socket(sock, *args, **kwargs): + sock.kwargs = kwargs + sock._secure_socket = True + return sock + + @staticmethod + def wrap_bio(incoming, outcoming, *args, **kwargs): + from mocket.socket import MocketSocket + + ssl_obj = MocketSocket() + ssl_obj._host = kwargs["server_hostname"] + return ssl_obj diff --git a/mocket/utils.py b/mocket/utils.py index 35cfcea8..f94b78f7 100644 --- a/mocket/utils.py +++ b/mocket/utils.py @@ -1,36 +1,20 @@ from __future__ import annotations import binascii -import io -import os import ssl -from typing import TYPE_CHECKING, Any, Callable, ClassVar +from typing import Callable from mocket.compat import decode_from_bytes, encode_to_bytes -from mocket.exceptions import StrictMocketException -if TYPE_CHECKING: # pragma: no cover - from typing import NoReturn +# NOTE this is here for backwards-compat to keep old import-paths working +from mocket.io import MocketSocketCore as MocketSocketCore +# NOTE this is here for backwards-compat to keep old import-paths working +from mocket.mode import MocketMode as MocketMode SSL_PROTOCOL = ssl.PROTOCOL_TLSv1_2 -class MocketSocketCore(io.BytesIO): - def __init__(self, address) -> None: - self._address = address - super().__init__() - - def write(self, content): - from mocket import Mocket - - super().write(content) - - _, w_fd = Mocket.get_pair(self._address) - if w_fd: - os.write(w_fd, content) - - def hexdump(binary_string: bytes) -> str: r""" >>> hexdump(b"bar foobar foo") == decode_from_bytes(encode_to_bytes("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F")) @@ -58,41 +42,3 @@ def get_mocketize(wrapper_: Callable) -> Callable: wrapper_, kwsyntax=True, ) - - -class MocketMode: - __shared_state: ClassVar[dict[str, Any]] = {} - STRICT: ClassVar = None - STRICT_ALLOWED: ClassVar = None - - def __init__(self) -> None: - self.__dict__ = self.__shared_state - - def is_allowed(self, location: str | tuple[str, int]) -> bool: - """ - Checks if (`host`, `port`) or at least `host` - are allowed locations to perform real `socket` calls - """ - if not self.STRICT: - return True - - host_allowed = False - if isinstance(location, tuple): - host_allowed = location[0] in self.STRICT_ALLOWED - return host_allowed or location in self.STRICT_ALLOWED - - @staticmethod - def raise_not_allowed() -> NoReturn: - from mocket.mocket import Mocket - - current_entries = [ - (location, "\n ".join(map(str, entries))) - for location, entries in Mocket._entries.items() - ] - formatted_entries = "\n".join( - [f" {location}:\n {entries}" for location, entries in current_entries] - ) - raise StrictMocketException( - "Mocket tried to use the real `socket` module while STRICT mode was active.\n" - f"Registered entries:\n{formatted_entries}" - )