Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Change join_where semantics #18640

Merged
merged 2 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions crates/polars-plan/src/dsl/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,29 @@ impl Operator {
)
}

pub fn swap_operands(self) -> Self {
match self {
Operator::Eq => Operator::Eq,
Operator::Gt => Operator::Lt,
Operator::GtEq => Operator::LtEq,
Operator::LtEq => Operator::GtEq,
Operator::Or => Operator::Or,
Operator::LogicalAnd => Operator::LogicalAnd,
Operator::LogicalOr => Operator::LogicalOr,
Operator::Xor => Operator::Xor,
Operator::NotEq => Operator::NotEq,
Operator::EqValidity => Operator::EqValidity,
Operator::NotEqValidity => Operator::NotEqValidity,
Operator::Divide => Operator::Multiply,
Operator::Multiply => Operator::Divide,
Operator::And => Operator::And,
Operator::Plus => Operator::Minus,
Operator::Minus => Operator::Plus,
Operator::Lt => Operator::Gt,
_ => unimplemented!(),
}
}

pub fn is_arithmetic(&self) -> bool {
!(self.is_comparison())
}
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-plan/src/plans/conversion/dsl_to_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -558,8 +558,8 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult
options,
} => {
return join::resolve_join(
input_left,
input_right,
Either::Left(input_left),
Either::Left(input_right),
left_on,
right_on,
predicates,
Expand Down
190 changes: 155 additions & 35 deletions crates/polars-plan/src/plans/conversion/join.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use arrow::legacy::error::PolarsResult;
use either::Either;

use super::*;
use crate::dsl::Expr;
Expand All @@ -16,8 +17,8 @@ fn check_join_keys(keys: &[Expr]) -> PolarsResult<()> {
Ok(())
}
pub fn resolve_join(
input_left: Arc<DslPlan>,
input_right: Arc<DslPlan>,
input_left: Either<Arc<DslPlan>, Node>,
input_right: Either<Arc<DslPlan>, Node>,
left_on: Vec<Expr>,
right_on: Vec<Expr>,
predicates: Vec<Expr>,
Expand All @@ -26,7 +27,13 @@ pub fn resolve_join(
) -> PolarsResult<Node> {
if !predicates.is_empty() {
debug_assert!(left_on.is_empty() && right_on.is_empty());
return resolve_join_where(input_left, input_right, predicates, options, ctxt);
return resolve_join_where(
input_left.unwrap_left(),
input_right.unwrap_left(),
predicates,
options,
ctxt,
);
}

let owned = Arc::unwrap_or_clone;
Expand Down Expand Up @@ -62,10 +69,12 @@ pub fn resolve_join(
);
}

let input_left =
to_alp_impl(owned(input_left), ctxt).map_err(|e| e.context(failed_input!(join left)))?;
let input_right =
to_alp_impl(owned(input_right), ctxt).map_err(|e| e.context(failed_input!(join, right)))?;
let input_left = input_left.map_right(Ok).right_or_else(|input| {
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_input!(join left)))
})?;
let input_right = input_right.map_right(Ok).right_or_else(|input| {
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_input!(join right)))
})?;

let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena);
let schema_right = ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena);
Expand Down Expand Up @@ -129,7 +138,6 @@ fn resolve_join_where(
ctxt: &mut DslConversionContext,
) -> PolarsResult<Node> {
check_join_keys(&predicates)?;

for e in &predicates {
let no_binary_comparisons = e
.into_iter()
Expand All @@ -138,15 +146,40 @@ fn resolve_join_where(
_ => false,
})
.count();
polars_ensure!(no_binary_comparisons == 1, InvalidOperation: "only 1 binary comparison allowed as join condition")
polars_ensure!(no_binary_comparisons == 1, InvalidOperation: "only 1 binary comparison allowed as join condition");
}
let input_left = to_alp_impl(Arc::unwrap_or_clone(input_left), ctxt)
.map_err(|e| e.context(failed_input!(join left)))?;
let input_right = to_alp_impl(Arc::unwrap_or_clone(input_right), ctxt)
.map_err(|e| e.context(failed_input!(join left)))?;

let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena);
let schema_right = ctxt
.lp_arena
.get(input_right)
.schema(ctxt.lp_arena)
.into_owned();

let owned = |e: Arc<Expr>| (*e).clone();

