diff --git a/mocket/async_mocket.py b/mocket/async_mocket.py index 936ec22d..2970e0f4 100644 --- a/mocket/async_mocket.py +++ b/mocket/async_mocket.py @@ -8,7 +8,7 @@ async def wrapper( strict_mode=False, strict_mode_allowed=None, *args, - **kwargs + **kwargs, ): async with Mocketizer.factory( test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args diff --git a/mocket/mocket.py b/mocket/mocket.py index 966ebc76..c2c065cf 100644 --- a/mocket/mocket.py +++ b/mocket/mocket.py @@ -22,7 +22,6 @@ urllib3_wrap_socket = None from .compat import basestring, byte_type, decode_from_bytes, encode_to_bytes, text_type -from .exceptions import StrictMocketException from .utils import ( SSL_PROTOCOL, MocketMode, @@ -333,22 +332,8 @@ def recv(self, buffersize, flags=None): raise exc def true_sendall(self, data, *args, **kwargs): - if MocketMode().STRICT: - if not MocketMode().allowed((self._host, self._port)): - current_entries = [ - (location, "\n ".join(map(str, entries))) - for location, entries in Mocket._entries.items() - ] - formatted_entries = "\n".join( - [ - f" {location}:\n {entries}" - for location, entries in current_entries - ] - ) - raise StrictMocketException( - "Mocket tried to use the real `socket` module while strict mode is active.\n" - f"Registered entries:\n{formatted_entries}" - ) + if not MocketMode().is_allowed((self._host, self._port)): + MocketMode.raise_not_allowed() req = decode_from_bytes(data) # make request unique again @@ -693,7 +678,12 @@ def __init__( self.truesocket_recording_dir = truesocket_recording_dir self.namespace = namespace or text_type(id(self)) MocketMode().STRICT = strict_mode - MocketMode().STRICT_ALLOWED = strict_mode_allowed + if strict_mode: + MocketMode().STRICT_ALLOWED = strict_mode_allowed or [] + elif strict_mode_allowed: + raise ValueError( + "Allowed locations are only accepted when STRICT mode is active." + ) def enter(self): Mocket.enable( @@ -755,7 +745,7 @@ def wrapper( strict_mode=False, strict_mode_allowed=None, *args, - **kwargs + **kwargs, ): with Mocketizer.factory( test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args diff --git a/mocket/mockhttp.py b/mocket/mockhttp.py index 60b134a1..4ab3345d 100644 --- a/mocket/mockhttp.py +++ b/mocket/mockhttp.py @@ -65,9 +65,11 @@ def headers(self): @property def querystring(self): parts = self._protocol.url.split("?", 1) - if len(parts) == 2: - return parse_qs(unquote(parts[1]), keep_blank_values=True) - return {} + return ( + parse_qs(unquote(parts[1]), keep_blank_values=True) + if len(parts) == 2 + else {} + ) @property def body(self): diff --git a/mocket/plugins/httpretty/__init__.py b/mocket/plugins/httpretty/__init__.py index 5aaebeb1..bf1e7e21 100644 --- a/mocket/plugins/httpretty/__init__.py +++ b/mocket/plugins/httpretty/__init__.py @@ -70,9 +70,8 @@ def register_uri( responses=None, match_querystring=False, priority=0, - **headers + **headers, ): - headers = httprettifier_headers(headers) if adding_headers is not None: @@ -101,7 +100,6 @@ def force_headers(self): class MocketHTTPretty: - Response = Response def __getattr__(self, name): diff --git a/mocket/utils.py b/mocket/utils.py index 88d7868c..2f17838b 100644 --- a/mocket/utils.py +++ b/mocket/utils.py @@ -2,8 +2,10 @@ import io import os import ssl +from typing import Tuple, Union from .compat import decode_from_bytes, encode_to_bytes +from .exceptions import StrictMocketException SSL_PROTOCOL = ssl.PROTOCOL_TLSv1_2 @@ -52,13 +54,28 @@ class MocketMode: def __init__(self): self.__dict__ = self.__shared_state - def allowed(self, location): - if not self.STRICT_ALLOWED: - return False - host, port = location - for allowed in self.STRICT_ALLOWED: - if isinstance(allowed, str) and host == allowed: - return True - elif location == allowed: - return True - return False + def is_allowed(self, location: Union[str, Tuple[str, int]]) -> bool: + """ + Checks if (`host`, `port`) or at least `host` + are allowed locationsto perform real `socket` calls + """ + if not self.STRICT: + return True + host, _ = location + return location in self.STRICT_ALLOWED or host in self.STRICT_ALLOWED + + @staticmethod + def raise_not_allowed(): + from .mocket import Mocket + + current_entries = [ + (location, "\n ".join(map(str, entries))) + for location, entries in Mocket._entries.items() + ] + formatted_entries = "\n".join( + [f" {location}:\n {entries}" for location, entries in current_entries] + ) + raise StrictMocketException( + "Mocket tried to use the real `socket` module while STRICT mode was active.\n" + f"Registered entries:\n{formatted_entries}" + ) diff --git a/tests/main/test_mode.py b/tests/main/test_mode.py index e3a909f3..0d2d2e7c 100644 --- a/tests/main/test_mode.py +++ b/tests/main/test_mode.py @@ -4,6 +4,7 @@ from mocket import Mocketizer, mocketize from mocket.exceptions import StrictMocketException from mocket.mockhttp import Entry, Response +from mocket.utils import MocketMode @mocketize(strict_mode=True) @@ -51,9 +52,20 @@ def test_strict_mode_error_message(): assert ( str(exc_info.value) == """ -Mocket tried to use the real `socket` module while strict mode is active. +Mocket tried to use the real `socket` module while STRICT mode was active. Registered entries: ('httpbin.local', 80): Entry(method='GET', schema='http', location=('httpbin.local', 80), path='/user.agent', query='') """.strip() ) + + +def test_strict_mode_false_with_allowed_hosts(): + with pytest.raises(ValueError): + Mocketizer(strict_mode=False, strict_mode_allowed=["foobar.local"]) + + +def test_strict_mode_false_always_allowed(): + with Mocketizer(strict_mode=False): + assert MocketMode().is_allowed("foobar.com") + assert MocketMode().is_allowed(("foobar.com", 443))