From 03c7499c6fb594d4dddd28b2602d95ffbed408d4 Mon Sep 17 00:00:00 2001 From: Dmitriy Date: Mon, 1 Jul 2024 13:01:14 +0500 Subject: [PATCH] add typing to tests/test_conn.py, fix review --- Makefile | 1 + aiokafka/conn.py | 18 +++++- aiokafka/protocol/admin.py | 6 +- aiokafka/protocol/commit.py | 16 +++++- tests/_testutil.py | 15 +++++ tests/conftest.py | 39 ++++++++----- tests/test_conn.py | 108 +++++++++++++++++++----------------- 7 files changed, 128 insertions(+), 75 deletions(-) diff --git a/Makefile b/Makefile index 83c872bb..8173a850 100644 --- a/Makefile +++ b/Makefile @@ -16,6 +16,7 @@ FORMATTED_AREAS=\ aiokafka/protocol/ \ aiokafka/record/ \ tests/test_codec.py \ + tests/test_conn.py \ tests/test_helpers.py \ tests/test_protocol.py \ tests/test_protocol_object_conversion.py \ diff --git a/aiokafka/conn.py b/aiokafka/conn.py index 11aede0d..7d52dc69 100644 --- a/aiokafka/conn.py +++ b/aiokafka/conn.py @@ -48,7 +48,11 @@ from aiokafka.protocol.admin import ( ApiVersionRequest, SaslAuthenticateRequest, + SaslAuthenticateResponse_v0, + SaslAuthenticateResponse_v1, SaslHandShakeRequest, + SaslHandShakeResponse_v0, + SaslHandShakeResponse_v1, ) from aiokafka.protocol.api import Request, Response from aiokafka.protocol.commit import ( @@ -100,7 +104,6 @@ def __init__(self, versions: Dict[int, Tuple[int, int]]) -> None: self._versions = versions def pick_best(self, request_versions: Sequence[Type[RequestT]]) -> Type[RequestT]: - api_key = request_versions[0].API_KEY api_key = cast(int, request_versions[0].API_KEY) if api_key not in self._versions: return request_versions[0] @@ -173,7 +176,7 @@ def __init__( **kw: Any, ) -> None: self._closed_fut = closed_fut - super().__init__(*args, loop=loop, **kw) + super().__init__(*args, loop=loop, **kw) # type: ignore[misc] def connection_lost(self, exc: Optional[Exception]) -> None: super().connection_lost(exc) @@ -359,6 +362,9 @@ async def _do_sasl_handshake(self) -> None: sasl_handshake = handshake_klass(self._sasl_mechanism) response = await self.send(sasl_handshake) + response = cast( + Union[SaslHandShakeResponse_v0, SaslHandShakeResponse_v1], response + ) error_type = Errors.for_code(response.error_code) if error_type is not Errors.NoError: error = error_type(self) @@ -419,6 +425,10 @@ async def _do_sasl_handshake(self) -> None: else: req = auth_klass(payload) resp = await self.send(req) + resp = cast( + Union[SaslAuthenticateResponse_v0, SaslAuthenticateResponse_v1], + resp, + ) error_type = Errors.for_code(resp.error_code) if error_type is not Errors.NoError: exc = error_type(resp.error_message) @@ -583,6 +593,8 @@ def _send_sasl_token( self, payload: bytes, expect_response: Literal[True] ) -> Coroutine[None, None, bytes]: ... @overload + def _send_sasl_token(self, payload: bytes) -> Coroutine[None, None, bytes]: ... + @overload def _send_sasl_token( self, payload: bytes, expect_response: bool ) -> Union[Coroutine[None, None, None], Coroutine[None, None, bytes]]: ... @@ -682,7 +694,7 @@ async def _read(self_ref: weakref.ReferenceType[AIOKafkaConnection]) -> None: def _handle_frame(self, resp: bytes) -> None: packet = self._requests[0] - if packet.correlation_id is None: # Is a SASL packet, just pass it though + if isinstance(packet, SaslPacket): # Is a SASL packet, just pass it though if not packet.fut.done(): packet.fut.set_result(resp) else: diff --git a/aiokafka/protocol/admin.py b/aiokafka/protocol/admin.py index b48ebe80..410ff47c 100644 --- a/aiokafka/protocol/admin.py +++ b/aiokafka/protocol/admin.py @@ -527,7 +527,7 @@ class SaslHandShakeRequest_v1(Request[SaslHandShakeResponse_v1]): SaslHandShakeRequest = (SaslHandShakeRequest_v0, SaslHandShakeRequest_v1) -SaslHandShakeResponse = [SaslHandShakeResponse_v0, SaslHandShakeResponse_v1] +SaslHandShakeResponse = (SaslHandShakeResponse_v0, SaslHandShakeResponse_v1) class DescribeAclsResponse_v0(Response): @@ -1048,10 +1048,10 @@ class SaslAuthenticateRequest_v1(Request[SaslAuthenticateResponse_v1]): SaslAuthenticateRequest_v0, SaslAuthenticateRequest_v1, ) -SaslAuthenticateResponse = [ +SaslAuthenticateResponse = ( SaslAuthenticateResponse_v0, SaslAuthenticateResponse_v1, -] +) class CreatePartitionsResponse_v0(Response): diff --git a/aiokafka/protocol/commit.py b/aiokafka/protocol/commit.py index b0fda8c3..8305d360 100644 --- a/aiokafka/protocol/commit.py +++ b/aiokafka/protocol/commit.py @@ -275,6 +275,11 @@ class GroupCoordinatorResponse_v0(Response): ("port", Int32), ) + error_code: int + coordinator_id: int + host: str + port: int + class GroupCoordinatorResponse_v1(Response): API_KEY = 10 @@ -288,8 +293,15 @@ class GroupCoordinatorResponse_v1(Response): ("port", Int32), ) + throttle_time_ms: int + error_code: int + error_message: str + coordinator_id: int + host: str + port: int + -class GroupCoordinatorRequest_v0(Request): +class GroupCoordinatorRequest_v0(Request[GroupCoordinatorResponse_v0]): API_KEY = 10 API_VERSION = 0 RESPONSE_TYPE = GroupCoordinatorResponse_v0 @@ -298,7 +310,7 @@ class GroupCoordinatorRequest_v0(Request): ) -class GroupCoordinatorRequest_v1(Request): +class GroupCoordinatorRequest_v1(Request[GroupCoordinatorResponse_v1]): API_KEY = 10 API_VERSION = 1 RESPONSE_TYPE = GroupCoordinatorResponse_v1 diff --git a/tests/_testutil.py b/tests/_testutil.py index 67cd2f75..049afbc3 100644 --- a/tests/_testutil.py +++ b/tests/_testutil.py @@ -13,6 +13,7 @@ from concurrent import futures from contextlib import contextmanager from functools import wraps +from typing import List from unittest.mock import Mock import pytest @@ -352,6 +353,20 @@ def kdestroy(self): class KafkaIntegrationTestCase(unittest.TestCase): topic = None + # from setup_test_class fixture + loop: asyncio.AbstractEventLoop + kafka_host: str + kafka_port: int + kafka_ssl_port: int + kafka_sasl_plain_port: int + kafka_sasl_ssl_port: int + ssl_folder: pathlib.Path + acl_manager: ACLManager + kerberos_utils: KerberosUtils + kafka_config: KafkaConfig + hosts: List[str] + kafka_version: str + @contextmanager def silence_loop_exception_handler(self): if hasattr(self.loop, "get_exception_handler"): diff --git a/tests/conftest.py b/tests/conftest.py index d582386c..6231d64c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import gc import logging @@ -7,6 +9,7 @@ import sys import uuid from dataclasses import dataclass +from typing import Generator import pytest @@ -21,7 +24,7 @@ ) from aiokafka.util import NO_EXTENSIONS -from ._testutil import wait_kafka +from ._testutil import ACLManager, KafkaConfig, KerberosUtils, wait_kafka if not NO_EXTENSIONS: assert ( @@ -67,23 +70,23 @@ def docker(request): @pytest.fixture(scope="class") -def acl_manager(kafka_server, request): +def acl_manager( + kafka_server: KafkaServer, request: pytest.FixtureRequest +) -> ACLManager: image = request.config.getoption("--docker-image") tag = image.split(":")[-1].replace("_", "-") - from ._testutil import ACLManager - manager = ACLManager(kafka_server.container, tag) return manager @pytest.fixture(scope="class") -def kafka_config(kafka_server, request): +def kafka_config( + kafka_server: KafkaServer, request: pytest.FixtureRequest +) -> KafkaConfig: image = request.config.getoption("--docker-image") tag = image.split(":")[-1].replace("_", "-") - from ._testutil import KafkaConfig - manager = KafkaConfig(kafka_server.container, tag) return manager @@ -91,9 +94,7 @@ def kafka_config(kafka_server, request): if sys.platform != "win32": @pytest.fixture(scope="class") - def kerberos_utils(kafka_server): - from ._testutil import KerberosUtils - + def kerberos_utils(kafka_server: KafkaServer) -> KerberosUtils: utils = KerberosUtils(kafka_server.container) utils.create_keytab() return utils @@ -124,7 +125,9 @@ def kafka_image(): @pytest.fixture(scope="session") -def ssl_folder(docker_ip_address, docker, kafka_image): +def ssl_folder( + docker_ip_address: str, docker: libdocker.DockerClient, kafka_image: str +) -> pathlib.Path: ssl_dir = pathlib.Path("tests/ssl_cert") if ssl_dir.is_dir(): # Skip generating certificates when they already exist. Remove @@ -171,7 +174,7 @@ def ssl_folder(docker_ip_address, docker, kafka_image): @pytest.fixture(scope="session") -def docker_ip_address(): +def docker_ip_address() -> str: """Returns IP address of the docker daemon service.""" return "127.0.0.1" @@ -210,7 +213,7 @@ def hosts(self): @pytest.fixture(scope="session") def kafka_server( kafka_image, docker, docker_ip_address, unused_port, session_id, ssl_folder - ): + ) -> Generator[KafkaServer, None, None]: kafka_host = docker_ip_address kafka_port = unused_port() kafka_ssl_port = unused_port() @@ -316,8 +319,14 @@ def setup_test_class_serverless(request, loop): @pytest.fixture(scope="class") def setup_test_class( - request, loop, kafka_server, ssl_folder, acl_manager, kerberos_utils, kafka_config -): + request: pytest.FixtureRequest, + loop: asyncio.AbstractEventLoop, + kafka_server: KafkaServer, + ssl_folder: pathlib.Path, + acl_manager: ACLManager, + kerberos_utils: KerberosUtils, + kafka_config: KafkaConfig, +) -> None: request.cls.loop = loop request.cls.kafka_host = kafka_server.host request.cls.kafka_port = kafka_server.port diff --git a/tests/test_conn.py b/tests/test_conn.py index f0f4a075..11cbb3f3 100644 --- a/tests/test_conn.py +++ b/tests/test_conn.py @@ -1,13 +1,14 @@ import asyncio import gc import struct -from typing import Any +from typing import Any, List, NoReturn, Type from unittest import mock import pytest -from aiokafka.conn import AIOKafkaConnection, VersionInfo, create_conn +from aiokafka.conn import AIOKafkaConnection, SaslPacket, VersionInfo, create_conn from aiokafka.errors import ( + BrokerResponseError, CorrelationIdError, IllegalSaslStateError, KafkaConnectionError, @@ -22,6 +23,7 @@ SaslHandShakeRequest, SaslHandShakeResponse, ) +from aiokafka.protocol.api import Request, Response from aiokafka.protocol.commit import ( GroupCoordinatorRequest_v0 as GroupCoordinatorRequest, ) @@ -40,7 +42,7 @@ @pytest.mark.usefixtures("setup_test_class") class ConnIntegrationTest(KafkaIntegrationTestCase): @run_until_complete - async def test_ctor(self): + async def test_ctor(self) -> None: conn = AIOKafkaConnection("localhost", 1234) self.assertEqual("localhost", conn.host) self.assertEqual(1234, conn.port) @@ -49,7 +51,7 @@ async def test_ctor(self): self.assertIsNone(conn._writer) @run_until_complete - async def test_global_loop_for_create_conn(self): + async def test_global_loop_for_create_conn(self) -> None: loop = get_running_loop() host, port = self.kafka_host, self.kafka_port conn = await create_conn(host, port) @@ -60,7 +62,7 @@ async def test_global_loop_for_create_conn(self): conn.close() @run_until_complete - async def test_conn_warn_unclosed(self): + async def test_conn_warn_unclosed(self) -> None: host, port = self.kafka_host, self.kafka_port conn = await create_conn(host, port, max_idle_ms=100000) @@ -70,7 +72,7 @@ async def test_conn_warn_unclosed(self): gc.collect() @run_until_complete - async def test_basic_connection_load_meta(self): + async def test_basic_connection_load_meta(self) -> None: host, port = self.kafka_host, self.kafka_port conn = await create_conn(host, port) @@ -81,7 +83,7 @@ async def test_basic_connection_load_meta(self): self.assertIsInstance(response, MetadataResponse) @run_until_complete - async def test_connections_max_idle_ms(self): + async def test_connections_max_idle_ms(self) -> None: host, port = self.kafka_host, self.kafka_port conn = await create_conn(host, port, max_idle_ms=200) self.assertEqual(conn.connected(), True) @@ -94,10 +96,11 @@ async def test_connections_max_idle_ms(self): self.assertEqual(conn.connected(), True) # It shouldn't break if we have a long running call either + assert conn._reader is not None readexactly = conn._reader.readexactly with mock.patch.object(conn._reader, "readexactly") as mocked: - async def long_read(n): + async def long_read(n: int) -> bytes: await asyncio.sleep(0.2) return await readexactly(n) @@ -109,7 +112,7 @@ async def long_read(n): self.assertEqual(conn.connected(), False) @run_until_complete - async def test_send_without_response(self): + async def test_send_without_response(self) -> None: """Imitate producer without acknowledge, in this case client produces messages and kafka does not send response, and we make sure that futures do not stuck in queue forever""" @@ -137,7 +140,7 @@ async def test_send_without_response(self): conn.close() @run_until_complete - async def test_send_to_closed(self): + async def test_send_to_closed(self) -> None: host, port = self.kafka_host, self.kafka_port conn = AIOKafkaConnection(host=host, port=port) request = MetadataRequest([]) @@ -151,7 +154,7 @@ async def test_send_to_closed(self): await conn.send(request) @run_until_complete - async def test_invalid_correlation_id(self): + async def test_invalid_correlation_id(self) -> None: host, port = self.kafka_host, self.kafka_port request = MetadataRequest([]) @@ -163,14 +166,14 @@ async def test_invalid_correlation_id(self): reader = mock.MagicMock() int32 = struct.Struct(">i") resp = MetadataResponse(brokers=[], topics=[]) - resp = resp.encode() - resp = int32.pack(999) + resp # set invalid correlation id + resp_bytes = resp.encode() + resp_bytes = int32.pack(999) + resp_bytes # set invalid correlation id - async def first_resp(*args: Any, **kw: Any): - return int32.pack(len(resp)) + async def first_resp(*args: Any, **kw: Any) -> bytes: + return int32.pack(len(resp_bytes)) - async def second_resp(*args: Any, **kw: Any): - return resp + async def second_resp(*args: Any, **kw: Any) -> bytes: + return resp_bytes reader.readexactly.side_effect = [first_resp(), second_resp()] writer = mock.MagicMock() @@ -184,7 +187,7 @@ async def second_resp(*args: Any, **kw: Any): await conn.send(request) @run_until_complete - async def test_correlation_id_on_group_coordinator_req(self): + async def test_correlation_id_on_group_coordinator_req(self) -> None: host, port = self.kafka_host, self.kafka_port request = GroupCoordinatorRequest(consumer_group="test") @@ -198,14 +201,14 @@ async def test_correlation_id_on_group_coordinator_req(self): resp = GroupCoordinatorResponse( error_code=0, coordinator_id=22, host="127.0.0.1", port=3333 ) - resp = resp.encode() - resp = int32.pack(0) + resp # set correlation id to 0 + resp_bytes = resp.encode() + resp_bytes = int32.pack(0) + resp_bytes # set correlation id to 0 - async def first_resp(*args: Any, **kw: Any): - return int32.pack(len(resp)) + async def first_resp(*args: Any, **kw: Any) -> bytes: + return int32.pack(len(resp_bytes)) - async def second_resp(*args: Any, **kw: Any): - return resp + async def second_resp(*args: Any, **kw: Any) -> bytes: + return resp_bytes reader.readexactly.side_effect = [first_resp(), second_resp()] writer = mock.MagicMock() @@ -223,10 +226,10 @@ async def second_resp(*args: Any, **kw: Any): self.assertEqual(response.port, 3333) @run_until_complete - async def test_osserror_in_reader_task(self): + async def test_osserror_in_reader_task(self) -> None: host, port = self.kafka_host, self.kafka_port - async def invoke_osserror(*a, **kw): + async def invoke_osserror(*a: Any, **kw: Any) -> NoReturn: await asyncio.sleep(0.1) raise OSError("test oserror") @@ -249,28 +252,28 @@ async def invoke_osserror(*a, **kw): self.assertEqual(conn.connected(), False) @run_until_complete - async def test_close_disconnects_connection(self): + async def test_close_disconnects_connection(self) -> None: host, port = self.kafka_host, self.kafka_port conn = await create_conn(host, port) self.assertTrue(conn.connected()) conn.close() self.assertFalse(conn.connected()) - def test_connection_version_info(self): + def test_connection_version_info(self) -> None: # All version supported - version_info = VersionInfo({SaslHandShakeRequest[0].API_KEY: [0, 1]}) + version_info = VersionInfo({SaslHandShakeRequest[0].API_KEY: (0, 1)}) self.assertEqual( version_info.pick_best(SaslHandShakeRequest), SaslHandShakeRequest[1] ) # Broker only supports the lesser version - version_info = VersionInfo({SaslHandShakeRequest[0].API_KEY: [0, 0]}) + version_info = VersionInfo({SaslHandShakeRequest[0].API_KEY: (0, 0)}) self.assertEqual( version_info.pick_best(SaslHandShakeRequest), SaslHandShakeRequest[0] ) # We don't support any version compatible with the broker - version_info = VersionInfo({SaslHandShakeRequest[0].API_KEY: [2, 3]}) + version_info = VersionInfo({SaslHandShakeRequest[0].API_KEY: (2, 3)}) with self.assertRaises(KafkaError): self.assertEqual( version_info.pick_best(SaslHandShakeRequest), SaslHandShakeRequest[1] @@ -283,7 +286,7 @@ def test_connection_version_info(self): ) @run_until_complete - async def test__do_sasl_handshake_v0(self): + async def test__do_sasl_handshake_v0(self) -> None: host, port = self.kafka_host, self.kafka_port # setup connection with mocked send and send_bytes @@ -294,22 +297,22 @@ async def test__do_sasl_handshake_v0(self): sasl_plain_username="admin", sasl_plain_password="123", ) - conn.close = close_mock = mock.MagicMock() + conn.close = close_mock = mock.MagicMock() # type: ignore[method-assign] supported_mechanisms = ["PLAIN"] - error_class = NoError + error_class: Type[BrokerResponseError] = NoError - async def mock_send(request, expect_response=True): + async def mock_send(request: Request, expect_response: bool = True) -> Response: return SaslHandShakeResponse[0]( error_code=error_class.errno, enabled_mechanisms=supported_mechanisms ) - async def mock_sasl_send(payload, expect_response): + async def mock_sasl_send(payload: bytes, expect_response: bool) -> bytes: return b"" - conn.send = mock.Mock(side_effect=mock_send) - conn._send_sasl_token = mock.Mock(side_effect=mock_sasl_send) - conn._version_info = VersionInfo({SaslHandShakeRequest[0].API_KEY: [0, 0]}) + conn.send = mock.Mock(side_effect=mock_send) # type: ignore[method-assign] + conn._send_sasl_token = mock.Mock(side_effect=mock_sasl_send) # type: ignore[method-assign] + conn._version_info = VersionInfo({SaslHandShakeRequest[0].API_KEY: (0, 0)}) await conn._do_sasl_handshake() @@ -326,7 +329,7 @@ async def mock_sasl_send(payload, expect_response): self.assertTrue(close_mock.call_count) @run_until_complete - async def test__do_sasl_handshake_v1(self): + async def test__do_sasl_handshake_v1(self) -> None: host, port = self.kafka_host, self.kafka_port # setup connection with mocked send and send_bytes @@ -338,13 +341,13 @@ async def test__do_sasl_handshake_v1(self): sasl_plain_password="123", security_protocol="SASL_PLAINTEXT", ) - conn.close = close_mock = mock.MagicMock() + conn.close = close_mock = mock.MagicMock() # type: ignore[method-assign] supported_mechanisms = ["PLAIN"] - error_class = NoError - auth_error_class = NoError + error_class: Type[BrokerResponseError] = NoError + auth_error_class: Type[BrokerResponseError] = NoError - async def mock_send(request, expect_response=True): + async def mock_send(request: Request, expect_response: bool = True) -> Response: if request.API_KEY == SaslHandShakeRequest[0].API_KEY: assert request.API_VERSION == 1 return SaslHandShakeResponse[1]( @@ -359,8 +362,8 @@ async def mock_send(request, expect_response=True): sasl_auth_bytes=b"", ) - conn.send = mock.Mock(side_effect=mock_send) - conn._version_info = VersionInfo({SaslHandShakeRequest[0].API_KEY: [0, 1]}) + conn.send = mock.Mock(side_effect=mock_send) # type: ignore[method-assign] + conn._version_info = VersionInfo({SaslHandShakeRequest[0].API_KEY: (0, 1)}) await conn._do_sasl_handshake() @@ -386,16 +389,16 @@ async def mock_send(request, expect_response=True): self.assertTrue(close_mock.call_count) @run_until_complete - async def test__send_sasl_token(self): + async def test__send_sasl_token(self) -> None: # Before Kafka 1.0.0 SASL was performed on the wire without # KAFKA_HEADER in the protocol. So we needed another private # function to send `raw` data with only length prefixed # setup connection with mocked transport and protocol conn = AIOKafkaConnection(host="", port=9999) - conn.close = mock.MagicMock() + conn.close = mock.MagicMock() # type: ignore[method-assign] conn._writer = mock.MagicMock() - out_buffer = [] + out_buffer: List[bytes] = [] conn._writer.write = mock.Mock(side_effect=out_buffer.append) conn._reader = mock.MagicMock() self.assertEqual(len(conn._requests), 0) @@ -407,20 +410,21 @@ async def test__send_sasl_token(self): out_buffer.clear() # Resolve the request - conn._requests[0][2].set_result(None) + assert isinstance(conn._requests[0], SaslPacket) + conn._requests[0][2].set_result(b"") conn._requests.clear() await fut # Broken pipe error conn._writer.write.side_effect = OSError with self.assertRaises(KafkaConnectionError): - conn._send_sasl_token(b"Super data") + await conn._send_sasl_token(b"Super data") self.assertEqual(out_buffer, []) self.assertEqual(len(conn._requests), 0) self.assertEqual(conn.close.call_count, 1) conn._writer = None with self.assertRaises(KafkaConnectionError): - conn._send_sasl_token(b"Super data") + await conn._send_sasl_token(b"Super data") # We don't need to close 2ce self.assertEqual(conn.close.call_count, 1)