// Partition to:
// We do a few things
// First weartition to:
ritchie46 marked this conversation as resolved.
Show resolved Hide resolved
// - IEjoin supported inequality predicates
// - equality predicates
// - remaining predicates
// And then decide to which join we dispatch.
// The remaining predicates will be applied as filter.

// What make things a bit complicated is that duplicate join names
// are referred to in the query with the name post-join, but on joins
// we refer to the names pre-join (e.g. without suffix). So there is some
// bookkeeping.
//
// - First we determine which side of the binary expression refers to the left and right table
// and make sure that lhs of the binary expr, maps to the lhs of the join tables and vice versa.
// Next we ensure the suffixes are removed when we partition.
//
// If a predicate has to be applied as post-join filter, we put the suffixes back if needed.
let mut ie_left_on = vec![];
let mut ie_right_on = vec![];
let mut ie_op = vec![];
Expand All @@ -166,68 +199,150 @@ fn resolve_join_where(
}
}

fn rename_expr(e: Expr, old: &str, new: &str) -> Expr {
e.map_expr(|e| match e {
Expr::Column(name) if name.as_str() == old => Expr::Column(new.into()),
e => e,
})
}

fn determine_order_and_pre_join_names(
left: Expr,
op: Operator,
right: Expr,
schema_left: &Schema,
schema_right: &Schema,
suffix: &str,
) -> PolarsResult<(Expr, Operator, Expr)> {
let left_names = expr_to_leaf_column_names_iter(&left).collect::<PlHashSet<_>>();
let right_names = expr_to_leaf_column_names_iter(&right).collect::<PlHashSet<_>>();

// All left should be in the left schema.
let (left_names, right_names, left, op, mut right) =
if !left_names.iter().all(|n| schema_left.contains(n)) {
// If all right names are in left schema -> swap
if right_names.iter().all(|n| schema_left.contains(n)) {
(right_names, left_names, right, op.swap_operands(), left)
} else {
polars_bail!(InvalidOperation: "got ambiguous column names in 'join_where'")
}
} else {
(left_names, right_names, left, op, right)
};
for name in &left_names {
polars_ensure!(!right_names.contains(name.as_str()), InvalidOperation: "got ambiguous column names in 'join_where'\n\n\
Note that you should refer to the column names as they are post-join operation.")
}

// Now we know left belongs to the left schema, rhs suffixes are dealt with.
for post_join_name in right_names {
if let Some(pre_join_name) = post_join_name.strip_suffix(suffix) {
// Name is both sides, so a suffix will be added by the join.
// We rename
if schema_right.contains(pre_join_name) && schema_left.contains(pre_join_name) {
right = rename_expr(right, &post_join_name, pre_join_name);
}
}
}
Ok((left, op, right))
}

// Make it a binary comparison and ensure the columns refer to post join names.
fn to_binary_post_join(
l: Expr,
op: Operator,
mut r: Expr,
schema_right: &Schema,
suffix: &str,
) -> Expr {
let names = expr_to_leaf_column_names_iter(&r).collect::<Vec<_>>();
for pre_join_name in &names {
if !schema_right.contains(pre_join_name) {
let post_join_name = _join_suffix_name(pre_join_name, suffix);
r = rename_expr(r, pre_join_name, post_join_name.as_str());
}
}

Expr::BinaryExpr {
left: Arc::from(l),
op,
right: Arc::from(r),
}
}

let suffix = options.args.suffix().clone();
for pred in predicates.into_iter() {
let Expr::BinaryExpr { left, op, right } = pred.clone() else {
polars_bail!(InvalidOperation: "can only join on binary expressions")
};
polars_ensure!(op.is_comparison(), InvalidOperation: "expected comparison in join predicate");
let (left, op, right) = determine_order_and_pre_join_names(
owned(left),
op,
owned(right),
&schema_left,
&schema_right,
&suffix,
)?;

if let Some(ie_op_) = to_inequality_operator(&op) {
// We already have an IEjoin or an Inner join, push to remaining
if ie_op.len() >= 2 || !eq_right_on.is_empty() {
remaining_preds.push(Expr::BinaryExpr { left, op, right })
remaining_preds.push(to_binary_post_join(left, op, right, &schema_right, &suffix))
} else {
ie_left_on.push(owned(left));
ie_right_on.push(owned(right));
ie_left_on.push(left);
ie_right_on.push(right);
ie_op.push(ie_op_)
}
} else if matches!(op, Operator::Eq) {
eq_left_on.push(owned(left));
eq_right_on.push(owned(right));
eq_left_on.push(left);
eq_right_on.push(right);
} else {
remaining_preds.push(pred);
remaining_preds.push(to_binary_post_join(left, op, right, &schema_right, &suffix));
}
}

