From 5bdd828c384e7203fd53864cf8855ef50ba2969a Mon Sep 17 00:00:00 2001 From: Jack Leow Date: Fri, 15 Nov 2024 13:54:14 -0800 Subject: [PATCH] Added type annotation to public classes. --- aiokafka/abc.py | 10 +- aiokafka/cluster.py | 111 ++++++++----- aiokafka/conn.py | 5 +- aiokafka/consumer/consumer.py | 198 ++++++++++++----------- aiokafka/producer/message_accumulator.py | 7 +- aiokafka/producer/producer.py | 159 +++++++++--------- aiokafka/py.typed | 0 aiokafka/structs.py | 6 +- 8 files changed, 278 insertions(+), 218 deletions(-) create mode 100644 aiokafka/py.typed diff --git a/aiokafka/abc.py b/aiokafka/abc.py index abb9f2167..b059b5b24 100644 --- a/aiokafka/abc.py +++ b/aiokafka/abc.py @@ -1,5 +1,7 @@ import abc +from aiokafka.structs import TopicPartition + class ConsumerRebalanceListener(abc.ABC): """ @@ -45,7 +47,7 @@ class ConsumerRebalanceListener(abc.ABC): """ @abc.abstractmethod - def on_partitions_revoked(self, revoked): + def on_partitions_revoked(self, revoked: list[TopicPartition]) -> None: """ A coroutine or function the user can implement to provide cleanup or custom state save on the start of a rebalance operation. @@ -65,7 +67,7 @@ def on_partitions_revoked(self, revoked): """ @abc.abstractmethod - def on_partitions_assigned(self, assigned): + def on_partitions_assigned(self, assigned: list[TopicPartition]) -> None: """ A coroutine or function the user can implement to provide load of custom consumer state or cache warmup on completion of a successful @@ -103,7 +105,7 @@ class AbstractTokenProvider(abc.ABC): """ @abc.abstractmethod - async def token(self): + async def token(self) -> None: """ An async callback returning a :class:`str` ID/Access Token to be sent to the Kafka client. In case where a synchronous callback is needed, @@ -122,7 +124,7 @@ def _token(self): # The actual synchronous token callback. """ - def extensions(self): + def extensions(self) -> dict[str, str]: """ This is an OPTIONAL method that may be implemented. diff --git a/aiokafka/cluster.py b/aiokafka/cluster.py index 85496ea92..2495df6a0 100644 --- a/aiokafka/cluster.py +++ b/aiokafka/cluster.py @@ -1,16 +1,39 @@ +from __future__ import annotations + import collections import copy import logging import threading import time from concurrent.futures import Future -from typing import Optional +from typing import Any, Callable, Optional, Sequence, Set, TypedDict, Union from aiokafka import errors as Errors +from aiokafka.client import CoordinationType from aiokafka.conn import collect_hosts +from aiokafka.protocol.commit import GroupCoordinatorResponse_v0, GroupCoordinatorResponse_v1 +from aiokafka.protocol.metadata import MetadataResponse_v0, MetadataResponse_v1, MetadataResponse_v2, MetadataResponse_v3, MetadataResponse_v4, MetadataResponse_v5 from aiokafka.structs import BrokerMetadata, PartitionMetadata, TopicPartition log = logging.getLogger(__name__) +MetadataResponse = Union[ + MetadataResponse_v0, + MetadataResponse_v1, + MetadataResponse_v2, + MetadataResponse_v3, + MetadataResponse_v4, + MetadataResponse_v5, +] +GroupCoordinatorResponse = Union[ + GroupCoordinatorResponse_v0, + GroupCoordinatorResponse_v1 +] + + +class ClusterConfig(TypedDict): + retry_backoff_ms: int + metadata_max_age_ms: int + bootstrap_servers: str | list[str] class ClusterMetadata: @@ -35,28 +58,28 @@ class ClusterMetadata: specified, will default to localhost:9092. """ - DEFAULT_CONFIG = { + DEFAULT_CONFIG: ClusterConfig = { "retry_backoff_ms": 100, "metadata_max_age_ms": 300000, "bootstrap_servers": [], } - def __init__(self, **configs): - self._brokers = {} # node_id -> BrokerMetadata - self._partitions = {} # topic -> partition -> PartitionMetadata + def __init__(self, **configs: int | str | list[str]): + self._brokers: dict[str, BrokerMetadata] = {} # node_id -> BrokerMetadata + self._partitions: dict[str, dict[int, PartitionMetadata]]= {} # topic -> partition -> PartitionMetadata # node_id -> {TopicPartition...} - self._broker_partitions = collections.defaultdict(set) - self._groups = {} # group_name -> node_id - self._last_refresh_ms = 0 - self._last_successful_refresh_ms = 0 - self._need_update = True - self._future = None - self._listeners = set() - self._lock = threading.Lock() - self.need_all_topic_metadata = False - self.unauthorized_topics = set() - self.internal_topics = set() - self.controller = None + self._broker_partitions: dict[int | str, set[TopicPartition]] = collections.defaultdict(set) + self._groups: dict[str, int | str] = {} # group_name -> node_id + self._last_refresh_ms: int = 0 + self._last_successful_refresh_ms: int = 0 + self._need_update: bool = True + self._future: Future[ClusterMetadata] | None = None + self._listeners: set[Callable[[ClusterMetadata], Any]] = set() + self._lock: threading.Lock = threading.Lock() + self.need_all_topic_metadata: bool = False + self.unauthorized_topics: set[str] = set() + self.internal_topics: set[str] = set() + self.controller: BrokerMetadata | None = None self.config = copy.copy(self.DEFAULT_CONFIG) for key in self.config: @@ -64,24 +87,24 @@ def __init__(self, **configs): self.config[key] = configs[key] self._bootstrap_brokers = self._generate_bootstrap_brokers() - self._coordinator_brokers = {} - self._coordinators = {} - self._coordinator_by_key = {} + self._coordinator_brokers: dict[str, BrokerMetadata] = {} + self._coordinators: dict[int | str, BrokerMetadata] = {} + self._coordinator_by_key: dict[tuple[CoordinationType, str], int | str] = {} - def _generate_bootstrap_brokers(self): + def _generate_bootstrap_brokers(self) -> dict[str, BrokerMetadata]: # collect_hosts does not perform DNS, so we should be fine to re-use bootstrap_hosts = collect_hosts(self.config["bootstrap_servers"]) - brokers = {} + brokers: dict[str, BrokerMetadata] = {} for i, (host, port, _) in enumerate(bootstrap_hosts): node_id = f"bootstrap-{i}" brokers[node_id] = BrokerMetadata(node_id, host, port, None) return brokers - def is_bootstrap(self, node_id): + def is_bootstrap(self, node_id: str) -> bool: return node_id in self._bootstrap_brokers - def brokers(self): + def brokers(self) -> set[BrokerMetadata]: """Get all BrokerMetadata Returns: @@ -89,11 +112,11 @@ def brokers(self): """ return set(self._brokers.values()) or set(self._bootstrap_brokers.values()) - def broker_metadata(self, broker_id): + def broker_metadata(self, broker_id: str) -> BrokerMetadata | None: """Get BrokerMetadata Arguments: - broker_id (int): node_id for a broker to check + broker_id (str): node_id for a broker to check Returns: BrokerMetadata or None if not found @@ -117,7 +140,7 @@ def partitions_for_topic(self, topic: str) -> Optional[set[int]]: return None return set(self._partitions[topic].keys()) - def available_partitions_for_topic(self, topic): + def available_partitions_for_topic(self, topic: str) -> Optional[Set[int]]: """Return set of partitions with known leaders Arguments: @@ -135,7 +158,7 @@ def available_partitions_for_topic(self, topic): if metadata.leader != -1 } - def leader_for_partition(self, partition): + def leader_for_partition(self, partition: PartitionMetadata) -> int | None: """Return node_id of leader, -1 unavailable, None if unknown.""" if partition.topic not in self._partitions: return None @@ -144,7 +167,7 @@ def leader_for_partition(self, partition): return None return partitions[partition.partition].leader - def partitions_for_broker(self, broker_id): + def partitions_for_broker(self, broker_id: int | str) -> set[TopicPartition] | None: """Return TopicPartitions for which the broker is a leader. Arguments: @@ -156,7 +179,7 @@ def partitions_for_broker(self, broker_id): """ return self._broker_partitions.get(broker_id) - def coordinator_for_group(self, group): + def coordinator_for_group(self, group: str) -> int | str | None: """Return node_id of group coordinator. Arguments: @@ -168,7 +191,7 @@ def coordinator_for_group(self, group): """ return self._groups.get(group) - def request_update(self): + def request_update(self) -> Future[ClusterMetadata]: """Flags metadata for update, return Future() Actual update must be handled separately. This method will only @@ -179,11 +202,11 @@ def request_update(self): """ with self._lock: self._need_update = True - if not self._future or self._future.is_done: + if not self._future or self._future.done(): self._future = Future() return self._future - def topics(self, exclude_internal_topics=True): + def topics(self, exclude_internal_topics: bool=True) -> set[str]: """Get set of known topics. Arguments: @@ -201,7 +224,7 @@ def topics(self, exclude_internal_topics=True): else: return topics - def failed_update(self, exception): + def failed_update(self, exception: BaseException) -> None: """Update cluster state given a failed MetadataRequest.""" f = None with self._lock: @@ -212,7 +235,7 @@ def failed_update(self, exception): f.set_exception(exception) self._last_refresh_ms = time.time() * 1000 - def update_metadata(self, metadata): + def update_metadata(self, metadata: MetadataResponse) -> None: """Update cluster state given a MetadataResponse. Arguments: @@ -241,8 +264,8 @@ def update_metadata(self, metadata): _new_partitions = {} _new_broker_partitions = collections.defaultdict(set) - _new_unauthorized_topics = set() - _new_internal_topics = set() + _new_unauthorized_topics: set[str] = set() + _new_internal_topics: set[str] = set() for topic_data in metadata.topics: if metadata.API_VERSION == 0: @@ -320,15 +343,15 @@ def update_metadata(self, metadata): # another fetch should be unnecessary. self._need_update = False - def add_listener(self, listener): + def add_listener(self, listener: Callable[[ClusterMetadata], Any]) -> None: """Add a callback function to be called on each metadata update""" self._listeners.add(listener) - def remove_listener(self, listener): + def remove_listener(self, listener: Callable[[ClusterMetadata], Any]) -> None: """Remove a previously added listener callback""" self._listeners.remove(listener) - def add_group_coordinator(self, group, response): + def add_group_coordinator(self, group: str, response: GroupCoordinatorResponse) -> str | None: """Update with metadata for a group coordinator Arguments: @@ -355,7 +378,7 @@ def add_group_coordinator(self, group, response): self._groups[group] = node_id return node_id - def with_partitions(self, partitions_to_add): + def with_partitions(self, partitions_to_add: Sequence[PartitionMetadata]) -> ClusterMetadata: """Returns a copy of cluster metadata with partitions added""" new_metadata = ClusterMetadata(**self.config) new_metadata._brokers = copy.deepcopy(self._brokers) @@ -375,10 +398,10 @@ def with_partitions(self, partitions_to_add): return new_metadata - def coordinator_metadata(self, node_id): + def coordinator_metadata(self, node_id: int | str) -> BrokerMetadata | None: return self._coordinators.get(node_id) - def add_coordinator(self, node_id, host, port, rack=None, *, purpose): + def add_coordinator(self, node_id: int | str, host: str, port: int, rack: str | None=None, *, purpose: tuple[CoordinationType, str]) -> None: """Keep track of all coordinator nodes separately and remove them if a new one was elected for the same purpose (For example group coordinator for group X). @@ -390,7 +413,7 @@ def add_coordinator(self, node_id, host, port, rack=None, *, purpose): self._coordinators[node_id] = BrokerMetadata(node_id, host, port, rack) self._coordinator_by_key[purpose] = node_id - def __str__(self): + def __str__(self) -> str: return "ClusterMetadata(brokers: %d, topics: %d, groups: %d)" % ( len(self._brokers), len(self._partitions), diff --git a/aiokafka/conn.py b/aiokafka/conn.py index 859f8f245..ccef548b5 100644 --- a/aiokafka/conn.py +++ b/aiokafka/conn.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import base64 import collections @@ -16,6 +18,7 @@ import warnings import weakref from enum import IntEnum +from typing import Literal import async_timeout @@ -893,7 +896,7 @@ def get_ip_port_afi(host_and_port_str): return host, port, af -def collect_hosts(hosts, randomize=True): +def collect_hosts(hosts: str | list[str], randomize: bool=True) -> list[tuple[str, int, Literal[0] | Literal[2] | Literal[10]]]: """ Collects a comma-separated set of hosts (host:port) and optionally randomize the returned list. diff --git a/aiokafka/consumer/consumer.py b/aiokafka/consumer/consumer.py index 559fa35ad..486a52596 100644 --- a/aiokafka/consumer/consumer.py +++ b/aiokafka/consumer/consumer.py @@ -1,13 +1,19 @@ +from __future__ import annotations + import asyncio import logging import re import sys import traceback import warnings +from ssl import SSLContext +from types import ModuleType, TracebackType +from typing import Callable, Generic, Literal, TypeVar from aiokafka import __version__ -from aiokafka.abc import ConsumerRebalanceListener +from aiokafka.abc import AbstractTokenProvider, ConsumerRebalanceListener from aiokafka.client import AIOKafkaClient +from aiokafka.coordinator.assignors.abstract import AbstractPartitionAssignor from aiokafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor from aiokafka.errors import ( ConsumerStoppedError, @@ -16,7 +22,7 @@ RecordTooLargeError, UnsupportedVersionError, ) -from aiokafka.structs import ConsumerRecord, TopicPartition +from aiokafka.structs import ConsumerRecord, OffsetAndMetadata, OffsetAndTimestamp, TopicPartition from aiokafka.util import commit_structure_validate, get_running_loop from .fetcher import Fetcher, OffsetResetStrategy @@ -26,7 +32,12 @@ log = logging.getLogger(__name__) -class AIOKafkaConsumer: +KT = TypeVar("KT", covariant=True) +VT = TypeVar("VT", covariant=True) +ET = TypeVar("ET", bound=BaseException) + + +class AIOKafkaConsumer(Generic[KT, VT]): """ A client that consumes records from a Kafka cluster. @@ -42,7 +53,7 @@ class AIOKafkaConsumer: https://cwiki.apache.org/confluence/display/KAFKA/KIP-62%3A+Allow+consumer+to+send+heartbeats+from+a+background+thread Arguments: - *topics (list(str)): optional list of topics to subscribe to. If not set, + *topics (tuple(str)): optional list of topics to subscribe to. If not set, call :meth:`.subscribe` or :meth:`.assign` before consuming records. Passing topics directly is same as calling :meth:`.subscribe` API. bootstrap_servers (str, list(str)): a ``host[:port]`` string (or list of @@ -94,8 +105,9 @@ class AIOKafkaConsumer: send messages larger than the consumer can fetch. If that happens, the consumer can get stuck trying to fetch a large message on a certain partition. Default: 1048576. - max_poll_records (int): The maximum number of records returned in a - single call to :meth:`.getmany`. Defaults ``None``, no limit. + max_poll_records (int or None): The maximum number of records + returned in a single call to :meth:`.getmany`. + Defaults ``None``, no limit. request_timeout_ms (int): Client request timeout in milliseconds. Default: 40000. retry_backoff_ms (int): Milliseconds to backoff when retrying on @@ -117,7 +129,7 @@ class AIOKafkaConsumer: which we force a refresh of metadata even if we haven't seen any partition leadership changes to proactively discover any new brokers or partitions. Default: 300000 - partition_assignment_strategy (list): List of objects to use to + partition_assignment_strategy (list or tuple): List of objects to use to distribute partition ownership amongst consumer instances when group management is used. This preference is implicit in the order of the strategies in the list. When assignment strategy changes: @@ -209,11 +221,11 @@ class AIOKafkaConsumer: ``PLAIN``, ``GSSAPI``, ``SCRAM-SHA-256``, ``SCRAM-SHA-512``, ``OAUTHBEARER``. Default: ``PLAIN`` - sasl_plain_username (str): username for SASL ``PLAIN`` authentication. + sasl_plain_username (str or None): username for SASL ``PLAIN`` authentication. Default: None - sasl_plain_password (str): password for SASL ``PLAIN`` authentication. + sasl_plain_password (str or None): password for SASL ``PLAIN`` authentication. Default: None - sasl_oauth_token_provider (~aiokafka.abc.AbstractTokenProvider): + sasl_oauth_token_provider (~aiokafka.abc.AbstractTokenProvider or None): OAuthBearer token provider instance. Default: None @@ -228,45 +240,45 @@ class AIOKafkaConsumer: def __init__( self, - *topics, - loop=None, - bootstrap_servers="localhost", - client_id="aiokafka-" + __version__, - group_id=None, - group_instance_id=None, - key_deserializer=None, - value_deserializer=None, - fetch_max_wait_ms=500, - fetch_max_bytes=52428800, - fetch_min_bytes=1, - max_partition_fetch_bytes=1 * 1024 * 1024, - request_timeout_ms=40 * 1000, - retry_backoff_ms=100, - auto_offset_reset="latest", - enable_auto_commit=True, - auto_commit_interval_ms=5000, - check_crcs=True, - metadata_max_age_ms=5 * 60 * 1000, - partition_assignment_strategy=(RoundRobinPartitionAssignor,), - max_poll_interval_ms=300000, - rebalance_timeout_ms=None, - session_timeout_ms=10000, - heartbeat_interval_ms=3000, - consumer_timeout_ms=200, - max_poll_records=None, - ssl_context=None, - security_protocol="PLAINTEXT", - api_version="auto", - exclude_internal_topics=True, - connections_max_idle_ms=540000, - isolation_level="read_uncommitted", - sasl_mechanism="PLAIN", - sasl_plain_password=None, - sasl_plain_username=None, - sasl_kerberos_service_name="kafka", - sasl_kerberos_domain_name=None, - sasl_oauth_token_provider=None, - ): + *topics: str, + loop: asyncio.AbstractEventLoop | None=None, + bootstrap_servers: str | list[str]="localhost", + client_id: str="aiokafka-" + __version__, + group_id: str | None=None, + group_instance_id: str | None=None, + key_deserializer: Callable[[bytes], KT]=lambda x: x, + value_deserializer: Callable[[bytes], VT]=lambda x: x, + fetch_max_wait_ms: int=500, + fetch_max_bytes: int=52428800, + fetch_min_bytes: int=1, + max_partition_fetch_bytes: int=1 * 1024 * 1024, + request_timeout_ms: int=40 * 1000, + retry_backoff_ms: int=100, + auto_offset_reset: Literal["earliest"] | Literal["latest"] | Literal["none"]="latest", + enable_auto_commit: bool=True, + auto_commit_interval_ms: int=5000, + check_crcs: bool=True, + metadata_max_age_ms: int=5 * 60 * 1000, + partition_assignment_strategy: tuple[type[AbstractPartitionAssignor], ...]=(RoundRobinPartitionAssignor,), + max_poll_interval_ms: int=300000, + rebalance_timeout_ms: int | None=None, + session_timeout_ms: int=10000, + heartbeat_interval_ms: int=3000, + consumer_timeout_ms: int=200, + max_poll_records: int | None=None, + ssl_context: SSLContext | None=None, + security_protocol: Literal["PLAINTEXT"] | Literal["SSL"] | Literal["SASL_PLAINTEXT"] | Literal["SASL_SSL"]="PLAINTEXT", + api_version: str="auto", + exclude_internal_topics: bool=True, + connections_max_idle_ms: int=540000, + isolation_level: Literal["read_committed"] | Literal["read_uncommitted"]="read_uncommitted", + sasl_mechanism: Literal["PLAIN"] | Literal["GSSAPI"] | Literal["SCRAM-SHA-256"] | Literal["SCRAM-SHA-512"] | Literal["OAUTHBEARER"]="PLAIN", + sasl_plain_password: str | None=None, + sasl_plain_username: str | None=None, + sasl_kerberos_service_name: str="kafka", + sasl_kerberos_domain_name: str | None=None, + sasl_oauth_token_provider: AbstractTokenProvider | None=None, + ) -> None: if loop is None: loop = get_running_loop() else: @@ -338,11 +350,11 @@ def __init__( self._closed = False if topics: - topics = self._validate_topics(topics) - self._client.set_topics(topics) - self._subscription.subscribe(topics=topics) + _topics: tuple[str, ...] | set[str] | list[str] = self._validate_topics(topics) + self._client.set_topics(_topics) + self._subscription.subscribe(topics=_topics) - def __del__(self, _warnings=warnings): + def __del__(self, _warnings: ModuleType=warnings) -> None: if self._closed is False: _warnings.warn( f"Unclosed AIOKafkaConsumer {self!r}", @@ -357,7 +369,7 @@ def __del__(self, _warnings=warnings): context["source_traceback"] = self._source_traceback self._loop.call_exception_handler(context) - async def start(self): + async def start(self) -> None: """Connect to Kafka cluster. This will: * Load metadata for all cluster nodes and partition allocation @@ -446,17 +458,17 @@ async def start(self): await self._client.force_metadata_update() self._coordinator.assign_all_partitions(check_unknown=True) - async def _wait_topics(self): + async def _wait_topics(self) -> None: if self._subscription.subscription is not None: for topic in self._subscription.subscription.topics: await self._client._wait_on_metadata(topic) - def _validate_topics(self, topics): + def _validate_topics(self, topics: tuple[str, ...] | set[str] | list[str]) -> set[str]: if not isinstance(topics, (tuple, set, list)): raise TypeError("Topics should be list of strings") return set(topics) - def assign(self, partitions): + def assign(self, partitions: list[TopicPartition]) -> None: """Manually assign a list of :class:`.TopicPartition` to this consumer. This interface does not support incremental assignment and will @@ -487,7 +499,7 @@ def assign(self, partitions): assignment = self._subscription.subscription.assignment self._coordinator.start_commit_offsets_refresh_task(assignment) - def assignment(self): + def assignment(self) -> set[TopicPartition]: """Get the set of partitions currently assigned to this consumer. If partitions were directly assigned using :meth:`assign`, then this will @@ -504,7 +516,7 @@ def assignment(self): """ return self._subscription.assigned_partitions() - async def stop(self): + async def stop(self) -> None: """Close the consumer, while waiting for finalizers: * Commit last consumed message if autocommit enabled @@ -521,7 +533,7 @@ async def stop(self): await self._client.close() log.debug("The KafkaConsumer has closed.") - async def commit(self, offsets=None): + async def commit(self, offsets: dict[TopicPartition, int | tuple[int, str] | OffsetAndMetadata] | None=None) -> None: """Commit offsets to Kafka. This commits offsets only to Kafka. The offsets committed using this @@ -595,7 +607,7 @@ async def commit(self, offsets=None): await self._coordinator.commit_offsets(assignment, offsets) - async def committed(self, partition): + async def committed(self, partition: TopicPartition) -> int | None: """Get the last committed offset for the given partition. (whether the commit happened by this process or another). @@ -627,7 +639,7 @@ async def committed(self, partition): committed = None return committed - async def topics(self): + async def topics(self) -> set[str]: """Get all topics the user is authorized to view. Returns: @@ -636,7 +648,7 @@ async def topics(self): cluster = await self._client.fetch_all_metadata() return cluster.topics() - def partitions_for_topic(self, topic): + def partitions_for_topic(self, topic: str) -> set[int] | None: """Get metadata about the partitions for a given topic. This method will return `None` if Consumer does not already have @@ -650,7 +662,7 @@ def partitions_for_topic(self, topic): """ return self._client.cluster.partitions_for_topic(topic) - async def position(self, partition): + async def position(self, partition: TopicPartition) -> int: """Get the offset of the *next record* that will be fetched (if a record with that offset exists on broker). @@ -693,7 +705,7 @@ async def position(self, partition): continue return tp_state.position - def highwater(self, partition): + def highwater(self, partition: TopicPartition) -> int | None: # TODO Return type seems wrong """Last known highwater offset for a partition. A highwater offset is the offset that will be assigned to the next @@ -715,7 +727,7 @@ def highwater(self, partition): assignment = self._subscription.subscription.assignment return assignment.state_value(partition).highwater - def last_stable_offset(self, partition): + def last_stable_offset(self, partition: TopicPartition) -> int | None: # TODO Return type seems wrong """Returns the Last Stable Offset of a topic. It will be the last offset up to which point all transactions were completed. Only available in with isolation_level `read_committed`, in @@ -735,7 +747,7 @@ def last_stable_offset(self, partition): assignment = self._subscription.subscription.assignment return assignment.state_value(partition).lso - def last_poll_timestamp(self, partition): + def last_poll_timestamp(self, partition: TopicPartition) -> int | None: # TODO Return type seems wrong """Returns the timestamp of the last poll of this partition (in ms). It is the last time :meth:`highwater` and :meth:`last_stable_offset` were updated. However it does not mean that new messages were received. @@ -753,7 +765,7 @@ def last_poll_timestamp(self, partition): assignment = self._subscription.subscription.assignment return assignment.state_value(partition).timestamp - def seek(self, partition, offset): + def seek(self, partition: TopicPartition, offset: int) -> None: """Manually specify the fetch offset for a :class:`.TopicPartition`. Overrides the fetch offsets that the consumer will use on the next @@ -785,7 +797,7 @@ def seek(self, partition, offset): log.debug("Seeking to offset %s for partition %s", offset, partition) self._fetcher.seek_to(partition, offset) - async def seek_to_beginning(self, *partitions): + async def seek_to_beginning(self, *partitions: TopicPartition) -> bool: """Seek to the oldest available offset for partitions. Arguments: @@ -825,7 +837,7 @@ async def seek_to_beginning(self, *partitions): self._coordinator.check_errors() return fut.done() - async def seek_to_end(self, *partitions): + async def seek_to_end(self, *partitions: TopicPartition) -> bool: """Seek to the most recent available offset for partitions. Arguments: @@ -862,7 +874,7 @@ async def seek_to_end(self, *partitions): self._coordinator.check_errors() return fut.done() - async def seek_to_committed(self, *partitions): + async def seek_to_committed(self, *partitions: TopicPartition) -> dict[TopicPartition, int | None]: """Seek to the committed offset for partitions. Arguments: @@ -894,7 +906,7 @@ async def seek_to_committed(self, *partitions): if not_assigned: raise IllegalStateError(f"Partitions {not_assigned} are not assigned") - committed_offsets = {} + committed_offsets: dict[TopicPartition, int | None] = {} for tp in partitions: offset = await self.committed(tp) committed_offsets[tp] = offset @@ -903,7 +915,7 @@ async def seek_to_committed(self, *partitions): self._fetcher.seek_to(tp, offset) return committed_offsets - async def offsets_for_times(self, timestamps): + async def offsets_for_times(self, timestamps: dict[TopicPartition, int]) -> dict[TopicPartition, OffsetAndTimestamp | None]: """ Look up the offsets for the given partitions by timestamp. The returned offset for each partition is the earliest offset whose timestamp is @@ -925,7 +937,7 @@ async def offsets_for_times(self, timestamps): beginning of the epoch (midnight Jan 1, 1970 (UTC)) Returns: - dict(TopicPartition, OffsetAndTimestamp): mapping from + dict(TopicPartition, OffsetAndTimestamp or None): mapping from partition to the timestamp and offset of the first message with timestamp greater than or equal to the target timestamp. None will be returned for the partition if there is no such message. @@ -951,12 +963,12 @@ async def offsets_for_times(self, timestamps): f"The target time for partition {tp} is {ts}." " The target time cannot be negative." ) - offsets = await self._fetcher.get_offsets_by_times( + offsets: dict[TopicPartition, OffsetAndTimestamp] = await self._fetcher.get_offsets_by_times( timestamps, self._request_timeout_ms ) return offsets - async def beginning_offsets(self, partitions): + async def beginning_offsets(self, partitions: list[TopicPartition]) -> dict[TopicPartition, int]: """Get the first offset for the given partitions. This method does not change the current consumer position of the @@ -991,7 +1003,7 @@ async def beginning_offsets(self, partitions): ) return offsets - async def end_offsets(self, partitions): + async def end_offsets(self, partitions: list[TopicPartition]) -> dict[TopicPartition, int]: """Get the last offset for the given partitions. The last offset of a partition is the offset of the upcoming message, i.e. the offset of the last available message + 1. @@ -1026,7 +1038,7 @@ async def end_offsets(self, partitions): offsets = await self._fetcher.end_offsets(partitions, self._request_timeout_ms) return offsets - def subscribe(self, topics=(), pattern=None, listener=None): + def subscribe(self, topics: list[str] | tuple[str, ...]=(), pattern: str | None=None, listener: ConsumerRebalanceListener | None=None) -> None: """Subscribe to a list of topics, or a topic regex pattern. Partitions will be dynamically assigned via a group coordinator. @@ -1036,7 +1048,7 @@ def subscribe(self, topics=(), pattern=None, listener=None): This method is incompatible with :meth:`assign`. Arguments: - topics (list): List of topics for subscription. + topics (list or tuple): List of topics for subscription. pattern (str): Pattern to match available topics. You must provide either topics or pattern, but not both. listener (ConsumerRebalanceListener): Optionally include listener @@ -1100,15 +1112,15 @@ def subscribe(self, topics=(), pattern=None, listener=None): self._coordinator._metadata_snapshot = {} log.info("Subscribed to topic(s): %s", topics) - def subscription(self): + def subscription(self) -> set[str]: """Get the current topics subscription. Returns: - frozenset(str): a set of topics + set(str): a set of topics """ return self._subscription.topics - def unsubscribe(self): + def unsubscribe(self) -> None: """Unsubscribe from all topics and clear all assigned partitions.""" self._subscription.unsubscribe() if self._group_id is not None: @@ -1116,7 +1128,7 @@ def unsubscribe(self): self._client.set_topics([]) log.info("Unsubscribed all topics or patterns and assigned partitions") - async def getone(self, *partitions) -> ConsumerRecord: + async def getone(self, *partitions: TopicPartition) -> ConsumerRecord[KT, VT]: """ Get one message from Kafka. If no new messages prefetched, this method will wait for it. @@ -1161,8 +1173,8 @@ async def getone(self, *partitions) -> ConsumerRecord: return msg async def getmany( - self, *partitions, timeout_ms=0, max_records=None - ) -> dict[TopicPartition, list[ConsumerRecord]]: + self, *partitions: TopicPartition, timeout_ms: int=0, max_records: int | None=None + ) -> dict[TopicPartition, list[ConsumerRecord[KT, VT]]]: """Get messages from assigned topics / partitions. Prefetched messages are returned in batches by topic-partition. @@ -1215,7 +1227,7 @@ async def getmany( ) return records - def pause(self, *partitions): + def pause(self, *partitions: TopicPartition) -> None: """Suspend fetching from the requested partitions. Future calls to :meth:`.getmany` will not return any records from these @@ -1235,7 +1247,7 @@ def pause(self, *partitions): log.debug("Pausing partition %s", partition) self._subscription.pause(partition) - def paused(self): + def paused(self) -> set[TopicPartition]: """Get the partitions that were previously paused using :meth:`.pause`. @@ -1244,11 +1256,11 @@ def paused(self): """ return self._subscription.paused_partitions() - def resume(self, *partitions): + def resume(self, *partitions: TopicPartition) -> None: """Resume fetching from the specified (paused) partitions. Arguments: - *partitions (list[TopicPartition]): Partitions to resume. + *partitions (tuple[TopicPartition,...]): Partitions to resume. """ if not all(isinstance(p, TopicPartition) for p in partitions): raise TypeError("partitions must be TopicPartition namedtuples") @@ -1257,12 +1269,12 @@ def resume(self, *partitions): log.debug("Resuming partition %s", partition) self._subscription.resume(partition) - def __aiter__(self): + def __aiter__(self) -> AIOKafkaConsumer[KT, VT]: if self._closed: raise ConsumerStoppedError() return self - async def __anext__(self) -> ConsumerRecord: + async def __anext__(self) -> ConsumerRecord[KT, VT]: """Asyncio iterator interface for consumer Note: @@ -1278,9 +1290,9 @@ async def __anext__(self) -> ConsumerRecord: except RecordTooLargeError: log.exception("error in consumer iterator: %s") - async def __aenter__(self): + async def __aenter__(self) -> AIOKafkaConsumer[KT, VT]: await self.start() return self - async def __aexit__(self, exc_type, exc, tb): + async def __aexit__(self, exc_type: type[ET] | None, exc: ET | None, tb: TracebackType | None) -> None: await self.stop() diff --git a/aiokafka/producer/message_accumulator.py b/aiokafka/producer/message_accumulator.py index 2b2e37853..9b8584ea7 100644 --- a/aiokafka/producer/message_accumulator.py +++ b/aiokafka/producer/message_accumulator.py @@ -3,6 +3,7 @@ import copy import time from collections.abc import Sequence +from typing import Generic, TypeVar from aiokafka.errors import ( KafkaTimeoutError, @@ -16,7 +17,11 @@ from aiokafka.util import create_future, get_running_loop -class BatchBuilder: +KT = TypeVar("KT", contravariant=True) +VT = TypeVar("VT", contravariant=True) + + +class BatchBuilder(Generic[KT, VT]): def __init__( self, magic, diff --git a/aiokafka/producer/producer.py b/aiokafka/producer/producer.py index ab956e81b..7944bbf74 100644 --- a/aiokafka/producer/producer.py +++ b/aiokafka/producer/producer.py @@ -1,9 +1,15 @@ +from __future__ import annotations + import asyncio import logging import sys import traceback import warnings +from ssl import SSLContext +from types import ModuleType, TracebackType +from typing import Callable, Generic, Iterable, Literal, TypeVar +from aiokafka.abc import AbstractTokenProvider from aiokafka.client import AIOKafkaClient from aiokafka.codec import has_gzip, has_lz4, has_snappy, has_zstd from aiokafka.errors import ( @@ -14,7 +20,7 @@ from aiokafka.partitioner import DefaultPartitioner from aiokafka.record.default_records import DefaultRecordBatch from aiokafka.record.legacy_records import LegacyRecordBatchBuilder -from aiokafka.structs import TopicPartition +from aiokafka.structs import OffsetAndMetadata, RecordMetadata, TopicPartition from aiokafka.util import ( INTEGER_MAX_VALUE, commit_structure_validate, @@ -22,7 +28,7 @@ get_running_loop, ) -from .message_accumulator import MessageAccumulator +from .message_accumulator import BatchBuilder, MessageAccumulator from .sender import Sender from .transaction_manager import TransactionManager @@ -30,11 +36,19 @@ _missing = object() +def _identity(data: bytes) -> bytes: + return data + _DEFAULT_PARTITIONER = DefaultPartitioner() -class AIOKafkaProducer: +KT = TypeVar("KT", contravariant=True) +VT = TypeVar("VT", contravariant=True) +ET = TypeVar("ET", bound=BaseException) + + +class AIOKafkaProducer(Generic[KT, VT]): """A Kafka client that publishes records to the Kafka cluster. The producer consists of a pool of buffer space that holds records that @@ -60,17 +74,18 @@ class AIOKafkaProducer: full node list. It just needs to have at least one broker that will respond to a Metadata API Request. Default port is 9092. If no servers are specified, will default to ``localhost:9092``. - client_id (str): a name for this client. This string is passed in + client_id (str or None): a name for this client. This string is passed in each request to servers and can be used to identify specific server-side log entries that correspond to this client. - Default: ``aiokafka-producer-#`` (appended with a unique number - per instance) - key_serializer (Callable): used to convert user-supplied keys to bytes - If not :data:`None`, called as ``f(key),`` should return + If ``None`` ``aiokafka-producer-#`` (appended with a unique number + per instance) is used. + Default: :data:`None` + key_serializer (Callable[[KT], bytes]): used to convert user-supplied keys + to bytes. If not :data:`None`, called as ``f(key),`` should return :class:`bytes`. Default: :data:`None`. - value_serializer (Callable): used to convert user-supplied message - values to :class:`bytes`. If not :data:`None`, called as + value_serializer (Callable[[VT], bytes]): used to convert user-supplied + message values to :class:`bytes`. If not :data:`None`, called as ``f(value)``, should return :class:`bytes`. Default: :data:`None`. acks (Any): one of ``0``, ``1``, ``all``. The number of acknowledgments @@ -197,33 +212,33 @@ class AIOKafkaProducer: def __init__( self, *, - loop=None, - bootstrap_servers="localhost", - client_id=None, - metadata_max_age_ms=300000, - request_timeout_ms=40000, - api_version="auto", - acks=_missing, - key_serializer=None, - value_serializer=None, - compression_type=None, - max_batch_size=16384, - partitioner=_DEFAULT_PARTITIONER, - max_request_size=1048576, - linger_ms=0, - retry_backoff_ms=100, - security_protocol="PLAINTEXT", - ssl_context=None, - connections_max_idle_ms=540000, - enable_idempotence=False, - transactional_id=None, - transaction_timeout_ms=60000, - sasl_mechanism="PLAIN", - sasl_plain_password=None, - sasl_plain_username=None, - sasl_kerberos_service_name="kafka", - sasl_kerberos_domain_name=None, - sasl_oauth_token_provider=None, + loop: asyncio.AbstractEventLoop | None=None, + bootstrap_servers: str | list[str]="localhost", + client_id: str | None=None, + metadata_max_age_ms: int=300000, + request_timeout_ms: int=40000, + api_version: str="auto", + acks: Literal[0] | Literal[1] | Literal["all"] | object=_missing, + key_serializer: Callable[[KT], bytes]=_identity, + value_serializer: Callable[[VT], bytes]=_identity, + compression_type: Literal["gzip"] | Literal["snappy"] | Literal["lz4"] | Literal["zstd"] | None=None, + max_batch_size: int=16384, + partitioner: Callable[[bytes, list[int], list[int]], int]=_DEFAULT_PARTITIONER, + max_request_size: int=1048576, + linger_ms: int=0, + retry_backoff_ms: int=100, + security_protocol: Literal["PLAINTEXT"] | Literal["SSL"] | Literal["SASL_PLAINTEXT"] | Literal["SASL_SSL"]="PLAINTEXT", + ssl_context: SSLContext | None=None, + connections_max_idle_ms: int=540000, + enable_idempotence: bool=False, + transactional_id: int | str | None=None, # In theory, this could be any unique object + transaction_timeout_ms: int=60000, + sasl_mechanism: Literal["PLAIN"] | Literal["GSSAPI"] | Literal["SCRAM-SHA-256"] | Literal["SCRAM-SHA-512"] | Literal["OAUTHBEARER"]="PLAIN", + sasl_plain_password: str | None=None, + sasl_plain_username: str | None=None, + sasl_kerberos_service_name: str="kafka", + sasl_kerberos_domain_name: str | None=None, + sasl_oauth_token_provider: AbstractTokenProvider | None=None, ): if loop is None: loop = get_running_loop() @@ -328,7 +343,7 @@ def __init__( # Warn if producer was not closed properly # We don't attempt to close the Consumer, as __del__ is synchronous - def __del__(self, _warnings=warnings): + def __del__(self, _warnings: ModuleType=warnings) -> None: if self._closed is False: _warnings.warn( f"Unclosed AIOKafkaProducer {self!r}", @@ -343,7 +358,7 @@ def __del__(self, _warnings=warnings): context["source_traceback"] = self._source_traceback self._loop.call_exception_handler(context) - async def start(self): + async def start(self) -> None: """Connect to Kafka cluster and check server version""" assert ( self._loop is get_running_loop() @@ -371,11 +386,11 @@ async def start(self): self._producer_magic = 0 if self.client.api_version < (0, 10) else 1 log.debug("Kafka producer started") - async def flush(self): + async def flush(self) -> None: """Wait until all batches are Delivered and futures resolved""" await self._message_accumulator.flush() - async def stop(self): + async def stop(self) -> None: """Flush all pending data and close all connections to kafka cluster""" if self._closed: return @@ -396,11 +411,11 @@ async def stop(self): await self.client.close() log.debug("The Kafka producer has closed.") - async def partitions_for(self, topic): + async def partitions_for(self, topic: str) -> set[int]: """Returns set of all known partitions for the topic.""" return await self.client._wait_on_metadata(topic) - def _serialize(self, topic, key, value): + def _serialize(self, topic: str, key: KT, value: VT): if self._key_serializer is None: serialized_key = key else: @@ -425,8 +440,8 @@ def _serialize(self, topic, key, value): return serialized_key, serialized_value def _partition( - self, topic, partition, key, value, serialized_key, serialized_value - ): + self, topic: str, partition: int, key: KT, value: VT, serialized_key: bytes, serialized_value: bytes + ) -> int: if partition is not None: assert partition >= 0 assert partition in self._metadata.partitions_for_topic( @@ -440,13 +455,13 @@ def _partition( async def send( self, - topic, - value=None, - key=None, - partition=None, - timestamp_ms=None, - headers=None, - ): + topic: str, + value: VT | None=None, + key: KT | None=None, + partition: int | None=None, + timestamp_ms: int | None=None, + headers: Iterable[tuple[str, bytes]] | None=None, + ) -> asyncio.Future[RecordMetadata]: """Publish a message to a topic. Arguments: @@ -534,18 +549,18 @@ async def send( async def send_and_wait( self, - topic, - value=None, - key=None, - partition=None, - timestamp_ms=None, - headers=None, - ): + topic: str, + value: VT | None=None, + key: KT | None=None, + partition: int | None=None, + timestamp_ms: int | None=None, + headers: Iterable[tuple[str, bytes]] | None=None, + ) -> RecordMetadata: """Publish a message to a topic and wait the result""" future = await self.send(topic, value, key, partition, timestamp_ms, headers) return await future - def create_batch(self): + def create_batch(self) -> BatchBuilder[KT, VT]: """Create and return an empty :class:`.BatchBuilder`. The batch is not queued for send until submission to :meth:`send_batch`. @@ -557,7 +572,7 @@ def create_batch(self): key_serializer=self._key_serializer, value_serializer=self._value_serializer ) - async def send_batch(self, batch, topic, *, partition): + async def send_batch(self, batch: BatchBuilder, topic: str, *, partition: int) -> asyncio.Future[RecordMetadata]: """Submit a BatchBuilder for publication. Arguments: @@ -590,13 +605,13 @@ async def send_batch(self, batch, topic, *, partition): ) return future - def _ensure_transactional(self): + def _ensure_transactional(self) -> None: if self._txn_manager is None or self._txn_manager.transactional_id is None: raise IllegalOperation( "You need to configure transaction_id to use transactions" ) - async def begin_transaction(self): + async def begin_transaction(self) -> None: self._ensure_transactional() log.debug( "Beginning a new transaction for id %s", self._txn_manager.transactional_id @@ -604,7 +619,7 @@ async def begin_transaction(self): await asyncio.shield(self._txn_manager.wait_for_pid()) self._txn_manager.begin_transaction() - async def commit_transaction(self): + async def commit_transaction(self) -> None: self._ensure_transactional() log.debug( "Committing transaction for id %s", self._txn_manager.transactional_id @@ -614,7 +629,7 @@ async def commit_transaction(self): self._txn_manager.wait_for_transaction_end(), ) - async def abort_transaction(self): + async def abort_transaction(self) -> None: self._ensure_transactional() log.debug("Aborting transaction for id %s", self._txn_manager.transactional_id) self._txn_manager.aborting_transaction() @@ -622,12 +637,12 @@ async def abort_transaction(self): self._txn_manager.wait_for_transaction_end(), ) - def transaction(self): + def transaction(self) -> TransactionContext: """Start a transaction context""" return TransactionContext(self) - async def send_offsets_to_transaction(self, offsets, group_id): + async def send_offsets_to_transaction(self, offsets: dict[TopicPartition, int | tuple[int, str] | OffsetAndMetadata], group_id: str) -> None: self._ensure_transactional() if not self._txn_manager.is_in_transaction(): @@ -647,23 +662,23 @@ async def send_offsets_to_transaction(self, offsets, group_id): fut = self._txn_manager.add_offsets_to_txn(formatted_offsets, group_id) await asyncio.shield(fut) - async def __aenter__(self): + async def __aenter__(self) -> AIOKafkaProducer[KT, VT]: await self.start() return self - async def __aexit__(self, exc_type, exc, tb): + async def __aexit__(self, exc_type: type[ET] | None, exc: ET | None, tb: TracebackType | None) -> None: await self.stop() class TransactionContext: - def __init__(self, producer): + def __init__(self, producer: AIOKafkaProducer[KT, VT]): self._producer = producer - async def __aenter__(self): + async def __aenter__(self) -> TransactionContext: await self._producer.begin_transaction() return self - async def __aexit__(self, exc_type, exc_value, traceback): + async def __aexit__(self, exc_type: type[ET] | None, exc: ET | None, tb: TracebackType | None) -> None: if exc_type is not None: # If called directly we want the API to raise a InvalidState error, # but when exiting a context manager we should just let it out diff --git a/aiokafka/py.typed b/aiokafka/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/aiokafka/structs.py b/aiokafka/structs.py index 3db1d20a4..393e79f82 100644 --- a/aiokafka/structs.py +++ b/aiokafka/structs.py @@ -27,7 +27,7 @@ class TopicPartition(NamedTuple): class BrokerMetadata(NamedTuple): """A Kafka broker metadata used by admin tools""" - nodeId: int + nodeId: int | str "The Kafka broker id" host: str @@ -117,8 +117,8 @@ class RecordMetadata(NamedTuple): "" -KT = TypeVar("KT") -VT = TypeVar("VT") +KT = TypeVar("KT", covariant=True) +VT = TypeVar("VT", covariant=True) @dataclass