From 98ee5a32d0baf003ef64fc744f87e6042072124b Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Thu, 12 Oct 2023 10:05:01 +0200 Subject: [PATCH] fix(python): fix sort_by regression (#11679) --- .../src/physical_plan/expressions/sortby.rs | 50 ++++++++++++++++++- py-polars/tests/unit/operations/test_sort.py | 17 +++++++ 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/crates/polars-lazy/src/physical_plan/expressions/sortby.rs b/crates/polars-lazy/src/physical_plan/expressions/sortby.rs index 008d714c09d2..feb398432b64 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/sortby.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/sortby.rs @@ -87,6 +87,40 @@ fn sort_by_groups_single_by( Ok((*first, new_idx)) } +fn sort_by_groups_no_match_single<'a>( + mut ac_in: AggregationContext<'a>, + mut ac_by: AggregationContext<'a>, + descending: bool, + expr: &Expr, +) -> PolarsResult> { + let s_in = ac_in.aggregated(); + let s_by = ac_by.aggregated(); + let mut s_in = s_in.list().unwrap().clone(); + let mut s_by = s_by.list().unwrap().clone(); + + let ca: PolarsResult = POOL.install(|| { + s_in.par_iter_indexed() + .zip(s_by.par_iter_indexed()) + .map(|(opt_s, s_sort_by)| match (opt_s, s_sort_by) { + (Some(s), Some(s_sort_by)) => { + polars_ensure!(s.len() == s_sort_by.len(), ComputeError: "series lengths don't match in 'sort_by' expression"); + let idx = s_sort_by.arg_sort(SortOptions { + descending, + // We are already in par iter. + multithreaded: false, + ..Default::default() + }); + Ok(Some(unsafe { s.take_unchecked(&idx) })) + }, + _ => Ok(None), + }) + .collect() + }); + let s = ca?.with_name(s_in.name()).into_series(); + ac_in.with_series(s, true, Some(expr))?; + Ok(ac_in) +} + fn sort_by_groups_multiple_by( indicator: GroupsIndicator, sort_by_s: &[Series], @@ -198,7 +232,7 @@ impl PhysicalExpr for SortByExpr { .iter() .map(|e| e.evaluate_on_groups(df, groups, state)) .collect::>>()?; - let sort_by_s = ac_sort_by + let mut sort_by_s = ac_sort_by .iter() .map(|s| { let s = s.flat_naive(); @@ -225,7 +259,19 @@ impl PhysicalExpr for SortByExpr { let groups = if self.by.len() == 1 { let mut ac_sort_by = ac_sort_by.pop().unwrap(); - let sort_by_s = ac_sort_by.flat_naive().into_owned(); + + // The groups of the lhs of the expressions do not match the series values, + // we must take the slower path. + if !matches!(ac_in.update_groups, UpdateGroups::No) { + return sort_by_groups_no_match_single( + ac_in, + ac_sort_by, + self.descending[0], + &self.expr, + ); + }; + + let sort_by_s = sort_by_s.pop().unwrap(); let groups = ac_sort_by.groups(); let (check, groups) = POOL.join( diff --git a/py-polars/tests/unit/operations/test_sort.py b/py-polars/tests/unit/operations/test_sort.py index 8bc3a1bc6751..884a43656472 100644 --- a/py-polars/tests/unit/operations/test_sort.py +++ b/py-polars/tests/unit/operations/test_sort.py @@ -711,3 +711,20 @@ def test_sorted_update_flags_10327() -> None: pl.Series("a", [], dtype=pl.Int64).to_frame(), ] )["a"].to_list() == [1, 2] + + +def test_sort_by_11653() -> None: + df = pl.DataFrame( + { + "id": [0, 0, 0, 0, 0, 1], + "weights": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + "other": [0.8, 0.4, 0.5, 0.6, 0.7, 0.8], + } + ) + + assert df.group_by("id").agg( + (pl.col("weights") / pl.col("weights").sum()) + .sort_by("other") + .sum() + .alias("sort_by"), + ).sort("id").to_dict(False) == {"id": [0, 1], "sort_by": [1.0, 1.0]}