From 1862620f2cf1fb7e7ca1147d7dba651314b355ff Mon Sep 17 00:00:00 2001 From: Dmitriy Date: Sun, 21 Apr 2024 23:24:33 +0500 Subject: [PATCH] add typing to aiokafka/protocol/* (#999) * add typing to aiokafka/protocol/* * fix review * fix VarInt64 * fix review tuple -> list * fix review * fix review * move ALL_TOPICS/NO_TOPICS to docs * remove default values from Message() * fix checking abstractproperty in test * fix review * fix review (from docstrings to comments) * fix: collections.abc.Sequence -> typing.Sequence * fix review: Message * add FIXME * fix review: Message * use NotImplemented instead of False --- Makefile | 1 + aiokafka/protocol/abstract.py | 12 +- aiokafka/protocol/admin.py | 13 ++- aiokafka/protocol/api.py | 77 ++++++++----- aiokafka/protocol/fetch.py | 6 +- aiokafka/protocol/message.py | 199 +++++++++++++++++++++++---------- aiokafka/protocol/metadata.py | 34 ++++-- aiokafka/protocol/produce.py | 32 +++--- aiokafka/protocol/struct.py | 24 ++-- aiokafka/protocol/types.py | 204 +++++++++++++++++++++------------- tests/test_protocol.py | 32 ++++-- 11 files changed, 415 insertions(+), 219 deletions(-) diff --git a/Makefile b/Makefile index ec298661..ca394ed2 100644 --- a/Makefile +++ b/Makefile @@ -11,6 +11,7 @@ FORMATTED_AREAS=\ aiokafka/helpers.py \ aiokafka/structs.py \ aiokafka/util.py \ + aiokafka/protocol/ \ tests/test_codec.py \ tests/test_helpers.py diff --git a/aiokafka/protocol/abstract.py b/aiokafka/protocol/abstract.py index 117d058e..c466357e 100644 --- a/aiokafka/protocol/abstract.py +++ b/aiokafka/protocol/abstract.py @@ -1,15 +1,19 @@ import abc +from io import BytesIO +from typing import Generic, TypeVar +T = TypeVar("T") -class AbstractType(metaclass=abc.ABCMeta): + +class AbstractType(Generic[T], metaclass=abc.ABCMeta): @classmethod @abc.abstractmethod - def encode(cls, value): ... + def encode(cls, value: T) -> bytes: ... @classmethod @abc.abstractmethod - def decode(cls, data): ... + def decode(cls, data: BytesIO) -> T: ... @classmethod - def repr(cls, value): + def repr(cls, value: T) -> str: return repr(value) diff --git a/aiokafka/protocol/admin.py b/aiokafka/protocol/admin.py index 2bb17eeb..2f374286 100644 --- a/aiokafka/protocol/admin.py +++ b/aiokafka/protocol/admin.py @@ -1,3 +1,5 @@ +from typing import Dict, Iterable, Optional, Tuple + from .api import Request, Response from .types import ( Array, @@ -429,8 +431,8 @@ class DescribeGroupsResponse_v3(Response): ("member_assignment", Bytes), ), ), + ("authorized_operations", Int32), ), - ("authorized_operations", Int32), ), ) @@ -1119,7 +1121,7 @@ class DeleteGroupsRequest_v1(Request): DeleteGroupsResponse = [DeleteGroupsResponse_v0, DeleteGroupsResponse_v1] -class DescribeClientQuotasResponse_v0(Request): +class DescribeClientQuotasResponse_v0(Response): API_KEY = 48 API_VERSION = 0 SCHEMA = Schema( @@ -1385,7 +1387,12 @@ class DeleteRecordsRequest_v2(Request): ("tags", TaggedFields), ) - def __init__(self, topics, timeout_ms, tags=None): + def __init__( + self, + topics: Iterable[Tuple[str, Iterable[Tuple[int, int]]]], + timeout_ms: int, + tags: Optional[Dict[int, bytes]] = None, + ) -> None: super().__init__( [ ( diff --git a/aiokafka/protocol/api.py b/aiokafka/protocol/api.py index 77a7a485..1e6ee3b6 100644 --- a/aiokafka/protocol/api.py +++ b/aiokafka/protocol/api.py @@ -1,4 +1,8 @@ +from __future__ import annotations + import abc +from io import BytesIO +from typing import Any, ClassVar, Dict, Optional, Type, Union from .struct import Struct from .types import Array, Int16, Int32, Schema, String, TaggedFields @@ -12,7 +16,9 @@ class RequestHeader_v0(Struct): ("client_id", String("utf-8")), ) - def __init__(self, request, correlation_id=0, client_id="aiokafka"): + def __init__( + self, request: Request, correlation_id: int = 0, client_id: str = "aiokafka" + ) -> None: super().__init__( request.API_KEY, request.API_VERSION, correlation_id, client_id ) @@ -28,7 +34,13 @@ class RequestHeader_v1(Struct): ("tags", TaggedFields), ) - def __init__(self, request, correlation_id=0, client_id="aiokafka", tags=None): + def __init__( + self, + request: Request, + correlation_id: int = 0, + client_id: str = "aiokafka", + tags: Optional[Dict[int, bytes]] = None, + ): super().__init__( request.API_KEY, request.API_VERSION, correlation_id, client_id, tags or {} ) @@ -48,32 +60,38 @@ class ResponseHeader_v1(Struct): class Request(Struct, metaclass=abc.ABCMeta): - FLEXIBLE_VERSION = False + FLEXIBLE_VERSION: ClassVar[bool] = False - @abc.abstractproperty - def API_KEY(self): + @property + @abc.abstractmethod + def API_KEY(self) -> int: """Integer identifier for api request""" - @abc.abstractproperty - def API_VERSION(self): + @property + @abc.abstractmethod + def API_VERSION(self) -> int: """Integer of api request version""" - @abc.abstractproperty - def SCHEMA(self): - """An instance of Schema() representing the request structure""" - - @abc.abstractproperty - def RESPONSE_TYPE(self): + @property + @abc.abstractmethod + def RESPONSE_TYPE(self) -> Type[Response]: """The Response class associated with the api request""" - def expect_response(self): + @property + @abc.abstractmethod + def SCHEMA(self) -> Schema: + """An instance of Schema() representing the request structure""" + + def expect_response(self) -> bool: """Override this method if an api request does not always generate a response""" return True - def to_object(self): + def to_object(self) -> Dict[str, Any]: return _to_object(self.SCHEMA, self) - def build_request_header(self, correlation_id, client_id): + def build_request_header( + self, correlation_id: int, client_id: str + ) -> Union[RequestHeader_v0, RequestHeader_v1]: if self.FLEXIBLE_VERSION: return RequestHeader_v1( self, correlation_id=correlation_id, client_id=client_id @@ -82,31 +100,36 @@ def build_request_header(self, correlation_id, client_id): self, correlation_id=correlation_id, client_id=client_id ) - def parse_response_header(self, read_buffer): + def parse_response_header( + self, read_buffer: Union[BytesIO, bytes] + ) -> Union[ResponseHeader_v0, ResponseHeader_v1]: if self.FLEXIBLE_VERSION: return ResponseHeader_v1.decode(read_buffer) return ResponseHeader_v0.decode(read_buffer) class Response(Struct, metaclass=abc.ABCMeta): - @abc.abstractproperty - def API_KEY(self): + @property + @abc.abstractmethod + def API_KEY(self) -> int: """Integer identifier for api request/response""" - @abc.abstractproperty - def API_VERSION(self): + @property + @abc.abstractmethod + def API_VERSION(self) -> int: """Integer of api request/response version""" - @abc.abstractproperty - def SCHEMA(self): + @property + @abc.abstractmethod + def SCHEMA(self) -> Schema: """An instance of Schema() representing the response structure""" - def to_object(self): + def to_object(self) -> Dict[str, Any]: return _to_object(self.SCHEMA, self) -def _to_object(schema, data): - obj = {} +def _to_object(schema: Schema, data: Union[Struct, Dict[int, Any]]) -> Dict[str, Any]: + obj: Dict[str, Any] = {} for idx, (name, _type) in enumerate(zip(schema.names, schema.fields)): if isinstance(data, Struct): val = data.get_item(name) @@ -116,7 +139,7 @@ def _to_object(schema, data): if isinstance(_type, Schema): obj[name] = _to_object(_type, val) elif isinstance(_type, Array): - if isinstance(_type.array_of, (Array, Schema)): + if isinstance(_type.array_of, Schema): obj[name] = [_to_object(_type.array_of, x) for x in val] else: obj[name] = val diff --git a/aiokafka/protocol/fetch.py b/aiokafka/protocol/fetch.py index 56cbdd73..c63256d7 100644 --- a/aiokafka/protocol/fetch.py +++ b/aiokafka/protocol/fetch.py @@ -376,7 +376,7 @@ class FetchRequest_v7(Request): ), ( "forgotten_topics_data", - Array(("topic", String), ("partitions", Array(Int32))), + Array(("topic", String("utf-8")), ("partitions", Array(Int32))), ), ) @@ -428,7 +428,7 @@ class FetchRequest_v9(Request): ( "forgotten_topics_data", Array( - ("topic", String), + ("topic", String("utf-8")), ("partitions", Array(Int32)), ), ), @@ -480,7 +480,7 @@ class FetchRequest_v11(Request): ), ( "forgotten_topics_data", - Array(("topic", String), ("partitions", Array(Int32))), + Array(("topic", String("utf-8")), ("partitions", Array(Int32))), ), ("rack_id", String("utf-8")), ) diff --git a/aiokafka/protocol/message.py b/aiokafka/protocol/message.py index 31993fe6..77103af4 100644 --- a/aiokafka/protocol/message.py +++ b/aiokafka/protocol/message.py @@ -1,6 +1,9 @@ import io import time from binascii import crc32 +from typing import Iterable, List, Literal, Optional, Tuple, Union, cast, overload + +from typing_extensions import Self from aiokafka.codec import ( gzip_decode, @@ -15,25 +18,34 @@ from aiokafka.errors import UnsupportedCodecError from .struct import Struct -from .types import AbstractType, Bytes, Int8, Int32, Int64, Schema, UInt32 +from .types import Bytes, Int8, Int32, Int64, Schema, UInt32 class Message(Struct): + # FIXME: override __eq__/__repr__ methods from Struct + + BASE_FIELDS = ( + ("crc", UInt32), + ("magic", Int8), + ("attributes", Int8), + ) + MAGIC0_FIELDS = ( + ("key", Bytes), + ("value", Bytes), + ) + MAGIC1_FIELDS = ( + ("timestamp", Int64), + ("key", Bytes), + ("value", Bytes), + ) SCHEMAS = [ Schema( - ("crc", UInt32), - ("magic", Int8), - ("attributes", Int8), - ("key", Bytes), - ("value", Bytes), + *BASE_FIELDS, + *MAGIC0_FIELDS, ), Schema( - ("crc", UInt32), - ("magic", Int8), - ("attributes", Int8), - ("timestamp", Int64), - ("key", Bytes), - ("value", Bytes), + *BASE_FIELDS, + *MAGIC1_FIELDS, ), ] SCHEMA = SCHEMAS[1] @@ -47,7 +59,39 @@ class Message(Struct): 22 # crc(4), magic(1), attributes(1), timestamp(8), key+value size(4*2) ) - def __init__(self, value, key=None, magic=0, attributes=0, crc=0, timestamp=None): + @overload + def __init__( + self, + *, + value: Optional[bytes], + key: Optional[bytes], + magic: Literal[0], + attributes: int, + crc: int, + ) -> None: ... + + @overload + def __init__( + self, + *, + value: Optional[bytes], + key: Optional[bytes], + magic: Literal[1], + attributes: int, + crc: int, + timestamp: int, + ) -> None: ... + + def __init__( + self, + *, + value: Optional[bytes], + key: Optional[bytes], + magic: Literal[0, 1], + attributes: int, + crc: int, + timestamp: Optional[int] = None, + ) -> None: assert value is None or isinstance(value, bytes), "value must be bytes" assert key is None or isinstance(key, bytes), "key must be bytes" assert magic > 0 or timestamp is None, "timestamp not supported in v0" @@ -57,14 +101,14 @@ def __init__(self, value, key=None, magic=0, attributes=0, crc=0, timestamp=None timestamp = int(time.time() * 1000) self.timestamp = timestamp self.crc = crc - self._validated_crc = None + self._validated_crc: Optional[int] = None self.magic = magic self.attributes = attributes self.key = key self.value = value @property - def timestamp_type(self): + def timestamp_type(self) -> Optional[Literal[0, 1]]: """0 for CreateTime; 1 for LogAppendTime; None if unsupported. Value is determined by broker; produced messages should always set to 0 @@ -77,55 +121,78 @@ def timestamp_type(self): else: return 0 - def encode(self, recalc_crc=True): + def encode(self, recalc_crc: bool = True) -> bytes: version = self.magic if version == 1: - fields = ( - self.crc, - self.magic, - self.attributes, - self.timestamp, - self.key, - self.value, + message = Message.SCHEMAS[version].encode( + ( + self.crc, + self.magic, + self.attributes, + self.timestamp, + self.key, + self.value, + ) ) elif version == 0: - fields = (self.crc, self.magic, self.attributes, self.key, self.value) + message = Message.SCHEMAS[version].encode( + (self.crc, self.magic, self.attributes, self.key, self.value) + ) else: raise ValueError(f"Unrecognized message version: {version}") - message = Message.SCHEMAS[version].encode(fields) if not recalc_crc: return message self.crc = crc32(message[4:]) - crc_field = self.SCHEMAS[version].fields[0] + crc_field = self.BASE_FIELDS[0][1] return crc_field.encode(self.crc) + message[4:] @classmethod - def decode(cls, data): - _validated_crc = None + def decode(cls, data: Union[io.BytesIO, bytes]) -> Self: + _validated_crc: Optional[int] = None if isinstance(data, bytes): _validated_crc = crc32(data[4:]) data = io.BytesIO(data) # Partial decode required to determine message version - base_fields = cls.SCHEMAS[0].fields[0:3] - crc, magic, attributes = (field.decode(data) for field in base_fields) - remaining = cls.SCHEMAS[magic].fields[3:] - fields = [field.decode(data) for field in remaining] + crc, magic, attributes = ( + cls.BASE_FIELDS[0][1].decode(data), + cls.BASE_FIELDS[1][1].decode(data), + cls.BASE_FIELDS[2][1].decode(data), + ) if magic == 1: - timestamp = fields[0] + magic = cast(Literal[1], magic) + timestamp, key, value = ( + cls.MAGIC1_FIELDS[0][1].decode(data), + cls.MAGIC1_FIELDS[1][1].decode(data), + cls.MAGIC1_FIELDS[2][1].decode(data), + ) + msg = cls( + value=value, + key=key, + magic=magic, + attributes=attributes, + crc=crc, + timestamp=timestamp, + ) + elif magic == 0: + magic = cast(Literal[0], magic) + key, value = ( + cls.MAGIC0_FIELDS[0][1].decode(data), + cls.MAGIC0_FIELDS[1][1].decode(data), + ) + msg = cls( + value=value, + key=key, + magic=magic, + attributes=attributes, + crc=crc, + ) else: - timestamp = None - msg = cls( - fields[-1], - key=fields[-2], - magic=magic, - attributes=attributes, - crc=crc, - timestamp=timestamp, - ) + raise ValueError(f"Unrecognized message version: {magic}") + msg._validated_crc = _validated_crc return msg - def validate_crc(self): + def validate_crc(self) -> bool: if self._validated_crc is None: raw_msg = self.encode(recalc_crc=False) self._validated_crc = crc32(raw_msg[4:]) @@ -133,10 +200,13 @@ def validate_crc(self): return True return False - def is_compressed(self): + def is_compressed(self) -> bool: return self.attributes & self.CODEC_MASK != 0 - def decompress(self): + def decompress( + self, + ) -> List[Union[Tuple[int, int, "Message"], Tuple[None, None, "PartialMessage"]]]: + assert self.value is not None codec = self.attributes & self.CODEC_MASK assert codec in ( self.CODEC_GZIP, @@ -167,21 +237,25 @@ def decompress(self): return MessageSet.decode(raw_bytes, bytes_to_read=len(raw_bytes)) - def __hash__(self): + def __hash__(self) -> int: return hash(self.encode(recalc_crc=False)) class PartialMessage(bytes): - def __repr__(self): - return f"PartialMessage({self})" + def __repr__(self) -> str: + return f"PartialMessage({self!r})" -class MessageSet(AbstractType): +class MessageSet: ITEM = Schema(("offset", Int64), ("message", Bytes)) HEADER_SIZE = 12 # offset + message_size @classmethod - def encode(cls, items, prepend_size=True): + def encode( + cls, + items: Union[io.BytesIO, Iterable[Tuple[int, bytes]]], + prepend_size: bool = True, + ) -> bytes: # RecordAccumulator encodes messagesets internally if isinstance(items, io.BytesIO): size = Int32.decode(items) @@ -191,7 +265,7 @@ def encode(cls, items, prepend_size=True): size += 4 return items.read(size) - encoded_values = [] + encoded_values: List[bytes] = [] for offset, message in items: encoded_values.append(Int64.encode(offset)) encoded_values.append(Bytes.encode(message)) @@ -202,7 +276,9 @@ def encode(cls, items, prepend_size=True): return encoded @classmethod - def decode(cls, data, bytes_to_read=None): + def decode( + cls, data: Union[io.BytesIO, bytes], bytes_to_read: Optional[int] = None + ) -> List[Union[Tuple[int, int, Message], Tuple[None, None, PartialMessage]]]: """Compressed messages should pass in bytes_to_read (via message size) otherwise, we decode from data as Int32 """ @@ -216,11 +292,14 @@ def decode(cls, data, bytes_to_read=None): # So create an internal buffer to avoid over-reading raw = io.BytesIO(data.read(bytes_to_read)) - items = [] + items: List[ + Union[Tuple[int, int, Message], Tuple[None, None, PartialMessage]] + ] = [] try: while bytes_to_read: offset = Int64.decode(raw) msg_bytes = Bytes.decode(raw) + assert msg_bytes is not None bytes_to_read -= 8 + 4 + len(msg_bytes) items.append( (offset, len(msg_bytes), Message.decode(msg_bytes)), @@ -233,10 +312,18 @@ def decode(cls, data, bytes_to_read=None): return items @classmethod - def repr(cls, messages): + def repr( + cls, + messages: Union[ + io.BytesIO, + List[Union[Tuple[int, int, Message], Tuple[None, None, PartialMessage]]], + ], + ) -> str: if isinstance(messages, io.BytesIO): offset = messages.tell() decoded = cls.decode(messages) messages.seek(offset) - messages = decoded - return str([cls.ITEM.repr(m) for m in messages]) + decoded_messages = decoded + else: + decoded_messages = messages + return str([cls.ITEM.repr(m) for m in decoded_messages]) diff --git a/aiokafka/protocol/metadata.py b/aiokafka/protocol/metadata.py index 79a5600a..2c9ca624 100644 --- a/aiokafka/protocol/metadata.py +++ b/aiokafka/protocol/metadata.py @@ -183,49 +183,59 @@ class MetadataResponse_v5(Response): class MetadataRequest_v0(Request): + # topics: + # None: Empty Array (len 0) for topics returns all topics + API_KEY = 3 API_VERSION = 0 RESPONSE_TYPE = MetadataResponse_v0 SCHEMA = Schema(("topics", Array(String("utf-8")))) - ALL_TOPICS = None # Empty Array (len 0) for topics returns all topics class MetadataRequest_v1(Request): + # topics: + # -1: Null Array (len -1) for topics returns all topics + # None: Empty array (len 0) for topics returns no topics + API_KEY = 3 API_VERSION = 1 RESPONSE_TYPE = MetadataResponse_v1 SCHEMA = MetadataRequest_v0.SCHEMA - ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics - NO_TOPICS = None # Empty array (len 0) for topics returns no topics class MetadataRequest_v2(Request): + # topics: + # -1: Null Array (len -1) for topics returns all topics + # None: Empty array (len 0) for topics returns no topics + API_KEY = 3 API_VERSION = 2 RESPONSE_TYPE = MetadataResponse_v2 SCHEMA = MetadataRequest_v1.SCHEMA - ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics - NO_TOPICS = None # Empty array (len 0) for topics returns no topics class MetadataRequest_v3(Request): + # topics: + # -1: Null Array (len -1) for topics returns all topics + # None: Empty array (len 0) for topics returns no topics + API_KEY = 3 API_VERSION = 3 RESPONSE_TYPE = MetadataResponse_v3 SCHEMA = MetadataRequest_v1.SCHEMA - ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics - NO_TOPICS = None # Empty array (len 0) for topics returns no topics class MetadataRequest_v4(Request): + # topics: + # -1: Null Array (len -1) for topics returns all topics + # None: Empty array (len 0) for topics returns no topics + API_KEY = 3 API_VERSION = 4 RESPONSE_TYPE = MetadataResponse_v4 SCHEMA = Schema( ("topics", Array(String("utf-8"))), ("allow_auto_topic_creation", Boolean) ) - ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics - NO_TOPICS = None # Empty array (len 0) for topics returns no topics class MetadataRequest_v5(Request): @@ -234,12 +244,14 @@ class MetadataRequest_v5(Request): An additional field for offline_replicas has been added to the v5 metadata response """ + # topics: + # -1: Null Array (len -1) for topics returns all topics + # None: Empty array (len 0) for topics returns no topics + API_KEY = 3 API_VERSION = 5 RESPONSE_TYPE = MetadataResponse_v5 SCHEMA = MetadataRequest_v4.SCHEMA - ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics - NO_TOPICS = None # Empty array (len 0) for topics returns no topics MetadataRequest = [ diff --git a/aiokafka/protocol/produce.py b/aiokafka/protocol/produce.py index e55f616a..6b69fc31 100644 --- a/aiokafka/protocol/produce.py +++ b/aiokafka/protocol/produce.py @@ -148,17 +148,15 @@ class ProduceResponse_v8(Response): ("offset", Int64), ("timestamp", Int64), ("log_start_offset", Int64), - ), - ( - "record_errors", ( + "record_errors", Array( ("batch_index", Int32), ("batch_index_error_message", String("utf-8")), - ) + ), ), + ("error_message", String("utf-8")), ), - ("error_message", String("utf-8")), ), ), ), @@ -166,16 +164,18 @@ class ProduceResponse_v8(Response): ) -class ProduceRequest(Request): +class ProduceRequestBase(Request): API_KEY = 0 - def expect_response(self): + required_acks: int + + def expect_response(self) -> bool: if self.required_acks == 0: return False return True -class ProduceRequest_v0(ProduceRequest): +class ProduceRequest_v0(ProduceRequestBase): API_VERSION = 0 RESPONSE_TYPE = ProduceResponse_v0 SCHEMA = Schema( @@ -191,19 +191,19 @@ class ProduceRequest_v0(ProduceRequest): ) -class ProduceRequest_v1(ProduceRequest): +class ProduceRequest_v1(ProduceRequestBase): API_VERSION = 1 RESPONSE_TYPE = ProduceResponse_v1 SCHEMA = ProduceRequest_v0.SCHEMA -class ProduceRequest_v2(ProduceRequest): +class ProduceRequest_v2(ProduceRequestBase): API_VERSION = 2 RESPONSE_TYPE = ProduceResponse_v2 SCHEMA = ProduceRequest_v1.SCHEMA -class ProduceRequest_v3(ProduceRequest): +class ProduceRequest_v3(ProduceRequestBase): API_VERSION = 3 RESPONSE_TYPE = ProduceResponse_v3 SCHEMA = Schema( @@ -220,7 +220,7 @@ class ProduceRequest_v3(ProduceRequest): ) -class ProduceRequest_v4(ProduceRequest): +class ProduceRequest_v4(ProduceRequestBase): """ The version number is bumped up to indicate that the client supports KafkaStorageException. The KafkaStorageException will be translated to @@ -232,7 +232,7 @@ class ProduceRequest_v4(ProduceRequest): SCHEMA = ProduceRequest_v3.SCHEMA -class ProduceRequest_v5(ProduceRequest): +class ProduceRequest_v5(ProduceRequestBase): """ Same as v4. The version number is bumped since the v5 response includes an additional partition level field: the log_start_offset. @@ -243,7 +243,7 @@ class ProduceRequest_v5(ProduceRequest): SCHEMA = ProduceRequest_v4.SCHEMA -class ProduceRequest_v6(ProduceRequest): +class ProduceRequest_v6(ProduceRequestBase): """ The version number is bumped to indicate that on quota violation brokers send out responses before throttling. @@ -254,7 +254,7 @@ class ProduceRequest_v6(ProduceRequest): SCHEMA = ProduceRequest_v5.SCHEMA -class ProduceRequest_v7(ProduceRequest): +class ProduceRequest_v7(ProduceRequestBase): """ V7 bumped up to indicate ZStandard capability. (see KIP-110) """ @@ -264,7 +264,7 @@ class ProduceRequest_v7(ProduceRequest): SCHEMA = ProduceRequest_v6.SCHEMA -class ProduceRequest_v8(ProduceRequest): +class ProduceRequest_v8(ProduceRequestBase): """ V8 bumped up to add two new fields record_errors offset list and error_message to PartitionResponse (See KIP-467) diff --git a/aiokafka/protocol/struct.py b/aiokafka/protocol/struct.py index ee99c75a..fc1461bf 100644 --- a/aiokafka/protocol/struct.py +++ b/aiokafka/protocol/struct.py @@ -1,13 +1,15 @@ from io import BytesIO +from typing import Any, ClassVar, List, Union + +from typing_extensions import Self -from .abstract import AbstractType from .types import Schema -class Struct(AbstractType): - SCHEMA = Schema() +class Struct: + SCHEMA: ClassVar = Schema() - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: if len(args) == len(self.SCHEMA.fields): for i, name in enumerate(self.SCHEMA.names): self.__dict__[name] = args[i] @@ -23,27 +25,29 @@ def __init__(self, *args, **kwargs): ) ) - def encode(self): + def encode(self) -> bytes: return self.SCHEMA.encode([self.__dict__[name] for name in self.SCHEMA.names]) @classmethod - def decode(cls, data): + def decode(cls, data: Union[BytesIO, bytes]) -> Self: if isinstance(data, bytes): data = BytesIO(data) return cls(*[field.decode(data) for field in cls.SCHEMA.fields]) - def get_item(self, name): + def get_item(self, name: str) -> Any: if name not in self.SCHEMA.names: raise KeyError("%s is not in the schema" % name) return self.__dict__[name] - def __repr__(self): - key_vals = [] + def __repr__(self) -> str: + key_vals: List[str] = [] for name, field in zip(self.SCHEMA.names, self.SCHEMA.fields): key_vals.append(f"{name}={field.repr(self.__dict__[name])}") return self.__class__.__name__ + "(" + ", ".join(key_vals) + ")" - def __eq__(self, other): + def __eq__(self, other: object) -> bool: + if not isinstance(other, Struct): + return NotImplemented if self.SCHEMA != other.SCHEMA: return False for attr in self.SCHEMA.names: diff --git a/aiokafka/protocol/types.py b/aiokafka/protocol/types.py index 7eadf7fb..944783c0 100644 --- a/aiokafka/protocol/types.py +++ b/aiokafka/protocol/types.py @@ -1,10 +1,31 @@ import struct +from io import BytesIO from struct import error +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, + cast, + overload, +) + +from typing_extensions import Buffer, TypeAlias from .abstract import AbstractType +T = TypeVar("T") -def _pack(f, value): +ValueT: TypeAlias = Union[Type[AbstractType[Any]], "String", "Array", "Schema"] + + +def _pack(f: Callable[[T], bytes], value: T) -> bytes: try: return f(value) except error as e: @@ -14,7 +35,7 @@ def _pack(f, value): ) from e -def _unpack(f, data): +def _unpack(f: Callable[[Buffer], Tuple[T, ...]], data: Buffer) -> T: try: (value,) = f(data) except error as e: @@ -26,95 +47,95 @@ def _unpack(f, data): return value -class Int8(AbstractType): +class Int8(AbstractType[int]): _pack = struct.Struct(">b").pack _unpack = struct.Struct(">b").unpack @classmethod - def encode(cls, value): + def encode(cls, value: int) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> int: return _unpack(cls._unpack, data.read(1)) -class Int16(AbstractType): +class Int16(AbstractType[int]): _pack = struct.Struct(">h").pack _unpack = struct.Struct(">h").unpack @classmethod - def encode(cls, value): + def encode(cls, value: int) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> int: return _unpack(cls._unpack, data.read(2)) -class Int32(AbstractType): +class Int32(AbstractType[int]): _pack = struct.Struct(">i").pack _unpack = struct.Struct(">i").unpack @classmethod - def encode(cls, value): + def encode(cls, value: int) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> int: return _unpack(cls._unpack, data.read(4)) -class UInt32(AbstractType): +class UInt32(AbstractType[int]): _pack = struct.Struct(">I").pack _unpack = struct.Struct(">I").unpack @classmethod - def encode(cls, value): + def encode(cls, value: int) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> int: return _unpack(cls._unpack, data.read(4)) -class Int64(AbstractType): +class Int64(AbstractType[int]): _pack = struct.Struct(">q").pack _unpack = struct.Struct(">q").unpack @classmethod - def encode(cls, value): + def encode(cls, value: int) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> int: return _unpack(cls._unpack, data.read(8)) -class Float64(AbstractType): +class Float64(AbstractType[float]): _pack = struct.Struct(">d").pack _unpack = struct.Struct(">d").unpack @classmethod - def encode(cls, value): + def encode(cls, value: float) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> float: return _unpack(cls._unpack, data.read(8)) -class String(AbstractType): - def __init__(self, encoding="utf-8"): +class String: + def __init__(self, encoding: str = "utf-8"): self.encoding = encoding - def encode(self, value): + def encode(self, value: Optional[str]) -> bytes: if value is None: return Int16.encode(-1) - value = str(value).encode(self.encoding) - return Int16.encode(len(value)) + value + encoded_value = str(value).encode(self.encoding) + return Int16.encode(len(encoded_value)) + encoded_value - def decode(self, data): + def decode(self, data: BytesIO) -> Optional[str]: length = Int16.decode(data) if length < 0: return None @@ -123,17 +144,21 @@ def decode(self, data): raise ValueError("Buffer underrun decoding string") return value.decode(self.encoding) + @classmethod + def repr(cls, value: str) -> str: + return repr(value) -class Bytes(AbstractType): + +class Bytes(AbstractType[Optional[bytes]]): @classmethod - def encode(cls, value): + def encode(cls, value: Optional[bytes]) -> bytes: if value is None: return Int32.encode(-1) else: return Int32.encode(len(value)) + value @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> Optional[bytes]: length = Int32.decode(data) if length < 0: return None @@ -143,45 +168,50 @@ def decode(cls, data): return value @classmethod - def repr(cls, value): + def repr(cls, value: Optional[bytes]) -> str: return repr( value[:100] + b"..." if value is not None and len(value) > 100 else value ) -class Boolean(AbstractType): +class Boolean(AbstractType[bool]): _pack = struct.Struct(">?").pack _unpack = struct.Struct(">?").unpack @classmethod - def encode(cls, value): + def encode(cls, value: bool) -> bytes: return _pack(cls._pack, value) @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> bool: return _unpack(cls._unpack, data.read(1)) -class Schema(AbstractType): - def __init__(self, *fields): +class Schema: + names: Tuple[str, ...] + fields: Tuple[ValueT, ...] + + def __init__(self, *fields: Tuple[str, ValueT]): if fields: self.names, self.fields = zip(*fields) else: self.names, self.fields = (), () - def encode(self, item): + def encode(self, item: Sequence[Any]) -> bytes: if len(item) != len(self.fields): raise ValueError("Item field count does not match Schema") return b"".join(field.encode(item[i]) for i, field in enumerate(self.fields)) - def decode(self, data): + def decode( + self, data: BytesIO + ) -> Tuple[Union[Any, str, None, List[Union[Any, Tuple[Any, ...]]]], ...]: return tuple(field.decode(data) for field in self.fields) - def __len__(self): + def __len__(self) -> int: return len(self.fields) - def repr(self, value): - key_vals = [] + def repr(self, value: Any) -> str: + key_vals: List[str] = [] try: for i in range(len(self)): try: @@ -194,19 +224,35 @@ def repr(self, value): return repr(value) -class Array(AbstractType): - def __init__(self, *array_of): - if len(array_of) > 1: - self.array_of = Schema(*array_of) - elif len(array_of) == 1 and ( - isinstance(array_of[0], AbstractType) - or issubclass(array_of[0], AbstractType) - ): - self.array_of = array_of[0] - else: - raise ValueError("Array instantiated with no array_of type") +class Array: + array_of: ValueT + + @overload + def __init__(self, array_of_0: ValueT): ... + + @overload + def __init__( + self, array_of_0: Tuple[str, ValueT], *array_of: Tuple[str, ValueT] + ): ... - def encode(self, items): + def __init__( + self, + array_of_0: Union[ValueT, Tuple[str, ValueT]], + *array_of: Tuple[str, ValueT], + ) -> None: + if array_of: + array_of_0 = cast(Tuple[str, ValueT], array_of_0) + self.array_of = Schema(array_of_0, *array_of) + else: + array_of_0 = cast(ValueT, array_of_0) + if isinstance(array_of_0, (String, Array, Schema)) or issubclass( + array_of_0, AbstractType + ): + self.array_of = array_of_0 + else: + raise ValueError("Array instantiated with no array_of type") + + def encode(self, items: Optional[Sequence[Any]]) -> bytes: if items is None: return Int32.encode(-1) encoded_items = (self.array_of.encode(item) for item in items) @@ -214,22 +260,23 @@ def encode(self, items): (Int32.encode(len(items)), *encoded_items), ) - def decode(self, data): + def decode(self, data: BytesIO) -> Optional[List[Union[Any, Tuple[Any, ...]]]]: length = Int32.decode(data) if length == -1: return None return [self.array_of.decode(data) for _ in range(length)] - def repr(self, list_of_items): + def repr(self, list_of_items: Optional[Sequence[Any]]) -> str: if list_of_items is None: return "NULL" return "[" + ", ".join(self.array_of.repr(item) for item in list_of_items) + "]" -class UnsignedVarInt32(AbstractType): +class UnsignedVarInt32(AbstractType[int]): @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> int: value, i = 0, 0 + b: int while True: (b,) = struct.unpack("B", data.read(1)) if not (b & 0x80): @@ -242,7 +289,7 @@ def decode(cls, data): return value @classmethod - def encode(cls, value): + def encode(cls, value: int) -> bytes: value &= 0xFFFFFFFF ret = b"" while (value & 0xFFFFFF80) != 0: @@ -253,25 +300,26 @@ def encode(cls, value): return ret -class VarInt32(AbstractType): +class VarInt32(AbstractType[int]): @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> int: value = UnsignedVarInt32.decode(data) return (value >> 1) ^ -(value & 1) @classmethod - def encode(cls, value): + def encode(cls, value: int) -> bytes: # bring it in line with the java binary repr value &= 0xFFFFFFFF return UnsignedVarInt32.encode((value << 1) ^ (value >> 31)) -class VarInt64(AbstractType): +class VarInt64(AbstractType[int]): @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> int: value, i = 0, 0 + b: int while True: - b = data.read(1) + (b,) = struct.unpack("B", data.read(1)) if not (b & 0x80): break value |= (b & 0x7F) << i @@ -282,7 +330,7 @@ def decode(cls, data): return (value >> 1) ^ -(value & 1) @classmethod - def encode(cls, value): + def encode(cls, value: int) -> bytes: # bring it in line with the java binary repr value &= 0xFFFFFFFFFFFFFFFF v = (value << 1) ^ (value >> 63) @@ -296,7 +344,7 @@ def encode(cls, value): class CompactString(String): - def decode(self, data): + def decode(self, data: BytesIO) -> Optional[str]: length = UnsignedVarInt32.decode(data) - 1 if length < 0: return None @@ -305,18 +353,18 @@ def decode(self, data): raise ValueError("Buffer underrun decoding string") return value.decode(self.encoding) - def encode(self, value): + def encode(self, value: Optional[str]) -> bytes: if value is None: return UnsignedVarInt32.encode(0) - value = str(value).encode(self.encoding) - return UnsignedVarInt32.encode(len(value) + 1) + value + encoded_value = str(value).encode(self.encoding) + return UnsignedVarInt32.encode(len(encoded_value) + 1) + encoded_value -class TaggedFields(AbstractType): +class TaggedFields(AbstractType[Dict[int, bytes]]): @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> Dict[int, bytes]: num_fields = UnsignedVarInt32.decode(data) - ret = {} + ret: Dict[int, bytes] = {} if not num_fields: return ret prev_tag = -1 @@ -331,20 +379,20 @@ def decode(cls, data): return ret @classmethod - def encode(cls, value): + def encode(cls, value: Dict[int, bytes]) -> bytes: ret = UnsignedVarInt32.encode(len(value)) for k, v in value.items(): # do we allow for other data types ?? It could get complicated really fast - assert isinstance(v, bytes), f"Value {v} is not a byte array" + assert isinstance(v, bytes), f"Value {v!r} is not a byte array" assert isinstance(k, int) and k > 0, f"Key {k} is not a positive integer" ret += UnsignedVarInt32.encode(k) ret += v return ret -class CompactBytes(AbstractType): +class CompactBytes(AbstractType[Optional[bytes]]): @classmethod - def decode(cls, data): + def decode(cls, data: BytesIO) -> Optional[bytes]: length = UnsignedVarInt32.decode(data) - 1 if length < 0: return None @@ -354,7 +402,7 @@ def decode(cls, data): return value @classmethod - def encode(cls, value): + def encode(cls, value: Optional[bytes]) -> bytes: if value is None: return UnsignedVarInt32.encode(0) else: @@ -362,7 +410,7 @@ def encode(cls, value): class CompactArray(Array): - def encode(self, items): + def encode(self, items: Optional[Sequence[Any]]) -> bytes: if items is None: return UnsignedVarInt32.encode(0) encoded_items = (self.array_of.encode(item) for item in items) @@ -370,7 +418,7 @@ def encode(self, items): (UnsignedVarInt32.encode(len(items) + 1), *encoded_items), ) - def decode(self, data): + def decode(self, data: BytesIO) -> Optional[List[Union[Any, Tuple[Any, ...]]]]: length = UnsignedVarInt32.decode(data) - 1 if length == -1: return None diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 56680f07..1d81aea5 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -1,4 +1,3 @@ -import abc import io import struct @@ -24,7 +23,7 @@ def test_create_message(): payload = b"test" key = b"key" - msg = Message(payload, key=key) + msg = Message(value=payload, key=key, magic=0, attributes=0, crc=0) assert msg.magic == 0 assert msg.attributes == 0 assert msg.key == key @@ -32,7 +31,7 @@ def test_create_message(): def test_encode_message_v0(): - message = Message(b"test", key=b"key") + message = Message(value=b"test", key=b"key", magic=0, attributes=0, crc=0) encoded = message.encode() expect = b"".join( [ @@ -48,7 +47,9 @@ def test_encode_message_v0(): def test_encode_message_v1(): - message = Message(b"test", key=b"key", magic=1, timestamp=1234) + message = Message( + value=b"test", key=b"key", magic=1, attributes=0, crc=0, timestamp=1234 + ) encoded = message.encode() expect = b"".join( [ @@ -76,7 +77,7 @@ def test_decode_message(): ] ) decoded_message = Message.decode(encoded) - msg = Message(b"test", key=b"key") + msg = Message(value=b"test", key=b"key", magic=0, attributes=0, crc=0) msg.encode() # crc is recalculated during encoding assert decoded_message == msg @@ -110,7 +111,10 @@ def test_decode_message_validate_crc(): def test_encode_message_set(): - messages = [Message(b"v1", key=b"k1"), Message(b"v2", key=b"k2")] + messages = [ + Message(value=b"v1", key=b"k1", magic=0, attributes=0, crc=0), + Message(value=b"v2", key=b"k2", magic=0, attributes=0, crc=0), + ] encoded = MessageSet.encode([(0, msg.encode()) for msg in messages]) expect = b"".join( [ @@ -166,12 +170,12 @@ def test_decode_message_set(): returned_offset2, message2_size, decoded_message2 = msg2 assert returned_offset1 == 0 - message1 = Message(b"v1", key=b"k1") + message1 = Message(value=b"v1", key=b"k1", magic=0, attributes=0, crc=0) message1.encode() assert decoded_message1 == message1 assert returned_offset2 == 1 - message2 = Message(b"v2", key=b"k2") + message2 = Message(value=b"v2", key=b"k2", magic=0, attributes=0, crc=0) message2.encode() assert decoded_message2 == message2 @@ -222,7 +226,7 @@ def test_decode_message_set_partial(): returned_offset2, message2_size, decoded_message2 = msg2 assert returned_offset1 == 0 - message1 = Message(b"v1", key=b"k1") + message1 = Message(value=b"v1", key=b"k1", magic=0, attributes=0, crc=0) message1.encode() assert decoded_message1 == message1 @@ -353,7 +357,10 @@ def test_compact_data_structs(): attr_names = [ - n for n in dir(Request) if isinstance(getattr(Request, n), abc.abstractproperty) + n + for n in dir(Request) + if isinstance(getattr(Request, n), property) + and getattr(Request, n).__isabstractmethod__ is True ] @@ -364,7 +371,10 @@ def test_request_type_conformance(klass, attr_name): attr_names = [ - n for n in dir(Response) if isinstance(getattr(Response, n), abc.abstractproperty) + n + for n in dir(Response) + if isinstance(getattr(Response, n), property) + and getattr(Response, n).__isabstractmethod__ is True ]