Skip to content

Commit

Permalink
fix: Broadcast zip_with for structs
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite committed Sep 16, 2024
1 parent 962b576 commit 0485685
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 1 deletion.
49 changes: 48 additions & 1 deletion crates/polars-core/src/chunked_array/ops/zip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,54 @@ impl ChunkZip<StructType> for StructChunked {
mask: &BooleanChunked,
other: &ChunkedArray<StructType>,
) -> PolarsResult<ChunkedArray<StructType>> {
let (l, r, mask) = align_chunks_ternary(self, other, mask);
let length = self.length.max(mask.length).max(other.length);

debug_assert!(self.length == 1 || self.length == length);
debug_assert!(mask.length == 1 || mask.length == length);
debug_assert!(other.length == 1 || other.length == length);

let length = length as usize;

let mut if_true = self;
let mut if_false = other;
let mask = mask;

// align_chunks_ternary can only align chunks if:
// - Each chunkedarray only has 1 chunk
// - Each chunkedarray has an equal length (i.e. is broadcasted)
//
// Therefore, we broadcast only those that are necessary to be broadcasted.
let needs_broadcast = if_true.chunks() > 1 || if_false.chunks() > 1 || mask.chunks() > 1;
if needs_broadcast && length > 1 {
// Special case. In this case, we know what to do.
if mask.length == 1 {
// pl.when(None) <=> pl.when(False)
return Ok(if mask.get(0).unwrap_or(false) {
if self.length == 1 {
self.new_from_index(0, length)
} else {
self.clone()
}
} else {
if other.length == 1 {
other.new_from_index(0, length)
} else {
other.clone()
}
});
}

if self.length == 1 {
let broadcasted = self.new_from_index(0, length);
if_true = &broadcasted;
}
if other.length == 1 {
let broadcasted = other.new_from_index(0, length);
if_false = &broadcasted;
}
}

let (l, r, mask) = align_chunks_ternary(if_true, if_false, mask);

// Prepare the boolean arrays such that Null maps to false.
// This prevents every field doing that.
Expand Down
12 changes: 12 additions & 0 deletions py-polars/tests/unit/functions/test_when_then.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,3 +630,15 @@ def test_chained_when_no_subclass_17142() -> None:

assert not isinstance(when, pl.Expr)
assert "<polars.expr.whenthen.ChainedWhen object at" in str(when)


def test_when_then_chunked_structs_18673() -> None:
df = pl.DataFrame([
pl.Series('x', [{ 'a': 1 }]),
pl.Series('b', [None], dtype=pl.Boolean),
])

df = df.vstack(df)

# This used to panic
df.select(pl.when(pl.col.b).then(pl.first("x")).otherwise(pl.first("x")))

0 comments on commit 0485685

Please sign in to comment.