Skip to content

Commit

Permalink
add typing to tests/test_conn.py, fix review
Browse files Browse the repository at this point in the history
  • Loading branch information
dimastbk committed Jul 1, 2024
1 parent 2fe0965 commit 03c7499
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 75 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ FORMATTED_AREAS=\
aiokafka/protocol/ \
aiokafka/record/ \
tests/test_codec.py \
tests/test_conn.py \
tests/test_helpers.py \
tests/test_protocol.py \
tests/test_protocol_object_conversion.py \
Expand Down
18 changes: 15 additions & 3 deletions aiokafka/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@
from aiokafka.protocol.admin import (
ApiVersionRequest,
SaslAuthenticateRequest,
SaslAuthenticateResponse_v0,
SaslAuthenticateResponse_v1,
SaslHandShakeRequest,
SaslHandShakeResponse_v0,
SaslHandShakeResponse_v1,
)
from aiokafka.protocol.api import Request, Response
from aiokafka.protocol.commit import (
Expand Down Expand Up @@ -100,7 +104,6 @@ def __init__(self, versions: Dict[int, Tuple[int, int]]) -> None:
self._versions = versions

def pick_best(self, request_versions: Sequence[Type[RequestT]]) -> Type[RequestT]:
api_key = request_versions[0].API_KEY
api_key = cast(int, request_versions[0].API_KEY)

Check warning on line 107 in aiokafka/conn.py

View check run for this annotation

Codecov / codecov/patch

aiokafka/conn.py#L107

Added line #L107 was not covered by tests
if api_key not in self._versions:
return request_versions[0]
Expand Down Expand Up @@ -173,7 +176,7 @@ def __init__(
**kw: Any,
) -> None:
self._closed_fut = closed_fut
super().__init__(*args, loop=loop, **kw)
super().__init__(*args, loop=loop, **kw) # type: ignore[misc]

Check warning on line 179 in aiokafka/conn.py

View check run for this annotation

Codecov / codecov/patch

aiokafka/conn.py#L179

Added line #L179 was not covered by tests

def connection_lost(self, exc: Optional[Exception]) -> None:
super().connection_lost(exc)
Expand Down Expand Up @@ -359,6 +362,9 @@ async def _do_sasl_handshake(self) -> None:

sasl_handshake = handshake_klass(self._sasl_mechanism)
response = await self.send(sasl_handshake)
response = cast(

Check warning on line 365 in aiokafka/conn.py

View check run for this annotation

Codecov / codecov/patch

aiokafka/conn.py#L365

Added line #L365 was not covered by tests
Union[SaslHandShakeResponse_v0, SaslHandShakeResponse_v1], response
)
error_type = Errors.for_code(response.error_code)
if error_type is not Errors.NoError:
error = error_type(self)
Expand Down Expand Up @@ -419,6 +425,10 @@ async def _do_sasl_handshake(self) -> None:
else:
req = auth_klass(payload)
resp = await self.send(req)
resp = cast(

Check warning on line 428 in aiokafka/conn.py

View check run for this annotation

Codecov / codecov/patch

aiokafka/conn.py#L428

Added line #L428 was not covered by tests
Union[SaslAuthenticateResponse_v0, SaslAuthenticateResponse_v1],
resp,
)
error_type = Errors.for_code(resp.error_code)
if error_type is not Errors.NoError:
exc = error_type(resp.error_message)
Expand Down Expand Up @@ -583,6 +593,8 @@ def _send_sasl_token(
self, payload: bytes, expect_response: Literal[True]
) -> Coroutine[None, None, bytes]: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.
@overload
def _send_sasl_token(self, payload: bytes) -> Coroutine[None, None, bytes]: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.
@overload
def _send_sasl_token(
self, payload: bytes, expect_response: bool
) -> Union[Coroutine[None, None, None], Coroutine[None, None, bytes]]: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.
Expand Down Expand Up @@ -682,7 +694,7 @@ async def _read(self_ref: weakref.ReferenceType[AIOKafkaConnection]) -> None:
def _handle_frame(self, resp: bytes) -> None:
packet = self._requests[0]

Check warning on line 695 in aiokafka/conn.py

View check run for this annotation

Codecov / codecov/patch

aiokafka/conn.py#L695

Added line #L695 was not covered by tests

if packet.correlation_id is None: # Is a SASL packet, just pass it though
if isinstance(packet, SaslPacket): # Is a SASL packet, just pass it though
if not packet.fut.done():
packet.fut.set_result(resp)

Check warning on line 699 in aiokafka/conn.py

View check run for this annotation

Codecov / codecov/patch

aiokafka/conn.py#L699

Added line #L699 was not covered by tests
else:
Expand Down
6 changes: 3 additions & 3 deletions aiokafka/protocol/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ class SaslHandShakeRequest_v1(Request[SaslHandShakeResponse_v1]):


SaslHandShakeRequest = (SaslHandShakeRequest_v0, SaslHandShakeRequest_v1)
SaslHandShakeResponse = [SaslHandShakeResponse_v0, SaslHandShakeResponse_v1]
SaslHandShakeResponse = (SaslHandShakeResponse_v0, SaslHandShakeResponse_v1)


class DescribeAclsResponse_v0(Response):
Expand Down Expand Up @@ -1048,10 +1048,10 @@ class SaslAuthenticateRequest_v1(Request[SaslAuthenticateResponse_v1]):
SaslAuthenticateRequest_v0,
SaslAuthenticateRequest_v1,
)
SaslAuthenticateResponse = [
SaslAuthenticateResponse = (
SaslAuthenticateResponse_v0,
SaslAuthenticateResponse_v1,
]
)


