Skip to content

Commit

Permalink
fixed a bug with shard alignment (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
maximedion2 authored Nov 28, 2024
1 parent 17d4af2 commit 743c452
Show file tree
Hide file tree
Showing 4 changed files with 225 additions and 19 deletions.
58 changes: 52 additions & 6 deletions src/async_reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,28 @@ mod zarr_async_reader_tests {
assert!(matched);
}

fn compare_values<T>(col_name1: &str, col_name2: &str, rec: &RecordBatch)
where
T: ArrowPrimitiveType,
{
let mut vals1 = None;
let mut vals2 = None;
for (idx, col) in enumerate(rec.schema().fields.iter()) {
if col.name().as_str() == col_name1 {
vals1 = Some(rec.column(idx).as_primitive::<T>().values())
} else if col.name().as_str() == col_name2 {
vals2 = Some(rec.column(idx).as_primitive::<T>().values())
}
}

if let (Some(vals1), Some(vals2)) = (vals1, vals2) {
assert_eq!(vals1, vals2);
return;
}

panic!("columns not found");
}

// create a test filter
fn create_filter() -> ZarrChunkFilter {
let mut filters: Vec<Box<dyn ZarrArrowPredicate>> = Vec::new();
Expand Down Expand Up @@ -1047,15 +1069,15 @@ mod zarr_async_reader_tests {
"float_data",
rec,
&[
32.0, 33.0, 40.0, 41.0, 34.0, 35.0, 42.0, 43.0, 48.0, 49.0, 56.0, 57.0, 50.0, 51.0,
32.0, 33.0, 34.0, 35.0, 40.0, 41.0, 42.0, 43.0, 48.0, 49.0, 50.0, 51.0, 56.0, 57.0,
58.0, 59.0,
],
);
validate_primitive_column::<Int64Type, i64>(
"int_data",
rec,
&[
32, 33, 40, 41, 34, 35, 42, 43, 48, 49, 56, 57, 50, 51, 58, 59,
32, 33, 34, 35, 40, 41, 42, 43, 48, 49, 50, 51, 56, 57, 58, 59,
],
);
}
Expand All @@ -1076,10 +1098,10 @@ mod zarr_async_reader_tests {
"float_data",
rec,
&[
1020.0, 1021.0, 1031.0, 1032.0, 1141.0, 1142.0, 1152.0, 1153.0, 1022.0, 1033.0,
1143.0, 1154.0, 1042.0, 1043.0, 1053.0, 1054.0, 1163.0, 1164.0, 1174.0, 1175.0,
1044.0, 1055.0, 1165.0, 1176.0, 1262.0, 1263.0, 1273.0, 1274.0, 1264.0, 1275.0,
1284.0, 1285.0, 1295.0, 1296.0, 1286.0, 1297.0,
1020.0, 1021.0, 1022.0, 1031.0, 1032.0, 1033.0, 1042.0, 1043.0, 1044.0, 1053.0,
1054.0, 1055.0, 1141.0, 1142.0, 1143.0, 1152.0, 1153.0, 1154.0, 1163.0, 1164.0,
1165.0, 1174.0, 1175.0, 1176.0, 1262.0, 1263.0, 1264.0, 1273.0, 1274.0, 1275.0,
1284.0, 1285.0, 1286.0, 1295.0, 1296.0, 1297.0,
],
);
}
Expand All @@ -1102,4 +1124,28 @@ mod zarr_async_reader_tests {
&[4, 5, 6, 7, 20, 21, 22, 23, 36, 37, 38, 39, 52, 53, 54, 55],
);
}

#[tokio::test]
async fn with_partial_sharding_tests() {
let zp = get_v3_test_zarr_path("with_partial_sharding.zarr".to_string());
let stream_builder = ZarrRecordBatchStreamBuilder::new(zp);

let stream = stream_builder.build().await.unwrap();
let records: Vec<_> = stream.try_collect().await.unwrap();
for rec in records {
compare_values::<Float64Type>("float_data_not_sharded", "float_data_sharded", &rec);
}
}

