From 743c452945411ad31c0d3aa6deda179d7dd007bf Mon Sep 17 00:00:00 2001 From: maximedion2 <125930903+maximedion2@users.noreply.github.com> Date: Thu, 28 Nov 2024 11:39:27 -0600 Subject: [PATCH] fixed a bug with shard alignment (#32) --- src/async_reader/mod.rs | 58 +++++++++++++++++-- src/reader/codecs.rs | 122 ++++++++++++++++++++++++++++++++++++++-- src/reader/mod.rs | 62 +++++++++++++++++--- test-data | 2 +- 4 files changed, 225 insertions(+), 19 deletions(-) diff --git a/src/async_reader/mod.rs b/src/async_reader/mod.rs index 2fbb0d3..29be626 100644 --- a/src/async_reader/mod.rs +++ b/src/async_reader/mod.rs @@ -775,6 +775,28 @@ mod zarr_async_reader_tests { assert!(matched); } + fn compare_values(col_name1: &str, col_name2: &str, rec: &RecordBatch) + where + T: ArrowPrimitiveType, + { + let mut vals1 = None; + let mut vals2 = None; + for (idx, col) in enumerate(rec.schema().fields.iter()) { + if col.name().as_str() == col_name1 { + vals1 = Some(rec.column(idx).as_primitive::().values()) + } else if col.name().as_str() == col_name2 { + vals2 = Some(rec.column(idx).as_primitive::().values()) + } + } + + if let (Some(vals1), Some(vals2)) = (vals1, vals2) { + assert_eq!(vals1, vals2); + return; + } + + panic!("columns not found"); + } + // create a test filter fn create_filter() -> ZarrChunkFilter { let mut filters: Vec> = Vec::new(); @@ -1047,7 +1069,7 @@ mod zarr_async_reader_tests { "float_data", rec, &[ - 32.0, 33.0, 40.0, 41.0, 34.0, 35.0, 42.0, 43.0, 48.0, 49.0, 56.0, 57.0, 50.0, 51.0, + 32.0, 33.0, 34.0, 35.0, 40.0, 41.0, 42.0, 43.0, 48.0, 49.0, 50.0, 51.0, 56.0, 57.0, 58.0, 59.0, ], ); @@ -1055,7 +1077,7 @@ mod zarr_async_reader_tests { "int_data", rec, &[ - 32, 33, 40, 41, 34, 35, 42, 43, 48, 49, 56, 57, 50, 51, 58, 59, + 32, 33, 34, 35, 40, 41, 42, 43, 48, 49, 50, 51, 56, 57, 58, 59, ], ); } @@ -1076,10 +1098,10 @@ mod zarr_async_reader_tests { "float_data", rec, &[ - 1020.0, 1021.0, 1031.0, 1032.0, 1141.0, 1142.0, 1152.0, 1153.0, 1022.0, 1033.0, - 1143.0, 1154.0, 1042.0, 1043.0, 1053.0, 1054.0, 1163.0, 1164.0, 1174.0, 1175.0, - 1044.0, 1055.0, 1165.0, 1176.0, 1262.0, 1263.0, 1273.0, 1274.0, 1264.0, 1275.0, - 1284.0, 1285.0, 1295.0, 1296.0, 1286.0, 1297.0, + 1020.0, 1021.0, 1022.0, 1031.0, 1032.0, 1033.0, 1042.0, 1043.0, 1044.0, 1053.0, + 1054.0, 1055.0, 1141.0, 1142.0, 1143.0, 1152.0, 1153.0, 1154.0, 1163.0, 1164.0, + 1165.0, 1174.0, 1175.0, 1176.0, 1262.0, 1263.0, 1264.0, 1273.0, 1274.0, 1275.0, + 1284.0, 1285.0, 1286.0, 1295.0, 1296.0, 1297.0, ], ); } @@ -1102,4 +1124,28 @@ mod zarr_async_reader_tests { &[4, 5, 6, 7, 20, 21, 22, 23, 36, 37, 38, 39, 52, 53, 54, 55], ); } + + #[tokio::test] + async fn with_partial_sharding_tests() { + let zp = get_v3_test_zarr_path("with_partial_sharding.zarr".to_string()); + let stream_builder = ZarrRecordBatchStreamBuilder::new(zp); + + let stream = stream_builder.build().await.unwrap(); + let records: Vec<_> = stream.try_collect().await.unwrap(); + for rec in records { + compare_values::("float_data_not_sharded", "float_data_sharded", &rec); + } + } + + #[tokio::test] + async fn with_partial_sharding_3d_tests() { + let zp = get_v3_test_zarr_path("with_partial_sharding_3D.zarr".to_string()); + let stream_builder = ZarrRecordBatchStreamBuilder::new(zp); + + let stream = stream_builder.build().await.unwrap(); + let records: Vec<_> = stream.try_collect().await.unwrap(); + for rec in records { + compare_values::("float_data_not_sharded", "float_data_sharded", &rec); + } + } } diff --git a/src/reader/codecs.rs b/src/reader/codecs.rs index 2eaf2cd..b4d2ab9 100644 --- a/src/reader/codecs.rs +++ b/src/reader/codecs.rs @@ -27,6 +27,10 @@ use std::str::FromStr; use std::sync::Arc; use std::vec; +// couple useful constant for empty shards +const NULL_OFFSET: usize = usize::pow(2, 63) + (usize::pow(2, 63) - 1); +const NULL_NBYTES: usize = usize::pow(2, 63) + (usize::pow(2, 63) - 1); + // Type enum and for the various support zarr types #[derive(Debug, PartialEq, Clone)] pub(crate) enum ZarrDataType { @@ -592,6 +596,96 @@ fn broadcast_array( } } +// the logic here gets a little messy, but the goal is to fill in the +// the data for the outer chunks from the data within the inner shards +// in the outer chunk. the below function is called one inner shard at +// a time, and the "flat" data from the inner shard is read in small +// chunks and written to the correct position in the array for the entire +// outer chunk. +fn fill_data_from_shard( + data: &mut [T], + shard_data: &[T], + chunk_real_dims: &[usize], + chunk_dims: &[usize], + inner_real_dims: &[usize], + inner_dims: &[usize], + pos: usize, +) -> ZarrResult<()> { + let l = inner_real_dims.len(); + let err = "mismatch between inner and outer chunks dimensions"; + if l != chunk_real_dims.len() || l != inner_dims.len() { + return Err(throw_invalid_meta(err)); + } + + match l { + // the 1D array case is easy, there's just one offset to compute, + // from the start of the 1D outer chunk. + 1 => { + let stride = inner_dims[0]; + data[pos * stride..(pos + 1) * stride].copy_from_slice(shard_data); + } + // The 2D case is trickier, we need to keep track of the inner shard + // position within the outer chunk, as well the where we're reading + // in the inner shard. + 2 => { + // first, find the position of the shard within the chunk. + let n_shards_1 = chunk_dims[1] / inner_dims[1]; + let i = pos / n_shards_1; + let j = pos - i * n_shards_1; + let shard_pos = [i, j]; + + // then loop over each row in the shard and write to the + // data array for the chunk. + let stride = inner_real_dims[1]; + for row_idx in 0..inner_real_dims[0] { + // read the row from the 1D shard array + let shard_row = &shard_data[row_idx * stride..(row_idx + 1) * stride]; + + // determine where in the chunk data array the shard array should go. + // the first component of the offsets is the "vertrical" position of + // the shard within the chunk times the rows per shard, plus the row + // within the shard, all that times the number of rows per shards. + // the second component is the "horizontal" position of the shard + // within the chunk, times the number of columns per shard. + let chunk_offset = (shard_pos[0] * inner_dims[0] + row_idx) * chunk_real_dims[1] + + shard_pos[1] * inner_dims[1]; + data[chunk_offset..chunk_offset + stride].copy_from_slice(shard_row); + } + } + // similar to the 2D case, but a but more complicated, for 3D arrays. + 3 => { + let n_shards_1 = chunk_dims[1] / inner_dims[1]; + let n_shards_2 = chunk_dims[2] / inner_dims[2]; + let i = pos / (n_shards_1 * n_shards_2); + let j = (pos - i * (n_shards_1 * n_shards_2)) / n_shards_2; + let k = pos - i * (n_shards_1 * n_shards_2) - j * n_shards_2; + let shard_pos = [i, j, k]; + + let stride = inner_real_dims[2]; + for depth_idx in 0..inner_real_dims[0] { + for row_idx in 0..inner_real_dims[1] { + let shard_start = depth_idx * inner_real_dims[1] * stride + row_idx * stride; + let shard_row = &shard_data[shard_start..shard_start + stride]; + let chunk_offset = (shard_pos[0] * inner_dims[0] + depth_idx) + * chunk_real_dims[1] + * chunk_real_dims[2] + + (shard_pos[1] * inner_dims[1] + row_idx) * chunk_real_dims[2] + + shard_pos[2] * inner_dims[2]; + + data[chunk_offset..chunk_offset + stride].copy_from_slice(shard_row); + } + } + } + _ => { + return Err(throw_invalid_meta( + "too many dimensions when processing shards", + )); + } + } + + Ok(()) +} + // a macro that instantiates functions to decode different data types macro_rules! create_decode_function { ($func_name: tt, $type: ty, $byte_size: tt) => { @@ -607,7 +701,7 @@ macro_rules! create_decode_function { (bytes, array_to_bytes_codec, array_to_array_codec) = decode_bytes_to_bytes(&codecs, &bytes[..], &sharding_params)?; - let mut data = Vec::new(); + let mut data; if let Some(sharding_params) = sharding_params.as_ref() { let mut index_size: usize = 2 * 8 * sharding_params.n_chunks.iter().fold(1, |mult, x| mult * x); @@ -622,16 +716,34 @@ macro_rules! create_decode_function { let (offsets, nbytes) = extract_sharding_index(&sharding_params.index_codecs, index_bytes)?; + data = vec![<$type>::default(); bytes.len()]; + let mut total_length = 0; for (pos, (o, n)) in offsets.iter().zip(nbytes.iter()).enumerate() { + // the below condition indicates an empty shard + if o == &NULL_OFFSET && n == &NULL_NBYTES { + continue; + } + let inner_real_dims = + get_inner_chunk_real_dims(&sharding_params, &real_dims, pos); let inner_data = $func_name( bytes[*o..o + n].to_vec(), &sharding_params.chunk_shape, - &get_inner_chunk_real_dims(&sharding_params, &real_dims, pos), // TODO: fix this to real dims + &inner_real_dims, &sharding_params.codecs, None, )?; - data.extend(inner_data); + total_length += inner_data.len(); + fill_data_from_shard( + &mut data, + &inner_data, + &real_dims, + &chunk_dims, + &inner_real_dims, + &sharding_params.chunk_shape, + pos, + )?; } + data = data[0..total_length].to_vec(); } else { if let Some(ZarrCodec::Bytes(e)) = array_to_bytes_codec { data = convert_bytes!(bytes, &e, $type, $byte_size); @@ -983,7 +1095,7 @@ mod zarr_codecs_tests { Arc::new(Field::new("float_data", DataType::Float64, false)) ); let target_arr: Float64Array = vec![ - 36.0, 37.0, 44.0, 45.0, 38.0, 39.0, 46.0, 47.0, 52.0, 53.0, 60.0, 61.0, 54.0, 55.0, + 36.0, 37.0, 38.0, 39.0, 44.0, 45.0, 46.0, 47.0, 52.0, 53.0, 54.0, 55.0, 60.0, 61.0, 62.0, 63.0, ] .into(); @@ -1034,7 +1146,7 @@ mod zarr_codecs_tests { field, Arc::new(Field::new("uint_data", DataType::UInt16, false)) ); - let target_arr: UInt16Array = vec![32, 33, 39, 40, 34, 41, 46, 47, 48].into(); + let target_arr: UInt16Array = vec![32, 33, 34, 39, 40, 41, 46, 47, 48].into(); assert_eq!(*arr, target_arr); } diff --git a/src/reader/mod.rs b/src/reader/mod.rs index 2023656..cca087d 100644 --- a/src/reader/mod.rs +++ b/src/reader/mod.rs @@ -456,6 +456,28 @@ mod zarr_reader_tests { assert!(matched); } + fn compare_values(col_name1: &str, col_name2: &str, rec: &RecordBatch) + where + T: ArrowPrimitiveType, + { + let mut vals1 = None; + let mut vals2 = None; + for (idx, col) in enumerate(rec.schema().fields.iter()) { + if col.name().as_str() == col_name1 { + vals1 = Some(rec.column(idx).as_primitive::().values()) + } else if col.name().as_str() == col_name2 { + vals2 = Some(rec.column(idx).as_primitive::().values()) + } + } + + if let (Some(vals1), Some(vals2)) = (vals1, vals2) { + assert_eq!(vals1, vals2); + return; + } + + panic!("columns not found"); + } + fn validate_string_column(col_name: &str, rec: &RecordBatch, targets: &[&str]) { let mut matched = false; for (idx, col) in enumerate(rec.schema().fields.iter()) { @@ -1060,6 +1082,32 @@ mod zarr_reader_tests { ); } + #[test] + fn with_partial_sharding_tests() { + let p = get_test_v3_data_path("with_partial_sharding.zarr".to_string()); + let builder = ZarrRecordBatchReaderBuilder::new(p); + + let reader = builder.build().unwrap(); + let records: Vec = reader.map(|x| x.unwrap()).collect(); + + for rec in records { + compare_values::("float_data_not_sharded", "float_data_sharded", &rec); + } + } + + #[test] + fn with_partial_sharding_3d_tests() { + let p = get_test_v3_data_path("with_partial_sharding_3D.zarr".to_string()); + let builder = ZarrRecordBatchReaderBuilder::new(p); + + let reader = builder.build().unwrap(); + let records: Vec = reader.map(|x| x.unwrap()).collect(); + + for rec in records { + compare_values::("float_data_not_sharded", "float_data_sharded", &rec); + } + } + #[test] fn with_sharding_tests() { let p = get_test_v3_data_path("with_sharding.zarr".to_string()); @@ -1079,7 +1127,7 @@ mod zarr_reader_tests { "float_data", rec, &[ - 32.0, 33.0, 40.0, 41.0, 34.0, 35.0, 42.0, 43.0, 48.0, 49.0, 56.0, 57.0, 50.0, 51.0, + 32.0, 33.0, 34.0, 35.0, 40.0, 41.0, 42.0, 43.0, 48.0, 49.0, 50.0, 51.0, 56.0, 57.0, 58.0, 59.0, ], ); @@ -1087,7 +1135,7 @@ mod zarr_reader_tests { "int_data", rec, &[ - 32, 33, 40, 41, 34, 35, 42, 43, 48, 49, 56, 57, 50, 51, 58, 59, + 32, 33, 34, 35, 40, 41, 42, 43, 48, 49, 50, 51, 56, 57, 58, 59, ], ); } @@ -1107,7 +1155,7 @@ mod zarr_reader_tests { validate_primitive_column::( "uint_data", rec, - &[4, 5, 11, 12, 6, 13, 18, 19, 25, 26, 20, 27], + &[4, 5, 6, 11, 12, 13, 18, 19, 20, 25, 26, 27], ); } @@ -1142,10 +1190,10 @@ mod zarr_reader_tests { "float_data", rec, &[ - 1020.0, 1021.0, 1031.0, 1032.0, 1141.0, 1142.0, 1152.0, 1153.0, 1022.0, 1033.0, - 1143.0, 1154.0, 1042.0, 1043.0, 1053.0, 1054.0, 1163.0, 1164.0, 1174.0, 1175.0, - 1044.0, 1055.0, 1165.0, 1176.0, 1262.0, 1263.0, 1273.0, 1274.0, 1264.0, 1275.0, - 1284.0, 1285.0, 1295.0, 1296.0, 1286.0, 1297.0, + 1020.0, 1021.0, 1022.0, 1031.0, 1032.0, 1033.0, 1042.0, 1043.0, 1044.0, 1053.0, + 1054.0, 1055.0, 1141.0, 1142.0, 1143.0, 1152.0, 1153.0, 1154.0, 1163.0, 1164.0, + 1165.0, 1174.0, 1175.0, 1176.0, 1262.0, 1263.0, 1264.0, 1273.0, 1274.0, 1275.0, + 1284.0, 1285.0, 1286.0, 1295.0, 1296.0, 1297.0, ], ); } diff --git a/test-data b/test-data index d131c99..7809aec 160000 --- a/test-data +++ b/test-data @@ -1 +1 @@ -Subproject commit d131c9998aabba6382fc6435416c2a78d5a46fe0 +Subproject commit 7809aeccd2d460bce819d94ec6cc09a6c48068d0