Skip to content

Commit

Permalink
fix(rust): Correctly display multilevel nested Arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
eitsupi committed Sep 11, 2024
1 parent d8acacf commit f984d7a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
17 changes: 16 additions & 1 deletion crates/polars-core/src/datatypes/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,21 @@ impl DataType {
prev
}

#[cfg(feature = "dtype-array")]
/// Get the inner data type of a multidimensional array.
pub fn array_leaf_dtype(&self) -> Option<&DataType> {
let mut prev = self;
match prev {
DataType::Array(_, _) => {
while let DataType::Array(inner, _) = &prev {
prev = &inner;
}
Some(prev)
},
_ => None,
}
}

/// Cast the leaf types of Lists/Arrays and keep the nesting.
pub fn cast_leaf(&self, to: DataType) -> DataType {
use DataType::*;
Expand Down Expand Up @@ -717,7 +732,7 @@ impl Display for DataType {
DataType::Time => "time",
#[cfg(feature = "dtype-array")]
DataType::Array(_, _) => {
let tp = self.leaf_dtype();
let tp = self.array_leaf_dtype().unwrap();

let dims = self.get_shape().unwrap();
let shape = if dims.len() == 1 {
Expand Down
4 changes: 4 additions & 0 deletions py-polars/tests/unit/datatypes/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,10 @@ def test_recursive_array_dtype() -> None:
s = pl.Series(np.arange(6).reshape((2, 3)), dtype=dtype)
assert s.dtype == dtype
assert s.len() == 2
dtype = pl.Array(pl.List(pl.Array(pl.Int8, (2, 2))), 2)
s = pl.Series(dtype=dtype)
assert s.dtype == dtype
assert str(s) == "shape: (0,)\nSeries: '' [array[list[array[i8, (2, 2)]], 2]]\n[\n]"


def test_ndarray_construction() -> None:
Expand Down

0 comments on commit f984d7a

Please sign in to comment.