diff --git a/crates/polars-plan/src/dsl/expr.rs b/crates/polars-plan/src/dsl/expr.rs index ced2de5e7eb5..a8c48cd17fb8 100644 --- a/crates/polars-plan/src/dsl/expr.rs +++ b/crates/polars-plan/src/dsl/expr.rs @@ -393,6 +393,29 @@ impl Operator { ) } + pub fn swap_operands(self) -> Self { + match self { + Operator::Eq => Operator::Eq, + Operator::Gt => Operator::Lt, + Operator::GtEq => Operator::LtEq, + Operator::LtEq => Operator::GtEq, + Operator::Or => Operator::Or, + Operator::LogicalAnd => Operator::LogicalAnd, + Operator::LogicalOr => Operator::LogicalOr, + Operator::Xor => Operator::Xor, + Operator::NotEq => Operator::NotEq, + Operator::EqValidity => Operator::EqValidity, + Operator::NotEqValidity => Operator::NotEqValidity, + Operator::Divide => Operator::Multiply, + Operator::Multiply => Operator::Divide, + Operator::And => Operator::And, + Operator::Plus => Operator::Minus, + Operator::Minus => Operator::Plus, + Operator::Lt => Operator::Gt, + _ => unimplemented!(), + } + } + pub fn is_arithmetic(&self) -> bool { !(self.is_comparison()) } diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs index 5ab23be19a14..658dc9989eb5 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs @@ -558,8 +558,8 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult options, } => { return join::resolve_join( - input_left, - input_right, + Either::Left(input_left), + Either::Left(input_right), left_on, right_on, predicates, diff --git a/crates/polars-plan/src/plans/conversion/join.rs b/crates/polars-plan/src/plans/conversion/join.rs index 6c3e28bb6c7a..012a801b883a 100644 --- a/crates/polars-plan/src/plans/conversion/join.rs +++ b/crates/polars-plan/src/plans/conversion/join.rs @@ -1,4 +1,5 @@ use arrow::legacy::error::PolarsResult; +use either::Either; use super::*; use crate::dsl::Expr; @@ -16,8 +17,8 @@ fn check_join_keys(keys: &[Expr]) -> PolarsResult<()> { Ok(()) } pub fn resolve_join( - input_left: Arc, - input_right: Arc, + input_left: Either, Node>, + input_right: Either, Node>, left_on: Vec, right_on: Vec, predicates: Vec, @@ -26,7 +27,13 @@ pub fn resolve_join( ) -> PolarsResult { if !predicates.is_empty() { debug_assert!(left_on.is_empty() && right_on.is_empty()); - return resolve_join_where(input_left, input_right, predicates, options, ctxt); + return resolve_join_where( + input_left.unwrap_left(), + input_right.unwrap_left(), + predicates, + options, + ctxt, + ); } let owned = Arc::unwrap_or_clone; @@ -62,10 +69,12 @@ pub fn resolve_join( ); } - let input_left = - to_alp_impl(owned(input_left), ctxt).map_err(|e| e.context(failed_input!(join left)))?; - let input_right = - to_alp_impl(owned(input_right), ctxt).map_err(|e| e.context(failed_input!(join, right)))?; + let input_left = input_left.map_right(Ok).right_or_else(|input| { + to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_input!(join left))) + })?; + let input_right = input_right.map_right(Ok).right_or_else(|input| { + to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_input!(join right))) + })?; let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena); let schema_right = ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena); @@ -129,7 +138,6 @@ fn resolve_join_where( ctxt: &mut DslConversionContext, ) -> PolarsResult { check_join_keys(&predicates)?; - for e in &predicates { let no_binary_comparisons = e .into_iter() @@ -138,15 +146,40 @@ fn resolve_join_where( _ => false, }) .count(); - polars_ensure!(no_binary_comparisons == 1, InvalidOperation: "only 1 binary comparison allowed as join condition") + polars_ensure!(no_binary_comparisons == 1, InvalidOperation: "only 1 binary comparison allowed as join condition"); } + let input_left = to_alp_impl(Arc::unwrap_or_clone(input_left), ctxt) + .map_err(|e| e.context(failed_input!(join left)))?; + let input_right = to_alp_impl(Arc::unwrap_or_clone(input_right), ctxt) + .map_err(|e| e.context(failed_input!(join left)))?; + + let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena); + let schema_right = ctxt + .lp_arena + .get(input_right) + .schema(ctxt.lp_arena) + .into_owned(); let owned = |e: Arc| (*e).clone(); - // Partition to: + // We do a few things + // First weartition to: // - IEjoin supported inequality predicates // - equality predicates // - remaining predicates + // And then decide to which join we dispatch. + // The remaining predicates will be applied as filter. + + // What make things a bit complicated is that duplicate join names + // are referred to in the query with the name post-join, but on joins + // we refer to the names pre-join (e.g. without suffix). So there is some + // bookkeeping. + // + // - First we determine which side of the binary expression refers to the left and right table + // and make sure that lhs of the binary expr, maps to the lhs of the join tables and vice versa. + // Next we ensure the suffixes are removed when we partition. + // + // If a predicate has to be applied as post-join filter, we put the suffixes back if needed. let mut ie_left_on = vec![]; let mut ie_right_on = vec![]; let mut ie_op = vec![]; @@ -166,37 +199,110 @@ fn resolve_join_where( } } + fn rename_expr(e: Expr, old: &str, new: &str) -> Expr { + e.map_expr(|e| match e { + Expr::Column(name) if name.as_str() == old => Expr::Column(new.into()), + e => e, + }) + } + + fn determine_order_and_pre_join_names( + left: Expr, + op: Operator, + right: Expr, + schema_left: &Schema, + schema_right: &Schema, + suffix: &str, + ) -> PolarsResult<(Expr, Operator, Expr)> { + let left_names = expr_to_leaf_column_names_iter(&left).collect::>(); + let right_names = expr_to_leaf_column_names_iter(&right).collect::>(); + + // All left should be in the left schema. + let (left_names, right_names, left, op, mut right) = + if !left_names.iter().all(|n| schema_left.contains(n)) { + // If all right names are in left schema -> swap + if right_names.iter().all(|n| schema_left.contains(n)) { + (right_names, left_names, right, op.swap_operands(), left) + } else { + polars_bail!(InvalidOperation: "got ambiguous column names in 'join_where'") + } + } else { + (left_names, right_names, left, op, right) + }; + for name in &left_names { + polars_ensure!(!right_names.contains(name.as_str()), InvalidOperation: "got ambiguous column names in 'join_where'\n\n\ + Note that you should refer to the column names as they are post-join operation.") + } + + // Now we know left belongs to the left schema, rhs suffixes are dealt with. + for post_join_name in right_names { + if let Some(pre_join_name) = post_join_name.strip_suffix(suffix) { + // Name is both sides, so a suffix will be added by the join. + // We rename + if schema_right.contains(pre_join_name) && schema_left.contains(pre_join_name) { + right = rename_expr(right, &post_join_name, pre_join_name); + } + } + } + Ok((left, op, right)) + } + + // Make it a binary comparison and ensure the columns refer to post join names. + fn to_binary_post_join( + l: Expr, + op: Operator, + mut r: Expr, + schema_right: &Schema, + suffix: &str, + ) -> Expr { + let names = expr_to_leaf_column_names_iter(&r).collect::>(); + for pre_join_name in &names { + if !schema_right.contains(pre_join_name) { + let post_join_name = _join_suffix_name(pre_join_name, suffix); + r = rename_expr(r, pre_join_name, post_join_name.as_str()); + } + } + + Expr::BinaryExpr { + left: Arc::from(l), + op, + right: Arc::from(r), + } + } + + let suffix = options.args.suffix().clone(); for pred in predicates.into_iter() { let Expr::BinaryExpr { left, op, right } = pred.clone() else { polars_bail!(InvalidOperation: "can only join on binary expressions") }; polars_ensure!(op.is_comparison(), InvalidOperation: "expected comparison in join predicate"); + let (left, op, right) = determine_order_and_pre_join_names( + owned(left), + op, + owned(right), + &schema_left, + &schema_right, + &suffix, + )?; if let Some(ie_op_) = to_inequality_operator(&op) { // We already have an IEjoin or an Inner join, push to remaining if ie_op.len() >= 2 || !eq_right_on.is_empty() { - remaining_preds.push(Expr::BinaryExpr { left, op, right }) + remaining_preds.push(to_binary_post_join(left, op, right, &schema_right, &suffix)) } else { - ie_left_on.push(owned(left)); - ie_right_on.push(owned(right)); + ie_left_on.push(left); + ie_right_on.push(right); ie_op.push(ie_op_) } } else if matches!(op, Operator::Eq) { - eq_left_on.push(owned(left)); - eq_right_on.push(owned(right)); + eq_left_on.push(left); + eq_right_on.push(right); } else { - remaining_preds.push(pred); + remaining_preds.push(to_binary_post_join(left, op, right, &schema_right, &suffix)); } } // Now choose a primary join and do the remaining predicates as filters - fn to_binary(l: Expr, op: Operator, r: Expr) -> Expr { - Expr::BinaryExpr { - left: Arc::from(l), - op, - right: Arc::from(r), - } - } // Add the ie predicates to the remaining predicates buffer so that they will be executed in the // filter node. fn ie_predicates_to_remaining( @@ -204,13 +310,15 @@ fn resolve_join_where( ie_left_on: Vec, ie_right_on: Vec, ie_op: Vec, + schema_right: &Schema, + suffix: &str, ) { for ((l, op), r) in ie_left_on .into_iter() .zip(ie_op.into_iter()) .zip(ie_right_on.into_iter()) { - remaining_preds.push(to_binary(l, op.into(), r)) + remaining_preds.push(to_binary_post_join(l, op.into(), r, schema_right, suffix)) } } @@ -218,8 +326,8 @@ fn resolve_join_where( // We found one or more equality predicates. Go into a default equi join // as those are cheapest on avg. let join_node = resolve_join( - input_left, - input_right, + Either::Right(input_left), + Either::Right(input_right), eq_left_on, eq_right_on, vec![], @@ -227,7 +335,14 @@ fn resolve_join_where( ctxt, )?; - ie_predicates_to_remaining(&mut remaining_preds, ie_left_on, ie_right_on, ie_op); + ie_predicates_to_remaining( + &mut remaining_preds, + ie_left_on, + ie_right_on, + ie_op, + &schema_right, + &suffix, + ); join_node } // TODO! once we support single IEjoin predicates, we must add a branch for the singe ie_pred case. @@ -240,8 +355,8 @@ fn resolve_join_where( }); let join_node = resolve_join( - input_left, - input_right, + Either::Right(input_left), + Either::Right(input_right), ie_left_on[..2].to_vec(), ie_right_on[..2].to_vec(), vec![], @@ -258,7 +373,7 @@ fn resolve_join_where( let r = ie_left_on.pop().unwrap(); let op = ie_op.pop().unwrap(); - remaining_preds.push(to_binary(l, op.into(), r)) + remaining_preds.push(to_binary_post_join(l, op.into(), r, &schema_right, &suffix)) } join_node } else { @@ -268,8 +383,8 @@ fn resolve_join_where( opts.args.how = JoinType::Cross; let join_node = resolve_join( - input_left, - input_right, + Either::Right(input_left), + Either::Right(input_right), vec![], vec![], vec![], @@ -277,7 +392,14 @@ fn resolve_join_where( ctxt, )?; // TODO: This can be removed once we support the single IEjoin. - ie_predicates_to_remaining(&mut remaining_preds, ie_left_on, ie_right_on, ie_op); + ie_predicates_to_remaining( + &mut remaining_preds, + ie_left_on, + ie_right_on, + ie_op, + &schema_right, + &suffix, + ); join_node }; @@ -301,8 +423,6 @@ fn resolve_join_where( .schema(ctxt.lp_arena) .into_owned(); - let suffix = options.args.suffix(); - let mut last_node = join_node; // Ensure that the predicates use the proper suffix diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 8a023b4e01ae..2d8710b9f431 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -7118,20 +7118,11 @@ def join_where( DataFrame to join with. *predicates (In)Equality condition to join the two table on. - The left `pl.col(..)` will refer to the left table - and the right `pl.col(..)` - to the right table. - For example: `pl.col("time") >= pl.col("duration")` + When a column name occurs in both tables, the proper suffix must + be applied in the predicate. suffix Suffix to append to columns with a duplicate name. - Notes - ----- - This method is strict about its equality expressions. - Only 1 equality expression is allowed per predicate, where - the lhs `pl.col` refers to the left table in the join, and the - rhs `pl.col` refers to the right table. - Examples -------- >>> east = pl.DataFrame( diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index 44978528e272..ec329898441a 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -4581,23 +4581,14 @@ def join_where( Parameters ---------- other - LazyFrame to join with. + DataFrame to join with. *predicates (In)Equality condition to join the two table on. - The left `pl.col(..)` will refer to the left table - and the right `pl.col(..)` - to the right table. - For example: `pl.col("time") >= pl.col("duration")` + When a column name occurs in both tables, the proper suffix must + be applied in the predicate. suffix Suffix to append to columns with a duplicate name. - Notes - ----- - This method is strict about its equality expressions. - Only 1 equality expression is allowed per predicate, where - the lhs `pl.col` refers to the left table in the join, and the - rhs `pl.col` refers to the right table. - Examples -------- >>> east = pl.LazyFrame( diff --git a/py-polars/tests/unit/operations/test_inequality_join.py b/py-polars/tests/unit/operations/test_inequality_join.py index 7b3ddb279fec..242a0e1f5e8e 100644 --- a/py-polars/tests/unit/operations/test_inequality_join.py +++ b/py-polars/tests/unit/operations/test_inequality_join.py @@ -15,7 +15,14 @@ from hypothesis.strategies import DrawFn, SearchStrategy -def test_self_join() -> None: +@pytest.mark.parametrize( + ("pred_1", "pred_2"), + [ + (pl.col("time") > pl.col("time_right"), pl.col("cost") < pl.col("cost_right")), + (pl.col("time_right") < pl.col("time"), pl.col("cost_right") > pl.col("cost")), + ], +) +def test_self_join(pred_1: pl.Expr, pred_2: pl.Expr) -> None: west = pl.DataFrame( { "t_id": [404, 498, 676, 742], @@ -25,9 +32,7 @@ def test_self_join() -> None: } ) - actual = west.join_where( - west, pl.col("time") > pl.col("time"), pl.col("cost") < pl.col("cost") - ) + actual = west.join_where(west, pred_1, pred_2) expected = pl.DataFrame( { @@ -223,7 +228,7 @@ def test_join_where_predicates() -> None: right.lazy(), pl.col("time") >= pl.col("start_time"), pl.col("time") < pl.col("end_time"), - pl.col("group") == pl.col("group"), + pl.col("group_right") == pl.col("group"), ) .select("id", "id_right", "group") .sort("id") @@ -252,7 +257,7 @@ def test_join_where_predicates() -> None: right.lazy(), pl.col("time") >= pl.col("start_time"), pl.col("time") < pl.col("end_time"), - pl.col("group") != pl.col("group"), + pl.col("group") != pl.col("group_right"), ) .select("id", "id_right", "group") .sort("id") @@ -279,7 +284,7 @@ def test_join_where_predicates() -> None: left.lazy() .join_where( right.lazy(), - pl.col("group") != pl.col("group"), + pl.col("group") != pl.col("group_right"), ) .select("id", "group", "group_right") .sort("id") @@ -443,10 +448,10 @@ def test_ie_join_with_floats( assert_frame_equal(actual, expected, check_row_order=False, check_exact=True) -def test_raise_on_suffixed_predicate_18604() -> None: +def test_raise_on_ambiguous_name() -> None: df = pl.DataFrame({"id": [1, 2]}) - with pytest.raises(pl.exceptions.ColumnNotFoundError): - df.join_where(df, pl.col("id") >= pl.col("id_right")) + with pytest.raises(pl.exceptions.InvalidOperationError): + df.join_where(df, pl.col("id") >= pl.col("id")) def test_raise_on_multiple_binary_comparisons() -> None: