From 27a4fe2a5c387a7e6e5661e19e265efc259f9396 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Fri, 20 Oct 2023 08:15:42 +0200 Subject: [PATCH] fix: recursively check allowed streaming dtypes (#11879) --- .../src/physical_plan/streaming/convert_alp.rs | 4 ++++ .../unit/streaming/test_streaming_categoricals.py | 14 ++++++++++++++ 2 files changed, 18 insertions(+) create mode 100644 py-polars/tests/unit/streaming/test_streaming_categoricals.py diff --git a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs index e23357f58d40..bc2af56f18d7 100644 --- a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs +++ b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs @@ -357,6 +357,10 @@ pub(crate) fn insert_streaming_nodes( DataType::Object(_) => false, #[cfg(feature = "dtype-categorical")] DataType::Categorical(_) => string_cache, + DataType::List(inner) => allowed_dtype(inner, string_cache), + DataType::Struct(fields) => fields + .iter() + .all(|fld| allowed_dtype(fld.data_type(), string_cache)), _ => true, } } diff --git a/py-polars/tests/unit/streaming/test_streaming_categoricals.py b/py-polars/tests/unit/streaming/test_streaming_categoricals.py new file mode 100644 index 000000000000..776e0c0ce377 --- /dev/null +++ b/py-polars/tests/unit/streaming/test_streaming_categoricals.py @@ -0,0 +1,14 @@ +import polars as pl + + +def test_streaming_nested_categorical() -> None: + assert ( + pl.LazyFrame({"numbers": [1, 1, 2], "cat": [["str"], ["foo"], ["bar"]]}) + .with_columns(pl.col("cat").cast(pl.List(pl.Categorical))) + .group_by("numbers") + .agg(pl.col("cat").first()) + .sort("numbers") + ).collect(streaming=True).to_dict(False) == { + "numbers": [1, 2], + "cat": [["str"], ["bar"]], + }