Skip to content

Commit

Permalink
fix(python): fix sort_by regression (#11679)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Oct 12, 2023
1 parent 4a368f5 commit 98ee5a3
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 2 deletions.
50 changes: 48 additions & 2 deletions crates/polars-lazy/src/physical_plan/expressions/sortby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,40 @@ fn sort_by_groups_single_by(
Ok((*first, new_idx))
}

fn sort_by_groups_no_match_single<'a>(
mut ac_in: AggregationContext<'a>,
mut ac_by: AggregationContext<'a>,
descending: bool,
expr: &Expr,
) -> PolarsResult<AggregationContext<'a>> {
let s_in = ac_in.aggregated();
let s_by = ac_by.aggregated();
let mut s_in = s_in.list().unwrap().clone();
let mut s_by = s_by.list().unwrap().clone();

let ca: PolarsResult<ListChunked> = POOL.install(|| {
s_in.par_iter_indexed()
.zip(s_by.par_iter_indexed())
.map(|(opt_s, s_sort_by)| match (opt_s, s_sort_by) {
(Some(s), Some(s_sort_by)) => {
polars_ensure!(s.len() == s_sort_by.len(), ComputeError: "series lengths don't match in 'sort_by' expression");
let idx = s_sort_by.arg_sort(SortOptions {
descending,
// We are already in par iter.
multithreaded: false,
..Default::default()
});
Ok(Some(unsafe { s.take_unchecked(&idx) }))
},
_ => Ok(None),
})
.collect()
});
let s = ca?.with_name(s_in.name()).into_series();
ac_in.with_series(s, true, Some(expr))?;
Ok(ac_in)
}

fn sort_by_groups_multiple_by(
indicator: GroupsIndicator,
sort_by_s: &[Series],
Expand Down Expand Up @@ -198,7 +232,7 @@ impl PhysicalExpr for SortByExpr {
.iter()
.map(|e| e.evaluate_on_groups(df, groups, state))
.collect::<PolarsResult<Vec<_>>>()?;
let sort_by_s = ac_sort_by
let mut sort_by_s = ac_sort_by
.iter()
.map(|s| {
let s = s.flat_naive();
Expand All @@ -225,7 +259,19 @@ impl PhysicalExpr for SortByExpr {

let groups = if self.by.len() == 1 {
let mut ac_sort_by = ac_sort_by.pop().unwrap();
let sort_by_s = ac_sort_by.flat_naive().into_owned();

// The groups of the lhs of the expressions do not match the series values,
// we must take the slower path.
if !matches!(ac_in.update_groups, UpdateGroups::No) {
return sort_by_groups_no_match_single(
ac_in,
ac_sort_by,
self.descending[0],
&self.expr,
);
};

let sort_by_s = sort_by_s.pop().unwrap();
let groups = ac_sort_by.groups();

let (check, groups) = POOL.join(
Expand Down
17 changes: 17 additions & 0 deletions py-polars/tests/unit/operations/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,3 +711,20 @@ def test_sorted_update_flags_10327() -> None:
pl.Series("a", [], dtype=pl.Int64).to_frame(),
]
)["a"].to_list() == [1, 2]


def test_sort_by_11653() -> None:
df = pl.DataFrame(
{
"id": [0, 0, 0, 0, 0, 1],
"weights": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
"other": [0.8, 0.4, 0.5, 0.6, 0.7, 0.8],
}
)

assert df.group_by("id").agg(
(pl.col("weights") / pl.col("weights").sum())
.sort_by("other")
.sum()
.alias("sort_by"),
).sort("id").to_dict(False) == {"id": [0, 1], "sort_by": [1.0, 1.0]}

0 comments on commit 98ee5a3

Please sign in to comment.