Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allow more group_by agg expressions in the new streaming engine #20663

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion crates/polars-mem-engine/src/planner/lp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ fn partitionable_gb(
// complex expressions in the group_by itself are also not partitionable
// in this case anything more than col("foo")
for key in keys {
if (expr_arena).iter(key.node()).count() > 1 {
if (expr_arena).iter(key.node()).count() > 1
|| has_aexpr(key.node(), expr_arena, |ae| {
matches!(ae, AExpr::Literal(LiteralValue::Series(_)))
})
{
return false;
}
}
Expand Down
34 changes: 22 additions & 12 deletions crates/polars-stream/src/physical_plan/lower_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ use slotmap::SlotMap;

use super::{PhysNode, PhysNodeKey, PhysNodeKind, PhysStream};

type IRNodeKey = Node;
type ExprNodeKey = Node;

fn unique_column_name() -> PlSmallStr {
pub fn unique_column_name() -> PlSmallStr {
static COUNTER: AtomicU64 = AtomicU64::new(0);
let idx = COUNTER.fetch_add(1, Ordering::Relaxed);
format_pl_smallstr!("__POLARS_STMP_{idx}")
Expand All @@ -48,7 +48,7 @@ struct LowerExprContext<'a> {
}

pub(crate) fn is_elementwise_rec_cached(
expr_key: IRNodeKey,
expr_key: ExprNodeKey,
arena: &Arena<AExpr>,
cache: &mut ExprCache,
) -> bool {
Expand Down Expand Up @@ -97,10 +97,10 @@ pub(crate) fn is_elementwise_rec_cached(
}

#[recursive::recursive]
fn is_input_independent_rec(
expr_key: IRNodeKey,
pub fn is_input_independent_rec(
expr_key: ExprNodeKey,
arena: &Arena<AExpr>,
cache: &mut PlHashMap<IRNodeKey, bool>,
cache: &mut PlHashMap<ExprNodeKey, bool>,
) -> bool {
if let Some(ret) = cache.get(&expr_key) {
return *ret;
Expand Down Expand Up @@ -207,7 +207,15 @@ fn is_input_independent_rec(
ret
}

fn is_input_independent(expr_key: IRNodeKey, ctx: &mut LowerExprContext) -> bool {
pub fn is_input_independent(
expr_key: ExprNodeKey,
expr_arena: &Arena<AExpr>,
cache: &mut ExprCache,
) -> bool {
is_input_independent_rec(expr_key, expr_arena, &mut cache.is_input_independent)
}

fn is_input_independent_ctx(expr_key: ExprNodeKey, ctx: &mut LowerExprContext) -> bool {
is_input_independent_rec(
expr_key,
ctx.expr_arena,
Expand Down Expand Up @@ -359,7 +367,7 @@ fn lower_exprs_with_ctx(
) -> PolarsResult<(PhysStream, Vec<Node>)> {
// We have to catch this case separately, in case all the input independent expressions are elementwise.
// TODO: we shouldn't always do this when recursing, e.g. pl.col.a.sum() + 1 will still hit this in the recursion.
if exprs.iter().all(|e| is_input_independent(*e, ctx)) {
if exprs.iter().all(|e| is_input_independent_ctx(*e, ctx)) {
let expr_irs = exprs
.iter()
.map(|e| ExprIR::new(*e, OutputName::Alias(unique_column_name())))
Expand All @@ -384,7 +392,7 @@ fn lower_exprs_with_ctx(

for expr in exprs.iter().copied() {
if is_elementwise_rec_cached(expr, ctx.expr_arena, ctx.cache) {
if !is_input_independent(expr, ctx) {
if !is_input_independent_ctx(expr, ctx) {
input_streams.insert(input);
}
transformed_exprs.push(expr);
Expand Down Expand Up @@ -679,7 +687,10 @@ fn build_select_stream_with_ctx(
exprs: &[ExprIR],
ctx: &mut LowerExprContext,
) -> PolarsResult<PhysStream> {
if exprs.iter().all(|e| is_input_independent(e.node(), ctx)) {
if exprs
.iter()
.all(|e| is_input_independent_ctx(e.node(), ctx))
{
return Ok(PhysStream::first(build_input_independent_node_with_ctx(
exprs, ctx,
)?));
Expand All @@ -696,8 +707,7 @@ fn build_select_stream_with_ctx(

if let Some(columns) = all_simple_columns {
let input_schema = ctx.phys_sm[input.node].output_schema.clone();
if !cfg!(debug_assertions)
&& input_schema.len() == columns.len()
if input_schema.len() == columns.len()
&& input_schema.iter_names().zip(&columns).all(|(l, r)| l == r)
{
// Input node already has the correct schema, just pass through.
Expand Down
Loading
Loading