From e8621c0b50bbd87713766d85446f291062e07bdf Mon Sep 17 00:00:00 2001 From: Marshall Crumiller Date: Wed, 1 Jan 2025 21:07:09 -0500 Subject: [PATCH] Add mean_horizontal for temporals --- .../polars-ops/src/series/ops/horizontal.rs | 180 ++++++++++++++---- .../src/dsl/function_expr/schema.rs | 12 +- .../operations/aggregation/test_horizontal.py | 89 ++++++++- 3 files changed, 240 insertions(+), 41 deletions(-) diff --git a/crates/polars-ops/src/series/ops/horizontal.rs b/crates/polars-ops/src/series/ops/horizontal.rs index 025779e77349..385db5d91919 100644 --- a/crates/polars-ops/src/series/ops/horizontal.rs +++ b/crates/polars-ops/src/series/ops/horizontal.rs @@ -1,5 +1,6 @@ use std::borrow::Cow; +use arrow::temporal_conversions::MILLISECONDS_IN_DAY; use polars_core::chunked_array::cast::CastOptions; use polars_core::prelude::*; use polars_core::series::arithmetic::coerce_lhs_rhs; @@ -190,10 +191,44 @@ pub fn min_horizontal(columns: &[Column]) -> PolarsResult> { } } +fn null_with_supertype( + columns: Vec<&Series>, + date_to_datetime: bool, +) -> PolarsResult> { + // We must first determine the correct return dtype. + let mut return_dtype = dtypes_to_supertype(columns.iter().map(|c| c.dtype()))?; + if return_dtype == DataType::Boolean { + return_dtype = IDX_DTYPE; + } else if date_to_datetime && return_dtype == DataType::Date { + return_dtype = DataType::Datetime(TimeUnit::Milliseconds, None); + } + Ok(Some(Column::full_null( + columns[0].name().clone(), + columns[0].len(), + &return_dtype, + ))) +} + +fn null_with_supertype_from_series( + columns: &[Column], + date_to_datetime: bool, +) -> PolarsResult> { + null_with_supertype( + columns + .iter() + .map(|c| c.as_materialized_series()) + .collect::>(), + date_to_datetime, + ) +} + pub fn sum_horizontal( columns: &[Column], null_strategy: NullStrategy, ) -> PolarsResult> { + if columns.is_empty() { + return Ok(None); + } validate_column_lengths(columns)?; let ignore_nulls = null_strategy == NullStrategy::Ignore; @@ -220,17 +255,12 @@ pub fn sum_horizontal( .collect::>(); // If we have any null columns and null strategy is not `Ignore`, we can return immediately. - if !ignore_nulls && non_null_cols.len() < columns.len() { - // We must first determine the correct return dtype. - let return_dtype = match dtypes_to_supertype(non_null_cols.iter().map(|c| c.dtype()))? { - DataType::Boolean => DataType::UInt32, - dt => dt, - }; - return Ok(Some(Column::full_null( - columns[0].name().clone(), - columns[0].len(), - &return_dtype, - ))); + let num_cols = non_null_cols.len(); + let name = columns[0].name(); + if num_cols == 0 { + return Ok(Some(columns[0].clone().with_name(name.clone()))); + } else if !ignore_nulls && non_null_cols.len() < columns.len() { + return null_with_supertype(non_null_cols, false); } match non_null_cols.len() { @@ -239,15 +269,16 @@ pub fn sum_horizontal( Ok(None) } else { // all columns are null dtype, so result is null dtype - Ok(Some(columns[0].clone())) + Ok(Some(columns[0].clone().with_name(name.clone()))) } }, 1 => Ok(Some( apply_null_strategy(if non_null_cols[0].dtype() == &DataType::Boolean { - non_null_cols[0].cast(&DataType::UInt32)? + non_null_cols[0].cast(&IDX_DTYPE)? } else { non_null_cols[0].clone() })? + .with_name(name.clone()) .into(), )), 2 => sum_fn(non_null_cols[0].clone(), non_null_cols[1].clone()) @@ -274,33 +305,98 @@ pub fn mean_horizontal( columns: &[Column], null_strategy: NullStrategy, ) -> PolarsResult> { + if columns.is_empty() { + return Ok(None); + } + let ignore_nulls = null_strategy == NullStrategy::Ignore; + validate_column_lengths(columns)?; + let name = columns[0].name().clone(); - let (numeric_columns, non_numeric_columns): (Vec<_>, Vec<_>) = columns.iter().partition(|s| { - let dtype = s.dtype(); - dtype.is_numeric() || dtype.is_decimal() || dtype.is_bool() || dtype.is_null() - }); + let Some(first_non_null_idx) = columns.iter().position(|c| !c.dtype().is_null()) else { + // All columns are null; return f64 nulls. + return Ok(Some( + Float64Chunked::full_null(name, columns[0].len()).into_column(), + )); + }; - if !non_numeric_columns.is_empty() { - let col = non_numeric_columns.first().cloned(); + if first_non_null_idx > 0 && !ignore_nulls { + return null_with_supertype_from_series(columns, true); + } + + // Ensure column dtypes are all valid + let first_dtype = columns[first_non_null_idx].dtype(); + let is_temporal = first_dtype.is_temporal(); + let columns = if is_temporal { + // All remaining dtypes must be the same temporal dtype (or null). + for col in &columns[first_non_null_idx + 1..] { + let dtype = col.dtype(); + if !ignore_nulls && dtype == &DataType::Null { + // A null column guarantees null output. + return null_with_supertype_from_series(columns, true); + } else if dtype != first_dtype && dtype != &DataType::Null { + polars_bail!( + InvalidOperation: "'horizontal_mean' expects all numeric or all temporal expressions, found {:?} (dtype={}) and {:?} (dtype={})", + columns[first_non_null_idx].name(), + first_dtype, + dtype, + col.name(), + ); + }; + } + columns[first_non_null_idx..] + .iter() + .map(|c| c.cast(&DataType::Int64).unwrap()) + .collect::>() + } else if first_dtype.is_numeric() + || first_dtype.is_decimal() + || first_dtype.is_bool() + || first_dtype.is_null() + || first_dtype.is_temporal() + { + // All remaining must be numeric (or null). + for col in &columns[first_non_null_idx + 1..] { + let dtype = col.dtype(); + if !(dtype.is_numeric() + || dtype.is_decimal() + || dtype.is_bool() + || dtype.is_temporal() + || dtype.is_null()) + { + polars_bail!( + InvalidOperation: "'horizontal_mean' expects all numeric or all temporal expressions, found {:?} (dtype={}) and {:?} (dtype={})", + columns[first_non_null_idx].name(), + first_dtype, + col.name(), + dtype, + ); + } + } + columns[first_non_null_idx..].to_vec() + } else { polars_bail!( - InvalidOperation: "'horizontal_mean' expects numeric expressions, found {:?} (dtype={})", - col.unwrap().name(), - col.unwrap().dtype(), + InvalidOperation: "'horizontal_mean' expects all numeric or all temporal expressions, found {:?} (dtype={})", + columns[first_non_null_idx].name(), + first_dtype, ); - } - let columns = numeric_columns.into_iter().cloned().collect::>(); + }; + let num_rows = columns.len(); match num_rows { - 0 => Ok(None), - 1 => Ok(Some(match columns[0].dtype() { - dt if dt != &DataType::Float32 && !dt.is_decimal() => { - columns[0].cast(&DataType::Float64)? + 1 => Ok(Some(match first_dtype { + dt if dt != &DataType::Float32 && !is_temporal && !dt.is_decimal() => { + columns[0].cast(&DataType::Float64)?.with_name(name) + }, + _ => match first_dtype { + DataType::Date => (&columns[0] * MILLISECONDS_IN_DAY) + .with_name(name) + .cast(&DataType::Datetime(TimeUnit::Milliseconds, None))?, + dt if is_temporal => columns[0].cast(dt)?.with_name(name), + _ => columns[0].clone().with_name(name), }, - _ => columns[0].clone(), })), _ => { - let sum = || sum_horizontal(columns.as_slice(), null_strategy); + let sum = || sum_horizontal(&columns, null_strategy); let null_count = || { columns .par_iter() @@ -321,7 +417,7 @@ pub fn mean_horizontal( }; let (sum, null_count) = POOL.install(|| rayon::join(sum, null_count)); - let sum = sum?; + let sum = sum?.map(|c| c.with_name(name)); let null_count = null_count?; // value lengths: len - null_count @@ -349,8 +445,26 @@ pub fn mean_horizontal( .into_column() .cast(dt)?; - sum.map(|sum| std::ops::Div::div(&sum, &value_length)) - .transpose() + let out = sum.map(|sum| std::ops::Div::div(&sum, &value_length)); + + let x = out.map(|opt| { + opt.and_then(|value| { + if is_temporal { + if first_dtype == &DataType::Date { + // Cast to DateTime(us) + (value * MILLISECONDS_IN_DAY) + .cast(&DataType::Datetime(TimeUnit::Milliseconds, None)) + } else { + // Cast to original + value.cast(first_dtype) + } + } else { + Ok(value) + } + }) + }); + + x.transpose() }, } } diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index d45f75c01e9d..9bd0f4fbaf0c 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -331,17 +331,19 @@ impl FunctionExpr { MinHorizontal => mapper.map_to_supertype(), SumHorizontal { .. } => { mapper.map_to_supertype().map(|mut f| { - match f.dtype { - // Booleans sum to UInt32. - DataType::Boolean => { f.dtype = DataType::UInt32; f}, - _ => f, + if f.dtype == DataType::Boolean { + f.dtype = IDX_DTYPE; } + f }) }, MeanHorizontal { .. } => { mapper.map_to_supertype().map(|mut f| { match f.dtype { - dt @ DataType::Float32 => { f.dtype = dt; }, + DataType::Boolean => { f.dtype = DataType::Float64; }, + DataType::Float32 => { f.dtype = DataType::Float32; }, + DataType::Date => { f.dtype = DataType::Datetime(TimeUnit::Milliseconds, None); } + dt if dt.is_temporal() => { f.dtype = dt; } _ => { f.dtype = DataType::Float64; }, }; f diff --git a/py-polars/tests/unit/operations/aggregation/test_horizontal.py b/py-polars/tests/unit/operations/aggregation/test_horizontal.py index 3959e15e22ed..0e8d9311c4be 100644 --- a/py-polars/tests/unit/operations/aggregation/test_horizontal.py +++ b/py-polars/tests/unit/operations/aggregation/test_horizontal.py @@ -1,7 +1,7 @@ from __future__ import annotations -import datetime from collections import OrderedDict +from datetime import date, datetime, time, timedelta from typing import TYPE_CHECKING, Any import pytest @@ -12,7 +12,7 @@ from polars.testing import assert_frame_equal, assert_series_equal if TYPE_CHECKING: - from polars._typing import PolarsDataType + from polars._typing import PolarsDataType, TimeUnit def test_any_expr(fruits_cars: pl.DataFrame) -> None: @@ -340,7 +340,7 @@ def test_sum_dtype_12028() -> None: [ pl.Series( "sum_duration", - [datetime.timedelta(seconds=10)], + [timedelta(seconds=10)], dtype=pl.Duration(time_unit="us"), ), ] @@ -431,6 +431,7 @@ def test_mean_horizontal_no_rows() -> None: expected = pl.LazyFrame({"a": []}, schema={"a": pl.Float64}) assert_frame_equal(result, expected) + assert result.collect_schema() == expected.collect_schema() def test_mean_horizontal_all_null() -> None: @@ -442,6 +443,85 @@ def test_mean_horizontal_all_null() -> None: assert_frame_equal(result, expected) +@pytest.mark.parametrize("tz", [None, "UTC", "Asia/Kathmandu"]) +@pytest.mark.parametrize("tu", ["ms", "us", "ns"]) +@pytest.mark.parametrize("ignore_nulls", [False, True]) +def test_mean_horizontal_temporal(tu: TimeUnit, tz: str, ignore_nulls: bool) -> None: + dt1 = [date(2024, 1, 1), date(2024, 1, 3)] + dt2 = [date(2024, 1, 2), date(2024, 1, 4)] + dur1 = [timedelta(hours=1), timedelta(hours=3)] + dur2 = [timedelta(hours=2), timedelta(hours=4)] + lf = pl.LazyFrame( + { + "date1": pl.Series(dt1, dtype=pl.Date), + "date2": pl.Series(dt2, dtype=pl.Date), + "datetime1": pl.Series(dt1, dtype=pl.Datetime(time_unit=tu, time_zone=tz)), + "datetime2": pl.Series(dt2, dtype=pl.Datetime(time_unit=tu, time_zone=tz)), + "time1": [time(1), time(3)], + "time2": [time(2), time(4)], + "duration1": pl.Series(dur1, dtype=pl.Duration(time_unit=tu)), + "duration2": pl.Series(dur2, dtype=pl.Duration(time_unit=tu)), + "null": pl.Series([None, None], dtype=pl.Null), + } + ) + null_method = {"ignore_nulls": ignore_nulls} + out = lf.select( + pl.mean_horizontal("date1", "date2", **null_method).alias("date"), + pl.mean_horizontal("datetime1", "datetime2", **null_method).alias("datetime"), + pl.mean_horizontal("time1", "time2", **null_method).alias("time"), + pl.mean_horizontal("duration1", "duration2", **null_method).alias("duration"), + pl.mean_horizontal("date1", "null", **null_method).alias("null_date1"), + pl.mean_horizontal("null", "date2", **null_method).alias("null_date2"), + pl.mean_horizontal("datetime1", "null", **null_method).alias("null_datetime1"), + pl.mean_horizontal("null", "datetime2", **null_method).alias("null_datetime2"), + ) + + expected = pl.DataFrame( + { + "date": pl.Series( + [datetime(2024, 1, 1, 12), datetime(2024, 1, 3, 12)], + dtype=pl.Datetime("ms"), + ), + "datetime": pl.Series( + [datetime(2024, 1, 1, 12), datetime(2024, 1, 3, 12)], + dtype=pl.Datetime(tu, tz), + ), + "time": [time(hour=1, minute=30), time(hour=3, minute=30)], + "duration": pl.Series( + [timedelta(hours=1, minutes=30), timedelta(hours=3, minutes=30)], + dtype=pl.Duration(time_unit=tu), + ), + "null_date1": ( + pl.Series(dt1, dtype=pl.Datetime(time_unit="ms")) + if ignore_nulls + else pl.Series([None, None], dtype=pl.Datetime(time_unit="ms")) + ), + "null_date2": ( + pl.Series(dt2, dtype=pl.Datetime(time_unit="ms")) + if ignore_nulls + else pl.Series([None, None], dtype=pl.Datetime(time_unit="ms")) + ), + "null_datetime1": ( + pl.Series(dt1, dtype=pl.Datetime(time_unit=tu, time_zone=tz)) + if ignore_nulls + else pl.Series( + [None, None], dtype=pl.Datetime(time_unit=tu, time_zone=tz) + ) + ), + "null_datetime2": pl.Series( + pl.Series(dt2, dtype=pl.Datetime(time_unit=tu, time_zone=tz)) + if ignore_nulls + else pl.Series( + [None, None], dtype=pl.Datetime(time_unit=tu, time_zone=tz) + ) + ), + } + ) + + assert_frame_equal(out.collect(), expected) + assert out.collect_schema() == expected.collect_schema() + + @pytest.mark.parametrize( ("in_dtype", "out_dtype"), [ @@ -556,6 +636,7 @@ def test_horizontal_sum_boolean_with_null() -> None: ) assert_frame_equal(out.collect(), expected_df) + assert out.collect_schema(), expected_df.collect_schema() @pytest.mark.parametrize("ignore_nulls", [True, False]) @@ -589,6 +670,7 @@ def test_horizontal_sum_with_null_col_ignore_strategy( values = [None, None, None] # type: ignore[list-item] expected = pl.LazyFrame(pl.Series("null", values, dtype=dtype_out)) assert_frame_equal(result, expected) + assert result.collect_schema() == result.collect_schema() @pytest.mark.parametrize("ignore_nulls", [True, False]) @@ -621,3 +703,4 @@ def test_horizontal_mean_with_null_col_ignore_strategy( values = [None, None, None] # type: ignore[list-item] expected = pl.LazyFrame(pl.Series("null", values, dtype=dtype_out)) assert_frame_equal(result, expected) + assert result.collect_schema() == expected.collect_schema()