#[tokio::test]
async fn with_partial_sharding_3d_tests() {
let zp = get_v3_test_zarr_path("with_partial_sharding_3D.zarr".to_string());
let stream_builder = ZarrRecordBatchStreamBuilder::new(zp);

let stream = stream_builder.build().await.unwrap();
let records: Vec<_> = stream.try_collect().await.unwrap();
for rec in records {
compare_values::<Float64Type>("float_data_not_sharded", "float_data_sharded", &rec);
}
}
}
122 changes: 117 additions & 5 deletions src/reader/codecs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ use std::str::FromStr;
use std::sync::Arc;
use std::vec;

// couple useful constant for empty shards
const NULL_OFFSET: usize = usize::pow(2, 63) + (usize::pow(2, 63) - 1);
const NULL_NBYTES: usize = usize::pow(2, 63) + (usize::pow(2, 63) - 1);

// Type enum and for the various support zarr types
#[derive(Debug, PartialEq, Clone)]
pub(crate) enum ZarrDataType {
Expand Down Expand Up @@ -592,6 +596,96 @@ fn broadcast_array<T: Clone>(
}
}

// the logic here gets a little messy, but the goal is to fill in the
// the data for the outer chunks from the data within the inner shards
// in the outer chunk. the below function is called one inner shard at
// a time, and the "flat" data from the inner shard is read in small
// chunks and written to the correct position in the array for the entire
// outer chunk.
fn fill_data_from_shard<T: Copy>(
data: &mut [T],
shard_data: &[T],
chunk_real_dims: &[usize],
chunk_dims: &[usize],
inner_real_dims: &[usize],
inner_dims: &[usize],
pos: usize,
) -> ZarrResult<()> {
let l = inner_real_dims.len();
let err = "mismatch between inner and outer chunks dimensions";
if l != chunk_real_dims.len() || l != inner_dims.len() {
return Err(throw_invalid_meta(err));
}

match l {
// the 1D array case is easy, there's just one offset to compute,
// from the start of the 1D outer chunk.
1 => {
let stride = inner_dims[0];
data[pos * stride..(pos + 1) * stride].copy_from_slice(shard_data);
}
// The 2D case is trickier, we need to keep track of the inner shard
// position within the outer chunk, as well the where we're reading
// in the inner shard.
2 => {
// first, find the position of the shard within the chunk.
let n_shards_1 = chunk_dims[1] / inner_dims[1];
let i = pos / n_shards_1;
let j = pos - i * n_shards_1;
let shard_pos = [i, j];

// then loop over each row in the shard and write to the
// data array for the chunk.
let stride = inner_real_dims[1];
for row_idx in 0..inner_real_dims[0] {
// read the row from the 1D shard array
let shard_row = &shard_data[row_idx * stride..(row_idx + 1) * stride];

// determine where in the chunk data array the shard array should go.
// the first component of the offsets is the "vertrical" position of
// the shard within the chunk times the rows per shard, plus the row
// within the shard, all that times the number of rows per shards.
// the second component is the "horizontal" position of the shard
// within the chunk, times the number of columns per shard.
let chunk_offset = (shard_pos[0] * inner_dims[0] + row_idx) * chunk_real_dims[1]
+ shard_pos[1] * inner_dims[1];
data[chunk_offset..chunk_offset + stride].copy_from_slice(shard_row);
}
}
// similar to the 2D case, but a but more complicated, for 3D arrays.
3 => {
let n_shards_1 = chunk_dims[1] / inner_dims[1];
let n_shards_2 = chunk_dims[2] / inner_dims[2];
let i = pos / (n_shards_1 * n_shards_2);
let j = (pos - i * (n_shards_1 * n_shards_2)) / n_shards_2;
let k = pos - i * (n_shards_1 * n_shards_2) - j * n_shards_2;
let shard_pos = [i, j, k];

let stride = inner_real_dims[2];
for depth_idx in 0..inner_real_dims[0] {
for row_idx in 0..inner_real_dims[1] {
let shard_start = depth_idx * inner_real_dims[1] * stride + row_idx * stride;
let shard_row = &shard_data[shard_start..shard_start + stride];
let chunk_offset = (shard_pos[0] * inner_dims[0] + depth_idx)
* chunk_real_dims[1]
* chunk_real_dims[2]
+ (shard_pos[1] * inner_dims[1] + row_idx) * chunk_real_dims[2]
+ shard_pos[2] * inner_dims[2];

data[chunk_offset..chunk_offset + stride].copy_from_slice(shard_row);
}
}
}
_ => {
return Err(throw_invalid_meta(
"too many dimensions when processing shards",
));
}
}

Ok(())
}