// Now choose a primary join and do the remaining predicates as filters
fn to_binary(l: Expr, op: Operator, r: Expr) -> Expr {
Expr::BinaryExpr {
left: Arc::from(l),
op,
right: Arc::from(r),
}
}
// Add the ie predicates to the remaining predicates buffer so that they will be executed in the
// filter node.
fn ie_predicates_to_remaining(
remaining_preds: &mut Vec<Expr>,
ie_left_on: Vec<Expr>,
ie_right_on: Vec<Expr>,
ie_op: Vec<InequalityOperator>,
schema_right: &Schema,
suffix: &str,
) {
for ((l, op), r) in ie_left_on
.into_iter()
.zip(ie_op.into_iter())
.zip(ie_right_on.into_iter())
{
remaining_preds.push(to_binary(l, op.into(), r))
remaining_preds.push(to_binary_post_join(l, op.into(), r, schema_right, suffix))
}
}

let join_node = if !eq_left_on.is_empty() {
// We found one or more equality predicates. Go into a default equi join
// as those are cheapest on avg.
let join_node = resolve_join(
input_left,
input_right,
Either::Right(input_left),
Either::Right(input_right),
eq_left_on,
eq_right_on,
vec![],
options.clone(),
ctxt,
)?;

ie_predicates_to_remaining(&mut remaining_preds, ie_left_on, ie_right_on, ie_op);
ie_predicates_to_remaining(
&mut remaining_preds,
ie_left_on,
ie_right_on,
ie_op,
&schema_right,
&suffix,
);
join_node
}
// TODO! once we support single IEjoin predicates, we must add a branch for the singe ie_pred case.
Expand All @@ -240,8 +355,8 @@ fn resolve_join_where(
});

let join_node = resolve_join(
input_left,
input_right,
Either::Right(input_left),
Either::Right(input_right),
ie_left_on[..2].to_vec(),
ie_right_on[..2].to_vec(),
vec![],
Expand All @@ -258,7 +373,7 @@ fn resolve_join_where(
let r = ie_left_on.pop().unwrap();
let op = ie_op.pop().unwrap();

remaining_preds.push(to_binary(l, op.into(), r))
remaining_preds.push(to_binary_post_join(l, op.into(), r, &schema_right, &suffix))
}
join_node
} else {
Expand All @@ -268,16 +383,23 @@ fn resolve_join_where(
opts.args.how = JoinType::Cross;

let join_node = resolve_join(
input_left,
input_right,
Either::Right(input_left),
Either::Right(input_right),
vec![],
vec![],
vec![],
options.clone(),
ctxt,
)?;
// TODO: This can be removed once we support the single IEjoin.
ie_predicates_to_remaining(&mut remaining_preds, ie_left_on, ie_right_on, ie_op);
ie_predicates_to_remaining(
&mut remaining_preds,
ie_left_on,
ie_right_on,
ie_op,
&schema_right,
&suffix,
);
join_node
};

Expand All @@ -301,8 +423,6 @@ fn resolve_join_where(
.schema(ctxt.lp_arena)
.into_owned();

let suffix = options.args.suffix();

let mut last_node = join_node;

// Ensure that the predicates use the proper suffix
Expand Down
13 changes: 2 additions & 11 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -7118,20 +7118,11 @@ def join_where(
DataFrame to join with.
*predicates
(In)Equality condition to join the two table on.
The left `pl.col(..)` will refer to the left table
and the right `pl.col(..)`
to the right table.
For example: `pl.col("time") >= pl.col("duration")`
When a column name occurs in both tables, the proper suffix must
be applied in the predicate.
suffix
Suffix to append to columns with a duplicate name.

Notes
-----
This method is strict about its equality expressions.
Only 1 equality expression is allowed per predicate, where
the lhs `pl.col` refers to the left table in the join, and the
rhs `pl.col` refers to the right table.

Examples
--------
>>> east = pl.DataFrame(
Expand Down
Loading
Loading