From 760ab20367a39529e9bcb37e24f595d0f71d2828 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Tue, 10 Sep 2024 14:15:07 +0200 Subject: [PATCH] chore: Don't raise on multiple same names in ie_join (#18658) --- .../polars-plan/src/plans/conversion/join.rs | 21 ++++++++++++------- .../unit/operations/test_inequality_join.py | 16 ++++++++++++++ 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/crates/polars-plan/src/plans/conversion/join.rs b/crates/polars-plan/src/plans/conversion/join.rs index 0a073934b5d3..e7199b2c13d2 100644 --- a/crates/polars-plan/src/plans/conversion/join.rs +++ b/crates/polars-plan/src/plans/conversion/join.rs @@ -90,13 +90,20 @@ pub fn resolve_join( let left_on = to_expr_irs_ignore_alias(left_on, ctxt.expr_arena)?; let right_on = to_expr_irs_ignore_alias(right_on, ctxt.expr_arena)?; let mut joined_on = PlHashSet::new(); - for (l, r) in left_on.iter().zip(right_on.iter()) { - polars_ensure!( - joined_on.insert((l.output_name(), r.output_name())), - InvalidOperation: "joining with repeated key names; already joined on {} and {}", - l.output_name(), - r.output_name() - ) + + #[cfg(feature = "iejoin")] + let check = !matches!(options.args.how, JoinType::IEJoin(_)); + #[cfg(not(feature = "iejoin"))] + let check = true; + if check { + for (l, r) in left_on.iter().zip(right_on.iter()) { + polars_ensure!( + joined_on.insert((l.output_name(), r.output_name())), + InvalidOperation: "joining with repeated key names; already joined on {} and {}", + l.output_name(), + r.output_name() + ) + } } drop(joined_on); diff --git a/py-polars/tests/unit/operations/test_inequality_join.py b/py-polars/tests/unit/operations/test_inequality_join.py index 94839c39fb4f..cc4ea5c6bb02 100644 --- a/py-polars/tests/unit/operations/test_inequality_join.py +++ b/py-polars/tests/unit/operations/test_inequality_join.py @@ -466,3 +466,19 @@ def test_raise_invalid_input_join_where() -> None: df = pl.DataFrame({"id": [1, 2]}) with pytest.raises(pl.exceptions.InvalidOperationError): df.join_where(df) + + +def test_ie_join_use_keys_multiple() -> None: + a = pl.LazyFrame({"a": [1, 2, 3], "x": [7, 2, 1]}) + b = pl.LazyFrame({"b": [2, 2, 2], "x": [7, 1, 3]}) + + assert a.join_where( + b, + pl.col.a >= pl.col.b, + pl.col.a <= pl.col.b, + ).collect().sort("x_right").to_dict(as_series=False) == { + "a": [2, 2, 2], + "x": [2, 2, 2], + "b": [2, 2, 2], + "x_right": [1, 3, 7], + }