diff --git a/aiokafka/conn.py b/aiokafka/conn.py index 859f8f24..ae3f70df 100644 --- a/aiokafka/conn.py +++ b/aiokafka/conn.py @@ -458,6 +458,12 @@ def send(self, request, expect_response=True): f"No connection to broker at {self._host}:{self._port}" ) + if self._writer.is_closing(): + self.close(reason=CloseReason.CONNECTION_BROKEN) + raise Errors.KafkaConnectionError( + f"Connection at {self._host}:{self._port} is closing" + ) + correlation_id = self._next_correlation_id() header = request.build_request_header( correlation_id=correlation_id, client_id=self._client_id diff --git a/requirements-ci.txt b/requirements-ci.txt index 615ffb65..ed530519 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -11,3 +11,4 @@ Pygments==2.18.0 gssapi==1.9.0 async-timeout==4.0.3 cramjam==2.9.0 +uvloop==0.19.0 diff --git a/tests/test_conn.py b/tests/test_conn.py index f0f4a075..a92913ba 100644 --- a/tests/test_conn.py +++ b/tests/test_conn.py @@ -1,10 +1,13 @@ import asyncio import gc +import socket import struct -from typing import Any +import sys +from typing import Any, AsyncIterable, Iterable, Tuple from unittest import mock import pytest +import pytest_asyncio from aiokafka.conn import AIOKafkaConnection, VersionInfo, create_conn from aiokafka.errors import ( @@ -144,7 +147,7 @@ async def test_send_to_closed(self): with self.assertRaises(KafkaConnectionError): await conn.send(request) - conn._writer = mock.MagicMock() + conn._writer = mock.MagicMock(is_closing=mock.Mock(return_value=False)) conn._writer.write.side_effect = OSError("mocked writer is closed") with self.assertRaises(KafkaConnectionError): @@ -173,7 +176,7 @@ async def second_resp(*args: Any, **kw: Any): return resp reader.readexactly.side_effect = [first_resp(), second_resp()] - writer = mock.MagicMock() + writer = mock.MagicMock(is_closing=mock.Mock(return_value=False)) conn._reader = reader conn._writer = writer @@ -208,7 +211,7 @@ async def second_resp(*args: Any, **kw: Any): return resp reader.readexactly.side_effect = [first_resp(), second_resp()] - writer = mock.MagicMock() + writer = mock.MagicMock(is_closing=mock.Mock(return_value=False)) conn._reader = reader conn._writer = writer @@ -237,7 +240,7 @@ async def invoke_osserror(*a, **kw): # setup reader reader = mock.MagicMock() reader.readexactly.return_value = invoke_osserror() - writer = mock.MagicMock() + writer = mock.MagicMock(is_closing=mock.Mock(return_value=False)) conn._reader = reader conn._writer = writer @@ -394,7 +397,7 @@ async def test__send_sasl_token(self): # setup connection with mocked transport and protocol conn = AIOKafkaConnection(host="", port=9999) conn.close = mock.MagicMock() - conn._writer = mock.MagicMock() + conn._writer = mock.MagicMock(is_closing=mock.Mock(return_value=False)) out_buffer = [] conn._writer.write = mock.Mock(side_effect=out_buffer.append) conn._reader = mock.MagicMock() @@ -424,3 +427,80 @@ async def test__send_sasl_token(self): conn._send_sasl_token(b"Super data") # We don't need to close 2ce self.assertEqual(conn.close.call_count, 1) + + +@pytest.mark.skipif(sys.platform == "win32", reason="Uvloop doesn't support Windows") +class TestClosedSocket: + @pytest.fixture( + params=( + pytest.param("asyncio", id="asyncio"), + pytest.param("uvloop", id="uvloop"), + ), + ) + def event_loop( + self, request: pytest.FixtureRequest + ) -> Iterable[asyncio.AbstractEventLoop]: + if request.param == "asyncio": + policy = asyncio.DefaultEventLoopPolicy() + elif request.param == "uvloop": + import uvloop + + policy = uvloop.EventLoopPolicy() + else: + raise ValueError(f"loop {request.param} is not supported") + + loop: asyncio.AbstractEventLoop = policy.new_event_loop() + yield loop + loop.close() + + @pytest.fixture() + def server(self, unused_tcp_port: int) -> Iterable[Tuple[str, int, socket.socket]]: + host = "localhost" + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind((host, unused_tcp_port)) + sock.listen(8) + sock.setblocking(False) + + yield host, unused_tcp_port, sock + + sock.close() + + @pytest_asyncio.fixture() + async def conn( + self, server: Tuple[str, int, socket.socket] + ) -> AsyncIterable[AIOKafkaConnection]: + host, port, _ = server + + conn = AIOKafkaConnection(host=host, port=port, request_timeout_ms=1000) + conn._create_reader_task = mock.Mock() + + yield conn + + fut = conn.close() + if fut: + await fut + + @pytest.mark.asyncio + async def test_send_to_closed_socket( + self, server: Tuple[str, int, socket.socket], conn: AIOKafkaConnection + ) -> None: + host, port, sock = server + + request = MetadataRequest([]) + + with pytest.raises( + KafkaConnectionError, + match=f"KafkaConnectionError: No connection to broker at {host}:{port}", + ): + await conn.send(request) + + await conn.connect() + + sock.close() + await asyncio.sleep(0.1) + + with pytest.raises( + KafkaConnectionError, + match=f"KafkaConnectionError: Connection at {host}:{port} is closing", + ): + await conn.send(request)