diff --git a/avro-fastserde/src/main/java/com/linkedin/avro/api/PrimitiveFloatList.java b/avro-fastserde/src/main/java/com/linkedin/avro/api/PrimitiveFloatList.java new file mode 100644 index 000000000..8e87a9e48 --- /dev/null +++ b/avro-fastserde/src/main/java/com/linkedin/avro/api/PrimitiveFloatList.java @@ -0,0 +1,11 @@ +package com.linkedin.avro.api; + +import java.util.List; + +/** + * A {@link List} implementation with additional functions to prevent boxing. + */ +public interface PrimitiveFloatList extends List { + float getPrimitive(int index); + boolean addPrimitive(float o); +} diff --git a/avro-fastserde/src/main/java/com/linkedin/avro/fastserde/PrimitiveFloatList.java b/avro-fastserde/src/main/java/com/linkedin/avro/fastserde/ByteBufferBackedPrimitiveFloatList.java similarity index 87% rename from avro-fastserde/src/main/java/com/linkedin/avro/fastserde/PrimitiveFloatList.java rename to avro-fastserde/src/main/java/com/linkedin/avro/fastserde/ByteBufferBackedPrimitiveFloatList.java index d80548d40..7effd4f0a 100644 --- a/avro-fastserde/src/main/java/com/linkedin/avro/fastserde/PrimitiveFloatList.java +++ b/avro-fastserde/src/main/java/com/linkedin/avro/fastserde/ByteBufferBackedPrimitiveFloatList.java @@ -1,5 +1,6 @@ package com.linkedin.avro.fastserde; +import com.linkedin.avro.api.PrimitiveFloatList; import java.io.IOException; import java.nio.ByteBuffer; import java.util.AbstractList; @@ -33,8 +34,8 @@ * * TODO: Provide arrays for other primitive types. */ -public class PrimitiveFloatList extends AbstractList - implements GenericArray, Comparable> { +public class ByteBufferBackedPrimitiveFloatList extends AbstractList + implements GenericArray, Comparable>, PrimitiveFloatList { private static final float[] EMPTY = new float[0]; private static final int FLOAT_SIZE = Float.BYTES; private static final Schema FLOAT_SCHEMA = Schema.create(Schema.Type.FLOAT); @@ -44,7 +45,7 @@ public class PrimitiveFloatList extends AbstractList private boolean isCached = false; private CompositeByteBuffer byteBuffer; - public PrimitiveFloatList(int capacity) { + public ByteBufferBackedPrimitiveFloatList(int capacity) { if (capacity != 0) { elements = new float[capacity]; } @@ -52,7 +53,7 @@ public PrimitiveFloatList(int capacity) { byteBuffer = new CompositeByteBuffer(capacity != 0); } - public PrimitiveFloatList(Collection c) { + public ByteBufferBackedPrimitiveFloatList(Collection c) { if (c != null) { elements = new float[c.size()]; addAll(c); @@ -61,13 +62,13 @@ public PrimitiveFloatList(Collection c) { } /** - * Instantiate (or re-use) and populate a {@link PrimitiveFloatList} from a {@link org.apache.avro.io.Decoder}. + * Instantiate (or re-use) and populate a {@link ByteBufferBackedPrimitiveFloatList} from a {@link org.apache.avro.io.Decoder}. * * N.B.: the caller must ensure the data is of the appropriate type by calling {@link #isFloatArray(Schema)}. * - * @param old old {@link PrimitiveFloatList} to reuse + * @param old old {@link ByteBufferBackedPrimitiveFloatList} to reuse * @param in {@link org.apache.avro.io.Decoder} to read new list from - * @return a {@link PrimitiveFloatList} with data, possibly the old argument reused + * @return a {@link ByteBufferBackedPrimitiveFloatList} with data, possibly the old argument reused * @throws IOException on io errors */ public static Object readPrimitiveFloatArray(Object old, Decoder in) throws IOException { @@ -75,7 +76,7 @@ public static Object readPrimitiveFloatArray(Object old, Decoder in) throws IOEx long totalLength = 0; if (length > 0) { - PrimitiveFloatList array = (PrimitiveFloatList) newPrimitiveFloatArray(old); + ByteBufferBackedPrimitiveFloatList array = (ByteBufferBackedPrimitiveFloatList) newPrimitiveFloatArray(old); int index = 0; do { @@ -90,11 +91,11 @@ public static Object readPrimitiveFloatArray(Object old, Decoder in) throws IOEx setupElements(array, (int)totalLength); return array; } else { - return new PrimitiveFloatList(0); + return new ByteBufferBackedPrimitiveFloatList(0); } } - private static void setupElements(PrimitiveFloatList list, int totalSize) { + private static void setupElements(ByteBufferBackedPrimitiveFloatList list, int totalSize) { if (list.elements.length != 0) { if (totalSize <= list.getCapacity()) { // reuse the float array directly @@ -111,7 +112,7 @@ private static void setupElements(PrimitiveFloatList list, int totalSize) { /** * @param expected {@link Schema} to inspect - * @return true if the {@code expected} SCHEMA is of the right type to decode as a {@link PrimitiveFloatList} + * @return true if the {@code expected} SCHEMA is of the right type to decode as a {@link ByteBufferBackedPrimitiveFloatList} * false otherwise */ public static boolean isFloatArray(Schema expected) { @@ -120,15 +121,15 @@ public static boolean isFloatArray(Schema expected) { } private static Object newPrimitiveFloatArray(Object old) { - if (old instanceof PrimitiveFloatList) { - PrimitiveFloatList oldFloatList = (PrimitiveFloatList) old; + if (old instanceof ByteBufferBackedPrimitiveFloatList) { + ByteBufferBackedPrimitiveFloatList oldFloatList = (ByteBufferBackedPrimitiveFloatList) old; oldFloatList.byteBuffer.clear(); oldFloatList.isCached = false; oldFloatList.size = 0; return oldFloatList; } else { // Just a place holder, will set up the elements later. - return new PrimitiveFloatList(0); + return new ByteBufferBackedPrimitiveFloatList(0); } } @@ -282,8 +283,8 @@ public Float peek() { @Override public int compareTo(GenericArray that) { cacheFromByteBuffer(); - if (that instanceof PrimitiveFloatList) { - PrimitiveFloatList thatPrimitiveList = (PrimitiveFloatList) that; + if (that instanceof ByteBufferBackedPrimitiveFloatList) { + ByteBufferBackedPrimitiveFloatList thatPrimitiveList = (ByteBufferBackedPrimitiveFloatList) that; if (this.size == thatPrimitiveList.size) { for (int i = 0; i < this.size; i++) { int compare = Float.compare(this.elements[i], thatPrimitiveList.elements[i]); diff --git a/avro-fastserde/src/main/java/com/linkedin/avro/fastserde/FastDeserializerGenerator.java b/avro-fastserde/src/main/java/com/linkedin/avro/fastserde/FastDeserializerGenerator.java index 7a6925518..8733c161d 100644 --- a/avro-fastserde/src/main/java/com/linkedin/avro/fastserde/FastDeserializerGenerator.java +++ b/avro-fastserde/src/main/java/com/linkedin/avro/fastserde/FastDeserializerGenerator.java @@ -589,10 +589,10 @@ private void processArray(JVar arraySchemaVar, final String name, final Schema a final JVar arrayVar = action.getShouldRead() ? declareValueVar(name, arraySchema, parentBody) : null; /** - * Special optimization for float array by leveraging {@link PrimitiveFloatList}. + * Special optimization for float array by leveraging {@link ByteBufferBackedPrimitiveFloatList}. */ if (action.getShouldRead() && arraySchema.getElementType().getType().equals(Schema.Type.FLOAT)) { - JClass primitiveFloatList = codeModel.ref(PrimitiveFloatList.class); + JClass primitiveFloatList = codeModel.ref(ByteBufferBackedPrimitiveFloatList.class); JExpression readPrimitiveFloatArrayInvocation = primitiveFloatList.staticInvoke("readPrimitiveFloatArray"). arg(reuseSupplier.get()).arg(JExpr.direct(DECODER)); JExpression castedResult = diff --git a/avro-fastserde/src/main/java/com/linkedin/avro/fastserde/FastSerdeCache.java b/avro-fastserde/src/main/java/com/linkedin/avro/fastserde/FastSerdeCache.java index f6e922880..bbfb22676 100644 --- a/avro-fastserde/src/main/java/com/linkedin/avro/fastserde/FastSerdeCache.java +++ b/avro-fastserde/src/main/java/com/linkedin/avro/fastserde/FastSerdeCache.java @@ -1,5 +1,7 @@ package com.linkedin.avro.fastserde; +import org.apache.avro.generic.ColdGenericDatumReader; +import org.apache.avro.generic.ColdSpecificDatumReader; import java.io.File; import java.io.IOException; import java.lang.reflect.ParameterizedType; @@ -478,7 +480,7 @@ public static class FastDeserializerWithAvroSpecificImpl implements FastDeser private final SpecificDatumReader datumReader; public FastDeserializerWithAvroSpecificImpl(Schema writerSchema, Schema readerSchema) { - this.datumReader = new SpecificDatumReader<>(writerSchema, readerSchema); + this.datumReader = new ColdSpecificDatumReader<>(writerSchema, readerSchema); } @Override @@ -491,7 +493,7 @@ public static class FastDeserializerWithAvroGenericImpl implements FastDeseri private final GenericDatumReader datumReader; public FastDeserializerWithAvroGenericImpl(Schema writerSchema, Schema readerSchema) { - this.datumReader = new GenericDatumReader<>(writerSchema, readerSchema); + this.datumReader = new ColdGenericDatumReader<>(writerSchema, readerSchema); } @Override diff --git a/avro-fastserde/src/main/java/com/linkedin/avro/fastserde/coldstart/ColdPrimitiveFloatList.java b/avro-fastserde/src/main/java/com/linkedin/avro/fastserde/coldstart/ColdPrimitiveFloatList.java new file mode 100644 index 000000000..b389bd8bb --- /dev/null +++ b/avro-fastserde/src/main/java/com/linkedin/avro/fastserde/coldstart/ColdPrimitiveFloatList.java @@ -0,0 +1,22 @@ +package com.linkedin.avro.fastserde.coldstart; + +import com.linkedin.avro.api.PrimitiveFloatList; +import org.apache.avro.Schema; +import org.apache.avro.generic.GenericData; + +public class ColdPrimitiveFloatList extends GenericData.Array implements PrimitiveFloatList { + private static final Schema SCHEMA = Schema.createArray(Schema.create(Schema.Type.FLOAT)); + public ColdPrimitiveFloatList(int capacity) { + super(capacity, SCHEMA); + } + + @Override + public float getPrimitive(int index) { + return get(index); + } + + @Override + public boolean addPrimitive(float o) { + return add(o); + } +} diff --git a/avro-fastserde/src/main/java/org/apache/avro/generic/ColdDatumReaderMixIn.java b/avro-fastserde/src/main/java/org/apache/avro/generic/ColdDatumReaderMixIn.java new file mode 100644 index 000000000..dcba7c532 --- /dev/null +++ b/avro-fastserde/src/main/java/org/apache/avro/generic/ColdDatumReaderMixIn.java @@ -0,0 +1,27 @@ +package org.apache.avro.generic; + +import com.linkedin.avro.fastserde.coldstart.ColdPrimitiveFloatList; +import org.apache.avro.Schema; + + +/** + * An interface with default implementation in order to defeat the lack of multiple inheritance. + */ +public interface ColdDatumReaderMixIn { + default Object newArray(Object old, int size, Schema schema, NewArrayFunction fallBackFunction) { + switch (schema.getElementType().getType()) { + case FLOAT: + if (null == old || !(old instanceof ColdPrimitiveFloatList)) { + old = new ColdPrimitiveFloatList(size); + } + return old; + // TODO: Add more cases when we support more primitive array types + default: + return fallBackFunction.newArray(old, size, schema); + } + } + + interface NewArrayFunction { + Object newArray(Object old, int size, Schema schema); + } +} diff --git a/avro-fastserde/src/main/java/org/apache/avro/generic/ColdGenericDatumReader.java b/avro-fastserde/src/main/java/org/apache/avro/generic/ColdGenericDatumReader.java new file mode 100644 index 000000000..2ecb60769 --- /dev/null +++ b/avro-fastserde/src/main/java/org/apache/avro/generic/ColdGenericDatumReader.java @@ -0,0 +1,21 @@ +package org.apache.avro.generic; + +import org.apache.avro.Schema; + + +/** + * A light-weight extension of {@link GenericDatumReader} which merely ensures that the types of the + * extended API are always returned. + * + * This class needs to be in the org.apache.avro.generic package in order to access protected methods. + */ +public class ColdGenericDatumReader extends GenericDatumReader implements ColdDatumReaderMixIn { + public ColdGenericDatumReader(Schema writerSchema, Schema readerSchema) { + super(writerSchema, readerSchema); + } + + @Override + protected Object newArray(Object old, int size, Schema schema) { + return newArray(old, size, schema, super::newArray); + } +} diff --git a/avro-fastserde/src/main/java/org/apache/avro/generic/ColdSpecificDatumReader.java b/avro-fastserde/src/main/java/org/apache/avro/generic/ColdSpecificDatumReader.java new file mode 100644 index 000000000..008b97eb7 --- /dev/null +++ b/avro-fastserde/src/main/java/org/apache/avro/generic/ColdSpecificDatumReader.java @@ -0,0 +1,22 @@ +package org.apache.avro.generic; + +import org.apache.avro.Schema; +import org.apache.avro.specific.SpecificDatumReader; + + +/** + * A light-weight extension of {@link SpecificDatumReader} which merely ensures that the types of + * the extended API are always returned. + * + * This class needs to be in the org.apache.avro.generic package in order to access protected methods. + */ +public class ColdSpecificDatumReader extends SpecificDatumReader implements ColdDatumReaderMixIn { + public ColdSpecificDatumReader(Schema writerSchema, Schema readerSchema) { + super(writerSchema, readerSchema); + } + + @Override + protected Object newArray(Object old, int size, Schema schema) { + return newArray(old, size, schema, super::newArray); + } +} diff --git a/avro-fastserde/src/test/java/com/linkedin/avro/fastserde/FastDeserializerDefaultsTest.java b/avro-fastserde/src/test/java/com/linkedin/avro/fastserde/FastDeserializerDefaultsTest.java index c490c9d75..94df3d737 100644 --- a/avro-fastserde/src/test/java/com/linkedin/avro/fastserde/FastDeserializerDefaultsTest.java +++ b/avro-fastserde/src/test/java/com/linkedin/avro/fastserde/FastDeserializerDefaultsTest.java @@ -87,7 +87,7 @@ public void testPrimitiveFloatListAddPrimitive() { long startTime = System.currentTimeMillis(); for (int i = 0; i < iteration; i++) { - PrimitiveFloatList list = new PrimitiveFloatList(array_size); + ByteBufferBackedPrimitiveFloatList list = new ByteBufferBackedPrimitiveFloatList(array_size); for (int l = 0; l < array_size; l++) { list.addPrimitive((float) l); diff --git a/avro-fastserde/src/test/java/com/linkedin/avro/fastserde/FastGenericDeserializerGeneratorTest.java b/avro-fastserde/src/test/java/com/linkedin/avro/fastserde/FastGenericDeserializerGeneratorTest.java index 7fdc80913..1ac208931 100644 --- a/avro-fastserde/src/test/java/com/linkedin/avro/fastserde/FastGenericDeserializerGeneratorTest.java +++ b/avro-fastserde/src/test/java/com/linkedin/avro/fastserde/FastGenericDeserializerGeneratorTest.java @@ -1,5 +1,6 @@ package com.linkedin.avro.fastserde; +import com.linkedin.avro.api.PrimitiveFloatList; import com.linkedin.avroutil1.compatibility.AvroCompatibilityHelper; import java.io.File; import java.io.IOException; @@ -13,6 +14,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Supplier; import org.apache.avro.Schema; import org.apache.avro.generic.GenericData; import org.apache.avro.generic.GenericDatumReader; @@ -20,6 +22,7 @@ import org.apache.avro.io.Decoder; import org.apache.avro.util.Utf8; import org.testng.Assert; +import org.testng.SkipException; import org.testng.annotations.BeforeTest; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; @@ -29,9 +32,43 @@ public class FastGenericDeserializerGeneratorTest { - private File tempDir; - private ClassLoader classLoader; + private static File tempDir; + private static ClassLoader classLoader; + enum Implementation { + VANILLA_AVRO(false, FastGenericDeserializerGeneratorTest::decodeRecordSlow), + COLD_FAST_AVRO(true, FastGenericDeserializerGeneratorTest::decodeRecordColdFast), + WARM_FAST_AVRO(true, FastGenericDeserializerGeneratorTest::decodeRecordWarmFast); + + boolean isFast; + DecodeFunction decodeFunction; + Implementation(boolean isFast, DecodeFunction decodeFunction) { + this.isFast = isFast; + this.decodeFunction = decodeFunction; + } + + interface DecodeFunction { + T decode(Schema writerSchema, Schema readerSchema, Decoder decoder); + } + + T decode(Schema writerSchema, Schema readerSchema, Decoder decoder) { + return (T) decodeFunction.decode(writerSchema, readerSchema, decoder); + } + } + + @DataProvider(name = "Implementation") + public static Object[][] implementations() { + return new Object[][]{ + {Implementation.VANILLA_AVRO}, + {Implementation.COLD_FAST_AVRO}, + {Implementation.WARM_FAST_AVRO} + }; + } + + /** + * @deprecated TODO Migrate to {@link #implementations()} + */ + @Deprecated @DataProvider(name = "SlowFastDeserializer") public static Object[][] deserializers() { return new Object[][]{{true}, {false}}; @@ -46,8 +83,8 @@ public void prepare() throws Exception { FastGenericDeserializerGeneratorTest.class.getClassLoader()); } - @Test(groups = {"deserializationTest"}, dataProvider = "SlowFastDeserializer") - public void shouldReadPrimitives(Boolean whetherUseFastDeserializer) { + @Test(groups = {"deserializationTest"}, dataProvider = "Implementation") + public void shouldReadPrimitives(Implementation implementation) { // given Schema recordSchema = createRecord("testRecord", createField("testInt", Schema.create(Schema.Type.INT)), createPrimitiveUnionFieldSchema("testIntUnion", Schema.Type.INT), @@ -81,12 +118,7 @@ public void shouldReadPrimitives(Boolean whetherUseFastDeserializer) { record.put("testBytesUnion", ByteBuffer.wrap(new byte[]{0x01, 0x02})); // when - GenericRecord decodedRecord = null; - if (whetherUseFastDeserializer) { - decodedRecord = decodeRecordFast(recordSchema, recordSchema, genericDataAsDecoder(record)); - } else { - decodedRecord = decodeRecordSlow(recordSchema, recordSchema, genericDataAsDecoder(record)); - } + GenericRecord decodedRecord = implementation.decode(recordSchema, recordSchema, genericDataAsDecoder(record)); // then Assert.assertEquals(1, decodedRecord.get("testInt")); @@ -111,8 +143,8 @@ public GenericData.Fixed newFixed(Schema fixedSchema, byte[] bytes) { return fixed; } - @Test(groups = {"deserializationTest"}, dataProvider = "SlowFastDeserializer") - public void shouldReadFixed(Boolean whetherUseFastDeserializer) { + @Test(groups = {"deserializationTest"}, dataProvider = "Implementation") + public void shouldReadFixed(Implementation implementation) { // given Schema fixedSchema = createFixedSchema("testFixed", 2); Schema recordSchema = createRecord("testRecord", createField("testFixed", fixedSchema), @@ -126,12 +158,7 @@ public void shouldReadFixed(Boolean whetherUseFastDeserializer) { originalRecord.put("testFixedUnionArray", Arrays.asList(newFixed(fixedSchema, new byte[]{0x07, 0x08}))); // when - GenericRecord record = null; - if (whetherUseFastDeserializer) { - record = decodeRecordFast(recordSchema, recordSchema, genericDataAsDecoder(originalRecord)); - } else { - record = decodeRecordFast(recordSchema, recordSchema, genericDataAsDecoder(originalRecord)); - } + GenericRecord record = implementation.decode(recordSchema, recordSchema, genericDataAsDecoder(originalRecord)); // then Assert.assertEquals(new byte[]{0x01, 0x02}, ((GenericData.Fixed) record.get("testFixed")).bytes()); @@ -142,8 +169,8 @@ record = decodeRecordFast(recordSchema, recordSchema, genericDataAsDecoder(origi ((List) record.get("testFixedUnionArray")).get(0).bytes()); } - @Test(groups = {"deserializationTest"}, dataProvider = "SlowFastDeserializer") - public void shouldReadEnum(Boolean whetherUseFastDeserializer) { + @Test(groups = {"deserializationTest"}, dataProvider = "Implementation") + public void shouldReadEnum(Implementation implementation) { // given Schema enumSchema = createEnumSchema("testEnum", new String[]{"A", "B"}); Schema recordSchema = @@ -162,12 +189,7 @@ public void shouldReadEnum(Boolean whetherUseFastDeserializer) { Arrays.asList(AvroCompatibilityHelper.newEnumSymbol(enumSchema, "A")));//new GenericData.EnumSymbol("A"))); // when - GenericRecord record = null; - if (whetherUseFastDeserializer) { - record = decodeRecordFast(recordSchema, recordSchema, genericDataAsDecoder(originalRecord)); - } else { - record = decodeRecordFast(recordSchema, recordSchema, genericDataAsDecoder(originalRecord)); - } + GenericRecord record = implementation.decode(recordSchema, recordSchema, genericDataAsDecoder(originalRecord)); // then Assert.assertEquals("A", record.get("testEnum").toString()); @@ -176,8 +198,8 @@ record = decodeRecordFast(recordSchema, recordSchema, genericDataAsDecoder(origi Assert.assertEquals("A", ((List) record.get("testEnumUnionArray")).get(0).toString()); } - @Test(groups = {"deserializationTest"}, dataProvider = "SlowFastDeserializer") - public void shouldReadPermutatedEnum(Boolean whetherUseFastDeserializer) { + @Test(groups = {"deserializationTest"}, dataProvider = "Implementation") + public void shouldReadPermutatedEnum(Implementation implementation) { // given Schema enumSchema = createEnumSchema("testEnum", new String[]{"A", "B", "C", "D", "E"}); Schema recordSchema = @@ -202,12 +224,7 @@ public void shouldReadPermutatedEnum(Boolean whetherUseFastDeserializer) { createArrayFieldSchema("testEnumUnionArray", createUnionSchema(enumSchema1))); // when - GenericRecord record = null; - if (whetherUseFastDeserializer) { - record = decodeRecordFast(recordSchema, recordSchema1, genericDataAsDecoder(originalRecord)); - } else { - record = decodeRecordFast(recordSchema, recordSchema1, genericDataAsDecoder(originalRecord)); - } + GenericRecord record = implementation.decode(recordSchema, recordSchema1, genericDataAsDecoder(originalRecord)); // then Assert.assertEquals("A", record.get("testEnum").toString()); @@ -230,11 +247,11 @@ public void shouldNotReadStrippedEnum() { Schema recordSchema1 = createRecord("testRecord", createField("testEnum", enumSchema1)); // when - GenericRecord record = decodeRecordFast(recordSchema, recordSchema1, genericDataAsDecoder(originalRecord)); + GenericRecord record = decodeRecordWarmFast(recordSchema, recordSchema1, genericDataAsDecoder(originalRecord)); } - @Test(groups = {"deserializationTest"}, dataProvider = "SlowFastDeserializer") - public void shouldReadSubRecordField(Boolean whetherUseFastDeserializer) { + @Test(groups = {"deserializationTest"}, dataProvider = "Implementation") + public void shouldReadSubRecordField(Implementation implementation) { // given Schema subRecordSchema = createRecord("subRecord", createPrimitiveUnionFieldSchema("subField", Schema.Type.STRING)); @@ -251,12 +268,7 @@ public void shouldReadSubRecordField(Boolean whetherUseFastDeserializer) { builder.put("field", "abc"); // when - GenericRecord record = null; - if (whetherUseFastDeserializer) { - record = decodeRecordFast(recordSchema, recordSchema, genericDataAsDecoder(builder)); - } else { - record = decodeRecordSlow(recordSchema, recordSchema, genericDataAsDecoder(builder)); - } + GenericRecord record = implementation.decode(recordSchema, recordSchema, genericDataAsDecoder(builder)); // then Assert.assertEquals(new Utf8("abc"), ((GenericRecord) record.get("record")).get("subField")); @@ -266,8 +278,8 @@ record = decodeRecordSlow(recordSchema, recordSchema, genericDataAsDecoder(build Assert.assertEquals(new Utf8("abc"), record.get("field")); } - @Test(groups = {"deserializationTest"}, dataProvider = "SlowFastDeserializer") - public void shouldReadSubRecordCollectionsField(Boolean whetherUseFastDeserializer) { + @Test(groups = {"deserializationTest"}, dataProvider = "Implementation") + public void shouldReadSubRecordCollectionsField(Implementation implementation) { // given Schema subRecordSchema = createRecord("subRecord", createPrimitiveUnionFieldSchema("subField", Schema.Type.STRING)); Schema recordSchema = createRecord("test", createArrayFieldSchema("recordsArray", subRecordSchema), @@ -289,12 +301,7 @@ public void shouldReadSubRecordCollectionsField(Boolean whetherUseFastDeserializ builder.put("recordsMapUnion", recordsMap); // when - GenericRecord record = null; - if (whetherUseFastDeserializer) { - record = decodeRecordFast(recordSchema, recordSchema, genericDataAsDecoder(builder)); - } else { - record = decodeRecordFast(recordSchema, recordSchema, genericDataAsDecoder(builder)); - } + GenericRecord record = implementation.decode(recordSchema, recordSchema, genericDataAsDecoder(builder)); // then Assert.assertEquals(new Utf8("abc"), @@ -307,8 +314,8 @@ record = decodeRecordFast(recordSchema, recordSchema, genericDataAsDecoder(build ((Map) record.get("recordsMapUnion")).get(new Utf8("1")).get("subField")); } - @Test(groups = {"deserializationTest"}, dataProvider = "SlowFastDeserializer") - public void shouldReadSubRecordComplexCollectionsField(Boolean whetherUseFastDeserializer) { + @Test(groups = {"deserializationTest"}, dataProvider = "Implementation") + public void shouldReadSubRecordComplexCollectionsField(Implementation implementation) { // given Schema subRecordSchema = createRecord("subRecord", createPrimitiveUnionFieldSchema("subField", Schema.Type.STRING)); Schema recordSchema = createRecord("test", @@ -340,12 +347,7 @@ public void shouldReadSubRecordComplexCollectionsField(Boolean whetherUseFastDes builder.put("recordsMapArrayUnion", recordsMapArray); // when - GenericRecord record = null; - if (whetherUseFastDeserializer) { - record = decodeRecordFast(recordSchema, recordSchema, genericDataAsDecoder(builder)); - } else { - record = decodeRecordSlow(recordSchema, recordSchema, genericDataAsDecoder(builder)); - } + GenericRecord record = implementation.decode(recordSchema, recordSchema, genericDataAsDecoder(builder)); // then Assert.assertEquals(new Utf8("abc"), @@ -361,36 +363,56 @@ record = decodeRecordSlow(recordSchema, recordSchema, genericDataAsDecoder(build .get("subField")); } - @Test(groups = {"deserializationTest"}, dataProvider = "SlowFastDeserializer") - public void shouldReadAliasedField(Boolean whetherUseFastDeserializer) { + @Test(groups = {"deserializationTest"}, dataProvider = "Implementation") + public void shouldReadAliasedField(Implementation implementation) { + if (Utils.isAvro14()) { + /** + * The rest of the code in this function has been "adapted" to work with 1.4's bugs, but the end result is so + * contrived that it's probably better to not run this at all. Feel free to comment the Skip and try it out... + */ + throw new SkipException("Aliases are not properly supported in Avro 1.4"); + } + // given - Schema record1Schema = createRecord("test", createPrimitiveUnionFieldSchema("testString", Schema.Type.STRING), - createPrimitiveUnionFieldSchema("testStringUnion", Schema.Type.STRING)); - Schema record2Schema = createRecord("test", createPrimitiveUnionFieldSchema("testString", Schema.Type.STRING), - addAliases(createPrimitiveUnionFieldSchema("testStringUnionAlias", Schema.Type.STRING), "testStringUnion")); + String field1Name = "testString"; + String field1Value = "abc"; + /** A Supplier is needed because Avro doesn't tolerate reusing {@link Field} instances in multiple {@link Schema}... */ + Supplier field1Supplier = () -> createPrimitiveUnionFieldSchema(field1Name, Schema.Type.STRING); + /** + * Avro 1.4's support for aliases is so broken that if any of the fields have an alias, then ALL fields must have one. + * Therefore, we are "fixing" the new schema in this weird way by having an alias which is the same as the original + * name, otherwise that field is completely gone when decoding. This is due to a bug in the 1.4 implementation of + * {@link Schema#getFieldAlias(Schema.Name, String, Map)}. + */ + Schema.Field newField1 = Utils.isAvro14() + ? addAliases(createPrimitiveUnionFieldSchema(field1Name, Schema.Type.STRING), field1Name) + : field1Supplier.get(); + + String originalField2Name = "testStringUnion"; + String newField2Name = "testStringUnionAlias"; + String field2Value = "def"; + Schema.Field originalField2 = createPrimitiveUnionFieldSchema(originalField2Name, Schema.Type.STRING); + Schema.Field newField2 = addAliases(createPrimitiveUnionFieldSchema(newField2Name, Schema.Type.STRING), originalField2Name); + + Schema record1Schema = createRecord("test", field1Supplier.get(), originalField2); + Schema record2Schema = createRecord("test", newField1, newField2); GenericData.Record builder = new GenericData.Record(record1Schema); - builder.put("testString", "abc"); - builder.put("testStringUnion", "def"); + builder.put(field1Name, field1Value); + builder.put(originalField2Name, field2Value); // when - GenericRecord record = null; - if (whetherUseFastDeserializer) { - record = decodeRecordFast(record1Schema, record2Schema, genericDataAsDecoder(builder)); - } else { - record = decodeRecordFast(record1Schema, record2Schema, genericDataAsDecoder(builder)); - } + GenericRecord record = implementation.decode(record1Schema, record2Schema, genericDataAsDecoder(builder)); // then - Assert.assertEquals(new Utf8("abc"), record.get("testString")); - // Alias is not well supported in avro-1.4 + Assert.assertEquals(record.get(field1Name), new Utf8(field1Value)); if (!Utils.isAvro14()) { - Assert.assertEquals(new Utf8("def"), record.get("testStringUnionAlias")); + Assert.assertEquals(record.get(newField2Name), new Utf8(field2Value)); } } - @Test(groups = {"deserializationTest"}, dataProvider = "SlowFastDeserializer") - public void shouldSkipRemovedField(Boolean whetherUseFastDeserializer) { + @Test(groups = {"deserializationTest"}, dataProvider = "Implementation") + public void shouldSkipRemovedField(Implementation implementation) { // given Schema subRecord1Schema = createRecord("subRecord", createPrimitiveUnionFieldSchema("testNotRemoved", Schema.Type.STRING), @@ -426,12 +448,7 @@ public void shouldSkipRemovedField(Boolean whetherUseFastDeserializer) { builder.put("subRecordMap", recordsMap); // when - GenericRecord record = null; - if (whetherUseFastDeserializer) { - record = decodeRecordFast(record1Schema, record2Schema, genericDataAsDecoder(builder)); - } else { - record = decodeRecordFast(record1Schema, record2Schema, genericDataAsDecoder(builder)); - } + GenericRecord record = implementation.decode(record1Schema, record2Schema, genericDataAsDecoder(builder)); // then Assert.assertEquals(new Utf8("abc"), record.get("testNotRemoved")); @@ -444,8 +461,8 @@ record = decodeRecordFast(record1Schema, record2Schema, genericDataAsDecoder(bui ((Map) record.get("subRecordMap")).get(new Utf8("1")).get("testNotRemoved2")); } - @Test(groups = {"deserializationTest"}, dataProvider = "SlowFastDeserializer") - public void shouldSkipRemovedRecord(Boolean whetherUseFastDeserializer) { + @Test(groups = {"deserializationTest"}, dataProvider = "Implementation") + public void shouldSkipRemovedRecord(Implementation implementation) { // given Schema subRecord1Schema = createRecord("subRecord", createPrimitiveFieldSchema("test1", Schema.Type.STRING), createPrimitiveFieldSchema("test2", Schema.Type.STRING)); @@ -474,12 +491,7 @@ public void shouldSkipRemovedRecord(Boolean whetherUseFastDeserializer) { builder.put("subRecord4", subRecordBuilder); // when - GenericRecord record = null; - if (whetherUseFastDeserializer) { - record = decodeRecordFast(record1Schema, record2Schema, genericDataAsDecoder(builder)); - } else { - record = decodeRecordFast(record1Schema, record2Schema, genericDataAsDecoder(builder)); - } + GenericRecord record = implementation.decode(record1Schema, record2Schema, genericDataAsDecoder(builder)); // then Assert.assertEquals(new Utf8("abc"), ((GenericRecord) record.get("subRecord1")).get("test1")); @@ -488,8 +500,8 @@ record = decodeRecordFast(record1Schema, record2Schema, genericDataAsDecoder(bui Assert.assertEquals(new Utf8("def"), ((GenericRecord) record.get("subRecord4")).get("test2")); } - @Test(groups = {"deserializationTest"}, dataProvider = "SlowFastDeserializer") - public void shouldSkipRemovedNestedRecord(Boolean whetherUseFastDeserializer) { + @Test(groups = {"deserializationTest"}, dataProvider = "Implementation") + public void shouldSkipRemovedNestedRecord(Implementation implementation) { // given Schema subSubRecordSchema = createRecord("subSubRecord", createPrimitiveFieldSchema("test1", Schema.Type.STRING), createPrimitiveFieldSchema("test2", Schema.Type.STRING)); @@ -517,20 +529,15 @@ public void shouldSkipRemovedNestedRecord(Boolean whetherUseFastDeserializer) { builder.put("subRecord", subRecordBuilder); // when - GenericRecord record = null; - if (whetherUseFastDeserializer) { - record = decodeRecordFast(record1Schema, record2Schema, genericDataAsDecoder(builder)); - } else { - record = decodeRecordSlow(record2Schema, record1Schema, genericDataAsDecoder(builder)); - } + GenericRecord record = implementation.decode(record1Schema, record2Schema, genericDataAsDecoder(builder)); // then Assert.assertEquals(new Utf8("abc"), ((GenericRecord) record.get("subRecord")).get("test1")); Assert.assertEquals(new Utf8("def"), ((GenericRecord) record.get("subRecord")).get("test4")); } - @Test(groups = {"deserializationTest"}, dataProvider = "SlowFastDeserializer") - public void shouldReadMultipleChoiceUnion(Boolean whetherUseFastDeserializer) { + @Test(groups = {"deserializationTest"}, dataProvider = "Implementation") + public void shouldReadMultipleChoiceUnion(Implementation implementation) { // given Schema subRecordSchema = createRecord("subRecord", createPrimitiveUnionFieldSchema("subField", Schema.Type.STRING)); @@ -544,12 +551,7 @@ public void shouldReadMultipleChoiceUnion(Boolean whetherUseFastDeserializer) { builder.put("union", subRecordBuilder); // when - GenericRecord record = null; - if (whetherUseFastDeserializer) { - record = decodeRecordFast(recordSchema, recordSchema, genericDataAsDecoder(builder)); - } else { - record = decodeRecordSlow(recordSchema, recordSchema, genericDataAsDecoder(builder)); - } + GenericRecord record = implementation.decode(recordSchema, recordSchema, genericDataAsDecoder(builder)); // then Assert.assertEquals(new Utf8("abc"), ((GenericData.Record) record.get("union")).get("subField")); @@ -559,11 +561,8 @@ record = decodeRecordSlow(recordSchema, recordSchema, genericDataAsDecoder(build builder.put("union", "abc"); // when - if (whetherUseFastDeserializer) { - record = decodeRecordFast(recordSchema, recordSchema, genericDataAsDecoder(builder)); - } else { - record = decodeRecordSlow(recordSchema, recordSchema, genericDataAsDecoder(builder)); - } + record = implementation.decode(recordSchema, recordSchema, genericDataAsDecoder(builder)); + // then Assert.assertEquals(new Utf8("abc"), record.get("union")); @@ -572,14 +571,14 @@ record = decodeRecordSlow(recordSchema, recordSchema, genericDataAsDecoder(build builder.put("union", 1); // when - record = decodeRecordFast(recordSchema, recordSchema, genericDataAsDecoder(builder)); + record = implementation.decode(recordSchema, recordSchema, genericDataAsDecoder(builder)); // then Assert.assertEquals(1, record.get("union")); } - @Test(groups = {"deserializationTest"}, dataProvider = "SlowFastDeserializer") - public void shouldReadArrayOfRecords(Boolean whetherUseFastDeserializer) { + @Test(groups = {"deserializationTest"}, dataProvider = "Implementation") + public void shouldReadArrayOfRecords(Implementation implementation) { // given Schema recordSchema = createRecord("record", createPrimitiveUnionFieldSchema("field", Schema.Type.STRING)); @@ -593,12 +592,7 @@ public void shouldReadArrayOfRecords(Boolean whetherUseFastDeserializer) { recordsArray.add(subRecordBuilder); // when - GenericData.Array array = null; - if (whetherUseFastDeserializer) { - array = decodeRecordFast(arrayRecordSchema, arrayRecordSchema, genericDataAsDecoder(recordsArray)); - } else { - array = decodeRecordSlow(arrayRecordSchema, arrayRecordSchema, genericDataAsDecoder(recordsArray)); - } + GenericData.Array array = implementation.decode(arrayRecordSchema, arrayRecordSchema, genericDataAsDecoder(recordsArray)); // then Assert.assertEquals(2, array.size()); @@ -617,19 +611,46 @@ public void shouldReadArrayOfRecords(Boolean whetherUseFastDeserializer) { recordsArray.add(subRecordBuilder); // when - if (whetherUseFastDeserializer) { - array = decodeRecordFast(arrayRecordSchema, arrayRecordSchema, genericDataAsDecoder(recordsArray)); - } else { - array = decodeRecordSlow(arrayRecordSchema, arrayRecordSchema, genericDataAsDecoder(recordsArray)); - } + array = implementation.decode(arrayRecordSchema, arrayRecordSchema, genericDataAsDecoder(recordsArray)); + // then Assert.assertEquals(2, array.size()); Assert.assertEquals(new Utf8("abc"), array.get(0).get("field")); Assert.assertEquals(new Utf8("abc"), array.get(1).get("field")); } - @Test(groups = {"deserializationTest"}, dataProvider = "SlowFastDeserializer") - public void shouldReadMapOfRecords(Boolean whetherUseFastDeserializer) { + @Test(groups = {"deserializationTest"}, dataProvider = "Implementation") + public void shouldReadArrayOfFloats(Implementation implementation) { + // given + Schema elementSchema = Schema.create(Schema.Type.FLOAT); + Schema arraySchema = Schema.createArray(elementSchema); + + List data = new ArrayList<>(2); + data.add(1.0f); + data.add(2.0f); + + GenericData.Array avroArray = new GenericData.Array<>(0, arraySchema); + for (float f: data) { + avroArray.add(f); + } + + // when + List array = implementation.decode(arraySchema, arraySchema, genericDataAsDecoder(avroArray)); + + // then + if (implementation.isFast) { + // The extended API should always be available, regardless of whether warm or cold + Assert.assertTrue(Arrays.stream(array.getClass().getInterfaces()).anyMatch(c -> c.equals(PrimitiveFloatList.class)), + "The returned type should implement " + PrimitiveFloatList.class.getSimpleName()); + } + Assert.assertEquals(array.size(), data.size()); + for (int i = 0; i < data.size(); i++) { + Assert.assertEquals(array.get(i), data.get(i)); + } + } + + @Test(groups = {"deserializationTest"}, dataProvider = "Implementation") + public void shouldReadMapOfRecords(Implementation implementation) { // given Schema recordSchema = createRecord("record", createPrimitiveUnionFieldSchema("field", Schema.Type.STRING)); @@ -643,14 +664,8 @@ public void shouldReadMapOfRecords(Boolean whetherUseFastDeserializer) { recordsMap.put("2", subRecordBuilder); // when - Map map = null; - if (whetherUseFastDeserializer) { - map = decodeRecordFast(mapRecordSchema, mapRecordSchema, - FastSerdeTestsSupport.genericDataAsDecoder(recordsMap, mapRecordSchema)); - } else { - map = decodeRecordSlow(mapRecordSchema, mapRecordSchema, - FastSerdeTestsSupport.genericDataAsDecoder(recordsMap, mapRecordSchema)); - } + Map map = implementation.decode(mapRecordSchema, mapRecordSchema, + FastSerdeTestsSupport.genericDataAsDecoder(recordsMap, mapRecordSchema)); // then Assert.assertEquals(2, map.size()); @@ -668,7 +683,7 @@ public void shouldReadMapOfRecords(Boolean whetherUseFastDeserializer) { recordsMap.put("2", subRecordBuilder); // when - map = decodeRecordFast(mapRecordSchema, mapRecordSchema, + map = decodeRecordWarmFast(mapRecordSchema, mapRecordSchema, FastSerdeTestsSupport.genericDataAsDecoder(recordsMap, mapRecordSchema)); // then @@ -677,8 +692,8 @@ public void shouldReadMapOfRecords(Boolean whetherUseFastDeserializer) { Assert.assertEquals(new Utf8("abc"), map.get(new Utf8("2")).get("field")); } - @Test(groups = {"deserializationTest"}, dataProvider = "SlowFastDeserializer") - public void shouldReadNestedMap(Boolean whetherUseFastDeserializer) { + @Test(groups = {"deserializationTest"}, dataProvider = "Implementation") + public void shouldReadNestedMap(Implementation implementation) { // given Schema nestedMapSchema = createRecord("record", createMapFieldSchema( "mapField", Schema.createMap(Schema.createArray(Schema.create(Schema.Type.INT))))); @@ -693,14 +708,8 @@ public void shouldReadNestedMap(Boolean whetherUseFastDeserializer) { recordData.put("mapField", mapField); // when - GenericData.Record decodedRecord = null; - if (whetherUseFastDeserializer) { - decodedRecord = decodeRecordFast(nestedMapSchema, nestedMapSchema, - FastSerdeTestsSupport.genericDataAsDecoder(recordData, nestedMapSchema)); - } else { - decodedRecord = decodeRecordSlow(nestedMapSchema, nestedMapSchema, - FastSerdeTestsSupport.genericDataAsDecoder(recordData, nestedMapSchema)); - } + GenericData.Record decodedRecord = implementation.decode(nestedMapSchema, nestedMapSchema, + FastSerdeTestsSupport.genericDataAsDecoder(recordData, nestedMapSchema)); // then Object decodedMapField = decodedRecord.get("mapField"); @@ -712,8 +721,8 @@ public void shouldReadNestedMap(Boolean whetherUseFastDeserializer) { Assert.assertEquals(Arrays.asList(2), ((List) ((Map) subMap).get(new Utf8("subKey2")))); } - @Test(groups = {"deserializationTest"}, dataProvider = "SlowFastDeserializer") - public void shouldReadRecursiveUnionRecord(Boolean whetherUseFastDeserializer) { + @Test(groups = {"deserializationTest"}, dataProvider = "Implementation") + public void shouldReadRecursiveUnionRecord(Implementation implementation) { // given Schema unionRecordSchema = Schema.parse("{\"type\":\"record\",\"name\":\"recordName\",\"namespace\":\"com.linkedin.avro.fastserde.generated.avro\",\"fields\":[{\"name\":\"strField\",\"type\":\"string\"},{\"name\":\"unionField\",\"type\":[\"null\",\"recordName\"]}]}"); @@ -725,14 +734,8 @@ public void shouldReadRecursiveUnionRecord(Boolean whetherUseFastDeserializer) { recordData.put("unionField", unionField); // when - GenericData.Record decodedRecord = null; - if (whetherUseFastDeserializer) { - decodedRecord = decodeRecordFast(unionRecordSchema, unionRecordSchema, - FastSerdeTestsSupport.genericDataAsDecoder(recordData, unionRecordSchema)); - } else { - decodedRecord = decodeRecordSlow(unionRecordSchema, unionRecordSchema, - FastSerdeTestsSupport.genericDataAsDecoder(recordData, unionRecordSchema)); - } + GenericData.Record decodedRecord = implementation.decode(unionRecordSchema, unionRecordSchema, + FastSerdeTestsSupport.genericDataAsDecoder(recordData, unionRecordSchema)); // then Assert.assertEquals(new Utf8("foo"), decodedRecord.get("strField")); @@ -741,11 +744,22 @@ public void shouldReadRecursiveUnionRecord(Boolean whetherUseFastDeserializer) { Assert.assertEquals(new Utf8("bar"), ((GenericData.Record) decodedUnionField).get("strField")); } - public T decodeRecordFast(Schema writerSchema, Schema readerSchema, Decoder decoder) { + private static T decodeRecordColdFast(Schema writerSchema, Schema readerSchema, Decoder decoder) { + FastDeserializer deserializer = + new FastSerdeCache.FastDeserializerWithAvroGenericImpl<>(writerSchema, readerSchema); + + return decodeRecordFast(deserializer, decoder); + } + + private static T decodeRecordWarmFast(Schema writerSchema, Schema readerSchema, Decoder decoder) { FastDeserializer deserializer = new FastGenericDeserializerGenerator(writerSchema, readerSchema, tempDir, classLoader, null).generateDeserializer(); + return decodeRecordFast(deserializer, decoder); + } + + private static T decodeRecordFast(FastDeserializer deserializer, Decoder decoder) { try { return deserializer.deserialize(null, decoder); } catch (Exception e) { @@ -754,7 +768,7 @@ public T decodeRecordFast(Schema writerSchema, Schema readerSchema, Decoder } @SuppressWarnings("unchecked") - private T decodeRecordSlow(Schema readerSchema, Schema writerSchema, Decoder decoder) { + private static T decodeRecordSlow(Schema writerSchema, Schema readerSchema, Decoder decoder) { org.apache.avro.io.DatumReader datumReader = new GenericDatumReader<>(writerSchema, readerSchema); try { return (T) datumReader.read(null, decoder);