Skip to content

Commit

Permalink
feat: Support order-by in window functions (#16743)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jun 5, 2024
1 parent 0b48a93 commit 6f3fd8e
Show file tree
Hide file tree
Showing 21 changed files with 246 additions and 59 deletions.
1 change: 1 addition & 0 deletions crates/polars-expr/src/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub(crate) use slice::*;
pub(crate) use sort::*;
pub(crate) use sortby::*;
pub(crate) use ternary::*;
pub use window::window_function_format_order_by;
pub(crate) use window::*;

use crate::state::ExecutionState;
Expand Down
6 changes: 2 additions & 4 deletions crates/polars-expr/src/expressions/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use polars_core::prelude::*;
use polars_core::POOL;
use polars_ops::chunked_array::ListNameSpaceImpl;
use polars_utils::idx_vec::IdxVec;
use polars_utils::slice::GetSaferUnchecked;
use rayon::prelude::*;

use super::*;
Expand Down Expand Up @@ -29,10 +30,7 @@ pub(crate) fn map_sorted_indices_to_group_idx(sorted_idx: &IdxCa, idx: &[IdxSize
.cont_slice()
.unwrap()
.iter()
.map(|&i| {
debug_assert!(idx.get(i as usize).is_some());
unsafe { *idx.get_unchecked(i as usize) }
})
.map(|&i| unsafe { *idx.get_unchecked_release(i as usize) })
.collect()
}

Expand Down
52 changes: 32 additions & 20 deletions crates/polars-expr/src/expressions/sortby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,32 +53,42 @@ fn check_groups(a: &GroupsProxy, b: &GroupsProxy) -> PolarsResult<()> {
Ok(())
}

pub(super) fn update_groups_sort_by(
groups: &GroupsProxy,
sort_by_s: &Series,
options: &SortOptions,
) -> PolarsResult<GroupsProxy> {
let groups = groups
.par_iter()
.map(|indicator| sort_by_groups_single_by(indicator, sort_by_s, options))
.collect::<PolarsResult<_>>()?;

Ok(GroupsProxy::Idx(groups))
}

fn sort_by_groups_single_by(
indicator: GroupsIndicator,
sort_by_s: &Series,
descending: &[bool],
options: &SortOptions,
) -> PolarsResult<(IdxSize, IdxVec)> {
let options = SortOptions {
descending: options.descending,
nulls_last: options.nulls_last,
// We are already in par iter.
multithreaded: false,
..Default::default()
};
let new_idx = match indicator {
GroupsIndicator::Idx((_, idx)) => {
// SAFETY: group tuples are always in bounds.
let group = unsafe { sort_by_s.take_slice_unchecked(idx) };

let sorted_idx = group.arg_sort(SortOptions {
descending: descending[0],
// We are already in par iter.
multithreaded: false,
..Default::default()
});
let sorted_idx = group.arg_sort(options);
map_sorted_indices_to_group_idx(&sorted_idx, idx)
},
GroupsIndicator::Slice([first, len]) => {
let group = sort_by_s.slice(first as i64, len as usize);
let sorted_idx = group.arg_sort(SortOptions {
descending: descending[0],
// We are already in par iter.
multithreaded: false,
..Default::default()
});
let sorted_idx = group.arg_sort(options);
map_sorted_indices_to_group_slice(&sorted_idx, first)
},
};
Expand Down Expand Up @@ -283,17 +293,19 @@ impl PhysicalExpr for SortByExpr {
let (check, groups) = POOL.join(
|| check_groups(groups, ac_in.groups()),
|| {
groups
.par_iter()
.map(|indicator| {
sort_by_groups_single_by(indicator, &sort_by_s, &descending)
})
.collect::<PolarsResult<_>>()
update_groups_sort_by(
groups,
&sort_by_s,
&SortOptions {
descending: descending[0],
..Default::default()
},
)
},
);
check?;

GroupsProxy::Idx(groups?)
groups?
} else {
let groups = ac_sort_by[0].groups();

Expand Down
25 changes: 24 additions & 1 deletion crates/polars-expr/src/expressions/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub struct WindowExpr {
/// the root column that the Function will be applied on.
/// This will be used to create a smaller DataFrame to prevent taking unneeded columns by index
pub(crate) group_by: Vec<Arc<dyn PhysicalExpr>>,
pub(crate) order_by: Option<(Arc<dyn PhysicalExpr>, SortOptions)>,
pub(crate) apply_columns: Vec<Arc<str>>,
pub(crate) out_name: Option<Arc<str>>,
/// A function Expr. i.e. Mean, Median, Max, etc.
Expand Down Expand Up @@ -366,6 +367,11 @@ impl WindowExpr {
}
}

// Utility to create partitions and cache keys
pub fn window_function_format_order_by(to: &mut String, e: &Expr, k: &SortOptions) {
write!(to, "_PL_{:?}{}_{}", e, k.descending, k.nulls_last).unwrap();
}

impl PhysicalExpr for WindowExpr {
// Note: this was first implemented with expression evaluation but this performed really bad.
// Therefore we choose the group_by -> apply -> self join approach
Expand Down Expand Up @@ -439,7 +445,15 @@ impl PhysicalExpr for WindowExpr {

let create_groups = || {
let gb = df.group_by_with_series(group_by_columns.clone(), true, sort_groups)?;
let out: PolarsResult<GroupsProxy> = Ok(gb.take_groups());
let mut groups = gb.take_groups();

if let Some((order_by, options)) = &self.order_by {
let order_by = order_by.evaluate(df, state)?;
polars_ensure!(order_by.len() == df.height(), ShapeMismatch: "the order by expression evaluated to a length: {} that doesn't match the input DataFrame: {}", order_by.len(), df.height());
groups = update_groups_sort_by(&groups, &order_by, options)?
}

let out: PolarsResult<GroupsProxy> = Ok(groups);
out
};

Expand All @@ -450,6 +464,15 @@ impl PhysicalExpr for WindowExpr {
for s in &group_by_columns {
cache_key.push_str(s.name());
}
if let Some((e, options)) = &self.order_by {
let e = match e.as_expression() {
Some(e) => e,
None => {
polars_bail!(InvalidOperation: "cannot order by this expression in window function")
},
};
window_function_format_order_by(&mut cache_key, e, options)
}

let mut gt_map_guard = state.group_tuples.write().unwrap();
// we run sequential and partitioned
Expand Down
17 changes: 17 additions & 0 deletions crates/polars-expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ fn create_physical_expr_inner(
Window {
mut function,
partition_by,
order_by,
options,
} => {
state.set_window();
Expand All @@ -208,6 +209,21 @@ fn create_physical_expr_inner(
state,
)?;

let order_by = order_by
.map(|(node, options)| {
PolarsResult::Ok((
create_physical_expr_inner(
node,
Context::Aggregation,
expr_arena,
schema,
state,
)?,
options,
))
})
.transpose()?;

let mut out_name = None;
if let Alias(expr, name) = expr_arena.get(function) {
function = *expr;
Expand Down Expand Up @@ -250,6 +266,7 @@ fn create_physical_expr_inner(

Ok(Arc::new(WindowExpr {
group_by,
order_by,
apply_columns,
out_name,
function: function_expr,
Expand Down
17 changes: 12 additions & 5 deletions crates/polars-lazy/src/physical_plan/executors/projection_utils.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use polars_utils::format_smartstring;
use polars_utils::iter::EnumerateIdxTrait;
use smartstring::alias::String as SmartString;

Expand Down Expand Up @@ -51,7 +50,7 @@ fn rolling_evaluate(
fn window_evaluate(
df: &DataFrame,
state: &ExecutionState,
window: PlHashMap<SmartString, Vec<IdAndExpression>>,
window: PlHashMap<String, Vec<IdAndExpression>>,
) -> PolarsResult<Vec<Vec<(u32, Series)>>> {
POOL.install(|| {
window
Expand Down Expand Up @@ -111,7 +110,7 @@ fn execute_projection_cached_window_fns(
#[allow(clippy::type_complexity)]
// String: partition_name,
// u32: index,
let mut windows: PlHashMap<SmartString, Vec<IdAndExpression>> = PlHashMap::default();
let mut windows: PlHashMap<String, Vec<IdAndExpression>> = PlHashMap::default();
#[cfg(feature = "dynamic_group_by")]
let mut rolling: PlHashMap<&RollingGroupOptions, Vec<IdAndExpression>> = PlHashMap::default();
let mut other = Vec::with_capacity(exprs.len());
Expand All @@ -126,13 +125,21 @@ fn execute_projection_cached_window_fns(
if let Expr::Window {
partition_by,
options,
order_by,
..
} = e
{
let entry = match options {
WindowType::Over(_) => {
let group_by = format_smartstring!("{:?}", partition_by.as_slice());
windows.entry(group_by).or_insert_with(Vec::new)
let mut key = format!("{:?}", partition_by.as_slice());
if let Some((e, k)) = order_by {
polars_expr::prelude::window_function_format_order_by(
&mut key,
e.as_ref(),
k,
)
}
windows.entry(key).or_insert_with(Vec::new)
},
#[cfg(feature = "dynamic_group_by")]
WindowType::Rolling(options) => rolling.entry(options).or_insert_with(Vec::new),
Expand Down
12 changes: 7 additions & 5 deletions crates/polars-lazy/src/tests/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1143,9 +1143,11 @@ fn test_fill_forward() -> PolarsResult<()> {

let out = df
.lazy()
.select([col("b")
.forward_fill(None)
.over_with_options([col("a")], WindowMapping::Join)])
.select([col("b").forward_fill(None).over_with_options(
[col("a")],
None,
WindowMapping::Join,
)])
.collect()?;
let agg = out.column("b")?.list()?;

Expand Down Expand Up @@ -1305,7 +1307,7 @@ fn test_filter_after_shift_in_groups() -> PolarsResult<()> {
col("B")
.shift(lit(1))
.filter(col("B").shift(lit(1)).gt(lit(4)))
.over_with_options([col("fruits")], WindowMapping::Join)
.over_with_options([col("fruits")], None, WindowMapping::Join)
.alias("filtered"),
])
.collect()?;
Expand Down Expand Up @@ -1664,7 +1666,7 @@ fn test_single_ranked_group() -> PolarsResult<()> {
},
None,
)
.over_with_options([col("group")], WindowMapping::Join)])
.over_with_options([col("group")], None, WindowMapping::Join)])
.collect()?;

let out = out.column("value")?.explode()?;
Expand Down
5 changes: 4 additions & 1 deletion crates/polars-plan/src/dsl/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,12 @@ pub enum Expr {
input: Arc<Expr>,
by: Arc<Expr>,
},
/// See postgres window functions
/// Polars flavored window functions.
Window {
/// Also has the input. i.e. avg("foo")
function: Arc<Expr>,
partition_by: Vec<Expr>,
order_by: Option<(Arc<Expr>, SortOptions)>,
options: WindowType,
},
Wildcard,
Expand Down Expand Up @@ -249,10 +250,12 @@ impl Hash for Expr {
Expr::Window {
function,
partition_by,
order_by,
options,
} => {
function.hash(state);
partition_by.hash(state);
order_by.hash(state);
options.hash(state);
},
Expr::Slice {
Expand Down
20 changes: 19 additions & 1 deletion crates/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ pub use list::*;
pub use meta::*;
pub use name::*;
pub use options::*;
use polars_core::error::feature_gated;
use polars_core::prelude::*;
#[cfg(feature = "diff")]
use polars_core::series::ops::NullBehavior;
Expand Down Expand Up @@ -952,22 +953,38 @@ impl Expr {
/// ╰────────┴────────╯
/// ```
pub fn over<E: AsRef<[IE]>, IE: Into<Expr> + Clone>(self, partition_by: E) -> Self {
self.over_with_options(partition_by, Default::default())
self.over_with_options(partition_by, None, Default::default())
}

pub fn over_with_options<E: AsRef<[IE]>, IE: Into<Expr> + Clone>(
self,
partition_by: E,
order_by: Option<(E, SortOptions)>,
options: WindowMapping,
) -> Self {
let partition_by = partition_by
.as_ref()
.iter()
.map(|e| e.clone().into())
.collect();

let order_by = order_by.map(|(e, options)| {
let e = e.as_ref();
let e = if e.len() == 1 {
Arc::new(e[0].clone().into())
} else {
feature_gated!["dtype-struct", {
let e = e.iter().map(|e| e.clone().into()).collect::<Vec<_>>();
Arc::new(as_struct(e))
}]
};
(e, options)
});

Expr::Window {
function: Arc::new(self),
partition_by,
order_by,
options: options.into(),
}
}
Expand All @@ -980,6 +997,7 @@ impl Expr {
Expr::Window {
function: Arc::new(self),
partition_by: vec![index_col],
order_by: None,
options: WindowType::Rolling(options),
}
}
Expand Down
Loading

0 comments on commit 6f3fd8e

Please sign in to comment.