From 4cba502b3827cbfdf64773d0b060a980ea32a4e3 Mon Sep 17 00:00:00 2001 From: Dmitriy Date: Sat, 29 Jun 2024 17:29:27 +0500 Subject: [PATCH] add typing to aiokafka/coordinator/* (#1006) * add typing to aiokafka/record/* * add some annotations to tests/record * fix almost all errors * test w/o protocols * Revert "test w/o protocols" This reverts commit 7fa1efa9f65a4cfaf889302a87426c8099711360. * use TypeIs * use dataclass * remove timestamp/timestamp_type from cython DefaultRecord * sync cython stubs with code * simplify types * add typing to aiokafka/coordinator/* * fix review * fix format * fix review * fix type errors * fix review * fix review * assert consumer is not None * fix review (continue is consumer is None) --- Makefile | 2 + aiokafka/cluster.py | 3 +- aiokafka/coordinator/assignors/abstract.py | 27 ++- aiokafka/coordinator/assignors/range.py | 30 ++- aiokafka/coordinator/assignors/roundrobin.py | 26 ++- .../assignors/sticky/partition_movements.py | 49 ++-- .../assignors/sticky/sorted_set.py | 58 +++-- .../assignors/sticky/sticky_assignor.py | 209 ++++++++++++------ aiokafka/coordinator/protocol.py | 16 +- tests/coordinator/test_assignors.py | 135 +++++++---- tests/coordinator/test_partition_movements.py | 6 +- 11 files changed, 382 insertions(+), 179 deletions(-) diff --git a/Makefile b/Makefile index 36714cbf..f0214347 100644 --- a/Makefile +++ b/Makefile @@ -7,6 +7,7 @@ DOCKER_IMAGE=aiolibs/kafka:$(SCALA_VERSION)_$(KAFKA_VERSION) DIFF_BRANCH=origin/master FORMATTED_AREAS=\ aiokafka/codec.py \ + aiokafka/coordinator/ \ aiokafka/errors.py \ aiokafka/helpers.py \ aiokafka/structs.py \ @@ -17,6 +18,7 @@ FORMATTED_AREAS=\ tests/test_helpers.py \ tests/test_protocol.py \ tests/test_protocol_object_conversion.py \ + tests/coordinator/ \ tests/record/ .PHONY: setup diff --git a/aiokafka/cluster.py b/aiokafka/cluster.py index 724bfd65..23e688bb 100644 --- a/aiokafka/cluster.py +++ b/aiokafka/cluster.py @@ -4,6 +4,7 @@ import threading import time from concurrent.futures import Future +from typing import Optional, Set from aiokafka import errors as Errors from aiokafka.conn import collect_hosts @@ -103,7 +104,7 @@ def broker_metadata(self, broker_id): or self._coordinator_brokers.get(broker_id) ) - def partitions_for_topic(self, topic): + def partitions_for_topic(self, topic: str) -> Optional[Set[int]]: """Return set of all partitions for topic (whether available or not) Arguments: diff --git a/aiokafka/coordinator/assignors/abstract.py b/aiokafka/coordinator/assignors/abstract.py index 329f2fc4..0946afb9 100644 --- a/aiokafka/coordinator/assignors/abstract.py +++ b/aiokafka/coordinator/assignors/abstract.py @@ -1,20 +1,33 @@ import abc import logging +from typing import Dict, Iterable, Mapping + +from aiokafka.cluster import ClusterMetadata +from aiokafka.coordinator.protocol import ( + ConsumerProtocolMemberAssignment, + ConsumerProtocolMemberMetadata, +) log = logging.getLogger(__name__) -class AbstractPartitionAssignor: +class AbstractPartitionAssignor(abc.ABC): """Abstract assignor implementation which does some common grunt work (in particular collecting partition counts which are always needed in assignors). """ - @abc.abstractproperty - def name(self): + @property + @abc.abstractmethod + def name(self) -> str: """.name should be a string identifying the assignor""" + @classmethod @abc.abstractmethod - def assign(self, cluster, members): + def assign( + cls, + cluster: ClusterMetadata, + members: Mapping[str, ConsumerProtocolMemberMetadata], + ) -> Dict[str, ConsumerProtocolMemberAssignment]: """Perform group assignment given cluster metadata and member subscriptions Arguments: @@ -26,8 +39,9 @@ def assign(self, cluster, members): dict: {member_id: MemberAssignment} """ + @classmethod @abc.abstractmethod - def metadata(self, topics): + def metadata(cls, topics: Iterable[str]) -> ConsumerProtocolMemberMetadata: """Generate ProtocolMetadata to be submitted via JoinGroupRequest. Arguments: @@ -37,8 +51,9 @@ def metadata(self, topics): MemberMetadata struct """ + @classmethod @abc.abstractmethod - def on_assignment(self, assignment): + def on_assignment(cls, assignment: ConsumerProtocolMemberAssignment) -> None: """Callback that runs on each assignment. This method can be used to update internal state, if any, of the diff --git a/aiokafka/coordinator/assignors/range.py b/aiokafka/coordinator/assignors/range.py index 101b1387..7d398f9e 100644 --- a/aiokafka/coordinator/assignors/range.py +++ b/aiokafka/coordinator/assignors/range.py @@ -1,6 +1,8 @@ import collections import logging +from typing import Dict, Iterable, List, Mapping +from aiokafka.cluster import ClusterMetadata from aiokafka.coordinator.assignors.abstract import AbstractPartitionAssignor from aiokafka.coordinator.protocol import ( ConsumerProtocolMemberAssignment, @@ -32,25 +34,29 @@ class RangePartitionAssignor(AbstractPartitionAssignor): version = 0 @classmethod - def assign(cls, cluster, member_metadata): - consumers_per_topic = collections.defaultdict(list) - for member, metadata in member_metadata.items(): + def assign( + cls, + cluster: ClusterMetadata, + members: Mapping[str, ConsumerProtocolMemberMetadata], + ) -> Dict[str, ConsumerProtocolMemberAssignment]: + consumers_per_topic: Dict[str, List[str]] = collections.defaultdict(list) + for member, metadata in members.items(): for topic in metadata.subscription: consumers_per_topic[topic].append(member) # construct {member_id: {topic: [partition, ...]}} - assignment = collections.defaultdict(dict) + assignment: Dict[str, Dict[str, List[int]]] = collections.defaultdict(dict) for topic, consumers_for_topic in consumers_per_topic.items(): partitions = cluster.partitions_for_topic(topic) if partitions is None: log.warning("No partition metadata for topic %s", topic) continue - partitions = sorted(partitions) + partitions_list = sorted(partitions) consumers_for_topic.sort() - partitions_per_consumer = len(partitions) // len(consumers_for_topic) - consumers_with_extra = len(partitions) % len(consumers_for_topic) + partitions_per_consumer = len(partitions_list) // len(consumers_for_topic) + consumers_with_extra = len(partitions_list) % len(consumers_for_topic) for i, member in enumerate(consumers_for_topic): start = partitions_per_consumer * i @@ -58,19 +64,19 @@ def assign(cls, cluster, member_metadata): length = partitions_per_consumer if not i + 1 > consumers_with_extra: length += 1 - assignment[member][topic] = partitions[start : start + length] + assignment[member][topic] = partitions_list[start : start + length] - protocol_assignment = {} - for member_id in member_metadata: + protocol_assignment: Dict[str, ConsumerProtocolMemberAssignment] = {} + for member_id in members: protocol_assignment[member_id] = ConsumerProtocolMemberAssignment( cls.version, sorted(assignment[member_id].items()), b"" ) return protocol_assignment @classmethod - def metadata(cls, topics): + def metadata(cls, topics: Iterable[str]) -> ConsumerProtocolMemberMetadata: return ConsumerProtocolMemberMetadata(cls.version, list(topics), b"") @classmethod - def on_assignment(cls, assignment): + def on_assignment(cls, assignment: ConsumerProtocolMemberAssignment) -> None: pass diff --git a/aiokafka/coordinator/assignors/roundrobin.py b/aiokafka/coordinator/assignors/roundrobin.py index 4d6a464f..0399b199 100644 --- a/aiokafka/coordinator/assignors/roundrobin.py +++ b/aiokafka/coordinator/assignors/roundrobin.py @@ -1,7 +1,9 @@ import collections import itertools import logging +from typing import Dict, Iterable, List, Mapping +from aiokafka.cluster import ClusterMetadata from aiokafka.coordinator.assignors.abstract import AbstractPartitionAssignor from aiokafka.coordinator.protocol import ( ConsumerProtocolMemberAssignment, @@ -49,12 +51,16 @@ class RoundRobinPartitionAssignor(AbstractPartitionAssignor): version = 0 @classmethod - def assign(cls, cluster, member_metadata): + def assign( + cls, + cluster: ClusterMetadata, + members: Mapping[str, ConsumerProtocolMemberMetadata], + ) -> Dict[str, ConsumerProtocolMemberAssignment]: all_topics = set() - for metadata in member_metadata.values(): + for metadata in members.values(): all_topics.update(metadata.subscription) - all_topic_partitions = [] + all_topic_partitions: List[TopicPartition] = [] for topic in all_topics: partitions = cluster.partitions_for_topic(topic) if partitions is None: @@ -66,9 +72,11 @@ def assign(cls, cluster, member_metadata): all_topic_partitions.sort() # construct {member_id: {topic: [partition, ...]}} - assignment = collections.defaultdict(lambda: collections.defaultdict(list)) + assignment: Dict[str, Dict[str, List[int]]] = collections.defaultdict( + lambda: collections.defaultdict(list) + ) - member_iter = itertools.cycle(sorted(member_metadata.keys())) + member_iter = itertools.cycle(sorted(members.keys())) for partition in all_topic_partitions: member_id = next(member_iter) @@ -76,21 +84,21 @@ def assign(cls, cluster, member_metadata): # member subscribed topics, we should be safe assuming that # each topic in all_topic_partitions is in at least one member # subscription; otherwise this could yield an infinite loop - while partition.topic not in member_metadata[member_id].subscription: + while partition.topic not in members[member_id].subscription: member_id = next(member_iter) assignment[member_id][partition.topic].append(partition.partition) protocol_assignment = {} - for member_id in member_metadata: + for member_id in members: protocol_assignment[member_id] = ConsumerProtocolMemberAssignment( cls.version, sorted(assignment[member_id].items()), b"" ) return protocol_assignment @classmethod - def metadata(cls, topics): + def metadata(cls, topics: Iterable[str]) -> ConsumerProtocolMemberMetadata: return ConsumerProtocolMemberMetadata(cls.version, list(topics), b"") @classmethod - def on_assignment(cls, assignment): + def on_assignment(cls, assignment: ConsumerProtocolMemberAssignment) -> None: pass diff --git a/aiokafka/coordinator/assignors/sticky/partition_movements.py b/aiokafka/coordinator/assignors/sticky/partition_movements.py index d5858ddd..e9531113 100644 --- a/aiokafka/coordinator/assignors/sticky/partition_movements.py +++ b/aiokafka/coordinator/assignors/sticky/partition_movements.py @@ -1,11 +1,18 @@ import logging -from collections import defaultdict, namedtuple +from collections import defaultdict from copy import deepcopy +from typing import Any, Dict, List, NamedTuple, Sequence, Set, Tuple + +from aiokafka.structs import TopicPartition log = logging.getLogger(__name__) -ConsumerPair = namedtuple("ConsumerPair", ["src_member_id", "dst_member_id"]) +class ConsumerPair(NamedTuple): + src_member_id: str + dst_member_id: str + + """ Represents a pair of Kafka consumer ids involved in a partition reassignment. Each ConsumerPair corresponds to a particular partition or topic, indicates that the @@ -16,7 +23,7 @@ """ -def is_sublist(source, target): +def is_sublist(source: Sequence[Any], target: Sequence[Any]) -> bool: """Checks if one list is a sublist of another. Arguments: @@ -40,11 +47,13 @@ class PartitionMovements: form a ConsumerPair object) for each partition. """ - def __init__(self): - self.partition_movements_by_topic = defaultdict(lambda: defaultdict(set)) - self.partition_movements = {} + def __init__(self) -> None: + self.partition_movements_by_topic: Dict[str, Dict[ConsumerPair, Set[TopicPartition]]] = defaultdict(lambda: defaultdict(set)) # fmt: skip # noqa: E501 + self.partition_movements: Dict[TopicPartition, ConsumerPair] = {} - def move_partition(self, partition, old_consumer, new_consumer): + def move_partition( + self, partition: TopicPartition, old_consumer: str, new_consumer: str + ) -> None: pair = ConsumerPair(src_member_id=old_consumer, dst_member_id=new_consumer) if partition in self.partition_movements: # this partition has previously moved @@ -62,7 +71,9 @@ def move_partition(self, partition, old_consumer, new_consumer): else: self._add_partition_movement_record(partition, pair) - def get_partition_to_be_moved(self, partition, old_consumer, new_consumer): + def get_partition_to_be_moved( + self, partition: TopicPartition, old_consumer: str, new_consumer: str + ) -> TopicPartition: if partition.topic not in self.partition_movements_by_topic: return partition if partition in self.partition_movements: @@ -79,7 +90,7 @@ def get_partition_to_be_moved(self, partition, old_consumer, new_consumer): iter(self.partition_movements_by_topic[partition.topic][reverse_pair]) ) - def are_sticky(self): + def are_sticky(self) -> bool: for topic, movements in self.partition_movements_by_topic.items(): movement_pairs = set(movements.keys()) if self._has_cycles(movement_pairs): @@ -93,7 +104,9 @@ def are_sticky(self): return False return True - def _remove_movement_record_of_partition(self, partition): + def _remove_movement_record_of_partition( + self, partition: TopicPartition + ) -> ConsumerPair: pair = self.partition_movements[partition] del self.partition_movements[partition] @@ -105,16 +118,18 @@ def _remove_movement_record_of_partition(self, partition): return pair - def _add_partition_movement_record(self, partition, pair): + def _add_partition_movement_record( + self, partition: TopicPartition, pair: ConsumerPair + ) -> None: self.partition_movements[partition] = pair self.partition_movements_by_topic[partition.topic][pair].add(partition) - def _has_cycles(self, consumer_pairs): - cycles = set() + def _has_cycles(self, consumer_pairs: Set[ConsumerPair]) -> bool: + cycles: Set[Tuple[str, ...]] = set() for pair in consumer_pairs: reduced_pairs = deepcopy(consumer_pairs) reduced_pairs.remove(pair) - path = [pair.src_member_id] + path: List[str] = [pair.src_member_id] if self._is_linked( pair.dst_member_id, pair.src_member_id, reduced_pairs, path ) and not self._is_subcycle(path, cycles): @@ -132,7 +147,7 @@ def _has_cycles(self, consumer_pairs): ) @staticmethod - def _is_subcycle(cycle, cycles): + def _is_subcycle(cycle: List[str], cycles: Set[Tuple[str, ...]]) -> bool: super_cycle = deepcopy(cycle) super_cycle = super_cycle[:-1] super_cycle.extend(cycle) @@ -141,7 +156,9 @@ def _is_subcycle(cycle, cycles): return True return False - def _is_linked(self, src, dst, pairs, current_path): + def _is_linked( + self, src: str, dst: str, pairs: Set[ConsumerPair], current_path: List[str] + ) -> bool: if src == dst: return False if not pairs: diff --git a/aiokafka/coordinator/assignors/sticky/sorted_set.py b/aiokafka/coordinator/assignors/sticky/sorted_set.py index 7903f6ca..8ffc53bf 100644 --- a/aiokafka/coordinator/assignors/sticky/sorted_set.py +++ b/aiokafka/coordinator/assignors/sticky/sorted_set.py @@ -1,12 +1,33 @@ -class SortedSet: - def __init__(self, iterable=None, key=None): - self._key = key if key is not None else lambda x: x - self._set = set(iterable) if iterable is not None else set() +from typing import ( + Any, + Callable, + Collection, + Generic, + Iterable, + Iterator, + Optional, + Set, + TypeVar, + final, +) - self._cached_last = None - self._cached_first = None +T = TypeVar("T") + + +@final +class SortedSet(Generic[T], Collection[T]): + def __init__( + self, + iterable: Optional[Iterable[T]] = None, + key: Optional[Callable[[T], Any]] = None, + ) -> None: + self._key: Callable[[T], Any] = key if key is not None else lambda x: x + self._set: Set[T] = set(iterable) if iterable is not None else set() + + self._cached_last: Optional[T] = None + self._cached_first: Optional[T] = None - def first(self): + def first(self) -> Optional[T]: if self._cached_first is not None: return self._cached_first @@ -17,7 +38,7 @@ def first(self): self._cached_first = first return first - def last(self): + def last(self) -> Optional[T]: if self._cached_last is not None: return self._cached_last @@ -28,13 +49,17 @@ def last(self): self._cached_last = last return last - def pop_last(self): + def pop_last(self) -> T: value = self.last() + + if value is None: + raise KeyError + self._set.remove(value) self._cached_last = None return value - def add(self, value): + def add(self, value: T) -> None: if self._cached_last is not None and self._key(value) > self._key( self._cached_last ): @@ -46,7 +71,7 @@ def add(self, value): return self._set.add(value) - def remove(self, value): + def remove(self, value: T) -> None: if self._cached_last is not None and self._cached_last == value: self._cached_last = None if self._cached_first is not None and self._cached_first == value: @@ -54,14 +79,11 @@ def remove(self, value): return self._set.remove(value) - def __contains__(self, value): + def __contains__(self, value: Any) -> bool: return value in self._set - def __iter__(self): + def __iter__(self) -> Iterator[T]: return iter(sorted(self._set, key=self._key)) - def _bool(self): - return len(self._set) != 0 - - __nonzero__ = _bool - __bool__ = _bool + def __len__(self) -> int: + return len(self._set) diff --git a/aiokafka/coordinator/assignors/sticky/sticky_assignor.py b/aiokafka/coordinator/assignors/sticky/sticky_assignor.py index e1941230..c0462ce1 100644 --- a/aiokafka/coordinator/assignors/sticky/sticky_assignor.py +++ b/aiokafka/coordinator/assignors/sticky/sticky_assignor.py @@ -1,8 +1,23 @@ import contextlib import logging -from collections import defaultdict, namedtuple +from collections import defaultdict from copy import deepcopy +from typing import ( + Any, + Collection, + Dict, + Iterable, + List, + Mapping, + MutableSequence, + NamedTuple, + Optional, + Sequence, + Sized, + Tuple, +) +from aiokafka.cluster import ClusterMetadata from aiokafka.coordinator.assignors.abstract import AbstractPartitionAssignor from aiokafka.coordinator.assignors.sticky.partition_movements import PartitionMovements from aiokafka.coordinator.assignors.sticky.sorted_set import SortedSet @@ -17,12 +32,18 @@ log = logging.getLogger(__name__) -ConsumerGenerationPair = namedtuple( - "ConsumerGenerationPair", ["consumer", "generation"] -) +class ConsumerGenerationPair(NamedTuple): + consumer: str + generation: int + + +class ConsumerSubscription(NamedTuple): + consumer: str + partitions: Sequence[TopicPartition] -def has_identical_list_elements(list_): + +def has_identical_list_elements(list_: Sequence[List[Any]]) -> bool: """Checks if all lists in the collection have the same members Arguments: @@ -36,22 +57,25 @@ def has_identical_list_elements(list_): return all(list_[i] == list_[i - 1] for i in range(1, len(list_))) -def subscriptions_comparator_key(element): +def subscriptions_comparator_key(element: Tuple[str, Sized]) -> Tuple[int, str]: return len(element[1]), element[0] -def partitions_comparator_key(element): +def partitions_comparator_key( + element: Tuple[TopicPartition, Sized], +) -> Tuple[int, str, int]: return len(element[1]), element[0].topic, element[0].partition -def remove_if_present(collection, element): +def remove_if_present(collection: MutableSequence[Any], element: Any) -> None: with contextlib.suppress(ValueError, KeyError): collection.remove(element) -StickyAssignorMemberMetadataV1 = namedtuple( - "StickyAssignorMemberMetadataV1", ["subscription", "partitions", "generation"] -) +class StickyAssignorMemberMetadataV1(NamedTuple): + subscription: List[str] + partitions: List[TopicPartition] + generation: int class StickyAssignorUserDataV1(Struct): @@ -60,6 +84,13 @@ class StickyAssignorUserDataV1(Struct): list and sending it as user data to the leader during a rebalance """ + class PreviousAssignment(NamedTuple): + topic: str + partitions: List[int] + + previous_assignment: List[PreviousAssignment] + generation: int + SCHEMA = Schema( ( "previous_assignment", @@ -70,31 +101,35 @@ class StickyAssignorUserDataV1(Struct): class StickyAssignmentExecutor: - def __init__(self, cluster, members): + def __init__( + self, + cluster: ClusterMetadata, + members: Dict[str, StickyAssignorMemberMetadataV1], + ) -> None: self.members = members # a mapping between consumers and their assigned partitions that is updated # during assignment procedure - self.current_assignment = defaultdict(list) + self.current_assignment: Dict[str, List[TopicPartition]] = defaultdict(list) # an assignment from a previous generation - self.previous_assignment = {} + self.previous_assignment: Dict[TopicPartition, ConsumerGenerationPair] = {} # a mapping between partitions and their assigned consumers - self.current_partition_consumer = {} + self.current_partition_consumer: Dict[TopicPartition, str] = {} # a flag indicating that there were no previous assignments performed ever self.is_fresh_assignment = False # a mapping of all topic partitions to all consumers that can be assigned to # them - self.partition_to_all_potential_consumers = {} + self.partition_to_all_potential_consumers: Dict[TopicPartition, List[str]] = {} # a mapping of all consumers to all potential topic partitions that can be # assigned to them - self.consumer_to_all_potential_partitions = {} + self.consumer_to_all_potential_partitions: Dict[str, List[TopicPartition]] = {} # an ascending sorted set of consumers based on how many topic partitions are # already assigned to them - self.sorted_current_subscriptions = SortedSet() + self.sorted_current_subscriptions: SortedSet[ConsumerSubscription] = SortedSet() # an ascending sorted list of topic partitions based on how many consumers can # potentially use them - self.sorted_partitions = [] + self.sorted_partitions: List[TopicPartition] = [] # all partitions that need to be assigned - self.unassigned_partitions = [] + self.unassigned_partitions: List[TopicPartition] = [] # a flag indicating that a certain partition cannot remain assigned to its # current consumer because the consumer is no longer subscribed to its topic self.revocation_required = False @@ -102,11 +137,11 @@ def __init__(self, cluster, members): self.partition_movements = PartitionMovements() self._initialize(cluster) - def perform_initial_assignment(self): + def perform_initial_assignment(self) -> None: self._populate_sorted_partitions() self._populate_partitions_to_reassign() - def balance(self): + def balance(self) -> None: self._initialize_current_subscriptions() initializing = ( len(self.current_assignment[self._get_consumer_with_most_subscriptions()]) @@ -132,7 +167,7 @@ def balance(self): # narrow down the reassignment scope to only those consumers that are subject to # reassignment - fixed_assignments = {} + fixed_assignments: Dict[str, List[TopicPartition]] = {} for consumer in self.consumer_to_all_potential_partitions: if not self._can_consumer_participate_in_reassignment(consumer): self._remove_consumer_from_current_subscriptions_and_maintain_order( @@ -170,14 +205,14 @@ def balance(self): self.current_assignment[consumer] = partitions self._add_consumer_to_current_subscriptions_and_maintain_order(consumer) - def get_final_assignment(self, member_id): - assignment = defaultdict(list) + def get_final_assignment(self, member_id: str) -> Collection[Tuple[str, List[int]]]: + assignment: Dict[str, List[int]] = defaultdict(list) for topic_partition in self.current_assignment[member_id]: assignment[topic_partition.topic].append(topic_partition.partition) assignment = {k: sorted(v) for k, v in assignment.items()} return assignment.items() - def _initialize(self, cluster): + def _initialize(self, cluster: ClusterMetadata) -> None: self._init_current_assignments(self.members) for topic in cluster.topics(): @@ -191,10 +226,11 @@ def _initialize(self, cluster): for consumer_id, member_metadata in self.members.items(): self.consumer_to_all_potential_partitions[consumer_id] = [] for topic in member_metadata.subscription: - if cluster.partitions_for_topic(topic) is None: + partitions_for_topic = cluster.partitions_for_topic(topic) + if partitions_for_topic is None: log.warning("No partition metadata for topic %r", topic) continue - for p in cluster.partitions_for_topic(topic): + for p in partitions_for_topic: partition = TopicPartition(topic=topic, partition=p) self.consumer_to_all_potential_partitions[consumer_id].append( partition @@ -205,14 +241,18 @@ def _initialize(self, cluster): if consumer_id not in self.current_assignment: self.current_assignment[consumer_id] = [] - def _init_current_assignments(self, members): + def _init_current_assignments( + self, members: Dict[str, StickyAssignorMemberMetadataV1] + ) -> None: # we need to process subscriptions' user data with each consumer's reported # generation in mind higher generations overwrite lower generations in case of # a conflict note that a conflict could exists only if user data is for # different generations # for each partition we create a map of its consumers by generation - sorted_partition_consumers_by_generation = {} + sorted_partition_consumers_by_generation: Dict[ + TopicPartition, Dict[int, str] + ] = {} for consumer, member_metadata in members.items(): for partition in member_metadata.partitions: if partition in sorted_partition_consumers_by_generation: @@ -255,7 +295,7 @@ def _init_current_assignments(self, members): for partition in partitions: self.current_partition_consumer[partition] = consumer_id - def _are_subscriptions_identical(self): + def _are_subscriptions_identical(self) -> bool: """ Returns: true, if both potential consumers of partitions and potential partitions @@ -269,7 +309,7 @@ def _are_subscriptions_identical(self): list(self.consumer_to_all_potential_partitions.values()) ) - def _populate_sorted_partitions(self): + def _populate_sorted_partitions(self) -> None: # set of topic partitions with their respective potential consumers all_partitions = { (tp, tuple(consumers)) @@ -303,6 +343,7 @@ def _populate_sorted_partitions(self): # at this point, sorted_consumers contains an ascending-sorted list of # consumers based on how many valid partitions are currently assigned to # them + while sorted_consumers: # take the consumer with the most partitions consumer, _ = sorted_consumers.pop_last() @@ -336,10 +377,10 @@ def _populate_sorted_partitions(self): partitions_sorted_by_num_of_potential_consumers.pop(0)[0] ) - def _populate_partitions_to_reassign(self): + def _populate_partitions_to_reassign(self) -> None: self.unassigned_partitions = deepcopy(self.sorted_partitions) - assignments_to_remove = [] + assignments_to_remove: List[str] = [] for consumer_id, partitions in self.current_assignment.items(): if consumer_id not in self.members: # if a consumer that existed before (and had some partition assignments) @@ -373,32 +414,46 @@ def _populate_partitions_to_reassign(self): for consumer_id in assignments_to_remove: del self.current_assignment[consumer_id] - def _initialize_current_subscriptions(self): + def _initialize_current_subscriptions(self) -> None: self.sorted_current_subscriptions = SortedSet( iterable=[ - (consumer, tuple(partitions)) + ConsumerSubscription(consumer=consumer, partitions=tuple(partitions)) for consumer, partitions in self.current_assignment.items() ], key=subscriptions_comparator_key, ) - def _get_consumer_with_least_subscriptions(self): - return self.sorted_current_subscriptions.first()[0] + def _get_consumer_with_least_subscriptions(self) -> str: + if current_subscription := self.sorted_current_subscriptions.first(): + return current_subscription[0] + raise ValueError("sorted_current_subscriptions is empty") - def _get_consumer_with_most_subscriptions(self): - return self.sorted_current_subscriptions.last()[0] + def _get_consumer_with_most_subscriptions(self) -> str: + if current_subscription := self.sorted_current_subscriptions.last(): + return current_subscription[0] + raise ValueError("sorted_current_subscriptions is empty") - def _remove_consumer_from_current_subscriptions_and_maintain_order(self, consumer): + def _remove_consumer_from_current_subscriptions_and_maintain_order( + self, consumer: str + ) -> None: self.sorted_current_subscriptions.remove( - (consumer, tuple(self.current_assignment[consumer])) + ConsumerSubscription( + consumer=consumer, + partitions=tuple(self.current_assignment[consumer]), + ) ) - def _add_consumer_to_current_subscriptions_and_maintain_order(self, consumer): + def _add_consumer_to_current_subscriptions_and_maintain_order( + self, consumer: str + ) -> None: self.sorted_current_subscriptions.add( - (consumer, tuple(self.current_assignment[consumer])) + ConsumerSubscription( + consumer=consumer, + partitions=tuple(self.current_assignment[consumer]), + ) ) - def _is_balanced(self): + def _is_balanced(self) -> bool: """Determines if the current assignment is a balanced one""" if ( len(self.current_assignment[self._get_consumer_with_least_subscriptions()]) @@ -443,7 +498,7 @@ def _is_balanced(self): return False return True - def _assign_partition(self, partition): + def _assign_partition(self, partition: TopicPartition) -> None: for consumer, _ in self.sorted_current_subscriptions: if partition in self.consumer_to_all_potential_partitions[consumer]: self._remove_consumer_from_current_subscriptions_and_maintain_order( @@ -454,10 +509,12 @@ def _assign_partition(self, partition): self._add_consumer_to_current_subscriptions_and_maintain_order(consumer) break - def _can_partition_participate_in_reassignment(self, partition): + def _can_partition_participate_in_reassignment( + self, partition: TopicPartition + ) -> bool: return len(self.partition_to_all_potential_consumers[partition]) >= 2 - def _can_consumer_participate_in_reassignment(self, consumer): + def _can_consumer_participate_in_reassignment(self, consumer: str) -> bool: current_partitions = self.current_assignment[consumer] current_assignment_size = len(current_partitions) max_assignment_size = len(self.consumer_to_all_potential_partitions[consumer]) @@ -478,7 +535,9 @@ def _can_consumer_participate_in_reassignment(self, consumer): return True return False - def _perform_reassignments(self, reassignable_partitions): + def _perform_reassignments( + self, reassignable_partitions: List[TopicPartition] + ) -> bool: reassignment_performed = False # repeat reassignment until no partition can be moved to improve the balance @@ -502,6 +561,7 @@ def _perform_reassignments(self, reassignable_partitions): log.error( "Expected partition %r to be assigned to a consumer", partition ) + continue if ( partition in self.previous_assignment @@ -539,7 +599,7 @@ def _perform_reassignments(self, reassignable_partitions): break return reassignment_performed - def _reassign_partition(self, partition): + def _reassign_partition(self, partition: TopicPartition) -> None: new_consumer = None for another_consumer, _ in self.sorted_current_subscriptions: if partition in self.consumer_to_all_potential_partitions[another_consumer]: @@ -548,7 +608,9 @@ def _reassign_partition(self, partition): assert new_consumer is not None self._reassign_partition_to_consumer(partition, new_consumer) - def _reassign_partition_to_consumer(self, partition, new_consumer): + def _reassign_partition_to_consumer( + self, partition: TopicPartition, new_consumer: str + ) -> None: consumer = self.current_partition_consumer[partition] # find the correct partition movement considering the stickiness requirement partition_to_be_moved = self.partition_movements.get_partition_to_be_moved( @@ -556,7 +618,7 @@ def _reassign_partition_to_consumer(self, partition, new_consumer): ) self._move_partition(partition_to_be_moved, new_consumer) - def _move_partition(self, partition, new_consumer): + def _move_partition(self, partition: TopicPartition, new_consumer: str) -> None: old_consumer = self.current_partition_consumer[partition] self._remove_consumer_from_current_subscriptions_and_maintain_order( old_consumer @@ -575,7 +637,7 @@ def _move_partition(self, partition, new_consumer): self._add_consumer_to_current_subscriptions_and_maintain_order(old_consumer) @staticmethod - def _get_balance_score(assignment): + def _get_balance_score(assignment: Dict[str, List[TopicPartition]]) -> int: """Calculates a balance score of a give assignment as the sum of assigned partitions size difference of all consumer pairs. A perfectly balanced assignment (with all consumers getting the same number of @@ -589,7 +651,7 @@ def _get_balance_score(assignment): the balance score of the assignment """ score = 0 - consumer_to_assignment = {} + consumer_to_assignment: Dict[str, int] = {} for consumer_id, partitions in assignment.items(): consumer_to_assignment[consumer_id] = len(partitions) @@ -684,13 +746,17 @@ class StickyPartitionAssignor(AbstractPartitionAssignor): name = "sticky" version = 0 - member_assignment = None - generation = DEFAULT_GENERATION_ID + member_assignment: Optional[List[TopicPartition]] = None + generation: int = DEFAULT_GENERATION_ID - _latest_partition_movements = None + _latest_partition_movements: Optional[PartitionMovements] = None @classmethod - def assign(cls, cluster, members): + def assign( + cls, + cluster: ClusterMetadata, + members: Mapping[str, ConsumerProtocolMemberMetadata], + ) -> Dict[str, ConsumerProtocolMemberAssignment]: """Performs group assignment given cluster metadata and member subscriptions Arguments: @@ -701,7 +767,7 @@ def assign(cls, cluster, members): Returns: dict: {member_id: MemberAssignment} """ - members_metadata = {} + members_metadata: Dict[str, StickyAssignorMemberMetadataV1] = {} for consumer, member_metadata in members.items(): members_metadata[consumer] = cls.parse_member_metadata(member_metadata) @@ -711,15 +777,19 @@ def assign(cls, cluster, members): cls._latest_partition_movements = executor.partition_movements - assignment = {} + assignment: Dict[str, ConsumerProtocolMemberAssignment] = {} for member_id in members: assignment[member_id] = ConsumerProtocolMemberAssignment( - cls.version, sorted(executor.get_final_assignment(member_id)), b"" + cls.version, + sorted(executor.get_final_assignment(member_id)), + b"", ) return assignment @classmethod - def parse_member_metadata(cls, metadata): + def parse_member_metadata( + cls, metadata: ConsumerProtocolMemberMetadata + ) -> StickyAssignorMemberMetadataV1: """ Parses member metadata into a python object. This implementation only serializes and deserializes the @@ -752,7 +822,7 @@ def parse_member_metadata(cls, metadata): subscription=metadata.subscription, ) - member_partitions = [] + member_partitions: List[TopicPartition] = [] for ( topic, partitions, @@ -767,11 +837,16 @@ def parse_member_metadata(cls, metadata): ) @classmethod - def metadata(cls, topics): + def metadata(cls, topics: Iterable[str]) -> ConsumerProtocolMemberMetadata: return cls._metadata(topics, cls.member_assignment, cls.generation) @classmethod - def _metadata(cls, topics, member_assignment_partitions, generation=-1): + def _metadata( + cls, + topics: Iterable[str], + member_assignment_partitions: Optional[List[TopicPartition]], + generation: int = -1, + ) -> ConsumerProtocolMemberMetadata: if member_assignment_partitions is None: log.debug("No member assignment available") user_data = b"" @@ -791,7 +866,7 @@ def _metadata(cls, topics, member_assignment_partitions, generation=-1): return ConsumerProtocolMemberMetadata(cls.version, list(topics), user_data) @classmethod - def on_assignment(cls, assignment): + def on_assignment(cls, assignment: ConsumerProtocolMemberAssignment) -> None: """Callback that runs on each assignment. Updates assignor's state. Arguments: @@ -801,7 +876,7 @@ def on_assignment(cls, assignment): cls.member_assignment = assignment.partitions() @classmethod - def on_generation_assignment(cls, generation): + def on_generation_assignment(cls, generation: int) -> None: """Callback that runs on each assignment. Updates assignor's generation id. Arguments: diff --git a/aiokafka/coordinator/protocol.py b/aiokafka/coordinator/protocol.py index 1dc79434..afc15e56 100644 --- a/aiokafka/coordinator/protocol.py +++ b/aiokafka/coordinator/protocol.py @@ -1,9 +1,15 @@ +from typing import List, NamedTuple + from aiokafka.protocol.struct import Struct from aiokafka.protocol.types import Array, Bytes, Int16, Int32, Schema, String from aiokafka.structs import TopicPartition class ConsumerProtocolMemberMetadata(Struct): + version: int + subscription: List[str] + user_data: bytes + SCHEMA = Schema( ("version", Int16), ("subscription", Array(String("utf-8"))), @@ -12,13 +18,21 @@ class ConsumerProtocolMemberMetadata(Struct): class ConsumerProtocolMemberAssignment(Struct): + class Assignment(NamedTuple): + topic: str + partitions: List[int] + + version: int + assignment: List[Assignment] + user_data: bytes + SCHEMA = Schema( ("version", Int16), ("assignment", Array(("topic", String("utf-8")), ("partitions", Array(Int32)))), ("user_data", Bytes), ) - def partitions(self): + def partitions(self) -> List[TopicPartition]: return [ TopicPartition(topic, partition) for topic, partitions in self.assignment diff --git a/tests/coordinator/test_assignors.py b/tests/coordinator/test_assignors.py index a7b327c9..86572459 100644 --- a/tests/coordinator/test_assignors.py +++ b/tests/coordinator/test_assignors.py @@ -1,27 +1,36 @@ from collections import defaultdict from random import randint, sample +from typing import Callable, Dict, Generator, Optional, Sequence, Set +from unittest.mock import MagicMock import pytest +from pytest_mock import MockerFixture from aiokafka.coordinator.assignors.range import RangePartitionAssignor from aiokafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor from aiokafka.coordinator.assignors.sticky.sticky_assignor import ( StickyPartitionAssignor, ) -from aiokafka.coordinator.protocol import ConsumerProtocolMemberAssignment +from aiokafka.coordinator.protocol import ( + ConsumerProtocolMemberAssignment, + ConsumerProtocolMemberMetadata, +) from aiokafka.structs import TopicPartition @pytest.fixture(autouse=True) -def reset_sticky_assignor(): +def reset_sticky_assignor() -> Generator[None, None, None]: yield StickyPartitionAssignor.member_assignment = None StickyPartitionAssignor.generation = -1 def create_cluster( - mocker, topics, topics_partitions=None, topic_partitions_lambda=None -): + mocker: MockerFixture, + topics: Set[str], + topics_partitions: Optional[Set[int]] = None, + topic_partitions_lambda: Optional[Callable[[str], Optional[Set[int]]]] = None, +) -> MagicMock: cluster = mocker.MagicMock() cluster.topics.return_value = topics if topics_partitions is not None: @@ -31,7 +40,7 @@ def create_cluster( return cluster -def test_assignor_roundrobin(mocker): +def test_assignor_roundrobin(mocker: MockerFixture) -> None: assignor = RoundRobinPartitionAssignor member_metadata = { @@ -55,7 +64,7 @@ def test_assignor_roundrobin(mocker): assert ret[member].encode() == expected[member].encode() -def test_assignor_range(mocker): +def test_assignor_range(mocker: MockerFixture) -> None: assignor = RangePartitionAssignor member_metadata = { @@ -79,7 +88,7 @@ def test_assignor_range(mocker): assert ret[member].encode() == expected[member].encode() -def test_sticky_assignor1(mocker): +def test_sticky_assignor1(mocker: MockerFixture) -> None: """ Given: there are three consumers C0, C1, C2, four topics t0, t1, t2, t3, and each topic has 2 partitions, @@ -147,7 +156,7 @@ def test_sticky_assignor1(mocker): assert_assignment(sticky_assignment, expected_assignment) -def test_sticky_assignor2(mocker): +def test_sticky_assignor2(mocker: MockerFixture) -> None: """ Given: there are three consumers C0, C1, C2, and three topics t0, t1, t2, with 1, 2, and 3 partitions respectively. @@ -215,10 +224,10 @@ def test_sticky_assignor2(mocker): assert_assignment(sticky_assignment, expected_assignment) -def test_sticky_one_consumer_no_topic(mocker): - cluster = create_cluster(mocker, topics={}, topics_partitions={}) +def test_sticky_one_consumer_no_topic(mocker: MockerFixture) -> None: + cluster = create_cluster(mocker, topics=set(), topics_partitions=set()) - subscriptions = { + subscriptions: Dict[str, Set[str]] = { "C": set(), } member_metadata = make_member_metadata(subscriptions) @@ -230,8 +239,8 @@ def test_sticky_one_consumer_no_topic(mocker): assert_assignment(sticky_assignment, expected_assignment) -def test_sticky_one_consumer_nonexisting_topic(mocker): - cluster = create_cluster(mocker, topics={}, topics_partitions={}) +def test_sticky_one_consumer_nonexisting_topic(mocker: MockerFixture) -> None: + cluster = create_cluster(mocker, topics=set(), topics_partitions=set()) subscriptions = { "C": {"t"}, @@ -245,7 +254,7 @@ def test_sticky_one_consumer_nonexisting_topic(mocker): assert_assignment(sticky_assignment, expected_assignment) -def test_sticky_one_consumer_one_topic(mocker): +def test_sticky_one_consumer_one_topic(mocker: MockerFixture) -> None: cluster = create_cluster(mocker, topics={"t"}, topics_partitions={0, 1, 2}) subscriptions = { @@ -262,7 +271,9 @@ def test_sticky_one_consumer_one_topic(mocker): assert_assignment(sticky_assignment, expected_assignment) -def test_sticky_should_only_assign_partitions_from_subscribed_topics(mocker): +def test_sticky_should_only_assign_partitions_from_subscribed_topics( + mocker: MockerFixture, +) -> None: cluster = create_cluster( mocker, topics={"t", "other-t"}, topics_partitions={0, 1, 2} ) @@ -281,7 +292,7 @@ def test_sticky_should_only_assign_partitions_from_subscribed_topics(mocker): assert_assignment(sticky_assignment, expected_assignment) -def test_sticky_one_consumer_multiple_topics(mocker): +def test_sticky_one_consumer_multiple_topics(mocker: MockerFixture) -> None: cluster = create_cluster(mocker, topics={"t1", "t2"}, topics_partitions={0, 1, 2}) subscriptions = { @@ -298,7 +309,7 @@ def test_sticky_one_consumer_multiple_topics(mocker): assert_assignment(sticky_assignment, expected_assignment) -def test_sticky_two_consumers_one_topic_one_partition(mocker): +def test_sticky_two_consumers_one_topic_one_partition(mocker: MockerFixture) -> None: cluster = create_cluster(mocker, topics={"t"}, topics_partitions={0}) subscriptions = { @@ -319,7 +330,7 @@ def test_sticky_two_consumers_one_topic_one_partition(mocker): assert_assignment(sticky_assignment, expected_assignment) -def test_sticky_two_consumers_one_topic_two_partitions(mocker): +def test_sticky_two_consumers_one_topic_two_partitions(mocker: MockerFixture) -> None: cluster = create_cluster(mocker, topics={"t"}, topics_partitions={0, 1}) subscriptions = { @@ -340,7 +351,9 @@ def test_sticky_two_consumers_one_topic_two_partitions(mocker): assert_assignment(sticky_assignment, expected_assignment) -def test_sticky_multiple_consumers_mixed_topic_subscriptions(mocker): +def test_sticky_multiple_consumers_mixed_topic_subscriptions( + mocker: MockerFixture, +) -> None: partitions = {"t1": {0, 1, 2}, "t2": {0, 1}} cluster = create_cluster( mocker, topics={"t1", "t2"}, topic_partitions_lambda=lambda t: partitions[t] @@ -368,7 +381,7 @@ def test_sticky_multiple_consumers_mixed_topic_subscriptions(mocker): assert_assignment(sticky_assignment, expected_assignment) -def test_sticky_add_remove_consumer_one_topic(mocker): +def test_sticky_add_remove_consumer_one_topic(mocker: MockerFixture) -> None: cluster = create_cluster(mocker, topics={"t"}, topics_partitions={0, 1, 2}) subscriptions = { @@ -411,7 +424,7 @@ def test_sticky_add_remove_consumer_one_topic(mocker): assert len(assignment["C2"].assignment[0][1]) == 3 -def test_sticky_add_remove_topic_two_consumers(mocker): +def test_sticky_add_remove_topic_two_consumers(mocker: MockerFixture) -> None: cluster = create_cluster(mocker, topics={"t1", "t2"}, topics_partitions={0, 1, 2}) subscriptions = { @@ -474,7 +487,7 @@ def test_sticky_add_remove_topic_two_consumers(mocker): assert_assignment(sticky_assignment, expected_assignment) -def test_sticky_reassignment_after_one_consumer_leaves(mocker): +def test_sticky_reassignment_after_one_consumer_leaves(mocker: MockerFixture) -> None: partitions = {f"t{i}": set(range(i)) for i in range(1, 20)} cluster = create_cluster( mocker, @@ -503,10 +516,11 @@ def test_sticky_reassignment_after_one_consumer_leaves(mocker): assignment = StickyPartitionAssignor.assign(cluster, member_metadata) verify_validity_and_balance(subscriptions, assignment) + assert StickyPartitionAssignor._latest_partition_movements assert StickyPartitionAssignor._latest_partition_movements.are_sticky() -def test_sticky_reassignment_after_one_consumer_added(mocker): +def test_sticky_reassignment_after_one_consumer_added(mocker: MockerFixture) -> None: cluster = create_cluster(mocker, topics={"t"}, topics_partitions=set(range(20))) subscriptions = defaultdict(set) @@ -526,10 +540,11 @@ def test_sticky_reassignment_after_one_consumer_added(mocker): ) assignment = StickyPartitionAssignor.assign(cluster, member_metadata) verify_validity_and_balance(subscriptions, assignment) + assert StickyPartitionAssignor._latest_partition_movements assert StickyPartitionAssignor._latest_partition_movements.are_sticky() -def test_sticky_same_subscriptions(mocker): +def test_sticky_same_subscriptions(mocker: MockerFixture) -> None: partitions = {f"t{i}": set(range(i)) for i in range(1, 15)} cluster = create_cluster( mocker, @@ -555,10 +570,13 @@ def test_sticky_same_subscriptions(mocker): ) assignment = StickyPartitionAssignor.assign(cluster, member_metadata) verify_validity_and_balance(subscriptions, assignment) + assert StickyPartitionAssignor._latest_partition_movements assert StickyPartitionAssignor._latest_partition_movements.are_sticky() -def test_sticky_large_assignment_with_multiple_consumers_leaving(mocker): +def test_sticky_large_assignment_with_multiple_consumers_leaving( + mocker: MockerFixture, +) -> None: n_topics = 40 n_consumers = 200 @@ -592,10 +610,11 @@ def test_sticky_large_assignment_with_multiple_consumers_leaving(mocker): assignment = StickyPartitionAssignor.assign(cluster, member_metadata) verify_validity_and_balance(subscriptions, assignment) + assert StickyPartitionAssignor._latest_partition_movements assert StickyPartitionAssignor._latest_partition_movements.are_sticky() -def test_new_subscription(mocker): +def test_new_subscription(mocker: MockerFixture) -> None: cluster = create_cluster( mocker, topics={"t1", "t2", "t3", "t4"}, topics_partitions={0} ) @@ -617,10 +636,11 @@ def test_new_subscription(mocker): assignment = StickyPartitionAssignor.assign(cluster, member_metadata) verify_validity_and_balance(subscriptions, assignment) + assert StickyPartitionAssignor._latest_partition_movements assert StickyPartitionAssignor._latest_partition_movements.are_sticky() -def test_move_existing_assignments(mocker): +def test_move_existing_assignments(mocker: MockerFixture) -> None: cluster = create_cluster( mocker, topics={"t1", "t2", "t3", "t4", "t5", "t6"}, topics_partitions={0} ) @@ -650,7 +670,7 @@ def test_move_existing_assignments(mocker): verify_validity_and_balance(subscriptions, assignment) -def test_stickiness(mocker): +def test_stickiness(mocker: MockerFixture) -> None: cluster = create_cluster(mocker, topics={"t"}, topics_partitions={0, 1, 2}) subscriptions = { "C1": {"t"}, @@ -680,6 +700,7 @@ def test_stickiness(mocker): assignment = StickyPartitionAssignor.assign(cluster, member_metadata) verify_validity_and_balance(subscriptions, assignment) + assert StickyPartitionAssignor._latest_partition_movements assert StickyPartitionAssignor._latest_partition_movements.are_sticky() for consumer, consumer_assignment in assignment.items(): @@ -692,8 +713,8 @@ def test_stickiness(mocker): ), f"Stickiness was not honored for consumer {consumer}" -def test_assignment_updated_for_deleted_topic(mocker): - def topic_partitions(topic): +def test_assignment_updated_for_deleted_topic(mocker: MockerFixture) -> None: + def topic_partitions(topic: str) -> Optional[Set[int]]: if topic == "t1": return {0} elif topic == "t3": @@ -720,7 +741,9 @@ def topic_partitions(topic): assert_assignment(sticky_assignment, expected_assignment) -def test_no_exceptions_when_only_subscribed_topic_is_deleted(mocker): +def test_no_exceptions_when_only_subscribed_topic_is_deleted( + mocker: MockerFixture, +) -> None: cluster = create_cluster(mocker, topics={"t"}, topics_partitions={0, 1, 2}) subscriptions = { @@ -737,7 +760,7 @@ def test_no_exceptions_when_only_subscribed_topic_is_deleted(mocker): assert_assignment(sticky_assignment, expected_assignment) subscriptions = { - "C": {}, + "C": set(), } member_metadata = {} for member, topics in subscriptions.items(): @@ -745,7 +768,7 @@ def test_no_exceptions_when_only_subscribed_topic_is_deleted(mocker): topics, sticky_assignment[member].partitions() ) - cluster = create_cluster(mocker, topics={}, topics_partitions={}) + cluster = create_cluster(mocker, topics=set(), topics_partitions=set()) sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) expected_assignment = { "C": ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [], b""), @@ -753,7 +776,7 @@ def test_no_exceptions_when_only_subscribed_topic_is_deleted(mocker): assert_assignment(sticky_assignment, expected_assignment) -def test_conflicting_previous_assignments(mocker): +def test_conflicting_previous_assignments(mocker: MockerFixture) -> None: cluster = create_cluster(mocker, topics={"t"}, topics_partitions={0, 1}) subscriptions = { @@ -776,12 +799,14 @@ def test_conflicting_previous_assignments(mocker): [(i, randint(10, 20), randint(20, 40)) for i in range(100)], ) def test_reassignment_with_random_subscriptions_and_changes( - mocker, execution_number, n_topics, n_consumers -): + mocker: MockerFixture, execution_number: int, n_topics: int, n_consumers: int +) -> None: all_topics = sorted([f"t{i}" for i in range(1, n_topics + 1)]) partitions = {t: set(range(1, i + 1)) for i, t in enumerate(all_topics)} cluster = create_cluster( - mocker, topics=all_topics, topic_partitions_lambda=lambda t: partitions[t] + mocker, + topics=all_topics, # type: ignore[arg-type] + topic_partitions_lambda=lambda t: partitions[t], ) subscriptions = defaultdict(set) @@ -807,10 +832,11 @@ def test_reassignment_with_random_subscriptions_and_changes( assignment = StickyPartitionAssignor.assign(cluster, member_metadata) verify_validity_and_balance(subscriptions, assignment) + assert StickyPartitionAssignor._latest_partition_movements assert StickyPartitionAssignor._latest_partition_movements.are_sticky() -def test_assignment_with_multiple_generations1(mocker): +def test_assignment_with_multiple_generations1(mocker: MockerFixture) -> None: cluster = create_cluster(mocker, topics={"t"}, topics_partitions={0, 1, 2, 3, 4, 5}) member_metadata = { @@ -842,6 +868,7 @@ def test_assignment_with_multiple_generations1(mocker): partition in assignment2["C2"].assignment[0][1] for partition in assignment1["C2"].assignment[0][1] ) + assert StickyPartitionAssignor._latest_partition_movements assert StickyPartitionAssignor._latest_partition_movements.are_sticky() member_metadata = { @@ -857,10 +884,11 @@ def test_assignment_with_multiple_generations1(mocker): verify_validity_and_balance({"C2": {"t"}, "C3": {"t"}}, assignment3) assert len(assignment3["C2"].assignment[0][1]) == 3 assert len(assignment3["C3"].assignment[0][1]) == 3 + assert StickyPartitionAssignor._latest_partition_movements assert StickyPartitionAssignor._latest_partition_movements.are_sticky() -def test_assignment_with_multiple_generations2(mocker): +def test_assignment_with_multiple_generations2(mocker: MockerFixture) -> None: cluster = create_cluster(mocker, topics={"t"}, topics_partitions={0, 1, 2, 3, 4, 5}) member_metadata = { @@ -888,6 +916,7 @@ def test_assignment_with_multiple_generations2(mocker): partition in assignment2["C2"].assignment[0][1] for partition in assignment1["C2"].assignment[0][1] ) + assert StickyPartitionAssignor._latest_partition_movements assert StickyPartitionAssignor._latest_partition_movements.are_sticky() member_metadata = { @@ -904,6 +933,7 @@ def test_assignment_with_multiple_generations2(mocker): assignment3 = StickyPartitionAssignor.assign(cluster, member_metadata) verify_validity_and_balance({"C1": {"t"}, "C2": {"t"}, "C3": {"t"}}, assignment3) + assert StickyPartitionAssignor._latest_partition_movements assert StickyPartitionAssignor._latest_partition_movements.are_sticky() assert set(assignment3["C1"].assignment[0][1]) == set( assignment1["C1"].assignment[0][1] @@ -917,7 +947,9 @@ def test_assignment_with_multiple_generations2(mocker): @pytest.mark.parametrize("execution_number", range(50)) -def test_assignment_with_conflicting_previous_generations(mocker, execution_number): +def test_assignment_with_conflicting_previous_generations( + mocker: MockerFixture, execution_number: int +) -> None: cluster = create_cluster(mocker, topics={"t"}, topics_partitions={0, 1, 2, 3, 4, 5}) member_assignments = { @@ -930,7 +962,7 @@ def test_assignment_with_conflicting_previous_generations(mocker, execution_numb "C2": 1, "C3": 2, } - member_metadata = {} + member_metadata: Dict[str, ConsumerProtocolMemberMetadata] = {} for member in member_assignments: member_metadata[member] = StickyPartitionAssignor._metadata( {"t"}, member_assignments[member], member_generations[member] @@ -938,17 +970,23 @@ def test_assignment_with_conflicting_previous_generations(mocker, execution_numb assignment = StickyPartitionAssignor.assign(cluster, member_metadata) verify_validity_and_balance({"C1": {"t"}, "C2": {"t"}, "C3": {"t"}}, assignment) + assert StickyPartitionAssignor._latest_partition_movements assert StickyPartitionAssignor._latest_partition_movements.are_sticky() -def make_member_metadata(subscriptions): - member_metadata = {} +def make_member_metadata( + subscriptions: Dict[str, Set[str]], +) -> Dict[str, ConsumerProtocolMemberMetadata]: + member_metadata: Dict[str, ConsumerProtocolMemberMetadata] = {} for member, topics in subscriptions.items(): member_metadata[member] = StickyPartitionAssignor._metadata(topics, []) return member_metadata -def assert_assignment(result_assignment, expected_assignment): +def assert_assignment( + result_assignment: Dict[str, ConsumerProtocolMemberAssignment], + expected_assignment: Dict[str, ConsumerProtocolMemberAssignment], +) -> None: assert result_assignment == expected_assignment assert set(result_assignment) == set(expected_assignment) for member in result_assignment: @@ -957,7 +995,10 @@ def assert_assignment(result_assignment, expected_assignment): ) -def verify_validity_and_balance(subscriptions, assignment): +def verify_validity_and_balance( + subscriptions: Dict[str, Set[str]], + assignment: Dict[str, ConsumerProtocolMemberAssignment], +) -> None: """ Verifies that the given assignment is valid with respect to the given subscriptions Validity requirements: @@ -1027,7 +1068,9 @@ def verify_validity_and_balance(subscriptions, assignment): ) -def group_partitions_by_topic(partitions): +def group_partitions_by_topic( + partitions: Sequence[TopicPartition], +) -> Dict[str, Set[int]]: result = defaultdict(set) for p in partitions: result[p.topic].add(p.partition) diff --git a/tests/coordinator/test_partition_movements.py b/tests/coordinator/test_partition_movements.py index d901e4fe..1029f2e7 100644 --- a/tests/coordinator/test_partition_movements.py +++ b/tests/coordinator/test_partition_movements.py @@ -2,12 +2,12 @@ from aiokafka.structs import TopicPartition -def test_empty_movements_are_sticky(): +def test_empty_movements_are_sticky() -> None: partition_movements = PartitionMovements() assert partition_movements.are_sticky() -def test_sticky_movements(): +def test_sticky_movements() -> None: partition_movements = PartitionMovements() partition_movements.move_partition(TopicPartition("t", 1), "C1", "C2") partition_movements.move_partition(TopicPartition("t", 1), "C2", "C3") @@ -15,7 +15,7 @@ def test_sticky_movements(): assert partition_movements.are_sticky() -def test_should_detect_non_sticky_assignment(): +def test_should_detect_non_sticky_assignment() -> None: partition_movements = PartitionMovements() partition_movements.move_partition(TopicPartition("t", 1), "C1", "C2") partition_movements.move_partition(TopicPartition("t", 2), "C2", "C1")