From b2ad521b6bae7af102b7220d5b0768506c0e9270 Mon Sep 17 00:00:00 2001 From: Eduard Karacharov Date: Sun, 22 Dec 2024 19:32:46 +0200 Subject: [PATCH] fix: column writer for dictionary decimal primitive type --- .../src/arrow/array_reader/primitive_array.rs | 51 +++++++++++ parquet/src/arrow/arrow_writer/mod.rs | 90 ++++++++++++++++++- 2 files changed, 140 insertions(+), 1 deletion(-) diff --git a/parquet/src/arrow/array_reader/primitive_array.rs b/parquet/src/arrow/array_reader/primitive_array.rs index a952e00e12ef..735fcf2e10bb 100644 --- a/parquet/src/arrow/array_reader/primitive_array.rs +++ b/parquet/src/arrow/array_reader/primitive_array.rs @@ -270,6 +270,57 @@ where Arc::new(array) as ArrayRef } + ArrowType::Dictionary(_, value_type) => match value_type.as_ref() { + ArrowType::Decimal128(p, s) => { + let array = match array.data_type() { + ArrowType::Int32 => array + .as_any() + .downcast_ref::() + .unwrap() + .unary(|i| i as i128) + as Decimal128Array, + ArrowType::Int64 => array + .as_any() + .downcast_ref::() + .unwrap() + .unary(|i| i as i128) + as Decimal128Array, + _ => { + return Err(arrow_err!( + "Cannot convert {:?} to decimal dictionary", + array.data_type() + )); + } + }.with_precision_and_scale(*p, *s)?; + + arrow_cast::cast(&array, target_type)? + }, + ArrowType::Decimal256(p, s) => { + let array = match array.data_type() { + ArrowType::Int32 => array + .as_any() + .downcast_ref::() + .unwrap() + .unary(i256::from) + as Decimal256Array, + ArrowType::Int64 => array + .as_any() + .downcast_ref::() + .unwrap() + .unary(i256::from) + as Decimal256Array, + _ => { + return Err(arrow_err!( + "Cannot convert {:?} to decimal dictionary", + array.data_type() + )); + } + }.with_precision_and_scale(*p, *s)?; + + arrow_cast::cast(&array, target_type)? + }, + _ => arrow_cast::cast(&array, target_type)? + } _ => arrow_cast::cast(&array, target_type)?, }; diff --git a/parquet/src/arrow/arrow_writer/mod.rs b/parquet/src/arrow/arrow_writer/mod.rs index 41f15569fda0..716abf415c0f 100644 --- a/parquet/src/arrow/arrow_writer/mod.rs +++ b/parquet/src/arrow/arrow_writer/mod.rs @@ -843,6 +843,27 @@ fn write_leaf(writer: &mut ColumnWriter<'_>, levels: &ArrayLevels) -> Result(|v| v.as_i128() as i32); write_primitive(typed, array.values(), levels) } + ArrowDataType::Dictionary(_, value_type) => match value_type.as_ref() { + ArrowDataType::Decimal128(_, _) => { + let array = arrow_cast::cast(column, value_type)?; + let array = array + .as_primitive::() + .unary::<_, Int32Type>(|v| v as i32); + write_primitive(typed, array.values(), levels) + } + ArrowDataType::Decimal256(_, _) => { + let array = arrow_cast::cast(column, value_type)?; + let array = array + .as_primitive::() + .unary::<_, Int32Type>(|v| v.as_i128() as i32); + write_primitive(typed, array.values(), levels) + } + _ => { + let array = arrow_cast::cast(column, &ArrowDataType::Int32)?; + let array = array.as_primitive::(); + write_primitive(typed, array.values(), levels) + } + }, _ => { let array = arrow_cast::cast(column, &ArrowDataType::Int32)?; let array = array.as_primitive::(); @@ -891,6 +912,27 @@ fn write_leaf(writer: &mut ColumnWriter<'_>, levels: &ArrayLevels) -> Result(|v| v.as_i128() as i64); write_primitive(typed, array.values(), levels) } + ArrowDataType::Dictionary(_, value_type) => match value_type.as_ref() { + ArrowDataType::Decimal128(_, _) => { + let array = arrow_cast::cast(column, value_type)?; + let array = array + .as_primitive::() + .unary::<_, Int64Type>(|v| v as i64); + write_primitive(typed, array.values(), levels) + } + ArrowDataType::Decimal256(_, _) => { + let array = arrow_cast::cast(column, value_type)?; + let array = array + .as_primitive::() + .unary::<_, Int64Type>(|v| v.as_i128() as i64); + write_primitive(typed, array.values(), levels) + } + _ => { + let array = arrow_cast::cast(column, &ArrowDataType::Int64)?; + let array = array.as_primitive::(); + write_primitive(typed, array.values(), levels) + } + }, _ => { let array = arrow_cast::cast(column, &ArrowDataType::Int64)?; let array = array.as_primitive::(); @@ -1093,7 +1135,7 @@ mod tests { use arrow::util::data_gen::create_random_array; use arrow::util::pretty::pretty_format_batches; use arrow::{array::*, buffer::Buffer}; - use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano, NullBuffer}; + use arrow_buffer::{i256, IntervalDayTime, IntervalMonthDayNano, NullBuffer}; use arrow_schema::Fields; use half::f16; @@ -2670,6 +2712,52 @@ mod tests { one_column_roundtrip_with_schema(Arc::new(d), schema); } + #[test] + fn arrow_writer_decimal128_dictionary() { + let integers = vec![12345, 56789, 34567]; + + let keys = UInt8Array::from(vec![Some(0), None, Some(1), Some(2), Some(1)]); + + let values = Decimal128Array::from(integers.clone()) + .with_precision_and_scale(5, 2) + .unwrap(); + + let array = DictionaryArray::new(keys, Arc::new(values)); + one_column_roundtrip(Arc::new(array.clone()), true); + + let values = Decimal128Array::from(integers) + .with_precision_and_scale(12, 2) + .unwrap(); + + let array = array.with_values(Arc::new(values)); + one_column_roundtrip(Arc::new(array), true); + } + + #[test] + fn arrow_writer_decimal256_dictionary() { + let integers = vec![ + i256::from_i128(12345), + i256::from_i128(56789), + i256::from_i128(34567), + ]; + + let keys = UInt8Array::from(vec![Some(0), None, Some(1), Some(2), Some(1)]); + + let values = Decimal256Array::from(integers.clone()) + .with_precision_and_scale(5, 2) + .unwrap(); + + let array = DictionaryArray::new(keys, Arc::new(values)); + one_column_roundtrip(Arc::new(array.clone()), true); + + let values = Decimal256Array::from(integers) + .with_precision_and_scale(12, 2) + .unwrap(); + + let array = array.with_values(Arc::new(values)); + one_column_roundtrip(Arc::new(array), true); + } + #[test] fn arrow_writer_string_dictionary_unsigned_index() { // define schema