Skip to content

Commit

Permalink
Merge pull request #219 from mindflayer/external-pr
Browse files Browse the repository at this point in the history
External contribution
  • Loading branch information
mindflayer authored Jan 16, 2024
2 parents 2bca049 + a5af5c3 commit 3edce14
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 26 deletions.
41 changes: 36 additions & 5 deletions mocket/mocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@
except ImportError:
pyopenssl_override = False

try: # pragma: no cover
from aiohttp import TCPConnector

aiohttp_make_ssl_context_cache_clear = TCPConnector._make_ssl_context.cache_clear
except (ImportError, AttributeError):
aiohttp_make_ssl_context_cache_clear = None


true_socket = socket.socket
true_create_connection = socket.create_connection
Expand Down Expand Up @@ -85,6 +92,7 @@ class FakeSSLContext(SuperFakeSSLContext):
"load_verify_locations",
"set_alpn_protocols",
"set_ciphers",
"set_default_verify_paths",
)
sock = None
post_handshake_auth = None
Expand Down Expand Up @@ -180,6 +188,8 @@ def __init__(
self.type = int(type)
self.proto = int(proto)
self._truesocket_recording_dir = None
self._did_handshake = False
self._sent_non_empty_bytes = False
self.kwargs = kwargs

def __str__(self):
Expand Down Expand Up @@ -218,7 +228,7 @@ def getsockopt(level, optname, buflen=None):
return socket.SOCK_STREAM

def do_handshake(self):
pass
self._did_handshake = True

def getpeername(self):
return self._address
Expand Down Expand Up @@ -257,6 +267,8 @@ def write(self, data):

@staticmethod
def fileno():
if Mocket.r_fd is not None:
return Mocket.r_fd
Mocket.r_fd, Mocket.w_fd = os.pipe()
return Mocket.r_fd

Expand Down Expand Up @@ -292,10 +304,21 @@ def sendall(self, data, entry=None, *args, **kwargs):
self.fd.seek(0)

def read(self, buffersize):
return self.fd.read(buffersize)
rv = self.fd.read(buffersize)
if rv:
self._sent_non_empty_bytes = True
if self._did_handshake and not self._sent_non_empty_bytes:
raise ssl.SSLWantReadError("The operation did not complete (read)")
return rv

def recv_into(self, buffer, buffersize=None, flags=None):
return buffer.write(self.read(buffersize))
if hasattr(buffer, "write"):
return buffer.write(self.read(buffersize))
# buffer is a memoryview
data = self.read(buffersize)
if data:
buffer[: len(data)] = data
return len(data)

def recv(self, buffersize, flags=None):
if Mocket.r_fd and Mocket.w_fd:
Expand Down Expand Up @@ -455,8 +478,12 @@ def collect(cls, data):

@classmethod
def reset(cls):
cls.r_fd = None
cls.w_fd = None
if cls.r_fd is not None:
os.close(cls.r_fd)
cls.r_fd = None
if cls.w_fd is not None:
os.close(cls.w_fd)
cls.w_fd = None
cls._entries = collections.defaultdict(list)
cls._requests = []

Expand Down Expand Up @@ -527,6 +554,8 @@ def enable(namespace=None, truesocket_recording_dir=None):
if pyopenssl_override: # pragma: no cover
# Take out the pyopenssl version - use the default implementation
extract_from_urllib3()
if aiohttp_make_ssl_context_cache_clear: # pragma: no cover
aiohttp_make_ssl_context_cache_clear()

@staticmethod
def disable():
Expand Down Expand Up @@ -563,6 +592,8 @@ def disable():
if pyopenssl_override: # pragma: no cover
# Put the pyopenssl version back in place
inject_into_urllib3()
if aiohttp_make_ssl_context_cache_clear: # pragma: no cover
aiohttp_make_ssl_context_cache_clear()

@classmethod
def get_namespace(cls):
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dynamic = ["version"]
[project.optional-dependencies]
test = [
"pre-commit",
"psutil",
"pytest",
"pytest-cov",
"pytest-asyncio",
Expand Down
18 changes: 18 additions & 0 deletions tests/main/test_mocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from unittest import TestCase
from unittest.mock import patch

import httpx
import psutil
import pytest

from mocket import Mocket, MocketEntry, Mocketizer, mocketize
Expand Down Expand Up @@ -190,3 +192,19 @@ def test_patch(
):
method_patch.return_value = "foo"
assert os.getcwd() == "foo"


@pytest.mark.skipif(not psutil.POSIX, reason="Uses a POSIX-only API to test")
@pytest.mark.asyncio
async def test_no_dangling_fds():
url = "http://httpbin.local/ip"

proc = psutil.Process(os.getpid())

prev_num_fds = proc.num_fds()

async with Mocketizer(strict_mode=False):
async with httpx.AsyncClient() as client:
await client.get(url)

assert proc.num_fds() == prev_num_fds
59 changes: 38 additions & 21 deletions tests/tests38/test_http_aiohttp.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import json
from unittest import IsolatedAsyncioTestCase

import httpx
import pytest

from mocket.async_mocket import async_mocketize
from mocket.mocket import Mocket
from mocket.mocket import Mocket, Mocketizer
from mocket.mockhttp import Entry
from mocket.plugins.httpretty import HTTPretty, async_httprettified

Expand Down Expand Up @@ -46,6 +45,23 @@ async def test_http_session(self):

self.assertEqual(len(Mocket.request_list()), 2)

@async_httprettified
async def test_httprettish_session(self):
HTTPretty.register_uri(
HTTPretty.GET,
self.target_url,
body=json.dumps(dict(origin="127.0.0.1")),
)

async with aiohttp.ClientSession(timeout=self.timeout) as session:
async with session.get(self.target_url) as get_response:
assert get_response.status == 200
assert await get_response.text() == '{"origin": "127.0.0.1"}'

class AioHttpsEntryTestCase(IsolatedAsyncioTestCase):
timeout = aiohttp.ClientTimeout(total=3)
target_url = "https://httpbin.localhost/anything/"

@async_mocketize
async def test_https_session(self):
body = "asd" * 100
Expand All @@ -67,7 +83,14 @@ async def test_https_session(self):

self.assertEqual(len(Mocket.request_list()), 2)

@pytest.mark.xfail
@async_mocketize
async def test_no_verify(self):
Entry.single_register(Entry.GET, self.target_url, status=404)

async with aiohttp.ClientSession(timeout=self.timeout) as session:
async with session.get(self.target_url, ssl=False) as get_response:
assert get_response.status == 404

@async_httprettified
async def test_httprettish_session(self):
HTTPretty.register_uri(
Expand All @@ -81,21 +104,15 @@ async def test_httprettish_session(self):
assert get_response.status == 200
assert await get_response.text() == '{"origin": "127.0.0.1"}'


class HttpxEntryTestCase(IsolatedAsyncioTestCase):
target_url = "http://httpbin.local/ip"

@async_httprettified
async def test_httprettish_httpx_session(self):
expected_response = {"origin": "127.0.0.1"}

HTTPretty.register_uri(
HTTPretty.GET,
self.target_url,
body=json.dumps(expected_response),
)

async with httpx.AsyncClient() as client:
response = await client.get(self.target_url)
assert response.status_code == 200
assert response.json() == expected_response
@pytest.mark.skipif('os.getenv("SKIP_TRUE_HTTP", False)')
async def test_mocked_https_request_after_unmocked_https_request(self):
async with aiohttp.ClientSession(timeout=self.timeout) as session:
response = await session.get(self.target_url + "real", ssl=False)
assert response.status == 200

async with Mocketizer(None):
Entry.single_register(Entry.GET, self.target_url + "mocked", status=404)
async with aiohttp.ClientSession(timeout=self.timeout) as session:
response = await session.get(self.target_url + "mocked", ssl=False)
assert response.status == 404
self.assertEqual(len(Mocket.request_list()), 1)
44 changes: 44 additions & 0 deletions tests/tests38/test_http_httpx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import json
from unittest import IsolatedAsyncioTestCase

import httpx

from mocket.plugins.httpretty import HTTPretty, async_httprettified


class HttpxEntryTestCase(IsolatedAsyncioTestCase):
target_url = "http://httpbin.local/ip"

@async_httprettified
async def test_httprettish_httpx_session(self):
expected_response = {"origin": "127.0.0.1"}

HTTPretty.register_uri(
HTTPretty.GET,
self.target_url,
body=json.dumps(expected_response),
)

async with httpx.AsyncClient() as client:
response = await client.get(self.target_url)
assert response.status_code == 200
assert response.json() == expected_response


class HttpxHttpsEntryTestCase(IsolatedAsyncioTestCase):
target_url = "https://httpbin.local/ip"

@async_httprettified
async def test_httprettish_httpx_session(self):
expected_response = {"origin": "127.0.0.1"}

HTTPretty.register_uri(
HTTPretty.GET,
self.target_url,
body=json.dumps(expected_response),
)

async with httpx.AsyncClient() as client:
response = await client.get(self.target_url)
assert response.status_code == 200
assert response.json() == expected_response

0 comments on commit 3edce14

Please sign in to comment.