From df7485865dd36b0e7cccc07ef45ed9d2836dfd38 Mon Sep 17 00:00:00 2001 From: Marshall Crumiller Date: Thu, 2 Jan 2025 10:23:42 -0500 Subject: [PATCH 1/3] Fix IdxType for sum_horizontal --- .../polars-ops/src/series/ops/horizontal.rs | 47 ++++++++++++++----- .../src/dsl/function_expr/schema.rs | 7 ++- .../operations/aggregation/test_horizontal.py | 11 +++-- 3 files changed, 45 insertions(+), 20 deletions(-) diff --git a/crates/polars-ops/src/series/ops/horizontal.rs b/crates/polars-ops/src/series/ops/horizontal.rs index 025779e77349..e51bddc948db 100644 --- a/crates/polars-ops/src/series/ops/horizontal.rs +++ b/crates/polars-ops/src/series/ops/horizontal.rs @@ -190,6 +190,40 @@ pub fn min_horizontal(columns: &[Column]) -> PolarsResult> { } } +// Given a Series sequence, determines the correct supertype and returns full-null Column. +// The `date_to_datetime` flag indicates that Date supertypes should be converted to Datetime("ms"). +fn null_with_supertype( + columns: Vec<&Series>, + date_to_datetime: bool, +) -> PolarsResult> { + // We must first determine the correct return dtype. + let mut return_dtype = dtypes_to_supertype(columns.iter().map(|c| c.dtype()))?; + if return_dtype == DataType::Boolean { + return_dtype = IDX_DTYPE; + } else if date_to_datetime && return_dtype == DataType::Date { + return_dtype = DataType::Datetime(TimeUnit::Milliseconds, None); + } + Ok(Some(Column::full_null( + columns[0].name().clone(), + columns[0].len(), + &return_dtype, + ))) +} + +// Apply `null_with_supertype` to a sequence of Columns. +fn null_with_supertype_from_columns( + columns: &[Column], + date_to_datetime: bool, +) -> PolarsResult> { + null_with_supertype( + columns + .iter() + .map(|c| c.as_materialized_series()) + .collect::>(), + date_to_datetime, + ) +} + pub fn sum_horizontal( columns: &[Column], null_strategy: NullStrategy, @@ -221,16 +255,7 @@ pub fn sum_horizontal( // If we have any null columns and null strategy is not `Ignore`, we can return immediately. if !ignore_nulls && non_null_cols.len() < columns.len() { - // We must first determine the correct return dtype. - let return_dtype = match dtypes_to_supertype(non_null_cols.iter().map(|c| c.dtype()))? { - DataType::Boolean => DataType::UInt32, - dt => dt, - }; - return Ok(Some(Column::full_null( - columns[0].name().clone(), - columns[0].len(), - &return_dtype, - ))); + return null_with_supertype_from_columns(columns, false); } match non_null_cols.len() { @@ -244,7 +269,7 @@ pub fn sum_horizontal( }, 1 => Ok(Some( apply_null_strategy(if non_null_cols[0].dtype() == &DataType::Boolean { - non_null_cols[0].cast(&DataType::UInt32)? + non_null_cols[0].cast(&IDX_DTYPE)? } else { non_null_cols[0].clone() })? diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index d45f75c01e9d..beaacac49942 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -331,11 +331,10 @@ impl FunctionExpr { MinHorizontal => mapper.map_to_supertype(), SumHorizontal { .. } => { mapper.map_to_supertype().map(|mut f| { - match f.dtype { - // Booleans sum to UInt32. - DataType::Boolean => { f.dtype = DataType::UInt32; f}, - _ => f, + if f.dtype == DataType::Boolean { + f.dtype = IDX_DTYPE; } + f }) }, MeanHorizontal { .. } => { diff --git a/py-polars/tests/unit/operations/aggregation/test_horizontal.py b/py-polars/tests/unit/operations/aggregation/test_horizontal.py index 3959e15e22ed..305db1d8a4b1 100644 --- a/py-polars/tests/unit/operations/aggregation/test_horizontal.py +++ b/py-polars/tests/unit/operations/aggregation/test_horizontal.py @@ -541,8 +541,8 @@ def test_horizontal_sum_boolean_with_null() -> None: expected_schema = pl.Schema( { - "null_first": pl.UInt32, - "bool_first": pl.UInt32, + "null_first": pl.get_index_type(), + "bool_first": pl.get_index_type(), } ) @@ -550,8 +550,8 @@ def test_horizontal_sum_boolean_with_null() -> None: expected_df = pl.DataFrame( { - "null_first": pl.Series([1, 0], dtype=pl.UInt32), - "bool_first": pl.Series([1, 0], dtype=pl.UInt32), + "null_first": pl.Series([1, 0], dtype=pl.get_index_type()), + "bool_first": pl.Series([1, 0], dtype=pl.get_index_type()), } ) @@ -563,7 +563,7 @@ def test_horizontal_sum_boolean_with_null() -> None: ("dtype_in", "dtype_out"), [ (pl.Null, pl.Null), - (pl.Boolean, pl.UInt32), + (pl.Boolean, pl.get_index_type()), (pl.UInt8, pl.UInt8), (pl.Float32, pl.Float32), (pl.Float64, pl.Float64), @@ -589,6 +589,7 @@ def test_horizontal_sum_with_null_col_ignore_strategy( values = [None, None, None] # type: ignore[list-item] expected = pl.LazyFrame(pl.Series("null", values, dtype=dtype_out)) assert_frame_equal(result, expected) + assert result.collect_schema() == expected.collect_schema() @pytest.mark.parametrize("ignore_nulls", [True, False]) From a2b1bc5a0ceef7c3c1644167803800f63abcf3d6 Mon Sep 17 00:00:00 2001 From: Marshall Crumiller Date: Fri, 3 Jan 2025 10:09:28 -0500 Subject: [PATCH 2/3] Cleanup --- .../polars-ops/src/series/ops/horizontal.rs | 45 +++++-------------- .../operations/aggregation/test_horizontal.py | 33 ++++++++++++++ 2 files changed, 43 insertions(+), 35 deletions(-) diff --git a/crates/polars-ops/src/series/ops/horizontal.rs b/crates/polars-ops/src/series/ops/horizontal.rs index e51bddc948db..6a6960480c47 100644 --- a/crates/polars-ops/src/series/ops/horizontal.rs +++ b/crates/polars-ops/src/series/ops/horizontal.rs @@ -190,40 +190,6 @@ pub fn min_horizontal(columns: &[Column]) -> PolarsResult> { } } -// Given a Series sequence, determines the correct supertype and returns full-null Column. -// The `date_to_datetime` flag indicates that Date supertypes should be converted to Datetime("ms"). -fn null_with_supertype( - columns: Vec<&Series>, - date_to_datetime: bool, -) -> PolarsResult> { - // We must first determine the correct return dtype. - let mut return_dtype = dtypes_to_supertype(columns.iter().map(|c| c.dtype()))?; - if return_dtype == DataType::Boolean { - return_dtype = IDX_DTYPE; - } else if date_to_datetime && return_dtype == DataType::Date { - return_dtype = DataType::Datetime(TimeUnit::Milliseconds, None); - } - Ok(Some(Column::full_null( - columns[0].name().clone(), - columns[0].len(), - &return_dtype, - ))) -} - -// Apply `null_with_supertype` to a sequence of Columns. -fn null_with_supertype_from_columns( - columns: &[Column], - date_to_datetime: bool, -) -> PolarsResult> { - null_with_supertype( - columns - .iter() - .map(|c| c.as_materialized_series()) - .collect::>(), - date_to_datetime, - ) -} - pub fn sum_horizontal( columns: &[Column], null_strategy: NullStrategy, @@ -255,7 +221,16 @@ pub fn sum_horizontal( // If we have any null columns and null strategy is not `Ignore`, we can return immediately. if !ignore_nulls && non_null_cols.len() < columns.len() { - return null_with_supertype_from_columns(columns, false); + // We must determine the correct return dtype. + let return_dtype = match dtypes_to_supertype(non_null_cols.iter().map(|c| c.dtype()))? { + DataType::Boolean => IDX_DTYPE, + dt => dt, + }; + return Ok(Some(Column::full_null( + columns[0].name().clone(), + columns[0].len(), + &return_dtype, + ))); } match non_null_cols.len() { diff --git a/py-polars/tests/unit/operations/aggregation/test_horizontal.py b/py-polars/tests/unit/operations/aggregation/test_horizontal.py index 305db1d8a4b1..e6511924c3d4 100644 --- a/py-polars/tests/unit/operations/aggregation/test_horizontal.py +++ b/py-polars/tests/unit/operations/aggregation/test_horizontal.py @@ -319,6 +319,39 @@ def test_sum_single_col() -> None: ) +@pytest.mark.parametrize("ignore_nulls", [False, True]) +def test_sum_correct_supertype(ignore_nulls: bool) -> None: + values = [1, 2] if ignore_nulls else [None, None] + lf = pl.LazyFrame( + { + "null": [None, None], + "int": pl.Series(values, dtype=pl.Int32), + "float": pl.Series(values, dtype=pl.Float32), + } + ) + + # null + int32 should produce int32 + out = lf.select(pl.sum_horizontal("null", "int", ignore_nulls=ignore_nulls)) + expected = pl.LazyFrame({"null": pl.Series(values, dtype=pl.Int32)}) + assert_frame_equal(out.collect(), expected.collect()) + assert out.collect_schema() == expected.collect_schema() + + # null + float32 should produce float32 + out = lf.select(pl.sum_horizontal("null", "float", ignore_nulls=ignore_nulls)) + expected = pl.LazyFrame({"null": pl.Series(values, dtype=pl.Float32)}) + assert_frame_equal(out.collect(), expected.collect()) + assert out.collect_schema() == expected.collect_schema() + + # null + int32 + float32 should produce float64 + values = [2, 4] if ignore_nulls else [None, None] + out = lf.select( + pl.sum_horizontal("null", "int", "float", ignore_nulls=ignore_nulls) + ) + expected = pl.LazyFrame({"null": pl.Series(values, dtype=pl.Float64)}) + assert_frame_equal(out.collect(), expected.collect()) + assert out.collect_schema() == expected.collect_schema() + + def test_cum_sum_horizontal() -> None: df = pl.DataFrame( { From 259e101f1ac351c58bfdb20b50438a3222b62e8b Mon Sep 17 00:00:00 2001 From: Marshall Crumiller Date: Fri, 3 Jan 2025 10:17:43 -0500 Subject: [PATCH 3/3] Fix lint --- .../tests/unit/operations/aggregation/test_horizontal.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py-polars/tests/unit/operations/aggregation/test_horizontal.py b/py-polars/tests/unit/operations/aggregation/test_horizontal.py index e6511924c3d4..bc557a231d75 100644 --- a/py-polars/tests/unit/operations/aggregation/test_horizontal.py +++ b/py-polars/tests/unit/operations/aggregation/test_horizontal.py @@ -321,7 +321,7 @@ def test_sum_single_col() -> None: @pytest.mark.parametrize("ignore_nulls", [False, True]) def test_sum_correct_supertype(ignore_nulls: bool) -> None: - values = [1, 2] if ignore_nulls else [None, None] + values = [1, 2] if ignore_nulls else [None, None] # type: ignore[list-item] lf = pl.LazyFrame( { "null": [None, None], @@ -343,7 +343,7 @@ def test_sum_correct_supertype(ignore_nulls: bool) -> None: assert out.collect_schema() == expected.collect_schema() # null + int32 + float32 should produce float64 - values = [2, 4] if ignore_nulls else [None, None] + values = [2, 4] if ignore_nulls else [None, None] # type: ignore[list-item] out = lf.select( pl.sum_horizontal("null", "int", "float", ignore_nulls=ignore_nulls) )