Skip to content

Commit

Permalink
Merge pull request #224 from mindflayer/chore/review-pr
Browse files Browse the repository at this point in the history
Add allowed locations (whitelist) for STRICT mode
  • Loading branch information
mindflayer authored Feb 6, 2024
2 parents 6d861c2 + 324845f commit 4bd3e4f
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
make services-up
- name: Setup hostname
run: |
export CONTAINER_ID=$(docker-compose ps -q proxy)
export CONTAINER_ID=$(docker compose ps -q proxy)
export CONTAINER_IP=$(docker inspect -f '{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}' $CONTAINER_ID)
echo "$CONTAINER_IP httpbin.local" | sudo tee -a /etc/hosts
- name: Test
Expand Down
14 changes: 14 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,20 @@ NEW!!! Sometimes you just want your tests to fail when they attempt to use the n
with pytest.raises(StrictMocketException):
requests.get("https://duckduckgo.com/")
You can specify exceptions as a list of hosts or host-port pairs.

.. code-block:: python
with Mocketizer(strict_mode=True, strict_mode_allowed=["localhost", ("intake.ourmetrics.net", 443)]):
...
# OR
@mocketize(strict_mode=True, strict_mode_allowed=["localhost", ("intake.ourmetrics.net", 443)])
def test_get():
...
How to be sure that all the Entry instances have been served?
=============================================================
Add this instruction at the end of the test execution:
Expand Down
11 changes: 9 additions & 2 deletions mocket/async_mocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,16 @@


async def wrapper(
test, truesocket_recording_dir=None, strict_mode=False, *args, **kwargs
test,
truesocket_recording_dir=None,
strict_mode=False,
strict_mode_allowed=None,
*args,
**kwargs,
):
async with Mocketizer.factory(test, truesocket_recording_dir, strict_mode, args):
async with Mocketizer.factory(
test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args
):
return await test(*args, **kwargs)


Expand Down
31 changes: 25 additions & 6 deletions mocket/mocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
urllib3_wrap_socket = None

