Skip to content

Commit

Permalink
Added tests for Schema normalization. Partial tests for RecordBatch.
Browse files Browse the repository at this point in the history
  • Loading branch information
nglime committed Nov 25, 2024
1 parent 30d6294 commit 0ed979d
Show file tree
Hide file tree
Showing 2 changed files with 183 additions and 12 deletions.
114 changes: 102 additions & 12 deletions arrow-array/src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -432,14 +432,8 @@ impl RecordBatch {
match f.data_type() {
DataType::Struct(ff) => {
// Need to zip these in reverse to maintain original order
for (cff, fff) in c
.as_struct()
.columns()
.iter()
.rev()
.zip(ff.into_iter().rev())
{
let new_key = format!("{}{separator}{}", f.name(), fff.name());
for (cff, fff) in c.as_struct().columns().iter().zip(ff.into_iter()).rev() {
let new_key = format!("{}{}{}", f.name(), separator, fff.name());
let updated_field = Field::new(
new_key.as_str(),
fff.data_type().clone(),
Expand Down Expand Up @@ -1291,10 +1285,10 @@ mod tests {
Field::new("month", DataType::Int64, true),
]);

let record_batch = RecordBatch::try_new(Arc::new(schema), vec![a, month.clone()])
.expect("valid conversion");

let normalized = record_batch.normalize(".", 0).expect("valid normalization");
let normalized = RecordBatch::try_new(Arc::new(schema), vec![a, month.clone()])
.expect("valid conversion")
.normalize(".", 0)
.expect("valid normalization");

let expected = RecordBatch::try_from_iter_with_nullable(vec![
("a.animals", animals.clone(), true),
Expand All @@ -1307,6 +1301,102 @@ mod tests {
assert_eq!(expected, normalized);
}

#[test]
fn normalize_nested() {
// Initialize schema
let a = Arc::new(Field::new("a", DataType::Utf8, true));
let b = Arc::new(Field::new("b", DataType::Int64, false));
let c = Arc::new(Field::new("c", DataType::Int64, true));

let d = Arc::new(Field::new("d", DataType::Utf8, true));
let e = Arc::new(Field::new("e", DataType::Int64, false));
let f = Arc::new(Field::new("f", DataType::Int64, true));

let one = Arc::new(Field::new(
"1",
DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
false,
));
let two = Arc::new(Field::new(
"2",
DataType::Struct(Fields::from(vec![d.clone(), e.clone(), f.clone()])),
true,
));

let exclamation = Arc::new(Field::new(
"!",
DataType::Struct(Fields::from(vec![one, two])),
false,
));

// Initialize fields
let a_field: ArrayRef = Arc::new(StringArray::from(vec!["a1_field_data", "a1_field_data"]));
let b_field: ArrayRef = Arc::new(Int64Array::from(vec![Some(0), Some(1)]));
let c_field: ArrayRef = Arc::new(Int64Array::from(vec![None, Some(2)]));

let d_field: ArrayRef = Arc::new(StringArray::from(vec!["d1_field_data", "d2_field_data"]));
let e_field: ArrayRef = Arc::new(Int64Array::from(vec![Some(3), Some(4)]));
let f_field: ArrayRef = Arc::new(Int64Array::from(vec![None, Some(5)]));

let one_field = Arc::new(StructArray::from(vec![
(a.clone(), Arc::new(a_field.clone()) as ArrayRef),
(b.clone(), Arc::new(b_field.clone()) as ArrayRef),
(c.clone(), Arc::new(c_field.clone()) as ArrayRef),
]));
let two_field = Arc::new(StructArray::from(vec![
(a.clone(), Arc::new(a_field.clone()) as ArrayRef),
(b.clone(), Arc::new(b_field.clone()) as ArrayRef),
(c.clone(), Arc::new(c_field.clone()) as ArrayRef),
]));

/*let exclamation_field = Arc::new(StructArray::from(vec![
(one.clone(), Arc::new(one_field.clone()) as ArrayRef),
(two.clone(), Arc::new(two_field.clone()) as ArrayRef),
]));*/

let schema = Schema::new(vec![exclamation.clone()]);
/*let normalized = RecordBatch::try_new(Arc::new(schema), vec![exclamation_field])
.expect("valid conversion");*/
//.normalize(".", 0)
//.expect("valid normalization");

/*let expected = RecordBatch::try_from_iter_with_nullable(vec![
("a.animals", animals.clone(), true),
("a.n_legs", n_legs.clone(), true),
("a.year", year.clone(), true),
("month", month.clone(), true),
])
.expect("valid conversion");*/

//assert_eq!(expected, normalized);
}

#[test]
fn normalize_empty() {
let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
let year_field = Arc::new(Field::new("year", DataType::Int64, true));

let schema = Schema::new(vec![
Field::new(
"a",
DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
false,
),
Field::new("month", DataType::Int64, true),
]);

let normalized = RecordBatch::new_empty(Arc::new(schema.clone()))
.normalize(".", 0)
.expect("valid normalization");

let expected = RecordBatch::new_empty(Arc::new(
schema.normalize(".", 0).expect("valid normalization"),
));

assert_eq!(expected, normalized);
}

#[test]
fn project() {
let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]));
Expand Down
81 changes: 81 additions & 0 deletions arrow-schema/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,87 @@ mod tests {
schema.index_of("nickname").unwrap();
}

#[test]
fn normalize() {
let schema = Schema::new(vec![
Field::new(
"a",
DataType::Struct(Fields::from(vec![
Arc::new(Field::new("animals", DataType::Utf8, true)),
Arc::new(Field::new("n_legs", DataType::Int64, true)),
Arc::new(Field::new("year", DataType::Int64, true)),
])),
false,
),
Field::new("month", DataType::Int64, true),
])
.normalize(".", 0)
.expect("valid normalization");

let expected = Schema::new(vec![
Field::new("a.animals", DataType::Utf8, true),
Field::new("a.n_legs", DataType::Int64, true),
Field::new("a.year", DataType::Int64, true),
Field::new("month", DataType::Int64, true),
]);

assert_eq!(schema, expected);
}

#[test]
fn normalize_nested() {
let a = Arc::new(Field::new("a", DataType::Utf8, true));
let b = Arc::new(Field::new("b", DataType::Int64, false));
let c = Arc::new(Field::new("c", DataType::Int64, true));

let d = Arc::new(Field::new("d", DataType::Utf8, true));
let e = Arc::new(Field::new("e", DataType::Int64, false));
let f = Arc::new(Field::new("f", DataType::Int64, true));

let one = Arc::new(Field::new(
"1",
DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
false,
));
let two = Arc::new(Field::new(
"2",
DataType::Struct(Fields::from(vec![d.clone(), e.clone(), f.clone()])),
true,
));

let exclamation = Arc::new(Field::new(
"!",
DataType::Struct(Fields::from(vec![one, two])),
false,
));

let normalize_all = Schema::new(vec![exclamation.clone()])
.normalize(".", 0)
.expect("valid normalization");

let expected = Schema::new(vec![
Field::new("!.1.a", DataType::Utf8, true),
Field::new("!.1.b", DataType::Int64, false),
Field::new("!.1.c", DataType::Int64, true),
Field::new("!.2.d", DataType::Utf8, true),
Field::new("!.2.e", DataType::Int64, false),
Field::new("!.2.f", DataType::Int64, true),
]);

assert_eq!(normalize_all, expected);

let normalize_depth_one = Schema::new(vec![exclamation])
.normalize(".", 1)
.expect("valid normalization");

let expected = Schema::new(vec![
Field::new("!.1", DataType::Struct(Fields::from(vec![a, b, c])), false),
Field::new("!.2", DataType::Struct(Fields::from(vec![d, e, f])), true),
]);

assert_eq!(normalize_depth_one, expected);
}

#[test]
#[should_panic(
expected = "Unable to get field named \\\"nickname\\\". Valid fields: [\\\"first_name\\\", \\\"last_name\\\", \\\"address\\\", \\\"interests\\\"]"
Expand Down

0 comments on commit 0ed979d

Please sign in to comment.