From ee20b3781a478c96da9e16c843081420dc28a483 Mon Sep 17 00:00:00 2001 From: Denis Otkidach Date: Sat, 21 Oct 2023 20:54:34 +0300 Subject: [PATCH 01/20] Speed up local testing: don't regenerate certs --- docker/Makefile | 4 ++-- tests/conftest.py | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/docker/Makefile b/docker/Makefile index 4163f79b..d6b77a9f 100644 --- a/docker/Makefile +++ b/docker/Makefile @@ -1,5 +1,5 @@ -SCALA_VERSION?=2.12 -KAFKA_VERSION?=1.1.1 +SCALA_VERSION?=2.13 +KAFKA_VERSION?=2.8.1 IMAGE_NAME?=aiolibs/kafka IMAGE_TAG=$(IMAGE_NAME):$(SCALA_VERSION)_$(KAFKA_VERSION) diff --git a/tests/conftest.py b/tests/conftest.py index 7a1338ab..1cd91b22 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,7 +9,6 @@ import uuid import sys import pathlib -import shutil from aiokafka.record.legacy_records import ( LegacyRecordBatchBuilder, _LegacyRecordBatchBuilderPy) @@ -117,8 +116,10 @@ def kafka_image(): @pytest.fixture(scope='session') def ssl_folder(docker_ip_address, docker, kafka_image): ssl_dir = pathlib.Path('tests/ssl_cert') - if ssl_dir.exists(): - shutil.rmtree(str(ssl_dir)) + if ssl_dir.is_dir(): + # Skip generating certificates when they already exist. Remove + # directory to re-generate them. + return ssl_dir ssl_dir.mkdir() From 86562e779cf8e05733131645051801d309157c3d Mon Sep 17 00:00:00 2001 From: Denis Otkidach Date: Sat, 21 Oct 2023 21:48:13 +0300 Subject: [PATCH 02/20] Squashed 'kafka/' content from commit c89bd516 git-subtree-dir: kafka git-subtree-split: c89bd5169277e2e7a6f29674d54c085a8cdbbfe3 --- __init__.py | 34 + admin/__init__.py | 14 + admin/acl_resource.py | 244 +++ admin/client.py | 1347 +++++++++++++++ admin/config_resource.py | 36 + admin/new_partitions.py | 19 + admin/new_topic.py | 34 + client_async.py | 1077 ++++++++++++ cluster.py | 397 +++++ codec.py | 326 ++++ conn.py | 1534 +++++++++++++++++ consumer/__init__.py | 7 + consumer/fetcher.py | 1016 +++++++++++ consumer/group.py | 1225 +++++++++++++ consumer/subscription_state.py | 501 ++++++ coordinator/__init__.py | 0 coordinator/assignors/__init__.py | 0 coordinator/assignors/abstract.py | 56 + coordinator/assignors/range.py | 77 + coordinator/assignors/roundrobin.py | 96 ++ coordinator/assignors/sticky/__init__.py | 0 .../assignors/sticky/partition_movements.py | 149 ++ coordinator/assignors/sticky/sorted_set.py | 63 + .../assignors/sticky/sticky_assignor.py | 685 ++++++++ coordinator/base.py | 1023 +++++++++++ coordinator/consumer.py | 833 +++++++++ coordinator/heartbeat.py | 68 + coordinator/protocol.py | 33 + errors.py | 538 ++++++ future.py | 83 + metrics/__init__.py | 15 + metrics/compound_stat.py | 34 + metrics/dict_reporter.py | 83 + metrics/kafka_metric.py | 36 + metrics/measurable.py | 29 + metrics/measurable_stat.py | 16 + metrics/metric_config.py | 33 + metrics/metric_name.py | 106 ++ metrics/metrics.py | 261 +++ metrics/metrics_reporter.py | 57 + metrics/quota.py | 42 + metrics/stat.py | 23 + metrics/stats/__init__.py | 17 + metrics/stats/avg.py | 24 + metrics/stats/count.py | 17 + metrics/stats/histogram.py | 95 + metrics/stats/max_stat.py | 17 + metrics/stats/min_stat.py | 19 + metrics/stats/percentile.py | 15 + metrics/stats/percentiles.py | 74 + metrics/stats/rate.py | 117 ++ metrics/stats/sampled_stat.py | 101 ++ metrics/stats/sensor.py | 134 ++ metrics/stats/total.py | 15 + oauth/__init__.py | 3 + oauth/abstract.py | 42 + partitioner/__init__.py | 8 + partitioner/default.py | 102 ++ producer/__init__.py | 7 + producer/buffer.py | 115 ++ producer/future.py | 71 + producer/kafka.py | 752 ++++++++ producer/record_accumulator.py | 590 +++++++ producer/sender.py | 517 ++++++ protocol/__init__.py | 49 + protocol/abstract.py | 19 + protocol/admin.py | 1054 +++++++++++ protocol/api.py | 138 ++ protocol/commit.py | 255 +++ protocol/fetch.py | 386 +++++ protocol/frame.py | 30 + protocol/group.py | 230 +++ protocol/message.py | 216 +++ protocol/metadata.py | 200 +++ protocol/offset.py | 194 +++ protocol/parser.py | 176 ++ protocol/pickle.py | 35 + protocol/produce.py | 232 +++ protocol/struct.py | 72 + protocol/types.py | 365 ++++ record/README | 8 + record/__init__.py | 3 + record/_crc32c.py | 145 ++ record/abc.py | 124 ++ record/default_records.py | 630 +++++++ record/legacy_records.py | 548 ++++++ record/memory_records.py | 187 ++ record/util.py | 135 ++ scram.py | 81 + serializer/__init__.py | 3 + serializer/abstract.py | 31 + structs.py | 87 + util.py | 66 + vendor/__init__.py | 0 vendor/enum34.py | 841 +++++++++ vendor/selectors34.py | 637 +++++++ vendor/six.py | 897 ++++++++++ vendor/socketpair.py | 58 + version.py | 1 + 99 files changed, 23235 insertions(+) create mode 100644 __init__.py create mode 100644 admin/__init__.py create mode 100644 admin/acl_resource.py create mode 100644 admin/client.py create mode 100644 admin/config_resource.py create mode 100644 admin/new_partitions.py create mode 100644 admin/new_topic.py create mode 100644 client_async.py create mode 100644 cluster.py create mode 100644 codec.py create mode 100644 conn.py create mode 100644 consumer/__init__.py create mode 100644 consumer/fetcher.py create mode 100644 consumer/group.py create mode 100644 consumer/subscription_state.py create mode 100644 coordinator/__init__.py create mode 100644 coordinator/assignors/__init__.py create mode 100644 coordinator/assignors/abstract.py create mode 100644 coordinator/assignors/range.py create mode 100644 coordinator/assignors/roundrobin.py create mode 100644 coordinator/assignors/sticky/__init__.py create mode 100644 coordinator/assignors/sticky/partition_movements.py create mode 100644 coordinator/assignors/sticky/sorted_set.py create mode 100644 coordinator/assignors/sticky/sticky_assignor.py create mode 100644 coordinator/base.py create mode 100644 coordinator/consumer.py create mode 100644 coordinator/heartbeat.py create mode 100644 coordinator/protocol.py create mode 100644 errors.py create mode 100644 future.py create mode 100644 metrics/__init__.py create mode 100644 metrics/compound_stat.py create mode 100644 metrics/dict_reporter.py create mode 100644 metrics/kafka_metric.py create mode 100644 metrics/measurable.py create mode 100644 metrics/measurable_stat.py create mode 100644 metrics/metric_config.py create mode 100644 metrics/metric_name.py create mode 100644 metrics/metrics.py create mode 100644 metrics/metrics_reporter.py create mode 100644 metrics/quota.py create mode 100644 metrics/stat.py create mode 100644 metrics/stats/__init__.py create mode 100644 metrics/stats/avg.py create mode 100644 metrics/stats/count.py create mode 100644 metrics/stats/histogram.py create mode 100644 metrics/stats/max_stat.py create mode 100644 metrics/stats/min_stat.py create mode 100644 metrics/stats/percentile.py create mode 100644 metrics/stats/percentiles.py create mode 100644 metrics/stats/rate.py create mode 100644 metrics/stats/sampled_stat.py create mode 100644 metrics/stats/sensor.py create mode 100644 metrics/stats/total.py create mode 100644 oauth/__init__.py create mode 100644 oauth/abstract.py create mode 100644 partitioner/__init__.py create mode 100644 partitioner/default.py create mode 100644 producer/__init__.py create mode 100644 producer/buffer.py create mode 100644 producer/future.py create mode 100644 producer/kafka.py create mode 100644 producer/record_accumulator.py create mode 100644 producer/sender.py create mode 100644 protocol/__init__.py create mode 100644 protocol/abstract.py create mode 100644 protocol/admin.py create mode 100644 protocol/api.py create mode 100644 protocol/commit.py create mode 100644 protocol/fetch.py create mode 100644 protocol/frame.py create mode 100644 protocol/group.py create mode 100644 protocol/message.py create mode 100644 protocol/metadata.py create mode 100644 protocol/offset.py create mode 100644 protocol/parser.py create mode 100644 protocol/pickle.py create mode 100644 protocol/produce.py create mode 100644 protocol/struct.py create mode 100644 protocol/types.py create mode 100644 record/README create mode 100644 record/__init__.py create mode 100644 record/_crc32c.py create mode 100644 record/abc.py create mode 100644 record/default_records.py create mode 100644 record/legacy_records.py create mode 100644 record/memory_records.py create mode 100644 record/util.py create mode 100644 scram.py create mode 100644 serializer/__init__.py create mode 100644 serializer/abstract.py create mode 100644 structs.py create mode 100644 util.py create mode 100644 vendor/__init__.py create mode 100644 vendor/enum34.py create mode 100644 vendor/selectors34.py create mode 100644 vendor/six.py create mode 100644 vendor/socketpair.py create mode 100644 version.py diff --git a/__init__.py b/__init__.py new file mode 100644 index 00000000..d5e30aff --- /dev/null +++ b/__init__.py @@ -0,0 +1,34 @@ +from __future__ import absolute_import + +__title__ = 'kafka' +from kafka.version import __version__ +__author__ = 'Dana Powers' +__license__ = 'Apache License 2.0' +__copyright__ = 'Copyright 2016 Dana Powers, David Arthur, and Contributors' + +# Set default logging handler to avoid "No handler found" warnings. +import logging +try: # Python 2.7+ + from logging import NullHandler +except ImportError: + class NullHandler(logging.Handler): + def emit(self, record): + pass + +logging.getLogger(__name__).addHandler(NullHandler()) + + +from kafka.admin import KafkaAdminClient +from kafka.client_async import KafkaClient +from kafka.consumer import KafkaConsumer +from kafka.consumer.subscription_state import ConsumerRebalanceListener +from kafka.producer import KafkaProducer +from kafka.conn import BrokerConnection +from kafka.serializer import Serializer, Deserializer +from kafka.structs import TopicPartition, OffsetAndMetadata + + +__all__ = [ + 'BrokerConnection', 'ConsumerRebalanceListener', 'KafkaAdminClient', + 'KafkaClient', 'KafkaConsumer', 'KafkaProducer', +] diff --git a/admin/__init__.py b/admin/__init__.py new file mode 100644 index 00000000..c240fc6d --- /dev/null +++ b/admin/__init__.py @@ -0,0 +1,14 @@ +from __future__ import absolute_import + +from kafka.admin.config_resource import ConfigResource, ConfigResourceType +from kafka.admin.client import KafkaAdminClient +from kafka.admin.acl_resource import (ACL, ACLFilter, ResourcePattern, ResourcePatternFilter, ACLOperation, + ResourceType, ACLPermissionType, ACLResourcePatternType) +from kafka.admin.new_topic import NewTopic +from kafka.admin.new_partitions import NewPartitions + +__all__ = [ + 'ConfigResource', 'ConfigResourceType', 'KafkaAdminClient', 'NewTopic', 'NewPartitions', 'ACL', 'ACLFilter', + 'ResourcePattern', 'ResourcePatternFilter', 'ACLOperation', 'ResourceType', 'ACLPermissionType', + 'ACLResourcePatternType' +] diff --git a/admin/acl_resource.py b/admin/acl_resource.py new file mode 100644 index 00000000..fd997a10 --- /dev/null +++ b/admin/acl_resource.py @@ -0,0 +1,244 @@ +from __future__ import absolute_import +from kafka.errors import IllegalArgumentError + +# enum in stdlib as of py3.4 +try: + from enum import IntEnum # pylint: disable=import-error +except ImportError: + # vendored backport module + from kafka.vendor.enum34 import IntEnum + + +class ResourceType(IntEnum): + """Type of kafka resource to set ACL for + + The ANY value is only valid in a filter context + """ + + UNKNOWN = 0, + ANY = 1, + CLUSTER = 4, + DELEGATION_TOKEN = 6, + GROUP = 3, + TOPIC = 2, + TRANSACTIONAL_ID = 5 + + +class ACLOperation(IntEnum): + """Type of operation + + The ANY value is only valid in a filter context + """ + + ANY = 1, + ALL = 2, + READ = 3, + WRITE = 4, + CREATE = 5, + DELETE = 6, + ALTER = 7, + DESCRIBE = 8, + CLUSTER_ACTION = 9, + DESCRIBE_CONFIGS = 10, + ALTER_CONFIGS = 11, + IDEMPOTENT_WRITE = 12 + + +class ACLPermissionType(IntEnum): + """An enumerated type of permissions + + The ANY value is only valid in a filter context + """ + + ANY = 1, + DENY = 2, + ALLOW = 3 + + +class ACLResourcePatternType(IntEnum): + """An enumerated type of resource patterns + + More details on the pattern types and how they work + can be found in KIP-290 (Support for prefixed ACLs) + https://cwiki.apache.org/confluence/display/KAFKA/KIP-290%3A+Support+for+Prefixed+ACLs + """ + + ANY = 1, + MATCH = 2, + LITERAL = 3, + PREFIXED = 4 + + +class ACLFilter(object): + """Represents a filter to use with describing and deleting ACLs + + The difference between this class and the ACL class is mainly that + we allow using ANY with the operation, permission, and resource type objects + to fetch ALCs matching any of the properties. + + To make a filter matching any principal, set principal to None + """ + + def __init__( + self, + principal, + host, + operation, + permission_type, + resource_pattern + ): + self.principal = principal + self.host = host + self.operation = operation + self.permission_type = permission_type + self.resource_pattern = resource_pattern + + self.validate() + + def validate(self): + if not isinstance(self.operation, ACLOperation): + raise IllegalArgumentError("operation must be an ACLOperation object, and cannot be ANY") + if not isinstance(self.permission_type, ACLPermissionType): + raise IllegalArgumentError("permission_type must be an ACLPermissionType object, and cannot be ANY") + if not isinstance(self.resource_pattern, ResourcePatternFilter): + raise IllegalArgumentError("resource_pattern must be a ResourcePatternFilter object") + + def __repr__(self): + return "".format( + principal=self.principal, + host=self.host, + operation=self.operation.name, + type=self.permission_type.name, + resource=self.resource_pattern + ) + + def __eq__(self, other): + return all(( + self.principal == other.principal, + self.host == other.host, + self.operation == other.operation, + self.permission_type == other.permission_type, + self.resource_pattern == other.resource_pattern + )) + + def __hash__(self): + return hash(( + self.principal, + self.host, + self.operation, + self.permission_type, + self.resource_pattern, + )) + + +class ACL(ACLFilter): + """Represents a concrete ACL for a specific ResourcePattern + + In kafka an ACL is a 4-tuple of (principal, host, operation, permission_type) + that limits who can do what on a specific resource (or since KIP-290 a resource pattern) + + Terminology: + Principal -> This is the identifier for the user. Depending on the authorization method used (SSL, SASL etc) + the principal will look different. See http://kafka.apache.org/documentation/#security_authz for details. + The principal must be on the format "User:" or kafka will treat it as invalid. It's possible to use + other principal types than "User" if using a custom authorizer for the cluster. + Host -> This must currently be an IP address. It cannot be a range, and it cannot be a domain name. + It can be set to "*", which is special cased in kafka to mean "any host" + Operation -> Which client operation this ACL refers to. Has different meaning depending + on the resource type the ACL refers to. See https://docs.confluent.io/current/kafka/authorization.html#acl-format + for a list of which combinations of resource/operation that unlocks which kafka APIs + Permission Type: Whether this ACL is allowing or denying access + Resource Pattern -> This is a representation of the resource or resource pattern that the ACL + refers to. See the ResourcePattern class for details. + + """ + + def __init__( + self, + principal, + host, + operation, + permission_type, + resource_pattern + ): + super(ACL, self).__init__(principal, host, operation, permission_type, resource_pattern) + self.validate() + + def validate(self): + if self.operation == ACLOperation.ANY: + raise IllegalArgumentError("operation cannot be ANY") + if self.permission_type == ACLPermissionType.ANY: + raise IllegalArgumentError("permission_type cannot be ANY") + if not isinstance(self.resource_pattern, ResourcePattern): + raise IllegalArgumentError("resource_pattern must be a ResourcePattern object") + + +class ResourcePatternFilter(object): + def __init__( + self, + resource_type, + resource_name, + pattern_type + ): + self.resource_type = resource_type + self.resource_name = resource_name + self.pattern_type = pattern_type + + self.validate() + + def validate(self): + if not isinstance(self.resource_type, ResourceType): + raise IllegalArgumentError("resource_type must be a ResourceType object") + if not isinstance(self.pattern_type, ACLResourcePatternType): + raise IllegalArgumentError("pattern_type must be an ACLResourcePatternType object") + + def __repr__(self): + return "".format( + self.resource_type.name, + self.resource_name, + self.pattern_type.name + ) + + def __eq__(self, other): + return all(( + self.resource_type == other.resource_type, + self.resource_name == other.resource_name, + self.pattern_type == other.pattern_type, + )) + + def __hash__(self): + return hash(( + self.resource_type, + self.resource_name, + self.pattern_type + )) + + +class ResourcePattern(ResourcePatternFilter): + """A resource pattern to apply the ACL to + + Resource patterns are used to be able to specify which resources an ACL + describes in a more flexible way than just pointing to a literal topic name for example. + Since KIP-290 (kafka 2.0) it's possible to set an ACL for a prefixed resource name, which + can cut down considerably on the number of ACLs needed when the number of topics and + consumer groups start to grow. + The default pattern_type is LITERAL, and it describes a specific resource. This is also how + ACLs worked before the introduction of prefixed ACLs + """ + + def __init__( + self, + resource_type, + resource_name, + pattern_type=ACLResourcePatternType.LITERAL + ): + super(ResourcePattern, self).__init__(resource_type, resource_name, pattern_type) + self.validate() + + def validate(self): + if self.resource_type == ResourceType.ANY: + raise IllegalArgumentError("resource_type cannot be ANY") + if self.pattern_type in [ACLResourcePatternType.ANY, ACLResourcePatternType.MATCH]: + raise IllegalArgumentError( + "pattern_type cannot be {} on a concrete ResourcePattern".format(self.pattern_type.name) + ) diff --git a/admin/client.py b/admin/client.py new file mode 100644 index 00000000..8eb7504a --- /dev/null +++ b/admin/client.py @@ -0,0 +1,1347 @@ +from __future__ import absolute_import + +from collections import defaultdict +import copy +import logging +import socket + +from . import ConfigResourceType +from kafka.vendor import six + +from kafka.admin.acl_resource import ACLOperation, ACLPermissionType, ACLFilter, ACL, ResourcePattern, ResourceType, \ + ACLResourcePatternType +from kafka.client_async import KafkaClient, selectors +from kafka.coordinator.protocol import ConsumerProtocolMemberMetadata, ConsumerProtocolMemberAssignment, ConsumerProtocol +import kafka.errors as Errors +from kafka.errors import ( + IncompatibleBrokerVersion, KafkaConfigurationError, NotControllerError, + UnrecognizedBrokerVersion, IllegalArgumentError) +from kafka.metrics import MetricConfig, Metrics +from kafka.protocol.admin import ( + CreateTopicsRequest, DeleteTopicsRequest, DescribeConfigsRequest, AlterConfigsRequest, CreatePartitionsRequest, + ListGroupsRequest, DescribeGroupsRequest, DescribeAclsRequest, CreateAclsRequest, DeleteAclsRequest, + DeleteGroupsRequest +) +from kafka.protocol.commit import GroupCoordinatorRequest, OffsetFetchRequest +from kafka.protocol.metadata import MetadataRequest +from kafka.protocol.types import Array +from kafka.structs import TopicPartition, OffsetAndMetadata, MemberInformation, GroupInformation +from kafka.version import __version__ + + +log = logging.getLogger(__name__) + + +class KafkaAdminClient(object): + """A class for administering the Kafka cluster. + + Warning: + This is an unstable interface that was recently added and is subject to + change without warning. In particular, many methods currently return + raw protocol tuples. In future releases, we plan to make these into + nicer, more pythonic objects. Unfortunately, this will likely break + those interfaces. + + The KafkaAdminClient class will negotiate for the latest version of each message + protocol format supported by both the kafka-python client library and the + Kafka broker. Usage of optional fields from protocol versions that are not + supported by the broker will result in IncompatibleBrokerVersion exceptions. + + Use of this class requires a minimum broker version >= 0.10.0.0. + + Keyword Arguments: + bootstrap_servers: 'host[:port]' string (or list of 'host[:port]' + strings) that the consumer should contact to bootstrap initial + cluster metadata. This does not have to be the full node list. + It just needs to have at least one broker that will respond to a + Metadata API Request. Default port is 9092. If no servers are + specified, will default to localhost:9092. + client_id (str): a name for this client. This string is passed in + each request to servers and can be used to identify specific + server-side log entries that correspond to this client. Also + submitted to GroupCoordinator for logging with respect to + consumer group administration. Default: 'kafka-python-{version}' + reconnect_backoff_ms (int): The amount of time in milliseconds to + wait before attempting to reconnect to a given host. + Default: 50. + reconnect_backoff_max_ms (int): The maximum amount of time in + milliseconds to backoff/wait when reconnecting to a broker that has + repeatedly failed to connect. If provided, the backoff per host + will increase exponentially for each consecutive connection + failure, up to this maximum. Once the maximum is reached, + reconnection attempts will continue periodically with this fixed + rate. To avoid connection storms, a randomization factor of 0.2 + will be applied to the backoff resulting in a random range between + 20% below and 20% above the computed value. Default: 1000. + request_timeout_ms (int): Client request timeout in milliseconds. + Default: 30000. + connections_max_idle_ms: Close idle connections after the number of + milliseconds specified by this config. The broker closes idle + connections after connections.max.idle.ms, so this avoids hitting + unexpected socket disconnected errors on the client. + Default: 540000 + retry_backoff_ms (int): Milliseconds to backoff when retrying on + errors. Default: 100. + max_in_flight_requests_per_connection (int): Requests are pipelined + to kafka brokers up to this number of maximum requests per + broker connection. Default: 5. + receive_buffer_bytes (int): The size of the TCP receive buffer + (SO_RCVBUF) to use when reading data. Default: None (relies on + system defaults). Java client defaults to 32768. + send_buffer_bytes (int): The size of the TCP send buffer + (SO_SNDBUF) to use when sending data. Default: None (relies on + system defaults). Java client defaults to 131072. + socket_options (list): List of tuple-arguments to socket.setsockopt + to apply to broker connection sockets. Default: + [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)] + metadata_max_age_ms (int): The period of time in milliseconds after + which we force a refresh of metadata even if we haven't seen any + partition leadership changes to proactively discover any new + brokers or partitions. Default: 300000 + security_protocol (str): Protocol used to communicate with brokers. + Valid values are: PLAINTEXT, SSL, SASL_PLAINTEXT, SASL_SSL. + Default: PLAINTEXT. + ssl_context (ssl.SSLContext): Pre-configured SSLContext for wrapping + socket connections. If provided, all other ssl_* configurations + will be ignored. Default: None. + ssl_check_hostname (bool): Flag to configure whether SSL handshake + should verify that the certificate matches the broker's hostname. + Default: True. + ssl_cafile (str): Optional filename of CA file to use in certificate + verification. Default: None. + ssl_certfile (str): Optional filename of file in PEM format containing + the client certificate, as well as any CA certificates needed to + establish the certificate's authenticity. Default: None. + ssl_keyfile (str): Optional filename containing the client private key. + Default: None. + ssl_password (str): Optional password to be used when loading the + certificate chain. Default: None. + ssl_crlfile (str): Optional filename containing the CRL to check for + certificate expiration. By default, no CRL check is done. When + providing a file, only the leaf certificate will be checked against + this CRL. The CRL can only be checked with Python 3.4+ or 2.7.9+. + Default: None. + api_version (tuple): Specify which Kafka API version to use. If set + to None, KafkaClient will attempt to infer the broker version by + probing various APIs. Example: (0, 10, 2). Default: None + api_version_auto_timeout_ms (int): number of milliseconds to throw a + timeout exception from the constructor when checking the broker + api version. Only applies if api_version is None + selector (selectors.BaseSelector): Provide a specific selector + implementation to use for I/O multiplexing. + Default: selectors.DefaultSelector + metrics (kafka.metrics.Metrics): Optionally provide a metrics + instance for capturing network IO stats. Default: None. + metric_group_prefix (str): Prefix for metric names. Default: '' + sasl_mechanism (str): Authentication mechanism when security_protocol + is configured for SASL_PLAINTEXT or SASL_SSL. Valid values are: + PLAIN, GSSAPI, OAUTHBEARER, SCRAM-SHA-256, SCRAM-SHA-512. + sasl_plain_username (str): username for sasl PLAIN and SCRAM authentication. + Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms. + sasl_plain_password (str): password for sasl PLAIN and SCRAM authentication. + Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms. + sasl_kerberos_service_name (str): Service name to include in GSSAPI + sasl mechanism handshake. Default: 'kafka' + sasl_kerberos_domain_name (str): kerberos domain name to use in GSSAPI + sasl mechanism handshake. Default: one of bootstrap servers + sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider + instance. (See kafka.oauth.abstract). Default: None + kafka_client (callable): Custom class / callable for creating KafkaClient instances + + """ + DEFAULT_CONFIG = { + # client configs + 'bootstrap_servers': 'localhost', + 'client_id': 'kafka-python-' + __version__, + 'request_timeout_ms': 30000, + 'connections_max_idle_ms': 9 * 60 * 1000, + 'reconnect_backoff_ms': 50, + 'reconnect_backoff_max_ms': 1000, + 'max_in_flight_requests_per_connection': 5, + 'receive_buffer_bytes': None, + 'send_buffer_bytes': None, + 'socket_options': [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)], + 'sock_chunk_bytes': 4096, # undocumented experimental option + 'sock_chunk_buffer_count': 1000, # undocumented experimental option + 'retry_backoff_ms': 100, + 'metadata_max_age_ms': 300000, + 'security_protocol': 'PLAINTEXT', + 'ssl_context': None, + 'ssl_check_hostname': True, + 'ssl_cafile': None, + 'ssl_certfile': None, + 'ssl_keyfile': None, + 'ssl_password': None, + 'ssl_crlfile': None, + 'api_version': None, + 'api_version_auto_timeout_ms': 2000, + 'selector': selectors.DefaultSelector, + 'sasl_mechanism': None, + 'sasl_plain_username': None, + 'sasl_plain_password': None, + 'sasl_kerberos_service_name': 'kafka', + 'sasl_kerberos_domain_name': None, + 'sasl_oauth_token_provider': None, + + # metrics configs + 'metric_reporters': [], + 'metrics_num_samples': 2, + 'metrics_sample_window_ms': 30000, + 'kafka_client': KafkaClient, + } + + def __init__(self, **configs): + log.debug("Starting KafkaAdminClient with configuration: %s", configs) + extra_configs = set(configs).difference(self.DEFAULT_CONFIG) + if extra_configs: + raise KafkaConfigurationError("Unrecognized configs: {}".format(extra_configs)) + + self.config = copy.copy(self.DEFAULT_CONFIG) + self.config.update(configs) + + # Configure metrics + metrics_tags = {'client-id': self.config['client_id']} + metric_config = MetricConfig(samples=self.config['metrics_num_samples'], + time_window_ms=self.config['metrics_sample_window_ms'], + tags=metrics_tags) + reporters = [reporter() for reporter in self.config['metric_reporters']] + self._metrics = Metrics(metric_config, reporters) + + self._client = self.config['kafka_client']( + metrics=self._metrics, + metric_group_prefix='admin', + **self.config + ) + self._client.check_version(timeout=(self.config['api_version_auto_timeout_ms'] / 1000)) + + # Get auto-discovered version from client if necessary + if self.config['api_version'] is None: + self.config['api_version'] = self._client.config['api_version'] + + self._closed = False + self._refresh_controller_id() + log.debug("KafkaAdminClient started.") + + def close(self): + """Close the KafkaAdminClient connection to the Kafka broker.""" + if not hasattr(self, '_closed') or self._closed: + log.info("KafkaAdminClient already closed.") + return + + self._metrics.close() + self._client.close() + self._closed = True + log.debug("KafkaAdminClient is now closed.") + + def _matching_api_version(self, operation): + """Find the latest version of the protocol operation supported by both + this library and the broker. + + This resolves to the lesser of either the latest api version this + library supports, or the max version supported by the broker. + + :param operation: A list of protocol operation versions from kafka.protocol. + :return: The max matching version number between client and broker. + """ + broker_api_versions = self._client.get_api_versions() + api_key = operation[0].API_KEY + if broker_api_versions is None or api_key not in broker_api_versions: + raise IncompatibleBrokerVersion( + "Kafka broker does not support the '{}' Kafka protocol." + .format(operation[0].__name__)) + min_version, max_version = broker_api_versions[api_key] + version = min(len(operation) - 1, max_version) + if version < min_version: + # max library version is less than min broker version. Currently, + # no Kafka versions specify a min msg version. Maybe in the future? + raise IncompatibleBrokerVersion( + "No version of the '{}' Kafka protocol is supported by both the client and broker." + .format(operation[0].__name__)) + return version + + def _validate_timeout(self, timeout_ms): + """Validate the timeout is set or use the configuration default. + + :param timeout_ms: The timeout provided by api call, in milliseconds. + :return: The timeout to use for the operation. + """ + return timeout_ms or self.config['request_timeout_ms'] + + def _refresh_controller_id(self): + """Determine the Kafka cluster controller.""" + version = self._matching_api_version(MetadataRequest) + if 1 <= version <= 6: + request = MetadataRequest[version]() + future = self._send_request_to_node(self._client.least_loaded_node(), request) + + self._wait_for_futures([future]) + + response = future.value + controller_id = response.controller_id + # verify the controller is new enough to support our requests + controller_version = self._client.check_version(controller_id, timeout=(self.config['api_version_auto_timeout_ms'] / 1000)) + if controller_version < (0, 10, 0): + raise IncompatibleBrokerVersion( + "The controller appears to be running Kafka {}. KafkaAdminClient requires brokers >= 0.10.0.0." + .format(controller_version)) + self._controller_id = controller_id + else: + raise UnrecognizedBrokerVersion( + "Kafka Admin interface cannot determine the controller using MetadataRequest_v{}." + .format(version)) + + def _find_coordinator_id_send_request(self, group_id): + """Send a FindCoordinatorRequest to a broker. + + :param group_id: The consumer group ID. This is typically the group + name as a string. + :return: A message future + """ + # TODO add support for dynamically picking version of + # GroupCoordinatorRequest which was renamed to FindCoordinatorRequest. + # When I experimented with this, the coordinator value returned in + # GroupCoordinatorResponse_v1 didn't match the value returned by + # GroupCoordinatorResponse_v0 and I couldn't figure out why. + version = 0 + # version = self._matching_api_version(GroupCoordinatorRequest) + if version <= 0: + request = GroupCoordinatorRequest[version](group_id) + else: + raise NotImplementedError( + "Support for GroupCoordinatorRequest_v{} has not yet been added to KafkaAdminClient." + .format(version)) + return self._send_request_to_node(self._client.least_loaded_node(), request) + + def _find_coordinator_id_process_response(self, response): + """Process a FindCoordinatorResponse. + + :param response: a FindCoordinatorResponse. + :return: The node_id of the broker that is the coordinator. + """ + if response.API_VERSION <= 0: + error_type = Errors.for_code(response.error_code) + if error_type is not Errors.NoError: + # Note: When error_type.retriable, Java will retry... see + # KafkaAdminClient's handleFindCoordinatorError method + raise error_type( + "FindCoordinatorRequest failed with response '{}'." + .format(response)) + else: + raise NotImplementedError( + "Support for FindCoordinatorRequest_v{} has not yet been added to KafkaAdminClient." + .format(response.API_VERSION)) + return response.coordinator_id + + def _find_coordinator_ids(self, group_ids): + """Find the broker node_ids of the coordinators of the given groups. + + Sends a FindCoordinatorRequest message to the cluster for each group_id. + Will block until the FindCoordinatorResponse is received for all groups. + Any errors are immediately raised. + + :param group_ids: A list of consumer group IDs. This is typically the group + name as a string. + :return: A dict of {group_id: node_id} where node_id is the id of the + broker that is the coordinator for the corresponding group. + """ + groups_futures = { + group_id: self._find_coordinator_id_send_request(group_id) + for group_id in group_ids + } + self._wait_for_futures(groups_futures.values()) + groups_coordinators = { + group_id: self._find_coordinator_id_process_response(future.value) + for group_id, future in groups_futures.items() + } + return groups_coordinators + + def _send_request_to_node(self, node_id, request, wakeup=True): + """Send a Kafka protocol message to a specific broker. + + Returns a future that may be polled for status and results. + + :param node_id: The broker id to which to send the message. + :param request: The message to send. + :param wakeup: Optional flag to disable thread-wakeup. + :return: A future object that may be polled for status and results. + :exception: The exception if the message could not be sent. + """ + while not self._client.ready(node_id): + # poll until the connection to broker is ready, otherwise send() + # will fail with NodeNotReadyError + self._client.poll() + return self._client.send(node_id, request, wakeup) + + def _send_request_to_controller(self, request): + """Send a Kafka protocol message to the cluster controller. + + Will block until the message result is received. + + :param request: The message to send. + :return: The Kafka protocol response for the message. + """ + tries = 2 # in case our cached self._controller_id is outdated + while tries: + tries -= 1 + future = self._send_request_to_node(self._controller_id, request) + + self._wait_for_futures([future]) + + response = future.value + # In Java, the error field name is inconsistent: + # - CreateTopicsResponse / CreatePartitionsResponse uses topic_errors + # - DeleteTopicsResponse uses topic_error_codes + # So this is a little brittle in that it assumes all responses have + # one of these attributes and that they always unpack into + # (topic, error_code) tuples. + topic_error_tuples = (response.topic_errors if hasattr(response, 'topic_errors') + else response.topic_error_codes) + # Also small py2/py3 compatibility -- py3 can ignore extra values + # during unpack via: for x, y, *rest in list_of_values. py2 cannot. + # So for now we have to map across the list and explicitly drop any + # extra values (usually the error_message) + for topic, error_code in map(lambda e: e[:2], topic_error_tuples): + error_type = Errors.for_code(error_code) + if tries and error_type is NotControllerError: + # No need to inspect the rest of the errors for + # non-retriable errors because NotControllerError should + # either be thrown for all errors or no errors. + self._refresh_controller_id() + break + elif error_type is not Errors.NoError: + raise error_type( + "Request '{}' failed with response '{}'." + .format(request, response)) + else: + return response + raise RuntimeError("This should never happen, please file a bug with full stacktrace if encountered") + + @staticmethod + def _convert_new_topic_request(new_topic): + return ( + new_topic.name, + new_topic.num_partitions, + new_topic.replication_factor, + [ + (partition_id, replicas) for partition_id, replicas in new_topic.replica_assignments.items() + ], + [ + (config_key, config_value) for config_key, config_value in new_topic.topic_configs.items() + ] + ) + + def create_topics(self, new_topics, timeout_ms=None, validate_only=False): + """Create new topics in the cluster. + + :param new_topics: A list of NewTopic objects. + :param timeout_ms: Milliseconds to wait for new topics to be created + before the broker returns. + :param validate_only: If True, don't actually create new topics. + Not supported by all versions. Default: False + :return: Appropriate version of CreateTopicResponse class. + """ + version = self._matching_api_version(CreateTopicsRequest) + timeout_ms = self._validate_timeout(timeout_ms) + if version == 0: + if validate_only: + raise IncompatibleBrokerVersion( + "validate_only requires CreateTopicsRequest >= v1, which is not supported by Kafka {}." + .format(self.config['api_version'])) + request = CreateTopicsRequest[version]( + create_topic_requests=[self._convert_new_topic_request(new_topic) for new_topic in new_topics], + timeout=timeout_ms + ) + elif version <= 3: + request = CreateTopicsRequest[version]( + create_topic_requests=[self._convert_new_topic_request(new_topic) for new_topic in new_topics], + timeout=timeout_ms, + validate_only=validate_only + ) + else: + raise NotImplementedError( + "Support for CreateTopics v{} has not yet been added to KafkaAdminClient." + .format(version)) + # TODO convert structs to a more pythonic interface + # TODO raise exceptions if errors + return self._send_request_to_controller(request) + + def delete_topics(self, topics, timeout_ms=None): + """Delete topics from the cluster. + + :param topics: A list of topic name strings. + :param timeout_ms: Milliseconds to wait for topics to be deleted + before the broker returns. + :return: Appropriate version of DeleteTopicsResponse class. + """ + version = self._matching_api_version(DeleteTopicsRequest) + timeout_ms = self._validate_timeout(timeout_ms) + if version <= 3: + request = DeleteTopicsRequest[version]( + topics=topics, + timeout=timeout_ms + ) + response = self._send_request_to_controller(request) + else: + raise NotImplementedError( + "Support for DeleteTopics v{} has not yet been added to KafkaAdminClient." + .format(version)) + return response + + + def _get_cluster_metadata(self, topics=None, auto_topic_creation=False): + """ + topics == None means "get all topics" + """ + version = self._matching_api_version(MetadataRequest) + if version <= 3: + if auto_topic_creation: + raise IncompatibleBrokerVersion( + "auto_topic_creation requires MetadataRequest >= v4, which" + " is not supported by Kafka {}" + .format(self.config['api_version'])) + + request = MetadataRequest[version](topics=topics) + elif version <= 5: + request = MetadataRequest[version]( + topics=topics, + allow_auto_topic_creation=auto_topic_creation + ) + + future = self._send_request_to_node( + self._client.least_loaded_node(), + request + ) + self._wait_for_futures([future]) + return future.value + + def list_topics(self): + metadata = self._get_cluster_metadata(topics=None) + obj = metadata.to_object() + return [t['topic'] for t in obj['topics']] + + def describe_topics(self, topics=None): + metadata = self._get_cluster_metadata(topics=topics) + obj = metadata.to_object() + return obj['topics'] + + def describe_cluster(self): + metadata = self._get_cluster_metadata() + obj = metadata.to_object() + obj.pop('topics') # We have 'describe_topics' for this + return obj + + @staticmethod + def _convert_describe_acls_response_to_acls(describe_response): + version = describe_response.API_VERSION + + error = Errors.for_code(describe_response.error_code) + acl_list = [] + for resources in describe_response.resources: + if version == 0: + resource_type, resource_name, acls = resources + resource_pattern_type = ACLResourcePatternType.LITERAL.value + elif version <= 1: + resource_type, resource_name, resource_pattern_type, acls = resources + else: + raise NotImplementedError( + "Support for DescribeAcls Response v{} has not yet been added to KafkaAdmin." + .format(version) + ) + for acl in acls: + principal, host, operation, permission_type = acl + conv_acl = ACL( + principal=principal, + host=host, + operation=ACLOperation(operation), + permission_type=ACLPermissionType(permission_type), + resource_pattern=ResourcePattern( + ResourceType(resource_type), + resource_name, + ACLResourcePatternType(resource_pattern_type) + ) + ) + acl_list.append(conv_acl) + + return (acl_list, error,) + + def describe_acls(self, acl_filter): + """Describe a set of ACLs + + Used to return a set of ACLs matching the supplied ACLFilter. + The cluster must be configured with an authorizer for this to work, or + you will get a SecurityDisabledError + + :param acl_filter: an ACLFilter object + :return: tuple of a list of matching ACL objects and a KafkaError (NoError if successful) + """ + + version = self._matching_api_version(DescribeAclsRequest) + if version == 0: + request = DescribeAclsRequest[version]( + resource_type=acl_filter.resource_pattern.resource_type, + resource_name=acl_filter.resource_pattern.resource_name, + principal=acl_filter.principal, + host=acl_filter.host, + operation=acl_filter.operation, + permission_type=acl_filter.permission_type + ) + elif version <= 1: + request = DescribeAclsRequest[version]( + resource_type=acl_filter.resource_pattern.resource_type, + resource_name=acl_filter.resource_pattern.resource_name, + resource_pattern_type_filter=acl_filter.resource_pattern.pattern_type, + principal=acl_filter.principal, + host=acl_filter.host, + operation=acl_filter.operation, + permission_type=acl_filter.permission_type + + ) + else: + raise NotImplementedError( + "Support for DescribeAcls v{} has not yet been added to KafkaAdmin." + .format(version) + ) + + future = self._send_request_to_node(self._client.least_loaded_node(), request) + self._wait_for_futures([future]) + response = future.value + + error_type = Errors.for_code(response.error_code) + if error_type is not Errors.NoError: + # optionally we could retry if error_type.retriable + raise error_type( + "Request '{}' failed with response '{}'." + .format(request, response)) + + return self._convert_describe_acls_response_to_acls(response) + + @staticmethod + def _convert_create_acls_resource_request_v0(acl): + + return ( + acl.resource_pattern.resource_type, + acl.resource_pattern.resource_name, + acl.principal, + acl.host, + acl.operation, + acl.permission_type + ) + + @staticmethod + def _convert_create_acls_resource_request_v1(acl): + + return ( + acl.resource_pattern.resource_type, + acl.resource_pattern.resource_name, + acl.resource_pattern.pattern_type, + acl.principal, + acl.host, + acl.operation, + acl.permission_type + ) + + @staticmethod + def _convert_create_acls_response_to_acls(acls, create_response): + version = create_response.API_VERSION + + creations_error = [] + creations_success = [] + for i, creations in enumerate(create_response.creation_responses): + if version <= 1: + error_code, error_message = creations + acl = acls[i] + error = Errors.for_code(error_code) + else: + raise NotImplementedError( + "Support for DescribeAcls Response v{} has not yet been added to KafkaAdmin." + .format(version) + ) + + if error is Errors.NoError: + creations_success.append(acl) + else: + creations_error.append((acl, error,)) + + return {"succeeded": creations_success, "failed": creations_error} + + def create_acls(self, acls): + """Create a list of ACLs + + This endpoint only accepts a list of concrete ACL objects, no ACLFilters. + Throws TopicAlreadyExistsError if topic is already present. + + :param acls: a list of ACL objects + :return: dict of successes and failures + """ + + for acl in acls: + if not isinstance(acl, ACL): + raise IllegalArgumentError("acls must contain ACL objects") + + version = self._matching_api_version(CreateAclsRequest) + if version == 0: + request = CreateAclsRequest[version]( + creations=[self._convert_create_acls_resource_request_v0(acl) for acl in acls] + ) + elif version <= 1: + request = CreateAclsRequest[version]( + creations=[self._convert_create_acls_resource_request_v1(acl) for acl in acls] + ) + else: + raise NotImplementedError( + "Support for CreateAcls v{} has not yet been added to KafkaAdmin." + .format(version) + ) + + future = self._send_request_to_node(self._client.least_loaded_node(), request) + self._wait_for_futures([future]) + response = future.value + + return self._convert_create_acls_response_to_acls(acls, response) + + @staticmethod + def _convert_delete_acls_resource_request_v0(acl): + return ( + acl.resource_pattern.resource_type, + acl.resource_pattern.resource_name, + acl.principal, + acl.host, + acl.operation, + acl.permission_type + ) + + @staticmethod + def _convert_delete_acls_resource_request_v1(acl): + return ( + acl.resource_pattern.resource_type, + acl.resource_pattern.resource_name, + acl.resource_pattern.pattern_type, + acl.principal, + acl.host, + acl.operation, + acl.permission_type + ) + + @staticmethod + def _convert_delete_acls_response_to_matching_acls(acl_filters, delete_response): + version = delete_response.API_VERSION + filter_result_list = [] + for i, filter_responses in enumerate(delete_response.filter_responses): + filter_error_code, filter_error_message, matching_acls = filter_responses + filter_error = Errors.for_code(filter_error_code) + acl_result_list = [] + for acl in matching_acls: + if version == 0: + error_code, error_message, resource_type, resource_name, principal, host, operation, permission_type = acl + resource_pattern_type = ACLResourcePatternType.LITERAL.value + elif version == 1: + error_code, error_message, resource_type, resource_name, resource_pattern_type, principal, host, operation, permission_type = acl + else: + raise NotImplementedError( + "Support for DescribeAcls Response v{} has not yet been added to KafkaAdmin." + .format(version) + ) + acl_error = Errors.for_code(error_code) + conv_acl = ACL( + principal=principal, + host=host, + operation=ACLOperation(operation), + permission_type=ACLPermissionType(permission_type), + resource_pattern=ResourcePattern( + ResourceType(resource_type), + resource_name, + ACLResourcePatternType(resource_pattern_type) + ) + ) + acl_result_list.append((conv_acl, acl_error,)) + filter_result_list.append((acl_filters[i], acl_result_list, filter_error,)) + return filter_result_list + + def delete_acls(self, acl_filters): + """Delete a set of ACLs + + Deletes all ACLs matching the list of input ACLFilter + + :param acl_filters: a list of ACLFilter + :return: a list of 3-tuples corresponding to the list of input filters. + The tuples hold (the input ACLFilter, list of affected ACLs, KafkaError instance) + """ + + for acl in acl_filters: + if not isinstance(acl, ACLFilter): + raise IllegalArgumentError("acl_filters must contain ACLFilter type objects") + + version = self._matching_api_version(DeleteAclsRequest) + + if version == 0: + request = DeleteAclsRequest[version]( + filters=[self._convert_delete_acls_resource_request_v0(acl) for acl in acl_filters] + ) + elif version <= 1: + request = DeleteAclsRequest[version]( + filters=[self._convert_delete_acls_resource_request_v1(acl) for acl in acl_filters] + ) + else: + raise NotImplementedError( + "Support for DeleteAcls v{} has not yet been added to KafkaAdmin." + .format(version) + ) + + future = self._send_request_to_node(self._client.least_loaded_node(), request) + self._wait_for_futures([future]) + response = future.value + + return self._convert_delete_acls_response_to_matching_acls(acl_filters, response) + + @staticmethod + def _convert_describe_config_resource_request(config_resource): + return ( + config_resource.resource_type, + config_resource.name, + [ + config_key for config_key, config_value in config_resource.configs.items() + ] if config_resource.configs else None + ) + + def describe_configs(self, config_resources, include_synonyms=False): + """Fetch configuration parameters for one or more Kafka resources. + + :param config_resources: An list of ConfigResource objects. + Any keys in ConfigResource.configs dict will be used to filter the + result. Setting the configs dict to None will get all values. An + empty dict will get zero values (as per Kafka protocol). + :param include_synonyms: If True, return synonyms in response. Not + supported by all versions. Default: False. + :return: Appropriate version of DescribeConfigsResponse class. + """ + + # Break up requests by type - a broker config request must be sent to the specific broker. + # All other (currently just topic resources) can be sent to any broker. + broker_resources = [] + topic_resources = [] + + for config_resource in config_resources: + if config_resource.resource_type == ConfigResourceType.BROKER: + broker_resources.append(self._convert_describe_config_resource_request(config_resource)) + else: + topic_resources.append(self._convert_describe_config_resource_request(config_resource)) + + futures = [] + version = self._matching_api_version(DescribeConfigsRequest) + if version == 0: + if include_synonyms: + raise IncompatibleBrokerVersion( + "include_synonyms requires DescribeConfigsRequest >= v1, which is not supported by Kafka {}." + .format(self.config['api_version'])) + + if len(broker_resources) > 0: + for broker_resource in broker_resources: + try: + broker_id = int(broker_resource[1]) + except ValueError: + raise ValueError("Broker resource names must be an integer or a string represented integer") + + futures.append(self._send_request_to_node( + broker_id, + DescribeConfigsRequest[version](resources=[broker_resource]) + )) + + if len(topic_resources) > 0: + futures.append(self._send_request_to_node( + self._client.least_loaded_node(), + DescribeConfigsRequest[version](resources=topic_resources) + )) + + elif version <= 2: + if len(broker_resources) > 0: + for broker_resource in broker_resources: + try: + broker_id = int(broker_resource[1]) + except ValueError: + raise ValueError("Broker resource names must be an integer or a string represented integer") + + futures.append(self._send_request_to_node( + broker_id, + DescribeConfigsRequest[version]( + resources=[broker_resource], + include_synonyms=include_synonyms) + )) + + if len(topic_resources) > 0: + futures.append(self._send_request_to_node( + self._client.least_loaded_node(), + DescribeConfigsRequest[version](resources=topic_resources, include_synonyms=include_synonyms) + )) + else: + raise NotImplementedError( + "Support for DescribeConfigs v{} has not yet been added to KafkaAdminClient.".format(version)) + + self._wait_for_futures(futures) + return [f.value for f in futures] + + @staticmethod + def _convert_alter_config_resource_request(config_resource): + return ( + config_resource.resource_type, + config_resource.name, + [ + (config_key, config_value) for config_key, config_value in config_resource.configs.items() + ] + ) + + def alter_configs(self, config_resources): + """Alter configuration parameters of one or more Kafka resources. + + Warning: + This is currently broken for BROKER resources because those must be + sent to that specific broker, versus this always picks the + least-loaded node. See the comment in the source code for details. + We would happily accept a PR fixing this. + + :param config_resources: A list of ConfigResource objects. + :return: Appropriate version of AlterConfigsResponse class. + """ + version = self._matching_api_version(AlterConfigsRequest) + if version <= 1: + request = AlterConfigsRequest[version]( + resources=[self._convert_alter_config_resource_request(config_resource) for config_resource in config_resources] + ) + else: + raise NotImplementedError( + "Support for AlterConfigs v{} has not yet been added to KafkaAdminClient." + .format(version)) + # TODO the Java client has the note: + # // We must make a separate AlterConfigs request for every BROKER resource we want to alter + # // and send the request to that specific broker. Other resources are grouped together into + # // a single request that may be sent to any broker. + # + # So this is currently broken as it always sends to the least_loaded_node() + future = self._send_request_to_node(self._client.least_loaded_node(), request) + + self._wait_for_futures([future]) + response = future.value + return response + + # alter replica logs dir protocol not yet implemented + # Note: have to lookup the broker with the replica assignment and send the request to that broker + + # describe log dirs protocol not yet implemented + # Note: have to lookup the broker with the replica assignment and send the request to that broker + + @staticmethod + def _convert_create_partitions_request(topic_name, new_partitions): + return ( + topic_name, + ( + new_partitions.total_count, + new_partitions.new_assignments + ) + ) + + def create_partitions(self, topic_partitions, timeout_ms=None, validate_only=False): + """Create additional partitions for an existing topic. + + :param topic_partitions: A map of topic name strings to NewPartition objects. + :param timeout_ms: Milliseconds to wait for new partitions to be + created before the broker returns. + :param validate_only: If True, don't actually create new partitions. + Default: False + :return: Appropriate version of CreatePartitionsResponse class. + """ + version = self._matching_api_version(CreatePartitionsRequest) + timeout_ms = self._validate_timeout(timeout_ms) + if version <= 1: + request = CreatePartitionsRequest[version]( + topic_partitions=[self._convert_create_partitions_request(topic_name, new_partitions) for topic_name, new_partitions in topic_partitions.items()], + timeout=timeout_ms, + validate_only=validate_only + ) + else: + raise NotImplementedError( + "Support for CreatePartitions v{} has not yet been added to KafkaAdminClient." + .format(version)) + return self._send_request_to_controller(request) + + # delete records protocol not yet implemented + # Note: send the request to the partition leaders + + # create delegation token protocol not yet implemented + # Note: send the request to the least_loaded_node() + + # renew delegation token protocol not yet implemented + # Note: send the request to the least_loaded_node() + + # expire delegation_token protocol not yet implemented + # Note: send the request to the least_loaded_node() + + # describe delegation_token protocol not yet implemented + # Note: send the request to the least_loaded_node() + + def _describe_consumer_groups_send_request(self, group_id, group_coordinator_id, include_authorized_operations=False): + """Send a DescribeGroupsRequest to the group's coordinator. + + :param group_id: The group name as a string + :param group_coordinator_id: The node_id of the groups' coordinator + broker. + :return: A message future. + """ + version = self._matching_api_version(DescribeGroupsRequest) + if version <= 2: + if include_authorized_operations: + raise IncompatibleBrokerVersion( + "include_authorized_operations requests " + "DescribeGroupsRequest >= v3, which is not " + "supported by Kafka {}".format(version) + ) + # Note: KAFKA-6788 A potential optimization is to group the + # request per coordinator and send one request with a list of + # all consumer groups. Java still hasn't implemented this + # because the error checking is hard to get right when some + # groups error and others don't. + request = DescribeGroupsRequest[version](groups=(group_id,)) + elif version <= 3: + request = DescribeGroupsRequest[version]( + groups=(group_id,), + include_authorized_operations=include_authorized_operations + ) + else: + raise NotImplementedError( + "Support for DescribeGroupsRequest_v{} has not yet been added to KafkaAdminClient." + .format(version)) + return self._send_request_to_node(group_coordinator_id, request) + + def _describe_consumer_groups_process_response(self, response): + """Process a DescribeGroupsResponse into a group description.""" + if response.API_VERSION <= 3: + assert len(response.groups) == 1 + for response_field, response_name in zip(response.SCHEMA.fields, response.SCHEMA.names): + if isinstance(response_field, Array): + described_groups_field_schema = response_field.array_of + described_group = response.__dict__[response_name][0] + described_group_information_list = [] + protocol_type_is_consumer = False + for (described_group_information, group_information_name, group_information_field) in zip(described_group, described_groups_field_schema.names, described_groups_field_schema.fields): + if group_information_name == 'protocol_type': + protocol_type = described_group_information + protocol_type_is_consumer = (protocol_type == ConsumerProtocol.PROTOCOL_TYPE or not protocol_type) + if isinstance(group_information_field, Array): + member_information_list = [] + member_schema = group_information_field.array_of + for members in described_group_information: + member_information = [] + for (member, member_field, member_name) in zip(members, member_schema.fields, member_schema.names): + if protocol_type_is_consumer: + if member_name == 'member_metadata' and member: + member_information.append(ConsumerProtocolMemberMetadata.decode(member)) + elif member_name == 'member_assignment' and member: + member_information.append(ConsumerProtocolMemberAssignment.decode(member)) + else: + member_information.append(member) + member_info_tuple = MemberInformation._make(member_information) + member_information_list.append(member_info_tuple) + described_group_information_list.append(member_information_list) + else: + described_group_information_list.append(described_group_information) + # Version 3 of the DescribeGroups API introduced the "authorized_operations" field. + # This will cause the namedtuple to fail. + # Therefore, appending a placeholder of None in it. + if response.API_VERSION <=2: + described_group_information_list.append(None) + group_description = GroupInformation._make(described_group_information_list) + error_code = group_description.error_code + error_type = Errors.for_code(error_code) + # Java has the note: KAFKA-6789, we can retry based on the error code + if error_type is not Errors.NoError: + raise error_type( + "DescribeGroupsResponse failed with response '{}'." + .format(response)) + else: + raise NotImplementedError( + "Support for DescribeGroupsResponse_v{} has not yet been added to KafkaAdminClient." + .format(response.API_VERSION)) + return group_description + + def describe_consumer_groups(self, group_ids, group_coordinator_id=None, include_authorized_operations=False): + """Describe a set of consumer groups. + + Any errors are immediately raised. + + :param group_ids: A list of consumer group IDs. These are typically the + group names as strings. + :param group_coordinator_id: The node_id of the groups' coordinator + broker. If set to None, it will query the cluster for each group to + find that group's coordinator. Explicitly specifying this can be + useful for avoiding extra network round trips if you already know + the group coordinator. This is only useful when all the group_ids + have the same coordinator, otherwise it will error. Default: None. + :param include_authorized_operations: Whether or not to include + information about the operations a group is allowed to perform. + Only supported on API version >= v3. Default: False. + :return: A list of group descriptions. For now the group descriptions + are the raw results from the DescribeGroupsResponse. Long-term, we + plan to change this to return namedtuples as well as decoding the + partition assignments. + """ + group_descriptions = [] + + if group_coordinator_id is not None: + groups_coordinators = {group_id: group_coordinator_id for group_id in group_ids} + else: + groups_coordinators = self._find_coordinator_ids(group_ids) + + futures = [ + self._describe_consumer_groups_send_request( + group_id, + coordinator_id, + include_authorized_operations) + for group_id, coordinator_id in groups_coordinators.items() + ] + self._wait_for_futures(futures) + + for future in futures: + response = future.value + group_description = self._describe_consumer_groups_process_response(response) + group_descriptions.append(group_description) + + return group_descriptions + + def _list_consumer_groups_send_request(self, broker_id): + """Send a ListGroupsRequest to a broker. + + :param broker_id: The broker's node_id. + :return: A message future + """ + version = self._matching_api_version(ListGroupsRequest) + if version <= 2: + request = ListGroupsRequest[version]() + else: + raise NotImplementedError( + "Support for ListGroupsRequest_v{} has not yet been added to KafkaAdminClient." + .format(version)) + return self._send_request_to_node(broker_id, request) + + def _list_consumer_groups_process_response(self, response): + """Process a ListGroupsResponse into a list of groups.""" + if response.API_VERSION <= 2: + error_type = Errors.for_code(response.error_code) + if error_type is not Errors.NoError: + raise error_type( + "ListGroupsRequest failed with response '{}'." + .format(response)) + else: + raise NotImplementedError( + "Support for ListGroupsResponse_v{} has not yet been added to KafkaAdminClient." + .format(response.API_VERSION)) + return response.groups + + def list_consumer_groups(self, broker_ids=None): + """List all consumer groups known to the cluster. + + This returns a list of Consumer Group tuples. The tuples are + composed of the consumer group name and the consumer group protocol + type. + + Only consumer groups that store their offsets in Kafka are returned. + The protocol type will be an empty string for groups created using + Kafka < 0.9 APIs because, although they store their offsets in Kafka, + they don't use Kafka for group coordination. For groups created using + Kafka >= 0.9, the protocol type will typically be "consumer". + + As soon as any error is encountered, it is immediately raised. + + :param broker_ids: A list of broker node_ids to query for consumer + groups. If set to None, will query all brokers in the cluster. + Explicitly specifying broker(s) can be useful for determining which + consumer groups are coordinated by those broker(s). Default: None + :return list: List of tuples of Consumer Groups. + :exception GroupCoordinatorNotAvailableError: The coordinator is not + available, so cannot process requests. + :exception GroupLoadInProgressError: The coordinator is loading and + hence can't process requests. + """ + # While we return a list, internally use a set to prevent duplicates + # because if a group coordinator fails after being queried, and its + # consumer groups move to new brokers that haven't yet been queried, + # then the same group could be returned by multiple brokers. + consumer_groups = set() + if broker_ids is None: + broker_ids = [broker.nodeId for broker in self._client.cluster.brokers()] + futures = [self._list_consumer_groups_send_request(b) for b in broker_ids] + self._wait_for_futures(futures) + for f in futures: + response = f.value + consumer_groups.update(self._list_consumer_groups_process_response(response)) + return list(consumer_groups) + + def _list_consumer_group_offsets_send_request(self, group_id, + group_coordinator_id, partitions=None): + """Send an OffsetFetchRequest to a broker. + + :param group_id: The consumer group id name for which to fetch offsets. + :param group_coordinator_id: The node_id of the group's coordinator + broker. + :return: A message future + """ + version = self._matching_api_version(OffsetFetchRequest) + if version <= 3: + if partitions is None: + if version <= 1: + raise ValueError( + """OffsetFetchRequest_v{} requires specifying the + partitions for which to fetch offsets. Omitting the + partitions is only supported on brokers >= 0.10.2. + For details, see KIP-88.""".format(version)) + topics_partitions = None + else: + # transform from [TopicPartition("t1", 1), TopicPartition("t1", 2)] to [("t1", [1, 2])] + topics_partitions_dict = defaultdict(set) + for topic, partition in partitions: + topics_partitions_dict[topic].add(partition) + topics_partitions = list(six.iteritems(topics_partitions_dict)) + request = OffsetFetchRequest[version](group_id, topics_partitions) + else: + raise NotImplementedError( + "Support for OffsetFetchRequest_v{} has not yet been added to KafkaAdminClient." + .format(version)) + return self._send_request_to_node(group_coordinator_id, request) + + def _list_consumer_group_offsets_process_response(self, response): + """Process an OffsetFetchResponse. + + :param response: an OffsetFetchResponse. + :return: A dictionary composed of TopicPartition keys and + OffsetAndMetadata values. + """ + if response.API_VERSION <= 3: + + # OffsetFetchResponse_v1 lacks a top-level error_code + if response.API_VERSION > 1: + error_type = Errors.for_code(response.error_code) + if error_type is not Errors.NoError: + # optionally we could retry if error_type.retriable + raise error_type( + "OffsetFetchResponse failed with response '{}'." + .format(response)) + + # transform response into a dictionary with TopicPartition keys and + # OffsetAndMetadata values--this is what the Java AdminClient returns + offsets = {} + for topic, partitions in response.topics: + for partition, offset, metadata, error_code in partitions: + error_type = Errors.for_code(error_code) + if error_type is not Errors.NoError: + raise error_type( + "Unable to fetch consumer group offsets for topic {}, partition {}" + .format(topic, partition)) + offsets[TopicPartition(topic, partition)] = OffsetAndMetadata(offset, metadata) + else: + raise NotImplementedError( + "Support for OffsetFetchResponse_v{} has not yet been added to KafkaAdminClient." + .format(response.API_VERSION)) + return offsets + + def list_consumer_group_offsets(self, group_id, group_coordinator_id=None, + partitions=None): + """Fetch Consumer Offsets for a single consumer group. + + Note: + This does not verify that the group_id or partitions actually exist + in the cluster. + + As soon as any error is encountered, it is immediately raised. + + :param group_id: The consumer group id name for which to fetch offsets. + :param group_coordinator_id: The node_id of the group's coordinator + broker. If set to None, will query the cluster to find the group + coordinator. Explicitly specifying this can be useful to prevent + that extra network round trip if you already know the group + coordinator. Default: None. + :param partitions: A list of TopicPartitions for which to fetch + offsets. On brokers >= 0.10.2, this can be set to None to fetch all + known offsets for the consumer group. Default: None. + :return dictionary: A dictionary with TopicPartition keys and + OffsetAndMetada values. Partitions that are not specified and for + which the group_id does not have a recorded offset are omitted. An + offset value of `-1` indicates the group_id has no offset for that + TopicPartition. A `-1` can only happen for partitions that are + explicitly specified. + """ + if group_coordinator_id is None: + group_coordinator_id = self._find_coordinator_ids([group_id])[group_id] + future = self._list_consumer_group_offsets_send_request( + group_id, group_coordinator_id, partitions) + self._wait_for_futures([future]) + response = future.value + return self._list_consumer_group_offsets_process_response(response) + + def delete_consumer_groups(self, group_ids, group_coordinator_id=None): + """Delete Consumer Group Offsets for given consumer groups. + + Note: + This does not verify that the group ids actually exist and + group_coordinator_id is the correct coordinator for all these groups. + + The result needs checking for potential errors. + + :param group_ids: The consumer group ids of the groups which are to be deleted. + :param group_coordinator_id: The node_id of the broker which is the coordinator for + all the groups. Use only if all groups are coordinated by the same broker. + If set to None, will query the cluster to find the coordinator for every single group. + Explicitly specifying this can be useful to prevent + that extra network round trips if you already know the group + coordinator. Default: None. + :return: A list of tuples (group_id, KafkaError) + """ + if group_coordinator_id is not None: + futures = [self._delete_consumer_groups_send_request(group_ids, group_coordinator_id)] + else: + coordinators_groups = defaultdict(list) + for group_id, coordinator_id in self._find_coordinator_ids(group_ids).items(): + coordinators_groups[coordinator_id].append(group_id) + futures = [ + self._delete_consumer_groups_send_request(group_ids, coordinator_id) + for coordinator_id, group_ids in coordinators_groups.items() + ] + + self._wait_for_futures(futures) + + results = [] + for f in futures: + results.extend(self._convert_delete_groups_response(f.value)) + return results + + def _convert_delete_groups_response(self, response): + if response.API_VERSION <= 1: + results = [] + for group_id, error_code in response.results: + results.append((group_id, Errors.for_code(error_code))) + return results + else: + raise NotImplementedError( + "Support for DeleteGroupsResponse_v{} has not yet been added to KafkaAdminClient." + .format(response.API_VERSION)) + + def _delete_consumer_groups_send_request(self, group_ids, group_coordinator_id): + """Send a DeleteGroups request to a broker. + + :param group_ids: The consumer group ids of the groups which are to be deleted. + :param group_coordinator_id: The node_id of the broker which is the coordinator for + all the groups. + :return: A message future + """ + version = self._matching_api_version(DeleteGroupsRequest) + if version <= 1: + request = DeleteGroupsRequest[version](group_ids) + else: + raise NotImplementedError( + "Support for DeleteGroupsRequest_v{} has not yet been added to KafkaAdminClient." + .format(version)) + return self._send_request_to_node(group_coordinator_id, request) + + def _wait_for_futures(self, futures): + while not all(future.succeeded() for future in futures): + for future in futures: + self._client.poll(future=future) + + if future.failed(): + raise future.exception # pylint: disable-msg=raising-bad-type diff --git a/admin/config_resource.py b/admin/config_resource.py new file mode 100644 index 00000000..e3294c9c --- /dev/null +++ b/admin/config_resource.py @@ -0,0 +1,36 @@ +from __future__ import absolute_import + +# enum in stdlib as of py3.4 +try: + from enum import IntEnum # pylint: disable=import-error +except ImportError: + # vendored backport module + from kafka.vendor.enum34 import IntEnum + + +class ConfigResourceType(IntEnum): + """An enumerated type of config resources""" + + BROKER = 4, + TOPIC = 2 + + +class ConfigResource(object): + """A class for specifying config resources. + Arguments: + resource_type (ConfigResourceType): the type of kafka resource + name (string): The name of the kafka resource + configs ({key : value}): A maps of config keys to values. + """ + + def __init__( + self, + resource_type, + name, + configs=None + ): + if not isinstance(resource_type, (ConfigResourceType)): + resource_type = ConfigResourceType[str(resource_type).upper()] # pylint: disable-msg=unsubscriptable-object + self.resource_type = resource_type + self.name = name + self.configs = configs diff --git a/admin/new_partitions.py b/admin/new_partitions.py new file mode 100644 index 00000000..429b2e19 --- /dev/null +++ b/admin/new_partitions.py @@ -0,0 +1,19 @@ +from __future__ import absolute_import + + +class NewPartitions(object): + """A class for new partition creation on existing topics. Note that the length of new_assignments, if specified, + must be the difference between the new total number of partitions and the existing number of partitions. + Arguments: + total_count (int): the total number of partitions that should exist on the topic + new_assignments ([[int]]): an array of arrays of replica assignments for new partitions. + If not set, broker assigns replicas per an internal algorithm. + """ + + def __init__( + self, + total_count, + new_assignments=None + ): + self.total_count = total_count + self.new_assignments = new_assignments diff --git a/admin/new_topic.py b/admin/new_topic.py new file mode 100644 index 00000000..645ac383 --- /dev/null +++ b/admin/new_topic.py @@ -0,0 +1,34 @@ +from __future__ import absolute_import + +from kafka.errors import IllegalArgumentError + + +class NewTopic(object): + """ A class for new topic creation + Arguments: + name (string): name of the topic + num_partitions (int): number of partitions + or -1 if replica_assignment has been specified + replication_factor (int): replication factor or -1 if + replica assignment is specified + replica_assignment (dict of int: [int]): A mapping containing + partition id and replicas to assign to it. + topic_configs (dict of str: str): A mapping of config key + and value for the topic. + """ + + def __init__( + self, + name, + num_partitions, + replication_factor, + replica_assignments=None, + topic_configs=None, + ): + if not (num_partitions == -1 or replication_factor == -1) ^ (replica_assignments is None): + raise IllegalArgumentError('either num_partitions/replication_factor or replica_assignment must be specified') + self.name = name + self.num_partitions = num_partitions + self.replication_factor = replication_factor + self.replica_assignments = replica_assignments or {} + self.topic_configs = topic_configs or {} diff --git a/client_async.py b/client_async.py new file mode 100644 index 00000000..58f22d4e --- /dev/null +++ b/client_async.py @@ -0,0 +1,1077 @@ +from __future__ import absolute_import, division + +import collections +import copy +import logging +import random +import socket +import threading +import time +import weakref + +# selectors in stdlib as of py3.4 +try: + import selectors # pylint: disable=import-error +except ImportError: + # vendored backport module + from kafka.vendor import selectors34 as selectors + +from kafka.vendor import six + +from kafka.cluster import ClusterMetadata +from kafka.conn import BrokerConnection, ConnectionStates, collect_hosts, get_ip_port_afi +from kafka import errors as Errors +from kafka.future import Future +from kafka.metrics import AnonMeasurable +from kafka.metrics.stats import Avg, Count, Rate +from kafka.metrics.stats.rate import TimeUnit +from kafka.protocol.metadata import MetadataRequest +from kafka.util import Dict, WeakMethod +# Although this looks unused, it actually monkey-patches socket.socketpair() +# and should be left in as long as we're using socket.socketpair() in this file +from kafka.vendor import socketpair +from kafka.version import __version__ + +if six.PY2: + ConnectionError = None + + +log = logging.getLogger('kafka.client') + + +class KafkaClient(object): + """ + A network client for asynchronous request/response network I/O. + + This is an internal class used to implement the user-facing producer and + consumer clients. + + This class is not thread-safe! + + Attributes: + cluster (:any:`ClusterMetadata`): Local cache of cluster metadata, retrieved + via MetadataRequests during :meth:`~kafka.KafkaClient.poll`. + + Keyword Arguments: + bootstrap_servers: 'host[:port]' string (or list of 'host[:port]' + strings) that the client should contact to bootstrap initial + cluster metadata. This does not have to be the full node list. + It just needs to have at least one broker that will respond to a + Metadata API Request. Default port is 9092. If no servers are + specified, will default to localhost:9092. + client_id (str): a name for this client. This string is passed in + each request to servers and can be used to identify specific + server-side log entries that correspond to this client. Also + submitted to GroupCoordinator for logging with respect to + consumer group administration. Default: 'kafka-python-{version}' + reconnect_backoff_ms (int): The amount of time in milliseconds to + wait before attempting to reconnect to a given host. + Default: 50. + reconnect_backoff_max_ms (int): The maximum amount of time in + milliseconds to backoff/wait when reconnecting to a broker that has + repeatedly failed to connect. If provided, the backoff per host + will increase exponentially for each consecutive connection + failure, up to this maximum. Once the maximum is reached, + reconnection attempts will continue periodically with this fixed + rate. To avoid connection storms, a randomization factor of 0.2 + will be applied to the backoff resulting in a random range between + 20% below and 20% above the computed value. Default: 1000. + request_timeout_ms (int): Client request timeout in milliseconds. + Default: 30000. + connections_max_idle_ms: Close idle connections after the number of + milliseconds specified by this config. The broker closes idle + connections after connections.max.idle.ms, so this avoids hitting + unexpected socket disconnected errors on the client. + Default: 540000 + retry_backoff_ms (int): Milliseconds to backoff when retrying on + errors. Default: 100. + max_in_flight_requests_per_connection (int): Requests are pipelined + to kafka brokers up to this number of maximum requests per + broker connection. Default: 5. + receive_buffer_bytes (int): The size of the TCP receive buffer + (SO_RCVBUF) to use when reading data. Default: None (relies on + system defaults). Java client defaults to 32768. + send_buffer_bytes (int): The size of the TCP send buffer + (SO_SNDBUF) to use when sending data. Default: None (relies on + system defaults). Java client defaults to 131072. + socket_options (list): List of tuple-arguments to socket.setsockopt + to apply to broker connection sockets. Default: + [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)] + metadata_max_age_ms (int): The period of time in milliseconds after + which we force a refresh of metadata even if we haven't seen any + partition leadership changes to proactively discover any new + brokers or partitions. Default: 300000 + security_protocol (str): Protocol used to communicate with brokers. + Valid values are: PLAINTEXT, SSL, SASL_PLAINTEXT, SASL_SSL. + Default: PLAINTEXT. + ssl_context (ssl.SSLContext): Pre-configured SSLContext for wrapping + socket connections. If provided, all other ssl_* configurations + will be ignored. Default: None. + ssl_check_hostname (bool): Flag to configure whether SSL handshake + should verify that the certificate matches the broker's hostname. + Default: True. + ssl_cafile (str): Optional filename of CA file to use in certificate + verification. Default: None. + ssl_certfile (str): Optional filename of file in PEM format containing + the client certificate, as well as any CA certificates needed to + establish the certificate's authenticity. Default: None. + ssl_keyfile (str): Optional filename containing the client private key. + Default: None. + ssl_password (str): Optional password to be used when loading the + certificate chain. Default: None. + ssl_crlfile (str): Optional filename containing the CRL to check for + certificate expiration. By default, no CRL check is done. When + providing a file, only the leaf certificate will be checked against + this CRL. The CRL can only be checked with Python 3.4+ or 2.7.9+. + Default: None. + ssl_ciphers (str): optionally set the available ciphers for ssl + connections. It should be a string in the OpenSSL cipher list + format. If no cipher can be selected (because compile-time options + or other configuration forbids use of all the specified ciphers), + an ssl.SSLError will be raised. See ssl.SSLContext.set_ciphers + api_version (tuple): Specify which Kafka API version to use. If set + to None, KafkaClient will attempt to infer the broker version by + probing various APIs. Example: (0, 10, 2). Default: None + api_version_auto_timeout_ms (int): number of milliseconds to throw a + timeout exception from the constructor when checking the broker + api version. Only applies if api_version is None + selector (selectors.BaseSelector): Provide a specific selector + implementation to use for I/O multiplexing. + Default: selectors.DefaultSelector + metrics (kafka.metrics.Metrics): Optionally provide a metrics + instance for capturing network IO stats. Default: None. + metric_group_prefix (str): Prefix for metric names. Default: '' + sasl_mechanism (str): Authentication mechanism when security_protocol + is configured for SASL_PLAINTEXT or SASL_SSL. Valid values are: + PLAIN, GSSAPI, OAUTHBEARER, SCRAM-SHA-256, SCRAM-SHA-512. + sasl_plain_username (str): username for sasl PLAIN and SCRAM authentication. + Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms. + sasl_plain_password (str): password for sasl PLAIN and SCRAM authentication. + Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms. + sasl_kerberos_service_name (str): Service name to include in GSSAPI + sasl mechanism handshake. Default: 'kafka' + sasl_kerberos_domain_name (str): kerberos domain name to use in GSSAPI + sasl mechanism handshake. Default: one of bootstrap servers + sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider + instance. (See kafka.oauth.abstract). Default: None + """ + + DEFAULT_CONFIG = { + 'bootstrap_servers': 'localhost', + 'bootstrap_topics_filter': set(), + 'client_id': 'kafka-python-' + __version__, + 'request_timeout_ms': 30000, + 'wakeup_timeout_ms': 3000, + 'connections_max_idle_ms': 9 * 60 * 1000, + 'reconnect_backoff_ms': 50, + 'reconnect_backoff_max_ms': 1000, + 'max_in_flight_requests_per_connection': 5, + 'receive_buffer_bytes': None, + 'send_buffer_bytes': None, + 'socket_options': [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)], + 'sock_chunk_bytes': 4096, # undocumented experimental option + 'sock_chunk_buffer_count': 1000, # undocumented experimental option + 'retry_backoff_ms': 100, + 'metadata_max_age_ms': 300000, + 'security_protocol': 'PLAINTEXT', + 'ssl_context': None, + 'ssl_check_hostname': True, + 'ssl_cafile': None, + 'ssl_certfile': None, + 'ssl_keyfile': None, + 'ssl_password': None, + 'ssl_crlfile': None, + 'ssl_ciphers': None, + 'api_version': None, + 'api_version_auto_timeout_ms': 2000, + 'selector': selectors.DefaultSelector, + 'metrics': None, + 'metric_group_prefix': '', + 'sasl_mechanism': None, + 'sasl_plain_username': None, + 'sasl_plain_password': None, + 'sasl_kerberos_service_name': 'kafka', + 'sasl_kerberos_domain_name': None, + 'sasl_oauth_token_provider': None + } + + def __init__(self, **configs): + self.config = copy.copy(self.DEFAULT_CONFIG) + for key in self.config: + if key in configs: + self.config[key] = configs[key] + + # these properties need to be set on top of the initialization pipeline + # because they are used when __del__ method is called + self._closed = False + self._wake_r, self._wake_w = socket.socketpair() + self._selector = self.config['selector']() + + self.cluster = ClusterMetadata(**self.config) + self._topics = set() # empty set will fetch all topic metadata + self._metadata_refresh_in_progress = False + self._conns = Dict() # object to support weakrefs + self._api_versions = None + self._connecting = set() + self._sending = set() + self._refresh_on_disconnects = True + self._last_bootstrap = 0 + self._bootstrap_fails = 0 + self._wake_r.setblocking(False) + self._wake_w.settimeout(self.config['wakeup_timeout_ms'] / 1000.0) + self._wake_lock = threading.Lock() + + self._lock = threading.RLock() + + # when requests complete, they are transferred to this queue prior to + # invocation. The purpose is to avoid invoking them while holding the + # lock above. + self._pending_completion = collections.deque() + + self._selector.register(self._wake_r, selectors.EVENT_READ) + self._idle_expiry_manager = IdleConnectionManager(self.config['connections_max_idle_ms']) + self._sensors = None + if self.config['metrics']: + self._sensors = KafkaClientMetrics(self.config['metrics'], + self.config['metric_group_prefix'], + weakref.proxy(self._conns)) + + self._num_bootstrap_hosts = len(collect_hosts(self.config['bootstrap_servers'])) + + # Check Broker Version if not set explicitly + if self.config['api_version'] is None: + check_timeout = self.config['api_version_auto_timeout_ms'] / 1000 + self.config['api_version'] = self.check_version(timeout=check_timeout) + + def _can_bootstrap(self): + effective_failures = self._bootstrap_fails // self._num_bootstrap_hosts + backoff_factor = 2 ** effective_failures + backoff_ms = min(self.config['reconnect_backoff_ms'] * backoff_factor, + self.config['reconnect_backoff_max_ms']) + + backoff_ms *= random.uniform(0.8, 1.2) + + next_at = self._last_bootstrap + backoff_ms / 1000.0 + now = time.time() + if next_at > now: + return False + return True + + def _can_connect(self, node_id): + if node_id not in self._conns: + if self.cluster.broker_metadata(node_id): + return True + return False + conn = self._conns[node_id] + return conn.disconnected() and not conn.blacked_out() + + def _conn_state_change(self, node_id, sock, conn): + with self._lock: + if conn.connecting(): + # SSL connections can enter this state 2x (second during Handshake) + if node_id not in self._connecting: + self._connecting.add(node_id) + try: + self._selector.register(sock, selectors.EVENT_WRITE, conn) + except KeyError: + self._selector.modify(sock, selectors.EVENT_WRITE, conn) + + if self.cluster.is_bootstrap(node_id): + self._last_bootstrap = time.time() + + elif conn.connected(): + log.debug("Node %s connected", node_id) + if node_id in self._connecting: + self._connecting.remove(node_id) + + try: + self._selector.modify(sock, selectors.EVENT_READ, conn) + except KeyError: + self._selector.register(sock, selectors.EVENT_READ, conn) + + if self._sensors: + self._sensors.connection_created.record() + + self._idle_expiry_manager.update(node_id) + + if self.cluster.is_bootstrap(node_id): + self._bootstrap_fails = 0 + + else: + for node_id in list(self._conns.keys()): + if self.cluster.is_bootstrap(node_id): + self._conns.pop(node_id).close() + + # Connection failures imply that our metadata is stale, so let's refresh + elif conn.state is ConnectionStates.DISCONNECTED: + if node_id in self._connecting: + self._connecting.remove(node_id) + try: + self._selector.unregister(sock) + except KeyError: + pass + + if self._sensors: + self._sensors.connection_closed.record() + + idle_disconnect = False + if self._idle_expiry_manager.is_expired(node_id): + idle_disconnect = True + self._idle_expiry_manager.remove(node_id) + + # If the connection has already by popped from self._conns, + # we can assume the disconnect was intentional and not a failure + if node_id not in self._conns: + pass + + elif self.cluster.is_bootstrap(node_id): + self._bootstrap_fails += 1 + + elif self._refresh_on_disconnects and not self._closed and not idle_disconnect: + log.warning("Node %s connection failed -- refreshing metadata", node_id) + self.cluster.request_update() + + def maybe_connect(self, node_id, wakeup=True): + """Queues a node for asynchronous connection during the next .poll()""" + if self._can_connect(node_id): + self._connecting.add(node_id) + # Wakeup signal is useful in case another thread is + # blocked waiting for incoming network traffic while holding + # the client lock in poll(). + if wakeup: + self.wakeup() + return True + return False + + def _should_recycle_connection(self, conn): + # Never recycle unless disconnected + if not conn.disconnected(): + return False + + # Otherwise, only recycle when broker metadata has changed + broker = self.cluster.broker_metadata(conn.node_id) + if broker is None: + return False + + host, _, afi = get_ip_port_afi(broker.host) + if conn.host != host or conn.port != broker.port: + log.info("Broker metadata change detected for node %s" + " from %s:%s to %s:%s", conn.node_id, conn.host, conn.port, + broker.host, broker.port) + return True + + return False + + def _maybe_connect(self, node_id): + """Idempotent non-blocking connection attempt to the given node id.""" + with self._lock: + conn = self._conns.get(node_id) + + if conn is None: + broker = self.cluster.broker_metadata(node_id) + assert broker, 'Broker id %s not in current metadata' % (node_id,) + + log.debug("Initiating connection to node %s at %s:%s", + node_id, broker.host, broker.port) + host, port, afi = get_ip_port_afi(broker.host) + cb = WeakMethod(self._conn_state_change) + conn = BrokerConnection(host, broker.port, afi, + state_change_callback=cb, + node_id=node_id, + **self.config) + self._conns[node_id] = conn + + # Check if existing connection should be recreated because host/port changed + elif self._should_recycle_connection(conn): + self._conns.pop(node_id) + return False + + elif conn.connected(): + return True + + conn.connect() + return conn.connected() + + def ready(self, node_id, metadata_priority=True): + """Check whether a node is connected and ok to send more requests. + + Arguments: + node_id (int): the id of the node to check + metadata_priority (bool): Mark node as not-ready if a metadata + refresh is required. Default: True + + Returns: + bool: True if we are ready to send to the given node + """ + self.maybe_connect(node_id) + return self.is_ready(node_id, metadata_priority=metadata_priority) + + def connected(self, node_id): + """Return True iff the node_id is connected.""" + conn = self._conns.get(node_id) + if conn is None: + return False + return conn.connected() + + def _close(self): + if not self._closed: + self._closed = True + self._wake_r.close() + self._wake_w.close() + self._selector.close() + + def close(self, node_id=None): + """Close one or all broker connections. + + Arguments: + node_id (int, optional): the id of the node to close + """ + with self._lock: + if node_id is None: + self._close() + conns = list(self._conns.values()) + self._conns.clear() + for conn in conns: + conn.close() + elif node_id in self._conns: + self._conns.pop(node_id).close() + else: + log.warning("Node %s not found in current connection list; skipping", node_id) + return + + def __del__(self): + self._close() + + def is_disconnected(self, node_id): + """Check whether the node connection has been disconnected or failed. + + A disconnected node has either been closed or has failed. Connection + failures are usually transient and can be resumed in the next ready() + call, but there are cases where transient failures need to be caught + and re-acted upon. + + Arguments: + node_id (int): the id of the node to check + + Returns: + bool: True iff the node exists and is disconnected + """ + conn = self._conns.get(node_id) + if conn is None: + return False + return conn.disconnected() + + def connection_delay(self, node_id): + """ + Return the number of milliseconds to wait, based on the connection + state, before attempting to send data. When disconnected, this respects + the reconnect backoff time. When connecting, returns 0 to allow + non-blocking connect to finish. When connected, returns a very large + number to handle slow/stalled connections. + + Arguments: + node_id (int): The id of the node to check + + Returns: + int: The number of milliseconds to wait. + """ + conn = self._conns.get(node_id) + if conn is None: + return 0 + return conn.connection_delay() + + def is_ready(self, node_id, metadata_priority=True): + """Check whether a node is ready to send more requests. + + In addition to connection-level checks, this method also is used to + block additional requests from being sent during a metadata refresh. + + Arguments: + node_id (int): id of the node to check + metadata_priority (bool): Mark node as not-ready if a metadata + refresh is required. Default: True + + Returns: + bool: True if the node is ready and metadata is not refreshing + """ + if not self._can_send_request(node_id): + return False + + # if we need to update our metadata now declare all requests unready to + # make metadata requests first priority + if metadata_priority: + if self._metadata_refresh_in_progress: + return False + if self.cluster.ttl() == 0: + return False + return True + + def _can_send_request(self, node_id): + conn = self._conns.get(node_id) + if not conn: + return False + return conn.connected() and conn.can_send_more() + + def send(self, node_id, request, wakeup=True): + """Send a request to a specific node. Bytes are placed on an + internal per-connection send-queue. Actual network I/O will be + triggered in a subsequent call to .poll() + + Arguments: + node_id (int): destination node + request (Struct): request object (not-encoded) + wakeup (bool): optional flag to disable thread-wakeup + + Raises: + AssertionError: if node_id is not in current cluster metadata + + Returns: + Future: resolves to Response struct or Error + """ + conn = self._conns.get(node_id) + if not conn or not self._can_send_request(node_id): + self.maybe_connect(node_id, wakeup=wakeup) + return Future().failure(Errors.NodeNotReadyError(node_id)) + + # conn.send will queue the request internally + # we will need to call send_pending_requests() + # to trigger network I/O + future = conn.send(request, blocking=False) + self._sending.add(conn) + + # Wakeup signal is useful in case another thread is + # blocked waiting for incoming network traffic while holding + # the client lock in poll(). + if wakeup: + self.wakeup() + + return future + + def poll(self, timeout_ms=None, future=None): + """Try to read and write to sockets. + + This method will also attempt to complete node connections, refresh + stale metadata, and run previously-scheduled tasks. + + Arguments: + timeout_ms (int, optional): maximum amount of time to wait (in ms) + for at least one response. Must be non-negative. The actual + timeout will be the minimum of timeout, request timeout and + metadata timeout. Default: request_timeout_ms + future (Future, optional): if provided, blocks until future.is_done + + Returns: + list: responses received (can be empty) + """ + if future is not None: + timeout_ms = 100 + elif timeout_ms is None: + timeout_ms = self.config['request_timeout_ms'] + elif not isinstance(timeout_ms, (int, float)): + raise TypeError('Invalid type for timeout: %s' % type(timeout_ms)) + + # Loop for futures, break after first loop if None + responses = [] + while True: + with self._lock: + if self._closed: + break + + # Attempt to complete pending connections + for node_id in list(self._connecting): + self._maybe_connect(node_id) + + # Send a metadata request if needed + metadata_timeout_ms = self._maybe_refresh_metadata() + + # If we got a future that is already done, don't block in _poll + if future is not None and future.is_done: + timeout = 0 + else: + idle_connection_timeout_ms = self._idle_expiry_manager.next_check_ms() + timeout = min( + timeout_ms, + metadata_timeout_ms, + idle_connection_timeout_ms, + self.config['request_timeout_ms']) + # if there are no requests in flight, do not block longer than the retry backoff + if self.in_flight_request_count() == 0: + timeout = min(timeout, self.config['retry_backoff_ms']) + timeout = max(0, timeout) # avoid negative timeouts + + self._poll(timeout / 1000) + + # called without the lock to avoid deadlock potential + # if handlers need to acquire locks + responses.extend(self._fire_pending_completed_requests()) + + # If all we had was a timeout (future is None) - only do one poll + # If we do have a future, we keep looping until it is done + if future is None or future.is_done: + break + + return responses + + def _register_send_sockets(self): + while self._sending: + conn = self._sending.pop() + try: + key = self._selector.get_key(conn._sock) + events = key.events | selectors.EVENT_WRITE + self._selector.modify(key.fileobj, events, key.data) + except KeyError: + self._selector.register(conn._sock, selectors.EVENT_WRITE, conn) + + def _poll(self, timeout): + # This needs to be locked, but since it is only called from within the + # locked section of poll(), there is no additional lock acquisition here + processed = set() + + # Send pending requests first, before polling for responses + self._register_send_sockets() + + start_select = time.time() + ready = self._selector.select(timeout) + end_select = time.time() + if self._sensors: + self._sensors.select_time.record((end_select - start_select) * 1000000000) + + for key, events in ready: + if key.fileobj is self._wake_r: + self._clear_wake_fd() + continue + + # Send pending requests if socket is ready to write + if events & selectors.EVENT_WRITE: + conn = key.data + if conn.connecting(): + conn.connect() + else: + if conn.send_pending_requests_v2(): + # If send is complete, we dont need to track write readiness + # for this socket anymore + if key.events ^ selectors.EVENT_WRITE: + self._selector.modify( + key.fileobj, + key.events ^ selectors.EVENT_WRITE, + key.data) + else: + self._selector.unregister(key.fileobj) + + if not (events & selectors.EVENT_READ): + continue + conn = key.data + processed.add(conn) + + if not conn.in_flight_requests: + # if we got an EVENT_READ but there were no in-flight requests, one of + # two things has happened: + # + # 1. The remote end closed the connection (because it died, or because + # a firewall timed out, or whatever) + # 2. The protocol is out of sync. + # + # either way, we can no longer safely use this connection + # + # Do a 1-byte read to check protocol didnt get out of sync, and then close the conn + try: + unexpected_data = key.fileobj.recv(1) + if unexpected_data: # anything other than a 0-byte read means protocol issues + log.warning('Protocol out of sync on %r, closing', conn) + except socket.error: + pass + conn.close(Errors.KafkaConnectionError('Socket EVENT_READ without in-flight-requests')) + continue + + self._idle_expiry_manager.update(conn.node_id) + self._pending_completion.extend(conn.recv()) + + # Check for additional pending SSL bytes + if self.config['security_protocol'] in ('SSL', 'SASL_SSL'): + # TODO: optimize + for conn in self._conns.values(): + if conn not in processed and conn.connected() and conn._sock.pending(): + self._pending_completion.extend(conn.recv()) + + for conn in six.itervalues(self._conns): + if conn.requests_timed_out(): + log.warning('%s timed out after %s ms. Closing connection.', + conn, conn.config['request_timeout_ms']) + conn.close(error=Errors.RequestTimedOutError( + 'Request timed out after %s ms' % + conn.config['request_timeout_ms'])) + + if self._sensors: + self._sensors.io_time.record((time.time() - end_select) * 1000000000) + + self._maybe_close_oldest_connection() + + def in_flight_request_count(self, node_id=None): + """Get the number of in-flight requests for a node or all nodes. + + Arguments: + node_id (int, optional): a specific node to check. If unspecified, + return the total for all nodes + + Returns: + int: pending in-flight requests for the node, or all nodes if None + """ + if node_id is not None: + conn = self._conns.get(node_id) + if conn is None: + return 0 + return len(conn.in_flight_requests) + else: + return sum([len(conn.in_flight_requests) + for conn in list(self._conns.values())]) + + def _fire_pending_completed_requests(self): + responses = [] + while True: + try: + # We rely on deque.popleft remaining threadsafe + # to allow both the heartbeat thread and the main thread + # to process responses + response, future = self._pending_completion.popleft() + except IndexError: + break + future.success(response) + responses.append(response) + return responses + + def least_loaded_node(self): + """Choose the node with fewest outstanding requests, with fallbacks. + + This method will prefer a node with an existing connection and no + in-flight-requests. If no such node is found, a node will be chosen + randomly from disconnected nodes that are not "blacked out" (i.e., + are not subject to a reconnect backoff). If no node metadata has been + obtained, will return a bootstrap node (subject to exponential backoff). + + Returns: + node_id or None if no suitable node was found + """ + nodes = [broker.nodeId for broker in self.cluster.brokers()] + random.shuffle(nodes) + + inflight = float('inf') + found = None + for node_id in nodes: + conn = self._conns.get(node_id) + connected = conn is not None and conn.connected() + blacked_out = conn is not None and conn.blacked_out() + curr_inflight = len(conn.in_flight_requests) if conn is not None else 0 + if connected and curr_inflight == 0: + # if we find an established connection + # with no in-flight requests, we can stop right away + return node_id + elif not blacked_out and curr_inflight < inflight: + # otherwise if this is the best we have found so far, record that + inflight = curr_inflight + found = node_id + + return found + + def set_topics(self, topics): + """Set specific topics to track for metadata. + + Arguments: + topics (list of str): topics to check for metadata + + Returns: + Future: resolves after metadata request/response + """ + if set(topics).difference(self._topics): + future = self.cluster.request_update() + else: + future = Future().success(set(topics)) + self._topics = set(topics) + return future + + def add_topic(self, topic): + """Add a topic to the list of topics tracked via metadata. + + Arguments: + topic (str): topic to track + + Returns: + Future: resolves after metadata request/response + """ + if topic in self._topics: + return Future().success(set(self._topics)) + + self._topics.add(topic) + return self.cluster.request_update() + + # This method should be locked when running multi-threaded + def _maybe_refresh_metadata(self, wakeup=False): + """Send a metadata request if needed. + + Returns: + int: milliseconds until next refresh + """ + ttl = self.cluster.ttl() + wait_for_in_progress_ms = self.config['request_timeout_ms'] if self._metadata_refresh_in_progress else 0 + metadata_timeout = max(ttl, wait_for_in_progress_ms) + + if metadata_timeout > 0: + return metadata_timeout + + # Beware that the behavior of this method and the computation of + # timeouts for poll() are highly dependent on the behavior of + # least_loaded_node() + node_id = self.least_loaded_node() + if node_id is None: + log.debug("Give up sending metadata request since no node is available"); + return self.config['reconnect_backoff_ms'] + + if self._can_send_request(node_id): + topics = list(self._topics) + if not topics and self.cluster.is_bootstrap(node_id): + topics = list(self.config['bootstrap_topics_filter']) + + if self.cluster.need_all_topic_metadata or not topics: + topics = [] if self.config['api_version'] < (0, 10) else None + api_version = 0 if self.config['api_version'] < (0, 10) else 1 + request = MetadataRequest[api_version](topics) + log.debug("Sending metadata request %s to node %s", request, node_id) + future = self.send(node_id, request, wakeup=wakeup) + future.add_callback(self.cluster.update_metadata) + future.add_errback(self.cluster.failed_update) + + self._metadata_refresh_in_progress = True + def refresh_done(val_or_error): + self._metadata_refresh_in_progress = False + future.add_callback(refresh_done) + future.add_errback(refresh_done) + return self.config['request_timeout_ms'] + + # If there's any connection establishment underway, wait until it completes. This prevents + # the client from unnecessarily connecting to additional nodes while a previous connection + # attempt has not been completed. + if self._connecting: + return self.config['reconnect_backoff_ms'] + + if self.maybe_connect(node_id, wakeup=wakeup): + log.debug("Initializing connection to node %s for metadata request", node_id) + return self.config['reconnect_backoff_ms'] + + # connected but can't send more, OR connecting + # In either case we just need to wait for a network event + # to let us know the selected connection might be usable again. + return float('inf') + + def get_api_versions(self): + """Return the ApiVersions map, if available. + + Note: A call to check_version must previously have succeeded and returned + version 0.10.0 or later + + Returns: a map of dict mapping {api_key : (min_version, max_version)}, + or None if ApiVersion is not supported by the kafka cluster. + """ + return self._api_versions + + def check_version(self, node_id=None, timeout=2, strict=False): + """Attempt to guess the version of a Kafka broker. + + Note: It is possible that this method blocks longer than the + specified timeout. This can happen if the entire cluster + is down and the client enters a bootstrap backoff sleep. + This is only possible if node_id is None. + + Returns: version tuple, i.e. (0, 10), (0, 9), (0, 8, 2), ... + + Raises: + NodeNotReadyError (if node_id is provided) + NoBrokersAvailable (if node_id is None) + UnrecognizedBrokerVersion: please file bug if seen! + AssertionError (if strict=True): please file bug if seen! + """ + self._lock.acquire() + end = time.time() + timeout + while time.time() < end: + + # It is possible that least_loaded_node falls back to bootstrap, + # which can block for an increasing backoff period + try_node = node_id or self.least_loaded_node() + if try_node is None: + self._lock.release() + raise Errors.NoBrokersAvailable() + self._maybe_connect(try_node) + conn = self._conns[try_node] + + # We will intentionally cause socket failures + # These should not trigger metadata refresh + self._refresh_on_disconnects = False + try: + remaining = end - time.time() + version = conn.check_version(timeout=remaining, strict=strict, topics=list(self.config['bootstrap_topics_filter'])) + if version >= (0, 10, 0): + # cache the api versions map if it's available (starting + # in 0.10 cluster version) + self._api_versions = conn.get_api_versions() + self._lock.release() + return version + except Errors.NodeNotReadyError: + # Only raise to user if this is a node-specific request + if node_id is not None: + self._lock.release() + raise + finally: + self._refresh_on_disconnects = True + + # Timeout + else: + self._lock.release() + raise Errors.NoBrokersAvailable() + + def wakeup(self): + with self._wake_lock: + try: + self._wake_w.sendall(b'x') + except socket.timeout: + log.warning('Timeout to send to wakeup socket!') + raise Errors.KafkaTimeoutError() + except socket.error: + log.warning('Unable to send to wakeup socket!') + + def _clear_wake_fd(self): + # reading from wake socket should only happen in a single thread + while True: + try: + self._wake_r.recv(1024) + except socket.error: + break + + def _maybe_close_oldest_connection(self): + expired_connection = self._idle_expiry_manager.poll_expired_connection() + if expired_connection: + conn_id, ts = expired_connection + idle_ms = (time.time() - ts) * 1000 + log.info('Closing idle connection %s, last active %d ms ago', conn_id, idle_ms) + self.close(node_id=conn_id) + + def bootstrap_connected(self): + """Return True if a bootstrap node is connected""" + for node_id in self._conns: + if not self.cluster.is_bootstrap(node_id): + continue + if self._conns[node_id].connected(): + return True + else: + return False + + +# OrderedDict requires python2.7+ +try: + from collections import OrderedDict +except ImportError: + # If we dont have OrderedDict, we'll fallback to dict with O(n) priority reads + OrderedDict = dict + + +class IdleConnectionManager(object): + def __init__(self, connections_max_idle_ms): + if connections_max_idle_ms > 0: + self.connections_max_idle = connections_max_idle_ms / 1000 + else: + self.connections_max_idle = float('inf') + self.next_idle_close_check_time = None + self.update_next_idle_close_check_time(time.time()) + self.lru_connections = OrderedDict() + + def update(self, conn_id): + # order should reflect last-update + if conn_id in self.lru_connections: + del self.lru_connections[conn_id] + self.lru_connections[conn_id] = time.time() + + def remove(self, conn_id): + if conn_id in self.lru_connections: + del self.lru_connections[conn_id] + + def is_expired(self, conn_id): + if conn_id not in self.lru_connections: + return None + return time.time() >= self.lru_connections[conn_id] + self.connections_max_idle + + def next_check_ms(self): + now = time.time() + if not self.lru_connections: + return float('inf') + elif self.next_idle_close_check_time <= now: + return 0 + else: + return int((self.next_idle_close_check_time - now) * 1000) + + def update_next_idle_close_check_time(self, ts): + self.next_idle_close_check_time = ts + self.connections_max_idle + + def poll_expired_connection(self): + if time.time() < self.next_idle_close_check_time: + return None + + if not len(self.lru_connections): + return None + + oldest_conn_id = None + oldest_ts = None + if OrderedDict is dict: + for conn_id, ts in self.lru_connections.items(): + if oldest_conn_id is None or ts < oldest_ts: + oldest_conn_id = conn_id + oldest_ts = ts + else: + (oldest_conn_id, oldest_ts) = next(iter(self.lru_connections.items())) + + self.update_next_idle_close_check_time(oldest_ts) + + if time.time() >= oldest_ts + self.connections_max_idle: + return (oldest_conn_id, oldest_ts) + else: + return None + + +class KafkaClientMetrics(object): + def __init__(self, metrics, metric_group_prefix, conns): + self.metrics = metrics + self.metric_group_name = metric_group_prefix + '-metrics' + + self.connection_closed = metrics.sensor('connections-closed') + self.connection_closed.add(metrics.metric_name( + 'connection-close-rate', self.metric_group_name, + 'Connections closed per second in the window.'), Rate()) + self.connection_created = metrics.sensor('connections-created') + self.connection_created.add(metrics.metric_name( + 'connection-creation-rate', self.metric_group_name, + 'New connections established per second in the window.'), Rate()) + + self.select_time = metrics.sensor('select-time') + self.select_time.add(metrics.metric_name( + 'select-rate', self.metric_group_name, + 'Number of times the I/O layer checked for new I/O to perform per' + ' second'), Rate(sampled_stat=Count())) + self.select_time.add(metrics.metric_name( + 'io-wait-time-ns-avg', self.metric_group_name, + 'The average length of time the I/O thread spent waiting for a' + ' socket ready for reads or writes in nanoseconds.'), Avg()) + self.select_time.add(metrics.metric_name( + 'io-wait-ratio', self.metric_group_name, + 'The fraction of time the I/O thread spent waiting.'), + Rate(time_unit=TimeUnit.NANOSECONDS)) + + self.io_time = metrics.sensor('io-time') + self.io_time.add(metrics.metric_name( + 'io-time-ns-avg', self.metric_group_name, + 'The average length of time for I/O per select call in nanoseconds.'), + Avg()) + self.io_time.add(metrics.metric_name( + 'io-ratio', self.metric_group_name, + 'The fraction of time the I/O thread spent doing I/O'), + Rate(time_unit=TimeUnit.NANOSECONDS)) + + metrics.add_metric(metrics.metric_name( + 'connection-count', self.metric_group_name, + 'The current number of active connections.'), AnonMeasurable( + lambda config, now: len(conns))) diff --git a/cluster.py b/cluster.py new file mode 100644 index 00000000..438baf29 --- /dev/null +++ b/cluster.py @@ -0,0 +1,397 @@ +from __future__ import absolute_import + +import collections +import copy +import logging +import threading +import time + +from kafka.vendor import six + +from kafka import errors as Errors +from kafka.conn import collect_hosts +from kafka.future import Future +from kafka.structs import BrokerMetadata, PartitionMetadata, TopicPartition + +log = logging.getLogger(__name__) + + +class ClusterMetadata(object): + """ + A class to manage kafka cluster metadata. + + This class does not perform any IO. It simply updates internal state + given API responses (MetadataResponse, GroupCoordinatorResponse). + + Keyword Arguments: + retry_backoff_ms (int): Milliseconds to backoff when retrying on + errors. Default: 100. + metadata_max_age_ms (int): The period of time in milliseconds after + which we force a refresh of metadata even if we haven't seen any + partition leadership changes to proactively discover any new + brokers or partitions. Default: 300000 + bootstrap_servers: 'host[:port]' string (or list of 'host[:port]' + strings) that the client should contact to bootstrap initial + cluster metadata. This does not have to be the full node list. + It just needs to have at least one broker that will respond to a + Metadata API Request. Default port is 9092. If no servers are + specified, will default to localhost:9092. + """ + DEFAULT_CONFIG = { + 'retry_backoff_ms': 100, + 'metadata_max_age_ms': 300000, + 'bootstrap_servers': [], + } + + def __init__(self, **configs): + self._brokers = {} # node_id -> BrokerMetadata + self._partitions = {} # topic -> partition -> PartitionMetadata + self._broker_partitions = collections.defaultdict(set) # node_id -> {TopicPartition...} + self._groups = {} # group_name -> node_id + self._last_refresh_ms = 0 + self._last_successful_refresh_ms = 0 + self._need_update = True + self._future = None + self._listeners = set() + self._lock = threading.Lock() + self.need_all_topic_metadata = False + self.unauthorized_topics = set() + self.internal_topics = set() + self.controller = None + + self.config = copy.copy(self.DEFAULT_CONFIG) + for key in self.config: + if key in configs: + self.config[key] = configs[key] + + self._bootstrap_brokers = self._generate_bootstrap_brokers() + self._coordinator_brokers = {} + + def _generate_bootstrap_brokers(self): + # collect_hosts does not perform DNS, so we should be fine to re-use + bootstrap_hosts = collect_hosts(self.config['bootstrap_servers']) + + brokers = {} + for i, (host, port, _) in enumerate(bootstrap_hosts): + node_id = 'bootstrap-%s' % i + brokers[node_id] = BrokerMetadata(node_id, host, port, None) + return brokers + + def is_bootstrap(self, node_id): + return node_id in self._bootstrap_brokers + + def brokers(self): + """Get all BrokerMetadata + + Returns: + set: {BrokerMetadata, ...} + """ + return set(self._brokers.values()) or set(self._bootstrap_brokers.values()) + + def broker_metadata(self, broker_id): + """Get BrokerMetadata + + Arguments: + broker_id (int): node_id for a broker to check + + Returns: + BrokerMetadata or None if not found + """ + return ( + self._brokers.get(broker_id) or + self._bootstrap_brokers.get(broker_id) or + self._coordinator_brokers.get(broker_id) + ) + + def partitions_for_topic(self, topic): + """Return set of all partitions for topic (whether available or not) + + Arguments: + topic (str): topic to check for partitions + + Returns: + set: {partition (int), ...} + """ + if topic not in self._partitions: + return None + return set(self._partitions[topic].keys()) + + def available_partitions_for_topic(self, topic): + """Return set of partitions with known leaders + + Arguments: + topic (str): topic to check for partitions + + Returns: + set: {partition (int), ...} + None if topic not found. + """ + if topic not in self._partitions: + return None + return set([partition for partition, metadata + in six.iteritems(self._partitions[topic]) + if metadata.leader != -1]) + + def leader_for_partition(self, partition): + """Return node_id of leader, -1 unavailable, None if unknown.""" + if partition.topic not in self._partitions: + return None + elif partition.partition not in self._partitions[partition.topic]: + return None + return self._partitions[partition.topic][partition.partition].leader + + def partitions_for_broker(self, broker_id): + """Return TopicPartitions for which the broker is a leader. + + Arguments: + broker_id (int): node id for a broker + + Returns: + set: {TopicPartition, ...} + None if the broker either has no partitions or does not exist. + """ + return self._broker_partitions.get(broker_id) + + def coordinator_for_group(self, group): + """Return node_id of group coordinator. + + Arguments: + group (str): name of consumer group + + Returns: + int: node_id for group coordinator + None if the group does not exist. + """ + return self._groups.get(group) + + def ttl(self): + """Milliseconds until metadata should be refreshed""" + now = time.time() * 1000 + if self._need_update: + ttl = 0 + else: + metadata_age = now - self._last_successful_refresh_ms + ttl = self.config['metadata_max_age_ms'] - metadata_age + + retry_age = now - self._last_refresh_ms + next_retry = self.config['retry_backoff_ms'] - retry_age + + return max(ttl, next_retry, 0) + + def refresh_backoff(self): + """Return milliseconds to wait before attempting to retry after failure""" + return self.config['retry_backoff_ms'] + + def request_update(self): + """Flags metadata for update, return Future() + + Actual update must be handled separately. This method will only + change the reported ttl() + + Returns: + kafka.future.Future (value will be the cluster object after update) + """ + with self._lock: + self._need_update = True + if not self._future or self._future.is_done: + self._future = Future() + return self._future + + def topics(self, exclude_internal_topics=True): + """Get set of known topics. + + Arguments: + exclude_internal_topics (bool): Whether records from internal topics + (such as offsets) should be exposed to the consumer. If set to + True the only way to receive records from an internal topic is + subscribing to it. Default True + + Returns: + set: {topic (str), ...} + """ + topics = set(self._partitions.keys()) + if exclude_internal_topics: + return topics - self.internal_topics + else: + return topics + + def failed_update(self, exception): + """Update cluster state given a failed MetadataRequest.""" + f = None + with self._lock: + if self._future: + f = self._future + self._future = None + if f: + f.failure(exception) + self._last_refresh_ms = time.time() * 1000 + + def update_metadata(self, metadata): + """Update cluster state given a MetadataResponse. + + Arguments: + metadata (MetadataResponse): broker response to a metadata request + + Returns: None + """ + # In the common case where we ask for a single topic and get back an + # error, we should fail the future + if len(metadata.topics) == 1 and metadata.topics[0][0] != 0: + error_code, topic = metadata.topics[0][:2] + error = Errors.for_code(error_code)(topic) + return self.failed_update(error) + + if not metadata.brokers: + log.warning("No broker metadata found in MetadataResponse -- ignoring.") + return self.failed_update(Errors.MetadataEmptyBrokerList(metadata)) + + _new_brokers = {} + for broker in metadata.brokers: + if metadata.API_VERSION == 0: + node_id, host, port = broker + rack = None + else: + node_id, host, port, rack = broker + _new_brokers.update({ + node_id: BrokerMetadata(node_id, host, port, rack) + }) + + if metadata.API_VERSION == 0: + _new_controller = None + else: + _new_controller = _new_brokers.get(metadata.controller_id) + + _new_partitions = {} + _new_broker_partitions = collections.defaultdict(set) + _new_unauthorized_topics = set() + _new_internal_topics = set() + + for topic_data in metadata.topics: + if metadata.API_VERSION == 0: + error_code, topic, partitions = topic_data + is_internal = False + else: + error_code, topic, is_internal, partitions = topic_data + if is_internal: + _new_internal_topics.add(topic) + error_type = Errors.for_code(error_code) + if error_type is Errors.NoError: + _new_partitions[topic] = {} + for p_error, partition, leader, replicas, isr in partitions: + _new_partitions[topic][partition] = PartitionMetadata( + topic=topic, partition=partition, leader=leader, + replicas=replicas, isr=isr, error=p_error) + if leader != -1: + _new_broker_partitions[leader].add( + TopicPartition(topic, partition)) + + # Specific topic errors can be ignored if this is a full metadata fetch + elif self.need_all_topic_metadata: + continue + + elif error_type is Errors.LeaderNotAvailableError: + log.warning("Topic %s is not available during auto-create" + " initialization", topic) + elif error_type is Errors.UnknownTopicOrPartitionError: + log.error("Topic %s not found in cluster metadata", topic) + elif error_type is Errors.TopicAuthorizationFailedError: + log.error("Topic %s is not authorized for this client", topic) + _new_unauthorized_topics.add(topic) + elif error_type is Errors.InvalidTopicError: + log.error("'%s' is not a valid topic name", topic) + else: + log.error("Error fetching metadata for topic %s: %s", + topic, error_type) + + with self._lock: + self._brokers = _new_brokers + self.controller = _new_controller + self._partitions = _new_partitions + self._broker_partitions = _new_broker_partitions + self.unauthorized_topics = _new_unauthorized_topics + self.internal_topics = _new_internal_topics + f = None + if self._future: + f = self._future + self._future = None + self._need_update = False + + now = time.time() * 1000 + self._last_refresh_ms = now + self._last_successful_refresh_ms = now + + if f: + f.success(self) + log.debug("Updated cluster metadata to %s", self) + + for listener in self._listeners: + listener(self) + + if self.need_all_topic_metadata: + # the listener may change the interested topics, + # which could cause another metadata refresh. + # If we have already fetched all topics, however, + # another fetch should be unnecessary. + self._need_update = False + + def add_listener(self, listener): + """Add a callback function to be called on each metadata update""" + self._listeners.add(listener) + + def remove_listener(self, listener): + """Remove a previously added listener callback""" + self._listeners.remove(listener) + + def add_group_coordinator(self, group, response): + """Update with metadata for a group coordinator + + Arguments: + group (str): name of group from GroupCoordinatorRequest + response (GroupCoordinatorResponse): broker response + + Returns: + string: coordinator node_id if metadata is updated, None on error + """ + log.debug("Updating coordinator for %s: %s", group, response) + error_type = Errors.for_code(response.error_code) + if error_type is not Errors.NoError: + log.error("GroupCoordinatorResponse error: %s", error_type) + self._groups[group] = -1 + return + + # Use a coordinator-specific node id so that group requests + # get a dedicated connection + node_id = 'coordinator-{}'.format(response.coordinator_id) + coordinator = BrokerMetadata( + node_id, + response.host, + response.port, + None) + + log.info("Group coordinator for %s is %s", group, coordinator) + self._coordinator_brokers[node_id] = coordinator + self._groups[group] = node_id + return node_id + + def with_partitions(self, partitions_to_add): + """Returns a copy of cluster metadata with partitions added""" + new_metadata = ClusterMetadata(**self.config) + new_metadata._brokers = copy.deepcopy(self._brokers) + new_metadata._partitions = copy.deepcopy(self._partitions) + new_metadata._broker_partitions = copy.deepcopy(self._broker_partitions) + new_metadata._groups = copy.deepcopy(self._groups) + new_metadata.internal_topics = copy.deepcopy(self.internal_topics) + new_metadata.unauthorized_topics = copy.deepcopy(self.unauthorized_topics) + + for partition in partitions_to_add: + new_metadata._partitions[partition.topic][partition.partition] = partition + + if partition.leader is not None and partition.leader != -1: + new_metadata._broker_partitions[partition.leader].add( + TopicPartition(partition.topic, partition.partition)) + + return new_metadata + + def __str__(self): + return 'ClusterMetadata(brokers: %d, topics: %d, groups: %d)' % \ + (len(self._brokers), len(self._partitions), len(self._groups)) diff --git a/codec.py b/codec.py new file mode 100644 index 00000000..c740a181 --- /dev/null +++ b/codec.py @@ -0,0 +1,326 @@ +from __future__ import absolute_import + +import gzip +import io +import platform +import struct + +from kafka.vendor import six +from kafka.vendor.six.moves import range + +_XERIAL_V1_HEADER = (-126, b'S', b'N', b'A', b'P', b'P', b'Y', 0, 1, 1) +_XERIAL_V1_FORMAT = 'bccccccBii' +ZSTD_MAX_OUTPUT_SIZE = 1024 * 1024 + +try: + import snappy +except ImportError: + snappy = None + +try: + import zstandard as zstd +except ImportError: + zstd = None + +try: + import lz4.frame as lz4 + + def _lz4_compress(payload, **kwargs): + # Kafka does not support LZ4 dependent blocks + try: + # For lz4>=0.12.0 + kwargs.pop('block_linked', None) + return lz4.compress(payload, block_linked=False, **kwargs) + except TypeError: + # For earlier versions of lz4 + kwargs.pop('block_mode', None) + return lz4.compress(payload, block_mode=1, **kwargs) + +except ImportError: + lz4 = None + +try: + import lz4f +except ImportError: + lz4f = None + +try: + import lz4framed +except ImportError: + lz4framed = None + +try: + import xxhash +except ImportError: + xxhash = None + +PYPY = bool(platform.python_implementation() == 'PyPy') + +def has_gzip(): + return True + + +def has_snappy(): + return snappy is not None + + +def has_zstd(): + return zstd is not None + + +def has_lz4(): + if lz4 is not None: + return True + if lz4f is not None: + return True + if lz4framed is not None: + return True + return False + + +def gzip_encode(payload, compresslevel=None): + if not compresslevel: + compresslevel = 9 + + buf = io.BytesIO() + + # Gzip context manager introduced in python 2.7 + # so old-fashioned way until we decide to not support 2.6 + gzipper = gzip.GzipFile(fileobj=buf, mode="w", compresslevel=compresslevel) + try: + gzipper.write(payload) + finally: + gzipper.close() + + return buf.getvalue() + + +def gzip_decode(payload): + buf = io.BytesIO(payload) + + # Gzip context manager introduced in python 2.7 + # so old-fashioned way until we decide to not support 2.6 + gzipper = gzip.GzipFile(fileobj=buf, mode='r') + try: + return gzipper.read() + finally: + gzipper.close() + + +def snappy_encode(payload, xerial_compatible=True, xerial_blocksize=32*1024): + """Encodes the given data with snappy compression. + + If xerial_compatible is set then the stream is encoded in a fashion + compatible with the xerial snappy library. + + The block size (xerial_blocksize) controls how frequent the blocking occurs + 32k is the default in the xerial library. + + The format winds up being: + + + +-------------+------------+--------------+------------+--------------+ + | Header | Block1 len | Block1 data | Blockn len | Blockn data | + +-------------+------------+--------------+------------+--------------+ + | 16 bytes | BE int32 | snappy bytes | BE int32 | snappy bytes | + +-------------+------------+--------------+------------+--------------+ + + + It is important to note that the blocksize is the amount of uncompressed + data presented to snappy at each block, whereas the blocklen is the number + of bytes that will be present in the stream; so the length will always be + <= blocksize. + + """ + + if not has_snappy(): + raise NotImplementedError("Snappy codec is not available") + + if not xerial_compatible: + return snappy.compress(payload) + + out = io.BytesIO() + for fmt, dat in zip(_XERIAL_V1_FORMAT, _XERIAL_V1_HEADER): + out.write(struct.pack('!' + fmt, dat)) + + # Chunk through buffers to avoid creating intermediate slice copies + if PYPY: + # on pypy, snappy.compress() on a sliced buffer consumes the entire + # buffer... likely a python-snappy bug, so just use a slice copy + chunker = lambda payload, i, size: payload[i:size+i] + + elif six.PY2: + # Sliced buffer avoids additional copies + # pylint: disable-msg=undefined-variable + chunker = lambda payload, i, size: buffer(payload, i, size) + else: + # snappy.compress does not like raw memoryviews, so we have to convert + # tobytes, which is a copy... oh well. it's the thought that counts. + # pylint: disable-msg=undefined-variable + chunker = lambda payload, i, size: memoryview(payload)[i:size+i].tobytes() + + for chunk in (chunker(payload, i, xerial_blocksize) + for i in range(0, len(payload), xerial_blocksize)): + + block = snappy.compress(chunk) + block_size = len(block) + out.write(struct.pack('!i', block_size)) + out.write(block) + + return out.getvalue() + + +def _detect_xerial_stream(payload): + """Detects if the data given might have been encoded with the blocking mode + of the xerial snappy library. + + This mode writes a magic header of the format: + +--------+--------------+------------+---------+--------+ + | Marker | Magic String | Null / Pad | Version | Compat | + +--------+--------------+------------+---------+--------+ + | byte | c-string | byte | int32 | int32 | + +--------+--------------+------------+---------+--------+ + | -126 | 'SNAPPY' | \0 | | | + +--------+--------------+------------+---------+--------+ + + The pad appears to be to ensure that SNAPPY is a valid cstring + The version is the version of this format as written by xerial, + in the wild this is currently 1 as such we only support v1. + + Compat is there to claim the minimum supported version that + can read a xerial block stream, presently in the wild this is + 1. + """ + + if len(payload) > 16: + header = struct.unpack('!' + _XERIAL_V1_FORMAT, bytes(payload)[:16]) + return header == _XERIAL_V1_HEADER + return False + + +def snappy_decode(payload): + if not has_snappy(): + raise NotImplementedError("Snappy codec is not available") + + if _detect_xerial_stream(payload): + # TODO ? Should become a fileobj ? + out = io.BytesIO() + byt = payload[16:] + length = len(byt) + cursor = 0 + + while cursor < length: + block_size = struct.unpack_from('!i', byt[cursor:])[0] + # Skip the block size + cursor += 4 + end = cursor + block_size + out.write(snappy.decompress(byt[cursor:end])) + cursor = end + + out.seek(0) + return out.read() + else: + return snappy.decompress(payload) + + +if lz4: + lz4_encode = _lz4_compress # pylint: disable-msg=no-member +elif lz4f: + lz4_encode = lz4f.compressFrame # pylint: disable-msg=no-member +elif lz4framed: + lz4_encode = lz4framed.compress # pylint: disable-msg=no-member +else: + lz4_encode = None + + +def lz4f_decode(payload): + """Decode payload using interoperable LZ4 framing. Requires Kafka >= 0.10""" + # pylint: disable-msg=no-member + ctx = lz4f.createDecompContext() + data = lz4f.decompressFrame(payload, ctx) + lz4f.freeDecompContext(ctx) + + # lz4f python module does not expose how much of the payload was + # actually read if the decompression was only partial. + if data['next'] != 0: + raise RuntimeError('lz4f unable to decompress full payload') + return data['decomp'] + + +if lz4: + lz4_decode = lz4.decompress # pylint: disable-msg=no-member +elif lz4f: + lz4_decode = lz4f_decode +elif lz4framed: + lz4_decode = lz4framed.decompress # pylint: disable-msg=no-member +else: + lz4_decode = None + + +def lz4_encode_old_kafka(payload): + """Encode payload for 0.8/0.9 brokers -- requires an incorrect header checksum.""" + assert xxhash is not None + data = lz4_encode(payload) + header_size = 7 + flg = data[4] + if not isinstance(flg, int): + flg = ord(flg) + + content_size_bit = ((flg >> 3) & 1) + if content_size_bit: + # Old kafka does not accept the content-size field + # so we need to discard it and reset the header flag + flg -= 8 + data = bytearray(data) + data[4] = flg + data = bytes(data) + payload = data[header_size+8:] + else: + payload = data[header_size:] + + # This is the incorrect hc + hc = xxhash.xxh32(data[0:header_size-1]).digest()[-2:-1] # pylint: disable-msg=no-member + + return b''.join([ + data[0:header_size-1], + hc, + payload + ]) + + +def lz4_decode_old_kafka(payload): + assert xxhash is not None + # Kafka's LZ4 code has a bug in its header checksum implementation + header_size = 7 + if isinstance(payload[4], int): + flg = payload[4] + else: + flg = ord(payload[4]) + content_size_bit = ((flg >> 3) & 1) + if content_size_bit: + header_size += 8 + + # This should be the correct hc + hc = xxhash.xxh32(payload[4:header_size-1]).digest()[-2:-1] # pylint: disable-msg=no-member + + munged_payload = b''.join([ + payload[0:header_size-1], + hc, + payload[header_size:] + ]) + return lz4_decode(munged_payload) + + +def zstd_encode(payload): + if not zstd: + raise NotImplementedError("Zstd codec is not available") + return zstd.ZstdCompressor().compress(payload) + + +def zstd_decode(payload): + if not zstd: + raise NotImplementedError("Zstd codec is not available") + try: + return zstd.ZstdDecompressor().decompress(payload) + except zstd.ZstdError: + return zstd.ZstdDecompressor().decompress(payload, max_output_size=ZSTD_MAX_OUTPUT_SIZE) diff --git a/conn.py b/conn.py new file mode 100644 index 00000000..cac35487 --- /dev/null +++ b/conn.py @@ -0,0 +1,1534 @@ +from __future__ import absolute_import, division + +import copy +import errno +import io +import logging +from random import shuffle, uniform + +# selectors in stdlib as of py3.4 +try: + import selectors # pylint: disable=import-error +except ImportError: + # vendored backport module + from kafka.vendor import selectors34 as selectors + +import socket +import struct +import threading +import time + +from kafka.vendor import six + +import kafka.errors as Errors +from kafka.future import Future +from kafka.metrics.stats import Avg, Count, Max, Rate +from kafka.oauth.abstract import AbstractTokenProvider +from kafka.protocol.admin import SaslHandShakeRequest, DescribeAclsRequest_v2, DescribeClientQuotasRequest +from kafka.protocol.commit import OffsetFetchRequest +from kafka.protocol.offset import OffsetRequest +from kafka.protocol.produce import ProduceRequest +from kafka.protocol.metadata import MetadataRequest +from kafka.protocol.fetch import FetchRequest +from kafka.protocol.parser import KafkaProtocol +from kafka.protocol.types import Int32, Int8 +from kafka.scram import ScramClient +from kafka.version import __version__ + + +if six.PY2: + ConnectionError = socket.error + TimeoutError = socket.error + BlockingIOError = Exception + +log = logging.getLogger(__name__) + +DEFAULT_KAFKA_PORT = 9092 + +SASL_QOP_AUTH = 1 +SASL_QOP_AUTH_INT = 2 +SASL_QOP_AUTH_CONF = 4 + +try: + import ssl + ssl_available = True + try: + SSLEOFError = ssl.SSLEOFError + SSLWantReadError = ssl.SSLWantReadError + SSLWantWriteError = ssl.SSLWantWriteError + SSLZeroReturnError = ssl.SSLZeroReturnError + except AttributeError: + # support older ssl libraries + log.warning('Old SSL module detected.' + ' SSL error handling may not operate cleanly.' + ' Consider upgrading to Python 3.3 or 2.7.9') + SSLEOFError = ssl.SSLError + SSLWantReadError = ssl.SSLError + SSLWantWriteError = ssl.SSLError + SSLZeroReturnError = ssl.SSLError +except ImportError: + # support Python without ssl libraries + ssl_available = False + class SSLWantReadError(Exception): + pass + class SSLWantWriteError(Exception): + pass + +# needed for SASL_GSSAPI authentication: +try: + import gssapi + from gssapi.raw.misc import GSSError +except ImportError: + #no gssapi available, will disable gssapi mechanism + gssapi = None + GSSError = None + + +AFI_NAMES = { + socket.AF_UNSPEC: "unspecified", + socket.AF_INET: "IPv4", + socket.AF_INET6: "IPv6", +} + + +class ConnectionStates(object): + DISCONNECTING = '' + DISCONNECTED = '' + CONNECTING = '' + HANDSHAKE = '' + CONNECTED = '' + AUTHENTICATING = '' + + +class BrokerConnection(object): + """Initialize a Kafka broker connection + + Keyword Arguments: + client_id (str): a name for this client. This string is passed in + each request to servers and can be used to identify specific + server-side log entries that correspond to this client. Also + submitted to GroupCoordinator for logging with respect to + consumer group administration. Default: 'kafka-python-{version}' + reconnect_backoff_ms (int): The amount of time in milliseconds to + wait before attempting to reconnect to a given host. + Default: 50. + reconnect_backoff_max_ms (int): The maximum amount of time in + milliseconds to backoff/wait when reconnecting to a broker that has + repeatedly failed to connect. If provided, the backoff per host + will increase exponentially for each consecutive connection + failure, up to this maximum. Once the maximum is reached, + reconnection attempts will continue periodically with this fixed + rate. To avoid connection storms, a randomization factor of 0.2 + will be applied to the backoff resulting in a random range between + 20% below and 20% above the computed value. Default: 1000. + request_timeout_ms (int): Client request timeout in milliseconds. + Default: 30000. + max_in_flight_requests_per_connection (int): Requests are pipelined + to kafka brokers up to this number of maximum requests per + broker connection. Default: 5. + receive_buffer_bytes (int): The size of the TCP receive buffer + (SO_RCVBUF) to use when reading data. Default: None (relies on + system defaults). Java client defaults to 32768. + send_buffer_bytes (int): The size of the TCP send buffer + (SO_SNDBUF) to use when sending data. Default: None (relies on + system defaults). Java client defaults to 131072. + socket_options (list): List of tuple-arguments to socket.setsockopt + to apply to broker connection sockets. Default: + [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)] + security_protocol (str): Protocol used to communicate with brokers. + Valid values are: PLAINTEXT, SSL, SASL_PLAINTEXT, SASL_SSL. + Default: PLAINTEXT. + ssl_context (ssl.SSLContext): pre-configured SSLContext for wrapping + socket connections. If provided, all other ssl_* configurations + will be ignored. Default: None. + ssl_check_hostname (bool): flag to configure whether ssl handshake + should verify that the certificate matches the brokers hostname. + default: True. + ssl_cafile (str): optional filename of ca file to use in certificate + verification. default: None. + ssl_certfile (str): optional filename of file in pem format containing + the client certificate, as well as any ca certificates needed to + establish the certificate's authenticity. default: None. + ssl_keyfile (str): optional filename containing the client private key. + default: None. + ssl_password (callable, str, bytes, bytearray): optional password or + callable function that returns a password, for decrypting the + client private key. Default: None. + ssl_crlfile (str): optional filename containing the CRL to check for + certificate expiration. By default, no CRL check is done. When + providing a file, only the leaf certificate will be checked against + this CRL. The CRL can only be checked with Python 3.4+ or 2.7.9+. + default: None. + ssl_ciphers (str): optionally set the available ciphers for ssl + connections. It should be a string in the OpenSSL cipher list + format. If no cipher can be selected (because compile-time options + or other configuration forbids use of all the specified ciphers), + an ssl.SSLError will be raised. See ssl.SSLContext.set_ciphers + api_version (tuple): Specify which Kafka API version to use. + Accepted values are: (0, 8, 0), (0, 8, 1), (0, 8, 2), (0, 9), + (0, 10). Default: (0, 8, 2) + api_version_auto_timeout_ms (int): number of milliseconds to throw a + timeout exception from the constructor when checking the broker + api version. Only applies if api_version is None + selector (selectors.BaseSelector): Provide a specific selector + implementation to use for I/O multiplexing. + Default: selectors.DefaultSelector + state_change_callback (callable): function to be called when the + connection state changes from CONNECTING to CONNECTED etc. + metrics (kafka.metrics.Metrics): Optionally provide a metrics + instance for capturing network IO stats. Default: None. + metric_group_prefix (str): Prefix for metric names. Default: '' + sasl_mechanism (str): Authentication mechanism when security_protocol + is configured for SASL_PLAINTEXT or SASL_SSL. Valid values are: + PLAIN, GSSAPI, OAUTHBEARER, SCRAM-SHA-256, SCRAM-SHA-512. + sasl_plain_username (str): username for sasl PLAIN and SCRAM authentication. + Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms. + sasl_plain_password (str): password for sasl PLAIN and SCRAM authentication. + Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms. + sasl_kerberos_service_name (str): Service name to include in GSSAPI + sasl mechanism handshake. Default: 'kafka' + sasl_kerberos_domain_name (str): kerberos domain name to use in GSSAPI + sasl mechanism handshake. Default: one of bootstrap servers + sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider + instance. (See kafka.oauth.abstract). Default: None + """ + + DEFAULT_CONFIG = { + 'client_id': 'kafka-python-' + __version__, + 'node_id': 0, + 'request_timeout_ms': 30000, + 'reconnect_backoff_ms': 50, + 'reconnect_backoff_max_ms': 1000, + 'max_in_flight_requests_per_connection': 5, + 'receive_buffer_bytes': None, + 'send_buffer_bytes': None, + 'socket_options': [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)], + 'sock_chunk_bytes': 4096, # undocumented experimental option + 'sock_chunk_buffer_count': 1000, # undocumented experimental option + 'security_protocol': 'PLAINTEXT', + 'ssl_context': None, + 'ssl_check_hostname': True, + 'ssl_cafile': None, + 'ssl_certfile': None, + 'ssl_keyfile': None, + 'ssl_crlfile': None, + 'ssl_password': None, + 'ssl_ciphers': None, + 'api_version': (0, 8, 2), # default to most restrictive + 'selector': selectors.DefaultSelector, + 'state_change_callback': lambda node_id, sock, conn: True, + 'metrics': None, + 'metric_group_prefix': '', + 'sasl_mechanism': None, + 'sasl_plain_username': None, + 'sasl_plain_password': None, + 'sasl_kerberos_service_name': 'kafka', + 'sasl_kerberos_domain_name': None, + 'sasl_oauth_token_provider': None + } + SECURITY_PROTOCOLS = ('PLAINTEXT', 'SSL', 'SASL_PLAINTEXT', 'SASL_SSL') + SASL_MECHANISMS = ('PLAIN', 'GSSAPI', 'OAUTHBEARER', "SCRAM-SHA-256", "SCRAM-SHA-512") + + def __init__(self, host, port, afi, **configs): + self.host = host + self.port = port + self.afi = afi + self._sock_afi = afi + self._sock_addr = None + self._api_versions = None + + self.config = copy.copy(self.DEFAULT_CONFIG) + for key in self.config: + if key in configs: + self.config[key] = configs[key] + + self.node_id = self.config.pop('node_id') + + if self.config['receive_buffer_bytes'] is not None: + self.config['socket_options'].append( + (socket.SOL_SOCKET, socket.SO_RCVBUF, + self.config['receive_buffer_bytes'])) + if self.config['send_buffer_bytes'] is not None: + self.config['socket_options'].append( + (socket.SOL_SOCKET, socket.SO_SNDBUF, + self.config['send_buffer_bytes'])) + + assert self.config['security_protocol'] in self.SECURITY_PROTOCOLS, ( + 'security_protocol must be in ' + ', '.join(self.SECURITY_PROTOCOLS)) + + if self.config['security_protocol'] in ('SSL', 'SASL_SSL'): + assert ssl_available, "Python wasn't built with SSL support" + + if self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL'): + assert self.config['sasl_mechanism'] in self.SASL_MECHANISMS, ( + 'sasl_mechanism must be in ' + ', '.join(self.SASL_MECHANISMS)) + if self.config['sasl_mechanism'] in ('PLAIN', 'SCRAM-SHA-256', 'SCRAM-SHA-512'): + assert self.config['sasl_plain_username'] is not None, ( + 'sasl_plain_username required for PLAIN or SCRAM sasl' + ) + assert self.config['sasl_plain_password'] is not None, ( + 'sasl_plain_password required for PLAIN or SCRAM sasl' + ) + if self.config['sasl_mechanism'] == 'GSSAPI': + assert gssapi is not None, 'GSSAPI lib not available' + assert self.config['sasl_kerberos_service_name'] is not None, 'sasl_kerberos_service_name required for GSSAPI sasl' + if self.config['sasl_mechanism'] == 'OAUTHBEARER': + token_provider = self.config['sasl_oauth_token_provider'] + assert token_provider is not None, 'sasl_oauth_token_provider required for OAUTHBEARER sasl' + assert callable(getattr(token_provider, "token", None)), 'sasl_oauth_token_provider must implement method #token()' + # This is not a general lock / this class is not generally thread-safe yet + # However, to avoid pushing responsibility for maintaining + # per-connection locks to the upstream client, we will use this lock to + # make sure that access to the protocol buffer is synchronized + # when sends happen on multiple threads + self._lock = threading.Lock() + + # the protocol parser instance manages actual tracking of the + # sequence of in-flight requests to responses, which should + # function like a FIFO queue. For additional request data, + # including tracking request futures and timestamps, we + # can use a simple dictionary of correlation_id => request data + self.in_flight_requests = dict() + + self._protocol = KafkaProtocol( + client_id=self.config['client_id'], + api_version=self.config['api_version']) + self.state = ConnectionStates.DISCONNECTED + self._reset_reconnect_backoff() + self._sock = None + self._send_buffer = b'' + self._ssl_context = None + if self.config['ssl_context'] is not None: + self._ssl_context = self.config['ssl_context'] + self._sasl_auth_future = None + self.last_attempt = 0 + self._gai = [] + self._sensors = None + if self.config['metrics']: + self._sensors = BrokerConnectionMetrics(self.config['metrics'], + self.config['metric_group_prefix'], + self.node_id) + + def _dns_lookup(self): + self._gai = dns_lookup(self.host, self.port, self.afi) + if not self._gai: + log.error('DNS lookup failed for %s:%i (%s)', + self.host, self.port, self.afi) + return False + return True + + def _next_afi_sockaddr(self): + if not self._gai: + if not self._dns_lookup(): + return + afi, _, __, ___, sockaddr = self._gai.pop(0) + return (afi, sockaddr) + + def connect_blocking(self, timeout=float('inf')): + if self.connected(): + return True + timeout += time.time() + # First attempt to perform dns lookup + # note that the underlying interface, socket.getaddrinfo, + # has no explicit timeout so we may exceed the user-specified timeout + self._dns_lookup() + + # Loop once over all returned dns entries + selector = None + while self._gai: + while time.time() < timeout: + self.connect() + if self.connected(): + if selector is not None: + selector.close() + return True + elif self.connecting(): + if selector is None: + selector = self.config['selector']() + selector.register(self._sock, selectors.EVENT_WRITE) + selector.select(1) + elif self.disconnected(): + if selector is not None: + selector.close() + selector = None + break + else: + break + return False + + def connect(self): + """Attempt to connect and return ConnectionState""" + if self.state is ConnectionStates.DISCONNECTED and not self.blacked_out(): + self.last_attempt = time.time() + next_lookup = self._next_afi_sockaddr() + if not next_lookup: + self.close(Errors.KafkaConnectionError('DNS failure')) + return self.state + else: + log.debug('%s: creating new socket', self) + assert self._sock is None + self._sock_afi, self._sock_addr = next_lookup + self._sock = socket.socket(self._sock_afi, socket.SOCK_STREAM) + + for option in self.config['socket_options']: + log.debug('%s: setting socket option %s', self, option) + self._sock.setsockopt(*option) + + self._sock.setblocking(False) + self.state = ConnectionStates.CONNECTING + self.config['state_change_callback'](self.node_id, self._sock, self) + log.info('%s: connecting to %s:%d [%s %s]', self, self.host, + self.port, self._sock_addr, AFI_NAMES[self._sock_afi]) + + if self.state is ConnectionStates.CONNECTING: + # in non-blocking mode, use repeated calls to socket.connect_ex + # to check connection status + ret = None + try: + ret = self._sock.connect_ex(self._sock_addr) + except socket.error as err: + ret = err.errno + + # Connection succeeded + if not ret or ret == errno.EISCONN: + log.debug('%s: established TCP connection', self) + + if self.config['security_protocol'] in ('SSL', 'SASL_SSL'): + log.debug('%s: initiating SSL handshake', self) + self.state = ConnectionStates.HANDSHAKE + self.config['state_change_callback'](self.node_id, self._sock, self) + # _wrap_ssl can alter the connection state -- disconnects on failure + self._wrap_ssl() + + elif self.config['security_protocol'] == 'SASL_PLAINTEXT': + log.debug('%s: initiating SASL authentication', self) + self.state = ConnectionStates.AUTHENTICATING + self.config['state_change_callback'](self.node_id, self._sock, self) + + else: + # security_protocol PLAINTEXT + log.info('%s: Connection complete.', self) + self.state = ConnectionStates.CONNECTED + self._reset_reconnect_backoff() + self.config['state_change_callback'](self.node_id, self._sock, self) + + # Connection failed + # WSAEINVAL == 10022, but errno.WSAEINVAL is not available on non-win systems + elif ret not in (errno.EINPROGRESS, errno.EALREADY, errno.EWOULDBLOCK, 10022): + log.error('Connect attempt to %s returned error %s.' + ' Disconnecting.', self, ret) + errstr = errno.errorcode.get(ret, 'UNKNOWN') + self.close(Errors.KafkaConnectionError('{} {}'.format(ret, errstr))) + return self.state + + # Needs retry + else: + pass + + if self.state is ConnectionStates.HANDSHAKE: + if self._try_handshake(): + log.debug('%s: completed SSL handshake.', self) + if self.config['security_protocol'] == 'SASL_SSL': + log.debug('%s: initiating SASL authentication', self) + self.state = ConnectionStates.AUTHENTICATING + else: + log.info('%s: Connection complete.', self) + self.state = ConnectionStates.CONNECTED + self._reset_reconnect_backoff() + self.config['state_change_callback'](self.node_id, self._sock, self) + + if self.state is ConnectionStates.AUTHENTICATING: + assert self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL') + if self._try_authenticate(): + # _try_authenticate has side-effects: possibly disconnected on socket errors + if self.state is ConnectionStates.AUTHENTICATING: + log.info('%s: Connection complete.', self) + self.state = ConnectionStates.CONNECTED + self._reset_reconnect_backoff() + self.config['state_change_callback'](self.node_id, self._sock, self) + + if self.state not in (ConnectionStates.CONNECTED, + ConnectionStates.DISCONNECTED): + # Connection timed out + request_timeout = self.config['request_timeout_ms'] / 1000.0 + if time.time() > request_timeout + self.last_attempt: + log.error('Connection attempt to %s timed out', self) + self.close(Errors.KafkaConnectionError('timeout')) + return self.state + + return self.state + + def _wrap_ssl(self): + assert self.config['security_protocol'] in ('SSL', 'SASL_SSL') + if self._ssl_context is None: + log.debug('%s: configuring default SSL Context', self) + self._ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) # pylint: disable=no-member + self._ssl_context.options |= ssl.OP_NO_SSLv2 # pylint: disable=no-member + self._ssl_context.options |= ssl.OP_NO_SSLv3 # pylint: disable=no-member + self._ssl_context.verify_mode = ssl.CERT_OPTIONAL + if self.config['ssl_check_hostname']: + self._ssl_context.check_hostname = True + if self.config['ssl_cafile']: + log.info('%s: Loading SSL CA from %s', self, self.config['ssl_cafile']) + self._ssl_context.load_verify_locations(self.config['ssl_cafile']) + self._ssl_context.verify_mode = ssl.CERT_REQUIRED + else: + log.info('%s: Loading system default SSL CAs from %s', self, ssl.get_default_verify_paths()) + self._ssl_context.load_default_certs() + if self.config['ssl_certfile'] and self.config['ssl_keyfile']: + log.info('%s: Loading SSL Cert from %s', self, self.config['ssl_certfile']) + log.info('%s: Loading SSL Key from %s', self, self.config['ssl_keyfile']) + self._ssl_context.load_cert_chain( + certfile=self.config['ssl_certfile'], + keyfile=self.config['ssl_keyfile'], + password=self.config['ssl_password']) + if self.config['ssl_crlfile']: + if not hasattr(ssl, 'VERIFY_CRL_CHECK_LEAF'): + raise RuntimeError('This version of Python does not support ssl_crlfile!') + log.info('%s: Loading SSL CRL from %s', self, self.config['ssl_crlfile']) + self._ssl_context.load_verify_locations(self.config['ssl_crlfile']) + # pylint: disable=no-member + self._ssl_context.verify_flags |= ssl.VERIFY_CRL_CHECK_LEAF + if self.config['ssl_ciphers']: + log.info('%s: Setting SSL Ciphers: %s', self, self.config['ssl_ciphers']) + self._ssl_context.set_ciphers(self.config['ssl_ciphers']) + log.debug('%s: wrapping socket in ssl context', self) + try: + self._sock = self._ssl_context.wrap_socket( + self._sock, + server_hostname=self.host, + do_handshake_on_connect=False) + except ssl.SSLError as e: + log.exception('%s: Failed to wrap socket in SSLContext!', self) + self.close(e) + + def _try_handshake(self): + assert self.config['security_protocol'] in ('SSL', 'SASL_SSL') + try: + self._sock.do_handshake() + return True + # old ssl in python2.6 will swallow all SSLErrors here... + except (SSLWantReadError, SSLWantWriteError): + pass + except (SSLZeroReturnError, ConnectionError, TimeoutError, SSLEOFError): + log.warning('SSL connection closed by server during handshake.') + self.close(Errors.KafkaConnectionError('SSL connection closed by server during handshake')) + # Other SSLErrors will be raised to user + + return False + + def _try_authenticate(self): + assert self.config['api_version'] is None or self.config['api_version'] >= (0, 10) + + if self._sasl_auth_future is None: + # Build a SaslHandShakeRequest message + request = SaslHandShakeRequest[0](self.config['sasl_mechanism']) + future = Future() + sasl_response = self._send(request) + sasl_response.add_callback(self._handle_sasl_handshake_response, future) + sasl_response.add_errback(lambda f, e: f.failure(e), future) + self._sasl_auth_future = future + + for r, f in self.recv(): + f.success(r) + + # A connection error could trigger close() which will reset the future + if self._sasl_auth_future is None: + return False + elif self._sasl_auth_future.failed(): + ex = self._sasl_auth_future.exception + if not isinstance(ex, Errors.KafkaConnectionError): + raise ex # pylint: disable-msg=raising-bad-type + return self._sasl_auth_future.succeeded() + + def _handle_sasl_handshake_response(self, future, response): + error_type = Errors.for_code(response.error_code) + if error_type is not Errors.NoError: + error = error_type(self) + self.close(error=error) + return future.failure(error_type(self)) + + if self.config['sasl_mechanism'] not in response.enabled_mechanisms: + return future.failure( + Errors.UnsupportedSaslMechanismError( + 'Kafka broker does not support %s sasl mechanism. Enabled mechanisms are: %s' + % (self.config['sasl_mechanism'], response.enabled_mechanisms))) + elif self.config['sasl_mechanism'] == 'PLAIN': + return self._try_authenticate_plain(future) + elif self.config['sasl_mechanism'] == 'GSSAPI': + return self._try_authenticate_gssapi(future) + elif self.config['sasl_mechanism'] == 'OAUTHBEARER': + return self._try_authenticate_oauth(future) + elif self.config['sasl_mechanism'].startswith("SCRAM-SHA-"): + return self._try_authenticate_scram(future) + else: + return future.failure( + Errors.UnsupportedSaslMechanismError( + 'kafka-python does not support SASL mechanism %s' % + self.config['sasl_mechanism'])) + + def _send_bytes(self, data): + """Send some data via non-blocking IO + + Note: this method is not synchronized internally; you should + always hold the _lock before calling + + Returns: number of bytes + Raises: socket exception + """ + total_sent = 0 + while total_sent < len(data): + try: + sent_bytes = self._sock.send(data[total_sent:]) + total_sent += sent_bytes + except (SSLWantReadError, SSLWantWriteError): + break + except (ConnectionError, TimeoutError) as e: + if six.PY2 and e.errno == errno.EWOULDBLOCK: + break + raise + except BlockingIOError: + if six.PY3: + break + raise + return total_sent + + def _send_bytes_blocking(self, data): + self._sock.settimeout(self.config['request_timeout_ms'] / 1000) + total_sent = 0 + try: + while total_sent < len(data): + sent_bytes = self._sock.send(data[total_sent:]) + total_sent += sent_bytes + if total_sent != len(data): + raise ConnectionError('Buffer overrun during socket send') + return total_sent + finally: + self._sock.settimeout(0.0) + + def _recv_bytes_blocking(self, n): + self._sock.settimeout(self.config['request_timeout_ms'] / 1000) + try: + data = b'' + while len(data) < n: + fragment = self._sock.recv(n - len(data)) + if not fragment: + raise ConnectionError('Connection reset during recv') + data += fragment + return data + finally: + self._sock.settimeout(0.0) + + def _try_authenticate_plain(self, future): + if self.config['security_protocol'] == 'SASL_PLAINTEXT': + log.warning('%s: Sending username and password in the clear', self) + + data = b'' + # Send PLAIN credentials per RFC-4616 + msg = bytes('\0'.join([self.config['sasl_plain_username'], + self.config['sasl_plain_username'], + self.config['sasl_plain_password']]).encode('utf-8')) + size = Int32.encode(len(msg)) + + err = None + close = False + with self._lock: + if not self._can_send_recv(): + err = Errors.NodeNotReadyError(str(self)) + close = False + else: + try: + self._send_bytes_blocking(size + msg) + + # The server will send a zero sized message (that is Int32(0)) on success. + # The connection is closed on failure + data = self._recv_bytes_blocking(4) + + except (ConnectionError, TimeoutError) as e: + log.exception("%s: Error receiving reply from server", self) + err = Errors.KafkaConnectionError("%s: %s" % (self, e)) + close = True + + if err is not None: + if close: + self.close(error=err) + return future.failure(err) + + if data != b'\x00\x00\x00\x00': + error = Errors.AuthenticationFailedError('Unrecognized response during authentication') + return future.failure(error) + + log.info('%s: Authenticated as %s via PLAIN', self, self.config['sasl_plain_username']) + return future.success(True) + + def _try_authenticate_scram(self, future): + if self.config['security_protocol'] == 'SASL_PLAINTEXT': + log.warning('%s: Exchanging credentials in the clear', self) + + scram_client = ScramClient( + self.config['sasl_plain_username'], self.config['sasl_plain_password'], self.config['sasl_mechanism'] + ) + + err = None + close = False + with self._lock: + if not self._can_send_recv(): + err = Errors.NodeNotReadyError(str(self)) + close = False + else: + try: + client_first = scram_client.first_message().encode('utf-8') + size = Int32.encode(len(client_first)) + self._send_bytes_blocking(size + client_first) + + (data_len,) = struct.unpack('>i', self._recv_bytes_blocking(4)) + server_first = self._recv_bytes_blocking(data_len).decode('utf-8') + scram_client.process_server_first_message(server_first) + + client_final = scram_client.final_message().encode('utf-8') + size = Int32.encode(len(client_final)) + self._send_bytes_blocking(size + client_final) + + (data_len,) = struct.unpack('>i', self._recv_bytes_blocking(4)) + server_final = self._recv_bytes_blocking(data_len).decode('utf-8') + scram_client.process_server_final_message(server_final) + + except (ConnectionError, TimeoutError) as e: + log.exception("%s: Error receiving reply from server", self) + err = Errors.KafkaConnectionError("%s: %s" % (self, e)) + close = True + + if err is not None: + if close: + self.close(error=err) + return future.failure(err) + + log.info( + '%s: Authenticated as %s via %s', self, self.config['sasl_plain_username'], self.config['sasl_mechanism'] + ) + return future.success(True) + + def _try_authenticate_gssapi(self, future): + kerberos_damin_name = self.config['sasl_kerberos_domain_name'] or self.host + auth_id = self.config['sasl_kerberos_service_name'] + '@' + kerberos_damin_name + gssapi_name = gssapi.Name( + auth_id, + name_type=gssapi.NameType.hostbased_service + ).canonicalize(gssapi.MechType.kerberos) + log.debug('%s: GSSAPI name: %s', self, gssapi_name) + + err = None + close = False + with self._lock: + if not self._can_send_recv(): + err = Errors.NodeNotReadyError(str(self)) + close = False + else: + # Establish security context and negotiate protection level + # For reference RFC 2222, section 7.2.1 + try: + # Exchange tokens until authentication either succeeds or fails + client_ctx = gssapi.SecurityContext(name=gssapi_name, usage='initiate') + received_token = None + while not client_ctx.complete: + # calculate an output token from kafka token (or None if first iteration) + output_token = client_ctx.step(received_token) + + # pass output token to kafka, or send empty response if the security + # context is complete (output token is None in that case) + if output_token is None: + self._send_bytes_blocking(Int32.encode(0)) + else: + msg = output_token + size = Int32.encode(len(msg)) + self._send_bytes_blocking(size + msg) + + # The server will send a token back. Processing of this token either + # establishes a security context, or it needs further token exchange. + # The gssapi will be able to identify the needed next step. + # The connection is closed on failure. + header = self._recv_bytes_blocking(4) + (token_size,) = struct.unpack('>i', header) + received_token = self._recv_bytes_blocking(token_size) + + # Process the security layer negotiation token, sent by the server + # once the security context is established. + + # unwraps message containing supported protection levels and msg size + msg = client_ctx.unwrap(received_token).message + # Kafka currently doesn't support integrity or confidentiality security layers, so we + # simply set QoP to 'auth' only (first octet). We reuse the max message size proposed + # by the server + msg = Int8.encode(SASL_QOP_AUTH & Int8.decode(io.BytesIO(msg[0:1]))) + msg[1:] + # add authorization identity to the response, GSS-wrap and send it + msg = client_ctx.wrap(msg + auth_id.encode(), False).message + size = Int32.encode(len(msg)) + self._send_bytes_blocking(size + msg) + + except (ConnectionError, TimeoutError) as e: + log.exception("%s: Error receiving reply from server", self) + err = Errors.KafkaConnectionError("%s: %s" % (self, e)) + close = True + except Exception as e: + err = e + close = True + + if err is not None: + if close: + self.close(error=err) + return future.failure(err) + + log.info('%s: Authenticated as %s via GSSAPI', self, gssapi_name) + return future.success(True) + + def _try_authenticate_oauth(self, future): + data = b'' + + msg = bytes(self._build_oauth_client_request().encode("utf-8")) + size = Int32.encode(len(msg)) + + err = None + close = False + with self._lock: + if not self._can_send_recv(): + err = Errors.NodeNotReadyError(str(self)) + close = False + else: + try: + # Send SASL OAuthBearer request with OAuth token + self._send_bytes_blocking(size + msg) + + # The server will send a zero sized message (that is Int32(0)) on success. + # The connection is closed on failure + data = self._recv_bytes_blocking(4) + + except (ConnectionError, TimeoutError) as e: + log.exception("%s: Error receiving reply from server", self) + err = Errors.KafkaConnectionError("%s: %s" % (self, e)) + close = True + + if err is not None: + if close: + self.close(error=err) + return future.failure(err) + + if data != b'\x00\x00\x00\x00': + error = Errors.AuthenticationFailedError('Unrecognized response during authentication') + return future.failure(error) + + log.info('%s: Authenticated via OAuth', self) + return future.success(True) + + def _build_oauth_client_request(self): + token_provider = self.config['sasl_oauth_token_provider'] + return "n,,\x01auth=Bearer {}{}\x01\x01".format(token_provider.token(), self._token_extensions()) + + def _token_extensions(self): + """ + Return a string representation of the OPTIONAL key-value pairs that can be sent with an OAUTHBEARER + initial request. + """ + token_provider = self.config['sasl_oauth_token_provider'] + + # Only run if the #extensions() method is implemented by the clients Token Provider class + # Builds up a string separated by \x01 via a dict of key value pairs + if callable(getattr(token_provider, "extensions", None)) and len(token_provider.extensions()) > 0: + msg = "\x01".join(["{}={}".format(k, v) for k, v in token_provider.extensions().items()]) + return "\x01" + msg + else: + return "" + + def blacked_out(self): + """ + Return true if we are disconnected from the given node and can't + re-establish a connection yet + """ + if self.state is ConnectionStates.DISCONNECTED: + if time.time() < self.last_attempt + self._reconnect_backoff: + return True + return False + + def connection_delay(self): + """ + Return the number of milliseconds to wait, based on the connection + state, before attempting to send data. When disconnected, this respects + the reconnect backoff time. When connecting or connected, returns a very + large number to handle slow/stalled connections. + """ + time_waited = time.time() - (self.last_attempt or 0) + if self.state is ConnectionStates.DISCONNECTED: + return max(self._reconnect_backoff - time_waited, 0) * 1000 + else: + # When connecting or connected, we should be able to delay + # indefinitely since other events (connection or data acked) will + # cause a wakeup once data can be sent. + return float('inf') + + def connected(self): + """Return True iff socket is connected.""" + return self.state is ConnectionStates.CONNECTED + + def connecting(self): + """Returns True if still connecting (this may encompass several + different states, such as SSL handshake, authorization, etc).""" + return self.state in (ConnectionStates.CONNECTING, + ConnectionStates.HANDSHAKE, + ConnectionStates.AUTHENTICATING) + + def disconnected(self): + """Return True iff socket is closed""" + return self.state is ConnectionStates.DISCONNECTED + + def _reset_reconnect_backoff(self): + self._failures = 0 + self._reconnect_backoff = self.config['reconnect_backoff_ms'] / 1000.0 + + def _update_reconnect_backoff(self): + # Do not mark as failure if there are more dns entries available to try + if len(self._gai) > 0: + return + if self.config['reconnect_backoff_max_ms'] > self.config['reconnect_backoff_ms']: + self._failures += 1 + self._reconnect_backoff = self.config['reconnect_backoff_ms'] * 2 ** (self._failures - 1) + self._reconnect_backoff = min(self._reconnect_backoff, self.config['reconnect_backoff_max_ms']) + self._reconnect_backoff *= uniform(0.8, 1.2) + self._reconnect_backoff /= 1000.0 + log.debug('%s: reconnect backoff %s after %s failures', self, self._reconnect_backoff, self._failures) + + def _close_socket(self): + if hasattr(self, '_sock') and self._sock is not None: + self._sock.close() + self._sock = None + + def __del__(self): + self._close_socket() + + def close(self, error=None): + """Close socket and fail all in-flight-requests. + + Arguments: + error (Exception, optional): pending in-flight-requests + will be failed with this exception. + Default: kafka.errors.KafkaConnectionError. + """ + if self.state is ConnectionStates.DISCONNECTED: + return + with self._lock: + if self.state is ConnectionStates.DISCONNECTED: + return + log.info('%s: Closing connection. %s', self, error or '') + self._update_reconnect_backoff() + self._sasl_auth_future = None + self._protocol = KafkaProtocol( + client_id=self.config['client_id'], + api_version=self.config['api_version']) + self._send_buffer = b'' + if error is None: + error = Errors.Cancelled(str(self)) + ifrs = list(self.in_flight_requests.items()) + self.in_flight_requests.clear() + self.state = ConnectionStates.DISCONNECTED + # To avoid race conditions and/or deadlocks + # keep a reference to the socket but leave it + # open until after the state_change_callback + # This should give clients a change to deregister + # the socket fd from selectors cleanly. + sock = self._sock + self._sock = None + + # drop lock before state change callback and processing futures + self.config['state_change_callback'](self.node_id, sock, self) + sock.close() + for (_correlation_id, (future, _timestamp)) in ifrs: + future.failure(error) + + def _can_send_recv(self): + """Return True iff socket is ready for requests / responses""" + return self.state in (ConnectionStates.AUTHENTICATING, + ConnectionStates.CONNECTED) + + def send(self, request, blocking=True): + """Queue request for async network send, return Future()""" + future = Future() + if self.connecting(): + return future.failure(Errors.NodeNotReadyError(str(self))) + elif not self.connected(): + return future.failure(Errors.KafkaConnectionError(str(self))) + elif not self.can_send_more(): + return future.failure(Errors.TooManyInFlightRequests(str(self))) + return self._send(request, blocking=blocking) + + def _send(self, request, blocking=True): + future = Future() + with self._lock: + if not self._can_send_recv(): + # In this case, since we created the future above, + # we know there are no callbacks/errbacks that could fire w/ + # lock. So failing + returning inline should be safe + return future.failure(Errors.NodeNotReadyError(str(self))) + + correlation_id = self._protocol.send_request(request) + + log.debug('%s Request %d: %s', self, correlation_id, request) + if request.expect_response(): + sent_time = time.time() + assert correlation_id not in self.in_flight_requests, 'Correlation ID already in-flight!' + self.in_flight_requests[correlation_id] = (future, sent_time) + else: + future.success(None) + + # Attempt to replicate behavior from prior to introduction of + # send_pending_requests() / async sends + if blocking: + self.send_pending_requests() + + return future + + def send_pending_requests(self): + """Attempts to send pending requests messages via blocking IO + If all requests have been sent, return True + Otherwise, if the socket is blocked and there are more bytes to send, + return False. + """ + try: + with self._lock: + if not self._can_send_recv(): + return False + data = self._protocol.send_bytes() + total_bytes = self._send_bytes_blocking(data) + + if self._sensors: + self._sensors.bytes_sent.record(total_bytes) + return True + + except (ConnectionError, TimeoutError) as e: + log.exception("Error sending request data to %s", self) + error = Errors.KafkaConnectionError("%s: %s" % (self, e)) + self.close(error=error) + return False + + def send_pending_requests_v2(self): + """Attempts to send pending requests messages via non-blocking IO + If all requests have been sent, return True + Otherwise, if the socket is blocked and there are more bytes to send, + return False. + """ + try: + with self._lock: + if not self._can_send_recv(): + return False + + # _protocol.send_bytes returns encoded requests to send + # we send them via _send_bytes() + # and hold leftover bytes in _send_buffer + if not self._send_buffer: + self._send_buffer = self._protocol.send_bytes() + + total_bytes = 0 + if self._send_buffer: + total_bytes = self._send_bytes(self._send_buffer) + self._send_buffer = self._send_buffer[total_bytes:] + + if self._sensors: + self._sensors.bytes_sent.record(total_bytes) + # Return True iff send buffer is empty + return len(self._send_buffer) == 0 + + except (ConnectionError, TimeoutError, Exception) as e: + log.exception("Error sending request data to %s", self) + error = Errors.KafkaConnectionError("%s: %s" % (self, e)) + self.close(error=error) + return False + + def can_send_more(self): + """Return True unless there are max_in_flight_requests_per_connection.""" + max_ifrs = self.config['max_in_flight_requests_per_connection'] + return len(self.in_flight_requests) < max_ifrs + + def recv(self): + """Non-blocking network receive. + + Return list of (response, future) tuples + """ + responses = self._recv() + if not responses and self.requests_timed_out(): + log.warning('%s timed out after %s ms. Closing connection.', + self, self.config['request_timeout_ms']) + self.close(error=Errors.RequestTimedOutError( + 'Request timed out after %s ms' % + self.config['request_timeout_ms'])) + return () + + # augment responses w/ correlation_id, future, and timestamp + for i, (correlation_id, response) in enumerate(responses): + try: + with self._lock: + (future, timestamp) = self.in_flight_requests.pop(correlation_id) + except KeyError: + self.close(Errors.KafkaConnectionError('Received unrecognized correlation id')) + return () + latency_ms = (time.time() - timestamp) * 1000 + if self._sensors: + self._sensors.request_time.record(latency_ms) + + log.debug('%s Response %d (%s ms): %s', self, correlation_id, latency_ms, response) + responses[i] = (response, future) + + return responses + + def _recv(self): + """Take all available bytes from socket, return list of any responses from parser""" + recvd = [] + err = None + with self._lock: + if not self._can_send_recv(): + log.warning('%s cannot recv: socket not connected', self) + return () + + while len(recvd) < self.config['sock_chunk_buffer_count']: + try: + data = self._sock.recv(self.config['sock_chunk_bytes']) + # We expect socket.recv to raise an exception if there are no + # bytes available to read from the socket in non-blocking mode. + # but if the socket is disconnected, we will get empty data + # without an exception raised + if not data: + log.error('%s: socket disconnected', self) + err = Errors.KafkaConnectionError('socket disconnected') + break + else: + recvd.append(data) + + except (SSLWantReadError, SSLWantWriteError): + break + except (ConnectionError, TimeoutError) as e: + if six.PY2 and e.errno == errno.EWOULDBLOCK: + break + log.exception('%s: Error receiving network data' + ' closing socket', self) + err = Errors.KafkaConnectionError(e) + break + except BlockingIOError: + if six.PY3: + break + # For PY2 this is a catchall and should be re-raised + raise + + # Only process bytes if there was no connection exception + if err is None: + recvd_data = b''.join(recvd) + if self._sensors: + self._sensors.bytes_received.record(len(recvd_data)) + + # We need to keep the lock through protocol receipt + # so that we ensure that the processed byte order is the + # same as the received byte order + try: + return self._protocol.receive_bytes(recvd_data) + except Errors.KafkaProtocolError as e: + err = e + + self.close(error=err) + return () + + def requests_timed_out(self): + with self._lock: + if self.in_flight_requests: + get_timestamp = lambda v: v[1] + oldest_at = min(map(get_timestamp, + self.in_flight_requests.values())) + timeout = self.config['request_timeout_ms'] / 1000.0 + if time.time() >= oldest_at + timeout: + return True + return False + + def _handle_api_version_response(self, response): + error_type = Errors.for_code(response.error_code) + assert error_type is Errors.NoError, "API version check failed" + self._api_versions = dict([ + (api_key, (min_version, max_version)) + for api_key, min_version, max_version in response.api_versions + ]) + return self._api_versions + + def get_api_versions(self): + if self._api_versions is not None: + return self._api_versions + + version = self.check_version() + if version < (0, 10, 0): + raise Errors.UnsupportedVersionError( + "ApiVersion not supported by cluster version {} < 0.10.0" + .format(version)) + # _api_versions is set as a side effect of check_versions() on a cluster + # that supports 0.10.0 or later + return self._api_versions + + def _infer_broker_version_from_api_versions(self, api_versions): + # The logic here is to check the list of supported request versions + # in reverse order. As soon as we find one that works, return it + test_cases = [ + # format (, ) + ((2, 6, 0), DescribeClientQuotasRequest[0]), + ((2, 5, 0), DescribeAclsRequest_v2), + ((2, 4, 0), ProduceRequest[8]), + ((2, 3, 0), FetchRequest[11]), + ((2, 2, 0), OffsetRequest[5]), + ((2, 1, 0), FetchRequest[10]), + ((2, 0, 0), FetchRequest[8]), + ((1, 1, 0), FetchRequest[7]), + ((1, 0, 0), MetadataRequest[5]), + ((0, 11, 0), MetadataRequest[4]), + ((0, 10, 2), OffsetFetchRequest[2]), + ((0, 10, 1), MetadataRequest[2]), + ] + + # Get the best match of test cases + for broker_version, struct in sorted(test_cases, reverse=True): + if struct.API_KEY not in api_versions: + continue + min_version, max_version = api_versions[struct.API_KEY] + if min_version <= struct.API_VERSION <= max_version: + return broker_version + + # We know that ApiVersionResponse is only supported in 0.10+ + # so if all else fails, choose that + return (0, 10, 0) + + def check_version(self, timeout=2, strict=False, topics=[]): + """Attempt to guess the broker version. + + Note: This is a blocking call. + + Returns: version tuple, i.e. (0, 10), (0, 9), (0, 8, 2), ... + """ + timeout_at = time.time() + timeout + log.info('Probing node %s broker version', self.node_id) + # Monkeypatch some connection configurations to avoid timeouts + override_config = { + 'request_timeout_ms': timeout * 1000, + 'max_in_flight_requests_per_connection': 5 + } + stashed = {} + for key in override_config: + stashed[key] = self.config[key] + self.config[key] = override_config[key] + + def reset_override_configs(): + for key in stashed: + self.config[key] = stashed[key] + + # kafka kills the connection when it doesn't recognize an API request + # so we can send a test request and then follow immediately with a + # vanilla MetadataRequest. If the server did not recognize the first + # request, both will be failed with a ConnectionError that wraps + # socket.error (32, 54, or 104) + from kafka.protocol.admin import ApiVersionRequest, ListGroupsRequest + from kafka.protocol.commit import OffsetFetchRequest, GroupCoordinatorRequest + + test_cases = [ + # All cases starting from 0.10 will be based on ApiVersionResponse + ((0, 10), ApiVersionRequest[0]()), + ((0, 9), ListGroupsRequest[0]()), + ((0, 8, 2), GroupCoordinatorRequest[0]('kafka-python-default-group')), + ((0, 8, 1), OffsetFetchRequest[0]('kafka-python-default-group', [])), + ((0, 8, 0), MetadataRequest[0](topics)), + ] + + for version, request in test_cases: + if not self.connect_blocking(timeout_at - time.time()): + reset_override_configs() + raise Errors.NodeNotReadyError() + f = self.send(request) + # HACK: sleeping to wait for socket to send bytes + time.sleep(0.1) + # when broker receives an unrecognized request API + # it abruptly closes our socket. + # so we attempt to send a second request immediately + # that we believe it will definitely recognize (metadata) + # the attempt to write to a disconnected socket should + # immediately fail and allow us to infer that the prior + # request was unrecognized + mr = self.send(MetadataRequest[0](topics)) + + selector = self.config['selector']() + selector.register(self._sock, selectors.EVENT_READ) + while not (f.is_done and mr.is_done): + selector.select(1) + for response, future in self.recv(): + future.success(response) + selector.close() + + if f.succeeded(): + if isinstance(request, ApiVersionRequest[0]): + # Starting from 0.10 kafka broker we determine version + # by looking at ApiVersionResponse + api_versions = self._handle_api_version_response(f.value) + version = self._infer_broker_version_from_api_versions(api_versions) + log.info('Broker version identified as %s', '.'.join(map(str, version))) + log.info('Set configuration api_version=%s to skip auto' + ' check_version requests on startup', version) + break + + # Only enable strict checking to verify that we understand failure + # modes. For most users, the fact that the request failed should be + # enough to rule out a particular broker version. + if strict: + # If the socket flush hack did not work (which should force the + # connection to close and fail all pending requests), then we + # get a basic Request Timeout. This is not ideal, but we'll deal + if isinstance(f.exception, Errors.RequestTimedOutError): + pass + + # 0.9 brokers do not close the socket on unrecognized api + # requests (bug...). In this case we expect to see a correlation + # id mismatch + elif (isinstance(f.exception, Errors.CorrelationIdError) and + version == (0, 10)): + pass + elif six.PY2: + assert isinstance(f.exception.args[0], socket.error) + assert f.exception.args[0].errno in (32, 54, 104) + else: + assert isinstance(f.exception.args[0], ConnectionError) + log.info("Broker is not v%s -- it did not recognize %s", + version, request.__class__.__name__) + else: + reset_override_configs() + raise Errors.UnrecognizedBrokerVersion() + + reset_override_configs() + return version + + def __str__(self): + return "" % ( + self.node_id, self.host, self.port, self.state, + AFI_NAMES[self._sock_afi], self._sock_addr) + + +class BrokerConnectionMetrics(object): + def __init__(self, metrics, metric_group_prefix, node_id): + self.metrics = metrics + + # Any broker may have registered summary metrics already + # but if not, we need to create them so we can set as parents below + all_conns_transferred = metrics.get_sensor('bytes-sent-received') + if not all_conns_transferred: + metric_group_name = metric_group_prefix + '-metrics' + + bytes_transferred = metrics.sensor('bytes-sent-received') + bytes_transferred.add(metrics.metric_name( + 'network-io-rate', metric_group_name, + 'The average number of network operations (reads or writes) on all' + ' connections per second.'), Rate(sampled_stat=Count())) + + bytes_sent = metrics.sensor('bytes-sent', + parents=[bytes_transferred]) + bytes_sent.add(metrics.metric_name( + 'outgoing-byte-rate', metric_group_name, + 'The average number of outgoing bytes sent per second to all' + ' servers.'), Rate()) + bytes_sent.add(metrics.metric_name( + 'request-rate', metric_group_name, + 'The average number of requests sent per second.'), + Rate(sampled_stat=Count())) + bytes_sent.add(metrics.metric_name( + 'request-size-avg', metric_group_name, + 'The average size of all requests in the window.'), Avg()) + bytes_sent.add(metrics.metric_name( + 'request-size-max', metric_group_name, + 'The maximum size of any request sent in the window.'), Max()) + + bytes_received = metrics.sensor('bytes-received', + parents=[bytes_transferred]) + bytes_received.add(metrics.metric_name( + 'incoming-byte-rate', metric_group_name, + 'Bytes/second read off all sockets'), Rate()) + bytes_received.add(metrics.metric_name( + 'response-rate', metric_group_name, + 'Responses received sent per second.'), + Rate(sampled_stat=Count())) + + request_latency = metrics.sensor('request-latency') + request_latency.add(metrics.metric_name( + 'request-latency-avg', metric_group_name, + 'The average request latency in ms.'), + Avg()) + request_latency.add(metrics.metric_name( + 'request-latency-max', metric_group_name, + 'The maximum request latency in ms.'), + Max()) + + # if one sensor of the metrics has been registered for the connection, + # then all other sensors should have been registered; and vice versa + node_str = 'node-{0}'.format(node_id) + node_sensor = metrics.get_sensor(node_str + '.bytes-sent') + if not node_sensor: + metric_group_name = metric_group_prefix + '-node-metrics.' + node_str + + bytes_sent = metrics.sensor( + node_str + '.bytes-sent', + parents=[metrics.get_sensor('bytes-sent')]) + bytes_sent.add(metrics.metric_name( + 'outgoing-byte-rate', metric_group_name, + 'The average number of outgoing bytes sent per second.'), + Rate()) + bytes_sent.add(metrics.metric_name( + 'request-rate', metric_group_name, + 'The average number of requests sent per second.'), + Rate(sampled_stat=Count())) + bytes_sent.add(metrics.metric_name( + 'request-size-avg', metric_group_name, + 'The average size of all requests in the window.'), + Avg()) + bytes_sent.add(metrics.metric_name( + 'request-size-max', metric_group_name, + 'The maximum size of any request sent in the window.'), + Max()) + + bytes_received = metrics.sensor( + node_str + '.bytes-received', + parents=[metrics.get_sensor('bytes-received')]) + bytes_received.add(metrics.metric_name( + 'incoming-byte-rate', metric_group_name, + 'Bytes/second read off node-connection socket'), + Rate()) + bytes_received.add(metrics.metric_name( + 'response-rate', metric_group_name, + 'The average number of responses received per second.'), + Rate(sampled_stat=Count())) + + request_time = metrics.sensor( + node_str + '.latency', + parents=[metrics.get_sensor('request-latency')]) + request_time.add(metrics.metric_name( + 'request-latency-avg', metric_group_name, + 'The average request latency in ms.'), + Avg()) + request_time.add(metrics.metric_name( + 'request-latency-max', metric_group_name, + 'The maximum request latency in ms.'), + Max()) + + self.bytes_sent = metrics.sensor(node_str + '.bytes-sent') + self.bytes_received = metrics.sensor(node_str + '.bytes-received') + self.request_time = metrics.sensor(node_str + '.latency') + + +def _address_family(address): + """ + Attempt to determine the family of an address (or hostname) + + :return: either socket.AF_INET or socket.AF_INET6 or socket.AF_UNSPEC if the address family + could not be determined + """ + if address.startswith('[') and address.endswith(']'): + return socket.AF_INET6 + for af in (socket.AF_INET, socket.AF_INET6): + try: + socket.inet_pton(af, address) + return af + except (ValueError, AttributeError, socket.error): + continue + return socket.AF_UNSPEC + + +def get_ip_port_afi(host_and_port_str): + """ + Parse the IP and port from a string in the format of: + + * host_or_ip <- Can be either IPv4 address literal or hostname/fqdn + * host_or_ipv4:port <- Can be either IPv4 address literal or hostname/fqdn + * [host_or_ip] <- IPv6 address literal + * [host_or_ip]:port. <- IPv6 address literal + + .. note:: IPv6 address literals with ports *must* be enclosed in brackets + + .. note:: If the port is not specified, default will be returned. + + :return: tuple (host, port, afi), afi will be socket.AF_INET or socket.AF_INET6 or socket.AF_UNSPEC + """ + host_and_port_str = host_and_port_str.strip() + if host_and_port_str.startswith('['): + af = socket.AF_INET6 + host, rest = host_and_port_str[1:].split(']') + if rest: + port = int(rest[1:]) + else: + port = DEFAULT_KAFKA_PORT + return host, port, af + else: + if ':' not in host_and_port_str: + af = _address_family(host_and_port_str) + return host_and_port_str, DEFAULT_KAFKA_PORT, af + else: + # now we have something with a colon in it and no square brackets. It could be + # either an IPv6 address literal (e.g., "::1") or an IP:port pair or a host:port pair + try: + # if it decodes as an IPv6 address, use that + socket.inet_pton(socket.AF_INET6, host_and_port_str) + return host_and_port_str, DEFAULT_KAFKA_PORT, socket.AF_INET6 + except AttributeError: + log.warning('socket.inet_pton not available on this platform.' + ' consider `pip install win_inet_pton`') + pass + except (ValueError, socket.error): + # it's a host:port pair + pass + host, port = host_and_port_str.rsplit(':', 1) + port = int(port) + + af = _address_family(host) + return host, port, af + + +def collect_hosts(hosts, randomize=True): + """ + Collects a comma-separated set of hosts (host:port) and optionally + randomize the returned list. + """ + + if isinstance(hosts, six.string_types): + hosts = hosts.strip().split(',') + + result = [] + afi = socket.AF_INET + for host_port in hosts: + + host, port, afi = get_ip_port_afi(host_port) + + if port < 0: + port = DEFAULT_KAFKA_PORT + + result.append((host, port, afi)) + + if randomize: + shuffle(result) + + return result + + +def is_inet_4_or_6(gai): + """Given a getaddrinfo struct, return True iff ipv4 or ipv6""" + return gai[0] in (socket.AF_INET, socket.AF_INET6) + + +def dns_lookup(host, port, afi=socket.AF_UNSPEC): + """Returns a list of getaddrinfo structs, optionally filtered to an afi (ipv4 / ipv6)""" + # XXX: all DNS functions in Python are blocking. If we really + # want to be non-blocking here, we need to use a 3rd-party + # library like python-adns, or move resolution onto its + # own thread. This will be subject to the default libc + # name resolution timeout (5s on most Linux boxes) + try: + return list(filter(is_inet_4_or_6, + socket.getaddrinfo(host, port, afi, + socket.SOCK_STREAM))) + except socket.gaierror as ex: + log.warning('DNS lookup failed for %s:%d,' + ' exception was %s. Is your' + ' advertised.listeners (called' + ' advertised.host.name before Kafka 9)' + ' correct and resolvable?', + host, port, ex) + return [] diff --git a/consumer/__init__.py b/consumer/__init__.py new file mode 100644 index 00000000..e09bcc1b --- /dev/null +++ b/consumer/__init__.py @@ -0,0 +1,7 @@ +from __future__ import absolute_import + +from kafka.consumer.group import KafkaConsumer + +__all__ = [ + 'KafkaConsumer' +] diff --git a/consumer/fetcher.py b/consumer/fetcher.py new file mode 100644 index 00000000..7ff9daf7 --- /dev/null +++ b/consumer/fetcher.py @@ -0,0 +1,1016 @@ +from __future__ import absolute_import + +import collections +import copy +import logging +import random +import sys +import time + +from kafka.vendor import six + +import kafka.errors as Errors +from kafka.future import Future +from kafka.metrics.stats import Avg, Count, Max, Rate +from kafka.protocol.fetch import FetchRequest +from kafka.protocol.offset import ( + OffsetRequest, OffsetResetStrategy, UNKNOWN_OFFSET +) +from kafka.record import MemoryRecords +from kafka.serializer import Deserializer +from kafka.structs import TopicPartition, OffsetAndTimestamp + +log = logging.getLogger(__name__) + + +# Isolation levels +READ_UNCOMMITTED = 0 +READ_COMMITTED = 1 + +ConsumerRecord = collections.namedtuple("ConsumerRecord", + ["topic", "partition", "offset", "timestamp", "timestamp_type", + "key", "value", "headers", "checksum", "serialized_key_size", "serialized_value_size", "serialized_header_size"]) + + +CompletedFetch = collections.namedtuple("CompletedFetch", + ["topic_partition", "fetched_offset", "response_version", + "partition_data", "metric_aggregator"]) + + +class NoOffsetForPartitionError(Errors.KafkaError): + pass + + +class RecordTooLargeError(Errors.KafkaError): + pass + + +class Fetcher(six.Iterator): + DEFAULT_CONFIG = { + 'key_deserializer': None, + 'value_deserializer': None, + 'fetch_min_bytes': 1, + 'fetch_max_wait_ms': 500, + 'fetch_max_bytes': 52428800, + 'max_partition_fetch_bytes': 1048576, + 'max_poll_records': sys.maxsize, + 'check_crcs': True, + 'iterator_refetch_records': 1, # undocumented -- interface may change + 'metric_group_prefix': 'consumer', + 'api_version': (0, 8, 0), + 'retry_backoff_ms': 100 + } + + def __init__(self, client, subscriptions, metrics, **configs): + """Initialize a Kafka Message Fetcher. + + Keyword Arguments: + key_deserializer (callable): Any callable that takes a + raw message key and returns a deserialized key. + value_deserializer (callable, optional): Any callable that takes a + raw message value and returns a deserialized value. + fetch_min_bytes (int): Minimum amount of data the server should + return for a fetch request, otherwise wait up to + fetch_max_wait_ms for more data to accumulate. Default: 1. + fetch_max_wait_ms (int): The maximum amount of time in milliseconds + the server will block before answering the fetch request if + there isn't sufficient data to immediately satisfy the + requirement given by fetch_min_bytes. Default: 500. + fetch_max_bytes (int): The maximum amount of data the server should + return for a fetch request. This is not an absolute maximum, if + the first message in the first non-empty partition of the fetch + is larger than this value, the message will still be returned + to ensure that the consumer can make progress. NOTE: consumer + performs fetches to multiple brokers in parallel so memory + usage will depend on the number of brokers containing + partitions for the topic. + Supported Kafka version >= 0.10.1.0. Default: 52428800 (50 MB). + max_partition_fetch_bytes (int): The maximum amount of data + per-partition the server will return. The maximum total memory + used for a request = #partitions * max_partition_fetch_bytes. + This size must be at least as large as the maximum message size + the server allows or else it is possible for the producer to + send messages larger than the consumer can fetch. If that + happens, the consumer can get stuck trying to fetch a large + message on a certain partition. Default: 1048576. + check_crcs (bool): Automatically check the CRC32 of the records + consumed. This ensures no on-the-wire or on-disk corruption to + the messages occurred. This check adds some overhead, so it may + be disabled in cases seeking extreme performance. Default: True + """ + self.config = copy.copy(self.DEFAULT_CONFIG) + for key in self.config: + if key in configs: + self.config[key] = configs[key] + + self._client = client + self._subscriptions = subscriptions + self._completed_fetches = collections.deque() # Unparsed responses + self._next_partition_records = None # Holds a single PartitionRecords until fully consumed + self._iterator = None + self._fetch_futures = collections.deque() + self._sensors = FetchManagerMetrics(metrics, self.config['metric_group_prefix']) + self._isolation_level = READ_UNCOMMITTED + + def send_fetches(self): + """Send FetchRequests for all assigned partitions that do not already have + an in-flight fetch or pending fetch data. + + Returns: + List of Futures: each future resolves to a FetchResponse + """ + futures = [] + for node_id, request in six.iteritems(self._create_fetch_requests()): + if self._client.ready(node_id): + log.debug("Sending FetchRequest to node %s", node_id) + future = self._client.send(node_id, request, wakeup=False) + future.add_callback(self._handle_fetch_response, request, time.time()) + future.add_errback(log.error, 'Fetch to node %s failed: %s', node_id) + futures.append(future) + self._fetch_futures.extend(futures) + self._clean_done_fetch_futures() + return futures + + def reset_offsets_if_needed(self, partitions): + """Lookup and set offsets for any partitions which are awaiting an + explicit reset. + + Arguments: + partitions (set of TopicPartitions): the partitions to reset + """ + for tp in partitions: + # TODO: If there are several offsets to reset, we could submit offset requests in parallel + if self._subscriptions.is_assigned(tp) and self._subscriptions.is_offset_reset_needed(tp): + self._reset_offset(tp) + + def _clean_done_fetch_futures(self): + while True: + if not self._fetch_futures: + break + if not self._fetch_futures[0].is_done: + break + self._fetch_futures.popleft() + + def in_flight_fetches(self): + """Return True if there are any unprocessed FetchRequests in flight.""" + self._clean_done_fetch_futures() + return bool(self._fetch_futures) + + def update_fetch_positions(self, partitions): + """Update the fetch positions for the provided partitions. + + Arguments: + partitions (list of TopicPartitions): partitions to update + + Raises: + NoOffsetForPartitionError: if no offset is stored for a given + partition and no reset policy is available + """ + # reset the fetch position to the committed position + for tp in partitions: + if not self._subscriptions.is_assigned(tp): + log.warning("partition %s is not assigned - skipping offset" + " update", tp) + continue + elif self._subscriptions.is_fetchable(tp): + log.warning("partition %s is still fetchable -- skipping offset" + " update", tp) + continue + + if self._subscriptions.is_offset_reset_needed(tp): + self._reset_offset(tp) + elif self._subscriptions.assignment[tp].committed is None: + # there's no committed position, so we need to reset with the + # default strategy + self._subscriptions.need_offset_reset(tp) + self._reset_offset(tp) + else: + committed = self._subscriptions.assignment[tp].committed.offset + log.debug("Resetting offset for partition %s to the committed" + " offset %s", tp, committed) + self._subscriptions.seek(tp, committed) + + def get_offsets_by_times(self, timestamps, timeout_ms): + offsets = self._retrieve_offsets(timestamps, timeout_ms) + for tp in timestamps: + if tp not in offsets: + offsets[tp] = None + else: + offset, timestamp = offsets[tp] + offsets[tp] = OffsetAndTimestamp(offset, timestamp) + return offsets + + def beginning_offsets(self, partitions, timeout_ms): + return self.beginning_or_end_offset( + partitions, OffsetResetStrategy.EARLIEST, timeout_ms) + + def end_offsets(self, partitions, timeout_ms): + return self.beginning_or_end_offset( + partitions, OffsetResetStrategy.LATEST, timeout_ms) + + def beginning_or_end_offset(self, partitions, timestamp, timeout_ms): + timestamps = dict([(tp, timestamp) for tp in partitions]) + offsets = self._retrieve_offsets(timestamps, timeout_ms) + for tp in timestamps: + offsets[tp] = offsets[tp][0] + return offsets + + def _reset_offset(self, partition): + """Reset offsets for the given partition using the offset reset strategy. + + Arguments: + partition (TopicPartition): the partition that needs reset offset + + Raises: + NoOffsetForPartitionError: if no offset reset strategy is defined + """ + timestamp = self._subscriptions.assignment[partition].reset_strategy + if timestamp is OffsetResetStrategy.EARLIEST: + strategy = 'earliest' + elif timestamp is OffsetResetStrategy.LATEST: + strategy = 'latest' + else: + raise NoOffsetForPartitionError(partition) + + log.debug("Resetting offset for partition %s to %s offset.", + partition, strategy) + offsets = self._retrieve_offsets({partition: timestamp}) + + if partition in offsets: + offset = offsets[partition][0] + + # we might lose the assignment while fetching the offset, + # so check it is still active + if self._subscriptions.is_assigned(partition): + self._subscriptions.seek(partition, offset) + else: + log.debug("Could not find offset for partition %s since it is probably deleted" % (partition,)) + + def _retrieve_offsets(self, timestamps, timeout_ms=float("inf")): + """Fetch offset for each partition passed in ``timestamps`` map. + + Blocks until offsets are obtained, a non-retriable exception is raised + or ``timeout_ms`` passed. + + Arguments: + timestamps: {TopicPartition: int} dict with timestamps to fetch + offsets by. -1 for the latest available, -2 for the earliest + available. Otherwise timestamp is treated as epoch milliseconds. + + Returns: + {TopicPartition: (int, int)}: Mapping of partition to + retrieved offset and timestamp. If offset does not exist for + the provided timestamp, that partition will be missing from + this mapping. + """ + if not timestamps: + return {} + + start_time = time.time() + remaining_ms = timeout_ms + timestamps = copy.copy(timestamps) + while remaining_ms > 0: + if not timestamps: + return {} + + future = self._send_offset_requests(timestamps) + self._client.poll(future=future, timeout_ms=remaining_ms) + + if future.succeeded(): + return future.value + if not future.retriable(): + raise future.exception # pylint: disable-msg=raising-bad-type + + elapsed_ms = (time.time() - start_time) * 1000 + remaining_ms = timeout_ms - elapsed_ms + if remaining_ms < 0: + break + + if future.exception.invalid_metadata: + refresh_future = self._client.cluster.request_update() + self._client.poll(future=refresh_future, timeout_ms=remaining_ms) + + # Issue #1780 + # Recheck partition existence after after a successful metadata refresh + if refresh_future.succeeded() and isinstance(future.exception, Errors.StaleMetadata): + log.debug("Stale metadata was raised, and we now have an updated metadata. Rechecking partition existence") + unknown_partition = future.exception.args[0] # TopicPartition from StaleMetadata + if self._client.cluster.leader_for_partition(unknown_partition) is None: + log.debug("Removed partition %s from offsets retrieval" % (unknown_partition, )) + timestamps.pop(unknown_partition) + else: + time.sleep(self.config['retry_backoff_ms'] / 1000.0) + + elapsed_ms = (time.time() - start_time) * 1000 + remaining_ms = timeout_ms - elapsed_ms + + raise Errors.KafkaTimeoutError( + "Failed to get offsets by timestamps in %s ms" % (timeout_ms,)) + + def fetched_records(self, max_records=None, update_offsets=True): + """Returns previously fetched records and updates consumed offsets. + + Arguments: + max_records (int): Maximum number of records returned. Defaults + to max_poll_records configuration. + + Raises: + OffsetOutOfRangeError: if no subscription offset_reset_strategy + CorruptRecordException: if message crc validation fails (check_crcs + must be set to True) + RecordTooLargeError: if a message is larger than the currently + configured max_partition_fetch_bytes + TopicAuthorizationError: if consumer is not authorized to fetch + messages from the topic + + Returns: (records (dict), partial (bool)) + records: {TopicPartition: [messages]} + partial: True if records returned did not fully drain any pending + partition requests. This may be useful for choosing when to + pipeline additional fetch requests. + """ + if max_records is None: + max_records = self.config['max_poll_records'] + assert max_records > 0 + + drained = collections.defaultdict(list) + records_remaining = max_records + + while records_remaining > 0: + if not self._next_partition_records: + if not self._completed_fetches: + break + completion = self._completed_fetches.popleft() + self._next_partition_records = self._parse_fetched_data(completion) + else: + records_remaining -= self._append(drained, + self._next_partition_records, + records_remaining, + update_offsets) + return dict(drained), bool(self._completed_fetches) + + def _append(self, drained, part, max_records, update_offsets): + if not part: + return 0 + + tp = part.topic_partition + fetch_offset = part.fetch_offset + if not self._subscriptions.is_assigned(tp): + # this can happen when a rebalance happened before + # fetched records are returned to the consumer's poll call + log.debug("Not returning fetched records for partition %s" + " since it is no longer assigned", tp) + else: + # note that the position should always be available + # as long as the partition is still assigned + position = self._subscriptions.assignment[tp].position + if not self._subscriptions.is_fetchable(tp): + # this can happen when a partition is paused before + # fetched records are returned to the consumer's poll call + log.debug("Not returning fetched records for assigned partition" + " %s since it is no longer fetchable", tp) + + elif fetch_offset == position: + # we are ensured to have at least one record since we already checked for emptiness + part_records = part.take(max_records) + next_offset = part_records[-1].offset + 1 + + log.log(0, "Returning fetched records at offset %d for assigned" + " partition %s and update position to %s", position, + tp, next_offset) + + for record in part_records: + drained[tp].append(record) + + if update_offsets: + self._subscriptions.assignment[tp].position = next_offset + return len(part_records) + + else: + # these records aren't next in line based on the last consumed + # position, ignore them they must be from an obsolete request + log.debug("Ignoring fetched records for %s at offset %s since" + " the current position is %d", tp, part.fetch_offset, + position) + + part.discard() + return 0 + + def _message_generator(self): + """Iterate over fetched_records""" + while self._next_partition_records or self._completed_fetches: + + if not self._next_partition_records: + completion = self._completed_fetches.popleft() + self._next_partition_records = self._parse_fetched_data(completion) + continue + + # Send additional FetchRequests when the internal queue is low + # this should enable moderate pipelining + if len(self._completed_fetches) <= self.config['iterator_refetch_records']: + self.send_fetches() + + tp = self._next_partition_records.topic_partition + + # We can ignore any prior signal to drop pending message sets + # because we are starting from a fresh one where fetch_offset == position + # i.e., the user seek()'d to this position + self._subscriptions.assignment[tp].drop_pending_message_set = False + + for msg in self._next_partition_records.take(): + + # Because we are in a generator, it is possible for + # subscription state to change between yield calls + # so we need to re-check on each loop + # this should catch assignment changes, pauses + # and resets via seek_to_beginning / seek_to_end + if not self._subscriptions.is_fetchable(tp): + log.debug("Not returning fetched records for partition %s" + " since it is no longer fetchable", tp) + self._next_partition_records = None + break + + # If there is a seek during message iteration, + # we should stop unpacking this message set and + # wait for a new fetch response that aligns with the + # new seek position + elif self._subscriptions.assignment[tp].drop_pending_message_set: + log.debug("Skipping remainder of message set for partition %s", tp) + self._subscriptions.assignment[tp].drop_pending_message_set = False + self._next_partition_records = None + break + + # Compressed messagesets may include earlier messages + elif msg.offset < self._subscriptions.assignment[tp].position: + log.debug("Skipping message offset: %s (expecting %s)", + msg.offset, + self._subscriptions.assignment[tp].position) + continue + + self._subscriptions.assignment[tp].position = msg.offset + 1 + yield msg + + self._next_partition_records = None + + def _unpack_message_set(self, tp, records): + try: + batch = records.next_batch() + while batch is not None: + + # LegacyRecordBatch cannot access either base_offset or last_offset_delta + try: + self._subscriptions.assignment[tp].last_offset_from_message_batch = batch.base_offset + \ + batch.last_offset_delta + except AttributeError: + pass + + for record in batch: + key_size = len(record.key) if record.key is not None else -1 + value_size = len(record.value) if record.value is not None else -1 + key = self._deserialize( + self.config['key_deserializer'], + tp.topic, record.key) + value = self._deserialize( + self.config['value_deserializer'], + tp.topic, record.value) + headers = record.headers + header_size = sum( + len(h_key.encode("utf-8")) + (len(h_val) if h_val is not None else 0) for h_key, h_val in + headers) if headers else -1 + yield ConsumerRecord( + tp.topic, tp.partition, record.offset, record.timestamp, + record.timestamp_type, key, value, headers, record.checksum, + key_size, value_size, header_size) + + batch = records.next_batch() + + # If unpacking raises StopIteration, it is erroneously + # caught by the generator. We want all exceptions to be raised + # back to the user. See Issue 545 + except StopIteration as e: + log.exception('StopIteration raised unpacking messageset') + raise RuntimeError('StopIteration raised unpacking messageset') + + def __iter__(self): # pylint: disable=non-iterator-returned + return self + + def __next__(self): + if not self._iterator: + self._iterator = self._message_generator() + try: + return next(self._iterator) + except StopIteration: + self._iterator = None + raise + + def _deserialize(self, f, topic, bytes_): + if not f: + return bytes_ + if isinstance(f, Deserializer): + return f.deserialize(topic, bytes_) + return f(bytes_) + + def _send_offset_requests(self, timestamps): + """Fetch offsets for each partition in timestamps dict. This may send + request to multiple nodes, based on who is Leader for partition. + + Arguments: + timestamps (dict): {TopicPartition: int} mapping of fetching + timestamps. + + Returns: + Future: resolves to a mapping of retrieved offsets + """ + timestamps_by_node = collections.defaultdict(dict) + for partition, timestamp in six.iteritems(timestamps): + node_id = self._client.cluster.leader_for_partition(partition) + if node_id is None: + self._client.add_topic(partition.topic) + log.debug("Partition %s is unknown for fetching offset," + " wait for metadata refresh", partition) + return Future().failure(Errors.StaleMetadata(partition)) + elif node_id == -1: + log.debug("Leader for partition %s unavailable for fetching " + "offset, wait for metadata refresh", partition) + return Future().failure( + Errors.LeaderNotAvailableError(partition)) + else: + timestamps_by_node[node_id][partition] = timestamp + + # Aggregate results until we have all + list_offsets_future = Future() + responses = [] + node_count = len(timestamps_by_node) + + def on_success(value): + responses.append(value) + if len(responses) == node_count: + offsets = {} + for r in responses: + offsets.update(r) + list_offsets_future.success(offsets) + + def on_fail(err): + if not list_offsets_future.is_done: + list_offsets_future.failure(err) + + for node_id, timestamps in six.iteritems(timestamps_by_node): + _f = self._send_offset_request(node_id, timestamps) + _f.add_callback(on_success) + _f.add_errback(on_fail) + return list_offsets_future + + def _send_offset_request(self, node_id, timestamps): + by_topic = collections.defaultdict(list) + for tp, timestamp in six.iteritems(timestamps): + if self.config['api_version'] >= (0, 10, 1): + data = (tp.partition, timestamp) + else: + data = (tp.partition, timestamp, 1) + by_topic[tp.topic].append(data) + + if self.config['api_version'] >= (0, 10, 1): + request = OffsetRequest[1](-1, list(six.iteritems(by_topic))) + else: + request = OffsetRequest[0](-1, list(six.iteritems(by_topic))) + + # Client returns a future that only fails on network issues + # so create a separate future and attach a callback to update it + # based on response error codes + future = Future() + + _f = self._client.send(node_id, request) + _f.add_callback(self._handle_offset_response, future) + _f.add_errback(lambda e: future.failure(e)) + return future + + def _handle_offset_response(self, future, response): + """Callback for the response of the list offset call above. + + Arguments: + future (Future): the future to update based on response + response (OffsetResponse): response from the server + + Raises: + AssertionError: if response does not match partition + """ + timestamp_offset_map = {} + for topic, part_data in response.topics: + for partition_info in part_data: + partition, error_code = partition_info[:2] + partition = TopicPartition(topic, partition) + error_type = Errors.for_code(error_code) + if error_type is Errors.NoError: + if response.API_VERSION == 0: + offsets = partition_info[2] + assert len(offsets) <= 1, 'Expected OffsetResponse with one offset' + if not offsets: + offset = UNKNOWN_OFFSET + else: + offset = offsets[0] + log.debug("Handling v0 ListOffsetResponse response for %s. " + "Fetched offset %s", partition, offset) + if offset != UNKNOWN_OFFSET: + timestamp_offset_map[partition] = (offset, None) + else: + timestamp, offset = partition_info[2:] + log.debug("Handling ListOffsetResponse response for %s. " + "Fetched offset %s, timestamp %s", + partition, offset, timestamp) + if offset != UNKNOWN_OFFSET: + timestamp_offset_map[partition] = (offset, timestamp) + elif error_type is Errors.UnsupportedForMessageFormatError: + # The message format on the broker side is before 0.10.0, + # we simply put None in the response. + log.debug("Cannot search by timestamp for partition %s because the" + " message format version is before 0.10.0", partition) + elif error_type is Errors.NotLeaderForPartitionError: + log.debug("Attempt to fetch offsets for partition %s failed due" + " to obsolete leadership information, retrying.", + partition) + future.failure(error_type(partition)) + return + elif error_type is Errors.UnknownTopicOrPartitionError: + log.warning("Received unknown topic or partition error in ListOffset " + "request for partition %s. The topic/partition " + + "may not exist or the user may not have Describe access " + "to it.", partition) + future.failure(error_type(partition)) + return + else: + log.warning("Attempt to fetch offsets for partition %s failed due to:" + " %s", partition, error_type) + future.failure(error_type(partition)) + return + if not future.is_done: + future.success(timestamp_offset_map) + + def _fetchable_partitions(self): + fetchable = self._subscriptions.fetchable_partitions() + # do not fetch a partition if we have a pending fetch response to process + current = self._next_partition_records + pending = copy.copy(self._completed_fetches) + if current: + fetchable.discard(current.topic_partition) + for fetch in pending: + fetchable.discard(fetch.topic_partition) + return fetchable + + def _create_fetch_requests(self): + """Create fetch requests for all assigned partitions, grouped by node. + + FetchRequests skipped if no leader, or node has requests in flight + + Returns: + dict: {node_id: FetchRequest, ...} (version depends on api_version) + """ + # create the fetch info as a dict of lists of partition info tuples + # which can be passed to FetchRequest() via .items() + fetchable = collections.defaultdict(lambda: collections.defaultdict(list)) + + for partition in self._fetchable_partitions(): + node_id = self._client.cluster.leader_for_partition(partition) + + # advance position for any deleted compacted messages if required + if self._subscriptions.assignment[partition].last_offset_from_message_batch: + next_offset_from_batch_header = self._subscriptions.assignment[partition].last_offset_from_message_batch + 1 + if next_offset_from_batch_header > self._subscriptions.assignment[partition].position: + log.debug( + "Advance position for partition %s from %s to %s (last message batch location plus one)" + " to correct for deleted compacted messages", + partition, self._subscriptions.assignment[partition].position, next_offset_from_batch_header) + self._subscriptions.assignment[partition].position = next_offset_from_batch_header + + position = self._subscriptions.assignment[partition].position + + # fetch if there is a leader and no in-flight requests + if node_id is None or node_id == -1: + log.debug("No leader found for partition %s." + " Requesting metadata update", partition) + self._client.cluster.request_update() + + elif self._client.in_flight_request_count(node_id) == 0: + partition_info = ( + partition.partition, + position, + self.config['max_partition_fetch_bytes'] + ) + fetchable[node_id][partition.topic].append(partition_info) + log.debug("Adding fetch request for partition %s at offset %d", + partition, position) + else: + log.log(0, "Skipping fetch for partition %s because there is an inflight request to node %s", + partition, node_id) + + if self.config['api_version'] >= (0, 11, 0): + version = 4 + elif self.config['api_version'] >= (0, 10, 1): + version = 3 + elif self.config['api_version'] >= (0, 10): + version = 2 + elif self.config['api_version'] == (0, 9): + version = 1 + else: + version = 0 + requests = {} + for node_id, partition_data in six.iteritems(fetchable): + if version < 3: + requests[node_id] = FetchRequest[version]( + -1, # replica_id + self.config['fetch_max_wait_ms'], + self.config['fetch_min_bytes'], + partition_data.items()) + else: + # As of version == 3 partitions will be returned in order as + # they are requested, so to avoid starvation with + # `fetch_max_bytes` option we need this shuffle + # NOTE: we do have partition_data in random order due to usage + # of unordered structures like dicts, but that does not + # guarantee equal distribution, and starting in Python3.6 + # dicts retain insert order. + partition_data = list(partition_data.items()) + random.shuffle(partition_data) + if version == 3: + requests[node_id] = FetchRequest[version]( + -1, # replica_id + self.config['fetch_max_wait_ms'], + self.config['fetch_min_bytes'], + self.config['fetch_max_bytes'], + partition_data) + else: + requests[node_id] = FetchRequest[version]( + -1, # replica_id + self.config['fetch_max_wait_ms'], + self.config['fetch_min_bytes'], + self.config['fetch_max_bytes'], + self._isolation_level, + partition_data) + return requests + + def _handle_fetch_response(self, request, send_time, response): + """The callback for fetch completion""" + fetch_offsets = {} + for topic, partitions in request.topics: + for partition_data in partitions: + partition, offset = partition_data[:2] + fetch_offsets[TopicPartition(topic, partition)] = offset + + partitions = set([TopicPartition(topic, partition_data[0]) + for topic, partitions in response.topics + for partition_data in partitions]) + metric_aggregator = FetchResponseMetricAggregator(self._sensors, partitions) + + # randomized ordering should improve balance for short-lived consumers + random.shuffle(response.topics) + for topic, partitions in response.topics: + random.shuffle(partitions) + for partition_data in partitions: + tp = TopicPartition(topic, partition_data[0]) + completed_fetch = CompletedFetch( + tp, fetch_offsets[tp], + response.API_VERSION, + partition_data[1:], + metric_aggregator + ) + self._completed_fetches.append(completed_fetch) + + if response.API_VERSION >= 1: + self._sensors.fetch_throttle_time_sensor.record(response.throttle_time_ms) + self._sensors.fetch_latency.record((time.time() - send_time) * 1000) + + def _parse_fetched_data(self, completed_fetch): + tp = completed_fetch.topic_partition + fetch_offset = completed_fetch.fetched_offset + num_bytes = 0 + records_count = 0 + parsed_records = None + + error_code, highwater = completed_fetch.partition_data[:2] + error_type = Errors.for_code(error_code) + + try: + if not self._subscriptions.is_fetchable(tp): + # this can happen when a rebalance happened or a partition + # consumption paused while fetch is still in-flight + log.debug("Ignoring fetched records for partition %s" + " since it is no longer fetchable", tp) + + elif error_type is Errors.NoError: + self._subscriptions.assignment[tp].highwater = highwater + + # we are interested in this fetch only if the beginning + # offset (of the *request*) matches the current consumed position + # Note that the *response* may return a messageset that starts + # earlier (e.g., compressed messages) or later (e.g., compacted topic) + position = self._subscriptions.assignment[tp].position + if position is None or position != fetch_offset: + log.debug("Discarding fetch response for partition %s" + " since its offset %d does not match the" + " expected offset %d", tp, fetch_offset, + position) + return None + + records = MemoryRecords(completed_fetch.partition_data[-1]) + if records.has_next(): + log.debug("Adding fetched record for partition %s with" + " offset %d to buffered record list", tp, + position) + unpacked = list(self._unpack_message_set(tp, records)) + parsed_records = self.PartitionRecords(fetch_offset, tp, unpacked) + if unpacked: + last_offset = unpacked[-1].offset + self._sensors.records_fetch_lag.record(highwater - last_offset) + num_bytes = records.valid_bytes() + records_count = len(unpacked) + elif records.size_in_bytes() > 0: + # we did not read a single message from a non-empty + # buffer because that message's size is larger than + # fetch size, in this case record this exception + record_too_large_partitions = {tp: fetch_offset} + raise RecordTooLargeError( + "There are some messages at [Partition=Offset]: %s " + " whose size is larger than the fetch size %s" + " and hence cannot be ever returned." + " Increase the fetch size, or decrease the maximum message" + " size the broker will allow." % ( + record_too_large_partitions, + self.config['max_partition_fetch_bytes']), + record_too_large_partitions) + self._sensors.record_topic_fetch_metrics(tp.topic, num_bytes, records_count) + + elif error_type in (Errors.NotLeaderForPartitionError, + Errors.UnknownTopicOrPartitionError): + self._client.cluster.request_update() + elif error_type is Errors.OffsetOutOfRangeError: + position = self._subscriptions.assignment[tp].position + if position is None or position != fetch_offset: + log.debug("Discarding stale fetch response for partition %s" + " since the fetched offset %d does not match the" + " current offset %d", tp, fetch_offset, position) + elif self._subscriptions.has_default_offset_reset_policy(): + log.info("Fetch offset %s is out of range for topic-partition %s", fetch_offset, tp) + self._subscriptions.need_offset_reset(tp) + else: + raise Errors.OffsetOutOfRangeError({tp: fetch_offset}) + + elif error_type is Errors.TopicAuthorizationFailedError: + log.warning("Not authorized to read from topic %s.", tp.topic) + raise Errors.TopicAuthorizationFailedError(set(tp.topic)) + elif error_type is Errors.UnknownError: + log.warning("Unknown error fetching data for topic-partition %s", tp) + else: + raise error_type('Unexpected error while fetching data') + + finally: + completed_fetch.metric_aggregator.record(tp, num_bytes, records_count) + + return parsed_records + + class PartitionRecords(object): + def __init__(self, fetch_offset, tp, messages): + self.fetch_offset = fetch_offset + self.topic_partition = tp + self.messages = messages + # When fetching an offset that is in the middle of a + # compressed batch, we will get all messages in the batch. + # But we want to start 'take' at the fetch_offset + # (or the next highest offset in case the message was compacted) + for i, msg in enumerate(messages): + if msg.offset < fetch_offset: + log.debug("Skipping message offset: %s (expecting %s)", + msg.offset, fetch_offset) + else: + self.message_idx = i + break + + else: + self.message_idx = 0 + self.messages = None + + # For truthiness evaluation we need to define __len__ or __nonzero__ + def __len__(self): + if self.messages is None or self.message_idx >= len(self.messages): + return 0 + return len(self.messages) - self.message_idx + + def discard(self): + self.messages = None + + def take(self, n=None): + if not len(self): + return [] + if n is None or n > len(self): + n = len(self) + next_idx = self.message_idx + n + res = self.messages[self.message_idx:next_idx] + self.message_idx = next_idx + # fetch_offset should be incremented by 1 to parallel the + # subscription position (also incremented by 1) + self.fetch_offset = max(self.fetch_offset, res[-1].offset + 1) + return res + + +class FetchResponseMetricAggregator(object): + """ + Since we parse the message data for each partition from each fetch + response lazily, fetch-level metrics need to be aggregated as the messages + from each partition are parsed. This class is used to facilitate this + incremental aggregation. + """ + def __init__(self, sensors, partitions): + self.sensors = sensors + self.unrecorded_partitions = partitions + self.total_bytes = 0 + self.total_records = 0 + + def record(self, partition, num_bytes, num_records): + """ + After each partition is parsed, we update the current metric totals + with the total bytes and number of records parsed. After all partitions + have reported, we write the metric. + """ + self.unrecorded_partitions.remove(partition) + self.total_bytes += num_bytes + self.total_records += num_records + + # once all expected partitions from the fetch have reported in, record the metrics + if not self.unrecorded_partitions: + self.sensors.bytes_fetched.record(self.total_bytes) + self.sensors.records_fetched.record(self.total_records) + + +class FetchManagerMetrics(object): + def __init__(self, metrics, prefix): + self.metrics = metrics + self.group_name = '%s-fetch-manager-metrics' % (prefix,) + + self.bytes_fetched = metrics.sensor('bytes-fetched') + self.bytes_fetched.add(metrics.metric_name('fetch-size-avg', self.group_name, + 'The average number of bytes fetched per request'), Avg()) + self.bytes_fetched.add(metrics.metric_name('fetch-size-max', self.group_name, + 'The maximum number of bytes fetched per request'), Max()) + self.bytes_fetched.add(metrics.metric_name('bytes-consumed-rate', self.group_name, + 'The average number of bytes consumed per second'), Rate()) + + self.records_fetched = self.metrics.sensor('records-fetched') + self.records_fetched.add(metrics.metric_name('records-per-request-avg', self.group_name, + 'The average number of records in each request'), Avg()) + self.records_fetched.add(metrics.metric_name('records-consumed-rate', self.group_name, + 'The average number of records consumed per second'), Rate()) + + self.fetch_latency = metrics.sensor('fetch-latency') + self.fetch_latency.add(metrics.metric_name('fetch-latency-avg', self.group_name, + 'The average time taken for a fetch request.'), Avg()) + self.fetch_latency.add(metrics.metric_name('fetch-latency-max', self.group_name, + 'The max time taken for any fetch request.'), Max()) + self.fetch_latency.add(metrics.metric_name('fetch-rate', self.group_name, + 'The number of fetch requests per second.'), Rate(sampled_stat=Count())) + + self.records_fetch_lag = metrics.sensor('records-lag') + self.records_fetch_lag.add(metrics.metric_name('records-lag-max', self.group_name, + 'The maximum lag in terms of number of records for any partition in self window'), Max()) + + self.fetch_throttle_time_sensor = metrics.sensor('fetch-throttle-time') + self.fetch_throttle_time_sensor.add(metrics.metric_name('fetch-throttle-time-avg', self.group_name, + 'The average throttle time in ms'), Avg()) + self.fetch_throttle_time_sensor.add(metrics.metric_name('fetch-throttle-time-max', self.group_name, + 'The maximum throttle time in ms'), Max()) + + def record_topic_fetch_metrics(self, topic, num_bytes, num_records): + # record bytes fetched + name = '.'.join(['topic', topic, 'bytes-fetched']) + bytes_fetched = self.metrics.get_sensor(name) + if not bytes_fetched: + metric_tags = {'topic': topic.replace('.', '_')} + + bytes_fetched = self.metrics.sensor(name) + bytes_fetched.add(self.metrics.metric_name('fetch-size-avg', + self.group_name, + 'The average number of bytes fetched per request for topic %s' % (topic,), + metric_tags), Avg()) + bytes_fetched.add(self.metrics.metric_name('fetch-size-max', + self.group_name, + 'The maximum number of bytes fetched per request for topic %s' % (topic,), + metric_tags), Max()) + bytes_fetched.add(self.metrics.metric_name('bytes-consumed-rate', + self.group_name, + 'The average number of bytes consumed per second for topic %s' % (topic,), + metric_tags), Rate()) + bytes_fetched.record(num_bytes) + + # record records fetched + name = '.'.join(['topic', topic, 'records-fetched']) + records_fetched = self.metrics.get_sensor(name) + if not records_fetched: + metric_tags = {'topic': topic.replace('.', '_')} + + records_fetched = self.metrics.sensor(name) + records_fetched.add(self.metrics.metric_name('records-per-request-avg', + self.group_name, + 'The average number of records in each request for topic %s' % (topic,), + metric_tags), Avg()) + records_fetched.add(self.metrics.metric_name('records-consumed-rate', + self.group_name, + 'The average number of records consumed per second for topic %s' % (topic,), + metric_tags), Rate()) + records_fetched.record(num_records) diff --git a/consumer/group.py b/consumer/group.py new file mode 100644 index 00000000..a1d1dfa3 --- /dev/null +++ b/consumer/group.py @@ -0,0 +1,1225 @@ +from __future__ import absolute_import, division + +import copy +import logging +import socket +import time + +from kafka.errors import KafkaConfigurationError, UnsupportedVersionError + +from kafka.vendor import six + +from kafka.client_async import KafkaClient, selectors +from kafka.consumer.fetcher import Fetcher +from kafka.consumer.subscription_state import SubscriptionState +from kafka.coordinator.consumer import ConsumerCoordinator +from kafka.coordinator.assignors.range import RangePartitionAssignor +from kafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor +from kafka.metrics import MetricConfig, Metrics +from kafka.protocol.offset import OffsetResetStrategy +from kafka.structs import TopicPartition +from kafka.version import __version__ + +log = logging.getLogger(__name__) + + +class KafkaConsumer(six.Iterator): + """Consume records from a Kafka cluster. + + The consumer will transparently handle the failure of servers in the Kafka + cluster, and adapt as topic-partitions are created or migrate between + brokers. It also interacts with the assigned kafka Group Coordinator node + to allow multiple consumers to load balance consumption of topics (requires + kafka >= 0.9.0.0). + + The consumer is not thread safe and should not be shared across threads. + + Arguments: + *topics (str): optional list of topics to subscribe to. If not set, + call :meth:`~kafka.KafkaConsumer.subscribe` or + :meth:`~kafka.KafkaConsumer.assign` before consuming records. + + Keyword Arguments: + bootstrap_servers: 'host[:port]' string (or list of 'host[:port]' + strings) that the consumer should contact to bootstrap initial + cluster metadata. This does not have to be the full node list. + It just needs to have at least one broker that will respond to a + Metadata API Request. Default port is 9092. If no servers are + specified, will default to localhost:9092. + client_id (str): A name for this client. This string is passed in + each request to servers and can be used to identify specific + server-side log entries that correspond to this client. Also + submitted to GroupCoordinator for logging with respect to + consumer group administration. Default: 'kafka-python-{version}' + group_id (str or None): The name of the consumer group to join for dynamic + partition assignment (if enabled), and to use for fetching and + committing offsets. If None, auto-partition assignment (via + group coordinator) and offset commits are disabled. + Default: None + key_deserializer (callable): Any callable that takes a + raw message key and returns a deserialized key. + value_deserializer (callable): Any callable that takes a + raw message value and returns a deserialized value. + fetch_min_bytes (int): Minimum amount of data the server should + return for a fetch request, otherwise wait up to + fetch_max_wait_ms for more data to accumulate. Default: 1. + fetch_max_wait_ms (int): The maximum amount of time in milliseconds + the server will block before answering the fetch request if + there isn't sufficient data to immediately satisfy the + requirement given by fetch_min_bytes. Default: 500. + fetch_max_bytes (int): The maximum amount of data the server should + return for a fetch request. This is not an absolute maximum, if the + first message in the first non-empty partition of the fetch is + larger than this value, the message will still be returned to + ensure that the consumer can make progress. NOTE: consumer performs + fetches to multiple brokers in parallel so memory usage will depend + on the number of brokers containing partitions for the topic. + Supported Kafka version >= 0.10.1.0. Default: 52428800 (50 MB). + max_partition_fetch_bytes (int): The maximum amount of data + per-partition the server will return. The maximum total memory + used for a request = #partitions * max_partition_fetch_bytes. + This size must be at least as large as the maximum message size + the server allows or else it is possible for the producer to + send messages larger than the consumer can fetch. If that + happens, the consumer can get stuck trying to fetch a large + message on a certain partition. Default: 1048576. + request_timeout_ms (int): Client request timeout in milliseconds. + Default: 305000. + retry_backoff_ms (int): Milliseconds to backoff when retrying on + errors. Default: 100. + reconnect_backoff_ms (int): The amount of time in milliseconds to + wait before attempting to reconnect to a given host. + Default: 50. + reconnect_backoff_max_ms (int): The maximum amount of time in + milliseconds to backoff/wait when reconnecting to a broker that has + repeatedly failed to connect. If provided, the backoff per host + will increase exponentially for each consecutive connection + failure, up to this maximum. Once the maximum is reached, + reconnection attempts will continue periodically with this fixed + rate. To avoid connection storms, a randomization factor of 0.2 + will be applied to the backoff resulting in a random range between + 20% below and 20% above the computed value. Default: 1000. + max_in_flight_requests_per_connection (int): Requests are pipelined + to kafka brokers up to this number of maximum requests per + broker connection. Default: 5. + auto_offset_reset (str): A policy for resetting offsets on + OffsetOutOfRange errors: 'earliest' will move to the oldest + available message, 'latest' will move to the most recent. Any + other value will raise the exception. Default: 'latest'. + enable_auto_commit (bool): If True , the consumer's offset will be + periodically committed in the background. Default: True. + auto_commit_interval_ms (int): Number of milliseconds between automatic + offset commits, if enable_auto_commit is True. Default: 5000. + default_offset_commit_callback (callable): Called as + callback(offsets, response) response will be either an Exception + or an OffsetCommitResponse struct. This callback can be used to + trigger custom actions when a commit request completes. + check_crcs (bool): Automatically check the CRC32 of the records + consumed. This ensures no on-the-wire or on-disk corruption to + the messages occurred. This check adds some overhead, so it may + be disabled in cases seeking extreme performance. Default: True + metadata_max_age_ms (int): The period of time in milliseconds after + which we force a refresh of metadata, even if we haven't seen any + partition leadership changes to proactively discover any new + brokers or partitions. Default: 300000 + partition_assignment_strategy (list): List of objects to use to + distribute partition ownership amongst consumer instances when + group management is used. + Default: [RangePartitionAssignor, RoundRobinPartitionAssignor] + max_poll_records (int): The maximum number of records returned in a + single call to :meth:`~kafka.KafkaConsumer.poll`. Default: 500 + max_poll_interval_ms (int): The maximum delay between invocations of + :meth:`~kafka.KafkaConsumer.poll` when using consumer group + management. This places an upper bound on the amount of time that + the consumer can be idle before fetching more records. If + :meth:`~kafka.KafkaConsumer.poll` is not called before expiration + of this timeout, then the consumer is considered failed and the + group will rebalance in order to reassign the partitions to another + member. Default 300000 + session_timeout_ms (int): The timeout used to detect failures when + using Kafka's group management facilities. The consumer sends + periodic heartbeats to indicate its liveness to the broker. If + no heartbeats are received by the broker before the expiration of + this session timeout, then the broker will remove this consumer + from the group and initiate a rebalance. Note that the value must + be in the allowable range as configured in the broker configuration + by group.min.session.timeout.ms and group.max.session.timeout.ms. + Default: 10000 + heartbeat_interval_ms (int): The expected time in milliseconds + between heartbeats to the consumer coordinator when using + Kafka's group management facilities. Heartbeats are used to ensure + that the consumer's session stays active and to facilitate + rebalancing when new consumers join or leave the group. The + value must be set lower than session_timeout_ms, but typically + should be set no higher than 1/3 of that value. It can be + adjusted even lower to control the expected time for normal + rebalances. Default: 3000 + receive_buffer_bytes (int): The size of the TCP receive buffer + (SO_RCVBUF) to use when reading data. Default: None (relies on + system defaults). The java client defaults to 32768. + send_buffer_bytes (int): The size of the TCP send buffer + (SO_SNDBUF) to use when sending data. Default: None (relies on + system defaults). The java client defaults to 131072. + socket_options (list): List of tuple-arguments to socket.setsockopt + to apply to broker connection sockets. Default: + [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)] + consumer_timeout_ms (int): number of milliseconds to block during + message iteration before raising StopIteration (i.e., ending the + iterator). Default block forever [float('inf')]. + security_protocol (str): Protocol used to communicate with brokers. + Valid values are: PLAINTEXT, SSL, SASL_PLAINTEXT, SASL_SSL. + Default: PLAINTEXT. + ssl_context (ssl.SSLContext): Pre-configured SSLContext for wrapping + socket connections. If provided, all other ssl_* configurations + will be ignored. Default: None. + ssl_check_hostname (bool): Flag to configure whether ssl handshake + should verify that the certificate matches the brokers hostname. + Default: True. + ssl_cafile (str): Optional filename of ca file to use in certificate + verification. Default: None. + ssl_certfile (str): Optional filename of file in pem format containing + the client certificate, as well as any ca certificates needed to + establish the certificate's authenticity. Default: None. + ssl_keyfile (str): Optional filename containing the client private key. + Default: None. + ssl_password (str): Optional password to be used when loading the + certificate chain. Default: None. + ssl_crlfile (str): Optional filename containing the CRL to check for + certificate expiration. By default, no CRL check is done. When + providing a file, only the leaf certificate will be checked against + this CRL. The CRL can only be checked with Python 3.4+ or 2.7.9+. + Default: None. + ssl_ciphers (str): optionally set the available ciphers for ssl + connections. It should be a string in the OpenSSL cipher list + format. If no cipher can be selected (because compile-time options + or other configuration forbids use of all the specified ciphers), + an ssl.SSLError will be raised. See ssl.SSLContext.set_ciphers + api_version (tuple): Specify which Kafka API version to use. If set to + None, the client will attempt to infer the broker version by probing + various APIs. Different versions enable different functionality. + + Examples: + (0, 9) enables full group coordination features with automatic + partition assignment and rebalancing, + (0, 8, 2) enables kafka-storage offset commits with manual + partition assignment only, + (0, 8, 1) enables zookeeper-storage offset commits with manual + partition assignment only, + (0, 8, 0) enables basic functionality but requires manual + partition assignment and offset management. + + Default: None + api_version_auto_timeout_ms (int): number of milliseconds to throw a + timeout exception from the constructor when checking the broker + api version. Only applies if api_version set to None. + connections_max_idle_ms: Close idle connections after the number of + milliseconds specified by this config. The broker closes idle + connections after connections.max.idle.ms, so this avoids hitting + unexpected socket disconnected errors on the client. + Default: 540000 + metric_reporters (list): A list of classes to use as metrics reporters. + Implementing the AbstractMetricsReporter interface allows plugging + in classes that will be notified of new metric creation. Default: [] + metrics_num_samples (int): The number of samples maintained to compute + metrics. Default: 2 + metrics_sample_window_ms (int): The maximum age in milliseconds of + samples used to compute metrics. Default: 30000 + selector (selectors.BaseSelector): Provide a specific selector + implementation to use for I/O multiplexing. + Default: selectors.DefaultSelector + exclude_internal_topics (bool): Whether records from internal topics + (such as offsets) should be exposed to the consumer. If set to True + the only way to receive records from an internal topic is + subscribing to it. Requires 0.10+ Default: True + sasl_mechanism (str): Authentication mechanism when security_protocol + is configured for SASL_PLAINTEXT or SASL_SSL. Valid values are: + PLAIN, GSSAPI, OAUTHBEARER, SCRAM-SHA-256, SCRAM-SHA-512. + sasl_plain_username (str): username for sasl PLAIN and SCRAM authentication. + Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms. + sasl_plain_password (str): password for sasl PLAIN and SCRAM authentication. + Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms. + sasl_kerberos_service_name (str): Service name to include in GSSAPI + sasl mechanism handshake. Default: 'kafka' + sasl_kerberos_domain_name (str): kerberos domain name to use in GSSAPI + sasl mechanism handshake. Default: one of bootstrap servers + sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider + instance. (See kafka.oauth.abstract). Default: None + kafka_client (callable): Custom class / callable for creating KafkaClient instances + + Note: + Configuration parameters are described in more detail at + https://kafka.apache.org/documentation/#consumerconfigs + """ + DEFAULT_CONFIG = { + 'bootstrap_servers': 'localhost', + 'client_id': 'kafka-python-' + __version__, + 'group_id': None, + 'key_deserializer': None, + 'value_deserializer': None, + 'fetch_max_wait_ms': 500, + 'fetch_min_bytes': 1, + 'fetch_max_bytes': 52428800, + 'max_partition_fetch_bytes': 1 * 1024 * 1024, + 'request_timeout_ms': 305000, # chosen to be higher than the default of max_poll_interval_ms + 'retry_backoff_ms': 100, + 'reconnect_backoff_ms': 50, + 'reconnect_backoff_max_ms': 1000, + 'max_in_flight_requests_per_connection': 5, + 'auto_offset_reset': 'latest', + 'enable_auto_commit': True, + 'auto_commit_interval_ms': 5000, + 'default_offset_commit_callback': lambda offsets, response: True, + 'check_crcs': True, + 'metadata_max_age_ms': 5 * 60 * 1000, + 'partition_assignment_strategy': (RangePartitionAssignor, RoundRobinPartitionAssignor), + 'max_poll_records': 500, + 'max_poll_interval_ms': 300000, + 'session_timeout_ms': 10000, + 'heartbeat_interval_ms': 3000, + 'receive_buffer_bytes': None, + 'send_buffer_bytes': None, + 'socket_options': [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)], + 'sock_chunk_bytes': 4096, # undocumented experimental option + 'sock_chunk_buffer_count': 1000, # undocumented experimental option + 'consumer_timeout_ms': float('inf'), + 'security_protocol': 'PLAINTEXT', + 'ssl_context': None, + 'ssl_check_hostname': True, + 'ssl_cafile': None, + 'ssl_certfile': None, + 'ssl_keyfile': None, + 'ssl_crlfile': None, + 'ssl_password': None, + 'ssl_ciphers': None, + 'api_version': None, + 'api_version_auto_timeout_ms': 2000, + 'connections_max_idle_ms': 9 * 60 * 1000, + 'metric_reporters': [], + 'metrics_num_samples': 2, + 'metrics_sample_window_ms': 30000, + 'metric_group_prefix': 'consumer', + 'selector': selectors.DefaultSelector, + 'exclude_internal_topics': True, + 'sasl_mechanism': None, + 'sasl_plain_username': None, + 'sasl_plain_password': None, + 'sasl_kerberos_service_name': 'kafka', + 'sasl_kerberos_domain_name': None, + 'sasl_oauth_token_provider': None, + 'legacy_iterator': False, # enable to revert to < 1.4.7 iterator + 'kafka_client': KafkaClient, + } + DEFAULT_SESSION_TIMEOUT_MS_0_9 = 30000 + + def __init__(self, *topics, **configs): + # Only check for extra config keys in top-level class + extra_configs = set(configs).difference(self.DEFAULT_CONFIG) + if extra_configs: + raise KafkaConfigurationError("Unrecognized configs: %s" % (extra_configs,)) + + self.config = copy.copy(self.DEFAULT_CONFIG) + self.config.update(configs) + + deprecated = {'smallest': 'earliest', 'largest': 'latest'} + if self.config['auto_offset_reset'] in deprecated: + new_config = deprecated[self.config['auto_offset_reset']] + log.warning('use auto_offset_reset=%s (%s is deprecated)', + new_config, self.config['auto_offset_reset']) + self.config['auto_offset_reset'] = new_config + + connections_max_idle_ms = self.config['connections_max_idle_ms'] + request_timeout_ms = self.config['request_timeout_ms'] + fetch_max_wait_ms = self.config['fetch_max_wait_ms'] + if not (fetch_max_wait_ms < request_timeout_ms < connections_max_idle_ms): + raise KafkaConfigurationError( + "connections_max_idle_ms ({}) must be larger than " + "request_timeout_ms ({}) which must be larger than " + "fetch_max_wait_ms ({})." + .format(connections_max_idle_ms, request_timeout_ms, fetch_max_wait_ms)) + + metrics_tags = {'client-id': self.config['client_id']} + metric_config = MetricConfig(samples=self.config['metrics_num_samples'], + time_window_ms=self.config['metrics_sample_window_ms'], + tags=metrics_tags) + reporters = [reporter() for reporter in self.config['metric_reporters']] + self._metrics = Metrics(metric_config, reporters) + # TODO _metrics likely needs to be passed to KafkaClient, etc. + + # api_version was previously a str. Accept old format for now + if isinstance(self.config['api_version'], str): + str_version = self.config['api_version'] + if str_version == 'auto': + self.config['api_version'] = None + else: + self.config['api_version'] = tuple(map(int, str_version.split('.'))) + log.warning('use api_version=%s [tuple] -- "%s" as str is deprecated', + str(self.config['api_version']), str_version) + + self._client = self.config['kafka_client'](metrics=self._metrics, **self.config) + + # Get auto-discovered version from client if necessary + if self.config['api_version'] is None: + self.config['api_version'] = self._client.config['api_version'] + + # Coordinator configurations are different for older brokers + # max_poll_interval_ms is not supported directly -- it must the be + # the same as session_timeout_ms. If the user provides one of them, + # use it for both. Otherwise use the old default of 30secs + if self.config['api_version'] < (0, 10, 1): + if 'session_timeout_ms' not in configs: + if 'max_poll_interval_ms' in configs: + self.config['session_timeout_ms'] = configs['max_poll_interval_ms'] + else: + self.config['session_timeout_ms'] = self.DEFAULT_SESSION_TIMEOUT_MS_0_9 + if 'max_poll_interval_ms' not in configs: + self.config['max_poll_interval_ms'] = self.config['session_timeout_ms'] + + if self.config['group_id'] is not None: + if self.config['request_timeout_ms'] <= self.config['session_timeout_ms']: + raise KafkaConfigurationError( + "Request timeout (%s) must be larger than session timeout (%s)" % + (self.config['request_timeout_ms'], self.config['session_timeout_ms'])) + + self._subscription = SubscriptionState(self.config['auto_offset_reset']) + self._fetcher = Fetcher( + self._client, self._subscription, self._metrics, **self.config) + self._coordinator = ConsumerCoordinator( + self._client, self._subscription, self._metrics, + assignors=self.config['partition_assignment_strategy'], + **self.config) + self._closed = False + self._iterator = None + self._consumer_timeout = float('inf') + + if topics: + self._subscription.subscribe(topics=topics) + self._client.set_topics(topics) + + def bootstrap_connected(self): + """Return True if the bootstrap is connected.""" + return self._client.bootstrap_connected() + + def assign(self, partitions): + """Manually assign a list of TopicPartitions to this consumer. + + Arguments: + partitions (list of TopicPartition): Assignment for this instance. + + Raises: + IllegalStateError: If consumer has already called + :meth:`~kafka.KafkaConsumer.subscribe`. + + Warning: + It is not possible to use both manual partition assignment with + :meth:`~kafka.KafkaConsumer.assign` and group assignment with + :meth:`~kafka.KafkaConsumer.subscribe`. + + Note: + This interface does not support incremental assignment and will + replace the previous assignment (if there was one). + + Note: + Manual topic assignment through this method does not use the + consumer's group management functionality. As such, there will be + no rebalance operation triggered when group membership or cluster + and topic metadata change. + """ + self._subscription.assign_from_user(partitions) + self._client.set_topics([tp.topic for tp in partitions]) + + def assignment(self): + """Get the TopicPartitions currently assigned to this consumer. + + If partitions were directly assigned using + :meth:`~kafka.KafkaConsumer.assign`, then this will simply return the + same partitions that were previously assigned. If topics were + subscribed using :meth:`~kafka.KafkaConsumer.subscribe`, then this will + give the set of topic partitions currently assigned to the consumer + (which may be None if the assignment hasn't happened yet, or if the + partitions are in the process of being reassigned). + + Returns: + set: {TopicPartition, ...} + """ + return self._subscription.assigned_partitions() + + def close(self, autocommit=True): + """Close the consumer, waiting indefinitely for any needed cleanup. + + Keyword Arguments: + autocommit (bool): If auto-commit is configured for this consumer, + this optional flag causes the consumer to attempt to commit any + pending consumed offsets prior to close. Default: True + """ + if self._closed: + return + log.debug("Closing the KafkaConsumer.") + self._closed = True + self._coordinator.close(autocommit=autocommit) + self._metrics.close() + self._client.close() + try: + self.config['key_deserializer'].close() + except AttributeError: + pass + try: + self.config['value_deserializer'].close() + except AttributeError: + pass + log.debug("The KafkaConsumer has closed.") + + def commit_async(self, offsets=None, callback=None): + """Commit offsets to kafka asynchronously, optionally firing callback. + + This commits offsets only to Kafka. The offsets committed using this API + will be used on the first fetch after every rebalance and also on + startup. As such, if you need to store offsets in anything other than + Kafka, this API should not be used. To avoid re-processing the last + message read if a consumer is restarted, the committed offset should be + the next message your application should consume, i.e.: last_offset + 1. + + This is an asynchronous call and will not block. Any errors encountered + are either passed to the callback (if provided) or discarded. + + Arguments: + offsets (dict, optional): {TopicPartition: OffsetAndMetadata} dict + to commit with the configured group_id. Defaults to currently + consumed offsets for all subscribed partitions. + callback (callable, optional): Called as callback(offsets, response) + with response as either an Exception or an OffsetCommitResponse + struct. This callback can be used to trigger custom actions when + a commit request completes. + + Returns: + kafka.future.Future + """ + assert self.config['api_version'] >= (0, 8, 1), 'Requires >= Kafka 0.8.1' + assert self.config['group_id'] is not None, 'Requires group_id' + if offsets is None: + offsets = self._subscription.all_consumed_offsets() + log.debug("Committing offsets: %s", offsets) + future = self._coordinator.commit_offsets_async( + offsets, callback=callback) + return future + + def commit(self, offsets=None): + """Commit offsets to kafka, blocking until success or error. + + This commits offsets only to Kafka. The offsets committed using this API + will be used on the first fetch after every rebalance and also on + startup. As such, if you need to store offsets in anything other than + Kafka, this API should not be used. To avoid re-processing the last + message read if a consumer is restarted, the committed offset should be + the next message your application should consume, i.e.: last_offset + 1. + + Blocks until either the commit succeeds or an unrecoverable error is + encountered (in which case it is thrown to the caller). + + Currently only supports kafka-topic offset storage (not zookeeper). + + Arguments: + offsets (dict, optional): {TopicPartition: OffsetAndMetadata} dict + to commit with the configured group_id. Defaults to currently + consumed offsets for all subscribed partitions. + """ + assert self.config['api_version'] >= (0, 8, 1), 'Requires >= Kafka 0.8.1' + assert self.config['group_id'] is not None, 'Requires group_id' + if offsets is None: + offsets = self._subscription.all_consumed_offsets() + self._coordinator.commit_offsets_sync(offsets) + + def committed(self, partition, metadata=False): + """Get the last committed offset for the given partition. + + This offset will be used as the position for the consumer + in the event of a failure. + + This call may block to do a remote call if the partition in question + isn't assigned to this consumer or if the consumer hasn't yet + initialized its cache of committed offsets. + + Arguments: + partition (TopicPartition): The partition to check. + metadata (bool, optional): If True, return OffsetAndMetadata struct + instead of offset int. Default: False. + + Returns: + The last committed offset (int or OffsetAndMetadata), or None if there was no prior commit. + """ + assert self.config['api_version'] >= (0, 8, 1), 'Requires >= Kafka 0.8.1' + assert self.config['group_id'] is not None, 'Requires group_id' + if not isinstance(partition, TopicPartition): + raise TypeError('partition must be a TopicPartition namedtuple') + if self._subscription.is_assigned(partition): + committed = self._subscription.assignment[partition].committed + if committed is None: + self._coordinator.refresh_committed_offsets_if_needed() + committed = self._subscription.assignment[partition].committed + else: + commit_map = self._coordinator.fetch_committed_offsets([partition]) + if partition in commit_map: + committed = commit_map[partition] + else: + committed = None + + if committed is not None: + if metadata: + return committed + else: + return committed.offset + + def _fetch_all_topic_metadata(self): + """A blocking call that fetches topic metadata for all topics in the + cluster that the user is authorized to view. + """ + cluster = self._client.cluster + if self._client._metadata_refresh_in_progress and self._client._topics: + future = cluster.request_update() + self._client.poll(future=future) + stash = cluster.need_all_topic_metadata + cluster.need_all_topic_metadata = True + future = cluster.request_update() + self._client.poll(future=future) + cluster.need_all_topic_metadata = stash + + def topics(self): + """Get all topics the user is authorized to view. + This will always issue a remote call to the cluster to fetch the latest + information. + + Returns: + set: topics + """ + self._fetch_all_topic_metadata() + return self._client.cluster.topics() + + def partitions_for_topic(self, topic): + """This method first checks the local metadata cache for information + about the topic. If the topic is not found (either because the topic + does not exist, the user is not authorized to view the topic, or the + metadata cache is not populated), then it will issue a metadata update + call to the cluster. + + Arguments: + topic (str): Topic to check. + + Returns: + set: Partition ids + """ + cluster = self._client.cluster + partitions = cluster.partitions_for_topic(topic) + if partitions is None: + self._fetch_all_topic_metadata() + partitions = cluster.partitions_for_topic(topic) + return partitions + + def poll(self, timeout_ms=0, max_records=None, update_offsets=True): + """Fetch data from assigned topics / partitions. + + Records are fetched and returned in batches by topic-partition. + On each poll, consumer will try to use the last consumed offset as the + starting offset and fetch sequentially. The last consumed offset can be + manually set through :meth:`~kafka.KafkaConsumer.seek` or automatically + set as the last committed offset for the subscribed list of partitions. + + Incompatible with iterator interface -- use one or the other, not both. + + Arguments: + timeout_ms (int, optional): Milliseconds spent waiting in poll if + data is not available in the buffer. If 0, returns immediately + with any records that are available currently in the buffer, + else returns empty. Must not be negative. Default: 0 + max_records (int, optional): The maximum number of records returned + in a single call to :meth:`~kafka.KafkaConsumer.poll`. + Default: Inherit value from max_poll_records. + + Returns: + dict: Topic to list of records since the last fetch for the + subscribed list of topics and partitions. + """ + # Note: update_offsets is an internal-use only argument. It is used to + # support the python iterator interface, and which wraps consumer.poll() + # and requires that the partition offsets tracked by the fetcher are not + # updated until the iterator returns each record to the user. As such, + # the argument is not documented and should not be relied on by library + # users to not break in the future. + assert timeout_ms >= 0, 'Timeout must not be negative' + if max_records is None: + max_records = self.config['max_poll_records'] + assert isinstance(max_records, int), 'max_records must be an integer' + assert max_records > 0, 'max_records must be positive' + assert not self._closed, 'KafkaConsumer is closed' + + # Poll for new data until the timeout expires + start = time.time() + remaining = timeout_ms + while not self._closed: + records = self._poll_once(remaining, max_records, update_offsets=update_offsets) + if records: + return records + + elapsed_ms = (time.time() - start) * 1000 + remaining = timeout_ms - elapsed_ms + + if remaining <= 0: + break + + return {} + + def _poll_once(self, timeout_ms, max_records, update_offsets=True): + """Do one round of polling. In addition to checking for new data, this does + any needed heart-beating, auto-commits, and offset updates. + + Arguments: + timeout_ms (int): The maximum time in milliseconds to block. + + Returns: + dict: Map of topic to list of records (may be empty). + """ + self._coordinator.poll() + + # Fetch positions if we have partitions we're subscribed to that we + # don't know the offset for + if not self._subscription.has_all_fetch_positions(): + self._update_fetch_positions(self._subscription.missing_fetch_positions()) + + # If data is available already, e.g. from a previous network client + # poll() call to commit, then just return it immediately + records, partial = self._fetcher.fetched_records(max_records, update_offsets=update_offsets) + if records: + # Before returning the fetched records, we can send off the + # next round of fetches and avoid block waiting for their + # responses to enable pipelining while the user is handling the + # fetched records. + if not partial: + futures = self._fetcher.send_fetches() + if len(futures): + self._client.poll(timeout_ms=0) + return records + + # Send any new fetches (won't resend pending fetches) + futures = self._fetcher.send_fetches() + if len(futures): + self._client.poll(timeout_ms=0) + + timeout_ms = min(timeout_ms, self._coordinator.time_to_next_poll() * 1000) + self._client.poll(timeout_ms=timeout_ms) + # after the long poll, we should check whether the group needs to rebalance + # prior to returning data so that the group can stabilize faster + if self._coordinator.need_rejoin(): + return {} + + records, _ = self._fetcher.fetched_records(max_records, update_offsets=update_offsets) + return records + + def position(self, partition): + """Get the offset of the next record that will be fetched + + Arguments: + partition (TopicPartition): Partition to check + + Returns: + int: Offset + """ + if not isinstance(partition, TopicPartition): + raise TypeError('partition must be a TopicPartition namedtuple') + assert self._subscription.is_assigned(partition), 'Partition is not assigned' + offset = self._subscription.assignment[partition].position + if offset is None: + self._update_fetch_positions([partition]) + offset = self._subscription.assignment[partition].position + return offset + + def highwater(self, partition): + """Last known highwater offset for a partition. + + A highwater offset is the offset that will be assigned to the next + message that is produced. It may be useful for calculating lag, by + comparing with the reported position. Note that both position and + highwater refer to the *next* offset -- i.e., highwater offset is + one greater than the newest available message. + + Highwater offsets are returned in FetchResponse messages, so will + not be available if no FetchRequests have been sent for this partition + yet. + + Arguments: + partition (TopicPartition): Partition to check + + Returns: + int or None: Offset if available + """ + if not isinstance(partition, TopicPartition): + raise TypeError('partition must be a TopicPartition namedtuple') + assert self._subscription.is_assigned(partition), 'Partition is not assigned' + return self._subscription.assignment[partition].highwater + + def pause(self, *partitions): + """Suspend fetching from the requested partitions. + + Future calls to :meth:`~kafka.KafkaConsumer.poll` will not return any + records from these partitions until they have been resumed using + :meth:`~kafka.KafkaConsumer.resume`. + + Note: This method does not affect partition subscription. In particular, + it does not cause a group rebalance when automatic assignment is used. + + Arguments: + *partitions (TopicPartition): Partitions to pause. + """ + if not all([isinstance(p, TopicPartition) for p in partitions]): + raise TypeError('partitions must be TopicPartition namedtuples') + for partition in partitions: + log.debug("Pausing partition %s", partition) + self._subscription.pause(partition) + # Because the iterator checks is_fetchable() on each iteration + # we expect pauses to get handled automatically and therefore + # we do not need to reset the full iterator (forcing a full refetch) + + def paused(self): + """Get the partitions that were previously paused using + :meth:`~kafka.KafkaConsumer.pause`. + + Returns: + set: {partition (TopicPartition), ...} + """ + return self._subscription.paused_partitions() + + def resume(self, *partitions): + """Resume fetching from the specified (paused) partitions. + + Arguments: + *partitions (TopicPartition): Partitions to resume. + """ + if not all([isinstance(p, TopicPartition) for p in partitions]): + raise TypeError('partitions must be TopicPartition namedtuples') + for partition in partitions: + log.debug("Resuming partition %s", partition) + self._subscription.resume(partition) + + def seek(self, partition, offset): + """Manually specify the fetch offset for a TopicPartition. + + Overrides the fetch offsets that the consumer will use on the next + :meth:`~kafka.KafkaConsumer.poll`. If this API is invoked for the same + partition more than once, the latest offset will be used on the next + :meth:`~kafka.KafkaConsumer.poll`. + + Note: You may lose data if this API is arbitrarily used in the middle of + consumption to reset the fetch offsets. + + Arguments: + partition (TopicPartition): Partition for seek operation + offset (int): Message offset in partition + + Raises: + AssertionError: If offset is not an int >= 0; or if partition is not + currently assigned. + """ + if not isinstance(partition, TopicPartition): + raise TypeError('partition must be a TopicPartition namedtuple') + assert isinstance(offset, int) and offset >= 0, 'Offset must be >= 0' + assert partition in self._subscription.assigned_partitions(), 'Unassigned partition' + log.debug("Seeking to offset %s for partition %s", offset, partition) + self._subscription.assignment[partition].seek(offset) + if not self.config['legacy_iterator']: + self._iterator = None + + def seek_to_beginning(self, *partitions): + """Seek to the oldest available offset for partitions. + + Arguments: + *partitions: Optionally provide specific TopicPartitions, otherwise + default to all assigned partitions. + + Raises: + AssertionError: If any partition is not currently assigned, or if + no partitions are assigned. + """ + if not all([isinstance(p, TopicPartition) for p in partitions]): + raise TypeError('partitions must be TopicPartition namedtuples') + if not partitions: + partitions = self._subscription.assigned_partitions() + assert partitions, 'No partitions are currently assigned' + else: + for p in partitions: + assert p in self._subscription.assigned_partitions(), 'Unassigned partition' + + for tp in partitions: + log.debug("Seeking to beginning of partition %s", tp) + self._subscription.need_offset_reset(tp, OffsetResetStrategy.EARLIEST) + if not self.config['legacy_iterator']: + self._iterator = None + + def seek_to_end(self, *partitions): + """Seek to the most recent available offset for partitions. + + Arguments: + *partitions: Optionally provide specific TopicPartitions, otherwise + default to all assigned partitions. + + Raises: + AssertionError: If any partition is not currently assigned, or if + no partitions are assigned. + """ + if not all([isinstance(p, TopicPartition) for p in partitions]): + raise TypeError('partitions must be TopicPartition namedtuples') + if not partitions: + partitions = self._subscription.assigned_partitions() + assert partitions, 'No partitions are currently assigned' + else: + for p in partitions: + assert p in self._subscription.assigned_partitions(), 'Unassigned partition' + + for tp in partitions: + log.debug("Seeking to end of partition %s", tp) + self._subscription.need_offset_reset(tp, OffsetResetStrategy.LATEST) + if not self.config['legacy_iterator']: + self._iterator = None + + def subscribe(self, topics=(), pattern=None, listener=None): + """Subscribe to a list of topics, or a topic regex pattern. + + Partitions will be dynamically assigned via a group coordinator. + Topic subscriptions are not incremental: this list will replace the + current assignment (if there is one). + + This method is incompatible with :meth:`~kafka.KafkaConsumer.assign`. + + Arguments: + topics (list): List of topics for subscription. + pattern (str): Pattern to match available topics. You must provide + either topics or pattern, but not both. + listener (ConsumerRebalanceListener): Optionally include listener + callback, which will be called before and after each rebalance + operation. + + As part of group management, the consumer will keep track of the + list of consumers that belong to a particular group and will + trigger a rebalance operation if one of the following events + trigger: + + * Number of partitions change for any of the subscribed topics + * Topic is created or deleted + * An existing member of the consumer group dies + * A new member is added to the consumer group + + When any of these events are triggered, the provided listener + will be invoked first to indicate that the consumer's assignment + has been revoked, and then again when the new assignment has + been received. Note that this listener will immediately override + any listener set in a previous call to subscribe. It is + guaranteed, however, that the partitions revoked/assigned + through this interface are from topics subscribed in this call. + + Raises: + IllegalStateError: If called after previously calling + :meth:`~kafka.KafkaConsumer.assign`. + AssertionError: If neither topics or pattern is provided. + TypeError: If listener is not a ConsumerRebalanceListener. + """ + # SubscriptionState handles error checking + self._subscription.subscribe(topics=topics, + pattern=pattern, + listener=listener) + + # Regex will need all topic metadata + if pattern is not None: + self._client.cluster.need_all_topic_metadata = True + self._client.set_topics([]) + self._client.cluster.request_update() + log.debug("Subscribed to topic pattern: %s", pattern) + else: + self._client.cluster.need_all_topic_metadata = False + self._client.set_topics(self._subscription.group_subscription()) + log.debug("Subscribed to topic(s): %s", topics) + + def subscription(self): + """Get the current topic subscription. + + Returns: + set: {topic, ...} + """ + if self._subscription.subscription is None: + return None + return self._subscription.subscription.copy() + + def unsubscribe(self): + """Unsubscribe from all topics and clear all assigned partitions.""" + self._subscription.unsubscribe() + self._coordinator.close() + self._client.cluster.need_all_topic_metadata = False + self._client.set_topics([]) + log.debug("Unsubscribed all topics or patterns and assigned partitions") + if not self.config['legacy_iterator']: + self._iterator = None + + def metrics(self, raw=False): + """Get metrics on consumer performance. + + This is ported from the Java Consumer, for details see: + https://kafka.apache.org/documentation/#consumer_monitoring + + Warning: + This is an unstable interface. It may change in future + releases without warning. + """ + if raw: + return self._metrics.metrics.copy() + + metrics = {} + for k, v in six.iteritems(self._metrics.metrics.copy()): + if k.group not in metrics: + metrics[k.group] = {} + if k.name not in metrics[k.group]: + metrics[k.group][k.name] = {} + metrics[k.group][k.name] = v.value() + return metrics + + def offsets_for_times(self, timestamps): + """Look up the offsets for the given partitions by timestamp. The + returned offset for each partition is the earliest offset whose + timestamp is greater than or equal to the given timestamp in the + corresponding partition. + + This is a blocking call. The consumer does not have to be assigned the + partitions. + + If the message format version in a partition is before 0.10.0, i.e. + the messages do not have timestamps, ``None`` will be returned for that + partition. ``None`` will also be returned for the partition if there + are no messages in it. + + Note: + This method may block indefinitely if the partition does not exist. + + Arguments: + timestamps (dict): ``{TopicPartition: int}`` mapping from partition + to the timestamp to look up. Unit should be milliseconds since + beginning of the epoch (midnight Jan 1, 1970 (UTC)) + + Returns: + ``{TopicPartition: OffsetAndTimestamp}``: mapping from partition + to the timestamp and offset of the first message with timestamp + greater than or equal to the target timestamp. + + Raises: + ValueError: If the target timestamp is negative + UnsupportedVersionError: If the broker does not support looking + up the offsets by timestamp. + KafkaTimeoutError: If fetch failed in request_timeout_ms + """ + if self.config['api_version'] <= (0, 10, 0): + raise UnsupportedVersionError( + "offsets_for_times API not supported for cluster version {}" + .format(self.config['api_version'])) + for tp, ts in six.iteritems(timestamps): + timestamps[tp] = int(ts) + if ts < 0: + raise ValueError( + "The target time for partition {} is {}. The target time " + "cannot be negative.".format(tp, ts)) + return self._fetcher.get_offsets_by_times( + timestamps, self.config['request_timeout_ms']) + + def beginning_offsets(self, partitions): + """Get the first offset for the given partitions. + + This method does not change the current consumer position of the + partitions. + + Note: + This method may block indefinitely if the partition does not exist. + + Arguments: + partitions (list): List of TopicPartition instances to fetch + offsets for. + + Returns: + ``{TopicPartition: int}``: The earliest available offsets for the + given partitions. + + Raises: + UnsupportedVersionError: If the broker does not support looking + up the offsets by timestamp. + KafkaTimeoutError: If fetch failed in request_timeout_ms. + """ + offsets = self._fetcher.beginning_offsets( + partitions, self.config['request_timeout_ms']) + return offsets + + def end_offsets(self, partitions): + """Get the last offset for the given partitions. The last offset of a + partition is the offset of the upcoming message, i.e. the offset of the + last available message + 1. + + This method does not change the current consumer position of the + partitions. + + Note: + This method may block indefinitely if the partition does not exist. + + Arguments: + partitions (list): List of TopicPartition instances to fetch + offsets for. + + Returns: + ``{TopicPartition: int}``: The end offsets for the given partitions. + + Raises: + UnsupportedVersionError: If the broker does not support looking + up the offsets by timestamp. + KafkaTimeoutError: If fetch failed in request_timeout_ms + """ + offsets = self._fetcher.end_offsets( + partitions, self.config['request_timeout_ms']) + return offsets + + def _use_consumer_group(self): + """Return True iff this consumer can/should join a broker-coordinated group.""" + if self.config['api_version'] < (0, 9): + return False + elif self.config['group_id'] is None: + return False + elif not self._subscription.partitions_auto_assigned(): + return False + return True + + def _update_fetch_positions(self, partitions): + """Set the fetch position to the committed position (if there is one) + or reset it using the offset reset policy the user has configured. + + Arguments: + partitions (List[TopicPartition]): The partitions that need + updating fetch positions. + + Raises: + NoOffsetForPartitionError: If no offset is stored for a given + partition and no offset reset policy is defined. + """ + # Lookup any positions for partitions which are awaiting reset (which may be the + # case if the user called :meth:`seek_to_beginning` or :meth:`seek_to_end`. We do + # this check first to avoid an unnecessary lookup of committed offsets (which + # typically occurs when the user is manually assigning partitions and managing + # their own offsets). + self._fetcher.reset_offsets_if_needed(partitions) + + if not self._subscription.has_all_fetch_positions(): + # if we still don't have offsets for all partitions, then we should either seek + # to the last committed position or reset using the auto reset policy + if (self.config['api_version'] >= (0, 8, 1) and + self.config['group_id'] is not None): + # first refresh commits for all assigned partitions + self._coordinator.refresh_committed_offsets_if_needed() + + # Then, do any offset lookups in case some positions are not known + self._fetcher.update_fetch_positions(partitions) + + def _message_generator_v2(self): + timeout_ms = 1000 * (self._consumer_timeout - time.time()) + record_map = self.poll(timeout_ms=timeout_ms, update_offsets=False) + for tp, records in six.iteritems(record_map): + # Generators are stateful, and it is possible that the tp / records + # here may become stale during iteration -- i.e., we seek to a + # different offset, pause consumption, or lose assignment. + for record in records: + # is_fetchable(tp) should handle assignment changes and offset + # resets; for all other changes (e.g., seeks) we'll rely on the + # outer function destroying the existing iterator/generator + # via self._iterator = None + if not self._subscription.is_fetchable(tp): + log.debug("Not returning fetched records for partition %s" + " since it is no longer fetchable", tp) + break + self._subscription.assignment[tp].position = record.offset + 1 + yield record + + def _message_generator(self): + assert self.assignment() or self.subscription() is not None, 'No topic subscription or manual partition assignment' + while time.time() < self._consumer_timeout: + + self._coordinator.poll() + + # Fetch offsets for any subscribed partitions that we arent tracking yet + if not self._subscription.has_all_fetch_positions(): + partitions = self._subscription.missing_fetch_positions() + self._update_fetch_positions(partitions) + + poll_ms = min((1000 * (self._consumer_timeout - time.time())), self.config['retry_backoff_ms']) + self._client.poll(timeout_ms=poll_ms) + + # after the long poll, we should check whether the group needs to rebalance + # prior to returning data so that the group can stabilize faster + if self._coordinator.need_rejoin(): + continue + + # We need to make sure we at least keep up with scheduled tasks, + # like heartbeats, auto-commits, and metadata refreshes + timeout_at = self._next_timeout() + + # Short-circuit the fetch iterator if we are already timed out + # to avoid any unintentional interaction with fetcher setup + if time.time() > timeout_at: + continue + + for msg in self._fetcher: + yield msg + if time.time() > timeout_at: + log.debug("internal iterator timeout - breaking for poll") + break + self._client.poll(timeout_ms=0) + + # An else block on a for loop only executes if there was no break + # so this should only be called on a StopIteration from the fetcher + # We assume that it is safe to init_fetches when fetcher is done + # i.e., there are no more records stored internally + else: + self._fetcher.send_fetches() + + def _next_timeout(self): + timeout = min(self._consumer_timeout, + self._client.cluster.ttl() / 1000.0 + time.time(), + self._coordinator.time_to_next_poll() + time.time()) + return timeout + + def __iter__(self): # pylint: disable=non-iterator-returned + return self + + def __next__(self): + if self._closed: + raise StopIteration('KafkaConsumer closed') + # Now that the heartbeat thread runs in the background + # there should be no reason to maintain a separate iterator + # but we'll keep it available for a few releases just in case + if self.config['legacy_iterator']: + return self.next_v1() + else: + return self.next_v2() + + def next_v2(self): + self._set_consumer_timeout() + while time.time() < self._consumer_timeout: + if not self._iterator: + self._iterator = self._message_generator_v2() + try: + return next(self._iterator) + except StopIteration: + self._iterator = None + raise StopIteration() + + def next_v1(self): + if not self._iterator: + self._iterator = self._message_generator() + + self._set_consumer_timeout() + try: + return next(self._iterator) + except StopIteration: + self._iterator = None + raise + + def _set_consumer_timeout(self): + # consumer_timeout_ms can be used to stop iteration early + if self.config['consumer_timeout_ms'] >= 0: + self._consumer_timeout = time.time() + ( + self.config['consumer_timeout_ms'] / 1000.0) diff --git a/consumer/subscription_state.py b/consumer/subscription_state.py new file mode 100644 index 00000000..08842d13 --- /dev/null +++ b/consumer/subscription_state.py @@ -0,0 +1,501 @@ +from __future__ import absolute_import + +import abc +import logging +import re + +from kafka.vendor import six + +from kafka.errors import IllegalStateError +from kafka.protocol.offset import OffsetResetStrategy +from kafka.structs import OffsetAndMetadata + +log = logging.getLogger(__name__) + + +class SubscriptionState(object): + """ + A class for tracking the topics, partitions, and offsets for the consumer. + A partition is "assigned" either directly with assign_from_user() (manual + assignment) or with assign_from_subscribed() (automatic assignment from + subscription). + + Once assigned, the partition is not considered "fetchable" until its initial + position has been set with seek(). Fetchable partitions track a fetch + position which is used to set the offset of the next fetch, and a consumed + position which is the last offset that has been returned to the user. You + can suspend fetching from a partition through pause() without affecting the + fetched/consumed offsets. The partition will remain unfetchable until the + resume() is used. You can also query the pause state independently with + is_paused(). + + Note that pause state as well as fetch/consumed positions are not preserved + when partition assignment is changed whether directly by the user or + through a group rebalance. + + This class also maintains a cache of the latest commit position for each of + the assigned partitions. This is updated through committed() and can be used + to set the initial fetch position (e.g. Fetcher._reset_offset() ). + """ + _SUBSCRIPTION_EXCEPTION_MESSAGE = ( + "You must choose only one way to configure your consumer:" + " (1) subscribe to specific topics by name," + " (2) subscribe to topics matching a regex pattern," + " (3) assign itself specific topic-partitions.") + + # Taken from: https://github.com/apache/kafka/blob/39eb31feaeebfb184d98cc5d94da9148c2319d81/clients/src/main/java/org/apache/kafka/common/internals/Topic.java#L29 + _MAX_NAME_LENGTH = 249 + _TOPIC_LEGAL_CHARS = re.compile('^[a-zA-Z0-9._-]+$') + + def __init__(self, offset_reset_strategy='earliest'): + """Initialize a SubscriptionState instance + + Keyword Arguments: + offset_reset_strategy: 'earliest' or 'latest', otherwise + exception will be raised when fetching an offset that is no + longer available. Default: 'earliest' + """ + try: + offset_reset_strategy = getattr(OffsetResetStrategy, + offset_reset_strategy.upper()) + except AttributeError: + log.warning('Unrecognized offset_reset_strategy, using NONE') + offset_reset_strategy = OffsetResetStrategy.NONE + self._default_offset_reset_strategy = offset_reset_strategy + + self.subscription = None # set() or None + self.subscribed_pattern = None # regex str or None + self._group_subscription = set() + self._user_assignment = set() + self.assignment = dict() + self.listener = None + + # initialize to true for the consumers to fetch offset upon starting up + self.needs_fetch_committed_offsets = True + + def subscribe(self, topics=(), pattern=None, listener=None): + """Subscribe to a list of topics, or a topic regex pattern. + + Partitions will be dynamically assigned via a group coordinator. + Topic subscriptions are not incremental: this list will replace the + current assignment (if there is one). + + This method is incompatible with assign_from_user() + + Arguments: + topics (list): List of topics for subscription. + pattern (str): Pattern to match available topics. You must provide + either topics or pattern, but not both. + listener (ConsumerRebalanceListener): Optionally include listener + callback, which will be called before and after each rebalance + operation. + + As part of group management, the consumer will keep track of the + list of consumers that belong to a particular group and will + trigger a rebalance operation if one of the following events + trigger: + + * Number of partitions change for any of the subscribed topics + * Topic is created or deleted + * An existing member of the consumer group dies + * A new member is added to the consumer group + + When any of these events are triggered, the provided listener + will be invoked first to indicate that the consumer's assignment + has been revoked, and then again when the new assignment has + been received. Note that this listener will immediately override + any listener set in a previous call to subscribe. It is + guaranteed, however, that the partitions revoked/assigned + through this interface are from topics subscribed in this call. + """ + if self._user_assignment or (topics and pattern): + raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE) + assert topics or pattern, 'Must provide topics or pattern' + + if pattern: + log.info('Subscribing to pattern: /%s/', pattern) + self.subscription = set() + self.subscribed_pattern = re.compile(pattern) + else: + self.change_subscription(topics) + + if listener and not isinstance(listener, ConsumerRebalanceListener): + raise TypeError('listener must be a ConsumerRebalanceListener') + self.listener = listener + + def _ensure_valid_topic_name(self, topic): + """ Ensures that the topic name is valid according to the kafka source. """ + + # See Kafka Source: + # https://github.com/apache/kafka/blob/39eb31feaeebfb184d98cc5d94da9148c2319d81/clients/src/main/java/org/apache/kafka/common/internals/Topic.java + if topic is None: + raise TypeError('All topics must not be None') + if not isinstance(topic, six.string_types): + raise TypeError('All topics must be strings') + if len(topic) == 0: + raise ValueError('All topics must be non-empty strings') + if topic == '.' or topic == '..': + raise ValueError('Topic name cannot be "." or ".."') + if len(topic) > self._MAX_NAME_LENGTH: + raise ValueError('Topic name is illegal, it can\'t be longer than {0} characters, topic: "{1}"'.format(self._MAX_NAME_LENGTH, topic)) + if not self._TOPIC_LEGAL_CHARS.match(topic): + raise ValueError('Topic name "{0}" is illegal, it contains a character other than ASCII alphanumerics, ".", "_" and "-"'.format(topic)) + + def change_subscription(self, topics): + """Change the topic subscription. + + Arguments: + topics (list of str): topics for subscription + + Raises: + IllegalStateError: if assign_from_user has been used already + TypeError: if a topic is None or a non-str + ValueError: if a topic is an empty string or + - a topic name is '.' or '..' or + - a topic name does not consist of ASCII-characters/'-'/'_'/'.' + """ + if self._user_assignment: + raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE) + + if isinstance(topics, six.string_types): + topics = [topics] + + if self.subscription == set(topics): + log.warning("subscription unchanged by change_subscription(%s)", + topics) + return + + for t in topics: + self._ensure_valid_topic_name(t) + + log.info('Updating subscribed topics to: %s', topics) + self.subscription = set(topics) + self._group_subscription.update(topics) + + # Remove any assigned partitions which are no longer subscribed to + for tp in set(self.assignment.keys()): + if tp.topic not in self.subscription: + del self.assignment[tp] + + def group_subscribe(self, topics): + """Add topics to the current group subscription. + + This is used by the group leader to ensure that it receives metadata + updates for all topics that any member of the group is subscribed to. + + Arguments: + topics (list of str): topics to add to the group subscription + """ + if self._user_assignment: + raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE) + self._group_subscription.update(topics) + + def reset_group_subscription(self): + """Reset the group's subscription to only contain topics subscribed by this consumer.""" + if self._user_assignment: + raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE) + assert self.subscription is not None, 'Subscription required' + self._group_subscription.intersection_update(self.subscription) + + def assign_from_user(self, partitions): + """Manually assign a list of TopicPartitions to this consumer. + + This interface does not allow for incremental assignment and will + replace the previous assignment (if there was one). + + Manual topic assignment through this method does not use the consumer's + group management functionality. As such, there will be no rebalance + operation triggered when group membership or cluster and topic metadata + change. Note that it is not possible to use both manual partition + assignment with assign() and group assignment with subscribe(). + + Arguments: + partitions (list of TopicPartition): assignment for this instance. + + Raises: + IllegalStateError: if consumer has already called subscribe() + """ + if self.subscription is not None: + raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE) + + if self._user_assignment != set(partitions): + self._user_assignment = set(partitions) + + for partition in partitions: + if partition not in self.assignment: + self._add_assigned_partition(partition) + + for tp in set(self.assignment.keys()) - self._user_assignment: + del self.assignment[tp] + + self.needs_fetch_committed_offsets = True + + def assign_from_subscribed(self, assignments): + """Update the assignment to the specified partitions + + This method is called by the coordinator to dynamically assign + partitions based on the consumer's topic subscription. This is different + from assign_from_user() which directly sets the assignment from a + user-supplied TopicPartition list. + + Arguments: + assignments (list of TopicPartition): partitions to assign to this + consumer instance. + """ + if not self.partitions_auto_assigned(): + raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE) + + for tp in assignments: + if tp.topic not in self.subscription: + raise ValueError("Assigned partition %s for non-subscribed topic." % (tp,)) + + # after rebalancing, we always reinitialize the assignment state + self.assignment.clear() + for tp in assignments: + self._add_assigned_partition(tp) + self.needs_fetch_committed_offsets = True + log.info("Updated partition assignment: %s", assignments) + + def unsubscribe(self): + """Clear all topic subscriptions and partition assignments""" + self.subscription = None + self._user_assignment.clear() + self.assignment.clear() + self.subscribed_pattern = None + + def group_subscription(self): + """Get the topic subscription for the group. + + For the leader, this will include the union of all member subscriptions. + For followers, it is the member's subscription only. + + This is used when querying topic metadata to detect metadata changes + that would require rebalancing (the leader fetches metadata for all + topics in the group so that it can do partition assignment). + + Returns: + set: topics + """ + return self._group_subscription + + def seek(self, partition, offset): + """Manually specify the fetch offset for a TopicPartition. + + Overrides the fetch offsets that the consumer will use on the next + poll(). If this API is invoked for the same partition more than once, + the latest offset will be used on the next poll(). Note that you may + lose data if this API is arbitrarily used in the middle of consumption, + to reset the fetch offsets. + + Arguments: + partition (TopicPartition): partition for seek operation + offset (int): message offset in partition + """ + self.assignment[partition].seek(offset) + + def assigned_partitions(self): + """Return set of TopicPartitions in current assignment.""" + return set(self.assignment.keys()) + + def paused_partitions(self): + """Return current set of paused TopicPartitions.""" + return set(partition for partition in self.assignment + if self.is_paused(partition)) + + def fetchable_partitions(self): + """Return set of TopicPartitions that should be Fetched.""" + fetchable = set() + for partition, state in six.iteritems(self.assignment): + if state.is_fetchable(): + fetchable.add(partition) + return fetchable + + def partitions_auto_assigned(self): + """Return True unless user supplied partitions manually.""" + return self.subscription is not None + + def all_consumed_offsets(self): + """Returns consumed offsets as {TopicPartition: OffsetAndMetadata}""" + all_consumed = {} + for partition, state in six.iteritems(self.assignment): + if state.has_valid_position: + all_consumed[partition] = OffsetAndMetadata(state.position, '') + return all_consumed + + def need_offset_reset(self, partition, offset_reset_strategy=None): + """Mark partition for offset reset using specified or default strategy. + + Arguments: + partition (TopicPartition): partition to mark + offset_reset_strategy (OffsetResetStrategy, optional) + """ + if offset_reset_strategy is None: + offset_reset_strategy = self._default_offset_reset_strategy + self.assignment[partition].await_reset(offset_reset_strategy) + + def has_default_offset_reset_policy(self): + """Return True if default offset reset policy is Earliest or Latest""" + return self._default_offset_reset_strategy != OffsetResetStrategy.NONE + + def is_offset_reset_needed(self, partition): + return self.assignment[partition].awaiting_reset + + def has_all_fetch_positions(self): + for state in self.assignment.values(): + if not state.has_valid_position: + return False + return True + + def missing_fetch_positions(self): + missing = set() + for partition, state in six.iteritems(self.assignment): + if not state.has_valid_position: + missing.add(partition) + return missing + + def is_assigned(self, partition): + return partition in self.assignment + + def is_paused(self, partition): + return partition in self.assignment and self.assignment[partition].paused + + def is_fetchable(self, partition): + return partition in self.assignment and self.assignment[partition].is_fetchable() + + def pause(self, partition): + self.assignment[partition].pause() + + def resume(self, partition): + self.assignment[partition].resume() + + def _add_assigned_partition(self, partition): + self.assignment[partition] = TopicPartitionState() + + +class TopicPartitionState(object): + def __init__(self): + self.committed = None # last committed OffsetAndMetadata + self.has_valid_position = False # whether we have valid position + self.paused = False # whether this partition has been paused by the user + self.awaiting_reset = False # whether we are awaiting reset + self.reset_strategy = None # the reset strategy if awaitingReset is set + self._position = None # offset exposed to the user + self.highwater = None + self.drop_pending_message_set = False + # The last message offset hint available from a message batch with + # magic=2 which includes deleted compacted messages + self.last_offset_from_message_batch = None + + def _set_position(self, offset): + assert self.has_valid_position, 'Valid position required' + self._position = offset + + def _get_position(self): + return self._position + + position = property(_get_position, _set_position, None, "last position") + + def await_reset(self, strategy): + self.awaiting_reset = True + self.reset_strategy = strategy + self._position = None + self.last_offset_from_message_batch = None + self.has_valid_position = False + + def seek(self, offset): + self._position = offset + self.awaiting_reset = False + self.reset_strategy = None + self.has_valid_position = True + self.drop_pending_message_set = True + self.last_offset_from_message_batch = None + + def pause(self): + self.paused = True + + def resume(self): + self.paused = False + + def is_fetchable(self): + return not self.paused and self.has_valid_position + + +class ConsumerRebalanceListener(object): + """ + A callback interface that the user can implement to trigger custom actions + when the set of partitions assigned to the consumer changes. + + This is applicable when the consumer is having Kafka auto-manage group + membership. If the consumer's directly assign partitions, those + partitions will never be reassigned and this callback is not applicable. + + When Kafka is managing the group membership, a partition re-assignment will + be triggered any time the members of the group changes or the subscription + of the members changes. This can occur when processes die, new process + instances are added or old instances come back to life after failure. + Rebalances can also be triggered by changes affecting the subscribed + topics (e.g. when then number of partitions is administratively adjusted). + + There are many uses for this functionality. One common use is saving offsets + in a custom store. By saving offsets in the on_partitions_revoked(), call we + can ensure that any time partition assignment changes the offset gets saved. + + Another use is flushing out any kind of cache of intermediate results the + consumer may be keeping. For example, consider a case where the consumer is + subscribed to a topic containing user page views, and the goal is to count + the number of page views per users for each five minute window. Let's say + the topic is partitioned by the user id so that all events for a particular + user will go to a single consumer instance. The consumer can keep in memory + a running tally of actions per user and only flush these out to a remote + data store when its cache gets too big. However if a partition is reassigned + it may want to automatically trigger a flush of this cache, before the new + owner takes over consumption. + + This callback will execute in the user thread as part of the Consumer.poll() + whenever partition assignment changes. + + It is guaranteed that all consumer processes will invoke + on_partitions_revoked() prior to any process invoking + on_partitions_assigned(). So if offsets or other state is saved in the + on_partitions_revoked() call, it should be saved by the time the process + taking over that partition has their on_partitions_assigned() callback + called to load the state. + """ + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def on_partitions_revoked(self, revoked): + """ + A callback method the user can implement to provide handling of offset + commits to a customized store on the start of a rebalance operation. + This method will be called before a rebalance operation starts and + after the consumer stops fetching data. It is recommended that offsets + should be committed in this callback to either Kafka or a custom offset + store to prevent duplicate data. + + NOTE: This method is only called before rebalances. It is not called + prior to KafkaConsumer.close() + + Arguments: + revoked (list of TopicPartition): the partitions that were assigned + to the consumer on the last rebalance + """ + pass + + @abc.abstractmethod + def on_partitions_assigned(self, assigned): + """ + A callback method the user can implement to provide handling of + customized offsets on completion of a successful partition + re-assignment. This method will be called after an offset re-assignment + completes and before the consumer starts fetching data. + + It is guaranteed that all the processes in a consumer group will execute + their on_partitions_revoked() callback before any instance executes its + on_partitions_assigned() callback. + + Arguments: + assigned (list of TopicPartition): the partitions assigned to the + consumer (may include partitions that were previously assigned) + """ + pass diff --git a/coordinator/__init__.py b/coordinator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/coordinator/assignors/__init__.py b/coordinator/assignors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/coordinator/assignors/abstract.py b/coordinator/assignors/abstract.py new file mode 100644 index 00000000..a1fef384 --- /dev/null +++ b/coordinator/assignors/abstract.py @@ -0,0 +1,56 @@ +from __future__ import absolute_import + +import abc +import logging + +log = logging.getLogger(__name__) + + +class AbstractPartitionAssignor(object): + """ + Abstract assignor implementation which does some common grunt work (in particular collecting + partition counts which are always needed in assignors). + """ + + @abc.abstractproperty + def name(self): + """.name should be a string identifying the assignor""" + pass + + @abc.abstractmethod + def assign(self, cluster, members): + """Perform group assignment given cluster metadata and member subscriptions + + Arguments: + cluster (ClusterMetadata): metadata for use in assignment + members (dict of {member_id: MemberMetadata}): decoded metadata for + each member in the group. + + Returns: + dict: {member_id: MemberAssignment} + """ + pass + + @abc.abstractmethod + def metadata(self, topics): + """Generate ProtocolMetadata to be submitted via JoinGroupRequest. + + Arguments: + topics (set): a member's subscribed topics + + Returns: + MemberMetadata struct + """ + pass + + @abc.abstractmethod + def on_assignment(self, assignment): + """Callback that runs on each assignment. + + This method can be used to update internal state, if any, of the + partition assignor. + + Arguments: + assignment (MemberAssignment): the member's assignment + """ + pass diff --git a/coordinator/assignors/range.py b/coordinator/assignors/range.py new file mode 100644 index 00000000..299e39c4 --- /dev/null +++ b/coordinator/assignors/range.py @@ -0,0 +1,77 @@ +from __future__ import absolute_import + +import collections +import logging + +from kafka.vendor import six + +from kafka.coordinator.assignors.abstract import AbstractPartitionAssignor +from kafka.coordinator.protocol import ConsumerProtocolMemberMetadata, ConsumerProtocolMemberAssignment + +log = logging.getLogger(__name__) + + +class RangePartitionAssignor(AbstractPartitionAssignor): + """ + The range assignor works on a per-topic basis. For each topic, we lay out + the available partitions in numeric order and the consumers in + lexicographic order. We then divide the number of partitions by the total + number of consumers to determine the number of partitions to assign to each + consumer. If it does not evenly divide, then the first few consumers will + have one extra partition. + + For example, suppose there are two consumers C0 and C1, two topics t0 and + t1, and each topic has 3 partitions, resulting in partitions t0p0, t0p1, + t0p2, t1p0, t1p1, and t1p2. + + The assignment will be: + C0: [t0p0, t0p1, t1p0, t1p1] + C1: [t0p2, t1p2] + """ + name = 'range' + version = 0 + + @classmethod + def assign(cls, cluster, member_metadata): + consumers_per_topic = collections.defaultdict(list) + for member, metadata in six.iteritems(member_metadata): + for topic in metadata.subscription: + consumers_per_topic[topic].append(member) + + # construct {member_id: {topic: [partition, ...]}} + assignment = collections.defaultdict(dict) + + for topic, consumers_for_topic in six.iteritems(consumers_per_topic): + partitions = cluster.partitions_for_topic(topic) + if partitions is None: + log.warning('No partition metadata for topic %s', topic) + continue + partitions = sorted(partitions) + consumers_for_topic.sort() + + partitions_per_consumer = len(partitions) // len(consumers_for_topic) + consumers_with_extra = len(partitions) % len(consumers_for_topic) + + for i, member in enumerate(consumers_for_topic): + start = partitions_per_consumer * i + start += min(i, consumers_with_extra) + length = partitions_per_consumer + if not i + 1 > consumers_with_extra: + length += 1 + assignment[member][topic] = partitions[start:start+length] + + protocol_assignment = {} + for member_id in member_metadata: + protocol_assignment[member_id] = ConsumerProtocolMemberAssignment( + cls.version, + sorted(assignment[member_id].items()), + b'') + return protocol_assignment + + @classmethod + def metadata(cls, topics): + return ConsumerProtocolMemberMetadata(cls.version, list(topics), b'') + + @classmethod + def on_assignment(cls, assignment): + pass diff --git a/coordinator/assignors/roundrobin.py b/coordinator/assignors/roundrobin.py new file mode 100644 index 00000000..2d24a5c8 --- /dev/null +++ b/coordinator/assignors/roundrobin.py @@ -0,0 +1,96 @@ +from __future__ import absolute_import + +import collections +import itertools +import logging + +from kafka.vendor import six + +from kafka.coordinator.assignors.abstract import AbstractPartitionAssignor +from kafka.coordinator.protocol import ConsumerProtocolMemberMetadata, ConsumerProtocolMemberAssignment +from kafka.structs import TopicPartition + +log = logging.getLogger(__name__) + + +class RoundRobinPartitionAssignor(AbstractPartitionAssignor): + """ + The roundrobin assignor lays out all the available partitions and all the + available consumers. It then proceeds to do a roundrobin assignment from + partition to consumer. If the subscriptions of all consumer instances are + identical, then the partitions will be uniformly distributed. (i.e., the + partition ownership counts will be within a delta of exactly one across all + consumers.) + + For example, suppose there are two consumers C0 and C1, two topics t0 and + t1, and each topic has 3 partitions, resulting in partitions t0p0, t0p1, + t0p2, t1p0, t1p1, and t1p2. + + The assignment will be: + C0: [t0p0, t0p2, t1p1] + C1: [t0p1, t1p0, t1p2] + + When subscriptions differ across consumer instances, the assignment process + still considers each consumer instance in round robin fashion but skips + over an instance if it is not subscribed to the topic. Unlike the case when + subscriptions are identical, this can result in imbalanced assignments. + + For example, suppose we have three consumers C0, C1, C2, and three topics + t0, t1, t2, with unbalanced partitions t0p0, t1p0, t1p1, t2p0, t2p1, t2p2, + where C0 is subscribed to t0; C1 is subscribed to t0, t1; and C2 is + subscribed to t0, t1, t2. + + The assignment will be: + C0: [t0p0] + C1: [t1p0] + C2: [t1p1, t2p0, t2p1, t2p2] + """ + name = 'roundrobin' + version = 0 + + @classmethod + def assign(cls, cluster, member_metadata): + all_topics = set() + for metadata in six.itervalues(member_metadata): + all_topics.update(metadata.subscription) + + all_topic_partitions = [] + for topic in all_topics: + partitions = cluster.partitions_for_topic(topic) + if partitions is None: + log.warning('No partition metadata for topic %s', topic) + continue + for partition in partitions: + all_topic_partitions.append(TopicPartition(topic, partition)) + all_topic_partitions.sort() + + # construct {member_id: {topic: [partition, ...]}} + assignment = collections.defaultdict(lambda: collections.defaultdict(list)) + + member_iter = itertools.cycle(sorted(member_metadata.keys())) + for partition in all_topic_partitions: + member_id = next(member_iter) + + # Because we constructed all_topic_partitions from the set of + # member subscribed topics, we should be safe assuming that + # each topic in all_topic_partitions is in at least one member + # subscription; otherwise this could yield an infinite loop + while partition.topic not in member_metadata[member_id].subscription: + member_id = next(member_iter) + assignment[member_id][partition.topic].append(partition.partition) + + protocol_assignment = {} + for member_id in member_metadata: + protocol_assignment[member_id] = ConsumerProtocolMemberAssignment( + cls.version, + sorted(assignment[member_id].items()), + b'') + return protocol_assignment + + @classmethod + def metadata(cls, topics): + return ConsumerProtocolMemberMetadata(cls.version, list(topics), b'') + + @classmethod + def on_assignment(cls, assignment): + pass diff --git a/coordinator/assignors/sticky/__init__.py b/coordinator/assignors/sticky/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/coordinator/assignors/sticky/partition_movements.py b/coordinator/assignors/sticky/partition_movements.py new file mode 100644 index 00000000..8851e4cd --- /dev/null +++ b/coordinator/assignors/sticky/partition_movements.py @@ -0,0 +1,149 @@ +import logging +from collections import defaultdict, namedtuple +from copy import deepcopy + +from kafka.vendor import six + +log = logging.getLogger(__name__) + + +ConsumerPair = namedtuple("ConsumerPair", ["src_member_id", "dst_member_id"]) +""" +Represents a pair of Kafka consumer ids involved in a partition reassignment. +Each ConsumerPair corresponds to a particular partition or topic, indicates that the particular partition or some +partition of the particular topic was moved from the source consumer to the destination consumer +during the rebalance. This class helps in determining whether a partition reassignment results in cycles among +the generated graph of consumer pairs. +""" + + +def is_sublist(source, target): + """Checks if one list is a sublist of another. + + Arguments: + source: the list in which to search for the occurrence of target. + target: the list to search for as a sublist of source + + Returns: + true if target is in source; false otherwise + """ + for index in (i for i, e in enumerate(source) if e == target[0]): + if tuple(source[index: index + len(target)]) == target: + return True + return False + + +class PartitionMovements: + """ + This class maintains some data structures to simplify lookup of partition movements among consumers. + At each point of time during a partition rebalance it keeps track of partition movements + corresponding to each topic, and also possible movement (in form a ConsumerPair object) for each partition. + """ + + def __init__(self): + self.partition_movements_by_topic = defaultdict( + lambda: defaultdict(set) + ) + self.partition_movements = {} + + def move_partition(self, partition, old_consumer, new_consumer): + pair = ConsumerPair(src_member_id=old_consumer, dst_member_id=new_consumer) + if partition in self.partition_movements: + # this partition has previously moved + existing_pair = self._remove_movement_record_of_partition(partition) + assert existing_pair.dst_member_id == old_consumer + if existing_pair.src_member_id != new_consumer: + # the partition is not moving back to its previous consumer + self._add_partition_movement_record( + partition, ConsumerPair(src_member_id=existing_pair.src_member_id, dst_member_id=new_consumer) + ) + else: + self._add_partition_movement_record(partition, pair) + + def get_partition_to_be_moved(self, partition, old_consumer, new_consumer): + if partition.topic not in self.partition_movements_by_topic: + return partition + if partition in self.partition_movements: + # this partition has previously moved + assert old_consumer == self.partition_movements[partition].dst_member_id + old_consumer = self.partition_movements[partition].src_member_id + reverse_pair = ConsumerPair(src_member_id=new_consumer, dst_member_id=old_consumer) + if reverse_pair not in self.partition_movements_by_topic[partition.topic]: + return partition + + return next(iter(self.partition_movements_by_topic[partition.topic][reverse_pair])) + + def are_sticky(self): + for topic, movements in six.iteritems(self.partition_movements_by_topic): + movement_pairs = set(movements.keys()) + if self._has_cycles(movement_pairs): + log.error( + "Stickiness is violated for topic {}\n" + "Partition movements for this topic occurred among the following consumer pairs:\n" + "{}".format(topic, movement_pairs) + ) + return False + return True + + def _remove_movement_record_of_partition(self, partition): + pair = self.partition_movements[partition] + del self.partition_movements[partition] + + self.partition_movements_by_topic[partition.topic][pair].remove(partition) + if not self.partition_movements_by_topic[partition.topic][pair]: + del self.partition_movements_by_topic[partition.topic][pair] + if not self.partition_movements_by_topic[partition.topic]: + del self.partition_movements_by_topic[partition.topic] + + return pair + + def _add_partition_movement_record(self, partition, pair): + self.partition_movements[partition] = pair + self.partition_movements_by_topic[partition.topic][pair].add(partition) + + def _has_cycles(self, consumer_pairs): + cycles = set() + for pair in consumer_pairs: + reduced_pairs = deepcopy(consumer_pairs) + reduced_pairs.remove(pair) + path = [pair.src_member_id] + if self._is_linked(pair.dst_member_id, pair.src_member_id, reduced_pairs, path) and not self._is_subcycle( + path, cycles + ): + cycles.add(tuple(path)) + log.error("A cycle of length {} was found: {}".format(len(path) - 1, path)) + + # for now we want to make sure there is no partition movements of the same topic between a pair of consumers. + # the odds of finding a cycle among more than two consumers seem to be very low (according to various randomized + # tests with the given sticky algorithm) that it should not worth the added complexity of handling those cases. + for cycle in cycles: + if len(cycle) == 3: # indicates a cycle of length 2 + return True + return False + + @staticmethod + def _is_subcycle(cycle, cycles): + super_cycle = deepcopy(cycle) + super_cycle = super_cycle[:-1] + super_cycle.extend(cycle) + for found_cycle in cycles: + if len(found_cycle) == len(cycle) and is_sublist(super_cycle, found_cycle): + return True + return False + + def _is_linked(self, src, dst, pairs, current_path): + if src == dst: + return False + if not pairs: + return False + if ConsumerPair(src, dst) in pairs: + current_path.append(src) + current_path.append(dst) + return True + for pair in pairs: + if pair.src_member_id == src: + reduced_set = deepcopy(pairs) + reduced_set.remove(pair) + current_path.append(pair.src_member_id) + return self._is_linked(pair.dst_member_id, dst, reduced_set, current_path) + return False diff --git a/coordinator/assignors/sticky/sorted_set.py b/coordinator/assignors/sticky/sorted_set.py new file mode 100644 index 00000000..6a454a42 --- /dev/null +++ b/coordinator/assignors/sticky/sorted_set.py @@ -0,0 +1,63 @@ +class SortedSet: + def __init__(self, iterable=None, key=None): + self._key = key if key is not None else lambda x: x + self._set = set(iterable) if iterable is not None else set() + + self._cached_last = None + self._cached_first = None + + def first(self): + if self._cached_first is not None: + return self._cached_first + + first = None + for element in self._set: + if first is None or self._key(first) > self._key(element): + first = element + self._cached_first = first + return first + + def last(self): + if self._cached_last is not None: + return self._cached_last + + last = None + for element in self._set: + if last is None or self._key(last) < self._key(element): + last = element + self._cached_last = last + return last + + def pop_last(self): + value = self.last() + self._set.remove(value) + self._cached_last = None + return value + + def add(self, value): + if self._cached_last is not None and self._key(value) > self._key(self._cached_last): + self._cached_last = value + if self._cached_first is not None and self._key(value) < self._key(self._cached_first): + self._cached_first = value + + return self._set.add(value) + + def remove(self, value): + if self._cached_last is not None and self._cached_last == value: + self._cached_last = None + if self._cached_first is not None and self._cached_first == value: + self._cached_first = None + + return self._set.remove(value) + + def __contains__(self, value): + return value in self._set + + def __iter__(self): + return iter(sorted(self._set, key=self._key)) + + def _bool(self): + return len(self._set) != 0 + + __nonzero__ = _bool + __bool__ = _bool diff --git a/coordinator/assignors/sticky/sticky_assignor.py b/coordinator/assignors/sticky/sticky_assignor.py new file mode 100644 index 00000000..dce714f1 --- /dev/null +++ b/coordinator/assignors/sticky/sticky_assignor.py @@ -0,0 +1,685 @@ +import logging +from collections import defaultdict, namedtuple +from copy import deepcopy + +from kafka.cluster import ClusterMetadata +from kafka.coordinator.assignors.abstract import AbstractPartitionAssignor +from kafka.coordinator.assignors.sticky.partition_movements import PartitionMovements +from kafka.coordinator.assignors.sticky.sorted_set import SortedSet +from kafka.coordinator.protocol import ConsumerProtocolMemberMetadata, ConsumerProtocolMemberAssignment +from kafka.coordinator.protocol import Schema +from kafka.protocol.struct import Struct +from kafka.protocol.types import String, Array, Int32 +from kafka.structs import TopicPartition +from kafka.vendor import six + +log = logging.getLogger(__name__) + +ConsumerGenerationPair = namedtuple("ConsumerGenerationPair", ["consumer", "generation"]) + + +def has_identical_list_elements(list_): + """Checks if all lists in the collection have the same members + + Arguments: + list_: collection of lists + + Returns: + true if all lists in the collection have the same members; false otherwise + """ + if not list_: + return True + for i in range(1, len(list_)): + if list_[i] != list_[i - 1]: + return False + return True + + +def subscriptions_comparator_key(element): + return len(element[1]), element[0] + + +def partitions_comparator_key(element): + return len(element[1]), element[0].topic, element[0].partition + + +def remove_if_present(collection, element): + try: + collection.remove(element) + except (ValueError, KeyError): + pass + + +StickyAssignorMemberMetadataV1 = namedtuple("StickyAssignorMemberMetadataV1", + ["subscription", "partitions", "generation"]) + + +class StickyAssignorUserDataV1(Struct): + """ + Used for preserving consumer's previously assigned partitions + list and sending it as user data to the leader during a rebalance + """ + + SCHEMA = Schema( + ("previous_assignment", Array(("topic", String("utf-8")), ("partitions", Array(Int32)))), ("generation", Int32) + ) + + +class StickyAssignmentExecutor: + def __init__(self, cluster, members): + self.members = members + # a mapping between consumers and their assigned partitions that is updated during assignment procedure + self.current_assignment = defaultdict(list) + # an assignment from a previous generation + self.previous_assignment = {} + # a mapping between partitions and their assigned consumers + self.current_partition_consumer = {} + # a flag indicating that there were no previous assignments performed ever + self.is_fresh_assignment = False + # a mapping of all topic partitions to all consumers that can be assigned to them + self.partition_to_all_potential_consumers = {} + # a mapping of all consumers to all potential topic partitions that can be assigned to them + self.consumer_to_all_potential_partitions = {} + # an ascending sorted set of consumers based on how many topic partitions are already assigned to them + self.sorted_current_subscriptions = SortedSet() + # an ascending sorted list of topic partitions based on how many consumers can potentially use them + self.sorted_partitions = [] + # all partitions that need to be assigned + self.unassigned_partitions = [] + # a flag indicating that a certain partition cannot remain assigned to its current consumer because the consumer + # is no longer subscribed to its topic + self.revocation_required = False + + self.partition_movements = PartitionMovements() + self._initialize(cluster) + + def perform_initial_assignment(self): + self._populate_sorted_partitions() + self._populate_partitions_to_reassign() + + def balance(self): + self._initialize_current_subscriptions() + initializing = len(self.current_assignment[self._get_consumer_with_most_subscriptions()]) == 0 + + # assign all unassigned partitions + for partition in self.unassigned_partitions: + # skip if there is no potential consumer for the partition + if not self.partition_to_all_potential_consumers[partition]: + continue + self._assign_partition(partition) + + # narrow down the reassignment scope to only those partitions that can actually be reassigned + fixed_partitions = set() + for partition in six.iterkeys(self.partition_to_all_potential_consumers): + if not self._can_partition_participate_in_reassignment(partition): + fixed_partitions.add(partition) + for fixed_partition in fixed_partitions: + remove_if_present(self.sorted_partitions, fixed_partition) + remove_if_present(self.unassigned_partitions, fixed_partition) + + # narrow down the reassignment scope to only those consumers that are subject to reassignment + fixed_assignments = {} + for consumer in six.iterkeys(self.consumer_to_all_potential_partitions): + if not self._can_consumer_participate_in_reassignment(consumer): + self._remove_consumer_from_current_subscriptions_and_maintain_order(consumer) + fixed_assignments[consumer] = self.current_assignment[consumer] + del self.current_assignment[consumer] + + # create a deep copy of the current assignment so we can revert to it + # if we do not get a more balanced assignment later + prebalance_assignment = deepcopy(self.current_assignment) + prebalance_partition_consumers = deepcopy(self.current_partition_consumer) + + # if we don't already need to revoke something due to subscription changes, + # first try to balance by only moving newly added partitions + if not self.revocation_required: + self._perform_reassignments(self.unassigned_partitions) + reassignment_performed = self._perform_reassignments(self.sorted_partitions) + + # if we are not preserving existing assignments and we have made changes to the current assignment + # make sure we are getting a more balanced assignment; otherwise, revert to previous assignment + if ( + not initializing + and reassignment_performed + and self._get_balance_score(self.current_assignment) >= self._get_balance_score(prebalance_assignment) + ): + self.current_assignment = prebalance_assignment + self.current_partition_consumer.clear() + self.current_partition_consumer.update(prebalance_partition_consumers) + + # add the fixed assignments (those that could not change) back + for consumer, partitions in six.iteritems(fixed_assignments): + self.current_assignment[consumer] = partitions + self._add_consumer_to_current_subscriptions_and_maintain_order(consumer) + + def get_final_assignment(self, member_id): + assignment = defaultdict(list) + for topic_partition in self.current_assignment[member_id]: + assignment[topic_partition.topic].append(topic_partition.partition) + assignment = {k: sorted(v) for k, v in six.iteritems(assignment)} + return six.viewitems(assignment) + + def _initialize(self, cluster): + self._init_current_assignments(self.members) + + for topic in cluster.topics(): + partitions = cluster.partitions_for_topic(topic) + if partitions is None: + log.warning("No partition metadata for topic %s", topic) + continue + for p in partitions: + partition = TopicPartition(topic=topic, partition=p) + self.partition_to_all_potential_consumers[partition] = [] + for consumer_id, member_metadata in six.iteritems(self.members): + self.consumer_to_all_potential_partitions[consumer_id] = [] + for topic in member_metadata.subscription: + if cluster.partitions_for_topic(topic) is None: + log.warning("No partition metadata for topic {}".format(topic)) + continue + for p in cluster.partitions_for_topic(topic): + partition = TopicPartition(topic=topic, partition=p) + self.consumer_to_all_potential_partitions[consumer_id].append(partition) + self.partition_to_all_potential_consumers[partition].append(consumer_id) + if consumer_id not in self.current_assignment: + self.current_assignment[consumer_id] = [] + + def _init_current_assignments(self, members): + # we need to process subscriptions' user data with each consumer's reported generation in mind + # higher generations overwrite lower generations in case of a conflict + # note that a conflict could exists only if user data is for different generations + + # for each partition we create a map of its consumers by generation + sorted_partition_consumers_by_generation = {} + for consumer, member_metadata in six.iteritems(members): + for partitions in member_metadata.partitions: + if partitions in sorted_partition_consumers_by_generation: + consumers = sorted_partition_consumers_by_generation[partitions] + if member_metadata.generation and member_metadata.generation in consumers: + # same partition is assigned to two consumers during the same rebalance. + # log a warning and skip this record + log.warning( + "Partition {} is assigned to multiple consumers " + "following sticky assignment generation {}.".format(partitions, member_metadata.generation) + ) + else: + consumers[member_metadata.generation] = consumer + else: + sorted_consumers = {member_metadata.generation: consumer} + sorted_partition_consumers_by_generation[partitions] = sorted_consumers + + # previous_assignment holds the prior ConsumerGenerationPair (before current) of each partition + # current and previous consumers are the last two consumers of each partition in the above sorted map + for partitions, consumers in six.iteritems(sorted_partition_consumers_by_generation): + generations = sorted(consumers.keys(), reverse=True) + self.current_assignment[consumers[generations[0]]].append(partitions) + # now update previous assignment if any + if len(generations) > 1: + self.previous_assignment[partitions] = ConsumerGenerationPair( + consumer=consumers[generations[1]], generation=generations[1] + ) + + self.is_fresh_assignment = len(self.current_assignment) == 0 + + for consumer_id, partitions in six.iteritems(self.current_assignment): + for partition in partitions: + self.current_partition_consumer[partition] = consumer_id + + def _are_subscriptions_identical(self): + """ + Returns: + true, if both potential consumers of partitions and potential partitions that consumers can + consume are the same + """ + if not has_identical_list_elements(list(six.itervalues(self.partition_to_all_potential_consumers))): + return False + return has_identical_list_elements(list(six.itervalues(self.consumer_to_all_potential_partitions))) + + def _populate_sorted_partitions(self): + # set of topic partitions with their respective potential consumers + all_partitions = set((tp, tuple(consumers)) + for tp, consumers in six.iteritems(self.partition_to_all_potential_consumers)) + partitions_sorted_by_num_of_potential_consumers = sorted(all_partitions, key=partitions_comparator_key) + + self.sorted_partitions = [] + if not self.is_fresh_assignment and self._are_subscriptions_identical(): + # if this is a reassignment and the subscriptions are identical (all consumers can consumer from all topics) + # then we just need to simply list partitions in a round robin fashion (from consumers with + # most assigned partitions to those with least) + assignments = deepcopy(self.current_assignment) + for consumer_id, partitions in six.iteritems(assignments): + to_remove = [] + for partition in partitions: + if partition not in self.partition_to_all_potential_consumers: + to_remove.append(partition) + for partition in to_remove: + partitions.remove(partition) + + sorted_consumers = SortedSet( + iterable=[(consumer, tuple(partitions)) for consumer, partitions in six.iteritems(assignments)], + key=subscriptions_comparator_key, + ) + # at this point, sorted_consumers contains an ascending-sorted list of consumers based on + # how many valid partitions are currently assigned to them + while sorted_consumers: + # take the consumer with the most partitions + consumer, _ = sorted_consumers.pop_last() + # currently assigned partitions to this consumer + remaining_partitions = assignments[consumer] + # from partitions that had a different consumer before, + # keep only those that are assigned to this consumer now + previous_partitions = set(six.iterkeys(self.previous_assignment)).intersection(set(remaining_partitions)) + if previous_partitions: + # if there is a partition of this consumer that was assigned to another consumer before + # mark it as good options for reassignment + partition = previous_partitions.pop() + remaining_partitions.remove(partition) + self.sorted_partitions.append(partition) + sorted_consumers.add((consumer, tuple(assignments[consumer]))) + elif remaining_partitions: + # otherwise, mark any other one of the current partitions as a reassignment candidate + self.sorted_partitions.append(remaining_partitions.pop()) + sorted_consumers.add((consumer, tuple(assignments[consumer]))) + + while partitions_sorted_by_num_of_potential_consumers: + partition = partitions_sorted_by_num_of_potential_consumers.pop(0)[0] + if partition not in self.sorted_partitions: + self.sorted_partitions.append(partition) + else: + while partitions_sorted_by_num_of_potential_consumers: + self.sorted_partitions.append(partitions_sorted_by_num_of_potential_consumers.pop(0)[0]) + + def _populate_partitions_to_reassign(self): + self.unassigned_partitions = deepcopy(self.sorted_partitions) + + assignments_to_remove = [] + for consumer_id, partitions in six.iteritems(self.current_assignment): + if consumer_id not in self.members: + # if a consumer that existed before (and had some partition assignments) is now removed, + # remove it from current_assignment + for partition in partitions: + del self.current_partition_consumer[partition] + assignments_to_remove.append(consumer_id) + else: + # otherwise (the consumer still exists) + partitions_to_remove = [] + for partition in partitions: + if partition not in self.partition_to_all_potential_consumers: + # if this topic partition of this consumer no longer exists + # remove it from current_assignment of the consumer + partitions_to_remove.append(partition) + elif partition.topic not in self.members[consumer_id].subscription: + # if this partition cannot remain assigned to its current consumer because the consumer + # is no longer subscribed to its topic remove it from current_assignment of the consumer + partitions_to_remove.append(partition) + self.revocation_required = True + else: + # otherwise, remove the topic partition from those that need to be assigned only if + # its current consumer is still subscribed to its topic (because it is already assigned + # and we would want to preserve that assignment as much as possible) + self.unassigned_partitions.remove(partition) + for partition in partitions_to_remove: + self.current_assignment[consumer_id].remove(partition) + del self.current_partition_consumer[partition] + for consumer_id in assignments_to_remove: + del self.current_assignment[consumer_id] + + def _initialize_current_subscriptions(self): + self.sorted_current_subscriptions = SortedSet( + iterable=[(consumer, tuple(partitions)) for consumer, partitions in six.iteritems(self.current_assignment)], + key=subscriptions_comparator_key, + ) + + def _get_consumer_with_least_subscriptions(self): + return self.sorted_current_subscriptions.first()[0] + + def _get_consumer_with_most_subscriptions(self): + return self.sorted_current_subscriptions.last()[0] + + def _remove_consumer_from_current_subscriptions_and_maintain_order(self, consumer): + self.sorted_current_subscriptions.remove((consumer, tuple(self.current_assignment[consumer]))) + + def _add_consumer_to_current_subscriptions_and_maintain_order(self, consumer): + self.sorted_current_subscriptions.add((consumer, tuple(self.current_assignment[consumer]))) + + def _is_balanced(self): + """Determines if the current assignment is a balanced one""" + if ( + len(self.current_assignment[self._get_consumer_with_least_subscriptions()]) + >= len(self.current_assignment[self._get_consumer_with_most_subscriptions()]) - 1 + ): + # if minimum and maximum numbers of partitions assigned to consumers differ by at most one return true + return True + + # create a mapping from partitions to the consumer assigned to them + all_assigned_partitions = {} + for consumer_id, consumer_partitions in six.iteritems(self.current_assignment): + for partition in consumer_partitions: + if partition in all_assigned_partitions: + log.error("{} is assigned to more than one consumer.".format(partition)) + all_assigned_partitions[partition] = consumer_id + + # for each consumer that does not have all the topic partitions it can get + # make sure none of the topic partitions it could but did not get cannot be moved to it + # (because that would break the balance) + for consumer, _ in self.sorted_current_subscriptions: + consumer_partition_count = len(self.current_assignment[consumer]) + # skip if this consumer already has all the topic partitions it can get + if consumer_partition_count == len(self.consumer_to_all_potential_partitions[consumer]): + continue + + # otherwise make sure it cannot get any more + for partition in self.consumer_to_all_potential_partitions[consumer]: + if partition not in self.current_assignment[consumer]: + other_consumer = all_assigned_partitions[partition] + other_consumer_partition_count = len(self.current_assignment[other_consumer]) + if consumer_partition_count < other_consumer_partition_count: + return False + return True + + def _assign_partition(self, partition): + for consumer, _ in self.sorted_current_subscriptions: + if partition in self.consumer_to_all_potential_partitions[consumer]: + self._remove_consumer_from_current_subscriptions_and_maintain_order(consumer) + self.current_assignment[consumer].append(partition) + self.current_partition_consumer[partition] = consumer + self._add_consumer_to_current_subscriptions_and_maintain_order(consumer) + break + + def _can_partition_participate_in_reassignment(self, partition): + return len(self.partition_to_all_potential_consumers[partition]) >= 2 + + def _can_consumer_participate_in_reassignment(self, consumer): + current_partitions = self.current_assignment[consumer] + current_assignment_size = len(current_partitions) + max_assignment_size = len(self.consumer_to_all_potential_partitions[consumer]) + if current_assignment_size > max_assignment_size: + log.error("The consumer {} is assigned more partitions than the maximum possible.".format(consumer)) + if current_assignment_size < max_assignment_size: + # if a consumer is not assigned all its potential partitions it is subject to reassignment + return True + for partition in current_partitions: + # if any of the partitions assigned to a consumer is subject to reassignment the consumer itself + # is subject to reassignment + if self._can_partition_participate_in_reassignment(partition): + return True + return False + + def _perform_reassignments(self, reassignable_partitions): + reassignment_performed = False + + # repeat reassignment until no partition can be moved to improve the balance + while True: + modified = False + # reassign all reassignable partitions until the full list is processed or a balance is achieved + # (starting from the partition with least potential consumers and if needed) + for partition in reassignable_partitions: + if self._is_balanced(): + break + # the partition must have at least two potential consumers + if len(self.partition_to_all_potential_consumers[partition]) <= 1: + log.error("Expected more than one potential consumer for partition {}".format(partition)) + # the partition must have a current consumer + consumer = self.current_partition_consumer.get(partition) + if consumer is None: + log.error("Expected partition {} to be assigned to a consumer".format(partition)) + + if ( + partition in self.previous_assignment + and len(self.current_assignment[consumer]) + > len(self.current_assignment[self.previous_assignment[partition].consumer]) + 1 + ): + self._reassign_partition_to_consumer( + partition, self.previous_assignment[partition].consumer, + ) + reassignment_performed = True + modified = True + continue + + # check if a better-suited consumer exist for the partition; if so, reassign it + for other_consumer in self.partition_to_all_potential_consumers[partition]: + if len(self.current_assignment[consumer]) > len(self.current_assignment[other_consumer]) + 1: + self._reassign_partition(partition) + reassignment_performed = True + modified = True + break + + if not modified: + break + return reassignment_performed + + def _reassign_partition(self, partition): + new_consumer = None + for another_consumer, _ in self.sorted_current_subscriptions: + if partition in self.consumer_to_all_potential_partitions[another_consumer]: + new_consumer = another_consumer + break + assert new_consumer is not None + self._reassign_partition_to_consumer(partition, new_consumer) + + def _reassign_partition_to_consumer(self, partition, new_consumer): + consumer = self.current_partition_consumer[partition] + # find the correct partition movement considering the stickiness requirement + partition_to_be_moved = self.partition_movements.get_partition_to_be_moved(partition, consumer, new_consumer) + self._move_partition(partition_to_be_moved, new_consumer) + + def _move_partition(self, partition, new_consumer): + old_consumer = self.current_partition_consumer[partition] + self._remove_consumer_from_current_subscriptions_and_maintain_order(old_consumer) + self._remove_consumer_from_current_subscriptions_and_maintain_order(new_consumer) + + self.partition_movements.move_partition(partition, old_consumer, new_consumer) + + self.current_assignment[old_consumer].remove(partition) + self.current_assignment[new_consumer].append(partition) + self.current_partition_consumer[partition] = new_consumer + + self._add_consumer_to_current_subscriptions_and_maintain_order(new_consumer) + self._add_consumer_to_current_subscriptions_and_maintain_order(old_consumer) + + @staticmethod + def _get_balance_score(assignment): + """Calculates a balance score of a give assignment + as the sum of assigned partitions size difference of all consumer pairs. + A perfectly balanced assignment (with all consumers getting the same number of partitions) + has a balance score of 0. Lower balance score indicates a more balanced assignment. + + Arguments: + assignment (dict): {consumer: list of assigned topic partitions} + + Returns: + the balance score of the assignment + """ + score = 0 + consumer_to_assignment = {} + for consumer_id, partitions in six.iteritems(assignment): + consumer_to_assignment[consumer_id] = len(partitions) + + consumers_to_explore = set(consumer_to_assignment.keys()) + for consumer_id in consumer_to_assignment.keys(): + if consumer_id in consumers_to_explore: + consumers_to_explore.remove(consumer_id) + for other_consumer_id in consumers_to_explore: + score += abs(consumer_to_assignment[consumer_id] - consumer_to_assignment[other_consumer_id]) + return score + + +class StickyPartitionAssignor(AbstractPartitionAssignor): + """ + https://cwiki.apache.org/confluence/display/KAFKA/KIP-54+-+Sticky+Partition+Assignment+Strategy + + The sticky assignor serves two purposes. First, it guarantees an assignment that is as balanced as possible, meaning either: + - the numbers of topic partitions assigned to consumers differ by at most one; or + - each consumer that has 2+ fewer topic partitions than some other consumer cannot get any of those topic partitions transferred to it. + + Second, it preserved as many existing assignment as possible when a reassignment occurs. + This helps in saving some of the overhead processing when topic partitions move from one consumer to another. + + Starting fresh it would work by distributing the partitions over consumers as evenly as possible. + Even though this may sound similar to how round robin assignor works, the second example below shows that it is not. + During a reassignment it would perform the reassignment in such a way that in the new assignment + - topic partitions are still distributed as evenly as possible, and + - topic partitions stay with their previously assigned consumers as much as possible. + + The first goal above takes precedence over the second one. + + Example 1. + Suppose there are three consumers C0, C1, C2, + four topics t0, t1, t2, t3, and each topic has 2 partitions, + resulting in partitions t0p0, t0p1, t1p0, t1p1, t2p0, t2p1, t3p0, t3p1. + Each consumer is subscribed to all three topics. + + The assignment with both sticky and round robin assignors will be: + - C0: [t0p0, t1p1, t3p0] + - C1: [t0p1, t2p0, t3p1] + - C2: [t1p0, t2p1] + + Now, let's assume C1 is removed and a reassignment is about to happen. The round robin assignor would produce: + - C0: [t0p0, t1p0, t2p0, t3p0] + - C2: [t0p1, t1p1, t2p1, t3p1] + + while the sticky assignor would result in: + - C0 [t0p0, t1p1, t3p0, t2p0] + - C2 [t1p0, t2p1, t0p1, t3p1] + preserving all the previous assignments (unlike the round robin assignor). + + + Example 2. + There are three consumers C0, C1, C2, + and three topics t0, t1, t2, with 1, 2, and 3 partitions respectively. + Therefore, the partitions are t0p0, t1p0, t1p1, t2p0, t2p1, t2p2. + C0 is subscribed to t0; + C1 is subscribed to t0, t1; + and C2 is subscribed to t0, t1, t2. + + The round robin assignor would come up with the following assignment: + - C0 [t0p0] + - C1 [t1p0] + - C2 [t1p1, t2p0, t2p1, t2p2] + + which is not as balanced as the assignment suggested by sticky assignor: + - C0 [t0p0] + - C1 [t1p0, t1p1] + - C2 [t2p0, t2p1, t2p2] + + Now, if consumer C0 is removed, these two assignors would produce the following assignments. + Round Robin (preserves 3 partition assignments): + - C1 [t0p0, t1p1] + - C2 [t1p0, t2p0, t2p1, t2p2] + + Sticky (preserves 5 partition assignments): + - C1 [t1p0, t1p1, t0p0] + - C2 [t2p0, t2p1, t2p2] + """ + + DEFAULT_GENERATION_ID = -1 + + name = "sticky" + version = 0 + + member_assignment = None + generation = DEFAULT_GENERATION_ID + + _latest_partition_movements = None + + @classmethod + def assign(cls, cluster, members): + """Performs group assignment given cluster metadata and member subscriptions + + Arguments: + cluster (ClusterMetadata): cluster metadata + members (dict of {member_id: MemberMetadata}): decoded metadata for each member in the group. + + Returns: + dict: {member_id: MemberAssignment} + """ + members_metadata = {} + for consumer, member_metadata in six.iteritems(members): + members_metadata[consumer] = cls.parse_member_metadata(member_metadata) + + executor = StickyAssignmentExecutor(cluster, members_metadata) + executor.perform_initial_assignment() + executor.balance() + + cls._latest_partition_movements = executor.partition_movements + + assignment = {} + for member_id in members: + assignment[member_id] = ConsumerProtocolMemberAssignment( + cls.version, sorted(executor.get_final_assignment(member_id)), b'' + ) + return assignment + + @classmethod + def parse_member_metadata(cls, metadata): + """ + Parses member metadata into a python object. + This implementation only serializes and deserializes the StickyAssignorMemberMetadataV1 user data, + since no StickyAssignor written in Python was deployed ever in the wild with version V0, meaning that + there is no need to support backward compatibility with V0. + + Arguments: + metadata (MemberMetadata): decoded metadata for a member of the group. + + Returns: + parsed metadata (StickyAssignorMemberMetadataV1) + """ + user_data = metadata.user_data + if not user_data: + return StickyAssignorMemberMetadataV1( + partitions=[], generation=cls.DEFAULT_GENERATION_ID, subscription=metadata.subscription + ) + + try: + decoded_user_data = StickyAssignorUserDataV1.decode(user_data) + except Exception as e: + # ignore the consumer's previous assignment if it cannot be parsed + log.error("Could not parse member data", e) # pylint: disable=logging-too-many-args + return StickyAssignorMemberMetadataV1( + partitions=[], generation=cls.DEFAULT_GENERATION_ID, subscription=metadata.subscription + ) + + member_partitions = [] + for topic, partitions in decoded_user_data.previous_assignment: # pylint: disable=no-member + member_partitions.extend([TopicPartition(topic, partition) for partition in partitions]) + return StickyAssignorMemberMetadataV1( + # pylint: disable=no-member + partitions=member_partitions, generation=decoded_user_data.generation, subscription=metadata.subscription + ) + + @classmethod + def metadata(cls, topics): + return cls._metadata(topics, cls.member_assignment, cls.generation) + + @classmethod + def _metadata(cls, topics, member_assignment_partitions, generation=-1): + if member_assignment_partitions is None: + log.debug("No member assignment available") + user_data = b'' + else: + log.debug("Member assignment is available, generating the metadata: generation {}".format(cls.generation)) + partitions_by_topic = defaultdict(list) + for topic_partition in member_assignment_partitions: + partitions_by_topic[topic_partition.topic].append(topic_partition.partition) + data = StickyAssignorUserDataV1(six.viewitems(partitions_by_topic), generation) + user_data = data.encode() + return ConsumerProtocolMemberMetadata(cls.version, list(topics), user_data) + + @classmethod + def on_assignment(cls, assignment): + """Callback that runs on each assignment. Updates assignor's state. + + Arguments: + assignment: MemberAssignment + """ + log.debug("On assignment: assignment={}".format(assignment)) + cls.member_assignment = assignment.partitions() + + @classmethod + def on_generation_assignment(cls, generation): + """Callback that runs on each assignment. Updates assignor's generation id. + + Arguments: + generation: generation id + """ + log.debug("On generation assignment: generation={}".format(generation)) + cls.generation = generation diff --git a/coordinator/base.py b/coordinator/base.py new file mode 100644 index 00000000..e7198410 --- /dev/null +++ b/coordinator/base.py @@ -0,0 +1,1023 @@ +from __future__ import absolute_import, division + +import abc +import copy +import logging +import threading +import time +import weakref + +from kafka.vendor import six + +from kafka.coordinator.heartbeat import Heartbeat +from kafka import errors as Errors +from kafka.future import Future +from kafka.metrics import AnonMeasurable +from kafka.metrics.stats import Avg, Count, Max, Rate +from kafka.protocol.commit import GroupCoordinatorRequest, OffsetCommitRequest +from kafka.protocol.group import (HeartbeatRequest, JoinGroupRequest, + LeaveGroupRequest, SyncGroupRequest) + +log = logging.getLogger('kafka.coordinator') + + +class MemberState(object): + UNJOINED = '' # the client is not part of a group + REBALANCING = '' # the client has begun rebalancing + STABLE = '' # the client has joined and is sending heartbeats + + +class Generation(object): + def __init__(self, generation_id, member_id, protocol): + self.generation_id = generation_id + self.member_id = member_id + self.protocol = protocol + +Generation.NO_GENERATION = Generation( + OffsetCommitRequest[2].DEFAULT_GENERATION_ID, + JoinGroupRequest[0].UNKNOWN_MEMBER_ID, + None) + + +class UnjoinedGroupException(Errors.KafkaError): + retriable = True + + +class BaseCoordinator(object): + """ + BaseCoordinator implements group management for a single group member + by interacting with a designated Kafka broker (the coordinator). Group + semantics are provided by extending this class. See ConsumerCoordinator + for example usage. + + From a high level, Kafka's group management protocol consists of the + following sequence of actions: + + 1. Group Registration: Group members register with the coordinator providing + their own metadata (such as the set of topics they are interested in). + + 2. Group/Leader Selection: The coordinator select the members of the group + and chooses one member as the leader. + + 3. State Assignment: The leader collects the metadata from all the members + of the group and assigns state. + + 4. Group Stabilization: Each member receives the state assigned by the + leader and begins processing. + + To leverage this protocol, an implementation must define the format of + metadata provided by each member for group registration in + :meth:`.group_protocols` and the format of the state assignment provided by + the leader in :meth:`._perform_assignment` and which becomes available to + members in :meth:`._on_join_complete`. + + Note on locking: this class shares state between the caller and a background + thread which is used for sending heartbeats after the client has joined the + group. All mutable state as well as state transitions are protected with the + class's monitor. Generally this means acquiring the lock before reading or + writing the state of the group (e.g. generation, member_id) and holding the + lock when sending a request that affects the state of the group + (e.g. JoinGroup, LeaveGroup). + """ + + DEFAULT_CONFIG = { + 'group_id': 'kafka-python-default-group', + 'session_timeout_ms': 10000, + 'heartbeat_interval_ms': 3000, + 'max_poll_interval_ms': 300000, + 'retry_backoff_ms': 100, + 'api_version': (0, 10, 1), + 'metric_group_prefix': '', + } + + def __init__(self, client, metrics, **configs): + """ + Keyword Arguments: + group_id (str): name of the consumer group to join for dynamic + partition assignment (if enabled), and to use for fetching and + committing offsets. Default: 'kafka-python-default-group' + session_timeout_ms (int): The timeout used to detect failures when + using Kafka's group management facilities. Default: 30000 + heartbeat_interval_ms (int): The expected time in milliseconds + between heartbeats to the consumer coordinator when using + Kafka's group management feature. Heartbeats are used to ensure + that the consumer's session stays active and to facilitate + rebalancing when new consumers join or leave the group. The + value must be set lower than session_timeout_ms, but typically + should be set no higher than 1/3 of that value. It can be + adjusted even lower to control the expected time for normal + rebalances. Default: 3000 + retry_backoff_ms (int): Milliseconds to backoff when retrying on + errors. Default: 100. + """ + self.config = copy.copy(self.DEFAULT_CONFIG) + for key in self.config: + if key in configs: + self.config[key] = configs[key] + + if self.config['api_version'] < (0, 10, 1): + if self.config['max_poll_interval_ms'] != self.config['session_timeout_ms']: + raise Errors.KafkaConfigurationError("Broker version %s does not support " + "different values for max_poll_interval_ms " + "and session_timeout_ms") + + self._client = client + self.group_id = self.config['group_id'] + self.heartbeat = Heartbeat(**self.config) + self._heartbeat_thread = None + self._lock = threading.Condition() + self.rejoin_needed = True + self.rejoining = False # renamed / complement of java needsJoinPrepare + self.state = MemberState.UNJOINED + self.join_future = None + self.coordinator_id = None + self._find_coordinator_future = None + self._generation = Generation.NO_GENERATION + self.sensors = GroupCoordinatorMetrics(self.heartbeat, metrics, + self.config['metric_group_prefix']) + + @abc.abstractmethod + def protocol_type(self): + """ + Unique identifier for the class of supported protocols + (e.g. "consumer" or "connect"). + + Returns: + str: protocol type name + """ + pass + + @abc.abstractmethod + def group_protocols(self): + """Return the list of supported group protocols and metadata. + + This list is submitted by each group member via a JoinGroupRequest. + The order of the protocols in the list indicates the preference of the + protocol (the first entry is the most preferred). The coordinator takes + this preference into account when selecting the generation protocol + (generally more preferred protocols will be selected as long as all + members support them and there is no disagreement on the preference). + + Note: metadata must be type bytes or support an encode() method + + Returns: + list: [(protocol, metadata), ...] + """ + pass + + @abc.abstractmethod + def _on_join_prepare(self, generation, member_id): + """Invoked prior to each group join or rejoin. + + This is typically used to perform any cleanup from the previous + generation (such as committing offsets for the consumer) + + Arguments: + generation (int): The previous generation or -1 if there was none + member_id (str): The identifier of this member in the previous group + or '' if there was none + """ + pass + + @abc.abstractmethod + def _perform_assignment(self, leader_id, protocol, members): + """Perform assignment for the group. + + This is used by the leader to push state to all the members of the group + (e.g. to push partition assignments in the case of the new consumer) + + Arguments: + leader_id (str): The id of the leader (which is this member) + protocol (str): the chosen group protocol (assignment strategy) + members (list): [(member_id, metadata_bytes)] from + JoinGroupResponse. metadata_bytes are associated with the chosen + group protocol, and the Coordinator subclass is responsible for + decoding metadata_bytes based on that protocol. + + Returns: + dict: {member_id: assignment}; assignment must either be bytes + or have an encode() method to convert to bytes + """ + pass + + @abc.abstractmethod + def _on_join_complete(self, generation, member_id, protocol, + member_assignment_bytes): + """Invoked when a group member has successfully joined a group. + + Arguments: + generation (int): the generation that was joined + member_id (str): the identifier for the local member in the group + protocol (str): the protocol selected by the coordinator + member_assignment_bytes (bytes): the protocol-encoded assignment + propagated from the group leader. The Coordinator instance is + responsible for decoding based on the chosen protocol. + """ + pass + + def coordinator_unknown(self): + """Check if we know who the coordinator is and have an active connection + + Side-effect: reset coordinator_id to None if connection failed + + Returns: + bool: True if the coordinator is unknown + """ + return self.coordinator() is None + + def coordinator(self): + """Get the current coordinator + + Returns: the current coordinator id or None if it is unknown + """ + if self.coordinator_id is None: + return None + elif self._client.is_disconnected(self.coordinator_id): + self.coordinator_dead('Node Disconnected') + return None + else: + return self.coordinator_id + + def ensure_coordinator_ready(self): + """Block until the coordinator for this group is known + (and we have an active connection -- java client uses unsent queue). + """ + with self._client._lock, self._lock: + while self.coordinator_unknown(): + + # Prior to 0.8.2 there was no group coordinator + # so we will just pick a node at random and treat + # it as the "coordinator" + if self.config['api_version'] < (0, 8, 2): + self.coordinator_id = self._client.least_loaded_node() + if self.coordinator_id is not None: + self._client.maybe_connect(self.coordinator_id) + continue + + future = self.lookup_coordinator() + self._client.poll(future=future) + + if future.failed(): + if future.retriable(): + if getattr(future.exception, 'invalid_metadata', False): + log.debug('Requesting metadata for group coordinator request: %s', future.exception) + metadata_update = self._client.cluster.request_update() + self._client.poll(future=metadata_update) + else: + time.sleep(self.config['retry_backoff_ms'] / 1000) + else: + raise future.exception # pylint: disable-msg=raising-bad-type + + def _reset_find_coordinator_future(self, result): + self._find_coordinator_future = None + + def lookup_coordinator(self): + with self._lock: + if self._find_coordinator_future is not None: + return self._find_coordinator_future + + # If there is an error sending the group coordinator request + # then _reset_find_coordinator_future will immediately fire and + # set _find_coordinator_future = None + # To avoid returning None, we capture the future in a local variable + future = self._send_group_coordinator_request() + self._find_coordinator_future = future + self._find_coordinator_future.add_both(self._reset_find_coordinator_future) + return future + + def need_rejoin(self): + """Check whether the group should be rejoined (e.g. if metadata changes) + + Returns: + bool: True if it should, False otherwise + """ + return self.rejoin_needed + + def poll_heartbeat(self): + """ + Check the status of the heartbeat thread (if it is active) and indicate + the liveness of the client. This must be called periodically after + joining with :meth:`.ensure_active_group` to ensure that the member stays + in the group. If an interval of time longer than the provided rebalance + timeout (max_poll_interval_ms) expires without calling this method, then + the client will proactively leave the group. + + Raises: RuntimeError for unexpected errors raised from the heartbeat thread + """ + with self._lock: + if self._heartbeat_thread is not None: + if self._heartbeat_thread.failed: + # set the heartbeat thread to None and raise an exception. + # If the user catches it, the next call to ensure_active_group() + # will spawn a new heartbeat thread. + cause = self._heartbeat_thread.failed + self._heartbeat_thread = None + raise cause # pylint: disable-msg=raising-bad-type + + # Awake the heartbeat thread if needed + if self.heartbeat.should_heartbeat(): + self._lock.notify() + self.heartbeat.poll() + + def time_to_next_heartbeat(self): + """Returns seconds (float) remaining before next heartbeat should be sent + + Note: Returns infinite if group is not joined + """ + with self._lock: + # if we have not joined the group, we don't need to send heartbeats + if self.state is MemberState.UNJOINED: + return float('inf') + return self.heartbeat.time_to_next_heartbeat() + + def _handle_join_success(self, member_assignment_bytes): + with self._lock: + log.info("Successfully joined group %s with generation %s", + self.group_id, self._generation.generation_id) + self.state = MemberState.STABLE + self.rejoin_needed = False + if self._heartbeat_thread: + self._heartbeat_thread.enable() + + def _handle_join_failure(self, _): + with self._lock: + self.state = MemberState.UNJOINED + + def ensure_active_group(self): + """Ensure that the group is active (i.e. joined and synced)""" + with self._client._lock, self._lock: + if self._heartbeat_thread is None: + self._start_heartbeat_thread() + + while self.need_rejoin() or self._rejoin_incomplete(): + self.ensure_coordinator_ready() + + # call on_join_prepare if needed. We set a flag + # to make sure that we do not call it a second + # time if the client is woken up before a pending + # rebalance completes. This must be called on each + # iteration of the loop because an event requiring + # a rebalance (such as a metadata refresh which + # changes the matched subscription set) can occur + # while another rebalance is still in progress. + if not self.rejoining: + self._on_join_prepare(self._generation.generation_id, + self._generation.member_id) + self.rejoining = True + + # ensure that there are no pending requests to the coordinator. + # This is important in particular to avoid resending a pending + # JoinGroup request. + while not self.coordinator_unknown(): + if not self._client.in_flight_request_count(self.coordinator_id): + break + self._client.poll() + else: + continue + + # we store the join future in case we are woken up by the user + # after beginning the rebalance in the call to poll below. + # This ensures that we do not mistakenly attempt to rejoin + # before the pending rebalance has completed. + if self.join_future is None: + # Fence off the heartbeat thread explicitly so that it cannot + # interfere with the join group. Note that this must come after + # the call to _on_join_prepare since we must be able to continue + # sending heartbeats if that callback takes some time. + self._heartbeat_thread.disable() + + self.state = MemberState.REBALANCING + future = self._send_join_group_request() + + self.join_future = future # this should happen before adding callbacks + + # handle join completion in the callback so that the + # callback will be invoked even if the consumer is woken up + # before finishing the rebalance + future.add_callback(self._handle_join_success) + + # we handle failures below after the request finishes. + # If the join completes after having been woken up, the + # exception is ignored and we will rejoin + future.add_errback(self._handle_join_failure) + + else: + future = self.join_future + + self._client.poll(future=future) + + if future.succeeded(): + self._on_join_complete(self._generation.generation_id, + self._generation.member_id, + self._generation.protocol, + future.value) + self.join_future = None + self.rejoining = False + + else: + self.join_future = None + exception = future.exception + if isinstance(exception, (Errors.UnknownMemberIdError, + Errors.RebalanceInProgressError, + Errors.IllegalGenerationError)): + continue + elif not future.retriable(): + raise exception # pylint: disable-msg=raising-bad-type + time.sleep(self.config['retry_backoff_ms'] / 1000) + + def _rejoin_incomplete(self): + return self.join_future is not None + + def _send_join_group_request(self): + """Join the group and return the assignment for the next generation. + + This function handles both JoinGroup and SyncGroup, delegating to + :meth:`._perform_assignment` if elected leader by the coordinator. + + Returns: + Future: resolves to the encoded-bytes assignment returned from the + group leader + """ + if self.coordinator_unknown(): + e = Errors.GroupCoordinatorNotAvailableError(self.coordinator_id) + return Future().failure(e) + + elif not self._client.ready(self.coordinator_id, metadata_priority=False): + e = Errors.NodeNotReadyError(self.coordinator_id) + return Future().failure(e) + + # send a join group request to the coordinator + log.info("(Re-)joining group %s", self.group_id) + member_metadata = [ + (protocol, metadata if isinstance(metadata, bytes) else metadata.encode()) + for protocol, metadata in self.group_protocols() + ] + if self.config['api_version'] < (0, 9): + raise Errors.KafkaError('JoinGroupRequest api requires 0.9+ brokers') + elif (0, 9) <= self.config['api_version'] < (0, 10, 1): + request = JoinGroupRequest[0]( + self.group_id, + self.config['session_timeout_ms'], + self._generation.member_id, + self.protocol_type(), + member_metadata) + elif (0, 10, 1) <= self.config['api_version'] < (0, 11, 0): + request = JoinGroupRequest[1]( + self.group_id, + self.config['session_timeout_ms'], + self.config['max_poll_interval_ms'], + self._generation.member_id, + self.protocol_type(), + member_metadata) + else: + request = JoinGroupRequest[2]( + self.group_id, + self.config['session_timeout_ms'], + self.config['max_poll_interval_ms'], + self._generation.member_id, + self.protocol_type(), + member_metadata) + + # create the request for the coordinator + log.debug("Sending JoinGroup (%s) to coordinator %s", request, self.coordinator_id) + future = Future() + _f = self._client.send(self.coordinator_id, request) + _f.add_callback(self._handle_join_group_response, future, time.time()) + _f.add_errback(self._failed_request, self.coordinator_id, + request, future) + return future + + def _failed_request(self, node_id, request, future, error): + # Marking coordinator dead + # unless the error is caused by internal client pipelining + if not isinstance(error, (Errors.NodeNotReadyError, + Errors.TooManyInFlightRequests)): + log.error('Error sending %s to node %s [%s]', + request.__class__.__name__, node_id, error) + self.coordinator_dead(error) + else: + log.debug('Error sending %s to node %s [%s]', + request.__class__.__name__, node_id, error) + future.failure(error) + + def _handle_join_group_response(self, future, send_time, response): + error_type = Errors.for_code(response.error_code) + if error_type is Errors.NoError: + log.debug("Received successful JoinGroup response for group %s: %s", + self.group_id, response) + self.sensors.join_latency.record((time.time() - send_time) * 1000) + with self._lock: + if self.state is not MemberState.REBALANCING: + # if the consumer was woken up before a rebalance completes, + # we may have already left the group. In this case, we do + # not want to continue with the sync group. + future.failure(UnjoinedGroupException()) + else: + self._generation = Generation(response.generation_id, + response.member_id, + response.group_protocol) + + if response.leader_id == response.member_id: + log.info("Elected group leader -- performing partition" + " assignments using %s", self._generation.protocol) + self._on_join_leader(response).chain(future) + else: + self._on_join_follower().chain(future) + + elif error_type is Errors.GroupLoadInProgressError: + log.debug("Attempt to join group %s rejected since coordinator %s" + " is loading the group.", self.group_id, self.coordinator_id) + # backoff and retry + future.failure(error_type(response)) + elif error_type is Errors.UnknownMemberIdError: + # reset the member id and retry immediately + error = error_type(self._generation.member_id) + self.reset_generation() + log.debug("Attempt to join group %s failed due to unknown member id", + self.group_id) + future.failure(error) + elif error_type in (Errors.GroupCoordinatorNotAvailableError, + Errors.NotCoordinatorForGroupError): + # re-discover the coordinator and retry with backoff + self.coordinator_dead(error_type()) + log.debug("Attempt to join group %s failed due to obsolete " + "coordinator information: %s", self.group_id, + error_type.__name__) + future.failure(error_type()) + elif error_type in (Errors.InconsistentGroupProtocolError, + Errors.InvalidSessionTimeoutError, + Errors.InvalidGroupIdError): + # log the error and re-throw the exception + error = error_type(response) + log.error("Attempt to join group %s failed due to fatal error: %s", + self.group_id, error) + future.failure(error) + elif error_type is Errors.GroupAuthorizationFailedError: + future.failure(error_type(self.group_id)) + else: + # unexpected error, throw the exception + error = error_type() + log.error("Unexpected error in join group response: %s", error) + future.failure(error) + + def _on_join_follower(self): + # send follower's sync group with an empty assignment + version = 0 if self.config['api_version'] < (0, 11, 0) else 1 + request = SyncGroupRequest[version]( + self.group_id, + self._generation.generation_id, + self._generation.member_id, + {}) + log.debug("Sending follower SyncGroup for group %s to coordinator %s: %s", + self.group_id, self.coordinator_id, request) + return self._send_sync_group_request(request) + + def _on_join_leader(self, response): + """ + Perform leader synchronization and send back the assignment + for the group via SyncGroupRequest + + Arguments: + response (JoinResponse): broker response to parse + + Returns: + Future: resolves to member assignment encoded-bytes + """ + try: + group_assignment = self._perform_assignment(response.leader_id, + response.group_protocol, + response.members) + except Exception as e: + return Future().failure(e) + + version = 0 if self.config['api_version'] < (0, 11, 0) else 1 + request = SyncGroupRequest[version]( + self.group_id, + self._generation.generation_id, + self._generation.member_id, + [(member_id, + assignment if isinstance(assignment, bytes) else assignment.encode()) + for member_id, assignment in six.iteritems(group_assignment)]) + + log.debug("Sending leader SyncGroup for group %s to coordinator %s: %s", + self.group_id, self.coordinator_id, request) + return self._send_sync_group_request(request) + + def _send_sync_group_request(self, request): + if self.coordinator_unknown(): + e = Errors.GroupCoordinatorNotAvailableError(self.coordinator_id) + return Future().failure(e) + + # We assume that coordinator is ready if we're sending SyncGroup + # as it typically follows a successful JoinGroup + # Also note that if client.ready() enforces a metadata priority policy, + # we can get into an infinite loop if the leader assignment process + # itself requests a metadata update + + future = Future() + _f = self._client.send(self.coordinator_id, request) + _f.add_callback(self._handle_sync_group_response, future, time.time()) + _f.add_errback(self._failed_request, self.coordinator_id, + request, future) + return future + + def _handle_sync_group_response(self, future, send_time, response): + error_type = Errors.for_code(response.error_code) + if error_type is Errors.NoError: + self.sensors.sync_latency.record((time.time() - send_time) * 1000) + future.success(response.member_assignment) + return + + # Always rejoin on error + self.request_rejoin() + if error_type is Errors.GroupAuthorizationFailedError: + future.failure(error_type(self.group_id)) + elif error_type is Errors.RebalanceInProgressError: + log.debug("SyncGroup for group %s failed due to coordinator" + " rebalance", self.group_id) + future.failure(error_type(self.group_id)) + elif error_type in (Errors.UnknownMemberIdError, + Errors.IllegalGenerationError): + error = error_type() + log.debug("SyncGroup for group %s failed due to %s", self.group_id, error) + self.reset_generation() + future.failure(error) + elif error_type in (Errors.GroupCoordinatorNotAvailableError, + Errors.NotCoordinatorForGroupError): + error = error_type() + log.debug("SyncGroup for group %s failed due to %s", self.group_id, error) + self.coordinator_dead(error) + future.failure(error) + else: + error = error_type() + log.error("Unexpected error from SyncGroup: %s", error) + future.failure(error) + + def _send_group_coordinator_request(self): + """Discover the current coordinator for the group. + + Returns: + Future: resolves to the node id of the coordinator + """ + node_id = self._client.least_loaded_node() + if node_id is None: + return Future().failure(Errors.NoBrokersAvailable()) + + elif not self._client.ready(node_id, metadata_priority=False): + e = Errors.NodeNotReadyError(node_id) + return Future().failure(e) + + log.debug("Sending group coordinator request for group %s to broker %s", + self.group_id, node_id) + request = GroupCoordinatorRequest[0](self.group_id) + future = Future() + _f = self._client.send(node_id, request) + _f.add_callback(self._handle_group_coordinator_response, future) + _f.add_errback(self._failed_request, node_id, request, future) + return future + + def _handle_group_coordinator_response(self, future, response): + log.debug("Received group coordinator response %s", response) + + error_type = Errors.for_code(response.error_code) + if error_type is Errors.NoError: + with self._lock: + coordinator_id = self._client.cluster.add_group_coordinator(self.group_id, response) + if not coordinator_id: + # This could happen if coordinator metadata is different + # than broker metadata + future.failure(Errors.IllegalStateError()) + return + + self.coordinator_id = coordinator_id + log.info("Discovered coordinator %s for group %s", + self.coordinator_id, self.group_id) + self._client.maybe_connect(self.coordinator_id) + self.heartbeat.reset_timeouts() + future.success(self.coordinator_id) + + elif error_type is Errors.GroupCoordinatorNotAvailableError: + log.debug("Group Coordinator Not Available; retry") + future.failure(error_type()) + elif error_type is Errors.GroupAuthorizationFailedError: + error = error_type(self.group_id) + log.error("Group Coordinator Request failed: %s", error) + future.failure(error) + else: + error = error_type() + log.error("Group coordinator lookup for group %s failed: %s", + self.group_id, error) + future.failure(error) + + def coordinator_dead(self, error): + """Mark the current coordinator as dead.""" + if self.coordinator_id is not None: + log.warning("Marking the coordinator dead (node %s) for group %s: %s.", + self.coordinator_id, self.group_id, error) + self.coordinator_id = None + + def generation(self): + """Get the current generation state if the group is stable. + + Returns: the current generation or None if the group is unjoined/rebalancing + """ + with self._lock: + if self.state is not MemberState.STABLE: + return None + return self._generation + + def reset_generation(self): + """Reset the generation and memberId because we have fallen out of the group.""" + with self._lock: + self._generation = Generation.NO_GENERATION + self.rejoin_needed = True + self.state = MemberState.UNJOINED + + def request_rejoin(self): + self.rejoin_needed = True + + def _start_heartbeat_thread(self): + if self._heartbeat_thread is None: + log.info('Starting new heartbeat thread') + self._heartbeat_thread = HeartbeatThread(weakref.proxy(self)) + self._heartbeat_thread.daemon = True + self._heartbeat_thread.start() + + def _close_heartbeat_thread(self): + if self._heartbeat_thread is not None: + log.info('Stopping heartbeat thread') + try: + self._heartbeat_thread.close() + except ReferenceError: + pass + self._heartbeat_thread = None + + def __del__(self): + self._close_heartbeat_thread() + + def close(self): + """Close the coordinator, leave the current group, + and reset local generation / member_id""" + self._close_heartbeat_thread() + self.maybe_leave_group() + + def maybe_leave_group(self): + """Leave the current group and reset local generation/memberId.""" + with self._client._lock, self._lock: + if (not self.coordinator_unknown() + and self.state is not MemberState.UNJOINED + and self._generation is not Generation.NO_GENERATION): + + # this is a minimal effort attempt to leave the group. we do not + # attempt any resending if the request fails or times out. + log.info('Leaving consumer group (%s).', self.group_id) + version = 0 if self.config['api_version'] < (0, 11, 0) else 1 + request = LeaveGroupRequest[version](self.group_id, self._generation.member_id) + future = self._client.send(self.coordinator_id, request) + future.add_callback(self._handle_leave_group_response) + future.add_errback(log.error, "LeaveGroup request failed: %s") + self._client.poll(future=future) + + self.reset_generation() + + def _handle_leave_group_response(self, response): + error_type = Errors.for_code(response.error_code) + if error_type is Errors.NoError: + log.debug("LeaveGroup request for group %s returned successfully", + self.group_id) + else: + log.error("LeaveGroup request for group %s failed with error: %s", + self.group_id, error_type()) + + def _send_heartbeat_request(self): + """Send a heartbeat request""" + if self.coordinator_unknown(): + e = Errors.GroupCoordinatorNotAvailableError(self.coordinator_id) + return Future().failure(e) + + elif not self._client.ready(self.coordinator_id, metadata_priority=False): + e = Errors.NodeNotReadyError(self.coordinator_id) + return Future().failure(e) + + version = 0 if self.config['api_version'] < (0, 11, 0) else 1 + request = HeartbeatRequest[version](self.group_id, + self._generation.generation_id, + self._generation.member_id) + log.debug("Heartbeat: %s[%s] %s", request.group, request.generation_id, request.member_id) # pylint: disable-msg=no-member + future = Future() + _f = self._client.send(self.coordinator_id, request) + _f.add_callback(self._handle_heartbeat_response, future, time.time()) + _f.add_errback(self._failed_request, self.coordinator_id, + request, future) + return future + + def _handle_heartbeat_response(self, future, send_time, response): + self.sensors.heartbeat_latency.record((time.time() - send_time) * 1000) + error_type = Errors.for_code(response.error_code) + if error_type is Errors.NoError: + log.debug("Received successful heartbeat response for group %s", + self.group_id) + future.success(None) + elif error_type in (Errors.GroupCoordinatorNotAvailableError, + Errors.NotCoordinatorForGroupError): + log.warning("Heartbeat failed for group %s: coordinator (node %s)" + " is either not started or not valid", self.group_id, + self.coordinator()) + self.coordinator_dead(error_type()) + future.failure(error_type()) + elif error_type is Errors.RebalanceInProgressError: + log.warning("Heartbeat failed for group %s because it is" + " rebalancing", self.group_id) + self.request_rejoin() + future.failure(error_type()) + elif error_type is Errors.IllegalGenerationError: + log.warning("Heartbeat failed for group %s: generation id is not " + " current.", self.group_id) + self.reset_generation() + future.failure(error_type()) + elif error_type is Errors.UnknownMemberIdError: + log.warning("Heartbeat: local member_id was not recognized;" + " this consumer needs to re-join") + self.reset_generation() + future.failure(error_type) + elif error_type is Errors.GroupAuthorizationFailedError: + error = error_type(self.group_id) + log.error("Heartbeat failed: authorization error: %s", error) + future.failure(error) + else: + error = error_type() + log.error("Heartbeat failed: Unhandled error: %s", error) + future.failure(error) + + +class GroupCoordinatorMetrics(object): + def __init__(self, heartbeat, metrics, prefix, tags=None): + self.heartbeat = heartbeat + self.metrics = metrics + self.metric_group_name = prefix + "-coordinator-metrics" + + self.heartbeat_latency = metrics.sensor('heartbeat-latency') + self.heartbeat_latency.add(metrics.metric_name( + 'heartbeat-response-time-max', self.metric_group_name, + 'The max time taken to receive a response to a heartbeat request', + tags), Max()) + self.heartbeat_latency.add(metrics.metric_name( + 'heartbeat-rate', self.metric_group_name, + 'The average number of heartbeats per second', + tags), Rate(sampled_stat=Count())) + + self.join_latency = metrics.sensor('join-latency') + self.join_latency.add(metrics.metric_name( + 'join-time-avg', self.metric_group_name, + 'The average time taken for a group rejoin', + tags), Avg()) + self.join_latency.add(metrics.metric_name( + 'join-time-max', self.metric_group_name, + 'The max time taken for a group rejoin', + tags), Max()) + self.join_latency.add(metrics.metric_name( + 'join-rate', self.metric_group_name, + 'The number of group joins per second', + tags), Rate(sampled_stat=Count())) + + self.sync_latency = metrics.sensor('sync-latency') + self.sync_latency.add(metrics.metric_name( + 'sync-time-avg', self.metric_group_name, + 'The average time taken for a group sync', + tags), Avg()) + self.sync_latency.add(metrics.metric_name( + 'sync-time-max', self.metric_group_name, + 'The max time taken for a group sync', + tags), Max()) + self.sync_latency.add(metrics.metric_name( + 'sync-rate', self.metric_group_name, + 'The number of group syncs per second', + tags), Rate(sampled_stat=Count())) + + metrics.add_metric(metrics.metric_name( + 'last-heartbeat-seconds-ago', self.metric_group_name, + 'The number of seconds since the last controller heartbeat was sent', + tags), AnonMeasurable( + lambda _, now: (now / 1000) - self.heartbeat.last_send)) + + +class HeartbeatThread(threading.Thread): + def __init__(self, coordinator): + super(HeartbeatThread, self).__init__() + self.name = coordinator.group_id + '-heartbeat' + self.coordinator = coordinator + self.enabled = False + self.closed = False + self.failed = None + + def enable(self): + with self.coordinator._lock: + self.enabled = True + self.coordinator.heartbeat.reset_timeouts() + self.coordinator._lock.notify() + + def disable(self): + self.enabled = False + + def close(self): + self.closed = True + with self.coordinator._lock: + self.coordinator._lock.notify() + if self.is_alive(): + self.join(self.coordinator.config['heartbeat_interval_ms'] / 1000) + if self.is_alive(): + log.warning("Heartbeat thread did not fully terminate during close") + + def run(self): + try: + log.debug('Heartbeat thread started') + while not self.closed: + self._run_once() + + except ReferenceError: + log.debug('Heartbeat thread closed due to coordinator gc') + + except RuntimeError as e: + log.error("Heartbeat thread for group %s failed due to unexpected error: %s", + self.coordinator.group_id, e) + self.failed = e + + finally: + log.debug('Heartbeat thread closed') + + def _run_once(self): + with self.coordinator._client._lock, self.coordinator._lock: + if self.enabled and self.coordinator.state is MemberState.STABLE: + # TODO: When consumer.wakeup() is implemented, we need to + # disable here to prevent propagating an exception to this + # heartbeat thread + # must get client._lock, or maybe deadlock at heartbeat + # failure callback in consumer poll + self.coordinator._client.poll(timeout_ms=0) + + with self.coordinator._lock: + if not self.enabled: + log.debug('Heartbeat disabled. Waiting') + self.coordinator._lock.wait() + log.debug('Heartbeat re-enabled.') + return + + if self.coordinator.state is not MemberState.STABLE: + # the group is not stable (perhaps because we left the + # group or because the coordinator kicked us out), so + # disable heartbeats and wait for the main thread to rejoin. + log.debug('Group state is not stable, disabling heartbeats') + self.disable() + return + + if self.coordinator.coordinator_unknown(): + future = self.coordinator.lookup_coordinator() + if not future.is_done or future.failed(): + # the immediate future check ensures that we backoff + # properly in the case that no brokers are available + # to connect to (and the future is automatically failed). + self.coordinator._lock.wait(self.coordinator.config['retry_backoff_ms'] / 1000) + + elif self.coordinator.heartbeat.session_timeout_expired(): + # the session timeout has expired without seeing a + # successful heartbeat, so we should probably make sure + # the coordinator is still healthy. + log.warning('Heartbeat session expired, marking coordinator dead') + self.coordinator.coordinator_dead('Heartbeat session expired') + + elif self.coordinator.heartbeat.poll_timeout_expired(): + # the poll timeout has expired, which means that the + # foreground thread has stalled in between calls to + # poll(), so we explicitly leave the group. + log.warning('Heartbeat poll expired, leaving group') + self.coordinator.maybe_leave_group() + + elif not self.coordinator.heartbeat.should_heartbeat(): + # poll again after waiting for the retry backoff in case + # the heartbeat failed or the coordinator disconnected + log.log(0, 'Not ready to heartbeat, waiting') + self.coordinator._lock.wait(self.coordinator.config['retry_backoff_ms'] / 1000) + + else: + self.coordinator.heartbeat.sent_heartbeat() + future = self.coordinator._send_heartbeat_request() + future.add_callback(self._handle_heartbeat_success) + future.add_errback(self._handle_heartbeat_failure) + + def _handle_heartbeat_success(self, result): + with self.coordinator._lock: + self.coordinator.heartbeat.received_heartbeat() + + def _handle_heartbeat_failure(self, exception): + with self.coordinator._lock: + if isinstance(exception, Errors.RebalanceInProgressError): + # it is valid to continue heartbeating while the group is + # rebalancing. This ensures that the coordinator keeps the + # member in the group for as long as the duration of the + # rebalance timeout. If we stop sending heartbeats, however, + # then the session timeout may expire before we can rejoin. + self.coordinator.heartbeat.received_heartbeat() + else: + self.coordinator.heartbeat.fail_heartbeat() + # wake up the thread if it's sleeping to reschedule the heartbeat + self.coordinator._lock.notify() diff --git a/coordinator/consumer.py b/coordinator/consumer.py new file mode 100644 index 00000000..971f5e80 --- /dev/null +++ b/coordinator/consumer.py @@ -0,0 +1,833 @@ +from __future__ import absolute_import, division + +import collections +import copy +import functools +import logging +import time + +from kafka.vendor import six + +from kafka.coordinator.base import BaseCoordinator, Generation +from kafka.coordinator.assignors.range import RangePartitionAssignor +from kafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor +from kafka.coordinator.assignors.sticky.sticky_assignor import StickyPartitionAssignor +from kafka.coordinator.protocol import ConsumerProtocol +import kafka.errors as Errors +from kafka.future import Future +from kafka.metrics import AnonMeasurable +from kafka.metrics.stats import Avg, Count, Max, Rate +from kafka.protocol.commit import OffsetCommitRequest, OffsetFetchRequest +from kafka.structs import OffsetAndMetadata, TopicPartition +from kafka.util import WeakMethod + + +log = logging.getLogger(__name__) + + +class ConsumerCoordinator(BaseCoordinator): + """This class manages the coordination process with the consumer coordinator.""" + DEFAULT_CONFIG = { + 'group_id': 'kafka-python-default-group', + 'enable_auto_commit': True, + 'auto_commit_interval_ms': 5000, + 'default_offset_commit_callback': None, + 'assignors': (RangePartitionAssignor, RoundRobinPartitionAssignor, StickyPartitionAssignor), + 'session_timeout_ms': 10000, + 'heartbeat_interval_ms': 3000, + 'max_poll_interval_ms': 300000, + 'retry_backoff_ms': 100, + 'api_version': (0, 10, 1), + 'exclude_internal_topics': True, + 'metric_group_prefix': 'consumer' + } + + def __init__(self, client, subscription, metrics, **configs): + """Initialize the coordination manager. + + Keyword Arguments: + group_id (str): name of the consumer group to join for dynamic + partition assignment (if enabled), and to use for fetching and + committing offsets. Default: 'kafka-python-default-group' + enable_auto_commit (bool): If true the consumer's offset will be + periodically committed in the background. Default: True. + auto_commit_interval_ms (int): milliseconds between automatic + offset commits, if enable_auto_commit is True. Default: 5000. + default_offset_commit_callback (callable): called as + callback(offsets, exception) response will be either an Exception + or None. This callback can be used to trigger custom actions when + a commit request completes. + assignors (list): List of objects to use to distribute partition + ownership amongst consumer instances when group management is + used. Default: [RangePartitionAssignor, RoundRobinPartitionAssignor] + heartbeat_interval_ms (int): The expected time in milliseconds + between heartbeats to the consumer coordinator when using + Kafka's group management feature. Heartbeats are used to ensure + that the consumer's session stays active and to facilitate + rebalancing when new consumers join or leave the group. The + value must be set lower than session_timeout_ms, but typically + should be set no higher than 1/3 of that value. It can be + adjusted even lower to control the expected time for normal + rebalances. Default: 3000 + session_timeout_ms (int): The timeout used to detect failures when + using Kafka's group management facilities. Default: 30000 + retry_backoff_ms (int): Milliseconds to backoff when retrying on + errors. Default: 100. + exclude_internal_topics (bool): Whether records from internal topics + (such as offsets) should be exposed to the consumer. If set to + True the only way to receive records from an internal topic is + subscribing to it. Requires 0.10+. Default: True + """ + super(ConsumerCoordinator, self).__init__(client, metrics, **configs) + + self.config = copy.copy(self.DEFAULT_CONFIG) + for key in self.config: + if key in configs: + self.config[key] = configs[key] + + self._subscription = subscription + self._is_leader = False + self._joined_subscription = set() + self._metadata_snapshot = self._build_metadata_snapshot(subscription, client.cluster) + self._assignment_snapshot = None + self._cluster = client.cluster + self.auto_commit_interval = self.config['auto_commit_interval_ms'] / 1000 + self.next_auto_commit_deadline = None + self.completed_offset_commits = collections.deque() + + if self.config['default_offset_commit_callback'] is None: + self.config['default_offset_commit_callback'] = self._default_offset_commit_callback + + if self.config['group_id'] is not None: + if self.config['api_version'] >= (0, 9): + if not self.config['assignors']: + raise Errors.KafkaConfigurationError('Coordinator requires assignors') + if self.config['api_version'] < (0, 10, 1): + if self.config['max_poll_interval_ms'] != self.config['session_timeout_ms']: + raise Errors.KafkaConfigurationError("Broker version %s does not support " + "different values for max_poll_interval_ms " + "and session_timeout_ms") + + if self.config['enable_auto_commit']: + if self.config['api_version'] < (0, 8, 1): + log.warning('Broker version (%s) does not support offset' + ' commits; disabling auto-commit.', + self.config['api_version']) + self.config['enable_auto_commit'] = False + elif self.config['group_id'] is None: + log.warning('group_id is None: disabling auto-commit.') + self.config['enable_auto_commit'] = False + else: + self.next_auto_commit_deadline = time.time() + self.auto_commit_interval + + self.consumer_sensors = ConsumerCoordinatorMetrics( + metrics, self.config['metric_group_prefix'], self._subscription) + + self._cluster.request_update() + self._cluster.add_listener(WeakMethod(self._handle_metadata_update)) + + def __del__(self): + if hasattr(self, '_cluster') and self._cluster: + self._cluster.remove_listener(WeakMethod(self._handle_metadata_update)) + super(ConsumerCoordinator, self).__del__() + + def protocol_type(self): + return ConsumerProtocol.PROTOCOL_TYPE + + def group_protocols(self): + """Returns list of preferred (protocols, metadata)""" + if self._subscription.subscription is None: + raise Errors.IllegalStateError('Consumer has not subscribed to topics') + # dpkp note: I really dislike this. + # why? because we are using this strange method group_protocols, + # which is seemingly innocuous, to set internal state (_joined_subscription) + # that is later used to check whether metadata has changed since we joined a group + # but there is no guarantee that this method, group_protocols, will get called + # in the correct sequence or that it will only be called when we want it to be. + # So this really should be moved elsewhere, but I don't have the energy to + # work that out right now. If you read this at some later date after the mutable + # state has bitten you... I'm sorry! It mimics the java client, and that's the + # best I've got for now. + self._joined_subscription = set(self._subscription.subscription) + metadata_list = [] + for assignor in self.config['assignors']: + metadata = assignor.metadata(self._joined_subscription) + group_protocol = (assignor.name, metadata) + metadata_list.append(group_protocol) + return metadata_list + + def _handle_metadata_update(self, cluster): + # if we encounter any unauthorized topics, raise an exception + if cluster.unauthorized_topics: + raise Errors.TopicAuthorizationFailedError(cluster.unauthorized_topics) + + if self._subscription.subscribed_pattern: + topics = [] + for topic in cluster.topics(self.config['exclude_internal_topics']): + if self._subscription.subscribed_pattern.match(topic): + topics.append(topic) + + if set(topics) != self._subscription.subscription: + self._subscription.change_subscription(topics) + self._client.set_topics(self._subscription.group_subscription()) + + # check if there are any changes to the metadata which should trigger + # a rebalance + if self._subscription.partitions_auto_assigned(): + metadata_snapshot = self._build_metadata_snapshot(self._subscription, cluster) + if self._metadata_snapshot != metadata_snapshot: + self._metadata_snapshot = metadata_snapshot + + # If we haven't got group coordinator support, + # just assign all partitions locally + if self._auto_assign_all_partitions(): + self._subscription.assign_from_subscribed([ + TopicPartition(topic, partition) + for topic in self._subscription.subscription + for partition in self._metadata_snapshot[topic] + ]) + + def _auto_assign_all_partitions(self): + # For users that use "subscribe" without group support, + # we will simply assign all partitions to this consumer + if self.config['api_version'] < (0, 9): + return True + elif self.config['group_id'] is None: + return True + else: + return False + + def _build_metadata_snapshot(self, subscription, cluster): + metadata_snapshot = {} + for topic in subscription.group_subscription(): + partitions = cluster.partitions_for_topic(topic) or [] + metadata_snapshot[topic] = set(partitions) + return metadata_snapshot + + def _lookup_assignor(self, name): + for assignor in self.config['assignors']: + if assignor.name == name: + return assignor + return None + + def _on_join_complete(self, generation, member_id, protocol, + member_assignment_bytes): + # only the leader is responsible for monitoring for metadata changes + # (i.e. partition changes) + if not self._is_leader: + self._assignment_snapshot = None + + assignor = self._lookup_assignor(protocol) + assert assignor, 'Coordinator selected invalid assignment protocol: %s' % (protocol,) + + assignment = ConsumerProtocol.ASSIGNMENT.decode(member_assignment_bytes) + + # set the flag to refresh last committed offsets + self._subscription.needs_fetch_committed_offsets = True + + # update partition assignment + try: + self._subscription.assign_from_subscribed(assignment.partitions()) + except ValueError as e: + log.warning("%s. Probably due to a deleted topic. Requesting Re-join" % e) + self.request_rejoin() + + # give the assignor a chance to update internal state + # based on the received assignment + assignor.on_assignment(assignment) + if assignor.name == 'sticky': + assignor.on_generation_assignment(generation) + + # reschedule the auto commit starting from now + self.next_auto_commit_deadline = time.time() + self.auto_commit_interval + + assigned = set(self._subscription.assigned_partitions()) + log.info("Setting newly assigned partitions %s for group %s", + assigned, self.group_id) + + # execute the user's callback after rebalance + if self._subscription.listener: + try: + self._subscription.listener.on_partitions_assigned(assigned) + except Exception: + log.exception("User provided listener %s for group %s" + " failed on partition assignment: %s", + self._subscription.listener, self.group_id, + assigned) + + def poll(self): + """ + Poll for coordinator events. Only applicable if group_id is set, and + broker version supports GroupCoordinators. This ensures that the + coordinator is known, and if using automatic partition assignment, + ensures that the consumer has joined the group. This also handles + periodic offset commits if they are enabled. + """ + if self.group_id is None: + return + + self._invoke_completed_offset_commit_callbacks() + self.ensure_coordinator_ready() + + if self.config['api_version'] >= (0, 9) and self._subscription.partitions_auto_assigned(): + if self.need_rejoin(): + # due to a race condition between the initial metadata fetch and the + # initial rebalance, we need to ensure that the metadata is fresh + # before joining initially, and then request the metadata update. If + # metadata update arrives while the rebalance is still pending (for + # example, when the join group is still inflight), then we will lose + # track of the fact that we need to rebalance again to reflect the + # change to the topic subscription. Without ensuring that the + # metadata is fresh, any metadata update that changes the topic + # subscriptions and arrives while a rebalance is in progress will + # essentially be ignored. See KAFKA-3949 for the complete + # description of the problem. + if self._subscription.subscribed_pattern: + metadata_update = self._client.cluster.request_update() + self._client.poll(future=metadata_update) + + self.ensure_active_group() + + self.poll_heartbeat() + + self._maybe_auto_commit_offsets_async() + + def time_to_next_poll(self): + """Return seconds (float) remaining until :meth:`.poll` should be called again""" + if not self.config['enable_auto_commit']: + return self.time_to_next_heartbeat() + + if time.time() > self.next_auto_commit_deadline: + return 0 + + return min(self.next_auto_commit_deadline - time.time(), + self.time_to_next_heartbeat()) + + def _perform_assignment(self, leader_id, assignment_strategy, members): + assignor = self._lookup_assignor(assignment_strategy) + assert assignor, 'Invalid assignment protocol: %s' % (assignment_strategy,) + member_metadata = {} + all_subscribed_topics = set() + for member_id, metadata_bytes in members: + metadata = ConsumerProtocol.METADATA.decode(metadata_bytes) + member_metadata[member_id] = metadata + all_subscribed_topics.update(metadata.subscription) # pylint: disable-msg=no-member + + # the leader will begin watching for changes to any of the topics + # the group is interested in, which ensures that all metadata changes + # will eventually be seen + # Because assignment typically happens within response callbacks, + # we cannot block on metadata updates here (no recursion into poll()) + self._subscription.group_subscribe(all_subscribed_topics) + self._client.set_topics(self._subscription.group_subscription()) + + # keep track of the metadata used for assignment so that we can check + # after rebalance completion whether anything has changed + self._cluster.request_update() + self._is_leader = True + self._assignment_snapshot = self._metadata_snapshot + + log.debug("Performing assignment for group %s using strategy %s" + " with subscriptions %s", self.group_id, assignor.name, + member_metadata) + + assignments = assignor.assign(self._cluster, member_metadata) + + log.debug("Finished assignment for group %s: %s", self.group_id, assignments) + + group_assignment = {} + for member_id, assignment in six.iteritems(assignments): + group_assignment[member_id] = assignment + return group_assignment + + def _on_join_prepare(self, generation, member_id): + # commit offsets prior to rebalance if auto-commit enabled + self._maybe_auto_commit_offsets_sync() + + # execute the user's callback before rebalance + log.info("Revoking previously assigned partitions %s for group %s", + self._subscription.assigned_partitions(), self.group_id) + if self._subscription.listener: + try: + revoked = set(self._subscription.assigned_partitions()) + self._subscription.listener.on_partitions_revoked(revoked) + except Exception: + log.exception("User provided subscription listener %s" + " for group %s failed on_partitions_revoked", + self._subscription.listener, self.group_id) + + self._is_leader = False + self._subscription.reset_group_subscription() + + def need_rejoin(self): + """Check whether the group should be rejoined + + Returns: + bool: True if consumer should rejoin group, False otherwise + """ + if not self._subscription.partitions_auto_assigned(): + return False + + if self._auto_assign_all_partitions(): + return False + + # we need to rejoin if we performed the assignment and metadata has changed + if (self._assignment_snapshot is not None + and self._assignment_snapshot != self._metadata_snapshot): + return True + + # we need to join if our subscription has changed since the last join + if (self._joined_subscription is not None + and self._joined_subscription != self._subscription.subscription): + return True + + return super(ConsumerCoordinator, self).need_rejoin() + + def refresh_committed_offsets_if_needed(self): + """Fetch committed offsets for assigned partitions.""" + if self._subscription.needs_fetch_committed_offsets: + offsets = self.fetch_committed_offsets(self._subscription.assigned_partitions()) + for partition, offset in six.iteritems(offsets): + # verify assignment is still active + if self._subscription.is_assigned(partition): + self._subscription.assignment[partition].committed = offset + self._subscription.needs_fetch_committed_offsets = False + + def fetch_committed_offsets(self, partitions): + """Fetch the current committed offsets for specified partitions + + Arguments: + partitions (list of TopicPartition): partitions to fetch + + Returns: + dict: {TopicPartition: OffsetAndMetadata} + """ + if not partitions: + return {} + + while True: + self.ensure_coordinator_ready() + + # contact coordinator to fetch committed offsets + future = self._send_offset_fetch_request(partitions) + self._client.poll(future=future) + + if future.succeeded(): + return future.value + + if not future.retriable(): + raise future.exception # pylint: disable-msg=raising-bad-type + + time.sleep(self.config['retry_backoff_ms'] / 1000) + + def close(self, autocommit=True): + """Close the coordinator, leave the current group, + and reset local generation / member_id. + + Keyword Arguments: + autocommit (bool): If auto-commit is configured for this consumer, + this optional flag causes the consumer to attempt to commit any + pending consumed offsets prior to close. Default: True + """ + try: + if autocommit: + self._maybe_auto_commit_offsets_sync() + finally: + super(ConsumerCoordinator, self).close() + + def _invoke_completed_offset_commit_callbacks(self): + while self.completed_offset_commits: + callback, offsets, exception = self.completed_offset_commits.popleft() + callback(offsets, exception) + + def commit_offsets_async(self, offsets, callback=None): + """Commit specific offsets asynchronously. + + Arguments: + offsets (dict {TopicPartition: OffsetAndMetadata}): what to commit + callback (callable, optional): called as callback(offsets, response) + response will be either an Exception or a OffsetCommitResponse + struct. This callback can be used to trigger custom actions when + a commit request completes. + + Returns: + kafka.future.Future + """ + self._invoke_completed_offset_commit_callbacks() + if not self.coordinator_unknown(): + future = self._do_commit_offsets_async(offsets, callback) + else: + # we don't know the current coordinator, so try to find it and then + # send the commit or fail (we don't want recursive retries which can + # cause offset commits to arrive out of order). Note that there may + # be multiple offset commits chained to the same coordinator lookup + # request. This is fine because the listeners will be invoked in the + # same order that they were added. Note also that BaseCoordinator + # prevents multiple concurrent coordinator lookup requests. + future = self.lookup_coordinator() + future.add_callback(lambda r: functools.partial(self._do_commit_offsets_async, offsets, callback)()) + if callback: + future.add_errback(lambda e: self.completed_offset_commits.appendleft((callback, offsets, e))) + + # ensure the commit has a chance to be transmitted (without blocking on + # its completion). Note that commits are treated as heartbeats by the + # coordinator, so there is no need to explicitly allow heartbeats + # through delayed task execution. + self._client.poll(timeout_ms=0) # no wakeup if we add that feature + + return future + + def _do_commit_offsets_async(self, offsets, callback=None): + assert self.config['api_version'] >= (0, 8, 1), 'Unsupported Broker API' + assert all(map(lambda k: isinstance(k, TopicPartition), offsets)) + assert all(map(lambda v: isinstance(v, OffsetAndMetadata), + offsets.values())) + if callback is None: + callback = self.config['default_offset_commit_callback'] + self._subscription.needs_fetch_committed_offsets = True + future = self._send_offset_commit_request(offsets) + future.add_both(lambda res: self.completed_offset_commits.appendleft((callback, offsets, res))) + return future + + def commit_offsets_sync(self, offsets): + """Commit specific offsets synchronously. + + This method will retry until the commit completes successfully or an + unrecoverable error is encountered. + + Arguments: + offsets (dict {TopicPartition: OffsetAndMetadata}): what to commit + + Raises error on failure + """ + assert self.config['api_version'] >= (0, 8, 1), 'Unsupported Broker API' + assert all(map(lambda k: isinstance(k, TopicPartition), offsets)) + assert all(map(lambda v: isinstance(v, OffsetAndMetadata), + offsets.values())) + self._invoke_completed_offset_commit_callbacks() + if not offsets: + return + + while True: + self.ensure_coordinator_ready() + + future = self._send_offset_commit_request(offsets) + self._client.poll(future=future) + + if future.succeeded(): + return future.value + + if not future.retriable(): + raise future.exception # pylint: disable-msg=raising-bad-type + + time.sleep(self.config['retry_backoff_ms'] / 1000) + + def _maybe_auto_commit_offsets_sync(self): + if self.config['enable_auto_commit']: + try: + self.commit_offsets_sync(self._subscription.all_consumed_offsets()) + + # The three main group membership errors are known and should not + # require a stacktrace -- just a warning + except (Errors.UnknownMemberIdError, + Errors.IllegalGenerationError, + Errors.RebalanceInProgressError): + log.warning("Offset commit failed: group membership out of date" + " This is likely to cause duplicate message" + " delivery.") + except Exception: + log.exception("Offset commit failed: This is likely to cause" + " duplicate message delivery") + + def _send_offset_commit_request(self, offsets): + """Commit offsets for the specified list of topics and partitions. + + This is a non-blocking call which returns a request future that can be + polled in the case of a synchronous commit or ignored in the + asynchronous case. + + Arguments: + offsets (dict of {TopicPartition: OffsetAndMetadata}): what should + be committed + + Returns: + Future: indicating whether the commit was successful or not + """ + assert self.config['api_version'] >= (0, 8, 1), 'Unsupported Broker API' + assert all(map(lambda k: isinstance(k, TopicPartition), offsets)) + assert all(map(lambda v: isinstance(v, OffsetAndMetadata), + offsets.values())) + if not offsets: + log.debug('No offsets to commit') + return Future().success(None) + + node_id = self.coordinator() + if node_id is None: + return Future().failure(Errors.GroupCoordinatorNotAvailableError) + + + # create the offset commit request + offset_data = collections.defaultdict(dict) + for tp, offset in six.iteritems(offsets): + offset_data[tp.topic][tp.partition] = offset + + if self._subscription.partitions_auto_assigned(): + generation = self.generation() + else: + generation = Generation.NO_GENERATION + + # if the generation is None, we are not part of an active group + # (and we expect to be). The only thing we can do is fail the commit + # and let the user rejoin the group in poll() + if self.config['api_version'] >= (0, 9) and generation is None: + return Future().failure(Errors.CommitFailedError()) + + if self.config['api_version'] >= (0, 9): + request = OffsetCommitRequest[2]( + self.group_id, + generation.generation_id, + generation.member_id, + OffsetCommitRequest[2].DEFAULT_RETENTION_TIME, + [( + topic, [( + partition, + offset.offset, + offset.metadata + ) for partition, offset in six.iteritems(partitions)] + ) for topic, partitions in six.iteritems(offset_data)] + ) + elif self.config['api_version'] >= (0, 8, 2): + request = OffsetCommitRequest[1]( + self.group_id, -1, '', + [( + topic, [( + partition, + offset.offset, + -1, + offset.metadata + ) for partition, offset in six.iteritems(partitions)] + ) for topic, partitions in six.iteritems(offset_data)] + ) + elif self.config['api_version'] >= (0, 8, 1): + request = OffsetCommitRequest[0]( + self.group_id, + [( + topic, [( + partition, + offset.offset, + offset.metadata + ) for partition, offset in six.iteritems(partitions)] + ) for topic, partitions in six.iteritems(offset_data)] + ) + + log.debug("Sending offset-commit request with %s for group %s to %s", + offsets, self.group_id, node_id) + + future = Future() + _f = self._client.send(node_id, request) + _f.add_callback(self._handle_offset_commit_response, offsets, future, time.time()) + _f.add_errback(self._failed_request, node_id, request, future) + return future + + def _handle_offset_commit_response(self, offsets, future, send_time, response): + # TODO look at adding request_latency_ms to response (like java kafka) + self.consumer_sensors.commit_latency.record((time.time() - send_time) * 1000) + unauthorized_topics = set() + + for topic, partitions in response.topics: + for partition, error_code in partitions: + tp = TopicPartition(topic, partition) + offset = offsets[tp] + + error_type = Errors.for_code(error_code) + if error_type is Errors.NoError: + log.debug("Group %s committed offset %s for partition %s", + self.group_id, offset, tp) + if self._subscription.is_assigned(tp): + self._subscription.assignment[tp].committed = offset + elif error_type is Errors.GroupAuthorizationFailedError: + log.error("Not authorized to commit offsets for group %s", + self.group_id) + future.failure(error_type(self.group_id)) + return + elif error_type is Errors.TopicAuthorizationFailedError: + unauthorized_topics.add(topic) + elif error_type in (Errors.OffsetMetadataTooLargeError, + Errors.InvalidCommitOffsetSizeError): + # raise the error to the user + log.debug("OffsetCommit for group %s failed on partition %s" + " %s", self.group_id, tp, error_type.__name__) + future.failure(error_type()) + return + elif error_type is Errors.GroupLoadInProgressError: + # just retry + log.debug("OffsetCommit for group %s failed: %s", + self.group_id, error_type.__name__) + future.failure(error_type(self.group_id)) + return + elif error_type in (Errors.GroupCoordinatorNotAvailableError, + Errors.NotCoordinatorForGroupError, + Errors.RequestTimedOutError): + log.debug("OffsetCommit for group %s failed: %s", + self.group_id, error_type.__name__) + self.coordinator_dead(error_type()) + future.failure(error_type(self.group_id)) + return + elif error_type in (Errors.UnknownMemberIdError, + Errors.IllegalGenerationError, + Errors.RebalanceInProgressError): + # need to re-join group + error = error_type(self.group_id) + log.debug("OffsetCommit for group %s failed: %s", + self.group_id, error) + self.reset_generation() + future.failure(Errors.CommitFailedError()) + return + else: + log.error("Group %s failed to commit partition %s at offset" + " %s: %s", self.group_id, tp, offset, + error_type.__name__) + future.failure(error_type()) + return + + if unauthorized_topics: + log.error("Not authorized to commit to topics %s for group %s", + unauthorized_topics, self.group_id) + future.failure(Errors.TopicAuthorizationFailedError(unauthorized_topics)) + else: + future.success(None) + + def _send_offset_fetch_request(self, partitions): + """Fetch the committed offsets for a set of partitions. + + This is a non-blocking call. The returned future can be polled to get + the actual offsets returned from the broker. + + Arguments: + partitions (list of TopicPartition): the partitions to fetch + + Returns: + Future: resolves to dict of offsets: {TopicPartition: OffsetAndMetadata} + """ + assert self.config['api_version'] >= (0, 8, 1), 'Unsupported Broker API' + assert all(map(lambda k: isinstance(k, TopicPartition), partitions)) + if not partitions: + return Future().success({}) + + node_id = self.coordinator() + if node_id is None: + return Future().failure(Errors.GroupCoordinatorNotAvailableError) + + # Verify node is ready + if not self._client.ready(node_id): + log.debug("Node %s not ready -- failing offset fetch request", + node_id) + return Future().failure(Errors.NodeNotReadyError) + + log.debug("Group %s fetching committed offsets for partitions: %s", + self.group_id, partitions) + # construct the request + topic_partitions = collections.defaultdict(set) + for tp in partitions: + topic_partitions[tp.topic].add(tp.partition) + + if self.config['api_version'] >= (0, 8, 2): + request = OffsetFetchRequest[1]( + self.group_id, + list(topic_partitions.items()) + ) + else: + request = OffsetFetchRequest[0]( + self.group_id, + list(topic_partitions.items()) + ) + + # send the request with a callback + future = Future() + _f = self._client.send(node_id, request) + _f.add_callback(self._handle_offset_fetch_response, future) + _f.add_errback(self._failed_request, node_id, request, future) + return future + + def _handle_offset_fetch_response(self, future, response): + offsets = {} + for topic, partitions in response.topics: + for partition, offset, metadata, error_code in partitions: + tp = TopicPartition(topic, partition) + error_type = Errors.for_code(error_code) + if error_type is not Errors.NoError: + error = error_type() + log.debug("Group %s failed to fetch offset for partition" + " %s: %s", self.group_id, tp, error) + if error_type is Errors.GroupLoadInProgressError: + # just retry + future.failure(error) + elif error_type is Errors.NotCoordinatorForGroupError: + # re-discover the coordinator and retry + self.coordinator_dead(error_type()) + future.failure(error) + elif error_type is Errors.UnknownTopicOrPartitionError: + log.warning("OffsetFetchRequest -- unknown topic %s" + " (have you committed any offsets yet?)", + topic) + continue + else: + log.error("Unknown error fetching offsets for %s: %s", + tp, error) + future.failure(error) + return + elif offset >= 0: + # record the position with the offset + # (-1 indicates no committed offset to fetch) + offsets[tp] = OffsetAndMetadata(offset, metadata) + else: + log.debug("Group %s has no committed offset for partition" + " %s", self.group_id, tp) + future.success(offsets) + + def _default_offset_commit_callback(self, offsets, exception): + if exception is not None: + log.error("Offset commit failed: %s", exception) + + def _commit_offsets_async_on_complete(self, offsets, exception): + if exception is not None: + log.warning("Auto offset commit failed for group %s: %s", + self.group_id, exception) + if getattr(exception, 'retriable', False): + self.next_auto_commit_deadline = min(time.time() + self.config['retry_backoff_ms'] / 1000, self.next_auto_commit_deadline) + else: + log.debug("Completed autocommit of offsets %s for group %s", + offsets, self.group_id) + + def _maybe_auto_commit_offsets_async(self): + if self.config['enable_auto_commit']: + if self.coordinator_unknown(): + self.next_auto_commit_deadline = time.time() + self.config['retry_backoff_ms'] / 1000 + elif time.time() > self.next_auto_commit_deadline: + self.next_auto_commit_deadline = time.time() + self.auto_commit_interval + self.commit_offsets_async(self._subscription.all_consumed_offsets(), + self._commit_offsets_async_on_complete) + + +class ConsumerCoordinatorMetrics(object): + def __init__(self, metrics, metric_group_prefix, subscription): + self.metrics = metrics + self.metric_group_name = '%s-coordinator-metrics' % (metric_group_prefix,) + + self.commit_latency = metrics.sensor('commit-latency') + self.commit_latency.add(metrics.metric_name( + 'commit-latency-avg', self.metric_group_name, + 'The average time taken for a commit request'), Avg()) + self.commit_latency.add(metrics.metric_name( + 'commit-latency-max', self.metric_group_name, + 'The max time taken for a commit request'), Max()) + self.commit_latency.add(metrics.metric_name( + 'commit-rate', self.metric_group_name, + 'The number of commit calls per second'), Rate(sampled_stat=Count())) + + num_parts = AnonMeasurable(lambda config, now: + len(subscription.assigned_partitions())) + metrics.add_metric(metrics.metric_name( + 'assigned-partitions', self.metric_group_name, + 'The number of partitions currently assigned to this consumer'), + num_parts) diff --git a/coordinator/heartbeat.py b/coordinator/heartbeat.py new file mode 100644 index 00000000..2f5930b6 --- /dev/null +++ b/coordinator/heartbeat.py @@ -0,0 +1,68 @@ +from __future__ import absolute_import, division + +import copy +import time + + +class Heartbeat(object): + DEFAULT_CONFIG = { + 'group_id': None, + 'heartbeat_interval_ms': 3000, + 'session_timeout_ms': 10000, + 'max_poll_interval_ms': 300000, + 'retry_backoff_ms': 100, + } + + def __init__(self, **configs): + self.config = copy.copy(self.DEFAULT_CONFIG) + for key in self.config: + if key in configs: + self.config[key] = configs[key] + + if self.config['group_id'] is not None: + assert (self.config['heartbeat_interval_ms'] + <= self.config['session_timeout_ms']), ( + 'Heartbeat interval must be lower than the session timeout') + + self.last_send = -1 * float('inf') + self.last_receive = -1 * float('inf') + self.last_poll = -1 * float('inf') + self.last_reset = time.time() + self.heartbeat_failed = None + + def poll(self): + self.last_poll = time.time() + + def sent_heartbeat(self): + self.last_send = time.time() + self.heartbeat_failed = False + + def fail_heartbeat(self): + self.heartbeat_failed = True + + def received_heartbeat(self): + self.last_receive = time.time() + + def time_to_next_heartbeat(self): + """Returns seconds (float) remaining before next heartbeat should be sent""" + time_since_last_heartbeat = time.time() - max(self.last_send, self.last_reset) + if self.heartbeat_failed: + delay_to_next_heartbeat = self.config['retry_backoff_ms'] / 1000 + else: + delay_to_next_heartbeat = self.config['heartbeat_interval_ms'] / 1000 + return max(0, delay_to_next_heartbeat - time_since_last_heartbeat) + + def should_heartbeat(self): + return self.time_to_next_heartbeat() == 0 + + def session_timeout_expired(self): + last_recv = max(self.last_receive, self.last_reset) + return (time.time() - last_recv) > (self.config['session_timeout_ms'] / 1000) + + def reset_timeouts(self): + self.last_reset = time.time() + self.last_poll = time.time() + self.heartbeat_failed = False + + def poll_timeout_expired(self): + return (time.time() - self.last_poll) > (self.config['max_poll_interval_ms'] / 1000) diff --git a/coordinator/protocol.py b/coordinator/protocol.py new file mode 100644 index 00000000..56a39015 --- /dev/null +++ b/coordinator/protocol.py @@ -0,0 +1,33 @@ +from __future__ import absolute_import + +from kafka.protocol.struct import Struct +from kafka.protocol.types import Array, Bytes, Int16, Int32, Schema, String +from kafka.structs import TopicPartition + + +class ConsumerProtocolMemberMetadata(Struct): + SCHEMA = Schema( + ('version', Int16), + ('subscription', Array(String('utf-8'))), + ('user_data', Bytes)) + + +class ConsumerProtocolMemberAssignment(Struct): + SCHEMA = Schema( + ('version', Int16), + ('assignment', Array( + ('topic', String('utf-8')), + ('partitions', Array(Int32)))), + ('user_data', Bytes)) + + def partitions(self): + return [TopicPartition(topic, partition) + for topic, partitions in self.assignment # pylint: disable-msg=no-member + for partition in partitions] + + +class ConsumerProtocol(object): + PROTOCOL_TYPE = 'consumer' + ASSIGNMENT_STRATEGIES = ('range', 'roundrobin') + METADATA = ConsumerProtocolMemberMetadata + ASSIGNMENT = ConsumerProtocolMemberAssignment diff --git a/errors.py b/errors.py new file mode 100644 index 00000000..b33cf51e --- /dev/null +++ b/errors.py @@ -0,0 +1,538 @@ +from __future__ import absolute_import + +import inspect +import sys + + +class KafkaError(RuntimeError): + retriable = False + # whether metadata should be refreshed on error + invalid_metadata = False + + def __str__(self): + if not self.args: + return self.__class__.__name__ + return '{0}: {1}'.format(self.__class__.__name__, + super(KafkaError, self).__str__()) + + +class IllegalStateError(KafkaError): + pass + + +class IllegalArgumentError(KafkaError): + pass + + +class NoBrokersAvailable(KafkaError): + retriable = True + invalid_metadata = True + + +class NodeNotReadyError(KafkaError): + retriable = True + + +class KafkaProtocolError(KafkaError): + retriable = True + + +class CorrelationIdError(KafkaProtocolError): + retriable = True + + +class Cancelled(KafkaError): + retriable = True + + +class TooManyInFlightRequests(KafkaError): + retriable = True + + +class StaleMetadata(KafkaError): + retriable = True + invalid_metadata = True + + +class MetadataEmptyBrokerList(KafkaError): + retriable = True + + +class UnrecognizedBrokerVersion(KafkaError): + pass + + +class IncompatibleBrokerVersion(KafkaError): + pass + + +class CommitFailedError(KafkaError): + def __init__(self, *args, **kwargs): + super(CommitFailedError, self).__init__( + """Commit cannot be completed since the group has already + rebalanced and assigned the partitions to another member. + This means that the time between subsequent calls to poll() + was longer than the configured max_poll_interval_ms, which + typically implies that the poll loop is spending too much + time message processing. You can address this either by + increasing the rebalance timeout with max_poll_interval_ms, + or by reducing the maximum size of batches returned in poll() + with max_poll_records. + """, *args, **kwargs) + + +class AuthenticationMethodNotSupported(KafkaError): + pass + + +class AuthenticationFailedError(KafkaError): + retriable = False + + +class BrokerResponseError(KafkaError): + errno = None + message = None + description = None + + def __str__(self): + """Add errno to standard KafkaError str""" + return '[Error {0}] {1}'.format( + self.errno, + super(BrokerResponseError, self).__str__()) + + +class NoError(BrokerResponseError): + errno = 0 + message = 'NO_ERROR' + description = 'No error--it worked!' + + +class UnknownError(BrokerResponseError): + errno = -1 + message = 'UNKNOWN' + description = 'An unexpected server error.' + + +class OffsetOutOfRangeError(BrokerResponseError): + errno = 1 + message = 'OFFSET_OUT_OF_RANGE' + description = ('The requested offset is outside the range of offsets' + ' maintained by the server for the given topic/partition.') + + +class CorruptRecordException(BrokerResponseError): + errno = 2 + message = 'CORRUPT_MESSAGE' + description = ('This message has failed its CRC checksum, exceeds the' + ' valid size, or is otherwise corrupt.') + +# Backward compatibility +InvalidMessageError = CorruptRecordException + + +class UnknownTopicOrPartitionError(BrokerResponseError): + errno = 3 + message = 'UNKNOWN_TOPIC_OR_PARTITION' + description = ('This request is for a topic or partition that does not' + ' exist on this broker.') + retriable = True + invalid_metadata = True + + +class InvalidFetchRequestError(BrokerResponseError): + errno = 4 + message = 'INVALID_FETCH_SIZE' + description = 'The message has a negative size.' + + +class LeaderNotAvailableError(BrokerResponseError): + errno = 5 + message = 'LEADER_NOT_AVAILABLE' + description = ('This error is thrown if we are in the middle of a' + ' leadership election and there is currently no leader for' + ' this partition and hence it is unavailable for writes.') + retriable = True + invalid_metadata = True + + +class NotLeaderForPartitionError(BrokerResponseError): + errno = 6 + message = 'NOT_LEADER_FOR_PARTITION' + description = ('This error is thrown if the client attempts to send' + ' messages to a replica that is not the leader for some' + ' partition. It indicates that the clients metadata is out' + ' of date.') + retriable = True + invalid_metadata = True + + +class RequestTimedOutError(BrokerResponseError): + errno = 7 + message = 'REQUEST_TIMED_OUT' + description = ('This error is thrown if the request exceeds the' + ' user-specified time limit in the request.') + retriable = True + + +class BrokerNotAvailableError(BrokerResponseError): + errno = 8 + message = 'BROKER_NOT_AVAILABLE' + description = ('This is not a client facing error and is used mostly by' + ' tools when a broker is not alive.') + + +class ReplicaNotAvailableError(BrokerResponseError): + errno = 9 + message = 'REPLICA_NOT_AVAILABLE' + description = ('If replica is expected on a broker, but is not (this can be' + ' safely ignored).') + + +class MessageSizeTooLargeError(BrokerResponseError): + errno = 10 + message = 'MESSAGE_SIZE_TOO_LARGE' + description = ('The server has a configurable maximum message size to avoid' + ' unbounded memory allocation. This error is thrown if the' + ' client attempt to produce a message larger than this' + ' maximum.') + + +class StaleControllerEpochError(BrokerResponseError): + errno = 11 + message = 'STALE_CONTROLLER_EPOCH' + description = 'Internal error code for broker-to-broker communication.' + + +class OffsetMetadataTooLargeError(BrokerResponseError): + errno = 12 + message = 'OFFSET_METADATA_TOO_LARGE' + description = ('If you specify a string larger than configured maximum for' + ' offset metadata.') + + +# TODO is this deprecated? https://cwiki.apache.org/confluence/display/KAFKA/A+Guide+To+The+Kafka+Protocol#AGuideToTheKafkaProtocol-ErrorCodes +class StaleLeaderEpochCodeError(BrokerResponseError): + errno = 13 + message = 'STALE_LEADER_EPOCH_CODE' + + +class GroupLoadInProgressError(BrokerResponseError): + errno = 14 + message = 'OFFSETS_LOAD_IN_PROGRESS' + description = ('The broker returns this error code for an offset fetch' + ' request if it is still loading offsets (after a leader' + ' change for that offsets topic partition), or in response' + ' to group membership requests (such as heartbeats) when' + ' group metadata is being loaded by the coordinator.') + retriable = True + + +class GroupCoordinatorNotAvailableError(BrokerResponseError): + errno = 15 + message = 'CONSUMER_COORDINATOR_NOT_AVAILABLE' + description = ('The broker returns this error code for group coordinator' + ' requests, offset commits, and most group management' + ' requests if the offsets topic has not yet been created, or' + ' if the group coordinator is not active.') + retriable = True + + +class NotCoordinatorForGroupError(BrokerResponseError): + errno = 16 + message = 'NOT_COORDINATOR_FOR_CONSUMER' + description = ('The broker returns this error code if it receives an offset' + ' fetch or commit request for a group that it is not a' + ' coordinator for.') + retriable = True + + +class InvalidTopicError(BrokerResponseError): + errno = 17 + message = 'INVALID_TOPIC' + description = ('For a request which attempts to access an invalid topic' + ' (e.g. one which has an illegal name), or if an attempt' + ' is made to write to an internal topic (such as the' + ' consumer offsets topic).') + + +class RecordListTooLargeError(BrokerResponseError): + errno = 18 + message = 'RECORD_LIST_TOO_LARGE' + description = ('If a message batch in a produce request exceeds the maximum' + ' configured segment size.') + + +class NotEnoughReplicasError(BrokerResponseError): + errno = 19 + message = 'NOT_ENOUGH_REPLICAS' + description = ('Returned from a produce request when the number of in-sync' + ' replicas is lower than the configured minimum and' + ' requiredAcks is -1.') + retriable = True + + +class NotEnoughReplicasAfterAppendError(BrokerResponseError): + errno = 20 + message = 'NOT_ENOUGH_REPLICAS_AFTER_APPEND' + description = ('Returned from a produce request when the message was' + ' written to the log, but with fewer in-sync replicas than' + ' required.') + retriable = True + + +class InvalidRequiredAcksError(BrokerResponseError): + errno = 21 + message = 'INVALID_REQUIRED_ACKS' + description = ('Returned from a produce request if the requested' + ' requiredAcks is invalid (anything other than -1, 1, or 0).') + + +class IllegalGenerationError(BrokerResponseError): + errno = 22 + message = 'ILLEGAL_GENERATION' + description = ('Returned from group membership requests (such as heartbeats)' + ' when the generation id provided in the request is not the' + ' current generation.') + + +class InconsistentGroupProtocolError(BrokerResponseError): + errno = 23 + message = 'INCONSISTENT_GROUP_PROTOCOL' + description = ('Returned in join group when the member provides a protocol' + ' type or set of protocols which is not compatible with the' + ' current group.') + + +class InvalidGroupIdError(BrokerResponseError): + errno = 24 + message = 'INVALID_GROUP_ID' + description = 'Returned in join group when the groupId is empty or null.' + + +class UnknownMemberIdError(BrokerResponseError): + errno = 25 + message = 'UNKNOWN_MEMBER_ID' + description = ('Returned from group requests (offset commits/fetches,' + ' heartbeats, etc) when the memberId is not in the current' + ' generation.') + + +class InvalidSessionTimeoutError(BrokerResponseError): + errno = 26 + message = 'INVALID_SESSION_TIMEOUT' + description = ('Return in join group when the requested session timeout is' + ' outside of the allowed range on the broker') + + +class RebalanceInProgressError(BrokerResponseError): + errno = 27 + message = 'REBALANCE_IN_PROGRESS' + description = ('Returned in heartbeat requests when the coordinator has' + ' begun rebalancing the group. This indicates to the client' + ' that it should rejoin the group.') + + +class InvalidCommitOffsetSizeError(BrokerResponseError): + errno = 28 + message = 'INVALID_COMMIT_OFFSET_SIZE' + description = ('This error indicates that an offset commit was rejected' + ' because of oversize metadata.') + + +class TopicAuthorizationFailedError(BrokerResponseError): + errno = 29 + message = 'TOPIC_AUTHORIZATION_FAILED' + description = ('Returned by the broker when the client is not authorized to' + ' access the requested topic.') + + +class GroupAuthorizationFailedError(BrokerResponseError): + errno = 30 + message = 'GROUP_AUTHORIZATION_FAILED' + description = ('Returned by the broker when the client is not authorized to' + ' access a particular groupId.') + + +class ClusterAuthorizationFailedError(BrokerResponseError): + errno = 31 + message = 'CLUSTER_AUTHORIZATION_FAILED' + description = ('Returned by the broker when the client is not authorized to' + ' use an inter-broker or administrative API.') + + +class InvalidTimestampError(BrokerResponseError): + errno = 32 + message = 'INVALID_TIMESTAMP' + description = 'The timestamp of the message is out of acceptable range.' + + +class UnsupportedSaslMechanismError(BrokerResponseError): + errno = 33 + message = 'UNSUPPORTED_SASL_MECHANISM' + description = 'The broker does not support the requested SASL mechanism.' + + +class IllegalSaslStateError(BrokerResponseError): + errno = 34 + message = 'ILLEGAL_SASL_STATE' + description = 'Request is not valid given the current SASL state.' + + +class UnsupportedVersionError(BrokerResponseError): + errno = 35 + message = 'UNSUPPORTED_VERSION' + description = 'The version of API is not supported.' + + +class TopicAlreadyExistsError(BrokerResponseError): + errno = 36 + message = 'TOPIC_ALREADY_EXISTS' + description = 'Topic with this name already exists.' + + +class InvalidPartitionsError(BrokerResponseError): + errno = 37 + message = 'INVALID_PARTITIONS' + description = 'Number of partitions is invalid.' + + +class InvalidReplicationFactorError(BrokerResponseError): + errno = 38 + message = 'INVALID_REPLICATION_FACTOR' + description = 'Replication-factor is invalid.' + + +class InvalidReplicationAssignmentError(BrokerResponseError): + errno = 39 + message = 'INVALID_REPLICATION_ASSIGNMENT' + description = 'Replication assignment is invalid.' + + +class InvalidConfigurationError(BrokerResponseError): + errno = 40 + message = 'INVALID_CONFIG' + description = 'Configuration is invalid.' + + +class NotControllerError(BrokerResponseError): + errno = 41 + message = 'NOT_CONTROLLER' + description = 'This is not the correct controller for this cluster.' + retriable = True + + +class InvalidRequestError(BrokerResponseError): + errno = 42 + message = 'INVALID_REQUEST' + description = ('This most likely occurs because of a request being' + ' malformed by the client library or the message was' + ' sent to an incompatible broker. See the broker logs' + ' for more details.') + + +class UnsupportedForMessageFormatError(BrokerResponseError): + errno = 43 + message = 'UNSUPPORTED_FOR_MESSAGE_FORMAT' + description = ('The message format version on the broker does not' + ' support this request.') + + +class PolicyViolationError(BrokerResponseError): + errno = 44 + message = 'POLICY_VIOLATION' + description = 'Request parameters do not satisfy the configured policy.' + + +class SecurityDisabledError(BrokerResponseError): + errno = 54 + message = 'SECURITY_DISABLED' + description = 'Security features are disabled.' + + +class NonEmptyGroupError(BrokerResponseError): + errno = 68 + message = 'NON_EMPTY_GROUP' + description = 'The group is not empty.' + + +class GroupIdNotFoundError(BrokerResponseError): + errno = 69 + message = 'GROUP_ID_NOT_FOUND' + description = 'The group id does not exist.' + + +class KafkaUnavailableError(KafkaError): + pass + + +class KafkaTimeoutError(KafkaError): + pass + + +class FailedPayloadsError(KafkaError): + def __init__(self, payload, *args): + super(FailedPayloadsError, self).__init__(*args) + self.payload = payload + + +class KafkaConnectionError(KafkaError): + retriable = True + invalid_metadata = True + + +class ProtocolError(KafkaError): + pass + + +class UnsupportedCodecError(KafkaError): + pass + + +class KafkaConfigurationError(KafkaError): + pass + + +class QuotaViolationError(KafkaError): + pass + + +class AsyncProducerQueueFull(KafkaError): + def __init__(self, failed_msgs, *args): + super(AsyncProducerQueueFull, self).__init__(*args) + self.failed_msgs = failed_msgs + + +def _iter_broker_errors(): + for name, obj in inspect.getmembers(sys.modules[__name__]): + if inspect.isclass(obj) and issubclass(obj, BrokerResponseError) and obj != BrokerResponseError: + yield obj + + +kafka_errors = dict([(x.errno, x) for x in _iter_broker_errors()]) + + +def for_code(error_code): + return kafka_errors.get(error_code, UnknownError) + + +def check_error(response): + if isinstance(response, Exception): + raise response + if response.error: + error_class = kafka_errors.get(response.error, UnknownError) + raise error_class(response) + + +RETRY_BACKOFF_ERROR_TYPES = ( + KafkaUnavailableError, LeaderNotAvailableError, + KafkaConnectionError, FailedPayloadsError +) + + +RETRY_REFRESH_ERROR_TYPES = ( + NotLeaderForPartitionError, UnknownTopicOrPartitionError, + LeaderNotAvailableError, KafkaConnectionError +) + + +RETRY_ERROR_TYPES = RETRY_BACKOFF_ERROR_TYPES + RETRY_REFRESH_ERROR_TYPES diff --git a/future.py b/future.py new file mode 100644 index 00000000..d0f3c665 --- /dev/null +++ b/future.py @@ -0,0 +1,83 @@ +from __future__ import absolute_import + +import functools +import logging + +log = logging.getLogger(__name__) + + +class Future(object): + error_on_callbacks = False # and errbacks + + def __init__(self): + self.is_done = False + self.value = None + self.exception = None + self._callbacks = [] + self._errbacks = [] + + def succeeded(self): + return self.is_done and not bool(self.exception) + + def failed(self): + return self.is_done and bool(self.exception) + + def retriable(self): + try: + return self.exception.retriable + except AttributeError: + return False + + def success(self, value): + assert not self.is_done, 'Future is already complete' + self.value = value + self.is_done = True + if self._callbacks: + self._call_backs('callback', self._callbacks, self.value) + return self + + def failure(self, e): + assert not self.is_done, 'Future is already complete' + self.exception = e if type(e) is not type else e() + assert isinstance(self.exception, BaseException), ( + 'future failed without an exception') + self.is_done = True + self._call_backs('errback', self._errbacks, self.exception) + return self + + def add_callback(self, f, *args, **kwargs): + if args or kwargs: + f = functools.partial(f, *args, **kwargs) + if self.is_done and not self.exception: + self._call_backs('callback', [f], self.value) + else: + self._callbacks.append(f) + return self + + def add_errback(self, f, *args, **kwargs): + if args or kwargs: + f = functools.partial(f, *args, **kwargs) + if self.is_done and self.exception: + self._call_backs('errback', [f], self.exception) + else: + self._errbacks.append(f) + return self + + def add_both(self, f, *args, **kwargs): + self.add_callback(f, *args, **kwargs) + self.add_errback(f, *args, **kwargs) + return self + + def chain(self, future): + self.add_callback(future.success) + self.add_errback(future.failure) + return self + + def _call_backs(self, back_type, backs, value): + for f in backs: + try: + f(value) + except Exception as e: + log.exception('Error processing %s', back_type) + if self.error_on_callbacks: + raise e diff --git a/metrics/__init__.py b/metrics/__init__.py new file mode 100644 index 00000000..2a62d633 --- /dev/null +++ b/metrics/__init__.py @@ -0,0 +1,15 @@ +from __future__ import absolute_import + +from kafka.metrics.compound_stat import NamedMeasurable +from kafka.metrics.dict_reporter import DictReporter +from kafka.metrics.kafka_metric import KafkaMetric +from kafka.metrics.measurable import AnonMeasurable +from kafka.metrics.metric_config import MetricConfig +from kafka.metrics.metric_name import MetricName +from kafka.metrics.metrics import Metrics +from kafka.metrics.quota import Quota + +__all__ = [ + 'AnonMeasurable', 'DictReporter', 'KafkaMetric', 'MetricConfig', + 'MetricName', 'Metrics', 'NamedMeasurable', 'Quota' +] diff --git a/metrics/compound_stat.py b/metrics/compound_stat.py new file mode 100644 index 00000000..ac92480d --- /dev/null +++ b/metrics/compound_stat.py @@ -0,0 +1,34 @@ +from __future__ import absolute_import + +import abc + +from kafka.metrics.stat import AbstractStat + + +class AbstractCompoundStat(AbstractStat): + """ + A compound stat is a stat where a single measurement and associated + data structure feeds many metrics. This is the example for a + histogram which has many associated percentiles. + """ + __metaclass__ = abc.ABCMeta + + def stats(self): + """ + Return list of NamedMeasurable + """ + raise NotImplementedError + + +class NamedMeasurable(object): + def __init__(self, metric_name, measurable_stat): + self._name = metric_name + self._stat = measurable_stat + + @property + def name(self): + return self._name + + @property + def stat(self): + return self._stat diff --git a/metrics/dict_reporter.py b/metrics/dict_reporter.py new file mode 100644 index 00000000..0b98fe1e --- /dev/null +++ b/metrics/dict_reporter.py @@ -0,0 +1,83 @@ +from __future__ import absolute_import + +import logging +import threading + +from kafka.metrics.metrics_reporter import AbstractMetricsReporter + +logger = logging.getLogger(__name__) + + +class DictReporter(AbstractMetricsReporter): + """A basic dictionary based metrics reporter. + + Store all metrics in a two level dictionary of category > name > metric. + """ + def __init__(self, prefix=''): + self._lock = threading.Lock() + self._prefix = prefix if prefix else '' # never allow None + self._store = {} + + def snapshot(self): + """ + Return a nested dictionary snapshot of all metrics and their + values at this time. Example: + { + 'category': { + 'metric1_name': 42.0, + 'metric2_name': 'foo' + } + } + """ + return dict((category, dict((name, metric.value()) + for name, metric in list(metrics.items()))) + for category, metrics in + list(self._store.items())) + + def init(self, metrics): + for metric in metrics: + self.metric_change(metric) + + def metric_change(self, metric): + with self._lock: + category = self.get_category(metric) + if category not in self._store: + self._store[category] = {} + self._store[category][metric.metric_name.name] = metric + + def metric_removal(self, metric): + with self._lock: + category = self.get_category(metric) + metrics = self._store.get(category, {}) + removed = metrics.pop(metric.metric_name.name, None) + if not metrics: + self._store.pop(category, None) + return removed + + def get_category(self, metric): + """ + Return a string category for the metric. + + The category is made up of this reporter's prefix and the + metric's group and tags. + + Examples: + prefix = 'foo', group = 'bar', tags = {'a': 1, 'b': 2} + returns: 'foo.bar.a=1,b=2' + + prefix = 'foo', group = 'bar', tags = None + returns: 'foo.bar' + + prefix = None, group = 'bar', tags = None + returns: 'bar' + """ + tags = ','.join('%s=%s' % (k, v) for k, v in + sorted(metric.metric_name.tags.items())) + return '.'.join(x for x in + [self._prefix, metric.metric_name.group, tags] if x) + + def configure(self, configs): + pass + + def close(self): + pass diff --git a/metrics/kafka_metric.py b/metrics/kafka_metric.py new file mode 100644 index 00000000..9fb8d89f --- /dev/null +++ b/metrics/kafka_metric.py @@ -0,0 +1,36 @@ +from __future__ import absolute_import + +import time + + +class KafkaMetric(object): + # NOTE java constructor takes a lock instance + def __init__(self, metric_name, measurable, config): + if not metric_name: + raise ValueError('metric_name must be non-empty') + if not measurable: + raise ValueError('measurable must be non-empty') + self._metric_name = metric_name + self._measurable = measurable + self._config = config + + @property + def metric_name(self): + return self._metric_name + + @property + def measurable(self): + return self._measurable + + @property + def config(self): + return self._config + + @config.setter + def config(self, config): + self._config = config + + def value(self, time_ms=None): + if time_ms is None: + time_ms = time.time() * 1000 + return self.measurable.measure(self.config, time_ms) diff --git a/metrics/measurable.py b/metrics/measurable.py new file mode 100644 index 00000000..b06d4d78 --- /dev/null +++ b/metrics/measurable.py @@ -0,0 +1,29 @@ +from __future__ import absolute_import + +import abc + + +class AbstractMeasurable(object): + """A measurable quantity that can be registered as a metric""" + @abc.abstractmethod + def measure(self, config, now): + """ + Measure this quantity and return the result + + Arguments: + config (MetricConfig): The configuration for this metric + now (int): The POSIX time in milliseconds the measurement + is being taken + + Returns: + The measured value + """ + raise NotImplementedError + + +class AnonMeasurable(AbstractMeasurable): + def __init__(self, measure_fn): + self._measure_fn = measure_fn + + def measure(self, config, now): + return float(self._measure_fn(config, now)) diff --git a/metrics/measurable_stat.py b/metrics/measurable_stat.py new file mode 100644 index 00000000..4487adf6 --- /dev/null +++ b/metrics/measurable_stat.py @@ -0,0 +1,16 @@ +from __future__ import absolute_import + +import abc + +from kafka.metrics.measurable import AbstractMeasurable +from kafka.metrics.stat import AbstractStat + + +class AbstractMeasurableStat(AbstractStat, AbstractMeasurable): + """ + An AbstractMeasurableStat is an AbstractStat that is also + an AbstractMeasurable (i.e. can produce a single floating point value). + This is the interface used for most of the simple statistics such + as Avg, Max, Count, etc. + """ + __metaclass__ = abc.ABCMeta diff --git a/metrics/metric_config.py b/metrics/metric_config.py new file mode 100644 index 00000000..2e55abfc --- /dev/null +++ b/metrics/metric_config.py @@ -0,0 +1,33 @@ +from __future__ import absolute_import + +import sys + + +class MetricConfig(object): + """Configuration values for metrics""" + def __init__(self, quota=None, samples=2, event_window=sys.maxsize, + time_window_ms=30 * 1000, tags=None): + """ + Arguments: + quota (Quota, optional): Upper or lower bound of a value. + samples (int, optional): Max number of samples kept per metric. + event_window (int, optional): Max number of values per sample. + time_window_ms (int, optional): Max age of an individual sample. + tags (dict of {str: str}, optional): Tags for each metric. + """ + self.quota = quota + self._samples = samples + self.event_window = event_window + self.time_window_ms = time_window_ms + # tags should be OrderedDict (not supported in py26) + self.tags = tags if tags else {} + + @property + def samples(self): + return self._samples + + @samples.setter + def samples(self, value): + if value < 1: + raise ValueError('The number of samples must be at least 1.') + self._samples = value diff --git a/metrics/metric_name.py b/metrics/metric_name.py new file mode 100644 index 00000000..b5acd166 --- /dev/null +++ b/metrics/metric_name.py @@ -0,0 +1,106 @@ +from __future__ import absolute_import + +import copy + + +class MetricName(object): + """ + This class encapsulates a metric's name, logical group and its + related attributes (tags). + + group, tags parameters can be used to create unique metric names. + e.g. domainName:type=group,key1=val1,key2=val2 + + Usage looks something like this: + + # set up metrics: + metric_tags = {'client-id': 'producer-1', 'topic': 'topic'} + metric_config = MetricConfig(tags=metric_tags) + + # metrics is the global repository of metrics and sensors + metrics = Metrics(metric_config) + + sensor = metrics.sensor('message-sizes') + metric_name = metrics.metric_name('message-size-avg', + 'producer-metrics', + 'average message size') + sensor.add(metric_name, Avg()) + + metric_name = metrics.metric_name('message-size-max', + sensor.add(metric_name, Max()) + + tags = {'client-id': 'my-client', 'topic': 'my-topic'} + metric_name = metrics.metric_name('message-size-min', + 'producer-metrics', + 'message minimum size', tags) + sensor.add(metric_name, Min()) + + # as messages are sent we record the sizes + sensor.record(message_size) + """ + + def __init__(self, name, group, description=None, tags=None): + """ + Arguments: + name (str): The name of the metric. + group (str): The logical group name of the metrics to which this + metric belongs. + description (str, optional): A human-readable description to + include in the metric. + tags (dict, optional): Additional key/val attributes of the metric. + """ + if not (name and group): + raise ValueError('name and group must be non-empty.') + if tags is not None and not isinstance(tags, dict): + raise ValueError('tags must be a dict if present.') + + self._name = name + self._group = group + self._description = description + self._tags = copy.copy(tags) + self._hash = 0 + + @property + def name(self): + return self._name + + @property + def group(self): + return self._group + + @property + def description(self): + return self._description + + @property + def tags(self): + return copy.copy(self._tags) + + def __hash__(self): + if self._hash != 0: + return self._hash + prime = 31 + result = 1 + result = prime * result + hash(self.group) + result = prime * result + hash(self.name) + tags_hash = hash(frozenset(self.tags.items())) if self.tags else 0 + result = prime * result + tags_hash + self._hash = result + return result + + def __eq__(self, other): + if self is other: + return True + if other is None: + return False + return (type(self) == type(other) and + self.group == other.group and + self.name == other.name and + self.tags == other.tags) + + def __ne__(self, other): + return not self.__eq__(other) + + def __str__(self): + return 'MetricName(name=%s, group=%s, description=%s, tags=%s)' % ( + self.name, self.group, self.description, self.tags) diff --git a/metrics/metrics.py b/metrics/metrics.py new file mode 100644 index 00000000..2c53488f --- /dev/null +++ b/metrics/metrics.py @@ -0,0 +1,261 @@ +from __future__ import absolute_import + +import logging +import sys +import time +import threading + +from kafka.metrics import AnonMeasurable, KafkaMetric, MetricConfig, MetricName +from kafka.metrics.stats import Sensor + +logger = logging.getLogger(__name__) + + +class Metrics(object): + """ + A registry of sensors and metrics. + + A metric is a named, numerical measurement. A sensor is a handle to + record numerical measurements as they occur. Each Sensor has zero or + more associated metrics. For example a Sensor might represent message + sizes and we might associate with this sensor a metric for the average, + maximum, or other statistics computed off the sequence of message sizes + that are recorded by the sensor. + + Usage looks something like this: + # set up metrics: + metrics = Metrics() # the global repository of metrics and sensors + sensor = metrics.sensor('message-sizes') + metric_name = MetricName('message-size-avg', 'producer-metrics') + sensor.add(metric_name, Avg()) + metric_name = MetricName('message-size-max', 'producer-metrics') + sensor.add(metric_name, Max()) + + # as messages are sent we record the sizes + sensor.record(message_size); + """ + def __init__(self, default_config=None, reporters=None, + enable_expiration=False): + """ + Create a metrics repository with a default config, given metric + reporters and the ability to expire eligible sensors + + Arguments: + default_config (MetricConfig, optional): The default config + reporters (list of AbstractMetricsReporter, optional): + The metrics reporters + enable_expiration (bool, optional): true if the metrics instance + can garbage collect inactive sensors, false otherwise + """ + self._lock = threading.RLock() + self._config = default_config or MetricConfig() + self._sensors = {} + self._metrics = {} + self._children_sensors = {} + self._reporters = reporters or [] + for reporter in self._reporters: + reporter.init([]) + + if enable_expiration: + def expire_loop(): + while True: + # delay 30 seconds + time.sleep(30) + self.ExpireSensorTask.run(self) + metrics_scheduler = threading.Thread(target=expire_loop) + # Creating a daemon thread to not block shutdown + metrics_scheduler.daemon = True + metrics_scheduler.start() + + self.add_metric(self.metric_name('count', 'kafka-metrics-count', + 'total number of registered metrics'), + AnonMeasurable(lambda config, now: len(self._metrics))) + + @property + def config(self): + return self._config + + @property + def metrics(self): + """ + Get all the metrics currently maintained and indexed by metricName + """ + return self._metrics + + def metric_name(self, name, group, description='', tags=None): + """ + Create a MetricName with the given name, group, description and tags, + plus default tags specified in the metric configuration. + Tag in tags takes precedence if the same tag key is specified in + the default metric configuration. + + Arguments: + name (str): The name of the metric + group (str): logical group name of the metrics to which this + metric belongs + description (str, optional): A human-readable description to + include in the metric + tags (dict, optionals): additional key/value attributes of + the metric + """ + combined_tags = dict(self.config.tags) + combined_tags.update(tags or {}) + return MetricName(name, group, description, combined_tags) + + def get_sensor(self, name): + """ + Get the sensor with the given name if it exists + + Arguments: + name (str): The name of the sensor + + Returns: + Sensor: The sensor or None if no such sensor exists + """ + if not name: + raise ValueError('name must be non-empty') + return self._sensors.get(name, None) + + def sensor(self, name, config=None, + inactive_sensor_expiration_time_seconds=sys.maxsize, + parents=None): + """ + Get or create a sensor with the given unique name and zero or + more parent sensors. All parent sensors will receive every value + recorded with this sensor. + + Arguments: + name (str): The name of the sensor + config (MetricConfig, optional): A default configuration to use + for this sensor for metrics that don't have their own config + inactive_sensor_expiration_time_seconds (int, optional): + If no value if recorded on the Sensor for this duration of + time, it is eligible for removal + parents (list of Sensor): The parent sensors + + Returns: + Sensor: The sensor that is created + """ + sensor = self.get_sensor(name) + if sensor: + return sensor + + with self._lock: + sensor = self.get_sensor(name) + if not sensor: + sensor = Sensor(self, name, parents, config or self.config, + inactive_sensor_expiration_time_seconds) + self._sensors[name] = sensor + if parents: + for parent in parents: + children = self._children_sensors.get(parent) + if not children: + children = [] + self._children_sensors[parent] = children + children.append(sensor) + logger.debug('Added sensor with name %s', name) + return sensor + + def remove_sensor(self, name): + """ + Remove a sensor (if it exists), associated metrics and its children. + + Arguments: + name (str): The name of the sensor to be removed + """ + sensor = self._sensors.get(name) + if sensor: + child_sensors = None + with sensor._lock: + with self._lock: + val = self._sensors.pop(name, None) + if val and val == sensor: + for metric in sensor.metrics: + self.remove_metric(metric.metric_name) + logger.debug('Removed sensor with name %s', name) + child_sensors = self._children_sensors.pop(sensor, None) + if child_sensors: + for child_sensor in child_sensors: + self.remove_sensor(child_sensor.name) + + def add_metric(self, metric_name, measurable, config=None): + """ + Add a metric to monitor an object that implements measurable. + This metric won't be associated with any sensor. + This is a way to expose existing values as metrics. + + Arguments: + metricName (MetricName): The name of the metric + measurable (AbstractMeasurable): The measurable that will be + measured by this metric + config (MetricConfig, optional): The configuration to use when + measuring this measurable + """ + # NOTE there was a lock here, but i don't think it's needed + metric = KafkaMetric(metric_name, measurable, config or self.config) + self.register_metric(metric) + + def remove_metric(self, metric_name): + """ + Remove a metric if it exists and return it. Return None otherwise. + If a metric is removed, `metric_removal` will be invoked + for each reporter. + + Arguments: + metric_name (MetricName): The name of the metric + + Returns: + KafkaMetric: the removed `KafkaMetric` or None if no such + metric exists + """ + with self._lock: + metric = self._metrics.pop(metric_name, None) + if metric: + for reporter in self._reporters: + reporter.metric_removal(metric) + return metric + + def add_reporter(self, reporter): + """Add a MetricReporter""" + with self._lock: + reporter.init(list(self.metrics.values())) + self._reporters.append(reporter) + + def register_metric(self, metric): + with self._lock: + if metric.metric_name in self.metrics: + raise ValueError('A metric named "%s" already exists, cannot' + ' register another one.' % (metric.metric_name,)) + self.metrics[metric.metric_name] = metric + for reporter in self._reporters: + reporter.metric_change(metric) + + class ExpireSensorTask(object): + """ + This iterates over every Sensor and triggers a remove_sensor + if it has expired. Package private for testing + """ + @staticmethod + def run(metrics): + items = list(metrics._sensors.items()) + for name, sensor in items: + # remove_sensor also locks the sensor object. This is fine + # because synchronized is reentrant. There is however a minor + # race condition here. Assume we have a parent sensor P and + # child sensor C. Calling record on C would cause a record on + # P as well. So expiration time for P == expiration time for C. + # If the record on P happens via C just after P is removed, + # that will cause C to also get removed. Since the expiration + # time is typically high it is not expected to be a significant + # concern and thus not necessary to optimize + with sensor._lock: + if sensor.has_expired(): + logger.debug('Removing expired sensor %s', name) + metrics.remove_sensor(name) + + def close(self): + """Close this metrics repository.""" + for reporter in self._reporters: + reporter.close() + + self._metrics.clear() diff --git a/metrics/metrics_reporter.py b/metrics/metrics_reporter.py new file mode 100644 index 00000000..d8bd12b3 --- /dev/null +++ b/metrics/metrics_reporter.py @@ -0,0 +1,57 @@ +from __future__ import absolute_import + +import abc + + +class AbstractMetricsReporter(object): + """ + An abstract class to allow things to listen as new metrics + are created so they can be reported. + """ + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def init(self, metrics): + """ + This is called when the reporter is first registered + to initially register all existing metrics + + Arguments: + metrics (list of KafkaMetric): All currently existing metrics + """ + raise NotImplementedError + + @abc.abstractmethod + def metric_change(self, metric): + """ + This is called whenever a metric is updated or added + + Arguments: + metric (KafkaMetric) + """ + raise NotImplementedError + + @abc.abstractmethod + def metric_removal(self, metric): + """ + This is called whenever a metric is removed + + Arguments: + metric (KafkaMetric) + """ + raise NotImplementedError + + @abc.abstractmethod + def configure(self, configs): + """ + Configure this class with the given key-value pairs + + Arguments: + configs (dict of {str, ?}) + """ + raise NotImplementedError + + @abc.abstractmethod + def close(self): + """Called when the metrics repository is closed.""" + raise NotImplementedError diff --git a/metrics/quota.py b/metrics/quota.py new file mode 100644 index 00000000..4d1b0d6c --- /dev/null +++ b/metrics/quota.py @@ -0,0 +1,42 @@ +from __future__ import absolute_import + + +class Quota(object): + """An upper or lower bound for metrics""" + def __init__(self, bound, is_upper): + self._bound = bound + self._upper = is_upper + + @staticmethod + def upper_bound(upper_bound): + return Quota(upper_bound, True) + + @staticmethod + def lower_bound(lower_bound): + return Quota(lower_bound, False) + + def is_upper_bound(self): + return self._upper + + @property + def bound(self): + return self._bound + + def is_acceptable(self, value): + return ((self.is_upper_bound() and value <= self.bound) or + (not self.is_upper_bound() and value >= self.bound)) + + def __hash__(self): + prime = 31 + result = prime + self.bound + return prime * result + self.is_upper_bound() + + def __eq__(self, other): + if self is other: + return True + return (type(self) == type(other) and + self.bound == other.bound and + self.is_upper_bound() == other.is_upper_bound()) + + def __ne__(self, other): + return not self.__eq__(other) diff --git a/metrics/stat.py b/metrics/stat.py new file mode 100644 index 00000000..9fd2f01e --- /dev/null +++ b/metrics/stat.py @@ -0,0 +1,23 @@ +from __future__ import absolute_import + +import abc + + +class AbstractStat(object): + """ + An AbstractStat is a quantity such as average, max, etc that is computed + off the stream of updates to a sensor + """ + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def record(self, config, value, time_ms): + """ + Record the given value + + Arguments: + config (MetricConfig): The configuration to use for this metric + value (float): The value to record + timeMs (int): The POSIX time in milliseconds this value occurred + """ + raise NotImplementedError diff --git a/metrics/stats/__init__.py b/metrics/stats/__init__.py new file mode 100644 index 00000000..a3d535df --- /dev/null +++ b/metrics/stats/__init__.py @@ -0,0 +1,17 @@ +from __future__ import absolute_import + +from kafka.metrics.stats.avg import Avg +from kafka.metrics.stats.count import Count +from kafka.metrics.stats.histogram import Histogram +from kafka.metrics.stats.max_stat import Max +from kafka.metrics.stats.min_stat import Min +from kafka.metrics.stats.percentile import Percentile +from kafka.metrics.stats.percentiles import Percentiles +from kafka.metrics.stats.rate import Rate +from kafka.metrics.stats.sensor import Sensor +from kafka.metrics.stats.total import Total + +__all__ = [ + 'Avg', 'Count', 'Histogram', 'Max', 'Min', 'Percentile', 'Percentiles', + 'Rate', 'Sensor', 'Total' +] diff --git a/metrics/stats/avg.py b/metrics/stats/avg.py new file mode 100644 index 00000000..cfbaec30 --- /dev/null +++ b/metrics/stats/avg.py @@ -0,0 +1,24 @@ +from __future__ import absolute_import + +from kafka.metrics.stats.sampled_stat import AbstractSampledStat + + +class Avg(AbstractSampledStat): + """ + An AbstractSampledStat that maintains a simple average over its samples. + """ + def __init__(self): + super(Avg, self).__init__(0.0) + + def update(self, sample, config, value, now): + sample.value += value + + def combine(self, samples, config, now): + total_sum = 0 + total_count = 0 + for sample in samples: + total_sum += sample.value + total_count += sample.event_count + if not total_count: + return 0 + return float(total_sum) / total_count diff --git a/metrics/stats/count.py b/metrics/stats/count.py new file mode 100644 index 00000000..6e0a2d54 --- /dev/null +++ b/metrics/stats/count.py @@ -0,0 +1,17 @@ +from __future__ import absolute_import + +from kafka.metrics.stats.sampled_stat import AbstractSampledStat + + +class Count(AbstractSampledStat): + """ + An AbstractSampledStat that maintains a simple count of what it has seen. + """ + def __init__(self): + super(Count, self).__init__(0.0) + + def update(self, sample, config, value, now): + sample.value += 1.0 + + def combine(self, samples, config, now): + return float(sum(sample.value for sample in samples)) diff --git a/metrics/stats/histogram.py b/metrics/stats/histogram.py new file mode 100644 index 00000000..ecc6c9db --- /dev/null +++ b/metrics/stats/histogram.py @@ -0,0 +1,95 @@ +from __future__ import absolute_import + +import math + + +class Histogram(object): + def __init__(self, bin_scheme): + self._hist = [0.0] * bin_scheme.bins + self._count = 0.0 + self._bin_scheme = bin_scheme + + def record(self, value): + self._hist[self._bin_scheme.to_bin(value)] += 1.0 + self._count += 1.0 + + def value(self, quantile): + if self._count == 0.0: + return float('NaN') + _sum = 0.0 + quant = float(quantile) + for i, value in enumerate(self._hist[:-1]): + _sum += value + if _sum / self._count > quant: + return self._bin_scheme.from_bin(i) + return float('inf') + + @property + def counts(self): + return self._hist + + def clear(self): + for i in range(self._hist): + self._hist[i] = 0.0 + self._count = 0 + + def __str__(self): + values = ['%.10f:%.0f' % (self._bin_scheme.from_bin(i), value) for + i, value in enumerate(self._hist[:-1])] + values.append('%s:%s' % (float('inf'), self._hist[-1])) + return '{%s}' % ','.join(values) + + class ConstantBinScheme(object): + def __init__(self, bins, min_val, max_val): + if bins < 2: + raise ValueError('Must have at least 2 bins.') + self._min = float(min_val) + self._max = float(max_val) + self._bins = int(bins) + self._bucket_width = (max_val - min_val) / (bins - 2) + + @property + def bins(self): + return self._bins + + def from_bin(self, b): + if b == 0: + return float('-inf') + elif b == self._bins - 1: + return float('inf') + else: + return self._min + (b - 1) * self._bucket_width + + def to_bin(self, x): + if x < self._min: + return 0 + elif x > self._max: + return self._bins - 1 + else: + return int(((x - self._min) / self._bucket_width) + 1) + + class LinearBinScheme(object): + def __init__(self, num_bins, max_val): + self._bins = num_bins + self._max = max_val + self._scale = max_val / (num_bins * (num_bins - 1) / 2) + + @property + def bins(self): + return self._bins + + def from_bin(self, b): + if b == self._bins - 1: + return float('inf') + else: + unscaled = (b * (b + 1.0)) / 2.0 + return unscaled * self._scale + + def to_bin(self, x): + if x < 0.0: + raise ValueError('Values less than 0.0 not accepted.') + elif x > self._max: + return self._bins - 1 + else: + scaled = x / self._scale + return int(-0.5 + math.sqrt(2.0 * scaled + 0.25)) diff --git a/metrics/stats/max_stat.py b/metrics/stats/max_stat.py new file mode 100644 index 00000000..08aebddf --- /dev/null +++ b/metrics/stats/max_stat.py @@ -0,0 +1,17 @@ +from __future__ import absolute_import + +from kafka.metrics.stats.sampled_stat import AbstractSampledStat + + +class Max(AbstractSampledStat): + """An AbstractSampledStat that gives the max over its samples.""" + def __init__(self): + super(Max, self).__init__(float('-inf')) + + def update(self, sample, config, value, now): + sample.value = max(sample.value, value) + + def combine(self, samples, config, now): + if not samples: + return float('-inf') + return float(max(sample.value for sample in samples)) diff --git a/metrics/stats/min_stat.py b/metrics/stats/min_stat.py new file mode 100644 index 00000000..072106d8 --- /dev/null +++ b/metrics/stats/min_stat.py @@ -0,0 +1,19 @@ +from __future__ import absolute_import + +import sys + +from kafka.metrics.stats.sampled_stat import AbstractSampledStat + + +class Min(AbstractSampledStat): + """An AbstractSampledStat that gives the min over its samples.""" + def __init__(self): + super(Min, self).__init__(float(sys.maxsize)) + + def update(self, sample, config, value, now): + sample.value = min(sample.value, value) + + def combine(self, samples, config, now): + if not samples: + return float(sys.maxsize) + return float(min(sample.value for sample in samples)) diff --git a/metrics/stats/percentile.py b/metrics/stats/percentile.py new file mode 100644 index 00000000..3a86a84a --- /dev/null +++ b/metrics/stats/percentile.py @@ -0,0 +1,15 @@ +from __future__ import absolute_import + + +class Percentile(object): + def __init__(self, metric_name, percentile): + self._metric_name = metric_name + self._percentile = float(percentile) + + @property + def name(self): + return self._metric_name + + @property + def percentile(self): + return self._percentile diff --git a/metrics/stats/percentiles.py b/metrics/stats/percentiles.py new file mode 100644 index 00000000..6d702e80 --- /dev/null +++ b/metrics/stats/percentiles.py @@ -0,0 +1,74 @@ +from __future__ import absolute_import + +from kafka.metrics import AnonMeasurable, NamedMeasurable +from kafka.metrics.compound_stat import AbstractCompoundStat +from kafka.metrics.stats import Histogram +from kafka.metrics.stats.sampled_stat import AbstractSampledStat + + +class BucketSizing(object): + CONSTANT = 0 + LINEAR = 1 + + +class Percentiles(AbstractSampledStat, AbstractCompoundStat): + """A compound stat that reports one or more percentiles""" + def __init__(self, size_in_bytes, bucketing, max_val, min_val=0.0, + percentiles=None): + super(Percentiles, self).__init__(0.0) + self._percentiles = percentiles or [] + self._buckets = int(size_in_bytes / 4) + if bucketing == BucketSizing.CONSTANT: + self._bin_scheme = Histogram.ConstantBinScheme(self._buckets, + min_val, max_val) + elif bucketing == BucketSizing.LINEAR: + if min_val != 0.0: + raise ValueError('Linear bucket sizing requires min_val' + ' to be 0.0.') + self.bin_scheme = Histogram.LinearBinScheme(self._buckets, max_val) + else: + ValueError('Unknown bucket type: %s' % (bucketing,)) + + def stats(self): + measurables = [] + + def make_measure_fn(pct): + return lambda config, now: self.value(config, now, + pct / 100.0) + + for percentile in self._percentiles: + measure_fn = make_measure_fn(percentile.percentile) + stat = NamedMeasurable(percentile.name, AnonMeasurable(measure_fn)) + measurables.append(stat) + return measurables + + def value(self, config, now, quantile): + self.purge_obsolete_samples(config, now) + count = sum(sample.event_count for sample in self._samples) + if count == 0.0: + return float('NaN') + sum_val = 0.0 + quant = float(quantile) + for b in range(self._buckets): + for sample in self._samples: + assert type(sample) is self.HistogramSample + hist = sample.histogram.counts + sum_val += hist[b] + if sum_val / count > quant: + return self._bin_scheme.from_bin(b) + return float('inf') + + def combine(self, samples, config, now): + return self.value(config, now, 0.5) + + def new_sample(self, time_ms): + return Percentiles.HistogramSample(self._bin_scheme, time_ms) + + def update(self, sample, config, value, time_ms): + assert type(sample) is self.HistogramSample + sample.histogram.record(value) + + class HistogramSample(AbstractSampledStat.Sample): + def __init__(self, scheme, now): + super(Percentiles.HistogramSample, self).__init__(0.0, now) + self.histogram = Histogram(scheme) diff --git a/metrics/stats/rate.py b/metrics/stats/rate.py new file mode 100644 index 00000000..68393fbf --- /dev/null +++ b/metrics/stats/rate.py @@ -0,0 +1,117 @@ +from __future__ import absolute_import + +from kafka.metrics.measurable_stat import AbstractMeasurableStat +from kafka.metrics.stats.sampled_stat import AbstractSampledStat + + +class TimeUnit(object): + _names = { + 'nanosecond': 0, + 'microsecond': 1, + 'millisecond': 2, + 'second': 3, + 'minute': 4, + 'hour': 5, + 'day': 6, + } + + NANOSECONDS = _names['nanosecond'] + MICROSECONDS = _names['microsecond'] + MILLISECONDS = _names['millisecond'] + SECONDS = _names['second'] + MINUTES = _names['minute'] + HOURS = _names['hour'] + DAYS = _names['day'] + + @staticmethod + def get_name(time_unit): + return TimeUnit._names[time_unit] + + +class Rate(AbstractMeasurableStat): + """ + The rate of the given quantity. By default this is the total observed + over a set of samples from a sampled statistic divided by the elapsed + time over the sample windows. Alternative AbstractSampledStat + implementations can be provided, however, to record the rate of + occurrences (e.g. the count of values measured over the time interval) + or other such values. + """ + def __init__(self, time_unit=TimeUnit.SECONDS, sampled_stat=None): + self._stat = sampled_stat or SampledTotal() + self._unit = time_unit + + def unit_name(self): + return TimeUnit.get_name(self._unit) + + def record(self, config, value, time_ms): + self._stat.record(config, value, time_ms) + + def measure(self, config, now): + value = self._stat.measure(config, now) + return float(value) / self.convert(self.window_size(config, now)) + + def window_size(self, config, now): + # purge old samples before we compute the window size + self._stat.purge_obsolete_samples(config, now) + + """ + Here we check the total amount of time elapsed since the oldest + non-obsolete window. This give the total window_size of the batch + which is the time used for Rate computation. However, there is + an issue if we do not have sufficient data for e.g. if only + 1 second has elapsed in a 30 second window, the measured rate + will be very high. Hence we assume that the elapsed time is + always N-1 complete windows plus whatever fraction of the final + window is complete. + + Note that we could simply count the amount of time elapsed in + the current window and add n-1 windows to get the total time, + but this approach does not account for sleeps. AbstractSampledStat + only creates samples whenever record is called, if no record is + called for a period of time that time is not accounted for in + window_size and produces incorrect results. + """ + total_elapsed_time_ms = now - self._stat.oldest(now).last_window_ms + # Check how many full windows of data we have currently retained + num_full_windows = int(total_elapsed_time_ms / config.time_window_ms) + min_full_windows = config.samples - 1 + + # If the available windows are less than the minimum required, + # add the difference to the totalElapsedTime + if num_full_windows < min_full_windows: + total_elapsed_time_ms += ((min_full_windows - num_full_windows) * + config.time_window_ms) + + return total_elapsed_time_ms + + def convert(self, time_ms): + if self._unit == TimeUnit.NANOSECONDS: + return time_ms * 1000.0 * 1000.0 + elif self._unit == TimeUnit.MICROSECONDS: + return time_ms * 1000.0 + elif self._unit == TimeUnit.MILLISECONDS: + return time_ms + elif self._unit == TimeUnit.SECONDS: + return time_ms / 1000.0 + elif self._unit == TimeUnit.MINUTES: + return time_ms / (60.0 * 1000.0) + elif self._unit == TimeUnit.HOURS: + return time_ms / (60.0 * 60.0 * 1000.0) + elif self._unit == TimeUnit.DAYS: + return time_ms / (24.0 * 60.0 * 60.0 * 1000.0) + else: + raise ValueError('Unknown unit: %s' % (self._unit,)) + + +class SampledTotal(AbstractSampledStat): + def __init__(self, initial_value=None): + if initial_value is not None: + raise ValueError('initial_value cannot be set on SampledTotal') + super(SampledTotal, self).__init__(0.0) + + def update(self, sample, config, value, time_ms): + sample.value += value + + def combine(self, samples, config, now): + return float(sum(sample.value for sample in samples)) diff --git a/metrics/stats/sampled_stat.py b/metrics/stats/sampled_stat.py new file mode 100644 index 00000000..c41b14bb --- /dev/null +++ b/metrics/stats/sampled_stat.py @@ -0,0 +1,101 @@ +from __future__ import absolute_import + +import abc + +from kafka.metrics.measurable_stat import AbstractMeasurableStat + + +class AbstractSampledStat(AbstractMeasurableStat): + """ + An AbstractSampledStat records a single scalar value measured over + one or more samples. Each sample is recorded over a configurable + window. The window can be defined by number of events or elapsed + time (or both, if both are given the window is complete when + *either* the event count or elapsed time criterion is met). + + All the samples are combined to produce the measurement. When a + window is complete the oldest sample is cleared and recycled to + begin recording the next sample. + + Subclasses of this class define different statistics measured + using this basic pattern. + """ + __metaclass__ = abc.ABCMeta + + def __init__(self, initial_value): + self._initial_value = initial_value + self._samples = [] + self._current = 0 + + @abc.abstractmethod + def update(self, sample, config, value, time_ms): + raise NotImplementedError + + @abc.abstractmethod + def combine(self, samples, config, now): + raise NotImplementedError + + def record(self, config, value, time_ms): + sample = self.current(time_ms) + if sample.is_complete(time_ms, config): + sample = self._advance(config, time_ms) + self.update(sample, config, float(value), time_ms) + sample.event_count += 1 + + def new_sample(self, time_ms): + return self.Sample(self._initial_value, time_ms) + + def measure(self, config, now): + self.purge_obsolete_samples(config, now) + return float(self.combine(self._samples, config, now)) + + def current(self, time_ms): + if not self._samples: + self._samples.append(self.new_sample(time_ms)) + return self._samples[self._current] + + def oldest(self, now): + if not self._samples: + self._samples.append(self.new_sample(now)) + oldest = self._samples[0] + for sample in self._samples[1:]: + if sample.last_window_ms < oldest.last_window_ms: + oldest = sample + return oldest + + def purge_obsolete_samples(self, config, now): + """ + Timeout any windows that have expired in the absence of any events + """ + expire_age = config.samples * config.time_window_ms + for sample in self._samples: + if now - sample.last_window_ms >= expire_age: + sample.reset(now) + + def _advance(self, config, time_ms): + self._current = (self._current + 1) % config.samples + if self._current >= len(self._samples): + sample = self.new_sample(time_ms) + self._samples.append(sample) + return sample + else: + sample = self.current(time_ms) + sample.reset(time_ms) + return sample + + class Sample(object): + + def __init__(self, initial_value, now): + self.initial_value = initial_value + self.event_count = 0 + self.last_window_ms = now + self.value = initial_value + + def reset(self, now): + self.event_count = 0 + self.last_window_ms = now + self.value = self.initial_value + + def is_complete(self, time_ms, config): + return (time_ms - self.last_window_ms >= config.time_window_ms or + self.event_count >= config.event_window) diff --git a/metrics/stats/sensor.py b/metrics/stats/sensor.py new file mode 100644 index 00000000..571723f9 --- /dev/null +++ b/metrics/stats/sensor.py @@ -0,0 +1,134 @@ +from __future__ import absolute_import + +import threading +import time + +from kafka.errors import QuotaViolationError +from kafka.metrics import KafkaMetric + + +class Sensor(object): + """ + A sensor applies a continuous sequence of numerical values + to a set of associated metrics. For example a sensor on + message size would record a sequence of message sizes using + the `record(double)` api and would maintain a set + of metrics about request sizes such as the average or max. + """ + def __init__(self, registry, name, parents, config, + inactive_sensor_expiration_time_seconds): + if not name: + raise ValueError('name must be non-empty') + self._lock = threading.RLock() + self._registry = registry + self._name = name + self._parents = parents or [] + self._metrics = [] + self._stats = [] + self._config = config + self._inactive_sensor_expiration_time_ms = ( + inactive_sensor_expiration_time_seconds * 1000) + self._last_record_time = time.time() * 1000 + self._check_forest(set()) + + def _check_forest(self, sensors): + """Validate that this sensor doesn't end up referencing itself.""" + if self in sensors: + raise ValueError('Circular dependency in sensors: %s is its own' + 'parent.' % (self.name,)) + sensors.add(self) + for parent in self._parents: + parent._check_forest(sensors) + + @property + def name(self): + """ + The name this sensor is registered with. + This name will be unique among all registered sensors. + """ + return self._name + + @property + def metrics(self): + return tuple(self._metrics) + + def record(self, value=1.0, time_ms=None): + """ + Record a value at a known time. + Arguments: + value (double): The value we are recording + time_ms (int): A POSIX timestamp in milliseconds. + Default: The time when record() is evaluated (now) + + Raises: + QuotaViolationException: if recording this value moves a + metric beyond its configured maximum or minimum bound + """ + if time_ms is None: + time_ms = time.time() * 1000 + self._last_record_time = time_ms + with self._lock: # XXX high volume, might be performance issue + # increment all the stats + for stat in self._stats: + stat.record(self._config, value, time_ms) + self._check_quotas(time_ms) + for parent in self._parents: + parent.record(value, time_ms) + + def _check_quotas(self, time_ms): + """ + Check if we have violated our quota for any metric that + has a configured quota + """ + for metric in self._metrics: + if metric.config and metric.config.quota: + value = metric.value(time_ms) + if not metric.config.quota.is_acceptable(value): + raise QuotaViolationError("'%s' violated quota. Actual: " + "%d, Threshold: %d" % + (metric.metric_name, + value, + metric.config.quota.bound)) + + def add_compound(self, compound_stat, config=None): + """ + Register a compound statistic with this sensor which + yields multiple measurable quantities (like a histogram) + + Arguments: + stat (AbstractCompoundStat): The stat to register + config (MetricConfig): The configuration for this stat. + If None then the stat will use the default configuration + for this sensor. + """ + if not compound_stat: + raise ValueError('compound stat must be non-empty') + self._stats.append(compound_stat) + for named_measurable in compound_stat.stats(): + metric = KafkaMetric(named_measurable.name, named_measurable.stat, + config or self._config) + self._registry.register_metric(metric) + self._metrics.append(metric) + + def add(self, metric_name, stat, config=None): + """ + Register a metric with this sensor + + Arguments: + metric_name (MetricName): The name of the metric + stat (AbstractMeasurableStat): The statistic to keep + config (MetricConfig): A special configuration for this metric. + If None use the sensor default configuration. + """ + with self._lock: + metric = KafkaMetric(metric_name, stat, config or self._config) + self._registry.register_metric(metric) + self._metrics.append(metric) + self._stats.append(stat) + + def has_expired(self): + """ + Return True if the Sensor is eligible for removal due to inactivity. + """ + return ((time.time() * 1000 - self._last_record_time) > + self._inactive_sensor_expiration_time_ms) diff --git a/metrics/stats/total.py b/metrics/stats/total.py new file mode 100644 index 00000000..5b3bb87f --- /dev/null +++ b/metrics/stats/total.py @@ -0,0 +1,15 @@ +from __future__ import absolute_import + +from kafka.metrics.measurable_stat import AbstractMeasurableStat + + +class Total(AbstractMeasurableStat): + """An un-windowed cumulative total maintained over all time.""" + def __init__(self, value=0.0): + self._total = value + + def record(self, config, value, now): + self._total += value + + def measure(self, config, now): + return float(self._total) diff --git a/oauth/__init__.py b/oauth/__init__.py new file mode 100644 index 00000000..8c834956 --- /dev/null +++ b/oauth/__init__.py @@ -0,0 +1,3 @@ +from __future__ import absolute_import + +from kafka.oauth.abstract import AbstractTokenProvider diff --git a/oauth/abstract.py b/oauth/abstract.py new file mode 100644 index 00000000..8d89ff51 --- /dev/null +++ b/oauth/abstract.py @@ -0,0 +1,42 @@ +from __future__ import absolute_import + +import abc + +# This statement is compatible with both Python 2.7 & 3+ +ABC = abc.ABCMeta('ABC', (object,), {'__slots__': ()}) + +class AbstractTokenProvider(ABC): + """ + A Token Provider must be used for the SASL OAuthBearer protocol. + + The implementation should ensure token reuse so that multiple + calls at connect time do not create multiple tokens. The implementation + should also periodically refresh the token in order to guarantee + that each call returns an unexpired token. A timeout error should + be returned after a short period of inactivity so that the + broker can log debugging info and retry. + + Token Providers MUST implement the token() method + """ + + def __init__(self, **config): + pass + + @abc.abstractmethod + def token(self): + """ + Returns a (str) ID/Access Token to be sent to the Kafka + client. + """ + pass + + def extensions(self): + """ + This is an OPTIONAL method that may be implemented. + + Returns a map of key-value pairs that can + be sent with the SASL/OAUTHBEARER initial client request. If + not implemented, the values are ignored. This feature is only available + in Kafka >= 2.1.0. + """ + return {} diff --git a/partitioner/__init__.py b/partitioner/__init__.py new file mode 100644 index 00000000..21a3bbb6 --- /dev/null +++ b/partitioner/__init__.py @@ -0,0 +1,8 @@ +from __future__ import absolute_import + +from kafka.partitioner.default import DefaultPartitioner, murmur2 + + +__all__ = [ + 'DefaultPartitioner', 'murmur2' +] diff --git a/partitioner/default.py b/partitioner/default.py new file mode 100644 index 00000000..d0914c68 --- /dev/null +++ b/partitioner/default.py @@ -0,0 +1,102 @@ +from __future__ import absolute_import + +import random + +from kafka.vendor import six + + +class DefaultPartitioner(object): + """Default partitioner. + + Hashes key to partition using murmur2 hashing (from java client) + If key is None, selects partition randomly from available, + or from all partitions if none are currently available + """ + @classmethod + def __call__(cls, key, all_partitions, available): + """ + Get the partition corresponding to key + :param key: partitioning key + :param all_partitions: list of all partitions sorted by partition ID + :param available: list of available partitions in no particular order + :return: one of the values from all_partitions or available + """ + if key is None: + if available: + return random.choice(available) + return random.choice(all_partitions) + + idx = murmur2(key) + idx &= 0x7fffffff + idx %= len(all_partitions) + return all_partitions[idx] + + +# https://github.com/apache/kafka/blob/0.8.2/clients/src/main/java/org/apache/kafka/common/utils/Utils.java#L244 +def murmur2(data): + """Pure-python Murmur2 implementation. + + Based on java client, see org.apache.kafka.common.utils.Utils.murmur2 + + Args: + data (bytes): opaque bytes + + Returns: MurmurHash2 of data + """ + # Python2 bytes is really a str, causing the bitwise operations below to fail + # so convert to bytearray. + if six.PY2: + data = bytearray(bytes(data)) + + length = len(data) + seed = 0x9747b28c + # 'm' and 'r' are mixing constants generated offline. + # They're not really 'magic', they just happen to work well. + m = 0x5bd1e995 + r = 24 + + # Initialize the hash to a random value + h = seed ^ length + length4 = length // 4 + + for i in range(length4): + i4 = i * 4 + k = ((data[i4 + 0] & 0xff) + + ((data[i4 + 1] & 0xff) << 8) + + ((data[i4 + 2] & 0xff) << 16) + + ((data[i4 + 3] & 0xff) << 24)) + k &= 0xffffffff + k *= m + k &= 0xffffffff + k ^= (k % 0x100000000) >> r # k ^= k >>> r + k &= 0xffffffff + k *= m + k &= 0xffffffff + + h *= m + h &= 0xffffffff + h ^= k + h &= 0xffffffff + + # Handle the last few bytes of the input array + extra_bytes = length % 4 + if extra_bytes >= 3: + h ^= (data[(length & ~3) + 2] & 0xff) << 16 + h &= 0xffffffff + if extra_bytes >= 2: + h ^= (data[(length & ~3) + 1] & 0xff) << 8 + h &= 0xffffffff + if extra_bytes >= 1: + h ^= (data[length & ~3] & 0xff) + h &= 0xffffffff + h *= m + h &= 0xffffffff + + h ^= (h % 0x100000000) >> 13 # h >>> 13; + h &= 0xffffffff + h *= m + h &= 0xffffffff + h ^= (h % 0x100000000) >> 15 # h >>> 15; + h &= 0xffffffff + + return h diff --git a/producer/__init__.py b/producer/__init__.py new file mode 100644 index 00000000..576c772a --- /dev/null +++ b/producer/__init__.py @@ -0,0 +1,7 @@ +from __future__ import absolute_import + +from kafka.producer.kafka import KafkaProducer + +__all__ = [ + 'KafkaProducer' +] diff --git a/producer/buffer.py b/producer/buffer.py new file mode 100644 index 00000000..10080170 --- /dev/null +++ b/producer/buffer.py @@ -0,0 +1,115 @@ +from __future__ import absolute_import, division + +import collections +import io +import threading +import time + +from kafka.metrics.stats import Rate + +import kafka.errors as Errors + + +class SimpleBufferPool(object): + """A simple pool of BytesIO objects with a weak memory ceiling.""" + def __init__(self, memory, poolable_size, metrics=None, metric_group_prefix='producer-metrics'): + """Create a new buffer pool. + + Arguments: + memory (int): maximum memory that this buffer pool can allocate + poolable_size (int): memory size per buffer to cache in the free + list rather than deallocating + """ + self._poolable_size = poolable_size + self._lock = threading.RLock() + + buffers = int(memory / poolable_size) if poolable_size else 0 + self._free = collections.deque([io.BytesIO() for _ in range(buffers)]) + + self._waiters = collections.deque() + self.wait_time = None + if metrics: + self.wait_time = metrics.sensor('bufferpool-wait-time') + self.wait_time.add(metrics.metric_name( + 'bufferpool-wait-ratio', metric_group_prefix, + 'The fraction of time an appender waits for space allocation.'), + Rate()) + + def allocate(self, size, max_time_to_block_ms): + """ + Allocate a buffer of the given size. This method blocks if there is not + enough memory and the buffer pool is configured with blocking mode. + + Arguments: + size (int): The buffer size to allocate in bytes [ignored] + max_time_to_block_ms (int): The maximum time in milliseconds to + block for buffer memory to be available + + Returns: + io.BytesIO + """ + with self._lock: + # check if we have a free buffer of the right size pooled + if self._free: + return self._free.popleft() + + elif self._poolable_size == 0: + return io.BytesIO() + + else: + # we are out of buffers and will have to block + buf = None + more_memory = threading.Condition(self._lock) + self._waiters.append(more_memory) + # loop over and over until we have a buffer or have reserved + # enough memory to allocate one + while buf is None: + start_wait = time.time() + more_memory.wait(max_time_to_block_ms / 1000.0) + end_wait = time.time() + if self.wait_time: + self.wait_time.record(end_wait - start_wait) + + if self._free: + buf = self._free.popleft() + else: + self._waiters.remove(more_memory) + raise Errors.KafkaTimeoutError( + "Failed to allocate memory within the configured" + " max blocking time") + + # remove the condition for this thread to let the next thread + # in line start getting memory + removed = self._waiters.popleft() + assert removed is more_memory, 'Wrong condition' + + # signal any additional waiters if there is more memory left + # over for them + if self._free and self._waiters: + self._waiters[0].notify() + + # unlock and return the buffer + return buf + + def deallocate(self, buf): + """ + Return buffers to the pool. If they are of the poolable size add them + to the free list, otherwise just mark the memory as free. + + Arguments: + buffer_ (io.BytesIO): The buffer to return + """ + with self._lock: + # BytesIO.truncate here makes the pool somewhat pointless + # but we stick with the BufferPool API until migrating to + # bytesarray / memoryview. The buffer we return must not + # expose any prior data on read(). + buf.truncate(0) + self._free.append(buf) + if self._waiters: + self._waiters[0].notify() + + def queued(self): + """The number of threads blocked waiting on memory.""" + with self._lock: + return len(self._waiters) diff --git a/producer/future.py b/producer/future.py new file mode 100644 index 00000000..07fa4adb --- /dev/null +++ b/producer/future.py @@ -0,0 +1,71 @@ +from __future__ import absolute_import + +import collections +import threading + +from kafka import errors as Errors +from kafka.future import Future + + +class FutureProduceResult(Future): + def __init__(self, topic_partition): + super(FutureProduceResult, self).__init__() + self.topic_partition = topic_partition + self._latch = threading.Event() + + def success(self, value): + ret = super(FutureProduceResult, self).success(value) + self._latch.set() + return ret + + def failure(self, error): + ret = super(FutureProduceResult, self).failure(error) + self._latch.set() + return ret + + def wait(self, timeout=None): + # wait() on python2.6 returns None instead of the flag value + return self._latch.wait(timeout) or self._latch.is_set() + + +class FutureRecordMetadata(Future): + def __init__(self, produce_future, relative_offset, timestamp_ms, checksum, serialized_key_size, serialized_value_size, serialized_header_size): + super(FutureRecordMetadata, self).__init__() + self._produce_future = produce_future + # packing args as a tuple is a minor speed optimization + self.args = (relative_offset, timestamp_ms, checksum, serialized_key_size, serialized_value_size, serialized_header_size) + produce_future.add_callback(self._produce_success) + produce_future.add_errback(self.failure) + + def _produce_success(self, offset_and_timestamp): + offset, produce_timestamp_ms, log_start_offset = offset_and_timestamp + + # Unpacking from args tuple is minor speed optimization + (relative_offset, timestamp_ms, checksum, + serialized_key_size, serialized_value_size, serialized_header_size) = self.args + + # None is when Broker does not support the API (<0.10) and + # -1 is when the broker is configured for CREATE_TIME timestamps + if produce_timestamp_ms is not None and produce_timestamp_ms != -1: + timestamp_ms = produce_timestamp_ms + if offset != -1 and relative_offset is not None: + offset += relative_offset + tp = self._produce_future.topic_partition + metadata = RecordMetadata(tp[0], tp[1], tp, offset, timestamp_ms, log_start_offset, + checksum, serialized_key_size, + serialized_value_size, serialized_header_size) + self.success(metadata) + + def get(self, timeout=None): + if not self.is_done and not self._produce_future.wait(timeout): + raise Errors.KafkaTimeoutError( + "Timeout after waiting for %s secs." % (timeout,)) + assert self.is_done + if self.failed(): + raise self.exception # pylint: disable-msg=raising-bad-type + return self.value + + +RecordMetadata = collections.namedtuple( + 'RecordMetadata', ['topic', 'partition', 'topic_partition', 'offset', 'timestamp', 'log_start_offset', + 'checksum', 'serialized_key_size', 'serialized_value_size', 'serialized_header_size']) diff --git a/producer/kafka.py b/producer/kafka.py new file mode 100644 index 00000000..dd1cc508 --- /dev/null +++ b/producer/kafka.py @@ -0,0 +1,752 @@ +from __future__ import absolute_import + +import atexit +import copy +import logging +import socket +import threading +import time +import weakref + +from kafka.vendor import six + +import kafka.errors as Errors +from kafka.client_async import KafkaClient, selectors +from kafka.codec import has_gzip, has_snappy, has_lz4, has_zstd +from kafka.metrics import MetricConfig, Metrics +from kafka.partitioner.default import DefaultPartitioner +from kafka.producer.future import FutureRecordMetadata, FutureProduceResult +from kafka.producer.record_accumulator import AtomicInteger, RecordAccumulator +from kafka.producer.sender import Sender +from kafka.record.default_records import DefaultRecordBatchBuilder +from kafka.record.legacy_records import LegacyRecordBatchBuilder +from kafka.serializer import Serializer +from kafka.structs import TopicPartition + + +log = logging.getLogger(__name__) +PRODUCER_CLIENT_ID_SEQUENCE = AtomicInteger() + + +class KafkaProducer(object): + """A Kafka client that publishes records to the Kafka cluster. + + The producer is thread safe and sharing a single producer instance across + threads will generally be faster than having multiple instances. + + The producer consists of a pool of buffer space that holds records that + haven't yet been transmitted to the server as well as a background I/O + thread that is responsible for turning these records into requests and + transmitting them to the cluster. + + :meth:`~kafka.KafkaProducer.send` is asynchronous. When called it adds the + record to a buffer of pending record sends and immediately returns. This + allows the producer to batch together individual records for efficiency. + + The 'acks' config controls the criteria under which requests are considered + complete. The "all" setting will result in blocking on the full commit of + the record, the slowest but most durable setting. + + If the request fails, the producer can automatically retry, unless + 'retries' is configured to 0. Enabling retries also opens up the + possibility of duplicates (see the documentation on message + delivery semantics for details: + https://kafka.apache.org/documentation.html#semantics + ). + + The producer maintains buffers of unsent records for each partition. These + buffers are of a size specified by the 'batch_size' config. Making this + larger can result in more batching, but requires more memory (since we will + generally have one of these buffers for each active partition). + + By default a buffer is available to send immediately even if there is + additional unused space in the buffer. However if you want to reduce the + number of requests you can set 'linger_ms' to something greater than 0. + This will instruct the producer to wait up to that number of milliseconds + before sending a request in hope that more records will arrive to fill up + the same batch. This is analogous to Nagle's algorithm in TCP. Note that + records that arrive close together in time will generally batch together + even with linger_ms=0 so under heavy load batching will occur regardless of + the linger configuration; however setting this to something larger than 0 + can lead to fewer, more efficient requests when not under maximal load at + the cost of a small amount of latency. + + The buffer_memory controls the total amount of memory available to the + producer for buffering. If records are sent faster than they can be + transmitted to the server then this buffer space will be exhausted. When + the buffer space is exhausted additional send calls will block. + + The key_serializer and value_serializer instruct how to turn the key and + value objects the user provides into bytes. + + Keyword Arguments: + bootstrap_servers: 'host[:port]' string (or list of 'host[:port]' + strings) that the producer should contact to bootstrap initial + cluster metadata. This does not have to be the full node list. + It just needs to have at least one broker that will respond to a + Metadata API Request. Default port is 9092. If no servers are + specified, will default to localhost:9092. + client_id (str): a name for this client. This string is passed in + each request to servers and can be used to identify specific + server-side log entries that correspond to this client. + Default: 'kafka-python-producer-#' (appended with a unique number + per instance) + key_serializer (callable): used to convert user-supplied keys to bytes + If not None, called as f(key), should return bytes. Default: None. + value_serializer (callable): used to convert user-supplied message + values to bytes. If not None, called as f(value), should return + bytes. Default: None. + acks (0, 1, 'all'): The number of acknowledgments the producer requires + the leader to have received before considering a request complete. + This controls the durability of records that are sent. The + following settings are common: + + 0: Producer will not wait for any acknowledgment from the server. + The message will immediately be added to the socket + buffer and considered sent. No guarantee can be made that the + server has received the record in this case, and the retries + configuration will not take effect (as the client won't + generally know of any failures). The offset given back for each + record will always be set to -1. + 1: Wait for leader to write the record to its local log only. + Broker will respond without awaiting full acknowledgement from + all followers. In this case should the leader fail immediately + after acknowledging the record but before the followers have + replicated it then the record will be lost. + all: Wait for the full set of in-sync replicas to write the record. + This guarantees that the record will not be lost as long as at + least one in-sync replica remains alive. This is the strongest + available guarantee. + If unset, defaults to acks=1. + compression_type (str): The compression type for all data generated by + the producer. Valid values are 'gzip', 'snappy', 'lz4', 'zstd' or None. + Compression is of full batches of data, so the efficacy of batching + will also impact the compression ratio (more batching means better + compression). Default: None. + retries (int): Setting a value greater than zero will cause the client + to resend any record whose send fails with a potentially transient + error. Note that this retry is no different than if the client + resent the record upon receiving the error. Allowing retries + without setting max_in_flight_requests_per_connection to 1 will + potentially change the ordering of records because if two batches + are sent to a single partition, and the first fails and is retried + but the second succeeds, then the records in the second batch may + appear first. + Default: 0. + batch_size (int): Requests sent to brokers will contain multiple + batches, one for each partition with data available to be sent. + A small batch size will make batching less common and may reduce + throughput (a batch size of zero will disable batching entirely). + Default: 16384 + linger_ms (int): The producer groups together any records that arrive + in between request transmissions into a single batched request. + Normally this occurs only under load when records arrive faster + than they can be sent out. However in some circumstances the client + may want to reduce the number of requests even under moderate load. + This setting accomplishes this by adding a small amount of + artificial delay; that is, rather than immediately sending out a + record the producer will wait for up to the given delay to allow + other records to be sent so that the sends can be batched together. + This can be thought of as analogous to Nagle's algorithm in TCP. + This setting gives the upper bound on the delay for batching: once + we get batch_size worth of records for a partition it will be sent + immediately regardless of this setting, however if we have fewer + than this many bytes accumulated for this partition we will + 'linger' for the specified time waiting for more records to show + up. This setting defaults to 0 (i.e. no delay). Setting linger_ms=5 + would have the effect of reducing the number of requests sent but + would add up to 5ms of latency to records sent in the absence of + load. Default: 0. + partitioner (callable): Callable used to determine which partition + each message is assigned to. Called (after key serialization): + partitioner(key_bytes, all_partitions, available_partitions). + The default partitioner implementation hashes each non-None key + using the same murmur2 algorithm as the java client so that + messages with the same key are assigned to the same partition. + When a key is None, the message is delivered to a random partition + (filtered to partitions with available leaders only, if possible). + buffer_memory (int): The total bytes of memory the producer should use + to buffer records waiting to be sent to the server. If records are + sent faster than they can be delivered to the server the producer + will block up to max_block_ms, raising an exception on timeout. + In the current implementation, this setting is an approximation. + Default: 33554432 (32MB) + connections_max_idle_ms: Close idle connections after the number of + milliseconds specified by this config. The broker closes idle + connections after connections.max.idle.ms, so this avoids hitting + unexpected socket disconnected errors on the client. + Default: 540000 + max_block_ms (int): Number of milliseconds to block during + :meth:`~kafka.KafkaProducer.send` and + :meth:`~kafka.KafkaProducer.partitions_for`. These methods can be + blocked either because the buffer is full or metadata unavailable. + Blocking in the user-supplied serializers or partitioner will not be + counted against this timeout. Default: 60000. + max_request_size (int): The maximum size of a request. This is also + effectively a cap on the maximum record size. Note that the server + has its own cap on record size which may be different from this. + This setting will limit the number of record batches the producer + will send in a single request to avoid sending huge requests. + Default: 1048576. + metadata_max_age_ms (int): The period of time in milliseconds after + which we force a refresh of metadata even if we haven't seen any + partition leadership changes to proactively discover any new + brokers or partitions. Default: 300000 + retry_backoff_ms (int): Milliseconds to backoff when retrying on + errors. Default: 100. + request_timeout_ms (int): Client request timeout in milliseconds. + Default: 30000. + receive_buffer_bytes (int): The size of the TCP receive buffer + (SO_RCVBUF) to use when reading data. Default: None (relies on + system defaults). Java client defaults to 32768. + send_buffer_bytes (int): The size of the TCP send buffer + (SO_SNDBUF) to use when sending data. Default: None (relies on + system defaults). Java client defaults to 131072. + socket_options (list): List of tuple-arguments to socket.setsockopt + to apply to broker connection sockets. Default: + [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)] + reconnect_backoff_ms (int): The amount of time in milliseconds to + wait before attempting to reconnect to a given host. + Default: 50. + reconnect_backoff_max_ms (int): The maximum amount of time in + milliseconds to backoff/wait when reconnecting to a broker that has + repeatedly failed to connect. If provided, the backoff per host + will increase exponentially for each consecutive connection + failure, up to this maximum. Once the maximum is reached, + reconnection attempts will continue periodically with this fixed + rate. To avoid connection storms, a randomization factor of 0.2 + will be applied to the backoff resulting in a random range between + 20% below and 20% above the computed value. Default: 1000. + max_in_flight_requests_per_connection (int): Requests are pipelined + to kafka brokers up to this number of maximum requests per + broker connection. Note that if this setting is set to be greater + than 1 and there are failed sends, there is a risk of message + re-ordering due to retries (i.e., if retries are enabled). + Default: 5. + security_protocol (str): Protocol used to communicate with brokers. + Valid values are: PLAINTEXT, SSL, SASL_PLAINTEXT, SASL_SSL. + Default: PLAINTEXT. + ssl_context (ssl.SSLContext): pre-configured SSLContext for wrapping + socket connections. If provided, all other ssl_* configurations + will be ignored. Default: None. + ssl_check_hostname (bool): flag to configure whether ssl handshake + should verify that the certificate matches the brokers hostname. + default: true. + ssl_cafile (str): optional filename of ca file to use in certificate + verification. default: none. + ssl_certfile (str): optional filename of file in pem format containing + the client certificate, as well as any ca certificates needed to + establish the certificate's authenticity. default: none. + ssl_keyfile (str): optional filename containing the client private key. + default: none. + ssl_password (str): optional password to be used when loading the + certificate chain. default: none. + ssl_crlfile (str): optional filename containing the CRL to check for + certificate expiration. By default, no CRL check is done. When + providing a file, only the leaf certificate will be checked against + this CRL. The CRL can only be checked with Python 3.4+ or 2.7.9+. + default: none. + ssl_ciphers (str): optionally set the available ciphers for ssl + connections. It should be a string in the OpenSSL cipher list + format. If no cipher can be selected (because compile-time options + or other configuration forbids use of all the specified ciphers), + an ssl.SSLError will be raised. See ssl.SSLContext.set_ciphers + api_version (tuple): Specify which Kafka API version to use. If set to + None, the client will attempt to infer the broker version by probing + various APIs. Example: (0, 10, 2). Default: None + api_version_auto_timeout_ms (int): number of milliseconds to throw a + timeout exception from the constructor when checking the broker + api version. Only applies if api_version set to None. + metric_reporters (list): A list of classes to use as metrics reporters. + Implementing the AbstractMetricsReporter interface allows plugging + in classes that will be notified of new metric creation. Default: [] + metrics_num_samples (int): The number of samples maintained to compute + metrics. Default: 2 + metrics_sample_window_ms (int): The maximum age in milliseconds of + samples used to compute metrics. Default: 30000 + selector (selectors.BaseSelector): Provide a specific selector + implementation to use for I/O multiplexing. + Default: selectors.DefaultSelector + sasl_mechanism (str): Authentication mechanism when security_protocol + is configured for SASL_PLAINTEXT or SASL_SSL. Valid values are: + PLAIN, GSSAPI, OAUTHBEARER, SCRAM-SHA-256, SCRAM-SHA-512. + sasl_plain_username (str): username for sasl PLAIN and SCRAM authentication. + Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms. + sasl_plain_password (str): password for sasl PLAIN and SCRAM authentication. + Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms. + sasl_kerberos_service_name (str): Service name to include in GSSAPI + sasl mechanism handshake. Default: 'kafka' + sasl_kerberos_domain_name (str): kerberos domain name to use in GSSAPI + sasl mechanism handshake. Default: one of bootstrap servers + sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider + instance. (See kafka.oauth.abstract). Default: None + kafka_client (callable): Custom class / callable for creating KafkaClient instances + + Note: + Configuration parameters are described in more detail at + https://kafka.apache.org/0100/documentation/#producerconfigs + """ + DEFAULT_CONFIG = { + 'bootstrap_servers': 'localhost', + 'client_id': None, + 'key_serializer': None, + 'value_serializer': None, + 'acks': 1, + 'bootstrap_topics_filter': set(), + 'compression_type': None, + 'retries': 0, + 'batch_size': 16384, + 'linger_ms': 0, + 'partitioner': DefaultPartitioner(), + 'buffer_memory': 33554432, + 'connections_max_idle_ms': 9 * 60 * 1000, + 'max_block_ms': 60000, + 'max_request_size': 1048576, + 'metadata_max_age_ms': 300000, + 'retry_backoff_ms': 100, + 'request_timeout_ms': 30000, + 'receive_buffer_bytes': None, + 'send_buffer_bytes': None, + 'socket_options': [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)], + 'sock_chunk_bytes': 4096, # undocumented experimental option + 'sock_chunk_buffer_count': 1000, # undocumented experimental option + 'reconnect_backoff_ms': 50, + 'reconnect_backoff_max_ms': 1000, + 'max_in_flight_requests_per_connection': 5, + 'security_protocol': 'PLAINTEXT', + 'ssl_context': None, + 'ssl_check_hostname': True, + 'ssl_cafile': None, + 'ssl_certfile': None, + 'ssl_keyfile': None, + 'ssl_crlfile': None, + 'ssl_password': None, + 'ssl_ciphers': None, + 'api_version': None, + 'api_version_auto_timeout_ms': 2000, + 'metric_reporters': [], + 'metrics_num_samples': 2, + 'metrics_sample_window_ms': 30000, + 'selector': selectors.DefaultSelector, + 'sasl_mechanism': None, + 'sasl_plain_username': None, + 'sasl_plain_password': None, + 'sasl_kerberos_service_name': 'kafka', + 'sasl_kerberos_domain_name': None, + 'sasl_oauth_token_provider': None, + 'kafka_client': KafkaClient, + } + + _COMPRESSORS = { + 'gzip': (has_gzip, LegacyRecordBatchBuilder.CODEC_GZIP), + 'snappy': (has_snappy, LegacyRecordBatchBuilder.CODEC_SNAPPY), + 'lz4': (has_lz4, LegacyRecordBatchBuilder.CODEC_LZ4), + 'zstd': (has_zstd, DefaultRecordBatchBuilder.CODEC_ZSTD), + None: (lambda: True, LegacyRecordBatchBuilder.CODEC_NONE), + } + + def __init__(self, **configs): + log.debug("Starting the Kafka producer") # trace + self.config = copy.copy(self.DEFAULT_CONFIG) + for key in self.config: + if key in configs: + self.config[key] = configs.pop(key) + + # Only check for extra config keys in top-level class + assert not configs, 'Unrecognized configs: %s' % (configs,) + + if self.config['client_id'] is None: + self.config['client_id'] = 'kafka-python-producer-%s' % \ + (PRODUCER_CLIENT_ID_SEQUENCE.increment(),) + + if self.config['acks'] == 'all': + self.config['acks'] = -1 + + # api_version was previously a str. accept old format for now + if isinstance(self.config['api_version'], str): + deprecated = self.config['api_version'] + if deprecated == 'auto': + self.config['api_version'] = None + else: + self.config['api_version'] = tuple(map(int, deprecated.split('.'))) + log.warning('use api_version=%s [tuple] -- "%s" as str is deprecated', + str(self.config['api_version']), deprecated) + + # Configure metrics + metrics_tags = {'client-id': self.config['client_id']} + metric_config = MetricConfig(samples=self.config['metrics_num_samples'], + time_window_ms=self.config['metrics_sample_window_ms'], + tags=metrics_tags) + reporters = [reporter() for reporter in self.config['metric_reporters']] + self._metrics = Metrics(metric_config, reporters) + + client = self.config['kafka_client']( + metrics=self._metrics, metric_group_prefix='producer', + wakeup_timeout_ms=self.config['max_block_ms'], + **self.config) + + # Get auto-discovered version from client if necessary + if self.config['api_version'] is None: + self.config['api_version'] = client.config['api_version'] + + if self.config['compression_type'] == 'lz4': + assert self.config['api_version'] >= (0, 8, 2), 'LZ4 Requires >= Kafka 0.8.2 Brokers' + + if self.config['compression_type'] == 'zstd': + assert self.config['api_version'] >= (2, 1, 0), 'Zstd Requires >= Kafka 2.1.0 Brokers' + + # Check compression_type for library support + ct = self.config['compression_type'] + if ct not in self._COMPRESSORS: + raise ValueError("Not supported codec: {}".format(ct)) + else: + checker, compression_attrs = self._COMPRESSORS[ct] + assert checker(), "Libraries for {} compression codec not found".format(ct) + self.config['compression_attrs'] = compression_attrs + + message_version = self._max_usable_produce_magic() + self._accumulator = RecordAccumulator(message_version=message_version, metrics=self._metrics, **self.config) + self._metadata = client.cluster + guarantee_message_order = bool(self.config['max_in_flight_requests_per_connection'] == 1) + self._sender = Sender(client, self._metadata, + self._accumulator, self._metrics, + guarantee_message_order=guarantee_message_order, + **self.config) + self._sender.daemon = True + self._sender.start() + self._closed = False + + self._cleanup = self._cleanup_factory() + atexit.register(self._cleanup) + log.debug("Kafka producer started") + + def bootstrap_connected(self): + """Return True if the bootstrap is connected.""" + return self._sender.bootstrap_connected() + + def _cleanup_factory(self): + """Build a cleanup clojure that doesn't increase our ref count""" + _self = weakref.proxy(self) + def wrapper(): + try: + _self.close(timeout=0) + except (ReferenceError, AttributeError): + pass + return wrapper + + def _unregister_cleanup(self): + if getattr(self, '_cleanup', None): + if hasattr(atexit, 'unregister'): + atexit.unregister(self._cleanup) # pylint: disable=no-member + + # py2 requires removing from private attribute... + else: + + # ValueError on list.remove() if the exithandler no longer exists + # but that is fine here + try: + atexit._exithandlers.remove( # pylint: disable=no-member + (self._cleanup, (), {})) + except ValueError: + pass + self._cleanup = None + + def __del__(self): + # Disable logger during destruction to avoid touching dangling references + class NullLogger(object): + def __getattr__(self, name): + return lambda *args: None + + global log + log = NullLogger() + + self.close() + + def close(self, timeout=None): + """Close this producer. + + Arguments: + timeout (float, optional): timeout in seconds to wait for completion. + """ + + # drop our atexit handler now to avoid leaks + self._unregister_cleanup() + + if not hasattr(self, '_closed') or self._closed: + log.info('Kafka producer closed') + return + if timeout is None: + # threading.TIMEOUT_MAX is available in Python3.3+ + timeout = getattr(threading, 'TIMEOUT_MAX', float('inf')) + if getattr(threading, 'TIMEOUT_MAX', False): + assert 0 <= timeout <= getattr(threading, 'TIMEOUT_MAX') + else: + assert timeout >= 0 + + log.info("Closing the Kafka producer with %s secs timeout.", timeout) + invoked_from_callback = bool(threading.current_thread() is self._sender) + if timeout > 0: + if invoked_from_callback: + log.warning("Overriding close timeout %s secs to 0 in order to" + " prevent useless blocking due to self-join. This" + " means you have incorrectly invoked close with a" + " non-zero timeout from the producer call-back.", + timeout) + else: + # Try to close gracefully. + if self._sender is not None: + self._sender.initiate_close() + self._sender.join(timeout) + + if self._sender is not None and self._sender.is_alive(): + log.info("Proceeding to force close the producer since pending" + " requests could not be completed within timeout %s.", + timeout) + self._sender.force_close() + + self._metrics.close() + try: + self.config['key_serializer'].close() + except AttributeError: + pass + try: + self.config['value_serializer'].close() + except AttributeError: + pass + self._closed = True + log.debug("The Kafka producer has closed.") + + def partitions_for(self, topic): + """Returns set of all known partitions for the topic.""" + max_wait = self.config['max_block_ms'] / 1000.0 + return self._wait_on_metadata(topic, max_wait) + + def _max_usable_produce_magic(self): + if self.config['api_version'] >= (0, 11): + return 2 + elif self.config['api_version'] >= (0, 10): + return 1 + else: + return 0 + + def _estimate_size_in_bytes(self, key, value, headers=[]): + magic = self._max_usable_produce_magic() + if magic == 2: + return DefaultRecordBatchBuilder.estimate_size_in_bytes( + key, value, headers) + else: + return LegacyRecordBatchBuilder.estimate_size_in_bytes( + magic, self.config['compression_type'], key, value) + + def send(self, topic, value=None, key=None, headers=None, partition=None, timestamp_ms=None): + """Publish a message to a topic. + + Arguments: + topic (str): topic where the message will be published + value (optional): message value. Must be type bytes, or be + serializable to bytes via configured value_serializer. If value + is None, key is required and message acts as a 'delete'. + See kafka compaction documentation for more details: + https://kafka.apache.org/documentation.html#compaction + (compaction requires kafka >= 0.8.1) + partition (int, optional): optionally specify a partition. If not + set, the partition will be selected using the configured + 'partitioner'. + key (optional): a key to associate with the message. Can be used to + determine which partition to send the message to. If partition + is None (and producer's partitioner config is left as default), + then messages with the same key will be delivered to the same + partition (but if key is None, partition is chosen randomly). + Must be type bytes, or be serializable to bytes via configured + key_serializer. + headers (optional): a list of header key value pairs. List items + are tuples of str key and bytes value. + timestamp_ms (int, optional): epoch milliseconds (from Jan 1 1970 UTC) + to use as the message timestamp. Defaults to current time. + + Returns: + FutureRecordMetadata: resolves to RecordMetadata + + Raises: + KafkaTimeoutError: if unable to fetch topic metadata, or unable + to obtain memory buffer prior to configured max_block_ms + """ + assert value is not None or self.config['api_version'] >= (0, 8, 1), ( + 'Null messages require kafka >= 0.8.1') + assert not (value is None and key is None), 'Need at least one: key or value' + key_bytes = value_bytes = None + try: + self._wait_on_metadata(topic, self.config['max_block_ms'] / 1000.0) + + key_bytes = self._serialize( + self.config['key_serializer'], + topic, key) + value_bytes = self._serialize( + self.config['value_serializer'], + topic, value) + assert type(key_bytes) in (bytes, bytearray, memoryview, type(None)) + assert type(value_bytes) in (bytes, bytearray, memoryview, type(None)) + + partition = self._partition(topic, partition, key, value, + key_bytes, value_bytes) + + if headers is None: + headers = [] + assert type(headers) == list + assert all(type(item) == tuple and len(item) == 2 and type(item[0]) == str and type(item[1]) == bytes for item in headers) + + message_size = self._estimate_size_in_bytes(key_bytes, value_bytes, headers) + self._ensure_valid_record_size(message_size) + + tp = TopicPartition(topic, partition) + log.debug("Sending (key=%r value=%r headers=%r) to %s", key, value, headers, tp) + result = self._accumulator.append(tp, timestamp_ms, + key_bytes, value_bytes, headers, + self.config['max_block_ms'], + estimated_size=message_size) + future, batch_is_full, new_batch_created = result + if batch_is_full or new_batch_created: + log.debug("Waking up the sender since %s is either full or" + " getting a new batch", tp) + self._sender.wakeup() + + return future + # handling exceptions and record the errors; + # for API exceptions return them in the future, + # for other exceptions raise directly + except Errors.BrokerResponseError as e: + log.debug("Exception occurred during message send: %s", e) + return FutureRecordMetadata( + FutureProduceResult(TopicPartition(topic, partition)), + -1, None, None, + len(key_bytes) if key_bytes is not None else -1, + len(value_bytes) if value_bytes is not None else -1, + sum(len(h_key.encode("utf-8")) + len(h_value) for h_key, h_value in headers) if headers else -1, + ).failure(e) + + def flush(self, timeout=None): + """ + Invoking this method makes all buffered records immediately available + to send (even if linger_ms is greater than 0) and blocks on the + completion of the requests associated with these records. The + post-condition of :meth:`~kafka.KafkaProducer.flush` is that any + previously sent record will have completed + (e.g. Future.is_done() == True). A request is considered completed when + either it is successfully acknowledged according to the 'acks' + configuration for the producer, or it results in an error. + + Other threads can continue sending messages while one thread is blocked + waiting for a flush call to complete; however, no guarantee is made + about the completion of messages sent after the flush call begins. + + Arguments: + timeout (float, optional): timeout in seconds to wait for completion. + + Raises: + KafkaTimeoutError: failure to flush buffered records within the + provided timeout + """ + log.debug("Flushing accumulated records in producer.") # trace + self._accumulator.begin_flush() + self._sender.wakeup() + self._accumulator.await_flush_completion(timeout=timeout) + + def _ensure_valid_record_size(self, size): + """Validate that the record size isn't too large.""" + if size > self.config['max_request_size']: + raise Errors.MessageSizeTooLargeError( + "The message is %d bytes when serialized which is larger than" + " the maximum request size you have configured with the" + " max_request_size configuration" % (size,)) + if size > self.config['buffer_memory']: + raise Errors.MessageSizeTooLargeError( + "The message is %d bytes when serialized which is larger than" + " the total memory buffer you have configured with the" + " buffer_memory configuration." % (size,)) + + def _wait_on_metadata(self, topic, max_wait): + """ + Wait for cluster metadata including partitions for the given topic to + be available. + + Arguments: + topic (str): topic we want metadata for + max_wait (float): maximum time in secs for waiting on the metadata + + Returns: + set: partition ids for the topic + + Raises: + KafkaTimeoutError: if partitions for topic were not obtained before + specified max_wait timeout + """ + # add topic to metadata topic list if it is not there already. + self._sender.add_topic(topic) + begin = time.time() + elapsed = 0.0 + metadata_event = None + while True: + partitions = self._metadata.partitions_for_topic(topic) + if partitions is not None: + return partitions + + if not metadata_event: + metadata_event = threading.Event() + + log.debug("Requesting metadata update for topic %s", topic) + + metadata_event.clear() + future = self._metadata.request_update() + future.add_both(lambda e, *args: e.set(), metadata_event) + self._sender.wakeup() + metadata_event.wait(max_wait - elapsed) + elapsed = time.time() - begin + if not metadata_event.is_set(): + raise Errors.KafkaTimeoutError( + "Failed to update metadata after %.1f secs." % (max_wait,)) + elif topic in self._metadata.unauthorized_topics: + raise Errors.TopicAuthorizationFailedError(topic) + else: + log.debug("_wait_on_metadata woke after %s secs.", elapsed) + + def _serialize(self, f, topic, data): + if not f: + return data + if isinstance(f, Serializer): + return f.serialize(topic, data) + return f(data) + + def _partition(self, topic, partition, key, value, + serialized_key, serialized_value): + if partition is not None: + assert partition >= 0 + assert partition in self._metadata.partitions_for_topic(topic), 'Unrecognized partition' + return partition + + all_partitions = sorted(self._metadata.partitions_for_topic(topic)) + available = list(self._metadata.available_partitions_for_topic(topic)) + return self.config['partitioner'](serialized_key, + all_partitions, + available) + + def metrics(self, raw=False): + """Get metrics on producer performance. + + This is ported from the Java Producer, for details see: + https://kafka.apache.org/documentation/#producer_monitoring + + Warning: + This is an unstable interface. It may change in future + releases without warning. + """ + if raw: + return self._metrics.metrics.copy() + + metrics = {} + for k, v in six.iteritems(self._metrics.metrics.copy()): + if k.group not in metrics: + metrics[k.group] = {} + if k.name not in metrics[k.group]: + metrics[k.group][k.name] = {} + metrics[k.group][k.name] = v.value() + return metrics diff --git a/producer/record_accumulator.py b/producer/record_accumulator.py new file mode 100644 index 00000000..a2aa0e8e --- /dev/null +++ b/producer/record_accumulator.py @@ -0,0 +1,590 @@ +from __future__ import absolute_import + +import collections +import copy +import logging +import threading +import time + +import kafka.errors as Errors +from kafka.producer.buffer import SimpleBufferPool +from kafka.producer.future import FutureRecordMetadata, FutureProduceResult +from kafka.record.memory_records import MemoryRecordsBuilder +from kafka.structs import TopicPartition + + +log = logging.getLogger(__name__) + + +class AtomicInteger(object): + def __init__(self, val=0): + self._lock = threading.Lock() + self._val = val + + def increment(self): + with self._lock: + self._val += 1 + return self._val + + def decrement(self): + with self._lock: + self._val -= 1 + return self._val + + def get(self): + return self._val + + +class ProducerBatch(object): + def __init__(self, tp, records, buffer): + self.max_record_size = 0 + now = time.time() + self.created = now + self.drained = None + self.attempts = 0 + self.last_attempt = now + self.last_append = now + self.records = records + self.topic_partition = tp + self.produce_future = FutureProduceResult(tp) + self._retry = False + self._buffer = buffer # We only save it, we don't write to it + + @property + def record_count(self): + return self.records.next_offset() + + def try_append(self, timestamp_ms, key, value, headers): + metadata = self.records.append(timestamp_ms, key, value, headers) + if metadata is None: + return None + + self.max_record_size = max(self.max_record_size, metadata.size) + self.last_append = time.time() + future = FutureRecordMetadata(self.produce_future, metadata.offset, + metadata.timestamp, metadata.crc, + len(key) if key is not None else -1, + len(value) if value is not None else -1, + sum(len(h_key.encode("utf-8")) + len(h_val) for h_key, h_val in headers) if headers else -1) + return future + + def done(self, base_offset=None, timestamp_ms=None, exception=None, log_start_offset=None, global_error=None): + level = logging.DEBUG if exception is None else logging.WARNING + log.log(level, "Produced messages to topic-partition %s with base offset" + " %s log start offset %s and error %s.", self.topic_partition, base_offset, + log_start_offset, global_error) # trace + if self.produce_future.is_done: + log.warning('Batch is already closed -- ignoring batch.done()') + return + elif exception is None: + self.produce_future.success((base_offset, timestamp_ms, log_start_offset)) + else: + self.produce_future.failure(exception) + + def maybe_expire(self, request_timeout_ms, retry_backoff_ms, linger_ms, is_full): + """Expire batches if metadata is not available + + A batch whose metadata is not available should be expired if one + of the following is true: + + * the batch is not in retry AND request timeout has elapsed after + it is ready (full or linger.ms has reached). + + * the batch is in retry AND request timeout has elapsed after the + backoff period ended. + """ + now = time.time() + since_append = now - self.last_append + since_ready = now - (self.created + linger_ms / 1000.0) + since_backoff = now - (self.last_attempt + retry_backoff_ms / 1000.0) + timeout = request_timeout_ms / 1000.0 + + error = None + if not self.in_retry() and is_full and timeout < since_append: + error = "%d seconds have passed since last append" % (since_append,) + elif not self.in_retry() and timeout < since_ready: + error = "%d seconds have passed since batch creation plus linger time" % (since_ready,) + elif self.in_retry() and timeout < since_backoff: + error = "%d seconds have passed since last attempt plus backoff time" % (since_backoff,) + + if error: + self.records.close() + self.done(-1, None, Errors.KafkaTimeoutError( + "Batch for %s containing %s record(s) expired: %s" % ( + self.topic_partition, self.records.next_offset(), error))) + return True + return False + + def in_retry(self): + return self._retry + + def set_retry(self): + self._retry = True + + def buffer(self): + return self._buffer + + def __str__(self): + return 'ProducerBatch(topic_partition=%s, record_count=%d)' % ( + self.topic_partition, self.records.next_offset()) + + +class RecordAccumulator(object): + """ + This class maintains a dequeue per TopicPartition that accumulates messages + into MessageSets to be sent to the server. + + The accumulator attempts to bound memory use, and append calls will block + when that memory is exhausted. + + Keyword Arguments: + batch_size (int): Requests sent to brokers will contain multiple + batches, one for each partition with data available to be sent. + A small batch size will make batching less common and may reduce + throughput (a batch size of zero will disable batching entirely). + Default: 16384 + buffer_memory (int): The total bytes of memory the producer should use + to buffer records waiting to be sent to the server. If records are + sent faster than they can be delivered to the server the producer + will block up to max_block_ms, raising an exception on timeout. + In the current implementation, this setting is an approximation. + Default: 33554432 (32MB) + compression_attrs (int): The compression type for all data generated by + the producer. Valid values are gzip(1), snappy(2), lz4(3), or + none(0). + Compression is of full batches of data, so the efficacy of batching + will also impact the compression ratio (more batching means better + compression). Default: None. + linger_ms (int): An artificial delay time to add before declaring a + messageset (that isn't full) ready for sending. This allows + time for more records to arrive. Setting a non-zero linger_ms + will trade off some latency for potentially better throughput + due to more batching (and hence fewer, larger requests). + Default: 0 + retry_backoff_ms (int): An artificial delay time to retry the + produce request upon receiving an error. This avoids exhausting + all retries in a short period of time. Default: 100 + """ + DEFAULT_CONFIG = { + 'buffer_memory': 33554432, + 'batch_size': 16384, + 'compression_attrs': 0, + 'linger_ms': 0, + 'retry_backoff_ms': 100, + 'message_version': 0, + 'metrics': None, + 'metric_group_prefix': 'producer-metrics', + } + + def __init__(self, **configs): + self.config = copy.copy(self.DEFAULT_CONFIG) + for key in self.config: + if key in configs: + self.config[key] = configs.pop(key) + + self._closed = False + self._flushes_in_progress = AtomicInteger() + self._appends_in_progress = AtomicInteger() + self._batches = collections.defaultdict(collections.deque) # TopicPartition: [ProducerBatch] + self._tp_locks = {None: threading.Lock()} # TopicPartition: Lock, plus a lock to add entries + self._free = SimpleBufferPool(self.config['buffer_memory'], + self.config['batch_size'], + metrics=self.config['metrics'], + metric_group_prefix=self.config['metric_group_prefix']) + self._incomplete = IncompleteProducerBatches() + # The following variables should only be accessed by the sender thread, + # so we don't need to protect them w/ locking. + self.muted = set() + self._drain_index = 0 + + def append(self, tp, timestamp_ms, key, value, headers, max_time_to_block_ms, + estimated_size=0): + """Add a record to the accumulator, return the append result. + + The append result will contain the future metadata, and flag for + whether the appended batch is full or a new batch is created + + Arguments: + tp (TopicPartition): The topic/partition to which this record is + being sent + timestamp_ms (int): The timestamp of the record (epoch ms) + key (bytes): The key for the record + value (bytes): The value for the record + headers (List[Tuple[str, bytes]]): The header fields for the record + max_time_to_block_ms (int): The maximum time in milliseconds to + block for buffer memory to be available + + Returns: + tuple: (future, batch_is_full, new_batch_created) + """ + assert isinstance(tp, TopicPartition), 'not TopicPartition' + assert not self._closed, 'RecordAccumulator is closed' + # We keep track of the number of appending thread to make sure we do + # not miss batches in abortIncompleteBatches(). + self._appends_in_progress.increment() + try: + if tp not in self._tp_locks: + with self._tp_locks[None]: + if tp not in self._tp_locks: + self._tp_locks[tp] = threading.Lock() + + with self._tp_locks[tp]: + # check if we have an in-progress batch + dq = self._batches[tp] + if dq: + last = dq[-1] + future = last.try_append(timestamp_ms, key, value, headers) + if future is not None: + batch_is_full = len(dq) > 1 or last.records.is_full() + return future, batch_is_full, False + + size = max(self.config['batch_size'], estimated_size) + log.debug("Allocating a new %d byte message buffer for %s", size, tp) # trace + buf = self._free.allocate(size, max_time_to_block_ms) + with self._tp_locks[tp]: + # Need to check if producer is closed again after grabbing the + # dequeue lock. + assert not self._closed, 'RecordAccumulator is closed' + + if dq: + last = dq[-1] + future = last.try_append(timestamp_ms, key, value, headers) + if future is not None: + # Somebody else found us a batch, return the one we + # waited for! Hopefully this doesn't happen often... + self._free.deallocate(buf) + batch_is_full = len(dq) > 1 or last.records.is_full() + return future, batch_is_full, False + + records = MemoryRecordsBuilder( + self.config['message_version'], + self.config['compression_attrs'], + self.config['batch_size'] + ) + + batch = ProducerBatch(tp, records, buf) + future = batch.try_append(timestamp_ms, key, value, headers) + if not future: + raise Exception() + + dq.append(batch) + self._incomplete.add(batch) + batch_is_full = len(dq) > 1 or batch.records.is_full() + return future, batch_is_full, True + finally: + self._appends_in_progress.decrement() + + def abort_expired_batches(self, request_timeout_ms, cluster): + """Abort the batches that have been sitting in RecordAccumulator for + more than the configured request_timeout due to metadata being + unavailable. + + Arguments: + request_timeout_ms (int): milliseconds to timeout + cluster (ClusterMetadata): current metadata for kafka cluster + + Returns: + list of ProducerBatch that were expired + """ + expired_batches = [] + to_remove = [] + count = 0 + for tp in list(self._batches.keys()): + assert tp in self._tp_locks, 'TopicPartition not in locks dict' + + # We only check if the batch should be expired if the partition + # does not have a batch in flight. This is to avoid the later + # batches get expired when an earlier batch is still in progress. + # This protection only takes effect when user sets + # max.in.flight.request.per.connection=1. Otherwise the expiration + # order is not guranteed. + if tp in self.muted: + continue + + with self._tp_locks[tp]: + # iterate over the batches and expire them if they have stayed + # in accumulator for more than request_timeout_ms + dq = self._batches[tp] + for batch in dq: + is_full = bool(bool(batch != dq[-1]) or batch.records.is_full()) + # check if the batch is expired + if batch.maybe_expire(request_timeout_ms, + self.config['retry_backoff_ms'], + self.config['linger_ms'], + is_full): + expired_batches.append(batch) + to_remove.append(batch) + count += 1 + self.deallocate(batch) + else: + # Stop at the first batch that has not expired. + break + + # Python does not allow us to mutate the dq during iteration + # Assuming expired batches are infrequent, this is better than + # creating a new copy of the deque for iteration on every loop + if to_remove: + for batch in to_remove: + dq.remove(batch) + to_remove = [] + + if expired_batches: + log.warning("Expired %d batches in accumulator", count) # trace + + return expired_batches + + def reenqueue(self, batch): + """Re-enqueue the given record batch in the accumulator to retry.""" + now = time.time() + batch.attempts += 1 + batch.last_attempt = now + batch.last_append = now + batch.set_retry() + assert batch.topic_partition in self._tp_locks, 'TopicPartition not in locks dict' + assert batch.topic_partition in self._batches, 'TopicPartition not in batches' + dq = self._batches[batch.topic_partition] + with self._tp_locks[batch.topic_partition]: + dq.appendleft(batch) + + def ready(self, cluster): + """ + Get a list of nodes whose partitions are ready to be sent, and the + earliest time at which any non-sendable partition will be ready; + Also return the flag for whether there are any unknown leaders for the + accumulated partition batches. + + A destination node is ready to send if: + + * There is at least one partition that is not backing off its send + * and those partitions are not muted (to prevent reordering if + max_in_flight_requests_per_connection is set to 1) + * and any of the following are true: + + * The record set is full + * The record set has sat in the accumulator for at least linger_ms + milliseconds + * The accumulator is out of memory and threads are blocking waiting + for data (in this case all partitions are immediately considered + ready). + * The accumulator has been closed + + Arguments: + cluster (ClusterMetadata): + + Returns: + tuple: + ready_nodes (set): node_ids that have ready batches + next_ready_check (float): secs until next ready after backoff + unknown_leaders_exist (bool): True if metadata refresh needed + """ + ready_nodes = set() + next_ready_check = 9999999.99 + unknown_leaders_exist = False + now = time.time() + + exhausted = bool(self._free.queued() > 0) + # several threads are accessing self._batches -- to simplify + # concurrent access, we iterate over a snapshot of partitions + # and lock each partition separately as needed + partitions = list(self._batches.keys()) + for tp in partitions: + leader = cluster.leader_for_partition(tp) + if leader is None or leader == -1: + unknown_leaders_exist = True + continue + elif leader in ready_nodes: + continue + elif tp in self.muted: + continue + + with self._tp_locks[tp]: + dq = self._batches[tp] + if not dq: + continue + batch = dq[0] + retry_backoff = self.config['retry_backoff_ms'] / 1000.0 + linger = self.config['linger_ms'] / 1000.0 + backing_off = bool(batch.attempts > 0 and + batch.last_attempt + retry_backoff > now) + waited_time = now - batch.last_attempt + time_to_wait = retry_backoff if backing_off else linger + time_left = max(time_to_wait - waited_time, 0) + full = bool(len(dq) > 1 or batch.records.is_full()) + expired = bool(waited_time >= time_to_wait) + + sendable = (full or expired or exhausted or self._closed or + self._flush_in_progress()) + + if sendable and not backing_off: + ready_nodes.add(leader) + else: + # Note that this results in a conservative estimate since + # an un-sendable partition may have a leader that will + # later be found to have sendable data. However, this is + # good enough since we'll just wake up and then sleep again + # for the remaining time. + next_ready_check = min(time_left, next_ready_check) + + return ready_nodes, next_ready_check, unknown_leaders_exist + + def has_unsent(self): + """Return whether there is any unsent record in the accumulator.""" + for tp in list(self._batches.keys()): + with self._tp_locks[tp]: + dq = self._batches[tp] + if len(dq): + return True + return False + + def drain(self, cluster, nodes, max_size): + """ + Drain all the data for the given nodes and collate them into a list of + batches that will fit within the specified size on a per-node basis. + This method attempts to avoid choosing the same topic-node repeatedly. + + Arguments: + cluster (ClusterMetadata): The current cluster metadata + nodes (list): list of node_ids to drain + max_size (int): maximum number of bytes to drain + + Returns: + dict: {node_id: list of ProducerBatch} with total size less than the + requested max_size. + """ + if not nodes: + return {} + + now = time.time() + batches = {} + for node_id in nodes: + size = 0 + partitions = list(cluster.partitions_for_broker(node_id)) + ready = [] + # to make starvation less likely this loop doesn't start at 0 + self._drain_index %= len(partitions) + start = self._drain_index + while True: + tp = partitions[self._drain_index] + if tp in self._batches and tp not in self.muted: + with self._tp_locks[tp]: + dq = self._batches[tp] + if dq: + first = dq[0] + backoff = ( + bool(first.attempts > 0) and + bool(first.last_attempt + + self.config['retry_backoff_ms'] / 1000.0 + > now) + ) + # Only drain the batch if it is not during backoff + if not backoff: + if (size + first.records.size_in_bytes() > max_size + and len(ready) > 0): + # there is a rare case that a single batch + # size is larger than the request size due + # to compression; in this case we will + # still eventually send this batch in a + # single request + break + else: + batch = dq.popleft() + batch.records.close() + size += batch.records.size_in_bytes() + ready.append(batch) + batch.drained = now + + self._drain_index += 1 + self._drain_index %= len(partitions) + if start == self._drain_index: + break + + batches[node_id] = ready + return batches + + def deallocate(self, batch): + """Deallocate the record batch.""" + self._incomplete.remove(batch) + self._free.deallocate(batch.buffer()) + + def _flush_in_progress(self): + """Are there any threads currently waiting on a flush?""" + return self._flushes_in_progress.get() > 0 + + def begin_flush(self): + """ + Initiate the flushing of data from the accumulator...this makes all + requests immediately ready + """ + self._flushes_in_progress.increment() + + def await_flush_completion(self, timeout=None): + """ + Mark all partitions as ready to send and block until the send is complete + """ + try: + for batch in self._incomplete.all(): + log.debug('Waiting on produce to %s', + batch.produce_future.topic_partition) + if not batch.produce_future.wait(timeout=timeout): + raise Errors.KafkaTimeoutError('Timeout waiting for future') + if not batch.produce_future.is_done: + raise Errors.UnknownError('Future not done') + + if batch.produce_future.failed(): + log.warning(batch.produce_future.exception) + finally: + self._flushes_in_progress.decrement() + + def abort_incomplete_batches(self): + """ + This function is only called when sender is closed forcefully. It will fail all the + incomplete batches and return. + """ + # We need to keep aborting the incomplete batch until no thread is trying to append to + # 1. Avoid losing batches. + # 2. Free up memory in case appending threads are blocked on buffer full. + # This is a tight loop but should be able to get through very quickly. + while True: + self._abort_batches() + if not self._appends_in_progress.get(): + break + # After this point, no thread will append any messages because they will see the close + # flag set. We need to do the last abort after no thread was appending in case the there was a new + # batch appended by the last appending thread. + self._abort_batches() + self._batches.clear() + + def _abort_batches(self): + """Go through incomplete batches and abort them.""" + error = Errors.IllegalStateError("Producer is closed forcefully.") + for batch in self._incomplete.all(): + tp = batch.topic_partition + # Close the batch before aborting + with self._tp_locks[tp]: + batch.records.close() + batch.done(exception=error) + self.deallocate(batch) + + def close(self): + """Close this accumulator and force all the record buffers to be drained.""" + self._closed = True + + +class IncompleteProducerBatches(object): + """A threadsafe helper class to hold ProducerBatches that haven't been ack'd yet""" + + def __init__(self): + self._incomplete = set() + self._lock = threading.Lock() + + def add(self, batch): + with self._lock: + return self._incomplete.add(batch) + + def remove(self, batch): + with self._lock: + return self._incomplete.remove(batch) + + def all(self): + with self._lock: + return list(self._incomplete) diff --git a/producer/sender.py b/producer/sender.py new file mode 100644 index 00000000..35688d3f --- /dev/null +++ b/producer/sender.py @@ -0,0 +1,517 @@ +from __future__ import absolute_import, division + +import collections +import copy +import logging +import threading +import time + +from kafka.vendor import six + +from kafka import errors as Errors +from kafka.metrics.measurable import AnonMeasurable +from kafka.metrics.stats import Avg, Max, Rate +from kafka.protocol.produce import ProduceRequest +from kafka.structs import TopicPartition +from kafka.version import __version__ + +log = logging.getLogger(__name__) + + +class Sender(threading.Thread): + """ + The background thread that handles the sending of produce requests to the + Kafka cluster. This thread makes metadata requests to renew its view of the + cluster and then sends produce requests to the appropriate nodes. + """ + DEFAULT_CONFIG = { + 'max_request_size': 1048576, + 'acks': 1, + 'retries': 0, + 'request_timeout_ms': 30000, + 'guarantee_message_order': False, + 'client_id': 'kafka-python-' + __version__, + 'api_version': (0, 8, 0), + } + + def __init__(self, client, metadata, accumulator, metrics, **configs): + super(Sender, self).__init__() + self.config = copy.copy(self.DEFAULT_CONFIG) + for key in self.config: + if key in configs: + self.config[key] = configs.pop(key) + + self.name = self.config['client_id'] + '-network-thread' + self._client = client + self._accumulator = accumulator + self._metadata = client.cluster + self._running = True + self._force_close = False + self._topics_to_add = set() + self._sensors = SenderMetrics(metrics, self._client, self._metadata) + + def run(self): + """The main run loop for the sender thread.""" + log.debug("Starting Kafka producer I/O thread.") + + # main loop, runs until close is called + while self._running: + try: + self.run_once() + except Exception: + log.exception("Uncaught error in kafka producer I/O thread") + + log.debug("Beginning shutdown of Kafka producer I/O thread, sending" + " remaining records.") + + # okay we stopped accepting requests but there may still be + # requests in the accumulator or waiting for acknowledgment, + # wait until these are completed. + while (not self._force_close + and (self._accumulator.has_unsent() + or self._client.in_flight_request_count() > 0)): + try: + self.run_once() + except Exception: + log.exception("Uncaught error in kafka producer I/O thread") + + if self._force_close: + # We need to fail all the incomplete batches and wake up the + # threads waiting on the futures. + self._accumulator.abort_incomplete_batches() + + try: + self._client.close() + except Exception: + log.exception("Failed to close network client") + + log.debug("Shutdown of Kafka producer I/O thread has completed.") + + def run_once(self): + """Run a single iteration of sending.""" + while self._topics_to_add: + self._client.add_topic(self._topics_to_add.pop()) + + # get the list of partitions with data ready to send + result = self._accumulator.ready(self._metadata) + ready_nodes, next_ready_check_delay, unknown_leaders_exist = result + + # if there are any partitions whose leaders are not known yet, force + # metadata update + if unknown_leaders_exist: + log.debug('Unknown leaders exist, requesting metadata update') + self._metadata.request_update() + + # remove any nodes we aren't ready to send to + not_ready_timeout = float('inf') + for node in list(ready_nodes): + if not self._client.is_ready(node): + log.debug('Node %s not ready; delaying produce of accumulated batch', node) + self._client.maybe_connect(node, wakeup=False) + ready_nodes.remove(node) + not_ready_timeout = min(not_ready_timeout, + self._client.connection_delay(node)) + + # create produce requests + batches_by_node = self._accumulator.drain( + self._metadata, ready_nodes, self.config['max_request_size']) + + if self.config['guarantee_message_order']: + # Mute all the partitions drained + for batch_list in six.itervalues(batches_by_node): + for batch in batch_list: + self._accumulator.muted.add(batch.topic_partition) + + expired_batches = self._accumulator.abort_expired_batches( + self.config['request_timeout_ms'], self._metadata) + for expired_batch in expired_batches: + self._sensors.record_errors(expired_batch.topic_partition.topic, expired_batch.record_count) + + self._sensors.update_produce_request_metrics(batches_by_node) + requests = self._create_produce_requests(batches_by_node) + # If we have any nodes that are ready to send + have sendable data, + # poll with 0 timeout so this can immediately loop and try sending more + # data. Otherwise, the timeout is determined by nodes that have + # partitions with data that isn't yet sendable (e.g. lingering, backing + # off). Note that this specifically does not include nodes with + # sendable data that aren't ready to send since they would cause busy + # looping. + poll_timeout_ms = min(next_ready_check_delay * 1000, not_ready_timeout) + if ready_nodes: + log.debug("Nodes with data ready to send: %s", ready_nodes) # trace + log.debug("Created %d produce requests: %s", len(requests), requests) # trace + poll_timeout_ms = 0 + + for node_id, request in six.iteritems(requests): + batches = batches_by_node[node_id] + log.debug('Sending Produce Request: %r', request) + (self._client.send(node_id, request, wakeup=False) + .add_callback( + self._handle_produce_response, node_id, time.time(), batches) + .add_errback( + self._failed_produce, batches, node_id)) + + # if some partitions are already ready to be sent, the select time + # would be 0; otherwise if some partition already has some data + # accumulated but not ready yet, the select time will be the time + # difference between now and its linger expiry time; otherwise the + # select time will be the time difference between now and the + # metadata expiry time + self._client.poll(timeout_ms=poll_timeout_ms) + + def initiate_close(self): + """Start closing the sender (won't complete until all data is sent).""" + self._running = False + self._accumulator.close() + self.wakeup() + + def force_close(self): + """Closes the sender without sending out any pending messages.""" + self._force_close = True + self.initiate_close() + + def add_topic(self, topic): + # This is generally called from a separate thread + # so this needs to be a thread-safe operation + # we assume that checking set membership across threads + # is ok where self._client._topics should never + # remove topics for a producer instance, only add them. + if topic not in self._client._topics: + self._topics_to_add.add(topic) + self.wakeup() + + def _failed_produce(self, batches, node_id, error): + log.debug("Error sending produce request to node %d: %s", node_id, error) # trace + for batch in batches: + self._complete_batch(batch, error, -1, None) + + def _handle_produce_response(self, node_id, send_time, batches, response): + """Handle a produce response.""" + # if we have a response, parse it + log.debug('Parsing produce response: %r', response) + if response: + batches_by_partition = dict([(batch.topic_partition, batch) + for batch in batches]) + + for topic, partitions in response.topics: + for partition_info in partitions: + global_error = None + log_start_offset = None + if response.API_VERSION < 2: + partition, error_code, offset = partition_info + ts = None + elif 2 <= response.API_VERSION <= 4: + partition, error_code, offset, ts = partition_info + elif 5 <= response.API_VERSION <= 7: + partition, error_code, offset, ts, log_start_offset = partition_info + else: + # the ignored parameter is record_error of type list[(batch_index: int, error_message: str)] + partition, error_code, offset, ts, log_start_offset, _, global_error = partition_info + tp = TopicPartition(topic, partition) + error = Errors.for_code(error_code) + batch = batches_by_partition[tp] + self._complete_batch(batch, error, offset, ts, log_start_offset, global_error) + + if response.API_VERSION > 0: + self._sensors.record_throttle_time(response.throttle_time_ms, node=node_id) + + else: + # this is the acks = 0 case, just complete all requests + for batch in batches: + self._complete_batch(batch, None, -1, None) + + def _complete_batch(self, batch, error, base_offset, timestamp_ms=None, log_start_offset=None, global_error=None): + """Complete or retry the given batch of records. + + Arguments: + batch (RecordBatch): The record batch + error (Exception): The error (or None if none) + base_offset (int): The base offset assigned to the records if successful + timestamp_ms (int, optional): The timestamp returned by the broker for this batch + log_start_offset (int): The start offset of the log at the time this produce response was created + global_error (str): The summarising error message + """ + # Standardize no-error to None + if error is Errors.NoError: + error = None + + if error is not None and self._can_retry(batch, error): + # retry + log.warning("Got error produce response on topic-partition %s," + " retrying (%d attempts left). Error: %s", + batch.topic_partition, + self.config['retries'] - batch.attempts - 1, + global_error or error) + self._accumulator.reenqueue(batch) + self._sensors.record_retries(batch.topic_partition.topic, batch.record_count) + else: + if error is Errors.TopicAuthorizationFailedError: + error = error(batch.topic_partition.topic) + + # tell the user the result of their request + batch.done(base_offset, timestamp_ms, error, log_start_offset, global_error) + self._accumulator.deallocate(batch) + if error is not None: + self._sensors.record_errors(batch.topic_partition.topic, batch.record_count) + + if getattr(error, 'invalid_metadata', False): + self._metadata.request_update() + + # Unmute the completed partition. + if self.config['guarantee_message_order']: + self._accumulator.muted.remove(batch.topic_partition) + + def _can_retry(self, batch, error): + """ + We can retry a send if the error is transient and the number of + attempts taken is fewer than the maximum allowed + """ + return (batch.attempts < self.config['retries'] + and getattr(error, 'retriable', False)) + + def _create_produce_requests(self, collated): + """ + Transfer the record batches into a list of produce requests on a + per-node basis. + + Arguments: + collated: {node_id: [RecordBatch]} + + Returns: + dict: {node_id: ProduceRequest} (version depends on api_version) + """ + requests = {} + for node_id, batches in six.iteritems(collated): + requests[node_id] = self._produce_request( + node_id, self.config['acks'], + self.config['request_timeout_ms'], batches) + return requests + + def _produce_request(self, node_id, acks, timeout, batches): + """Create a produce request from the given record batches. + + Returns: + ProduceRequest (version depends on api_version) + """ + produce_records_by_partition = collections.defaultdict(dict) + for batch in batches: + topic = batch.topic_partition.topic + partition = batch.topic_partition.partition + + buf = batch.records.buffer() + produce_records_by_partition[topic][partition] = buf + + kwargs = {} + if self.config['api_version'] >= (2, 1): + version = 7 + elif self.config['api_version'] >= (2, 0): + version = 6 + elif self.config['api_version'] >= (1, 1): + version = 5 + elif self.config['api_version'] >= (1, 0): + version = 4 + elif self.config['api_version'] >= (0, 11): + version = 3 + kwargs = dict(transactional_id=None) + elif self.config['api_version'] >= (0, 10): + version = 2 + elif self.config['api_version'] == (0, 9): + version = 1 + else: + version = 0 + return ProduceRequest[version]( + required_acks=acks, + timeout=timeout, + topics=[(topic, list(partition_info.items())) + for topic, partition_info + in six.iteritems(produce_records_by_partition)], + **kwargs + ) + + def wakeup(self): + """Wake up the selector associated with this send thread.""" + self._client.wakeup() + + def bootstrap_connected(self): + return self._client.bootstrap_connected() + + +class SenderMetrics(object): + + def __init__(self, metrics, client, metadata): + self.metrics = metrics + self._client = client + self._metadata = metadata + + sensor_name = 'batch-size' + self.batch_size_sensor = self.metrics.sensor(sensor_name) + self.add_metric('batch-size-avg', Avg(), + sensor_name=sensor_name, + description='The average number of bytes sent per partition per-request.') + self.add_metric('batch-size-max', Max(), + sensor_name=sensor_name, + description='The max number of bytes sent per partition per-request.') + + sensor_name = 'compression-rate' + self.compression_rate_sensor = self.metrics.sensor(sensor_name) + self.add_metric('compression-rate-avg', Avg(), + sensor_name=sensor_name, + description='The average compression rate of record batches.') + + sensor_name = 'queue-time' + self.queue_time_sensor = self.metrics.sensor(sensor_name) + self.add_metric('record-queue-time-avg', Avg(), + sensor_name=sensor_name, + description='The average time in ms record batches spent in the record accumulator.') + self.add_metric('record-queue-time-max', Max(), + sensor_name=sensor_name, + description='The maximum time in ms record batches spent in the record accumulator.') + + sensor_name = 'produce-throttle-time' + self.produce_throttle_time_sensor = self.metrics.sensor(sensor_name) + self.add_metric('produce-throttle-time-avg', Avg(), + sensor_name=sensor_name, + description='The average throttle time in ms') + self.add_metric('produce-throttle-time-max', Max(), + sensor_name=sensor_name, + description='The maximum throttle time in ms') + + sensor_name = 'records-per-request' + self.records_per_request_sensor = self.metrics.sensor(sensor_name) + self.add_metric('record-send-rate', Rate(), + sensor_name=sensor_name, + description='The average number of records sent per second.') + self.add_metric('records-per-request-avg', Avg(), + sensor_name=sensor_name, + description='The average number of records per request.') + + sensor_name = 'bytes' + self.byte_rate_sensor = self.metrics.sensor(sensor_name) + self.add_metric('byte-rate', Rate(), + sensor_name=sensor_name, + description='The average number of bytes sent per second.') + + sensor_name = 'record-retries' + self.retry_sensor = self.metrics.sensor(sensor_name) + self.add_metric('record-retry-rate', Rate(), + sensor_name=sensor_name, + description='The average per-second number of retried record sends') + + sensor_name = 'errors' + self.error_sensor = self.metrics.sensor(sensor_name) + self.add_metric('record-error-rate', Rate(), + sensor_name=sensor_name, + description='The average per-second number of record sends that resulted in errors') + + sensor_name = 'record-size-max' + self.max_record_size_sensor = self.metrics.sensor(sensor_name) + self.add_metric('record-size-max', Max(), + sensor_name=sensor_name, + description='The maximum record size across all batches') + self.add_metric('record-size-avg', Avg(), + sensor_name=sensor_name, + description='The average maximum record size per batch') + + self.add_metric('requests-in-flight', + AnonMeasurable(lambda *_: self._client.in_flight_request_count()), + description='The current number of in-flight requests awaiting a response.') + + self.add_metric('metadata-age', + AnonMeasurable(lambda _, now: (now - self._metadata._last_successful_refresh_ms) / 1000), + description='The age in seconds of the current producer metadata being used.') + + def add_metric(self, metric_name, measurable, group_name='producer-metrics', + description=None, tags=None, + sensor_name=None): + m = self.metrics + metric = m.metric_name(metric_name, group_name, description, tags) + if sensor_name: + sensor = m.sensor(sensor_name) + sensor.add(metric, measurable) + else: + m.add_metric(metric, measurable) + + def maybe_register_topic_metrics(self, topic): + + def sensor_name(name): + return 'topic.{0}.{1}'.format(topic, name) + + # if one sensor of the metrics has been registered for the topic, + # then all other sensors should have been registered; and vice versa + if not self.metrics.get_sensor(sensor_name('records-per-batch')): + + self.add_metric('record-send-rate', Rate(), + sensor_name=sensor_name('records-per-batch'), + group_name='producer-topic-metrics.' + topic, + description= 'Records sent per second for topic ' + topic) + + self.add_metric('byte-rate', Rate(), + sensor_name=sensor_name('bytes'), + group_name='producer-topic-metrics.' + topic, + description='Bytes per second for topic ' + topic) + + self.add_metric('compression-rate', Avg(), + sensor_name=sensor_name('compression-rate'), + group_name='producer-topic-metrics.' + topic, + description='Average Compression ratio for topic ' + topic) + + self.add_metric('record-retry-rate', Rate(), + sensor_name=sensor_name('record-retries'), + group_name='producer-topic-metrics.' + topic, + description='Record retries per second for topic ' + topic) + + self.add_metric('record-error-rate', Rate(), + sensor_name=sensor_name('record-errors'), + group_name='producer-topic-metrics.' + topic, + description='Record errors per second for topic ' + topic) + + def update_produce_request_metrics(self, batches_map): + for node_batch in batches_map.values(): + records = 0 + total_bytes = 0 + for batch in node_batch: + # register all per-topic metrics at once + topic = batch.topic_partition.topic + self.maybe_register_topic_metrics(topic) + + # per-topic record send rate + topic_records_count = self.metrics.get_sensor( + 'topic.' + topic + '.records-per-batch') + topic_records_count.record(batch.record_count) + + # per-topic bytes send rate + topic_byte_rate = self.metrics.get_sensor( + 'topic.' + topic + '.bytes') + topic_byte_rate.record(batch.records.size_in_bytes()) + + # per-topic compression rate + topic_compression_rate = self.metrics.get_sensor( + 'topic.' + topic + '.compression-rate') + topic_compression_rate.record(batch.records.compression_rate()) + + # global metrics + self.batch_size_sensor.record(batch.records.size_in_bytes()) + if batch.drained: + self.queue_time_sensor.record(batch.drained - batch.created) + self.compression_rate_sensor.record(batch.records.compression_rate()) + self.max_record_size_sensor.record(batch.max_record_size) + records += batch.record_count + total_bytes += batch.records.size_in_bytes() + + self.records_per_request_sensor.record(records) + self.byte_rate_sensor.record(total_bytes) + + def record_retries(self, topic, count): + self.retry_sensor.record(count) + sensor = self.metrics.get_sensor('topic.' + topic + '.record-retries') + if sensor: + sensor.record(count) + + def record_errors(self, topic, count): + self.error_sensor.record(count) + sensor = self.metrics.get_sensor('topic.' + topic + '.record-errors') + if sensor: + sensor.record(count) + + def record_throttle_time(self, throttle_time_ms, node=None): + self.produce_throttle_time_sensor.record(throttle_time_ms) diff --git a/protocol/__init__.py b/protocol/__init__.py new file mode 100644 index 00000000..025447f9 --- /dev/null +++ b/protocol/__init__.py @@ -0,0 +1,49 @@ +from __future__ import absolute_import + + +API_KEYS = { + 0: 'Produce', + 1: 'Fetch', + 2: 'ListOffsets', + 3: 'Metadata', + 4: 'LeaderAndIsr', + 5: 'StopReplica', + 6: 'UpdateMetadata', + 7: 'ControlledShutdown', + 8: 'OffsetCommit', + 9: 'OffsetFetch', + 10: 'FindCoordinator', + 11: 'JoinGroup', + 12: 'Heartbeat', + 13: 'LeaveGroup', + 14: 'SyncGroup', + 15: 'DescribeGroups', + 16: 'ListGroups', + 17: 'SaslHandshake', + 18: 'ApiVersions', + 19: 'CreateTopics', + 20: 'DeleteTopics', + 21: 'DeleteRecords', + 22: 'InitProducerId', + 23: 'OffsetForLeaderEpoch', + 24: 'AddPartitionsToTxn', + 25: 'AddOffsetsToTxn', + 26: 'EndTxn', + 27: 'WriteTxnMarkers', + 28: 'TxnOffsetCommit', + 29: 'DescribeAcls', + 30: 'CreateAcls', + 31: 'DeleteAcls', + 32: 'DescribeConfigs', + 33: 'AlterConfigs', + 36: 'SaslAuthenticate', + 37: 'CreatePartitions', + 38: 'CreateDelegationToken', + 39: 'RenewDelegationToken', + 40: 'ExpireDelegationToken', + 41: 'DescribeDelegationToken', + 42: 'DeleteGroups', + 45: 'AlterPartitionReassignments', + 46: 'ListPartitionReassignments', + 48: 'DescribeClientQuotas', +} diff --git a/protocol/abstract.py b/protocol/abstract.py new file mode 100644 index 00000000..2de65c4b --- /dev/null +++ b/protocol/abstract.py @@ -0,0 +1,19 @@ +from __future__ import absolute_import + +import abc + + +class AbstractType(object): + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def encode(cls, value): # pylint: disable=no-self-argument + pass + + @abc.abstractmethod + def decode(cls, data): # pylint: disable=no-self-argument + pass + + @classmethod + def repr(cls, value): + return repr(value) diff --git a/protocol/admin.py b/protocol/admin.py new file mode 100644 index 00000000..f9d61e5c --- /dev/null +++ b/protocol/admin.py @@ -0,0 +1,1054 @@ +from __future__ import absolute_import + +from kafka.protocol.api import Request, Response +from kafka.protocol.types import Array, Boolean, Bytes, Int8, Int16, Int32, Int64, Schema, String, Float64, CompactString, CompactArray, TaggedFields + + +class ApiVersionResponse_v0(Response): + API_KEY = 18 + API_VERSION = 0 + SCHEMA = Schema( + ('error_code', Int16), + ('api_versions', Array( + ('api_key', Int16), + ('min_version', Int16), + ('max_version', Int16))) + ) + + +class ApiVersionResponse_v1(Response): + API_KEY = 18 + API_VERSION = 1 + SCHEMA = Schema( + ('error_code', Int16), + ('api_versions', Array( + ('api_key', Int16), + ('min_version', Int16), + ('max_version', Int16))), + ('throttle_time_ms', Int32) + ) + + +class ApiVersionResponse_v2(Response): + API_KEY = 18 + API_VERSION = 2 + SCHEMA = ApiVersionResponse_v1.SCHEMA + + +class ApiVersionRequest_v0(Request): + API_KEY = 18 + API_VERSION = 0 + RESPONSE_TYPE = ApiVersionResponse_v0 + SCHEMA = Schema() + + +class ApiVersionRequest_v1(Request): + API_KEY = 18 + API_VERSION = 1 + RESPONSE_TYPE = ApiVersionResponse_v1 + SCHEMA = ApiVersionRequest_v0.SCHEMA + + +class ApiVersionRequest_v2(Request): + API_KEY = 18 + API_VERSION = 2 + RESPONSE_TYPE = ApiVersionResponse_v1 + SCHEMA = ApiVersionRequest_v0.SCHEMA + + +ApiVersionRequest = [ + ApiVersionRequest_v0, ApiVersionRequest_v1, ApiVersionRequest_v2, +] +ApiVersionResponse = [ + ApiVersionResponse_v0, ApiVersionResponse_v1, ApiVersionResponse_v2, +] + + +class CreateTopicsResponse_v0(Response): + API_KEY = 19 + API_VERSION = 0 + SCHEMA = Schema( + ('topic_errors', Array( + ('topic', String('utf-8')), + ('error_code', Int16))) + ) + + +class CreateTopicsResponse_v1(Response): + API_KEY = 19 + API_VERSION = 1 + SCHEMA = Schema( + ('topic_errors', Array( + ('topic', String('utf-8')), + ('error_code', Int16), + ('error_message', String('utf-8')))) + ) + + +class CreateTopicsResponse_v2(Response): + API_KEY = 19 + API_VERSION = 2 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('topic_errors', Array( + ('topic', String('utf-8')), + ('error_code', Int16), + ('error_message', String('utf-8')))) + ) + +class CreateTopicsResponse_v3(Response): + API_KEY = 19 + API_VERSION = 3 + SCHEMA = CreateTopicsResponse_v2.SCHEMA + + +class CreateTopicsRequest_v0(Request): + API_KEY = 19 + API_VERSION = 0 + RESPONSE_TYPE = CreateTopicsResponse_v0 + SCHEMA = Schema( + ('create_topic_requests', Array( + ('topic', String('utf-8')), + ('num_partitions', Int32), + ('replication_factor', Int16), + ('replica_assignment', Array( + ('partition_id', Int32), + ('replicas', Array(Int32)))), + ('configs', Array( + ('config_key', String('utf-8')), + ('config_value', String('utf-8')))))), + ('timeout', Int32) + ) + + +class CreateTopicsRequest_v1(Request): + API_KEY = 19 + API_VERSION = 1 + RESPONSE_TYPE = CreateTopicsResponse_v1 + SCHEMA = Schema( + ('create_topic_requests', Array( + ('topic', String('utf-8')), + ('num_partitions', Int32), + ('replication_factor', Int16), + ('replica_assignment', Array( + ('partition_id', Int32), + ('replicas', Array(Int32)))), + ('configs', Array( + ('config_key', String('utf-8')), + ('config_value', String('utf-8')))))), + ('timeout', Int32), + ('validate_only', Boolean) + ) + + +class CreateTopicsRequest_v2(Request): + API_KEY = 19 + API_VERSION = 2 + RESPONSE_TYPE = CreateTopicsResponse_v2 + SCHEMA = CreateTopicsRequest_v1.SCHEMA + + +class CreateTopicsRequest_v3(Request): + API_KEY = 19 + API_VERSION = 3 + RESPONSE_TYPE = CreateTopicsResponse_v3 + SCHEMA = CreateTopicsRequest_v1.SCHEMA + + +CreateTopicsRequest = [ + CreateTopicsRequest_v0, CreateTopicsRequest_v1, + CreateTopicsRequest_v2, CreateTopicsRequest_v3, +] +CreateTopicsResponse = [ + CreateTopicsResponse_v0, CreateTopicsResponse_v1, + CreateTopicsResponse_v2, CreateTopicsResponse_v3, +] + + +class DeleteTopicsResponse_v0(Response): + API_KEY = 20 + API_VERSION = 0 + SCHEMA = Schema( + ('topic_error_codes', Array( + ('topic', String('utf-8')), + ('error_code', Int16))) + ) + + +class DeleteTopicsResponse_v1(Response): + API_KEY = 20 + API_VERSION = 1 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('topic_error_codes', Array( + ('topic', String('utf-8')), + ('error_code', Int16))) + ) + + +class DeleteTopicsResponse_v2(Response): + API_KEY = 20 + API_VERSION = 2 + SCHEMA = DeleteTopicsResponse_v1.SCHEMA + + +class DeleteTopicsResponse_v3(Response): + API_KEY = 20 + API_VERSION = 3 + SCHEMA = DeleteTopicsResponse_v1.SCHEMA + + +class DeleteTopicsRequest_v0(Request): + API_KEY = 20 + API_VERSION = 0 + RESPONSE_TYPE = DeleteTopicsResponse_v0 + SCHEMA = Schema( + ('topics', Array(String('utf-8'))), + ('timeout', Int32) + ) + + +class DeleteTopicsRequest_v1(Request): + API_KEY = 20 + API_VERSION = 1 + RESPONSE_TYPE = DeleteTopicsResponse_v1 + SCHEMA = DeleteTopicsRequest_v0.SCHEMA + + +class DeleteTopicsRequest_v2(Request): + API_KEY = 20 + API_VERSION = 2 + RESPONSE_TYPE = DeleteTopicsResponse_v2 + SCHEMA = DeleteTopicsRequest_v0.SCHEMA + + +class DeleteTopicsRequest_v3(Request): + API_KEY = 20 + API_VERSION = 3 + RESPONSE_TYPE = DeleteTopicsResponse_v3 + SCHEMA = DeleteTopicsRequest_v0.SCHEMA + + +DeleteTopicsRequest = [ + DeleteTopicsRequest_v0, DeleteTopicsRequest_v1, + DeleteTopicsRequest_v2, DeleteTopicsRequest_v3, +] +DeleteTopicsResponse = [ + DeleteTopicsResponse_v0, DeleteTopicsResponse_v1, + DeleteTopicsResponse_v2, DeleteTopicsResponse_v3, +] + + +class ListGroupsResponse_v0(Response): + API_KEY = 16 + API_VERSION = 0 + SCHEMA = Schema( + ('error_code', Int16), + ('groups', Array( + ('group', String('utf-8')), + ('protocol_type', String('utf-8')))) + ) + + +class ListGroupsResponse_v1(Response): + API_KEY = 16 + API_VERSION = 1 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('error_code', Int16), + ('groups', Array( + ('group', String('utf-8')), + ('protocol_type', String('utf-8')))) + ) + +class ListGroupsResponse_v2(Response): + API_KEY = 16 + API_VERSION = 2 + SCHEMA = ListGroupsResponse_v1.SCHEMA + + +class ListGroupsRequest_v0(Request): + API_KEY = 16 + API_VERSION = 0 + RESPONSE_TYPE = ListGroupsResponse_v0 + SCHEMA = Schema() + + +class ListGroupsRequest_v1(Request): + API_KEY = 16 + API_VERSION = 1 + RESPONSE_TYPE = ListGroupsResponse_v1 + SCHEMA = ListGroupsRequest_v0.SCHEMA + +class ListGroupsRequest_v2(Request): + API_KEY = 16 + API_VERSION = 1 + RESPONSE_TYPE = ListGroupsResponse_v2 + SCHEMA = ListGroupsRequest_v0.SCHEMA + + +ListGroupsRequest = [ + ListGroupsRequest_v0, ListGroupsRequest_v1, + ListGroupsRequest_v2, +] +ListGroupsResponse = [ + ListGroupsResponse_v0, ListGroupsResponse_v1, + ListGroupsResponse_v2, +] + + +class DescribeGroupsResponse_v0(Response): + API_KEY = 15 + API_VERSION = 0 + SCHEMA = Schema( + ('groups', Array( + ('error_code', Int16), + ('group', String('utf-8')), + ('state', String('utf-8')), + ('protocol_type', String('utf-8')), + ('protocol', String('utf-8')), + ('members', Array( + ('member_id', String('utf-8')), + ('client_id', String('utf-8')), + ('client_host', String('utf-8')), + ('member_metadata', Bytes), + ('member_assignment', Bytes))))) + ) + + +class DescribeGroupsResponse_v1(Response): + API_KEY = 15 + API_VERSION = 1 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('groups', Array( + ('error_code', Int16), + ('group', String('utf-8')), + ('state', String('utf-8')), + ('protocol_type', String('utf-8')), + ('protocol', String('utf-8')), + ('members', Array( + ('member_id', String('utf-8')), + ('client_id', String('utf-8')), + ('client_host', String('utf-8')), + ('member_metadata', Bytes), + ('member_assignment', Bytes))))) + ) + + +class DescribeGroupsResponse_v2(Response): + API_KEY = 15 + API_VERSION = 2 + SCHEMA = DescribeGroupsResponse_v1.SCHEMA + + +class DescribeGroupsResponse_v3(Response): + API_KEY = 15 + API_VERSION = 3 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('groups', Array( + ('error_code', Int16), + ('group', String('utf-8')), + ('state', String('utf-8')), + ('protocol_type', String('utf-8')), + ('protocol', String('utf-8')), + ('members', Array( + ('member_id', String('utf-8')), + ('client_id', String('utf-8')), + ('client_host', String('utf-8')), + ('member_metadata', Bytes), + ('member_assignment', Bytes)))), + ('authorized_operations', Int32)) + ) + + +class DescribeGroupsRequest_v0(Request): + API_KEY = 15 + API_VERSION = 0 + RESPONSE_TYPE = DescribeGroupsResponse_v0 + SCHEMA = Schema( + ('groups', Array(String('utf-8'))) + ) + + +class DescribeGroupsRequest_v1(Request): + API_KEY = 15 + API_VERSION = 1 + RESPONSE_TYPE = DescribeGroupsResponse_v1 + SCHEMA = DescribeGroupsRequest_v0.SCHEMA + + +class DescribeGroupsRequest_v2(Request): + API_KEY = 15 + API_VERSION = 2 + RESPONSE_TYPE = DescribeGroupsResponse_v2 + SCHEMA = DescribeGroupsRequest_v0.SCHEMA + + +class DescribeGroupsRequest_v3(Request): + API_KEY = 15 + API_VERSION = 3 + RESPONSE_TYPE = DescribeGroupsResponse_v2 + SCHEMA = Schema( + ('groups', Array(String('utf-8'))), + ('include_authorized_operations', Boolean) + ) + + +DescribeGroupsRequest = [ + DescribeGroupsRequest_v0, DescribeGroupsRequest_v1, + DescribeGroupsRequest_v2, DescribeGroupsRequest_v3, +] +DescribeGroupsResponse = [ + DescribeGroupsResponse_v0, DescribeGroupsResponse_v1, + DescribeGroupsResponse_v2, DescribeGroupsResponse_v3, +] + + +class SaslHandShakeResponse_v0(Response): + API_KEY = 17 + API_VERSION = 0 + SCHEMA = Schema( + ('error_code', Int16), + ('enabled_mechanisms', Array(String('utf-8'))) + ) + + +class SaslHandShakeResponse_v1(Response): + API_KEY = 17 + API_VERSION = 1 + SCHEMA = SaslHandShakeResponse_v0.SCHEMA + + +class SaslHandShakeRequest_v0(Request): + API_KEY = 17 + API_VERSION = 0 + RESPONSE_TYPE = SaslHandShakeResponse_v0 + SCHEMA = Schema( + ('mechanism', String('utf-8')) + ) + + +class SaslHandShakeRequest_v1(Request): + API_KEY = 17 + API_VERSION = 1 + RESPONSE_TYPE = SaslHandShakeResponse_v1 + SCHEMA = SaslHandShakeRequest_v0.SCHEMA + + +SaslHandShakeRequest = [SaslHandShakeRequest_v0, SaslHandShakeRequest_v1] +SaslHandShakeResponse = [SaslHandShakeResponse_v0, SaslHandShakeResponse_v1] + + +class DescribeAclsResponse_v0(Response): + API_KEY = 29 + API_VERSION = 0 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('error_code', Int16), + ('error_message', String('utf-8')), + ('resources', Array( + ('resource_type', Int8), + ('resource_name', String('utf-8')), + ('acls', Array( + ('principal', String('utf-8')), + ('host', String('utf-8')), + ('operation', Int8), + ('permission_type', Int8))))) + ) + + +class DescribeAclsResponse_v1(Response): + API_KEY = 29 + API_VERSION = 1 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('error_code', Int16), + ('error_message', String('utf-8')), + ('resources', Array( + ('resource_type', Int8), + ('resource_name', String('utf-8')), + ('resource_pattern_type', Int8), + ('acls', Array( + ('principal', String('utf-8')), + ('host', String('utf-8')), + ('operation', Int8), + ('permission_type', Int8))))) + ) + + +class DescribeAclsResponse_v2(Response): + API_KEY = 29 + API_VERSION = 2 + SCHEMA = DescribeAclsResponse_v1.SCHEMA + + +class DescribeAclsRequest_v0(Request): + API_KEY = 29 + API_VERSION = 0 + RESPONSE_TYPE = DescribeAclsResponse_v0 + SCHEMA = Schema( + ('resource_type', Int8), + ('resource_name', String('utf-8')), + ('principal', String('utf-8')), + ('host', String('utf-8')), + ('operation', Int8), + ('permission_type', Int8) + ) + + +class DescribeAclsRequest_v1(Request): + API_KEY = 29 + API_VERSION = 1 + RESPONSE_TYPE = DescribeAclsResponse_v1 + SCHEMA = Schema( + ('resource_type', Int8), + ('resource_name', String('utf-8')), + ('resource_pattern_type_filter', Int8), + ('principal', String('utf-8')), + ('host', String('utf-8')), + ('operation', Int8), + ('permission_type', Int8) + ) + + +class DescribeAclsRequest_v2(Request): + """ + Enable flexible version + """ + API_KEY = 29 + API_VERSION = 2 + RESPONSE_TYPE = DescribeAclsResponse_v2 + SCHEMA = DescribeAclsRequest_v1.SCHEMA + + +DescribeAclsRequest = [DescribeAclsRequest_v0, DescribeAclsRequest_v1] +DescribeAclsResponse = [DescribeAclsResponse_v0, DescribeAclsResponse_v1] + +class CreateAclsResponse_v0(Response): + API_KEY = 30 + API_VERSION = 0 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('creation_responses', Array( + ('error_code', Int16), + ('error_message', String('utf-8')))) + ) + +class CreateAclsResponse_v1(Response): + API_KEY = 30 + API_VERSION = 1 + SCHEMA = CreateAclsResponse_v0.SCHEMA + +class CreateAclsRequest_v0(Request): + API_KEY = 30 + API_VERSION = 0 + RESPONSE_TYPE = CreateAclsResponse_v0 + SCHEMA = Schema( + ('creations', Array( + ('resource_type', Int8), + ('resource_name', String('utf-8')), + ('principal', String('utf-8')), + ('host', String('utf-8')), + ('operation', Int8), + ('permission_type', Int8))) + ) + +class CreateAclsRequest_v1(Request): + API_KEY = 30 + API_VERSION = 1 + RESPONSE_TYPE = CreateAclsResponse_v1 + SCHEMA = Schema( + ('creations', Array( + ('resource_type', Int8), + ('resource_name', String('utf-8')), + ('resource_pattern_type', Int8), + ('principal', String('utf-8')), + ('host', String('utf-8')), + ('operation', Int8), + ('permission_type', Int8))) + ) + +CreateAclsRequest = [CreateAclsRequest_v0, CreateAclsRequest_v1] +CreateAclsResponse = [CreateAclsResponse_v0, CreateAclsResponse_v1] + +class DeleteAclsResponse_v0(Response): + API_KEY = 31 + API_VERSION = 0 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('filter_responses', Array( + ('error_code', Int16), + ('error_message', String('utf-8')), + ('matching_acls', Array( + ('error_code', Int16), + ('error_message', String('utf-8')), + ('resource_type', Int8), + ('resource_name', String('utf-8')), + ('principal', String('utf-8')), + ('host', String('utf-8')), + ('operation', Int8), + ('permission_type', Int8))))) + ) + +class DeleteAclsResponse_v1(Response): + API_KEY = 31 + API_VERSION = 1 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('filter_responses', Array( + ('error_code', Int16), + ('error_message', String('utf-8')), + ('matching_acls', Array( + ('error_code', Int16), + ('error_message', String('utf-8')), + ('resource_type', Int8), + ('resource_name', String('utf-8')), + ('resource_pattern_type', Int8), + ('principal', String('utf-8')), + ('host', String('utf-8')), + ('operation', Int8), + ('permission_type', Int8))))) + ) + +class DeleteAclsRequest_v0(Request): + API_KEY = 31 + API_VERSION = 0 + RESPONSE_TYPE = DeleteAclsResponse_v0 + SCHEMA = Schema( + ('filters', Array( + ('resource_type', Int8), + ('resource_name', String('utf-8')), + ('principal', String('utf-8')), + ('host', String('utf-8')), + ('operation', Int8), + ('permission_type', Int8))) + ) + +class DeleteAclsRequest_v1(Request): + API_KEY = 31 + API_VERSION = 1 + RESPONSE_TYPE = DeleteAclsResponse_v1 + SCHEMA = Schema( + ('filters', Array( + ('resource_type', Int8), + ('resource_name', String('utf-8')), + ('resource_pattern_type_filter', Int8), + ('principal', String('utf-8')), + ('host', String('utf-8')), + ('operation', Int8), + ('permission_type', Int8))) + ) + +DeleteAclsRequest = [DeleteAclsRequest_v0, DeleteAclsRequest_v1] +DeleteAclsResponse = [DeleteAclsResponse_v0, DeleteAclsResponse_v1] + +class AlterConfigsResponse_v0(Response): + API_KEY = 33 + API_VERSION = 0 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('resources', Array( + ('error_code', Int16), + ('error_message', String('utf-8')), + ('resource_type', Int8), + ('resource_name', String('utf-8')))) + ) + + +class AlterConfigsResponse_v1(Response): + API_KEY = 33 + API_VERSION = 1 + SCHEMA = AlterConfigsResponse_v0.SCHEMA + + +class AlterConfigsRequest_v0(Request): + API_KEY = 33 + API_VERSION = 0 + RESPONSE_TYPE = AlterConfigsResponse_v0 + SCHEMA = Schema( + ('resources', Array( + ('resource_type', Int8), + ('resource_name', String('utf-8')), + ('config_entries', Array( + ('config_name', String('utf-8')), + ('config_value', String('utf-8')))))), + ('validate_only', Boolean) + ) + +class AlterConfigsRequest_v1(Request): + API_KEY = 33 + API_VERSION = 1 + RESPONSE_TYPE = AlterConfigsResponse_v1 + SCHEMA = AlterConfigsRequest_v0.SCHEMA + +AlterConfigsRequest = [AlterConfigsRequest_v0, AlterConfigsRequest_v1] +AlterConfigsResponse = [AlterConfigsResponse_v0, AlterConfigsRequest_v1] + + +class DescribeConfigsResponse_v0(Response): + API_KEY = 32 + API_VERSION = 0 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('resources', Array( + ('error_code', Int16), + ('error_message', String('utf-8')), + ('resource_type', Int8), + ('resource_name', String('utf-8')), + ('config_entries', Array( + ('config_names', String('utf-8')), + ('config_value', String('utf-8')), + ('read_only', Boolean), + ('is_default', Boolean), + ('is_sensitive', Boolean))))) + ) + +class DescribeConfigsResponse_v1(Response): + API_KEY = 32 + API_VERSION = 1 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('resources', Array( + ('error_code', Int16), + ('error_message', String('utf-8')), + ('resource_type', Int8), + ('resource_name', String('utf-8')), + ('config_entries', Array( + ('config_names', String('utf-8')), + ('config_value', String('utf-8')), + ('read_only', Boolean), + ('is_default', Boolean), + ('is_sensitive', Boolean), + ('config_synonyms', Array( + ('config_name', String('utf-8')), + ('config_value', String('utf-8')), + ('config_source', Int8))))))) + ) + +class DescribeConfigsResponse_v2(Response): + API_KEY = 32 + API_VERSION = 2 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('resources', Array( + ('error_code', Int16), + ('error_message', String('utf-8')), + ('resource_type', Int8), + ('resource_name', String('utf-8')), + ('config_entries', Array( + ('config_names', String('utf-8')), + ('config_value', String('utf-8')), + ('read_only', Boolean), + ('config_source', Int8), + ('is_sensitive', Boolean), + ('config_synonyms', Array( + ('config_name', String('utf-8')), + ('config_value', String('utf-8')), + ('config_source', Int8))))))) + ) + +class DescribeConfigsRequest_v0(Request): + API_KEY = 32 + API_VERSION = 0 + RESPONSE_TYPE = DescribeConfigsResponse_v0 + SCHEMA = Schema( + ('resources', Array( + ('resource_type', Int8), + ('resource_name', String('utf-8')), + ('config_names', Array(String('utf-8'))))) + ) + +class DescribeConfigsRequest_v1(Request): + API_KEY = 32 + API_VERSION = 1 + RESPONSE_TYPE = DescribeConfigsResponse_v1 + SCHEMA = Schema( + ('resources', Array( + ('resource_type', Int8), + ('resource_name', String('utf-8')), + ('config_names', Array(String('utf-8'))))), + ('include_synonyms', Boolean) + ) + + +class DescribeConfigsRequest_v2(Request): + API_KEY = 32 + API_VERSION = 2 + RESPONSE_TYPE = DescribeConfigsResponse_v2 + SCHEMA = DescribeConfigsRequest_v1.SCHEMA + + +DescribeConfigsRequest = [ + DescribeConfigsRequest_v0, DescribeConfigsRequest_v1, + DescribeConfigsRequest_v2, +] +DescribeConfigsResponse = [ + DescribeConfigsResponse_v0, DescribeConfigsResponse_v1, + DescribeConfigsResponse_v2, +] + + +class SaslAuthenticateResponse_v0(Response): + API_KEY = 36 + API_VERSION = 0 + SCHEMA = Schema( + ('error_code', Int16), + ('error_message', String('utf-8')), + ('sasl_auth_bytes', Bytes) + ) + + +class SaslAuthenticateResponse_v1(Response): + API_KEY = 36 + API_VERSION = 1 + SCHEMA = Schema( + ('error_code', Int16), + ('error_message', String('utf-8')), + ('sasl_auth_bytes', Bytes), + ('session_lifetime_ms', Int64) + ) + + +class SaslAuthenticateRequest_v0(Request): + API_KEY = 36 + API_VERSION = 0 + RESPONSE_TYPE = SaslAuthenticateResponse_v0 + SCHEMA = Schema( + ('sasl_auth_bytes', Bytes) + ) + + +class SaslAuthenticateRequest_v1(Request): + API_KEY = 36 + API_VERSION = 1 + RESPONSE_TYPE = SaslAuthenticateResponse_v1 + SCHEMA = SaslAuthenticateRequest_v0.SCHEMA + + +SaslAuthenticateRequest = [ + SaslAuthenticateRequest_v0, SaslAuthenticateRequest_v1, +] +SaslAuthenticateResponse = [ + SaslAuthenticateResponse_v0, SaslAuthenticateResponse_v1, +] + + +class CreatePartitionsResponse_v0(Response): + API_KEY = 37 + API_VERSION = 0 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('topic_errors', Array( + ('topic', String('utf-8')), + ('error_code', Int16), + ('error_message', String('utf-8')))) + ) + + +class CreatePartitionsResponse_v1(Response): + API_KEY = 37 + API_VERSION = 1 + SCHEMA = CreatePartitionsResponse_v0.SCHEMA + + +class CreatePartitionsRequest_v0(Request): + API_KEY = 37 + API_VERSION = 0 + RESPONSE_TYPE = CreatePartitionsResponse_v0 + SCHEMA = Schema( + ('topic_partitions', Array( + ('topic', String('utf-8')), + ('new_partitions', Schema( + ('count', Int32), + ('assignment', Array(Array(Int32))))))), + ('timeout', Int32), + ('validate_only', Boolean) + ) + + +class CreatePartitionsRequest_v1(Request): + API_KEY = 37 + API_VERSION = 1 + SCHEMA = CreatePartitionsRequest_v0.SCHEMA + RESPONSE_TYPE = CreatePartitionsResponse_v1 + + +CreatePartitionsRequest = [ + CreatePartitionsRequest_v0, CreatePartitionsRequest_v1, +] +CreatePartitionsResponse = [ + CreatePartitionsResponse_v0, CreatePartitionsResponse_v1, +] + + +class DeleteGroupsResponse_v0(Response): + API_KEY = 42 + API_VERSION = 0 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ("results", Array( + ("group_id", String("utf-8")), + ("error_code", Int16))) + ) + + +class DeleteGroupsResponse_v1(Response): + API_KEY = 42 + API_VERSION = 1 + SCHEMA = DeleteGroupsResponse_v0.SCHEMA + + +class DeleteGroupsRequest_v0(Request): + API_KEY = 42 + API_VERSION = 0 + RESPONSE_TYPE = DeleteGroupsResponse_v0 + SCHEMA = Schema( + ("groups_names", Array(String("utf-8"))) + ) + + +class DeleteGroupsRequest_v1(Request): + API_KEY = 42 + API_VERSION = 1 + RESPONSE_TYPE = DeleteGroupsResponse_v1 + SCHEMA = DeleteGroupsRequest_v0.SCHEMA + + +DeleteGroupsRequest = [ + DeleteGroupsRequest_v0, DeleteGroupsRequest_v1 +] + +DeleteGroupsResponse = [ + DeleteGroupsResponse_v0, DeleteGroupsResponse_v1 +] + + +class DescribeClientQuotasResponse_v0(Request): + API_KEY = 48 + API_VERSION = 0 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('error_code', Int16), + ('error_message', String('utf-8')), + ('entries', Array( + ('entity', Array( + ('entity_type', String('utf-8')), + ('entity_name', String('utf-8')))), + ('values', Array( + ('name', String('utf-8')), + ('value', Float64))))), + ) + + +class DescribeClientQuotasRequest_v0(Request): + API_KEY = 48 + API_VERSION = 0 + RESPONSE_TYPE = DescribeClientQuotasResponse_v0 + SCHEMA = Schema( + ('components', Array( + ('entity_type', String('utf-8')), + ('match_type', Int8), + ('match', String('utf-8')), + )), + ('strict', Boolean) + ) + + +DescribeClientQuotasRequest = [ + DescribeClientQuotasRequest_v0, +] + +DescribeClientQuotasResponse = [ + DescribeClientQuotasResponse_v0, +] + + +class AlterPartitionReassignmentsResponse_v0(Response): + API_KEY = 45 + API_VERSION = 0 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ("error_code", Int16), + ("error_message", CompactString("utf-8")), + ("responses", CompactArray( + ("name", CompactString("utf-8")), + ("partitions", CompactArray( + ("partition_index", Int32), + ("error_code", Int16), + ("error_message", CompactString("utf-8")), + ("tags", TaggedFields) + )), + ("tags", TaggedFields) + )), + ("tags", TaggedFields) + ) + + +class AlterPartitionReassignmentsRequest_v0(Request): + FLEXIBLE_VERSION = True + API_KEY = 45 + API_VERSION = 0 + RESPONSE_TYPE = AlterPartitionReassignmentsResponse_v0 + SCHEMA = Schema( + ("timeout_ms", Int32), + ("topics", CompactArray( + ("name", CompactString("utf-8")), + ("partitions", CompactArray( + ("partition_index", Int32), + ("replicas", CompactArray(Int32)), + ("tags", TaggedFields) + )), + ("tags", TaggedFields) + )), + ("tags", TaggedFields) + ) + + +AlterPartitionReassignmentsRequest = [AlterPartitionReassignmentsRequest_v0] + +AlterPartitionReassignmentsResponse = [AlterPartitionReassignmentsResponse_v0] + + +class ListPartitionReassignmentsResponse_v0(Response): + API_KEY = 46 + API_VERSION = 0 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ("error_code", Int16), + ("error_message", CompactString("utf-8")), + ("topics", CompactArray( + ("name", CompactString("utf-8")), + ("partitions", CompactArray( + ("partition_index", Int32), + ("replicas", CompactArray(Int32)), + ("adding_replicas", CompactArray(Int32)), + ("removing_replicas", CompactArray(Int32)), + ("tags", TaggedFields) + )), + ("tags", TaggedFields) + )), + ("tags", TaggedFields) + ) + + +class ListPartitionReassignmentsRequest_v0(Request): + FLEXIBLE_VERSION = True + API_KEY = 46 + API_VERSION = 0 + RESPONSE_TYPE = ListPartitionReassignmentsResponse_v0 + SCHEMA = Schema( + ("timeout_ms", Int32), + ("topics", CompactArray( + ("name", CompactString("utf-8")), + ("partition_index", CompactArray(Int32)), + ("tags", TaggedFields) + )), + ("tags", TaggedFields) + ) + + +ListPartitionReassignmentsRequest = [ListPartitionReassignmentsRequest_v0] + +ListPartitionReassignmentsResponse = [ListPartitionReassignmentsResponse_v0] diff --git a/protocol/api.py b/protocol/api.py new file mode 100644 index 00000000..f12cb972 --- /dev/null +++ b/protocol/api.py @@ -0,0 +1,138 @@ +from __future__ import absolute_import + +import abc + +from kafka.protocol.struct import Struct +from kafka.protocol.types import Int16, Int32, String, Schema, Array, TaggedFields + + +class RequestHeader(Struct): + SCHEMA = Schema( + ('api_key', Int16), + ('api_version', Int16), + ('correlation_id', Int32), + ('client_id', String('utf-8')) + ) + + def __init__(self, request, correlation_id=0, client_id='kafka-python'): + super(RequestHeader, self).__init__( + request.API_KEY, request.API_VERSION, correlation_id, client_id + ) + + +class RequestHeaderV2(Struct): + # Flexible response / request headers end in field buffer + SCHEMA = Schema( + ('api_key', Int16), + ('api_version', Int16), + ('correlation_id', Int32), + ('client_id', String('utf-8')), + ('tags', TaggedFields), + ) + + def __init__(self, request, correlation_id=0, client_id='kafka-python', tags=None): + super(RequestHeaderV2, self).__init__( + request.API_KEY, request.API_VERSION, correlation_id, client_id, tags or {} + ) + + +class ResponseHeader(Struct): + SCHEMA = Schema( + ('correlation_id', Int32), + ) + + +class ResponseHeaderV2(Struct): + SCHEMA = Schema( + ('correlation_id', Int32), + ('tags', TaggedFields), + ) + + +class Request(Struct): + __metaclass__ = abc.ABCMeta + + FLEXIBLE_VERSION = False + + @abc.abstractproperty + def API_KEY(self): + """Integer identifier for api request""" + pass + + @abc.abstractproperty + def API_VERSION(self): + """Integer of api request version""" + pass + + @abc.abstractproperty + def SCHEMA(self): + """An instance of Schema() representing the request structure""" + pass + + @abc.abstractproperty + def RESPONSE_TYPE(self): + """The Response class associated with the api request""" + pass + + def expect_response(self): + """Override this method if an api request does not always generate a response""" + return True + + def to_object(self): + return _to_object(self.SCHEMA, self) + + def build_request_header(self, correlation_id, client_id): + if self.FLEXIBLE_VERSION: + return RequestHeaderV2(self, correlation_id=correlation_id, client_id=client_id) + return RequestHeader(self, correlation_id=correlation_id, client_id=client_id) + + def parse_response_header(self, read_buffer): + if self.FLEXIBLE_VERSION: + return ResponseHeaderV2.decode(read_buffer) + return ResponseHeader.decode(read_buffer) + + +class Response(Struct): + __metaclass__ = abc.ABCMeta + + @abc.abstractproperty + def API_KEY(self): + """Integer identifier for api request/response""" + pass + + @abc.abstractproperty + def API_VERSION(self): + """Integer of api request/response version""" + pass + + @abc.abstractproperty + def SCHEMA(self): + """An instance of Schema() representing the response structure""" + pass + + def to_object(self): + return _to_object(self.SCHEMA, self) + + +def _to_object(schema, data): + obj = {} + for idx, (name, _type) in enumerate(zip(schema.names, schema.fields)): + if isinstance(data, Struct): + val = data.get_item(name) + else: + val = data[idx] + + if isinstance(_type, Schema): + obj[name] = _to_object(_type, val) + elif isinstance(_type, Array): + if isinstance(_type.array_of, (Array, Schema)): + obj[name] = [ + _to_object(_type.array_of, x) + for x in val + ] + else: + obj[name] = val + else: + obj[name] = val + + return obj diff --git a/protocol/commit.py b/protocol/commit.py new file mode 100644 index 00000000..31fc2370 --- /dev/null +++ b/protocol/commit.py @@ -0,0 +1,255 @@ +from __future__ import absolute_import + +from kafka.protocol.api import Request, Response +from kafka.protocol.types import Array, Int8, Int16, Int32, Int64, Schema, String + + +class OffsetCommitResponse_v0(Response): + API_KEY = 8 + API_VERSION = 0 + SCHEMA = Schema( + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('error_code', Int16))))) + ) + + +class OffsetCommitResponse_v1(Response): + API_KEY = 8 + API_VERSION = 1 + SCHEMA = OffsetCommitResponse_v0.SCHEMA + + +class OffsetCommitResponse_v2(Response): + API_KEY = 8 + API_VERSION = 2 + SCHEMA = OffsetCommitResponse_v1.SCHEMA + + +class OffsetCommitResponse_v3(Response): + API_KEY = 8 + API_VERSION = 3 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('error_code', Int16))))) + ) + + +class OffsetCommitRequest_v0(Request): + API_KEY = 8 + API_VERSION = 0 # Zookeeper-backed storage + RESPONSE_TYPE = OffsetCommitResponse_v0 + SCHEMA = Schema( + ('consumer_group', String('utf-8')), + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('offset', Int64), + ('metadata', String('utf-8')))))) + ) + + +class OffsetCommitRequest_v1(Request): + API_KEY = 8 + API_VERSION = 1 # Kafka-backed storage + RESPONSE_TYPE = OffsetCommitResponse_v1 + SCHEMA = Schema( + ('consumer_group', String('utf-8')), + ('consumer_group_generation_id', Int32), + ('consumer_id', String('utf-8')), + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('offset', Int64), + ('timestamp', Int64), + ('metadata', String('utf-8')))))) + ) + + +class OffsetCommitRequest_v2(Request): + API_KEY = 8 + API_VERSION = 2 # added retention_time, dropped timestamp + RESPONSE_TYPE = OffsetCommitResponse_v2 + SCHEMA = Schema( + ('consumer_group', String('utf-8')), + ('consumer_group_generation_id', Int32), + ('consumer_id', String('utf-8')), + ('retention_time', Int64), + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('offset', Int64), + ('metadata', String('utf-8')))))) + ) + DEFAULT_GENERATION_ID = -1 + DEFAULT_RETENTION_TIME = -1 + + +class OffsetCommitRequest_v3(Request): + API_KEY = 8 + API_VERSION = 3 + RESPONSE_TYPE = OffsetCommitResponse_v3 + SCHEMA = OffsetCommitRequest_v2.SCHEMA + + +OffsetCommitRequest = [ + OffsetCommitRequest_v0, OffsetCommitRequest_v1, + OffsetCommitRequest_v2, OffsetCommitRequest_v3 +] +OffsetCommitResponse = [ + OffsetCommitResponse_v0, OffsetCommitResponse_v1, + OffsetCommitResponse_v2, OffsetCommitResponse_v3 +] + + +class OffsetFetchResponse_v0(Response): + API_KEY = 9 + API_VERSION = 0 + SCHEMA = Schema( + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('offset', Int64), + ('metadata', String('utf-8')), + ('error_code', Int16))))) + ) + + +class OffsetFetchResponse_v1(Response): + API_KEY = 9 + API_VERSION = 1 + SCHEMA = OffsetFetchResponse_v0.SCHEMA + + +class OffsetFetchResponse_v2(Response): + # Added in KIP-88 + API_KEY = 9 + API_VERSION = 2 + SCHEMA = Schema( + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('offset', Int64), + ('metadata', String('utf-8')), + ('error_code', Int16))))), + ('error_code', Int16) + ) + + +class OffsetFetchResponse_v3(Response): + API_KEY = 9 + API_VERSION = 3 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('offset', Int64), + ('metadata', String('utf-8')), + ('error_code', Int16))))), + ('error_code', Int16) + ) + + +class OffsetFetchRequest_v0(Request): + API_KEY = 9 + API_VERSION = 0 # zookeeper-backed storage + RESPONSE_TYPE = OffsetFetchResponse_v0 + SCHEMA = Schema( + ('consumer_group', String('utf-8')), + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array(Int32)))) + ) + + +class OffsetFetchRequest_v1(Request): + API_KEY = 9 + API_VERSION = 1 # kafka-backed storage + RESPONSE_TYPE = OffsetFetchResponse_v1 + SCHEMA = OffsetFetchRequest_v0.SCHEMA + + +class OffsetFetchRequest_v2(Request): + # KIP-88: Allows passing null topics to return offsets for all partitions + # that the consumer group has a stored offset for, even if no consumer in + # the group is currently consuming that partition. + API_KEY = 9 + API_VERSION = 2 + RESPONSE_TYPE = OffsetFetchResponse_v2 + SCHEMA = OffsetFetchRequest_v1.SCHEMA + + +class OffsetFetchRequest_v3(Request): + API_KEY = 9 + API_VERSION = 3 + RESPONSE_TYPE = OffsetFetchResponse_v3 + SCHEMA = OffsetFetchRequest_v2.SCHEMA + + +OffsetFetchRequest = [ + OffsetFetchRequest_v0, OffsetFetchRequest_v1, + OffsetFetchRequest_v2, OffsetFetchRequest_v3, +] +OffsetFetchResponse = [ + OffsetFetchResponse_v0, OffsetFetchResponse_v1, + OffsetFetchResponse_v2, OffsetFetchResponse_v3, +] + + +class GroupCoordinatorResponse_v0(Response): + API_KEY = 10 + API_VERSION = 0 + SCHEMA = Schema( + ('error_code', Int16), + ('coordinator_id', Int32), + ('host', String('utf-8')), + ('port', Int32) + ) + + +class GroupCoordinatorResponse_v1(Response): + API_KEY = 10 + API_VERSION = 1 + SCHEMA = Schema( + ('error_code', Int16), + ('error_message', String('utf-8')), + ('coordinator_id', Int32), + ('host', String('utf-8')), + ('port', Int32) + ) + + +class GroupCoordinatorRequest_v0(Request): + API_KEY = 10 + API_VERSION = 0 + RESPONSE_TYPE = GroupCoordinatorResponse_v0 + SCHEMA = Schema( + ('consumer_group', String('utf-8')) + ) + + +class GroupCoordinatorRequest_v1(Request): + API_KEY = 10 + API_VERSION = 1 + RESPONSE_TYPE = GroupCoordinatorResponse_v1 + SCHEMA = Schema( + ('coordinator_key', String('utf-8')), + ('coordinator_type', Int8) + ) + + +GroupCoordinatorRequest = [GroupCoordinatorRequest_v0, GroupCoordinatorRequest_v1] +GroupCoordinatorResponse = [GroupCoordinatorResponse_v0, GroupCoordinatorResponse_v1] diff --git a/protocol/fetch.py b/protocol/fetch.py new file mode 100644 index 00000000..f367848c --- /dev/null +++ b/protocol/fetch.py @@ -0,0 +1,386 @@ +from __future__ import absolute_import + +from kafka.protocol.api import Request, Response +from kafka.protocol.types import Array, Int8, Int16, Int32, Int64, Schema, String, Bytes + + +class FetchResponse_v0(Response): + API_KEY = 1 + API_VERSION = 0 + SCHEMA = Schema( + ('topics', Array( + ('topics', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('error_code', Int16), + ('highwater_offset', Int64), + ('message_set', Bytes))))) + ) + + +class FetchResponse_v1(Response): + API_KEY = 1 + API_VERSION = 1 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('topics', Array( + ('topics', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('error_code', Int16), + ('highwater_offset', Int64), + ('message_set', Bytes))))) + ) + + +class FetchResponse_v2(Response): + API_KEY = 1 + API_VERSION = 2 + SCHEMA = FetchResponse_v1.SCHEMA # message format changed internally + + +class FetchResponse_v3(Response): + API_KEY = 1 + API_VERSION = 3 + SCHEMA = FetchResponse_v2.SCHEMA + + +class FetchResponse_v4(Response): + API_KEY = 1 + API_VERSION = 4 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('topics', Array( + ('topics', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('error_code', Int16), + ('highwater_offset', Int64), + ('last_stable_offset', Int64), + ('aborted_transactions', Array( + ('producer_id', Int64), + ('first_offset', Int64))), + ('message_set', Bytes))))) + ) + + +class FetchResponse_v5(Response): + API_KEY = 1 + API_VERSION = 5 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('topics', Array( + ('topics', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('error_code', Int16), + ('highwater_offset', Int64), + ('last_stable_offset', Int64), + ('log_start_offset', Int64), + ('aborted_transactions', Array( + ('producer_id', Int64), + ('first_offset', Int64))), + ('message_set', Bytes))))) + ) + + +class FetchResponse_v6(Response): + """ + Same as FetchResponse_v5. The version number is bumped up to indicate that the client supports KafkaStorageException. + The KafkaStorageException will be translated to NotLeaderForPartitionException in the response if version <= 5 + """ + API_KEY = 1 + API_VERSION = 6 + SCHEMA = FetchResponse_v5.SCHEMA + + +class FetchResponse_v7(Response): + """ + Add error_code and session_id to response + """ + API_KEY = 1 + API_VERSION = 7 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('error_code', Int16), + ('session_id', Int32), + ('topics', Array( + ('topics', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('error_code', Int16), + ('highwater_offset', Int64), + ('last_stable_offset', Int64), + ('log_start_offset', Int64), + ('aborted_transactions', Array( + ('producer_id', Int64), + ('first_offset', Int64))), + ('message_set', Bytes))))) + ) + + +class FetchResponse_v8(Response): + API_KEY = 1 + API_VERSION = 8 + SCHEMA = FetchResponse_v7.SCHEMA + + +class FetchResponse_v9(Response): + API_KEY = 1 + API_VERSION = 9 + SCHEMA = FetchResponse_v7.SCHEMA + + +class FetchResponse_v10(Response): + API_KEY = 1 + API_VERSION = 10 + SCHEMA = FetchResponse_v7.SCHEMA + + +class FetchResponse_v11(Response): + API_KEY = 1 + API_VERSION = 11 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('error_code', Int16), + ('session_id', Int32), + ('topics', Array( + ('topics', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('error_code', Int16), + ('highwater_offset', Int64), + ('last_stable_offset', Int64), + ('log_start_offset', Int64), + ('aborted_transactions', Array( + ('producer_id', Int64), + ('first_offset', Int64))), + ('preferred_read_replica', Int32), + ('message_set', Bytes))))) + ) + + +class FetchRequest_v0(Request): + API_KEY = 1 + API_VERSION = 0 + RESPONSE_TYPE = FetchResponse_v0 + SCHEMA = Schema( + ('replica_id', Int32), + ('max_wait_time', Int32), + ('min_bytes', Int32), + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('offset', Int64), + ('max_bytes', Int32))))) + ) + + +class FetchRequest_v1(Request): + API_KEY = 1 + API_VERSION = 1 + RESPONSE_TYPE = FetchResponse_v1 + SCHEMA = FetchRequest_v0.SCHEMA + + +class FetchRequest_v2(Request): + API_KEY = 1 + API_VERSION = 2 + RESPONSE_TYPE = FetchResponse_v2 + SCHEMA = FetchRequest_v1.SCHEMA + + +class FetchRequest_v3(Request): + API_KEY = 1 + API_VERSION = 3 + RESPONSE_TYPE = FetchResponse_v3 + SCHEMA = Schema( + ('replica_id', Int32), + ('max_wait_time', Int32), + ('min_bytes', Int32), + ('max_bytes', Int32), # This new field is only difference from FR_v2 + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('offset', Int64), + ('max_bytes', Int32))))) + ) + + +class FetchRequest_v4(Request): + # Adds isolation_level field + API_KEY = 1 + API_VERSION = 4 + RESPONSE_TYPE = FetchResponse_v4 + SCHEMA = Schema( + ('replica_id', Int32), + ('max_wait_time', Int32), + ('min_bytes', Int32), + ('max_bytes', Int32), + ('isolation_level', Int8), + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('offset', Int64), + ('max_bytes', Int32))))) + ) + + +class FetchRequest_v5(Request): + # This may only be used in broker-broker api calls + API_KEY = 1 + API_VERSION = 5 + RESPONSE_TYPE = FetchResponse_v5 + SCHEMA = Schema( + ('replica_id', Int32), + ('max_wait_time', Int32), + ('min_bytes', Int32), + ('max_bytes', Int32), + ('isolation_level', Int8), + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('fetch_offset', Int64), + ('log_start_offset', Int64), + ('max_bytes', Int32))))) + ) + + +class FetchRequest_v6(Request): + """ + The body of FETCH_REQUEST_V6 is the same as FETCH_REQUEST_V5. + The version number is bumped up to indicate that the client supports KafkaStorageException. + The KafkaStorageException will be translated to NotLeaderForPartitionException in the response if version <= 5 + """ + API_KEY = 1 + API_VERSION = 6 + RESPONSE_TYPE = FetchResponse_v6 + SCHEMA = FetchRequest_v5.SCHEMA + + +class FetchRequest_v7(Request): + """ + Add incremental fetch requests + """ + API_KEY = 1 + API_VERSION = 7 + RESPONSE_TYPE = FetchResponse_v7 + SCHEMA = Schema( + ('replica_id', Int32), + ('max_wait_time', Int32), + ('min_bytes', Int32), + ('max_bytes', Int32), + ('isolation_level', Int8), + ('session_id', Int32), + ('session_epoch', Int32), + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('fetch_offset', Int64), + ('log_start_offset', Int64), + ('max_bytes', Int32))))), + ('forgotten_topics_data', Array( + ('topic', String), + ('partitions', Array(Int32)) + )), + ) + + +class FetchRequest_v8(Request): + """ + bump used to indicate that on quota violation brokers send out responses before throttling. + """ + API_KEY = 1 + API_VERSION = 8 + RESPONSE_TYPE = FetchResponse_v8 + SCHEMA = FetchRequest_v7.SCHEMA + + +class FetchRequest_v9(Request): + """ + adds the current leader epoch (see KIP-320) + """ + API_KEY = 1 + API_VERSION = 9 + RESPONSE_TYPE = FetchResponse_v9 + SCHEMA = Schema( + ('replica_id', Int32), + ('max_wait_time', Int32), + ('min_bytes', Int32), + ('max_bytes', Int32), + ('isolation_level', Int8), + ('session_id', Int32), + ('session_epoch', Int32), + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('current_leader_epoch', Int32), + ('fetch_offset', Int64), + ('log_start_offset', Int64), + ('max_bytes', Int32))))), + ('forgotten_topics_data', Array( + ('topic', String), + ('partitions', Array(Int32)), + )), + ) + + +class FetchRequest_v10(Request): + """ + bumped up to indicate ZStandard capability. (see KIP-110) + """ + API_KEY = 1 + API_VERSION = 10 + RESPONSE_TYPE = FetchResponse_v10 + SCHEMA = FetchRequest_v9.SCHEMA + + +class FetchRequest_v11(Request): + """ + added rack ID to support read from followers (KIP-392) + """ + API_KEY = 1 + API_VERSION = 11 + RESPONSE_TYPE = FetchResponse_v11 + SCHEMA = Schema( + ('replica_id', Int32), + ('max_wait_time', Int32), + ('min_bytes', Int32), + ('max_bytes', Int32), + ('isolation_level', Int8), + ('session_id', Int32), + ('session_epoch', Int32), + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('current_leader_epoch', Int32), + ('fetch_offset', Int64), + ('log_start_offset', Int64), + ('max_bytes', Int32))))), + ('forgotten_topics_data', Array( + ('topic', String), + ('partitions', Array(Int32)) + )), + ('rack_id', String('utf-8')), + ) + + +FetchRequest = [ + FetchRequest_v0, FetchRequest_v1, FetchRequest_v2, + FetchRequest_v3, FetchRequest_v4, FetchRequest_v5, + FetchRequest_v6, FetchRequest_v7, FetchRequest_v8, + FetchRequest_v9, FetchRequest_v10, FetchRequest_v11, +] +FetchResponse = [ + FetchResponse_v0, FetchResponse_v1, FetchResponse_v2, + FetchResponse_v3, FetchResponse_v4, FetchResponse_v5, + FetchResponse_v6, FetchResponse_v7, FetchResponse_v8, + FetchResponse_v9, FetchResponse_v10, FetchResponse_v11, +] diff --git a/protocol/frame.py b/protocol/frame.py new file mode 100644 index 00000000..7b4a32bc --- /dev/null +++ b/protocol/frame.py @@ -0,0 +1,30 @@ +class KafkaBytes(bytearray): + def __init__(self, size): + super(KafkaBytes, self).__init__(size) + self._idx = 0 + + def read(self, nbytes=None): + if nbytes is None: + nbytes = len(self) - self._idx + start = self._idx + self._idx += nbytes + if self._idx > len(self): + self._idx = len(self) + return bytes(self[start:self._idx]) + + def write(self, data): + start = self._idx + self._idx += len(data) + self[start:self._idx] = data + + def seek(self, idx): + self._idx = idx + + def tell(self): + return self._idx + + def __str__(self): + return 'KafkaBytes(%d)' % len(self) + + def __repr__(self): + return str(self) diff --git a/protocol/group.py b/protocol/group.py new file mode 100644 index 00000000..bcb96553 --- /dev/null +++ b/protocol/group.py @@ -0,0 +1,230 @@ +from __future__ import absolute_import + +from kafka.protocol.api import Request, Response +from kafka.protocol.struct import Struct +from kafka.protocol.types import Array, Bytes, Int16, Int32, Schema, String + + +class JoinGroupResponse_v0(Response): + API_KEY = 11 + API_VERSION = 0 + SCHEMA = Schema( + ('error_code', Int16), + ('generation_id', Int32), + ('group_protocol', String('utf-8')), + ('leader_id', String('utf-8')), + ('member_id', String('utf-8')), + ('members', Array( + ('member_id', String('utf-8')), + ('member_metadata', Bytes))) + ) + + +class JoinGroupResponse_v1(Response): + API_KEY = 11 + API_VERSION = 1 + SCHEMA = JoinGroupResponse_v0.SCHEMA + + +class JoinGroupResponse_v2(Response): + API_KEY = 11 + API_VERSION = 2 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('error_code', Int16), + ('generation_id', Int32), + ('group_protocol', String('utf-8')), + ('leader_id', String('utf-8')), + ('member_id', String('utf-8')), + ('members', Array( + ('member_id', String('utf-8')), + ('member_metadata', Bytes))) + ) + + +class JoinGroupRequest_v0(Request): + API_KEY = 11 + API_VERSION = 0 + RESPONSE_TYPE = JoinGroupResponse_v0 + SCHEMA = Schema( + ('group', String('utf-8')), + ('session_timeout', Int32), + ('member_id', String('utf-8')), + ('protocol_type', String('utf-8')), + ('group_protocols', Array( + ('protocol_name', String('utf-8')), + ('protocol_metadata', Bytes))) + ) + UNKNOWN_MEMBER_ID = '' + + +class JoinGroupRequest_v1(Request): + API_KEY = 11 + API_VERSION = 1 + RESPONSE_TYPE = JoinGroupResponse_v1 + SCHEMA = Schema( + ('group', String('utf-8')), + ('session_timeout', Int32), + ('rebalance_timeout', Int32), + ('member_id', String('utf-8')), + ('protocol_type', String('utf-8')), + ('group_protocols', Array( + ('protocol_name', String('utf-8')), + ('protocol_metadata', Bytes))) + ) + UNKNOWN_MEMBER_ID = '' + + +class JoinGroupRequest_v2(Request): + API_KEY = 11 + API_VERSION = 2 + RESPONSE_TYPE = JoinGroupResponse_v2 + SCHEMA = JoinGroupRequest_v1.SCHEMA + UNKNOWN_MEMBER_ID = '' + + +JoinGroupRequest = [ + JoinGroupRequest_v0, JoinGroupRequest_v1, JoinGroupRequest_v2 +] +JoinGroupResponse = [ + JoinGroupResponse_v0, JoinGroupResponse_v1, JoinGroupResponse_v2 +] + + +class ProtocolMetadata(Struct): + SCHEMA = Schema( + ('version', Int16), + ('subscription', Array(String('utf-8'))), # topics list + ('user_data', Bytes) + ) + + +class SyncGroupResponse_v0(Response): + API_KEY = 14 + API_VERSION = 0 + SCHEMA = Schema( + ('error_code', Int16), + ('member_assignment', Bytes) + ) + + +class SyncGroupResponse_v1(Response): + API_KEY = 14 + API_VERSION = 1 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('error_code', Int16), + ('member_assignment', Bytes) + ) + + +class SyncGroupRequest_v0(Request): + API_KEY = 14 + API_VERSION = 0 + RESPONSE_TYPE = SyncGroupResponse_v0 + SCHEMA = Schema( + ('group', String('utf-8')), + ('generation_id', Int32), + ('member_id', String('utf-8')), + ('group_assignment', Array( + ('member_id', String('utf-8')), + ('member_metadata', Bytes))) + ) + + +class SyncGroupRequest_v1(Request): + API_KEY = 14 + API_VERSION = 1 + RESPONSE_TYPE = SyncGroupResponse_v1 + SCHEMA = SyncGroupRequest_v0.SCHEMA + + +SyncGroupRequest = [SyncGroupRequest_v0, SyncGroupRequest_v1] +SyncGroupResponse = [SyncGroupResponse_v0, SyncGroupResponse_v1] + + +class MemberAssignment(Struct): + SCHEMA = Schema( + ('version', Int16), + ('assignment', Array( + ('topic', String('utf-8')), + ('partitions', Array(Int32)))), + ('user_data', Bytes) + ) + + +class HeartbeatResponse_v0(Response): + API_KEY = 12 + API_VERSION = 0 + SCHEMA = Schema( + ('error_code', Int16) + ) + + +class HeartbeatResponse_v1(Response): + API_KEY = 12 + API_VERSION = 1 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('error_code', Int16) + ) + + +class HeartbeatRequest_v0(Request): + API_KEY = 12 + API_VERSION = 0 + RESPONSE_TYPE = HeartbeatResponse_v0 + SCHEMA = Schema( + ('group', String('utf-8')), + ('generation_id', Int32), + ('member_id', String('utf-8')) + ) + + +class HeartbeatRequest_v1(Request): + API_KEY = 12 + API_VERSION = 1 + RESPONSE_TYPE = HeartbeatResponse_v1 + SCHEMA = HeartbeatRequest_v0.SCHEMA + + +HeartbeatRequest = [HeartbeatRequest_v0, HeartbeatRequest_v1] +HeartbeatResponse = [HeartbeatResponse_v0, HeartbeatResponse_v1] + + +class LeaveGroupResponse_v0(Response): + API_KEY = 13 + API_VERSION = 0 + SCHEMA = Schema( + ('error_code', Int16) + ) + + +class LeaveGroupResponse_v1(Response): + API_KEY = 13 + API_VERSION = 1 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('error_code', Int16) + ) + + +class LeaveGroupRequest_v0(Request): + API_KEY = 13 + API_VERSION = 0 + RESPONSE_TYPE = LeaveGroupResponse_v0 + SCHEMA = Schema( + ('group', String('utf-8')), + ('member_id', String('utf-8')) + ) + + +class LeaveGroupRequest_v1(Request): + API_KEY = 13 + API_VERSION = 1 + RESPONSE_TYPE = LeaveGroupResponse_v1 + SCHEMA = LeaveGroupRequest_v0.SCHEMA + + +LeaveGroupRequest = [LeaveGroupRequest_v0, LeaveGroupRequest_v1] +LeaveGroupResponse = [LeaveGroupResponse_v0, LeaveGroupResponse_v1] diff --git a/protocol/message.py b/protocol/message.py new file mode 100644 index 00000000..4c5c031b --- /dev/null +++ b/protocol/message.py @@ -0,0 +1,216 @@ +from __future__ import absolute_import + +import io +import time + +from kafka.codec import (has_gzip, has_snappy, has_lz4, has_zstd, + gzip_decode, snappy_decode, zstd_decode, + lz4_decode, lz4_decode_old_kafka) +from kafka.protocol.frame import KafkaBytes +from kafka.protocol.struct import Struct +from kafka.protocol.types import ( + Int8, Int32, Int64, Bytes, Schema, AbstractType +) +from kafka.util import crc32, WeakMethod + + +class Message(Struct): + SCHEMAS = [ + Schema( + ('crc', Int32), + ('magic', Int8), + ('attributes', Int8), + ('key', Bytes), + ('value', Bytes)), + Schema( + ('crc', Int32), + ('magic', Int8), + ('attributes', Int8), + ('timestamp', Int64), + ('key', Bytes), + ('value', Bytes)), + ] + SCHEMA = SCHEMAS[1] + CODEC_MASK = 0x07 + CODEC_GZIP = 0x01 + CODEC_SNAPPY = 0x02 + CODEC_LZ4 = 0x03 + CODEC_ZSTD = 0x04 + TIMESTAMP_TYPE_MASK = 0x08 + HEADER_SIZE = 22 # crc(4), magic(1), attributes(1), timestamp(8), key+value size(4*2) + + def __init__(self, value, key=None, magic=0, attributes=0, crc=0, + timestamp=None): + assert value is None or isinstance(value, bytes), 'value must be bytes' + assert key is None or isinstance(key, bytes), 'key must be bytes' + assert magic > 0 or timestamp is None, 'timestamp not supported in v0' + + # Default timestamp to now for v1 messages + if magic > 0 and timestamp is None: + timestamp = int(time.time() * 1000) + self.timestamp = timestamp + self.crc = crc + self._validated_crc = None + self.magic = magic + self.attributes = attributes + self.key = key + self.value = value + self.encode = WeakMethod(self._encode_self) + + @property + def timestamp_type(self): + """0 for CreateTime; 1 for LogAppendTime; None if unsupported. + + Value is determined by broker; produced messages should always set to 0 + Requires Kafka >= 0.10 / message version >= 1 + """ + if self.magic == 0: + return None + elif self.attributes & self.TIMESTAMP_TYPE_MASK: + return 1 + else: + return 0 + + def _encode_self(self, recalc_crc=True): + version = self.magic + if version == 1: + fields = (self.crc, self.magic, self.attributes, self.timestamp, self.key, self.value) + elif version == 0: + fields = (self.crc, self.magic, self.attributes, self.key, self.value) + else: + raise ValueError('Unrecognized message version: %s' % (version,)) + message = Message.SCHEMAS[version].encode(fields) + if not recalc_crc: + return message + self.crc = crc32(message[4:]) + crc_field = self.SCHEMAS[version].fields[0] + return crc_field.encode(self.crc) + message[4:] + + @classmethod + def decode(cls, data): + _validated_crc = None + if isinstance(data, bytes): + _validated_crc = crc32(data[4:]) + data = io.BytesIO(data) + # Partial decode required to determine message version + base_fields = cls.SCHEMAS[0].fields[0:3] + crc, magic, attributes = [field.decode(data) for field in base_fields] + remaining = cls.SCHEMAS[magic].fields[3:] + fields = [field.decode(data) for field in remaining] + if magic == 1: + timestamp = fields[0] + else: + timestamp = None + msg = cls(fields[-1], key=fields[-2], + magic=magic, attributes=attributes, crc=crc, + timestamp=timestamp) + msg._validated_crc = _validated_crc + return msg + + def validate_crc(self): + if self._validated_crc is None: + raw_msg = self._encode_self(recalc_crc=False) + self._validated_crc = crc32(raw_msg[4:]) + if self.crc == self._validated_crc: + return True + return False + + def is_compressed(self): + return self.attributes & self.CODEC_MASK != 0 + + def decompress(self): + codec = self.attributes & self.CODEC_MASK + assert codec in (self.CODEC_GZIP, self.CODEC_SNAPPY, self.CODEC_LZ4, self.CODEC_ZSTD) + if codec == self.CODEC_GZIP: + assert has_gzip(), 'Gzip decompression unsupported' + raw_bytes = gzip_decode(self.value) + elif codec == self.CODEC_SNAPPY: + assert has_snappy(), 'Snappy decompression unsupported' + raw_bytes = snappy_decode(self.value) + elif codec == self.CODEC_LZ4: + assert has_lz4(), 'LZ4 decompression unsupported' + if self.magic == 0: + raw_bytes = lz4_decode_old_kafka(self.value) + else: + raw_bytes = lz4_decode(self.value) + elif codec == self.CODEC_ZSTD: + assert has_zstd(), "ZSTD decompression unsupported" + raw_bytes = zstd_decode(self.value) + else: + raise Exception('This should be impossible') + + return MessageSet.decode(raw_bytes, bytes_to_read=len(raw_bytes)) + + def __hash__(self): + return hash(self._encode_self(recalc_crc=False)) + + +class PartialMessage(bytes): + def __repr__(self): + return 'PartialMessage(%s)' % (self,) + + +class MessageSet(AbstractType): + ITEM = Schema( + ('offset', Int64), + ('message', Bytes) + ) + HEADER_SIZE = 12 # offset + message_size + + @classmethod + def encode(cls, items, prepend_size=True): + # RecordAccumulator encodes messagesets internally + if isinstance(items, (io.BytesIO, KafkaBytes)): + size = Int32.decode(items) + if prepend_size: + # rewind and return all the bytes + items.seek(items.tell() - 4) + size += 4 + return items.read(size) + + encoded_values = [] + for (offset, message) in items: + encoded_values.append(Int64.encode(offset)) + encoded_values.append(Bytes.encode(message)) + encoded = b''.join(encoded_values) + if prepend_size: + return Bytes.encode(encoded) + else: + return encoded + + @classmethod + def decode(cls, data, bytes_to_read=None): + """Compressed messages should pass in bytes_to_read (via message size) + otherwise, we decode from data as Int32 + """ + if isinstance(data, bytes): + data = io.BytesIO(data) + if bytes_to_read is None: + bytes_to_read = Int32.decode(data) + + # if FetchRequest max_bytes is smaller than the available message set + # the server returns partial data for the final message + # So create an internal buffer to avoid over-reading + raw = io.BytesIO(data.read(bytes_to_read)) + + items = [] + while bytes_to_read: + try: + offset = Int64.decode(raw) + msg_bytes = Bytes.decode(raw) + bytes_to_read -= 8 + 4 + len(msg_bytes) + items.append((offset, len(msg_bytes), Message.decode(msg_bytes))) + except ValueError: + # PartialMessage to signal that max_bytes may be too small + items.append((None, None, PartialMessage())) + break + return items + + @classmethod + def repr(cls, messages): + if isinstance(messages, (KafkaBytes, io.BytesIO)): + offset = messages.tell() + decoded = cls.decode(messages) + messages.seek(offset) + messages = decoded + return str([cls.ITEM.repr(m) for m in messages]) diff --git a/protocol/metadata.py b/protocol/metadata.py new file mode 100644 index 00000000..414e5b84 --- /dev/null +++ b/protocol/metadata.py @@ -0,0 +1,200 @@ +from __future__ import absolute_import + +from kafka.protocol.api import Request, Response +from kafka.protocol.types import Array, Boolean, Int16, Int32, Schema, String + + +class MetadataResponse_v0(Response): + API_KEY = 3 + API_VERSION = 0 + SCHEMA = Schema( + ('brokers', Array( + ('node_id', Int32), + ('host', String('utf-8')), + ('port', Int32))), + ('topics', Array( + ('error_code', Int16), + ('topic', String('utf-8')), + ('partitions', Array( + ('error_code', Int16), + ('partition', Int32), + ('leader', Int32), + ('replicas', Array(Int32)), + ('isr', Array(Int32)))))) + ) + + +class MetadataResponse_v1(Response): + API_KEY = 3 + API_VERSION = 1 + SCHEMA = Schema( + ('brokers', Array( + ('node_id', Int32), + ('host', String('utf-8')), + ('port', Int32), + ('rack', String('utf-8')))), + ('controller_id', Int32), + ('topics', Array( + ('error_code', Int16), + ('topic', String('utf-8')), + ('is_internal', Boolean), + ('partitions', Array( + ('error_code', Int16), + ('partition', Int32), + ('leader', Int32), + ('replicas', Array(Int32)), + ('isr', Array(Int32)))))) + ) + + +class MetadataResponse_v2(Response): + API_KEY = 3 + API_VERSION = 2 + SCHEMA = Schema( + ('brokers', Array( + ('node_id', Int32), + ('host', String('utf-8')), + ('port', Int32), + ('rack', String('utf-8')))), + ('cluster_id', String('utf-8')), # <-- Added cluster_id field in v2 + ('controller_id', Int32), + ('topics', Array( + ('error_code', Int16), + ('topic', String('utf-8')), + ('is_internal', Boolean), + ('partitions', Array( + ('error_code', Int16), + ('partition', Int32), + ('leader', Int32), + ('replicas', Array(Int32)), + ('isr', Array(Int32)))))) + ) + + +class MetadataResponse_v3(Response): + API_KEY = 3 + API_VERSION = 3 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('brokers', Array( + ('node_id', Int32), + ('host', String('utf-8')), + ('port', Int32), + ('rack', String('utf-8')))), + ('cluster_id', String('utf-8')), + ('controller_id', Int32), + ('topics', Array( + ('error_code', Int16), + ('topic', String('utf-8')), + ('is_internal', Boolean), + ('partitions', Array( + ('error_code', Int16), + ('partition', Int32), + ('leader', Int32), + ('replicas', Array(Int32)), + ('isr', Array(Int32)))))) + ) + + +class MetadataResponse_v4(Response): + API_KEY = 3 + API_VERSION = 4 + SCHEMA = MetadataResponse_v3.SCHEMA + + +class MetadataResponse_v5(Response): + API_KEY = 3 + API_VERSION = 5 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('brokers', Array( + ('node_id', Int32), + ('host', String('utf-8')), + ('port', Int32), + ('rack', String('utf-8')))), + ('cluster_id', String('utf-8')), + ('controller_id', Int32), + ('topics', Array( + ('error_code', Int16), + ('topic', String('utf-8')), + ('is_internal', Boolean), + ('partitions', Array( + ('error_code', Int16), + ('partition', Int32), + ('leader', Int32), + ('replicas', Array(Int32)), + ('isr', Array(Int32)), + ('offline_replicas', Array(Int32)))))) + ) + + +class MetadataRequest_v0(Request): + API_KEY = 3 + API_VERSION = 0 + RESPONSE_TYPE = MetadataResponse_v0 + SCHEMA = Schema( + ('topics', Array(String('utf-8'))) + ) + ALL_TOPICS = None # Empty Array (len 0) for topics returns all topics + + +class MetadataRequest_v1(Request): + API_KEY = 3 + API_VERSION = 1 + RESPONSE_TYPE = MetadataResponse_v1 + SCHEMA = MetadataRequest_v0.SCHEMA + ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics + NO_TOPICS = None # Empty array (len 0) for topics returns no topics + + +class MetadataRequest_v2(Request): + API_KEY = 3 + API_VERSION = 2 + RESPONSE_TYPE = MetadataResponse_v2 + SCHEMA = MetadataRequest_v1.SCHEMA + ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics + NO_TOPICS = None # Empty array (len 0) for topics returns no topics + + +class MetadataRequest_v3(Request): + API_KEY = 3 + API_VERSION = 3 + RESPONSE_TYPE = MetadataResponse_v3 + SCHEMA = MetadataRequest_v1.SCHEMA + ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics + NO_TOPICS = None # Empty array (len 0) for topics returns no topics + + +class MetadataRequest_v4(Request): + API_KEY = 3 + API_VERSION = 4 + RESPONSE_TYPE = MetadataResponse_v4 + SCHEMA = Schema( + ('topics', Array(String('utf-8'))), + ('allow_auto_topic_creation', Boolean) + ) + ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics + NO_TOPICS = None # Empty array (len 0) for topics returns no topics + + +class MetadataRequest_v5(Request): + """ + The v5 metadata request is the same as v4. + An additional field for offline_replicas has been added to the v5 metadata response + """ + API_KEY = 3 + API_VERSION = 5 + RESPONSE_TYPE = MetadataResponse_v5 + SCHEMA = MetadataRequest_v4.SCHEMA + ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics + NO_TOPICS = None # Empty array (len 0) for topics returns no topics + + +MetadataRequest = [ + MetadataRequest_v0, MetadataRequest_v1, MetadataRequest_v2, + MetadataRequest_v3, MetadataRequest_v4, MetadataRequest_v5 +] +MetadataResponse = [ + MetadataResponse_v0, MetadataResponse_v1, MetadataResponse_v2, + MetadataResponse_v3, MetadataResponse_v4, MetadataResponse_v5 +] diff --git a/protocol/offset.py b/protocol/offset.py new file mode 100644 index 00000000..1ed382b0 --- /dev/null +++ b/protocol/offset.py @@ -0,0 +1,194 @@ +from __future__ import absolute_import + +from kafka.protocol.api import Request, Response +from kafka.protocol.types import Array, Int8, Int16, Int32, Int64, Schema, String + +UNKNOWN_OFFSET = -1 + + +class OffsetResetStrategy(object): + LATEST = -1 + EARLIEST = -2 + NONE = 0 + + +class OffsetResponse_v0(Response): + API_KEY = 2 + API_VERSION = 0 + SCHEMA = Schema( + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('error_code', Int16), + ('offsets', Array(Int64)))))) + ) + +class OffsetResponse_v1(Response): + API_KEY = 2 + API_VERSION = 1 + SCHEMA = Schema( + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('error_code', Int16), + ('timestamp', Int64), + ('offset', Int64))))) + ) + + +class OffsetResponse_v2(Response): + API_KEY = 2 + API_VERSION = 2 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('error_code', Int16), + ('timestamp', Int64), + ('offset', Int64))))) + ) + + +class OffsetResponse_v3(Response): + """ + on quota violation, brokers send out responses before throttling + """ + API_KEY = 2 + API_VERSION = 3 + SCHEMA = OffsetResponse_v2.SCHEMA + + +class OffsetResponse_v4(Response): + """ + Add leader_epoch to response + """ + API_KEY = 2 + API_VERSION = 4 + SCHEMA = Schema( + ('throttle_time_ms', Int32), + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('error_code', Int16), + ('timestamp', Int64), + ('offset', Int64), + ('leader_epoch', Int32))))) + ) + + +class OffsetResponse_v5(Response): + """ + adds a new error code, OFFSET_NOT_AVAILABLE + """ + API_KEY = 2 + API_VERSION = 5 + SCHEMA = OffsetResponse_v4.SCHEMA + + +class OffsetRequest_v0(Request): + API_KEY = 2 + API_VERSION = 0 + RESPONSE_TYPE = OffsetResponse_v0 + SCHEMA = Schema( + ('replica_id', Int32), + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('timestamp', Int64), + ('max_offsets', Int32))))) + ) + DEFAULTS = { + 'replica_id': -1 + } + +class OffsetRequest_v1(Request): + API_KEY = 2 + API_VERSION = 1 + RESPONSE_TYPE = OffsetResponse_v1 + SCHEMA = Schema( + ('replica_id', Int32), + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('timestamp', Int64))))) + ) + DEFAULTS = { + 'replica_id': -1 + } + + +class OffsetRequest_v2(Request): + API_KEY = 2 + API_VERSION = 2 + RESPONSE_TYPE = OffsetResponse_v2 + SCHEMA = Schema( + ('replica_id', Int32), + ('isolation_level', Int8), # <- added isolation_level + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('timestamp', Int64))))) + ) + DEFAULTS = { + 'replica_id': -1 + } + + +class OffsetRequest_v3(Request): + API_KEY = 2 + API_VERSION = 3 + RESPONSE_TYPE = OffsetResponse_v3 + SCHEMA = OffsetRequest_v2.SCHEMA + DEFAULTS = { + 'replica_id': -1 + } + + +class OffsetRequest_v4(Request): + """ + Add current_leader_epoch to request + """ + API_KEY = 2 + API_VERSION = 4 + RESPONSE_TYPE = OffsetResponse_v4 + SCHEMA = Schema( + ('replica_id', Int32), + ('isolation_level', Int8), # <- added isolation_level + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('current_leader_epoch', Int64), + ('timestamp', Int64))))) + ) + DEFAULTS = { + 'replica_id': -1 + } + + +class OffsetRequest_v5(Request): + API_KEY = 2 + API_VERSION = 5 + RESPONSE_TYPE = OffsetResponse_v5 + SCHEMA = OffsetRequest_v4.SCHEMA + DEFAULTS = { + 'replica_id': -1 + } + + +OffsetRequest = [ + OffsetRequest_v0, OffsetRequest_v1, OffsetRequest_v2, + OffsetRequest_v3, OffsetRequest_v4, OffsetRequest_v5, +] +OffsetResponse = [ + OffsetResponse_v0, OffsetResponse_v1, OffsetResponse_v2, + OffsetResponse_v3, OffsetResponse_v4, OffsetResponse_v5, +] diff --git a/protocol/parser.py b/protocol/parser.py new file mode 100644 index 00000000..a9e76722 --- /dev/null +++ b/protocol/parser.py @@ -0,0 +1,176 @@ +from __future__ import absolute_import + +import collections +import logging + +import kafka.errors as Errors +from kafka.protocol.commit import GroupCoordinatorResponse +from kafka.protocol.frame import KafkaBytes +from kafka.protocol.types import Int32, TaggedFields +from kafka.version import __version__ + +log = logging.getLogger(__name__) + + +class KafkaProtocol(object): + """Manage the kafka network protocol + + Use an instance of KafkaProtocol to manage bytes send/recv'd + from a network socket to a broker. + + Arguments: + client_id (str): identifier string to be included in each request + api_version (tuple): Optional tuple to specify api_version to use. + Currently only used to check for 0.8.2 protocol quirks, but + may be used for more in the future. + """ + def __init__(self, client_id=None, api_version=None): + if client_id is None: + client_id = self._gen_client_id() + self._client_id = client_id + self._api_version = api_version + self._correlation_id = 0 + self._header = KafkaBytes(4) + self._rbuffer = None + self._receiving = False + self.in_flight_requests = collections.deque() + self.bytes_to_send = [] + + def _next_correlation_id(self): + self._correlation_id = (self._correlation_id + 1) % 2**31 + return self._correlation_id + + def _gen_client_id(self): + return 'kafka-python' + __version__ + + def send_request(self, request, correlation_id=None): + """Encode and queue a kafka api request for sending. + + Arguments: + request (object): An un-encoded kafka request. + correlation_id (int, optional): Optionally specify an ID to + correlate requests with responses. If not provided, an ID will + be generated automatically. + + Returns: + correlation_id + """ + log.debug('Sending request %s', request) + if correlation_id is None: + correlation_id = self._next_correlation_id() + + header = request.build_request_header(correlation_id=correlation_id, client_id=self._client_id) + message = b''.join([header.encode(), request.encode()]) + size = Int32.encode(len(message)) + data = size + message + self.bytes_to_send.append(data) + if request.expect_response(): + ifr = (correlation_id, request) + self.in_flight_requests.append(ifr) + return correlation_id + + def send_bytes(self): + """Retrieve all pending bytes to send on the network""" + data = b''.join(self.bytes_to_send) + self.bytes_to_send = [] + return data + + def receive_bytes(self, data): + """Process bytes received from the network. + + Arguments: + data (bytes): any length bytes received from a network connection + to a kafka broker. + + Returns: + responses (list of (correlation_id, response)): any/all completed + responses, decoded from bytes to python objects. + + Raises: + KafkaProtocolError: if the bytes received could not be decoded. + CorrelationIdError: if the response does not match the request + correlation id. + """ + i = 0 + n = len(data) + responses = [] + while i < n: + + # Not receiving is the state of reading the payload header + if not self._receiving: + bytes_to_read = min(4 - self._header.tell(), n - i) + self._header.write(data[i:i+bytes_to_read]) + i += bytes_to_read + + if self._header.tell() == 4: + self._header.seek(0) + nbytes = Int32.decode(self._header) + # reset buffer and switch state to receiving payload bytes + self._rbuffer = KafkaBytes(nbytes) + self._receiving = True + elif self._header.tell() > 4: + raise Errors.KafkaError('this should not happen - are you threading?') + + if self._receiving: + total_bytes = len(self._rbuffer) + staged_bytes = self._rbuffer.tell() + bytes_to_read = min(total_bytes - staged_bytes, n - i) + self._rbuffer.write(data[i:i+bytes_to_read]) + i += bytes_to_read + + staged_bytes = self._rbuffer.tell() + if staged_bytes > total_bytes: + raise Errors.KafkaError('Receive buffer has more bytes than expected?') + + if staged_bytes != total_bytes: + break + + self._receiving = False + self._rbuffer.seek(0) + resp = self._process_response(self._rbuffer) + responses.append(resp) + self._reset_buffer() + return responses + + def _process_response(self, read_buffer): + if not self.in_flight_requests: + raise Errors.CorrelationIdError('No in-flight-request found for server response') + (correlation_id, request) = self.in_flight_requests.popleft() + response_header = request.parse_response_header(read_buffer) + recv_correlation_id = response_header.correlation_id + log.debug('Received correlation id: %d', recv_correlation_id) + # 0.8.2 quirk + if (recv_correlation_id == 0 and + correlation_id != 0 and + request.RESPONSE_TYPE is GroupCoordinatorResponse[0] and + (self._api_version == (0, 8, 2) or self._api_version is None)): + log.warning('Kafka 0.8.2 quirk -- GroupCoordinatorResponse' + ' Correlation ID does not match request. This' + ' should go away once at least one topic has been' + ' initialized on the broker.') + + elif correlation_id != recv_correlation_id: + # return or raise? + raise Errors.CorrelationIdError( + 'Correlation IDs do not match: sent %d, recv %d' + % (correlation_id, recv_correlation_id)) + + # decode response + log.debug('Processing response %s', request.RESPONSE_TYPE.__name__) + try: + response = request.RESPONSE_TYPE.decode(read_buffer) + except ValueError: + read_buffer.seek(0) + buf = read_buffer.read() + log.error('Response %d [ResponseType: %s Request: %s]:' + ' Unable to decode %d-byte buffer: %r', + correlation_id, request.RESPONSE_TYPE, + request, len(buf), buf) + raise Errors.KafkaProtocolError('Unable to decode response') + + return (correlation_id, response) + + def _reset_buffer(self): + self._receiving = False + self._header.seek(0) + self._rbuffer = None diff --git a/protocol/pickle.py b/protocol/pickle.py new file mode 100644 index 00000000..d6e5fa74 --- /dev/null +++ b/protocol/pickle.py @@ -0,0 +1,35 @@ +from __future__ import absolute_import + +try: + import copyreg # pylint: disable=import-error +except ImportError: + import copy_reg as copyreg # pylint: disable=import-error + +import types + + +def _pickle_method(method): + try: + func_name = method.__func__.__name__ + obj = method.__self__ + cls = method.__self__.__class__ + except AttributeError: + func_name = method.im_func.__name__ + obj = method.im_self + cls = method.im_class + + return _unpickle_method, (func_name, obj, cls) + + +def _unpickle_method(func_name, obj, cls): + for cls in cls.mro(): + try: + func = cls.__dict__[func_name] + except KeyError: + pass + else: + break + return func.__get__(obj, cls) + +# https://bytes.com/topic/python/answers/552476-why-cant-you-pickle-instancemethods +copyreg.pickle(types.MethodType, _pickle_method, _unpickle_method) diff --git a/protocol/produce.py b/protocol/produce.py new file mode 100644 index 00000000..9b3f6bf5 --- /dev/null +++ b/protocol/produce.py @@ -0,0 +1,232 @@ +from __future__ import absolute_import + +from kafka.protocol.api import Request, Response +from kafka.protocol.types import Int16, Int32, Int64, String, Array, Schema, Bytes + + +class ProduceResponse_v0(Response): + API_KEY = 0 + API_VERSION = 0 + SCHEMA = Schema( + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('error_code', Int16), + ('offset', Int64))))) + ) + + +class ProduceResponse_v1(Response): + API_KEY = 0 + API_VERSION = 1 + SCHEMA = Schema( + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('error_code', Int16), + ('offset', Int64))))), + ('throttle_time_ms', Int32) + ) + + +class ProduceResponse_v2(Response): + API_KEY = 0 + API_VERSION = 2 + SCHEMA = Schema( + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('error_code', Int16), + ('offset', Int64), + ('timestamp', Int64))))), + ('throttle_time_ms', Int32) + ) + + +class ProduceResponse_v3(Response): + API_KEY = 0 + API_VERSION = 3 + SCHEMA = ProduceResponse_v2.SCHEMA + + +class ProduceResponse_v4(Response): + """ + The version number is bumped up to indicate that the client supports KafkaStorageException. + The KafkaStorageException will be translated to NotLeaderForPartitionException in the response if version <= 3 + """ + API_KEY = 0 + API_VERSION = 4 + SCHEMA = ProduceResponse_v3.SCHEMA + + +class ProduceResponse_v5(Response): + API_KEY = 0 + API_VERSION = 5 + SCHEMA = Schema( + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('error_code', Int16), + ('offset', Int64), + ('timestamp', Int64), + ('log_start_offset', Int64))))), + ('throttle_time_ms', Int32) + ) + + +class ProduceResponse_v6(Response): + """ + The version number is bumped to indicate that on quota violation brokers send out responses before throttling. + """ + API_KEY = 0 + API_VERSION = 6 + SCHEMA = ProduceResponse_v5.SCHEMA + + +class ProduceResponse_v7(Response): + """ + V7 bumped up to indicate ZStandard capability. (see KIP-110) + """ + API_KEY = 0 + API_VERSION = 7 + SCHEMA = ProduceResponse_v6.SCHEMA + + +class ProduceResponse_v8(Response): + """ + V8 bumped up to add two new fields record_errors offset list and error_message + (See KIP-467) + """ + API_KEY = 0 + API_VERSION = 8 + SCHEMA = Schema( + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('error_code', Int16), + ('offset', Int64), + ('timestamp', Int64), + ('log_start_offset', Int64)), + ('record_errors', (Array( + ('batch_index', Int32), + ('batch_index_error_message', String('utf-8')) + ))), + ('error_message', String('utf-8')) + ))), + ('throttle_time_ms', Int32) + ) + + +class ProduceRequest(Request): + API_KEY = 0 + + def expect_response(self): + if self.required_acks == 0: # pylint: disable=no-member + return False + return True + + +class ProduceRequest_v0(ProduceRequest): + API_VERSION = 0 + RESPONSE_TYPE = ProduceResponse_v0 + SCHEMA = Schema( + ('required_acks', Int16), + ('timeout', Int32), + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('messages', Bytes))))) + ) + + +class ProduceRequest_v1(ProduceRequest): + API_VERSION = 1 + RESPONSE_TYPE = ProduceResponse_v1 + SCHEMA = ProduceRequest_v0.SCHEMA + + +class ProduceRequest_v2(ProduceRequest): + API_VERSION = 2 + RESPONSE_TYPE = ProduceResponse_v2 + SCHEMA = ProduceRequest_v1.SCHEMA + + +class ProduceRequest_v3(ProduceRequest): + API_VERSION = 3 + RESPONSE_TYPE = ProduceResponse_v3 + SCHEMA = Schema( + ('transactional_id', String('utf-8')), + ('required_acks', Int16), + ('timeout', Int32), + ('topics', Array( + ('topic', String('utf-8')), + ('partitions', Array( + ('partition', Int32), + ('messages', Bytes))))) + ) + + +class ProduceRequest_v4(ProduceRequest): + """ + The version number is bumped up to indicate that the client supports KafkaStorageException. + The KafkaStorageException will be translated to NotLeaderForPartitionException in the response if version <= 3 + """ + API_VERSION = 4 + RESPONSE_TYPE = ProduceResponse_v4 + SCHEMA = ProduceRequest_v3.SCHEMA + + +class ProduceRequest_v5(ProduceRequest): + """ + Same as v4. The version number is bumped since the v5 response includes an additional + partition level field: the log_start_offset. + """ + API_VERSION = 5 + RESPONSE_TYPE = ProduceResponse_v5 + SCHEMA = ProduceRequest_v4.SCHEMA + + +class ProduceRequest_v6(ProduceRequest): + """ + The version number is bumped to indicate that on quota violation brokers send out responses before throttling. + """ + API_VERSION = 6 + RESPONSE_TYPE = ProduceResponse_v6 + SCHEMA = ProduceRequest_v5.SCHEMA + + +class ProduceRequest_v7(ProduceRequest): + """ + V7 bumped up to indicate ZStandard capability. (see KIP-110) + """ + API_VERSION = 7 + RESPONSE_TYPE = ProduceResponse_v7 + SCHEMA = ProduceRequest_v6.SCHEMA + + +class ProduceRequest_v8(ProduceRequest): + """ + V8 bumped up to add two new fields record_errors offset list and error_message to PartitionResponse + (See KIP-467) + """ + API_VERSION = 8 + RESPONSE_TYPE = ProduceResponse_v8 + SCHEMA = ProduceRequest_v7.SCHEMA + + +ProduceRequest = [ + ProduceRequest_v0, ProduceRequest_v1, ProduceRequest_v2, + ProduceRequest_v3, ProduceRequest_v4, ProduceRequest_v5, + ProduceRequest_v6, ProduceRequest_v7, ProduceRequest_v8, +] +ProduceResponse = [ + ProduceResponse_v0, ProduceResponse_v1, ProduceResponse_v2, + ProduceResponse_v3, ProduceResponse_v4, ProduceResponse_v5, + ProduceResponse_v6, ProduceResponse_v7, ProduceResponse_v8, +] diff --git a/protocol/struct.py b/protocol/struct.py new file mode 100644 index 00000000..e9da6e6c --- /dev/null +++ b/protocol/struct.py @@ -0,0 +1,72 @@ +from __future__ import absolute_import + +from io import BytesIO + +from kafka.protocol.abstract import AbstractType +from kafka.protocol.types import Schema + +from kafka.util import WeakMethod + + +class Struct(AbstractType): + SCHEMA = Schema() + + def __init__(self, *args, **kwargs): + if len(args) == len(self.SCHEMA.fields): + for i, name in enumerate(self.SCHEMA.names): + self.__dict__[name] = args[i] + elif len(args) > 0: + raise ValueError('Args must be empty or mirror schema') + else: + for name in self.SCHEMA.names: + self.__dict__[name] = kwargs.pop(name, None) + if kwargs: + raise ValueError('Keyword(s) not in schema %s: %s' + % (list(self.SCHEMA.names), + ', '.join(kwargs.keys()))) + + # overloading encode() to support both class and instance + # Without WeakMethod() this creates circular ref, which + # causes instances to "leak" to garbage + self.encode = WeakMethod(self._encode_self) + + + @classmethod + def encode(cls, item): # pylint: disable=E0202 + bits = [] + for i, field in enumerate(cls.SCHEMA.fields): + bits.append(field.encode(item[i])) + return b''.join(bits) + + def _encode_self(self): + return self.SCHEMA.encode( + [self.__dict__[name] for name in self.SCHEMA.names] + ) + + @classmethod + def decode(cls, data): + if isinstance(data, bytes): + data = BytesIO(data) + return cls(*[field.decode(data) for field in cls.SCHEMA.fields]) + + def get_item(self, name): + if name not in self.SCHEMA.names: + raise KeyError("%s is not in the schema" % name) + return self.__dict__[name] + + def __repr__(self): + key_vals = [] + for name, field in zip(self.SCHEMA.names, self.SCHEMA.fields): + key_vals.append('%s=%s' % (name, field.repr(self.__dict__[name]))) + return self.__class__.__name__ + '(' + ', '.join(key_vals) + ')' + + def __hash__(self): + return hash(self.encode()) + + def __eq__(self, other): + if self.SCHEMA != other.SCHEMA: + return False + for attr in self.SCHEMA.names: + if self.__dict__[attr] != other.__dict__[attr]: + return False + return True diff --git a/protocol/types.py b/protocol/types.py new file mode 100644 index 00000000..0e3685d7 --- /dev/null +++ b/protocol/types.py @@ -0,0 +1,365 @@ +from __future__ import absolute_import + +import struct +from struct import error + +from kafka.protocol.abstract import AbstractType + + +def _pack(f, value): + try: + return f(value) + except error as e: + raise ValueError("Error encountered when attempting to convert value: " + "{!r} to struct format: '{}', hit error: {}" + .format(value, f, e)) + + +def _unpack(f, data): + try: + (value,) = f(data) + return value + except error as e: + raise ValueError("Error encountered when attempting to convert value: " + "{!r} to struct format: '{}', hit error: {}" + .format(data, f, e)) + + +class Int8(AbstractType): + _pack = struct.Struct('>b').pack + _unpack = struct.Struct('>b').unpack + + @classmethod + def encode(cls, value): + return _pack(cls._pack, value) + + @classmethod + def decode(cls, data): + return _unpack(cls._unpack, data.read(1)) + + +class Int16(AbstractType): + _pack = struct.Struct('>h').pack + _unpack = struct.Struct('>h').unpack + + @classmethod + def encode(cls, value): + return _pack(cls._pack, value) + + @classmethod + def decode(cls, data): + return _unpack(cls._unpack, data.read(2)) + + +class Int32(AbstractType): + _pack = struct.Struct('>i').pack + _unpack = struct.Struct('>i').unpack + + @classmethod + def encode(cls, value): + return _pack(cls._pack, value) + + @classmethod + def decode(cls, data): + return _unpack(cls._unpack, data.read(4)) + + +class Int64(AbstractType): + _pack = struct.Struct('>q').pack + _unpack = struct.Struct('>q').unpack + + @classmethod + def encode(cls, value): + return _pack(cls._pack, value) + + @classmethod + def decode(cls, data): + return _unpack(cls._unpack, data.read(8)) + + +class Float64(AbstractType): + _pack = struct.Struct('>d').pack + _unpack = struct.Struct('>d').unpack + + @classmethod + def encode(cls, value): + return _pack(cls._pack, value) + + @classmethod + def decode(cls, data): + return _unpack(cls._unpack, data.read(8)) + + +class String(AbstractType): + def __init__(self, encoding='utf-8'): + self.encoding = encoding + + def encode(self, value): + if value is None: + return Int16.encode(-1) + value = str(value).encode(self.encoding) + return Int16.encode(len(value)) + value + + def decode(self, data): + length = Int16.decode(data) + if length < 0: + return None + value = data.read(length) + if len(value) != length: + raise ValueError('Buffer underrun decoding string') + return value.decode(self.encoding) + + +class Bytes(AbstractType): + @classmethod + def encode(cls, value): + if value is None: + return Int32.encode(-1) + else: + return Int32.encode(len(value)) + value + + @classmethod + def decode(cls, data): + length = Int32.decode(data) + if length < 0: + return None + value = data.read(length) + if len(value) != length: + raise ValueError('Buffer underrun decoding Bytes') + return value + + @classmethod + def repr(cls, value): + return repr(value[:100] + b'...' if value is not None and len(value) > 100 else value) + + +class Boolean(AbstractType): + _pack = struct.Struct('>?').pack + _unpack = struct.Struct('>?').unpack + + @classmethod + def encode(cls, value): + return _pack(cls._pack, value) + + @classmethod + def decode(cls, data): + return _unpack(cls._unpack, data.read(1)) + + +class Schema(AbstractType): + def __init__(self, *fields): + if fields: + self.names, self.fields = zip(*fields) + else: + self.names, self.fields = (), () + + def encode(self, item): + if len(item) != len(self.fields): + raise ValueError('Item field count does not match Schema') + return b''.join([ + field.encode(item[i]) + for i, field in enumerate(self.fields) + ]) + + def decode(self, data): + return tuple([field.decode(data) for field in self.fields]) + + def __len__(self): + return len(self.fields) + + def repr(self, value): + key_vals = [] + try: + for i in range(len(self)): + try: + field_val = getattr(value, self.names[i]) + except AttributeError: + field_val = value[i] + key_vals.append('%s=%s' % (self.names[i], self.fields[i].repr(field_val))) + return '(' + ', '.join(key_vals) + ')' + except Exception: + return repr(value) + + +class Array(AbstractType): + def __init__(self, *array_of): + if len(array_of) > 1: + self.array_of = Schema(*array_of) + elif len(array_of) == 1 and (isinstance(array_of[0], AbstractType) or + issubclass(array_of[0], AbstractType)): + self.array_of = array_of[0] + else: + raise ValueError('Array instantiated with no array_of type') + + def encode(self, items): + if items is None: + return Int32.encode(-1) + encoded_items = [self.array_of.encode(item) for item in items] + return b''.join( + [Int32.encode(len(encoded_items))] + + encoded_items + ) + + def decode(self, data): + length = Int32.decode(data) + if length == -1: + return None + return [self.array_of.decode(data) for _ in range(length)] + + def repr(self, list_of_items): + if list_of_items is None: + return 'NULL' + return '[' + ', '.join([self.array_of.repr(item) for item in list_of_items]) + ']' + + +class UnsignedVarInt32(AbstractType): + @classmethod + def decode(cls, data): + value, i = 0, 0 + while True: + b, = struct.unpack('B', data.read(1)) + if not (b & 0x80): + break + value |= (b & 0x7f) << i + i += 7 + if i > 28: + raise ValueError('Invalid value {}'.format(value)) + value |= b << i + return value + + @classmethod + def encode(cls, value): + value &= 0xffffffff + ret = b'' + while (value & 0xffffff80) != 0: + b = (value & 0x7f) | 0x80 + ret += struct.pack('B', b) + value >>= 7 + ret += struct.pack('B', value) + return ret + + +class VarInt32(AbstractType): + @classmethod + def decode(cls, data): + value = UnsignedVarInt32.decode(data) + return (value >> 1) ^ -(value & 1) + + @classmethod + def encode(cls, value): + # bring it in line with the java binary repr + value &= 0xffffffff + return UnsignedVarInt32.encode((value << 1) ^ (value >> 31)) + + +class VarInt64(AbstractType): + @classmethod + def decode(cls, data): + value, i = 0, 0 + while True: + b = data.read(1) + if not (b & 0x80): + break + value |= (b & 0x7f) << i + i += 7 + if i > 63: + raise ValueError('Invalid value {}'.format(value)) + value |= b << i + return (value >> 1) ^ -(value & 1) + + @classmethod + def encode(cls, value): + # bring it in line with the java binary repr + value &= 0xffffffffffffffff + v = (value << 1) ^ (value >> 63) + ret = b'' + while (v & 0xffffffffffffff80) != 0: + b = (value & 0x7f) | 0x80 + ret += struct.pack('B', b) + v >>= 7 + ret += struct.pack('B', v) + return ret + + +class CompactString(String): + def decode(self, data): + length = UnsignedVarInt32.decode(data) - 1 + if length < 0: + return None + value = data.read(length) + if len(value) != length: + raise ValueError('Buffer underrun decoding string') + return value.decode(self.encoding) + + def encode(self, value): + if value is None: + return UnsignedVarInt32.encode(0) + value = str(value).encode(self.encoding) + return UnsignedVarInt32.encode(len(value) + 1) + value + + +class TaggedFields(AbstractType): + @classmethod + def decode(cls, data): + num_fields = UnsignedVarInt32.decode(data) + ret = {} + if not num_fields: + return ret + prev_tag = -1 + for i in range(num_fields): + tag = UnsignedVarInt32.decode(data) + if tag <= prev_tag: + raise ValueError('Invalid or out-of-order tag {}'.format(tag)) + prev_tag = tag + size = UnsignedVarInt32.decode(data) + val = data.read(size) + ret[tag] = val + return ret + + @classmethod + def encode(cls, value): + ret = UnsignedVarInt32.encode(len(value)) + for k, v in value.items(): + # do we allow for other data types ?? It could get complicated really fast + assert isinstance(v, bytes), 'Value {} is not a byte array'.format(v) + assert isinstance(k, int) and k > 0, 'Key {} is not a positive integer'.format(k) + ret += UnsignedVarInt32.encode(k) + ret += v + return ret + + +class CompactBytes(AbstractType): + @classmethod + def decode(cls, data): + length = UnsignedVarInt32.decode(data) - 1 + if length < 0: + return None + value = data.read(length) + if len(value) != length: + raise ValueError('Buffer underrun decoding Bytes') + return value + + @classmethod + def encode(cls, value): + if value is None: + return UnsignedVarInt32.encode(0) + else: + return UnsignedVarInt32.encode(len(value) + 1) + value + + +class CompactArray(Array): + + def encode(self, items): + if items is None: + return UnsignedVarInt32.encode(0) + return b''.join( + [UnsignedVarInt32.encode(len(items) + 1)] + + [self.array_of.encode(item) for item in items] + ) + + def decode(self, data): + length = UnsignedVarInt32.decode(data) - 1 + if length == -1: + return None + return [self.array_of.decode(data) for _ in range(length)] + diff --git a/record/README b/record/README new file mode 100644 index 00000000..e4454554 --- /dev/null +++ b/record/README @@ -0,0 +1,8 @@ +Module structured mostly based on +kafka/clients/src/main/java/org/apache/kafka/common/record/ module of Java +Client. + +See abc.py for abstract declarations. `ABCRecords` is used as a facade to hide +version differences. `ABCRecordBatch` subclasses will implement actual parsers +for different versions (v0/v1 as LegacyBatch and v2 as DefaultBatch. Names +taken from Java). diff --git a/record/__init__.py b/record/__init__.py new file mode 100644 index 00000000..93936df4 --- /dev/null +++ b/record/__init__.py @@ -0,0 +1,3 @@ +from kafka.record.memory_records import MemoryRecords, MemoryRecordsBuilder + +__all__ = ["MemoryRecords", "MemoryRecordsBuilder"] diff --git a/record/_crc32c.py b/record/_crc32c.py new file mode 100644 index 00000000..9b51ad8a --- /dev/null +++ b/record/_crc32c.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python +# +# Taken from https://cloud.google.com/appengine/docs/standard/python/refdocs/\ +# modules/google/appengine/api/files/crc32c?hl=ru +# +# Copyright 2007 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Implementation of CRC-32C checksumming as in rfc3720 section B.4. +See https://en.wikipedia.org/wiki/Cyclic_redundancy_check for details on CRC-32C +This code is a manual python translation of c code generated by +pycrc 0.7.1 (https://pycrc.org/). Command line used: +'./pycrc.py --model=crc-32c --generate c --algorithm=table-driven' +""" + +import array + +CRC_TABLE = ( + 0x00000000, 0xf26b8303, 0xe13b70f7, 0x1350f3f4, + 0xc79a971f, 0x35f1141c, 0x26a1e7e8, 0xd4ca64eb, + 0x8ad958cf, 0x78b2dbcc, 0x6be22838, 0x9989ab3b, + 0x4d43cfd0, 0xbf284cd3, 0xac78bf27, 0x5e133c24, + 0x105ec76f, 0xe235446c, 0xf165b798, 0x030e349b, + 0xd7c45070, 0x25afd373, 0x36ff2087, 0xc494a384, + 0x9a879fa0, 0x68ec1ca3, 0x7bbcef57, 0x89d76c54, + 0x5d1d08bf, 0xaf768bbc, 0xbc267848, 0x4e4dfb4b, + 0x20bd8ede, 0xd2d60ddd, 0xc186fe29, 0x33ed7d2a, + 0xe72719c1, 0x154c9ac2, 0x061c6936, 0xf477ea35, + 0xaa64d611, 0x580f5512, 0x4b5fa6e6, 0xb93425e5, + 0x6dfe410e, 0x9f95c20d, 0x8cc531f9, 0x7eaeb2fa, + 0x30e349b1, 0xc288cab2, 0xd1d83946, 0x23b3ba45, + 0xf779deae, 0x05125dad, 0x1642ae59, 0xe4292d5a, + 0xba3a117e, 0x4851927d, 0x5b016189, 0xa96ae28a, + 0x7da08661, 0x8fcb0562, 0x9c9bf696, 0x6ef07595, + 0x417b1dbc, 0xb3109ebf, 0xa0406d4b, 0x522bee48, + 0x86e18aa3, 0x748a09a0, 0x67dafa54, 0x95b17957, + 0xcba24573, 0x39c9c670, 0x2a993584, 0xd8f2b687, + 0x0c38d26c, 0xfe53516f, 0xed03a29b, 0x1f682198, + 0x5125dad3, 0xa34e59d0, 0xb01eaa24, 0x42752927, + 0x96bf4dcc, 0x64d4cecf, 0x77843d3b, 0x85efbe38, + 0xdbfc821c, 0x2997011f, 0x3ac7f2eb, 0xc8ac71e8, + 0x1c661503, 0xee0d9600, 0xfd5d65f4, 0x0f36e6f7, + 0x61c69362, 0x93ad1061, 0x80fde395, 0x72966096, + 0xa65c047d, 0x5437877e, 0x4767748a, 0xb50cf789, + 0xeb1fcbad, 0x197448ae, 0x0a24bb5a, 0xf84f3859, + 0x2c855cb2, 0xdeeedfb1, 0xcdbe2c45, 0x3fd5af46, + 0x7198540d, 0x83f3d70e, 0x90a324fa, 0x62c8a7f9, + 0xb602c312, 0x44694011, 0x5739b3e5, 0xa55230e6, + 0xfb410cc2, 0x092a8fc1, 0x1a7a7c35, 0xe811ff36, + 0x3cdb9bdd, 0xceb018de, 0xdde0eb2a, 0x2f8b6829, + 0x82f63b78, 0x709db87b, 0x63cd4b8f, 0x91a6c88c, + 0x456cac67, 0xb7072f64, 0xa457dc90, 0x563c5f93, + 0x082f63b7, 0xfa44e0b4, 0xe9141340, 0x1b7f9043, + 0xcfb5f4a8, 0x3dde77ab, 0x2e8e845f, 0xdce5075c, + 0x92a8fc17, 0x60c37f14, 0x73938ce0, 0x81f80fe3, + 0x55326b08, 0xa759e80b, 0xb4091bff, 0x466298fc, + 0x1871a4d8, 0xea1a27db, 0xf94ad42f, 0x0b21572c, + 0xdfeb33c7, 0x2d80b0c4, 0x3ed04330, 0xccbbc033, + 0xa24bb5a6, 0x502036a5, 0x4370c551, 0xb11b4652, + 0x65d122b9, 0x97baa1ba, 0x84ea524e, 0x7681d14d, + 0x2892ed69, 0xdaf96e6a, 0xc9a99d9e, 0x3bc21e9d, + 0xef087a76, 0x1d63f975, 0x0e330a81, 0xfc588982, + 0xb21572c9, 0x407ef1ca, 0x532e023e, 0xa145813d, + 0x758fe5d6, 0x87e466d5, 0x94b49521, 0x66df1622, + 0x38cc2a06, 0xcaa7a905, 0xd9f75af1, 0x2b9cd9f2, + 0xff56bd19, 0x0d3d3e1a, 0x1e6dcdee, 0xec064eed, + 0xc38d26c4, 0x31e6a5c7, 0x22b65633, 0xd0ddd530, + 0x0417b1db, 0xf67c32d8, 0xe52cc12c, 0x1747422f, + 0x49547e0b, 0xbb3ffd08, 0xa86f0efc, 0x5a048dff, + 0x8ecee914, 0x7ca56a17, 0x6ff599e3, 0x9d9e1ae0, + 0xd3d3e1ab, 0x21b862a8, 0x32e8915c, 0xc083125f, + 0x144976b4, 0xe622f5b7, 0xf5720643, 0x07198540, + 0x590ab964, 0xab613a67, 0xb831c993, 0x4a5a4a90, + 0x9e902e7b, 0x6cfbad78, 0x7fab5e8c, 0x8dc0dd8f, + 0xe330a81a, 0x115b2b19, 0x020bd8ed, 0xf0605bee, + 0x24aa3f05, 0xd6c1bc06, 0xc5914ff2, 0x37faccf1, + 0x69e9f0d5, 0x9b8273d6, 0x88d28022, 0x7ab90321, + 0xae7367ca, 0x5c18e4c9, 0x4f48173d, 0xbd23943e, + 0xf36e6f75, 0x0105ec76, 0x12551f82, 0xe03e9c81, + 0x34f4f86a, 0xc69f7b69, 0xd5cf889d, 0x27a40b9e, + 0x79b737ba, 0x8bdcb4b9, 0x988c474d, 0x6ae7c44e, + 0xbe2da0a5, 0x4c4623a6, 0x5f16d052, 0xad7d5351, +) + +CRC_INIT = 0 +_MASK = 0xFFFFFFFF + + +def crc_update(crc, data): + """Update CRC-32C checksum with data. + Args: + crc: 32-bit checksum to update as long. + data: byte array, string or iterable over bytes. + Returns: + 32-bit updated CRC-32C as long. + """ + if not isinstance(data, array.array) or data.itemsize != 1: + buf = array.array("B", data) + else: + buf = data + crc = crc ^ _MASK + for b in buf: + table_index = (crc ^ b) & 0xff + crc = (CRC_TABLE[table_index] ^ (crc >> 8)) & _MASK + return crc ^ _MASK + + +def crc_finalize(crc): + """Finalize CRC-32C checksum. + This function should be called as last step of crc calculation. + Args: + crc: 32-bit checksum as long. + Returns: + finalized 32-bit checksum as long + """ + return crc & _MASK + + +def crc(data): + """Compute CRC-32C checksum of the data. + Args: + data: byte array, string or iterable over bytes. + Returns: + 32-bit CRC-32C checksum of data as long. + """ + return crc_finalize(crc_update(CRC_INIT, data)) + + +if __name__ == "__main__": + import sys + # TODO remove the pylint disable once pylint fixes + # https://github.com/PyCQA/pylint/issues/2571 + data = sys.stdin.read() # pylint: disable=assignment-from-no-return + print(hex(crc(data))) diff --git a/record/abc.py b/record/abc.py new file mode 100644 index 00000000..8509e23e --- /dev/null +++ b/record/abc.py @@ -0,0 +1,124 @@ +from __future__ import absolute_import +import abc + + +class ABCRecord(object): + __metaclass__ = abc.ABCMeta + __slots__ = () + + @abc.abstractproperty + def offset(self): + """ Absolute offset of record + """ + + @abc.abstractproperty + def timestamp(self): + """ Epoch milliseconds + """ + + @abc.abstractproperty + def timestamp_type(self): + """ CREATE_TIME(0) or APPEND_TIME(1) + """ + + @abc.abstractproperty + def key(self): + """ Bytes key or None + """ + + @abc.abstractproperty + def value(self): + """ Bytes value or None + """ + + @abc.abstractproperty + def checksum(self): + """ Prior to v2 format CRC was contained in every message. This will + be the checksum for v0 and v1 and None for v2 and above. + """ + + @abc.abstractproperty + def headers(self): + """ If supported by version list of key-value tuples, or empty list if + not supported by format. + """ + + +class ABCRecordBatchBuilder(object): + __metaclass__ = abc.ABCMeta + __slots__ = () + + @abc.abstractmethod + def append(self, offset, timestamp, key, value, headers=None): + """ Writes record to internal buffer. + + Arguments: + offset (int): Relative offset of record, starting from 0 + timestamp (int or None): Timestamp in milliseconds since beginning + of the epoch (midnight Jan 1, 1970 (UTC)). If omitted, will be + set to current time. + key (bytes or None): Key of the record + value (bytes or None): Value of the record + headers (List[Tuple[str, bytes]]): Headers of the record. Header + keys can not be ``None``. + + Returns: + (bytes, int): Checksum of the written record (or None for v2 and + above) and size of the written record. + """ + + @abc.abstractmethod + def size_in_bytes(self, offset, timestamp, key, value, headers): + """ Return the expected size change on buffer (uncompressed) if we add + this message. This will account for varint size changes and give a + reliable size. + """ + + @abc.abstractmethod + def build(self): + """ Close for append, compress if needed, write size and header and + return a ready to send buffer object. + + Return: + bytearray: finished batch, ready to send. + """ + + +class ABCRecordBatch(object): + """ For v2 encapsulates a RecordBatch, for v0/v1 a single (maybe + compressed) message. + """ + __metaclass__ = abc.ABCMeta + __slots__ = () + + @abc.abstractmethod + def __iter__(self): + """ Return iterator over records (ABCRecord instances). Will decompress + if needed. + """ + + +class ABCRecords(object): + __metaclass__ = abc.ABCMeta + __slots__ = () + + @abc.abstractmethod + def __init__(self, buffer): + """ Initialize with bytes-like object conforming to the buffer + interface (ie. bytes, bytearray, memoryview etc.). + """ + + @abc.abstractmethod + def size_in_bytes(self): + """ Returns the size of inner buffer. + """ + + @abc.abstractmethod + def next_batch(self): + """ Return next batch of records (ABCRecordBatch instances). + """ + + @abc.abstractmethod + def has_next(self): + """ True if there are more batches to read, False otherwise. + """ diff --git a/record/default_records.py b/record/default_records.py new file mode 100644 index 00000000..a098c42a --- /dev/null +++ b/record/default_records.py @@ -0,0 +1,630 @@ +# See: +# https://github.com/apache/kafka/blob/trunk/clients/src/main/java/org/\ +# apache/kafka/common/record/DefaultRecordBatch.java +# https://github.com/apache/kafka/blob/trunk/clients/src/main/java/org/\ +# apache/kafka/common/record/DefaultRecord.java + +# RecordBatch and Record implementation for magic 2 and above. +# The schema is given below: + +# RecordBatch => +# BaseOffset => Int64 +# Length => Int32 +# PartitionLeaderEpoch => Int32 +# Magic => Int8 +# CRC => Uint32 +# Attributes => Int16 +# LastOffsetDelta => Int32 // also serves as LastSequenceDelta +# FirstTimestamp => Int64 +# MaxTimestamp => Int64 +# ProducerId => Int64 +# ProducerEpoch => Int16 +# BaseSequence => Int32 +# Records => [Record] + +# Record => +# Length => Varint +# Attributes => Int8 +# TimestampDelta => Varlong +# OffsetDelta => Varint +# Key => Bytes +# Value => Bytes +# Headers => [HeaderKey HeaderValue] +# HeaderKey => String +# HeaderValue => Bytes + +# Note that when compression is enabled (see attributes below), the compressed +# record data is serialized directly following the count of the number of +# records. (ie Records => [Record], but without length bytes) + +# The CRC covers the data from the attributes to the end of the batch (i.e. all +# the bytes that follow the CRC). It is located after the magic byte, which +# means that clients must parse the magic byte before deciding how to interpret +# the bytes between the batch length and the magic byte. The partition leader +# epoch field is not included in the CRC computation to avoid the need to +# recompute the CRC when this field is assigned for every batch that is +# received by the broker. The CRC-32C (Castagnoli) polynomial is used for the +# computation. + +# The current RecordBatch attributes are given below: +# +# * Unused (6-15) +# * Control (5) +# * Transactional (4) +# * Timestamp Type (3) +# * Compression Type (0-2) + +import struct +import time +from kafka.record.abc import ABCRecord, ABCRecordBatch, ABCRecordBatchBuilder +from kafka.record.util import ( + decode_varint, encode_varint, calc_crc32c, size_of_varint +) +from kafka.errors import CorruptRecordException, UnsupportedCodecError +from kafka.codec import ( + gzip_encode, snappy_encode, lz4_encode, zstd_encode, + gzip_decode, snappy_decode, lz4_decode, zstd_decode +) +import kafka.codec as codecs + + +class DefaultRecordBase(object): + + __slots__ = () + + HEADER_STRUCT = struct.Struct( + ">q" # BaseOffset => Int64 + "i" # Length => Int32 + "i" # PartitionLeaderEpoch => Int32 + "b" # Magic => Int8 + "I" # CRC => Uint32 + "h" # Attributes => Int16 + "i" # LastOffsetDelta => Int32 // also serves as LastSequenceDelta + "q" # FirstTimestamp => Int64 + "q" # MaxTimestamp => Int64 + "q" # ProducerId => Int64 + "h" # ProducerEpoch => Int16 + "i" # BaseSequence => Int32 + "i" # Records count => Int32 + ) + # Byte offset in HEADER_STRUCT of attributes field. Used to calculate CRC + ATTRIBUTES_OFFSET = struct.calcsize(">qiibI") + CRC_OFFSET = struct.calcsize(">qiib") + AFTER_LEN_OFFSET = struct.calcsize(">qi") + + CODEC_MASK = 0x07 + CODEC_NONE = 0x00 + CODEC_GZIP = 0x01 + CODEC_SNAPPY = 0x02 + CODEC_LZ4 = 0x03 + CODEC_ZSTD = 0x04 + TIMESTAMP_TYPE_MASK = 0x08 + TRANSACTIONAL_MASK = 0x10 + CONTROL_MASK = 0x20 + + LOG_APPEND_TIME = 1 + CREATE_TIME = 0 + + def _assert_has_codec(self, compression_type): + if compression_type == self.CODEC_GZIP: + checker, name = codecs.has_gzip, "gzip" + elif compression_type == self.CODEC_SNAPPY: + checker, name = codecs.has_snappy, "snappy" + elif compression_type == self.CODEC_LZ4: + checker, name = codecs.has_lz4, "lz4" + elif compression_type == self.CODEC_ZSTD: + checker, name = codecs.has_zstd, "zstd" + if not checker(): + raise UnsupportedCodecError( + "Libraries for {} compression codec not found".format(name)) + + +class DefaultRecordBatch(DefaultRecordBase, ABCRecordBatch): + + __slots__ = ("_buffer", "_header_data", "_pos", "_num_records", + "_next_record_index", "_decompressed") + + def __init__(self, buffer): + self._buffer = bytearray(buffer) + self._header_data = self.HEADER_STRUCT.unpack_from(self._buffer) + self._pos = self.HEADER_STRUCT.size + self._num_records = self._header_data[12] + self._next_record_index = 0 + self._decompressed = False + + @property + def base_offset(self): + return self._header_data[0] + + @property + def magic(self): + return self._header_data[3] + + @property + def crc(self): + return self._header_data[4] + + @property + def attributes(self): + return self._header_data[5] + + @property + def last_offset_delta(self): + return self._header_data[6] + + @property + def compression_type(self): + return self.attributes & self.CODEC_MASK + + @property + def timestamp_type(self): + return int(bool(self.attributes & self.TIMESTAMP_TYPE_MASK)) + + @property + def is_transactional(self): + return bool(self.attributes & self.TRANSACTIONAL_MASK) + + @property + def is_control_batch(self): + return bool(self.attributes & self.CONTROL_MASK) + + @property + def first_timestamp(self): + return self._header_data[7] + + @property + def max_timestamp(self): + return self._header_data[8] + + def _maybe_uncompress(self): + if not self._decompressed: + compression_type = self.compression_type + if compression_type != self.CODEC_NONE: + self._assert_has_codec(compression_type) + data = memoryview(self._buffer)[self._pos:] + if compression_type == self.CODEC_GZIP: + uncompressed = gzip_decode(data) + if compression_type == self.CODEC_SNAPPY: + uncompressed = snappy_decode(data.tobytes()) + if compression_type == self.CODEC_LZ4: + uncompressed = lz4_decode(data.tobytes()) + if compression_type == self.CODEC_ZSTD: + uncompressed = zstd_decode(data.tobytes()) + self._buffer = bytearray(uncompressed) + self._pos = 0 + self._decompressed = True + + def _read_msg( + self, + decode_varint=decode_varint): + # Record => + # Length => Varint + # Attributes => Int8 + # TimestampDelta => Varlong + # OffsetDelta => Varint + # Key => Bytes + # Value => Bytes + # Headers => [HeaderKey HeaderValue] + # HeaderKey => String + # HeaderValue => Bytes + + buffer = self._buffer + pos = self._pos + length, pos = decode_varint(buffer, pos) + start_pos = pos + _, pos = decode_varint(buffer, pos) # attrs can be skipped for now + + ts_delta, pos = decode_varint(buffer, pos) + if self.timestamp_type == self.LOG_APPEND_TIME: + timestamp = self.max_timestamp + else: + timestamp = self.first_timestamp + ts_delta + + offset_delta, pos = decode_varint(buffer, pos) + offset = self.base_offset + offset_delta + + key_len, pos = decode_varint(buffer, pos) + if key_len >= 0: + key = bytes(buffer[pos: pos + key_len]) + pos += key_len + else: + key = None + + value_len, pos = decode_varint(buffer, pos) + if value_len >= 0: + value = bytes(buffer[pos: pos + value_len]) + pos += value_len + else: + value = None + + header_count, pos = decode_varint(buffer, pos) + if header_count < 0: + raise CorruptRecordException("Found invalid number of record " + "headers {}".format(header_count)) + headers = [] + while header_count: + # Header key is of type String, that can't be None + h_key_len, pos = decode_varint(buffer, pos) + if h_key_len < 0: + raise CorruptRecordException( + "Invalid negative header key size {}".format(h_key_len)) + h_key = buffer[pos: pos + h_key_len].decode("utf-8") + pos += h_key_len + + # Value is of type NULLABLE_BYTES, so it can be None + h_value_len, pos = decode_varint(buffer, pos) + if h_value_len >= 0: + h_value = bytes(buffer[pos: pos + h_value_len]) + pos += h_value_len + else: + h_value = None + + headers.append((h_key, h_value)) + header_count -= 1 + + # validate whether we have read all header bytes in the current record + if pos - start_pos != length: + raise CorruptRecordException( + "Invalid record size: expected to read {} bytes in record " + "payload, but instead read {}".format(length, pos - start_pos)) + self._pos = pos + + return DefaultRecord( + offset, timestamp, self.timestamp_type, key, value, headers) + + def __iter__(self): + self._maybe_uncompress() + return self + + def __next__(self): + if self._next_record_index >= self._num_records: + if self._pos != len(self._buffer): + raise CorruptRecordException( + "{} unconsumed bytes after all records consumed".format( + len(self._buffer) - self._pos)) + raise StopIteration + try: + msg = self._read_msg() + except (ValueError, IndexError) as err: + raise CorruptRecordException( + "Found invalid record structure: {!r}".format(err)) + else: + self._next_record_index += 1 + return msg + + next = __next__ + + def validate_crc(self): + assert self._decompressed is False, \ + "Validate should be called before iteration" + + crc = self.crc + data_view = memoryview(self._buffer)[self.ATTRIBUTES_OFFSET:] + verify_crc = calc_crc32c(data_view.tobytes()) + return crc == verify_crc + + +class DefaultRecord(ABCRecord): + + __slots__ = ("_offset", "_timestamp", "_timestamp_type", "_key", "_value", + "_headers") + + def __init__(self, offset, timestamp, timestamp_type, key, value, headers): + self._offset = offset + self._timestamp = timestamp + self._timestamp_type = timestamp_type + self._key = key + self._value = value + self._headers = headers + + @property + def offset(self): + return self._offset + + @property + def timestamp(self): + """ Epoch milliseconds + """ + return self._timestamp + + @property + def timestamp_type(self): + """ CREATE_TIME(0) or APPEND_TIME(1) + """ + return self._timestamp_type + + @property + def key(self): + """ Bytes key or None + """ + return self._key + + @property + def value(self): + """ Bytes value or None + """ + return self._value + + @property + def headers(self): + return self._headers + + @property + def checksum(self): + return None + + def __repr__(self): + return ( + "DefaultRecord(offset={!r}, timestamp={!r}, timestamp_type={!r}," + " key={!r}, value={!r}, headers={!r})".format( + self._offset, self._timestamp, self._timestamp_type, + self._key, self._value, self._headers) + ) + + +class DefaultRecordBatchBuilder(DefaultRecordBase, ABCRecordBatchBuilder): + + # excluding key, value and headers: + # 5 bytes length + 10 bytes timestamp + 5 bytes offset + 1 byte attributes + MAX_RECORD_OVERHEAD = 21 + + __slots__ = ("_magic", "_compression_type", "_batch_size", "_is_transactional", + "_producer_id", "_producer_epoch", "_base_sequence", + "_first_timestamp", "_max_timestamp", "_last_offset", "_num_records", + "_buffer") + + def __init__( + self, magic, compression_type, is_transactional, + producer_id, producer_epoch, base_sequence, batch_size): + assert magic >= 2 + self._magic = magic + self._compression_type = compression_type & self.CODEC_MASK + self._batch_size = batch_size + self._is_transactional = bool(is_transactional) + # KIP-98 fields for EOS + self._producer_id = producer_id + self._producer_epoch = producer_epoch + self._base_sequence = base_sequence + + self._first_timestamp = None + self._max_timestamp = None + self._last_offset = 0 + self._num_records = 0 + + self._buffer = bytearray(self.HEADER_STRUCT.size) + + def _get_attributes(self, include_compression_type=True): + attrs = 0 + if include_compression_type: + attrs |= self._compression_type + # Timestamp Type is set by Broker + if self._is_transactional: + attrs |= self.TRANSACTIONAL_MASK + # Control batches are only created by Broker + return attrs + + def append(self, offset, timestamp, key, value, headers, + # Cache for LOAD_FAST opcodes + encode_varint=encode_varint, size_of_varint=size_of_varint, + get_type=type, type_int=int, time_time=time.time, + byte_like=(bytes, bytearray, memoryview), + bytearray_type=bytearray, len_func=len, zero_len_varint=1 + ): + """ Write message to messageset buffer with MsgVersion 2 + """ + # Check types + if get_type(offset) != type_int: + raise TypeError(offset) + if timestamp is None: + timestamp = type_int(time_time() * 1000) + elif get_type(timestamp) != type_int: + raise TypeError(timestamp) + if not (key is None or get_type(key) in byte_like): + raise TypeError( + "Not supported type for key: {}".format(type(key))) + if not (value is None or get_type(value) in byte_like): + raise TypeError( + "Not supported type for value: {}".format(type(value))) + + # We will always add the first message, so those will be set + if self._first_timestamp is None: + self._first_timestamp = timestamp + self._max_timestamp = timestamp + timestamp_delta = 0 + first_message = 1 + else: + timestamp_delta = timestamp - self._first_timestamp + first_message = 0 + + # We can't write record right away to out buffer, we need to + # precompute the length as first value... + message_buffer = bytearray_type(b"\x00") # Attributes + write_byte = message_buffer.append + write = message_buffer.extend + + encode_varint(timestamp_delta, write_byte) + # Base offset is always 0 on Produce + encode_varint(offset, write_byte) + + if key is not None: + encode_varint(len_func(key), write_byte) + write(key) + else: + write_byte(zero_len_varint) + + if value is not None: + encode_varint(len_func(value), write_byte) + write(value) + else: + write_byte(zero_len_varint) + + encode_varint(len_func(headers), write_byte) + + for h_key, h_value in headers: + h_key = h_key.encode("utf-8") + encode_varint(len_func(h_key), write_byte) + write(h_key) + if h_value is not None: + encode_varint(len_func(h_value), write_byte) + write(h_value) + else: + write_byte(zero_len_varint) + + message_len = len_func(message_buffer) + main_buffer = self._buffer + + required_size = message_len + size_of_varint(message_len) + # Check if we can write this message + if (required_size + len_func(main_buffer) > self._batch_size and + not first_message): + return None + + # Those should be updated after the length check + if self._max_timestamp < timestamp: + self._max_timestamp = timestamp + self._num_records += 1 + self._last_offset = offset + + encode_varint(message_len, main_buffer.append) + main_buffer.extend(message_buffer) + + return DefaultRecordMetadata(offset, required_size, timestamp) + + def write_header(self, use_compression_type=True): + batch_len = len(self._buffer) + self.HEADER_STRUCT.pack_into( + self._buffer, 0, + 0, # BaseOffset, set by broker + batch_len - self.AFTER_LEN_OFFSET, # Size from here to end + 0, # PartitionLeaderEpoch, set by broker + self._magic, + 0, # CRC will be set below, as we need a filled buffer for it + self._get_attributes(use_compression_type), + self._last_offset, + self._first_timestamp, + self._max_timestamp, + self._producer_id, + self._producer_epoch, + self._base_sequence, + self._num_records + ) + crc = calc_crc32c(self._buffer[self.ATTRIBUTES_OFFSET:]) + struct.pack_into(">I", self._buffer, self.CRC_OFFSET, crc) + + def _maybe_compress(self): + if self._compression_type != self.CODEC_NONE: + self._assert_has_codec(self._compression_type) + header_size = self.HEADER_STRUCT.size + data = bytes(self._buffer[header_size:]) + if self._compression_type == self.CODEC_GZIP: + compressed = gzip_encode(data) + elif self._compression_type == self.CODEC_SNAPPY: + compressed = snappy_encode(data) + elif self._compression_type == self.CODEC_LZ4: + compressed = lz4_encode(data) + elif self._compression_type == self.CODEC_ZSTD: + compressed = zstd_encode(data) + compressed_size = len(compressed) + if len(data) <= compressed_size: + # We did not get any benefit from compression, lets send + # uncompressed + return False + else: + # Trim bytearray to the required size + needed_size = header_size + compressed_size + del self._buffer[needed_size:] + self._buffer[header_size:needed_size] = compressed + return True + return False + + def build(self): + send_compressed = self._maybe_compress() + self.write_header(send_compressed) + return self._buffer + + def size(self): + """ Return current size of data written to buffer + """ + return len(self._buffer) + + def size_in_bytes(self, offset, timestamp, key, value, headers): + if self._first_timestamp is not None: + timestamp_delta = timestamp - self._first_timestamp + else: + timestamp_delta = 0 + size_of_body = ( + 1 + # Attrs + size_of_varint(offset) + + size_of_varint(timestamp_delta) + + self.size_of(key, value, headers) + ) + return size_of_body + size_of_varint(size_of_body) + + @classmethod + def size_of(cls, key, value, headers): + size = 0 + # Key size + if key is None: + size += 1 + else: + key_len = len(key) + size += size_of_varint(key_len) + key_len + # Value size + if value is None: + size += 1 + else: + value_len = len(value) + size += size_of_varint(value_len) + value_len + # Header size + size += size_of_varint(len(headers)) + for h_key, h_value in headers: + h_key_len = len(h_key.encode("utf-8")) + size += size_of_varint(h_key_len) + h_key_len + + if h_value is None: + size += 1 + else: + h_value_len = len(h_value) + size += size_of_varint(h_value_len) + h_value_len + return size + + @classmethod + def estimate_size_in_bytes(cls, key, value, headers): + """ Get the upper bound estimate on the size of record + """ + return ( + cls.HEADER_STRUCT.size + cls.MAX_RECORD_OVERHEAD + + cls.size_of(key, value, headers) + ) + + +class DefaultRecordMetadata(object): + + __slots__ = ("_size", "_timestamp", "_offset") + + def __init__(self, offset, size, timestamp): + self._offset = offset + self._size = size + self._timestamp = timestamp + + @property + def offset(self): + return self._offset + + @property + def crc(self): + return None + + @property + def size(self): + return self._size + + @property + def timestamp(self): + return self._timestamp + + def __repr__(self): + return ( + "DefaultRecordMetadata(offset={!r}, size={!r}, timestamp={!r})" + .format(self._offset, self._size, self._timestamp) + ) diff --git a/record/legacy_records.py b/record/legacy_records.py new file mode 100644 index 00000000..2f8523fc --- /dev/null +++ b/record/legacy_records.py @@ -0,0 +1,548 @@ +# See: +# https://github.com/apache/kafka/blob/trunk/clients/src/main/java/org/\ +# apache/kafka/common/record/LegacyRecord.java + +# Builder and reader implementation for V0 and V1 record versions. As of Kafka +# 0.11.0.0 those were replaced with V2, thus the Legacy naming. + +# The schema is given below (see +# https://kafka.apache.org/protocol#protocol_message_sets for more details): + +# MessageSet => [Offset MessageSize Message] +# Offset => int64 +# MessageSize => int32 + +# v0 +# Message => Crc MagicByte Attributes Key Value +# Crc => int32 +# MagicByte => int8 +# Attributes => int8 +# Key => bytes +# Value => bytes + +# v1 (supported since 0.10.0) +# Message => Crc MagicByte Attributes Key Value +# Crc => int32 +# MagicByte => int8 +# Attributes => int8 +# Timestamp => int64 +# Key => bytes +# Value => bytes + +# The message attribute bits are given below: +# * Unused (4-7) +# * Timestamp Type (3) (added in V1) +# * Compression Type (0-2) + +# Note that when compression is enabled (see attributes above), the whole +# array of MessageSet's is compressed and places into a message as the `value`. +# Only the parent message is marked with `compression` bits in attributes. + +# The CRC covers the data from the Magic byte to the end of the message. + + +import struct +import time + +from kafka.record.abc import ABCRecord, ABCRecordBatch, ABCRecordBatchBuilder +from kafka.record.util import calc_crc32 + +from kafka.codec import ( + gzip_encode, snappy_encode, lz4_encode, lz4_encode_old_kafka, + gzip_decode, snappy_decode, lz4_decode, lz4_decode_old_kafka, +) +import kafka.codec as codecs +from kafka.errors import CorruptRecordException, UnsupportedCodecError + + +class LegacyRecordBase(object): + + __slots__ = () + + HEADER_STRUCT_V0 = struct.Struct( + ">q" # BaseOffset => Int64 + "i" # Length => Int32 + "I" # CRC => Int32 + "b" # Magic => Int8 + "b" # Attributes => Int8 + ) + HEADER_STRUCT_V1 = struct.Struct( + ">q" # BaseOffset => Int64 + "i" # Length => Int32 + "I" # CRC => Int32 + "b" # Magic => Int8 + "b" # Attributes => Int8 + "q" # timestamp => Int64 + ) + + LOG_OVERHEAD = CRC_OFFSET = struct.calcsize( + ">q" # Offset + "i" # Size + ) + MAGIC_OFFSET = LOG_OVERHEAD + struct.calcsize( + ">I" # CRC + ) + # Those are used for fast size calculations + RECORD_OVERHEAD_V0 = struct.calcsize( + ">I" # CRC + "b" # magic + "b" # attributes + "i" # Key length + "i" # Value length + ) + RECORD_OVERHEAD_V1 = struct.calcsize( + ">I" # CRC + "b" # magic + "b" # attributes + "q" # timestamp + "i" # Key length + "i" # Value length + ) + + KEY_OFFSET_V0 = HEADER_STRUCT_V0.size + KEY_OFFSET_V1 = HEADER_STRUCT_V1.size + KEY_LENGTH = VALUE_LENGTH = struct.calcsize(">i") # Bytes length is Int32 + + CODEC_MASK = 0x07 + CODEC_NONE = 0x00 + CODEC_GZIP = 0x01 + CODEC_SNAPPY = 0x02 + CODEC_LZ4 = 0x03 + TIMESTAMP_TYPE_MASK = 0x08 + + LOG_APPEND_TIME = 1 + CREATE_TIME = 0 + + NO_TIMESTAMP = -1 + + def _assert_has_codec(self, compression_type): + if compression_type == self.CODEC_GZIP: + checker, name = codecs.has_gzip, "gzip" + elif compression_type == self.CODEC_SNAPPY: + checker, name = codecs.has_snappy, "snappy" + elif compression_type == self.CODEC_LZ4: + checker, name = codecs.has_lz4, "lz4" + if not checker(): + raise UnsupportedCodecError( + "Libraries for {} compression codec not found".format(name)) + + +class LegacyRecordBatch(ABCRecordBatch, LegacyRecordBase): + + __slots__ = ("_buffer", "_magic", "_offset", "_crc", "_timestamp", + "_attributes", "_decompressed") + + def __init__(self, buffer, magic): + self._buffer = memoryview(buffer) + self._magic = magic + + offset, length, crc, magic_, attrs, timestamp = self._read_header(0) + assert length == len(buffer) - self.LOG_OVERHEAD + assert magic == magic_ + + self._offset = offset + self._crc = crc + self._timestamp = timestamp + self._attributes = attrs + self._decompressed = False + + @property + def timestamp_type(self): + """0 for CreateTime; 1 for LogAppendTime; None if unsupported. + + Value is determined by broker; produced messages should always set to 0 + Requires Kafka >= 0.10 / message version >= 1 + """ + if self._magic == 0: + return None + elif self._attributes & self.TIMESTAMP_TYPE_MASK: + return 1 + else: + return 0 + + @property + def compression_type(self): + return self._attributes & self.CODEC_MASK + + def validate_crc(self): + crc = calc_crc32(self._buffer[self.MAGIC_OFFSET:]) + return self._crc == crc + + def _decompress(self, key_offset): + # Copy of `_read_key_value`, but uses memoryview + pos = key_offset + key_size = struct.unpack_from(">i", self._buffer, pos)[0] + pos += self.KEY_LENGTH + if key_size != -1: + pos += key_size + value_size = struct.unpack_from(">i", self._buffer, pos)[0] + pos += self.VALUE_LENGTH + if value_size == -1: + raise CorruptRecordException("Value of compressed message is None") + else: + data = self._buffer[pos:pos + value_size] + + compression_type = self.compression_type + self._assert_has_codec(compression_type) + if compression_type == self.CODEC_GZIP: + uncompressed = gzip_decode(data) + elif compression_type == self.CODEC_SNAPPY: + uncompressed = snappy_decode(data.tobytes()) + elif compression_type == self.CODEC_LZ4: + if self._magic == 0: + uncompressed = lz4_decode_old_kafka(data.tobytes()) + else: + uncompressed = lz4_decode(data.tobytes()) + return uncompressed + + def _read_header(self, pos): + if self._magic == 0: + offset, length, crc, magic_read, attrs = \ + self.HEADER_STRUCT_V0.unpack_from(self._buffer, pos) + timestamp = None + else: + offset, length, crc, magic_read, attrs, timestamp = \ + self.HEADER_STRUCT_V1.unpack_from(self._buffer, pos) + return offset, length, crc, magic_read, attrs, timestamp + + def _read_all_headers(self): + pos = 0 + msgs = [] + buffer_len = len(self._buffer) + while pos < buffer_len: + header = self._read_header(pos) + msgs.append((header, pos)) + pos += self.LOG_OVERHEAD + header[1] # length + return msgs + + def _read_key_value(self, pos): + key_size = struct.unpack_from(">i", self._buffer, pos)[0] + pos += self.KEY_LENGTH + if key_size == -1: + key = None + else: + key = self._buffer[pos:pos + key_size].tobytes() + pos += key_size + + value_size = struct.unpack_from(">i", self._buffer, pos)[0] + pos += self.VALUE_LENGTH + if value_size == -1: + value = None + else: + value = self._buffer[pos:pos + value_size].tobytes() + return key, value + + def __iter__(self): + if self._magic == 1: + key_offset = self.KEY_OFFSET_V1 + else: + key_offset = self.KEY_OFFSET_V0 + timestamp_type = self.timestamp_type + + if self.compression_type: + # In case we will call iter again + if not self._decompressed: + self._buffer = memoryview(self._decompress(key_offset)) + self._decompressed = True + + # If relative offset is used, we need to decompress the entire + # message first to compute the absolute offset. + headers = self._read_all_headers() + if self._magic > 0: + msg_header, _ = headers[-1] + absolute_base_offset = self._offset - msg_header[0] + else: + absolute_base_offset = -1 + + for header, msg_pos in headers: + offset, _, crc, _, attrs, timestamp = header + # There should only ever be a single layer of compression + assert not attrs & self.CODEC_MASK, ( + 'MessageSet at offset %d appears double-compressed. This ' + 'should not happen -- check your producers!' % (offset,)) + + # When magic value is greater than 0, the timestamp + # of a compressed message depends on the + # timestamp type of the wrapper message: + if timestamp_type == self.LOG_APPEND_TIME: + timestamp = self._timestamp + + if absolute_base_offset >= 0: + offset += absolute_base_offset + + key, value = self._read_key_value(msg_pos + key_offset) + yield LegacyRecord( + offset, timestamp, timestamp_type, + key, value, crc) + else: + key, value = self._read_key_value(key_offset) + yield LegacyRecord( + self._offset, self._timestamp, timestamp_type, + key, value, self._crc) + + +class LegacyRecord(ABCRecord): + + __slots__ = ("_offset", "_timestamp", "_timestamp_type", "_key", "_value", + "_crc") + + def __init__(self, offset, timestamp, timestamp_type, key, value, crc): + self._offset = offset + self._timestamp = timestamp + self._timestamp_type = timestamp_type + self._key = key + self._value = value + self._crc = crc + + @property + def offset(self): + return self._offset + + @property + def timestamp(self): + """ Epoch milliseconds + """ + return self._timestamp + + @property + def timestamp_type(self): + """ CREATE_TIME(0) or APPEND_TIME(1) + """ + return self._timestamp_type + + @property + def key(self): + """ Bytes key or None + """ + return self._key + + @property + def value(self): + """ Bytes value or None + """ + return self._value + + @property + def headers(self): + return [] + + @property + def checksum(self): + return self._crc + + def __repr__(self): + return ( + "LegacyRecord(offset={!r}, timestamp={!r}, timestamp_type={!r}," + " key={!r}, value={!r}, crc={!r})".format( + self._offset, self._timestamp, self._timestamp_type, + self._key, self._value, self._crc) + ) + + +class LegacyRecordBatchBuilder(ABCRecordBatchBuilder, LegacyRecordBase): + + __slots__ = ("_magic", "_compression_type", "_batch_size", "_buffer") + + def __init__(self, magic, compression_type, batch_size): + self._magic = magic + self._compression_type = compression_type + self._batch_size = batch_size + self._buffer = bytearray() + + def append(self, offset, timestamp, key, value, headers=None): + """ Append message to batch. + """ + assert not headers, "Headers not supported in v0/v1" + # Check types + if type(offset) != int: + raise TypeError(offset) + if self._magic == 0: + timestamp = self.NO_TIMESTAMP + elif timestamp is None: + timestamp = int(time.time() * 1000) + elif type(timestamp) != int: + raise TypeError( + "`timestamp` should be int, but {} provided".format( + type(timestamp))) + if not (key is None or + isinstance(key, (bytes, bytearray, memoryview))): + raise TypeError( + "Not supported type for key: {}".format(type(key))) + if not (value is None or + isinstance(value, (bytes, bytearray, memoryview))): + raise TypeError( + "Not supported type for value: {}".format(type(value))) + + # Check if we have room for another message + pos = len(self._buffer) + size = self.size_in_bytes(offset, timestamp, key, value) + # We always allow at least one record to be appended + if offset != 0 and pos + size >= self._batch_size: + return None + + # Allocate proper buffer length + self._buffer.extend(bytearray(size)) + + # Encode message + crc = self._encode_msg(pos, offset, timestamp, key, value) + + return LegacyRecordMetadata(offset, crc, size, timestamp) + + def _encode_msg(self, start_pos, offset, timestamp, key, value, + attributes=0): + """ Encode msg data into the `msg_buffer`, which should be allocated + to at least the size of this message. + """ + magic = self._magic + buf = self._buffer + pos = start_pos + + # Write key and value + pos += self.KEY_OFFSET_V0 if magic == 0 else self.KEY_OFFSET_V1 + + if key is None: + struct.pack_into(">i", buf, pos, -1) + pos += self.KEY_LENGTH + else: + key_size = len(key) + struct.pack_into(">i", buf, pos, key_size) + pos += self.KEY_LENGTH + buf[pos: pos + key_size] = key + pos += key_size + + if value is None: + struct.pack_into(">i", buf, pos, -1) + pos += self.VALUE_LENGTH + else: + value_size = len(value) + struct.pack_into(">i", buf, pos, value_size) + pos += self.VALUE_LENGTH + buf[pos: pos + value_size] = value + pos += value_size + length = (pos - start_pos) - self.LOG_OVERHEAD + + # Write msg header. Note, that Crc will be updated later + if magic == 0: + self.HEADER_STRUCT_V0.pack_into( + buf, start_pos, + offset, length, 0, magic, attributes) + else: + self.HEADER_STRUCT_V1.pack_into( + buf, start_pos, + offset, length, 0, magic, attributes, timestamp) + + # Calculate CRC for msg + crc_data = memoryview(buf)[start_pos + self.MAGIC_OFFSET:] + crc = calc_crc32(crc_data) + struct.pack_into(">I", buf, start_pos + self.CRC_OFFSET, crc) + return crc + + def _maybe_compress(self): + if self._compression_type: + self._assert_has_codec(self._compression_type) + data = bytes(self._buffer) + if self._compression_type == self.CODEC_GZIP: + compressed = gzip_encode(data) + elif self._compression_type == self.CODEC_SNAPPY: + compressed = snappy_encode(data) + elif self._compression_type == self.CODEC_LZ4: + if self._magic == 0: + compressed = lz4_encode_old_kafka(data) + else: + compressed = lz4_encode(data) + size = self.size_in_bytes( + 0, timestamp=0, key=None, value=compressed) + # We will try to reuse the same buffer if we have enough space + if size > len(self._buffer): + self._buffer = bytearray(size) + else: + del self._buffer[size:] + self._encode_msg( + start_pos=0, + offset=0, timestamp=0, key=None, value=compressed, + attributes=self._compression_type) + return True + return False + + def build(self): + """Compress batch to be ready for send""" + self._maybe_compress() + return self._buffer + + def size(self): + """ Return current size of data written to buffer + """ + return len(self._buffer) + + # Size calculations. Just copied Java's implementation + + def size_in_bytes(self, offset, timestamp, key, value, headers=None): + """ Actual size of message to add + """ + assert not headers, "Headers not supported in v0/v1" + magic = self._magic + return self.LOG_OVERHEAD + self.record_size(magic, key, value) + + @classmethod + def record_size(cls, magic, key, value): + message_size = cls.record_overhead(magic) + if key is not None: + message_size += len(key) + if value is not None: + message_size += len(value) + return message_size + + @classmethod + def record_overhead(cls, magic): + assert magic in [0, 1], "Not supported magic" + if magic == 0: + return cls.RECORD_OVERHEAD_V0 + else: + return cls.RECORD_OVERHEAD_V1 + + @classmethod + def estimate_size_in_bytes(cls, magic, compression_type, key, value): + """ Upper bound estimate of record size. + """ + assert magic in [0, 1], "Not supported magic" + # In case of compression we may need another overhead for inner msg + if compression_type: + return ( + cls.LOG_OVERHEAD + cls.record_overhead(magic) + + cls.record_size(magic, key, value) + ) + return cls.LOG_OVERHEAD + cls.record_size(magic, key, value) + + +class LegacyRecordMetadata(object): + + __slots__ = ("_crc", "_size", "_timestamp", "_offset") + + def __init__(self, offset, crc, size, timestamp): + self._offset = offset + self._crc = crc + self._size = size + self._timestamp = timestamp + + @property + def offset(self): + return self._offset + + @property + def crc(self): + return self._crc + + @property + def size(self): + return self._size + + @property + def timestamp(self): + return self._timestamp + + def __repr__(self): + return ( + "LegacyRecordMetadata(offset={!r}, crc={!r}, size={!r}," + " timestamp={!r})".format( + self._offset, self._crc, self._size, self._timestamp) + ) diff --git a/record/memory_records.py b/record/memory_records.py new file mode 100644 index 00000000..fc2ef2d6 --- /dev/null +++ b/record/memory_records.py @@ -0,0 +1,187 @@ +# This class takes advantage of the fact that all formats v0, v1 and v2 of +# messages storage has the same byte offsets for Length and Magic fields. +# Lets look closely at what leading bytes all versions have: +# +# V0 and V1 (Offset is MessageSet part, other bytes are Message ones): +# Offset => Int64 +# BytesLength => Int32 +# CRC => Int32 +# Magic => Int8 +# ... +# +# V2: +# BaseOffset => Int64 +# Length => Int32 +# PartitionLeaderEpoch => Int32 +# Magic => Int8 +# ... +# +# So we can iterate over batches just by knowing offsets of Length. Magic is +# used to construct the correct class for Batch itself. +from __future__ import division + +import struct + +from kafka.errors import CorruptRecordException +from kafka.record.abc import ABCRecords +from kafka.record.legacy_records import LegacyRecordBatch, LegacyRecordBatchBuilder +from kafka.record.default_records import DefaultRecordBatch, DefaultRecordBatchBuilder + + +class MemoryRecords(ABCRecords): + + LENGTH_OFFSET = struct.calcsize(">q") + LOG_OVERHEAD = struct.calcsize(">qi") + MAGIC_OFFSET = struct.calcsize(">qii") + + # Minimum space requirements for Record V0 + MIN_SLICE = LOG_OVERHEAD + LegacyRecordBatch.RECORD_OVERHEAD_V0 + + __slots__ = ("_buffer", "_pos", "_next_slice", "_remaining_bytes") + + def __init__(self, bytes_data): + self._buffer = bytes_data + self._pos = 0 + # We keep one slice ahead so `has_next` will return very fast + self._next_slice = None + self._remaining_bytes = None + self._cache_next() + + def size_in_bytes(self): + return len(self._buffer) + + def valid_bytes(self): + # We need to read the whole buffer to get the valid_bytes. + # NOTE: in Fetcher we do the call after iteration, so should be fast + if self._remaining_bytes is None: + next_slice = self._next_slice + pos = self._pos + while self._remaining_bytes is None: + self._cache_next() + # Reset previous iterator position + self._next_slice = next_slice + self._pos = pos + return len(self._buffer) - self._remaining_bytes + + # NOTE: we cache offsets here as kwargs for a bit more speed, as cPython + # will use LOAD_FAST opcode in this case + def _cache_next(self, len_offset=LENGTH_OFFSET, log_overhead=LOG_OVERHEAD): + buffer = self._buffer + buffer_len = len(buffer) + pos = self._pos + remaining = buffer_len - pos + if remaining < log_overhead: + # Will be re-checked in Fetcher for remaining bytes. + self._remaining_bytes = remaining + self._next_slice = None + return + + length, = struct.unpack_from( + ">i", buffer, pos + len_offset) + + slice_end = pos + log_overhead + length + if slice_end > buffer_len: + # Will be re-checked in Fetcher for remaining bytes + self._remaining_bytes = remaining + self._next_slice = None + return + + self._next_slice = memoryview(buffer)[pos: slice_end] + self._pos = slice_end + + def has_next(self): + return self._next_slice is not None + + # NOTE: same cache for LOAD_FAST as above + def next_batch(self, _min_slice=MIN_SLICE, + _magic_offset=MAGIC_OFFSET): + next_slice = self._next_slice + if next_slice is None: + return None + if len(next_slice) < _min_slice: + raise CorruptRecordException( + "Record size is less than the minimum record overhead " + "({})".format(_min_slice - self.LOG_OVERHEAD)) + self._cache_next() + magic, = struct.unpack_from(">b", next_slice, _magic_offset) + if magic <= 1: + return LegacyRecordBatch(next_slice, magic) + else: + return DefaultRecordBatch(next_slice) + + +class MemoryRecordsBuilder(object): + + __slots__ = ("_builder", "_batch_size", "_buffer", "_next_offset", "_closed", + "_bytes_written") + + def __init__(self, magic, compression_type, batch_size): + assert magic in [0, 1, 2], "Not supported magic" + assert compression_type in [0, 1, 2, 3, 4], "Not valid compression type" + if magic >= 2: + self._builder = DefaultRecordBatchBuilder( + magic=magic, compression_type=compression_type, + is_transactional=False, producer_id=-1, producer_epoch=-1, + base_sequence=-1, batch_size=batch_size) + else: + self._builder = LegacyRecordBatchBuilder( + magic=magic, compression_type=compression_type, + batch_size=batch_size) + self._batch_size = batch_size + self._buffer = None + + self._next_offset = 0 + self._closed = False + self._bytes_written = 0 + + def append(self, timestamp, key, value, headers=[]): + """ Append a message to the buffer. + + Returns: RecordMetadata or None if unable to append + """ + if self._closed: + return None + + offset = self._next_offset + metadata = self._builder.append(offset, timestamp, key, value, headers) + # Return of None means there's no space to add a new message + if metadata is None: + return None + + self._next_offset += 1 + return metadata + + def close(self): + # This method may be called multiple times on the same batch + # i.e., on retries + # we need to make sure we only close it out once + # otherwise compressed messages may be double-compressed + # see Issue 718 + if not self._closed: + self._bytes_written = self._builder.size() + self._buffer = bytes(self._builder.build()) + self._builder = None + self._closed = True + + def size_in_bytes(self): + if not self._closed: + return self._builder.size() + else: + return len(self._buffer) + + def compression_rate(self): + assert self._closed + return self.size_in_bytes() / self._bytes_written + + def is_full(self): + if self._closed: + return True + else: + return self._builder.size() >= self._batch_size + + def next_offset(self): + return self._next_offset + + def buffer(self): + assert self._closed + return self._buffer diff --git a/record/util.py b/record/util.py new file mode 100644 index 00000000..3b712005 --- /dev/null +++ b/record/util.py @@ -0,0 +1,135 @@ +import binascii + +from kafka.record._crc32c import crc as crc32c_py +try: + from crc32c import crc32c as crc32c_c +except ImportError: + crc32c_c = None + + +def encode_varint(value, write): + """ Encode an integer to a varint presentation. See + https://developers.google.com/protocol-buffers/docs/encoding?csw=1#varints + on how those can be produced. + + Arguments: + value (int): Value to encode + write (function): Called per byte that needs to be writen + + Returns: + int: Number of bytes written + """ + value = (value << 1) ^ (value >> 63) + + if value <= 0x7f: # 1 byte + write(value) + return 1 + if value <= 0x3fff: # 2 bytes + write(0x80 | (value & 0x7f)) + write(value >> 7) + return 2 + if value <= 0x1fffff: # 3 bytes + write(0x80 | (value & 0x7f)) + write(0x80 | ((value >> 7) & 0x7f)) + write(value >> 14) + return 3 + if value <= 0xfffffff: # 4 bytes + write(0x80 | (value & 0x7f)) + write(0x80 | ((value >> 7) & 0x7f)) + write(0x80 | ((value >> 14) & 0x7f)) + write(value >> 21) + return 4 + if value <= 0x7ffffffff: # 5 bytes + write(0x80 | (value & 0x7f)) + write(0x80 | ((value >> 7) & 0x7f)) + write(0x80 | ((value >> 14) & 0x7f)) + write(0x80 | ((value >> 21) & 0x7f)) + write(value >> 28) + return 5 + else: + # Return to general algorithm + bits = value & 0x7f + value >>= 7 + i = 0 + while value: + write(0x80 | bits) + bits = value & 0x7f + value >>= 7 + i += 1 + write(bits) + return i + + +def size_of_varint(value): + """ Number of bytes needed to encode an integer in variable-length format. + """ + value = (value << 1) ^ (value >> 63) + if value <= 0x7f: + return 1 + if value <= 0x3fff: + return 2 + if value <= 0x1fffff: + return 3 + if value <= 0xfffffff: + return 4 + if value <= 0x7ffffffff: + return 5 + if value <= 0x3ffffffffff: + return 6 + if value <= 0x1ffffffffffff: + return 7 + if value <= 0xffffffffffffff: + return 8 + if value <= 0x7fffffffffffffff: + return 9 + return 10 + + +def decode_varint(buffer, pos=0): + """ Decode an integer from a varint presentation. See + https://developers.google.com/protocol-buffers/docs/encoding?csw=1#varints + on how those can be produced. + + Arguments: + buffer (bytearray): buffer to read from. + pos (int): optional position to read from + + Returns: + (int, int): Decoded int value and next read position + """ + result = buffer[pos] + if not (result & 0x81): + return (result >> 1), pos + 1 + if not (result & 0x80): + return (result >> 1) ^ (~0), pos + 1 + + result &= 0x7f + pos += 1 + shift = 7 + while 1: + b = buffer[pos] + result |= ((b & 0x7f) << shift) + pos += 1 + if not (b & 0x80): + return ((result >> 1) ^ -(result & 1), pos) + shift += 7 + if shift >= 64: + raise ValueError("Out of int64 range") + + +_crc32c = crc32c_py +if crc32c_c is not None: + _crc32c = crc32c_c + + +def calc_crc32c(memview, _crc32c=_crc32c): + """ Calculate CRC-32C (Castagnoli) checksum over a memoryview of data + """ + return _crc32c(memview) + + +def calc_crc32(memview): + """ Calculate simple CRC-32 checksum over a memoryview of data + """ + crc = binascii.crc32(memview) & 0xffffffff + return crc diff --git a/scram.py b/scram.py new file mode 100644 index 00000000..7f003750 --- /dev/null +++ b/scram.py @@ -0,0 +1,81 @@ +from __future__ import absolute_import + +import base64 +import hashlib +import hmac +import uuid + +from kafka.vendor import six + + +if six.PY2: + def xor_bytes(left, right): + return bytearray(ord(lb) ^ ord(rb) for lb, rb in zip(left, right)) +else: + def xor_bytes(left, right): + return bytes(lb ^ rb for lb, rb in zip(left, right)) + + +class ScramClient: + MECHANISMS = { + 'SCRAM-SHA-256': hashlib.sha256, + 'SCRAM-SHA-512': hashlib.sha512 + } + + def __init__(self, user, password, mechanism): + self.nonce = str(uuid.uuid4()).replace('-', '') + self.auth_message = '' + self.salted_password = None + self.user = user + self.password = password.encode('utf-8') + self.hashfunc = self.MECHANISMS[mechanism] + self.hashname = ''.join(mechanism.lower().split('-')[1:3]) + self.stored_key = None + self.client_key = None + self.client_signature = None + self.client_proof = None + self.server_key = None + self.server_signature = None + + def first_message(self): + client_first_bare = 'n={},r={}'.format(self.user, self.nonce) + self.auth_message += client_first_bare + return 'n,,' + client_first_bare + + def process_server_first_message(self, server_first_message): + self.auth_message += ',' + server_first_message + params = dict(pair.split('=', 1) for pair in server_first_message.split(',')) + server_nonce = params['r'] + if not server_nonce.startswith(self.nonce): + raise ValueError("Server nonce, did not start with client nonce!") + self.nonce = server_nonce + self.auth_message += ',c=biws,r=' + self.nonce + + salt = base64.b64decode(params['s'].encode('utf-8')) + iterations = int(params['i']) + self.create_salted_password(salt, iterations) + + self.client_key = self.hmac(self.salted_password, b'Client Key') + self.stored_key = self.hashfunc(self.client_key).digest() + self.client_signature = self.hmac(self.stored_key, self.auth_message.encode('utf-8')) + self.client_proof = xor_bytes(self.client_key, self.client_signature) + self.server_key = self.hmac(self.salted_password, b'Server Key') + self.server_signature = self.hmac(self.server_key, self.auth_message.encode('utf-8')) + + def hmac(self, key, msg): + return hmac.new(key, msg, digestmod=self.hashfunc).digest() + + def create_salted_password(self, salt, iterations): + self.salted_password = hashlib.pbkdf2_hmac( + self.hashname, self.password, salt, iterations + ) + + def final_message(self): + return 'c=biws,r={},p={}'.format(self.nonce, base64.b64encode(self.client_proof).decode('utf-8')) + + def process_server_final_message(self, server_final_message): + params = dict(pair.split('=', 1) for pair in server_final_message.split(',')) + if self.server_signature != base64.b64decode(params['v'].encode('utf-8')): + raise ValueError("Server sent wrong signature!") + + diff --git a/serializer/__init__.py b/serializer/__init__.py new file mode 100644 index 00000000..90cd93ab --- /dev/null +++ b/serializer/__init__.py @@ -0,0 +1,3 @@ +from __future__ import absolute_import + +from kafka.serializer.abstract import Serializer, Deserializer diff --git a/serializer/abstract.py b/serializer/abstract.py new file mode 100644 index 00000000..18ad8d69 --- /dev/null +++ b/serializer/abstract.py @@ -0,0 +1,31 @@ +from __future__ import absolute_import + +import abc + + +class Serializer(object): + __meta__ = abc.ABCMeta + + def __init__(self, **config): + pass + + @abc.abstractmethod + def serialize(self, topic, value): + pass + + def close(self): + pass + + +class Deserializer(object): + __meta__ = abc.ABCMeta + + def __init__(self, **config): + pass + + @abc.abstractmethod + def deserialize(self, topic, bytes_): + pass + + def close(self): + pass diff --git a/structs.py b/structs.py new file mode 100644 index 00000000..bcb02367 --- /dev/null +++ b/structs.py @@ -0,0 +1,87 @@ +""" Other useful structs """ +from __future__ import absolute_import + +from collections import namedtuple + + +"""A topic and partition tuple + +Keyword Arguments: + topic (str): A topic name + partition (int): A partition id +""" +TopicPartition = namedtuple("TopicPartition", + ["topic", "partition"]) + + +"""A Kafka broker metadata used by admin tools. + +Keyword Arguments: + nodeID (int): The Kafka broker id. + host (str): The Kafka broker hostname. + port (int): The Kafka broker port. + rack (str): The rack of the broker, which is used to in rack aware + partition assignment for fault tolerance. + Examples: `RACK1`, `us-east-1d`. Default: None +""" +BrokerMetadata = namedtuple("BrokerMetadata", + ["nodeId", "host", "port", "rack"]) + + +"""A topic partition metadata describing the state in the MetadataResponse. + +Keyword Arguments: + topic (str): The topic name of the partition this metadata relates to. + partition (int): The id of the partition this metadata relates to. + leader (int): The id of the broker that is the leader for the partition. + replicas (List[int]): The ids of all brokers that contain replicas of the + partition. + isr (List[int]): The ids of all brokers that contain in-sync replicas of + the partition. + error (KafkaError): A KafkaError object associated with the request for + this partition metadata. +""" +PartitionMetadata = namedtuple("PartitionMetadata", + ["topic", "partition", "leader", "replicas", "isr", "error"]) + + +"""The Kafka offset commit API + +The Kafka offset commit API allows users to provide additional metadata +(in the form of a string) when an offset is committed. This can be useful +(for example) to store information about which node made the commit, +what time the commit was made, etc. + +Keyword Arguments: + offset (int): The offset to be committed + metadata (str): Non-null metadata +""" +OffsetAndMetadata = namedtuple("OffsetAndMetadata", + # TODO add leaderEpoch: OffsetAndMetadata(offset, leaderEpoch, metadata) + ["offset", "metadata"]) + + +"""An offset and timestamp tuple + +Keyword Arguments: + offset (int): An offset + timestamp (int): The timestamp associated to the offset +""" +OffsetAndTimestamp = namedtuple("OffsetAndTimestamp", + ["offset", "timestamp"]) + +MemberInformation = namedtuple("MemberInformation", + ["member_id", "client_id", "client_host", "member_metadata", "member_assignment"]) + +GroupInformation = namedtuple("GroupInformation", + ["error_code", "group", "state", "protocol_type", "protocol", "members", "authorized_operations"]) + +"""Define retry policy for async producer + +Keyword Arguments: + Limit (int): Number of retries. limit >= 0, 0 means no retries + backoff_ms (int): Milliseconds to backoff. + retry_on_timeouts: +""" +RetryOptions = namedtuple("RetryOptions", + ["limit", "backoff_ms", "retry_on_timeouts"]) diff --git a/util.py b/util.py new file mode 100644 index 00000000..e31d9930 --- /dev/null +++ b/util.py @@ -0,0 +1,66 @@ +from __future__ import absolute_import + +import binascii +import weakref + +from kafka.vendor import six + + +if six.PY3: + MAX_INT = 2 ** 31 + TO_SIGNED = 2 ** 32 + + def crc32(data): + crc = binascii.crc32(data) + # py2 and py3 behave a little differently + # CRC is encoded as a signed int in kafka protocol + # so we'll convert the py3 unsigned result to signed + if crc >= MAX_INT: + crc -= TO_SIGNED + return crc +else: + from binascii import crc32 + + +class WeakMethod(object): + """ + Callable that weakly references a method and the object it is bound to. It + is based on https://stackoverflow.com/a/24287465. + + Arguments: + + object_dot_method: A bound instance method (i.e. 'object.method'). + """ + def __init__(self, object_dot_method): + try: + self.target = weakref.ref(object_dot_method.__self__) + except AttributeError: + self.target = weakref.ref(object_dot_method.im_self) + self._target_id = id(self.target()) + try: + self.method = weakref.ref(object_dot_method.__func__) + except AttributeError: + self.method = weakref.ref(object_dot_method.im_func) + self._method_id = id(self.method()) + + def __call__(self, *args, **kwargs): + """ + Calls the method on target with args and kwargs. + """ + return self.method()(self.target(), *args, **kwargs) + + def __hash__(self): + return hash(self.target) ^ hash(self.method) + + def __eq__(self, other): + if not isinstance(other, WeakMethod): + return False + return self._target_id == other._target_id and self._method_id == other._method_id + + +class Dict(dict): + """Utility class to support passing weakrefs to dicts + + See: https://docs.python.org/2/library/weakref.html + """ + pass diff --git a/vendor/__init__.py b/vendor/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vendor/enum34.py b/vendor/enum34.py new file mode 100644 index 00000000..5f64bd2d --- /dev/null +++ b/vendor/enum34.py @@ -0,0 +1,841 @@ +# pylint: skip-file +# vendored from: +# https://bitbucket.org/stoneleaf/enum34/src/58c4cd7174ca35f164304c8a6f0a4d47b779c2a7/enum/__init__.py?at=1.1.6 + +"""Python Enumerations""" + +import sys as _sys + +__all__ = ['Enum', 'IntEnum', 'unique'] + +version = 1, 1, 6 + +pyver = float('%s.%s' % _sys.version_info[:2]) + +try: + any +except NameError: + def any(iterable): + for element in iterable: + if element: + return True + return False + +try: + from collections import OrderedDict +except ImportError: + OrderedDict = None + +try: + basestring +except NameError: + # In Python 2 basestring is the ancestor of both str and unicode + # in Python 3 it's just str, but was missing in 3.1 + basestring = str + +try: + unicode +except NameError: + # In Python 3 unicode no longer exists (it's just str) + unicode = str + +class _RouteClassAttributeToGetattr(object): + """Route attribute access on a class to __getattr__. + + This is a descriptor, used to define attributes that act differently when + accessed through an instance and through a class. Instance access remains + normal, but access to an attribute through a class will be routed to the + class's __getattr__ method; this is done by raising AttributeError. + + """ + def __init__(self, fget=None): + self.fget = fget + + def __get__(self, instance, ownerclass=None): + if instance is None: + raise AttributeError() + return self.fget(instance) + + def __set__(self, instance, value): + raise AttributeError("can't set attribute") + + def __delete__(self, instance): + raise AttributeError("can't delete attribute") + + +def _is_descriptor(obj): + """Returns True if obj is a descriptor, False otherwise.""" + return ( + hasattr(obj, '__get__') or + hasattr(obj, '__set__') or + hasattr(obj, '__delete__')) + + +def _is_dunder(name): + """Returns True if a __dunder__ name, False otherwise.""" + return (name[:2] == name[-2:] == '__' and + name[2:3] != '_' and + name[-3:-2] != '_' and + len(name) > 4) + + +def _is_sunder(name): + """Returns True if a _sunder_ name, False otherwise.""" + return (name[0] == name[-1] == '_' and + name[1:2] != '_' and + name[-2:-1] != '_' and + len(name) > 2) + + +def _make_class_unpicklable(cls): + """Make the given class un-picklable.""" + def _break_on_call_reduce(self, protocol=None): + raise TypeError('%r cannot be pickled' % self) + cls.__reduce_ex__ = _break_on_call_reduce + cls.__module__ = '' + + +class _EnumDict(dict): + """Track enum member order and ensure member names are not reused. + + EnumMeta will use the names found in self._member_names as the + enumeration member names. + + """ + def __init__(self): + super(_EnumDict, self).__init__() + self._member_names = [] + + def __setitem__(self, key, value): + """Changes anything not dundered or not a descriptor. + + If a descriptor is added with the same name as an enum member, the name + is removed from _member_names (this may leave a hole in the numerical + sequence of values). + + If an enum member name is used twice, an error is raised; duplicate + values are not checked for. + + Single underscore (sunder) names are reserved. + + Note: in 3.x __order__ is simply discarded as a not necessary piece + leftover from 2.x + + """ + if pyver >= 3.0 and key in ('_order_', '__order__'): + return + elif key == '__order__': + key = '_order_' + if _is_sunder(key): + if key != '_order_': + raise ValueError('_names_ are reserved for future Enum use') + elif _is_dunder(key): + pass + elif key in self._member_names: + # descriptor overwriting an enum? + raise TypeError('Attempted to reuse key: %r' % key) + elif not _is_descriptor(value): + if key in self: + # enum overwriting a descriptor? + raise TypeError('Key already defined as: %r' % self[key]) + self._member_names.append(key) + super(_EnumDict, self).__setitem__(key, value) + + +# Dummy value for Enum as EnumMeta explicity checks for it, but of course until +# EnumMeta finishes running the first time the Enum class doesn't exist. This +# is also why there are checks in EnumMeta like `if Enum is not None` +Enum = None + + +class EnumMeta(type): + """Metaclass for Enum""" + @classmethod + def __prepare__(metacls, cls, bases): + return _EnumDict() + + def __new__(metacls, cls, bases, classdict): + # an Enum class is final once enumeration items have been defined; it + # cannot be mixed with other types (int, float, etc.) if it has an + # inherited __new__ unless a new __new__ is defined (or the resulting + # class will fail). + if type(classdict) is dict: + original_dict = classdict + classdict = _EnumDict() + for k, v in original_dict.items(): + classdict[k] = v + + member_type, first_enum = metacls._get_mixins_(bases) + __new__, save_new, use_args = metacls._find_new_(classdict, member_type, + first_enum) + # save enum items into separate mapping so they don't get baked into + # the new class + members = dict((k, classdict[k]) for k in classdict._member_names) + for name in classdict._member_names: + del classdict[name] + + # py2 support for definition order + _order_ = classdict.get('_order_') + if _order_ is None: + if pyver < 3.0: + try: + _order_ = [name for (name, value) in sorted(members.items(), key=lambda item: item[1])] + except TypeError: + _order_ = [name for name in sorted(members.keys())] + else: + _order_ = classdict._member_names + else: + del classdict['_order_'] + if pyver < 3.0: + _order_ = _order_.replace(',', ' ').split() + aliases = [name for name in members if name not in _order_] + _order_ += aliases + + # check for illegal enum names (any others?) + invalid_names = set(members) & set(['mro']) + if invalid_names: + raise ValueError('Invalid enum member name(s): %s' % ( + ', '.join(invalid_names), )) + + # save attributes from super classes so we know if we can take + # the shortcut of storing members in the class dict + base_attributes = set([a for b in bases for a in b.__dict__]) + # create our new Enum type + enum_class = super(EnumMeta, metacls).__new__(metacls, cls, bases, classdict) + enum_class._member_names_ = [] # names in random order + if OrderedDict is not None: + enum_class._member_map_ = OrderedDict() + else: + enum_class._member_map_ = {} # name->value map + enum_class._member_type_ = member_type + + # Reverse value->name map for hashable values. + enum_class._value2member_map_ = {} + + # instantiate them, checking for duplicates as we go + # we instantiate first instead of checking for duplicates first in case + # a custom __new__ is doing something funky with the values -- such as + # auto-numbering ;) + if __new__ is None: + __new__ = enum_class.__new__ + for member_name in _order_: + value = members[member_name] + if not isinstance(value, tuple): + args = (value, ) + else: + args = value + if member_type is tuple: # special case for tuple enums + args = (args, ) # wrap it one more time + if not use_args or not args: + enum_member = __new__(enum_class) + if not hasattr(enum_member, '_value_'): + enum_member._value_ = value + else: + enum_member = __new__(enum_class, *args) + if not hasattr(enum_member, '_value_'): + enum_member._value_ = member_type(*args) + value = enum_member._value_ + enum_member._name_ = member_name + enum_member.__objclass__ = enum_class + enum_member.__init__(*args) + # If another member with the same value was already defined, the + # new member becomes an alias to the existing one. + for name, canonical_member in enum_class._member_map_.items(): + if canonical_member.value == enum_member._value_: + enum_member = canonical_member + break + else: + # Aliases don't appear in member names (only in __members__). + enum_class._member_names_.append(member_name) + # performance boost for any member that would not shadow + # a DynamicClassAttribute (aka _RouteClassAttributeToGetattr) + if member_name not in base_attributes: + setattr(enum_class, member_name, enum_member) + # now add to _member_map_ + enum_class._member_map_[member_name] = enum_member + try: + # This may fail if value is not hashable. We can't add the value + # to the map, and by-value lookups for this value will be + # linear. + enum_class._value2member_map_[value] = enum_member + except TypeError: + pass + + + # If a custom type is mixed into the Enum, and it does not know how + # to pickle itself, pickle.dumps will succeed but pickle.loads will + # fail. Rather than have the error show up later and possibly far + # from the source, sabotage the pickle protocol for this class so + # that pickle.dumps also fails. + # + # However, if the new class implements its own __reduce_ex__, do not + # sabotage -- it's on them to make sure it works correctly. We use + # __reduce_ex__ instead of any of the others as it is preferred by + # pickle over __reduce__, and it handles all pickle protocols. + unpicklable = False + if '__reduce_ex__' not in classdict: + if member_type is not object: + methods = ('__getnewargs_ex__', '__getnewargs__', + '__reduce_ex__', '__reduce__') + if not any(m in member_type.__dict__ for m in methods): + _make_class_unpicklable(enum_class) + unpicklable = True + + + # double check that repr and friends are not the mixin's or various + # things break (such as pickle) + for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'): + class_method = getattr(enum_class, name) + obj_method = getattr(member_type, name, None) + enum_method = getattr(first_enum, name, None) + if name not in classdict and class_method is not enum_method: + if name == '__reduce_ex__' and unpicklable: + continue + setattr(enum_class, name, enum_method) + + # method resolution and int's are not playing nice + # Python's less than 2.6 use __cmp__ + + if pyver < 2.6: + + if issubclass(enum_class, int): + setattr(enum_class, '__cmp__', getattr(int, '__cmp__')) + + elif pyver < 3.0: + + if issubclass(enum_class, int): + for method in ( + '__le__', + '__lt__', + '__gt__', + '__ge__', + '__eq__', + '__ne__', + '__hash__', + ): + setattr(enum_class, method, getattr(int, method)) + + # replace any other __new__ with our own (as long as Enum is not None, + # anyway) -- again, this is to support pickle + if Enum is not None: + # if the user defined their own __new__, save it before it gets + # clobbered in case they subclass later + if save_new: + setattr(enum_class, '__member_new__', enum_class.__dict__['__new__']) + setattr(enum_class, '__new__', Enum.__dict__['__new__']) + return enum_class + + def __bool__(cls): + """ + classes/types should always be True. + """ + return True + + def __call__(cls, value, names=None, module=None, type=None, start=1): + """Either returns an existing member, or creates a new enum class. + + This method is used both when an enum class is given a value to match + to an enumeration member (i.e. Color(3)) and for the functional API + (i.e. Color = Enum('Color', names='red green blue')). + + When used for the functional API: `module`, if set, will be stored in + the new class' __module__ attribute; `type`, if set, will be mixed in + as the first base class. + + Note: if `module` is not set this routine will attempt to discover the + calling module by walking the frame stack; if this is unsuccessful + the resulting class will not be pickleable. + + """ + if names is None: # simple value lookup + return cls.__new__(cls, value) + # otherwise, functional API: we're creating a new Enum type + return cls._create_(value, names, module=module, type=type, start=start) + + def __contains__(cls, member): + return isinstance(member, cls) and member.name in cls._member_map_ + + def __delattr__(cls, attr): + # nicer error message when someone tries to delete an attribute + # (see issue19025). + if attr in cls._member_map_: + raise AttributeError( + "%s: cannot delete Enum member." % cls.__name__) + super(EnumMeta, cls).__delattr__(attr) + + def __dir__(self): + return (['__class__', '__doc__', '__members__', '__module__'] + + self._member_names_) + + @property + def __members__(cls): + """Returns a mapping of member name->value. + + This mapping lists all enum members, including aliases. Note that this + is a copy of the internal mapping. + + """ + return cls._member_map_.copy() + + def __getattr__(cls, name): + """Return the enum member matching `name` + + We use __getattr__ instead of descriptors or inserting into the enum + class' __dict__ in order to support `name` and `value` being both + properties for enum members (which live in the class' __dict__) and + enum members themselves. + + """ + if _is_dunder(name): + raise AttributeError(name) + try: + return cls._member_map_[name] + except KeyError: + raise AttributeError(name) + + def __getitem__(cls, name): + return cls._member_map_[name] + + def __iter__(cls): + return (cls._member_map_[name] for name in cls._member_names_) + + def __reversed__(cls): + return (cls._member_map_[name] for name in reversed(cls._member_names_)) + + def __len__(cls): + return len(cls._member_names_) + + __nonzero__ = __bool__ + + def __repr__(cls): + return "" % cls.__name__ + + def __setattr__(cls, name, value): + """Block attempts to reassign Enum members. + + A simple assignment to the class namespace only changes one of the + several possible ways to get an Enum member from the Enum class, + resulting in an inconsistent Enumeration. + + """ + member_map = cls.__dict__.get('_member_map_', {}) + if name in member_map: + raise AttributeError('Cannot reassign members.') + super(EnumMeta, cls).__setattr__(name, value) + + def _create_(cls, class_name, names=None, module=None, type=None, start=1): + """Convenience method to create a new Enum class. + + `names` can be: + + * A string containing member names, separated either with spaces or + commas. Values are auto-numbered from 1. + * An iterable of member names. Values are auto-numbered from 1. + * An iterable of (member name, value) pairs. + * A mapping of member name -> value. + + """ + if pyver < 3.0: + # if class_name is unicode, attempt a conversion to ASCII + if isinstance(class_name, unicode): + try: + class_name = class_name.encode('ascii') + except UnicodeEncodeError: + raise TypeError('%r is not representable in ASCII' % class_name) + metacls = cls.__class__ + if type is None: + bases = (cls, ) + else: + bases = (type, cls) + classdict = metacls.__prepare__(class_name, bases) + _order_ = [] + + # special processing needed for names? + if isinstance(names, basestring): + names = names.replace(',', ' ').split() + if isinstance(names, (tuple, list)) and isinstance(names[0], basestring): + names = [(e, i+start) for (i, e) in enumerate(names)] + + # Here, names is either an iterable of (name, value) or a mapping. + item = None # in case names is empty + for item in names: + if isinstance(item, basestring): + member_name, member_value = item, names[item] + else: + member_name, member_value = item + classdict[member_name] = member_value + _order_.append(member_name) + # only set _order_ in classdict if name/value was not from a mapping + if not isinstance(item, basestring): + classdict['_order_'] = ' '.join(_order_) + enum_class = metacls.__new__(metacls, class_name, bases, classdict) + + # TODO: replace the frame hack if a blessed way to know the calling + # module is ever developed + if module is None: + try: + module = _sys._getframe(2).f_globals['__name__'] + except (AttributeError, ValueError): + pass + if module is None: + _make_class_unpicklable(enum_class) + else: + enum_class.__module__ = module + + return enum_class + + @staticmethod + def _get_mixins_(bases): + """Returns the type for creating enum members, and the first inherited + enum class. + + bases: the tuple of bases that was given to __new__ + + """ + if not bases or Enum is None: + return object, Enum + + + # double check that we are not subclassing a class with existing + # enumeration members; while we're at it, see if any other data + # type has been mixed in so we can use the correct __new__ + member_type = first_enum = None + for base in bases: + if (base is not Enum and + issubclass(base, Enum) and + base._member_names_): + raise TypeError("Cannot extend enumerations") + # base is now the last base in bases + if not issubclass(base, Enum): + raise TypeError("new enumerations must be created as " + "`ClassName([mixin_type,] enum_type)`") + + # get correct mix-in type (either mix-in type of Enum subclass, or + # first base if last base is Enum) + if not issubclass(bases[0], Enum): + member_type = bases[0] # first data type + first_enum = bases[-1] # enum type + else: + for base in bases[0].__mro__: + # most common: (IntEnum, int, Enum, object) + # possible: (, , + # , , + # ) + if issubclass(base, Enum): + if first_enum is None: + first_enum = base + else: + if member_type is None: + member_type = base + + return member_type, first_enum + + if pyver < 3.0: + @staticmethod + def _find_new_(classdict, member_type, first_enum): + """Returns the __new__ to be used for creating the enum members. + + classdict: the class dictionary given to __new__ + member_type: the data type whose __new__ will be used by default + first_enum: enumeration to check for an overriding __new__ + + """ + # now find the correct __new__, checking to see of one was defined + # by the user; also check earlier enum classes in case a __new__ was + # saved as __member_new__ + __new__ = classdict.get('__new__', None) + if __new__: + return None, True, True # __new__, save_new, use_args + + N__new__ = getattr(None, '__new__') + O__new__ = getattr(object, '__new__') + if Enum is None: + E__new__ = N__new__ + else: + E__new__ = Enum.__dict__['__new__'] + # check all possibles for __member_new__ before falling back to + # __new__ + for method in ('__member_new__', '__new__'): + for possible in (member_type, first_enum): + try: + target = possible.__dict__[method] + except (AttributeError, KeyError): + target = getattr(possible, method, None) + if target not in [ + None, + N__new__, + O__new__, + E__new__, + ]: + if method == '__member_new__': + classdict['__new__'] = target + return None, False, True + if isinstance(target, staticmethod): + target = target.__get__(member_type) + __new__ = target + break + if __new__ is not None: + break + else: + __new__ = object.__new__ + + # if a non-object.__new__ is used then whatever value/tuple was + # assigned to the enum member name will be passed to __new__ and to the + # new enum member's __init__ + if __new__ is object.__new__: + use_args = False + else: + use_args = True + + return __new__, False, use_args + else: + @staticmethod + def _find_new_(classdict, member_type, first_enum): + """Returns the __new__ to be used for creating the enum members. + + classdict: the class dictionary given to __new__ + member_type: the data type whose __new__ will be used by default + first_enum: enumeration to check for an overriding __new__ + + """ + # now find the correct __new__, checking to see of one was defined + # by the user; also check earlier enum classes in case a __new__ was + # saved as __member_new__ + __new__ = classdict.get('__new__', None) + + # should __new__ be saved as __member_new__ later? + save_new = __new__ is not None + + if __new__ is None: + # check all possibles for __member_new__ before falling back to + # __new__ + for method in ('__member_new__', '__new__'): + for possible in (member_type, first_enum): + target = getattr(possible, method, None) + if target not in ( + None, + None.__new__, + object.__new__, + Enum.__new__, + ): + __new__ = target + break + if __new__ is not None: + break + else: + __new__ = object.__new__ + + # if a non-object.__new__ is used then whatever value/tuple was + # assigned to the enum member name will be passed to __new__ and to the + # new enum member's __init__ + if __new__ is object.__new__: + use_args = False + else: + use_args = True + + return __new__, save_new, use_args + + +######################################################## +# In order to support Python 2 and 3 with a single +# codebase we have to create the Enum methods separately +# and then use the `type(name, bases, dict)` method to +# create the class. +######################################################## +temp_enum_dict = {} +temp_enum_dict['__doc__'] = "Generic enumeration.\n\n Derive from this class to define new enumerations.\n\n" + +def __new__(cls, value): + # all enum instances are actually created during class construction + # without calling this method; this method is called by the metaclass' + # __call__ (i.e. Color(3) ), and by pickle + if type(value) is cls: + # For lookups like Color(Color.red) + value = value.value + #return value + # by-value search for a matching enum member + # see if it's in the reverse mapping (for hashable values) + try: + if value in cls._value2member_map_: + return cls._value2member_map_[value] + except TypeError: + # not there, now do long search -- O(n) behavior + for member in cls._member_map_.values(): + if member.value == value: + return member + raise ValueError("%s is not a valid %s" % (value, cls.__name__)) +temp_enum_dict['__new__'] = __new__ +del __new__ + +def __repr__(self): + return "<%s.%s: %r>" % ( + self.__class__.__name__, self._name_, self._value_) +temp_enum_dict['__repr__'] = __repr__ +del __repr__ + +def __str__(self): + return "%s.%s" % (self.__class__.__name__, self._name_) +temp_enum_dict['__str__'] = __str__ +del __str__ + +if pyver >= 3.0: + def __dir__(self): + added_behavior = [ + m + for cls in self.__class__.mro() + for m in cls.__dict__ + if m[0] != '_' and m not in self._member_map_ + ] + return (['__class__', '__doc__', '__module__', ] + added_behavior) + temp_enum_dict['__dir__'] = __dir__ + del __dir__ + +def __format__(self, format_spec): + # mixed-in Enums should use the mixed-in type's __format__, otherwise + # we can get strange results with the Enum name showing up instead of + # the value + + # pure Enum branch + if self._member_type_ is object: + cls = str + val = str(self) + # mix-in branch + else: + cls = self._member_type_ + val = self.value + return cls.__format__(val, format_spec) +temp_enum_dict['__format__'] = __format__ +del __format__ + + +#################################### +# Python's less than 2.6 use __cmp__ + +if pyver < 2.6: + + def __cmp__(self, other): + if type(other) is self.__class__: + if self is other: + return 0 + return -1 + return NotImplemented + raise TypeError("unorderable types: %s() and %s()" % (self.__class__.__name__, other.__class__.__name__)) + temp_enum_dict['__cmp__'] = __cmp__ + del __cmp__ + +else: + + def __le__(self, other): + raise TypeError("unorderable types: %s() <= %s()" % (self.__class__.__name__, other.__class__.__name__)) + temp_enum_dict['__le__'] = __le__ + del __le__ + + def __lt__(self, other): + raise TypeError("unorderable types: %s() < %s()" % (self.__class__.__name__, other.__class__.__name__)) + temp_enum_dict['__lt__'] = __lt__ + del __lt__ + + def __ge__(self, other): + raise TypeError("unorderable types: %s() >= %s()" % (self.__class__.__name__, other.__class__.__name__)) + temp_enum_dict['__ge__'] = __ge__ + del __ge__ + + def __gt__(self, other): + raise TypeError("unorderable types: %s() > %s()" % (self.__class__.__name__, other.__class__.__name__)) + temp_enum_dict['__gt__'] = __gt__ + del __gt__ + + +def __eq__(self, other): + if type(other) is self.__class__: + return self is other + return NotImplemented +temp_enum_dict['__eq__'] = __eq__ +del __eq__ + +def __ne__(self, other): + if type(other) is self.__class__: + return self is not other + return NotImplemented +temp_enum_dict['__ne__'] = __ne__ +del __ne__ + +def __hash__(self): + return hash(self._name_) +temp_enum_dict['__hash__'] = __hash__ +del __hash__ + +def __reduce_ex__(self, proto): + return self.__class__, (self._value_, ) +temp_enum_dict['__reduce_ex__'] = __reduce_ex__ +del __reduce_ex__ + +# _RouteClassAttributeToGetattr is used to provide access to the `name` +# and `value` properties of enum members while keeping some measure of +# protection from modification, while still allowing for an enumeration +# to have members named `name` and `value`. This works because enumeration +# members are not set directly on the enum class -- __getattr__ is +# used to look them up. + +@_RouteClassAttributeToGetattr +def name(self): + return self._name_ +temp_enum_dict['name'] = name +del name + +@_RouteClassAttributeToGetattr +def value(self): + return self._value_ +temp_enum_dict['value'] = value +del value + +@classmethod +def _convert(cls, name, module, filter, source=None): + """ + Create a new Enum subclass that replaces a collection of global constants + """ + # convert all constants from source (or module) that pass filter() to + # a new Enum called name, and export the enum and its members back to + # module; + # also, replace the __reduce_ex__ method so unpickling works in + # previous Python versions + module_globals = vars(_sys.modules[module]) + if source: + source = vars(source) + else: + source = module_globals + members = dict((name, value) for name, value in source.items() if filter(name)) + cls = cls(name, members, module=module) + cls.__reduce_ex__ = _reduce_ex_by_name + module_globals.update(cls.__members__) + module_globals[name] = cls + return cls +temp_enum_dict['_convert'] = _convert +del _convert + +Enum = EnumMeta('Enum', (object, ), temp_enum_dict) +del temp_enum_dict + +# Enum has now been created +########################### + +class IntEnum(int, Enum): + """Enum where members are also (and must be) ints""" + +def _reduce_ex_by_name(self, proto): + return self.name + +def unique(enumeration): + """Class decorator that ensures only unique members exist in an enumeration.""" + duplicates = [] + for name, member in enumeration.__members__.items(): + if name != member.name: + duplicates.append((name, member.name)) + if duplicates: + duplicate_names = ', '.join( + ["%s -> %s" % (alias, name) for (alias, name) in duplicates] + ) + raise ValueError('duplicate names found in %r: %s' % + (enumeration, duplicate_names) + ) + return enumeration diff --git a/vendor/selectors34.py b/vendor/selectors34.py new file mode 100644 index 00000000..ebf5d515 --- /dev/null +++ b/vendor/selectors34.py @@ -0,0 +1,637 @@ +# pylint: skip-file +# vendored from https://github.com/berkerpeksag/selectors34 +# at commit ff61b82168d2cc9c4922ae08e2a8bf94aab61ea2 (unreleased, ~1.2) +# +# Original author: Charles-Francois Natali (c.f.natali[at]gmail.com) +# Maintainer: Berker Peksag (berker.peksag[at]gmail.com) +# Also see https://pypi.python.org/pypi/selectors34 +"""Selectors module. + +This module allows high-level and efficient I/O multiplexing, built upon the +`select` module primitives. + +The following code adapted from trollius.selectors. +""" +from __future__ import absolute_import + +from abc import ABCMeta, abstractmethod +from collections import namedtuple, Mapping +from errno import EINTR +import math +import select +import sys + +from kafka.vendor import six + + +def _wrap_error(exc, mapping, key): + if key not in mapping: + return + new_err_cls = mapping[key] + new_err = new_err_cls(*exc.args) + + # raise a new exception with the original traceback + if hasattr(exc, '__traceback__'): + traceback = exc.__traceback__ + else: + traceback = sys.exc_info()[2] + six.reraise(new_err_cls, new_err, traceback) + + +# generic events, that must be mapped to implementation-specific ones +EVENT_READ = (1 << 0) +EVENT_WRITE = (1 << 1) + + +def _fileobj_to_fd(fileobj): + """Return a file descriptor from a file object. + + Parameters: + fileobj -- file object or file descriptor + + Returns: + corresponding file descriptor + + Raises: + ValueError if the object is invalid + """ + if isinstance(fileobj, six.integer_types): + fd = fileobj + else: + try: + fd = int(fileobj.fileno()) + except (AttributeError, TypeError, ValueError): + raise ValueError("Invalid file object: " + "{0!r}".format(fileobj)) + if fd < 0: + raise ValueError("Invalid file descriptor: {0}".format(fd)) + return fd + + +SelectorKey = namedtuple('SelectorKey', ['fileobj', 'fd', 'events', 'data']) +"""Object used to associate a file object to its backing file descriptor, +selected event mask and attached data.""" + + +class _SelectorMapping(Mapping): + """Mapping of file objects to selector keys.""" + + def __init__(self, selector): + self._selector = selector + + def __len__(self): + return len(self._selector._fd_to_key) + + def __getitem__(self, fileobj): + try: + fd = self._selector._fileobj_lookup(fileobj) + return self._selector._fd_to_key[fd] + except KeyError: + raise KeyError("{0!r} is not registered".format(fileobj)) + + def __iter__(self): + return iter(self._selector._fd_to_key) + +# Using six.add_metaclass() decorator instead of six.with_metaclass() because +# the latter leaks temporary_class to garbage with gc disabled +@six.add_metaclass(ABCMeta) +class BaseSelector(object): + """Selector abstract base class. + + A selector supports registering file objects to be monitored for specific + I/O events. + + A file object is a file descriptor or any object with a `fileno()` method. + An arbitrary object can be attached to the file object, which can be used + for example to store context information, a callback, etc. + + A selector can use various implementations (select(), poll(), epoll()...) + depending on the platform. The default `Selector` class uses the most + efficient implementation on the current platform. + """ + + @abstractmethod + def register(self, fileobj, events, data=None): + """Register a file object. + + Parameters: + fileobj -- file object or file descriptor + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + + Raises: + ValueError if events is invalid + KeyError if fileobj is already registered + OSError if fileobj is closed or otherwise is unacceptable to + the underlying system call (if a system call is made) + + Note: + OSError may or may not be raised + """ + raise NotImplementedError + + @abstractmethod + def unregister(self, fileobj): + """Unregister a file object. + + Parameters: + fileobj -- file object or file descriptor + + Returns: + SelectorKey instance + + Raises: + KeyError if fileobj is not registered + + Note: + If fileobj is registered but has since been closed this does + *not* raise OSError (even if the wrapped syscall does) + """ + raise NotImplementedError + + def modify(self, fileobj, events, data=None): + """Change a registered file object monitored events or attached data. + + Parameters: + fileobj -- file object or file descriptor + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + + Raises: + Anything that unregister() or register() raises + """ + self.unregister(fileobj) + return self.register(fileobj, events, data) + + @abstractmethod + def select(self, timeout=None): + """Perform the actual selection, until some monitored file objects are + ready or a timeout expires. + + Parameters: + timeout -- if timeout > 0, this specifies the maximum wait time, in + seconds + if timeout <= 0, the select() call won't block, and will + report the currently ready file objects + if timeout is None, select() will block until a monitored + file object becomes ready + + Returns: + list of (key, events) for ready file objects + `events` is a bitwise mask of EVENT_READ|EVENT_WRITE + """ + raise NotImplementedError + + def close(self): + """Close the selector. + + This must be called to make sure that any underlying resource is freed. + """ + pass + + def get_key(self, fileobj): + """Return the key associated to a registered file object. + + Returns: + SelectorKey for this file object + """ + mapping = self.get_map() + if mapping is None: + raise RuntimeError('Selector is closed') + try: + return mapping[fileobj] + except KeyError: + raise KeyError("{0!r} is not registered".format(fileobj)) + + @abstractmethod + def get_map(self): + """Return a mapping of file objects to selector keys.""" + raise NotImplementedError + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + +class _BaseSelectorImpl(BaseSelector): + """Base selector implementation.""" + + def __init__(self): + # this maps file descriptors to keys + self._fd_to_key = {} + # read-only mapping returned by get_map() + self._map = _SelectorMapping(self) + + def _fileobj_lookup(self, fileobj): + """Return a file descriptor from a file object. + + This wraps _fileobj_to_fd() to do an exhaustive search in case + the object is invalid but we still have it in our map. This + is used by unregister() so we can unregister an object that + was previously registered even if it is closed. It is also + used by _SelectorMapping. + """ + try: + return _fileobj_to_fd(fileobj) + except ValueError: + # Do an exhaustive search. + for key in self._fd_to_key.values(): + if key.fileobj is fileobj: + return key.fd + # Raise ValueError after all. + raise + + def register(self, fileobj, events, data=None): + if (not events) or (events & ~(EVENT_READ | EVENT_WRITE)): + raise ValueError("Invalid events: {0!r}".format(events)) + + key = SelectorKey(fileobj, self._fileobj_lookup(fileobj), events, data) + + if key.fd in self._fd_to_key: + raise KeyError("{0!r} (FD {1}) is already registered" + .format(fileobj, key.fd)) + + self._fd_to_key[key.fd] = key + return key + + def unregister(self, fileobj): + try: + key = self._fd_to_key.pop(self._fileobj_lookup(fileobj)) + except KeyError: + raise KeyError("{0!r} is not registered".format(fileobj)) + return key + + def modify(self, fileobj, events, data=None): + # TODO: Subclasses can probably optimize this even further. + try: + key = self._fd_to_key[self._fileobj_lookup(fileobj)] + except KeyError: + raise KeyError("{0!r} is not registered".format(fileobj)) + if events != key.events: + self.unregister(fileobj) + key = self.register(fileobj, events, data) + elif data != key.data: + # Use a shortcut to update the data. + key = key._replace(data=data) + self._fd_to_key[key.fd] = key + return key + + def close(self): + self._fd_to_key.clear() + self._map = None + + def get_map(self): + return self._map + + def _key_from_fd(self, fd): + """Return the key associated to a given file descriptor. + + Parameters: + fd -- file descriptor + + Returns: + corresponding key, or None if not found + """ + try: + return self._fd_to_key[fd] + except KeyError: + return None + + +class SelectSelector(_BaseSelectorImpl): + """Select-based selector.""" + + def __init__(self): + super(SelectSelector, self).__init__() + self._readers = set() + self._writers = set() + + def register(self, fileobj, events, data=None): + key = super(SelectSelector, self).register(fileobj, events, data) + if events & EVENT_READ: + self._readers.add(key.fd) + if events & EVENT_WRITE: + self._writers.add(key.fd) + return key + + def unregister(self, fileobj): + key = super(SelectSelector, self).unregister(fileobj) + self._readers.discard(key.fd) + self._writers.discard(key.fd) + return key + + if sys.platform == 'win32': + def _select(self, r, w, _, timeout=None): + r, w, x = select.select(r, w, w, timeout) + return r, w + x, [] + else: + _select = staticmethod(select.select) + + def select(self, timeout=None): + timeout = None if timeout is None else max(timeout, 0) + ready = [] + try: + r, w, _ = self._select(self._readers, self._writers, [], timeout) + except select.error as exc: + if exc.args[0] == EINTR: + return ready + else: + raise + r = set(r) + w = set(w) + for fd in r | w: + events = 0 + if fd in r: + events |= EVENT_READ + if fd in w: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + +if hasattr(select, 'poll'): + + class PollSelector(_BaseSelectorImpl): + """Poll-based selector.""" + + def __init__(self): + super(PollSelector, self).__init__() + self._poll = select.poll() + + def register(self, fileobj, events, data=None): + key = super(PollSelector, self).register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= select.POLLIN + if events & EVENT_WRITE: + poll_events |= select.POLLOUT + self._poll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super(PollSelector, self).unregister(fileobj) + self._poll.unregister(key.fd) + return key + + def select(self, timeout=None): + if timeout is None: + timeout = None + elif timeout <= 0: + timeout = 0 + else: + # poll() has a resolution of 1 millisecond, round away from + # zero to wait *at least* timeout seconds. + timeout = int(math.ceil(timeout * 1e3)) + ready = [] + try: + fd_event_list = self._poll.poll(timeout) + except select.error as exc: + if exc.args[0] == EINTR: + return ready + else: + raise + for fd, event in fd_event_list: + events = 0 + if event & ~select.POLLIN: + events |= EVENT_WRITE + if event & ~select.POLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + +if hasattr(select, 'epoll'): + + class EpollSelector(_BaseSelectorImpl): + """Epoll-based selector.""" + + def __init__(self): + super(EpollSelector, self).__init__() + self._epoll = select.epoll() + + def fileno(self): + return self._epoll.fileno() + + def register(self, fileobj, events, data=None): + key = super(EpollSelector, self).register(fileobj, events, data) + epoll_events = 0 + if events & EVENT_READ: + epoll_events |= select.EPOLLIN + if events & EVENT_WRITE: + epoll_events |= select.EPOLLOUT + self._epoll.register(key.fd, epoll_events) + return key + + def unregister(self, fileobj): + key = super(EpollSelector, self).unregister(fileobj) + try: + self._epoll.unregister(key.fd) + except IOError: + # This can happen if the FD was closed since it + # was registered. + pass + return key + + def select(self, timeout=None): + if timeout is None: + timeout = -1 + elif timeout <= 0: + timeout = 0 + else: + # epoll_wait() has a resolution of 1 millisecond, round away + # from zero to wait *at least* timeout seconds. + timeout = math.ceil(timeout * 1e3) * 1e-3 + + # epoll_wait() expects `maxevents` to be greater than zero; + # we want to make sure that `select()` can be called when no + # FD is registered. + max_ev = max(len(self._fd_to_key), 1) + + ready = [] + try: + fd_event_list = self._epoll.poll(timeout, max_ev) + except IOError as exc: + if exc.errno == EINTR: + return ready + else: + raise + for fd, event in fd_event_list: + events = 0 + if event & ~select.EPOLLIN: + events |= EVENT_WRITE + if event & ~select.EPOLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + self._epoll.close() + super(EpollSelector, self).close() + + +if hasattr(select, 'devpoll'): + + class DevpollSelector(_BaseSelectorImpl): + """Solaris /dev/poll selector.""" + + def __init__(self): + super(DevpollSelector, self).__init__() + self._devpoll = select.devpoll() + + def fileno(self): + return self._devpoll.fileno() + + def register(self, fileobj, events, data=None): + key = super(DevpollSelector, self).register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= select.POLLIN + if events & EVENT_WRITE: + poll_events |= select.POLLOUT + self._devpoll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super(DevpollSelector, self).unregister(fileobj) + self._devpoll.unregister(key.fd) + return key + + def select(self, timeout=None): + if timeout is None: + timeout = None + elif timeout <= 0: + timeout = 0 + else: + # devpoll() has a resolution of 1 millisecond, round away from + # zero to wait *at least* timeout seconds. + timeout = math.ceil(timeout * 1e3) + ready = [] + try: + fd_event_list = self._devpoll.poll(timeout) + except OSError as exc: + if exc.errno == EINTR: + return ready + else: + raise + for fd, event in fd_event_list: + events = 0 + if event & ~select.POLLIN: + events |= EVENT_WRITE + if event & ~select.POLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + self._devpoll.close() + super(DevpollSelector, self).close() + + +if hasattr(select, 'kqueue'): + + class KqueueSelector(_BaseSelectorImpl): + """Kqueue-based selector.""" + + def __init__(self): + super(KqueueSelector, self).__init__() + self._kqueue = select.kqueue() + + def fileno(self): + return self._kqueue.fileno() + + def register(self, fileobj, events, data=None): + key = super(KqueueSelector, self).register(fileobj, events, data) + if events & EVENT_READ: + kev = select.kevent(key.fd, select.KQ_FILTER_READ, + select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + if events & EVENT_WRITE: + kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, + select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return key + + def unregister(self, fileobj): + key = super(KqueueSelector, self).unregister(fileobj) + if key.events & EVENT_READ: + kev = select.kevent(key.fd, select.KQ_FILTER_READ, + select.KQ_EV_DELETE) + try: + self._kqueue.control([kev], 0, 0) + except OSError: + # This can happen if the FD was closed since it + # was registered. + pass + if key.events & EVENT_WRITE: + kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, + select.KQ_EV_DELETE) + try: + self._kqueue.control([kev], 0, 0) + except OSError: + # See comment above. + pass + return key + + def select(self, timeout=None): + timeout = None if timeout is None else max(timeout, 0) + max_ev = len(self._fd_to_key) + ready = [] + try: + kev_list = self._kqueue.control(None, max_ev, timeout) + except OSError as exc: + if exc.errno == EINTR: + return ready + else: + raise + for kev in kev_list: + fd = kev.ident + flag = kev.filter + events = 0 + if flag == select.KQ_FILTER_READ: + events |= EVENT_READ + if flag == select.KQ_FILTER_WRITE: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + self._kqueue.close() + super(KqueueSelector, self).close() + + +# Choose the best implementation, roughly: +# epoll|kqueue|devpoll > poll > select. +# select() also can't accept a FD > FD_SETSIZE (usually around 1024) +if 'KqueueSelector' in globals(): + DefaultSelector = KqueueSelector +elif 'EpollSelector' in globals(): + DefaultSelector = EpollSelector +elif 'DevpollSelector' in globals(): + DefaultSelector = DevpollSelector +elif 'PollSelector' in globals(): + DefaultSelector = PollSelector +else: + DefaultSelector = SelectSelector diff --git a/vendor/six.py b/vendor/six.py new file mode 100644 index 00000000..3621a0ab --- /dev/null +++ b/vendor/six.py @@ -0,0 +1,897 @@ +# pylint: skip-file + +# Copyright (c) 2010-2017 Benjamin Peterson +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Utilities for writing code that runs on Python 2 and 3""" + +from __future__ import absolute_import + +import functools +import itertools +import operator +import sys +import types + +__author__ = "Benjamin Peterson " +__version__ = "1.11.0" + + +# Useful for very coarse version differentiation. +PY2 = sys.version_info[0] == 2 +PY3 = sys.version_info[0] == 3 +PY34 = sys.version_info[0:2] >= (3, 4) + +if PY3: + string_types = str, + integer_types = int, + class_types = type, + text_type = str + binary_type = bytes + + MAXSIZE = sys.maxsize +else: + string_types = basestring, + integer_types = (int, long) + class_types = (type, types.ClassType) + text_type = unicode + binary_type = str + + if sys.platform.startswith("java"): + # Jython always uses 32 bits. + MAXSIZE = int((1 << 31) - 1) + else: + # It's possible to have sizeof(long) != sizeof(Py_ssize_t). + class X(object): + + def __len__(self): + return 1 << 31 + try: + len(X()) + except OverflowError: + # 32-bit + MAXSIZE = int((1 << 31) - 1) + else: + # 64-bit + MAXSIZE = int((1 << 63) - 1) + + # Don't del it here, cause with gc disabled this "leaks" to garbage. + # Note: This is a kafka-python customization, details at: + # https://github.com/dpkp/kafka-python/pull/979#discussion_r100403389 + # del X + + +def _add_doc(func, doc): + """Add documentation to a function.""" + func.__doc__ = doc + + +def _import_module(name): + """Import module, returning the module after the last dot.""" + __import__(name) + return sys.modules[name] + + +class _LazyDescr(object): + + def __init__(self, name): + self.name = name + + def __get__(self, obj, tp): + result = self._resolve() + setattr(obj, self.name, result) # Invokes __set__. + try: + # This is a bit ugly, but it avoids running this again by + # removing this descriptor. + delattr(obj.__class__, self.name) + except AttributeError: + pass + return result + + +class MovedModule(_LazyDescr): + + def __init__(self, name, old, new=None): + super(MovedModule, self).__init__(name) + if PY3: + if new is None: + new = name + self.mod = new + else: + self.mod = old + + def _resolve(self): + return _import_module(self.mod) + + def __getattr__(self, attr): + _module = self._resolve() + value = getattr(_module, attr) + setattr(self, attr, value) + return value + + +class _LazyModule(types.ModuleType): + + def __init__(self, name): + super(_LazyModule, self).__init__(name) + self.__doc__ = self.__class__.__doc__ + + def __dir__(self): + attrs = ["__doc__", "__name__"] + attrs += [attr.name for attr in self._moved_attributes] + return attrs + + # Subclasses should override this + _moved_attributes = [] + + +class MovedAttribute(_LazyDescr): + + def __init__(self, name, old_mod, new_mod, old_attr=None, new_attr=None): + super(MovedAttribute, self).__init__(name) + if PY3: + if new_mod is None: + new_mod = name + self.mod = new_mod + if new_attr is None: + if old_attr is None: + new_attr = name + else: + new_attr = old_attr + self.attr = new_attr + else: + self.mod = old_mod + if old_attr is None: + old_attr = name + self.attr = old_attr + + def _resolve(self): + module = _import_module(self.mod) + return getattr(module, self.attr) + + +class _SixMetaPathImporter(object): + + """ + A meta path importer to import six.moves and its submodules. + + This class implements a PEP302 finder and loader. It should be compatible + with Python 2.5 and all existing versions of Python3 + """ + + def __init__(self, six_module_name): + self.name = six_module_name + self.known_modules = {} + + def _add_module(self, mod, *fullnames): + for fullname in fullnames: + self.known_modules[self.name + "." + fullname] = mod + + def _get_module(self, fullname): + return self.known_modules[self.name + "." + fullname] + + def find_module(self, fullname, path=None): + if fullname in self.known_modules: + return self + return None + + def __get_module(self, fullname): + try: + return self.known_modules[fullname] + except KeyError: + raise ImportError("This loader does not know module " + fullname) + + def load_module(self, fullname): + try: + # in case of a reload + return sys.modules[fullname] + except KeyError: + pass + mod = self.__get_module(fullname) + if isinstance(mod, MovedModule): + mod = mod._resolve() + else: + mod.__loader__ = self + sys.modules[fullname] = mod + return mod + + def is_package(self, fullname): + """ + Return true, if the named module is a package. + + We need this method to get correct spec objects with + Python 3.4 (see PEP451) + """ + return hasattr(self.__get_module(fullname), "__path__") + + def get_code(self, fullname): + """Return None + + Required, if is_package is implemented""" + self.__get_module(fullname) # eventually raises ImportError + return None + get_source = get_code # same as get_code + +_importer = _SixMetaPathImporter(__name__) + + +class _MovedItems(_LazyModule): + + """Lazy loading of moved objects""" + __path__ = [] # mark as package + + +_moved_attributes = [ + MovedAttribute("cStringIO", "cStringIO", "io", "StringIO"), + MovedAttribute("filter", "itertools", "builtins", "ifilter", "filter"), + MovedAttribute("filterfalse", "itertools", "itertools", "ifilterfalse", "filterfalse"), + MovedAttribute("input", "__builtin__", "builtins", "raw_input", "input"), + MovedAttribute("intern", "__builtin__", "sys"), + MovedAttribute("map", "itertools", "builtins", "imap", "map"), + MovedAttribute("getcwd", "os", "os", "getcwdu", "getcwd"), + MovedAttribute("getcwdb", "os", "os", "getcwd", "getcwdb"), + MovedAttribute("getoutput", "commands", "subprocess"), + MovedAttribute("range", "__builtin__", "builtins", "xrange", "range"), + MovedAttribute("reload_module", "__builtin__", "importlib" if PY34 else "imp", "reload"), + MovedAttribute("reduce", "__builtin__", "functools"), + MovedAttribute("shlex_quote", "pipes", "shlex", "quote"), + MovedAttribute("StringIO", "StringIO", "io"), + MovedAttribute("UserDict", "UserDict", "collections"), + MovedAttribute("UserList", "UserList", "collections"), + MovedAttribute("UserString", "UserString", "collections"), + MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"), + MovedAttribute("zip", "itertools", "builtins", "izip", "zip"), + MovedAttribute("zip_longest", "itertools", "itertools", "izip_longest", "zip_longest"), + MovedModule("builtins", "__builtin__"), + MovedModule("configparser", "ConfigParser"), + MovedModule("copyreg", "copy_reg"), + MovedModule("dbm_gnu", "gdbm", "dbm.gnu"), + MovedModule("_dummy_thread", "dummy_thread", "_dummy_thread"), + MovedModule("http_cookiejar", "cookielib", "http.cookiejar"), + MovedModule("http_cookies", "Cookie", "http.cookies"), + MovedModule("html_entities", "htmlentitydefs", "html.entities"), + MovedModule("html_parser", "HTMLParser", "html.parser"), + MovedModule("http_client", "httplib", "http.client"), + MovedModule("email_mime_base", "email.MIMEBase", "email.mime.base"), + MovedModule("email_mime_image", "email.MIMEImage", "email.mime.image"), + MovedModule("email_mime_multipart", "email.MIMEMultipart", "email.mime.multipart"), + MovedModule("email_mime_nonmultipart", "email.MIMENonMultipart", "email.mime.nonmultipart"), + MovedModule("email_mime_text", "email.MIMEText", "email.mime.text"), + MovedModule("BaseHTTPServer", "BaseHTTPServer", "http.server"), + MovedModule("CGIHTTPServer", "CGIHTTPServer", "http.server"), + MovedModule("SimpleHTTPServer", "SimpleHTTPServer", "http.server"), + MovedModule("cPickle", "cPickle", "pickle"), + MovedModule("queue", "Queue"), + MovedModule("reprlib", "repr"), + MovedModule("socketserver", "SocketServer"), + MovedModule("_thread", "thread", "_thread"), + MovedModule("tkinter", "Tkinter"), + MovedModule("tkinter_dialog", "Dialog", "tkinter.dialog"), + MovedModule("tkinter_filedialog", "FileDialog", "tkinter.filedialog"), + MovedModule("tkinter_scrolledtext", "ScrolledText", "tkinter.scrolledtext"), + MovedModule("tkinter_simpledialog", "SimpleDialog", "tkinter.simpledialog"), + MovedModule("tkinter_tix", "Tix", "tkinter.tix"), + MovedModule("tkinter_ttk", "ttk", "tkinter.ttk"), + MovedModule("tkinter_constants", "Tkconstants", "tkinter.constants"), + MovedModule("tkinter_dnd", "Tkdnd", "tkinter.dnd"), + MovedModule("tkinter_colorchooser", "tkColorChooser", + "tkinter.colorchooser"), + MovedModule("tkinter_commondialog", "tkCommonDialog", + "tkinter.commondialog"), + MovedModule("tkinter_tkfiledialog", "tkFileDialog", "tkinter.filedialog"), + MovedModule("tkinter_font", "tkFont", "tkinter.font"), + MovedModule("tkinter_messagebox", "tkMessageBox", "tkinter.messagebox"), + MovedModule("tkinter_tksimpledialog", "tkSimpleDialog", + "tkinter.simpledialog"), + MovedModule("urllib_parse", __name__ + ".moves.urllib_parse", "urllib.parse"), + MovedModule("urllib_error", __name__ + ".moves.urllib_error", "urllib.error"), + MovedModule("urllib", __name__ + ".moves.urllib", __name__ + ".moves.urllib"), + MovedModule("urllib_robotparser", "robotparser", "urllib.robotparser"), + MovedModule("xmlrpc_client", "xmlrpclib", "xmlrpc.client"), + MovedModule("xmlrpc_server", "SimpleXMLRPCServer", "xmlrpc.server"), +] +# Add windows specific modules. +if sys.platform == "win32": + _moved_attributes += [ + MovedModule("winreg", "_winreg"), + ] + +for attr in _moved_attributes: + setattr(_MovedItems, attr.name, attr) + if isinstance(attr, MovedModule): + _importer._add_module(attr, "moves." + attr.name) +del attr + +_MovedItems._moved_attributes = _moved_attributes + +moves = _MovedItems(__name__ + ".moves") +_importer._add_module(moves, "moves") + + +class Module_six_moves_urllib_parse(_LazyModule): + + """Lazy loading of moved objects in six.moves.urllib_parse""" + + +_urllib_parse_moved_attributes = [ + MovedAttribute("ParseResult", "urlparse", "urllib.parse"), + MovedAttribute("SplitResult", "urlparse", "urllib.parse"), + MovedAttribute("parse_qs", "urlparse", "urllib.parse"), + MovedAttribute("parse_qsl", "urlparse", "urllib.parse"), + MovedAttribute("urldefrag", "urlparse", "urllib.parse"), + MovedAttribute("urljoin", "urlparse", "urllib.parse"), + MovedAttribute("urlparse", "urlparse", "urllib.parse"), + MovedAttribute("urlsplit", "urlparse", "urllib.parse"), + MovedAttribute("urlunparse", "urlparse", "urllib.parse"), + MovedAttribute("urlunsplit", "urlparse", "urllib.parse"), + MovedAttribute("quote", "urllib", "urllib.parse"), + MovedAttribute("quote_plus", "urllib", "urllib.parse"), + MovedAttribute("unquote", "urllib", "urllib.parse"), + MovedAttribute("unquote_plus", "urllib", "urllib.parse"), + MovedAttribute("unquote_to_bytes", "urllib", "urllib.parse", "unquote", "unquote_to_bytes"), + MovedAttribute("urlencode", "urllib", "urllib.parse"), + MovedAttribute("splitquery", "urllib", "urllib.parse"), + MovedAttribute("splittag", "urllib", "urllib.parse"), + MovedAttribute("splituser", "urllib", "urllib.parse"), + MovedAttribute("splitvalue", "urllib", "urllib.parse"), + MovedAttribute("uses_fragment", "urlparse", "urllib.parse"), + MovedAttribute("uses_netloc", "urlparse", "urllib.parse"), + MovedAttribute("uses_params", "urlparse", "urllib.parse"), + MovedAttribute("uses_query", "urlparse", "urllib.parse"), + MovedAttribute("uses_relative", "urlparse", "urllib.parse"), +] +for attr in _urllib_parse_moved_attributes: + setattr(Module_six_moves_urllib_parse, attr.name, attr) +del attr + +Module_six_moves_urllib_parse._moved_attributes = _urllib_parse_moved_attributes + +_importer._add_module(Module_six_moves_urllib_parse(__name__ + ".moves.urllib_parse"), + "moves.urllib_parse", "moves.urllib.parse") + + +class Module_six_moves_urllib_error(_LazyModule): + + """Lazy loading of moved objects in six.moves.urllib_error""" + + +_urllib_error_moved_attributes = [ + MovedAttribute("URLError", "urllib2", "urllib.error"), + MovedAttribute("HTTPError", "urllib2", "urllib.error"), + MovedAttribute("ContentTooShortError", "urllib", "urllib.error"), +] +for attr in _urllib_error_moved_attributes: + setattr(Module_six_moves_urllib_error, attr.name, attr) +del attr + +Module_six_moves_urllib_error._moved_attributes = _urllib_error_moved_attributes + +_importer._add_module(Module_six_moves_urllib_error(__name__ + ".moves.urllib.error"), + "moves.urllib_error", "moves.urllib.error") + + +class Module_six_moves_urllib_request(_LazyModule): + + """Lazy loading of moved objects in six.moves.urllib_request""" + + +_urllib_request_moved_attributes = [ + MovedAttribute("urlopen", "urllib2", "urllib.request"), + MovedAttribute("install_opener", "urllib2", "urllib.request"), + MovedAttribute("build_opener", "urllib2", "urllib.request"), + MovedAttribute("pathname2url", "urllib", "urllib.request"), + MovedAttribute("url2pathname", "urllib", "urllib.request"), + MovedAttribute("getproxies", "urllib", "urllib.request"), + MovedAttribute("Request", "urllib2", "urllib.request"), + MovedAttribute("OpenerDirector", "urllib2", "urllib.request"), + MovedAttribute("HTTPDefaultErrorHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPRedirectHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPCookieProcessor", "urllib2", "urllib.request"), + MovedAttribute("ProxyHandler", "urllib2", "urllib.request"), + MovedAttribute("BaseHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPPasswordMgr", "urllib2", "urllib.request"), + MovedAttribute("HTTPPasswordMgrWithDefaultRealm", "urllib2", "urllib.request"), + MovedAttribute("AbstractBasicAuthHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPBasicAuthHandler", "urllib2", "urllib.request"), + MovedAttribute("ProxyBasicAuthHandler", "urllib2", "urllib.request"), + MovedAttribute("AbstractDigestAuthHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPDigestAuthHandler", "urllib2", "urllib.request"), + MovedAttribute("ProxyDigestAuthHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPSHandler", "urllib2", "urllib.request"), + MovedAttribute("FileHandler", "urllib2", "urllib.request"), + MovedAttribute("FTPHandler", "urllib2", "urllib.request"), + MovedAttribute("CacheFTPHandler", "urllib2", "urllib.request"), + MovedAttribute("UnknownHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPErrorProcessor", "urllib2", "urllib.request"), + MovedAttribute("urlretrieve", "urllib", "urllib.request"), + MovedAttribute("urlcleanup", "urllib", "urllib.request"), + MovedAttribute("URLopener", "urllib", "urllib.request"), + MovedAttribute("FancyURLopener", "urllib", "urllib.request"), + MovedAttribute("proxy_bypass", "urllib", "urllib.request"), + MovedAttribute("parse_http_list", "urllib2", "urllib.request"), + MovedAttribute("parse_keqv_list", "urllib2", "urllib.request"), +] +for attr in _urllib_request_moved_attributes: + setattr(Module_six_moves_urllib_request, attr.name, attr) +del attr + +Module_six_moves_urllib_request._moved_attributes = _urllib_request_moved_attributes + +_importer._add_module(Module_six_moves_urllib_request(__name__ + ".moves.urllib.request"), + "moves.urllib_request", "moves.urllib.request") + + +class Module_six_moves_urllib_response(_LazyModule): + + """Lazy loading of moved objects in six.moves.urllib_response""" + + +_urllib_response_moved_attributes = [ + MovedAttribute("addbase", "urllib", "urllib.response"), + MovedAttribute("addclosehook", "urllib", "urllib.response"), + MovedAttribute("addinfo", "urllib", "urllib.response"), + MovedAttribute("addinfourl", "urllib", "urllib.response"), +] +for attr in _urllib_response_moved_attributes: + setattr(Module_six_moves_urllib_response, attr.name, attr) +del attr + +Module_six_moves_urllib_response._moved_attributes = _urllib_response_moved_attributes + +_importer._add_module(Module_six_moves_urllib_response(__name__ + ".moves.urllib.response"), + "moves.urllib_response", "moves.urllib.response") + + +class Module_six_moves_urllib_robotparser(_LazyModule): + + """Lazy loading of moved objects in six.moves.urllib_robotparser""" + + +_urllib_robotparser_moved_attributes = [ + MovedAttribute("RobotFileParser", "robotparser", "urllib.robotparser"), +] +for attr in _urllib_robotparser_moved_attributes: + setattr(Module_six_moves_urllib_robotparser, attr.name, attr) +del attr + +Module_six_moves_urllib_robotparser._moved_attributes = _urllib_robotparser_moved_attributes + +_importer._add_module(Module_six_moves_urllib_robotparser(__name__ + ".moves.urllib.robotparser"), + "moves.urllib_robotparser", "moves.urllib.robotparser") + + +class Module_six_moves_urllib(types.ModuleType): + + """Create a six.moves.urllib namespace that resembles the Python 3 namespace""" + __path__ = [] # mark as package + parse = _importer._get_module("moves.urllib_parse") + error = _importer._get_module("moves.urllib_error") + request = _importer._get_module("moves.urllib_request") + response = _importer._get_module("moves.urllib_response") + robotparser = _importer._get_module("moves.urllib_robotparser") + + def __dir__(self): + return ['parse', 'error', 'request', 'response', 'robotparser'] + +_importer._add_module(Module_six_moves_urllib(__name__ + ".moves.urllib"), + "moves.urllib") + + +def add_move(move): + """Add an item to six.moves.""" + setattr(_MovedItems, move.name, move) + + +def remove_move(name): + """Remove item from six.moves.""" + try: + delattr(_MovedItems, name) + except AttributeError: + try: + del moves.__dict__[name] + except KeyError: + raise AttributeError("no such move, %r" % (name,)) + + +if PY3: + _meth_func = "__func__" + _meth_self = "__self__" + + _func_closure = "__closure__" + _func_code = "__code__" + _func_defaults = "__defaults__" + _func_globals = "__globals__" +else: + _meth_func = "im_func" + _meth_self = "im_self" + + _func_closure = "func_closure" + _func_code = "func_code" + _func_defaults = "func_defaults" + _func_globals = "func_globals" + + +try: + advance_iterator = next +except NameError: + def advance_iterator(it): + return it.next() +next = advance_iterator + + +try: + callable = callable +except NameError: + def callable(obj): + return any("__call__" in klass.__dict__ for klass in type(obj).__mro__) + + +if PY3: + def get_unbound_function(unbound): + return unbound + + create_bound_method = types.MethodType + + def create_unbound_method(func, cls): + return func + + Iterator = object +else: + def get_unbound_function(unbound): + return unbound.im_func + + def create_bound_method(func, obj): + return types.MethodType(func, obj, obj.__class__) + + def create_unbound_method(func, cls): + return types.MethodType(func, None, cls) + + class Iterator(object): + + def next(self): + return type(self).__next__(self) + + callable = callable +_add_doc(get_unbound_function, + """Get the function out of a possibly unbound function""") + + +get_method_function = operator.attrgetter(_meth_func) +get_method_self = operator.attrgetter(_meth_self) +get_function_closure = operator.attrgetter(_func_closure) +get_function_code = operator.attrgetter(_func_code) +get_function_defaults = operator.attrgetter(_func_defaults) +get_function_globals = operator.attrgetter(_func_globals) + + +if PY3: + def iterkeys(d, **kw): + return iter(d.keys(**kw)) + + def itervalues(d, **kw): + return iter(d.values(**kw)) + + def iteritems(d, **kw): + return iter(d.items(**kw)) + + def iterlists(d, **kw): + return iter(d.lists(**kw)) + + viewkeys = operator.methodcaller("keys") + + viewvalues = operator.methodcaller("values") + + viewitems = operator.methodcaller("items") +else: + def iterkeys(d, **kw): + return d.iterkeys(**kw) + + def itervalues(d, **kw): + return d.itervalues(**kw) + + def iteritems(d, **kw): + return d.iteritems(**kw) + + def iterlists(d, **kw): + return d.iterlists(**kw) + + viewkeys = operator.methodcaller("viewkeys") + + viewvalues = operator.methodcaller("viewvalues") + + viewitems = operator.methodcaller("viewitems") + +_add_doc(iterkeys, "Return an iterator over the keys of a dictionary.") +_add_doc(itervalues, "Return an iterator over the values of a dictionary.") +_add_doc(iteritems, + "Return an iterator over the (key, value) pairs of a dictionary.") +_add_doc(iterlists, + "Return an iterator over the (key, [values]) pairs of a dictionary.") + + +if PY3: + def b(s): + return s.encode("latin-1") + + def u(s): + return s + unichr = chr + import struct + int2byte = struct.Struct(">B").pack + del struct + byte2int = operator.itemgetter(0) + indexbytes = operator.getitem + iterbytes = iter + import io + StringIO = io.StringIO + BytesIO = io.BytesIO + _assertCountEqual = "assertCountEqual" + if sys.version_info[1] <= 1: + _assertRaisesRegex = "assertRaisesRegexp" + _assertRegex = "assertRegexpMatches" + else: + _assertRaisesRegex = "assertRaisesRegex" + _assertRegex = "assertRegex" +else: + def b(s): + return s + # Workaround for standalone backslash + + def u(s): + return unicode(s.replace(r'\\', r'\\\\'), "unicode_escape") + unichr = unichr + int2byte = chr + + def byte2int(bs): + return ord(bs[0]) + + def indexbytes(buf, i): + return ord(buf[i]) + iterbytes = functools.partial(itertools.imap, ord) + import StringIO + StringIO = BytesIO = StringIO.StringIO + _assertCountEqual = "assertItemsEqual" + _assertRaisesRegex = "assertRaisesRegexp" + _assertRegex = "assertRegexpMatches" +_add_doc(b, """Byte literal""") +_add_doc(u, """Text literal""") + + +def assertCountEqual(self, *args, **kwargs): + return getattr(self, _assertCountEqual)(*args, **kwargs) + + +def assertRaisesRegex(self, *args, **kwargs): + return getattr(self, _assertRaisesRegex)(*args, **kwargs) + + +def assertRegex(self, *args, **kwargs): + return getattr(self, _assertRegex)(*args, **kwargs) + + +if PY3: + exec_ = getattr(moves.builtins, "exec") + + def reraise(tp, value, tb=None): + try: + if value is None: + value = tp() + if value.__traceback__ is not tb: + raise value.with_traceback(tb) + raise value + finally: + value = None + tb = None + +else: + def exec_(_code_, _globs_=None, _locs_=None): + """Execute code in a namespace.""" + if _globs_ is None: + frame = sys._getframe(1) + _globs_ = frame.f_globals + if _locs_ is None: + _locs_ = frame.f_locals + del frame + elif _locs_ is None: + _locs_ = _globs_ + exec("""exec _code_ in _globs_, _locs_""") + + exec_("""def reraise(tp, value, tb=None): + try: + raise tp, value, tb + finally: + tb = None +""") + + +if sys.version_info[:2] == (3, 2): + exec_("""def raise_from(value, from_value): + try: + if from_value is None: + raise value + raise value from from_value + finally: + value = None +""") +elif sys.version_info[:2] > (3, 2): + exec_("""def raise_from(value, from_value): + try: + raise value from from_value + finally: + value = None +""") +else: + def raise_from(value, from_value): + raise value + + +print_ = getattr(moves.builtins, "print", None) +if print_ is None: + def print_(*args, **kwargs): + """The new-style print function for Python 2.4 and 2.5.""" + fp = kwargs.pop("file", sys.stdout) + if fp is None: + return + + def write(data): + if not isinstance(data, basestring): + data = str(data) + # If the file has an encoding, encode unicode with it. + if (isinstance(fp, file) and + isinstance(data, unicode) and + fp.encoding is not None): + errors = getattr(fp, "errors", None) + if errors is None: + errors = "strict" + data = data.encode(fp.encoding, errors) + fp.write(data) + want_unicode = False + sep = kwargs.pop("sep", None) + if sep is not None: + if isinstance(sep, unicode): + want_unicode = True + elif not isinstance(sep, str): + raise TypeError("sep must be None or a string") + end = kwargs.pop("end", None) + if end is not None: + if isinstance(end, unicode): + want_unicode = True + elif not isinstance(end, str): + raise TypeError("end must be None or a string") + if kwargs: + raise TypeError("invalid keyword arguments to print()") + if not want_unicode: + for arg in args: + if isinstance(arg, unicode): + want_unicode = True + break + if want_unicode: + newline = unicode("\n") + space = unicode(" ") + else: + newline = "\n" + space = " " + if sep is None: + sep = space + if end is None: + end = newline + for i, arg in enumerate(args): + if i: + write(sep) + write(arg) + write(end) +if sys.version_info[:2] < (3, 3): + _print = print_ + + def print_(*args, **kwargs): + fp = kwargs.get("file", sys.stdout) + flush = kwargs.pop("flush", False) + _print(*args, **kwargs) + if flush and fp is not None: + fp.flush() + +_add_doc(reraise, """Reraise an exception.""") + +if sys.version_info[0:2] < (3, 4): + def wraps(wrapped, assigned=functools.WRAPPER_ASSIGNMENTS, + updated=functools.WRAPPER_UPDATES): + def wrapper(f): + f = functools.wraps(wrapped, assigned, updated)(f) + f.__wrapped__ = wrapped + return f + return wrapper +else: + wraps = functools.wraps + + +def with_metaclass(meta, *bases): + """Create a base class with a metaclass.""" + # This requires a bit of explanation: the basic idea is to make a dummy + # metaclass for one level of class instantiation that replaces itself with + # the actual metaclass. + class metaclass(type): + + def __new__(cls, name, this_bases, d): + return meta(name, bases, d) + + @classmethod + def __prepare__(cls, name, this_bases): + return meta.__prepare__(name, bases) + return type.__new__(metaclass, 'temporary_class', (), {}) + + +def add_metaclass(metaclass): + """Class decorator for creating a class with a metaclass.""" + def wrapper(cls): + orig_vars = cls.__dict__.copy() + slots = orig_vars.get('__slots__') + if slots is not None: + if isinstance(slots, str): + slots = [slots] + for slots_var in slots: + orig_vars.pop(slots_var) + orig_vars.pop('__dict__', None) + orig_vars.pop('__weakref__', None) + return metaclass(cls.__name__, cls.__bases__, orig_vars) + return wrapper + + +def python_2_unicode_compatible(klass): + """ + A decorator that defines __unicode__ and __str__ methods under Python 2. + Under Python 3 it does nothing. + + To support Python 2 and 3 with a single code base, define a __str__ method + returning text and apply this decorator to the class. + """ + if PY2: + if '__str__' not in klass.__dict__: + raise ValueError("@python_2_unicode_compatible cannot be applied " + "to %s because it doesn't define __str__()." % + klass.__name__) + klass.__unicode__ = klass.__str__ + klass.__str__ = lambda self: self.__unicode__().encode('utf-8') + return klass + + +# Complete the moves implementation. +# This code is at the end of this module to speed up module loading. +# Turn this module into a package. +__path__ = [] # required for PEP 302 and PEP 451 +__package__ = __name__ # see PEP 366 @ReservedAssignment +if globals().get("__spec__") is not None: + __spec__.submodule_search_locations = [] # PEP 451 @UndefinedVariable +# Remove other six meta path importers, since they cause problems. This can +# happen if six is removed from sys.modules and then reloaded. (Setuptools does +# this for some reason.) +if sys.meta_path: + for i, importer in enumerate(sys.meta_path): + # Here's some real nastiness: Another "instance" of the six module might + # be floating around. Therefore, we can't use isinstance() to check for + # the six meta path importer, since the other six instance will have + # inserted an importer with different class. + if (type(importer).__name__ == "_SixMetaPathImporter" and + importer.name == __name__): + del sys.meta_path[i] + break + del i, importer +# Finally, add the importer to the meta path import hook. +sys.meta_path.append(_importer) diff --git a/vendor/socketpair.py b/vendor/socketpair.py new file mode 100644 index 00000000..b55e629e --- /dev/null +++ b/vendor/socketpair.py @@ -0,0 +1,58 @@ +# pylint: skip-file +# vendored from https://github.com/mhils/backports.socketpair +from __future__ import absolute_import + +import sys +import socket +import errno + +_LOCALHOST = '127.0.0.1' +_LOCALHOST_V6 = '::1' + +if not hasattr(socket, "socketpair"): + # Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. + def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): + if family == socket.AF_INET: + host = _LOCALHOST + elif family == socket.AF_INET6: + host = _LOCALHOST_V6 + else: + raise ValueError("Only AF_INET and AF_INET6 socket address families " + "are supported") + if type != socket.SOCK_STREAM: + raise ValueError("Only SOCK_STREAM socket type is supported") + if proto != 0: + raise ValueError("Only protocol zero is supported") + + # We create a connected TCP socket. Note the trick with + # setblocking(False) that prevents us from having to create a thread. + lsock = socket.socket(family, type, proto) + try: + lsock.bind((host, 0)) + lsock.listen(min(socket.SOMAXCONN, 128)) + # On IPv6, ignore flow_info and scope_id + addr, port = lsock.getsockname()[:2] + csock = socket.socket(family, type, proto) + try: + csock.setblocking(False) + if sys.version_info >= (3, 0): + try: + csock.connect((addr, port)) + except (BlockingIOError, InterruptedError): + pass + else: + try: + csock.connect((addr, port)) + except socket.error as e: + if e.errno != errno.WSAEWOULDBLOCK: + raise + csock.setblocking(True) + ssock, _ = lsock.accept() + except Exception: + csock.close() + raise + finally: + lsock.close() + return (ssock, csock) + + socket.socketpair = socketpair diff --git a/version.py b/version.py new file mode 100644 index 00000000..06306bd1 --- /dev/null +++ b/version.py @@ -0,0 +1 @@ +__version__ = '2.0.3-dev' From 3f3c169f0adaab621199ea6318e3fd10a21dec47 Mon Sep 17 00:00:00 2001 From: Denis Otkidach Date: Sat, 21 Oct 2023 22:02:54 +0300 Subject: [PATCH 03/20] Remove kafka-python from dependencies --- .github/ISSUE_TEMPLATE/bug_report.md | 3 +-- setup.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index fb5c9f92..f7d99612 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -15,8 +15,7 @@ A clear and concise description of what you expected to happen. **Environment (please complete the following information):** - aiokafka version (`python -c "import aiokafka; print(aiokafka.__version__)"`): - - kafka-python version (`python -c "import kafka; print(kafka.__version__)"`): - - Kafka Broker version (`kafka-topics.sh --version`): + - Kafka Broker version (`kafka-topics.sh --version`): - Other information (Confluent Cloud version, etc.): **Reproducible example** diff --git a/setup.py b/setup.py index 5ff8cc54..fc69494b 100644 --- a/setup.py +++ b/setup.py @@ -112,7 +112,6 @@ def build_extension(self, ext): install_requires = [ "async-timeout", - "kafka-python>=2.0.2", "packaging", ] @@ -173,7 +172,7 @@ def read_version(): }, download_url="https://pypi.python.org/pypi/aiokafka", license="Apache 2", - packages=["aiokafka"], + packages=["aiokafka", "kafka"], python_requires=">=3.8", install_requires=install_requires, extras_require=extras_require, From 90a88510adfc4376689500a44389ed68b1ec3025 Mon Sep 17 00:00:00 2001 From: Denis Otkidach Date: Sat, 21 Oct 2023 22:06:16 +0300 Subject: [PATCH 04/20] Squashed 'tests/kafka/' content from commit 1d17d3b7 git-subtree-dir: tests/kafka git-subtree-split: 1d17d3b7950b40aeeeef52281b8b7ab51622b272 --- __init__.py | 8 + conftest.py | 175 ++++++ fixtures.py | 673 +++++++++++++++++++++++ record/test_default_records.py | 208 +++++++ record/test_legacy_records.py | 197 +++++++ record/test_records.py | 232 ++++++++ record/test_util.py | 96 ++++ service.py | 133 +++++ test_acl_comparisons.py | 92 ++++ test_admin.py | 78 +++ test_admin_integration.py | 314 +++++++++++ test_api_object_implementation.py | 18 + test_assignors.py | 871 ++++++++++++++++++++++++++++++ test_client_async.py | 409 ++++++++++++++ test_cluster.py | 22 + test_codec.py | 124 +++++ test_conn.py | 342 ++++++++++++ test_consumer.py | 26 + test_consumer_group.py | 179 ++++++ test_consumer_integration.py | 297 ++++++++++ test_coordinator.py | 638 ++++++++++++++++++++++ test_fetcher.py | 553 +++++++++++++++++++ test_metrics.py | 499 +++++++++++++++++ test_object_conversion.py | 236 ++++++++ test_package.py | 25 + test_partition_movements.py | 23 + test_partitioner.py | 38 ++ test_producer.py | 137 +++++ test_protocol.py | 336 ++++++++++++ test_sasl_integration.py | 80 +++ test_sender.py | 53 ++ test_subscription_state.py | 25 + testutil.py | 46 ++ 33 files changed, 7183 insertions(+) create mode 100644 __init__.py create mode 100644 conftest.py create mode 100644 fixtures.py create mode 100644 record/test_default_records.py create mode 100644 record/test_legacy_records.py create mode 100644 record/test_records.py create mode 100644 record/test_util.py create mode 100644 service.py create mode 100644 test_acl_comparisons.py create mode 100644 test_admin.py create mode 100644 test_admin_integration.py create mode 100644 test_api_object_implementation.py create mode 100644 test_assignors.py create mode 100644 test_client_async.py create mode 100644 test_cluster.py create mode 100644 test_codec.py create mode 100644 test_conn.py create mode 100644 test_consumer.py create mode 100644 test_consumer_group.py create mode 100644 test_consumer_integration.py create mode 100644 test_coordinator.py create mode 100644 test_fetcher.py create mode 100644 test_metrics.py create mode 100644 test_object_conversion.py create mode 100644 test_package.py create mode 100644 test_partition_movements.py create mode 100644 test_partitioner.py create mode 100644 test_producer.py create mode 100644 test_protocol.py create mode 100644 test_sasl_integration.py create mode 100644 test_sender.py create mode 100644 test_subscription_state.py create mode 100644 testutil.py diff --git a/__init__.py b/__init__.py new file mode 100644 index 00000000..329277dc --- /dev/null +++ b/__init__.py @@ -0,0 +1,8 @@ +from __future__ import absolute_import + +# Set default logging handler to avoid "No handler found" warnings. +import logging +logging.basicConfig(level=logging.INFO) + +from kafka.future import Future +Future.error_on_callbacks = True # always fail during testing diff --git a/conftest.py b/conftest.py new file mode 100644 index 00000000..3fa0262f --- /dev/null +++ b/conftest.py @@ -0,0 +1,175 @@ +from __future__ import absolute_import + +import uuid + +import pytest + +from test.testutil import env_kafka_version, random_string +from test.fixtures import KafkaFixture, ZookeeperFixture + +@pytest.fixture(scope="module") +def zookeeper(): + """Return a Zookeeper fixture""" + zk_instance = ZookeeperFixture.instance() + yield zk_instance + zk_instance.close() + + +@pytest.fixture(scope="module") +def kafka_broker(kafka_broker_factory): + """Return a Kafka broker fixture""" + return kafka_broker_factory()[0] + + +@pytest.fixture(scope="module") +def kafka_broker_factory(zookeeper): + """Return a Kafka broker fixture factory""" + assert env_kafka_version(), 'KAFKA_VERSION must be specified to run integration tests' + + _brokers = [] + def factory(**broker_params): + params = {} if broker_params is None else broker_params.copy() + params.setdefault('partitions', 4) + num_brokers = params.pop('num_brokers', 1) + brokers = tuple(KafkaFixture.instance(x, zookeeper, **params) + for x in range(num_brokers)) + _brokers.extend(brokers) + return brokers + + yield factory + + for broker in _brokers: + broker.close() + + +@pytest.fixture +def kafka_client(kafka_broker, request): + """Return a KafkaClient fixture""" + (client,) = kafka_broker.get_clients(cnt=1, client_id='%s_client' % (request.node.name,)) + yield client + client.close() + + +@pytest.fixture +def kafka_consumer(kafka_consumer_factory): + """Return a KafkaConsumer fixture""" + return kafka_consumer_factory() + + +@pytest.fixture +def kafka_consumer_factory(kafka_broker, topic, request): + """Return a KafkaConsumer factory fixture""" + _consumer = [None] + + def factory(**kafka_consumer_params): + params = {} if kafka_consumer_params is None else kafka_consumer_params.copy() + params.setdefault('client_id', 'consumer_%s' % (request.node.name,)) + params.setdefault('auto_offset_reset', 'earliest') + _consumer[0] = next(kafka_broker.get_consumers(cnt=1, topics=[topic], **params)) + return _consumer[0] + + yield factory + + if _consumer[0]: + _consumer[0].close() + + +@pytest.fixture +def kafka_producer(kafka_producer_factory): + """Return a KafkaProducer fixture""" + yield kafka_producer_factory() + + +@pytest.fixture +def kafka_producer_factory(kafka_broker, request): + """Return a KafkaProduce factory fixture""" + _producer = [None] + + def factory(**kafka_producer_params): + params = {} if kafka_producer_params is None else kafka_producer_params.copy() + params.setdefault('client_id', 'producer_%s' % (request.node.name,)) + _producer[0] = next(kafka_broker.get_producers(cnt=1, **params)) + return _producer[0] + + yield factory + + if _producer[0]: + _producer[0].close() + +@pytest.fixture +def kafka_admin_client(kafka_admin_client_factory): + """Return a KafkaAdminClient fixture""" + yield kafka_admin_client_factory() + +@pytest.fixture +def kafka_admin_client_factory(kafka_broker): + """Return a KafkaAdminClient factory fixture""" + _admin_client = [None] + + def factory(**kafka_admin_client_params): + params = {} if kafka_admin_client_params is None else kafka_admin_client_params.copy() + _admin_client[0] = next(kafka_broker.get_admin_clients(cnt=1, **params)) + return _admin_client[0] + + yield factory + + if _admin_client[0]: + _admin_client[0].close() + +@pytest.fixture +def topic(kafka_broker, request): + """Return a topic fixture""" + topic_name = '%s_%s' % (request.node.name, random_string(10)) + kafka_broker.create_topics([topic_name]) + return topic_name + + +@pytest.fixture +def conn(mocker): + """Return a connection mocker fixture""" + from kafka.conn import ConnectionStates + from kafka.future import Future + from kafka.protocol.metadata import MetadataResponse + conn = mocker.patch('kafka.client_async.BrokerConnection') + conn.return_value = conn + conn.state = ConnectionStates.CONNECTED + conn.send.return_value = Future().success( + MetadataResponse[0]( + [(0, 'foo', 12), (1, 'bar', 34)], # brokers + [])) # topics + conn.blacked_out.return_value = False + def _set_conn_state(state): + conn.state = state + return state + conn._set_conn_state = _set_conn_state + conn.connect.side_effect = lambda: conn.state + conn.connect_blocking.return_value = True + conn.connecting = lambda: conn.state in (ConnectionStates.CONNECTING, + ConnectionStates.HANDSHAKE) + conn.connected = lambda: conn.state is ConnectionStates.CONNECTED + conn.disconnected = lambda: conn.state is ConnectionStates.DISCONNECTED + return conn + + +@pytest.fixture() +def send_messages(topic, kafka_producer, request): + """A factory that returns a send_messages function with a pre-populated + topic topic / producer.""" + + def _send_messages(number_range, partition=0, topic=topic, producer=kafka_producer, request=request): + """ + messages is typically `range(0,100)` + partition is an int + """ + messages_and_futures = [] # [(message, produce_future),] + for i in number_range: + # request.node.name provides the test name (including parametrized values) + encoded_msg = '{}-{}-{}'.format(i, request.node.name, uuid.uuid4()).encode('utf-8') + future = kafka_producer.send(topic, value=encoded_msg, partition=partition) + messages_and_futures.append((encoded_msg, future)) + kafka_producer.flush() + for (msg, f) in messages_and_futures: + assert f.succeeded() + return [msg for (msg, f) in messages_and_futures] + + return _send_messages diff --git a/fixtures.py b/fixtures.py new file mode 100644 index 00000000..d9c072b8 --- /dev/null +++ b/fixtures.py @@ -0,0 +1,673 @@ +from __future__ import absolute_import + +import atexit +import logging +import os +import os.path +import socket +import subprocess +import time +import uuid + +import py +from kafka.vendor.six.moves import urllib, range +from kafka.vendor.six.moves.urllib.parse import urlparse # pylint: disable=E0611,F0401 + +from kafka import errors, KafkaAdminClient, KafkaClient, KafkaConsumer, KafkaProducer +from kafka.errors import InvalidReplicationFactorError +from kafka.protocol.admin import CreateTopicsRequest +from kafka.protocol.metadata import MetadataRequest +from test.testutil import env_kafka_version, random_string +from test.service import ExternalService, SpawnedService + +log = logging.getLogger(__name__) + + +def get_open_port(): + sock = socket.socket() + sock.bind(("127.0.0.1", 0)) + port = sock.getsockname()[1] + sock.close() + return port + + +def gen_ssl_resources(directory): + os.system(""" + cd {0} + echo Generating SSL resources in {0} + + # Step 1 + keytool -keystore kafka.server.keystore.jks -alias localhost -validity 1 \ + -genkey -storepass foobar -keypass foobar \ + -dname "CN=localhost, OU=kafka-python, O=kafka-python, L=SF, ST=CA, C=US" \ + -ext SAN=dns:localhost + + # Step 2 + openssl genrsa -out ca-key 2048 + openssl req -new -x509 -key ca-key -out ca-cert -days 1 \ + -subj "/C=US/ST=CA/O=MyOrg, Inc./CN=mydomain.com" + keytool -keystore kafka.server.truststore.jks -alias CARoot -import \ + -file ca-cert -storepass foobar -noprompt + + # Step 3 + keytool -keystore kafka.server.keystore.jks -alias localhost -certreq \ + -file cert-file -storepass foobar + openssl x509 -req -CA ca-cert -CAkey ca-key -in cert-file -out cert-signed \ + -days 1 -CAcreateserial -passin pass:foobar + keytool -keystore kafka.server.keystore.jks -alias CARoot -import \ + -file ca-cert -storepass foobar -noprompt + keytool -keystore kafka.server.keystore.jks -alias localhost -import \ + -file cert-signed -storepass foobar -noprompt + """.format(directory)) + + +class Fixture(object): + kafka_version = os.environ.get('KAFKA_VERSION', '0.11.0.2') + scala_version = os.environ.get("SCALA_VERSION", '2.8.0') + project_root = os.environ.get('PROJECT_ROOT', + os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + kafka_root = os.environ.get("KAFKA_ROOT", + os.path.join(project_root, 'servers', kafka_version, "kafka-bin")) + + def __init__(self): + self.child = None + + @classmethod + def download_official_distribution(cls, + kafka_version=None, + scala_version=None, + output_dir=None): + if not kafka_version: + kafka_version = cls.kafka_version + if not scala_version: + scala_version = cls.scala_version + if not output_dir: + output_dir = os.path.join(cls.project_root, 'servers', 'dist') + + distfile = 'kafka_%s-%s' % (scala_version, kafka_version,) + url_base = 'https://archive.apache.org/dist/kafka/%s/' % (kafka_version,) + output_file = os.path.join(output_dir, distfile + '.tgz') + + if os.path.isfile(output_file): + log.info("Found file already on disk: %s", output_file) + return output_file + + # New tarballs are .tgz, older ones are sometimes .tar.gz + try: + url = url_base + distfile + '.tgz' + log.info("Attempting to download %s", url) + response = urllib.request.urlopen(url) + except urllib.error.HTTPError: + log.exception("HTTP Error") + url = url_base + distfile + '.tar.gz' + log.info("Attempting to download %s", url) + response = urllib.request.urlopen(url) + + log.info("Saving distribution file to %s", output_file) + with open(output_file, 'w') as output_file_fd: + output_file_fd.write(response.read()) + + return output_file + + @classmethod + def test_resource(cls, filename): + return os.path.join(cls.project_root, "servers", cls.kafka_version, "resources", filename) + + @classmethod + def kafka_run_class_args(cls, *args): + result = [os.path.join(cls.kafka_root, 'bin', 'kafka-run-class.sh')] + result.extend([str(arg) for arg in args]) + return result + + def kafka_run_class_env(self): + env = os.environ.copy() + env['KAFKA_LOG4J_OPTS'] = "-Dlog4j.configuration=file:%s" % \ + (self.test_resource("log4j.properties"),) + return env + + @classmethod + def render_template(cls, source_file, target_file, binding): + log.info('Rendering %s from template %s', target_file.strpath, source_file) + with open(source_file, "r") as handle: + template = handle.read() + assert len(template) > 0, 'Empty template %s' % (source_file,) + with open(target_file.strpath, "w") as handle: + handle.write(template.format(**binding)) + handle.flush() + os.fsync(handle) + + # fsync directory for durability + # https://blog.gocept.com/2013/07/15/reliable-file-updates-with-python/ + dirfd = os.open(os.path.dirname(target_file.strpath), os.O_DIRECTORY) + os.fsync(dirfd) + os.close(dirfd) + log.debug("Template string:") + for line in template.splitlines(): + log.debug(' ' + line.strip()) + log.debug("Rendered template:") + with open(target_file.strpath, 'r') as o: + for line in o: + log.debug(' ' + line.strip()) + log.debug("binding:") + for key, value in binding.items(): + log.debug(" {key}={value}".format(key=key, value=value)) + + def dump_logs(self): + self.child.dump_logs() + + +class ZookeeperFixture(Fixture): + @classmethod + def instance(cls): + if "ZOOKEEPER_URI" in os.environ: + parse = urlparse(os.environ["ZOOKEEPER_URI"]) + (host, port) = (parse.hostname, parse.port) + fixture = ExternalService(host, port) + else: + (host, port) = ("127.0.0.1", None) + fixture = cls(host, port) + + fixture.open() + return fixture + + def __init__(self, host, port, tmp_dir=None): + super(ZookeeperFixture, self).__init__() + self.host = host + self.port = port + + self.tmp_dir = tmp_dir + + def kafka_run_class_env(self): + env = super(ZookeeperFixture, self).kafka_run_class_env() + env['LOG_DIR'] = self.tmp_dir.join('logs').strpath + return env + + def out(self, message): + log.info("*** Zookeeper [%s:%s]: %s", self.host, self.port or '(auto)', message) + + def open(self): + if self.tmp_dir is None: + self.tmp_dir = py.path.local.mkdtemp() #pylint: disable=no-member + self.tmp_dir.ensure(dir=True) + + self.out("Running local instance...") + log.info(" host = %s", self.host) + log.info(" port = %s", self.port or '(auto)') + log.info(" tmp_dir = %s", self.tmp_dir.strpath) + + # Configure Zookeeper child process + template = self.test_resource("zookeeper.properties") + properties = self.tmp_dir.join("zookeeper.properties") + args = self.kafka_run_class_args("org.apache.zookeeper.server.quorum.QuorumPeerMain", + properties.strpath) + env = self.kafka_run_class_env() + + # Party! + timeout = 5 + max_timeout = 120 + backoff = 1 + end_at = time.time() + max_timeout + tries = 1 + auto_port = (self.port is None) + while time.time() < end_at: + if auto_port: + self.port = get_open_port() + self.out('Attempting to start on port %d (try #%d)' % (self.port, tries)) + self.render_template(template, properties, vars(self)) + self.child = SpawnedService(args, env) + self.child.start() + timeout = min(timeout, max(end_at - time.time(), 0)) + if self.child.wait_for(r"binding to port", timeout=timeout): + break + self.child.dump_logs() + self.child.stop() + timeout *= 2 + time.sleep(backoff) + tries += 1 + backoff += 1 + else: + raise RuntimeError('Failed to start Zookeeper before max_timeout') + self.out("Done!") + atexit.register(self.close) + + def close(self): + if self.child is None: + return + self.out("Stopping...") + self.child.stop() + self.child = None + self.out("Done!") + self.tmp_dir.remove() + + def __del__(self): + self.close() + + +class KafkaFixture(Fixture): + broker_user = 'alice' + broker_password = 'alice-secret' + + @classmethod + def instance(cls, broker_id, zookeeper, zk_chroot=None, + host=None, port=None, + transport='PLAINTEXT', replicas=1, partitions=2, + sasl_mechanism=None, auto_create_topic=True, tmp_dir=None): + + if zk_chroot is None: + zk_chroot = "kafka-python_" + str(uuid.uuid4()).replace("-", "_") + if "KAFKA_URI" in os.environ: + parse = urlparse(os.environ["KAFKA_URI"]) + (host, port) = (parse.hostname, parse.port) + fixture = ExternalService(host, port) + else: + if host is None: + host = "localhost" + fixture = KafkaFixture(host, port, broker_id, + zookeeper, zk_chroot, + transport=transport, + replicas=replicas, partitions=partitions, + sasl_mechanism=sasl_mechanism, + auto_create_topic=auto_create_topic, + tmp_dir=tmp_dir) + + fixture.open() + return fixture + + def __init__(self, host, port, broker_id, zookeeper, zk_chroot, + replicas=1, partitions=2, transport='PLAINTEXT', + sasl_mechanism=None, auto_create_topic=True, + tmp_dir=None): + super(KafkaFixture, self).__init__() + + self.host = host + self.port = port + + self.broker_id = broker_id + self.auto_create_topic = auto_create_topic + self.transport = transport.upper() + if sasl_mechanism is not None: + self.sasl_mechanism = sasl_mechanism.upper() + else: + self.sasl_mechanism = None + self.ssl_dir = self.test_resource('ssl') + + # TODO: checking for port connection would be better than scanning logs + # until then, we need the pattern to work across all supported broker versions + # The logging format changed slightly in 1.0.0 + self.start_pattern = r"\[Kafka ?Server (id=)?%d\],? started" % (broker_id,) + # Need to wait until the broker has fetched user configs from zookeeper in case we use scram as sasl mechanism + self.scram_pattern = r"Removing Produce quota for user %s" % (self.broker_user) + + self.zookeeper = zookeeper + self.zk_chroot = zk_chroot + # Add the attributes below for the template binding + self.zk_host = self.zookeeper.host + self.zk_port = self.zookeeper.port + + self.replicas = replicas + self.partitions = partitions + + self.tmp_dir = tmp_dir + self.running = False + + self._client = None + self.sasl_config = '' + self.jaas_config = '' + + def _sasl_config(self): + if not self.sasl_enabled: + return '' + + sasl_config = ( + 'sasl.enabled.mechanisms={mechanism}\n' + 'sasl.mechanism.inter.broker.protocol={mechanism}\n' + ) + return sasl_config.format(mechanism=self.sasl_mechanism) + + def _jaas_config(self): + if not self.sasl_enabled: + return '' + + elif self.sasl_mechanism == 'PLAIN': + jaas_config = ( + 'org.apache.kafka.common.security.plain.PlainLoginModule required\n' + ' username="{user}" password="{password}" user_{user}="{password}";\n' + ) + elif self.sasl_mechanism in ("SCRAM-SHA-256", "SCRAM-SHA-512"): + jaas_config = ( + 'org.apache.kafka.common.security.scram.ScramLoginModule required\n' + ' username="{user}" password="{password}";\n' + ) + else: + raise ValueError("SASL mechanism {} currently not supported".format(self.sasl_mechanism)) + return jaas_config.format(user=self.broker_user, password=self.broker_password) + + def _add_scram_user(self): + self.out("Adding SCRAM credentials for user {} to zookeeper.".format(self.broker_user)) + args = self.kafka_run_class_args( + "kafka.admin.ConfigCommand", + "--zookeeper", + "%s:%d/%s" % (self.zookeeper.host, + self.zookeeper.port, + self.zk_chroot), + "--alter", + "--entity-type", "users", + "--entity-name", self.broker_user, + "--add-config", + "{}=[password={}]".format(self.sasl_mechanism, self.broker_password), + ) + env = self.kafka_run_class_env() + proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + stdout, stderr = proc.communicate() + + if proc.returncode != 0: + self.out("Failed to save credentials to zookeeper!") + self.out(stdout) + self.out(stderr) + raise RuntimeError("Failed to save credentials to zookeeper!") + self.out("User created.") + + @property + def sasl_enabled(self): + return self.sasl_mechanism is not None + + def bootstrap_server(self): + return '%s:%d' % (self.host, self.port) + + def kafka_run_class_env(self): + env = super(KafkaFixture, self).kafka_run_class_env() + env['LOG_DIR'] = self.tmp_dir.join('logs').strpath + return env + + def out(self, message): + log.info("*** Kafka [%s:%s]: %s", self.host, self.port or '(auto)', message) + + def _create_zk_chroot(self): + self.out("Creating Zookeeper chroot node...") + args = self.kafka_run_class_args("org.apache.zookeeper.ZooKeeperMain", + "-server", + "%s:%d" % (self.zookeeper.host, + self.zookeeper.port), + "create", + "/%s" % (self.zk_chroot,), + "kafka-python") + env = self.kafka_run_class_env() + proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + stdout, stderr = proc.communicate() + + if proc.returncode != 0: + self.out("Failed to create Zookeeper chroot node") + self.out(stdout) + self.out(stderr) + raise RuntimeError("Failed to create Zookeeper chroot node") + self.out("Kafka chroot created in Zookeeper!") + + def start(self): + # Configure Kafka child process + properties = self.tmp_dir.join("kafka.properties") + jaas_conf = self.tmp_dir.join("kafka_server_jaas.conf") + properties_template = self.test_resource("kafka.properties") + jaas_conf_template = self.test_resource("kafka_server_jaas.conf") + + args = self.kafka_run_class_args("kafka.Kafka", properties.strpath) + env = self.kafka_run_class_env() + if self.sasl_enabled: + opts = env.get('KAFKA_OPTS', '').strip() + opts += ' -Djava.security.auth.login.config={}'.format(jaas_conf.strpath) + env['KAFKA_OPTS'] = opts + self.render_template(jaas_conf_template, jaas_conf, vars(self)) + + timeout = 5 + max_timeout = 120 + backoff = 1 + end_at = time.time() + max_timeout + tries = 1 + auto_port = (self.port is None) + while time.time() < end_at: + # We have had problems with port conflicts on travis + # so we will try a different port on each retry + # unless the fixture was passed a specific port + if auto_port: + self.port = get_open_port() + self.out('Attempting to start on port %d (try #%d)' % (self.port, tries)) + self.render_template(properties_template, properties, vars(self)) + + self.child = SpawnedService(args, env) + self.child.start() + timeout = min(timeout, max(end_at - time.time(), 0)) + if self._broker_ready(timeout) and self._scram_user_present(timeout): + break + + self.child.dump_logs() + self.child.stop() + + timeout *= 2 + time.sleep(backoff) + tries += 1 + backoff += 1 + else: + raise RuntimeError('Failed to start KafkaInstance before max_timeout') + + (self._client,) = self.get_clients(1, client_id='_internal_client') + + self.out("Done!") + self.running = True + + def _broker_ready(self, timeout): + return self.child.wait_for(self.start_pattern, timeout=timeout) + + def _scram_user_present(self, timeout): + # no need to wait for scram user if scram is not used + if not self.sasl_enabled or not self.sasl_mechanism.startswith('SCRAM-SHA-'): + return True + return self.child.wait_for(self.scram_pattern, timeout=timeout) + + def open(self): + if self.running: + self.out("Instance already running") + return + + # Create directories + if self.tmp_dir is None: + self.tmp_dir = py.path.local.mkdtemp() #pylint: disable=no-member + self.tmp_dir.ensure(dir=True) + self.tmp_dir.ensure('logs', dir=True) + self.tmp_dir.ensure('data', dir=True) + + self.out("Running local instance...") + log.info(" host = %s", self.host) + log.info(" port = %s", self.port or '(auto)') + log.info(" transport = %s", self.transport) + log.info(" sasl_mechanism = %s", self.sasl_mechanism) + log.info(" broker_id = %s", self.broker_id) + log.info(" zk_host = %s", self.zookeeper.host) + log.info(" zk_port = %s", self.zookeeper.port) + log.info(" zk_chroot = %s", self.zk_chroot) + log.info(" replicas = %s", self.replicas) + log.info(" partitions = %s", self.partitions) + log.info(" tmp_dir = %s", self.tmp_dir.strpath) + + self._create_zk_chroot() + self.sasl_config = self._sasl_config() + self.jaas_config = self._jaas_config() + # add user to zookeeper for the first server + if self.sasl_enabled and self.sasl_mechanism.startswith("SCRAM-SHA") and self.broker_id == 0: + self._add_scram_user() + self.start() + + atexit.register(self.close) + + def __del__(self): + self.close() + + def stop(self): + if not self.running: + self.out("Instance already stopped") + return + + self.out("Stopping...") + self.child.stop() + self.child = None + self.running = False + self.out("Stopped!") + + def close(self): + self.stop() + if self.tmp_dir is not None: + self.tmp_dir.remove() + self.tmp_dir = None + self.out("Done!") + + def dump_logs(self): + super(KafkaFixture, self).dump_logs() + self.zookeeper.dump_logs() + + def _send_request(self, request, timeout=None): + def _failure(error): + raise error + retries = 10 + while True: + node_id = self._client.least_loaded_node() + for connect_retry in range(40): + self._client.maybe_connect(node_id) + if self._client.connected(node_id): + break + self._client.poll(timeout_ms=100) + else: + raise RuntimeError('Could not connect to broker with node id %d' % (node_id,)) + + try: + future = self._client.send(node_id, request) + future.error_on_callbacks = True + future.add_errback(_failure) + self._client.poll(future=future, timeout_ms=timeout) + return future.value + except Exception as exc: + time.sleep(1) + retries -= 1 + if retries == 0: + raise exc + else: + pass # retry + + def _create_topic(self, topic_name, num_partitions=None, replication_factor=None, timeout_ms=10000): + if num_partitions is None: + num_partitions = self.partitions + if replication_factor is None: + replication_factor = self.replicas + + # Try different methods to create a topic, from the fastest to the slowest + if self.auto_create_topic and num_partitions == self.partitions and replication_factor == self.replicas: + self._create_topic_via_metadata(topic_name, timeout_ms) + elif env_kafka_version() >= (0, 10, 1, 0): + try: + self._create_topic_via_admin_api(topic_name, num_partitions, replication_factor, timeout_ms) + except InvalidReplicationFactorError: + # wait and try again + # on travis the brokers sometimes take a while to find themselves + time.sleep(0.5) + self._create_topic_via_admin_api(topic_name, num_partitions, replication_factor, timeout_ms) + else: + self._create_topic_via_cli(topic_name, num_partitions, replication_factor) + + def _create_topic_via_metadata(self, topic_name, timeout_ms=10000): + self._send_request(MetadataRequest[0]([topic_name]), timeout_ms) + + def _create_topic_via_admin_api(self, topic_name, num_partitions, replication_factor, timeout_ms=10000): + request = CreateTopicsRequest[0]([(topic_name, num_partitions, + replication_factor, [], [])], timeout_ms) + response = self._send_request(request, timeout=timeout_ms) + for topic_result in response.topic_errors: + error_code = topic_result[1] + if error_code != 0: + raise errors.for_code(error_code) + + def _create_topic_via_cli(self, topic_name, num_partitions, replication_factor): + args = self.kafka_run_class_args('kafka.admin.TopicCommand', + '--zookeeper', '%s:%s/%s' % (self.zookeeper.host, + self.zookeeper.port, + self.zk_chroot), + '--create', + '--topic', topic_name, + '--partitions', self.partitions \ + if num_partitions is None else num_partitions, + '--replication-factor', self.replicas \ + if replication_factor is None \ + else replication_factor) + if env_kafka_version() >= (0, 10): + args.append('--if-not-exists') + env = self.kafka_run_class_env() + proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = proc.communicate() + if proc.returncode != 0: + if 'kafka.common.TopicExistsException' not in stdout: + self.out("Failed to create topic %s" % (topic_name,)) + self.out(stdout) + self.out(stderr) + raise RuntimeError("Failed to create topic %s" % (topic_name,)) + + def get_topic_names(self): + args = self.kafka_run_class_args('kafka.admin.TopicCommand', + '--zookeeper', '%s:%s/%s' % (self.zookeeper.host, + self.zookeeper.port, + self.zk_chroot), + '--list' + ) + env = self.kafka_run_class_env() + env.pop('KAFKA_LOG4J_OPTS') + proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = proc.communicate() + if proc.returncode != 0: + self.out("Failed to list topics!") + self.out(stdout) + self.out(stderr) + raise RuntimeError("Failed to list topics!") + return stdout.decode().splitlines(False) + + def create_topics(self, topic_names, num_partitions=None, replication_factor=None): + for topic_name in topic_names: + self._create_topic(topic_name, num_partitions, replication_factor) + + def _enrich_client_params(self, params, **defaults): + params = params.copy() + for key, value in defaults.items(): + params.setdefault(key, value) + params.setdefault('bootstrap_servers', self.bootstrap_server()) + if self.sasl_enabled: + params.setdefault('sasl_mechanism', self.sasl_mechanism) + params.setdefault('security_protocol', self.transport) + if self.sasl_mechanism in ('PLAIN', 'SCRAM-SHA-256', 'SCRAM-SHA-512'): + params.setdefault('sasl_plain_username', self.broker_user) + params.setdefault('sasl_plain_password', self.broker_password) + return params + + @staticmethod + def _create_many_clients(cnt, cls, *args, **params): + client_id = params['client_id'] + for _ in range(cnt): + params['client_id'] = '%s_%s' % (client_id, random_string(4)) + yield cls(*args, **params) + + def get_clients(self, cnt=1, **params): + params = self._enrich_client_params(params, client_id='client') + for client in self._create_many_clients(cnt, KafkaClient, **params): + yield client + + def get_admin_clients(self, cnt, **params): + params = self._enrich_client_params(params, client_id='admin_client') + for client in self._create_many_clients(cnt, KafkaAdminClient, **params): + yield client + + def get_consumers(self, cnt, topics, **params): + params = self._enrich_client_params( + params, client_id='consumer', heartbeat_interval_ms=500, auto_offset_reset='earliest' + ) + for client in self._create_many_clients(cnt, KafkaConsumer, *topics, **params): + yield client + + def get_producers(self, cnt, **params): + params = self._enrich_client_params(params, client_id='producer') + for client in self._create_many_clients(cnt, KafkaProducer, **params): + yield client diff --git a/record/test_default_records.py b/record/test_default_records.py new file mode 100644 index 00000000..c3a7b02c --- /dev/null +++ b/record/test_default_records.py @@ -0,0 +1,208 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals +import pytest +from mock import patch +import kafka.codec +from kafka.record.default_records import ( + DefaultRecordBatch, DefaultRecordBatchBuilder +) +from kafka.errors import UnsupportedCodecError + + +@pytest.mark.parametrize("compression_type", [ + DefaultRecordBatch.CODEC_NONE, + DefaultRecordBatch.CODEC_GZIP, + DefaultRecordBatch.CODEC_SNAPPY, + DefaultRecordBatch.CODEC_LZ4 +]) +def test_read_write_serde_v2(compression_type): + builder = DefaultRecordBatchBuilder( + magic=2, compression_type=compression_type, is_transactional=1, + producer_id=123456, producer_epoch=123, base_sequence=9999, + batch_size=999999) + headers = [("header1", b"aaa"), ("header2", b"bbb")] + for offset in range(10): + builder.append( + offset, timestamp=9999999, key=b"test", value=b"Super", + headers=headers) + buffer = builder.build() + reader = DefaultRecordBatch(bytes(buffer)) + msgs = list(reader) + + assert reader.is_transactional is True + assert reader.compression_type == compression_type + assert reader.magic == 2 + assert reader.timestamp_type == 0 + assert reader.base_offset == 0 + for offset, msg in enumerate(msgs): + assert msg.offset == offset + assert msg.timestamp == 9999999 + assert msg.key == b"test" + assert msg.value == b"Super" + assert msg.headers == headers + + +def test_written_bytes_equals_size_in_bytes_v2(): + key = b"test" + value = b"Super" + headers = [("header1", b"aaa"), ("header2", b"bbb"), ("xx", None)] + builder = DefaultRecordBatchBuilder( + magic=2, compression_type=0, is_transactional=0, + producer_id=-1, producer_epoch=-1, base_sequence=-1, + batch_size=999999) + + size_in_bytes = builder.size_in_bytes( + 0, timestamp=9999999, key=key, value=value, headers=headers) + + pos = builder.size() + meta = builder.append( + 0, timestamp=9999999, key=key, value=value, headers=headers) + + assert builder.size() - pos == size_in_bytes + assert meta.size == size_in_bytes + + +def test_estimate_size_in_bytes_bigger_than_batch_v2(): + key = b"Super Key" + value = b"1" * 100 + headers = [("header1", b"aaa"), ("header2", b"bbb")] + estimate_size = DefaultRecordBatchBuilder.estimate_size_in_bytes( + key, value, headers) + + builder = DefaultRecordBatchBuilder( + magic=2, compression_type=0, is_transactional=0, + producer_id=-1, producer_epoch=-1, base_sequence=-1, + batch_size=999999) + builder.append( + 0, timestamp=9999999, key=key, value=value, headers=headers) + buf = builder.build() + assert len(buf) <= estimate_size, \ + "Estimate should always be upper bound" + + +def test_default_batch_builder_validates_arguments(): + builder = DefaultRecordBatchBuilder( + magic=2, compression_type=0, is_transactional=0, + producer_id=-1, producer_epoch=-1, base_sequence=-1, + batch_size=999999) + + # Key should not be str + with pytest.raises(TypeError): + builder.append( + 0, timestamp=9999999, key="some string", value=None, headers=[]) + + # Value should not be str + with pytest.raises(TypeError): + builder.append( + 0, timestamp=9999999, key=None, value="some string", headers=[]) + + # Timestamp should be of proper type + with pytest.raises(TypeError): + builder.append( + 0, timestamp="1243812793", key=None, value=b"some string", + headers=[]) + + # Offset of invalid type + with pytest.raises(TypeError): + builder.append( + "0", timestamp=9999999, key=None, value=b"some string", headers=[]) + + # Ok to pass value as None + builder.append( + 0, timestamp=9999999, key=b"123", value=None, headers=[]) + + # Timestamp can be None + builder.append( + 1, timestamp=None, key=None, value=b"some string", headers=[]) + + # Ok to pass offsets in not incremental order. This should not happen thou + builder.append( + 5, timestamp=9999999, key=b"123", value=None, headers=[]) + + # Check record with headers + builder.append( + 6, timestamp=9999999, key=b"234", value=None, headers=[("hkey", b"hval")]) + + # in case error handling code fails to fix inner buffer in builder + assert len(builder.build()) == 124 + + +def test_default_correct_metadata_response(): + builder = DefaultRecordBatchBuilder( + magic=2, compression_type=0, is_transactional=0, + producer_id=-1, producer_epoch=-1, base_sequence=-1, + batch_size=1024 * 1024) + meta = builder.append( + 0, timestamp=9999999, key=b"test", value=b"Super", headers=[]) + + assert meta.offset == 0 + assert meta.timestamp == 9999999 + assert meta.crc is None + assert meta.size == 16 + assert repr(meta) == ( + "DefaultRecordMetadata(offset=0, size={}, timestamp={})" + .format(meta.size, meta.timestamp) + ) + + +def test_default_batch_size_limit(): + # First message can be added even if it's too big + builder = DefaultRecordBatchBuilder( + magic=2, compression_type=0, is_transactional=0, + producer_id=-1, producer_epoch=-1, base_sequence=-1, + batch_size=1024) + + meta = builder.append( + 0, timestamp=None, key=None, value=b"M" * 2000, headers=[]) + assert meta.size > 0 + assert meta.crc is None + assert meta.offset == 0 + assert meta.timestamp is not None + assert len(builder.build()) > 2000 + + builder = DefaultRecordBatchBuilder( + magic=2, compression_type=0, is_transactional=0, + producer_id=-1, producer_epoch=-1, base_sequence=-1, + batch_size=1024) + meta = builder.append( + 0, timestamp=None, key=None, value=b"M" * 700, headers=[]) + assert meta is not None + meta = builder.append( + 1, timestamp=None, key=None, value=b"M" * 700, headers=[]) + assert meta is None + meta = builder.append( + 2, timestamp=None, key=None, value=b"M" * 700, headers=[]) + assert meta is None + assert len(builder.build()) < 1000 + + +@pytest.mark.parametrize("compression_type,name,checker_name", [ + (DefaultRecordBatch.CODEC_GZIP, "gzip", "has_gzip"), + (DefaultRecordBatch.CODEC_SNAPPY, "snappy", "has_snappy"), + (DefaultRecordBatch.CODEC_LZ4, "lz4", "has_lz4") +]) +@pytest.mark.parametrize("magic", [0, 1]) +def test_unavailable_codec(magic, compression_type, name, checker_name): + builder = DefaultRecordBatchBuilder( + magic=2, compression_type=compression_type, is_transactional=0, + producer_id=-1, producer_epoch=-1, base_sequence=-1, + batch_size=1024) + builder.append(0, timestamp=None, key=None, value=b"M" * 2000, headers=[]) + correct_buffer = builder.build() + + with patch.object(kafka.codec, checker_name) as mocked: + mocked.return_value = False + # Check that builder raises error + builder = DefaultRecordBatchBuilder( + magic=2, compression_type=compression_type, is_transactional=0, + producer_id=-1, producer_epoch=-1, base_sequence=-1, + batch_size=1024) + error_msg = "Libraries for {} compression codec not found".format(name) + with pytest.raises(UnsupportedCodecError, match=error_msg): + builder.append(0, timestamp=None, key=None, value=b"M", headers=[]) + builder.build() + + # Check that reader raises same error + batch = DefaultRecordBatch(bytes(correct_buffer)) + with pytest.raises(UnsupportedCodecError, match=error_msg): + list(batch) diff --git a/record/test_legacy_records.py b/record/test_legacy_records.py new file mode 100644 index 00000000..43970f7c --- /dev/null +++ b/record/test_legacy_records.py @@ -0,0 +1,197 @@ +from __future__ import unicode_literals +import pytest +from mock import patch +from kafka.record.legacy_records import ( + LegacyRecordBatch, LegacyRecordBatchBuilder +) +import kafka.codec +from kafka.errors import UnsupportedCodecError + + +@pytest.mark.parametrize("magic", [0, 1]) +def test_read_write_serde_v0_v1_no_compression(magic): + builder = LegacyRecordBatchBuilder( + magic=magic, compression_type=0, batch_size=9999999) + builder.append( + 0, timestamp=9999999, key=b"test", value=b"Super") + buffer = builder.build() + + batch = LegacyRecordBatch(bytes(buffer), magic) + msgs = list(batch) + assert len(msgs) == 1 + msg = msgs[0] + + assert msg.offset == 0 + assert msg.timestamp == (9999999 if magic else None) + assert msg.timestamp_type == (0 if magic else None) + assert msg.key == b"test" + assert msg.value == b"Super" + assert msg.checksum == (-2095076219 if magic else 278251978) & 0xffffffff + + +@pytest.mark.parametrize("compression_type", [ + LegacyRecordBatch.CODEC_GZIP, + LegacyRecordBatch.CODEC_SNAPPY, + LegacyRecordBatch.CODEC_LZ4 +]) +@pytest.mark.parametrize("magic", [0, 1]) +def test_read_write_serde_v0_v1_with_compression(compression_type, magic): + builder = LegacyRecordBatchBuilder( + magic=magic, compression_type=compression_type, batch_size=9999999) + for offset in range(10): + builder.append( + offset, timestamp=9999999, key=b"test", value=b"Super") + buffer = builder.build() + + batch = LegacyRecordBatch(bytes(buffer), magic) + msgs = list(batch) + + for offset, msg in enumerate(msgs): + assert msg.offset == offset + assert msg.timestamp == (9999999 if magic else None) + assert msg.timestamp_type == (0 if magic else None) + assert msg.key == b"test" + assert msg.value == b"Super" + assert msg.checksum == (-2095076219 if magic else 278251978) & \ + 0xffffffff + + +@pytest.mark.parametrize("magic", [0, 1]) +def test_written_bytes_equals_size_in_bytes(magic): + key = b"test" + value = b"Super" + builder = LegacyRecordBatchBuilder( + magic=magic, compression_type=0, batch_size=9999999) + + size_in_bytes = builder.size_in_bytes( + 0, timestamp=9999999, key=key, value=value) + + pos = builder.size() + builder.append(0, timestamp=9999999, key=key, value=value) + + assert builder.size() - pos == size_in_bytes + + +@pytest.mark.parametrize("magic", [0, 1]) +def test_estimate_size_in_bytes_bigger_than_batch(magic): + key = b"Super Key" + value = b"1" * 100 + estimate_size = LegacyRecordBatchBuilder.estimate_size_in_bytes( + magic, compression_type=0, key=key, value=value) + + builder = LegacyRecordBatchBuilder( + magic=magic, compression_type=0, batch_size=9999999) + builder.append( + 0, timestamp=9999999, key=key, value=value) + buf = builder.build() + assert len(buf) <= estimate_size, \ + "Estimate should always be upper bound" + + +@pytest.mark.parametrize("magic", [0, 1]) +def test_legacy_batch_builder_validates_arguments(magic): + builder = LegacyRecordBatchBuilder( + magic=magic, compression_type=0, batch_size=1024 * 1024) + + # Key should not be str + with pytest.raises(TypeError): + builder.append( + 0, timestamp=9999999, key="some string", value=None) + + # Value should not be str + with pytest.raises(TypeError): + builder.append( + 0, timestamp=9999999, key=None, value="some string") + + # Timestamp should be of proper type + if magic != 0: + with pytest.raises(TypeError): + builder.append( + 0, timestamp="1243812793", key=None, value=b"some string") + + # Offset of invalid type + with pytest.raises(TypeError): + builder.append( + "0", timestamp=9999999, key=None, value=b"some string") + + # Ok to pass value as None + builder.append( + 0, timestamp=9999999, key=b"123", value=None) + + # Timestamp can be None + builder.append( + 1, timestamp=None, key=None, value=b"some string") + + # Ok to pass offsets in not incremental order. This should not happen thou + builder.append( + 5, timestamp=9999999, key=b"123", value=None) + + # in case error handling code fails to fix inner buffer in builder + assert len(builder.build()) == 119 if magic else 95 + + +@pytest.mark.parametrize("magic", [0, 1]) +def test_legacy_correct_metadata_response(magic): + builder = LegacyRecordBatchBuilder( + magic=magic, compression_type=0, batch_size=1024 * 1024) + meta = builder.append( + 0, timestamp=9999999, key=b"test", value=b"Super") + + assert meta.offset == 0 + assert meta.timestamp == (9999999 if magic else -1) + assert meta.crc == (-2095076219 if magic else 278251978) & 0xffffffff + assert repr(meta) == ( + "LegacyRecordMetadata(offset=0, crc={!r}, size={}, " + "timestamp={})".format(meta.crc, meta.size, meta.timestamp) + ) + + +@pytest.mark.parametrize("magic", [0, 1]) +def test_legacy_batch_size_limit(magic): + # First message can be added even if it's too big + builder = LegacyRecordBatchBuilder( + magic=magic, compression_type=0, batch_size=1024) + meta = builder.append(0, timestamp=None, key=None, value=b"M" * 2000) + assert meta.size > 0 + assert meta.crc is not None + assert meta.offset == 0 + assert meta.timestamp is not None + assert len(builder.build()) > 2000 + + builder = LegacyRecordBatchBuilder( + magic=magic, compression_type=0, batch_size=1024) + meta = builder.append(0, timestamp=None, key=None, value=b"M" * 700) + assert meta is not None + meta = builder.append(1, timestamp=None, key=None, value=b"M" * 700) + assert meta is None + meta = builder.append(2, timestamp=None, key=None, value=b"M" * 700) + assert meta is None + assert len(builder.build()) < 1000 + + +@pytest.mark.parametrize("compression_type,name,checker_name", [ + (LegacyRecordBatch.CODEC_GZIP, "gzip", "has_gzip"), + (LegacyRecordBatch.CODEC_SNAPPY, "snappy", "has_snappy"), + (LegacyRecordBatch.CODEC_LZ4, "lz4", "has_lz4") +]) +@pytest.mark.parametrize("magic", [0, 1]) +def test_unavailable_codec(magic, compression_type, name, checker_name): + builder = LegacyRecordBatchBuilder( + magic=magic, compression_type=compression_type, batch_size=1024) + builder.append(0, timestamp=None, key=None, value=b"M") + correct_buffer = builder.build() + + with patch.object(kafka.codec, checker_name) as mocked: + mocked.return_value = False + # Check that builder raises error + builder = LegacyRecordBatchBuilder( + magic=magic, compression_type=compression_type, batch_size=1024) + error_msg = "Libraries for {} compression codec not found".format(name) + with pytest.raises(UnsupportedCodecError, match=error_msg): + builder.append(0, timestamp=None, key=None, value=b"M") + builder.build() + + # Check that reader raises same error + batch = LegacyRecordBatch(bytes(correct_buffer), magic) + with pytest.raises(UnsupportedCodecError, match=error_msg): + list(batch) diff --git a/record/test_records.py b/record/test_records.py new file mode 100644 index 00000000..9f72234a --- /dev/null +++ b/record/test_records.py @@ -0,0 +1,232 @@ +# -*- coding: utf-8 -*- +from __future__ import unicode_literals +import pytest +from kafka.record import MemoryRecords, MemoryRecordsBuilder +from kafka.errors import CorruptRecordException + +# This is real live data from Kafka 11 broker +record_batch_data_v2 = [ + # First Batch value == "123" + b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00;\x00\x00\x00\x01\x02\x03' + b'\x18\xa2p\x00\x00\x00\x00\x00\x00\x00\x00\x01]\xff{\x06<\x00\x00\x01]' + b'\xff{\x06<\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00' + b'\x00\x00\x01\x12\x00\x00\x00\x01\x06123\x00', + # Second Batch value = "" and value = "". 2 records + b'\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00@\x00\x00\x00\x02\x02\xc8' + b'\\\xbd#\x00\x00\x00\x00\x00\x01\x00\x00\x01]\xff|\xddl\x00\x00\x01]\xff' + b'|\xde\x14\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00' + b'\x00\x00\x02\x0c\x00\x00\x00\x01\x00\x00\x0e\x00\xd0\x02\x02\x01\x00' + b'\x00', + # Third batch value = "123" + b'\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00;\x00\x00\x00\x02\x02.\x0b' + b'\x85\xb7\x00\x00\x00\x00\x00\x00\x00\x00\x01]\xff|\xe7\x9d\x00\x00\x01]' + b'\xff|\xe7\x9d\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff' + b'\x00\x00\x00\x01\x12\x00\x00\x00\x01\x06123\x00' + # Fourth batch value = "hdr" with header hkey=hval + b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00E\x00\x00\x00\x00\x02\\' + b'\xd8\xefR\x00\x00\x00\x00\x00\x00\x00\x00\x01e\x85\xb6\xf3\xc1\x00\x00' + b'\x01e\x85\xb6\xf3\xc1\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff' + b'\xff\xff\x00\x00\x00\x01&\x00\x00\x00\x01\x06hdr\x02\x08hkey\x08hval' +] + +record_batch_data_v1 = [ + # First Message value == "123" + b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x19G\x86(\xc2\x01\x00\x00' + b'\x00\x01^\x18g\xab\xae\xff\xff\xff\xff\x00\x00\x00\x03123', + # Second Message value == "" + b'\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x16\xef\x98\xc9 \x01\x00' + b'\x00\x00\x01^\x18g\xaf\xc0\xff\xff\xff\xff\x00\x00\x00\x00', + # Third Message value == "" + b'\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x16_\xaf\xfb^\x01\x00\x00' + b'\x00\x01^\x18g\xb0r\xff\xff\xff\xff\x00\x00\x00\x00', + # Fourth Message value = "123" + b'\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x19\xa8\x12W \x01\x00\x00' + b'\x00\x01^\x18g\xb8\x03\xff\xff\xff\xff\x00\x00\x00\x03123' +] + +# This is real live data from Kafka 10 broker +record_batch_data_v0 = [ + # First Message value == "123" + b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11\xfe\xb0\x1d\xbf\x00' + b'\x00\xff\xff\xff\xff\x00\x00\x00\x03123', + # Second Message value == "" + b'\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x0eyWH\xe0\x00\x00\xff' + b'\xff\xff\xff\x00\x00\x00\x00', + # Third Message value == "" + b'\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x0eyWH\xe0\x00\x00\xff' + b'\xff\xff\xff\x00\x00\x00\x00', + # Fourth Message value = "123" + b'\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x11\xfe\xb0\x1d\xbf\x00' + b'\x00\xff\xff\xff\xff\x00\x00\x00\x03123' +] + + +def test_memory_records_v2(): + data_bytes = b"".join(record_batch_data_v2) + b"\x00" * 4 + records = MemoryRecords(data_bytes) + + assert records.size_in_bytes() == 303 + assert records.valid_bytes() == 299 + + assert records.has_next() is True + batch = records.next_batch() + recs = list(batch) + assert len(recs) == 1 + assert recs[0].value == b"123" + assert recs[0].key is None + assert recs[0].timestamp == 1503229838908 + assert recs[0].timestamp_type == 0 + assert recs[0].checksum is None + assert recs[0].headers == [] + + assert records.next_batch() is not None + assert records.next_batch() is not None + + batch = records.next_batch() + recs = list(batch) + assert len(recs) == 1 + assert recs[0].value == b"hdr" + assert recs[0].headers == [('hkey', b'hval')] + + assert records.has_next() is False + assert records.next_batch() is None + assert records.next_batch() is None + + +def test_memory_records_v1(): + data_bytes = b"".join(record_batch_data_v1) + b"\x00" * 4 + records = MemoryRecords(data_bytes) + + assert records.size_in_bytes() == 146 + assert records.valid_bytes() == 142 + + assert records.has_next() is True + batch = records.next_batch() + recs = list(batch) + assert len(recs) == 1 + assert recs[0].value == b"123" + assert recs[0].key is None + assert recs[0].timestamp == 1503648000942 + assert recs[0].timestamp_type == 0 + assert recs[0].checksum == 1199974594 & 0xffffffff + + assert records.next_batch() is not None + assert records.next_batch() is not None + assert records.next_batch() is not None + + assert records.has_next() is False + assert records.next_batch() is None + assert records.next_batch() is None + + +def test_memory_records_v0(): + data_bytes = b"".join(record_batch_data_v0) + records = MemoryRecords(data_bytes + b"\x00" * 4) + + assert records.size_in_bytes() == 114 + assert records.valid_bytes() == 110 + + records = MemoryRecords(data_bytes) + + assert records.has_next() is True + batch = records.next_batch() + recs = list(batch) + assert len(recs) == 1 + assert recs[0].value == b"123" + assert recs[0].key is None + assert recs[0].timestamp is None + assert recs[0].timestamp_type is None + assert recs[0].checksum == -22012481 & 0xffffffff + + assert records.next_batch() is not None + assert records.next_batch() is not None + assert records.next_batch() is not None + + assert records.has_next() is False + assert records.next_batch() is None + assert records.next_batch() is None + + +def test_memory_records_corrupt(): + records = MemoryRecords(b"") + assert records.size_in_bytes() == 0 + assert records.valid_bytes() == 0 + assert records.has_next() is False + + records = MemoryRecords(b"\x00\x00\x00") + assert records.size_in_bytes() == 3 + assert records.valid_bytes() == 0 + assert records.has_next() is False + + records = MemoryRecords( + b"\x00\x00\x00\x00\x00\x00\x00\x03" # Offset=3 + b"\x00\x00\x00\x03" # Length=3 + b"\xfe\xb0\x1d", # Some random bytes + ) + with pytest.raises(CorruptRecordException): + records.next_batch() + + +@pytest.mark.parametrize("compression_type", [0, 1, 2, 3]) +@pytest.mark.parametrize("magic", [0, 1, 2]) +def test_memory_records_builder(magic, compression_type): + builder = MemoryRecordsBuilder( + magic=magic, compression_type=compression_type, batch_size=1024 * 10) + base_size = builder.size_in_bytes() # V2 has a header before + + msg_sizes = [] + for offset in range(10): + metadata = builder.append( + timestamp=10000 + offset, key=b"test", value=b"Super") + msg_sizes.append(metadata.size) + assert metadata.offset == offset + if magic > 0: + assert metadata.timestamp == 10000 + offset + else: + assert metadata.timestamp == -1 + assert builder.next_offset() == offset + 1 + + # Error appends should not leave junk behind, like null bytes or something + with pytest.raises(TypeError): + builder.append( + timestamp=None, key="test", value="Super") # Not bytes, but str + + assert not builder.is_full() + size_before_close = builder.size_in_bytes() + assert size_before_close == sum(msg_sizes) + base_size + + # Size should remain the same after closing. No trailing bytes + builder.close() + assert builder.compression_rate() > 0 + expected_size = size_before_close * builder.compression_rate() + assert builder.is_full() + assert builder.size_in_bytes() == expected_size + buffer = builder.buffer() + assert len(buffer) == expected_size + + # We can close second time, as in retry + builder.close() + assert builder.size_in_bytes() == expected_size + assert builder.buffer() == buffer + + # Can't append after close + meta = builder.append(timestamp=None, key=b"test", value=b"Super") + assert meta is None + + +@pytest.mark.parametrize("compression_type", [0, 1, 2, 3]) +@pytest.mark.parametrize("magic", [0, 1, 2]) +def test_memory_records_builder_full(magic, compression_type): + builder = MemoryRecordsBuilder( + magic=magic, compression_type=compression_type, batch_size=1024 * 10) + + # 1 message should always be appended + metadata = builder.append( + key=None, timestamp=None, value=b"M" * 10240) + assert metadata is not None + assert builder.is_full() + + metadata = builder.append( + key=None, timestamp=None, value=b"M") + assert metadata is None + assert builder.next_offset() == 1 diff --git a/record/test_util.py b/record/test_util.py new file mode 100644 index 00000000..0b2782e7 --- /dev/null +++ b/record/test_util.py @@ -0,0 +1,96 @@ +import struct +import pytest +from kafka.record import util + + +varint_data = [ + (b"\x00", 0), + (b"\x01", -1), + (b"\x02", 1), + (b"\x7E", 63), + (b"\x7F", -64), + (b"\x80\x01", 64), + (b"\x81\x01", -65), + (b"\xFE\x7F", 8191), + (b"\xFF\x7F", -8192), + (b"\x80\x80\x01", 8192), + (b"\x81\x80\x01", -8193), + (b"\xFE\xFF\x7F", 1048575), + (b"\xFF\xFF\x7F", -1048576), + (b"\x80\x80\x80\x01", 1048576), + (b"\x81\x80\x80\x01", -1048577), + (b"\xFE\xFF\xFF\x7F", 134217727), + (b"\xFF\xFF\xFF\x7F", -134217728), + (b"\x80\x80\x80\x80\x01", 134217728), + (b"\x81\x80\x80\x80\x01", -134217729), + (b"\xFE\xFF\xFF\xFF\x7F", 17179869183), + (b"\xFF\xFF\xFF\xFF\x7F", -17179869184), + (b"\x80\x80\x80\x80\x80\x01", 17179869184), + (b"\x81\x80\x80\x80\x80\x01", -17179869185), + (b"\xFE\xFF\xFF\xFF\xFF\x7F", 2199023255551), + (b"\xFF\xFF\xFF\xFF\xFF\x7F", -2199023255552), + (b"\x80\x80\x80\x80\x80\x80\x01", 2199023255552), + (b"\x81\x80\x80\x80\x80\x80\x01", -2199023255553), + (b"\xFE\xFF\xFF\xFF\xFF\xFF\x7F", 281474976710655), + (b"\xFF\xFF\xFF\xFF\xFF\xFF\x7F", -281474976710656), + (b"\x80\x80\x80\x80\x80\x80\x80\x01", 281474976710656), + (b"\x81\x80\x80\x80\x80\x80\x80\x01", -281474976710657), + (b"\xFE\xFF\xFF\xFF\xFF\xFF\xFF\x7F", 36028797018963967), + (b"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x7F", -36028797018963968), + (b"\x80\x80\x80\x80\x80\x80\x80\x80\x01", 36028797018963968), + (b"\x81\x80\x80\x80\x80\x80\x80\x80\x01", -36028797018963969), + (b"\xFE\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x7F", 4611686018427387903), + (b"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x7F", -4611686018427387904), + (b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01", 4611686018427387904), + (b"\x81\x80\x80\x80\x80\x80\x80\x80\x80\x01", -4611686018427387905), +] + + +@pytest.mark.parametrize("encoded, decoded", varint_data) +def test_encode_varint(encoded, decoded): + res = bytearray() + util.encode_varint(decoded, res.append) + assert res == encoded + + +@pytest.mark.parametrize("encoded, decoded", varint_data) +def test_decode_varint(encoded, decoded): + # We add a bit of bytes around just to check position is calculated + # correctly + value, pos = util.decode_varint( + bytearray(b"\x01\xf0" + encoded + b"\xff\x01"), 2) + assert value == decoded + assert pos - 2 == len(encoded) + + +@pytest.mark.parametrize("encoded, decoded", varint_data) +def test_size_of_varint(encoded, decoded): + assert util.size_of_varint(decoded) == len(encoded) + + +@pytest.mark.parametrize("crc32_func", [util.crc32c_c, util.crc32c_py]) +def test_crc32c(crc32_func): + def make_crc(data): + crc = crc32_func(data) + return struct.pack(">I", crc) + assert make_crc(b"") == b"\x00\x00\x00\x00" + assert make_crc(b"a") == b"\xc1\xd0\x43\x30" + + # Took from librdkafka testcase + long_text = b"""\ + This software is provided 'as-is', without any express or implied + warranty. In no event will the author be held liable for any damages + arising from the use of this software. + + Permission is granted to anyone to use this software for any purpose, + including commercial applications, and to alter it and redistribute it + freely, subject to the following restrictions: + + 1. The origin of this software must not be misrepresented; you must not + claim that you wrote the original software. If you use this software + in a product, an acknowledgment in the product documentation would be + appreciated but is not required. + 2. Altered source versions must be plainly marked as such, and must not be + misrepresented as being the original software. + 3. This notice may not be removed or altered from any source distribution.""" + assert make_crc(long_text) == b"\x7d\xcd\xe1\x13" diff --git a/service.py b/service.py new file mode 100644 index 00000000..045d780e --- /dev/null +++ b/service.py @@ -0,0 +1,133 @@ +from __future__ import absolute_import + +import logging +import os +import re +import select +import subprocess +import sys +import threading +import time + +__all__ = [ + 'ExternalService', + 'SpawnedService', +] + +log = logging.getLogger(__name__) + + +class ExternalService(object): + def __init__(self, host, port): + log.info("Using already running service at %s:%d", host, port) + self.host = host + self.port = port + + def open(self): + pass + + def close(self): + pass + + +class SpawnedService(threading.Thread): + def __init__(self, args=None, env=None): + super(SpawnedService, self).__init__() + + if args is None: + raise TypeError("args parameter is required") + self.args = args + self.env = env + self.captured_stdout = [] + self.captured_stderr = [] + + self.should_die = threading.Event() + self.child = None + self.alive = False + self.daemon = True + log.info("Created service for command:") + log.info(" "+' '.join(self.args)) + log.debug("With environment:") + for key, value in self.env.items(): + log.debug(" {key}={value}".format(key=key, value=value)) + + def _spawn(self): + if self.alive: return + if self.child and self.child.poll() is None: return + + self.child = subprocess.Popen( + self.args, + preexec_fn=os.setsid, # to avoid propagating signals + env=self.env, + bufsize=1, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + self.alive = self.child.poll() is None + + def _despawn(self): + if self.child.poll() is None: + self.child.terminate() + self.alive = False + for _ in range(50): + if self.child.poll() is not None: + self.child = None + break + time.sleep(0.1) + else: + self.child.kill() + + def run(self): + self._spawn() + while True: + try: + (rds, _, _) = select.select([self.child.stdout, self.child.stderr], [], [], 1) + except select.error as ex: + if ex.args[0] == 4: + continue + else: + raise + + if self.child.stdout in rds: + line = self.child.stdout.readline().decode('utf-8').rstrip() + if line: + self.captured_stdout.append(line) + + if self.child.stderr in rds: + line = self.child.stderr.readline().decode('utf-8').rstrip() + if line: + self.captured_stderr.append(line) + + if self.child.poll() is not None: + self.dump_logs() + break + + if self.should_die.is_set(): + self._despawn() + break + + def dump_logs(self): + sys.stderr.write('\n'.join(self.captured_stderr)) + sys.stdout.write('\n'.join(self.captured_stdout)) + + def wait_for(self, pattern, timeout=30): + start = time.time() + while True: + if not self.is_alive(): + raise RuntimeError("Child thread died already.") + + elapsed = time.time() - start + if elapsed >= timeout: + log.error("Waiting for %r timed out after %d seconds", pattern, timeout) + return False + + if re.search(pattern, '\n'.join(self.captured_stdout), re.IGNORECASE) is not None: + log.info("Found pattern %r in %d seconds via stdout", pattern, elapsed) + return True + if re.search(pattern, '\n'.join(self.captured_stderr), re.IGNORECASE) is not None: + log.info("Found pattern %r in %d seconds via stderr", pattern, elapsed) + return True + time.sleep(0.1) + + def stop(self): + self.should_die.set() + self.join() diff --git a/test_acl_comparisons.py b/test_acl_comparisons.py new file mode 100644 index 00000000..291bf0e2 --- /dev/null +++ b/test_acl_comparisons.py @@ -0,0 +1,92 @@ +from kafka.admin.acl_resource import ACL +from kafka.admin.acl_resource import ACLOperation +from kafka.admin.acl_resource import ACLPermissionType +from kafka.admin.acl_resource import ResourcePattern +from kafka.admin.acl_resource import ResourceType +from kafka.admin.acl_resource import ACLResourcePatternType + + +def test_different_acls_are_different(): + one = ACL( + principal='User:A', + host='*', + operation=ACLOperation.ALL, + permission_type=ACLPermissionType.ALLOW, + resource_pattern=ResourcePattern( + resource_type=ResourceType.TOPIC, + resource_name='some-topic', + pattern_type=ACLResourcePatternType.LITERAL + ) + ) + + two = ACL( + principal='User:B', # Different principal + host='*', + operation=ACLOperation.ALL, + permission_type=ACLPermissionType.ALLOW, + resource_pattern=ResourcePattern( + resource_type=ResourceType.TOPIC, + resource_name='some-topic', + pattern_type=ACLResourcePatternType.LITERAL + ) + ) + + assert one != two + assert hash(one) != hash(two) + +def test_different_acls_are_different_with_glob_topics(): + one = ACL( + principal='User:A', + host='*', + operation=ACLOperation.ALL, + permission_type=ACLPermissionType.ALLOW, + resource_pattern=ResourcePattern( + resource_type=ResourceType.TOPIC, + resource_name='*', + pattern_type=ACLResourcePatternType.LITERAL + ) + ) + + two = ACL( + principal='User:B', # Different principal + host='*', + operation=ACLOperation.ALL, + permission_type=ACLPermissionType.ALLOW, + resource_pattern=ResourcePattern( + resource_type=ResourceType.TOPIC, + resource_name='*', + pattern_type=ACLResourcePatternType.LITERAL + ) + ) + + assert one != two + assert hash(one) != hash(two) + +def test_same_acls_are_same(): + one = ACL( + principal='User:A', + host='*', + operation=ACLOperation.ALL, + permission_type=ACLPermissionType.ALLOW, + resource_pattern=ResourcePattern( + resource_type=ResourceType.TOPIC, + resource_name='some-topic', + pattern_type=ACLResourcePatternType.LITERAL + ) + ) + + two = ACL( + principal='User:A', + host='*', + operation=ACLOperation.ALL, + permission_type=ACLPermissionType.ALLOW, + resource_pattern=ResourcePattern( + resource_type=ResourceType.TOPIC, + resource_name='some-topic', + pattern_type=ACLResourcePatternType.LITERAL + ) + ) + + assert one == two + assert hash(one) == hash(two) + assert len(set((one, two))) == 1 diff --git a/test_admin.py b/test_admin.py new file mode 100644 index 00000000..279f85ab --- /dev/null +++ b/test_admin.py @@ -0,0 +1,78 @@ +import pytest + +import kafka.admin +from kafka.errors import IllegalArgumentError + + +def test_config_resource(): + with pytest.raises(KeyError): + bad_resource = kafka.admin.ConfigResource('something', 'foo') + good_resource = kafka.admin.ConfigResource('broker', 'bar') + assert good_resource.resource_type == kafka.admin.ConfigResourceType.BROKER + assert good_resource.name == 'bar' + assert good_resource.configs is None + good_resource = kafka.admin.ConfigResource(kafka.admin.ConfigResourceType.TOPIC, 'baz', {'frob': 'nob'}) + assert good_resource.resource_type == kafka.admin.ConfigResourceType.TOPIC + assert good_resource.name == 'baz' + assert good_resource.configs == {'frob': 'nob'} + + +def test_new_partitions(): + good_partitions = kafka.admin.NewPartitions(6) + assert good_partitions.total_count == 6 + assert good_partitions.new_assignments is None + good_partitions = kafka.admin.NewPartitions(7, [[1, 2, 3]]) + assert good_partitions.total_count == 7 + assert good_partitions.new_assignments == [[1, 2, 3]] + + +def test_acl_resource(): + good_acl = kafka.admin.ACL( + "User:bar", + "*", + kafka.admin.ACLOperation.ALL, + kafka.admin.ACLPermissionType.ALLOW, + kafka.admin.ResourcePattern( + kafka.admin.ResourceType.TOPIC, + "foo", + kafka.admin.ACLResourcePatternType.LITERAL + ) + ) + + assert(good_acl.resource_pattern.resource_type == kafka.admin.ResourceType.TOPIC) + assert(good_acl.operation == kafka.admin.ACLOperation.ALL) + assert(good_acl.permission_type == kafka.admin.ACLPermissionType.ALLOW) + assert(good_acl.resource_pattern.pattern_type == kafka.admin.ACLResourcePatternType.LITERAL) + + with pytest.raises(IllegalArgumentError): + kafka.admin.ACL( + "User:bar", + "*", + kafka.admin.ACLOperation.ANY, + kafka.admin.ACLPermissionType.ANY, + kafka.admin.ResourcePattern( + kafka.admin.ResourceType.TOPIC, + "foo", + kafka.admin.ACLResourcePatternType.LITERAL + ) + ) + +def test_new_topic(): + with pytest.raises(IllegalArgumentError): + bad_topic = kafka.admin.NewTopic('foo', -1, -1) + with pytest.raises(IllegalArgumentError): + bad_topic = kafka.admin.NewTopic('foo', 1, -1) + with pytest.raises(IllegalArgumentError): + bad_topic = kafka.admin.NewTopic('foo', 1, 1, {1: [1, 1, 1]}) + good_topic = kafka.admin.NewTopic('foo', 1, 2) + assert good_topic.name == 'foo' + assert good_topic.num_partitions == 1 + assert good_topic.replication_factor == 2 + assert good_topic.replica_assignments == {} + assert good_topic.topic_configs == {} + good_topic = kafka.admin.NewTopic('bar', -1, -1, {1: [1, 2, 3]}, {'key': 'value'}) + assert good_topic.name == 'bar' + assert good_topic.num_partitions == -1 + assert good_topic.replication_factor == -1 + assert good_topic.replica_assignments == {1: [1, 2, 3]} + assert good_topic.topic_configs == {'key': 'value'} diff --git a/test_admin_integration.py b/test_admin_integration.py new file mode 100644 index 00000000..06c40a22 --- /dev/null +++ b/test_admin_integration.py @@ -0,0 +1,314 @@ +import pytest + +from logging import info +from test.testutil import env_kafka_version, random_string +from threading import Event, Thread +from time import time, sleep + +from kafka.admin import ( + ACLFilter, ACLOperation, ACLPermissionType, ResourcePattern, ResourceType, ACL, ConfigResource, ConfigResourceType) +from kafka.errors import (NoError, GroupCoordinatorNotAvailableError, NonEmptyGroupError, GroupIdNotFoundError) + + +@pytest.mark.skipif(env_kafka_version() < (0, 11), reason="ACL features require broker >=0.11") +def test_create_describe_delete_acls(kafka_admin_client): + """Tests that we can add, list and remove ACLs + """ + + # Check that we don't have any ACLs in the cluster + acls, error = kafka_admin_client.describe_acls( + ACLFilter( + principal=None, + host="*", + operation=ACLOperation.ANY, + permission_type=ACLPermissionType.ANY, + resource_pattern=ResourcePattern(ResourceType.TOPIC, "topic") + ) + ) + + assert error is NoError + assert len(acls) == 0 + + # Try to add an ACL + acl = ACL( + principal="User:test", + host="*", + operation=ACLOperation.READ, + permission_type=ACLPermissionType.ALLOW, + resource_pattern=ResourcePattern(ResourceType.TOPIC, "topic") + ) + result = kafka_admin_client.create_acls([acl]) + + assert len(result["failed"]) == 0 + assert len(result["succeeded"]) == 1 + + # Check that we can list the ACL we created + acl_filter = ACLFilter( + principal=None, + host="*", + operation=ACLOperation.ANY, + permission_type=ACLPermissionType.ANY, + resource_pattern=ResourcePattern(ResourceType.TOPIC, "topic") + ) + acls, error = kafka_admin_client.describe_acls(acl_filter) + + assert error is NoError + assert len(acls) == 1 + + # Remove the ACL + delete_results = kafka_admin_client.delete_acls( + [ + ACLFilter( + principal="User:test", + host="*", + operation=ACLOperation.READ, + permission_type=ACLPermissionType.ALLOW, + resource_pattern=ResourcePattern(ResourceType.TOPIC, "topic") + ) + ] + ) + + assert len(delete_results) == 1 + assert len(delete_results[0][1]) == 1 # Check number of affected ACLs + + # Make sure the ACL does not exist in the cluster anymore + acls, error = kafka_admin_client.describe_acls( + ACLFilter( + principal="*", + host="*", + operation=ACLOperation.ANY, + permission_type=ACLPermissionType.ANY, + resource_pattern=ResourcePattern(ResourceType.TOPIC, "topic") + ) + ) + + assert error is NoError + assert len(acls) == 0 + + +@pytest.mark.skipif(env_kafka_version() < (0, 11), reason="Describe config features require broker >=0.11") +def test_describe_configs_broker_resource_returns_configs(kafka_admin_client): + """Tests that describe config returns configs for broker + """ + broker_id = kafka_admin_client._client.cluster._brokers[0].nodeId + configs = kafka_admin_client.describe_configs([ConfigResource(ConfigResourceType.BROKER, broker_id)]) + + assert len(configs) == 1 + assert configs[0].resources[0][2] == ConfigResourceType.BROKER + assert configs[0].resources[0][3] == str(broker_id) + assert len(configs[0].resources[0][4]) > 1 + + +@pytest.mark.xfail(condition=True, + reason="https://github.com/dpkp/kafka-python/issues/1929", + raises=AssertionError) +@pytest.mark.skipif(env_kafka_version() < (0, 11), reason="Describe config features require broker >=0.11") +def test_describe_configs_topic_resource_returns_configs(topic, kafka_admin_client): + """Tests that describe config returns configs for topic + """ + configs = kafka_admin_client.describe_configs([ConfigResource(ConfigResourceType.TOPIC, topic)]) + + assert len(configs) == 1 + assert configs[0].resources[0][2] == ConfigResourceType.TOPIC + assert configs[0].resources[0][3] == topic + assert len(configs[0].resources[0][4]) > 1 + + +@pytest.mark.skipif(env_kafka_version() < (0, 11), reason="Describe config features require broker >=0.11") +def test_describe_configs_mixed_resources_returns_configs(topic, kafka_admin_client): + """Tests that describe config returns configs for mixed resource types (topic + broker) + """ + broker_id = kafka_admin_client._client.cluster._brokers[0].nodeId + configs = kafka_admin_client.describe_configs([ + ConfigResource(ConfigResourceType.TOPIC, topic), + ConfigResource(ConfigResourceType.BROKER, broker_id)]) + + assert len(configs) == 2 + + for config in configs: + assert (config.resources[0][2] == ConfigResourceType.TOPIC + and config.resources[0][3] == topic) or \ + (config.resources[0][2] == ConfigResourceType.BROKER + and config.resources[0][3] == str(broker_id)) + assert len(config.resources[0][4]) > 1 + + +@pytest.mark.skipif(env_kafka_version() < (0, 11), reason="Describe config features require broker >=0.11") +def test_describe_configs_invalid_broker_id_raises(kafka_admin_client): + """Tests that describe config raises exception on non-integer broker id + """ + broker_id = "str" + + with pytest.raises(ValueError): + configs = kafka_admin_client.describe_configs([ConfigResource(ConfigResourceType.BROKER, broker_id)]) + + +@pytest.mark.skipif(env_kafka_version() < (0, 11), reason='Describe consumer group requires broker >=0.11') +def test_describe_consumer_group_does_not_exist(kafka_admin_client): + """Tests that the describe consumer group call fails if the group coordinator is not available + """ + with pytest.raises(GroupCoordinatorNotAvailableError): + group_description = kafka_admin_client.describe_consumer_groups(['test']) + + +@pytest.mark.skipif(env_kafka_version() < (0, 11), reason='Describe consumer group requires broker >=0.11') +def test_describe_consumer_group_exists(kafka_admin_client, kafka_consumer_factory, topic): + """Tests that the describe consumer group call returns valid consumer group information + This test takes inspiration from the test 'test_group' in test_consumer_group.py. + """ + consumers = {} + stop = {} + threads = {} + random_group_id = 'test-group-' + random_string(6) + group_id_list = [random_group_id, random_group_id + '_2'] + generations = {group_id_list[0]: set(), group_id_list[1]: set()} + def consumer_thread(i, group_id): + assert i not in consumers + assert i not in stop + stop[i] = Event() + consumers[i] = kafka_consumer_factory(group_id=group_id) + while not stop[i].is_set(): + consumers[i].poll(20) + consumers[i].close() + consumers[i] = None + stop[i] = None + + num_consumers = 3 + for i in range(num_consumers): + group_id = group_id_list[i % 2] + t = Thread(target=consumer_thread, args=(i, group_id,)) + t.start() + threads[i] = t + + try: + timeout = time() + 35 + while True: + for c in range(num_consumers): + + # Verify all consumers have been created + if c not in consumers: + break + + # Verify all consumers have an assignment + elif not consumers[c].assignment(): + break + + # If all consumers exist and have an assignment + else: + + info('All consumers have assignment... checking for stable group') + # Verify all consumers are in the same generation + # then log state and break while loop + + for consumer in consumers.values(): + generations[consumer.config['group_id']].add(consumer._coordinator._generation.generation_id) + + is_same_generation = any([len(consumer_generation) == 1 for consumer_generation in generations.values()]) + + # New generation assignment is not complete until + # coordinator.rejoining = False + rejoining = any([consumer._coordinator.rejoining + for consumer in list(consumers.values())]) + + if not rejoining and is_same_generation: + break + else: + sleep(1) + assert time() < timeout, "timeout waiting for assignments" + + info('Group stabilized; verifying assignment') + output = kafka_admin_client.describe_consumer_groups(group_id_list) + assert len(output) == 2 + consumer_groups = set() + for consumer_group in output: + assert(consumer_group.group in group_id_list) + if consumer_group.group == group_id_list[0]: + assert(len(consumer_group.members) == 2) + else: + assert(len(consumer_group.members) == 1) + for member in consumer_group.members: + assert(member.member_metadata.subscription[0] == topic) + assert(member.member_assignment.assignment[0][0] == topic) + consumer_groups.add(consumer_group.group) + assert(sorted(list(consumer_groups)) == group_id_list) + finally: + info('Shutting down %s consumers', num_consumers) + for c in range(num_consumers): + info('Stopping consumer %s', c) + stop[c].set() + threads[c].join() + threads[c] = None + + +@pytest.mark.skipif(env_kafka_version() < (1, 1), reason="Delete consumer groups requires broker >=1.1") +def test_delete_consumergroups(kafka_admin_client, kafka_consumer_factory, send_messages): + random_group_id = 'test-group-' + random_string(6) + group1 = random_group_id + "_1" + group2 = random_group_id + "_2" + group3 = random_group_id + "_3" + + send_messages(range(0, 100), partition=0) + consumer1 = kafka_consumer_factory(group_id=group1) + next(consumer1) + consumer1.close() + + consumer2 = kafka_consumer_factory(group_id=group2) + next(consumer2) + consumer2.close() + + consumer3 = kafka_consumer_factory(group_id=group3) + next(consumer3) + consumer3.close() + + consumergroups = {group_id for group_id, _ in kafka_admin_client.list_consumer_groups()} + assert group1 in consumergroups + assert group2 in consumergroups + assert group3 in consumergroups + + delete_results = { + group_id: error + for group_id, error in kafka_admin_client.delete_consumer_groups([group1, group2]) + } + assert delete_results[group1] == NoError + assert delete_results[group2] == NoError + assert group3 not in delete_results + + consumergroups = {group_id for group_id, _ in kafka_admin_client.list_consumer_groups()} + assert group1 not in consumergroups + assert group2 not in consumergroups + assert group3 in consumergroups + + +@pytest.mark.skipif(env_kafka_version() < (1, 1), reason="Delete consumer groups requires broker >=1.1") +def test_delete_consumergroups_with_errors(kafka_admin_client, kafka_consumer_factory, send_messages): + random_group_id = 'test-group-' + random_string(6) + group1 = random_group_id + "_1" + group2 = random_group_id + "_2" + group3 = random_group_id + "_3" + + send_messages(range(0, 100), partition=0) + consumer1 = kafka_consumer_factory(group_id=group1) + next(consumer1) + consumer1.close() + + consumer2 = kafka_consumer_factory(group_id=group2) + next(consumer2) + + consumergroups = {group_id for group_id, _ in kafka_admin_client.list_consumer_groups()} + assert group1 in consumergroups + assert group2 in consumergroups + assert group3 not in consumergroups + + delete_results = { + group_id: error + for group_id, error in kafka_admin_client.delete_consumer_groups([group1, group2, group3]) + } + + assert delete_results[group1] == NoError + assert delete_results[group2] == NonEmptyGroupError + assert delete_results[group3] == GroupIdNotFoundError + + consumergroups = {group_id for group_id, _ in kafka_admin_client.list_consumer_groups()} + assert group1 not in consumergroups + assert group2 in consumergroups + assert group3 not in consumergroups diff --git a/test_api_object_implementation.py b/test_api_object_implementation.py new file mode 100644 index 00000000..da80f148 --- /dev/null +++ b/test_api_object_implementation.py @@ -0,0 +1,18 @@ +import abc +import pytest + +from kafka.protocol.api import Request +from kafka.protocol.api import Response + + +attr_names = [n for n in dir(Request) if isinstance(getattr(Request, n), abc.abstractproperty)] +@pytest.mark.parametrize('klass', Request.__subclasses__()) +@pytest.mark.parametrize('attr_name', attr_names) +def test_request_type_conformance(klass, attr_name): + assert hasattr(klass, attr_name) + +attr_names = [n for n in dir(Response) if isinstance(getattr(Response, n), abc.abstractproperty)] +@pytest.mark.parametrize('klass', Response.__subclasses__()) +@pytest.mark.parametrize('attr_name', attr_names) +def test_response_type_conformance(klass, attr_name): + assert hasattr(klass, attr_name) diff --git a/test_assignors.py b/test_assignors.py new file mode 100644 index 00000000..858ef426 --- /dev/null +++ b/test_assignors.py @@ -0,0 +1,871 @@ +# pylint: skip-file +from __future__ import absolute_import + +from collections import defaultdict +from random import randint, sample + +import pytest + +from kafka.structs import TopicPartition +from kafka.coordinator.assignors.range import RangePartitionAssignor +from kafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor +from kafka.coordinator.assignors.sticky.sticky_assignor import StickyPartitionAssignor, StickyAssignorUserDataV1 +from kafka.coordinator.protocol import ConsumerProtocolMemberAssignment, ConsumerProtocolMemberMetadata +from kafka.vendor import six + + +@pytest.fixture(autouse=True) +def reset_sticky_assignor(): + yield + StickyPartitionAssignor.member_assignment = None + StickyPartitionAssignor.generation = -1 + + +def create_cluster(mocker, topics, topics_partitions=None, topic_partitions_lambda=None): + cluster = mocker.MagicMock() + cluster.topics.return_value = topics + if topics_partitions is not None: + cluster.partitions_for_topic.return_value = topics_partitions + if topic_partitions_lambda is not None: + cluster.partitions_for_topic.side_effect = topic_partitions_lambda + return cluster + + +def test_assignor_roundrobin(mocker): + assignor = RoundRobinPartitionAssignor + + member_metadata = { + 'C0': assignor.metadata({'t0', 't1'}), + 'C1': assignor.metadata({'t0', 't1'}), + } + + cluster = create_cluster(mocker, {'t0', 't1'}, topics_partitions={0, 1, 2}) + ret = assignor.assign(cluster, member_metadata) + expected = { + 'C0': ConsumerProtocolMemberAssignment( + assignor.version, [('t0', [0, 2]), ('t1', [1])], b''), + 'C1': ConsumerProtocolMemberAssignment( + assignor.version, [('t0', [1]), ('t1', [0, 2])], b'') + } + assert ret == expected + assert set(ret) == set(expected) + for member in ret: + assert ret[member].encode() == expected[member].encode() + + +def test_assignor_range(mocker): + assignor = RangePartitionAssignor + + member_metadata = { + 'C0': assignor.metadata({'t0', 't1'}), + 'C1': assignor.metadata({'t0', 't1'}), + } + + cluster = create_cluster(mocker, {'t0', 't1'}, topics_partitions={0, 1, 2}) + ret = assignor.assign(cluster, member_metadata) + expected = { + 'C0': ConsumerProtocolMemberAssignment( + assignor.version, [('t0', [0, 1]), ('t1', [0, 1])], b''), + 'C1': ConsumerProtocolMemberAssignment( + assignor.version, [('t0', [2]), ('t1', [2])], b'') + } + assert ret == expected + assert set(ret) == set(expected) + for member in ret: + assert ret[member].encode() == expected[member].encode() + + +def test_sticky_assignor1(mocker): + """ + Given: there are three consumers C0, C1, C2, + four topics t0, t1, t2, t3, and each topic has 2 partitions, + resulting in partitions t0p0, t0p1, t1p0, t1p1, t2p0, t2p1, t3p0, t3p1. + Each consumer is subscribed to all three topics. + Then: perform fresh assignment + Expected: the assignment is + - C0: [t0p0, t1p1, t3p0] + - C1: [t0p1, t2p0, t3p1] + - C2: [t1p0, t2p1] + Then: remove C1 consumer and perform the reassignment + Expected: the new assignment is + - C0 [t0p0, t1p1, t2p0, t3p0] + - C2 [t0p1, t1p0, t2p1, t3p1] + """ + cluster = create_cluster(mocker, topics={'t0', 't1', 't2', 't3'}, topics_partitions={0, 1}) + + subscriptions = { + 'C0': {'t0', 't1', 't2', 't3'}, + 'C1': {'t0', 't1', 't2', 't3'}, + 'C2': {'t0', 't1', 't2', 't3'}, + } + member_metadata = make_member_metadata(subscriptions) + + sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + expected_assignment = { + 'C0': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t0', [0]), ('t1', [1]), ('t3', [0])], b''), + 'C1': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t0', [1]), ('t2', [0]), ('t3', [1])], b''), + 'C2': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t1', [0]), ('t2', [1])], b''), + } + assert_assignment(sticky_assignment, expected_assignment) + + del subscriptions['C1'] + member_metadata = {} + for member, topics in six.iteritems(subscriptions): + member_metadata[member] = StickyPartitionAssignor._metadata(topics, sticky_assignment[member].partitions()) + + sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + expected_assignment = { + 'C0': ConsumerProtocolMemberAssignment( + StickyPartitionAssignor.version, [('t0', [0]), ('t1', [1]), ('t2', [0]), ('t3', [0])], b'' + ), + 'C2': ConsumerProtocolMemberAssignment( + StickyPartitionAssignor.version, [('t0', [1]), ('t1', [0]), ('t2', [1]), ('t3', [1])], b'' + ), + } + assert_assignment(sticky_assignment, expected_assignment) + + +def test_sticky_assignor2(mocker): + """ + Given: there are three consumers C0, C1, C2, + and three topics t0, t1, t2, with 1, 2, and 3 partitions respectively. + Therefore, the partitions are t0p0, t1p0, t1p1, t2p0, t2p1, t2p2. + C0 is subscribed to t0; + C1 is subscribed to t0, t1; + and C2 is subscribed to t0, t1, t2. + Then: perform the assignment + Expected: the assignment is + - C0 [t0p0] + - C1 [t1p0, t1p1] + - C2 [t2p0, t2p1, t2p2] + Then: remove C0 and perform the assignment + Expected: the assignment is + - C1 [t0p0, t1p0, t1p1] + - C2 [t2p0, t2p1, t2p2] + """ + + partitions = {'t0': {0}, 't1': {0, 1}, 't2': {0, 1, 2}} + cluster = create_cluster(mocker, topics={'t0', 't1', 't2'}, topic_partitions_lambda=lambda t: partitions[t]) + + subscriptions = { + 'C0': {'t0'}, + 'C1': {'t0', 't1'}, + 'C2': {'t0', 't1', 't2'}, + } + member_metadata = {} + for member, topics in six.iteritems(subscriptions): + member_metadata[member] = StickyPartitionAssignor._metadata(topics, []) + + sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + expected_assignment = { + 'C0': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t0', [0])], b''), + 'C1': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t1', [0, 1])], b''), + 'C2': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t2', [0, 1, 2])], b''), + } + assert_assignment(sticky_assignment, expected_assignment) + + del subscriptions['C0'] + member_metadata = {} + for member, topics in six.iteritems(subscriptions): + member_metadata[member] = StickyPartitionAssignor._metadata(topics, sticky_assignment[member].partitions()) + + sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + expected_assignment = { + 'C1': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t0', [0]), ('t1', [0, 1])], b''), + 'C2': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t2', [0, 1, 2])], b''), + } + assert_assignment(sticky_assignment, expected_assignment) + + +def test_sticky_one_consumer_no_topic(mocker): + cluster = create_cluster(mocker, topics={}, topics_partitions={}) + + subscriptions = { + 'C': set(), + } + member_metadata = make_member_metadata(subscriptions) + + sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + expected_assignment = { + 'C': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [], b''), + } + assert_assignment(sticky_assignment, expected_assignment) + + +def test_sticky_one_consumer_nonexisting_topic(mocker): + cluster = create_cluster(mocker, topics={}, topics_partitions={}) + + subscriptions = { + 'C': {'t'}, + } + member_metadata = make_member_metadata(subscriptions) + + sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + expected_assignment = { + 'C': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [], b''), + } + assert_assignment(sticky_assignment, expected_assignment) + + +def test_sticky_one_consumer_one_topic(mocker): + cluster = create_cluster(mocker, topics={'t'}, topics_partitions={0, 1, 2}) + + subscriptions = { + 'C': {'t'}, + } + member_metadata = make_member_metadata(subscriptions) + + sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + expected_assignment = { + 'C': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t', [0, 1, 2])], b''), + } + assert_assignment(sticky_assignment, expected_assignment) + + +def test_sticky_should_only_assign_partitions_from_subscribed_topics(mocker): + cluster = create_cluster(mocker, topics={'t', 'other-t'}, topics_partitions={0, 1, 2}) + + subscriptions = { + 'C': {'t'}, + } + member_metadata = make_member_metadata(subscriptions) + + sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + expected_assignment = { + 'C': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t', [0, 1, 2])], b''), + } + assert_assignment(sticky_assignment, expected_assignment) + + +def test_sticky_one_consumer_multiple_topics(mocker): + cluster = create_cluster(mocker, topics={'t1', 't2'}, topics_partitions={0, 1, 2}) + + subscriptions = { + 'C': {'t1', 't2'}, + } + member_metadata = make_member_metadata(subscriptions) + + sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + expected_assignment = { + 'C': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t1', [0, 1, 2]), ('t2', [0, 1, 2])], b''), + } + assert_assignment(sticky_assignment, expected_assignment) + + +def test_sticky_two_consumers_one_topic_one_partition(mocker): + cluster = create_cluster(mocker, topics={'t'}, topics_partitions={0}) + + subscriptions = { + 'C1': {'t'}, + 'C2': {'t'}, + } + member_metadata = make_member_metadata(subscriptions) + + sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + expected_assignment = { + 'C1': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t', [0])], b''), + 'C2': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [], b''), + } + assert_assignment(sticky_assignment, expected_assignment) + + +def test_sticky_two_consumers_one_topic_two_partitions(mocker): + cluster = create_cluster(mocker, topics={'t'}, topics_partitions={0, 1}) + + subscriptions = { + 'C1': {'t'}, + 'C2': {'t'}, + } + member_metadata = make_member_metadata(subscriptions) + + sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + expected_assignment = { + 'C1': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t', [0])], b''), + 'C2': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t', [1])], b''), + } + assert_assignment(sticky_assignment, expected_assignment) + + +def test_sticky_multiple_consumers_mixed_topic_subscriptions(mocker): + partitions = {'t1': {0, 1, 2}, 't2': {0, 1}} + cluster = create_cluster(mocker, topics={'t1', 't2'}, topic_partitions_lambda=lambda t: partitions[t]) + + subscriptions = { + 'C1': {'t1'}, + 'C2': {'t1', 't2'}, + 'C3': {'t1'}, + } + member_metadata = make_member_metadata(subscriptions) + + sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + expected_assignment = { + 'C1': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t1', [0, 2])], b''), + 'C2': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t2', [0, 1])], b''), + 'C3': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t1', [1])], b''), + } + assert_assignment(sticky_assignment, expected_assignment) + + +def test_sticky_add_remove_consumer_one_topic(mocker): + cluster = create_cluster(mocker, topics={'t'}, topics_partitions={0, 1, 2}) + + subscriptions = { + 'C1': {'t'}, + } + member_metadata = make_member_metadata(subscriptions) + + assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + expected_assignment = { + 'C1': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t', [0, 1, 2])], b''), + } + assert_assignment(assignment, expected_assignment) + + subscriptions = { + 'C1': {'t'}, + 'C2': {'t'}, + } + member_metadata = {} + for member, topics in six.iteritems(subscriptions): + member_metadata[member] = StickyPartitionAssignor._metadata( + topics, assignment[member].partitions() if member in assignment else [] + ) + + assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + verify_validity_and_balance(subscriptions, assignment) + + subscriptions = { + 'C2': {'t'}, + } + member_metadata = {} + for member, topics in six.iteritems(subscriptions): + member_metadata[member] = StickyPartitionAssignor._metadata(topics, assignment[member].partitions()) + + assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + verify_validity_and_balance(subscriptions, assignment) + assert len(assignment['C2'].assignment[0][1]) == 3 + + +def test_sticky_add_remove_topic_two_consumers(mocker): + cluster = create_cluster(mocker, topics={'t1', 't2'}, topics_partitions={0, 1, 2}) + + subscriptions = { + 'C1': {'t1'}, + 'C2': {'t1'}, + } + member_metadata = make_member_metadata(subscriptions) + + sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + expected_assignment = { + 'C1': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t1', [0, 2])], b''), + 'C2': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t1', [1])], b''), + } + assert_assignment(sticky_assignment, expected_assignment) + + subscriptions = { + 'C1': {'t1', 't2'}, + 'C2': {'t1', 't2'}, + } + member_metadata = {} + for member, topics in six.iteritems(subscriptions): + member_metadata[member] = StickyPartitionAssignor._metadata(topics, sticky_assignment[member].partitions()) + + sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + expected_assignment = { + 'C1': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t1', [0, 2]), ('t2', [1])], b''), + 'C2': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t1', [1]), ('t2', [0, 2])], b''), + } + assert_assignment(sticky_assignment, expected_assignment) + + subscriptions = { + 'C1': {'t2'}, + 'C2': {'t2'}, + } + member_metadata = {} + for member, topics in six.iteritems(subscriptions): + member_metadata[member] = StickyPartitionAssignor._metadata(topics, sticky_assignment[member].partitions()) + + sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + expected_assignment = { + 'C1': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t2', [1])], b''), + 'C2': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t2', [0, 2])], b''), + } + assert_assignment(sticky_assignment, expected_assignment) + + +def test_sticky_reassignment_after_one_consumer_leaves(mocker): + partitions = dict([('t{}'.format(i), set(range(i))) for i in range(1, 20)]) + cluster = create_cluster( + mocker, topics=set(['t{}'.format(i) for i in range(1, 20)]), topic_partitions_lambda=lambda t: partitions[t] + ) + + subscriptions = {} + for i in range(1, 20): + topics = set() + for j in range(1, i + 1): + topics.add('t{}'.format(j)) + subscriptions['C{}'.format(i)] = topics + + member_metadata = make_member_metadata(subscriptions) + + assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + verify_validity_and_balance(subscriptions, assignment) + + del subscriptions['C10'] + member_metadata = {} + for member, topics in six.iteritems(subscriptions): + member_metadata[member] = StickyPartitionAssignor._metadata(topics, assignment[member].partitions()) + + assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + verify_validity_and_balance(subscriptions, assignment) + assert StickyPartitionAssignor._latest_partition_movements.are_sticky() + + +def test_sticky_reassignment_after_one_consumer_added(mocker): + cluster = create_cluster(mocker, topics={'t'}, topics_partitions=set(range(20))) + + subscriptions = defaultdict(set) + for i in range(1, 10): + subscriptions['C{}'.format(i)] = {'t'} + + member_metadata = make_member_metadata(subscriptions) + + assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + verify_validity_and_balance(subscriptions, assignment) + + subscriptions['C10'] = {'t'} + member_metadata = {} + for member, topics in six.iteritems(subscriptions): + member_metadata[member] = StickyPartitionAssignor._metadata( + topics, assignment[member].partitions() if member in assignment else [] + ) + assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + verify_validity_and_balance(subscriptions, assignment) + assert StickyPartitionAssignor._latest_partition_movements.are_sticky() + + +def test_sticky_same_subscriptions(mocker): + partitions = dict([('t{}'.format(i), set(range(i))) for i in range(1, 15)]) + cluster = create_cluster( + mocker, topics=set(['t{}'.format(i) for i in range(1, 15)]), topic_partitions_lambda=lambda t: partitions[t] + ) + + subscriptions = defaultdict(set) + for i in range(1, 9): + for j in range(1, len(six.viewkeys(partitions)) + 1): + subscriptions['C{}'.format(i)].add('t{}'.format(j)) + + member_metadata = make_member_metadata(subscriptions) + + assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + verify_validity_and_balance(subscriptions, assignment) + + del subscriptions['C5'] + member_metadata = {} + for member, topics in six.iteritems(subscriptions): + member_metadata[member] = StickyPartitionAssignor._metadata(topics, assignment[member].partitions()) + assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + verify_validity_and_balance(subscriptions, assignment) + assert StickyPartitionAssignor._latest_partition_movements.are_sticky() + + +def test_sticky_large_assignment_with_multiple_consumers_leaving(mocker): + n_topics = 40 + n_consumers = 200 + + all_topics = set(['t{}'.format(i) for i in range(1, n_topics + 1)]) + partitions = dict([(t, set(range(1, randint(0, 10) + 1))) for t in all_topics]) + cluster = create_cluster(mocker, topics=all_topics, topic_partitions_lambda=lambda t: partitions[t]) + + subscriptions = defaultdict(set) + for i in range(1, n_consumers + 1): + for j in range(0, randint(1, 20)): + subscriptions['C{}'.format(i)].add('t{}'.format(randint(1, n_topics))) + + member_metadata = make_member_metadata(subscriptions) + + assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + verify_validity_and_balance(subscriptions, assignment) + + member_metadata = {} + for member, topics in six.iteritems(subscriptions): + member_metadata[member] = StickyPartitionAssignor._metadata(topics, assignment[member].partitions()) + + for i in range(50): + member = 'C{}'.format(randint(1, n_consumers)) + if member in subscriptions: + del subscriptions[member] + del member_metadata[member] + + assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + verify_validity_and_balance(subscriptions, assignment) + assert StickyPartitionAssignor._latest_partition_movements.are_sticky() + + +def test_new_subscription(mocker): + cluster = create_cluster(mocker, topics={'t1', 't2', 't3', 't4'}, topics_partitions={0}) + + subscriptions = defaultdict(set) + for i in range(3): + for j in range(i, 3 * i - 2 + 1): + subscriptions['C{}'.format(i)].add('t{}'.format(j)) + + member_metadata = make_member_metadata(subscriptions) + + assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + verify_validity_and_balance(subscriptions, assignment) + + subscriptions['C0'].add('t1') + member_metadata = {} + for member, topics in six.iteritems(subscriptions): + member_metadata[member] = StickyPartitionAssignor._metadata(topics, []) + + assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + verify_validity_and_balance(subscriptions, assignment) + assert StickyPartitionAssignor._latest_partition_movements.are_sticky() + + +def test_move_existing_assignments(mocker): + cluster = create_cluster(mocker, topics={'t1', 't2', 't3', 't4', 't5', 't6'}, topics_partitions={0}) + + subscriptions = { + 'C1': {'t1', 't2'}, + 'C2': {'t1', 't2', 't3', 't4'}, + 'C3': {'t2', 't3', 't4', 't5', 't6'}, + } + member_assignments = { + 'C1': [TopicPartition('t1', 0)], + 'C2': [TopicPartition('t2', 0), TopicPartition('t3', 0)], + 'C3': [TopicPartition('t4', 0), TopicPartition('t5', 0), TopicPartition('t6', 0)], + } + + member_metadata = {} + for member, topics in six.iteritems(subscriptions): + member_metadata[member] = StickyPartitionAssignor._metadata(topics, member_assignments[member]) + + assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + verify_validity_and_balance(subscriptions, assignment) + + +def test_stickiness(mocker): + cluster = create_cluster(mocker, topics={'t'}, topics_partitions={0, 1, 2}) + subscriptions = { + 'C1': {'t'}, + 'C2': {'t'}, + 'C3': {'t'}, + 'C4': {'t'}, + } + member_metadata = make_member_metadata(subscriptions) + + assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + verify_validity_and_balance(subscriptions, assignment) + partitions_assigned = {} + for consumer, consumer_assignment in six.iteritems(assignment): + assert ( + len(consumer_assignment.partitions()) <= 1 + ), 'Consumer {} is assigned more topic partitions than expected.'.format(consumer) + if len(consumer_assignment.partitions()) == 1: + partitions_assigned[consumer] = consumer_assignment.partitions()[0] + + # removing the potential group leader + del subscriptions['C1'] + member_metadata = {} + for member, topics in six.iteritems(subscriptions): + member_metadata[member] = StickyPartitionAssignor._metadata(topics, assignment[member].partitions()) + + assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + verify_validity_and_balance(subscriptions, assignment) + assert StickyPartitionAssignor._latest_partition_movements.are_sticky() + + for consumer, consumer_assignment in six.iteritems(assignment): + assert ( + len(consumer_assignment.partitions()) <= 1 + ), 'Consumer {} is assigned more topic partitions than expected.'.format(consumer) + assert ( + consumer not in partitions_assigned or partitions_assigned[consumer] in consumer_assignment.partitions() + ), 'Stickiness was not honored for consumer {}'.format(consumer) + + +def test_assignment_updated_for_deleted_topic(mocker): + def topic_partitions(topic): + if topic == 't1': + return {0} + if topic == 't3': + return set(range(100)) + + cluster = create_cluster(mocker, topics={'t1', 't3'}, topic_partitions_lambda=topic_partitions) + + subscriptions = { + 'C': {'t1', 't2', 't3'}, + } + member_metadata = make_member_metadata(subscriptions) + + sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + expected_assignment = { + 'C': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t1', [0]), ('t3', list(range(100)))], b''), + } + assert_assignment(sticky_assignment, expected_assignment) + + +def test_no_exceptions_when_only_subscribed_topic_is_deleted(mocker): + cluster = create_cluster(mocker, topics={'t'}, topics_partitions={0, 1, 2}) + + subscriptions = { + 'C': {'t'}, + } + member_metadata = make_member_metadata(subscriptions) + + sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + expected_assignment = { + 'C': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t', [0, 1, 2])], b''), + } + assert_assignment(sticky_assignment, expected_assignment) + + subscriptions = { + 'C': {}, + } + member_metadata = {} + for member, topics in six.iteritems(subscriptions): + member_metadata[member] = StickyPartitionAssignor._metadata(topics, sticky_assignment[member].partitions()) + + cluster = create_cluster(mocker, topics={}, topics_partitions={}) + sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + expected_assignment = { + 'C': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [], b''), + } + assert_assignment(sticky_assignment, expected_assignment) + + +def test_conflicting_previous_assignments(mocker): + cluster = create_cluster(mocker, topics={'t'}, topics_partitions={0, 1}) + + subscriptions = { + 'C1': {'t'}, + 'C2': {'t'}, + } + member_metadata = {} + for member, topics in six.iteritems(subscriptions): + # assume both C1 and C2 have partition 1 assigned to them in generation 1 + member_metadata[member] = StickyPartitionAssignor._metadata(topics, [TopicPartition('t', 0), TopicPartition('t', 0)], 1) + + assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + verify_validity_and_balance(subscriptions, assignment) + + +@pytest.mark.parametrize( + 'execution_number,n_topics,n_consumers', [(i, randint(10, 20), randint(20, 40)) for i in range(100)] +) +def test_reassignment_with_random_subscriptions_and_changes(mocker, execution_number, n_topics, n_consumers): + all_topics = sorted(['t{}'.format(i) for i in range(1, n_topics + 1)]) + partitions = dict([(t, set(range(1, i + 1))) for i, t in enumerate(all_topics)]) + cluster = create_cluster(mocker, topics=all_topics, topic_partitions_lambda=lambda t: partitions[t]) + + subscriptions = defaultdict(set) + for i in range(n_consumers): + topics_sample = sample(all_topics, randint(1, len(all_topics) - 1)) + subscriptions['C{}'.format(i)].update(topics_sample) + + member_metadata = make_member_metadata(subscriptions) + + assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + verify_validity_and_balance(subscriptions, assignment) + + subscriptions = defaultdict(set) + for i in range(n_consumers): + topics_sample = sample(all_topics, randint(1, len(all_topics) - 1)) + subscriptions['C{}'.format(i)].update(topics_sample) + + member_metadata = {} + for member, topics in six.iteritems(subscriptions): + member_metadata[member] = StickyPartitionAssignor._metadata(topics, assignment[member].partitions()) + + assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + verify_validity_and_balance(subscriptions, assignment) + assert StickyPartitionAssignor._latest_partition_movements.are_sticky() + + +def test_assignment_with_multiple_generations1(mocker): + cluster = create_cluster(mocker, topics={'t'}, topics_partitions={0, 1, 2, 3, 4, 5}) + + member_metadata = { + 'C1': StickyPartitionAssignor._metadata({'t'}, []), + 'C2': StickyPartitionAssignor._metadata({'t'}, []), + 'C3': StickyPartitionAssignor._metadata({'t'}, []), + } + + assignment1 = StickyPartitionAssignor.assign(cluster, member_metadata) + verify_validity_and_balance({'C1': {'t'}, 'C2': {'t'}, 'C3': {'t'}}, assignment1) + assert len(assignment1['C1'].assignment[0][1]) == 2 + assert len(assignment1['C2'].assignment[0][1]) == 2 + assert len(assignment1['C3'].assignment[0][1]) == 2 + + member_metadata = { + 'C1': StickyPartitionAssignor._metadata({'t'}, assignment1['C1'].partitions()), + 'C2': StickyPartitionAssignor._metadata({'t'}, assignment1['C2'].partitions()), + } + + assignment2 = StickyPartitionAssignor.assign(cluster, member_metadata) + verify_validity_and_balance({'C1': {'t'}, 'C2': {'t'}}, assignment2) + assert len(assignment2['C1'].assignment[0][1]) == 3 + assert len(assignment2['C2'].assignment[0][1]) == 3 + assert all([partition in assignment2['C1'].assignment[0][1] for partition in assignment1['C1'].assignment[0][1]]) + assert all([partition in assignment2['C2'].assignment[0][1] for partition in assignment1['C2'].assignment[0][1]]) + assert StickyPartitionAssignor._latest_partition_movements.are_sticky() + + member_metadata = { + 'C2': StickyPartitionAssignor._metadata({'t'}, assignment2['C2'].partitions(), 2), + 'C3': StickyPartitionAssignor._metadata({'t'}, assignment1['C3'].partitions(), 1), + } + + assignment3 = StickyPartitionAssignor.assign(cluster, member_metadata) + verify_validity_and_balance({'C2': {'t'}, 'C3': {'t'}}, assignment3) + assert len(assignment3['C2'].assignment[0][1]) == 3 + assert len(assignment3['C3'].assignment[0][1]) == 3 + assert StickyPartitionAssignor._latest_partition_movements.are_sticky() + + +def test_assignment_with_multiple_generations2(mocker): + cluster = create_cluster(mocker, topics={'t'}, topics_partitions={0, 1, 2, 3, 4, 5}) + + member_metadata = { + 'C1': StickyPartitionAssignor._metadata({'t'}, []), + 'C2': StickyPartitionAssignor._metadata({'t'}, []), + 'C3': StickyPartitionAssignor._metadata({'t'}, []), + } + + assignment1 = StickyPartitionAssignor.assign(cluster, member_metadata) + verify_validity_and_balance({'C1': {'t'}, 'C2': {'t'}, 'C3': {'t'}}, assignment1) + assert len(assignment1['C1'].assignment[0][1]) == 2 + assert len(assignment1['C2'].assignment[0][1]) == 2 + assert len(assignment1['C3'].assignment[0][1]) == 2 + + member_metadata = { + 'C2': StickyPartitionAssignor._metadata({'t'}, assignment1['C2'].partitions(), 1), + } + + assignment2 = StickyPartitionAssignor.assign(cluster, member_metadata) + verify_validity_and_balance({'C2': {'t'}}, assignment2) + assert len(assignment2['C2'].assignment[0][1]) == 6 + assert all([partition in assignment2['C2'].assignment[0][1] for partition in assignment1['C2'].assignment[0][1]]) + assert StickyPartitionAssignor._latest_partition_movements.are_sticky() + + member_metadata = { + 'C1': StickyPartitionAssignor._metadata({'t'}, assignment1['C1'].partitions(), 1), + 'C2': StickyPartitionAssignor._metadata({'t'}, assignment2['C2'].partitions(), 2), + 'C3': StickyPartitionAssignor._metadata({'t'}, assignment1['C3'].partitions(), 1), + } + + assignment3 = StickyPartitionAssignor.assign(cluster, member_metadata) + verify_validity_and_balance({'C1': {'t'}, 'C2': {'t'}, 'C3': {'t'}}, assignment3) + assert StickyPartitionAssignor._latest_partition_movements.are_sticky() + assert set(assignment3['C1'].assignment[0][1]) == set(assignment1['C1'].assignment[0][1]) + assert set(assignment3['C2'].assignment[0][1]) == set(assignment1['C2'].assignment[0][1]) + assert set(assignment3['C3'].assignment[0][1]) == set(assignment1['C3'].assignment[0][1]) + + +@pytest.mark.parametrize('execution_number', range(50)) +def test_assignment_with_conflicting_previous_generations(mocker, execution_number): + cluster = create_cluster(mocker, topics={'t'}, topics_partitions={0, 1, 2, 3, 4, 5}) + + member_assignments = { + 'C1': [TopicPartition('t', p) for p in {0, 1, 4}], + 'C2': [TopicPartition('t', p) for p in {0, 2, 3}], + 'C3': [TopicPartition('t', p) for p in {3, 4, 5}], + } + member_generations = { + 'C1': 1, + 'C2': 1, + 'C3': 2, + } + member_metadata = {} + for member in six.iterkeys(member_assignments): + member_metadata[member] = StickyPartitionAssignor._metadata({'t'}, member_assignments[member], member_generations[member]) + + assignment = StickyPartitionAssignor.assign(cluster, member_metadata) + verify_validity_and_balance({'C1': {'t'}, 'C2': {'t'}, 'C3': {'t'}}, assignment) + assert StickyPartitionAssignor._latest_partition_movements.are_sticky() + + +def make_member_metadata(subscriptions): + member_metadata = {} + for member, topics in six.iteritems(subscriptions): + member_metadata[member] = StickyPartitionAssignor._metadata(topics, []) + return member_metadata + + +def assert_assignment(result_assignment, expected_assignment): + assert result_assignment == expected_assignment + assert set(result_assignment) == set(expected_assignment) + for member in result_assignment: + assert result_assignment[member].encode() == expected_assignment[member].encode() + + +def verify_validity_and_balance(subscriptions, assignment): + """ + Verifies that the given assignment is valid with respect to the given subscriptions + Validity requirements: + - each consumer is subscribed to topics of all partitions assigned to it, and + - each partition is assigned to no more than one consumer + Balance requirements: + - the assignment is fully balanced (the numbers of topic partitions assigned to consumers differ by at most one), or + - there is no topic partition that can be moved from one consumer to another with 2+ fewer topic partitions + + :param subscriptions topic subscriptions of each consumer + :param assignment: given assignment for balance check + """ + assert six.viewkeys(subscriptions) == six.viewkeys(assignment) + + consumers = sorted(six.viewkeys(assignment)) + for i in range(len(consumers)): + consumer = consumers[i] + partitions = assignment[consumer].partitions() + for partition in partitions: + assert partition.topic in subscriptions[consumer], ( + 'Error: Partition {} is assigned to consumer {}, ' + 'but it is not subscribed to topic {}\n' + 'Subscriptions: {}\n' + 'Assignments: {}'.format(partition, consumers[i], partition.topic, subscriptions, assignment) + ) + if i == len(consumers) - 1: + continue + + for j in range(i + 1, len(consumers)): + other_consumer = consumers[j] + other_partitions = assignment[other_consumer].partitions() + partitions_intersection = set(partitions).intersection(set(other_partitions)) + assert partitions_intersection == set(), ( + 'Error: Consumers {} and {} have common partitions ' + 'assigned to them: {}\n' + 'Subscriptions: {}\n' + 'Assignments: {}'.format(consumer, other_consumer, partitions_intersection, subscriptions, assignment) + ) + + if abs(len(partitions) - len(other_partitions)) <= 1: + continue + + assignments_by_topic = group_partitions_by_topic(partitions) + other_assignments_by_topic = group_partitions_by_topic(other_partitions) + if len(partitions) > len(other_partitions): + for topic in six.iterkeys(assignments_by_topic): + assert topic not in other_assignments_by_topic, ( + 'Error: Some partitions can be moved from {} ({} partitions) ' + 'to {} ({} partitions) ' + 'to achieve a better balance\n' + 'Subscriptions: {}\n' + 'Assignments: {}'.format(consumer, len(partitions), other_consumer, len(other_partitions), subscriptions, assignment) + ) + if len(other_partitions) > len(partitions): + for topic in six.iterkeys(other_assignments_by_topic): + assert topic not in assignments_by_topic, ( + 'Error: Some partitions can be moved from {} ({} partitions) ' + 'to {} ({} partitions) ' + 'to achieve a better balance\n' + 'Subscriptions: {}\n' + 'Assignments: {}'.format(other_consumer, len(other_partitions), consumer, len(partitions), subscriptions, assignment) + ) + + +def group_partitions_by_topic(partitions): + result = defaultdict(set) + for p in partitions: + result[p.topic].add(p.partition) + return result diff --git a/test_client_async.py b/test_client_async.py new file mode 100644 index 00000000..74da66a3 --- /dev/null +++ b/test_client_async.py @@ -0,0 +1,409 @@ +from __future__ import absolute_import, division + +# selectors in stdlib as of py3.4 +try: + import selectors # pylint: disable=import-error +except ImportError: + # vendored backport module + import kafka.vendor.selectors34 as selectors + +import socket +import time + +import pytest + +from kafka.client_async import KafkaClient, IdleConnectionManager +from kafka.cluster import ClusterMetadata +from kafka.conn import ConnectionStates +import kafka.errors as Errors +from kafka.future import Future +from kafka.protocol.metadata import MetadataRequest +from kafka.protocol.produce import ProduceRequest +from kafka.structs import BrokerMetadata + + +@pytest.fixture +def cli(mocker, conn): + client = KafkaClient(api_version=(0, 9)) + mocker.patch.object(client, '_selector') + client.poll(future=client.cluster.request_update()) + return client + + +def test_bootstrap(mocker, conn): + conn.state = ConnectionStates.CONNECTED + cli = KafkaClient(api_version=(0, 9)) + mocker.patch.object(cli, '_selector') + future = cli.cluster.request_update() + cli.poll(future=future) + + assert future.succeeded() + args, kwargs = conn.call_args + assert args == ('localhost', 9092, socket.AF_UNSPEC) + kwargs.pop('state_change_callback') + kwargs.pop('node_id') + assert kwargs == cli.config + conn.send.assert_called_once_with(MetadataRequest[0]([]), blocking=False) + assert cli._bootstrap_fails == 0 + assert cli.cluster.brokers() == set([BrokerMetadata(0, 'foo', 12, None), + BrokerMetadata(1, 'bar', 34, None)]) + + +def test_can_connect(cli, conn): + # Node is not in broker metadata - can't connect + assert not cli._can_connect(2) + + # Node is in broker metadata but not in _conns + assert 0 not in cli._conns + assert cli._can_connect(0) + + # Node is connected, can't reconnect + assert cli._maybe_connect(0) is True + assert not cli._can_connect(0) + + # Node is disconnected, can connect + cli._conns[0].state = ConnectionStates.DISCONNECTED + assert cli._can_connect(0) + + # Node is disconnected, but blacked out + conn.blacked_out.return_value = True + assert not cli._can_connect(0) + + +def test_maybe_connect(cli, conn): + try: + # Node not in metadata, raises AssertionError + cli._maybe_connect(2) + except AssertionError: + pass + else: + assert False, 'Exception not raised' + + # New node_id creates a conn object + assert 0 not in cli._conns + conn.state = ConnectionStates.DISCONNECTED + conn.connect.side_effect = lambda: conn._set_conn_state(ConnectionStates.CONNECTING) + assert cli._maybe_connect(0) is False + assert cli._conns[0] is conn + + +def test_conn_state_change(mocker, cli, conn): + sel = cli._selector + + node_id = 0 + cli._conns[node_id] = conn + conn.state = ConnectionStates.CONNECTING + sock = conn._sock + cli._conn_state_change(node_id, sock, conn) + assert node_id in cli._connecting + sel.register.assert_called_with(sock, selectors.EVENT_WRITE, conn) + + conn.state = ConnectionStates.CONNECTED + cli._conn_state_change(node_id, sock, conn) + assert node_id not in cli._connecting + sel.modify.assert_called_with(sock, selectors.EVENT_READ, conn) + + # Failure to connect should trigger metadata update + assert cli.cluster._need_update is False + conn.state = ConnectionStates.DISCONNECTED + cli._conn_state_change(node_id, sock, conn) + assert node_id not in cli._connecting + assert cli.cluster._need_update is True + sel.unregister.assert_called_with(sock) + + conn.state = ConnectionStates.CONNECTING + cli._conn_state_change(node_id, sock, conn) + assert node_id in cli._connecting + conn.state = ConnectionStates.DISCONNECTED + cli._conn_state_change(node_id, sock, conn) + assert node_id not in cli._connecting + + +def test_ready(mocker, cli, conn): + maybe_connect = mocker.patch.object(cli, 'maybe_connect') + node_id = 1 + cli.ready(node_id) + maybe_connect.assert_called_with(node_id) + + +def test_is_ready(mocker, cli, conn): + cli._maybe_connect(0) + cli._maybe_connect(1) + + # metadata refresh blocks ready nodes + assert cli.is_ready(0) + assert cli.is_ready(1) + cli._metadata_refresh_in_progress = True + assert not cli.is_ready(0) + assert not cli.is_ready(1) + + # requesting metadata update also blocks ready nodes + cli._metadata_refresh_in_progress = False + assert cli.is_ready(0) + assert cli.is_ready(1) + cli.cluster.request_update() + cli.cluster.config['retry_backoff_ms'] = 0 + assert not cli._metadata_refresh_in_progress + assert not cli.is_ready(0) + assert not cli.is_ready(1) + cli.cluster._need_update = False + + # if connection can't send more, not ready + assert cli.is_ready(0) + conn.can_send_more.return_value = False + assert not cli.is_ready(0) + conn.can_send_more.return_value = True + + # disconnected nodes, not ready + assert cli.is_ready(0) + conn.state = ConnectionStates.DISCONNECTED + assert not cli.is_ready(0) + + +def test_close(mocker, cli, conn): + mocker.patch.object(cli, '_selector') + + call_count = conn.close.call_count + + # Unknown node - silent + cli.close(2) + call_count += 0 + assert conn.close.call_count == call_count + + # Single node close + cli._maybe_connect(0) + assert conn.close.call_count == call_count + cli.close(0) + call_count += 1 + assert conn.close.call_count == call_count + + # All node close + cli._maybe_connect(1) + cli.close() + # +2 close: node 1, node bootstrap (node 0 already closed) + call_count += 2 + assert conn.close.call_count == call_count + + +def test_is_disconnected(cli, conn): + # False if not connected yet + conn.state = ConnectionStates.DISCONNECTED + assert not cli.is_disconnected(0) + + cli._maybe_connect(0) + assert cli.is_disconnected(0) + + conn.state = ConnectionStates.CONNECTING + assert not cli.is_disconnected(0) + + conn.state = ConnectionStates.CONNECTED + assert not cli.is_disconnected(0) + + +def test_send(cli, conn): + # Send to unknown node => raises AssertionError + try: + cli.send(2, None) + assert False, 'Exception not raised' + except AssertionError: + pass + + # Send to disconnected node => NodeNotReady + conn.state = ConnectionStates.DISCONNECTED + f = cli.send(0, None) + assert f.failed() + assert isinstance(f.exception, Errors.NodeNotReadyError) + + conn.state = ConnectionStates.CONNECTED + cli._maybe_connect(0) + # ProduceRequest w/ 0 required_acks -> no response + request = ProduceRequest[0](0, 0, []) + assert request.expect_response() is False + ret = cli.send(0, request) + assert conn.send.called_with(request) + assert isinstance(ret, Future) + + request = MetadataRequest[0]([]) + cli.send(0, request) + assert conn.send.called_with(request) + + +def test_poll(mocker): + metadata = mocker.patch.object(KafkaClient, '_maybe_refresh_metadata') + _poll = mocker.patch.object(KafkaClient, '_poll') + ifrs = mocker.patch.object(KafkaClient, 'in_flight_request_count') + ifrs.return_value = 1 + cli = KafkaClient(api_version=(0, 9)) + + # metadata timeout wins + metadata.return_value = 1000 + cli.poll() + _poll.assert_called_with(1.0) + + # user timeout wins + cli.poll(250) + _poll.assert_called_with(0.25) + + # default is request_timeout_ms + metadata.return_value = 1000000 + cli.poll() + _poll.assert_called_with(cli.config['request_timeout_ms'] / 1000.0) + + # If no in-flight-requests, drop timeout to retry_backoff_ms + ifrs.return_value = 0 + cli.poll() + _poll.assert_called_with(cli.config['retry_backoff_ms'] / 1000.0) + + +def test__poll(): + pass + + +def test_in_flight_request_count(): + pass + + +def test_least_loaded_node(): + pass + + +def test_set_topics(mocker): + request_update = mocker.patch.object(ClusterMetadata, 'request_update') + request_update.side_effect = lambda: Future() + cli = KafkaClient(api_version=(0, 10)) + + # replace 'empty' with 'non empty' + request_update.reset_mock() + fut = cli.set_topics(['t1', 't2']) + assert not fut.is_done + request_update.assert_called_with() + + # replace 'non empty' with 'same' + request_update.reset_mock() + fut = cli.set_topics(['t1', 't2']) + assert fut.is_done + assert fut.value == set(['t1', 't2']) + request_update.assert_not_called() + + # replace 'non empty' with 'empty' + request_update.reset_mock() + fut = cli.set_topics([]) + assert fut.is_done + assert fut.value == set() + request_update.assert_not_called() + + +@pytest.fixture +def client(mocker): + _poll = mocker.patch.object(KafkaClient, '_poll') + + cli = KafkaClient(request_timeout_ms=9999999, + reconnect_backoff_ms=2222, + connections_max_idle_ms=float('inf'), + api_version=(0, 9)) + + ttl = mocker.patch.object(cli.cluster, 'ttl') + ttl.return_value = 0 + return cli + + +def test_maybe_refresh_metadata_ttl(mocker, client): + client.cluster.ttl.return_value = 1234 + mocker.patch.object(KafkaClient, 'in_flight_request_count', return_value=1) + + client.poll(timeout_ms=12345678) + client._poll.assert_called_with(1.234) + + +def test_maybe_refresh_metadata_backoff(mocker, client): + mocker.patch.object(KafkaClient, 'in_flight_request_count', return_value=1) + now = time.time() + t = mocker.patch('time.time') + t.return_value = now + + client.poll(timeout_ms=12345678) + client._poll.assert_called_with(2.222) # reconnect backoff + + +def test_maybe_refresh_metadata_in_progress(mocker, client): + client._metadata_refresh_in_progress = True + mocker.patch.object(KafkaClient, 'in_flight_request_count', return_value=1) + + client.poll(timeout_ms=12345678) + client._poll.assert_called_with(9999.999) # request_timeout_ms + + +def test_maybe_refresh_metadata_update(mocker, client): + mocker.patch.object(client, 'least_loaded_node', return_value='foobar') + mocker.patch.object(client, '_can_send_request', return_value=True) + mocker.patch.object(KafkaClient, 'in_flight_request_count', return_value=1) + send = mocker.patch.object(client, 'send') + + client.poll(timeout_ms=12345678) + client._poll.assert_called_with(9999.999) # request_timeout_ms + assert client._metadata_refresh_in_progress + request = MetadataRequest[0]([]) + send.assert_called_once_with('foobar', request, wakeup=False) + + +def test_maybe_refresh_metadata_cant_send(mocker, client): + mocker.patch.object(client, 'least_loaded_node', return_value='foobar') + mocker.patch.object(client, '_can_connect', return_value=True) + mocker.patch.object(client, '_maybe_connect', return_value=True) + mocker.patch.object(client, 'maybe_connect', return_value=True) + mocker.patch.object(KafkaClient, 'in_flight_request_count', return_value=1) + + now = time.time() + t = mocker.patch('time.time') + t.return_value = now + + # first poll attempts connection + client.poll(timeout_ms=12345678) + client._poll.assert_called_with(2.222) # reconnect backoff + client.maybe_connect.assert_called_once_with('foobar', wakeup=False) + + # poll while connecting should not attempt a new connection + client._connecting.add('foobar') + client._can_connect.reset_mock() + client.poll(timeout_ms=12345678) + client._poll.assert_called_with(2.222) # connection timeout (reconnect timeout) + assert not client._can_connect.called + + assert not client._metadata_refresh_in_progress + + +def test_schedule(): + pass + + +def test_unschedule(): + pass + + +def test_idle_connection_manager(mocker): + t = mocker.patch.object(time, 'time') + t.return_value = 0 + + idle = IdleConnectionManager(100) + assert idle.next_check_ms() == float('inf') + + idle.update('foo') + assert not idle.is_expired('foo') + assert idle.poll_expired_connection() is None + assert idle.next_check_ms() == 100 + + t.return_value = 90 / 1000 + assert not idle.is_expired('foo') + assert idle.poll_expired_connection() is None + assert idle.next_check_ms() == 10 + + t.return_value = 100 / 1000 + assert idle.is_expired('foo') + assert idle.next_check_ms() == 0 + + conn_id, conn_ts = idle.poll_expired_connection() + assert conn_id == 'foo' + assert conn_ts == 0 + + idle.remove('foo') + assert idle.next_check_ms() == float('inf') diff --git a/test_cluster.py b/test_cluster.py new file mode 100644 index 00000000..f010c4f7 --- /dev/null +++ b/test_cluster.py @@ -0,0 +1,22 @@ +# pylint: skip-file +from __future__ import absolute_import + +import pytest + +from kafka.cluster import ClusterMetadata +from kafka.protocol.metadata import MetadataResponse + + +def test_empty_broker_list(): + cluster = ClusterMetadata() + assert len(cluster.brokers()) == 0 + + cluster.update_metadata(MetadataResponse[0]( + [(0, 'foo', 12), (1, 'bar', 34)], [])) + assert len(cluster.brokers()) == 2 + + # empty broker list response should be ignored + cluster.update_metadata(MetadataResponse[0]( + [], # empty brokers + [(17, 'foo', []), (17, 'bar', [])])) # topics w/ error + assert len(cluster.brokers()) == 2 diff --git a/test_codec.py b/test_codec.py new file mode 100644 index 00000000..e0570745 --- /dev/null +++ b/test_codec.py @@ -0,0 +1,124 @@ +from __future__ import absolute_import + +import platform +import struct + +import pytest +from kafka.vendor.six.moves import range + +from kafka.codec import ( + has_snappy, has_lz4, has_zstd, + gzip_encode, gzip_decode, + snappy_encode, snappy_decode, + lz4_encode, lz4_decode, + lz4_encode_old_kafka, lz4_decode_old_kafka, + zstd_encode, zstd_decode, +) + +from test.testutil import random_string + + +def test_gzip(): + for i in range(1000): + b1 = random_string(100).encode('utf-8') + b2 = gzip_decode(gzip_encode(b1)) + assert b1 == b2 + + +@pytest.mark.skipif(not has_snappy(), reason="Snappy not available") +def test_snappy(): + for i in range(1000): + b1 = random_string(100).encode('utf-8') + b2 = snappy_decode(snappy_encode(b1)) + assert b1 == b2 + + +@pytest.mark.skipif(not has_snappy(), reason="Snappy not available") +def test_snappy_detect_xerial(): + import kafka as kafka1 + _detect_xerial_stream = kafka1.codec._detect_xerial_stream + + header = b'\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01Some extra bytes' + false_header = b'\x01SNAPPY\x00\x00\x00\x01\x00\x00\x00\x01' + default_snappy = snappy_encode(b'foobar' * 50) + random_snappy = snappy_encode(b'SNAPPY' * 50, xerial_compatible=False) + short_data = b'\x01\x02\x03\x04' + + assert _detect_xerial_stream(header) is True + assert _detect_xerial_stream(b'') is False + assert _detect_xerial_stream(b'\x00') is False + assert _detect_xerial_stream(false_header) is False + assert _detect_xerial_stream(default_snappy) is True + assert _detect_xerial_stream(random_snappy) is False + assert _detect_xerial_stream(short_data) is False + + +@pytest.mark.skipif(not has_snappy(), reason="Snappy not available") +def test_snappy_decode_xerial(): + header = b'\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01' + random_snappy = snappy_encode(b'SNAPPY' * 50, xerial_compatible=False) + block_len = len(random_snappy) + random_snappy2 = snappy_encode(b'XERIAL' * 50, xerial_compatible=False) + block_len2 = len(random_snappy2) + + to_test = header \ + + struct.pack('!i', block_len) + random_snappy \ + + struct.pack('!i', block_len2) + random_snappy2 \ + + assert snappy_decode(to_test) == (b'SNAPPY' * 50) + (b'XERIAL' * 50) + + +@pytest.mark.skipif(not has_snappy(), reason="Snappy not available") +def test_snappy_encode_xerial(): + to_ensure = ( + b'\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01' + b'\x00\x00\x00\x18' + b'\xac\x02\x14SNAPPY\xfe\x06\x00\xfe\x06\x00\xfe\x06\x00\xfe\x06\x00\x96\x06\x00' + b'\x00\x00\x00\x18' + b'\xac\x02\x14XERIAL\xfe\x06\x00\xfe\x06\x00\xfe\x06\x00\xfe\x06\x00\x96\x06\x00' + ) + + to_test = (b'SNAPPY' * 50) + (b'XERIAL' * 50) + + compressed = snappy_encode(to_test, xerial_compatible=True, xerial_blocksize=300) + assert compressed == to_ensure + + +@pytest.mark.skipif(not has_lz4() or platform.python_implementation() == 'PyPy', + reason="python-lz4 crashes on old versions of pypy") +def test_lz4(): + for i in range(1000): + b1 = random_string(100).encode('utf-8') + b2 = lz4_decode(lz4_encode(b1)) + assert len(b1) == len(b2) + assert b1 == b2 + + +@pytest.mark.skipif(not has_lz4() or platform.python_implementation() == 'PyPy', + reason="python-lz4 crashes on old versions of pypy") +def test_lz4_old(): + for i in range(1000): + b1 = random_string(100).encode('utf-8') + b2 = lz4_decode_old_kafka(lz4_encode_old_kafka(b1)) + assert len(b1) == len(b2) + assert b1 == b2 + + +@pytest.mark.skipif(not has_lz4() or platform.python_implementation() == 'PyPy', + reason="python-lz4 crashes on old versions of pypy") +def test_lz4_incremental(): + for i in range(1000): + # lz4 max single block size is 4MB + # make sure we test with multiple-blocks + b1 = random_string(100).encode('utf-8') * 50000 + b2 = lz4_decode(lz4_encode(b1)) + assert len(b1) == len(b2) + assert b1 == b2 + + +@pytest.mark.skipif(not has_zstd(), reason="Zstd not available") +def test_zstd(): + for _ in range(1000): + b1 = random_string(100).encode('utf-8') + b2 = zstd_decode(zstd_encode(b1)) + assert b1 == b2 diff --git a/test_conn.py b/test_conn.py new file mode 100644 index 00000000..966f7b34 --- /dev/null +++ b/test_conn.py @@ -0,0 +1,342 @@ +# pylint: skip-file +from __future__ import absolute_import + +from errno import EALREADY, EINPROGRESS, EISCONN, ECONNRESET +import socket + +import mock +import pytest + +from kafka.conn import BrokerConnection, ConnectionStates, collect_hosts +from kafka.protocol.api import RequestHeader +from kafka.protocol.metadata import MetadataRequest +from kafka.protocol.produce import ProduceRequest + +import kafka.errors as Errors + + +@pytest.fixture +def dns_lookup(mocker): + return mocker.patch('kafka.conn.dns_lookup', + return_value=[(socket.AF_INET, + None, None, None, + ('localhost', 9092))]) + +@pytest.fixture +def _socket(mocker): + socket = mocker.MagicMock() + socket.connect_ex.return_value = 0 + mocker.patch('socket.socket', return_value=socket) + return socket + + +@pytest.fixture +def conn(_socket, dns_lookup): + conn = BrokerConnection('localhost', 9092, socket.AF_INET) + return conn + + +@pytest.mark.parametrize("states", [ + (([EINPROGRESS, EALREADY], ConnectionStates.CONNECTING),), + (([EALREADY, EALREADY], ConnectionStates.CONNECTING),), + (([0], ConnectionStates.CONNECTED),), + (([EINPROGRESS, EALREADY], ConnectionStates.CONNECTING), + ([ECONNRESET], ConnectionStates.DISCONNECTED)), + (([EINPROGRESS, EALREADY], ConnectionStates.CONNECTING), + ([EALREADY], ConnectionStates.CONNECTING), + ([EISCONN], ConnectionStates.CONNECTED)), +]) +def test_connect(_socket, conn, states): + assert conn.state is ConnectionStates.DISCONNECTED + + for errno, state in states: + _socket.connect_ex.side_effect = errno + conn.connect() + assert conn.state is state + + +def test_connect_timeout(_socket, conn): + assert conn.state is ConnectionStates.DISCONNECTED + + # Initial connect returns EINPROGRESS + # immediate inline connect returns EALREADY + # second explicit connect returns EALREADY + # third explicit connect returns EALREADY and times out via last_attempt + _socket.connect_ex.side_effect = [EINPROGRESS, EALREADY, EALREADY, EALREADY] + conn.connect() + assert conn.state is ConnectionStates.CONNECTING + conn.connect() + assert conn.state is ConnectionStates.CONNECTING + conn.last_attempt = 0 + conn.connect() + assert conn.state is ConnectionStates.DISCONNECTED + + +def test_blacked_out(conn): + with mock.patch("time.time", return_value=1000): + conn.last_attempt = 0 + assert conn.blacked_out() is False + conn.last_attempt = 1000 + assert conn.blacked_out() is True + + +def test_connection_delay(conn): + with mock.patch("time.time", return_value=1000): + conn.last_attempt = 1000 + assert conn.connection_delay() == conn.config['reconnect_backoff_ms'] + conn.state = ConnectionStates.CONNECTING + assert conn.connection_delay() == float('inf') + conn.state = ConnectionStates.CONNECTED + assert conn.connection_delay() == float('inf') + + +def test_connected(conn): + assert conn.connected() is False + conn.state = ConnectionStates.CONNECTED + assert conn.connected() is True + + +def test_connecting(conn): + assert conn.connecting() is False + conn.state = ConnectionStates.CONNECTING + assert conn.connecting() is True + conn.state = ConnectionStates.CONNECTED + assert conn.connecting() is False + + +def test_send_disconnected(conn): + conn.state = ConnectionStates.DISCONNECTED + f = conn.send('foobar') + assert f.failed() is True + assert isinstance(f.exception, Errors.KafkaConnectionError) + + +def test_send_connecting(conn): + conn.state = ConnectionStates.CONNECTING + f = conn.send('foobar') + assert f.failed() is True + assert isinstance(f.exception, Errors.NodeNotReadyError) + + +def test_send_max_ifr(conn): + conn.state = ConnectionStates.CONNECTED + max_ifrs = conn.config['max_in_flight_requests_per_connection'] + for i in range(max_ifrs): + conn.in_flight_requests[i] = 'foo' + f = conn.send('foobar') + assert f.failed() is True + assert isinstance(f.exception, Errors.TooManyInFlightRequests) + + +def test_send_no_response(_socket, conn): + conn.connect() + assert conn.state is ConnectionStates.CONNECTED + req = ProduceRequest[0](required_acks=0, timeout=0, topics=()) + header = RequestHeader(req, client_id=conn.config['client_id']) + payload_bytes = len(header.encode()) + len(req.encode()) + third = payload_bytes // 3 + remainder = payload_bytes % 3 + _socket.send.side_effect = [4, third, third, third, remainder] + + assert len(conn.in_flight_requests) == 0 + f = conn.send(req) + assert f.succeeded() is True + assert f.value is None + assert len(conn.in_flight_requests) == 0 + + +def test_send_response(_socket, conn): + conn.connect() + assert conn.state is ConnectionStates.CONNECTED + req = MetadataRequest[0]([]) + header = RequestHeader(req, client_id=conn.config['client_id']) + payload_bytes = len(header.encode()) + len(req.encode()) + third = payload_bytes // 3 + remainder = payload_bytes % 3 + _socket.send.side_effect = [4, third, third, third, remainder] + + assert len(conn.in_flight_requests) == 0 + f = conn.send(req) + assert f.is_done is False + assert len(conn.in_flight_requests) == 1 + + +def test_send_error(_socket, conn): + conn.connect() + assert conn.state is ConnectionStates.CONNECTED + req = MetadataRequest[0]([]) + try: + _socket.send.side_effect = ConnectionError + except NameError: + _socket.send.side_effect = socket.error + f = conn.send(req) + assert f.failed() is True + assert isinstance(f.exception, Errors.KafkaConnectionError) + assert _socket.close.call_count == 1 + assert conn.state is ConnectionStates.DISCONNECTED + + +def test_can_send_more(conn): + assert conn.can_send_more() is True + max_ifrs = conn.config['max_in_flight_requests_per_connection'] + for i in range(max_ifrs): + assert conn.can_send_more() is True + conn.in_flight_requests[i] = 'foo' + assert conn.can_send_more() is False + + +def test_recv_disconnected(_socket, conn): + conn.connect() + assert conn.connected() + + req = MetadataRequest[0]([]) + header = RequestHeader(req, client_id=conn.config['client_id']) + payload_bytes = len(header.encode()) + len(req.encode()) + _socket.send.side_effect = [4, payload_bytes] + conn.send(req) + + # Empty data on recv means the socket is disconnected + _socket.recv.return_value = b'' + + # Attempt to receive should mark connection as disconnected + assert conn.connected() + conn.recv() + assert conn.disconnected() + + +def test_recv(_socket, conn): + pass # TODO + + +def test_close(conn): + pass # TODO + + +def test_collect_hosts__happy_path(): + hosts = "127.0.0.1:1234,127.0.0.1" + results = collect_hosts(hosts) + assert set(results) == set([ + ('127.0.0.1', 1234, socket.AF_INET), + ('127.0.0.1', 9092, socket.AF_INET), + ]) + + +def test_collect_hosts__ipv6(): + hosts = "[localhost]:1234,[2001:1000:2000::1],[2001:1000:2000::1]:1234" + results = collect_hosts(hosts) + assert set(results) == set([ + ('localhost', 1234, socket.AF_INET6), + ('2001:1000:2000::1', 9092, socket.AF_INET6), + ('2001:1000:2000::1', 1234, socket.AF_INET6), + ]) + + +def test_collect_hosts__string_list(): + hosts = [ + 'localhost:1234', + 'localhost', + '[localhost]', + '2001::1', + '[2001::1]', + '[2001::1]:1234', + ] + results = collect_hosts(hosts) + assert set(results) == set([ + ('localhost', 1234, socket.AF_UNSPEC), + ('localhost', 9092, socket.AF_UNSPEC), + ('localhost', 9092, socket.AF_INET6), + ('2001::1', 9092, socket.AF_INET6), + ('2001::1', 9092, socket.AF_INET6), + ('2001::1', 1234, socket.AF_INET6), + ]) + + +def test_collect_hosts__with_spaces(): + hosts = "localhost:1234, localhost" + results = collect_hosts(hosts) + assert set(results) == set([ + ('localhost', 1234, socket.AF_UNSPEC), + ('localhost', 9092, socket.AF_UNSPEC), + ]) + + +def test_lookup_on_connect(): + hostname = 'example.org' + port = 9092 + conn = BrokerConnection(hostname, port, socket.AF_UNSPEC) + assert conn.host == hostname + assert conn.port == port + assert conn.afi == socket.AF_UNSPEC + afi1 = socket.AF_INET + sockaddr1 = ('127.0.0.1', 9092) + mock_return1 = [ + (afi1, socket.SOCK_STREAM, 6, '', sockaddr1), + ] + with mock.patch("socket.getaddrinfo", return_value=mock_return1) as m: + conn.connect() + m.assert_called_once_with(hostname, port, 0, socket.SOCK_STREAM) + assert conn._sock_afi == afi1 + assert conn._sock_addr == sockaddr1 + conn.close() + + afi2 = socket.AF_INET6 + sockaddr2 = ('::1', 9092, 0, 0) + mock_return2 = [ + (afi2, socket.SOCK_STREAM, 6, '', sockaddr2), + ] + + with mock.patch("socket.getaddrinfo", return_value=mock_return2) as m: + conn.last_attempt = 0 + conn.connect() + m.assert_called_once_with(hostname, port, 0, socket.SOCK_STREAM) + assert conn._sock_afi == afi2 + assert conn._sock_addr == sockaddr2 + conn.close() + + +def test_relookup_on_failure(): + hostname = 'example.org' + port = 9092 + conn = BrokerConnection(hostname, port, socket.AF_UNSPEC) + assert conn.host == hostname + mock_return1 = [] + with mock.patch("socket.getaddrinfo", return_value=mock_return1) as m: + last_attempt = conn.last_attempt + conn.connect() + m.assert_called_once_with(hostname, port, 0, socket.SOCK_STREAM) + assert conn.disconnected() + assert conn.last_attempt > last_attempt + + afi2 = socket.AF_INET + sockaddr2 = ('127.0.0.2', 9092) + mock_return2 = [ + (afi2, socket.SOCK_STREAM, 6, '', sockaddr2), + ] + + with mock.patch("socket.getaddrinfo", return_value=mock_return2) as m: + conn.last_attempt = 0 + conn.connect() + m.assert_called_once_with(hostname, port, 0, socket.SOCK_STREAM) + assert conn._sock_afi == afi2 + assert conn._sock_addr == sockaddr2 + conn.close() + + +def test_requests_timed_out(conn): + with mock.patch("time.time", return_value=0): + # No in-flight requests, not timed out + assert not conn.requests_timed_out() + + # Single request, timestamp = now (0) + conn.in_flight_requests[0] = ('foo', 0) + assert not conn.requests_timed_out() + + # Add another request w/ timestamp > request_timeout ago + request_timeout = conn.config['request_timeout_ms'] + expired_timestamp = 0 - request_timeout - 1 + conn.in_flight_requests[1] = ('bar', expired_timestamp) + assert conn.requests_timed_out() + + # Drop the expired request and we should be good to go again + conn.in_flight_requests.pop(1) + assert not conn.requests_timed_out() diff --git a/test_consumer.py b/test_consumer.py new file mode 100644 index 00000000..436fe55c --- /dev/null +++ b/test_consumer.py @@ -0,0 +1,26 @@ +import pytest + +from kafka import KafkaConsumer +from kafka.errors import KafkaConfigurationError + + +class TestKafkaConsumer: + def test_session_timeout_larger_than_request_timeout_raises(self): + with pytest.raises(KafkaConfigurationError): + KafkaConsumer(bootstrap_servers='localhost:9092', api_version=(0, 9), group_id='foo', session_timeout_ms=50000, request_timeout_ms=40000) + + def test_fetch_max_wait_larger_than_request_timeout_raises(self): + with pytest.raises(KafkaConfigurationError): + KafkaConsumer(bootstrap_servers='localhost:9092', fetch_max_wait_ms=50000, request_timeout_ms=40000) + + def test_request_timeout_larger_than_connections_max_idle_ms_raises(self): + with pytest.raises(KafkaConfigurationError): + KafkaConsumer(bootstrap_servers='localhost:9092', api_version=(0, 9), request_timeout_ms=50000, connections_max_idle_ms=40000) + + def test_subscription_copy(self): + consumer = KafkaConsumer('foo', api_version=(0, 10)) + sub = consumer.subscription() + assert sub is not consumer.subscription() + assert sub == set(['foo']) + sub.add('fizz') + assert consumer.subscription() == set(['foo']) diff --git a/test_consumer_group.py b/test_consumer_group.py new file mode 100644 index 00000000..58dc7ebf --- /dev/null +++ b/test_consumer_group.py @@ -0,0 +1,179 @@ +import collections +import logging +import threading +import time + +import pytest +from kafka.vendor import six + +from kafka.conn import ConnectionStates +from kafka.consumer.group import KafkaConsumer +from kafka.coordinator.base import MemberState +from kafka.structs import TopicPartition + +from test.testutil import env_kafka_version, random_string + + +def get_connect_str(kafka_broker): + return kafka_broker.host + ':' + str(kafka_broker.port) + + +@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set") +def test_consumer(kafka_broker, topic): + # The `topic` fixture is included because + # 0.8.2 brokers need a topic to function well + consumer = KafkaConsumer(bootstrap_servers=get_connect_str(kafka_broker)) + consumer.poll(500) + assert len(consumer._client._conns) > 0 + node_id = list(consumer._client._conns.keys())[0] + assert consumer._client._conns[node_id].state is ConnectionStates.CONNECTED + consumer.close() + + +@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set") +def test_consumer_topics(kafka_broker, topic): + consumer = KafkaConsumer(bootstrap_servers=get_connect_str(kafka_broker)) + # Necessary to drive the IO + consumer.poll(500) + assert topic in consumer.topics() + assert len(consumer.partitions_for_topic(topic)) > 0 + consumer.close() + + +@pytest.mark.skipif(env_kafka_version() < (0, 9), reason='Unsupported Kafka Version') +def test_group(kafka_broker, topic): + num_partitions = 4 + connect_str = get_connect_str(kafka_broker) + consumers = {} + stop = {} + threads = {} + messages = collections.defaultdict(list) + group_id = 'test-group-' + random_string(6) + def consumer_thread(i): + assert i not in consumers + assert i not in stop + stop[i] = threading.Event() + consumers[i] = KafkaConsumer(topic, + bootstrap_servers=connect_str, + group_id=group_id, + heartbeat_interval_ms=500) + while not stop[i].is_set(): + for tp, records in six.itervalues(consumers[i].poll(100)): + messages[i][tp].extend(records) + consumers[i].close() + consumers[i] = None + stop[i] = None + + num_consumers = 4 + for i in range(num_consumers): + t = threading.Thread(target=consumer_thread, args=(i,)) + t.start() + threads[i] = t + + try: + timeout = time.time() + 35 + while True: + for c in range(num_consumers): + + # Verify all consumers have been created + if c not in consumers: + break + + # Verify all consumers have an assignment + elif not consumers[c].assignment(): + break + + # If all consumers exist and have an assignment + else: + + logging.info('All consumers have assignment... checking for stable group') + # Verify all consumers are in the same generation + # then log state and break while loop + generations = set([consumer._coordinator._generation.generation_id + for consumer in list(consumers.values())]) + + # New generation assignment is not complete until + # coordinator.rejoining = False + rejoining = any([consumer._coordinator.rejoining + for consumer in list(consumers.values())]) + + if not rejoining and len(generations) == 1: + for c, consumer in list(consumers.items()): + logging.info("[%s] %s %s: %s", c, + consumer._coordinator._generation.generation_id, + consumer._coordinator._generation.member_id, + consumer.assignment()) + break + else: + logging.info('Rejoining: %s, generations: %s', rejoining, generations) + time.sleep(1) + assert time.time() < timeout, "timeout waiting for assignments" + + logging.info('Group stabilized; verifying assignment') + group_assignment = set() + for c in range(num_consumers): + assert len(consumers[c].assignment()) != 0 + assert set.isdisjoint(consumers[c].assignment(), group_assignment) + group_assignment.update(consumers[c].assignment()) + + assert group_assignment == set([ + TopicPartition(topic, partition) + for partition in range(num_partitions)]) + logging.info('Assignment looks good!') + + finally: + logging.info('Shutting down %s consumers', num_consumers) + for c in range(num_consumers): + logging.info('Stopping consumer %s', c) + stop[c].set() + threads[c].join() + threads[c] = None + + +@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set") +def test_paused(kafka_broker, topic): + consumer = KafkaConsumer(bootstrap_servers=get_connect_str(kafka_broker)) + topics = [TopicPartition(topic, 1)] + consumer.assign(topics) + assert set(topics) == consumer.assignment() + assert set() == consumer.paused() + + consumer.pause(topics[0]) + assert set([topics[0]]) == consumer.paused() + + consumer.resume(topics[0]) + assert set() == consumer.paused() + + consumer.unsubscribe() + assert set() == consumer.paused() + consumer.close() + + +@pytest.mark.skipif(env_kafka_version() < (0, 9), reason='Unsupported Kafka Version') +def test_heartbeat_thread(kafka_broker, topic): + group_id = 'test-group-' + random_string(6) + consumer = KafkaConsumer(topic, + bootstrap_servers=get_connect_str(kafka_broker), + group_id=group_id, + heartbeat_interval_ms=500) + + # poll until we have joined group / have assignment + while not consumer.assignment(): + consumer.poll(timeout_ms=100) + + assert consumer._coordinator.state is MemberState.STABLE + last_poll = consumer._coordinator.heartbeat.last_poll + last_beat = consumer._coordinator.heartbeat.last_send + + timeout = time.time() + 30 + while True: + if time.time() > timeout: + raise RuntimeError('timeout waiting for heartbeat') + if consumer._coordinator.heartbeat.last_send > last_beat: + break + time.sleep(0.5) + + assert consumer._coordinator.heartbeat.last_poll == last_poll + consumer.poll(timeout_ms=100) + assert consumer._coordinator.heartbeat.last_poll > last_poll + consumer.close() diff --git a/test_consumer_integration.py b/test_consumer_integration.py new file mode 100644 index 00000000..90b7ed20 --- /dev/null +++ b/test_consumer_integration.py @@ -0,0 +1,297 @@ +import logging +import time + +from mock import patch +import pytest +from kafka.vendor.six.moves import range + +import kafka.codec +from kafka.errors import UnsupportedCodecError, UnsupportedVersionError +from kafka.structs import TopicPartition, OffsetAndTimestamp + +from test.testutil import Timer, assert_message_count, env_kafka_version, random_string + + +@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set") +def test_kafka_version_infer(kafka_consumer_factory): + consumer = kafka_consumer_factory() + actual_ver_major_minor = env_kafka_version()[:2] + client = consumer._client + conn = list(client._conns.values())[0] + inferred_ver_major_minor = conn.check_version()[:2] + assert actual_ver_major_minor == inferred_ver_major_minor, \ + "Was expecting inferred broker version to be %s but was %s" % (actual_ver_major_minor, inferred_ver_major_minor) + + +@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set") +def test_kafka_consumer(kafka_consumer_factory, send_messages): + """Test KafkaConsumer""" + consumer = kafka_consumer_factory(auto_offset_reset='earliest') + send_messages(range(0, 100), partition=0) + send_messages(range(0, 100), partition=1) + cnt = 0 + messages = {0: [], 1: []} + for message in consumer: + logging.debug("Consumed message %s", repr(message)) + cnt += 1 + messages[message.partition].append(message) + if cnt >= 200: + break + + assert_message_count(messages[0], 100) + assert_message_count(messages[1], 100) + + +@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set") +def test_kafka_consumer_unsupported_encoding( + topic, kafka_producer_factory, kafka_consumer_factory): + # Send a compressed message + producer = kafka_producer_factory(compression_type="gzip") + fut = producer.send(topic, b"simple message" * 200) + fut.get(timeout=5) + producer.close() + + # Consume, but with the related compression codec not available + with patch.object(kafka.codec, "has_gzip") as mocked: + mocked.return_value = False + consumer = kafka_consumer_factory(auto_offset_reset='earliest') + error_msg = "Libraries for gzip compression codec not found" + with pytest.raises(UnsupportedCodecError, match=error_msg): + consumer.poll(timeout_ms=2000) + + +@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set") +def test_kafka_consumer__blocking(kafka_consumer_factory, topic, send_messages): + TIMEOUT_MS = 500 + consumer = kafka_consumer_factory(auto_offset_reset='earliest', + enable_auto_commit=False, + consumer_timeout_ms=TIMEOUT_MS) + + # Manual assignment avoids overhead of consumer group mgmt + consumer.unsubscribe() + consumer.assign([TopicPartition(topic, 0)]) + + # Ask for 5 messages, nothing in queue, block 500ms + with Timer() as t: + with pytest.raises(StopIteration): + msg = next(consumer) + assert t.interval >= (TIMEOUT_MS / 1000.0) + + send_messages(range(0, 10)) + + # Ask for 5 messages, 10 in queue. Get 5 back, no blocking + messages = [] + with Timer() as t: + for i in range(5): + msg = next(consumer) + messages.append(msg) + assert_message_count(messages, 5) + assert t.interval < (TIMEOUT_MS / 1000.0) + + # Ask for 10 messages, get 5 back, block 500ms + messages = [] + with Timer() as t: + with pytest.raises(StopIteration): + for i in range(10): + msg = next(consumer) + messages.append(msg) + assert_message_count(messages, 5) + assert t.interval >= (TIMEOUT_MS / 1000.0) + + +@pytest.mark.skipif(env_kafka_version() < (0, 8, 1), reason="Requires KAFKA_VERSION >= 0.8.1") +def test_kafka_consumer__offset_commit_resume(kafka_consumer_factory, send_messages): + GROUP_ID = random_string(10) + + send_messages(range(0, 100), partition=0) + send_messages(range(100, 200), partition=1) + + # Start a consumer and grab the first 180 messages + consumer1 = kafka_consumer_factory( + group_id=GROUP_ID, + enable_auto_commit=True, + auto_commit_interval_ms=100, + auto_offset_reset='earliest', + ) + output_msgs1 = [] + for _ in range(180): + m = next(consumer1) + output_msgs1.append(m) + assert_message_count(output_msgs1, 180) + + # Normally we let the pytest fixture `kafka_consumer_factory` handle + # closing as part of its teardown. Here we manually call close() to force + # auto-commit to occur before the second consumer starts. That way the + # second consumer only consumes previously unconsumed messages. + consumer1.close() + + # Start a second consumer to grab 181-200 + consumer2 = kafka_consumer_factory( + group_id=GROUP_ID, + enable_auto_commit=True, + auto_commit_interval_ms=100, + auto_offset_reset='earliest', + ) + output_msgs2 = [] + for _ in range(20): + m = next(consumer2) + output_msgs2.append(m) + assert_message_count(output_msgs2, 20) + + # Verify the second consumer wasn't reconsuming messages that the first + # consumer already saw + assert_message_count(output_msgs1 + output_msgs2, 200) + + +@pytest.mark.skipif(env_kafka_version() < (0, 10, 1), reason="Requires KAFKA_VERSION >= 0.10.1") +def test_kafka_consumer_max_bytes_simple(kafka_consumer_factory, topic, send_messages): + send_messages(range(100, 200), partition=0) + send_messages(range(200, 300), partition=1) + + # Start a consumer + consumer = kafka_consumer_factory( + auto_offset_reset='earliest', fetch_max_bytes=300) + seen_partitions = set() + for i in range(90): + poll_res = consumer.poll(timeout_ms=100) + for partition, msgs in poll_res.items(): + for msg in msgs: + seen_partitions.add(partition) + + # Check that we fetched at least 1 message from both partitions + assert seen_partitions == {TopicPartition(topic, 0), TopicPartition(topic, 1)} + + +@pytest.mark.skipif(env_kafka_version() < (0, 10, 1), reason="Requires KAFKA_VERSION >= 0.10.1") +def test_kafka_consumer_max_bytes_one_msg(kafka_consumer_factory, send_messages): + # We send to only 1 partition so we don't have parallel requests to 2 + # nodes for data. + send_messages(range(100, 200)) + + # Start a consumer. FetchResponse_v3 should always include at least 1 + # full msg, so by setting fetch_max_bytes=1 we should get 1 msg at a time + # But 0.11.0.0 returns 1 MessageSet at a time when the messages are + # stored in the new v2 format by the broker. + # + # DP Note: This is a strange test. The consumer shouldn't care + # how many messages are included in a FetchResponse, as long as it is + # non-zero. I would not mind if we deleted this test. It caused + # a minor headache when testing 0.11.0.0. + group = 'test-kafka-consumer-max-bytes-one-msg-' + random_string(5) + consumer = kafka_consumer_factory( + group_id=group, + auto_offset_reset='earliest', + consumer_timeout_ms=5000, + fetch_max_bytes=1) + + fetched_msgs = [next(consumer) for i in range(10)] + assert_message_count(fetched_msgs, 10) + + +@pytest.mark.skipif(env_kafka_version() < (0, 10, 1), reason="Requires KAFKA_VERSION >= 0.10.1") +def test_kafka_consumer_offsets_for_time(topic, kafka_consumer, kafka_producer): + late_time = int(time.time()) * 1000 + middle_time = late_time - 1000 + early_time = late_time - 2000 + tp = TopicPartition(topic, 0) + + timeout = 10 + early_msg = kafka_producer.send( + topic, partition=0, value=b"first", + timestamp_ms=early_time).get(timeout) + late_msg = kafka_producer.send( + topic, partition=0, value=b"last", + timestamp_ms=late_time).get(timeout) + + consumer = kafka_consumer + offsets = consumer.offsets_for_times({tp: early_time}) + assert len(offsets) == 1 + assert offsets[tp].offset == early_msg.offset + assert offsets[tp].timestamp == early_time + + offsets = consumer.offsets_for_times({tp: middle_time}) + assert offsets[tp].offset == late_msg.offset + assert offsets[tp].timestamp == late_time + + offsets = consumer.offsets_for_times({tp: late_time}) + assert offsets[tp].offset == late_msg.offset + assert offsets[tp].timestamp == late_time + + offsets = consumer.offsets_for_times({}) + assert offsets == {} + + # Out of bound timestamps check + + offsets = consumer.offsets_for_times({tp: 0}) + assert offsets[tp].offset == early_msg.offset + assert offsets[tp].timestamp == early_time + + offsets = consumer.offsets_for_times({tp: 9999999999999}) + assert offsets[tp] is None + + # Beginning/End offsets + + offsets = consumer.beginning_offsets([tp]) + assert offsets == {tp: early_msg.offset} + offsets = consumer.end_offsets([tp]) + assert offsets == {tp: late_msg.offset + 1} + + +@pytest.mark.skipif(env_kafka_version() < (0, 10, 1), reason="Requires KAFKA_VERSION >= 0.10.1") +def test_kafka_consumer_offsets_search_many_partitions(kafka_consumer, kafka_producer, topic): + tp0 = TopicPartition(topic, 0) + tp1 = TopicPartition(topic, 1) + + send_time = int(time.time() * 1000) + timeout = 10 + p0msg = kafka_producer.send( + topic, partition=0, value=b"XXX", + timestamp_ms=send_time).get(timeout) + p1msg = kafka_producer.send( + topic, partition=1, value=b"XXX", + timestamp_ms=send_time).get(timeout) + + consumer = kafka_consumer + offsets = consumer.offsets_for_times({ + tp0: send_time, + tp1: send_time + }) + + assert offsets == { + tp0: OffsetAndTimestamp(p0msg.offset, send_time), + tp1: OffsetAndTimestamp(p1msg.offset, send_time) + } + + offsets = consumer.beginning_offsets([tp0, tp1]) + assert offsets == { + tp0: p0msg.offset, + tp1: p1msg.offset + } + + offsets = consumer.end_offsets([tp0, tp1]) + assert offsets == { + tp0: p0msg.offset + 1, + tp1: p1msg.offset + 1 + } + + +@pytest.mark.skipif(env_kafka_version() >= (0, 10, 1), reason="Requires KAFKA_VERSION < 0.10.1") +def test_kafka_consumer_offsets_for_time_old(kafka_consumer, topic): + consumer = kafka_consumer + tp = TopicPartition(topic, 0) + + with pytest.raises(UnsupportedVersionError): + consumer.offsets_for_times({tp: int(time.time())}) + + +@pytest.mark.skipif(env_kafka_version() < (0, 10, 1), reason="Requires KAFKA_VERSION >= 0.10.1") +def test_kafka_consumer_offsets_for_times_errors(kafka_consumer_factory, topic): + consumer = kafka_consumer_factory(fetch_max_wait_ms=200, + request_timeout_ms=500) + tp = TopicPartition(topic, 0) + bad_tp = TopicPartition(topic, 100) + + with pytest.raises(ValueError): + consumer.offsets_for_times({tp: -1}) + + assert consumer.offsets_for_times({bad_tp: 0}) == {bad_tp: None} diff --git a/test_coordinator.py b/test_coordinator.py new file mode 100644 index 00000000..a35cdd1a --- /dev/null +++ b/test_coordinator.py @@ -0,0 +1,638 @@ +# pylint: skip-file +from __future__ import absolute_import +import time + +import pytest + +from kafka.client_async import KafkaClient +from kafka.consumer.subscription_state import ( + SubscriptionState, ConsumerRebalanceListener) +from kafka.coordinator.assignors.range import RangePartitionAssignor +from kafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor +from kafka.coordinator.assignors.sticky.sticky_assignor import StickyPartitionAssignor +from kafka.coordinator.base import Generation, MemberState, HeartbeatThread +from kafka.coordinator.consumer import ConsumerCoordinator +from kafka.coordinator.protocol import ( + ConsumerProtocolMemberMetadata, ConsumerProtocolMemberAssignment) +import kafka.errors as Errors +from kafka.future import Future +from kafka.metrics import Metrics +from kafka.protocol.commit import ( + OffsetCommitRequest, OffsetCommitResponse, + OffsetFetchRequest, OffsetFetchResponse) +from kafka.protocol.metadata import MetadataResponse +from kafka.structs import OffsetAndMetadata, TopicPartition +from kafka.util import WeakMethod + + +@pytest.fixture +def client(conn): + return KafkaClient(api_version=(0, 9)) + +@pytest.fixture +def coordinator(client): + return ConsumerCoordinator(client, SubscriptionState(), Metrics()) + + +def test_init(client, coordinator): + # metadata update on init + assert client.cluster._need_update is True + assert WeakMethod(coordinator._handle_metadata_update) in client.cluster._listeners + + +@pytest.mark.parametrize("api_version", [(0, 8, 0), (0, 8, 1), (0, 8, 2), (0, 9)]) +def test_autocommit_enable_api_version(client, api_version): + coordinator = ConsumerCoordinator(client, SubscriptionState(), + Metrics(), + enable_auto_commit=True, + session_timeout_ms=30000, # session_timeout_ms and max_poll_interval_ms + max_poll_interval_ms=30000, # should be the same to avoid KafkaConfigurationError + group_id='foobar', + api_version=api_version) + if api_version < (0, 8, 1): + assert coordinator.config['enable_auto_commit'] is False + else: + assert coordinator.config['enable_auto_commit'] is True + + +def test_protocol_type(coordinator): + assert coordinator.protocol_type() == 'consumer' + + +def test_group_protocols(coordinator): + # Requires a subscription + try: + coordinator.group_protocols() + except Errors.IllegalStateError: + pass + else: + assert False, 'Exception not raised when expected' + + coordinator._subscription.subscribe(topics=['foobar']) + assert coordinator.group_protocols() == [ + ('range', ConsumerProtocolMemberMetadata( + RangePartitionAssignor.version, + ['foobar'], + b'')), + ('roundrobin', ConsumerProtocolMemberMetadata( + RoundRobinPartitionAssignor.version, + ['foobar'], + b'')), + ('sticky', ConsumerProtocolMemberMetadata( + StickyPartitionAssignor.version, + ['foobar'], + b'')), + ] + + +@pytest.mark.parametrize('api_version', [(0, 8, 0), (0, 8, 1), (0, 8, 2), (0, 9)]) +def test_pattern_subscription(coordinator, api_version): + coordinator.config['api_version'] = api_version + coordinator._subscription.subscribe(pattern='foo') + assert coordinator._subscription.subscription == set([]) + assert coordinator._metadata_snapshot == coordinator._build_metadata_snapshot(coordinator._subscription, {}) + + cluster = coordinator._client.cluster + cluster.update_metadata(MetadataResponse[0]( + # brokers + [(0, 'foo', 12), (1, 'bar', 34)], + # topics + [(0, 'fizz', []), + (0, 'foo1', [(0, 0, 0, [], [])]), + (0, 'foo2', [(0, 0, 1, [], [])])])) + assert coordinator._subscription.subscription == {'foo1', 'foo2'} + + # 0.9 consumers should trigger dynamic partition assignment + if api_version >= (0, 9): + assert coordinator._subscription.assignment == {} + + # earlier consumers get all partitions assigned locally + else: + assert set(coordinator._subscription.assignment.keys()) == {TopicPartition('foo1', 0), + TopicPartition('foo2', 0)} + + +def test_lookup_assignor(coordinator): + assert coordinator._lookup_assignor('roundrobin') is RoundRobinPartitionAssignor + assert coordinator._lookup_assignor('range') is RangePartitionAssignor + assert coordinator._lookup_assignor('sticky') is StickyPartitionAssignor + assert coordinator._lookup_assignor('foobar') is None + + +def test_join_complete(mocker, coordinator): + coordinator._subscription.subscribe(topics=['foobar']) + assignor = RoundRobinPartitionAssignor() + coordinator.config['assignors'] = (assignor,) + mocker.spy(assignor, 'on_assignment') + assert assignor.on_assignment.call_count == 0 + assignment = ConsumerProtocolMemberAssignment(0, [('foobar', [0, 1])], b'') + coordinator._on_join_complete(0, 'member-foo', 'roundrobin', assignment.encode()) + assert assignor.on_assignment.call_count == 1 + assignor.on_assignment.assert_called_with(assignment) + + +def test_join_complete_with_sticky_assignor(mocker, coordinator): + coordinator._subscription.subscribe(topics=['foobar']) + assignor = StickyPartitionAssignor() + coordinator.config['assignors'] = (assignor,) + mocker.spy(assignor, 'on_assignment') + mocker.spy(assignor, 'on_generation_assignment') + assert assignor.on_assignment.call_count == 0 + assert assignor.on_generation_assignment.call_count == 0 + assignment = ConsumerProtocolMemberAssignment(0, [('foobar', [0, 1])], b'') + coordinator._on_join_complete(0, 'member-foo', 'sticky', assignment.encode()) + assert assignor.on_assignment.call_count == 1 + assert assignor.on_generation_assignment.call_count == 1 + assignor.on_assignment.assert_called_with(assignment) + assignor.on_generation_assignment.assert_called_with(0) + + +def test_subscription_listener(mocker, coordinator): + listener = mocker.MagicMock(spec=ConsumerRebalanceListener) + coordinator._subscription.subscribe( + topics=['foobar'], + listener=listener) + + coordinator._on_join_prepare(0, 'member-foo') + assert listener.on_partitions_revoked.call_count == 1 + listener.on_partitions_revoked.assert_called_with(set([])) + + assignment = ConsumerProtocolMemberAssignment(0, [('foobar', [0, 1])], b'') + coordinator._on_join_complete( + 0, 'member-foo', 'roundrobin', assignment.encode()) + assert listener.on_partitions_assigned.call_count == 1 + listener.on_partitions_assigned.assert_called_with({TopicPartition('foobar', 0), TopicPartition('foobar', 1)}) + + +def test_subscription_listener_failure(mocker, coordinator): + listener = mocker.MagicMock(spec=ConsumerRebalanceListener) + coordinator._subscription.subscribe( + topics=['foobar'], + listener=listener) + + # exception raised in listener should not be re-raised by coordinator + listener.on_partitions_revoked.side_effect = Exception('crash') + coordinator._on_join_prepare(0, 'member-foo') + assert listener.on_partitions_revoked.call_count == 1 + + assignment = ConsumerProtocolMemberAssignment(0, [('foobar', [0, 1])], b'') + coordinator._on_join_complete( + 0, 'member-foo', 'roundrobin', assignment.encode()) + assert listener.on_partitions_assigned.call_count == 1 + + +def test_perform_assignment(mocker, coordinator): + member_metadata = { + 'member-foo': ConsumerProtocolMemberMetadata(0, ['foo1'], b''), + 'member-bar': ConsumerProtocolMemberMetadata(0, ['foo1'], b'') + } + assignments = { + 'member-foo': ConsumerProtocolMemberAssignment( + 0, [('foo1', [0])], b''), + 'member-bar': ConsumerProtocolMemberAssignment( + 0, [('foo1', [1])], b'') + } + + mocker.patch.object(RoundRobinPartitionAssignor, 'assign') + RoundRobinPartitionAssignor.assign.return_value = assignments + + ret = coordinator._perform_assignment( + 'member-foo', 'roundrobin', + [(member, metadata.encode()) + for member, metadata in member_metadata.items()]) + + assert RoundRobinPartitionAssignor.assign.call_count == 1 + RoundRobinPartitionAssignor.assign.assert_called_with( + coordinator._client.cluster, member_metadata) + assert ret == assignments + + +def test_on_join_prepare(coordinator): + coordinator._subscription.subscribe(topics=['foobar']) + coordinator._on_join_prepare(0, 'member-foo') + + +def test_need_rejoin(coordinator): + # No subscription - no rejoin + assert coordinator.need_rejoin() is False + + coordinator._subscription.subscribe(topics=['foobar']) + assert coordinator.need_rejoin() is True + + +def test_refresh_committed_offsets_if_needed(mocker, coordinator): + mocker.patch.object(ConsumerCoordinator, 'fetch_committed_offsets', + return_value = { + TopicPartition('foobar', 0): OffsetAndMetadata(123, b''), + TopicPartition('foobar', 1): OffsetAndMetadata(234, b'')}) + coordinator._subscription.assign_from_user([TopicPartition('foobar', 0)]) + assert coordinator._subscription.needs_fetch_committed_offsets is True + coordinator.refresh_committed_offsets_if_needed() + assignment = coordinator._subscription.assignment + assert assignment[TopicPartition('foobar', 0)].committed == OffsetAndMetadata(123, b'') + assert TopicPartition('foobar', 1) not in assignment + assert coordinator._subscription.needs_fetch_committed_offsets is False + + +def test_fetch_committed_offsets(mocker, coordinator): + + # No partitions, no IO polling + mocker.patch.object(coordinator._client, 'poll') + assert coordinator.fetch_committed_offsets([]) == {} + assert coordinator._client.poll.call_count == 0 + + # general case -- send offset fetch request, get successful future + mocker.patch.object(coordinator, 'ensure_coordinator_ready') + mocker.patch.object(coordinator, '_send_offset_fetch_request', + return_value=Future().success('foobar')) + partitions = [TopicPartition('foobar', 0)] + ret = coordinator.fetch_committed_offsets(partitions) + assert ret == 'foobar' + coordinator._send_offset_fetch_request.assert_called_with(partitions) + assert coordinator._client.poll.call_count == 1 + + # Failed future is raised if not retriable + coordinator._send_offset_fetch_request.return_value = Future().failure(AssertionError) + coordinator._client.poll.reset_mock() + try: + coordinator.fetch_committed_offsets(partitions) + except AssertionError: + pass + else: + assert False, 'Exception not raised when expected' + assert coordinator._client.poll.call_count == 1 + + coordinator._client.poll.reset_mock() + coordinator._send_offset_fetch_request.side_effect = [ + Future().failure(Errors.RequestTimedOutError), + Future().success('fizzbuzz')] + + ret = coordinator.fetch_committed_offsets(partitions) + assert ret == 'fizzbuzz' + assert coordinator._client.poll.call_count == 2 # call + retry + + +def test_close(mocker, coordinator): + mocker.patch.object(coordinator, '_maybe_auto_commit_offsets_sync') + mocker.patch.object(coordinator, '_handle_leave_group_response') + mocker.patch.object(coordinator, 'coordinator_unknown', return_value=False) + coordinator.coordinator_id = 0 + coordinator._generation = Generation(1, 'foobar', b'') + coordinator.state = MemberState.STABLE + cli = coordinator._client + mocker.patch.object(cli, 'send', return_value=Future().success('foobar')) + mocker.patch.object(cli, 'poll') + + coordinator.close() + assert coordinator._maybe_auto_commit_offsets_sync.call_count == 1 + coordinator._handle_leave_group_response.assert_called_with('foobar') + + assert coordinator.generation() is None + assert coordinator._generation is Generation.NO_GENERATION + assert coordinator.state is MemberState.UNJOINED + assert coordinator.rejoin_needed is True + + +@pytest.fixture +def offsets(): + return { + TopicPartition('foobar', 0): OffsetAndMetadata(123, b''), + TopicPartition('foobar', 1): OffsetAndMetadata(234, b''), + } + + +def test_commit_offsets_async(mocker, coordinator, offsets): + mocker.patch.object(coordinator._client, 'poll') + mocker.patch.object(coordinator, 'coordinator_unknown', return_value=False) + mocker.patch.object(coordinator, 'ensure_coordinator_ready') + mocker.patch.object(coordinator, '_send_offset_commit_request', + return_value=Future().success('fizzbuzz')) + coordinator.commit_offsets_async(offsets) + assert coordinator._send_offset_commit_request.call_count == 1 + + +def test_commit_offsets_sync(mocker, coordinator, offsets): + mocker.patch.object(coordinator, 'ensure_coordinator_ready') + mocker.patch.object(coordinator, '_send_offset_commit_request', + return_value=Future().success('fizzbuzz')) + cli = coordinator._client + mocker.patch.object(cli, 'poll') + + # No offsets, no calls + assert coordinator.commit_offsets_sync({}) is None + assert coordinator._send_offset_commit_request.call_count == 0 + assert cli.poll.call_count == 0 + + ret = coordinator.commit_offsets_sync(offsets) + assert coordinator._send_offset_commit_request.call_count == 1 + assert cli.poll.call_count == 1 + assert ret == 'fizzbuzz' + + # Failed future is raised if not retriable + coordinator._send_offset_commit_request.return_value = Future().failure(AssertionError) + coordinator._client.poll.reset_mock() + try: + coordinator.commit_offsets_sync(offsets) + except AssertionError: + pass + else: + assert False, 'Exception not raised when expected' + assert coordinator._client.poll.call_count == 1 + + coordinator._client.poll.reset_mock() + coordinator._send_offset_commit_request.side_effect = [ + Future().failure(Errors.RequestTimedOutError), + Future().success('fizzbuzz')] + + ret = coordinator.commit_offsets_sync(offsets) + assert ret == 'fizzbuzz' + assert coordinator._client.poll.call_count == 2 # call + retry + + +@pytest.mark.parametrize( + 'api_version,group_id,enable,error,has_auto_commit,commit_offsets,warn,exc', [ + ((0, 8, 0), 'foobar', True, None, False, False, True, False), + ((0, 8, 1), 'foobar', True, None, True, True, False, False), + ((0, 8, 2), 'foobar', True, None, True, True, False, False), + ((0, 9), 'foobar', False, None, False, False, False, False), + ((0, 9), 'foobar', True, Errors.UnknownMemberIdError(), True, True, True, False), + ((0, 9), 'foobar', True, Errors.IllegalGenerationError(), True, True, True, False), + ((0, 9), 'foobar', True, Errors.RebalanceInProgressError(), True, True, True, False), + ((0, 9), 'foobar', True, Exception(), True, True, False, True), + ((0, 9), 'foobar', True, None, True, True, False, False), + ((0, 9), None, True, None, False, False, True, False), + ]) +def test_maybe_auto_commit_offsets_sync(mocker, api_version, group_id, enable, + error, has_auto_commit, commit_offsets, + warn, exc): + mock_warn = mocker.patch('kafka.coordinator.consumer.log.warning') + mock_exc = mocker.patch('kafka.coordinator.consumer.log.exception') + client = KafkaClient(api_version=api_version) + coordinator = ConsumerCoordinator(client, SubscriptionState(), + Metrics(), + api_version=api_version, + session_timeout_ms=30000, + max_poll_interval_ms=30000, + enable_auto_commit=enable, + group_id=group_id) + commit_sync = mocker.patch.object(coordinator, 'commit_offsets_sync', + side_effect=error) + if has_auto_commit: + assert coordinator.next_auto_commit_deadline is not None + else: + assert coordinator.next_auto_commit_deadline is None + + assert coordinator._maybe_auto_commit_offsets_sync() is None + + if has_auto_commit: + assert coordinator.next_auto_commit_deadline is not None + + assert commit_sync.call_count == (1 if commit_offsets else 0) + assert mock_warn.call_count == (1 if warn else 0) + assert mock_exc.call_count == (1 if exc else 0) + + +@pytest.fixture +def patched_coord(mocker, coordinator): + coordinator._subscription.subscribe(topics=['foobar']) + mocker.patch.object(coordinator, 'coordinator_unknown', return_value=False) + coordinator.coordinator_id = 0 + mocker.patch.object(coordinator, 'coordinator', return_value=0) + coordinator._generation = Generation(0, 'foobar', b'') + coordinator.state = MemberState.STABLE + coordinator.rejoin_needed = False + mocker.patch.object(coordinator, 'need_rejoin', return_value=False) + mocker.patch.object(coordinator._client, 'least_loaded_node', + return_value=1) + mocker.patch.object(coordinator._client, 'ready', return_value=True) + mocker.patch.object(coordinator._client, 'send') + mocker.patch.object(coordinator, '_heartbeat_thread') + mocker.spy(coordinator, '_failed_request') + mocker.spy(coordinator, '_handle_offset_commit_response') + mocker.spy(coordinator, '_handle_offset_fetch_response') + return coordinator + + +def test_send_offset_commit_request_fail(mocker, patched_coord, offsets): + patched_coord.coordinator_unknown.return_value = True + patched_coord.coordinator_id = None + patched_coord.coordinator.return_value = None + + # No offsets + ret = patched_coord._send_offset_commit_request({}) + assert isinstance(ret, Future) + assert ret.succeeded() + + # No coordinator + ret = patched_coord._send_offset_commit_request(offsets) + assert ret.failed() + assert isinstance(ret.exception, Errors.GroupCoordinatorNotAvailableError) + + +@pytest.mark.parametrize('api_version,req_type', [ + ((0, 8, 1), OffsetCommitRequest[0]), + ((0, 8, 2), OffsetCommitRequest[1]), + ((0, 9), OffsetCommitRequest[2])]) +def test_send_offset_commit_request_versions(patched_coord, offsets, + api_version, req_type): + expect_node = 0 + patched_coord.config['api_version'] = api_version + + patched_coord._send_offset_commit_request(offsets) + (node, request), _ = patched_coord._client.send.call_args + assert node == expect_node, 'Unexpected coordinator node' + assert isinstance(request, req_type) + + +def test_send_offset_commit_request_failure(patched_coord, offsets): + _f = Future() + patched_coord._client.send.return_value = _f + future = patched_coord._send_offset_commit_request(offsets) + (node, request), _ = patched_coord._client.send.call_args + error = Exception() + _f.failure(error) + patched_coord._failed_request.assert_called_with(0, request, future, error) + assert future.failed() + assert future.exception is error + + +def test_send_offset_commit_request_success(mocker, patched_coord, offsets): + _f = Future() + patched_coord._client.send.return_value = _f + future = patched_coord._send_offset_commit_request(offsets) + (node, request), _ = patched_coord._client.send.call_args + response = OffsetCommitResponse[0]([('foobar', [(0, 0), (1, 0)])]) + _f.success(response) + patched_coord._handle_offset_commit_response.assert_called_with( + offsets, future, mocker.ANY, response) + + +@pytest.mark.parametrize('response,error,dead', [ + (OffsetCommitResponse[0]([('foobar', [(0, 30), (1, 30)])]), + Errors.GroupAuthorizationFailedError, False), + (OffsetCommitResponse[0]([('foobar', [(0, 12), (1, 12)])]), + Errors.OffsetMetadataTooLargeError, False), + (OffsetCommitResponse[0]([('foobar', [(0, 28), (1, 28)])]), + Errors.InvalidCommitOffsetSizeError, False), + (OffsetCommitResponse[0]([('foobar', [(0, 14), (1, 14)])]), + Errors.GroupLoadInProgressError, False), + (OffsetCommitResponse[0]([('foobar', [(0, 15), (1, 15)])]), + Errors.GroupCoordinatorNotAvailableError, True), + (OffsetCommitResponse[0]([('foobar', [(0, 16), (1, 16)])]), + Errors.NotCoordinatorForGroupError, True), + (OffsetCommitResponse[0]([('foobar', [(0, 7), (1, 7)])]), + Errors.RequestTimedOutError, True), + (OffsetCommitResponse[0]([('foobar', [(0, 25), (1, 25)])]), + Errors.CommitFailedError, False), + (OffsetCommitResponse[0]([('foobar', [(0, 22), (1, 22)])]), + Errors.CommitFailedError, False), + (OffsetCommitResponse[0]([('foobar', [(0, 27), (1, 27)])]), + Errors.CommitFailedError, False), + (OffsetCommitResponse[0]([('foobar', [(0, 17), (1, 17)])]), + Errors.InvalidTopicError, False), + (OffsetCommitResponse[0]([('foobar', [(0, 29), (1, 29)])]), + Errors.TopicAuthorizationFailedError, False), +]) +def test_handle_offset_commit_response(mocker, patched_coord, offsets, + response, error, dead): + future = Future() + patched_coord._handle_offset_commit_response(offsets, future, time.time(), + response) + assert isinstance(future.exception, error) + assert patched_coord.coordinator_id is (None if dead else 0) + + +@pytest.fixture +def partitions(): + return [TopicPartition('foobar', 0), TopicPartition('foobar', 1)] + + +def test_send_offset_fetch_request_fail(mocker, patched_coord, partitions): + patched_coord.coordinator_unknown.return_value = True + patched_coord.coordinator_id = None + patched_coord.coordinator.return_value = None + + # No partitions + ret = patched_coord._send_offset_fetch_request([]) + assert isinstance(ret, Future) + assert ret.succeeded() + assert ret.value == {} + + # No coordinator + ret = patched_coord._send_offset_fetch_request(partitions) + assert ret.failed() + assert isinstance(ret.exception, Errors.GroupCoordinatorNotAvailableError) + + +@pytest.mark.parametrize('api_version,req_type', [ + ((0, 8, 1), OffsetFetchRequest[0]), + ((0, 8, 2), OffsetFetchRequest[1]), + ((0, 9), OffsetFetchRequest[1])]) +def test_send_offset_fetch_request_versions(patched_coord, partitions, + api_version, req_type): + # assuming fixture sets coordinator=0, least_loaded_node=1 + expect_node = 0 + patched_coord.config['api_version'] = api_version + + patched_coord._send_offset_fetch_request(partitions) + (node, request), _ = patched_coord._client.send.call_args + assert node == expect_node, 'Unexpected coordinator node' + assert isinstance(request, req_type) + + +def test_send_offset_fetch_request_failure(patched_coord, partitions): + _f = Future() + patched_coord._client.send.return_value = _f + future = patched_coord._send_offset_fetch_request(partitions) + (node, request), _ = patched_coord._client.send.call_args + error = Exception() + _f.failure(error) + patched_coord._failed_request.assert_called_with(0, request, future, error) + assert future.failed() + assert future.exception is error + + +def test_send_offset_fetch_request_success(patched_coord, partitions): + _f = Future() + patched_coord._client.send.return_value = _f + future = patched_coord._send_offset_fetch_request(partitions) + (node, request), _ = patched_coord._client.send.call_args + response = OffsetFetchResponse[0]([('foobar', [(0, 123, b'', 0), (1, 234, b'', 0)])]) + _f.success(response) + patched_coord._handle_offset_fetch_response.assert_called_with( + future, response) + + +@pytest.mark.parametrize('response,error,dead', [ + (OffsetFetchResponse[0]([('foobar', [(0, 123, b'', 14), (1, 234, b'', 14)])]), + Errors.GroupLoadInProgressError, False), + (OffsetFetchResponse[0]([('foobar', [(0, 123, b'', 16), (1, 234, b'', 16)])]), + Errors.NotCoordinatorForGroupError, True), + (OffsetFetchResponse[0]([('foobar', [(0, 123, b'', 25), (1, 234, b'', 25)])]), + Errors.UnknownMemberIdError, False), + (OffsetFetchResponse[0]([('foobar', [(0, 123, b'', 22), (1, 234, b'', 22)])]), + Errors.IllegalGenerationError, False), + (OffsetFetchResponse[0]([('foobar', [(0, 123, b'', 29), (1, 234, b'', 29)])]), + Errors.TopicAuthorizationFailedError, False), + (OffsetFetchResponse[0]([('foobar', [(0, 123, b'', 0), (1, 234, b'', 0)])]), + None, False), +]) +def test_handle_offset_fetch_response(patched_coord, offsets, + response, error, dead): + future = Future() + patched_coord._handle_offset_fetch_response(future, response) + if error is not None: + assert isinstance(future.exception, error) + else: + assert future.succeeded() + assert future.value == offsets + assert patched_coord.coordinator_id is (None if dead else 0) + + +def test_heartbeat(mocker, patched_coord): + heartbeat = HeartbeatThread(patched_coord) + + assert not heartbeat.enabled and not heartbeat.closed + + heartbeat.enable() + assert heartbeat.enabled + + heartbeat.disable() + assert not heartbeat.enabled + + # heartbeat disables when un-joined + heartbeat.enable() + patched_coord.state = MemberState.UNJOINED + heartbeat._run_once() + assert not heartbeat.enabled + + heartbeat.enable() + patched_coord.state = MemberState.STABLE + mocker.spy(patched_coord, '_send_heartbeat_request') + mocker.patch.object(patched_coord.heartbeat, 'should_heartbeat', return_value=True) + heartbeat._run_once() + assert patched_coord._send_heartbeat_request.call_count == 1 + + heartbeat.close() + assert heartbeat.closed + + +def test_lookup_coordinator_failure(mocker, coordinator): + + mocker.patch.object(coordinator, '_send_group_coordinator_request', + return_value=Future().failure(Exception('foobar'))) + future = coordinator.lookup_coordinator() + assert future.failed() + + +def test_ensure_active_group(mocker, coordinator): + coordinator._subscription.subscribe(topics=['foobar']) + mocker.patch.object(coordinator, 'coordinator_unknown', return_value=False) + mocker.patch.object(coordinator, '_send_join_group_request', return_value=Future().success(True)) + mocker.patch.object(coordinator, 'need_rejoin', side_effect=[True, False]) + mocker.patch.object(coordinator, '_on_join_complete') + mocker.patch.object(coordinator, '_heartbeat_thread') + + coordinator.ensure_active_group() + + coordinator._send_join_group_request.assert_called_once_with() diff --git a/test_fetcher.py b/test_fetcher.py new file mode 100644 index 00000000..697f8be1 --- /dev/null +++ b/test_fetcher.py @@ -0,0 +1,553 @@ +# pylint: skip-file +from __future__ import absolute_import + +import pytest + +from collections import OrderedDict +import itertools +import time + +from kafka.client_async import KafkaClient +from kafka.consumer.fetcher import ( + CompletedFetch, ConsumerRecord, Fetcher, NoOffsetForPartitionError +) +from kafka.consumer.subscription_state import SubscriptionState +from kafka.future import Future +from kafka.metrics import Metrics +from kafka.protocol.fetch import FetchRequest, FetchResponse +from kafka.protocol.offset import OffsetResponse +from kafka.errors import ( + StaleMetadata, LeaderNotAvailableError, NotLeaderForPartitionError, + UnknownTopicOrPartitionError, OffsetOutOfRangeError +) +from kafka.record.memory_records import MemoryRecordsBuilder, MemoryRecords +from kafka.structs import OffsetAndMetadata, TopicPartition + + +@pytest.fixture +def client(mocker): + return mocker.Mock(spec=KafkaClient(bootstrap_servers=(), api_version=(0, 9))) + + +@pytest.fixture +def subscription_state(): + return SubscriptionState() + + +@pytest.fixture +def topic(): + return 'foobar' + + +@pytest.fixture +def fetcher(client, subscription_state, topic): + subscription_state.subscribe(topics=[topic]) + assignment = [TopicPartition(topic, i) for i in range(3)] + subscription_state.assign_from_subscribed(assignment) + for tp in assignment: + subscription_state.seek(tp, 0) + return Fetcher(client, subscription_state, Metrics()) + + +def _build_record_batch(msgs, compression=0): + builder = MemoryRecordsBuilder( + magic=1, compression_type=0, batch_size=9999999) + for msg in msgs: + key, value, timestamp = msg + builder.append(key=key, value=value, timestamp=timestamp, headers=[]) + builder.close() + return builder.buffer() + + +def test_send_fetches(fetcher, topic, mocker): + fetch_requests = [ + FetchRequest[0]( + -1, fetcher.config['fetch_max_wait_ms'], + fetcher.config['fetch_min_bytes'], + [(topic, [ + (0, 0, fetcher.config['max_partition_fetch_bytes']), + (1, 0, fetcher.config['max_partition_fetch_bytes']), + ])]), + FetchRequest[0]( + -1, fetcher.config['fetch_max_wait_ms'], + fetcher.config['fetch_min_bytes'], + [(topic, [ + (2, 0, fetcher.config['max_partition_fetch_bytes']), + ])]) + ] + + mocker.patch.object(fetcher, '_create_fetch_requests', + return_value=dict(enumerate(fetch_requests))) + + ret = fetcher.send_fetches() + for node, request in enumerate(fetch_requests): + fetcher._client.send.assert_any_call(node, request, wakeup=False) + assert len(ret) == len(fetch_requests) + + +@pytest.mark.parametrize(("api_version", "fetch_version"), [ + ((0, 10, 1), 3), + ((0, 10, 0), 2), + ((0, 9), 1), + ((0, 8), 0) +]) +def test_create_fetch_requests(fetcher, mocker, api_version, fetch_version): + fetcher._client.in_flight_request_count.return_value = 0 + fetcher.config['api_version'] = api_version + by_node = fetcher._create_fetch_requests() + requests = by_node.values() + assert all([isinstance(r, FetchRequest[fetch_version]) for r in requests]) + + +def test_update_fetch_positions(fetcher, topic, mocker): + mocker.patch.object(fetcher, '_reset_offset') + partition = TopicPartition(topic, 0) + + # unassigned partition + fetcher.update_fetch_positions([TopicPartition('fizzbuzz', 0)]) + assert fetcher._reset_offset.call_count == 0 + + # fetchable partition (has offset, not paused) + fetcher.update_fetch_positions([partition]) + assert fetcher._reset_offset.call_count == 0 + + # partition needs reset, no committed offset + fetcher._subscriptions.need_offset_reset(partition) + fetcher._subscriptions.assignment[partition].awaiting_reset = False + fetcher.update_fetch_positions([partition]) + fetcher._reset_offset.assert_called_with(partition) + assert fetcher._subscriptions.assignment[partition].awaiting_reset is True + fetcher.update_fetch_positions([partition]) + fetcher._reset_offset.assert_called_with(partition) + + # partition needs reset, has committed offset + fetcher._reset_offset.reset_mock() + fetcher._subscriptions.need_offset_reset(partition) + fetcher._subscriptions.assignment[partition].awaiting_reset = False + fetcher._subscriptions.assignment[partition].committed = OffsetAndMetadata(123, b'') + mocker.patch.object(fetcher._subscriptions, 'seek') + fetcher.update_fetch_positions([partition]) + assert fetcher._reset_offset.call_count == 0 + fetcher._subscriptions.seek.assert_called_with(partition, 123) + + +def test__reset_offset(fetcher, mocker): + tp = TopicPartition("topic", 0) + fetcher._subscriptions.subscribe(topics="topic") + fetcher._subscriptions.assign_from_subscribed([tp]) + fetcher._subscriptions.need_offset_reset(tp) + mocked = mocker.patch.object(fetcher, '_retrieve_offsets') + + mocked.return_value = {tp: (1001, None)} + fetcher._reset_offset(tp) + assert not fetcher._subscriptions.assignment[tp].awaiting_reset + assert fetcher._subscriptions.assignment[tp].position == 1001 + + +def test__send_offset_requests(fetcher, mocker): + tp = TopicPartition("topic_send_offset", 1) + mocked_send = mocker.patch.object(fetcher, "_send_offset_request") + send_futures = [] + + def send_side_effect(*args, **kw): + f = Future() + send_futures.append(f) + return f + mocked_send.side_effect = send_side_effect + + mocked_leader = mocker.patch.object( + fetcher._client.cluster, "leader_for_partition") + # First we report unavailable leader 2 times different ways and later + # always as available + mocked_leader.side_effect = itertools.chain( + [None, -1], itertools.cycle([0])) + + # Leader == None + fut = fetcher._send_offset_requests({tp: 0}) + assert fut.failed() + assert isinstance(fut.exception, StaleMetadata) + assert not mocked_send.called + + # Leader == -1 + fut = fetcher._send_offset_requests({tp: 0}) + assert fut.failed() + assert isinstance(fut.exception, LeaderNotAvailableError) + assert not mocked_send.called + + # Leader == 0, send failed + fut = fetcher._send_offset_requests({tp: 0}) + assert not fut.is_done + assert mocked_send.called + # Check that we bound the futures correctly to chain failure + send_futures.pop().failure(NotLeaderForPartitionError(tp)) + assert fut.failed() + assert isinstance(fut.exception, NotLeaderForPartitionError) + + # Leader == 0, send success + fut = fetcher._send_offset_requests({tp: 0}) + assert not fut.is_done + assert mocked_send.called + # Check that we bound the futures correctly to chain success + send_futures.pop().success({tp: (10, 10000)}) + assert fut.succeeded() + assert fut.value == {tp: (10, 10000)} + + +def test__send_offset_requests_multiple_nodes(fetcher, mocker): + tp1 = TopicPartition("topic_send_offset", 1) + tp2 = TopicPartition("topic_send_offset", 2) + tp3 = TopicPartition("topic_send_offset", 3) + tp4 = TopicPartition("topic_send_offset", 4) + mocked_send = mocker.patch.object(fetcher, "_send_offset_request") + send_futures = [] + + def send_side_effect(node_id, timestamps): + f = Future() + send_futures.append((node_id, timestamps, f)) + return f + mocked_send.side_effect = send_side_effect + + mocked_leader = mocker.patch.object( + fetcher._client.cluster, "leader_for_partition") + mocked_leader.side_effect = itertools.cycle([0, 1]) + + # -- All node succeeded case + tss = OrderedDict([(tp1, 0), (tp2, 0), (tp3, 0), (tp4, 0)]) + fut = fetcher._send_offset_requests(tss) + assert not fut.is_done + assert mocked_send.call_count == 2 + + req_by_node = {} + second_future = None + for node, timestamps, f in send_futures: + req_by_node[node] = timestamps + if node == 0: + # Say tp3 does not have any messages so it's missing + f.success({tp1: (11, 1001)}) + else: + second_future = f + assert req_by_node == { + 0: {tp1: 0, tp3: 0}, + 1: {tp2: 0, tp4: 0} + } + + # We only resolved 1 future so far, so result future is not yet ready + assert not fut.is_done + second_future.success({tp2: (12, 1002), tp4: (14, 1004)}) + assert fut.succeeded() + assert fut.value == {tp1: (11, 1001), tp2: (12, 1002), tp4: (14, 1004)} + + # -- First succeeded second not + del send_futures[:] + fut = fetcher._send_offset_requests(tss) + assert len(send_futures) == 2 + send_futures[0][2].success({tp1: (11, 1001)}) + send_futures[1][2].failure(UnknownTopicOrPartitionError(tp1)) + assert fut.failed() + assert isinstance(fut.exception, UnknownTopicOrPartitionError) + + # -- First fails second succeeded + del send_futures[:] + fut = fetcher._send_offset_requests(tss) + assert len(send_futures) == 2 + send_futures[0][2].failure(UnknownTopicOrPartitionError(tp1)) + send_futures[1][2].success({tp1: (11, 1001)}) + assert fut.failed() + assert isinstance(fut.exception, UnknownTopicOrPartitionError) + + +def test__handle_offset_response(fetcher, mocker): + # Broker returns UnsupportedForMessageFormatError, will omit partition + fut = Future() + res = OffsetResponse[1]([ + ("topic", [(0, 43, -1, -1)]), + ("topic", [(1, 0, 1000, 9999)]) + ]) + fetcher._handle_offset_response(fut, res) + assert fut.succeeded() + assert fut.value == {TopicPartition("topic", 1): (9999, 1000)} + + # Broker returns NotLeaderForPartitionError + fut = Future() + res = OffsetResponse[1]([ + ("topic", [(0, 6, -1, -1)]), + ]) + fetcher._handle_offset_response(fut, res) + assert fut.failed() + assert isinstance(fut.exception, NotLeaderForPartitionError) + + # Broker returns UnknownTopicOrPartitionError + fut = Future() + res = OffsetResponse[1]([ + ("topic", [(0, 3, -1, -1)]), + ]) + fetcher._handle_offset_response(fut, res) + assert fut.failed() + assert isinstance(fut.exception, UnknownTopicOrPartitionError) + + # Broker returns many errors and 1 result + # Will fail on 1st error and return + fut = Future() + res = OffsetResponse[1]([ + ("topic", [(0, 43, -1, -1)]), + ("topic", [(1, 6, -1, -1)]), + ("topic", [(2, 3, -1, -1)]), + ("topic", [(3, 0, 1000, 9999)]) + ]) + fetcher._handle_offset_response(fut, res) + assert fut.failed() + assert isinstance(fut.exception, NotLeaderForPartitionError) + + +def test_fetched_records(fetcher, topic, mocker): + fetcher.config['check_crcs'] = False + tp = TopicPartition(topic, 0) + + msgs = [] + for i in range(10): + msgs.append((None, b"foo", None)) + completed_fetch = CompletedFetch( + tp, 0, 0, [0, 100, _build_record_batch(msgs)], + mocker.MagicMock() + ) + fetcher._completed_fetches.append(completed_fetch) + records, partial = fetcher.fetched_records() + assert tp in records + assert len(records[tp]) == len(msgs) + assert all(map(lambda x: isinstance(x, ConsumerRecord), records[tp])) + assert partial is False + + +@pytest.mark.parametrize(("fetch_request", "fetch_response", "num_partitions"), [ + ( + FetchRequest[0]( + -1, 100, 100, + [('foo', [(0, 0, 1000),])]), + FetchResponse[0]( + [("foo", [(0, 0, 1000, [(0, b'xxx'),])]),]), + 1, + ), + ( + FetchRequest[1]( + -1, 100, 100, + [('foo', [(0, 0, 1000), (1, 0, 1000),])]), + FetchResponse[1]( + 0, + [("foo", [ + (0, 0, 1000, [(0, b'xxx'),]), + (1, 0, 1000, [(0, b'xxx'),]), + ]),]), + 2, + ), + ( + FetchRequest[2]( + -1, 100, 100, + [('foo', [(0, 0, 1000),])]), + FetchResponse[2]( + 0, [("foo", [(0, 0, 1000, [(0, b'xxx'),])]),]), + 1, + ), + ( + FetchRequest[3]( + -1, 100, 100, 10000, + [('foo', [(0, 0, 1000),])]), + FetchResponse[3]( + 0, [("foo", [(0, 0, 1000, [(0, b'xxx'),])]),]), + 1, + ), + ( + FetchRequest[4]( + -1, 100, 100, 10000, 0, + [('foo', [(0, 0, 1000),])]), + FetchResponse[4]( + 0, [("foo", [(0, 0, 1000, 0, [], [(0, b'xxx'),])]),]), + 1, + ), + ( + # This may only be used in broker-broker api calls + FetchRequest[5]( + -1, 100, 100, 10000, 0, + [('foo', [(0, 0, 1000),])]), + FetchResponse[5]( + 0, [("foo", [(0, 0, 1000, 0, 0, [], [(0, b'xxx'),])]),]), + 1, + ), +]) +def test__handle_fetch_response(fetcher, fetch_request, fetch_response, num_partitions): + fetcher._handle_fetch_response(fetch_request, time.time(), fetch_response) + assert len(fetcher._completed_fetches) == num_partitions + + +def test__unpack_message_set(fetcher): + fetcher.config['check_crcs'] = False + tp = TopicPartition('foo', 0) + messages = [ + (None, b"a", None), + (None, b"b", None), + (None, b"c", None), + ] + memory_records = MemoryRecords(_build_record_batch(messages)) + records = list(fetcher._unpack_message_set(tp, memory_records)) + assert len(records) == 3 + assert all(map(lambda x: isinstance(x, ConsumerRecord), records)) + assert records[0].value == b'a' + assert records[1].value == b'b' + assert records[2].value == b'c' + assert records[0].offset == 0 + assert records[1].offset == 1 + assert records[2].offset == 2 + + +def test__message_generator(fetcher, topic, mocker): + fetcher.config['check_crcs'] = False + tp = TopicPartition(topic, 0) + msgs = [] + for i in range(10): + msgs.append((None, b"foo", None)) + completed_fetch = CompletedFetch( + tp, 0, 0, [0, 100, _build_record_batch(msgs)], + mocker.MagicMock() + ) + fetcher._completed_fetches.append(completed_fetch) + for i in range(10): + msg = next(fetcher) + assert isinstance(msg, ConsumerRecord) + assert msg.offset == i + assert msg.value == b'foo' + + +def test__parse_fetched_data(fetcher, topic, mocker): + fetcher.config['check_crcs'] = False + tp = TopicPartition(topic, 0) + msgs = [] + for i in range(10): + msgs.append((None, b"foo", None)) + completed_fetch = CompletedFetch( + tp, 0, 0, [0, 100, _build_record_batch(msgs)], + mocker.MagicMock() + ) + partition_record = fetcher._parse_fetched_data(completed_fetch) + assert isinstance(partition_record, fetcher.PartitionRecords) + assert len(partition_record) == 10 + + +def test__parse_fetched_data__paused(fetcher, topic, mocker): + fetcher.config['check_crcs'] = False + tp = TopicPartition(topic, 0) + msgs = [] + for i in range(10): + msgs.append((None, b"foo", None)) + completed_fetch = CompletedFetch( + tp, 0, 0, [0, 100, _build_record_batch(msgs)], + mocker.MagicMock() + ) + fetcher._subscriptions.pause(tp) + partition_record = fetcher._parse_fetched_data(completed_fetch) + assert partition_record is None + + +def test__parse_fetched_data__stale_offset(fetcher, topic, mocker): + fetcher.config['check_crcs'] = False + tp = TopicPartition(topic, 0) + msgs = [] + for i in range(10): + msgs.append((None, b"foo", None)) + completed_fetch = CompletedFetch( + tp, 10, 0, [0, 100, _build_record_batch(msgs)], + mocker.MagicMock() + ) + partition_record = fetcher._parse_fetched_data(completed_fetch) + assert partition_record is None + + +def test__parse_fetched_data__not_leader(fetcher, topic, mocker): + fetcher.config['check_crcs'] = False + tp = TopicPartition(topic, 0) + completed_fetch = CompletedFetch( + tp, 0, 0, [NotLeaderForPartitionError.errno, -1, None], + mocker.MagicMock() + ) + partition_record = fetcher._parse_fetched_data(completed_fetch) + assert partition_record is None + fetcher._client.cluster.request_update.assert_called_with() + + +def test__parse_fetched_data__unknown_tp(fetcher, topic, mocker): + fetcher.config['check_crcs'] = False + tp = TopicPartition(topic, 0) + completed_fetch = CompletedFetch( + tp, 0, 0, [UnknownTopicOrPartitionError.errno, -1, None], + mocker.MagicMock() + ) + partition_record = fetcher._parse_fetched_data(completed_fetch) + assert partition_record is None + fetcher._client.cluster.request_update.assert_called_with() + + +def test__parse_fetched_data__out_of_range(fetcher, topic, mocker): + fetcher.config['check_crcs'] = False + tp = TopicPartition(topic, 0) + completed_fetch = CompletedFetch( + tp, 0, 0, [OffsetOutOfRangeError.errno, -1, None], + mocker.MagicMock() + ) + partition_record = fetcher._parse_fetched_data(completed_fetch) + assert partition_record is None + assert fetcher._subscriptions.assignment[tp].awaiting_reset is True + + +def test_partition_records_offset(): + """Test that compressed messagesets are handle correctly + when fetch offset is in the middle of the message list + """ + batch_start = 120 + batch_end = 130 + fetch_offset = 123 + tp = TopicPartition('foo', 0) + messages = [ConsumerRecord(tp.topic, tp.partition, i, + None, None, 'key', 'value', [], 'checksum', 0, 0, -1) + for i in range(batch_start, batch_end)] + records = Fetcher.PartitionRecords(fetch_offset, None, messages) + assert len(records) > 0 + msgs = records.take(1) + assert msgs[0].offset == fetch_offset + assert records.fetch_offset == fetch_offset + 1 + msgs = records.take(2) + assert len(msgs) == 2 + assert len(records) > 0 + records.discard() + assert len(records) == 0 + + +def test_partition_records_empty(): + records = Fetcher.PartitionRecords(0, None, []) + assert len(records) == 0 + + +def test_partition_records_no_fetch_offset(): + batch_start = 0 + batch_end = 100 + fetch_offset = 123 + tp = TopicPartition('foo', 0) + messages = [ConsumerRecord(tp.topic, tp.partition, i, + None, None, 'key', 'value', None, 'checksum', 0, 0, -1) + for i in range(batch_start, batch_end)] + records = Fetcher.PartitionRecords(fetch_offset, None, messages) + assert len(records) == 0 + + +def test_partition_records_compacted_offset(): + """Test that messagesets are handle correctly + when the fetch offset points to a message that has been compacted + """ + batch_start = 0 + batch_end = 100 + fetch_offset = 42 + tp = TopicPartition('foo', 0) + messages = [ConsumerRecord(tp.topic, tp.partition, i, + None, None, 'key', 'value', None, 'checksum', 0, 0, -1) + for i in range(batch_start, batch_end) if i != fetch_offset] + records = Fetcher.PartitionRecords(fetch_offset, None, messages) + assert len(records) == batch_end - fetch_offset - 1 + msgs = records.take(1) + assert msgs[0].offset == fetch_offset + 1 diff --git a/test_metrics.py b/test_metrics.py new file mode 100644 index 00000000..308ea583 --- /dev/null +++ b/test_metrics.py @@ -0,0 +1,499 @@ +import sys +import time + +import pytest + +from kafka.errors import QuotaViolationError +from kafka.metrics import DictReporter, MetricConfig, MetricName, Metrics, Quota +from kafka.metrics.measurable import AbstractMeasurable +from kafka.metrics.stats import (Avg, Count, Max, Min, Percentile, Percentiles, + Rate, Total) +from kafka.metrics.stats.percentiles import BucketSizing +from kafka.metrics.stats.rate import TimeUnit + +EPS = 0.000001 + + +@pytest.fixture +def time_keeper(): + return TimeKeeper() + + +@pytest.fixture +def config(): + return MetricConfig() + + +@pytest.fixture +def reporter(): + return DictReporter() + + +@pytest.fixture +def metrics(request, config, reporter): + metrics = Metrics(config, [reporter], enable_expiration=True) + yield metrics + metrics.close() + + +def test_MetricName(): + # The Java test only cover the differences between the deprecated + # constructors, so I'm skipping them but doing some other basic testing. + + # In short, metrics should be equal IFF their name, group, and tags are + # the same. Descriptions do not matter. + name1 = MetricName('name', 'group', 'A metric.', {'a': 1, 'b': 2}) + name2 = MetricName('name', 'group', 'A description.', {'a': 1, 'b': 2}) + assert name1 == name2 + + name1 = MetricName('name', 'group', tags={'a': 1, 'b': 2}) + name2 = MetricName('name', 'group', tags={'a': 1, 'b': 2}) + assert name1 == name2 + + name1 = MetricName('foo', 'group') + name2 = MetricName('name', 'group') + assert name1 != name2 + + name1 = MetricName('name', 'foo') + name2 = MetricName('name', 'group') + assert name1 != name2 + + # name and group must be non-empty. Everything else is optional. + with pytest.raises(Exception): + MetricName('', 'group') + with pytest.raises(Exception): + MetricName('name', None) + # tags must be a dict if supplied + with pytest.raises(Exception): + MetricName('name', 'group', tags=set()) + + # Because of the implementation of __eq__ and __hash__, the values of + # a MetricName cannot be mutable. + tags = {'a': 1} + name = MetricName('name', 'group', 'description', tags=tags) + with pytest.raises(AttributeError): + name.name = 'new name' + with pytest.raises(AttributeError): + name.group = 'new name' + with pytest.raises(AttributeError): + name.tags = {} + # tags is a copy, so the instance isn't altered + name.tags['b'] = 2 + assert name.tags == tags + + +def test_simple_stats(mocker, time_keeper, config, metrics): + mocker.patch('time.time', side_effect=time_keeper.time) + + measurable = ConstantMeasurable() + + metrics.add_metric(metrics.metric_name('direct.measurable', 'grp1', + 'The fraction of time an appender waits for space allocation.'), + measurable) + sensor = metrics.sensor('test.sensor') + sensor.add(metrics.metric_name('test.avg', 'grp1'), Avg()) + sensor.add(metrics.metric_name('test.max', 'grp1'), Max()) + sensor.add(metrics.metric_name('test.min', 'grp1'), Min()) + sensor.add(metrics.metric_name('test.rate', 'grp1'), Rate(TimeUnit.SECONDS)) + sensor.add(metrics.metric_name('test.occurences', 'grp1'),Rate(TimeUnit.SECONDS, Count())) + sensor.add(metrics.metric_name('test.count', 'grp1'), Count()) + percentiles = [Percentile(metrics.metric_name('test.median', 'grp1'), 50.0), + Percentile(metrics.metric_name('test.perc99_9', 'grp1'), 99.9)] + sensor.add_compound(Percentiles(100, BucketSizing.CONSTANT, 100, -100, + percentiles=percentiles)) + + sensor2 = metrics.sensor('test.sensor2') + sensor2.add(metrics.metric_name('s2.total', 'grp1'), Total()) + sensor2.record(5.0) + + sum_val = 0 + count = 10 + for i in range(count): + sensor.record(i) + sum_val += i + + # prior to any time passing + elapsed_secs = (config.time_window_ms * (config.samples - 1)) / 1000.0 + assert abs(count / elapsed_secs - + metrics.metrics.get(metrics.metric_name('test.occurences', 'grp1')).value()) \ + < EPS, 'Occurrences(0...%d) = %f' % (count, count / elapsed_secs) + + # pretend 2 seconds passed... + sleep_time_seconds = 2.0 + time_keeper.sleep(sleep_time_seconds) + elapsed_secs += sleep_time_seconds + + assert abs(5.0 - metrics.metrics.get(metrics.metric_name('s2.total', 'grp1')).value()) \ + < EPS, 's2 reflects the constant value' + assert abs(4.5 - metrics.metrics.get(metrics.metric_name('test.avg', 'grp1')).value()) \ + < EPS, 'Avg(0...9) = 4.5' + assert abs((count - 1) - metrics.metrics.get(metrics.metric_name('test.max', 'grp1')).value()) \ + < EPS, 'Max(0...9) = 9' + assert abs(0.0 - metrics.metrics.get(metrics.metric_name('test.min', 'grp1')).value()) \ + < EPS, 'Min(0...9) = 0' + assert abs((sum_val / elapsed_secs) - metrics.metrics.get(metrics.metric_name('test.rate', 'grp1')).value()) \ + < EPS, 'Rate(0...9) = 1.40625' + assert abs((count / elapsed_secs) - metrics.metrics.get(metrics.metric_name('test.occurences', 'grp1')).value()) \ + < EPS, 'Occurrences(0...%d) = %f' % (count, count / elapsed_secs) + assert abs(count - metrics.metrics.get(metrics.metric_name('test.count', 'grp1')).value()) \ + < EPS, 'Count(0...9) = 10' + + +def test_hierarchical_sensors(metrics): + parent1 = metrics.sensor('test.parent1') + parent1.add(metrics.metric_name('test.parent1.count', 'grp1'), Count()) + parent2 = metrics.sensor('test.parent2') + parent2.add(metrics.metric_name('test.parent2.count', 'grp1'), Count()) + child1 = metrics.sensor('test.child1', parents=[parent1, parent2]) + child1.add(metrics.metric_name('test.child1.count', 'grp1'), Count()) + child2 = metrics.sensor('test.child2', parents=[parent1]) + child2.add(metrics.metric_name('test.child2.count', 'grp1'), Count()) + grandchild = metrics.sensor('test.grandchild', parents=[child1]) + grandchild.add(metrics.metric_name('test.grandchild.count', 'grp1'), Count()) + + # increment each sensor one time + parent1.record() + parent2.record() + child1.record() + child2.record() + grandchild.record() + + p1 = parent1.metrics[0].value() + p2 = parent2.metrics[0].value() + c1 = child1.metrics[0].value() + c2 = child2.metrics[0].value() + gc = grandchild.metrics[0].value() + + # each metric should have a count equal to one + its children's count + assert 1.0 == gc + assert 1.0 + gc == c1 + assert 1.0 == c2 + assert 1.0 + c1 == p2 + assert 1.0 + c1 + c2 == p1 + assert [child1, child2] == metrics._children_sensors.get(parent1) + assert [child1] == metrics._children_sensors.get(parent2) + assert metrics._children_sensors.get(grandchild) is None + + +def test_bad_sensor_hierarchy(metrics): + parent = metrics.sensor('parent') + child1 = metrics.sensor('child1', parents=[parent]) + child2 = metrics.sensor('child2', parents=[parent]) + + with pytest.raises(ValueError): + metrics.sensor('gc', parents=[child1, child2]) + + +def test_remove_sensor(metrics): + size = len(metrics.metrics) + parent1 = metrics.sensor('test.parent1') + parent1.add(metrics.metric_name('test.parent1.count', 'grp1'), Count()) + parent2 = metrics.sensor('test.parent2') + parent2.add(metrics.metric_name('test.parent2.count', 'grp1'), Count()) + child1 = metrics.sensor('test.child1', parents=[parent1, parent2]) + child1.add(metrics.metric_name('test.child1.count', 'grp1'), Count()) + child2 = metrics.sensor('test.child2', parents=[parent2]) + child2.add(metrics.metric_name('test.child2.count', 'grp1'), Count()) + grandchild1 = metrics.sensor('test.gchild2', parents=[child2]) + grandchild1.add(metrics.metric_name('test.gchild2.count', 'grp1'), Count()) + + sensor = metrics.get_sensor('test.parent1') + assert sensor is not None + metrics.remove_sensor('test.parent1') + assert metrics.get_sensor('test.parent1') is None + assert metrics.metrics.get(metrics.metric_name('test.parent1.count', 'grp1')) is None + assert metrics.get_sensor('test.child1') is None + assert metrics._children_sensors.get(sensor) is None + assert metrics.metrics.get(metrics.metric_name('test.child1.count', 'grp1')) is None + + sensor = metrics.get_sensor('test.gchild2') + assert sensor is not None + metrics.remove_sensor('test.gchild2') + assert metrics.get_sensor('test.gchild2') is None + assert metrics._children_sensors.get(sensor) is None + assert metrics.metrics.get(metrics.metric_name('test.gchild2.count', 'grp1')) is None + + sensor = metrics.get_sensor('test.child2') + assert sensor is not None + metrics.remove_sensor('test.child2') + assert metrics.get_sensor('test.child2') is None + assert metrics._children_sensors.get(sensor) is None + assert metrics.metrics.get(metrics.metric_name('test.child2.count', 'grp1')) is None + + sensor = metrics.get_sensor('test.parent2') + assert sensor is not None + metrics.remove_sensor('test.parent2') + assert metrics.get_sensor('test.parent2') is None + assert metrics._children_sensors.get(sensor) is None + assert metrics.metrics.get(metrics.metric_name('test.parent2.count', 'grp1')) is None + + assert size == len(metrics.metrics) + + +def test_remove_inactive_metrics(mocker, time_keeper, metrics): + mocker.patch('time.time', side_effect=time_keeper.time) + + s1 = metrics.sensor('test.s1', None, 1) + s1.add(metrics.metric_name('test.s1.count', 'grp1'), Count()) + + s2 = metrics.sensor('test.s2', None, 3) + s2.add(metrics.metric_name('test.s2.count', 'grp1'), Count()) + + purger = Metrics.ExpireSensorTask + purger.run(metrics) + assert metrics.get_sensor('test.s1') is not None, \ + 'Sensor test.s1 must be present' + assert metrics.metrics.get(metrics.metric_name('test.s1.count', 'grp1')) is not None, \ + 'MetricName test.s1.count must be present' + assert metrics.get_sensor('test.s2') is not None, \ + 'Sensor test.s2 must be present' + assert metrics.metrics.get(metrics.metric_name('test.s2.count', 'grp1')) is not None, \ + 'MetricName test.s2.count must be present' + + time_keeper.sleep(1.001) + purger.run(metrics) + assert metrics.get_sensor('test.s1') is None, \ + 'Sensor test.s1 should have been purged' + assert metrics.metrics.get(metrics.metric_name('test.s1.count', 'grp1')) is None, \ + 'MetricName test.s1.count should have been purged' + assert metrics.get_sensor('test.s2') is not None, \ + 'Sensor test.s2 must be present' + assert metrics.metrics.get(metrics.metric_name('test.s2.count', 'grp1')) is not None, \ + 'MetricName test.s2.count must be present' + + # record a value in sensor s2. This should reset the clock for that sensor. + # It should not get purged at the 3 second mark after creation + s2.record() + + time_keeper.sleep(2) + purger.run(metrics) + assert metrics.get_sensor('test.s2') is not None, \ + 'Sensor test.s2 must be present' + assert metrics.metrics.get(metrics.metric_name('test.s2.count', 'grp1')) is not None, \ + 'MetricName test.s2.count must be present' + + # After another 1 second sleep, the metric should be purged + time_keeper.sleep(1) + purger.run(metrics) + assert metrics.get_sensor('test.s1') is None, \ + 'Sensor test.s2 should have been purged' + assert metrics.metrics.get(metrics.metric_name('test.s1.count', 'grp1')) is None, \ + 'MetricName test.s2.count should have been purged' + + # After purging, it should be possible to recreate a metric + s1 = metrics.sensor('test.s1', None, 1) + s1.add(metrics.metric_name('test.s1.count', 'grp1'), Count()) + assert metrics.get_sensor('test.s1') is not None, \ + 'Sensor test.s1 must be present' + assert metrics.metrics.get(metrics.metric_name('test.s1.count', 'grp1')) is not None, \ + 'MetricName test.s1.count must be present' + + +def test_remove_metric(metrics): + size = len(metrics.metrics) + metrics.add_metric(metrics.metric_name('test1', 'grp1'), Count()) + metrics.add_metric(metrics.metric_name('test2', 'grp1'), Count()) + + assert metrics.remove_metric(metrics.metric_name('test1', 'grp1')) is not None + assert metrics.metrics.get(metrics.metric_name('test1', 'grp1')) is None + assert metrics.metrics.get(metrics.metric_name('test2', 'grp1')) is not None + + assert metrics.remove_metric(metrics.metric_name('test2', 'grp1')) is not None + assert metrics.metrics.get(metrics.metric_name('test2', 'grp1')) is None + + assert size == len(metrics.metrics) + + +def test_event_windowing(mocker, time_keeper): + mocker.patch('time.time', side_effect=time_keeper.time) + + count = Count() + config = MetricConfig(event_window=1, samples=2) + count.record(config, 1.0, time_keeper.ms()) + count.record(config, 1.0, time_keeper.ms()) + assert 2.0 == count.measure(config, time_keeper.ms()) + count.record(config, 1.0, time_keeper.ms()) # first event times out + assert 2.0 == count.measure(config, time_keeper.ms()) + + +def test_time_windowing(mocker, time_keeper): + mocker.patch('time.time', side_effect=time_keeper.time) + + count = Count() + config = MetricConfig(time_window_ms=1, samples=2) + count.record(config, 1.0, time_keeper.ms()) + time_keeper.sleep(.001) + count.record(config, 1.0, time_keeper.ms()) + assert 2.0 == count.measure(config, time_keeper.ms()) + time_keeper.sleep(.001) + count.record(config, 1.0, time_keeper.ms()) # oldest event times out + assert 2.0 == count.measure(config, time_keeper.ms()) + + +def test_old_data_has_no_effect(mocker, time_keeper): + mocker.patch('time.time', side_effect=time_keeper.time) + + max_stat = Max() + min_stat = Min() + avg_stat = Avg() + count_stat = Count() + window_ms = 100 + samples = 2 + config = MetricConfig(time_window_ms=window_ms, samples=samples) + max_stat.record(config, 50, time_keeper.ms()) + min_stat.record(config, 50, time_keeper.ms()) + avg_stat.record(config, 50, time_keeper.ms()) + count_stat.record(config, 50, time_keeper.ms()) + + time_keeper.sleep(samples * window_ms / 1000.0) + assert float('-inf') == max_stat.measure(config, time_keeper.ms()) + assert float(sys.maxsize) == min_stat.measure(config, time_keeper.ms()) + assert 0.0 == avg_stat.measure(config, time_keeper.ms()) + assert 0 == count_stat.measure(config, time_keeper.ms()) + + +def test_duplicate_MetricName(metrics): + metrics.sensor('test').add(metrics.metric_name('test', 'grp1'), Avg()) + with pytest.raises(ValueError): + metrics.sensor('test2').add(metrics.metric_name('test', 'grp1'), Total()) + + +def test_Quotas(metrics): + sensor = metrics.sensor('test') + sensor.add(metrics.metric_name('test1.total', 'grp1'), Total(), + MetricConfig(quota=Quota.upper_bound(5.0))) + sensor.add(metrics.metric_name('test2.total', 'grp1'), Total(), + MetricConfig(quota=Quota.lower_bound(0.0))) + sensor.record(5.0) + with pytest.raises(QuotaViolationError): + sensor.record(1.0) + + assert abs(6.0 - metrics.metrics.get(metrics.metric_name('test1.total', 'grp1')).value()) \ + < EPS + + sensor.record(-6.0) + with pytest.raises(QuotaViolationError): + sensor.record(-1.0) + + +def test_Quotas_equality(): + quota1 = Quota.upper_bound(10.5) + quota2 = Quota.lower_bound(10.5) + assert quota1 != quota2, 'Quota with different upper values should not be equal' + + quota3 = Quota.lower_bound(10.5) + assert quota2 == quota3, 'Quota with same upper and bound values should be equal' + + +def test_Percentiles(metrics): + buckets = 100 + _percentiles = [ + Percentile(metrics.metric_name('test.p25', 'grp1'), 25), + Percentile(metrics.metric_name('test.p50', 'grp1'), 50), + Percentile(metrics.metric_name('test.p75', 'grp1'), 75), + ] + percs = Percentiles(4 * buckets, BucketSizing.CONSTANT, 100.0, 0.0, + percentiles=_percentiles) + config = MetricConfig(event_window=50, samples=2) + sensor = metrics.sensor('test', config) + sensor.add_compound(percs) + p25 = metrics.metrics.get(metrics.metric_name('test.p25', 'grp1')) + p50 = metrics.metrics.get(metrics.metric_name('test.p50', 'grp1')) + p75 = metrics.metrics.get(metrics.metric_name('test.p75', 'grp1')) + + # record two windows worth of sequential values + for i in range(buckets): + sensor.record(i) + + assert abs(p25.value() - 25) < 1.0 + assert abs(p50.value() - 50) < 1.0 + assert abs(p75.value() - 75) < 1.0 + + for i in range(buckets): + sensor.record(0.0) + + assert p25.value() < 1.0 + assert p50.value() < 1.0 + assert p75.value() < 1.0 + +def test_rate_windowing(mocker, time_keeper, metrics): + mocker.patch('time.time', side_effect=time_keeper.time) + + # Use the default time window. Set 3 samples + config = MetricConfig(samples=3) + sensor = metrics.sensor('test.sensor', config) + sensor.add(metrics.metric_name('test.rate', 'grp1'), Rate(TimeUnit.SECONDS)) + + sum_val = 0 + count = config.samples - 1 + # Advance 1 window after every record + for i in range(count): + sensor.record(100) + sum_val += 100 + time_keeper.sleep(config.time_window_ms / 1000.0) + + # Sleep for half the window. + time_keeper.sleep(config.time_window_ms / 2.0 / 1000.0) + + # prior to any time passing + elapsed_secs = (config.time_window_ms * (config.samples - 1) + config.time_window_ms / 2.0) / 1000.0 + + kafka_metric = metrics.metrics.get(metrics.metric_name('test.rate', 'grp1')) + assert abs((sum_val / elapsed_secs) - kafka_metric.value()) < EPS, \ + 'Rate(0...2) = 2.666' + assert abs(elapsed_secs - (kafka_metric.measurable.window_size(config, time.time() * 1000) / 1000.0)) \ + < EPS, 'Elapsed Time = 75 seconds' + + +def test_reporter(metrics): + reporter = DictReporter() + foo_reporter = DictReporter(prefix='foo') + metrics.add_reporter(reporter) + metrics.add_reporter(foo_reporter) + sensor = metrics.sensor('kafka.requests') + sensor.add(metrics.metric_name('pack.bean1.avg', 'grp1'), Avg()) + sensor.add(metrics.metric_name('pack.bean2.total', 'grp2'), Total()) + sensor2 = metrics.sensor('kafka.blah') + sensor2.add(metrics.metric_name('pack.bean1.some', 'grp1'), Total()) + sensor2.add(metrics.metric_name('pack.bean2.some', 'grp1', + tags={'a': 42, 'b': 'bar'}), Total()) + + # kafka-metrics-count > count is the total number of metrics and automatic + expected = { + 'kafka-metrics-count': {'count': 5.0}, + 'grp2': {'pack.bean2.total': 0.0}, + 'grp1': {'pack.bean1.avg': 0.0, 'pack.bean1.some': 0.0}, + 'grp1.a=42,b=bar': {'pack.bean2.some': 0.0}, + } + assert expected == reporter.snapshot() + + for key in list(expected.keys()): + metrics = expected.pop(key) + expected['foo.%s' % (key,)] = metrics + assert expected == foo_reporter.snapshot() + + +class ConstantMeasurable(AbstractMeasurable): + _value = 0.0 + + def measure(self, config, now): + return self._value + + +class TimeKeeper(object): + """ + A clock that you can manually advance by calling sleep + """ + def __init__(self, auto_tick_ms=0): + self._millis = time.time() * 1000 + self._auto_tick_ms = auto_tick_ms + + def time(self): + return self.ms() / 1000.0 + + def ms(self): + self.sleep(self._auto_tick_ms) + return self._millis + + def sleep(self, seconds): + self._millis += (seconds * 1000) diff --git a/test_object_conversion.py b/test_object_conversion.py new file mode 100644 index 00000000..9b1ff213 --- /dev/null +++ b/test_object_conversion.py @@ -0,0 +1,236 @@ +from kafka.protocol.admin import Request +from kafka.protocol.admin import Response +from kafka.protocol.types import Schema +from kafka.protocol.types import Array +from kafka.protocol.types import Int16 +from kafka.protocol.types import String + +import pytest + +@pytest.mark.parametrize('superclass', (Request, Response)) +class TestObjectConversion: + def test_get_item(self, superclass): + class TestClass(superclass): + API_KEY = 0 + API_VERSION = 0 + RESPONSE_TYPE = None # To satisfy the Request ABC + SCHEMA = Schema( + ('myobject', Int16)) + + tc = TestClass(myobject=0) + assert tc.get_item('myobject') == 0 + with pytest.raises(KeyError): + tc.get_item('does-not-exist') + + def test_with_empty_schema(self, superclass): + class TestClass(superclass): + API_KEY = 0 + API_VERSION = 0 + RESPONSE_TYPE = None # To satisfy the Request ABC + SCHEMA = Schema() + + tc = TestClass() + tc.encode() + assert tc.to_object() == {} + + def test_with_basic_schema(self, superclass): + class TestClass(superclass): + API_KEY = 0 + API_VERSION = 0 + RESPONSE_TYPE = None # To satisfy the Request ABC + SCHEMA = Schema( + ('myobject', Int16)) + + tc = TestClass(myobject=0) + tc.encode() + assert tc.to_object() == {'myobject': 0} + + def test_with_basic_array_schema(self, superclass): + class TestClass(superclass): + API_KEY = 0 + API_VERSION = 0 + RESPONSE_TYPE = None # To satisfy the Request ABC + SCHEMA = Schema( + ('myarray', Array(Int16))) + + tc = TestClass(myarray=[1,2,3]) + tc.encode() + assert tc.to_object()['myarray'] == [1, 2, 3] + + def test_with_complex_array_schema(self, superclass): + class TestClass(superclass): + API_KEY = 0 + API_VERSION = 0 + RESPONSE_TYPE = None # To satisfy the Request ABC + SCHEMA = Schema( + ('myarray', Array( + ('subobject', Int16), + ('othersubobject', String('utf-8'))))) + + tc = TestClass( + myarray=[[10, 'hello']] + ) + tc.encode() + obj = tc.to_object() + assert len(obj['myarray']) == 1 + assert obj['myarray'][0]['subobject'] == 10 + assert obj['myarray'][0]['othersubobject'] == 'hello' + + def test_with_array_and_other(self, superclass): + class TestClass(superclass): + API_KEY = 0 + API_VERSION = 0 + RESPONSE_TYPE = None # To satisfy the Request ABC + SCHEMA = Schema( + ('myarray', Array( + ('subobject', Int16), + ('othersubobject', String('utf-8')))), + ('notarray', Int16)) + + tc = TestClass( + myarray=[[10, 'hello']], + notarray=42 + ) + + obj = tc.to_object() + assert len(obj['myarray']) == 1 + assert obj['myarray'][0]['subobject'] == 10 + assert obj['myarray'][0]['othersubobject'] == 'hello' + assert obj['notarray'] == 42 + + def test_with_nested_array(self, superclass): + class TestClass(superclass): + API_KEY = 0 + API_VERSION = 0 + RESPONSE_TYPE = None # To satisfy the Request ABC + SCHEMA = Schema( + ('myarray', Array( + ('subarray', Array(Int16)), + ('otherobject', Int16)))) + + tc = TestClass( + myarray=[ + [[1, 2], 2], + [[2, 3], 4], + ] + ) + print(tc.encode()) + + + obj = tc.to_object() + assert len(obj['myarray']) == 2 + assert obj['myarray'][0]['subarray'] == [1, 2] + assert obj['myarray'][0]['otherobject'] == 2 + assert obj['myarray'][1]['subarray'] == [2, 3] + assert obj['myarray'][1]['otherobject'] == 4 + + def test_with_complex_nested_array(self, superclass): + class TestClass(superclass): + API_KEY = 0 + API_VERSION = 0 + RESPONSE_TYPE = None # To satisfy the Request ABC + SCHEMA = Schema( + ('myarray', Array( + ('subarray', Array( + ('innertest', String('utf-8')), + ('otherinnertest', String('utf-8')))), + ('othersubarray', Array(Int16)))), + ('notarray', String('utf-8'))) + + tc = TestClass( + myarray=[ + [[['hello', 'hello'], ['hello again', 'hello again']], [0]], + [[['hello', 'hello again']], [1]], + ], + notarray='notarray' + ) + tc.encode() + + obj = tc.to_object() + + assert obj['notarray'] == 'notarray' + myarray = obj['myarray'] + assert len(myarray) == 2 + + assert myarray[0]['othersubarray'] == [0] + assert len(myarray[0]['subarray']) == 2 + assert myarray[0]['subarray'][0]['innertest'] == 'hello' + assert myarray[0]['subarray'][0]['otherinnertest'] == 'hello' + assert myarray[0]['subarray'][1]['innertest'] == 'hello again' + assert myarray[0]['subarray'][1]['otherinnertest'] == 'hello again' + + assert myarray[1]['othersubarray'] == [1] + assert len(myarray[1]['subarray']) == 1 + assert myarray[1]['subarray'][0]['innertest'] == 'hello' + assert myarray[1]['subarray'][0]['otherinnertest'] == 'hello again' + +def test_with_metadata_response(): + from kafka.protocol.metadata import MetadataResponse_v5 + tc = MetadataResponse_v5( + throttle_time_ms=0, + brokers=[ + [0, 'testhost0', 9092, 'testrack0'], + [1, 'testhost1', 9092, 'testrack1'], + ], + cluster_id='abcd', + controller_id=0, + topics=[ + [0, 'testtopic1', False, [ + [0, 0, 0, [0, 1], [0, 1], []], + [0, 1, 1, [1, 0], [1, 0], []], + ], + ], [0, 'other-test-topic', True, [ + [0, 0, 0, [0, 1], [0, 1], []], + ] + ]] + ) + tc.encode() # Make sure this object encodes successfully + + + obj = tc.to_object() + + assert obj['throttle_time_ms'] == 0 + + assert len(obj['brokers']) == 2 + assert obj['brokers'][0]['node_id'] == 0 + assert obj['brokers'][0]['host'] == 'testhost0' + assert obj['brokers'][0]['port'] == 9092 + assert obj['brokers'][0]['rack'] == 'testrack0' + assert obj['brokers'][1]['node_id'] == 1 + assert obj['brokers'][1]['host'] == 'testhost1' + assert obj['brokers'][1]['port'] == 9092 + assert obj['brokers'][1]['rack'] == 'testrack1' + + assert obj['cluster_id'] == 'abcd' + assert obj['controller_id'] == 0 + + assert len(obj['topics']) == 2 + assert obj['topics'][0]['error_code'] == 0 + assert obj['topics'][0]['topic'] == 'testtopic1' + assert obj['topics'][0]['is_internal'] == False + assert len(obj['topics'][0]['partitions']) == 2 + assert obj['topics'][0]['partitions'][0]['error_code'] == 0 + assert obj['topics'][0]['partitions'][0]['partition'] == 0 + assert obj['topics'][0]['partitions'][0]['leader'] == 0 + assert obj['topics'][0]['partitions'][0]['replicas'] == [0, 1] + assert obj['topics'][0]['partitions'][0]['isr'] == [0, 1] + assert obj['topics'][0]['partitions'][0]['offline_replicas'] == [] + assert obj['topics'][0]['partitions'][1]['error_code'] == 0 + assert obj['topics'][0]['partitions'][1]['partition'] == 1 + assert obj['topics'][0]['partitions'][1]['leader'] == 1 + assert obj['topics'][0]['partitions'][1]['replicas'] == [1, 0] + assert obj['topics'][0]['partitions'][1]['isr'] == [1, 0] + assert obj['topics'][0]['partitions'][1]['offline_replicas'] == [] + + assert obj['topics'][1]['error_code'] == 0 + assert obj['topics'][1]['topic'] == 'other-test-topic' + assert obj['topics'][1]['is_internal'] == True + assert len(obj['topics'][1]['partitions']) == 1 + assert obj['topics'][1]['partitions'][0]['error_code'] == 0 + assert obj['topics'][1]['partitions'][0]['partition'] == 0 + assert obj['topics'][1]['partitions'][0]['leader'] == 0 + assert obj['topics'][1]['partitions'][0]['replicas'] == [0, 1] + assert obj['topics'][1]['partitions'][0]['isr'] == [0, 1] + assert obj['topics'][1]['partitions'][0]['offline_replicas'] == [] + + tc.encode() diff --git a/test_package.py b/test_package.py new file mode 100644 index 00000000..aa42c9ce --- /dev/null +++ b/test_package.py @@ -0,0 +1,25 @@ +class TestPackage: + def test_top_level_namespace(self): + import kafka as kafka1 + assert kafka1.KafkaConsumer.__name__ == "KafkaConsumer" + assert kafka1.consumer.__name__ == "kafka.consumer" + assert kafka1.codec.__name__ == "kafka.codec" + + def test_submodule_namespace(self): + import kafka.client_async as client1 + assert client1.__name__ == "kafka.client_async" + + from kafka import client_async as client2 + assert client2.__name__ == "kafka.client_async" + + from kafka.client_async import KafkaClient as KafkaClient1 + assert KafkaClient1.__name__ == "KafkaClient" + + from kafka import KafkaClient as KafkaClient2 + assert KafkaClient2.__name__ == "KafkaClient" + + from kafka.codec import gzip_encode as gzip_encode1 + assert gzip_encode1.__name__ == "gzip_encode" + + from kafka.codec import snappy_encode + assert snappy_encode.__name__ == "snappy_encode" diff --git a/test_partition_movements.py b/test_partition_movements.py new file mode 100644 index 00000000..bc990bf3 --- /dev/null +++ b/test_partition_movements.py @@ -0,0 +1,23 @@ +from kafka.structs import TopicPartition + +from kafka.coordinator.assignors.sticky.partition_movements import PartitionMovements + + +def test_empty_movements_are_sticky(): + partition_movements = PartitionMovements() + assert partition_movements.are_sticky() + + +def test_sticky_movements(): + partition_movements = PartitionMovements() + partition_movements.move_partition(TopicPartition('t', 1), 'C1', 'C2') + partition_movements.move_partition(TopicPartition('t', 1), 'C2', 'C3') + partition_movements.move_partition(TopicPartition('t', 1), 'C3', 'C1') + assert partition_movements.are_sticky() + + +def test_should_detect_non_sticky_assignment(): + partition_movements = PartitionMovements() + partition_movements.move_partition(TopicPartition('t', 1), 'C1', 'C2') + partition_movements.move_partition(TopicPartition('t', 2), 'C2', 'C1') + assert not partition_movements.are_sticky() diff --git a/test_partitioner.py b/test_partitioner.py new file mode 100644 index 00000000..853fbf69 --- /dev/null +++ b/test_partitioner.py @@ -0,0 +1,38 @@ +from __future__ import absolute_import + +import pytest + +from kafka.partitioner import DefaultPartitioner, murmur2 + + +def test_default_partitioner(): + partitioner = DefaultPartitioner() + all_partitions = available = list(range(100)) + # partitioner should return the same partition for the same key + p1 = partitioner(b'foo', all_partitions, available) + p2 = partitioner(b'foo', all_partitions, available) + assert p1 == p2 + assert p1 in all_partitions + + # when key is None, choose one of available partitions + assert partitioner(None, all_partitions, [123]) == 123 + + # with fallback to all_partitions + assert partitioner(None, all_partitions, []) in all_partitions + + +@pytest.mark.parametrize("bytes_payload,partition_number", [ + (b'', 681), (b'a', 524), (b'ab', 434), (b'abc', 107), (b'123456789', 566), + (b'\x00 ', 742) +]) +def test_murmur2_java_compatibility(bytes_payload, partition_number): + partitioner = DefaultPartitioner() + all_partitions = available = list(range(1000)) + # compare with output from Kafka's org.apache.kafka.clients.producer.Partitioner + assert partitioner(bytes_payload, all_partitions, available) == partition_number + + +def test_murmur2_not_ascii(): + # Verify no regression of murmur2() bug encoding py2 bytes that don't ascii encode + murmur2(b'\xa4') + murmur2(b'\x81' * 1000) diff --git a/test_producer.py b/test_producer.py new file mode 100644 index 00000000..7263130d --- /dev/null +++ b/test_producer.py @@ -0,0 +1,137 @@ +import gc +import platform +import time +import threading + +import pytest + +from kafka import KafkaConsumer, KafkaProducer, TopicPartition +from kafka.producer.buffer import SimpleBufferPool +from test.testutil import env_kafka_version, random_string + + +def test_buffer_pool(): + pool = SimpleBufferPool(1000, 1000) + + buf1 = pool.allocate(1000, 1000) + message = ''.join(map(str, range(100))) + buf1.write(message.encode('utf-8')) + pool.deallocate(buf1) + + buf2 = pool.allocate(1000, 1000) + assert buf2.read() == b'' + + +@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set") +@pytest.mark.parametrize("compression", [None, 'gzip', 'snappy', 'lz4', 'zstd']) +def test_end_to_end(kafka_broker, compression): + if compression == 'lz4': + if env_kafka_version() < (0, 8, 2): + pytest.skip('LZ4 requires 0.8.2') + elif platform.python_implementation() == 'PyPy': + pytest.skip('python-lz4 crashes on older versions of pypy') + + if compression == 'zstd' and env_kafka_version() < (2, 1, 0): + pytest.skip('zstd requires kafka 2.1.0 or newer') + + connect_str = ':'.join([kafka_broker.host, str(kafka_broker.port)]) + producer = KafkaProducer(bootstrap_servers=connect_str, + retries=5, + max_block_ms=30000, + compression_type=compression, + value_serializer=str.encode) + consumer = KafkaConsumer(bootstrap_servers=connect_str, + group_id=None, + consumer_timeout_ms=30000, + auto_offset_reset='earliest', + value_deserializer=bytes.decode) + + topic = random_string(5) + + messages = 100 + futures = [] + for i in range(messages): + futures.append(producer.send(topic, 'msg %d' % i)) + ret = [f.get(timeout=30) for f in futures] + assert len(ret) == messages + producer.close() + + consumer.subscribe([topic]) + msgs = set() + for i in range(messages): + try: + msgs.add(next(consumer).value) + except StopIteration: + break + + assert msgs == set(['msg %d' % (i,) for i in range(messages)]) + consumer.close() + + +@pytest.mark.skipif(platform.python_implementation() != 'CPython', + reason='Test relies on CPython-specific gc policies') +def test_kafka_producer_gc_cleanup(): + gc.collect() + threads = threading.active_count() + producer = KafkaProducer(api_version='0.9') # set api_version explicitly to avoid auto-detection + assert threading.active_count() == threads + 1 + del(producer) + gc.collect() + assert threading.active_count() == threads + + +@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set") +@pytest.mark.parametrize("compression", [None, 'gzip', 'snappy', 'lz4', 'zstd']) +def test_kafka_producer_proper_record_metadata(kafka_broker, compression): + if compression == 'zstd' and env_kafka_version() < (2, 1, 0): + pytest.skip('zstd requires 2.1.0 or more') + connect_str = ':'.join([kafka_broker.host, str(kafka_broker.port)]) + producer = KafkaProducer(bootstrap_servers=connect_str, + retries=5, + max_block_ms=30000, + compression_type=compression) + magic = producer._max_usable_produce_magic() + + # record headers are supported in 0.11.0 + if env_kafka_version() < (0, 11, 0): + headers = None + else: + headers = [("Header Key", b"Header Value")] + + topic = random_string(5) + future = producer.send( + topic, + value=b"Simple value", key=b"Simple key", headers=headers, timestamp_ms=9999999, + partition=0) + record = future.get(timeout=5) + assert record is not None + assert record.topic == topic + assert record.partition == 0 + assert record.topic_partition == TopicPartition(topic, 0) + assert record.offset == 0 + if magic >= 1: + assert record.timestamp == 9999999 + else: + assert record.timestamp == -1 # NO_TIMESTAMP + + if magic >= 2: + assert record.checksum is None + elif magic == 1: + assert record.checksum == 1370034956 + else: + assert record.checksum == 3296137851 + + assert record.serialized_key_size == 10 + assert record.serialized_value_size == 12 + if headers: + assert record.serialized_header_size == 22 + + if magic == 0: + pytest.skip('generated timestamp case is skipped for broker 0.9 and below') + send_time = time.time() * 1000 + future = producer.send( + topic, + value=b"Simple value", key=b"Simple key", timestamp_ms=None, + partition=0) + record = future.get(timeout=5) + assert abs(record.timestamp - send_time) <= 1000 # Allow 1s deviation diff --git a/test_protocol.py b/test_protocol.py new file mode 100644 index 00000000..6a77e19d --- /dev/null +++ b/test_protocol.py @@ -0,0 +1,336 @@ +#pylint: skip-file +import io +import struct + +import pytest + +from kafka.protocol.api import RequestHeader +from kafka.protocol.commit import GroupCoordinatorRequest +from kafka.protocol.fetch import FetchRequest, FetchResponse +from kafka.protocol.message import Message, MessageSet, PartialMessage +from kafka.protocol.metadata import MetadataRequest +from kafka.protocol.types import Int16, Int32, Int64, String, UnsignedVarInt32, CompactString, CompactArray, CompactBytes + + +def test_create_message(): + payload = b'test' + key = b'key' + msg = Message(payload, key=key) + assert msg.magic == 0 + assert msg.attributes == 0 + assert msg.key == key + assert msg.value == payload + + +def test_encode_message_v0(): + message = Message(b'test', key=b'key') + encoded = message.encode() + expect = b''.join([ + struct.pack('>i', -1427009701), # CRC + struct.pack('>bb', 0, 0), # Magic, flags + struct.pack('>i', 3), # Length of key + b'key', # key + struct.pack('>i', 4), # Length of value + b'test', # value + ]) + assert encoded == expect + + +def test_encode_message_v1(): + message = Message(b'test', key=b'key', magic=1, timestamp=1234) + encoded = message.encode() + expect = b''.join([ + struct.pack('>i', 1331087195), # CRC + struct.pack('>bb', 1, 0), # Magic, flags + struct.pack('>q', 1234), # Timestamp + struct.pack('>i', 3), # Length of key + b'key', # key + struct.pack('>i', 4), # Length of value + b'test', # value + ]) + assert encoded == expect + + +def test_decode_message(): + encoded = b''.join([ + struct.pack('>i', -1427009701), # CRC + struct.pack('>bb', 0, 0), # Magic, flags + struct.pack('>i', 3), # Length of key + b'key', # key + struct.pack('>i', 4), # Length of value + b'test', # value + ]) + decoded_message = Message.decode(encoded) + msg = Message(b'test', key=b'key') + msg.encode() # crc is recalculated during encoding + assert decoded_message == msg + + +def test_decode_message_validate_crc(): + encoded = b''.join([ + struct.pack('>i', -1427009701), # CRC + struct.pack('>bb', 0, 0), # Magic, flags + struct.pack('>i', 3), # Length of key + b'key', # key + struct.pack('>i', 4), # Length of value + b'test', # value + ]) + decoded_message = Message.decode(encoded) + assert decoded_message.validate_crc() is True + + encoded = b''.join([ + struct.pack('>i', 1234), # Incorrect CRC + struct.pack('>bb', 0, 0), # Magic, flags + struct.pack('>i', 3), # Length of key + b'key', # key + struct.pack('>i', 4), # Length of value + b'test', # value + ]) + decoded_message = Message.decode(encoded) + assert decoded_message.validate_crc() is False + + +def test_encode_message_set(): + messages = [ + Message(b'v1', key=b'k1'), + Message(b'v2', key=b'k2') + ] + encoded = MessageSet.encode([(0, msg.encode()) + for msg in messages]) + expect = b''.join([ + struct.pack('>q', 0), # MsgSet Offset + struct.pack('>i', 18), # Msg Size + struct.pack('>i', 1474775406), # CRC + struct.pack('>bb', 0, 0), # Magic, flags + struct.pack('>i', 2), # Length of key + b'k1', # Key + struct.pack('>i', 2), # Length of value + b'v1', # Value + + struct.pack('>q', 0), # MsgSet Offset + struct.pack('>i', 18), # Msg Size + struct.pack('>i', -16383415), # CRC + struct.pack('>bb', 0, 0), # Magic, flags + struct.pack('>i', 2), # Length of key + b'k2', # Key + struct.pack('>i', 2), # Length of value + b'v2', # Value + ]) + expect = struct.pack('>i', len(expect)) + expect + assert encoded == expect + + +def test_decode_message_set(): + encoded = b''.join([ + struct.pack('>q', 0), # MsgSet Offset + struct.pack('>i', 18), # Msg Size + struct.pack('>i', 1474775406), # CRC + struct.pack('>bb', 0, 0), # Magic, flags + struct.pack('>i', 2), # Length of key + b'k1', # Key + struct.pack('>i', 2), # Length of value + b'v1', # Value + + struct.pack('>q', 1), # MsgSet Offset + struct.pack('>i', 18), # Msg Size + struct.pack('>i', -16383415), # CRC + struct.pack('>bb', 0, 0), # Magic, flags + struct.pack('>i', 2), # Length of key + b'k2', # Key + struct.pack('>i', 2), # Length of value + b'v2', # Value + ]) + + msgs = MessageSet.decode(encoded, bytes_to_read=len(encoded)) + assert len(msgs) == 2 + msg1, msg2 = msgs + + returned_offset1, message1_size, decoded_message1 = msg1 + returned_offset2, message2_size, decoded_message2 = msg2 + + assert returned_offset1 == 0 + message1 = Message(b'v1', key=b'k1') + message1.encode() + assert decoded_message1 == message1 + + assert returned_offset2 == 1 + message2 = Message(b'v2', key=b'k2') + message2.encode() + assert decoded_message2 == message2 + + +def test_encode_message_header(): + expect = b''.join([ + struct.pack('>h', 10), # API Key + struct.pack('>h', 0), # API Version + struct.pack('>i', 4), # Correlation Id + struct.pack('>h', len('client3')), # Length of clientId + b'client3', # ClientId + ]) + + req = GroupCoordinatorRequest[0]('foo') + header = RequestHeader(req, correlation_id=4, client_id='client3') + assert header.encode() == expect + + +def test_decode_message_set_partial(): + encoded = b''.join([ + struct.pack('>q', 0), # Msg Offset + struct.pack('>i', 18), # Msg Size + struct.pack('>i', 1474775406), # CRC + struct.pack('>bb', 0, 0), # Magic, flags + struct.pack('>i', 2), # Length of key + b'k1', # Key + struct.pack('>i', 2), # Length of value + b'v1', # Value + + struct.pack('>q', 1), # Msg Offset + struct.pack('>i', 24), # Msg Size (larger than remaining MsgSet size) + struct.pack('>i', -16383415), # CRC + struct.pack('>bb', 0, 0), # Magic, flags + struct.pack('>i', 2), # Length of key + b'k2', # Key + struct.pack('>i', 8), # Length of value + b'ar', # Value (truncated) + ]) + + msgs = MessageSet.decode(encoded, bytes_to_read=len(encoded)) + assert len(msgs) == 2 + msg1, msg2 = msgs + + returned_offset1, message1_size, decoded_message1 = msg1 + returned_offset2, message2_size, decoded_message2 = msg2 + + assert returned_offset1 == 0 + message1 = Message(b'v1', key=b'k1') + message1.encode() + assert decoded_message1 == message1 + + assert returned_offset2 is None + assert message2_size is None + assert decoded_message2 == PartialMessage() + + +def test_decode_fetch_response_partial(): + encoded = b''.join([ + Int32.encode(1), # Num Topics (Array) + String('utf-8').encode('foobar'), + Int32.encode(2), # Num Partitions (Array) + Int32.encode(0), # Partition id + Int16.encode(0), # Error Code + Int64.encode(1234), # Highwater offset + Int32.encode(52), # MessageSet size + Int64.encode(0), # Msg Offset + Int32.encode(18), # Msg Size + struct.pack('>i', 1474775406), # CRC + struct.pack('>bb', 0, 0), # Magic, flags + struct.pack('>i', 2), # Length of key + b'k1', # Key + struct.pack('>i', 2), # Length of value + b'v1', # Value + + Int64.encode(1), # Msg Offset + struct.pack('>i', 24), # Msg Size (larger than remaining MsgSet size) + struct.pack('>i', -16383415), # CRC + struct.pack('>bb', 0, 0), # Magic, flags + struct.pack('>i', 2), # Length of key + b'k2', # Key + struct.pack('>i', 8), # Length of value + b'ar', # Value (truncated) + Int32.encode(1), + Int16.encode(0), + Int64.encode(2345), + Int32.encode(52), # MessageSet size + Int64.encode(0), # Msg Offset + Int32.encode(18), # Msg Size + struct.pack('>i', 1474775406), # CRC + struct.pack('>bb', 0, 0), # Magic, flags + struct.pack('>i', 2), # Length of key + b'k1', # Key + struct.pack('>i', 2), # Length of value + b'v1', # Value + + Int64.encode(1), # Msg Offset + struct.pack('>i', 24), # Msg Size (larger than remaining MsgSet size) + struct.pack('>i', -16383415), # CRC + struct.pack('>bb', 0, 0), # Magic, flags + struct.pack('>i', 2), # Length of key + b'k2', # Key + struct.pack('>i', 8), # Length of value + b'ar', # Value (truncated) + ]) + resp = FetchResponse[0].decode(io.BytesIO(encoded)) + assert len(resp.topics) == 1 + topic, partitions = resp.topics[0] + assert topic == 'foobar' + assert len(partitions) == 2 + + m1 = MessageSet.decode( + partitions[0][3], bytes_to_read=len(partitions[0][3])) + assert len(m1) == 2 + assert m1[1] == (None, None, PartialMessage()) + + +def test_struct_unrecognized_kwargs(): + try: + mr = MetadataRequest[0](topicz='foo') + assert False, 'Structs should not allow unrecognized kwargs' + except ValueError: + pass + + +def test_struct_missing_kwargs(): + fr = FetchRequest[0](max_wait_time=100) + assert fr.min_bytes is None + + +def test_unsigned_varint_serde(): + pairs = { + 0: [0], + -1: [0xff, 0xff, 0xff, 0xff, 0x0f], + 1: [1], + 63: [0x3f], + -64: [0xc0, 0xff, 0xff, 0xff, 0x0f], + 64: [0x40], + 8191: [0xff, 0x3f], + -8192: [0x80, 0xc0, 0xff, 0xff, 0x0f], + 8192: [0x80, 0x40], + -8193: [0xff, 0xbf, 0xff, 0xff, 0x0f], + 1048575: [0xff, 0xff, 0x3f], + + } + for value, expected_encoded in pairs.items(): + value &= 0xffffffff + encoded = UnsignedVarInt32.encode(value) + assert encoded == b''.join(struct.pack('>B', x) for x in expected_encoded) + assert value == UnsignedVarInt32.decode(io.BytesIO(encoded)) + + +def test_compact_data_structs(): + cs = CompactString() + encoded = cs.encode(None) + assert encoded == struct.pack('B', 0) + decoded = cs.decode(io.BytesIO(encoded)) + assert decoded is None + assert b'\x01' == cs.encode('') + assert '' == cs.decode(io.BytesIO(b'\x01')) + encoded = cs.encode("foobarbaz") + assert cs.decode(io.BytesIO(encoded)) == "foobarbaz" + + arr = CompactArray(CompactString()) + assert arr.encode(None) == b'\x00' + assert arr.decode(io.BytesIO(b'\x00')) is None + enc = arr.encode([]) + assert enc == b'\x01' + assert [] == arr.decode(io.BytesIO(enc)) + encoded = arr.encode(["foo", "bar", "baz", "quux"]) + assert arr.decode(io.BytesIO(encoded)) == ["foo", "bar", "baz", "quux"] + + enc = CompactBytes.encode(None) + assert enc == b'\x00' + assert CompactBytes.decode(io.BytesIO(b'\x00')) is None + enc = CompactBytes.encode(b'') + assert enc == b'\x01' + assert CompactBytes.decode(io.BytesIO(b'\x01')) is b'' + enc = CompactBytes.encode(b'foo') + assert CompactBytes.decode(io.BytesIO(enc)) == b'foo' diff --git a/test_sasl_integration.py b/test_sasl_integration.py new file mode 100644 index 00000000..e3a4813a --- /dev/null +++ b/test_sasl_integration.py @@ -0,0 +1,80 @@ +import logging +import uuid + +import pytest + +from kafka.admin import NewTopic +from kafka.protocol.metadata import MetadataRequest_v1 +from test.testutil import assert_message_count, env_kafka_version, random_string, special_to_underscore + + +@pytest.fixture( + params=[ + pytest.param( + "PLAIN", marks=pytest.mark.skipif(env_kafka_version() < (0, 10), reason="Requires KAFKA_VERSION >= 0.10") + ), + pytest.param( + "SCRAM-SHA-256", + marks=pytest.mark.skipif(env_kafka_version() < (0, 10, 2), reason="Requires KAFKA_VERSION >= 0.10.2"), + ), + pytest.param( + "SCRAM-SHA-512", + marks=pytest.mark.skipif(env_kafka_version() < (0, 10, 2), reason="Requires KAFKA_VERSION >= 0.10.2"), + ), + ] +) +def sasl_kafka(request, kafka_broker_factory): + sasl_kafka = kafka_broker_factory(transport="SASL_PLAINTEXT", sasl_mechanism=request.param)[0] + yield sasl_kafka + sasl_kafka.child.dump_logs() + + +def test_admin(request, sasl_kafka): + topic_name = special_to_underscore(request.node.name + random_string(4)) + admin, = sasl_kafka.get_admin_clients(1) + admin.create_topics([NewTopic(topic_name, 1, 1)]) + assert topic_name in sasl_kafka.get_topic_names() + + +def test_produce_and_consume(request, sasl_kafka): + topic_name = special_to_underscore(request.node.name + random_string(4)) + sasl_kafka.create_topics([topic_name], num_partitions=2) + producer, = sasl_kafka.get_producers(1) + + messages_and_futures = [] # [(message, produce_future),] + for i in range(100): + encoded_msg = "{}-{}-{}".format(i, request.node.name, uuid.uuid4()).encode("utf-8") + future = producer.send(topic_name, value=encoded_msg, partition=i % 2) + messages_and_futures.append((encoded_msg, future)) + producer.flush() + + for (msg, f) in messages_and_futures: + assert f.succeeded() + + consumer, = sasl_kafka.get_consumers(1, [topic_name]) + messages = {0: [], 1: []} + for i, message in enumerate(consumer, 1): + logging.debug("Consumed message %s", repr(message)) + messages[message.partition].append(message) + if i >= 100: + break + + assert_message_count(messages[0], 50) + assert_message_count(messages[1], 50) + + +def test_client(request, sasl_kafka): + topic_name = special_to_underscore(request.node.name + random_string(4)) + sasl_kafka.create_topics([topic_name], num_partitions=1) + + client, = sasl_kafka.get_clients(1) + request = MetadataRequest_v1(None) + client.send(0, request) + for _ in range(10): + result = client.poll(timeout_ms=10000) + if len(result) > 0: + break + else: + raise RuntimeError("Couldn't fetch topic response from Broker.") + result = result[0] + assert topic_name in [t[1] for t in result.topics] diff --git a/test_sender.py b/test_sender.py new file mode 100644 index 00000000..2a68defc --- /dev/null +++ b/test_sender.py @@ -0,0 +1,53 @@ +# pylint: skip-file +from __future__ import absolute_import + +import pytest +import io + +from kafka.client_async import KafkaClient +from kafka.cluster import ClusterMetadata +from kafka.metrics import Metrics +from kafka.protocol.produce import ProduceRequest +from kafka.producer.record_accumulator import RecordAccumulator, ProducerBatch +from kafka.producer.sender import Sender +from kafka.record.memory_records import MemoryRecordsBuilder +from kafka.structs import TopicPartition + + +@pytest.fixture +def client(mocker): + _cli = mocker.Mock(spec=KafkaClient(bootstrap_servers=(), api_version=(0, 9))) + _cli.cluster = mocker.Mock(spec=ClusterMetadata()) + return _cli + + +@pytest.fixture +def accumulator(): + return RecordAccumulator() + + +@pytest.fixture +def metrics(): + return Metrics() + + +@pytest.fixture +def sender(client, accumulator, metrics): + return Sender(client, client.cluster, accumulator, metrics) + + +@pytest.mark.parametrize(("api_version", "produce_version"), [ + ((0, 10), 2), + ((0, 9), 1), + ((0, 8), 0) +]) +def test_produce_request(sender, mocker, api_version, produce_version): + sender.config['api_version'] = api_version + tp = TopicPartition('foo', 0) + buffer = io.BytesIO() + records = MemoryRecordsBuilder( + magic=1, compression_type=0, batch_size=100000) + batch = ProducerBatch(tp, records, buffer) + records.close() + produce_request = sender._produce_request(0, 0, 0, [batch]) + assert isinstance(produce_request, ProduceRequest[produce_version]) diff --git a/test_subscription_state.py b/test_subscription_state.py new file mode 100644 index 00000000..9718f6af --- /dev/null +++ b/test_subscription_state.py @@ -0,0 +1,25 @@ +# pylint: skip-file +from __future__ import absolute_import + +import pytest + +from kafka.consumer.subscription_state import SubscriptionState + +@pytest.mark.parametrize(('topic_name', 'expectation'), [ + (0, pytest.raises(TypeError)), + (None, pytest.raises(TypeError)), + ('', pytest.raises(ValueError)), + ('.', pytest.raises(ValueError)), + ('..', pytest.raises(ValueError)), + ('a' * 250, pytest.raises(ValueError)), + ('abc/123', pytest.raises(ValueError)), + ('/abc/123', pytest.raises(ValueError)), + ('/abc123', pytest.raises(ValueError)), + ('name with space', pytest.raises(ValueError)), + ('name*with*stars', pytest.raises(ValueError)), + ('name+with+plus', pytest.raises(ValueError)), +]) +def test_topic_name_validation(topic_name, expectation): + state = SubscriptionState() + with expectation: + state._ensure_valid_topic_name(topic_name) diff --git a/testutil.py b/testutil.py new file mode 100644 index 00000000..ec4d70bf --- /dev/null +++ b/testutil.py @@ -0,0 +1,46 @@ +from __future__ import absolute_import + +import os +import random +import re +import string +import time + + +def special_to_underscore(string, _matcher=re.compile(r'[^a-zA-Z0-9_]+')): + return _matcher.sub('_', string) + + +def random_string(length): + return "".join(random.choice(string.ascii_letters) for i in range(length)) + + +def env_kafka_version(): + """Return the Kafka version set in the OS environment as a tuple. + + Example: '0.8.1.1' --> (0, 8, 1, 1) + """ + if 'KAFKA_VERSION' not in os.environ: + return () + return tuple(map(int, os.environ['KAFKA_VERSION'].split('.'))) + + +def assert_message_count(messages, num_messages): + """Check that we received the expected number of messages with no duplicates.""" + # Make sure we got them all + assert len(messages) == num_messages + # Make sure there are no duplicates + # Note: Currently duplicates are identified only using key/value. Other attributes like topic, partition, headers, + # timestamp, etc are ignored... this could be changed if necessary, but will be more tolerant of dupes. + unique_messages = {(m.key, m.value) for m in messages} + assert len(unique_messages) == num_messages + + +class Timer(object): + def __enter__(self): + self.start = time.time() + return self + + def __exit__(self, *args): + self.end = time.time() + self.interval = self.end - self.start From c220efc684ce94d6a1d70ae703a94651eea4c179 Mon Sep 17 00:00:00 2001 From: Denis Otkidach Date: Sat, 21 Oct 2023 22:47:07 +0300 Subject: [PATCH 05/20] Adjust imports in kafka-python tests --- requirements-ci.txt | 1 + requirements-win-test.txt | 1 + tests/kafka/conftest.py | 4 ++-- tests/kafka/fixtures.py | 4 ++-- tests/kafka/record/__init__.py | 0 tests/kafka/record/test_default_records.py | 2 +- tests/kafka/record/test_legacy_records.py | 2 +- tests/kafka/test_admin_integration.py | 2 +- tests/kafka/test_codec.py | 2 +- tests/kafka/test_conn.py | 2 +- tests/kafka/test_consumer_group.py | 2 +- tests/kafka/test_consumer_integration.py | 4 ++-- tests/kafka/test_producer.py | 2 +- tests/kafka/test_sasl_integration.py | 2 +- 14 files changed, 16 insertions(+), 14 deletions(-) create mode 100644 tests/kafka/record/__init__.py diff --git a/requirements-ci.txt b/requirements-ci.txt index 315662ef..3d09463d 100644 --- a/requirements-ci.txt +++ b/requirements-ci.txt @@ -6,6 +6,7 @@ isort[colors]==5.10.0 pytest==7.1.2 pytest-cov==3.0.0 pytest-asyncio==0.18.3 +pytest-mock==3.12.0 docker==6.1.2 chardet==4.0.0 # Until fixed requests is released lz4==3.1.3 diff --git a/requirements-win-test.txt b/requirements-win-test.txt index 6d3cca85..e2e51c61 100644 --- a/requirements-win-test.txt +++ b/requirements-win-test.txt @@ -6,6 +6,7 @@ isort[colors]==5.10.0 pytest==7.1.2 pytest-cov==3.0.0 pytest-asyncio==0.18.3 +pytest-mock==3.12.0 docker==6.0.1 chardet==4.0.0 # Until fixed requests is released lz4==3.1.3 diff --git a/tests/kafka/conftest.py b/tests/kafka/conftest.py index 3fa0262f..0bbd1a2b 100644 --- a/tests/kafka/conftest.py +++ b/tests/kafka/conftest.py @@ -4,8 +4,8 @@ import pytest -from test.testutil import env_kafka_version, random_string -from test.fixtures import KafkaFixture, ZookeeperFixture +from tests.kafka.testutil import env_kafka_version, random_string +from tests.kafka.fixtures import KafkaFixture, ZookeeperFixture @pytest.fixture(scope="module") def zookeeper(): diff --git a/tests/kafka/fixtures.py b/tests/kafka/fixtures.py index d9c072b8..5299bf3e 100644 --- a/tests/kafka/fixtures.py +++ b/tests/kafka/fixtures.py @@ -17,8 +17,8 @@ from kafka.errors import InvalidReplicationFactorError from kafka.protocol.admin import CreateTopicsRequest from kafka.protocol.metadata import MetadataRequest -from test.testutil import env_kafka_version, random_string -from test.service import ExternalService, SpawnedService +from tests.kafka.testutil import env_kafka_version, random_string +from tests.kafka.service import ExternalService, SpawnedService log = logging.getLogger(__name__) diff --git a/tests/kafka/record/__init__.py b/tests/kafka/record/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/kafka/record/test_default_records.py b/tests/kafka/record/test_default_records.py index c3a7b02c..3c809ebc 100644 --- a/tests/kafka/record/test_default_records.py +++ b/tests/kafka/record/test_default_records.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import unicode_literals import pytest -from mock import patch +from unittest.mock import patch import kafka.codec from kafka.record.default_records import ( DefaultRecordBatch, DefaultRecordBatchBuilder diff --git a/tests/kafka/record/test_legacy_records.py b/tests/kafka/record/test_legacy_records.py index 43970f7c..0c87ad9a 100644 --- a/tests/kafka/record/test_legacy_records.py +++ b/tests/kafka/record/test_legacy_records.py @@ -1,6 +1,6 @@ from __future__ import unicode_literals import pytest -from mock import patch +from unittest.mock import patch from kafka.record.legacy_records import ( LegacyRecordBatch, LegacyRecordBatchBuilder ) diff --git a/tests/kafka/test_admin_integration.py b/tests/kafka/test_admin_integration.py index 06c40a22..87ba289d 100644 --- a/tests/kafka/test_admin_integration.py +++ b/tests/kafka/test_admin_integration.py @@ -1,7 +1,7 @@ import pytest from logging import info -from test.testutil import env_kafka_version, random_string +from tests.kafka.testutil import env_kafka_version, random_string from threading import Event, Thread from time import time, sleep diff --git a/tests/kafka/test_codec.py b/tests/kafka/test_codec.py index e0570745..db6a14b6 100644 --- a/tests/kafka/test_codec.py +++ b/tests/kafka/test_codec.py @@ -15,7 +15,7 @@ zstd_encode, zstd_decode, ) -from test.testutil import random_string +from tests.kafka.testutil import random_string def test_gzip(): diff --git a/tests/kafka/test_conn.py b/tests/kafka/test_conn.py index 966f7b34..b49a8bd3 100644 --- a/tests/kafka/test_conn.py +++ b/tests/kafka/test_conn.py @@ -4,7 +4,7 @@ from errno import EALREADY, EINPROGRESS, EISCONN, ECONNRESET import socket -import mock +from unittest import mock import pytest from kafka.conn import BrokerConnection, ConnectionStates, collect_hosts diff --git a/tests/kafka/test_consumer_group.py b/tests/kafka/test_consumer_group.py index 58dc7ebf..40dc9d70 100644 --- a/tests/kafka/test_consumer_group.py +++ b/tests/kafka/test_consumer_group.py @@ -11,7 +11,7 @@ from kafka.coordinator.base import MemberState from kafka.structs import TopicPartition -from test.testutil import env_kafka_version, random_string +from tests.kafka.testutil import env_kafka_version, random_string def get_connect_str(kafka_broker): diff --git a/tests/kafka/test_consumer_integration.py b/tests/kafka/test_consumer_integration.py index 90b7ed20..a2644bae 100644 --- a/tests/kafka/test_consumer_integration.py +++ b/tests/kafka/test_consumer_integration.py @@ -1,7 +1,7 @@ import logging import time -from mock import patch +from unittest.mock import patch import pytest from kafka.vendor.six.moves import range @@ -9,7 +9,7 @@ from kafka.errors import UnsupportedCodecError, UnsupportedVersionError from kafka.structs import TopicPartition, OffsetAndTimestamp -from test.testutil import Timer, assert_message_count, env_kafka_version, random_string +from tests.kafka.testutil import Timer, assert_message_count, env_kafka_version, random_string @pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set") diff --git a/tests/kafka/test_producer.py b/tests/kafka/test_producer.py index 7263130d..97099e93 100644 --- a/tests/kafka/test_producer.py +++ b/tests/kafka/test_producer.py @@ -7,7 +7,7 @@ from kafka import KafkaConsumer, KafkaProducer, TopicPartition from kafka.producer.buffer import SimpleBufferPool -from test.testutil import env_kafka_version, random_string +from tests.kafka.testutil import env_kafka_version, random_string def test_buffer_pool(): diff --git a/tests/kafka/test_sasl_integration.py b/tests/kafka/test_sasl_integration.py index e3a4813a..d66a7349 100644 --- a/tests/kafka/test_sasl_integration.py +++ b/tests/kafka/test_sasl_integration.py @@ -5,7 +5,7 @@ from kafka.admin import NewTopic from kafka.protocol.metadata import MetadataRequest_v1 -from test.testutil import assert_message_count, env_kafka_version, random_string, special_to_underscore +from tests.kafka.testutil import assert_message_count, env_kafka_version, random_string, special_to_underscore @pytest.fixture( From e7c11e680b61c023c1a3afc854e528a158a585f8 Mon Sep 17 00:00:00 2001 From: Denis Otkidach Date: Sun, 22 Oct 2023 14:57:08 +0300 Subject: [PATCH 06/20] De-dup: record packages --- aiokafka/consumer/fetcher.py | 2 +- kafka/__init__.py | 2 - kafka/consumer/__init__.py | 7 - kafka/consumer/fetcher.py | 1016 ---------------- kafka/consumer/group.py | 1225 -------------------- kafka/producer/__init__.py | 7 - kafka/producer/buffer.py | 115 -- kafka/producer/future.py | 71 -- kafka/producer/kafka.py | 752 ------------ kafka/producer/record_accumulator.py | 590 ---------- kafka/producer/sender.py | 517 --------- kafka/record/README | 8 - kafka/record/__init__.py | 3 - kafka/record/_crc32c.py | 145 --- kafka/record/abc.py | 124 -- kafka/record/default_records.py | 630 ---------- kafka/record/legacy_records.py | 548 --------- kafka/record/memory_records.py | 187 --- kafka/record/util.py | 135 --- tests/kafka/fixtures.py | 14 +- tests/kafka/record/__init__.py | 0 tests/kafka/record/test_default_records.py | 208 ---- tests/kafka/record/test_legacy_records.py | 197 ---- tests/kafka/record/test_records.py | 232 ---- tests/kafka/record/test_util.py | 96 -- tests/kafka/test_consumer.py | 26 - tests/kafka/test_consumer_group.py | 179 --- tests/kafka/test_fetcher.py | 553 --------- tests/kafka/test_package.py | 25 - tests/kafka/test_producer.py | 137 --- tests/kafka/test_sender.py | 53 - 31 files changed, 2 insertions(+), 7802 deletions(-) delete mode 100644 kafka/consumer/fetcher.py delete mode 100644 kafka/consumer/group.py delete mode 100644 kafka/producer/__init__.py delete mode 100644 kafka/producer/buffer.py delete mode 100644 kafka/producer/future.py delete mode 100644 kafka/producer/kafka.py delete mode 100644 kafka/producer/record_accumulator.py delete mode 100644 kafka/producer/sender.py delete mode 100644 kafka/record/README delete mode 100644 kafka/record/__init__.py delete mode 100644 kafka/record/_crc32c.py delete mode 100644 kafka/record/abc.py delete mode 100644 kafka/record/default_records.py delete mode 100644 kafka/record/legacy_records.py delete mode 100644 kafka/record/memory_records.py delete mode 100644 kafka/record/util.py delete mode 100644 tests/kafka/record/__init__.py delete mode 100644 tests/kafka/record/test_default_records.py delete mode 100644 tests/kafka/record/test_legacy_records.py delete mode 100644 tests/kafka/record/test_records.py delete mode 100644 tests/kafka/record/test_util.py delete mode 100644 tests/kafka/test_consumer.py delete mode 100644 tests/kafka/test_consumer_group.py delete mode 100644 tests/kafka/test_fetcher.py delete mode 100644 tests/kafka/test_package.py delete mode 100644 tests/kafka/test_producer.py delete mode 100644 tests/kafka/test_sender.py diff --git a/aiokafka/consumer/fetcher.py b/aiokafka/consumer/fetcher.py index 9db37129..2a3394b3 100644 --- a/aiokafka/consumer/fetcher.py +++ b/aiokafka/consumer/fetcher.py @@ -309,7 +309,7 @@ class Fetcher: Parameters: client (AIOKafkaClient): kafka client subscription (SubscriptionState): instance of SubscriptionState - located in kafka.consumer.subscription_state + located in aiokafka.consumer.subscription_state key_deserializer (callable): Any callable that takes a raw message key and returns a deserialized key. value_deserializer (callable, optional): Any callable that takes a diff --git a/kafka/__init__.py b/kafka/__init__.py index d5e30aff..c4308c5e 100644 --- a/kafka/__init__.py +++ b/kafka/__init__.py @@ -20,9 +20,7 @@ def emit(self, record): from kafka.admin import KafkaAdminClient from kafka.client_async import KafkaClient -from kafka.consumer import KafkaConsumer from kafka.consumer.subscription_state import ConsumerRebalanceListener -from kafka.producer import KafkaProducer from kafka.conn import BrokerConnection from kafka.serializer import Serializer, Deserializer from kafka.structs import TopicPartition, OffsetAndMetadata diff --git a/kafka/consumer/__init__.py b/kafka/consumer/__init__.py index e09bcc1b..e69de29b 100644 --- a/kafka/consumer/__init__.py +++ b/kafka/consumer/__init__.py @@ -1,7 +0,0 @@ -from __future__ import absolute_import - -from kafka.consumer.group import KafkaConsumer - -__all__ = [ - 'KafkaConsumer' -] diff --git a/kafka/consumer/fetcher.py b/kafka/consumer/fetcher.py deleted file mode 100644 index 7ff9daf7..00000000 --- a/kafka/consumer/fetcher.py +++ /dev/null @@ -1,1016 +0,0 @@ -from __future__ import absolute_import - -import collections -import copy -import logging -import random -import sys -import time - -from kafka.vendor import six - -import kafka.errors as Errors -from kafka.future import Future -from kafka.metrics.stats import Avg, Count, Max, Rate -from kafka.protocol.fetch import FetchRequest -from kafka.protocol.offset import ( - OffsetRequest, OffsetResetStrategy, UNKNOWN_OFFSET -) -from kafka.record import MemoryRecords -from kafka.serializer import Deserializer -from kafka.structs import TopicPartition, OffsetAndTimestamp - -log = logging.getLogger(__name__) - - -# Isolation levels -READ_UNCOMMITTED = 0 -READ_COMMITTED = 1 - -ConsumerRecord = collections.namedtuple("ConsumerRecord", - ["topic", "partition", "offset", "timestamp", "timestamp_type", - "key", "value", "headers", "checksum", "serialized_key_size", "serialized_value_size", "serialized_header_size"]) - - -CompletedFetch = collections.namedtuple("CompletedFetch", - ["topic_partition", "fetched_offset", "response_version", - "partition_data", "metric_aggregator"]) - - -class NoOffsetForPartitionError(Errors.KafkaError): - pass - - -class RecordTooLargeError(Errors.KafkaError): - pass - - -class Fetcher(six.Iterator): - DEFAULT_CONFIG = { - 'key_deserializer': None, - 'value_deserializer': None, - 'fetch_min_bytes': 1, - 'fetch_max_wait_ms': 500, - 'fetch_max_bytes': 52428800, - 'max_partition_fetch_bytes': 1048576, - 'max_poll_records': sys.maxsize, - 'check_crcs': True, - 'iterator_refetch_records': 1, # undocumented -- interface may change - 'metric_group_prefix': 'consumer', - 'api_version': (0, 8, 0), - 'retry_backoff_ms': 100 - } - - def __init__(self, client, subscriptions, metrics, **configs): - """Initialize a Kafka Message Fetcher. - - Keyword Arguments: - key_deserializer (callable): Any callable that takes a - raw message key and returns a deserialized key. - value_deserializer (callable, optional): Any callable that takes a - raw message value and returns a deserialized value. - fetch_min_bytes (int): Minimum amount of data the server should - return for a fetch request, otherwise wait up to - fetch_max_wait_ms for more data to accumulate. Default: 1. - fetch_max_wait_ms (int): The maximum amount of time in milliseconds - the server will block before answering the fetch request if - there isn't sufficient data to immediately satisfy the - requirement given by fetch_min_bytes. Default: 500. - fetch_max_bytes (int): The maximum amount of data the server should - return for a fetch request. This is not an absolute maximum, if - the first message in the first non-empty partition of the fetch - is larger than this value, the message will still be returned - to ensure that the consumer can make progress. NOTE: consumer - performs fetches to multiple brokers in parallel so memory - usage will depend on the number of brokers containing - partitions for the topic. - Supported Kafka version >= 0.10.1.0. Default: 52428800 (50 MB). - max_partition_fetch_bytes (int): The maximum amount of data - per-partition the server will return. The maximum total memory - used for a request = #partitions * max_partition_fetch_bytes. - This size must be at least as large as the maximum message size - the server allows or else it is possible for the producer to - send messages larger than the consumer can fetch. If that - happens, the consumer can get stuck trying to fetch a large - message on a certain partition. Default: 1048576. - check_crcs (bool): Automatically check the CRC32 of the records - consumed. This ensures no on-the-wire or on-disk corruption to - the messages occurred. This check adds some overhead, so it may - be disabled in cases seeking extreme performance. Default: True - """ - self.config = copy.copy(self.DEFAULT_CONFIG) - for key in self.config: - if key in configs: - self.config[key] = configs[key] - - self._client = client - self._subscriptions = subscriptions - self._completed_fetches = collections.deque() # Unparsed responses - self._next_partition_records = None # Holds a single PartitionRecords until fully consumed - self._iterator = None - self._fetch_futures = collections.deque() - self._sensors = FetchManagerMetrics(metrics, self.config['metric_group_prefix']) - self._isolation_level = READ_UNCOMMITTED - - def send_fetches(self): - """Send FetchRequests for all assigned partitions that do not already have - an in-flight fetch or pending fetch data. - - Returns: - List of Futures: each future resolves to a FetchResponse - """ - futures = [] - for node_id, request in six.iteritems(self._create_fetch_requests()): - if self._client.ready(node_id): - log.debug("Sending FetchRequest to node %s", node_id) - future = self._client.send(node_id, request, wakeup=False) - future.add_callback(self._handle_fetch_response, request, time.time()) - future.add_errback(log.error, 'Fetch to node %s failed: %s', node_id) - futures.append(future) - self._fetch_futures.extend(futures) - self._clean_done_fetch_futures() - return futures - - def reset_offsets_if_needed(self, partitions): - """Lookup and set offsets for any partitions which are awaiting an - explicit reset. - - Arguments: - partitions (set of TopicPartitions): the partitions to reset - """ - for tp in partitions: - # TODO: If there are several offsets to reset, we could submit offset requests in parallel - if self._subscriptions.is_assigned(tp) and self._subscriptions.is_offset_reset_needed(tp): - self._reset_offset(tp) - - def _clean_done_fetch_futures(self): - while True: - if not self._fetch_futures: - break - if not self._fetch_futures[0].is_done: - break - self._fetch_futures.popleft() - - def in_flight_fetches(self): - """Return True if there are any unprocessed FetchRequests in flight.""" - self._clean_done_fetch_futures() - return bool(self._fetch_futures) - - def update_fetch_positions(self, partitions): - """Update the fetch positions for the provided partitions. - - Arguments: - partitions (list of TopicPartitions): partitions to update - - Raises: - NoOffsetForPartitionError: if no offset is stored for a given - partition and no reset policy is available - """ - # reset the fetch position to the committed position - for tp in partitions: - if not self._subscriptions.is_assigned(tp): - log.warning("partition %s is not assigned - skipping offset" - " update", tp) - continue - elif self._subscriptions.is_fetchable(tp): - log.warning("partition %s is still fetchable -- skipping offset" - " update", tp) - continue - - if self._subscriptions.is_offset_reset_needed(tp): - self._reset_offset(tp) - elif self._subscriptions.assignment[tp].committed is None: - # there's no committed position, so we need to reset with the - # default strategy - self._subscriptions.need_offset_reset(tp) - self._reset_offset(tp) - else: - committed = self._subscriptions.assignment[tp].committed.offset - log.debug("Resetting offset for partition %s to the committed" - " offset %s", tp, committed) - self._subscriptions.seek(tp, committed) - - def get_offsets_by_times(self, timestamps, timeout_ms): - offsets = self._retrieve_offsets(timestamps, timeout_ms) - for tp in timestamps: - if tp not in offsets: - offsets[tp] = None - else: - offset, timestamp = offsets[tp] - offsets[tp] = OffsetAndTimestamp(offset, timestamp) - return offsets - - def beginning_offsets(self, partitions, timeout_ms): - return self.beginning_or_end_offset( - partitions, OffsetResetStrategy.EARLIEST, timeout_ms) - - def end_offsets(self, partitions, timeout_ms): - return self.beginning_or_end_offset( - partitions, OffsetResetStrategy.LATEST, timeout_ms) - - def beginning_or_end_offset(self, partitions, timestamp, timeout_ms): - timestamps = dict([(tp, timestamp) for tp in partitions]) - offsets = self._retrieve_offsets(timestamps, timeout_ms) - for tp in timestamps: - offsets[tp] = offsets[tp][0] - return offsets - - def _reset_offset(self, partition): - """Reset offsets for the given partition using the offset reset strategy. - - Arguments: - partition (TopicPartition): the partition that needs reset offset - - Raises: - NoOffsetForPartitionError: if no offset reset strategy is defined - """ - timestamp = self._subscriptions.assignment[partition].reset_strategy - if timestamp is OffsetResetStrategy.EARLIEST: - strategy = 'earliest' - elif timestamp is OffsetResetStrategy.LATEST: - strategy = 'latest' - else: - raise NoOffsetForPartitionError(partition) - - log.debug("Resetting offset for partition %s to %s offset.", - partition, strategy) - offsets = self._retrieve_offsets({partition: timestamp}) - - if partition in offsets: - offset = offsets[partition][0] - - # we might lose the assignment while fetching the offset, - # so check it is still active - if self._subscriptions.is_assigned(partition): - self._subscriptions.seek(partition, offset) - else: - log.debug("Could not find offset for partition %s since it is probably deleted" % (partition,)) - - def _retrieve_offsets(self, timestamps, timeout_ms=float("inf")): - """Fetch offset for each partition passed in ``timestamps`` map. - - Blocks until offsets are obtained, a non-retriable exception is raised - or ``timeout_ms`` passed. - - Arguments: - timestamps: {TopicPartition: int} dict with timestamps to fetch - offsets by. -1 for the latest available, -2 for the earliest - available. Otherwise timestamp is treated as epoch milliseconds. - - Returns: - {TopicPartition: (int, int)}: Mapping of partition to - retrieved offset and timestamp. If offset does not exist for - the provided timestamp, that partition will be missing from - this mapping. - """ - if not timestamps: - return {} - - start_time = time.time() - remaining_ms = timeout_ms - timestamps = copy.copy(timestamps) - while remaining_ms > 0: - if not timestamps: - return {} - - future = self._send_offset_requests(timestamps) - self._client.poll(future=future, timeout_ms=remaining_ms) - - if future.succeeded(): - return future.value - if not future.retriable(): - raise future.exception # pylint: disable-msg=raising-bad-type - - elapsed_ms = (time.time() - start_time) * 1000 - remaining_ms = timeout_ms - elapsed_ms - if remaining_ms < 0: - break - - if future.exception.invalid_metadata: - refresh_future = self._client.cluster.request_update() - self._client.poll(future=refresh_future, timeout_ms=remaining_ms) - - # Issue #1780 - # Recheck partition existence after after a successful metadata refresh - if refresh_future.succeeded() and isinstance(future.exception, Errors.StaleMetadata): - log.debug("Stale metadata was raised, and we now have an updated metadata. Rechecking partition existence") - unknown_partition = future.exception.args[0] # TopicPartition from StaleMetadata - if self._client.cluster.leader_for_partition(unknown_partition) is None: - log.debug("Removed partition %s from offsets retrieval" % (unknown_partition, )) - timestamps.pop(unknown_partition) - else: - time.sleep(self.config['retry_backoff_ms'] / 1000.0) - - elapsed_ms = (time.time() - start_time) * 1000 - remaining_ms = timeout_ms - elapsed_ms - - raise Errors.KafkaTimeoutError( - "Failed to get offsets by timestamps in %s ms" % (timeout_ms,)) - - def fetched_records(self, max_records=None, update_offsets=True): - """Returns previously fetched records and updates consumed offsets. - - Arguments: - max_records (int): Maximum number of records returned. Defaults - to max_poll_records configuration. - - Raises: - OffsetOutOfRangeError: if no subscription offset_reset_strategy - CorruptRecordException: if message crc validation fails (check_crcs - must be set to True) - RecordTooLargeError: if a message is larger than the currently - configured max_partition_fetch_bytes - TopicAuthorizationError: if consumer is not authorized to fetch - messages from the topic - - Returns: (records (dict), partial (bool)) - records: {TopicPartition: [messages]} - partial: True if records returned did not fully drain any pending - partition requests. This may be useful for choosing when to - pipeline additional fetch requests. - """ - if max_records is None: - max_records = self.config['max_poll_records'] - assert max_records > 0 - - drained = collections.defaultdict(list) - records_remaining = max_records - - while records_remaining > 0: - if not self._next_partition_records: - if not self._completed_fetches: - break - completion = self._completed_fetches.popleft() - self._next_partition_records = self._parse_fetched_data(completion) - else: - records_remaining -= self._append(drained, - self._next_partition_records, - records_remaining, - update_offsets) - return dict(drained), bool(self._completed_fetches) - - def _append(self, drained, part, max_records, update_offsets): - if not part: - return 0 - - tp = part.topic_partition - fetch_offset = part.fetch_offset - if not self._subscriptions.is_assigned(tp): - # this can happen when a rebalance happened before - # fetched records are returned to the consumer's poll call - log.debug("Not returning fetched records for partition %s" - " since it is no longer assigned", tp) - else: - # note that the position should always be available - # as long as the partition is still assigned - position = self._subscriptions.assignment[tp].position - if not self._subscriptions.is_fetchable(tp): - # this can happen when a partition is paused before - # fetched records are returned to the consumer's poll call - log.debug("Not returning fetched records for assigned partition" - " %s since it is no longer fetchable", tp) - - elif fetch_offset == position: - # we are ensured to have at least one record since we already checked for emptiness - part_records = part.take(max_records) - next_offset = part_records[-1].offset + 1 - - log.log(0, "Returning fetched records at offset %d for assigned" - " partition %s and update position to %s", position, - tp, next_offset) - - for record in part_records: - drained[tp].append(record) - - if update_offsets: - self._subscriptions.assignment[tp].position = next_offset - return len(part_records) - - else: - # these records aren't next in line based on the last consumed - # position, ignore them they must be from an obsolete request - log.debug("Ignoring fetched records for %s at offset %s since" - " the current position is %d", tp, part.fetch_offset, - position) - - part.discard() - return 0 - - def _message_generator(self): - """Iterate over fetched_records""" - while self._next_partition_records or self._completed_fetches: - - if not self._next_partition_records: - completion = self._completed_fetches.popleft() - self._next_partition_records = self._parse_fetched_data(completion) - continue - - # Send additional FetchRequests when the internal queue is low - # this should enable moderate pipelining - if len(self._completed_fetches) <= self.config['iterator_refetch_records']: - self.send_fetches() - - tp = self._next_partition_records.topic_partition - - # We can ignore any prior signal to drop pending message sets - # because we are starting from a fresh one where fetch_offset == position - # i.e., the user seek()'d to this position - self._subscriptions.assignment[tp].drop_pending_message_set = False - - for msg in self._next_partition_records.take(): - - # Because we are in a generator, it is possible for - # subscription state to change between yield calls - # so we need to re-check on each loop - # this should catch assignment changes, pauses - # and resets via seek_to_beginning / seek_to_end - if not self._subscriptions.is_fetchable(tp): - log.debug("Not returning fetched records for partition %s" - " since it is no longer fetchable", tp) - self._next_partition_records = None - break - - # If there is a seek during message iteration, - # we should stop unpacking this message set and - # wait for a new fetch response that aligns with the - # new seek position - elif self._subscriptions.assignment[tp].drop_pending_message_set: - log.debug("Skipping remainder of message set for partition %s", tp) - self._subscriptions.assignment[tp].drop_pending_message_set = False - self._next_partition_records = None - break - - # Compressed messagesets may include earlier messages - elif msg.offset < self._subscriptions.assignment[tp].position: - log.debug("Skipping message offset: %s (expecting %s)", - msg.offset, - self._subscriptions.assignment[tp].position) - continue - - self._subscriptions.assignment[tp].position = msg.offset + 1 - yield msg - - self._next_partition_records = None - - def _unpack_message_set(self, tp, records): - try: - batch = records.next_batch() - while batch is not None: - - # LegacyRecordBatch cannot access either base_offset or last_offset_delta - try: - self._subscriptions.assignment[tp].last_offset_from_message_batch = batch.base_offset + \ - batch.last_offset_delta - except AttributeError: - pass - - for record in batch: - key_size = len(record.key) if record.key is not None else -1 - value_size = len(record.value) if record.value is not None else -1 - key = self._deserialize( - self.config['key_deserializer'], - tp.topic, record.key) - value = self._deserialize( - self.config['value_deserializer'], - tp.topic, record.value) - headers = record.headers - header_size = sum( - len(h_key.encode("utf-8")) + (len(h_val) if h_val is not None else 0) for h_key, h_val in - headers) if headers else -1 - yield ConsumerRecord( - tp.topic, tp.partition, record.offset, record.timestamp, - record.timestamp_type, key, value, headers, record.checksum, - key_size, value_size, header_size) - - batch = records.next_batch() - - # If unpacking raises StopIteration, it is erroneously - # caught by the generator. We want all exceptions to be raised - # back to the user. See Issue 545 - except StopIteration as e: - log.exception('StopIteration raised unpacking messageset') - raise RuntimeError('StopIteration raised unpacking messageset') - - def __iter__(self): # pylint: disable=non-iterator-returned - return self - - def __next__(self): - if not self._iterator: - self._iterator = self._message_generator() - try: - return next(self._iterator) - except StopIteration: - self._iterator = None - raise - - def _deserialize(self, f, topic, bytes_): - if not f: - return bytes_ - if isinstance(f, Deserializer): - return f.deserialize(topic, bytes_) - return f(bytes_) - - def _send_offset_requests(self, timestamps): - """Fetch offsets for each partition in timestamps dict. This may send - request to multiple nodes, based on who is Leader for partition. - - Arguments: - timestamps (dict): {TopicPartition: int} mapping of fetching - timestamps. - - Returns: - Future: resolves to a mapping of retrieved offsets - """ - timestamps_by_node = collections.defaultdict(dict) - for partition, timestamp in six.iteritems(timestamps): - node_id = self._client.cluster.leader_for_partition(partition) - if node_id is None: - self._client.add_topic(partition.topic) - log.debug("Partition %s is unknown for fetching offset," - " wait for metadata refresh", partition) - return Future().failure(Errors.StaleMetadata(partition)) - elif node_id == -1: - log.debug("Leader for partition %s unavailable for fetching " - "offset, wait for metadata refresh", partition) - return Future().failure( - Errors.LeaderNotAvailableError(partition)) - else: - timestamps_by_node[node_id][partition] = timestamp - - # Aggregate results until we have all - list_offsets_future = Future() - responses = [] - node_count = len(timestamps_by_node) - - def on_success(value): - responses.append(value) - if len(responses) == node_count: - offsets = {} - for r in responses: - offsets.update(r) - list_offsets_future.success(offsets) - - def on_fail(err): - if not list_offsets_future.is_done: - list_offsets_future.failure(err) - - for node_id, timestamps in six.iteritems(timestamps_by_node): - _f = self._send_offset_request(node_id, timestamps) - _f.add_callback(on_success) - _f.add_errback(on_fail) - return list_offsets_future - - def _send_offset_request(self, node_id, timestamps): - by_topic = collections.defaultdict(list) - for tp, timestamp in six.iteritems(timestamps): - if self.config['api_version'] >= (0, 10, 1): - data = (tp.partition, timestamp) - else: - data = (tp.partition, timestamp, 1) - by_topic[tp.topic].append(data) - - if self.config['api_version'] >= (0, 10, 1): - request = OffsetRequest[1](-1, list(six.iteritems(by_topic))) - else: - request = OffsetRequest[0](-1, list(six.iteritems(by_topic))) - - # Client returns a future that only fails on network issues - # so create a separate future and attach a callback to update it - # based on response error codes - future = Future() - - _f = self._client.send(node_id, request) - _f.add_callback(self._handle_offset_response, future) - _f.add_errback(lambda e: future.failure(e)) - return future - - def _handle_offset_response(self, future, response): - """Callback for the response of the list offset call above. - - Arguments: - future (Future): the future to update based on response - response (OffsetResponse): response from the server - - Raises: - AssertionError: if response does not match partition - """ - timestamp_offset_map = {} - for topic, part_data in response.topics: - for partition_info in part_data: - partition, error_code = partition_info[:2] - partition = TopicPartition(topic, partition) - error_type = Errors.for_code(error_code) - if error_type is Errors.NoError: - if response.API_VERSION == 0: - offsets = partition_info[2] - assert len(offsets) <= 1, 'Expected OffsetResponse with one offset' - if not offsets: - offset = UNKNOWN_OFFSET - else: - offset = offsets[0] - log.debug("Handling v0 ListOffsetResponse response for %s. " - "Fetched offset %s", partition, offset) - if offset != UNKNOWN_OFFSET: - timestamp_offset_map[partition] = (offset, None) - else: - timestamp, offset = partition_info[2:] - log.debug("Handling ListOffsetResponse response for %s. " - "Fetched offset %s, timestamp %s", - partition, offset, timestamp) - if offset != UNKNOWN_OFFSET: - timestamp_offset_map[partition] = (offset, timestamp) - elif error_type is Errors.UnsupportedForMessageFormatError: - # The message format on the broker side is before 0.10.0, - # we simply put None in the response. - log.debug("Cannot search by timestamp for partition %s because the" - " message format version is before 0.10.0", partition) - elif error_type is Errors.NotLeaderForPartitionError: - log.debug("Attempt to fetch offsets for partition %s failed due" - " to obsolete leadership information, retrying.", - partition) - future.failure(error_type(partition)) - return - elif error_type is Errors.UnknownTopicOrPartitionError: - log.warning("Received unknown topic or partition error in ListOffset " - "request for partition %s. The topic/partition " + - "may not exist or the user may not have Describe access " - "to it.", partition) - future.failure(error_type(partition)) - return - else: - log.warning("Attempt to fetch offsets for partition %s failed due to:" - " %s", partition, error_type) - future.failure(error_type(partition)) - return - if not future.is_done: - future.success(timestamp_offset_map) - - def _fetchable_partitions(self): - fetchable = self._subscriptions.fetchable_partitions() - # do not fetch a partition if we have a pending fetch response to process - current = self._next_partition_records - pending = copy.copy(self._completed_fetches) - if current: - fetchable.discard(current.topic_partition) - for fetch in pending: - fetchable.discard(fetch.topic_partition) - return fetchable - - def _create_fetch_requests(self): - """Create fetch requests for all assigned partitions, grouped by node. - - FetchRequests skipped if no leader, or node has requests in flight - - Returns: - dict: {node_id: FetchRequest, ...} (version depends on api_version) - """ - # create the fetch info as a dict of lists of partition info tuples - # which can be passed to FetchRequest() via .items() - fetchable = collections.defaultdict(lambda: collections.defaultdict(list)) - - for partition in self._fetchable_partitions(): - node_id = self._client.cluster.leader_for_partition(partition) - - # advance position for any deleted compacted messages if required - if self._subscriptions.assignment[partition].last_offset_from_message_batch: - next_offset_from_batch_header = self._subscriptions.assignment[partition].last_offset_from_message_batch + 1 - if next_offset_from_batch_header > self._subscriptions.assignment[partition].position: - log.debug( - "Advance position for partition %s from %s to %s (last message batch location plus one)" - " to correct for deleted compacted messages", - partition, self._subscriptions.assignment[partition].position, next_offset_from_batch_header) - self._subscriptions.assignment[partition].position = next_offset_from_batch_header - - position = self._subscriptions.assignment[partition].position - - # fetch if there is a leader and no in-flight requests - if node_id is None or node_id == -1: - log.debug("No leader found for partition %s." - " Requesting metadata update", partition) - self._client.cluster.request_update() - - elif self._client.in_flight_request_count(node_id) == 0: - partition_info = ( - partition.partition, - position, - self.config['max_partition_fetch_bytes'] - ) - fetchable[node_id][partition.topic].append(partition_info) - log.debug("Adding fetch request for partition %s at offset %d", - partition, position) - else: - log.log(0, "Skipping fetch for partition %s because there is an inflight request to node %s", - partition, node_id) - - if self.config['api_version'] >= (0, 11, 0): - version = 4 - elif self.config['api_version'] >= (0, 10, 1): - version = 3 - elif self.config['api_version'] >= (0, 10): - version = 2 - elif self.config['api_version'] == (0, 9): - version = 1 - else: - version = 0 - requests = {} - for node_id, partition_data in six.iteritems(fetchable): - if version < 3: - requests[node_id] = FetchRequest[version]( - -1, # replica_id - self.config['fetch_max_wait_ms'], - self.config['fetch_min_bytes'], - partition_data.items()) - else: - # As of version == 3 partitions will be returned in order as - # they are requested, so to avoid starvation with - # `fetch_max_bytes` option we need this shuffle - # NOTE: we do have partition_data in random order due to usage - # of unordered structures like dicts, but that does not - # guarantee equal distribution, and starting in Python3.6 - # dicts retain insert order. - partition_data = list(partition_data.items()) - random.shuffle(partition_data) - if version == 3: - requests[node_id] = FetchRequest[version]( - -1, # replica_id - self.config['fetch_max_wait_ms'], - self.config['fetch_min_bytes'], - self.config['fetch_max_bytes'], - partition_data) - else: - requests[node_id] = FetchRequest[version]( - -1, # replica_id - self.config['fetch_max_wait_ms'], - self.config['fetch_min_bytes'], - self.config['fetch_max_bytes'], - self._isolation_level, - partition_data) - return requests - - def _handle_fetch_response(self, request, send_time, response): - """The callback for fetch completion""" - fetch_offsets = {} - for topic, partitions in request.topics: - for partition_data in partitions: - partition, offset = partition_data[:2] - fetch_offsets[TopicPartition(topic, partition)] = offset - - partitions = set([TopicPartition(topic, partition_data[0]) - for topic, partitions in response.topics - for partition_data in partitions]) - metric_aggregator = FetchResponseMetricAggregator(self._sensors, partitions) - - # randomized ordering should improve balance for short-lived consumers - random.shuffle(response.topics) - for topic, partitions in response.topics: - random.shuffle(partitions) - for partition_data in partitions: - tp = TopicPartition(topic, partition_data[0]) - completed_fetch = CompletedFetch( - tp, fetch_offsets[tp], - response.API_VERSION, - partition_data[1:], - metric_aggregator - ) - self._completed_fetches.append(completed_fetch) - - if response.API_VERSION >= 1: - self._sensors.fetch_throttle_time_sensor.record(response.throttle_time_ms) - self._sensors.fetch_latency.record((time.time() - send_time) * 1000) - - def _parse_fetched_data(self, completed_fetch): - tp = completed_fetch.topic_partition - fetch_offset = completed_fetch.fetched_offset - num_bytes = 0 - records_count = 0 - parsed_records = None - - error_code, highwater = completed_fetch.partition_data[:2] - error_type = Errors.for_code(error_code) - - try: - if not self._subscriptions.is_fetchable(tp): - # this can happen when a rebalance happened or a partition - # consumption paused while fetch is still in-flight - log.debug("Ignoring fetched records for partition %s" - " since it is no longer fetchable", tp) - - elif error_type is Errors.NoError: - self._subscriptions.assignment[tp].highwater = highwater - - # we are interested in this fetch only if the beginning - # offset (of the *request*) matches the current consumed position - # Note that the *response* may return a messageset that starts - # earlier (e.g., compressed messages) or later (e.g., compacted topic) - position = self._subscriptions.assignment[tp].position - if position is None or position != fetch_offset: - log.debug("Discarding fetch response for partition %s" - " since its offset %d does not match the" - " expected offset %d", tp, fetch_offset, - position) - return None - - records = MemoryRecords(completed_fetch.partition_data[-1]) - if records.has_next(): - log.debug("Adding fetched record for partition %s with" - " offset %d to buffered record list", tp, - position) - unpacked = list(self._unpack_message_set(tp, records)) - parsed_records = self.PartitionRecords(fetch_offset, tp, unpacked) - if unpacked: - last_offset = unpacked[-1].offset - self._sensors.records_fetch_lag.record(highwater - last_offset) - num_bytes = records.valid_bytes() - records_count = len(unpacked) - elif records.size_in_bytes() > 0: - # we did not read a single message from a non-empty - # buffer because that message's size is larger than - # fetch size, in this case record this exception - record_too_large_partitions = {tp: fetch_offset} - raise RecordTooLargeError( - "There are some messages at [Partition=Offset]: %s " - " whose size is larger than the fetch size %s" - " and hence cannot be ever returned." - " Increase the fetch size, or decrease the maximum message" - " size the broker will allow." % ( - record_too_large_partitions, - self.config['max_partition_fetch_bytes']), - record_too_large_partitions) - self._sensors.record_topic_fetch_metrics(tp.topic, num_bytes, records_count) - - elif error_type in (Errors.NotLeaderForPartitionError, - Errors.UnknownTopicOrPartitionError): - self._client.cluster.request_update() - elif error_type is Errors.OffsetOutOfRangeError: - position = self._subscriptions.assignment[tp].position - if position is None or position != fetch_offset: - log.debug("Discarding stale fetch response for partition %s" - " since the fetched offset %d does not match the" - " current offset %d", tp, fetch_offset, position) - elif self._subscriptions.has_default_offset_reset_policy(): - log.info("Fetch offset %s is out of range for topic-partition %s", fetch_offset, tp) - self._subscriptions.need_offset_reset(tp) - else: - raise Errors.OffsetOutOfRangeError({tp: fetch_offset}) - - elif error_type is Errors.TopicAuthorizationFailedError: - log.warning("Not authorized to read from topic %s.", tp.topic) - raise Errors.TopicAuthorizationFailedError(set(tp.topic)) - elif error_type is Errors.UnknownError: - log.warning("Unknown error fetching data for topic-partition %s", tp) - else: - raise error_type('Unexpected error while fetching data') - - finally: - completed_fetch.metric_aggregator.record(tp, num_bytes, records_count) - - return parsed_records - - class PartitionRecords(object): - def __init__(self, fetch_offset, tp, messages): - self.fetch_offset = fetch_offset - self.topic_partition = tp - self.messages = messages - # When fetching an offset that is in the middle of a - # compressed batch, we will get all messages in the batch. - # But we want to start 'take' at the fetch_offset - # (or the next highest offset in case the message was compacted) - for i, msg in enumerate(messages): - if msg.offset < fetch_offset: - log.debug("Skipping message offset: %s (expecting %s)", - msg.offset, fetch_offset) - else: - self.message_idx = i - break - - else: - self.message_idx = 0 - self.messages = None - - # For truthiness evaluation we need to define __len__ or __nonzero__ - def __len__(self): - if self.messages is None or self.message_idx >= len(self.messages): - return 0 - return len(self.messages) - self.message_idx - - def discard(self): - self.messages = None - - def take(self, n=None): - if not len(self): - return [] - if n is None or n > len(self): - n = len(self) - next_idx = self.message_idx + n - res = self.messages[self.message_idx:next_idx] - self.message_idx = next_idx - # fetch_offset should be incremented by 1 to parallel the - # subscription position (also incremented by 1) - self.fetch_offset = max(self.fetch_offset, res[-1].offset + 1) - return res - - -class FetchResponseMetricAggregator(object): - """ - Since we parse the message data for each partition from each fetch - response lazily, fetch-level metrics need to be aggregated as the messages - from each partition are parsed. This class is used to facilitate this - incremental aggregation. - """ - def __init__(self, sensors, partitions): - self.sensors = sensors - self.unrecorded_partitions = partitions - self.total_bytes = 0 - self.total_records = 0 - - def record(self, partition, num_bytes, num_records): - """ - After each partition is parsed, we update the current metric totals - with the total bytes and number of records parsed. After all partitions - have reported, we write the metric. - """ - self.unrecorded_partitions.remove(partition) - self.total_bytes += num_bytes - self.total_records += num_records - - # once all expected partitions from the fetch have reported in, record the metrics - if not self.unrecorded_partitions: - self.sensors.bytes_fetched.record(self.total_bytes) - self.sensors.records_fetched.record(self.total_records) - - -class FetchManagerMetrics(object): - def __init__(self, metrics, prefix): - self.metrics = metrics - self.group_name = '%s-fetch-manager-metrics' % (prefix,) - - self.bytes_fetched = metrics.sensor('bytes-fetched') - self.bytes_fetched.add(metrics.metric_name('fetch-size-avg', self.group_name, - 'The average number of bytes fetched per request'), Avg()) - self.bytes_fetched.add(metrics.metric_name('fetch-size-max', self.group_name, - 'The maximum number of bytes fetched per request'), Max()) - self.bytes_fetched.add(metrics.metric_name('bytes-consumed-rate', self.group_name, - 'The average number of bytes consumed per second'), Rate()) - - self.records_fetched = self.metrics.sensor('records-fetched') - self.records_fetched.add(metrics.metric_name('records-per-request-avg', self.group_name, - 'The average number of records in each request'), Avg()) - self.records_fetched.add(metrics.metric_name('records-consumed-rate', self.group_name, - 'The average number of records consumed per second'), Rate()) - - self.fetch_latency = metrics.sensor('fetch-latency') - self.fetch_latency.add(metrics.metric_name('fetch-latency-avg', self.group_name, - 'The average time taken for a fetch request.'), Avg()) - self.fetch_latency.add(metrics.metric_name('fetch-latency-max', self.group_name, - 'The max time taken for any fetch request.'), Max()) - self.fetch_latency.add(metrics.metric_name('fetch-rate', self.group_name, - 'The number of fetch requests per second.'), Rate(sampled_stat=Count())) - - self.records_fetch_lag = metrics.sensor('records-lag') - self.records_fetch_lag.add(metrics.metric_name('records-lag-max', self.group_name, - 'The maximum lag in terms of number of records for any partition in self window'), Max()) - - self.fetch_throttle_time_sensor = metrics.sensor('fetch-throttle-time') - self.fetch_throttle_time_sensor.add(metrics.metric_name('fetch-throttle-time-avg', self.group_name, - 'The average throttle time in ms'), Avg()) - self.fetch_throttle_time_sensor.add(metrics.metric_name('fetch-throttle-time-max', self.group_name, - 'The maximum throttle time in ms'), Max()) - - def record_topic_fetch_metrics(self, topic, num_bytes, num_records): - # record bytes fetched - name = '.'.join(['topic', topic, 'bytes-fetched']) - bytes_fetched = self.metrics.get_sensor(name) - if not bytes_fetched: - metric_tags = {'topic': topic.replace('.', '_')} - - bytes_fetched = self.metrics.sensor(name) - bytes_fetched.add(self.metrics.metric_name('fetch-size-avg', - self.group_name, - 'The average number of bytes fetched per request for topic %s' % (topic,), - metric_tags), Avg()) - bytes_fetched.add(self.metrics.metric_name('fetch-size-max', - self.group_name, - 'The maximum number of bytes fetched per request for topic %s' % (topic,), - metric_tags), Max()) - bytes_fetched.add(self.metrics.metric_name('bytes-consumed-rate', - self.group_name, - 'The average number of bytes consumed per second for topic %s' % (topic,), - metric_tags), Rate()) - bytes_fetched.record(num_bytes) - - # record records fetched - name = '.'.join(['topic', topic, 'records-fetched']) - records_fetched = self.metrics.get_sensor(name) - if not records_fetched: - metric_tags = {'topic': topic.replace('.', '_')} - - records_fetched = self.metrics.sensor(name) - records_fetched.add(self.metrics.metric_name('records-per-request-avg', - self.group_name, - 'The average number of records in each request for topic %s' % (topic,), - metric_tags), Avg()) - records_fetched.add(self.metrics.metric_name('records-consumed-rate', - self.group_name, - 'The average number of records consumed per second for topic %s' % (topic,), - metric_tags), Rate()) - records_fetched.record(num_records) diff --git a/kafka/consumer/group.py b/kafka/consumer/group.py deleted file mode 100644 index a1d1dfa3..00000000 --- a/kafka/consumer/group.py +++ /dev/null @@ -1,1225 +0,0 @@ -from __future__ import absolute_import, division - -import copy -import logging -import socket -import time - -from kafka.errors import KafkaConfigurationError, UnsupportedVersionError - -from kafka.vendor import six - -from kafka.client_async import KafkaClient, selectors -from kafka.consumer.fetcher import Fetcher -from kafka.consumer.subscription_state import SubscriptionState -from kafka.coordinator.consumer import ConsumerCoordinator -from kafka.coordinator.assignors.range import RangePartitionAssignor -from kafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor -from kafka.metrics import MetricConfig, Metrics -from kafka.protocol.offset import OffsetResetStrategy -from kafka.structs import TopicPartition -from kafka.version import __version__ - -log = logging.getLogger(__name__) - - -class KafkaConsumer(six.Iterator): - """Consume records from a Kafka cluster. - - The consumer will transparently handle the failure of servers in the Kafka - cluster, and adapt as topic-partitions are created or migrate between - brokers. It also interacts with the assigned kafka Group Coordinator node - to allow multiple consumers to load balance consumption of topics (requires - kafka >= 0.9.0.0). - - The consumer is not thread safe and should not be shared across threads. - - Arguments: - *topics (str): optional list of topics to subscribe to. If not set, - call :meth:`~kafka.KafkaConsumer.subscribe` or - :meth:`~kafka.KafkaConsumer.assign` before consuming records. - - Keyword Arguments: - bootstrap_servers: 'host[:port]' string (or list of 'host[:port]' - strings) that the consumer should contact to bootstrap initial - cluster metadata. This does not have to be the full node list. - It just needs to have at least one broker that will respond to a - Metadata API Request. Default port is 9092. If no servers are - specified, will default to localhost:9092. - client_id (str): A name for this client. This string is passed in - each request to servers and can be used to identify specific - server-side log entries that correspond to this client. Also - submitted to GroupCoordinator for logging with respect to - consumer group administration. Default: 'kafka-python-{version}' - group_id (str or None): The name of the consumer group to join for dynamic - partition assignment (if enabled), and to use for fetching and - committing offsets. If None, auto-partition assignment (via - group coordinator) and offset commits are disabled. - Default: None - key_deserializer (callable): Any callable that takes a - raw message key and returns a deserialized key. - value_deserializer (callable): Any callable that takes a - raw message value and returns a deserialized value. - fetch_min_bytes (int): Minimum amount of data the server should - return for a fetch request, otherwise wait up to - fetch_max_wait_ms for more data to accumulate. Default: 1. - fetch_max_wait_ms (int): The maximum amount of time in milliseconds - the server will block before answering the fetch request if - there isn't sufficient data to immediately satisfy the - requirement given by fetch_min_bytes. Default: 500. - fetch_max_bytes (int): The maximum amount of data the server should - return for a fetch request. This is not an absolute maximum, if the - first message in the first non-empty partition of the fetch is - larger than this value, the message will still be returned to - ensure that the consumer can make progress. NOTE: consumer performs - fetches to multiple brokers in parallel so memory usage will depend - on the number of brokers containing partitions for the topic. - Supported Kafka version >= 0.10.1.0. Default: 52428800 (50 MB). - max_partition_fetch_bytes (int): The maximum amount of data - per-partition the server will return. The maximum total memory - used for a request = #partitions * max_partition_fetch_bytes. - This size must be at least as large as the maximum message size - the server allows or else it is possible for the producer to - send messages larger than the consumer can fetch. If that - happens, the consumer can get stuck trying to fetch a large - message on a certain partition. Default: 1048576. - request_timeout_ms (int): Client request timeout in milliseconds. - Default: 305000. - retry_backoff_ms (int): Milliseconds to backoff when retrying on - errors. Default: 100. - reconnect_backoff_ms (int): The amount of time in milliseconds to - wait before attempting to reconnect to a given host. - Default: 50. - reconnect_backoff_max_ms (int): The maximum amount of time in - milliseconds to backoff/wait when reconnecting to a broker that has - repeatedly failed to connect. If provided, the backoff per host - will increase exponentially for each consecutive connection - failure, up to this maximum. Once the maximum is reached, - reconnection attempts will continue periodically with this fixed - rate. To avoid connection storms, a randomization factor of 0.2 - will be applied to the backoff resulting in a random range between - 20% below and 20% above the computed value. Default: 1000. - max_in_flight_requests_per_connection (int): Requests are pipelined - to kafka brokers up to this number of maximum requests per - broker connection. Default: 5. - auto_offset_reset (str): A policy for resetting offsets on - OffsetOutOfRange errors: 'earliest' will move to the oldest - available message, 'latest' will move to the most recent. Any - other value will raise the exception. Default: 'latest'. - enable_auto_commit (bool): If True , the consumer's offset will be - periodically committed in the background. Default: True. - auto_commit_interval_ms (int): Number of milliseconds between automatic - offset commits, if enable_auto_commit is True. Default: 5000. - default_offset_commit_callback (callable): Called as - callback(offsets, response) response will be either an Exception - or an OffsetCommitResponse struct. This callback can be used to - trigger custom actions when a commit request completes. - check_crcs (bool): Automatically check the CRC32 of the records - consumed. This ensures no on-the-wire or on-disk corruption to - the messages occurred. This check adds some overhead, so it may - be disabled in cases seeking extreme performance. Default: True - metadata_max_age_ms (int): The period of time in milliseconds after - which we force a refresh of metadata, even if we haven't seen any - partition leadership changes to proactively discover any new - brokers or partitions. Default: 300000 - partition_assignment_strategy (list): List of objects to use to - distribute partition ownership amongst consumer instances when - group management is used. - Default: [RangePartitionAssignor, RoundRobinPartitionAssignor] - max_poll_records (int): The maximum number of records returned in a - single call to :meth:`~kafka.KafkaConsumer.poll`. Default: 500 - max_poll_interval_ms (int): The maximum delay between invocations of - :meth:`~kafka.KafkaConsumer.poll` when using consumer group - management. This places an upper bound on the amount of time that - the consumer can be idle before fetching more records. If - :meth:`~kafka.KafkaConsumer.poll` is not called before expiration - of this timeout, then the consumer is considered failed and the - group will rebalance in order to reassign the partitions to another - member. Default 300000 - session_timeout_ms (int): The timeout used to detect failures when - using Kafka's group management facilities. The consumer sends - periodic heartbeats to indicate its liveness to the broker. If - no heartbeats are received by the broker before the expiration of - this session timeout, then the broker will remove this consumer - from the group and initiate a rebalance. Note that the value must - be in the allowable range as configured in the broker configuration - by group.min.session.timeout.ms and group.max.session.timeout.ms. - Default: 10000 - heartbeat_interval_ms (int): The expected time in milliseconds - between heartbeats to the consumer coordinator when using - Kafka's group management facilities. Heartbeats are used to ensure - that the consumer's session stays active and to facilitate - rebalancing when new consumers join or leave the group. The - value must be set lower than session_timeout_ms, but typically - should be set no higher than 1/3 of that value. It can be - adjusted even lower to control the expected time for normal - rebalances. Default: 3000 - receive_buffer_bytes (int): The size of the TCP receive buffer - (SO_RCVBUF) to use when reading data. Default: None (relies on - system defaults). The java client defaults to 32768. - send_buffer_bytes (int): The size of the TCP send buffer - (SO_SNDBUF) to use when sending data. Default: None (relies on - system defaults). The java client defaults to 131072. - socket_options (list): List of tuple-arguments to socket.setsockopt - to apply to broker connection sockets. Default: - [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)] - consumer_timeout_ms (int): number of milliseconds to block during - message iteration before raising StopIteration (i.e., ending the - iterator). Default block forever [float('inf')]. - security_protocol (str): Protocol used to communicate with brokers. - Valid values are: PLAINTEXT, SSL, SASL_PLAINTEXT, SASL_SSL. - Default: PLAINTEXT. - ssl_context (ssl.SSLContext): Pre-configured SSLContext for wrapping - socket connections. If provided, all other ssl_* configurations - will be ignored. Default: None. - ssl_check_hostname (bool): Flag to configure whether ssl handshake - should verify that the certificate matches the brokers hostname. - Default: True. - ssl_cafile (str): Optional filename of ca file to use in certificate - verification. Default: None. - ssl_certfile (str): Optional filename of file in pem format containing - the client certificate, as well as any ca certificates needed to - establish the certificate's authenticity. Default: None. - ssl_keyfile (str): Optional filename containing the client private key. - Default: None. - ssl_password (str): Optional password to be used when loading the - certificate chain. Default: None. - ssl_crlfile (str): Optional filename containing the CRL to check for - certificate expiration. By default, no CRL check is done. When - providing a file, only the leaf certificate will be checked against - this CRL. The CRL can only be checked with Python 3.4+ or 2.7.9+. - Default: None. - ssl_ciphers (str): optionally set the available ciphers for ssl - connections. It should be a string in the OpenSSL cipher list - format. If no cipher can be selected (because compile-time options - or other configuration forbids use of all the specified ciphers), - an ssl.SSLError will be raised. See ssl.SSLContext.set_ciphers - api_version (tuple): Specify which Kafka API version to use. If set to - None, the client will attempt to infer the broker version by probing - various APIs. Different versions enable different functionality. - - Examples: - (0, 9) enables full group coordination features with automatic - partition assignment and rebalancing, - (0, 8, 2) enables kafka-storage offset commits with manual - partition assignment only, - (0, 8, 1) enables zookeeper-storage offset commits with manual - partition assignment only, - (0, 8, 0) enables basic functionality but requires manual - partition assignment and offset management. - - Default: None - api_version_auto_timeout_ms (int): number of milliseconds to throw a - timeout exception from the constructor when checking the broker - api version. Only applies if api_version set to None. - connections_max_idle_ms: Close idle connections after the number of - milliseconds specified by this config. The broker closes idle - connections after connections.max.idle.ms, so this avoids hitting - unexpected socket disconnected errors on the client. - Default: 540000 - metric_reporters (list): A list of classes to use as metrics reporters. - Implementing the AbstractMetricsReporter interface allows plugging - in classes that will be notified of new metric creation. Default: [] - metrics_num_samples (int): The number of samples maintained to compute - metrics. Default: 2 - metrics_sample_window_ms (int): The maximum age in milliseconds of - samples used to compute metrics. Default: 30000 - selector (selectors.BaseSelector): Provide a specific selector - implementation to use for I/O multiplexing. - Default: selectors.DefaultSelector - exclude_internal_topics (bool): Whether records from internal topics - (such as offsets) should be exposed to the consumer. If set to True - the only way to receive records from an internal topic is - subscribing to it. Requires 0.10+ Default: True - sasl_mechanism (str): Authentication mechanism when security_protocol - is configured for SASL_PLAINTEXT or SASL_SSL. Valid values are: - PLAIN, GSSAPI, OAUTHBEARER, SCRAM-SHA-256, SCRAM-SHA-512. - sasl_plain_username (str): username for sasl PLAIN and SCRAM authentication. - Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms. - sasl_plain_password (str): password for sasl PLAIN and SCRAM authentication. - Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms. - sasl_kerberos_service_name (str): Service name to include in GSSAPI - sasl mechanism handshake. Default: 'kafka' - sasl_kerberos_domain_name (str): kerberos domain name to use in GSSAPI - sasl mechanism handshake. Default: one of bootstrap servers - sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider - instance. (See kafka.oauth.abstract). Default: None - kafka_client (callable): Custom class / callable for creating KafkaClient instances - - Note: - Configuration parameters are described in more detail at - https://kafka.apache.org/documentation/#consumerconfigs - """ - DEFAULT_CONFIG = { - 'bootstrap_servers': 'localhost', - 'client_id': 'kafka-python-' + __version__, - 'group_id': None, - 'key_deserializer': None, - 'value_deserializer': None, - 'fetch_max_wait_ms': 500, - 'fetch_min_bytes': 1, - 'fetch_max_bytes': 52428800, - 'max_partition_fetch_bytes': 1 * 1024 * 1024, - 'request_timeout_ms': 305000, # chosen to be higher than the default of max_poll_interval_ms - 'retry_backoff_ms': 100, - 'reconnect_backoff_ms': 50, - 'reconnect_backoff_max_ms': 1000, - 'max_in_flight_requests_per_connection': 5, - 'auto_offset_reset': 'latest', - 'enable_auto_commit': True, - 'auto_commit_interval_ms': 5000, - 'default_offset_commit_callback': lambda offsets, response: True, - 'check_crcs': True, - 'metadata_max_age_ms': 5 * 60 * 1000, - 'partition_assignment_strategy': (RangePartitionAssignor, RoundRobinPartitionAssignor), - 'max_poll_records': 500, - 'max_poll_interval_ms': 300000, - 'session_timeout_ms': 10000, - 'heartbeat_interval_ms': 3000, - 'receive_buffer_bytes': None, - 'send_buffer_bytes': None, - 'socket_options': [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)], - 'sock_chunk_bytes': 4096, # undocumented experimental option - 'sock_chunk_buffer_count': 1000, # undocumented experimental option - 'consumer_timeout_ms': float('inf'), - 'security_protocol': 'PLAINTEXT', - 'ssl_context': None, - 'ssl_check_hostname': True, - 'ssl_cafile': None, - 'ssl_certfile': None, - 'ssl_keyfile': None, - 'ssl_crlfile': None, - 'ssl_password': None, - 'ssl_ciphers': None, - 'api_version': None, - 'api_version_auto_timeout_ms': 2000, - 'connections_max_idle_ms': 9 * 60 * 1000, - 'metric_reporters': [], - 'metrics_num_samples': 2, - 'metrics_sample_window_ms': 30000, - 'metric_group_prefix': 'consumer', - 'selector': selectors.DefaultSelector, - 'exclude_internal_topics': True, - 'sasl_mechanism': None, - 'sasl_plain_username': None, - 'sasl_plain_password': None, - 'sasl_kerberos_service_name': 'kafka', - 'sasl_kerberos_domain_name': None, - 'sasl_oauth_token_provider': None, - 'legacy_iterator': False, # enable to revert to < 1.4.7 iterator - 'kafka_client': KafkaClient, - } - DEFAULT_SESSION_TIMEOUT_MS_0_9 = 30000 - - def __init__(self, *topics, **configs): - # Only check for extra config keys in top-level class - extra_configs = set(configs).difference(self.DEFAULT_CONFIG) - if extra_configs: - raise KafkaConfigurationError("Unrecognized configs: %s" % (extra_configs,)) - - self.config = copy.copy(self.DEFAULT_CONFIG) - self.config.update(configs) - - deprecated = {'smallest': 'earliest', 'largest': 'latest'} - if self.config['auto_offset_reset'] in deprecated: - new_config = deprecated[self.config['auto_offset_reset']] - log.warning('use auto_offset_reset=%s (%s is deprecated)', - new_config, self.config['auto_offset_reset']) - self.config['auto_offset_reset'] = new_config - - connections_max_idle_ms = self.config['connections_max_idle_ms'] - request_timeout_ms = self.config['request_timeout_ms'] - fetch_max_wait_ms = self.config['fetch_max_wait_ms'] - if not (fetch_max_wait_ms < request_timeout_ms < connections_max_idle_ms): - raise KafkaConfigurationError( - "connections_max_idle_ms ({}) must be larger than " - "request_timeout_ms ({}) which must be larger than " - "fetch_max_wait_ms ({})." - .format(connections_max_idle_ms, request_timeout_ms, fetch_max_wait_ms)) - - metrics_tags = {'client-id': self.config['client_id']} - metric_config = MetricConfig(samples=self.config['metrics_num_samples'], - time_window_ms=self.config['metrics_sample_window_ms'], - tags=metrics_tags) - reporters = [reporter() for reporter in self.config['metric_reporters']] - self._metrics = Metrics(metric_config, reporters) - # TODO _metrics likely needs to be passed to KafkaClient, etc. - - # api_version was previously a str. Accept old format for now - if isinstance(self.config['api_version'], str): - str_version = self.config['api_version'] - if str_version == 'auto': - self.config['api_version'] = None - else: - self.config['api_version'] = tuple(map(int, str_version.split('.'))) - log.warning('use api_version=%s [tuple] -- "%s" as str is deprecated', - str(self.config['api_version']), str_version) - - self._client = self.config['kafka_client'](metrics=self._metrics, **self.config) - - # Get auto-discovered version from client if necessary - if self.config['api_version'] is None: - self.config['api_version'] = self._client.config['api_version'] - - # Coordinator configurations are different for older brokers - # max_poll_interval_ms is not supported directly -- it must the be - # the same as session_timeout_ms. If the user provides one of them, - # use it for both. Otherwise use the old default of 30secs - if self.config['api_version'] < (0, 10, 1): - if 'session_timeout_ms' not in configs: - if 'max_poll_interval_ms' in configs: - self.config['session_timeout_ms'] = configs['max_poll_interval_ms'] - else: - self.config['session_timeout_ms'] = self.DEFAULT_SESSION_TIMEOUT_MS_0_9 - if 'max_poll_interval_ms' not in configs: - self.config['max_poll_interval_ms'] = self.config['session_timeout_ms'] - - if self.config['group_id'] is not None: - if self.config['request_timeout_ms'] <= self.config['session_timeout_ms']: - raise KafkaConfigurationError( - "Request timeout (%s) must be larger than session timeout (%s)" % - (self.config['request_timeout_ms'], self.config['session_timeout_ms'])) - - self._subscription = SubscriptionState(self.config['auto_offset_reset']) - self._fetcher = Fetcher( - self._client, self._subscription, self._metrics, **self.config) - self._coordinator = ConsumerCoordinator( - self._client, self._subscription, self._metrics, - assignors=self.config['partition_assignment_strategy'], - **self.config) - self._closed = False - self._iterator = None - self._consumer_timeout = float('inf') - - if topics: - self._subscription.subscribe(topics=topics) - self._client.set_topics(topics) - - def bootstrap_connected(self): - """Return True if the bootstrap is connected.""" - return self._client.bootstrap_connected() - - def assign(self, partitions): - """Manually assign a list of TopicPartitions to this consumer. - - Arguments: - partitions (list of TopicPartition): Assignment for this instance. - - Raises: - IllegalStateError: If consumer has already called - :meth:`~kafka.KafkaConsumer.subscribe`. - - Warning: - It is not possible to use both manual partition assignment with - :meth:`~kafka.KafkaConsumer.assign` and group assignment with - :meth:`~kafka.KafkaConsumer.subscribe`. - - Note: - This interface does not support incremental assignment and will - replace the previous assignment (if there was one). - - Note: - Manual topic assignment through this method does not use the - consumer's group management functionality. As such, there will be - no rebalance operation triggered when group membership or cluster - and topic metadata change. - """ - self._subscription.assign_from_user(partitions) - self._client.set_topics([tp.topic for tp in partitions]) - - def assignment(self): - """Get the TopicPartitions currently assigned to this consumer. - - If partitions were directly assigned using - :meth:`~kafka.KafkaConsumer.assign`, then this will simply return the - same partitions that were previously assigned. If topics were - subscribed using :meth:`~kafka.KafkaConsumer.subscribe`, then this will - give the set of topic partitions currently assigned to the consumer - (which may be None if the assignment hasn't happened yet, or if the - partitions are in the process of being reassigned). - - Returns: - set: {TopicPartition, ...} - """ - return self._subscription.assigned_partitions() - - def close(self, autocommit=True): - """Close the consumer, waiting indefinitely for any needed cleanup. - - Keyword Arguments: - autocommit (bool): If auto-commit is configured for this consumer, - this optional flag causes the consumer to attempt to commit any - pending consumed offsets prior to close. Default: True - """ - if self._closed: - return - log.debug("Closing the KafkaConsumer.") - self._closed = True - self._coordinator.close(autocommit=autocommit) - self._metrics.close() - self._client.close() - try: - self.config['key_deserializer'].close() - except AttributeError: - pass - try: - self.config['value_deserializer'].close() - except AttributeError: - pass - log.debug("The KafkaConsumer has closed.") - - def commit_async(self, offsets=None, callback=None): - """Commit offsets to kafka asynchronously, optionally firing callback. - - This commits offsets only to Kafka. The offsets committed using this API - will be used on the first fetch after every rebalance and also on - startup. As such, if you need to store offsets in anything other than - Kafka, this API should not be used. To avoid re-processing the last - message read if a consumer is restarted, the committed offset should be - the next message your application should consume, i.e.: last_offset + 1. - - This is an asynchronous call and will not block. Any errors encountered - are either passed to the callback (if provided) or discarded. - - Arguments: - offsets (dict, optional): {TopicPartition: OffsetAndMetadata} dict - to commit with the configured group_id. Defaults to currently - consumed offsets for all subscribed partitions. - callback (callable, optional): Called as callback(offsets, response) - with response as either an Exception or an OffsetCommitResponse - struct. This callback can be used to trigger custom actions when - a commit request completes. - - Returns: - kafka.future.Future - """ - assert self.config['api_version'] >= (0, 8, 1), 'Requires >= Kafka 0.8.1' - assert self.config['group_id'] is not None, 'Requires group_id' - if offsets is None: - offsets = self._subscription.all_consumed_offsets() - log.debug("Committing offsets: %s", offsets) - future = self._coordinator.commit_offsets_async( - offsets, callback=callback) - return future - - def commit(self, offsets=None): - """Commit offsets to kafka, blocking until success or error. - - This commits offsets only to Kafka. The offsets committed using this API - will be used on the first fetch after every rebalance and also on - startup. As such, if you need to store offsets in anything other than - Kafka, this API should not be used. To avoid re-processing the last - message read if a consumer is restarted, the committed offset should be - the next message your application should consume, i.e.: last_offset + 1. - - Blocks until either the commit succeeds or an unrecoverable error is - encountered (in which case it is thrown to the caller). - - Currently only supports kafka-topic offset storage (not zookeeper). - - Arguments: - offsets (dict, optional): {TopicPartition: OffsetAndMetadata} dict - to commit with the configured group_id. Defaults to currently - consumed offsets for all subscribed partitions. - """ - assert self.config['api_version'] >= (0, 8, 1), 'Requires >= Kafka 0.8.1' - assert self.config['group_id'] is not None, 'Requires group_id' - if offsets is None: - offsets = self._subscription.all_consumed_offsets() - self._coordinator.commit_offsets_sync(offsets) - - def committed(self, partition, metadata=False): - """Get the last committed offset for the given partition. - - This offset will be used as the position for the consumer - in the event of a failure. - - This call may block to do a remote call if the partition in question - isn't assigned to this consumer or if the consumer hasn't yet - initialized its cache of committed offsets. - - Arguments: - partition (TopicPartition): The partition to check. - metadata (bool, optional): If True, return OffsetAndMetadata struct - instead of offset int. Default: False. - - Returns: - The last committed offset (int or OffsetAndMetadata), or None if there was no prior commit. - """ - assert self.config['api_version'] >= (0, 8, 1), 'Requires >= Kafka 0.8.1' - assert self.config['group_id'] is not None, 'Requires group_id' - if not isinstance(partition, TopicPartition): - raise TypeError('partition must be a TopicPartition namedtuple') - if self._subscription.is_assigned(partition): - committed = self._subscription.assignment[partition].committed - if committed is None: - self._coordinator.refresh_committed_offsets_if_needed() - committed = self._subscription.assignment[partition].committed - else: - commit_map = self._coordinator.fetch_committed_offsets([partition]) - if partition in commit_map: - committed = commit_map[partition] - else: - committed = None - - if committed is not None: - if metadata: - return committed - else: - return committed.offset - - def _fetch_all_topic_metadata(self): - """A blocking call that fetches topic metadata for all topics in the - cluster that the user is authorized to view. - """ - cluster = self._client.cluster - if self._client._metadata_refresh_in_progress and self._client._topics: - future = cluster.request_update() - self._client.poll(future=future) - stash = cluster.need_all_topic_metadata - cluster.need_all_topic_metadata = True - future = cluster.request_update() - self._client.poll(future=future) - cluster.need_all_topic_metadata = stash - - def topics(self): - """Get all topics the user is authorized to view. - This will always issue a remote call to the cluster to fetch the latest - information. - - Returns: - set: topics - """ - self._fetch_all_topic_metadata() - return self._client.cluster.topics() - - def partitions_for_topic(self, topic): - """This method first checks the local metadata cache for information - about the topic. If the topic is not found (either because the topic - does not exist, the user is not authorized to view the topic, or the - metadata cache is not populated), then it will issue a metadata update - call to the cluster. - - Arguments: - topic (str): Topic to check. - - Returns: - set: Partition ids - """ - cluster = self._client.cluster - partitions = cluster.partitions_for_topic(topic) - if partitions is None: - self._fetch_all_topic_metadata() - partitions = cluster.partitions_for_topic(topic) - return partitions - - def poll(self, timeout_ms=0, max_records=None, update_offsets=True): - """Fetch data from assigned topics / partitions. - - Records are fetched and returned in batches by topic-partition. - On each poll, consumer will try to use the last consumed offset as the - starting offset and fetch sequentially. The last consumed offset can be - manually set through :meth:`~kafka.KafkaConsumer.seek` or automatically - set as the last committed offset for the subscribed list of partitions. - - Incompatible with iterator interface -- use one or the other, not both. - - Arguments: - timeout_ms (int, optional): Milliseconds spent waiting in poll if - data is not available in the buffer. If 0, returns immediately - with any records that are available currently in the buffer, - else returns empty. Must not be negative. Default: 0 - max_records (int, optional): The maximum number of records returned - in a single call to :meth:`~kafka.KafkaConsumer.poll`. - Default: Inherit value from max_poll_records. - - Returns: - dict: Topic to list of records since the last fetch for the - subscribed list of topics and partitions. - """ - # Note: update_offsets is an internal-use only argument. It is used to - # support the python iterator interface, and which wraps consumer.poll() - # and requires that the partition offsets tracked by the fetcher are not - # updated until the iterator returns each record to the user. As such, - # the argument is not documented and should not be relied on by library - # users to not break in the future. - assert timeout_ms >= 0, 'Timeout must not be negative' - if max_records is None: - max_records = self.config['max_poll_records'] - assert isinstance(max_records, int), 'max_records must be an integer' - assert max_records > 0, 'max_records must be positive' - assert not self._closed, 'KafkaConsumer is closed' - - # Poll for new data until the timeout expires - start = time.time() - remaining = timeout_ms - while not self._closed: - records = self._poll_once(remaining, max_records, update_offsets=update_offsets) - if records: - return records - - elapsed_ms = (time.time() - start) * 1000 - remaining = timeout_ms - elapsed_ms - - if remaining <= 0: - break - - return {} - - def _poll_once(self, timeout_ms, max_records, update_offsets=True): - """Do one round of polling. In addition to checking for new data, this does - any needed heart-beating, auto-commits, and offset updates. - - Arguments: - timeout_ms (int): The maximum time in milliseconds to block. - - Returns: - dict: Map of topic to list of records (may be empty). - """ - self._coordinator.poll() - - # Fetch positions if we have partitions we're subscribed to that we - # don't know the offset for - if not self._subscription.has_all_fetch_positions(): - self._update_fetch_positions(self._subscription.missing_fetch_positions()) - - # If data is available already, e.g. from a previous network client - # poll() call to commit, then just return it immediately - records, partial = self._fetcher.fetched_records(max_records, update_offsets=update_offsets) - if records: - # Before returning the fetched records, we can send off the - # next round of fetches and avoid block waiting for their - # responses to enable pipelining while the user is handling the - # fetched records. - if not partial: - futures = self._fetcher.send_fetches() - if len(futures): - self._client.poll(timeout_ms=0) - return records - - # Send any new fetches (won't resend pending fetches) - futures = self._fetcher.send_fetches() - if len(futures): - self._client.poll(timeout_ms=0) - - timeout_ms = min(timeout_ms, self._coordinator.time_to_next_poll() * 1000) - self._client.poll(timeout_ms=timeout_ms) - # after the long poll, we should check whether the group needs to rebalance - # prior to returning data so that the group can stabilize faster - if self._coordinator.need_rejoin(): - return {} - - records, _ = self._fetcher.fetched_records(max_records, update_offsets=update_offsets) - return records - - def position(self, partition): - """Get the offset of the next record that will be fetched - - Arguments: - partition (TopicPartition): Partition to check - - Returns: - int: Offset - """ - if not isinstance(partition, TopicPartition): - raise TypeError('partition must be a TopicPartition namedtuple') - assert self._subscription.is_assigned(partition), 'Partition is not assigned' - offset = self._subscription.assignment[partition].position - if offset is None: - self._update_fetch_positions([partition]) - offset = self._subscription.assignment[partition].position - return offset - - def highwater(self, partition): - """Last known highwater offset for a partition. - - A highwater offset is the offset that will be assigned to the next - message that is produced. It may be useful for calculating lag, by - comparing with the reported position. Note that both position and - highwater refer to the *next* offset -- i.e., highwater offset is - one greater than the newest available message. - - Highwater offsets are returned in FetchResponse messages, so will - not be available if no FetchRequests have been sent for this partition - yet. - - Arguments: - partition (TopicPartition): Partition to check - - Returns: - int or None: Offset if available - """ - if not isinstance(partition, TopicPartition): - raise TypeError('partition must be a TopicPartition namedtuple') - assert self._subscription.is_assigned(partition), 'Partition is not assigned' - return self._subscription.assignment[partition].highwater - - def pause(self, *partitions): - """Suspend fetching from the requested partitions. - - Future calls to :meth:`~kafka.KafkaConsumer.poll` will not return any - records from these partitions until they have been resumed using - :meth:`~kafka.KafkaConsumer.resume`. - - Note: This method does not affect partition subscription. In particular, - it does not cause a group rebalance when automatic assignment is used. - - Arguments: - *partitions (TopicPartition): Partitions to pause. - """ - if not all([isinstance(p, TopicPartition) for p in partitions]): - raise TypeError('partitions must be TopicPartition namedtuples') - for partition in partitions: - log.debug("Pausing partition %s", partition) - self._subscription.pause(partition) - # Because the iterator checks is_fetchable() on each iteration - # we expect pauses to get handled automatically and therefore - # we do not need to reset the full iterator (forcing a full refetch) - - def paused(self): - """Get the partitions that were previously paused using - :meth:`~kafka.KafkaConsumer.pause`. - - Returns: - set: {partition (TopicPartition), ...} - """ - return self._subscription.paused_partitions() - - def resume(self, *partitions): - """Resume fetching from the specified (paused) partitions. - - Arguments: - *partitions (TopicPartition): Partitions to resume. - """ - if not all([isinstance(p, TopicPartition) for p in partitions]): - raise TypeError('partitions must be TopicPartition namedtuples') - for partition in partitions: - log.debug("Resuming partition %s", partition) - self._subscription.resume(partition) - - def seek(self, partition, offset): - """Manually specify the fetch offset for a TopicPartition. - - Overrides the fetch offsets that the consumer will use on the next - :meth:`~kafka.KafkaConsumer.poll`. If this API is invoked for the same - partition more than once, the latest offset will be used on the next - :meth:`~kafka.KafkaConsumer.poll`. - - Note: You may lose data if this API is arbitrarily used in the middle of - consumption to reset the fetch offsets. - - Arguments: - partition (TopicPartition): Partition for seek operation - offset (int): Message offset in partition - - Raises: - AssertionError: If offset is not an int >= 0; or if partition is not - currently assigned. - """ - if not isinstance(partition, TopicPartition): - raise TypeError('partition must be a TopicPartition namedtuple') - assert isinstance(offset, int) and offset >= 0, 'Offset must be >= 0' - assert partition in self._subscription.assigned_partitions(), 'Unassigned partition' - log.debug("Seeking to offset %s for partition %s", offset, partition) - self._subscription.assignment[partition].seek(offset) - if not self.config['legacy_iterator']: - self._iterator = None - - def seek_to_beginning(self, *partitions): - """Seek to the oldest available offset for partitions. - - Arguments: - *partitions: Optionally provide specific TopicPartitions, otherwise - default to all assigned partitions. - - Raises: - AssertionError: If any partition is not currently assigned, or if - no partitions are assigned. - """ - if not all([isinstance(p, TopicPartition) for p in partitions]): - raise TypeError('partitions must be TopicPartition namedtuples') - if not partitions: - partitions = self._subscription.assigned_partitions() - assert partitions, 'No partitions are currently assigned' - else: - for p in partitions: - assert p in self._subscription.assigned_partitions(), 'Unassigned partition' - - for tp in partitions: - log.debug("Seeking to beginning of partition %s", tp) - self._subscription.need_offset_reset(tp, OffsetResetStrategy.EARLIEST) - if not self.config['legacy_iterator']: - self._iterator = None - - def seek_to_end(self, *partitions): - """Seek to the most recent available offset for partitions. - - Arguments: - *partitions: Optionally provide specific TopicPartitions, otherwise - default to all assigned partitions. - - Raises: - AssertionError: If any partition is not currently assigned, or if - no partitions are assigned. - """ - if not all([isinstance(p, TopicPartition) for p in partitions]): - raise TypeError('partitions must be TopicPartition namedtuples') - if not partitions: - partitions = self._subscription.assigned_partitions() - assert partitions, 'No partitions are currently assigned' - else: - for p in partitions: - assert p in self._subscription.assigned_partitions(), 'Unassigned partition' - - for tp in partitions: - log.debug("Seeking to end of partition %s", tp) - self._subscription.need_offset_reset(tp, OffsetResetStrategy.LATEST) - if not self.config['legacy_iterator']: - self._iterator = None - - def subscribe(self, topics=(), pattern=None, listener=None): - """Subscribe to a list of topics, or a topic regex pattern. - - Partitions will be dynamically assigned via a group coordinator. - Topic subscriptions are not incremental: this list will replace the - current assignment (if there is one). - - This method is incompatible with :meth:`~kafka.KafkaConsumer.assign`. - - Arguments: - topics (list): List of topics for subscription. - pattern (str): Pattern to match available topics. You must provide - either topics or pattern, but not both. - listener (ConsumerRebalanceListener): Optionally include listener - callback, which will be called before and after each rebalance - operation. - - As part of group management, the consumer will keep track of the - list of consumers that belong to a particular group and will - trigger a rebalance operation if one of the following events - trigger: - - * Number of partitions change for any of the subscribed topics - * Topic is created or deleted - * An existing member of the consumer group dies - * A new member is added to the consumer group - - When any of these events are triggered, the provided listener - will be invoked first to indicate that the consumer's assignment - has been revoked, and then again when the new assignment has - been received. Note that this listener will immediately override - any listener set in a previous call to subscribe. It is - guaranteed, however, that the partitions revoked/assigned - through this interface are from topics subscribed in this call. - - Raises: - IllegalStateError: If called after previously calling - :meth:`~kafka.KafkaConsumer.assign`. - AssertionError: If neither topics or pattern is provided. - TypeError: If listener is not a ConsumerRebalanceListener. - """ - # SubscriptionState handles error checking - self._subscription.subscribe(topics=topics, - pattern=pattern, - listener=listener) - - # Regex will need all topic metadata - if pattern is not None: - self._client.cluster.need_all_topic_metadata = True - self._client.set_topics([]) - self._client.cluster.request_update() - log.debug("Subscribed to topic pattern: %s", pattern) - else: - self._client.cluster.need_all_topic_metadata = False - self._client.set_topics(self._subscription.group_subscription()) - log.debug("Subscribed to topic(s): %s", topics) - - def subscription(self): - """Get the current topic subscription. - - Returns: - set: {topic, ...} - """ - if self._subscription.subscription is None: - return None - return self._subscription.subscription.copy() - - def unsubscribe(self): - """Unsubscribe from all topics and clear all assigned partitions.""" - self._subscription.unsubscribe() - self._coordinator.close() - self._client.cluster.need_all_topic_metadata = False - self._client.set_topics([]) - log.debug("Unsubscribed all topics or patterns and assigned partitions") - if not self.config['legacy_iterator']: - self._iterator = None - - def metrics(self, raw=False): - """Get metrics on consumer performance. - - This is ported from the Java Consumer, for details see: - https://kafka.apache.org/documentation/#consumer_monitoring - - Warning: - This is an unstable interface. It may change in future - releases without warning. - """ - if raw: - return self._metrics.metrics.copy() - - metrics = {} - for k, v in six.iteritems(self._metrics.metrics.copy()): - if k.group not in metrics: - metrics[k.group] = {} - if k.name not in metrics[k.group]: - metrics[k.group][k.name] = {} - metrics[k.group][k.name] = v.value() - return metrics - - def offsets_for_times(self, timestamps): - """Look up the offsets for the given partitions by timestamp. The - returned offset for each partition is the earliest offset whose - timestamp is greater than or equal to the given timestamp in the - corresponding partition. - - This is a blocking call. The consumer does not have to be assigned the - partitions. - - If the message format version in a partition is before 0.10.0, i.e. - the messages do not have timestamps, ``None`` will be returned for that - partition. ``None`` will also be returned for the partition if there - are no messages in it. - - Note: - This method may block indefinitely if the partition does not exist. - - Arguments: - timestamps (dict): ``{TopicPartition: int}`` mapping from partition - to the timestamp to look up. Unit should be milliseconds since - beginning of the epoch (midnight Jan 1, 1970 (UTC)) - - Returns: - ``{TopicPartition: OffsetAndTimestamp}``: mapping from partition - to the timestamp and offset of the first message with timestamp - greater than or equal to the target timestamp. - - Raises: - ValueError: If the target timestamp is negative - UnsupportedVersionError: If the broker does not support looking - up the offsets by timestamp. - KafkaTimeoutError: If fetch failed in request_timeout_ms - """ - if self.config['api_version'] <= (0, 10, 0): - raise UnsupportedVersionError( - "offsets_for_times API not supported for cluster version {}" - .format(self.config['api_version'])) - for tp, ts in six.iteritems(timestamps): - timestamps[tp] = int(ts) - if ts < 0: - raise ValueError( - "The target time for partition {} is {}. The target time " - "cannot be negative.".format(tp, ts)) - return self._fetcher.get_offsets_by_times( - timestamps, self.config['request_timeout_ms']) - - def beginning_offsets(self, partitions): - """Get the first offset for the given partitions. - - This method does not change the current consumer position of the - partitions. - - Note: - This method may block indefinitely if the partition does not exist. - - Arguments: - partitions (list): List of TopicPartition instances to fetch - offsets for. - - Returns: - ``{TopicPartition: int}``: The earliest available offsets for the - given partitions. - - Raises: - UnsupportedVersionError: If the broker does not support looking - up the offsets by timestamp. - KafkaTimeoutError: If fetch failed in request_timeout_ms. - """ - offsets = self._fetcher.beginning_offsets( - partitions, self.config['request_timeout_ms']) - return offsets - - def end_offsets(self, partitions): - """Get the last offset for the given partitions. The last offset of a - partition is the offset of the upcoming message, i.e. the offset of the - last available message + 1. - - This method does not change the current consumer position of the - partitions. - - Note: - This method may block indefinitely if the partition does not exist. - - Arguments: - partitions (list): List of TopicPartition instances to fetch - offsets for. - - Returns: - ``{TopicPartition: int}``: The end offsets for the given partitions. - - Raises: - UnsupportedVersionError: If the broker does not support looking - up the offsets by timestamp. - KafkaTimeoutError: If fetch failed in request_timeout_ms - """ - offsets = self._fetcher.end_offsets( - partitions, self.config['request_timeout_ms']) - return offsets - - def _use_consumer_group(self): - """Return True iff this consumer can/should join a broker-coordinated group.""" - if self.config['api_version'] < (0, 9): - return False - elif self.config['group_id'] is None: - return False - elif not self._subscription.partitions_auto_assigned(): - return False - return True - - def _update_fetch_positions(self, partitions): - """Set the fetch position to the committed position (if there is one) - or reset it using the offset reset policy the user has configured. - - Arguments: - partitions (List[TopicPartition]): The partitions that need - updating fetch positions. - - Raises: - NoOffsetForPartitionError: If no offset is stored for a given - partition and no offset reset policy is defined. - """ - # Lookup any positions for partitions which are awaiting reset (which may be the - # case if the user called :meth:`seek_to_beginning` or :meth:`seek_to_end`. We do - # this check first to avoid an unnecessary lookup of committed offsets (which - # typically occurs when the user is manually assigning partitions and managing - # their own offsets). - self._fetcher.reset_offsets_if_needed(partitions) - - if not self._subscription.has_all_fetch_positions(): - # if we still don't have offsets for all partitions, then we should either seek - # to the last committed position or reset using the auto reset policy - if (self.config['api_version'] >= (0, 8, 1) and - self.config['group_id'] is not None): - # first refresh commits for all assigned partitions - self._coordinator.refresh_committed_offsets_if_needed() - - # Then, do any offset lookups in case some positions are not known - self._fetcher.update_fetch_positions(partitions) - - def _message_generator_v2(self): - timeout_ms = 1000 * (self._consumer_timeout - time.time()) - record_map = self.poll(timeout_ms=timeout_ms, update_offsets=False) - for tp, records in six.iteritems(record_map): - # Generators are stateful, and it is possible that the tp / records - # here may become stale during iteration -- i.e., we seek to a - # different offset, pause consumption, or lose assignment. - for record in records: - # is_fetchable(tp) should handle assignment changes and offset - # resets; for all other changes (e.g., seeks) we'll rely on the - # outer function destroying the existing iterator/generator - # via self._iterator = None - if not self._subscription.is_fetchable(tp): - log.debug("Not returning fetched records for partition %s" - " since it is no longer fetchable", tp) - break - self._subscription.assignment[tp].position = record.offset + 1 - yield record - - def _message_generator(self): - assert self.assignment() or self.subscription() is not None, 'No topic subscription or manual partition assignment' - while time.time() < self._consumer_timeout: - - self._coordinator.poll() - - # Fetch offsets for any subscribed partitions that we arent tracking yet - if not self._subscription.has_all_fetch_positions(): - partitions = self._subscription.missing_fetch_positions() - self._update_fetch_positions(partitions) - - poll_ms = min((1000 * (self._consumer_timeout - time.time())), self.config['retry_backoff_ms']) - self._client.poll(timeout_ms=poll_ms) - - # after the long poll, we should check whether the group needs to rebalance - # prior to returning data so that the group can stabilize faster - if self._coordinator.need_rejoin(): - continue - - # We need to make sure we at least keep up with scheduled tasks, - # like heartbeats, auto-commits, and metadata refreshes - timeout_at = self._next_timeout() - - # Short-circuit the fetch iterator if we are already timed out - # to avoid any unintentional interaction with fetcher setup - if time.time() > timeout_at: - continue - - for msg in self._fetcher: - yield msg - if time.time() > timeout_at: - log.debug("internal iterator timeout - breaking for poll") - break - self._client.poll(timeout_ms=0) - - # An else block on a for loop only executes if there was no break - # so this should only be called on a StopIteration from the fetcher - # We assume that it is safe to init_fetches when fetcher is done - # i.e., there are no more records stored internally - else: - self._fetcher.send_fetches() - - def _next_timeout(self): - timeout = min(self._consumer_timeout, - self._client.cluster.ttl() / 1000.0 + time.time(), - self._coordinator.time_to_next_poll() + time.time()) - return timeout - - def __iter__(self): # pylint: disable=non-iterator-returned - return self - - def __next__(self): - if self._closed: - raise StopIteration('KafkaConsumer closed') - # Now that the heartbeat thread runs in the background - # there should be no reason to maintain a separate iterator - # but we'll keep it available for a few releases just in case - if self.config['legacy_iterator']: - return self.next_v1() - else: - return self.next_v2() - - def next_v2(self): - self._set_consumer_timeout() - while time.time() < self._consumer_timeout: - if not self._iterator: - self._iterator = self._message_generator_v2() - try: - return next(self._iterator) - except StopIteration: - self._iterator = None - raise StopIteration() - - def next_v1(self): - if not self._iterator: - self._iterator = self._message_generator() - - self._set_consumer_timeout() - try: - return next(self._iterator) - except StopIteration: - self._iterator = None - raise - - def _set_consumer_timeout(self): - # consumer_timeout_ms can be used to stop iteration early - if self.config['consumer_timeout_ms'] >= 0: - self._consumer_timeout = time.time() + ( - self.config['consumer_timeout_ms'] / 1000.0) diff --git a/kafka/producer/__init__.py b/kafka/producer/__init__.py deleted file mode 100644 index 576c772a..00000000 --- a/kafka/producer/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from __future__ import absolute_import - -from kafka.producer.kafka import KafkaProducer - -__all__ = [ - 'KafkaProducer' -] diff --git a/kafka/producer/buffer.py b/kafka/producer/buffer.py deleted file mode 100644 index 10080170..00000000 --- a/kafka/producer/buffer.py +++ /dev/null @@ -1,115 +0,0 @@ -from __future__ import absolute_import, division - -import collections -import io -import threading -import time - -from kafka.metrics.stats import Rate - -import kafka.errors as Errors - - -class SimpleBufferPool(object): - """A simple pool of BytesIO objects with a weak memory ceiling.""" - def __init__(self, memory, poolable_size, metrics=None, metric_group_prefix='producer-metrics'): - """Create a new buffer pool. - - Arguments: - memory (int): maximum memory that this buffer pool can allocate - poolable_size (int): memory size per buffer to cache in the free - list rather than deallocating - """ - self._poolable_size = poolable_size - self._lock = threading.RLock() - - buffers = int(memory / poolable_size) if poolable_size else 0 - self._free = collections.deque([io.BytesIO() for _ in range(buffers)]) - - self._waiters = collections.deque() - self.wait_time = None - if metrics: - self.wait_time = metrics.sensor('bufferpool-wait-time') - self.wait_time.add(metrics.metric_name( - 'bufferpool-wait-ratio', metric_group_prefix, - 'The fraction of time an appender waits for space allocation.'), - Rate()) - - def allocate(self, size, max_time_to_block_ms): - """ - Allocate a buffer of the given size. This method blocks if there is not - enough memory and the buffer pool is configured with blocking mode. - - Arguments: - size (int): The buffer size to allocate in bytes [ignored] - max_time_to_block_ms (int): The maximum time in milliseconds to - block for buffer memory to be available - - Returns: - io.BytesIO - """ - with self._lock: - # check if we have a free buffer of the right size pooled - if self._free: - return self._free.popleft() - - elif self._poolable_size == 0: - return io.BytesIO() - - else: - # we are out of buffers and will have to block - buf = None - more_memory = threading.Condition(self._lock) - self._waiters.append(more_memory) - # loop over and over until we have a buffer or have reserved - # enough memory to allocate one - while buf is None: - start_wait = time.time() - more_memory.wait(max_time_to_block_ms / 1000.0) - end_wait = time.time() - if self.wait_time: - self.wait_time.record(end_wait - start_wait) - - if self._free: - buf = self._free.popleft() - else: - self._waiters.remove(more_memory) - raise Errors.KafkaTimeoutError( - "Failed to allocate memory within the configured" - " max blocking time") - - # remove the condition for this thread to let the next thread - # in line start getting memory - removed = self._waiters.popleft() - assert removed is more_memory, 'Wrong condition' - - # signal any additional waiters if there is more memory left - # over for them - if self._free and self._waiters: - self._waiters[0].notify() - - # unlock and return the buffer - return buf - - def deallocate(self, buf): - """ - Return buffers to the pool. If they are of the poolable size add them - to the free list, otherwise just mark the memory as free. - - Arguments: - buffer_ (io.BytesIO): The buffer to return - """ - with self._lock: - # BytesIO.truncate here makes the pool somewhat pointless - # but we stick with the BufferPool API until migrating to - # bytesarray / memoryview. The buffer we return must not - # expose any prior data on read(). - buf.truncate(0) - self._free.append(buf) - if self._waiters: - self._waiters[0].notify() - - def queued(self): - """The number of threads blocked waiting on memory.""" - with self._lock: - return len(self._waiters) diff --git a/kafka/producer/future.py b/kafka/producer/future.py deleted file mode 100644 index 07fa4adb..00000000 --- a/kafka/producer/future.py +++ /dev/null @@ -1,71 +0,0 @@ -from __future__ import absolute_import - -import collections -import threading - -from kafka import errors as Errors -from kafka.future import Future - - -class FutureProduceResult(Future): - def __init__(self, topic_partition): - super(FutureProduceResult, self).__init__() - self.topic_partition = topic_partition - self._latch = threading.Event() - - def success(self, value): - ret = super(FutureProduceResult, self).success(value) - self._latch.set() - return ret - - def failure(self, error): - ret = super(FutureProduceResult, self).failure(error) - self._latch.set() - return ret - - def wait(self, timeout=None): - # wait() on python2.6 returns None instead of the flag value - return self._latch.wait(timeout) or self._latch.is_set() - - -class FutureRecordMetadata(Future): - def __init__(self, produce_future, relative_offset, timestamp_ms, checksum, serialized_key_size, serialized_value_size, serialized_header_size): - super(FutureRecordMetadata, self).__init__() - self._produce_future = produce_future - # packing args as a tuple is a minor speed optimization - self.args = (relative_offset, timestamp_ms, checksum, serialized_key_size, serialized_value_size, serialized_header_size) - produce_future.add_callback(self._produce_success) - produce_future.add_errback(self.failure) - - def _produce_success(self, offset_and_timestamp): - offset, produce_timestamp_ms, log_start_offset = offset_and_timestamp - - # Unpacking from args tuple is minor speed optimization - (relative_offset, timestamp_ms, checksum, - serialized_key_size, serialized_value_size, serialized_header_size) = self.args - - # None is when Broker does not support the API (<0.10) and - # -1 is when the broker is configured for CREATE_TIME timestamps - if produce_timestamp_ms is not None and produce_timestamp_ms != -1: - timestamp_ms = produce_timestamp_ms - if offset != -1 and relative_offset is not None: - offset += relative_offset - tp = self._produce_future.topic_partition - metadata = RecordMetadata(tp[0], tp[1], tp, offset, timestamp_ms, log_start_offset, - checksum, serialized_key_size, - serialized_value_size, serialized_header_size) - self.success(metadata) - - def get(self, timeout=None): - if not self.is_done and not self._produce_future.wait(timeout): - raise Errors.KafkaTimeoutError( - "Timeout after waiting for %s secs." % (timeout,)) - assert self.is_done - if self.failed(): - raise self.exception # pylint: disable-msg=raising-bad-type - return self.value - - -RecordMetadata = collections.namedtuple( - 'RecordMetadata', ['topic', 'partition', 'topic_partition', 'offset', 'timestamp', 'log_start_offset', - 'checksum', 'serialized_key_size', 'serialized_value_size', 'serialized_header_size']) diff --git a/kafka/producer/kafka.py b/kafka/producer/kafka.py deleted file mode 100644 index dd1cc508..00000000 --- a/kafka/producer/kafka.py +++ /dev/null @@ -1,752 +0,0 @@ -from __future__ import absolute_import - -import atexit -import copy -import logging -import socket -import threading -import time -import weakref - -from kafka.vendor import six - -import kafka.errors as Errors -from kafka.client_async import KafkaClient, selectors -from kafka.codec import has_gzip, has_snappy, has_lz4, has_zstd -from kafka.metrics import MetricConfig, Metrics -from kafka.partitioner.default import DefaultPartitioner -from kafka.producer.future import FutureRecordMetadata, FutureProduceResult -from kafka.producer.record_accumulator import AtomicInteger, RecordAccumulator -from kafka.producer.sender import Sender -from kafka.record.default_records import DefaultRecordBatchBuilder -from kafka.record.legacy_records import LegacyRecordBatchBuilder -from kafka.serializer import Serializer -from kafka.structs import TopicPartition - - -log = logging.getLogger(__name__) -PRODUCER_CLIENT_ID_SEQUENCE = AtomicInteger() - - -class KafkaProducer(object): - """A Kafka client that publishes records to the Kafka cluster. - - The producer is thread safe and sharing a single producer instance across - threads will generally be faster than having multiple instances. - - The producer consists of a pool of buffer space that holds records that - haven't yet been transmitted to the server as well as a background I/O - thread that is responsible for turning these records into requests and - transmitting them to the cluster. - - :meth:`~kafka.KafkaProducer.send` is asynchronous. When called it adds the - record to a buffer of pending record sends and immediately returns. This - allows the producer to batch together individual records for efficiency. - - The 'acks' config controls the criteria under which requests are considered - complete. The "all" setting will result in blocking on the full commit of - the record, the slowest but most durable setting. - - If the request fails, the producer can automatically retry, unless - 'retries' is configured to 0. Enabling retries also opens up the - possibility of duplicates (see the documentation on message - delivery semantics for details: - https://kafka.apache.org/documentation.html#semantics - ). - - The producer maintains buffers of unsent records for each partition. These - buffers are of a size specified by the 'batch_size' config. Making this - larger can result in more batching, but requires more memory (since we will - generally have one of these buffers for each active partition). - - By default a buffer is available to send immediately even if there is - additional unused space in the buffer. However if you want to reduce the - number of requests you can set 'linger_ms' to something greater than 0. - This will instruct the producer to wait up to that number of milliseconds - before sending a request in hope that more records will arrive to fill up - the same batch. This is analogous to Nagle's algorithm in TCP. Note that - records that arrive close together in time will generally batch together - even with linger_ms=0 so under heavy load batching will occur regardless of - the linger configuration; however setting this to something larger than 0 - can lead to fewer, more efficient requests when not under maximal load at - the cost of a small amount of latency. - - The buffer_memory controls the total amount of memory available to the - producer for buffering. If records are sent faster than they can be - transmitted to the server then this buffer space will be exhausted. When - the buffer space is exhausted additional send calls will block. - - The key_serializer and value_serializer instruct how to turn the key and - value objects the user provides into bytes. - - Keyword Arguments: - bootstrap_servers: 'host[:port]' string (or list of 'host[:port]' - strings) that the producer should contact to bootstrap initial - cluster metadata. This does not have to be the full node list. - It just needs to have at least one broker that will respond to a - Metadata API Request. Default port is 9092. If no servers are - specified, will default to localhost:9092. - client_id (str): a name for this client. This string is passed in - each request to servers and can be used to identify specific - server-side log entries that correspond to this client. - Default: 'kafka-python-producer-#' (appended with a unique number - per instance) - key_serializer (callable): used to convert user-supplied keys to bytes - If not None, called as f(key), should return bytes. Default: None. - value_serializer (callable): used to convert user-supplied message - values to bytes. If not None, called as f(value), should return - bytes. Default: None. - acks (0, 1, 'all'): The number of acknowledgments the producer requires - the leader to have received before considering a request complete. - This controls the durability of records that are sent. The - following settings are common: - - 0: Producer will not wait for any acknowledgment from the server. - The message will immediately be added to the socket - buffer and considered sent. No guarantee can be made that the - server has received the record in this case, and the retries - configuration will not take effect (as the client won't - generally know of any failures). The offset given back for each - record will always be set to -1. - 1: Wait for leader to write the record to its local log only. - Broker will respond without awaiting full acknowledgement from - all followers. In this case should the leader fail immediately - after acknowledging the record but before the followers have - replicated it then the record will be lost. - all: Wait for the full set of in-sync replicas to write the record. - This guarantees that the record will not be lost as long as at - least one in-sync replica remains alive. This is the strongest - available guarantee. - If unset, defaults to acks=1. - compression_type (str): The compression type for all data generated by - the producer. Valid values are 'gzip', 'snappy', 'lz4', 'zstd' or None. - Compression is of full batches of data, so the efficacy of batching - will also impact the compression ratio (more batching means better - compression). Default: None. - retries (int): Setting a value greater than zero will cause the client - to resend any record whose send fails with a potentially transient - error. Note that this retry is no different than if the client - resent the record upon receiving the error. Allowing retries - without setting max_in_flight_requests_per_connection to 1 will - potentially change the ordering of records because if two batches - are sent to a single partition, and the first fails and is retried - but the second succeeds, then the records in the second batch may - appear first. - Default: 0. - batch_size (int): Requests sent to brokers will contain multiple - batches, one for each partition with data available to be sent. - A small batch size will make batching less common and may reduce - throughput (a batch size of zero will disable batching entirely). - Default: 16384 - linger_ms (int): The producer groups together any records that arrive - in between request transmissions into a single batched request. - Normally this occurs only under load when records arrive faster - than they can be sent out. However in some circumstances the client - may want to reduce the number of requests even under moderate load. - This setting accomplishes this by adding a small amount of - artificial delay; that is, rather than immediately sending out a - record the producer will wait for up to the given delay to allow - other records to be sent so that the sends can be batched together. - This can be thought of as analogous to Nagle's algorithm in TCP. - This setting gives the upper bound on the delay for batching: once - we get batch_size worth of records for a partition it will be sent - immediately regardless of this setting, however if we have fewer - than this many bytes accumulated for this partition we will - 'linger' for the specified time waiting for more records to show - up. This setting defaults to 0 (i.e. no delay). Setting linger_ms=5 - would have the effect of reducing the number of requests sent but - would add up to 5ms of latency to records sent in the absence of - load. Default: 0. - partitioner (callable): Callable used to determine which partition - each message is assigned to. Called (after key serialization): - partitioner(key_bytes, all_partitions, available_partitions). - The default partitioner implementation hashes each non-None key - using the same murmur2 algorithm as the java client so that - messages with the same key are assigned to the same partition. - When a key is None, the message is delivered to a random partition - (filtered to partitions with available leaders only, if possible). - buffer_memory (int): The total bytes of memory the producer should use - to buffer records waiting to be sent to the server. If records are - sent faster than they can be delivered to the server the producer - will block up to max_block_ms, raising an exception on timeout. - In the current implementation, this setting is an approximation. - Default: 33554432 (32MB) - connections_max_idle_ms: Close idle connections after the number of - milliseconds specified by this config. The broker closes idle - connections after connections.max.idle.ms, so this avoids hitting - unexpected socket disconnected errors on the client. - Default: 540000 - max_block_ms (int): Number of milliseconds to block during - :meth:`~kafka.KafkaProducer.send` and - :meth:`~kafka.KafkaProducer.partitions_for`. These methods can be - blocked either because the buffer is full or metadata unavailable. - Blocking in the user-supplied serializers or partitioner will not be - counted against this timeout. Default: 60000. - max_request_size (int): The maximum size of a request. This is also - effectively a cap on the maximum record size. Note that the server - has its own cap on record size which may be different from this. - This setting will limit the number of record batches the producer - will send in a single request to avoid sending huge requests. - Default: 1048576. - metadata_max_age_ms (int): The period of time in milliseconds after - which we force a refresh of metadata even if we haven't seen any - partition leadership changes to proactively discover any new - brokers or partitions. Default: 300000 - retry_backoff_ms (int): Milliseconds to backoff when retrying on - errors. Default: 100. - request_timeout_ms (int): Client request timeout in milliseconds. - Default: 30000. - receive_buffer_bytes (int): The size of the TCP receive buffer - (SO_RCVBUF) to use when reading data. Default: None (relies on - system defaults). Java client defaults to 32768. - send_buffer_bytes (int): The size of the TCP send buffer - (SO_SNDBUF) to use when sending data. Default: None (relies on - system defaults). Java client defaults to 131072. - socket_options (list): List of tuple-arguments to socket.setsockopt - to apply to broker connection sockets. Default: - [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)] - reconnect_backoff_ms (int): The amount of time in milliseconds to - wait before attempting to reconnect to a given host. - Default: 50. - reconnect_backoff_max_ms (int): The maximum amount of time in - milliseconds to backoff/wait when reconnecting to a broker that has - repeatedly failed to connect. If provided, the backoff per host - will increase exponentially for each consecutive connection - failure, up to this maximum. Once the maximum is reached, - reconnection attempts will continue periodically with this fixed - rate. To avoid connection storms, a randomization factor of 0.2 - will be applied to the backoff resulting in a random range between - 20% below and 20% above the computed value. Default: 1000. - max_in_flight_requests_per_connection (int): Requests are pipelined - to kafka brokers up to this number of maximum requests per - broker connection. Note that if this setting is set to be greater - than 1 and there are failed sends, there is a risk of message - re-ordering due to retries (i.e., if retries are enabled). - Default: 5. - security_protocol (str): Protocol used to communicate with brokers. - Valid values are: PLAINTEXT, SSL, SASL_PLAINTEXT, SASL_SSL. - Default: PLAINTEXT. - ssl_context (ssl.SSLContext): pre-configured SSLContext for wrapping - socket connections. If provided, all other ssl_* configurations - will be ignored. Default: None. - ssl_check_hostname (bool): flag to configure whether ssl handshake - should verify that the certificate matches the brokers hostname. - default: true. - ssl_cafile (str): optional filename of ca file to use in certificate - verification. default: none. - ssl_certfile (str): optional filename of file in pem format containing - the client certificate, as well as any ca certificates needed to - establish the certificate's authenticity. default: none. - ssl_keyfile (str): optional filename containing the client private key. - default: none. - ssl_password (str): optional password to be used when loading the - certificate chain. default: none. - ssl_crlfile (str): optional filename containing the CRL to check for - certificate expiration. By default, no CRL check is done. When - providing a file, only the leaf certificate will be checked against - this CRL. The CRL can only be checked with Python 3.4+ or 2.7.9+. - default: none. - ssl_ciphers (str): optionally set the available ciphers for ssl - connections. It should be a string in the OpenSSL cipher list - format. If no cipher can be selected (because compile-time options - or other configuration forbids use of all the specified ciphers), - an ssl.SSLError will be raised. See ssl.SSLContext.set_ciphers - api_version (tuple): Specify which Kafka API version to use. If set to - None, the client will attempt to infer the broker version by probing - various APIs. Example: (0, 10, 2). Default: None - api_version_auto_timeout_ms (int): number of milliseconds to throw a - timeout exception from the constructor when checking the broker - api version. Only applies if api_version set to None. - metric_reporters (list): A list of classes to use as metrics reporters. - Implementing the AbstractMetricsReporter interface allows plugging - in classes that will be notified of new metric creation. Default: [] - metrics_num_samples (int): The number of samples maintained to compute - metrics. Default: 2 - metrics_sample_window_ms (int): The maximum age in milliseconds of - samples used to compute metrics. Default: 30000 - selector (selectors.BaseSelector): Provide a specific selector - implementation to use for I/O multiplexing. - Default: selectors.DefaultSelector - sasl_mechanism (str): Authentication mechanism when security_protocol - is configured for SASL_PLAINTEXT or SASL_SSL. Valid values are: - PLAIN, GSSAPI, OAUTHBEARER, SCRAM-SHA-256, SCRAM-SHA-512. - sasl_plain_username (str): username for sasl PLAIN and SCRAM authentication. - Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms. - sasl_plain_password (str): password for sasl PLAIN and SCRAM authentication. - Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms. - sasl_kerberos_service_name (str): Service name to include in GSSAPI - sasl mechanism handshake. Default: 'kafka' - sasl_kerberos_domain_name (str): kerberos domain name to use in GSSAPI - sasl mechanism handshake. Default: one of bootstrap servers - sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider - instance. (See kafka.oauth.abstract). Default: None - kafka_client (callable): Custom class / callable for creating KafkaClient instances - - Note: - Configuration parameters are described in more detail at - https://kafka.apache.org/0100/documentation/#producerconfigs - """ - DEFAULT_CONFIG = { - 'bootstrap_servers': 'localhost', - 'client_id': None, - 'key_serializer': None, - 'value_serializer': None, - 'acks': 1, - 'bootstrap_topics_filter': set(), - 'compression_type': None, - 'retries': 0, - 'batch_size': 16384, - 'linger_ms': 0, - 'partitioner': DefaultPartitioner(), - 'buffer_memory': 33554432, - 'connections_max_idle_ms': 9 * 60 * 1000, - 'max_block_ms': 60000, - 'max_request_size': 1048576, - 'metadata_max_age_ms': 300000, - 'retry_backoff_ms': 100, - 'request_timeout_ms': 30000, - 'receive_buffer_bytes': None, - 'send_buffer_bytes': None, - 'socket_options': [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)], - 'sock_chunk_bytes': 4096, # undocumented experimental option - 'sock_chunk_buffer_count': 1000, # undocumented experimental option - 'reconnect_backoff_ms': 50, - 'reconnect_backoff_max_ms': 1000, - 'max_in_flight_requests_per_connection': 5, - 'security_protocol': 'PLAINTEXT', - 'ssl_context': None, - 'ssl_check_hostname': True, - 'ssl_cafile': None, - 'ssl_certfile': None, - 'ssl_keyfile': None, - 'ssl_crlfile': None, - 'ssl_password': None, - 'ssl_ciphers': None, - 'api_version': None, - 'api_version_auto_timeout_ms': 2000, - 'metric_reporters': [], - 'metrics_num_samples': 2, - 'metrics_sample_window_ms': 30000, - 'selector': selectors.DefaultSelector, - 'sasl_mechanism': None, - 'sasl_plain_username': None, - 'sasl_plain_password': None, - 'sasl_kerberos_service_name': 'kafka', - 'sasl_kerberos_domain_name': None, - 'sasl_oauth_token_provider': None, - 'kafka_client': KafkaClient, - } - - _COMPRESSORS = { - 'gzip': (has_gzip, LegacyRecordBatchBuilder.CODEC_GZIP), - 'snappy': (has_snappy, LegacyRecordBatchBuilder.CODEC_SNAPPY), - 'lz4': (has_lz4, LegacyRecordBatchBuilder.CODEC_LZ4), - 'zstd': (has_zstd, DefaultRecordBatchBuilder.CODEC_ZSTD), - None: (lambda: True, LegacyRecordBatchBuilder.CODEC_NONE), - } - - def __init__(self, **configs): - log.debug("Starting the Kafka producer") # trace - self.config = copy.copy(self.DEFAULT_CONFIG) - for key in self.config: - if key in configs: - self.config[key] = configs.pop(key) - - # Only check for extra config keys in top-level class - assert not configs, 'Unrecognized configs: %s' % (configs,) - - if self.config['client_id'] is None: - self.config['client_id'] = 'kafka-python-producer-%s' % \ - (PRODUCER_CLIENT_ID_SEQUENCE.increment(),) - - if self.config['acks'] == 'all': - self.config['acks'] = -1 - - # api_version was previously a str. accept old format for now - if isinstance(self.config['api_version'], str): - deprecated = self.config['api_version'] - if deprecated == 'auto': - self.config['api_version'] = None - else: - self.config['api_version'] = tuple(map(int, deprecated.split('.'))) - log.warning('use api_version=%s [tuple] -- "%s" as str is deprecated', - str(self.config['api_version']), deprecated) - - # Configure metrics - metrics_tags = {'client-id': self.config['client_id']} - metric_config = MetricConfig(samples=self.config['metrics_num_samples'], - time_window_ms=self.config['metrics_sample_window_ms'], - tags=metrics_tags) - reporters = [reporter() for reporter in self.config['metric_reporters']] - self._metrics = Metrics(metric_config, reporters) - - client = self.config['kafka_client']( - metrics=self._metrics, metric_group_prefix='producer', - wakeup_timeout_ms=self.config['max_block_ms'], - **self.config) - - # Get auto-discovered version from client if necessary - if self.config['api_version'] is None: - self.config['api_version'] = client.config['api_version'] - - if self.config['compression_type'] == 'lz4': - assert self.config['api_version'] >= (0, 8, 2), 'LZ4 Requires >= Kafka 0.8.2 Brokers' - - if self.config['compression_type'] == 'zstd': - assert self.config['api_version'] >= (2, 1, 0), 'Zstd Requires >= Kafka 2.1.0 Brokers' - - # Check compression_type for library support - ct = self.config['compression_type'] - if ct not in self._COMPRESSORS: - raise ValueError("Not supported codec: {}".format(ct)) - else: - checker, compression_attrs = self._COMPRESSORS[ct] - assert checker(), "Libraries for {} compression codec not found".format(ct) - self.config['compression_attrs'] = compression_attrs - - message_version = self._max_usable_produce_magic() - self._accumulator = RecordAccumulator(message_version=message_version, metrics=self._metrics, **self.config) - self._metadata = client.cluster - guarantee_message_order = bool(self.config['max_in_flight_requests_per_connection'] == 1) - self._sender = Sender(client, self._metadata, - self._accumulator, self._metrics, - guarantee_message_order=guarantee_message_order, - **self.config) - self._sender.daemon = True - self._sender.start() - self._closed = False - - self._cleanup = self._cleanup_factory() - atexit.register(self._cleanup) - log.debug("Kafka producer started") - - def bootstrap_connected(self): - """Return True if the bootstrap is connected.""" - return self._sender.bootstrap_connected() - - def _cleanup_factory(self): - """Build a cleanup clojure that doesn't increase our ref count""" - _self = weakref.proxy(self) - def wrapper(): - try: - _self.close(timeout=0) - except (ReferenceError, AttributeError): - pass - return wrapper - - def _unregister_cleanup(self): - if getattr(self, '_cleanup', None): - if hasattr(atexit, 'unregister'): - atexit.unregister(self._cleanup) # pylint: disable=no-member - - # py2 requires removing from private attribute... - else: - - # ValueError on list.remove() if the exithandler no longer exists - # but that is fine here - try: - atexit._exithandlers.remove( # pylint: disable=no-member - (self._cleanup, (), {})) - except ValueError: - pass - self._cleanup = None - - def __del__(self): - # Disable logger during destruction to avoid touching dangling references - class NullLogger(object): - def __getattr__(self, name): - return lambda *args: None - - global log - log = NullLogger() - - self.close() - - def close(self, timeout=None): - """Close this producer. - - Arguments: - timeout (float, optional): timeout in seconds to wait for completion. - """ - - # drop our atexit handler now to avoid leaks - self._unregister_cleanup() - - if not hasattr(self, '_closed') or self._closed: - log.info('Kafka producer closed') - return - if timeout is None: - # threading.TIMEOUT_MAX is available in Python3.3+ - timeout = getattr(threading, 'TIMEOUT_MAX', float('inf')) - if getattr(threading, 'TIMEOUT_MAX', False): - assert 0 <= timeout <= getattr(threading, 'TIMEOUT_MAX') - else: - assert timeout >= 0 - - log.info("Closing the Kafka producer with %s secs timeout.", timeout) - invoked_from_callback = bool(threading.current_thread() is self._sender) - if timeout > 0: - if invoked_from_callback: - log.warning("Overriding close timeout %s secs to 0 in order to" - " prevent useless blocking due to self-join. This" - " means you have incorrectly invoked close with a" - " non-zero timeout from the producer call-back.", - timeout) - else: - # Try to close gracefully. - if self._sender is not None: - self._sender.initiate_close() - self._sender.join(timeout) - - if self._sender is not None and self._sender.is_alive(): - log.info("Proceeding to force close the producer since pending" - " requests could not be completed within timeout %s.", - timeout) - self._sender.force_close() - - self._metrics.close() - try: - self.config['key_serializer'].close() - except AttributeError: - pass - try: - self.config['value_serializer'].close() - except AttributeError: - pass - self._closed = True - log.debug("The Kafka producer has closed.") - - def partitions_for(self, topic): - """Returns set of all known partitions for the topic.""" - max_wait = self.config['max_block_ms'] / 1000.0 - return self._wait_on_metadata(topic, max_wait) - - def _max_usable_produce_magic(self): - if self.config['api_version'] >= (0, 11): - return 2 - elif self.config['api_version'] >= (0, 10): - return 1 - else: - return 0 - - def _estimate_size_in_bytes(self, key, value, headers=[]): - magic = self._max_usable_produce_magic() - if magic == 2: - return DefaultRecordBatchBuilder.estimate_size_in_bytes( - key, value, headers) - else: - return LegacyRecordBatchBuilder.estimate_size_in_bytes( - magic, self.config['compression_type'], key, value) - - def send(self, topic, value=None, key=None, headers=None, partition=None, timestamp_ms=None): - """Publish a message to a topic. - - Arguments: - topic (str): topic where the message will be published - value (optional): message value. Must be type bytes, or be - serializable to bytes via configured value_serializer. If value - is None, key is required and message acts as a 'delete'. - See kafka compaction documentation for more details: - https://kafka.apache.org/documentation.html#compaction - (compaction requires kafka >= 0.8.1) - partition (int, optional): optionally specify a partition. If not - set, the partition will be selected using the configured - 'partitioner'. - key (optional): a key to associate with the message. Can be used to - determine which partition to send the message to. If partition - is None (and producer's partitioner config is left as default), - then messages with the same key will be delivered to the same - partition (but if key is None, partition is chosen randomly). - Must be type bytes, or be serializable to bytes via configured - key_serializer. - headers (optional): a list of header key value pairs. List items - are tuples of str key and bytes value. - timestamp_ms (int, optional): epoch milliseconds (from Jan 1 1970 UTC) - to use as the message timestamp. Defaults to current time. - - Returns: - FutureRecordMetadata: resolves to RecordMetadata - - Raises: - KafkaTimeoutError: if unable to fetch topic metadata, or unable - to obtain memory buffer prior to configured max_block_ms - """ - assert value is not None or self.config['api_version'] >= (0, 8, 1), ( - 'Null messages require kafka >= 0.8.1') - assert not (value is None and key is None), 'Need at least one: key or value' - key_bytes = value_bytes = None - try: - self._wait_on_metadata(topic, self.config['max_block_ms'] / 1000.0) - - key_bytes = self._serialize( - self.config['key_serializer'], - topic, key) - value_bytes = self._serialize( - self.config['value_serializer'], - topic, value) - assert type(key_bytes) in (bytes, bytearray, memoryview, type(None)) - assert type(value_bytes) in (bytes, bytearray, memoryview, type(None)) - - partition = self._partition(topic, partition, key, value, - key_bytes, value_bytes) - - if headers is None: - headers = [] - assert type(headers) == list - assert all(type(item) == tuple and len(item) == 2 and type(item[0]) == str and type(item[1]) == bytes for item in headers) - - message_size = self._estimate_size_in_bytes(key_bytes, value_bytes, headers) - self._ensure_valid_record_size(message_size) - - tp = TopicPartition(topic, partition) - log.debug("Sending (key=%r value=%r headers=%r) to %s", key, value, headers, tp) - result = self._accumulator.append(tp, timestamp_ms, - key_bytes, value_bytes, headers, - self.config['max_block_ms'], - estimated_size=message_size) - future, batch_is_full, new_batch_created = result - if batch_is_full or new_batch_created: - log.debug("Waking up the sender since %s is either full or" - " getting a new batch", tp) - self._sender.wakeup() - - return future - # handling exceptions and record the errors; - # for API exceptions return them in the future, - # for other exceptions raise directly - except Errors.BrokerResponseError as e: - log.debug("Exception occurred during message send: %s", e) - return FutureRecordMetadata( - FutureProduceResult(TopicPartition(topic, partition)), - -1, None, None, - len(key_bytes) if key_bytes is not None else -1, - len(value_bytes) if value_bytes is not None else -1, - sum(len(h_key.encode("utf-8")) + len(h_value) for h_key, h_value in headers) if headers else -1, - ).failure(e) - - def flush(self, timeout=None): - """ - Invoking this method makes all buffered records immediately available - to send (even if linger_ms is greater than 0) and blocks on the - completion of the requests associated with these records. The - post-condition of :meth:`~kafka.KafkaProducer.flush` is that any - previously sent record will have completed - (e.g. Future.is_done() == True). A request is considered completed when - either it is successfully acknowledged according to the 'acks' - configuration for the producer, or it results in an error. - - Other threads can continue sending messages while one thread is blocked - waiting for a flush call to complete; however, no guarantee is made - about the completion of messages sent after the flush call begins. - - Arguments: - timeout (float, optional): timeout in seconds to wait for completion. - - Raises: - KafkaTimeoutError: failure to flush buffered records within the - provided timeout - """ - log.debug("Flushing accumulated records in producer.") # trace - self._accumulator.begin_flush() - self._sender.wakeup() - self._accumulator.await_flush_completion(timeout=timeout) - - def _ensure_valid_record_size(self, size): - """Validate that the record size isn't too large.""" - if size > self.config['max_request_size']: - raise Errors.MessageSizeTooLargeError( - "The message is %d bytes when serialized which is larger than" - " the maximum request size you have configured with the" - " max_request_size configuration" % (size,)) - if size > self.config['buffer_memory']: - raise Errors.MessageSizeTooLargeError( - "The message is %d bytes when serialized which is larger than" - " the total memory buffer you have configured with the" - " buffer_memory configuration." % (size,)) - - def _wait_on_metadata(self, topic, max_wait): - """ - Wait for cluster metadata including partitions for the given topic to - be available. - - Arguments: - topic (str): topic we want metadata for - max_wait (float): maximum time in secs for waiting on the metadata - - Returns: - set: partition ids for the topic - - Raises: - KafkaTimeoutError: if partitions for topic were not obtained before - specified max_wait timeout - """ - # add topic to metadata topic list if it is not there already. - self._sender.add_topic(topic) - begin = time.time() - elapsed = 0.0 - metadata_event = None - while True: - partitions = self._metadata.partitions_for_topic(topic) - if partitions is not None: - return partitions - - if not metadata_event: - metadata_event = threading.Event() - - log.debug("Requesting metadata update for topic %s", topic) - - metadata_event.clear() - future = self._metadata.request_update() - future.add_both(lambda e, *args: e.set(), metadata_event) - self._sender.wakeup() - metadata_event.wait(max_wait - elapsed) - elapsed = time.time() - begin - if not metadata_event.is_set(): - raise Errors.KafkaTimeoutError( - "Failed to update metadata after %.1f secs." % (max_wait,)) - elif topic in self._metadata.unauthorized_topics: - raise Errors.TopicAuthorizationFailedError(topic) - else: - log.debug("_wait_on_metadata woke after %s secs.", elapsed) - - def _serialize(self, f, topic, data): - if not f: - return data - if isinstance(f, Serializer): - return f.serialize(topic, data) - return f(data) - - def _partition(self, topic, partition, key, value, - serialized_key, serialized_value): - if partition is not None: - assert partition >= 0 - assert partition in self._metadata.partitions_for_topic(topic), 'Unrecognized partition' - return partition - - all_partitions = sorted(self._metadata.partitions_for_topic(topic)) - available = list(self._metadata.available_partitions_for_topic(topic)) - return self.config['partitioner'](serialized_key, - all_partitions, - available) - - def metrics(self, raw=False): - """Get metrics on producer performance. - - This is ported from the Java Producer, for details see: - https://kafka.apache.org/documentation/#producer_monitoring - - Warning: - This is an unstable interface. It may change in future - releases without warning. - """ - if raw: - return self._metrics.metrics.copy() - - metrics = {} - for k, v in six.iteritems(self._metrics.metrics.copy()): - if k.group not in metrics: - metrics[k.group] = {} - if k.name not in metrics[k.group]: - metrics[k.group][k.name] = {} - metrics[k.group][k.name] = v.value() - return metrics diff --git a/kafka/producer/record_accumulator.py b/kafka/producer/record_accumulator.py deleted file mode 100644 index a2aa0e8e..00000000 --- a/kafka/producer/record_accumulator.py +++ /dev/null @@ -1,590 +0,0 @@ -from __future__ import absolute_import - -import collections -import copy -import logging -import threading -import time - -import kafka.errors as Errors -from kafka.producer.buffer import SimpleBufferPool -from kafka.producer.future import FutureRecordMetadata, FutureProduceResult -from kafka.record.memory_records import MemoryRecordsBuilder -from kafka.structs import TopicPartition - - -log = logging.getLogger(__name__) - - -class AtomicInteger(object): - def __init__(self, val=0): - self._lock = threading.Lock() - self._val = val - - def increment(self): - with self._lock: - self._val += 1 - return self._val - - def decrement(self): - with self._lock: - self._val -= 1 - return self._val - - def get(self): - return self._val - - -class ProducerBatch(object): - def __init__(self, tp, records, buffer): - self.max_record_size = 0 - now = time.time() - self.created = now - self.drained = None - self.attempts = 0 - self.last_attempt = now - self.last_append = now - self.records = records - self.topic_partition = tp - self.produce_future = FutureProduceResult(tp) - self._retry = False - self._buffer = buffer # We only save it, we don't write to it - - @property - def record_count(self): - return self.records.next_offset() - - def try_append(self, timestamp_ms, key, value, headers): - metadata = self.records.append(timestamp_ms, key, value, headers) - if metadata is None: - return None - - self.max_record_size = max(self.max_record_size, metadata.size) - self.last_append = time.time() - future = FutureRecordMetadata(self.produce_future, metadata.offset, - metadata.timestamp, metadata.crc, - len(key) if key is not None else -1, - len(value) if value is not None else -1, - sum(len(h_key.encode("utf-8")) + len(h_val) for h_key, h_val in headers) if headers else -1) - return future - - def done(self, base_offset=None, timestamp_ms=None, exception=None, log_start_offset=None, global_error=None): - level = logging.DEBUG if exception is None else logging.WARNING - log.log(level, "Produced messages to topic-partition %s with base offset" - " %s log start offset %s and error %s.", self.topic_partition, base_offset, - log_start_offset, global_error) # trace - if self.produce_future.is_done: - log.warning('Batch is already closed -- ignoring batch.done()') - return - elif exception is None: - self.produce_future.success((base_offset, timestamp_ms, log_start_offset)) - else: - self.produce_future.failure(exception) - - def maybe_expire(self, request_timeout_ms, retry_backoff_ms, linger_ms, is_full): - """Expire batches if metadata is not available - - A batch whose metadata is not available should be expired if one - of the following is true: - - * the batch is not in retry AND request timeout has elapsed after - it is ready (full or linger.ms has reached). - - * the batch is in retry AND request timeout has elapsed after the - backoff period ended. - """ - now = time.time() - since_append = now - self.last_append - since_ready = now - (self.created + linger_ms / 1000.0) - since_backoff = now - (self.last_attempt + retry_backoff_ms / 1000.0) - timeout = request_timeout_ms / 1000.0 - - error = None - if not self.in_retry() and is_full and timeout < since_append: - error = "%d seconds have passed since last append" % (since_append,) - elif not self.in_retry() and timeout < since_ready: - error = "%d seconds have passed since batch creation plus linger time" % (since_ready,) - elif self.in_retry() and timeout < since_backoff: - error = "%d seconds have passed since last attempt plus backoff time" % (since_backoff,) - - if error: - self.records.close() - self.done(-1, None, Errors.KafkaTimeoutError( - "Batch for %s containing %s record(s) expired: %s" % ( - self.topic_partition, self.records.next_offset(), error))) - return True - return False - - def in_retry(self): - return self._retry - - def set_retry(self): - self._retry = True - - def buffer(self): - return self._buffer - - def __str__(self): - return 'ProducerBatch(topic_partition=%s, record_count=%d)' % ( - self.topic_partition, self.records.next_offset()) - - -class RecordAccumulator(object): - """ - This class maintains a dequeue per TopicPartition that accumulates messages - into MessageSets to be sent to the server. - - The accumulator attempts to bound memory use, and append calls will block - when that memory is exhausted. - - Keyword Arguments: - batch_size (int): Requests sent to brokers will contain multiple - batches, one for each partition with data available to be sent. - A small batch size will make batching less common and may reduce - throughput (a batch size of zero will disable batching entirely). - Default: 16384 - buffer_memory (int): The total bytes of memory the producer should use - to buffer records waiting to be sent to the server. If records are - sent faster than they can be delivered to the server the producer - will block up to max_block_ms, raising an exception on timeout. - In the current implementation, this setting is an approximation. - Default: 33554432 (32MB) - compression_attrs (int): The compression type for all data generated by - the producer. Valid values are gzip(1), snappy(2), lz4(3), or - none(0). - Compression is of full batches of data, so the efficacy of batching - will also impact the compression ratio (more batching means better - compression). Default: None. - linger_ms (int): An artificial delay time to add before declaring a - messageset (that isn't full) ready for sending. This allows - time for more records to arrive. Setting a non-zero linger_ms - will trade off some latency for potentially better throughput - due to more batching (and hence fewer, larger requests). - Default: 0 - retry_backoff_ms (int): An artificial delay time to retry the - produce request upon receiving an error. This avoids exhausting - all retries in a short period of time. Default: 100 - """ - DEFAULT_CONFIG = { - 'buffer_memory': 33554432, - 'batch_size': 16384, - 'compression_attrs': 0, - 'linger_ms': 0, - 'retry_backoff_ms': 100, - 'message_version': 0, - 'metrics': None, - 'metric_group_prefix': 'producer-metrics', - } - - def __init__(self, **configs): - self.config = copy.copy(self.DEFAULT_CONFIG) - for key in self.config: - if key in configs: - self.config[key] = configs.pop(key) - - self._closed = False - self._flushes_in_progress = AtomicInteger() - self._appends_in_progress = AtomicInteger() - self._batches = collections.defaultdict(collections.deque) # TopicPartition: [ProducerBatch] - self._tp_locks = {None: threading.Lock()} # TopicPartition: Lock, plus a lock to add entries - self._free = SimpleBufferPool(self.config['buffer_memory'], - self.config['batch_size'], - metrics=self.config['metrics'], - metric_group_prefix=self.config['metric_group_prefix']) - self._incomplete = IncompleteProducerBatches() - # The following variables should only be accessed by the sender thread, - # so we don't need to protect them w/ locking. - self.muted = set() - self._drain_index = 0 - - def append(self, tp, timestamp_ms, key, value, headers, max_time_to_block_ms, - estimated_size=0): - """Add a record to the accumulator, return the append result. - - The append result will contain the future metadata, and flag for - whether the appended batch is full or a new batch is created - - Arguments: - tp (TopicPartition): The topic/partition to which this record is - being sent - timestamp_ms (int): The timestamp of the record (epoch ms) - key (bytes): The key for the record - value (bytes): The value for the record - headers (List[Tuple[str, bytes]]): The header fields for the record - max_time_to_block_ms (int): The maximum time in milliseconds to - block for buffer memory to be available - - Returns: - tuple: (future, batch_is_full, new_batch_created) - """ - assert isinstance(tp, TopicPartition), 'not TopicPartition' - assert not self._closed, 'RecordAccumulator is closed' - # We keep track of the number of appending thread to make sure we do - # not miss batches in abortIncompleteBatches(). - self._appends_in_progress.increment() - try: - if tp not in self._tp_locks: - with self._tp_locks[None]: - if tp not in self._tp_locks: - self._tp_locks[tp] = threading.Lock() - - with self._tp_locks[tp]: - # check if we have an in-progress batch - dq = self._batches[tp] - if dq: - last = dq[-1] - future = last.try_append(timestamp_ms, key, value, headers) - if future is not None: - batch_is_full = len(dq) > 1 or last.records.is_full() - return future, batch_is_full, False - - size = max(self.config['batch_size'], estimated_size) - log.debug("Allocating a new %d byte message buffer for %s", size, tp) # trace - buf = self._free.allocate(size, max_time_to_block_ms) - with self._tp_locks[tp]: - # Need to check if producer is closed again after grabbing the - # dequeue lock. - assert not self._closed, 'RecordAccumulator is closed' - - if dq: - last = dq[-1] - future = last.try_append(timestamp_ms, key, value, headers) - if future is not None: - # Somebody else found us a batch, return the one we - # waited for! Hopefully this doesn't happen often... - self._free.deallocate(buf) - batch_is_full = len(dq) > 1 or last.records.is_full() - return future, batch_is_full, False - - records = MemoryRecordsBuilder( - self.config['message_version'], - self.config['compression_attrs'], - self.config['batch_size'] - ) - - batch = ProducerBatch(tp, records, buf) - future = batch.try_append(timestamp_ms, key, value, headers) - if not future: - raise Exception() - - dq.append(batch) - self._incomplete.add(batch) - batch_is_full = len(dq) > 1 or batch.records.is_full() - return future, batch_is_full, True - finally: - self._appends_in_progress.decrement() - - def abort_expired_batches(self, request_timeout_ms, cluster): - """Abort the batches that have been sitting in RecordAccumulator for - more than the configured request_timeout due to metadata being - unavailable. - - Arguments: - request_timeout_ms (int): milliseconds to timeout - cluster (ClusterMetadata): current metadata for kafka cluster - - Returns: - list of ProducerBatch that were expired - """ - expired_batches = [] - to_remove = [] - count = 0 - for tp in list(self._batches.keys()): - assert tp in self._tp_locks, 'TopicPartition not in locks dict' - - # We only check if the batch should be expired if the partition - # does not have a batch in flight. This is to avoid the later - # batches get expired when an earlier batch is still in progress. - # This protection only takes effect when user sets - # max.in.flight.request.per.connection=1. Otherwise the expiration - # order is not guranteed. - if tp in self.muted: - continue - - with self._tp_locks[tp]: - # iterate over the batches and expire them if they have stayed - # in accumulator for more than request_timeout_ms - dq = self._batches[tp] - for batch in dq: - is_full = bool(bool(batch != dq[-1]) or batch.records.is_full()) - # check if the batch is expired - if batch.maybe_expire(request_timeout_ms, - self.config['retry_backoff_ms'], - self.config['linger_ms'], - is_full): - expired_batches.append(batch) - to_remove.append(batch) - count += 1 - self.deallocate(batch) - else: - # Stop at the first batch that has not expired. - break - - # Python does not allow us to mutate the dq during iteration - # Assuming expired batches are infrequent, this is better than - # creating a new copy of the deque for iteration on every loop - if to_remove: - for batch in to_remove: - dq.remove(batch) - to_remove = [] - - if expired_batches: - log.warning("Expired %d batches in accumulator", count) # trace - - return expired_batches - - def reenqueue(self, batch): - """Re-enqueue the given record batch in the accumulator to retry.""" - now = time.time() - batch.attempts += 1 - batch.last_attempt = now - batch.last_append = now - batch.set_retry() - assert batch.topic_partition in self._tp_locks, 'TopicPartition not in locks dict' - assert batch.topic_partition in self._batches, 'TopicPartition not in batches' - dq = self._batches[batch.topic_partition] - with self._tp_locks[batch.topic_partition]: - dq.appendleft(batch) - - def ready(self, cluster): - """ - Get a list of nodes whose partitions are ready to be sent, and the - earliest time at which any non-sendable partition will be ready; - Also return the flag for whether there are any unknown leaders for the - accumulated partition batches. - - A destination node is ready to send if: - - * There is at least one partition that is not backing off its send - * and those partitions are not muted (to prevent reordering if - max_in_flight_requests_per_connection is set to 1) - * and any of the following are true: - - * The record set is full - * The record set has sat in the accumulator for at least linger_ms - milliseconds - * The accumulator is out of memory and threads are blocking waiting - for data (in this case all partitions are immediately considered - ready). - * The accumulator has been closed - - Arguments: - cluster (ClusterMetadata): - - Returns: - tuple: - ready_nodes (set): node_ids that have ready batches - next_ready_check (float): secs until next ready after backoff - unknown_leaders_exist (bool): True if metadata refresh needed - """ - ready_nodes = set() - next_ready_check = 9999999.99 - unknown_leaders_exist = False - now = time.time() - - exhausted = bool(self._free.queued() > 0) - # several threads are accessing self._batches -- to simplify - # concurrent access, we iterate over a snapshot of partitions - # and lock each partition separately as needed - partitions = list(self._batches.keys()) - for tp in partitions: - leader = cluster.leader_for_partition(tp) - if leader is None or leader == -1: - unknown_leaders_exist = True - continue - elif leader in ready_nodes: - continue - elif tp in self.muted: - continue - - with self._tp_locks[tp]: - dq = self._batches[tp] - if not dq: - continue - batch = dq[0] - retry_backoff = self.config['retry_backoff_ms'] / 1000.0 - linger = self.config['linger_ms'] / 1000.0 - backing_off = bool(batch.attempts > 0 and - batch.last_attempt + retry_backoff > now) - waited_time = now - batch.last_attempt - time_to_wait = retry_backoff if backing_off else linger - time_left = max(time_to_wait - waited_time, 0) - full = bool(len(dq) > 1 or batch.records.is_full()) - expired = bool(waited_time >= time_to_wait) - - sendable = (full or expired or exhausted or self._closed or - self._flush_in_progress()) - - if sendable and not backing_off: - ready_nodes.add(leader) - else: - # Note that this results in a conservative estimate since - # an un-sendable partition may have a leader that will - # later be found to have sendable data. However, this is - # good enough since we'll just wake up and then sleep again - # for the remaining time. - next_ready_check = min(time_left, next_ready_check) - - return ready_nodes, next_ready_check, unknown_leaders_exist - - def has_unsent(self): - """Return whether there is any unsent record in the accumulator.""" - for tp in list(self._batches.keys()): - with self._tp_locks[tp]: - dq = self._batches[tp] - if len(dq): - return True - return False - - def drain(self, cluster, nodes, max_size): - """ - Drain all the data for the given nodes and collate them into a list of - batches that will fit within the specified size on a per-node basis. - This method attempts to avoid choosing the same topic-node repeatedly. - - Arguments: - cluster (ClusterMetadata): The current cluster metadata - nodes (list): list of node_ids to drain - max_size (int): maximum number of bytes to drain - - Returns: - dict: {node_id: list of ProducerBatch} with total size less than the - requested max_size. - """ - if not nodes: - return {} - - now = time.time() - batches = {} - for node_id in nodes: - size = 0 - partitions = list(cluster.partitions_for_broker(node_id)) - ready = [] - # to make starvation less likely this loop doesn't start at 0 - self._drain_index %= len(partitions) - start = self._drain_index - while True: - tp = partitions[self._drain_index] - if tp in self._batches and tp not in self.muted: - with self._tp_locks[tp]: - dq = self._batches[tp] - if dq: - first = dq[0] - backoff = ( - bool(first.attempts > 0) and - bool(first.last_attempt + - self.config['retry_backoff_ms'] / 1000.0 - > now) - ) - # Only drain the batch if it is not during backoff - if not backoff: - if (size + first.records.size_in_bytes() > max_size - and len(ready) > 0): - # there is a rare case that a single batch - # size is larger than the request size due - # to compression; in this case we will - # still eventually send this batch in a - # single request - break - else: - batch = dq.popleft() - batch.records.close() - size += batch.records.size_in_bytes() - ready.append(batch) - batch.drained = now - - self._drain_index += 1 - self._drain_index %= len(partitions) - if start == self._drain_index: - break - - batches[node_id] = ready - return batches - - def deallocate(self, batch): - """Deallocate the record batch.""" - self._incomplete.remove(batch) - self._free.deallocate(batch.buffer()) - - def _flush_in_progress(self): - """Are there any threads currently waiting on a flush?""" - return self._flushes_in_progress.get() > 0 - - def begin_flush(self): - """ - Initiate the flushing of data from the accumulator...this makes all - requests immediately ready - """ - self._flushes_in_progress.increment() - - def await_flush_completion(self, timeout=None): - """ - Mark all partitions as ready to send and block until the send is complete - """ - try: - for batch in self._incomplete.all(): - log.debug('Waiting on produce to %s', - batch.produce_future.topic_partition) - if not batch.produce_future.wait(timeout=timeout): - raise Errors.KafkaTimeoutError('Timeout waiting for future') - if not batch.produce_future.is_done: - raise Errors.UnknownError('Future not done') - - if batch.produce_future.failed(): - log.warning(batch.produce_future.exception) - finally: - self._flushes_in_progress.decrement() - - def abort_incomplete_batches(self): - """ - This function is only called when sender is closed forcefully. It will fail all the - incomplete batches and return. - """ - # We need to keep aborting the incomplete batch until no thread is trying to append to - # 1. Avoid losing batches. - # 2. Free up memory in case appending threads are blocked on buffer full. - # This is a tight loop but should be able to get through very quickly. - while True: - self._abort_batches() - if not self._appends_in_progress.get(): - break - # After this point, no thread will append any messages because they will see the close - # flag set. We need to do the last abort after no thread was appending in case the there was a new - # batch appended by the last appending thread. - self._abort_batches() - self._batches.clear() - - def _abort_batches(self): - """Go through incomplete batches and abort them.""" - error = Errors.IllegalStateError("Producer is closed forcefully.") - for batch in self._incomplete.all(): - tp = batch.topic_partition - # Close the batch before aborting - with self._tp_locks[tp]: - batch.records.close() - batch.done(exception=error) - self.deallocate(batch) - - def close(self): - """Close this accumulator and force all the record buffers to be drained.""" - self._closed = True - - -class IncompleteProducerBatches(object): - """A threadsafe helper class to hold ProducerBatches that haven't been ack'd yet""" - - def __init__(self): - self._incomplete = set() - self._lock = threading.Lock() - - def add(self, batch): - with self._lock: - return self._incomplete.add(batch) - - def remove(self, batch): - with self._lock: - return self._incomplete.remove(batch) - - def all(self): - with self._lock: - return list(self._incomplete) diff --git a/kafka/producer/sender.py b/kafka/producer/sender.py deleted file mode 100644 index 35688d3f..00000000 --- a/kafka/producer/sender.py +++ /dev/null @@ -1,517 +0,0 @@ -from __future__ import absolute_import, division - -import collections -import copy -import logging -import threading -import time - -from kafka.vendor import six - -from kafka import errors as Errors -from kafka.metrics.measurable import AnonMeasurable -from kafka.metrics.stats import Avg, Max, Rate -from kafka.protocol.produce import ProduceRequest -from kafka.structs import TopicPartition -from kafka.version import __version__ - -log = logging.getLogger(__name__) - - -class Sender(threading.Thread): - """ - The background thread that handles the sending of produce requests to the - Kafka cluster. This thread makes metadata requests to renew its view of the - cluster and then sends produce requests to the appropriate nodes. - """ - DEFAULT_CONFIG = { - 'max_request_size': 1048576, - 'acks': 1, - 'retries': 0, - 'request_timeout_ms': 30000, - 'guarantee_message_order': False, - 'client_id': 'kafka-python-' + __version__, - 'api_version': (0, 8, 0), - } - - def __init__(self, client, metadata, accumulator, metrics, **configs): - super(Sender, self).__init__() - self.config = copy.copy(self.DEFAULT_CONFIG) - for key in self.config: - if key in configs: - self.config[key] = configs.pop(key) - - self.name = self.config['client_id'] + '-network-thread' - self._client = client - self._accumulator = accumulator - self._metadata = client.cluster - self._running = True - self._force_close = False - self._topics_to_add = set() - self._sensors = SenderMetrics(metrics, self._client, self._metadata) - - def run(self): - """The main run loop for the sender thread.""" - log.debug("Starting Kafka producer I/O thread.") - - # main loop, runs until close is called - while self._running: - try: - self.run_once() - except Exception: - log.exception("Uncaught error in kafka producer I/O thread") - - log.debug("Beginning shutdown of Kafka producer I/O thread, sending" - " remaining records.") - - # okay we stopped accepting requests but there may still be - # requests in the accumulator or waiting for acknowledgment, - # wait until these are completed. - while (not self._force_close - and (self._accumulator.has_unsent() - or self._client.in_flight_request_count() > 0)): - try: - self.run_once() - except Exception: - log.exception("Uncaught error in kafka producer I/O thread") - - if self._force_close: - # We need to fail all the incomplete batches and wake up the - # threads waiting on the futures. - self._accumulator.abort_incomplete_batches() - - try: - self._client.close() - except Exception: - log.exception("Failed to close network client") - - log.debug("Shutdown of Kafka producer I/O thread has completed.") - - def run_once(self): - """Run a single iteration of sending.""" - while self._topics_to_add: - self._client.add_topic(self._topics_to_add.pop()) - - # get the list of partitions with data ready to send - result = self._accumulator.ready(self._metadata) - ready_nodes, next_ready_check_delay, unknown_leaders_exist = result - - # if there are any partitions whose leaders are not known yet, force - # metadata update - if unknown_leaders_exist: - log.debug('Unknown leaders exist, requesting metadata update') - self._metadata.request_update() - - # remove any nodes we aren't ready to send to - not_ready_timeout = float('inf') - for node in list(ready_nodes): - if not self._client.is_ready(node): - log.debug('Node %s not ready; delaying produce of accumulated batch', node) - self._client.maybe_connect(node, wakeup=False) - ready_nodes.remove(node) - not_ready_timeout = min(not_ready_timeout, - self._client.connection_delay(node)) - - # create produce requests - batches_by_node = self._accumulator.drain( - self._metadata, ready_nodes, self.config['max_request_size']) - - if self.config['guarantee_message_order']: - # Mute all the partitions drained - for batch_list in six.itervalues(batches_by_node): - for batch in batch_list: - self._accumulator.muted.add(batch.topic_partition) - - expired_batches = self._accumulator.abort_expired_batches( - self.config['request_timeout_ms'], self._metadata) - for expired_batch in expired_batches: - self._sensors.record_errors(expired_batch.topic_partition.topic, expired_batch.record_count) - - self._sensors.update_produce_request_metrics(batches_by_node) - requests = self._create_produce_requests(batches_by_node) - # If we have any nodes that are ready to send + have sendable data, - # poll with 0 timeout so this can immediately loop and try sending more - # data. Otherwise, the timeout is determined by nodes that have - # partitions with data that isn't yet sendable (e.g. lingering, backing - # off). Note that this specifically does not include nodes with - # sendable data that aren't ready to send since they would cause busy - # looping. - poll_timeout_ms = min(next_ready_check_delay * 1000, not_ready_timeout) - if ready_nodes: - log.debug("Nodes with data ready to send: %s", ready_nodes) # trace - log.debug("Created %d produce requests: %s", len(requests), requests) # trace - poll_timeout_ms = 0 - - for node_id, request in six.iteritems(requests): - batches = batches_by_node[node_id] - log.debug('Sending Produce Request: %r', request) - (self._client.send(node_id, request, wakeup=False) - .add_callback( - self._handle_produce_response, node_id, time.time(), batches) - .add_errback( - self._failed_produce, batches, node_id)) - - # if some partitions are already ready to be sent, the select time - # would be 0; otherwise if some partition already has some data - # accumulated but not ready yet, the select time will be the time - # difference between now and its linger expiry time; otherwise the - # select time will be the time difference between now and the - # metadata expiry time - self._client.poll(timeout_ms=poll_timeout_ms) - - def initiate_close(self): - """Start closing the sender (won't complete until all data is sent).""" - self._running = False - self._accumulator.close() - self.wakeup() - - def force_close(self): - """Closes the sender without sending out any pending messages.""" - self._force_close = True - self.initiate_close() - - def add_topic(self, topic): - # This is generally called from a separate thread - # so this needs to be a thread-safe operation - # we assume that checking set membership across threads - # is ok where self._client._topics should never - # remove topics for a producer instance, only add them. - if topic not in self._client._topics: - self._topics_to_add.add(topic) - self.wakeup() - - def _failed_produce(self, batches, node_id, error): - log.debug("Error sending produce request to node %d: %s", node_id, error) # trace - for batch in batches: - self._complete_batch(batch, error, -1, None) - - def _handle_produce_response(self, node_id, send_time, batches, response): - """Handle a produce response.""" - # if we have a response, parse it - log.debug('Parsing produce response: %r', response) - if response: - batches_by_partition = dict([(batch.topic_partition, batch) - for batch in batches]) - - for topic, partitions in response.topics: - for partition_info in partitions: - global_error = None - log_start_offset = None - if response.API_VERSION < 2: - partition, error_code, offset = partition_info - ts = None - elif 2 <= response.API_VERSION <= 4: - partition, error_code, offset, ts = partition_info - elif 5 <= response.API_VERSION <= 7: - partition, error_code, offset, ts, log_start_offset = partition_info - else: - # the ignored parameter is record_error of type list[(batch_index: int, error_message: str)] - partition, error_code, offset, ts, log_start_offset, _, global_error = partition_info - tp = TopicPartition(topic, partition) - error = Errors.for_code(error_code) - batch = batches_by_partition[tp] - self._complete_batch(batch, error, offset, ts, log_start_offset, global_error) - - if response.API_VERSION > 0: - self._sensors.record_throttle_time(response.throttle_time_ms, node=node_id) - - else: - # this is the acks = 0 case, just complete all requests - for batch in batches: - self._complete_batch(batch, None, -1, None) - - def _complete_batch(self, batch, error, base_offset, timestamp_ms=None, log_start_offset=None, global_error=None): - """Complete or retry the given batch of records. - - Arguments: - batch (RecordBatch): The record batch - error (Exception): The error (or None if none) - base_offset (int): The base offset assigned to the records if successful - timestamp_ms (int, optional): The timestamp returned by the broker for this batch - log_start_offset (int): The start offset of the log at the time this produce response was created - global_error (str): The summarising error message - """ - # Standardize no-error to None - if error is Errors.NoError: - error = None - - if error is not None and self._can_retry(batch, error): - # retry - log.warning("Got error produce response on topic-partition %s," - " retrying (%d attempts left). Error: %s", - batch.topic_partition, - self.config['retries'] - batch.attempts - 1, - global_error or error) - self._accumulator.reenqueue(batch) - self._sensors.record_retries(batch.topic_partition.topic, batch.record_count) - else: - if error is Errors.TopicAuthorizationFailedError: - error = error(batch.topic_partition.topic) - - # tell the user the result of their request - batch.done(base_offset, timestamp_ms, error, log_start_offset, global_error) - self._accumulator.deallocate(batch) - if error is not None: - self._sensors.record_errors(batch.topic_partition.topic, batch.record_count) - - if getattr(error, 'invalid_metadata', False): - self._metadata.request_update() - - # Unmute the completed partition. - if self.config['guarantee_message_order']: - self._accumulator.muted.remove(batch.topic_partition) - - def _can_retry(self, batch, error): - """ - We can retry a send if the error is transient and the number of - attempts taken is fewer than the maximum allowed - """ - return (batch.attempts < self.config['retries'] - and getattr(error, 'retriable', False)) - - def _create_produce_requests(self, collated): - """ - Transfer the record batches into a list of produce requests on a - per-node basis. - - Arguments: - collated: {node_id: [RecordBatch]} - - Returns: - dict: {node_id: ProduceRequest} (version depends on api_version) - """ - requests = {} - for node_id, batches in six.iteritems(collated): - requests[node_id] = self._produce_request( - node_id, self.config['acks'], - self.config['request_timeout_ms'], batches) - return requests - - def _produce_request(self, node_id, acks, timeout, batches): - """Create a produce request from the given record batches. - - Returns: - ProduceRequest (version depends on api_version) - """ - produce_records_by_partition = collections.defaultdict(dict) - for batch in batches: - topic = batch.topic_partition.topic - partition = batch.topic_partition.partition - - buf = batch.records.buffer() - produce_records_by_partition[topic][partition] = buf - - kwargs = {} - if self.config['api_version'] >= (2, 1): - version = 7 - elif self.config['api_version'] >= (2, 0): - version = 6 - elif self.config['api_version'] >= (1, 1): - version = 5 - elif self.config['api_version'] >= (1, 0): - version = 4 - elif self.config['api_version'] >= (0, 11): - version = 3 - kwargs = dict(transactional_id=None) - elif self.config['api_version'] >= (0, 10): - version = 2 - elif self.config['api_version'] == (0, 9): - version = 1 - else: - version = 0 - return ProduceRequest[version]( - required_acks=acks, - timeout=timeout, - topics=[(topic, list(partition_info.items())) - for topic, partition_info - in six.iteritems(produce_records_by_partition)], - **kwargs - ) - - def wakeup(self): - """Wake up the selector associated with this send thread.""" - self._client.wakeup() - - def bootstrap_connected(self): - return self._client.bootstrap_connected() - - -class SenderMetrics(object): - - def __init__(self, metrics, client, metadata): - self.metrics = metrics - self._client = client - self._metadata = metadata - - sensor_name = 'batch-size' - self.batch_size_sensor = self.metrics.sensor(sensor_name) - self.add_metric('batch-size-avg', Avg(), - sensor_name=sensor_name, - description='The average number of bytes sent per partition per-request.') - self.add_metric('batch-size-max', Max(), - sensor_name=sensor_name, - description='The max number of bytes sent per partition per-request.') - - sensor_name = 'compression-rate' - self.compression_rate_sensor = self.metrics.sensor(sensor_name) - self.add_metric('compression-rate-avg', Avg(), - sensor_name=sensor_name, - description='The average compression rate of record batches.') - - sensor_name = 'queue-time' - self.queue_time_sensor = self.metrics.sensor(sensor_name) - self.add_metric('record-queue-time-avg', Avg(), - sensor_name=sensor_name, - description='The average time in ms record batches spent in the record accumulator.') - self.add_metric('record-queue-time-max', Max(), - sensor_name=sensor_name, - description='The maximum time in ms record batches spent in the record accumulator.') - - sensor_name = 'produce-throttle-time' - self.produce_throttle_time_sensor = self.metrics.sensor(sensor_name) - self.add_metric('produce-throttle-time-avg', Avg(), - sensor_name=sensor_name, - description='The average throttle time in ms') - self.add_metric('produce-throttle-time-max', Max(), - sensor_name=sensor_name, - description='The maximum throttle time in ms') - - sensor_name = 'records-per-request' - self.records_per_request_sensor = self.metrics.sensor(sensor_name) - self.add_metric('record-send-rate', Rate(), - sensor_name=sensor_name, - description='The average number of records sent per second.') - self.add_metric('records-per-request-avg', Avg(), - sensor_name=sensor_name, - description='The average number of records per request.') - - sensor_name = 'bytes' - self.byte_rate_sensor = self.metrics.sensor(sensor_name) - self.add_metric('byte-rate', Rate(), - sensor_name=sensor_name, - description='The average number of bytes sent per second.') - - sensor_name = 'record-retries' - self.retry_sensor = self.metrics.sensor(sensor_name) - self.add_metric('record-retry-rate', Rate(), - sensor_name=sensor_name, - description='The average per-second number of retried record sends') - - sensor_name = 'errors' - self.error_sensor = self.metrics.sensor(sensor_name) - self.add_metric('record-error-rate', Rate(), - sensor_name=sensor_name, - description='The average per-second number of record sends that resulted in errors') - - sensor_name = 'record-size-max' - self.max_record_size_sensor = self.metrics.sensor(sensor_name) - self.add_metric('record-size-max', Max(), - sensor_name=sensor_name, - description='The maximum record size across all batches') - self.add_metric('record-size-avg', Avg(), - sensor_name=sensor_name, - description='The average maximum record size per batch') - - self.add_metric('requests-in-flight', - AnonMeasurable(lambda *_: self._client.in_flight_request_count()), - description='The current number of in-flight requests awaiting a response.') - - self.add_metric('metadata-age', - AnonMeasurable(lambda _, now: (now - self._metadata._last_successful_refresh_ms) / 1000), - description='The age in seconds of the current producer metadata being used.') - - def add_metric(self, metric_name, measurable, group_name='producer-metrics', - description=None, tags=None, - sensor_name=None): - m = self.metrics - metric = m.metric_name(metric_name, group_name, description, tags) - if sensor_name: - sensor = m.sensor(sensor_name) - sensor.add(metric, measurable) - else: - m.add_metric(metric, measurable) - - def maybe_register_topic_metrics(self, topic): - - def sensor_name(name): - return 'topic.{0}.{1}'.format(topic, name) - - # if one sensor of the metrics has been registered for the topic, - # then all other sensors should have been registered; and vice versa - if not self.metrics.get_sensor(sensor_name('records-per-batch')): - - self.add_metric('record-send-rate', Rate(), - sensor_name=sensor_name('records-per-batch'), - group_name='producer-topic-metrics.' + topic, - description= 'Records sent per second for topic ' + topic) - - self.add_metric('byte-rate', Rate(), - sensor_name=sensor_name('bytes'), - group_name='producer-topic-metrics.' + topic, - description='Bytes per second for topic ' + topic) - - self.add_metric('compression-rate', Avg(), - sensor_name=sensor_name('compression-rate'), - group_name='producer-topic-metrics.' + topic, - description='Average Compression ratio for topic ' + topic) - - self.add_metric('record-retry-rate', Rate(), - sensor_name=sensor_name('record-retries'), - group_name='producer-topic-metrics.' + topic, - description='Record retries per second for topic ' + topic) - - self.add_metric('record-error-rate', Rate(), - sensor_name=sensor_name('record-errors'), - group_name='producer-topic-metrics.' + topic, - description='Record errors per second for topic ' + topic) - - def update_produce_request_metrics(self, batches_map): - for node_batch in batches_map.values(): - records = 0 - total_bytes = 0 - for batch in node_batch: - # register all per-topic metrics at once - topic = batch.topic_partition.topic - self.maybe_register_topic_metrics(topic) - - # per-topic record send rate - topic_records_count = self.metrics.get_sensor( - 'topic.' + topic + '.records-per-batch') - topic_records_count.record(batch.record_count) - - # per-topic bytes send rate - topic_byte_rate = self.metrics.get_sensor( - 'topic.' + topic + '.bytes') - topic_byte_rate.record(batch.records.size_in_bytes()) - - # per-topic compression rate - topic_compression_rate = self.metrics.get_sensor( - 'topic.' + topic + '.compression-rate') - topic_compression_rate.record(batch.records.compression_rate()) - - # global metrics - self.batch_size_sensor.record(batch.records.size_in_bytes()) - if batch.drained: - self.queue_time_sensor.record(batch.drained - batch.created) - self.compression_rate_sensor.record(batch.records.compression_rate()) - self.max_record_size_sensor.record(batch.max_record_size) - records += batch.record_count - total_bytes += batch.records.size_in_bytes() - - self.records_per_request_sensor.record(records) - self.byte_rate_sensor.record(total_bytes) - - def record_retries(self, topic, count): - self.retry_sensor.record(count) - sensor = self.metrics.get_sensor('topic.' + topic + '.record-retries') - if sensor: - sensor.record(count) - - def record_errors(self, topic, count): - self.error_sensor.record(count) - sensor = self.metrics.get_sensor('topic.' + topic + '.record-errors') - if sensor: - sensor.record(count) - - def record_throttle_time(self, throttle_time_ms, node=None): - self.produce_throttle_time_sensor.record(throttle_time_ms) diff --git a/kafka/record/README b/kafka/record/README deleted file mode 100644 index e4454554..00000000 --- a/kafka/record/README +++ /dev/null @@ -1,8 +0,0 @@ -Module structured mostly based on -kafka/clients/src/main/java/org/apache/kafka/common/record/ module of Java -Client. - -See abc.py for abstract declarations. `ABCRecords` is used as a facade to hide -version differences. `ABCRecordBatch` subclasses will implement actual parsers -for different versions (v0/v1 as LegacyBatch and v2 as DefaultBatch. Names -taken from Java). diff --git a/kafka/record/__init__.py b/kafka/record/__init__.py deleted file mode 100644 index 93936df4..00000000 --- a/kafka/record/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from kafka.record.memory_records import MemoryRecords, MemoryRecordsBuilder - -__all__ = ["MemoryRecords", "MemoryRecordsBuilder"] diff --git a/kafka/record/_crc32c.py b/kafka/record/_crc32c.py deleted file mode 100644 index 9b51ad8a..00000000 --- a/kafka/record/_crc32c.py +++ /dev/null @@ -1,145 +0,0 @@ -#!/usr/bin/env python -# -# Taken from https://cloud.google.com/appengine/docs/standard/python/refdocs/\ -# modules/google/appengine/api/files/crc32c?hl=ru -# -# Copyright 2007 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -"""Implementation of CRC-32C checksumming as in rfc3720 section B.4. -See https://en.wikipedia.org/wiki/Cyclic_redundancy_check for details on CRC-32C -This code is a manual python translation of c code generated by -pycrc 0.7.1 (https://pycrc.org/). Command line used: -'./pycrc.py --model=crc-32c --generate c --algorithm=table-driven' -""" - -import array - -CRC_TABLE = ( - 0x00000000, 0xf26b8303, 0xe13b70f7, 0x1350f3f4, - 0xc79a971f, 0x35f1141c, 0x26a1e7e8, 0xd4ca64eb, - 0x8ad958cf, 0x78b2dbcc, 0x6be22838, 0x9989ab3b, - 0x4d43cfd0, 0xbf284cd3, 0xac78bf27, 0x5e133c24, - 0x105ec76f, 0xe235446c, 0xf165b798, 0x030e349b, - 0xd7c45070, 0x25afd373, 0x36ff2087, 0xc494a384, - 0x9a879fa0, 0x68ec1ca3, 0x7bbcef57, 0x89d76c54, - 0x5d1d08bf, 0xaf768bbc, 0xbc267848, 0x4e4dfb4b, - 0x20bd8ede, 0xd2d60ddd, 0xc186fe29, 0x33ed7d2a, - 0xe72719c1, 0x154c9ac2, 0x061c6936, 0xf477ea35, - 0xaa64d611, 0x580f5512, 0x4b5fa6e6, 0xb93425e5, - 0x6dfe410e, 0x9f95c20d, 0x8cc531f9, 0x7eaeb2fa, - 0x30e349b1, 0xc288cab2, 0xd1d83946, 0x23b3ba45, - 0xf779deae, 0x05125dad, 0x1642ae59, 0xe4292d5a, - 0xba3a117e, 0x4851927d, 0x5b016189, 0xa96ae28a, - 0x7da08661, 0x8fcb0562, 0x9c9bf696, 0x6ef07595, - 0x417b1dbc, 0xb3109ebf, 0xa0406d4b, 0x522bee48, - 0x86e18aa3, 0x748a09a0, 0x67dafa54, 0x95b17957, - 0xcba24573, 0x39c9c670, 0x2a993584, 0xd8f2b687, - 0x0c38d26c, 0xfe53516f, 0xed03a29b, 0x1f682198, - 0x5125dad3, 0xa34e59d0, 0xb01eaa24, 0x42752927, - 0x96bf4dcc, 0x64d4cecf, 0x77843d3b, 0x85efbe38, - 0xdbfc821c, 0x2997011f, 0x3ac7f2eb, 0xc8ac71e8, - 0x1c661503, 0xee0d9600, 0xfd5d65f4, 0x0f36e6f7, - 0x61c69362, 0x93ad1061, 0x80fde395, 0x72966096, - 0xa65c047d, 0x5437877e, 0x4767748a, 0xb50cf789, - 0xeb1fcbad, 0x197448ae, 0x0a24bb5a, 0xf84f3859, - 0x2c855cb2, 0xdeeedfb1, 0xcdbe2c45, 0x3fd5af46, - 0x7198540d, 0x83f3d70e, 0x90a324fa, 0x62c8a7f9, - 0xb602c312, 0x44694011, 0x5739b3e5, 0xa55230e6, - 0xfb410cc2, 0x092a8fc1, 0x1a7a7c35, 0xe811ff36, - 0x3cdb9bdd, 0xceb018de, 0xdde0eb2a, 0x2f8b6829, - 0x82f63b78, 0x709db87b, 0x63cd4b8f, 0x91a6c88c, - 0x456cac67, 0xb7072f64, 0xa457dc90, 0x563c5f93, - 0x082f63b7, 0xfa44e0b4, 0xe9141340, 0x1b7f9043, - 0xcfb5f4a8, 0x3dde77ab, 0x2e8e845f, 0xdce5075c, - 0x92a8fc17, 0x60c37f14, 0x73938ce0, 0x81f80fe3, - 0x55326b08, 0xa759e80b, 0xb4091bff, 0x466298fc, - 0x1871a4d8, 0xea1a27db, 0xf94ad42f, 0x0b21572c, - 0xdfeb33c7, 0x2d80b0c4, 0x3ed04330, 0xccbbc033, - 0xa24bb5a6, 0x502036a5, 0x4370c551, 0xb11b4652, - 0x65d122b9, 0x97baa1ba, 0x84ea524e, 0x7681d14d, - 0x2892ed69, 0xdaf96e6a, 0xc9a99d9e, 0x3bc21e9d, - 0xef087a76, 0x1d63f975, 0x0e330a81, 0xfc588982, - 0xb21572c9, 0x407ef1ca, 0x532e023e, 0xa145813d, - 0x758fe5d6, 0x87e466d5, 0x94b49521, 0x66df1622, - 0x38cc2a06, 0xcaa7a905, 0xd9f75af1, 0x2b9cd9f2, - 0xff56bd19, 0x0d3d3e1a, 0x1e6dcdee, 0xec064eed, - 0xc38d26c4, 0x31e6a5c7, 0x22b65633, 0xd0ddd530, - 0x0417b1db, 0xf67c32d8, 0xe52cc12c, 0x1747422f, - 0x49547e0b, 0xbb3ffd08, 0xa86f0efc, 0x5a048dff, - 0x8ecee914, 0x7ca56a17, 0x6ff599e3, 0x9d9e1ae0, - 0xd3d3e1ab, 0x21b862a8, 0x32e8915c, 0xc083125f, - 0x144976b4, 0xe622f5b7, 0xf5720643, 0x07198540, - 0x590ab964, 0xab613a67, 0xb831c993, 0x4a5a4a90, - 0x9e902e7b, 0x6cfbad78, 0x7fab5e8c, 0x8dc0dd8f, - 0xe330a81a, 0x115b2b19, 0x020bd8ed, 0xf0605bee, - 0x24aa3f05, 0xd6c1bc06, 0xc5914ff2, 0x37faccf1, - 0x69e9f0d5, 0x9b8273d6, 0x88d28022, 0x7ab90321, - 0xae7367ca, 0x5c18e4c9, 0x4f48173d, 0xbd23943e, - 0xf36e6f75, 0x0105ec76, 0x12551f82, 0xe03e9c81, - 0x34f4f86a, 0xc69f7b69, 0xd5cf889d, 0x27a40b9e, - 0x79b737ba, 0x8bdcb4b9, 0x988c474d, 0x6ae7c44e, - 0xbe2da0a5, 0x4c4623a6, 0x5f16d052, 0xad7d5351, -) - -CRC_INIT = 0 -_MASK = 0xFFFFFFFF - - -def crc_update(crc, data): - """Update CRC-32C checksum with data. - Args: - crc: 32-bit checksum to update as long. - data: byte array, string or iterable over bytes. - Returns: - 32-bit updated CRC-32C as long. - """ - if not isinstance(data, array.array) or data.itemsize != 1: - buf = array.array("B", data) - else: - buf = data - crc = crc ^ _MASK - for b in buf: - table_index = (crc ^ b) & 0xff - crc = (CRC_TABLE[table_index] ^ (crc >> 8)) & _MASK - return crc ^ _MASK - - -def crc_finalize(crc): - """Finalize CRC-32C checksum. - This function should be called as last step of crc calculation. - Args: - crc: 32-bit checksum as long. - Returns: - finalized 32-bit checksum as long - """ - return crc & _MASK - - -def crc(data): - """Compute CRC-32C checksum of the data. - Args: - data: byte array, string or iterable over bytes. - Returns: - 32-bit CRC-32C checksum of data as long. - """ - return crc_finalize(crc_update(CRC_INIT, data)) - - -if __name__ == "__main__": - import sys - # TODO remove the pylint disable once pylint fixes - # https://github.com/PyCQA/pylint/issues/2571 - data = sys.stdin.read() # pylint: disable=assignment-from-no-return - print(hex(crc(data))) diff --git a/kafka/record/abc.py b/kafka/record/abc.py deleted file mode 100644 index 8509e23e..00000000 --- a/kafka/record/abc.py +++ /dev/null @@ -1,124 +0,0 @@ -from __future__ import absolute_import -import abc - - -class ABCRecord(object): - __metaclass__ = abc.ABCMeta - __slots__ = () - - @abc.abstractproperty - def offset(self): - """ Absolute offset of record - """ - - @abc.abstractproperty - def timestamp(self): - """ Epoch milliseconds - """ - - @abc.abstractproperty - def timestamp_type(self): - """ CREATE_TIME(0) or APPEND_TIME(1) - """ - - @abc.abstractproperty - def key(self): - """ Bytes key or None - """ - - @abc.abstractproperty - def value(self): - """ Bytes value or None - """ - - @abc.abstractproperty - def checksum(self): - """ Prior to v2 format CRC was contained in every message. This will - be the checksum for v0 and v1 and None for v2 and above. - """ - - @abc.abstractproperty - def headers(self): - """ If supported by version list of key-value tuples, or empty list if - not supported by format. - """ - - -class ABCRecordBatchBuilder(object): - __metaclass__ = abc.ABCMeta - __slots__ = () - - @abc.abstractmethod - def append(self, offset, timestamp, key, value, headers=None): - """ Writes record to internal buffer. - - Arguments: - offset (int): Relative offset of record, starting from 0 - timestamp (int or None): Timestamp in milliseconds since beginning - of the epoch (midnight Jan 1, 1970 (UTC)). If omitted, will be - set to current time. - key (bytes or None): Key of the record - value (bytes or None): Value of the record - headers (List[Tuple[str, bytes]]): Headers of the record. Header - keys can not be ``None``. - - Returns: - (bytes, int): Checksum of the written record (or None for v2 and - above) and size of the written record. - """ - - @abc.abstractmethod - def size_in_bytes(self, offset, timestamp, key, value, headers): - """ Return the expected size change on buffer (uncompressed) if we add - this message. This will account for varint size changes and give a - reliable size. - """ - - @abc.abstractmethod - def build(self): - """ Close for append, compress if needed, write size and header and - return a ready to send buffer object. - - Return: - bytearray: finished batch, ready to send. - """ - - -class ABCRecordBatch(object): - """ For v2 encapsulates a RecordBatch, for v0/v1 a single (maybe - compressed) message. - """ - __metaclass__ = abc.ABCMeta - __slots__ = () - - @abc.abstractmethod - def __iter__(self): - """ Return iterator over records (ABCRecord instances). Will decompress - if needed. - """ - - -class ABCRecords(object): - __metaclass__ = abc.ABCMeta - __slots__ = () - - @abc.abstractmethod - def __init__(self, buffer): - """ Initialize with bytes-like object conforming to the buffer - interface (ie. bytes, bytearray, memoryview etc.). - """ - - @abc.abstractmethod - def size_in_bytes(self): - """ Returns the size of inner buffer. - """ - - @abc.abstractmethod - def next_batch(self): - """ Return next batch of records (ABCRecordBatch instances). - """ - - @abc.abstractmethod - def has_next(self): - """ True if there are more batches to read, False otherwise. - """ diff --git a/kafka/record/default_records.py b/kafka/record/default_records.py deleted file mode 100644 index a098c42a..00000000 --- a/kafka/record/default_records.py +++ /dev/null @@ -1,630 +0,0 @@ -# See: -# https://github.com/apache/kafka/blob/trunk/clients/src/main/java/org/\ -# apache/kafka/common/record/DefaultRecordBatch.java -# https://github.com/apache/kafka/blob/trunk/clients/src/main/java/org/\ -# apache/kafka/common/record/DefaultRecord.java - -# RecordBatch and Record implementation for magic 2 and above. -# The schema is given below: - -# RecordBatch => -# BaseOffset => Int64 -# Length => Int32 -# PartitionLeaderEpoch => Int32 -# Magic => Int8 -# CRC => Uint32 -# Attributes => Int16 -# LastOffsetDelta => Int32 // also serves as LastSequenceDelta -# FirstTimestamp => Int64 -# MaxTimestamp => Int64 -# ProducerId => Int64 -# ProducerEpoch => Int16 -# BaseSequence => Int32 -# Records => [Record] - -# Record => -# Length => Varint -# Attributes => Int8 -# TimestampDelta => Varlong -# OffsetDelta => Varint -# Key => Bytes -# Value => Bytes -# Headers => [HeaderKey HeaderValue] -# HeaderKey => String -# HeaderValue => Bytes - -# Note that when compression is enabled (see attributes below), the compressed -# record data is serialized directly following the count of the number of -# records. (ie Records => [Record], but without length bytes) - -# The CRC covers the data from the attributes to the end of the batch (i.e. all -# the bytes that follow the CRC). It is located after the magic byte, which -# means that clients must parse the magic byte before deciding how to interpret -# the bytes between the batch length and the magic byte. The partition leader -# epoch field is not included in the CRC computation to avoid the need to -# recompute the CRC when this field is assigned for every batch that is -# received by the broker. The CRC-32C (Castagnoli) polynomial is used for the -# computation. - -# The current RecordBatch attributes are given below: -# -# * Unused (6-15) -# * Control (5) -# * Transactional (4) -# * Timestamp Type (3) -# * Compression Type (0-2) - -import struct -import time -from kafka.record.abc import ABCRecord, ABCRecordBatch, ABCRecordBatchBuilder -from kafka.record.util import ( - decode_varint, encode_varint, calc_crc32c, size_of_varint -) -from kafka.errors import CorruptRecordException, UnsupportedCodecError -from kafka.codec import ( - gzip_encode, snappy_encode, lz4_encode, zstd_encode, - gzip_decode, snappy_decode, lz4_decode, zstd_decode -) -import kafka.codec as codecs - - -class DefaultRecordBase(object): - - __slots__ = () - - HEADER_STRUCT = struct.Struct( - ">q" # BaseOffset => Int64 - "i" # Length => Int32 - "i" # PartitionLeaderEpoch => Int32 - "b" # Magic => Int8 - "I" # CRC => Uint32 - "h" # Attributes => Int16 - "i" # LastOffsetDelta => Int32 // also serves as LastSequenceDelta - "q" # FirstTimestamp => Int64 - "q" # MaxTimestamp => Int64 - "q" # ProducerId => Int64 - "h" # ProducerEpoch => Int16 - "i" # BaseSequence => Int32 - "i" # Records count => Int32 - ) - # Byte offset in HEADER_STRUCT of attributes field. Used to calculate CRC - ATTRIBUTES_OFFSET = struct.calcsize(">qiibI") - CRC_OFFSET = struct.calcsize(">qiib") - AFTER_LEN_OFFSET = struct.calcsize(">qi") - - CODEC_MASK = 0x07 - CODEC_NONE = 0x00 - CODEC_GZIP = 0x01 - CODEC_SNAPPY = 0x02 - CODEC_LZ4 = 0x03 - CODEC_ZSTD = 0x04 - TIMESTAMP_TYPE_MASK = 0x08 - TRANSACTIONAL_MASK = 0x10 - CONTROL_MASK = 0x20 - - LOG_APPEND_TIME = 1 - CREATE_TIME = 0 - - def _assert_has_codec(self, compression_type): - if compression_type == self.CODEC_GZIP: - checker, name = codecs.has_gzip, "gzip" - elif compression_type == self.CODEC_SNAPPY: - checker, name = codecs.has_snappy, "snappy" - elif compression_type == self.CODEC_LZ4: - checker, name = codecs.has_lz4, "lz4" - elif compression_type == self.CODEC_ZSTD: - checker, name = codecs.has_zstd, "zstd" - if not checker(): - raise UnsupportedCodecError( - "Libraries for {} compression codec not found".format(name)) - - -class DefaultRecordBatch(DefaultRecordBase, ABCRecordBatch): - - __slots__ = ("_buffer", "_header_data", "_pos", "_num_records", - "_next_record_index", "_decompressed") - - def __init__(self, buffer): - self._buffer = bytearray(buffer) - self._header_data = self.HEADER_STRUCT.unpack_from(self._buffer) - self._pos = self.HEADER_STRUCT.size - self._num_records = self._header_data[12] - self._next_record_index = 0 - self._decompressed = False - - @property - def base_offset(self): - return self._header_data[0] - - @property - def magic(self): - return self._header_data[3] - - @property - def crc(self): - return self._header_data[4] - - @property - def attributes(self): - return self._header_data[5] - - @property - def last_offset_delta(self): - return self._header_data[6] - - @property - def compression_type(self): - return self.attributes & self.CODEC_MASK - - @property - def timestamp_type(self): - return int(bool(self.attributes & self.TIMESTAMP_TYPE_MASK)) - - @property - def is_transactional(self): - return bool(self.attributes & self.TRANSACTIONAL_MASK) - - @property - def is_control_batch(self): - return bool(self.attributes & self.CONTROL_MASK) - - @property - def first_timestamp(self): - return self._header_data[7] - - @property - def max_timestamp(self): - return self._header_data[8] - - def _maybe_uncompress(self): - if not self._decompressed: - compression_type = self.compression_type - if compression_type != self.CODEC_NONE: - self._assert_has_codec(compression_type) - data = memoryview(self._buffer)[self._pos:] - if compression_type == self.CODEC_GZIP: - uncompressed = gzip_decode(data) - if compression_type == self.CODEC_SNAPPY: - uncompressed = snappy_decode(data.tobytes()) - if compression_type == self.CODEC_LZ4: - uncompressed = lz4_decode(data.tobytes()) - if compression_type == self.CODEC_ZSTD: - uncompressed = zstd_decode(data.tobytes()) - self._buffer = bytearray(uncompressed) - self._pos = 0 - self._decompressed = True - - def _read_msg( - self, - decode_varint=decode_varint): - # Record => - # Length => Varint - # Attributes => Int8 - # TimestampDelta => Varlong - # OffsetDelta => Varint - # Key => Bytes - # Value => Bytes - # Headers => [HeaderKey HeaderValue] - # HeaderKey => String - # HeaderValue => Bytes - - buffer = self._buffer - pos = self._pos - length, pos = decode_varint(buffer, pos) - start_pos = pos - _, pos = decode_varint(buffer, pos) # attrs can be skipped for now - - ts_delta, pos = decode_varint(buffer, pos) - if self.timestamp_type == self.LOG_APPEND_TIME: - timestamp = self.max_timestamp - else: - timestamp = self.first_timestamp + ts_delta - - offset_delta, pos = decode_varint(buffer, pos) - offset = self.base_offset + offset_delta - - key_len, pos = decode_varint(buffer, pos) - if key_len >= 0: - key = bytes(buffer[pos: pos + key_len]) - pos += key_len - else: - key = None - - value_len, pos = decode_varint(buffer, pos) - if value_len >= 0: - value = bytes(buffer[pos: pos + value_len]) - pos += value_len - else: - value = None - - header_count, pos = decode_varint(buffer, pos) - if header_count < 0: - raise CorruptRecordException("Found invalid number of record " - "headers {}".format(header_count)) - headers = [] - while header_count: - # Header key is of type String, that can't be None - h_key_len, pos = decode_varint(buffer, pos) - if h_key_len < 0: - raise CorruptRecordException( - "Invalid negative header key size {}".format(h_key_len)) - h_key = buffer[pos: pos + h_key_len].decode("utf-8") - pos += h_key_len - - # Value is of type NULLABLE_BYTES, so it can be None - h_value_len, pos = decode_varint(buffer, pos) - if h_value_len >= 0: - h_value = bytes(buffer[pos: pos + h_value_len]) - pos += h_value_len - else: - h_value = None - - headers.append((h_key, h_value)) - header_count -= 1 - - # validate whether we have read all header bytes in the current record - if pos - start_pos != length: - raise CorruptRecordException( - "Invalid record size: expected to read {} bytes in record " - "payload, but instead read {}".format(length, pos - start_pos)) - self._pos = pos - - return DefaultRecord( - offset, timestamp, self.timestamp_type, key, value, headers) - - def __iter__(self): - self._maybe_uncompress() - return self - - def __next__(self): - if self._next_record_index >= self._num_records: - if self._pos != len(self._buffer): - raise CorruptRecordException( - "{} unconsumed bytes after all records consumed".format( - len(self._buffer) - self._pos)) - raise StopIteration - try: - msg = self._read_msg() - except (ValueError, IndexError) as err: - raise CorruptRecordException( - "Found invalid record structure: {!r}".format(err)) - else: - self._next_record_index += 1 - return msg - - next = __next__ - - def validate_crc(self): - assert self._decompressed is False, \ - "Validate should be called before iteration" - - crc = self.crc - data_view = memoryview(self._buffer)[self.ATTRIBUTES_OFFSET:] - verify_crc = calc_crc32c(data_view.tobytes()) - return crc == verify_crc - - -class DefaultRecord(ABCRecord): - - __slots__ = ("_offset", "_timestamp", "_timestamp_type", "_key", "_value", - "_headers") - - def __init__(self, offset, timestamp, timestamp_type, key, value, headers): - self._offset = offset - self._timestamp = timestamp - self._timestamp_type = timestamp_type - self._key = key - self._value = value - self._headers = headers - - @property - def offset(self): - return self._offset - - @property - def timestamp(self): - """ Epoch milliseconds - """ - return self._timestamp - - @property - def timestamp_type(self): - """ CREATE_TIME(0) or APPEND_TIME(1) - """ - return self._timestamp_type - - @property - def key(self): - """ Bytes key or None - """ - return self._key - - @property - def value(self): - """ Bytes value or None - """ - return self._value - - @property - def headers(self): - return self._headers - - @property - def checksum(self): - return None - - def __repr__(self): - return ( - "DefaultRecord(offset={!r}, timestamp={!r}, timestamp_type={!r}," - " key={!r}, value={!r}, headers={!r})".format( - self._offset, self._timestamp, self._timestamp_type, - self._key, self._value, self._headers) - ) - - -class DefaultRecordBatchBuilder(DefaultRecordBase, ABCRecordBatchBuilder): - - # excluding key, value and headers: - # 5 bytes length + 10 bytes timestamp + 5 bytes offset + 1 byte attributes - MAX_RECORD_OVERHEAD = 21 - - __slots__ = ("_magic", "_compression_type", "_batch_size", "_is_transactional", - "_producer_id", "_producer_epoch", "_base_sequence", - "_first_timestamp", "_max_timestamp", "_last_offset", "_num_records", - "_buffer") - - def __init__( - self, magic, compression_type, is_transactional, - producer_id, producer_epoch, base_sequence, batch_size): - assert magic >= 2 - self._magic = magic - self._compression_type = compression_type & self.CODEC_MASK - self._batch_size = batch_size - self._is_transactional = bool(is_transactional) - # KIP-98 fields for EOS - self._producer_id = producer_id - self._producer_epoch = producer_epoch - self._base_sequence = base_sequence - - self._first_timestamp = None - self._max_timestamp = None - self._last_offset = 0 - self._num_records = 0 - - self._buffer = bytearray(self.HEADER_STRUCT.size) - - def _get_attributes(self, include_compression_type=True): - attrs = 0 - if include_compression_type: - attrs |= self._compression_type - # Timestamp Type is set by Broker - if self._is_transactional: - attrs |= self.TRANSACTIONAL_MASK - # Control batches are only created by Broker - return attrs - - def append(self, offset, timestamp, key, value, headers, - # Cache for LOAD_FAST opcodes - encode_varint=encode_varint, size_of_varint=size_of_varint, - get_type=type, type_int=int, time_time=time.time, - byte_like=(bytes, bytearray, memoryview), - bytearray_type=bytearray, len_func=len, zero_len_varint=1 - ): - """ Write message to messageset buffer with MsgVersion 2 - """ - # Check types - if get_type(offset) != type_int: - raise TypeError(offset) - if timestamp is None: - timestamp = type_int(time_time() * 1000) - elif get_type(timestamp) != type_int: - raise TypeError(timestamp) - if not (key is None or get_type(key) in byte_like): - raise TypeError( - "Not supported type for key: {}".format(type(key))) - if not (value is None or get_type(value) in byte_like): - raise TypeError( - "Not supported type for value: {}".format(type(value))) - - # We will always add the first message, so those will be set - if self._first_timestamp is None: - self._first_timestamp = timestamp - self._max_timestamp = timestamp - timestamp_delta = 0 - first_message = 1 - else: - timestamp_delta = timestamp - self._first_timestamp - first_message = 0 - - # We can't write record right away to out buffer, we need to - # precompute the length as first value... - message_buffer = bytearray_type(b"\x00") # Attributes - write_byte = message_buffer.append - write = message_buffer.extend - - encode_varint(timestamp_delta, write_byte) - # Base offset is always 0 on Produce - encode_varint(offset, write_byte) - - if key is not None: - encode_varint(len_func(key), write_byte) - write(key) - else: - write_byte(zero_len_varint) - - if value is not None: - encode_varint(len_func(value), write_byte) - write(value) - else: - write_byte(zero_len_varint) - - encode_varint(len_func(headers), write_byte) - - for h_key, h_value in headers: - h_key = h_key.encode("utf-8") - encode_varint(len_func(h_key), write_byte) - write(h_key) - if h_value is not None: - encode_varint(len_func(h_value), write_byte) - write(h_value) - else: - write_byte(zero_len_varint) - - message_len = len_func(message_buffer) - main_buffer = self._buffer - - required_size = message_len + size_of_varint(message_len) - # Check if we can write this message - if (required_size + len_func(main_buffer) > self._batch_size and - not first_message): - return None - - # Those should be updated after the length check - if self._max_timestamp < timestamp: - self._max_timestamp = timestamp - self._num_records += 1 - self._last_offset = offset - - encode_varint(message_len, main_buffer.append) - main_buffer.extend(message_buffer) - - return DefaultRecordMetadata(offset, required_size, timestamp) - - def write_header(self, use_compression_type=True): - batch_len = len(self._buffer) - self.HEADER_STRUCT.pack_into( - self._buffer, 0, - 0, # BaseOffset, set by broker - batch_len - self.AFTER_LEN_OFFSET, # Size from here to end - 0, # PartitionLeaderEpoch, set by broker - self._magic, - 0, # CRC will be set below, as we need a filled buffer for it - self._get_attributes(use_compression_type), - self._last_offset, - self._first_timestamp, - self._max_timestamp, - self._producer_id, - self._producer_epoch, - self._base_sequence, - self._num_records - ) - crc = calc_crc32c(self._buffer[self.ATTRIBUTES_OFFSET:]) - struct.pack_into(">I", self._buffer, self.CRC_OFFSET, crc) - - def _maybe_compress(self): - if self._compression_type != self.CODEC_NONE: - self._assert_has_codec(self._compression_type) - header_size = self.HEADER_STRUCT.size - data = bytes(self._buffer[header_size:]) - if self._compression_type == self.CODEC_GZIP: - compressed = gzip_encode(data) - elif self._compression_type == self.CODEC_SNAPPY: - compressed = snappy_encode(data) - elif self._compression_type == self.CODEC_LZ4: - compressed = lz4_encode(data) - elif self._compression_type == self.CODEC_ZSTD: - compressed = zstd_encode(data) - compressed_size = len(compressed) - if len(data) <= compressed_size: - # We did not get any benefit from compression, lets send - # uncompressed - return False - else: - # Trim bytearray to the required size - needed_size = header_size + compressed_size - del self._buffer[needed_size:] - self._buffer[header_size:needed_size] = compressed - return True - return False - - def build(self): - send_compressed = self._maybe_compress() - self.write_header(send_compressed) - return self._buffer - - def size(self): - """ Return current size of data written to buffer - """ - return len(self._buffer) - - def size_in_bytes(self, offset, timestamp, key, value, headers): - if self._first_timestamp is not None: - timestamp_delta = timestamp - self._first_timestamp - else: - timestamp_delta = 0 - size_of_body = ( - 1 + # Attrs - size_of_varint(offset) + - size_of_varint(timestamp_delta) + - self.size_of(key, value, headers) - ) - return size_of_body + size_of_varint(size_of_body) - - @classmethod - def size_of(cls, key, value, headers): - size = 0 - # Key size - if key is None: - size += 1 - else: - key_len = len(key) - size += size_of_varint(key_len) + key_len - # Value size - if value is None: - size += 1 - else: - value_len = len(value) - size += size_of_varint(value_len) + value_len - # Header size - size += size_of_varint(len(headers)) - for h_key, h_value in headers: - h_key_len = len(h_key.encode("utf-8")) - size += size_of_varint(h_key_len) + h_key_len - - if h_value is None: - size += 1 - else: - h_value_len = len(h_value) - size += size_of_varint(h_value_len) + h_value_len - return size - - @classmethod - def estimate_size_in_bytes(cls, key, value, headers): - """ Get the upper bound estimate on the size of record - """ - return ( - cls.HEADER_STRUCT.size + cls.MAX_RECORD_OVERHEAD + - cls.size_of(key, value, headers) - ) - - -class DefaultRecordMetadata(object): - - __slots__ = ("_size", "_timestamp", "_offset") - - def __init__(self, offset, size, timestamp): - self._offset = offset - self._size = size - self._timestamp = timestamp - - @property - def offset(self): - return self._offset - - @property - def crc(self): - return None - - @property - def size(self): - return self._size - - @property - def timestamp(self): - return self._timestamp - - def __repr__(self): - return ( - "DefaultRecordMetadata(offset={!r}, size={!r}, timestamp={!r})" - .format(self._offset, self._size, self._timestamp) - ) diff --git a/kafka/record/legacy_records.py b/kafka/record/legacy_records.py deleted file mode 100644 index 2f8523fc..00000000 --- a/kafka/record/legacy_records.py +++ /dev/null @@ -1,548 +0,0 @@ -# See: -# https://github.com/apache/kafka/blob/trunk/clients/src/main/java/org/\ -# apache/kafka/common/record/LegacyRecord.java - -# Builder and reader implementation for V0 and V1 record versions. As of Kafka -# 0.11.0.0 those were replaced with V2, thus the Legacy naming. - -# The schema is given below (see -# https://kafka.apache.org/protocol#protocol_message_sets for more details): - -# MessageSet => [Offset MessageSize Message] -# Offset => int64 -# MessageSize => int32 - -# v0 -# Message => Crc MagicByte Attributes Key Value -# Crc => int32 -# MagicByte => int8 -# Attributes => int8 -# Key => bytes -# Value => bytes - -# v1 (supported since 0.10.0) -# Message => Crc MagicByte Attributes Key Value -# Crc => int32 -# MagicByte => int8 -# Attributes => int8 -# Timestamp => int64 -# Key => bytes -# Value => bytes - -# The message attribute bits are given below: -# * Unused (4-7) -# * Timestamp Type (3) (added in V1) -# * Compression Type (0-2) - -# Note that when compression is enabled (see attributes above), the whole -# array of MessageSet's is compressed and places into a message as the `value`. -# Only the parent message is marked with `compression` bits in attributes. - -# The CRC covers the data from the Magic byte to the end of the message. - - -import struct -import time - -from kafka.record.abc import ABCRecord, ABCRecordBatch, ABCRecordBatchBuilder -from kafka.record.util import calc_crc32 - -from kafka.codec import ( - gzip_encode, snappy_encode, lz4_encode, lz4_encode_old_kafka, - gzip_decode, snappy_decode, lz4_decode, lz4_decode_old_kafka, -) -import kafka.codec as codecs -from kafka.errors import CorruptRecordException, UnsupportedCodecError - - -class LegacyRecordBase(object): - - __slots__ = () - - HEADER_STRUCT_V0 = struct.Struct( - ">q" # BaseOffset => Int64 - "i" # Length => Int32 - "I" # CRC => Int32 - "b" # Magic => Int8 - "b" # Attributes => Int8 - ) - HEADER_STRUCT_V1 = struct.Struct( - ">q" # BaseOffset => Int64 - "i" # Length => Int32 - "I" # CRC => Int32 - "b" # Magic => Int8 - "b" # Attributes => Int8 - "q" # timestamp => Int64 - ) - - LOG_OVERHEAD = CRC_OFFSET = struct.calcsize( - ">q" # Offset - "i" # Size - ) - MAGIC_OFFSET = LOG_OVERHEAD + struct.calcsize( - ">I" # CRC - ) - # Those are used for fast size calculations - RECORD_OVERHEAD_V0 = struct.calcsize( - ">I" # CRC - "b" # magic - "b" # attributes - "i" # Key length - "i" # Value length - ) - RECORD_OVERHEAD_V1 = struct.calcsize( - ">I" # CRC - "b" # magic - "b" # attributes - "q" # timestamp - "i" # Key length - "i" # Value length - ) - - KEY_OFFSET_V0 = HEADER_STRUCT_V0.size - KEY_OFFSET_V1 = HEADER_STRUCT_V1.size - KEY_LENGTH = VALUE_LENGTH = struct.calcsize(">i") # Bytes length is Int32 - - CODEC_MASK = 0x07 - CODEC_NONE = 0x00 - CODEC_GZIP = 0x01 - CODEC_SNAPPY = 0x02 - CODEC_LZ4 = 0x03 - TIMESTAMP_TYPE_MASK = 0x08 - - LOG_APPEND_TIME = 1 - CREATE_TIME = 0 - - NO_TIMESTAMP = -1 - - def _assert_has_codec(self, compression_type): - if compression_type == self.CODEC_GZIP: - checker, name = codecs.has_gzip, "gzip" - elif compression_type == self.CODEC_SNAPPY: - checker, name = codecs.has_snappy, "snappy" - elif compression_type == self.CODEC_LZ4: - checker, name = codecs.has_lz4, "lz4" - if not checker(): - raise UnsupportedCodecError( - "Libraries for {} compression codec not found".format(name)) - - -class LegacyRecordBatch(ABCRecordBatch, LegacyRecordBase): - - __slots__ = ("_buffer", "_magic", "_offset", "_crc", "_timestamp", - "_attributes", "_decompressed") - - def __init__(self, buffer, magic): - self._buffer = memoryview(buffer) - self._magic = magic - - offset, length, crc, magic_, attrs, timestamp = self._read_header(0) - assert length == len(buffer) - self.LOG_OVERHEAD - assert magic == magic_ - - self._offset = offset - self._crc = crc - self._timestamp = timestamp - self._attributes = attrs - self._decompressed = False - - @property - def timestamp_type(self): - """0 for CreateTime; 1 for LogAppendTime; None if unsupported. - - Value is determined by broker; produced messages should always set to 0 - Requires Kafka >= 0.10 / message version >= 1 - """ - if self._magic == 0: - return None - elif self._attributes & self.TIMESTAMP_TYPE_MASK: - return 1 - else: - return 0 - - @property - def compression_type(self): - return self._attributes & self.CODEC_MASK - - def validate_crc(self): - crc = calc_crc32(self._buffer[self.MAGIC_OFFSET:]) - return self._crc == crc - - def _decompress(self, key_offset): - # Copy of `_read_key_value`, but uses memoryview - pos = key_offset - key_size = struct.unpack_from(">i", self._buffer, pos)[0] - pos += self.KEY_LENGTH - if key_size != -1: - pos += key_size - value_size = struct.unpack_from(">i", self._buffer, pos)[0] - pos += self.VALUE_LENGTH - if value_size == -1: - raise CorruptRecordException("Value of compressed message is None") - else: - data = self._buffer[pos:pos + value_size] - - compression_type = self.compression_type - self._assert_has_codec(compression_type) - if compression_type == self.CODEC_GZIP: - uncompressed = gzip_decode(data) - elif compression_type == self.CODEC_SNAPPY: - uncompressed = snappy_decode(data.tobytes()) - elif compression_type == self.CODEC_LZ4: - if self._magic == 0: - uncompressed = lz4_decode_old_kafka(data.tobytes()) - else: - uncompressed = lz4_decode(data.tobytes()) - return uncompressed - - def _read_header(self, pos): - if self._magic == 0: - offset, length, crc, magic_read, attrs = \ - self.HEADER_STRUCT_V0.unpack_from(self._buffer, pos) - timestamp = None - else: - offset, length, crc, magic_read, attrs, timestamp = \ - self.HEADER_STRUCT_V1.unpack_from(self._buffer, pos) - return offset, length, crc, magic_read, attrs, timestamp - - def _read_all_headers(self): - pos = 0 - msgs = [] - buffer_len = len(self._buffer) - while pos < buffer_len: - header = self._read_header(pos) - msgs.append((header, pos)) - pos += self.LOG_OVERHEAD + header[1] # length - return msgs - - def _read_key_value(self, pos): - key_size = struct.unpack_from(">i", self._buffer, pos)[0] - pos += self.KEY_LENGTH - if key_size == -1: - key = None - else: - key = self._buffer[pos:pos + key_size].tobytes() - pos += key_size - - value_size = struct.unpack_from(">i", self._buffer, pos)[0] - pos += self.VALUE_LENGTH - if value_size == -1: - value = None - else: - value = self._buffer[pos:pos + value_size].tobytes() - return key, value - - def __iter__(self): - if self._magic == 1: - key_offset = self.KEY_OFFSET_V1 - else: - key_offset = self.KEY_OFFSET_V0 - timestamp_type = self.timestamp_type - - if self.compression_type: - # In case we will call iter again - if not self._decompressed: - self._buffer = memoryview(self._decompress(key_offset)) - self._decompressed = True - - # If relative offset is used, we need to decompress the entire - # message first to compute the absolute offset. - headers = self._read_all_headers() - if self._magic > 0: - msg_header, _ = headers[-1] - absolute_base_offset = self._offset - msg_header[0] - else: - absolute_base_offset = -1 - - for header, msg_pos in headers: - offset, _, crc, _, attrs, timestamp = header - # There should only ever be a single layer of compression - assert not attrs & self.CODEC_MASK, ( - 'MessageSet at offset %d appears double-compressed. This ' - 'should not happen -- check your producers!' % (offset,)) - - # When magic value is greater than 0, the timestamp - # of a compressed message depends on the - # timestamp type of the wrapper message: - if timestamp_type == self.LOG_APPEND_TIME: - timestamp = self._timestamp - - if absolute_base_offset >= 0: - offset += absolute_base_offset - - key, value = self._read_key_value(msg_pos + key_offset) - yield LegacyRecord( - offset, timestamp, timestamp_type, - key, value, crc) - else: - key, value = self._read_key_value(key_offset) - yield LegacyRecord( - self._offset, self._timestamp, timestamp_type, - key, value, self._crc) - - -class LegacyRecord(ABCRecord): - - __slots__ = ("_offset", "_timestamp", "_timestamp_type", "_key", "_value", - "_crc") - - def __init__(self, offset, timestamp, timestamp_type, key, value, crc): - self._offset = offset - self._timestamp = timestamp - self._timestamp_type = timestamp_type - self._key = key - self._value = value - self._crc = crc - - @property - def offset(self): - return self._offset - - @property - def timestamp(self): - """ Epoch milliseconds - """ - return self._timestamp - - @property - def timestamp_type(self): - """ CREATE_TIME(0) or APPEND_TIME(1) - """ - return self._timestamp_type - - @property - def key(self): - """ Bytes key or None - """ - return self._key - - @property - def value(self): - """ Bytes value or None - """ - return self._value - - @property - def headers(self): - return [] - - @property - def checksum(self): - return self._crc - - def __repr__(self): - return ( - "LegacyRecord(offset={!r}, timestamp={!r}, timestamp_type={!r}," - " key={!r}, value={!r}, crc={!r})".format( - self._offset, self._timestamp, self._timestamp_type, - self._key, self._value, self._crc) - ) - - -class LegacyRecordBatchBuilder(ABCRecordBatchBuilder, LegacyRecordBase): - - __slots__ = ("_magic", "_compression_type", "_batch_size", "_buffer") - - def __init__(self, magic, compression_type, batch_size): - self._magic = magic - self._compression_type = compression_type - self._batch_size = batch_size - self._buffer = bytearray() - - def append(self, offset, timestamp, key, value, headers=None): - """ Append message to batch. - """ - assert not headers, "Headers not supported in v0/v1" - # Check types - if type(offset) != int: - raise TypeError(offset) - if self._magic == 0: - timestamp = self.NO_TIMESTAMP - elif timestamp is None: - timestamp = int(time.time() * 1000) - elif type(timestamp) != int: - raise TypeError( - "`timestamp` should be int, but {} provided".format( - type(timestamp))) - if not (key is None or - isinstance(key, (bytes, bytearray, memoryview))): - raise TypeError( - "Not supported type for key: {}".format(type(key))) - if not (value is None or - isinstance(value, (bytes, bytearray, memoryview))): - raise TypeError( - "Not supported type for value: {}".format(type(value))) - - # Check if we have room for another message - pos = len(self._buffer) - size = self.size_in_bytes(offset, timestamp, key, value) - # We always allow at least one record to be appended - if offset != 0 and pos + size >= self._batch_size: - return None - - # Allocate proper buffer length - self._buffer.extend(bytearray(size)) - - # Encode message - crc = self._encode_msg(pos, offset, timestamp, key, value) - - return LegacyRecordMetadata(offset, crc, size, timestamp) - - def _encode_msg(self, start_pos, offset, timestamp, key, value, - attributes=0): - """ Encode msg data into the `msg_buffer`, which should be allocated - to at least the size of this message. - """ - magic = self._magic - buf = self._buffer - pos = start_pos - - # Write key and value - pos += self.KEY_OFFSET_V0 if magic == 0 else self.KEY_OFFSET_V1 - - if key is None: - struct.pack_into(">i", buf, pos, -1) - pos += self.KEY_LENGTH - else: - key_size = len(key) - struct.pack_into(">i", buf, pos, key_size) - pos += self.KEY_LENGTH - buf[pos: pos + key_size] = key - pos += key_size - - if value is None: - struct.pack_into(">i", buf, pos, -1) - pos += self.VALUE_LENGTH - else: - value_size = len(value) - struct.pack_into(">i", buf, pos, value_size) - pos += self.VALUE_LENGTH - buf[pos: pos + value_size] = value - pos += value_size - length = (pos - start_pos) - self.LOG_OVERHEAD - - # Write msg header. Note, that Crc will be updated later - if magic == 0: - self.HEADER_STRUCT_V0.pack_into( - buf, start_pos, - offset, length, 0, magic, attributes) - else: - self.HEADER_STRUCT_V1.pack_into( - buf, start_pos, - offset, length, 0, magic, attributes, timestamp) - - # Calculate CRC for msg - crc_data = memoryview(buf)[start_pos + self.MAGIC_OFFSET:] - crc = calc_crc32(crc_data) - struct.pack_into(">I", buf, start_pos + self.CRC_OFFSET, crc) - return crc - - def _maybe_compress(self): - if self._compression_type: - self._assert_has_codec(self._compression_type) - data = bytes(self._buffer) - if self._compression_type == self.CODEC_GZIP: - compressed = gzip_encode(data) - elif self._compression_type == self.CODEC_SNAPPY: - compressed = snappy_encode(data) - elif self._compression_type == self.CODEC_LZ4: - if self._magic == 0: - compressed = lz4_encode_old_kafka(data) - else: - compressed = lz4_encode(data) - size = self.size_in_bytes( - 0, timestamp=0, key=None, value=compressed) - # We will try to reuse the same buffer if we have enough space - if size > len(self._buffer): - self._buffer = bytearray(size) - else: - del self._buffer[size:] - self._encode_msg( - start_pos=0, - offset=0, timestamp=0, key=None, value=compressed, - attributes=self._compression_type) - return True - return False - - def build(self): - """Compress batch to be ready for send""" - self._maybe_compress() - return self._buffer - - def size(self): - """ Return current size of data written to buffer - """ - return len(self._buffer) - - # Size calculations. Just copied Java's implementation - - def size_in_bytes(self, offset, timestamp, key, value, headers=None): - """ Actual size of message to add - """ - assert not headers, "Headers not supported in v0/v1" - magic = self._magic - return self.LOG_OVERHEAD + self.record_size(magic, key, value) - - @classmethod - def record_size(cls, magic, key, value): - message_size = cls.record_overhead(magic) - if key is not None: - message_size += len(key) - if value is not None: - message_size += len(value) - return message_size - - @classmethod - def record_overhead(cls, magic): - assert magic in [0, 1], "Not supported magic" - if magic == 0: - return cls.RECORD_OVERHEAD_V0 - else: - return cls.RECORD_OVERHEAD_V1 - - @classmethod - def estimate_size_in_bytes(cls, magic, compression_type, key, value): - """ Upper bound estimate of record size. - """ - assert magic in [0, 1], "Not supported magic" - # In case of compression we may need another overhead for inner msg - if compression_type: - return ( - cls.LOG_OVERHEAD + cls.record_overhead(magic) + - cls.record_size(magic, key, value) - ) - return cls.LOG_OVERHEAD + cls.record_size(magic, key, value) - - -class LegacyRecordMetadata(object): - - __slots__ = ("_crc", "_size", "_timestamp", "_offset") - - def __init__(self, offset, crc, size, timestamp): - self._offset = offset - self._crc = crc - self._size = size - self._timestamp = timestamp - - @property - def offset(self): - return self._offset - - @property - def crc(self): - return self._crc - - @property - def size(self): - return self._size - - @property - def timestamp(self): - return self._timestamp - - def __repr__(self): - return ( - "LegacyRecordMetadata(offset={!r}, crc={!r}, size={!r}," - " timestamp={!r})".format( - self._offset, self._crc, self._size, self._timestamp) - ) diff --git a/kafka/record/memory_records.py b/kafka/record/memory_records.py deleted file mode 100644 index fc2ef2d6..00000000 --- a/kafka/record/memory_records.py +++ /dev/null @@ -1,187 +0,0 @@ -# This class takes advantage of the fact that all formats v0, v1 and v2 of -# messages storage has the same byte offsets for Length and Magic fields. -# Lets look closely at what leading bytes all versions have: -# -# V0 and V1 (Offset is MessageSet part, other bytes are Message ones): -# Offset => Int64 -# BytesLength => Int32 -# CRC => Int32 -# Magic => Int8 -# ... -# -# V2: -# BaseOffset => Int64 -# Length => Int32 -# PartitionLeaderEpoch => Int32 -# Magic => Int8 -# ... -# -# So we can iterate over batches just by knowing offsets of Length. Magic is -# used to construct the correct class for Batch itself. -from __future__ import division - -import struct - -from kafka.errors import CorruptRecordException -from kafka.record.abc import ABCRecords -from kafka.record.legacy_records import LegacyRecordBatch, LegacyRecordBatchBuilder -from kafka.record.default_records import DefaultRecordBatch, DefaultRecordBatchBuilder - - -class MemoryRecords(ABCRecords): - - LENGTH_OFFSET = struct.calcsize(">q") - LOG_OVERHEAD = struct.calcsize(">qi") - MAGIC_OFFSET = struct.calcsize(">qii") - - # Minimum space requirements for Record V0 - MIN_SLICE = LOG_OVERHEAD + LegacyRecordBatch.RECORD_OVERHEAD_V0 - - __slots__ = ("_buffer", "_pos", "_next_slice", "_remaining_bytes") - - def __init__(self, bytes_data): - self._buffer = bytes_data - self._pos = 0 - # We keep one slice ahead so `has_next` will return very fast - self._next_slice = None - self._remaining_bytes = None - self._cache_next() - - def size_in_bytes(self): - return len(self._buffer) - - def valid_bytes(self): - # We need to read the whole buffer to get the valid_bytes. - # NOTE: in Fetcher we do the call after iteration, so should be fast - if self._remaining_bytes is None: - next_slice = self._next_slice - pos = self._pos - while self._remaining_bytes is None: - self._cache_next() - # Reset previous iterator position - self._next_slice = next_slice - self._pos = pos - return len(self._buffer) - self._remaining_bytes - - # NOTE: we cache offsets here as kwargs for a bit more speed, as cPython - # will use LOAD_FAST opcode in this case - def _cache_next(self, len_offset=LENGTH_OFFSET, log_overhead=LOG_OVERHEAD): - buffer = self._buffer - buffer_len = len(buffer) - pos = self._pos - remaining = buffer_len - pos - if remaining < log_overhead: - # Will be re-checked in Fetcher for remaining bytes. - self._remaining_bytes = remaining - self._next_slice = None - return - - length, = struct.unpack_from( - ">i", buffer, pos + len_offset) - - slice_end = pos + log_overhead + length - if slice_end > buffer_len: - # Will be re-checked in Fetcher for remaining bytes - self._remaining_bytes = remaining - self._next_slice = None - return - - self._next_slice = memoryview(buffer)[pos: slice_end] - self._pos = slice_end - - def has_next(self): - return self._next_slice is not None - - # NOTE: same cache for LOAD_FAST as above - def next_batch(self, _min_slice=MIN_SLICE, - _magic_offset=MAGIC_OFFSET): - next_slice = self._next_slice - if next_slice is None: - return None - if len(next_slice) < _min_slice: - raise CorruptRecordException( - "Record size is less than the minimum record overhead " - "({})".format(_min_slice - self.LOG_OVERHEAD)) - self._cache_next() - magic, = struct.unpack_from(">b", next_slice, _magic_offset) - if magic <= 1: - return LegacyRecordBatch(next_slice, magic) - else: - return DefaultRecordBatch(next_slice) - - -class MemoryRecordsBuilder(object): - - __slots__ = ("_builder", "_batch_size", "_buffer", "_next_offset", "_closed", - "_bytes_written") - - def __init__(self, magic, compression_type, batch_size): - assert magic in [0, 1, 2], "Not supported magic" - assert compression_type in [0, 1, 2, 3, 4], "Not valid compression type" - if magic >= 2: - self._builder = DefaultRecordBatchBuilder( - magic=magic, compression_type=compression_type, - is_transactional=False, producer_id=-1, producer_epoch=-1, - base_sequence=-1, batch_size=batch_size) - else: - self._builder = LegacyRecordBatchBuilder( - magic=magic, compression_type=compression_type, - batch_size=batch_size) - self._batch_size = batch_size - self._buffer = None - - self._next_offset = 0 - self._closed = False - self._bytes_written = 0 - - def append(self, timestamp, key, value, headers=[]): - """ Append a message to the buffer. - - Returns: RecordMetadata or None if unable to append - """ - if self._closed: - return None - - offset = self._next_offset - metadata = self._builder.append(offset, timestamp, key, value, headers) - # Return of None means there's no space to add a new message - if metadata is None: - return None - - self._next_offset += 1 - return metadata - - def close(self): - # This method may be called multiple times on the same batch - # i.e., on retries - # we need to make sure we only close it out once - # otherwise compressed messages may be double-compressed - # see Issue 718 - if not self._closed: - self._bytes_written = self._builder.size() - self._buffer = bytes(self._builder.build()) - self._builder = None - self._closed = True - - def size_in_bytes(self): - if not self._closed: - return self._builder.size() - else: - return len(self._buffer) - - def compression_rate(self): - assert self._closed - return self.size_in_bytes() / self._bytes_written - - def is_full(self): - if self._closed: - return True - else: - return self._builder.size() >= self._batch_size - - def next_offset(self): - return self._next_offset - - def buffer(self): - assert self._closed - return self._buffer diff --git a/kafka/record/util.py b/kafka/record/util.py deleted file mode 100644 index 3b712005..00000000 --- a/kafka/record/util.py +++ /dev/null @@ -1,135 +0,0 @@ -import binascii - -from kafka.record._crc32c import crc as crc32c_py -try: - from crc32c import crc32c as crc32c_c -except ImportError: - crc32c_c = None - - -def encode_varint(value, write): - """ Encode an integer to a varint presentation. See - https://developers.google.com/protocol-buffers/docs/encoding?csw=1#varints - on how those can be produced. - - Arguments: - value (int): Value to encode - write (function): Called per byte that needs to be writen - - Returns: - int: Number of bytes written - """ - value = (value << 1) ^ (value >> 63) - - if value <= 0x7f: # 1 byte - write(value) - return 1 - if value <= 0x3fff: # 2 bytes - write(0x80 | (value & 0x7f)) - write(value >> 7) - return 2 - if value <= 0x1fffff: # 3 bytes - write(0x80 | (value & 0x7f)) - write(0x80 | ((value >> 7) & 0x7f)) - write(value >> 14) - return 3 - if value <= 0xfffffff: # 4 bytes - write(0x80 | (value & 0x7f)) - write(0x80 | ((value >> 7) & 0x7f)) - write(0x80 | ((value >> 14) & 0x7f)) - write(value >> 21) - return 4 - if value <= 0x7ffffffff: # 5 bytes - write(0x80 | (value & 0x7f)) - write(0x80 | ((value >> 7) & 0x7f)) - write(0x80 | ((value >> 14) & 0x7f)) - write(0x80 | ((value >> 21) & 0x7f)) - write(value >> 28) - return 5 - else: - # Return to general algorithm - bits = value & 0x7f - value >>= 7 - i = 0 - while value: - write(0x80 | bits) - bits = value & 0x7f - value >>= 7 - i += 1 - write(bits) - return i - - -def size_of_varint(value): - """ Number of bytes needed to encode an integer in variable-length format. - """ - value = (value << 1) ^ (value >> 63) - if value <= 0x7f: - return 1 - if value <= 0x3fff: - return 2 - if value <= 0x1fffff: - return 3 - if value <= 0xfffffff: - return 4 - if value <= 0x7ffffffff: - return 5 - if value <= 0x3ffffffffff: - return 6 - if value <= 0x1ffffffffffff: - return 7 - if value <= 0xffffffffffffff: - return 8 - if value <= 0x7fffffffffffffff: - return 9 - return 10 - - -def decode_varint(buffer, pos=0): - """ Decode an integer from a varint presentation. See - https://developers.google.com/protocol-buffers/docs/encoding?csw=1#varints - on how those can be produced. - - Arguments: - buffer (bytearray): buffer to read from. - pos (int): optional position to read from - - Returns: - (int, int): Decoded int value and next read position - """ - result = buffer[pos] - if not (result & 0x81): - return (result >> 1), pos + 1 - if not (result & 0x80): - return (result >> 1) ^ (~0), pos + 1 - - result &= 0x7f - pos += 1 - shift = 7 - while 1: - b = buffer[pos] - result |= ((b & 0x7f) << shift) - pos += 1 - if not (b & 0x80): - return ((result >> 1) ^ -(result & 1), pos) - shift += 7 - if shift >= 64: - raise ValueError("Out of int64 range") - - -_crc32c = crc32c_py -if crc32c_c is not None: - _crc32c = crc32c_c - - -def calc_crc32c(memview, _crc32c=_crc32c): - """ Calculate CRC-32C (Castagnoli) checksum over a memoryview of data - """ - return _crc32c(memview) - - -def calc_crc32(memview): - """ Calculate simple CRC-32 checksum over a memoryview of data - """ - crc = binascii.crc32(memview) & 0xffffffff - return crc diff --git a/tests/kafka/fixtures.py b/tests/kafka/fixtures.py index 5299bf3e..f36dd7e8 100644 --- a/tests/kafka/fixtures.py +++ b/tests/kafka/fixtures.py @@ -13,7 +13,7 @@ from kafka.vendor.six.moves import urllib, range from kafka.vendor.six.moves.urllib.parse import urlparse # pylint: disable=E0611,F0401 -from kafka import errors, KafkaAdminClient, KafkaClient, KafkaConsumer, KafkaProducer +from kafka import errors, KafkaAdminClient, KafkaClient from kafka.errors import InvalidReplicationFactorError from kafka.protocol.admin import CreateTopicsRequest from kafka.protocol.metadata import MetadataRequest @@ -659,15 +659,3 @@ def get_admin_clients(self, cnt, **params): params = self._enrich_client_params(params, client_id='admin_client') for client in self._create_many_clients(cnt, KafkaAdminClient, **params): yield client - - def get_consumers(self, cnt, topics, **params): - params = self._enrich_client_params( - params, client_id='consumer', heartbeat_interval_ms=500, auto_offset_reset='earliest' - ) - for client in self._create_many_clients(cnt, KafkaConsumer, *topics, **params): - yield client - - def get_producers(self, cnt, **params): - params = self._enrich_client_params(params, client_id='producer') - for client in self._create_many_clients(cnt, KafkaProducer, **params): - yield client diff --git a/tests/kafka/record/__init__.py b/tests/kafka/record/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/kafka/record/test_default_records.py b/tests/kafka/record/test_default_records.py deleted file mode 100644 index 3c809ebc..00000000 --- a/tests/kafka/record/test_default_records.py +++ /dev/null @@ -1,208 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals -import pytest -from unittest.mock import patch -import kafka.codec -from kafka.record.default_records import ( - DefaultRecordBatch, DefaultRecordBatchBuilder -) -from kafka.errors import UnsupportedCodecError - - -@pytest.mark.parametrize("compression_type", [ - DefaultRecordBatch.CODEC_NONE, - DefaultRecordBatch.CODEC_GZIP, - DefaultRecordBatch.CODEC_SNAPPY, - DefaultRecordBatch.CODEC_LZ4 -]) -def test_read_write_serde_v2(compression_type): - builder = DefaultRecordBatchBuilder( - magic=2, compression_type=compression_type, is_transactional=1, - producer_id=123456, producer_epoch=123, base_sequence=9999, - batch_size=999999) - headers = [("header1", b"aaa"), ("header2", b"bbb")] - for offset in range(10): - builder.append( - offset, timestamp=9999999, key=b"test", value=b"Super", - headers=headers) - buffer = builder.build() - reader = DefaultRecordBatch(bytes(buffer)) - msgs = list(reader) - - assert reader.is_transactional is True - assert reader.compression_type == compression_type - assert reader.magic == 2 - assert reader.timestamp_type == 0 - assert reader.base_offset == 0 - for offset, msg in enumerate(msgs): - assert msg.offset == offset - assert msg.timestamp == 9999999 - assert msg.key == b"test" - assert msg.value == b"Super" - assert msg.headers == headers - - -def test_written_bytes_equals_size_in_bytes_v2(): - key = b"test" - value = b"Super" - headers = [("header1", b"aaa"), ("header2", b"bbb"), ("xx", None)] - builder = DefaultRecordBatchBuilder( - magic=2, compression_type=0, is_transactional=0, - producer_id=-1, producer_epoch=-1, base_sequence=-1, - batch_size=999999) - - size_in_bytes = builder.size_in_bytes( - 0, timestamp=9999999, key=key, value=value, headers=headers) - - pos = builder.size() - meta = builder.append( - 0, timestamp=9999999, key=key, value=value, headers=headers) - - assert builder.size() - pos == size_in_bytes - assert meta.size == size_in_bytes - - -def test_estimate_size_in_bytes_bigger_than_batch_v2(): - key = b"Super Key" - value = b"1" * 100 - headers = [("header1", b"aaa"), ("header2", b"bbb")] - estimate_size = DefaultRecordBatchBuilder.estimate_size_in_bytes( - key, value, headers) - - builder = DefaultRecordBatchBuilder( - magic=2, compression_type=0, is_transactional=0, - producer_id=-1, producer_epoch=-1, base_sequence=-1, - batch_size=999999) - builder.append( - 0, timestamp=9999999, key=key, value=value, headers=headers) - buf = builder.build() - assert len(buf) <= estimate_size, \ - "Estimate should always be upper bound" - - -def test_default_batch_builder_validates_arguments(): - builder = DefaultRecordBatchBuilder( - magic=2, compression_type=0, is_transactional=0, - producer_id=-1, producer_epoch=-1, base_sequence=-1, - batch_size=999999) - - # Key should not be str - with pytest.raises(TypeError): - builder.append( - 0, timestamp=9999999, key="some string", value=None, headers=[]) - - # Value should not be str - with pytest.raises(TypeError): - builder.append( - 0, timestamp=9999999, key=None, value="some string", headers=[]) - - # Timestamp should be of proper type - with pytest.raises(TypeError): - builder.append( - 0, timestamp="1243812793", key=None, value=b"some string", - headers=[]) - - # Offset of invalid type - with pytest.raises(TypeError): - builder.append( - "0", timestamp=9999999, key=None, value=b"some string", headers=[]) - - # Ok to pass value as None - builder.append( - 0, timestamp=9999999, key=b"123", value=None, headers=[]) - - # Timestamp can be None - builder.append( - 1, timestamp=None, key=None, value=b"some string", headers=[]) - - # Ok to pass offsets in not incremental order. This should not happen thou - builder.append( - 5, timestamp=9999999, key=b"123", value=None, headers=[]) - - # Check record with headers - builder.append( - 6, timestamp=9999999, key=b"234", value=None, headers=[("hkey", b"hval")]) - - # in case error handling code fails to fix inner buffer in builder - assert len(builder.build()) == 124 - - -def test_default_correct_metadata_response(): - builder = DefaultRecordBatchBuilder( - magic=2, compression_type=0, is_transactional=0, - producer_id=-1, producer_epoch=-1, base_sequence=-1, - batch_size=1024 * 1024) - meta = builder.append( - 0, timestamp=9999999, key=b"test", value=b"Super", headers=[]) - - assert meta.offset == 0 - assert meta.timestamp == 9999999 - assert meta.crc is None - assert meta.size == 16 - assert repr(meta) == ( - "DefaultRecordMetadata(offset=0, size={}, timestamp={})" - .format(meta.size, meta.timestamp) - ) - - -def test_default_batch_size_limit(): - # First message can be added even if it's too big - builder = DefaultRecordBatchBuilder( - magic=2, compression_type=0, is_transactional=0, - producer_id=-1, producer_epoch=-1, base_sequence=-1, - batch_size=1024) - - meta = builder.append( - 0, timestamp=None, key=None, value=b"M" * 2000, headers=[]) - assert meta.size > 0 - assert meta.crc is None - assert meta.offset == 0 - assert meta.timestamp is not None - assert len(builder.build()) > 2000 - - builder = DefaultRecordBatchBuilder( - magic=2, compression_type=0, is_transactional=0, - producer_id=-1, producer_epoch=-1, base_sequence=-1, - batch_size=1024) - meta = builder.append( - 0, timestamp=None, key=None, value=b"M" * 700, headers=[]) - assert meta is not None - meta = builder.append( - 1, timestamp=None, key=None, value=b"M" * 700, headers=[]) - assert meta is None - meta = builder.append( - 2, timestamp=None, key=None, value=b"M" * 700, headers=[]) - assert meta is None - assert len(builder.build()) < 1000 - - -@pytest.mark.parametrize("compression_type,name,checker_name", [ - (DefaultRecordBatch.CODEC_GZIP, "gzip", "has_gzip"), - (DefaultRecordBatch.CODEC_SNAPPY, "snappy", "has_snappy"), - (DefaultRecordBatch.CODEC_LZ4, "lz4", "has_lz4") -]) -@pytest.mark.parametrize("magic", [0, 1]) -def test_unavailable_codec(magic, compression_type, name, checker_name): - builder = DefaultRecordBatchBuilder( - magic=2, compression_type=compression_type, is_transactional=0, - producer_id=-1, producer_epoch=-1, base_sequence=-1, - batch_size=1024) - builder.append(0, timestamp=None, key=None, value=b"M" * 2000, headers=[]) - correct_buffer = builder.build() - - with patch.object(kafka.codec, checker_name) as mocked: - mocked.return_value = False - # Check that builder raises error - builder = DefaultRecordBatchBuilder( - magic=2, compression_type=compression_type, is_transactional=0, - producer_id=-1, producer_epoch=-1, base_sequence=-1, - batch_size=1024) - error_msg = "Libraries for {} compression codec not found".format(name) - with pytest.raises(UnsupportedCodecError, match=error_msg): - builder.append(0, timestamp=None, key=None, value=b"M", headers=[]) - builder.build() - - # Check that reader raises same error - batch = DefaultRecordBatch(bytes(correct_buffer)) - with pytest.raises(UnsupportedCodecError, match=error_msg): - list(batch) diff --git a/tests/kafka/record/test_legacy_records.py b/tests/kafka/record/test_legacy_records.py deleted file mode 100644 index 0c87ad9a..00000000 --- a/tests/kafka/record/test_legacy_records.py +++ /dev/null @@ -1,197 +0,0 @@ -from __future__ import unicode_literals -import pytest -from unittest.mock import patch -from kafka.record.legacy_records import ( - LegacyRecordBatch, LegacyRecordBatchBuilder -) -import kafka.codec -from kafka.errors import UnsupportedCodecError - - -@pytest.mark.parametrize("magic", [0, 1]) -def test_read_write_serde_v0_v1_no_compression(magic): - builder = LegacyRecordBatchBuilder( - magic=magic, compression_type=0, batch_size=9999999) - builder.append( - 0, timestamp=9999999, key=b"test", value=b"Super") - buffer = builder.build() - - batch = LegacyRecordBatch(bytes(buffer), magic) - msgs = list(batch) - assert len(msgs) == 1 - msg = msgs[0] - - assert msg.offset == 0 - assert msg.timestamp == (9999999 if magic else None) - assert msg.timestamp_type == (0 if magic else None) - assert msg.key == b"test" - assert msg.value == b"Super" - assert msg.checksum == (-2095076219 if magic else 278251978) & 0xffffffff - - -@pytest.mark.parametrize("compression_type", [ - LegacyRecordBatch.CODEC_GZIP, - LegacyRecordBatch.CODEC_SNAPPY, - LegacyRecordBatch.CODEC_LZ4 -]) -@pytest.mark.parametrize("magic", [0, 1]) -def test_read_write_serde_v0_v1_with_compression(compression_type, magic): - builder = LegacyRecordBatchBuilder( - magic=magic, compression_type=compression_type, batch_size=9999999) - for offset in range(10): - builder.append( - offset, timestamp=9999999, key=b"test", value=b"Super") - buffer = builder.build() - - batch = LegacyRecordBatch(bytes(buffer), magic) - msgs = list(batch) - - for offset, msg in enumerate(msgs): - assert msg.offset == offset - assert msg.timestamp == (9999999 if magic else None) - assert msg.timestamp_type == (0 if magic else None) - assert msg.key == b"test" - assert msg.value == b"Super" - assert msg.checksum == (-2095076219 if magic else 278251978) & \ - 0xffffffff - - -@pytest.mark.parametrize("magic", [0, 1]) -def test_written_bytes_equals_size_in_bytes(magic): - key = b"test" - value = b"Super" - builder = LegacyRecordBatchBuilder( - magic=magic, compression_type=0, batch_size=9999999) - - size_in_bytes = builder.size_in_bytes( - 0, timestamp=9999999, key=key, value=value) - - pos = builder.size() - builder.append(0, timestamp=9999999, key=key, value=value) - - assert builder.size() - pos == size_in_bytes - - -@pytest.mark.parametrize("magic", [0, 1]) -def test_estimate_size_in_bytes_bigger_than_batch(magic): - key = b"Super Key" - value = b"1" * 100 - estimate_size = LegacyRecordBatchBuilder.estimate_size_in_bytes( - magic, compression_type=0, key=key, value=value) - - builder = LegacyRecordBatchBuilder( - magic=magic, compression_type=0, batch_size=9999999) - builder.append( - 0, timestamp=9999999, key=key, value=value) - buf = builder.build() - assert len(buf) <= estimate_size, \ - "Estimate should always be upper bound" - - -@pytest.mark.parametrize("magic", [0, 1]) -def test_legacy_batch_builder_validates_arguments(magic): - builder = LegacyRecordBatchBuilder( - magic=magic, compression_type=0, batch_size=1024 * 1024) - - # Key should not be str - with pytest.raises(TypeError): - builder.append( - 0, timestamp=9999999, key="some string", value=None) - - # Value should not be str - with pytest.raises(TypeError): - builder.append( - 0, timestamp=9999999, key=None, value="some string") - - # Timestamp should be of proper type - if magic != 0: - with pytest.raises(TypeError): - builder.append( - 0, timestamp="1243812793", key=None, value=b"some string") - - # Offset of invalid type - with pytest.raises(TypeError): - builder.append( - "0", timestamp=9999999, key=None, value=b"some string") - - # Ok to pass value as None - builder.append( - 0, timestamp=9999999, key=b"123", value=None) - - # Timestamp can be None - builder.append( - 1, timestamp=None, key=None, value=b"some string") - - # Ok to pass offsets in not incremental order. This should not happen thou - builder.append( - 5, timestamp=9999999, key=b"123", value=None) - - # in case error handling code fails to fix inner buffer in builder - assert len(builder.build()) == 119 if magic else 95 - - -@pytest.mark.parametrize("magic", [0, 1]) -def test_legacy_correct_metadata_response(magic): - builder = LegacyRecordBatchBuilder( - magic=magic, compression_type=0, batch_size=1024 * 1024) - meta = builder.append( - 0, timestamp=9999999, key=b"test", value=b"Super") - - assert meta.offset == 0 - assert meta.timestamp == (9999999 if magic else -1) - assert meta.crc == (-2095076219 if magic else 278251978) & 0xffffffff - assert repr(meta) == ( - "LegacyRecordMetadata(offset=0, crc={!r}, size={}, " - "timestamp={})".format(meta.crc, meta.size, meta.timestamp) - ) - - -@pytest.mark.parametrize("magic", [0, 1]) -def test_legacy_batch_size_limit(magic): - # First message can be added even if it's too big - builder = LegacyRecordBatchBuilder( - magic=magic, compression_type=0, batch_size=1024) - meta = builder.append(0, timestamp=None, key=None, value=b"M" * 2000) - assert meta.size > 0 - assert meta.crc is not None - assert meta.offset == 0 - assert meta.timestamp is not None - assert len(builder.build()) > 2000 - - builder = LegacyRecordBatchBuilder( - magic=magic, compression_type=0, batch_size=1024) - meta = builder.append(0, timestamp=None, key=None, value=b"M" * 700) - assert meta is not None - meta = builder.append(1, timestamp=None, key=None, value=b"M" * 700) - assert meta is None - meta = builder.append(2, timestamp=None, key=None, value=b"M" * 700) - assert meta is None - assert len(builder.build()) < 1000 - - -@pytest.mark.parametrize("compression_type,name,checker_name", [ - (LegacyRecordBatch.CODEC_GZIP, "gzip", "has_gzip"), - (LegacyRecordBatch.CODEC_SNAPPY, "snappy", "has_snappy"), - (LegacyRecordBatch.CODEC_LZ4, "lz4", "has_lz4") -]) -@pytest.mark.parametrize("magic", [0, 1]) -def test_unavailable_codec(magic, compression_type, name, checker_name): - builder = LegacyRecordBatchBuilder( - magic=magic, compression_type=compression_type, batch_size=1024) - builder.append(0, timestamp=None, key=None, value=b"M") - correct_buffer = builder.build() - - with patch.object(kafka.codec, checker_name) as mocked: - mocked.return_value = False - # Check that builder raises error - builder = LegacyRecordBatchBuilder( - magic=magic, compression_type=compression_type, batch_size=1024) - error_msg = "Libraries for {} compression codec not found".format(name) - with pytest.raises(UnsupportedCodecError, match=error_msg): - builder.append(0, timestamp=None, key=None, value=b"M") - builder.build() - - # Check that reader raises same error - batch = LegacyRecordBatch(bytes(correct_buffer), magic) - with pytest.raises(UnsupportedCodecError, match=error_msg): - list(batch) diff --git a/tests/kafka/record/test_records.py b/tests/kafka/record/test_records.py deleted file mode 100644 index 9f72234a..00000000 --- a/tests/kafka/record/test_records.py +++ /dev/null @@ -1,232 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import unicode_literals -import pytest -from kafka.record import MemoryRecords, MemoryRecordsBuilder -from kafka.errors import CorruptRecordException - -# This is real live data from Kafka 11 broker -record_batch_data_v2 = [ - # First Batch value == "123" - b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00;\x00\x00\x00\x01\x02\x03' - b'\x18\xa2p\x00\x00\x00\x00\x00\x00\x00\x00\x01]\xff{\x06<\x00\x00\x01]' - b'\xff{\x06<\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00' - b'\x00\x00\x01\x12\x00\x00\x00\x01\x06123\x00', - # Second Batch value = "" and value = "". 2 records - b'\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00@\x00\x00\x00\x02\x02\xc8' - b'\\\xbd#\x00\x00\x00\x00\x00\x01\x00\x00\x01]\xff|\xddl\x00\x00\x01]\xff' - b'|\xde\x14\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00' - b'\x00\x00\x02\x0c\x00\x00\x00\x01\x00\x00\x0e\x00\xd0\x02\x02\x01\x00' - b'\x00', - # Third batch value = "123" - b'\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00;\x00\x00\x00\x02\x02.\x0b' - b'\x85\xb7\x00\x00\x00\x00\x00\x00\x00\x00\x01]\xff|\xe7\x9d\x00\x00\x01]' - b'\xff|\xe7\x9d\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff' - b'\x00\x00\x00\x01\x12\x00\x00\x00\x01\x06123\x00' - # Fourth batch value = "hdr" with header hkey=hval - b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00E\x00\x00\x00\x00\x02\\' - b'\xd8\xefR\x00\x00\x00\x00\x00\x00\x00\x00\x01e\x85\xb6\xf3\xc1\x00\x00' - b'\x01e\x85\xb6\xf3\xc1\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff' - b'\xff\xff\x00\x00\x00\x01&\x00\x00\x00\x01\x06hdr\x02\x08hkey\x08hval' -] - -record_batch_data_v1 = [ - # First Message value == "123" - b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x19G\x86(\xc2\x01\x00\x00' - b'\x00\x01^\x18g\xab\xae\xff\xff\xff\xff\x00\x00\x00\x03123', - # Second Message value == "" - b'\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x16\xef\x98\xc9 \x01\x00' - b'\x00\x00\x01^\x18g\xaf\xc0\xff\xff\xff\xff\x00\x00\x00\x00', - # Third Message value == "" - b'\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x16_\xaf\xfb^\x01\x00\x00' - b'\x00\x01^\x18g\xb0r\xff\xff\xff\xff\x00\x00\x00\x00', - # Fourth Message value = "123" - b'\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x19\xa8\x12W \x01\x00\x00' - b'\x00\x01^\x18g\xb8\x03\xff\xff\xff\xff\x00\x00\x00\x03123' -] - -# This is real live data from Kafka 10 broker -record_batch_data_v0 = [ - # First Message value == "123" - b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11\xfe\xb0\x1d\xbf\x00' - b'\x00\xff\xff\xff\xff\x00\x00\x00\x03123', - # Second Message value == "" - b'\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x0eyWH\xe0\x00\x00\xff' - b'\xff\xff\xff\x00\x00\x00\x00', - # Third Message value == "" - b'\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x0eyWH\xe0\x00\x00\xff' - b'\xff\xff\xff\x00\x00\x00\x00', - # Fourth Message value = "123" - b'\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x11\xfe\xb0\x1d\xbf\x00' - b'\x00\xff\xff\xff\xff\x00\x00\x00\x03123' -] - - -def test_memory_records_v2(): - data_bytes = b"".join(record_batch_data_v2) + b"\x00" * 4 - records = MemoryRecords(data_bytes) - - assert records.size_in_bytes() == 303 - assert records.valid_bytes() == 299 - - assert records.has_next() is True - batch = records.next_batch() - recs = list(batch) - assert len(recs) == 1 - assert recs[0].value == b"123" - assert recs[0].key is None - assert recs[0].timestamp == 1503229838908 - assert recs[0].timestamp_type == 0 - assert recs[0].checksum is None - assert recs[0].headers == [] - - assert records.next_batch() is not None - assert records.next_batch() is not None - - batch = records.next_batch() - recs = list(batch) - assert len(recs) == 1 - assert recs[0].value == b"hdr" - assert recs[0].headers == [('hkey', b'hval')] - - assert records.has_next() is False - assert records.next_batch() is None - assert records.next_batch() is None - - -def test_memory_records_v1(): - data_bytes = b"".join(record_batch_data_v1) + b"\x00" * 4 - records = MemoryRecords(data_bytes) - - assert records.size_in_bytes() == 146 - assert records.valid_bytes() == 142 - - assert records.has_next() is True - batch = records.next_batch() - recs = list(batch) - assert len(recs) == 1 - assert recs[0].value == b"123" - assert recs[0].key is None - assert recs[0].timestamp == 1503648000942 - assert recs[0].timestamp_type == 0 - assert recs[0].checksum == 1199974594 & 0xffffffff - - assert records.next_batch() is not None - assert records.next_batch() is not None - assert records.next_batch() is not None - - assert records.has_next() is False - assert records.next_batch() is None - assert records.next_batch() is None - - -def test_memory_records_v0(): - data_bytes = b"".join(record_batch_data_v0) - records = MemoryRecords(data_bytes + b"\x00" * 4) - - assert records.size_in_bytes() == 114 - assert records.valid_bytes() == 110 - - records = MemoryRecords(data_bytes) - - assert records.has_next() is True - batch = records.next_batch() - recs = list(batch) - assert len(recs) == 1 - assert recs[0].value == b"123" - assert recs[0].key is None - assert recs[0].timestamp is None - assert recs[0].timestamp_type is None - assert recs[0].checksum == -22012481 & 0xffffffff - - assert records.next_batch() is not None - assert records.next_batch() is not None - assert records.next_batch() is not None - - assert records.has_next() is False - assert records.next_batch() is None - assert records.next_batch() is None - - -def test_memory_records_corrupt(): - records = MemoryRecords(b"") - assert records.size_in_bytes() == 0 - assert records.valid_bytes() == 0 - assert records.has_next() is False - - records = MemoryRecords(b"\x00\x00\x00") - assert records.size_in_bytes() == 3 - assert records.valid_bytes() == 0 - assert records.has_next() is False - - records = MemoryRecords( - b"\x00\x00\x00\x00\x00\x00\x00\x03" # Offset=3 - b"\x00\x00\x00\x03" # Length=3 - b"\xfe\xb0\x1d", # Some random bytes - ) - with pytest.raises(CorruptRecordException): - records.next_batch() - - -@pytest.mark.parametrize("compression_type", [0, 1, 2, 3]) -@pytest.mark.parametrize("magic", [0, 1, 2]) -def test_memory_records_builder(magic, compression_type): - builder = MemoryRecordsBuilder( - magic=magic, compression_type=compression_type, batch_size=1024 * 10) - base_size = builder.size_in_bytes() # V2 has a header before - - msg_sizes = [] - for offset in range(10): - metadata = builder.append( - timestamp=10000 + offset, key=b"test", value=b"Super") - msg_sizes.append(metadata.size) - assert metadata.offset == offset - if magic > 0: - assert metadata.timestamp == 10000 + offset - else: - assert metadata.timestamp == -1 - assert builder.next_offset() == offset + 1 - - # Error appends should not leave junk behind, like null bytes or something - with pytest.raises(TypeError): - builder.append( - timestamp=None, key="test", value="Super") # Not bytes, but str - - assert not builder.is_full() - size_before_close = builder.size_in_bytes() - assert size_before_close == sum(msg_sizes) + base_size - - # Size should remain the same after closing. No trailing bytes - builder.close() - assert builder.compression_rate() > 0 - expected_size = size_before_close * builder.compression_rate() - assert builder.is_full() - assert builder.size_in_bytes() == expected_size - buffer = builder.buffer() - assert len(buffer) == expected_size - - # We can close second time, as in retry - builder.close() - assert builder.size_in_bytes() == expected_size - assert builder.buffer() == buffer - - # Can't append after close - meta = builder.append(timestamp=None, key=b"test", value=b"Super") - assert meta is None - - -@pytest.mark.parametrize("compression_type", [0, 1, 2, 3]) -@pytest.mark.parametrize("magic", [0, 1, 2]) -def test_memory_records_builder_full(magic, compression_type): - builder = MemoryRecordsBuilder( - magic=magic, compression_type=compression_type, batch_size=1024 * 10) - - # 1 message should always be appended - metadata = builder.append( - key=None, timestamp=None, value=b"M" * 10240) - assert metadata is not None - assert builder.is_full() - - metadata = builder.append( - key=None, timestamp=None, value=b"M") - assert metadata is None - assert builder.next_offset() == 1 diff --git a/tests/kafka/record/test_util.py b/tests/kafka/record/test_util.py deleted file mode 100644 index 0b2782e7..00000000 --- a/tests/kafka/record/test_util.py +++ /dev/null @@ -1,96 +0,0 @@ -import struct -import pytest -from kafka.record import util - - -varint_data = [ - (b"\x00", 0), - (b"\x01", -1), - (b"\x02", 1), - (b"\x7E", 63), - (b"\x7F", -64), - (b"\x80\x01", 64), - (b"\x81\x01", -65), - (b"\xFE\x7F", 8191), - (b"\xFF\x7F", -8192), - (b"\x80\x80\x01", 8192), - (b"\x81\x80\x01", -8193), - (b"\xFE\xFF\x7F", 1048575), - (b"\xFF\xFF\x7F", -1048576), - (b"\x80\x80\x80\x01", 1048576), - (b"\x81\x80\x80\x01", -1048577), - (b"\xFE\xFF\xFF\x7F", 134217727), - (b"\xFF\xFF\xFF\x7F", -134217728), - (b"\x80\x80\x80\x80\x01", 134217728), - (b"\x81\x80\x80\x80\x01", -134217729), - (b"\xFE\xFF\xFF\xFF\x7F", 17179869183), - (b"\xFF\xFF\xFF\xFF\x7F", -17179869184), - (b"\x80\x80\x80\x80\x80\x01", 17179869184), - (b"\x81\x80\x80\x80\x80\x01", -17179869185), - (b"\xFE\xFF\xFF\xFF\xFF\x7F", 2199023255551), - (b"\xFF\xFF\xFF\xFF\xFF\x7F", -2199023255552), - (b"\x80\x80\x80\x80\x80\x80\x01", 2199023255552), - (b"\x81\x80\x80\x80\x80\x80\x01", -2199023255553), - (b"\xFE\xFF\xFF\xFF\xFF\xFF\x7F", 281474976710655), - (b"\xFF\xFF\xFF\xFF\xFF\xFF\x7F", -281474976710656), - (b"\x80\x80\x80\x80\x80\x80\x80\x01", 281474976710656), - (b"\x81\x80\x80\x80\x80\x80\x80\x01", -281474976710657), - (b"\xFE\xFF\xFF\xFF\xFF\xFF\xFF\x7F", 36028797018963967), - (b"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x7F", -36028797018963968), - (b"\x80\x80\x80\x80\x80\x80\x80\x80\x01", 36028797018963968), - (b"\x81\x80\x80\x80\x80\x80\x80\x80\x01", -36028797018963969), - (b"\xFE\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x7F", 4611686018427387903), - (b"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF\x7F", -4611686018427387904), - (b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01", 4611686018427387904), - (b"\x81\x80\x80\x80\x80\x80\x80\x80\x80\x01", -4611686018427387905), -] - - -@pytest.mark.parametrize("encoded, decoded", varint_data) -def test_encode_varint(encoded, decoded): - res = bytearray() - util.encode_varint(decoded, res.append) - assert res == encoded - - -@pytest.mark.parametrize("encoded, decoded", varint_data) -def test_decode_varint(encoded, decoded): - # We add a bit of bytes around just to check position is calculated - # correctly - value, pos = util.decode_varint( - bytearray(b"\x01\xf0" + encoded + b"\xff\x01"), 2) - assert value == decoded - assert pos - 2 == len(encoded) - - -@pytest.mark.parametrize("encoded, decoded", varint_data) -def test_size_of_varint(encoded, decoded): - assert util.size_of_varint(decoded) == len(encoded) - - -@pytest.mark.parametrize("crc32_func", [util.crc32c_c, util.crc32c_py]) -def test_crc32c(crc32_func): - def make_crc(data): - crc = crc32_func(data) - return struct.pack(">I", crc) - assert make_crc(b"") == b"\x00\x00\x00\x00" - assert make_crc(b"a") == b"\xc1\xd0\x43\x30" - - # Took from librdkafka testcase - long_text = b"""\ - This software is provided 'as-is', without any express or implied - warranty. In no event will the author be held liable for any damages - arising from the use of this software. - - Permission is granted to anyone to use this software for any purpose, - including commercial applications, and to alter it and redistribute it - freely, subject to the following restrictions: - - 1. The origin of this software must not be misrepresented; you must not - claim that you wrote the original software. If you use this software - in a product, an acknowledgment in the product documentation would be - appreciated but is not required. - 2. Altered source versions must be plainly marked as such, and must not be - misrepresented as being the original software. - 3. This notice may not be removed or altered from any source distribution.""" - assert make_crc(long_text) == b"\x7d\xcd\xe1\x13" diff --git a/tests/kafka/test_consumer.py b/tests/kafka/test_consumer.py deleted file mode 100644 index 436fe55c..00000000 --- a/tests/kafka/test_consumer.py +++ /dev/null @@ -1,26 +0,0 @@ -import pytest - -from kafka import KafkaConsumer -from kafka.errors import KafkaConfigurationError - - -class TestKafkaConsumer: - def test_session_timeout_larger_than_request_timeout_raises(self): - with pytest.raises(KafkaConfigurationError): - KafkaConsumer(bootstrap_servers='localhost:9092', api_version=(0, 9), group_id='foo', session_timeout_ms=50000, request_timeout_ms=40000) - - def test_fetch_max_wait_larger_than_request_timeout_raises(self): - with pytest.raises(KafkaConfigurationError): - KafkaConsumer(bootstrap_servers='localhost:9092', fetch_max_wait_ms=50000, request_timeout_ms=40000) - - def test_request_timeout_larger_than_connections_max_idle_ms_raises(self): - with pytest.raises(KafkaConfigurationError): - KafkaConsumer(bootstrap_servers='localhost:9092', api_version=(0, 9), request_timeout_ms=50000, connections_max_idle_ms=40000) - - def test_subscription_copy(self): - consumer = KafkaConsumer('foo', api_version=(0, 10)) - sub = consumer.subscription() - assert sub is not consumer.subscription() - assert sub == set(['foo']) - sub.add('fizz') - assert consumer.subscription() == set(['foo']) diff --git a/tests/kafka/test_consumer_group.py b/tests/kafka/test_consumer_group.py deleted file mode 100644 index 40dc9d70..00000000 --- a/tests/kafka/test_consumer_group.py +++ /dev/null @@ -1,179 +0,0 @@ -import collections -import logging -import threading -import time - -import pytest -from kafka.vendor import six - -from kafka.conn import ConnectionStates -from kafka.consumer.group import KafkaConsumer -from kafka.coordinator.base import MemberState -from kafka.structs import TopicPartition - -from tests.kafka.testutil import env_kafka_version, random_string - - -def get_connect_str(kafka_broker): - return kafka_broker.host + ':' + str(kafka_broker.port) - - -@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set") -def test_consumer(kafka_broker, topic): - # The `topic` fixture is included because - # 0.8.2 brokers need a topic to function well - consumer = KafkaConsumer(bootstrap_servers=get_connect_str(kafka_broker)) - consumer.poll(500) - assert len(consumer._client._conns) > 0 - node_id = list(consumer._client._conns.keys())[0] - assert consumer._client._conns[node_id].state is ConnectionStates.CONNECTED - consumer.close() - - -@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set") -def test_consumer_topics(kafka_broker, topic): - consumer = KafkaConsumer(bootstrap_servers=get_connect_str(kafka_broker)) - # Necessary to drive the IO - consumer.poll(500) - assert topic in consumer.topics() - assert len(consumer.partitions_for_topic(topic)) > 0 - consumer.close() - - -@pytest.mark.skipif(env_kafka_version() < (0, 9), reason='Unsupported Kafka Version') -def test_group(kafka_broker, topic): - num_partitions = 4 - connect_str = get_connect_str(kafka_broker) - consumers = {} - stop = {} - threads = {} - messages = collections.defaultdict(list) - group_id = 'test-group-' + random_string(6) - def consumer_thread(i): - assert i not in consumers - assert i not in stop - stop[i] = threading.Event() - consumers[i] = KafkaConsumer(topic, - bootstrap_servers=connect_str, - group_id=group_id, - heartbeat_interval_ms=500) - while not stop[i].is_set(): - for tp, records in six.itervalues(consumers[i].poll(100)): - messages[i][tp].extend(records) - consumers[i].close() - consumers[i] = None - stop[i] = None - - num_consumers = 4 - for i in range(num_consumers): - t = threading.Thread(target=consumer_thread, args=(i,)) - t.start() - threads[i] = t - - try: - timeout = time.time() + 35 - while True: - for c in range(num_consumers): - - # Verify all consumers have been created - if c not in consumers: - break - - # Verify all consumers have an assignment - elif not consumers[c].assignment(): - break - - # If all consumers exist and have an assignment - else: - - logging.info('All consumers have assignment... checking for stable group') - # Verify all consumers are in the same generation - # then log state and break while loop - generations = set([consumer._coordinator._generation.generation_id - for consumer in list(consumers.values())]) - - # New generation assignment is not complete until - # coordinator.rejoining = False - rejoining = any([consumer._coordinator.rejoining - for consumer in list(consumers.values())]) - - if not rejoining and len(generations) == 1: - for c, consumer in list(consumers.items()): - logging.info("[%s] %s %s: %s", c, - consumer._coordinator._generation.generation_id, - consumer._coordinator._generation.member_id, - consumer.assignment()) - break - else: - logging.info('Rejoining: %s, generations: %s', rejoining, generations) - time.sleep(1) - assert time.time() < timeout, "timeout waiting for assignments" - - logging.info('Group stabilized; verifying assignment') - group_assignment = set() - for c in range(num_consumers): - assert len(consumers[c].assignment()) != 0 - assert set.isdisjoint(consumers[c].assignment(), group_assignment) - group_assignment.update(consumers[c].assignment()) - - assert group_assignment == set([ - TopicPartition(topic, partition) - for partition in range(num_partitions)]) - logging.info('Assignment looks good!') - - finally: - logging.info('Shutting down %s consumers', num_consumers) - for c in range(num_consumers): - logging.info('Stopping consumer %s', c) - stop[c].set() - threads[c].join() - threads[c] = None - - -@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set") -def test_paused(kafka_broker, topic): - consumer = KafkaConsumer(bootstrap_servers=get_connect_str(kafka_broker)) - topics = [TopicPartition(topic, 1)] - consumer.assign(topics) - assert set(topics) == consumer.assignment() - assert set() == consumer.paused() - - consumer.pause(topics[0]) - assert set([topics[0]]) == consumer.paused() - - consumer.resume(topics[0]) - assert set() == consumer.paused() - - consumer.unsubscribe() - assert set() == consumer.paused() - consumer.close() - - -@pytest.mark.skipif(env_kafka_version() < (0, 9), reason='Unsupported Kafka Version') -def test_heartbeat_thread(kafka_broker, topic): - group_id = 'test-group-' + random_string(6) - consumer = KafkaConsumer(topic, - bootstrap_servers=get_connect_str(kafka_broker), - group_id=group_id, - heartbeat_interval_ms=500) - - # poll until we have joined group / have assignment - while not consumer.assignment(): - consumer.poll(timeout_ms=100) - - assert consumer._coordinator.state is MemberState.STABLE - last_poll = consumer._coordinator.heartbeat.last_poll - last_beat = consumer._coordinator.heartbeat.last_send - - timeout = time.time() + 30 - while True: - if time.time() > timeout: - raise RuntimeError('timeout waiting for heartbeat') - if consumer._coordinator.heartbeat.last_send > last_beat: - break - time.sleep(0.5) - - assert consumer._coordinator.heartbeat.last_poll == last_poll - consumer.poll(timeout_ms=100) - assert consumer._coordinator.heartbeat.last_poll > last_poll - consumer.close() diff --git a/tests/kafka/test_fetcher.py b/tests/kafka/test_fetcher.py deleted file mode 100644 index 697f8be1..00000000 --- a/tests/kafka/test_fetcher.py +++ /dev/null @@ -1,553 +0,0 @@ -# pylint: skip-file -from __future__ import absolute_import - -import pytest - -from collections import OrderedDict -import itertools -import time - -from kafka.client_async import KafkaClient -from kafka.consumer.fetcher import ( - CompletedFetch, ConsumerRecord, Fetcher, NoOffsetForPartitionError -) -from kafka.consumer.subscription_state import SubscriptionState -from kafka.future import Future -from kafka.metrics import Metrics -from kafka.protocol.fetch import FetchRequest, FetchResponse -from kafka.protocol.offset import OffsetResponse -from kafka.errors import ( - StaleMetadata, LeaderNotAvailableError, NotLeaderForPartitionError, - UnknownTopicOrPartitionError, OffsetOutOfRangeError -) -from kafka.record.memory_records import MemoryRecordsBuilder, MemoryRecords -from kafka.structs import OffsetAndMetadata, TopicPartition - - -@pytest.fixture -def client(mocker): - return mocker.Mock(spec=KafkaClient(bootstrap_servers=(), api_version=(0, 9))) - - -@pytest.fixture -def subscription_state(): - return SubscriptionState() - - -@pytest.fixture -def topic(): - return 'foobar' - - -@pytest.fixture -def fetcher(client, subscription_state, topic): - subscription_state.subscribe(topics=[topic]) - assignment = [TopicPartition(topic, i) for i in range(3)] - subscription_state.assign_from_subscribed(assignment) - for tp in assignment: - subscription_state.seek(tp, 0) - return Fetcher(client, subscription_state, Metrics()) - - -def _build_record_batch(msgs, compression=0): - builder = MemoryRecordsBuilder( - magic=1, compression_type=0, batch_size=9999999) - for msg in msgs: - key, value, timestamp = msg - builder.append(key=key, value=value, timestamp=timestamp, headers=[]) - builder.close() - return builder.buffer() - - -def test_send_fetches(fetcher, topic, mocker): - fetch_requests = [ - FetchRequest[0]( - -1, fetcher.config['fetch_max_wait_ms'], - fetcher.config['fetch_min_bytes'], - [(topic, [ - (0, 0, fetcher.config['max_partition_fetch_bytes']), - (1, 0, fetcher.config['max_partition_fetch_bytes']), - ])]), - FetchRequest[0]( - -1, fetcher.config['fetch_max_wait_ms'], - fetcher.config['fetch_min_bytes'], - [(topic, [ - (2, 0, fetcher.config['max_partition_fetch_bytes']), - ])]) - ] - - mocker.patch.object(fetcher, '_create_fetch_requests', - return_value=dict(enumerate(fetch_requests))) - - ret = fetcher.send_fetches() - for node, request in enumerate(fetch_requests): - fetcher._client.send.assert_any_call(node, request, wakeup=False) - assert len(ret) == len(fetch_requests) - - -@pytest.mark.parametrize(("api_version", "fetch_version"), [ - ((0, 10, 1), 3), - ((0, 10, 0), 2), - ((0, 9), 1), - ((0, 8), 0) -]) -def test_create_fetch_requests(fetcher, mocker, api_version, fetch_version): - fetcher._client.in_flight_request_count.return_value = 0 - fetcher.config['api_version'] = api_version - by_node = fetcher._create_fetch_requests() - requests = by_node.values() - assert all([isinstance(r, FetchRequest[fetch_version]) for r in requests]) - - -def test_update_fetch_positions(fetcher, topic, mocker): - mocker.patch.object(fetcher, '_reset_offset') - partition = TopicPartition(topic, 0) - - # unassigned partition - fetcher.update_fetch_positions([TopicPartition('fizzbuzz', 0)]) - assert fetcher._reset_offset.call_count == 0 - - # fetchable partition (has offset, not paused) - fetcher.update_fetch_positions([partition]) - assert fetcher._reset_offset.call_count == 0 - - # partition needs reset, no committed offset - fetcher._subscriptions.need_offset_reset(partition) - fetcher._subscriptions.assignment[partition].awaiting_reset = False - fetcher.update_fetch_positions([partition]) - fetcher._reset_offset.assert_called_with(partition) - assert fetcher._subscriptions.assignment[partition].awaiting_reset is True - fetcher.update_fetch_positions([partition]) - fetcher._reset_offset.assert_called_with(partition) - - # partition needs reset, has committed offset - fetcher._reset_offset.reset_mock() - fetcher._subscriptions.need_offset_reset(partition) - fetcher._subscriptions.assignment[partition].awaiting_reset = False - fetcher._subscriptions.assignment[partition].committed = OffsetAndMetadata(123, b'') - mocker.patch.object(fetcher._subscriptions, 'seek') - fetcher.update_fetch_positions([partition]) - assert fetcher._reset_offset.call_count == 0 - fetcher._subscriptions.seek.assert_called_with(partition, 123) - - -def test__reset_offset(fetcher, mocker): - tp = TopicPartition("topic", 0) - fetcher._subscriptions.subscribe(topics="topic") - fetcher._subscriptions.assign_from_subscribed([tp]) - fetcher._subscriptions.need_offset_reset(tp) - mocked = mocker.patch.object(fetcher, '_retrieve_offsets') - - mocked.return_value = {tp: (1001, None)} - fetcher._reset_offset(tp) - assert not fetcher._subscriptions.assignment[tp].awaiting_reset - assert fetcher._subscriptions.assignment[tp].position == 1001 - - -def test__send_offset_requests(fetcher, mocker): - tp = TopicPartition("topic_send_offset", 1) - mocked_send = mocker.patch.object(fetcher, "_send_offset_request") - send_futures = [] - - def send_side_effect(*args, **kw): - f = Future() - send_futures.append(f) - return f - mocked_send.side_effect = send_side_effect - - mocked_leader = mocker.patch.object( - fetcher._client.cluster, "leader_for_partition") - # First we report unavailable leader 2 times different ways and later - # always as available - mocked_leader.side_effect = itertools.chain( - [None, -1], itertools.cycle([0])) - - # Leader == None - fut = fetcher._send_offset_requests({tp: 0}) - assert fut.failed() - assert isinstance(fut.exception, StaleMetadata) - assert not mocked_send.called - - # Leader == -1 - fut = fetcher._send_offset_requests({tp: 0}) - assert fut.failed() - assert isinstance(fut.exception, LeaderNotAvailableError) - assert not mocked_send.called - - # Leader == 0, send failed - fut = fetcher._send_offset_requests({tp: 0}) - assert not fut.is_done - assert mocked_send.called - # Check that we bound the futures correctly to chain failure - send_futures.pop().failure(NotLeaderForPartitionError(tp)) - assert fut.failed() - assert isinstance(fut.exception, NotLeaderForPartitionError) - - # Leader == 0, send success - fut = fetcher._send_offset_requests({tp: 0}) - assert not fut.is_done - assert mocked_send.called - # Check that we bound the futures correctly to chain success - send_futures.pop().success({tp: (10, 10000)}) - assert fut.succeeded() - assert fut.value == {tp: (10, 10000)} - - -def test__send_offset_requests_multiple_nodes(fetcher, mocker): - tp1 = TopicPartition("topic_send_offset", 1) - tp2 = TopicPartition("topic_send_offset", 2) - tp3 = TopicPartition("topic_send_offset", 3) - tp4 = TopicPartition("topic_send_offset", 4) - mocked_send = mocker.patch.object(fetcher, "_send_offset_request") - send_futures = [] - - def send_side_effect(node_id, timestamps): - f = Future() - send_futures.append((node_id, timestamps, f)) - return f - mocked_send.side_effect = send_side_effect - - mocked_leader = mocker.patch.object( - fetcher._client.cluster, "leader_for_partition") - mocked_leader.side_effect = itertools.cycle([0, 1]) - - # -- All node succeeded case - tss = OrderedDict([(tp1, 0), (tp2, 0), (tp3, 0), (tp4, 0)]) - fut = fetcher._send_offset_requests(tss) - assert not fut.is_done - assert mocked_send.call_count == 2 - - req_by_node = {} - second_future = None - for node, timestamps, f in send_futures: - req_by_node[node] = timestamps - if node == 0: - # Say tp3 does not have any messages so it's missing - f.success({tp1: (11, 1001)}) - else: - second_future = f - assert req_by_node == { - 0: {tp1: 0, tp3: 0}, - 1: {tp2: 0, tp4: 0} - } - - # We only resolved 1 future so far, so result future is not yet ready - assert not fut.is_done - second_future.success({tp2: (12, 1002), tp4: (14, 1004)}) - assert fut.succeeded() - assert fut.value == {tp1: (11, 1001), tp2: (12, 1002), tp4: (14, 1004)} - - # -- First succeeded second not - del send_futures[:] - fut = fetcher._send_offset_requests(tss) - assert len(send_futures) == 2 - send_futures[0][2].success({tp1: (11, 1001)}) - send_futures[1][2].failure(UnknownTopicOrPartitionError(tp1)) - assert fut.failed() - assert isinstance(fut.exception, UnknownTopicOrPartitionError) - - # -- First fails second succeeded - del send_futures[:] - fut = fetcher._send_offset_requests(tss) - assert len(send_futures) == 2 - send_futures[0][2].failure(UnknownTopicOrPartitionError(tp1)) - send_futures[1][2].success({tp1: (11, 1001)}) - assert fut.failed() - assert isinstance(fut.exception, UnknownTopicOrPartitionError) - - -def test__handle_offset_response(fetcher, mocker): - # Broker returns UnsupportedForMessageFormatError, will omit partition - fut = Future() - res = OffsetResponse[1]([ - ("topic", [(0, 43, -1, -1)]), - ("topic", [(1, 0, 1000, 9999)]) - ]) - fetcher._handle_offset_response(fut, res) - assert fut.succeeded() - assert fut.value == {TopicPartition("topic", 1): (9999, 1000)} - - # Broker returns NotLeaderForPartitionError - fut = Future() - res = OffsetResponse[1]([ - ("topic", [(0, 6, -1, -1)]), - ]) - fetcher._handle_offset_response(fut, res) - assert fut.failed() - assert isinstance(fut.exception, NotLeaderForPartitionError) - - # Broker returns UnknownTopicOrPartitionError - fut = Future() - res = OffsetResponse[1]([ - ("topic", [(0, 3, -1, -1)]), - ]) - fetcher._handle_offset_response(fut, res) - assert fut.failed() - assert isinstance(fut.exception, UnknownTopicOrPartitionError) - - # Broker returns many errors and 1 result - # Will fail on 1st error and return - fut = Future() - res = OffsetResponse[1]([ - ("topic", [(0, 43, -1, -1)]), - ("topic", [(1, 6, -1, -1)]), - ("topic", [(2, 3, -1, -1)]), - ("topic", [(3, 0, 1000, 9999)]) - ]) - fetcher._handle_offset_response(fut, res) - assert fut.failed() - assert isinstance(fut.exception, NotLeaderForPartitionError) - - -def test_fetched_records(fetcher, topic, mocker): - fetcher.config['check_crcs'] = False - tp = TopicPartition(topic, 0) - - msgs = [] - for i in range(10): - msgs.append((None, b"foo", None)) - completed_fetch = CompletedFetch( - tp, 0, 0, [0, 100, _build_record_batch(msgs)], - mocker.MagicMock() - ) - fetcher._completed_fetches.append(completed_fetch) - records, partial = fetcher.fetched_records() - assert tp in records - assert len(records[tp]) == len(msgs) - assert all(map(lambda x: isinstance(x, ConsumerRecord), records[tp])) - assert partial is False - - -@pytest.mark.parametrize(("fetch_request", "fetch_response", "num_partitions"), [ - ( - FetchRequest[0]( - -1, 100, 100, - [('foo', [(0, 0, 1000),])]), - FetchResponse[0]( - [("foo", [(0, 0, 1000, [(0, b'xxx'),])]),]), - 1, - ), - ( - FetchRequest[1]( - -1, 100, 100, - [('foo', [(0, 0, 1000), (1, 0, 1000),])]), - FetchResponse[1]( - 0, - [("foo", [ - (0, 0, 1000, [(0, b'xxx'),]), - (1, 0, 1000, [(0, b'xxx'),]), - ]),]), - 2, - ), - ( - FetchRequest[2]( - -1, 100, 100, - [('foo', [(0, 0, 1000),])]), - FetchResponse[2]( - 0, [("foo", [(0, 0, 1000, [(0, b'xxx'),])]),]), - 1, - ), - ( - FetchRequest[3]( - -1, 100, 100, 10000, - [('foo', [(0, 0, 1000),])]), - FetchResponse[3]( - 0, [("foo", [(0, 0, 1000, [(0, b'xxx'),])]),]), - 1, - ), - ( - FetchRequest[4]( - -1, 100, 100, 10000, 0, - [('foo', [(0, 0, 1000),])]), - FetchResponse[4]( - 0, [("foo", [(0, 0, 1000, 0, [], [(0, b'xxx'),])]),]), - 1, - ), - ( - # This may only be used in broker-broker api calls - FetchRequest[5]( - -1, 100, 100, 10000, 0, - [('foo', [(0, 0, 1000),])]), - FetchResponse[5]( - 0, [("foo", [(0, 0, 1000, 0, 0, [], [(0, b'xxx'),])]),]), - 1, - ), -]) -def test__handle_fetch_response(fetcher, fetch_request, fetch_response, num_partitions): - fetcher._handle_fetch_response(fetch_request, time.time(), fetch_response) - assert len(fetcher._completed_fetches) == num_partitions - - -def test__unpack_message_set(fetcher): - fetcher.config['check_crcs'] = False - tp = TopicPartition('foo', 0) - messages = [ - (None, b"a", None), - (None, b"b", None), - (None, b"c", None), - ] - memory_records = MemoryRecords(_build_record_batch(messages)) - records = list(fetcher._unpack_message_set(tp, memory_records)) - assert len(records) == 3 - assert all(map(lambda x: isinstance(x, ConsumerRecord), records)) - assert records[0].value == b'a' - assert records[1].value == b'b' - assert records[2].value == b'c' - assert records[0].offset == 0 - assert records[1].offset == 1 - assert records[2].offset == 2 - - -def test__message_generator(fetcher, topic, mocker): - fetcher.config['check_crcs'] = False - tp = TopicPartition(topic, 0) - msgs = [] - for i in range(10): - msgs.append((None, b"foo", None)) - completed_fetch = CompletedFetch( - tp, 0, 0, [0, 100, _build_record_batch(msgs)], - mocker.MagicMock() - ) - fetcher._completed_fetches.append(completed_fetch) - for i in range(10): - msg = next(fetcher) - assert isinstance(msg, ConsumerRecord) - assert msg.offset == i - assert msg.value == b'foo' - - -def test__parse_fetched_data(fetcher, topic, mocker): - fetcher.config['check_crcs'] = False - tp = TopicPartition(topic, 0) - msgs = [] - for i in range(10): - msgs.append((None, b"foo", None)) - completed_fetch = CompletedFetch( - tp, 0, 0, [0, 100, _build_record_batch(msgs)], - mocker.MagicMock() - ) - partition_record = fetcher._parse_fetched_data(completed_fetch) - assert isinstance(partition_record, fetcher.PartitionRecords) - assert len(partition_record) == 10 - - -def test__parse_fetched_data__paused(fetcher, topic, mocker): - fetcher.config['check_crcs'] = False - tp = TopicPartition(topic, 0) - msgs = [] - for i in range(10): - msgs.append((None, b"foo", None)) - completed_fetch = CompletedFetch( - tp, 0, 0, [0, 100, _build_record_batch(msgs)], - mocker.MagicMock() - ) - fetcher._subscriptions.pause(tp) - partition_record = fetcher._parse_fetched_data(completed_fetch) - assert partition_record is None - - -def test__parse_fetched_data__stale_offset(fetcher, topic, mocker): - fetcher.config['check_crcs'] = False - tp = TopicPartition(topic, 0) - msgs = [] - for i in range(10): - msgs.append((None, b"foo", None)) - completed_fetch = CompletedFetch( - tp, 10, 0, [0, 100, _build_record_batch(msgs)], - mocker.MagicMock() - ) - partition_record = fetcher._parse_fetched_data(completed_fetch) - assert partition_record is None - - -def test__parse_fetched_data__not_leader(fetcher, topic, mocker): - fetcher.config['check_crcs'] = False - tp = TopicPartition(topic, 0) - completed_fetch = CompletedFetch( - tp, 0, 0, [NotLeaderForPartitionError.errno, -1, None], - mocker.MagicMock() - ) - partition_record = fetcher._parse_fetched_data(completed_fetch) - assert partition_record is None - fetcher._client.cluster.request_update.assert_called_with() - - -def test__parse_fetched_data__unknown_tp(fetcher, topic, mocker): - fetcher.config['check_crcs'] = False - tp = TopicPartition(topic, 0) - completed_fetch = CompletedFetch( - tp, 0, 0, [UnknownTopicOrPartitionError.errno, -1, None], - mocker.MagicMock() - ) - partition_record = fetcher._parse_fetched_data(completed_fetch) - assert partition_record is None - fetcher._client.cluster.request_update.assert_called_with() - - -def test__parse_fetched_data__out_of_range(fetcher, topic, mocker): - fetcher.config['check_crcs'] = False - tp = TopicPartition(topic, 0) - completed_fetch = CompletedFetch( - tp, 0, 0, [OffsetOutOfRangeError.errno, -1, None], - mocker.MagicMock() - ) - partition_record = fetcher._parse_fetched_data(completed_fetch) - assert partition_record is None - assert fetcher._subscriptions.assignment[tp].awaiting_reset is True - - -def test_partition_records_offset(): - """Test that compressed messagesets are handle correctly - when fetch offset is in the middle of the message list - """ - batch_start = 120 - batch_end = 130 - fetch_offset = 123 - tp = TopicPartition('foo', 0) - messages = [ConsumerRecord(tp.topic, tp.partition, i, - None, None, 'key', 'value', [], 'checksum', 0, 0, -1) - for i in range(batch_start, batch_end)] - records = Fetcher.PartitionRecords(fetch_offset, None, messages) - assert len(records) > 0 - msgs = records.take(1) - assert msgs[0].offset == fetch_offset - assert records.fetch_offset == fetch_offset + 1 - msgs = records.take(2) - assert len(msgs) == 2 - assert len(records) > 0 - records.discard() - assert len(records) == 0 - - -def test_partition_records_empty(): - records = Fetcher.PartitionRecords(0, None, []) - assert len(records) == 0 - - -def test_partition_records_no_fetch_offset(): - batch_start = 0 - batch_end = 100 - fetch_offset = 123 - tp = TopicPartition('foo', 0) - messages = [ConsumerRecord(tp.topic, tp.partition, i, - None, None, 'key', 'value', None, 'checksum', 0, 0, -1) - for i in range(batch_start, batch_end)] - records = Fetcher.PartitionRecords(fetch_offset, None, messages) - assert len(records) == 0 - - -def test_partition_records_compacted_offset(): - """Test that messagesets are handle correctly - when the fetch offset points to a message that has been compacted - """ - batch_start = 0 - batch_end = 100 - fetch_offset = 42 - tp = TopicPartition('foo', 0) - messages = [ConsumerRecord(tp.topic, tp.partition, i, - None, None, 'key', 'value', None, 'checksum', 0, 0, -1) - for i in range(batch_start, batch_end) if i != fetch_offset] - records = Fetcher.PartitionRecords(fetch_offset, None, messages) - assert len(records) == batch_end - fetch_offset - 1 - msgs = records.take(1) - assert msgs[0].offset == fetch_offset + 1 diff --git a/tests/kafka/test_package.py b/tests/kafka/test_package.py deleted file mode 100644 index aa42c9ce..00000000 --- a/tests/kafka/test_package.py +++ /dev/null @@ -1,25 +0,0 @@ -class TestPackage: - def test_top_level_namespace(self): - import kafka as kafka1 - assert kafka1.KafkaConsumer.__name__ == "KafkaConsumer" - assert kafka1.consumer.__name__ == "kafka.consumer" - assert kafka1.codec.__name__ == "kafka.codec" - - def test_submodule_namespace(self): - import kafka.client_async as client1 - assert client1.__name__ == "kafka.client_async" - - from kafka import client_async as client2 - assert client2.__name__ == "kafka.client_async" - - from kafka.client_async import KafkaClient as KafkaClient1 - assert KafkaClient1.__name__ == "KafkaClient" - - from kafka import KafkaClient as KafkaClient2 - assert KafkaClient2.__name__ == "KafkaClient" - - from kafka.codec import gzip_encode as gzip_encode1 - assert gzip_encode1.__name__ == "gzip_encode" - - from kafka.codec import snappy_encode - assert snappy_encode.__name__ == "snappy_encode" diff --git a/tests/kafka/test_producer.py b/tests/kafka/test_producer.py deleted file mode 100644 index 97099e93..00000000 --- a/tests/kafka/test_producer.py +++ /dev/null @@ -1,137 +0,0 @@ -import gc -import platform -import time -import threading - -import pytest - -from kafka import KafkaConsumer, KafkaProducer, TopicPartition -from kafka.producer.buffer import SimpleBufferPool -from tests.kafka.testutil import env_kafka_version, random_string - - -def test_buffer_pool(): - pool = SimpleBufferPool(1000, 1000) - - buf1 = pool.allocate(1000, 1000) - message = ''.join(map(str, range(100))) - buf1.write(message.encode('utf-8')) - pool.deallocate(buf1) - - buf2 = pool.allocate(1000, 1000) - assert buf2.read() == b'' - - -@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set") -@pytest.mark.parametrize("compression", [None, 'gzip', 'snappy', 'lz4', 'zstd']) -def test_end_to_end(kafka_broker, compression): - if compression == 'lz4': - if env_kafka_version() < (0, 8, 2): - pytest.skip('LZ4 requires 0.8.2') - elif platform.python_implementation() == 'PyPy': - pytest.skip('python-lz4 crashes on older versions of pypy') - - if compression == 'zstd' and env_kafka_version() < (2, 1, 0): - pytest.skip('zstd requires kafka 2.1.0 or newer') - - connect_str = ':'.join([kafka_broker.host, str(kafka_broker.port)]) - producer = KafkaProducer(bootstrap_servers=connect_str, - retries=5, - max_block_ms=30000, - compression_type=compression, - value_serializer=str.encode) - consumer = KafkaConsumer(bootstrap_servers=connect_str, - group_id=None, - consumer_timeout_ms=30000, - auto_offset_reset='earliest', - value_deserializer=bytes.decode) - - topic = random_string(5) - - messages = 100 - futures = [] - for i in range(messages): - futures.append(producer.send(topic, 'msg %d' % i)) - ret = [f.get(timeout=30) for f in futures] - assert len(ret) == messages - producer.close() - - consumer.subscribe([topic]) - msgs = set() - for i in range(messages): - try: - msgs.add(next(consumer).value) - except StopIteration: - break - - assert msgs == set(['msg %d' % (i,) for i in range(messages)]) - consumer.close() - - -@pytest.mark.skipif(platform.python_implementation() != 'CPython', - reason='Test relies on CPython-specific gc policies') -def test_kafka_producer_gc_cleanup(): - gc.collect() - threads = threading.active_count() - producer = KafkaProducer(api_version='0.9') # set api_version explicitly to avoid auto-detection - assert threading.active_count() == threads + 1 - del(producer) - gc.collect() - assert threading.active_count() == threads - - -@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set") -@pytest.mark.parametrize("compression", [None, 'gzip', 'snappy', 'lz4', 'zstd']) -def test_kafka_producer_proper_record_metadata(kafka_broker, compression): - if compression == 'zstd' and env_kafka_version() < (2, 1, 0): - pytest.skip('zstd requires 2.1.0 or more') - connect_str = ':'.join([kafka_broker.host, str(kafka_broker.port)]) - producer = KafkaProducer(bootstrap_servers=connect_str, - retries=5, - max_block_ms=30000, - compression_type=compression) - magic = producer._max_usable_produce_magic() - - # record headers are supported in 0.11.0 - if env_kafka_version() < (0, 11, 0): - headers = None - else: - headers = [("Header Key", b"Header Value")] - - topic = random_string(5) - future = producer.send( - topic, - value=b"Simple value", key=b"Simple key", headers=headers, timestamp_ms=9999999, - partition=0) - record = future.get(timeout=5) - assert record is not None - assert record.topic == topic - assert record.partition == 0 - assert record.topic_partition == TopicPartition(topic, 0) - assert record.offset == 0 - if magic >= 1: - assert record.timestamp == 9999999 - else: - assert record.timestamp == -1 # NO_TIMESTAMP - - if magic >= 2: - assert record.checksum is None - elif magic == 1: - assert record.checksum == 1370034956 - else: - assert record.checksum == 3296137851 - - assert record.serialized_key_size == 10 - assert record.serialized_value_size == 12 - if headers: - assert record.serialized_header_size == 22 - - if magic == 0: - pytest.skip('generated timestamp case is skipped for broker 0.9 and below') - send_time = time.time() * 1000 - future = producer.send( - topic, - value=b"Simple value", key=b"Simple key", timestamp_ms=None, - partition=0) - record = future.get(timeout=5) - assert abs(record.timestamp - send_time) <= 1000 # Allow 1s deviation diff --git a/tests/kafka/test_sender.py b/tests/kafka/test_sender.py deleted file mode 100644 index 2a68defc..00000000 --- a/tests/kafka/test_sender.py +++ /dev/null @@ -1,53 +0,0 @@ -# pylint: skip-file -from __future__ import absolute_import - -import pytest -import io - -from kafka.client_async import KafkaClient -from kafka.cluster import ClusterMetadata -from kafka.metrics import Metrics -from kafka.protocol.produce import ProduceRequest -from kafka.producer.record_accumulator import RecordAccumulator, ProducerBatch -from kafka.producer.sender import Sender -from kafka.record.memory_records import MemoryRecordsBuilder -from kafka.structs import TopicPartition - - -@pytest.fixture -def client(mocker): - _cli = mocker.Mock(spec=KafkaClient(bootstrap_servers=(), api_version=(0, 9))) - _cli.cluster = mocker.Mock(spec=ClusterMetadata()) - return _cli - - -@pytest.fixture -def accumulator(): - return RecordAccumulator() - - -@pytest.fixture -def metrics(): - return Metrics() - - -@pytest.fixture -def sender(client, accumulator, metrics): - return Sender(client, client.cluster, accumulator, metrics) - - -@pytest.mark.parametrize(("api_version", "produce_version"), [ - ((0, 10), 2), - ((0, 9), 1), - ((0, 8), 0) -]) -def test_produce_request(sender, mocker, api_version, produce_version): - sender.config['api_version'] = api_version - tp = TopicPartition('foo', 0) - buffer = io.BytesIO() - records = MemoryRecordsBuilder( - magic=1, compression_type=0, batch_size=100000) - batch = ProducerBatch(tp, records, buffer) - records.close() - produce_request = sender._produce_request(0, 0, 0, [batch]) - assert isinstance(produce_request, ProduceRequest[produce_version]) From 29526ce5212de12240ba5d26f8d399986b50c38f Mon Sep 17 00:00:00 2001 From: Denis Otkidach Date: Sun, 22 Oct 2023 16:03:37 +0300 Subject: [PATCH 07/20] Untie ConsumerRebalanceListener --- aiokafka/abc.py | 3 +- kafka/__init__.py | 1 - kafka/consumer/__init__.py | 0 kafka/consumer/subscription_state.py | 501 ------------------ setup.cfg | 1 + tests/kafka/conftest.py | 8 - tests/kafka/test_consumer_integration.py | 297 ----------- tests/kafka/test_coordinator.py | 638 ----------------------- tests/kafka/test_subscription_state.py | 25 - 9 files changed, 2 insertions(+), 1472 deletions(-) delete mode 100644 kafka/consumer/__init__.py delete mode 100644 kafka/consumer/subscription_state.py delete mode 100644 tests/kafka/test_consumer_integration.py delete mode 100644 tests/kafka/test_coordinator.py delete mode 100644 tests/kafka/test_subscription_state.py diff --git a/aiokafka/abc.py b/aiokafka/abc.py index 6d1815a6..d51b986e 100644 --- a/aiokafka/abc.py +++ b/aiokafka/abc.py @@ -1,8 +1,7 @@ import abc -from kafka import ConsumerRebalanceListener as BaseConsumerRebalanceListener -class ConsumerRebalanceListener(BaseConsumerRebalanceListener): +class ConsumerRebalanceListener(abc.ABC): """ A callback interface that the user can implement to trigger custom actions when the set of partitions assigned to the consumer changes. diff --git a/kafka/__init__.py b/kafka/__init__.py index c4308c5e..2a335d23 100644 --- a/kafka/__init__.py +++ b/kafka/__init__.py @@ -20,7 +20,6 @@ def emit(self, record): from kafka.admin import KafkaAdminClient from kafka.client_async import KafkaClient -from kafka.consumer.subscription_state import ConsumerRebalanceListener from kafka.conn import BrokerConnection from kafka.serializer import Serializer, Deserializer from kafka.structs import TopicPartition, OffsetAndMetadata diff --git a/kafka/consumer/__init__.py b/kafka/consumer/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/kafka/consumer/subscription_state.py b/kafka/consumer/subscription_state.py deleted file mode 100644 index 08842d13..00000000 --- a/kafka/consumer/subscription_state.py +++ /dev/null @@ -1,501 +0,0 @@ -from __future__ import absolute_import - -import abc -import logging -import re - -from kafka.vendor import six - -from kafka.errors import IllegalStateError -from kafka.protocol.offset import OffsetResetStrategy -from kafka.structs import OffsetAndMetadata - -log = logging.getLogger(__name__) - - -class SubscriptionState(object): - """ - A class for tracking the topics, partitions, and offsets for the consumer. - A partition is "assigned" either directly with assign_from_user() (manual - assignment) or with assign_from_subscribed() (automatic assignment from - subscription). - - Once assigned, the partition is not considered "fetchable" until its initial - position has been set with seek(). Fetchable partitions track a fetch - position which is used to set the offset of the next fetch, and a consumed - position which is the last offset that has been returned to the user. You - can suspend fetching from a partition through pause() without affecting the - fetched/consumed offsets. The partition will remain unfetchable until the - resume() is used. You can also query the pause state independently with - is_paused(). - - Note that pause state as well as fetch/consumed positions are not preserved - when partition assignment is changed whether directly by the user or - through a group rebalance. - - This class also maintains a cache of the latest commit position for each of - the assigned partitions. This is updated through committed() and can be used - to set the initial fetch position (e.g. Fetcher._reset_offset() ). - """ - _SUBSCRIPTION_EXCEPTION_MESSAGE = ( - "You must choose only one way to configure your consumer:" - " (1) subscribe to specific topics by name," - " (2) subscribe to topics matching a regex pattern," - " (3) assign itself specific topic-partitions.") - - # Taken from: https://github.com/apache/kafka/blob/39eb31feaeebfb184d98cc5d94da9148c2319d81/clients/src/main/java/org/apache/kafka/common/internals/Topic.java#L29 - _MAX_NAME_LENGTH = 249 - _TOPIC_LEGAL_CHARS = re.compile('^[a-zA-Z0-9._-]+$') - - def __init__(self, offset_reset_strategy='earliest'): - """Initialize a SubscriptionState instance - - Keyword Arguments: - offset_reset_strategy: 'earliest' or 'latest', otherwise - exception will be raised when fetching an offset that is no - longer available. Default: 'earliest' - """ - try: - offset_reset_strategy = getattr(OffsetResetStrategy, - offset_reset_strategy.upper()) - except AttributeError: - log.warning('Unrecognized offset_reset_strategy, using NONE') - offset_reset_strategy = OffsetResetStrategy.NONE - self._default_offset_reset_strategy = offset_reset_strategy - - self.subscription = None # set() or None - self.subscribed_pattern = None # regex str or None - self._group_subscription = set() - self._user_assignment = set() - self.assignment = dict() - self.listener = None - - # initialize to true for the consumers to fetch offset upon starting up - self.needs_fetch_committed_offsets = True - - def subscribe(self, topics=(), pattern=None, listener=None): - """Subscribe to a list of topics, or a topic regex pattern. - - Partitions will be dynamically assigned via a group coordinator. - Topic subscriptions are not incremental: this list will replace the - current assignment (if there is one). - - This method is incompatible with assign_from_user() - - Arguments: - topics (list): List of topics for subscription. - pattern (str): Pattern to match available topics. You must provide - either topics or pattern, but not both. - listener (ConsumerRebalanceListener): Optionally include listener - callback, which will be called before and after each rebalance - operation. - - As part of group management, the consumer will keep track of the - list of consumers that belong to a particular group and will - trigger a rebalance operation if one of the following events - trigger: - - * Number of partitions change for any of the subscribed topics - * Topic is created or deleted - * An existing member of the consumer group dies - * A new member is added to the consumer group - - When any of these events are triggered, the provided listener - will be invoked first to indicate that the consumer's assignment - has been revoked, and then again when the new assignment has - been received. Note that this listener will immediately override - any listener set in a previous call to subscribe. It is - guaranteed, however, that the partitions revoked/assigned - through this interface are from topics subscribed in this call. - """ - if self._user_assignment or (topics and pattern): - raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE) - assert topics or pattern, 'Must provide topics or pattern' - - if pattern: - log.info('Subscribing to pattern: /%s/', pattern) - self.subscription = set() - self.subscribed_pattern = re.compile(pattern) - else: - self.change_subscription(topics) - - if listener and not isinstance(listener, ConsumerRebalanceListener): - raise TypeError('listener must be a ConsumerRebalanceListener') - self.listener = listener - - def _ensure_valid_topic_name(self, topic): - """ Ensures that the topic name is valid according to the kafka source. """ - - # See Kafka Source: - # https://github.com/apache/kafka/blob/39eb31feaeebfb184d98cc5d94da9148c2319d81/clients/src/main/java/org/apache/kafka/common/internals/Topic.java - if topic is None: - raise TypeError('All topics must not be None') - if not isinstance(topic, six.string_types): - raise TypeError('All topics must be strings') - if len(topic) == 0: - raise ValueError('All topics must be non-empty strings') - if topic == '.' or topic == '..': - raise ValueError('Topic name cannot be "." or ".."') - if len(topic) > self._MAX_NAME_LENGTH: - raise ValueError('Topic name is illegal, it can\'t be longer than {0} characters, topic: "{1}"'.format(self._MAX_NAME_LENGTH, topic)) - if not self._TOPIC_LEGAL_CHARS.match(topic): - raise ValueError('Topic name "{0}" is illegal, it contains a character other than ASCII alphanumerics, ".", "_" and "-"'.format(topic)) - - def change_subscription(self, topics): - """Change the topic subscription. - - Arguments: - topics (list of str): topics for subscription - - Raises: - IllegalStateError: if assign_from_user has been used already - TypeError: if a topic is None or a non-str - ValueError: if a topic is an empty string or - - a topic name is '.' or '..' or - - a topic name does not consist of ASCII-characters/'-'/'_'/'.' - """ - if self._user_assignment: - raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE) - - if isinstance(topics, six.string_types): - topics = [topics] - - if self.subscription == set(topics): - log.warning("subscription unchanged by change_subscription(%s)", - topics) - return - - for t in topics: - self._ensure_valid_topic_name(t) - - log.info('Updating subscribed topics to: %s', topics) - self.subscription = set(topics) - self._group_subscription.update(topics) - - # Remove any assigned partitions which are no longer subscribed to - for tp in set(self.assignment.keys()): - if tp.topic not in self.subscription: - del self.assignment[tp] - - def group_subscribe(self, topics): - """Add topics to the current group subscription. - - This is used by the group leader to ensure that it receives metadata - updates for all topics that any member of the group is subscribed to. - - Arguments: - topics (list of str): topics to add to the group subscription - """ - if self._user_assignment: - raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE) - self._group_subscription.update(topics) - - def reset_group_subscription(self): - """Reset the group's subscription to only contain topics subscribed by this consumer.""" - if self._user_assignment: - raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE) - assert self.subscription is not None, 'Subscription required' - self._group_subscription.intersection_update(self.subscription) - - def assign_from_user(self, partitions): - """Manually assign a list of TopicPartitions to this consumer. - - This interface does not allow for incremental assignment and will - replace the previous assignment (if there was one). - - Manual topic assignment through this method does not use the consumer's - group management functionality. As such, there will be no rebalance - operation triggered when group membership or cluster and topic metadata - change. Note that it is not possible to use both manual partition - assignment with assign() and group assignment with subscribe(). - - Arguments: - partitions (list of TopicPartition): assignment for this instance. - - Raises: - IllegalStateError: if consumer has already called subscribe() - """ - if self.subscription is not None: - raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE) - - if self._user_assignment != set(partitions): - self._user_assignment = set(partitions) - - for partition in partitions: - if partition not in self.assignment: - self._add_assigned_partition(partition) - - for tp in set(self.assignment.keys()) - self._user_assignment: - del self.assignment[tp] - - self.needs_fetch_committed_offsets = True - - def assign_from_subscribed(self, assignments): - """Update the assignment to the specified partitions - - This method is called by the coordinator to dynamically assign - partitions based on the consumer's topic subscription. This is different - from assign_from_user() which directly sets the assignment from a - user-supplied TopicPartition list. - - Arguments: - assignments (list of TopicPartition): partitions to assign to this - consumer instance. - """ - if not self.partitions_auto_assigned(): - raise IllegalStateError(self._SUBSCRIPTION_EXCEPTION_MESSAGE) - - for tp in assignments: - if tp.topic not in self.subscription: - raise ValueError("Assigned partition %s for non-subscribed topic." % (tp,)) - - # after rebalancing, we always reinitialize the assignment state - self.assignment.clear() - for tp in assignments: - self._add_assigned_partition(tp) - self.needs_fetch_committed_offsets = True - log.info("Updated partition assignment: %s", assignments) - - def unsubscribe(self): - """Clear all topic subscriptions and partition assignments""" - self.subscription = None - self._user_assignment.clear() - self.assignment.clear() - self.subscribed_pattern = None - - def group_subscription(self): - """Get the topic subscription for the group. - - For the leader, this will include the union of all member subscriptions. - For followers, it is the member's subscription only. - - This is used when querying topic metadata to detect metadata changes - that would require rebalancing (the leader fetches metadata for all - topics in the group so that it can do partition assignment). - - Returns: - set: topics - """ - return self._group_subscription - - def seek(self, partition, offset): - """Manually specify the fetch offset for a TopicPartition. - - Overrides the fetch offsets that the consumer will use on the next - poll(). If this API is invoked for the same partition more than once, - the latest offset will be used on the next poll(). Note that you may - lose data if this API is arbitrarily used in the middle of consumption, - to reset the fetch offsets. - - Arguments: - partition (TopicPartition): partition for seek operation - offset (int): message offset in partition - """ - self.assignment[partition].seek(offset) - - def assigned_partitions(self): - """Return set of TopicPartitions in current assignment.""" - return set(self.assignment.keys()) - - def paused_partitions(self): - """Return current set of paused TopicPartitions.""" - return set(partition for partition in self.assignment - if self.is_paused(partition)) - - def fetchable_partitions(self): - """Return set of TopicPartitions that should be Fetched.""" - fetchable = set() - for partition, state in six.iteritems(self.assignment): - if state.is_fetchable(): - fetchable.add(partition) - return fetchable - - def partitions_auto_assigned(self): - """Return True unless user supplied partitions manually.""" - return self.subscription is not None - - def all_consumed_offsets(self): - """Returns consumed offsets as {TopicPartition: OffsetAndMetadata}""" - all_consumed = {} - for partition, state in six.iteritems(self.assignment): - if state.has_valid_position: - all_consumed[partition] = OffsetAndMetadata(state.position, '') - return all_consumed - - def need_offset_reset(self, partition, offset_reset_strategy=None): - """Mark partition for offset reset using specified or default strategy. - - Arguments: - partition (TopicPartition): partition to mark - offset_reset_strategy (OffsetResetStrategy, optional) - """ - if offset_reset_strategy is None: - offset_reset_strategy = self._default_offset_reset_strategy - self.assignment[partition].await_reset(offset_reset_strategy) - - def has_default_offset_reset_policy(self): - """Return True if default offset reset policy is Earliest or Latest""" - return self._default_offset_reset_strategy != OffsetResetStrategy.NONE - - def is_offset_reset_needed(self, partition): - return self.assignment[partition].awaiting_reset - - def has_all_fetch_positions(self): - for state in self.assignment.values(): - if not state.has_valid_position: - return False - return True - - def missing_fetch_positions(self): - missing = set() - for partition, state in six.iteritems(self.assignment): - if not state.has_valid_position: - missing.add(partition) - return missing - - def is_assigned(self, partition): - return partition in self.assignment - - def is_paused(self, partition): - return partition in self.assignment and self.assignment[partition].paused - - def is_fetchable(self, partition): - return partition in self.assignment and self.assignment[partition].is_fetchable() - - def pause(self, partition): - self.assignment[partition].pause() - - def resume(self, partition): - self.assignment[partition].resume() - - def _add_assigned_partition(self, partition): - self.assignment[partition] = TopicPartitionState() - - -class TopicPartitionState(object): - def __init__(self): - self.committed = None # last committed OffsetAndMetadata - self.has_valid_position = False # whether we have valid position - self.paused = False # whether this partition has been paused by the user - self.awaiting_reset = False # whether we are awaiting reset - self.reset_strategy = None # the reset strategy if awaitingReset is set - self._position = None # offset exposed to the user - self.highwater = None - self.drop_pending_message_set = False - # The last message offset hint available from a message batch with - # magic=2 which includes deleted compacted messages - self.last_offset_from_message_batch = None - - def _set_position(self, offset): - assert self.has_valid_position, 'Valid position required' - self._position = offset - - def _get_position(self): - return self._position - - position = property(_get_position, _set_position, None, "last position") - - def await_reset(self, strategy): - self.awaiting_reset = True - self.reset_strategy = strategy - self._position = None - self.last_offset_from_message_batch = None - self.has_valid_position = False - - def seek(self, offset): - self._position = offset - self.awaiting_reset = False - self.reset_strategy = None - self.has_valid_position = True - self.drop_pending_message_set = True - self.last_offset_from_message_batch = None - - def pause(self): - self.paused = True - - def resume(self): - self.paused = False - - def is_fetchable(self): - return not self.paused and self.has_valid_position - - -class ConsumerRebalanceListener(object): - """ - A callback interface that the user can implement to trigger custom actions - when the set of partitions assigned to the consumer changes. - - This is applicable when the consumer is having Kafka auto-manage group - membership. If the consumer's directly assign partitions, those - partitions will never be reassigned and this callback is not applicable. - - When Kafka is managing the group membership, a partition re-assignment will - be triggered any time the members of the group changes or the subscription - of the members changes. This can occur when processes die, new process - instances are added or old instances come back to life after failure. - Rebalances can also be triggered by changes affecting the subscribed - topics (e.g. when then number of partitions is administratively adjusted). - - There are many uses for this functionality. One common use is saving offsets - in a custom store. By saving offsets in the on_partitions_revoked(), call we - can ensure that any time partition assignment changes the offset gets saved. - - Another use is flushing out any kind of cache of intermediate results the - consumer may be keeping. For example, consider a case where the consumer is - subscribed to a topic containing user page views, and the goal is to count - the number of page views per users for each five minute window. Let's say - the topic is partitioned by the user id so that all events for a particular - user will go to a single consumer instance. The consumer can keep in memory - a running tally of actions per user and only flush these out to a remote - data store when its cache gets too big. However if a partition is reassigned - it may want to automatically trigger a flush of this cache, before the new - owner takes over consumption. - - This callback will execute in the user thread as part of the Consumer.poll() - whenever partition assignment changes. - - It is guaranteed that all consumer processes will invoke - on_partitions_revoked() prior to any process invoking - on_partitions_assigned(). So if offsets or other state is saved in the - on_partitions_revoked() call, it should be saved by the time the process - taking over that partition has their on_partitions_assigned() callback - called to load the state. - """ - __metaclass__ = abc.ABCMeta - - @abc.abstractmethod - def on_partitions_revoked(self, revoked): - """ - A callback method the user can implement to provide handling of offset - commits to a customized store on the start of a rebalance operation. - This method will be called before a rebalance operation starts and - after the consumer stops fetching data. It is recommended that offsets - should be committed in this callback to either Kafka or a custom offset - store to prevent duplicate data. - - NOTE: This method is only called before rebalances. It is not called - prior to KafkaConsumer.close() - - Arguments: - revoked (list of TopicPartition): the partitions that were assigned - to the consumer on the last rebalance - """ - pass - - @abc.abstractmethod - def on_partitions_assigned(self, assigned): - """ - A callback method the user can implement to provide handling of - customized offsets on completion of a successful partition - re-assignment. This method will be called after an offset re-assignment - completes and before the consumer starts fetching data. - - It is guaranteed that all the processes in a consumer group will execute - their on_partitions_revoked() callback before any instance executes its - on_partitions_assigned() callback. - - Arguments: - assigned (list of TopicPartition): the partitions assigned to the - consumer (may include partitions that were previously assigned) - """ - pass diff --git a/setup.cfg b/setup.cfg index 6b23dfa0..ff457c50 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,6 +5,7 @@ exclude = venv __pycache__ .tox + tests/kafka/ [isort] line_length=88 diff --git a/tests/kafka/conftest.py b/tests/kafka/conftest.py index 0bbd1a2b..2fd11b40 100644 --- a/tests/kafka/conftest.py +++ b/tests/kafka/conftest.py @@ -42,14 +42,6 @@ def factory(**broker_params): broker.close() -@pytest.fixture -def kafka_client(kafka_broker, request): - """Return a KafkaClient fixture""" - (client,) = kafka_broker.get_clients(cnt=1, client_id='%s_client' % (request.node.name,)) - yield client - client.close() - - @pytest.fixture def kafka_consumer(kafka_consumer_factory): """Return a KafkaConsumer fixture""" diff --git a/tests/kafka/test_consumer_integration.py b/tests/kafka/test_consumer_integration.py deleted file mode 100644 index a2644bae..00000000 --- a/tests/kafka/test_consumer_integration.py +++ /dev/null @@ -1,297 +0,0 @@ -import logging -import time - -from unittest.mock import patch -import pytest -from kafka.vendor.six.moves import range - -import kafka.codec -from kafka.errors import UnsupportedCodecError, UnsupportedVersionError -from kafka.structs import TopicPartition, OffsetAndTimestamp - -from tests.kafka.testutil import Timer, assert_message_count, env_kafka_version, random_string - - -@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set") -def test_kafka_version_infer(kafka_consumer_factory): - consumer = kafka_consumer_factory() - actual_ver_major_minor = env_kafka_version()[:2] - client = consumer._client - conn = list(client._conns.values())[0] - inferred_ver_major_minor = conn.check_version()[:2] - assert actual_ver_major_minor == inferred_ver_major_minor, \ - "Was expecting inferred broker version to be %s but was %s" % (actual_ver_major_minor, inferred_ver_major_minor) - - -@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set") -def test_kafka_consumer(kafka_consumer_factory, send_messages): - """Test KafkaConsumer""" - consumer = kafka_consumer_factory(auto_offset_reset='earliest') - send_messages(range(0, 100), partition=0) - send_messages(range(0, 100), partition=1) - cnt = 0 - messages = {0: [], 1: []} - for message in consumer: - logging.debug("Consumed message %s", repr(message)) - cnt += 1 - messages[message.partition].append(message) - if cnt >= 200: - break - - assert_message_count(messages[0], 100) - assert_message_count(messages[1], 100) - - -@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set") -def test_kafka_consumer_unsupported_encoding( - topic, kafka_producer_factory, kafka_consumer_factory): - # Send a compressed message - producer = kafka_producer_factory(compression_type="gzip") - fut = producer.send(topic, b"simple message" * 200) - fut.get(timeout=5) - producer.close() - - # Consume, but with the related compression codec not available - with patch.object(kafka.codec, "has_gzip") as mocked: - mocked.return_value = False - consumer = kafka_consumer_factory(auto_offset_reset='earliest') - error_msg = "Libraries for gzip compression codec not found" - with pytest.raises(UnsupportedCodecError, match=error_msg): - consumer.poll(timeout_ms=2000) - - -@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set") -def test_kafka_consumer__blocking(kafka_consumer_factory, topic, send_messages): - TIMEOUT_MS = 500 - consumer = kafka_consumer_factory(auto_offset_reset='earliest', - enable_auto_commit=False, - consumer_timeout_ms=TIMEOUT_MS) - - # Manual assignment avoids overhead of consumer group mgmt - consumer.unsubscribe() - consumer.assign([TopicPartition(topic, 0)]) - - # Ask for 5 messages, nothing in queue, block 500ms - with Timer() as t: - with pytest.raises(StopIteration): - msg = next(consumer) - assert t.interval >= (TIMEOUT_MS / 1000.0) - - send_messages(range(0, 10)) - - # Ask for 5 messages, 10 in queue. Get 5 back, no blocking - messages = [] - with Timer() as t: - for i in range(5): - msg = next(consumer) - messages.append(msg) - assert_message_count(messages, 5) - assert t.interval < (TIMEOUT_MS / 1000.0) - - # Ask for 10 messages, get 5 back, block 500ms - messages = [] - with Timer() as t: - with pytest.raises(StopIteration): - for i in range(10): - msg = next(consumer) - messages.append(msg) - assert_message_count(messages, 5) - assert t.interval >= (TIMEOUT_MS / 1000.0) - - -@pytest.mark.skipif(env_kafka_version() < (0, 8, 1), reason="Requires KAFKA_VERSION >= 0.8.1") -def test_kafka_consumer__offset_commit_resume(kafka_consumer_factory, send_messages): - GROUP_ID = random_string(10) - - send_messages(range(0, 100), partition=0) - send_messages(range(100, 200), partition=1) - - # Start a consumer and grab the first 180 messages - consumer1 = kafka_consumer_factory( - group_id=GROUP_ID, - enable_auto_commit=True, - auto_commit_interval_ms=100, - auto_offset_reset='earliest', - ) - output_msgs1 = [] - for _ in range(180): - m = next(consumer1) - output_msgs1.append(m) - assert_message_count(output_msgs1, 180) - - # Normally we let the pytest fixture `kafka_consumer_factory` handle - # closing as part of its teardown. Here we manually call close() to force - # auto-commit to occur before the second consumer starts. That way the - # second consumer only consumes previously unconsumed messages. - consumer1.close() - - # Start a second consumer to grab 181-200 - consumer2 = kafka_consumer_factory( - group_id=GROUP_ID, - enable_auto_commit=True, - auto_commit_interval_ms=100, - auto_offset_reset='earliest', - ) - output_msgs2 = [] - for _ in range(20): - m = next(consumer2) - output_msgs2.append(m) - assert_message_count(output_msgs2, 20) - - # Verify the second consumer wasn't reconsuming messages that the first - # consumer already saw - assert_message_count(output_msgs1 + output_msgs2, 200) - - -@pytest.mark.skipif(env_kafka_version() < (0, 10, 1), reason="Requires KAFKA_VERSION >= 0.10.1") -def test_kafka_consumer_max_bytes_simple(kafka_consumer_factory, topic, send_messages): - send_messages(range(100, 200), partition=0) - send_messages(range(200, 300), partition=1) - - # Start a consumer - consumer = kafka_consumer_factory( - auto_offset_reset='earliest', fetch_max_bytes=300) - seen_partitions = set() - for i in range(90): - poll_res = consumer.poll(timeout_ms=100) - for partition, msgs in poll_res.items(): - for msg in msgs: - seen_partitions.add(partition) - - # Check that we fetched at least 1 message from both partitions - assert seen_partitions == {TopicPartition(topic, 0), TopicPartition(topic, 1)} - - -@pytest.mark.skipif(env_kafka_version() < (0, 10, 1), reason="Requires KAFKA_VERSION >= 0.10.1") -def test_kafka_consumer_max_bytes_one_msg(kafka_consumer_factory, send_messages): - # We send to only 1 partition so we don't have parallel requests to 2 - # nodes for data. - send_messages(range(100, 200)) - - # Start a consumer. FetchResponse_v3 should always include at least 1 - # full msg, so by setting fetch_max_bytes=1 we should get 1 msg at a time - # But 0.11.0.0 returns 1 MessageSet at a time when the messages are - # stored in the new v2 format by the broker. - # - # DP Note: This is a strange test. The consumer shouldn't care - # how many messages are included in a FetchResponse, as long as it is - # non-zero. I would not mind if we deleted this test. It caused - # a minor headache when testing 0.11.0.0. - group = 'test-kafka-consumer-max-bytes-one-msg-' + random_string(5) - consumer = kafka_consumer_factory( - group_id=group, - auto_offset_reset='earliest', - consumer_timeout_ms=5000, - fetch_max_bytes=1) - - fetched_msgs = [next(consumer) for i in range(10)] - assert_message_count(fetched_msgs, 10) - - -@pytest.mark.skipif(env_kafka_version() < (0, 10, 1), reason="Requires KAFKA_VERSION >= 0.10.1") -def test_kafka_consumer_offsets_for_time(topic, kafka_consumer, kafka_producer): - late_time = int(time.time()) * 1000 - middle_time = late_time - 1000 - early_time = late_time - 2000 - tp = TopicPartition(topic, 0) - - timeout = 10 - early_msg = kafka_producer.send( - topic, partition=0, value=b"first", - timestamp_ms=early_time).get(timeout) - late_msg = kafka_producer.send( - topic, partition=0, value=b"last", - timestamp_ms=late_time).get(timeout) - - consumer = kafka_consumer - offsets = consumer.offsets_for_times({tp: early_time}) - assert len(offsets) == 1 - assert offsets[tp].offset == early_msg.offset - assert offsets[tp].timestamp == early_time - - offsets = consumer.offsets_for_times({tp: middle_time}) - assert offsets[tp].offset == late_msg.offset - assert offsets[tp].timestamp == late_time - - offsets = consumer.offsets_for_times({tp: late_time}) - assert offsets[tp].offset == late_msg.offset - assert offsets[tp].timestamp == late_time - - offsets = consumer.offsets_for_times({}) - assert offsets == {} - - # Out of bound timestamps check - - offsets = consumer.offsets_for_times({tp: 0}) - assert offsets[tp].offset == early_msg.offset - assert offsets[tp].timestamp == early_time - - offsets = consumer.offsets_for_times({tp: 9999999999999}) - assert offsets[tp] is None - - # Beginning/End offsets - - offsets = consumer.beginning_offsets([tp]) - assert offsets == {tp: early_msg.offset} - offsets = consumer.end_offsets([tp]) - assert offsets == {tp: late_msg.offset + 1} - - -@pytest.mark.skipif(env_kafka_version() < (0, 10, 1), reason="Requires KAFKA_VERSION >= 0.10.1") -def test_kafka_consumer_offsets_search_many_partitions(kafka_consumer, kafka_producer, topic): - tp0 = TopicPartition(topic, 0) - tp1 = TopicPartition(topic, 1) - - send_time = int(time.time() * 1000) - timeout = 10 - p0msg = kafka_producer.send( - topic, partition=0, value=b"XXX", - timestamp_ms=send_time).get(timeout) - p1msg = kafka_producer.send( - topic, partition=1, value=b"XXX", - timestamp_ms=send_time).get(timeout) - - consumer = kafka_consumer - offsets = consumer.offsets_for_times({ - tp0: send_time, - tp1: send_time - }) - - assert offsets == { - tp0: OffsetAndTimestamp(p0msg.offset, send_time), - tp1: OffsetAndTimestamp(p1msg.offset, send_time) - } - - offsets = consumer.beginning_offsets([tp0, tp1]) - assert offsets == { - tp0: p0msg.offset, - tp1: p1msg.offset - } - - offsets = consumer.end_offsets([tp0, tp1]) - assert offsets == { - tp0: p0msg.offset + 1, - tp1: p1msg.offset + 1 - } - - -@pytest.mark.skipif(env_kafka_version() >= (0, 10, 1), reason="Requires KAFKA_VERSION < 0.10.1") -def test_kafka_consumer_offsets_for_time_old(kafka_consumer, topic): - consumer = kafka_consumer - tp = TopicPartition(topic, 0) - - with pytest.raises(UnsupportedVersionError): - consumer.offsets_for_times({tp: int(time.time())}) - - -@pytest.mark.skipif(env_kafka_version() < (0, 10, 1), reason="Requires KAFKA_VERSION >= 0.10.1") -def test_kafka_consumer_offsets_for_times_errors(kafka_consumer_factory, topic): - consumer = kafka_consumer_factory(fetch_max_wait_ms=200, - request_timeout_ms=500) - tp = TopicPartition(topic, 0) - bad_tp = TopicPartition(topic, 100) - - with pytest.raises(ValueError): - consumer.offsets_for_times({tp: -1}) - - assert consumer.offsets_for_times({bad_tp: 0}) == {bad_tp: None} diff --git a/tests/kafka/test_coordinator.py b/tests/kafka/test_coordinator.py deleted file mode 100644 index a35cdd1a..00000000 --- a/tests/kafka/test_coordinator.py +++ /dev/null @@ -1,638 +0,0 @@ -# pylint: skip-file -from __future__ import absolute_import -import time - -import pytest - -from kafka.client_async import KafkaClient -from kafka.consumer.subscription_state import ( - SubscriptionState, ConsumerRebalanceListener) -from kafka.coordinator.assignors.range import RangePartitionAssignor -from kafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor -from kafka.coordinator.assignors.sticky.sticky_assignor import StickyPartitionAssignor -from kafka.coordinator.base import Generation, MemberState, HeartbeatThread -from kafka.coordinator.consumer import ConsumerCoordinator -from kafka.coordinator.protocol import ( - ConsumerProtocolMemberMetadata, ConsumerProtocolMemberAssignment) -import kafka.errors as Errors -from kafka.future import Future -from kafka.metrics import Metrics -from kafka.protocol.commit import ( - OffsetCommitRequest, OffsetCommitResponse, - OffsetFetchRequest, OffsetFetchResponse) -from kafka.protocol.metadata import MetadataResponse -from kafka.structs import OffsetAndMetadata, TopicPartition -from kafka.util import WeakMethod - - -@pytest.fixture -def client(conn): - return KafkaClient(api_version=(0, 9)) - -@pytest.fixture -def coordinator(client): - return ConsumerCoordinator(client, SubscriptionState(), Metrics()) - - -def test_init(client, coordinator): - # metadata update on init - assert client.cluster._need_update is True - assert WeakMethod(coordinator._handle_metadata_update) in client.cluster._listeners - - -@pytest.mark.parametrize("api_version", [(0, 8, 0), (0, 8, 1), (0, 8, 2), (0, 9)]) -def test_autocommit_enable_api_version(client, api_version): - coordinator = ConsumerCoordinator(client, SubscriptionState(), - Metrics(), - enable_auto_commit=True, - session_timeout_ms=30000, # session_timeout_ms and max_poll_interval_ms - max_poll_interval_ms=30000, # should be the same to avoid KafkaConfigurationError - group_id='foobar', - api_version=api_version) - if api_version < (0, 8, 1): - assert coordinator.config['enable_auto_commit'] is False - else: - assert coordinator.config['enable_auto_commit'] is True - - -def test_protocol_type(coordinator): - assert coordinator.protocol_type() == 'consumer' - - -def test_group_protocols(coordinator): - # Requires a subscription - try: - coordinator.group_protocols() - except Errors.IllegalStateError: - pass - else: - assert False, 'Exception not raised when expected' - - coordinator._subscription.subscribe(topics=['foobar']) - assert coordinator.group_protocols() == [ - ('range', ConsumerProtocolMemberMetadata( - RangePartitionAssignor.version, - ['foobar'], - b'')), - ('roundrobin', ConsumerProtocolMemberMetadata( - RoundRobinPartitionAssignor.version, - ['foobar'], - b'')), - ('sticky', ConsumerProtocolMemberMetadata( - StickyPartitionAssignor.version, - ['foobar'], - b'')), - ] - - -@pytest.mark.parametrize('api_version', [(0, 8, 0), (0, 8, 1), (0, 8, 2), (0, 9)]) -def test_pattern_subscription(coordinator, api_version): - coordinator.config['api_version'] = api_version - coordinator._subscription.subscribe(pattern='foo') - assert coordinator._subscription.subscription == set([]) - assert coordinator._metadata_snapshot == coordinator._build_metadata_snapshot(coordinator._subscription, {}) - - cluster = coordinator._client.cluster - cluster.update_metadata(MetadataResponse[0]( - # brokers - [(0, 'foo', 12), (1, 'bar', 34)], - # topics - [(0, 'fizz', []), - (0, 'foo1', [(0, 0, 0, [], [])]), - (0, 'foo2', [(0, 0, 1, [], [])])])) - assert coordinator._subscription.subscription == {'foo1', 'foo2'} - - # 0.9 consumers should trigger dynamic partition assignment - if api_version >= (0, 9): - assert coordinator._subscription.assignment == {} - - # earlier consumers get all partitions assigned locally - else: - assert set(coordinator._subscription.assignment.keys()) == {TopicPartition('foo1', 0), - TopicPartition('foo2', 0)} - - -def test_lookup_assignor(coordinator): - assert coordinator._lookup_assignor('roundrobin') is RoundRobinPartitionAssignor - assert coordinator._lookup_assignor('range') is RangePartitionAssignor - assert coordinator._lookup_assignor('sticky') is StickyPartitionAssignor - assert coordinator._lookup_assignor('foobar') is None - - -def test_join_complete(mocker, coordinator): - coordinator._subscription.subscribe(topics=['foobar']) - assignor = RoundRobinPartitionAssignor() - coordinator.config['assignors'] = (assignor,) - mocker.spy(assignor, 'on_assignment') - assert assignor.on_assignment.call_count == 0 - assignment = ConsumerProtocolMemberAssignment(0, [('foobar', [0, 1])], b'') - coordinator._on_join_complete(0, 'member-foo', 'roundrobin', assignment.encode()) - assert assignor.on_assignment.call_count == 1 - assignor.on_assignment.assert_called_with(assignment) - - -def test_join_complete_with_sticky_assignor(mocker, coordinator): - coordinator._subscription.subscribe(topics=['foobar']) - assignor = StickyPartitionAssignor() - coordinator.config['assignors'] = (assignor,) - mocker.spy(assignor, 'on_assignment') - mocker.spy(assignor, 'on_generation_assignment') - assert assignor.on_assignment.call_count == 0 - assert assignor.on_generation_assignment.call_count == 0 - assignment = ConsumerProtocolMemberAssignment(0, [('foobar', [0, 1])], b'') - coordinator._on_join_complete(0, 'member-foo', 'sticky', assignment.encode()) - assert assignor.on_assignment.call_count == 1 - assert assignor.on_generation_assignment.call_count == 1 - assignor.on_assignment.assert_called_with(assignment) - assignor.on_generation_assignment.assert_called_with(0) - - -def test_subscription_listener(mocker, coordinator): - listener = mocker.MagicMock(spec=ConsumerRebalanceListener) - coordinator._subscription.subscribe( - topics=['foobar'], - listener=listener) - - coordinator._on_join_prepare(0, 'member-foo') - assert listener.on_partitions_revoked.call_count == 1 - listener.on_partitions_revoked.assert_called_with(set([])) - - assignment = ConsumerProtocolMemberAssignment(0, [('foobar', [0, 1])], b'') - coordinator._on_join_complete( - 0, 'member-foo', 'roundrobin', assignment.encode()) - assert listener.on_partitions_assigned.call_count == 1 - listener.on_partitions_assigned.assert_called_with({TopicPartition('foobar', 0), TopicPartition('foobar', 1)}) - - -def test_subscription_listener_failure(mocker, coordinator): - listener = mocker.MagicMock(spec=ConsumerRebalanceListener) - coordinator._subscription.subscribe( - topics=['foobar'], - listener=listener) - - # exception raised in listener should not be re-raised by coordinator - listener.on_partitions_revoked.side_effect = Exception('crash') - coordinator._on_join_prepare(0, 'member-foo') - assert listener.on_partitions_revoked.call_count == 1 - - assignment = ConsumerProtocolMemberAssignment(0, [('foobar', [0, 1])], b'') - coordinator._on_join_complete( - 0, 'member-foo', 'roundrobin', assignment.encode()) - assert listener.on_partitions_assigned.call_count == 1 - - -def test_perform_assignment(mocker, coordinator): - member_metadata = { - 'member-foo': ConsumerProtocolMemberMetadata(0, ['foo1'], b''), - 'member-bar': ConsumerProtocolMemberMetadata(0, ['foo1'], b'') - } - assignments = { - 'member-foo': ConsumerProtocolMemberAssignment( - 0, [('foo1', [0])], b''), - 'member-bar': ConsumerProtocolMemberAssignment( - 0, [('foo1', [1])], b'') - } - - mocker.patch.object(RoundRobinPartitionAssignor, 'assign') - RoundRobinPartitionAssignor.assign.return_value = assignments - - ret = coordinator._perform_assignment( - 'member-foo', 'roundrobin', - [(member, metadata.encode()) - for member, metadata in member_metadata.items()]) - - assert RoundRobinPartitionAssignor.assign.call_count == 1 - RoundRobinPartitionAssignor.assign.assert_called_with( - coordinator._client.cluster, member_metadata) - assert ret == assignments - - -def test_on_join_prepare(coordinator): - coordinator._subscription.subscribe(topics=['foobar']) - coordinator._on_join_prepare(0, 'member-foo') - - -def test_need_rejoin(coordinator): - # No subscription - no rejoin - assert coordinator.need_rejoin() is False - - coordinator._subscription.subscribe(topics=['foobar']) - assert coordinator.need_rejoin() is True - - -def test_refresh_committed_offsets_if_needed(mocker, coordinator): - mocker.patch.object(ConsumerCoordinator, 'fetch_committed_offsets', - return_value = { - TopicPartition('foobar', 0): OffsetAndMetadata(123, b''), - TopicPartition('foobar', 1): OffsetAndMetadata(234, b'')}) - coordinator._subscription.assign_from_user([TopicPartition('foobar', 0)]) - assert coordinator._subscription.needs_fetch_committed_offsets is True - coordinator.refresh_committed_offsets_if_needed() - assignment = coordinator._subscription.assignment - assert assignment[TopicPartition('foobar', 0)].committed == OffsetAndMetadata(123, b'') - assert TopicPartition('foobar', 1) not in assignment - assert coordinator._subscription.needs_fetch_committed_offsets is False - - -def test_fetch_committed_offsets(mocker, coordinator): - - # No partitions, no IO polling - mocker.patch.object(coordinator._client, 'poll') - assert coordinator.fetch_committed_offsets([]) == {} - assert coordinator._client.poll.call_count == 0 - - # general case -- send offset fetch request, get successful future - mocker.patch.object(coordinator, 'ensure_coordinator_ready') - mocker.patch.object(coordinator, '_send_offset_fetch_request', - return_value=Future().success('foobar')) - partitions = [TopicPartition('foobar', 0)] - ret = coordinator.fetch_committed_offsets(partitions) - assert ret == 'foobar' - coordinator._send_offset_fetch_request.assert_called_with(partitions) - assert coordinator._client.poll.call_count == 1 - - # Failed future is raised if not retriable - coordinator._send_offset_fetch_request.return_value = Future().failure(AssertionError) - coordinator._client.poll.reset_mock() - try: - coordinator.fetch_committed_offsets(partitions) - except AssertionError: - pass - else: - assert False, 'Exception not raised when expected' - assert coordinator._client.poll.call_count == 1 - - coordinator._client.poll.reset_mock() - coordinator._send_offset_fetch_request.side_effect = [ - Future().failure(Errors.RequestTimedOutError), - Future().success('fizzbuzz')] - - ret = coordinator.fetch_committed_offsets(partitions) - assert ret == 'fizzbuzz' - assert coordinator._client.poll.call_count == 2 # call + retry - - -def test_close(mocker, coordinator): - mocker.patch.object(coordinator, '_maybe_auto_commit_offsets_sync') - mocker.patch.object(coordinator, '_handle_leave_group_response') - mocker.patch.object(coordinator, 'coordinator_unknown', return_value=False) - coordinator.coordinator_id = 0 - coordinator._generation = Generation(1, 'foobar', b'') - coordinator.state = MemberState.STABLE - cli = coordinator._client - mocker.patch.object(cli, 'send', return_value=Future().success('foobar')) - mocker.patch.object(cli, 'poll') - - coordinator.close() - assert coordinator._maybe_auto_commit_offsets_sync.call_count == 1 - coordinator._handle_leave_group_response.assert_called_with('foobar') - - assert coordinator.generation() is None - assert coordinator._generation is Generation.NO_GENERATION - assert coordinator.state is MemberState.UNJOINED - assert coordinator.rejoin_needed is True - - -@pytest.fixture -def offsets(): - return { - TopicPartition('foobar', 0): OffsetAndMetadata(123, b''), - TopicPartition('foobar', 1): OffsetAndMetadata(234, b''), - } - - -def test_commit_offsets_async(mocker, coordinator, offsets): - mocker.patch.object(coordinator._client, 'poll') - mocker.patch.object(coordinator, 'coordinator_unknown', return_value=False) - mocker.patch.object(coordinator, 'ensure_coordinator_ready') - mocker.patch.object(coordinator, '_send_offset_commit_request', - return_value=Future().success('fizzbuzz')) - coordinator.commit_offsets_async(offsets) - assert coordinator._send_offset_commit_request.call_count == 1 - - -def test_commit_offsets_sync(mocker, coordinator, offsets): - mocker.patch.object(coordinator, 'ensure_coordinator_ready') - mocker.patch.object(coordinator, '_send_offset_commit_request', - return_value=Future().success('fizzbuzz')) - cli = coordinator._client - mocker.patch.object(cli, 'poll') - - # No offsets, no calls - assert coordinator.commit_offsets_sync({}) is None - assert coordinator._send_offset_commit_request.call_count == 0 - assert cli.poll.call_count == 0 - - ret = coordinator.commit_offsets_sync(offsets) - assert coordinator._send_offset_commit_request.call_count == 1 - assert cli.poll.call_count == 1 - assert ret == 'fizzbuzz' - - # Failed future is raised if not retriable - coordinator._send_offset_commit_request.return_value = Future().failure(AssertionError) - coordinator._client.poll.reset_mock() - try: - coordinator.commit_offsets_sync(offsets) - except AssertionError: - pass - else: - assert False, 'Exception not raised when expected' - assert coordinator._client.poll.call_count == 1 - - coordinator._client.poll.reset_mock() - coordinator._send_offset_commit_request.side_effect = [ - Future().failure(Errors.RequestTimedOutError), - Future().success('fizzbuzz')] - - ret = coordinator.commit_offsets_sync(offsets) - assert ret == 'fizzbuzz' - assert coordinator._client.poll.call_count == 2 # call + retry - - -@pytest.mark.parametrize( - 'api_version,group_id,enable,error,has_auto_commit,commit_offsets,warn,exc', [ - ((0, 8, 0), 'foobar', True, None, False, False, True, False), - ((0, 8, 1), 'foobar', True, None, True, True, False, False), - ((0, 8, 2), 'foobar', True, None, True, True, False, False), - ((0, 9), 'foobar', False, None, False, False, False, False), - ((0, 9), 'foobar', True, Errors.UnknownMemberIdError(), True, True, True, False), - ((0, 9), 'foobar', True, Errors.IllegalGenerationError(), True, True, True, False), - ((0, 9), 'foobar', True, Errors.RebalanceInProgressError(), True, True, True, False), - ((0, 9), 'foobar', True, Exception(), True, True, False, True), - ((0, 9), 'foobar', True, None, True, True, False, False), - ((0, 9), None, True, None, False, False, True, False), - ]) -def test_maybe_auto_commit_offsets_sync(mocker, api_version, group_id, enable, - error, has_auto_commit, commit_offsets, - warn, exc): - mock_warn = mocker.patch('kafka.coordinator.consumer.log.warning') - mock_exc = mocker.patch('kafka.coordinator.consumer.log.exception') - client = KafkaClient(api_version=api_version) - coordinator = ConsumerCoordinator(client, SubscriptionState(), - Metrics(), - api_version=api_version, - session_timeout_ms=30000, - max_poll_interval_ms=30000, - enable_auto_commit=enable, - group_id=group_id) - commit_sync = mocker.patch.object(coordinator, 'commit_offsets_sync', - side_effect=error) - if has_auto_commit: - assert coordinator.next_auto_commit_deadline is not None - else: - assert coordinator.next_auto_commit_deadline is None - - assert coordinator._maybe_auto_commit_offsets_sync() is None - - if has_auto_commit: - assert coordinator.next_auto_commit_deadline is not None - - assert commit_sync.call_count == (1 if commit_offsets else 0) - assert mock_warn.call_count == (1 if warn else 0) - assert mock_exc.call_count == (1 if exc else 0) - - -@pytest.fixture -def patched_coord(mocker, coordinator): - coordinator._subscription.subscribe(topics=['foobar']) - mocker.patch.object(coordinator, 'coordinator_unknown', return_value=False) - coordinator.coordinator_id = 0 - mocker.patch.object(coordinator, 'coordinator', return_value=0) - coordinator._generation = Generation(0, 'foobar', b'') - coordinator.state = MemberState.STABLE - coordinator.rejoin_needed = False - mocker.patch.object(coordinator, 'need_rejoin', return_value=False) - mocker.patch.object(coordinator._client, 'least_loaded_node', - return_value=1) - mocker.patch.object(coordinator._client, 'ready', return_value=True) - mocker.patch.object(coordinator._client, 'send') - mocker.patch.object(coordinator, '_heartbeat_thread') - mocker.spy(coordinator, '_failed_request') - mocker.spy(coordinator, '_handle_offset_commit_response') - mocker.spy(coordinator, '_handle_offset_fetch_response') - return coordinator - - -def test_send_offset_commit_request_fail(mocker, patched_coord, offsets): - patched_coord.coordinator_unknown.return_value = True - patched_coord.coordinator_id = None - patched_coord.coordinator.return_value = None - - # No offsets - ret = patched_coord._send_offset_commit_request({}) - assert isinstance(ret, Future) - assert ret.succeeded() - - # No coordinator - ret = patched_coord._send_offset_commit_request(offsets) - assert ret.failed() - assert isinstance(ret.exception, Errors.GroupCoordinatorNotAvailableError) - - -@pytest.mark.parametrize('api_version,req_type', [ - ((0, 8, 1), OffsetCommitRequest[0]), - ((0, 8, 2), OffsetCommitRequest[1]), - ((0, 9), OffsetCommitRequest[2])]) -def test_send_offset_commit_request_versions(patched_coord, offsets, - api_version, req_type): - expect_node = 0 - patched_coord.config['api_version'] = api_version - - patched_coord._send_offset_commit_request(offsets) - (node, request), _ = patched_coord._client.send.call_args - assert node == expect_node, 'Unexpected coordinator node' - assert isinstance(request, req_type) - - -def test_send_offset_commit_request_failure(patched_coord, offsets): - _f = Future() - patched_coord._client.send.return_value = _f - future = patched_coord._send_offset_commit_request(offsets) - (node, request), _ = patched_coord._client.send.call_args - error = Exception() - _f.failure(error) - patched_coord._failed_request.assert_called_with(0, request, future, error) - assert future.failed() - assert future.exception is error - - -def test_send_offset_commit_request_success(mocker, patched_coord, offsets): - _f = Future() - patched_coord._client.send.return_value = _f - future = patched_coord._send_offset_commit_request(offsets) - (node, request), _ = patched_coord._client.send.call_args - response = OffsetCommitResponse[0]([('foobar', [(0, 0), (1, 0)])]) - _f.success(response) - patched_coord._handle_offset_commit_response.assert_called_with( - offsets, future, mocker.ANY, response) - - -@pytest.mark.parametrize('response,error,dead', [ - (OffsetCommitResponse[0]([('foobar', [(0, 30), (1, 30)])]), - Errors.GroupAuthorizationFailedError, False), - (OffsetCommitResponse[0]([('foobar', [(0, 12), (1, 12)])]), - Errors.OffsetMetadataTooLargeError, False), - (OffsetCommitResponse[0]([('foobar', [(0, 28), (1, 28)])]), - Errors.InvalidCommitOffsetSizeError, False), - (OffsetCommitResponse[0]([('foobar', [(0, 14), (1, 14)])]), - Errors.GroupLoadInProgressError, False), - (OffsetCommitResponse[0]([('foobar', [(0, 15), (1, 15)])]), - Errors.GroupCoordinatorNotAvailableError, True), - (OffsetCommitResponse[0]([('foobar', [(0, 16), (1, 16)])]), - Errors.NotCoordinatorForGroupError, True), - (OffsetCommitResponse[0]([('foobar', [(0, 7), (1, 7)])]), - Errors.RequestTimedOutError, True), - (OffsetCommitResponse[0]([('foobar', [(0, 25), (1, 25)])]), - Errors.CommitFailedError, False), - (OffsetCommitResponse[0]([('foobar', [(0, 22), (1, 22)])]), - Errors.CommitFailedError, False), - (OffsetCommitResponse[0]([('foobar', [(0, 27), (1, 27)])]), - Errors.CommitFailedError, False), - (OffsetCommitResponse[0]([('foobar', [(0, 17), (1, 17)])]), - Errors.InvalidTopicError, False), - (OffsetCommitResponse[0]([('foobar', [(0, 29), (1, 29)])]), - Errors.TopicAuthorizationFailedError, False), -]) -def test_handle_offset_commit_response(mocker, patched_coord, offsets, - response, error, dead): - future = Future() - patched_coord._handle_offset_commit_response(offsets, future, time.time(), - response) - assert isinstance(future.exception, error) - assert patched_coord.coordinator_id is (None if dead else 0) - - -@pytest.fixture -def partitions(): - return [TopicPartition('foobar', 0), TopicPartition('foobar', 1)] - - -def test_send_offset_fetch_request_fail(mocker, patched_coord, partitions): - patched_coord.coordinator_unknown.return_value = True - patched_coord.coordinator_id = None - patched_coord.coordinator.return_value = None - - # No partitions - ret = patched_coord._send_offset_fetch_request([]) - assert isinstance(ret, Future) - assert ret.succeeded() - assert ret.value == {} - - # No coordinator - ret = patched_coord._send_offset_fetch_request(partitions) - assert ret.failed() - assert isinstance(ret.exception, Errors.GroupCoordinatorNotAvailableError) - - -@pytest.mark.parametrize('api_version,req_type', [ - ((0, 8, 1), OffsetFetchRequest[0]), - ((0, 8, 2), OffsetFetchRequest[1]), - ((0, 9), OffsetFetchRequest[1])]) -def test_send_offset_fetch_request_versions(patched_coord, partitions, - api_version, req_type): - # assuming fixture sets coordinator=0, least_loaded_node=1 - expect_node = 0 - patched_coord.config['api_version'] = api_version - - patched_coord._send_offset_fetch_request(partitions) - (node, request), _ = patched_coord._client.send.call_args - assert node == expect_node, 'Unexpected coordinator node' - assert isinstance(request, req_type) - - -def test_send_offset_fetch_request_failure(patched_coord, partitions): - _f = Future() - patched_coord._client.send.return_value = _f - future = patched_coord._send_offset_fetch_request(partitions) - (node, request), _ = patched_coord._client.send.call_args - error = Exception() - _f.failure(error) - patched_coord._failed_request.assert_called_with(0, request, future, error) - assert future.failed() - assert future.exception is error - - -def test_send_offset_fetch_request_success(patched_coord, partitions): - _f = Future() - patched_coord._client.send.return_value = _f - future = patched_coord._send_offset_fetch_request(partitions) - (node, request), _ = patched_coord._client.send.call_args - response = OffsetFetchResponse[0]([('foobar', [(0, 123, b'', 0), (1, 234, b'', 0)])]) - _f.success(response) - patched_coord._handle_offset_fetch_response.assert_called_with( - future, response) - - -@pytest.mark.parametrize('response,error,dead', [ - (OffsetFetchResponse[0]([('foobar', [(0, 123, b'', 14), (1, 234, b'', 14)])]), - Errors.GroupLoadInProgressError, False), - (OffsetFetchResponse[0]([('foobar', [(0, 123, b'', 16), (1, 234, b'', 16)])]), - Errors.NotCoordinatorForGroupError, True), - (OffsetFetchResponse[0]([('foobar', [(0, 123, b'', 25), (1, 234, b'', 25)])]), - Errors.UnknownMemberIdError, False), - (OffsetFetchResponse[0]([('foobar', [(0, 123, b'', 22), (1, 234, b'', 22)])]), - Errors.IllegalGenerationError, False), - (OffsetFetchResponse[0]([('foobar', [(0, 123, b'', 29), (1, 234, b'', 29)])]), - Errors.TopicAuthorizationFailedError, False), - (OffsetFetchResponse[0]([('foobar', [(0, 123, b'', 0), (1, 234, b'', 0)])]), - None, False), -]) -def test_handle_offset_fetch_response(patched_coord, offsets, - response, error, dead): - future = Future() - patched_coord._handle_offset_fetch_response(future, response) - if error is not None: - assert isinstance(future.exception, error) - else: - assert future.succeeded() - assert future.value == offsets - assert patched_coord.coordinator_id is (None if dead else 0) - - -def test_heartbeat(mocker, patched_coord): - heartbeat = HeartbeatThread(patched_coord) - - assert not heartbeat.enabled and not heartbeat.closed - - heartbeat.enable() - assert heartbeat.enabled - - heartbeat.disable() - assert not heartbeat.enabled - - # heartbeat disables when un-joined - heartbeat.enable() - patched_coord.state = MemberState.UNJOINED - heartbeat._run_once() - assert not heartbeat.enabled - - heartbeat.enable() - patched_coord.state = MemberState.STABLE - mocker.spy(patched_coord, '_send_heartbeat_request') - mocker.patch.object(patched_coord.heartbeat, 'should_heartbeat', return_value=True) - heartbeat._run_once() - assert patched_coord._send_heartbeat_request.call_count == 1 - - heartbeat.close() - assert heartbeat.closed - - -def test_lookup_coordinator_failure(mocker, coordinator): - - mocker.patch.object(coordinator, '_send_group_coordinator_request', - return_value=Future().failure(Exception('foobar'))) - future = coordinator.lookup_coordinator() - assert future.failed() - - -def test_ensure_active_group(mocker, coordinator): - coordinator._subscription.subscribe(topics=['foobar']) - mocker.patch.object(coordinator, 'coordinator_unknown', return_value=False) - mocker.patch.object(coordinator, '_send_join_group_request', return_value=Future().success(True)) - mocker.patch.object(coordinator, 'need_rejoin', side_effect=[True, False]) - mocker.patch.object(coordinator, '_on_join_complete') - mocker.patch.object(coordinator, '_heartbeat_thread') - - coordinator.ensure_active_group() - - coordinator._send_join_group_request.assert_called_once_with() diff --git a/tests/kafka/test_subscription_state.py b/tests/kafka/test_subscription_state.py deleted file mode 100644 index 9718f6af..00000000 --- a/tests/kafka/test_subscription_state.py +++ /dev/null @@ -1,25 +0,0 @@ -# pylint: skip-file -from __future__ import absolute_import - -import pytest - -from kafka.consumer.subscription_state import SubscriptionState - -@pytest.mark.parametrize(('topic_name', 'expectation'), [ - (0, pytest.raises(TypeError)), - (None, pytest.raises(TypeError)), - ('', pytest.raises(ValueError)), - ('.', pytest.raises(ValueError)), - ('..', pytest.raises(ValueError)), - ('a' * 250, pytest.raises(ValueError)), - ('abc/123', pytest.raises(ValueError)), - ('/abc/123', pytest.raises(ValueError)), - ('/abc123', pytest.raises(ValueError)), - ('name with space', pytest.raises(ValueError)), - ('name*with*stars', pytest.raises(ValueError)), - ('name+with+plus', pytest.raises(ValueError)), -]) -def test_topic_name_validation(topic_name, expectation): - state = SubscriptionState() - with expectation: - state._ensure_valid_topic_name(topic_name) From 97386fcba8f2b0960612c81a9d60939371984b65 Mon Sep 17 00:00:00 2001 From: Denis Otkidach Date: Sun, 22 Oct 2023 16:48:50 +0300 Subject: [PATCH 08/20] Merge admin --- aiokafka/admin/__init__.py | 5 + aiokafka/{admin.py => admin/client.py} | 48 +- {kafka => aiokafka}/admin/config_resource.py | 13 +- aiokafka/admin/new_partitions.py | 19 + {kafka => aiokafka}/admin/new_topic.py | 14 +- kafka/__init__.py | 5 +- kafka/admin/__init__.py | 14 - kafka/admin/acl_resource.py | 244 ---- kafka/admin/client.py | 1347 ------------------ kafka/admin/new_partitions.py | 19 - kafka/client_async.py | 1077 -------------- tests/kafka/fixtures.py | 12 +- tests/kafka/test_acl_comparisons.py | 92 -- tests/kafka/test_admin.py | 78 - tests/kafka/test_admin_integration.py | 314 ---- tests/kafka/test_client_async.py | 409 ------ tests/kafka/test_sasl_integration.py | 80 -- tests/test_admin.py | 6 +- 18 files changed, 81 insertions(+), 3715 deletions(-) create mode 100644 aiokafka/admin/__init__.py rename aiokafka/{admin.py => admin/client.py} (94%) rename {kafka => aiokafka}/admin/config_resource.py (69%) create mode 100644 aiokafka/admin/new_partitions.py rename {kafka => aiokafka}/admin/new_topic.py (75%) delete mode 100644 kafka/admin/__init__.py delete mode 100644 kafka/admin/acl_resource.py delete mode 100644 kafka/admin/client.py delete mode 100644 kafka/admin/new_partitions.py delete mode 100644 kafka/client_async.py delete mode 100644 tests/kafka/test_acl_comparisons.py delete mode 100644 tests/kafka/test_admin.py delete mode 100644 tests/kafka/test_admin_integration.py delete mode 100644 tests/kafka/test_client_async.py delete mode 100644 tests/kafka/test_sasl_integration.py diff --git a/aiokafka/admin/__init__.py b/aiokafka/admin/__init__.py new file mode 100644 index 00000000..61913cc8 --- /dev/null +++ b/aiokafka/admin/__init__.py @@ -0,0 +1,5 @@ +from .client import AIOKafkaAdminClient +from .new_partitions import NewPartitions +from .new_topic import NewTopic + +__all__ = ["AIOKafkaAdminClient", "NewPartitions", "NewTopic"] diff --git a/aiokafka/admin.py b/aiokafka/admin/client.py similarity index 94% rename from aiokafka/admin.py rename to aiokafka/admin/client.py index 9720707a..392c93eb 100644 --- a/aiokafka/admin.py +++ b/aiokafka/admin/client.py @@ -18,12 +18,13 @@ ListGroupsRequest, ApiVersionRequest_v0) from kafka.structs import TopicPartition, OffsetAndMetadata -from kafka.admin import NewTopic, KafkaAdminClient as Admin -from kafka.admin.config_resource import ConfigResourceType, ConfigResource from aiokafka import __version__ from aiokafka.client import AIOKafkaClient +from .config_resource import ConfigResourceType, ConfigResource +from .new_topic import NewTopic + log = logging.getLogger(__name__) @@ -109,9 +110,9 @@ def __init__(self, *, loop=None, sasl_oauth_token_provider=sasl_oauth_token_provider) async def close(self): - """Close the KafkaAdminClient connection to the Kafka broker.""" + """Close the AIOKafkaAdminClient connection to the Kafka broker.""" if not hasattr(self, '_closed') or self._closed: - log.info("KafkaAdminClient already closed.") + log.info("AIOKafkaAdminClient already closed.") return await self._client.close() @@ -165,6 +166,22 @@ def _matching_api_version(self, operation: List[Request]) -> int: .format(operation[0].__name__)) return version + @staticmethod + def _convert_new_topic_request(new_topic): + return ( + new_topic.name, + new_topic.num_partitions, + new_topic.replication_factor, + [ + (partition_id, replicas) + for partition_id, replicas in new_topic.replica_assignments.items() + ], + [ + (config_key, config_value) + for config_key, config_value in new_topic.topic_configs.items() + ] + ) + async def create_topics( self, new_topics: List[NewTopic], @@ -181,7 +198,7 @@ async def create_topics( :return: Appropriate version of CreateTopicResponse class. """ version = self._matching_api_version(CreateTopicsRequest) - topics = [Admin._convert_new_topic_request(nt) for nt in new_topics] + topics = [self._convert_new_topic_request(nt) for nt in new_topics] log.debug("Attempting to send create topic request for %r", new_topics) timeout_ms = timeout_ms or self._request_timeout_ms if version == 0: @@ -320,15 +337,32 @@ async def alter_configs(self, config_resources: List[ConfigResource]) -> Respons return await asyncio.gather(*futures) @staticmethod + def _convert_describe_config_resource_request(config_resource): + return ( + config_resource.resource_type, + config_resource.name, + list(config_resource.configs.keys()) if config_resource.configs else None + ) + + @staticmethod + def _convert_alter_config_resource_request(config_resource): + return ( + config_resource.resource_type, + config_resource.name, + list(config_resource.configs.items()) + ) + + @classmethod def _convert_config_resources( + cls, config_resources: List[ConfigResource], op_type: str = "describe") -> Tuple[Dict[int, Any], List[Any]]: broker_resources = defaultdict(list) topic_resources = [] if op_type == "describe": - convert_func = Admin._convert_describe_config_resource_request + convert_func = cls._convert_describe_config_resource_request else: - convert_func = Admin._convert_alter_config_resource_request + convert_func = cls._convert_alter_config_resource_request for config_resource in config_resources: resource = convert_func(config_resource) if config_resource.resource_type == ConfigResourceType.BROKER: diff --git a/kafka/admin/config_resource.py b/aiokafka/admin/config_resource.py similarity index 69% rename from kafka/admin/config_resource.py rename to aiokafka/admin/config_resource.py index e3294c9c..4c67c5eb 100644 --- a/kafka/admin/config_resource.py +++ b/aiokafka/admin/config_resource.py @@ -1,11 +1,4 @@ -from __future__ import absolute_import - -# enum in stdlib as of py3.4 -try: - from enum import IntEnum # pylint: disable=import-error -except ImportError: - # vendored backport module - from kafka.vendor.enum34 import IntEnum +from enum import IntEnum class ConfigResourceType(IntEnum): @@ -15,7 +8,7 @@ class ConfigResourceType(IntEnum): TOPIC = 2 -class ConfigResource(object): +class ConfigResource: """A class for specifying config resources. Arguments: resource_type (ConfigResourceType): the type of kafka resource @@ -30,7 +23,7 @@ def __init__( configs=None ): if not isinstance(resource_type, (ConfigResourceType)): - resource_type = ConfigResourceType[str(resource_type).upper()] # pylint: disable-msg=unsubscriptable-object + resource_type = ConfigResourceType[str(resource_type).upper()] self.resource_type = resource_type self.name = name self.configs = configs diff --git a/aiokafka/admin/new_partitions.py b/aiokafka/admin/new_partitions.py new file mode 100644 index 00000000..7c452819 --- /dev/null +++ b/aiokafka/admin/new_partitions.py @@ -0,0 +1,19 @@ +class NewPartitions: + """A class for new partition creation on existing topics. Note that the + length of new_assignments, if specified, must be the difference between the + new total number of partitions and the existing number of partitions. + Arguments: + total_count (int): + the total number of partitions that should exist on the topic + new_assignments ([[int]]): + an array of arrays of replica assignments for new partitions. + If not set, broker assigns replicas per an internal algorithm. + """ + + def __init__( + self, + total_count, + new_assignments=None + ): + self.total_count = total_count + self.new_assignments = new_assignments diff --git a/kafka/admin/new_topic.py b/aiokafka/admin/new_topic.py similarity index 75% rename from kafka/admin/new_topic.py rename to aiokafka/admin/new_topic.py index 645ac383..4d00daed 100644 --- a/kafka/admin/new_topic.py +++ b/aiokafka/admin/new_topic.py @@ -1,9 +1,7 @@ -from __future__ import absolute_import - from kafka.errors import IllegalArgumentError -class NewTopic(object): +class NewTopic: """ A class for new topic creation Arguments: name (string): name of the topic @@ -25,8 +23,14 @@ def __init__( replica_assignments=None, topic_configs=None, ): - if not (num_partitions == -1 or replication_factor == -1) ^ (replica_assignments is None): - raise IllegalArgumentError('either num_partitions/replication_factor or replica_assignment must be specified') + if not ( + (num_partitions == -1 or replication_factor == -1) + ^ (replica_assignments is None) + ): + raise IllegalArgumentError( + "either num_partitions/replication_factor or replica_assignment " + "must be specified" + ) self.name = name self.num_partitions = num_partitions self.replication_factor = replication_factor diff --git a/kafka/__init__.py b/kafka/__init__.py index 2a335d23..976287b2 100644 --- a/kafka/__init__.py +++ b/kafka/__init__.py @@ -18,14 +18,11 @@ def emit(self, record): logging.getLogger(__name__).addHandler(NullHandler()) -from kafka.admin import KafkaAdminClient -from kafka.client_async import KafkaClient from kafka.conn import BrokerConnection from kafka.serializer import Serializer, Deserializer from kafka.structs import TopicPartition, OffsetAndMetadata __all__ = [ - 'BrokerConnection', 'ConsumerRebalanceListener', 'KafkaAdminClient', - 'KafkaClient', 'KafkaConsumer', 'KafkaProducer', + 'BrokerConnection', 'ConsumerRebalanceListener', ] diff --git a/kafka/admin/__init__.py b/kafka/admin/__init__.py deleted file mode 100644 index c240fc6d..00000000 --- a/kafka/admin/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -from __future__ import absolute_import - -from kafka.admin.config_resource import ConfigResource, ConfigResourceType -from kafka.admin.client import KafkaAdminClient -from kafka.admin.acl_resource import (ACL, ACLFilter, ResourcePattern, ResourcePatternFilter, ACLOperation, - ResourceType, ACLPermissionType, ACLResourcePatternType) -from kafka.admin.new_topic import NewTopic -from kafka.admin.new_partitions import NewPartitions - -__all__ = [ - 'ConfigResource', 'ConfigResourceType', 'KafkaAdminClient', 'NewTopic', 'NewPartitions', 'ACL', 'ACLFilter', - 'ResourcePattern', 'ResourcePatternFilter', 'ACLOperation', 'ResourceType', 'ACLPermissionType', - 'ACLResourcePatternType' -] diff --git a/kafka/admin/acl_resource.py b/kafka/admin/acl_resource.py deleted file mode 100644 index fd997a10..00000000 --- a/kafka/admin/acl_resource.py +++ /dev/null @@ -1,244 +0,0 @@ -from __future__ import absolute_import -from kafka.errors import IllegalArgumentError - -# enum in stdlib as of py3.4 -try: - from enum import IntEnum # pylint: disable=import-error -except ImportError: - # vendored backport module - from kafka.vendor.enum34 import IntEnum - - -class ResourceType(IntEnum): - """Type of kafka resource to set ACL for - - The ANY value is only valid in a filter context - """ - - UNKNOWN = 0, - ANY = 1, - CLUSTER = 4, - DELEGATION_TOKEN = 6, - GROUP = 3, - TOPIC = 2, - TRANSACTIONAL_ID = 5 - - -class ACLOperation(IntEnum): - """Type of operation - - The ANY value is only valid in a filter context - """ - - ANY = 1, - ALL = 2, - READ = 3, - WRITE = 4, - CREATE = 5, - DELETE = 6, - ALTER = 7, - DESCRIBE = 8, - CLUSTER_ACTION = 9, - DESCRIBE_CONFIGS = 10, - ALTER_CONFIGS = 11, - IDEMPOTENT_WRITE = 12 - - -class ACLPermissionType(IntEnum): - """An enumerated type of permissions - - The ANY value is only valid in a filter context - """ - - ANY = 1, - DENY = 2, - ALLOW = 3 - - -class ACLResourcePatternType(IntEnum): - """An enumerated type of resource patterns - - More details on the pattern types and how they work - can be found in KIP-290 (Support for prefixed ACLs) - https://cwiki.apache.org/confluence/display/KAFKA/KIP-290%3A+Support+for+Prefixed+ACLs - """ - - ANY = 1, - MATCH = 2, - LITERAL = 3, - PREFIXED = 4 - - -class ACLFilter(object): - """Represents a filter to use with describing and deleting ACLs - - The difference between this class and the ACL class is mainly that - we allow using ANY with the operation, permission, and resource type objects - to fetch ALCs matching any of the properties. - - To make a filter matching any principal, set principal to None - """ - - def __init__( - self, - principal, - host, - operation, - permission_type, - resource_pattern - ): - self.principal = principal - self.host = host - self.operation = operation - self.permission_type = permission_type - self.resource_pattern = resource_pattern - - self.validate() - - def validate(self): - if not isinstance(self.operation, ACLOperation): - raise IllegalArgumentError("operation must be an ACLOperation object, and cannot be ANY") - if not isinstance(self.permission_type, ACLPermissionType): - raise IllegalArgumentError("permission_type must be an ACLPermissionType object, and cannot be ANY") - if not isinstance(self.resource_pattern, ResourcePatternFilter): - raise IllegalArgumentError("resource_pattern must be a ResourcePatternFilter object") - - def __repr__(self): - return "".format( - principal=self.principal, - host=self.host, - operation=self.operation.name, - type=self.permission_type.name, - resource=self.resource_pattern - ) - - def __eq__(self, other): - return all(( - self.principal == other.principal, - self.host == other.host, - self.operation == other.operation, - self.permission_type == other.permission_type, - self.resource_pattern == other.resource_pattern - )) - - def __hash__(self): - return hash(( - self.principal, - self.host, - self.operation, - self.permission_type, - self.resource_pattern, - )) - - -class ACL(ACLFilter): - """Represents a concrete ACL for a specific ResourcePattern - - In kafka an ACL is a 4-tuple of (principal, host, operation, permission_type) - that limits who can do what on a specific resource (or since KIP-290 a resource pattern) - - Terminology: - Principal -> This is the identifier for the user. Depending on the authorization method used (SSL, SASL etc) - the principal will look different. See http://kafka.apache.org/documentation/#security_authz for details. - The principal must be on the format "User:" or kafka will treat it as invalid. It's possible to use - other principal types than "User" if using a custom authorizer for the cluster. - Host -> This must currently be an IP address. It cannot be a range, and it cannot be a domain name. - It can be set to "*", which is special cased in kafka to mean "any host" - Operation -> Which client operation this ACL refers to. Has different meaning depending - on the resource type the ACL refers to. See https://docs.confluent.io/current/kafka/authorization.html#acl-format - for a list of which combinations of resource/operation that unlocks which kafka APIs - Permission Type: Whether this ACL is allowing or denying access - Resource Pattern -> This is a representation of the resource or resource pattern that the ACL - refers to. See the ResourcePattern class for details. - - """ - - def __init__( - self, - principal, - host, - operation, - permission_type, - resource_pattern - ): - super(ACL, self).__init__(principal, host, operation, permission_type, resource_pattern) - self.validate() - - def validate(self): - if self.operation == ACLOperation.ANY: - raise IllegalArgumentError("operation cannot be ANY") - if self.permission_type == ACLPermissionType.ANY: - raise IllegalArgumentError("permission_type cannot be ANY") - if not isinstance(self.resource_pattern, ResourcePattern): - raise IllegalArgumentError("resource_pattern must be a ResourcePattern object") - - -class ResourcePatternFilter(object): - def __init__( - self, - resource_type, - resource_name, - pattern_type - ): - self.resource_type = resource_type - self.resource_name = resource_name - self.pattern_type = pattern_type - - self.validate() - - def validate(self): - if not isinstance(self.resource_type, ResourceType): - raise IllegalArgumentError("resource_type must be a ResourceType object") - if not isinstance(self.pattern_type, ACLResourcePatternType): - raise IllegalArgumentError("pattern_type must be an ACLResourcePatternType object") - - def __repr__(self): - return "".format( - self.resource_type.name, - self.resource_name, - self.pattern_type.name - ) - - def __eq__(self, other): - return all(( - self.resource_type == other.resource_type, - self.resource_name == other.resource_name, - self.pattern_type == other.pattern_type, - )) - - def __hash__(self): - return hash(( - self.resource_type, - self.resource_name, - self.pattern_type - )) - - -class ResourcePattern(ResourcePatternFilter): - """A resource pattern to apply the ACL to - - Resource patterns are used to be able to specify which resources an ACL - describes in a more flexible way than just pointing to a literal topic name for example. - Since KIP-290 (kafka 2.0) it's possible to set an ACL for a prefixed resource name, which - can cut down considerably on the number of ACLs needed when the number of topics and - consumer groups start to grow. - The default pattern_type is LITERAL, and it describes a specific resource. This is also how - ACLs worked before the introduction of prefixed ACLs - """ - - def __init__( - self, - resource_type, - resource_name, - pattern_type=ACLResourcePatternType.LITERAL - ): - super(ResourcePattern, self).__init__(resource_type, resource_name, pattern_type) - self.validate() - - def validate(self): - if self.resource_type == ResourceType.ANY: - raise IllegalArgumentError("resource_type cannot be ANY") - if self.pattern_type in [ACLResourcePatternType.ANY, ACLResourcePatternType.MATCH]: - raise IllegalArgumentError( - "pattern_type cannot be {} on a concrete ResourcePattern".format(self.pattern_type.name) - ) diff --git a/kafka/admin/client.py b/kafka/admin/client.py deleted file mode 100644 index 8eb7504a..00000000 --- a/kafka/admin/client.py +++ /dev/null @@ -1,1347 +0,0 @@ -from __future__ import absolute_import - -from collections import defaultdict -import copy -import logging -import socket - -from . import ConfigResourceType -from kafka.vendor import six - -from kafka.admin.acl_resource import ACLOperation, ACLPermissionType, ACLFilter, ACL, ResourcePattern, ResourceType, \ - ACLResourcePatternType -from kafka.client_async import KafkaClient, selectors -from kafka.coordinator.protocol import ConsumerProtocolMemberMetadata, ConsumerProtocolMemberAssignment, ConsumerProtocol -import kafka.errors as Errors -from kafka.errors import ( - IncompatibleBrokerVersion, KafkaConfigurationError, NotControllerError, - UnrecognizedBrokerVersion, IllegalArgumentError) -from kafka.metrics import MetricConfig, Metrics -from kafka.protocol.admin import ( - CreateTopicsRequest, DeleteTopicsRequest, DescribeConfigsRequest, AlterConfigsRequest, CreatePartitionsRequest, - ListGroupsRequest, DescribeGroupsRequest, DescribeAclsRequest, CreateAclsRequest, DeleteAclsRequest, - DeleteGroupsRequest -) -from kafka.protocol.commit import GroupCoordinatorRequest, OffsetFetchRequest -from kafka.protocol.metadata import MetadataRequest -from kafka.protocol.types import Array -from kafka.structs import TopicPartition, OffsetAndMetadata, MemberInformation, GroupInformation -from kafka.version import __version__ - - -log = logging.getLogger(__name__) - - -class KafkaAdminClient(object): - """A class for administering the Kafka cluster. - - Warning: - This is an unstable interface that was recently added and is subject to - change without warning. In particular, many methods currently return - raw protocol tuples. In future releases, we plan to make these into - nicer, more pythonic objects. Unfortunately, this will likely break - those interfaces. - - The KafkaAdminClient class will negotiate for the latest version of each message - protocol format supported by both the kafka-python client library and the - Kafka broker. Usage of optional fields from protocol versions that are not - supported by the broker will result in IncompatibleBrokerVersion exceptions. - - Use of this class requires a minimum broker version >= 0.10.0.0. - - Keyword Arguments: - bootstrap_servers: 'host[:port]' string (or list of 'host[:port]' - strings) that the consumer should contact to bootstrap initial - cluster metadata. This does not have to be the full node list. - It just needs to have at least one broker that will respond to a - Metadata API Request. Default port is 9092. If no servers are - specified, will default to localhost:9092. - client_id (str): a name for this client. This string is passed in - each request to servers and can be used to identify specific - server-side log entries that correspond to this client. Also - submitted to GroupCoordinator for logging with respect to - consumer group administration. Default: 'kafka-python-{version}' - reconnect_backoff_ms (int): The amount of time in milliseconds to - wait before attempting to reconnect to a given host. - Default: 50. - reconnect_backoff_max_ms (int): The maximum amount of time in - milliseconds to backoff/wait when reconnecting to a broker that has - repeatedly failed to connect. If provided, the backoff per host - will increase exponentially for each consecutive connection - failure, up to this maximum. Once the maximum is reached, - reconnection attempts will continue periodically with this fixed - rate. To avoid connection storms, a randomization factor of 0.2 - will be applied to the backoff resulting in a random range between - 20% below and 20% above the computed value. Default: 1000. - request_timeout_ms (int): Client request timeout in milliseconds. - Default: 30000. - connections_max_idle_ms: Close idle connections after the number of - milliseconds specified by this config. The broker closes idle - connections after connections.max.idle.ms, so this avoids hitting - unexpected socket disconnected errors on the client. - Default: 540000 - retry_backoff_ms (int): Milliseconds to backoff when retrying on - errors. Default: 100. - max_in_flight_requests_per_connection (int): Requests are pipelined - to kafka brokers up to this number of maximum requests per - broker connection. Default: 5. - receive_buffer_bytes (int): The size of the TCP receive buffer - (SO_RCVBUF) to use when reading data. Default: None (relies on - system defaults). Java client defaults to 32768. - send_buffer_bytes (int): The size of the TCP send buffer - (SO_SNDBUF) to use when sending data. Default: None (relies on - system defaults). Java client defaults to 131072. - socket_options (list): List of tuple-arguments to socket.setsockopt - to apply to broker connection sockets. Default: - [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)] - metadata_max_age_ms (int): The period of time in milliseconds after - which we force a refresh of metadata even if we haven't seen any - partition leadership changes to proactively discover any new - brokers or partitions. Default: 300000 - security_protocol (str): Protocol used to communicate with brokers. - Valid values are: PLAINTEXT, SSL, SASL_PLAINTEXT, SASL_SSL. - Default: PLAINTEXT. - ssl_context (ssl.SSLContext): Pre-configured SSLContext for wrapping - socket connections. If provided, all other ssl_* configurations - will be ignored. Default: None. - ssl_check_hostname (bool): Flag to configure whether SSL handshake - should verify that the certificate matches the broker's hostname. - Default: True. - ssl_cafile (str): Optional filename of CA file to use in certificate - verification. Default: None. - ssl_certfile (str): Optional filename of file in PEM format containing - the client certificate, as well as any CA certificates needed to - establish the certificate's authenticity. Default: None. - ssl_keyfile (str): Optional filename containing the client private key. - Default: None. - ssl_password (str): Optional password to be used when loading the - certificate chain. Default: None. - ssl_crlfile (str): Optional filename containing the CRL to check for - certificate expiration. By default, no CRL check is done. When - providing a file, only the leaf certificate will be checked against - this CRL. The CRL can only be checked with Python 3.4+ or 2.7.9+. - Default: None. - api_version (tuple): Specify which Kafka API version to use. If set - to None, KafkaClient will attempt to infer the broker version by - probing various APIs. Example: (0, 10, 2). Default: None - api_version_auto_timeout_ms (int): number of milliseconds to throw a - timeout exception from the constructor when checking the broker - api version. Only applies if api_version is None - selector (selectors.BaseSelector): Provide a specific selector - implementation to use for I/O multiplexing. - Default: selectors.DefaultSelector - metrics (kafka.metrics.Metrics): Optionally provide a metrics - instance for capturing network IO stats. Default: None. - metric_group_prefix (str): Prefix for metric names. Default: '' - sasl_mechanism (str): Authentication mechanism when security_protocol - is configured for SASL_PLAINTEXT or SASL_SSL. Valid values are: - PLAIN, GSSAPI, OAUTHBEARER, SCRAM-SHA-256, SCRAM-SHA-512. - sasl_plain_username (str): username for sasl PLAIN and SCRAM authentication. - Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms. - sasl_plain_password (str): password for sasl PLAIN and SCRAM authentication. - Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms. - sasl_kerberos_service_name (str): Service name to include in GSSAPI - sasl mechanism handshake. Default: 'kafka' - sasl_kerberos_domain_name (str): kerberos domain name to use in GSSAPI - sasl mechanism handshake. Default: one of bootstrap servers - sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider - instance. (See kafka.oauth.abstract). Default: None - kafka_client (callable): Custom class / callable for creating KafkaClient instances - - """ - DEFAULT_CONFIG = { - # client configs - 'bootstrap_servers': 'localhost', - 'client_id': 'kafka-python-' + __version__, - 'request_timeout_ms': 30000, - 'connections_max_idle_ms': 9 * 60 * 1000, - 'reconnect_backoff_ms': 50, - 'reconnect_backoff_max_ms': 1000, - 'max_in_flight_requests_per_connection': 5, - 'receive_buffer_bytes': None, - 'send_buffer_bytes': None, - 'socket_options': [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)], - 'sock_chunk_bytes': 4096, # undocumented experimental option - 'sock_chunk_buffer_count': 1000, # undocumented experimental option - 'retry_backoff_ms': 100, - 'metadata_max_age_ms': 300000, - 'security_protocol': 'PLAINTEXT', - 'ssl_context': None, - 'ssl_check_hostname': True, - 'ssl_cafile': None, - 'ssl_certfile': None, - 'ssl_keyfile': None, - 'ssl_password': None, - 'ssl_crlfile': None, - 'api_version': None, - 'api_version_auto_timeout_ms': 2000, - 'selector': selectors.DefaultSelector, - 'sasl_mechanism': None, - 'sasl_plain_username': None, - 'sasl_plain_password': None, - 'sasl_kerberos_service_name': 'kafka', - 'sasl_kerberos_domain_name': None, - 'sasl_oauth_token_provider': None, - - # metrics configs - 'metric_reporters': [], - 'metrics_num_samples': 2, - 'metrics_sample_window_ms': 30000, - 'kafka_client': KafkaClient, - } - - def __init__(self, **configs): - log.debug("Starting KafkaAdminClient with configuration: %s", configs) - extra_configs = set(configs).difference(self.DEFAULT_CONFIG) - if extra_configs: - raise KafkaConfigurationError("Unrecognized configs: {}".format(extra_configs)) - - self.config = copy.copy(self.DEFAULT_CONFIG) - self.config.update(configs) - - # Configure metrics - metrics_tags = {'client-id': self.config['client_id']} - metric_config = MetricConfig(samples=self.config['metrics_num_samples'], - time_window_ms=self.config['metrics_sample_window_ms'], - tags=metrics_tags) - reporters = [reporter() for reporter in self.config['metric_reporters']] - self._metrics = Metrics(metric_config, reporters) - - self._client = self.config['kafka_client']( - metrics=self._metrics, - metric_group_prefix='admin', - **self.config - ) - self._client.check_version(timeout=(self.config['api_version_auto_timeout_ms'] / 1000)) - - # Get auto-discovered version from client if necessary - if self.config['api_version'] is None: - self.config['api_version'] = self._client.config['api_version'] - - self._closed = False - self._refresh_controller_id() - log.debug("KafkaAdminClient started.") - - def close(self): - """Close the KafkaAdminClient connection to the Kafka broker.""" - if not hasattr(self, '_closed') or self._closed: - log.info("KafkaAdminClient already closed.") - return - - self._metrics.close() - self._client.close() - self._closed = True - log.debug("KafkaAdminClient is now closed.") - - def _matching_api_version(self, operation): - """Find the latest version of the protocol operation supported by both - this library and the broker. - - This resolves to the lesser of either the latest api version this - library supports, or the max version supported by the broker. - - :param operation: A list of protocol operation versions from kafka.protocol. - :return: The max matching version number between client and broker. - """ - broker_api_versions = self._client.get_api_versions() - api_key = operation[0].API_KEY - if broker_api_versions is None or api_key not in broker_api_versions: - raise IncompatibleBrokerVersion( - "Kafka broker does not support the '{}' Kafka protocol." - .format(operation[0].__name__)) - min_version, max_version = broker_api_versions[api_key] - version = min(len(operation) - 1, max_version) - if version < min_version: - # max library version is less than min broker version. Currently, - # no Kafka versions specify a min msg version. Maybe in the future? - raise IncompatibleBrokerVersion( - "No version of the '{}' Kafka protocol is supported by both the client and broker." - .format(operation[0].__name__)) - return version - - def _validate_timeout(self, timeout_ms): - """Validate the timeout is set or use the configuration default. - - :param timeout_ms: The timeout provided by api call, in milliseconds. - :return: The timeout to use for the operation. - """ - return timeout_ms or self.config['request_timeout_ms'] - - def _refresh_controller_id(self): - """Determine the Kafka cluster controller.""" - version = self._matching_api_version(MetadataRequest) - if 1 <= version <= 6: - request = MetadataRequest[version]() - future = self._send_request_to_node(self._client.least_loaded_node(), request) - - self._wait_for_futures([future]) - - response = future.value - controller_id = response.controller_id - # verify the controller is new enough to support our requests - controller_version = self._client.check_version(controller_id, timeout=(self.config['api_version_auto_timeout_ms'] / 1000)) - if controller_version < (0, 10, 0): - raise IncompatibleBrokerVersion( - "The controller appears to be running Kafka {}. KafkaAdminClient requires brokers >= 0.10.0.0." - .format(controller_version)) - self._controller_id = controller_id - else: - raise UnrecognizedBrokerVersion( - "Kafka Admin interface cannot determine the controller using MetadataRequest_v{}." - .format(version)) - - def _find_coordinator_id_send_request(self, group_id): - """Send a FindCoordinatorRequest to a broker. - - :param group_id: The consumer group ID. This is typically the group - name as a string. - :return: A message future - """ - # TODO add support for dynamically picking version of - # GroupCoordinatorRequest which was renamed to FindCoordinatorRequest. - # When I experimented with this, the coordinator value returned in - # GroupCoordinatorResponse_v1 didn't match the value returned by - # GroupCoordinatorResponse_v0 and I couldn't figure out why. - version = 0 - # version = self._matching_api_version(GroupCoordinatorRequest) - if version <= 0: - request = GroupCoordinatorRequest[version](group_id) - else: - raise NotImplementedError( - "Support for GroupCoordinatorRequest_v{} has not yet been added to KafkaAdminClient." - .format(version)) - return self._send_request_to_node(self._client.least_loaded_node(), request) - - def _find_coordinator_id_process_response(self, response): - """Process a FindCoordinatorResponse. - - :param response: a FindCoordinatorResponse. - :return: The node_id of the broker that is the coordinator. - """ - if response.API_VERSION <= 0: - error_type = Errors.for_code(response.error_code) - if error_type is not Errors.NoError: - # Note: When error_type.retriable, Java will retry... see - # KafkaAdminClient's handleFindCoordinatorError method - raise error_type( - "FindCoordinatorRequest failed with response '{}'." - .format(response)) - else: - raise NotImplementedError( - "Support for FindCoordinatorRequest_v{} has not yet been added to KafkaAdminClient." - .format(response.API_VERSION)) - return response.coordinator_id - - def _find_coordinator_ids(self, group_ids): - """Find the broker node_ids of the coordinators of the given groups. - - Sends a FindCoordinatorRequest message to the cluster for each group_id. - Will block until the FindCoordinatorResponse is received for all groups. - Any errors are immediately raised. - - :param group_ids: A list of consumer group IDs. This is typically the group - name as a string. - :return: A dict of {group_id: node_id} where node_id is the id of the - broker that is the coordinator for the corresponding group. - """ - groups_futures = { - group_id: self._find_coordinator_id_send_request(group_id) - for group_id in group_ids - } - self._wait_for_futures(groups_futures.values()) - groups_coordinators = { - group_id: self._find_coordinator_id_process_response(future.value) - for group_id, future in groups_futures.items() - } - return groups_coordinators - - def _send_request_to_node(self, node_id, request, wakeup=True): - """Send a Kafka protocol message to a specific broker. - - Returns a future that may be polled for status and results. - - :param node_id: The broker id to which to send the message. - :param request: The message to send. - :param wakeup: Optional flag to disable thread-wakeup. - :return: A future object that may be polled for status and results. - :exception: The exception if the message could not be sent. - """ - while not self._client.ready(node_id): - # poll until the connection to broker is ready, otherwise send() - # will fail with NodeNotReadyError - self._client.poll() - return self._client.send(node_id, request, wakeup) - - def _send_request_to_controller(self, request): - """Send a Kafka protocol message to the cluster controller. - - Will block until the message result is received. - - :param request: The message to send. - :return: The Kafka protocol response for the message. - """ - tries = 2 # in case our cached self._controller_id is outdated - while tries: - tries -= 1 - future = self._send_request_to_node(self._controller_id, request) - - self._wait_for_futures([future]) - - response = future.value - # In Java, the error field name is inconsistent: - # - CreateTopicsResponse / CreatePartitionsResponse uses topic_errors - # - DeleteTopicsResponse uses topic_error_codes - # So this is a little brittle in that it assumes all responses have - # one of these attributes and that they always unpack into - # (topic, error_code) tuples. - topic_error_tuples = (response.topic_errors if hasattr(response, 'topic_errors') - else response.topic_error_codes) - # Also small py2/py3 compatibility -- py3 can ignore extra values - # during unpack via: for x, y, *rest in list_of_values. py2 cannot. - # So for now we have to map across the list and explicitly drop any - # extra values (usually the error_message) - for topic, error_code in map(lambda e: e[:2], topic_error_tuples): - error_type = Errors.for_code(error_code) - if tries and error_type is NotControllerError: - # No need to inspect the rest of the errors for - # non-retriable errors because NotControllerError should - # either be thrown for all errors or no errors. - self._refresh_controller_id() - break - elif error_type is not Errors.NoError: - raise error_type( - "Request '{}' failed with response '{}'." - .format(request, response)) - else: - return response - raise RuntimeError("This should never happen, please file a bug with full stacktrace if encountered") - - @staticmethod - def _convert_new_topic_request(new_topic): - return ( - new_topic.name, - new_topic.num_partitions, - new_topic.replication_factor, - [ - (partition_id, replicas) for partition_id, replicas in new_topic.replica_assignments.items() - ], - [ - (config_key, config_value) for config_key, config_value in new_topic.topic_configs.items() - ] - ) - - def create_topics(self, new_topics, timeout_ms=None, validate_only=False): - """Create new topics in the cluster. - - :param new_topics: A list of NewTopic objects. - :param timeout_ms: Milliseconds to wait for new topics to be created - before the broker returns. - :param validate_only: If True, don't actually create new topics. - Not supported by all versions. Default: False - :return: Appropriate version of CreateTopicResponse class. - """ - version = self._matching_api_version(CreateTopicsRequest) - timeout_ms = self._validate_timeout(timeout_ms) - if version == 0: - if validate_only: - raise IncompatibleBrokerVersion( - "validate_only requires CreateTopicsRequest >= v1, which is not supported by Kafka {}." - .format(self.config['api_version'])) - request = CreateTopicsRequest[version]( - create_topic_requests=[self._convert_new_topic_request(new_topic) for new_topic in new_topics], - timeout=timeout_ms - ) - elif version <= 3: - request = CreateTopicsRequest[version]( - create_topic_requests=[self._convert_new_topic_request(new_topic) for new_topic in new_topics], - timeout=timeout_ms, - validate_only=validate_only - ) - else: - raise NotImplementedError( - "Support for CreateTopics v{} has not yet been added to KafkaAdminClient." - .format(version)) - # TODO convert structs to a more pythonic interface - # TODO raise exceptions if errors - return self._send_request_to_controller(request) - - def delete_topics(self, topics, timeout_ms=None): - """Delete topics from the cluster. - - :param topics: A list of topic name strings. - :param timeout_ms: Milliseconds to wait for topics to be deleted - before the broker returns. - :return: Appropriate version of DeleteTopicsResponse class. - """ - version = self._matching_api_version(DeleteTopicsRequest) - timeout_ms = self._validate_timeout(timeout_ms) - if version <= 3: - request = DeleteTopicsRequest[version]( - topics=topics, - timeout=timeout_ms - ) - response = self._send_request_to_controller(request) - else: - raise NotImplementedError( - "Support for DeleteTopics v{} has not yet been added to KafkaAdminClient." - .format(version)) - return response - - - def _get_cluster_metadata(self, topics=None, auto_topic_creation=False): - """ - topics == None means "get all topics" - """ - version = self._matching_api_version(MetadataRequest) - if version <= 3: - if auto_topic_creation: - raise IncompatibleBrokerVersion( - "auto_topic_creation requires MetadataRequest >= v4, which" - " is not supported by Kafka {}" - .format(self.config['api_version'])) - - request = MetadataRequest[version](topics=topics) - elif version <= 5: - request = MetadataRequest[version]( - topics=topics, - allow_auto_topic_creation=auto_topic_creation - ) - - future = self._send_request_to_node( - self._client.least_loaded_node(), - request - ) - self._wait_for_futures([future]) - return future.value - - def list_topics(self): - metadata = self._get_cluster_metadata(topics=None) - obj = metadata.to_object() - return [t['topic'] for t in obj['topics']] - - def describe_topics(self, topics=None): - metadata = self._get_cluster_metadata(topics=topics) - obj = metadata.to_object() - return obj['topics'] - - def describe_cluster(self): - metadata = self._get_cluster_metadata() - obj = metadata.to_object() - obj.pop('topics') # We have 'describe_topics' for this - return obj - - @staticmethod - def _convert_describe_acls_response_to_acls(describe_response): - version = describe_response.API_VERSION - - error = Errors.for_code(describe_response.error_code) - acl_list = [] - for resources in describe_response.resources: - if version == 0: - resource_type, resource_name, acls = resources - resource_pattern_type = ACLResourcePatternType.LITERAL.value - elif version <= 1: - resource_type, resource_name, resource_pattern_type, acls = resources - else: - raise NotImplementedError( - "Support for DescribeAcls Response v{} has not yet been added to KafkaAdmin." - .format(version) - ) - for acl in acls: - principal, host, operation, permission_type = acl - conv_acl = ACL( - principal=principal, - host=host, - operation=ACLOperation(operation), - permission_type=ACLPermissionType(permission_type), - resource_pattern=ResourcePattern( - ResourceType(resource_type), - resource_name, - ACLResourcePatternType(resource_pattern_type) - ) - ) - acl_list.append(conv_acl) - - return (acl_list, error,) - - def describe_acls(self, acl_filter): - """Describe a set of ACLs - - Used to return a set of ACLs matching the supplied ACLFilter. - The cluster must be configured with an authorizer for this to work, or - you will get a SecurityDisabledError - - :param acl_filter: an ACLFilter object - :return: tuple of a list of matching ACL objects and a KafkaError (NoError if successful) - """ - - version = self._matching_api_version(DescribeAclsRequest) - if version == 0: - request = DescribeAclsRequest[version]( - resource_type=acl_filter.resource_pattern.resource_type, - resource_name=acl_filter.resource_pattern.resource_name, - principal=acl_filter.principal, - host=acl_filter.host, - operation=acl_filter.operation, - permission_type=acl_filter.permission_type - ) - elif version <= 1: - request = DescribeAclsRequest[version]( - resource_type=acl_filter.resource_pattern.resource_type, - resource_name=acl_filter.resource_pattern.resource_name, - resource_pattern_type_filter=acl_filter.resource_pattern.pattern_type, - principal=acl_filter.principal, - host=acl_filter.host, - operation=acl_filter.operation, - permission_type=acl_filter.permission_type - - ) - else: - raise NotImplementedError( - "Support for DescribeAcls v{} has not yet been added to KafkaAdmin." - .format(version) - ) - - future = self._send_request_to_node(self._client.least_loaded_node(), request) - self._wait_for_futures([future]) - response = future.value - - error_type = Errors.for_code(response.error_code) - if error_type is not Errors.NoError: - # optionally we could retry if error_type.retriable - raise error_type( - "Request '{}' failed with response '{}'." - .format(request, response)) - - return self._convert_describe_acls_response_to_acls(response) - - @staticmethod - def _convert_create_acls_resource_request_v0(acl): - - return ( - acl.resource_pattern.resource_type, - acl.resource_pattern.resource_name, - acl.principal, - acl.host, - acl.operation, - acl.permission_type - ) - - @staticmethod - def _convert_create_acls_resource_request_v1(acl): - - return ( - acl.resource_pattern.resource_type, - acl.resource_pattern.resource_name, - acl.resource_pattern.pattern_type, - acl.principal, - acl.host, - acl.operation, - acl.permission_type - ) - - @staticmethod - def _convert_create_acls_response_to_acls(acls, create_response): - version = create_response.API_VERSION - - creations_error = [] - creations_success = [] - for i, creations in enumerate(create_response.creation_responses): - if version <= 1: - error_code, error_message = creations - acl = acls[i] - error = Errors.for_code(error_code) - else: - raise NotImplementedError( - "Support for DescribeAcls Response v{} has not yet been added to KafkaAdmin." - .format(version) - ) - - if error is Errors.NoError: - creations_success.append(acl) - else: - creations_error.append((acl, error,)) - - return {"succeeded": creations_success, "failed": creations_error} - - def create_acls(self, acls): - """Create a list of ACLs - - This endpoint only accepts a list of concrete ACL objects, no ACLFilters. - Throws TopicAlreadyExistsError if topic is already present. - - :param acls: a list of ACL objects - :return: dict of successes and failures - """ - - for acl in acls: - if not isinstance(acl, ACL): - raise IllegalArgumentError("acls must contain ACL objects") - - version = self._matching_api_version(CreateAclsRequest) - if version == 0: - request = CreateAclsRequest[version]( - creations=[self._convert_create_acls_resource_request_v0(acl) for acl in acls] - ) - elif version <= 1: - request = CreateAclsRequest[version]( - creations=[self._convert_create_acls_resource_request_v1(acl) for acl in acls] - ) - else: - raise NotImplementedError( - "Support for CreateAcls v{} has not yet been added to KafkaAdmin." - .format(version) - ) - - future = self._send_request_to_node(self._client.least_loaded_node(), request) - self._wait_for_futures([future]) - response = future.value - - return self._convert_create_acls_response_to_acls(acls, response) - - @staticmethod - def _convert_delete_acls_resource_request_v0(acl): - return ( - acl.resource_pattern.resource_type, - acl.resource_pattern.resource_name, - acl.principal, - acl.host, - acl.operation, - acl.permission_type - ) - - @staticmethod - def _convert_delete_acls_resource_request_v1(acl): - return ( - acl.resource_pattern.resource_type, - acl.resource_pattern.resource_name, - acl.resource_pattern.pattern_type, - acl.principal, - acl.host, - acl.operation, - acl.permission_type - ) - - @staticmethod - def _convert_delete_acls_response_to_matching_acls(acl_filters, delete_response): - version = delete_response.API_VERSION - filter_result_list = [] - for i, filter_responses in enumerate(delete_response.filter_responses): - filter_error_code, filter_error_message, matching_acls = filter_responses - filter_error = Errors.for_code(filter_error_code) - acl_result_list = [] - for acl in matching_acls: - if version == 0: - error_code, error_message, resource_type, resource_name, principal, host, operation, permission_type = acl - resource_pattern_type = ACLResourcePatternType.LITERAL.value - elif version == 1: - error_code, error_message, resource_type, resource_name, resource_pattern_type, principal, host, operation, permission_type = acl - else: - raise NotImplementedError( - "Support for DescribeAcls Response v{} has not yet been added to KafkaAdmin." - .format(version) - ) - acl_error = Errors.for_code(error_code) - conv_acl = ACL( - principal=principal, - host=host, - operation=ACLOperation(operation), - permission_type=ACLPermissionType(permission_type), - resource_pattern=ResourcePattern( - ResourceType(resource_type), - resource_name, - ACLResourcePatternType(resource_pattern_type) - ) - ) - acl_result_list.append((conv_acl, acl_error,)) - filter_result_list.append((acl_filters[i], acl_result_list, filter_error,)) - return filter_result_list - - def delete_acls(self, acl_filters): - """Delete a set of ACLs - - Deletes all ACLs matching the list of input ACLFilter - - :param acl_filters: a list of ACLFilter - :return: a list of 3-tuples corresponding to the list of input filters. - The tuples hold (the input ACLFilter, list of affected ACLs, KafkaError instance) - """ - - for acl in acl_filters: - if not isinstance(acl, ACLFilter): - raise IllegalArgumentError("acl_filters must contain ACLFilter type objects") - - version = self._matching_api_version(DeleteAclsRequest) - - if version == 0: - request = DeleteAclsRequest[version]( - filters=[self._convert_delete_acls_resource_request_v0(acl) for acl in acl_filters] - ) - elif version <= 1: - request = DeleteAclsRequest[version]( - filters=[self._convert_delete_acls_resource_request_v1(acl) for acl in acl_filters] - ) - else: - raise NotImplementedError( - "Support for DeleteAcls v{} has not yet been added to KafkaAdmin." - .format(version) - ) - - future = self._send_request_to_node(self._client.least_loaded_node(), request) - self._wait_for_futures([future]) - response = future.value - - return self._convert_delete_acls_response_to_matching_acls(acl_filters, response) - - @staticmethod - def _convert_describe_config_resource_request(config_resource): - return ( - config_resource.resource_type, - config_resource.name, - [ - config_key for config_key, config_value in config_resource.configs.items() - ] if config_resource.configs else None - ) - - def describe_configs(self, config_resources, include_synonyms=False): - """Fetch configuration parameters for one or more Kafka resources. - - :param config_resources: An list of ConfigResource objects. - Any keys in ConfigResource.configs dict will be used to filter the - result. Setting the configs dict to None will get all values. An - empty dict will get zero values (as per Kafka protocol). - :param include_synonyms: If True, return synonyms in response. Not - supported by all versions. Default: False. - :return: Appropriate version of DescribeConfigsResponse class. - """ - - # Break up requests by type - a broker config request must be sent to the specific broker. - # All other (currently just topic resources) can be sent to any broker. - broker_resources = [] - topic_resources = [] - - for config_resource in config_resources: - if config_resource.resource_type == ConfigResourceType.BROKER: - broker_resources.append(self._convert_describe_config_resource_request(config_resource)) - else: - topic_resources.append(self._convert_describe_config_resource_request(config_resource)) - - futures = [] - version = self._matching_api_version(DescribeConfigsRequest) - if version == 0: - if include_synonyms: - raise IncompatibleBrokerVersion( - "include_synonyms requires DescribeConfigsRequest >= v1, which is not supported by Kafka {}." - .format(self.config['api_version'])) - - if len(broker_resources) > 0: - for broker_resource in broker_resources: - try: - broker_id = int(broker_resource[1]) - except ValueError: - raise ValueError("Broker resource names must be an integer or a string represented integer") - - futures.append(self._send_request_to_node( - broker_id, - DescribeConfigsRequest[version](resources=[broker_resource]) - )) - - if len(topic_resources) > 0: - futures.append(self._send_request_to_node( - self._client.least_loaded_node(), - DescribeConfigsRequest[version](resources=topic_resources) - )) - - elif version <= 2: - if len(broker_resources) > 0: - for broker_resource in broker_resources: - try: - broker_id = int(broker_resource[1]) - except ValueError: - raise ValueError("Broker resource names must be an integer or a string represented integer") - - futures.append(self._send_request_to_node( - broker_id, - DescribeConfigsRequest[version]( - resources=[broker_resource], - include_synonyms=include_synonyms) - )) - - if len(topic_resources) > 0: - futures.append(self._send_request_to_node( - self._client.least_loaded_node(), - DescribeConfigsRequest[version](resources=topic_resources, include_synonyms=include_synonyms) - )) - else: - raise NotImplementedError( - "Support for DescribeConfigs v{} has not yet been added to KafkaAdminClient.".format(version)) - - self._wait_for_futures(futures) - return [f.value for f in futures] - - @staticmethod - def _convert_alter_config_resource_request(config_resource): - return ( - config_resource.resource_type, - config_resource.name, - [ - (config_key, config_value) for config_key, config_value in config_resource.configs.items() - ] - ) - - def alter_configs(self, config_resources): - """Alter configuration parameters of one or more Kafka resources. - - Warning: - This is currently broken for BROKER resources because those must be - sent to that specific broker, versus this always picks the - least-loaded node. See the comment in the source code for details. - We would happily accept a PR fixing this. - - :param config_resources: A list of ConfigResource objects. - :return: Appropriate version of AlterConfigsResponse class. - """ - version = self._matching_api_version(AlterConfigsRequest) - if version <= 1: - request = AlterConfigsRequest[version]( - resources=[self._convert_alter_config_resource_request(config_resource) for config_resource in config_resources] - ) - else: - raise NotImplementedError( - "Support for AlterConfigs v{} has not yet been added to KafkaAdminClient." - .format(version)) - # TODO the Java client has the note: - # // We must make a separate AlterConfigs request for every BROKER resource we want to alter - # // and send the request to that specific broker. Other resources are grouped together into - # // a single request that may be sent to any broker. - # - # So this is currently broken as it always sends to the least_loaded_node() - future = self._send_request_to_node(self._client.least_loaded_node(), request) - - self._wait_for_futures([future]) - response = future.value - return response - - # alter replica logs dir protocol not yet implemented - # Note: have to lookup the broker with the replica assignment and send the request to that broker - - # describe log dirs protocol not yet implemented - # Note: have to lookup the broker with the replica assignment and send the request to that broker - - @staticmethod - def _convert_create_partitions_request(topic_name, new_partitions): - return ( - topic_name, - ( - new_partitions.total_count, - new_partitions.new_assignments - ) - ) - - def create_partitions(self, topic_partitions, timeout_ms=None, validate_only=False): - """Create additional partitions for an existing topic. - - :param topic_partitions: A map of topic name strings to NewPartition objects. - :param timeout_ms: Milliseconds to wait for new partitions to be - created before the broker returns. - :param validate_only: If True, don't actually create new partitions. - Default: False - :return: Appropriate version of CreatePartitionsResponse class. - """ - version = self._matching_api_version(CreatePartitionsRequest) - timeout_ms = self._validate_timeout(timeout_ms) - if version <= 1: - request = CreatePartitionsRequest[version]( - topic_partitions=[self._convert_create_partitions_request(topic_name, new_partitions) for topic_name, new_partitions in topic_partitions.items()], - timeout=timeout_ms, - validate_only=validate_only - ) - else: - raise NotImplementedError( - "Support for CreatePartitions v{} has not yet been added to KafkaAdminClient." - .format(version)) - return self._send_request_to_controller(request) - - # delete records protocol not yet implemented - # Note: send the request to the partition leaders - - # create delegation token protocol not yet implemented - # Note: send the request to the least_loaded_node() - - # renew delegation token protocol not yet implemented - # Note: send the request to the least_loaded_node() - - # expire delegation_token protocol not yet implemented - # Note: send the request to the least_loaded_node() - - # describe delegation_token protocol not yet implemented - # Note: send the request to the least_loaded_node() - - def _describe_consumer_groups_send_request(self, group_id, group_coordinator_id, include_authorized_operations=False): - """Send a DescribeGroupsRequest to the group's coordinator. - - :param group_id: The group name as a string - :param group_coordinator_id: The node_id of the groups' coordinator - broker. - :return: A message future. - """ - version = self._matching_api_version(DescribeGroupsRequest) - if version <= 2: - if include_authorized_operations: - raise IncompatibleBrokerVersion( - "include_authorized_operations requests " - "DescribeGroupsRequest >= v3, which is not " - "supported by Kafka {}".format(version) - ) - # Note: KAFKA-6788 A potential optimization is to group the - # request per coordinator and send one request with a list of - # all consumer groups. Java still hasn't implemented this - # because the error checking is hard to get right when some - # groups error and others don't. - request = DescribeGroupsRequest[version](groups=(group_id,)) - elif version <= 3: - request = DescribeGroupsRequest[version]( - groups=(group_id,), - include_authorized_operations=include_authorized_operations - ) - else: - raise NotImplementedError( - "Support for DescribeGroupsRequest_v{} has not yet been added to KafkaAdminClient." - .format(version)) - return self._send_request_to_node(group_coordinator_id, request) - - def _describe_consumer_groups_process_response(self, response): - """Process a DescribeGroupsResponse into a group description.""" - if response.API_VERSION <= 3: - assert len(response.groups) == 1 - for response_field, response_name in zip(response.SCHEMA.fields, response.SCHEMA.names): - if isinstance(response_field, Array): - described_groups_field_schema = response_field.array_of - described_group = response.__dict__[response_name][0] - described_group_information_list = [] - protocol_type_is_consumer = False - for (described_group_information, group_information_name, group_information_field) in zip(described_group, described_groups_field_schema.names, described_groups_field_schema.fields): - if group_information_name == 'protocol_type': - protocol_type = described_group_information - protocol_type_is_consumer = (protocol_type == ConsumerProtocol.PROTOCOL_TYPE or not protocol_type) - if isinstance(group_information_field, Array): - member_information_list = [] - member_schema = group_information_field.array_of - for members in described_group_information: - member_information = [] - for (member, member_field, member_name) in zip(members, member_schema.fields, member_schema.names): - if protocol_type_is_consumer: - if member_name == 'member_metadata' and member: - member_information.append(ConsumerProtocolMemberMetadata.decode(member)) - elif member_name == 'member_assignment' and member: - member_information.append(ConsumerProtocolMemberAssignment.decode(member)) - else: - member_information.append(member) - member_info_tuple = MemberInformation._make(member_information) - member_information_list.append(member_info_tuple) - described_group_information_list.append(member_information_list) - else: - described_group_information_list.append(described_group_information) - # Version 3 of the DescribeGroups API introduced the "authorized_operations" field. - # This will cause the namedtuple to fail. - # Therefore, appending a placeholder of None in it. - if response.API_VERSION <=2: - described_group_information_list.append(None) - group_description = GroupInformation._make(described_group_information_list) - error_code = group_description.error_code - error_type = Errors.for_code(error_code) - # Java has the note: KAFKA-6789, we can retry based on the error code - if error_type is not Errors.NoError: - raise error_type( - "DescribeGroupsResponse failed with response '{}'." - .format(response)) - else: - raise NotImplementedError( - "Support for DescribeGroupsResponse_v{} has not yet been added to KafkaAdminClient." - .format(response.API_VERSION)) - return group_description - - def describe_consumer_groups(self, group_ids, group_coordinator_id=None, include_authorized_operations=False): - """Describe a set of consumer groups. - - Any errors are immediately raised. - - :param group_ids: A list of consumer group IDs. These are typically the - group names as strings. - :param group_coordinator_id: The node_id of the groups' coordinator - broker. If set to None, it will query the cluster for each group to - find that group's coordinator. Explicitly specifying this can be - useful for avoiding extra network round trips if you already know - the group coordinator. This is only useful when all the group_ids - have the same coordinator, otherwise it will error. Default: None. - :param include_authorized_operations: Whether or not to include - information about the operations a group is allowed to perform. - Only supported on API version >= v3. Default: False. - :return: A list of group descriptions. For now the group descriptions - are the raw results from the DescribeGroupsResponse. Long-term, we - plan to change this to return namedtuples as well as decoding the - partition assignments. - """ - group_descriptions = [] - - if group_coordinator_id is not None: - groups_coordinators = {group_id: group_coordinator_id for group_id in group_ids} - else: - groups_coordinators = self._find_coordinator_ids(group_ids) - - futures = [ - self._describe_consumer_groups_send_request( - group_id, - coordinator_id, - include_authorized_operations) - for group_id, coordinator_id in groups_coordinators.items() - ] - self._wait_for_futures(futures) - - for future in futures: - response = future.value - group_description = self._describe_consumer_groups_process_response(response) - group_descriptions.append(group_description) - - return group_descriptions - - def _list_consumer_groups_send_request(self, broker_id): - """Send a ListGroupsRequest to a broker. - - :param broker_id: The broker's node_id. - :return: A message future - """ - version = self._matching_api_version(ListGroupsRequest) - if version <= 2: - request = ListGroupsRequest[version]() - else: - raise NotImplementedError( - "Support for ListGroupsRequest_v{} has not yet been added to KafkaAdminClient." - .format(version)) - return self._send_request_to_node(broker_id, request) - - def _list_consumer_groups_process_response(self, response): - """Process a ListGroupsResponse into a list of groups.""" - if response.API_VERSION <= 2: - error_type = Errors.for_code(response.error_code) - if error_type is not Errors.NoError: - raise error_type( - "ListGroupsRequest failed with response '{}'." - .format(response)) - else: - raise NotImplementedError( - "Support for ListGroupsResponse_v{} has not yet been added to KafkaAdminClient." - .format(response.API_VERSION)) - return response.groups - - def list_consumer_groups(self, broker_ids=None): - """List all consumer groups known to the cluster. - - This returns a list of Consumer Group tuples. The tuples are - composed of the consumer group name and the consumer group protocol - type. - - Only consumer groups that store their offsets in Kafka are returned. - The protocol type will be an empty string for groups created using - Kafka < 0.9 APIs because, although they store their offsets in Kafka, - they don't use Kafka for group coordination. For groups created using - Kafka >= 0.9, the protocol type will typically be "consumer". - - As soon as any error is encountered, it is immediately raised. - - :param broker_ids: A list of broker node_ids to query for consumer - groups. If set to None, will query all brokers in the cluster. - Explicitly specifying broker(s) can be useful for determining which - consumer groups are coordinated by those broker(s). Default: None - :return list: List of tuples of Consumer Groups. - :exception GroupCoordinatorNotAvailableError: The coordinator is not - available, so cannot process requests. - :exception GroupLoadInProgressError: The coordinator is loading and - hence can't process requests. - """ - # While we return a list, internally use a set to prevent duplicates - # because if a group coordinator fails after being queried, and its - # consumer groups move to new brokers that haven't yet been queried, - # then the same group could be returned by multiple brokers. - consumer_groups = set() - if broker_ids is None: - broker_ids = [broker.nodeId for broker in self._client.cluster.brokers()] - futures = [self._list_consumer_groups_send_request(b) for b in broker_ids] - self._wait_for_futures(futures) - for f in futures: - response = f.value - consumer_groups.update(self._list_consumer_groups_process_response(response)) - return list(consumer_groups) - - def _list_consumer_group_offsets_send_request(self, group_id, - group_coordinator_id, partitions=None): - """Send an OffsetFetchRequest to a broker. - - :param group_id: The consumer group id name for which to fetch offsets. - :param group_coordinator_id: The node_id of the group's coordinator - broker. - :return: A message future - """ - version = self._matching_api_version(OffsetFetchRequest) - if version <= 3: - if partitions is None: - if version <= 1: - raise ValueError( - """OffsetFetchRequest_v{} requires specifying the - partitions for which to fetch offsets. Omitting the - partitions is only supported on brokers >= 0.10.2. - For details, see KIP-88.""".format(version)) - topics_partitions = None - else: - # transform from [TopicPartition("t1", 1), TopicPartition("t1", 2)] to [("t1", [1, 2])] - topics_partitions_dict = defaultdict(set) - for topic, partition in partitions: - topics_partitions_dict[topic].add(partition) - topics_partitions = list(six.iteritems(topics_partitions_dict)) - request = OffsetFetchRequest[version](group_id, topics_partitions) - else: - raise NotImplementedError( - "Support for OffsetFetchRequest_v{} has not yet been added to KafkaAdminClient." - .format(version)) - return self._send_request_to_node(group_coordinator_id, request) - - def _list_consumer_group_offsets_process_response(self, response): - """Process an OffsetFetchResponse. - - :param response: an OffsetFetchResponse. - :return: A dictionary composed of TopicPartition keys and - OffsetAndMetadata values. - """ - if response.API_VERSION <= 3: - - # OffsetFetchResponse_v1 lacks a top-level error_code - if response.API_VERSION > 1: - error_type = Errors.for_code(response.error_code) - if error_type is not Errors.NoError: - # optionally we could retry if error_type.retriable - raise error_type( - "OffsetFetchResponse failed with response '{}'." - .format(response)) - - # transform response into a dictionary with TopicPartition keys and - # OffsetAndMetadata values--this is what the Java AdminClient returns - offsets = {} - for topic, partitions in response.topics: - for partition, offset, metadata, error_code in partitions: - error_type = Errors.for_code(error_code) - if error_type is not Errors.NoError: - raise error_type( - "Unable to fetch consumer group offsets for topic {}, partition {}" - .format(topic, partition)) - offsets[TopicPartition(topic, partition)] = OffsetAndMetadata(offset, metadata) - else: - raise NotImplementedError( - "Support for OffsetFetchResponse_v{} has not yet been added to KafkaAdminClient." - .format(response.API_VERSION)) - return offsets - - def list_consumer_group_offsets(self, group_id, group_coordinator_id=None, - partitions=None): - """Fetch Consumer Offsets for a single consumer group. - - Note: - This does not verify that the group_id or partitions actually exist - in the cluster. - - As soon as any error is encountered, it is immediately raised. - - :param group_id: The consumer group id name for which to fetch offsets. - :param group_coordinator_id: The node_id of the group's coordinator - broker. If set to None, will query the cluster to find the group - coordinator. Explicitly specifying this can be useful to prevent - that extra network round trip if you already know the group - coordinator. Default: None. - :param partitions: A list of TopicPartitions for which to fetch - offsets. On brokers >= 0.10.2, this can be set to None to fetch all - known offsets for the consumer group. Default: None. - :return dictionary: A dictionary with TopicPartition keys and - OffsetAndMetada values. Partitions that are not specified and for - which the group_id does not have a recorded offset are omitted. An - offset value of `-1` indicates the group_id has no offset for that - TopicPartition. A `-1` can only happen for partitions that are - explicitly specified. - """ - if group_coordinator_id is None: - group_coordinator_id = self._find_coordinator_ids([group_id])[group_id] - future = self._list_consumer_group_offsets_send_request( - group_id, group_coordinator_id, partitions) - self._wait_for_futures([future]) - response = future.value - return self._list_consumer_group_offsets_process_response(response) - - def delete_consumer_groups(self, group_ids, group_coordinator_id=None): - """Delete Consumer Group Offsets for given consumer groups. - - Note: - This does not verify that the group ids actually exist and - group_coordinator_id is the correct coordinator for all these groups. - - The result needs checking for potential errors. - - :param group_ids: The consumer group ids of the groups which are to be deleted. - :param group_coordinator_id: The node_id of the broker which is the coordinator for - all the groups. Use only if all groups are coordinated by the same broker. - If set to None, will query the cluster to find the coordinator for every single group. - Explicitly specifying this can be useful to prevent - that extra network round trips if you already know the group - coordinator. Default: None. - :return: A list of tuples (group_id, KafkaError) - """ - if group_coordinator_id is not None: - futures = [self._delete_consumer_groups_send_request(group_ids, group_coordinator_id)] - else: - coordinators_groups = defaultdict(list) - for group_id, coordinator_id in self._find_coordinator_ids(group_ids).items(): - coordinators_groups[coordinator_id].append(group_id) - futures = [ - self._delete_consumer_groups_send_request(group_ids, coordinator_id) - for coordinator_id, group_ids in coordinators_groups.items() - ] - - self._wait_for_futures(futures) - - results = [] - for f in futures: - results.extend(self._convert_delete_groups_response(f.value)) - return results - - def _convert_delete_groups_response(self, response): - if response.API_VERSION <= 1: - results = [] - for group_id, error_code in response.results: - results.append((group_id, Errors.for_code(error_code))) - return results - else: - raise NotImplementedError( - "Support for DeleteGroupsResponse_v{} has not yet been added to KafkaAdminClient." - .format(response.API_VERSION)) - - def _delete_consumer_groups_send_request(self, group_ids, group_coordinator_id): - """Send a DeleteGroups request to a broker. - - :param group_ids: The consumer group ids of the groups which are to be deleted. - :param group_coordinator_id: The node_id of the broker which is the coordinator for - all the groups. - :return: A message future - """ - version = self._matching_api_version(DeleteGroupsRequest) - if version <= 1: - request = DeleteGroupsRequest[version](group_ids) - else: - raise NotImplementedError( - "Support for DeleteGroupsRequest_v{} has not yet been added to KafkaAdminClient." - .format(version)) - return self._send_request_to_node(group_coordinator_id, request) - - def _wait_for_futures(self, futures): - while not all(future.succeeded() for future in futures): - for future in futures: - self._client.poll(future=future) - - if future.failed(): - raise future.exception # pylint: disable-msg=raising-bad-type diff --git a/kafka/admin/new_partitions.py b/kafka/admin/new_partitions.py deleted file mode 100644 index 429b2e19..00000000 --- a/kafka/admin/new_partitions.py +++ /dev/null @@ -1,19 +0,0 @@ -from __future__ import absolute_import - - -class NewPartitions(object): - """A class for new partition creation on existing topics. Note that the length of new_assignments, if specified, - must be the difference between the new total number of partitions and the existing number of partitions. - Arguments: - total_count (int): the total number of partitions that should exist on the topic - new_assignments ([[int]]): an array of arrays of replica assignments for new partitions. - If not set, broker assigns replicas per an internal algorithm. - """ - - def __init__( - self, - total_count, - new_assignments=None - ): - self.total_count = total_count - self.new_assignments = new_assignments diff --git a/kafka/client_async.py b/kafka/client_async.py deleted file mode 100644 index 58f22d4e..00000000 --- a/kafka/client_async.py +++ /dev/null @@ -1,1077 +0,0 @@ -from __future__ import absolute_import, division - -import collections -import copy -import logging -import random -import socket -import threading -import time -import weakref - -# selectors in stdlib as of py3.4 -try: - import selectors # pylint: disable=import-error -except ImportError: - # vendored backport module - from kafka.vendor import selectors34 as selectors - -from kafka.vendor import six - -from kafka.cluster import ClusterMetadata -from kafka.conn import BrokerConnection, ConnectionStates, collect_hosts, get_ip_port_afi -from kafka import errors as Errors -from kafka.future import Future -from kafka.metrics import AnonMeasurable -from kafka.metrics.stats import Avg, Count, Rate -from kafka.metrics.stats.rate import TimeUnit -from kafka.protocol.metadata import MetadataRequest -from kafka.util import Dict, WeakMethod -# Although this looks unused, it actually monkey-patches socket.socketpair() -# and should be left in as long as we're using socket.socketpair() in this file -from kafka.vendor import socketpair -from kafka.version import __version__ - -if six.PY2: - ConnectionError = None - - -log = logging.getLogger('kafka.client') - - -class KafkaClient(object): - """ - A network client for asynchronous request/response network I/O. - - This is an internal class used to implement the user-facing producer and - consumer clients. - - This class is not thread-safe! - - Attributes: - cluster (:any:`ClusterMetadata`): Local cache of cluster metadata, retrieved - via MetadataRequests during :meth:`~kafka.KafkaClient.poll`. - - Keyword Arguments: - bootstrap_servers: 'host[:port]' string (or list of 'host[:port]' - strings) that the client should contact to bootstrap initial - cluster metadata. This does not have to be the full node list. - It just needs to have at least one broker that will respond to a - Metadata API Request. Default port is 9092. If no servers are - specified, will default to localhost:9092. - client_id (str): a name for this client. This string is passed in - each request to servers and can be used to identify specific - server-side log entries that correspond to this client. Also - submitted to GroupCoordinator for logging with respect to - consumer group administration. Default: 'kafka-python-{version}' - reconnect_backoff_ms (int): The amount of time in milliseconds to - wait before attempting to reconnect to a given host. - Default: 50. - reconnect_backoff_max_ms (int): The maximum amount of time in - milliseconds to backoff/wait when reconnecting to a broker that has - repeatedly failed to connect. If provided, the backoff per host - will increase exponentially for each consecutive connection - failure, up to this maximum. Once the maximum is reached, - reconnection attempts will continue periodically with this fixed - rate. To avoid connection storms, a randomization factor of 0.2 - will be applied to the backoff resulting in a random range between - 20% below and 20% above the computed value. Default: 1000. - request_timeout_ms (int): Client request timeout in milliseconds. - Default: 30000. - connections_max_idle_ms: Close idle connections after the number of - milliseconds specified by this config. The broker closes idle - connections after connections.max.idle.ms, so this avoids hitting - unexpected socket disconnected errors on the client. - Default: 540000 - retry_backoff_ms (int): Milliseconds to backoff when retrying on - errors. Default: 100. - max_in_flight_requests_per_connection (int): Requests are pipelined - to kafka brokers up to this number of maximum requests per - broker connection. Default: 5. - receive_buffer_bytes (int): The size of the TCP receive buffer - (SO_RCVBUF) to use when reading data. Default: None (relies on - system defaults). Java client defaults to 32768. - send_buffer_bytes (int): The size of the TCP send buffer - (SO_SNDBUF) to use when sending data. Default: None (relies on - system defaults). Java client defaults to 131072. - socket_options (list): List of tuple-arguments to socket.setsockopt - to apply to broker connection sockets. Default: - [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)] - metadata_max_age_ms (int): The period of time in milliseconds after - which we force a refresh of metadata even if we haven't seen any - partition leadership changes to proactively discover any new - brokers or partitions. Default: 300000 - security_protocol (str): Protocol used to communicate with brokers. - Valid values are: PLAINTEXT, SSL, SASL_PLAINTEXT, SASL_SSL. - Default: PLAINTEXT. - ssl_context (ssl.SSLContext): Pre-configured SSLContext for wrapping - socket connections. If provided, all other ssl_* configurations - will be ignored. Default: None. - ssl_check_hostname (bool): Flag to configure whether SSL handshake - should verify that the certificate matches the broker's hostname. - Default: True. - ssl_cafile (str): Optional filename of CA file to use in certificate - verification. Default: None. - ssl_certfile (str): Optional filename of file in PEM format containing - the client certificate, as well as any CA certificates needed to - establish the certificate's authenticity. Default: None. - ssl_keyfile (str): Optional filename containing the client private key. - Default: None. - ssl_password (str): Optional password to be used when loading the - certificate chain. Default: None. - ssl_crlfile (str): Optional filename containing the CRL to check for - certificate expiration. By default, no CRL check is done. When - providing a file, only the leaf certificate will be checked against - this CRL. The CRL can only be checked with Python 3.4+ or 2.7.9+. - Default: None. - ssl_ciphers (str): optionally set the available ciphers for ssl - connections. It should be a string in the OpenSSL cipher list - format. If no cipher can be selected (because compile-time options - or other configuration forbids use of all the specified ciphers), - an ssl.SSLError will be raised. See ssl.SSLContext.set_ciphers - api_version (tuple): Specify which Kafka API version to use. If set - to None, KafkaClient will attempt to infer the broker version by - probing various APIs. Example: (0, 10, 2). Default: None - api_version_auto_timeout_ms (int): number of milliseconds to throw a - timeout exception from the constructor when checking the broker - api version. Only applies if api_version is None - selector (selectors.BaseSelector): Provide a specific selector - implementation to use for I/O multiplexing. - Default: selectors.DefaultSelector - metrics (kafka.metrics.Metrics): Optionally provide a metrics - instance for capturing network IO stats. Default: None. - metric_group_prefix (str): Prefix for metric names. Default: '' - sasl_mechanism (str): Authentication mechanism when security_protocol - is configured for SASL_PLAINTEXT or SASL_SSL. Valid values are: - PLAIN, GSSAPI, OAUTHBEARER, SCRAM-SHA-256, SCRAM-SHA-512. - sasl_plain_username (str): username for sasl PLAIN and SCRAM authentication. - Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms. - sasl_plain_password (str): password for sasl PLAIN and SCRAM authentication. - Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms. - sasl_kerberos_service_name (str): Service name to include in GSSAPI - sasl mechanism handshake. Default: 'kafka' - sasl_kerberos_domain_name (str): kerberos domain name to use in GSSAPI - sasl mechanism handshake. Default: one of bootstrap servers - sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider - instance. (See kafka.oauth.abstract). Default: None - """ - - DEFAULT_CONFIG = { - 'bootstrap_servers': 'localhost', - 'bootstrap_topics_filter': set(), - 'client_id': 'kafka-python-' + __version__, - 'request_timeout_ms': 30000, - 'wakeup_timeout_ms': 3000, - 'connections_max_idle_ms': 9 * 60 * 1000, - 'reconnect_backoff_ms': 50, - 'reconnect_backoff_max_ms': 1000, - 'max_in_flight_requests_per_connection': 5, - 'receive_buffer_bytes': None, - 'send_buffer_bytes': None, - 'socket_options': [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)], - 'sock_chunk_bytes': 4096, # undocumented experimental option - 'sock_chunk_buffer_count': 1000, # undocumented experimental option - 'retry_backoff_ms': 100, - 'metadata_max_age_ms': 300000, - 'security_protocol': 'PLAINTEXT', - 'ssl_context': None, - 'ssl_check_hostname': True, - 'ssl_cafile': None, - 'ssl_certfile': None, - 'ssl_keyfile': None, - 'ssl_password': None, - 'ssl_crlfile': None, - 'ssl_ciphers': None, - 'api_version': None, - 'api_version_auto_timeout_ms': 2000, - 'selector': selectors.DefaultSelector, - 'metrics': None, - 'metric_group_prefix': '', - 'sasl_mechanism': None, - 'sasl_plain_username': None, - 'sasl_plain_password': None, - 'sasl_kerberos_service_name': 'kafka', - 'sasl_kerberos_domain_name': None, - 'sasl_oauth_token_provider': None - } - - def __init__(self, **configs): - self.config = copy.copy(self.DEFAULT_CONFIG) - for key in self.config: - if key in configs: - self.config[key] = configs[key] - - # these properties need to be set on top of the initialization pipeline - # because they are used when __del__ method is called - self._closed = False - self._wake_r, self._wake_w = socket.socketpair() - self._selector = self.config['selector']() - - self.cluster = ClusterMetadata(**self.config) - self._topics = set() # empty set will fetch all topic metadata - self._metadata_refresh_in_progress = False - self._conns = Dict() # object to support weakrefs - self._api_versions = None - self._connecting = set() - self._sending = set() - self._refresh_on_disconnects = True - self._last_bootstrap = 0 - self._bootstrap_fails = 0 - self._wake_r.setblocking(False) - self._wake_w.settimeout(self.config['wakeup_timeout_ms'] / 1000.0) - self._wake_lock = threading.Lock() - - self._lock = threading.RLock() - - # when requests complete, they are transferred to this queue prior to - # invocation. The purpose is to avoid invoking them while holding the - # lock above. - self._pending_completion = collections.deque() - - self._selector.register(self._wake_r, selectors.EVENT_READ) - self._idle_expiry_manager = IdleConnectionManager(self.config['connections_max_idle_ms']) - self._sensors = None - if self.config['metrics']: - self._sensors = KafkaClientMetrics(self.config['metrics'], - self.config['metric_group_prefix'], - weakref.proxy(self._conns)) - - self._num_bootstrap_hosts = len(collect_hosts(self.config['bootstrap_servers'])) - - # Check Broker Version if not set explicitly - if self.config['api_version'] is None: - check_timeout = self.config['api_version_auto_timeout_ms'] / 1000 - self.config['api_version'] = self.check_version(timeout=check_timeout) - - def _can_bootstrap(self): - effective_failures = self._bootstrap_fails // self._num_bootstrap_hosts - backoff_factor = 2 ** effective_failures - backoff_ms = min(self.config['reconnect_backoff_ms'] * backoff_factor, - self.config['reconnect_backoff_max_ms']) - - backoff_ms *= random.uniform(0.8, 1.2) - - next_at = self._last_bootstrap + backoff_ms / 1000.0 - now = time.time() - if next_at > now: - return False - return True - - def _can_connect(self, node_id): - if node_id not in self._conns: - if self.cluster.broker_metadata(node_id): - return True - return False - conn = self._conns[node_id] - return conn.disconnected() and not conn.blacked_out() - - def _conn_state_change(self, node_id, sock, conn): - with self._lock: - if conn.connecting(): - # SSL connections can enter this state 2x (second during Handshake) - if node_id not in self._connecting: - self._connecting.add(node_id) - try: - self._selector.register(sock, selectors.EVENT_WRITE, conn) - except KeyError: - self._selector.modify(sock, selectors.EVENT_WRITE, conn) - - if self.cluster.is_bootstrap(node_id): - self._last_bootstrap = time.time() - - elif conn.connected(): - log.debug("Node %s connected", node_id) - if node_id in self._connecting: - self._connecting.remove(node_id) - - try: - self._selector.modify(sock, selectors.EVENT_READ, conn) - except KeyError: - self._selector.register(sock, selectors.EVENT_READ, conn) - - if self._sensors: - self._sensors.connection_created.record() - - self._idle_expiry_manager.update(node_id) - - if self.cluster.is_bootstrap(node_id): - self._bootstrap_fails = 0 - - else: - for node_id in list(self._conns.keys()): - if self.cluster.is_bootstrap(node_id): - self._conns.pop(node_id).close() - - # Connection failures imply that our metadata is stale, so let's refresh - elif conn.state is ConnectionStates.DISCONNECTED: - if node_id in self._connecting: - self._connecting.remove(node_id) - try: - self._selector.unregister(sock) - except KeyError: - pass - - if self._sensors: - self._sensors.connection_closed.record() - - idle_disconnect = False - if self._idle_expiry_manager.is_expired(node_id): - idle_disconnect = True - self._idle_expiry_manager.remove(node_id) - - # If the connection has already by popped from self._conns, - # we can assume the disconnect was intentional and not a failure - if node_id not in self._conns: - pass - - elif self.cluster.is_bootstrap(node_id): - self._bootstrap_fails += 1 - - elif self._refresh_on_disconnects and not self._closed and not idle_disconnect: - log.warning("Node %s connection failed -- refreshing metadata", node_id) - self.cluster.request_update() - - def maybe_connect(self, node_id, wakeup=True): - """Queues a node for asynchronous connection during the next .poll()""" - if self._can_connect(node_id): - self._connecting.add(node_id) - # Wakeup signal is useful in case another thread is - # blocked waiting for incoming network traffic while holding - # the client lock in poll(). - if wakeup: - self.wakeup() - return True - return False - - def _should_recycle_connection(self, conn): - # Never recycle unless disconnected - if not conn.disconnected(): - return False - - # Otherwise, only recycle when broker metadata has changed - broker = self.cluster.broker_metadata(conn.node_id) - if broker is None: - return False - - host, _, afi = get_ip_port_afi(broker.host) - if conn.host != host or conn.port != broker.port: - log.info("Broker metadata change detected for node %s" - " from %s:%s to %s:%s", conn.node_id, conn.host, conn.port, - broker.host, broker.port) - return True - - return False - - def _maybe_connect(self, node_id): - """Idempotent non-blocking connection attempt to the given node id.""" - with self._lock: - conn = self._conns.get(node_id) - - if conn is None: - broker = self.cluster.broker_metadata(node_id) - assert broker, 'Broker id %s not in current metadata' % (node_id,) - - log.debug("Initiating connection to node %s at %s:%s", - node_id, broker.host, broker.port) - host, port, afi = get_ip_port_afi(broker.host) - cb = WeakMethod(self._conn_state_change) - conn = BrokerConnection(host, broker.port, afi, - state_change_callback=cb, - node_id=node_id, - **self.config) - self._conns[node_id] = conn - - # Check if existing connection should be recreated because host/port changed - elif self._should_recycle_connection(conn): - self._conns.pop(node_id) - return False - - elif conn.connected(): - return True - - conn.connect() - return conn.connected() - - def ready(self, node_id, metadata_priority=True): - """Check whether a node is connected and ok to send more requests. - - Arguments: - node_id (int): the id of the node to check - metadata_priority (bool): Mark node as not-ready if a metadata - refresh is required. Default: True - - Returns: - bool: True if we are ready to send to the given node - """ - self.maybe_connect(node_id) - return self.is_ready(node_id, metadata_priority=metadata_priority) - - def connected(self, node_id): - """Return True iff the node_id is connected.""" - conn = self._conns.get(node_id) - if conn is None: - return False - return conn.connected() - - def _close(self): - if not self._closed: - self._closed = True - self._wake_r.close() - self._wake_w.close() - self._selector.close() - - def close(self, node_id=None): - """Close one or all broker connections. - - Arguments: - node_id (int, optional): the id of the node to close - """ - with self._lock: - if node_id is None: - self._close() - conns = list(self._conns.values()) - self._conns.clear() - for conn in conns: - conn.close() - elif node_id in self._conns: - self._conns.pop(node_id).close() - else: - log.warning("Node %s not found in current connection list; skipping", node_id) - return - - def __del__(self): - self._close() - - def is_disconnected(self, node_id): - """Check whether the node connection has been disconnected or failed. - - A disconnected node has either been closed or has failed. Connection - failures are usually transient and can be resumed in the next ready() - call, but there are cases where transient failures need to be caught - and re-acted upon. - - Arguments: - node_id (int): the id of the node to check - - Returns: - bool: True iff the node exists and is disconnected - """ - conn = self._conns.get(node_id) - if conn is None: - return False - return conn.disconnected() - - def connection_delay(self, node_id): - """ - Return the number of milliseconds to wait, based on the connection - state, before attempting to send data. When disconnected, this respects - the reconnect backoff time. When connecting, returns 0 to allow - non-blocking connect to finish. When connected, returns a very large - number to handle slow/stalled connections. - - Arguments: - node_id (int): The id of the node to check - - Returns: - int: The number of milliseconds to wait. - """ - conn = self._conns.get(node_id) - if conn is None: - return 0 - return conn.connection_delay() - - def is_ready(self, node_id, metadata_priority=True): - """Check whether a node is ready to send more requests. - - In addition to connection-level checks, this method also is used to - block additional requests from being sent during a metadata refresh. - - Arguments: - node_id (int): id of the node to check - metadata_priority (bool): Mark node as not-ready if a metadata - refresh is required. Default: True - - Returns: - bool: True if the node is ready and metadata is not refreshing - """ - if not self._can_send_request(node_id): - return False - - # if we need to update our metadata now declare all requests unready to - # make metadata requests first priority - if metadata_priority: - if self._metadata_refresh_in_progress: - return False - if self.cluster.ttl() == 0: - return False - return True - - def _can_send_request(self, node_id): - conn = self._conns.get(node_id) - if not conn: - return False - return conn.connected() and conn.can_send_more() - - def send(self, node_id, request, wakeup=True): - """Send a request to a specific node. Bytes are placed on an - internal per-connection send-queue. Actual network I/O will be - triggered in a subsequent call to .poll() - - Arguments: - node_id (int): destination node - request (Struct): request object (not-encoded) - wakeup (bool): optional flag to disable thread-wakeup - - Raises: - AssertionError: if node_id is not in current cluster metadata - - Returns: - Future: resolves to Response struct or Error - """ - conn = self._conns.get(node_id) - if not conn or not self._can_send_request(node_id): - self.maybe_connect(node_id, wakeup=wakeup) - return Future().failure(Errors.NodeNotReadyError(node_id)) - - # conn.send will queue the request internally - # we will need to call send_pending_requests() - # to trigger network I/O - future = conn.send(request, blocking=False) - self._sending.add(conn) - - # Wakeup signal is useful in case another thread is - # blocked waiting for incoming network traffic while holding - # the client lock in poll(). - if wakeup: - self.wakeup() - - return future - - def poll(self, timeout_ms=None, future=None): - """Try to read and write to sockets. - - This method will also attempt to complete node connections, refresh - stale metadata, and run previously-scheduled tasks. - - Arguments: - timeout_ms (int, optional): maximum amount of time to wait (in ms) - for at least one response. Must be non-negative. The actual - timeout will be the minimum of timeout, request timeout and - metadata timeout. Default: request_timeout_ms - future (Future, optional): if provided, blocks until future.is_done - - Returns: - list: responses received (can be empty) - """ - if future is not None: - timeout_ms = 100 - elif timeout_ms is None: - timeout_ms = self.config['request_timeout_ms'] - elif not isinstance(timeout_ms, (int, float)): - raise TypeError('Invalid type for timeout: %s' % type(timeout_ms)) - - # Loop for futures, break after first loop if None - responses = [] - while True: - with self._lock: - if self._closed: - break - - # Attempt to complete pending connections - for node_id in list(self._connecting): - self._maybe_connect(node_id) - - # Send a metadata request if needed - metadata_timeout_ms = self._maybe_refresh_metadata() - - # If we got a future that is already done, don't block in _poll - if future is not None and future.is_done: - timeout = 0 - else: - idle_connection_timeout_ms = self._idle_expiry_manager.next_check_ms() - timeout = min( - timeout_ms, - metadata_timeout_ms, - idle_connection_timeout_ms, - self.config['request_timeout_ms']) - # if there are no requests in flight, do not block longer than the retry backoff - if self.in_flight_request_count() == 0: - timeout = min(timeout, self.config['retry_backoff_ms']) - timeout = max(0, timeout) # avoid negative timeouts - - self._poll(timeout / 1000) - - # called without the lock to avoid deadlock potential - # if handlers need to acquire locks - responses.extend(self._fire_pending_completed_requests()) - - # If all we had was a timeout (future is None) - only do one poll - # If we do have a future, we keep looping until it is done - if future is None or future.is_done: - break - - return responses - - def _register_send_sockets(self): - while self._sending: - conn = self._sending.pop() - try: - key = self._selector.get_key(conn._sock) - events = key.events | selectors.EVENT_WRITE - self._selector.modify(key.fileobj, events, key.data) - except KeyError: - self._selector.register(conn._sock, selectors.EVENT_WRITE, conn) - - def _poll(self, timeout): - # This needs to be locked, but since it is only called from within the - # locked section of poll(), there is no additional lock acquisition here - processed = set() - - # Send pending requests first, before polling for responses - self._register_send_sockets() - - start_select = time.time() - ready = self._selector.select(timeout) - end_select = time.time() - if self._sensors: - self._sensors.select_time.record((end_select - start_select) * 1000000000) - - for key, events in ready: - if key.fileobj is self._wake_r: - self._clear_wake_fd() - continue - - # Send pending requests if socket is ready to write - if events & selectors.EVENT_WRITE: - conn = key.data - if conn.connecting(): - conn.connect() - else: - if conn.send_pending_requests_v2(): - # If send is complete, we dont need to track write readiness - # for this socket anymore - if key.events ^ selectors.EVENT_WRITE: - self._selector.modify( - key.fileobj, - key.events ^ selectors.EVENT_WRITE, - key.data) - else: - self._selector.unregister(key.fileobj) - - if not (events & selectors.EVENT_READ): - continue - conn = key.data - processed.add(conn) - - if not conn.in_flight_requests: - # if we got an EVENT_READ but there were no in-flight requests, one of - # two things has happened: - # - # 1. The remote end closed the connection (because it died, or because - # a firewall timed out, or whatever) - # 2. The protocol is out of sync. - # - # either way, we can no longer safely use this connection - # - # Do a 1-byte read to check protocol didnt get out of sync, and then close the conn - try: - unexpected_data = key.fileobj.recv(1) - if unexpected_data: # anything other than a 0-byte read means protocol issues - log.warning('Protocol out of sync on %r, closing', conn) - except socket.error: - pass - conn.close(Errors.KafkaConnectionError('Socket EVENT_READ without in-flight-requests')) - continue - - self._idle_expiry_manager.update(conn.node_id) - self._pending_completion.extend(conn.recv()) - - # Check for additional pending SSL bytes - if self.config['security_protocol'] in ('SSL', 'SASL_SSL'): - # TODO: optimize - for conn in self._conns.values(): - if conn not in processed and conn.connected() and conn._sock.pending(): - self._pending_completion.extend(conn.recv()) - - for conn in six.itervalues(self._conns): - if conn.requests_timed_out(): - log.warning('%s timed out after %s ms. Closing connection.', - conn, conn.config['request_timeout_ms']) - conn.close(error=Errors.RequestTimedOutError( - 'Request timed out after %s ms' % - conn.config['request_timeout_ms'])) - - if self._sensors: - self._sensors.io_time.record((time.time() - end_select) * 1000000000) - - self._maybe_close_oldest_connection() - - def in_flight_request_count(self, node_id=None): - """Get the number of in-flight requests for a node or all nodes. - - Arguments: - node_id (int, optional): a specific node to check. If unspecified, - return the total for all nodes - - Returns: - int: pending in-flight requests for the node, or all nodes if None - """ - if node_id is not None: - conn = self._conns.get(node_id) - if conn is None: - return 0 - return len(conn.in_flight_requests) - else: - return sum([len(conn.in_flight_requests) - for conn in list(self._conns.values())]) - - def _fire_pending_completed_requests(self): - responses = [] - while True: - try: - # We rely on deque.popleft remaining threadsafe - # to allow both the heartbeat thread and the main thread - # to process responses - response, future = self._pending_completion.popleft() - except IndexError: - break - future.success(response) - responses.append(response) - return responses - - def least_loaded_node(self): - """Choose the node with fewest outstanding requests, with fallbacks. - - This method will prefer a node with an existing connection and no - in-flight-requests. If no such node is found, a node will be chosen - randomly from disconnected nodes that are not "blacked out" (i.e., - are not subject to a reconnect backoff). If no node metadata has been - obtained, will return a bootstrap node (subject to exponential backoff). - - Returns: - node_id or None if no suitable node was found - """ - nodes = [broker.nodeId for broker in self.cluster.brokers()] - random.shuffle(nodes) - - inflight = float('inf') - found = None - for node_id in nodes: - conn = self._conns.get(node_id) - connected = conn is not None and conn.connected() - blacked_out = conn is not None and conn.blacked_out() - curr_inflight = len(conn.in_flight_requests) if conn is not None else 0 - if connected and curr_inflight == 0: - # if we find an established connection - # with no in-flight requests, we can stop right away - return node_id - elif not blacked_out and curr_inflight < inflight: - # otherwise if this is the best we have found so far, record that - inflight = curr_inflight - found = node_id - - return found - - def set_topics(self, topics): - """Set specific topics to track for metadata. - - Arguments: - topics (list of str): topics to check for metadata - - Returns: - Future: resolves after metadata request/response - """ - if set(topics).difference(self._topics): - future = self.cluster.request_update() - else: - future = Future().success(set(topics)) - self._topics = set(topics) - return future - - def add_topic(self, topic): - """Add a topic to the list of topics tracked via metadata. - - Arguments: - topic (str): topic to track - - Returns: - Future: resolves after metadata request/response - """ - if topic in self._topics: - return Future().success(set(self._topics)) - - self._topics.add(topic) - return self.cluster.request_update() - - # This method should be locked when running multi-threaded - def _maybe_refresh_metadata(self, wakeup=False): - """Send a metadata request if needed. - - Returns: - int: milliseconds until next refresh - """ - ttl = self.cluster.ttl() - wait_for_in_progress_ms = self.config['request_timeout_ms'] if self._metadata_refresh_in_progress else 0 - metadata_timeout = max(ttl, wait_for_in_progress_ms) - - if metadata_timeout > 0: - return metadata_timeout - - # Beware that the behavior of this method and the computation of - # timeouts for poll() are highly dependent on the behavior of - # least_loaded_node() - node_id = self.least_loaded_node() - if node_id is None: - log.debug("Give up sending metadata request since no node is available"); - return self.config['reconnect_backoff_ms'] - - if self._can_send_request(node_id): - topics = list(self._topics) - if not topics and self.cluster.is_bootstrap(node_id): - topics = list(self.config['bootstrap_topics_filter']) - - if self.cluster.need_all_topic_metadata or not topics: - topics = [] if self.config['api_version'] < (0, 10) else None - api_version = 0 if self.config['api_version'] < (0, 10) else 1 - request = MetadataRequest[api_version](topics) - log.debug("Sending metadata request %s to node %s", request, node_id) - future = self.send(node_id, request, wakeup=wakeup) - future.add_callback(self.cluster.update_metadata) - future.add_errback(self.cluster.failed_update) - - self._metadata_refresh_in_progress = True - def refresh_done(val_or_error): - self._metadata_refresh_in_progress = False - future.add_callback(refresh_done) - future.add_errback(refresh_done) - return self.config['request_timeout_ms'] - - # If there's any connection establishment underway, wait until it completes. This prevents - # the client from unnecessarily connecting to additional nodes while a previous connection - # attempt has not been completed. - if self._connecting: - return self.config['reconnect_backoff_ms'] - - if self.maybe_connect(node_id, wakeup=wakeup): - log.debug("Initializing connection to node %s for metadata request", node_id) - return self.config['reconnect_backoff_ms'] - - # connected but can't send more, OR connecting - # In either case we just need to wait for a network event - # to let us know the selected connection might be usable again. - return float('inf') - - def get_api_versions(self): - """Return the ApiVersions map, if available. - - Note: A call to check_version must previously have succeeded and returned - version 0.10.0 or later - - Returns: a map of dict mapping {api_key : (min_version, max_version)}, - or None if ApiVersion is not supported by the kafka cluster. - """ - return self._api_versions - - def check_version(self, node_id=None, timeout=2, strict=False): - """Attempt to guess the version of a Kafka broker. - - Note: It is possible that this method blocks longer than the - specified timeout. This can happen if the entire cluster - is down and the client enters a bootstrap backoff sleep. - This is only possible if node_id is None. - - Returns: version tuple, i.e. (0, 10), (0, 9), (0, 8, 2), ... - - Raises: - NodeNotReadyError (if node_id is provided) - NoBrokersAvailable (if node_id is None) - UnrecognizedBrokerVersion: please file bug if seen! - AssertionError (if strict=True): please file bug if seen! - """ - self._lock.acquire() - end = time.time() + timeout - while time.time() < end: - - # It is possible that least_loaded_node falls back to bootstrap, - # which can block for an increasing backoff period - try_node = node_id or self.least_loaded_node() - if try_node is None: - self._lock.release() - raise Errors.NoBrokersAvailable() - self._maybe_connect(try_node) - conn = self._conns[try_node] - - # We will intentionally cause socket failures - # These should not trigger metadata refresh - self._refresh_on_disconnects = False - try: - remaining = end - time.time() - version = conn.check_version(timeout=remaining, strict=strict, topics=list(self.config['bootstrap_topics_filter'])) - if version >= (0, 10, 0): - # cache the api versions map if it's available (starting - # in 0.10 cluster version) - self._api_versions = conn.get_api_versions() - self._lock.release() - return version - except Errors.NodeNotReadyError: - # Only raise to user if this is a node-specific request - if node_id is not None: - self._lock.release() - raise - finally: - self._refresh_on_disconnects = True - - # Timeout - else: - self._lock.release() - raise Errors.NoBrokersAvailable() - - def wakeup(self): - with self._wake_lock: - try: - self._wake_w.sendall(b'x') - except socket.timeout: - log.warning('Timeout to send to wakeup socket!') - raise Errors.KafkaTimeoutError() - except socket.error: - log.warning('Unable to send to wakeup socket!') - - def _clear_wake_fd(self): - # reading from wake socket should only happen in a single thread - while True: - try: - self._wake_r.recv(1024) - except socket.error: - break - - def _maybe_close_oldest_connection(self): - expired_connection = self._idle_expiry_manager.poll_expired_connection() - if expired_connection: - conn_id, ts = expired_connection - idle_ms = (time.time() - ts) * 1000 - log.info('Closing idle connection %s, last active %d ms ago', conn_id, idle_ms) - self.close(node_id=conn_id) - - def bootstrap_connected(self): - """Return True if a bootstrap node is connected""" - for node_id in self._conns: - if not self.cluster.is_bootstrap(node_id): - continue - if self._conns[node_id].connected(): - return True - else: - return False - - -# OrderedDict requires python2.7+ -try: - from collections import OrderedDict -except ImportError: - # If we dont have OrderedDict, we'll fallback to dict with O(n) priority reads - OrderedDict = dict - - -class IdleConnectionManager(object): - def __init__(self, connections_max_idle_ms): - if connections_max_idle_ms > 0: - self.connections_max_idle = connections_max_idle_ms / 1000 - else: - self.connections_max_idle = float('inf') - self.next_idle_close_check_time = None - self.update_next_idle_close_check_time(time.time()) - self.lru_connections = OrderedDict() - - def update(self, conn_id): - # order should reflect last-update - if conn_id in self.lru_connections: - del self.lru_connections[conn_id] - self.lru_connections[conn_id] = time.time() - - def remove(self, conn_id): - if conn_id in self.lru_connections: - del self.lru_connections[conn_id] - - def is_expired(self, conn_id): - if conn_id not in self.lru_connections: - return None - return time.time() >= self.lru_connections[conn_id] + self.connections_max_idle - - def next_check_ms(self): - now = time.time() - if not self.lru_connections: - return float('inf') - elif self.next_idle_close_check_time <= now: - return 0 - else: - return int((self.next_idle_close_check_time - now) * 1000) - - def update_next_idle_close_check_time(self, ts): - self.next_idle_close_check_time = ts + self.connections_max_idle - - def poll_expired_connection(self): - if time.time() < self.next_idle_close_check_time: - return None - - if not len(self.lru_connections): - return None - - oldest_conn_id = None - oldest_ts = None - if OrderedDict is dict: - for conn_id, ts in self.lru_connections.items(): - if oldest_conn_id is None or ts < oldest_ts: - oldest_conn_id = conn_id - oldest_ts = ts - else: - (oldest_conn_id, oldest_ts) = next(iter(self.lru_connections.items())) - - self.update_next_idle_close_check_time(oldest_ts) - - if time.time() >= oldest_ts + self.connections_max_idle: - return (oldest_conn_id, oldest_ts) - else: - return None - - -class KafkaClientMetrics(object): - def __init__(self, metrics, metric_group_prefix, conns): - self.metrics = metrics - self.metric_group_name = metric_group_prefix + '-metrics' - - self.connection_closed = metrics.sensor('connections-closed') - self.connection_closed.add(metrics.metric_name( - 'connection-close-rate', self.metric_group_name, - 'Connections closed per second in the window.'), Rate()) - self.connection_created = metrics.sensor('connections-created') - self.connection_created.add(metrics.metric_name( - 'connection-creation-rate', self.metric_group_name, - 'New connections established per second in the window.'), Rate()) - - self.select_time = metrics.sensor('select-time') - self.select_time.add(metrics.metric_name( - 'select-rate', self.metric_group_name, - 'Number of times the I/O layer checked for new I/O to perform per' - ' second'), Rate(sampled_stat=Count())) - self.select_time.add(metrics.metric_name( - 'io-wait-time-ns-avg', self.metric_group_name, - 'The average length of time the I/O thread spent waiting for a' - ' socket ready for reads or writes in nanoseconds.'), Avg()) - self.select_time.add(metrics.metric_name( - 'io-wait-ratio', self.metric_group_name, - 'The fraction of time the I/O thread spent waiting.'), - Rate(time_unit=TimeUnit.NANOSECONDS)) - - self.io_time = metrics.sensor('io-time') - self.io_time.add(metrics.metric_name( - 'io-time-ns-avg', self.metric_group_name, - 'The average length of time for I/O per select call in nanoseconds.'), - Avg()) - self.io_time.add(metrics.metric_name( - 'io-ratio', self.metric_group_name, - 'The fraction of time the I/O thread spent doing I/O'), - Rate(time_unit=TimeUnit.NANOSECONDS)) - - metrics.add_metric(metrics.metric_name( - 'connection-count', self.metric_group_name, - 'The current number of active connections.'), AnonMeasurable( - lambda config, now: len(conns))) diff --git a/tests/kafka/fixtures.py b/tests/kafka/fixtures.py index f36dd7e8..45e2a053 100644 --- a/tests/kafka/fixtures.py +++ b/tests/kafka/fixtures.py @@ -13,7 +13,7 @@ from kafka.vendor.six.moves import urllib, range from kafka.vendor.six.moves.urllib.parse import urlparse # pylint: disable=E0611,F0401 -from kafka import errors, KafkaAdminClient, KafkaClient +from kafka import errors from kafka.errors import InvalidReplicationFactorError from kafka.protocol.admin import CreateTopicsRequest from kafka.protocol.metadata import MetadataRequest @@ -649,13 +649,3 @@ def _create_many_clients(cnt, cls, *args, **params): for _ in range(cnt): params['client_id'] = '%s_%s' % (client_id, random_string(4)) yield cls(*args, **params) - - def get_clients(self, cnt=1, **params): - params = self._enrich_client_params(params, client_id='client') - for client in self._create_many_clients(cnt, KafkaClient, **params): - yield client - - def get_admin_clients(self, cnt, **params): - params = self._enrich_client_params(params, client_id='admin_client') - for client in self._create_many_clients(cnt, KafkaAdminClient, **params): - yield client diff --git a/tests/kafka/test_acl_comparisons.py b/tests/kafka/test_acl_comparisons.py deleted file mode 100644 index 291bf0e2..00000000 --- a/tests/kafka/test_acl_comparisons.py +++ /dev/null @@ -1,92 +0,0 @@ -from kafka.admin.acl_resource import ACL -from kafka.admin.acl_resource import ACLOperation -from kafka.admin.acl_resource import ACLPermissionType -from kafka.admin.acl_resource import ResourcePattern -from kafka.admin.acl_resource import ResourceType -from kafka.admin.acl_resource import ACLResourcePatternType - - -def test_different_acls_are_different(): - one = ACL( - principal='User:A', - host='*', - operation=ACLOperation.ALL, - permission_type=ACLPermissionType.ALLOW, - resource_pattern=ResourcePattern( - resource_type=ResourceType.TOPIC, - resource_name='some-topic', - pattern_type=ACLResourcePatternType.LITERAL - ) - ) - - two = ACL( - principal='User:B', # Different principal - host='*', - operation=ACLOperation.ALL, - permission_type=ACLPermissionType.ALLOW, - resource_pattern=ResourcePattern( - resource_type=ResourceType.TOPIC, - resource_name='some-topic', - pattern_type=ACLResourcePatternType.LITERAL - ) - ) - - assert one != two - assert hash(one) != hash(two) - -def test_different_acls_are_different_with_glob_topics(): - one = ACL( - principal='User:A', - host='*', - operation=ACLOperation.ALL, - permission_type=ACLPermissionType.ALLOW, - resource_pattern=ResourcePattern( - resource_type=ResourceType.TOPIC, - resource_name='*', - pattern_type=ACLResourcePatternType.LITERAL - ) - ) - - two = ACL( - principal='User:B', # Different principal - host='*', - operation=ACLOperation.ALL, - permission_type=ACLPermissionType.ALLOW, - resource_pattern=ResourcePattern( - resource_type=ResourceType.TOPIC, - resource_name='*', - pattern_type=ACLResourcePatternType.LITERAL - ) - ) - - assert one != two - assert hash(one) != hash(two) - -def test_same_acls_are_same(): - one = ACL( - principal='User:A', - host='*', - operation=ACLOperation.ALL, - permission_type=ACLPermissionType.ALLOW, - resource_pattern=ResourcePattern( - resource_type=ResourceType.TOPIC, - resource_name='some-topic', - pattern_type=ACLResourcePatternType.LITERAL - ) - ) - - two = ACL( - principal='User:A', - host='*', - operation=ACLOperation.ALL, - permission_type=ACLPermissionType.ALLOW, - resource_pattern=ResourcePattern( - resource_type=ResourceType.TOPIC, - resource_name='some-topic', - pattern_type=ACLResourcePatternType.LITERAL - ) - ) - - assert one == two - assert hash(one) == hash(two) - assert len(set((one, two))) == 1 diff --git a/tests/kafka/test_admin.py b/tests/kafka/test_admin.py deleted file mode 100644 index 279f85ab..00000000 --- a/tests/kafka/test_admin.py +++ /dev/null @@ -1,78 +0,0 @@ -import pytest - -import kafka.admin -from kafka.errors import IllegalArgumentError - - -def test_config_resource(): - with pytest.raises(KeyError): - bad_resource = kafka.admin.ConfigResource('something', 'foo') - good_resource = kafka.admin.ConfigResource('broker', 'bar') - assert good_resource.resource_type == kafka.admin.ConfigResourceType.BROKER - assert good_resource.name == 'bar' - assert good_resource.configs is None - good_resource = kafka.admin.ConfigResource(kafka.admin.ConfigResourceType.TOPIC, 'baz', {'frob': 'nob'}) - assert good_resource.resource_type == kafka.admin.ConfigResourceType.TOPIC - assert good_resource.name == 'baz' - assert good_resource.configs == {'frob': 'nob'} - - -def test_new_partitions(): - good_partitions = kafka.admin.NewPartitions(6) - assert good_partitions.total_count == 6 - assert good_partitions.new_assignments is None - good_partitions = kafka.admin.NewPartitions(7, [[1, 2, 3]]) - assert good_partitions.total_count == 7 - assert good_partitions.new_assignments == [[1, 2, 3]] - - -def test_acl_resource(): - good_acl = kafka.admin.ACL( - "User:bar", - "*", - kafka.admin.ACLOperation.ALL, - kafka.admin.ACLPermissionType.ALLOW, - kafka.admin.ResourcePattern( - kafka.admin.ResourceType.TOPIC, - "foo", - kafka.admin.ACLResourcePatternType.LITERAL - ) - ) - - assert(good_acl.resource_pattern.resource_type == kafka.admin.ResourceType.TOPIC) - assert(good_acl.operation == kafka.admin.ACLOperation.ALL) - assert(good_acl.permission_type == kafka.admin.ACLPermissionType.ALLOW) - assert(good_acl.resource_pattern.pattern_type == kafka.admin.ACLResourcePatternType.LITERAL) - - with pytest.raises(IllegalArgumentError): - kafka.admin.ACL( - "User:bar", - "*", - kafka.admin.ACLOperation.ANY, - kafka.admin.ACLPermissionType.ANY, - kafka.admin.ResourcePattern( - kafka.admin.ResourceType.TOPIC, - "foo", - kafka.admin.ACLResourcePatternType.LITERAL - ) - ) - -def test_new_topic(): - with pytest.raises(IllegalArgumentError): - bad_topic = kafka.admin.NewTopic('foo', -1, -1) - with pytest.raises(IllegalArgumentError): - bad_topic = kafka.admin.NewTopic('foo', 1, -1) - with pytest.raises(IllegalArgumentError): - bad_topic = kafka.admin.NewTopic('foo', 1, 1, {1: [1, 1, 1]}) - good_topic = kafka.admin.NewTopic('foo', 1, 2) - assert good_topic.name == 'foo' - assert good_topic.num_partitions == 1 - assert good_topic.replication_factor == 2 - assert good_topic.replica_assignments == {} - assert good_topic.topic_configs == {} - good_topic = kafka.admin.NewTopic('bar', -1, -1, {1: [1, 2, 3]}, {'key': 'value'}) - assert good_topic.name == 'bar' - assert good_topic.num_partitions == -1 - assert good_topic.replication_factor == -1 - assert good_topic.replica_assignments == {1: [1, 2, 3]} - assert good_topic.topic_configs == {'key': 'value'} diff --git a/tests/kafka/test_admin_integration.py b/tests/kafka/test_admin_integration.py deleted file mode 100644 index 87ba289d..00000000 --- a/tests/kafka/test_admin_integration.py +++ /dev/null @@ -1,314 +0,0 @@ -import pytest - -from logging import info -from tests.kafka.testutil import env_kafka_version, random_string -from threading import Event, Thread -from time import time, sleep - -from kafka.admin import ( - ACLFilter, ACLOperation, ACLPermissionType, ResourcePattern, ResourceType, ACL, ConfigResource, ConfigResourceType) -from kafka.errors import (NoError, GroupCoordinatorNotAvailableError, NonEmptyGroupError, GroupIdNotFoundError) - - -@pytest.mark.skipif(env_kafka_version() < (0, 11), reason="ACL features require broker >=0.11") -def test_create_describe_delete_acls(kafka_admin_client): - """Tests that we can add, list and remove ACLs - """ - - # Check that we don't have any ACLs in the cluster - acls, error = kafka_admin_client.describe_acls( - ACLFilter( - principal=None, - host="*", - operation=ACLOperation.ANY, - permission_type=ACLPermissionType.ANY, - resource_pattern=ResourcePattern(ResourceType.TOPIC, "topic") - ) - ) - - assert error is NoError - assert len(acls) == 0 - - # Try to add an ACL - acl = ACL( - principal="User:test", - host="*", - operation=ACLOperation.READ, - permission_type=ACLPermissionType.ALLOW, - resource_pattern=ResourcePattern(ResourceType.TOPIC, "topic") - ) - result = kafka_admin_client.create_acls([acl]) - - assert len(result["failed"]) == 0 - assert len(result["succeeded"]) == 1 - - # Check that we can list the ACL we created - acl_filter = ACLFilter( - principal=None, - host="*", - operation=ACLOperation.ANY, - permission_type=ACLPermissionType.ANY, - resource_pattern=ResourcePattern(ResourceType.TOPIC, "topic") - ) - acls, error = kafka_admin_client.describe_acls(acl_filter) - - assert error is NoError - assert len(acls) == 1 - - # Remove the ACL - delete_results = kafka_admin_client.delete_acls( - [ - ACLFilter( - principal="User:test", - host="*", - operation=ACLOperation.READ, - permission_type=ACLPermissionType.ALLOW, - resource_pattern=ResourcePattern(ResourceType.TOPIC, "topic") - ) - ] - ) - - assert len(delete_results) == 1 - assert len(delete_results[0][1]) == 1 # Check number of affected ACLs - - # Make sure the ACL does not exist in the cluster anymore - acls, error = kafka_admin_client.describe_acls( - ACLFilter( - principal="*", - host="*", - operation=ACLOperation.ANY, - permission_type=ACLPermissionType.ANY, - resource_pattern=ResourcePattern(ResourceType.TOPIC, "topic") - ) - ) - - assert error is NoError - assert len(acls) == 0 - - -@pytest.mark.skipif(env_kafka_version() < (0, 11), reason="Describe config features require broker >=0.11") -def test_describe_configs_broker_resource_returns_configs(kafka_admin_client): - """Tests that describe config returns configs for broker - """ - broker_id = kafka_admin_client._client.cluster._brokers[0].nodeId - configs = kafka_admin_client.describe_configs([ConfigResource(ConfigResourceType.BROKER, broker_id)]) - - assert len(configs) == 1 - assert configs[0].resources[0][2] == ConfigResourceType.BROKER - assert configs[0].resources[0][3] == str(broker_id) - assert len(configs[0].resources[0][4]) > 1 - - -@pytest.mark.xfail(condition=True, - reason="https://github.com/dpkp/kafka-python/issues/1929", - raises=AssertionError) -@pytest.mark.skipif(env_kafka_version() < (0, 11), reason="Describe config features require broker >=0.11") -def test_describe_configs_topic_resource_returns_configs(topic, kafka_admin_client): - """Tests that describe config returns configs for topic - """ - configs = kafka_admin_client.describe_configs([ConfigResource(ConfigResourceType.TOPIC, topic)]) - - assert len(configs) == 1 - assert configs[0].resources[0][2] == ConfigResourceType.TOPIC - assert configs[0].resources[0][3] == topic - assert len(configs[0].resources[0][4]) > 1 - - -@pytest.mark.skipif(env_kafka_version() < (0, 11), reason="Describe config features require broker >=0.11") -def test_describe_configs_mixed_resources_returns_configs(topic, kafka_admin_client): - """Tests that describe config returns configs for mixed resource types (topic + broker) - """ - broker_id = kafka_admin_client._client.cluster._brokers[0].nodeId - configs = kafka_admin_client.describe_configs([ - ConfigResource(ConfigResourceType.TOPIC, topic), - ConfigResource(ConfigResourceType.BROKER, broker_id)]) - - assert len(configs) == 2 - - for config in configs: - assert (config.resources[0][2] == ConfigResourceType.TOPIC - and config.resources[0][3] == topic) or \ - (config.resources[0][2] == ConfigResourceType.BROKER - and config.resources[0][3] == str(broker_id)) - assert len(config.resources[0][4]) > 1 - - -@pytest.mark.skipif(env_kafka_version() < (0, 11), reason="Describe config features require broker >=0.11") -def test_describe_configs_invalid_broker_id_raises(kafka_admin_client): - """Tests that describe config raises exception on non-integer broker id - """ - broker_id = "str" - - with pytest.raises(ValueError): - configs = kafka_admin_client.describe_configs([ConfigResource(ConfigResourceType.BROKER, broker_id)]) - - -@pytest.mark.skipif(env_kafka_version() < (0, 11), reason='Describe consumer group requires broker >=0.11') -def test_describe_consumer_group_does_not_exist(kafka_admin_client): - """Tests that the describe consumer group call fails if the group coordinator is not available - """ - with pytest.raises(GroupCoordinatorNotAvailableError): - group_description = kafka_admin_client.describe_consumer_groups(['test']) - - -@pytest.mark.skipif(env_kafka_version() < (0, 11), reason='Describe consumer group requires broker >=0.11') -def test_describe_consumer_group_exists(kafka_admin_client, kafka_consumer_factory, topic): - """Tests that the describe consumer group call returns valid consumer group information - This test takes inspiration from the test 'test_group' in test_consumer_group.py. - """ - consumers = {} - stop = {} - threads = {} - random_group_id = 'test-group-' + random_string(6) - group_id_list = [random_group_id, random_group_id + '_2'] - generations = {group_id_list[0]: set(), group_id_list[1]: set()} - def consumer_thread(i, group_id): - assert i not in consumers - assert i not in stop - stop[i] = Event() - consumers[i] = kafka_consumer_factory(group_id=group_id) - while not stop[i].is_set(): - consumers[i].poll(20) - consumers[i].close() - consumers[i] = None - stop[i] = None - - num_consumers = 3 - for i in range(num_consumers): - group_id = group_id_list[i % 2] - t = Thread(target=consumer_thread, args=(i, group_id,)) - t.start() - threads[i] = t - - try: - timeout = time() + 35 - while True: - for c in range(num_consumers): - - # Verify all consumers have been created - if c not in consumers: - break - - # Verify all consumers have an assignment - elif not consumers[c].assignment(): - break - - # If all consumers exist and have an assignment - else: - - info('All consumers have assignment... checking for stable group') - # Verify all consumers are in the same generation - # then log state and break while loop - - for consumer in consumers.values(): - generations[consumer.config['group_id']].add(consumer._coordinator._generation.generation_id) - - is_same_generation = any([len(consumer_generation) == 1 for consumer_generation in generations.values()]) - - # New generation assignment is not complete until - # coordinator.rejoining = False - rejoining = any([consumer._coordinator.rejoining - for consumer in list(consumers.values())]) - - if not rejoining and is_same_generation: - break - else: - sleep(1) - assert time() < timeout, "timeout waiting for assignments" - - info('Group stabilized; verifying assignment') - output = kafka_admin_client.describe_consumer_groups(group_id_list) - assert len(output) == 2 - consumer_groups = set() - for consumer_group in output: - assert(consumer_group.group in group_id_list) - if consumer_group.group == group_id_list[0]: - assert(len(consumer_group.members) == 2) - else: - assert(len(consumer_group.members) == 1) - for member in consumer_group.members: - assert(member.member_metadata.subscription[0] == topic) - assert(member.member_assignment.assignment[0][0] == topic) - consumer_groups.add(consumer_group.group) - assert(sorted(list(consumer_groups)) == group_id_list) - finally: - info('Shutting down %s consumers', num_consumers) - for c in range(num_consumers): - info('Stopping consumer %s', c) - stop[c].set() - threads[c].join() - threads[c] = None - - -@pytest.mark.skipif(env_kafka_version() < (1, 1), reason="Delete consumer groups requires broker >=1.1") -def test_delete_consumergroups(kafka_admin_client, kafka_consumer_factory, send_messages): - random_group_id = 'test-group-' + random_string(6) - group1 = random_group_id + "_1" - group2 = random_group_id + "_2" - group3 = random_group_id + "_3" - - send_messages(range(0, 100), partition=0) - consumer1 = kafka_consumer_factory(group_id=group1) - next(consumer1) - consumer1.close() - - consumer2 = kafka_consumer_factory(group_id=group2) - next(consumer2) - consumer2.close() - - consumer3 = kafka_consumer_factory(group_id=group3) - next(consumer3) - consumer3.close() - - consumergroups = {group_id for group_id, _ in kafka_admin_client.list_consumer_groups()} - assert group1 in consumergroups - assert group2 in consumergroups - assert group3 in consumergroups - - delete_results = { - group_id: error - for group_id, error in kafka_admin_client.delete_consumer_groups([group1, group2]) - } - assert delete_results[group1] == NoError - assert delete_results[group2] == NoError - assert group3 not in delete_results - - consumergroups = {group_id for group_id, _ in kafka_admin_client.list_consumer_groups()} - assert group1 not in consumergroups - assert group2 not in consumergroups - assert group3 in consumergroups - - -@pytest.mark.skipif(env_kafka_version() < (1, 1), reason="Delete consumer groups requires broker >=1.1") -def test_delete_consumergroups_with_errors(kafka_admin_client, kafka_consumer_factory, send_messages): - random_group_id = 'test-group-' + random_string(6) - group1 = random_group_id + "_1" - group2 = random_group_id + "_2" - group3 = random_group_id + "_3" - - send_messages(range(0, 100), partition=0) - consumer1 = kafka_consumer_factory(group_id=group1) - next(consumer1) - consumer1.close() - - consumer2 = kafka_consumer_factory(group_id=group2) - next(consumer2) - - consumergroups = {group_id for group_id, _ in kafka_admin_client.list_consumer_groups()} - assert group1 in consumergroups - assert group2 in consumergroups - assert group3 not in consumergroups - - delete_results = { - group_id: error - for group_id, error in kafka_admin_client.delete_consumer_groups([group1, group2, group3]) - } - - assert delete_results[group1] == NoError - assert delete_results[group2] == NonEmptyGroupError - assert delete_results[group3] == GroupIdNotFoundError - - consumergroups = {group_id for group_id, _ in kafka_admin_client.list_consumer_groups()} - assert group1 not in consumergroups - assert group2 in consumergroups - assert group3 not in consumergroups diff --git a/tests/kafka/test_client_async.py b/tests/kafka/test_client_async.py deleted file mode 100644 index 74da66a3..00000000 --- a/tests/kafka/test_client_async.py +++ /dev/null @@ -1,409 +0,0 @@ -from __future__ import absolute_import, division - -# selectors in stdlib as of py3.4 -try: - import selectors # pylint: disable=import-error -except ImportError: - # vendored backport module - import kafka.vendor.selectors34 as selectors - -import socket -import time - -import pytest - -from kafka.client_async import KafkaClient, IdleConnectionManager -from kafka.cluster import ClusterMetadata -from kafka.conn import ConnectionStates -import kafka.errors as Errors -from kafka.future import Future -from kafka.protocol.metadata import MetadataRequest -from kafka.protocol.produce import ProduceRequest -from kafka.structs import BrokerMetadata - - -@pytest.fixture -def cli(mocker, conn): - client = KafkaClient(api_version=(0, 9)) - mocker.patch.object(client, '_selector') - client.poll(future=client.cluster.request_update()) - return client - - -def test_bootstrap(mocker, conn): - conn.state = ConnectionStates.CONNECTED - cli = KafkaClient(api_version=(0, 9)) - mocker.patch.object(cli, '_selector') - future = cli.cluster.request_update() - cli.poll(future=future) - - assert future.succeeded() - args, kwargs = conn.call_args - assert args == ('localhost', 9092, socket.AF_UNSPEC) - kwargs.pop('state_change_callback') - kwargs.pop('node_id') - assert kwargs == cli.config - conn.send.assert_called_once_with(MetadataRequest[0]([]), blocking=False) - assert cli._bootstrap_fails == 0 - assert cli.cluster.brokers() == set([BrokerMetadata(0, 'foo', 12, None), - BrokerMetadata(1, 'bar', 34, None)]) - - -def test_can_connect(cli, conn): - # Node is not in broker metadata - can't connect - assert not cli._can_connect(2) - - # Node is in broker metadata but not in _conns - assert 0 not in cli._conns - assert cli._can_connect(0) - - # Node is connected, can't reconnect - assert cli._maybe_connect(0) is True - assert not cli._can_connect(0) - - # Node is disconnected, can connect - cli._conns[0].state = ConnectionStates.DISCONNECTED - assert cli._can_connect(0) - - # Node is disconnected, but blacked out - conn.blacked_out.return_value = True - assert not cli._can_connect(0) - - -def test_maybe_connect(cli, conn): - try: - # Node not in metadata, raises AssertionError - cli._maybe_connect(2) - except AssertionError: - pass - else: - assert False, 'Exception not raised' - - # New node_id creates a conn object - assert 0 not in cli._conns - conn.state = ConnectionStates.DISCONNECTED - conn.connect.side_effect = lambda: conn._set_conn_state(ConnectionStates.CONNECTING) - assert cli._maybe_connect(0) is False - assert cli._conns[0] is conn - - -def test_conn_state_change(mocker, cli, conn): - sel = cli._selector - - node_id = 0 - cli._conns[node_id] = conn - conn.state = ConnectionStates.CONNECTING - sock = conn._sock - cli._conn_state_change(node_id, sock, conn) - assert node_id in cli._connecting - sel.register.assert_called_with(sock, selectors.EVENT_WRITE, conn) - - conn.state = ConnectionStates.CONNECTED - cli._conn_state_change(node_id, sock, conn) - assert node_id not in cli._connecting - sel.modify.assert_called_with(sock, selectors.EVENT_READ, conn) - - # Failure to connect should trigger metadata update - assert cli.cluster._need_update is False - conn.state = ConnectionStates.DISCONNECTED - cli._conn_state_change(node_id, sock, conn) - assert node_id not in cli._connecting - assert cli.cluster._need_update is True - sel.unregister.assert_called_with(sock) - - conn.state = ConnectionStates.CONNECTING - cli._conn_state_change(node_id, sock, conn) - assert node_id in cli._connecting - conn.state = ConnectionStates.DISCONNECTED - cli._conn_state_change(node_id, sock, conn) - assert node_id not in cli._connecting - - -def test_ready(mocker, cli, conn): - maybe_connect = mocker.patch.object(cli, 'maybe_connect') - node_id = 1 - cli.ready(node_id) - maybe_connect.assert_called_with(node_id) - - -def test_is_ready(mocker, cli, conn): - cli._maybe_connect(0) - cli._maybe_connect(1) - - # metadata refresh blocks ready nodes - assert cli.is_ready(0) - assert cli.is_ready(1) - cli._metadata_refresh_in_progress = True - assert not cli.is_ready(0) - assert not cli.is_ready(1) - - # requesting metadata update also blocks ready nodes - cli._metadata_refresh_in_progress = False - assert cli.is_ready(0) - assert cli.is_ready(1) - cli.cluster.request_update() - cli.cluster.config['retry_backoff_ms'] = 0 - assert not cli._metadata_refresh_in_progress - assert not cli.is_ready(0) - assert not cli.is_ready(1) - cli.cluster._need_update = False - - # if connection can't send more, not ready - assert cli.is_ready(0) - conn.can_send_more.return_value = False - assert not cli.is_ready(0) - conn.can_send_more.return_value = True - - # disconnected nodes, not ready - assert cli.is_ready(0) - conn.state = ConnectionStates.DISCONNECTED - assert not cli.is_ready(0) - - -def test_close(mocker, cli, conn): - mocker.patch.object(cli, '_selector') - - call_count = conn.close.call_count - - # Unknown node - silent - cli.close(2) - call_count += 0 - assert conn.close.call_count == call_count - - # Single node close - cli._maybe_connect(0) - assert conn.close.call_count == call_count - cli.close(0) - call_count += 1 - assert conn.close.call_count == call_count - - # All node close - cli._maybe_connect(1) - cli.close() - # +2 close: node 1, node bootstrap (node 0 already closed) - call_count += 2 - assert conn.close.call_count == call_count - - -def test_is_disconnected(cli, conn): - # False if not connected yet - conn.state = ConnectionStates.DISCONNECTED - assert not cli.is_disconnected(0) - - cli._maybe_connect(0) - assert cli.is_disconnected(0) - - conn.state = ConnectionStates.CONNECTING - assert not cli.is_disconnected(0) - - conn.state = ConnectionStates.CONNECTED - assert not cli.is_disconnected(0) - - -def test_send(cli, conn): - # Send to unknown node => raises AssertionError - try: - cli.send(2, None) - assert False, 'Exception not raised' - except AssertionError: - pass - - # Send to disconnected node => NodeNotReady - conn.state = ConnectionStates.DISCONNECTED - f = cli.send(0, None) - assert f.failed() - assert isinstance(f.exception, Errors.NodeNotReadyError) - - conn.state = ConnectionStates.CONNECTED - cli._maybe_connect(0) - # ProduceRequest w/ 0 required_acks -> no response - request = ProduceRequest[0](0, 0, []) - assert request.expect_response() is False - ret = cli.send(0, request) - assert conn.send.called_with(request) - assert isinstance(ret, Future) - - request = MetadataRequest[0]([]) - cli.send(0, request) - assert conn.send.called_with(request) - - -def test_poll(mocker): - metadata = mocker.patch.object(KafkaClient, '_maybe_refresh_metadata') - _poll = mocker.patch.object(KafkaClient, '_poll') - ifrs = mocker.patch.object(KafkaClient, 'in_flight_request_count') - ifrs.return_value = 1 - cli = KafkaClient(api_version=(0, 9)) - - # metadata timeout wins - metadata.return_value = 1000 - cli.poll() - _poll.assert_called_with(1.0) - - # user timeout wins - cli.poll(250) - _poll.assert_called_with(0.25) - - # default is request_timeout_ms - metadata.return_value = 1000000 - cli.poll() - _poll.assert_called_with(cli.config['request_timeout_ms'] / 1000.0) - - # If no in-flight-requests, drop timeout to retry_backoff_ms - ifrs.return_value = 0 - cli.poll() - _poll.assert_called_with(cli.config['retry_backoff_ms'] / 1000.0) - - -def test__poll(): - pass - - -def test_in_flight_request_count(): - pass - - -def test_least_loaded_node(): - pass - - -def test_set_topics(mocker): - request_update = mocker.patch.object(ClusterMetadata, 'request_update') - request_update.side_effect = lambda: Future() - cli = KafkaClient(api_version=(0, 10)) - - # replace 'empty' with 'non empty' - request_update.reset_mock() - fut = cli.set_topics(['t1', 't2']) - assert not fut.is_done - request_update.assert_called_with() - - # replace 'non empty' with 'same' - request_update.reset_mock() - fut = cli.set_topics(['t1', 't2']) - assert fut.is_done - assert fut.value == set(['t1', 't2']) - request_update.assert_not_called() - - # replace 'non empty' with 'empty' - request_update.reset_mock() - fut = cli.set_topics([]) - assert fut.is_done - assert fut.value == set() - request_update.assert_not_called() - - -@pytest.fixture -def client(mocker): - _poll = mocker.patch.object(KafkaClient, '_poll') - - cli = KafkaClient(request_timeout_ms=9999999, - reconnect_backoff_ms=2222, - connections_max_idle_ms=float('inf'), - api_version=(0, 9)) - - ttl = mocker.patch.object(cli.cluster, 'ttl') - ttl.return_value = 0 - return cli - - -def test_maybe_refresh_metadata_ttl(mocker, client): - client.cluster.ttl.return_value = 1234 - mocker.patch.object(KafkaClient, 'in_flight_request_count', return_value=1) - - client.poll(timeout_ms=12345678) - client._poll.assert_called_with(1.234) - - -def test_maybe_refresh_metadata_backoff(mocker, client): - mocker.patch.object(KafkaClient, 'in_flight_request_count', return_value=1) - now = time.time() - t = mocker.patch('time.time') - t.return_value = now - - client.poll(timeout_ms=12345678) - client._poll.assert_called_with(2.222) # reconnect backoff - - -def test_maybe_refresh_metadata_in_progress(mocker, client): - client._metadata_refresh_in_progress = True - mocker.patch.object(KafkaClient, 'in_flight_request_count', return_value=1) - - client.poll(timeout_ms=12345678) - client._poll.assert_called_with(9999.999) # request_timeout_ms - - -def test_maybe_refresh_metadata_update(mocker, client): - mocker.patch.object(client, 'least_loaded_node', return_value='foobar') - mocker.patch.object(client, '_can_send_request', return_value=True) - mocker.patch.object(KafkaClient, 'in_flight_request_count', return_value=1) - send = mocker.patch.object(client, 'send') - - client.poll(timeout_ms=12345678) - client._poll.assert_called_with(9999.999) # request_timeout_ms - assert client._metadata_refresh_in_progress - request = MetadataRequest[0]([]) - send.assert_called_once_with('foobar', request, wakeup=False) - - -def test_maybe_refresh_metadata_cant_send(mocker, client): - mocker.patch.object(client, 'least_loaded_node', return_value='foobar') - mocker.patch.object(client, '_can_connect', return_value=True) - mocker.patch.object(client, '_maybe_connect', return_value=True) - mocker.patch.object(client, 'maybe_connect', return_value=True) - mocker.patch.object(KafkaClient, 'in_flight_request_count', return_value=1) - - now = time.time() - t = mocker.patch('time.time') - t.return_value = now - - # first poll attempts connection - client.poll(timeout_ms=12345678) - client._poll.assert_called_with(2.222) # reconnect backoff - client.maybe_connect.assert_called_once_with('foobar', wakeup=False) - - # poll while connecting should not attempt a new connection - client._connecting.add('foobar') - client._can_connect.reset_mock() - client.poll(timeout_ms=12345678) - client._poll.assert_called_with(2.222) # connection timeout (reconnect timeout) - assert not client._can_connect.called - - assert not client._metadata_refresh_in_progress - - -def test_schedule(): - pass - - -def test_unschedule(): - pass - - -def test_idle_connection_manager(mocker): - t = mocker.patch.object(time, 'time') - t.return_value = 0 - - idle = IdleConnectionManager(100) - assert idle.next_check_ms() == float('inf') - - idle.update('foo') - assert not idle.is_expired('foo') - assert idle.poll_expired_connection() is None - assert idle.next_check_ms() == 100 - - t.return_value = 90 / 1000 - assert not idle.is_expired('foo') - assert idle.poll_expired_connection() is None - assert idle.next_check_ms() == 10 - - t.return_value = 100 / 1000 - assert idle.is_expired('foo') - assert idle.next_check_ms() == 0 - - conn_id, conn_ts = idle.poll_expired_connection() - assert conn_id == 'foo' - assert conn_ts == 0 - - idle.remove('foo') - assert idle.next_check_ms() == float('inf') diff --git a/tests/kafka/test_sasl_integration.py b/tests/kafka/test_sasl_integration.py deleted file mode 100644 index d66a7349..00000000 --- a/tests/kafka/test_sasl_integration.py +++ /dev/null @@ -1,80 +0,0 @@ -import logging -import uuid - -import pytest - -from kafka.admin import NewTopic -from kafka.protocol.metadata import MetadataRequest_v1 -from tests.kafka.testutil import assert_message_count, env_kafka_version, random_string, special_to_underscore - - -@pytest.fixture( - params=[ - pytest.param( - "PLAIN", marks=pytest.mark.skipif(env_kafka_version() < (0, 10), reason="Requires KAFKA_VERSION >= 0.10") - ), - pytest.param( - "SCRAM-SHA-256", - marks=pytest.mark.skipif(env_kafka_version() < (0, 10, 2), reason="Requires KAFKA_VERSION >= 0.10.2"), - ), - pytest.param( - "SCRAM-SHA-512", - marks=pytest.mark.skipif(env_kafka_version() < (0, 10, 2), reason="Requires KAFKA_VERSION >= 0.10.2"), - ), - ] -) -def sasl_kafka(request, kafka_broker_factory): - sasl_kafka = kafka_broker_factory(transport="SASL_PLAINTEXT", sasl_mechanism=request.param)[0] - yield sasl_kafka - sasl_kafka.child.dump_logs() - - -def test_admin(request, sasl_kafka): - topic_name = special_to_underscore(request.node.name + random_string(4)) - admin, = sasl_kafka.get_admin_clients(1) - admin.create_topics([NewTopic(topic_name, 1, 1)]) - assert topic_name in sasl_kafka.get_topic_names() - - -def test_produce_and_consume(request, sasl_kafka): - topic_name = special_to_underscore(request.node.name + random_string(4)) - sasl_kafka.create_topics([topic_name], num_partitions=2) - producer, = sasl_kafka.get_producers(1) - - messages_and_futures = [] # [(message, produce_future),] - for i in range(100): - encoded_msg = "{}-{}-{}".format(i, request.node.name, uuid.uuid4()).encode("utf-8") - future = producer.send(topic_name, value=encoded_msg, partition=i % 2) - messages_and_futures.append((encoded_msg, future)) - producer.flush() - - for (msg, f) in messages_and_futures: - assert f.succeeded() - - consumer, = sasl_kafka.get_consumers(1, [topic_name]) - messages = {0: [], 1: []} - for i, message in enumerate(consumer, 1): - logging.debug("Consumed message %s", repr(message)) - messages[message.partition].append(message) - if i >= 100: - break - - assert_message_count(messages[0], 50) - assert_message_count(messages[1], 50) - - -def test_client(request, sasl_kafka): - topic_name = special_to_underscore(request.node.name + random_string(4)) - sasl_kafka.create_topics([topic_name], num_partitions=1) - - client, = sasl_kafka.get_clients(1) - request = MetadataRequest_v1(None) - client.send(0, request) - for _ in range(10): - result = client.poll(timeout_ms=10000) - if len(result) > 0: - break - else: - raise RuntimeError("Couldn't fetch topic response from Broker.") - result = result[0] - assert topic_name in [t[1] for t in result.topics] diff --git a/tests/test_admin.py b/tests/test_admin.py index c7db2096..bead1ef6 100644 --- a/tests/test_admin.py +++ b/tests/test_admin.py @@ -1,9 +1,7 @@ import asyncio -from kafka.admin import NewTopic, NewPartitions -from kafka.admin.config_resource import ConfigResource, ConfigResourceType - -from aiokafka.admin import AIOKafkaAdminClient +from aiokafka.admin import AIOKafkaAdminClient, NewTopic, NewPartitions +from aiokafka.admin.config_resource import ConfigResource, ConfigResourceType from aiokafka.consumer import AIOKafkaConsumer from aiokafka.producer import AIOKafkaProducer from aiokafka.structs import TopicPartition From adde1521e38b12501bad609302ed10a443ba7f48 Mon Sep 17 00:00:00 2001 From: Denis Otkidach Date: Sun, 22 Oct 2023 18:08:40 +0300 Subject: [PATCH 09/20] Merge errors --- aiokafka/admin/client.py | 2 +- aiokafka/admin/new_topic.py | 2 +- aiokafka/client.py | 11 +- aiokafka/conn.py | 120 +++- aiokafka/errors.py | 790 +++++++++++++++++++++------ docs/api.rst | 15 - kafka/cluster.py | 2 +- kafka/conn.py | 4 +- kafka/coordinator/consumer.py | 2 +- kafka/errors.py | 538 ------------------ kafka/metrics/stats/sensor.py | 2 +- kafka/protocol/parser.py | 2 +- tests/kafka/fixtures.py | 4 +- tests/kafka/test_conn.py | 2 +- tests/kafka/test_metrics.py | 2 +- tests/record/test_default_records.py | 3 +- tests/record/test_legacy.py | 4 +- tests/test_client.py | 8 +- tests/test_coordinator.py | 2 +- tests/test_message_accumulator.py | 6 +- 20 files changed, 755 insertions(+), 766 deletions(-) delete mode 100644 kafka/errors.py diff --git a/aiokafka/admin/client.py b/aiokafka/admin/client.py index 392c93eb..7ce2465a 100644 --- a/aiokafka/admin/client.py +++ b/aiokafka/admin/client.py @@ -4,7 +4,6 @@ from ssl import SSLContext from typing import List, Optional, Dict, Tuple, Any -from kafka.errors import IncompatibleBrokerVersion, for_code from kafka.protocol.api import Request, Response from kafka.protocol.metadata import MetadataRequest from kafka.protocol.commit import OffsetFetchRequest, GroupCoordinatorRequest @@ -20,6 +19,7 @@ from kafka.structs import TopicPartition, OffsetAndMetadata from aiokafka import __version__ +from aiokafka.errors import IncompatibleBrokerVersion, for_code from aiokafka.client import AIOKafkaClient from .config_resource import ConfigResourceType, ConfigResource diff --git a/aiokafka/admin/new_topic.py b/aiokafka/admin/new_topic.py index 4d00daed..f5155c33 100644 --- a/aiokafka/admin/new_topic.py +++ b/aiokafka/admin/new_topic.py @@ -1,4 +1,4 @@ -from kafka.errors import IllegalArgumentError +from aiokafka.errors import IllegalArgumentError class NewTopic: diff --git a/aiokafka/client.py b/aiokafka/client.py index c1bdf6d1..371e4668 100644 --- a/aiokafka/client.py +++ b/aiokafka/client.py @@ -3,7 +3,6 @@ import random import time -from kafka.conn import collect_hosts from kafka.protocol.admin import DescribeAclsRequest_v2 from kafka.protocol.commit import OffsetFetchRequest from kafka.protocol.fetch import FetchRequest @@ -13,7 +12,7 @@ import aiokafka.errors as Errors from aiokafka import __version__ -from aiokafka.conn import create_conn, CloseReason +from aiokafka.conn import collect_hosts, create_conn, CloseReason from aiokafka.cluster import ClusterMetadata from aiokafka.protocol.coordination import FindCoordinatorRequest from aiokafka.errors import ( @@ -482,10 +481,10 @@ async def send(self, node_id, request, *, group=ConnectionGroup.DEFAULT): request (Struct): request object (not-encoded) Raises: - kafka.errors.RequestTimedOutError - kafka.errors.NodeNotReadyError - kafka.errors.KafkaConnectionError - kafka.errors.CorrelationIdError + aiokafka.errors.RequestTimedOutError + aiokafka.errors.NodeNotReadyError + aiokafka.errors.KafkaConnectionError + aiokafka.errors.CorrelationIdError Returns: Future: resolves to Response struct diff --git a/aiokafka/conn.py b/aiokafka/conn.py index 650aa3ef..2be93dc3 100644 --- a/aiokafka/conn.py +++ b/aiokafka/conn.py @@ -5,6 +5,8 @@ import hashlib import hmac import logging +import random +import socket import struct import sys import time @@ -33,6 +35,9 @@ __all__ = ['AIOKafkaConnection', 'create_conn'] +log = logging.getLogger(__name__) + +DEFAULT_KAFKA_PORT = 9092 READER_LIMIT = 2 ** 16 SASL_QOP_AUTH = 1 @@ -113,8 +118,6 @@ def connection_lost(self, exc): class AIOKafkaConnection: """Class for manage connection to Kafka node""" - log = logging.getLogger(__name__) - _reader = None # For __del__ to work properly, just in case _source_traceback = None @@ -284,7 +287,7 @@ async def _do_sasl_handshake(self): ) if self._security_protocol == 'SASL_PLAINTEXT' and \ self._sasl_mechanism == 'PLAIN': - self.log.warning( + log.warning( 'Sending username and password in the clear') if self._sasl_mechanism == 'GSSAPI': @@ -329,15 +332,15 @@ async def _do_sasl_handshake(self): auth_bytes = resp.sasl_auth_bytes if self._sasl_mechanism == 'GSSAPI': - self.log.info( + log.info( 'Authenticated as %s via GSSAPI', self.sasl_principal) elif self._sasl_mechanism == 'OAUTHBEARER': - self.log.info( + log.info( 'Authenticated via OAUTHBEARER' ) else: - self.log.info( + log.info( 'Authenticated as %s via %s', self._sasl_plain_username, self._sasl_mechanism @@ -382,7 +385,7 @@ def _on_read_task_error(cls, self_ref, read_task): read_task.result() except Exception as exc: if not isinstance(exc, (OSError, EOFError, ConnectionError)): - cls.log.exception("Unexpected exception in AIOKafkaConnection") + log.exception("Unexpected exception in AIOKafkaConnection") self = self_ref() if self is not None: @@ -441,7 +444,7 @@ def send(self, request, expect_response=True): f"Connection at {self._host}:{self._port} broken: {err}" ) - self.log.debug( + log.debug( '%s Request %d: %s', self, correlation_id, request) if not expect_response: @@ -475,7 +478,7 @@ def connected(self): return bool(self._reader is not None and not self._reader.at_eof()) def close(self, reason=None, exc=None): - self.log.debug("Closing connection at %s:%s", self._host, self._port) + log.debug("Closing connection at %s:%s", self._host, self._port) if self._reader is not None: self._writer.close() self._writer = self._reader = None @@ -546,7 +549,7 @@ def _handle_frame(self, resp): if (self._api_version == (0, 8, 2) and resp_type is GroupCoordinatorResponse and correlation_id != 0 and recv_correlation_id == 0): - self.log.warning( + log.warning( 'Kafka 0.8.2 quirk -- GroupCoordinatorResponse' ' coorelation id does not match request. This' ' should go away once at least one topic has been' @@ -564,7 +567,7 @@ def _handle_frame(self, resp): if not fut.done(): response = resp_type.decode(resp[4:]) - self.log.debug( + log.debug( '%s Response %d: %s', self, correlation_id, response) fut.set_result(response) @@ -771,3 +774,98 @@ def _token_extensions(self): return "\x01" + msg return "" + + +def _address_family(address): + """ + Attempt to determine the family of an address (or hostname) + + :return: either socket.AF_INET or socket.AF_INET6 or socket.AF_UNSPEC + if the address family could not be determined + """ + if address.startswith('[') and address.endswith(']'): + return socket.AF_INET6 + for af in (socket.AF_INET, socket.AF_INET6): + try: + socket.inet_pton(af, address) + return af + except (ValueError, AttributeError, socket.error): + continue + return socket.AF_UNSPEC + + +def get_ip_port_afi(host_and_port_str): + """ + Parse the IP and port from a string in the format of: + + * host_or_ip <- Can be either IPv4 address literal or hostname/fqdn + * host_or_ipv4:port <- Can be either IPv4 address literal or hostname/fqdn + * [host_or_ip] <- IPv6 address literal + * [host_or_ip]:port. <- IPv6 address literal + + .. note:: IPv6 address literals with ports *must* be enclosed in brackets + + .. note:: If the port is not specified, default will be returned. + + :return: tuple (host, port, afi), afi will be socket.AF_INET or + socket.AF_INET6 or socket.AF_UNSPEC + """ + host_and_port_str = host_and_port_str.strip() + if host_and_port_str.startswith('['): + af = socket.AF_INET6 + host, rest = host_and_port_str[1:].split(']') + if rest: + port = int(rest[1:]) + else: + port = DEFAULT_KAFKA_PORT + return host, port, af + else: + if ':' not in host_and_port_str: + af = _address_family(host_and_port_str) + return host_and_port_str, DEFAULT_KAFKA_PORT, af + else: + # now we have something with a colon in it and no square brackets. It could + # be either an IPv6 address literal (e.g., "::1") or an IP:port pair or a + # host:port pair + try: + # if it decodes as an IPv6 address, use that + socket.inet_pton(socket.AF_INET6, host_and_port_str) + return host_and_port_str, DEFAULT_KAFKA_PORT, socket.AF_INET6 + except AttributeError: + log.warning('socket.inet_pton not available on this platform.' + ' consider `pip install win_inet_pton`') + pass + except (ValueError, socket.error): + # it's a host:port pair + pass + host, port = host_and_port_str.rsplit(':', 1) + port = int(port) + + af = _address_family(host) + return host, port, af + + +def collect_hosts(hosts, randomize=True): + """ + Collects a comma-separated set of hosts (host:port) and optionally + randomize the returned list. + """ + + if isinstance(hosts, str): + hosts = hosts.strip().split(',') + + result = [] + afi = socket.AF_INET + for host_port in hosts: + + host, port, afi = get_ip_port_afi(host_port) + + if port < 0: + port = DEFAULT_KAFKA_PORT + + result.append((host, port, afi)) + + if randomize: + random.shuffle(result) + + return result diff --git a/aiokafka/errors.py b/aiokafka/errors.py index c000369b..d35c04b2 100644 --- a/aiokafka/errors.py +++ b/aiokafka/errors.py @@ -1,82 +1,13 @@ import inspect import sys -from kafka.errors import ( - KafkaError, - IllegalStateError, - IllegalArgumentError, - NoBrokersAvailable, - NodeNotReadyError, - KafkaProtocolError, - CorrelationIdError, - Cancelled, - TooManyInFlightRequests, - StaleMetadata, - UnrecognizedBrokerVersion, - CommitFailedError, - AuthenticationMethodNotSupported, - AuthenticationFailedError, - BrokerResponseError, - - # Numbered errors - NoError, # 0 - UnknownError, # -1 - OffsetOutOfRangeError, # 1 - CorruptRecordException, # 2 - UnknownTopicOrPartitionError, # 3 - InvalidFetchRequestError, # 4 - LeaderNotAvailableError, # 5 - NotLeaderForPartitionError, # 6 - RequestTimedOutError, # 7 - BrokerNotAvailableError, # 8 - ReplicaNotAvailableError, # 9 - MessageSizeTooLargeError, # 10 - StaleControllerEpochError, # 11 - OffsetMetadataTooLargeError, # 12 - StaleLeaderEpochCodeError, # 13 - GroupLoadInProgressError, # 14 - GroupCoordinatorNotAvailableError, # 15 - NotCoordinatorForGroupError, # 16 - InvalidTopicError, # 17 - RecordListTooLargeError, # 18 - NotEnoughReplicasError, # 19 - NotEnoughReplicasAfterAppendError, # 20 - InvalidRequiredAcksError, # 21 - IllegalGenerationError, # 22 - InconsistentGroupProtocolError, # 23 - InvalidGroupIdError, # 24 - UnknownMemberIdError, # 25 - InvalidSessionTimeoutError, # 26 - RebalanceInProgressError, # 27 - InvalidCommitOffsetSizeError, # 28 - TopicAuthorizationFailedError, # 29 - GroupAuthorizationFailedError, # 30 - ClusterAuthorizationFailedError, # 31 - InvalidTimestampError, # 32 - UnsupportedSaslMechanismError, # 33 - IllegalSaslStateError, # 34 - UnsupportedVersionError, # 35 - TopicAlreadyExistsError, # 36 - InvalidPartitionsError, # 37 - InvalidReplicationFactorError, # 38 - InvalidReplicationAssignmentError, # 39 - InvalidConfigurationError, # 40 - NotControllerError, # 41 - InvalidRequestError, # 42 - UnsupportedForMessageFormatError, # 43 - PolicyViolationError, # 44 - - KafkaUnavailableError, - KafkaTimeoutError, - KafkaConnectionError, - UnsupportedCodecError, -) __all__ = [ # aiokafka custom errors - "ConsumerStoppedError", "NoOffsetForPartitionError", "RecordTooLargeError", + "ConsumerStoppedError", + "NoOffsetForPartitionError", + "RecordTooLargeError", "ProducerClosed", - # Kafka Python errors "KafkaError", "IllegalStateError", @@ -89,11 +20,11 @@ "TooManyInFlightRequests", "StaleMetadata", "UnrecognizedBrokerVersion", + "IncompatibleBrokerVersion", "CommitFailedError", "AuthenticationMethodNotSupported", "AuthenticationFailedError", "BrokerResponseError", - # Numbered errors "NoError", # 0 "UnknownError", # -1 @@ -141,7 +72,6 @@ "InvalidRequestError", # 42 "UnsupportedForMessageFormatError", # 43 "PolicyViolationError", # 44 - "KafkaUnavailableError", "KafkaTimeoutError", "KafkaConnectionError", @@ -149,34 +79,130 @@ ] -class CoordinatorNotAvailableError(GroupCoordinatorNotAvailableError): - message = "COORDINATOR_NOT_AVAILABLE" +class KafkaError(RuntimeError): + retriable = False + # whether metadata should be refreshed on error + invalid_metadata = False + def __str__(self): + if not self.args: + return self.__class__.__name__ + return "{0}: {1}".format( + self.__class__.__name__, super(KafkaError, self).__str__() + ) -class NotCoordinatorError(NotCoordinatorForGroupError): - message = "NOT_COORDINATOR" +class IllegalStateError(KafkaError): + pass -class CoordinatorLoadInProgressError(GroupLoadInProgressError): - message = "COORDINATOR_LOAD_IN_PROGRESS" + +class IllegalArgumentError(KafkaError): + pass -InvalidMessageError = CorruptRecordException -GroupCoordinatorNotAvailableError = CoordinatorNotAvailableError -NotCoordinatorForGroupError = NotCoordinatorError -GroupLoadInProgressError = CoordinatorLoadInProgressError +class NoBrokersAvailable(KafkaError): + retriable = True + invalid_metadata = True + + +class NodeNotReadyError(KafkaError): + retriable = True + + +class KafkaProtocolError(KafkaError): + retriable = True + + +class CorrelationIdError(KafkaProtocolError): + retriable = True + + +class Cancelled(KafkaError): + retriable = True + + +class TooManyInFlightRequests(KafkaError): + retriable = True + + +class StaleMetadata(KafkaError): + retriable = True + invalid_metadata = True + + +class MetadataEmptyBrokerList(KafkaError): + retriable = True + + +class UnrecognizedBrokerVersion(KafkaError): + pass + + +class IncompatibleBrokerVersion(KafkaError): + pass + + +class CommitFailedError(KafkaError): + def __init__(self, *args, **kwargs): + super(CommitFailedError, self).__init__( + """Commit cannot be completed since the group has already + rebalanced and assigned the partitions to another member. + This means that the time between subsequent calls to poll() + was longer than the configured max_poll_interval_ms, which + typically implies that the poll loop is spending too much + time message processing. You can address this either by + increasing the rebalance timeout with max_poll_interval_ms, + or by reducing the maximum size of batches returned in poll() + with max_poll_records. + """, + *args, + **kwargs + ) + + +class AuthenticationMethodNotSupported(KafkaError): + pass + + +class AuthenticationFailedError(KafkaError): + retriable = False + + +class KafkaUnavailableError(KafkaError): + pass + + +class KafkaTimeoutError(KafkaError): + pass + + +class KafkaConnectionError(KafkaError): + retriable = True + invalid_metadata = True + + +class UnsupportedCodecError(KafkaError): + pass + + +class KafkaConfigurationError(KafkaError): + pass + + +class QuotaViolationError(KafkaError): + pass class ConsumerStoppedError(Exception): - """ Raised on `get*` methods of Consumer if it's cancelled, even pending - ones. + """Raised on `get*` methods of Consumer if it's cancelled, even pending + ones. """ class IllegalOperation(Exception): - """ Raised if you try to execute an operation, that is not available with - current configuration. For example trying to commit if no group_id was - given. + """Raised if you try to execute an operation, that is not available with + current configuration. For example trying to commit if no group_id was + given. """ @@ -194,231 +220,649 @@ class ProducerClosed(KafkaError): class ProducerFenced(KafkaError): """Another producer with the same transactional ID went online. - NOTE: As it seems this will be raised by Broker if transaction timeout - occurred also. + NOTE: As it seems this will be raised by Broker if transaction timeout + occurred also. """ def __init__( self, msg="There is a newer producer using the same transactional_id or" - "transaction timeout occurred (check that processing time is " - "below transaction_timeout_ms)" + "transaction timeout occurred (check that processing time is " + "below transaction_timeout_ms)", ): super().__init__(msg) +class BrokerResponseError(KafkaError): + errno = None + message = None + description = None + + def __str__(self): + """Add errno to standard KafkaError str""" + return "[Error {0}] {1}".format( + self.errno, super(BrokerResponseError, self).__str__() + ) + + +class NoError(BrokerResponseError): + errno = 0 + message = "NO_ERROR" + description = "No error--it worked!" + + +class UnknownError(BrokerResponseError): + errno = -1 + message = "UNKNOWN" + description = "An unexpected server error." + + +class OffsetOutOfRangeError(BrokerResponseError): + errno = 1 + message = "OFFSET_OUT_OF_RANGE" + description = ( + "The requested offset is outside the range of offsets" + " maintained by the server for the given topic/partition." + ) + + +class CorruptRecordException(BrokerResponseError): + errno = 2 + message = "CORRUPT_MESSAGE" + description = ( + "This message has failed its CRC checksum, exceeds the" + " valid size, or is otherwise corrupt." + ) + + +# Backward compatibility +InvalidMessageError = CorruptRecordException + + +class UnknownTopicOrPartitionError(BrokerResponseError): + errno = 3 + message = "UNKNOWN_TOPIC_OR_PARTITION" + description = ( + "This request is for a topic or partition that does not" + " exist on this broker." + ) + retriable = True + invalid_metadata = True + + +class InvalidFetchRequestError(BrokerResponseError): + errno = 4 + message = "INVALID_FETCH_SIZE" + description = "The message has a negative size." + + +class LeaderNotAvailableError(BrokerResponseError): + errno = 5 + message = "LEADER_NOT_AVAILABLE" + description = ( + "This error is thrown if we are in the middle of a" + " leadership election and there is currently no leader for" + " this partition and hence it is unavailable for writes." + ) + retriable = True + invalid_metadata = True + + +class NotLeaderForPartitionError(BrokerResponseError): + errno = 6 + message = "NOT_LEADER_FOR_PARTITION" + description = ( + "This error is thrown if the client attempts to send" + " messages to a replica that is not the leader for some" + " partition. It indicates that the clients metadata is out" + " of date." + ) + retriable = True + invalid_metadata = True + + +class RequestTimedOutError(BrokerResponseError): + errno = 7 + message = "REQUEST_TIMED_OUT" + description = ( + "This error is thrown if the request exceeds the" + " user-specified time limit in the request." + ) + retriable = True + + +class BrokerNotAvailableError(BrokerResponseError): + errno = 8 + message = "BROKER_NOT_AVAILABLE" + description = ( + "This is not a client facing error and is used mostly by" + " tools when a broker is not alive." + ) + + +class ReplicaNotAvailableError(BrokerResponseError): + errno = 9 + message = "REPLICA_NOT_AVAILABLE" + description = ( + "If replica is expected on a broker, but is not (this can be" + " safely ignored)." + ) + + +class MessageSizeTooLargeError(BrokerResponseError): + errno = 10 + message = "MESSAGE_SIZE_TOO_LARGE" + description = ( + "The server has a configurable maximum message size to avoid" + " unbounded memory allocation. This error is thrown if the" + " client attempt to produce a message larger than this" + " maximum." + ) + + +class StaleControllerEpochError(BrokerResponseError): + errno = 11 + message = "STALE_CONTROLLER_EPOCH" + description = "Internal error code for broker-to-broker communication." + + +class OffsetMetadataTooLargeError(BrokerResponseError): + errno = 12 + message = "OFFSET_METADATA_TOO_LARGE" + description = ( + "If you specify a string larger than configured maximum for" " offset metadata." + ) + + +# TODO is this deprecated? +# https://cwiki.apache.org/confluence/display/KAFKA/A+Guide+To+The+Kafka+Protocol#AGuideToTheKafkaProtocol-ErrorCodes +class StaleLeaderEpochCodeError(BrokerResponseError): + errno = 13 + message = "STALE_LEADER_EPOCH_CODE" + + +class GroupLoadInProgressError(BrokerResponseError): + errno = 14 + message = "COORDINATOR_LOAD_IN_PROGRESS" + description = ( + "The broker returns this error code for an offset fetch" + " request if it is still loading offsets (after a leader" + " change for that offsets topic partition), or in response" + " to group membership requests (such as heartbeats) when" + " group metadata is being loaded by the coordinator." + ) + retriable = True + + +CoordinatorLoadInProgressError = GroupLoadInProgressError + + +class GroupCoordinatorNotAvailableError(BrokerResponseError): + errno = 15 + message = "COORDINATOR_NOT_AVAILABLE" + description = ( + "The broker returns this error code for group coordinator" + " requests, offset commits, and most group management" + " requests if the offsets topic has not yet been created, or" + " if the group coordinator is not active." + ) + retriable = True + + +CoordinatorNotAvailableError = GroupCoordinatorNotAvailableError + + +class NotCoordinatorForGroupError(BrokerResponseError): + errno = 16 + message = "NOT_COORDINATOR" + description = ( + "The broker returns this error code if it receives an offset" + " fetch or commit request for a group that it is not a" + " coordinator for." + ) + retriable = True + + +NotCoordinatorError = NotCoordinatorForGroupError + + +class InvalidTopicError(BrokerResponseError): + errno = 17 + message = "INVALID_TOPIC" + description = ( + "For a request which attempts to access an invalid topic" + " (e.g. one which has an illegal name), or if an attempt" + " is made to write to an internal topic (such as the" + " consumer offsets topic)." + ) + + +class RecordListTooLargeError(BrokerResponseError): + errno = 18 + message = "RECORD_LIST_TOO_LARGE" + description = ( + "If a message batch in a produce request exceeds the maximum" + " configured segment size." + ) + + +class NotEnoughReplicasError(BrokerResponseError): + errno = 19 + message = "NOT_ENOUGH_REPLICAS" + description = ( + "Returned from a produce request when the number of in-sync" + " replicas is lower than the configured minimum and" + " requiredAcks is -1." + ) + retriable = True + + +class NotEnoughReplicasAfterAppendError(BrokerResponseError): + errno = 20 + message = "NOT_ENOUGH_REPLICAS_AFTER_APPEND" + description = ( + "Returned from a produce request when the message was" + " written to the log, but with fewer in-sync replicas than" + " required." + ) + retriable = True + + +class InvalidRequiredAcksError(BrokerResponseError): + errno = 21 + message = "INVALID_REQUIRED_ACKS" + description = ( + "Returned from a produce request if the requested" + " requiredAcks is invalid (anything other than -1, 1, or 0)." + ) + + +class IllegalGenerationError(BrokerResponseError): + errno = 22 + message = "ILLEGAL_GENERATION" + description = ( + "Returned from group membership requests (such as heartbeats)" + " when the generation id provided in the request is not the" + " current generation." + ) + + +class InconsistentGroupProtocolError(BrokerResponseError): + errno = 23 + message = "INCONSISTENT_GROUP_PROTOCOL" + description = ( + "Returned in join group when the member provides a protocol" + " type or set of protocols which is not compatible with the" + " current group." + ) + + +class InvalidGroupIdError(BrokerResponseError): + errno = 24 + message = "INVALID_GROUP_ID" + description = "Returned in join group when the groupId is empty or null." + + +class UnknownMemberIdError(BrokerResponseError): + errno = 25 + message = "UNKNOWN_MEMBER_ID" + description = ( + "Returned from group requests (offset commits/fetches," + " heartbeats, etc) when the memberId is not in the current" + " generation." + ) + + +class InvalidSessionTimeoutError(BrokerResponseError): + errno = 26 + message = "INVALID_SESSION_TIMEOUT" + description = ( + "Return in join group when the requested session timeout is" + " outside of the allowed range on the broker" + ) + + +class RebalanceInProgressError(BrokerResponseError): + errno = 27 + message = "REBALANCE_IN_PROGRESS" + description = ( + "Returned in heartbeat requests when the coordinator has" + " begun rebalancing the group. This indicates to the client" + " that it should rejoin the group." + ) + + +class InvalidCommitOffsetSizeError(BrokerResponseError): + errno = 28 + message = "INVALID_COMMIT_OFFSET_SIZE" + description = ( + "This error indicates that an offset commit was rejected" + " because of oversize metadata." + ) + + +class TopicAuthorizationFailedError(BrokerResponseError): + errno = 29 + message = "TOPIC_AUTHORIZATION_FAILED" + description = ( + "Returned by the broker when the client is not authorized to" + " access the requested topic." + ) + + +class GroupAuthorizationFailedError(BrokerResponseError): + errno = 30 + message = "GROUP_AUTHORIZATION_FAILED" + description = ( + "Returned by the broker when the client is not authorized to" + " access a particular groupId." + ) + + +class ClusterAuthorizationFailedError(BrokerResponseError): + errno = 31 + message = "CLUSTER_AUTHORIZATION_FAILED" + description = ( + "Returned by the broker when the client is not authorized to" + " use an inter-broker or administrative API." + ) + + +class InvalidTimestampError(BrokerResponseError): + errno = 32 + message = "INVALID_TIMESTAMP" + description = "The timestamp of the message is out of acceptable range." + + +class UnsupportedSaslMechanismError(BrokerResponseError): + errno = 33 + message = "UNSUPPORTED_SASL_MECHANISM" + description = "The broker does not support the requested SASL mechanism." + + +class IllegalSaslStateError(BrokerResponseError): + errno = 34 + message = "ILLEGAL_SASL_STATE" + description = "Request is not valid given the current SASL state." + + +class UnsupportedVersionError(BrokerResponseError): + errno = 35 + message = "UNSUPPORTED_VERSION" + description = "The version of API is not supported." + + +class TopicAlreadyExistsError(BrokerResponseError): + errno = 36 + message = "TOPIC_ALREADY_EXISTS" + description = "Topic with this name already exists." + + +class InvalidPartitionsError(BrokerResponseError): + errno = 37 + message = "INVALID_PARTITIONS" + description = "Number of partitions is invalid." + + +class InvalidReplicationFactorError(BrokerResponseError): + errno = 38 + message = "INVALID_REPLICATION_FACTOR" + description = "Replication-factor is invalid." + + +class InvalidReplicationAssignmentError(BrokerResponseError): + errno = 39 + message = "INVALID_REPLICATION_ASSIGNMENT" + description = "Replication assignment is invalid." + + +class InvalidConfigurationError(BrokerResponseError): + errno = 40 + message = "INVALID_CONFIG" + description = "Configuration is invalid." + + +class NotControllerError(BrokerResponseError): + errno = 41 + message = "NOT_CONTROLLER" + description = "This is not the correct controller for this cluster." + retriable = True + + +class InvalidRequestError(BrokerResponseError): + errno = 42 + message = "INVALID_REQUEST" + description = ( + "This most likely occurs because of a request being" + " malformed by the client library or the message was" + " sent to an incompatible broker. See the broker logs" + " for more details." + ) + + +class UnsupportedForMessageFormatError(BrokerResponseError): + errno = 43 + message = "UNSUPPORTED_FOR_MESSAGE_FORMAT" + description = ( + "The message format version on the broker does not" " support this request." + ) + + +class PolicyViolationError(BrokerResponseError): + errno = 44 + message = "POLICY_VIOLATION" + description = "Request parameters do not satisfy the configured policy." + + class OutOfOrderSequenceNumber(BrokerResponseError): errno = 45 - message = 'OUT_OF_ORDER_SEQUENCE_NUMBER' - description = 'The broker received an out of order sequence number' + message = "OUT_OF_ORDER_SEQUENCE_NUMBER" + description = "The broker received an out of order sequence number" class DuplicateSequenceNumber(BrokerResponseError): errno = 46 - message = 'DUPLICATE_SEQUENCE_NUMBER' - description = 'The broker received a duplicate sequence number' + message = "DUPLICATE_SEQUENCE_NUMBER" + description = "The broker received a duplicate sequence number" class InvalidProducerEpoch(BrokerResponseError): errno = 47 - message = 'INVALID_PRODUCER_EPOCH' + message = "INVALID_PRODUCER_EPOCH" description = ( - 'Producer attempted an operation with an old epoch. Either ' - 'there is a newer producer with the same transactionalId, or the ' - 'producer\'s transaction has been expired by the broker.' + "Producer attempted an operation with an old epoch. Either " + "there is a newer producer with the same transactionalId, or the " + "producer's transaction has been expired by the broker." ) class InvalidTxnState(BrokerResponseError): errno = 48 - message = 'INVALID_TXN_STATE' - description = ( - 'The producer attempted a transactional operation in an invalid state' - ) + message = "INVALID_TXN_STATE" + description = "The producer attempted a transactional operation in an invalid state" class InvalidProducerIdMapping(BrokerResponseError): errno = 49 - message = 'INVALID_PRODUCER_ID_MAPPING' + message = "INVALID_PRODUCER_ID_MAPPING" description = ( - 'The producer attempted to use a producer id which is not currently ' - 'assigned to its transactional id' + "The producer attempted to use a producer id which is not currently " + "assigned to its transactional id" ) class InvalidTransactionTimeout(BrokerResponseError): errno = 50 - message = 'INVALID_TRANSACTION_TIMEOUT' + message = "INVALID_TRANSACTION_TIMEOUT" description = ( - 'The transaction timeout is larger than the maximum value allowed by' - ' the broker (as configured by transaction.max.timeout.ms).' + "The transaction timeout is larger than the maximum value allowed by" + " the broker (as configured by transaction.max.timeout.ms)." ) class ConcurrentTransactions(BrokerResponseError): errno = 51 - message = 'CONCURRENT_TRANSACTIONS' + message = "CONCURRENT_TRANSACTIONS" description = ( - 'The producer attempted to update a transaction while another ' - 'concurrent operation on the same transaction was ongoing' + "The producer attempted to update a transaction while another " + "concurrent operation on the same transaction was ongoing" ) class TransactionCoordinatorFenced(BrokerResponseError): errno = 52 - message = 'TRANSACTION_COORDINATOR_FENCED' + message = "TRANSACTION_COORDINATOR_FENCED" description = ( - 'Indicates that the transaction coordinator sending a WriteTxnMarker' - ' is no longer the current coordinator for a given producer' + "Indicates that the transaction coordinator sending a WriteTxnMarker" + " is no longer the current coordinator for a given producer" ) class TransactionalIdAuthorizationFailed(BrokerResponseError): errno = 53 - message = 'TRANSACTIONAL_ID_AUTHORIZATION_FAILED' - description = 'Transactional Id authorization failed' + message = "TRANSACTIONAL_ID_AUTHORIZATION_FAILED" + description = "Transactional Id authorization failed" class SecurityDisabled(BrokerResponseError): errno = 54 - message = 'SECURITY_DISABLED' - description = 'Security features are disabled' + message = "SECURITY_DISABLED" + description = "Security features are disabled" class OperationNotAttempted(BrokerResponseError): errno = 55 - message = 'OPERATION_NOT_ATTEMPTED' + message = "OPERATION_NOT_ATTEMPTED" description = ( - 'The broker did not attempt to execute this operation. This may happen' - ' for batched RPCs where some operations in the batch failed, causing ' - 'the broker to respond without trying the rest.' + "The broker did not attempt to execute this operation. This may happen" + " for batched RPCs where some operations in the batch failed, causing " + "the broker to respond without trying the rest." ) class KafkaStorageError(BrokerResponseError): errno = 56 - message = 'KAFKA_STORAGE_ERROR' - description = ( - 'The user-specified log directory is not found in the broker config.' - ) + message = "KAFKA_STORAGE_ERROR" + description = "The user-specified log directory is not found in the broker config." class LogDirNotFound(BrokerResponseError): errno = 57 - message = 'LOG_DIR_NOT_FOUND' - description = ( - 'The user-specified log directory is not found in the broker config.' - ) + message = "LOG_DIR_NOT_FOUND" + description = "The user-specified log directory is not found in the broker config." class SaslAuthenticationFailed(BrokerResponseError): errno = 58 - message = 'SASL_AUTHENTICATION_FAILED' - description = 'SASL Authentication failed.' + message = "SASL_AUTHENTICATION_FAILED" + description = "SASL Authentication failed." class UnknownProducerId(BrokerResponseError): errno = 59 - message = 'UNKNOWN_PRODUCER_ID' + message = "UNKNOWN_PRODUCER_ID" description = ( - 'This exception is raised by the broker if it could not locate the ' - 'producer metadata associated with the producerId in question. This ' - 'could happen if, for instance, the producer\'s records were deleted ' - 'because their retention time had elapsed. Once the last records of ' - 'the producerId are removed, the producer\'s metadata is removed from' - ' the broker, and future appends by the producer will return this ' - 'exception.' + "This exception is raised by the broker if it could not locate the " + "producer metadata associated with the producerId in question. This " + "could happen if, for instance, the producer's records were deleted " + "because their retention time had elapsed. Once the last records of " + "the producerId are removed, the producer's metadata is removed from" + " the broker, and future appends by the producer will return this " + "exception." ) class ReassignmentInProgress(BrokerResponseError): errno = 60 - message = 'REASSIGNMENT_IN_PROGRESS' - description = 'A partition reassignment is in progress' + message = "REASSIGNMENT_IN_PROGRESS" + description = "A partition reassignment is in progress" class DelegationTokenAuthDisabled(BrokerResponseError): errno = 61 - message = 'DELEGATION_TOKEN_AUTH_DISABLED' - description = 'Delegation Token feature is not enabled' + message = "DELEGATION_TOKEN_AUTH_DISABLED" + description = "Delegation Token feature is not enabled" class DelegationTokenNotFound(BrokerResponseError): errno = 62 - message = 'DELEGATION_TOKEN_NOT_FOUND' - description = 'Delegation Token is not found on server.' + message = "DELEGATION_TOKEN_NOT_FOUND" + description = "Delegation Token is not found on server." class DelegationTokenOwnerMismatch(BrokerResponseError): errno = 63 - message = 'DELEGATION_TOKEN_OWNER_MISMATCH' - description = 'Specified Principal is not valid Owner/Renewer.' + message = "DELEGATION_TOKEN_OWNER_MISMATCH" + description = "Specified Principal is not valid Owner/Renewer." class DelegationTokenRequestNotAllowed(BrokerResponseError): errno = 64 - message = 'DELEGATION_TOKEN_REQUEST_NOT_ALLOWED' + message = "DELEGATION_TOKEN_REQUEST_NOT_ALLOWED" description = ( - 'Delegation Token requests are not allowed on PLAINTEXT/1-way SSL ' - 'channels and on delegation token authenticated channels.' + "Delegation Token requests are not allowed on PLAINTEXT/1-way SSL " + "channels and on delegation token authenticated channels." ) class DelegationTokenAuthorizationFailed(BrokerResponseError): errno = 65 - message = 'DELEGATION_TOKEN_AUTHORIZATION_FAILED' - description = 'Delegation Token authorization failed.' + message = "DELEGATION_TOKEN_AUTHORIZATION_FAILED" + description = "Delegation Token authorization failed." class DelegationTokenExpired(BrokerResponseError): errno = 66 - message = 'DELEGATION_TOKEN_EXPIRED' - description = 'Delegation Token is expired.' + message = "DELEGATION_TOKEN_EXPIRED" + description = "Delegation Token is expired." class InvalidPrincipalType(BrokerResponseError): errno = 67 - message = 'INVALID_PRINCIPAL_TYPE' - description = 'Supplied principalType is not supported' + message = "INVALID_PRINCIPAL_TYPE" + description = "Supplied principalType is not supported" class NonEmptyGroup(BrokerResponseError): errno = 68 - message = 'NON_EMPTY_GROUP' - description = 'The group is not empty' + message = "NON_EMPTY_GROUP" + description = "The group is not empty" class GroupIdNotFound(BrokerResponseError): errno = 69 - message = 'GROUP_ID_NOT_FOUND' - description = 'The group id does not exist' + message = "GROUP_ID_NOT_FOUND" + description = "The group id does not exist" class FetchSessionIdNotFound(BrokerResponseError): errno = 70 - message = 'FETCH_SESSION_ID_NOT_FOUND' - description = 'The fetch session ID was not found' + message = "FETCH_SESSION_ID_NOT_FOUND" + description = "The fetch session ID was not found" class InvalidFetchSessionEpoch(BrokerResponseError): errno = 71 - message = 'INVALID_FETCH_SESSION_EPOCH' - description = 'The fetch session epoch is invalid' + message = "INVALID_FETCH_SESSION_EPOCH" + description = "The fetch session epoch is invalid" class ListenerNotFound(BrokerResponseError): errno = 72 - message = 'LISTENER_NOT_FOUND' + message = "LISTENER_NOT_FOUND" description = ( - 'There is no listener on the leader broker that matches the' - ' listener on which metadata request was processed' + "There is no listener on the leader broker that matches the" + " listener on which metadata request was processed" ) def _iter_broker_errors(): for name, obj in inspect.getmembers(sys.modules[__name__]): - if inspect.isclass(obj) and issubclass(obj, BrokerResponseError) and \ - obj != BrokerResponseError: + if ( + inspect.isclass(obj) + and issubclass(obj, BrokerResponseError) + and obj != BrokerResponseError + ): yield obj diff --git a/docs/api.rst b/docs/api.rst index ec616db9..1a404d2d 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -144,21 +144,6 @@ Errors :members: -.. autoclass:: aiokafka.errors.KafkaTimeoutError -.. autoclass:: aiokafka.errors.RequestTimedOutError -.. autoclass:: aiokafka.errors.NotEnoughReplicasError -.. autoclass:: aiokafka.errors.NotEnoughReplicasAfterAppendError -.. autoclass:: aiokafka.errors.KafkaError -.. autoclass:: aiokafka.errors.UnsupportedVersionError -.. autoclass:: aiokafka.errors.TopicAuthorizationFailedError -.. autoclass:: aiokafka.errors.OffsetOutOfRangeError -.. autoclass:: aiokafka.errors.CorruptRecordException -.. autoclass:: kafka.errors.CorruptRecordException -.. autoclass:: aiokafka.errors.InvalidMessageError -.. autoclass:: aiokafka.errors.IllegalStateError -.. autoclass:: aiokafka.errors.CommitFailedError - - Structs ^^^^^^^ diff --git a/kafka/cluster.py b/kafka/cluster.py index 438baf29..f6d5e510 100644 --- a/kafka/cluster.py +++ b/kafka/cluster.py @@ -8,7 +8,7 @@ from kafka.vendor import six -from kafka import errors as Errors +from aiokafka import errors as Errors from kafka.conn import collect_hosts from kafka.future import Future from kafka.structs import BrokerMetadata, PartitionMetadata, TopicPartition diff --git a/kafka/conn.py b/kafka/conn.py index cac35487..3edd1915 100644 --- a/kafka/conn.py +++ b/kafka/conn.py @@ -20,7 +20,7 @@ from kafka.vendor import six -import kafka.errors as Errors +import aiokafka.errors as Errors from kafka.future import Future from kafka.metrics.stats import Avg, Count, Max, Rate from kafka.oauth.abstract import AbstractTokenProvider @@ -909,7 +909,7 @@ def close(self, error=None): Arguments: error (Exception, optional): pending in-flight-requests will be failed with this exception. - Default: kafka.errors.KafkaConnectionError. + Default: aiokafka.errors.KafkaConnectionError. """ if self.state is ConnectionStates.DISCONNECTED: return diff --git a/kafka/coordinator/consumer.py b/kafka/coordinator/consumer.py index 971f5e80..6f0de2db 100644 --- a/kafka/coordinator/consumer.py +++ b/kafka/coordinator/consumer.py @@ -13,7 +13,7 @@ from kafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor from kafka.coordinator.assignors.sticky.sticky_assignor import StickyPartitionAssignor from kafka.coordinator.protocol import ConsumerProtocol -import kafka.errors as Errors +import aiokafka.errors as Errors from kafka.future import Future from kafka.metrics import AnonMeasurable from kafka.metrics.stats import Avg, Count, Max, Rate diff --git a/kafka/errors.py b/kafka/errors.py deleted file mode 100644 index b33cf51e..00000000 --- a/kafka/errors.py +++ /dev/null @@ -1,538 +0,0 @@ -from __future__ import absolute_import - -import inspect -import sys - - -class KafkaError(RuntimeError): - retriable = False - # whether metadata should be refreshed on error - invalid_metadata = False - - def __str__(self): - if not self.args: - return self.__class__.__name__ - return '{0}: {1}'.format(self.__class__.__name__, - super(KafkaError, self).__str__()) - - -class IllegalStateError(KafkaError): - pass - - -class IllegalArgumentError(KafkaError): - pass - - -class NoBrokersAvailable(KafkaError): - retriable = True - invalid_metadata = True - - -class NodeNotReadyError(KafkaError): - retriable = True - - -class KafkaProtocolError(KafkaError): - retriable = True - - -class CorrelationIdError(KafkaProtocolError): - retriable = True - - -class Cancelled(KafkaError): - retriable = True - - -class TooManyInFlightRequests(KafkaError): - retriable = True - - -class StaleMetadata(KafkaError): - retriable = True - invalid_metadata = True - - -class MetadataEmptyBrokerList(KafkaError): - retriable = True - - -class UnrecognizedBrokerVersion(KafkaError): - pass - - -class IncompatibleBrokerVersion(KafkaError): - pass - - -class CommitFailedError(KafkaError): - def __init__(self, *args, **kwargs): - super(CommitFailedError, self).__init__( - """Commit cannot be completed since the group has already - rebalanced and assigned the partitions to another member. - This means that the time between subsequent calls to poll() - was longer than the configured max_poll_interval_ms, which - typically implies that the poll loop is spending too much - time message processing. You can address this either by - increasing the rebalance timeout with max_poll_interval_ms, - or by reducing the maximum size of batches returned in poll() - with max_poll_records. - """, *args, **kwargs) - - -class AuthenticationMethodNotSupported(KafkaError): - pass - - -class AuthenticationFailedError(KafkaError): - retriable = False - - -class BrokerResponseError(KafkaError): - errno = None - message = None - description = None - - def __str__(self): - """Add errno to standard KafkaError str""" - return '[Error {0}] {1}'.format( - self.errno, - super(BrokerResponseError, self).__str__()) - - -class NoError(BrokerResponseError): - errno = 0 - message = 'NO_ERROR' - description = 'No error--it worked!' - - -class UnknownError(BrokerResponseError): - errno = -1 - message = 'UNKNOWN' - description = 'An unexpected server error.' - - -class OffsetOutOfRangeError(BrokerResponseError): - errno = 1 - message = 'OFFSET_OUT_OF_RANGE' - description = ('The requested offset is outside the range of offsets' - ' maintained by the server for the given topic/partition.') - - -class CorruptRecordException(BrokerResponseError): - errno = 2 - message = 'CORRUPT_MESSAGE' - description = ('This message has failed its CRC checksum, exceeds the' - ' valid size, or is otherwise corrupt.') - -# Backward compatibility -InvalidMessageError = CorruptRecordException - - -class UnknownTopicOrPartitionError(BrokerResponseError): - errno = 3 - message = 'UNKNOWN_TOPIC_OR_PARTITION' - description = ('This request is for a topic or partition that does not' - ' exist on this broker.') - retriable = True - invalid_metadata = True - - -class InvalidFetchRequestError(BrokerResponseError): - errno = 4 - message = 'INVALID_FETCH_SIZE' - description = 'The message has a negative size.' - - -class LeaderNotAvailableError(BrokerResponseError): - errno = 5 - message = 'LEADER_NOT_AVAILABLE' - description = ('This error is thrown if we are in the middle of a' - ' leadership election and there is currently no leader for' - ' this partition and hence it is unavailable for writes.') - retriable = True - invalid_metadata = True - - -class NotLeaderForPartitionError(BrokerResponseError): - errno = 6 - message = 'NOT_LEADER_FOR_PARTITION' - description = ('This error is thrown if the client attempts to send' - ' messages to a replica that is not the leader for some' - ' partition. It indicates that the clients metadata is out' - ' of date.') - retriable = True - invalid_metadata = True - - -class RequestTimedOutError(BrokerResponseError): - errno = 7 - message = 'REQUEST_TIMED_OUT' - description = ('This error is thrown if the request exceeds the' - ' user-specified time limit in the request.') - retriable = True - - -class BrokerNotAvailableError(BrokerResponseError): - errno = 8 - message = 'BROKER_NOT_AVAILABLE' - description = ('This is not a client facing error and is used mostly by' - ' tools when a broker is not alive.') - - -class ReplicaNotAvailableError(BrokerResponseError): - errno = 9 - message = 'REPLICA_NOT_AVAILABLE' - description = ('If replica is expected on a broker, but is not (this can be' - ' safely ignored).') - - -class MessageSizeTooLargeError(BrokerResponseError): - errno = 10 - message = 'MESSAGE_SIZE_TOO_LARGE' - description = ('The server has a configurable maximum message size to avoid' - ' unbounded memory allocation. This error is thrown if the' - ' client attempt to produce a message larger than this' - ' maximum.') - - -class StaleControllerEpochError(BrokerResponseError): - errno = 11 - message = 'STALE_CONTROLLER_EPOCH' - description = 'Internal error code for broker-to-broker communication.' - - -class OffsetMetadataTooLargeError(BrokerResponseError): - errno = 12 - message = 'OFFSET_METADATA_TOO_LARGE' - description = ('If you specify a string larger than configured maximum for' - ' offset metadata.') - - -# TODO is this deprecated? https://cwiki.apache.org/confluence/display/KAFKA/A+Guide+To+The+Kafka+Protocol#AGuideToTheKafkaProtocol-ErrorCodes -class StaleLeaderEpochCodeError(BrokerResponseError): - errno = 13 - message = 'STALE_LEADER_EPOCH_CODE' - - -class GroupLoadInProgressError(BrokerResponseError): - errno = 14 - message = 'OFFSETS_LOAD_IN_PROGRESS' - description = ('The broker returns this error code for an offset fetch' - ' request if it is still loading offsets (after a leader' - ' change for that offsets topic partition), or in response' - ' to group membership requests (such as heartbeats) when' - ' group metadata is being loaded by the coordinator.') - retriable = True - - -class GroupCoordinatorNotAvailableError(BrokerResponseError): - errno = 15 - message = 'CONSUMER_COORDINATOR_NOT_AVAILABLE' - description = ('The broker returns this error code for group coordinator' - ' requests, offset commits, and most group management' - ' requests if the offsets topic has not yet been created, or' - ' if the group coordinator is not active.') - retriable = True - - -class NotCoordinatorForGroupError(BrokerResponseError): - errno = 16 - message = 'NOT_COORDINATOR_FOR_CONSUMER' - description = ('The broker returns this error code if it receives an offset' - ' fetch or commit request for a group that it is not a' - ' coordinator for.') - retriable = True - - -class InvalidTopicError(BrokerResponseError): - errno = 17 - message = 'INVALID_TOPIC' - description = ('For a request which attempts to access an invalid topic' - ' (e.g. one which has an illegal name), or if an attempt' - ' is made to write to an internal topic (such as the' - ' consumer offsets topic).') - - -class RecordListTooLargeError(BrokerResponseError): - errno = 18 - message = 'RECORD_LIST_TOO_LARGE' - description = ('If a message batch in a produce request exceeds the maximum' - ' configured segment size.') - - -class NotEnoughReplicasError(BrokerResponseError): - errno = 19 - message = 'NOT_ENOUGH_REPLICAS' - description = ('Returned from a produce request when the number of in-sync' - ' replicas is lower than the configured minimum and' - ' requiredAcks is -1.') - retriable = True - - -class NotEnoughReplicasAfterAppendError(BrokerResponseError): - errno = 20 - message = 'NOT_ENOUGH_REPLICAS_AFTER_APPEND' - description = ('Returned from a produce request when the message was' - ' written to the log, but with fewer in-sync replicas than' - ' required.') - retriable = True - - -class InvalidRequiredAcksError(BrokerResponseError): - errno = 21 - message = 'INVALID_REQUIRED_ACKS' - description = ('Returned from a produce request if the requested' - ' requiredAcks is invalid (anything other than -1, 1, or 0).') - - -class IllegalGenerationError(BrokerResponseError): - errno = 22 - message = 'ILLEGAL_GENERATION' - description = ('Returned from group membership requests (such as heartbeats)' - ' when the generation id provided in the request is not the' - ' current generation.') - - -class InconsistentGroupProtocolError(BrokerResponseError): - errno = 23 - message = 'INCONSISTENT_GROUP_PROTOCOL' - description = ('Returned in join group when the member provides a protocol' - ' type or set of protocols which is not compatible with the' - ' current group.') - - -class InvalidGroupIdError(BrokerResponseError): - errno = 24 - message = 'INVALID_GROUP_ID' - description = 'Returned in join group when the groupId is empty or null.' - - -class UnknownMemberIdError(BrokerResponseError): - errno = 25 - message = 'UNKNOWN_MEMBER_ID' - description = ('Returned from group requests (offset commits/fetches,' - ' heartbeats, etc) when the memberId is not in the current' - ' generation.') - - -class InvalidSessionTimeoutError(BrokerResponseError): - errno = 26 - message = 'INVALID_SESSION_TIMEOUT' - description = ('Return in join group when the requested session timeout is' - ' outside of the allowed range on the broker') - - -class RebalanceInProgressError(BrokerResponseError): - errno = 27 - message = 'REBALANCE_IN_PROGRESS' - description = ('Returned in heartbeat requests when the coordinator has' - ' begun rebalancing the group. This indicates to the client' - ' that it should rejoin the group.') - - -class InvalidCommitOffsetSizeError(BrokerResponseError): - errno = 28 - message = 'INVALID_COMMIT_OFFSET_SIZE' - description = ('This error indicates that an offset commit was rejected' - ' because of oversize metadata.') - - -class TopicAuthorizationFailedError(BrokerResponseError): - errno = 29 - message = 'TOPIC_AUTHORIZATION_FAILED' - description = ('Returned by the broker when the client is not authorized to' - ' access the requested topic.') - - -class GroupAuthorizationFailedError(BrokerResponseError): - errno = 30 - message = 'GROUP_AUTHORIZATION_FAILED' - description = ('Returned by the broker when the client is not authorized to' - ' access a particular groupId.') - - -class ClusterAuthorizationFailedError(BrokerResponseError): - errno = 31 - message = 'CLUSTER_AUTHORIZATION_FAILED' - description = ('Returned by the broker when the client is not authorized to' - ' use an inter-broker or administrative API.') - - -class InvalidTimestampError(BrokerResponseError): - errno = 32 - message = 'INVALID_TIMESTAMP' - description = 'The timestamp of the message is out of acceptable range.' - - -class UnsupportedSaslMechanismError(BrokerResponseError): - errno = 33 - message = 'UNSUPPORTED_SASL_MECHANISM' - description = 'The broker does not support the requested SASL mechanism.' - - -class IllegalSaslStateError(BrokerResponseError): - errno = 34 - message = 'ILLEGAL_SASL_STATE' - description = 'Request is not valid given the current SASL state.' - - -class UnsupportedVersionError(BrokerResponseError): - errno = 35 - message = 'UNSUPPORTED_VERSION' - description = 'The version of API is not supported.' - - -class TopicAlreadyExistsError(BrokerResponseError): - errno = 36 - message = 'TOPIC_ALREADY_EXISTS' - description = 'Topic with this name already exists.' - - -class InvalidPartitionsError(BrokerResponseError): - errno = 37 - message = 'INVALID_PARTITIONS' - description = 'Number of partitions is invalid.' - - -class InvalidReplicationFactorError(BrokerResponseError): - errno = 38 - message = 'INVALID_REPLICATION_FACTOR' - description = 'Replication-factor is invalid.' - - -class InvalidReplicationAssignmentError(BrokerResponseError): - errno = 39 - message = 'INVALID_REPLICATION_ASSIGNMENT' - description = 'Replication assignment is invalid.' - - -class InvalidConfigurationError(BrokerResponseError): - errno = 40 - message = 'INVALID_CONFIG' - description = 'Configuration is invalid.' - - -class NotControllerError(BrokerResponseError): - errno = 41 - message = 'NOT_CONTROLLER' - description = 'This is not the correct controller for this cluster.' - retriable = True - - -class InvalidRequestError(BrokerResponseError): - errno = 42 - message = 'INVALID_REQUEST' - description = ('This most likely occurs because of a request being' - ' malformed by the client library or the message was' - ' sent to an incompatible broker. See the broker logs' - ' for more details.') - - -class UnsupportedForMessageFormatError(BrokerResponseError): - errno = 43 - message = 'UNSUPPORTED_FOR_MESSAGE_FORMAT' - description = ('The message format version on the broker does not' - ' support this request.') - - -class PolicyViolationError(BrokerResponseError): - errno = 44 - message = 'POLICY_VIOLATION' - description = 'Request parameters do not satisfy the configured policy.' - - -class SecurityDisabledError(BrokerResponseError): - errno = 54 - message = 'SECURITY_DISABLED' - description = 'Security features are disabled.' - - -class NonEmptyGroupError(BrokerResponseError): - errno = 68 - message = 'NON_EMPTY_GROUP' - description = 'The group is not empty.' - - -class GroupIdNotFoundError(BrokerResponseError): - errno = 69 - message = 'GROUP_ID_NOT_FOUND' - description = 'The group id does not exist.' - - -class KafkaUnavailableError(KafkaError): - pass - - -class KafkaTimeoutError(KafkaError): - pass - - -class FailedPayloadsError(KafkaError): - def __init__(self, payload, *args): - super(FailedPayloadsError, self).__init__(*args) - self.payload = payload - - -class KafkaConnectionError(KafkaError): - retriable = True - invalid_metadata = True - - -class ProtocolError(KafkaError): - pass - - -class UnsupportedCodecError(KafkaError): - pass - - -class KafkaConfigurationError(KafkaError): - pass - - -class QuotaViolationError(KafkaError): - pass - - -class AsyncProducerQueueFull(KafkaError): - def __init__(self, failed_msgs, *args): - super(AsyncProducerQueueFull, self).__init__(*args) - self.failed_msgs = failed_msgs - - -def _iter_broker_errors(): - for name, obj in inspect.getmembers(sys.modules[__name__]): - if inspect.isclass(obj) and issubclass(obj, BrokerResponseError) and obj != BrokerResponseError: - yield obj - - -kafka_errors = dict([(x.errno, x) for x in _iter_broker_errors()]) - - -def for_code(error_code): - return kafka_errors.get(error_code, UnknownError) - - -def check_error(response): - if isinstance(response, Exception): - raise response - if response.error: - error_class = kafka_errors.get(response.error, UnknownError) - raise error_class(response) - - -RETRY_BACKOFF_ERROR_TYPES = ( - KafkaUnavailableError, LeaderNotAvailableError, - KafkaConnectionError, FailedPayloadsError -) - - -RETRY_REFRESH_ERROR_TYPES = ( - NotLeaderForPartitionError, UnknownTopicOrPartitionError, - LeaderNotAvailableError, KafkaConnectionError -) - - -RETRY_ERROR_TYPES = RETRY_BACKOFF_ERROR_TYPES + RETRY_REFRESH_ERROR_TYPES diff --git a/kafka/metrics/stats/sensor.py b/kafka/metrics/stats/sensor.py index 571723f9..a0dbe4c1 100644 --- a/kafka/metrics/stats/sensor.py +++ b/kafka/metrics/stats/sensor.py @@ -3,7 +3,7 @@ import threading import time -from kafka.errors import QuotaViolationError +from aiokafka.errors import QuotaViolationError from kafka.metrics import KafkaMetric diff --git a/kafka/protocol/parser.py b/kafka/protocol/parser.py index a9e76722..a872202d 100644 --- a/kafka/protocol/parser.py +++ b/kafka/protocol/parser.py @@ -3,7 +3,7 @@ import collections import logging -import kafka.errors as Errors +import aiokafka.errors as Errors from kafka.protocol.commit import GroupCoordinatorResponse from kafka.protocol.frame import KafkaBytes from kafka.protocol.types import Int32, TaggedFields diff --git a/tests/kafka/fixtures.py b/tests/kafka/fixtures.py index 45e2a053..76bde28f 100644 --- a/tests/kafka/fixtures.py +++ b/tests/kafka/fixtures.py @@ -13,8 +13,8 @@ from kafka.vendor.six.moves import urllib, range from kafka.vendor.six.moves.urllib.parse import urlparse # pylint: disable=E0611,F0401 -from kafka import errors -from kafka.errors import InvalidReplicationFactorError +from aiokafka import errors +from aiokafka.errors import InvalidReplicationFactorError from kafka.protocol.admin import CreateTopicsRequest from kafka.protocol.metadata import MetadataRequest from tests.kafka.testutil import env_kafka_version, random_string diff --git a/tests/kafka/test_conn.py b/tests/kafka/test_conn.py index b49a8bd3..6eb45f45 100644 --- a/tests/kafka/test_conn.py +++ b/tests/kafka/test_conn.py @@ -12,7 +12,7 @@ from kafka.protocol.metadata import MetadataRequest from kafka.protocol.produce import ProduceRequest -import kafka.errors as Errors +import aiokafka.errors as Errors @pytest.fixture diff --git a/tests/kafka/test_metrics.py b/tests/kafka/test_metrics.py index 308ea583..64cc1fc1 100644 --- a/tests/kafka/test_metrics.py +++ b/tests/kafka/test_metrics.py @@ -3,7 +3,7 @@ import pytest -from kafka.errors import QuotaViolationError +from aiokafka.errors import QuotaViolationError from kafka.metrics import DictReporter, MetricConfig, MetricName, Metrics, Quota from kafka.metrics.measurable import AbstractMeasurable from kafka.metrics.stats import (Avg, Count, Max, Min, Percentile, Percentiles, diff --git a/tests/record/test_default_records.py b/tests/record/test_default_records.py index f2d6d8cd..455590c9 100644 --- a/tests/record/test_default_records.py +++ b/tests/record/test_default_records.py @@ -1,8 +1,9 @@ from unittest import mock import kafka.codec -from kafka.errors import UnsupportedCodecError import pytest + +from aiokafka.errors import UnsupportedCodecError from aiokafka.record.default_records import ( DefaultRecordBatch, DefaultRecordBatchBuilder ) diff --git a/tests/record/test_legacy.py b/tests/record/test_legacy.py index b3ea4ff5..ee3c6a76 100644 --- a/tests/record/test_legacy.py +++ b/tests/record/test_legacy.py @@ -2,12 +2,12 @@ from unittest import mock import kafka.codec -from kafka.errors import UnsupportedCodecError import pytest + +from aiokafka.errors import CorruptRecordException, UnsupportedCodecError from aiokafka.record.legacy_records import ( LegacyRecordBatch, LegacyRecordBatchBuilder ) -from aiokafka.errors import CorruptRecordException @pytest.mark.parametrize("magic", [0, 1]) diff --git a/tests/test_client.py b/tests/test_client.py index 3fbec2ee..e9ceb517 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -4,10 +4,6 @@ from typing import Any from unittest import mock -from kafka.errors import ( - KafkaError, KafkaConnectionError, RequestTimedOutError, - NodeNotReadyError, UnrecognizedBrokerVersion -) from kafka.protocol.metadata import ( MetadataRequest_v0 as MetadataRequest, MetadataResponse_v0 as MetadataResponse) @@ -16,6 +12,10 @@ from aiokafka import __version__ from aiokafka.client import AIOKafkaClient, ConnectionGroup, CoordinationType from aiokafka.conn import AIOKafkaConnection, CloseReason +from aiokafka.errors import ( + KafkaError, KafkaConnectionError, RequestTimedOutError, + NodeNotReadyError, UnrecognizedBrokerVersion +) from aiokafka.util import create_task, get_running_loop from ._testutil import ( KafkaIntegrationTestCase, run_until_complete, kafka_versions diff --git a/tests/test_coordinator.py b/tests/test_coordinator.py index f8ebb6c0..59b26de9 100644 --- a/tests/test_coordinator.py +++ b/tests/test_coordinator.py @@ -12,12 +12,12 @@ OffsetCommitRequest, OffsetCommitResponse_v2, OffsetFetchRequest_v1 as OffsetFetchRequest ) -import kafka.errors as Errors from ._testutil import KafkaIntegrationTestCase, run_until_complete from aiokafka import ConsumerRebalanceListener from aiokafka.client import AIOKafkaClient +import aiokafka.errors as Errors from aiokafka.structs import OffsetAndMetadata, TopicPartition from aiokafka.consumer.group_coordinator import ( GroupCoordinator, CoordinatorGroupRebalance, NoGroupCoordinator) diff --git a/tests/test_message_accumulator.py b/tests/test_message_accumulator.py index 406526f1..bcce740b 100644 --- a/tests/test_message_accumulator.py +++ b/tests/test_message_accumulator.py @@ -4,11 +4,11 @@ from unittest import mock from kafka.cluster import ClusterMetadata -from kafka.errors import (KafkaTimeoutError, - NotLeaderForPartitionError, - LeaderNotAvailableError) from kafka.structs import TopicPartition from ._testutil import run_until_complete +from aiokafka.errors import ( + KafkaTimeoutError, NotLeaderForPartitionError, LeaderNotAvailableError +) from aiokafka.util import create_task, get_running_loop from aiokafka.producer.message_accumulator import ( MessageAccumulator, MessageBatch, BatchBuilder From 04b7bff5ca6acfed758b111053049cc89f2fea3f Mon Sep 17 00:00:00 2001 From: Denis Otkidach Date: Sun, 22 Oct 2023 18:20:53 +0300 Subject: [PATCH 10/20] Move coordinator --- aiokafka/consumer/consumer.py | 3 +- aiokafka/consumer/group_coordinator.py | 4 +- {kafka => aiokafka}/coordinator/__init__.py | 0 .../coordinator/assignors/__init__.py | 0 .../coordinator/assignors/abstract.py | 7 +- .../coordinator/assignors/range.py | 29 +- .../coordinator/assignors/roundrobin.py | 26 +- .../coordinator/assignors/sticky/__init__.py | 0 .../assignors/sticky/partition_movements.py | 65 +- .../assignors/sticky/sorted_set.py | 8 +- .../assignors/sticky/sticky_assignor.py | 431 +++++++---- {kafka => aiokafka}/coordinator/base.py | 593 +++++++++------ {kafka => aiokafka}/coordinator/consumer.py | 616 +++++++++------ {kafka => aiokafka}/coordinator/heartbeat.py | 37 +- aiokafka/coordinator/protocol.py | 33 + docs/api.rst | 2 +- kafka/coordinator/protocol.py | 33 - tests/coordinator/__init__.py | 0 .../{kafka => coordinator}/test_assignors.py | 714 +++++++++++------- tests/coordinator/test_partition_movements.py | 23 + tests/kafka/test_partition_movements.py | 23 - 21 files changed, 1656 insertions(+), 991 deletions(-) rename {kafka => aiokafka}/coordinator/__init__.py (100%) rename {kafka => aiokafka}/coordinator/assignors/__init__.py (100%) rename {kafka => aiokafka}/coordinator/assignors/abstract.py (86%) rename {kafka => aiokafka}/coordinator/assignors/range.py (79%) rename {kafka => aiokafka}/coordinator/assignors/roundrobin.py (86%) rename {kafka => aiokafka}/coordinator/assignors/sticky/__init__.py (100%) rename {kafka => aiokafka}/coordinator/assignors/sticky/partition_movements.py (70%) rename {kafka => aiokafka}/coordinator/assignors/sticky/sorted_set.py (94%) rename {kafka => aiokafka}/coordinator/assignors/sticky/sticky_assignor.py (69%) rename {kafka => aiokafka}/coordinator/base.py (70%) rename {kafka => aiokafka}/coordinator/consumer.py (62%) rename {kafka => aiokafka}/coordinator/heartbeat.py (59%) create mode 100644 aiokafka/coordinator/protocol.py delete mode 100644 kafka/coordinator/protocol.py create mode 100644 tests/coordinator/__init__.py rename tests/{kafka => coordinator}/test_assignors.py (51%) create mode 100644 tests/coordinator/test_partition_movements.py delete mode 100644 tests/kafka/test_partition_movements.py diff --git a/aiokafka/consumer/consumer.py b/aiokafka/consumer/consumer.py index 27421690..c82f6fb0 100644 --- a/aiokafka/consumer/consumer.py +++ b/aiokafka/consumer/consumer.py @@ -6,10 +6,9 @@ import warnings from typing import Dict, List -from kafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor - from aiokafka.abc import ConsumerRebalanceListener from aiokafka.client import AIOKafkaClient +from aiokafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor from aiokafka.errors import ( TopicAuthorizationFailedError, OffsetOutOfRangeError, ConsumerStoppedError, IllegalOperation, UnsupportedVersionError, diff --git a/aiokafka/consumer/group_coordinator.py b/aiokafka/consumer/group_coordinator.py index 8d244288..8a8c76f4 100644 --- a/aiokafka/consumer/group_coordinator.py +++ b/aiokafka/consumer/group_coordinator.py @@ -4,8 +4,6 @@ import copy import time -from kafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor -from kafka.coordinator.protocol import ConsumerProtocol from kafka.protocol.commit import ( OffsetCommitRequest_v2 as OffsetCommitRequest, OffsetFetchRequest_v1 as OffsetFetchRequest) @@ -15,6 +13,8 @@ import aiokafka.errors as Errors from aiokafka.structs import OffsetAndMetadata, TopicPartition from aiokafka.client import ConnectionGroup, CoordinationType +from aiokafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor +from aiokafka.coordinator.protocol import ConsumerProtocol from aiokafka.util import create_future, create_task log = logging.getLogger(__name__) diff --git a/kafka/coordinator/__init__.py b/aiokafka/coordinator/__init__.py similarity index 100% rename from kafka/coordinator/__init__.py rename to aiokafka/coordinator/__init__.py diff --git a/kafka/coordinator/assignors/__init__.py b/aiokafka/coordinator/assignors/__init__.py similarity index 100% rename from kafka/coordinator/assignors/__init__.py rename to aiokafka/coordinator/assignors/__init__.py diff --git a/kafka/coordinator/assignors/abstract.py b/aiokafka/coordinator/assignors/abstract.py similarity index 86% rename from kafka/coordinator/assignors/abstract.py rename to aiokafka/coordinator/assignors/abstract.py index a1fef384..dc22342f 100644 --- a/kafka/coordinator/assignors/abstract.py +++ b/aiokafka/coordinator/assignors/abstract.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - import abc import logging @@ -7,9 +5,8 @@ class AbstractPartitionAssignor(object): - """ - Abstract assignor implementation which does some common grunt work (in particular collecting - partition counts which are always needed in assignors). + """Abstract assignor implementation which does some common grunt work (in particular + collecting partition counts which are always needed in assignors). """ @abc.abstractproperty diff --git a/kafka/coordinator/assignors/range.py b/aiokafka/coordinator/assignors/range.py similarity index 79% rename from kafka/coordinator/assignors/range.py rename to aiokafka/coordinator/assignors/range.py index 299e39c4..38886f71 100644 --- a/kafka/coordinator/assignors/range.py +++ b/aiokafka/coordinator/assignors/range.py @@ -1,12 +1,11 @@ -from __future__ import absolute_import - import collections import logging -from kafka.vendor import six - -from kafka.coordinator.assignors.abstract import AbstractPartitionAssignor -from kafka.coordinator.protocol import ConsumerProtocolMemberMetadata, ConsumerProtocolMemberAssignment +from aiokafka.coordinator.assignors.abstract import AbstractPartitionAssignor +from aiokafka.coordinator.protocol import ( + ConsumerProtocolMemberMetadata, + ConsumerProtocolMemberAssignment, +) log = logging.getLogger(__name__) @@ -28,23 +27,24 @@ class RangePartitionAssignor(AbstractPartitionAssignor): C0: [t0p0, t0p1, t1p0, t1p1] C1: [t0p2, t1p2] """ - name = 'range' + + name = "range" version = 0 @classmethod def assign(cls, cluster, member_metadata): consumers_per_topic = collections.defaultdict(list) - for member, metadata in six.iteritems(member_metadata): + for member, metadata in member_metadata.items(): for topic in metadata.subscription: consumers_per_topic[topic].append(member) # construct {member_id: {topic: [partition, ...]}} assignment = collections.defaultdict(dict) - for topic, consumers_for_topic in six.iteritems(consumers_per_topic): + for topic, consumers_for_topic in consumers_per_topic.items(): partitions = cluster.partitions_for_topic(topic) if partitions is None: - log.warning('No partition metadata for topic %s', topic) + log.warning("No partition metadata for topic %s", topic) continue partitions = sorted(partitions) consumers_for_topic.sort() @@ -58,19 +58,18 @@ def assign(cls, cluster, member_metadata): length = partitions_per_consumer if not i + 1 > consumers_with_extra: length += 1 - assignment[member][topic] = partitions[start:start+length] + assignment[member][topic] = partitions[start:start + length] protocol_assignment = {} for member_id in member_metadata: protocol_assignment[member_id] = ConsumerProtocolMemberAssignment( - cls.version, - sorted(assignment[member_id].items()), - b'') + cls.version, sorted(assignment[member_id].items()), b"" + ) return protocol_assignment @classmethod def metadata(cls, topics): - return ConsumerProtocolMemberMetadata(cls.version, list(topics), b'') + return ConsumerProtocolMemberMetadata(cls.version, list(topics), b"") @classmethod def on_assignment(cls, assignment): diff --git a/kafka/coordinator/assignors/roundrobin.py b/aiokafka/coordinator/assignors/roundrobin.py similarity index 86% rename from kafka/coordinator/assignors/roundrobin.py rename to aiokafka/coordinator/assignors/roundrobin.py index 2d24a5c8..f3dd47f2 100644 --- a/kafka/coordinator/assignors/roundrobin.py +++ b/aiokafka/coordinator/assignors/roundrobin.py @@ -1,15 +1,15 @@ -from __future__ import absolute_import - import collections import itertools import logging -from kafka.vendor import six - -from kafka.coordinator.assignors.abstract import AbstractPartitionAssignor -from kafka.coordinator.protocol import ConsumerProtocolMemberMetadata, ConsumerProtocolMemberAssignment from kafka.structs import TopicPartition +from aiokafka.coordinator.assignors.abstract import AbstractPartitionAssignor +from aiokafka.coordinator.protocol import ( + ConsumerProtocolMemberMetadata, + ConsumerProtocolMemberAssignment, +) + log = logging.getLogger(__name__) @@ -45,20 +45,21 @@ class RoundRobinPartitionAssignor(AbstractPartitionAssignor): C1: [t1p0] C2: [t1p1, t2p0, t2p1, t2p2] """ - name = 'roundrobin' + + name = "roundrobin" version = 0 @classmethod def assign(cls, cluster, member_metadata): all_topics = set() - for metadata in six.itervalues(member_metadata): + for metadata in member_metadata.values(): all_topics.update(metadata.subscription) all_topic_partitions = [] for topic in all_topics: partitions = cluster.partitions_for_topic(topic) if partitions is None: - log.warning('No partition metadata for topic %s', topic) + log.warning("No partition metadata for topic %s", topic) continue for partition in partitions: all_topic_partitions.append(TopicPartition(topic, partition)) @@ -82,14 +83,13 @@ def assign(cls, cluster, member_metadata): protocol_assignment = {} for member_id in member_metadata: protocol_assignment[member_id] = ConsumerProtocolMemberAssignment( - cls.version, - sorted(assignment[member_id].items()), - b'') + cls.version, sorted(assignment[member_id].items()), b"" + ) return protocol_assignment @classmethod def metadata(cls, topics): - return ConsumerProtocolMemberMetadata(cls.version, list(topics), b'') + return ConsumerProtocolMemberMetadata(cls.version, list(topics), b"") @classmethod def on_assignment(cls, assignment): diff --git a/kafka/coordinator/assignors/sticky/__init__.py b/aiokafka/coordinator/assignors/sticky/__init__.py similarity index 100% rename from kafka/coordinator/assignors/sticky/__init__.py rename to aiokafka/coordinator/assignors/sticky/__init__.py diff --git a/kafka/coordinator/assignors/sticky/partition_movements.py b/aiokafka/coordinator/assignors/sticky/partition_movements.py similarity index 70% rename from kafka/coordinator/assignors/sticky/partition_movements.py rename to aiokafka/coordinator/assignors/sticky/partition_movements.py index 8851e4cd..47e7c71b 100644 --- a/kafka/coordinator/assignors/sticky/partition_movements.py +++ b/aiokafka/coordinator/assignors/sticky/partition_movements.py @@ -2,18 +2,17 @@ from collections import defaultdict, namedtuple from copy import deepcopy -from kafka.vendor import six - log = logging.getLogger(__name__) ConsumerPair = namedtuple("ConsumerPair", ["src_member_id", "dst_member_id"]) """ Represents a pair of Kafka consumer ids involved in a partition reassignment. -Each ConsumerPair corresponds to a particular partition or topic, indicates that the particular partition or some -partition of the particular topic was moved from the source consumer to the destination consumer -during the rebalance. This class helps in determining whether a partition reassignment results in cycles among -the generated graph of consumer pairs. +Each ConsumerPair corresponds to a particular partition or topic, indicates that the +particular partition or some partition of the particular topic was moved from the source +consumer to the destination consumer during the rebalance. This class helps in +determining whether a partition reassignment results in cycles among the generated graph +of consumer pairs. """ @@ -28,22 +27,21 @@ def is_sublist(source, target): true if target is in source; false otherwise """ for index in (i for i, e in enumerate(source) if e == target[0]): - if tuple(source[index: index + len(target)]) == target: + if tuple(source[index:index + len(target)]) == target: return True return False class PartitionMovements: """ - This class maintains some data structures to simplify lookup of partition movements among consumers. - At each point of time during a partition rebalance it keeps track of partition movements - corresponding to each topic, and also possible movement (in form a ConsumerPair object) for each partition. + This class maintains some data structures to simplify lookup of partition movements + among consumers. At each point of time during a partition rebalance it keeps track + of partition movements corresponding to each topic, and also possible movement (in + form a ConsumerPair object) for each partition. """ def __init__(self): - self.partition_movements_by_topic = defaultdict( - lambda: defaultdict(set) - ) + self.partition_movements_by_topic = defaultdict(lambda: defaultdict(set)) self.partition_movements = {} def move_partition(self, partition, old_consumer, new_consumer): @@ -55,7 +53,11 @@ def move_partition(self, partition, old_consumer, new_consumer): if existing_pair.src_member_id != new_consumer: # the partition is not moving back to its previous consumer self._add_partition_movement_record( - partition, ConsumerPair(src_member_id=existing_pair.src_member_id, dst_member_id=new_consumer) + partition, + ConsumerPair( + src_member_id=existing_pair.src_member_id, + dst_member_id=new_consumer, + ), ) else: self._add_partition_movement_record(partition, pair) @@ -67,19 +69,24 @@ def get_partition_to_be_moved(self, partition, old_consumer, new_consumer): # this partition has previously moved assert old_consumer == self.partition_movements[partition].dst_member_id old_consumer = self.partition_movements[partition].src_member_id - reverse_pair = ConsumerPair(src_member_id=new_consumer, dst_member_id=old_consumer) + reverse_pair = ConsumerPair( + src_member_id=new_consumer, dst_member_id=old_consumer + ) if reverse_pair not in self.partition_movements_by_topic[partition.topic]: return partition - return next(iter(self.partition_movements_by_topic[partition.topic][reverse_pair])) + return next( + iter(self.partition_movements_by_topic[partition.topic][reverse_pair]) + ) def are_sticky(self): - for topic, movements in six.iteritems(self.partition_movements_by_topic): + for topic, movements in self.partition_movements_by_topic.items(): movement_pairs = set(movements.keys()) if self._has_cycles(movement_pairs): log.error( "Stickiness is violated for topic {}\n" - "Partition movements for this topic occurred among the following consumer pairs:\n" + "Partition movements for this topic occurred among the following " + "consumer pairs:\n" "{}".format(topic, movement_pairs) ) return False @@ -107,15 +114,19 @@ def _has_cycles(self, consumer_pairs): reduced_pairs = deepcopy(consumer_pairs) reduced_pairs.remove(pair) path = [pair.src_member_id] - if self._is_linked(pair.dst_member_id, pair.src_member_id, reduced_pairs, path) and not self._is_subcycle( - path, cycles - ): + if self._is_linked( + pair.dst_member_id, pair.src_member_id, reduced_pairs, path + ) and not self._is_subcycle(path, cycles): cycles.add(tuple(path)) - log.error("A cycle of length {} was found: {}".format(len(path) - 1, path)) + log.error( + "A cycle of length {} was found: {}".format(len(path) - 1, path) + ) - # for now we want to make sure there is no partition movements of the same topic between a pair of consumers. - # the odds of finding a cycle among more than two consumers seem to be very low (according to various randomized - # tests with the given sticky algorithm) that it should not worth the added complexity of handling those cases. + # for now we want to make sure there is no partition movements of the same topic + # between a pair of consumers. the odds of finding a cycle among more than two + # consumers seem to be very low (according to various randomized tests with the + # given sticky algorithm) that it should not worth the added complexity of + # handling those cases. for cycle in cycles: if len(cycle) == 3: # indicates a cycle of length 2 return True @@ -145,5 +156,7 @@ def _is_linked(self, src, dst, pairs, current_path): reduced_set = deepcopy(pairs) reduced_set.remove(pair) current_path.append(pair.src_member_id) - return self._is_linked(pair.dst_member_id, dst, reduced_set, current_path) + return self._is_linked( + pair.dst_member_id, dst, reduced_set, current_path + ) return False diff --git a/kafka/coordinator/assignors/sticky/sorted_set.py b/aiokafka/coordinator/assignors/sticky/sorted_set.py similarity index 94% rename from kafka/coordinator/assignors/sticky/sorted_set.py rename to aiokafka/coordinator/assignors/sticky/sorted_set.py index 6a454a42..7903f6ca 100644 --- a/kafka/coordinator/assignors/sticky/sorted_set.py +++ b/aiokafka/coordinator/assignors/sticky/sorted_set.py @@ -35,9 +35,13 @@ def pop_last(self): return value def add(self, value): - if self._cached_last is not None and self._key(value) > self._key(self._cached_last): + if self._cached_last is not None and self._key(value) > self._key( + self._cached_last + ): self._cached_last = value - if self._cached_first is not None and self._key(value) < self._key(self._cached_first): + if self._cached_first is not None and self._key(value) < self._key( + self._cached_first + ): self._cached_first = value return self._set.add(value) diff --git a/kafka/coordinator/assignors/sticky/sticky_assignor.py b/aiokafka/coordinator/assignors/sticky/sticky_assignor.py similarity index 69% rename from kafka/coordinator/assignors/sticky/sticky_assignor.py rename to aiokafka/coordinator/assignors/sticky/sticky_assignor.py index dce714f1..452b2fd6 100644 --- a/kafka/coordinator/assignors/sticky/sticky_assignor.py +++ b/aiokafka/coordinator/assignors/sticky/sticky_assignor.py @@ -2,20 +2,24 @@ from collections import defaultdict, namedtuple from copy import deepcopy -from kafka.cluster import ClusterMetadata -from kafka.coordinator.assignors.abstract import AbstractPartitionAssignor -from kafka.coordinator.assignors.sticky.partition_movements import PartitionMovements -from kafka.coordinator.assignors.sticky.sorted_set import SortedSet -from kafka.coordinator.protocol import ConsumerProtocolMemberMetadata, ConsumerProtocolMemberAssignment -from kafka.coordinator.protocol import Schema from kafka.protocol.struct import Struct from kafka.protocol.types import String, Array, Int32 from kafka.structs import TopicPartition -from kafka.vendor import six + +from aiokafka.coordinator.assignors.abstract import AbstractPartitionAssignor +from aiokafka.coordinator.assignors.sticky.partition_movements import PartitionMovements +from aiokafka.coordinator.assignors.sticky.sorted_set import SortedSet +from aiokafka.coordinator.protocol import ( + ConsumerProtocolMemberMetadata, + ConsumerProtocolMemberAssignment, +) +from aiokafka.coordinator.protocol import Schema log = logging.getLogger(__name__) -ConsumerGenerationPair = namedtuple("ConsumerGenerationPair", ["consumer", "generation"]) +ConsumerGenerationPair = namedtuple( + "ConsumerGenerationPair", ["consumer", "generation"] +) def has_identical_list_elements(list_): @@ -50,8 +54,9 @@ def remove_if_present(collection, element): pass -StickyAssignorMemberMetadataV1 = namedtuple("StickyAssignorMemberMetadataV1", - ["subscription", "partitions", "generation"]) +StickyAssignorMemberMetadataV1 = namedtuple( + "StickyAssignorMemberMetadataV1", ["subscription", "partitions", "generation"] +) class StickyAssignorUserDataV1(Struct): @@ -61,14 +66,19 @@ class StickyAssignorUserDataV1(Struct): """ SCHEMA = Schema( - ("previous_assignment", Array(("topic", String("utf-8")), ("partitions", Array(Int32)))), ("generation", Int32) + ( + "previous_assignment", + Array(("topic", String("utf-8")), ("partitions", Array(Int32))), + ), + ("generation", Int32), ) class StickyAssignmentExecutor: def __init__(self, cluster, members): self.members = members - # a mapping between consumers and their assigned partitions that is updated during assignment procedure + # a mapping between consumers and their assigned partitions that is updated + # during assignment procedure self.current_assignment = defaultdict(list) # an assignment from a previous generation self.previous_assignment = {} @@ -76,18 +86,22 @@ def __init__(self, cluster, members): self.current_partition_consumer = {} # a flag indicating that there were no previous assignments performed ever self.is_fresh_assignment = False - # a mapping of all topic partitions to all consumers that can be assigned to them + # a mapping of all topic partitions to all consumers that can be assigned to + # them self.partition_to_all_potential_consumers = {} - # a mapping of all consumers to all potential topic partitions that can be assigned to them + # a mapping of all consumers to all potential topic partitions that can be + # assigned to them self.consumer_to_all_potential_partitions = {} - # an ascending sorted set of consumers based on how many topic partitions are already assigned to them + # an ascending sorted set of consumers based on how many topic partitions are + # already assigned to them self.sorted_current_subscriptions = SortedSet() - # an ascending sorted list of topic partitions based on how many consumers can potentially use them + # an ascending sorted list of topic partitions based on how many consumers can + # potentially use them self.sorted_partitions = [] # all partitions that need to be assigned self.unassigned_partitions = [] - # a flag indicating that a certain partition cannot remain assigned to its current consumer because the consumer - # is no longer subscribed to its topic + # a flag indicating that a certain partition cannot remain assigned to its + # current consumer because the consumer is no longer subscribed to its topic self.revocation_required = False self.partition_movements = PartitionMovements() @@ -99,7 +113,10 @@ def perform_initial_assignment(self): def balance(self): self._initialize_current_subscriptions() - initializing = len(self.current_assignment[self._get_consumer_with_most_subscriptions()]) == 0 + initializing = ( + len(self.current_assignment[self._get_consumer_with_most_subscriptions()]) + == 0 + ) # assign all unassigned partitions for partition in self.unassigned_partitions: @@ -108,20 +125,24 @@ def balance(self): continue self._assign_partition(partition) - # narrow down the reassignment scope to only those partitions that can actually be reassigned + # narrow down the reassignment scope to only those partitions that can actually + # be reassigned fixed_partitions = set() - for partition in six.iterkeys(self.partition_to_all_potential_consumers): + for partition in self.partition_to_all_potential_consumers.keys(): if not self._can_partition_participate_in_reassignment(partition): fixed_partitions.add(partition) for fixed_partition in fixed_partitions: remove_if_present(self.sorted_partitions, fixed_partition) remove_if_present(self.unassigned_partitions, fixed_partition) - # narrow down the reassignment scope to only those consumers that are subject to reassignment + # narrow down the reassignment scope to only those consumers that are subject to + # reassignment fixed_assignments = {} - for consumer in six.iterkeys(self.consumer_to_all_potential_partitions): + for consumer in self.consumer_to_all_potential_partitions.keys(): if not self._can_consumer_participate_in_reassignment(consumer): - self._remove_consumer_from_current_subscriptions_and_maintain_order(consumer) + self._remove_consumer_from_current_subscriptions_and_maintain_order( + consumer + ) fixed_assignments[consumer] = self.current_assignment[consumer] del self.current_assignment[consumer] @@ -136,19 +157,21 @@ def balance(self): self._perform_reassignments(self.unassigned_partitions) reassignment_performed = self._perform_reassignments(self.sorted_partitions) - # if we are not preserving existing assignments and we have made changes to the current assignment - # make sure we are getting a more balanced assignment; otherwise, revert to previous assignment + # if we are not preserving existing assignments and we have made changes to the + # current assignment make sure we are getting a more balanced assignment; + # otherwise, revert to previous assignment if ( not initializing and reassignment_performed - and self._get_balance_score(self.current_assignment) >= self._get_balance_score(prebalance_assignment) + and self._get_balance_score(self.current_assignment) + >= self._get_balance_score(prebalance_assignment) ): self.current_assignment = prebalance_assignment self.current_partition_consumer.clear() self.current_partition_consumer.update(prebalance_partition_consumers) # add the fixed assignments (those that could not change) back - for consumer, partitions in six.iteritems(fixed_assignments): + for consumer, partitions in fixed_assignments.items(): self.current_assignment[consumer] = partitions self._add_consumer_to_current_subscriptions_and_maintain_order(consumer) @@ -156,8 +179,8 @@ def get_final_assignment(self, member_id): assignment = defaultdict(list) for topic_partition in self.current_assignment[member_id]: assignment[topic_partition.topic].append(topic_partition.partition) - assignment = {k: sorted(v) for k, v in six.iteritems(assignment)} - return six.viewitems(assignment) + assignment = {k: sorted(v) for k, v in assignment.items()} + return assignment.items() def _initialize(self, cluster): self._init_current_assignments(self.members) @@ -170,7 +193,7 @@ def _initialize(self, cluster): for p in partitions: partition = TopicPartition(topic=topic, partition=p) self.partition_to_all_potential_consumers[partition] = [] - for consumer_id, member_metadata in six.iteritems(self.members): + for consumer_id, member_metadata in self.members.items(): self.consumer_to_all_potential_partitions[consumer_id] = [] for topic in member_metadata.subscription: if cluster.partitions_for_topic(topic) is None: @@ -178,38 +201,51 @@ def _initialize(self, cluster): continue for p in cluster.partitions_for_topic(topic): partition = TopicPartition(topic=topic, partition=p) - self.consumer_to_all_potential_partitions[consumer_id].append(partition) - self.partition_to_all_potential_consumers[partition].append(consumer_id) + self.consumer_to_all_potential_partitions[consumer_id].append( + partition + ) + self.partition_to_all_potential_consumers[partition].append( + consumer_id + ) if consumer_id not in self.current_assignment: self.current_assignment[consumer_id] = [] def _init_current_assignments(self, members): - # we need to process subscriptions' user data with each consumer's reported generation in mind - # higher generations overwrite lower generations in case of a conflict - # note that a conflict could exists only if user data is for different generations + # we need to process subscriptions' user data with each consumer's reported + # generation in mind higher generations overwrite lower generations in case of + # a conflict note that a conflict could exists only if user data is for + # different generations # for each partition we create a map of its consumers by generation sorted_partition_consumers_by_generation = {} - for consumer, member_metadata in six.iteritems(members): + for consumer, member_metadata in members.items(): for partitions in member_metadata.partitions: if partitions in sorted_partition_consumers_by_generation: consumers = sorted_partition_consumers_by_generation[partitions] - if member_metadata.generation and member_metadata.generation in consumers: - # same partition is assigned to two consumers during the same rebalance. - # log a warning and skip this record + if ( + member_metadata.generation + and member_metadata.generation in consumers + ): + # same partition is assigned to two consumers during the same + # rebalance. log a warning and skip this record log.warning( "Partition {} is assigned to multiple consumers " - "following sticky assignment generation {}.".format(partitions, member_metadata.generation) + "following sticky assignment generation {}.".format( + partitions, member_metadata.generation + ) ) else: consumers[member_metadata.generation] = consumer else: sorted_consumers = {member_metadata.generation: consumer} - sorted_partition_consumers_by_generation[partitions] = sorted_consumers - - # previous_assignment holds the prior ConsumerGenerationPair (before current) of each partition - # current and previous consumers are the last two consumers of each partition in the above sorted map - for partitions, consumers in six.iteritems(sorted_partition_consumers_by_generation): + sorted_partition_consumers_by_generation[ + partitions + ] = sorted_consumers + + # previous_assignment holds the prior ConsumerGenerationPair (before current) of + # each partition current and previous consumers are the last two consumers of + # each partition in the above sorted map + for partitions, consumers in sorted_partition_consumers_by_generation.items(): generations = sorted(consumers.keys(), reverse=True) self.current_assignment[consumers[generations[0]]].append(partitions) # now update previous assignment if any @@ -220,33 +256,42 @@ def _init_current_assignments(self, members): self.is_fresh_assignment = len(self.current_assignment) == 0 - for consumer_id, partitions in six.iteritems(self.current_assignment): + for consumer_id, partitions in self.current_assignment.items(): for partition in partitions: self.current_partition_consumer[partition] = consumer_id def _are_subscriptions_identical(self): """ Returns: - true, if both potential consumers of partitions and potential partitions that consumers can - consume are the same + true, if both potential consumers of partitions and potential partitions + that consumers can consume are the same """ - if not has_identical_list_elements(list(six.itervalues(self.partition_to_all_potential_consumers))): + if not has_identical_list_elements( + list(self.partition_to_all_potential_consumers.values()) + ): return False - return has_identical_list_elements(list(six.itervalues(self.consumer_to_all_potential_partitions))) + return has_identical_list_elements( + list(self.consumer_to_all_potential_partitions.values()) + ) def _populate_sorted_partitions(self): # set of topic partitions with their respective potential consumers - all_partitions = set((tp, tuple(consumers)) - for tp, consumers in six.iteritems(self.partition_to_all_potential_consumers)) - partitions_sorted_by_num_of_potential_consumers = sorted(all_partitions, key=partitions_comparator_key) + all_partitions = set( + (tp, tuple(consumers)) + for tp, consumers in self.partition_to_all_potential_consumers.items() + ) + partitions_sorted_by_num_of_potential_consumers = sorted( + all_partitions, key=partitions_comparator_key + ) self.sorted_partitions = [] if not self.is_fresh_assignment and self._are_subscriptions_identical(): - # if this is a reassignment and the subscriptions are identical (all consumers can consumer from all topics) - # then we just need to simply list partitions in a round robin fashion (from consumers with - # most assigned partitions to those with least) + # if this is a reassignment and the subscriptions are identical (all + # consumers can consumer from all topics) then we just need to simply list + # partitions in a round robin fashion (from consumers with most assigned + # partitions to those with least) assignments = deepcopy(self.current_assignment) - for consumer_id, partitions in six.iteritems(assignments): + for partitions in assignments.values(): to_remove = [] for partition in partitions: if partition not in self.partition_to_all_potential_consumers: @@ -255,11 +300,15 @@ def _populate_sorted_partitions(self): partitions.remove(partition) sorted_consumers = SortedSet( - iterable=[(consumer, tuple(partitions)) for consumer, partitions in six.iteritems(assignments)], + iterable=[ + (consumer, tuple(partitions)) + for consumer, partitions in assignments.items() + ], key=subscriptions_comparator_key, ) - # at this point, sorted_consumers contains an ascending-sorted list of consumers based on - # how many valid partitions are currently assigned to them + # at this point, sorted_consumers contains an ascending-sorted list of + # consumers based on how many valid partitions are currently assigned to + # them while sorted_consumers: # take the consumer with the most partitions consumer, _ = sorted_consumers.pop_last() @@ -267,16 +316,19 @@ def _populate_sorted_partitions(self): remaining_partitions = assignments[consumer] # from partitions that had a different consumer before, # keep only those that are assigned to this consumer now - previous_partitions = set(six.iterkeys(self.previous_assignment)).intersection(set(remaining_partitions)) + previous_partitions = set(self.previous_assignment.keys()).intersection( + set(remaining_partitions) + ) if previous_partitions: - # if there is a partition of this consumer that was assigned to another consumer before - # mark it as good options for reassignment + # if there is a partition of this consumer that was assigned to + # another consumer before mark it as good options for reassignment partition = previous_partitions.pop() remaining_partitions.remove(partition) self.sorted_partitions.append(partition) sorted_consumers.add((consumer, tuple(assignments[consumer]))) elif remaining_partitions: - # otherwise, mark any other one of the current partitions as a reassignment candidate + # otherwise, mark any other one of the current partitions as a + # reassignment candidate self.sorted_partitions.append(remaining_partitions.pop()) sorted_consumers.add((consumer, tuple(assignments[consumer]))) @@ -286,16 +338,18 @@ def _populate_sorted_partitions(self): self.sorted_partitions.append(partition) else: while partitions_sorted_by_num_of_potential_consumers: - self.sorted_partitions.append(partitions_sorted_by_num_of_potential_consumers.pop(0)[0]) + self.sorted_partitions.append( + partitions_sorted_by_num_of_potential_consumers.pop(0)[0] + ) def _populate_partitions_to_reassign(self): self.unassigned_partitions = deepcopy(self.sorted_partitions) assignments_to_remove = [] - for consumer_id, partitions in six.iteritems(self.current_assignment): + for consumer_id, partitions in self.current_assignment.items(): if consumer_id not in self.members: - # if a consumer that existed before (and had some partition assignments) is now removed, - # remove it from current_assignment + # if a consumer that existed before (and had some partition assignments) + # is now removed, remove it from current_assignment for partition in partitions: del self.current_partition_consumer[partition] assignments_to_remove.append(consumer_id) @@ -308,14 +362,16 @@ def _populate_partitions_to_reassign(self): # remove it from current_assignment of the consumer partitions_to_remove.append(partition) elif partition.topic not in self.members[consumer_id].subscription: - # if this partition cannot remain assigned to its current consumer because the consumer - # is no longer subscribed to its topic remove it from current_assignment of the consumer + # if this partition cannot remain assigned to its current + # consumer because the consumer is no longer subscribed to its + # topic remove it from current_assignment of the consumer partitions_to_remove.append(partition) self.revocation_required = True else: - # otherwise, remove the topic partition from those that need to be assigned only if - # its current consumer is still subscribed to its topic (because it is already assigned - # and we would want to preserve that assignment as much as possible) + # otherwise, remove the topic partition from those that need to + # be assigned only if its current consumer is still subscribed + # to its topic (because it is already assigned and we would want + # to preserve that assignment as much as possible) self.unassigned_partitions.remove(partition) for partition in partitions_to_remove: self.current_assignment[consumer_id].remove(partition) @@ -325,7 +381,10 @@ def _populate_partitions_to_reassign(self): def _initialize_current_subscriptions(self): self.sorted_current_subscriptions = SortedSet( - iterable=[(consumer, tuple(partitions)) for consumer, partitions in six.iteritems(self.current_assignment)], + iterable=[ + (consumer, tuple(partitions)) + for consumer, partitions in self.current_assignment.items() + ], key=subscriptions_comparator_key, ) @@ -336,42 +395,56 @@ def _get_consumer_with_most_subscriptions(self): return self.sorted_current_subscriptions.last()[0] def _remove_consumer_from_current_subscriptions_and_maintain_order(self, consumer): - self.sorted_current_subscriptions.remove((consumer, tuple(self.current_assignment[consumer]))) + self.sorted_current_subscriptions.remove( + (consumer, tuple(self.current_assignment[consumer])) + ) def _add_consumer_to_current_subscriptions_and_maintain_order(self, consumer): - self.sorted_current_subscriptions.add((consumer, tuple(self.current_assignment[consumer]))) + self.sorted_current_subscriptions.add( + (consumer, tuple(self.current_assignment[consumer])) + ) def _is_balanced(self): """Determines if the current assignment is a balanced one""" if ( len(self.current_assignment[self._get_consumer_with_least_subscriptions()]) - >= len(self.current_assignment[self._get_consumer_with_most_subscriptions()]) - 1 + >= len( + self.current_assignment[self._get_consumer_with_most_subscriptions()] + ) + - 1 ): - # if minimum and maximum numbers of partitions assigned to consumers differ by at most one return true + # if minimum and maximum numbers of partitions assigned to consumers differ + # by at most one return true return True # create a mapping from partitions to the consumer assigned to them all_assigned_partitions = {} - for consumer_id, consumer_partitions in six.iteritems(self.current_assignment): + for consumer_id, consumer_partitions in self.current_assignment.items(): for partition in consumer_partitions: if partition in all_assigned_partitions: - log.error("{} is assigned to more than one consumer.".format(partition)) + log.error( + "{} is assigned to more than one consumer.".format(partition) + ) all_assigned_partitions[partition] = consumer_id # for each consumer that does not have all the topic partitions it can get - # make sure none of the topic partitions it could but did not get cannot be moved to it - # (because that would break the balance) + # make sure none of the topic partitions it could but did not get cannot be + # moved to it (because that would break the balance) for consumer, _ in self.sorted_current_subscriptions: consumer_partition_count = len(self.current_assignment[consumer]) # skip if this consumer already has all the topic partitions it can get - if consumer_partition_count == len(self.consumer_to_all_potential_partitions[consumer]): + if consumer_partition_count == len( + self.consumer_to_all_potential_partitions[consumer] + ): continue # otherwise make sure it cannot get any more for partition in self.consumer_to_all_potential_partitions[consumer]: if partition not in self.current_assignment[consumer]: other_consumer = all_assigned_partitions[partition] - other_consumer_partition_count = len(self.current_assignment[other_consumer]) + other_consumer_partition_count = len( + self.current_assignment[other_consumer] + ) if consumer_partition_count < other_consumer_partition_count: return False return True @@ -379,7 +452,9 @@ def _is_balanced(self): def _assign_partition(self, partition): for consumer, _ in self.sorted_current_subscriptions: if partition in self.consumer_to_all_potential_partitions[consumer]: - self._remove_consumer_from_current_subscriptions_and_maintain_order(consumer) + self._remove_consumer_from_current_subscriptions_and_maintain_order( + consumer + ) self.current_assignment[consumer].append(partition) self.current_partition_consumer[partition] = consumer self._add_consumer_to_current_subscriptions_and_maintain_order(consumer) @@ -393,13 +468,19 @@ def _can_consumer_participate_in_reassignment(self, consumer): current_assignment_size = len(current_partitions) max_assignment_size = len(self.consumer_to_all_potential_partitions[consumer]) if current_assignment_size > max_assignment_size: - log.error("The consumer {} is assigned more partitions than the maximum possible.".format(consumer)) + log.error( + "The consumer {} is assigned more partitions than the maximum " + "possible.".format( + consumer + ) + ) if current_assignment_size < max_assignment_size: - # if a consumer is not assigned all its potential partitions it is subject to reassignment + # if a consumer is not assigned all its potential partitions it is subject + # to reassignment return True for partition in current_partitions: - # if any of the partitions assigned to a consumer is subject to reassignment the consumer itself - # is subject to reassignment + # if any of the partitions assigned to a consumer is subject to reassignment + # the consumer itself is subject to reassignment if self._can_partition_participate_in_reassignment(partition): return True return False @@ -410,34 +491,54 @@ def _perform_reassignments(self, reassignable_partitions): # repeat reassignment until no partition can be moved to improve the balance while True: modified = False - # reassign all reassignable partitions until the full list is processed or a balance is achieved - # (starting from the partition with least potential consumers and if needed) + # reassign all reassignable partitions until the full list is processed or + # a balance is achieved (starting from the partition with least potential + # consumers and if needed) for partition in reassignable_partitions: if self._is_balanced(): break # the partition must have at least two potential consumers if len(self.partition_to_all_potential_consumers[partition]) <= 1: - log.error("Expected more than one potential consumer for partition {}".format(partition)) + log.error( + "Expected more than one potential consumer for partition " + "{}".format(partition) + ) # the partition must have a current consumer consumer = self.current_partition_consumer.get(partition) if consumer is None: - log.error("Expected partition {} to be assigned to a consumer".format(partition)) + log.error( + "Expected partition {} to be assigned to a consumer".format( + partition + ) + ) if ( partition in self.previous_assignment and len(self.current_assignment[consumer]) - > len(self.current_assignment[self.previous_assignment[partition].consumer]) + 1 + > len( + self.current_assignment[ + self.previous_assignment[partition].consumer + ] + ) + + 1 ): self._reassign_partition_to_consumer( - partition, self.previous_assignment[partition].consumer, + partition, + self.previous_assignment[partition].consumer, ) reassignment_performed = True modified = True continue - # check if a better-suited consumer exist for the partition; if so, reassign it - for other_consumer in self.partition_to_all_potential_consumers[partition]: - if len(self.current_assignment[consumer]) > len(self.current_assignment[other_consumer]) + 1: + # check if a better-suited consumer exist for the partition; if so, + # reassign it + for other_consumer in self.partition_to_all_potential_consumers[ + partition + ]: + if ( + len(self.current_assignment[consumer]) + > len(self.current_assignment[other_consumer]) + 1 + ): self._reassign_partition(partition) reassignment_performed = True modified = True @@ -459,13 +560,19 @@ def _reassign_partition(self, partition): def _reassign_partition_to_consumer(self, partition, new_consumer): consumer = self.current_partition_consumer[partition] # find the correct partition movement considering the stickiness requirement - partition_to_be_moved = self.partition_movements.get_partition_to_be_moved(partition, consumer, new_consumer) + partition_to_be_moved = self.partition_movements.get_partition_to_be_moved( + partition, consumer, new_consumer + ) self._move_partition(partition_to_be_moved, new_consumer) def _move_partition(self, partition, new_consumer): old_consumer = self.current_partition_consumer[partition] - self._remove_consumer_from_current_subscriptions_and_maintain_order(old_consumer) - self._remove_consumer_from_current_subscriptions_and_maintain_order(new_consumer) + self._remove_consumer_from_current_subscriptions_and_maintain_order( + old_consumer + ) + self._remove_consumer_from_current_subscriptions_and_maintain_order( + new_consumer + ) self.partition_movements.move_partition(partition, old_consumer, new_consumer) @@ -480,8 +587,9 @@ def _move_partition(self, partition, new_consumer): def _get_balance_score(assignment): """Calculates a balance score of a give assignment as the sum of assigned partitions size difference of all consumer pairs. - A perfectly balanced assignment (with all consumers getting the same number of partitions) - has a balance score of 0. Lower balance score indicates a more balanced assignment. + A perfectly balanced assignment (with all consumers getting the same number of + partitions) has a balance score of 0. Lower balance score indicates a more + balanced assignment. Arguments: assignment (dict): {consumer: list of assigned topic partitions} @@ -491,7 +599,7 @@ def _get_balance_score(assignment): """ score = 0 consumer_to_assignment = {} - for consumer_id, partitions in six.iteritems(assignment): + for consumer_id, partitions in assignment.items(): consumer_to_assignment[consumer_id] = len(partitions) consumers_to_explore = set(consumer_to_assignment.keys()) @@ -499,50 +607,59 @@ def _get_balance_score(assignment): if consumer_id in consumers_to_explore: consumers_to_explore.remove(consumer_id) for other_consumer_id in consumers_to_explore: - score += abs(consumer_to_assignment[consumer_id] - consumer_to_assignment[other_consumer_id]) + score += abs( + consumer_to_assignment[consumer_id] + - consumer_to_assignment[other_consumer_id] + ) return score class StickyPartitionAssignor(AbstractPartitionAssignor): """ https://cwiki.apache.org/confluence/display/KAFKA/KIP-54+-+Sticky+Partition+Assignment+Strategy - - The sticky assignor serves two purposes. First, it guarantees an assignment that is as balanced as possible, meaning either: + + The sticky assignor serves two purposes. First, it guarantees an assignment that is + as balanced as possible, meaning either: - the numbers of topic partitions assigned to consumers differ by at most one; or - - each consumer that has 2+ fewer topic partitions than some other consumer cannot get any of those topic partitions transferred to it. - - Second, it preserved as many existing assignment as possible when a reassignment occurs. - This helps in saving some of the overhead processing when topic partitions move from one consumer to another. - - Starting fresh it would work by distributing the partitions over consumers as evenly as possible. - Even though this may sound similar to how round robin assignor works, the second example below shows that it is not. - During a reassignment it would perform the reassignment in such a way that in the new assignment + - each consumer that has 2+ fewer topic partitions than some other consumer cannot + get any of those topic partitions transferred to it. + + Second, it preserved as many existing assignment as possible when a reassignment + occurs. This helps in saving some of the overhead processing when topic partitions + move from one consumer to another. + + Starting fresh it would work by distributing the partitions over consumers as evenly + as possible. Even though this may sound similar to how round robin assignor works, + the second example below shows that it is not. During a reassignment it would + perform the reassignment in such a way that in the new assignment - topic partitions are still distributed as evenly as possible, and - - topic partitions stay with their previously assigned consumers as much as possible. - + - topic partitions stay with their previously assigned consumers as much as + possible. + The first goal above takes precedence over the second one. - + Example 1. Suppose there are three consumers C0, C1, C2, four topics t0, t1, t2, t3, and each topic has 2 partitions, resulting in partitions t0p0, t0p1, t1p0, t1p1, t2p0, t2p1, t3p0, t3p1. Each consumer is subscribed to all three topics. - + The assignment with both sticky and round robin assignors will be: - C0: [t0p0, t1p1, t3p0] - C1: [t0p1, t2p0, t3p1] - C2: [t1p0, t2p1] - - Now, let's assume C1 is removed and a reassignment is about to happen. The round robin assignor would produce: + + Now, let's assume C1 is removed and a reassignment is about to happen. The round + robin assignor would produce: - C0: [t0p0, t1p0, t2p0, t3p0] - C2: [t0p1, t1p1, t2p1, t3p1] - + while the sticky assignor would result in: - C0 [t0p0, t1p1, t3p0, t2p0] - C2 [t1p0, t2p1, t0p1, t3p1] preserving all the previous assignments (unlike the round robin assignor). - - + + Example 2. There are three consumers C0, C1, C2, and three topics t0, t1, t2, with 1, 2, and 3 partitions respectively. @@ -550,22 +667,22 @@ class StickyPartitionAssignor(AbstractPartitionAssignor): C0 is subscribed to t0; C1 is subscribed to t0, t1; and C2 is subscribed to t0, t1, t2. - + The round robin assignor would come up with the following assignment: - C0 [t0p0] - C1 [t1p0] - C2 [t1p1, t2p0, t2p1, t2p2] - + which is not as balanced as the assignment suggested by sticky assignor: - C0 [t0p0] - C1 [t1p0, t1p1] - C2 [t2p0, t2p1, t2p2] - - Now, if consumer C0 is removed, these two assignors would produce the following assignments. - Round Robin (preserves 3 partition assignments): + + Now, if consumer C0 is removed, these two assignors would produce the following + assignments. Round Robin (preserves 3 partition assignments): - C1 [t0p0, t1p1] - C2 [t1p0, t2p0, t2p1, t2p2] - + Sticky (preserves 5 partition assignments): - C1 [t1p0, t1p1, t0p0] - C2 [t2p0, t2p1, t2p2] @@ -587,13 +704,14 @@ def assign(cls, cluster, members): Arguments: cluster (ClusterMetadata): cluster metadata - members (dict of {member_id: MemberMetadata}): decoded metadata for each member in the group. + members (dict of {member_id: MemberMetadata}): decoded metadata for each + member in the group. Returns: dict: {member_id: MemberAssignment} """ members_metadata = {} - for consumer, member_metadata in six.iteritems(members): + for consumer, member_metadata in members.items(): members_metadata[consumer] = cls.parse_member_metadata(member_metadata) executor = StickyAssignmentExecutor(cluster, members_metadata) @@ -605,7 +723,7 @@ def assign(cls, cluster, members): assignment = {} for member_id in members: assignment[member_id] = ConsumerProtocolMemberAssignment( - cls.version, sorted(executor.get_final_assignment(member_id)), b'' + cls.version, sorted(executor.get_final_assignment(member_id)), b"" ) return assignment @@ -613,9 +731,10 @@ def assign(cls, cluster, members): def parse_member_metadata(cls, metadata): """ Parses member metadata into a python object. - This implementation only serializes and deserializes the StickyAssignorMemberMetadataV1 user data, - since no StickyAssignor written in Python was deployed ever in the wild with version V0, meaning that - there is no need to support backward compatibility with V0. + This implementation only serializes and deserializes the + StickyAssignorMemberMetadataV1 user data, since no StickyAssignor written in + Python was deployed ever in the wild with version V0, meaning that there is no + need to support backward compatibility with V0. Arguments: metadata (MemberMetadata): decoded metadata for a member of the group. @@ -626,24 +745,37 @@ def parse_member_metadata(cls, metadata): user_data = metadata.user_data if not user_data: return StickyAssignorMemberMetadataV1( - partitions=[], generation=cls.DEFAULT_GENERATION_ID, subscription=metadata.subscription + partitions=[], + generation=cls.DEFAULT_GENERATION_ID, + subscription=metadata.subscription, ) try: decoded_user_data = StickyAssignorUserDataV1.decode(user_data) except Exception as e: # ignore the consumer's previous assignment if it cannot be parsed - log.error("Could not parse member data", e) # pylint: disable=logging-too-many-args + log.error( + "Could not parse member data", e + ) # pylint: disable=logging-too-many-args return StickyAssignorMemberMetadataV1( - partitions=[], generation=cls.DEFAULT_GENERATION_ID, subscription=metadata.subscription + partitions=[], + generation=cls.DEFAULT_GENERATION_ID, + subscription=metadata.subscription, ) member_partitions = [] - for topic, partitions in decoded_user_data.previous_assignment: # pylint: disable=no-member - member_partitions.extend([TopicPartition(topic, partition) for partition in partitions]) + for ( + topic, + partitions, + ) in decoded_user_data.previous_assignment: # pylint: disable=no-member + member_partitions.extend( + [TopicPartition(topic, partition) for partition in partitions] + ) return StickyAssignorMemberMetadataV1( # pylint: disable=no-member - partitions=member_partitions, generation=decoded_user_data.generation, subscription=metadata.subscription + partitions=member_partitions, + generation=decoded_user_data.generation, + subscription=metadata.subscription, ) @classmethod @@ -654,13 +786,18 @@ def metadata(cls, topics): def _metadata(cls, topics, member_assignment_partitions, generation=-1): if member_assignment_partitions is None: log.debug("No member assignment available") - user_data = b'' + user_data = b"" else: - log.debug("Member assignment is available, generating the metadata: generation {}".format(cls.generation)) + log.debug( + "Member assignment is available, generating the metadata: " + "generation {}".format(cls.generation) + ) partitions_by_topic = defaultdict(list) for topic_partition in member_assignment_partitions: - partitions_by_topic[topic_partition.topic].append(topic_partition.partition) - data = StickyAssignorUserDataV1(six.viewitems(partitions_by_topic), generation) + partitions_by_topic[topic_partition.topic].append( + topic_partition.partition + ) + data = StickyAssignorUserDataV1(partitions_by_topic.items(), generation) user_data = data.encode() return ConsumerProtocolMemberMetadata(cls.version, list(topics), user_data) diff --git a/kafka/coordinator/base.py b/aiokafka/coordinator/base.py similarity index 70% rename from kafka/coordinator/base.py rename to aiokafka/coordinator/base.py index e7198410..89401c06 100644 --- a/kafka/coordinator/base.py +++ b/aiokafka/coordinator/base.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import, division - import abc import copy import logging @@ -7,24 +5,28 @@ import time import weakref -from kafka.vendor import six - -from kafka.coordinator.heartbeat import Heartbeat -from kafka import errors as Errors from kafka.future import Future from kafka.metrics import AnonMeasurable from kafka.metrics.stats import Avg, Count, Max, Rate from kafka.protocol.commit import GroupCoordinatorRequest, OffsetCommitRequest -from kafka.protocol.group import (HeartbeatRequest, JoinGroupRequest, - LeaveGroupRequest, SyncGroupRequest) +from kafka.protocol.group import ( + HeartbeatRequest, + JoinGroupRequest, + LeaveGroupRequest, + SyncGroupRequest, +) + +from aiokafka import errors as Errors -log = logging.getLogger('kafka.coordinator') +from .heartbeat import Heartbeat + +log = logging.getLogger("aiokafka.coordinator") class MemberState(object): - UNJOINED = '' # the client is not part of a group - REBALANCING = '' # the client has begun rebalancing - STABLE = '' # the client has joined and is sending heartbeats + UNJOINED = "" # the client is not part of a group + REBALANCING = "" # the client has begun rebalancing + STABLE = "" # the client has joined and is sending heartbeats class Generation(object): @@ -33,10 +35,12 @@ def __init__(self, generation_id, member_id, protocol): self.member_id = member_id self.protocol = protocol + Generation.NO_GENERATION = Generation( OffsetCommitRequest[2].DEFAULT_GENERATION_ID, JoinGroupRequest[0].UNKNOWN_MEMBER_ID, - None) + None, +) class UnjoinedGroupException(Errors.KafkaError): @@ -81,13 +85,13 @@ class BaseCoordinator(object): """ DEFAULT_CONFIG = { - 'group_id': 'kafka-python-default-group', - 'session_timeout_ms': 10000, - 'heartbeat_interval_ms': 3000, - 'max_poll_interval_ms': 300000, - 'retry_backoff_ms': 100, - 'api_version': (0, 10, 1), - 'metric_group_prefix': '', + "group_id": "kafka-python-default-group", + "session_timeout_ms": 10000, + "heartbeat_interval_ms": 3000, + "max_poll_interval_ms": 300000, + "retry_backoff_ms": 100, + "api_version": (0, 10, 1), + "metric_group_prefix": "", } def __init__(self, client, metrics, **configs): @@ -115,14 +119,16 @@ def __init__(self, client, metrics, **configs): if key in configs: self.config[key] = configs[key] - if self.config['api_version'] < (0, 10, 1): - if self.config['max_poll_interval_ms'] != self.config['session_timeout_ms']: - raise Errors.KafkaConfigurationError("Broker version %s does not support " - "different values for max_poll_interval_ms " - "and session_timeout_ms") + if self.config["api_version"] < (0, 10, 1): + if self.config["max_poll_interval_ms"] != self.config["session_timeout_ms"]: + raise Errors.KafkaConfigurationError( + "Broker version %s does not support " + "different values for max_poll_interval_ms " + "and session_timeout_ms" + ) self._client = client - self.group_id = self.config['group_id'] + self.group_id = self.config["group_id"] self.heartbeat = Heartbeat(**self.config) self._heartbeat_thread = None self._lock = threading.Condition() @@ -133,8 +139,9 @@ def __init__(self, client, metrics, **configs): self.coordinator_id = None self._find_coordinator_future = None self._generation = Generation.NO_GENERATION - self.sensors = GroupCoordinatorMetrics(self.heartbeat, metrics, - self.config['metric_group_prefix']) + self.sensors = GroupCoordinatorMetrics( + self.heartbeat, metrics, self.config["metric_group_prefix"] + ) @abc.abstractmethod def protocol_type(self): @@ -201,8 +208,9 @@ def _perform_assignment(self, leader_id, protocol, members): pass @abc.abstractmethod - def _on_join_complete(self, generation, member_id, protocol, - member_assignment_bytes): + def _on_join_complete( + self, generation, member_id, protocol, member_assignment_bytes + ): """Invoked when a group member has successfully joined a group. Arguments: @@ -233,7 +241,7 @@ def coordinator(self): if self.coordinator_id is None: return None elif self._client.is_disconnected(self.coordinator_id): - self.coordinator_dead('Node Disconnected') + self.coordinator_dead("Node Disconnected") return None else: return self.coordinator_id @@ -248,7 +256,7 @@ def ensure_coordinator_ready(self): # Prior to 0.8.2 there was no group coordinator # so we will just pick a node at random and treat # it as the "coordinator" - if self.config['api_version'] < (0, 8, 2): + if self.config["api_version"] < (0, 8, 2): self.coordinator_id = self._client.least_loaded_node() if self.coordinator_id is not None: self._client.maybe_connect(self.coordinator_id) @@ -259,12 +267,15 @@ def ensure_coordinator_ready(self): if future.failed(): if future.retriable(): - if getattr(future.exception, 'invalid_metadata', False): - log.debug('Requesting metadata for group coordinator request: %s', future.exception) + if getattr(future.exception, "invalid_metadata", False): + log.debug( + "Requesting metadata for group coordinator request: %s", + future.exception, + ) metadata_update = self._client.cluster.request_update() self._client.poll(future=metadata_update) else: - time.sleep(self.config['retry_backoff_ms'] / 1000) + time.sleep(self.config["retry_backoff_ms"] / 1000) else: raise future.exception # pylint: disable-msg=raising-bad-type @@ -327,13 +338,16 @@ def time_to_next_heartbeat(self): with self._lock: # if we have not joined the group, we don't need to send heartbeats if self.state is MemberState.UNJOINED: - return float('inf') + return float("inf") return self.heartbeat.time_to_next_heartbeat() def _handle_join_success(self, member_assignment_bytes): with self._lock: - log.info("Successfully joined group %s with generation %s", - self.group_id, self._generation.generation_id) + log.info( + "Successfully joined group %s with generation %s", + self.group_id, + self._generation.generation_id, + ) self.state = MemberState.STABLE self.rejoin_needed = False if self._heartbeat_thread: @@ -361,8 +375,9 @@ def ensure_active_group(self): # changes the matched subscription set) can occur # while another rebalance is still in progress. if not self.rejoining: - self._on_join_prepare(self._generation.generation_id, - self._generation.member_id) + self._on_join_prepare( + self._generation.generation_id, self._generation.member_id + ) self.rejoining = True # ensure that there are no pending requests to the coordinator. @@ -389,7 +404,9 @@ def ensure_active_group(self): self.state = MemberState.REBALANCING future = self._send_join_group_request() - self.join_future = future # this should happen before adding callbacks + self.join_future = ( + future # this should happen before adding callbacks + ) # handle join completion in the callback so that the # callback will be invoked even if the consumer is woken up @@ -407,23 +424,30 @@ def ensure_active_group(self): self._client.poll(future=future) if future.succeeded(): - self._on_join_complete(self._generation.generation_id, - self._generation.member_id, - self._generation.protocol, - future.value) + self._on_join_complete( + self._generation.generation_id, + self._generation.member_id, + self._generation.protocol, + future.value, + ) self.join_future = None self.rejoining = False else: self.join_future = None exception = future.exception - if isinstance(exception, (Errors.UnknownMemberIdError, - Errors.RebalanceInProgressError, - Errors.IllegalGenerationError)): + if isinstance( + exception, + ( + Errors.UnknownMemberIdError, + Errors.RebalanceInProgressError, + Errors.IllegalGenerationError, + ), + ): continue elif not future.retriable(): raise exception # pylint: disable-msg=raising-bad-type - time.sleep(self.config['retry_backoff_ms'] / 1000) + time.sleep(self.config["retry_backoff_ms"] / 1000) def _rejoin_incomplete(self): return self.join_future is not None @@ -452,59 +476,75 @@ def _send_join_group_request(self): (protocol, metadata if isinstance(metadata, bytes) else metadata.encode()) for protocol, metadata in self.group_protocols() ] - if self.config['api_version'] < (0, 9): - raise Errors.KafkaError('JoinGroupRequest api requires 0.9+ brokers') - elif (0, 9) <= self.config['api_version'] < (0, 10, 1): + if self.config["api_version"] < (0, 9): + raise Errors.KafkaError("JoinGroupRequest api requires 0.9+ brokers") + elif (0, 9) <= self.config["api_version"] < (0, 10, 1): request = JoinGroupRequest[0]( self.group_id, - self.config['session_timeout_ms'], + self.config["session_timeout_ms"], self._generation.member_id, self.protocol_type(), - member_metadata) - elif (0, 10, 1) <= self.config['api_version'] < (0, 11, 0): + member_metadata, + ) + elif (0, 10, 1) <= self.config["api_version"] < (0, 11, 0): request = JoinGroupRequest[1]( self.group_id, - self.config['session_timeout_ms'], - self.config['max_poll_interval_ms'], + self.config["session_timeout_ms"], + self.config["max_poll_interval_ms"], self._generation.member_id, self.protocol_type(), - member_metadata) + member_metadata, + ) else: request = JoinGroupRequest[2]( self.group_id, - self.config['session_timeout_ms'], - self.config['max_poll_interval_ms'], + self.config["session_timeout_ms"], + self.config["max_poll_interval_ms"], self._generation.member_id, self.protocol_type(), - member_metadata) + member_metadata, + ) # create the request for the coordinator - log.debug("Sending JoinGroup (%s) to coordinator %s", request, self.coordinator_id) + log.debug( + "Sending JoinGroup (%s) to coordinator %s", request, self.coordinator_id + ) future = Future() _f = self._client.send(self.coordinator_id, request) _f.add_callback(self._handle_join_group_response, future, time.time()) - _f.add_errback(self._failed_request, self.coordinator_id, - request, future) + _f.add_errback(self._failed_request, self.coordinator_id, request, future) return future def _failed_request(self, node_id, request, future, error): # Marking coordinator dead # unless the error is caused by internal client pipelining - if not isinstance(error, (Errors.NodeNotReadyError, - Errors.TooManyInFlightRequests)): - log.error('Error sending %s to node %s [%s]', - request.__class__.__name__, node_id, error) + if not isinstance( + error, (Errors.NodeNotReadyError, Errors.TooManyInFlightRequests) + ): + log.error( + "Error sending %s to node %s [%s]", + request.__class__.__name__, + node_id, + error, + ) self.coordinator_dead(error) else: - log.debug('Error sending %s to node %s [%s]', - request.__class__.__name__, node_id, error) + log.debug( + "Error sending %s to node %s [%s]", + request.__class__.__name__, + node_id, + error, + ) future.failure(error) def _handle_join_group_response(self, future, send_time, response): error_type = Errors.for_code(response.error_code) if error_type is Errors.NoError: - log.debug("Received successful JoinGroup response for group %s: %s", - self.group_id, response) + log.debug( + "Received successful JoinGroup response for group %s: %s", + self.group_id, + response, + ) self.sensors.join_latency.record((time.time() - send_time) * 1000) with self._lock: if self.state is not MemberState.REBALANCING: @@ -513,44 +553,65 @@ def _handle_join_group_response(self, future, send_time, response): # not want to continue with the sync group. future.failure(UnjoinedGroupException()) else: - self._generation = Generation(response.generation_id, - response.member_id, - response.group_protocol) + self._generation = Generation( + response.generation_id, + response.member_id, + response.group_protocol, + ) if response.leader_id == response.member_id: - log.info("Elected group leader -- performing partition" - " assignments using %s", self._generation.protocol) + log.info( + "Elected group leader -- performing partition" + " assignments using %s", + self._generation.protocol, + ) self._on_join_leader(response).chain(future) else: self._on_join_follower().chain(future) elif error_type is Errors.GroupLoadInProgressError: - log.debug("Attempt to join group %s rejected since coordinator %s" - " is loading the group.", self.group_id, self.coordinator_id) + log.debug( + "Attempt to join group %s rejected since coordinator %s" + " is loading the group.", + self.group_id, + self.coordinator_id, + ) # backoff and retry future.failure(error_type(response)) elif error_type is Errors.UnknownMemberIdError: # reset the member id and retry immediately error = error_type(self._generation.member_id) self.reset_generation() - log.debug("Attempt to join group %s failed due to unknown member id", - self.group_id) + log.debug( + "Attempt to join group %s failed due to unknown member id", + self.group_id, + ) future.failure(error) - elif error_type in (Errors.GroupCoordinatorNotAvailableError, - Errors.NotCoordinatorForGroupError): + elif error_type in ( + Errors.GroupCoordinatorNotAvailableError, + Errors.NotCoordinatorForGroupError, + ): # re-discover the coordinator and retry with backoff self.coordinator_dead(error_type()) - log.debug("Attempt to join group %s failed due to obsolete " - "coordinator information: %s", self.group_id, - error_type.__name__) + log.debug( + "Attempt to join group %s failed due to obsolete " + "coordinator information: %s", + self.group_id, + error_type.__name__, + ) future.failure(error_type()) - elif error_type in (Errors.InconsistentGroupProtocolError, - Errors.InvalidSessionTimeoutError, - Errors.InvalidGroupIdError): + elif error_type in ( + Errors.InconsistentGroupProtocolError, + Errors.InvalidSessionTimeoutError, + Errors.InvalidGroupIdError, + ): # log the error and re-throw the exception error = error_type(response) - log.error("Attempt to join group %s failed due to fatal error: %s", - self.group_id, error) + log.error( + "Attempt to join group %s failed due to fatal error: %s", + self.group_id, + error, + ) future.failure(error) elif error_type is Errors.GroupAuthorizationFailedError: future.failure(error_type(self.group_id)) @@ -562,14 +623,19 @@ def _handle_join_group_response(self, future, send_time, response): def _on_join_follower(self): # send follower's sync group with an empty assignment - version = 0 if self.config['api_version'] < (0, 11, 0) else 1 + version = 0 if self.config["api_version"] < (0, 11, 0) else 1 request = SyncGroupRequest[version]( self.group_id, self._generation.generation_id, self._generation.member_id, - {}) - log.debug("Sending follower SyncGroup for group %s to coordinator %s: %s", - self.group_id, self.coordinator_id, request) + {}, + ) + log.debug( + "Sending follower SyncGroup for group %s to coordinator %s: %s", + self.group_id, + self.coordinator_id, + request, + ) return self._send_sync_group_request(request) def _on_join_leader(self, response): @@ -584,23 +650,34 @@ def _on_join_leader(self, response): Future: resolves to member assignment encoded-bytes """ try: - group_assignment = self._perform_assignment(response.leader_id, - response.group_protocol, - response.members) + group_assignment = self._perform_assignment( + response.leader_id, response.group_protocol, response.members + ) except Exception as e: return Future().failure(e) - version = 0 if self.config['api_version'] < (0, 11, 0) else 1 + version = 0 if self.config["api_version"] < (0, 11, 0) else 1 request = SyncGroupRequest[version]( self.group_id, self._generation.generation_id, self._generation.member_id, - [(member_id, - assignment if isinstance(assignment, bytes) else assignment.encode()) - for member_id, assignment in six.iteritems(group_assignment)]) - - log.debug("Sending leader SyncGroup for group %s to coordinator %s: %s", - self.group_id, self.coordinator_id, request) + [ + ( + member_id, + assignment + if isinstance(assignment, bytes) + else assignment.encode(), + ) + for member_id, assignment in group_assignment.items() + ], + ) + + log.debug( + "Sending leader SyncGroup for group %s to coordinator %s: %s", + self.group_id, + self.coordinator_id, + request, + ) return self._send_sync_group_request(request) def _send_sync_group_request(self, request): @@ -617,8 +694,7 @@ def _send_sync_group_request(self, request): future = Future() _f = self._client.send(self.coordinator_id, request) _f.add_callback(self._handle_sync_group_response, future, time.time()) - _f.add_errback(self._failed_request, self.coordinator_id, - request, future) + _f.add_errback(self._failed_request, self.coordinator_id, request, future) return future def _handle_sync_group_response(self, future, send_time, response): @@ -633,17 +709,20 @@ def _handle_sync_group_response(self, future, send_time, response): if error_type is Errors.GroupAuthorizationFailedError: future.failure(error_type(self.group_id)) elif error_type is Errors.RebalanceInProgressError: - log.debug("SyncGroup for group %s failed due to coordinator" - " rebalance", self.group_id) + log.debug( + "SyncGroup for group %s failed due to coordinator" " rebalance", + self.group_id, + ) future.failure(error_type(self.group_id)) - elif error_type in (Errors.UnknownMemberIdError, - Errors.IllegalGenerationError): + elif error_type in (Errors.UnknownMemberIdError, Errors.IllegalGenerationError): error = error_type() log.debug("SyncGroup for group %s failed due to %s", self.group_id, error) self.reset_generation() future.failure(error) - elif error_type in (Errors.GroupCoordinatorNotAvailableError, - Errors.NotCoordinatorForGroupError): + elif error_type in ( + Errors.GroupCoordinatorNotAvailableError, + Errors.NotCoordinatorForGroupError, + ): error = error_type() log.debug("SyncGroup for group %s failed due to %s", self.group_id, error) self.coordinator_dead(error) @@ -667,8 +746,11 @@ def _send_group_coordinator_request(self): e = Errors.NodeNotReadyError(node_id) return Future().failure(e) - log.debug("Sending group coordinator request for group %s to broker %s", - self.group_id, node_id) + log.debug( + "Sending group coordinator request for group %s to broker %s", + self.group_id, + node_id, + ) request = GroupCoordinatorRequest[0](self.group_id) future = Future() _f = self._client.send(node_id, request) @@ -682,7 +764,9 @@ def _handle_group_coordinator_response(self, future, response): error_type = Errors.for_code(response.error_code) if error_type is Errors.NoError: with self._lock: - coordinator_id = self._client.cluster.add_group_coordinator(self.group_id, response) + coordinator_id = self._client.cluster.add_group_coordinator( + self.group_id, response + ) if not coordinator_id: # This could happen if coordinator metadata is different # than broker metadata @@ -690,8 +774,11 @@ def _handle_group_coordinator_response(self, future, response): return self.coordinator_id = coordinator_id - log.info("Discovered coordinator %s for group %s", - self.coordinator_id, self.group_id) + log.info( + "Discovered coordinator %s for group %s", + self.coordinator_id, + self.group_id, + ) self._client.maybe_connect(self.coordinator_id) self.heartbeat.reset_timeouts() future.success(self.coordinator_id) @@ -705,15 +792,20 @@ def _handle_group_coordinator_response(self, future, response): future.failure(error) else: error = error_type() - log.error("Group coordinator lookup for group %s failed: %s", - self.group_id, error) + log.error( + "Group coordinator lookup for group %s failed: %s", self.group_id, error + ) future.failure(error) def coordinator_dead(self, error): """Mark the current coordinator as dead.""" if self.coordinator_id is not None: - log.warning("Marking the coordinator dead (node %s) for group %s: %s.", - self.coordinator_id, self.group_id, error) + log.warning( + "Marking the coordinator dead (node %s) for group %s: %s.", + self.coordinator_id, + self.group_id, + error, + ) self.coordinator_id = None def generation(self): @@ -738,14 +830,14 @@ def request_rejoin(self): def _start_heartbeat_thread(self): if self._heartbeat_thread is None: - log.info('Starting new heartbeat thread') + log.info("Starting new heartbeat thread") self._heartbeat_thread = HeartbeatThread(weakref.proxy(self)) self._heartbeat_thread.daemon = True self._heartbeat_thread.start() def _close_heartbeat_thread(self): if self._heartbeat_thread is not None: - log.info('Stopping heartbeat thread') + log.info("Stopping heartbeat thread") try: self._heartbeat_thread.close() except ReferenceError: @@ -764,15 +856,19 @@ def close(self): def maybe_leave_group(self): """Leave the current group and reset local generation/memberId.""" with self._client._lock, self._lock: - if (not self.coordinator_unknown() + if ( + not self.coordinator_unknown() and self.state is not MemberState.UNJOINED - and self._generation is not Generation.NO_GENERATION): + and self._generation is not Generation.NO_GENERATION + ): # this is a minimal effort attempt to leave the group. we do not # attempt any resending if the request fails or times out. - log.info('Leaving consumer group (%s).', self.group_id) - version = 0 if self.config['api_version'] < (0, 11, 0) else 1 - request = LeaveGroupRequest[version](self.group_id, self._generation.member_id) + log.info("Leaving consumer group (%s).", self.group_id) + version = 0 if self.config["api_version"] < (0, 11, 0) else 1 + request = LeaveGroupRequest[version]( + self.group_id, self._generation.member_id + ) future = self._client.send(self.coordinator_id, request) future.add_callback(self._handle_leave_group_response) future.add_errback(log.error, "LeaveGroup request failed: %s") @@ -783,11 +879,15 @@ def maybe_leave_group(self): def _handle_leave_group_response(self, response): error_type = Errors.for_code(response.error_code) if error_type is Errors.NoError: - log.debug("LeaveGroup request for group %s returned successfully", - self.group_id) + log.debug( + "LeaveGroup request for group %s returned successfully", self.group_id + ) else: - log.error("LeaveGroup request for group %s failed with error: %s", - self.group_id, error_type()) + log.error( + "LeaveGroup request for group %s failed with error: %s", + self.group_id, + error_type(), + ) def _send_heartbeat_request(self): """Send a heartbeat request""" @@ -799,45 +899,61 @@ def _send_heartbeat_request(self): e = Errors.NodeNotReadyError(self.coordinator_id) return Future().failure(e) - version = 0 if self.config['api_version'] < (0, 11, 0) else 1 - request = HeartbeatRequest[version](self.group_id, - self._generation.generation_id, - self._generation.member_id) - log.debug("Heartbeat: %s[%s] %s", request.group, request.generation_id, request.member_id) # pylint: disable-msg=no-member + version = 0 if self.config["api_version"] < (0, 11, 0) else 1 + request = HeartbeatRequest[version]( + self.group_id, self._generation.generation_id, self._generation.member_id + ) + log.debug( + "Heartbeat: %s[%s] %s", + request.group, + request.generation_id, + request.member_id, + ) # pylint: disable-msg=no-member future = Future() _f = self._client.send(self.coordinator_id, request) _f.add_callback(self._handle_heartbeat_response, future, time.time()) - _f.add_errback(self._failed_request, self.coordinator_id, - request, future) + _f.add_errback(self._failed_request, self.coordinator_id, request, future) return future def _handle_heartbeat_response(self, future, send_time, response): self.sensors.heartbeat_latency.record((time.time() - send_time) * 1000) error_type = Errors.for_code(response.error_code) if error_type is Errors.NoError: - log.debug("Received successful heartbeat response for group %s", - self.group_id) + log.debug( + "Received successful heartbeat response for group %s", self.group_id + ) future.success(None) - elif error_type in (Errors.GroupCoordinatorNotAvailableError, - Errors.NotCoordinatorForGroupError): - log.warning("Heartbeat failed for group %s: coordinator (node %s)" - " is either not started or not valid", self.group_id, - self.coordinator()) + elif error_type in ( + Errors.GroupCoordinatorNotAvailableError, + Errors.NotCoordinatorForGroupError, + ): + log.warning( + "Heartbeat failed for group %s: coordinator (node %s)" + " is either not started or not valid", + self.group_id, + self.coordinator(), + ) self.coordinator_dead(error_type()) future.failure(error_type()) elif error_type is Errors.RebalanceInProgressError: - log.warning("Heartbeat failed for group %s because it is" - " rebalancing", self.group_id) + log.warning( + "Heartbeat failed for group %s because it is" " rebalancing", + self.group_id, + ) self.request_rejoin() future.failure(error_type()) elif error_type is Errors.IllegalGenerationError: - log.warning("Heartbeat failed for group %s: generation id is not " - " current.", self.group_id) + log.warning( + "Heartbeat failed for group %s: generation id is not " " current.", + self.group_id, + ) self.reset_generation() future.failure(error_type()) elif error_type is Errors.UnknownMemberIdError: - log.warning("Heartbeat: local member_id was not recognized;" - " this consumer needs to re-join") + log.warning( + "Heartbeat: local member_id was not recognized;" + " this consumer needs to re-join" + ) self.reset_generation() future.failure(error_type) elif error_type is Errors.GroupAuthorizationFailedError: @@ -856,55 +972,99 @@ def __init__(self, heartbeat, metrics, prefix, tags=None): self.metrics = metrics self.metric_group_name = prefix + "-coordinator-metrics" - self.heartbeat_latency = metrics.sensor('heartbeat-latency') - self.heartbeat_latency.add(metrics.metric_name( - 'heartbeat-response-time-max', self.metric_group_name, - 'The max time taken to receive a response to a heartbeat request', - tags), Max()) - self.heartbeat_latency.add(metrics.metric_name( - 'heartbeat-rate', self.metric_group_name, - 'The average number of heartbeats per second', - tags), Rate(sampled_stat=Count())) - - self.join_latency = metrics.sensor('join-latency') - self.join_latency.add(metrics.metric_name( - 'join-time-avg', self.metric_group_name, - 'The average time taken for a group rejoin', - tags), Avg()) - self.join_latency.add(metrics.metric_name( - 'join-time-max', self.metric_group_name, - 'The max time taken for a group rejoin', - tags), Max()) - self.join_latency.add(metrics.metric_name( - 'join-rate', self.metric_group_name, - 'The number of group joins per second', - tags), Rate(sampled_stat=Count())) - - self.sync_latency = metrics.sensor('sync-latency') - self.sync_latency.add(metrics.metric_name( - 'sync-time-avg', self.metric_group_name, - 'The average time taken for a group sync', - tags), Avg()) - self.sync_latency.add(metrics.metric_name( - 'sync-time-max', self.metric_group_name, - 'The max time taken for a group sync', - tags), Max()) - self.sync_latency.add(metrics.metric_name( - 'sync-rate', self.metric_group_name, - 'The number of group syncs per second', - tags), Rate(sampled_stat=Count())) - - metrics.add_metric(metrics.metric_name( - 'last-heartbeat-seconds-ago', self.metric_group_name, - 'The number of seconds since the last controller heartbeat was sent', - tags), AnonMeasurable( - lambda _, now: (now / 1000) - self.heartbeat.last_send)) + self.heartbeat_latency = metrics.sensor("heartbeat-latency") + self.heartbeat_latency.add( + metrics.metric_name( + "heartbeat-response-time-max", + self.metric_group_name, + "The max time taken to receive a response to a heartbeat request", + tags, + ), + Max(), + ) + self.heartbeat_latency.add( + metrics.metric_name( + "heartbeat-rate", + self.metric_group_name, + "The average number of heartbeats per second", + tags, + ), + Rate(sampled_stat=Count()), + ) + + self.join_latency = metrics.sensor("join-latency") + self.join_latency.add( + metrics.metric_name( + "join-time-avg", + self.metric_group_name, + "The average time taken for a group rejoin", + tags, + ), + Avg(), + ) + self.join_latency.add( + metrics.metric_name( + "join-time-max", + self.metric_group_name, + "The max time taken for a group rejoin", + tags, + ), + Max(), + ) + self.join_latency.add( + metrics.metric_name( + "join-rate", + self.metric_group_name, + "The number of group joins per second", + tags, + ), + Rate(sampled_stat=Count()), + ) + + self.sync_latency = metrics.sensor("sync-latency") + self.sync_latency.add( + metrics.metric_name( + "sync-time-avg", + self.metric_group_name, + "The average time taken for a group sync", + tags, + ), + Avg(), + ) + self.sync_latency.add( + metrics.metric_name( + "sync-time-max", + self.metric_group_name, + "The max time taken for a group sync", + tags, + ), + Max(), + ) + self.sync_latency.add( + metrics.metric_name( + "sync-rate", + self.metric_group_name, + "The number of group syncs per second", + tags, + ), + Rate(sampled_stat=Count()), + ) + + metrics.add_metric( + metrics.metric_name( + "last-heartbeat-seconds-ago", + self.metric_group_name, + "The number of seconds since the last controller heartbeat was sent", + tags, + ), + AnonMeasurable(lambda _, now: (now / 1000) - self.heartbeat.last_send), + ) class HeartbeatThread(threading.Thread): def __init__(self, coordinator): super(HeartbeatThread, self).__init__() - self.name = coordinator.group_id + '-heartbeat' + self.name = coordinator.group_id + "-heartbeat" self.coordinator = coordinator self.enabled = False self.closed = False @@ -924,26 +1084,29 @@ def close(self): with self.coordinator._lock: self.coordinator._lock.notify() if self.is_alive(): - self.join(self.coordinator.config['heartbeat_interval_ms'] / 1000) + self.join(self.coordinator.config["heartbeat_interval_ms"] / 1000) if self.is_alive(): log.warning("Heartbeat thread did not fully terminate during close") def run(self): try: - log.debug('Heartbeat thread started') + log.debug("Heartbeat thread started") while not self.closed: self._run_once() except ReferenceError: - log.debug('Heartbeat thread closed due to coordinator gc') + log.debug("Heartbeat thread closed due to coordinator gc") except RuntimeError as e: - log.error("Heartbeat thread for group %s failed due to unexpected error: %s", - self.coordinator.group_id, e) + log.error( + "Heartbeat thread for group %s failed due to unexpected error: %s", + self.coordinator.group_id, + e, + ) self.failed = e finally: - log.debug('Heartbeat thread closed') + log.debug("Heartbeat thread closed") def _run_once(self): with self.coordinator._client._lock, self.coordinator._lock: @@ -951,22 +1114,22 @@ def _run_once(self): # TODO: When consumer.wakeup() is implemented, we need to # disable here to prevent propagating an exception to this # heartbeat thread - # must get client._lock, or maybe deadlock at heartbeat + # must get client._lock, or maybe deadlock at heartbeat # failure callback in consumer poll self.coordinator._client.poll(timeout_ms=0) with self.coordinator._lock: if not self.enabled: - log.debug('Heartbeat disabled. Waiting') + log.debug("Heartbeat disabled. Waiting") self.coordinator._lock.wait() - log.debug('Heartbeat re-enabled.') + log.debug("Heartbeat re-enabled.") return if self.coordinator.state is not MemberState.STABLE: # the group is not stable (perhaps because we left the # group or because the coordinator kicked us out), so # disable heartbeats and wait for the main thread to rejoin. - log.debug('Group state is not stable, disabling heartbeats') + log.debug("Group state is not stable, disabling heartbeats") self.disable() return @@ -976,27 +1139,31 @@ def _run_once(self): # the immediate future check ensures that we backoff # properly in the case that no brokers are available # to connect to (and the future is automatically failed). - self.coordinator._lock.wait(self.coordinator.config['retry_backoff_ms'] / 1000) + self.coordinator._lock.wait( + self.coordinator.config["retry_backoff_ms"] / 1000 + ) elif self.coordinator.heartbeat.session_timeout_expired(): # the session timeout has expired without seeing a # successful heartbeat, so we should probably make sure # the coordinator is still healthy. - log.warning('Heartbeat session expired, marking coordinator dead') - self.coordinator.coordinator_dead('Heartbeat session expired') + log.warning("Heartbeat session expired, marking coordinator dead") + self.coordinator.coordinator_dead("Heartbeat session expired") elif self.coordinator.heartbeat.poll_timeout_expired(): # the poll timeout has expired, which means that the # foreground thread has stalled in between calls to # poll(), so we explicitly leave the group. - log.warning('Heartbeat poll expired, leaving group') + log.warning("Heartbeat poll expired, leaving group") self.coordinator.maybe_leave_group() elif not self.coordinator.heartbeat.should_heartbeat(): # poll again after waiting for the retry backoff in case # the heartbeat failed or the coordinator disconnected - log.log(0, 'Not ready to heartbeat, waiting') - self.coordinator._lock.wait(self.coordinator.config['retry_backoff_ms'] / 1000) + log.log(0, "Not ready to heartbeat, waiting") + self.coordinator._lock.wait( + self.coordinator.config["retry_backoff_ms"] / 1000 + ) else: self.coordinator.heartbeat.sent_heartbeat() diff --git a/kafka/coordinator/consumer.py b/aiokafka/coordinator/consumer.py similarity index 62% rename from kafka/coordinator/consumer.py rename to aiokafka/coordinator/consumer.py index 6f0de2db..60f922ff 100644 --- a/kafka/coordinator/consumer.py +++ b/aiokafka/coordinator/consumer.py @@ -1,19 +1,9 @@ -from __future__ import absolute_import, division - import collections import copy import functools import logging import time -from kafka.vendor import six - -from kafka.coordinator.base import BaseCoordinator, Generation -from kafka.coordinator.assignors.range import RangePartitionAssignor -from kafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor -from kafka.coordinator.assignors.sticky.sticky_assignor import StickyPartitionAssignor -from kafka.coordinator.protocol import ConsumerProtocol -import aiokafka.errors as Errors from kafka.future import Future from kafka.metrics import AnonMeasurable from kafka.metrics.stats import Avg, Count, Max, Rate @@ -21,25 +11,38 @@ from kafka.structs import OffsetAndMetadata, TopicPartition from kafka.util import WeakMethod +import aiokafka.errors as Errors + +from .base import BaseCoordinator, Generation +from .assignors.range import RangePartitionAssignor +from .assignors.roundrobin import RoundRobinPartitionAssignor +from .assignors.sticky.sticky_assignor import StickyPartitionAssignor +from .protocol import ConsumerProtocol + log = logging.getLogger(__name__) class ConsumerCoordinator(BaseCoordinator): """This class manages the coordination process with the consumer coordinator.""" + DEFAULT_CONFIG = { - 'group_id': 'kafka-python-default-group', - 'enable_auto_commit': True, - 'auto_commit_interval_ms': 5000, - 'default_offset_commit_callback': None, - 'assignors': (RangePartitionAssignor, RoundRobinPartitionAssignor, StickyPartitionAssignor), - 'session_timeout_ms': 10000, - 'heartbeat_interval_ms': 3000, - 'max_poll_interval_ms': 300000, - 'retry_backoff_ms': 100, - 'api_version': (0, 10, 1), - 'exclude_internal_topics': True, - 'metric_group_prefix': 'consumer' + "group_id": "kafka-python-default-group", + "enable_auto_commit": True, + "auto_commit_interval_ms": 5000, + "default_offset_commit_callback": None, + "assignors": ( + RangePartitionAssignor, + RoundRobinPartitionAssignor, + StickyPartitionAssignor, + ), + "session_timeout_ms": 10000, + "heartbeat_interval_ms": 3000, + "max_poll_interval_ms": 300000, + "retry_backoff_ms": 100, + "api_version": (0, 10, 1), + "exclude_internal_topics": True, + "metric_group_prefix": "consumer", } def __init__(self, client, subscription, metrics, **configs): @@ -88,46 +91,60 @@ def __init__(self, client, subscription, metrics, **configs): self._subscription = subscription self._is_leader = False self._joined_subscription = set() - self._metadata_snapshot = self._build_metadata_snapshot(subscription, client.cluster) + self._metadata_snapshot = self._build_metadata_snapshot( + subscription, client.cluster + ) self._assignment_snapshot = None self._cluster = client.cluster - self.auto_commit_interval = self.config['auto_commit_interval_ms'] / 1000 + self.auto_commit_interval = self.config["auto_commit_interval_ms"] / 1000 self.next_auto_commit_deadline = None self.completed_offset_commits = collections.deque() - if self.config['default_offset_commit_callback'] is None: - self.config['default_offset_commit_callback'] = self._default_offset_commit_callback - - if self.config['group_id'] is not None: - if self.config['api_version'] >= (0, 9): - if not self.config['assignors']: - raise Errors.KafkaConfigurationError('Coordinator requires assignors') - if self.config['api_version'] < (0, 10, 1): - if self.config['max_poll_interval_ms'] != self.config['session_timeout_ms']: - raise Errors.KafkaConfigurationError("Broker version %s does not support " - "different values for max_poll_interval_ms " - "and session_timeout_ms") - - if self.config['enable_auto_commit']: - if self.config['api_version'] < (0, 8, 1): - log.warning('Broker version (%s) does not support offset' - ' commits; disabling auto-commit.', - self.config['api_version']) - self.config['enable_auto_commit'] = False - elif self.config['group_id'] is None: - log.warning('group_id is None: disabling auto-commit.') - self.config['enable_auto_commit'] = False + if self.config["default_offset_commit_callback"] is None: + self.config[ + "default_offset_commit_callback" + ] = self._default_offset_commit_callback + + if self.config["group_id"] is not None: + if self.config["api_version"] >= (0, 9): + if not self.config["assignors"]: + raise Errors.KafkaConfigurationError( + "Coordinator requires assignors" + ) + if self.config["api_version"] < (0, 10, 1): + if ( + self.config["max_poll_interval_ms"] + != self.config["session_timeout_ms"] + ): + raise Errors.KafkaConfigurationError( + "Broker version %s does not support " + "different values for max_poll_interval_ms " + "and session_timeout_ms" + ) + + if self.config["enable_auto_commit"]: + if self.config["api_version"] < (0, 8, 1): + log.warning( + "Broker version (%s) does not support offset" + " commits; disabling auto-commit.", + self.config["api_version"], + ) + self.config["enable_auto_commit"] = False + elif self.config["group_id"] is None: + log.warning("group_id is None: disabling auto-commit.") + self.config["enable_auto_commit"] = False else: self.next_auto_commit_deadline = time.time() + self.auto_commit_interval self.consumer_sensors = ConsumerCoordinatorMetrics( - metrics, self.config['metric_group_prefix'], self._subscription) + metrics, self.config["metric_group_prefix"], self._subscription + ) self._cluster.request_update() self._cluster.add_listener(WeakMethod(self._handle_metadata_update)) def __del__(self): - if hasattr(self, '_cluster') and self._cluster: + if hasattr(self, "_cluster") and self._cluster: self._cluster.remove_listener(WeakMethod(self._handle_metadata_update)) super(ConsumerCoordinator, self).__del__() @@ -137,20 +154,20 @@ def protocol_type(self): def group_protocols(self): """Returns list of preferred (protocols, metadata)""" if self._subscription.subscription is None: - raise Errors.IllegalStateError('Consumer has not subscribed to topics') + raise Errors.IllegalStateError("Consumer has not subscribed to topics") # dpkp note: I really dislike this. # why? because we are using this strange method group_protocols, # which is seemingly innocuous, to set internal state (_joined_subscription) - # that is later used to check whether metadata has changed since we joined a group - # but there is no guarantee that this method, group_protocols, will get called - # in the correct sequence or that it will only be called when we want it to be. - # So this really should be moved elsewhere, but I don't have the energy to - # work that out right now. If you read this at some later date after the mutable - # state has bitten you... I'm sorry! It mimics the java client, and that's the - # best I've got for now. + # that is later used to check whether metadata has changed since we joined a + # group but there is no guarantee that this method, group_protocols, will get + # called in the correct sequence or that it will only be called when we want it + # to be. So this really should be moved elsewhere, but I don't have the energy + # to work that out right now. If you read this at some later date after the + # mutable state has bitten you... I'm sorry! It mimics the java client, and + # that's the best I've got for now. self._joined_subscription = set(self._subscription.subscription) metadata_list = [] - for assignor in self.config['assignors']: + for assignor in self.config["assignors"]: metadata = assignor.metadata(self._joined_subscription) group_protocol = (assignor.name, metadata) metadata_list.append(group_protocol) @@ -163,7 +180,7 @@ def _handle_metadata_update(self, cluster): if self._subscription.subscribed_pattern: topics = [] - for topic in cluster.topics(self.config['exclude_internal_topics']): + for topic in cluster.topics(self.config["exclude_internal_topics"]): if self._subscription.subscribed_pattern.match(topic): topics.append(topic) @@ -174,25 +191,29 @@ def _handle_metadata_update(self, cluster): # check if there are any changes to the metadata which should trigger # a rebalance if self._subscription.partitions_auto_assigned(): - metadata_snapshot = self._build_metadata_snapshot(self._subscription, cluster) + metadata_snapshot = self._build_metadata_snapshot( + self._subscription, cluster + ) if self._metadata_snapshot != metadata_snapshot: self._metadata_snapshot = metadata_snapshot # If we haven't got group coordinator support, # just assign all partitions locally if self._auto_assign_all_partitions(): - self._subscription.assign_from_subscribed([ - TopicPartition(topic, partition) - for topic in self._subscription.subscription - for partition in self._metadata_snapshot[topic] - ]) + self._subscription.assign_from_subscribed( + [ + TopicPartition(topic, partition) + for topic in self._subscription.subscription + for partition in self._metadata_snapshot[topic] + ] + ) def _auto_assign_all_partitions(self): # For users that use "subscribe" without group support, # we will simply assign all partitions to this consumer - if self.config['api_version'] < (0, 9): + if self.config["api_version"] < (0, 9): return True - elif self.config['group_id'] is None: + elif self.config["group_id"] is None: return True else: return False @@ -205,20 +226,23 @@ def _build_metadata_snapshot(self, subscription, cluster): return metadata_snapshot def _lookup_assignor(self, name): - for assignor in self.config['assignors']: + for assignor in self.config["assignors"]: if assignor.name == name: return assignor return None - def _on_join_complete(self, generation, member_id, protocol, - member_assignment_bytes): + def _on_join_complete( + self, generation, member_id, protocol, member_assignment_bytes + ): # only the leader is responsible for monitoring for metadata changes # (i.e. partition changes) if not self._is_leader: self._assignment_snapshot = None assignor = self._lookup_assignor(protocol) - assert assignor, 'Coordinator selected invalid assignment protocol: %s' % (protocol,) + assert assignor, "Coordinator selected invalid assignment protocol: %s" % ( + protocol, + ) assignment = ConsumerProtocol.ASSIGNMENT.decode(member_assignment_bytes) @@ -235,25 +259,29 @@ def _on_join_complete(self, generation, member_id, protocol, # give the assignor a chance to update internal state # based on the received assignment assignor.on_assignment(assignment) - if assignor.name == 'sticky': + if assignor.name == "sticky": assignor.on_generation_assignment(generation) # reschedule the auto commit starting from now self.next_auto_commit_deadline = time.time() + self.auto_commit_interval assigned = set(self._subscription.assigned_partitions()) - log.info("Setting newly assigned partitions %s for group %s", - assigned, self.group_id) + log.info( + "Setting newly assigned partitions %s for group %s", assigned, self.group_id + ) # execute the user's callback after rebalance if self._subscription.listener: try: self._subscription.listener.on_partitions_assigned(assigned) except Exception: - log.exception("User provided listener %s for group %s" - " failed on partition assignment: %s", - self._subscription.listener, self.group_id, - assigned) + log.exception( + "User provided listener %s for group %s" + " failed on partition assignment: %s", + self._subscription.listener, + self.group_id, + assigned, + ) def poll(self): """ @@ -269,7 +297,10 @@ def poll(self): self._invoke_completed_offset_commit_callbacks() self.ensure_coordinator_ready() - if self.config['api_version'] >= (0, 9) and self._subscription.partitions_auto_assigned(): + if ( + self.config["api_version"] >= (0, 9) + and self._subscription.partitions_auto_assigned() + ): if self.need_rejoin(): # due to a race condition between the initial metadata fetch and the # initial rebalance, we need to ensure that the metadata is fresh @@ -293,25 +324,30 @@ def poll(self): self._maybe_auto_commit_offsets_async() def time_to_next_poll(self): - """Return seconds (float) remaining until :meth:`.poll` should be called again""" - if not self.config['enable_auto_commit']: + """Return seconds (float) remaining until :meth:`.poll` should be called + again + """ + if not self.config["enable_auto_commit"]: return self.time_to_next_heartbeat() if time.time() > self.next_auto_commit_deadline: return 0 - return min(self.next_auto_commit_deadline - time.time(), - self.time_to_next_heartbeat()) + return min( + self.next_auto_commit_deadline - time.time(), self.time_to_next_heartbeat() + ) def _perform_assignment(self, leader_id, assignment_strategy, members): assignor = self._lookup_assignor(assignment_strategy) - assert assignor, 'Invalid assignment protocol: %s' % (assignment_strategy,) + assert assignor, "Invalid assignment protocol: %s" % (assignment_strategy,) member_metadata = {} all_subscribed_topics = set() for member_id, metadata_bytes in members: metadata = ConsumerProtocol.METADATA.decode(metadata_bytes) member_metadata[member_id] = metadata - all_subscribed_topics.update(metadata.subscription) # pylint: disable-msg=no-member + all_subscribed_topics.update( + metadata.subscription + ) # pylint: disable-msg=no-member # the leader will begin watching for changes to any of the topics # the group is interested in, which ensures that all metadata changes @@ -327,16 +363,20 @@ def _perform_assignment(self, leader_id, assignment_strategy, members): self._is_leader = True self._assignment_snapshot = self._metadata_snapshot - log.debug("Performing assignment for group %s using strategy %s" - " with subscriptions %s", self.group_id, assignor.name, - member_metadata) + log.debug( + "Performing assignment for group %s using strategy %s" + " with subscriptions %s", + self.group_id, + assignor.name, + member_metadata, + ) assignments = assignor.assign(self._cluster, member_metadata) log.debug("Finished assignment for group %s: %s", self.group_id, assignments) group_assignment = {} - for member_id, assignment in six.iteritems(assignments): + for member_id, assignment in assignments.items(): group_assignment[member_id] = assignment return group_assignment @@ -345,16 +385,22 @@ def _on_join_prepare(self, generation, member_id): self._maybe_auto_commit_offsets_sync() # execute the user's callback before rebalance - log.info("Revoking previously assigned partitions %s for group %s", - self._subscription.assigned_partitions(), self.group_id) + log.info( + "Revoking previously assigned partitions %s for group %s", + self._subscription.assigned_partitions(), + self.group_id, + ) if self._subscription.listener: try: revoked = set(self._subscription.assigned_partitions()) self._subscription.listener.on_partitions_revoked(revoked) except Exception: - log.exception("User provided subscription listener %s" - " for group %s failed on_partitions_revoked", - self._subscription.listener, self.group_id) + log.exception( + "User provided subscription listener %s" + " for group %s failed on_partitions_revoked", + self._subscription.listener, + self.group_id, + ) self._is_leader = False self._subscription.reset_group_subscription() @@ -372,13 +418,17 @@ def need_rejoin(self): return False # we need to rejoin if we performed the assignment and metadata has changed - if (self._assignment_snapshot is not None - and self._assignment_snapshot != self._metadata_snapshot): + if ( + self._assignment_snapshot is not None + and self._assignment_snapshot != self._metadata_snapshot + ): return True # we need to join if our subscription has changed since the last join - if (self._joined_subscription is not None - and self._joined_subscription != self._subscription.subscription): + if ( + self._joined_subscription is not None + and self._joined_subscription != self._subscription.subscription + ): return True return super(ConsumerCoordinator, self).need_rejoin() @@ -386,8 +436,10 @@ def need_rejoin(self): def refresh_committed_offsets_if_needed(self): """Fetch committed offsets for assigned partitions.""" if self._subscription.needs_fetch_committed_offsets: - offsets = self.fetch_committed_offsets(self._subscription.assigned_partitions()) - for partition, offset in six.iteritems(offsets): + offsets = self.fetch_committed_offsets( + self._subscription.assigned_partitions() + ) + for partition, offset in offsets.items(): # verify assignment is still active if self._subscription.is_assigned(partition): self._subscription.assignment[partition].committed = offset @@ -416,9 +468,9 @@ def fetch_committed_offsets(self, partitions): return future.value if not future.retriable(): - raise future.exception # pylint: disable-msg=raising-bad-type + raise future.exception # pylint: disable-msg=raising-bad-type - time.sleep(self.config['retry_backoff_ms'] / 1000) + time.sleep(self.config["retry_backoff_ms"] / 1000) def close(self, autocommit=True): """Close the coordinator, leave the current group, @@ -465,28 +517,39 @@ def commit_offsets_async(self, offsets, callback=None): # same order that they were added. Note also that BaseCoordinator # prevents multiple concurrent coordinator lookup requests. future = self.lookup_coordinator() - future.add_callback(lambda r: functools.partial(self._do_commit_offsets_async, offsets, callback)()) + future.add_callback( + lambda r: functools.partial( + self._do_commit_offsets_async, offsets, callback + )() + ) if callback: - future.add_errback(lambda e: self.completed_offset_commits.appendleft((callback, offsets, e))) + future.add_errback( + lambda e: self.completed_offset_commits.appendleft( + (callback, offsets, e) + ) + ) # ensure the commit has a chance to be transmitted (without blocking on # its completion). Note that commits are treated as heartbeats by the # coordinator, so there is no need to explicitly allow heartbeats # through delayed task execution. - self._client.poll(timeout_ms=0) # no wakeup if we add that feature + self._client.poll(timeout_ms=0) # no wakeup if we add that feature return future def _do_commit_offsets_async(self, offsets, callback=None): - assert self.config['api_version'] >= (0, 8, 1), 'Unsupported Broker API' + assert self.config["api_version"] >= (0, 8, 1), "Unsupported Broker API" assert all(map(lambda k: isinstance(k, TopicPartition), offsets)) - assert all(map(lambda v: isinstance(v, OffsetAndMetadata), - offsets.values())) + assert all(map(lambda v: isinstance(v, OffsetAndMetadata), offsets.values())) if callback is None: - callback = self.config['default_offset_commit_callback'] + callback = self.config["default_offset_commit_callback"] self._subscription.needs_fetch_committed_offsets = True future = self._send_offset_commit_request(offsets) - future.add_both(lambda res: self.completed_offset_commits.appendleft((callback, offsets, res))) + future.add_both( + lambda res: self.completed_offset_commits.appendleft( + (callback, offsets, res) + ) + ) return future def commit_offsets_sync(self, offsets): @@ -500,10 +563,9 @@ def commit_offsets_sync(self, offsets): Raises error on failure """ - assert self.config['api_version'] >= (0, 8, 1), 'Unsupported Broker API' + assert self.config["api_version"] >= (0, 8, 1), "Unsupported Broker API" assert all(map(lambda k: isinstance(k, TopicPartition), offsets)) - assert all(map(lambda v: isinstance(v, OffsetAndMetadata), - offsets.values())) + assert all(map(lambda v: isinstance(v, OffsetAndMetadata), offsets.values())) self._invoke_completed_offset_commit_callbacks() if not offsets: return @@ -518,26 +580,32 @@ def commit_offsets_sync(self, offsets): return future.value if not future.retriable(): - raise future.exception # pylint: disable-msg=raising-bad-type + raise future.exception # pylint: disable-msg=raising-bad-type - time.sleep(self.config['retry_backoff_ms'] / 1000) + time.sleep(self.config["retry_backoff_ms"] / 1000) def _maybe_auto_commit_offsets_sync(self): - if self.config['enable_auto_commit']: + if self.config["enable_auto_commit"]: try: self.commit_offsets_sync(self._subscription.all_consumed_offsets()) # The three main group membership errors are known and should not # require a stacktrace -- just a warning - except (Errors.UnknownMemberIdError, - Errors.IllegalGenerationError, - Errors.RebalanceInProgressError): - log.warning("Offset commit failed: group membership out of date" - " This is likely to cause duplicate message" - " delivery.") + except ( + Errors.UnknownMemberIdError, + Errors.IllegalGenerationError, + Errors.RebalanceInProgressError, + ): + log.warning( + "Offset commit failed: group membership out of date" + " This is likely to cause duplicate message" + " delivery." + ) except Exception: - log.exception("Offset commit failed: This is likely to cause" - " duplicate message delivery") + log.exception( + "Offset commit failed: This is likely to cause" + " duplicate message delivery" + ) def _send_offset_commit_request(self, offsets): """Commit offsets for the specified list of topics and partitions. @@ -553,22 +621,20 @@ def _send_offset_commit_request(self, offsets): Returns: Future: indicating whether the commit was successful or not """ - assert self.config['api_version'] >= (0, 8, 1), 'Unsupported Broker API' + assert self.config["api_version"] >= (0, 8, 1), "Unsupported Broker API" assert all(map(lambda k: isinstance(k, TopicPartition), offsets)) - assert all(map(lambda v: isinstance(v, OffsetAndMetadata), - offsets.values())) + assert all(map(lambda v: isinstance(v, OffsetAndMetadata), offsets.values())) if not offsets: - log.debug('No offsets to commit') + log.debug("No offsets to commit") return Future().success(None) node_id = self.coordinator() if node_id is None: return Future().failure(Errors.GroupCoordinatorNotAvailableError) - # create the offset commit request offset_data = collections.defaultdict(dict) - for tp, offset in six.iteritems(offsets): + for tp, offset in offsets.items(): offset_data[tp.topic][tp.partition] = offset if self._subscription.partitions_auto_assigned(): @@ -579,53 +645,69 @@ def _send_offset_commit_request(self, offsets): # if the generation is None, we are not part of an active group # (and we expect to be). The only thing we can do is fail the commit # and let the user rejoin the group in poll() - if self.config['api_version'] >= (0, 9) and generation is None: + if self.config["api_version"] >= (0, 9) and generation is None: return Future().failure(Errors.CommitFailedError()) - if self.config['api_version'] >= (0, 9): + if self.config["api_version"] >= (0, 9): request = OffsetCommitRequest[2]( self.group_id, generation.generation_id, generation.member_id, OffsetCommitRequest[2].DEFAULT_RETENTION_TIME, - [( - topic, [( - partition, - offset.offset, - offset.metadata - ) for partition, offset in six.iteritems(partitions)] - ) for topic, partitions in six.iteritems(offset_data)] + [ + ( + topic, + [ + (partition, offset.offset, offset.metadata) + for partition, offset in partitions.items() + ], + ) + for topic, partitions in offset_data.items() + ], ) - elif self.config['api_version'] >= (0, 8, 2): + elif self.config["api_version"] >= (0, 8, 2): request = OffsetCommitRequest[1]( - self.group_id, -1, '', - [( - topic, [( - partition, - offset.offset, - -1, - offset.metadata - ) for partition, offset in six.iteritems(partitions)] - ) for topic, partitions in six.iteritems(offset_data)] + self.group_id, + -1, + "", + [ + ( + topic, + [ + (partition, offset.offset, -1, offset.metadata) + for partition, offset in partitions.items() + ], + ) + for topic, partitions in offset_data.items() + ], ) - elif self.config['api_version'] >= (0, 8, 1): + elif self.config["api_version"] >= (0, 8, 1): request = OffsetCommitRequest[0]( self.group_id, - [( - topic, [( - partition, - offset.offset, - offset.metadata - ) for partition, offset in six.iteritems(partitions)] - ) for topic, partitions in six.iteritems(offset_data)] + [ + ( + topic, + [ + (partition, offset.offset, offset.metadata) + for partition, offset in partitions.items() + ], + ) + for topic, partitions in offset_data.items() + ], ) - log.debug("Sending offset-commit request with %s for group %s to %s", - offsets, self.group_id, node_id) + log.debug( + "Sending offset-commit request with %s for group %s to %s", + offsets, + self.group_id, + node_id, + ) future = Future() _f = self._client.send(node_id, request) - _f.add_callback(self._handle_offset_commit_response, offsets, future, time.time()) + _f.add_callback( + self._handle_offset_commit_response, offsets, future, time.time() + ) _f.add_errback(self._failed_request, node_id, request, future) return future @@ -641,58 +723,87 @@ def _handle_offset_commit_response(self, offsets, future, send_time, response): error_type = Errors.for_code(error_code) if error_type is Errors.NoError: - log.debug("Group %s committed offset %s for partition %s", - self.group_id, offset, tp) + log.debug( + "Group %s committed offset %s for partition %s", + self.group_id, + offset, + tp, + ) if self._subscription.is_assigned(tp): self._subscription.assignment[tp].committed = offset elif error_type is Errors.GroupAuthorizationFailedError: - log.error("Not authorized to commit offsets for group %s", - self.group_id) + log.error( + "Not authorized to commit offsets for group %s", self.group_id + ) future.failure(error_type(self.group_id)) return elif error_type is Errors.TopicAuthorizationFailedError: unauthorized_topics.add(topic) - elif error_type in (Errors.OffsetMetadataTooLargeError, - Errors.InvalidCommitOffsetSizeError): + elif error_type in ( + Errors.OffsetMetadataTooLargeError, + Errors.InvalidCommitOffsetSizeError, + ): # raise the error to the user - log.debug("OffsetCommit for group %s failed on partition %s" - " %s", self.group_id, tp, error_type.__name__) + log.debug( + "OffsetCommit for group %s failed on partition %s" " %s", + self.group_id, + tp, + error_type.__name__, + ) future.failure(error_type()) return elif error_type is Errors.GroupLoadInProgressError: # just retry - log.debug("OffsetCommit for group %s failed: %s", - self.group_id, error_type.__name__) + log.debug( + "OffsetCommit for group %s failed: %s", + self.group_id, + error_type.__name__, + ) future.failure(error_type(self.group_id)) return - elif error_type in (Errors.GroupCoordinatorNotAvailableError, - Errors.NotCoordinatorForGroupError, - Errors.RequestTimedOutError): - log.debug("OffsetCommit for group %s failed: %s", - self.group_id, error_type.__name__) + elif error_type in ( + Errors.GroupCoordinatorNotAvailableError, + Errors.NotCoordinatorForGroupError, + Errors.RequestTimedOutError, + ): + log.debug( + "OffsetCommit for group %s failed: %s", + self.group_id, + error_type.__name__, + ) self.coordinator_dead(error_type()) future.failure(error_type(self.group_id)) return - elif error_type in (Errors.UnknownMemberIdError, - Errors.IllegalGenerationError, - Errors.RebalanceInProgressError): + elif error_type in ( + Errors.UnknownMemberIdError, + Errors.IllegalGenerationError, + Errors.RebalanceInProgressError, + ): # need to re-join group error = error_type(self.group_id) - log.debug("OffsetCommit for group %s failed: %s", - self.group_id, error) + log.debug( + "OffsetCommit for group %s failed: %s", self.group_id, error + ) self.reset_generation() future.failure(Errors.CommitFailedError()) return else: - log.error("Group %s failed to commit partition %s at offset" - " %s: %s", self.group_id, tp, offset, - error_type.__name__) + log.error( + "Group %s failed to commit partition %s at offset" " %s: %s", + self.group_id, + tp, + offset, + error_type.__name__, + ) future.failure(error_type()) return if unauthorized_topics: - log.error("Not authorized to commit to topics %s for group %s", - unauthorized_topics, self.group_id) + log.error( + "Not authorized to commit to topics %s for group %s", + unauthorized_topics, + self.group_id, + ) future.failure(Errors.TopicAuthorizationFailedError(unauthorized_topics)) else: future.success(None) @@ -709,7 +820,7 @@ def _send_offset_fetch_request(self, partitions): Returns: Future: resolves to dict of offsets: {TopicPartition: OffsetAndMetadata} """ - assert self.config['api_version'] >= (0, 8, 1), 'Unsupported Broker API' + assert self.config["api_version"] >= (0, 8, 1), "Unsupported Broker API" assert all(map(lambda k: isinstance(k, TopicPartition), partitions)) if not partitions: return Future().success({}) @@ -720,26 +831,26 @@ def _send_offset_fetch_request(self, partitions): # Verify node is ready if not self._client.ready(node_id): - log.debug("Node %s not ready -- failing offset fetch request", - node_id) + log.debug("Node %s not ready -- failing offset fetch request", node_id) return Future().failure(Errors.NodeNotReadyError) - log.debug("Group %s fetching committed offsets for partitions: %s", - self.group_id, partitions) + log.debug( + "Group %s fetching committed offsets for partitions: %s", + self.group_id, + partitions, + ) # construct the request topic_partitions = collections.defaultdict(set) for tp in partitions: topic_partitions[tp.topic].add(tp.partition) - if self.config['api_version'] >= (0, 8, 2): + if self.config["api_version"] >= (0, 8, 2): request = OffsetFetchRequest[1]( - self.group_id, - list(topic_partitions.items()) + self.group_id, list(topic_partitions.items()) ) else: request = OffsetFetchRequest[0]( - self.group_id, - list(topic_partitions.items()) + self.group_id, list(topic_partitions.items()) ) # send the request with a callback @@ -757,8 +868,12 @@ def _handle_offset_fetch_response(self, future, response): error_type = Errors.for_code(error_code) if error_type is not Errors.NoError: error = error_type() - log.debug("Group %s failed to fetch offset for partition" - " %s: %s", self.group_id, tp, error) + log.debug( + "Group %s failed to fetch offset for partition" " %s: %s", + self.group_id, + tp, + error, + ) if error_type is Errors.GroupLoadInProgressError: # just retry future.failure(error) @@ -767,13 +882,16 @@ def _handle_offset_fetch_response(self, future, response): self.coordinator_dead(error_type()) future.failure(error) elif error_type is Errors.UnknownTopicOrPartitionError: - log.warning("OffsetFetchRequest -- unknown topic %s" - " (have you committed any offsets yet?)", - topic) + log.warning( + "OffsetFetchRequest -- unknown topic %s" + " (have you committed any offsets yet?)", + topic, + ) continue else: - log.error("Unknown error fetching offsets for %s: %s", - tp, error) + log.error( + "Unknown error fetching offsets for %s: %s", tp, error + ) future.failure(error) return elif offset >= 0: @@ -781,8 +899,11 @@ def _handle_offset_fetch_response(self, future, response): # (-1 indicates no committed offset to fetch) offsets[tp] = OffsetAndMetadata(offset, metadata) else: - log.debug("Group %s has no committed offset for partition" - " %s", self.group_id, tp) + log.debug( + "Group %s has no committed offset for partition" " %s", + self.group_id, + tp, + ) future.success(offsets) def _default_offset_commit_callback(self, offsets, exception): @@ -791,43 +912,74 @@ def _default_offset_commit_callback(self, offsets, exception): def _commit_offsets_async_on_complete(self, offsets, exception): if exception is not None: - log.warning("Auto offset commit failed for group %s: %s", - self.group_id, exception) - if getattr(exception, 'retriable', False): - self.next_auto_commit_deadline = min(time.time() + self.config['retry_backoff_ms'] / 1000, self.next_auto_commit_deadline) + log.warning( + "Auto offset commit failed for group %s: %s", self.group_id, exception + ) + if getattr(exception, "retriable", False): + self.next_auto_commit_deadline = min( + time.time() + self.config["retry_backoff_ms"] / 1000, + self.next_auto_commit_deadline, + ) else: - log.debug("Completed autocommit of offsets %s for group %s", - offsets, self.group_id) + log.debug( + "Completed autocommit of offsets %s for group %s", + offsets, + self.group_id, + ) def _maybe_auto_commit_offsets_async(self): - if self.config['enable_auto_commit']: + if self.config["enable_auto_commit"]: if self.coordinator_unknown(): - self.next_auto_commit_deadline = time.time() + self.config['retry_backoff_ms'] / 1000 + self.next_auto_commit_deadline = ( + time.time() + self.config["retry_backoff_ms"] / 1000 + ) elif time.time() > self.next_auto_commit_deadline: self.next_auto_commit_deadline = time.time() + self.auto_commit_interval - self.commit_offsets_async(self._subscription.all_consumed_offsets(), - self._commit_offsets_async_on_complete) + self.commit_offsets_async( + self._subscription.all_consumed_offsets(), + self._commit_offsets_async_on_complete, + ) class ConsumerCoordinatorMetrics(object): def __init__(self, metrics, metric_group_prefix, subscription): self.metrics = metrics - self.metric_group_name = '%s-coordinator-metrics' % (metric_group_prefix,) - - self.commit_latency = metrics.sensor('commit-latency') - self.commit_latency.add(metrics.metric_name( - 'commit-latency-avg', self.metric_group_name, - 'The average time taken for a commit request'), Avg()) - self.commit_latency.add(metrics.metric_name( - 'commit-latency-max', self.metric_group_name, - 'The max time taken for a commit request'), Max()) - self.commit_latency.add(metrics.metric_name( - 'commit-rate', self.metric_group_name, - 'The number of commit calls per second'), Rate(sampled_stat=Count())) - - num_parts = AnonMeasurable(lambda config, now: - len(subscription.assigned_partitions())) - metrics.add_metric(metrics.metric_name( - 'assigned-partitions', self.metric_group_name, - 'The number of partitions currently assigned to this consumer'), - num_parts) + self.metric_group_name = "%s-coordinator-metrics" % (metric_group_prefix,) + + self.commit_latency = metrics.sensor("commit-latency") + self.commit_latency.add( + metrics.metric_name( + "commit-latency-avg", + self.metric_group_name, + "The average time taken for a commit request", + ), + Avg(), + ) + self.commit_latency.add( + metrics.metric_name( + "commit-latency-max", + self.metric_group_name, + "The max time taken for a commit request", + ), + Max(), + ) + self.commit_latency.add( + metrics.metric_name( + "commit-rate", + self.metric_group_name, + "The number of commit calls per second", + ), + Rate(sampled_stat=Count()), + ) + + num_parts = AnonMeasurable( + lambda config, now: len(subscription.assigned_partitions()) + ) + metrics.add_metric( + metrics.metric_name( + "assigned-partitions", + self.metric_group_name, + "The number of partitions currently assigned to this consumer", + ), + num_parts, + ) diff --git a/kafka/coordinator/heartbeat.py b/aiokafka/coordinator/heartbeat.py similarity index 59% rename from kafka/coordinator/heartbeat.py rename to aiokafka/coordinator/heartbeat.py index 2f5930b6..b10a726d 100644 --- a/kafka/coordinator/heartbeat.py +++ b/aiokafka/coordinator/heartbeat.py @@ -1,16 +1,14 @@ -from __future__ import absolute_import, division - import copy import time class Heartbeat(object): DEFAULT_CONFIG = { - 'group_id': None, - 'heartbeat_interval_ms': 3000, - 'session_timeout_ms': 10000, - 'max_poll_interval_ms': 300000, - 'retry_backoff_ms': 100, + "group_id": None, + "heartbeat_interval_ms": 3000, + "session_timeout_ms": 10000, + "max_poll_interval_ms": 300000, + "retry_backoff_ms": 100, } def __init__(self, **configs): @@ -19,14 +17,15 @@ def __init__(self, **configs): if key in configs: self.config[key] = configs[key] - if self.config['group_id'] is not None: - assert (self.config['heartbeat_interval_ms'] - <= self.config['session_timeout_ms']), ( - 'Heartbeat interval must be lower than the session timeout') + if self.config["group_id"] is not None: + assert ( + self.config["heartbeat_interval_ms"] + <= self.config["session_timeout_ms"] + ), "Heartbeat interval must be lower than the session timeout" - self.last_send = -1 * float('inf') - self.last_receive = -1 * float('inf') - self.last_poll = -1 * float('inf') + self.last_send = -1 * float("inf") + self.last_receive = -1 * float("inf") + self.last_poll = -1 * float("inf") self.last_reset = time.time() self.heartbeat_failed = None @@ -47,9 +46,9 @@ def time_to_next_heartbeat(self): """Returns seconds (float) remaining before next heartbeat should be sent""" time_since_last_heartbeat = time.time() - max(self.last_send, self.last_reset) if self.heartbeat_failed: - delay_to_next_heartbeat = self.config['retry_backoff_ms'] / 1000 + delay_to_next_heartbeat = self.config["retry_backoff_ms"] / 1000 else: - delay_to_next_heartbeat = self.config['heartbeat_interval_ms'] / 1000 + delay_to_next_heartbeat = self.config["heartbeat_interval_ms"] / 1000 return max(0, delay_to_next_heartbeat - time_since_last_heartbeat) def should_heartbeat(self): @@ -57,7 +56,7 @@ def should_heartbeat(self): def session_timeout_expired(self): last_recv = max(self.last_receive, self.last_reset) - return (time.time() - last_recv) > (self.config['session_timeout_ms'] / 1000) + return (time.time() - last_recv) > (self.config["session_timeout_ms"] / 1000) def reset_timeouts(self): self.last_reset = time.time() @@ -65,4 +64,6 @@ def reset_timeouts(self): self.heartbeat_failed = False def poll_timeout_expired(self): - return (time.time() - self.last_poll) > (self.config['max_poll_interval_ms'] / 1000) + return (time.time() - self.last_poll) > ( + self.config["max_poll_interval_ms"] / 1000 + ) diff --git a/aiokafka/coordinator/protocol.py b/aiokafka/coordinator/protocol.py new file mode 100644 index 00000000..87425007 --- /dev/null +++ b/aiokafka/coordinator/protocol.py @@ -0,0 +1,33 @@ +from kafka.protocol.struct import Struct +from kafka.protocol.types import Array, Bytes, Int16, Int32, Schema, String +from kafka.structs import TopicPartition + + +class ConsumerProtocolMemberMetadata(Struct): + SCHEMA = Schema( + ("version", Int16), + ("subscription", Array(String("utf-8"))), + ("user_data", Bytes), + ) + + +class ConsumerProtocolMemberAssignment(Struct): + SCHEMA = Schema( + ("version", Int16), + ("assignment", Array(("topic", String("utf-8")), ("partitions", Array(Int32)))), + ("user_data", Bytes), + ) + + def partitions(self): + return [ + TopicPartition(topic, partition) + for topic, partitions in self.assignment # pylint: disable-msg=no-member + for partition in partitions + ] + + +class ConsumerProtocol(object): + PROTOCOL_TYPE = "consumer" + ASSIGNMENT_STRATEGIES = ("range", "roundrobin") + METADATA = ConsumerProtocolMemberMetadata + ASSIGNMENT = ConsumerProtocolMemberAssignment diff --git a/docs/api.rst b/docs/api.rst index 1a404d2d..2a8043c7 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -132,7 +132,7 @@ Other references .. autoclass:: aiokafka.producer.message_accumulator.BatchBuilder .. autoclass:: aiokafka.consumer.group_coordinator.GroupCoordinator -.. autoclass:: kafka.coordinator.assignors.roundrobin.RoundRobinPartitionAssignor +.. autoclass:: aiokafka.coordinator.assignors.roundrobin.RoundRobinPartitionAssignor Errors diff --git a/kafka/coordinator/protocol.py b/kafka/coordinator/protocol.py deleted file mode 100644 index 56a39015..00000000 --- a/kafka/coordinator/protocol.py +++ /dev/null @@ -1,33 +0,0 @@ -from __future__ import absolute_import - -from kafka.protocol.struct import Struct -from kafka.protocol.types import Array, Bytes, Int16, Int32, Schema, String -from kafka.structs import TopicPartition - - -class ConsumerProtocolMemberMetadata(Struct): - SCHEMA = Schema( - ('version', Int16), - ('subscription', Array(String('utf-8'))), - ('user_data', Bytes)) - - -class ConsumerProtocolMemberAssignment(Struct): - SCHEMA = Schema( - ('version', Int16), - ('assignment', Array( - ('topic', String('utf-8')), - ('partitions', Array(Int32)))), - ('user_data', Bytes)) - - def partitions(self): - return [TopicPartition(topic, partition) - for topic, partitions in self.assignment # pylint: disable-msg=no-member - for partition in partitions] - - -class ConsumerProtocol(object): - PROTOCOL_TYPE = 'consumer' - ASSIGNMENT_STRATEGIES = ('range', 'roundrobin') - METADATA = ConsumerProtocolMemberMetadata - ASSIGNMENT = ConsumerProtocolMemberAssignment diff --git a/tests/coordinator/__init__.py b/tests/coordinator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/kafka/test_assignors.py b/tests/coordinator/test_assignors.py similarity index 51% rename from tests/kafka/test_assignors.py rename to tests/coordinator/test_assignors.py index 858ef426..5fc8a5f5 100644 --- a/tests/kafka/test_assignors.py +++ b/tests/coordinator/test_assignors.py @@ -7,11 +7,12 @@ import pytest from kafka.structs import TopicPartition -from kafka.coordinator.assignors.range import RangePartitionAssignor -from kafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor -from kafka.coordinator.assignors.sticky.sticky_assignor import StickyPartitionAssignor, StickyAssignorUserDataV1 -from kafka.coordinator.protocol import ConsumerProtocolMemberAssignment, ConsumerProtocolMemberMetadata -from kafka.vendor import six +from aiokafka.coordinator.assignors.range import RangePartitionAssignor +from aiokafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor +from aiokafka.coordinator.assignors.sticky.sticky_assignor import ( + StickyPartitionAssignor, +) +from aiokafka.coordinator.protocol import ConsumerProtocolMemberAssignment @pytest.fixture(autouse=True) @@ -21,7 +22,9 @@ def reset_sticky_assignor(): StickyPartitionAssignor.generation = -1 -def create_cluster(mocker, topics, topics_partitions=None, topic_partitions_lambda=None): +def create_cluster( + mocker, topics, topics_partitions=None, topic_partitions_lambda=None +): cluster = mocker.MagicMock() cluster.topics.return_value = topics if topics_partitions is not None: @@ -35,17 +38,19 @@ def test_assignor_roundrobin(mocker): assignor = RoundRobinPartitionAssignor member_metadata = { - 'C0': assignor.metadata({'t0', 't1'}), - 'C1': assignor.metadata({'t0', 't1'}), + "C0": assignor.metadata({"t0", "t1"}), + "C1": assignor.metadata({"t0", "t1"}), } - cluster = create_cluster(mocker, {'t0', 't1'}, topics_partitions={0, 1, 2}) + cluster = create_cluster(mocker, {"t0", "t1"}, topics_partitions={0, 1, 2}) ret = assignor.assign(cluster, member_metadata) expected = { - 'C0': ConsumerProtocolMemberAssignment( - assignor.version, [('t0', [0, 2]), ('t1', [1])], b''), - 'C1': ConsumerProtocolMemberAssignment( - assignor.version, [('t0', [1]), ('t1', [0, 2])], b'') + "C0": ConsumerProtocolMemberAssignment( + assignor.version, [("t0", [0, 2]), ("t1", [1])], b"" + ), + "C1": ConsumerProtocolMemberAssignment( + assignor.version, [("t0", [1]), ("t1", [0, 2])], b"" + ), } assert ret == expected assert set(ret) == set(expected) @@ -57,17 +62,19 @@ def test_assignor_range(mocker): assignor = RangePartitionAssignor member_metadata = { - 'C0': assignor.metadata({'t0', 't1'}), - 'C1': assignor.metadata({'t0', 't1'}), + "C0": assignor.metadata({"t0", "t1"}), + "C1": assignor.metadata({"t0", "t1"}), } - cluster = create_cluster(mocker, {'t0', 't1'}, topics_partitions={0, 1, 2}) + cluster = create_cluster(mocker, {"t0", "t1"}, topics_partitions={0, 1, 2}) ret = assignor.assign(cluster, member_metadata) expected = { - 'C0': ConsumerProtocolMemberAssignment( - assignor.version, [('t0', [0, 1]), ('t1', [0, 1])], b''), - 'C1': ConsumerProtocolMemberAssignment( - assignor.version, [('t0', [2]), ('t1', [2])], b'') + "C0": ConsumerProtocolMemberAssignment( + assignor.version, [("t0", [0, 1]), ("t1", [0, 1])], b"" + ), + "C1": ConsumerProtocolMemberAssignment( + assignor.version, [("t0", [2]), ("t1", [2])], b"" + ), } assert ret == expected assert set(ret) == set(expected) @@ -91,35 +98,53 @@ def test_sticky_assignor1(mocker): - C0 [t0p0, t1p1, t2p0, t3p0] - C2 [t0p1, t1p0, t2p1, t3p1] """ - cluster = create_cluster(mocker, topics={'t0', 't1', 't2', 't3'}, topics_partitions={0, 1}) + cluster = create_cluster( + mocker, topics={"t0", "t1", "t2", "t3"}, topics_partitions={0, 1} + ) subscriptions = { - 'C0': {'t0', 't1', 't2', 't3'}, - 'C1': {'t0', 't1', 't2', 't3'}, - 'C2': {'t0', 't1', 't2', 't3'}, + "C0": {"t0", "t1", "t2", "t3"}, + "C1": {"t0", "t1", "t2", "t3"}, + "C2": {"t0", "t1", "t2", "t3"}, } member_metadata = make_member_metadata(subscriptions) sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) expected_assignment = { - 'C0': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t0', [0]), ('t1', [1]), ('t3', [0])], b''), - 'C1': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t0', [1]), ('t2', [0]), ('t3', [1])], b''), - 'C2': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t1', [0]), ('t2', [1])], b''), + "C0": ConsumerProtocolMemberAssignment( + StickyPartitionAssignor.version, + [("t0", [0]), ("t1", [1]), ("t3", [0])], + b"", + ), + "C1": ConsumerProtocolMemberAssignment( + StickyPartitionAssignor.version, + [("t0", [1]), ("t2", [0]), ("t3", [1])], + b"", + ), + "C2": ConsumerProtocolMemberAssignment( + StickyPartitionAssignor.version, [("t1", [0]), ("t2", [1])], b"" + ), } assert_assignment(sticky_assignment, expected_assignment) - del subscriptions['C1'] + del subscriptions["C1"] member_metadata = {} - for member, topics in six.iteritems(subscriptions): - member_metadata[member] = StickyPartitionAssignor._metadata(topics, sticky_assignment[member].partitions()) + for member, topics in subscriptions.items(): + member_metadata[member] = StickyPartitionAssignor._metadata( + topics, sticky_assignment[member].partitions() + ) sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) expected_assignment = { - 'C0': ConsumerProtocolMemberAssignment( - StickyPartitionAssignor.version, [('t0', [0]), ('t1', [1]), ('t2', [0]), ('t3', [0])], b'' + "C0": ConsumerProtocolMemberAssignment( + StickyPartitionAssignor.version, + [("t0", [0]), ("t1", [1]), ("t2", [0]), ("t3", [0])], + b"", ), - 'C2': ConsumerProtocolMemberAssignment( - StickyPartitionAssignor.version, [('t0', [1]), ('t1', [0]), ('t2', [1]), ('t3', [1])], b'' + "C2": ConsumerProtocolMemberAssignment( + StickyPartitionAssignor.version, + [("t0", [1]), ("t1", [0]), ("t2", [1]), ("t3", [1])], + b"", ), } assert_assignment(sticky_assignment, expected_assignment) @@ -144,35 +169,51 @@ def test_sticky_assignor2(mocker): - C2 [t2p0, t2p1, t2p2] """ - partitions = {'t0': {0}, 't1': {0, 1}, 't2': {0, 1, 2}} - cluster = create_cluster(mocker, topics={'t0', 't1', 't2'}, topic_partitions_lambda=lambda t: partitions[t]) + partitions = {"t0": {0}, "t1": {0, 1}, "t2": {0, 1, 2}} + cluster = create_cluster( + mocker, + topics={"t0", "t1", "t2"}, + topic_partitions_lambda=lambda t: partitions[t], + ) subscriptions = { - 'C0': {'t0'}, - 'C1': {'t0', 't1'}, - 'C2': {'t0', 't1', 't2'}, + "C0": {"t0"}, + "C1": {"t0", "t1"}, + "C2": {"t0", "t1", "t2"}, } member_metadata = {} - for member, topics in six.iteritems(subscriptions): + for member, topics in subscriptions.items(): member_metadata[member] = StickyPartitionAssignor._metadata(topics, []) sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) expected_assignment = { - 'C0': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t0', [0])], b''), - 'C1': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t1', [0, 1])], b''), - 'C2': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t2', [0, 1, 2])], b''), + "C0": ConsumerProtocolMemberAssignment( + StickyPartitionAssignor.version, [("t0", [0])], b"" + ), + "C1": ConsumerProtocolMemberAssignment( + StickyPartitionAssignor.version, [("t1", [0, 1])], b"" + ), + "C2": ConsumerProtocolMemberAssignment( + StickyPartitionAssignor.version, [("t2", [0, 1, 2])], b"" + ), } assert_assignment(sticky_assignment, expected_assignment) - del subscriptions['C0'] + del subscriptions["C0"] member_metadata = {} - for member, topics in six.iteritems(subscriptions): - member_metadata[member] = StickyPartitionAssignor._metadata(topics, sticky_assignment[member].partitions()) + for member, topics in subscriptions.items(): + member_metadata[member] = StickyPartitionAssignor._metadata( + topics, sticky_assignment[member].partitions() + ) sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) expected_assignment = { - 'C1': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t0', [0]), ('t1', [0, 1])], b''), - 'C2': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t2', [0, 1, 2])], b''), + "C1": ConsumerProtocolMemberAssignment( + StickyPartitionAssignor.version, [("t0", [0]), ("t1", [0, 1])], b"" + ), + "C2": ConsumerProtocolMemberAssignment( + StickyPartitionAssignor.version, [("t2", [0, 1, 2])], b"" + ), } assert_assignment(sticky_assignment, expected_assignment) @@ -181,13 +222,13 @@ def test_sticky_one_consumer_no_topic(mocker): cluster = create_cluster(mocker, topics={}, topics_partitions={}) subscriptions = { - 'C': set(), + "C": set(), } member_metadata = make_member_metadata(subscriptions) sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) expected_assignment = { - 'C': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [], b''), + "C": ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [], b""), } assert_assignment(sticky_assignment, expected_assignment) @@ -196,136 +237,162 @@ def test_sticky_one_consumer_nonexisting_topic(mocker): cluster = create_cluster(mocker, topics={}, topics_partitions={}) subscriptions = { - 'C': {'t'}, + "C": {"t"}, } member_metadata = make_member_metadata(subscriptions) sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) expected_assignment = { - 'C': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [], b''), + "C": ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [], b""), } assert_assignment(sticky_assignment, expected_assignment) def test_sticky_one_consumer_one_topic(mocker): - cluster = create_cluster(mocker, topics={'t'}, topics_partitions={0, 1, 2}) + cluster = create_cluster(mocker, topics={"t"}, topics_partitions={0, 1, 2}) subscriptions = { - 'C': {'t'}, + "C": {"t"}, } member_metadata = make_member_metadata(subscriptions) sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) expected_assignment = { - 'C': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t', [0, 1, 2])], b''), + "C": ConsumerProtocolMemberAssignment( + StickyPartitionAssignor.version, [("t", [0, 1, 2])], b"" + ), } assert_assignment(sticky_assignment, expected_assignment) def test_sticky_should_only_assign_partitions_from_subscribed_topics(mocker): - cluster = create_cluster(mocker, topics={'t', 'other-t'}, topics_partitions={0, 1, 2}) + cluster = create_cluster( + mocker, topics={"t", "other-t"}, topics_partitions={0, 1, 2} + ) subscriptions = { - 'C': {'t'}, + "C": {"t"}, } member_metadata = make_member_metadata(subscriptions) sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) expected_assignment = { - 'C': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t', [0, 1, 2])], b''), + "C": ConsumerProtocolMemberAssignment( + StickyPartitionAssignor.version, [("t", [0, 1, 2])], b"" + ), } assert_assignment(sticky_assignment, expected_assignment) def test_sticky_one_consumer_multiple_topics(mocker): - cluster = create_cluster(mocker, topics={'t1', 't2'}, topics_partitions={0, 1, 2}) + cluster = create_cluster(mocker, topics={"t1", "t2"}, topics_partitions={0, 1, 2}) subscriptions = { - 'C': {'t1', 't2'}, + "C": {"t1", "t2"}, } member_metadata = make_member_metadata(subscriptions) sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) expected_assignment = { - 'C': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t1', [0, 1, 2]), ('t2', [0, 1, 2])], b''), + "C": ConsumerProtocolMemberAssignment( + StickyPartitionAssignor.version, [("t1", [0, 1, 2]), ("t2", [0, 1, 2])], b"" + ), } assert_assignment(sticky_assignment, expected_assignment) def test_sticky_two_consumers_one_topic_one_partition(mocker): - cluster = create_cluster(mocker, topics={'t'}, topics_partitions={0}) + cluster = create_cluster(mocker, topics={"t"}, topics_partitions={0}) subscriptions = { - 'C1': {'t'}, - 'C2': {'t'}, + "C1": {"t"}, + "C2": {"t"}, } member_metadata = make_member_metadata(subscriptions) sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) expected_assignment = { - 'C1': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t', [0])], b''), - 'C2': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [], b''), + "C1": ConsumerProtocolMemberAssignment( + StickyPartitionAssignor.version, [("t", [0])], b"" + ), + "C2": ConsumerProtocolMemberAssignment( + StickyPartitionAssignor.version, [], b"" + ), } assert_assignment(sticky_assignment, expected_assignment) def test_sticky_two_consumers_one_topic_two_partitions(mocker): - cluster = create_cluster(mocker, topics={'t'}, topics_partitions={0, 1}) + cluster = create_cluster(mocker, topics={"t"}, topics_partitions={0, 1}) subscriptions = { - 'C1': {'t'}, - 'C2': {'t'}, + "C1": {"t"}, + "C2": {"t"}, } member_metadata = make_member_metadata(subscriptions) sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) expected_assignment = { - 'C1': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t', [0])], b''), - 'C2': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t', [1])], b''), + "C1": ConsumerProtocolMemberAssignment( + StickyPartitionAssignor.version, [("t", [0])], b"" + ), + "C2": ConsumerProtocolMemberAssignment( + StickyPartitionAssignor.version, [("t", [1])], b"" + ), } assert_assignment(sticky_assignment, expected_assignment) def test_sticky_multiple_consumers_mixed_topic_subscriptions(mocker): - partitions = {'t1': {0, 1, 2}, 't2': {0, 1}} - cluster = create_cluster(mocker, topics={'t1', 't2'}, topic_partitions_lambda=lambda t: partitions[t]) + partitions = {"t1": {0, 1, 2}, "t2": {0, 1}} + cluster = create_cluster( + mocker, topics={"t1", "t2"}, topic_partitions_lambda=lambda t: partitions[t] + ) subscriptions = { - 'C1': {'t1'}, - 'C2': {'t1', 't2'}, - 'C3': {'t1'}, + "C1": {"t1"}, + "C2": {"t1", "t2"}, + "C3": {"t1"}, } member_metadata = make_member_metadata(subscriptions) sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) expected_assignment = { - 'C1': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t1', [0, 2])], b''), - 'C2': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t2', [0, 1])], b''), - 'C3': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t1', [1])], b''), + "C1": ConsumerProtocolMemberAssignment( + StickyPartitionAssignor.version, [("t1", [0, 2])], b"" + ), + "C2": ConsumerProtocolMemberAssignment( + StickyPartitionAssignor.version, [("t2", [0, 1])], b"" + ), + "C3": ConsumerProtocolMemberAssignment( + StickyPartitionAssignor.version, [("t1", [1])], b"" + ), } assert_assignment(sticky_assignment, expected_assignment) def test_sticky_add_remove_consumer_one_topic(mocker): - cluster = create_cluster(mocker, topics={'t'}, topics_partitions={0, 1, 2}) + cluster = create_cluster(mocker, topics={"t"}, topics_partitions={0, 1, 2}) subscriptions = { - 'C1': {'t'}, + "C1": {"t"}, } member_metadata = make_member_metadata(subscriptions) assignment = StickyPartitionAssignor.assign(cluster, member_metadata) expected_assignment = { - 'C1': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t', [0, 1, 2])], b''), + "C1": ConsumerProtocolMemberAssignment( + StickyPartitionAssignor.version, [("t", [0, 1, 2])], b"" + ), } assert_assignment(assignment, expected_assignment) subscriptions = { - 'C1': {'t'}, - 'C2': {'t'}, + "C1": {"t"}, + "C2": {"t"}, } member_metadata = {} - for member, topics in six.iteritems(subscriptions): + for member, topics in subscriptions.items(): member_metadata[member] = StickyPartitionAssignor._metadata( topics, assignment[member].partitions() if member in assignment else [] ) @@ -334,86 +401,108 @@ def test_sticky_add_remove_consumer_one_topic(mocker): verify_validity_and_balance(subscriptions, assignment) subscriptions = { - 'C2': {'t'}, + "C2": {"t"}, } member_metadata = {} - for member, topics in six.iteritems(subscriptions): - member_metadata[member] = StickyPartitionAssignor._metadata(topics, assignment[member].partitions()) + for member, topics in subscriptions.items(): + member_metadata[member] = StickyPartitionAssignor._metadata( + topics, assignment[member].partitions() + ) assignment = StickyPartitionAssignor.assign(cluster, member_metadata) verify_validity_and_balance(subscriptions, assignment) - assert len(assignment['C2'].assignment[0][1]) == 3 + assert len(assignment["C2"].assignment[0][1]) == 3 def test_sticky_add_remove_topic_two_consumers(mocker): - cluster = create_cluster(mocker, topics={'t1', 't2'}, topics_partitions={0, 1, 2}) + cluster = create_cluster(mocker, topics={"t1", "t2"}, topics_partitions={0, 1, 2}) subscriptions = { - 'C1': {'t1'}, - 'C2': {'t1'}, + "C1": {"t1"}, + "C2": {"t1"}, } member_metadata = make_member_metadata(subscriptions) sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) expected_assignment = { - 'C1': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t1', [0, 2])], b''), - 'C2': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t1', [1])], b''), + "C1": ConsumerProtocolMemberAssignment( + StickyPartitionAssignor.version, [("t1", [0, 2])], b"" + ), + "C2": ConsumerProtocolMemberAssignment( + StickyPartitionAssignor.version, [("t1", [1])], b"" + ), } assert_assignment(sticky_assignment, expected_assignment) subscriptions = { - 'C1': {'t1', 't2'}, - 'C2': {'t1', 't2'}, + "C1": {"t1", "t2"}, + "C2": {"t1", "t2"}, } member_metadata = {} - for member, topics in six.iteritems(subscriptions): - member_metadata[member] = StickyPartitionAssignor._metadata(topics, sticky_assignment[member].partitions()) + for member, topics in subscriptions.items(): + member_metadata[member] = StickyPartitionAssignor._metadata( + topics, sticky_assignment[member].partitions() + ) sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) expected_assignment = { - 'C1': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t1', [0, 2]), ('t2', [1])], b''), - 'C2': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t1', [1]), ('t2', [0, 2])], b''), + "C1": ConsumerProtocolMemberAssignment( + StickyPartitionAssignor.version, [("t1", [0, 2]), ("t2", [1])], b"" + ), + "C2": ConsumerProtocolMemberAssignment( + StickyPartitionAssignor.version, [("t1", [1]), ("t2", [0, 2])], b"" + ), } assert_assignment(sticky_assignment, expected_assignment) subscriptions = { - 'C1': {'t2'}, - 'C2': {'t2'}, + "C1": {"t2"}, + "C2": {"t2"}, } member_metadata = {} - for member, topics in six.iteritems(subscriptions): - member_metadata[member] = StickyPartitionAssignor._metadata(topics, sticky_assignment[member].partitions()) + for member, topics in subscriptions.items(): + member_metadata[member] = StickyPartitionAssignor._metadata( + topics, sticky_assignment[member].partitions() + ) sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) expected_assignment = { - 'C1': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t2', [1])], b''), - 'C2': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t2', [0, 2])], b''), + "C1": ConsumerProtocolMemberAssignment( + StickyPartitionAssignor.version, [("t2", [1])], b"" + ), + "C2": ConsumerProtocolMemberAssignment( + StickyPartitionAssignor.version, [("t2", [0, 2])], b"" + ), } assert_assignment(sticky_assignment, expected_assignment) def test_sticky_reassignment_after_one_consumer_leaves(mocker): - partitions = dict([('t{}'.format(i), set(range(i))) for i in range(1, 20)]) + partitions = dict([("t{}".format(i), set(range(i))) for i in range(1, 20)]) cluster = create_cluster( - mocker, topics=set(['t{}'.format(i) for i in range(1, 20)]), topic_partitions_lambda=lambda t: partitions[t] + mocker, + topics=set(["t{}".format(i) for i in range(1, 20)]), + topic_partitions_lambda=lambda t: partitions[t], ) subscriptions = {} for i in range(1, 20): topics = set() for j in range(1, i + 1): - topics.add('t{}'.format(j)) - subscriptions['C{}'.format(i)] = topics + topics.add("t{}".format(j)) + subscriptions["C{}".format(i)] = topics member_metadata = make_member_metadata(subscriptions) assignment = StickyPartitionAssignor.assign(cluster, member_metadata) verify_validity_and_balance(subscriptions, assignment) - del subscriptions['C10'] + del subscriptions["C10"] member_metadata = {} - for member, topics in six.iteritems(subscriptions): - member_metadata[member] = StickyPartitionAssignor._metadata(topics, assignment[member].partitions()) + for member, topics in subscriptions.items(): + member_metadata[member] = StickyPartitionAssignor._metadata( + topics, assignment[member].partitions() + ) assignment = StickyPartitionAssignor.assign(cluster, member_metadata) verify_validity_and_balance(subscriptions, assignment) @@ -421,20 +510,20 @@ def test_sticky_reassignment_after_one_consumer_leaves(mocker): def test_sticky_reassignment_after_one_consumer_added(mocker): - cluster = create_cluster(mocker, topics={'t'}, topics_partitions=set(range(20))) + cluster = create_cluster(mocker, topics={"t"}, topics_partitions=set(range(20))) subscriptions = defaultdict(set) for i in range(1, 10): - subscriptions['C{}'.format(i)] = {'t'} + subscriptions["C{}".format(i)] = {"t"} member_metadata = make_member_metadata(subscriptions) assignment = StickyPartitionAssignor.assign(cluster, member_metadata) verify_validity_and_balance(subscriptions, assignment) - subscriptions['C10'] = {'t'} + subscriptions["C10"] = {"t"} member_metadata = {} - for member, topics in six.iteritems(subscriptions): + for member, topics in subscriptions.items(): member_metadata[member] = StickyPartitionAssignor._metadata( topics, assignment[member].partitions() if member in assignment else [] ) @@ -444,25 +533,29 @@ def test_sticky_reassignment_after_one_consumer_added(mocker): def test_sticky_same_subscriptions(mocker): - partitions = dict([('t{}'.format(i), set(range(i))) for i in range(1, 15)]) + partitions = dict([("t{}".format(i), set(range(i))) for i in range(1, 15)]) cluster = create_cluster( - mocker, topics=set(['t{}'.format(i) for i in range(1, 15)]), topic_partitions_lambda=lambda t: partitions[t] + mocker, + topics=set(["t{}".format(i) for i in range(1, 15)]), + topic_partitions_lambda=lambda t: partitions[t], ) subscriptions = defaultdict(set) for i in range(1, 9): - for j in range(1, len(six.viewkeys(partitions)) + 1): - subscriptions['C{}'.format(i)].add('t{}'.format(j)) + for j in range(1, len(partitions) + 1): + subscriptions["C{}".format(i)].add("t{}".format(j)) member_metadata = make_member_metadata(subscriptions) assignment = StickyPartitionAssignor.assign(cluster, member_metadata) verify_validity_and_balance(subscriptions, assignment) - del subscriptions['C5'] + del subscriptions["C5"] member_metadata = {} - for member, topics in six.iteritems(subscriptions): - member_metadata[member] = StickyPartitionAssignor._metadata(topics, assignment[member].partitions()) + for member, topics in subscriptions.items(): + member_metadata[member] = StickyPartitionAssignor._metadata( + topics, assignment[member].partitions() + ) assignment = StickyPartitionAssignor.assign(cluster, member_metadata) verify_validity_and_balance(subscriptions, assignment) assert StickyPartitionAssignor._latest_partition_movements.are_sticky() @@ -472,14 +565,16 @@ def test_sticky_large_assignment_with_multiple_consumers_leaving(mocker): n_topics = 40 n_consumers = 200 - all_topics = set(['t{}'.format(i) for i in range(1, n_topics + 1)]) + all_topics = set(["t{}".format(i) for i in range(1, n_topics + 1)]) partitions = dict([(t, set(range(1, randint(0, 10) + 1))) for t in all_topics]) - cluster = create_cluster(mocker, topics=all_topics, topic_partitions_lambda=lambda t: partitions[t]) + cluster = create_cluster( + mocker, topics=all_topics, topic_partitions_lambda=lambda t: partitions[t] + ) subscriptions = defaultdict(set) for i in range(1, n_consumers + 1): for j in range(0, randint(1, 20)): - subscriptions['C{}'.format(i)].add('t{}'.format(randint(1, n_topics))) + subscriptions["C{}".format(i)].add("t{}".format(randint(1, n_topics))) member_metadata = make_member_metadata(subscriptions) @@ -487,11 +582,13 @@ def test_sticky_large_assignment_with_multiple_consumers_leaving(mocker): verify_validity_and_balance(subscriptions, assignment) member_metadata = {} - for member, topics in six.iteritems(subscriptions): - member_metadata[member] = StickyPartitionAssignor._metadata(topics, assignment[member].partitions()) + for member, topics in subscriptions.items(): + member_metadata[member] = StickyPartitionAssignor._metadata( + topics, assignment[member].partitions() + ) for i in range(50): - member = 'C{}'.format(randint(1, n_consumers)) + member = "C{}".format(randint(1, n_consumers)) if member in subscriptions: del subscriptions[member] del member_metadata[member] @@ -502,21 +599,23 @@ def test_sticky_large_assignment_with_multiple_consumers_leaving(mocker): def test_new_subscription(mocker): - cluster = create_cluster(mocker, topics={'t1', 't2', 't3', 't4'}, topics_partitions={0}) + cluster = create_cluster( + mocker, topics={"t1", "t2", "t3", "t4"}, topics_partitions={0} + ) subscriptions = defaultdict(set) for i in range(3): for j in range(i, 3 * i - 2 + 1): - subscriptions['C{}'.format(i)].add('t{}'.format(j)) + subscriptions["C{}".format(i)].add("t{}".format(j)) member_metadata = make_member_metadata(subscriptions) assignment = StickyPartitionAssignor.assign(cluster, member_metadata) verify_validity_and_balance(subscriptions, assignment) - subscriptions['C0'].add('t1') + subscriptions["C0"].add("t1") member_metadata = {} - for member, topics in six.iteritems(subscriptions): + for member, topics in subscriptions.items(): member_metadata[member] = StickyPartitionAssignor._metadata(topics, []) assignment = StickyPartitionAssignor.assign(cluster, member_metadata) @@ -525,144 +624,176 @@ def test_new_subscription(mocker): def test_move_existing_assignments(mocker): - cluster = create_cluster(mocker, topics={'t1', 't2', 't3', 't4', 't5', 't6'}, topics_partitions={0}) + cluster = create_cluster( + mocker, topics={"t1", "t2", "t3", "t4", "t5", "t6"}, topics_partitions={0} + ) subscriptions = { - 'C1': {'t1', 't2'}, - 'C2': {'t1', 't2', 't3', 't4'}, - 'C3': {'t2', 't3', 't4', 't5', 't6'}, + "C1": {"t1", "t2"}, + "C2": {"t1", "t2", "t3", "t4"}, + "C3": {"t2", "t3", "t4", "t5", "t6"}, } member_assignments = { - 'C1': [TopicPartition('t1', 0)], - 'C2': [TopicPartition('t2', 0), TopicPartition('t3', 0)], - 'C3': [TopicPartition('t4', 0), TopicPartition('t5', 0), TopicPartition('t6', 0)], + "C1": [TopicPartition("t1", 0)], + "C2": [TopicPartition("t2", 0), TopicPartition("t3", 0)], + "C3": [ + TopicPartition("t4", 0), + TopicPartition("t5", 0), + TopicPartition("t6", 0), + ], } member_metadata = {} - for member, topics in six.iteritems(subscriptions): - member_metadata[member] = StickyPartitionAssignor._metadata(topics, member_assignments[member]) + for member, topics in subscriptions.items(): + member_metadata[member] = StickyPartitionAssignor._metadata( + topics, member_assignments[member] + ) assignment = StickyPartitionAssignor.assign(cluster, member_metadata) verify_validity_and_balance(subscriptions, assignment) def test_stickiness(mocker): - cluster = create_cluster(mocker, topics={'t'}, topics_partitions={0, 1, 2}) + cluster = create_cluster(mocker, topics={"t"}, topics_partitions={0, 1, 2}) subscriptions = { - 'C1': {'t'}, - 'C2': {'t'}, - 'C3': {'t'}, - 'C4': {'t'}, + "C1": {"t"}, + "C2": {"t"}, + "C3": {"t"}, + "C4": {"t"}, } member_metadata = make_member_metadata(subscriptions) assignment = StickyPartitionAssignor.assign(cluster, member_metadata) verify_validity_and_balance(subscriptions, assignment) partitions_assigned = {} - for consumer, consumer_assignment in six.iteritems(assignment): + for consumer, consumer_assignment in assignment.items(): assert ( len(consumer_assignment.partitions()) <= 1 - ), 'Consumer {} is assigned more topic partitions than expected.'.format(consumer) + ), "Consumer {} is assigned more topic partitions than expected.".format( + consumer + ) if len(consumer_assignment.partitions()) == 1: partitions_assigned[consumer] = consumer_assignment.partitions()[0] # removing the potential group leader - del subscriptions['C1'] + del subscriptions["C1"] member_metadata = {} - for member, topics in six.iteritems(subscriptions): - member_metadata[member] = StickyPartitionAssignor._metadata(topics, assignment[member].partitions()) + for member, topics in subscriptions.items(): + member_metadata[member] = StickyPartitionAssignor._metadata( + topics, assignment[member].partitions() + ) assignment = StickyPartitionAssignor.assign(cluster, member_metadata) verify_validity_and_balance(subscriptions, assignment) assert StickyPartitionAssignor._latest_partition_movements.are_sticky() - for consumer, consumer_assignment in six.iteritems(assignment): + for consumer, consumer_assignment in assignment.items(): assert ( len(consumer_assignment.partitions()) <= 1 - ), 'Consumer {} is assigned more topic partitions than expected.'.format(consumer) + ), "Consumer {} is assigned more topic partitions than expected.".format( + consumer + ) assert ( - consumer not in partitions_assigned or partitions_assigned[consumer] in consumer_assignment.partitions() - ), 'Stickiness was not honored for consumer {}'.format(consumer) + consumer not in partitions_assigned + or partitions_assigned[consumer] in consumer_assignment.partitions() + ), "Stickiness was not honored for consumer {}".format(consumer) def test_assignment_updated_for_deleted_topic(mocker): def topic_partitions(topic): - if topic == 't1': + if topic == "t1": return {0} - if topic == 't3': + if topic == "t3": return set(range(100)) - cluster = create_cluster(mocker, topics={'t1', 't3'}, topic_partitions_lambda=topic_partitions) + cluster = create_cluster( + mocker, topics={"t1", "t3"}, topic_partitions_lambda=topic_partitions + ) subscriptions = { - 'C': {'t1', 't2', 't3'}, + "C": {"t1", "t2", "t3"}, } member_metadata = make_member_metadata(subscriptions) sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) expected_assignment = { - 'C': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t1', [0]), ('t3', list(range(100)))], b''), + "C": ConsumerProtocolMemberAssignment( + StickyPartitionAssignor.version, + [("t1", [0]), ("t3", list(range(100)))], + b"", + ), } assert_assignment(sticky_assignment, expected_assignment) def test_no_exceptions_when_only_subscribed_topic_is_deleted(mocker): - cluster = create_cluster(mocker, topics={'t'}, topics_partitions={0, 1, 2}) + cluster = create_cluster(mocker, topics={"t"}, topics_partitions={0, 1, 2}) subscriptions = { - 'C': {'t'}, + "C": {"t"}, } member_metadata = make_member_metadata(subscriptions) sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) expected_assignment = { - 'C': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [('t', [0, 1, 2])], b''), + "C": ConsumerProtocolMemberAssignment( + StickyPartitionAssignor.version, [("t", [0, 1, 2])], b"" + ), } assert_assignment(sticky_assignment, expected_assignment) subscriptions = { - 'C': {}, + "C": {}, } member_metadata = {} - for member, topics in six.iteritems(subscriptions): - member_metadata[member] = StickyPartitionAssignor._metadata(topics, sticky_assignment[member].partitions()) + for member, topics in subscriptions.items(): + member_metadata[member] = StickyPartitionAssignor._metadata( + topics, sticky_assignment[member].partitions() + ) cluster = create_cluster(mocker, topics={}, topics_partitions={}) sticky_assignment = StickyPartitionAssignor.assign(cluster, member_metadata) expected_assignment = { - 'C': ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [], b''), + "C": ConsumerProtocolMemberAssignment(StickyPartitionAssignor.version, [], b""), } assert_assignment(sticky_assignment, expected_assignment) def test_conflicting_previous_assignments(mocker): - cluster = create_cluster(mocker, topics={'t'}, topics_partitions={0, 1}) + cluster = create_cluster(mocker, topics={"t"}, topics_partitions={0, 1}) subscriptions = { - 'C1': {'t'}, - 'C2': {'t'}, + "C1": {"t"}, + "C2": {"t"}, } member_metadata = {} - for member, topics in six.iteritems(subscriptions): + for member, topics in subscriptions.items(): # assume both C1 and C2 have partition 1 assigned to them in generation 1 - member_metadata[member] = StickyPartitionAssignor._metadata(topics, [TopicPartition('t', 0), TopicPartition('t', 0)], 1) + member_metadata[member] = StickyPartitionAssignor._metadata( + topics, [TopicPartition("t", 0), TopicPartition("t", 0)], 1 + ) assignment = StickyPartitionAssignor.assign(cluster, member_metadata) verify_validity_and_balance(subscriptions, assignment) @pytest.mark.parametrize( - 'execution_number,n_topics,n_consumers', [(i, randint(10, 20), randint(20, 40)) for i in range(100)] + "execution_number,n_topics,n_consumers", + [(i, randint(10, 20), randint(20, 40)) for i in range(100)], ) -def test_reassignment_with_random_subscriptions_and_changes(mocker, execution_number, n_topics, n_consumers): - all_topics = sorted(['t{}'.format(i) for i in range(1, n_topics + 1)]) +def test_reassignment_with_random_subscriptions_and_changes( + mocker, execution_number, n_topics, n_consumers +): + all_topics = sorted(["t{}".format(i) for i in range(1, n_topics + 1)]) partitions = dict([(t, set(range(1, i + 1))) for i, t in enumerate(all_topics)]) - cluster = create_cluster(mocker, topics=all_topics, topic_partitions_lambda=lambda t: partitions[t]) + cluster = create_cluster( + mocker, topics=all_topics, topic_partitions_lambda=lambda t: partitions[t] + ) subscriptions = defaultdict(set) for i in range(n_consumers): topics_sample = sample(all_topics, randint(1, len(all_topics) - 1)) - subscriptions['C{}'.format(i)].update(topics_sample) + subscriptions["C{}".format(i)].update(topics_sample) member_metadata = make_member_metadata(subscriptions) @@ -672,11 +803,13 @@ def test_reassignment_with_random_subscriptions_and_changes(mocker, execution_nu subscriptions = defaultdict(set) for i in range(n_consumers): topics_sample = sample(all_topics, randint(1, len(all_topics) - 1)) - subscriptions['C{}'.format(i)].update(topics_sample) + subscriptions["C{}".format(i)].update(topics_sample) member_metadata = {} - for member, topics in six.iteritems(subscriptions): - member_metadata[member] = StickyPartitionAssignor._metadata(topics, assignment[member].partitions()) + for member, topics in subscriptions.items(): + member_metadata[member] = StickyPartitionAssignor._metadata( + topics, assignment[member].partitions() + ) assignment = StickyPartitionAssignor.assign(cluster, member_metadata) verify_validity_and_balance(subscriptions, assignment) @@ -684,110 +817,145 @@ def test_reassignment_with_random_subscriptions_and_changes(mocker, execution_nu def test_assignment_with_multiple_generations1(mocker): - cluster = create_cluster(mocker, topics={'t'}, topics_partitions={0, 1, 2, 3, 4, 5}) + cluster = create_cluster(mocker, topics={"t"}, topics_partitions={0, 1, 2, 3, 4, 5}) member_metadata = { - 'C1': StickyPartitionAssignor._metadata({'t'}, []), - 'C2': StickyPartitionAssignor._metadata({'t'}, []), - 'C3': StickyPartitionAssignor._metadata({'t'}, []), + "C1": StickyPartitionAssignor._metadata({"t"}, []), + "C2": StickyPartitionAssignor._metadata({"t"}, []), + "C3": StickyPartitionAssignor._metadata({"t"}, []), } assignment1 = StickyPartitionAssignor.assign(cluster, member_metadata) - verify_validity_and_balance({'C1': {'t'}, 'C2': {'t'}, 'C3': {'t'}}, assignment1) - assert len(assignment1['C1'].assignment[0][1]) == 2 - assert len(assignment1['C2'].assignment[0][1]) == 2 - assert len(assignment1['C3'].assignment[0][1]) == 2 + verify_validity_and_balance({"C1": {"t"}, "C2": {"t"}, "C3": {"t"}}, assignment1) + assert len(assignment1["C1"].assignment[0][1]) == 2 + assert len(assignment1["C2"].assignment[0][1]) == 2 + assert len(assignment1["C3"].assignment[0][1]) == 2 member_metadata = { - 'C1': StickyPartitionAssignor._metadata({'t'}, assignment1['C1'].partitions()), - 'C2': StickyPartitionAssignor._metadata({'t'}, assignment1['C2'].partitions()), + "C1": StickyPartitionAssignor._metadata({"t"}, assignment1["C1"].partitions()), + "C2": StickyPartitionAssignor._metadata({"t"}, assignment1["C2"].partitions()), } assignment2 = StickyPartitionAssignor.assign(cluster, member_metadata) - verify_validity_and_balance({'C1': {'t'}, 'C2': {'t'}}, assignment2) - assert len(assignment2['C1'].assignment[0][1]) == 3 - assert len(assignment2['C2'].assignment[0][1]) == 3 - assert all([partition in assignment2['C1'].assignment[0][1] for partition in assignment1['C1'].assignment[0][1]]) - assert all([partition in assignment2['C2'].assignment[0][1] for partition in assignment1['C2'].assignment[0][1]]) + verify_validity_and_balance({"C1": {"t"}, "C2": {"t"}}, assignment2) + assert len(assignment2["C1"].assignment[0][1]) == 3 + assert len(assignment2["C2"].assignment[0][1]) == 3 + assert all( + [ + partition in assignment2["C1"].assignment[0][1] + for partition in assignment1["C1"].assignment[0][1] + ] + ) + assert all( + [ + partition in assignment2["C2"].assignment[0][1] + for partition in assignment1["C2"].assignment[0][1] + ] + ) assert StickyPartitionAssignor._latest_partition_movements.are_sticky() member_metadata = { - 'C2': StickyPartitionAssignor._metadata({'t'}, assignment2['C2'].partitions(), 2), - 'C3': StickyPartitionAssignor._metadata({'t'}, assignment1['C3'].partitions(), 1), + "C2": StickyPartitionAssignor._metadata( + {"t"}, assignment2["C2"].partitions(), 2 + ), + "C3": StickyPartitionAssignor._metadata( + {"t"}, assignment1["C3"].partitions(), 1 + ), } assignment3 = StickyPartitionAssignor.assign(cluster, member_metadata) - verify_validity_and_balance({'C2': {'t'}, 'C3': {'t'}}, assignment3) - assert len(assignment3['C2'].assignment[0][1]) == 3 - assert len(assignment3['C3'].assignment[0][1]) == 3 + verify_validity_and_balance({"C2": {"t"}, "C3": {"t"}}, assignment3) + assert len(assignment3["C2"].assignment[0][1]) == 3 + assert len(assignment3["C3"].assignment[0][1]) == 3 assert StickyPartitionAssignor._latest_partition_movements.are_sticky() def test_assignment_with_multiple_generations2(mocker): - cluster = create_cluster(mocker, topics={'t'}, topics_partitions={0, 1, 2, 3, 4, 5}) + cluster = create_cluster(mocker, topics={"t"}, topics_partitions={0, 1, 2, 3, 4, 5}) member_metadata = { - 'C1': StickyPartitionAssignor._metadata({'t'}, []), - 'C2': StickyPartitionAssignor._metadata({'t'}, []), - 'C3': StickyPartitionAssignor._metadata({'t'}, []), + "C1": StickyPartitionAssignor._metadata({"t"}, []), + "C2": StickyPartitionAssignor._metadata({"t"}, []), + "C3": StickyPartitionAssignor._metadata({"t"}, []), } assignment1 = StickyPartitionAssignor.assign(cluster, member_metadata) - verify_validity_and_balance({'C1': {'t'}, 'C2': {'t'}, 'C3': {'t'}}, assignment1) - assert len(assignment1['C1'].assignment[0][1]) == 2 - assert len(assignment1['C2'].assignment[0][1]) == 2 - assert len(assignment1['C3'].assignment[0][1]) == 2 + verify_validity_and_balance({"C1": {"t"}, "C2": {"t"}, "C3": {"t"}}, assignment1) + assert len(assignment1["C1"].assignment[0][1]) == 2 + assert len(assignment1["C2"].assignment[0][1]) == 2 + assert len(assignment1["C3"].assignment[0][1]) == 2 member_metadata = { - 'C2': StickyPartitionAssignor._metadata({'t'}, assignment1['C2'].partitions(), 1), + "C2": StickyPartitionAssignor._metadata( + {"t"}, assignment1["C2"].partitions(), 1 + ), } assignment2 = StickyPartitionAssignor.assign(cluster, member_metadata) - verify_validity_and_balance({'C2': {'t'}}, assignment2) - assert len(assignment2['C2'].assignment[0][1]) == 6 - assert all([partition in assignment2['C2'].assignment[0][1] for partition in assignment1['C2'].assignment[0][1]]) + verify_validity_and_balance({"C2": {"t"}}, assignment2) + assert len(assignment2["C2"].assignment[0][1]) == 6 + assert all( + [ + partition in assignment2["C2"].assignment[0][1] + for partition in assignment1["C2"].assignment[0][1] + ] + ) assert StickyPartitionAssignor._latest_partition_movements.are_sticky() member_metadata = { - 'C1': StickyPartitionAssignor._metadata({'t'}, assignment1['C1'].partitions(), 1), - 'C2': StickyPartitionAssignor._metadata({'t'}, assignment2['C2'].partitions(), 2), - 'C3': StickyPartitionAssignor._metadata({'t'}, assignment1['C3'].partitions(), 1), + "C1": StickyPartitionAssignor._metadata( + {"t"}, assignment1["C1"].partitions(), 1 + ), + "C2": StickyPartitionAssignor._metadata( + {"t"}, assignment2["C2"].partitions(), 2 + ), + "C3": StickyPartitionAssignor._metadata( + {"t"}, assignment1["C3"].partitions(), 1 + ), } assignment3 = StickyPartitionAssignor.assign(cluster, member_metadata) - verify_validity_and_balance({'C1': {'t'}, 'C2': {'t'}, 'C3': {'t'}}, assignment3) + verify_validity_and_balance({"C1": {"t"}, "C2": {"t"}, "C3": {"t"}}, assignment3) assert StickyPartitionAssignor._latest_partition_movements.are_sticky() - assert set(assignment3['C1'].assignment[0][1]) == set(assignment1['C1'].assignment[0][1]) - assert set(assignment3['C2'].assignment[0][1]) == set(assignment1['C2'].assignment[0][1]) - assert set(assignment3['C3'].assignment[0][1]) == set(assignment1['C3'].assignment[0][1]) + assert set(assignment3["C1"].assignment[0][1]) == set( + assignment1["C1"].assignment[0][1] + ) + assert set(assignment3["C2"].assignment[0][1]) == set( + assignment1["C2"].assignment[0][1] + ) + assert set(assignment3["C3"].assignment[0][1]) == set( + assignment1["C3"].assignment[0][1] + ) -@pytest.mark.parametrize('execution_number', range(50)) +@pytest.mark.parametrize("execution_number", range(50)) def test_assignment_with_conflicting_previous_generations(mocker, execution_number): - cluster = create_cluster(mocker, topics={'t'}, topics_partitions={0, 1, 2, 3, 4, 5}) + cluster = create_cluster(mocker, topics={"t"}, topics_partitions={0, 1, 2, 3, 4, 5}) member_assignments = { - 'C1': [TopicPartition('t', p) for p in {0, 1, 4}], - 'C2': [TopicPartition('t', p) for p in {0, 2, 3}], - 'C3': [TopicPartition('t', p) for p in {3, 4, 5}], + "C1": [TopicPartition("t", p) for p in {0, 1, 4}], + "C2": [TopicPartition("t", p) for p in {0, 2, 3}], + "C3": [TopicPartition("t", p) for p in {3, 4, 5}], } member_generations = { - 'C1': 1, - 'C2': 1, - 'C3': 2, + "C1": 1, + "C2": 1, + "C3": 2, } member_metadata = {} - for member in six.iterkeys(member_assignments): - member_metadata[member] = StickyPartitionAssignor._metadata({'t'}, member_assignments[member], member_generations[member]) + for member in member_assignments.keys(): + member_metadata[member] = StickyPartitionAssignor._metadata( + {"t"}, member_assignments[member], member_generations[member] + ) assignment = StickyPartitionAssignor.assign(cluster, member_metadata) - verify_validity_and_balance({'C1': {'t'}, 'C2': {'t'}, 'C3': {'t'}}, assignment) + verify_validity_and_balance({"C1": {"t"}, "C2": {"t"}, "C3": {"t"}}, assignment) assert StickyPartitionAssignor._latest_partition_movements.are_sticky() def make_member_metadata(subscriptions): member_metadata = {} - for member, topics in six.iteritems(subscriptions): + for member, topics in subscriptions.items(): member_metadata[member] = StickyPartitionAssignor._metadata(topics, []) return member_metadata @@ -796,7 +964,9 @@ def assert_assignment(result_assignment, expected_assignment): assert result_assignment == expected_assignment assert set(result_assignment) == set(expected_assignment) for member in result_assignment: - assert result_assignment[member].encode() == expected_assignment[member].encode() + assert ( + result_assignment[member].encode() == expected_assignment[member].encode() + ) def verify_validity_and_balance(subscriptions, assignment): @@ -806,24 +976,28 @@ def verify_validity_and_balance(subscriptions, assignment): - each consumer is subscribed to topics of all partitions assigned to it, and - each partition is assigned to no more than one consumer Balance requirements: - - the assignment is fully balanced (the numbers of topic partitions assigned to consumers differ by at most one), or - - there is no topic partition that can be moved from one consumer to another with 2+ fewer topic partitions + - the assignment is fully balanced (the numbers of topic partitions assigned to + consumers differ by at most one), or + - there is no topic partition that can be moved from one consumer to another with + 2+ fewer topic partitions :param subscriptions topic subscriptions of each consumer :param assignment: given assignment for balance check """ - assert six.viewkeys(subscriptions) == six.viewkeys(assignment) + assert subscriptions.keys() == assignment.keys() - consumers = sorted(six.viewkeys(assignment)) + consumers = sorted(assignment.keys()) for i in range(len(consumers)): consumer = consumers[i] partitions = assignment[consumer].partitions() for partition in partitions: assert partition.topic in subscriptions[consumer], ( - 'Error: Partition {} is assigned to consumer {}, ' - 'but it is not subscribed to topic {}\n' - 'Subscriptions: {}\n' - 'Assignments: {}'.format(partition, consumers[i], partition.topic, subscriptions, assignment) + "Error: Partition {} is assigned to consumer {}, " + "but it is not subscribed to topic {}\n" + "Subscriptions: {}\n" + "Assignments: {}".format( + partition, consumers[i], partition.topic, subscriptions, assignment + ) ) if i == len(consumers) - 1: continue @@ -831,12 +1005,20 @@ def verify_validity_and_balance(subscriptions, assignment): for j in range(i + 1, len(consumers)): other_consumer = consumers[j] other_partitions = assignment[other_consumer].partitions() - partitions_intersection = set(partitions).intersection(set(other_partitions)) + partitions_intersection = set(partitions).intersection( + set(other_partitions) + ) assert partitions_intersection == set(), ( - 'Error: Consumers {} and {} have common partitions ' - 'assigned to them: {}\n' - 'Subscriptions: {}\n' - 'Assignments: {}'.format(consumer, other_consumer, partitions_intersection, subscriptions, assignment) + "Error: Consumers {} and {} have common partitions " + "assigned to them: {}\n" + "Subscriptions: {}\n" + "Assignments: {}".format( + consumer, + other_consumer, + partitions_intersection, + subscriptions, + assignment, + ) ) if abs(len(partitions) - len(other_partitions)) <= 1: @@ -845,22 +1027,36 @@ def verify_validity_and_balance(subscriptions, assignment): assignments_by_topic = group_partitions_by_topic(partitions) other_assignments_by_topic = group_partitions_by_topic(other_partitions) if len(partitions) > len(other_partitions): - for topic in six.iterkeys(assignments_by_topic): + for topic in assignments_by_topic.keys(): assert topic not in other_assignments_by_topic, ( - 'Error: Some partitions can be moved from {} ({} partitions) ' - 'to {} ({} partitions) ' - 'to achieve a better balance\n' - 'Subscriptions: {}\n' - 'Assignments: {}'.format(consumer, len(partitions), other_consumer, len(other_partitions), subscriptions, assignment) + "Error: Some partitions can be moved from {} ({} partitions) " + "to {} ({} partitions) " + "to achieve a better balance\n" + "Subscriptions: {}\n" + "Assignments: {}".format( + consumer, + len(partitions), + other_consumer, + len(other_partitions), + subscriptions, + assignment, + ) ) if len(other_partitions) > len(partitions): - for topic in six.iterkeys(other_assignments_by_topic): + for topic in other_assignments_by_topic.keys(): assert topic not in assignments_by_topic, ( - 'Error: Some partitions can be moved from {} ({} partitions) ' - 'to {} ({} partitions) ' - 'to achieve a better balance\n' - 'Subscriptions: {}\n' - 'Assignments: {}'.format(other_consumer, len(other_partitions), consumer, len(partitions), subscriptions, assignment) + "Error: Some partitions can be moved from {} ({} partitions) " + "to {} ({} partitions) " + "to achieve a better balance\n" + "Subscriptions: {}\n" + "Assignments: {}".format( + other_consumer, + len(other_partitions), + consumer, + len(partitions), + subscriptions, + assignment, + ) ) diff --git a/tests/coordinator/test_partition_movements.py b/tests/coordinator/test_partition_movements.py new file mode 100644 index 00000000..d5da876b --- /dev/null +++ b/tests/coordinator/test_partition_movements.py @@ -0,0 +1,23 @@ +from kafka.structs import TopicPartition + +from aiokafka.coordinator.assignors.sticky.partition_movements import PartitionMovements + + +def test_empty_movements_are_sticky(): + partition_movements = PartitionMovements() + assert partition_movements.are_sticky() + + +def test_sticky_movements(): + partition_movements = PartitionMovements() + partition_movements.move_partition(TopicPartition("t", 1), "C1", "C2") + partition_movements.move_partition(TopicPartition("t", 1), "C2", "C3") + partition_movements.move_partition(TopicPartition("t", 1), "C3", "C1") + assert partition_movements.are_sticky() + + +def test_should_detect_non_sticky_assignment(): + partition_movements = PartitionMovements() + partition_movements.move_partition(TopicPartition("t", 1), "C1", "C2") + partition_movements.move_partition(TopicPartition("t", 2), "C2", "C1") + assert not partition_movements.are_sticky() diff --git a/tests/kafka/test_partition_movements.py b/tests/kafka/test_partition_movements.py deleted file mode 100644 index bc990bf3..00000000 --- a/tests/kafka/test_partition_movements.py +++ /dev/null @@ -1,23 +0,0 @@ -from kafka.structs import TopicPartition - -from kafka.coordinator.assignors.sticky.partition_movements import PartitionMovements - - -def test_empty_movements_are_sticky(): - partition_movements = PartitionMovements() - assert partition_movements.are_sticky() - - -def test_sticky_movements(): - partition_movements = PartitionMovements() - partition_movements.move_partition(TopicPartition('t', 1), 'C1', 'C2') - partition_movements.move_partition(TopicPartition('t', 1), 'C2', 'C3') - partition_movements.move_partition(TopicPartition('t', 1), 'C3', 'C1') - assert partition_movements.are_sticky() - - -def test_should_detect_non_sticky_assignment(): - partition_movements = PartitionMovements() - partition_movements.move_partition(TopicPartition('t', 1), 'C1', 'C2') - partition_movements.move_partition(TopicPartition('t', 2), 'C2', 'C1') - assert not partition_movements.are_sticky() From f311fb0126e6143ac682fd9cbe4ca442c2e49c3f Mon Sep 17 00:00:00 2001 From: Denis Otkidach Date: Sun, 22 Oct 2023 18:52:30 +0300 Subject: [PATCH 11/20] Move partitioner --- .../default.py => aiokafka/partitioner.py | 64 +++++++++---------- aiokafka/producer/producer.py | 2 +- kafka/partitioner/__init__.py | 8 --- tests/{kafka => }/test_partitioner.py | 27 ++++---- 4 files changed, 46 insertions(+), 55 deletions(-) rename kafka/partitioner/default.py => aiokafka/partitioner.py (63%) delete mode 100644 kafka/partitioner/__init__.py rename tests/{kafka => }/test_partitioner.py (67%) diff --git a/kafka/partitioner/default.py b/aiokafka/partitioner.py similarity index 63% rename from kafka/partitioner/default.py rename to aiokafka/partitioner.py index d0914c68..dabc3def 100644 --- a/kafka/partitioner/default.py +++ b/aiokafka/partitioner.py @@ -1,9 +1,5 @@ -from __future__ import absolute_import - import random -from kafka.vendor import six - class DefaultPartitioner(object): """Default partitioner. @@ -12,6 +8,7 @@ class DefaultPartitioner(object): If key is None, selects partition randomly from available, or from all partitions if none are currently available """ + @classmethod def __call__(cls, key, all_partitions, available): """ @@ -27,7 +24,7 @@ def __call__(cls, key, all_partitions, available): return random.choice(all_partitions) idx = murmur2(key) - idx &= 0x7fffffff + idx &= 0x7FFFFFFF idx %= len(all_partitions) return all_partitions[idx] @@ -43,16 +40,11 @@ def murmur2(data): Returns: MurmurHash2 of data """ - # Python2 bytes is really a str, causing the bitwise operations below to fail - # so convert to bytearray. - if six.PY2: - data = bytearray(bytes(data)) - length = len(data) - seed = 0x9747b28c + seed = 0x9747B28C # 'm' and 'r' are mixing constants generated offline. # They're not really 'magic', they just happen to work well. - m = 0x5bd1e995 + m = 0x5BD1E995 r = 24 # Initialize the hash to a random value @@ -61,42 +53,44 @@ def murmur2(data): for i in range(length4): i4 = i * 4 - k = ((data[i4 + 0] & 0xff) + - ((data[i4 + 1] & 0xff) << 8) + - ((data[i4 + 2] & 0xff) << 16) + - ((data[i4 + 3] & 0xff) << 24)) - k &= 0xffffffff + k = ( + (data[i4 + 0] & 0xFF) + + ((data[i4 + 1] & 0xFF) << 8) + + ((data[i4 + 2] & 0xFF) << 16) + + ((data[i4 + 3] & 0xFF) << 24) + ) + k &= 0xFFFFFFFF k *= m - k &= 0xffffffff - k ^= (k % 0x100000000) >> r # k ^= k >>> r - k &= 0xffffffff + k &= 0xFFFFFFFF + k ^= (k % 0x100000000) >> r # k ^= k >>> r + k &= 0xFFFFFFFF k *= m - k &= 0xffffffff + k &= 0xFFFFFFFF h *= m - h &= 0xffffffff + h &= 0xFFFFFFFF h ^= k - h &= 0xffffffff + h &= 0xFFFFFFFF # Handle the last few bytes of the input array extra_bytes = length % 4 if extra_bytes >= 3: - h ^= (data[(length & ~3) + 2] & 0xff) << 16 - h &= 0xffffffff + h ^= (data[(length & ~3) + 2] & 0xFF) << 16 + h &= 0xFFFFFFFF if extra_bytes >= 2: - h ^= (data[(length & ~3) + 1] & 0xff) << 8 - h &= 0xffffffff + h ^= (data[(length & ~3) + 1] & 0xFF) << 8 + h &= 0xFFFFFFFF if extra_bytes >= 1: - h ^= (data[length & ~3] & 0xff) - h &= 0xffffffff + h ^= data[length & ~3] & 0xFF + h &= 0xFFFFFFFF h *= m - h &= 0xffffffff + h &= 0xFFFFFFFF - h ^= (h % 0x100000000) >> 13 # h >>> 13; - h &= 0xffffffff + h ^= (h % 0x100000000) >> 13 # h >>> 13; + h &= 0xFFFFFFFF h *= m - h &= 0xffffffff - h ^= (h % 0x100000000) >> 15 # h >>> 15; - h &= 0xffffffff + h &= 0xFFFFFFFF + h ^= (h % 0x100000000) >> 15 # h >>> 15; + h &= 0xFFFFFFFF return h diff --git a/aiokafka/producer/producer.py b/aiokafka/producer/producer.py index 7bfd1089..c8d763b6 100644 --- a/aiokafka/producer/producer.py +++ b/aiokafka/producer/producer.py @@ -4,12 +4,12 @@ import traceback import warnings -from kafka.partitioner.default import DefaultPartitioner from kafka.codec import has_gzip, has_snappy, has_lz4, has_zstd from aiokafka.client import AIOKafkaClient from aiokafka.errors import ( MessageSizeTooLargeError, UnsupportedVersionError, IllegalOperation) +from aiokafka.partitioner import DefaultPartitioner from aiokafka.record.default_records import DefaultRecordBatch from aiokafka.record.legacy_records import LegacyRecordBatchBuilder from aiokafka.structs import TopicPartition diff --git a/kafka/partitioner/__init__.py b/kafka/partitioner/__init__.py deleted file mode 100644 index 21a3bbb6..00000000 --- a/kafka/partitioner/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from __future__ import absolute_import - -from kafka.partitioner.default import DefaultPartitioner, murmur2 - - -__all__ = [ - 'DefaultPartitioner', 'murmur2' -] diff --git a/tests/kafka/test_partitioner.py b/tests/test_partitioner.py similarity index 67% rename from tests/kafka/test_partitioner.py rename to tests/test_partitioner.py index 853fbf69..af0a7cb2 100644 --- a/tests/kafka/test_partitioner.py +++ b/tests/test_partitioner.py @@ -1,16 +1,14 @@ -from __future__ import absolute_import - import pytest -from kafka.partitioner import DefaultPartitioner, murmur2 +from aiokafka.partitioner import DefaultPartitioner, murmur2 def test_default_partitioner(): partitioner = DefaultPartitioner() all_partitions = available = list(range(100)) # partitioner should return the same partition for the same key - p1 = partitioner(b'foo', all_partitions, available) - p2 = partitioner(b'foo', all_partitions, available) + p1 = partitioner(b"foo", all_partitions, available) + p2 = partitioner(b"foo", all_partitions, available) assert p1 == p2 assert p1 in all_partitions @@ -21,10 +19,17 @@ def test_default_partitioner(): assert partitioner(None, all_partitions, []) in all_partitions -@pytest.mark.parametrize("bytes_payload,partition_number", [ - (b'', 681), (b'a', 524), (b'ab', 434), (b'abc', 107), (b'123456789', 566), - (b'\x00 ', 742) -]) +@pytest.mark.parametrize( + "bytes_payload,partition_number", + [ + (b"", 681), + (b"a", 524), + (b"ab", 434), + (b"abc", 107), + (b"123456789", 566), + (b"\x00 ", 742), + ], +) def test_murmur2_java_compatibility(bytes_payload, partition_number): partitioner = DefaultPartitioner() all_partitions = available = list(range(1000)) @@ -34,5 +39,5 @@ def test_murmur2_java_compatibility(bytes_payload, partition_number): def test_murmur2_not_ascii(): # Verify no regression of murmur2() bug encoding py2 bytes that don't ascii encode - murmur2(b'\xa4') - murmur2(b'\x81' * 1000) + murmur2(b"\xa4") + murmur2(b"\x81" * 1000) From 81e5bf23d2de526559f9a2ad1bf5f636b486f8cc Mon Sep 17 00:00:00 2001 From: Denis Otkidach Date: Sun, 22 Oct 2023 21:01:46 +0300 Subject: [PATCH 12/20] Merge cluster --- aiokafka/cluster.py | 326 +++++- kafka/__init__.py | 1 - kafka/cluster.py | 397 -------- kafka/conn.py | 1534 ----------------------------- tests/kafka/conftest.py | 27 - tests/kafka/test_conn.py | 342 ------- tests/{kafka => }/test_cluster.py | 8 +- tests/test_message_accumulator.py | 6 +- tests/test_producer.py | 13 +- 9 files changed, 320 insertions(+), 2334 deletions(-) delete mode 100644 kafka/cluster.py delete mode 100644 kafka/conn.py delete mode 100644 tests/kafka/test_conn.py rename tests/{kafka => }/test_cluster.py (81%) diff --git a/aiokafka/cluster.py b/aiokafka/cluster.py index ccb38eb8..fd565422 100644 --- a/aiokafka/cluster.py +++ b/aiokafka/cluster.py @@ -1,35 +1,230 @@ import collections +import copy import logging +import threading import time -from kafka.cluster import ClusterMetadata as BaseClusterMetadata -from aiokafka.structs import BrokerMetadata, PartitionMetadata, TopicPartition +from kafka.future import Future + from aiokafka import errors as Errors +from aiokafka.conn import collect_hosts +from aiokafka.structs import BrokerMetadata, PartitionMetadata, TopicPartition log = logging.getLogger(__name__) -class ClusterMetadata(BaseClusterMetadata): +class ClusterMetadata: + """ + A class to manage kafka cluster metadata. + + This class does not perform any IO. It simply updates internal state + given API responses (MetadataResponse, GroupCoordinatorResponse). + + Keyword Arguments: + retry_backoff_ms (int): Milliseconds to backoff when retrying on + errors. Default: 100. + metadata_max_age_ms (int): The period of time in milliseconds after + which we force a refresh of metadata even if we haven't seen any + partition leadership changes to proactively discover any new + brokers or partitions. Default: 300000 + bootstrap_servers: 'host[:port]' string (or list of 'host[:port]' + strings) that the client should contact to bootstrap initial + cluster metadata. This does not have to be the full node list. + It just needs to have at least one broker that will respond to a + Metadata API Request. Default port is 9092. If no servers are + specified, will default to localhost:9092. + """ + DEFAULT_CONFIG = { + 'retry_backoff_ms': 100, + 'metadata_max_age_ms': 300000, + 'bootstrap_servers': [], + } + + def __init__(self, **configs): + self._brokers = {} # node_id -> BrokerMetadata + self._partitions = {} # topic -> partition -> PartitionMetadata + # node_id -> {TopicPartition...} + self._broker_partitions = collections.defaultdict(set) + self._groups = {} # group_name -> node_id + self._last_refresh_ms = 0 + self._last_successful_refresh_ms = 0 + self._need_update = True + self._future = None + self._listeners = set() + self._lock = threading.Lock() + self.need_all_topic_metadata = False + self.unauthorized_topics = set() + self.internal_topics = set() + self.controller = None + + self.config = copy.copy(self.DEFAULT_CONFIG) + for key in self.config: + if key in configs: + self.config[key] = configs[key] - def __init__(self, *args, **kw): - super().__init__(*args, **kw) + self._bootstrap_brokers = self._generate_bootstrap_brokers() + self._coordinator_brokers = {} self._coordinators = {} self._coordinator_by_key = {} - def coordinator_metadata(self, node_id): - return self._coordinators.get(node_id) + def _generate_bootstrap_brokers(self): + # collect_hosts does not perform DNS, so we should be fine to re-use + bootstrap_hosts = collect_hosts(self.config['bootstrap_servers']) - def add_coordinator(self, node_id, host, port, rack=None, *, purpose): - """ Keep track of all coordinator nodes separately and remove them if - a new one was elected for the same purpose (For example group - coordinator for group X). + brokers = {} + for i, (host, port, _) in enumerate(bootstrap_hosts): + node_id = 'bootstrap-%s' % i + brokers[node_id] = BrokerMetadata(node_id, host, port, None) + return brokers + + def is_bootstrap(self, node_id): + return node_id in self._bootstrap_brokers + + def brokers(self): + """Get all BrokerMetadata + + Returns: + set: {BrokerMetadata, ...} """ - if purpose in self._coordinator_by_key: - old_id = self._coordinator_by_key.pop(purpose) - del self._coordinators[old_id] + return set(self._brokers.values()) or set(self._bootstrap_brokers.values()) - self._coordinators[node_id] = BrokerMetadata(node_id, host, port, rack) - self._coordinator_by_key[purpose] = node_id + def broker_metadata(self, broker_id): + """Get BrokerMetadata + + Arguments: + broker_id (int): node_id for a broker to check + + Returns: + BrokerMetadata or None if not found + """ + return ( + self._brokers.get(broker_id) or + self._bootstrap_brokers.get(broker_id) or + self._coordinator_brokers.get(broker_id) + ) + + def partitions_for_topic(self, topic): + """Return set of all partitions for topic (whether available or not) + + Arguments: + topic (str): topic to check for partitions + + Returns: + set: {partition (int), ...} + """ + if topic not in self._partitions: + return None + return set(self._partitions[topic].keys()) + + def available_partitions_for_topic(self, topic): + """Return set of partitions with known leaders + + Arguments: + topic (str): topic to check for partitions + + Returns: + set: {partition (int), ...} + None if topic not found. + """ + if topic not in self._partitions: + return None + return set([partition for partition, metadata + in self._partitions[topic].items() + if metadata.leader != -1]) + + def leader_for_partition(self, partition): + """Return node_id of leader, -1 unavailable, None if unknown.""" + if partition.topic not in self._partitions: + return None + elif partition.partition not in self._partitions[partition.topic]: + return None + return self._partitions[partition.topic][partition.partition].leader + + def partitions_for_broker(self, broker_id): + """Return TopicPartitions for which the broker is a leader. + + Arguments: + broker_id (int): node id for a broker + + Returns: + set: {TopicPartition, ...} + None if the broker either has no partitions or does not exist. + """ + return self._broker_partitions.get(broker_id) + + def coordinator_for_group(self, group): + """Return node_id of group coordinator. + + Arguments: + group (str): name of consumer group + + Returns: + int: node_id for group coordinator + None if the group does not exist. + """ + return self._groups.get(group) + + def ttl(self): + """Milliseconds until metadata should be refreshed""" + now = time.time() * 1000 + if self._need_update: + ttl = 0 + else: + metadata_age = now - self._last_successful_refresh_ms + ttl = self.config['metadata_max_age_ms'] - metadata_age + + retry_age = now - self._last_refresh_ms + next_retry = self.config['retry_backoff_ms'] - retry_age + + return max(ttl, next_retry, 0) + + def refresh_backoff(self): + """Return milliseconds to wait before attempting to retry after failure""" + return self.config['retry_backoff_ms'] + + def request_update(self): + """Flags metadata for update, return Future() + + Actual update must be handled separately. This method will only + change the reported ttl() + + Returns: + kafka.future.Future (value will be the cluster object after update) + """ + with self._lock: + self._need_update = True + if not self._future or self._future.is_done: + self._future = Future() + return self._future + + def topics(self, exclude_internal_topics=True): + """Get set of known topics. + + Arguments: + exclude_internal_topics (bool): Whether records from internal topics + (such as offsets) should be exposed to the consumer. If set to + True the only way to receive records from an internal topic is + subscribing to it. Default True + + Returns: + set: {topic (str), ...} + """ + topics = set(self._partitions.keys()) + if exclude_internal_topics: + return topics - self.internal_topics + else: + return topics + + def failed_update(self, exception): + """Update cluster state given a failed MetadataRequest.""" + f = None + with self._lock: + if self._future: + f = self._future + self._future = None + if f: + f.failure(exception) + self._last_refresh_ms = time.time() * 1000 def update_metadata(self, metadata): """Update cluster state given a MetadataResponse. @@ -39,9 +234,9 @@ def update_metadata(self, metadata): Returns: None """ - if not metadata.brokers: - log.warning("No broker metadata found in MetadataResponse") + log.warning("No broker metadata found in MetadataResponse -- ignoring.") + return self.failed_update(Errors.MetadataEmptyBrokerList(metadata)) _new_brokers = {} for broker in metadata.brokers: @@ -83,6 +278,10 @@ def update_metadata(self, metadata): _new_broker_partitions[leader].add( TopicPartition(topic, partition)) + # Specific topic errors can be ignored if this is a full metadata fetch + elif self.need_all_topic_metadata: + continue + elif error_type is Errors.LeaderNotAvailableError: log.warning("Topic %s is not available during auto-create" " initialization", topic) @@ -104,12 +303,103 @@ def update_metadata(self, metadata): self._broker_partitions = _new_broker_partitions self.unauthorized_topics = _new_unauthorized_topics self.internal_topics = _new_internal_topics + f = None + if self._future: + f = self._future + self._future = None + self._need_update = False now = time.time() * 1000 self._last_refresh_ms = now self._last_successful_refresh_ms = now + if f: + f.success(self) log.debug("Updated cluster metadata to %s", self) for listener in self._listeners: listener(self) + + if self.need_all_topic_metadata: + # the listener may change the interested topics, + # which could cause another metadata refresh. + # If we have already fetched all topics, however, + # another fetch should be unnecessary. + self._need_update = False + + def add_listener(self, listener): + """Add a callback function to be called on each metadata update""" + self._listeners.add(listener) + + def remove_listener(self, listener): + """Remove a previously added listener callback""" + self._listeners.remove(listener) + + def add_group_coordinator(self, group, response): + """Update with metadata for a group coordinator + + Arguments: + group (str): name of group from GroupCoordinatorRequest + response (GroupCoordinatorResponse): broker response + + Returns: + string: coordinator node_id if metadata is updated, None on error + """ + log.debug("Updating coordinator for %s: %s", group, response) + error_type = Errors.for_code(response.error_code) + if error_type is not Errors.NoError: + log.error("GroupCoordinatorResponse error: %s", error_type) + self._groups[group] = -1 + return + + # Use a coordinator-specific node id so that group requests + # get a dedicated connection + node_id = 'coordinator-{}'.format(response.coordinator_id) + coordinator = BrokerMetadata( + node_id, + response.host, + response.port, + None) + + log.info("Group coordinator for %s is %s", group, coordinator) + self._coordinator_brokers[node_id] = coordinator + self._groups[group] = node_id + return node_id + + def with_partitions(self, partitions_to_add): + """Returns a copy of cluster metadata with partitions added""" + new_metadata = ClusterMetadata(**self.config) + new_metadata._brokers = copy.deepcopy(self._brokers) + new_metadata._partitions = copy.deepcopy(self._partitions) + new_metadata._broker_partitions = copy.deepcopy(self._broker_partitions) + new_metadata._groups = copy.deepcopy(self._groups) + new_metadata.internal_topics = copy.deepcopy(self.internal_topics) + new_metadata.unauthorized_topics = copy.deepcopy(self.unauthorized_topics) + + for partition in partitions_to_add: + new_metadata._partitions[partition.topic][partition.partition] = partition + + if partition.leader is not None and partition.leader != -1: + new_metadata._broker_partitions[partition.leader].add( + TopicPartition(partition.topic, partition.partition)) + + return new_metadata + + def coordinator_metadata(self, node_id): + return self._coordinators.get(node_id) + + def add_coordinator(self, node_id, host, port, rack=None, *, purpose): + """ Keep track of all coordinator nodes separately and remove them if + a new one was elected for the same purpose (For example group + coordinator for group X). + """ + if purpose in self._coordinator_by_key: + old_id = self._coordinator_by_key.pop(purpose) + del self._coordinators[old_id] + + self._coordinators[node_id] = BrokerMetadata(node_id, host, port, rack) + self._coordinator_by_key[purpose] = node_id + + def __str__(self): + return 'ClusterMetadata(brokers: %d, topics: %d, groups: %d)' % \ + (len(self._brokers), len(self._partitions), len(self._groups)) diff --git a/kafka/__init__.py b/kafka/__init__.py index 976287b2..65ef3074 100644 --- a/kafka/__init__.py +++ b/kafka/__init__.py @@ -18,7 +18,6 @@ def emit(self, record): logging.getLogger(__name__).addHandler(NullHandler()) -from kafka.conn import BrokerConnection from kafka.serializer import Serializer, Deserializer from kafka.structs import TopicPartition, OffsetAndMetadata diff --git a/kafka/cluster.py b/kafka/cluster.py deleted file mode 100644 index f6d5e510..00000000 --- a/kafka/cluster.py +++ /dev/null @@ -1,397 +0,0 @@ -from __future__ import absolute_import - -import collections -import copy -import logging -import threading -import time - -from kafka.vendor import six - -from aiokafka import errors as Errors -from kafka.conn import collect_hosts -from kafka.future import Future -from kafka.structs import BrokerMetadata, PartitionMetadata, TopicPartition - -log = logging.getLogger(__name__) - - -class ClusterMetadata(object): - """ - A class to manage kafka cluster metadata. - - This class does not perform any IO. It simply updates internal state - given API responses (MetadataResponse, GroupCoordinatorResponse). - - Keyword Arguments: - retry_backoff_ms (int): Milliseconds to backoff when retrying on - errors. Default: 100. - metadata_max_age_ms (int): The period of time in milliseconds after - which we force a refresh of metadata even if we haven't seen any - partition leadership changes to proactively discover any new - brokers or partitions. Default: 300000 - bootstrap_servers: 'host[:port]' string (or list of 'host[:port]' - strings) that the client should contact to bootstrap initial - cluster metadata. This does not have to be the full node list. - It just needs to have at least one broker that will respond to a - Metadata API Request. Default port is 9092. If no servers are - specified, will default to localhost:9092. - """ - DEFAULT_CONFIG = { - 'retry_backoff_ms': 100, - 'metadata_max_age_ms': 300000, - 'bootstrap_servers': [], - } - - def __init__(self, **configs): - self._brokers = {} # node_id -> BrokerMetadata - self._partitions = {} # topic -> partition -> PartitionMetadata - self._broker_partitions = collections.defaultdict(set) # node_id -> {TopicPartition...} - self._groups = {} # group_name -> node_id - self._last_refresh_ms = 0 - self._last_successful_refresh_ms = 0 - self._need_update = True - self._future = None - self._listeners = set() - self._lock = threading.Lock() - self.need_all_topic_metadata = False - self.unauthorized_topics = set() - self.internal_topics = set() - self.controller = None - - self.config = copy.copy(self.DEFAULT_CONFIG) - for key in self.config: - if key in configs: - self.config[key] = configs[key] - - self._bootstrap_brokers = self._generate_bootstrap_brokers() - self._coordinator_brokers = {} - - def _generate_bootstrap_brokers(self): - # collect_hosts does not perform DNS, so we should be fine to re-use - bootstrap_hosts = collect_hosts(self.config['bootstrap_servers']) - - brokers = {} - for i, (host, port, _) in enumerate(bootstrap_hosts): - node_id = 'bootstrap-%s' % i - brokers[node_id] = BrokerMetadata(node_id, host, port, None) - return brokers - - def is_bootstrap(self, node_id): - return node_id in self._bootstrap_brokers - - def brokers(self): - """Get all BrokerMetadata - - Returns: - set: {BrokerMetadata, ...} - """ - return set(self._brokers.values()) or set(self._bootstrap_brokers.values()) - - def broker_metadata(self, broker_id): - """Get BrokerMetadata - - Arguments: - broker_id (int): node_id for a broker to check - - Returns: - BrokerMetadata or None if not found - """ - return ( - self._brokers.get(broker_id) or - self._bootstrap_brokers.get(broker_id) or - self._coordinator_brokers.get(broker_id) - ) - - def partitions_for_topic(self, topic): - """Return set of all partitions for topic (whether available or not) - - Arguments: - topic (str): topic to check for partitions - - Returns: - set: {partition (int), ...} - """ - if topic not in self._partitions: - return None - return set(self._partitions[topic].keys()) - - def available_partitions_for_topic(self, topic): - """Return set of partitions with known leaders - - Arguments: - topic (str): topic to check for partitions - - Returns: - set: {partition (int), ...} - None if topic not found. - """ - if topic not in self._partitions: - return None - return set([partition for partition, metadata - in six.iteritems(self._partitions[topic]) - if metadata.leader != -1]) - - def leader_for_partition(self, partition): - """Return node_id of leader, -1 unavailable, None if unknown.""" - if partition.topic not in self._partitions: - return None - elif partition.partition not in self._partitions[partition.topic]: - return None - return self._partitions[partition.topic][partition.partition].leader - - def partitions_for_broker(self, broker_id): - """Return TopicPartitions for which the broker is a leader. - - Arguments: - broker_id (int): node id for a broker - - Returns: - set: {TopicPartition, ...} - None if the broker either has no partitions or does not exist. - """ - return self._broker_partitions.get(broker_id) - - def coordinator_for_group(self, group): - """Return node_id of group coordinator. - - Arguments: - group (str): name of consumer group - - Returns: - int: node_id for group coordinator - None if the group does not exist. - """ - return self._groups.get(group) - - def ttl(self): - """Milliseconds until metadata should be refreshed""" - now = time.time() * 1000 - if self._need_update: - ttl = 0 - else: - metadata_age = now - self._last_successful_refresh_ms - ttl = self.config['metadata_max_age_ms'] - metadata_age - - retry_age = now - self._last_refresh_ms - next_retry = self.config['retry_backoff_ms'] - retry_age - - return max(ttl, next_retry, 0) - - def refresh_backoff(self): - """Return milliseconds to wait before attempting to retry after failure""" - return self.config['retry_backoff_ms'] - - def request_update(self): - """Flags metadata for update, return Future() - - Actual update must be handled separately. This method will only - change the reported ttl() - - Returns: - kafka.future.Future (value will be the cluster object after update) - """ - with self._lock: - self._need_update = True - if not self._future or self._future.is_done: - self._future = Future() - return self._future - - def topics(self, exclude_internal_topics=True): - """Get set of known topics. - - Arguments: - exclude_internal_topics (bool): Whether records from internal topics - (such as offsets) should be exposed to the consumer. If set to - True the only way to receive records from an internal topic is - subscribing to it. Default True - - Returns: - set: {topic (str), ...} - """ - topics = set(self._partitions.keys()) - if exclude_internal_topics: - return topics - self.internal_topics - else: - return topics - - def failed_update(self, exception): - """Update cluster state given a failed MetadataRequest.""" - f = None - with self._lock: - if self._future: - f = self._future - self._future = None - if f: - f.failure(exception) - self._last_refresh_ms = time.time() * 1000 - - def update_metadata(self, metadata): - """Update cluster state given a MetadataResponse. - - Arguments: - metadata (MetadataResponse): broker response to a metadata request - - Returns: None - """ - # In the common case where we ask for a single topic and get back an - # error, we should fail the future - if len(metadata.topics) == 1 and metadata.topics[0][0] != 0: - error_code, topic = metadata.topics[0][:2] - error = Errors.for_code(error_code)(topic) - return self.failed_update(error) - - if not metadata.brokers: - log.warning("No broker metadata found in MetadataResponse -- ignoring.") - return self.failed_update(Errors.MetadataEmptyBrokerList(metadata)) - - _new_brokers = {} - for broker in metadata.brokers: - if metadata.API_VERSION == 0: - node_id, host, port = broker - rack = None - else: - node_id, host, port, rack = broker - _new_brokers.update({ - node_id: BrokerMetadata(node_id, host, port, rack) - }) - - if metadata.API_VERSION == 0: - _new_controller = None - else: - _new_controller = _new_brokers.get(metadata.controller_id) - - _new_partitions = {} - _new_broker_partitions = collections.defaultdict(set) - _new_unauthorized_topics = set() - _new_internal_topics = set() - - for topic_data in metadata.topics: - if metadata.API_VERSION == 0: - error_code, topic, partitions = topic_data - is_internal = False - else: - error_code, topic, is_internal, partitions = topic_data - if is_internal: - _new_internal_topics.add(topic) - error_type = Errors.for_code(error_code) - if error_type is Errors.NoError: - _new_partitions[topic] = {} - for p_error, partition, leader, replicas, isr in partitions: - _new_partitions[topic][partition] = PartitionMetadata( - topic=topic, partition=partition, leader=leader, - replicas=replicas, isr=isr, error=p_error) - if leader != -1: - _new_broker_partitions[leader].add( - TopicPartition(topic, partition)) - - # Specific topic errors can be ignored if this is a full metadata fetch - elif self.need_all_topic_metadata: - continue - - elif error_type is Errors.LeaderNotAvailableError: - log.warning("Topic %s is not available during auto-create" - " initialization", topic) - elif error_type is Errors.UnknownTopicOrPartitionError: - log.error("Topic %s not found in cluster metadata", topic) - elif error_type is Errors.TopicAuthorizationFailedError: - log.error("Topic %s is not authorized for this client", topic) - _new_unauthorized_topics.add(topic) - elif error_type is Errors.InvalidTopicError: - log.error("'%s' is not a valid topic name", topic) - else: - log.error("Error fetching metadata for topic %s: %s", - topic, error_type) - - with self._lock: - self._brokers = _new_brokers - self.controller = _new_controller - self._partitions = _new_partitions - self._broker_partitions = _new_broker_partitions - self.unauthorized_topics = _new_unauthorized_topics - self.internal_topics = _new_internal_topics - f = None - if self._future: - f = self._future - self._future = None - self._need_update = False - - now = time.time() * 1000 - self._last_refresh_ms = now - self._last_successful_refresh_ms = now - - if f: - f.success(self) - log.debug("Updated cluster metadata to %s", self) - - for listener in self._listeners: - listener(self) - - if self.need_all_topic_metadata: - # the listener may change the interested topics, - # which could cause another metadata refresh. - # If we have already fetched all topics, however, - # another fetch should be unnecessary. - self._need_update = False - - def add_listener(self, listener): - """Add a callback function to be called on each metadata update""" - self._listeners.add(listener) - - def remove_listener(self, listener): - """Remove a previously added listener callback""" - self._listeners.remove(listener) - - def add_group_coordinator(self, group, response): - """Update with metadata for a group coordinator - - Arguments: - group (str): name of group from GroupCoordinatorRequest - response (GroupCoordinatorResponse): broker response - - Returns: - string: coordinator node_id if metadata is updated, None on error - """ - log.debug("Updating coordinator for %s: %s", group, response) - error_type = Errors.for_code(response.error_code) - if error_type is not Errors.NoError: - log.error("GroupCoordinatorResponse error: %s", error_type) - self._groups[group] = -1 - return - - # Use a coordinator-specific node id so that group requests - # get a dedicated connection - node_id = 'coordinator-{}'.format(response.coordinator_id) - coordinator = BrokerMetadata( - node_id, - response.host, - response.port, - None) - - log.info("Group coordinator for %s is %s", group, coordinator) - self._coordinator_brokers[node_id] = coordinator - self._groups[group] = node_id - return node_id - - def with_partitions(self, partitions_to_add): - """Returns a copy of cluster metadata with partitions added""" - new_metadata = ClusterMetadata(**self.config) - new_metadata._brokers = copy.deepcopy(self._brokers) - new_metadata._partitions = copy.deepcopy(self._partitions) - new_metadata._broker_partitions = copy.deepcopy(self._broker_partitions) - new_metadata._groups = copy.deepcopy(self._groups) - new_metadata.internal_topics = copy.deepcopy(self.internal_topics) - new_metadata.unauthorized_topics = copy.deepcopy(self.unauthorized_topics) - - for partition in partitions_to_add: - new_metadata._partitions[partition.topic][partition.partition] = partition - - if partition.leader is not None and partition.leader != -1: - new_metadata._broker_partitions[partition.leader].add( - TopicPartition(partition.topic, partition.partition)) - - return new_metadata - - def __str__(self): - return 'ClusterMetadata(brokers: %d, topics: %d, groups: %d)' % \ - (len(self._brokers), len(self._partitions), len(self._groups)) diff --git a/kafka/conn.py b/kafka/conn.py deleted file mode 100644 index 3edd1915..00000000 --- a/kafka/conn.py +++ /dev/null @@ -1,1534 +0,0 @@ -from __future__ import absolute_import, division - -import copy -import errno -import io -import logging -from random import shuffle, uniform - -# selectors in stdlib as of py3.4 -try: - import selectors # pylint: disable=import-error -except ImportError: - # vendored backport module - from kafka.vendor import selectors34 as selectors - -import socket -import struct -import threading -import time - -from kafka.vendor import six - -import aiokafka.errors as Errors -from kafka.future import Future -from kafka.metrics.stats import Avg, Count, Max, Rate -from kafka.oauth.abstract import AbstractTokenProvider -from kafka.protocol.admin import SaslHandShakeRequest, DescribeAclsRequest_v2, DescribeClientQuotasRequest -from kafka.protocol.commit import OffsetFetchRequest -from kafka.protocol.offset import OffsetRequest -from kafka.protocol.produce import ProduceRequest -from kafka.protocol.metadata import MetadataRequest -from kafka.protocol.fetch import FetchRequest -from kafka.protocol.parser import KafkaProtocol -from kafka.protocol.types import Int32, Int8 -from kafka.scram import ScramClient -from kafka.version import __version__ - - -if six.PY2: - ConnectionError = socket.error - TimeoutError = socket.error - BlockingIOError = Exception - -log = logging.getLogger(__name__) - -DEFAULT_KAFKA_PORT = 9092 - -SASL_QOP_AUTH = 1 -SASL_QOP_AUTH_INT = 2 -SASL_QOP_AUTH_CONF = 4 - -try: - import ssl - ssl_available = True - try: - SSLEOFError = ssl.SSLEOFError - SSLWantReadError = ssl.SSLWantReadError - SSLWantWriteError = ssl.SSLWantWriteError - SSLZeroReturnError = ssl.SSLZeroReturnError - except AttributeError: - # support older ssl libraries - log.warning('Old SSL module detected.' - ' SSL error handling may not operate cleanly.' - ' Consider upgrading to Python 3.3 or 2.7.9') - SSLEOFError = ssl.SSLError - SSLWantReadError = ssl.SSLError - SSLWantWriteError = ssl.SSLError - SSLZeroReturnError = ssl.SSLError -except ImportError: - # support Python without ssl libraries - ssl_available = False - class SSLWantReadError(Exception): - pass - class SSLWantWriteError(Exception): - pass - -# needed for SASL_GSSAPI authentication: -try: - import gssapi - from gssapi.raw.misc import GSSError -except ImportError: - #no gssapi available, will disable gssapi mechanism - gssapi = None - GSSError = None - - -AFI_NAMES = { - socket.AF_UNSPEC: "unspecified", - socket.AF_INET: "IPv4", - socket.AF_INET6: "IPv6", -} - - -class ConnectionStates(object): - DISCONNECTING = '' - DISCONNECTED = '' - CONNECTING = '' - HANDSHAKE = '' - CONNECTED = '' - AUTHENTICATING = '' - - -class BrokerConnection(object): - """Initialize a Kafka broker connection - - Keyword Arguments: - client_id (str): a name for this client. This string is passed in - each request to servers and can be used to identify specific - server-side log entries that correspond to this client. Also - submitted to GroupCoordinator for logging with respect to - consumer group administration. Default: 'kafka-python-{version}' - reconnect_backoff_ms (int): The amount of time in milliseconds to - wait before attempting to reconnect to a given host. - Default: 50. - reconnect_backoff_max_ms (int): The maximum amount of time in - milliseconds to backoff/wait when reconnecting to a broker that has - repeatedly failed to connect. If provided, the backoff per host - will increase exponentially for each consecutive connection - failure, up to this maximum. Once the maximum is reached, - reconnection attempts will continue periodically with this fixed - rate. To avoid connection storms, a randomization factor of 0.2 - will be applied to the backoff resulting in a random range between - 20% below and 20% above the computed value. Default: 1000. - request_timeout_ms (int): Client request timeout in milliseconds. - Default: 30000. - max_in_flight_requests_per_connection (int): Requests are pipelined - to kafka brokers up to this number of maximum requests per - broker connection. Default: 5. - receive_buffer_bytes (int): The size of the TCP receive buffer - (SO_RCVBUF) to use when reading data. Default: None (relies on - system defaults). Java client defaults to 32768. - send_buffer_bytes (int): The size of the TCP send buffer - (SO_SNDBUF) to use when sending data. Default: None (relies on - system defaults). Java client defaults to 131072. - socket_options (list): List of tuple-arguments to socket.setsockopt - to apply to broker connection sockets. Default: - [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)] - security_protocol (str): Protocol used to communicate with brokers. - Valid values are: PLAINTEXT, SSL, SASL_PLAINTEXT, SASL_SSL. - Default: PLAINTEXT. - ssl_context (ssl.SSLContext): pre-configured SSLContext for wrapping - socket connections. If provided, all other ssl_* configurations - will be ignored. Default: None. - ssl_check_hostname (bool): flag to configure whether ssl handshake - should verify that the certificate matches the brokers hostname. - default: True. - ssl_cafile (str): optional filename of ca file to use in certificate - verification. default: None. - ssl_certfile (str): optional filename of file in pem format containing - the client certificate, as well as any ca certificates needed to - establish the certificate's authenticity. default: None. - ssl_keyfile (str): optional filename containing the client private key. - default: None. - ssl_password (callable, str, bytes, bytearray): optional password or - callable function that returns a password, for decrypting the - client private key. Default: None. - ssl_crlfile (str): optional filename containing the CRL to check for - certificate expiration. By default, no CRL check is done. When - providing a file, only the leaf certificate will be checked against - this CRL. The CRL can only be checked with Python 3.4+ or 2.7.9+. - default: None. - ssl_ciphers (str): optionally set the available ciphers for ssl - connections. It should be a string in the OpenSSL cipher list - format. If no cipher can be selected (because compile-time options - or other configuration forbids use of all the specified ciphers), - an ssl.SSLError will be raised. See ssl.SSLContext.set_ciphers - api_version (tuple): Specify which Kafka API version to use. - Accepted values are: (0, 8, 0), (0, 8, 1), (0, 8, 2), (0, 9), - (0, 10). Default: (0, 8, 2) - api_version_auto_timeout_ms (int): number of milliseconds to throw a - timeout exception from the constructor when checking the broker - api version. Only applies if api_version is None - selector (selectors.BaseSelector): Provide a specific selector - implementation to use for I/O multiplexing. - Default: selectors.DefaultSelector - state_change_callback (callable): function to be called when the - connection state changes from CONNECTING to CONNECTED etc. - metrics (kafka.metrics.Metrics): Optionally provide a metrics - instance for capturing network IO stats. Default: None. - metric_group_prefix (str): Prefix for metric names. Default: '' - sasl_mechanism (str): Authentication mechanism when security_protocol - is configured for SASL_PLAINTEXT or SASL_SSL. Valid values are: - PLAIN, GSSAPI, OAUTHBEARER, SCRAM-SHA-256, SCRAM-SHA-512. - sasl_plain_username (str): username for sasl PLAIN and SCRAM authentication. - Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms. - sasl_plain_password (str): password for sasl PLAIN and SCRAM authentication. - Required if sasl_mechanism is PLAIN or one of the SCRAM mechanisms. - sasl_kerberos_service_name (str): Service name to include in GSSAPI - sasl mechanism handshake. Default: 'kafka' - sasl_kerberos_domain_name (str): kerberos domain name to use in GSSAPI - sasl mechanism handshake. Default: one of bootstrap servers - sasl_oauth_token_provider (AbstractTokenProvider): OAuthBearer token provider - instance. (See kafka.oauth.abstract). Default: None - """ - - DEFAULT_CONFIG = { - 'client_id': 'kafka-python-' + __version__, - 'node_id': 0, - 'request_timeout_ms': 30000, - 'reconnect_backoff_ms': 50, - 'reconnect_backoff_max_ms': 1000, - 'max_in_flight_requests_per_connection': 5, - 'receive_buffer_bytes': None, - 'send_buffer_bytes': None, - 'socket_options': [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)], - 'sock_chunk_bytes': 4096, # undocumented experimental option - 'sock_chunk_buffer_count': 1000, # undocumented experimental option - 'security_protocol': 'PLAINTEXT', - 'ssl_context': None, - 'ssl_check_hostname': True, - 'ssl_cafile': None, - 'ssl_certfile': None, - 'ssl_keyfile': None, - 'ssl_crlfile': None, - 'ssl_password': None, - 'ssl_ciphers': None, - 'api_version': (0, 8, 2), # default to most restrictive - 'selector': selectors.DefaultSelector, - 'state_change_callback': lambda node_id, sock, conn: True, - 'metrics': None, - 'metric_group_prefix': '', - 'sasl_mechanism': None, - 'sasl_plain_username': None, - 'sasl_plain_password': None, - 'sasl_kerberos_service_name': 'kafka', - 'sasl_kerberos_domain_name': None, - 'sasl_oauth_token_provider': None - } - SECURITY_PROTOCOLS = ('PLAINTEXT', 'SSL', 'SASL_PLAINTEXT', 'SASL_SSL') - SASL_MECHANISMS = ('PLAIN', 'GSSAPI', 'OAUTHBEARER', "SCRAM-SHA-256", "SCRAM-SHA-512") - - def __init__(self, host, port, afi, **configs): - self.host = host - self.port = port - self.afi = afi - self._sock_afi = afi - self._sock_addr = None - self._api_versions = None - - self.config = copy.copy(self.DEFAULT_CONFIG) - for key in self.config: - if key in configs: - self.config[key] = configs[key] - - self.node_id = self.config.pop('node_id') - - if self.config['receive_buffer_bytes'] is not None: - self.config['socket_options'].append( - (socket.SOL_SOCKET, socket.SO_RCVBUF, - self.config['receive_buffer_bytes'])) - if self.config['send_buffer_bytes'] is not None: - self.config['socket_options'].append( - (socket.SOL_SOCKET, socket.SO_SNDBUF, - self.config['send_buffer_bytes'])) - - assert self.config['security_protocol'] in self.SECURITY_PROTOCOLS, ( - 'security_protocol must be in ' + ', '.join(self.SECURITY_PROTOCOLS)) - - if self.config['security_protocol'] in ('SSL', 'SASL_SSL'): - assert ssl_available, "Python wasn't built with SSL support" - - if self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL'): - assert self.config['sasl_mechanism'] in self.SASL_MECHANISMS, ( - 'sasl_mechanism must be in ' + ', '.join(self.SASL_MECHANISMS)) - if self.config['sasl_mechanism'] in ('PLAIN', 'SCRAM-SHA-256', 'SCRAM-SHA-512'): - assert self.config['sasl_plain_username'] is not None, ( - 'sasl_plain_username required for PLAIN or SCRAM sasl' - ) - assert self.config['sasl_plain_password'] is not None, ( - 'sasl_plain_password required for PLAIN or SCRAM sasl' - ) - if self.config['sasl_mechanism'] == 'GSSAPI': - assert gssapi is not None, 'GSSAPI lib not available' - assert self.config['sasl_kerberos_service_name'] is not None, 'sasl_kerberos_service_name required for GSSAPI sasl' - if self.config['sasl_mechanism'] == 'OAUTHBEARER': - token_provider = self.config['sasl_oauth_token_provider'] - assert token_provider is not None, 'sasl_oauth_token_provider required for OAUTHBEARER sasl' - assert callable(getattr(token_provider, "token", None)), 'sasl_oauth_token_provider must implement method #token()' - # This is not a general lock / this class is not generally thread-safe yet - # However, to avoid pushing responsibility for maintaining - # per-connection locks to the upstream client, we will use this lock to - # make sure that access to the protocol buffer is synchronized - # when sends happen on multiple threads - self._lock = threading.Lock() - - # the protocol parser instance manages actual tracking of the - # sequence of in-flight requests to responses, which should - # function like a FIFO queue. For additional request data, - # including tracking request futures and timestamps, we - # can use a simple dictionary of correlation_id => request data - self.in_flight_requests = dict() - - self._protocol = KafkaProtocol( - client_id=self.config['client_id'], - api_version=self.config['api_version']) - self.state = ConnectionStates.DISCONNECTED - self._reset_reconnect_backoff() - self._sock = None - self._send_buffer = b'' - self._ssl_context = None - if self.config['ssl_context'] is not None: - self._ssl_context = self.config['ssl_context'] - self._sasl_auth_future = None - self.last_attempt = 0 - self._gai = [] - self._sensors = None - if self.config['metrics']: - self._sensors = BrokerConnectionMetrics(self.config['metrics'], - self.config['metric_group_prefix'], - self.node_id) - - def _dns_lookup(self): - self._gai = dns_lookup(self.host, self.port, self.afi) - if not self._gai: - log.error('DNS lookup failed for %s:%i (%s)', - self.host, self.port, self.afi) - return False - return True - - def _next_afi_sockaddr(self): - if not self._gai: - if not self._dns_lookup(): - return - afi, _, __, ___, sockaddr = self._gai.pop(0) - return (afi, sockaddr) - - def connect_blocking(self, timeout=float('inf')): - if self.connected(): - return True - timeout += time.time() - # First attempt to perform dns lookup - # note that the underlying interface, socket.getaddrinfo, - # has no explicit timeout so we may exceed the user-specified timeout - self._dns_lookup() - - # Loop once over all returned dns entries - selector = None - while self._gai: - while time.time() < timeout: - self.connect() - if self.connected(): - if selector is not None: - selector.close() - return True - elif self.connecting(): - if selector is None: - selector = self.config['selector']() - selector.register(self._sock, selectors.EVENT_WRITE) - selector.select(1) - elif self.disconnected(): - if selector is not None: - selector.close() - selector = None - break - else: - break - return False - - def connect(self): - """Attempt to connect and return ConnectionState""" - if self.state is ConnectionStates.DISCONNECTED and not self.blacked_out(): - self.last_attempt = time.time() - next_lookup = self._next_afi_sockaddr() - if not next_lookup: - self.close(Errors.KafkaConnectionError('DNS failure')) - return self.state - else: - log.debug('%s: creating new socket', self) - assert self._sock is None - self._sock_afi, self._sock_addr = next_lookup - self._sock = socket.socket(self._sock_afi, socket.SOCK_STREAM) - - for option in self.config['socket_options']: - log.debug('%s: setting socket option %s', self, option) - self._sock.setsockopt(*option) - - self._sock.setblocking(False) - self.state = ConnectionStates.CONNECTING - self.config['state_change_callback'](self.node_id, self._sock, self) - log.info('%s: connecting to %s:%d [%s %s]', self, self.host, - self.port, self._sock_addr, AFI_NAMES[self._sock_afi]) - - if self.state is ConnectionStates.CONNECTING: - # in non-blocking mode, use repeated calls to socket.connect_ex - # to check connection status - ret = None - try: - ret = self._sock.connect_ex(self._sock_addr) - except socket.error as err: - ret = err.errno - - # Connection succeeded - if not ret or ret == errno.EISCONN: - log.debug('%s: established TCP connection', self) - - if self.config['security_protocol'] in ('SSL', 'SASL_SSL'): - log.debug('%s: initiating SSL handshake', self) - self.state = ConnectionStates.HANDSHAKE - self.config['state_change_callback'](self.node_id, self._sock, self) - # _wrap_ssl can alter the connection state -- disconnects on failure - self._wrap_ssl() - - elif self.config['security_protocol'] == 'SASL_PLAINTEXT': - log.debug('%s: initiating SASL authentication', self) - self.state = ConnectionStates.AUTHENTICATING - self.config['state_change_callback'](self.node_id, self._sock, self) - - else: - # security_protocol PLAINTEXT - log.info('%s: Connection complete.', self) - self.state = ConnectionStates.CONNECTED - self._reset_reconnect_backoff() - self.config['state_change_callback'](self.node_id, self._sock, self) - - # Connection failed - # WSAEINVAL == 10022, but errno.WSAEINVAL is not available on non-win systems - elif ret not in (errno.EINPROGRESS, errno.EALREADY, errno.EWOULDBLOCK, 10022): - log.error('Connect attempt to %s returned error %s.' - ' Disconnecting.', self, ret) - errstr = errno.errorcode.get(ret, 'UNKNOWN') - self.close(Errors.KafkaConnectionError('{} {}'.format(ret, errstr))) - return self.state - - # Needs retry - else: - pass - - if self.state is ConnectionStates.HANDSHAKE: - if self._try_handshake(): - log.debug('%s: completed SSL handshake.', self) - if self.config['security_protocol'] == 'SASL_SSL': - log.debug('%s: initiating SASL authentication', self) - self.state = ConnectionStates.AUTHENTICATING - else: - log.info('%s: Connection complete.', self) - self.state = ConnectionStates.CONNECTED - self._reset_reconnect_backoff() - self.config['state_change_callback'](self.node_id, self._sock, self) - - if self.state is ConnectionStates.AUTHENTICATING: - assert self.config['security_protocol'] in ('SASL_PLAINTEXT', 'SASL_SSL') - if self._try_authenticate(): - # _try_authenticate has side-effects: possibly disconnected on socket errors - if self.state is ConnectionStates.AUTHENTICATING: - log.info('%s: Connection complete.', self) - self.state = ConnectionStates.CONNECTED - self._reset_reconnect_backoff() - self.config['state_change_callback'](self.node_id, self._sock, self) - - if self.state not in (ConnectionStates.CONNECTED, - ConnectionStates.DISCONNECTED): - # Connection timed out - request_timeout = self.config['request_timeout_ms'] / 1000.0 - if time.time() > request_timeout + self.last_attempt: - log.error('Connection attempt to %s timed out', self) - self.close(Errors.KafkaConnectionError('timeout')) - return self.state - - return self.state - - def _wrap_ssl(self): - assert self.config['security_protocol'] in ('SSL', 'SASL_SSL') - if self._ssl_context is None: - log.debug('%s: configuring default SSL Context', self) - self._ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) # pylint: disable=no-member - self._ssl_context.options |= ssl.OP_NO_SSLv2 # pylint: disable=no-member - self._ssl_context.options |= ssl.OP_NO_SSLv3 # pylint: disable=no-member - self._ssl_context.verify_mode = ssl.CERT_OPTIONAL - if self.config['ssl_check_hostname']: - self._ssl_context.check_hostname = True - if self.config['ssl_cafile']: - log.info('%s: Loading SSL CA from %s', self, self.config['ssl_cafile']) - self._ssl_context.load_verify_locations(self.config['ssl_cafile']) - self._ssl_context.verify_mode = ssl.CERT_REQUIRED - else: - log.info('%s: Loading system default SSL CAs from %s', self, ssl.get_default_verify_paths()) - self._ssl_context.load_default_certs() - if self.config['ssl_certfile'] and self.config['ssl_keyfile']: - log.info('%s: Loading SSL Cert from %s', self, self.config['ssl_certfile']) - log.info('%s: Loading SSL Key from %s', self, self.config['ssl_keyfile']) - self._ssl_context.load_cert_chain( - certfile=self.config['ssl_certfile'], - keyfile=self.config['ssl_keyfile'], - password=self.config['ssl_password']) - if self.config['ssl_crlfile']: - if not hasattr(ssl, 'VERIFY_CRL_CHECK_LEAF'): - raise RuntimeError('This version of Python does not support ssl_crlfile!') - log.info('%s: Loading SSL CRL from %s', self, self.config['ssl_crlfile']) - self._ssl_context.load_verify_locations(self.config['ssl_crlfile']) - # pylint: disable=no-member - self._ssl_context.verify_flags |= ssl.VERIFY_CRL_CHECK_LEAF - if self.config['ssl_ciphers']: - log.info('%s: Setting SSL Ciphers: %s', self, self.config['ssl_ciphers']) - self._ssl_context.set_ciphers(self.config['ssl_ciphers']) - log.debug('%s: wrapping socket in ssl context', self) - try: - self._sock = self._ssl_context.wrap_socket( - self._sock, - server_hostname=self.host, - do_handshake_on_connect=False) - except ssl.SSLError as e: - log.exception('%s: Failed to wrap socket in SSLContext!', self) - self.close(e) - - def _try_handshake(self): - assert self.config['security_protocol'] in ('SSL', 'SASL_SSL') - try: - self._sock.do_handshake() - return True - # old ssl in python2.6 will swallow all SSLErrors here... - except (SSLWantReadError, SSLWantWriteError): - pass - except (SSLZeroReturnError, ConnectionError, TimeoutError, SSLEOFError): - log.warning('SSL connection closed by server during handshake.') - self.close(Errors.KafkaConnectionError('SSL connection closed by server during handshake')) - # Other SSLErrors will be raised to user - - return False - - def _try_authenticate(self): - assert self.config['api_version'] is None or self.config['api_version'] >= (0, 10) - - if self._sasl_auth_future is None: - # Build a SaslHandShakeRequest message - request = SaslHandShakeRequest[0](self.config['sasl_mechanism']) - future = Future() - sasl_response = self._send(request) - sasl_response.add_callback(self._handle_sasl_handshake_response, future) - sasl_response.add_errback(lambda f, e: f.failure(e), future) - self._sasl_auth_future = future - - for r, f in self.recv(): - f.success(r) - - # A connection error could trigger close() which will reset the future - if self._sasl_auth_future is None: - return False - elif self._sasl_auth_future.failed(): - ex = self._sasl_auth_future.exception - if not isinstance(ex, Errors.KafkaConnectionError): - raise ex # pylint: disable-msg=raising-bad-type - return self._sasl_auth_future.succeeded() - - def _handle_sasl_handshake_response(self, future, response): - error_type = Errors.for_code(response.error_code) - if error_type is not Errors.NoError: - error = error_type(self) - self.close(error=error) - return future.failure(error_type(self)) - - if self.config['sasl_mechanism'] not in response.enabled_mechanisms: - return future.failure( - Errors.UnsupportedSaslMechanismError( - 'Kafka broker does not support %s sasl mechanism. Enabled mechanisms are: %s' - % (self.config['sasl_mechanism'], response.enabled_mechanisms))) - elif self.config['sasl_mechanism'] == 'PLAIN': - return self._try_authenticate_plain(future) - elif self.config['sasl_mechanism'] == 'GSSAPI': - return self._try_authenticate_gssapi(future) - elif self.config['sasl_mechanism'] == 'OAUTHBEARER': - return self._try_authenticate_oauth(future) - elif self.config['sasl_mechanism'].startswith("SCRAM-SHA-"): - return self._try_authenticate_scram(future) - else: - return future.failure( - Errors.UnsupportedSaslMechanismError( - 'kafka-python does not support SASL mechanism %s' % - self.config['sasl_mechanism'])) - - def _send_bytes(self, data): - """Send some data via non-blocking IO - - Note: this method is not synchronized internally; you should - always hold the _lock before calling - - Returns: number of bytes - Raises: socket exception - """ - total_sent = 0 - while total_sent < len(data): - try: - sent_bytes = self._sock.send(data[total_sent:]) - total_sent += sent_bytes - except (SSLWantReadError, SSLWantWriteError): - break - except (ConnectionError, TimeoutError) as e: - if six.PY2 and e.errno == errno.EWOULDBLOCK: - break - raise - except BlockingIOError: - if six.PY3: - break - raise - return total_sent - - def _send_bytes_blocking(self, data): - self._sock.settimeout(self.config['request_timeout_ms'] / 1000) - total_sent = 0 - try: - while total_sent < len(data): - sent_bytes = self._sock.send(data[total_sent:]) - total_sent += sent_bytes - if total_sent != len(data): - raise ConnectionError('Buffer overrun during socket send') - return total_sent - finally: - self._sock.settimeout(0.0) - - def _recv_bytes_blocking(self, n): - self._sock.settimeout(self.config['request_timeout_ms'] / 1000) - try: - data = b'' - while len(data) < n: - fragment = self._sock.recv(n - len(data)) - if not fragment: - raise ConnectionError('Connection reset during recv') - data += fragment - return data - finally: - self._sock.settimeout(0.0) - - def _try_authenticate_plain(self, future): - if self.config['security_protocol'] == 'SASL_PLAINTEXT': - log.warning('%s: Sending username and password in the clear', self) - - data = b'' - # Send PLAIN credentials per RFC-4616 - msg = bytes('\0'.join([self.config['sasl_plain_username'], - self.config['sasl_plain_username'], - self.config['sasl_plain_password']]).encode('utf-8')) - size = Int32.encode(len(msg)) - - err = None - close = False - with self._lock: - if not self._can_send_recv(): - err = Errors.NodeNotReadyError(str(self)) - close = False - else: - try: - self._send_bytes_blocking(size + msg) - - # The server will send a zero sized message (that is Int32(0)) on success. - # The connection is closed on failure - data = self._recv_bytes_blocking(4) - - except (ConnectionError, TimeoutError) as e: - log.exception("%s: Error receiving reply from server", self) - err = Errors.KafkaConnectionError("%s: %s" % (self, e)) - close = True - - if err is not None: - if close: - self.close(error=err) - return future.failure(err) - - if data != b'\x00\x00\x00\x00': - error = Errors.AuthenticationFailedError('Unrecognized response during authentication') - return future.failure(error) - - log.info('%s: Authenticated as %s via PLAIN', self, self.config['sasl_plain_username']) - return future.success(True) - - def _try_authenticate_scram(self, future): - if self.config['security_protocol'] == 'SASL_PLAINTEXT': - log.warning('%s: Exchanging credentials in the clear', self) - - scram_client = ScramClient( - self.config['sasl_plain_username'], self.config['sasl_plain_password'], self.config['sasl_mechanism'] - ) - - err = None - close = False - with self._lock: - if not self._can_send_recv(): - err = Errors.NodeNotReadyError(str(self)) - close = False - else: - try: - client_first = scram_client.first_message().encode('utf-8') - size = Int32.encode(len(client_first)) - self._send_bytes_blocking(size + client_first) - - (data_len,) = struct.unpack('>i', self._recv_bytes_blocking(4)) - server_first = self._recv_bytes_blocking(data_len).decode('utf-8') - scram_client.process_server_first_message(server_first) - - client_final = scram_client.final_message().encode('utf-8') - size = Int32.encode(len(client_final)) - self._send_bytes_blocking(size + client_final) - - (data_len,) = struct.unpack('>i', self._recv_bytes_blocking(4)) - server_final = self._recv_bytes_blocking(data_len).decode('utf-8') - scram_client.process_server_final_message(server_final) - - except (ConnectionError, TimeoutError) as e: - log.exception("%s: Error receiving reply from server", self) - err = Errors.KafkaConnectionError("%s: %s" % (self, e)) - close = True - - if err is not None: - if close: - self.close(error=err) - return future.failure(err) - - log.info( - '%s: Authenticated as %s via %s', self, self.config['sasl_plain_username'], self.config['sasl_mechanism'] - ) - return future.success(True) - - def _try_authenticate_gssapi(self, future): - kerberos_damin_name = self.config['sasl_kerberos_domain_name'] or self.host - auth_id = self.config['sasl_kerberos_service_name'] + '@' + kerberos_damin_name - gssapi_name = gssapi.Name( - auth_id, - name_type=gssapi.NameType.hostbased_service - ).canonicalize(gssapi.MechType.kerberos) - log.debug('%s: GSSAPI name: %s', self, gssapi_name) - - err = None - close = False - with self._lock: - if not self._can_send_recv(): - err = Errors.NodeNotReadyError(str(self)) - close = False - else: - # Establish security context and negotiate protection level - # For reference RFC 2222, section 7.2.1 - try: - # Exchange tokens until authentication either succeeds or fails - client_ctx = gssapi.SecurityContext(name=gssapi_name, usage='initiate') - received_token = None - while not client_ctx.complete: - # calculate an output token from kafka token (or None if first iteration) - output_token = client_ctx.step(received_token) - - # pass output token to kafka, or send empty response if the security - # context is complete (output token is None in that case) - if output_token is None: - self._send_bytes_blocking(Int32.encode(0)) - else: - msg = output_token - size = Int32.encode(len(msg)) - self._send_bytes_blocking(size + msg) - - # The server will send a token back. Processing of this token either - # establishes a security context, or it needs further token exchange. - # The gssapi will be able to identify the needed next step. - # The connection is closed on failure. - header = self._recv_bytes_blocking(4) - (token_size,) = struct.unpack('>i', header) - received_token = self._recv_bytes_blocking(token_size) - - # Process the security layer negotiation token, sent by the server - # once the security context is established. - - # unwraps message containing supported protection levels and msg size - msg = client_ctx.unwrap(received_token).message - # Kafka currently doesn't support integrity or confidentiality security layers, so we - # simply set QoP to 'auth' only (first octet). We reuse the max message size proposed - # by the server - msg = Int8.encode(SASL_QOP_AUTH & Int8.decode(io.BytesIO(msg[0:1]))) + msg[1:] - # add authorization identity to the response, GSS-wrap and send it - msg = client_ctx.wrap(msg + auth_id.encode(), False).message - size = Int32.encode(len(msg)) - self._send_bytes_blocking(size + msg) - - except (ConnectionError, TimeoutError) as e: - log.exception("%s: Error receiving reply from server", self) - err = Errors.KafkaConnectionError("%s: %s" % (self, e)) - close = True - except Exception as e: - err = e - close = True - - if err is not None: - if close: - self.close(error=err) - return future.failure(err) - - log.info('%s: Authenticated as %s via GSSAPI', self, gssapi_name) - return future.success(True) - - def _try_authenticate_oauth(self, future): - data = b'' - - msg = bytes(self._build_oauth_client_request().encode("utf-8")) - size = Int32.encode(len(msg)) - - err = None - close = False - with self._lock: - if not self._can_send_recv(): - err = Errors.NodeNotReadyError(str(self)) - close = False - else: - try: - # Send SASL OAuthBearer request with OAuth token - self._send_bytes_blocking(size + msg) - - # The server will send a zero sized message (that is Int32(0)) on success. - # The connection is closed on failure - data = self._recv_bytes_blocking(4) - - except (ConnectionError, TimeoutError) as e: - log.exception("%s: Error receiving reply from server", self) - err = Errors.KafkaConnectionError("%s: %s" % (self, e)) - close = True - - if err is not None: - if close: - self.close(error=err) - return future.failure(err) - - if data != b'\x00\x00\x00\x00': - error = Errors.AuthenticationFailedError('Unrecognized response during authentication') - return future.failure(error) - - log.info('%s: Authenticated via OAuth', self) - return future.success(True) - - def _build_oauth_client_request(self): - token_provider = self.config['sasl_oauth_token_provider'] - return "n,,\x01auth=Bearer {}{}\x01\x01".format(token_provider.token(), self._token_extensions()) - - def _token_extensions(self): - """ - Return a string representation of the OPTIONAL key-value pairs that can be sent with an OAUTHBEARER - initial request. - """ - token_provider = self.config['sasl_oauth_token_provider'] - - # Only run if the #extensions() method is implemented by the clients Token Provider class - # Builds up a string separated by \x01 via a dict of key value pairs - if callable(getattr(token_provider, "extensions", None)) and len(token_provider.extensions()) > 0: - msg = "\x01".join(["{}={}".format(k, v) for k, v in token_provider.extensions().items()]) - return "\x01" + msg - else: - return "" - - def blacked_out(self): - """ - Return true if we are disconnected from the given node and can't - re-establish a connection yet - """ - if self.state is ConnectionStates.DISCONNECTED: - if time.time() < self.last_attempt + self._reconnect_backoff: - return True - return False - - def connection_delay(self): - """ - Return the number of milliseconds to wait, based on the connection - state, before attempting to send data. When disconnected, this respects - the reconnect backoff time. When connecting or connected, returns a very - large number to handle slow/stalled connections. - """ - time_waited = time.time() - (self.last_attempt or 0) - if self.state is ConnectionStates.DISCONNECTED: - return max(self._reconnect_backoff - time_waited, 0) * 1000 - else: - # When connecting or connected, we should be able to delay - # indefinitely since other events (connection or data acked) will - # cause a wakeup once data can be sent. - return float('inf') - - def connected(self): - """Return True iff socket is connected.""" - return self.state is ConnectionStates.CONNECTED - - def connecting(self): - """Returns True if still connecting (this may encompass several - different states, such as SSL handshake, authorization, etc).""" - return self.state in (ConnectionStates.CONNECTING, - ConnectionStates.HANDSHAKE, - ConnectionStates.AUTHENTICATING) - - def disconnected(self): - """Return True iff socket is closed""" - return self.state is ConnectionStates.DISCONNECTED - - def _reset_reconnect_backoff(self): - self._failures = 0 - self._reconnect_backoff = self.config['reconnect_backoff_ms'] / 1000.0 - - def _update_reconnect_backoff(self): - # Do not mark as failure if there are more dns entries available to try - if len(self._gai) > 0: - return - if self.config['reconnect_backoff_max_ms'] > self.config['reconnect_backoff_ms']: - self._failures += 1 - self._reconnect_backoff = self.config['reconnect_backoff_ms'] * 2 ** (self._failures - 1) - self._reconnect_backoff = min(self._reconnect_backoff, self.config['reconnect_backoff_max_ms']) - self._reconnect_backoff *= uniform(0.8, 1.2) - self._reconnect_backoff /= 1000.0 - log.debug('%s: reconnect backoff %s after %s failures', self, self._reconnect_backoff, self._failures) - - def _close_socket(self): - if hasattr(self, '_sock') and self._sock is not None: - self._sock.close() - self._sock = None - - def __del__(self): - self._close_socket() - - def close(self, error=None): - """Close socket and fail all in-flight-requests. - - Arguments: - error (Exception, optional): pending in-flight-requests - will be failed with this exception. - Default: aiokafka.errors.KafkaConnectionError. - """ - if self.state is ConnectionStates.DISCONNECTED: - return - with self._lock: - if self.state is ConnectionStates.DISCONNECTED: - return - log.info('%s: Closing connection. %s', self, error or '') - self._update_reconnect_backoff() - self._sasl_auth_future = None - self._protocol = KafkaProtocol( - client_id=self.config['client_id'], - api_version=self.config['api_version']) - self._send_buffer = b'' - if error is None: - error = Errors.Cancelled(str(self)) - ifrs = list(self.in_flight_requests.items()) - self.in_flight_requests.clear() - self.state = ConnectionStates.DISCONNECTED - # To avoid race conditions and/or deadlocks - # keep a reference to the socket but leave it - # open until after the state_change_callback - # This should give clients a change to deregister - # the socket fd from selectors cleanly. - sock = self._sock - self._sock = None - - # drop lock before state change callback and processing futures - self.config['state_change_callback'](self.node_id, sock, self) - sock.close() - for (_correlation_id, (future, _timestamp)) in ifrs: - future.failure(error) - - def _can_send_recv(self): - """Return True iff socket is ready for requests / responses""" - return self.state in (ConnectionStates.AUTHENTICATING, - ConnectionStates.CONNECTED) - - def send(self, request, blocking=True): - """Queue request for async network send, return Future()""" - future = Future() - if self.connecting(): - return future.failure(Errors.NodeNotReadyError(str(self))) - elif not self.connected(): - return future.failure(Errors.KafkaConnectionError(str(self))) - elif not self.can_send_more(): - return future.failure(Errors.TooManyInFlightRequests(str(self))) - return self._send(request, blocking=blocking) - - def _send(self, request, blocking=True): - future = Future() - with self._lock: - if not self._can_send_recv(): - # In this case, since we created the future above, - # we know there are no callbacks/errbacks that could fire w/ - # lock. So failing + returning inline should be safe - return future.failure(Errors.NodeNotReadyError(str(self))) - - correlation_id = self._protocol.send_request(request) - - log.debug('%s Request %d: %s', self, correlation_id, request) - if request.expect_response(): - sent_time = time.time() - assert correlation_id not in self.in_flight_requests, 'Correlation ID already in-flight!' - self.in_flight_requests[correlation_id] = (future, sent_time) - else: - future.success(None) - - # Attempt to replicate behavior from prior to introduction of - # send_pending_requests() / async sends - if blocking: - self.send_pending_requests() - - return future - - def send_pending_requests(self): - """Attempts to send pending requests messages via blocking IO - If all requests have been sent, return True - Otherwise, if the socket is blocked and there are more bytes to send, - return False. - """ - try: - with self._lock: - if not self._can_send_recv(): - return False - data = self._protocol.send_bytes() - total_bytes = self._send_bytes_blocking(data) - - if self._sensors: - self._sensors.bytes_sent.record(total_bytes) - return True - - except (ConnectionError, TimeoutError) as e: - log.exception("Error sending request data to %s", self) - error = Errors.KafkaConnectionError("%s: %s" % (self, e)) - self.close(error=error) - return False - - def send_pending_requests_v2(self): - """Attempts to send pending requests messages via non-blocking IO - If all requests have been sent, return True - Otherwise, if the socket is blocked and there are more bytes to send, - return False. - """ - try: - with self._lock: - if not self._can_send_recv(): - return False - - # _protocol.send_bytes returns encoded requests to send - # we send them via _send_bytes() - # and hold leftover bytes in _send_buffer - if not self._send_buffer: - self._send_buffer = self._protocol.send_bytes() - - total_bytes = 0 - if self._send_buffer: - total_bytes = self._send_bytes(self._send_buffer) - self._send_buffer = self._send_buffer[total_bytes:] - - if self._sensors: - self._sensors.bytes_sent.record(total_bytes) - # Return True iff send buffer is empty - return len(self._send_buffer) == 0 - - except (ConnectionError, TimeoutError, Exception) as e: - log.exception("Error sending request data to %s", self) - error = Errors.KafkaConnectionError("%s: %s" % (self, e)) - self.close(error=error) - return False - - def can_send_more(self): - """Return True unless there are max_in_flight_requests_per_connection.""" - max_ifrs = self.config['max_in_flight_requests_per_connection'] - return len(self.in_flight_requests) < max_ifrs - - def recv(self): - """Non-blocking network receive. - - Return list of (response, future) tuples - """ - responses = self._recv() - if not responses and self.requests_timed_out(): - log.warning('%s timed out after %s ms. Closing connection.', - self, self.config['request_timeout_ms']) - self.close(error=Errors.RequestTimedOutError( - 'Request timed out after %s ms' % - self.config['request_timeout_ms'])) - return () - - # augment responses w/ correlation_id, future, and timestamp - for i, (correlation_id, response) in enumerate(responses): - try: - with self._lock: - (future, timestamp) = self.in_flight_requests.pop(correlation_id) - except KeyError: - self.close(Errors.KafkaConnectionError('Received unrecognized correlation id')) - return () - latency_ms = (time.time() - timestamp) * 1000 - if self._sensors: - self._sensors.request_time.record(latency_ms) - - log.debug('%s Response %d (%s ms): %s', self, correlation_id, latency_ms, response) - responses[i] = (response, future) - - return responses - - def _recv(self): - """Take all available bytes from socket, return list of any responses from parser""" - recvd = [] - err = None - with self._lock: - if not self._can_send_recv(): - log.warning('%s cannot recv: socket not connected', self) - return () - - while len(recvd) < self.config['sock_chunk_buffer_count']: - try: - data = self._sock.recv(self.config['sock_chunk_bytes']) - # We expect socket.recv to raise an exception if there are no - # bytes available to read from the socket in non-blocking mode. - # but if the socket is disconnected, we will get empty data - # without an exception raised - if not data: - log.error('%s: socket disconnected', self) - err = Errors.KafkaConnectionError('socket disconnected') - break - else: - recvd.append(data) - - except (SSLWantReadError, SSLWantWriteError): - break - except (ConnectionError, TimeoutError) as e: - if six.PY2 and e.errno == errno.EWOULDBLOCK: - break - log.exception('%s: Error receiving network data' - ' closing socket', self) - err = Errors.KafkaConnectionError(e) - break - except BlockingIOError: - if six.PY3: - break - # For PY2 this is a catchall and should be re-raised - raise - - # Only process bytes if there was no connection exception - if err is None: - recvd_data = b''.join(recvd) - if self._sensors: - self._sensors.bytes_received.record(len(recvd_data)) - - # We need to keep the lock through protocol receipt - # so that we ensure that the processed byte order is the - # same as the received byte order - try: - return self._protocol.receive_bytes(recvd_data) - except Errors.KafkaProtocolError as e: - err = e - - self.close(error=err) - return () - - def requests_timed_out(self): - with self._lock: - if self.in_flight_requests: - get_timestamp = lambda v: v[1] - oldest_at = min(map(get_timestamp, - self.in_flight_requests.values())) - timeout = self.config['request_timeout_ms'] / 1000.0 - if time.time() >= oldest_at + timeout: - return True - return False - - def _handle_api_version_response(self, response): - error_type = Errors.for_code(response.error_code) - assert error_type is Errors.NoError, "API version check failed" - self._api_versions = dict([ - (api_key, (min_version, max_version)) - for api_key, min_version, max_version in response.api_versions - ]) - return self._api_versions - - def get_api_versions(self): - if self._api_versions is not None: - return self._api_versions - - version = self.check_version() - if version < (0, 10, 0): - raise Errors.UnsupportedVersionError( - "ApiVersion not supported by cluster version {} < 0.10.0" - .format(version)) - # _api_versions is set as a side effect of check_versions() on a cluster - # that supports 0.10.0 or later - return self._api_versions - - def _infer_broker_version_from_api_versions(self, api_versions): - # The logic here is to check the list of supported request versions - # in reverse order. As soon as we find one that works, return it - test_cases = [ - # format (, ) - ((2, 6, 0), DescribeClientQuotasRequest[0]), - ((2, 5, 0), DescribeAclsRequest_v2), - ((2, 4, 0), ProduceRequest[8]), - ((2, 3, 0), FetchRequest[11]), - ((2, 2, 0), OffsetRequest[5]), - ((2, 1, 0), FetchRequest[10]), - ((2, 0, 0), FetchRequest[8]), - ((1, 1, 0), FetchRequest[7]), - ((1, 0, 0), MetadataRequest[5]), - ((0, 11, 0), MetadataRequest[4]), - ((0, 10, 2), OffsetFetchRequest[2]), - ((0, 10, 1), MetadataRequest[2]), - ] - - # Get the best match of test cases - for broker_version, struct in sorted(test_cases, reverse=True): - if struct.API_KEY not in api_versions: - continue - min_version, max_version = api_versions[struct.API_KEY] - if min_version <= struct.API_VERSION <= max_version: - return broker_version - - # We know that ApiVersionResponse is only supported in 0.10+ - # so if all else fails, choose that - return (0, 10, 0) - - def check_version(self, timeout=2, strict=False, topics=[]): - """Attempt to guess the broker version. - - Note: This is a blocking call. - - Returns: version tuple, i.e. (0, 10), (0, 9), (0, 8, 2), ... - """ - timeout_at = time.time() + timeout - log.info('Probing node %s broker version', self.node_id) - # Monkeypatch some connection configurations to avoid timeouts - override_config = { - 'request_timeout_ms': timeout * 1000, - 'max_in_flight_requests_per_connection': 5 - } - stashed = {} - for key in override_config: - stashed[key] = self.config[key] - self.config[key] = override_config[key] - - def reset_override_configs(): - for key in stashed: - self.config[key] = stashed[key] - - # kafka kills the connection when it doesn't recognize an API request - # so we can send a test request and then follow immediately with a - # vanilla MetadataRequest. If the server did not recognize the first - # request, both will be failed with a ConnectionError that wraps - # socket.error (32, 54, or 104) - from kafka.protocol.admin import ApiVersionRequest, ListGroupsRequest - from kafka.protocol.commit import OffsetFetchRequest, GroupCoordinatorRequest - - test_cases = [ - # All cases starting from 0.10 will be based on ApiVersionResponse - ((0, 10), ApiVersionRequest[0]()), - ((0, 9), ListGroupsRequest[0]()), - ((0, 8, 2), GroupCoordinatorRequest[0]('kafka-python-default-group')), - ((0, 8, 1), OffsetFetchRequest[0]('kafka-python-default-group', [])), - ((0, 8, 0), MetadataRequest[0](topics)), - ] - - for version, request in test_cases: - if not self.connect_blocking(timeout_at - time.time()): - reset_override_configs() - raise Errors.NodeNotReadyError() - f = self.send(request) - # HACK: sleeping to wait for socket to send bytes - time.sleep(0.1) - # when broker receives an unrecognized request API - # it abruptly closes our socket. - # so we attempt to send a second request immediately - # that we believe it will definitely recognize (metadata) - # the attempt to write to a disconnected socket should - # immediately fail and allow us to infer that the prior - # request was unrecognized - mr = self.send(MetadataRequest[0](topics)) - - selector = self.config['selector']() - selector.register(self._sock, selectors.EVENT_READ) - while not (f.is_done and mr.is_done): - selector.select(1) - for response, future in self.recv(): - future.success(response) - selector.close() - - if f.succeeded(): - if isinstance(request, ApiVersionRequest[0]): - # Starting from 0.10 kafka broker we determine version - # by looking at ApiVersionResponse - api_versions = self._handle_api_version_response(f.value) - version = self._infer_broker_version_from_api_versions(api_versions) - log.info('Broker version identified as %s', '.'.join(map(str, version))) - log.info('Set configuration api_version=%s to skip auto' - ' check_version requests on startup', version) - break - - # Only enable strict checking to verify that we understand failure - # modes. For most users, the fact that the request failed should be - # enough to rule out a particular broker version. - if strict: - # If the socket flush hack did not work (which should force the - # connection to close and fail all pending requests), then we - # get a basic Request Timeout. This is not ideal, but we'll deal - if isinstance(f.exception, Errors.RequestTimedOutError): - pass - - # 0.9 brokers do not close the socket on unrecognized api - # requests (bug...). In this case we expect to see a correlation - # id mismatch - elif (isinstance(f.exception, Errors.CorrelationIdError) and - version == (0, 10)): - pass - elif six.PY2: - assert isinstance(f.exception.args[0], socket.error) - assert f.exception.args[0].errno in (32, 54, 104) - else: - assert isinstance(f.exception.args[0], ConnectionError) - log.info("Broker is not v%s -- it did not recognize %s", - version, request.__class__.__name__) - else: - reset_override_configs() - raise Errors.UnrecognizedBrokerVersion() - - reset_override_configs() - return version - - def __str__(self): - return "" % ( - self.node_id, self.host, self.port, self.state, - AFI_NAMES[self._sock_afi], self._sock_addr) - - -class BrokerConnectionMetrics(object): - def __init__(self, metrics, metric_group_prefix, node_id): - self.metrics = metrics - - # Any broker may have registered summary metrics already - # but if not, we need to create them so we can set as parents below - all_conns_transferred = metrics.get_sensor('bytes-sent-received') - if not all_conns_transferred: - metric_group_name = metric_group_prefix + '-metrics' - - bytes_transferred = metrics.sensor('bytes-sent-received') - bytes_transferred.add(metrics.metric_name( - 'network-io-rate', metric_group_name, - 'The average number of network operations (reads or writes) on all' - ' connections per second.'), Rate(sampled_stat=Count())) - - bytes_sent = metrics.sensor('bytes-sent', - parents=[bytes_transferred]) - bytes_sent.add(metrics.metric_name( - 'outgoing-byte-rate', metric_group_name, - 'The average number of outgoing bytes sent per second to all' - ' servers.'), Rate()) - bytes_sent.add(metrics.metric_name( - 'request-rate', metric_group_name, - 'The average number of requests sent per second.'), - Rate(sampled_stat=Count())) - bytes_sent.add(metrics.metric_name( - 'request-size-avg', metric_group_name, - 'The average size of all requests in the window.'), Avg()) - bytes_sent.add(metrics.metric_name( - 'request-size-max', metric_group_name, - 'The maximum size of any request sent in the window.'), Max()) - - bytes_received = metrics.sensor('bytes-received', - parents=[bytes_transferred]) - bytes_received.add(metrics.metric_name( - 'incoming-byte-rate', metric_group_name, - 'Bytes/second read off all sockets'), Rate()) - bytes_received.add(metrics.metric_name( - 'response-rate', metric_group_name, - 'Responses received sent per second.'), - Rate(sampled_stat=Count())) - - request_latency = metrics.sensor('request-latency') - request_latency.add(metrics.metric_name( - 'request-latency-avg', metric_group_name, - 'The average request latency in ms.'), - Avg()) - request_latency.add(metrics.metric_name( - 'request-latency-max', metric_group_name, - 'The maximum request latency in ms.'), - Max()) - - # if one sensor of the metrics has been registered for the connection, - # then all other sensors should have been registered; and vice versa - node_str = 'node-{0}'.format(node_id) - node_sensor = metrics.get_sensor(node_str + '.bytes-sent') - if not node_sensor: - metric_group_name = metric_group_prefix + '-node-metrics.' + node_str - - bytes_sent = metrics.sensor( - node_str + '.bytes-sent', - parents=[metrics.get_sensor('bytes-sent')]) - bytes_sent.add(metrics.metric_name( - 'outgoing-byte-rate', metric_group_name, - 'The average number of outgoing bytes sent per second.'), - Rate()) - bytes_sent.add(metrics.metric_name( - 'request-rate', metric_group_name, - 'The average number of requests sent per second.'), - Rate(sampled_stat=Count())) - bytes_sent.add(metrics.metric_name( - 'request-size-avg', metric_group_name, - 'The average size of all requests in the window.'), - Avg()) - bytes_sent.add(metrics.metric_name( - 'request-size-max', metric_group_name, - 'The maximum size of any request sent in the window.'), - Max()) - - bytes_received = metrics.sensor( - node_str + '.bytes-received', - parents=[metrics.get_sensor('bytes-received')]) - bytes_received.add(metrics.metric_name( - 'incoming-byte-rate', metric_group_name, - 'Bytes/second read off node-connection socket'), - Rate()) - bytes_received.add(metrics.metric_name( - 'response-rate', metric_group_name, - 'The average number of responses received per second.'), - Rate(sampled_stat=Count())) - - request_time = metrics.sensor( - node_str + '.latency', - parents=[metrics.get_sensor('request-latency')]) - request_time.add(metrics.metric_name( - 'request-latency-avg', metric_group_name, - 'The average request latency in ms.'), - Avg()) - request_time.add(metrics.metric_name( - 'request-latency-max', metric_group_name, - 'The maximum request latency in ms.'), - Max()) - - self.bytes_sent = metrics.sensor(node_str + '.bytes-sent') - self.bytes_received = metrics.sensor(node_str + '.bytes-received') - self.request_time = metrics.sensor(node_str + '.latency') - - -def _address_family(address): - """ - Attempt to determine the family of an address (or hostname) - - :return: either socket.AF_INET or socket.AF_INET6 or socket.AF_UNSPEC if the address family - could not be determined - """ - if address.startswith('[') and address.endswith(']'): - return socket.AF_INET6 - for af in (socket.AF_INET, socket.AF_INET6): - try: - socket.inet_pton(af, address) - return af - except (ValueError, AttributeError, socket.error): - continue - return socket.AF_UNSPEC - - -def get_ip_port_afi(host_and_port_str): - """ - Parse the IP and port from a string in the format of: - - * host_or_ip <- Can be either IPv4 address literal or hostname/fqdn - * host_or_ipv4:port <- Can be either IPv4 address literal or hostname/fqdn - * [host_or_ip] <- IPv6 address literal - * [host_or_ip]:port. <- IPv6 address literal - - .. note:: IPv6 address literals with ports *must* be enclosed in brackets - - .. note:: If the port is not specified, default will be returned. - - :return: tuple (host, port, afi), afi will be socket.AF_INET or socket.AF_INET6 or socket.AF_UNSPEC - """ - host_and_port_str = host_and_port_str.strip() - if host_and_port_str.startswith('['): - af = socket.AF_INET6 - host, rest = host_and_port_str[1:].split(']') - if rest: - port = int(rest[1:]) - else: - port = DEFAULT_KAFKA_PORT - return host, port, af - else: - if ':' not in host_and_port_str: - af = _address_family(host_and_port_str) - return host_and_port_str, DEFAULT_KAFKA_PORT, af - else: - # now we have something with a colon in it and no square brackets. It could be - # either an IPv6 address literal (e.g., "::1") or an IP:port pair or a host:port pair - try: - # if it decodes as an IPv6 address, use that - socket.inet_pton(socket.AF_INET6, host_and_port_str) - return host_and_port_str, DEFAULT_KAFKA_PORT, socket.AF_INET6 - except AttributeError: - log.warning('socket.inet_pton not available on this platform.' - ' consider `pip install win_inet_pton`') - pass - except (ValueError, socket.error): - # it's a host:port pair - pass - host, port = host_and_port_str.rsplit(':', 1) - port = int(port) - - af = _address_family(host) - return host, port, af - - -def collect_hosts(hosts, randomize=True): - """ - Collects a comma-separated set of hosts (host:port) and optionally - randomize the returned list. - """ - - if isinstance(hosts, six.string_types): - hosts = hosts.strip().split(',') - - result = [] - afi = socket.AF_INET - for host_port in hosts: - - host, port, afi = get_ip_port_afi(host_port) - - if port < 0: - port = DEFAULT_KAFKA_PORT - - result.append((host, port, afi)) - - if randomize: - shuffle(result) - - return result - - -def is_inet_4_or_6(gai): - """Given a getaddrinfo struct, return True iff ipv4 or ipv6""" - return gai[0] in (socket.AF_INET, socket.AF_INET6) - - -def dns_lookup(host, port, afi=socket.AF_UNSPEC): - """Returns a list of getaddrinfo structs, optionally filtered to an afi (ipv4 / ipv6)""" - # XXX: all DNS functions in Python are blocking. If we really - # want to be non-blocking here, we need to use a 3rd-party - # library like python-adns, or move resolution onto its - # own thread. This will be subject to the default libc - # name resolution timeout (5s on most Linux boxes) - try: - return list(filter(is_inet_4_or_6, - socket.getaddrinfo(host, port, afi, - socket.SOCK_STREAM))) - except socket.gaierror as ex: - log.warning('DNS lookup failed for %s:%d,' - ' exception was %s. Is your' - ' advertised.listeners (called' - ' advertised.host.name before Kafka 9)' - ' correct and resolvable?', - host, port, ex) - return [] diff --git a/tests/kafka/conftest.py b/tests/kafka/conftest.py index 2fd11b40..04aec4b8 100644 --- a/tests/kafka/conftest.py +++ b/tests/kafka/conftest.py @@ -116,33 +116,6 @@ def topic(kafka_broker, request): return topic_name -@pytest.fixture -def conn(mocker): - """Return a connection mocker fixture""" - from kafka.conn import ConnectionStates - from kafka.future import Future - from kafka.protocol.metadata import MetadataResponse - conn = mocker.patch('kafka.client_async.BrokerConnection') - conn.return_value = conn - conn.state = ConnectionStates.CONNECTED - conn.send.return_value = Future().success( - MetadataResponse[0]( - [(0, 'foo', 12), (1, 'bar', 34)], # brokers - [])) # topics - conn.blacked_out.return_value = False - def _set_conn_state(state): - conn.state = state - return state - conn._set_conn_state = _set_conn_state - conn.connect.side_effect = lambda: conn.state - conn.connect_blocking.return_value = True - conn.connecting = lambda: conn.state in (ConnectionStates.CONNECTING, - ConnectionStates.HANDSHAKE) - conn.connected = lambda: conn.state is ConnectionStates.CONNECTED - conn.disconnected = lambda: conn.state is ConnectionStates.DISCONNECTED - return conn - - @pytest.fixture() def send_messages(topic, kafka_producer, request): """A factory that returns a send_messages function with a pre-populated diff --git a/tests/kafka/test_conn.py b/tests/kafka/test_conn.py deleted file mode 100644 index 6eb45f45..00000000 --- a/tests/kafka/test_conn.py +++ /dev/null @@ -1,342 +0,0 @@ -# pylint: skip-file -from __future__ import absolute_import - -from errno import EALREADY, EINPROGRESS, EISCONN, ECONNRESET -import socket - -from unittest import mock -import pytest - -from kafka.conn import BrokerConnection, ConnectionStates, collect_hosts -from kafka.protocol.api import RequestHeader -from kafka.protocol.metadata import MetadataRequest -from kafka.protocol.produce import ProduceRequest - -import aiokafka.errors as Errors - - -@pytest.fixture -def dns_lookup(mocker): - return mocker.patch('kafka.conn.dns_lookup', - return_value=[(socket.AF_INET, - None, None, None, - ('localhost', 9092))]) - -@pytest.fixture -def _socket(mocker): - socket = mocker.MagicMock() - socket.connect_ex.return_value = 0 - mocker.patch('socket.socket', return_value=socket) - return socket - - -@pytest.fixture -def conn(_socket, dns_lookup): - conn = BrokerConnection('localhost', 9092, socket.AF_INET) - return conn - - -@pytest.mark.parametrize("states", [ - (([EINPROGRESS, EALREADY], ConnectionStates.CONNECTING),), - (([EALREADY, EALREADY], ConnectionStates.CONNECTING),), - (([0], ConnectionStates.CONNECTED),), - (([EINPROGRESS, EALREADY], ConnectionStates.CONNECTING), - ([ECONNRESET], ConnectionStates.DISCONNECTED)), - (([EINPROGRESS, EALREADY], ConnectionStates.CONNECTING), - ([EALREADY], ConnectionStates.CONNECTING), - ([EISCONN], ConnectionStates.CONNECTED)), -]) -def test_connect(_socket, conn, states): - assert conn.state is ConnectionStates.DISCONNECTED - - for errno, state in states: - _socket.connect_ex.side_effect = errno - conn.connect() - assert conn.state is state - - -def test_connect_timeout(_socket, conn): - assert conn.state is ConnectionStates.DISCONNECTED - - # Initial connect returns EINPROGRESS - # immediate inline connect returns EALREADY - # second explicit connect returns EALREADY - # third explicit connect returns EALREADY and times out via last_attempt - _socket.connect_ex.side_effect = [EINPROGRESS, EALREADY, EALREADY, EALREADY] - conn.connect() - assert conn.state is ConnectionStates.CONNECTING - conn.connect() - assert conn.state is ConnectionStates.CONNECTING - conn.last_attempt = 0 - conn.connect() - assert conn.state is ConnectionStates.DISCONNECTED - - -def test_blacked_out(conn): - with mock.patch("time.time", return_value=1000): - conn.last_attempt = 0 - assert conn.blacked_out() is False - conn.last_attempt = 1000 - assert conn.blacked_out() is True - - -def test_connection_delay(conn): - with mock.patch("time.time", return_value=1000): - conn.last_attempt = 1000 - assert conn.connection_delay() == conn.config['reconnect_backoff_ms'] - conn.state = ConnectionStates.CONNECTING - assert conn.connection_delay() == float('inf') - conn.state = ConnectionStates.CONNECTED - assert conn.connection_delay() == float('inf') - - -def test_connected(conn): - assert conn.connected() is False - conn.state = ConnectionStates.CONNECTED - assert conn.connected() is True - - -def test_connecting(conn): - assert conn.connecting() is False - conn.state = ConnectionStates.CONNECTING - assert conn.connecting() is True - conn.state = ConnectionStates.CONNECTED - assert conn.connecting() is False - - -def test_send_disconnected(conn): - conn.state = ConnectionStates.DISCONNECTED - f = conn.send('foobar') - assert f.failed() is True - assert isinstance(f.exception, Errors.KafkaConnectionError) - - -def test_send_connecting(conn): - conn.state = ConnectionStates.CONNECTING - f = conn.send('foobar') - assert f.failed() is True - assert isinstance(f.exception, Errors.NodeNotReadyError) - - -def test_send_max_ifr(conn): - conn.state = ConnectionStates.CONNECTED - max_ifrs = conn.config['max_in_flight_requests_per_connection'] - for i in range(max_ifrs): - conn.in_flight_requests[i] = 'foo' - f = conn.send('foobar') - assert f.failed() is True - assert isinstance(f.exception, Errors.TooManyInFlightRequests) - - -def test_send_no_response(_socket, conn): - conn.connect() - assert conn.state is ConnectionStates.CONNECTED - req = ProduceRequest[0](required_acks=0, timeout=0, topics=()) - header = RequestHeader(req, client_id=conn.config['client_id']) - payload_bytes = len(header.encode()) + len(req.encode()) - third = payload_bytes // 3 - remainder = payload_bytes % 3 - _socket.send.side_effect = [4, third, third, third, remainder] - - assert len(conn.in_flight_requests) == 0 - f = conn.send(req) - assert f.succeeded() is True - assert f.value is None - assert len(conn.in_flight_requests) == 0 - - -def test_send_response(_socket, conn): - conn.connect() - assert conn.state is ConnectionStates.CONNECTED - req = MetadataRequest[0]([]) - header = RequestHeader(req, client_id=conn.config['client_id']) - payload_bytes = len(header.encode()) + len(req.encode()) - third = payload_bytes // 3 - remainder = payload_bytes % 3 - _socket.send.side_effect = [4, third, third, third, remainder] - - assert len(conn.in_flight_requests) == 0 - f = conn.send(req) - assert f.is_done is False - assert len(conn.in_flight_requests) == 1 - - -def test_send_error(_socket, conn): - conn.connect() - assert conn.state is ConnectionStates.CONNECTED - req = MetadataRequest[0]([]) - try: - _socket.send.side_effect = ConnectionError - except NameError: - _socket.send.side_effect = socket.error - f = conn.send(req) - assert f.failed() is True - assert isinstance(f.exception, Errors.KafkaConnectionError) - assert _socket.close.call_count == 1 - assert conn.state is ConnectionStates.DISCONNECTED - - -def test_can_send_more(conn): - assert conn.can_send_more() is True - max_ifrs = conn.config['max_in_flight_requests_per_connection'] - for i in range(max_ifrs): - assert conn.can_send_more() is True - conn.in_flight_requests[i] = 'foo' - assert conn.can_send_more() is False - - -def test_recv_disconnected(_socket, conn): - conn.connect() - assert conn.connected() - - req = MetadataRequest[0]([]) - header = RequestHeader(req, client_id=conn.config['client_id']) - payload_bytes = len(header.encode()) + len(req.encode()) - _socket.send.side_effect = [4, payload_bytes] - conn.send(req) - - # Empty data on recv means the socket is disconnected - _socket.recv.return_value = b'' - - # Attempt to receive should mark connection as disconnected - assert conn.connected() - conn.recv() - assert conn.disconnected() - - -def test_recv(_socket, conn): - pass # TODO - - -def test_close(conn): - pass # TODO - - -def test_collect_hosts__happy_path(): - hosts = "127.0.0.1:1234,127.0.0.1" - results = collect_hosts(hosts) - assert set(results) == set([ - ('127.0.0.1', 1234, socket.AF_INET), - ('127.0.0.1', 9092, socket.AF_INET), - ]) - - -def test_collect_hosts__ipv6(): - hosts = "[localhost]:1234,[2001:1000:2000::1],[2001:1000:2000::1]:1234" - results = collect_hosts(hosts) - assert set(results) == set([ - ('localhost', 1234, socket.AF_INET6), - ('2001:1000:2000::1', 9092, socket.AF_INET6), - ('2001:1000:2000::1', 1234, socket.AF_INET6), - ]) - - -def test_collect_hosts__string_list(): - hosts = [ - 'localhost:1234', - 'localhost', - '[localhost]', - '2001::1', - '[2001::1]', - '[2001::1]:1234', - ] - results = collect_hosts(hosts) - assert set(results) == set([ - ('localhost', 1234, socket.AF_UNSPEC), - ('localhost', 9092, socket.AF_UNSPEC), - ('localhost', 9092, socket.AF_INET6), - ('2001::1', 9092, socket.AF_INET6), - ('2001::1', 9092, socket.AF_INET6), - ('2001::1', 1234, socket.AF_INET6), - ]) - - -def test_collect_hosts__with_spaces(): - hosts = "localhost:1234, localhost" - results = collect_hosts(hosts) - assert set(results) == set([ - ('localhost', 1234, socket.AF_UNSPEC), - ('localhost', 9092, socket.AF_UNSPEC), - ]) - - -def test_lookup_on_connect(): - hostname = 'example.org' - port = 9092 - conn = BrokerConnection(hostname, port, socket.AF_UNSPEC) - assert conn.host == hostname - assert conn.port == port - assert conn.afi == socket.AF_UNSPEC - afi1 = socket.AF_INET - sockaddr1 = ('127.0.0.1', 9092) - mock_return1 = [ - (afi1, socket.SOCK_STREAM, 6, '', sockaddr1), - ] - with mock.patch("socket.getaddrinfo", return_value=mock_return1) as m: - conn.connect() - m.assert_called_once_with(hostname, port, 0, socket.SOCK_STREAM) - assert conn._sock_afi == afi1 - assert conn._sock_addr == sockaddr1 - conn.close() - - afi2 = socket.AF_INET6 - sockaddr2 = ('::1', 9092, 0, 0) - mock_return2 = [ - (afi2, socket.SOCK_STREAM, 6, '', sockaddr2), - ] - - with mock.patch("socket.getaddrinfo", return_value=mock_return2) as m: - conn.last_attempt = 0 - conn.connect() - m.assert_called_once_with(hostname, port, 0, socket.SOCK_STREAM) - assert conn._sock_afi == afi2 - assert conn._sock_addr == sockaddr2 - conn.close() - - -def test_relookup_on_failure(): - hostname = 'example.org' - port = 9092 - conn = BrokerConnection(hostname, port, socket.AF_UNSPEC) - assert conn.host == hostname - mock_return1 = [] - with mock.patch("socket.getaddrinfo", return_value=mock_return1) as m: - last_attempt = conn.last_attempt - conn.connect() - m.assert_called_once_with(hostname, port, 0, socket.SOCK_STREAM) - assert conn.disconnected() - assert conn.last_attempt > last_attempt - - afi2 = socket.AF_INET - sockaddr2 = ('127.0.0.2', 9092) - mock_return2 = [ - (afi2, socket.SOCK_STREAM, 6, '', sockaddr2), - ] - - with mock.patch("socket.getaddrinfo", return_value=mock_return2) as m: - conn.last_attempt = 0 - conn.connect() - m.assert_called_once_with(hostname, port, 0, socket.SOCK_STREAM) - assert conn._sock_afi == afi2 - assert conn._sock_addr == sockaddr2 - conn.close() - - -def test_requests_timed_out(conn): - with mock.patch("time.time", return_value=0): - # No in-flight requests, not timed out - assert not conn.requests_timed_out() - - # Single request, timestamp = now (0) - conn.in_flight_requests[0] = ('foo', 0) - assert not conn.requests_timed_out() - - # Add another request w/ timestamp > request_timeout ago - request_timeout = conn.config['request_timeout_ms'] - expired_timestamp = 0 - request_timeout - 1 - conn.in_flight_requests[1] = ('bar', expired_timestamp) - assert conn.requests_timed_out() - - # Drop the expired request and we should be good to go again - conn.in_flight_requests.pop(1) - assert not conn.requests_timed_out() diff --git a/tests/kafka/test_cluster.py b/tests/test_cluster.py similarity index 81% rename from tests/kafka/test_cluster.py rename to tests/test_cluster.py index f010c4f7..0fad6e31 100644 --- a/tests/kafka/test_cluster.py +++ b/tests/test_cluster.py @@ -1,11 +1,7 @@ -# pylint: skip-file -from __future__ import absolute_import - -import pytest - -from kafka.cluster import ClusterMetadata from kafka.protocol.metadata import MetadataResponse +from aiokafka.cluster import ClusterMetadata + def test_empty_broker_list(): cluster = ClusterMetadata() diff --git a/tests/test_message_accumulator.py b/tests/test_message_accumulator.py index bcce740b..16a70e9f 100644 --- a/tests/test_message_accumulator.py +++ b/tests/test_message_accumulator.py @@ -3,9 +3,9 @@ import unittest from unittest import mock -from kafka.cluster import ClusterMetadata from kafka.structs import TopicPartition -from ._testutil import run_until_complete + +from aiokafka.cluster import ClusterMetadata from aiokafka.errors import ( KafkaTimeoutError, NotLeaderForPartitionError, LeaderNotAvailableError ) @@ -14,6 +14,8 @@ MessageAccumulator, MessageBatch, BatchBuilder ) +from ._testutil import run_until_complete + @pytest.mark.usefixtures('setup_test_class_serverless') class TestMessageAccumulator(unittest.TestCase): diff --git a/tests/test_producer.py b/tests/test_producer.py index 63460446..9cc6858b 100644 --- a/tests/test_producer.py +++ b/tests/test_producer.py @@ -6,23 +6,22 @@ import weakref from unittest import mock -from kafka.cluster import ClusterMetadata from kafka.protocol.produce import ProduceResponse -from ._testutil import ( - KafkaIntegrationTestCase, run_until_complete, run_in_thread, kafka_versions -) - from aiokafka.producer import AIOKafkaProducer from aiokafka.client import AIOKafkaClient +from aiokafka.cluster import ClusterMetadata from aiokafka.consumer import AIOKafkaConsumer -from aiokafka.util import create_future - from aiokafka.errors import ( KafkaTimeoutError, UnknownTopicOrPartitionError, MessageSizeTooLargeError, NotLeaderForPartitionError, LeaderNotAvailableError, RequestTimedOutError, UnsupportedVersionError, ProducerClosed, KafkaError) +from aiokafka.util import create_future + +from ._testutil import ( + KafkaIntegrationTestCase, run_until_complete, run_in_thread, kafka_versions +) LOG_APPEND_TIME = 1 From 59fbd73bb030637ef0922e8a43e04ce9da65d835 Mon Sep 17 00:00:00 2001 From: Denis Otkidach Date: Sun, 22 Oct 2023 21:29:13 +0300 Subject: [PATCH 13/20] Merge structs --- aiokafka/admin/client.py | 4 +- aiokafka/coordinator/assignors/roundrobin.py | 3 +- .../assignors/sticky/sticky_assignor.py | 2 +- aiokafka/coordinator/consumer.py | 2 +- aiokafka/coordinator/protocol.py | 3 +- aiokafka/structs.py | 77 ++++++++++++++-- docs/api.rst | 2 +- examples/ssl_consume_produce.py | 2 +- kafka/__init__.py | 1 - kafka/scram.py | 81 ----------------- kafka/structs.py | 87 ------------------- tests/coordinator/test_assignors.py | 2 +- tests/coordinator/test_partition_movements.py | 3 +- tests/test_message_accumulator.py | 3 +- 14 files changed, 82 insertions(+), 190 deletions(-) delete mode 100644 kafka/scram.py delete mode 100644 kafka/structs.py diff --git a/aiokafka/admin/client.py b/aiokafka/admin/client.py index 7ce2465a..32309d7b 100644 --- a/aiokafka/admin/client.py +++ b/aiokafka/admin/client.py @@ -16,11 +16,11 @@ AlterConfigsRequest, ListGroupsRequest, ApiVersionRequest_v0) -from kafka.structs import TopicPartition, OffsetAndMetadata from aiokafka import __version__ -from aiokafka.errors import IncompatibleBrokerVersion, for_code from aiokafka.client import AIOKafkaClient +from aiokafka.errors import IncompatibleBrokerVersion, for_code +from aiokafka.structs import TopicPartition, OffsetAndMetadata from .config_resource import ConfigResourceType, ConfigResource from .new_topic import NewTopic diff --git a/aiokafka/coordinator/assignors/roundrobin.py b/aiokafka/coordinator/assignors/roundrobin.py index f3dd47f2..3ee09ac0 100644 --- a/aiokafka/coordinator/assignors/roundrobin.py +++ b/aiokafka/coordinator/assignors/roundrobin.py @@ -2,13 +2,12 @@ import itertools import logging -from kafka.structs import TopicPartition - from aiokafka.coordinator.assignors.abstract import AbstractPartitionAssignor from aiokafka.coordinator.protocol import ( ConsumerProtocolMemberMetadata, ConsumerProtocolMemberAssignment, ) +from aiokafka.structs import TopicPartition log = logging.getLogger(__name__) diff --git a/aiokafka/coordinator/assignors/sticky/sticky_assignor.py b/aiokafka/coordinator/assignors/sticky/sticky_assignor.py index 452b2fd6..05e14ef2 100644 --- a/aiokafka/coordinator/assignors/sticky/sticky_assignor.py +++ b/aiokafka/coordinator/assignors/sticky/sticky_assignor.py @@ -4,7 +4,6 @@ from kafka.protocol.struct import Struct from kafka.protocol.types import String, Array, Int32 -from kafka.structs import TopicPartition from aiokafka.coordinator.assignors.abstract import AbstractPartitionAssignor from aiokafka.coordinator.assignors.sticky.partition_movements import PartitionMovements @@ -14,6 +13,7 @@ ConsumerProtocolMemberAssignment, ) from aiokafka.coordinator.protocol import Schema +from aiokafka.structs import TopicPartition log = logging.getLogger(__name__) diff --git a/aiokafka/coordinator/consumer.py b/aiokafka/coordinator/consumer.py index 60f922ff..47af8449 100644 --- a/aiokafka/coordinator/consumer.py +++ b/aiokafka/coordinator/consumer.py @@ -8,10 +8,10 @@ from kafka.metrics import AnonMeasurable from kafka.metrics.stats import Avg, Count, Max, Rate from kafka.protocol.commit import OffsetCommitRequest, OffsetFetchRequest -from kafka.structs import OffsetAndMetadata, TopicPartition from kafka.util import WeakMethod import aiokafka.errors as Errors +from aiokafka.structs import OffsetAndMetadata, TopicPartition from .base import BaseCoordinator, Generation from .assignors.range import RangePartitionAssignor diff --git a/aiokafka/coordinator/protocol.py b/aiokafka/coordinator/protocol.py index 87425007..0dfbe7f9 100644 --- a/aiokafka/coordinator/protocol.py +++ b/aiokafka/coordinator/protocol.py @@ -1,6 +1,7 @@ from kafka.protocol.struct import Struct from kafka.protocol.types import Array, Bytes, Int16, Int32, Schema, String -from kafka.structs import TopicPartition + +from aiokafka.structs import TopicPartition class ConsumerProtocolMemberMetadata(Struct): diff --git a/aiokafka/structs.py b/aiokafka/structs.py index f7303851..d51a0bcd 100644 --- a/aiokafka/structs.py +++ b/aiokafka/structs.py @@ -1,12 +1,7 @@ from dataclasses import dataclass -from typing import Generic, NamedTuple, Optional, Sequence, Tuple, TypeVar +from typing import Generic, List, NamedTuple, Optional, Sequence, Tuple, TypeVar -from kafka.structs import ( - BrokerMetadata, - OffsetAndMetadata, - PartitionMetadata, - TopicPartition, -) +from aiokafka.errors import KafkaError __all__ = [ @@ -19,6 +14,74 @@ ] +class TopicPartition(NamedTuple): + """A topic and partition tuple""" + + topic: str + "A topic name" + + partition: int + "A partition id" + + +class BrokerMetadata(NamedTuple): + """A Kafka broker metadata used by admin tools""" + + nodeId: int + "The Kafka broker id" + + host: str + "The Kafka broker hostname" + + port: int + "The Kafka broker port" + + rack: Optional[str] + """The rack of the broker, which is used to in rack aware partition + assignment for fault tolerance. + Examples: `RACK1`, `us-east-1d`. Default: None + """ + + +class PartitionMetadata(NamedTuple): + """A topic partition metadata describing the state in the MetadataResponse""" + + topic: str + "The topic name of the partition this metadata relates to" + + partition: int + "The id of the partition this metadata relates to" + + leader: int + "The id of the broker that is the leader for the partition" + + replicas: List[int] + "The ids of all brokers that contain replicas of the partition" + isr: List[int] + "The ids of all brokers that contain in-sync replicas of the partition" + + error: Optional[KafkaError] + "A KafkaError object associated with the request for this partition metadata" + + +class OffsetAndMetadata(NamedTuple): + """The Kafka offset commit API + + The Kafka offset commit API allows users to provide additional metadata + (in the form of a string) when an offset is committed. This can be useful + (for example) to store information about which node made the commit, + what time the commit was made, etc. + """ + + offset: int + "The offset to be committed" + + metadata: str + "Non-null metadata" + + # TODO add leaderEpoch: + + class RecordMetadata(NamedTuple): """Returned when a :class:`~.AIOKafkaProducer` sends a message""" diff --git a/docs/api.rst b/docs/api.rst index 2a8043c7..0c698409 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -149,7 +149,7 @@ Structs .. automodule:: aiokafka.structs -.. autoclass:: kafka.structs.TopicPartition +.. autoclass:: aiokafka.structs.TopicPartition :members: .. autoclass:: aiokafka.structs.RecordMetadata diff --git a/examples/ssl_consume_produce.py b/examples/ssl_consume_produce.py index bfc6d598..9f56f98b 100644 --- a/examples/ssl_consume_produce.py +++ b/examples/ssl_consume_produce.py @@ -1,7 +1,7 @@ import asyncio from aiokafka import AIOKafkaProducer, AIOKafkaConsumer from aiokafka.helpers import create_ssl_context -from kafka.structs import TopicPartition +from aiokafka.structs import TopicPartition context = create_ssl_context( cafile="./ca-cert", # CA used to sign certificate. diff --git a/kafka/__init__.py b/kafka/__init__.py index 65ef3074..08a919c1 100644 --- a/kafka/__init__.py +++ b/kafka/__init__.py @@ -19,7 +19,6 @@ def emit(self, record): from kafka.serializer import Serializer, Deserializer -from kafka.structs import TopicPartition, OffsetAndMetadata __all__ = [ diff --git a/kafka/scram.py b/kafka/scram.py deleted file mode 100644 index 7f003750..00000000 --- a/kafka/scram.py +++ /dev/null @@ -1,81 +0,0 @@ -from __future__ import absolute_import - -import base64 -import hashlib -import hmac -import uuid - -from kafka.vendor import six - - -if six.PY2: - def xor_bytes(left, right): - return bytearray(ord(lb) ^ ord(rb) for lb, rb in zip(left, right)) -else: - def xor_bytes(left, right): - return bytes(lb ^ rb for lb, rb in zip(left, right)) - - -class ScramClient: - MECHANISMS = { - 'SCRAM-SHA-256': hashlib.sha256, - 'SCRAM-SHA-512': hashlib.sha512 - } - - def __init__(self, user, password, mechanism): - self.nonce = str(uuid.uuid4()).replace('-', '') - self.auth_message = '' - self.salted_password = None - self.user = user - self.password = password.encode('utf-8') - self.hashfunc = self.MECHANISMS[mechanism] - self.hashname = ''.join(mechanism.lower().split('-')[1:3]) - self.stored_key = None - self.client_key = None - self.client_signature = None - self.client_proof = None - self.server_key = None - self.server_signature = None - - def first_message(self): - client_first_bare = 'n={},r={}'.format(self.user, self.nonce) - self.auth_message += client_first_bare - return 'n,,' + client_first_bare - - def process_server_first_message(self, server_first_message): - self.auth_message += ',' + server_first_message - params = dict(pair.split('=', 1) for pair in server_first_message.split(',')) - server_nonce = params['r'] - if not server_nonce.startswith(self.nonce): - raise ValueError("Server nonce, did not start with client nonce!") - self.nonce = server_nonce - self.auth_message += ',c=biws,r=' + self.nonce - - salt = base64.b64decode(params['s'].encode('utf-8')) - iterations = int(params['i']) - self.create_salted_password(salt, iterations) - - self.client_key = self.hmac(self.salted_password, b'Client Key') - self.stored_key = self.hashfunc(self.client_key).digest() - self.client_signature = self.hmac(self.stored_key, self.auth_message.encode('utf-8')) - self.client_proof = xor_bytes(self.client_key, self.client_signature) - self.server_key = self.hmac(self.salted_password, b'Server Key') - self.server_signature = self.hmac(self.server_key, self.auth_message.encode('utf-8')) - - def hmac(self, key, msg): - return hmac.new(key, msg, digestmod=self.hashfunc).digest() - - def create_salted_password(self, salt, iterations): - self.salted_password = hashlib.pbkdf2_hmac( - self.hashname, self.password, salt, iterations - ) - - def final_message(self): - return 'c=biws,r={},p={}'.format(self.nonce, base64.b64encode(self.client_proof).decode('utf-8')) - - def process_server_final_message(self, server_final_message): - params = dict(pair.split('=', 1) for pair in server_final_message.split(',')) - if self.server_signature != base64.b64decode(params['v'].encode('utf-8')): - raise ValueError("Server sent wrong signature!") - - diff --git a/kafka/structs.py b/kafka/structs.py deleted file mode 100644 index bcb02367..00000000 --- a/kafka/structs.py +++ /dev/null @@ -1,87 +0,0 @@ -""" Other useful structs """ -from __future__ import absolute_import - -from collections import namedtuple - - -"""A topic and partition tuple - -Keyword Arguments: - topic (str): A topic name - partition (int): A partition id -""" -TopicPartition = namedtuple("TopicPartition", - ["topic", "partition"]) - - -"""A Kafka broker metadata used by admin tools. - -Keyword Arguments: - nodeID (int): The Kafka broker id. - host (str): The Kafka broker hostname. - port (int): The Kafka broker port. - rack (str): The rack of the broker, which is used to in rack aware - partition assignment for fault tolerance. - Examples: `RACK1`, `us-east-1d`. Default: None -""" -BrokerMetadata = namedtuple("BrokerMetadata", - ["nodeId", "host", "port", "rack"]) - - -"""A topic partition metadata describing the state in the MetadataResponse. - -Keyword Arguments: - topic (str): The topic name of the partition this metadata relates to. - partition (int): The id of the partition this metadata relates to. - leader (int): The id of the broker that is the leader for the partition. - replicas (List[int]): The ids of all brokers that contain replicas of the - partition. - isr (List[int]): The ids of all brokers that contain in-sync replicas of - the partition. - error (KafkaError): A KafkaError object associated with the request for - this partition metadata. -""" -PartitionMetadata = namedtuple("PartitionMetadata", - ["topic", "partition", "leader", "replicas", "isr", "error"]) - - -"""The Kafka offset commit API - -The Kafka offset commit API allows users to provide additional metadata -(in the form of a string) when an offset is committed. This can be useful -(for example) to store information about which node made the commit, -what time the commit was made, etc. - -Keyword Arguments: - offset (int): The offset to be committed - metadata (str): Non-null metadata -""" -OffsetAndMetadata = namedtuple("OffsetAndMetadata", - # TODO add leaderEpoch: OffsetAndMetadata(offset, leaderEpoch, metadata) - ["offset", "metadata"]) - - -"""An offset and timestamp tuple - -Keyword Arguments: - offset (int): An offset - timestamp (int): The timestamp associated to the offset -""" -OffsetAndTimestamp = namedtuple("OffsetAndTimestamp", - ["offset", "timestamp"]) - -MemberInformation = namedtuple("MemberInformation", - ["member_id", "client_id", "client_host", "member_metadata", "member_assignment"]) - -GroupInformation = namedtuple("GroupInformation", - ["error_code", "group", "state", "protocol_type", "protocol", "members", "authorized_operations"]) - -"""Define retry policy for async producer - -Keyword Arguments: - Limit (int): Number of retries. limit >= 0, 0 means no retries - backoff_ms (int): Milliseconds to backoff. - retry_on_timeouts: -""" -RetryOptions = namedtuple("RetryOptions", - ["limit", "backoff_ms", "retry_on_timeouts"]) diff --git a/tests/coordinator/test_assignors.py b/tests/coordinator/test_assignors.py index 5fc8a5f5..9ba3171f 100644 --- a/tests/coordinator/test_assignors.py +++ b/tests/coordinator/test_assignors.py @@ -6,13 +6,13 @@ import pytest -from kafka.structs import TopicPartition from aiokafka.coordinator.assignors.range import RangePartitionAssignor from aiokafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor from aiokafka.coordinator.assignors.sticky.sticky_assignor import ( StickyPartitionAssignor, ) from aiokafka.coordinator.protocol import ConsumerProtocolMemberAssignment +from aiokafka.structs import TopicPartition @pytest.fixture(autouse=True) diff --git a/tests/coordinator/test_partition_movements.py b/tests/coordinator/test_partition_movements.py index d5da876b..d901e4fe 100644 --- a/tests/coordinator/test_partition_movements.py +++ b/tests/coordinator/test_partition_movements.py @@ -1,6 +1,5 @@ -from kafka.structs import TopicPartition - from aiokafka.coordinator.assignors.sticky.partition_movements import PartitionMovements +from aiokafka.structs import TopicPartition def test_empty_movements_are_sticky(): diff --git a/tests/test_message_accumulator.py b/tests/test_message_accumulator.py index 16a70e9f..1c3c5f2d 100644 --- a/tests/test_message_accumulator.py +++ b/tests/test_message_accumulator.py @@ -3,8 +3,6 @@ import unittest from unittest import mock -from kafka.structs import TopicPartition - from aiokafka.cluster import ClusterMetadata from aiokafka.errors import ( KafkaTimeoutError, NotLeaderForPartitionError, LeaderNotAvailableError @@ -13,6 +11,7 @@ from aiokafka.producer.message_accumulator import ( MessageAccumulator, MessageBatch, BatchBuilder ) +from aiokafka.structs import TopicPartition from ._testutil import run_until_complete From acff7a50be97a1ce68a1b591cff0430f659c5c0f Mon Sep 17 00:00:00 2001 From: Denis Otkidach Date: Sun, 22 Oct 2023 22:00:22 +0300 Subject: [PATCH 14/20] Move oauth --- aiokafka/consumer/consumer.py | 2 +- kafka/oauth/abstract.py => aiokafka/oauth.py | 6 +----- aiokafka/producer/producer.py | 2 +- docs/api.rst | 2 +- kafka/oauth/__init__.py | 3 --- 5 files changed, 4 insertions(+), 11 deletions(-) rename kafka/oauth/abstract.py => aiokafka/oauth.py (85%) delete mode 100644 kafka/oauth/__init__.py diff --git a/aiokafka/consumer/consumer.py b/aiokafka/consumer/consumer.py index c82f6fb0..8b58fb4a 100644 --- a/aiokafka/consumer/consumer.py +++ b/aiokafka/consumer/consumer.py @@ -213,7 +213,7 @@ class AIOKafkaConsumer: sasl_plain_password (str): password for SASL ``PLAIN`` authentication. Default: None sasl_oauth_token_provider (~aiokafka.abc.AbstractTokenProvider): - OAuthBearer token provider instance. (See :mod:`kafka.oauth.abstract`). + OAuthBearer token provider instance. (See :mod:`aiokafka.oauth`). Default: None Note: diff --git a/kafka/oauth/abstract.py b/aiokafka/oauth.py similarity index 85% rename from kafka/oauth/abstract.py rename to aiokafka/oauth.py index 8d89ff51..cc416da1 100644 --- a/kafka/oauth/abstract.py +++ b/aiokafka/oauth.py @@ -1,11 +1,7 @@ -from __future__ import absolute_import - import abc -# This statement is compatible with both Python 2.7 & 3+ -ABC = abc.ABCMeta('ABC', (object,), {'__slots__': ()}) -class AbstractTokenProvider(ABC): +class AbstractTokenProvider(abc.ABC): """ A Token Provider must be used for the SASL OAuthBearer protocol. diff --git a/aiokafka/producer/producer.py b/aiokafka/producer/producer.py index c8d763b6..f5a57b73 100644 --- a/aiokafka/producer/producer.py +++ b/aiokafka/producer/producer.py @@ -167,7 +167,7 @@ class AIOKafkaProducer: Default: :data:`None` sasl_oauth_token_provider (:class:`~aiokafka.abc.AbstractTokenProvider`): OAuthBearer token provider instance. (See - :mod:`kafka.oauth.abstract`). + :mod:`aiokafka.oauth`). Default: :data:`None` Note: diff --git a/docs/api.rst b/docs/api.rst index 0c698409..ba75a088 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -82,7 +82,7 @@ for setup instructions on Broker side. Client configuration is pretty much the same as Java's, consult the ``sasl_*`` options in Consumer and Producer API Reference for more details. -.. automodule:: kafka.oauth.abstract +.. automodule:: aiokafka.oauth Error handling diff --git a/kafka/oauth/__init__.py b/kafka/oauth/__init__.py deleted file mode 100644 index 8c834956..00000000 --- a/kafka/oauth/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from __future__ import absolute_import - -from kafka.oauth.abstract import AbstractTokenProvider From b99433925a48dcf9063b046df45989e01916aa8d Mon Sep 17 00:00:00 2001 From: Denis Otkidach Date: Mon, 23 Oct 2023 09:24:29 +0300 Subject: [PATCH 15/20] Move metrics --- aiokafka/coordinator/base.py | 4 +- aiokafka/coordinator/consumer.py | 4 +- aiokafka/metrics/__init__.py | 19 + {kafka => aiokafka}/metrics/compound_stat.py | 5 +- {kafka => aiokafka}/metrics/dict_reporter.py | 28 +- {kafka => aiokafka}/metrics/kafka_metric.py | 6 +- {kafka => aiokafka}/metrics/measurable.py | 3 +- .../metrics/measurable_stat.py | 7 +- {kafka => aiokafka}/metrics/metric_config.py | 15 +- {kafka => aiokafka}/metrics/metric_name.py | 24 +- {kafka => aiokafka}/metrics/metrics.py | 60 +- .../metrics/metrics_reporter.py | 3 +- {kafka => aiokafka}/metrics/quota.py | 17 +- {kafka => aiokafka}/metrics/stat.py | 3 +- aiokafka/metrics/stats/__init__.py | 23 + {kafka => aiokafka}/metrics/stats/avg.py | 5 +- {kafka => aiokafka}/metrics/stats/count.py | 5 +- .../metrics/stats/histogram.py | 26 +- {kafka => aiokafka}/metrics/stats/max_stat.py | 9 +- {kafka => aiokafka}/metrics/stats/min_stat.py | 5 +- .../metrics/stats/percentile.py | 3 - .../metrics/stats/percentiles.py | 32 +- {kafka => aiokafka}/metrics/stats/rate.py | 44 +- .../metrics/stats/sampled_stat.py | 12 +- {kafka => aiokafka}/metrics/stats/sensor.py | 42 +- {kafka => aiokafka}/metrics/stats/total.py | 5 +- kafka/__init__.py | 8 - kafka/metrics/__init__.py | 15 - kafka/metrics/stats/__init__.py | 17 - kafka/serializer/__init__.py | 3 - kafka/serializer/abstract.py | 31 - tests/kafka/test_metrics.py | 499 --------------- tests/test_metrics.py | 585 ++++++++++++++++++ 33 files changed, 820 insertions(+), 747 deletions(-) create mode 100644 aiokafka/metrics/__init__.py rename {kafka => aiokafka}/metrics/compound_stat.py (89%) rename {kafka => aiokafka}/metrics/dict_reporter.py (74%) rename {kafka => aiokafka}/metrics/kafka_metric.py (82%) rename {kafka => aiokafka}/metrics/measurable.py (94%) rename {kafka => aiokafka}/metrics/measurable_stat.py (72%) rename {kafka => aiokafka}/metrics/metric_config.py (79%) rename {kafka => aiokafka}/metrics/metric_name.py (85%) rename {kafka => aiokafka}/metrics/metrics.py (86%) rename {kafka => aiokafka}/metrics/metrics_reporter.py (97%) rename {kafka => aiokafka}/metrics/quota.py (70%) rename {kafka => aiokafka}/metrics/stat.py (93%) create mode 100644 aiokafka/metrics/stats/__init__.py rename {kafka => aiokafka}/metrics/stats/avg.py (84%) rename {kafka => aiokafka}/metrics/stats/count.py (78%) rename {kafka => aiokafka}/metrics/stats/histogram.py (79%) rename {kafka => aiokafka}/metrics/stats/max_stat.py (65%) rename {kafka => aiokafka}/metrics/stats/min_stat.py (81%) rename {kafka => aiokafka}/metrics/stats/percentile.py (88%) rename {kafka => aiokafka}/metrics/stats/percentiles.py (69%) rename {kafka => aiokafka}/metrics/stats/rate.py (81%) rename {kafka => aiokafka}/metrics/stats/sampled_stat.py (92%) rename {kafka => aiokafka}/metrics/stats/sensor.py (78%) rename {kafka => aiokafka}/metrics/stats/total.py (75%) delete mode 100644 kafka/metrics/__init__.py delete mode 100644 kafka/metrics/stats/__init__.py delete mode 100644 kafka/serializer/__init__.py delete mode 100644 kafka/serializer/abstract.py delete mode 100644 tests/kafka/test_metrics.py create mode 100644 tests/test_metrics.py diff --git a/aiokafka/coordinator/base.py b/aiokafka/coordinator/base.py index 89401c06..4489884d 100644 --- a/aiokafka/coordinator/base.py +++ b/aiokafka/coordinator/base.py @@ -6,8 +6,6 @@ import weakref from kafka.future import Future -from kafka.metrics import AnonMeasurable -from kafka.metrics.stats import Avg, Count, Max, Rate from kafka.protocol.commit import GroupCoordinatorRequest, OffsetCommitRequest from kafka.protocol.group import ( HeartbeatRequest, @@ -17,6 +15,8 @@ ) from aiokafka import errors as Errors +from aiokafka.metrics import AnonMeasurable +from aiokafka.metrics.stats import Avg, Count, Max, Rate from .heartbeat import Heartbeat diff --git a/aiokafka/coordinator/consumer.py b/aiokafka/coordinator/consumer.py index 47af8449..2c7ebb4e 100644 --- a/aiokafka/coordinator/consumer.py +++ b/aiokafka/coordinator/consumer.py @@ -5,12 +5,12 @@ import time from kafka.future import Future -from kafka.metrics import AnonMeasurable -from kafka.metrics.stats import Avg, Count, Max, Rate from kafka.protocol.commit import OffsetCommitRequest, OffsetFetchRequest from kafka.util import WeakMethod import aiokafka.errors as Errors +from aiokafka.metrics import AnonMeasurable +from aiokafka.metrics.stats import Avg, Count, Max, Rate from aiokafka.structs import OffsetAndMetadata, TopicPartition from .base import BaseCoordinator, Generation diff --git a/aiokafka/metrics/__init__.py b/aiokafka/metrics/__init__.py new file mode 100644 index 00000000..ab17ecc9 --- /dev/null +++ b/aiokafka/metrics/__init__.py @@ -0,0 +1,19 @@ +from .compound_stat import NamedMeasurable +from .dict_reporter import DictReporter +from .kafka_metric import KafkaMetric +from .measurable import AnonMeasurable +from .metric_config import MetricConfig +from .metric_name import MetricName +from .metrics import Metrics +from .quota import Quota + +__all__ = [ + "AnonMeasurable", + "DictReporter", + "KafkaMetric", + "MetricConfig", + "MetricName", + "Metrics", + "NamedMeasurable", + "Quota", +] diff --git a/kafka/metrics/compound_stat.py b/aiokafka/metrics/compound_stat.py similarity index 89% rename from kafka/metrics/compound_stat.py rename to aiokafka/metrics/compound_stat.py index ac92480d..f119854d 100644 --- a/kafka/metrics/compound_stat.py +++ b/aiokafka/metrics/compound_stat.py @@ -1,8 +1,6 @@ -from __future__ import absolute_import - import abc -from kafka.metrics.stat import AbstractStat +from .stat import AbstractStat class AbstractCompoundStat(AbstractStat): @@ -11,6 +9,7 @@ class AbstractCompoundStat(AbstractStat): data structure feeds many metrics. This is the example for a histogram which has many associated percentiles. """ + __metaclass__ = abc.ABCMeta def stats(self): diff --git a/kafka/metrics/dict_reporter.py b/aiokafka/metrics/dict_reporter.py similarity index 74% rename from kafka/metrics/dict_reporter.py rename to aiokafka/metrics/dict_reporter.py index 0b98fe1e..6ef4defb 100644 --- a/kafka/metrics/dict_reporter.py +++ b/aiokafka/metrics/dict_reporter.py @@ -1,9 +1,7 @@ -from __future__ import absolute_import - import logging import threading -from kafka.metrics.metrics_reporter import AbstractMetricsReporter +from .metrics_reporter import AbstractMetricsReporter logger = logging.getLogger(__name__) @@ -13,9 +11,10 @@ class DictReporter(AbstractMetricsReporter): Store all metrics in a two level dictionary of category > name > metric. """ - def __init__(self, prefix=''): + + def __init__(self, prefix=""): self._lock = threading.Lock() - self._prefix = prefix if prefix else '' # never allow None + self._prefix = prefix if prefix else "" # never allow None self._store = {} def snapshot(self): @@ -29,10 +28,13 @@ def snapshot(self): } } """ - return dict((category, dict((name, metric.value()) - for name, metric in list(metrics.items()))) - for category, metrics in - list(self._store.items())) + return dict( + ( + category, + dict((name, metric.value()) for name, metric in list(metrics.items())), + ) + for category, metrics in list(self._store.items()) + ) def init(self, metrics): for metric in metrics: @@ -71,10 +73,10 @@ def get_category(self, metric): prefix = None, group = 'bar', tags = None returns: 'bar' """ - tags = ','.join('%s=%s' % (k, v) for k, v in - sorted(metric.metric_name.tags.items())) - return '.'.join(x for x in - [self._prefix, metric.metric_name.group, tags] if x) + tags = ",".join( + "%s=%s" % (k, v) for k, v in sorted(metric.metric_name.tags.items()) + ) + return ".".join(x for x in [self._prefix, metric.metric_name.group, tags] if x) def configure(self, configs): pass diff --git a/kafka/metrics/kafka_metric.py b/aiokafka/metrics/kafka_metric.py similarity index 82% rename from kafka/metrics/kafka_metric.py rename to aiokafka/metrics/kafka_metric.py index 9fb8d89f..b5fe4751 100644 --- a/kafka/metrics/kafka_metric.py +++ b/aiokafka/metrics/kafka_metric.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - import time @@ -7,9 +5,9 @@ class KafkaMetric(object): # NOTE java constructor takes a lock instance def __init__(self, metric_name, measurable, config): if not metric_name: - raise ValueError('metric_name must be non-empty') + raise ValueError("metric_name must be non-empty") if not measurable: - raise ValueError('measurable must be non-empty') + raise ValueError("measurable must be non-empty") self._metric_name = metric_name self._measurable = measurable self._config = config diff --git a/kafka/metrics/measurable.py b/aiokafka/metrics/measurable.py similarity index 94% rename from kafka/metrics/measurable.py rename to aiokafka/metrics/measurable.py index b06d4d78..545587f5 100644 --- a/kafka/metrics/measurable.py +++ b/aiokafka/metrics/measurable.py @@ -1,10 +1,9 @@ -from __future__ import absolute_import - import abc class AbstractMeasurable(object): """A measurable quantity that can be registered as a metric""" + @abc.abstractmethod def measure(self, config, now): """ diff --git a/kafka/metrics/measurable_stat.py b/aiokafka/metrics/measurable_stat.py similarity index 72% rename from kafka/metrics/measurable_stat.py rename to aiokafka/metrics/measurable_stat.py index 4487adf6..3b6a9838 100644 --- a/kafka/metrics/measurable_stat.py +++ b/aiokafka/metrics/measurable_stat.py @@ -1,9 +1,7 @@ -from __future__ import absolute_import - import abc -from kafka.metrics.measurable import AbstractMeasurable -from kafka.metrics.stat import AbstractStat +from .measurable import AbstractMeasurable +from .stat import AbstractStat class AbstractMeasurableStat(AbstractStat, AbstractMeasurable): @@ -13,4 +11,5 @@ class AbstractMeasurableStat(AbstractStat, AbstractMeasurable): This is the interface used for most of the simple statistics such as Avg, Max, Count, etc. """ + __metaclass__ = abc.ABCMeta diff --git a/kafka/metrics/metric_config.py b/aiokafka/metrics/metric_config.py similarity index 79% rename from kafka/metrics/metric_config.py rename to aiokafka/metrics/metric_config.py index 2e55abfc..ddfc08f5 100644 --- a/kafka/metrics/metric_config.py +++ b/aiokafka/metrics/metric_config.py @@ -1,12 +1,17 @@ -from __future__ import absolute_import - import sys class MetricConfig(object): """Configuration values for metrics""" - def __init__(self, quota=None, samples=2, event_window=sys.maxsize, - time_window_ms=30 * 1000, tags=None): + + def __init__( + self, + quota=None, + samples=2, + event_window=sys.maxsize, + time_window_ms=30 * 1000, + tags=None, + ): """ Arguments: quota (Quota, optional): Upper or lower bound of a value. @@ -29,5 +34,5 @@ def samples(self): @samples.setter def samples(self, value): if value < 1: - raise ValueError('The number of samples must be at least 1.') + raise ValueError("The number of samples must be at least 1.") self._samples = value diff --git a/kafka/metrics/metric_name.py b/aiokafka/metrics/metric_name.py similarity index 85% rename from kafka/metrics/metric_name.py rename to aiokafka/metrics/metric_name.py index b5acd166..355446a4 100644 --- a/kafka/metrics/metric_name.py +++ b/aiokafka/metrics/metric_name.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - import copy @@ -50,9 +48,9 @@ def __init__(self, name, group, description=None, tags=None): tags (dict, optional): Additional key/val attributes of the metric. """ if not (name and group): - raise ValueError('name and group must be non-empty.') + raise ValueError("name and group must be non-empty.") if tags is not None and not isinstance(tags, dict): - raise ValueError('tags must be a dict if present.') + raise ValueError("tags must be a dict if present.") self._name = name self._group = group @@ -93,14 +91,20 @@ def __eq__(self, other): return True if other is None: return False - return (type(self) == type(other) and - self.group == other.group and - self.name == other.name and - self.tags == other.tags) + return ( + type(self) == type(other) + and self.group == other.group + and self.name == other.name + and self.tags == other.tags + ) def __ne__(self, other): return not self.__eq__(other) def __str__(self): - return 'MetricName(name=%s, group=%s, description=%s, tags=%s)' % ( - self.name, self.group, self.description, self.tags) + return "MetricName(name=%s, group=%s, description=%s, tags=%s)" % ( + self.name, + self.group, + self.description, + self.tags, + ) diff --git a/kafka/metrics/metrics.py b/aiokafka/metrics/metrics.py similarity index 86% rename from kafka/metrics/metrics.py rename to aiokafka/metrics/metrics.py index 2c53488f..8a2de20e 100644 --- a/kafka/metrics/metrics.py +++ b/aiokafka/metrics/metrics.py @@ -1,12 +1,13 @@ -from __future__ import absolute_import - import logging import sys import time import threading -from kafka.metrics import AnonMeasurable, KafkaMetric, MetricConfig, MetricName -from kafka.metrics.stats import Sensor +from .kafka_metric import KafkaMetric +from .measurable import AnonMeasurable +from .metric_config import MetricConfig +from .metric_name import MetricName +from .stats import Sensor logger = logging.getLogger(__name__) @@ -34,8 +35,8 @@ class Metrics(object): # as messages are sent we record the sizes sensor.record(message_size); """ - def __init__(self, default_config=None, reporters=None, - enable_expiration=False): + + def __init__(self, default_config=None, reporters=None, enable_expiration=False): """ Create a metrics repository with a default config, given metric reporters and the ability to expire eligible sensors @@ -57,19 +58,24 @@ def __init__(self, default_config=None, reporters=None, reporter.init([]) if enable_expiration: + def expire_loop(): while True: # delay 30 seconds time.sleep(30) self.ExpireSensorTask.run(self) + metrics_scheduler = threading.Thread(target=expire_loop) # Creating a daemon thread to not block shutdown metrics_scheduler.daemon = True metrics_scheduler.start() - self.add_metric(self.metric_name('count', 'kafka-metrics-count', - 'total number of registered metrics'), - AnonMeasurable(lambda config, now: len(self._metrics))) + self.add_metric( + self.metric_name( + "count", "kafka-metrics-count", "total number of registered metrics" + ), + AnonMeasurable(lambda config, now: len(self._metrics)), + ) @property def config(self): @@ -82,7 +88,7 @@ def metrics(self): """ return self._metrics - def metric_name(self, name, group, description='', tags=None): + def metric_name(self, name, group, description="", tags=None): """ Create a MetricName with the given name, group, description and tags, plus default tags specified in the metric configuration. @@ -113,12 +119,16 @@ def get_sensor(self, name): Sensor: The sensor or None if no such sensor exists """ if not name: - raise ValueError('name must be non-empty') + raise ValueError("name must be non-empty") return self._sensors.get(name, None) - def sensor(self, name, config=None, - inactive_sensor_expiration_time_seconds=sys.maxsize, - parents=None): + def sensor( + self, + name, + config=None, + inactive_sensor_expiration_time_seconds=sys.maxsize, + parents=None, + ): """ Get or create a sensor with the given unique name and zero or more parent sensors. All parent sensors will receive every value @@ -143,8 +153,13 @@ def sensor(self, name, config=None, with self._lock: sensor = self.get_sensor(name) if not sensor: - sensor = Sensor(self, name, parents, config or self.config, - inactive_sensor_expiration_time_seconds) + sensor = Sensor( + self, + name, + parents, + config or self.config, + inactive_sensor_expiration_time_seconds, + ) self._sensors[name] = sensor if parents: for parent in parents: @@ -153,7 +168,7 @@ def sensor(self, name, config=None, children = [] self._children_sensors[parent] = children children.append(sensor) - logger.debug('Added sensor with name %s', name) + logger.debug("Added sensor with name %s", name) return sensor def remove_sensor(self, name): @@ -172,7 +187,7 @@ def remove_sensor(self, name): if val and val == sensor: for metric in sensor.metrics: self.remove_metric(metric.metric_name) - logger.debug('Removed sensor with name %s', name) + logger.debug("Removed sensor with name %s", name) child_sensors = self._children_sensors.pop(sensor, None) if child_sensors: for child_sensor in child_sensors: @@ -224,8 +239,10 @@ def add_reporter(self, reporter): def register_metric(self, metric): with self._lock: if metric.metric_name in self.metrics: - raise ValueError('A metric named "%s" already exists, cannot' - ' register another one.' % (metric.metric_name,)) + raise ValueError( + 'A metric named "%s" already exists, cannot' + " register another one." % (metric.metric_name,) + ) self.metrics[metric.metric_name] = metric for reporter in self._reporters: reporter.metric_change(metric) @@ -235,6 +252,7 @@ class ExpireSensorTask(object): This iterates over every Sensor and triggers a remove_sensor if it has expired. Package private for testing """ + @staticmethod def run(metrics): items = list(metrics._sensors.items()) @@ -250,7 +268,7 @@ def run(metrics): # concern and thus not necessary to optimize with sensor._lock: if sensor.has_expired(): - logger.debug('Removing expired sensor %s', name) + logger.debug("Removing expired sensor %s", name) metrics.remove_sensor(name) def close(self): diff --git a/kafka/metrics/metrics_reporter.py b/aiokafka/metrics/metrics_reporter.py similarity index 97% rename from kafka/metrics/metrics_reporter.py rename to aiokafka/metrics/metrics_reporter.py index d8bd12b3..9cc708ac 100644 --- a/kafka/metrics/metrics_reporter.py +++ b/aiokafka/metrics/metrics_reporter.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - import abc @@ -8,6 +6,7 @@ class AbstractMetricsReporter(object): An abstract class to allow things to listen as new metrics are created so they can be reported. """ + __metaclass__ = abc.ABCMeta @abc.abstractmethod diff --git a/kafka/metrics/quota.py b/aiokafka/metrics/quota.py similarity index 70% rename from kafka/metrics/quota.py rename to aiokafka/metrics/quota.py index 4d1b0d6c..6220bcf4 100644 --- a/kafka/metrics/quota.py +++ b/aiokafka/metrics/quota.py @@ -1,8 +1,6 @@ -from __future__ import absolute_import - - class Quota(object): """An upper or lower bound for metrics""" + def __init__(self, bound, is_upper): self._bound = bound self._upper = is_upper @@ -23,8 +21,9 @@ def bound(self): return self._bound def is_acceptable(self, value): - return ((self.is_upper_bound() and value <= self.bound) or - (not self.is_upper_bound() and value >= self.bound)) + return (self.is_upper_bound() and value <= self.bound) or ( + not self.is_upper_bound() and value >= self.bound + ) def __hash__(self): prime = 31 @@ -34,9 +33,11 @@ def __hash__(self): def __eq__(self, other): if self is other: return True - return (type(self) == type(other) and - self.bound == other.bound and - self.is_upper_bound() == other.is_upper_bound()) + return ( + type(self) == type(other) + and self.bound == other.bound + and self.is_upper_bound() == other.is_upper_bound() + ) def __ne__(self, other): return not self.__eq__(other) diff --git a/kafka/metrics/stat.py b/aiokafka/metrics/stat.py similarity index 93% rename from kafka/metrics/stat.py rename to aiokafka/metrics/stat.py index 9fd2f01e..60d778fc 100644 --- a/kafka/metrics/stat.py +++ b/aiokafka/metrics/stat.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - import abc @@ -8,6 +6,7 @@ class AbstractStat(object): An AbstractStat is a quantity such as average, max, etc that is computed off the stream of updates to a sensor """ + __metaclass__ = abc.ABCMeta @abc.abstractmethod diff --git a/aiokafka/metrics/stats/__init__.py b/aiokafka/metrics/stats/__init__.py new file mode 100644 index 00000000..678bb691 --- /dev/null +++ b/aiokafka/metrics/stats/__init__.py @@ -0,0 +1,23 @@ +from .avg import Avg +from .count import Count +from .histogram import Histogram +from .max_stat import Max +from .min_stat import Min +from .percentile import Percentile +from .percentiles import Percentiles +from .rate import Rate +from .sensor import Sensor +from .total import Total + +__all__ = [ + "Avg", + "Count", + "Histogram", + "Max", + "Min", + "Percentile", + "Percentiles", + "Rate", + "Sensor", + "Total", +] diff --git a/kafka/metrics/stats/avg.py b/aiokafka/metrics/stats/avg.py similarity index 84% rename from kafka/metrics/stats/avg.py rename to aiokafka/metrics/stats/avg.py index cfbaec30..7b7306c4 100644 --- a/kafka/metrics/stats/avg.py +++ b/aiokafka/metrics/stats/avg.py @@ -1,12 +1,11 @@ -from __future__ import absolute_import - -from kafka.metrics.stats.sampled_stat import AbstractSampledStat +from .sampled_stat import AbstractSampledStat class Avg(AbstractSampledStat): """ An AbstractSampledStat that maintains a simple average over its samples. """ + def __init__(self): super(Avg, self).__init__(0.0) diff --git a/kafka/metrics/stats/count.py b/aiokafka/metrics/stats/count.py similarity index 78% rename from kafka/metrics/stats/count.py rename to aiokafka/metrics/stats/count.py index 6e0a2d54..804bcdf0 100644 --- a/kafka/metrics/stats/count.py +++ b/aiokafka/metrics/stats/count.py @@ -1,12 +1,11 @@ -from __future__ import absolute_import - -from kafka.metrics.stats.sampled_stat import AbstractSampledStat +from .sampled_stat import AbstractSampledStat class Count(AbstractSampledStat): """ An AbstractSampledStat that maintains a simple count of what it has seen. """ + def __init__(self): super(Count, self).__init__(0.0) diff --git a/kafka/metrics/stats/histogram.py b/aiokafka/metrics/stats/histogram.py similarity index 79% rename from kafka/metrics/stats/histogram.py rename to aiokafka/metrics/stats/histogram.py index ecc6c9db..9f9f7cff 100644 --- a/kafka/metrics/stats/histogram.py +++ b/aiokafka/metrics/stats/histogram.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - import math @@ -15,14 +13,14 @@ def record(self, value): def value(self, quantile): if self._count == 0.0: - return float('NaN') + return float("NaN") _sum = 0.0 quant = float(quantile) for i, value in enumerate(self._hist[:-1]): _sum += value if _sum / self._count > quant: return self._bin_scheme.from_bin(i) - return float('inf') + return float("inf") @property def counts(self): @@ -34,15 +32,17 @@ def clear(self): self._count = 0 def __str__(self): - values = ['%.10f:%.0f' % (self._bin_scheme.from_bin(i), value) for - i, value in enumerate(self._hist[:-1])] - values.append('%s:%s' % (float('inf'), self._hist[-1])) - return '{%s}' % ','.join(values) + values = [ + "%.10f:%.0f" % (self._bin_scheme.from_bin(i), value) + for i, value in enumerate(self._hist[:-1]) + ] + values.append("%s:%s" % (float("inf"), self._hist[-1])) + return "{%s}" % ",".join(values) class ConstantBinScheme(object): def __init__(self, bins, min_val, max_val): if bins < 2: - raise ValueError('Must have at least 2 bins.') + raise ValueError("Must have at least 2 bins.") self._min = float(min_val) self._max = float(max_val) self._bins = int(bins) @@ -54,9 +54,9 @@ def bins(self): def from_bin(self, b): if b == 0: - return float('-inf') + return float("-inf") elif b == self._bins - 1: - return float('inf') + return float("inf") else: return self._min + (b - 1) * self._bucket_width @@ -80,14 +80,14 @@ def bins(self): def from_bin(self, b): if b == self._bins - 1: - return float('inf') + return float("inf") else: unscaled = (b * (b + 1.0)) / 2.0 return unscaled * self._scale def to_bin(self, x): if x < 0.0: - raise ValueError('Values less than 0.0 not accepted.') + raise ValueError("Values less than 0.0 not accepted.") elif x > self._max: return self._bins - 1 else: diff --git a/kafka/metrics/stats/max_stat.py b/aiokafka/metrics/stats/max_stat.py similarity index 65% rename from kafka/metrics/stats/max_stat.py rename to aiokafka/metrics/stats/max_stat.py index 08aebddf..af7874d1 100644 --- a/kafka/metrics/stats/max_stat.py +++ b/aiokafka/metrics/stats/max_stat.py @@ -1,17 +1,16 @@ -from __future__ import absolute_import - -from kafka.metrics.stats.sampled_stat import AbstractSampledStat +from .sampled_stat import AbstractSampledStat class Max(AbstractSampledStat): """An AbstractSampledStat that gives the max over its samples.""" + def __init__(self): - super(Max, self).__init__(float('-inf')) + super(Max, self).__init__(float("-inf")) def update(self, sample, config, value, now): sample.value = max(sample.value, value) def combine(self, samples, config, now): if not samples: - return float('-inf') + return float("-inf") return float(max(sample.value for sample in samples)) diff --git a/kafka/metrics/stats/min_stat.py b/aiokafka/metrics/stats/min_stat.py similarity index 81% rename from kafka/metrics/stats/min_stat.py rename to aiokafka/metrics/stats/min_stat.py index 072106d8..e826afac 100644 --- a/kafka/metrics/stats/min_stat.py +++ b/aiokafka/metrics/stats/min_stat.py @@ -1,12 +1,11 @@ -from __future__ import absolute_import - import sys -from kafka.metrics.stats.sampled_stat import AbstractSampledStat +from .sampled_stat import AbstractSampledStat class Min(AbstractSampledStat): """An AbstractSampledStat that gives the min over its samples.""" + def __init__(self): super(Min, self).__init__(float(sys.maxsize)) diff --git a/kafka/metrics/stats/percentile.py b/aiokafka/metrics/stats/percentile.py similarity index 88% rename from kafka/metrics/stats/percentile.py rename to aiokafka/metrics/stats/percentile.py index 3a86a84a..723b9e6a 100644 --- a/kafka/metrics/stats/percentile.py +++ b/aiokafka/metrics/stats/percentile.py @@ -1,6 +1,3 @@ -from __future__ import absolute_import - - class Percentile(object): def __init__(self, metric_name, percentile): self._metric_name = metric_name diff --git a/kafka/metrics/stats/percentiles.py b/aiokafka/metrics/stats/percentiles.py similarity index 69% rename from kafka/metrics/stats/percentiles.py rename to aiokafka/metrics/stats/percentiles.py index 6d702e80..70e1ca6a 100644 --- a/kafka/metrics/stats/percentiles.py +++ b/aiokafka/metrics/stats/percentiles.py @@ -1,9 +1,8 @@ -from __future__ import absolute_import +from aiokafka.metrics.measurable import AnonMeasurable +from aiokafka.metrics.compound_stat import AbstractCompoundStat, NamedMeasurable -from kafka.metrics import AnonMeasurable, NamedMeasurable -from kafka.metrics.compound_stat import AbstractCompoundStat -from kafka.metrics.stats import Histogram -from kafka.metrics.stats.sampled_stat import AbstractSampledStat +from .histogram import Histogram +from .sampled_stat import AbstractSampledStat class BucketSizing(object): @@ -13,28 +12,29 @@ class BucketSizing(object): class Percentiles(AbstractSampledStat, AbstractCompoundStat): """A compound stat that reports one or more percentiles""" - def __init__(self, size_in_bytes, bucketing, max_val, min_val=0.0, - percentiles=None): + + def __init__( + self, size_in_bytes, bucketing, max_val, min_val=0.0, percentiles=None + ): super(Percentiles, self).__init__(0.0) self._percentiles = percentiles or [] self._buckets = int(size_in_bytes / 4) if bucketing == BucketSizing.CONSTANT: - self._bin_scheme = Histogram.ConstantBinScheme(self._buckets, - min_val, max_val) + self._bin_scheme = Histogram.ConstantBinScheme( + self._buckets, min_val, max_val + ) elif bucketing == BucketSizing.LINEAR: if min_val != 0.0: - raise ValueError('Linear bucket sizing requires min_val' - ' to be 0.0.') + raise ValueError("Linear bucket sizing requires min_val" " to be 0.0.") self.bin_scheme = Histogram.LinearBinScheme(self._buckets, max_val) else: - ValueError('Unknown bucket type: %s' % (bucketing,)) + ValueError("Unknown bucket type: %s" % (bucketing,)) def stats(self): measurables = [] def make_measure_fn(pct): - return lambda config, now: self.value(config, now, - pct / 100.0) + return lambda config, now: self.value(config, now, pct / 100.0) for percentile in self._percentiles: measure_fn = make_measure_fn(percentile.percentile) @@ -46,7 +46,7 @@ def value(self, config, now, quantile): self.purge_obsolete_samples(config, now) count = sum(sample.event_count for sample in self._samples) if count == 0.0: - return float('NaN') + return float("NaN") sum_val = 0.0 quant = float(quantile) for b in range(self._buckets): @@ -56,7 +56,7 @@ def value(self, config, now, quantile): sum_val += hist[b] if sum_val / count > quant: return self._bin_scheme.from_bin(b) - return float('inf') + return float("inf") def combine(self, samples, config, now): return self.value(config, now, 0.5) diff --git a/kafka/metrics/stats/rate.py b/aiokafka/metrics/stats/rate.py similarity index 81% rename from kafka/metrics/stats/rate.py rename to aiokafka/metrics/stats/rate.py index 68393fbf..2722b3a7 100644 --- a/kafka/metrics/stats/rate.py +++ b/aiokafka/metrics/stats/rate.py @@ -1,27 +1,25 @@ -from __future__ import absolute_import - -from kafka.metrics.measurable_stat import AbstractMeasurableStat -from kafka.metrics.stats.sampled_stat import AbstractSampledStat +from aiokafka.metrics.measurable_stat import AbstractMeasurableStat +from aiokafka.metrics.stats.sampled_stat import AbstractSampledStat class TimeUnit(object): _names = { - 'nanosecond': 0, - 'microsecond': 1, - 'millisecond': 2, - 'second': 3, - 'minute': 4, - 'hour': 5, - 'day': 6, + "nanosecond": 0, + "microsecond": 1, + "millisecond": 2, + "second": 3, + "minute": 4, + "hour": 5, + "day": 6, } - NANOSECONDS = _names['nanosecond'] - MICROSECONDS = _names['microsecond'] - MILLISECONDS = _names['millisecond'] - SECONDS = _names['second'] - MINUTES = _names['minute'] - HOURS = _names['hour'] - DAYS = _names['day'] + NANOSECONDS = _names["nanosecond"] + MICROSECONDS = _names["microsecond"] + MILLISECONDS = _names["millisecond"] + SECONDS = _names["second"] + MINUTES = _names["minute"] + HOURS = _names["hour"] + DAYS = _names["day"] @staticmethod def get_name(time_unit): @@ -37,6 +35,7 @@ class Rate(AbstractMeasurableStat): occurrences (e.g. the count of values measured over the time interval) or other such values. """ + def __init__(self, time_unit=TimeUnit.SECONDS, sampled_stat=None): self._stat = sampled_stat or SampledTotal() self._unit = time_unit @@ -80,8 +79,9 @@ def window_size(self, config, now): # If the available windows are less than the minimum required, # add the difference to the totalElapsedTime if num_full_windows < min_full_windows: - total_elapsed_time_ms += ((min_full_windows - num_full_windows) * - config.time_window_ms) + total_elapsed_time_ms += ( + min_full_windows - num_full_windows + ) * config.time_window_ms return total_elapsed_time_ms @@ -101,13 +101,13 @@ def convert(self, time_ms): elif self._unit == TimeUnit.DAYS: return time_ms / (24.0 * 60.0 * 60.0 * 1000.0) else: - raise ValueError('Unknown unit: %s' % (self._unit,)) + raise ValueError("Unknown unit: %s" % (self._unit,)) class SampledTotal(AbstractSampledStat): def __init__(self, initial_value=None): if initial_value is not None: - raise ValueError('initial_value cannot be set on SampledTotal') + raise ValueError("initial_value cannot be set on SampledTotal") super(SampledTotal, self).__init__(0.0) def update(self, sample, config, value, time_ms): diff --git a/kafka/metrics/stats/sampled_stat.py b/aiokafka/metrics/stats/sampled_stat.py similarity index 92% rename from kafka/metrics/stats/sampled_stat.py rename to aiokafka/metrics/stats/sampled_stat.py index c41b14bb..7bb86f37 100644 --- a/kafka/metrics/stats/sampled_stat.py +++ b/aiokafka/metrics/stats/sampled_stat.py @@ -1,8 +1,6 @@ -from __future__ import absolute_import - import abc -from kafka.metrics.measurable_stat import AbstractMeasurableStat +from aiokafka.metrics.measurable_stat import AbstractMeasurableStat class AbstractSampledStat(AbstractMeasurableStat): @@ -20,6 +18,7 @@ class AbstractSampledStat(AbstractMeasurableStat): Subclasses of this class define different statistics measured using this basic pattern. """ + __metaclass__ = abc.ABCMeta def __init__(self, initial_value): @@ -84,7 +83,6 @@ def _advance(self, config, time_ms): return sample class Sample(object): - def __init__(self, initial_value, now): self.initial_value = initial_value self.event_count = 0 @@ -97,5 +95,7 @@ def reset(self, now): self.value = self.initial_value def is_complete(self, time_ms, config): - return (time_ms - self.last_window_ms >= config.time_window_ms or - self.event_count >= config.event_window) + return ( + time_ms - self.last_window_ms >= config.time_window_ms + or self.event_count >= config.event_window + ) diff --git a/kafka/metrics/stats/sensor.py b/aiokafka/metrics/stats/sensor.py similarity index 78% rename from kafka/metrics/stats/sensor.py rename to aiokafka/metrics/stats/sensor.py index a0dbe4c1..81dd9691 100644 --- a/kafka/metrics/stats/sensor.py +++ b/aiokafka/metrics/stats/sensor.py @@ -1,10 +1,8 @@ -from __future__ import absolute_import - import threading import time from aiokafka.errors import QuotaViolationError -from kafka.metrics import KafkaMetric +from aiokafka.metrics.kafka_metric import KafkaMetric class Sensor(object): @@ -15,10 +13,12 @@ class Sensor(object): the `record(double)` api and would maintain a set of metrics about request sizes such as the average or max. """ - def __init__(self, registry, name, parents, config, - inactive_sensor_expiration_time_seconds): + + def __init__( + self, registry, name, parents, config, inactive_sensor_expiration_time_seconds + ): if not name: - raise ValueError('name must be non-empty') + raise ValueError("name must be non-empty") self._lock = threading.RLock() self._registry = registry self._name = name @@ -27,15 +27,17 @@ def __init__(self, registry, name, parents, config, self._stats = [] self._config = config self._inactive_sensor_expiration_time_ms = ( - inactive_sensor_expiration_time_seconds * 1000) + inactive_sensor_expiration_time_seconds * 1000 + ) self._last_record_time = time.time() * 1000 self._check_forest(set()) def _check_forest(self, sensors): """Validate that this sensor doesn't end up referencing itself.""" if self in sensors: - raise ValueError('Circular dependency in sensors: %s is its own' - 'parent.' % (self.name,)) + raise ValueError( + "Circular dependency in sensors: %s is its own" "parent." % (self.name,) + ) sensors.add(self) for parent in self._parents: parent._check_forest(sensors) @@ -84,11 +86,11 @@ def _check_quotas(self, time_ms): if metric.config and metric.config.quota: value = metric.value(time_ms) if not metric.config.quota.is_acceptable(value): - raise QuotaViolationError("'%s' violated quota. Actual: " - "%d, Threshold: %d" % - (metric.metric_name, - value, - metric.config.quota.bound)) + raise QuotaViolationError( + "'%s' violated quota. Actual: " + "%d, Threshold: %d" + % (metric.metric_name, value, metric.config.quota.bound) + ) def add_compound(self, compound_stat, config=None): """ @@ -102,11 +104,12 @@ def add_compound(self, compound_stat, config=None): for this sensor. """ if not compound_stat: - raise ValueError('compound stat must be non-empty') + raise ValueError("compound stat must be non-empty") self._stats.append(compound_stat) for named_measurable in compound_stat.stats(): - metric = KafkaMetric(named_measurable.name, named_measurable.stat, - config or self._config) + metric = KafkaMetric( + named_measurable.name, named_measurable.stat, config or self._config + ) self._registry.register_metric(metric) self._metrics.append(metric) @@ -130,5 +133,6 @@ def has_expired(self): """ Return True if the Sensor is eligible for removal due to inactivity. """ - return ((time.time() * 1000 - self._last_record_time) > - self._inactive_sensor_expiration_time_ms) + return ( + time.time() * 1000 - self._last_record_time + ) > self._inactive_sensor_expiration_time_ms diff --git a/kafka/metrics/stats/total.py b/aiokafka/metrics/stats/total.py similarity index 75% rename from kafka/metrics/stats/total.py rename to aiokafka/metrics/stats/total.py index 5b3bb87f..fa44a807 100644 --- a/kafka/metrics/stats/total.py +++ b/aiokafka/metrics/stats/total.py @@ -1,10 +1,9 @@ -from __future__ import absolute_import - -from kafka.metrics.measurable_stat import AbstractMeasurableStat +from aiokafka.metrics.measurable_stat import AbstractMeasurableStat class Total(AbstractMeasurableStat): """An un-windowed cumulative total maintained over all time.""" + def __init__(self, value=0.0): self._total = value diff --git a/kafka/__init__.py b/kafka/__init__.py index 08a919c1..16d29b2e 100644 --- a/kafka/__init__.py +++ b/kafka/__init__.py @@ -16,11 +16,3 @@ def emit(self, record): pass logging.getLogger(__name__).addHandler(NullHandler()) - - -from kafka.serializer import Serializer, Deserializer - - -__all__ = [ - 'BrokerConnection', 'ConsumerRebalanceListener', -] diff --git a/kafka/metrics/__init__.py b/kafka/metrics/__init__.py deleted file mode 100644 index 2a62d633..00000000 --- a/kafka/metrics/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from __future__ import absolute_import - -from kafka.metrics.compound_stat import NamedMeasurable -from kafka.metrics.dict_reporter import DictReporter -from kafka.metrics.kafka_metric import KafkaMetric -from kafka.metrics.measurable import AnonMeasurable -from kafka.metrics.metric_config import MetricConfig -from kafka.metrics.metric_name import MetricName -from kafka.metrics.metrics import Metrics -from kafka.metrics.quota import Quota - -__all__ = [ - 'AnonMeasurable', 'DictReporter', 'KafkaMetric', 'MetricConfig', - 'MetricName', 'Metrics', 'NamedMeasurable', 'Quota' -] diff --git a/kafka/metrics/stats/__init__.py b/kafka/metrics/stats/__init__.py deleted file mode 100644 index a3d535df..00000000 --- a/kafka/metrics/stats/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -from __future__ import absolute_import - -from kafka.metrics.stats.avg import Avg -from kafka.metrics.stats.count import Count -from kafka.metrics.stats.histogram import Histogram -from kafka.metrics.stats.max_stat import Max -from kafka.metrics.stats.min_stat import Min -from kafka.metrics.stats.percentile import Percentile -from kafka.metrics.stats.percentiles import Percentiles -from kafka.metrics.stats.rate import Rate -from kafka.metrics.stats.sensor import Sensor -from kafka.metrics.stats.total import Total - -__all__ = [ - 'Avg', 'Count', 'Histogram', 'Max', 'Min', 'Percentile', 'Percentiles', - 'Rate', 'Sensor', 'Total' -] diff --git a/kafka/serializer/__init__.py b/kafka/serializer/__init__.py deleted file mode 100644 index 90cd93ab..00000000 --- a/kafka/serializer/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from __future__ import absolute_import - -from kafka.serializer.abstract import Serializer, Deserializer diff --git a/kafka/serializer/abstract.py b/kafka/serializer/abstract.py deleted file mode 100644 index 18ad8d69..00000000 --- a/kafka/serializer/abstract.py +++ /dev/null @@ -1,31 +0,0 @@ -from __future__ import absolute_import - -import abc - - -class Serializer(object): - __meta__ = abc.ABCMeta - - def __init__(self, **config): - pass - - @abc.abstractmethod - def serialize(self, topic, value): - pass - - def close(self): - pass - - -class Deserializer(object): - __meta__ = abc.ABCMeta - - def __init__(self, **config): - pass - - @abc.abstractmethod - def deserialize(self, topic, bytes_): - pass - - def close(self): - pass diff --git a/tests/kafka/test_metrics.py b/tests/kafka/test_metrics.py deleted file mode 100644 index 64cc1fc1..00000000 --- a/tests/kafka/test_metrics.py +++ /dev/null @@ -1,499 +0,0 @@ -import sys -import time - -import pytest - -from aiokafka.errors import QuotaViolationError -from kafka.metrics import DictReporter, MetricConfig, MetricName, Metrics, Quota -from kafka.metrics.measurable import AbstractMeasurable -from kafka.metrics.stats import (Avg, Count, Max, Min, Percentile, Percentiles, - Rate, Total) -from kafka.metrics.stats.percentiles import BucketSizing -from kafka.metrics.stats.rate import TimeUnit - -EPS = 0.000001 - - -@pytest.fixture -def time_keeper(): - return TimeKeeper() - - -@pytest.fixture -def config(): - return MetricConfig() - - -@pytest.fixture -def reporter(): - return DictReporter() - - -@pytest.fixture -def metrics(request, config, reporter): - metrics = Metrics(config, [reporter], enable_expiration=True) - yield metrics - metrics.close() - - -def test_MetricName(): - # The Java test only cover the differences between the deprecated - # constructors, so I'm skipping them but doing some other basic testing. - - # In short, metrics should be equal IFF their name, group, and tags are - # the same. Descriptions do not matter. - name1 = MetricName('name', 'group', 'A metric.', {'a': 1, 'b': 2}) - name2 = MetricName('name', 'group', 'A description.', {'a': 1, 'b': 2}) - assert name1 == name2 - - name1 = MetricName('name', 'group', tags={'a': 1, 'b': 2}) - name2 = MetricName('name', 'group', tags={'a': 1, 'b': 2}) - assert name1 == name2 - - name1 = MetricName('foo', 'group') - name2 = MetricName('name', 'group') - assert name1 != name2 - - name1 = MetricName('name', 'foo') - name2 = MetricName('name', 'group') - assert name1 != name2 - - # name and group must be non-empty. Everything else is optional. - with pytest.raises(Exception): - MetricName('', 'group') - with pytest.raises(Exception): - MetricName('name', None) - # tags must be a dict if supplied - with pytest.raises(Exception): - MetricName('name', 'group', tags=set()) - - # Because of the implementation of __eq__ and __hash__, the values of - # a MetricName cannot be mutable. - tags = {'a': 1} - name = MetricName('name', 'group', 'description', tags=tags) - with pytest.raises(AttributeError): - name.name = 'new name' - with pytest.raises(AttributeError): - name.group = 'new name' - with pytest.raises(AttributeError): - name.tags = {} - # tags is a copy, so the instance isn't altered - name.tags['b'] = 2 - assert name.tags == tags - - -def test_simple_stats(mocker, time_keeper, config, metrics): - mocker.patch('time.time', side_effect=time_keeper.time) - - measurable = ConstantMeasurable() - - metrics.add_metric(metrics.metric_name('direct.measurable', 'grp1', - 'The fraction of time an appender waits for space allocation.'), - measurable) - sensor = metrics.sensor('test.sensor') - sensor.add(metrics.metric_name('test.avg', 'grp1'), Avg()) - sensor.add(metrics.metric_name('test.max', 'grp1'), Max()) - sensor.add(metrics.metric_name('test.min', 'grp1'), Min()) - sensor.add(metrics.metric_name('test.rate', 'grp1'), Rate(TimeUnit.SECONDS)) - sensor.add(metrics.metric_name('test.occurences', 'grp1'),Rate(TimeUnit.SECONDS, Count())) - sensor.add(metrics.metric_name('test.count', 'grp1'), Count()) - percentiles = [Percentile(metrics.metric_name('test.median', 'grp1'), 50.0), - Percentile(metrics.metric_name('test.perc99_9', 'grp1'), 99.9)] - sensor.add_compound(Percentiles(100, BucketSizing.CONSTANT, 100, -100, - percentiles=percentiles)) - - sensor2 = metrics.sensor('test.sensor2') - sensor2.add(metrics.metric_name('s2.total', 'grp1'), Total()) - sensor2.record(5.0) - - sum_val = 0 - count = 10 - for i in range(count): - sensor.record(i) - sum_val += i - - # prior to any time passing - elapsed_secs = (config.time_window_ms * (config.samples - 1)) / 1000.0 - assert abs(count / elapsed_secs - - metrics.metrics.get(metrics.metric_name('test.occurences', 'grp1')).value()) \ - < EPS, 'Occurrences(0...%d) = %f' % (count, count / elapsed_secs) - - # pretend 2 seconds passed... - sleep_time_seconds = 2.0 - time_keeper.sleep(sleep_time_seconds) - elapsed_secs += sleep_time_seconds - - assert abs(5.0 - metrics.metrics.get(metrics.metric_name('s2.total', 'grp1')).value()) \ - < EPS, 's2 reflects the constant value' - assert abs(4.5 - metrics.metrics.get(metrics.metric_name('test.avg', 'grp1')).value()) \ - < EPS, 'Avg(0...9) = 4.5' - assert abs((count - 1) - metrics.metrics.get(metrics.metric_name('test.max', 'grp1')).value()) \ - < EPS, 'Max(0...9) = 9' - assert abs(0.0 - metrics.metrics.get(metrics.metric_name('test.min', 'grp1')).value()) \ - < EPS, 'Min(0...9) = 0' - assert abs((sum_val / elapsed_secs) - metrics.metrics.get(metrics.metric_name('test.rate', 'grp1')).value()) \ - < EPS, 'Rate(0...9) = 1.40625' - assert abs((count / elapsed_secs) - metrics.metrics.get(metrics.metric_name('test.occurences', 'grp1')).value()) \ - < EPS, 'Occurrences(0...%d) = %f' % (count, count / elapsed_secs) - assert abs(count - metrics.metrics.get(metrics.metric_name('test.count', 'grp1')).value()) \ - < EPS, 'Count(0...9) = 10' - - -def test_hierarchical_sensors(metrics): - parent1 = metrics.sensor('test.parent1') - parent1.add(metrics.metric_name('test.parent1.count', 'grp1'), Count()) - parent2 = metrics.sensor('test.parent2') - parent2.add(metrics.metric_name('test.parent2.count', 'grp1'), Count()) - child1 = metrics.sensor('test.child1', parents=[parent1, parent2]) - child1.add(metrics.metric_name('test.child1.count', 'grp1'), Count()) - child2 = metrics.sensor('test.child2', parents=[parent1]) - child2.add(metrics.metric_name('test.child2.count', 'grp1'), Count()) - grandchild = metrics.sensor('test.grandchild', parents=[child1]) - grandchild.add(metrics.metric_name('test.grandchild.count', 'grp1'), Count()) - - # increment each sensor one time - parent1.record() - parent2.record() - child1.record() - child2.record() - grandchild.record() - - p1 = parent1.metrics[0].value() - p2 = parent2.metrics[0].value() - c1 = child1.metrics[0].value() - c2 = child2.metrics[0].value() - gc = grandchild.metrics[0].value() - - # each metric should have a count equal to one + its children's count - assert 1.0 == gc - assert 1.0 + gc == c1 - assert 1.0 == c2 - assert 1.0 + c1 == p2 - assert 1.0 + c1 + c2 == p1 - assert [child1, child2] == metrics._children_sensors.get(parent1) - assert [child1] == metrics._children_sensors.get(parent2) - assert metrics._children_sensors.get(grandchild) is None - - -def test_bad_sensor_hierarchy(metrics): - parent = metrics.sensor('parent') - child1 = metrics.sensor('child1', parents=[parent]) - child2 = metrics.sensor('child2', parents=[parent]) - - with pytest.raises(ValueError): - metrics.sensor('gc', parents=[child1, child2]) - - -def test_remove_sensor(metrics): - size = len(metrics.metrics) - parent1 = metrics.sensor('test.parent1') - parent1.add(metrics.metric_name('test.parent1.count', 'grp1'), Count()) - parent2 = metrics.sensor('test.parent2') - parent2.add(metrics.metric_name('test.parent2.count', 'grp1'), Count()) - child1 = metrics.sensor('test.child1', parents=[parent1, parent2]) - child1.add(metrics.metric_name('test.child1.count', 'grp1'), Count()) - child2 = metrics.sensor('test.child2', parents=[parent2]) - child2.add(metrics.metric_name('test.child2.count', 'grp1'), Count()) - grandchild1 = metrics.sensor('test.gchild2', parents=[child2]) - grandchild1.add(metrics.metric_name('test.gchild2.count', 'grp1'), Count()) - - sensor = metrics.get_sensor('test.parent1') - assert sensor is not None - metrics.remove_sensor('test.parent1') - assert metrics.get_sensor('test.parent1') is None - assert metrics.metrics.get(metrics.metric_name('test.parent1.count', 'grp1')) is None - assert metrics.get_sensor('test.child1') is None - assert metrics._children_sensors.get(sensor) is None - assert metrics.metrics.get(metrics.metric_name('test.child1.count', 'grp1')) is None - - sensor = metrics.get_sensor('test.gchild2') - assert sensor is not None - metrics.remove_sensor('test.gchild2') - assert metrics.get_sensor('test.gchild2') is None - assert metrics._children_sensors.get(sensor) is None - assert metrics.metrics.get(metrics.metric_name('test.gchild2.count', 'grp1')) is None - - sensor = metrics.get_sensor('test.child2') - assert sensor is not None - metrics.remove_sensor('test.child2') - assert metrics.get_sensor('test.child2') is None - assert metrics._children_sensors.get(sensor) is None - assert metrics.metrics.get(metrics.metric_name('test.child2.count', 'grp1')) is None - - sensor = metrics.get_sensor('test.parent2') - assert sensor is not None - metrics.remove_sensor('test.parent2') - assert metrics.get_sensor('test.parent2') is None - assert metrics._children_sensors.get(sensor) is None - assert metrics.metrics.get(metrics.metric_name('test.parent2.count', 'grp1')) is None - - assert size == len(metrics.metrics) - - -def test_remove_inactive_metrics(mocker, time_keeper, metrics): - mocker.patch('time.time', side_effect=time_keeper.time) - - s1 = metrics.sensor('test.s1', None, 1) - s1.add(metrics.metric_name('test.s1.count', 'grp1'), Count()) - - s2 = metrics.sensor('test.s2', None, 3) - s2.add(metrics.metric_name('test.s2.count', 'grp1'), Count()) - - purger = Metrics.ExpireSensorTask - purger.run(metrics) - assert metrics.get_sensor('test.s1') is not None, \ - 'Sensor test.s1 must be present' - assert metrics.metrics.get(metrics.metric_name('test.s1.count', 'grp1')) is not None, \ - 'MetricName test.s1.count must be present' - assert metrics.get_sensor('test.s2') is not None, \ - 'Sensor test.s2 must be present' - assert metrics.metrics.get(metrics.metric_name('test.s2.count', 'grp1')) is not None, \ - 'MetricName test.s2.count must be present' - - time_keeper.sleep(1.001) - purger.run(metrics) - assert metrics.get_sensor('test.s1') is None, \ - 'Sensor test.s1 should have been purged' - assert metrics.metrics.get(metrics.metric_name('test.s1.count', 'grp1')) is None, \ - 'MetricName test.s1.count should have been purged' - assert metrics.get_sensor('test.s2') is not None, \ - 'Sensor test.s2 must be present' - assert metrics.metrics.get(metrics.metric_name('test.s2.count', 'grp1')) is not None, \ - 'MetricName test.s2.count must be present' - - # record a value in sensor s2. This should reset the clock for that sensor. - # It should not get purged at the 3 second mark after creation - s2.record() - - time_keeper.sleep(2) - purger.run(metrics) - assert metrics.get_sensor('test.s2') is not None, \ - 'Sensor test.s2 must be present' - assert metrics.metrics.get(metrics.metric_name('test.s2.count', 'grp1')) is not None, \ - 'MetricName test.s2.count must be present' - - # After another 1 second sleep, the metric should be purged - time_keeper.sleep(1) - purger.run(metrics) - assert metrics.get_sensor('test.s1') is None, \ - 'Sensor test.s2 should have been purged' - assert metrics.metrics.get(metrics.metric_name('test.s1.count', 'grp1')) is None, \ - 'MetricName test.s2.count should have been purged' - - # After purging, it should be possible to recreate a metric - s1 = metrics.sensor('test.s1', None, 1) - s1.add(metrics.metric_name('test.s1.count', 'grp1'), Count()) - assert metrics.get_sensor('test.s1') is not None, \ - 'Sensor test.s1 must be present' - assert metrics.metrics.get(metrics.metric_name('test.s1.count', 'grp1')) is not None, \ - 'MetricName test.s1.count must be present' - - -def test_remove_metric(metrics): - size = len(metrics.metrics) - metrics.add_metric(metrics.metric_name('test1', 'grp1'), Count()) - metrics.add_metric(metrics.metric_name('test2', 'grp1'), Count()) - - assert metrics.remove_metric(metrics.metric_name('test1', 'grp1')) is not None - assert metrics.metrics.get(metrics.metric_name('test1', 'grp1')) is None - assert metrics.metrics.get(metrics.metric_name('test2', 'grp1')) is not None - - assert metrics.remove_metric(metrics.metric_name('test2', 'grp1')) is not None - assert metrics.metrics.get(metrics.metric_name('test2', 'grp1')) is None - - assert size == len(metrics.metrics) - - -def test_event_windowing(mocker, time_keeper): - mocker.patch('time.time', side_effect=time_keeper.time) - - count = Count() - config = MetricConfig(event_window=1, samples=2) - count.record(config, 1.0, time_keeper.ms()) - count.record(config, 1.0, time_keeper.ms()) - assert 2.0 == count.measure(config, time_keeper.ms()) - count.record(config, 1.0, time_keeper.ms()) # first event times out - assert 2.0 == count.measure(config, time_keeper.ms()) - - -def test_time_windowing(mocker, time_keeper): - mocker.patch('time.time', side_effect=time_keeper.time) - - count = Count() - config = MetricConfig(time_window_ms=1, samples=2) - count.record(config, 1.0, time_keeper.ms()) - time_keeper.sleep(.001) - count.record(config, 1.0, time_keeper.ms()) - assert 2.0 == count.measure(config, time_keeper.ms()) - time_keeper.sleep(.001) - count.record(config, 1.0, time_keeper.ms()) # oldest event times out - assert 2.0 == count.measure(config, time_keeper.ms()) - - -def test_old_data_has_no_effect(mocker, time_keeper): - mocker.patch('time.time', side_effect=time_keeper.time) - - max_stat = Max() - min_stat = Min() - avg_stat = Avg() - count_stat = Count() - window_ms = 100 - samples = 2 - config = MetricConfig(time_window_ms=window_ms, samples=samples) - max_stat.record(config, 50, time_keeper.ms()) - min_stat.record(config, 50, time_keeper.ms()) - avg_stat.record(config, 50, time_keeper.ms()) - count_stat.record(config, 50, time_keeper.ms()) - - time_keeper.sleep(samples * window_ms / 1000.0) - assert float('-inf') == max_stat.measure(config, time_keeper.ms()) - assert float(sys.maxsize) == min_stat.measure(config, time_keeper.ms()) - assert 0.0 == avg_stat.measure(config, time_keeper.ms()) - assert 0 == count_stat.measure(config, time_keeper.ms()) - - -def test_duplicate_MetricName(metrics): - metrics.sensor('test').add(metrics.metric_name('test', 'grp1'), Avg()) - with pytest.raises(ValueError): - metrics.sensor('test2').add(metrics.metric_name('test', 'grp1'), Total()) - - -def test_Quotas(metrics): - sensor = metrics.sensor('test') - sensor.add(metrics.metric_name('test1.total', 'grp1'), Total(), - MetricConfig(quota=Quota.upper_bound(5.0))) - sensor.add(metrics.metric_name('test2.total', 'grp1'), Total(), - MetricConfig(quota=Quota.lower_bound(0.0))) - sensor.record(5.0) - with pytest.raises(QuotaViolationError): - sensor.record(1.0) - - assert abs(6.0 - metrics.metrics.get(metrics.metric_name('test1.total', 'grp1')).value()) \ - < EPS - - sensor.record(-6.0) - with pytest.raises(QuotaViolationError): - sensor.record(-1.0) - - -def test_Quotas_equality(): - quota1 = Quota.upper_bound(10.5) - quota2 = Quota.lower_bound(10.5) - assert quota1 != quota2, 'Quota with different upper values should not be equal' - - quota3 = Quota.lower_bound(10.5) - assert quota2 == quota3, 'Quota with same upper and bound values should be equal' - - -def test_Percentiles(metrics): - buckets = 100 - _percentiles = [ - Percentile(metrics.metric_name('test.p25', 'grp1'), 25), - Percentile(metrics.metric_name('test.p50', 'grp1'), 50), - Percentile(metrics.metric_name('test.p75', 'grp1'), 75), - ] - percs = Percentiles(4 * buckets, BucketSizing.CONSTANT, 100.0, 0.0, - percentiles=_percentiles) - config = MetricConfig(event_window=50, samples=2) - sensor = metrics.sensor('test', config) - sensor.add_compound(percs) - p25 = metrics.metrics.get(metrics.metric_name('test.p25', 'grp1')) - p50 = metrics.metrics.get(metrics.metric_name('test.p50', 'grp1')) - p75 = metrics.metrics.get(metrics.metric_name('test.p75', 'grp1')) - - # record two windows worth of sequential values - for i in range(buckets): - sensor.record(i) - - assert abs(p25.value() - 25) < 1.0 - assert abs(p50.value() - 50) < 1.0 - assert abs(p75.value() - 75) < 1.0 - - for i in range(buckets): - sensor.record(0.0) - - assert p25.value() < 1.0 - assert p50.value() < 1.0 - assert p75.value() < 1.0 - -def test_rate_windowing(mocker, time_keeper, metrics): - mocker.patch('time.time', side_effect=time_keeper.time) - - # Use the default time window. Set 3 samples - config = MetricConfig(samples=3) - sensor = metrics.sensor('test.sensor', config) - sensor.add(metrics.metric_name('test.rate', 'grp1'), Rate(TimeUnit.SECONDS)) - - sum_val = 0 - count = config.samples - 1 - # Advance 1 window after every record - for i in range(count): - sensor.record(100) - sum_val += 100 - time_keeper.sleep(config.time_window_ms / 1000.0) - - # Sleep for half the window. - time_keeper.sleep(config.time_window_ms / 2.0 / 1000.0) - - # prior to any time passing - elapsed_secs = (config.time_window_ms * (config.samples - 1) + config.time_window_ms / 2.0) / 1000.0 - - kafka_metric = metrics.metrics.get(metrics.metric_name('test.rate', 'grp1')) - assert abs((sum_val / elapsed_secs) - kafka_metric.value()) < EPS, \ - 'Rate(0...2) = 2.666' - assert abs(elapsed_secs - (kafka_metric.measurable.window_size(config, time.time() * 1000) / 1000.0)) \ - < EPS, 'Elapsed Time = 75 seconds' - - -def test_reporter(metrics): - reporter = DictReporter() - foo_reporter = DictReporter(prefix='foo') - metrics.add_reporter(reporter) - metrics.add_reporter(foo_reporter) - sensor = metrics.sensor('kafka.requests') - sensor.add(metrics.metric_name('pack.bean1.avg', 'grp1'), Avg()) - sensor.add(metrics.metric_name('pack.bean2.total', 'grp2'), Total()) - sensor2 = metrics.sensor('kafka.blah') - sensor2.add(metrics.metric_name('pack.bean1.some', 'grp1'), Total()) - sensor2.add(metrics.metric_name('pack.bean2.some', 'grp1', - tags={'a': 42, 'b': 'bar'}), Total()) - - # kafka-metrics-count > count is the total number of metrics and automatic - expected = { - 'kafka-metrics-count': {'count': 5.0}, - 'grp2': {'pack.bean2.total': 0.0}, - 'grp1': {'pack.bean1.avg': 0.0, 'pack.bean1.some': 0.0}, - 'grp1.a=42,b=bar': {'pack.bean2.some': 0.0}, - } - assert expected == reporter.snapshot() - - for key in list(expected.keys()): - metrics = expected.pop(key) - expected['foo.%s' % (key,)] = metrics - assert expected == foo_reporter.snapshot() - - -class ConstantMeasurable(AbstractMeasurable): - _value = 0.0 - - def measure(self, config, now): - return self._value - - -class TimeKeeper(object): - """ - A clock that you can manually advance by calling sleep - """ - def __init__(self, auto_tick_ms=0): - self._millis = time.time() * 1000 - self._auto_tick_ms = auto_tick_ms - - def time(self): - return self.ms() / 1000.0 - - def ms(self): - self.sleep(self._auto_tick_ms) - return self._millis - - def sleep(self, seconds): - self._millis += (seconds * 1000) diff --git a/tests/test_metrics.py b/tests/test_metrics.py new file mode 100644 index 00000000..1e789d36 --- /dev/null +++ b/tests/test_metrics.py @@ -0,0 +1,585 @@ +import sys +import time + +import pytest + +from aiokafka.errors import QuotaViolationError +from aiokafka.metrics import DictReporter, MetricConfig, MetricName, Metrics, Quota +from aiokafka.metrics.measurable import AbstractMeasurable +from aiokafka.metrics.stats import ( + Avg, + Count, + Max, + Min, + Percentile, + Percentiles, + Rate, + Total, +) +from aiokafka.metrics.stats.percentiles import BucketSizing +from aiokafka.metrics.stats.rate import TimeUnit + +EPS = 0.000001 + + +@pytest.fixture +def time_keeper(): + return TimeKeeper() + + +@pytest.fixture +def config(): + return MetricConfig() + + +@pytest.fixture +def reporter(): + return DictReporter() + + +@pytest.fixture +def metrics(request, config, reporter): + metrics = Metrics(config, [reporter], enable_expiration=True) + yield metrics + metrics.close() + + +def test_MetricName(): + # The Java test only cover the differences between the deprecated + # constructors, so I'm skipping them but doing some other basic testing. + + # In short, metrics should be equal IFF their name, group, and tags are + # the same. Descriptions do not matter. + name1 = MetricName("name", "group", "A metric.", {"a": 1, "b": 2}) + name2 = MetricName("name", "group", "A description.", {"a": 1, "b": 2}) + assert name1 == name2 + + name1 = MetricName("name", "group", tags={"a": 1, "b": 2}) + name2 = MetricName("name", "group", tags={"a": 1, "b": 2}) + assert name1 == name2 + + name1 = MetricName("foo", "group") + name2 = MetricName("name", "group") + assert name1 != name2 + + name1 = MetricName("name", "foo") + name2 = MetricName("name", "group") + assert name1 != name2 + + # name and group must be non-empty. Everything else is optional. + with pytest.raises(Exception): + MetricName("", "group") + with pytest.raises(Exception): + MetricName("name", None) + # tags must be a dict if supplied + with pytest.raises(Exception): + MetricName("name", "group", tags=set()) + + # Because of the implementation of __eq__ and __hash__, the values of + # a MetricName cannot be mutable. + tags = {"a": 1} + name = MetricName("name", "group", "description", tags=tags) + with pytest.raises(AttributeError): + name.name = "new name" + with pytest.raises(AttributeError): + name.group = "new name" + with pytest.raises(AttributeError): + name.tags = {} + # tags is a copy, so the instance isn't altered + name.tags["b"] = 2 + assert name.tags == tags + + +def test_simple_stats(mocker, time_keeper, config, metrics): + mocker.patch("time.time", side_effect=time_keeper.time) + + measurable = ConstantMeasurable() + + metrics.add_metric( + metrics.metric_name( + "direct.measurable", + "grp1", + "The fraction of time an appender waits for space allocation.", + ), + measurable, + ) + sensor = metrics.sensor("test.sensor") + sensor.add(metrics.metric_name("test.avg", "grp1"), Avg()) + sensor.add(metrics.metric_name("test.max", "grp1"), Max()) + sensor.add(metrics.metric_name("test.min", "grp1"), Min()) + sensor.add(metrics.metric_name("test.rate", "grp1"), Rate(TimeUnit.SECONDS)) + sensor.add( + metrics.metric_name("test.occurences", "grp1"), Rate(TimeUnit.SECONDS, Count()) + ) + sensor.add(metrics.metric_name("test.count", "grp1"), Count()) + percentiles = [ + Percentile(metrics.metric_name("test.median", "grp1"), 50.0), + Percentile(metrics.metric_name("test.perc99_9", "grp1"), 99.9), + ] + sensor.add_compound( + Percentiles(100, BucketSizing.CONSTANT, 100, -100, percentiles=percentiles) + ) + + sensor2 = metrics.sensor("test.sensor2") + sensor2.add(metrics.metric_name("s2.total", "grp1"), Total()) + sensor2.record(5.0) + + sum_val = 0 + count = 10 + for i in range(count): + sensor.record(i) + sum_val += i + + # prior to any time passing + elapsed_secs = (config.time_window_ms * (config.samples - 1)) / 1000.0 + assert ( + abs( + count / elapsed_secs + - metrics.metrics.get( + metrics.metric_name("test.occurences", "grp1") + ).value() + ) + < EPS + ), "Occurrences(0...%d) = %f" % (count, count / elapsed_secs) + + # pretend 2 seconds passed... + sleep_time_seconds = 2.0 + time_keeper.sleep(sleep_time_seconds) + elapsed_secs += sleep_time_seconds + + assert ( + abs(5.0 - metrics.metrics.get(metrics.metric_name("s2.total", "grp1")).value()) + < EPS + ), "s2 reflects the constant value" + assert ( + abs(4.5 - metrics.metrics.get(metrics.metric_name("test.avg", "grp1")).value()) + < EPS + ), "Avg(0...9) = 4.5" + assert ( + abs( + (count - 1) + - metrics.metrics.get(metrics.metric_name("test.max", "grp1")).value() + ) + < EPS + ), "Max(0...9) = 9" + assert ( + abs(0.0 - metrics.metrics.get(metrics.metric_name("test.min", "grp1")).value()) + < EPS + ), "Min(0...9) = 0" + assert ( + abs( + (sum_val / elapsed_secs) + - metrics.metrics.get(metrics.metric_name("test.rate", "grp1")).value() + ) + < EPS + ), "Rate(0...9) = 1.40625" + assert ( + abs( + (count / elapsed_secs) + - metrics.metrics.get( + metrics.metric_name("test.occurences", "grp1") + ).value() + ) + < EPS + ), "Occurrences(0...%d) = %f" % (count, count / elapsed_secs) + assert ( + abs( + count + - metrics.metrics.get(metrics.metric_name("test.count", "grp1")).value() + ) + < EPS + ), "Count(0...9) = 10" + + +def test_hierarchical_sensors(metrics): + parent1 = metrics.sensor("test.parent1") + parent1.add(metrics.metric_name("test.parent1.count", "grp1"), Count()) + parent2 = metrics.sensor("test.parent2") + parent2.add(metrics.metric_name("test.parent2.count", "grp1"), Count()) + child1 = metrics.sensor("test.child1", parents=[parent1, parent2]) + child1.add(metrics.metric_name("test.child1.count", "grp1"), Count()) + child2 = metrics.sensor("test.child2", parents=[parent1]) + child2.add(metrics.metric_name("test.child2.count", "grp1"), Count()) + grandchild = metrics.sensor("test.grandchild", parents=[child1]) + grandchild.add(metrics.metric_name("test.grandchild.count", "grp1"), Count()) + + # increment each sensor one time + parent1.record() + parent2.record() + child1.record() + child2.record() + grandchild.record() + + p1 = parent1.metrics[0].value() + p2 = parent2.metrics[0].value() + c1 = child1.metrics[0].value() + c2 = child2.metrics[0].value() + gc = grandchild.metrics[0].value() + + # each metric should have a count equal to one + its children's count + assert 1.0 == gc + assert 1.0 + gc == c1 + assert 1.0 == c2 + assert 1.0 + c1 == p2 + assert 1.0 + c1 + c2 == p1 + assert [child1, child2] == metrics._children_sensors.get(parent1) + assert [child1] == metrics._children_sensors.get(parent2) + assert metrics._children_sensors.get(grandchild) is None + + +def test_bad_sensor_hierarchy(metrics): + parent = metrics.sensor("parent") + child1 = metrics.sensor("child1", parents=[parent]) + child2 = metrics.sensor("child2", parents=[parent]) + + with pytest.raises(ValueError): + metrics.sensor("gc", parents=[child1, child2]) + + +def test_remove_sensor(metrics): + size = len(metrics.metrics) + parent1 = metrics.sensor("test.parent1") + parent1.add(metrics.metric_name("test.parent1.count", "grp1"), Count()) + parent2 = metrics.sensor("test.parent2") + parent2.add(metrics.metric_name("test.parent2.count", "grp1"), Count()) + child1 = metrics.sensor("test.child1", parents=[parent1, parent2]) + child1.add(metrics.metric_name("test.child1.count", "grp1"), Count()) + child2 = metrics.sensor("test.child2", parents=[parent2]) + child2.add(metrics.metric_name("test.child2.count", "grp1"), Count()) + grandchild1 = metrics.sensor("test.gchild2", parents=[child2]) + grandchild1.add(metrics.metric_name("test.gchild2.count", "grp1"), Count()) + + sensor = metrics.get_sensor("test.parent1") + assert sensor is not None + metrics.remove_sensor("test.parent1") + assert metrics.get_sensor("test.parent1") is None + assert ( + metrics.metrics.get(metrics.metric_name("test.parent1.count", "grp1")) is None + ) + assert metrics.get_sensor("test.child1") is None + assert metrics._children_sensors.get(sensor) is None + assert metrics.metrics.get(metrics.metric_name("test.child1.count", "grp1")) is None + + sensor = metrics.get_sensor("test.gchild2") + assert sensor is not None + metrics.remove_sensor("test.gchild2") + assert metrics.get_sensor("test.gchild2") is None + assert metrics._children_sensors.get(sensor) is None + assert ( + metrics.metrics.get(metrics.metric_name("test.gchild2.count", "grp1")) is None + ) + + sensor = metrics.get_sensor("test.child2") + assert sensor is not None + metrics.remove_sensor("test.child2") + assert metrics.get_sensor("test.child2") is None + assert metrics._children_sensors.get(sensor) is None + assert metrics.metrics.get(metrics.metric_name("test.child2.count", "grp1")) is None + + sensor = metrics.get_sensor("test.parent2") + assert sensor is not None + metrics.remove_sensor("test.parent2") + assert metrics.get_sensor("test.parent2") is None + assert metrics._children_sensors.get(sensor) is None + assert ( + metrics.metrics.get(metrics.metric_name("test.parent2.count", "grp1")) is None + ) + + assert size == len(metrics.metrics) + + +def test_remove_inactive_metrics(mocker, time_keeper, metrics): + mocker.patch("time.time", side_effect=time_keeper.time) + + s1 = metrics.sensor("test.s1", None, 1) + s1.add(metrics.metric_name("test.s1.count", "grp1"), Count()) + + s2 = metrics.sensor("test.s2", None, 3) + s2.add(metrics.metric_name("test.s2.count", "grp1"), Count()) + + purger = Metrics.ExpireSensorTask + purger.run(metrics) + assert metrics.get_sensor("test.s1") is not None, "Sensor test.s1 must be present" + assert ( + metrics.metrics.get(metrics.metric_name("test.s1.count", "grp1")) is not None + ), "MetricName test.s1.count must be present" + assert metrics.get_sensor("test.s2") is not None, "Sensor test.s2 must be present" + assert ( + metrics.metrics.get(metrics.metric_name("test.s2.count", "grp1")) is not None + ), "MetricName test.s2.count must be present" + + time_keeper.sleep(1.001) + purger.run(metrics) + assert ( + metrics.get_sensor("test.s1") is None + ), "Sensor test.s1 should have been purged" + assert ( + metrics.metrics.get(metrics.metric_name("test.s1.count", "grp1")) is None + ), "MetricName test.s1.count should have been purged" + assert metrics.get_sensor("test.s2") is not None, "Sensor test.s2 must be present" + assert ( + metrics.metrics.get(metrics.metric_name("test.s2.count", "grp1")) is not None + ), "MetricName test.s2.count must be present" + + # record a value in sensor s2. This should reset the clock for that sensor. + # It should not get purged at the 3 second mark after creation + s2.record() + + time_keeper.sleep(2) + purger.run(metrics) + assert metrics.get_sensor("test.s2") is not None, "Sensor test.s2 must be present" + assert ( + metrics.metrics.get(metrics.metric_name("test.s2.count", "grp1")) is not None + ), "MetricName test.s2.count must be present" + + # After another 1 second sleep, the metric should be purged + time_keeper.sleep(1) + purger.run(metrics) + assert ( + metrics.get_sensor("test.s1") is None + ), "Sensor test.s2 should have been purged" + assert ( + metrics.metrics.get(metrics.metric_name("test.s1.count", "grp1")) is None + ), "MetricName test.s2.count should have been purged" + + # After purging, it should be possible to recreate a metric + s1 = metrics.sensor("test.s1", None, 1) + s1.add(metrics.metric_name("test.s1.count", "grp1"), Count()) + assert metrics.get_sensor("test.s1") is not None, "Sensor test.s1 must be present" + assert ( + metrics.metrics.get(metrics.metric_name("test.s1.count", "grp1")) is not None + ), "MetricName test.s1.count must be present" + + +def test_remove_metric(metrics): + size = len(metrics.metrics) + metrics.add_metric(metrics.metric_name("test1", "grp1"), Count()) + metrics.add_metric(metrics.metric_name("test2", "grp1"), Count()) + + assert metrics.remove_metric(metrics.metric_name("test1", "grp1")) is not None + assert metrics.metrics.get(metrics.metric_name("test1", "grp1")) is None + assert metrics.metrics.get(metrics.metric_name("test2", "grp1")) is not None + + assert metrics.remove_metric(metrics.metric_name("test2", "grp1")) is not None + assert metrics.metrics.get(metrics.metric_name("test2", "grp1")) is None + + assert size == len(metrics.metrics) + + +def test_event_windowing(mocker, time_keeper): + mocker.patch("time.time", side_effect=time_keeper.time) + + count = Count() + config = MetricConfig(event_window=1, samples=2) + count.record(config, 1.0, time_keeper.ms()) + count.record(config, 1.0, time_keeper.ms()) + assert 2.0 == count.measure(config, time_keeper.ms()) + count.record(config, 1.0, time_keeper.ms()) # first event times out + assert 2.0 == count.measure(config, time_keeper.ms()) + + +def test_time_windowing(mocker, time_keeper): + mocker.patch("time.time", side_effect=time_keeper.time) + + count = Count() + config = MetricConfig(time_window_ms=1, samples=2) + count.record(config, 1.0, time_keeper.ms()) + time_keeper.sleep(0.001) + count.record(config, 1.0, time_keeper.ms()) + assert 2.0 == count.measure(config, time_keeper.ms()) + time_keeper.sleep(0.001) + count.record(config, 1.0, time_keeper.ms()) # oldest event times out + assert 2.0 == count.measure(config, time_keeper.ms()) + + +def test_old_data_has_no_effect(mocker, time_keeper): + mocker.patch("time.time", side_effect=time_keeper.time) + + max_stat = Max() + min_stat = Min() + avg_stat = Avg() + count_stat = Count() + window_ms = 100 + samples = 2 + config = MetricConfig(time_window_ms=window_ms, samples=samples) + max_stat.record(config, 50, time_keeper.ms()) + min_stat.record(config, 50, time_keeper.ms()) + avg_stat.record(config, 50, time_keeper.ms()) + count_stat.record(config, 50, time_keeper.ms()) + + time_keeper.sleep(samples * window_ms / 1000.0) + assert float("-inf") == max_stat.measure(config, time_keeper.ms()) + assert float(sys.maxsize) == min_stat.measure(config, time_keeper.ms()) + assert 0.0 == avg_stat.measure(config, time_keeper.ms()) + assert 0 == count_stat.measure(config, time_keeper.ms()) + + +def test_duplicate_MetricName(metrics): + metrics.sensor("test").add(metrics.metric_name("test", "grp1"), Avg()) + with pytest.raises(ValueError): + metrics.sensor("test2").add(metrics.metric_name("test", "grp1"), Total()) + + +def test_Quotas(metrics): + sensor = metrics.sensor("test") + sensor.add( + metrics.metric_name("test1.total", "grp1"), + Total(), + MetricConfig(quota=Quota.upper_bound(5.0)), + ) + sensor.add( + metrics.metric_name("test2.total", "grp1"), + Total(), + MetricConfig(quota=Quota.lower_bound(0.0)), + ) + sensor.record(5.0) + with pytest.raises(QuotaViolationError): + sensor.record(1.0) + + assert ( + abs( + 6.0 + - metrics.metrics.get(metrics.metric_name("test1.total", "grp1")).value() + ) + < EPS + ) + + sensor.record(-6.0) + with pytest.raises(QuotaViolationError): + sensor.record(-1.0) + + +def test_Quotas_equality(): + quota1 = Quota.upper_bound(10.5) + quota2 = Quota.lower_bound(10.5) + assert quota1 != quota2, "Quota with different upper values should not be equal" + + quota3 = Quota.lower_bound(10.5) + assert quota2 == quota3, "Quota with same upper and bound values should be equal" + + +def test_Percentiles(metrics): + buckets = 100 + _percentiles = [ + Percentile(metrics.metric_name("test.p25", "grp1"), 25), + Percentile(metrics.metric_name("test.p50", "grp1"), 50), + Percentile(metrics.metric_name("test.p75", "grp1"), 75), + ] + percs = Percentiles( + 4 * buckets, BucketSizing.CONSTANT, 100.0, 0.0, percentiles=_percentiles + ) + config = MetricConfig(event_window=50, samples=2) + sensor = metrics.sensor("test", config) + sensor.add_compound(percs) + p25 = metrics.metrics.get(metrics.metric_name("test.p25", "grp1")) + p50 = metrics.metrics.get(metrics.metric_name("test.p50", "grp1")) + p75 = metrics.metrics.get(metrics.metric_name("test.p75", "grp1")) + + # record two windows worth of sequential values + for i in range(buckets): + sensor.record(i) + + assert abs(p25.value() - 25) < 1.0 + assert abs(p50.value() - 50) < 1.0 + assert abs(p75.value() - 75) < 1.0 + + for i in range(buckets): + sensor.record(0.0) + + assert p25.value() < 1.0 + assert p50.value() < 1.0 + assert p75.value() < 1.0 + + +def test_rate_windowing(mocker, time_keeper, metrics): + mocker.patch("time.time", side_effect=time_keeper.time) + + # Use the default time window. Set 3 samples + config = MetricConfig(samples=3) + sensor = metrics.sensor("test.sensor", config) + sensor.add(metrics.metric_name("test.rate", "grp1"), Rate(TimeUnit.SECONDS)) + + sum_val = 0 + count = config.samples - 1 + # Advance 1 window after every record + for i in range(count): + sensor.record(100) + sum_val += 100 + time_keeper.sleep(config.time_window_ms / 1000.0) + + # Sleep for half the window. + time_keeper.sleep(config.time_window_ms / 2.0 / 1000.0) + + # prior to any time passing + elapsed_secs = ( + config.time_window_ms * (config.samples - 1) + config.time_window_ms / 2.0 + ) / 1000.0 + + kafka_metric = metrics.metrics.get(metrics.metric_name("test.rate", "grp1")) + assert ( + abs((sum_val / elapsed_secs) - kafka_metric.value()) < EPS + ), "Rate(0...2) = 2.666" + assert ( + abs( + elapsed_secs + - (kafka_metric.measurable.window_size(config, time.time() * 1000) / 1000.0) + ) + < EPS + ), "Elapsed Time = 75 seconds" + + +def test_reporter(metrics): + reporter = DictReporter() + foo_reporter = DictReporter(prefix="foo") + metrics.add_reporter(reporter) + metrics.add_reporter(foo_reporter) + sensor = metrics.sensor("kafka.requests") + sensor.add(metrics.metric_name("pack.bean1.avg", "grp1"), Avg()) + sensor.add(metrics.metric_name("pack.bean2.total", "grp2"), Total()) + sensor2 = metrics.sensor("kafka.blah") + sensor2.add(metrics.metric_name("pack.bean1.some", "grp1"), Total()) + sensor2.add( + metrics.metric_name("pack.bean2.some", "grp1", tags={"a": 42, "b": "bar"}), + Total(), + ) + + # kafka-metrics-count > count is the total number of metrics and automatic + expected = { + "kafka-metrics-count": {"count": 5.0}, + "grp2": {"pack.bean2.total": 0.0}, + "grp1": {"pack.bean1.avg": 0.0, "pack.bean1.some": 0.0}, + "grp1.a=42,b=bar": {"pack.bean2.some": 0.0}, + } + assert expected == reporter.snapshot() + + for key in list(expected.keys()): + metrics = expected.pop(key) + expected["foo.%s" % (key,)] = metrics + assert expected == foo_reporter.snapshot() + + +class ConstantMeasurable(AbstractMeasurable): + _value = 0.0 + + def measure(self, config, now): + return self._value + + +class TimeKeeper(object): + """ + A clock that you can manually advance by calling sleep + """ + + def __init__(self, auto_tick_ms=0): + self._millis = time.time() * 1000 + self._auto_tick_ms = auto_tick_ms + + def time(self): + return self.ms() / 1000.0 + + def ms(self): + self.sleep(self._auto_tick_ms) + return self._millis + + def sleep(self, seconds): + self._millis += seconds * 1000 From 91759b81bb81dc11fd132bfabdd3e3b79f95ee3d Mon Sep 17 00:00:00 2001 From: Denis Otkidach Date: Mon, 23 Oct 2023 09:56:22 +0300 Subject: [PATCH 16/20] Merge protocol --- aiokafka/admin/client.py | 17 +- aiokafka/client.py | 19 +- aiokafka/conn.py | 13 +- aiokafka/consumer/fetcher.py | 4 +- aiokafka/consumer/group_coordinator.py | 13 +- .../assignors/sticky/sticky_assignor.py | 5 +- aiokafka/coordinator/base.py | 12 +- aiokafka/coordinator/consumer.py | 2 +- aiokafka/coordinator/protocol.py | 5 +- aiokafka/producer/producer.py | 2 +- aiokafka/producer/sender.py | 3 +- aiokafka/protocol/__init__.py | 46 + {kafka => aiokafka}/protocol/abstract.py | 6 +- aiokafka/protocol/admin.py | 1278 +++++++++++++++++ {kafka => aiokafka}/protocol/api.py | 43 +- aiokafka/protocol/commit.py | 306 ++++ aiokafka/protocol/coordination.py | 39 +- aiokafka/protocol/fetch.py | 516 +++++++ {kafka => aiokafka}/protocol/frame.py | 2 +- aiokafka/protocol/group.py | 203 +++ {kafka => aiokafka}/protocol/message.py | 105 +- aiokafka/protocol/metadata.py | 260 ++++ aiokafka/protocol/offset.py | 246 ++++ {kafka => aiokafka}/protocol/parser.py | 84 +- {kafka => aiokafka}/protocol/pickle.py | 9 +- aiokafka/protocol/produce.py | 299 ++++ {kafka => aiokafka}/protocol/struct.py | 26 +- aiokafka/protocol/transaction.py | 167 +-- {kafka => aiokafka}/protocol/types.py | 140 +- docs/api.rst | 2 +- kafka/__init__.py | 1 - kafka/protocol/__init__.py | 49 - kafka/protocol/admin.py | 1054 -------------- kafka/protocol/commit.py | 255 ---- kafka/protocol/fetch.py | 386 ----- kafka/protocol/group.py | 230 --- kafka/protocol/metadata.py | 200 --- kafka/protocol/offset.py | 194 --- kafka/protocol/produce.py | 232 --- kafka/version.py | 1 - tests/kafka/fixtures.py | 4 +- tests/kafka/test_api_object_implementation.py | 18 - tests/kafka/test_object_conversion.py | 236 --- tests/kafka/test_protocol.py | 336 ----- tests/test_client.py | 9 +- tests/test_cluster.py | 3 +- tests/test_conn.py | 19 +- tests/test_coordinator.py | 21 +- tests/test_fetcher.py | 26 +- tests/test_producer.py | 5 +- tests/test_protocol.py | 376 +++++ tests/test_protocol_object_conversion.py | 251 ++++ tests/test_sender.py | 34 +- 53 files changed, 4196 insertions(+), 3616 deletions(-) rename {kafka => aiokafka}/protocol/abstract.py (57%) create mode 100644 aiokafka/protocol/admin.py rename {kafka => aiokafka}/protocol/api.py (74%) create mode 100644 aiokafka/protocol/commit.py create mode 100644 aiokafka/protocol/fetch.py rename {kafka => aiokafka}/protocol/frame.py (94%) create mode 100644 aiokafka/protocol/group.py rename {kafka => aiokafka}/protocol/message.py (73%) create mode 100644 aiokafka/protocol/metadata.py create mode 100644 aiokafka/protocol/offset.py rename {kafka => aiokafka}/protocol/parser.py (69%) rename {kafka => aiokafka}/protocol/pickle.py (80%) create mode 100644 aiokafka/protocol/produce.py rename {kafka => aiokafka}/protocol/struct.py (72%) rename {kafka => aiokafka}/protocol/types.py (68%) delete mode 100644 kafka/protocol/__init__.py delete mode 100644 kafka/protocol/admin.py delete mode 100644 kafka/protocol/commit.py delete mode 100644 kafka/protocol/fetch.py delete mode 100644 kafka/protocol/group.py delete mode 100644 kafka/protocol/metadata.py delete mode 100644 kafka/protocol/offset.py delete mode 100644 kafka/protocol/produce.py delete mode 100644 kafka/version.py delete mode 100644 tests/kafka/test_api_object_implementation.py delete mode 100644 tests/kafka/test_object_conversion.py delete mode 100644 tests/kafka/test_protocol.py create mode 100644 tests/test_protocol.py create mode 100644 tests/test_protocol_object_conversion.py diff --git a/aiokafka/admin/client.py b/aiokafka/admin/client.py index 32309d7b..cb436a47 100644 --- a/aiokafka/admin/client.py +++ b/aiokafka/admin/client.py @@ -4,10 +4,13 @@ from ssl import SSLContext from typing import List, Optional, Dict, Tuple, Any -from kafka.protocol.api import Request, Response -from kafka.protocol.metadata import MetadataRequest -from kafka.protocol.commit import OffsetFetchRequest, GroupCoordinatorRequest -from kafka.protocol.admin import ( +from aiokafka import __version__ +from aiokafka.client import AIOKafkaClient +from aiokafka.errors import IncompatibleBrokerVersion, for_code +from aiokafka.protocol.api import Request, Response +from aiokafka.protocol.metadata import MetadataRequest +from aiokafka.protocol.commit import OffsetFetchRequest, GroupCoordinatorRequest +from aiokafka.protocol.admin import ( CreatePartitionsRequest, CreateTopicsRequest, DeleteTopicsRequest, @@ -16,10 +19,6 @@ AlterConfigsRequest, ListGroupsRequest, ApiVersionRequest_v0) - -from aiokafka import __version__ -from aiokafka.client import AIOKafkaClient -from aiokafka.errors import IncompatibleBrokerVersion, for_code from aiokafka.structs import TopicPartition, OffsetAndMetadata from .config_resource import ConfigResourceType, ConfigResource @@ -149,7 +148,7 @@ def _matching_api_version(self, operation: List[Request]) -> int: supported by the broker. :param operation: A list of protocol operation versions from - kafka.protocol. + aiokafka.protocol. :return: The max matching version number between client and broker. """ api_key = operation[0].API_KEY diff --git a/aiokafka/client.py b/aiokafka/client.py index 371e4668..9ecd23cd 100644 --- a/aiokafka/client.py +++ b/aiokafka/client.py @@ -3,18 +3,17 @@ import random import time -from kafka.protocol.admin import DescribeAclsRequest_v2 -from kafka.protocol.commit import OffsetFetchRequest -from kafka.protocol.fetch import FetchRequest -from kafka.protocol.metadata import MetadataRequest -from kafka.protocol.offset import OffsetRequest -from kafka.protocol.produce import ProduceRequest - import aiokafka.errors as Errors from aiokafka import __version__ from aiokafka.conn import collect_hosts, create_conn, CloseReason from aiokafka.cluster import ClusterMetadata +from aiokafka.protocol.admin import DescribeAclsRequest_v2 +from aiokafka.protocol.commit import OffsetFetchRequest from aiokafka.protocol.coordination import FindCoordinatorRequest +from aiokafka.protocol.fetch import FetchRequest +from aiokafka.protocol.metadata import MetadataRequest +from aiokafka.protocol.offset import OffsetRequest +from aiokafka.protocol.produce import ProduceRequest from aiokafka.errors import ( KafkaError, KafkaConnectionError, @@ -525,11 +524,11 @@ async def check_version(self, node_id=None): assert self.cluster.brokers(), 'no brokers in metadata' node_id = list(self.cluster.brokers())[0].nodeId - from kafka.protocol.admin import ( + from aiokafka.protocol.admin import ( ListGroupsRequest_v0, ApiVersionRequest_v0) - from kafka.protocol.commit import ( + from aiokafka.protocol.commit import ( OffsetFetchRequest_v0, GroupCoordinatorRequest_v0) - from kafka.protocol.metadata import MetadataRequest_v0 + from aiokafka.protocol.metadata import MetadataRequest_v0 test_cases = [ ((0, 10), ApiVersionRequest_v0()), ((0, 9), ListGroupsRequest_v0()), diff --git a/aiokafka/conn.py b/aiokafka/conn.py index 2be93dc3..da27fd27 100644 --- a/aiokafka/conn.py +++ b/aiokafka/conn.py @@ -16,18 +16,17 @@ import weakref import async_timeout -from kafka.protocol.api import RequestHeader -from kafka.protocol.admin import ( + +import aiokafka.errors as Errors +from aiokafka.abc import AbstractTokenProvider +from aiokafka.protocol.api import RequestHeader +from aiokafka.protocol.admin import ( SaslHandShakeRequest, SaslAuthenticateRequest, ApiVersionRequest ) -from kafka.protocol.commit import ( +from aiokafka.protocol.commit import ( GroupCoordinatorResponse_v0 as GroupCoordinatorResponse) - -import aiokafka.errors as Errors from aiokafka.util import create_future, create_task, get_running_loop, wait_for -from aiokafka.abc import AbstractTokenProvider - try: import gssapi except ImportError: diff --git a/aiokafka/consumer/fetcher.py b/aiokafka/consumer/fetcher.py index 2a3394b3..6d08bd21 100644 --- a/aiokafka/consumer/fetcher.py +++ b/aiokafka/consumer/fetcher.py @@ -6,12 +6,12 @@ from itertools import chain import async_timeout -from kafka.protocol.offset import OffsetRequest -from kafka.protocol.fetch import FetchRequest import aiokafka.errors as Errors from aiokafka.errors import ( ConsumerStoppedError, RecordTooLargeError, KafkaTimeoutError) +from aiokafka.protocol.offset import OffsetRequest +from aiokafka.protocol.fetch import FetchRequest from aiokafka.record.memory_records import MemoryRecords from aiokafka.record.control_record import ControlRecord, ABORT_MARKER from aiokafka.structs import OffsetAndTimestamp, TopicPartition, ConsumerRecord diff --git a/aiokafka/consumer/group_coordinator.py b/aiokafka/consumer/group_coordinator.py index 8a8c76f4..963d558c 100644 --- a/aiokafka/consumer/group_coordinator.py +++ b/aiokafka/consumer/group_coordinator.py @@ -4,17 +4,16 @@ import copy import time -from kafka.protocol.commit import ( - OffsetCommitRequest_v2 as OffsetCommitRequest, - OffsetFetchRequest_v1 as OffsetFetchRequest) -from kafka.protocol.group import ( - HeartbeatRequest, JoinGroupRequest, LeaveGroupRequest, SyncGroupRequest) - import aiokafka.errors as Errors -from aiokafka.structs import OffsetAndMetadata, TopicPartition from aiokafka.client import ConnectionGroup, CoordinationType from aiokafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor from aiokafka.coordinator.protocol import ConsumerProtocol +from aiokafka.protocol.commit import ( + OffsetCommitRequest_v2 as OffsetCommitRequest, + OffsetFetchRequest_v1 as OffsetFetchRequest) +from aiokafka.protocol.group import ( + HeartbeatRequest, JoinGroupRequest, LeaveGroupRequest, SyncGroupRequest) +from aiokafka.structs import OffsetAndMetadata, TopicPartition from aiokafka.util import create_future, create_task log = logging.getLogger(__name__) diff --git a/aiokafka/coordinator/assignors/sticky/sticky_assignor.py b/aiokafka/coordinator/assignors/sticky/sticky_assignor.py index 05e14ef2..ae2235f5 100644 --- a/aiokafka/coordinator/assignors/sticky/sticky_assignor.py +++ b/aiokafka/coordinator/assignors/sticky/sticky_assignor.py @@ -2,9 +2,6 @@ from collections import defaultdict, namedtuple from copy import deepcopy -from kafka.protocol.struct import Struct -from kafka.protocol.types import String, Array, Int32 - from aiokafka.coordinator.assignors.abstract import AbstractPartitionAssignor from aiokafka.coordinator.assignors.sticky.partition_movements import PartitionMovements from aiokafka.coordinator.assignors.sticky.sorted_set import SortedSet @@ -13,6 +10,8 @@ ConsumerProtocolMemberAssignment, ) from aiokafka.coordinator.protocol import Schema +from aiokafka.protocol.struct import Struct +from aiokafka.protocol.types import String, Array, Int32 from aiokafka.structs import TopicPartition log = logging.getLogger(__name__) diff --git a/aiokafka/coordinator/base.py b/aiokafka/coordinator/base.py index 4489884d..ea6b4ccd 100644 --- a/aiokafka/coordinator/base.py +++ b/aiokafka/coordinator/base.py @@ -6,18 +6,18 @@ import weakref from kafka.future import Future -from kafka.protocol.commit import GroupCoordinatorRequest, OffsetCommitRequest -from kafka.protocol.group import ( + +from aiokafka import errors as Errors +from aiokafka.metrics import AnonMeasurable +from aiokafka.metrics.stats import Avg, Count, Max, Rate +from aiokafka.protocol.commit import GroupCoordinatorRequest, OffsetCommitRequest +from aiokafka.protocol.group import ( HeartbeatRequest, JoinGroupRequest, LeaveGroupRequest, SyncGroupRequest, ) -from aiokafka import errors as Errors -from aiokafka.metrics import AnonMeasurable -from aiokafka.metrics.stats import Avg, Count, Max, Rate - from .heartbeat import Heartbeat log = logging.getLogger("aiokafka.coordinator") diff --git a/aiokafka/coordinator/consumer.py b/aiokafka/coordinator/consumer.py index 2c7ebb4e..8f6cdaba 100644 --- a/aiokafka/coordinator/consumer.py +++ b/aiokafka/coordinator/consumer.py @@ -5,12 +5,12 @@ import time from kafka.future import Future -from kafka.protocol.commit import OffsetCommitRequest, OffsetFetchRequest from kafka.util import WeakMethod import aiokafka.errors as Errors from aiokafka.metrics import AnonMeasurable from aiokafka.metrics.stats import Avg, Count, Max, Rate +from aiokafka.protocol.commit import OffsetCommitRequest, OffsetFetchRequest from aiokafka.structs import OffsetAndMetadata, TopicPartition from .base import BaseCoordinator, Generation diff --git a/aiokafka/coordinator/protocol.py b/aiokafka/coordinator/protocol.py index 0dfbe7f9..aa86a7ff 100644 --- a/aiokafka/coordinator/protocol.py +++ b/aiokafka/coordinator/protocol.py @@ -1,6 +1,5 @@ -from kafka.protocol.struct import Struct -from kafka.protocol.types import Array, Bytes, Int16, Int32, Schema, String - +from aiokafka.protocol.struct import Struct +from aiokafka.protocol.types import Array, Bytes, Int16, Int32, Schema, String from aiokafka.structs import TopicPartition diff --git a/aiokafka/producer/producer.py b/aiokafka/producer/producer.py index f5a57b73..12a07e7f 100644 --- a/aiokafka/producer/producer.py +++ b/aiokafka/producer/producer.py @@ -128,7 +128,7 @@ class AIOKafkaProducer: brokers or partitions. Default: 300000 request_timeout_ms (int): Produce request timeout in milliseconds. As it's sent as part of - :class:`~kafka.protocol.produce.ProduceRequest` (it's a blocking + :class:`~aiokafka.protocol.produce.ProduceRequest` (it's a blocking call), maximum waiting time can be up to ``2 * request_timeout_ms``. Default: 40000. diff --git a/aiokafka/producer/sender.py b/aiokafka/producer/sender.py index b8faf1f6..bc6c8f4e 100644 --- a/aiokafka/producer/sender.py +++ b/aiokafka/producer/sender.py @@ -3,8 +3,6 @@ import logging import time -from kafka.protocol.produce import ProduceRequest - import aiokafka.errors as Errors from aiokafka.client import ConnectionGroup, CoordinationType from aiokafka.errors import ( @@ -16,6 +14,7 @@ OutOfOrderSequenceNumber, TopicAuthorizationFailedError, GroupAuthorizationFailedError, TransactionalIdAuthorizationFailed, OperationNotAttempted) +from aiokafka.protocol.produce import ProduceRequest from aiokafka.protocol.transaction import ( InitProducerIdRequest, AddPartitionsToTxnRequest, EndTxnRequest, AddOffsetsToTxnRequest, TxnOffsetCommitRequest diff --git a/aiokafka/protocol/__init__.py b/aiokafka/protocol/__init__.py index e69de29b..e001b571 100644 --- a/aiokafka/protocol/__init__.py +++ b/aiokafka/protocol/__init__.py @@ -0,0 +1,46 @@ +API_KEYS = { + 0: "Produce", + 1: "Fetch", + 2: "ListOffsets", + 3: "Metadata", + 4: "LeaderAndIsr", + 5: "StopReplica", + 6: "UpdateMetadata", + 7: "ControlledShutdown", + 8: "OffsetCommit", + 9: "OffsetFetch", + 10: "FindCoordinator", + 11: "JoinGroup", + 12: "Heartbeat", + 13: "LeaveGroup", + 14: "SyncGroup", + 15: "DescribeGroups", + 16: "ListGroups", + 17: "SaslHandshake", + 18: "ApiVersions", + 19: "CreateTopics", + 20: "DeleteTopics", + 21: "DeleteRecords", + 22: "InitProducerId", + 23: "OffsetForLeaderEpoch", + 24: "AddPartitionsToTxn", + 25: "AddOffsetsToTxn", + 26: "EndTxn", + 27: "WriteTxnMarkers", + 28: "TxnOffsetCommit", + 29: "DescribeAcls", + 30: "CreateAcls", + 31: "DeleteAcls", + 32: "DescribeConfigs", + 33: "AlterConfigs", + 36: "SaslAuthenticate", + 37: "CreatePartitions", + 38: "CreateDelegationToken", + 39: "RenewDelegationToken", + 40: "ExpireDelegationToken", + 41: "DescribeDelegationToken", + 42: "DeleteGroups", + 45: "AlterPartitionReassignments", + 46: "ListPartitionReassignments", + 48: "DescribeClientQuotas", +} diff --git a/kafka/protocol/abstract.py b/aiokafka/protocol/abstract.py similarity index 57% rename from kafka/protocol/abstract.py rename to aiokafka/protocol/abstract.py index 2de65c4b..b52a79ce 100644 --- a/kafka/protocol/abstract.py +++ b/aiokafka/protocol/abstract.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - import abc @@ -7,11 +5,11 @@ class AbstractType(object): __metaclass__ = abc.ABCMeta @abc.abstractmethod - def encode(cls, value): # pylint: disable=no-self-argument + def encode(cls, value): # pylint: disable=no-self-argument pass @abc.abstractmethod - def decode(cls, data): # pylint: disable=no-self-argument + def decode(cls, data): # pylint: disable=no-self-argument pass @classmethod diff --git a/aiokafka/protocol/admin.py b/aiokafka/protocol/admin.py new file mode 100644 index 00000000..e1b0ffc4 --- /dev/null +++ b/aiokafka/protocol/admin.py @@ -0,0 +1,1278 @@ +from .api import Request, Response +from .types import ( + Array, + Boolean, + Bytes, + Int8, + Int16, + Int32, + Int64, + Schema, + String, + Float64, + CompactString, + CompactArray, + TaggedFields, +) + + +class ApiVersionResponse_v0(Response): + API_KEY = 18 + API_VERSION = 0 + SCHEMA = Schema( + ("error_code", Int16), + ( + "api_versions", + Array(("api_key", Int16), ("min_version", Int16), ("max_version", Int16)), + ), + ) + + +class ApiVersionResponse_v1(Response): + API_KEY = 18 + API_VERSION = 1 + SCHEMA = Schema( + ("error_code", Int16), + ( + "api_versions", + Array(("api_key", Int16), ("min_version", Int16), ("max_version", Int16)), + ), + ("throttle_time_ms", Int32), + ) + + +class ApiVersionResponse_v2(Response): + API_KEY = 18 + API_VERSION = 2 + SCHEMA = ApiVersionResponse_v1.SCHEMA + + +class ApiVersionRequest_v0(Request): + API_KEY = 18 + API_VERSION = 0 + RESPONSE_TYPE = ApiVersionResponse_v0 + SCHEMA = Schema() + + +class ApiVersionRequest_v1(Request): + API_KEY = 18 + API_VERSION = 1 + RESPONSE_TYPE = ApiVersionResponse_v1 + SCHEMA = ApiVersionRequest_v0.SCHEMA + + +class ApiVersionRequest_v2(Request): + API_KEY = 18 + API_VERSION = 2 + RESPONSE_TYPE = ApiVersionResponse_v1 + SCHEMA = ApiVersionRequest_v0.SCHEMA + + +ApiVersionRequest = [ + ApiVersionRequest_v0, + ApiVersionRequest_v1, + ApiVersionRequest_v2, +] +ApiVersionResponse = [ + ApiVersionResponse_v0, + ApiVersionResponse_v1, + ApiVersionResponse_v2, +] + + +class CreateTopicsResponse_v0(Response): + API_KEY = 19 + API_VERSION = 0 + SCHEMA = Schema( + ("topic_errors", Array(("topic", String("utf-8")), ("error_code", Int16))) + ) + + +class CreateTopicsResponse_v1(Response): + API_KEY = 19 + API_VERSION = 1 + SCHEMA = Schema( + ( + "topic_errors", + Array( + ("topic", String("utf-8")), + ("error_code", Int16), + ("error_message", String("utf-8")), + ), + ) + ) + + +class CreateTopicsResponse_v2(Response): + API_KEY = 19 + API_VERSION = 2 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ( + "topic_errors", + Array( + ("topic", String("utf-8")), + ("error_code", Int16), + ("error_message", String("utf-8")), + ), + ), + ) + + +class CreateTopicsResponse_v3(Response): + API_KEY = 19 + API_VERSION = 3 + SCHEMA = CreateTopicsResponse_v2.SCHEMA + + +class CreateTopicsRequest_v0(Request): + API_KEY = 19 + API_VERSION = 0 + RESPONSE_TYPE = CreateTopicsResponse_v0 + SCHEMA = Schema( + ( + "create_topic_requests", + Array( + ("topic", String("utf-8")), + ("num_partitions", Int32), + ("replication_factor", Int16), + ( + "replica_assignment", + Array(("partition_id", Int32), ("replicas", Array(Int32))), + ), + ( + "configs", + Array( + ("config_key", String("utf-8")), + ("config_value", String("utf-8")), + ), + ), + ), + ), + ("timeout", Int32), + ) + + +class CreateTopicsRequest_v1(Request): + API_KEY = 19 + API_VERSION = 1 + RESPONSE_TYPE = CreateTopicsResponse_v1 + SCHEMA = Schema( + ( + "create_topic_requests", + Array( + ("topic", String("utf-8")), + ("num_partitions", Int32), + ("replication_factor", Int16), + ( + "replica_assignment", + Array(("partition_id", Int32), ("replicas", Array(Int32))), + ), + ( + "configs", + Array( + ("config_key", String("utf-8")), + ("config_value", String("utf-8")), + ), + ), + ), + ), + ("timeout", Int32), + ("validate_only", Boolean), + ) + + +class CreateTopicsRequest_v2(Request): + API_KEY = 19 + API_VERSION = 2 + RESPONSE_TYPE = CreateTopicsResponse_v2 + SCHEMA = CreateTopicsRequest_v1.SCHEMA + + +class CreateTopicsRequest_v3(Request): + API_KEY = 19 + API_VERSION = 3 + RESPONSE_TYPE = CreateTopicsResponse_v3 + SCHEMA = CreateTopicsRequest_v1.SCHEMA + + +CreateTopicsRequest = [ + CreateTopicsRequest_v0, + CreateTopicsRequest_v1, + CreateTopicsRequest_v2, + CreateTopicsRequest_v3, +] +CreateTopicsResponse = [ + CreateTopicsResponse_v0, + CreateTopicsResponse_v1, + CreateTopicsResponse_v2, + CreateTopicsResponse_v3, +] + + +class DeleteTopicsResponse_v0(Response): + API_KEY = 20 + API_VERSION = 0 + SCHEMA = Schema( + ("topic_error_codes", Array(("topic", String("utf-8")), ("error_code", Int16))) + ) + + +class DeleteTopicsResponse_v1(Response): + API_KEY = 20 + API_VERSION = 1 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ("topic_error_codes", Array(("topic", String("utf-8")), ("error_code", Int16))), + ) + + +class DeleteTopicsResponse_v2(Response): + API_KEY = 20 + API_VERSION = 2 + SCHEMA = DeleteTopicsResponse_v1.SCHEMA + + +class DeleteTopicsResponse_v3(Response): + API_KEY = 20 + API_VERSION = 3 + SCHEMA = DeleteTopicsResponse_v1.SCHEMA + + +class DeleteTopicsRequest_v0(Request): + API_KEY = 20 + API_VERSION = 0 + RESPONSE_TYPE = DeleteTopicsResponse_v0 + SCHEMA = Schema(("topics", Array(String("utf-8"))), ("timeout", Int32)) + + +class DeleteTopicsRequest_v1(Request): + API_KEY = 20 + API_VERSION = 1 + RESPONSE_TYPE = DeleteTopicsResponse_v1 + SCHEMA = DeleteTopicsRequest_v0.SCHEMA + + +class DeleteTopicsRequest_v2(Request): + API_KEY = 20 + API_VERSION = 2 + RESPONSE_TYPE = DeleteTopicsResponse_v2 + SCHEMA = DeleteTopicsRequest_v0.SCHEMA + + +class DeleteTopicsRequest_v3(Request): + API_KEY = 20 + API_VERSION = 3 + RESPONSE_TYPE = DeleteTopicsResponse_v3 + SCHEMA = DeleteTopicsRequest_v0.SCHEMA + + +DeleteTopicsRequest = [ + DeleteTopicsRequest_v0, + DeleteTopicsRequest_v1, + DeleteTopicsRequest_v2, + DeleteTopicsRequest_v3, +] +DeleteTopicsResponse = [ + DeleteTopicsResponse_v0, + DeleteTopicsResponse_v1, + DeleteTopicsResponse_v2, + DeleteTopicsResponse_v3, +] + + +class ListGroupsResponse_v0(Response): + API_KEY = 16 + API_VERSION = 0 + SCHEMA = Schema( + ("error_code", Int16), + ( + "groups", + Array(("group", String("utf-8")), ("protocol_type", String("utf-8"))), + ), + ) + + +class ListGroupsResponse_v1(Response): + API_KEY = 16 + API_VERSION = 1 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ("error_code", Int16), + ( + "groups", + Array(("group", String("utf-8")), ("protocol_type", String("utf-8"))), + ), + ) + + +class ListGroupsResponse_v2(Response): + API_KEY = 16 + API_VERSION = 2 + SCHEMA = ListGroupsResponse_v1.SCHEMA + + +class ListGroupsRequest_v0(Request): + API_KEY = 16 + API_VERSION = 0 + RESPONSE_TYPE = ListGroupsResponse_v0 + SCHEMA = Schema() + + +class ListGroupsRequest_v1(Request): + API_KEY = 16 + API_VERSION = 1 + RESPONSE_TYPE = ListGroupsResponse_v1 + SCHEMA = ListGroupsRequest_v0.SCHEMA + + +class ListGroupsRequest_v2(Request): + API_KEY = 16 + API_VERSION = 1 + RESPONSE_TYPE = ListGroupsResponse_v2 + SCHEMA = ListGroupsRequest_v0.SCHEMA + + +ListGroupsRequest = [ + ListGroupsRequest_v0, + ListGroupsRequest_v1, + ListGroupsRequest_v2, +] +ListGroupsResponse = [ + ListGroupsResponse_v0, + ListGroupsResponse_v1, + ListGroupsResponse_v2, +] + + +class DescribeGroupsResponse_v0(Response): + API_KEY = 15 + API_VERSION = 0 + SCHEMA = Schema( + ( + "groups", + Array( + ("error_code", Int16), + ("group", String("utf-8")), + ("state", String("utf-8")), + ("protocol_type", String("utf-8")), + ("protocol", String("utf-8")), + ( + "members", + Array( + ("member_id", String("utf-8")), + ("client_id", String("utf-8")), + ("client_host", String("utf-8")), + ("member_metadata", Bytes), + ("member_assignment", Bytes), + ), + ), + ), + ) + ) + + +class DescribeGroupsResponse_v1(Response): + API_KEY = 15 + API_VERSION = 1 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ( + "groups", + Array( + ("error_code", Int16), + ("group", String("utf-8")), + ("state", String("utf-8")), + ("protocol_type", String("utf-8")), + ("protocol", String("utf-8")), + ( + "members", + Array( + ("member_id", String("utf-8")), + ("client_id", String("utf-8")), + ("client_host", String("utf-8")), + ("member_metadata", Bytes), + ("member_assignment", Bytes), + ), + ), + ), + ), + ) + + +class DescribeGroupsResponse_v2(Response): + API_KEY = 15 + API_VERSION = 2 + SCHEMA = DescribeGroupsResponse_v1.SCHEMA + + +class DescribeGroupsResponse_v3(Response): + API_KEY = 15 + API_VERSION = 3 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ( + "groups", + Array( + ("error_code", Int16), + ("group", String("utf-8")), + ("state", String("utf-8")), + ("protocol_type", String("utf-8")), + ("protocol", String("utf-8")), + ( + "members", + Array( + ("member_id", String("utf-8")), + ("client_id", String("utf-8")), + ("client_host", String("utf-8")), + ("member_metadata", Bytes), + ("member_assignment", Bytes), + ), + ), + ), + ("authorized_operations", Int32), + ), + ) + + +class DescribeGroupsRequest_v0(Request): + API_KEY = 15 + API_VERSION = 0 + RESPONSE_TYPE = DescribeGroupsResponse_v0 + SCHEMA = Schema(("groups", Array(String("utf-8")))) + + +class DescribeGroupsRequest_v1(Request): + API_KEY = 15 + API_VERSION = 1 + RESPONSE_TYPE = DescribeGroupsResponse_v1 + SCHEMA = DescribeGroupsRequest_v0.SCHEMA + + +class DescribeGroupsRequest_v2(Request): + API_KEY = 15 + API_VERSION = 2 + RESPONSE_TYPE = DescribeGroupsResponse_v2 + SCHEMA = DescribeGroupsRequest_v0.SCHEMA + + +class DescribeGroupsRequest_v3(Request): + API_KEY = 15 + API_VERSION = 3 + RESPONSE_TYPE = DescribeGroupsResponse_v2 + SCHEMA = Schema( + ("groups", Array(String("utf-8"))), ("include_authorized_operations", Boolean) + ) + + +DescribeGroupsRequest = [ + DescribeGroupsRequest_v0, + DescribeGroupsRequest_v1, + DescribeGroupsRequest_v2, + DescribeGroupsRequest_v3, +] +DescribeGroupsResponse = [ + DescribeGroupsResponse_v0, + DescribeGroupsResponse_v1, + DescribeGroupsResponse_v2, + DescribeGroupsResponse_v3, +] + + +class SaslHandShakeResponse_v0(Response): + API_KEY = 17 + API_VERSION = 0 + SCHEMA = Schema( + ("error_code", Int16), ("enabled_mechanisms", Array(String("utf-8"))) + ) + + +class SaslHandShakeResponse_v1(Response): + API_KEY = 17 + API_VERSION = 1 + SCHEMA = SaslHandShakeResponse_v0.SCHEMA + + +class SaslHandShakeRequest_v0(Request): + API_KEY = 17 + API_VERSION = 0 + RESPONSE_TYPE = SaslHandShakeResponse_v0 + SCHEMA = Schema(("mechanism", String("utf-8"))) + + +class SaslHandShakeRequest_v1(Request): + API_KEY = 17 + API_VERSION = 1 + RESPONSE_TYPE = SaslHandShakeResponse_v1 + SCHEMA = SaslHandShakeRequest_v0.SCHEMA + + +SaslHandShakeRequest = [SaslHandShakeRequest_v0, SaslHandShakeRequest_v1] +SaslHandShakeResponse = [SaslHandShakeResponse_v0, SaslHandShakeResponse_v1] + + +class DescribeAclsResponse_v0(Response): + API_KEY = 29 + API_VERSION = 0 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ("error_code", Int16), + ("error_message", String("utf-8")), + ( + "resources", + Array( + ("resource_type", Int8), + ("resource_name", String("utf-8")), + ( + "acls", + Array( + ("principal", String("utf-8")), + ("host", String("utf-8")), + ("operation", Int8), + ("permission_type", Int8), + ), + ), + ), + ), + ) + + +class DescribeAclsResponse_v1(Response): + API_KEY = 29 + API_VERSION = 1 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ("error_code", Int16), + ("error_message", String("utf-8")), + ( + "resources", + Array( + ("resource_type", Int8), + ("resource_name", String("utf-8")), + ("resource_pattern_type", Int8), + ( + "acls", + Array( + ("principal", String("utf-8")), + ("host", String("utf-8")), + ("operation", Int8), + ("permission_type", Int8), + ), + ), + ), + ), + ) + + +class DescribeAclsResponse_v2(Response): + API_KEY = 29 + API_VERSION = 2 + SCHEMA = DescribeAclsResponse_v1.SCHEMA + + +class DescribeAclsRequest_v0(Request): + API_KEY = 29 + API_VERSION = 0 + RESPONSE_TYPE = DescribeAclsResponse_v0 + SCHEMA = Schema( + ("resource_type", Int8), + ("resource_name", String("utf-8")), + ("principal", String("utf-8")), + ("host", String("utf-8")), + ("operation", Int8), + ("permission_type", Int8), + ) + + +class DescribeAclsRequest_v1(Request): + API_KEY = 29 + API_VERSION = 1 + RESPONSE_TYPE = DescribeAclsResponse_v1 + SCHEMA = Schema( + ("resource_type", Int8), + ("resource_name", String("utf-8")), + ("resource_pattern_type_filter", Int8), + ("principal", String("utf-8")), + ("host", String("utf-8")), + ("operation", Int8), + ("permission_type", Int8), + ) + + +class DescribeAclsRequest_v2(Request): + """ + Enable flexible version + """ + + API_KEY = 29 + API_VERSION = 2 + RESPONSE_TYPE = DescribeAclsResponse_v2 + SCHEMA = DescribeAclsRequest_v1.SCHEMA + + +DescribeAclsRequest = [DescribeAclsRequest_v0, DescribeAclsRequest_v1] +DescribeAclsResponse = [DescribeAclsResponse_v0, DescribeAclsResponse_v1] + + +class CreateAclsResponse_v0(Response): + API_KEY = 30 + API_VERSION = 0 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ( + "creation_responses", + Array(("error_code", Int16), ("error_message", String("utf-8"))), + ), + ) + + +class CreateAclsResponse_v1(Response): + API_KEY = 30 + API_VERSION = 1 + SCHEMA = CreateAclsResponse_v0.SCHEMA + + +class CreateAclsRequest_v0(Request): + API_KEY = 30 + API_VERSION = 0 + RESPONSE_TYPE = CreateAclsResponse_v0 + SCHEMA = Schema( + ( + "creations", + Array( + ("resource_type", Int8), + ("resource_name", String("utf-8")), + ("principal", String("utf-8")), + ("host", String("utf-8")), + ("operation", Int8), + ("permission_type", Int8), + ), + ) + ) + + +class CreateAclsRequest_v1(Request): + API_KEY = 30 + API_VERSION = 1 + RESPONSE_TYPE = CreateAclsResponse_v1 + SCHEMA = Schema( + ( + "creations", + Array( + ("resource_type", Int8), + ("resource_name", String("utf-8")), + ("resource_pattern_type", Int8), + ("principal", String("utf-8")), + ("host", String("utf-8")), + ("operation", Int8), + ("permission_type", Int8), + ), + ) + ) + + +CreateAclsRequest = [CreateAclsRequest_v0, CreateAclsRequest_v1] +CreateAclsResponse = [CreateAclsResponse_v0, CreateAclsResponse_v1] + + +class DeleteAclsResponse_v0(Response): + API_KEY = 31 + API_VERSION = 0 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ( + "filter_responses", + Array( + ("error_code", Int16), + ("error_message", String("utf-8")), + ( + "matching_acls", + Array( + ("error_code", Int16), + ("error_message", String("utf-8")), + ("resource_type", Int8), + ("resource_name", String("utf-8")), + ("principal", String("utf-8")), + ("host", String("utf-8")), + ("operation", Int8), + ("permission_type", Int8), + ), + ), + ), + ), + ) + + +class DeleteAclsResponse_v1(Response): + API_KEY = 31 + API_VERSION = 1 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ( + "filter_responses", + Array( + ("error_code", Int16), + ("error_message", String("utf-8")), + ( + "matching_acls", + Array( + ("error_code", Int16), + ("error_message", String("utf-8")), + ("resource_type", Int8), + ("resource_name", String("utf-8")), + ("resource_pattern_type", Int8), + ("principal", String("utf-8")), + ("host", String("utf-8")), + ("operation", Int8), + ("permission_type", Int8), + ), + ), + ), + ), + ) + + +class DeleteAclsRequest_v0(Request): + API_KEY = 31 + API_VERSION = 0 + RESPONSE_TYPE = DeleteAclsResponse_v0 + SCHEMA = Schema( + ( + "filters", + Array( + ("resource_type", Int8), + ("resource_name", String("utf-8")), + ("principal", String("utf-8")), + ("host", String("utf-8")), + ("operation", Int8), + ("permission_type", Int8), + ), + ) + ) + + +class DeleteAclsRequest_v1(Request): + API_KEY = 31 + API_VERSION = 1 + RESPONSE_TYPE = DeleteAclsResponse_v1 + SCHEMA = Schema( + ( + "filters", + Array( + ("resource_type", Int8), + ("resource_name", String("utf-8")), + ("resource_pattern_type_filter", Int8), + ("principal", String("utf-8")), + ("host", String("utf-8")), + ("operation", Int8), + ("permission_type", Int8), + ), + ) + ) + + +DeleteAclsRequest = [DeleteAclsRequest_v0, DeleteAclsRequest_v1] +DeleteAclsResponse = [DeleteAclsResponse_v0, DeleteAclsResponse_v1] + + +class AlterConfigsResponse_v0(Response): + API_KEY = 33 + API_VERSION = 0 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ( + "resources", + Array( + ("error_code", Int16), + ("error_message", String("utf-8")), + ("resource_type", Int8), + ("resource_name", String("utf-8")), + ), + ), + ) + + +class AlterConfigsResponse_v1(Response): + API_KEY = 33 + API_VERSION = 1 + SCHEMA = AlterConfigsResponse_v0.SCHEMA + + +class AlterConfigsRequest_v0(Request): + API_KEY = 33 + API_VERSION = 0 + RESPONSE_TYPE = AlterConfigsResponse_v0 + SCHEMA = Schema( + ( + "resources", + Array( + ("resource_type", Int8), + ("resource_name", String("utf-8")), + ( + "config_entries", + Array( + ("config_name", String("utf-8")), + ("config_value", String("utf-8")), + ), + ), + ), + ), + ("validate_only", Boolean), + ) + + +class AlterConfigsRequest_v1(Request): + API_KEY = 33 + API_VERSION = 1 + RESPONSE_TYPE = AlterConfigsResponse_v1 + SCHEMA = AlterConfigsRequest_v0.SCHEMA + + +AlterConfigsRequest = [AlterConfigsRequest_v0, AlterConfigsRequest_v1] +AlterConfigsResponse = [AlterConfigsResponse_v0, AlterConfigsRequest_v1] + + +class DescribeConfigsResponse_v0(Response): + API_KEY = 32 + API_VERSION = 0 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ( + "resources", + Array( + ("error_code", Int16), + ("error_message", String("utf-8")), + ("resource_type", Int8), + ("resource_name", String("utf-8")), + ( + "config_entries", + Array( + ("config_names", String("utf-8")), + ("config_value", String("utf-8")), + ("read_only", Boolean), + ("is_default", Boolean), + ("is_sensitive", Boolean), + ), + ), + ), + ), + ) + + +class DescribeConfigsResponse_v1(Response): + API_KEY = 32 + API_VERSION = 1 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ( + "resources", + Array( + ("error_code", Int16), + ("error_message", String("utf-8")), + ("resource_type", Int8), + ("resource_name", String("utf-8")), + ( + "config_entries", + Array( + ("config_names", String("utf-8")), + ("config_value", String("utf-8")), + ("read_only", Boolean), + ("is_default", Boolean), + ("is_sensitive", Boolean), + ( + "config_synonyms", + Array( + ("config_name", String("utf-8")), + ("config_value", String("utf-8")), + ("config_source", Int8), + ), + ), + ), + ), + ), + ), + ) + + +class DescribeConfigsResponse_v2(Response): + API_KEY = 32 + API_VERSION = 2 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ( + "resources", + Array( + ("error_code", Int16), + ("error_message", String("utf-8")), + ("resource_type", Int8), + ("resource_name", String("utf-8")), + ( + "config_entries", + Array( + ("config_names", String("utf-8")), + ("config_value", String("utf-8")), + ("read_only", Boolean), + ("config_source", Int8), + ("is_sensitive", Boolean), + ( + "config_synonyms", + Array( + ("config_name", String("utf-8")), + ("config_value", String("utf-8")), + ("config_source", Int8), + ), + ), + ), + ), + ), + ), + ) + + +class DescribeConfigsRequest_v0(Request): + API_KEY = 32 + API_VERSION = 0 + RESPONSE_TYPE = DescribeConfigsResponse_v0 + SCHEMA = Schema( + ( + "resources", + Array( + ("resource_type", Int8), + ("resource_name", String("utf-8")), + ("config_names", Array(String("utf-8"))), + ), + ) + ) + + +class DescribeConfigsRequest_v1(Request): + API_KEY = 32 + API_VERSION = 1 + RESPONSE_TYPE = DescribeConfigsResponse_v1 + SCHEMA = Schema( + ( + "resources", + Array( + ("resource_type", Int8), + ("resource_name", String("utf-8")), + ("config_names", Array(String("utf-8"))), + ), + ), + ("include_synonyms", Boolean), + ) + + +class DescribeConfigsRequest_v2(Request): + API_KEY = 32 + API_VERSION = 2 + RESPONSE_TYPE = DescribeConfigsResponse_v2 + SCHEMA = DescribeConfigsRequest_v1.SCHEMA + + +DescribeConfigsRequest = [ + DescribeConfigsRequest_v0, + DescribeConfigsRequest_v1, + DescribeConfigsRequest_v2, +] +DescribeConfigsResponse = [ + DescribeConfigsResponse_v0, + DescribeConfigsResponse_v1, + DescribeConfigsResponse_v2, +] + + +class SaslAuthenticateResponse_v0(Response): + API_KEY = 36 + API_VERSION = 0 + SCHEMA = Schema( + ("error_code", Int16), + ("error_message", String("utf-8")), + ("sasl_auth_bytes", Bytes), + ) + + +class SaslAuthenticateResponse_v1(Response): + API_KEY = 36 + API_VERSION = 1 + SCHEMA = Schema( + ("error_code", Int16), + ("error_message", String("utf-8")), + ("sasl_auth_bytes", Bytes), + ("session_lifetime_ms", Int64), + ) + + +class SaslAuthenticateRequest_v0(Request): + API_KEY = 36 + API_VERSION = 0 + RESPONSE_TYPE = SaslAuthenticateResponse_v0 + SCHEMA = Schema(("sasl_auth_bytes", Bytes)) + + +class SaslAuthenticateRequest_v1(Request): + API_KEY = 36 + API_VERSION = 1 + RESPONSE_TYPE = SaslAuthenticateResponse_v1 + SCHEMA = SaslAuthenticateRequest_v0.SCHEMA + + +SaslAuthenticateRequest = [ + SaslAuthenticateRequest_v0, + SaslAuthenticateRequest_v1, +] +SaslAuthenticateResponse = [ + SaslAuthenticateResponse_v0, + SaslAuthenticateResponse_v1, +] + + +class CreatePartitionsResponse_v0(Response): + API_KEY = 37 + API_VERSION = 0 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ( + "topic_errors", + Array( + ("topic", String("utf-8")), + ("error_code", Int16), + ("error_message", String("utf-8")), + ), + ), + ) + + +class CreatePartitionsResponse_v1(Response): + API_KEY = 37 + API_VERSION = 1 + SCHEMA = CreatePartitionsResponse_v0.SCHEMA + + +class CreatePartitionsRequest_v0(Request): + API_KEY = 37 + API_VERSION = 0 + RESPONSE_TYPE = CreatePartitionsResponse_v0 + SCHEMA = Schema( + ( + "topic_partitions", + Array( + ("topic", String("utf-8")), + ( + "new_partitions", + Schema(("count", Int32), ("assignment", Array(Array(Int32)))), + ), + ), + ), + ("timeout", Int32), + ("validate_only", Boolean), + ) + + +class CreatePartitionsRequest_v1(Request): + API_KEY = 37 + API_VERSION = 1 + SCHEMA = CreatePartitionsRequest_v0.SCHEMA + RESPONSE_TYPE = CreatePartitionsResponse_v1 + + +CreatePartitionsRequest = [ + CreatePartitionsRequest_v0, + CreatePartitionsRequest_v1, +] +CreatePartitionsResponse = [ + CreatePartitionsResponse_v0, + CreatePartitionsResponse_v1, +] + + +class DeleteGroupsResponse_v0(Response): + API_KEY = 42 + API_VERSION = 0 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ("results", Array(("group_id", String("utf-8")), ("error_code", Int16))), + ) + + +class DeleteGroupsResponse_v1(Response): + API_KEY = 42 + API_VERSION = 1 + SCHEMA = DeleteGroupsResponse_v0.SCHEMA + + +class DeleteGroupsRequest_v0(Request): + API_KEY = 42 + API_VERSION = 0 + RESPONSE_TYPE = DeleteGroupsResponse_v0 + SCHEMA = Schema(("groups_names", Array(String("utf-8")))) + + +class DeleteGroupsRequest_v1(Request): + API_KEY = 42 + API_VERSION = 1 + RESPONSE_TYPE = DeleteGroupsResponse_v1 + SCHEMA = DeleteGroupsRequest_v0.SCHEMA + + +DeleteGroupsRequest = [DeleteGroupsRequest_v0, DeleteGroupsRequest_v1] + +DeleteGroupsResponse = [DeleteGroupsResponse_v0, DeleteGroupsResponse_v1] + + +class DescribeClientQuotasResponse_v0(Request): + API_KEY = 48 + API_VERSION = 0 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ("error_code", Int16), + ("error_message", String("utf-8")), + ( + "entries", + Array( + ( + "entity", + Array( + ("entity_type", String("utf-8")), + ("entity_name", String("utf-8")), + ), + ), + ("values", Array(("name", String("utf-8")), ("value", Float64))), + ), + ), + ) + + +class DescribeClientQuotasRequest_v0(Request): + API_KEY = 48 + API_VERSION = 0 + RESPONSE_TYPE = DescribeClientQuotasResponse_v0 + SCHEMA = Schema( + ( + "components", + Array( + ("entity_type", String("utf-8")), + ("match_type", Int8), + ("match", String("utf-8")), + ), + ), + ("strict", Boolean), + ) + + +DescribeClientQuotasRequest = [ + DescribeClientQuotasRequest_v0, +] + +DescribeClientQuotasResponse = [ + DescribeClientQuotasResponse_v0, +] + + +class AlterPartitionReassignmentsResponse_v0(Response): + API_KEY = 45 + API_VERSION = 0 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ("error_code", Int16), + ("error_message", CompactString("utf-8")), + ( + "responses", + CompactArray( + ("name", CompactString("utf-8")), + ( + "partitions", + CompactArray( + ("partition_index", Int32), + ("error_code", Int16), + ("error_message", CompactString("utf-8")), + ("tags", TaggedFields), + ), + ), + ("tags", TaggedFields), + ), + ), + ("tags", TaggedFields), + ) + + +class AlterPartitionReassignmentsRequest_v0(Request): + FLEXIBLE_VERSION = True + API_KEY = 45 + API_VERSION = 0 + RESPONSE_TYPE = AlterPartitionReassignmentsResponse_v0 + SCHEMA = Schema( + ("timeout_ms", Int32), + ( + "topics", + CompactArray( + ("name", CompactString("utf-8")), + ( + "partitions", + CompactArray( + ("partition_index", Int32), + ("replicas", CompactArray(Int32)), + ("tags", TaggedFields), + ), + ), + ("tags", TaggedFields), + ), + ), + ("tags", TaggedFields), + ) + + +AlterPartitionReassignmentsRequest = [AlterPartitionReassignmentsRequest_v0] + +AlterPartitionReassignmentsResponse = [AlterPartitionReassignmentsResponse_v0] + + +class ListPartitionReassignmentsResponse_v0(Response): + API_KEY = 46 + API_VERSION = 0 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ("error_code", Int16), + ("error_message", CompactString("utf-8")), + ( + "topics", + CompactArray( + ("name", CompactString("utf-8")), + ( + "partitions", + CompactArray( + ("partition_index", Int32), + ("replicas", CompactArray(Int32)), + ("adding_replicas", CompactArray(Int32)), + ("removing_replicas", CompactArray(Int32)), + ("tags", TaggedFields), + ), + ), + ("tags", TaggedFields), + ), + ), + ("tags", TaggedFields), + ) + + +class ListPartitionReassignmentsRequest_v0(Request): + FLEXIBLE_VERSION = True + API_KEY = 46 + API_VERSION = 0 + RESPONSE_TYPE = ListPartitionReassignmentsResponse_v0 + SCHEMA = Schema( + ("timeout_ms", Int32), + ( + "topics", + CompactArray( + ("name", CompactString("utf-8")), + ("partition_index", CompactArray(Int32)), + ("tags", TaggedFields), + ), + ), + ("tags", TaggedFields), + ) + + +ListPartitionReassignmentsRequest = [ListPartitionReassignmentsRequest_v0] + +ListPartitionReassignmentsResponse = [ListPartitionReassignmentsResponse_v0] diff --git a/kafka/protocol/api.py b/aiokafka/protocol/api.py similarity index 74% rename from kafka/protocol/api.py rename to aiokafka/protocol/api.py index f12cb972..9eb2b6fe 100644 --- a/kafka/protocol/api.py +++ b/aiokafka/protocol/api.py @@ -1,20 +1,18 @@ -from __future__ import absolute_import - import abc -from kafka.protocol.struct import Struct -from kafka.protocol.types import Int16, Int32, String, Schema, Array, TaggedFields +from .struct import Struct +from .types import Int16, Int32, String, Schema, Array, TaggedFields class RequestHeader(Struct): SCHEMA = Schema( - ('api_key', Int16), - ('api_version', Int16), - ('correlation_id', Int32), - ('client_id', String('utf-8')) + ("api_key", Int16), + ("api_version", Int16), + ("correlation_id", Int32), + ("client_id", String("utf-8")), ) - def __init__(self, request, correlation_id=0, client_id='kafka-python'): + def __init__(self, request, correlation_id=0, client_id="kafka-python"): super(RequestHeader, self).__init__( request.API_KEY, request.API_VERSION, correlation_id, client_id ) @@ -23,14 +21,14 @@ def __init__(self, request, correlation_id=0, client_id='kafka-python'): class RequestHeaderV2(Struct): # Flexible response / request headers end in field buffer SCHEMA = Schema( - ('api_key', Int16), - ('api_version', Int16), - ('correlation_id', Int32), - ('client_id', String('utf-8')), - ('tags', TaggedFields), + ("api_key", Int16), + ("api_version", Int16), + ("correlation_id", Int32), + ("client_id", String("utf-8")), + ("tags", TaggedFields), ) - def __init__(self, request, correlation_id=0, client_id='kafka-python', tags=None): + def __init__(self, request, correlation_id=0, client_id="kafka-python", tags=None): super(RequestHeaderV2, self).__init__( request.API_KEY, request.API_VERSION, correlation_id, client_id, tags or {} ) @@ -38,14 +36,14 @@ def __init__(self, request, correlation_id=0, client_id='kafka-python', tags=Non class ResponseHeader(Struct): SCHEMA = Schema( - ('correlation_id', Int32), + ("correlation_id", Int32), ) class ResponseHeaderV2(Struct): SCHEMA = Schema( - ('correlation_id', Int32), - ('tags', TaggedFields), + ("correlation_id", Int32), + ("tags", TaggedFields), ) @@ -83,7 +81,9 @@ def to_object(self): def build_request_header(self, correlation_id, client_id): if self.FLEXIBLE_VERSION: - return RequestHeaderV2(self, correlation_id=correlation_id, client_id=client_id) + return RequestHeaderV2( + self, correlation_id=correlation_id, client_id=client_id + ) return RequestHeader(self, correlation_id=correlation_id, client_id=client_id) def parse_response_header(self, read_buffer): @@ -126,10 +126,7 @@ def _to_object(schema, data): obj[name] = _to_object(_type, val) elif isinstance(_type, Array): if isinstance(_type.array_of, (Array, Schema)): - obj[name] = [ - _to_object(_type.array_of, x) - for x in val - ] + obj[name] = [_to_object(_type.array_of, x) for x in val] else: obj[name] = val else: diff --git a/aiokafka/protocol/commit.py b/aiokafka/protocol/commit.py new file mode 100644 index 00000000..81185397 --- /dev/null +++ b/aiokafka/protocol/commit.py @@ -0,0 +1,306 @@ +from .api import Request, Response +from .types import Array, Int8, Int16, Int32, Int64, Schema, String + + +class OffsetCommitResponse_v0(Response): + API_KEY = 8 + API_VERSION = 0 + SCHEMA = Schema( + ( + "topics", + Array( + ("topic", String("utf-8")), + ("partitions", Array(("partition", Int32), ("error_code", Int16))), + ), + ) + ) + + +class OffsetCommitResponse_v1(Response): + API_KEY = 8 + API_VERSION = 1 + SCHEMA = OffsetCommitResponse_v0.SCHEMA + + +class OffsetCommitResponse_v2(Response): + API_KEY = 8 + API_VERSION = 2 + SCHEMA = OffsetCommitResponse_v1.SCHEMA + + +class OffsetCommitResponse_v3(Response): + API_KEY = 8 + API_VERSION = 3 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ( + "topics", + Array( + ("topic", String("utf-8")), + ("partitions", Array(("partition", Int32), ("error_code", Int16))), + ), + ), + ) + + +class OffsetCommitRequest_v0(Request): + API_KEY = 8 + API_VERSION = 0 # Zookeeper-backed storage + RESPONSE_TYPE = OffsetCommitResponse_v0 + SCHEMA = Schema( + ("consumer_group", String("utf-8")), + ( + "topics", + Array( + ("topic", String("utf-8")), + ( + "partitions", + Array( + ("partition", Int32), + ("offset", Int64), + ("metadata", String("utf-8")), + ), + ), + ), + ), + ) + + +class OffsetCommitRequest_v1(Request): + API_KEY = 8 + API_VERSION = 1 # Kafka-backed storage + RESPONSE_TYPE = OffsetCommitResponse_v1 + SCHEMA = Schema( + ("consumer_group", String("utf-8")), + ("consumer_group_generation_id", Int32), + ("consumer_id", String("utf-8")), + ( + "topics", + Array( + ("topic", String("utf-8")), + ( + "partitions", + Array( + ("partition", Int32), + ("offset", Int64), + ("timestamp", Int64), + ("metadata", String("utf-8")), + ), + ), + ), + ), + ) + + +class OffsetCommitRequest_v2(Request): + API_KEY = 8 + API_VERSION = 2 # added retention_time, dropped timestamp + RESPONSE_TYPE = OffsetCommitResponse_v2 + SCHEMA = Schema( + ("consumer_group", String("utf-8")), + ("consumer_group_generation_id", Int32), + ("consumer_id", String("utf-8")), + ("retention_time", Int64), + ( + "topics", + Array( + ("topic", String("utf-8")), + ( + "partitions", + Array( + ("partition", Int32), + ("offset", Int64), + ("metadata", String("utf-8")), + ), + ), + ), + ), + ) + DEFAULT_GENERATION_ID = -1 + DEFAULT_RETENTION_TIME = -1 + + +class OffsetCommitRequest_v3(Request): + API_KEY = 8 + API_VERSION = 3 + RESPONSE_TYPE = OffsetCommitResponse_v3 + SCHEMA = OffsetCommitRequest_v2.SCHEMA + + +OffsetCommitRequest = [ + OffsetCommitRequest_v0, + OffsetCommitRequest_v1, + OffsetCommitRequest_v2, + OffsetCommitRequest_v3, +] +OffsetCommitResponse = [ + OffsetCommitResponse_v0, + OffsetCommitResponse_v1, + OffsetCommitResponse_v2, + OffsetCommitResponse_v3, +] + + +class OffsetFetchResponse_v0(Response): + API_KEY = 9 + API_VERSION = 0 + SCHEMA = Schema( + ( + "topics", + Array( + ("topic", String("utf-8")), + ( + "partitions", + Array( + ("partition", Int32), + ("offset", Int64), + ("metadata", String("utf-8")), + ("error_code", Int16), + ), + ), + ), + ) + ) + + +class OffsetFetchResponse_v1(Response): + API_KEY = 9 + API_VERSION = 1 + SCHEMA = OffsetFetchResponse_v0.SCHEMA + + +class OffsetFetchResponse_v2(Response): + # Added in KIP-88 + API_KEY = 9 + API_VERSION = 2 + SCHEMA = Schema( + ( + "topics", + Array( + ("topic", String("utf-8")), + ( + "partitions", + Array( + ("partition", Int32), + ("offset", Int64), + ("metadata", String("utf-8")), + ("error_code", Int16), + ), + ), + ), + ), + ("error_code", Int16), + ) + + +class OffsetFetchResponse_v3(Response): + API_KEY = 9 + API_VERSION = 3 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ( + "topics", + Array( + ("topic", String("utf-8")), + ( + "partitions", + Array( + ("partition", Int32), + ("offset", Int64), + ("metadata", String("utf-8")), + ("error_code", Int16), + ), + ), + ), + ), + ("error_code", Int16), + ) + + +class OffsetFetchRequest_v0(Request): + API_KEY = 9 + API_VERSION = 0 # zookeeper-backed storage + RESPONSE_TYPE = OffsetFetchResponse_v0 + SCHEMA = Schema( + ("consumer_group", String("utf-8")), + ("topics", Array(("topic", String("utf-8")), ("partitions", Array(Int32)))), + ) + + +class OffsetFetchRequest_v1(Request): + API_KEY = 9 + API_VERSION = 1 # kafka-backed storage + RESPONSE_TYPE = OffsetFetchResponse_v1 + SCHEMA = OffsetFetchRequest_v0.SCHEMA + + +class OffsetFetchRequest_v2(Request): + # KIP-88: Allows passing null topics to return offsets for all partitions + # that the consumer group has a stored offset for, even if no consumer in + # the group is currently consuming that partition. + API_KEY = 9 + API_VERSION = 2 + RESPONSE_TYPE = OffsetFetchResponse_v2 + SCHEMA = OffsetFetchRequest_v1.SCHEMA + + +class OffsetFetchRequest_v3(Request): + API_KEY = 9 + API_VERSION = 3 + RESPONSE_TYPE = OffsetFetchResponse_v3 + SCHEMA = OffsetFetchRequest_v2.SCHEMA + + +OffsetFetchRequest = [ + OffsetFetchRequest_v0, + OffsetFetchRequest_v1, + OffsetFetchRequest_v2, + OffsetFetchRequest_v3, +] +OffsetFetchResponse = [ + OffsetFetchResponse_v0, + OffsetFetchResponse_v1, + OffsetFetchResponse_v2, + OffsetFetchResponse_v3, +] + + +class GroupCoordinatorResponse_v0(Response): + API_KEY = 10 + API_VERSION = 0 + SCHEMA = Schema( + ("error_code", Int16), + ("coordinator_id", Int32), + ("host", String("utf-8")), + ("port", Int32), + ) + + +class GroupCoordinatorResponse_v1(Response): + API_KEY = 10 + API_VERSION = 1 + SCHEMA = Schema( + ("error_code", Int16), + ("error_message", String("utf-8")), + ("coordinator_id", Int32), + ("host", String("utf-8")), + ("port", Int32), + ) + + +class GroupCoordinatorRequest_v0(Request): + API_KEY = 10 + API_VERSION = 0 + RESPONSE_TYPE = GroupCoordinatorResponse_v0 + SCHEMA = Schema(("consumer_group", String("utf-8"))) + + +class GroupCoordinatorRequest_v1(Request): + API_KEY = 10 + API_VERSION = 1 + RESPONSE_TYPE = GroupCoordinatorResponse_v1 + SCHEMA = Schema(("coordinator_key", String("utf-8")), ("coordinator_type", Int8)) + + +GroupCoordinatorRequest = [GroupCoordinatorRequest_v0, GroupCoordinatorRequest_v1] +GroupCoordinatorResponse = [GroupCoordinatorResponse_v0, GroupCoordinatorResponse_v1] diff --git a/aiokafka/protocol/coordination.py b/aiokafka/protocol/coordination.py index 1690b9ae..9bf086ac 100644 --- a/aiokafka/protocol/coordination.py +++ b/aiokafka/protocol/coordination.py @@ -1,15 +1,15 @@ -from kafka.protocol.api import Request, Response -from kafka.protocol.types import Int8, Int16, Int32, Schema, String +from .api import Request, Response +from .types import Int8, Int16, Int32, Schema, String class FindCoordinatorResponse_v0(Response): API_KEY = 10 API_VERSION = 0 SCHEMA = Schema( - ('error_code', Int16), - ('coordinator_id', Int32), - ('host', String('utf-8')), - ('port', Int32) + ("error_code", Int16), + ("coordinator_id", Int32), + ("host", String("utf-8")), + ("port", Int32), ) @@ -17,12 +17,12 @@ class FindCoordinatorResponse_v1(Response): API_KEY = 10 API_VERSION = 1 SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('error_code', Int16), - ('error_message', String('utf-8')), - ('coordinator_id', Int32), - ('host', String('utf-8')), - ('port', Int32) + ("throttle_time_ms", Int32), + ("error_code", Int16), + ("error_message", String("utf-8")), + ("coordinator_id", Int32), + ("host", String("utf-8")), + ("port", Int32), ) @@ -30,22 +30,15 @@ class FindCoordinatorRequest_v0(Request): API_KEY = 10 API_VERSION = 0 RESPONSE_TYPE = FindCoordinatorResponse_v0 - SCHEMA = Schema( - ('consumer_group', String('utf-8')) - ) + SCHEMA = Schema(("consumer_group", String("utf-8"))) class FindCoordinatorRequest_v1(Request): API_KEY = 10 API_VERSION = 1 RESPONSE_TYPE = FindCoordinatorResponse_v1 - SCHEMA = Schema( - ('coordinator_key', String('utf-8')), - ('coordinator_type', Int8) - ) + SCHEMA = Schema(("coordinator_key", String("utf-8")), ("coordinator_type", Int8)) -FindCoordinatorRequest = [ - FindCoordinatorRequest_v0, FindCoordinatorRequest_v1] -FindCoordinatorResponse = [ - FindCoordinatorResponse_v0, FindCoordinatorResponse_v1] +FindCoordinatorRequest = [FindCoordinatorRequest_v0, FindCoordinatorRequest_v1] +FindCoordinatorResponse = [FindCoordinatorResponse_v0, FindCoordinatorResponse_v1] diff --git a/aiokafka/protocol/fetch.py b/aiokafka/protocol/fetch.py new file mode 100644 index 00000000..be5518e2 --- /dev/null +++ b/aiokafka/protocol/fetch.py @@ -0,0 +1,516 @@ +from .api import Request, Response +from .types import Array, Int8, Int16, Int32, Int64, Schema, String, Bytes + + +class FetchResponse_v0(Response): + API_KEY = 1 + API_VERSION = 0 + SCHEMA = Schema( + ( + "topics", + Array( + ("topics", String("utf-8")), + ( + "partitions", + Array( + ("partition", Int32), + ("error_code", Int16), + ("highwater_offset", Int64), + ("message_set", Bytes), + ), + ), + ), + ) + ) + + +class FetchResponse_v1(Response): + API_KEY = 1 + API_VERSION = 1 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ( + "topics", + Array( + ("topics", String("utf-8")), + ( + "partitions", + Array( + ("partition", Int32), + ("error_code", Int16), + ("highwater_offset", Int64), + ("message_set", Bytes), + ), + ), + ), + ), + ) + + +class FetchResponse_v2(Response): + API_KEY = 1 + API_VERSION = 2 + SCHEMA = FetchResponse_v1.SCHEMA # message format changed internally + + +class FetchResponse_v3(Response): + API_KEY = 1 + API_VERSION = 3 + SCHEMA = FetchResponse_v2.SCHEMA + + +class FetchResponse_v4(Response): + API_KEY = 1 + API_VERSION = 4 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ( + "topics", + Array( + ("topics", String("utf-8")), + ( + "partitions", + Array( + ("partition", Int32), + ("error_code", Int16), + ("highwater_offset", Int64), + ("last_stable_offset", Int64), + ( + "aborted_transactions", + Array(("producer_id", Int64), ("first_offset", Int64)), + ), + ("message_set", Bytes), + ), + ), + ), + ), + ) + + +class FetchResponse_v5(Response): + API_KEY = 1 + API_VERSION = 5 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ( + "topics", + Array( + ("topics", String("utf-8")), + ( + "partitions", + Array( + ("partition", Int32), + ("error_code", Int16), + ("highwater_offset", Int64), + ("last_stable_offset", Int64), + ("log_start_offset", Int64), + ( + "aborted_transactions", + Array(("producer_id", Int64), ("first_offset", Int64)), + ), + ("message_set", Bytes), + ), + ), + ), + ), + ) + + +class FetchResponse_v6(Response): + """ + Same as FetchResponse_v5. The version number is bumped up to indicate that the + client supports KafkaStorageException. The KafkaStorageException will be translated + to NotLeaderForPartitionException in the response if version <= 5 + """ + + API_KEY = 1 + API_VERSION = 6 + SCHEMA = FetchResponse_v5.SCHEMA + + +class FetchResponse_v7(Response): + """ + Add error_code and session_id to response + """ + + API_KEY = 1 + API_VERSION = 7 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ("error_code", Int16), + ("session_id", Int32), + ( + "topics", + Array( + ("topics", String("utf-8")), + ( + "partitions", + Array( + ("partition", Int32), + ("error_code", Int16), + ("highwater_offset", Int64), + ("last_stable_offset", Int64), + ("log_start_offset", Int64), + ( + "aborted_transactions", + Array(("producer_id", Int64), ("first_offset", Int64)), + ), + ("message_set", Bytes), + ), + ), + ), + ), + ) + + +class FetchResponse_v8(Response): + API_KEY = 1 + API_VERSION = 8 + SCHEMA = FetchResponse_v7.SCHEMA + + +class FetchResponse_v9(Response): + API_KEY = 1 + API_VERSION = 9 + SCHEMA = FetchResponse_v7.SCHEMA + + +class FetchResponse_v10(Response): + API_KEY = 1 + API_VERSION = 10 + SCHEMA = FetchResponse_v7.SCHEMA + + +class FetchResponse_v11(Response): + API_KEY = 1 + API_VERSION = 11 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ("error_code", Int16), + ("session_id", Int32), + ( + "topics", + Array( + ("topics", String("utf-8")), + ( + "partitions", + Array( + ("partition", Int32), + ("error_code", Int16), + ("highwater_offset", Int64), + ("last_stable_offset", Int64), + ("log_start_offset", Int64), + ( + "aborted_transactions", + Array(("producer_id", Int64), ("first_offset", Int64)), + ), + ("preferred_read_replica", Int32), + ("message_set", Bytes), + ), + ), + ), + ), + ) + + +class FetchRequest_v0(Request): + API_KEY = 1 + API_VERSION = 0 + RESPONSE_TYPE = FetchResponse_v0 + SCHEMA = Schema( + ("replica_id", Int32), + ("max_wait_time", Int32), + ("min_bytes", Int32), + ( + "topics", + Array( + ("topic", String("utf-8")), + ( + "partitions", + Array( + ("partition", Int32), ("offset", Int64), ("max_bytes", Int32) + ), + ), + ), + ), + ) + + +class FetchRequest_v1(Request): + API_KEY = 1 + API_VERSION = 1 + RESPONSE_TYPE = FetchResponse_v1 + SCHEMA = FetchRequest_v0.SCHEMA + + +class FetchRequest_v2(Request): + API_KEY = 1 + API_VERSION = 2 + RESPONSE_TYPE = FetchResponse_v2 + SCHEMA = FetchRequest_v1.SCHEMA + + +class FetchRequest_v3(Request): + API_KEY = 1 + API_VERSION = 3 + RESPONSE_TYPE = FetchResponse_v3 + SCHEMA = Schema( + ("replica_id", Int32), + ("max_wait_time", Int32), + ("min_bytes", Int32), + ("max_bytes", Int32), # This new field is only difference from FR_v2 + ( + "topics", + Array( + ("topic", String("utf-8")), + ( + "partitions", + Array( + ("partition", Int32), ("offset", Int64), ("max_bytes", Int32) + ), + ), + ), + ), + ) + + +class FetchRequest_v4(Request): + # Adds isolation_level field + API_KEY = 1 + API_VERSION = 4 + RESPONSE_TYPE = FetchResponse_v4 + SCHEMA = Schema( + ("replica_id", Int32), + ("max_wait_time", Int32), + ("min_bytes", Int32), + ("max_bytes", Int32), + ("isolation_level", Int8), + ( + "topics", + Array( + ("topic", String("utf-8")), + ( + "partitions", + Array( + ("partition", Int32), ("offset", Int64), ("max_bytes", Int32) + ), + ), + ), + ), + ) + + +class FetchRequest_v5(Request): + # This may only be used in broker-broker api calls + API_KEY = 1 + API_VERSION = 5 + RESPONSE_TYPE = FetchResponse_v5 + SCHEMA = Schema( + ("replica_id", Int32), + ("max_wait_time", Int32), + ("min_bytes", Int32), + ("max_bytes", Int32), + ("isolation_level", Int8), + ( + "topics", + Array( + ("topic", String("utf-8")), + ( + "partitions", + Array( + ("partition", Int32), + ("fetch_offset", Int64), + ("log_start_offset", Int64), + ("max_bytes", Int32), + ), + ), + ), + ), + ) + + +class FetchRequest_v6(Request): + """ + The body of FETCH_REQUEST_V6 is the same as FETCH_REQUEST_V5. The version number is + bumped up to indicate that the client supports KafkaStorageException. The + KafkaStorageException will be translated to NotLeaderForPartitionException in the + response if version <= 5 + """ + + API_KEY = 1 + API_VERSION = 6 + RESPONSE_TYPE = FetchResponse_v6 + SCHEMA = FetchRequest_v5.SCHEMA + + +class FetchRequest_v7(Request): + """ + Add incremental fetch requests + """ + + API_KEY = 1 + API_VERSION = 7 + RESPONSE_TYPE = FetchResponse_v7 + SCHEMA = Schema( + ("replica_id", Int32), + ("max_wait_time", Int32), + ("min_bytes", Int32), + ("max_bytes", Int32), + ("isolation_level", Int8), + ("session_id", Int32), + ("session_epoch", Int32), + ( + "topics", + Array( + ("topic", String("utf-8")), + ( + "partitions", + Array( + ("partition", Int32), + ("fetch_offset", Int64), + ("log_start_offset", Int64), + ("max_bytes", Int32), + ), + ), + ), + ), + ( + "forgotten_topics_data", + Array(("topic", String), ("partitions", Array(Int32))), + ), + ) + + +class FetchRequest_v8(Request): + """ + bump used to indicate that on quota violation brokers send out responses before + throttling. + """ + + API_KEY = 1 + API_VERSION = 8 + RESPONSE_TYPE = FetchResponse_v8 + SCHEMA = FetchRequest_v7.SCHEMA + + +class FetchRequest_v9(Request): + """ + adds the current leader epoch (see KIP-320) + """ + + API_KEY = 1 + API_VERSION = 9 + RESPONSE_TYPE = FetchResponse_v9 + SCHEMA = Schema( + ("replica_id", Int32), + ("max_wait_time", Int32), + ("min_bytes", Int32), + ("max_bytes", Int32), + ("isolation_level", Int8), + ("session_id", Int32), + ("session_epoch", Int32), + ( + "topics", + Array( + ("topic", String("utf-8")), + ( + "partitions", + Array( + ("partition", Int32), + ("current_leader_epoch", Int32), + ("fetch_offset", Int64), + ("log_start_offset", Int64), + ("max_bytes", Int32), + ), + ), + ), + ), + ( + "forgotten_topics_data", + Array( + ("topic", String), + ("partitions", Array(Int32)), + ), + ), + ) + + +class FetchRequest_v10(Request): + """ + bumped up to indicate ZStandard capability. (see KIP-110) + """ + + API_KEY = 1 + API_VERSION = 10 + RESPONSE_TYPE = FetchResponse_v10 + SCHEMA = FetchRequest_v9.SCHEMA + + +class FetchRequest_v11(Request): + """ + added rack ID to support read from followers (KIP-392) + """ + + API_KEY = 1 + API_VERSION = 11 + RESPONSE_TYPE = FetchResponse_v11 + SCHEMA = Schema( + ("replica_id", Int32), + ("max_wait_time", Int32), + ("min_bytes", Int32), + ("max_bytes", Int32), + ("isolation_level", Int8), + ("session_id", Int32), + ("session_epoch", Int32), + ( + "topics", + Array( + ("topic", String("utf-8")), + ( + "partitions", + Array( + ("partition", Int32), + ("current_leader_epoch", Int32), + ("fetch_offset", Int64), + ("log_start_offset", Int64), + ("max_bytes", Int32), + ), + ), + ), + ), + ( + "forgotten_topics_data", + Array(("topic", String), ("partitions", Array(Int32))), + ), + ("rack_id", String("utf-8")), + ) + + +FetchRequest = [ + FetchRequest_v0, + FetchRequest_v1, + FetchRequest_v2, + FetchRequest_v3, + FetchRequest_v4, + FetchRequest_v5, + FetchRequest_v6, + FetchRequest_v7, + FetchRequest_v8, + FetchRequest_v9, + FetchRequest_v10, + FetchRequest_v11, +] +FetchResponse = [ + FetchResponse_v0, + FetchResponse_v1, + FetchResponse_v2, + FetchResponse_v3, + FetchResponse_v4, + FetchResponse_v5, + FetchResponse_v6, + FetchResponse_v7, + FetchResponse_v8, + FetchResponse_v9, + FetchResponse_v10, + FetchResponse_v11, +] diff --git a/kafka/protocol/frame.py b/aiokafka/protocol/frame.py similarity index 94% rename from kafka/protocol/frame.py rename to aiokafka/protocol/frame.py index 7b4a32bc..897e091b 100644 --- a/kafka/protocol/frame.py +++ b/aiokafka/protocol/frame.py @@ -24,7 +24,7 @@ def tell(self): return self._idx def __str__(self): - return 'KafkaBytes(%d)' % len(self) + return "KafkaBytes(%d)" % len(self) def __repr__(self): return str(self) diff --git a/aiokafka/protocol/group.py b/aiokafka/protocol/group.py new file mode 100644 index 00000000..a809738a --- /dev/null +++ b/aiokafka/protocol/group.py @@ -0,0 +1,203 @@ +from .api import Request, Response +from .struct import Struct +from .types import Array, Bytes, Int16, Int32, Schema, String + + +class JoinGroupResponse_v0(Response): + API_KEY = 11 + API_VERSION = 0 + SCHEMA = Schema( + ("error_code", Int16), + ("generation_id", Int32), + ("group_protocol", String("utf-8")), + ("leader_id", String("utf-8")), + ("member_id", String("utf-8")), + ("members", Array(("member_id", String("utf-8")), ("member_metadata", Bytes))), + ) + + +class JoinGroupResponse_v1(Response): + API_KEY = 11 + API_VERSION = 1 + SCHEMA = JoinGroupResponse_v0.SCHEMA + + +class JoinGroupResponse_v2(Response): + API_KEY = 11 + API_VERSION = 2 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ("error_code", Int16), + ("generation_id", Int32), + ("group_protocol", String("utf-8")), + ("leader_id", String("utf-8")), + ("member_id", String("utf-8")), + ("members", Array(("member_id", String("utf-8")), ("member_metadata", Bytes))), + ) + + +class JoinGroupRequest_v0(Request): + API_KEY = 11 + API_VERSION = 0 + RESPONSE_TYPE = JoinGroupResponse_v0 + SCHEMA = Schema( + ("group", String("utf-8")), + ("session_timeout", Int32), + ("member_id", String("utf-8")), + ("protocol_type", String("utf-8")), + ( + "group_protocols", + Array(("protocol_name", String("utf-8")), ("protocol_metadata", Bytes)), + ), + ) + UNKNOWN_MEMBER_ID = "" + + +class JoinGroupRequest_v1(Request): + API_KEY = 11 + API_VERSION = 1 + RESPONSE_TYPE = JoinGroupResponse_v1 + SCHEMA = Schema( + ("group", String("utf-8")), + ("session_timeout", Int32), + ("rebalance_timeout", Int32), + ("member_id", String("utf-8")), + ("protocol_type", String("utf-8")), + ( + "group_protocols", + Array(("protocol_name", String("utf-8")), ("protocol_metadata", Bytes)), + ), + ) + UNKNOWN_MEMBER_ID = "" + + +class JoinGroupRequest_v2(Request): + API_KEY = 11 + API_VERSION = 2 + RESPONSE_TYPE = JoinGroupResponse_v2 + SCHEMA = JoinGroupRequest_v1.SCHEMA + UNKNOWN_MEMBER_ID = "" + + +JoinGroupRequest = [JoinGroupRequest_v0, JoinGroupRequest_v1, JoinGroupRequest_v2] +JoinGroupResponse = [JoinGroupResponse_v0, JoinGroupResponse_v1, JoinGroupResponse_v2] + + +class ProtocolMetadata(Struct): + SCHEMA = Schema( + ("version", Int16), + ("subscription", Array(String("utf-8"))), # topics list + ("user_data", Bytes), + ) + + +class SyncGroupResponse_v0(Response): + API_KEY = 14 + API_VERSION = 0 + SCHEMA = Schema(("error_code", Int16), ("member_assignment", Bytes)) + + +class SyncGroupResponse_v1(Response): + API_KEY = 14 + API_VERSION = 1 + SCHEMA = Schema( + ("throttle_time_ms", Int32), ("error_code", Int16), ("member_assignment", Bytes) + ) + + +class SyncGroupRequest_v0(Request): + API_KEY = 14 + API_VERSION = 0 + RESPONSE_TYPE = SyncGroupResponse_v0 + SCHEMA = Schema( + ("group", String("utf-8")), + ("generation_id", Int32), + ("member_id", String("utf-8")), + ( + "group_assignment", + Array(("member_id", String("utf-8")), ("member_metadata", Bytes)), + ), + ) + + +class SyncGroupRequest_v1(Request): + API_KEY = 14 + API_VERSION = 1 + RESPONSE_TYPE = SyncGroupResponse_v1 + SCHEMA = SyncGroupRequest_v0.SCHEMA + + +SyncGroupRequest = [SyncGroupRequest_v0, SyncGroupRequest_v1] +SyncGroupResponse = [SyncGroupResponse_v0, SyncGroupResponse_v1] + + +class MemberAssignment(Struct): + SCHEMA = Schema( + ("version", Int16), + ("assignment", Array(("topic", String("utf-8")), ("partitions", Array(Int32)))), + ("user_data", Bytes), + ) + + +class HeartbeatResponse_v0(Response): + API_KEY = 12 + API_VERSION = 0 + SCHEMA = Schema(("error_code", Int16)) + + +class HeartbeatResponse_v1(Response): + API_KEY = 12 + API_VERSION = 1 + SCHEMA = Schema(("throttle_time_ms", Int32), ("error_code", Int16)) + + +class HeartbeatRequest_v0(Request): + API_KEY = 12 + API_VERSION = 0 + RESPONSE_TYPE = HeartbeatResponse_v0 + SCHEMA = Schema( + ("group", String("utf-8")), + ("generation_id", Int32), + ("member_id", String("utf-8")), + ) + + +class HeartbeatRequest_v1(Request): + API_KEY = 12 + API_VERSION = 1 + RESPONSE_TYPE = HeartbeatResponse_v1 + SCHEMA = HeartbeatRequest_v0.SCHEMA + + +HeartbeatRequest = [HeartbeatRequest_v0, HeartbeatRequest_v1] +HeartbeatResponse = [HeartbeatResponse_v0, HeartbeatResponse_v1] + + +class LeaveGroupResponse_v0(Response): + API_KEY = 13 + API_VERSION = 0 + SCHEMA = Schema(("error_code", Int16)) + + +class LeaveGroupResponse_v1(Response): + API_KEY = 13 + API_VERSION = 1 + SCHEMA = Schema(("throttle_time_ms", Int32), ("error_code", Int16)) + + +class LeaveGroupRequest_v0(Request): + API_KEY = 13 + API_VERSION = 0 + RESPONSE_TYPE = LeaveGroupResponse_v0 + SCHEMA = Schema(("group", String("utf-8")), ("member_id", String("utf-8"))) + + +class LeaveGroupRequest_v1(Request): + API_KEY = 13 + API_VERSION = 1 + RESPONSE_TYPE = LeaveGroupResponse_v1 + SCHEMA = LeaveGroupRequest_v0.SCHEMA + + +LeaveGroupRequest = [LeaveGroupRequest_v0, LeaveGroupRequest_v1] +LeaveGroupResponse = [LeaveGroupResponse_v0, LeaveGroupResponse_v1] diff --git a/kafka/protocol/message.py b/aiokafka/protocol/message.py similarity index 73% rename from kafka/protocol/message.py rename to aiokafka/protocol/message.py index 4c5c031b..d187b9bc 100644 --- a/kafka/protocol/message.py +++ b/aiokafka/protocol/message.py @@ -1,34 +1,40 @@ -from __future__ import absolute_import - import io import time -from kafka.codec import (has_gzip, has_snappy, has_lz4, has_zstd, - gzip_decode, snappy_decode, zstd_decode, - lz4_decode, lz4_decode_old_kafka) -from kafka.protocol.frame import KafkaBytes -from kafka.protocol.struct import Struct -from kafka.protocol.types import ( - Int8, Int32, Int64, Bytes, Schema, AbstractType +from kafka.codec import ( + has_gzip, + has_snappy, + has_lz4, + has_zstd, + gzip_decode, + snappy_decode, + zstd_decode, + lz4_decode, + lz4_decode_old_kafka, ) +from .frame import KafkaBytes +from .struct import Struct +from .types import Int8, Int32, Int64, Bytes, Schema, AbstractType from kafka.util import crc32, WeakMethod class Message(Struct): SCHEMAS = [ Schema( - ('crc', Int32), - ('magic', Int8), - ('attributes', Int8), - ('key', Bytes), - ('value', Bytes)), + ("crc", Int32), + ("magic", Int8), + ("attributes", Int8), + ("key", Bytes), + ("value", Bytes), + ), Schema( - ('crc', Int32), - ('magic', Int8), - ('attributes', Int8), - ('timestamp', Int64), - ('key', Bytes), - ('value', Bytes)), + ("crc", Int32), + ("magic", Int8), + ("attributes", Int8), + ("timestamp", Int64), + ("key", Bytes), + ("value", Bytes), + ), ] SCHEMA = SCHEMAS[1] CODEC_MASK = 0x07 @@ -37,13 +43,14 @@ class Message(Struct): CODEC_LZ4 = 0x03 CODEC_ZSTD = 0x04 TIMESTAMP_TYPE_MASK = 0x08 - HEADER_SIZE = 22 # crc(4), magic(1), attributes(1), timestamp(8), key+value size(4*2) + HEADER_SIZE = ( + 22 # crc(4), magic(1), attributes(1), timestamp(8), key+value size(4*2) + ) - def __init__(self, value, key=None, magic=0, attributes=0, crc=0, - timestamp=None): - assert value is None or isinstance(value, bytes), 'value must be bytes' - assert key is None or isinstance(key, bytes), 'key must be bytes' - assert magic > 0 or timestamp is None, 'timestamp not supported in v0' + def __init__(self, value, key=None, magic=0, attributes=0, crc=0, timestamp=None): + assert value is None or isinstance(value, bytes), "value must be bytes" + assert key is None or isinstance(key, bytes), "key must be bytes" + assert magic > 0 or timestamp is None, "timestamp not supported in v0" # Default timestamp to now for v1 messages if magic > 0 and timestamp is None: @@ -74,11 +81,18 @@ def timestamp_type(self): def _encode_self(self, recalc_crc=True): version = self.magic if version == 1: - fields = (self.crc, self.magic, self.attributes, self.timestamp, self.key, self.value) + fields = ( + self.crc, + self.magic, + self.attributes, + self.timestamp, + self.key, + self.value, + ) elif version == 0: fields = (self.crc, self.magic, self.attributes, self.key, self.value) else: - raise ValueError('Unrecognized message version: %s' % (version,)) + raise ValueError("Unrecognized message version: %s" % (version,)) message = Message.SCHEMAS[version].encode(fields) if not recalc_crc: return message @@ -101,9 +115,14 @@ def decode(cls, data): timestamp = fields[0] else: timestamp = None - msg = cls(fields[-1], key=fields[-2], - magic=magic, attributes=attributes, crc=crc, - timestamp=timestamp) + msg = cls( + fields[-1], + key=fields[-2], + magic=magic, + attributes=attributes, + crc=crc, + timestamp=timestamp, + ) msg._validated_crc = _validated_crc return msg @@ -120,15 +139,20 @@ def is_compressed(self): def decompress(self): codec = self.attributes & self.CODEC_MASK - assert codec in (self.CODEC_GZIP, self.CODEC_SNAPPY, self.CODEC_LZ4, self.CODEC_ZSTD) + assert codec in ( + self.CODEC_GZIP, + self.CODEC_SNAPPY, + self.CODEC_LZ4, + self.CODEC_ZSTD, + ) if codec == self.CODEC_GZIP: - assert has_gzip(), 'Gzip decompression unsupported' + assert has_gzip(), "Gzip decompression unsupported" raw_bytes = gzip_decode(self.value) elif codec == self.CODEC_SNAPPY: - assert has_snappy(), 'Snappy decompression unsupported' + assert has_snappy(), "Snappy decompression unsupported" raw_bytes = snappy_decode(self.value) elif codec == self.CODEC_LZ4: - assert has_lz4(), 'LZ4 decompression unsupported' + assert has_lz4(), "LZ4 decompression unsupported" if self.magic == 0: raw_bytes = lz4_decode_old_kafka(self.value) else: @@ -137,7 +161,7 @@ def decompress(self): assert has_zstd(), "ZSTD decompression unsupported" raw_bytes = zstd_decode(self.value) else: - raise Exception('This should be impossible') + raise Exception("This should be impossible") return MessageSet.decode(raw_bytes, bytes_to_read=len(raw_bytes)) @@ -147,14 +171,11 @@ def __hash__(self): class PartialMessage(bytes): def __repr__(self): - return 'PartialMessage(%s)' % (self,) + return "PartialMessage(%s)" % (self,) class MessageSet(AbstractType): - ITEM = Schema( - ('offset', Int64), - ('message', Bytes) - ) + ITEM = Schema(("offset", Int64), ("message", Bytes)) HEADER_SIZE = 12 # offset + message_size @classmethod @@ -172,7 +193,7 @@ def encode(cls, items, prepend_size=True): for (offset, message) in items: encoded_values.append(Int64.encode(offset)) encoded_values.append(Bytes.encode(message)) - encoded = b''.join(encoded_values) + encoded = b"".join(encoded_values) if prepend_size: return Bytes.encode(encoded) else: diff --git a/aiokafka/protocol/metadata.py b/aiokafka/protocol/metadata.py new file mode 100644 index 00000000..79a5600a --- /dev/null +++ b/aiokafka/protocol/metadata.py @@ -0,0 +1,260 @@ +from .api import Request, Response +from .types import Array, Boolean, Int16, Int32, Schema, String + + +class MetadataResponse_v0(Response): + API_KEY = 3 + API_VERSION = 0 + SCHEMA = Schema( + ( + "brokers", + Array(("node_id", Int32), ("host", String("utf-8")), ("port", Int32)), + ), + ( + "topics", + Array( + ("error_code", Int16), + ("topic", String("utf-8")), + ( + "partitions", + Array( + ("error_code", Int16), + ("partition", Int32), + ("leader", Int32), + ("replicas", Array(Int32)), + ("isr", Array(Int32)), + ), + ), + ), + ), + ) + + +class MetadataResponse_v1(Response): + API_KEY = 3 + API_VERSION = 1 + SCHEMA = Schema( + ( + "brokers", + Array( + ("node_id", Int32), + ("host", String("utf-8")), + ("port", Int32), + ("rack", String("utf-8")), + ), + ), + ("controller_id", Int32), + ( + "topics", + Array( + ("error_code", Int16), + ("topic", String("utf-8")), + ("is_internal", Boolean), + ( + "partitions", + Array( + ("error_code", Int16), + ("partition", Int32), + ("leader", Int32), + ("replicas", Array(Int32)), + ("isr", Array(Int32)), + ), + ), + ), + ), + ) + + +class MetadataResponse_v2(Response): + API_KEY = 3 + API_VERSION = 2 + SCHEMA = Schema( + ( + "brokers", + Array( + ("node_id", Int32), + ("host", String("utf-8")), + ("port", Int32), + ("rack", String("utf-8")), + ), + ), + ("cluster_id", String("utf-8")), # <-- Added cluster_id field in v2 + ("controller_id", Int32), + ( + "topics", + Array( + ("error_code", Int16), + ("topic", String("utf-8")), + ("is_internal", Boolean), + ( + "partitions", + Array( + ("error_code", Int16), + ("partition", Int32), + ("leader", Int32), + ("replicas", Array(Int32)), + ("isr", Array(Int32)), + ), + ), + ), + ), + ) + + +class MetadataResponse_v3(Response): + API_KEY = 3 + API_VERSION = 3 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ( + "brokers", + Array( + ("node_id", Int32), + ("host", String("utf-8")), + ("port", Int32), + ("rack", String("utf-8")), + ), + ), + ("cluster_id", String("utf-8")), + ("controller_id", Int32), + ( + "topics", + Array( + ("error_code", Int16), + ("topic", String("utf-8")), + ("is_internal", Boolean), + ( + "partitions", + Array( + ("error_code", Int16), + ("partition", Int32), + ("leader", Int32), + ("replicas", Array(Int32)), + ("isr", Array(Int32)), + ), + ), + ), + ), + ) + + +class MetadataResponse_v4(Response): + API_KEY = 3 + API_VERSION = 4 + SCHEMA = MetadataResponse_v3.SCHEMA + + +class MetadataResponse_v5(Response): + API_KEY = 3 + API_VERSION = 5 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ( + "brokers", + Array( + ("node_id", Int32), + ("host", String("utf-8")), + ("port", Int32), + ("rack", String("utf-8")), + ), + ), + ("cluster_id", String("utf-8")), + ("controller_id", Int32), + ( + "topics", + Array( + ("error_code", Int16), + ("topic", String("utf-8")), + ("is_internal", Boolean), + ( + "partitions", + Array( + ("error_code", Int16), + ("partition", Int32), + ("leader", Int32), + ("replicas", Array(Int32)), + ("isr", Array(Int32)), + ("offline_replicas", Array(Int32)), + ), + ), + ), + ), + ) + + +class MetadataRequest_v0(Request): + API_KEY = 3 + API_VERSION = 0 + RESPONSE_TYPE = MetadataResponse_v0 + SCHEMA = Schema(("topics", Array(String("utf-8")))) + ALL_TOPICS = None # Empty Array (len 0) for topics returns all topics + + +class MetadataRequest_v1(Request): + API_KEY = 3 + API_VERSION = 1 + RESPONSE_TYPE = MetadataResponse_v1 + SCHEMA = MetadataRequest_v0.SCHEMA + ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics + NO_TOPICS = None # Empty array (len 0) for topics returns no topics + + +class MetadataRequest_v2(Request): + API_KEY = 3 + API_VERSION = 2 + RESPONSE_TYPE = MetadataResponse_v2 + SCHEMA = MetadataRequest_v1.SCHEMA + ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics + NO_TOPICS = None # Empty array (len 0) for topics returns no topics + + +class MetadataRequest_v3(Request): + API_KEY = 3 + API_VERSION = 3 + RESPONSE_TYPE = MetadataResponse_v3 + SCHEMA = MetadataRequest_v1.SCHEMA + ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics + NO_TOPICS = None # Empty array (len 0) for topics returns no topics + + +class MetadataRequest_v4(Request): + API_KEY = 3 + API_VERSION = 4 + RESPONSE_TYPE = MetadataResponse_v4 + SCHEMA = Schema( + ("topics", Array(String("utf-8"))), ("allow_auto_topic_creation", Boolean) + ) + ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics + NO_TOPICS = None # Empty array (len 0) for topics returns no topics + + +class MetadataRequest_v5(Request): + """ + The v5 metadata request is the same as v4. + An additional field for offline_replicas has been added to the v5 metadata response + """ + + API_KEY = 3 + API_VERSION = 5 + RESPONSE_TYPE = MetadataResponse_v5 + SCHEMA = MetadataRequest_v4.SCHEMA + ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics + NO_TOPICS = None # Empty array (len 0) for topics returns no topics + + +MetadataRequest = [ + MetadataRequest_v0, + MetadataRequest_v1, + MetadataRequest_v2, + MetadataRequest_v3, + MetadataRequest_v4, + MetadataRequest_v5, +] +MetadataResponse = [ + MetadataResponse_v0, + MetadataResponse_v1, + MetadataResponse_v2, + MetadataResponse_v3, + MetadataResponse_v4, + MetadataResponse_v5, +] diff --git a/aiokafka/protocol/offset.py b/aiokafka/protocol/offset.py new file mode 100644 index 00000000..11eef1e8 --- /dev/null +++ b/aiokafka/protocol/offset.py @@ -0,0 +1,246 @@ +from .api import Request, Response +from .types import Array, Int8, Int16, Int32, Int64, Schema, String + +UNKNOWN_OFFSET = -1 + + +class OffsetResetStrategy(object): + LATEST = -1 + EARLIEST = -2 + NONE = 0 + + +class OffsetResponse_v0(Response): + API_KEY = 2 + API_VERSION = 0 + SCHEMA = Schema( + ( + "topics", + Array( + ("topic", String("utf-8")), + ( + "partitions", + Array( + ("partition", Int32), + ("error_code", Int16), + ("offsets", Array(Int64)), + ), + ), + ), + ) + ) + + +class OffsetResponse_v1(Response): + API_KEY = 2 + API_VERSION = 1 + SCHEMA = Schema( + ( + "topics", + Array( + ("topic", String("utf-8")), + ( + "partitions", + Array( + ("partition", Int32), + ("error_code", Int16), + ("timestamp", Int64), + ("offset", Int64), + ), + ), + ), + ) + ) + + +class OffsetResponse_v2(Response): + API_KEY = 2 + API_VERSION = 2 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ( + "topics", + Array( + ("topic", String("utf-8")), + ( + "partitions", + Array( + ("partition", Int32), + ("error_code", Int16), + ("timestamp", Int64), + ("offset", Int64), + ), + ), + ), + ), + ) + + +class OffsetResponse_v3(Response): + """ + on quota violation, brokers send out responses before throttling + """ + + API_KEY = 2 + API_VERSION = 3 + SCHEMA = OffsetResponse_v2.SCHEMA + + +class OffsetResponse_v4(Response): + """ + Add leader_epoch to response + """ + + API_KEY = 2 + API_VERSION = 4 + SCHEMA = Schema( + ("throttle_time_ms", Int32), + ( + "topics", + Array( + ("topic", String("utf-8")), + ( + "partitions", + Array( + ("partition", Int32), + ("error_code", Int16), + ("timestamp", Int64), + ("offset", Int64), + ("leader_epoch", Int32), + ), + ), + ), + ), + ) + + +class OffsetResponse_v5(Response): + """ + adds a new error code, OFFSET_NOT_AVAILABLE + """ + + API_KEY = 2 + API_VERSION = 5 + SCHEMA = OffsetResponse_v4.SCHEMA + + +class OffsetRequest_v0(Request): + API_KEY = 2 + API_VERSION = 0 + RESPONSE_TYPE = OffsetResponse_v0 + SCHEMA = Schema( + ("replica_id", Int32), + ( + "topics", + Array( + ("topic", String("utf-8")), + ( + "partitions", + Array( + ("partition", Int32), + ("timestamp", Int64), + ("max_offsets", Int32), + ), + ), + ), + ), + ) + DEFAULTS = {"replica_id": -1} + + +class OffsetRequest_v1(Request): + API_KEY = 2 + API_VERSION = 1 + RESPONSE_TYPE = OffsetResponse_v1 + SCHEMA = Schema( + ("replica_id", Int32), + ( + "topics", + Array( + ("topic", String("utf-8")), + ("partitions", Array(("partition", Int32), ("timestamp", Int64))), + ), + ), + ) + DEFAULTS = {"replica_id": -1} + + +class OffsetRequest_v2(Request): + API_KEY = 2 + API_VERSION = 2 + RESPONSE_TYPE = OffsetResponse_v2 + SCHEMA = Schema( + ("replica_id", Int32), + ("isolation_level", Int8), # <- added isolation_level + ( + "topics", + Array( + ("topic", String("utf-8")), + ("partitions", Array(("partition", Int32), ("timestamp", Int64))), + ), + ), + ) + DEFAULTS = {"replica_id": -1} + + +class OffsetRequest_v3(Request): + API_KEY = 2 + API_VERSION = 3 + RESPONSE_TYPE = OffsetResponse_v3 + SCHEMA = OffsetRequest_v2.SCHEMA + DEFAULTS = {"replica_id": -1} + + +class OffsetRequest_v4(Request): + """ + Add current_leader_epoch to request + """ + + API_KEY = 2 + API_VERSION = 4 + RESPONSE_TYPE = OffsetResponse_v4 + SCHEMA = Schema( + ("replica_id", Int32), + ("isolation_level", Int8), # <- added isolation_level + ( + "topics", + Array( + ("topic", String("utf-8")), + ( + "partitions", + Array( + ("partition", Int32), + ("current_leader_epoch", Int64), + ("timestamp", Int64), + ), + ), + ), + ), + ) + DEFAULTS = {"replica_id": -1} + + +class OffsetRequest_v5(Request): + API_KEY = 2 + API_VERSION = 5 + RESPONSE_TYPE = OffsetResponse_v5 + SCHEMA = OffsetRequest_v4.SCHEMA + DEFAULTS = {"replica_id": -1} + + +OffsetRequest = [ + OffsetRequest_v0, + OffsetRequest_v1, + OffsetRequest_v2, + OffsetRequest_v3, + OffsetRequest_v4, + OffsetRequest_v5, +] +OffsetResponse = [ + OffsetResponse_v0, + OffsetResponse_v1, + OffsetResponse_v2, + OffsetResponse_v3, + OffsetResponse_v4, + OffsetResponse_v5, +] diff --git a/kafka/protocol/parser.py b/aiokafka/protocol/parser.py similarity index 69% rename from kafka/protocol/parser.py rename to aiokafka/protocol/parser.py index a872202d..c19dc4f1 100644 --- a/kafka/protocol/parser.py +++ b/aiokafka/protocol/parser.py @@ -1,13 +1,12 @@ -from __future__ import absolute_import - import collections import logging import aiokafka.errors as Errors -from kafka.protocol.commit import GroupCoordinatorResponse -from kafka.protocol.frame import KafkaBytes -from kafka.protocol.types import Int32, TaggedFields -from kafka.version import __version__ +from aiokafka import __version__ + +from .commit import GroupCoordinatorResponse +from .frame import KafkaBytes +from .types import Int32 log = logging.getLogger(__name__) @@ -24,6 +23,7 @@ class KafkaProtocol(object): Currently only used to check for 0.8.2 protocol quirks, but may be used for more in the future. """ + def __init__(self, client_id=None, api_version=None): if client_id is None: client_id = self._gen_client_id() @@ -41,7 +41,7 @@ def _next_correlation_id(self): return self._correlation_id def _gen_client_id(self): - return 'kafka-python' + __version__ + return "aiokafka" + __version__ def send_request(self, request, correlation_id=None): """Encode and queue a kafka api request for sending. @@ -55,12 +55,14 @@ def send_request(self, request, correlation_id=None): Returns: correlation_id """ - log.debug('Sending request %s', request) + log.debug("Sending request %s", request) if correlation_id is None: correlation_id = self._next_correlation_id() - header = request.build_request_header(correlation_id=correlation_id, client_id=self._client_id) - message = b''.join([header.encode(), request.encode()]) + header = request.build_request_header( + correlation_id=correlation_id, client_id=self._client_id + ) + message = b"".join([header.encode(), request.encode()]) size = Int32.encode(len(message)) data = size + message self.bytes_to_send.append(data) @@ -71,7 +73,7 @@ def send_request(self, request, correlation_id=None): def send_bytes(self): """Retrieve all pending bytes to send on the network""" - data = b''.join(self.bytes_to_send) + data = b"".join(self.bytes_to_send) self.bytes_to_send = [] return data @@ -99,7 +101,7 @@ def receive_bytes(self, data): # Not receiving is the state of reading the payload header if not self._receiving: bytes_to_read = min(4 - self._header.tell(), n - i) - self._header.write(data[i:i+bytes_to_read]) + self._header.write(data[i:i + bytes_to_read]) i += bytes_to_read if self._header.tell() == 4: @@ -109,18 +111,22 @@ def receive_bytes(self, data): self._rbuffer = KafkaBytes(nbytes) self._receiving = True elif self._header.tell() > 4: - raise Errors.KafkaError('this should not happen - are you threading?') + raise Errors.KafkaError( + "this should not happen - are you threading?" + ) if self._receiving: total_bytes = len(self._rbuffer) staged_bytes = self._rbuffer.tell() bytes_to_read = min(total_bytes - staged_bytes, n - i) - self._rbuffer.write(data[i:i+bytes_to_read]) + self._rbuffer.write(data[i:i + bytes_to_read]) i += bytes_to_read staged_bytes = self._rbuffer.tell() if staged_bytes > total_bytes: - raise Errors.KafkaError('Receive buffer has more bytes than expected?') + raise Errors.KafkaError( + "Receive buffer has more bytes than expected?" + ) if staged_bytes != total_bytes: break @@ -134,39 +140,51 @@ def receive_bytes(self, data): def _process_response(self, read_buffer): if not self.in_flight_requests: - raise Errors.CorrelationIdError('No in-flight-request found for server response') + raise Errors.CorrelationIdError( + "No in-flight-request found for server response" + ) (correlation_id, request) = self.in_flight_requests.popleft() response_header = request.parse_response_header(read_buffer) recv_correlation_id = response_header.correlation_id - log.debug('Received correlation id: %d', recv_correlation_id) + log.debug("Received correlation id: %d", recv_correlation_id) # 0.8.2 quirk - if (recv_correlation_id == 0 and - correlation_id != 0 and - request.RESPONSE_TYPE is GroupCoordinatorResponse[0] and - (self._api_version == (0, 8, 2) or self._api_version is None)): - log.warning('Kafka 0.8.2 quirk -- GroupCoordinatorResponse' - ' Correlation ID does not match request. This' - ' should go away once at least one topic has been' - ' initialized on the broker.') + if ( + recv_correlation_id == 0 + and correlation_id != 0 + and request.RESPONSE_TYPE is GroupCoordinatorResponse[0] + and (self._api_version == (0, 8, 2) or self._api_version is None) + ): + log.warning( + "Kafka 0.8.2 quirk -- GroupCoordinatorResponse" + " Correlation ID does not match request. This" + " should go away once at least one topic has been" + " initialized on the broker." + ) elif correlation_id != recv_correlation_id: # return or raise? raise Errors.CorrelationIdError( - 'Correlation IDs do not match: sent %d, recv %d' - % (correlation_id, recv_correlation_id)) + "Correlation IDs do not match: sent %d, recv %d" + % (correlation_id, recv_correlation_id) + ) # decode response - log.debug('Processing response %s', request.RESPONSE_TYPE.__name__) + log.debug("Processing response %s", request.RESPONSE_TYPE.__name__) try: response = request.RESPONSE_TYPE.decode(read_buffer) except ValueError: read_buffer.seek(0) buf = read_buffer.read() - log.error('Response %d [ResponseType: %s Request: %s]:' - ' Unable to decode %d-byte buffer: %r', - correlation_id, request.RESPONSE_TYPE, - request, len(buf), buf) - raise Errors.KafkaProtocolError('Unable to decode response') + log.error( + "Response %d [ResponseType: %s Request: %s]:" + " Unable to decode %d-byte buffer: %r", + correlation_id, + request.RESPONSE_TYPE, + request, + len(buf), + buf, + ) + raise Errors.KafkaProtocolError("Unable to decode response") return (correlation_id, response) diff --git a/kafka/protocol/pickle.py b/aiokafka/protocol/pickle.py similarity index 80% rename from kafka/protocol/pickle.py rename to aiokafka/protocol/pickle.py index d6e5fa74..780c4e88 100644 --- a/kafka/protocol/pickle.py +++ b/aiokafka/protocol/pickle.py @@ -1,10 +1,4 @@ -from __future__ import absolute_import - -try: - import copyreg # pylint: disable=import-error -except ImportError: - import copy_reg as copyreg # pylint: disable=import-error - +import copyreg import types @@ -31,5 +25,6 @@ def _unpickle_method(func_name, obj, cls): break return func.__get__(obj, cls) + # https://bytes.com/topic/python/answers/552476-why-cant-you-pickle-instancemethods copyreg.pickle(types.MethodType, _pickle_method, _unpickle_method) diff --git a/aiokafka/protocol/produce.py b/aiokafka/protocol/produce.py new file mode 100644 index 00000000..4083b941 --- /dev/null +++ b/aiokafka/protocol/produce.py @@ -0,0 +1,299 @@ +from .api import Request, Response +from .types import Int16, Int32, Int64, String, Array, Schema, Bytes + + +class ProduceResponse_v0(Response): + API_KEY = 0 + API_VERSION = 0 + SCHEMA = Schema( + ( + "topics", + Array( + ("topic", String("utf-8")), + ( + "partitions", + Array( + ("partition", Int32), ("error_code", Int16), ("offset", Int64) + ), + ), + ), + ) + ) + + +class ProduceResponse_v1(Response): + API_KEY = 0 + API_VERSION = 1 + SCHEMA = Schema( + ( + "topics", + Array( + ("topic", String("utf-8")), + ( + "partitions", + Array( + ("partition", Int32), ("error_code", Int16), ("offset", Int64) + ), + ), + ), + ), + ("throttle_time_ms", Int32), + ) + + +class ProduceResponse_v2(Response): + API_KEY = 0 + API_VERSION = 2 + SCHEMA = Schema( + ( + "topics", + Array( + ("topic", String("utf-8")), + ( + "partitions", + Array( + ("partition", Int32), + ("error_code", Int16), + ("offset", Int64), + ("timestamp", Int64), + ), + ), + ), + ), + ("throttle_time_ms", Int32), + ) + + +class ProduceResponse_v3(Response): + API_KEY = 0 + API_VERSION = 3 + SCHEMA = ProduceResponse_v2.SCHEMA + + +class ProduceResponse_v4(Response): + """ + The version number is bumped up to indicate that the client supports + KafkaStorageException. The KafkaStorageException will be translated to + NotLeaderForPartitionException in the response if version <= 3 + """ + + API_KEY = 0 + API_VERSION = 4 + SCHEMA = ProduceResponse_v3.SCHEMA + + +class ProduceResponse_v5(Response): + API_KEY = 0 + API_VERSION = 5 + SCHEMA = Schema( + ( + "topics", + Array( + ("topic", String("utf-8")), + ( + "partitions", + Array( + ("partition", Int32), + ("error_code", Int16), + ("offset", Int64), + ("timestamp", Int64), + ("log_start_offset", Int64), + ), + ), + ), + ), + ("throttle_time_ms", Int32), + ) + + +class ProduceResponse_v6(Response): + """ + The version number is bumped to indicate that on quota violation brokers send out + responses before throttling. + """ + + API_KEY = 0 + API_VERSION = 6 + SCHEMA = ProduceResponse_v5.SCHEMA + + +class ProduceResponse_v7(Response): + """ + V7 bumped up to indicate ZStandard capability. (see KIP-110) + """ + + API_KEY = 0 + API_VERSION = 7 + SCHEMA = ProduceResponse_v6.SCHEMA + + +class ProduceResponse_v8(Response): + """ + V8 bumped up to add two new fields record_errors offset list and error_message + (See KIP-467) + """ + + API_KEY = 0 + API_VERSION = 8 + SCHEMA = Schema( + ( + "topics", + Array( + ("topic", String("utf-8")), + ( + "partitions", + Array( + ("partition", Int32), + ("error_code", Int16), + ("offset", Int64), + ("timestamp", Int64), + ("log_start_offset", Int64), + ), + ( + "record_errors", + ( + Array( + ("batch_index", Int32), + ("batch_index_error_message", String("utf-8")), + ) + ), + ), + ("error_message", String("utf-8")), + ), + ), + ), + ("throttle_time_ms", Int32), + ) + + +class ProduceRequest(Request): + API_KEY = 0 + + def expect_response(self): + if self.required_acks == 0: # pylint: disable=no-member + return False + return True + + +class ProduceRequest_v0(ProduceRequest): + API_VERSION = 0 + RESPONSE_TYPE = ProduceResponse_v0 + SCHEMA = Schema( + ("required_acks", Int16), + ("timeout", Int32), + ( + "topics", + Array( + ("topic", String("utf-8")), + ("partitions", Array(("partition", Int32), ("messages", Bytes))), + ), + ), + ) + + +class ProduceRequest_v1(ProduceRequest): + API_VERSION = 1 + RESPONSE_TYPE = ProduceResponse_v1 + SCHEMA = ProduceRequest_v0.SCHEMA + + +class ProduceRequest_v2(ProduceRequest): + API_VERSION = 2 + RESPONSE_TYPE = ProduceResponse_v2 + SCHEMA = ProduceRequest_v1.SCHEMA + + +class ProduceRequest_v3(ProduceRequest): + API_VERSION = 3 + RESPONSE_TYPE = ProduceResponse_v3 + SCHEMA = Schema( + ("transactional_id", String("utf-8")), + ("required_acks", Int16), + ("timeout", Int32), + ( + "topics", + Array( + ("topic", String("utf-8")), + ("partitions", Array(("partition", Int32), ("messages", Bytes))), + ), + ), + ) + + +class ProduceRequest_v4(ProduceRequest): + """ + The version number is bumped up to indicate that the client supports + KafkaStorageException. The KafkaStorageException will be translated to + NotLeaderForPartitionException in the response if version <= 3 + """ + + API_VERSION = 4 + RESPONSE_TYPE = ProduceResponse_v4 + SCHEMA = ProduceRequest_v3.SCHEMA + + +class ProduceRequest_v5(ProduceRequest): + """ + Same as v4. The version number is bumped since the v5 response includes an + additional partition level field: the log_start_offset. + """ + + API_VERSION = 5 + RESPONSE_TYPE = ProduceResponse_v5 + SCHEMA = ProduceRequest_v4.SCHEMA + + +class ProduceRequest_v6(ProduceRequest): + """ + The version number is bumped to indicate that on quota violation brokers send out + responses before throttling. + """ + + API_VERSION = 6 + RESPONSE_TYPE = ProduceResponse_v6 + SCHEMA = ProduceRequest_v5.SCHEMA + + +class ProduceRequest_v7(ProduceRequest): + """ + V7 bumped up to indicate ZStandard capability. (see KIP-110) + """ + + API_VERSION = 7 + RESPONSE_TYPE = ProduceResponse_v7 + SCHEMA = ProduceRequest_v6.SCHEMA + + +class ProduceRequest_v8(ProduceRequest): + """ + V8 bumped up to add two new fields record_errors offset list and error_message to + PartitionResponse (See KIP-467) + """ + + API_VERSION = 8 + RESPONSE_TYPE = ProduceResponse_v8 + SCHEMA = ProduceRequest_v7.SCHEMA + + +ProduceRequest = [ + ProduceRequest_v0, + ProduceRequest_v1, + ProduceRequest_v2, + ProduceRequest_v3, + ProduceRequest_v4, + ProduceRequest_v5, + ProduceRequest_v6, + ProduceRequest_v7, + ProduceRequest_v8, +] +ProduceResponse = [ + ProduceResponse_v0, + ProduceResponse_v1, + ProduceResponse_v2, + ProduceResponse_v3, + ProduceResponse_v4, + ProduceResponse_v5, + ProduceResponse_v6, + ProduceResponse_v7, + ProduceResponse_v8, +] diff --git a/kafka/protocol/struct.py b/aiokafka/protocol/struct.py similarity index 72% rename from kafka/protocol/struct.py rename to aiokafka/protocol/struct.py index e9da6e6c..d7faa327 100644 --- a/kafka/protocol/struct.py +++ b/aiokafka/protocol/struct.py @@ -1,9 +1,7 @@ -from __future__ import absolute_import - from io import BytesIO -from kafka.protocol.abstract import AbstractType -from kafka.protocol.types import Schema +from .abstract import AbstractType +from .types import Schema from kafka.util import WeakMethod @@ -16,32 +14,30 @@ def __init__(self, *args, **kwargs): for i, name in enumerate(self.SCHEMA.names): self.__dict__[name] = args[i] elif len(args) > 0: - raise ValueError('Args must be empty or mirror schema') + raise ValueError("Args must be empty or mirror schema") else: for name in self.SCHEMA.names: self.__dict__[name] = kwargs.pop(name, None) if kwargs: - raise ValueError('Keyword(s) not in schema %s: %s' - % (list(self.SCHEMA.names), - ', '.join(kwargs.keys()))) + raise ValueError( + "Keyword(s) not in schema %s: %s" + % (list(self.SCHEMA.names), ", ".join(kwargs.keys())) + ) # overloading encode() to support both class and instance # Without WeakMethod() this creates circular ref, which # causes instances to "leak" to garbage self.encode = WeakMethod(self._encode_self) - @classmethod def encode(cls, item): # pylint: disable=E0202 bits = [] for i, field in enumerate(cls.SCHEMA.fields): bits.append(field.encode(item[i])) - return b''.join(bits) + return b"".join(bits) def _encode_self(self): - return self.SCHEMA.encode( - [self.__dict__[name] for name in self.SCHEMA.names] - ) + return self.SCHEMA.encode([self.__dict__[name] for name in self.SCHEMA.names]) @classmethod def decode(cls, data): @@ -57,8 +53,8 @@ def get_item(self, name): def __repr__(self): key_vals = [] for name, field in zip(self.SCHEMA.names, self.SCHEMA.fields): - key_vals.append('%s=%s' % (name, field.repr(self.__dict__[name]))) - return self.__class__.__name__ + '(' + ', '.join(key_vals) + ')' + key_vals.append("%s=%s" % (name, field.repr(self.__dict__[name]))) + return self.__class__.__name__ + "(" + ", ".join(key_vals) + ")" def __hash__(self): return hash(self.encode()) diff --git a/aiokafka/protocol/transaction.py b/aiokafka/protocol/transaction.py index c684bf3d..aa05d22d 100644 --- a/aiokafka/protocol/transaction.py +++ b/aiokafka/protocol/transaction.py @@ -1,17 +1,15 @@ -from kafka.protocol.api import Request, Response -from kafka.protocol.types import ( - Int16, Int32, Int64, Schema, String, Array, Boolean -) +from .api import Request, Response +from .types import Int16, Int32, Int64, Schema, String, Array, Boolean class InitProducerIdResponse_v0(Response): API_KEY = 22 API_VERSION = 0 SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('error_code', Int16), - ('producer_id', Int64), - ('producer_epoch', Int16), + ("throttle_time_ms", Int32), + ("error_code", Int16), + ("producer_id", Int64), + ("producer_epoch", Int16), ) @@ -20,8 +18,7 @@ class InitProducerIdRequest_v0(Request): API_VERSION = 0 RESPONSE_TYPE = InitProducerIdResponse_v0 SCHEMA = Schema( - ('transactional_id', String('utf-8')), - ('transaction_timeout_ms', Int32) + ("transactional_id", String("utf-8")), ("transaction_timeout_ms", Int32) ) @@ -29,12 +26,17 @@ class AddPartitionsToTxnResponse_v0(Response): API_KEY = 24 API_VERSION = 0 SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('errors', Array( - ('topic', String('utf-8')), - ('partition_errors', Array( - ('partition', Int32), - ('error_code', Int16))))) + ("throttle_time_ms", Int32), + ( + "errors", + Array( + ("topic", String("utf-8")), + ( + "partition_errors", + Array(("partition", Int32), ("error_code", Int16)), + ), + ), + ), ) @@ -43,22 +45,17 @@ class AddPartitionsToTxnRequest_v0(Request): API_VERSION = 0 RESPONSE_TYPE = AddPartitionsToTxnResponse_v0 SCHEMA = Schema( - ('transactional_id', String('utf-8')), - ('producer_id', Int64), - ('producer_epoch', Int16), - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array(Int32)))) + ("transactional_id", String("utf-8")), + ("producer_id", Int64), + ("producer_epoch", Int16), + ("topics", Array(("topic", String("utf-8")), ("partitions", Array(Int32)))), ) class AddOffsetsToTxnResponse_v0(Response): API_KEY = 25 API_VERSION = 0 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('error_code', Int16) - ) + SCHEMA = Schema(("throttle_time_ms", Int32), ("error_code", Int16)) class AddOffsetsToTxnRequest_v0(Request): @@ -66,20 +63,17 @@ class AddOffsetsToTxnRequest_v0(Request): API_VERSION = 0 RESPONSE_TYPE = AddOffsetsToTxnResponse_v0 SCHEMA = Schema( - ('transactional_id', String('utf-8')), - ('producer_id', Int64), - ('producer_epoch', Int16), - ('group_id', String('utf-8')) + ("transactional_id", String("utf-8")), + ("producer_id", Int64), + ("producer_epoch", Int16), + ("group_id", String("utf-8")), ) class EndTxnResponse_v0(Response): API_KEY = 26 API_VERSION = 0 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('error_code', Int16) - ) + SCHEMA = Schema(("throttle_time_ms", Int32), ("error_code", Int16)) class EndTxnRequest_v0(Request): @@ -87,10 +81,10 @@ class EndTxnRequest_v0(Request): API_VERSION = 0 RESPONSE_TYPE = EndTxnResponse_v0 SCHEMA = Schema( - ('transactional_id', String('utf-8')), - ('producer_id', Int64), - ('producer_epoch', Int16), - ('transaction_result', Boolean) + ("transactional_id", String("utf-8")), + ("producer_id", Int64), + ("producer_epoch", Int16), + ("transaction_result", Boolean), ) @@ -98,12 +92,17 @@ class TxnOffsetCommitResponse_v0(Response): API_KEY = 28 API_VERSION = 0 SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('errors', Array( - ('topic', String('utf-8')), - ('partition_errors', Array( - ('partition', Int32), - ('error_code', Int16))))) + ("throttle_time_ms", Int32), + ( + "errors", + Array( + ("topic", String("utf-8")), + ( + "partition_errors", + Array(("partition", Int32), ("error_code", Int16)), + ), + ), + ), ) @@ -112,52 +111,40 @@ class TxnOffsetCommitRequest_v0(Request): API_VERSION = 0 RESPONSE_TYPE = TxnOffsetCommitResponse_v0 SCHEMA = Schema( - ('transactional_id', String('utf-8')), - ('group_id', String('utf-8')), - ('producer_id', Int64), - ('producer_epoch', Int16), - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('offset', Int64), - ('metadata', String('utf-8')))))) + ("transactional_id", String("utf-8")), + ("group_id", String("utf-8")), + ("producer_id", Int64), + ("producer_epoch", Int16), + ( + "topics", + Array( + ("topic", String("utf-8")), + ( + "partitions", + Array( + ("partition", Int32), + ("offset", Int64), + ("metadata", String("utf-8")), + ), + ), + ), + ), ) -InitProducerIdRequest = [ - InitProducerIdRequest_v0 -] -InitProducerIdResponse = [ - InitProducerIdResponse_v0 -] - -AddPartitionsToTxnRequest = [ - AddPartitionsToTxnRequest_v0 -] -AddPartitionsToTxnResponse = [ - AddPartitionsToTxnResponse_v0 -] - -AddOffsetsToTxnRequest = [ - AddOffsetsToTxnRequest_v0 -] -AddOffsetsToTxnResponse = [ - AddOffsetsToTxnResponse_v0 -] - -EndTxnRequest = [ - EndTxnRequest_v0 -] - -EndTxnResponse = [ - EndTxnResponse_v0 -] - -TxnOffsetCommitResponse = [ - TxnOffsetCommitResponse_v0 -] - -TxnOffsetCommitRequest = [ - TxnOffsetCommitRequest_v0 -] +InitProducerIdRequest = [InitProducerIdRequest_v0] +InitProducerIdResponse = [InitProducerIdResponse_v0] + +AddPartitionsToTxnRequest = [AddPartitionsToTxnRequest_v0] +AddPartitionsToTxnResponse = [AddPartitionsToTxnResponse_v0] + +AddOffsetsToTxnRequest = [AddOffsetsToTxnRequest_v0] +AddOffsetsToTxnResponse = [AddOffsetsToTxnResponse_v0] + +EndTxnRequest = [EndTxnRequest_v0] + +EndTxnResponse = [EndTxnResponse_v0] + +TxnOffsetCommitResponse = [TxnOffsetCommitResponse_v0] + +TxnOffsetCommitRequest = [TxnOffsetCommitRequest_v0] diff --git a/kafka/protocol/types.py b/aiokafka/protocol/types.py similarity index 68% rename from kafka/protocol/types.py rename to aiokafka/protocol/types.py index 0e3685d7..56613905 100644 --- a/kafka/protocol/types.py +++ b/aiokafka/protocol/types.py @@ -1,18 +1,17 @@ -from __future__ import absolute_import - import struct from struct import error -from kafka.protocol.abstract import AbstractType +from .abstract import AbstractType def _pack(f, value): try: return f(value) except error as e: - raise ValueError("Error encountered when attempting to convert value: " - "{!r} to struct format: '{}', hit error: {}" - .format(value, f, e)) + raise ValueError( + "Error encountered when attempting to convert value: " + "{!r} to struct format: '{}', hit error: {}".format(value, f, e) + ) def _unpack(f, data): @@ -20,14 +19,15 @@ def _unpack(f, data): (value,) = f(data) return value except error as e: - raise ValueError("Error encountered when attempting to convert value: " - "{!r} to struct format: '{}', hit error: {}" - .format(data, f, e)) + raise ValueError( + "Error encountered when attempting to convert value: " + "{!r} to struct format: '{}', hit error: {}".format(data, f, e) + ) class Int8(AbstractType): - _pack = struct.Struct('>b').pack - _unpack = struct.Struct('>b').unpack + _pack = struct.Struct(">b").pack + _unpack = struct.Struct(">b").unpack @classmethod def encode(cls, value): @@ -39,8 +39,8 @@ def decode(cls, data): class Int16(AbstractType): - _pack = struct.Struct('>h').pack - _unpack = struct.Struct('>h').unpack + _pack = struct.Struct(">h").pack + _unpack = struct.Struct(">h").unpack @classmethod def encode(cls, value): @@ -52,8 +52,8 @@ def decode(cls, data): class Int32(AbstractType): - _pack = struct.Struct('>i').pack - _unpack = struct.Struct('>i').unpack + _pack = struct.Struct(">i").pack + _unpack = struct.Struct(">i").unpack @classmethod def encode(cls, value): @@ -65,8 +65,8 @@ def decode(cls, data): class Int64(AbstractType): - _pack = struct.Struct('>q').pack - _unpack = struct.Struct('>q').unpack + _pack = struct.Struct(">q").pack + _unpack = struct.Struct(">q").unpack @classmethod def encode(cls, value): @@ -78,8 +78,8 @@ def decode(cls, data): class Float64(AbstractType): - _pack = struct.Struct('>d').pack - _unpack = struct.Struct('>d').unpack + _pack = struct.Struct(">d").pack + _unpack = struct.Struct(">d").unpack @classmethod def encode(cls, value): @@ -91,7 +91,7 @@ def decode(cls, data): class String(AbstractType): - def __init__(self, encoding='utf-8'): + def __init__(self, encoding="utf-8"): self.encoding = encoding def encode(self, value): @@ -106,7 +106,7 @@ def decode(self, data): return None value = data.read(length) if len(value) != length: - raise ValueError('Buffer underrun decoding string') + raise ValueError("Buffer underrun decoding string") return value.decode(self.encoding) @@ -125,17 +125,19 @@ def decode(cls, data): return None value = data.read(length) if len(value) != length: - raise ValueError('Buffer underrun decoding Bytes') + raise ValueError("Buffer underrun decoding Bytes") return value @classmethod def repr(cls, value): - return repr(value[:100] + b'...' if value is not None and len(value) > 100 else value) + return repr( + value[:100] + b"..." if value is not None and len(value) > 100 else value + ) class Boolean(AbstractType): - _pack = struct.Struct('>?').pack - _unpack = struct.Struct('>?').unpack + _pack = struct.Struct(">?").pack + _unpack = struct.Struct(">?").unpack @classmethod def encode(cls, value): @@ -155,11 +157,8 @@ def __init__(self, *fields): def encode(self, item): if len(item) != len(self.fields): - raise ValueError('Item field count does not match Schema') - return b''.join([ - field.encode(item[i]) - for i, field in enumerate(self.fields) - ]) + raise ValueError("Item field count does not match Schema") + return b"".join([field.encode(item[i]) for i, field in enumerate(self.fields)]) def decode(self, data): return tuple([field.decode(data) for field in self.fields]) @@ -175,8 +174,10 @@ def repr(self, value): field_val = getattr(value, self.names[i]) except AttributeError: field_val = value[i] - key_vals.append('%s=%s' % (self.names[i], self.fields[i].repr(field_val))) - return '(' + ', '.join(key_vals) + ')' + key_vals.append( + "%s=%s" % (self.names[i], self.fields[i].repr(field_val)) + ) + return "(" + ", ".join(key_vals) + ")" except Exception: return repr(value) @@ -185,20 +186,19 @@ class Array(AbstractType): def __init__(self, *array_of): if len(array_of) > 1: self.array_of = Schema(*array_of) - elif len(array_of) == 1 and (isinstance(array_of[0], AbstractType) or - issubclass(array_of[0], AbstractType)): + elif len(array_of) == 1 and ( + isinstance(array_of[0], AbstractType) + or issubclass(array_of[0], AbstractType) + ): self.array_of = array_of[0] else: - raise ValueError('Array instantiated with no array_of type') + raise ValueError("Array instantiated with no array_of type") def encode(self, items): if items is None: return Int32.encode(-1) encoded_items = [self.array_of.encode(item) for item in items] - return b''.join( - [Int32.encode(len(encoded_items))] + - encoded_items - ) + return b"".join([Int32.encode(len(encoded_items))] + encoded_items) def decode(self, data): length = Int32.decode(data) @@ -208,8 +208,10 @@ def decode(self, data): def repr(self, list_of_items): if list_of_items is None: - return 'NULL' - return '[' + ', '.join([self.array_of.repr(item) for item in list_of_items]) + ']' + return "NULL" + return ( + "[" + ", ".join([self.array_of.repr(item) for item in list_of_items]) + "]" + ) class UnsignedVarInt32(AbstractType): @@ -217,25 +219,25 @@ class UnsignedVarInt32(AbstractType): def decode(cls, data): value, i = 0, 0 while True: - b, = struct.unpack('B', data.read(1)) + (b,) = struct.unpack("B", data.read(1)) if not (b & 0x80): break - value |= (b & 0x7f) << i + value |= (b & 0x7F) << i i += 7 if i > 28: - raise ValueError('Invalid value {}'.format(value)) + raise ValueError("Invalid value {}".format(value)) value |= b << i return value @classmethod def encode(cls, value): - value &= 0xffffffff - ret = b'' - while (value & 0xffffff80) != 0: - b = (value & 0x7f) | 0x80 - ret += struct.pack('B', b) + value &= 0xFFFFFFFF + ret = b"" + while (value & 0xFFFFFF80) != 0: + b = (value & 0x7F) | 0x80 + ret += struct.pack("B", b) value >>= 7 - ret += struct.pack('B', value) + ret += struct.pack("B", value) return ret @@ -248,7 +250,7 @@ def decode(cls, data): @classmethod def encode(cls, value): # bring it in line with the java binary repr - value &= 0xffffffff + value &= 0xFFFFFFFF return UnsignedVarInt32.encode((value << 1) ^ (value >> 31)) @@ -260,24 +262,24 @@ def decode(cls, data): b = data.read(1) if not (b & 0x80): break - value |= (b & 0x7f) << i + value |= (b & 0x7F) << i i += 7 if i > 63: - raise ValueError('Invalid value {}'.format(value)) + raise ValueError("Invalid value {}".format(value)) value |= b << i return (value >> 1) ^ -(value & 1) @classmethod def encode(cls, value): # bring it in line with the java binary repr - value &= 0xffffffffffffffff + value &= 0xFFFFFFFFFFFFFFFF v = (value << 1) ^ (value >> 63) - ret = b'' - while (v & 0xffffffffffffff80) != 0: - b = (value & 0x7f) | 0x80 - ret += struct.pack('B', b) + ret = b"" + while (v & 0xFFFFFFFFFFFFFF80) != 0: + b = (value & 0x7F) | 0x80 + ret += struct.pack("B", b) v >>= 7 - ret += struct.pack('B', v) + ret += struct.pack("B", v) return ret @@ -288,7 +290,7 @@ def decode(self, data): return None value = data.read(length) if len(value) != length: - raise ValueError('Buffer underrun decoding string') + raise ValueError("Buffer underrun decoding string") return value.decode(self.encoding) def encode(self, value): @@ -309,7 +311,7 @@ def decode(cls, data): for i in range(num_fields): tag = UnsignedVarInt32.decode(data) if tag <= prev_tag: - raise ValueError('Invalid or out-of-order tag {}'.format(tag)) + raise ValueError("Invalid or out-of-order tag {}".format(tag)) prev_tag = tag size = UnsignedVarInt32.decode(data) val = data.read(size) @@ -321,8 +323,10 @@ def encode(cls, value): ret = UnsignedVarInt32.encode(len(value)) for k, v in value.items(): # do we allow for other data types ?? It could get complicated really fast - assert isinstance(v, bytes), 'Value {} is not a byte array'.format(v) - assert isinstance(k, int) and k > 0, 'Key {} is not a positive integer'.format(k) + assert isinstance(v, bytes), "Value {} is not a byte array".format(v) + assert ( + isinstance(k, int) and k > 0 + ), "Key {} is not a positive integer".format(k) ret += UnsignedVarInt32.encode(k) ret += v return ret @@ -336,7 +340,7 @@ def decode(cls, data): return None value = data.read(length) if len(value) != length: - raise ValueError('Buffer underrun decoding Bytes') + raise ValueError("Buffer underrun decoding Bytes") return value @classmethod @@ -348,13 +352,12 @@ def encode(cls, value): class CompactArray(Array): - def encode(self, items): if items is None: return UnsignedVarInt32.encode(0) - return b''.join( - [UnsignedVarInt32.encode(len(items) + 1)] + - [self.array_of.encode(item) for item in items] + return b"".join( + [UnsignedVarInt32.encode(len(items) + 1)] + + [self.array_of.encode(item) for item in items] ) def decode(self, data): @@ -362,4 +365,3 @@ def decode(self, data): if length == -1: return None return [self.array_of.decode(data) for _ in range(length)] - diff --git a/docs/api.rst b/docs/api.rst index ba75a088..9ca4f590 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -176,6 +176,6 @@ Structs Protocols ^^^^^^^^^ -.. autoclass:: kafka.protocol.produce.ProduceRequest +.. autoclass:: aiokafka.protocol.produce.ProduceRequest :member-order: alphabetical :members: diff --git a/kafka/__init__.py b/kafka/__init__.py index 16d29b2e..a40686e6 100644 --- a/kafka/__init__.py +++ b/kafka/__init__.py @@ -1,7 +1,6 @@ from __future__ import absolute_import __title__ = 'kafka' -from kafka.version import __version__ __author__ = 'Dana Powers' __license__ = 'Apache License 2.0' __copyright__ = 'Copyright 2016 Dana Powers, David Arthur, and Contributors' diff --git a/kafka/protocol/__init__.py b/kafka/protocol/__init__.py deleted file mode 100644 index 025447f9..00000000 --- a/kafka/protocol/__init__.py +++ /dev/null @@ -1,49 +0,0 @@ -from __future__ import absolute_import - - -API_KEYS = { - 0: 'Produce', - 1: 'Fetch', - 2: 'ListOffsets', - 3: 'Metadata', - 4: 'LeaderAndIsr', - 5: 'StopReplica', - 6: 'UpdateMetadata', - 7: 'ControlledShutdown', - 8: 'OffsetCommit', - 9: 'OffsetFetch', - 10: 'FindCoordinator', - 11: 'JoinGroup', - 12: 'Heartbeat', - 13: 'LeaveGroup', - 14: 'SyncGroup', - 15: 'DescribeGroups', - 16: 'ListGroups', - 17: 'SaslHandshake', - 18: 'ApiVersions', - 19: 'CreateTopics', - 20: 'DeleteTopics', - 21: 'DeleteRecords', - 22: 'InitProducerId', - 23: 'OffsetForLeaderEpoch', - 24: 'AddPartitionsToTxn', - 25: 'AddOffsetsToTxn', - 26: 'EndTxn', - 27: 'WriteTxnMarkers', - 28: 'TxnOffsetCommit', - 29: 'DescribeAcls', - 30: 'CreateAcls', - 31: 'DeleteAcls', - 32: 'DescribeConfigs', - 33: 'AlterConfigs', - 36: 'SaslAuthenticate', - 37: 'CreatePartitions', - 38: 'CreateDelegationToken', - 39: 'RenewDelegationToken', - 40: 'ExpireDelegationToken', - 41: 'DescribeDelegationToken', - 42: 'DeleteGroups', - 45: 'AlterPartitionReassignments', - 46: 'ListPartitionReassignments', - 48: 'DescribeClientQuotas', -} diff --git a/kafka/protocol/admin.py b/kafka/protocol/admin.py deleted file mode 100644 index f9d61e5c..00000000 --- a/kafka/protocol/admin.py +++ /dev/null @@ -1,1054 +0,0 @@ -from __future__ import absolute_import - -from kafka.protocol.api import Request, Response -from kafka.protocol.types import Array, Boolean, Bytes, Int8, Int16, Int32, Int64, Schema, String, Float64, CompactString, CompactArray, TaggedFields - - -class ApiVersionResponse_v0(Response): - API_KEY = 18 - API_VERSION = 0 - SCHEMA = Schema( - ('error_code', Int16), - ('api_versions', Array( - ('api_key', Int16), - ('min_version', Int16), - ('max_version', Int16))) - ) - - -class ApiVersionResponse_v1(Response): - API_KEY = 18 - API_VERSION = 1 - SCHEMA = Schema( - ('error_code', Int16), - ('api_versions', Array( - ('api_key', Int16), - ('min_version', Int16), - ('max_version', Int16))), - ('throttle_time_ms', Int32) - ) - - -class ApiVersionResponse_v2(Response): - API_KEY = 18 - API_VERSION = 2 - SCHEMA = ApiVersionResponse_v1.SCHEMA - - -class ApiVersionRequest_v0(Request): - API_KEY = 18 - API_VERSION = 0 - RESPONSE_TYPE = ApiVersionResponse_v0 - SCHEMA = Schema() - - -class ApiVersionRequest_v1(Request): - API_KEY = 18 - API_VERSION = 1 - RESPONSE_TYPE = ApiVersionResponse_v1 - SCHEMA = ApiVersionRequest_v0.SCHEMA - - -class ApiVersionRequest_v2(Request): - API_KEY = 18 - API_VERSION = 2 - RESPONSE_TYPE = ApiVersionResponse_v1 - SCHEMA = ApiVersionRequest_v0.SCHEMA - - -ApiVersionRequest = [ - ApiVersionRequest_v0, ApiVersionRequest_v1, ApiVersionRequest_v2, -] -ApiVersionResponse = [ - ApiVersionResponse_v0, ApiVersionResponse_v1, ApiVersionResponse_v2, -] - - -class CreateTopicsResponse_v0(Response): - API_KEY = 19 - API_VERSION = 0 - SCHEMA = Schema( - ('topic_errors', Array( - ('topic', String('utf-8')), - ('error_code', Int16))) - ) - - -class CreateTopicsResponse_v1(Response): - API_KEY = 19 - API_VERSION = 1 - SCHEMA = Schema( - ('topic_errors', Array( - ('topic', String('utf-8')), - ('error_code', Int16), - ('error_message', String('utf-8')))) - ) - - -class CreateTopicsResponse_v2(Response): - API_KEY = 19 - API_VERSION = 2 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('topic_errors', Array( - ('topic', String('utf-8')), - ('error_code', Int16), - ('error_message', String('utf-8')))) - ) - -class CreateTopicsResponse_v3(Response): - API_KEY = 19 - API_VERSION = 3 - SCHEMA = CreateTopicsResponse_v2.SCHEMA - - -class CreateTopicsRequest_v0(Request): - API_KEY = 19 - API_VERSION = 0 - RESPONSE_TYPE = CreateTopicsResponse_v0 - SCHEMA = Schema( - ('create_topic_requests', Array( - ('topic', String('utf-8')), - ('num_partitions', Int32), - ('replication_factor', Int16), - ('replica_assignment', Array( - ('partition_id', Int32), - ('replicas', Array(Int32)))), - ('configs', Array( - ('config_key', String('utf-8')), - ('config_value', String('utf-8')))))), - ('timeout', Int32) - ) - - -class CreateTopicsRequest_v1(Request): - API_KEY = 19 - API_VERSION = 1 - RESPONSE_TYPE = CreateTopicsResponse_v1 - SCHEMA = Schema( - ('create_topic_requests', Array( - ('topic', String('utf-8')), - ('num_partitions', Int32), - ('replication_factor', Int16), - ('replica_assignment', Array( - ('partition_id', Int32), - ('replicas', Array(Int32)))), - ('configs', Array( - ('config_key', String('utf-8')), - ('config_value', String('utf-8')))))), - ('timeout', Int32), - ('validate_only', Boolean) - ) - - -class CreateTopicsRequest_v2(Request): - API_KEY = 19 - API_VERSION = 2 - RESPONSE_TYPE = CreateTopicsResponse_v2 - SCHEMA = CreateTopicsRequest_v1.SCHEMA - - -class CreateTopicsRequest_v3(Request): - API_KEY = 19 - API_VERSION = 3 - RESPONSE_TYPE = CreateTopicsResponse_v3 - SCHEMA = CreateTopicsRequest_v1.SCHEMA - - -CreateTopicsRequest = [ - CreateTopicsRequest_v0, CreateTopicsRequest_v1, - CreateTopicsRequest_v2, CreateTopicsRequest_v3, -] -CreateTopicsResponse = [ - CreateTopicsResponse_v0, CreateTopicsResponse_v1, - CreateTopicsResponse_v2, CreateTopicsResponse_v3, -] - - -class DeleteTopicsResponse_v0(Response): - API_KEY = 20 - API_VERSION = 0 - SCHEMA = Schema( - ('topic_error_codes', Array( - ('topic', String('utf-8')), - ('error_code', Int16))) - ) - - -class DeleteTopicsResponse_v1(Response): - API_KEY = 20 - API_VERSION = 1 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('topic_error_codes', Array( - ('topic', String('utf-8')), - ('error_code', Int16))) - ) - - -class DeleteTopicsResponse_v2(Response): - API_KEY = 20 - API_VERSION = 2 - SCHEMA = DeleteTopicsResponse_v1.SCHEMA - - -class DeleteTopicsResponse_v3(Response): - API_KEY = 20 - API_VERSION = 3 - SCHEMA = DeleteTopicsResponse_v1.SCHEMA - - -class DeleteTopicsRequest_v0(Request): - API_KEY = 20 - API_VERSION = 0 - RESPONSE_TYPE = DeleteTopicsResponse_v0 - SCHEMA = Schema( - ('topics', Array(String('utf-8'))), - ('timeout', Int32) - ) - - -class DeleteTopicsRequest_v1(Request): - API_KEY = 20 - API_VERSION = 1 - RESPONSE_TYPE = DeleteTopicsResponse_v1 - SCHEMA = DeleteTopicsRequest_v0.SCHEMA - - -class DeleteTopicsRequest_v2(Request): - API_KEY = 20 - API_VERSION = 2 - RESPONSE_TYPE = DeleteTopicsResponse_v2 - SCHEMA = DeleteTopicsRequest_v0.SCHEMA - - -class DeleteTopicsRequest_v3(Request): - API_KEY = 20 - API_VERSION = 3 - RESPONSE_TYPE = DeleteTopicsResponse_v3 - SCHEMA = DeleteTopicsRequest_v0.SCHEMA - - -DeleteTopicsRequest = [ - DeleteTopicsRequest_v0, DeleteTopicsRequest_v1, - DeleteTopicsRequest_v2, DeleteTopicsRequest_v3, -] -DeleteTopicsResponse = [ - DeleteTopicsResponse_v0, DeleteTopicsResponse_v1, - DeleteTopicsResponse_v2, DeleteTopicsResponse_v3, -] - - -class ListGroupsResponse_v0(Response): - API_KEY = 16 - API_VERSION = 0 - SCHEMA = Schema( - ('error_code', Int16), - ('groups', Array( - ('group', String('utf-8')), - ('protocol_type', String('utf-8')))) - ) - - -class ListGroupsResponse_v1(Response): - API_KEY = 16 - API_VERSION = 1 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('error_code', Int16), - ('groups', Array( - ('group', String('utf-8')), - ('protocol_type', String('utf-8')))) - ) - -class ListGroupsResponse_v2(Response): - API_KEY = 16 - API_VERSION = 2 - SCHEMA = ListGroupsResponse_v1.SCHEMA - - -class ListGroupsRequest_v0(Request): - API_KEY = 16 - API_VERSION = 0 - RESPONSE_TYPE = ListGroupsResponse_v0 - SCHEMA = Schema() - - -class ListGroupsRequest_v1(Request): - API_KEY = 16 - API_VERSION = 1 - RESPONSE_TYPE = ListGroupsResponse_v1 - SCHEMA = ListGroupsRequest_v0.SCHEMA - -class ListGroupsRequest_v2(Request): - API_KEY = 16 - API_VERSION = 1 - RESPONSE_TYPE = ListGroupsResponse_v2 - SCHEMA = ListGroupsRequest_v0.SCHEMA - - -ListGroupsRequest = [ - ListGroupsRequest_v0, ListGroupsRequest_v1, - ListGroupsRequest_v2, -] -ListGroupsResponse = [ - ListGroupsResponse_v0, ListGroupsResponse_v1, - ListGroupsResponse_v2, -] - - -class DescribeGroupsResponse_v0(Response): - API_KEY = 15 - API_VERSION = 0 - SCHEMA = Schema( - ('groups', Array( - ('error_code', Int16), - ('group', String('utf-8')), - ('state', String('utf-8')), - ('protocol_type', String('utf-8')), - ('protocol', String('utf-8')), - ('members', Array( - ('member_id', String('utf-8')), - ('client_id', String('utf-8')), - ('client_host', String('utf-8')), - ('member_metadata', Bytes), - ('member_assignment', Bytes))))) - ) - - -class DescribeGroupsResponse_v1(Response): - API_KEY = 15 - API_VERSION = 1 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('groups', Array( - ('error_code', Int16), - ('group', String('utf-8')), - ('state', String('utf-8')), - ('protocol_type', String('utf-8')), - ('protocol', String('utf-8')), - ('members', Array( - ('member_id', String('utf-8')), - ('client_id', String('utf-8')), - ('client_host', String('utf-8')), - ('member_metadata', Bytes), - ('member_assignment', Bytes))))) - ) - - -class DescribeGroupsResponse_v2(Response): - API_KEY = 15 - API_VERSION = 2 - SCHEMA = DescribeGroupsResponse_v1.SCHEMA - - -class DescribeGroupsResponse_v3(Response): - API_KEY = 15 - API_VERSION = 3 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('groups', Array( - ('error_code', Int16), - ('group', String('utf-8')), - ('state', String('utf-8')), - ('protocol_type', String('utf-8')), - ('protocol', String('utf-8')), - ('members', Array( - ('member_id', String('utf-8')), - ('client_id', String('utf-8')), - ('client_host', String('utf-8')), - ('member_metadata', Bytes), - ('member_assignment', Bytes)))), - ('authorized_operations', Int32)) - ) - - -class DescribeGroupsRequest_v0(Request): - API_KEY = 15 - API_VERSION = 0 - RESPONSE_TYPE = DescribeGroupsResponse_v0 - SCHEMA = Schema( - ('groups', Array(String('utf-8'))) - ) - - -class DescribeGroupsRequest_v1(Request): - API_KEY = 15 - API_VERSION = 1 - RESPONSE_TYPE = DescribeGroupsResponse_v1 - SCHEMA = DescribeGroupsRequest_v0.SCHEMA - - -class DescribeGroupsRequest_v2(Request): - API_KEY = 15 - API_VERSION = 2 - RESPONSE_TYPE = DescribeGroupsResponse_v2 - SCHEMA = DescribeGroupsRequest_v0.SCHEMA - - -class DescribeGroupsRequest_v3(Request): - API_KEY = 15 - API_VERSION = 3 - RESPONSE_TYPE = DescribeGroupsResponse_v2 - SCHEMA = Schema( - ('groups', Array(String('utf-8'))), - ('include_authorized_operations', Boolean) - ) - - -DescribeGroupsRequest = [ - DescribeGroupsRequest_v0, DescribeGroupsRequest_v1, - DescribeGroupsRequest_v2, DescribeGroupsRequest_v3, -] -DescribeGroupsResponse = [ - DescribeGroupsResponse_v0, DescribeGroupsResponse_v1, - DescribeGroupsResponse_v2, DescribeGroupsResponse_v3, -] - - -class SaslHandShakeResponse_v0(Response): - API_KEY = 17 - API_VERSION = 0 - SCHEMA = Schema( - ('error_code', Int16), - ('enabled_mechanisms', Array(String('utf-8'))) - ) - - -class SaslHandShakeResponse_v1(Response): - API_KEY = 17 - API_VERSION = 1 - SCHEMA = SaslHandShakeResponse_v0.SCHEMA - - -class SaslHandShakeRequest_v0(Request): - API_KEY = 17 - API_VERSION = 0 - RESPONSE_TYPE = SaslHandShakeResponse_v0 - SCHEMA = Schema( - ('mechanism', String('utf-8')) - ) - - -class SaslHandShakeRequest_v1(Request): - API_KEY = 17 - API_VERSION = 1 - RESPONSE_TYPE = SaslHandShakeResponse_v1 - SCHEMA = SaslHandShakeRequest_v0.SCHEMA - - -SaslHandShakeRequest = [SaslHandShakeRequest_v0, SaslHandShakeRequest_v1] -SaslHandShakeResponse = [SaslHandShakeResponse_v0, SaslHandShakeResponse_v1] - - -class DescribeAclsResponse_v0(Response): - API_KEY = 29 - API_VERSION = 0 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('error_code', Int16), - ('error_message', String('utf-8')), - ('resources', Array( - ('resource_type', Int8), - ('resource_name', String('utf-8')), - ('acls', Array( - ('principal', String('utf-8')), - ('host', String('utf-8')), - ('operation', Int8), - ('permission_type', Int8))))) - ) - - -class DescribeAclsResponse_v1(Response): - API_KEY = 29 - API_VERSION = 1 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('error_code', Int16), - ('error_message', String('utf-8')), - ('resources', Array( - ('resource_type', Int8), - ('resource_name', String('utf-8')), - ('resource_pattern_type', Int8), - ('acls', Array( - ('principal', String('utf-8')), - ('host', String('utf-8')), - ('operation', Int8), - ('permission_type', Int8))))) - ) - - -class DescribeAclsResponse_v2(Response): - API_KEY = 29 - API_VERSION = 2 - SCHEMA = DescribeAclsResponse_v1.SCHEMA - - -class DescribeAclsRequest_v0(Request): - API_KEY = 29 - API_VERSION = 0 - RESPONSE_TYPE = DescribeAclsResponse_v0 - SCHEMA = Schema( - ('resource_type', Int8), - ('resource_name', String('utf-8')), - ('principal', String('utf-8')), - ('host', String('utf-8')), - ('operation', Int8), - ('permission_type', Int8) - ) - - -class DescribeAclsRequest_v1(Request): - API_KEY = 29 - API_VERSION = 1 - RESPONSE_TYPE = DescribeAclsResponse_v1 - SCHEMA = Schema( - ('resource_type', Int8), - ('resource_name', String('utf-8')), - ('resource_pattern_type_filter', Int8), - ('principal', String('utf-8')), - ('host', String('utf-8')), - ('operation', Int8), - ('permission_type', Int8) - ) - - -class DescribeAclsRequest_v2(Request): - """ - Enable flexible version - """ - API_KEY = 29 - API_VERSION = 2 - RESPONSE_TYPE = DescribeAclsResponse_v2 - SCHEMA = DescribeAclsRequest_v1.SCHEMA - - -DescribeAclsRequest = [DescribeAclsRequest_v0, DescribeAclsRequest_v1] -DescribeAclsResponse = [DescribeAclsResponse_v0, DescribeAclsResponse_v1] - -class CreateAclsResponse_v0(Response): - API_KEY = 30 - API_VERSION = 0 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('creation_responses', Array( - ('error_code', Int16), - ('error_message', String('utf-8')))) - ) - -class CreateAclsResponse_v1(Response): - API_KEY = 30 - API_VERSION = 1 - SCHEMA = CreateAclsResponse_v0.SCHEMA - -class CreateAclsRequest_v0(Request): - API_KEY = 30 - API_VERSION = 0 - RESPONSE_TYPE = CreateAclsResponse_v0 - SCHEMA = Schema( - ('creations', Array( - ('resource_type', Int8), - ('resource_name', String('utf-8')), - ('principal', String('utf-8')), - ('host', String('utf-8')), - ('operation', Int8), - ('permission_type', Int8))) - ) - -class CreateAclsRequest_v1(Request): - API_KEY = 30 - API_VERSION = 1 - RESPONSE_TYPE = CreateAclsResponse_v1 - SCHEMA = Schema( - ('creations', Array( - ('resource_type', Int8), - ('resource_name', String('utf-8')), - ('resource_pattern_type', Int8), - ('principal', String('utf-8')), - ('host', String('utf-8')), - ('operation', Int8), - ('permission_type', Int8))) - ) - -CreateAclsRequest = [CreateAclsRequest_v0, CreateAclsRequest_v1] -CreateAclsResponse = [CreateAclsResponse_v0, CreateAclsResponse_v1] - -class DeleteAclsResponse_v0(Response): - API_KEY = 31 - API_VERSION = 0 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('filter_responses', Array( - ('error_code', Int16), - ('error_message', String('utf-8')), - ('matching_acls', Array( - ('error_code', Int16), - ('error_message', String('utf-8')), - ('resource_type', Int8), - ('resource_name', String('utf-8')), - ('principal', String('utf-8')), - ('host', String('utf-8')), - ('operation', Int8), - ('permission_type', Int8))))) - ) - -class DeleteAclsResponse_v1(Response): - API_KEY = 31 - API_VERSION = 1 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('filter_responses', Array( - ('error_code', Int16), - ('error_message', String('utf-8')), - ('matching_acls', Array( - ('error_code', Int16), - ('error_message', String('utf-8')), - ('resource_type', Int8), - ('resource_name', String('utf-8')), - ('resource_pattern_type', Int8), - ('principal', String('utf-8')), - ('host', String('utf-8')), - ('operation', Int8), - ('permission_type', Int8))))) - ) - -class DeleteAclsRequest_v0(Request): - API_KEY = 31 - API_VERSION = 0 - RESPONSE_TYPE = DeleteAclsResponse_v0 - SCHEMA = Schema( - ('filters', Array( - ('resource_type', Int8), - ('resource_name', String('utf-8')), - ('principal', String('utf-8')), - ('host', String('utf-8')), - ('operation', Int8), - ('permission_type', Int8))) - ) - -class DeleteAclsRequest_v1(Request): - API_KEY = 31 - API_VERSION = 1 - RESPONSE_TYPE = DeleteAclsResponse_v1 - SCHEMA = Schema( - ('filters', Array( - ('resource_type', Int8), - ('resource_name', String('utf-8')), - ('resource_pattern_type_filter', Int8), - ('principal', String('utf-8')), - ('host', String('utf-8')), - ('operation', Int8), - ('permission_type', Int8))) - ) - -DeleteAclsRequest = [DeleteAclsRequest_v0, DeleteAclsRequest_v1] -DeleteAclsResponse = [DeleteAclsResponse_v0, DeleteAclsResponse_v1] - -class AlterConfigsResponse_v0(Response): - API_KEY = 33 - API_VERSION = 0 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('resources', Array( - ('error_code', Int16), - ('error_message', String('utf-8')), - ('resource_type', Int8), - ('resource_name', String('utf-8')))) - ) - - -class AlterConfigsResponse_v1(Response): - API_KEY = 33 - API_VERSION = 1 - SCHEMA = AlterConfigsResponse_v0.SCHEMA - - -class AlterConfigsRequest_v0(Request): - API_KEY = 33 - API_VERSION = 0 - RESPONSE_TYPE = AlterConfigsResponse_v0 - SCHEMA = Schema( - ('resources', Array( - ('resource_type', Int8), - ('resource_name', String('utf-8')), - ('config_entries', Array( - ('config_name', String('utf-8')), - ('config_value', String('utf-8')))))), - ('validate_only', Boolean) - ) - -class AlterConfigsRequest_v1(Request): - API_KEY = 33 - API_VERSION = 1 - RESPONSE_TYPE = AlterConfigsResponse_v1 - SCHEMA = AlterConfigsRequest_v0.SCHEMA - -AlterConfigsRequest = [AlterConfigsRequest_v0, AlterConfigsRequest_v1] -AlterConfigsResponse = [AlterConfigsResponse_v0, AlterConfigsRequest_v1] - - -class DescribeConfigsResponse_v0(Response): - API_KEY = 32 - API_VERSION = 0 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('resources', Array( - ('error_code', Int16), - ('error_message', String('utf-8')), - ('resource_type', Int8), - ('resource_name', String('utf-8')), - ('config_entries', Array( - ('config_names', String('utf-8')), - ('config_value', String('utf-8')), - ('read_only', Boolean), - ('is_default', Boolean), - ('is_sensitive', Boolean))))) - ) - -class DescribeConfigsResponse_v1(Response): - API_KEY = 32 - API_VERSION = 1 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('resources', Array( - ('error_code', Int16), - ('error_message', String('utf-8')), - ('resource_type', Int8), - ('resource_name', String('utf-8')), - ('config_entries', Array( - ('config_names', String('utf-8')), - ('config_value', String('utf-8')), - ('read_only', Boolean), - ('is_default', Boolean), - ('is_sensitive', Boolean), - ('config_synonyms', Array( - ('config_name', String('utf-8')), - ('config_value', String('utf-8')), - ('config_source', Int8))))))) - ) - -class DescribeConfigsResponse_v2(Response): - API_KEY = 32 - API_VERSION = 2 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('resources', Array( - ('error_code', Int16), - ('error_message', String('utf-8')), - ('resource_type', Int8), - ('resource_name', String('utf-8')), - ('config_entries', Array( - ('config_names', String('utf-8')), - ('config_value', String('utf-8')), - ('read_only', Boolean), - ('config_source', Int8), - ('is_sensitive', Boolean), - ('config_synonyms', Array( - ('config_name', String('utf-8')), - ('config_value', String('utf-8')), - ('config_source', Int8))))))) - ) - -class DescribeConfigsRequest_v0(Request): - API_KEY = 32 - API_VERSION = 0 - RESPONSE_TYPE = DescribeConfigsResponse_v0 - SCHEMA = Schema( - ('resources', Array( - ('resource_type', Int8), - ('resource_name', String('utf-8')), - ('config_names', Array(String('utf-8'))))) - ) - -class DescribeConfigsRequest_v1(Request): - API_KEY = 32 - API_VERSION = 1 - RESPONSE_TYPE = DescribeConfigsResponse_v1 - SCHEMA = Schema( - ('resources', Array( - ('resource_type', Int8), - ('resource_name', String('utf-8')), - ('config_names', Array(String('utf-8'))))), - ('include_synonyms', Boolean) - ) - - -class DescribeConfigsRequest_v2(Request): - API_KEY = 32 - API_VERSION = 2 - RESPONSE_TYPE = DescribeConfigsResponse_v2 - SCHEMA = DescribeConfigsRequest_v1.SCHEMA - - -DescribeConfigsRequest = [ - DescribeConfigsRequest_v0, DescribeConfigsRequest_v1, - DescribeConfigsRequest_v2, -] -DescribeConfigsResponse = [ - DescribeConfigsResponse_v0, DescribeConfigsResponse_v1, - DescribeConfigsResponse_v2, -] - - -class SaslAuthenticateResponse_v0(Response): - API_KEY = 36 - API_VERSION = 0 - SCHEMA = Schema( - ('error_code', Int16), - ('error_message', String('utf-8')), - ('sasl_auth_bytes', Bytes) - ) - - -class SaslAuthenticateResponse_v1(Response): - API_KEY = 36 - API_VERSION = 1 - SCHEMA = Schema( - ('error_code', Int16), - ('error_message', String('utf-8')), - ('sasl_auth_bytes', Bytes), - ('session_lifetime_ms', Int64) - ) - - -class SaslAuthenticateRequest_v0(Request): - API_KEY = 36 - API_VERSION = 0 - RESPONSE_TYPE = SaslAuthenticateResponse_v0 - SCHEMA = Schema( - ('sasl_auth_bytes', Bytes) - ) - - -class SaslAuthenticateRequest_v1(Request): - API_KEY = 36 - API_VERSION = 1 - RESPONSE_TYPE = SaslAuthenticateResponse_v1 - SCHEMA = SaslAuthenticateRequest_v0.SCHEMA - - -SaslAuthenticateRequest = [ - SaslAuthenticateRequest_v0, SaslAuthenticateRequest_v1, -] -SaslAuthenticateResponse = [ - SaslAuthenticateResponse_v0, SaslAuthenticateResponse_v1, -] - - -class CreatePartitionsResponse_v0(Response): - API_KEY = 37 - API_VERSION = 0 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('topic_errors', Array( - ('topic', String('utf-8')), - ('error_code', Int16), - ('error_message', String('utf-8')))) - ) - - -class CreatePartitionsResponse_v1(Response): - API_KEY = 37 - API_VERSION = 1 - SCHEMA = CreatePartitionsResponse_v0.SCHEMA - - -class CreatePartitionsRequest_v0(Request): - API_KEY = 37 - API_VERSION = 0 - RESPONSE_TYPE = CreatePartitionsResponse_v0 - SCHEMA = Schema( - ('topic_partitions', Array( - ('topic', String('utf-8')), - ('new_partitions', Schema( - ('count', Int32), - ('assignment', Array(Array(Int32))))))), - ('timeout', Int32), - ('validate_only', Boolean) - ) - - -class CreatePartitionsRequest_v1(Request): - API_KEY = 37 - API_VERSION = 1 - SCHEMA = CreatePartitionsRequest_v0.SCHEMA - RESPONSE_TYPE = CreatePartitionsResponse_v1 - - -CreatePartitionsRequest = [ - CreatePartitionsRequest_v0, CreatePartitionsRequest_v1, -] -CreatePartitionsResponse = [ - CreatePartitionsResponse_v0, CreatePartitionsResponse_v1, -] - - -class DeleteGroupsResponse_v0(Response): - API_KEY = 42 - API_VERSION = 0 - SCHEMA = Schema( - ("throttle_time_ms", Int32), - ("results", Array( - ("group_id", String("utf-8")), - ("error_code", Int16))) - ) - - -class DeleteGroupsResponse_v1(Response): - API_KEY = 42 - API_VERSION = 1 - SCHEMA = DeleteGroupsResponse_v0.SCHEMA - - -class DeleteGroupsRequest_v0(Request): - API_KEY = 42 - API_VERSION = 0 - RESPONSE_TYPE = DeleteGroupsResponse_v0 - SCHEMA = Schema( - ("groups_names", Array(String("utf-8"))) - ) - - -class DeleteGroupsRequest_v1(Request): - API_KEY = 42 - API_VERSION = 1 - RESPONSE_TYPE = DeleteGroupsResponse_v1 - SCHEMA = DeleteGroupsRequest_v0.SCHEMA - - -DeleteGroupsRequest = [ - DeleteGroupsRequest_v0, DeleteGroupsRequest_v1 -] - -DeleteGroupsResponse = [ - DeleteGroupsResponse_v0, DeleteGroupsResponse_v1 -] - - -class DescribeClientQuotasResponse_v0(Request): - API_KEY = 48 - API_VERSION = 0 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('error_code', Int16), - ('error_message', String('utf-8')), - ('entries', Array( - ('entity', Array( - ('entity_type', String('utf-8')), - ('entity_name', String('utf-8')))), - ('values', Array( - ('name', String('utf-8')), - ('value', Float64))))), - ) - - -class DescribeClientQuotasRequest_v0(Request): - API_KEY = 48 - API_VERSION = 0 - RESPONSE_TYPE = DescribeClientQuotasResponse_v0 - SCHEMA = Schema( - ('components', Array( - ('entity_type', String('utf-8')), - ('match_type', Int8), - ('match', String('utf-8')), - )), - ('strict', Boolean) - ) - - -DescribeClientQuotasRequest = [ - DescribeClientQuotasRequest_v0, -] - -DescribeClientQuotasResponse = [ - DescribeClientQuotasResponse_v0, -] - - -class AlterPartitionReassignmentsResponse_v0(Response): - API_KEY = 45 - API_VERSION = 0 - SCHEMA = Schema( - ("throttle_time_ms", Int32), - ("error_code", Int16), - ("error_message", CompactString("utf-8")), - ("responses", CompactArray( - ("name", CompactString("utf-8")), - ("partitions", CompactArray( - ("partition_index", Int32), - ("error_code", Int16), - ("error_message", CompactString("utf-8")), - ("tags", TaggedFields) - )), - ("tags", TaggedFields) - )), - ("tags", TaggedFields) - ) - - -class AlterPartitionReassignmentsRequest_v0(Request): - FLEXIBLE_VERSION = True - API_KEY = 45 - API_VERSION = 0 - RESPONSE_TYPE = AlterPartitionReassignmentsResponse_v0 - SCHEMA = Schema( - ("timeout_ms", Int32), - ("topics", CompactArray( - ("name", CompactString("utf-8")), - ("partitions", CompactArray( - ("partition_index", Int32), - ("replicas", CompactArray(Int32)), - ("tags", TaggedFields) - )), - ("tags", TaggedFields) - )), - ("tags", TaggedFields) - ) - - -AlterPartitionReassignmentsRequest = [AlterPartitionReassignmentsRequest_v0] - -AlterPartitionReassignmentsResponse = [AlterPartitionReassignmentsResponse_v0] - - -class ListPartitionReassignmentsResponse_v0(Response): - API_KEY = 46 - API_VERSION = 0 - SCHEMA = Schema( - ("throttle_time_ms", Int32), - ("error_code", Int16), - ("error_message", CompactString("utf-8")), - ("topics", CompactArray( - ("name", CompactString("utf-8")), - ("partitions", CompactArray( - ("partition_index", Int32), - ("replicas", CompactArray(Int32)), - ("adding_replicas", CompactArray(Int32)), - ("removing_replicas", CompactArray(Int32)), - ("tags", TaggedFields) - )), - ("tags", TaggedFields) - )), - ("tags", TaggedFields) - ) - - -class ListPartitionReassignmentsRequest_v0(Request): - FLEXIBLE_VERSION = True - API_KEY = 46 - API_VERSION = 0 - RESPONSE_TYPE = ListPartitionReassignmentsResponse_v0 - SCHEMA = Schema( - ("timeout_ms", Int32), - ("topics", CompactArray( - ("name", CompactString("utf-8")), - ("partition_index", CompactArray(Int32)), - ("tags", TaggedFields) - )), - ("tags", TaggedFields) - ) - - -ListPartitionReassignmentsRequest = [ListPartitionReassignmentsRequest_v0] - -ListPartitionReassignmentsResponse = [ListPartitionReassignmentsResponse_v0] diff --git a/kafka/protocol/commit.py b/kafka/protocol/commit.py deleted file mode 100644 index 31fc2370..00000000 --- a/kafka/protocol/commit.py +++ /dev/null @@ -1,255 +0,0 @@ -from __future__ import absolute_import - -from kafka.protocol.api import Request, Response -from kafka.protocol.types import Array, Int8, Int16, Int32, Int64, Schema, String - - -class OffsetCommitResponse_v0(Response): - API_KEY = 8 - API_VERSION = 0 - SCHEMA = Schema( - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('error_code', Int16))))) - ) - - -class OffsetCommitResponse_v1(Response): - API_KEY = 8 - API_VERSION = 1 - SCHEMA = OffsetCommitResponse_v0.SCHEMA - - -class OffsetCommitResponse_v2(Response): - API_KEY = 8 - API_VERSION = 2 - SCHEMA = OffsetCommitResponse_v1.SCHEMA - - -class OffsetCommitResponse_v3(Response): - API_KEY = 8 - API_VERSION = 3 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('error_code', Int16))))) - ) - - -class OffsetCommitRequest_v0(Request): - API_KEY = 8 - API_VERSION = 0 # Zookeeper-backed storage - RESPONSE_TYPE = OffsetCommitResponse_v0 - SCHEMA = Schema( - ('consumer_group', String('utf-8')), - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('offset', Int64), - ('metadata', String('utf-8')))))) - ) - - -class OffsetCommitRequest_v1(Request): - API_KEY = 8 - API_VERSION = 1 # Kafka-backed storage - RESPONSE_TYPE = OffsetCommitResponse_v1 - SCHEMA = Schema( - ('consumer_group', String('utf-8')), - ('consumer_group_generation_id', Int32), - ('consumer_id', String('utf-8')), - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('offset', Int64), - ('timestamp', Int64), - ('metadata', String('utf-8')))))) - ) - - -class OffsetCommitRequest_v2(Request): - API_KEY = 8 - API_VERSION = 2 # added retention_time, dropped timestamp - RESPONSE_TYPE = OffsetCommitResponse_v2 - SCHEMA = Schema( - ('consumer_group', String('utf-8')), - ('consumer_group_generation_id', Int32), - ('consumer_id', String('utf-8')), - ('retention_time', Int64), - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('offset', Int64), - ('metadata', String('utf-8')))))) - ) - DEFAULT_GENERATION_ID = -1 - DEFAULT_RETENTION_TIME = -1 - - -class OffsetCommitRequest_v3(Request): - API_KEY = 8 - API_VERSION = 3 - RESPONSE_TYPE = OffsetCommitResponse_v3 - SCHEMA = OffsetCommitRequest_v2.SCHEMA - - -OffsetCommitRequest = [ - OffsetCommitRequest_v0, OffsetCommitRequest_v1, - OffsetCommitRequest_v2, OffsetCommitRequest_v3 -] -OffsetCommitResponse = [ - OffsetCommitResponse_v0, OffsetCommitResponse_v1, - OffsetCommitResponse_v2, OffsetCommitResponse_v3 -] - - -class OffsetFetchResponse_v0(Response): - API_KEY = 9 - API_VERSION = 0 - SCHEMA = Schema( - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('offset', Int64), - ('metadata', String('utf-8')), - ('error_code', Int16))))) - ) - - -class OffsetFetchResponse_v1(Response): - API_KEY = 9 - API_VERSION = 1 - SCHEMA = OffsetFetchResponse_v0.SCHEMA - - -class OffsetFetchResponse_v2(Response): - # Added in KIP-88 - API_KEY = 9 - API_VERSION = 2 - SCHEMA = Schema( - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('offset', Int64), - ('metadata', String('utf-8')), - ('error_code', Int16))))), - ('error_code', Int16) - ) - - -class OffsetFetchResponse_v3(Response): - API_KEY = 9 - API_VERSION = 3 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('offset', Int64), - ('metadata', String('utf-8')), - ('error_code', Int16))))), - ('error_code', Int16) - ) - - -class OffsetFetchRequest_v0(Request): - API_KEY = 9 - API_VERSION = 0 # zookeeper-backed storage - RESPONSE_TYPE = OffsetFetchResponse_v0 - SCHEMA = Schema( - ('consumer_group', String('utf-8')), - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array(Int32)))) - ) - - -class OffsetFetchRequest_v1(Request): - API_KEY = 9 - API_VERSION = 1 # kafka-backed storage - RESPONSE_TYPE = OffsetFetchResponse_v1 - SCHEMA = OffsetFetchRequest_v0.SCHEMA - - -class OffsetFetchRequest_v2(Request): - # KIP-88: Allows passing null topics to return offsets for all partitions - # that the consumer group has a stored offset for, even if no consumer in - # the group is currently consuming that partition. - API_KEY = 9 - API_VERSION = 2 - RESPONSE_TYPE = OffsetFetchResponse_v2 - SCHEMA = OffsetFetchRequest_v1.SCHEMA - - -class OffsetFetchRequest_v3(Request): - API_KEY = 9 - API_VERSION = 3 - RESPONSE_TYPE = OffsetFetchResponse_v3 - SCHEMA = OffsetFetchRequest_v2.SCHEMA - - -OffsetFetchRequest = [ - OffsetFetchRequest_v0, OffsetFetchRequest_v1, - OffsetFetchRequest_v2, OffsetFetchRequest_v3, -] -OffsetFetchResponse = [ - OffsetFetchResponse_v0, OffsetFetchResponse_v1, - OffsetFetchResponse_v2, OffsetFetchResponse_v3, -] - - -class GroupCoordinatorResponse_v0(Response): - API_KEY = 10 - API_VERSION = 0 - SCHEMA = Schema( - ('error_code', Int16), - ('coordinator_id', Int32), - ('host', String('utf-8')), - ('port', Int32) - ) - - -class GroupCoordinatorResponse_v1(Response): - API_KEY = 10 - API_VERSION = 1 - SCHEMA = Schema( - ('error_code', Int16), - ('error_message', String('utf-8')), - ('coordinator_id', Int32), - ('host', String('utf-8')), - ('port', Int32) - ) - - -class GroupCoordinatorRequest_v0(Request): - API_KEY = 10 - API_VERSION = 0 - RESPONSE_TYPE = GroupCoordinatorResponse_v0 - SCHEMA = Schema( - ('consumer_group', String('utf-8')) - ) - - -class GroupCoordinatorRequest_v1(Request): - API_KEY = 10 - API_VERSION = 1 - RESPONSE_TYPE = GroupCoordinatorResponse_v1 - SCHEMA = Schema( - ('coordinator_key', String('utf-8')), - ('coordinator_type', Int8) - ) - - -GroupCoordinatorRequest = [GroupCoordinatorRequest_v0, GroupCoordinatorRequest_v1] -GroupCoordinatorResponse = [GroupCoordinatorResponse_v0, GroupCoordinatorResponse_v1] diff --git a/kafka/protocol/fetch.py b/kafka/protocol/fetch.py deleted file mode 100644 index f367848c..00000000 --- a/kafka/protocol/fetch.py +++ /dev/null @@ -1,386 +0,0 @@ -from __future__ import absolute_import - -from kafka.protocol.api import Request, Response -from kafka.protocol.types import Array, Int8, Int16, Int32, Int64, Schema, String, Bytes - - -class FetchResponse_v0(Response): - API_KEY = 1 - API_VERSION = 0 - SCHEMA = Schema( - ('topics', Array( - ('topics', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('error_code', Int16), - ('highwater_offset', Int64), - ('message_set', Bytes))))) - ) - - -class FetchResponse_v1(Response): - API_KEY = 1 - API_VERSION = 1 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('topics', Array( - ('topics', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('error_code', Int16), - ('highwater_offset', Int64), - ('message_set', Bytes))))) - ) - - -class FetchResponse_v2(Response): - API_KEY = 1 - API_VERSION = 2 - SCHEMA = FetchResponse_v1.SCHEMA # message format changed internally - - -class FetchResponse_v3(Response): - API_KEY = 1 - API_VERSION = 3 - SCHEMA = FetchResponse_v2.SCHEMA - - -class FetchResponse_v4(Response): - API_KEY = 1 - API_VERSION = 4 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('topics', Array( - ('topics', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('error_code', Int16), - ('highwater_offset', Int64), - ('last_stable_offset', Int64), - ('aborted_transactions', Array( - ('producer_id', Int64), - ('first_offset', Int64))), - ('message_set', Bytes))))) - ) - - -class FetchResponse_v5(Response): - API_KEY = 1 - API_VERSION = 5 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('topics', Array( - ('topics', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('error_code', Int16), - ('highwater_offset', Int64), - ('last_stable_offset', Int64), - ('log_start_offset', Int64), - ('aborted_transactions', Array( - ('producer_id', Int64), - ('first_offset', Int64))), - ('message_set', Bytes))))) - ) - - -class FetchResponse_v6(Response): - """ - Same as FetchResponse_v5. The version number is bumped up to indicate that the client supports KafkaStorageException. - The KafkaStorageException will be translated to NotLeaderForPartitionException in the response if version <= 5 - """ - API_KEY = 1 - API_VERSION = 6 - SCHEMA = FetchResponse_v5.SCHEMA - - -class FetchResponse_v7(Response): - """ - Add error_code and session_id to response - """ - API_KEY = 1 - API_VERSION = 7 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('error_code', Int16), - ('session_id', Int32), - ('topics', Array( - ('topics', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('error_code', Int16), - ('highwater_offset', Int64), - ('last_stable_offset', Int64), - ('log_start_offset', Int64), - ('aborted_transactions', Array( - ('producer_id', Int64), - ('first_offset', Int64))), - ('message_set', Bytes))))) - ) - - -class FetchResponse_v8(Response): - API_KEY = 1 - API_VERSION = 8 - SCHEMA = FetchResponse_v7.SCHEMA - - -class FetchResponse_v9(Response): - API_KEY = 1 - API_VERSION = 9 - SCHEMA = FetchResponse_v7.SCHEMA - - -class FetchResponse_v10(Response): - API_KEY = 1 - API_VERSION = 10 - SCHEMA = FetchResponse_v7.SCHEMA - - -class FetchResponse_v11(Response): - API_KEY = 1 - API_VERSION = 11 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('error_code', Int16), - ('session_id', Int32), - ('topics', Array( - ('topics', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('error_code', Int16), - ('highwater_offset', Int64), - ('last_stable_offset', Int64), - ('log_start_offset', Int64), - ('aborted_transactions', Array( - ('producer_id', Int64), - ('first_offset', Int64))), - ('preferred_read_replica', Int32), - ('message_set', Bytes))))) - ) - - -class FetchRequest_v0(Request): - API_KEY = 1 - API_VERSION = 0 - RESPONSE_TYPE = FetchResponse_v0 - SCHEMA = Schema( - ('replica_id', Int32), - ('max_wait_time', Int32), - ('min_bytes', Int32), - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('offset', Int64), - ('max_bytes', Int32))))) - ) - - -class FetchRequest_v1(Request): - API_KEY = 1 - API_VERSION = 1 - RESPONSE_TYPE = FetchResponse_v1 - SCHEMA = FetchRequest_v0.SCHEMA - - -class FetchRequest_v2(Request): - API_KEY = 1 - API_VERSION = 2 - RESPONSE_TYPE = FetchResponse_v2 - SCHEMA = FetchRequest_v1.SCHEMA - - -class FetchRequest_v3(Request): - API_KEY = 1 - API_VERSION = 3 - RESPONSE_TYPE = FetchResponse_v3 - SCHEMA = Schema( - ('replica_id', Int32), - ('max_wait_time', Int32), - ('min_bytes', Int32), - ('max_bytes', Int32), # This new field is only difference from FR_v2 - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('offset', Int64), - ('max_bytes', Int32))))) - ) - - -class FetchRequest_v4(Request): - # Adds isolation_level field - API_KEY = 1 - API_VERSION = 4 - RESPONSE_TYPE = FetchResponse_v4 - SCHEMA = Schema( - ('replica_id', Int32), - ('max_wait_time', Int32), - ('min_bytes', Int32), - ('max_bytes', Int32), - ('isolation_level', Int8), - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('offset', Int64), - ('max_bytes', Int32))))) - ) - - -class FetchRequest_v5(Request): - # This may only be used in broker-broker api calls - API_KEY = 1 - API_VERSION = 5 - RESPONSE_TYPE = FetchResponse_v5 - SCHEMA = Schema( - ('replica_id', Int32), - ('max_wait_time', Int32), - ('min_bytes', Int32), - ('max_bytes', Int32), - ('isolation_level', Int8), - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('fetch_offset', Int64), - ('log_start_offset', Int64), - ('max_bytes', Int32))))) - ) - - -class FetchRequest_v6(Request): - """ - The body of FETCH_REQUEST_V6 is the same as FETCH_REQUEST_V5. - The version number is bumped up to indicate that the client supports KafkaStorageException. - The KafkaStorageException will be translated to NotLeaderForPartitionException in the response if version <= 5 - """ - API_KEY = 1 - API_VERSION = 6 - RESPONSE_TYPE = FetchResponse_v6 - SCHEMA = FetchRequest_v5.SCHEMA - - -class FetchRequest_v7(Request): - """ - Add incremental fetch requests - """ - API_KEY = 1 - API_VERSION = 7 - RESPONSE_TYPE = FetchResponse_v7 - SCHEMA = Schema( - ('replica_id', Int32), - ('max_wait_time', Int32), - ('min_bytes', Int32), - ('max_bytes', Int32), - ('isolation_level', Int8), - ('session_id', Int32), - ('session_epoch', Int32), - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('fetch_offset', Int64), - ('log_start_offset', Int64), - ('max_bytes', Int32))))), - ('forgotten_topics_data', Array( - ('topic', String), - ('partitions', Array(Int32)) - )), - ) - - -class FetchRequest_v8(Request): - """ - bump used to indicate that on quota violation brokers send out responses before throttling. - """ - API_KEY = 1 - API_VERSION = 8 - RESPONSE_TYPE = FetchResponse_v8 - SCHEMA = FetchRequest_v7.SCHEMA - - -class FetchRequest_v9(Request): - """ - adds the current leader epoch (see KIP-320) - """ - API_KEY = 1 - API_VERSION = 9 - RESPONSE_TYPE = FetchResponse_v9 - SCHEMA = Schema( - ('replica_id', Int32), - ('max_wait_time', Int32), - ('min_bytes', Int32), - ('max_bytes', Int32), - ('isolation_level', Int8), - ('session_id', Int32), - ('session_epoch', Int32), - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('current_leader_epoch', Int32), - ('fetch_offset', Int64), - ('log_start_offset', Int64), - ('max_bytes', Int32))))), - ('forgotten_topics_data', Array( - ('topic', String), - ('partitions', Array(Int32)), - )), - ) - - -class FetchRequest_v10(Request): - """ - bumped up to indicate ZStandard capability. (see KIP-110) - """ - API_KEY = 1 - API_VERSION = 10 - RESPONSE_TYPE = FetchResponse_v10 - SCHEMA = FetchRequest_v9.SCHEMA - - -class FetchRequest_v11(Request): - """ - added rack ID to support read from followers (KIP-392) - """ - API_KEY = 1 - API_VERSION = 11 - RESPONSE_TYPE = FetchResponse_v11 - SCHEMA = Schema( - ('replica_id', Int32), - ('max_wait_time', Int32), - ('min_bytes', Int32), - ('max_bytes', Int32), - ('isolation_level', Int8), - ('session_id', Int32), - ('session_epoch', Int32), - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('current_leader_epoch', Int32), - ('fetch_offset', Int64), - ('log_start_offset', Int64), - ('max_bytes', Int32))))), - ('forgotten_topics_data', Array( - ('topic', String), - ('partitions', Array(Int32)) - )), - ('rack_id', String('utf-8')), - ) - - -FetchRequest = [ - FetchRequest_v0, FetchRequest_v1, FetchRequest_v2, - FetchRequest_v3, FetchRequest_v4, FetchRequest_v5, - FetchRequest_v6, FetchRequest_v7, FetchRequest_v8, - FetchRequest_v9, FetchRequest_v10, FetchRequest_v11, -] -FetchResponse = [ - FetchResponse_v0, FetchResponse_v1, FetchResponse_v2, - FetchResponse_v3, FetchResponse_v4, FetchResponse_v5, - FetchResponse_v6, FetchResponse_v7, FetchResponse_v8, - FetchResponse_v9, FetchResponse_v10, FetchResponse_v11, -] diff --git a/kafka/protocol/group.py b/kafka/protocol/group.py deleted file mode 100644 index bcb96553..00000000 --- a/kafka/protocol/group.py +++ /dev/null @@ -1,230 +0,0 @@ -from __future__ import absolute_import - -from kafka.protocol.api import Request, Response -from kafka.protocol.struct import Struct -from kafka.protocol.types import Array, Bytes, Int16, Int32, Schema, String - - -class JoinGroupResponse_v0(Response): - API_KEY = 11 - API_VERSION = 0 - SCHEMA = Schema( - ('error_code', Int16), - ('generation_id', Int32), - ('group_protocol', String('utf-8')), - ('leader_id', String('utf-8')), - ('member_id', String('utf-8')), - ('members', Array( - ('member_id', String('utf-8')), - ('member_metadata', Bytes))) - ) - - -class JoinGroupResponse_v1(Response): - API_KEY = 11 - API_VERSION = 1 - SCHEMA = JoinGroupResponse_v0.SCHEMA - - -class JoinGroupResponse_v2(Response): - API_KEY = 11 - API_VERSION = 2 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('error_code', Int16), - ('generation_id', Int32), - ('group_protocol', String('utf-8')), - ('leader_id', String('utf-8')), - ('member_id', String('utf-8')), - ('members', Array( - ('member_id', String('utf-8')), - ('member_metadata', Bytes))) - ) - - -class JoinGroupRequest_v0(Request): - API_KEY = 11 - API_VERSION = 0 - RESPONSE_TYPE = JoinGroupResponse_v0 - SCHEMA = Schema( - ('group', String('utf-8')), - ('session_timeout', Int32), - ('member_id', String('utf-8')), - ('protocol_type', String('utf-8')), - ('group_protocols', Array( - ('protocol_name', String('utf-8')), - ('protocol_metadata', Bytes))) - ) - UNKNOWN_MEMBER_ID = '' - - -class JoinGroupRequest_v1(Request): - API_KEY = 11 - API_VERSION = 1 - RESPONSE_TYPE = JoinGroupResponse_v1 - SCHEMA = Schema( - ('group', String('utf-8')), - ('session_timeout', Int32), - ('rebalance_timeout', Int32), - ('member_id', String('utf-8')), - ('protocol_type', String('utf-8')), - ('group_protocols', Array( - ('protocol_name', String('utf-8')), - ('protocol_metadata', Bytes))) - ) - UNKNOWN_MEMBER_ID = '' - - -class JoinGroupRequest_v2(Request): - API_KEY = 11 - API_VERSION = 2 - RESPONSE_TYPE = JoinGroupResponse_v2 - SCHEMA = JoinGroupRequest_v1.SCHEMA - UNKNOWN_MEMBER_ID = '' - - -JoinGroupRequest = [ - JoinGroupRequest_v0, JoinGroupRequest_v1, JoinGroupRequest_v2 -] -JoinGroupResponse = [ - JoinGroupResponse_v0, JoinGroupResponse_v1, JoinGroupResponse_v2 -] - - -class ProtocolMetadata(Struct): - SCHEMA = Schema( - ('version', Int16), - ('subscription', Array(String('utf-8'))), # topics list - ('user_data', Bytes) - ) - - -class SyncGroupResponse_v0(Response): - API_KEY = 14 - API_VERSION = 0 - SCHEMA = Schema( - ('error_code', Int16), - ('member_assignment', Bytes) - ) - - -class SyncGroupResponse_v1(Response): - API_KEY = 14 - API_VERSION = 1 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('error_code', Int16), - ('member_assignment', Bytes) - ) - - -class SyncGroupRequest_v0(Request): - API_KEY = 14 - API_VERSION = 0 - RESPONSE_TYPE = SyncGroupResponse_v0 - SCHEMA = Schema( - ('group', String('utf-8')), - ('generation_id', Int32), - ('member_id', String('utf-8')), - ('group_assignment', Array( - ('member_id', String('utf-8')), - ('member_metadata', Bytes))) - ) - - -class SyncGroupRequest_v1(Request): - API_KEY = 14 - API_VERSION = 1 - RESPONSE_TYPE = SyncGroupResponse_v1 - SCHEMA = SyncGroupRequest_v0.SCHEMA - - -SyncGroupRequest = [SyncGroupRequest_v0, SyncGroupRequest_v1] -SyncGroupResponse = [SyncGroupResponse_v0, SyncGroupResponse_v1] - - -class MemberAssignment(Struct): - SCHEMA = Schema( - ('version', Int16), - ('assignment', Array( - ('topic', String('utf-8')), - ('partitions', Array(Int32)))), - ('user_data', Bytes) - ) - - -class HeartbeatResponse_v0(Response): - API_KEY = 12 - API_VERSION = 0 - SCHEMA = Schema( - ('error_code', Int16) - ) - - -class HeartbeatResponse_v1(Response): - API_KEY = 12 - API_VERSION = 1 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('error_code', Int16) - ) - - -class HeartbeatRequest_v0(Request): - API_KEY = 12 - API_VERSION = 0 - RESPONSE_TYPE = HeartbeatResponse_v0 - SCHEMA = Schema( - ('group', String('utf-8')), - ('generation_id', Int32), - ('member_id', String('utf-8')) - ) - - -class HeartbeatRequest_v1(Request): - API_KEY = 12 - API_VERSION = 1 - RESPONSE_TYPE = HeartbeatResponse_v1 - SCHEMA = HeartbeatRequest_v0.SCHEMA - - -HeartbeatRequest = [HeartbeatRequest_v0, HeartbeatRequest_v1] -HeartbeatResponse = [HeartbeatResponse_v0, HeartbeatResponse_v1] - - -class LeaveGroupResponse_v0(Response): - API_KEY = 13 - API_VERSION = 0 - SCHEMA = Schema( - ('error_code', Int16) - ) - - -class LeaveGroupResponse_v1(Response): - API_KEY = 13 - API_VERSION = 1 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('error_code', Int16) - ) - - -class LeaveGroupRequest_v0(Request): - API_KEY = 13 - API_VERSION = 0 - RESPONSE_TYPE = LeaveGroupResponse_v0 - SCHEMA = Schema( - ('group', String('utf-8')), - ('member_id', String('utf-8')) - ) - - -class LeaveGroupRequest_v1(Request): - API_KEY = 13 - API_VERSION = 1 - RESPONSE_TYPE = LeaveGroupResponse_v1 - SCHEMA = LeaveGroupRequest_v0.SCHEMA - - -LeaveGroupRequest = [LeaveGroupRequest_v0, LeaveGroupRequest_v1] -LeaveGroupResponse = [LeaveGroupResponse_v0, LeaveGroupResponse_v1] diff --git a/kafka/protocol/metadata.py b/kafka/protocol/metadata.py deleted file mode 100644 index 414e5b84..00000000 --- a/kafka/protocol/metadata.py +++ /dev/null @@ -1,200 +0,0 @@ -from __future__ import absolute_import - -from kafka.protocol.api import Request, Response -from kafka.protocol.types import Array, Boolean, Int16, Int32, Schema, String - - -class MetadataResponse_v0(Response): - API_KEY = 3 - API_VERSION = 0 - SCHEMA = Schema( - ('brokers', Array( - ('node_id', Int32), - ('host', String('utf-8')), - ('port', Int32))), - ('topics', Array( - ('error_code', Int16), - ('topic', String('utf-8')), - ('partitions', Array( - ('error_code', Int16), - ('partition', Int32), - ('leader', Int32), - ('replicas', Array(Int32)), - ('isr', Array(Int32)))))) - ) - - -class MetadataResponse_v1(Response): - API_KEY = 3 - API_VERSION = 1 - SCHEMA = Schema( - ('brokers', Array( - ('node_id', Int32), - ('host', String('utf-8')), - ('port', Int32), - ('rack', String('utf-8')))), - ('controller_id', Int32), - ('topics', Array( - ('error_code', Int16), - ('topic', String('utf-8')), - ('is_internal', Boolean), - ('partitions', Array( - ('error_code', Int16), - ('partition', Int32), - ('leader', Int32), - ('replicas', Array(Int32)), - ('isr', Array(Int32)))))) - ) - - -class MetadataResponse_v2(Response): - API_KEY = 3 - API_VERSION = 2 - SCHEMA = Schema( - ('brokers', Array( - ('node_id', Int32), - ('host', String('utf-8')), - ('port', Int32), - ('rack', String('utf-8')))), - ('cluster_id', String('utf-8')), # <-- Added cluster_id field in v2 - ('controller_id', Int32), - ('topics', Array( - ('error_code', Int16), - ('topic', String('utf-8')), - ('is_internal', Boolean), - ('partitions', Array( - ('error_code', Int16), - ('partition', Int32), - ('leader', Int32), - ('replicas', Array(Int32)), - ('isr', Array(Int32)))))) - ) - - -class MetadataResponse_v3(Response): - API_KEY = 3 - API_VERSION = 3 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('brokers', Array( - ('node_id', Int32), - ('host', String('utf-8')), - ('port', Int32), - ('rack', String('utf-8')))), - ('cluster_id', String('utf-8')), - ('controller_id', Int32), - ('topics', Array( - ('error_code', Int16), - ('topic', String('utf-8')), - ('is_internal', Boolean), - ('partitions', Array( - ('error_code', Int16), - ('partition', Int32), - ('leader', Int32), - ('replicas', Array(Int32)), - ('isr', Array(Int32)))))) - ) - - -class MetadataResponse_v4(Response): - API_KEY = 3 - API_VERSION = 4 - SCHEMA = MetadataResponse_v3.SCHEMA - - -class MetadataResponse_v5(Response): - API_KEY = 3 - API_VERSION = 5 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('brokers', Array( - ('node_id', Int32), - ('host', String('utf-8')), - ('port', Int32), - ('rack', String('utf-8')))), - ('cluster_id', String('utf-8')), - ('controller_id', Int32), - ('topics', Array( - ('error_code', Int16), - ('topic', String('utf-8')), - ('is_internal', Boolean), - ('partitions', Array( - ('error_code', Int16), - ('partition', Int32), - ('leader', Int32), - ('replicas', Array(Int32)), - ('isr', Array(Int32)), - ('offline_replicas', Array(Int32)))))) - ) - - -class MetadataRequest_v0(Request): - API_KEY = 3 - API_VERSION = 0 - RESPONSE_TYPE = MetadataResponse_v0 - SCHEMA = Schema( - ('topics', Array(String('utf-8'))) - ) - ALL_TOPICS = None # Empty Array (len 0) for topics returns all topics - - -class MetadataRequest_v1(Request): - API_KEY = 3 - API_VERSION = 1 - RESPONSE_TYPE = MetadataResponse_v1 - SCHEMA = MetadataRequest_v0.SCHEMA - ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics - NO_TOPICS = None # Empty array (len 0) for topics returns no topics - - -class MetadataRequest_v2(Request): - API_KEY = 3 - API_VERSION = 2 - RESPONSE_TYPE = MetadataResponse_v2 - SCHEMA = MetadataRequest_v1.SCHEMA - ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics - NO_TOPICS = None # Empty array (len 0) for topics returns no topics - - -class MetadataRequest_v3(Request): - API_KEY = 3 - API_VERSION = 3 - RESPONSE_TYPE = MetadataResponse_v3 - SCHEMA = MetadataRequest_v1.SCHEMA - ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics - NO_TOPICS = None # Empty array (len 0) for topics returns no topics - - -class MetadataRequest_v4(Request): - API_KEY = 3 - API_VERSION = 4 - RESPONSE_TYPE = MetadataResponse_v4 - SCHEMA = Schema( - ('topics', Array(String('utf-8'))), - ('allow_auto_topic_creation', Boolean) - ) - ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics - NO_TOPICS = None # Empty array (len 0) for topics returns no topics - - -class MetadataRequest_v5(Request): - """ - The v5 metadata request is the same as v4. - An additional field for offline_replicas has been added to the v5 metadata response - """ - API_KEY = 3 - API_VERSION = 5 - RESPONSE_TYPE = MetadataResponse_v5 - SCHEMA = MetadataRequest_v4.SCHEMA - ALL_TOPICS = -1 # Null Array (len -1) for topics returns all topics - NO_TOPICS = None # Empty array (len 0) for topics returns no topics - - -MetadataRequest = [ - MetadataRequest_v0, MetadataRequest_v1, MetadataRequest_v2, - MetadataRequest_v3, MetadataRequest_v4, MetadataRequest_v5 -] -MetadataResponse = [ - MetadataResponse_v0, MetadataResponse_v1, MetadataResponse_v2, - MetadataResponse_v3, MetadataResponse_v4, MetadataResponse_v5 -] diff --git a/kafka/protocol/offset.py b/kafka/protocol/offset.py deleted file mode 100644 index 1ed382b0..00000000 --- a/kafka/protocol/offset.py +++ /dev/null @@ -1,194 +0,0 @@ -from __future__ import absolute_import - -from kafka.protocol.api import Request, Response -from kafka.protocol.types import Array, Int8, Int16, Int32, Int64, Schema, String - -UNKNOWN_OFFSET = -1 - - -class OffsetResetStrategy(object): - LATEST = -1 - EARLIEST = -2 - NONE = 0 - - -class OffsetResponse_v0(Response): - API_KEY = 2 - API_VERSION = 0 - SCHEMA = Schema( - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('error_code', Int16), - ('offsets', Array(Int64)))))) - ) - -class OffsetResponse_v1(Response): - API_KEY = 2 - API_VERSION = 1 - SCHEMA = Schema( - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('error_code', Int16), - ('timestamp', Int64), - ('offset', Int64))))) - ) - - -class OffsetResponse_v2(Response): - API_KEY = 2 - API_VERSION = 2 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('error_code', Int16), - ('timestamp', Int64), - ('offset', Int64))))) - ) - - -class OffsetResponse_v3(Response): - """ - on quota violation, brokers send out responses before throttling - """ - API_KEY = 2 - API_VERSION = 3 - SCHEMA = OffsetResponse_v2.SCHEMA - - -class OffsetResponse_v4(Response): - """ - Add leader_epoch to response - """ - API_KEY = 2 - API_VERSION = 4 - SCHEMA = Schema( - ('throttle_time_ms', Int32), - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('error_code', Int16), - ('timestamp', Int64), - ('offset', Int64), - ('leader_epoch', Int32))))) - ) - - -class OffsetResponse_v5(Response): - """ - adds a new error code, OFFSET_NOT_AVAILABLE - """ - API_KEY = 2 - API_VERSION = 5 - SCHEMA = OffsetResponse_v4.SCHEMA - - -class OffsetRequest_v0(Request): - API_KEY = 2 - API_VERSION = 0 - RESPONSE_TYPE = OffsetResponse_v0 - SCHEMA = Schema( - ('replica_id', Int32), - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('timestamp', Int64), - ('max_offsets', Int32))))) - ) - DEFAULTS = { - 'replica_id': -1 - } - -class OffsetRequest_v1(Request): - API_KEY = 2 - API_VERSION = 1 - RESPONSE_TYPE = OffsetResponse_v1 - SCHEMA = Schema( - ('replica_id', Int32), - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('timestamp', Int64))))) - ) - DEFAULTS = { - 'replica_id': -1 - } - - -class OffsetRequest_v2(Request): - API_KEY = 2 - API_VERSION = 2 - RESPONSE_TYPE = OffsetResponse_v2 - SCHEMA = Schema( - ('replica_id', Int32), - ('isolation_level', Int8), # <- added isolation_level - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('timestamp', Int64))))) - ) - DEFAULTS = { - 'replica_id': -1 - } - - -class OffsetRequest_v3(Request): - API_KEY = 2 - API_VERSION = 3 - RESPONSE_TYPE = OffsetResponse_v3 - SCHEMA = OffsetRequest_v2.SCHEMA - DEFAULTS = { - 'replica_id': -1 - } - - -class OffsetRequest_v4(Request): - """ - Add current_leader_epoch to request - """ - API_KEY = 2 - API_VERSION = 4 - RESPONSE_TYPE = OffsetResponse_v4 - SCHEMA = Schema( - ('replica_id', Int32), - ('isolation_level', Int8), # <- added isolation_level - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('current_leader_epoch', Int64), - ('timestamp', Int64))))) - ) - DEFAULTS = { - 'replica_id': -1 - } - - -class OffsetRequest_v5(Request): - API_KEY = 2 - API_VERSION = 5 - RESPONSE_TYPE = OffsetResponse_v5 - SCHEMA = OffsetRequest_v4.SCHEMA - DEFAULTS = { - 'replica_id': -1 - } - - -OffsetRequest = [ - OffsetRequest_v0, OffsetRequest_v1, OffsetRequest_v2, - OffsetRequest_v3, OffsetRequest_v4, OffsetRequest_v5, -] -OffsetResponse = [ - OffsetResponse_v0, OffsetResponse_v1, OffsetResponse_v2, - OffsetResponse_v3, OffsetResponse_v4, OffsetResponse_v5, -] diff --git a/kafka/protocol/produce.py b/kafka/protocol/produce.py deleted file mode 100644 index 9b3f6bf5..00000000 --- a/kafka/protocol/produce.py +++ /dev/null @@ -1,232 +0,0 @@ -from __future__ import absolute_import - -from kafka.protocol.api import Request, Response -from kafka.protocol.types import Int16, Int32, Int64, String, Array, Schema, Bytes - - -class ProduceResponse_v0(Response): - API_KEY = 0 - API_VERSION = 0 - SCHEMA = Schema( - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('error_code', Int16), - ('offset', Int64))))) - ) - - -class ProduceResponse_v1(Response): - API_KEY = 0 - API_VERSION = 1 - SCHEMA = Schema( - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('error_code', Int16), - ('offset', Int64))))), - ('throttle_time_ms', Int32) - ) - - -class ProduceResponse_v2(Response): - API_KEY = 0 - API_VERSION = 2 - SCHEMA = Schema( - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('error_code', Int16), - ('offset', Int64), - ('timestamp', Int64))))), - ('throttle_time_ms', Int32) - ) - - -class ProduceResponse_v3(Response): - API_KEY = 0 - API_VERSION = 3 - SCHEMA = ProduceResponse_v2.SCHEMA - - -class ProduceResponse_v4(Response): - """ - The version number is bumped up to indicate that the client supports KafkaStorageException. - The KafkaStorageException will be translated to NotLeaderForPartitionException in the response if version <= 3 - """ - API_KEY = 0 - API_VERSION = 4 - SCHEMA = ProduceResponse_v3.SCHEMA - - -class ProduceResponse_v5(Response): - API_KEY = 0 - API_VERSION = 5 - SCHEMA = Schema( - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('error_code', Int16), - ('offset', Int64), - ('timestamp', Int64), - ('log_start_offset', Int64))))), - ('throttle_time_ms', Int32) - ) - - -class ProduceResponse_v6(Response): - """ - The version number is bumped to indicate that on quota violation brokers send out responses before throttling. - """ - API_KEY = 0 - API_VERSION = 6 - SCHEMA = ProduceResponse_v5.SCHEMA - - -class ProduceResponse_v7(Response): - """ - V7 bumped up to indicate ZStandard capability. (see KIP-110) - """ - API_KEY = 0 - API_VERSION = 7 - SCHEMA = ProduceResponse_v6.SCHEMA - - -class ProduceResponse_v8(Response): - """ - V8 bumped up to add two new fields record_errors offset list and error_message - (See KIP-467) - """ - API_KEY = 0 - API_VERSION = 8 - SCHEMA = Schema( - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('error_code', Int16), - ('offset', Int64), - ('timestamp', Int64), - ('log_start_offset', Int64)), - ('record_errors', (Array( - ('batch_index', Int32), - ('batch_index_error_message', String('utf-8')) - ))), - ('error_message', String('utf-8')) - ))), - ('throttle_time_ms', Int32) - ) - - -class ProduceRequest(Request): - API_KEY = 0 - - def expect_response(self): - if self.required_acks == 0: # pylint: disable=no-member - return False - return True - - -class ProduceRequest_v0(ProduceRequest): - API_VERSION = 0 - RESPONSE_TYPE = ProduceResponse_v0 - SCHEMA = Schema( - ('required_acks', Int16), - ('timeout', Int32), - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('messages', Bytes))))) - ) - - -class ProduceRequest_v1(ProduceRequest): - API_VERSION = 1 - RESPONSE_TYPE = ProduceResponse_v1 - SCHEMA = ProduceRequest_v0.SCHEMA - - -class ProduceRequest_v2(ProduceRequest): - API_VERSION = 2 - RESPONSE_TYPE = ProduceResponse_v2 - SCHEMA = ProduceRequest_v1.SCHEMA - - -class ProduceRequest_v3(ProduceRequest): - API_VERSION = 3 - RESPONSE_TYPE = ProduceResponse_v3 - SCHEMA = Schema( - ('transactional_id', String('utf-8')), - ('required_acks', Int16), - ('timeout', Int32), - ('topics', Array( - ('topic', String('utf-8')), - ('partitions', Array( - ('partition', Int32), - ('messages', Bytes))))) - ) - - -class ProduceRequest_v4(ProduceRequest): - """ - The version number is bumped up to indicate that the client supports KafkaStorageException. - The KafkaStorageException will be translated to NotLeaderForPartitionException in the response if version <= 3 - """ - API_VERSION = 4 - RESPONSE_TYPE = ProduceResponse_v4 - SCHEMA = ProduceRequest_v3.SCHEMA - - -class ProduceRequest_v5(ProduceRequest): - """ - Same as v4. The version number is bumped since the v5 response includes an additional - partition level field: the log_start_offset. - """ - API_VERSION = 5 - RESPONSE_TYPE = ProduceResponse_v5 - SCHEMA = ProduceRequest_v4.SCHEMA - - -class ProduceRequest_v6(ProduceRequest): - """ - The version number is bumped to indicate that on quota violation brokers send out responses before throttling. - """ - API_VERSION = 6 - RESPONSE_TYPE = ProduceResponse_v6 - SCHEMA = ProduceRequest_v5.SCHEMA - - -class ProduceRequest_v7(ProduceRequest): - """ - V7 bumped up to indicate ZStandard capability. (see KIP-110) - """ - API_VERSION = 7 - RESPONSE_TYPE = ProduceResponse_v7 - SCHEMA = ProduceRequest_v6.SCHEMA - - -class ProduceRequest_v8(ProduceRequest): - """ - V8 bumped up to add two new fields record_errors offset list and error_message to PartitionResponse - (See KIP-467) - """ - API_VERSION = 8 - RESPONSE_TYPE = ProduceResponse_v8 - SCHEMA = ProduceRequest_v7.SCHEMA - - -ProduceRequest = [ - ProduceRequest_v0, ProduceRequest_v1, ProduceRequest_v2, - ProduceRequest_v3, ProduceRequest_v4, ProduceRequest_v5, - ProduceRequest_v6, ProduceRequest_v7, ProduceRequest_v8, -] -ProduceResponse = [ - ProduceResponse_v0, ProduceResponse_v1, ProduceResponse_v2, - ProduceResponse_v3, ProduceResponse_v4, ProduceResponse_v5, - ProduceResponse_v6, ProduceResponse_v7, ProduceResponse_v8, -] diff --git a/kafka/version.py b/kafka/version.py deleted file mode 100644 index 06306bd1..00000000 --- a/kafka/version.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = '2.0.3-dev' diff --git a/tests/kafka/fixtures.py b/tests/kafka/fixtures.py index 76bde28f..b6854e54 100644 --- a/tests/kafka/fixtures.py +++ b/tests/kafka/fixtures.py @@ -15,8 +15,8 @@ from aiokafka import errors from aiokafka.errors import InvalidReplicationFactorError -from kafka.protocol.admin import CreateTopicsRequest -from kafka.protocol.metadata import MetadataRequest +from aiokafka.protocol.admin import CreateTopicsRequest +from aiokafka.protocol.metadata import MetadataRequest from tests.kafka.testutil import env_kafka_version, random_string from tests.kafka.service import ExternalService, SpawnedService diff --git a/tests/kafka/test_api_object_implementation.py b/tests/kafka/test_api_object_implementation.py deleted file mode 100644 index da80f148..00000000 --- a/tests/kafka/test_api_object_implementation.py +++ /dev/null @@ -1,18 +0,0 @@ -import abc -import pytest - -from kafka.protocol.api import Request -from kafka.protocol.api import Response - - -attr_names = [n for n in dir(Request) if isinstance(getattr(Request, n), abc.abstractproperty)] -@pytest.mark.parametrize('klass', Request.__subclasses__()) -@pytest.mark.parametrize('attr_name', attr_names) -def test_request_type_conformance(klass, attr_name): - assert hasattr(klass, attr_name) - -attr_names = [n for n in dir(Response) if isinstance(getattr(Response, n), abc.abstractproperty)] -@pytest.mark.parametrize('klass', Response.__subclasses__()) -@pytest.mark.parametrize('attr_name', attr_names) -def test_response_type_conformance(klass, attr_name): - assert hasattr(klass, attr_name) diff --git a/tests/kafka/test_object_conversion.py b/tests/kafka/test_object_conversion.py deleted file mode 100644 index 9b1ff213..00000000 --- a/tests/kafka/test_object_conversion.py +++ /dev/null @@ -1,236 +0,0 @@ -from kafka.protocol.admin import Request -from kafka.protocol.admin import Response -from kafka.protocol.types import Schema -from kafka.protocol.types import Array -from kafka.protocol.types import Int16 -from kafka.protocol.types import String - -import pytest - -@pytest.mark.parametrize('superclass', (Request, Response)) -class TestObjectConversion: - def test_get_item(self, superclass): - class TestClass(superclass): - API_KEY = 0 - API_VERSION = 0 - RESPONSE_TYPE = None # To satisfy the Request ABC - SCHEMA = Schema( - ('myobject', Int16)) - - tc = TestClass(myobject=0) - assert tc.get_item('myobject') == 0 - with pytest.raises(KeyError): - tc.get_item('does-not-exist') - - def test_with_empty_schema(self, superclass): - class TestClass(superclass): - API_KEY = 0 - API_VERSION = 0 - RESPONSE_TYPE = None # To satisfy the Request ABC - SCHEMA = Schema() - - tc = TestClass() - tc.encode() - assert tc.to_object() == {} - - def test_with_basic_schema(self, superclass): - class TestClass(superclass): - API_KEY = 0 - API_VERSION = 0 - RESPONSE_TYPE = None # To satisfy the Request ABC - SCHEMA = Schema( - ('myobject', Int16)) - - tc = TestClass(myobject=0) - tc.encode() - assert tc.to_object() == {'myobject': 0} - - def test_with_basic_array_schema(self, superclass): - class TestClass(superclass): - API_KEY = 0 - API_VERSION = 0 - RESPONSE_TYPE = None # To satisfy the Request ABC - SCHEMA = Schema( - ('myarray', Array(Int16))) - - tc = TestClass(myarray=[1,2,3]) - tc.encode() - assert tc.to_object()['myarray'] == [1, 2, 3] - - def test_with_complex_array_schema(self, superclass): - class TestClass(superclass): - API_KEY = 0 - API_VERSION = 0 - RESPONSE_TYPE = None # To satisfy the Request ABC - SCHEMA = Schema( - ('myarray', Array( - ('subobject', Int16), - ('othersubobject', String('utf-8'))))) - - tc = TestClass( - myarray=[[10, 'hello']] - ) - tc.encode() - obj = tc.to_object() - assert len(obj['myarray']) == 1 - assert obj['myarray'][0]['subobject'] == 10 - assert obj['myarray'][0]['othersubobject'] == 'hello' - - def test_with_array_and_other(self, superclass): - class TestClass(superclass): - API_KEY = 0 - API_VERSION = 0 - RESPONSE_TYPE = None # To satisfy the Request ABC - SCHEMA = Schema( - ('myarray', Array( - ('subobject', Int16), - ('othersubobject', String('utf-8')))), - ('notarray', Int16)) - - tc = TestClass( - myarray=[[10, 'hello']], - notarray=42 - ) - - obj = tc.to_object() - assert len(obj['myarray']) == 1 - assert obj['myarray'][0]['subobject'] == 10 - assert obj['myarray'][0]['othersubobject'] == 'hello' - assert obj['notarray'] == 42 - - def test_with_nested_array(self, superclass): - class TestClass(superclass): - API_KEY = 0 - API_VERSION = 0 - RESPONSE_TYPE = None # To satisfy the Request ABC - SCHEMA = Schema( - ('myarray', Array( - ('subarray', Array(Int16)), - ('otherobject', Int16)))) - - tc = TestClass( - myarray=[ - [[1, 2], 2], - [[2, 3], 4], - ] - ) - print(tc.encode()) - - - obj = tc.to_object() - assert len(obj['myarray']) == 2 - assert obj['myarray'][0]['subarray'] == [1, 2] - assert obj['myarray'][0]['otherobject'] == 2 - assert obj['myarray'][1]['subarray'] == [2, 3] - assert obj['myarray'][1]['otherobject'] == 4 - - def test_with_complex_nested_array(self, superclass): - class TestClass(superclass): - API_KEY = 0 - API_VERSION = 0 - RESPONSE_TYPE = None # To satisfy the Request ABC - SCHEMA = Schema( - ('myarray', Array( - ('subarray', Array( - ('innertest', String('utf-8')), - ('otherinnertest', String('utf-8')))), - ('othersubarray', Array(Int16)))), - ('notarray', String('utf-8'))) - - tc = TestClass( - myarray=[ - [[['hello', 'hello'], ['hello again', 'hello again']], [0]], - [[['hello', 'hello again']], [1]], - ], - notarray='notarray' - ) - tc.encode() - - obj = tc.to_object() - - assert obj['notarray'] == 'notarray' - myarray = obj['myarray'] - assert len(myarray) == 2 - - assert myarray[0]['othersubarray'] == [0] - assert len(myarray[0]['subarray']) == 2 - assert myarray[0]['subarray'][0]['innertest'] == 'hello' - assert myarray[0]['subarray'][0]['otherinnertest'] == 'hello' - assert myarray[0]['subarray'][1]['innertest'] == 'hello again' - assert myarray[0]['subarray'][1]['otherinnertest'] == 'hello again' - - assert myarray[1]['othersubarray'] == [1] - assert len(myarray[1]['subarray']) == 1 - assert myarray[1]['subarray'][0]['innertest'] == 'hello' - assert myarray[1]['subarray'][0]['otherinnertest'] == 'hello again' - -def test_with_metadata_response(): - from kafka.protocol.metadata import MetadataResponse_v5 - tc = MetadataResponse_v5( - throttle_time_ms=0, - brokers=[ - [0, 'testhost0', 9092, 'testrack0'], - [1, 'testhost1', 9092, 'testrack1'], - ], - cluster_id='abcd', - controller_id=0, - topics=[ - [0, 'testtopic1', False, [ - [0, 0, 0, [0, 1], [0, 1], []], - [0, 1, 1, [1, 0], [1, 0], []], - ], - ], [0, 'other-test-topic', True, [ - [0, 0, 0, [0, 1], [0, 1], []], - ] - ]] - ) - tc.encode() # Make sure this object encodes successfully - - - obj = tc.to_object() - - assert obj['throttle_time_ms'] == 0 - - assert len(obj['brokers']) == 2 - assert obj['brokers'][0]['node_id'] == 0 - assert obj['brokers'][0]['host'] == 'testhost0' - assert obj['brokers'][0]['port'] == 9092 - assert obj['brokers'][0]['rack'] == 'testrack0' - assert obj['brokers'][1]['node_id'] == 1 - assert obj['brokers'][1]['host'] == 'testhost1' - assert obj['brokers'][1]['port'] == 9092 - assert obj['brokers'][1]['rack'] == 'testrack1' - - assert obj['cluster_id'] == 'abcd' - assert obj['controller_id'] == 0 - - assert len(obj['topics']) == 2 - assert obj['topics'][0]['error_code'] == 0 - assert obj['topics'][0]['topic'] == 'testtopic1' - assert obj['topics'][0]['is_internal'] == False - assert len(obj['topics'][0]['partitions']) == 2 - assert obj['topics'][0]['partitions'][0]['error_code'] == 0 - assert obj['topics'][0]['partitions'][0]['partition'] == 0 - assert obj['topics'][0]['partitions'][0]['leader'] == 0 - assert obj['topics'][0]['partitions'][0]['replicas'] == [0, 1] - assert obj['topics'][0]['partitions'][0]['isr'] == [0, 1] - assert obj['topics'][0]['partitions'][0]['offline_replicas'] == [] - assert obj['topics'][0]['partitions'][1]['error_code'] == 0 - assert obj['topics'][0]['partitions'][1]['partition'] == 1 - assert obj['topics'][0]['partitions'][1]['leader'] == 1 - assert obj['topics'][0]['partitions'][1]['replicas'] == [1, 0] - assert obj['topics'][0]['partitions'][1]['isr'] == [1, 0] - assert obj['topics'][0]['partitions'][1]['offline_replicas'] == [] - - assert obj['topics'][1]['error_code'] == 0 - assert obj['topics'][1]['topic'] == 'other-test-topic' - assert obj['topics'][1]['is_internal'] == True - assert len(obj['topics'][1]['partitions']) == 1 - assert obj['topics'][1]['partitions'][0]['error_code'] == 0 - assert obj['topics'][1]['partitions'][0]['partition'] == 0 - assert obj['topics'][1]['partitions'][0]['leader'] == 0 - assert obj['topics'][1]['partitions'][0]['replicas'] == [0, 1] - assert obj['topics'][1]['partitions'][0]['isr'] == [0, 1] - assert obj['topics'][1]['partitions'][0]['offline_replicas'] == [] - - tc.encode() diff --git a/tests/kafka/test_protocol.py b/tests/kafka/test_protocol.py deleted file mode 100644 index 6a77e19d..00000000 --- a/tests/kafka/test_protocol.py +++ /dev/null @@ -1,336 +0,0 @@ -#pylint: skip-file -import io -import struct - -import pytest - -from kafka.protocol.api import RequestHeader -from kafka.protocol.commit import GroupCoordinatorRequest -from kafka.protocol.fetch import FetchRequest, FetchResponse -from kafka.protocol.message import Message, MessageSet, PartialMessage -from kafka.protocol.metadata import MetadataRequest -from kafka.protocol.types import Int16, Int32, Int64, String, UnsignedVarInt32, CompactString, CompactArray, CompactBytes - - -def test_create_message(): - payload = b'test' - key = b'key' - msg = Message(payload, key=key) - assert msg.magic == 0 - assert msg.attributes == 0 - assert msg.key == key - assert msg.value == payload - - -def test_encode_message_v0(): - message = Message(b'test', key=b'key') - encoded = message.encode() - expect = b''.join([ - struct.pack('>i', -1427009701), # CRC - struct.pack('>bb', 0, 0), # Magic, flags - struct.pack('>i', 3), # Length of key - b'key', # key - struct.pack('>i', 4), # Length of value - b'test', # value - ]) - assert encoded == expect - - -def test_encode_message_v1(): - message = Message(b'test', key=b'key', magic=1, timestamp=1234) - encoded = message.encode() - expect = b''.join([ - struct.pack('>i', 1331087195), # CRC - struct.pack('>bb', 1, 0), # Magic, flags - struct.pack('>q', 1234), # Timestamp - struct.pack('>i', 3), # Length of key - b'key', # key - struct.pack('>i', 4), # Length of value - b'test', # value - ]) - assert encoded == expect - - -def test_decode_message(): - encoded = b''.join([ - struct.pack('>i', -1427009701), # CRC - struct.pack('>bb', 0, 0), # Magic, flags - struct.pack('>i', 3), # Length of key - b'key', # key - struct.pack('>i', 4), # Length of value - b'test', # value - ]) - decoded_message = Message.decode(encoded) - msg = Message(b'test', key=b'key') - msg.encode() # crc is recalculated during encoding - assert decoded_message == msg - - -def test_decode_message_validate_crc(): - encoded = b''.join([ - struct.pack('>i', -1427009701), # CRC - struct.pack('>bb', 0, 0), # Magic, flags - struct.pack('>i', 3), # Length of key - b'key', # key - struct.pack('>i', 4), # Length of value - b'test', # value - ]) - decoded_message = Message.decode(encoded) - assert decoded_message.validate_crc() is True - - encoded = b''.join([ - struct.pack('>i', 1234), # Incorrect CRC - struct.pack('>bb', 0, 0), # Magic, flags - struct.pack('>i', 3), # Length of key - b'key', # key - struct.pack('>i', 4), # Length of value - b'test', # value - ]) - decoded_message = Message.decode(encoded) - assert decoded_message.validate_crc() is False - - -def test_encode_message_set(): - messages = [ - Message(b'v1', key=b'k1'), - Message(b'v2', key=b'k2') - ] - encoded = MessageSet.encode([(0, msg.encode()) - for msg in messages]) - expect = b''.join([ - struct.pack('>q', 0), # MsgSet Offset - struct.pack('>i', 18), # Msg Size - struct.pack('>i', 1474775406), # CRC - struct.pack('>bb', 0, 0), # Magic, flags - struct.pack('>i', 2), # Length of key - b'k1', # Key - struct.pack('>i', 2), # Length of value - b'v1', # Value - - struct.pack('>q', 0), # MsgSet Offset - struct.pack('>i', 18), # Msg Size - struct.pack('>i', -16383415), # CRC - struct.pack('>bb', 0, 0), # Magic, flags - struct.pack('>i', 2), # Length of key - b'k2', # Key - struct.pack('>i', 2), # Length of value - b'v2', # Value - ]) - expect = struct.pack('>i', len(expect)) + expect - assert encoded == expect - - -def test_decode_message_set(): - encoded = b''.join([ - struct.pack('>q', 0), # MsgSet Offset - struct.pack('>i', 18), # Msg Size - struct.pack('>i', 1474775406), # CRC - struct.pack('>bb', 0, 0), # Magic, flags - struct.pack('>i', 2), # Length of key - b'k1', # Key - struct.pack('>i', 2), # Length of value - b'v1', # Value - - struct.pack('>q', 1), # MsgSet Offset - struct.pack('>i', 18), # Msg Size - struct.pack('>i', -16383415), # CRC - struct.pack('>bb', 0, 0), # Magic, flags - struct.pack('>i', 2), # Length of key - b'k2', # Key - struct.pack('>i', 2), # Length of value - b'v2', # Value - ]) - - msgs = MessageSet.decode(encoded, bytes_to_read=len(encoded)) - assert len(msgs) == 2 - msg1, msg2 = msgs - - returned_offset1, message1_size, decoded_message1 = msg1 - returned_offset2, message2_size, decoded_message2 = msg2 - - assert returned_offset1 == 0 - message1 = Message(b'v1', key=b'k1') - message1.encode() - assert decoded_message1 == message1 - - assert returned_offset2 == 1 - message2 = Message(b'v2', key=b'k2') - message2.encode() - assert decoded_message2 == message2 - - -def test_encode_message_header(): - expect = b''.join([ - struct.pack('>h', 10), # API Key - struct.pack('>h', 0), # API Version - struct.pack('>i', 4), # Correlation Id - struct.pack('>h', len('client3')), # Length of clientId - b'client3', # ClientId - ]) - - req = GroupCoordinatorRequest[0]('foo') - header = RequestHeader(req, correlation_id=4, client_id='client3') - assert header.encode() == expect - - -def test_decode_message_set_partial(): - encoded = b''.join([ - struct.pack('>q', 0), # Msg Offset - struct.pack('>i', 18), # Msg Size - struct.pack('>i', 1474775406), # CRC - struct.pack('>bb', 0, 0), # Magic, flags - struct.pack('>i', 2), # Length of key - b'k1', # Key - struct.pack('>i', 2), # Length of value - b'v1', # Value - - struct.pack('>q', 1), # Msg Offset - struct.pack('>i', 24), # Msg Size (larger than remaining MsgSet size) - struct.pack('>i', -16383415), # CRC - struct.pack('>bb', 0, 0), # Magic, flags - struct.pack('>i', 2), # Length of key - b'k2', # Key - struct.pack('>i', 8), # Length of value - b'ar', # Value (truncated) - ]) - - msgs = MessageSet.decode(encoded, bytes_to_read=len(encoded)) - assert len(msgs) == 2 - msg1, msg2 = msgs - - returned_offset1, message1_size, decoded_message1 = msg1 - returned_offset2, message2_size, decoded_message2 = msg2 - - assert returned_offset1 == 0 - message1 = Message(b'v1', key=b'k1') - message1.encode() - assert decoded_message1 == message1 - - assert returned_offset2 is None - assert message2_size is None - assert decoded_message2 == PartialMessage() - - -def test_decode_fetch_response_partial(): - encoded = b''.join([ - Int32.encode(1), # Num Topics (Array) - String('utf-8').encode('foobar'), - Int32.encode(2), # Num Partitions (Array) - Int32.encode(0), # Partition id - Int16.encode(0), # Error Code - Int64.encode(1234), # Highwater offset - Int32.encode(52), # MessageSet size - Int64.encode(0), # Msg Offset - Int32.encode(18), # Msg Size - struct.pack('>i', 1474775406), # CRC - struct.pack('>bb', 0, 0), # Magic, flags - struct.pack('>i', 2), # Length of key - b'k1', # Key - struct.pack('>i', 2), # Length of value - b'v1', # Value - - Int64.encode(1), # Msg Offset - struct.pack('>i', 24), # Msg Size (larger than remaining MsgSet size) - struct.pack('>i', -16383415), # CRC - struct.pack('>bb', 0, 0), # Magic, flags - struct.pack('>i', 2), # Length of key - b'k2', # Key - struct.pack('>i', 8), # Length of value - b'ar', # Value (truncated) - Int32.encode(1), - Int16.encode(0), - Int64.encode(2345), - Int32.encode(52), # MessageSet size - Int64.encode(0), # Msg Offset - Int32.encode(18), # Msg Size - struct.pack('>i', 1474775406), # CRC - struct.pack('>bb', 0, 0), # Magic, flags - struct.pack('>i', 2), # Length of key - b'k1', # Key - struct.pack('>i', 2), # Length of value - b'v1', # Value - - Int64.encode(1), # Msg Offset - struct.pack('>i', 24), # Msg Size (larger than remaining MsgSet size) - struct.pack('>i', -16383415), # CRC - struct.pack('>bb', 0, 0), # Magic, flags - struct.pack('>i', 2), # Length of key - b'k2', # Key - struct.pack('>i', 8), # Length of value - b'ar', # Value (truncated) - ]) - resp = FetchResponse[0].decode(io.BytesIO(encoded)) - assert len(resp.topics) == 1 - topic, partitions = resp.topics[0] - assert topic == 'foobar' - assert len(partitions) == 2 - - m1 = MessageSet.decode( - partitions[0][3], bytes_to_read=len(partitions[0][3])) - assert len(m1) == 2 - assert m1[1] == (None, None, PartialMessage()) - - -def test_struct_unrecognized_kwargs(): - try: - mr = MetadataRequest[0](topicz='foo') - assert False, 'Structs should not allow unrecognized kwargs' - except ValueError: - pass - - -def test_struct_missing_kwargs(): - fr = FetchRequest[0](max_wait_time=100) - assert fr.min_bytes is None - - -def test_unsigned_varint_serde(): - pairs = { - 0: [0], - -1: [0xff, 0xff, 0xff, 0xff, 0x0f], - 1: [1], - 63: [0x3f], - -64: [0xc0, 0xff, 0xff, 0xff, 0x0f], - 64: [0x40], - 8191: [0xff, 0x3f], - -8192: [0x80, 0xc0, 0xff, 0xff, 0x0f], - 8192: [0x80, 0x40], - -8193: [0xff, 0xbf, 0xff, 0xff, 0x0f], - 1048575: [0xff, 0xff, 0x3f], - - } - for value, expected_encoded in pairs.items(): - value &= 0xffffffff - encoded = UnsignedVarInt32.encode(value) - assert encoded == b''.join(struct.pack('>B', x) for x in expected_encoded) - assert value == UnsignedVarInt32.decode(io.BytesIO(encoded)) - - -def test_compact_data_structs(): - cs = CompactString() - encoded = cs.encode(None) - assert encoded == struct.pack('B', 0) - decoded = cs.decode(io.BytesIO(encoded)) - assert decoded is None - assert b'\x01' == cs.encode('') - assert '' == cs.decode(io.BytesIO(b'\x01')) - encoded = cs.encode("foobarbaz") - assert cs.decode(io.BytesIO(encoded)) == "foobarbaz" - - arr = CompactArray(CompactString()) - assert arr.encode(None) == b'\x00' - assert arr.decode(io.BytesIO(b'\x00')) is None - enc = arr.encode([]) - assert enc == b'\x01' - assert [] == arr.decode(io.BytesIO(enc)) - encoded = arr.encode(["foo", "bar", "baz", "quux"]) - assert arr.decode(io.BytesIO(encoded)) == ["foo", "bar", "baz", "quux"] - - enc = CompactBytes.encode(None) - assert enc == b'\x00' - assert CompactBytes.decode(io.BytesIO(b'\x00')) is None - enc = CompactBytes.encode(b'') - assert enc == b'\x01' - assert CompactBytes.decode(io.BytesIO(b'\x01')) is b'' - enc = CompactBytes.encode(b'foo') - assert CompactBytes.decode(io.BytesIO(enc)) == b'foo' diff --git a/tests/test_client.py b/tests/test_client.py index e9ceb517..f17ca673 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -4,11 +4,6 @@ from typing import Any from unittest import mock -from kafka.protocol.metadata import ( - MetadataRequest_v0 as MetadataRequest, - MetadataResponse_v0 as MetadataResponse) -from kafka.protocol.fetch import FetchRequest_v0 - from aiokafka import __version__ from aiokafka.client import AIOKafkaClient, ConnectionGroup, CoordinationType from aiokafka.conn import AIOKafkaConnection, CloseReason @@ -16,6 +11,10 @@ KafkaError, KafkaConnectionError, RequestTimedOutError, NodeNotReadyError, UnrecognizedBrokerVersion ) +from aiokafka.protocol.metadata import ( + MetadataRequest_v0 as MetadataRequest, + MetadataResponse_v0 as MetadataResponse) +from aiokafka.protocol.fetch import FetchRequest_v0 from aiokafka.util import create_task, get_running_loop from ._testutil import ( KafkaIntegrationTestCase, run_until_complete, kafka_versions diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 0fad6e31..84ff1549 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -1,6 +1,5 @@ -from kafka.protocol.metadata import MetadataResponse - from aiokafka.cluster import ClusterMetadata +from aiokafka.protocol.metadata import MetadataResponse def test_empty_broker_list(): diff --git a/tests/test_conn.py b/tests/test_conn.py index 3862c11c..a1cedc92 100644 --- a/tests/test_conn.py +++ b/tests/test_conn.py @@ -5,23 +5,22 @@ from typing import Any from unittest import mock -from kafka.protocol.metadata import ( +from aiokafka.conn import AIOKafkaConnection, create_conn, VersionInfo +from aiokafka.errors import ( + KafkaConnectionError, CorrelationIdError, KafkaError, NoError, + UnknownError, UnsupportedSaslMechanismError, IllegalSaslStateError +) +from aiokafka.protocol.metadata import ( MetadataRequest_v0 as MetadataRequest, MetadataResponse_v0 as MetadataResponse) -from kafka.protocol.commit import ( +from aiokafka.protocol.commit import ( GroupCoordinatorRequest_v0 as GroupCoordinatorRequest, GroupCoordinatorResponse_v0 as GroupCoordinatorResponse) -from kafka.protocol.admin import ( +from aiokafka.protocol.admin import ( SaslHandShakeRequest, SaslHandShakeResponse, SaslAuthenticateRequest, SaslAuthenticateResponse ) -from kafka.protocol.produce import ProduceRequest_v0 as ProduceRequest - -from aiokafka.conn import AIOKafkaConnection, create_conn, VersionInfo -from aiokafka.errors import ( - KafkaConnectionError, CorrelationIdError, KafkaError, NoError, - UnknownError, UnsupportedSaslMechanismError, IllegalSaslStateError -) +from aiokafka.protocol.produce import ProduceRequest_v0 as ProduceRequest from aiokafka.record.legacy_records import LegacyRecordBatchBuilder from ._testutil import KafkaIntegrationTestCase, run_until_complete from aiokafka.util import get_running_loop diff --git a/tests/test_coordinator.py b/tests/test_coordinator.py index 59b26de9..38bf12d9 100644 --- a/tests/test_coordinator.py +++ b/tests/test_coordinator.py @@ -2,28 +2,27 @@ import re from unittest import mock -from kafka.protocol.group import ( +from aiokafka import ConsumerRebalanceListener +from aiokafka.client import AIOKafkaClient +import aiokafka.errors as Errors +from aiokafka.structs import OffsetAndMetadata, TopicPartition +from aiokafka.consumer.group_coordinator import ( + GroupCoordinator, CoordinatorGroupRebalance, NoGroupCoordinator) +from aiokafka.consumer.subscription_state import SubscriptionState +from aiokafka.protocol.group import ( JoinGroupRequest_v0 as JoinGroupRequest, SyncGroupResponse_v0 as SyncGroupResponse, LeaveGroupRequest_v0 as LeaveGroupRequest, HeartbeatRequest_v0 as HeartbeatRequest, ) -from kafka.protocol.commit import ( +from aiokafka.protocol.commit import ( OffsetCommitRequest, OffsetCommitResponse_v2, OffsetFetchRequest_v1 as OffsetFetchRequest ) +from aiokafka.util import create_future, create_task, get_running_loop from ._testutil import KafkaIntegrationTestCase, run_until_complete -from aiokafka import ConsumerRebalanceListener -from aiokafka.client import AIOKafkaClient -import aiokafka.errors as Errors -from aiokafka.structs import OffsetAndMetadata, TopicPartition -from aiokafka.consumer.group_coordinator import ( - GroupCoordinator, CoordinatorGroupRebalance, NoGroupCoordinator) -from aiokafka.consumer.subscription_state import SubscriptionState -from aiokafka.util import create_future, create_task, get_running_loop - UNKNOWN_MEMBER_ID = JoinGroupRequest.UNKNOWN_MEMBER_ID diff --git a/tests/test_fetcher.py b/tests/test_fetcher.py index c6d8e23f..aff4d69c 100644 --- a/tests/test_fetcher.py +++ b/tests/test_fetcher.py @@ -3,30 +3,30 @@ import unittest from unittest import mock -from kafka.protocol.fetch import ( +from aiokafka.client import AIOKafkaClient +from aiokafka.consumer.fetcher import ( + Fetcher, FetchResult, FetchError, ConsumerRecord, OffsetResetStrategy, + PartitionRecords, READ_COMMITTED, READ_UNCOMMITTED +) +from aiokafka.consumer.subscription_state import SubscriptionState +from aiokafka.errors import ( + TopicAuthorizationFailedError, UnknownError, UnknownTopicOrPartitionError, + OffsetOutOfRangeError, KafkaTimeoutError, NotLeaderForPartitionError +) +from aiokafka.protocol.fetch import ( FetchRequest_v0 as FetchRequest, FetchResponse_v0 as FetchResponse ) -from kafka.protocol.offset import OffsetResponse +from aiokafka.protocol.offset import OffsetResponse from aiokafka.record.legacy_records import LegacyRecordBatchBuilder from aiokafka.record.default_records import ( # NB: test_solitary_abort_marker relies on implementation details _DefaultRecordBatchBuilderPy as DefaultRecordBatchBuilder) from aiokafka.record.memory_records import MemoryRecords - -from aiokafka.errors import ( - TopicAuthorizationFailedError, UnknownError, UnknownTopicOrPartitionError, - OffsetOutOfRangeError, KafkaTimeoutError, NotLeaderForPartitionError -) from aiokafka.structs import ( TopicPartition, OffsetAndTimestamp, OffsetAndMetadata ) -from aiokafka.client import AIOKafkaClient -from aiokafka.consumer.fetcher import ( - Fetcher, FetchResult, FetchError, ConsumerRecord, OffsetResetStrategy, - PartitionRecords, READ_COMMITTED, READ_UNCOMMITTED -) -from aiokafka.consumer.subscription_state import SubscriptionState from aiokafka.util import create_future, create_task, get_running_loop + from ._testutil import run_until_complete diff --git a/tests/test_producer.py b/tests/test_producer.py index 9cc6858b..a6004a4e 100644 --- a/tests/test_producer.py +++ b/tests/test_producer.py @@ -6,9 +6,6 @@ import weakref from unittest import mock -from kafka.protocol.produce import ProduceResponse - -from aiokafka.producer import AIOKafkaProducer from aiokafka.client import AIOKafkaClient from aiokafka.cluster import ClusterMetadata from aiokafka.consumer import AIOKafkaConsumer @@ -17,6 +14,8 @@ MessageSizeTooLargeError, NotLeaderForPartitionError, LeaderNotAvailableError, RequestTimedOutError, UnsupportedVersionError, ProducerClosed, KafkaError) +from aiokafka.producer import AIOKafkaProducer +from aiokafka.protocol.produce import ProduceResponse from aiokafka.util import create_future from ._testutil import ( diff --git a/tests/test_protocol.py b/tests/test_protocol.py new file mode 100644 index 00000000..240ca356 --- /dev/null +++ b/tests/test_protocol.py @@ -0,0 +1,376 @@ +import abc +import io +import struct + +import pytest + +from aiokafka.protocol.api import RequestHeader, Request, Response +from aiokafka.protocol.commit import GroupCoordinatorRequest +from aiokafka.protocol.fetch import FetchRequest, FetchResponse +from aiokafka.protocol.message import Message, MessageSet, PartialMessage +from aiokafka.protocol.metadata import MetadataRequest +from aiokafka.protocol.types import ( + Int16, + Int32, + Int64, + String, + UnsignedVarInt32, + CompactString, + CompactArray, + CompactBytes, +) + + +def test_create_message(): + payload = b"test" + key = b"key" + msg = Message(payload, key=key) + assert msg.magic == 0 + assert msg.attributes == 0 + assert msg.key == key + assert msg.value == payload + + +def test_encode_message_v0(): + message = Message(b"test", key=b"key") + encoded = message.encode() + expect = b"".join( + [ + struct.pack(">i", -1427009701), # CRC + struct.pack(">bb", 0, 0), # Magic, flags + struct.pack(">i", 3), # Length of key + b"key", # key + struct.pack(">i", 4), # Length of value + b"test", # value + ] + ) + assert encoded == expect + + +def test_encode_message_v1(): + message = Message(b"test", key=b"key", magic=1, timestamp=1234) + encoded = message.encode() + expect = b"".join( + [ + struct.pack(">i", 1331087195), # CRC + struct.pack(">bb", 1, 0), # Magic, flags + struct.pack(">q", 1234), # Timestamp + struct.pack(">i", 3), # Length of key + b"key", # key + struct.pack(">i", 4), # Length of value + b"test", # value + ] + ) + assert encoded == expect + + +def test_decode_message(): + encoded = b"".join( + [ + struct.pack(">i", -1427009701), # CRC + struct.pack(">bb", 0, 0), # Magic, flags + struct.pack(">i", 3), # Length of key + b"key", # key + struct.pack(">i", 4), # Length of value + b"test", # value + ] + ) + decoded_message = Message.decode(encoded) + msg = Message(b"test", key=b"key") + msg.encode() # crc is recalculated during encoding + assert decoded_message == msg + + +def test_decode_message_validate_crc(): + encoded = b"".join( + [ + struct.pack(">i", -1427009701), # CRC + struct.pack(">bb", 0, 0), # Magic, flags + struct.pack(">i", 3), # Length of key + b"key", # key + struct.pack(">i", 4), # Length of value + b"test", # value + ] + ) + decoded_message = Message.decode(encoded) + assert decoded_message.validate_crc() is True + + encoded = b"".join( + [ + struct.pack(">i", 1234), # Incorrect CRC + struct.pack(">bb", 0, 0), # Magic, flags + struct.pack(">i", 3), # Length of key + b"key", # key + struct.pack(">i", 4), # Length of value + b"test", # value + ] + ) + decoded_message = Message.decode(encoded) + assert decoded_message.validate_crc() is False + + +def test_encode_message_set(): + messages = [Message(b"v1", key=b"k1"), Message(b"v2", key=b"k2")] + encoded = MessageSet.encode([(0, msg.encode()) for msg in messages]) + expect = b"".join( + [ + struct.pack(">q", 0), # MsgSet Offset + struct.pack(">i", 18), # Msg Size + struct.pack(">i", 1474775406), # CRC + struct.pack(">bb", 0, 0), # Magic, flags + struct.pack(">i", 2), # Length of key + b"k1", # Key + struct.pack(">i", 2), # Length of value + b"v1", # Value + struct.pack(">q", 0), # MsgSet Offset + struct.pack(">i", 18), # Msg Size + struct.pack(">i", -16383415), # CRC + struct.pack(">bb", 0, 0), # Magic, flags + struct.pack(">i", 2), # Length of key + b"k2", # Key + struct.pack(">i", 2), # Length of value + b"v2", # Value + ] + ) + expect = struct.pack(">i", len(expect)) + expect + assert encoded == expect + + +def test_decode_message_set(): + encoded = b"".join( + [ + struct.pack(">q", 0), # MsgSet Offset + struct.pack(">i", 18), # Msg Size + struct.pack(">i", 1474775406), # CRC + struct.pack(">bb", 0, 0), # Magic, flags + struct.pack(">i", 2), # Length of key + b"k1", # Key + struct.pack(">i", 2), # Length of value + b"v1", # Value + struct.pack(">q", 1), # MsgSet Offset + struct.pack(">i", 18), # Msg Size + struct.pack(">i", -16383415), # CRC + struct.pack(">bb", 0, 0), # Magic, flags + struct.pack(">i", 2), # Length of key + b"k2", # Key + struct.pack(">i", 2), # Length of value + b"v2", # Value + ] + ) + + msgs = MessageSet.decode(encoded, bytes_to_read=len(encoded)) + assert len(msgs) == 2 + msg1, msg2 = msgs + + returned_offset1, message1_size, decoded_message1 = msg1 + returned_offset2, message2_size, decoded_message2 = msg2 + + assert returned_offset1 == 0 + message1 = Message(b"v1", key=b"k1") + message1.encode() + assert decoded_message1 == message1 + + assert returned_offset2 == 1 + message2 = Message(b"v2", key=b"k2") + message2.encode() + assert decoded_message2 == message2 + + +def test_encode_message_header(): + expect = b"".join( + [ + struct.pack(">h", 10), # API Key + struct.pack(">h", 0), # API Version + struct.pack(">i", 4), # Correlation Id + struct.pack(">h", len("client3")), # Length of clientId + b"client3", # ClientId + ] + ) + + req = GroupCoordinatorRequest[0]("foo") + header = RequestHeader(req, correlation_id=4, client_id="client3") + assert header.encode() == expect + + +def test_decode_message_set_partial(): + encoded = b"".join( + [ + struct.pack(">q", 0), # Msg Offset + struct.pack(">i", 18), # Msg Size + struct.pack(">i", 1474775406), # CRC + struct.pack(">bb", 0, 0), # Magic, flags + struct.pack(">i", 2), # Length of key + b"k1", # Key + struct.pack(">i", 2), # Length of value + b"v1", # Value + struct.pack(">q", 1), # Msg Offset + struct.pack(">i", 24), # Msg Size (larger than remaining MsgSet size) + struct.pack(">i", -16383415), # CRC + struct.pack(">bb", 0, 0), # Magic, flags + struct.pack(">i", 2), # Length of key + b"k2", # Key + struct.pack(">i", 8), # Length of value + b"ar", # Value (truncated) + ] + ) + + msgs = MessageSet.decode(encoded, bytes_to_read=len(encoded)) + assert len(msgs) == 2 + msg1, msg2 = msgs + + returned_offset1, message1_size, decoded_message1 = msg1 + returned_offset2, message2_size, decoded_message2 = msg2 + + assert returned_offset1 == 0 + message1 = Message(b"v1", key=b"k1") + message1.encode() + assert decoded_message1 == message1 + + assert returned_offset2 is None + assert message2_size is None + assert decoded_message2 == PartialMessage() + + +def test_decode_fetch_response_partial(): + encoded = b"".join( + [ + Int32.encode(1), # Num Topics (Array) + String("utf-8").encode("foobar"), + Int32.encode(2), # Num Partitions (Array) + Int32.encode(0), # Partition id + Int16.encode(0), # Error Code + Int64.encode(1234), # Highwater offset + Int32.encode(52), # MessageSet size + Int64.encode(0), # Msg Offset + Int32.encode(18), # Msg Size + struct.pack(">i", 1474775406), # CRC + struct.pack(">bb", 0, 0), # Magic, flags + struct.pack(">i", 2), # Length of key + b"k1", # Key + struct.pack(">i", 2), # Length of value + b"v1", # Value + Int64.encode(1), # Msg Offset + struct.pack(">i", 24), # Msg Size (larger than remaining MsgSet size) + struct.pack(">i", -16383415), # CRC + struct.pack(">bb", 0, 0), # Magic, flags + struct.pack(">i", 2), # Length of key + b"k2", # Key + struct.pack(">i", 8), # Length of value + b"ar", # Value (truncated) + Int32.encode(1), + Int16.encode(0), + Int64.encode(2345), + Int32.encode(52), # MessageSet size + Int64.encode(0), # Msg Offset + Int32.encode(18), # Msg Size + struct.pack(">i", 1474775406), # CRC + struct.pack(">bb", 0, 0), # Magic, flags + struct.pack(">i", 2), # Length of key + b"k1", # Key + struct.pack(">i", 2), # Length of value + b"v1", # Value + Int64.encode(1), # Msg Offset + struct.pack(">i", 24), # Msg Size (larger than remaining MsgSet size) + struct.pack(">i", -16383415), # CRC + struct.pack(">bb", 0, 0), # Magic, flags + struct.pack(">i", 2), # Length of key + b"k2", # Key + struct.pack(">i", 8), # Length of value + b"ar", # Value (truncated) + ] + ) + resp = FetchResponse[0].decode(io.BytesIO(encoded)) + assert len(resp.topics) == 1 + topic, partitions = resp.topics[0] + assert topic == "foobar" + assert len(partitions) == 2 + + m1 = MessageSet.decode(partitions[0][3], bytes_to_read=len(partitions[0][3])) + assert len(m1) == 2 + assert m1[1] == (None, None, PartialMessage()) + + +def test_struct_unrecognized_kwargs(): + try: + MetadataRequest[0](topicz="foo") + assert False, "Structs should not allow unrecognized kwargs" + except ValueError: + pass + + +def test_struct_missing_kwargs(): + fr = FetchRequest[0](max_wait_time=100) + assert fr.min_bytes is None + + +def test_unsigned_varint_serde(): + pairs = { + 0: [0], + -1: [0xFF, 0xFF, 0xFF, 0xFF, 0x0F], + 1: [1], + 63: [0x3F], + -64: [0xC0, 0xFF, 0xFF, 0xFF, 0x0F], + 64: [0x40], + 8191: [0xFF, 0x3F], + -8192: [0x80, 0xC0, 0xFF, 0xFF, 0x0F], + 8192: [0x80, 0x40], + -8193: [0xFF, 0xBF, 0xFF, 0xFF, 0x0F], + 1048575: [0xFF, 0xFF, 0x3F], + } + for value, expected_encoded in pairs.items(): + value &= 0xFFFFFFFF + encoded = UnsignedVarInt32.encode(value) + assert encoded == b"".join(struct.pack(">B", x) for x in expected_encoded) + assert value == UnsignedVarInt32.decode(io.BytesIO(encoded)) + + +def test_compact_data_structs(): + cs = CompactString() + encoded = cs.encode(None) + assert encoded == struct.pack("B", 0) + decoded = cs.decode(io.BytesIO(encoded)) + assert decoded is None + assert b"\x01" == cs.encode("") + assert "" == cs.decode(io.BytesIO(b"\x01")) + encoded = cs.encode("foobarbaz") + assert cs.decode(io.BytesIO(encoded)) == "foobarbaz" + + arr = CompactArray(CompactString()) + assert arr.encode(None) == b"\x00" + assert arr.decode(io.BytesIO(b"\x00")) is None + enc = arr.encode([]) + assert enc == b"\x01" + assert [] == arr.decode(io.BytesIO(enc)) + encoded = arr.encode(["foo", "bar", "baz", "quux"]) + assert arr.decode(io.BytesIO(encoded)) == ["foo", "bar", "baz", "quux"] + + enc = CompactBytes.encode(None) + assert enc == b"\x00" + assert CompactBytes.decode(io.BytesIO(b"\x00")) is None + enc = CompactBytes.encode(b"") + assert enc == b"\x01" + assert CompactBytes.decode(io.BytesIO(b"\x01")) == b"" + enc = CompactBytes.encode(b"foo") + assert CompactBytes.decode(io.BytesIO(enc)) == b"foo" + + +attr_names = [ + n for n in dir(Request) if isinstance(getattr(Request, n), abc.abstractproperty) +] + + +@pytest.mark.parametrize("klass", Request.__subclasses__()) +@pytest.mark.parametrize("attr_name", attr_names) +def test_request_type_conformance(klass, attr_name): + assert hasattr(klass, attr_name) + + +attr_names = [ + n for n in dir(Response) if isinstance(getattr(Response, n), abc.abstractproperty) +] + + +@pytest.mark.parametrize("klass", Response.__subclasses__()) +@pytest.mark.parametrize("attr_name", attr_names) +def test_response_type_conformance(klass, attr_name): + assert hasattr(klass, attr_name) diff --git a/tests/test_protocol_object_conversion.py b/tests/test_protocol_object_conversion.py new file mode 100644 index 00000000..5a5317cf --- /dev/null +++ b/tests/test_protocol_object_conversion.py @@ -0,0 +1,251 @@ +from aiokafka.protocol.admin import Request +from aiokafka.protocol.admin import Response +from aiokafka.protocol.types import Schema +from aiokafka.protocol.types import Array +from aiokafka.protocol.types import Int16 +from aiokafka.protocol.types import String + +import pytest + + +@pytest.mark.parametrize("superclass", (Request, Response)) +class TestObjectConversion: + def test_get_item(self, superclass): + class TestClass(superclass): + API_KEY = 0 + API_VERSION = 0 + RESPONSE_TYPE = None # To satisfy the Request ABC + SCHEMA = Schema(("myobject", Int16)) + + tc = TestClass(myobject=0) + assert tc.get_item("myobject") == 0 + with pytest.raises(KeyError): + tc.get_item("does-not-exist") + + def test_with_empty_schema(self, superclass): + class TestClass(superclass): + API_KEY = 0 + API_VERSION = 0 + RESPONSE_TYPE = None # To satisfy the Request ABC + SCHEMA = Schema() + + tc = TestClass() + tc.encode() + assert tc.to_object() == {} + + def test_with_basic_schema(self, superclass): + class TestClass(superclass): + API_KEY = 0 + API_VERSION = 0 + RESPONSE_TYPE = None # To satisfy the Request ABC + SCHEMA = Schema(("myobject", Int16)) + + tc = TestClass(myobject=0) + tc.encode() + assert tc.to_object() == {"myobject": 0} + + def test_with_basic_array_schema(self, superclass): + class TestClass(superclass): + API_KEY = 0 + API_VERSION = 0 + RESPONSE_TYPE = None # To satisfy the Request ABC + SCHEMA = Schema(("myarray", Array(Int16))) + + tc = TestClass(myarray=[1, 2, 3]) + tc.encode() + assert tc.to_object()["myarray"] == [1, 2, 3] + + def test_with_complex_array_schema(self, superclass): + class TestClass(superclass): + API_KEY = 0 + API_VERSION = 0 + RESPONSE_TYPE = None # To satisfy the Request ABC + SCHEMA = Schema( + ( + "myarray", + Array(("subobject", Int16), ("othersubobject", String("utf-8"))), + ) + ) + + tc = TestClass(myarray=[[10, "hello"]]) + tc.encode() + obj = tc.to_object() + assert len(obj["myarray"]) == 1 + assert obj["myarray"][0]["subobject"] == 10 + assert obj["myarray"][0]["othersubobject"] == "hello" + + def test_with_array_and_other(self, superclass): + class TestClass(superclass): + API_KEY = 0 + API_VERSION = 0 + RESPONSE_TYPE = None # To satisfy the Request ABC + SCHEMA = Schema( + ( + "myarray", + Array(("subobject", Int16), ("othersubobject", String("utf-8"))), + ), + ("notarray", Int16), + ) + + tc = TestClass(myarray=[[10, "hello"]], notarray=42) + + obj = tc.to_object() + assert len(obj["myarray"]) == 1 + assert obj["myarray"][0]["subobject"] == 10 + assert obj["myarray"][0]["othersubobject"] == "hello" + assert obj["notarray"] == 42 + + def test_with_nested_array(self, superclass): + class TestClass(superclass): + API_KEY = 0 + API_VERSION = 0 + RESPONSE_TYPE = None # To satisfy the Request ABC + SCHEMA = Schema( + ("myarray", Array(("subarray", Array(Int16)), ("otherobject", Int16))) + ) + + tc = TestClass( + myarray=[ + [[1, 2], 2], + [[2, 3], 4], + ] + ) + print(tc.encode()) + + obj = tc.to_object() + assert len(obj["myarray"]) == 2 + assert obj["myarray"][0]["subarray"] == [1, 2] + assert obj["myarray"][0]["otherobject"] == 2 + assert obj["myarray"][1]["subarray"] == [2, 3] + assert obj["myarray"][1]["otherobject"] == 4 + + def test_with_complex_nested_array(self, superclass): + class TestClass(superclass): + API_KEY = 0 + API_VERSION = 0 + RESPONSE_TYPE = None # To satisfy the Request ABC + SCHEMA = Schema( + ( + "myarray", + Array( + ( + "subarray", + Array( + ("innertest", String("utf-8")), + ("otherinnertest", String("utf-8")), + ), + ), + ("othersubarray", Array(Int16)), + ), + ), + ("notarray", String("utf-8")), + ) + + tc = TestClass( + myarray=[ + [[["hello", "hello"], ["hello again", "hello again"]], [0]], + [[["hello", "hello again"]], [1]], + ], + notarray="notarray", + ) + tc.encode() + + obj = tc.to_object() + + assert obj["notarray"] == "notarray" + myarray = obj["myarray"] + assert len(myarray) == 2 + + assert myarray[0]["othersubarray"] == [0] + assert len(myarray[0]["subarray"]) == 2 + assert myarray[0]["subarray"][0]["innertest"] == "hello" + assert myarray[0]["subarray"][0]["otherinnertest"] == "hello" + assert myarray[0]["subarray"][1]["innertest"] == "hello again" + assert myarray[0]["subarray"][1]["otherinnertest"] == "hello again" + + assert myarray[1]["othersubarray"] == [1] + assert len(myarray[1]["subarray"]) == 1 + assert myarray[1]["subarray"][0]["innertest"] == "hello" + assert myarray[1]["subarray"][0]["otherinnertest"] == "hello again" + + +def test_with_metadata_response(): + from aiokafka.protocol.metadata import MetadataResponse_v5 + + tc = MetadataResponse_v5( + throttle_time_ms=0, + brokers=[ + [0, "testhost0", 9092, "testrack0"], + [1, "testhost1", 9092, "testrack1"], + ], + cluster_id="abcd", + controller_id=0, + topics=[ + [ + 0, + "testtopic1", + False, + [ + [0, 0, 0, [0, 1], [0, 1], []], + [0, 1, 1, [1, 0], [1, 0], []], + ], + ], + [ + 0, + "other-test-topic", + True, + [ + [0, 0, 0, [0, 1], [0, 1], []], + ], + ], + ], + ) + tc.encode() # Make sure this object encodes successfully + + obj = tc.to_object() + + assert obj["throttle_time_ms"] == 0 + + assert len(obj["brokers"]) == 2 + assert obj["brokers"][0]["node_id"] == 0 + assert obj["brokers"][0]["host"] == "testhost0" + assert obj["brokers"][0]["port"] == 9092 + assert obj["brokers"][0]["rack"] == "testrack0" + assert obj["brokers"][1]["node_id"] == 1 + assert obj["brokers"][1]["host"] == "testhost1" + assert obj["brokers"][1]["port"] == 9092 + assert obj["brokers"][1]["rack"] == "testrack1" + + assert obj["cluster_id"] == "abcd" + assert obj["controller_id"] == 0 + + assert len(obj["topics"]) == 2 + assert obj["topics"][0]["error_code"] == 0 + assert obj["topics"][0]["topic"] == "testtopic1" + assert obj["topics"][0]["is_internal"] is False + assert len(obj["topics"][0]["partitions"]) == 2 + assert obj["topics"][0]["partitions"][0]["error_code"] == 0 + assert obj["topics"][0]["partitions"][0]["partition"] == 0 + assert obj["topics"][0]["partitions"][0]["leader"] == 0 + assert obj["topics"][0]["partitions"][0]["replicas"] == [0, 1] + assert obj["topics"][0]["partitions"][0]["isr"] == [0, 1] + assert obj["topics"][0]["partitions"][0]["offline_replicas"] == [] + assert obj["topics"][0]["partitions"][1]["error_code"] == 0 + assert obj["topics"][0]["partitions"][1]["partition"] == 1 + assert obj["topics"][0]["partitions"][1]["leader"] == 1 + assert obj["topics"][0]["partitions"][1]["replicas"] == [1, 0] + assert obj["topics"][0]["partitions"][1]["isr"] == [1, 0] + assert obj["topics"][0]["partitions"][1]["offline_replicas"] == [] + + assert obj["topics"][1]["error_code"] == 0 + assert obj["topics"][1]["topic"] == "other-test-topic" + assert obj["topics"][1]["is_internal"] is True + assert len(obj["topics"][1]["partitions"]) == 1 + assert obj["topics"][1]["partitions"][0]["error_code"] == 0 + assert obj["topics"][1]["partitions"][0]["partition"] == 0 + assert obj["topics"][1]["partitions"][0]["leader"] == 0 + assert obj["topics"][1]["partitions"][0]["replicas"] == [0, 1] + assert obj["topics"][1]["partitions"][0]["isr"] == [0, 1] + assert obj["topics"][1]["partitions"][0]["offline_replicas"] == [] + + tc.encode() diff --git a/tests/test_sender.py b/tests/test_sender.py index 01fdb1d8..099965ee 100644 --- a/tests/test_sender.py +++ b/tests/test_sender.py @@ -1,11 +1,17 @@ from unittest import mock -from ._testutil import ( - KafkaIntegrationTestCase, run_until_complete, kafka_versions +from aiokafka.client import AIOKafkaClient, CoordinationType, ConnectionGroup +from aiokafka.errors import ( + NoError, UnknownError, + CoordinatorNotAvailableError, NotCoordinatorError, + CoordinatorLoadInProgressError, ConcurrentTransactions, + UnknownTopicOrPartitionError, InvalidProducerEpoch, + ProducerFenced, InvalidProducerIdMapping, InvalidTxnState, + RequestTimedOutError, DuplicateSequenceNumber, KafkaError, + TopicAuthorizationFailedError, OperationNotAttempted, + TransactionalIdAuthorizationFailed, GroupAuthorizationFailedError ) - -from kafka.protocol.produce import ProduceRequest, ProduceResponse - +from aiokafka.producer.message_accumulator import MessageAccumulator from aiokafka.producer.sender import ( Sender, InitPIDHandler, AddPartitionsToTxnHandler, AddOffsetsToTxnHandler, TxnOffsetCommitHandler, EndTxnHandler, @@ -14,6 +20,8 @@ from aiokafka.producer.transaction_manager import ( TransactionManager, TransactionState ) +from aiokafka.protocol.metadata import MetadataRequest +from aiokafka.protocol.produce import ProduceRequest, ProduceResponse from aiokafka.protocol.transaction import ( InitProducerIdRequest, InitProducerIdResponse, AddPartitionsToTxnRequest, AddPartitionsToTxnResponse, @@ -21,25 +29,13 @@ TxnOffsetCommitRequest, TxnOffsetCommitResponse, EndTxnRequest, EndTxnResponse ) -from aiokafka.producer.message_accumulator import MessageAccumulator -from aiokafka.client import AIOKafkaClient, CoordinationType, ConnectionGroup from aiokafka.structs import TopicPartition, OffsetAndMetadata from aiokafka.util import get_running_loop -from aiokafka.errors import ( - NoError, UnknownError, - CoordinatorNotAvailableError, NotCoordinatorError, - CoordinatorLoadInProgressError, ConcurrentTransactions, - UnknownTopicOrPartitionError, InvalidProducerEpoch, - ProducerFenced, InvalidProducerIdMapping, InvalidTxnState, - RequestTimedOutError, DuplicateSequenceNumber, KafkaError, - TopicAuthorizationFailedError, OperationNotAttempted, - TransactionalIdAuthorizationFailed, GroupAuthorizationFailedError +from ._testutil import ( + KafkaIntegrationTestCase, run_until_complete, kafka_versions ) -from kafka.protocol.metadata import MetadataRequest - - LOG_APPEND_TIME = 1 From 8c4cd4084e52bc3de955a2e20c6f3311f48b6e40 Mon Sep 17 00:00:00 2001 From: Denis Otkidach Date: Mon, 23 Oct 2023 12:17:39 +0300 Subject: [PATCH 17/20] Move codec --- {kafka => aiokafka}/codec.py | 124 ++++++++-------- aiokafka/producer/producer.py | 3 +- aiokafka/protocol/message.py | 2 +- aiokafka/record/_crecords/default_records.pyx | 6 +- aiokafka/record/_crecords/legacy_records.pyx | 4 +- aiokafka/record/default_records.py | 7 +- aiokafka/record/legacy_records.py | 8 +- tests/conftest.py | 1 + tests/kafka/test_codec.py | 124 ---------------- tests/record/test_default_records.py | 4 +- tests/record/test_legacy.py | 4 +- tests/test_codec.py | 136 ++++++++++++++++++ 12 files changed, 214 insertions(+), 209 deletions(-) rename {kafka => aiokafka}/codec.py (67%) delete mode 100644 tests/kafka/test_codec.py create mode 100644 tests/test_codec.py diff --git a/kafka/codec.py b/aiokafka/codec.py similarity index 67% rename from kafka/codec.py rename to aiokafka/codec.py index c740a181..2e3ddaaf 100644 --- a/kafka/codec.py +++ b/aiokafka/codec.py @@ -1,15 +1,10 @@ -from __future__ import absolute_import - import gzip import io import platform import struct -from kafka.vendor import six -from kafka.vendor.six.moves import range - -_XERIAL_V1_HEADER = (-126, b'S', b'N', b'A', b'P', b'P', b'Y', 0, 1, 1) -_XERIAL_V1_FORMAT = 'bccccccBii' +_XERIAL_V1_HEADER = (-126, b"S", b"N", b"A", b"P", b"P", b"Y", 0, 1, 1) +_XERIAL_V1_FORMAT = "bccccccBii" ZSTD_MAX_OUTPUT_SIZE = 1024 * 1024 try: @@ -29,11 +24,11 @@ def _lz4_compress(payload, **kwargs): # Kafka does not support LZ4 dependent blocks try: # For lz4>=0.12.0 - kwargs.pop('block_linked', None) + kwargs.pop("block_linked", None) return lz4.compress(payload, block_linked=False, **kwargs) except TypeError: # For earlier versions of lz4 - kwargs.pop('block_mode', None) + kwargs.pop("block_mode", None) return lz4.compress(payload, block_mode=1, **kwargs) except ImportError: @@ -54,7 +49,8 @@ def _lz4_compress(payload, **kwargs): except ImportError: xxhash = None -PYPY = bool(platform.python_implementation() == 'PyPy') +PYPY = bool(platform.python_implementation() == "PyPy") + def has_gzip(): return True @@ -100,14 +96,14 @@ def gzip_decode(payload): # Gzip context manager introduced in python 2.7 # so old-fashioned way until we decide to not support 2.6 - gzipper = gzip.GzipFile(fileobj=buf, mode='r') + gzipper = gzip.GzipFile(fileobj=buf, mode="r") try: return gzipper.read() finally: gzipper.close() -def snappy_encode(payload, xerial_compatible=True, xerial_blocksize=32*1024): +def snappy_encode(payload, xerial_compatible=True, xerial_blocksize=32 * 1024): """Encodes the given data with snappy compression. If xerial_compatible is set then the stream is encoded in a fashion @@ -141,30 +137,30 @@ def snappy_encode(payload, xerial_compatible=True, xerial_blocksize=32*1024): out = io.BytesIO() for fmt, dat in zip(_XERIAL_V1_FORMAT, _XERIAL_V1_HEADER): - out.write(struct.pack('!' + fmt, dat)) + out.write(struct.pack("!" + fmt, dat)) # Chunk through buffers to avoid creating intermediate slice copies if PYPY: # on pypy, snappy.compress() on a sliced buffer consumes the entire # buffer... likely a python-snappy bug, so just use a slice copy - chunker = lambda payload, i, size: payload[i:size+i] + def chunker(payload, i, size): + return payload[i:size + i] - elif six.PY2: - # Sliced buffer avoids additional copies - # pylint: disable-msg=undefined-variable - chunker = lambda payload, i, size: buffer(payload, i, size) else: # snappy.compress does not like raw memoryviews, so we have to convert # tobytes, which is a copy... oh well. it's the thought that counts. # pylint: disable-msg=undefined-variable - chunker = lambda payload, i, size: memoryview(payload)[i:size+i].tobytes() + def chunker(payload, i, size): + return memoryview(payload)[i:size + i].tobytes() - for chunk in (chunker(payload, i, xerial_blocksize) - for i in range(0, len(payload), xerial_blocksize)): + for chunk in ( + chunker(payload, i, xerial_blocksize) + for i in range(0, len(payload), xerial_blocksize) + ): block = snappy.compress(chunk) block_size = len(block) - out.write(struct.pack('!i', block_size)) + out.write(struct.pack("!i", block_size)) out.write(block) return out.getvalue() @@ -172,28 +168,28 @@ def snappy_encode(payload, xerial_compatible=True, xerial_blocksize=32*1024): def _detect_xerial_stream(payload): """Detects if the data given might have been encoded with the blocking mode - of the xerial snappy library. - - This mode writes a magic header of the format: - +--------+--------------+------------+---------+--------+ - | Marker | Magic String | Null / Pad | Version | Compat | - +--------+--------------+------------+---------+--------+ - | byte | c-string | byte | int32 | int32 | - +--------+--------------+------------+---------+--------+ - | -126 | 'SNAPPY' | \0 | | | - +--------+--------------+------------+---------+--------+ - - The pad appears to be to ensure that SNAPPY is a valid cstring - The version is the version of this format as written by xerial, - in the wild this is currently 1 as such we only support v1. - - Compat is there to claim the minimum supported version that - can read a xerial block stream, presently in the wild this is - 1. + of the xerial snappy library. + + This mode writes a magic header of the format: + +--------+--------------+------------+---------+--------+ + | Marker | Magic String | Null / Pad | Version | Compat | + +--------+--------------+------------+---------+--------+ + | byte | c-string | byte | int32 | int32 | + +--------+--------------+------------+---------+--------+ + | -126 | 'SNAPPY' | \0 | | | + +--------+--------------+------------+---------+--------+ + + The pad appears to be to ensure that SNAPPY is a valid cstring + The version is the version of this format as written by xerial, + in the wild this is currently 1 as such we only support v1. + + Compat is there to claim the minimum supported version that + can read a xerial block stream, presently in the wild this is + 1. """ if len(payload) > 16: - header = struct.unpack('!' + _XERIAL_V1_FORMAT, bytes(payload)[:16]) + header = struct.unpack("!" + _XERIAL_V1_FORMAT, bytes(payload)[:16]) return header == _XERIAL_V1_HEADER return False @@ -210,7 +206,7 @@ def snappy_decode(payload): cursor = 0 while cursor < length: - block_size = struct.unpack_from('!i', byt[cursor:])[0] + block_size = struct.unpack_from("!i", byt[cursor:])[0] # Skip the block size cursor += 4 end = cursor + block_size @@ -224,11 +220,11 @@ def snappy_decode(payload): if lz4: - lz4_encode = _lz4_compress # pylint: disable-msg=no-member + lz4_encode = _lz4_compress # pylint: disable-msg=no-member elif lz4f: - lz4_encode = lz4f.compressFrame # pylint: disable-msg=no-member + lz4_encode = lz4f.compressFrame # pylint: disable-msg=no-member elif lz4framed: - lz4_encode = lz4framed.compress # pylint: disable-msg=no-member + lz4_encode = lz4framed.compress # pylint: disable-msg=no-member else: lz4_encode = None @@ -242,17 +238,17 @@ def lz4f_decode(payload): # lz4f python module does not expose how much of the payload was # actually read if the decompression was only partial. - if data['next'] != 0: - raise RuntimeError('lz4f unable to decompress full payload') - return data['decomp'] + if data["next"] != 0: + raise RuntimeError("lz4f unable to decompress full payload") + return data["decomp"] if lz4: - lz4_decode = lz4.decompress # pylint: disable-msg=no-member + lz4_decode = lz4.decompress # pylint: disable-msg=no-member elif lz4f: lz4_decode = lz4f_decode elif lz4framed: - lz4_decode = lz4framed.decompress # pylint: disable-msg=no-member + lz4_decode = lz4framed.decompress # pylint: disable-msg=no-member else: lz4_decode = None @@ -266,7 +262,7 @@ def lz4_encode_old_kafka(payload): if not isinstance(flg, int): flg = ord(flg) - content_size_bit = ((flg >> 3) & 1) + content_size_bit = (flg >> 3) & 1 if content_size_bit: # Old kafka does not accept the content-size field # so we need to discard it and reset the header flag @@ -274,18 +270,16 @@ def lz4_encode_old_kafka(payload): data = bytearray(data) data[4] = flg data = bytes(data) - payload = data[header_size+8:] + payload = data[header_size + 8:] else: payload = data[header_size:] # This is the incorrect hc - hc = xxhash.xxh32(data[0:header_size-1]).digest()[-2:-1] # pylint: disable-msg=no-member + hc = xxhash.xxh32(data[0:header_size - 1]).digest()[ + -2:-1 + ] # pylint: disable-msg=no-member - return b''.join([ - data[0:header_size-1], - hc, - payload - ]) + return b"".join([data[0:header_size - 1], hc, payload]) def lz4_decode_old_kafka(payload): @@ -296,18 +290,14 @@ def lz4_decode_old_kafka(payload): flg = payload[4] else: flg = ord(payload[4]) - content_size_bit = ((flg >> 3) & 1) + content_size_bit = (flg >> 3) & 1 if content_size_bit: header_size += 8 # This should be the correct hc - hc = xxhash.xxh32(payload[4:header_size-1]).digest()[-2:-1] # pylint: disable-msg=no-member + hc = xxhash.xxh32(payload[4:header_size - 1]).digest()[-2:-1] - munged_payload = b''.join([ - payload[0:header_size-1], - hc, - payload[header_size:] - ]) + munged_payload = b"".join([payload[0:header_size - 1], hc, payload[header_size:]]) return lz4_decode(munged_payload) @@ -323,4 +313,6 @@ def zstd_decode(payload): try: return zstd.ZstdDecompressor().decompress(payload) except zstd.ZstdError: - return zstd.ZstdDecompressor().decompress(payload, max_output_size=ZSTD_MAX_OUTPUT_SIZE) + return zstd.ZstdDecompressor().decompress( + payload, max_output_size=ZSTD_MAX_OUTPUT_SIZE + ) diff --git a/aiokafka/producer/producer.py b/aiokafka/producer/producer.py index 12a07e7f..3c5a096e 100644 --- a/aiokafka/producer/producer.py +++ b/aiokafka/producer/producer.py @@ -4,9 +4,8 @@ import traceback import warnings -from kafka.codec import has_gzip, has_snappy, has_lz4, has_zstd - from aiokafka.client import AIOKafkaClient +from aiokafka.codec import has_gzip, has_snappy, has_lz4, has_zstd from aiokafka.errors import ( MessageSizeTooLargeError, UnsupportedVersionError, IllegalOperation) from aiokafka.partitioner import DefaultPartitioner diff --git a/aiokafka/protocol/message.py b/aiokafka/protocol/message.py index d187b9bc..3fc665e2 100644 --- a/aiokafka/protocol/message.py +++ b/aiokafka/protocol/message.py @@ -1,7 +1,7 @@ import io import time -from kafka.codec import ( +from aiokafka.codec import ( has_gzip, has_snappy, has_lz4, diff --git a/aiokafka/record/_crecords/default_records.pyx b/aiokafka/record/_crecords/default_records.pyx index 73ec764a..ba49411d 100644 --- a/aiokafka/record/_crecords/default_records.pyx +++ b/aiokafka/record/_crecords/default_records.pyx @@ -55,12 +55,12 @@ # * Timestamp Type (3) # * Compression Type (0-2) -from aiokafka.errors import CorruptRecordException, UnsupportedCodecError -from kafka.codec import ( +import aiokafka.codec as codecs +from aiokafka.codec import ( gzip_encode, snappy_encode, lz4_encode, zstd_encode, gzip_decode, snappy_decode, lz4_decode, zstd_decode ) -import kafka.codec as codecs +from aiokafka.errors import CorruptRecordException, UnsupportedCodecError from cpython cimport PyObject_GetBuffer, PyBuffer_Release, PyBUF_WRITABLE, \ PyBUF_SIMPLE, PyBUF_READ, Py_buffer, \ diff --git a/aiokafka/record/_crecords/legacy_records.pyx b/aiokafka/record/_crecords/legacy_records.pyx index 3c2f366a..2406ef12 100644 --- a/aiokafka/record/_crecords/legacy_records.pyx +++ b/aiokafka/record/_crecords/legacy_records.pyx @@ -1,10 +1,10 @@ #cython: language_level=3 -from kafka.codec import ( +import aiokafka.codec as codecs +from aiokafka.codec import ( gzip_encode, snappy_encode, lz4_encode, lz4_encode_old_kafka, gzip_decode, snappy_decode, lz4_decode, lz4_decode_old_kafka ) -import kafka.codec as codecs from aiokafka.errors import CorruptRecordException, UnsupportedCodecError from zlib import crc32 as py_crc32 # needed for windows macro diff --git a/aiokafka/record/default_records.py b/aiokafka/record/default_records.py index e9ee69ee..4adaf871 100644 --- a/aiokafka/record/default_records.py +++ b/aiokafka/record/default_records.py @@ -56,15 +56,16 @@ import struct import time -from .util import decode_varint, encode_varint, calc_crc32c, size_of_varint +import aiokafka.codec as codecs from aiokafka.errors import CorruptRecordException, UnsupportedCodecError from aiokafka.util import NO_EXTENSIONS -from kafka.codec import ( +from aiokafka.codec import ( gzip_encode, snappy_encode, lz4_encode, zstd_encode, gzip_decode, snappy_decode, lz4_decode, zstd_decode ) -import kafka.codec as codecs + +from .util import decode_varint, encode_varint, calc_crc32c, size_of_varint class DefaultRecordBase: diff --git a/aiokafka/record/legacy_records.py b/aiokafka/record/legacy_records.py index e0364651..c1ae9480 100644 --- a/aiokafka/record/legacy_records.py +++ b/aiokafka/record/legacy_records.py @@ -3,13 +3,13 @@ from binascii import crc32 -from aiokafka.errors import CorruptRecordException, UnsupportedCodecError -from aiokafka.util import NO_EXTENSIONS -from kafka.codec import ( +import aiokafka.codec as codecs +from aiokafka.codec import ( gzip_encode, snappy_encode, lz4_encode, lz4_encode_old_kafka, gzip_decode, snappy_decode, lz4_decode, lz4_decode_old_kafka ) -import kafka.codec as codecs +from aiokafka.errors import CorruptRecordException, UnsupportedCodecError +from aiokafka.util import NO_EXTENSIONS NoneType = type(None) diff --git a/tests/conftest.py b/tests/conftest.py index 1cd91b22..66080e5e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,7 @@ from aiokafka.record.default_records import ( DefaultRecordBatchBuilder, _DefaultRecordBatchBuilderPy) from aiokafka.util import NO_EXTENSIONS + from ._testutil import wait_kafka diff --git a/tests/kafka/test_codec.py b/tests/kafka/test_codec.py deleted file mode 100644 index db6a14b6..00000000 --- a/tests/kafka/test_codec.py +++ /dev/null @@ -1,124 +0,0 @@ -from __future__ import absolute_import - -import platform -import struct - -import pytest -from kafka.vendor.six.moves import range - -from kafka.codec import ( - has_snappy, has_lz4, has_zstd, - gzip_encode, gzip_decode, - snappy_encode, snappy_decode, - lz4_encode, lz4_decode, - lz4_encode_old_kafka, lz4_decode_old_kafka, - zstd_encode, zstd_decode, -) - -from tests.kafka.testutil import random_string - - -def test_gzip(): - for i in range(1000): - b1 = random_string(100).encode('utf-8') - b2 = gzip_decode(gzip_encode(b1)) - assert b1 == b2 - - -@pytest.mark.skipif(not has_snappy(), reason="Snappy not available") -def test_snappy(): - for i in range(1000): - b1 = random_string(100).encode('utf-8') - b2 = snappy_decode(snappy_encode(b1)) - assert b1 == b2 - - -@pytest.mark.skipif(not has_snappy(), reason="Snappy not available") -def test_snappy_detect_xerial(): - import kafka as kafka1 - _detect_xerial_stream = kafka1.codec._detect_xerial_stream - - header = b'\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01Some extra bytes' - false_header = b'\x01SNAPPY\x00\x00\x00\x01\x00\x00\x00\x01' - default_snappy = snappy_encode(b'foobar' * 50) - random_snappy = snappy_encode(b'SNAPPY' * 50, xerial_compatible=False) - short_data = b'\x01\x02\x03\x04' - - assert _detect_xerial_stream(header) is True - assert _detect_xerial_stream(b'') is False - assert _detect_xerial_stream(b'\x00') is False - assert _detect_xerial_stream(false_header) is False - assert _detect_xerial_stream(default_snappy) is True - assert _detect_xerial_stream(random_snappy) is False - assert _detect_xerial_stream(short_data) is False - - -@pytest.mark.skipif(not has_snappy(), reason="Snappy not available") -def test_snappy_decode_xerial(): - header = b'\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01' - random_snappy = snappy_encode(b'SNAPPY' * 50, xerial_compatible=False) - block_len = len(random_snappy) - random_snappy2 = snappy_encode(b'XERIAL' * 50, xerial_compatible=False) - block_len2 = len(random_snappy2) - - to_test = header \ - + struct.pack('!i', block_len) + random_snappy \ - + struct.pack('!i', block_len2) + random_snappy2 \ - - assert snappy_decode(to_test) == (b'SNAPPY' * 50) + (b'XERIAL' * 50) - - -@pytest.mark.skipif(not has_snappy(), reason="Snappy not available") -def test_snappy_encode_xerial(): - to_ensure = ( - b'\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01' - b'\x00\x00\x00\x18' - b'\xac\x02\x14SNAPPY\xfe\x06\x00\xfe\x06\x00\xfe\x06\x00\xfe\x06\x00\x96\x06\x00' - b'\x00\x00\x00\x18' - b'\xac\x02\x14XERIAL\xfe\x06\x00\xfe\x06\x00\xfe\x06\x00\xfe\x06\x00\x96\x06\x00' - ) - - to_test = (b'SNAPPY' * 50) + (b'XERIAL' * 50) - - compressed = snappy_encode(to_test, xerial_compatible=True, xerial_blocksize=300) - assert compressed == to_ensure - - -@pytest.mark.skipif(not has_lz4() or platform.python_implementation() == 'PyPy', - reason="python-lz4 crashes on old versions of pypy") -def test_lz4(): - for i in range(1000): - b1 = random_string(100).encode('utf-8') - b2 = lz4_decode(lz4_encode(b1)) - assert len(b1) == len(b2) - assert b1 == b2 - - -@pytest.mark.skipif(not has_lz4() or platform.python_implementation() == 'PyPy', - reason="python-lz4 crashes on old versions of pypy") -def test_lz4_old(): - for i in range(1000): - b1 = random_string(100).encode('utf-8') - b2 = lz4_decode_old_kafka(lz4_encode_old_kafka(b1)) - assert len(b1) == len(b2) - assert b1 == b2 - - -@pytest.mark.skipif(not has_lz4() or platform.python_implementation() == 'PyPy', - reason="python-lz4 crashes on old versions of pypy") -def test_lz4_incremental(): - for i in range(1000): - # lz4 max single block size is 4MB - # make sure we test with multiple-blocks - b1 = random_string(100).encode('utf-8') * 50000 - b2 = lz4_decode(lz4_encode(b1)) - assert len(b1) == len(b2) - assert b1 == b2 - - -@pytest.mark.skipif(not has_zstd(), reason="Zstd not available") -def test_zstd(): - for _ in range(1000): - b1 = random_string(100).encode('utf-8') - b2 = zstd_decode(zstd_encode(b1)) - assert b1 == b2 diff --git a/tests/record/test_default_records.py b/tests/record/test_default_records.py index 455590c9..a79f2aef 100644 --- a/tests/record/test_default_records.py +++ b/tests/record/test_default_records.py @@ -1,6 +1,6 @@ from unittest import mock -import kafka.codec +import aiokafka.codec import pytest from aiokafka.errors import UnsupportedCodecError @@ -196,7 +196,7 @@ def test_unavailable_codec(compression_type, name, checker_name): builder.append(0, timestamp=None, key=None, value=b"M" * 2000, headers=[]) correct_buffer = builder.build() - with mock.patch.object(kafka.codec, checker_name, return_value=False): + with mock.patch.object(aiokafka.codec, checker_name, return_value=False): # Check that builder raises error builder = DefaultRecordBatchBuilder( magic=2, compression_type=compression_type, is_transactional=0, diff --git a/tests/record/test_legacy.py b/tests/record/test_legacy.py index ee3c6a76..26f3d4dc 100644 --- a/tests/record/test_legacy.py +++ b/tests/record/test_legacy.py @@ -1,7 +1,7 @@ import struct from unittest import mock -import kafka.codec +import aiokafka.codec import pytest from aiokafka.errors import CorruptRecordException, UnsupportedCodecError @@ -202,7 +202,7 @@ def test_unavailable_codec(compression_type, name, checker_name): builder.append(0, timestamp=None, key=None, value=b"M") correct_buffer = builder.build() - with mock.patch.object(kafka.codec, checker_name) as mocked: + with mock.patch.object(aiokafka.codec, checker_name) as mocked: mocked.return_value = False # Check that builder raises error builder = LegacyRecordBatchBuilder( diff --git a/tests/test_codec.py b/tests/test_codec.py new file mode 100644 index 00000000..9ae53487 --- /dev/null +++ b/tests/test_codec.py @@ -0,0 +1,136 @@ +import platform +import struct + +import pytest + +from aiokafka import codec as codecs +from aiokafka.codec import ( + has_snappy, + has_lz4, + has_zstd, + gzip_encode, + gzip_decode, + snappy_encode, + snappy_decode, + lz4_encode, + lz4_decode, + lz4_encode_old_kafka, + lz4_decode_old_kafka, + zstd_encode, + zstd_decode, +) + +from ._testutil import random_string + + +def test_gzip(): + for i in range(1000): + b1 = random_string(100) + b2 = gzip_decode(gzip_encode(b1)) + assert b1 == b2 + + +@pytest.mark.skipif(not has_snappy(), reason="Snappy not available") +def test_snappy(): + for i in range(1000): + b1 = random_string(100) + b2 = snappy_decode(snappy_encode(b1)) + assert b1 == b2 + + +@pytest.mark.skipif(not has_snappy(), reason="Snappy not available") +def test_snappy_detect_xerial(): + _detect_xerial_stream = codecs._detect_xerial_stream + + header = b"\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01Some extra bytes" + false_header = b"\x01SNAPPY\x00\x00\x00\x01\x00\x00\x00\x01" + default_snappy = snappy_encode(b"foobar" * 50) + random_snappy = snappy_encode(b"SNAPPY" * 50, xerial_compatible=False) + short_data = b"\x01\x02\x03\x04" + + assert _detect_xerial_stream(header) is True + assert _detect_xerial_stream(b"") is False + assert _detect_xerial_stream(b"\x00") is False + assert _detect_xerial_stream(false_header) is False + assert _detect_xerial_stream(default_snappy) is True + assert _detect_xerial_stream(random_snappy) is False + assert _detect_xerial_stream(short_data) is False + + +@pytest.mark.skipif(not has_snappy(), reason="Snappy not available") +def test_snappy_decode_xerial(): + header = b"\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01" + random_snappy = snappy_encode(b"SNAPPY" * 50, xerial_compatible=False) + block_len = len(random_snappy) + random_snappy2 = snappy_encode(b"XERIAL" * 50, xerial_compatible=False) + block_len2 = len(random_snappy2) + + to_test = ( + header + + struct.pack("!i", block_len) + + random_snappy + + struct.pack("!i", block_len2) + + random_snappy2 + ) + assert snappy_decode(to_test) == (b"SNAPPY" * 50) + (b"XERIAL" * 50) + + +@pytest.mark.skipif(not has_snappy(), reason="Snappy not available") +def test_snappy_encode_xerial(): + to_ensure = ( + b"\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01" + b"\x00\x00\x00\x18\xac\x02\x14SNAPPY\xfe\x06\x00\xfe\x06\x00\xfe\x06\x00" + b"\xfe\x06\x00\x96\x06\x00\x00\x00\x00\x18\xac\x02" + b"\x14XERIAL\xfe\x06\x00\xfe\x06\x00\xfe\x06\x00\xfe\x06\x00\x96\x06\x00" + ) + + to_test = (b"SNAPPY" * 50) + (b"XERIAL" * 50) + + compressed = snappy_encode(to_test, xerial_compatible=True, xerial_blocksize=300) + assert compressed == to_ensure + + +@pytest.mark.skipif( + not has_lz4() or platform.python_implementation() == "PyPy", + reason="python-lz4 crashes on old versions of pypy", +) +def test_lz4(): + for i in range(1000): + b1 = random_string(100) + b2 = lz4_decode(lz4_encode(b1)) + assert len(b1) == len(b2) + assert b1 == b2 + + +@pytest.mark.skipif( + not has_lz4() or platform.python_implementation() == "PyPy", + reason="python-lz4 crashes on old versions of pypy", +) +def test_lz4_old(): + for i in range(1000): + b1 = random_string(100) + b2 = lz4_decode_old_kafka(lz4_encode_old_kafka(b1)) + assert len(b1) == len(b2) + assert b1 == b2 + + +@pytest.mark.skipif( + not has_lz4() or platform.python_implementation() == "PyPy", + reason="python-lz4 crashes on old versions of pypy", +) +def test_lz4_incremental(): + for i in range(1000): + # lz4 max single block size is 4MB + # make sure we test with multiple-blocks + b1 = random_string(100) * 50000 + b2 = lz4_decode(lz4_encode(b1)) + assert len(b1) == len(b2) + assert b1 == b2 + + +@pytest.mark.skipif(not has_zstd(), reason="Zstd not available") +def test_zstd(): + for _ in range(1000): + b1 = random_string(100) + b2 = zstd_decode(zstd_encode(b1)) + assert b1 == b2 From 6d7016499f3825f4a1b969990a8840ee6953bdd5 Mon Sep 17 00:00:00 2001 From: Denis Otkidach Date: Mon, 23 Oct 2023 12:29:52 +0300 Subject: [PATCH 18/20] Merge util --- aiokafka/coordinator/consumer.py | 2 +- aiokafka/protocol/message.py | 10 +++-- aiokafka/protocol/struct.py | 4 +- aiokafka/protocol/types.py | 13 +++++++ aiokafka/util.py | 37 ++++++++++++++++++ kafka/util.py | 66 -------------------------------- 6 files changed, 59 insertions(+), 73 deletions(-) delete mode 100644 kafka/util.py diff --git a/aiokafka/coordinator/consumer.py b/aiokafka/coordinator/consumer.py index 8f6cdaba..7604e051 100644 --- a/aiokafka/coordinator/consumer.py +++ b/aiokafka/coordinator/consumer.py @@ -5,13 +5,13 @@ import time from kafka.future import Future -from kafka.util import WeakMethod import aiokafka.errors as Errors from aiokafka.metrics import AnonMeasurable from aiokafka.metrics.stats import Avg, Count, Max, Rate from aiokafka.protocol.commit import OffsetCommitRequest, OffsetFetchRequest from aiokafka.structs import OffsetAndMetadata, TopicPartition +from aiokafka.util import WeakMethod from .base import BaseCoordinator, Generation from .assignors.range import RangePartitionAssignor diff --git a/aiokafka/protocol/message.py b/aiokafka/protocol/message.py index 3fc665e2..a305f419 100644 --- a/aiokafka/protocol/message.py +++ b/aiokafka/protocol/message.py @@ -1,5 +1,6 @@ import io import time +from binascii import crc32 from aiokafka.codec import ( has_gzip, @@ -12,23 +13,24 @@ lz4_decode, lz4_decode_old_kafka, ) +from aiokafka.util import WeakMethod + from .frame import KafkaBytes from .struct import Struct -from .types import Int8, Int32, Int64, Bytes, Schema, AbstractType -from kafka.util import crc32, WeakMethod +from .types import Int8, Int32, UInt32, Int64, Bytes, Schema, AbstractType class Message(Struct): SCHEMAS = [ Schema( - ("crc", Int32), + ("crc", UInt32), ("magic", Int8), ("attributes", Int8), ("key", Bytes), ("value", Bytes), ), Schema( - ("crc", Int32), + ("crc", UInt32), ("magic", Int8), ("attributes", Int8), ("timestamp", Int64), diff --git a/aiokafka/protocol/struct.py b/aiokafka/protocol/struct.py index d7faa327..b24d7b2b 100644 --- a/aiokafka/protocol/struct.py +++ b/aiokafka/protocol/struct.py @@ -1,10 +1,10 @@ from io import BytesIO +from aiokafka.util import WeakMethod + from .abstract import AbstractType from .types import Schema -from kafka.util import WeakMethod - class Struct(AbstractType): SCHEMA = Schema() diff --git a/aiokafka/protocol/types.py b/aiokafka/protocol/types.py index 56613905..f1e106c5 100644 --- a/aiokafka/protocol/types.py +++ b/aiokafka/protocol/types.py @@ -64,6 +64,19 @@ def decode(cls, data): return _unpack(cls._unpack, data.read(4)) +class UInt32(AbstractType): + _pack = struct.Struct(">I").pack + _unpack = struct.Struct(">I").unpack + + @classmethod + def encode(cls, value): + return _pack(cls._pack, value) + + @classmethod + def decode(cls, data): + return _unpack(cls._unpack, data.read(4)) + + class Int64(AbstractType): _pack = struct.Struct(">q").pack _unpack = struct.Struct(">q").unpack diff --git a/aiokafka/util.py b/aiokafka/util.py index 38a08baf..6d7c5968 100644 --- a/aiokafka/util.py +++ b/aiokafka/util.py @@ -1,6 +1,8 @@ import asyncio import os +import weakref from asyncio import AbstractEventLoop +from types import MethodType from typing import Any, Awaitable, Coroutine, Dict, Tuple, TypeVar, Union, cast import async_timeout @@ -91,3 +93,38 @@ def get_running_loop() -> asyncio.AbstractEventLoop: INTEGER_MAX_VALUE = 2**31 - 1 INTEGER_MIN_VALUE = -(2**31) + + +class WeakMethod(object): + """ + Callable that weakly references a method and the object it is bound to. It + is based on https://stackoverflow.com/a/24287465. + + Arguments: + + object_dot_method: A bound instance method (i.e. 'object.method'). + """ + + def __init__(self, object_dot_method: MethodType) -> None: + self.target = weakref.ref(object_dot_method.__self__) + self._target_id = id(self.target()) + self.method = weakref.ref(object_dot_method.__func__) + self._method_id = id(self.method()) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + """ + Calls the method on target with args and kwargs. + """ + method = self.method() + assert method is not None + return method(self.target(), *args, **kwargs) + + def __hash__(self) -> int: + return hash(self.target) ^ hash(self.method) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, WeakMethod): + return False + return ( + self._target_id == other._target_id and self._method_id == other._method_id + ) diff --git a/kafka/util.py b/kafka/util.py deleted file mode 100644 index e31d9930..00000000 --- a/kafka/util.py +++ /dev/null @@ -1,66 +0,0 @@ -from __future__ import absolute_import - -import binascii -import weakref - -from kafka.vendor import six - - -if six.PY3: - MAX_INT = 2 ** 31 - TO_SIGNED = 2 ** 32 - - def crc32(data): - crc = binascii.crc32(data) - # py2 and py3 behave a little differently - # CRC is encoded as a signed int in kafka protocol - # so we'll convert the py3 unsigned result to signed - if crc >= MAX_INT: - crc -= TO_SIGNED - return crc -else: - from binascii import crc32 - - -class WeakMethod(object): - """ - Callable that weakly references a method and the object it is bound to. It - is based on https://stackoverflow.com/a/24287465. - - Arguments: - - object_dot_method: A bound instance method (i.e. 'object.method'). - """ - def __init__(self, object_dot_method): - try: - self.target = weakref.ref(object_dot_method.__self__) - except AttributeError: - self.target = weakref.ref(object_dot_method.im_self) - self._target_id = id(self.target()) - try: - self.method = weakref.ref(object_dot_method.__func__) - except AttributeError: - self.method = weakref.ref(object_dot_method.im_func) - self._method_id = id(self.method()) - - def __call__(self, *args, **kwargs): - """ - Calls the method on target with args and kwargs. - """ - return self.method()(self.target(), *args, **kwargs) - - def __hash__(self): - return hash(self.target) ^ hash(self.method) - - def __eq__(self, other): - if not isinstance(other, WeakMethod): - return False - return self._target_id == other._target_id and self._method_id == other._method_id - - -class Dict(dict): - """Utility class to support passing weakrefs to dicts - - See: https://docs.python.org/2/library/weakref.html - """ - pass From a2ec341258e4f7588eb2632217f16f1abce9babe Mon Sep 17 00:00:00 2001 From: Denis Otkidach Date: Mon, 23 Oct 2023 12:45:23 +0300 Subject: [PATCH 19/20] Replace future and final clean-up --- aiokafka/cluster.py | 4 +- aiokafka/coordinator/base.py | 3 +- aiokafka/coordinator/consumer.py | 5 +- docs/examples/manual_commit.rst | 2 +- docs/examples/serialize_and_compress.rst | 2 +- docs/examples/ssl_consume_produce.rst | 2 +- kafka/__init__.py | 17 - kafka/future.py | 83 --- kafka/vendor/__init__.py | 0 kafka/vendor/enum34.py | 841 --------------------- kafka/vendor/selectors34.py | 637 ---------------- kafka/vendor/six.py | 897 ----------------------- kafka/vendor/socketpair.py | 58 -- setup.py | 2 +- tests/kafka/__init__.py | 8 - tests/kafka/conftest.py | 140 ---- tests/kafka/fixtures.py | 651 ---------------- tests/kafka/service.py | 133 ---- tests/kafka/testutil.py | 46 -- 19 files changed, 9 insertions(+), 3522 deletions(-) delete mode 100644 kafka/__init__.py delete mode 100644 kafka/future.py delete mode 100644 kafka/vendor/__init__.py delete mode 100644 kafka/vendor/enum34.py delete mode 100644 kafka/vendor/selectors34.py delete mode 100644 kafka/vendor/six.py delete mode 100644 kafka/vendor/socketpair.py delete mode 100644 tests/kafka/__init__.py delete mode 100644 tests/kafka/conftest.py delete mode 100644 tests/kafka/fixtures.py delete mode 100644 tests/kafka/service.py delete mode 100644 tests/kafka/testutil.py diff --git a/aiokafka/cluster.py b/aiokafka/cluster.py index fd565422..061db638 100644 --- a/aiokafka/cluster.py +++ b/aiokafka/cluster.py @@ -4,7 +4,7 @@ import threading import time -from kafka.future import Future +from concurrent.futures import Future from aiokafka import errors as Errors from aiokafka.conn import collect_hosts @@ -189,7 +189,7 @@ def request_update(self): change the reported ttl() Returns: - kafka.future.Future (value will be the cluster object after update) + Future (value will be the cluster object after update) """ with self._lock: self._need_update = True diff --git a/aiokafka/coordinator/base.py b/aiokafka/coordinator/base.py index ea6b4ccd..f1de92de 100644 --- a/aiokafka/coordinator/base.py +++ b/aiokafka/coordinator/base.py @@ -4,8 +4,7 @@ import threading import time import weakref - -from kafka.future import Future +from concurrent.futures import Future from aiokafka import errors as Errors from aiokafka.metrics import AnonMeasurable diff --git a/aiokafka/coordinator/consumer.py b/aiokafka/coordinator/consumer.py index 7604e051..dade9fcd 100644 --- a/aiokafka/coordinator/consumer.py +++ b/aiokafka/coordinator/consumer.py @@ -3,8 +3,7 @@ import functools import logging import time - -from kafka.future import Future +from concurrent.futures import Future import aiokafka.errors as Errors from aiokafka.metrics import AnonMeasurable @@ -503,7 +502,7 @@ def commit_offsets_async(self, offsets, callback=None): a commit request completes. Returns: - kafka.future.Future + Future """ self._invoke_completed_offset_commit_callbacks() if not self.coordinator_unknown(): diff --git a/docs/examples/manual_commit.rst b/docs/examples/manual_commit.rst index 30eca170..416f5ed5 100644 --- a/docs/examples/manual_commit.rst +++ b/docs/examples/manual_commit.rst @@ -22,7 +22,7 @@ Consumer: import json import asyncio - from kafka.common import KafkaError + from aiokafka.errors import KafkaError from aiokafka import AIOKafkaConsumer async def consume(): diff --git a/docs/examples/serialize_and_compress.rst b/docs/examples/serialize_and_compress.rst index 02c8dbdf..55d5a48d 100644 --- a/docs/examples/serialize_and_compress.rst +++ b/docs/examples/serialize_and_compress.rst @@ -49,7 +49,7 @@ Consumer import json import asyncio - from kafka.common import KafkaError + from aiokafka.errors import KafkaError from aiokafka import AIOKafkaConsumer def deserializer(serialized): diff --git a/docs/examples/ssl_consume_produce.rst b/docs/examples/ssl_consume_produce.rst index b3c0808f..b99d5e51 100644 --- a/docs/examples/ssl_consume_produce.rst +++ b/docs/examples/ssl_consume_produce.rst @@ -11,7 +11,7 @@ information. import asyncio from aiokafka import AIOKafkaProducer, AIOKafkaConsumer from aiokafka.helpers import create_ssl_context - from kafka.common import TopicPartition + from aiokafka.errors import TopicPartition context = create_ssl_context( cafile="./ca-cert", # CA used to sign certificate. diff --git a/kafka/__init__.py b/kafka/__init__.py deleted file mode 100644 index a40686e6..00000000 --- a/kafka/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -from __future__ import absolute_import - -__title__ = 'kafka' -__author__ = 'Dana Powers' -__license__ = 'Apache License 2.0' -__copyright__ = 'Copyright 2016 Dana Powers, David Arthur, and Contributors' - -# Set default logging handler to avoid "No handler found" warnings. -import logging -try: # Python 2.7+ - from logging import NullHandler -except ImportError: - class NullHandler(logging.Handler): - def emit(self, record): - pass - -logging.getLogger(__name__).addHandler(NullHandler()) diff --git a/kafka/future.py b/kafka/future.py deleted file mode 100644 index d0f3c665..00000000 --- a/kafka/future.py +++ /dev/null @@ -1,83 +0,0 @@ -from __future__ import absolute_import - -import functools -import logging - -log = logging.getLogger(__name__) - - -class Future(object): - error_on_callbacks = False # and errbacks - - def __init__(self): - self.is_done = False - self.value = None - self.exception = None - self._callbacks = [] - self._errbacks = [] - - def succeeded(self): - return self.is_done and not bool(self.exception) - - def failed(self): - return self.is_done and bool(self.exception) - - def retriable(self): - try: - return self.exception.retriable - except AttributeError: - return False - - def success(self, value): - assert not self.is_done, 'Future is already complete' - self.value = value - self.is_done = True - if self._callbacks: - self._call_backs('callback', self._callbacks, self.value) - return self - - def failure(self, e): - assert not self.is_done, 'Future is already complete' - self.exception = e if type(e) is not type else e() - assert isinstance(self.exception, BaseException), ( - 'future failed without an exception') - self.is_done = True - self._call_backs('errback', self._errbacks, self.exception) - return self - - def add_callback(self, f, *args, **kwargs): - if args or kwargs: - f = functools.partial(f, *args, **kwargs) - if self.is_done and not self.exception: - self._call_backs('callback', [f], self.value) - else: - self._callbacks.append(f) - return self - - def add_errback(self, f, *args, **kwargs): - if args or kwargs: - f = functools.partial(f, *args, **kwargs) - if self.is_done and self.exception: - self._call_backs('errback', [f], self.exception) - else: - self._errbacks.append(f) - return self - - def add_both(self, f, *args, **kwargs): - self.add_callback(f, *args, **kwargs) - self.add_errback(f, *args, **kwargs) - return self - - def chain(self, future): - self.add_callback(future.success) - self.add_errback(future.failure) - return self - - def _call_backs(self, back_type, backs, value): - for f in backs: - try: - f(value) - except Exception as e: - log.exception('Error processing %s', back_type) - if self.error_on_callbacks: - raise e diff --git a/kafka/vendor/__init__.py b/kafka/vendor/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/kafka/vendor/enum34.py b/kafka/vendor/enum34.py deleted file mode 100644 index 5f64bd2d..00000000 --- a/kafka/vendor/enum34.py +++ /dev/null @@ -1,841 +0,0 @@ -# pylint: skip-file -# vendored from: -# https://bitbucket.org/stoneleaf/enum34/src/58c4cd7174ca35f164304c8a6f0a4d47b779c2a7/enum/__init__.py?at=1.1.6 - -"""Python Enumerations""" - -import sys as _sys - -__all__ = ['Enum', 'IntEnum', 'unique'] - -version = 1, 1, 6 - -pyver = float('%s.%s' % _sys.version_info[:2]) - -try: - any -except NameError: - def any(iterable): - for element in iterable: - if element: - return True - return False - -try: - from collections import OrderedDict -except ImportError: - OrderedDict = None - -try: - basestring -except NameError: - # In Python 2 basestring is the ancestor of both str and unicode - # in Python 3 it's just str, but was missing in 3.1 - basestring = str - -try: - unicode -except NameError: - # In Python 3 unicode no longer exists (it's just str) - unicode = str - -class _RouteClassAttributeToGetattr(object): - """Route attribute access on a class to __getattr__. - - This is a descriptor, used to define attributes that act differently when - accessed through an instance and through a class. Instance access remains - normal, but access to an attribute through a class will be routed to the - class's __getattr__ method; this is done by raising AttributeError. - - """ - def __init__(self, fget=None): - self.fget = fget - - def __get__(self, instance, ownerclass=None): - if instance is None: - raise AttributeError() - return self.fget(instance) - - def __set__(self, instance, value): - raise AttributeError("can't set attribute") - - def __delete__(self, instance): - raise AttributeError("can't delete attribute") - - -def _is_descriptor(obj): - """Returns True if obj is a descriptor, False otherwise.""" - return ( - hasattr(obj, '__get__') or - hasattr(obj, '__set__') or - hasattr(obj, '__delete__')) - - -def _is_dunder(name): - """Returns True if a __dunder__ name, False otherwise.""" - return (name[:2] == name[-2:] == '__' and - name[2:3] != '_' and - name[-3:-2] != '_' and - len(name) > 4) - - -def _is_sunder(name): - """Returns True if a _sunder_ name, False otherwise.""" - return (name[0] == name[-1] == '_' and - name[1:2] != '_' and - name[-2:-1] != '_' and - len(name) > 2) - - -def _make_class_unpicklable(cls): - """Make the given class un-picklable.""" - def _break_on_call_reduce(self, protocol=None): - raise TypeError('%r cannot be pickled' % self) - cls.__reduce_ex__ = _break_on_call_reduce - cls.__module__ = '' - - -class _EnumDict(dict): - """Track enum member order and ensure member names are not reused. - - EnumMeta will use the names found in self._member_names as the - enumeration member names. - - """ - def __init__(self): - super(_EnumDict, self).__init__() - self._member_names = [] - - def __setitem__(self, key, value): - """Changes anything not dundered or not a descriptor. - - If a descriptor is added with the same name as an enum member, the name - is removed from _member_names (this may leave a hole in the numerical - sequence of values). - - If an enum member name is used twice, an error is raised; duplicate - values are not checked for. - - Single underscore (sunder) names are reserved. - - Note: in 3.x __order__ is simply discarded as a not necessary piece - leftover from 2.x - - """ - if pyver >= 3.0 and key in ('_order_', '__order__'): - return - elif key == '__order__': - key = '_order_' - if _is_sunder(key): - if key != '_order_': - raise ValueError('_names_ are reserved for future Enum use') - elif _is_dunder(key): - pass - elif key in self._member_names: - # descriptor overwriting an enum? - raise TypeError('Attempted to reuse key: %r' % key) - elif not _is_descriptor(value): - if key in self: - # enum overwriting a descriptor? - raise TypeError('Key already defined as: %r' % self[key]) - self._member_names.append(key) - super(_EnumDict, self).__setitem__(key, value) - - -# Dummy value for Enum as EnumMeta explicity checks for it, but of course until -# EnumMeta finishes running the first time the Enum class doesn't exist. This -# is also why there are checks in EnumMeta like `if Enum is not None` -Enum = None - - -class EnumMeta(type): - """Metaclass for Enum""" - @classmethod - def __prepare__(metacls, cls, bases): - return _EnumDict() - - def __new__(metacls, cls, bases, classdict): - # an Enum class is final once enumeration items have been defined; it - # cannot be mixed with other types (int, float, etc.) if it has an - # inherited __new__ unless a new __new__ is defined (or the resulting - # class will fail). - if type(classdict) is dict: - original_dict = classdict - classdict = _EnumDict() - for k, v in original_dict.items(): - classdict[k] = v - - member_type, first_enum = metacls._get_mixins_(bases) - __new__, save_new, use_args = metacls._find_new_(classdict, member_type, - first_enum) - # save enum items into separate mapping so they don't get baked into - # the new class - members = dict((k, classdict[k]) for k in classdict._member_names) - for name in classdict._member_names: - del classdict[name] - - # py2 support for definition order - _order_ = classdict.get('_order_') - if _order_ is None: - if pyver < 3.0: - try: - _order_ = [name for (name, value) in sorted(members.items(), key=lambda item: item[1])] - except TypeError: - _order_ = [name for name in sorted(members.keys())] - else: - _order_ = classdict._member_names - else: - del classdict['_order_'] - if pyver < 3.0: - _order_ = _order_.replace(',', ' ').split() - aliases = [name for name in members if name not in _order_] - _order_ += aliases - - # check for illegal enum names (any others?) - invalid_names = set(members) & set(['mro']) - if invalid_names: - raise ValueError('Invalid enum member name(s): %s' % ( - ', '.join(invalid_names), )) - - # save attributes from super classes so we know if we can take - # the shortcut of storing members in the class dict - base_attributes = set([a for b in bases for a in b.__dict__]) - # create our new Enum type - enum_class = super(EnumMeta, metacls).__new__(metacls, cls, bases, classdict) - enum_class._member_names_ = [] # names in random order - if OrderedDict is not None: - enum_class._member_map_ = OrderedDict() - else: - enum_class._member_map_ = {} # name->value map - enum_class._member_type_ = member_type - - # Reverse value->name map for hashable values. - enum_class._value2member_map_ = {} - - # instantiate them, checking for duplicates as we go - # we instantiate first instead of checking for duplicates first in case - # a custom __new__ is doing something funky with the values -- such as - # auto-numbering ;) - if __new__ is None: - __new__ = enum_class.__new__ - for member_name in _order_: - value = members[member_name] - if not isinstance(value, tuple): - args = (value, ) - else: - args = value - if member_type is tuple: # special case for tuple enums - args = (args, ) # wrap it one more time - if not use_args or not args: - enum_member = __new__(enum_class) - if not hasattr(enum_member, '_value_'): - enum_member._value_ = value - else: - enum_member = __new__(enum_class, *args) - if not hasattr(enum_member, '_value_'): - enum_member._value_ = member_type(*args) - value = enum_member._value_ - enum_member._name_ = member_name - enum_member.__objclass__ = enum_class - enum_member.__init__(*args) - # If another member with the same value was already defined, the - # new member becomes an alias to the existing one. - for name, canonical_member in enum_class._member_map_.items(): - if canonical_member.value == enum_member._value_: - enum_member = canonical_member - break - else: - # Aliases don't appear in member names (only in __members__). - enum_class._member_names_.append(member_name) - # performance boost for any member that would not shadow - # a DynamicClassAttribute (aka _RouteClassAttributeToGetattr) - if member_name not in base_attributes: - setattr(enum_class, member_name, enum_member) - # now add to _member_map_ - enum_class._member_map_[member_name] = enum_member - try: - # This may fail if value is not hashable. We can't add the value - # to the map, and by-value lookups for this value will be - # linear. - enum_class._value2member_map_[value] = enum_member - except TypeError: - pass - - - # If a custom type is mixed into the Enum, and it does not know how - # to pickle itself, pickle.dumps will succeed but pickle.loads will - # fail. Rather than have the error show up later and possibly far - # from the source, sabotage the pickle protocol for this class so - # that pickle.dumps also fails. - # - # However, if the new class implements its own __reduce_ex__, do not - # sabotage -- it's on them to make sure it works correctly. We use - # __reduce_ex__ instead of any of the others as it is preferred by - # pickle over __reduce__, and it handles all pickle protocols. - unpicklable = False - if '__reduce_ex__' not in classdict: - if member_type is not object: - methods = ('__getnewargs_ex__', '__getnewargs__', - '__reduce_ex__', '__reduce__') - if not any(m in member_type.__dict__ for m in methods): - _make_class_unpicklable(enum_class) - unpicklable = True - - - # double check that repr and friends are not the mixin's or various - # things break (such as pickle) - for name in ('__repr__', '__str__', '__format__', '__reduce_ex__'): - class_method = getattr(enum_class, name) - obj_method = getattr(member_type, name, None) - enum_method = getattr(first_enum, name, None) - if name not in classdict and class_method is not enum_method: - if name == '__reduce_ex__' and unpicklable: - continue - setattr(enum_class, name, enum_method) - - # method resolution and int's are not playing nice - # Python's less than 2.6 use __cmp__ - - if pyver < 2.6: - - if issubclass(enum_class, int): - setattr(enum_class, '__cmp__', getattr(int, '__cmp__')) - - elif pyver < 3.0: - - if issubclass(enum_class, int): - for method in ( - '__le__', - '__lt__', - '__gt__', - '__ge__', - '__eq__', - '__ne__', - '__hash__', - ): - setattr(enum_class, method, getattr(int, method)) - - # replace any other __new__ with our own (as long as Enum is not None, - # anyway) -- again, this is to support pickle - if Enum is not None: - # if the user defined their own __new__, save it before it gets - # clobbered in case they subclass later - if save_new: - setattr(enum_class, '__member_new__', enum_class.__dict__['__new__']) - setattr(enum_class, '__new__', Enum.__dict__['__new__']) - return enum_class - - def __bool__(cls): - """ - classes/types should always be True. - """ - return True - - def __call__(cls, value, names=None, module=None, type=None, start=1): - """Either returns an existing member, or creates a new enum class. - - This method is used both when an enum class is given a value to match - to an enumeration member (i.e. Color(3)) and for the functional API - (i.e. Color = Enum('Color', names='red green blue')). - - When used for the functional API: `module`, if set, will be stored in - the new class' __module__ attribute; `type`, if set, will be mixed in - as the first base class. - - Note: if `module` is not set this routine will attempt to discover the - calling module by walking the frame stack; if this is unsuccessful - the resulting class will not be pickleable. - - """ - if names is None: # simple value lookup - return cls.__new__(cls, value) - # otherwise, functional API: we're creating a new Enum type - return cls._create_(value, names, module=module, type=type, start=start) - - def __contains__(cls, member): - return isinstance(member, cls) and member.name in cls._member_map_ - - def __delattr__(cls, attr): - # nicer error message when someone tries to delete an attribute - # (see issue19025). - if attr in cls._member_map_: - raise AttributeError( - "%s: cannot delete Enum member." % cls.__name__) - super(EnumMeta, cls).__delattr__(attr) - - def __dir__(self): - return (['__class__', '__doc__', '__members__', '__module__'] + - self._member_names_) - - @property - def __members__(cls): - """Returns a mapping of member name->value. - - This mapping lists all enum members, including aliases. Note that this - is a copy of the internal mapping. - - """ - return cls._member_map_.copy() - - def __getattr__(cls, name): - """Return the enum member matching `name` - - We use __getattr__ instead of descriptors or inserting into the enum - class' __dict__ in order to support `name` and `value` being both - properties for enum members (which live in the class' __dict__) and - enum members themselves. - - """ - if _is_dunder(name): - raise AttributeError(name) - try: - return cls._member_map_[name] - except KeyError: - raise AttributeError(name) - - def __getitem__(cls, name): - return cls._member_map_[name] - - def __iter__(cls): - return (cls._member_map_[name] for name in cls._member_names_) - - def __reversed__(cls): - return (cls._member_map_[name] for name in reversed(cls._member_names_)) - - def __len__(cls): - return len(cls._member_names_) - - __nonzero__ = __bool__ - - def __repr__(cls): - return "" % cls.__name__ - - def __setattr__(cls, name, value): - """Block attempts to reassign Enum members. - - A simple assignment to the class namespace only changes one of the - several possible ways to get an Enum member from the Enum class, - resulting in an inconsistent Enumeration. - - """ - member_map = cls.__dict__.get('_member_map_', {}) - if name in member_map: - raise AttributeError('Cannot reassign members.') - super(EnumMeta, cls).__setattr__(name, value) - - def _create_(cls, class_name, names=None, module=None, type=None, start=1): - """Convenience method to create a new Enum class. - - `names` can be: - - * A string containing member names, separated either with spaces or - commas. Values are auto-numbered from 1. - * An iterable of member names. Values are auto-numbered from 1. - * An iterable of (member name, value) pairs. - * A mapping of member name -> value. - - """ - if pyver < 3.0: - # if class_name is unicode, attempt a conversion to ASCII - if isinstance(class_name, unicode): - try: - class_name = class_name.encode('ascii') - except UnicodeEncodeError: - raise TypeError('%r is not representable in ASCII' % class_name) - metacls = cls.__class__ - if type is None: - bases = (cls, ) - else: - bases = (type, cls) - classdict = metacls.__prepare__(class_name, bases) - _order_ = [] - - # special processing needed for names? - if isinstance(names, basestring): - names = names.replace(',', ' ').split() - if isinstance(names, (tuple, list)) and isinstance(names[0], basestring): - names = [(e, i+start) for (i, e) in enumerate(names)] - - # Here, names is either an iterable of (name, value) or a mapping. - item = None # in case names is empty - for item in names: - if isinstance(item, basestring): - member_name, member_value = item, names[item] - else: - member_name, member_value = item - classdict[member_name] = member_value - _order_.append(member_name) - # only set _order_ in classdict if name/value was not from a mapping - if not isinstance(item, basestring): - classdict['_order_'] = ' '.join(_order_) - enum_class = metacls.__new__(metacls, class_name, bases, classdict) - - # TODO: replace the frame hack if a blessed way to know the calling - # module is ever developed - if module is None: - try: - module = _sys._getframe(2).f_globals['__name__'] - except (AttributeError, ValueError): - pass - if module is None: - _make_class_unpicklable(enum_class) - else: - enum_class.__module__ = module - - return enum_class - - @staticmethod - def _get_mixins_(bases): - """Returns the type for creating enum members, and the first inherited - enum class. - - bases: the tuple of bases that was given to __new__ - - """ - if not bases or Enum is None: - return object, Enum - - - # double check that we are not subclassing a class with existing - # enumeration members; while we're at it, see if any other data - # type has been mixed in so we can use the correct __new__ - member_type = first_enum = None - for base in bases: - if (base is not Enum and - issubclass(base, Enum) and - base._member_names_): - raise TypeError("Cannot extend enumerations") - # base is now the last base in bases - if not issubclass(base, Enum): - raise TypeError("new enumerations must be created as " - "`ClassName([mixin_type,] enum_type)`") - - # get correct mix-in type (either mix-in type of Enum subclass, or - # first base if last base is Enum) - if not issubclass(bases[0], Enum): - member_type = bases[0] # first data type - first_enum = bases[-1] # enum type - else: - for base in bases[0].__mro__: - # most common: (IntEnum, int, Enum, object) - # possible: (, , - # , , - # ) - if issubclass(base, Enum): - if first_enum is None: - first_enum = base - else: - if member_type is None: - member_type = base - - return member_type, first_enum - - if pyver < 3.0: - @staticmethod - def _find_new_(classdict, member_type, first_enum): - """Returns the __new__ to be used for creating the enum members. - - classdict: the class dictionary given to __new__ - member_type: the data type whose __new__ will be used by default - first_enum: enumeration to check for an overriding __new__ - - """ - # now find the correct __new__, checking to see of one was defined - # by the user; also check earlier enum classes in case a __new__ was - # saved as __member_new__ - __new__ = classdict.get('__new__', None) - if __new__: - return None, True, True # __new__, save_new, use_args - - N__new__ = getattr(None, '__new__') - O__new__ = getattr(object, '__new__') - if Enum is None: - E__new__ = N__new__ - else: - E__new__ = Enum.__dict__['__new__'] - # check all possibles for __member_new__ before falling back to - # __new__ - for method in ('__member_new__', '__new__'): - for possible in (member_type, first_enum): - try: - target = possible.__dict__[method] - except (AttributeError, KeyError): - target = getattr(possible, method, None) - if target not in [ - None, - N__new__, - O__new__, - E__new__, - ]: - if method == '__member_new__': - classdict['__new__'] = target - return None, False, True - if isinstance(target, staticmethod): - target = target.__get__(member_type) - __new__ = target - break - if __new__ is not None: - break - else: - __new__ = object.__new__ - - # if a non-object.__new__ is used then whatever value/tuple was - # assigned to the enum member name will be passed to __new__ and to the - # new enum member's __init__ - if __new__ is object.__new__: - use_args = False - else: - use_args = True - - return __new__, False, use_args - else: - @staticmethod - def _find_new_(classdict, member_type, first_enum): - """Returns the __new__ to be used for creating the enum members. - - classdict: the class dictionary given to __new__ - member_type: the data type whose __new__ will be used by default - first_enum: enumeration to check for an overriding __new__ - - """ - # now find the correct __new__, checking to see of one was defined - # by the user; also check earlier enum classes in case a __new__ was - # saved as __member_new__ - __new__ = classdict.get('__new__', None) - - # should __new__ be saved as __member_new__ later? - save_new = __new__ is not None - - if __new__ is None: - # check all possibles for __member_new__ before falling back to - # __new__ - for method in ('__member_new__', '__new__'): - for possible in (member_type, first_enum): - target = getattr(possible, method, None) - if target not in ( - None, - None.__new__, - object.__new__, - Enum.__new__, - ): - __new__ = target - break - if __new__ is not None: - break - else: - __new__ = object.__new__ - - # if a non-object.__new__ is used then whatever value/tuple was - # assigned to the enum member name will be passed to __new__ and to the - # new enum member's __init__ - if __new__ is object.__new__: - use_args = False - else: - use_args = True - - return __new__, save_new, use_args - - -######################################################## -# In order to support Python 2 and 3 with a single -# codebase we have to create the Enum methods separately -# and then use the `type(name, bases, dict)` method to -# create the class. -######################################################## -temp_enum_dict = {} -temp_enum_dict['__doc__'] = "Generic enumeration.\n\n Derive from this class to define new enumerations.\n\n" - -def __new__(cls, value): - # all enum instances are actually created during class construction - # without calling this method; this method is called by the metaclass' - # __call__ (i.e. Color(3) ), and by pickle - if type(value) is cls: - # For lookups like Color(Color.red) - value = value.value - #return value - # by-value search for a matching enum member - # see if it's in the reverse mapping (for hashable values) - try: - if value in cls._value2member_map_: - return cls._value2member_map_[value] - except TypeError: - # not there, now do long search -- O(n) behavior - for member in cls._member_map_.values(): - if member.value == value: - return member - raise ValueError("%s is not a valid %s" % (value, cls.__name__)) -temp_enum_dict['__new__'] = __new__ -del __new__ - -def __repr__(self): - return "<%s.%s: %r>" % ( - self.__class__.__name__, self._name_, self._value_) -temp_enum_dict['__repr__'] = __repr__ -del __repr__ - -def __str__(self): - return "%s.%s" % (self.__class__.__name__, self._name_) -temp_enum_dict['__str__'] = __str__ -del __str__ - -if pyver >= 3.0: - def __dir__(self): - added_behavior = [ - m - for cls in self.__class__.mro() - for m in cls.__dict__ - if m[0] != '_' and m not in self._member_map_ - ] - return (['__class__', '__doc__', '__module__', ] + added_behavior) - temp_enum_dict['__dir__'] = __dir__ - del __dir__ - -def __format__(self, format_spec): - # mixed-in Enums should use the mixed-in type's __format__, otherwise - # we can get strange results with the Enum name showing up instead of - # the value - - # pure Enum branch - if self._member_type_ is object: - cls = str - val = str(self) - # mix-in branch - else: - cls = self._member_type_ - val = self.value - return cls.__format__(val, format_spec) -temp_enum_dict['__format__'] = __format__ -del __format__ - - -#################################### -# Python's less than 2.6 use __cmp__ - -if pyver < 2.6: - - def __cmp__(self, other): - if type(other) is self.__class__: - if self is other: - return 0 - return -1 - return NotImplemented - raise TypeError("unorderable types: %s() and %s()" % (self.__class__.__name__, other.__class__.__name__)) - temp_enum_dict['__cmp__'] = __cmp__ - del __cmp__ - -else: - - def __le__(self, other): - raise TypeError("unorderable types: %s() <= %s()" % (self.__class__.__name__, other.__class__.__name__)) - temp_enum_dict['__le__'] = __le__ - del __le__ - - def __lt__(self, other): - raise TypeError("unorderable types: %s() < %s()" % (self.__class__.__name__, other.__class__.__name__)) - temp_enum_dict['__lt__'] = __lt__ - del __lt__ - - def __ge__(self, other): - raise TypeError("unorderable types: %s() >= %s()" % (self.__class__.__name__, other.__class__.__name__)) - temp_enum_dict['__ge__'] = __ge__ - del __ge__ - - def __gt__(self, other): - raise TypeError("unorderable types: %s() > %s()" % (self.__class__.__name__, other.__class__.__name__)) - temp_enum_dict['__gt__'] = __gt__ - del __gt__ - - -def __eq__(self, other): - if type(other) is self.__class__: - return self is other - return NotImplemented -temp_enum_dict['__eq__'] = __eq__ -del __eq__ - -def __ne__(self, other): - if type(other) is self.__class__: - return self is not other - return NotImplemented -temp_enum_dict['__ne__'] = __ne__ -del __ne__ - -def __hash__(self): - return hash(self._name_) -temp_enum_dict['__hash__'] = __hash__ -del __hash__ - -def __reduce_ex__(self, proto): - return self.__class__, (self._value_, ) -temp_enum_dict['__reduce_ex__'] = __reduce_ex__ -del __reduce_ex__ - -# _RouteClassAttributeToGetattr is used to provide access to the `name` -# and `value` properties of enum members while keeping some measure of -# protection from modification, while still allowing for an enumeration -# to have members named `name` and `value`. This works because enumeration -# members are not set directly on the enum class -- __getattr__ is -# used to look them up. - -@_RouteClassAttributeToGetattr -def name(self): - return self._name_ -temp_enum_dict['name'] = name -del name - -@_RouteClassAttributeToGetattr -def value(self): - return self._value_ -temp_enum_dict['value'] = value -del value - -@classmethod -def _convert(cls, name, module, filter, source=None): - """ - Create a new Enum subclass that replaces a collection of global constants - """ - # convert all constants from source (or module) that pass filter() to - # a new Enum called name, and export the enum and its members back to - # module; - # also, replace the __reduce_ex__ method so unpickling works in - # previous Python versions - module_globals = vars(_sys.modules[module]) - if source: - source = vars(source) - else: - source = module_globals - members = dict((name, value) for name, value in source.items() if filter(name)) - cls = cls(name, members, module=module) - cls.__reduce_ex__ = _reduce_ex_by_name - module_globals.update(cls.__members__) - module_globals[name] = cls - return cls -temp_enum_dict['_convert'] = _convert -del _convert - -Enum = EnumMeta('Enum', (object, ), temp_enum_dict) -del temp_enum_dict - -# Enum has now been created -########################### - -class IntEnum(int, Enum): - """Enum where members are also (and must be) ints""" - -def _reduce_ex_by_name(self, proto): - return self.name - -def unique(enumeration): - """Class decorator that ensures only unique members exist in an enumeration.""" - duplicates = [] - for name, member in enumeration.__members__.items(): - if name != member.name: - duplicates.append((name, member.name)) - if duplicates: - duplicate_names = ', '.join( - ["%s -> %s" % (alias, name) for (alias, name) in duplicates] - ) - raise ValueError('duplicate names found in %r: %s' % - (enumeration, duplicate_names) - ) - return enumeration diff --git a/kafka/vendor/selectors34.py b/kafka/vendor/selectors34.py deleted file mode 100644 index ebf5d515..00000000 --- a/kafka/vendor/selectors34.py +++ /dev/null @@ -1,637 +0,0 @@ -# pylint: skip-file -# vendored from https://github.com/berkerpeksag/selectors34 -# at commit ff61b82168d2cc9c4922ae08e2a8bf94aab61ea2 (unreleased, ~1.2) -# -# Original author: Charles-Francois Natali (c.f.natali[at]gmail.com) -# Maintainer: Berker Peksag (berker.peksag[at]gmail.com) -# Also see https://pypi.python.org/pypi/selectors34 -"""Selectors module. - -This module allows high-level and efficient I/O multiplexing, built upon the -`select` module primitives. - -The following code adapted from trollius.selectors. -""" -from __future__ import absolute_import - -from abc import ABCMeta, abstractmethod -from collections import namedtuple, Mapping -from errno import EINTR -import math -import select -import sys - -from kafka.vendor import six - - -def _wrap_error(exc, mapping, key): - if key not in mapping: - return - new_err_cls = mapping[key] - new_err = new_err_cls(*exc.args) - - # raise a new exception with the original traceback - if hasattr(exc, '__traceback__'): - traceback = exc.__traceback__ - else: - traceback = sys.exc_info()[2] - six.reraise(new_err_cls, new_err, traceback) - - -# generic events, that must be mapped to implementation-specific ones -EVENT_READ = (1 << 0) -EVENT_WRITE = (1 << 1) - - -def _fileobj_to_fd(fileobj): - """Return a file descriptor from a file object. - - Parameters: - fileobj -- file object or file descriptor - - Returns: - corresponding file descriptor - - Raises: - ValueError if the object is invalid - """ - if isinstance(fileobj, six.integer_types): - fd = fileobj - else: - try: - fd = int(fileobj.fileno()) - except (AttributeError, TypeError, ValueError): - raise ValueError("Invalid file object: " - "{0!r}".format(fileobj)) - if fd < 0: - raise ValueError("Invalid file descriptor: {0}".format(fd)) - return fd - - -SelectorKey = namedtuple('SelectorKey', ['fileobj', 'fd', 'events', 'data']) -"""Object used to associate a file object to its backing file descriptor, -selected event mask and attached data.""" - - -class _SelectorMapping(Mapping): - """Mapping of file objects to selector keys.""" - - def __init__(self, selector): - self._selector = selector - - def __len__(self): - return len(self._selector._fd_to_key) - - def __getitem__(self, fileobj): - try: - fd = self._selector._fileobj_lookup(fileobj) - return self._selector._fd_to_key[fd] - except KeyError: - raise KeyError("{0!r} is not registered".format(fileobj)) - - def __iter__(self): - return iter(self._selector._fd_to_key) - -# Using six.add_metaclass() decorator instead of six.with_metaclass() because -# the latter leaks temporary_class to garbage with gc disabled -@six.add_metaclass(ABCMeta) -class BaseSelector(object): - """Selector abstract base class. - - A selector supports registering file objects to be monitored for specific - I/O events. - - A file object is a file descriptor or any object with a `fileno()` method. - An arbitrary object can be attached to the file object, which can be used - for example to store context information, a callback, etc. - - A selector can use various implementations (select(), poll(), epoll()...) - depending on the platform. The default `Selector` class uses the most - efficient implementation on the current platform. - """ - - @abstractmethod - def register(self, fileobj, events, data=None): - """Register a file object. - - Parameters: - fileobj -- file object or file descriptor - events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) - data -- attached data - - Returns: - SelectorKey instance - - Raises: - ValueError if events is invalid - KeyError if fileobj is already registered - OSError if fileobj is closed or otherwise is unacceptable to - the underlying system call (if a system call is made) - - Note: - OSError may or may not be raised - """ - raise NotImplementedError - - @abstractmethod - def unregister(self, fileobj): - """Unregister a file object. - - Parameters: - fileobj -- file object or file descriptor - - Returns: - SelectorKey instance - - Raises: - KeyError if fileobj is not registered - - Note: - If fileobj is registered but has since been closed this does - *not* raise OSError (even if the wrapped syscall does) - """ - raise NotImplementedError - - def modify(self, fileobj, events, data=None): - """Change a registered file object monitored events or attached data. - - Parameters: - fileobj -- file object or file descriptor - events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) - data -- attached data - - Returns: - SelectorKey instance - - Raises: - Anything that unregister() or register() raises - """ - self.unregister(fileobj) - return self.register(fileobj, events, data) - - @abstractmethod - def select(self, timeout=None): - """Perform the actual selection, until some monitored file objects are - ready or a timeout expires. - - Parameters: - timeout -- if timeout > 0, this specifies the maximum wait time, in - seconds - if timeout <= 0, the select() call won't block, and will - report the currently ready file objects - if timeout is None, select() will block until a monitored - file object becomes ready - - Returns: - list of (key, events) for ready file objects - `events` is a bitwise mask of EVENT_READ|EVENT_WRITE - """ - raise NotImplementedError - - def close(self): - """Close the selector. - - This must be called to make sure that any underlying resource is freed. - """ - pass - - def get_key(self, fileobj): - """Return the key associated to a registered file object. - - Returns: - SelectorKey for this file object - """ - mapping = self.get_map() - if mapping is None: - raise RuntimeError('Selector is closed') - try: - return mapping[fileobj] - except KeyError: - raise KeyError("{0!r} is not registered".format(fileobj)) - - @abstractmethod - def get_map(self): - """Return a mapping of file objects to selector keys.""" - raise NotImplementedError - - def __enter__(self): - return self - - def __exit__(self, *args): - self.close() - - -class _BaseSelectorImpl(BaseSelector): - """Base selector implementation.""" - - def __init__(self): - # this maps file descriptors to keys - self._fd_to_key = {} - # read-only mapping returned by get_map() - self._map = _SelectorMapping(self) - - def _fileobj_lookup(self, fileobj): - """Return a file descriptor from a file object. - - This wraps _fileobj_to_fd() to do an exhaustive search in case - the object is invalid but we still have it in our map. This - is used by unregister() so we can unregister an object that - was previously registered even if it is closed. It is also - used by _SelectorMapping. - """ - try: - return _fileobj_to_fd(fileobj) - except ValueError: - # Do an exhaustive search. - for key in self._fd_to_key.values(): - if key.fileobj is fileobj: - return key.fd - # Raise ValueError after all. - raise - - def register(self, fileobj, events, data=None): - if (not events) or (events & ~(EVENT_READ | EVENT_WRITE)): - raise ValueError("Invalid events: {0!r}".format(events)) - - key = SelectorKey(fileobj, self._fileobj_lookup(fileobj), events, data) - - if key.fd in self._fd_to_key: - raise KeyError("{0!r} (FD {1}) is already registered" - .format(fileobj, key.fd)) - - self._fd_to_key[key.fd] = key - return key - - def unregister(self, fileobj): - try: - key = self._fd_to_key.pop(self._fileobj_lookup(fileobj)) - except KeyError: - raise KeyError("{0!r} is not registered".format(fileobj)) - return key - - def modify(self, fileobj, events, data=None): - # TODO: Subclasses can probably optimize this even further. - try: - key = self._fd_to_key[self._fileobj_lookup(fileobj)] - except KeyError: - raise KeyError("{0!r} is not registered".format(fileobj)) - if events != key.events: - self.unregister(fileobj) - key = self.register(fileobj, events, data) - elif data != key.data: - # Use a shortcut to update the data. - key = key._replace(data=data) - self._fd_to_key[key.fd] = key - return key - - def close(self): - self._fd_to_key.clear() - self._map = None - - def get_map(self): - return self._map - - def _key_from_fd(self, fd): - """Return the key associated to a given file descriptor. - - Parameters: - fd -- file descriptor - - Returns: - corresponding key, or None if not found - """ - try: - return self._fd_to_key[fd] - except KeyError: - return None - - -class SelectSelector(_BaseSelectorImpl): - """Select-based selector.""" - - def __init__(self): - super(SelectSelector, self).__init__() - self._readers = set() - self._writers = set() - - def register(self, fileobj, events, data=None): - key = super(SelectSelector, self).register(fileobj, events, data) - if events & EVENT_READ: - self._readers.add(key.fd) - if events & EVENT_WRITE: - self._writers.add(key.fd) - return key - - def unregister(self, fileobj): - key = super(SelectSelector, self).unregister(fileobj) - self._readers.discard(key.fd) - self._writers.discard(key.fd) - return key - - if sys.platform == 'win32': - def _select(self, r, w, _, timeout=None): - r, w, x = select.select(r, w, w, timeout) - return r, w + x, [] - else: - _select = staticmethod(select.select) - - def select(self, timeout=None): - timeout = None if timeout is None else max(timeout, 0) - ready = [] - try: - r, w, _ = self._select(self._readers, self._writers, [], timeout) - except select.error as exc: - if exc.args[0] == EINTR: - return ready - else: - raise - r = set(r) - w = set(w) - for fd in r | w: - events = 0 - if fd in r: - events |= EVENT_READ - if fd in w: - events |= EVENT_WRITE - - key = self._key_from_fd(fd) - if key: - ready.append((key, events & key.events)) - return ready - - -if hasattr(select, 'poll'): - - class PollSelector(_BaseSelectorImpl): - """Poll-based selector.""" - - def __init__(self): - super(PollSelector, self).__init__() - self._poll = select.poll() - - def register(self, fileobj, events, data=None): - key = super(PollSelector, self).register(fileobj, events, data) - poll_events = 0 - if events & EVENT_READ: - poll_events |= select.POLLIN - if events & EVENT_WRITE: - poll_events |= select.POLLOUT - self._poll.register(key.fd, poll_events) - return key - - def unregister(self, fileobj): - key = super(PollSelector, self).unregister(fileobj) - self._poll.unregister(key.fd) - return key - - def select(self, timeout=None): - if timeout is None: - timeout = None - elif timeout <= 0: - timeout = 0 - else: - # poll() has a resolution of 1 millisecond, round away from - # zero to wait *at least* timeout seconds. - timeout = int(math.ceil(timeout * 1e3)) - ready = [] - try: - fd_event_list = self._poll.poll(timeout) - except select.error as exc: - if exc.args[0] == EINTR: - return ready - else: - raise - for fd, event in fd_event_list: - events = 0 - if event & ~select.POLLIN: - events |= EVENT_WRITE - if event & ~select.POLLOUT: - events |= EVENT_READ - - key = self._key_from_fd(fd) - if key: - ready.append((key, events & key.events)) - return ready - - -if hasattr(select, 'epoll'): - - class EpollSelector(_BaseSelectorImpl): - """Epoll-based selector.""" - - def __init__(self): - super(EpollSelector, self).__init__() - self._epoll = select.epoll() - - def fileno(self): - return self._epoll.fileno() - - def register(self, fileobj, events, data=None): - key = super(EpollSelector, self).register(fileobj, events, data) - epoll_events = 0 - if events & EVENT_READ: - epoll_events |= select.EPOLLIN - if events & EVENT_WRITE: - epoll_events |= select.EPOLLOUT - self._epoll.register(key.fd, epoll_events) - return key - - def unregister(self, fileobj): - key = super(EpollSelector, self).unregister(fileobj) - try: - self._epoll.unregister(key.fd) - except IOError: - # This can happen if the FD was closed since it - # was registered. - pass - return key - - def select(self, timeout=None): - if timeout is None: - timeout = -1 - elif timeout <= 0: - timeout = 0 - else: - # epoll_wait() has a resolution of 1 millisecond, round away - # from zero to wait *at least* timeout seconds. - timeout = math.ceil(timeout * 1e3) * 1e-3 - - # epoll_wait() expects `maxevents` to be greater than zero; - # we want to make sure that `select()` can be called when no - # FD is registered. - max_ev = max(len(self._fd_to_key), 1) - - ready = [] - try: - fd_event_list = self._epoll.poll(timeout, max_ev) - except IOError as exc: - if exc.errno == EINTR: - return ready - else: - raise - for fd, event in fd_event_list: - events = 0 - if event & ~select.EPOLLIN: - events |= EVENT_WRITE - if event & ~select.EPOLLOUT: - events |= EVENT_READ - - key = self._key_from_fd(fd) - if key: - ready.append((key, events & key.events)) - return ready - - def close(self): - self._epoll.close() - super(EpollSelector, self).close() - - -if hasattr(select, 'devpoll'): - - class DevpollSelector(_BaseSelectorImpl): - """Solaris /dev/poll selector.""" - - def __init__(self): - super(DevpollSelector, self).__init__() - self._devpoll = select.devpoll() - - def fileno(self): - return self._devpoll.fileno() - - def register(self, fileobj, events, data=None): - key = super(DevpollSelector, self).register(fileobj, events, data) - poll_events = 0 - if events & EVENT_READ: - poll_events |= select.POLLIN - if events & EVENT_WRITE: - poll_events |= select.POLLOUT - self._devpoll.register(key.fd, poll_events) - return key - - def unregister(self, fileobj): - key = super(DevpollSelector, self).unregister(fileobj) - self._devpoll.unregister(key.fd) - return key - - def select(self, timeout=None): - if timeout is None: - timeout = None - elif timeout <= 0: - timeout = 0 - else: - # devpoll() has a resolution of 1 millisecond, round away from - # zero to wait *at least* timeout seconds. - timeout = math.ceil(timeout * 1e3) - ready = [] - try: - fd_event_list = self._devpoll.poll(timeout) - except OSError as exc: - if exc.errno == EINTR: - return ready - else: - raise - for fd, event in fd_event_list: - events = 0 - if event & ~select.POLLIN: - events |= EVENT_WRITE - if event & ~select.POLLOUT: - events |= EVENT_READ - - key = self._key_from_fd(fd) - if key: - ready.append((key, events & key.events)) - return ready - - def close(self): - self._devpoll.close() - super(DevpollSelector, self).close() - - -if hasattr(select, 'kqueue'): - - class KqueueSelector(_BaseSelectorImpl): - """Kqueue-based selector.""" - - def __init__(self): - super(KqueueSelector, self).__init__() - self._kqueue = select.kqueue() - - def fileno(self): - return self._kqueue.fileno() - - def register(self, fileobj, events, data=None): - key = super(KqueueSelector, self).register(fileobj, events, data) - if events & EVENT_READ: - kev = select.kevent(key.fd, select.KQ_FILTER_READ, - select.KQ_EV_ADD) - self._kqueue.control([kev], 0, 0) - if events & EVENT_WRITE: - kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, - select.KQ_EV_ADD) - self._kqueue.control([kev], 0, 0) - return key - - def unregister(self, fileobj): - key = super(KqueueSelector, self).unregister(fileobj) - if key.events & EVENT_READ: - kev = select.kevent(key.fd, select.KQ_FILTER_READ, - select.KQ_EV_DELETE) - try: - self._kqueue.control([kev], 0, 0) - except OSError: - # This can happen if the FD was closed since it - # was registered. - pass - if key.events & EVENT_WRITE: - kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, - select.KQ_EV_DELETE) - try: - self._kqueue.control([kev], 0, 0) - except OSError: - # See comment above. - pass - return key - - def select(self, timeout=None): - timeout = None if timeout is None else max(timeout, 0) - max_ev = len(self._fd_to_key) - ready = [] - try: - kev_list = self._kqueue.control(None, max_ev, timeout) - except OSError as exc: - if exc.errno == EINTR: - return ready - else: - raise - for kev in kev_list: - fd = kev.ident - flag = kev.filter - events = 0 - if flag == select.KQ_FILTER_READ: - events |= EVENT_READ - if flag == select.KQ_FILTER_WRITE: - events |= EVENT_WRITE - - key = self._key_from_fd(fd) - if key: - ready.append((key, events & key.events)) - return ready - - def close(self): - self._kqueue.close() - super(KqueueSelector, self).close() - - -# Choose the best implementation, roughly: -# epoll|kqueue|devpoll > poll > select. -# select() also can't accept a FD > FD_SETSIZE (usually around 1024) -if 'KqueueSelector' in globals(): - DefaultSelector = KqueueSelector -elif 'EpollSelector' in globals(): - DefaultSelector = EpollSelector -elif 'DevpollSelector' in globals(): - DefaultSelector = DevpollSelector -elif 'PollSelector' in globals(): - DefaultSelector = PollSelector -else: - DefaultSelector = SelectSelector diff --git a/kafka/vendor/six.py b/kafka/vendor/six.py deleted file mode 100644 index 3621a0ab..00000000 --- a/kafka/vendor/six.py +++ /dev/null @@ -1,897 +0,0 @@ -# pylint: skip-file - -# Copyright (c) 2010-2017 Benjamin Peterson -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -"""Utilities for writing code that runs on Python 2 and 3""" - -from __future__ import absolute_import - -import functools -import itertools -import operator -import sys -import types - -__author__ = "Benjamin Peterson " -__version__ = "1.11.0" - - -# Useful for very coarse version differentiation. -PY2 = sys.version_info[0] == 2 -PY3 = sys.version_info[0] == 3 -PY34 = sys.version_info[0:2] >= (3, 4) - -if PY3: - string_types = str, - integer_types = int, - class_types = type, - text_type = str - binary_type = bytes - - MAXSIZE = sys.maxsize -else: - string_types = basestring, - integer_types = (int, long) - class_types = (type, types.ClassType) - text_type = unicode - binary_type = str - - if sys.platform.startswith("java"): - # Jython always uses 32 bits. - MAXSIZE = int((1 << 31) - 1) - else: - # It's possible to have sizeof(long) != sizeof(Py_ssize_t). - class X(object): - - def __len__(self): - return 1 << 31 - try: - len(X()) - except OverflowError: - # 32-bit - MAXSIZE = int((1 << 31) - 1) - else: - # 64-bit - MAXSIZE = int((1 << 63) - 1) - - # Don't del it here, cause with gc disabled this "leaks" to garbage. - # Note: This is a kafka-python customization, details at: - # https://github.com/dpkp/kafka-python/pull/979#discussion_r100403389 - # del X - - -def _add_doc(func, doc): - """Add documentation to a function.""" - func.__doc__ = doc - - -def _import_module(name): - """Import module, returning the module after the last dot.""" - __import__(name) - return sys.modules[name] - - -class _LazyDescr(object): - - def __init__(self, name): - self.name = name - - def __get__(self, obj, tp): - result = self._resolve() - setattr(obj, self.name, result) # Invokes __set__. - try: - # This is a bit ugly, but it avoids running this again by - # removing this descriptor. - delattr(obj.__class__, self.name) - except AttributeError: - pass - return result - - -class MovedModule(_LazyDescr): - - def __init__(self, name, old, new=None): - super(MovedModule, self).__init__(name) - if PY3: - if new is None: - new = name - self.mod = new - else: - self.mod = old - - def _resolve(self): - return _import_module(self.mod) - - def __getattr__(self, attr): - _module = self._resolve() - value = getattr(_module, attr) - setattr(self, attr, value) - return value - - -class _LazyModule(types.ModuleType): - - def __init__(self, name): - super(_LazyModule, self).__init__(name) - self.__doc__ = self.__class__.__doc__ - - def __dir__(self): - attrs = ["__doc__", "__name__"] - attrs += [attr.name for attr in self._moved_attributes] - return attrs - - # Subclasses should override this - _moved_attributes = [] - - -class MovedAttribute(_LazyDescr): - - def __init__(self, name, old_mod, new_mod, old_attr=None, new_attr=None): - super(MovedAttribute, self).__init__(name) - if PY3: - if new_mod is None: - new_mod = name - self.mod = new_mod - if new_attr is None: - if old_attr is None: - new_attr = name - else: - new_attr = old_attr - self.attr = new_attr - else: - self.mod = old_mod - if old_attr is None: - old_attr = name - self.attr = old_attr - - def _resolve(self): - module = _import_module(self.mod) - return getattr(module, self.attr) - - -class _SixMetaPathImporter(object): - - """ - A meta path importer to import six.moves and its submodules. - - This class implements a PEP302 finder and loader. It should be compatible - with Python 2.5 and all existing versions of Python3 - """ - - def __init__(self, six_module_name): - self.name = six_module_name - self.known_modules = {} - - def _add_module(self, mod, *fullnames): - for fullname in fullnames: - self.known_modules[self.name + "." + fullname] = mod - - def _get_module(self, fullname): - return self.known_modules[self.name + "." + fullname] - - def find_module(self, fullname, path=None): - if fullname in self.known_modules: - return self - return None - - def __get_module(self, fullname): - try: - return self.known_modules[fullname] - except KeyError: - raise ImportError("This loader does not know module " + fullname) - - def load_module(self, fullname): - try: - # in case of a reload - return sys.modules[fullname] - except KeyError: - pass - mod = self.__get_module(fullname) - if isinstance(mod, MovedModule): - mod = mod._resolve() - else: - mod.__loader__ = self - sys.modules[fullname] = mod - return mod - - def is_package(self, fullname): - """ - Return true, if the named module is a package. - - We need this method to get correct spec objects with - Python 3.4 (see PEP451) - """ - return hasattr(self.__get_module(fullname), "__path__") - - def get_code(self, fullname): - """Return None - - Required, if is_package is implemented""" - self.__get_module(fullname) # eventually raises ImportError - return None - get_source = get_code # same as get_code - -_importer = _SixMetaPathImporter(__name__) - - -class _MovedItems(_LazyModule): - - """Lazy loading of moved objects""" - __path__ = [] # mark as package - - -_moved_attributes = [ - MovedAttribute("cStringIO", "cStringIO", "io", "StringIO"), - MovedAttribute("filter", "itertools", "builtins", "ifilter", "filter"), - MovedAttribute("filterfalse", "itertools", "itertools", "ifilterfalse", "filterfalse"), - MovedAttribute("input", "__builtin__", "builtins", "raw_input", "input"), - MovedAttribute("intern", "__builtin__", "sys"), - MovedAttribute("map", "itertools", "builtins", "imap", "map"), - MovedAttribute("getcwd", "os", "os", "getcwdu", "getcwd"), - MovedAttribute("getcwdb", "os", "os", "getcwd", "getcwdb"), - MovedAttribute("getoutput", "commands", "subprocess"), - MovedAttribute("range", "__builtin__", "builtins", "xrange", "range"), - MovedAttribute("reload_module", "__builtin__", "importlib" if PY34 else "imp", "reload"), - MovedAttribute("reduce", "__builtin__", "functools"), - MovedAttribute("shlex_quote", "pipes", "shlex", "quote"), - MovedAttribute("StringIO", "StringIO", "io"), - MovedAttribute("UserDict", "UserDict", "collections"), - MovedAttribute("UserList", "UserList", "collections"), - MovedAttribute("UserString", "UserString", "collections"), - MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"), - MovedAttribute("zip", "itertools", "builtins", "izip", "zip"), - MovedAttribute("zip_longest", "itertools", "itertools", "izip_longest", "zip_longest"), - MovedModule("builtins", "__builtin__"), - MovedModule("configparser", "ConfigParser"), - MovedModule("copyreg", "copy_reg"), - MovedModule("dbm_gnu", "gdbm", "dbm.gnu"), - MovedModule("_dummy_thread", "dummy_thread", "_dummy_thread"), - MovedModule("http_cookiejar", "cookielib", "http.cookiejar"), - MovedModule("http_cookies", "Cookie", "http.cookies"), - MovedModule("html_entities", "htmlentitydefs", "html.entities"), - MovedModule("html_parser", "HTMLParser", "html.parser"), - MovedModule("http_client", "httplib", "http.client"), - MovedModule("email_mime_base", "email.MIMEBase", "email.mime.base"), - MovedModule("email_mime_image", "email.MIMEImage", "email.mime.image"), - MovedModule("email_mime_multipart", "email.MIMEMultipart", "email.mime.multipart"), - MovedModule("email_mime_nonmultipart", "email.MIMENonMultipart", "email.mime.nonmultipart"), - MovedModule("email_mime_text", "email.MIMEText", "email.mime.text"), - MovedModule("BaseHTTPServer", "BaseHTTPServer", "http.server"), - MovedModule("CGIHTTPServer", "CGIHTTPServer", "http.server"), - MovedModule("SimpleHTTPServer", "SimpleHTTPServer", "http.server"), - MovedModule("cPickle", "cPickle", "pickle"), - MovedModule("queue", "Queue"), - MovedModule("reprlib", "repr"), - MovedModule("socketserver", "SocketServer"), - MovedModule("_thread", "thread", "_thread"), - MovedModule("tkinter", "Tkinter"), - MovedModule("tkinter_dialog", "Dialog", "tkinter.dialog"), - MovedModule("tkinter_filedialog", "FileDialog", "tkinter.filedialog"), - MovedModule("tkinter_scrolledtext", "ScrolledText", "tkinter.scrolledtext"), - MovedModule("tkinter_simpledialog", "SimpleDialog", "tkinter.simpledialog"), - MovedModule("tkinter_tix", "Tix", "tkinter.tix"), - MovedModule("tkinter_ttk", "ttk", "tkinter.ttk"), - MovedModule("tkinter_constants", "Tkconstants", "tkinter.constants"), - MovedModule("tkinter_dnd", "Tkdnd", "tkinter.dnd"), - MovedModule("tkinter_colorchooser", "tkColorChooser", - "tkinter.colorchooser"), - MovedModule("tkinter_commondialog", "tkCommonDialog", - "tkinter.commondialog"), - MovedModule("tkinter_tkfiledialog", "tkFileDialog", "tkinter.filedialog"), - MovedModule("tkinter_font", "tkFont", "tkinter.font"), - MovedModule("tkinter_messagebox", "tkMessageBox", "tkinter.messagebox"), - MovedModule("tkinter_tksimpledialog", "tkSimpleDialog", - "tkinter.simpledialog"), - MovedModule("urllib_parse", __name__ + ".moves.urllib_parse", "urllib.parse"), - MovedModule("urllib_error", __name__ + ".moves.urllib_error", "urllib.error"), - MovedModule("urllib", __name__ + ".moves.urllib", __name__ + ".moves.urllib"), - MovedModule("urllib_robotparser", "robotparser", "urllib.robotparser"), - MovedModule("xmlrpc_client", "xmlrpclib", "xmlrpc.client"), - MovedModule("xmlrpc_server", "SimpleXMLRPCServer", "xmlrpc.server"), -] -# Add windows specific modules. -if sys.platform == "win32": - _moved_attributes += [ - MovedModule("winreg", "_winreg"), - ] - -for attr in _moved_attributes: - setattr(_MovedItems, attr.name, attr) - if isinstance(attr, MovedModule): - _importer._add_module(attr, "moves." + attr.name) -del attr - -_MovedItems._moved_attributes = _moved_attributes - -moves = _MovedItems(__name__ + ".moves") -_importer._add_module(moves, "moves") - - -class Module_six_moves_urllib_parse(_LazyModule): - - """Lazy loading of moved objects in six.moves.urllib_parse""" - - -_urllib_parse_moved_attributes = [ - MovedAttribute("ParseResult", "urlparse", "urllib.parse"), - MovedAttribute("SplitResult", "urlparse", "urllib.parse"), - MovedAttribute("parse_qs", "urlparse", "urllib.parse"), - MovedAttribute("parse_qsl", "urlparse", "urllib.parse"), - MovedAttribute("urldefrag", "urlparse", "urllib.parse"), - MovedAttribute("urljoin", "urlparse", "urllib.parse"), - MovedAttribute("urlparse", "urlparse", "urllib.parse"), - MovedAttribute("urlsplit", "urlparse", "urllib.parse"), - MovedAttribute("urlunparse", "urlparse", "urllib.parse"), - MovedAttribute("urlunsplit", "urlparse", "urllib.parse"), - MovedAttribute("quote", "urllib", "urllib.parse"), - MovedAttribute("quote_plus", "urllib", "urllib.parse"), - MovedAttribute("unquote", "urllib", "urllib.parse"), - MovedAttribute("unquote_plus", "urllib", "urllib.parse"), - MovedAttribute("unquote_to_bytes", "urllib", "urllib.parse", "unquote", "unquote_to_bytes"), - MovedAttribute("urlencode", "urllib", "urllib.parse"), - MovedAttribute("splitquery", "urllib", "urllib.parse"), - MovedAttribute("splittag", "urllib", "urllib.parse"), - MovedAttribute("splituser", "urllib", "urllib.parse"), - MovedAttribute("splitvalue", "urllib", "urllib.parse"), - MovedAttribute("uses_fragment", "urlparse", "urllib.parse"), - MovedAttribute("uses_netloc", "urlparse", "urllib.parse"), - MovedAttribute("uses_params", "urlparse", "urllib.parse"), - MovedAttribute("uses_query", "urlparse", "urllib.parse"), - MovedAttribute("uses_relative", "urlparse", "urllib.parse"), -] -for attr in _urllib_parse_moved_attributes: - setattr(Module_six_moves_urllib_parse, attr.name, attr) -del attr - -Module_six_moves_urllib_parse._moved_attributes = _urllib_parse_moved_attributes - -_importer._add_module(Module_six_moves_urllib_parse(__name__ + ".moves.urllib_parse"), - "moves.urllib_parse", "moves.urllib.parse") - - -class Module_six_moves_urllib_error(_LazyModule): - - """Lazy loading of moved objects in six.moves.urllib_error""" - - -_urllib_error_moved_attributes = [ - MovedAttribute("URLError", "urllib2", "urllib.error"), - MovedAttribute("HTTPError", "urllib2", "urllib.error"), - MovedAttribute("ContentTooShortError", "urllib", "urllib.error"), -] -for attr in _urllib_error_moved_attributes: - setattr(Module_six_moves_urllib_error, attr.name, attr) -del attr - -Module_six_moves_urllib_error._moved_attributes = _urllib_error_moved_attributes - -_importer._add_module(Module_six_moves_urllib_error(__name__ + ".moves.urllib.error"), - "moves.urllib_error", "moves.urllib.error") - - -class Module_six_moves_urllib_request(_LazyModule): - - """Lazy loading of moved objects in six.moves.urllib_request""" - - -_urllib_request_moved_attributes = [ - MovedAttribute("urlopen", "urllib2", "urllib.request"), - MovedAttribute("install_opener", "urllib2", "urllib.request"), - MovedAttribute("build_opener", "urllib2", "urllib.request"), - MovedAttribute("pathname2url", "urllib", "urllib.request"), - MovedAttribute("url2pathname", "urllib", "urllib.request"), - MovedAttribute("getproxies", "urllib", "urllib.request"), - MovedAttribute("Request", "urllib2", "urllib.request"), - MovedAttribute("OpenerDirector", "urllib2", "urllib.request"), - MovedAttribute("HTTPDefaultErrorHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPRedirectHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPCookieProcessor", "urllib2", "urllib.request"), - MovedAttribute("ProxyHandler", "urllib2", "urllib.request"), - MovedAttribute("BaseHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPPasswordMgr", "urllib2", "urllib.request"), - MovedAttribute("HTTPPasswordMgrWithDefaultRealm", "urllib2", "urllib.request"), - MovedAttribute("AbstractBasicAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPBasicAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("ProxyBasicAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("AbstractDigestAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPDigestAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("ProxyDigestAuthHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPSHandler", "urllib2", "urllib.request"), - MovedAttribute("FileHandler", "urllib2", "urllib.request"), - MovedAttribute("FTPHandler", "urllib2", "urllib.request"), - MovedAttribute("CacheFTPHandler", "urllib2", "urllib.request"), - MovedAttribute("UnknownHandler", "urllib2", "urllib.request"), - MovedAttribute("HTTPErrorProcessor", "urllib2", "urllib.request"), - MovedAttribute("urlretrieve", "urllib", "urllib.request"), - MovedAttribute("urlcleanup", "urllib", "urllib.request"), - MovedAttribute("URLopener", "urllib", "urllib.request"), - MovedAttribute("FancyURLopener", "urllib", "urllib.request"), - MovedAttribute("proxy_bypass", "urllib", "urllib.request"), - MovedAttribute("parse_http_list", "urllib2", "urllib.request"), - MovedAttribute("parse_keqv_list", "urllib2", "urllib.request"), -] -for attr in _urllib_request_moved_attributes: - setattr(Module_six_moves_urllib_request, attr.name, attr) -del attr - -Module_six_moves_urllib_request._moved_attributes = _urllib_request_moved_attributes - -_importer._add_module(Module_six_moves_urllib_request(__name__ + ".moves.urllib.request"), - "moves.urllib_request", "moves.urllib.request") - - -class Module_six_moves_urllib_response(_LazyModule): - - """Lazy loading of moved objects in six.moves.urllib_response""" - - -_urllib_response_moved_attributes = [ - MovedAttribute("addbase", "urllib", "urllib.response"), - MovedAttribute("addclosehook", "urllib", "urllib.response"), - MovedAttribute("addinfo", "urllib", "urllib.response"), - MovedAttribute("addinfourl", "urllib", "urllib.response"), -] -for attr in _urllib_response_moved_attributes: - setattr(Module_six_moves_urllib_response, attr.name, attr) -del attr - -Module_six_moves_urllib_response._moved_attributes = _urllib_response_moved_attributes - -_importer._add_module(Module_six_moves_urllib_response(__name__ + ".moves.urllib.response"), - "moves.urllib_response", "moves.urllib.response") - - -class Module_six_moves_urllib_robotparser(_LazyModule): - - """Lazy loading of moved objects in six.moves.urllib_robotparser""" - - -_urllib_robotparser_moved_attributes = [ - MovedAttribute("RobotFileParser", "robotparser", "urllib.robotparser"), -] -for attr in _urllib_robotparser_moved_attributes: - setattr(Module_six_moves_urllib_robotparser, attr.name, attr) -del attr - -Module_six_moves_urllib_robotparser._moved_attributes = _urllib_robotparser_moved_attributes - -_importer._add_module(Module_six_moves_urllib_robotparser(__name__ + ".moves.urllib.robotparser"), - "moves.urllib_robotparser", "moves.urllib.robotparser") - - -class Module_six_moves_urllib(types.ModuleType): - - """Create a six.moves.urllib namespace that resembles the Python 3 namespace""" - __path__ = [] # mark as package - parse = _importer._get_module("moves.urllib_parse") - error = _importer._get_module("moves.urllib_error") - request = _importer._get_module("moves.urllib_request") - response = _importer._get_module("moves.urllib_response") - robotparser = _importer._get_module("moves.urllib_robotparser") - - def __dir__(self): - return ['parse', 'error', 'request', 'response', 'robotparser'] - -_importer._add_module(Module_six_moves_urllib(__name__ + ".moves.urllib"), - "moves.urllib") - - -def add_move(move): - """Add an item to six.moves.""" - setattr(_MovedItems, move.name, move) - - -def remove_move(name): - """Remove item from six.moves.""" - try: - delattr(_MovedItems, name) - except AttributeError: - try: - del moves.__dict__[name] - except KeyError: - raise AttributeError("no such move, %r" % (name,)) - - -if PY3: - _meth_func = "__func__" - _meth_self = "__self__" - - _func_closure = "__closure__" - _func_code = "__code__" - _func_defaults = "__defaults__" - _func_globals = "__globals__" -else: - _meth_func = "im_func" - _meth_self = "im_self" - - _func_closure = "func_closure" - _func_code = "func_code" - _func_defaults = "func_defaults" - _func_globals = "func_globals" - - -try: - advance_iterator = next -except NameError: - def advance_iterator(it): - return it.next() -next = advance_iterator - - -try: - callable = callable -except NameError: - def callable(obj): - return any("__call__" in klass.__dict__ for klass in type(obj).__mro__) - - -if PY3: - def get_unbound_function(unbound): - return unbound - - create_bound_method = types.MethodType - - def create_unbound_method(func, cls): - return func - - Iterator = object -else: - def get_unbound_function(unbound): - return unbound.im_func - - def create_bound_method(func, obj): - return types.MethodType(func, obj, obj.__class__) - - def create_unbound_method(func, cls): - return types.MethodType(func, None, cls) - - class Iterator(object): - - def next(self): - return type(self).__next__(self) - - callable = callable -_add_doc(get_unbound_function, - """Get the function out of a possibly unbound function""") - - -get_method_function = operator.attrgetter(_meth_func) -get_method_self = operator.attrgetter(_meth_self) -get_function_closure = operator.attrgetter(_func_closure) -get_function_code = operator.attrgetter(_func_code) -get_function_defaults = operator.attrgetter(_func_defaults) -get_function_globals = operator.attrgetter(_func_globals) - - -if PY3: - def iterkeys(d, **kw): - return iter(d.keys(**kw)) - - def itervalues(d, **kw): - return iter(d.values(**kw)) - - def iteritems(d, **kw): - return iter(d.items(**kw)) - - def iterlists(d, **kw): - return iter(d.lists(**kw)) - - viewkeys = operator.methodcaller("keys") - - viewvalues = operator.methodcaller("values") - - viewitems = operator.methodcaller("items") -else: - def iterkeys(d, **kw): - return d.iterkeys(**kw) - - def itervalues(d, **kw): - return d.itervalues(**kw) - - def iteritems(d, **kw): - return d.iteritems(**kw) - - def iterlists(d, **kw): - return d.iterlists(**kw) - - viewkeys = operator.methodcaller("viewkeys") - - viewvalues = operator.methodcaller("viewvalues") - - viewitems = operator.methodcaller("viewitems") - -_add_doc(iterkeys, "Return an iterator over the keys of a dictionary.") -_add_doc(itervalues, "Return an iterator over the values of a dictionary.") -_add_doc(iteritems, - "Return an iterator over the (key, value) pairs of a dictionary.") -_add_doc(iterlists, - "Return an iterator over the (key, [values]) pairs of a dictionary.") - - -if PY3: - def b(s): - return s.encode("latin-1") - - def u(s): - return s - unichr = chr - import struct - int2byte = struct.Struct(">B").pack - del struct - byte2int = operator.itemgetter(0) - indexbytes = operator.getitem - iterbytes = iter - import io - StringIO = io.StringIO - BytesIO = io.BytesIO - _assertCountEqual = "assertCountEqual" - if sys.version_info[1] <= 1: - _assertRaisesRegex = "assertRaisesRegexp" - _assertRegex = "assertRegexpMatches" - else: - _assertRaisesRegex = "assertRaisesRegex" - _assertRegex = "assertRegex" -else: - def b(s): - return s - # Workaround for standalone backslash - - def u(s): - return unicode(s.replace(r'\\', r'\\\\'), "unicode_escape") - unichr = unichr - int2byte = chr - - def byte2int(bs): - return ord(bs[0]) - - def indexbytes(buf, i): - return ord(buf[i]) - iterbytes = functools.partial(itertools.imap, ord) - import StringIO - StringIO = BytesIO = StringIO.StringIO - _assertCountEqual = "assertItemsEqual" - _assertRaisesRegex = "assertRaisesRegexp" - _assertRegex = "assertRegexpMatches" -_add_doc(b, """Byte literal""") -_add_doc(u, """Text literal""") - - -def assertCountEqual(self, *args, **kwargs): - return getattr(self, _assertCountEqual)(*args, **kwargs) - - -def assertRaisesRegex(self, *args, **kwargs): - return getattr(self, _assertRaisesRegex)(*args, **kwargs) - - -def assertRegex(self, *args, **kwargs): - return getattr(self, _assertRegex)(*args, **kwargs) - - -if PY3: - exec_ = getattr(moves.builtins, "exec") - - def reraise(tp, value, tb=None): - try: - if value is None: - value = tp() - if value.__traceback__ is not tb: - raise value.with_traceback(tb) - raise value - finally: - value = None - tb = None - -else: - def exec_(_code_, _globs_=None, _locs_=None): - """Execute code in a namespace.""" - if _globs_ is None: - frame = sys._getframe(1) - _globs_ = frame.f_globals - if _locs_ is None: - _locs_ = frame.f_locals - del frame - elif _locs_ is None: - _locs_ = _globs_ - exec("""exec _code_ in _globs_, _locs_""") - - exec_("""def reraise(tp, value, tb=None): - try: - raise tp, value, tb - finally: - tb = None -""") - - -if sys.version_info[:2] == (3, 2): - exec_("""def raise_from(value, from_value): - try: - if from_value is None: - raise value - raise value from from_value - finally: - value = None -""") -elif sys.version_info[:2] > (3, 2): - exec_("""def raise_from(value, from_value): - try: - raise value from from_value - finally: - value = None -""") -else: - def raise_from(value, from_value): - raise value - - -print_ = getattr(moves.builtins, "print", None) -if print_ is None: - def print_(*args, **kwargs): - """The new-style print function for Python 2.4 and 2.5.""" - fp = kwargs.pop("file", sys.stdout) - if fp is None: - return - - def write(data): - if not isinstance(data, basestring): - data = str(data) - # If the file has an encoding, encode unicode with it. - if (isinstance(fp, file) and - isinstance(data, unicode) and - fp.encoding is not None): - errors = getattr(fp, "errors", None) - if errors is None: - errors = "strict" - data = data.encode(fp.encoding, errors) - fp.write(data) - want_unicode = False - sep = kwargs.pop("sep", None) - if sep is not None: - if isinstance(sep, unicode): - want_unicode = True - elif not isinstance(sep, str): - raise TypeError("sep must be None or a string") - end = kwargs.pop("end", None) - if end is not None: - if isinstance(end, unicode): - want_unicode = True - elif not isinstance(end, str): - raise TypeError("end must be None or a string") - if kwargs: - raise TypeError("invalid keyword arguments to print()") - if not want_unicode: - for arg in args: - if isinstance(arg, unicode): - want_unicode = True - break - if want_unicode: - newline = unicode("\n") - space = unicode(" ") - else: - newline = "\n" - space = " " - if sep is None: - sep = space - if end is None: - end = newline - for i, arg in enumerate(args): - if i: - write(sep) - write(arg) - write(end) -if sys.version_info[:2] < (3, 3): - _print = print_ - - def print_(*args, **kwargs): - fp = kwargs.get("file", sys.stdout) - flush = kwargs.pop("flush", False) - _print(*args, **kwargs) - if flush and fp is not None: - fp.flush() - -_add_doc(reraise, """Reraise an exception.""") - -if sys.version_info[0:2] < (3, 4): - def wraps(wrapped, assigned=functools.WRAPPER_ASSIGNMENTS, - updated=functools.WRAPPER_UPDATES): - def wrapper(f): - f = functools.wraps(wrapped, assigned, updated)(f) - f.__wrapped__ = wrapped - return f - return wrapper -else: - wraps = functools.wraps - - -def with_metaclass(meta, *bases): - """Create a base class with a metaclass.""" - # This requires a bit of explanation: the basic idea is to make a dummy - # metaclass for one level of class instantiation that replaces itself with - # the actual metaclass. - class metaclass(type): - - def __new__(cls, name, this_bases, d): - return meta(name, bases, d) - - @classmethod - def __prepare__(cls, name, this_bases): - return meta.__prepare__(name, bases) - return type.__new__(metaclass, 'temporary_class', (), {}) - - -def add_metaclass(metaclass): - """Class decorator for creating a class with a metaclass.""" - def wrapper(cls): - orig_vars = cls.__dict__.copy() - slots = orig_vars.get('__slots__') - if slots is not None: - if isinstance(slots, str): - slots = [slots] - for slots_var in slots: - orig_vars.pop(slots_var) - orig_vars.pop('__dict__', None) - orig_vars.pop('__weakref__', None) - return metaclass(cls.__name__, cls.__bases__, orig_vars) - return wrapper - - -def python_2_unicode_compatible(klass): - """ - A decorator that defines __unicode__ and __str__ methods under Python 2. - Under Python 3 it does nothing. - - To support Python 2 and 3 with a single code base, define a __str__ method - returning text and apply this decorator to the class. - """ - if PY2: - if '__str__' not in klass.__dict__: - raise ValueError("@python_2_unicode_compatible cannot be applied " - "to %s because it doesn't define __str__()." % - klass.__name__) - klass.__unicode__ = klass.__str__ - klass.__str__ = lambda self: self.__unicode__().encode('utf-8') - return klass - - -# Complete the moves implementation. -# This code is at the end of this module to speed up module loading. -# Turn this module into a package. -__path__ = [] # required for PEP 302 and PEP 451 -__package__ = __name__ # see PEP 366 @ReservedAssignment -if globals().get("__spec__") is not None: - __spec__.submodule_search_locations = [] # PEP 451 @UndefinedVariable -# Remove other six meta path importers, since they cause problems. This can -# happen if six is removed from sys.modules and then reloaded. (Setuptools does -# this for some reason.) -if sys.meta_path: - for i, importer in enumerate(sys.meta_path): - # Here's some real nastiness: Another "instance" of the six module might - # be floating around. Therefore, we can't use isinstance() to check for - # the six meta path importer, since the other six instance will have - # inserted an importer with different class. - if (type(importer).__name__ == "_SixMetaPathImporter" and - importer.name == __name__): - del sys.meta_path[i] - break - del i, importer -# Finally, add the importer to the meta path import hook. -sys.meta_path.append(_importer) diff --git a/kafka/vendor/socketpair.py b/kafka/vendor/socketpair.py deleted file mode 100644 index b55e629e..00000000 --- a/kafka/vendor/socketpair.py +++ /dev/null @@ -1,58 +0,0 @@ -# pylint: skip-file -# vendored from https://github.com/mhils/backports.socketpair -from __future__ import absolute_import - -import sys -import socket -import errno - -_LOCALHOST = '127.0.0.1' -_LOCALHOST_V6 = '::1' - -if not hasattr(socket, "socketpair"): - # Origin: https://gist.github.com/4325783, by Geert Jansen. Public domain. - def socketpair(family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0): - if family == socket.AF_INET: - host = _LOCALHOST - elif family == socket.AF_INET6: - host = _LOCALHOST_V6 - else: - raise ValueError("Only AF_INET and AF_INET6 socket address families " - "are supported") - if type != socket.SOCK_STREAM: - raise ValueError("Only SOCK_STREAM socket type is supported") - if proto != 0: - raise ValueError("Only protocol zero is supported") - - # We create a connected TCP socket. Note the trick with - # setblocking(False) that prevents us from having to create a thread. - lsock = socket.socket(family, type, proto) - try: - lsock.bind((host, 0)) - lsock.listen(min(socket.SOMAXCONN, 128)) - # On IPv6, ignore flow_info and scope_id - addr, port = lsock.getsockname()[:2] - csock = socket.socket(family, type, proto) - try: - csock.setblocking(False) - if sys.version_info >= (3, 0): - try: - csock.connect((addr, port)) - except (BlockingIOError, InterruptedError): - pass - else: - try: - csock.connect((addr, port)) - except socket.error as e: - if e.errno != errno.WSAEWOULDBLOCK: - raise - csock.setblocking(True) - ssock, _ = lsock.accept() - except Exception: - csock.close() - raise - finally: - lsock.close() - return (ssock, csock) - - socket.socketpair = socketpair diff --git a/setup.py b/setup.py index fc69494b..07d9710c 100644 --- a/setup.py +++ b/setup.py @@ -172,7 +172,7 @@ def read_version(): }, download_url="https://pypi.python.org/pypi/aiokafka", license="Apache 2", - packages=["aiokafka", "kafka"], + packages=["aiokafka"], python_requires=">=3.8", install_requires=install_requires, extras_require=extras_require, diff --git a/tests/kafka/__init__.py b/tests/kafka/__init__.py deleted file mode 100644 index 329277dc..00000000 --- a/tests/kafka/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from __future__ import absolute_import - -# Set default logging handler to avoid "No handler found" warnings. -import logging -logging.basicConfig(level=logging.INFO) - -from kafka.future import Future -Future.error_on_callbacks = True # always fail during testing diff --git a/tests/kafka/conftest.py b/tests/kafka/conftest.py deleted file mode 100644 index 04aec4b8..00000000 --- a/tests/kafka/conftest.py +++ /dev/null @@ -1,140 +0,0 @@ -from __future__ import absolute_import - -import uuid - -import pytest - -from tests.kafka.testutil import env_kafka_version, random_string -from tests.kafka.fixtures import KafkaFixture, ZookeeperFixture - -@pytest.fixture(scope="module") -def zookeeper(): - """Return a Zookeeper fixture""" - zk_instance = ZookeeperFixture.instance() - yield zk_instance - zk_instance.close() - - -@pytest.fixture(scope="module") -def kafka_broker(kafka_broker_factory): - """Return a Kafka broker fixture""" - return kafka_broker_factory()[0] - - -@pytest.fixture(scope="module") -def kafka_broker_factory(zookeeper): - """Return a Kafka broker fixture factory""" - assert env_kafka_version(), 'KAFKA_VERSION must be specified to run integration tests' - - _brokers = [] - def factory(**broker_params): - params = {} if broker_params is None else broker_params.copy() - params.setdefault('partitions', 4) - num_brokers = params.pop('num_brokers', 1) - brokers = tuple(KafkaFixture.instance(x, zookeeper, **params) - for x in range(num_brokers)) - _brokers.extend(brokers) - return brokers - - yield factory - - for broker in _brokers: - broker.close() - - -@pytest.fixture -def kafka_consumer(kafka_consumer_factory): - """Return a KafkaConsumer fixture""" - return kafka_consumer_factory() - - -@pytest.fixture -def kafka_consumer_factory(kafka_broker, topic, request): - """Return a KafkaConsumer factory fixture""" - _consumer = [None] - - def factory(**kafka_consumer_params): - params = {} if kafka_consumer_params is None else kafka_consumer_params.copy() - params.setdefault('client_id', 'consumer_%s' % (request.node.name,)) - params.setdefault('auto_offset_reset', 'earliest') - _consumer[0] = next(kafka_broker.get_consumers(cnt=1, topics=[topic], **params)) - return _consumer[0] - - yield factory - - if _consumer[0]: - _consumer[0].close() - - -@pytest.fixture -def kafka_producer(kafka_producer_factory): - """Return a KafkaProducer fixture""" - yield kafka_producer_factory() - - -@pytest.fixture -def kafka_producer_factory(kafka_broker, request): - """Return a KafkaProduce factory fixture""" - _producer = [None] - - def factory(**kafka_producer_params): - params = {} if kafka_producer_params is None else kafka_producer_params.copy() - params.setdefault('client_id', 'producer_%s' % (request.node.name,)) - _producer[0] = next(kafka_broker.get_producers(cnt=1, **params)) - return _producer[0] - - yield factory - - if _producer[0]: - _producer[0].close() - -@pytest.fixture -def kafka_admin_client(kafka_admin_client_factory): - """Return a KafkaAdminClient fixture""" - yield kafka_admin_client_factory() - -@pytest.fixture -def kafka_admin_client_factory(kafka_broker): - """Return a KafkaAdminClient factory fixture""" - _admin_client = [None] - - def factory(**kafka_admin_client_params): - params = {} if kafka_admin_client_params is None else kafka_admin_client_params.copy() - _admin_client[0] = next(kafka_broker.get_admin_clients(cnt=1, **params)) - return _admin_client[0] - - yield factory - - if _admin_client[0]: - _admin_client[0].close() - -@pytest.fixture -def topic(kafka_broker, request): - """Return a topic fixture""" - topic_name = '%s_%s' % (request.node.name, random_string(10)) - kafka_broker.create_topics([topic_name]) - return topic_name - - -@pytest.fixture() -def send_messages(topic, kafka_producer, request): - """A factory that returns a send_messages function with a pre-populated - topic topic / producer.""" - - def _send_messages(number_range, partition=0, topic=topic, producer=kafka_producer, request=request): - """ - messages is typically `range(0,100)` - partition is an int - """ - messages_and_futures = [] # [(message, produce_future),] - for i in number_range: - # request.node.name provides the test name (including parametrized values) - encoded_msg = '{}-{}-{}'.format(i, request.node.name, uuid.uuid4()).encode('utf-8') - future = kafka_producer.send(topic, value=encoded_msg, partition=partition) - messages_and_futures.append((encoded_msg, future)) - kafka_producer.flush() - for (msg, f) in messages_and_futures: - assert f.succeeded() - return [msg for (msg, f) in messages_and_futures] - - return _send_messages diff --git a/tests/kafka/fixtures.py b/tests/kafka/fixtures.py deleted file mode 100644 index b6854e54..00000000 --- a/tests/kafka/fixtures.py +++ /dev/null @@ -1,651 +0,0 @@ -from __future__ import absolute_import - -import atexit -import logging -import os -import os.path -import socket -import subprocess -import time -import uuid - -import py -from kafka.vendor.six.moves import urllib, range -from kafka.vendor.six.moves.urllib.parse import urlparse # pylint: disable=E0611,F0401 - -from aiokafka import errors -from aiokafka.errors import InvalidReplicationFactorError -from aiokafka.protocol.admin import CreateTopicsRequest -from aiokafka.protocol.metadata import MetadataRequest -from tests.kafka.testutil import env_kafka_version, random_string -from tests.kafka.service import ExternalService, SpawnedService - -log = logging.getLogger(__name__) - - -def get_open_port(): - sock = socket.socket() - sock.bind(("127.0.0.1", 0)) - port = sock.getsockname()[1] - sock.close() - return port - - -def gen_ssl_resources(directory): - os.system(""" - cd {0} - echo Generating SSL resources in {0} - - # Step 1 - keytool -keystore kafka.server.keystore.jks -alias localhost -validity 1 \ - -genkey -storepass foobar -keypass foobar \ - -dname "CN=localhost, OU=kafka-python, O=kafka-python, L=SF, ST=CA, C=US" \ - -ext SAN=dns:localhost - - # Step 2 - openssl genrsa -out ca-key 2048 - openssl req -new -x509 -key ca-key -out ca-cert -days 1 \ - -subj "/C=US/ST=CA/O=MyOrg, Inc./CN=mydomain.com" - keytool -keystore kafka.server.truststore.jks -alias CARoot -import \ - -file ca-cert -storepass foobar -noprompt - - # Step 3 - keytool -keystore kafka.server.keystore.jks -alias localhost -certreq \ - -file cert-file -storepass foobar - openssl x509 -req -CA ca-cert -CAkey ca-key -in cert-file -out cert-signed \ - -days 1 -CAcreateserial -passin pass:foobar - keytool -keystore kafka.server.keystore.jks -alias CARoot -import \ - -file ca-cert -storepass foobar -noprompt - keytool -keystore kafka.server.keystore.jks -alias localhost -import \ - -file cert-signed -storepass foobar -noprompt - """.format(directory)) - - -class Fixture(object): - kafka_version = os.environ.get('KAFKA_VERSION', '0.11.0.2') - scala_version = os.environ.get("SCALA_VERSION", '2.8.0') - project_root = os.environ.get('PROJECT_ROOT', - os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) - kafka_root = os.environ.get("KAFKA_ROOT", - os.path.join(project_root, 'servers', kafka_version, "kafka-bin")) - - def __init__(self): - self.child = None - - @classmethod - def download_official_distribution(cls, - kafka_version=None, - scala_version=None, - output_dir=None): - if not kafka_version: - kafka_version = cls.kafka_version - if not scala_version: - scala_version = cls.scala_version - if not output_dir: - output_dir = os.path.join(cls.project_root, 'servers', 'dist') - - distfile = 'kafka_%s-%s' % (scala_version, kafka_version,) - url_base = 'https://archive.apache.org/dist/kafka/%s/' % (kafka_version,) - output_file = os.path.join(output_dir, distfile + '.tgz') - - if os.path.isfile(output_file): - log.info("Found file already on disk: %s", output_file) - return output_file - - # New tarballs are .tgz, older ones are sometimes .tar.gz - try: - url = url_base + distfile + '.tgz' - log.info("Attempting to download %s", url) - response = urllib.request.urlopen(url) - except urllib.error.HTTPError: - log.exception("HTTP Error") - url = url_base + distfile + '.tar.gz' - log.info("Attempting to download %s", url) - response = urllib.request.urlopen(url) - - log.info("Saving distribution file to %s", output_file) - with open(output_file, 'w') as output_file_fd: - output_file_fd.write(response.read()) - - return output_file - - @classmethod - def test_resource(cls, filename): - return os.path.join(cls.project_root, "servers", cls.kafka_version, "resources", filename) - - @classmethod - def kafka_run_class_args(cls, *args): - result = [os.path.join(cls.kafka_root, 'bin', 'kafka-run-class.sh')] - result.extend([str(arg) for arg in args]) - return result - - def kafka_run_class_env(self): - env = os.environ.copy() - env['KAFKA_LOG4J_OPTS'] = "-Dlog4j.configuration=file:%s" % \ - (self.test_resource("log4j.properties"),) - return env - - @classmethod - def render_template(cls, source_file, target_file, binding): - log.info('Rendering %s from template %s', target_file.strpath, source_file) - with open(source_file, "r") as handle: - template = handle.read() - assert len(template) > 0, 'Empty template %s' % (source_file,) - with open(target_file.strpath, "w") as handle: - handle.write(template.format(**binding)) - handle.flush() - os.fsync(handle) - - # fsync directory for durability - # https://blog.gocept.com/2013/07/15/reliable-file-updates-with-python/ - dirfd = os.open(os.path.dirname(target_file.strpath), os.O_DIRECTORY) - os.fsync(dirfd) - os.close(dirfd) - log.debug("Template string:") - for line in template.splitlines(): - log.debug(' ' + line.strip()) - log.debug("Rendered template:") - with open(target_file.strpath, 'r') as o: - for line in o: - log.debug(' ' + line.strip()) - log.debug("binding:") - for key, value in binding.items(): - log.debug(" {key}={value}".format(key=key, value=value)) - - def dump_logs(self): - self.child.dump_logs() - - -class ZookeeperFixture(Fixture): - @classmethod - def instance(cls): - if "ZOOKEEPER_URI" in os.environ: - parse = urlparse(os.environ["ZOOKEEPER_URI"]) - (host, port) = (parse.hostname, parse.port) - fixture = ExternalService(host, port) - else: - (host, port) = ("127.0.0.1", None) - fixture = cls(host, port) - - fixture.open() - return fixture - - def __init__(self, host, port, tmp_dir=None): - super(ZookeeperFixture, self).__init__() - self.host = host - self.port = port - - self.tmp_dir = tmp_dir - - def kafka_run_class_env(self): - env = super(ZookeeperFixture, self).kafka_run_class_env() - env['LOG_DIR'] = self.tmp_dir.join('logs').strpath - return env - - def out(self, message): - log.info("*** Zookeeper [%s:%s]: %s", self.host, self.port or '(auto)', message) - - def open(self): - if self.tmp_dir is None: - self.tmp_dir = py.path.local.mkdtemp() #pylint: disable=no-member - self.tmp_dir.ensure(dir=True) - - self.out("Running local instance...") - log.info(" host = %s", self.host) - log.info(" port = %s", self.port or '(auto)') - log.info(" tmp_dir = %s", self.tmp_dir.strpath) - - # Configure Zookeeper child process - template = self.test_resource("zookeeper.properties") - properties = self.tmp_dir.join("zookeeper.properties") - args = self.kafka_run_class_args("org.apache.zookeeper.server.quorum.QuorumPeerMain", - properties.strpath) - env = self.kafka_run_class_env() - - # Party! - timeout = 5 - max_timeout = 120 - backoff = 1 - end_at = time.time() + max_timeout - tries = 1 - auto_port = (self.port is None) - while time.time() < end_at: - if auto_port: - self.port = get_open_port() - self.out('Attempting to start on port %d (try #%d)' % (self.port, tries)) - self.render_template(template, properties, vars(self)) - self.child = SpawnedService(args, env) - self.child.start() - timeout = min(timeout, max(end_at - time.time(), 0)) - if self.child.wait_for(r"binding to port", timeout=timeout): - break - self.child.dump_logs() - self.child.stop() - timeout *= 2 - time.sleep(backoff) - tries += 1 - backoff += 1 - else: - raise RuntimeError('Failed to start Zookeeper before max_timeout') - self.out("Done!") - atexit.register(self.close) - - def close(self): - if self.child is None: - return - self.out("Stopping...") - self.child.stop() - self.child = None - self.out("Done!") - self.tmp_dir.remove() - - def __del__(self): - self.close() - - -class KafkaFixture(Fixture): - broker_user = 'alice' - broker_password = 'alice-secret' - - @classmethod - def instance(cls, broker_id, zookeeper, zk_chroot=None, - host=None, port=None, - transport='PLAINTEXT', replicas=1, partitions=2, - sasl_mechanism=None, auto_create_topic=True, tmp_dir=None): - - if zk_chroot is None: - zk_chroot = "kafka-python_" + str(uuid.uuid4()).replace("-", "_") - if "KAFKA_URI" in os.environ: - parse = urlparse(os.environ["KAFKA_URI"]) - (host, port) = (parse.hostname, parse.port) - fixture = ExternalService(host, port) - else: - if host is None: - host = "localhost" - fixture = KafkaFixture(host, port, broker_id, - zookeeper, zk_chroot, - transport=transport, - replicas=replicas, partitions=partitions, - sasl_mechanism=sasl_mechanism, - auto_create_topic=auto_create_topic, - tmp_dir=tmp_dir) - - fixture.open() - return fixture - - def __init__(self, host, port, broker_id, zookeeper, zk_chroot, - replicas=1, partitions=2, transport='PLAINTEXT', - sasl_mechanism=None, auto_create_topic=True, - tmp_dir=None): - super(KafkaFixture, self).__init__() - - self.host = host - self.port = port - - self.broker_id = broker_id - self.auto_create_topic = auto_create_topic - self.transport = transport.upper() - if sasl_mechanism is not None: - self.sasl_mechanism = sasl_mechanism.upper() - else: - self.sasl_mechanism = None - self.ssl_dir = self.test_resource('ssl') - - # TODO: checking for port connection would be better than scanning logs - # until then, we need the pattern to work across all supported broker versions - # The logging format changed slightly in 1.0.0 - self.start_pattern = r"\[Kafka ?Server (id=)?%d\],? started" % (broker_id,) - # Need to wait until the broker has fetched user configs from zookeeper in case we use scram as sasl mechanism - self.scram_pattern = r"Removing Produce quota for user %s" % (self.broker_user) - - self.zookeeper = zookeeper - self.zk_chroot = zk_chroot - # Add the attributes below for the template binding - self.zk_host = self.zookeeper.host - self.zk_port = self.zookeeper.port - - self.replicas = replicas - self.partitions = partitions - - self.tmp_dir = tmp_dir - self.running = False - - self._client = None - self.sasl_config = '' - self.jaas_config = '' - - def _sasl_config(self): - if not self.sasl_enabled: - return '' - - sasl_config = ( - 'sasl.enabled.mechanisms={mechanism}\n' - 'sasl.mechanism.inter.broker.protocol={mechanism}\n' - ) - return sasl_config.format(mechanism=self.sasl_mechanism) - - def _jaas_config(self): - if not self.sasl_enabled: - return '' - - elif self.sasl_mechanism == 'PLAIN': - jaas_config = ( - 'org.apache.kafka.common.security.plain.PlainLoginModule required\n' - ' username="{user}" password="{password}" user_{user}="{password}";\n' - ) - elif self.sasl_mechanism in ("SCRAM-SHA-256", "SCRAM-SHA-512"): - jaas_config = ( - 'org.apache.kafka.common.security.scram.ScramLoginModule required\n' - ' username="{user}" password="{password}";\n' - ) - else: - raise ValueError("SASL mechanism {} currently not supported".format(self.sasl_mechanism)) - return jaas_config.format(user=self.broker_user, password=self.broker_password) - - def _add_scram_user(self): - self.out("Adding SCRAM credentials for user {} to zookeeper.".format(self.broker_user)) - args = self.kafka_run_class_args( - "kafka.admin.ConfigCommand", - "--zookeeper", - "%s:%d/%s" % (self.zookeeper.host, - self.zookeeper.port, - self.zk_chroot), - "--alter", - "--entity-type", "users", - "--entity-name", self.broker_user, - "--add-config", - "{}=[password={}]".format(self.sasl_mechanism, self.broker_password), - ) - env = self.kafka_run_class_env() - proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - - stdout, stderr = proc.communicate() - - if proc.returncode != 0: - self.out("Failed to save credentials to zookeeper!") - self.out(stdout) - self.out(stderr) - raise RuntimeError("Failed to save credentials to zookeeper!") - self.out("User created.") - - @property - def sasl_enabled(self): - return self.sasl_mechanism is not None - - def bootstrap_server(self): - return '%s:%d' % (self.host, self.port) - - def kafka_run_class_env(self): - env = super(KafkaFixture, self).kafka_run_class_env() - env['LOG_DIR'] = self.tmp_dir.join('logs').strpath - return env - - def out(self, message): - log.info("*** Kafka [%s:%s]: %s", self.host, self.port or '(auto)', message) - - def _create_zk_chroot(self): - self.out("Creating Zookeeper chroot node...") - args = self.kafka_run_class_args("org.apache.zookeeper.ZooKeeperMain", - "-server", - "%s:%d" % (self.zookeeper.host, - self.zookeeper.port), - "create", - "/%s" % (self.zk_chroot,), - "kafka-python") - env = self.kafka_run_class_env() - proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - - stdout, stderr = proc.communicate() - - if proc.returncode != 0: - self.out("Failed to create Zookeeper chroot node") - self.out(stdout) - self.out(stderr) - raise RuntimeError("Failed to create Zookeeper chroot node") - self.out("Kafka chroot created in Zookeeper!") - - def start(self): - # Configure Kafka child process - properties = self.tmp_dir.join("kafka.properties") - jaas_conf = self.tmp_dir.join("kafka_server_jaas.conf") - properties_template = self.test_resource("kafka.properties") - jaas_conf_template = self.test_resource("kafka_server_jaas.conf") - - args = self.kafka_run_class_args("kafka.Kafka", properties.strpath) - env = self.kafka_run_class_env() - if self.sasl_enabled: - opts = env.get('KAFKA_OPTS', '').strip() - opts += ' -Djava.security.auth.login.config={}'.format(jaas_conf.strpath) - env['KAFKA_OPTS'] = opts - self.render_template(jaas_conf_template, jaas_conf, vars(self)) - - timeout = 5 - max_timeout = 120 - backoff = 1 - end_at = time.time() + max_timeout - tries = 1 - auto_port = (self.port is None) - while time.time() < end_at: - # We have had problems with port conflicts on travis - # so we will try a different port on each retry - # unless the fixture was passed a specific port - if auto_port: - self.port = get_open_port() - self.out('Attempting to start on port %d (try #%d)' % (self.port, tries)) - self.render_template(properties_template, properties, vars(self)) - - self.child = SpawnedService(args, env) - self.child.start() - timeout = min(timeout, max(end_at - time.time(), 0)) - if self._broker_ready(timeout) and self._scram_user_present(timeout): - break - - self.child.dump_logs() - self.child.stop() - - timeout *= 2 - time.sleep(backoff) - tries += 1 - backoff += 1 - else: - raise RuntimeError('Failed to start KafkaInstance before max_timeout') - - (self._client,) = self.get_clients(1, client_id='_internal_client') - - self.out("Done!") - self.running = True - - def _broker_ready(self, timeout): - return self.child.wait_for(self.start_pattern, timeout=timeout) - - def _scram_user_present(self, timeout): - # no need to wait for scram user if scram is not used - if not self.sasl_enabled or not self.sasl_mechanism.startswith('SCRAM-SHA-'): - return True - return self.child.wait_for(self.scram_pattern, timeout=timeout) - - def open(self): - if self.running: - self.out("Instance already running") - return - - # Create directories - if self.tmp_dir is None: - self.tmp_dir = py.path.local.mkdtemp() #pylint: disable=no-member - self.tmp_dir.ensure(dir=True) - self.tmp_dir.ensure('logs', dir=True) - self.tmp_dir.ensure('data', dir=True) - - self.out("Running local instance...") - log.info(" host = %s", self.host) - log.info(" port = %s", self.port or '(auto)') - log.info(" transport = %s", self.transport) - log.info(" sasl_mechanism = %s", self.sasl_mechanism) - log.info(" broker_id = %s", self.broker_id) - log.info(" zk_host = %s", self.zookeeper.host) - log.info(" zk_port = %s", self.zookeeper.port) - log.info(" zk_chroot = %s", self.zk_chroot) - log.info(" replicas = %s", self.replicas) - log.info(" partitions = %s", self.partitions) - log.info(" tmp_dir = %s", self.tmp_dir.strpath) - - self._create_zk_chroot() - self.sasl_config = self._sasl_config() - self.jaas_config = self._jaas_config() - # add user to zookeeper for the first server - if self.sasl_enabled and self.sasl_mechanism.startswith("SCRAM-SHA") and self.broker_id == 0: - self._add_scram_user() - self.start() - - atexit.register(self.close) - - def __del__(self): - self.close() - - def stop(self): - if not self.running: - self.out("Instance already stopped") - return - - self.out("Stopping...") - self.child.stop() - self.child = None - self.running = False - self.out("Stopped!") - - def close(self): - self.stop() - if self.tmp_dir is not None: - self.tmp_dir.remove() - self.tmp_dir = None - self.out("Done!") - - def dump_logs(self): - super(KafkaFixture, self).dump_logs() - self.zookeeper.dump_logs() - - def _send_request(self, request, timeout=None): - def _failure(error): - raise error - retries = 10 - while True: - node_id = self._client.least_loaded_node() - for connect_retry in range(40): - self._client.maybe_connect(node_id) - if self._client.connected(node_id): - break - self._client.poll(timeout_ms=100) - else: - raise RuntimeError('Could not connect to broker with node id %d' % (node_id,)) - - try: - future = self._client.send(node_id, request) - future.error_on_callbacks = True - future.add_errback(_failure) - self._client.poll(future=future, timeout_ms=timeout) - return future.value - except Exception as exc: - time.sleep(1) - retries -= 1 - if retries == 0: - raise exc - else: - pass # retry - - def _create_topic(self, topic_name, num_partitions=None, replication_factor=None, timeout_ms=10000): - if num_partitions is None: - num_partitions = self.partitions - if replication_factor is None: - replication_factor = self.replicas - - # Try different methods to create a topic, from the fastest to the slowest - if self.auto_create_topic and num_partitions == self.partitions and replication_factor == self.replicas: - self._create_topic_via_metadata(topic_name, timeout_ms) - elif env_kafka_version() >= (0, 10, 1, 0): - try: - self._create_topic_via_admin_api(topic_name, num_partitions, replication_factor, timeout_ms) - except InvalidReplicationFactorError: - # wait and try again - # on travis the brokers sometimes take a while to find themselves - time.sleep(0.5) - self._create_topic_via_admin_api(topic_name, num_partitions, replication_factor, timeout_ms) - else: - self._create_topic_via_cli(topic_name, num_partitions, replication_factor) - - def _create_topic_via_metadata(self, topic_name, timeout_ms=10000): - self._send_request(MetadataRequest[0]([topic_name]), timeout_ms) - - def _create_topic_via_admin_api(self, topic_name, num_partitions, replication_factor, timeout_ms=10000): - request = CreateTopicsRequest[0]([(topic_name, num_partitions, - replication_factor, [], [])], timeout_ms) - response = self._send_request(request, timeout=timeout_ms) - for topic_result in response.topic_errors: - error_code = topic_result[1] - if error_code != 0: - raise errors.for_code(error_code) - - def _create_topic_via_cli(self, topic_name, num_partitions, replication_factor): - args = self.kafka_run_class_args('kafka.admin.TopicCommand', - '--zookeeper', '%s:%s/%s' % (self.zookeeper.host, - self.zookeeper.port, - self.zk_chroot), - '--create', - '--topic', topic_name, - '--partitions', self.partitions \ - if num_partitions is None else num_partitions, - '--replication-factor', self.replicas \ - if replication_factor is None \ - else replication_factor) - if env_kafka_version() >= (0, 10): - args.append('--if-not-exists') - env = self.kafka_run_class_env() - proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - stdout, stderr = proc.communicate() - if proc.returncode != 0: - if 'kafka.common.TopicExistsException' not in stdout: - self.out("Failed to create topic %s" % (topic_name,)) - self.out(stdout) - self.out(stderr) - raise RuntimeError("Failed to create topic %s" % (topic_name,)) - - def get_topic_names(self): - args = self.kafka_run_class_args('kafka.admin.TopicCommand', - '--zookeeper', '%s:%s/%s' % (self.zookeeper.host, - self.zookeeper.port, - self.zk_chroot), - '--list' - ) - env = self.kafka_run_class_env() - env.pop('KAFKA_LOG4J_OPTS') - proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - stdout, stderr = proc.communicate() - if proc.returncode != 0: - self.out("Failed to list topics!") - self.out(stdout) - self.out(stderr) - raise RuntimeError("Failed to list topics!") - return stdout.decode().splitlines(False) - - def create_topics(self, topic_names, num_partitions=None, replication_factor=None): - for topic_name in topic_names: - self._create_topic(topic_name, num_partitions, replication_factor) - - def _enrich_client_params(self, params, **defaults): - params = params.copy() - for key, value in defaults.items(): - params.setdefault(key, value) - params.setdefault('bootstrap_servers', self.bootstrap_server()) - if self.sasl_enabled: - params.setdefault('sasl_mechanism', self.sasl_mechanism) - params.setdefault('security_protocol', self.transport) - if self.sasl_mechanism in ('PLAIN', 'SCRAM-SHA-256', 'SCRAM-SHA-512'): - params.setdefault('sasl_plain_username', self.broker_user) - params.setdefault('sasl_plain_password', self.broker_password) - return params - - @staticmethod - def _create_many_clients(cnt, cls, *args, **params): - client_id = params['client_id'] - for _ in range(cnt): - params['client_id'] = '%s_%s' % (client_id, random_string(4)) - yield cls(*args, **params) diff --git a/tests/kafka/service.py b/tests/kafka/service.py deleted file mode 100644 index 045d780e..00000000 --- a/tests/kafka/service.py +++ /dev/null @@ -1,133 +0,0 @@ -from __future__ import absolute_import - -import logging -import os -import re -import select -import subprocess -import sys -import threading -import time - -__all__ = [ - 'ExternalService', - 'SpawnedService', -] - -log = logging.getLogger(__name__) - - -class ExternalService(object): - def __init__(self, host, port): - log.info("Using already running service at %s:%d", host, port) - self.host = host - self.port = port - - def open(self): - pass - - def close(self): - pass - - -class SpawnedService(threading.Thread): - def __init__(self, args=None, env=None): - super(SpawnedService, self).__init__() - - if args is None: - raise TypeError("args parameter is required") - self.args = args - self.env = env - self.captured_stdout = [] - self.captured_stderr = [] - - self.should_die = threading.Event() - self.child = None - self.alive = False - self.daemon = True - log.info("Created service for command:") - log.info(" "+' '.join(self.args)) - log.debug("With environment:") - for key, value in self.env.items(): - log.debug(" {key}={value}".format(key=key, value=value)) - - def _spawn(self): - if self.alive: return - if self.child and self.child.poll() is None: return - - self.child = subprocess.Popen( - self.args, - preexec_fn=os.setsid, # to avoid propagating signals - env=self.env, - bufsize=1, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - self.alive = self.child.poll() is None - - def _despawn(self): - if self.child.poll() is None: - self.child.terminate() - self.alive = False - for _ in range(50): - if self.child.poll() is not None: - self.child = None - break - time.sleep(0.1) - else: - self.child.kill() - - def run(self): - self._spawn() - while True: - try: - (rds, _, _) = select.select([self.child.stdout, self.child.stderr], [], [], 1) - except select.error as ex: - if ex.args[0] == 4: - continue - else: - raise - - if self.child.stdout in rds: - line = self.child.stdout.readline().decode('utf-8').rstrip() - if line: - self.captured_stdout.append(line) - - if self.child.stderr in rds: - line = self.child.stderr.readline().decode('utf-8').rstrip() - if line: - self.captured_stderr.append(line) - - if self.child.poll() is not None: - self.dump_logs() - break - - if self.should_die.is_set(): - self._despawn() - break - - def dump_logs(self): - sys.stderr.write('\n'.join(self.captured_stderr)) - sys.stdout.write('\n'.join(self.captured_stdout)) - - def wait_for(self, pattern, timeout=30): - start = time.time() - while True: - if not self.is_alive(): - raise RuntimeError("Child thread died already.") - - elapsed = time.time() - start - if elapsed >= timeout: - log.error("Waiting for %r timed out after %d seconds", pattern, timeout) - return False - - if re.search(pattern, '\n'.join(self.captured_stdout), re.IGNORECASE) is not None: - log.info("Found pattern %r in %d seconds via stdout", pattern, elapsed) - return True - if re.search(pattern, '\n'.join(self.captured_stderr), re.IGNORECASE) is not None: - log.info("Found pattern %r in %d seconds via stderr", pattern, elapsed) - return True - time.sleep(0.1) - - def stop(self): - self.should_die.set() - self.join() diff --git a/tests/kafka/testutil.py b/tests/kafka/testutil.py deleted file mode 100644 index ec4d70bf..00000000 --- a/tests/kafka/testutil.py +++ /dev/null @@ -1,46 +0,0 @@ -from __future__ import absolute_import - -import os -import random -import re -import string -import time - - -def special_to_underscore(string, _matcher=re.compile(r'[^a-zA-Z0-9_]+')): - return _matcher.sub('_', string) - - -def random_string(length): - return "".join(random.choice(string.ascii_letters) for i in range(length)) - - -def env_kafka_version(): - """Return the Kafka version set in the OS environment as a tuple. - - Example: '0.8.1.1' --> (0, 8, 1, 1) - """ - if 'KAFKA_VERSION' not in os.environ: - return () - return tuple(map(int, os.environ['KAFKA_VERSION'].split('.'))) - - -def assert_message_count(messages, num_messages): - """Check that we received the expected number of messages with no duplicates.""" - # Make sure we got them all - assert len(messages) == num_messages - # Make sure there are no duplicates - # Note: Currently duplicates are identified only using key/value. Other attributes like topic, partition, headers, - # timestamp, etc are ignored... this could be changed if necessary, but will be more tolerant of dupes. - unique_messages = {(m.key, m.value) for m in messages} - assert len(unique_messages) == num_messages - - -class Timer(object): - def __enter__(self): - self.start = time.time() - return self - - def __exit__(self, *args): - self.end = time.time() - self.interval = self.end - self.start From 66e299900bbba770f5a9da5a963f02c4d5d44dbe Mon Sep 17 00:00:00 2001 From: Denis Otkidach Date: Mon, 23 Oct 2023 15:04:20 +0300 Subject: [PATCH 20/20] Remove dead code --- aiokafka/cluster.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/aiokafka/cluster.py b/aiokafka/cluster.py index 061db638..ca4ad012 100644 --- a/aiokafka/cluster.py +++ b/aiokafka/cluster.py @@ -164,24 +164,6 @@ def coordinator_for_group(self, group): """ return self._groups.get(group) - def ttl(self): - """Milliseconds until metadata should be refreshed""" - now = time.time() * 1000 - if self._need_update: - ttl = 0 - else: - metadata_age = now - self._last_successful_refresh_ms - ttl = self.config['metadata_max_age_ms'] - metadata_age - - retry_age = now - self._last_refresh_ms - next_retry = self.config['retry_backoff_ms'] - retry_age - - return max(ttl, next_retry, 0) - - def refresh_backoff(self): - """Return milliseconds to wait before attempting to retry after failure""" - return self.config['retry_backoff_ms'] - def request_update(self): """Flags metadata for update, return Future()