from .compat import basestring, byte_type, decode_from_bytes, encode_to_bytes, text_type
from .exceptions import StrictMocketException
from .utils import (
SSL_PROTOCOL,
MocketMode,
Expand Down Expand Up @@ -333,8 +332,8 @@ def recv(self, buffersize, flags=None):
raise exc

def true_sendall(self, data, *args, **kwargs):
if MocketMode().STRICT:
raise StrictMocketException("Mocket tried to use the real `socket` module.")
if not MocketMode().is_allowed((self._host, self._port)):
MocketMode.raise_not_allowed()

req = decode_from_bytes(data)
# make request unique again
Expand Down Expand Up @@ -642,6 +641,9 @@ def __init__(self, location, responses):
r = self.response_cls(r)
self.responses.append(r)

def __repr__(self):
return "{}(location={})".format(self.__class__.__name__, self.location)

@staticmethod
def can_handle(data):
return True
Expand Down Expand Up @@ -670,11 +672,18 @@ def __init__(
namespace=None,
truesocket_recording_dir=None,
strict_mode=False,
strict_mode_allowed=None,
):
self.instance = instance
self.truesocket_recording_dir = truesocket_recording_dir
self.namespace = namespace or text_type(id(self))
MocketMode().STRICT = strict_mode
if strict_mode:
MocketMode().STRICT_ALLOWED = strict_mode_allowed or []
elif strict_mode_allowed:
raise ValueError(
"Allowed locations are only accepted when STRICT mode is active."
)

def enter(self):
Mocket.enable(
Expand Down Expand Up @@ -709,7 +718,7 @@ def check_and_call(self, method_name):
method()

@staticmethod
def factory(test, truesocket_recording_dir, strict_mode, args):
def factory(test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args):
instance = args[0] if args else None
namespace = None
if truesocket_recording_dir:
Expand All @@ -726,11 +735,21 @@ def factory(test, truesocket_recording_dir, strict_mode, args):
namespace=namespace,
truesocket_recording_dir=truesocket_recording_dir,
strict_mode=strict_mode,
strict_mode_allowed=strict_mode_allowed,
)


def wrapper(test, truesocket_recording_dir=None, strict_mode=False, *args, **kwargs):
with Mocketizer.factory(test, truesocket_recording_dir, strict_mode, args):
def wrapper(
test,
truesocket_recording_dir=None,
strict_mode=False,
strict_mode_allowed=None,
*args,
**kwargs,
):
with Mocketizer.factory(
test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args
):
return test(*args, **kwargs)


Expand Down
20 changes: 17 additions & 3 deletions mocket/mockhttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,11 @@ def headers(self):
@property
def querystring(self):
parts = self._protocol.url.split("?", 1)
if len(parts) == 2:
return parse_qs(unquote(parts[1]), keep_blank_values=True)
return {}
return (
parse_qs(unquote(parts[1]), keep_blank_values=True)
if len(parts) == 2
else {}
)

@property
def body(self):
Expand Down Expand Up @@ -175,6 +177,18 @@ def __init__(self, uri, method, responses, match_querystring=True):
self._sent_data = b""
self._match_querystring = match_querystring

def __repr__(self):
return (
"{}(method={!r}, schema={!r}, location={!r}, path={!r}, query={!r})".format(
self.__class__.__name__,
self.method,
self.schema,
self.location,
self.path,
self.query,
)
)

def collect(self, data):
consume_response = True

Expand Down
4 changes: 1 addition & 3 deletions mocket/plugins/httpretty/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,8 @@ def register_uri(
responses=None,
match_querystring=False,
priority=0,
**headers
**headers,
):

headers = httprettifier_headers(headers)

if adding_headers is not None:
Expand Down Expand Up @@ -101,7 +100,6 @@ def force_headers(self):


class MocketHTTPretty:

Response = Response

def __getattr__(self, name):
Expand Down
29 changes: 29 additions & 0 deletions mocket/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import io
import os
import ssl
from typing import Tuple, Union

from .compat import decode_from_bytes, encode_to_bytes
from .exceptions import StrictMocketException

SSL_PROTOCOL = ssl.PROTOCOL_TLSv1_2

Expand Down Expand Up @@ -47,6 +49,33 @@ def get_mocketize(wrapper_):
class MocketMode:
__shared_state = {}
STRICT = None
STRICT_ALLOWED = None

def __init__(self):
self.__dict__ = self.__shared_state

def is_allowed(self, location: Union[str, Tuple[str, int]]) -> bool:
"""
Checks if (`host`, `port`) or at least `host`
are allowed locationsto perform real `socket` calls
"""
if not self.STRICT:
return True
host, _ = location
return location in self.STRICT_ALLOWED or host in self.STRICT_ALLOWED

@staticmethod
def raise_not_allowed():
from .mocket import Mocket

current_entries = [
(location, "\n ".join(map(str, entries)))
for location, entries in Mocket._entries.items()
]
formatted_entries = "\n".join(
[f" {location}:\n {entries}" for location, entries in current_entries]
)
raise StrictMocketException(
"Mocket tried to use the real `socket` module while STRICT mode was active.\n"
f"Registered entries:\n{formatted_entries}"
)
43 changes: 43 additions & 0 deletions tests/main/test_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from mocket import Mocketizer, mocketize
from mocket.exceptions import StrictMocketException
from mocket.mockhttp import Entry, Response
from mocket.utils import MocketMode


@mocketize(strict_mode=True)
Expand All @@ -26,3 +28,44 @@ def test_intermittent_strict_mode():

with Mocketizer(strict_mode=False):
requests.get(url)


@pytest.mark.skipif('os.getenv("SKIP_TRUE_HTTP", False)')
def test_strict_mode_exceptions():
url = "http://httpbin.local/ip"

with Mocketizer(strict_mode=True, strict_mode_allowed=["httpbin.local"]):
requests.get(url)

with Mocketizer(strict_mode=True, strict_mode_allowed=[("httpbin.local", 80)]):
requests.get(url)


def test_strict_mode_error_message():
url = "http://httpbin.local/ip"

Entry.register(Entry.GET, "http://httpbin.local/user.agent", Response(status=404))

with Mocketizer(strict_mode=True):
with pytest.raises(StrictMocketException) as exc_info:
requests.get(url)
assert (
str(exc_info.value)
== """
Mocket tried to use the real `socket` module while STRICT mode was active.
Registered entries:
('httpbin.local', 80):
Entry(method='GET', schema='http', location=('httpbin.local', 80), path='/user.agent', query='')
""".strip()
)


def test_strict_mode_false_with_allowed_hosts():
with pytest.raises(ValueError):
Mocketizer(strict_mode=False, strict_mode_allowed=["foobar.local"])


def test_strict_mode_false_always_allowed():
with Mocketizer(strict_mode=False):
assert MocketMode().is_allowed("foobar.com")
assert MocketMode().is_allowed(("foobar.com", 443))

0 comments on commit 4bd3e4f

Please sign in to comment.