diff --git a/.gitignore b/.gitignore index 32cd4c8b..84454c77 100644 --- a/.gitignore +++ b/.gitignore @@ -137,3 +137,5 @@ dmypy.json # Project specific /schema/ + +/.idea diff --git a/MANIFEST.in b/MANIFEST.in index fea6e1db..1ae9e4f2 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -15,3 +15,4 @@ include LICENSE include NOTICE include README.md include pyproject.toml +recursive-exclude java_tester * diff --git a/codegen/generate_schema.py b/codegen/generate_schema.py index 1fcce206..e9185f86 100644 --- a/codegen/generate_schema.py +++ b/codegen/generate_schema.py @@ -68,6 +68,7 @@ from kio.static.primitive import i32Timedelta from kio.static.primitive import i64Timedelta from kio.static.primitive import TZAware +from kio.static.protocol import ApiMessage from kio.static.constants import ErrorCode ''' @@ -396,15 +397,17 @@ def generate_dataclass( # noqa: C901 name: str, fields: Sequence[Field], version: int, + top_level: bool = False, ) -> Iterator[str | CustomTypeDef | ExportName]: if (name, version) in seen: return seen.add((name, version)) + class_parent_str = "(ApiMessage)" if top_level else "" class_start = textwrap.dedent( f"""\ @dataclass(frozen=True, slots=True, kw_only=True) - class {name}: + class {name}{class_parent_str}: """ ) class_fields = [] @@ -506,6 +509,7 @@ def generate_models( name=schema.name, fields=schema.fields, version=version, + top_level=True, ): match item: case CustomTypeDef() as instruction: diff --git a/codegen/generate_tests.py b/codegen/generate_tests.py index 1174a3f1..938e024d 100644 --- a/codegen/generate_tests.py +++ b/codegen/generate_tests.py @@ -9,6 +9,7 @@ from types import ModuleType import kio.schema +from kio.static.protocol import ApiMessage from .case import to_snake_case @@ -48,7 +49,7 @@ def get_entities() -> Iterator[tuple[type, str]]: from hypothesis import given, settings from hypothesis.strategies import from_type from kio.serial import entity_writer -from tests.conftest import setup_buffer +from tests.conftest import setup_buffer, JavaTester from kio.serial import entity_reader from typing import Final """ @@ -72,6 +73,12 @@ def test_{entity_snake_case}_roundtrip(instance: {entity_type}) -> None: assert instance == result """ +test_code_java = """\ +@given(instance=from_type({entity_type})) +def test_{entity_snake_case}_java(instance: {entity_type}, java_tester: JavaTester) -> None: + java_tester.test(instance) +""" + base_dir = Path(__file__).parent.parent.resolve() generated_tests_module = base_dir / "tests" / "generated" @@ -106,6 +113,20 @@ def main() -> None: entity_snake_case=to_snake_case(entity_type.__name__), ) ) + if issubclass(entity_type, ApiMessage) and entity_type.__name__ not in { + "ProduceRequest", # Records + "FetchResponse", # Records + "FetchSnapshotResponse", # Records + "CreateTopicsResponse", # Should not output tagged field if its value equals to default + "FetchRequest", # Should not output tagged field if its value equals to default (presumably) + "ConsumerGroupHeartbeatResponse", # Nullable `assignment` field + }: + module_code[module_path].append( + test_code_java.format( + entity_type=entity_type.__name__, + entity_snake_case=to_snake_case(entity_type.__name__), + ) + ) for module_path, entity_imports in module_imports.items(): with module_path.open("w") as fd: diff --git a/java_tester/.gitignore b/java_tester/.gitignore new file mode 100644 index 00000000..d070b266 --- /dev/null +++ b/java_tester/.gitignore @@ -0,0 +1,11 @@ +# Ignore Gradle project-specific cache directory +.gradle + +# Ignore Gradle build output directory +build + +.idea + +/gradle/wrapper +/gradlew +/gradlew.bat diff --git a/java_tester/Dockerfile b/java_tester/Dockerfile new file mode 100644 index 00000000..1f155d9c --- /dev/null +++ b/java_tester/Dockerfile @@ -0,0 +1,13 @@ +FROM gradle:jdk17-alpine AS build + +RUN mkdir /java_tester +WORKDIR /java_tester +ADD src src +ADD build.gradle build.gradle + +RUN gradle --no-daemon distTar && tar -xvf build/distributions/java_tester.tar + +FROM eclipse-temurin:17-jdk +COPY --from=build /java_tester/java_tester /java_tester + +CMD /java_tester/bin/java_tester diff --git a/java_tester/Makefile b/java_tester/Makefile new file mode 100644 index 00000000..b9a6d43a --- /dev/null +++ b/java_tester/Makefile @@ -0,0 +1,3 @@ +.PHONY: docker_image +docker_image: + docker build -t java_tester . diff --git a/java_tester/build.gradle b/java_tester/build.gradle new file mode 100644 index 00000000..93ec79eb --- /dev/null +++ b/java_tester/build.gradle @@ -0,0 +1,24 @@ +plugins { + id 'application' +} + +java { + toolchain { + languageVersion = JavaLanguageVersion.of(17) + } +} + +repositories { + mavenCentral() +} + +dependencies { + implementation 'org.apache.kafka:kafka-clients:3.5.1' + implementation 'com.fasterxml.jackson.core:jackson-databind:2.15.2' + implementation 'org.apache.commons:commons-text:1.9' +} + +application { + mainClass = 'io.aiven.kio.java_tester.JavaTester' +} +run { standardInput = System.in } diff --git a/java_tester/settings.gradle b/java_tester/settings.gradle new file mode 100644 index 00000000..1ae018e3 --- /dev/null +++ b/java_tester/settings.gradle @@ -0,0 +1 @@ +rootProject.name = 'java_tester' diff --git a/java_tester/src/main/java/io/aiven/kio/java_tester/BaseCreator.java b/java_tester/src/main/java/io/aiven/kio/java_tester/BaseCreator.java new file mode 100644 index 00000000..27adca5f --- /dev/null +++ b/java_tester/src/main/java/io/aiven/kio/java_tester/BaseCreator.java @@ -0,0 +1,118 @@ +package io.aiven.kio.java_tester; + +import java.nio.ByteBuffer; +import java.util.Base64; +import java.util.Map; +import java.util.UUID; + +import org.apache.kafka.common.Uuid; + +import com.fasterxml.jackson.databind.JsonNode; + +abstract class BaseCreator { + protected final Class rootClazz; + protected final short version; + protected final Map commonStructs; + + BaseCreator(Class rootClazz, short version, Map commonStructs) { + this.rootClazz = rootClazz; + this.version = version; + this.commonStructs = commonStructs; + } + + protected static byte getByte(JsonNode fieldValue, String fieldName) throws Exception { + if (!fieldValue.isIntegralNumber()) { + throw new Exception("Expected byte in field " + fieldName + " but got " + fieldValue); + } + + long longValue = fieldValue.asLong(); + if ((byte) longValue != longValue) { + throw new Exception("Invalid byte value in field " + fieldName + ": " + longValue); + } + + return (byte) longValue; + } + + protected static short getShort(JsonNode fieldValue, String fieldName) throws Exception { + if (!fieldValue.isIntegralNumber()) { + throw new Exception("Expected short in field " + fieldName + " but got " + fieldValue); + } + + long longValue = fieldValue.asLong(); + if ((short) longValue != longValue) { + throw new Exception("Invalid short value in field " + fieldName + ": " + longValue); + } + + return (short) longValue; + } + + protected static int getInt(JsonNode fieldValue, String fieldName) throws Exception { + if (!fieldValue.isIntegralNumber()) { + throw new Exception("Expected int in field " + fieldName + " but got " + fieldValue); + } + + long longValue = fieldValue.asLong(); + if ((int) longValue != longValue) { + throw new Exception("Invalid int value in field " + fieldName + ": " + longValue); + } + + return (int) longValue; + } + + protected static long getLong(JsonNode fieldValue, String fieldName) throws Exception { + if (!fieldValue.isIntegralNumber()) { + throw new Exception("Expected int in field " + fieldName + " but got " + fieldValue); + } + return fieldValue.asLong(); + } + + protected static double getDouble(JsonNode fieldValue, String fieldName) throws Exception { + if (!fieldValue.isFloatingPointNumber()) { + throw new Exception("Expected double in field " + fieldName + " but got " + fieldValue); + } + return fieldValue.asDouble(); + } + + protected static boolean getBoolean(JsonNode fieldValue, String fieldName) throws Exception { + if (!fieldValue.isBoolean()) { + throw new Exception("Expected boolean in field " + fieldName + " but got " + fieldValue); + } + return fieldValue.asBoolean(); + } + + protected static String getString(JsonNode fieldValue, String fieldName) throws Exception { + if (fieldValue.isNull()) { + return null; + } + + if (!fieldValue.isTextual()) { + throw new Exception("Expected string in field " + fieldName + " but got " + fieldValue); + } + return fieldValue.asText(); + } + + protected static Uuid getUuid(JsonNode fieldValue, String fieldName) throws Exception { + String str = getString(fieldValue, fieldName); + if (str == null) { + return null; + } + UUID tmpUuid = UUID.fromString(str); + return new Uuid(tmpUuid.getMostSignificantBits(), tmpUuid.getLeastSignificantBits()); + } + + protected static ByteBuffer getByteBuffer(JsonNode fieldValue, String fieldName) throws Exception { + String str = getString(fieldValue, fieldName); + if (str == null) { + return null; + } + return ByteBuffer.wrap(Base64.getDecoder().decode(str)); + } + + protected static byte[] getBytes(JsonNode fieldValue, String fieldName) throws Exception { + String str = getString(fieldValue, fieldName); + if (str == null) { + return null; + } + return Base64.getDecoder().decode(str); + } +} diff --git a/java_tester/src/main/java/io/aiven/kio/java_tester/CollectionCreator.java b/java_tester/src/main/java/io/aiven/kio/java_tester/CollectionCreator.java new file mode 100644 index 00000000..db928d1c --- /dev/null +++ b/java_tester/src/main/java/io/aiven/kio/java_tester/CollectionCreator.java @@ -0,0 +1,126 @@ +package io.aiven.kio.java_tester; + +import java.lang.reflect.Method; +import java.util.AbstractCollection; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.databind.JsonNode; + +class CollectionCreator extends BaseCreator { + private final JsonNode fieldValue; + private final String fieldName; + private final JsonNode fieldSchema; + + private CollectionCreator(Class rootClazz, short version, Map commonStructs, + JsonNode fieldValue, String fieldName, JsonNode fieldSchema) { + super(rootClazz, version, commonStructs); + this.fieldValue = fieldValue; + this.fieldName = fieldName; + this.fieldSchema = fieldSchema; + } + + static AbstractCollection createAbstractCollection( + Class rootClazz, short version, Map commonStructs, + Class> collectionClazz, JsonNode fieldValue, String fieldName, JsonNode fieldSchema + ) throws Exception { + if (!fieldValue.isArray()) { + throw new Exception("The value of " + fieldName + " must be array but was " + fieldValue); + } + CollectionCreator creator = + new CollectionCreator<>(rootClazz, version, commonStructs, fieldValue, fieldName, fieldSchema); + return creator.createAbstractCollection(collectionClazz); + } + + private AbstractCollection createAbstractCollection( + Class> collectionClazz + ) throws Exception { + AbstractCollection collection = Instantiator.instantiate(collectionClazz); + Class elementClazz = getCollectionElementClass(collectionClazz); + fillCollectionFromChildren(elementClazz, collection); + return collectionClazz.cast(collection); + } + + private static Class getCollectionElementClass(Class> collectionClazz) + throws Exception { + // Try to estimate the element class based on the `find` method. + // `find` is expected to be present in all `AbstractCollection`s of interest. + for (Method method : collectionClazz.getDeclaredMethods()) { + if (method.getName().equals("find")) { + return (Class) method.getReturnType(); + } + } + throw new Exception("No 'find' method for " + collectionClazz); + } + + static List createList( + Class rootClazz, short version, Map commonStructs, + JsonNode fieldValue, String fieldName, JsonNode fieldSchema + ) throws Exception { + if (!fieldValue.isArray()) { + throw new Exception("The value of " + fieldName + " must be array but was " + fieldValue); + } + CollectionCreator creator = + new CollectionCreator<>(rootClazz, version, commonStructs, fieldValue, fieldName, fieldSchema); + return creator.createList(); + } + + private List createList() throws Exception { + final String elementTypeInSchema; + { + String tmp = fieldSchema.get("type").asText(); + if (!tmp.startsWith("[]")) { + throw new Exception("Unexpected type " + tmp); + } + elementTypeInSchema = tmp.substring(2); + } + + Class elementClazz = switch (elementTypeInSchema) { + case "int8" -> Byte.class; + case "int16" -> Short.class; + case "int32" -> Integer.class; + case "int64" -> Long.class; + case "string" -> String.class; + default -> Arrays.stream(rootClazz.getDeclaredClasses()) + .filter(c -> c.getName().endsWith("$" + elementTypeInSchema)) + .findFirst().get(); + }; + + List list = new ArrayList<>(); + fillCollectionFromChildren(elementClazz, list); + return list; + } + + private void fillCollectionFromChildren( + Class elementClazz, Collection collection + ) throws Exception { + if (!fieldValue.isArray()) { + throw new Exception("The value of " + fieldName + " must be array but was " + fieldValue); + } + + Iterator elements = fieldValue.elements(); + while (elements.hasNext()) { + JsonNode elementValue = elements.next(); + Object elementObj; + if (elementClazz.equals(Byte.class)) { + elementObj = getByte(elementValue, fieldName); + } else if (elementClazz.equals(Short.class)) { + elementObj = getShort(elementValue, fieldName); + } else if (elementClazz.equals(Integer.class)) { + elementObj = getInt(elementValue, fieldName); + } else if (elementClazz.equals(Long.class)) { + elementObj = getLong(elementValue, fieldName); + } else if (elementClazz.equals(String.class)) { + elementObj = getString(elementValue, fieldName); + } else { + elementObj = ObjectCreator.create( + rootClazz, version, commonStructs, elementClazz, fieldSchema, elementValue); + } + collection.add(elementObj); + } + } +} diff --git a/java_tester/src/main/java/io/aiven/kio/java_tester/Instantiator.java b/java_tester/src/main/java/io/aiven/kio/java_tester/Instantiator.java new file mode 100644 index 00000000..7c4e87f8 --- /dev/null +++ b/java_tester/src/main/java/io/aiven/kio/java_tester/Instantiator.java @@ -0,0 +1,11 @@ +package io.aiven.kio.java_tester; + +import java.lang.reflect.Constructor; + +class Instantiator { + static ObjT instantiate(Class clazz) throws ReflectiveOperationException { + Constructor constructor = clazz.getDeclaredConstructor(); + Object o = constructor.newInstance(); + return clazz.cast(o); + } +} diff --git a/java_tester/src/main/java/io/aiven/kio/java_tester/JavaTester.java b/java_tester/src/main/java/io/aiven/kio/java_tester/JavaTester.java new file mode 100644 index 00000000..64ffca75 --- /dev/null +++ b/java_tester/src/main/java/io/aiven/kio/java_tester/JavaTester.java @@ -0,0 +1,17 @@ +package io.aiven.kio.java_tester; + +import java.io.BufferedReader; +import java.io.InputStreamReader; + +public class JavaTester { + public static void main(String[] args) throws Exception { + try (BufferedReader reader = new BufferedReader(new InputStreamReader(System.in))) { + String caseStr; + while ((caseStr = reader.readLine()) != null) { + String result = new TestCase(caseStr).process(); + System.out.println(result); + System.out.flush(); + } + } + } +} diff --git a/java_tester/src/main/java/io/aiven/kio/java_tester/ObjectCreator.java b/java_tester/src/main/java/io/aiven/kio/java_tester/ObjectCreator.java new file mode 100644 index 00000000..b52de71a --- /dev/null +++ b/java_tester/src/main/java/io/aiven/kio/java_tester/ObjectCreator.java @@ -0,0 +1,237 @@ +package io.aiven.kio.java_tester; + +import java.lang.reflect.Method; +import java.nio.ByteBuffer; +import java.util.AbstractCollection; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.record.BaseRecords; + +import com.fasterxml.jackson.databind.JsonNode; +import org.apache.commons.text.CaseUtils; + +class ObjectCreator extends BaseCreator { + private final java.lang.Class clazz; + + private final JsonNode schema; + + private final JsonNode json; + + public static ObjT create( + java.lang.Class rootClazz, short version, Map commonStructs, + java.lang.Class clazz, JsonNode schema, JsonNode json + ) throws Exception { + ObjectCreator creator = new ObjectCreator<>( + rootClazz, version, commonStructs, clazz, schema, json); + return creator.create(); + } + + private ObjectCreator(java.lang.Class rootClazz, short version, Map commonStructs, + java.lang.Class clazz, JsonNode schema, JsonNode json) { + super(rootClazz, version, commonStructs); + this.clazz = clazz; + this.schema = schema; + this.json = json; + } + + public ObjT create() throws Exception { + ObjT instance = Instantiator.instantiate(clazz); + + Map setters = getSetters(); + List knownFieldNames = getKnownFieldNames(); + + Iterator fieldNames = json.fieldNames(); + while (fieldNames.hasNext()) { + String fieldName = fieldNames.next(); + String kafkaesqueFieldName = kafkaesqueFieldName(fieldName, knownFieldNames); + JsonNode fieldSchema = getFieldSchema(schema, kafkaesqueFieldName); + JsonNode fieldValue = json.get(fieldName); + String setterName = setterNameFromFieldName(kafkaesqueFieldName); + Method setter = getSetter(setters, setterName); + java.lang.Class parameterType = setter.getParameterTypes()[0]; + + if (fieldValue == null) { + if (parameterType.isPrimitive()) { + throw new Exception("The parameter is primitive (" + parameterType + "), " + + "but value of field " + fieldName + " is null"); + } + setter.invoke(instance, (Object) null); + } else if (parameterType.equals(byte.class)) { + setter.invoke(instance, getByte(fieldValue, fieldName)); + } else if (parameterType.equals(short.class)) { + if (fieldValue.isNull()) { + if (isTaggedVersion(fieldSchema)) { + // Do nothing for the tagged version. + continue; + } else { + throw new Exception("Unexpected null for non-tagged field " + fieldName); + } + } + setter.invoke(instance, getShort(fieldValue, fieldName)); + } else if (parameterType.equals(int.class)) { + setter.invoke(instance, getInt(fieldValue, fieldName)); + } else if (parameterType.equals(long.class)) { + setter.invoke(instance, getLong(fieldValue, fieldName)); + } else if (parameterType.equals(double.class)) { + setter.invoke(instance, getDouble(fieldValue, fieldName)); + } else if (parameterType.equals(boolean.class)) { + setter.invoke(instance, getBoolean(fieldValue, fieldName)); + } else if (parameterType.isPrimitive()) { + // Add the handling fo this type if you face this. + throw new Exception("Unsupported primitive type " + parameterType); + } else if (parameterType.isArray()) { + if (parameterType.componentType().equals(byte.class)) { + setter.invoke(instance, (Object) getBytes(fieldValue, fieldName)); + } else { + throw new Exception("Unsupported array type " + parameterType); + } + } else if (parameterType.equals(String.class)) { + setter.invoke(instance, getString(fieldValue, fieldName)); + } else if (parameterType.equals(Uuid.class)) { + Uuid uuid = getUuid(fieldValue, fieldName); + if (uuid != null) { + setter.invoke(instance, uuid); + } + } else if (parameterType.equals(ByteBuffer.class)) { + setter.invoke(instance, getByteBuffer(fieldValue, fieldName)); + } else if (AbstractCollection.class.isAssignableFrom(parameterType)) { + AbstractCollection collection = CollectionCreator.createAbstractCollection( + rootClazz, version, commonStructs, + (Class>) parameterType, + fieldValue, fieldName, fieldSchema + ); + setter.invoke(instance, collection); + } else if (List.class.isAssignableFrom(parameterType)) { + List list = CollectionCreator.createList( + rootClazz, version, commonStructs, + fieldValue, fieldName, fieldSchema); + setter.invoke(instance, list); + } else if (BaseRecords.class.isAssignableFrom(parameterType)) { + throw new Exception("Not implemented"); + } else { + Object o = ObjectCreator.create( + rootClazz, version, commonStructs, parameterType, fieldSchema, fieldValue); + setter.invoke(instance, o); + } + } + + return instance; + } + + private boolean isTaggedVersion(JsonNode fieldSchema) { + JsonNode taggedVersions = fieldSchema.get("taggedVersions"); + if (taggedVersions == null) { + return false; + } + short taggedVersionFrom = Short.parseShort(taggedVersions.asText().replace("+", "")); + return version >= taggedVersionFrom; + } + + private Map getSetters() { + return Arrays.stream(clazz.getDeclaredMethods()) + .filter(m -> m.getName().startsWith("set")) + .collect(Collectors.toMap(Method::getName, m -> m)); + } + + public List getKnownFieldNames() { + Iterator fields; + if (schema.get("fields") != null) { + fields = schema.get("fields").elements(); + } else { + String type = schema.get("type").asText().replace("[]", ""); + fields = commonStructs.get(type).get("fields").elements(); + } + + List knownNames = new ArrayList<>(); + + while (fields.hasNext()) { + JsonNode field = fields.next(); + String name = field.get("name").asText(); + knownNames.add(name); + } + return knownNames; + } + + private static String kafkaesqueFieldName(String fieldName, List knownFieldNames) { + switch (fieldName) { + case "timeout" -> fieldName = "timeout_ms"; + case "throttle_time" -> fieldName = "throttle_time_ms"; + case "max_wait" -> fieldName = "max_wait_ms"; + case "session_lifetime" -> fieldName = "session_lifetime_ms"; + case "transaction_timeout" -> fieldName = "transaction_timeout_ms"; + case "max_lifetime" -> fieldName = "max_lifetime_ms"; + case "session_timeout" -> fieldName = "session_timeout_ms"; + case "rebalance_timeout" -> fieldName = "rebalance_timeout_ms"; + case "expiry_time_period" -> fieldName = "expiry_time_period_ms"; + case "renew_period" -> fieldName = "renew_period_ms"; + case "retention_time" -> fieldName = "retention_time_ms"; + case "heartbeat_interval" -> fieldName = "heartbeat_interval_ms"; + case "issue_timestamp" -> fieldName = "issue_timestamp_ms"; + case "expiry_timestamp" -> fieldName = "expiry_timestamp_ms"; + case "max_timestamp" -> fieldName = "max_timestamp_ms"; + case "transaction_start_time" -> fieldName = "transaction_start_time_ms"; + case "log_append_time" -> fieldName = "log_append_time_ms"; + } + + fieldName = CaseUtils.toCamelCase(fieldName, true, '_'); + + if (!knownFieldNames.contains(fieldName)) { + switch (fieldName) { + case "IssueTimestampMs" -> fieldName = "IssueTimestamp"; + case "ExpiryTimestampMs" -> fieldName = "ExpiryTimestamp"; + case "MaxTimestampMs" -> fieldName = "MaxTimestamp"; + } + } + + return fieldName; + } + + private static String setterNameFromFieldName(String kafkaesqueFieldName) { + return "set" + kafkaesqueFieldName; + } + + private static Method getSetter(Map setters, String setterName) throws Exception { + Method method = setters.get(setterName); + if (method == null) { + throw new Exception("Setter method " + setterName + " not found"); + } + + if (method.getParameterCount() != 1) { + throw new Exception("Invalid number of parameters in " + setterName); + } + + return method; + } + + private JsonNode getFieldSchema(JsonNode schema, String fieldName) throws Exception { + Iterator fields; + if (schema.get("fields") != null) { + fields = schema.get("fields").elements(); + } else { + String type = schema.get("type").asText().replace("[]", ""); + fields = commonStructs.get(type).get("fields").elements(); + } + + while (fields.hasNext()) { + JsonNode field = fields.next(); + if (field.get("name").asText().equals(fieldName)) { + return field; + } + } + + return switch (fieldName) { + case "TimeoutMs" -> getFieldSchema(schema, "timeoutMs"); + case "ValidateOnly" -> getFieldSchema(schema, "validateOnly"); + case "MemberAssignment" -> getFieldSchema(schema, "memberAssignment"); + case "GroupId" -> getFieldSchema(schema, "groupId"); + case "IsKRaftController" -> getFieldSchema(schema, "isKRaftController"); + default -> throw new Exception("field " + fieldName + " not found"); + }; + } +} diff --git a/java_tester/src/main/java/io/aiven/kio/java_tester/TestCase.java b/java_tester/src/main/java/io/aiven/kio/java_tester/TestCase.java new file mode 100644 index 00000000..406c07bb --- /dev/null +++ b/java_tester/src/main/java/io/aiven/kio/java_tester/TestCase.java @@ -0,0 +1,158 @@ +package io.aiven.kio.java_tester; + +import java.io.IOException; +import java.io.InputStream; +import java.io.PrintWriter; +import java.io.StringWriter; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Base64; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; + +import org.apache.kafka.common.protocol.ApiMessage; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.ObjectSerializationCache; +import org.apache.kafka.common.protocol.Readable; + +import com.fasterxml.jackson.core.JsonFactory; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; + +class TestCase { + private static final String SUCCESS_RESPONSE = "{\"success\": true}"; + + private static final ObjectMapper OBJECT_MAPPER; + + static { + JsonFactory jsonFactory = new JsonFactory(); + jsonFactory.enable(JsonParser.Feature.ALLOW_COMMENTS); + OBJECT_MAPPER = new ObjectMapper(jsonFactory); + } + + private final String caseStr; + private JsonNode caseNode; + private short version; + private String shortClassName; + + private JsonNode rootSchema; + private Map commonStructs; + private Class rootClazz; + + TestCase(String caseStr) { + this.caseStr = caseStr; + } + + String process() { + try { + caseNode = OBJECT_MAPPER.readTree(caseStr); + version = caseNode.get("version").shortValue(); + shortClassName = caseNode.get("class").asText(); + + rootSchema = getRootSchema(); + commonStructs = getCommonStructs(); + rootClazz = getRootClass(); + + ObjectSerializationCache objectSerializationCache = new ObjectSerializationCache(); + ApiMessage constructedMessage = + ObjectCreator.create(rootClazz, version, commonStructs, rootClazz, rootSchema, caseNode.get("json")); + int size = constructedMessage.size(objectSerializationCache, version); + ByteBufferAccessor writer = new ByteBufferAccessor(ByteBuffer.allocate(size)); + constructedMessage.write(writer, objectSerializationCache, version); + + byte[] serializedFromPython = Base64.getDecoder().decode(caseNode.get("serialized").asText()); + Readable readable = new ByteBufferAccessor(ByteBuffer.wrap(serializedFromPython)); + ApiMessage messageDeserializedFromPython = Instantiator.instantiate(rootClazz); + messageDeserializedFromPython.read(readable, version); + + if (!messageDeserializedFromPython.equals(constructedMessage)) { + String message = "Deserialized message is not equal to constructed\n" + + "Input: " + caseStr + "\n" + + "Deserialized: " + messageDeserializedFromPython + "\n" + + "Constructed: " + constructedMessage; + return failureResponse(message); + } else { + byte[] serializedInJava = writer.buffer().array(); + if (!Arrays.equals(serializedFromPython, serializedInJava)) { + String message = "Message serialized in Java is not equal to message serialized in Python\n" + + "Input: " + caseStr + "\n" + + "Deserialized: " + messageDeserializedFromPython + "\n" + + "Constructed: " + constructedMessage; + return failureResponse(message); + } else { + return SUCCESS_RESPONSE; + } + } + } catch (Exception e) { + return exceptionResponse(e); + } + } + + private JsonNode getRootSchema() throws IOException { + String schemaResource = "common/message/" + shortClassName + ".json"; + try (InputStream resource = TestCase.class.getClassLoader().getResourceAsStream(schemaResource)) { + return OBJECT_MAPPER.readTree(resource); + } + } + + private Map getCommonStructs() { + Map result = new HashMap<>(); + + if (rootSchema.get("commonStructs") != null) { + Iterator elements = rootSchema.get("commonStructs").elements(); + while (elements.hasNext()) { + JsonNode struct = elements.next(); + result.put(struct.get("name").asText(), struct); + } + } + + return result; + } + + private Class getRootClass() throws ClassNotFoundException { + String className = "org.apache.kafka.common.message." + shortClassName; + if (!shortClassName.equals("SnapshotFooterRecord") + && !shortClassName.equals("SnapshotHeaderRecord") + && !shortClassName.equals("ConsumerProtocolAssignment") + && !shortClassName.equals("ConsumerProtocolSubscription") + && !shortClassName.equals("LeaderChangeMessage") + && !shortClassName.equals("DefaultPrincipalData") + ) { + className += "Data"; + } + return (Class) getClass().getClassLoader().loadClass(className); + } + + private static String failureResponse(String message) { + ObjectNode objectNode = OBJECT_MAPPER.createObjectNode(); + objectNode.put("success", false); + objectNode.put("message", message); + try { + return OBJECT_MAPPER.writeValueAsString(objectNode); + } catch (JsonProcessingException e) { + // this shouldn't happen + throw new RuntimeException(e); + } + } + + private String exceptionResponse(Exception exception) { + ObjectNode objectNode = OBJECT_MAPPER.createObjectNode(); + objectNode.put("success", false); + + StringWriter stringWriter = new StringWriter(); + PrintWriter printWriter = new PrintWriter(stringWriter); + exception.printStackTrace(printWriter); + objectNode.put("exception", stringWriter.toString() + "\n" + "Case: " + caseStr); + + try { + return OBJECT_MAPPER.writeValueAsString(objectNode); + } catch (JsonProcessingException e) { + // this shouldn't happen + throw new RuntimeException(e); + } + } +} diff --git a/src/kio/serial/writers.py b/src/kio/serial/writers.py index 2df47a0e..6935d0eb 100644 --- a/src/kio/serial/writers.py +++ b/src/kio/serial/writers.py @@ -129,18 +129,18 @@ def write_nullable_legacy_string(buffer: Writable, value: str | None) -> None: write_int16(buffer, i16(-1)) return - value = value.encode() + value_b = value.encode() try: - length = i16(len(value)) + length = i16(len(value_b)) except TypeError as exception: raise OutOfBoundValue( - f"String is too long for legacy string format ({len(value)} > " + f"String is too long for legacy string format ({len(value_b)} > " f"{i16.__high__})" ) from exception write_int16(buffer, length) - buffer.write(value) + buffer.write(value_b) def write_nullable_legacy_bytes(buffer: Writable, value: bytes | None) -> None: diff --git a/src/kio/static/protocol.py b/src/kio/static/protocol.py index 08bb5d51..9260c253 100644 --- a/src/kio/static/protocol.py +++ b/src/kio/static/protocol.py @@ -12,7 +12,11 @@ class DataclassInstance(Protocol): ... -__all__ = ("Entity", "Payload") +__all__ = ("ApiMessage", "Entity", "Payload") + + +class ApiMessage: + pass class Entity(DataclassInstance, Protocol): diff --git a/tests/conftest.py b/tests/conftest.py index 34930005..589d02eb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,33 @@ +# ruff: noqa: S603 import asyncio +import base64 import contextlib +import dataclasses import io +import json import os from collections.abc import AsyncIterator from collections.abc import Iterator from contextlib import closing +from datetime import datetime +from datetime import timedelta +from json import JSONEncoder +from pathlib import Path +from subprocess import PIPE +from subprocess import Popen +from subprocess import TimeoutExpired +from typing import Any +from uuid import UUID import pytest import pytest_asyncio +from hypothesis import settings + +from kio.serial import entity_writer +from kio.static.protocol import Entity + +settings.register_profile("test", deadline=timedelta(seconds=60)) +settings.load_profile("test") def buffer() -> Iterator[io.BytesIO]: @@ -68,3 +88,94 @@ async def stream_writer( async_buffers: tuple[object, asyncio.StreamWriter], ) -> asyncio.StreamWriter: return async_buffers[1] + + +class JavaTester: + class _Encoder(JSONEncoder): + def default(self, o: Any) -> Any: + if dataclasses.is_dataclass(o): + return self._replace_tzaware_nulls(dataclasses.asdict(o)) + if isinstance(o, timedelta): + return round(o.total_seconds() * 1000) + if isinstance(o, datetime): + return round(o.timestamp() * 1000) + if isinstance(o, UUID): + return str(o) + if isinstance(o, bytes): + return base64.b64encode(o).decode("utf-8") + return super().default(o) + + def _replace_tzaware_nulls(self, o: Any) -> Any: + if isinstance(o, dict): + result = {} + for k, v in o.items(): + if k == "log_append_time" and v is None: + result[k] = -1 + else: + result[k] = self._replace_tzaware_nulls(v) + return result + elif isinstance(o, list): + return [self._replace_tzaware_nulls(e) for e in o] + elif isinstance(o, tuple): + return tuple(self._replace_tzaware_nulls(e) for e in o) + else: + return o + + def __init__(self) -> None: + cmd = [ + "docker", + "compose", + "-f", + str(Path(__file__).parent / "docker-compose-java-tester.yaml"), + "run", + "--rm", + "-i", + "java_tester", + ] + self._p = Popen(cmd, stdin=PIPE, stdout=PIPE, shell=False, text=True) + + def test(self, instance: Entity) -> None: + instance_type = type(instance) + buffer = io.BytesIO() + writer = entity_writer(instance_type) + writer(buffer, instance) + buffer.seek(0) + + case = { + "class": instance_type.__name__, + "version": instance_type.__version__, + "json": instance, + "serialized": buffer.getvalue(), + } + case_str = json.dumps(case, cls=self._Encoder) + "\n" + + assert self._p.stdin is not None + assert self._p.stdout is not None + + self._p.stdin.write(case_str) + self._p.stdin.flush() + + line = self._p.stdout.readline() + response = json.loads(line) + assert response["success"], response.get("message") or response.get("exception") + + def close(self) -> None: + if self._p.stdin is not None: + self._p.stdin.flush() + self._p.stdin.close() + if self._p.stdout is not None: + self._p.stdout.close() + try: + self._p.wait(timeout=10) + except TimeoutExpired: + self._p.kill() + self._p.wait(timeout=10) + + +@pytest.fixture(name="java_tester", scope="session") +def fixture_java_tester() -> Iterator[JavaTester]: + jt = JavaTester() + try: + yield jt + finally: + jt.close() diff --git a/tests/docker-compose-java-tester.yaml b/tests/docker-compose-java-tester.yaml new file mode 100644 index 00000000..7a7db3c4 --- /dev/null +++ b/tests/docker-compose-java-tester.yaml @@ -0,0 +1,5 @@ +version: "3.8" +services: + java_tester: + build: + context: ../java_tester diff --git a/tests/serial/test_roundtrips.py b/tests/serial/test_roundtrips.py index 22439dd2..ec5f3de2 100644 --- a/tests/serial/test_roundtrips.py +++ b/tests/serial/test_roundtrips.py @@ -35,7 +35,7 @@ from kio.serial.readers import read_uint64 from kio.serial.readers import read_unsigned_varint from kio.serial.readers import read_uuid -from kio.serial.writers import Writer, write_legacy_bytes +from kio.serial.writers import Writer from kio.serial.writers import write_boolean from kio.serial.writers import write_compact_array_length from kio.serial.writers import write_compact_string @@ -44,6 +44,7 @@ from kio.serial.writers import write_int32 from kio.serial.writers import write_int64 from kio.serial.writers import write_legacy_array_length +from kio.serial.writers import write_legacy_bytes from kio.serial.writers import write_legacy_string from kio.serial.writers import write_nullable_compact_string from kio.serial.writers import write_nullable_legacy_string