Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
ivanyu committed Oct 21, 2023
1 parent 078ca37 commit 7d01b02
Show file tree
Hide file tree
Showing 20 changed files with 876 additions and 8 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,5 @@ dmypy.json

# Project specific
/schema/

/.idea
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ include LICENSE
include NOTICE
include README.md
include pyproject.toml
recursive-exclude java_tester *
6 changes: 5 additions & 1 deletion codegen/generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
from kio.static.primitive import i32Timedelta
from kio.static.primitive import i64Timedelta
from kio.static.primitive import TZAware
from kio.static.protocol import ApiMessage
from kio.static.constants import ErrorCode
'''

Expand Down Expand Up @@ -396,15 +397,17 @@ def generate_dataclass( # noqa: C901
name: str,
fields: Sequence[Field],
version: int,
top_level: bool = False,
) -> Iterator[str | CustomTypeDef | ExportName]:
if (name, version) in seen:
return
seen.add((name, version))

class_parent_str = "(ApiMessage)" if top_level else ""
class_start = textwrap.dedent(
f"""\
@dataclass(frozen=True, slots=True, kw_only=True)
class {name}:
class {name}{class_parent_str}:
"""
)
class_fields = []
Expand Down Expand Up @@ -506,6 +509,7 @@ def generate_models(
name=schema.name,
fields=schema.fields,
version=version,
top_level=True,
):
match item:
case CustomTypeDef() as instruction:
Expand Down
23 changes: 22 additions & 1 deletion codegen/generate_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from types import ModuleType

import kio.schema
from kio.static.protocol import ApiMessage

from .case import to_snake_case

Expand Down Expand Up @@ -48,7 +49,7 @@ def get_entities() -> Iterator[tuple[type, str]]:
from hypothesis import given, settings
from hypothesis.strategies import from_type
from kio.serial import entity_writer
from tests.conftest import setup_buffer
from tests.conftest import setup_buffer, JavaTester
from kio.serial import entity_reader
from typing import Final
"""
Expand All @@ -72,6 +73,12 @@ def test_{entity_snake_case}_roundtrip(instance: {entity_type}) -> None:
assert instance == result
"""

test_code_java = """\
@given(instance=from_type({entity_type}))
def test_{entity_snake_case}_java(instance: {entity_type}, java_tester: JavaTester) -> None:
java_tester.test(instance)
"""


base_dir = Path(__file__).parent.parent.resolve()
generated_tests_module = base_dir / "tests" / "generated"
Expand Down Expand Up @@ -106,6 +113,20 @@ def main() -> None:
entity_snake_case=to_snake_case(entity_type.__name__),
)
)
if issubclass(entity_type, ApiMessage) and entity_type.__name__ not in {
"ProduceRequest", # Records
"FetchResponse", # Records
"FetchSnapshotResponse", # Records
"CreateTopicsResponse", # Should not output tagged field if its value equals to default
"FetchRequest", # Should not output tagged field if its value equals to default (presumably)
"ConsumerGroupHeartbeatResponse", # Nullable `assignment` field
}:
module_code[module_path].append(
test_code_java.format(
entity_type=entity_type.__name__,
entity_snake_case=to_snake_case(entity_type.__name__),
)
)

for module_path, entity_imports in module_imports.items():
with module_path.open("w") as fd:
Expand Down
11 changes: 11 additions & 0 deletions java_tester/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Ignore Gradle project-specific cache directory
.gradle

# Ignore Gradle build output directory
build

.idea

/gradle/wrapper
/gradlew
/gradlew.bat
13 changes: 13 additions & 0 deletions java_tester/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
FROM gradle:jdk17-alpine AS build

RUN mkdir /java_tester
WORKDIR /java_tester
ADD src src
ADD build.gradle build.gradle

RUN gradle --no-daemon distTar && tar -xvf build/distributions/java_tester.tar

FROM eclipse-temurin:17-jdk
COPY --from=build /java_tester/java_tester /java_tester

CMD /java_tester/bin/java_tester
3 changes: 3 additions & 0 deletions java_tester/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
.PHONY: docker_image
docker_image:
docker build -t java_tester .
24 changes: 24 additions & 0 deletions java_tester/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
plugins {
id 'application'
}

java {
toolchain {
languageVersion = JavaLanguageVersion.of(17)
}
}

repositories {
mavenCentral()
}

dependencies {
implementation 'org.apache.kafka:kafka-clients:3.5.1'
implementation 'com.fasterxml.jackson.core:jackson-databind:2.15.2'
implementation 'org.apache.commons:commons-text:1.9'
}

application {
mainClass = 'io.aiven.kio.java_tester.JavaTester'
}
run { standardInput = System.in }
1 change: 1 addition & 0 deletions java_tester/settings.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
rootProject.name = 'java_tester'
118 changes: 118 additions & 0 deletions java_tester/src/main/java/io/aiven/kio/java_tester/BaseCreator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package io.aiven.kio.java_tester;

import java.nio.ByteBuffer;
import java.util.Base64;
import java.util.Map;
import java.util.UUID;

import org.apache.kafka.common.Uuid;

import com.fasterxml.jackson.databind.JsonNode;

abstract class BaseCreator<RootT> {
protected final Class<RootT> rootClazz;
protected final short version;
protected final Map<String, JsonNode> commonStructs;

BaseCreator(Class<RootT> rootClazz, short version, Map<String, JsonNode> commonStructs) {
this.rootClazz = rootClazz;
this.version = version;
this.commonStructs = commonStructs;
}

protected static byte getByte(JsonNode fieldValue, String fieldName) throws Exception {
if (!fieldValue.isIntegralNumber()) {
throw new Exception("Expected byte in field " + fieldName + " but got " + fieldValue);
}

long longValue = fieldValue.asLong();
if ((byte) longValue != longValue) {
throw new Exception("Invalid byte value in field " + fieldName + ": " + longValue);
}

return (byte) longValue;
}

protected static short getShort(JsonNode fieldValue, String fieldName) throws Exception {
if (!fieldValue.isIntegralNumber()) {
throw new Exception("Expected short in field " + fieldName + " but got " + fieldValue);
}

long longValue = fieldValue.asLong();
if ((short) longValue != longValue) {
throw new Exception("Invalid short value in field " + fieldName + ": " + longValue);
}

return (short) longValue;
}

protected static int getInt(JsonNode fieldValue, String fieldName) throws Exception {
if (!fieldValue.isIntegralNumber()) {
throw new Exception("Expected int in field " + fieldName + " but got " + fieldValue);
}

long longValue = fieldValue.asLong();
if ((int) longValue != longValue) {
throw new Exception("Invalid int value in field " + fieldName + ": " + longValue);
}

return (int) longValue;
}

protected static long getLong(JsonNode fieldValue, String fieldName) throws Exception {
if (!fieldValue.isIntegralNumber()) {
throw new Exception("Expected int in field " + fieldName + " but got " + fieldValue);
}
return fieldValue.asLong();
}

protected static double getDouble(JsonNode fieldValue, String fieldName) throws Exception {
if (!fieldValue.isFloatingPointNumber()) {
throw new Exception("Expected double in field " + fieldName + " but got " + fieldValue);
}
return fieldValue.asDouble();
}

protected static boolean getBoolean(JsonNode fieldValue, String fieldName) throws Exception {
if (!fieldValue.isBoolean()) {
throw new Exception("Expected boolean in field " + fieldName + " but got " + fieldValue);
}
return fieldValue.asBoolean();
}

protected static String getString(JsonNode fieldValue, String fieldName) throws Exception {
if (fieldValue.isNull()) {
return null;
}

if (!fieldValue.isTextual()) {
throw new Exception("Expected string in field " + fieldName + " but got " + fieldValue);
}
return fieldValue.asText();
}

protected static Uuid getUuid(JsonNode fieldValue, String fieldName) throws Exception {
String str = getString(fieldValue, fieldName);
if (str == null) {
return null;
}
UUID tmpUuid = UUID.fromString(str);
return new Uuid(tmpUuid.getMostSignificantBits(), tmpUuid.getLeastSignificantBits());
}

protected static ByteBuffer getByteBuffer(JsonNode fieldValue, String fieldName) throws Exception {
String str = getString(fieldValue, fieldName);
if (str == null) {
return null;
}
return ByteBuffer.wrap(Base64.getDecoder().decode(str));
}

protected static byte[] getBytes(JsonNode fieldValue, String fieldName) throws Exception {
String str = getString(fieldValue, fieldName);
if (str == null) {
return null;
}
return Base64.getDecoder().decode(str);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package io.aiven.kio.java_tester;

import java.lang.reflect.Method;
import java.util.AbstractCollection;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

import com.fasterxml.jackson.databind.JsonNode;

class CollectionCreator<RootT> extends BaseCreator<RootT> {
private final JsonNode fieldValue;
private final String fieldName;
private final JsonNode fieldSchema;

private CollectionCreator(Class<RootT> rootClazz, short version, Map<String, JsonNode> commonStructs,
JsonNode fieldValue, String fieldName, JsonNode fieldSchema) {
super(rootClazz, version, commonStructs);
this.fieldValue = fieldValue;
this.fieldName = fieldName;
this.fieldSchema = fieldSchema;
}

static <RootT> AbstractCollection<Object> createAbstractCollection(
Class<RootT> rootClazz, short version, Map<String, JsonNode> commonStructs,
Class<AbstractCollection<Object>> collectionClazz, JsonNode fieldValue, String fieldName, JsonNode fieldSchema
) throws Exception {
if (!fieldValue.isArray()) {
throw new Exception("The value of " + fieldName + " must be array but was " + fieldValue);
}
CollectionCreator<RootT> creator =
new CollectionCreator<>(rootClazz, version, commonStructs, fieldValue, fieldName, fieldSchema);
return creator.createAbstractCollection(collectionClazz);
}

private AbstractCollection<Object> createAbstractCollection(
Class<AbstractCollection<Object>> collectionClazz
) throws Exception {
AbstractCollection<Object> collection = Instantiator.instantiate(collectionClazz);
Class<?> elementClazz = getCollectionElementClass(collectionClazz);
fillCollectionFromChildren(elementClazz, collection);
return collectionClazz.cast(collection);
}

private static Class<Object> getCollectionElementClass(Class<AbstractCollection<Object>> collectionClazz)
throws Exception {
// Try to estimate the element class based on the `find` method.
// `find` is expected to be present in all `AbstractCollection`s of interest.
for (Method method : collectionClazz.getDeclaredMethods()) {
if (method.getName().equals("find")) {
return (Class<Object>) method.getReturnType();
}
}
throw new Exception("No 'find' method for " + collectionClazz);
}

static <RootT> List<?> createList(
Class<RootT> rootClazz, short version, Map<String, JsonNode> commonStructs,
JsonNode fieldValue, String fieldName, JsonNode fieldSchema
) throws Exception {
if (!fieldValue.isArray()) {
throw new Exception("The value of " + fieldName + " must be array but was " + fieldValue);
}
CollectionCreator<RootT> creator =
new CollectionCreator<>(rootClazz, version, commonStructs, fieldValue, fieldName, fieldSchema);
return creator.createList();
}

private List<?> createList() throws Exception {
final String elementTypeInSchema;
{
String tmp = fieldSchema.get("type").asText();
if (!tmp.startsWith("[]")) {
throw new Exception("Unexpected type " + tmp);
}
elementTypeInSchema = tmp.substring(2);
}

Class<?> elementClazz = switch (elementTypeInSchema) {
case "int8" -> Byte.class;
case "int16" -> Short.class;
case "int32" -> Integer.class;
case "int64" -> Long.class;
case "string" -> String.class;
default -> Arrays.stream(rootClazz.getDeclaredClasses())
.filter(c -> c.getName().endsWith("$" + elementTypeInSchema))
.findFirst().get();
};

List<Object> list = new ArrayList<>();
fillCollectionFromChildren(elementClazz, list);
return list;
}

private void fillCollectionFromChildren(
Class<?> elementClazz, Collection<Object> collection
) throws Exception {
if (!fieldValue.isArray()) {
throw new Exception("The value of " + fieldName + " must be array but was " + fieldValue);
}

Iterator<JsonNode> elements = fieldValue.elements();
while (elements.hasNext()) {
JsonNode elementValue = elements.next();
Object elementObj;
if (elementClazz.equals(Byte.class)) {
elementObj = getByte(elementValue, fieldName);
} else if (elementClazz.equals(Short.class)) {
elementObj = getShort(elementValue, fieldName);
} else if (elementClazz.equals(Integer.class)) {
elementObj = getInt(elementValue, fieldName);
} else if (elementClazz.equals(Long.class)) {
elementObj = getLong(elementValue, fieldName);
} else if (elementClazz.equals(String.class)) {
elementObj = getString(elementValue, fieldName);
} else {
elementObj = ObjectCreator.create(
rootClazz, version, commonStructs, elementClazz, fieldSchema, elementValue);
}
collection.add(elementObj);
}
}
}
Loading

0 comments on commit 7d01b02

Please sign in to comment.