class CreatePartitionsResponse_v0(Response):
Expand Down
16 changes: 14 additions & 2 deletions aiokafka/protocol/commit.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,11 @@ class GroupCoordinatorResponse_v0(Response):
("port", Int32),
)

error_code: int
coordinator_id: int
host: str
port: int


class GroupCoordinatorResponse_v1(Response):
API_KEY = 10
Expand All @@ -288,8 +293,15 @@ class GroupCoordinatorResponse_v1(Response):
("port", Int32),
)

throttle_time_ms: int
error_code: int
error_message: str
coordinator_id: int
host: str
port: int


class GroupCoordinatorRequest_v0(Request):
class GroupCoordinatorRequest_v0(Request[GroupCoordinatorResponse_v0]):
API_KEY = 10
API_VERSION = 0
RESPONSE_TYPE = GroupCoordinatorResponse_v0
Expand All @@ -298,7 +310,7 @@ class GroupCoordinatorRequest_v0(Request):
)


class GroupCoordinatorRequest_v1(Request):
class GroupCoordinatorRequest_v1(Request[GroupCoordinatorResponse_v1]):
API_KEY = 10
API_VERSION = 1
RESPONSE_TYPE = GroupCoordinatorResponse_v1
Expand Down
15 changes: 15 additions & 0 deletions tests/_testutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from concurrent import futures
from contextlib import contextmanager
from functools import wraps
from typing import List
from unittest.mock import Mock

import pytest
Expand Down Expand Up @@ -352,6 +353,20 @@ def kdestroy(self):
class KafkaIntegrationTestCase(unittest.TestCase):
topic = None

# from setup_test_class fixture
loop: asyncio.AbstractEventLoop
kafka_host: str
kafka_port: int
kafka_ssl_port: int
kafka_sasl_plain_port: int
kafka_sasl_ssl_port: int
ssl_folder: pathlib.Path
acl_manager: ACLManager
kerberos_utils: KerberosUtils
kafka_config: KafkaConfig
hosts: List[str]
kafka_version: str

@contextmanager
def silence_loop_exception_handler(self):
if hasattr(self.loop, "get_exception_handler"):
Expand Down
39 changes: 24 additions & 15 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import asyncio
import gc
import logging
Expand All @@ -7,6 +9,7 @@
import sys
import uuid
from dataclasses import dataclass
from typing import Generator

import pytest

Expand All @@ -21,7 +24,7 @@
)
from aiokafka.util import NO_EXTENSIONS

from ._testutil import wait_kafka
from ._testutil import ACLManager, KafkaConfig, KerberosUtils, wait_kafka

if not NO_EXTENSIONS:
assert (
Expand Down Expand Up @@ -67,33 +70,31 @@ def docker(request):


@pytest.fixture(scope="class")
def acl_manager(kafka_server, request):
def acl_manager(
kafka_server: KafkaServer, request: pytest.FixtureRequest
) -> ACLManager:
image = request.config.getoption("--docker-image")
tag = image.split(":")[-1].replace("_", "-")

from ._testutil import ACLManager

manager = ACLManager(kafka_server.container, tag)
return manager


@pytest.fixture(scope="class")
def kafka_config(kafka_server, request):
def kafka_config(
kafka_server: KafkaServer, request: pytest.FixtureRequest
) -> KafkaConfig:
image = request.config.getoption("--docker-image")
tag = image.split(":")[-1].replace("_", "-")

from ._testutil import KafkaConfig

manager = KafkaConfig(kafka_server.container, tag)
return manager


if sys.platform != "win32":

@pytest.fixture(scope="class")
def kerberos_utils(kafka_server):
from ._testutil import KerberosUtils

def kerberos_utils(kafka_server: KafkaServer) -> KerberosUtils:
utils = KerberosUtils(kafka_server.container)
utils.create_keytab()
return utils
Expand Down Expand Up @@ -124,7 +125,9 @@ def kafka_image():


@pytest.fixture(scope="session")
def ssl_folder(docker_ip_address, docker, kafka_image):
def ssl_folder(
docker_ip_address: str, docker: libdocker.DockerClient, kafka_image: str
) -> pathlib.Path:
ssl_dir = pathlib.Path("tests/ssl_cert")
if ssl_dir.is_dir():
# Skip generating certificates when they already exist. Remove
Expand Down Expand Up @@ -171,7 +174,7 @@ def ssl_folder(docker_ip_address, docker, kafka_image):


@pytest.fixture(scope="session")
def docker_ip_address():
def docker_ip_address() -> str:
"""Returns IP address of the docker daemon service."""
return "127.0.0.1"

Expand Down Expand Up @@ -210,7 +213,7 @@ def hosts(self):
@pytest.fixture(scope="session")
def kafka_server(
kafka_image, docker, docker_ip_address, unused_port, session_id, ssl_folder
):
) -> Generator[KafkaServer, None, None]:
kafka_host = docker_ip_address
kafka_port = unused_port()
kafka_ssl_port = unused_port()
Expand Down Expand Up @@ -316,8 +319,14 @@ def setup_test_class_serverless(request, loop):

@pytest.fixture(scope="class")
def setup_test_class(
request, loop, kafka_server, ssl_folder, acl_manager, kerberos_utils, kafka_config
):
request: pytest.FixtureRequest,
loop: asyncio.AbstractEventLoop,
kafka_server: KafkaServer,
ssl_folder: pathlib.Path,
acl_manager: ACLManager,
kerberos_utils: KerberosUtils,
kafka_config: KafkaConfig,
) -> None:
request.cls.loop = loop
request.cls.kafka_host = kafka_server.host
request.cls.kafka_port = kafka_server.port
Expand Down
Loading

0 comments on commit 03c7499

Please sign in to comment.