Skip to content

Commit

Permalink
Make Mocket work with big requests (#234)
Browse files Browse the repository at this point in the history
* Make Mocket work with big requests.
* Getting rid of old tests using `aiohttp`.
* Adding a single `aiohttp` test with timeout.
* Skip for Python versions older than 3.11 (looks like aio-libs/aiohttp#5582).

---------

Co-authored-by: Giorgio Salluzzo <giorgio.salluzzo@satellogic.com>
  • Loading branch information
mindflayer and mindflayer authored May 14, 2024
1 parent d154be2 commit 9d5c031
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 262 deletions.
30 changes: 9 additions & 21 deletions mocket/mocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import socket
import ssl
from datetime import datetime, timedelta
from io import BytesIO
from json.decoder import JSONDecodeError

import urllib3
Expand All @@ -26,7 +27,6 @@
from .utils import (
SSL_PROTOCOL,
MocketMode,
MocketSocketCore,
get_mocketize,
hexdump,
hexload,
Expand Down Expand Up @@ -175,6 +175,8 @@ class MocketSocket:
_mode = None
_bufsize = None
_secure_socket = False
_did_handshake = False
_sent_non_empty_bytes = False

def __init__(
self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, **kwargs
Expand All @@ -186,8 +188,6 @@ def __init__(
self.type = int(type)
self.proto = int(proto)
self._truesocket_recording_dir = None
self._did_handshake = False
self._sent_non_empty_bytes = False
self.kwargs = kwargs

def __str__(self):
Expand All @@ -202,7 +202,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
@property
def fd(self):
if self._fd is None:
self._fd = MocketSocketCore()
self._fd = BytesIO()
return self._fd

def gettimeout(self):
Expand Down Expand Up @@ -264,12 +264,10 @@ def unwrap(self):
def write(self, data):
return self.send(encode_to_bytes(data))

@staticmethod
def fileno():
if Mocket.r_fd is not None:
return Mocket.r_fd
Mocket.r_fd, Mocket.w_fd = os.pipe()
return Mocket.r_fd
def fileno(self):
if self.true_socket:
return self.true_socket.fileno()
return self.fd.fileno()

def connect(self, address):
self._address = self._host, self._port = address
Expand Down Expand Up @@ -317,8 +315,6 @@ def recv_into(self, buffer, buffersize=None, flags=None):
return len(data)

def recv(self, buffersize, flags=None):
if Mocket.r_fd and Mocket.w_fd:
return os.read(Mocket.r_fd, buffersize)
data = self.read(buffersize)
if data:
return data
Expand Down Expand Up @@ -436,7 +432,7 @@ def close(self):
self._fd = None

def __getattr__(self, name):
"""Do nothing catchall function, for methods like close() and shutdown()"""
"""Do nothing catchall function, for methods like shutdown()"""

def do_nothing(*args, **kwargs):
pass
Expand All @@ -450,8 +446,6 @@ class Mocket:
_requests = []
_namespace = text_type(id(_entries))
_truesocket_recording_dir = None
r_fd = None
w_fd = None

@classmethod
def register(cls, *entries):
Expand All @@ -473,12 +467,6 @@ def collect(cls, data):

@classmethod
def reset(cls):
if cls.r_fd is not None:
os.close(cls.r_fd)
cls.r_fd = None
if cls.w_fd is not None:
os.close(cls.w_fd)
cls.w_fd = None
cls._entries = collections.defaultdict(list)
cls._requests = []

Expand Down
16 changes: 0 additions & 16 deletions mocket/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from __future__ import annotations

import binascii
import io
import os
import ssl
from typing import TYPE_CHECKING, Any, Callable, ClassVar

Expand All @@ -12,24 +10,10 @@
if TYPE_CHECKING: # pragma: no cover
from typing import NoReturn

from _typeshed import ReadableBuffer

SSL_PROTOCOL = ssl.PROTOCOL_TLSv1_2


class MocketSocketCore(io.BytesIO):
def write( # type: ignore[override] # BytesIO returns int
self,
content: ReadableBuffer,
) -> None:
super().write(content)

from mocket import Mocket

if Mocket.r_fd and Mocket.w_fd:
os.write(Mocket.w_fd, content)


def hexdump(binary_string: bytes) -> str:
r"""
>>> hexdump(b"bar foobar foo") == decode_from_bytes(encode_to_bytes("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F"))
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ test = [
"pook",
"flake8>5",
"xxhash",
"aiohttp;python_version<'3.12'",
"httpx",
"pipfile",
"build",
"twine",
"fastapi",
"aiohttp",
"wait-for-it",
"mypy",
"types-decorator",
Expand Down
31 changes: 30 additions & 1 deletion tests/tests37/test_asyncio.py → tests/main/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@
import glob
import json
import socket
import sys
import tempfile

from mocket import Mocketizer
import aiohttp
import pytest

from mocket import Mocketizer, async_mocketize
from mocket.mockhttp import Entry


def test_asyncio_record_replay(event_loop):
Expand Down Expand Up @@ -37,3 +42,27 @@ async def test_asyncio_connection():
responses = json.load(f)

assert len(responses["google.com"]["80"].keys()) == 1


@pytest.mark.asyncio
@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Looks like https://github.com/aio-libs/aiohttp/issues/5582",
)
@async_mocketize
async def test_aiohttp():
url = "https://bar.foo/"
data = {"message": "Hello"}

Entry.single_register(
Entry.GET,
url,
body=json.dumps(data),
headers={"content-type": "application/json"},
)

async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=3)
) as session, session.get(url) as response:
response = await response.json()
assert response == data
141 changes: 140 additions & 1 deletion tests/main/test_httpx.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import datetime
import json

import httpx
import pytest
from asgiref.sync import async_to_sync
from fastapi import FastAPI
from fastapi.testclient import TestClient

from mocket.mocket import Mocket, mocketize
from mocket import Mocket, Mocketizer, async_mocketize, mocketize
from mocket.mockhttp import Entry
from mocket.plugins.httpretty import httprettified, httpretty

Expand Down Expand Up @@ -55,3 +58,139 @@ async def perform_async_transactions():

perform_async_transactions()
assert len(httpretty.latest_requests) == 1


@mocketize(strict_mode=True)
def test_sync_case():
test_uri = "https://abc.de/testdata/"
base_timestamp = int(datetime.datetime.now().timestamp())
response = [
{"timestamp": base_timestamp + i, "value": 1337 + 42 * i} for i in range(30_000)
]
Entry.single_register(
method=Entry.POST,
uri=test_uri,
body=json.dumps(
response,
),
headers={"content-type": "application/json"},
)

with httpx.Client() as client:
response = client.post(test_uri)

assert len(response.json())


@pytest.mark.asyncio
@async_mocketize(strict_mode=True)
async def test_async_case_low_number():
test_uri = "https://abc.de/testdata/"
base_timestamp = int(datetime.datetime.now().timestamp())
response = [
{"timestamp": base_timestamp + i, "value": 1337 + 42 * i} for i in range(100)
]
Entry.single_register(
method=Entry.POST,
uri=test_uri,
body=json.dumps(
response,
),
headers={"content-type": "application/json"},
)

async with httpx.AsyncClient() as client:
response = await client.post(test_uri)

assert len(response.json())


@pytest.mark.asyncio
@async_mocketize(strict_mode=True)
async def test_async_case_high_number():
test_uri = "https://abc.de/testdata/"
base_timestamp = int(datetime.datetime.now().timestamp())
response = [
{"timestamp": base_timestamp + i, "value": 1337 + 42 * i} for i in range(30_000)
]
Entry.single_register(
method=Entry.POST,
uri=test_uri,
body=json.dumps(
response,
),
headers={"content-type": "application/json"},
)

async with httpx.AsyncClient() as client:
response = await client.post(test_uri)

assert len(response.json())


def create_app() -> FastAPI:
app = FastAPI()

@app.get("/")
async def read_main() -> dict:
async with httpx.AsyncClient() as client:
r = await client.get("https://example.org/")
return r.json()

return app


@mocketize
def test_call_from_fastapi() -> None:
app = create_app()
client = TestClient(app)

Entry.single_register(Entry.GET, "https://example.org/", body='{"id": 1}')

response = client.get("/")

assert response.status_code == 200
assert response.json() == {"id": 1}


@pytest.mark.asyncio
@async_mocketize
async def test_httpx_decorator():
url = "https://bar.foo/"
data = {"message": "Hello"}

Entry.single_register(
Entry.GET,
url,
body=json.dumps(data),
headers={"content-type": "application/json"},
)

async with httpx.AsyncClient() as client:
response = await client.get(url)

assert response.json() == data


@pytest.fixture
def httpx_client() -> httpx.AsyncClient:
with Mocketizer():
yield httpx.AsyncClient()


@pytest.mark.asyncio
async def test_httpx_fixture(httpx_client):
url = "https://foo.bar/"
data = {"message": "Hello"}

Entry.single_register(
Entry.GET,
url,
body=json.dumps(data),
headers={"content-type": "application/json"},
)

async with httpx_client as client:
response = await client.get(url)

assert response.json() == data
Loading

0 comments on commit 9d5c031

Please sign in to comment.