From 8b65fe3bad10856ae9e618523caab774cd735ccb Mon Sep 17 00:00:00 2001 From: barak1412 Date: Fri, 30 Aug 2024 11:48:30 +0300 Subject: [PATCH] perf: Added optimizer rules for `is_null().all()` and similar expressions to use `null_count()` (#18359) --- crates/polars-plan/src/plans/aexpr/mod.rs | 13 ++ crates/polars-plan/src/plans/lit.rs | 11 + .../src/plans/optimizer/simplify_expr.rs | 69 ++++++ .../src/plans/optimizer/simplify_functions.rs | 81 +++++++ .../unit/lazyframe/test_optimizations.py | 206 ++++++++++++++++++ 5 files changed, 380 insertions(+) create mode 100644 py-polars/tests/unit/lazyframe/test_optimizations.py diff --git a/crates/polars-plan/src/plans/aexpr/mod.rs b/crates/polars-plan/src/plans/aexpr/mod.rs index 486dfa6c19e5..3cbf9e1519be 100644 --- a/crates/polars-plan/src/plans/aexpr/mod.rs +++ b/crates/polars-plan/src/plans/aexpr/mod.rs @@ -427,6 +427,19 @@ impl AExpr { pub(crate) fn is_leaf(&self) -> bool { matches!(self, AExpr::Column(_) | AExpr::Literal(_) | AExpr::Len) } + pub(crate) fn new_null_count(input: &[ExprIR]) -> Self { + AExpr::Function { + input: input.to_vec(), + function: FunctionExpr::NullCount, + options: FunctionOptions { + collect_groups: ApplyOptions::GroupWise, + fmt_str: "", + cast_to_supertypes: None, + check_lengths: UnsafeBool::default(), + flags: FunctionFlags::ALLOW_GROUP_AWARE | FunctionFlags::RETURNS_SCALAR, + }, + } + } } impl IRAggExpr { diff --git a/crates/polars-plan/src/plans/lit.rs b/crates/polars-plan/src/plans/lit.rs index 060bbf0fd460..9489415f5cda 100644 --- a/crates/polars-plan/src/plans/lit.rs +++ b/crates/polars-plan/src/plans/lit.rs @@ -220,6 +220,17 @@ impl LiteralValue { LiteralValue::StrCat(_) => DataType::Unknown(UnknownKind::Str), } } + + pub(crate) fn new_idxsize(value: IdxSize) -> Self { + #[cfg(feature = "bigidx")] + { + LiteralValue::UInt64(value) + } + #[cfg(not(feature = "bigidx"))] + { + LiteralValue::UInt32(value) + } + } } pub trait Literal { diff --git a/crates/polars-plan/src/plans/optimizer/simplify_expr.rs b/crates/polars-plan/src/plans/optimizer/simplify_expr.rs index db03f516506c..aec74543a5dc 100644 --- a/crates/polars-plan/src/plans/optimizer/simplify_expr.rs +++ b/crates/polars-plan/src/plans/optimizer/simplify_expr.rs @@ -440,6 +440,75 @@ impl OptimizationRule for SimplifyExprRule { let expr = expr_arena.get(expr_node).clone(); let out = match &expr { + // drop_nulls().len() -> len() - null_count() + // drop_nulls().count() -> len() - null_count() + AExpr::Agg(IRAggExpr::Count(input, _)) => { + let input_expr = expr_arena.get(*input); + match input_expr { + AExpr::Function { + input, + function: FunctionExpr::DropNulls, + options: _, + } => { + // we should perform optimization only if the original expression is a column + // so in case of disabled CSE, we will not suffer from performance regression + if input.len() == 1 { + let drop_nulls_input_node = input[0].node(); + match expr_arena.get(drop_nulls_input_node) { + AExpr::Column(_) => Some(AExpr::BinaryExpr { + op: Operator::Minus, + right: expr_arena.add(AExpr::new_null_count(input)), + left: expr_arena.add(AExpr::Agg(IRAggExpr::Count( + drop_nulls_input_node, + true, + ))), + }), + _ => None, + } + } else { + None + } + }, + _ => None, + } + }, + // is_null().sum() -> null_count() + // is_not_null().sum() -> len() - null_count() + AExpr::Agg(IRAggExpr::Sum(input)) => { + let input_expr = expr_arena.get(*input); + match input_expr { + AExpr::Function { + input, + function: FunctionExpr::Boolean(BooleanFunction::IsNull), + options: _, + } => Some(AExpr::new_null_count(input)), + AExpr::Function { + input, + function: FunctionExpr::Boolean(BooleanFunction::IsNotNull), + options: _, + } => { + // we should perform optimization only if the original expression is a column + // so in case of disabled CSE, we will not suffer from performance regression + if input.len() == 1 { + let is_not_null_input_node = input[0].node(); + match expr_arena.get(is_not_null_input_node) { + AExpr::Column(_) => Some(AExpr::BinaryExpr { + op: Operator::Minus, + right: expr_arena.add(AExpr::new_null_count(input)), + left: expr_arena.add(AExpr::Agg(IRAggExpr::Count( + is_not_null_input_node, + true, + ))), + }), + _ => None, + } + } else { + None + } + }, + _ => None, + } + }, // lit(left) + lit(right) => lit(left + right) // and null propagation AExpr::BinaryExpr { left, op, right } => { diff --git a/crates/polars-plan/src/plans/optimizer/simplify_functions.rs b/crates/polars-plan/src/plans/optimizer/simplify_functions.rs index 504af2e517f9..3d9518193276 100644 --- a/crates/polars-plan/src/plans/optimizer/simplify_functions.rs +++ b/crates/polars-plan/src/plans/optimizer/simplify_functions.rs @@ -7,6 +7,87 @@ pub(super) fn optimize_functions( expr_arena: &mut Arena, ) -> PolarsResult> { let out = match function { + // is_null().any() -> null_count() > 0 + // is_not_null().any() -> null_count() < len() + // CORRECTNESS: we can ignore 'ignore_nulls' since is_null/is_not_null never produces NULLS + FunctionExpr::Boolean(BooleanFunction::Any { ignore_nulls: _ }) => { + let input_node = expr_arena.get(input[0].node()); + match input_node { + AExpr::Function { + input, + function: FunctionExpr::Boolean(BooleanFunction::IsNull), + options: _, + } => Some(AExpr::BinaryExpr { + left: expr_arena.add(AExpr::new_null_count(input)), + op: Operator::Gt, + right: expr_arena.add(AExpr::Literal(LiteralValue::new_idxsize(0))), + }), + AExpr::Function { + input, + function: FunctionExpr::Boolean(BooleanFunction::IsNotNull), + options: _, + } => { + // we should perform optimization only if the original expression is a column + // so in case of disabled CSE, we will not suffer from performance regression + if input.len() == 1 { + let is_not_null_input_node = input[0].node(); + match expr_arena.get(is_not_null_input_node) { + AExpr::Column(_) => Some(AExpr::BinaryExpr { + op: Operator::Lt, + left: expr_arena.add(AExpr::new_null_count(input)), + right: expr_arena.add(AExpr::Agg(IRAggExpr::Count( + is_not_null_input_node, + true, + ))), + }), + _ => None, + } + } else { + None + } + }, + _ => None, + } + }, + // is_null().all() -> null_count() == len() + // is_not_null().all() -> null_count() == 0 + FunctionExpr::Boolean(BooleanFunction::All { ignore_nulls: _ }) => { + let input_node = expr_arena.get(input[0].node()); + match input_node { + AExpr::Function { + input, + function: FunctionExpr::Boolean(BooleanFunction::IsNull), + options: _, + } => { + // we should perform optimization only if the original expression is a column + // so in case of disabled CSE, we will not suffer from performance regression + if input.len() == 1 { + let is_null_input_node = input[0].node(); + match expr_arena.get(is_null_input_node) { + AExpr::Column(_) => Some(AExpr::BinaryExpr { + op: Operator::Eq, + right: expr_arena.add(AExpr::new_null_count(input)), + left: expr_arena + .add(AExpr::Agg(IRAggExpr::Count(is_null_input_node, true))), + }), + _ => None, + } + } else { + None + } + }, + AExpr::Function { + input, + function: FunctionExpr::Boolean(BooleanFunction::IsNotNull), + options: _, + } => Some(AExpr::BinaryExpr { + left: expr_arena.add(AExpr::new_null_count(input)), + op: Operator::Eq, + right: expr_arena.add(AExpr::Literal(LiteralValue::new_idxsize(0))), + }), + _ => None, + } + }, // sort().reverse() -> sort(reverse) // sort_by().reverse() -> sort_by(reverse) FunctionExpr::Reverse => { diff --git a/py-polars/tests/unit/lazyframe/test_optimizations.py b/py-polars/tests/unit/lazyframe/test_optimizations.py new file mode 100644 index 000000000000..648bd3123787 --- /dev/null +++ b/py-polars/tests/unit/lazyframe/test_optimizations.py @@ -0,0 +1,206 @@ +import polars as pl +from polars.testing import assert_frame_equal + + +def test_is_null_followed_by_all() -> None: + lf = pl.LazyFrame({"group": [0, 0, 0, 1], "val": [6, 0, None, None]}) + + expected_df = pl.DataFrame({"group": [0, 1], "val": [False, True]}) + result_lf = lf.group_by("group", maintain_order=True).agg( + pl.col("val").is_null().all() + ) + + assert ( + r'[[(col("val").count()) == (col("val").null_count())]]' in result_lf.explain() + ) + assert "is_null" not in result_lf + assert_frame_equal(expected_df, result_lf.collect()) + + # verify we don't optimize on chained expressions when last one is not col + non_optimized_result_plan = ( + lf.group_by("group", maintain_order=True) + .agg(pl.col("val").abs().is_null().all()) + .explain() + ) + assert "null_count" not in non_optimized_result_plan + assert "is_null" in non_optimized_result_plan + + # edge case of empty series + lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32}) + + expected_df = pl.DataFrame({"val": [True]}) + result_df = lf.select(pl.col("val").is_null().all()).collect() + assert_frame_equal(expected_df, result_df) + + +def test_is_null_followed_by_any() -> None: + lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]}) + + expected_df = pl.DataFrame({"group": [0, 1, 2], "val": [True, True, False]}) + result_lf = lf.group_by("group", maintain_order=True).agg( + pl.col("val").is_null().any() + ) + assert_frame_equal(expected_df, result_lf.collect()) + + # edge case of empty series + lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32}) + + expected_df = pl.DataFrame({"val": [False]}) + result_df = lf.select(pl.col("val").is_null().any()).collect() + assert_frame_equal(expected_df, result_df) + + +def test_is_not_null_followed_by_all() -> None: + lf = pl.LazyFrame({"group": [0, 0, 0, 1], "val": [6, 0, 5, None]}) + + expected_df = pl.DataFrame({"group": [0, 1], "val": [True, False]}) + result_df = ( + lf.group_by("group", maintain_order=True) + .agg(pl.col("val").is_not_null().all()) + .collect() + ) + + assert_frame_equal(expected_df, result_df) + + # edge case of empty series + lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32}) + + expected_df = pl.DataFrame({"val": [True]}) + result_df = lf.select(pl.col("val").is_not_null().all()).collect() + assert_frame_equal(expected_df, result_df) + + +def test_is_not_null_followed_by_any() -> None: + lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]}) + + expected_df = pl.DataFrame({"group": [0, 1, 2], "val": [True, False, True]}) + result_lf = lf.group_by("group", maintain_order=True).agg( + pl.col("val").is_not_null().any() + ) + + assert ( + r'[[(col("val").null_count()) < (col("val").count())]]' in result_lf.explain() + ) + assert "is_not_null" not in result_lf.explain() + assert_frame_equal(expected_df, result_lf.collect()) + + # verify we don't optimize on chained expressions when last one is not col + non_optimized_result_plan = ( + lf.group_by("group", maintain_order=True) + .agg(pl.col("val").abs().is_not_null().any()) + .explain() + ) + assert "null_count" not in non_optimized_result_plan + assert "is_not_null" in non_optimized_result_plan + + # edge case of empty series + lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32}) + + expected_df = pl.DataFrame({"val": [False]}) + result_df = lf.select(pl.col("val").is_not_null().any()).collect() + assert_frame_equal(expected_df, result_df) + + +def test_is_null_followed_by_sum() -> None: + lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]}) + + expected_df = pl.DataFrame( + {"group": [0, 1, 2], "val": [1, 1, 0]}, schema_overrides={"val": pl.UInt32} + ) + result_lf = lf.group_by("group", maintain_order=True).agg( + pl.col("val").is_null().sum() + ) + + assert r'[col("val").null_count()]' in result_lf.explain() + assert "is_null" not in result_lf.explain() + assert_frame_equal(expected_df, result_lf.collect()) + + # edge case of empty series + lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32}) + + expected_df = pl.DataFrame({"val": [0]}, schema={"val": pl.UInt32}) + result_df = lf.select(pl.col("val").is_null().sum()).collect() + assert_frame_equal(expected_df, result_df) + + +def test_is_not_null_followed_by_sum() -> None: + lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]}) + + expected_df = pl.DataFrame( + {"group": [0, 1, 2], "val": [2, 0, 1]}, schema_overrides={"val": pl.UInt32} + ) + result_lf = lf.group_by("group", maintain_order=True).agg( + pl.col("val").is_not_null().sum() + ) + + assert ( + r'[[(col("val").count()) - (col("val").null_count())]]' in result_lf.explain() + ) + assert "is_not_null" not in result_lf.explain() + assert_frame_equal(expected_df, result_lf.collect()) + + # verify we don't optimize on chained expressions when last one is not col + non_optimized_result_lf = lf.group_by("group", maintain_order=True).agg( + pl.col("val").abs().is_not_null().sum() + ) + assert "null_count" not in non_optimized_result_lf.explain() + assert "is_not_null" in non_optimized_result_lf.explain() + + # edge case of empty series + lf = pl.LazyFrame({"val": []}, schema={"val": pl.Int32}) + + expected_df = pl.DataFrame({"val": [0]}, schema={"val": pl.UInt32}) + result_df = lf.select(pl.col("val").is_not_null().sum()).collect() + assert_frame_equal(expected_df, result_df) + + +def test_drop_nulls_followed_by_len() -> None: + lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]}) + + expected_df = pl.DataFrame( + {"group": [0, 1, 2], "val": [2, 0, 1]}, schema_overrides={"val": pl.UInt32} + ) + result_lf = lf.group_by("group", maintain_order=True).agg( + pl.col("val").drop_nulls().len() + ) + + assert ( + r'[[(col("val").count()) - (col("val").null_count())]]' in result_lf.explain() + ) + assert "drop_nulls" not in result_lf.explain() + assert_frame_equal(expected_df, result_lf.collect()) + + # verify we don't optimize on chained expressions when last one is not col + non_optimized_result_plan = ( + lf.group_by("group", maintain_order=True) + .agg(pl.col("val").abs().drop_nulls().len()) + .explain() + ) + assert "null_count" not in non_optimized_result_plan + assert "drop_nulls" in non_optimized_result_plan + + +def test_drop_nulls_followed_by_count() -> None: + lf = pl.LazyFrame({"group": [0, 0, 0, 1, 2], "val": [6, 0, None, None, 5]}) + + expected_df = pl.DataFrame( + {"group": [0, 1, 2], "val": [2, 0, 1]}, schema_overrides={"val": pl.UInt32} + ) + result_lf = lf.group_by("group", maintain_order=True).agg( + pl.col("val").drop_nulls().count() + ) + + assert ( + r'[[(col("val").count()) - (col("val").null_count())]]' in result_lf.explain() + ) + assert "drop_nulls" not in result_lf.explain() + assert_frame_equal(expected_df, result_lf.collect()) + + # verify we don't optimize on chained expressions when last one is not col + non_optimized_result_plan = ( + lf.group_by("group", maintain_order=True) + .agg(pl.col("val").abs().drop_nulls().count()) + .explain() + ) + assert "null_count" not in non_optimized_result_plan + assert "drop_nulls" in non_optimized_result_plan