Skip to content

Commit

Permalink
Refactoring the new feature from @ento before merging it.
Browse files Browse the repository at this point in the history
  • Loading branch information
mindflayer committed Feb 5, 2024
1 parent d0189ae commit 1b730ef
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 37 deletions.
2 changes: 1 addition & 1 deletion mocket/async_mocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ async def wrapper(
strict_mode=False,
strict_mode_allowed=None,
*args,
**kwargs
**kwargs,
):
async with Mocketizer.factory(
test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args
Expand Down
28 changes: 9 additions & 19 deletions mocket/mocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
urllib3_wrap_socket = None

from .compat import basestring, byte_type, decode_from_bytes, encode_to_bytes, text_type
from .exceptions import StrictMocketException
from .utils import (
SSL_PROTOCOL,
MocketMode,
Expand Down Expand Up @@ -333,22 +332,8 @@ def recv(self, buffersize, flags=None):
raise exc

def true_sendall(self, data, *args, **kwargs):
if MocketMode().STRICT:
if not MocketMode().allowed((self._host, self._port)):
current_entries = [
(location, "\n ".join(map(str, entries)))
for location, entries in Mocket._entries.items()
]
formatted_entries = "\n".join(
[
f" {location}:\n {entries}"
for location, entries in current_entries
]
)
raise StrictMocketException(
"Mocket tried to use the real `socket` module while strict mode is active.\n"
f"Registered entries:\n{formatted_entries}"
)
if not MocketMode().is_allowed((self._host, self._port)):
MocketMode.raise_not_allowed()

req = decode_from_bytes(data)
# make request unique again
Expand Down Expand Up @@ -693,7 +678,12 @@ def __init__(
self.truesocket_recording_dir = truesocket_recording_dir
self.namespace = namespace or text_type(id(self))
MocketMode().STRICT = strict_mode
MocketMode().STRICT_ALLOWED = strict_mode_allowed
if strict_mode:
MocketMode().STRICT_ALLOWED = strict_mode_allowed or []
elif strict_mode_allowed:
raise ValueError(
"Allowed locations are only accepted when STRICT mode is active."
)

def enter(self):
Mocket.enable(
Expand Down Expand Up @@ -755,7 +745,7 @@ def wrapper(
strict_mode=False,
strict_mode_allowed=None,
*args,
**kwargs
**kwargs,
):
with Mocketizer.factory(
test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args
Expand Down
8 changes: 5 additions & 3 deletions mocket/mockhttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,11 @@ def headers(self):
@property
def querystring(self):
parts = self._protocol.url.split("?", 1)
if len(parts) == 2:
return parse_qs(unquote(parts[1]), keep_blank_values=True)
return {}
return (
parse_qs(unquote(parts[1]), keep_blank_values=True)
if len(parts) == 2
else {}
)

@property
def body(self):
Expand Down
4 changes: 1 addition & 3 deletions mocket/plugins/httpretty/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,8 @@ def register_uri(
responses=None,
match_querystring=False,
priority=0,
**headers
**headers,
):

headers = httprettifier_headers(headers)

if adding_headers is not None:
Expand Down Expand Up @@ -101,7 +100,6 @@ def force_headers(self):


class MocketHTTPretty:

Response = Response

def __getattr__(self, name):
Expand Down
37 changes: 27 additions & 10 deletions mocket/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import io
import os
import ssl
from typing import Tuple, Union

from .compat import decode_from_bytes, encode_to_bytes
from .exceptions import StrictMocketException

SSL_PROTOCOL = ssl.PROTOCOL_TLSv1_2

Expand Down Expand Up @@ -52,13 +54,28 @@ class MocketMode:
def __init__(self):
self.__dict__ = self.__shared_state

def allowed(self, location):
if not self.STRICT_ALLOWED:
return False
host, port = location
for allowed in self.STRICT_ALLOWED:
if isinstance(allowed, str) and host == allowed:
return True
elif location == allowed:
return True
return False
def is_allowed(self, location: Union[str, Tuple[str, int]]) -> bool:
"""
Checks if (`host`, `port`) or at least `host`
are allowed locationsto perform real `socket` calls
"""
if not self.STRICT:
return True
host, _ = location
return location in self.STRICT_ALLOWED or host in self.STRICT_ALLOWED

@staticmethod
def raise_not_allowed():
from .mocket import Mocket

current_entries = [
(location, "\n ".join(map(str, entries)))
for location, entries in Mocket._entries.items()
]
formatted_entries = "\n".join(
[f" {location}:\n {entries}" for location, entries in current_entries]
)
raise StrictMocketException(
"Mocket tried to use the real `socket` module while STRICT mode was active.\n"
f"Registered entries:\n{formatted_entries}"
)
14 changes: 13 additions & 1 deletion tests/main/test_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from mocket import Mocketizer, mocketize
from mocket.exceptions import StrictMocketException
from mocket.mockhttp import Entry, Response
from mocket.utils import MocketMode


@mocketize(strict_mode=True)
Expand Down Expand Up @@ -51,9 +52,20 @@ def test_strict_mode_error_message():
assert (
str(exc_info.value)
== """
Mocket tried to use the real `socket` module while strict mode is active.
Mocket tried to use the real `socket` module while STRICT mode was active.
Registered entries:
('httpbin.local', 80):
Entry(method='GET', schema='http', location=('httpbin.local', 80), path='/user.agent', query='')
""".strip()
)


def test_strict_mode_false_with_allowed_hosts():
with pytest.raises(ValueError):
Mocketizer(strict_mode=False, strict_mode_allowed=["foobar.local"])


def test_strict_mode_false_always_allowed():
with Mocketizer(strict_mode=False):
assert MocketMode().is_allowed("foobar.com")
assert MocketMode().is_allowed(("foobar.com", 443))

0 comments on commit 1b730ef

Please sign in to comment.