diff --git a/mocket/mocket.py b/mocket/mocket.py index 2b4d0c0f..c0a310a3 100644 --- a/mocket/mocket.py +++ b/mocket/mocket.py @@ -49,6 +49,13 @@ except ImportError: pyopenssl_override = False +try: # pragma: no cover + from aiohttp import TCPConnector + + aiohttp_make_ssl_context_cache_clear = TCPConnector._make_ssl_context.cache_clear +except (ImportError, AttributeError): + aiohttp_make_ssl_context_cache_clear = None + true_socket = socket.socket true_create_connection = socket.create_connection @@ -85,6 +92,7 @@ class FakeSSLContext(SuperFakeSSLContext): "load_verify_locations", "set_alpn_protocols", "set_ciphers", + "set_default_verify_paths", ) sock = None post_handshake_auth = None @@ -180,6 +188,8 @@ def __init__( self.type = int(type) self.proto = int(proto) self._truesocket_recording_dir = None + self._did_handshake = False + self._sent_non_empty_bytes = False self.kwargs = kwargs def __str__(self): @@ -218,7 +228,7 @@ def getsockopt(level, optname, buflen=None): return socket.SOCK_STREAM def do_handshake(self): - pass + self._did_handshake = True def getpeername(self): return self._address @@ -257,6 +267,8 @@ def write(self, data): @staticmethod def fileno(): + if Mocket.r_fd is not None: + return Mocket.r_fd Mocket.r_fd, Mocket.w_fd = os.pipe() return Mocket.r_fd @@ -292,10 +304,21 @@ def sendall(self, data, entry=None, *args, **kwargs): self.fd.seek(0) def read(self, buffersize): - return self.fd.read(buffersize) + rv = self.fd.read(buffersize) + if rv: + self._sent_non_empty_bytes = True + if self._did_handshake and not self._sent_non_empty_bytes: + raise ssl.SSLWantReadError("The operation did not complete (read)") + return rv def recv_into(self, buffer, buffersize=None, flags=None): - return buffer.write(self.read(buffersize)) + if hasattr(buffer, "write"): + return buffer.write(self.read(buffersize)) + # buffer is a memoryview + data = self.read(buffersize) + if data: + buffer[: len(data)] = data + return len(data) def recv(self, buffersize, flags=None): if Mocket.r_fd and Mocket.w_fd: @@ -455,8 +478,12 @@ def collect(cls, data): @classmethod def reset(cls): - cls.r_fd = None - cls.w_fd = None + if cls.r_fd is not None: + os.close(cls.r_fd) + cls.r_fd = None + if cls.w_fd is not None: + os.close(cls.w_fd) + cls.w_fd = None cls._entries = collections.defaultdict(list) cls._requests = [] @@ -527,6 +554,8 @@ def enable(namespace=None, truesocket_recording_dir=None): if pyopenssl_override: # pragma: no cover # Take out the pyopenssl version - use the default implementation extract_from_urllib3() + if aiohttp_make_ssl_context_cache_clear: # pragma: no cover + aiohttp_make_ssl_context_cache_clear() @staticmethod def disable(): @@ -563,6 +592,8 @@ def disable(): if pyopenssl_override: # pragma: no cover # Put the pyopenssl version back in place inject_into_urllib3() + if aiohttp_make_ssl_context_cache_clear: # pragma: no cover + aiohttp_make_ssl_context_cache_clear() @classmethod def get_namespace(cls): diff --git a/pyproject.toml b/pyproject.toml index 576ee047..e8f58b6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dynamic = ["version"] [project.optional-dependencies] test = [ "pre-commit", + "psutil", "pytest", "pytest-cov", "pytest-asyncio", diff --git a/tests/main/test_mocket.py b/tests/main/test_mocket.py index 14057b10..90d9738f 100644 --- a/tests/main/test_mocket.py +++ b/tests/main/test_mocket.py @@ -6,6 +6,8 @@ from unittest import TestCase from unittest.mock import patch +import httpx +import psutil import pytest from mocket import Mocket, MocketEntry, Mocketizer, mocketize @@ -190,3 +192,19 @@ def test_patch( ): method_patch.return_value = "foo" assert os.getcwd() == "foo" + + +@pytest.mark.skipif(not psutil.POSIX, reason="Uses a POSIX-only API to test") +@pytest.mark.asyncio +async def test_no_dangling_fds(): + url = "http://httpbin.local/ip" + + proc = psutil.Process(os.getpid()) + + prev_num_fds = proc.num_fds() + + async with Mocketizer(strict_mode=False): + async with httpx.AsyncClient() as client: + await client.get(url) + + assert proc.num_fds() == prev_num_fds diff --git a/tests/tests38/test_http_aiohttp.py b/tests/tests38/test_http_aiohttp.py index 7515eef8..8bdcd0b6 100644 --- a/tests/tests38/test_http_aiohttp.py +++ b/tests/tests38/test_http_aiohttp.py @@ -1,11 +1,10 @@ import json from unittest import IsolatedAsyncioTestCase -import httpx import pytest from mocket.async_mocket import async_mocketize -from mocket.mocket import Mocket +from mocket.mocket import Mocket, Mocketizer from mocket.mockhttp import Entry from mocket.plugins.httpretty import HTTPretty, async_httprettified @@ -46,6 +45,23 @@ async def test_http_session(self): self.assertEqual(len(Mocket.request_list()), 2) + @async_httprettified + async def test_httprettish_session(self): + HTTPretty.register_uri( + HTTPretty.GET, + self.target_url, + body=json.dumps(dict(origin="127.0.0.1")), + ) + + async with aiohttp.ClientSession(timeout=self.timeout) as session: + async with session.get(self.target_url) as get_response: + assert get_response.status == 200 + assert await get_response.text() == '{"origin": "127.0.0.1"}' + + class AioHttpsEntryTestCase(IsolatedAsyncioTestCase): + timeout = aiohttp.ClientTimeout(total=3) + target_url = "https://httpbin.localhost/anything/" + @async_mocketize async def test_https_session(self): body = "asd" * 100 @@ -67,7 +83,14 @@ async def test_https_session(self): self.assertEqual(len(Mocket.request_list()), 2) - @pytest.mark.xfail + @async_mocketize + async def test_no_verify(self): + Entry.single_register(Entry.GET, self.target_url, status=404) + + async with aiohttp.ClientSession(timeout=self.timeout) as session: + async with session.get(self.target_url, ssl=False) as get_response: + assert get_response.status == 404 + @async_httprettified async def test_httprettish_session(self): HTTPretty.register_uri( @@ -81,21 +104,15 @@ async def test_httprettish_session(self): assert get_response.status == 200 assert await get_response.text() == '{"origin": "127.0.0.1"}' - -class HttpxEntryTestCase(IsolatedAsyncioTestCase): - target_url = "http://httpbin.local/ip" - - @async_httprettified - async def test_httprettish_httpx_session(self): - expected_response = {"origin": "127.0.0.1"} - - HTTPretty.register_uri( - HTTPretty.GET, - self.target_url, - body=json.dumps(expected_response), - ) - - async with httpx.AsyncClient() as client: - response = await client.get(self.target_url) - assert response.status_code == 200 - assert response.json() == expected_response + @pytest.mark.skipif('os.getenv("SKIP_TRUE_HTTP", False)') + async def test_mocked_https_request_after_unmocked_https_request(self): + async with aiohttp.ClientSession(timeout=self.timeout) as session: + response = await session.get(self.target_url + "real", ssl=False) + assert response.status == 200 + + async with Mocketizer(None): + Entry.single_register(Entry.GET, self.target_url + "mocked", status=404) + async with aiohttp.ClientSession(timeout=self.timeout) as session: + response = await session.get(self.target_url + "mocked", ssl=False) + assert response.status == 404 + self.assertEqual(len(Mocket.request_list()), 1) diff --git a/tests/tests38/test_http_httpx.py b/tests/tests38/test_http_httpx.py new file mode 100644 index 00000000..6fb0fcab --- /dev/null +++ b/tests/tests38/test_http_httpx.py @@ -0,0 +1,44 @@ +import json +from unittest import IsolatedAsyncioTestCase + +import httpx + +from mocket.plugins.httpretty import HTTPretty, async_httprettified + + +class HttpxEntryTestCase(IsolatedAsyncioTestCase): + target_url = "http://httpbin.local/ip" + + @async_httprettified + async def test_httprettish_httpx_session(self): + expected_response = {"origin": "127.0.0.1"} + + HTTPretty.register_uri( + HTTPretty.GET, + self.target_url, + body=json.dumps(expected_response), + ) + + async with httpx.AsyncClient() as client: + response = await client.get(self.target_url) + assert response.status_code == 200 + assert response.json() == expected_response + + +class HttpxHttpsEntryTestCase(IsolatedAsyncioTestCase): + target_url = "https://httpbin.local/ip" + + @async_httprettified + async def test_httprettish_httpx_session(self): + expected_response = {"origin": "127.0.0.1"} + + HTTPretty.register_uri( + HTTPretty.GET, + self.target_url, + body=json.dumps(expected_response), + ) + + async with httpx.AsyncClient() as client: + response = await client.get(self.target_url) + assert response.status_code == 200 + assert response.json() == expected_response