diff --git a/crates/polars-core/src/datatypes/dtype.rs b/crates/polars-core/src/datatypes/dtype.rs index 9cdc5620ed07..eeff4e2a2e1f 100644 --- a/crates/polars-core/src/datatypes/dtype.rs +++ b/crates/polars-core/src/datatypes/dtype.rs @@ -228,6 +228,21 @@ impl DataType { prev } + #[cfg(feature = "dtype-array")] + /// Get the inner data type of a multidimensional array. + pub fn array_leaf_dtype(&self) -> Option<&DataType> { + let mut prev = self; + match prev { + DataType::Array(_, _) => { + while let DataType::Array(inner, _) = &prev { + prev = &inner; + } + Some(prev) + }, + _ => None, + } + } + /// Cast the leaf types of Lists/Arrays and keep the nesting. pub fn cast_leaf(&self, to: DataType) -> DataType { use DataType::*; @@ -717,7 +732,7 @@ impl Display for DataType { DataType::Time => "time", #[cfg(feature = "dtype-array")] DataType::Array(_, _) => { - let tp = self.leaf_dtype(); + let tp = self.array_leaf_dtype().unwrap(); let dims = self.get_shape().unwrap(); let shape = if dims.len() == 1 { diff --git a/py-polars/tests/unit/datatypes/test_array.py b/py-polars/tests/unit/datatypes/test_array.py index 38e5e5c154fe..03f92bd68d11 100644 --- a/py-polars/tests/unit/datatypes/test_array.py +++ b/py-polars/tests/unit/datatypes/test_array.py @@ -291,6 +291,10 @@ def test_recursive_array_dtype() -> None: s = pl.Series(np.arange(6).reshape((2, 3)), dtype=dtype) assert s.dtype == dtype assert s.len() == 2 + dtype = pl.Array(pl.List(pl.Array(pl.Int8, (2, 2))), 2) + s = pl.Series(dtype=dtype) + assert s.dtype == dtype + assert str(s) == "shape: (0,)\nSeries: '' [array[list[array[i8, (2, 2)]], 2]]\n[\n]" def test_ndarray_construction() -> None: