From 8c4cd4084e52bc3de955a2e20c6f3311f48b6e40 Mon Sep 17 00:00:00 2001 From: Denis Otkidach Date: Mon, 23 Oct 2023 12:17:39 +0300 Subject: [PATCH] Move codec --- {kafka => aiokafka}/codec.py | 124 ++++++++-------- aiokafka/producer/producer.py | 3 +- aiokafka/protocol/message.py | 2 +- aiokafka/record/_crecords/default_records.pyx | 6 +- aiokafka/record/_crecords/legacy_records.pyx | 4 +- aiokafka/record/default_records.py | 7 +- aiokafka/record/legacy_records.py | 8 +- tests/conftest.py | 1 + tests/kafka/test_codec.py | 124 ---------------- tests/record/test_default_records.py | 4 +- tests/record/test_legacy.py | 4 +- tests/test_codec.py | 136 ++++++++++++++++++ 12 files changed, 214 insertions(+), 209 deletions(-) rename {kafka => aiokafka}/codec.py (67%) delete mode 100644 tests/kafka/test_codec.py create mode 100644 tests/test_codec.py diff --git a/kafka/codec.py b/aiokafka/codec.py similarity index 67% rename from kafka/codec.py rename to aiokafka/codec.py index c740a181..2e3ddaaf 100644 --- a/kafka/codec.py +++ b/aiokafka/codec.py @@ -1,15 +1,10 @@ -from __future__ import absolute_import - import gzip import io import platform import struct -from kafka.vendor import six -from kafka.vendor.six.moves import range - -_XERIAL_V1_HEADER = (-126, b'S', b'N', b'A', b'P', b'P', b'Y', 0, 1, 1) -_XERIAL_V1_FORMAT = 'bccccccBii' +_XERIAL_V1_HEADER = (-126, b"S", b"N", b"A", b"P", b"P", b"Y", 0, 1, 1) +_XERIAL_V1_FORMAT = "bccccccBii" ZSTD_MAX_OUTPUT_SIZE = 1024 * 1024 try: @@ -29,11 +24,11 @@ def _lz4_compress(payload, **kwargs): # Kafka does not support LZ4 dependent blocks try: # For lz4>=0.12.0 - kwargs.pop('block_linked', None) + kwargs.pop("block_linked", None) return lz4.compress(payload, block_linked=False, **kwargs) except TypeError: # For earlier versions of lz4 - kwargs.pop('block_mode', None) + kwargs.pop("block_mode", None) return lz4.compress(payload, block_mode=1, **kwargs) except ImportError: @@ -54,7 +49,8 @@ def _lz4_compress(payload, **kwargs): except ImportError: xxhash = None -PYPY = bool(platform.python_implementation() == 'PyPy') +PYPY = bool(platform.python_implementation() == "PyPy") + def has_gzip(): return True @@ -100,14 +96,14 @@ def gzip_decode(payload): # Gzip context manager introduced in python 2.7 # so old-fashioned way until we decide to not support 2.6 - gzipper = gzip.GzipFile(fileobj=buf, mode='r') + gzipper = gzip.GzipFile(fileobj=buf, mode="r") try: return gzipper.read() finally: gzipper.close() -def snappy_encode(payload, xerial_compatible=True, xerial_blocksize=32*1024): +def snappy_encode(payload, xerial_compatible=True, xerial_blocksize=32 * 1024): """Encodes the given data with snappy compression. If xerial_compatible is set then the stream is encoded in a fashion @@ -141,30 +137,30 @@ def snappy_encode(payload, xerial_compatible=True, xerial_blocksize=32*1024): out = io.BytesIO() for fmt, dat in zip(_XERIAL_V1_FORMAT, _XERIAL_V1_HEADER): - out.write(struct.pack('!' + fmt, dat)) + out.write(struct.pack("!" + fmt, dat)) # Chunk through buffers to avoid creating intermediate slice copies if PYPY: # on pypy, snappy.compress() on a sliced buffer consumes the entire # buffer... likely a python-snappy bug, so just use a slice copy - chunker = lambda payload, i, size: payload[i:size+i] + def chunker(payload, i, size): + return payload[i:size + i] - elif six.PY2: - # Sliced buffer avoids additional copies - # pylint: disable-msg=undefined-variable - chunker = lambda payload, i, size: buffer(payload, i, size) else: # snappy.compress does not like raw memoryviews, so we have to convert # tobytes, which is a copy... oh well. it's the thought that counts. # pylint: disable-msg=undefined-variable - chunker = lambda payload, i, size: memoryview(payload)[i:size+i].tobytes() + def chunker(payload, i, size): + return memoryview(payload)[i:size + i].tobytes() - for chunk in (chunker(payload, i, xerial_blocksize) - for i in range(0, len(payload), xerial_blocksize)): + for chunk in ( + chunker(payload, i, xerial_blocksize) + for i in range(0, len(payload), xerial_blocksize) + ): block = snappy.compress(chunk) block_size = len(block) - out.write(struct.pack('!i', block_size)) + out.write(struct.pack("!i", block_size)) out.write(block) return out.getvalue() @@ -172,28 +168,28 @@ def snappy_encode(payload, xerial_compatible=True, xerial_blocksize=32*1024): def _detect_xerial_stream(payload): """Detects if the data given might have been encoded with the blocking mode - of the xerial snappy library. - - This mode writes a magic header of the format: - +--------+--------------+------------+---------+--------+ - | Marker | Magic String | Null / Pad | Version | Compat | - +--------+--------------+------------+---------+--------+ - | byte | c-string | byte | int32 | int32 | - +--------+--------------+------------+---------+--------+ - | -126 | 'SNAPPY' | \0 | | | - +--------+--------------+------------+---------+--------+ - - The pad appears to be to ensure that SNAPPY is a valid cstring - The version is the version of this format as written by xerial, - in the wild this is currently 1 as such we only support v1. - - Compat is there to claim the minimum supported version that - can read a xerial block stream, presently in the wild this is - 1. + of the xerial snappy library. + + This mode writes a magic header of the format: + +--------+--------------+------------+---------+--------+ + | Marker | Magic String | Null / Pad | Version | Compat | + +--------+--------------+------------+---------+--------+ + | byte | c-string | byte | int32 | int32 | + +--------+--------------+------------+---------+--------+ + | -126 | 'SNAPPY' | \0 | | | + +--------+--------------+------------+---------+--------+ + + The pad appears to be to ensure that SNAPPY is a valid cstring + The version is the version of this format as written by xerial, + in the wild this is currently 1 as such we only support v1. + + Compat is there to claim the minimum supported version that + can read a xerial block stream, presently in the wild this is + 1. """ if len(payload) > 16: - header = struct.unpack('!' + _XERIAL_V1_FORMAT, bytes(payload)[:16]) + header = struct.unpack("!" + _XERIAL_V1_FORMAT, bytes(payload)[:16]) return header == _XERIAL_V1_HEADER return False @@ -210,7 +206,7 @@ def snappy_decode(payload): cursor = 0 while cursor < length: - block_size = struct.unpack_from('!i', byt[cursor:])[0] + block_size = struct.unpack_from("!i", byt[cursor:])[0] # Skip the block size cursor += 4 end = cursor + block_size @@ -224,11 +220,11 @@ def snappy_decode(payload): if lz4: - lz4_encode = _lz4_compress # pylint: disable-msg=no-member + lz4_encode = _lz4_compress # pylint: disable-msg=no-member elif lz4f: - lz4_encode = lz4f.compressFrame # pylint: disable-msg=no-member + lz4_encode = lz4f.compressFrame # pylint: disable-msg=no-member elif lz4framed: - lz4_encode = lz4framed.compress # pylint: disable-msg=no-member + lz4_encode = lz4framed.compress # pylint: disable-msg=no-member else: lz4_encode = None @@ -242,17 +238,17 @@ def lz4f_decode(payload): # lz4f python module does not expose how much of the payload was # actually read if the decompression was only partial. - if data['next'] != 0: - raise RuntimeError('lz4f unable to decompress full payload') - return data['decomp'] + if data["next"] != 0: + raise RuntimeError("lz4f unable to decompress full payload") + return data["decomp"] if lz4: - lz4_decode = lz4.decompress # pylint: disable-msg=no-member + lz4_decode = lz4.decompress # pylint: disable-msg=no-member elif lz4f: lz4_decode = lz4f_decode elif lz4framed: - lz4_decode = lz4framed.decompress # pylint: disable-msg=no-member + lz4_decode = lz4framed.decompress # pylint: disable-msg=no-member else: lz4_decode = None @@ -266,7 +262,7 @@ def lz4_encode_old_kafka(payload): if not isinstance(flg, int): flg = ord(flg) - content_size_bit = ((flg >> 3) & 1) + content_size_bit = (flg >> 3) & 1 if content_size_bit: # Old kafka does not accept the content-size field # so we need to discard it and reset the header flag @@ -274,18 +270,16 @@ def lz4_encode_old_kafka(payload): data = bytearray(data) data[4] = flg data = bytes(data) - payload = data[header_size+8:] + payload = data[header_size + 8:] else: payload = data[header_size:] # This is the incorrect hc - hc = xxhash.xxh32(data[0:header_size-1]).digest()[-2:-1] # pylint: disable-msg=no-member + hc = xxhash.xxh32(data[0:header_size - 1]).digest()[ + -2:-1 + ] # pylint: disable-msg=no-member - return b''.join([ - data[0:header_size-1], - hc, - payload - ]) + return b"".join([data[0:header_size - 1], hc, payload]) def lz4_decode_old_kafka(payload): @@ -296,18 +290,14 @@ def lz4_decode_old_kafka(payload): flg = payload[4] else: flg = ord(payload[4]) - content_size_bit = ((flg >> 3) & 1) + content_size_bit = (flg >> 3) & 1 if content_size_bit: header_size += 8 # This should be the correct hc - hc = xxhash.xxh32(payload[4:header_size-1]).digest()[-2:-1] # pylint: disable-msg=no-member + hc = xxhash.xxh32(payload[4:header_size - 1]).digest()[-2:-1] - munged_payload = b''.join([ - payload[0:header_size-1], - hc, - payload[header_size:] - ]) + munged_payload = b"".join([payload[0:header_size - 1], hc, payload[header_size:]]) return lz4_decode(munged_payload) @@ -323,4 +313,6 @@ def zstd_decode(payload): try: return zstd.ZstdDecompressor().decompress(payload) except zstd.ZstdError: - return zstd.ZstdDecompressor().decompress(payload, max_output_size=ZSTD_MAX_OUTPUT_SIZE) + return zstd.ZstdDecompressor().decompress( + payload, max_output_size=ZSTD_MAX_OUTPUT_SIZE + ) diff --git a/aiokafka/producer/producer.py b/aiokafka/producer/producer.py index 12a07e7f..3c5a096e 100644 --- a/aiokafka/producer/producer.py +++ b/aiokafka/producer/producer.py @@ -4,9 +4,8 @@ import traceback import warnings -from kafka.codec import has_gzip, has_snappy, has_lz4, has_zstd - from aiokafka.client import AIOKafkaClient +from aiokafka.codec import has_gzip, has_snappy, has_lz4, has_zstd from aiokafka.errors import ( MessageSizeTooLargeError, UnsupportedVersionError, IllegalOperation) from aiokafka.partitioner import DefaultPartitioner diff --git a/aiokafka/protocol/message.py b/aiokafka/protocol/message.py index d187b9bc..3fc665e2 100644 --- a/aiokafka/protocol/message.py +++ b/aiokafka/protocol/message.py @@ -1,7 +1,7 @@ import io import time -from kafka.codec import ( +from aiokafka.codec import ( has_gzip, has_snappy, has_lz4, diff --git a/aiokafka/record/_crecords/default_records.pyx b/aiokafka/record/_crecords/default_records.pyx index 73ec764a..ba49411d 100644 --- a/aiokafka/record/_crecords/default_records.pyx +++ b/aiokafka/record/_crecords/default_records.pyx @@ -55,12 +55,12 @@ # * Timestamp Type (3) # * Compression Type (0-2) -from aiokafka.errors import CorruptRecordException, UnsupportedCodecError -from kafka.codec import ( +import aiokafka.codec as codecs +from aiokafka.codec import ( gzip_encode, snappy_encode, lz4_encode, zstd_encode, gzip_decode, snappy_decode, lz4_decode, zstd_decode ) -import kafka.codec as codecs +from aiokafka.errors import CorruptRecordException, UnsupportedCodecError from cpython cimport PyObject_GetBuffer, PyBuffer_Release, PyBUF_WRITABLE, \ PyBUF_SIMPLE, PyBUF_READ, Py_buffer, \ diff --git a/aiokafka/record/_crecords/legacy_records.pyx b/aiokafka/record/_crecords/legacy_records.pyx index 3c2f366a..2406ef12 100644 --- a/aiokafka/record/_crecords/legacy_records.pyx +++ b/aiokafka/record/_crecords/legacy_records.pyx @@ -1,10 +1,10 @@ #cython: language_level=3 -from kafka.codec import ( +import aiokafka.codec as codecs +from aiokafka.codec import ( gzip_encode, snappy_encode, lz4_encode, lz4_encode_old_kafka, gzip_decode, snappy_decode, lz4_decode, lz4_decode_old_kafka ) -import kafka.codec as codecs from aiokafka.errors import CorruptRecordException, UnsupportedCodecError from zlib import crc32 as py_crc32 # needed for windows macro diff --git a/aiokafka/record/default_records.py b/aiokafka/record/default_records.py index e9ee69ee..4adaf871 100644 --- a/aiokafka/record/default_records.py +++ b/aiokafka/record/default_records.py @@ -56,15 +56,16 @@ import struct import time -from .util import decode_varint, encode_varint, calc_crc32c, size_of_varint +import aiokafka.codec as codecs from aiokafka.errors import CorruptRecordException, UnsupportedCodecError from aiokafka.util import NO_EXTENSIONS -from kafka.codec import ( +from aiokafka.codec import ( gzip_encode, snappy_encode, lz4_encode, zstd_encode, gzip_decode, snappy_decode, lz4_decode, zstd_decode ) -import kafka.codec as codecs + +from .util import decode_varint, encode_varint, calc_crc32c, size_of_varint class DefaultRecordBase: diff --git a/aiokafka/record/legacy_records.py b/aiokafka/record/legacy_records.py index e0364651..c1ae9480 100644 --- a/aiokafka/record/legacy_records.py +++ b/aiokafka/record/legacy_records.py @@ -3,13 +3,13 @@ from binascii import crc32 -from aiokafka.errors import CorruptRecordException, UnsupportedCodecError -from aiokafka.util import NO_EXTENSIONS -from kafka.codec import ( +import aiokafka.codec as codecs +from aiokafka.codec import ( gzip_encode, snappy_encode, lz4_encode, lz4_encode_old_kafka, gzip_decode, snappy_decode, lz4_decode, lz4_decode_old_kafka ) -import kafka.codec as codecs +from aiokafka.errors import CorruptRecordException, UnsupportedCodecError +from aiokafka.util import NO_EXTENSIONS NoneType = type(None) diff --git a/tests/conftest.py b/tests/conftest.py index 1cd91b22..66080e5e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,7 @@ from aiokafka.record.default_records import ( DefaultRecordBatchBuilder, _DefaultRecordBatchBuilderPy) from aiokafka.util import NO_EXTENSIONS + from ._testutil import wait_kafka diff --git a/tests/kafka/test_codec.py b/tests/kafka/test_codec.py deleted file mode 100644 index db6a14b6..00000000 --- a/tests/kafka/test_codec.py +++ /dev/null @@ -1,124 +0,0 @@ -from __future__ import absolute_import - -import platform -import struct - -import pytest -from kafka.vendor.six.moves import range - -from kafka.codec import ( - has_snappy, has_lz4, has_zstd, - gzip_encode, gzip_decode, - snappy_encode, snappy_decode, - lz4_encode, lz4_decode, - lz4_encode_old_kafka, lz4_decode_old_kafka, - zstd_encode, zstd_decode, -) - -from tests.kafka.testutil import random_string - - -def test_gzip(): - for i in range(1000): - b1 = random_string(100).encode('utf-8') - b2 = gzip_decode(gzip_encode(b1)) - assert b1 == b2 - - -@pytest.mark.skipif(not has_snappy(), reason="Snappy not available") -def test_snappy(): - for i in range(1000): - b1 = random_string(100).encode('utf-8') - b2 = snappy_decode(snappy_encode(b1)) - assert b1 == b2 - - -@pytest.mark.skipif(not has_snappy(), reason="Snappy not available") -def test_snappy_detect_xerial(): - import kafka as kafka1 - _detect_xerial_stream = kafka1.codec._detect_xerial_stream - - header = b'\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01Some extra bytes' - false_header = b'\x01SNAPPY\x00\x00\x00\x01\x00\x00\x00\x01' - default_snappy = snappy_encode(b'foobar' * 50) - random_snappy = snappy_encode(b'SNAPPY' * 50, xerial_compatible=False) - short_data = b'\x01\x02\x03\x04' - - assert _detect_xerial_stream(header) is True - assert _detect_xerial_stream(b'') is False - assert _detect_xerial_stream(b'\x00') is False - assert _detect_xerial_stream(false_header) is False - assert _detect_xerial_stream(default_snappy) is True - assert _detect_xerial_stream(random_snappy) is False - assert _detect_xerial_stream(short_data) is False - - -@pytest.mark.skipif(not has_snappy(), reason="Snappy not available") -def test_snappy_decode_xerial(): - header = b'\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01' - random_snappy = snappy_encode(b'SNAPPY' * 50, xerial_compatible=False) - block_len = len(random_snappy) - random_snappy2 = snappy_encode(b'XERIAL' * 50, xerial_compatible=False) - block_len2 = len(random_snappy2) - - to_test = header \ - + struct.pack('!i', block_len) + random_snappy \ - + struct.pack('!i', block_len2) + random_snappy2 \ - - assert snappy_decode(to_test) == (b'SNAPPY' * 50) + (b'XERIAL' * 50) - - -@pytest.mark.skipif(not has_snappy(), reason="Snappy not available") -def test_snappy_encode_xerial(): - to_ensure = ( - b'\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01' - b'\x00\x00\x00\x18' - b'\xac\x02\x14SNAPPY\xfe\x06\x00\xfe\x06\x00\xfe\x06\x00\xfe\x06\x00\x96\x06\x00' - b'\x00\x00\x00\x18' - b'\xac\x02\x14XERIAL\xfe\x06\x00\xfe\x06\x00\xfe\x06\x00\xfe\x06\x00\x96\x06\x00' - ) - - to_test = (b'SNAPPY' * 50) + (b'XERIAL' * 50) - - compressed = snappy_encode(to_test, xerial_compatible=True, xerial_blocksize=300) - assert compressed == to_ensure - - -@pytest.mark.skipif(not has_lz4() or platform.python_implementation() == 'PyPy', - reason="python-lz4 crashes on old versions of pypy") -def test_lz4(): - for i in range(1000): - b1 = random_string(100).encode('utf-8') - b2 = lz4_decode(lz4_encode(b1)) - assert len(b1) == len(b2) - assert b1 == b2 - - -@pytest.mark.skipif(not has_lz4() or platform.python_implementation() == 'PyPy', - reason="python-lz4 crashes on old versions of pypy") -def test_lz4_old(): - for i in range(1000): - b1 = random_string(100).encode('utf-8') - b2 = lz4_decode_old_kafka(lz4_encode_old_kafka(b1)) - assert len(b1) == len(b2) - assert b1 == b2 - - -@pytest.mark.skipif(not has_lz4() or platform.python_implementation() == 'PyPy', - reason="python-lz4 crashes on old versions of pypy") -def test_lz4_incremental(): - for i in range(1000): - # lz4 max single block size is 4MB - # make sure we test with multiple-blocks - b1 = random_string(100).encode('utf-8') * 50000 - b2 = lz4_decode(lz4_encode(b1)) - assert len(b1) == len(b2) - assert b1 == b2 - - -@pytest.mark.skipif(not has_zstd(), reason="Zstd not available") -def test_zstd(): - for _ in range(1000): - b1 = random_string(100).encode('utf-8') - b2 = zstd_decode(zstd_encode(b1)) - assert b1 == b2 diff --git a/tests/record/test_default_records.py b/tests/record/test_default_records.py index 455590c9..a79f2aef 100644 --- a/tests/record/test_default_records.py +++ b/tests/record/test_default_records.py @@ -1,6 +1,6 @@ from unittest import mock -import kafka.codec +import aiokafka.codec import pytest from aiokafka.errors import UnsupportedCodecError @@ -196,7 +196,7 @@ def test_unavailable_codec(compression_type, name, checker_name): builder.append(0, timestamp=None, key=None, value=b"M" * 2000, headers=[]) correct_buffer = builder.build() - with mock.patch.object(kafka.codec, checker_name, return_value=False): + with mock.patch.object(aiokafka.codec, checker_name, return_value=False): # Check that builder raises error builder = DefaultRecordBatchBuilder( magic=2, compression_type=compression_type, is_transactional=0, diff --git a/tests/record/test_legacy.py b/tests/record/test_legacy.py index ee3c6a76..26f3d4dc 100644 --- a/tests/record/test_legacy.py +++ b/tests/record/test_legacy.py @@ -1,7 +1,7 @@ import struct from unittest import mock -import kafka.codec +import aiokafka.codec import pytest from aiokafka.errors import CorruptRecordException, UnsupportedCodecError @@ -202,7 +202,7 @@ def test_unavailable_codec(compression_type, name, checker_name): builder.append(0, timestamp=None, key=None, value=b"M") correct_buffer = builder.build() - with mock.patch.object(kafka.codec, checker_name) as mocked: + with mock.patch.object(aiokafka.codec, checker_name) as mocked: mocked.return_value = False # Check that builder raises error builder = LegacyRecordBatchBuilder( diff --git a/tests/test_codec.py b/tests/test_codec.py new file mode 100644 index 00000000..9ae53487 --- /dev/null +++ b/tests/test_codec.py @@ -0,0 +1,136 @@ +import platform +import struct + +import pytest + +from aiokafka import codec as codecs +from aiokafka.codec import ( + has_snappy, + has_lz4, + has_zstd, + gzip_encode, + gzip_decode, + snappy_encode, + snappy_decode, + lz4_encode, + lz4_decode, + lz4_encode_old_kafka, + lz4_decode_old_kafka, + zstd_encode, + zstd_decode, +) + +from ._testutil import random_string + + +def test_gzip(): + for i in range(1000): + b1 = random_string(100) + b2 = gzip_decode(gzip_encode(b1)) + assert b1 == b2 + + +@pytest.mark.skipif(not has_snappy(), reason="Snappy not available") +def test_snappy(): + for i in range(1000): + b1 = random_string(100) + b2 = snappy_decode(snappy_encode(b1)) + assert b1 == b2 + + +@pytest.mark.skipif(not has_snappy(), reason="Snappy not available") +def test_snappy_detect_xerial(): + _detect_xerial_stream = codecs._detect_xerial_stream + + header = b"\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01Some extra bytes" + false_header = b"\x01SNAPPY\x00\x00\x00\x01\x00\x00\x00\x01" + default_snappy = snappy_encode(b"foobar" * 50) + random_snappy = snappy_encode(b"SNAPPY" * 50, xerial_compatible=False) + short_data = b"\x01\x02\x03\x04" + + assert _detect_xerial_stream(header) is True + assert _detect_xerial_stream(b"") is False + assert _detect_xerial_stream(b"\x00") is False + assert _detect_xerial_stream(false_header) is False + assert _detect_xerial_stream(default_snappy) is True + assert _detect_xerial_stream(random_snappy) is False + assert _detect_xerial_stream(short_data) is False + + +@pytest.mark.skipif(not has_snappy(), reason="Snappy not available") +def test_snappy_decode_xerial(): + header = b"\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01" + random_snappy = snappy_encode(b"SNAPPY" * 50, xerial_compatible=False) + block_len = len(random_snappy) + random_snappy2 = snappy_encode(b"XERIAL" * 50, xerial_compatible=False) + block_len2 = len(random_snappy2) + + to_test = ( + header + + struct.pack("!i", block_len) + + random_snappy + + struct.pack("!i", block_len2) + + random_snappy2 + ) + assert snappy_decode(to_test) == (b"SNAPPY" * 50) + (b"XERIAL" * 50) + + +@pytest.mark.skipif(not has_snappy(), reason="Snappy not available") +def test_snappy_encode_xerial(): + to_ensure = ( + b"\x82SNAPPY\x00\x00\x00\x00\x01\x00\x00\x00\x01" + b"\x00\x00\x00\x18\xac\x02\x14SNAPPY\xfe\x06\x00\xfe\x06\x00\xfe\x06\x00" + b"\xfe\x06\x00\x96\x06\x00\x00\x00\x00\x18\xac\x02" + b"\x14XERIAL\xfe\x06\x00\xfe\x06\x00\xfe\x06\x00\xfe\x06\x00\x96\x06\x00" + ) + + to_test = (b"SNAPPY" * 50) + (b"XERIAL" * 50) + + compressed = snappy_encode(to_test, xerial_compatible=True, xerial_blocksize=300) + assert compressed == to_ensure + + +@pytest.mark.skipif( + not has_lz4() or platform.python_implementation() == "PyPy", + reason="python-lz4 crashes on old versions of pypy", +) +def test_lz4(): + for i in range(1000): + b1 = random_string(100) + b2 = lz4_decode(lz4_encode(b1)) + assert len(b1) == len(b2) + assert b1 == b2 + + +@pytest.mark.skipif( + not has_lz4() or platform.python_implementation() == "PyPy", + reason="python-lz4 crashes on old versions of pypy", +) +def test_lz4_old(): + for i in range(1000): + b1 = random_string(100) + b2 = lz4_decode_old_kafka(lz4_encode_old_kafka(b1)) + assert len(b1) == len(b2) + assert b1 == b2 + + +@pytest.mark.skipif( + not has_lz4() or platform.python_implementation() == "PyPy", + reason="python-lz4 crashes on old versions of pypy", +) +def test_lz4_incremental(): + for i in range(1000): + # lz4 max single block size is 4MB + # make sure we test with multiple-blocks + b1 = random_string(100) * 50000 + b2 = lz4_decode(lz4_encode(b1)) + assert len(b1) == len(b2) + assert b1 == b2 + + +@pytest.mark.skipif(not has_zstd(), reason="Zstd not available") +def test_zstd(): + for _ in range(1000): + b1 = random_string(100) + b2 = zstd_decode(zstd_encode(b1)) + assert b1 == b2