diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d519ef75..5bb32812 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -40,7 +40,7 @@ jobs: make services-up - name: Setup hostname run: | - export CONTAINER_ID=$(docker-compose ps -q proxy) + export CONTAINER_ID=$(docker compose ps -q proxy) export CONTAINER_IP=$(docker inspect -f '{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}' $CONTAINER_ID) echo "$CONTAINER_IP httpbin.local" | sudo tee -a /etc/hosts - name: Test diff --git a/README.rst b/README.rst index f6e027c7..d8732c62 100644 --- a/README.rst +++ b/README.rst @@ -169,6 +169,20 @@ NEW!!! Sometimes you just want your tests to fail when they attempt to use the n with pytest.raises(StrictMocketException): requests.get("https://duckduckgo.com/") +You can specify exceptions as a list of hosts or host-port pairs. + +.. code-block:: python + + with Mocketizer(strict_mode=True, strict_mode_allowed=["localhost", ("intake.ourmetrics.net", 443)]): + ... + + # OR + + @mocketize(strict_mode=True, strict_mode_allowed=["localhost", ("intake.ourmetrics.net", 443)]) + def test_get(): + ... + + How to be sure that all the Entry instances have been served? ============================================================= Add this instruction at the end of the test execution: diff --git a/mocket/async_mocket.py b/mocket/async_mocket.py index 5ebe7348..2970e0f4 100644 --- a/mocket/async_mocket.py +++ b/mocket/async_mocket.py @@ -3,9 +3,16 @@ async def wrapper( - test, truesocket_recording_dir=None, strict_mode=False, *args, **kwargs + test, + truesocket_recording_dir=None, + strict_mode=False, + strict_mode_allowed=None, + *args, + **kwargs, ): - async with Mocketizer.factory(test, truesocket_recording_dir, strict_mode, args): + async with Mocketizer.factory( + test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args + ): return await test(*args, **kwargs) diff --git a/mocket/mocket.py b/mocket/mocket.py index eeaa98e6..c2c065cf 100644 --- a/mocket/mocket.py +++ b/mocket/mocket.py @@ -22,7 +22,6 @@ urllib3_wrap_socket = None from .compat import basestring, byte_type, decode_from_bytes, encode_to_bytes, text_type -from .exceptions import StrictMocketException from .utils import ( SSL_PROTOCOL, MocketMode, @@ -333,8 +332,8 @@ def recv(self, buffersize, flags=None): raise exc def true_sendall(self, data, *args, **kwargs): - if MocketMode().STRICT: - raise StrictMocketException("Mocket tried to use the real `socket` module.") + if not MocketMode().is_allowed((self._host, self._port)): + MocketMode.raise_not_allowed() req = decode_from_bytes(data) # make request unique again @@ -642,6 +641,9 @@ def __init__(self, location, responses): r = self.response_cls(r) self.responses.append(r) + def __repr__(self): + return "{}(location={})".format(self.__class__.__name__, self.location) + @staticmethod def can_handle(data): return True @@ -670,11 +672,18 @@ def __init__( namespace=None, truesocket_recording_dir=None, strict_mode=False, + strict_mode_allowed=None, ): self.instance = instance self.truesocket_recording_dir = truesocket_recording_dir self.namespace = namespace or text_type(id(self)) MocketMode().STRICT = strict_mode + if strict_mode: + MocketMode().STRICT_ALLOWED = strict_mode_allowed or [] + elif strict_mode_allowed: + raise ValueError( + "Allowed locations are only accepted when STRICT mode is active." + ) def enter(self): Mocket.enable( @@ -709,7 +718,7 @@ def check_and_call(self, method_name): method() @staticmethod - def factory(test, truesocket_recording_dir, strict_mode, args): + def factory(test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args): instance = args[0] if args else None namespace = None if truesocket_recording_dir: @@ -726,11 +735,21 @@ def factory(test, truesocket_recording_dir, strict_mode, args): namespace=namespace, truesocket_recording_dir=truesocket_recording_dir, strict_mode=strict_mode, + strict_mode_allowed=strict_mode_allowed, ) -def wrapper(test, truesocket_recording_dir=None, strict_mode=False, *args, **kwargs): - with Mocketizer.factory(test, truesocket_recording_dir, strict_mode, args): +def wrapper( + test, + truesocket_recording_dir=None, + strict_mode=False, + strict_mode_allowed=None, + *args, + **kwargs, +): + with Mocketizer.factory( + test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args + ): return test(*args, **kwargs) diff --git a/mocket/mockhttp.py b/mocket/mockhttp.py index 8cb5cdc6..4ab3345d 100644 --- a/mocket/mockhttp.py +++ b/mocket/mockhttp.py @@ -65,9 +65,11 @@ def headers(self): @property def querystring(self): parts = self._protocol.url.split("?", 1) - if len(parts) == 2: - return parse_qs(unquote(parts[1]), keep_blank_values=True) - return {} + return ( + parse_qs(unquote(parts[1]), keep_blank_values=True) + if len(parts) == 2 + else {} + ) @property def body(self): @@ -175,6 +177,18 @@ def __init__(self, uri, method, responses, match_querystring=True): self._sent_data = b"" self._match_querystring = match_querystring + def __repr__(self): + return ( + "{}(method={!r}, schema={!r}, location={!r}, path={!r}, query={!r})".format( + self.__class__.__name__, + self.method, + self.schema, + self.location, + self.path, + self.query, + ) + ) + def collect(self, data): consume_response = True diff --git a/mocket/plugins/httpretty/__init__.py b/mocket/plugins/httpretty/__init__.py index 5aaebeb1..bf1e7e21 100644 --- a/mocket/plugins/httpretty/__init__.py +++ b/mocket/plugins/httpretty/__init__.py @@ -70,9 +70,8 @@ def register_uri( responses=None, match_querystring=False, priority=0, - **headers + **headers, ): - headers = httprettifier_headers(headers) if adding_headers is not None: @@ -101,7 +100,6 @@ def force_headers(self): class MocketHTTPretty: - Response = Response def __getattr__(self, name): diff --git a/mocket/utils.py b/mocket/utils.py index 64a2c18e..2f17838b 100644 --- a/mocket/utils.py +++ b/mocket/utils.py @@ -2,8 +2,10 @@ import io import os import ssl +from typing import Tuple, Union from .compat import decode_from_bytes, encode_to_bytes +from .exceptions import StrictMocketException SSL_PROTOCOL = ssl.PROTOCOL_TLSv1_2 @@ -47,6 +49,33 @@ def get_mocketize(wrapper_): class MocketMode: __shared_state = {} STRICT = None + STRICT_ALLOWED = None def __init__(self): self.__dict__ = self.__shared_state + + def is_allowed(self, location: Union[str, Tuple[str, int]]) -> bool: + """ + Checks if (`host`, `port`) or at least `host` + are allowed locationsto perform real `socket` calls + """ + if not self.STRICT: + return True + host, _ = location + return location in self.STRICT_ALLOWED or host in self.STRICT_ALLOWED + + @staticmethod + def raise_not_allowed(): + from .mocket import Mocket + + current_entries = [ + (location, "\n ".join(map(str, entries))) + for location, entries in Mocket._entries.items() + ] + formatted_entries = "\n".join( + [f" {location}:\n {entries}" for location, entries in current_entries] + ) + raise StrictMocketException( + "Mocket tried to use the real `socket` module while STRICT mode was active.\n" + f"Registered entries:\n{formatted_entries}" + ) diff --git a/tests/main/test_mode.py b/tests/main/test_mode.py index b104589c..0d2d2e7c 100644 --- a/tests/main/test_mode.py +++ b/tests/main/test_mode.py @@ -3,6 +3,8 @@ from mocket import Mocketizer, mocketize from mocket.exceptions import StrictMocketException +from mocket.mockhttp import Entry, Response +from mocket.utils import MocketMode @mocketize(strict_mode=True) @@ -26,3 +28,44 @@ def test_intermittent_strict_mode(): with Mocketizer(strict_mode=False): requests.get(url) + + +@pytest.mark.skipif('os.getenv("SKIP_TRUE_HTTP", False)') +def test_strict_mode_exceptions(): + url = "http://httpbin.local/ip" + + with Mocketizer(strict_mode=True, strict_mode_allowed=["httpbin.local"]): + requests.get(url) + + with Mocketizer(strict_mode=True, strict_mode_allowed=[("httpbin.local", 80)]): + requests.get(url) + + +def test_strict_mode_error_message(): + url = "http://httpbin.local/ip" + + Entry.register(Entry.GET, "http://httpbin.local/user.agent", Response(status=404)) + + with Mocketizer(strict_mode=True): + with pytest.raises(StrictMocketException) as exc_info: + requests.get(url) + assert ( + str(exc_info.value) + == """ +Mocket tried to use the real `socket` module while STRICT mode was active. +Registered entries: + ('httpbin.local', 80): + Entry(method='GET', schema='http', location=('httpbin.local', 80), path='/user.agent', query='') +""".strip() + ) + + +def test_strict_mode_false_with_allowed_hosts(): + with pytest.raises(ValueError): + Mocketizer(strict_mode=False, strict_mode_allowed=["foobar.local"]) + + +def test_strict_mode_false_always_allowed(): + with Mocketizer(strict_mode=False): + assert MocketMode().is_allowed("foobar.com") + assert MocketMode().is_allowed(("foobar.com", 443))