diff --git a/build.gradle b/build.gradle index e6ad01275b..7d1a332888 100644 --- a/build.gradle +++ b/build.gradle @@ -85,6 +85,7 @@ ext.libraries = [ grpcServices: "io.grpc:grpc-services:${grpcVersion}", grpcStub: "io.grpc:grpc-stub:${grpcVersion}", hadoopCommon: "org.apache.hadoop:hadoop-common:${hadoopVersion}", + hadoopHdfs: "org.apache.hadoop:hadoop-hdfs:${hadoopVersion}", httpAsyncClient: 'org.apache.httpcomponents:httpasyncclient:4.1.5', httpClient5: 'org.apache.httpcomponents.client5:httpclient5:5.3', httpCore5: 'org.apache.httpcomponents.core5:httpcore5:5.2.4', diff --git a/clients/venice-push-job/build.gradle b/clients/venice-push-job/build.gradle index efc07f7068..1c8b4f98e6 100644 --- a/clients/venice-push-job/build.gradle +++ b/clients/venice-push-job/build.gradle @@ -27,6 +27,13 @@ dependencies { exclude group: 'javax.servlet' } + implementation (libraries.hadoopHdfs) { + // Exclude transitive dependency + exclude group: 'org.apache.avro' + exclude group: 'javax.servlet' + exclude group: 'com.fasterxml.jackson.core' + } + implementation (libraries.apacheSparkAvro) { // Spark 3.1 depends on Avro 1.8.2 - which uses avro-mapred with the hadoop2 classifier. Starting from Avro 1.9 // onwards, avro-mapred is no longer published with a hadoop2 classifier, but Gradle still looks for one. diff --git a/clients/venice-push-job/src/main/java/com/linkedin/venice/hadoop/task/datawriter/AbstractPartitionWriter.java b/clients/venice-push-job/src/main/java/com/linkedin/venice/hadoop/task/datawriter/AbstractPartitionWriter.java index f8df8c2575..661290ebb0 100644 --- a/clients/venice-push-job/src/main/java/com/linkedin/venice/hadoop/task/datawriter/AbstractPartitionWriter.java +++ b/clients/venice-push-job/src/main/java/com/linkedin/venice/hadoop/task/datawriter/AbstractPartitionWriter.java @@ -10,9 +10,7 @@ import static com.linkedin.venice.vpj.VenicePushJobConstants.TELEMETRY_MESSAGE_INTERVAL; import static com.linkedin.venice.vpj.VenicePushJobConstants.TOPIC_PROP; import static com.linkedin.venice.vpj.VenicePushJobConstants.VALUE_SCHEMA_ID_PROP; -import static com.linkedin.venice.vpj.VenicePushJobConstants.VSON_PUSH; -import com.linkedin.avroutil1.compatibility.AvroCompatibilityHelper; import com.linkedin.venice.ConfigKeys; import com.linkedin.venice.annotation.NotThreadsafe; import com.linkedin.venice.exceptions.RecordTooLargeException; @@ -21,17 +19,12 @@ import com.linkedin.venice.guid.GuidUtils; import com.linkedin.venice.hadoop.InputStorageQuotaTracker; import com.linkedin.venice.hadoop.engine.EngineTaskConfigProvider; -import com.linkedin.venice.hadoop.input.recordreader.AbstractVeniceRecordReader; -import com.linkedin.venice.hadoop.input.recordreader.avro.VeniceAvroRecordReader; -import com.linkedin.venice.hadoop.input.recordreader.vson.VeniceVsonRecordReader; import com.linkedin.venice.hadoop.task.TaskTracker; import com.linkedin.venice.meta.Store; import com.linkedin.venice.partitioner.VenicePartitioner; import com.linkedin.venice.pubsub.api.PubSubProduceResult; import com.linkedin.venice.pubsub.api.PubSubProducerCallback; import com.linkedin.venice.serialization.DefaultSerializer; -import com.linkedin.venice.serializer.FastSerializerDeserializerFactory; -import com.linkedin.venice.serializer.RecordDeserializer; import com.linkedin.venice.utils.ByteUtils; import com.linkedin.venice.utils.PartitionUtils; import com.linkedin.venice.utils.SystemTime; @@ -44,7 +37,6 @@ import com.linkedin.venice.writer.VeniceWriter; import com.linkedin.venice.writer.VeniceWriterFactory; import com.linkedin.venice.writer.VeniceWriterOptions; -import java.io.ByteArrayOutputStream; import java.io.Closeable; import java.io.IOException; import java.nio.ByteBuffer; @@ -56,9 +48,6 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; -import org.apache.avro.Schema; -import org.apache.avro.generic.GenericDatumWriter; -import org.apache.avro.io.Encoder; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -247,7 +236,7 @@ protected VeniceWriterMessage extract( if (duplicateKeyPrinter == null) { throw new VeniceException("'DuplicateKeyPrinter' is not initialized properly"); } - duplicateKeyPrinter.detectAndHandleDuplicateKeys(keyBytes, valueBytes, values, dataWriterTaskTracker); + duplicateKeyPrinter.detectAndHandleDuplicateKeys(valueBytes, values, dataWriterTaskTracker); return new VeniceWriterMessage( keyBytes, valueBytes, @@ -545,24 +534,13 @@ public static class DuplicateKeyPrinter implements AutoCloseable, Closeable { private final boolean isDupKeyAllowed; - private final Schema keySchema; - private final RecordDeserializer keyDeserializer; - private final GenericDatumWriter avroDatumWriter; private int numOfDupKey = 0; DuplicateKeyPrinter(VeniceProperties props) { this.isDupKeyAllowed = props.getBoolean(ALLOW_DUPLICATE_KEY, false); - - AbstractVeniceRecordReader schemaReader = props.getBoolean(VSON_PUSH, false) - ? new VeniceVsonRecordReader(props) - : VeniceAvroRecordReader.fromProps(props); - this.keySchema = schemaReader.getKeySchema(); - this.keyDeserializer = FastSerializerDeserializerFactory.getFastAvroGenericDeserializer(keySchema, keySchema); - this.avroDatumWriter = new GenericDatumWriter<>(keySchema); } protected void detectAndHandleDuplicateKeys( - byte[] keyBytes, byte[] valueBytes, Iterator values, DataWriterTaskTracker dataWriterTaskTracker) { @@ -579,7 +557,8 @@ protected void detectAndHandleDuplicateKeys( identicalValuesToKeyCount++; if (shouldPrint) { shouldPrint = false; - LOGGER.warn(printDuplicateKey(keyBytes)); + numOfDupKey++; + LOGGER.warn("There are multiple records for the same key"); } } else { // Distinct values map to the same key. E.g. key:[ value_1, value_2 ] @@ -588,7 +567,8 @@ protected void detectAndHandleDuplicateKeys( if (isDupKeyAllowed) { if (shouldPrint) { shouldPrint = false; - LOGGER.warn(printDuplicateKey(keyBytes)); + numOfDupKey++; + LOGGER.warn("There are multiple records for the same key"); } } } @@ -597,21 +577,6 @@ protected void detectAndHandleDuplicateKeys( dataWriterTaskTracker.trackDuplicateKeyWithDistinctValue(distinctValuesToKeyCount); } - private String printDuplicateKey(byte[] keyBytes) { - Object keyRecord = keyDeserializer.deserialize(keyBytes); - try (ByteArrayOutputStream output = new ByteArrayOutputStream()) { - Encoder jsonEncoder = AvroCompatibilityHelper.newJsonEncoder(keySchema, output, false); - avroDatumWriter.write(keyRecord, jsonEncoder); - jsonEncoder.flush(); - output.flush(); - - numOfDupKey++; - return String.format("There are multiple records for key:\n%s", new String(output.toByteArray())); - } catch (IOException exception) { - throw new VeniceException(exception); - } - } - @Override public void close() { // Nothing to do diff --git a/clients/venice-push-job/src/main/java/com/linkedin/venice/spark/datawriter/jobs/DataWriterSparkJob.java b/clients/venice-push-job/src/main/java/com/linkedin/venice/spark/datawriter/jobs/DataWriterSparkJob.java index 2288aa370b..4c46568b49 100644 --- a/clients/venice-push-job/src/main/java/com/linkedin/venice/spark/datawriter/jobs/DataWriterSparkJob.java +++ b/clients/venice-push-job/src/main/java/com/linkedin/venice/spark/datawriter/jobs/DataWriterSparkJob.java @@ -1,23 +1,40 @@ package com.linkedin.venice.spark.datawriter.jobs; +import static com.linkedin.venice.spark.SparkConstants.DEFAULT_SCHEMA; import static com.linkedin.venice.vpj.VenicePushJobConstants.ETL_VALUE_SCHEMA_TRANSFORMATION; import static com.linkedin.venice.vpj.VenicePushJobConstants.FILE_KEY_SCHEMA; import static com.linkedin.venice.vpj.VenicePushJobConstants.FILE_VALUE_SCHEMA; import static com.linkedin.venice.vpj.VenicePushJobConstants.GENERATE_PARTIAL_UPDATE_RECORD_FROM_INPUT; +import static com.linkedin.venice.vpj.VenicePushJobConstants.GLOB_FILTER_PATTERN; import static com.linkedin.venice.vpj.VenicePushJobConstants.INPUT_PATH_PROP; import static com.linkedin.venice.vpj.VenicePushJobConstants.KEY_FIELD_PROP; import static com.linkedin.venice.vpj.VenicePushJobConstants.SCHEMA_STRING_PROP; +import static com.linkedin.venice.vpj.VenicePushJobConstants.SPARK_NATIVE_INPUT_FORMAT_ENABLED; import static com.linkedin.venice.vpj.VenicePushJobConstants.UPDATE_SCHEMA_STRING_PROP; import static com.linkedin.venice.vpj.VenicePushJobConstants.VALUE_FIELD_PROP; import static com.linkedin.venice.vpj.VenicePushJobConstants.VSON_PUSH; +import com.linkedin.avroutil1.compatibility.AvroCompatibilityHelper; import com.linkedin.venice.hadoop.PushJobSetting; +import com.linkedin.venice.hadoop.input.recordreader.avro.VeniceAvroRecordReader; +import com.linkedin.venice.hadoop.input.recordreader.vson.VeniceVsonRecordReader; import com.linkedin.venice.spark.input.hdfs.VeniceHdfsSource; +import com.linkedin.venice.spark.utils.RowToAvroConverter; +import com.linkedin.venice.utils.VeniceProperties; +import org.apache.avro.Schema; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.generic.IndexedRecord; +import org.apache.avro.mapred.AvroWrapper; import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.BytesWritable; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.sql.DataFrameReader; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.encoders.RowEncoder; +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema; /** @@ -29,6 +46,19 @@ protected Dataset getUserInputDataFrame() { SparkSession sparkSession = getSparkSession(); PushJobSetting pushJobSetting = getPushJobSetting(); + VeniceProperties jobProps = getJobProperties(); + boolean useNativeInputFormat = jobProps.getBoolean(SPARK_NATIVE_INPUT_FORMAT_ENABLED, false); + + if (!useNativeInputFormat) { + return getDataFrameFromCustomInputFormat(sparkSession, pushJobSetting); + } else if (pushJobSetting.isAvro) { + return getAvroDataFrame(sparkSession, pushJobSetting); + } else { + return getVsonDataFrame(sparkSession, pushJobSetting); + } + } + + private Dataset getDataFrameFromCustomInputFormat(SparkSession sparkSession, PushJobSetting pushJobSetting) { DataFrameReader dataFrameReader = sparkSession.read(); dataFrameReader.format(VeniceHdfsSource.class.getCanonicalName()); setInputConf(sparkSession, dataFrameReader, INPUT_PATH_PROP, new Path(pushJobSetting.inputURI).toString()); @@ -55,4 +85,53 @@ protected Dataset getUserInputDataFrame() { } return dataFrameReader.load(); } + + private Dataset getAvroDataFrame(SparkSession sparkSession, PushJobSetting pushJobSetting) { + Dataset df = + sparkSession.read().format("avro").option("pathGlobFilter", GLOB_FILTER_PATTERN).load(pushJobSetting.inputURI); + + // Transforming the input data format + df = df.map((MapFunction) (record) -> { + Schema updateSchema = null; + if (pushJobSetting.generatePartialUpdateRecordFromInput) { + updateSchema = AvroCompatibilityHelper.parse(pushJobSetting.valueSchemaString); + } + + GenericRecord rowRecord = RowToAvroConverter.convert(record, pushJobSetting.inputDataSchema); + VeniceAvroRecordReader recordReader = new VeniceAvroRecordReader( + pushJobSetting.inputDataSchema, + pushJobSetting.keyField, + pushJobSetting.valueField, + pushJobSetting.etlValueSchemaTransformation, + updateSchema); + + AvroWrapper recordAvroWrapper = new AvroWrapper<>(rowRecord); + final byte[] inputKeyBytes = recordReader.getKeyBytes(recordAvroWrapper, null); + final byte[] inputValueBytes = recordReader.getValueBytes(recordAvroWrapper, null); + + return new GenericRowWithSchema(new Object[] { inputKeyBytes, inputValueBytes }, DEFAULT_SCHEMA); + }, RowEncoder.apply(DEFAULT_SCHEMA)); + + return df; + } + + @Deprecated + private Dataset getVsonDataFrame(SparkSession sparkSession, PushJobSetting pushJobSetting) { + JavaRDD rdd = sparkSession.sparkContext() + .sequenceFile(pushJobSetting.inputURI, BytesWritable.class, BytesWritable.class) + .toJavaRDD() + .map(record -> { + VeniceVsonRecordReader recordReader = new VeniceVsonRecordReader( + pushJobSetting.vsonInputKeySchemaString, + pushJobSetting.vsonInputValueSchemaString, + pushJobSetting.keyField, + pushJobSetting.valueField); + + final byte[] inputKeyBytes = recordReader.getKeyBytes(record._1, record._2); + final byte[] inputValueBytes = recordReader.getValueBytes(record._1, record._2); + + return new GenericRowWithSchema(new Object[] { inputKeyBytes, inputValueBytes }, DEFAULT_SCHEMA); + }); + return sparkSession.createDataFrame(rdd, DEFAULT_SCHEMA); + } } diff --git a/clients/venice-push-job/src/main/java/com/linkedin/venice/spark/utils/RowToAvroConverter.java b/clients/venice-push-job/src/main/java/com/linkedin/venice/spark/utils/RowToAvroConverter.java new file mode 100644 index 0000000000..8ba86c734a --- /dev/null +++ b/clients/venice-push-job/src/main/java/com/linkedin/venice/spark/utils/RowToAvroConverter.java @@ -0,0 +1,483 @@ +package com.linkedin.venice.spark.utils; + +import com.linkedin.avroutil1.compatibility.AvroCompatibilityHelper; +import com.linkedin.venice.utils.ByteUtils; +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.sql.Date; +import java.time.Duration; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.Period; +import java.time.temporal.ChronoUnit; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import org.apache.avro.Conversions; +import org.apache.avro.LogicalType; +import org.apache.avro.LogicalTypes; +import org.apache.avro.Schema; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericEnumSymbol; +import org.apache.avro.generic.GenericFixed; +import org.apache.avro.generic.GenericRecord; +import org.apache.commons.lang3.Validate; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.BinaryType; +import org.apache.spark.sql.types.BooleanType; +import org.apache.spark.sql.types.ByteType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.DayTimeIntervalType; +import org.apache.spark.sql.types.DecimalType; +import org.apache.spark.sql.types.DoubleType; +import org.apache.spark.sql.types.FloatType; +import org.apache.spark.sql.types.IntegerType; +import org.apache.spark.sql.types.LongType; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.ShortType; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.TimestampNTZType; +import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.sql.types.YearMonthIntervalType; +import scala.collection.JavaConverters; + + +/** + * A utility class to convert Spark SQL Row to an Avro GenericRecord with the specified schema. This has been written in + * accordance with the following resources: + * + * + * Spark's implementation is not ideal to be used directly for two reasons: + *
    + *
  • It cannot handle complex unions in the version of Spark that we use (3.3.3). The support was added in 3.4.0.
  • + *
  • It converts directly to Avro binary that we need to deserialize, and that incurs an additional serde cost.
  • + *
+ */ +public final class RowToAvroConverter { + private RowToAvroConverter() { + } + + private static final Conversions.DecimalConversion DECIMAL_CONVERTER = new Conversions.DecimalConversion(); + + public static GenericRecord convert(Row row, Schema schema) { + Validate.notNull(row, "Row must not be null"); + Validate.notNull(schema, "Schema must not be null"); + Validate + .isTrue(schema.getType().equals(Schema.Type.RECORD), "Schema must be of type RECORD. Got: " + schema.getType()); + Validate.isInstanceOf(Row.class, row, "Row must be of type Row. Got: " + row.getClass().getName()); + + return convertToRecord(row, row.schema(), schema); + } + + static GenericRecord convertToRecord(Object o, DataType dataType, Schema schema) { + Validate.isInstanceOf(StructType.class, dataType, "Expected StructType, got: " + dataType.getClass().getName()); + Validate.isInstanceOf(Row.class, o, "Expected Row, got: " + o.getClass().getName()); + GenericRecord aResult = new GenericData.Record(schema); + + Row row = (Row) o; + + StructType sType = row.schema(); + StructField[] sFields = sType.fields(); + List aFields = schema.getFields(); + + Validate.isTrue( + sFields.length == aFields.size(), + "Row and Avro schema must have the same number of fields. Row: " + sFields.length + ", Avro: " + + aFields.size()); + + for (int i = 0; i < sFields.length; i++) { + StructField structField = sFields[i]; + Schema.Field avroField = aFields.get(i); + + // Spark field names are case-insensitive + Validate.isTrue( + structField.name().equalsIgnoreCase(avroField.name()), + "Field names must match. Row: " + structField.name() + ", Avro: " + avroField.name()); + + Object elem = row.get(i); + aResult.put(i, convertInternal(elem, structField.dataType(), avroField.schema())); + } + + return aResult; + } + + static Boolean convertToBoolean(Object o, DataType dataType) { + Validate.isInstanceOf(BooleanType.class, dataType, "Expected BooleanType, got: " + dataType.getClass().getName()); + Validate.isInstanceOf(Boolean.class, o, "Expected Boolean, got: " + o.getClass().getName()); + return ((Boolean) o); + } + + static Integer convertToInt(Object o, DataType dataType, Schema schema) { + // IntegerType + if (dataType instanceof IntegerType) { + Validate.isInstanceOf(Integer.class, o, "Expected Integer, got: " + o.getClass().getName()); + return ((Integer) o); + } + + // Avro logical type "date" is read as DateType in Spark + if (dataType instanceof DateType) { + validateLogicalType(schema, LogicalTypes.date()); + + LocalDate localDate; + + if (o instanceof LocalDate) { + localDate = ((LocalDate) o); + } else if (o instanceof Date) { + localDate = ((Date) o).toLocalDate(); + } else { + throw new IllegalArgumentException( + "Unsupported date type: " + o.getClass().getName() + ". Expected java.time.LocalDate or java.sql.Date"); + } + + // Long to int, but we are sure that it fits + return (int) localDate.toEpochDay(); + } + + if (dataType instanceof ByteType) { + Validate.isInstanceOf(Byte.class, o, "Expected Integer, got: " + o.getClass().getName()); + return ((Byte) o).intValue(); + } + + if (dataType instanceof ShortType) { + Validate.isInstanceOf(Short.class, o, "Expected Integer, got: " + o.getClass().getName()); + return ((Short) o).intValue(); + } + + // Spark default Avro converter converts YearMonthIntervalType to int type + // This is not the type read by Spark's native Avro reader, but added to support YearMonthIntervalType + if (dataType instanceof YearMonthIntervalType) { + Validate.isInstanceOf(Period.class, o, "Expected Period, got: " + o.getClass().getName()); + return ((Period) o).getMonths(); + } + + throw new IllegalArgumentException("Unsupported data type: " + dataType); + } + + static Long convertToLong(Object o, DataType dataType, Schema schema) { + // LongType + if (dataType instanceof LongType) { + Validate.isInstanceOf(Long.class, o, "Expected Long, got: " + o.getClass().getName()); + return ((Long) o); + } + + // Avro logical types "timestamp-millis" and "timestamp-micros" are read as LongType in Spark + if (dataType instanceof TimestampType) { + LogicalType logicalType = + validateLogicalType(schema, false, LogicalTypes.timestampMicros(), LogicalTypes.timestampMillis()); + + Instant instant; + if (o instanceof java.time.Instant) { + instant = ((java.time.Instant) o); + } else if (o instanceof java.sql.Timestamp) { + instant = ((java.sql.Timestamp) o).toInstant(); + } else { + throw new IllegalArgumentException( + "Unsupported timestamp type: " + o.getClass().getName() + + ". Expected java.time.Instant or java.sql.Timestamp"); + } + + if (logicalType == null || logicalType == LogicalTypes.timestampMillis()) { + return ChronoUnit.MILLIS.between(Instant.EPOCH, instant); + } + + return ChronoUnit.MICROS.between(Instant.EPOCH, instant); + } + + // Spark default Avro converter converts TimestampNTZType to int type + // This is not the type read by Spark's native Avro reader, but added to support TimestampNTZType + // Avro logical types "local-timestamp-millis" and "local-timestamp-micros" are read as LongType in Spark + if (dataType instanceof TimestampNTZType) { + LogicalType logicalType = + validateLogicalType(schema, false, LogicalTypes.localTimestampMicros(), LogicalTypes.localTimestampMillis()); + Validate.isInstanceOf(java.time.LocalDateTime.class, o, "Expected LocalDateTime, got: " + o.getClass().getName()); + + LocalDateTime localDateTime = ((java.time.LocalDateTime) o); + LocalDateTime epoch = LocalDateTime.of(1970, 1, 1, 0, 0, 0); + + if (logicalType == null || logicalType == LogicalTypes.localTimestampMillis()) { + return ChronoUnit.MILLIS.between(epoch, localDateTime); + } + + return ChronoUnit.MICROS.between(epoch, localDateTime); + } + + // Spark default Avro converter converts DayTimeIntervalType to long type + // This is not the type read by Spark's native Avro reader, but added to support DayTimeIntervalType + if (dataType instanceof DayTimeIntervalType) { + Validate.isInstanceOf(Duration.class, o, "Expected Duration, got: " + o.getClass().getName()); + Duration duration = (Duration) o; + return TimeUnit.SECONDS.toMicros(duration.getSeconds()) + TimeUnit.NANOSECONDS.toMicros(duration.getNano()); + } + + throw new IllegalArgumentException("Unsupported data type: " + dataType); + } + + static Float convertToFloat(Object o, DataType dataType) { + Validate.isInstanceOf(FloatType.class, dataType, "Expected FloatType, got: " + dataType); + Validate.isInstanceOf(Float.class, o, "Expected Float, got: " + o.getClass().getName()); + return ((Float) o); + } + + static Double convertToDouble(Object o, DataType dataType) { + Validate.isInstanceOf(DoubleType.class, dataType, "Expected DoubleType, got: " + dataType); + Validate.isInstanceOf(Double.class, o, "Expected Double, got: " + o.getClass().getName()); + return ((Double) o); + } + + static CharSequence convertToString(Object o, DataType dataType) { + Validate.isInstanceOf(StringType.class, dataType, "Expected StringType, got: " + dataType); + Validate.isInstanceOf(CharSequence.class, o, "Expected CharSequence, got: " + o.getClass().getName()); + return ((CharSequence) o); + } + + static ByteBuffer convertToBytes(Object o, DataType dataType, Schema schema) { + if (dataType instanceof BinaryType) { + if (o instanceof byte[]) { + return ByteBuffer.wrap((byte[]) o); + } + + if (o instanceof ByteBuffer) { + return (ByteBuffer) o; + } + + throw new IllegalArgumentException("Unsupported byte array type: " + o.getClass().getName()); + } + + if (dataType instanceof DecimalType) { + DecimalType decimalType = (DecimalType) dataType; + validateLogicalType(schema, LogicalTypes.decimal(decimalType.precision(), decimalType.scale())); + Validate.isInstanceOf(BigDecimal.class, o, "Expected BigDecimal, got: " + o.getClass().getName()); + BigDecimal decimal = (BigDecimal) o; + LogicalTypes.Decimal l = (LogicalTypes.Decimal) schema.getLogicalType(); + return DECIMAL_CONVERTER.toBytes(decimal, schema, l); + } + + throw new IllegalArgumentException("Unsupported data type: " + dataType); + } + + static GenericFixed convertToFixed(Object o, DataType dataType, Schema schema) { + if (dataType instanceof BinaryType) { + if (o instanceof byte[]) { + byte[] bytes = (byte[]) o; + Validate.isTrue( + bytes.length == schema.getFixedSize(), + "Fixed size mismatch. Expected: " + schema.getFixedSize() + ", got: " + bytes.length); + return AvroCompatibilityHelper.newFixed(schema, bytes); + } + + if (o instanceof ByteBuffer) { + ByteBuffer bytes = (ByteBuffer) o; + Validate.isTrue( + bytes.remaining() == schema.getFixedSize(), + "Fixed size mismatch. Expected: " + schema.getFixedSize() + ", got: " + bytes.remaining()); + return AvroCompatibilityHelper.newFixed(schema, ByteUtils.extractByteArray(bytes)); + } + + throw new IllegalArgumentException("Unsupported byte array type: " + o.getClass().getName()); + } + + if (dataType instanceof DecimalType) { + DecimalType decimalType = (DecimalType) dataType; + validateLogicalType(schema, LogicalTypes.decimal(decimalType.precision(), decimalType.scale())); + Validate.isInstanceOf(BigDecimal.class, o, "Expected BigDecimal, got: " + o.getClass().getName()); + BigDecimal decimal = (BigDecimal) o; + Conversions.DecimalConversion DECIMAL_CONVERTER = new Conversions.DecimalConversion(); + LogicalTypes.Decimal l = (LogicalTypes.Decimal) schema.getLogicalType(); + return DECIMAL_CONVERTER.toFixed(decimal, schema, l); + } + + throw new IllegalArgumentException("Unsupported data type: " + dataType); + } + + static GenericEnumSymbol convertToEnum(Object o, DataType dataType, Schema schema) { + Validate.isInstanceOf(StringType.class, dataType, "Expected StringType, got: " + dataType); + Validate.isInstanceOf(CharSequence.class, o, "Expected CharSequence, got: " + o.getClass().getName()); + Validate.isTrue(schema.getEnumSymbols().contains(o.toString()), "Enum symbol not found: " + o); + return AvroCompatibilityHelper.newEnumSymbol(schema, ((CharSequence) o).toString()); + } + + static List convertToArray(Object o, DataType dataType, Schema schema) { + Validate.isInstanceOf(ArrayType.class, dataType, "Expected ArrayType, got: " + dataType); + + // Type of elements in the array + Schema elementType = schema.getElementType(); + + List inputList; + if (o instanceof List) { + inputList = (List) o; + } else if (o instanceof scala.collection.Seq) { + // If the input is a scala.collection.Seq, convert it to a List + inputList = JavaConverters.seqAsJavaList((scala.collection.Seq) o); + } else { + throw new IllegalArgumentException("Unsupported array type: " + o.getClass().getName()); + } + + List outputList = new ArrayList<>(inputList.size()); + + for (Object element: inputList) { + outputList.add(convertInternal(element, ((ArrayType) dataType).elementType(), elementType)); + } + + return outputList; + } + + static Map convertToMap(Object o, DataType dataType, Schema schema) { + Validate.isInstanceOf(MapType.class, dataType, "Expected MapType, got: " + dataType.getClass().getName()); + + MapType sType = ((MapType) dataType); + + Map inputMap; + if (o instanceof Map) { + inputMap = (Map) o; + } else if (o instanceof scala.collection.Map) { + inputMap = JavaConverters.mapAsJavaMap((scala.collection.Map) o); + } else { + throw new IllegalArgumentException("Unsupported map type: " + o.getClass().getName()); + } + + Map outputMap = new HashMap<>(inputMap.size()); + + for (Object entryObj: inputMap.entrySet()) { + Validate.isInstanceOf(Map.Entry.class, entryObj, "Expected Map.Entry, got: " + entryObj.getClass().getName()); + Map.Entry entry = (Map.Entry) entryObj; + outputMap.put( + // Key is always a String in Avro + convertToString(entry.getKey(), sType.keyType()), + convertInternal(entry.getValue(), sType.valueType(), schema.getValueType())); + } + + return outputMap; + } + + static Object convertToUnion(Object o, DataType dataType, Schema schema) { + if (o == null) { + Validate.isTrue(schema.isNullable(), "Field is not nullable: " + schema.getName()); + return null; + } + + // Now that we've checked for null explicitly, we should process everything else as a non-null value. + // This is consistent with the way Spark handles unions. + List types = + schema.getTypes().stream().filter(s -> s.getType() != Schema.Type.NULL).collect(Collectors.toList()); + Schema first = types.get(0); + // If there's only one branch, Spark will use that as the data type + if (types.size() == 1) { + return convertInternal(o, dataType, first); + } + + Schema second = types.get(1); + if (types.size() == 2) { + // A union of int and long is read as LongType. + // This is lossy because we cannot know what type was provided in the input + if ((first.getType() == Schema.Type.INT && second.getType() == Schema.Type.LONG) + || (first.getType() == Schema.Type.LONG && second.getType() == Schema.Type.INT)) { + return convertToLong(o, dataType, schema); + } + + // A union of float and double is read as DoubleType. + // This is lossy because we cannot know what type was provided in the input + if ((first.getType() == Schema.Type.FLOAT && second.getType() == Schema.Type.DOUBLE) + || (first.getType() == Schema.Type.DOUBLE && second.getType() == Schema.Type.FLOAT)) { + return convertToDouble(o, dataType); + } + } + + // Now, handle complex unions: member0, member1, ... + // If a branch of the union is "null", then it is skipped in the Catalyst schema. + // So, [ "null", "int", "string" ], [ "int", "null", "string" ], [ "int", "string", "null" ], will all be parsed as + // StructType { member0 -> IntegerType, member1 -> StringType }. + Validate.isInstanceOf(StructType.class, dataType, "Expected StructType, got: " + dataType.getClass().getName()); + Validate.isInstanceOf(Row.class, o, "Expected Row, got: " + o.getClass().getName()); + Row row = (Row) o; + + StructType structType = (StructType) dataType; + StructField[] structFields = structType.fields(); + int structFieldIndex = 0; + for (Schema type: types) { + Validate.isTrue(type.getType() != Schema.Type.NULL); + + Object unionField = row.get(structFieldIndex); + if (unionField != null) { + return convertInternal(unionField, structFields[structFieldIndex].dataType(), type); + } + structFieldIndex++; + } + + throw new IllegalArgumentException("At least one field of complex union must be non-null: " + types); + } + + private static Object convertInternal(Object o, DataType dataType, Schema schema) { + if (o == null) { + Validate.isTrue(schema.isNullable(), "Field is not nullable: " + schema.getName()); + return null; + } + + switch (schema.getType()) { + case BOOLEAN: + return convertToBoolean(o, dataType); + case INT: + return convertToInt(o, dataType, schema); + case LONG: + return convertToLong(o, dataType, schema); + case FLOAT: + return convertToFloat(o, dataType); + case DOUBLE: + return convertToDouble(o, dataType); + case STRING: + return convertToString(o, dataType); + case BYTES: + return convertToBytes(o, dataType, schema); + case FIXED: + return convertToFixed(o, dataType, schema); + case ENUM: + return convertToEnum(o, dataType, schema); + case ARRAY: + return convertToArray(o, dataType, schema); + case MAP: + return convertToMap(o, dataType, schema); + case RECORD: + return convertToRecord(o, dataType, schema); + case UNION: + return convertToUnion(o, dataType, schema); + default: + throw new IllegalArgumentException("Unsupported Avro type: " + schema.getType()); + } + } + + static LogicalType validateLogicalType(Schema schema, LogicalType... expectedTypes) { + return validateLogicalType(schema, true, expectedTypes); + } + + static LogicalType validateLogicalType(Schema schema, boolean needLogicalType, LogicalType... expectedTypes) { + LogicalType logicalType = schema.getLogicalType(); + if (logicalType == null) { + if (needLogicalType) { + throw new IllegalArgumentException("Expected Avro logical type to be present, got schema: " + schema); + } else { + return null; + } + } + + for (LogicalType expectedType: expectedTypes) { + if (logicalType.equals(expectedType)) { + return expectedType; + } + } + + throw new IllegalArgumentException( + "Expected Avro logical type to be one of: " + Arrays.toString(expectedTypes) + ", got: " + logicalType); + } +} diff --git a/clients/venice-push-job/src/main/java/com/linkedin/venice/vpj/VenicePushJobConstants.java b/clients/venice-push-job/src/main/java/com/linkedin/venice/vpj/VenicePushJobConstants.java index c6f62f4ee2..3e172755f5 100644 --- a/clients/venice-push-job/src/main/java/com/linkedin/venice/vpj/VenicePushJobConstants.java +++ b/clients/venice-push-job/src/main/java/com/linkedin/venice/vpj/VenicePushJobConstants.java @@ -33,6 +33,9 @@ private VenicePushJobConstants() { public static final boolean DEFAULT_EXTENDED_SCHEMA_VALIDITY_CHECK_ENABLED = true; public static final String UPDATE_SCHEMA_STRING_PROP = "update.schema"; + // This is a temporary config used to rollout the native input format for Spark. This will be removed soon + public static final String SPARK_NATIVE_INPUT_FORMAT_ENABLED = "spark.native.input.format.enabled"; + // Vson input configs // Vson files store key/value schema on file header. key / value fields are optional // and should be specified only when key / value schema is the partial of the files. @@ -224,6 +227,7 @@ private VenicePushJobConstants() { * ignore hdfs files with prefix "_" and "." */ public static final PathFilter PATH_FILTER = p -> !p.getName().startsWith("_") && !p.getName().startsWith("."); + public static final String GLOB_FILTER_PATTERN = "[^_.]*"; // Configs to control temp paths and their permissions public static final String HADOOP_TMP_DIR = "hadoop.tmp.dir"; diff --git a/clients/venice-push-job/src/test/java/com/linkedin/venice/spark/utils/RowToAvroConverterTest.java b/clients/venice-push-job/src/test/java/com/linkedin/venice/spark/utils/RowToAvroConverterTest.java new file mode 100644 index 0000000000..da28c8982d --- /dev/null +++ b/clients/venice-push-job/src/test/java/com/linkedin/venice/spark/utils/RowToAvroConverterTest.java @@ -0,0 +1,1214 @@ +package com.linkedin.venice.spark.utils; + +import static org.apache.spark.sql.types.DataTypes.BinaryType; +import static org.apache.spark.sql.types.DataTypes.BooleanType; +import static org.apache.spark.sql.types.DataTypes.ByteType; +import static org.apache.spark.sql.types.DataTypes.DateType; +import static org.apache.spark.sql.types.DataTypes.DoubleType; +import static org.apache.spark.sql.types.DataTypes.FloatType; +import static org.apache.spark.sql.types.DataTypes.IntegerType; +import static org.apache.spark.sql.types.DataTypes.LongType; +import static org.apache.spark.sql.types.DataTypes.ShortType; +import static org.apache.spark.sql.types.DataTypes.StringType; +import static org.apache.spark.sql.types.DataTypes.TimestampType; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertNull; +import static org.testng.Assert.assertThrows; +import static org.testng.Assert.assertTrue; + +import com.linkedin.avroutil1.compatibility.AvroCompatibilityHelper; +import com.linkedin.venice.utils.Time; +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.nio.ByteBuffer; +import java.sql.Date; +import java.sql.Timestamp; +import java.time.Duration; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.Period; +import java.time.ZoneOffset; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.avro.LogicalTypes; +import org.apache.avro.Schema; +import org.apache.avro.SchemaBuilder; +import org.apache.avro.generic.GenericData; +import org.apache.avro.generic.GenericEnumSymbol; +import org.apache.avro.generic.GenericFixed; +import org.apache.avro.generic.GenericRecord; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.TimestampNTZType$; +import org.testng.annotations.Test; +import scala.collection.JavaConverters; + + +public class RowToAvroConverterTest { + private static final StructType COMPLEX_SUB_SCHEMA = DataTypes.createStructType( + new StructField[] { new StructField("int", IntegerType, false, Metadata.empty()), + new StructField("string", StringType, false, Metadata.empty()) }); + + private static final StructType UNION_STRUCT_STRING_INT = DataTypes.createStructType( + new StructField[] { new StructField("member0", StringType, true, Metadata.empty()), + new StructField("member1", IntegerType, true, Metadata.empty()) }); + + private static final StructType UNION_STRUCT_DOUBLE_FLOAT_STRING = DataTypes.createStructType( + new StructField[] { new StructField("member0", DoubleType, true, Metadata.empty()), + new StructField("member1", FloatType, true, Metadata.empty()), + new StructField("member2", StringType, true, Metadata.empty()) }); + + private static final StructType SPARK_STRUCT_SCHEMA = new StructType( + new StructField[] { new StructField("byteArr", BinaryType, false, Metadata.empty()), + new StructField("byteBuffer", BinaryType, false, Metadata.empty()), + new StructField("decimalBytes", DataTypes.createDecimalType(3, 2), false, Metadata.empty()), + new StructField("booleanTrue", BooleanType, false, Metadata.empty()), + new StructField("booleanFalse", BooleanType, false, Metadata.empty()), + new StructField("float", FloatType, false, Metadata.empty()), + new StructField("double", DoubleType, false, Metadata.empty()), + new StructField("string", StringType, false, Metadata.empty()), + new StructField("byteArrFixed", BinaryType, false, Metadata.empty()), + new StructField("byteBufferFixed", BinaryType, false, Metadata.empty()), + new StructField("decimalFixed", DataTypes.createDecimalType(3, 2), false, Metadata.empty()), + new StructField("enumType", StringType, false, Metadata.empty()), + new StructField("int", IntegerType, false, Metadata.empty()), + new StructField("date", DateType, false, Metadata.empty()), + new StructField("dateLocal", DateType, false, Metadata.empty()), + new StructField("byte", ByteType, false, Metadata.empty()), + new StructField("short", ShortType, false, Metadata.empty()), + new StructField("yearMonthInterval", DataTypes.createYearMonthIntervalType(), false, Metadata.empty()), + new StructField("long", LongType, false, Metadata.empty()), + new StructField("instantMicros", TimestampType, false, Metadata.empty()), + new StructField("instantMillis", TimestampType, false, Metadata.empty()), + new StructField("timestampMicros", TimestampType, false, Metadata.empty()), + new StructField("timestampMillis", TimestampType, false, Metadata.empty()), + new StructField("timestampNoLogical", TimestampType, false, Metadata.empty()), + new StructField("localTimestampMicros", TimestampNTZType$.MODULE$, false, Metadata.empty()), + new StructField("localTimestampMillis", TimestampNTZType$.MODULE$, false, Metadata.empty()), + new StructField("localTimestampNoLogical", TimestampNTZType$.MODULE$, false, Metadata.empty()), + new StructField("dayTimeInterval", DataTypes.createDayTimeIntervalType(), false, Metadata.empty()), + new StructField("arrayIntList", DataTypes.createArrayType(IntegerType), false, Metadata.empty()), + new StructField("arrayIntSeq", DataTypes.createArrayType(IntegerType), false, Metadata.empty()), + new StructField("arrayComplex", DataTypes.createArrayType(COMPLEX_SUB_SCHEMA), false, Metadata.empty()), + new StructField("mapIntJavaMap", DataTypes.createMapType(StringType, IntegerType), false, Metadata.empty()), + new StructField("mapIntScalaMap", DataTypes.createMapType(StringType, IntegerType), false, Metadata.empty()), + new StructField( + "mapComplex", + DataTypes.createMapType(StringType, COMPLEX_SUB_SCHEMA), + false, + Metadata.empty()), + new StructField("nullableUnion", IntegerType, true, Metadata.empty()), + new StructField("nullableUnion2", IntegerType, true, Metadata.empty()), + new StructField("singleElementUnion", IntegerType, false, Metadata.empty()), + new StructField("intLongUnion", LongType, false, Metadata.empty()), + new StructField("longIntUnion", LongType, false, Metadata.empty()), + new StructField("floatDoubleUnion", DoubleType, false, Metadata.empty()), + new StructField("doubleFloatUnion", DoubleType, false, Metadata.empty()), + new StructField("complexNonNullableUnion", UNION_STRUCT_DOUBLE_FLOAT_STRING, false, Metadata.empty()), + new StructField("complexNullableUnion1", UNION_STRUCT_STRING_INT, true, Metadata.empty()), + new StructField("complexNullableUnion2", UNION_STRUCT_STRING_INT, true, Metadata.empty()), + new StructField("complexNullableUnion3", UNION_STRUCT_STRING_INT, true, Metadata.empty()), }); + + private static final Schema DECIMAL_TYPE = LogicalTypes.decimal(3, 2).addToSchema(Schema.create(Schema.Type.BYTES)); + private static final Schema DECIMAL_FIXED_TYPE = + LogicalTypes.decimal(3, 2).addToSchema(Schema.createFixed("decimalFixed", null, null, 3)); + private static final Schema FIXED_TYPE_3 = Schema.createFixed("decimalFixed", null, null, 3); + private static final String STRING_VALUE = "PAX TIBI MARCE EVANGELISTA MEVS"; + private static final String STRING_VALUE_2 = + "It’s temples and palaces did seem like fabrics of enchantment piled to heaven"; + private static final String STRING_VALUE_3 = "Like eating an entire box of chocolate liqueurs in one go"; + private static final Schema DATE_TYPE = LogicalTypes.date().addToSchema(Schema.create(Schema.Type.INT)); + private static final Schema TIMESTAMP_MICROS = + LogicalTypes.timestampMicros().addToSchema(Schema.create(Schema.Type.LONG)); + private static final Schema TIMESTAMP_MILLIS = + LogicalTypes.timestampMillis().addToSchema(Schema.create(Schema.Type.LONG)); + private static final long TEST_EPOCH_MILLIS = 1718860000000L; + private static final Instant TEST_EPOCH_INSTANT = Instant.ofEpochMilli(TEST_EPOCH_MILLIS); + private static final Schema LOCAL_TIMESTAMP_MICROS = + LogicalTypes.localTimestampMicros().addToSchema(Schema.create(Schema.Type.LONG)); + private static final Schema LOCAL_TIMESTAMP_MILLIS = + LogicalTypes.localTimestampMillis().addToSchema(Schema.create(Schema.Type.LONG)); + private static final LocalDateTime TEST_LOCAL_DATE_TIME = + LocalDateTime.ofEpochSecond(TEST_EPOCH_MILLIS / 1000, 0, ZoneOffset.of("+02:00")); + // 2 hour offset to account for the local timezone + private static final long TEST_LOCAL_TIMESTAMP_MILLIS = TEST_EPOCH_MILLIS + 2 * Time.MS_PER_HOUR; + + private static final Schema COMPLEX_SUB_SCHEMA_AVRO = + SchemaBuilder.record("arrayComplex").fields().requiredInt("int").requiredString("string").endRecord(); + + private static final Schema AVRO_SCHEMA = SchemaBuilder.record("test") + .fields() + .name("byteArr") + .type() + .bytesType() + .noDefault() + .name("byteBuffer") + .type() + .bytesType() + .noDefault() + .name("decimalBytes") + .type(DECIMAL_TYPE) + .noDefault() + .name("booleanTrue") + .type() + .booleanType() + .noDefault() + .name("booleanFalse") + .type() + .booleanType() + .noDefault() + .name("float") + .type() + .floatType() + .noDefault() + .name("double") + .type() + .doubleType() + .noDefault() + .name("string") + .type() + .stringType() + .noDefault() + .name("byteArrFixed") + .type(FIXED_TYPE_3) + .noDefault() + .name("byteBufferFixed") + .type(FIXED_TYPE_3) + .noDefault() + .name("decimalFixed") + .type(DECIMAL_FIXED_TYPE) + .noDefault() + .name("enumType") + .type() + .enumeration("enumType") + .symbols("A", "B", "C") + .noDefault() + .name("int") + .type() + .intType() + .noDefault() + .name("date") + .type(DATE_TYPE) + .noDefault() + .name("dateLocal") + .type(DATE_TYPE) + .noDefault() + .name("byte") + .type() + .intType() + .noDefault() + .name("short") + .type() + .intType() + .noDefault() + .name("yearMonthInterval") + .type() + .intType() + .noDefault() + .name("long") + .type() + .longType() + .noDefault() + .name("instantMicros") + .type(TIMESTAMP_MICROS) + .noDefault() + .name("instantMillis") + .type(TIMESTAMP_MILLIS) + .noDefault() + .name("timestampMicros") + .type(TIMESTAMP_MICROS) + .noDefault() + .name("timestampMillis") + .type(TIMESTAMP_MILLIS) + .noDefault() + .name("timestampNoLogical") + .type() + .longType() + .noDefault() + .name("localTimestampMicros") + .type(LOCAL_TIMESTAMP_MICROS) + .noDefault() + .name("localTimestampMillis") + .type(LOCAL_TIMESTAMP_MILLIS) + .noDefault() + .name("localTimestampNoLogical") + .type() + .longType() + .noDefault() + .name("dayTimeInterval") + .type() + .longType() + .noDefault() + .name("arrayIntList") + .type() + .array() + .items() + .intType() + .noDefault() + .name("arrayIntSeq") + .type() + .array() + .items() + .intType() + .noDefault() + .name("arrayComplex") + .type() + .array() + .items(COMPLEX_SUB_SCHEMA_AVRO) + .noDefault() + .name("mapIntJavaMap") + .type() + .map() + .values() + .intType() + .noDefault() + .name("mapIntScalaMap") + .type() + .map() + .values() + .intType() + .noDefault() + .name("mapComplex") + .type() + .map() + .values(COMPLEX_SUB_SCHEMA_AVRO) + .noDefault() + .name("nullableUnion") + .type() + .unionOf() + .nullType() + .and() + .intType() + .endUnion() + .noDefault() + .name("nullableUnion2") + .type() + .unionOf() + .intType() + .and() + .nullType() + .endUnion() + .noDefault() + .name("singleElementUnion") + .type() + .unionOf() + .intType() + .endUnion() + .noDefault() + .name("intLongUnion") + .type() + .unionOf() + .intType() + .and() + .longType() + .endUnion() + .noDefault() + .name("longIntUnion") + .type() + .unionOf() + .longType() + .and() + .intType() + .endUnion() + .noDefault() + .name("floatDoubleUnion") + .type() + .unionOf() + .floatType() + .and() + .doubleType() + .endUnion() + .noDefault() + .name("doubleFloatUnion") + .type() + .unionOf() + .doubleType() + .and() + .floatType() + .endUnion() + .noDefault() + .name("complexNonNullableUnion") + .type() + .unionOf() + .doubleType() + .and() + .floatType() + .and() + .stringType() + .endUnion() + .noDefault() + .name("complexNullableUnion1") + .type() + .unionOf() + .nullType() + .and() + .stringType() + .and() + .intType() + .endUnion() + .noDefault() + .name("complexNullableUnion2") + .type() + .unionOf() + .stringType() + .and() + .nullType() + .and() + .intType() + .endUnion() + .noDefault() + .name("complexNullableUnion3") + .type() + .unionOf() + .stringType() + .and() + .intType() + .and() + .nullType() + .endUnion() + .noDefault() + .endRecord(); + + private static final Row SPARK_ROW = new GenericRowWithSchema( + new Object[] { new byte[] { 0x01, 0x02, 0x03 }, // byteArr + ByteBuffer.wrap(new byte[] { 0x04, 0x05, 0x06 }), // byteBuffer + new BigDecimal("0.456").setScale(2, RoundingMode.HALF_UP), // decimalBytes + true, // booleanTrue + false, // booleanFalse + 0.5f, // float + 0.7, // double + STRING_VALUE, // string + new byte[] { 0x01, 0x02, 0x03 }, // byteArrFixed + ByteBuffer.wrap(new byte[] { 0x04, 0x05, 0x06 }), // byteBufferFixed + new BigDecimal("0.456").setScale(2, RoundingMode.HALF_UP), // decimalFixed + "A", // enumType + 100, // int + Date.valueOf(LocalDate.of(2024, 6, 18)), // date + LocalDate.of(2024, 6, 18), // dateLocal + (byte) 100, // byte + (short) 100, // short + Period.ofMonths(5), // yearMonthInterval + 100L, // long + TEST_EPOCH_INSTANT, // instantMicros + TEST_EPOCH_INSTANT, // instantMillis + Timestamp.from(TEST_EPOCH_INSTANT), // timestampMicros + Timestamp.from(TEST_EPOCH_INSTANT), // timestampMillis + TEST_EPOCH_INSTANT, // timestampNoLogical + TEST_LOCAL_DATE_TIME, // localTimestampMicros + TEST_LOCAL_DATE_TIME, // localTimestampMillis + TEST_LOCAL_DATE_TIME, // localTimestampNoLogical + Duration.ofSeconds(100), // dayTimeInterval + Arrays.asList(1, 2, 3), // arrayIntList + JavaConverters.asScalaBuffer(Arrays.asList(1, 2, 3)).toList(), // arrayIntSeq + JavaConverters.asScalaBuffer( + Arrays.asList( + new GenericRowWithSchema(new Object[] { 10, STRING_VALUE_2 }, COMPLEX_SUB_SCHEMA), + new GenericRowWithSchema(new Object[] { 20, STRING_VALUE_3 }, COMPLEX_SUB_SCHEMA))) + .toList(), // arrayComplex + new HashMap() { + { + put("key1", 10); + put("key2", 20); + } + }, // mapIntJavaMap + JavaConverters.mapAsScalaMap(new HashMap() { + { + put("key1", 10); + put("key2", 20); + } + }), // mapIntScalaMap + new HashMap() { + { + put("key1", new GenericRowWithSchema(new Object[] { 10, STRING_VALUE_2 }, COMPLEX_SUB_SCHEMA)); + put("key2", new GenericRowWithSchema(new Object[] { 20, STRING_VALUE_3 }, COMPLEX_SUB_SCHEMA)); + } + }, // mapComplex + 10, // nullableUnion + null, // nullableUnion2 + 10, // singleElementUnion + 10L, // intLongUnion + 10L, // longIntUnion + 0.5, // floatDoubleUnion + 0.5, // doubleFloatUnion + new GenericRowWithSchema(new Object[] { null, 0.5f, null }, UNION_STRUCT_DOUBLE_FLOAT_STRING), // complexNonNullableUnion + new GenericRowWithSchema(new Object[] { null, 10 }, UNION_STRUCT_STRING_INT), // complexNullableUnion1 + new GenericRowWithSchema(new Object[] { STRING_VALUE, null }, UNION_STRUCT_STRING_INT), // complexNullableUnion2 + null, // complexNullableUnion3 + }, + SPARK_STRUCT_SCHEMA); + + @Test + public void testConvertToRecord() { + GenericRecord record = RowToAvroConverter.convertToRecord(SPARK_ROW, SPARK_STRUCT_SCHEMA, AVRO_SCHEMA); + assertEquals(record.get("byteArr"), ByteBuffer.wrap(new byte[] { 0x01, 0x02, 0x03 })); + assertEquals(record.get("byteBuffer"), ByteBuffer.wrap(new byte[] { 0x04, 0x05, 0x06 })); + assertEquals(record.get("decimalBytes"), ByteBuffer.wrap(new byte[] { 46 })); + assertEquals(record.get("booleanTrue"), true); + assertEquals(record.get("booleanFalse"), false); + assertEquals(record.get("float"), 0.5f); + assertEquals(record.get("double"), 0.7); + assertEquals(record.get("string"), STRING_VALUE); + assertEquals( + record.get("byteArrFixed"), + AvroCompatibilityHelper.newFixed(FIXED_TYPE_3, new byte[] { 0x01, 0x02, 0x03 })); + assertEquals( + record.get("byteBufferFixed"), + AvroCompatibilityHelper.newFixed(FIXED_TYPE_3, new byte[] { 0x04, 0x05, 0x06 })); + assertEquals(record.get("decimalFixed"), AvroCompatibilityHelper.newFixed(FIXED_TYPE_3, new byte[] { 0, 0, 46 })); + assertEquals( + record.get("enumType"), + AvroCompatibilityHelper.newEnumSymbol(AVRO_SCHEMA.getField("enumType").schema(), "A")); + assertEquals(record.get("int"), 100); + assertEquals(record.get("date"), (int) LocalDate.of(2024, 6, 18).toEpochDay()); + assertEquals(record.get("dateLocal"), (int) LocalDate.of(2024, 6, 18).toEpochDay()); + assertEquals(record.get("byte"), 100); + assertEquals(record.get("short"), 100); + assertEquals(record.get("yearMonthInterval"), 5); + assertEquals(record.get("long"), 100L); + assertEquals(record.get("instantMicros"), TEST_EPOCH_MILLIS * 1000); + assertEquals(record.get("instantMillis"), TEST_EPOCH_MILLIS); + assertEquals(record.get("timestampMicros"), TEST_EPOCH_MILLIS * 1000); + assertEquals(record.get("timestampMillis"), TEST_EPOCH_MILLIS); + assertEquals(record.get("timestampNoLogical"), TEST_EPOCH_MILLIS); + assertEquals(record.get("localTimestampMicros"), TEST_LOCAL_TIMESTAMP_MILLIS * 1000); + assertEquals(record.get("localTimestampMillis"), TEST_LOCAL_TIMESTAMP_MILLIS); + assertEquals(record.get("localTimestampNoLogical"), TEST_LOCAL_TIMESTAMP_MILLIS); + assertEquals(record.get("dayTimeInterval"), 100L * 1000 * 1000); + assertEquals(record.get("arrayIntList"), Arrays.asList(1, 2, 3)); + assertEquals(record.get("arrayIntSeq"), Arrays.asList(1, 2, 3)); + + GenericRecord complex_record_1 = new GenericData.Record(COMPLEX_SUB_SCHEMA_AVRO); + complex_record_1.put("int", 10); + complex_record_1.put("string", STRING_VALUE_2); + + GenericRecord complex_record_2 = new GenericData.Record(COMPLEX_SUB_SCHEMA_AVRO); + complex_record_2.put("int", 20); + complex_record_2.put("string", STRING_VALUE_3); + + assertEquals(record.get("arrayComplex"), Arrays.asList(complex_record_1, complex_record_2)); + + Map expectedIntMap = new HashMap() { + { + put("key1", 10); + put("key2", 20); + } + }; + assertEquals(record.get("mapIntJavaMap"), expectedIntMap); + assertEquals(record.get("mapIntScalaMap"), expectedIntMap); + + Map expectedComplexMap = new HashMap() { + { + put("key1", complex_record_1); + put("key2", complex_record_2); + } + }; + assertEquals(record.get("mapComplex"), expectedComplexMap); + + assertEquals(record.get("nullableUnion"), 10); + + assertNull(record.get("nullableUnion2")); + + assertEquals(record.get("singleElementUnion"), 10); + + assertEquals(record.get("intLongUnion"), 10L); + + assertEquals(record.get("longIntUnion"), 10L); + + assertEquals(record.get("floatDoubleUnion"), 0.5); + + assertEquals(record.get("doubleFloatUnion"), 0.5); + + Object complexNonNullableUnion = record.get("complexNonNullableUnion"); + assertTrue(complexNonNullableUnion instanceof Float); + assertEquals((Float) complexNonNullableUnion, 0.5f, 0.001f); + + Object complexNullableUnion1 = record.get("complexNullableUnion1"); + assertTrue(complexNullableUnion1 instanceof Integer); + assertEquals(((Integer) complexNullableUnion1).intValue(), 10); + + Object complexNullableUnion2 = record.get("complexNullableUnion2"); + assertTrue(complexNullableUnion2 instanceof CharSequence); + assertEquals(complexNullableUnion2, STRING_VALUE); + + assertNull(record.get("complexNullableUnion3")); + } + + @Test + public void testConvertToBoolean() { + Boolean trueObj = RowToAvroConverter.convertToBoolean(true, BooleanType); + assertNotNull(trueObj); + assertTrue(trueObj); + + Boolean falseObj = RowToAvroConverter.convertToBoolean(false, BooleanType); + assertNotNull(falseObj); + assertFalse(falseObj); + + // Type must be BooleanType + assertThrows(IllegalArgumentException.class, () -> RowToAvroConverter.convertToBoolean(true, ByteType)); + + // Data must be Boolean + assertThrows(IllegalArgumentException.class, () -> RowToAvroConverter.convertToBoolean(10, BooleanType)); + } + + @Test + public void testConvertToInt() { + Integer integer = + RowToAvroConverter.convertToInt(SPARK_ROW.getAs("int"), IntegerType, AVRO_SCHEMA.getField("int").schema()); + assertNotNull(integer); + assertEquals(integer.intValue(), 100); + + Integer date = + RowToAvroConverter.convertToInt(SPARK_ROW.getAs("date"), DateType, AVRO_SCHEMA.getField("date").schema()); + assertNotNull(date); + assertEquals(date.intValue(), (int) LocalDate.of(2024, 6, 18).toEpochDay()); + + Integer dateLocal = RowToAvroConverter + .convertToInt(SPARK_ROW.getAs("dateLocal"), DateType, AVRO_SCHEMA.getField("dateLocal").schema()); + assertNotNull(dateLocal); + assertEquals(dateLocal.intValue(), (int) LocalDate.of(2024, 6, 18).toEpochDay()); + + Integer byteInt = + RowToAvroConverter.convertToInt(SPARK_ROW.getAs("byte"), ByteType, AVRO_SCHEMA.getField("byte").schema()); + assertNotNull(byteInt); + assertEquals(byteInt.intValue(), 100); + + Integer shortInt = + RowToAvroConverter.convertToInt(SPARK_ROW.getAs("short"), ShortType, AVRO_SCHEMA.getField("short").schema()); + assertNotNull(shortInt); + assertEquals(shortInt.intValue(), 100); + + Integer yearMonthInterval = RowToAvroConverter.convertToInt( + SPARK_ROW.getAs("yearMonthInterval"), + DataTypes.createYearMonthIntervalType(), + AVRO_SCHEMA.getField("yearMonthInterval").schema()); + assertNotNull(yearMonthInterval); + assertEquals(yearMonthInterval.intValue(), 5); + + // Type must be IntegerType, ByteType, ShortType, DateType or YearMonthIntervalType + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter.convertToInt(0.5f, StringType, AVRO_SCHEMA.getField("int").schema())); + + // When using IntegerType, data must be an Integer + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter.convertToInt(10.0, IntegerType, AVRO_SCHEMA.getField("int").schema())); + + // When using ByteType, data must be a Byte + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter.convertToInt(10.0, ByteType, AVRO_SCHEMA.getField("byte").schema())); + + // When using ShortType, data must be a Short + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter.convertToInt(10.0, ShortType, AVRO_SCHEMA.getField("short").schema())); + + // When using DateType, data must be a java.time.LocalDate or java.sql.Date + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter.convertToInt(10.0, DateType, AVRO_SCHEMA.getField("date").schema())); + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter.convertToInt(10.0, DateType, AVRO_SCHEMA.getField("dateLocal").schema())); + + // When using DateType, the Avro schema must have a logical type of Date + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter + .convertToInt(LocalDate.of(2024, 6, 18), DateType, AVRO_SCHEMA.getField("int").schema())); + + // When using YearMonthIntervalType, data must be a Period + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter.convertToInt( + 10.0, + DataTypes.createYearMonthIntervalType(), + AVRO_SCHEMA.getField("yearMonthInterval").schema())); + } + + @Test + public void testConvertToLong() { + Long longType = + RowToAvroConverter.convertToLong(SPARK_ROW.getAs("long"), LongType, AVRO_SCHEMA.getField("long").schema()); + assertNotNull(longType); + assertEquals(longType.intValue(), 100); + + Long instantMicros = RowToAvroConverter + .convertToLong(SPARK_ROW.getAs("instantMicros"), TimestampType, AVRO_SCHEMA.getField("instantMicros").schema()); + assertNotNull(instantMicros); + assertEquals(instantMicros.longValue(), TEST_EPOCH_MILLIS * 1000); + + Long instantMillis = RowToAvroConverter + .convertToLong(SPARK_ROW.getAs("instantMillis"), TimestampType, AVRO_SCHEMA.getField("instantMillis").schema()); + assertNotNull(instantMillis); + assertEquals(instantMillis.longValue(), TEST_EPOCH_MILLIS); + + Long timestampMicros = RowToAvroConverter.convertToLong( + SPARK_ROW.getAs("timestampMicros"), + TimestampType, + AVRO_SCHEMA.getField("timestampMicros").schema()); + assertNotNull(timestampMicros); + assertEquals(timestampMicros.longValue(), TEST_EPOCH_MILLIS * 1000); + + Long timestampMillis = RowToAvroConverter.convertToLong( + SPARK_ROW.getAs("timestampMillis"), + TimestampType, + AVRO_SCHEMA.getField("timestampMillis").schema()); + assertNotNull(timestampMillis); + assertEquals(timestampMillis.longValue(), TEST_EPOCH_MILLIS); + + // When using TimestampType, and there is no logical type on the Avro schema, convert to millis by default + Long timestampNoLogical = RowToAvroConverter.convertToLong( + SPARK_ROW.getAs("timestampNoLogical"), + TimestampType, + AVRO_SCHEMA.getField("timestampNoLogical").schema()); + assertNotNull(timestampNoLogical); + assertEquals(timestampNoLogical.longValue(), TEST_EPOCH_MILLIS); + + Long localTimestampMicros = RowToAvroConverter.convertToLong( + SPARK_ROW.getAs("localTimestampMicros"), + TimestampNTZType$.MODULE$, + AVRO_SCHEMA.getField("localTimestampMicros").schema()); + assertNotNull(localTimestampMicros); + assertEquals(localTimestampMicros.longValue(), TEST_LOCAL_TIMESTAMP_MILLIS * 1000); + + Long localTimestampMillis = RowToAvroConverter.convertToLong( + SPARK_ROW.getAs("localTimestampMillis"), + TimestampNTZType$.MODULE$, + AVRO_SCHEMA.getField("localTimestampMillis").schema()); + assertNotNull(localTimestampMillis); + assertEquals(localTimestampMillis.longValue(), TEST_LOCAL_TIMESTAMP_MILLIS); + + // When using TimestampNTZType, and there is no logical type on the Avro schema, convert to millis by default + Long localTimestampNoLogical = RowToAvroConverter.convertToLong( + SPARK_ROW.getAs("localTimestampNoLogical"), + TimestampNTZType$.MODULE$, + AVRO_SCHEMA.getField("localTimestampNoLogical").schema()); + assertNotNull(localTimestampNoLogical); + assertEquals(localTimestampNoLogical.longValue(), TEST_LOCAL_TIMESTAMP_MILLIS); + + Long dayTimeInterval = RowToAvroConverter.convertToLong( + SPARK_ROW.getAs("dayTimeInterval"), + DataTypes.createDayTimeIntervalType(), + AVRO_SCHEMA.getField("dayTimeInterval").schema()); + assertNotNull(dayTimeInterval); + assertEquals(dayTimeInterval.longValue(), 100L * 1000 * 1000); + + // Type must be LongType, TimestampType, TimestampNTZType or DayTimeIntervalType + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter.convertToLong(0.5f, StringType, AVRO_SCHEMA.getField("long").schema())); + + // When using LongType, data must be a Long + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter.convertToLong(10.0, LongType, AVRO_SCHEMA.getField("long").schema())); + + // When using TimestampType, data must be a java.time.Instant or java.sql.Timestamp + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter.convertToLong(10.0, TimestampType, AVRO_SCHEMA.getField("instantMicros").schema())); + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter.convertToLong(10.0, TimestampType, AVRO_SCHEMA.getField("instantMillis").schema())); + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter.convertToLong(10.0, TimestampType, AVRO_SCHEMA.getField("timestampMicros").schema())); + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter.convertToLong(10.0, TimestampType, AVRO_SCHEMA.getField("timestampMillis").schema())); + + // When using TimestampNTZType, data must be a java.time.LocalDateTime + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter + .convertToLong(10.0, TimestampNTZType$.MODULE$, AVRO_SCHEMA.getField("localTimestampNoLogical").schema())); + + // When using DayTimeIntervalType, data must be a Duration + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter.convertToLong( + 10.0, + DataTypes.createDayTimeIntervalType(), + AVRO_SCHEMA.getField("dayTimeInterval").schema())); + } + + @Test + public void testConvertToFloat() { + Float floatObj = RowToAvroConverter.convertToFloat(0.5f, FloatType); + assertNotNull(floatObj); + assertEquals(floatObj, 0.5f, 0.0001f); + + // Type must be FloatType + assertThrows(IllegalArgumentException.class, () -> RowToAvroConverter.convertToFloat(0.5f, ByteType)); + + // Data must be Float + assertThrows(IllegalArgumentException.class, () -> RowToAvroConverter.convertToFloat(10, FloatType)); + } + + @Test + public void testConvertToDouble() { + Double doubleObj = RowToAvroConverter.convertToDouble(0.7, DoubleType); + assertNotNull(doubleObj); + assertEquals(doubleObj, 0.7, 0.0001); + + // Type must be DoubleType + assertThrows(IllegalArgumentException.class, () -> RowToAvroConverter.convertToDouble(0.7, ByteType)); + + // Data must be Double + assertThrows(IllegalArgumentException.class, () -> RowToAvroConverter.convertToDouble(true, DoubleType)); + } + + @Test + public void testConvertToString() { + CharSequence strObj = RowToAvroConverter.convertToString(STRING_VALUE, StringType); + assertNotNull(strObj); + assertEquals(strObj, STRING_VALUE); + + // Type must be StringType + assertThrows(IllegalArgumentException.class, () -> RowToAvroConverter.convertToString(STRING_VALUE, ByteType)); + + // Data must be String + assertThrows(IllegalArgumentException.class, () -> RowToAvroConverter.convertToString(100, StringType)); + } + + @Test + public void testConvertToBytes() { + ByteBuffer byteArrObj = RowToAvroConverter + .convertToBytes(SPARK_ROW.getAs("byteArr"), BinaryType, AVRO_SCHEMA.getField("byteArr").schema()); + assertNotNull(byteArrObj); + assertEquals(byteArrObj, ByteBuffer.wrap(new byte[] { 0x01, 0x02, 0x03 })); + + ByteBuffer byteBufferObj = RowToAvroConverter + .convertToBytes(SPARK_ROW.getAs("byteBuffer"), BinaryType, AVRO_SCHEMA.getField("byteBuffer").schema()); + assertNotNull(byteBufferObj); + assertEquals(byteBufferObj, ByteBuffer.wrap(new byte[] { 0x04, 0x05, 0x06 })); + + ByteBuffer decimalObj = RowToAvroConverter.convertToBytes( + SPARK_ROW.getAs("decimalBytes"), + DataTypes.createDecimalType(3, 2), + AVRO_SCHEMA.getField("decimalBytes").schema()); + assertNotNull(decimalObj); + assertEquals(decimalObj, ByteBuffer.wrap(new byte[] { 46 })); + + // The scale of the actual BigDecimal object shouldn't matter + ByteBuffer decimalObj2 = RowToAvroConverter.convertToBytes( + new BigDecimal("0.456").setScale(1, RoundingMode.HALF_UP), + DataTypes.createDecimalType(3, 2), + DECIMAL_TYPE); + assertNotNull(decimalObj2); + assertEquals(decimalObj2, ByteBuffer.wrap(new byte[] { 50 })); + + // Type must be BinaryType or DecimalType + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter + .convertToBytes(SPARK_ROW.getAs("byteArr"), ByteType, AVRO_SCHEMA.getField("byteArr").schema())); + + // Data must be byte[], ByteBuffer or BigDecimal + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter + .convertToBytes(SPARK_ROW.getAs("booleanTrue"), BinaryType, AVRO_SCHEMA.getField("booleanTrue").schema())); + + // Logical type scale must match the Spark type scale + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter.convertToBytes( + new BigDecimal("0.456").setScale(2, RoundingMode.HALF_UP), + DataTypes.createDecimalType(3, 2), + LogicalTypes.decimal(3, 3).addToSchema(Schema.create(Schema.Type.BYTES)))); + + // Logical type precision must match the Spark type precision + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter.convertToBytes( + new BigDecimal("0.456").setScale(4, RoundingMode.HALF_UP), + DataTypes.createDecimalType(4, 3), + LogicalTypes.decimal(3, 3).addToSchema(Schema.create(Schema.Type.BYTES)))); + } + + @Test + public void testConvertToFixed() { + GenericFixed byteArrObj = RowToAvroConverter + .convertToFixed(SPARK_ROW.getAs("byteArrFixed"), BinaryType, AVRO_SCHEMA.getField("byteArrFixed").schema()); + assertNotNull(byteArrObj); + assertNotNull(byteArrObj.bytes()); + assertEquals(byteArrObj.bytes().length, 3); + assertEquals(byteArrObj.bytes(), new byte[] { 0x01, 0x02, 0x03 }); + + GenericFixed byteBufferObj = RowToAvroConverter.convertToFixed( + SPARK_ROW.getAs("byteBufferFixed"), + BinaryType, + AVRO_SCHEMA.getField("byteBufferFixed").schema()); + assertNotNull(byteBufferObj); + assertNotNull(byteBufferObj.bytes()); + assertEquals(byteBufferObj.bytes().length, 3); + assertEquals(byteBufferObj.bytes(), new byte[] { 0x04, 0x05, 0x06 }); + + GenericFixed decimalObj = RowToAvroConverter.convertToFixed( + SPARK_ROW.getAs("decimalFixed"), + DataTypes.createDecimalType(3, 2), + AVRO_SCHEMA.getField("decimalFixed").schema()); + assertNotNull(decimalObj); + assertNotNull(decimalObj.bytes()); + assertEquals(decimalObj.bytes().length, 3); + assertEquals(decimalObj.bytes(), new byte[] { 0, 0, 46 }); + + // The scale of the actual BigDecimal object shouldn't matter + GenericFixed decimalObj2 = RowToAvroConverter.convertToFixed( + new BigDecimal("0.456").setScale(1, RoundingMode.HALF_UP), + DataTypes.createDecimalType(3, 2), + DECIMAL_FIXED_TYPE); + assertNotNull(decimalObj2); + assertNotNull(decimalObj2.bytes()); + assertEquals(decimalObj2.bytes().length, 3); + assertEquals(decimalObj2.bytes(), new byte[] { 0, 0, 50 }); + + // The byte array must have the correct length + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter + .convertToFixed(new byte[] { 0x00, 0x01 }, BinaryType, AVRO_SCHEMA.getField("byteArrFixed").schema())); + + // The byte buffer must have the correct length + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter.convertToFixed( + ByteBuffer.wrap(new byte[] { 0x00, 0x01 }), + BinaryType, + AVRO_SCHEMA.getField("byteBufferFixed").schema())); + + // Type must be BinaryType or DecimalType + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter + .convertToFixed(SPARK_ROW.getAs("byteArrFixed"), ByteType, AVRO_SCHEMA.getField("byteArrFixed").schema())); + + // Data must be byte[], ByteBuffer or BigDecimal + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter + .convertToFixed(SPARK_ROW.getAs("booleanTrue"), BinaryType, AVRO_SCHEMA.getField("booleanTrue").schema())); + + // Logical type scale must match the Spark type scale + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter.convertToFixed( + new BigDecimal("0.456").setScale(2, RoundingMode.HALF_UP), + DataTypes.createDecimalType(3, 2), + LogicalTypes.decimal(3, 3).addToSchema(Schema.createFixed("decimalFixed", null, null, 3)))); + + // Logical type precision must match the Spark type precision + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter.convertToFixed( + new BigDecimal("0.456").setScale(4, RoundingMode.HALF_UP), + DataTypes.createDecimalType(4, 3), + LogicalTypes.decimal(3, 3).addToSchema(Schema.createFixed("decimalFixed", null, null, 3)))); + } + + @Test + public void testConvertToEnum() { + GenericEnumSymbol enumObj = RowToAvroConverter + .convertToEnum(SPARK_ROW.getAs("enumType"), StringType, AVRO_SCHEMA.getField("enumType").schema()); + assertNotNull(enumObj); + assertEquals(enumObj, AvroCompatibilityHelper.newEnumSymbol(AVRO_SCHEMA.getField("enumType").schema(), "A")); + + // String value must be a valid symbol + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter.convertToEnum("D", StringType, AVRO_SCHEMA.getField("enumType").schema())); + + // Type must be StringType + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter + .convertToEnum(SPARK_ROW.getAs("enumType"), ByteType, AVRO_SCHEMA.getField("enumType").schema())); + + // Data must be String + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter + .convertToEnum(SPARK_ROW.getAs("booleanTrue"), StringType, AVRO_SCHEMA.getField("enumType").schema())); + } + + @Test + public void testConvertToArray() { + Schema arrayIntSchema = SchemaBuilder.array().items().intType(); + List arrayIntList = RowToAvroConverter + .convertToArray(SPARK_ROW.getAs("arrayIntList"), DataTypes.createArrayType(IntegerType), arrayIntSchema); + assertNotNull(arrayIntList); + assertEquals(arrayIntList, Arrays.asList(1, 2, 3)); + + List arrayIntSeq = RowToAvroConverter + .convertToArray(SPARK_ROW.getAs("arrayIntSeq"), DataTypes.createArrayType(IntegerType), arrayIntSchema); + assertNotNull(arrayIntSeq); + assertEquals(arrayIntSeq, Arrays.asList(1, 2, 3)); + + Schema arrayComplexSchema = SchemaBuilder.array().items(COMPLEX_SUB_SCHEMA_AVRO); + List arrayComplex = RowToAvroConverter.convertToArray( + SPARK_ROW.getAs("arrayComplex"), + DataTypes.createArrayType(COMPLEX_SUB_SCHEMA), + arrayComplexSchema); + assertNotNull(arrayComplex); + + GenericRecord complex_record_1 = new GenericData.Record(COMPLEX_SUB_SCHEMA_AVRO); + complex_record_1.put("int", 10); + complex_record_1.put("string", STRING_VALUE_2); + + GenericRecord complex_record_2 = new GenericData.Record(COMPLEX_SUB_SCHEMA_AVRO); + complex_record_2.put("int", 20); + complex_record_2.put("string", STRING_VALUE_3); + + assertEquals(arrayComplex, Arrays.asList(complex_record_1, complex_record_2)); + + // Type must be ArrayType + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter.convertToArray(SPARK_ROW.getAs("arrayIntList"), ByteType, arrayIntSchema)); + + // Data must be scala.collection.Seq or java.util.List + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter.convertToArray(100, DataTypes.createArrayType(IntegerType), arrayIntSchema)); + } + + @Test + public void testConvertToMap() { + Schema mapIntSchema = SchemaBuilder.map().values().intType(); + Map expectedIntMap = new HashMap() { + { + put("key1", 10); + put("key2", 20); + } + }; + + Map mapIntJavaMap = RowToAvroConverter + .convertToMap(SPARK_ROW.getAs("mapIntJavaMap"), DataTypes.createMapType(StringType, IntegerType), mapIntSchema); + assertNotNull(mapIntJavaMap); + assertEquals(mapIntJavaMap, expectedIntMap); + + Map mapIntScalaMap = RowToAvroConverter.convertToMap( + SPARK_ROW.getAs("mapIntScalaMap"), + DataTypes.createMapType(StringType, IntegerType), + mapIntSchema); + assertNotNull(mapIntScalaMap); + assertEquals(mapIntScalaMap, expectedIntMap); + + Schema mapComplexSchema = SchemaBuilder.map().values(COMPLEX_SUB_SCHEMA_AVRO); + Map mapComplex = RowToAvroConverter.convertToMap( + SPARK_ROW.getAs("mapComplex"), + DataTypes.createMapType(StringType, COMPLEX_SUB_SCHEMA), + mapComplexSchema); + assertNotNull(mapComplex); + + GenericRecord complex_record_1 = new GenericData.Record(COMPLEX_SUB_SCHEMA_AVRO); + complex_record_1.put("int", 10); + complex_record_1.put("string", STRING_VALUE_2); + + GenericRecord complex_record_2 = new GenericData.Record(COMPLEX_SUB_SCHEMA_AVRO); + complex_record_2.put("int", 20); + complex_record_2.put("string", STRING_VALUE_3); + + Map expectedComplexMap = new HashMap() { + { + put("key1", complex_record_1); + put("key2", complex_record_2); + } + }; + + assertEquals(mapComplex, expectedComplexMap); + + // Maps with keys that are not String are not supported + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter.convertToMap( + SPARK_ROW.getAs("mapIntJavaMap"), + DataTypes.createMapType(ByteType, IntegerType), + mapIntSchema)); + + // Type must be MapType + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter.convertToMap(SPARK_ROW.getAs("mapIntJavaMap"), ByteType, mapIntSchema)); + + // Data must be scala.collection.Map or java.util.Map + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter.convertToMap(100, DataTypes.createMapType(StringType, IntegerType), mapIntSchema)); + } + + @Test + public void testConvertToUnion() { + // null is allowed for nullable unions + assertNull(RowToAvroConverter.convertToUnion(null, IntegerType, AVRO_SCHEMA.getField("nullableUnion").schema())); + + // null is not allowed for non-nullable unions + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter.convertToUnion(null, LongType, AVRO_SCHEMA.getField("longIntUnion").schema())); + + // Test union with only 1 branch + Object singleElementUnion = + RowToAvroConverter.convertToUnion(10, IntegerType, AVRO_SCHEMA.getField("singleElementUnion").schema()); + assertTrue(singleElementUnion instanceof Integer); + assertEquals(((Integer) singleElementUnion).intValue(), 10); + + // Test union with two branches: null + something else + assertNull(RowToAvroConverter.convertToUnion(null, IntegerType, AVRO_SCHEMA.getField("nullableUnion").schema())); + Object nullableUnionValue = + RowToAvroConverter.convertToUnion(10, IntegerType, AVRO_SCHEMA.getField("nullableUnion").schema()); + assertTrue(nullableUnionValue instanceof Integer); + assertEquals(((Integer) nullableUnionValue).intValue(), 10); + + // Test union with two branches: something else + null + assertNull(RowToAvroConverter.convertToUnion(null, IntegerType, AVRO_SCHEMA.getField("nullableUnion2").schema())); + Object nullableUnion2Value = + RowToAvroConverter.convertToUnion(10, IntegerType, AVRO_SCHEMA.getField("nullableUnion2").schema()); + assertTrue(nullableUnion2Value instanceof Integer); + assertEquals(((Integer) nullableUnion2Value).intValue(), 10); + + // Test union with two branches: int + long + Object intLongUnion = + RowToAvroConverter.convertToUnion(10L, LongType, AVRO_SCHEMA.getField("intLongUnion").schema()); + assertTrue(intLongUnion instanceof Long); + assertEquals(((Long) intLongUnion).longValue(), 10L); + + // Test union with two branches: long + int + Object longIntUnion = + RowToAvroConverter.convertToUnion(10L, LongType, AVRO_SCHEMA.getField("intLongUnion").schema()); + assertTrue(longIntUnion instanceof Long); + assertEquals(((Long) longIntUnion).longValue(), 10L); + + // Test union with two branches: float + double + Object floatDoubleUnion = + RowToAvroConverter.convertToUnion(0.5, DoubleType, AVRO_SCHEMA.getField("floatDoubleUnion").schema()); + assertTrue(floatDoubleUnion instanceof Double); + assertEquals((Double) floatDoubleUnion, 0.5, 0.001); + + // Test union with two branches: double + float + Object doubleFloatUnion = + RowToAvroConverter.convertToUnion(0.5, DoubleType, AVRO_SCHEMA.getField("doubleFloatUnion").schema()); + assertTrue(doubleFloatUnion instanceof Double); + assertEquals((Double) doubleFloatUnion, 0.5, 0.001); + + // Test complex union without null + Object complexNonNullableUnion1 = RowToAvroConverter.convertToUnion( + new GenericRowWithSchema(new Object[] { null, 0.5f, null }, UNION_STRUCT_DOUBLE_FLOAT_STRING), + UNION_STRUCT_DOUBLE_FLOAT_STRING, + AVRO_SCHEMA.getField("complexNonNullableUnion").schema()); + assertTrue(complexNonNullableUnion1 instanceof Float); + assertEquals((Float) complexNonNullableUnion1, 0.5f, 0.001f); + + Object complexNonNullableUnion2 = RowToAvroConverter.convertToUnion( + new GenericRowWithSchema(new Object[] { 0.5, null, null }, UNION_STRUCT_DOUBLE_FLOAT_STRING), + UNION_STRUCT_DOUBLE_FLOAT_STRING, + AVRO_SCHEMA.getField("complexNonNullableUnion").schema()); + assertTrue(complexNonNullableUnion2 instanceof Double); + assertEquals((Double) complexNonNullableUnion2, 0.5, 0.001); + + Object complexNonNullableUnion3 = RowToAvroConverter.convertToUnion( + new GenericRowWithSchema(new Object[] { null, null, STRING_VALUE }, UNION_STRUCT_DOUBLE_FLOAT_STRING), + UNION_STRUCT_DOUBLE_FLOAT_STRING, + AVRO_SCHEMA.getField("complexNonNullableUnion").schema()); + assertTrue(complexNonNullableUnion3 instanceof String); + assertEquals(complexNonNullableUnion3, STRING_VALUE); + + // Test complex union with null in first branch + Object complexNullableUnion1_1 = RowToAvroConverter.convertToUnion( + new GenericRowWithSchema(new Object[] { null, 10 }, UNION_STRUCT_STRING_INT), + UNION_STRUCT_STRING_INT, + AVRO_SCHEMA.getField("complexNullableUnion1").schema()); + assertTrue(complexNullableUnion1_1 instanceof Integer); + assertEquals(((Integer) complexNullableUnion1_1).intValue(), 10); + + Object complexNullableUnion1_2 = RowToAvroConverter.convertToUnion( + new GenericRowWithSchema(new Object[] { STRING_VALUE, null }, UNION_STRUCT_STRING_INT), + UNION_STRUCT_STRING_INT, + AVRO_SCHEMA.getField("complexNullableUnion1").schema()); + assertTrue(complexNullableUnion1_2 instanceof String); + assertEquals(complexNullableUnion1_2, STRING_VALUE); + + Object complexNullableUnion1_3 = RowToAvroConverter + .convertToUnion(null, UNION_STRUCT_STRING_INT, AVRO_SCHEMA.getField("complexNullableUnion1").schema()); + assertNull(complexNullableUnion1_3); + + // Test complex union with null in second branch + Object complexNullableUnion2_1 = RowToAvroConverter.convertToUnion( + new GenericRowWithSchema(new Object[] { null, 10 }, UNION_STRUCT_STRING_INT), + UNION_STRUCT_STRING_INT, + AVRO_SCHEMA.getField("complexNullableUnion2").schema()); + assertTrue(complexNullableUnion2_1 instanceof Integer); + assertEquals(((Integer) complexNullableUnion2_1).intValue(), 10); + + Object complexNullableUnion2_2 = RowToAvroConverter.convertToUnion( + new GenericRowWithSchema(new Object[] { STRING_VALUE, null }, UNION_STRUCT_STRING_INT), + UNION_STRUCT_STRING_INT, + AVRO_SCHEMA.getField("complexNullableUnion2").schema()); + assertTrue(complexNullableUnion2_2 instanceof String); + assertEquals(complexNullableUnion2_2, STRING_VALUE); + + Object complexNullableUnion2_3 = RowToAvroConverter + .convertToUnion(null, UNION_STRUCT_STRING_INT, AVRO_SCHEMA.getField("complexNullableUnion2").schema()); + assertNull(complexNullableUnion2_3); + + // Test complex union with null in third branch + Object complexNullableUnion3_1 = RowToAvroConverter.convertToUnion( + new GenericRowWithSchema(new Object[] { null, 10 }, UNION_STRUCT_STRING_INT), + UNION_STRUCT_STRING_INT, + AVRO_SCHEMA.getField("complexNullableUnion3").schema()); + assertTrue(complexNullableUnion3_1 instanceof Integer); + assertEquals(((Integer) complexNullableUnion3_1).intValue(), 10); + + Object complexNullableUnion3_2 = RowToAvroConverter.convertToUnion( + new GenericRowWithSchema(new Object[] { STRING_VALUE, null }, UNION_STRUCT_STRING_INT), + UNION_STRUCT_STRING_INT, + AVRO_SCHEMA.getField("complexNullableUnion3").schema()); + assertTrue(complexNullableUnion3_2 instanceof String); + assertEquals(complexNullableUnion3_2, STRING_VALUE); + + Object complexNullableUnion3_3 = RowToAvroConverter + .convertToUnion(null, UNION_STRUCT_STRING_INT, AVRO_SCHEMA.getField("complexNullableUnion3").schema()); + assertNull(complexNullableUnion3_3); + + // At least one branch must be non-null + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter.convertToUnion( + new GenericRowWithSchema(new Object[] { null, null }, UNION_STRUCT_STRING_INT), + UNION_STRUCT_STRING_INT, + AVRO_SCHEMA.getField("complexNullableUnion3").schema())); + } + + @Test + public void testValidateLogicalType() { + assertEquals( + RowToAvroConverter.validateLogicalType(DATE_TYPE, LogicalTypes.date(), LogicalTypes.timeMillis()), + LogicalTypes.date()); + + // Logical type must match the Spark type if it is mandatory + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter.validateLogicalType(DATE_TYPE, LogicalTypes.timeMillis())); + + // Logical type must be present in the Avro schema if it is mandatory + assertThrows( + IllegalArgumentException.class, + () -> RowToAvroConverter.validateLogicalType(Schema.create(Schema.Type.LONG), LogicalTypes.timeMillis())); + + // Logical type might not be present in the Avro schema if it is optional + assertNull( + RowToAvroConverter.validateLogicalType(Schema.create(Schema.Type.LONG), false, LogicalTypes.timeMillis())); + } +} diff --git a/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/endToEnd/PartialUpdateTest.java b/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/endToEnd/PartialUpdateTest.java index b6930ba933..b6dd7bd164 100644 --- a/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/endToEnd/PartialUpdateTest.java +++ b/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/endToEnd/PartialUpdateTest.java @@ -34,6 +34,7 @@ import static com.linkedin.venice.vpj.VenicePushJobConstants.REPUSH_TTL_START_TIMESTAMP; import static com.linkedin.venice.vpj.VenicePushJobConstants.REWIND_TIME_IN_SECONDS_OVERRIDE; import static com.linkedin.venice.vpj.VenicePushJobConstants.SOURCE_KAFKA; +import static com.linkedin.venice.vpj.VenicePushJobConstants.SPARK_NATIVE_INPUT_FORMAT_ENABLED; import static com.linkedin.venice.vpj.VenicePushJobConstants.VENICE_STORE_NAME_PROP; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -399,6 +400,7 @@ public void testIncrementalPushPartialUpdateNewFormat(boolean useSparkCompute) t vpjProperties.put(INCREMENTAL_PUSH, true); if (useSparkCompute) { vpjProperties.setProperty(DATA_WRITER_COMPUTE_JOB_CLASS, DataWriterSparkJob.class.getCanonicalName()); + vpjProperties.setProperty(SPARK_NATIVE_INPUT_FORMAT_ENABLED, String.valueOf(true)); } try (ControllerClient parentControllerClient = new ControllerClient(CLUSTER_NAME, parentControllerUrl)) { diff --git a/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/endToEnd/TestBatch.java b/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/endToEnd/TestBatch.java index 81a6945c05..d95accd2d2 100644 --- a/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/endToEnd/TestBatch.java +++ b/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/endToEnd/TestBatch.java @@ -47,6 +47,7 @@ import static com.linkedin.venice.vpj.VenicePushJobConstants.SEND_CONTROL_MESSAGES_DIRECTLY; import static com.linkedin.venice.vpj.VenicePushJobConstants.SOURCE_ETL; import static com.linkedin.venice.vpj.VenicePushJobConstants.SOURCE_KAFKA; +import static com.linkedin.venice.vpj.VenicePushJobConstants.SPARK_NATIVE_INPUT_FORMAT_ENABLED; import static com.linkedin.venice.vpj.VenicePushJobConstants.USE_MAPPER_TO_BUILD_DICTIONARY; import static com.linkedin.venice.vpj.VenicePushJobConstants.VENICE_STORE_NAME_PROP; import static com.linkedin.venice.vpj.VenicePushJobConstants.ZSTD_COMPRESSION_LEVEL; @@ -769,7 +770,7 @@ public void testBatchFromETLWithNullDefaultValue() throws Exception { } @Test(timeOut = TEST_TIMEOUT) - public void testBatchFromETLWithForUnionWithNullSchema() throws Exception { + public void testBatchFromETLForUnionWithNullSchema() throws Exception { testBatchStore(inputDir -> { writeETLFileWithUnionWithNullSchema(inputDir); return new KeyAndValueSchemas(ETL_KEY_SCHEMA, ETL_UNION_VALUE_WITH_NULL_SCHEMA); @@ -796,7 +797,7 @@ public void testBatchFromETLWithForUnionWithNullSchema() throws Exception { } @Test(timeOut = TEST_TIMEOUT) - public void testBatchFromETLWithForUnionWithoutNullSchema() throws Exception { + public void testBatchFromETLForUnionWithoutNullSchema() throws Exception { testBatchStore(inputDir -> { writeETLFileWithUnionWithoutNullSchema(inputDir); return new KeyAndValueSchemas(ETL_KEY_SCHEMA, ETL_UNION_VALUE_WITHOUT_NULL_SCHEMA); @@ -895,6 +896,7 @@ private String testBatchStore( String inputDirPath = "file://" + inputDir.getAbsolutePath(); Properties props = defaultVPJProps(veniceCluster, inputDirPath, storeName); props.setProperty(DATA_WRITER_COMPUTE_JOB_CLASS, DataWriterSparkJob.class.getCanonicalName()); + props.setProperty(SPARK_NATIVE_INPUT_FORMAT_ENABLED, String.valueOf(true)); extraProps.accept(props); if (StringUtils.isEmpty(existingStore)) { @@ -1419,7 +1421,7 @@ public void testKafkaInputBatchJobSucceedsWhenSourceTopicIsEmpty() throws Except } @Test(timeOut = TEST_TIMEOUT, dataProvider = "True-and-False", dataProviderClass = DataProviderUtils.class) - public void testBatchJobSnapshots(Boolean isKakfaPush) throws Exception { + public void testBatchJobSnapshots(Boolean isKafkaPush) throws Exception { VPJValidator validator = (avroClient, vsonClient, metricsRepository) -> { for (int i = 1; i <= 100; i++) { @@ -1437,7 +1439,7 @@ public void testBatchJobSnapshots(Boolean isKakfaPush) throws Exception { deleteDirectory(Paths.get(BASE_DATA_PATH_1).toFile()); deleteDirectory(Paths.get(BASE_DATA_PATH_2).toFile()); - if (isKakfaPush) { + if (isKafkaPush) { testRepush(storeName, validator); } else { testBatchStore( diff --git a/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/endToEnd/TestVsonStoreBatch.java b/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/endToEnd/TestVsonStoreBatch.java index 2b2642c730..16f97ddb98 100644 --- a/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/endToEnd/TestVsonStoreBatch.java +++ b/internal/venice-test-common/src/integrationTest/java/com/linkedin/venice/endToEnd/TestVsonStoreBatch.java @@ -15,6 +15,7 @@ import static com.linkedin.venice.vpj.VenicePushJobConstants.KAFKA_INPUT_TOPIC; import static com.linkedin.venice.vpj.VenicePushJobConstants.KEY_FIELD_PROP; import static com.linkedin.venice.vpj.VenicePushJobConstants.SOURCE_KAFKA; +import static com.linkedin.venice.vpj.VenicePushJobConstants.SPARK_NATIVE_INPUT_FORMAT_ENABLED; import static com.linkedin.venice.vpj.VenicePushJobConstants.VALUE_FIELD_PROP; import com.linkedin.venice.client.store.AvroGenericStoreClient; @@ -322,6 +323,7 @@ private String testBatchStore( String inputDirPath = "file://" + inputDir.getAbsolutePath(); Properties props = defaultVPJPropsWithoutD2Routing(veniceCluster, inputDirPath, storeName); props.setProperty(DATA_WRITER_COMPUTE_JOB_CLASS, DataWriterSparkJob.class.getCanonicalName()); + props.setProperty(SPARK_NATIVE_INPUT_FORMAT_ENABLED, String.valueOf(true)); extraProps.accept(props); if (!storeNameOptional.isPresent()) {