diff --git a/Makefile b/Makefile index f0214347..83c872bb 100644 --- a/Makefile +++ b/Makefile @@ -7,6 +7,7 @@ DOCKER_IMAGE=aiolibs/kafka:$(SCALA_VERSION)_$(KAFKA_VERSION) DIFF_BRANCH=origin/master FORMATTED_AREAS=\ aiokafka/codec.py \ + aiokafka/conn.py \ aiokafka/coordinator/ \ aiokafka/errors.py \ aiokafka/helpers.py \ diff --git a/aiokafka/abc.py b/aiokafka/abc.py index abb9f216..f9264574 100644 --- a/aiokafka/abc.py +++ b/aiokafka/abc.py @@ -1,4 +1,5 @@ import abc +from typing import Dict class ConsumerRebalanceListener(abc.ABC): @@ -103,7 +104,7 @@ class AbstractTokenProvider(abc.ABC): """ @abc.abstractmethod - async def token(self): + async def token(self) -> str: """ An async callback returning a :class:`str` ID/Access Token to be sent to the Kafka client. In case where a synchronous callback is needed, @@ -122,7 +123,7 @@ def _token(self): # The actual synchronous token callback. """ - def extensions(self): + def extensions(self) -> Dict[str, str]: """ This is an OPTIONAL method that may be implemented. diff --git a/aiokafka/conn.py b/aiokafka/conn.py index a2402b72..11aede0d 100644 --- a/aiokafka/conn.py +++ b/aiokafka/conn.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import asyncio import base64 import collections +import enum import functools import hashlib import hmac @@ -8,6 +11,7 @@ import logging import random import socket +import ssl import struct import sys import time @@ -15,8 +19,29 @@ import uuid import warnings import weakref +from typing import ( + Any, + Awaitable, + Callable, + Coroutine, + Dict, + Generator, + Iterable, + List, + Literal, + NamedTuple, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, + cast, + overload, +) import async_timeout +from typing_extensions import Buffer import aiokafka.errors as Errors from aiokafka.abc import AbstractTokenProvider @@ -25,6 +50,7 @@ SaslAuthenticateRequest, SaslHandShakeRequest, ) +from aiokafka.protocol.api import Request, Response from aiokafka.protocol.commit import ( GroupCoordinatorResponse_v0 as GroupCoordinatorResponse, ) @@ -33,7 +59,10 @@ try: import gssapi except ImportError: - gssapi = None + gssapi = None # type: ignore[assignment] + +RequestT = TypeVar("RequestT", bound=Request) +ResponseT = TypeVar("ResponseT", bound=Response) __all__ = ["AIOKafkaConnection", "create_conn"] @@ -45,7 +74,19 @@ SASL_QOP_AUTH = 1 -class CloseReason: +class Packet(NamedTuple): + correlation_id: int + request: Request[Response] + fut: asyncio.Future[Response] + + +class SaslPacket(NamedTuple): + correlation_id: None + request: None + fut: asyncio.Future[bytes] + + +class CloseReason(enum.IntEnum): CONNECTION_BROKEN = 0 CONNECTION_TIMEOUT = 1 OUT_OF_SYNC = 2 @@ -55,17 +96,19 @@ class CloseReason: class VersionInfo: - def __init__(self, versions): + def __init__(self, versions: Dict[int, Tuple[int, int]]) -> None: self._versions = versions - def pick_best(self, request_versions): + def pick_best(self, request_versions: Sequence[Type[RequestT]]) -> Type[RequestT]: api_key = request_versions[0].API_KEY + api_key = cast(int, request_versions[0].API_KEY) if api_key not in self._versions: return request_versions[0] min_version, max_version = self._versions[api_key] for req_klass in reversed(request_versions): - if min_version <= req_klass.API_VERSION <= max_version: + req_api_version = cast(int, req_klass.API_VERSION) + if min_version <= req_api_version <= max_version: return req_klass raise Errors.KafkaError( @@ -75,24 +118,30 @@ def pick_best(self, request_versions): async def create_conn( - host, - port, + host: str, + port: int, *, - client_id="aiokafka", - request_timeout_ms=40000, - api_version=(0, 8, 2), - ssl_context=None, - security_protocol="PLAINTEXT", - max_idle_ms=None, - on_close=None, - sasl_mechanism=None, - sasl_plain_username=None, - sasl_plain_password=None, - sasl_kerberos_service_name="kafka", - sasl_kerberos_domain_name=None, - sasl_oauth_token_provider=None, - version_hint=None, -): + client_id: str = "aiokafka", + request_timeout_ms: float = 40000, + api_version: Union[Tuple[int, int], Tuple[int, int, int]] = (0, 8, 2), + ssl_context: Optional[ssl.SSLContext] = None, + security_protocol: Literal[ + "PLAINTEXT", "SASL_PLAINTEXT", "SSL", "SASL_SSL" + ] = "PLAINTEXT", + max_idle_ms: Optional[float] = None, + on_close: Optional[ + Callable[[AIOKafkaConnection, Optional[CloseReason]], None] + ] = None, + sasl_mechanism: Optional[ + Literal["PLAIN", "GSSAPI", "SCRAM-SHA-256", "SCRAM-SHA-512", "OAUTHBEARER"] + ] = None, + sasl_plain_username: Optional[str] = None, + sasl_plain_password: Optional[str] = None, + sasl_kerberos_service_name: str = "kafka", + sasl_kerberos_domain_name: Optional[str] = None, + sasl_oauth_token_provider: Optional[AbstractTokenProvider] = None, + version_hint: Optional[Union[Tuple[int, int], Tuple[int, int, int]]] = None, +) -> AIOKafkaConnection: conn = AIOKafkaConnection( host, port, @@ -116,11 +165,17 @@ async def create_conn( class AIOKafkaProtocol(asyncio.StreamReaderProtocol): - def __init__(self, closed_fut, *args, loop, **kw): + def __init__( + self, + closed_fut: asyncio.Future[None], + *args: Any, + loop: asyncio.AbstractEventLoop, + **kw: Any, + ) -> None: self._closed_fut = closed_fut super().__init__(*args, loop=loop, **kw) - def connection_lost(self, exc): + def connection_lost(self, exc: Optional[Exception]) -> None: super().connection_lost(exc) if not self._closed_fut.cancelled(): self._closed_fut.set_result(None) @@ -129,29 +184,37 @@ def connection_lost(self, exc): class AIOKafkaConnection: """Class for manage connection to Kafka node""" - _reader = None # For __del__ to work properly, just in case - _source_traceback = None + _reader: Optional[asyncio.StreamReader] = ( + None # For __del__ to work properly, just in case + ) + _source_traceback: Optional[traceback.StackSummary] = None def __init__( self, - host, - port, + host: str, + port: int, *, - client_id="aiokafka", - request_timeout_ms=40000, - api_version=(0, 8, 2), - ssl_context=None, - security_protocol="PLAINTEXT", - max_idle_ms=None, - on_close=None, - sasl_mechanism=None, - sasl_plain_password=None, - sasl_plain_username=None, - sasl_kerberos_service_name="kafka", - sasl_kerberos_domain_name=None, - sasl_oauth_token_provider=None, - version_hint=None, - ): + client_id: str = "aiokafka", + request_timeout_ms: float = 40000, + api_version: Union[Tuple[int, int], Tuple[int, int, int]] = (0, 8, 2), + ssl_context: Optional[ssl.SSLContext] = None, + security_protocol: Literal[ + "PLAINTEXT", "SASL_PLAINTEXT", "SSL", "SASL_SSL" + ] = "PLAINTEXT", + max_idle_ms: Optional[float] = None, + on_close: Optional[ + Callable[[AIOKafkaConnection, Optional[CloseReason]], None] + ] = None, + sasl_mechanism: Optional[ + Literal["PLAIN", "GSSAPI", "SCRAM-SHA-256", "SCRAM-SHA-512", "OAUTHBEARER"] + ] = None, + sasl_plain_password: Optional[str] = None, + sasl_plain_username: Optional[str] = None, + sasl_kerberos_service_name: str = "kafka", + sasl_kerberos_domain_name: Optional[str] = None, + sasl_oauth_token_provider: Optional[AbstractTokenProvider] = None, + version_hint: Optional[Union[Tuple[int, int], Tuple[int, int, int]]] = None, + ) -> None: loop = get_running_loop() if sasl_mechanism == "GSSAPI": @@ -188,17 +251,21 @@ def __init__( self._version_hint = version_hint self._version_info = VersionInfo({}) - self._reader = self._writer = self._protocol = None + self._reader: Optional[asyncio.StreamReader] = None + self._writer: Optional[asyncio.StreamWriter] = None + self._protocol: Optional[AIOKafkaProtocol] = None # Even on small size seems to be a bit faster than list. # ~2x on size of 2 in Python3.6 - self._requests = collections.deque() - self._read_task = None - self._correlation_id = 0 - self._closed_fut = None + self._requests: collections.deque[Union[Packet, SaslPacket]] = ( + collections.deque() + ) + self._read_task: Optional[asyncio.Task[None]] = None + self._correlation_id: int = 0 + self._closed_fut: Optional[asyncio.Future[None]] = None self._max_idle_ms = max_idle_ms self._last_action = time.monotonic() - self._idle_handle = None + self._idle_handle: Optional[asyncio.Handle] = None self._on_close_cb = on_close @@ -207,7 +274,7 @@ def __init__( # Warn and try to close. We can close synchronously, so will attempt # that - def __del__(self, _warnings=warnings): + def __del__(self, _warnings=warnings) -> None: # type: ignore[no-untyped-def] if self.connected(): _warnings.warn( f"Unclosed AIOKafkaConnection {self!r}", @@ -230,7 +297,7 @@ def __del__(self, _warnings=warnings): context["source_traceback"] = self._source_traceback self._loop.call_exception_handler(context) - async def connect(self): + async def connect(self) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: loop = self._loop self._closed_fut = create_future() if self._security_protocol in ["PLAINTEXT", "SASL_PLAINTEXT"]: @@ -268,10 +335,10 @@ async def connect(self): return reader, writer - async def _do_version_lookup(self): + async def _do_version_lookup(self) -> None: version_req = ApiVersionRequest[0]() response = await self.send(version_req) - versions = {} + versions: Dict[int, Tuple[int, int]] = {} for api_key, min_version, max_version in response.api_versions: assert min_version <= max_version, ( f"{min_version} should be less than" @@ -280,9 +347,10 @@ async def _do_version_lookup(self): versions[api_key] = (min_version, max_version) self._version_info = VersionInfo(versions) - async def _do_sasl_handshake(self): + async def _do_sasl_handshake(self) -> None: # NOTE: We will only fallback to v0.9 gssapi scheme if user explicitly # stated, that api_version is "0.9" + exc: Errors.KafkaError if self._version_hint and self._version_hint < (0, 10): handshake_klass = None assert self._sasl_mechanism == "GSSAPI", "Only GSSAPI supported for v0.9" @@ -318,6 +386,7 @@ async def _do_sasl_handshake(self): ): log.warning("Sending username and password in the clear") + authenticator: BaseSaslAuthenticator if self._sasl_mechanism == "GSSAPI": authenticator = self.authenticator_gssapi() elif self._sasl_mechanism.startswith("SCRAM-SHA-"): @@ -332,7 +401,7 @@ async def _do_sasl_handshake(self): else: auth_klass = None - auth_bytes = None + auth_bytes: Optional[bytes] = None expect_response = True while True: @@ -368,41 +437,51 @@ async def _do_sasl_handshake(self): self._sasl_mechanism, ) - def authenticator_plain(self): + def authenticator_plain(self) -> SaslPlainAuthenticator: + assert self._sasl_plain_password is not None + assert self._sasl_plain_username is not None return SaslPlainAuthenticator( loop=self._loop, sasl_plain_password=self._sasl_plain_password, sasl_plain_username=self._sasl_plain_username, ) - def authenticator_gssapi(self): + def authenticator_gssapi(self) -> SaslGSSAPIAuthenticator: return SaslGSSAPIAuthenticator( loop=self._loop, principal=self.sasl_principal, ) - def authenticator_scram(self): + def authenticator_scram(self) -> ScramAuthenticator: + assert self._sasl_plain_password is not None + assert self._sasl_plain_username is not None + assert self._sasl_mechanism in ("SCRAM-SHA-256", "SCRAM-SHA-512") return ScramAuthenticator( loop=self._loop, sasl_plain_password=self._sasl_plain_password, sasl_plain_username=self._sasl_plain_username, - sasl_mechanism=self._sasl_mechanism, + sasl_mechanism=self._sasl_mechanism, # type: ignore[arg-type] ) - def authenticator_oauth(self): + def authenticator_oauth(self) -> OAuthAuthenticator: + assert self._sasl_oauth_token_provider is not None return OAuthAuthenticator( sasl_oauth_token_provider=self._sasl_oauth_token_provider, ) @property - def sasl_principal(self): + def sasl_principal(self) -> str: service = self._sasl_kerberos_service_name domain = self._sasl_kerberos_domain_name or self.host return f"{service}@{domain}" @classmethod - def _on_read_task_error(cls, self_ref, read_task): + def _on_read_task_error( + cls, + self_ref: weakref.ReferenceType[AIOKafkaConnection], + read_task: asyncio.Task[None], + ) -> None: # We don't want to react to cancelled errors if read_task.cancelled(): return @@ -418,12 +497,13 @@ def _on_read_task_error(cls, self_ref, read_task): self.close(reason=CloseReason.CONNECTION_BROKEN, exc=exc) @staticmethod - def _idle_check(self_ref): + def _idle_check(self_ref: weakref.ReferenceType[AIOKafkaConnection]) -> None: self = self_ref() if self is None: return idle_for = time.monotonic() - self._last_action + assert self._max_idle_ms is not None timeout = self._max_idle_ms / 1000 # If we have any pending requests, we are assumed to be not idle. # it's up to `request_timeout_ms` to break those. @@ -440,18 +520,31 @@ def _idle_check(self_ref): wake_up_in, self._idle_check, self_ref ) - def __repr__(self): + def __repr__(self) -> str: return f"" @property - def host(self): + def host(self) -> str: return self._host @property - def port(self): + def port(self) -> int: return self._port - def send(self, request, expect_response=True): + @overload + def send(self, request: Request[ResponseT]) -> Coroutine[None, None, ResponseT]: ... + @overload + def send( + self, request: Request[ResponseT], expect_response: Literal[False] + ) -> Coroutine[None, None, None]: ... + @overload + def send( + self, request: Request[ResponseT], expect_response: Literal[True] + ) -> Coroutine[None, None, ResponseT]: ... + + def send( + self, request: Request[ResponseT], expect_response: bool = True + ) -> Union[Coroutine[None, None, ResponseT], Coroutine[None, None, None]]: if self._writer is None: raise Errors.KafkaConnectionError( f"No connection to broker at {self._host}:{self._port}" @@ -477,11 +570,26 @@ def send(self, request, expect_response=True): return self._writer.drain() fut = self._loop.create_future() self._requests.append( - (correlation_id, request, fut), + Packet(correlation_id, request, fut), ) return wait_for(fut, self._request_timeout) - def _send_sasl_token(self, payload, expect_response=True): + @overload + def _send_sasl_token( + self, payload: bytes, expect_response: Literal[False] + ) -> Coroutine[None, None, None]: ... + @overload + def _send_sasl_token( + self, payload: bytes, expect_response: Literal[True] + ) -> Coroutine[None, None, bytes]: ... + @overload + def _send_sasl_token( + self, payload: bytes, expect_response: bool + ) -> Union[Coroutine[None, None, None], Coroutine[None, None, bytes]]: ... + + def _send_sasl_token( + self, payload: bytes, expect_response: bool = True + ) -> Union[Coroutine[None, None, None], Coroutine[None, None, bytes]]: if self._writer is None: raise Errors.KafkaConnectionError( f"No connection to broker at {self._host}:{self._port}" @@ -499,17 +607,21 @@ def _send_sasl_token(self, payload, expect_response=True): return self._writer.drain() fut = self._loop.create_future() - self._requests.append((None, None, fut)) + self._requests.append(SaslPacket(None, None, fut)) return wait_for(fut, self._request_timeout) - def connected(self): + def connected(self) -> bool: return bool(self._reader is not None and not self._reader.at_eof()) - def close(self, reason=None, exc=None): + def close( + self, reason: Optional[CloseReason] = None, exc: Optional[Exception] = None + ) -> asyncio.Future[None]: log.debug("Closing connection at %s:%s", self._host, self._port) if self._reader is not None: + assert self._writer is not None self._writer.close() self._writer = self._reader = None + assert self._read_task is not None if not self._read_task.done(): self._read_task.cancel() self._read_task = None @@ -531,9 +643,10 @@ def close(self, reason=None, exc=None): # transport.close() will close socket, but not right ahead. Return # a future in case we need to wait on it. + assert self._closed_fut is not None return self._closed_fut - def _create_reader_task(self): + def _create_reader_task(self) -> asyncio.Task[None]: self_ref = weakref.ref(self) read_task = create_task(self._read(self_ref)) read_task.add_done_callback( @@ -542,7 +655,7 @@ def _create_reader_task(self): return read_task @staticmethod - async def _read(self_ref): + async def _read(self_ref: weakref.ReferenceType[AIOKafkaConnection]) -> None: # XXX: I know that it become a bit more ugly once cyclic references # were removed, but it's needed to allow connections to properly # release resources if leaked. @@ -552,6 +665,7 @@ async def _read(self_ref): return reader = self._reader del self + assert reader is not None while True: resp = await reader.readexactly(4) @@ -565,15 +679,16 @@ async def _read(self_ref): self._handle_frame(resp) del self - def _handle_frame(self, resp): - correlation_id, request, fut = self._requests[0] + def _handle_frame(self, resp: bytes) -> None: + packet = self._requests[0] - if correlation_id is None: # Is a SASL packet, just pass it though - if not fut.done(): - fut.set_result(resp) + if packet.correlation_id is None: # Is a SASL packet, just pass it though + if not packet.fut.done(): + packet.fut.set_result(resp) else: - resp = io.BytesIO(resp) - response_header = request.parse_response_header(resp) + correlation_id, request, fut = packet + resp_io = io.BytesIO(resp) + response_header = request.parse_response_header(resp_io) resp_type = request.RESPONSE_TYPE if ( @@ -600,7 +715,7 @@ def _handle_frame(self, resp): return if not fut.done(): - response = resp_type.decode(resp) + response = resp_type.decode(resp_io) log.debug("%s Response %d: %s", self, correlation_id, response) fut.set_result(response) @@ -611,23 +726,30 @@ def _handle_frame(self, resp): # this future. self._requests.popleft() - def _next_correlation_id(self): + def _next_correlation_id(self) -> int: self._correlation_id = (self._correlation_id + 1) % 2**31 return self._correlation_id class BaseSaslAuthenticator: - def step(self, payload): + # FIXME: move to __init__? + _loop: asyncio.AbstractEventLoop + _authenticator: Generator[Tuple[bytes, bool], bytes, None] + + def step(self, payload: Optional[bytes]) -> Awaitable[Optional[Tuple[bytes, bool]]]: return self._loop.run_in_executor(None, self._step, payload) - def _step(self, payload): + def _step(self, payload: Optional[bytes]) -> Optional[Tuple[bytes, bool]]: """Process next token in sequence and return with: ``None`` if it was the last needed exchange ``tuple`` tuple with new token and a boolean whether it requires an answer token """ try: - data = self._authenticator.send(payload) + if payload is None: + data = next(self._authenticator) + else: + data = self._authenticator.send(payload) except StopIteration: return None else: @@ -635,13 +757,19 @@ def _step(self, payload): class SaslPlainAuthenticator(BaseSaslAuthenticator): - def __init__(self, *, loop, sasl_plain_password, sasl_plain_username): + def __init__( + self, + *, + loop: asyncio.AbstractEventLoop, + sasl_plain_password: str, + sasl_plain_username: str, + ) -> None: self._loop = loop self._sasl_plain_username = sasl_plain_username self._sasl_plain_password = sasl_plain_password self._authenticator = self.authenticator_plain() - def authenticator_plain(self): + def authenticator_plain(self) -> Generator[Tuple[bytes, bool], bytes, None]: """Automaton to authenticate with SASL tokens""" # Send PLAIN credentials per RFC-4616 data = "\0".join( @@ -658,12 +786,12 @@ def authenticator_plain(self): class SaslGSSAPIAuthenticator(BaseSaslAuthenticator): - def __init__(self, *, loop, principal): + def __init__(self, *, loop: asyncio.AbstractEventLoop, principal: str) -> None: self._loop = loop self._principal = principal self._authenticator = self.authenticator_gssapi() - def authenticator_gssapi(self): + def authenticator_gssapi(self) -> Generator[Tuple[bytes, bool], bytes, None]: name = gssapi.Name( self._principal, name_type=gssapi.NameType.hostbased_service, @@ -679,6 +807,7 @@ def authenticator_gssapi(self): server_token = yield client_token, True + assert server_token is not None msg = client_ctx.unwrap(server_token).message qop = struct.pack("b", SASL_QOP_AUTH & msg[0]) @@ -697,33 +826,33 @@ class ScramAuthenticator(BaseSaslAuthenticator): def __init__( self, *, - loop, - sasl_plain_password, - sasl_plain_username, - sasl_mechanism, - ): + loop: asyncio.AbstractEventLoop, + sasl_plain_password: str, + sasl_plain_username: str, + sasl_mechanism: Literal["SCRAM-SHA-256", "SCRAM-SHA-512"], + ) -> None: self._loop = loop self._nonce = str(uuid.uuid4()).replace("-", "") self._auth_message = "" - self._salted_password = None + self._salted_password: Optional[bytes] = None self._sasl_plain_username = sasl_plain_username self._sasl_plain_password = sasl_plain_password.encode("utf-8") self._hashfunc = self.MECHANISMS[sasl_mechanism] self._hashname = "".join(sasl_mechanism.lower().split("-")[1:3]) - self._stored_key = None - self._client_key = None - self._client_signature = None - self._client_proof = None - self._server_key = None - self._server_signature = None + self._stored_key: Optional[bytes] = None + self._client_key: Optional[bytes] = None + self._client_signature: Optional[bytes] = None + self._client_proof: Optional[bytes] = None + self._server_key: Optional[bytes] = None + self._server_signature: Optional[bytes] = None self._authenticator = self.authenticator_scram() - def first_message(self): + def first_message(self) -> str: client_first_bare = f"n={self._sasl_plain_username},r={self._nonce}" self._auth_message += client_first_bare return "n,," + client_first_bare - def process_server_first_message(self, server_first): + def process_server_first_message(self, server_first: str) -> None: self._auth_message += "," + server_first params = dict(pair.split("=", 1) for pair in server_first.split(",")) server_nonce = params["r"] @@ -734,8 +863,10 @@ def process_server_first_message(self, server_first): salt = base64.b64decode(params["s"].encode("utf-8")) iterations = int(params["i"]) - self.create_salted_password(salt, iterations) + self._salted_password = hashlib.pbkdf2_hmac( + self._hashname, self._sasl_plain_password, salt, iterations + ) self._client_key = self.hmac(self._salted_password, b"Client Key") self._stored_key = self._hashfunc(self._client_key).digest() self._client_signature = self.hmac( @@ -749,16 +880,17 @@ def process_server_first_message(self, server_first): self._server_key, self._auth_message.encode("utf-8") ) - def final_message(self): + def final_message(self) -> str: + assert self._client_proof is not None client_proof = base64.b64encode(self._client_proof).decode("utf-8") return f"c=biws,r={self._nonce},p={client_proof}" - def process_server_final_message(self, server_final): + def process_server_final_message(self, server_final: str) -> None: params = dict(pair.split("=", 1) for pair in server_final.split(",")) if self._server_signature != base64.b64decode(params["v"].encode("utf-8")): raise ValueError("Server sent wrong signature!") - def authenticator_scram(self): + def authenticator_scram(self) -> Generator[Tuple[bytes, bool], bytes, None]: client_first = self.first_message().encode("utf-8") server_first = yield client_first, True self.process_server_first_message(server_first.decode("utf-8")) @@ -766,25 +898,22 @@ def authenticator_scram(self): server_final = yield client_final, True self.process_server_final_message(server_final.decode("utf-8")) - def hmac(self, key, msg): + def hmac(self, key: bytes, msg: Buffer) -> bytes: return hmac.new(key, msg, digestmod=self._hashfunc).digest() - def create_salted_password(self, salt, iterations): - self._salted_password = hashlib.pbkdf2_hmac( - self._hashname, self._sasl_plain_password, salt, iterations - ) - @staticmethod - def _xor_bytes(left, right): + def _xor_bytes(left: Iterable[int], right: Iterable[int]) -> bytes: return bytes(lb ^ rb for lb, rb in zip(left, right)) class OAuthAuthenticator(BaseSaslAuthenticator): - def __init__(self, *, sasl_oauth_token_provider): + def __init__(self, *, sasl_oauth_token_provider: AbstractTokenProvider) -> None: self._sasl_oauth_token_provider = sasl_oauth_token_provider self._token_sent = False - async def step(self, payload): + async def step( + self, payload: Optional[bytes] + ) -> Optional[Tuple[bytes, Literal[True]]]: if self._token_sent: return None token = await self._sasl_oauth_token_provider.token() @@ -795,10 +924,10 @@ async def step(self, payload): True, ) - def _build_oauth_client_request(self, token, token_extensions): + def _build_oauth_client_request(self, token: str, token_extensions: str) -> str: return f"n,,\x01auth=Bearer {token}{token_extensions}\x01\x01" - def _token_extensions(self): + def _token_extensions(self) -> str: """ Return a string representation of the OPTIONAL key-value pairs that can be sent with an OAUTHBEARER initial request. @@ -815,7 +944,7 @@ def _token_extensions(self): return "" -def _address_family(address): +def _address_family(address: str) -> socket.AddressFamily: """ Attempt to determine the family of an address (or hostname) @@ -834,7 +963,7 @@ def _address_family(address): return socket.AF_UNSPEC -def get_ip_port_afi(host_and_port_str): +def get_ip_port_afi(host_and_port_str: str) -> Tuple[str, int, socket.AddressFamily]: """ Parse the IP and port from a string in the format of: @@ -879,14 +1008,16 @@ def get_ip_port_afi(host_and_port_str): pass else: return host_and_port_str, DEFAULT_KAFKA_PORT, socket.AF_INET6 - host, port = host_and_port_str.rsplit(":", 1) - port = int(port) + host, port_str = host_and_port_str.rsplit(":", 1) + port = int(port_str) af = _address_family(host) return host, port, af -def collect_hosts(hosts, randomize=True): +def collect_hosts( + hosts: Union[str, Iterable[str]], randomize: bool = True +) -> List[Tuple[str, int, socket.AddressFamily]]: """ Collects a comma-separated set of hosts (host:port) and optionally randomize the returned list. diff --git a/aiokafka/protocol/admin.py b/aiokafka/protocol/admin.py index 2f374286..b48ebe80 100644 --- a/aiokafka/protocol/admin.py +++ b/aiokafka/protocol/admin.py @@ -1,4 +1,4 @@ -from typing import Dict, Iterable, Optional, Tuple +from typing import Dict, Iterable, Optional, Sequence, Tuple from .api import Request, Response from .types import ( @@ -29,6 +29,9 @@ class ApiVersionResponse_v0(Response): ), ) + error_code: int + api_versions: Sequence[Tuple[int, int, int]] + class ApiVersionResponse_v1(Response): API_KEY = 18 @@ -42,39 +45,47 @@ class ApiVersionResponse_v1(Response): ("throttle_time_ms", Int32), ) + error_code: int + api_versions: Sequence[Tuple[int, int, int]] + throttle_time_ms: int + class ApiVersionResponse_v2(Response): API_KEY = 18 API_VERSION = 2 SCHEMA = ApiVersionResponse_v1.SCHEMA + error_code: int + api_versions: Sequence[Tuple[int, int, int]] + throttle_time_ms: int + -class ApiVersionRequest_v0(Request): +class ApiVersionRequest_v0(Request[ApiVersionResponse_v0]): API_KEY = 18 API_VERSION = 0 RESPONSE_TYPE = ApiVersionResponse_v0 SCHEMA = Schema() -class ApiVersionRequest_v1(Request): +class ApiVersionRequest_v1(Request[ApiVersionResponse_v1]): API_KEY = 18 API_VERSION = 1 RESPONSE_TYPE = ApiVersionResponse_v1 SCHEMA = ApiVersionRequest_v0.SCHEMA -class ApiVersionRequest_v2(Request): +class ApiVersionRequest_v2(Request[ApiVersionResponse_v1]): API_KEY = 18 API_VERSION = 2 - RESPONSE_TYPE = ApiVersionResponse_v1 + RESPONSE_TYPE = ApiVersionResponse_v1 # TODO: Why v1? SCHEMA = ApiVersionRequest_v0.SCHEMA -ApiVersionRequest = [ +ApiVersionRequest = ( ApiVersionRequest_v0, ApiVersionRequest_v1, ApiVersionRequest_v2, -] +) ApiVersionResponse = [ ApiVersionResponse_v0, ApiVersionResponse_v1, @@ -488,28 +499,34 @@ class SaslHandShakeResponse_v0(Response): ("error_code", Int16), ("enabled_mechanisms", Array(String("utf-8"))) ) + error_code: int + enabled_mechanisms: Sequence[str] + class SaslHandShakeResponse_v1(Response): API_KEY = 17 API_VERSION = 1 SCHEMA = SaslHandShakeResponse_v0.SCHEMA + error_code: int + enabled_mechanisms: Sequence[str] + -class SaslHandShakeRequest_v0(Request): +class SaslHandShakeRequest_v0(Request[SaslHandShakeResponse_v0]): API_KEY = 17 API_VERSION = 0 RESPONSE_TYPE = SaslHandShakeResponse_v0 SCHEMA = Schema(("mechanism", String("utf-8"))) -class SaslHandShakeRequest_v1(Request): +class SaslHandShakeRequest_v1(Request[SaslHandShakeResponse_v1]): API_KEY = 17 API_VERSION = 1 RESPONSE_TYPE = SaslHandShakeResponse_v1 SCHEMA = SaslHandShakeRequest_v0.SCHEMA -SaslHandShakeRequest = [SaslHandShakeRequest_v0, SaslHandShakeRequest_v1] +SaslHandShakeRequest = (SaslHandShakeRequest_v0, SaslHandShakeRequest_v1) SaslHandShakeResponse = [SaslHandShakeResponse_v0, SaslHandShakeResponse_v1] @@ -992,6 +1009,10 @@ class SaslAuthenticateResponse_v0(Response): ("sasl_auth_bytes", Bytes), ) + error_code: int + error_message: str + sasl_auth_bytes: bytes + class SaslAuthenticateResponse_v1(Response): API_KEY = 36 @@ -1003,25 +1024,30 @@ class SaslAuthenticateResponse_v1(Response): ("session_lifetime_ms", Int64), ) + error_code: int + error_message: str + sasl_auth_bytes: bytes + session_lifetime_ms: int -class SaslAuthenticateRequest_v0(Request): + +class SaslAuthenticateRequest_v0(Request[SaslAuthenticateResponse_v0]): API_KEY = 36 API_VERSION = 0 RESPONSE_TYPE = SaslAuthenticateResponse_v0 SCHEMA = Schema(("sasl_auth_bytes", Bytes)) -class SaslAuthenticateRequest_v1(Request): +class SaslAuthenticateRequest_v1(Request[SaslAuthenticateResponse_v1]): API_KEY = 36 API_VERSION = 1 RESPONSE_TYPE = SaslAuthenticateResponse_v1 SCHEMA = SaslAuthenticateRequest_v0.SCHEMA -SaslAuthenticateRequest = [ +SaslAuthenticateRequest = ( SaslAuthenticateRequest_v0, SaslAuthenticateRequest_v1, -] +) SaslAuthenticateResponse = [ SaslAuthenticateResponse_v0, SaslAuthenticateResponse_v1, diff --git a/aiokafka/protocol/api.py b/aiokafka/protocol/api.py index 1e6ee3b6..c6c5a4ba 100644 --- a/aiokafka/protocol/api.py +++ b/aiokafka/protocol/api.py @@ -2,11 +2,17 @@ import abc from io import BytesIO -from typing import Any, ClassVar, Dict, Optional, Type, Union +from typing import Any, ClassVar, Dict, Generic, Optional, Type, Union + +from typing_extensions import TypeVar from .struct import Struct from .types import Array, Int16, Int32, Schema, String, TaggedFields +ResponseT_co = TypeVar( + "ResponseT_co", bound="Response", default="Response", covariant=True +) + class RequestHeader_v0(Struct): SCHEMA = Schema( @@ -17,7 +23,10 @@ class RequestHeader_v0(Struct): ) def __init__( - self, request: Request, correlation_id: int = 0, client_id: str = "aiokafka" + self, + request: Request[Any], + correlation_id: int = 0, + client_id: str = "aiokafka", ) -> None: super().__init__( request.API_KEY, request.API_VERSION, correlation_id, client_id @@ -36,7 +45,7 @@ class RequestHeader_v1(Struct): def __init__( self, - request: Request, + request: Request[Any], correlation_id: int = 0, client_id: str = "aiokafka", tags: Optional[Dict[int, bytes]] = None, @@ -51,6 +60,8 @@ class ResponseHeader_v0(Struct): ("correlation_id", Int32), ) + correlation_id: int + class ResponseHeader_v1(Struct): SCHEMA = Schema( @@ -58,8 +69,10 @@ class ResponseHeader_v1(Struct): ("tags", TaggedFields), ) + correlation_id: int + -class Request(Struct, metaclass=abc.ABCMeta): +class Request(Struct, Generic[ResponseT_co], metaclass=abc.ABCMeta): FLEXIBLE_VERSION: ClassVar[bool] = False @property @@ -74,7 +87,7 @@ def API_VERSION(self) -> int: @property @abc.abstractmethod - def RESPONSE_TYPE(self) -> Type[Response]: + def RESPONSE_TYPE(self) -> Type[ResponseT_co]: """The Response class associated with the api request""" @property