From 0485685a590682c20767db647e5377664623f542 Mon Sep 17 00:00:00 2001 From: coastalwhite Date: Mon, 16 Sep 2024 14:03:10 +0200 Subject: [PATCH] fix: Broadcast zip_with for structs --- .../polars-core/src/chunked_array/ops/zip.rs | 49 ++++++++++++++++++- .../tests/unit/functions/test_when_then.py | 12 +++++ 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/crates/polars-core/src/chunked_array/ops/zip.rs b/crates/polars-core/src/chunked_array/ops/zip.rs index 7fe09ba2c7d1..c9c39ff6e8c3 100644 --- a/crates/polars-core/src/chunked_array/ops/zip.rs +++ b/crates/polars-core/src/chunked_array/ops/zip.rs @@ -214,7 +214,54 @@ impl ChunkZip for StructChunked { mask: &BooleanChunked, other: &ChunkedArray, ) -> PolarsResult> { - let (l, r, mask) = align_chunks_ternary(self, other, mask); + let length = self.length.max(mask.length).max(other.length); + + debug_assert!(self.length == 1 || self.length == length); + debug_assert!(mask.length == 1 || mask.length == length); + debug_assert!(other.length == 1 || other.length == length); + + let length = length as usize; + + let mut if_true = self; + let mut if_false = other; + let mask = mask; + + // align_chunks_ternary can only align chunks if: + // - Each chunkedarray only has 1 chunk + // - Each chunkedarray has an equal length (i.e. is broadcasted) + // + // Therefore, we broadcast only those that are necessary to be broadcasted. + let needs_broadcast = if_true.chunks() > 1 || if_false.chunks() > 1 || mask.chunks() > 1; + if needs_broadcast && length > 1 { + // Special case. In this case, we know what to do. + if mask.length == 1 { + // pl.when(None) <=> pl.when(False) + return Ok(if mask.get(0).unwrap_or(false) { + if self.length == 1 { + self.new_from_index(0, length) + } else { + self.clone() + } + } else { + if other.length == 1 { + other.new_from_index(0, length) + } else { + other.clone() + } + }); + } + + if self.length == 1 { + let broadcasted = self.new_from_index(0, length); + if_true = &broadcasted; + } + if other.length == 1 { + let broadcasted = other.new_from_index(0, length); + if_false = &broadcasted; + } + } + + let (l, r, mask) = align_chunks_ternary(if_true, if_false, mask); // Prepare the boolean arrays such that Null maps to false. // This prevents every field doing that. diff --git a/py-polars/tests/unit/functions/test_when_then.py b/py-polars/tests/unit/functions/test_when_then.py index 6f70b9c1b9d1..36b36a56aef8 100644 --- a/py-polars/tests/unit/functions/test_when_then.py +++ b/py-polars/tests/unit/functions/test_when_then.py @@ -630,3 +630,15 @@ def test_chained_when_no_subclass_17142() -> None: assert not isinstance(when, pl.Expr) assert " None: + df = pl.DataFrame([ + pl.Series('x', [{ 'a': 1 }]), + pl.Series('b', [None], dtype=pl.Boolean), + ]) + + df = df.vstack(df) + + # This used to panic + df.select(pl.when(pl.col.b).then(pl.first("x")).otherwise(pl.first("x")))