// a macro that instantiates functions to decode different data types
macro_rules! create_decode_function {
($func_name: tt, $type: ty, $byte_size: tt) => {
Expand All @@ -607,7 +701,7 @@ macro_rules! create_decode_function {
(bytes, array_to_bytes_codec, array_to_array_codec) =
decode_bytes_to_bytes(&codecs, &bytes[..], &sharding_params)?;

let mut data = Vec::new();
let mut data;
if let Some(sharding_params) = sharding_params.as_ref() {
let mut index_size: usize =
2 * 8 * sharding_params.n_chunks.iter().fold(1, |mult, x| mult * x);
Expand All @@ -622,16 +716,34 @@ macro_rules! create_decode_function {
let (offsets, nbytes) =
extract_sharding_index(&sharding_params.index_codecs, index_bytes)?;

data = vec![<$type>::default(); bytes.len()];
let mut total_length = 0;
for (pos, (o, n)) in offsets.iter().zip(nbytes.iter()).enumerate() {
// the below condition indicates an empty shard
if o == &NULL_OFFSET && n == &NULL_NBYTES {
continue;
}
let inner_real_dims =
get_inner_chunk_real_dims(&sharding_params, &real_dims, pos);
let inner_data = $func_name(
bytes[*o..o + n].to_vec(),
&sharding_params.chunk_shape,
&get_inner_chunk_real_dims(&sharding_params, &real_dims, pos), // TODO: fix this to real dims
&inner_real_dims,
&sharding_params.codecs,
None,
)?;
data.extend(inner_data);
total_length += inner_data.len();
fill_data_from_shard(
&mut data,
&inner_data,
&real_dims,
&chunk_dims,
&inner_real_dims,
&sharding_params.chunk_shape,
pos,
)?;
}
data = data[0..total_length].to_vec();
} else {
if let Some(ZarrCodec::Bytes(e)) = array_to_bytes_codec {
data = convert_bytes!(bytes, &e, $type, $byte_size);
Expand Down Expand Up @@ -983,7 +1095,7 @@ mod zarr_codecs_tests {
Arc::new(Field::new("float_data", DataType::Float64, false))
);
let target_arr: Float64Array = vec![
36.0, 37.0, 44.0, 45.0, 38.0, 39.0, 46.0, 47.0, 52.0, 53.0, 60.0, 61.0, 54.0, 55.0,
36.0, 37.0, 38.0, 39.0, 44.0, 45.0, 46.0, 47.0, 52.0, 53.0, 54.0, 55.0, 60.0, 61.0,
62.0, 63.0,
]
.into();
Expand Down Expand Up @@ -1034,7 +1146,7 @@ mod zarr_codecs_tests {
field,
Arc::new(Field::new("uint_data", DataType::UInt16, false))
);
let target_arr: UInt16Array = vec![32, 33, 39, 40, 34, 41, 46, 47, 48].into();
let target_arr: UInt16Array = vec![32, 33, 34, 39, 40, 41, 46, 47, 48].into();
assert_eq!(*arr, target_arr);
}

Expand Down
62 changes: 55 additions & 7 deletions src/reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,28 @@ mod zarr_reader_tests {
assert!(matched);
}

fn compare_values<T>(col_name1: &str, col_name2: &str, rec: &RecordBatch)
where
T: ArrowPrimitiveType,
{
let mut vals1 = None;
let mut vals2 = None;
for (idx, col) in enumerate(rec.schema().fields.iter()) {
if col.name().as_str() == col_name1 {
vals1 = Some(rec.column(idx).as_primitive::<T>().values())
} else if col.name().as_str() == col_name2 {
vals2 = Some(rec.column(idx).as_primitive::<T>().values())
}
}

if let (Some(vals1), Some(vals2)) = (vals1, vals2) {
assert_eq!(vals1, vals2);
return;
}

panic!("columns not found");
}

fn validate_string_column(col_name: &str, rec: &RecordBatch, targets: &[&str]) {
let mut matched = false;
for (idx, col) in enumerate(rec.schema().fields.iter()) {
Expand Down Expand Up @@ -1060,6 +1082,32 @@ mod zarr_reader_tests {
);
}

#[test]
fn with_partial_sharding_tests() {
let p = get_test_v3_data_path("with_partial_sharding.zarr".to_string());
let builder = ZarrRecordBatchReaderBuilder::new(p);

let reader = builder.build().unwrap();
let records: Vec<RecordBatch> = reader.map(|x| x.unwrap()).collect();

for rec in records {
compare_values::<Float64Type>("float_data_not_sharded", "float_data_sharded", &rec);
}
}

#[test]
fn with_partial_sharding_3d_tests() {
let p = get_test_v3_data_path("with_partial_sharding_3D.zarr".to_string());
let builder = ZarrRecordBatchReaderBuilder::new(p);

let reader = builder.build().unwrap();
let records: Vec<RecordBatch> = reader.map(|x| x.unwrap()).collect();

for rec in records {
compare_values::<Float64Type>("float_data_not_sharded", "float_data_sharded", &rec);
}
}

#[test]
fn with_sharding_tests() {
let p = get_test_v3_data_path("with_sharding.zarr".to_string());
Expand All @@ -1079,15 +1127,15 @@ mod zarr_reader_tests {
"float_data",
rec,
&[
32.0, 33.0, 40.0, 41.0, 34.0, 35.0, 42.0, 43.0, 48.0, 49.0, 56.0, 57.0, 50.0, 51.0,
32.0, 33.0, 34.0, 35.0, 40.0, 41.0, 42.0, 43.0, 48.0, 49.0, 50.0, 51.0, 56.0, 57.0,
58.0, 59.0,
],
);
validate_primitive_column::<Int64Type, i64>(
"int_data",
rec,
&[
32, 33, 40, 41, 34, 35, 42, 43, 48, 49, 56, 57, 50, 51, 58, 59,
32, 33, 34, 35, 40, 41, 42, 43, 48, 49, 50, 51, 56, 57, 58, 59,
],
);
}
Expand All @@ -1107,7 +1155,7 @@ mod zarr_reader_tests {
validate_primitive_column::<UInt16Type, u16>(
"uint_data",
rec,
&[4, 5, 11, 12, 6, 13, 18, 19, 25, 26, 20, 27],
&[4, 5, 6, 11, 12, 13, 18, 19, 20, 25, 26, 27],
);
}

Expand Down Expand Up @@ -1142,10 +1190,10 @@ mod zarr_reader_tests {
"float_data",
rec,
&[
1020.0, 1021.0, 1031.0, 1032.0, 1141.0, 1142.0, 1152.0, 1153.0, 1022.0, 1033.0,
1143.0, 1154.0, 1042.0, 1043.0, 1053.0, 1054.0, 1163.0, 1164.0, 1174.0, 1175.0,
1044.0, 1055.0, 1165.0, 1176.0, 1262.0, 1263.0, 1273.0, 1274.0, 1264.0, 1275.0,
1284.0, 1285.0, 1295.0, 1296.0, 1286.0, 1297.0,
1020.0, 1021.0, 1022.0, 1031.0, 1032.0, 1033.0, 1042.0, 1043.0, 1044.0, 1053.0,
1054.0, 1055.0, 1141.0, 1142.0, 1143.0, 1152.0, 1153.0, 1154.0, 1163.0, 1164.0,
1165.0, 1174.0, 1175.0, 1176.0, 1262.0, 1263.0, 1264.0, 1273.0, 1274.0, 1275.0,
1284.0, 1285.0, 1286.0, 1295.0, 1296.0, 1297.0,
],
);
}
Expand Down
2 changes: 1 addition & 1 deletion test-data
Submodule test-data updated 40 files
+ data/zarr/v3_data/with_partial_sharding.zarr/float_data_not_sharded/0.0
+ data/zarr/v3_data/with_partial_sharding.zarr/float_data_not_sharded/0.1
+ data/zarr/v3_data/with_partial_sharding.zarr/float_data_not_sharded/0.2
+ data/zarr/v3_data/with_partial_sharding.zarr/float_data_not_sharded/1.0
+ data/zarr/v3_data/with_partial_sharding.zarr/float_data_not_sharded/1.1
+ data/zarr/v3_data/with_partial_sharding.zarr/float_data_not_sharded/1.2
+1 −0 data/zarr/v3_data/with_partial_sharding.zarr/float_data_not_sharded/zarr.json
+ data/zarr/v3_data/with_partial_sharding.zarr/float_data_sharded/0.0
+ data/zarr/v3_data/with_partial_sharding.zarr/float_data_sharded/0.1
+ data/zarr/v3_data/with_partial_sharding.zarr/float_data_sharded/0.2
+ data/zarr/v3_data/with_partial_sharding.zarr/float_data_sharded/1.0
+ data/zarr/v3_data/with_partial_sharding.zarr/float_data_sharded/1.1
+ data/zarr/v3_data/with_partial_sharding.zarr/float_data_sharded/1.2
+1 −0 data/zarr/v3_data/with_partial_sharding.zarr/float_data_sharded/zarr.json
+ data/zarr/v3_data/with_partial_sharding_3D.zarr/float_data_not_sharded/0.0.0
+ data/zarr/v3_data/with_partial_sharding_3D.zarr/float_data_not_sharded/0.0.1
+ data/zarr/v3_data/with_partial_sharding_3D.zarr/float_data_not_sharded/0.1.0
+ data/zarr/v3_data/with_partial_sharding_3D.zarr/float_data_not_sharded/0.1.1
+ data/zarr/v3_data/with_partial_sharding_3D.zarr/float_data_not_sharded/0.2.0
+ data/zarr/v3_data/with_partial_sharding_3D.zarr/float_data_not_sharded/0.2.1
+ data/zarr/v3_data/with_partial_sharding_3D.zarr/float_data_not_sharded/1.0.0
+ data/zarr/v3_data/with_partial_sharding_3D.zarr/float_data_not_sharded/1.0.1
+ data/zarr/v3_data/with_partial_sharding_3D.zarr/float_data_not_sharded/1.1.0
+ data/zarr/v3_data/with_partial_sharding_3D.zarr/float_data_not_sharded/1.1.1
+ data/zarr/v3_data/with_partial_sharding_3D.zarr/float_data_not_sharded/1.2.0
+ data/zarr/v3_data/with_partial_sharding_3D.zarr/float_data_not_sharded/1.2.1
+1 −0 data/zarr/v3_data/with_partial_sharding_3D.zarr/float_data_not_sharded/zarr.json
+ data/zarr/v3_data/with_partial_sharding_3D.zarr/float_data_sharded/0.0.0
+ data/zarr/v3_data/with_partial_sharding_3D.zarr/float_data_sharded/0.0.1
+ data/zarr/v3_data/with_partial_sharding_3D.zarr/float_data_sharded/0.1.0
+ data/zarr/v3_data/with_partial_sharding_3D.zarr/float_data_sharded/0.1.1
+ data/zarr/v3_data/with_partial_sharding_3D.zarr/float_data_sharded/0.2.0
+ data/zarr/v3_data/with_partial_sharding_3D.zarr/float_data_sharded/0.2.1
+ data/zarr/v3_data/with_partial_sharding_3D.zarr/float_data_sharded/1.0.0
+ data/zarr/v3_data/with_partial_sharding_3D.zarr/float_data_sharded/1.0.1
+ data/zarr/v3_data/with_partial_sharding_3D.zarr/float_data_sharded/1.1.0
+ data/zarr/v3_data/with_partial_sharding_3D.zarr/float_data_sharded/1.1.1
+ data/zarr/v3_data/with_partial_sharding_3D.zarr/float_data_sharded/1.2.0
+ data/zarr/v3_data/with_partial_sharding_3D.zarr/float_data_sharded/1.2.1
+1 −0 data/zarr/v3_data/with_partial_sharding_3D.zarr/float_data_sharded/zarr.json

0 comments on commit 743c452

Please sign in to comment.