diff --git a/crates/polars-plan/src/plans/aexpr/mod.rs b/crates/polars-plan/src/plans/aexpr/mod.rs index 95dee57901e4..e2a686c37bb4 100644 --- a/crates/polars-plan/src/plans/aexpr/mod.rs +++ b/crates/polars-plan/src/plans/aexpr/mod.rs @@ -263,4 +263,8 @@ impl AExpr { pub(crate) fn is_leaf(&self) -> bool { matches!(self, AExpr::Column(_) | AExpr::Literal(_) | AExpr::Len) } + + pub(crate) fn is_col(&self) -> bool { + matches!(self, AExpr::Column(_)) + } } diff --git a/crates/polars-plan/src/plans/conversion/join.rs b/crates/polars-plan/src/plans/conversion/join.rs index 9c973d446ead..ba0eb2670b14 100644 --- a/crates/polars-plan/src/plans/conversion/join.rs +++ b/crates/polars-plan/src/plans/conversion/join.rs @@ -148,55 +148,60 @@ pub fn resolve_join( }; } + // If we do a full join and keys are coalesced, the casted keys must be added up front. + let key_cols_coalesced = + options.args.should_coalesce() && matches!(&options.args.how, JoinType::Full); let mut to_cast_left = vec![]; let mut to_cast_right = vec![]; let mut to_cast_indices = vec![]; - for (i, (lnode, rnode)) in left_on.iter().zip(right_on.iter()).enumerate() { + for (i, (lnode, rnode)) in left_on.iter_mut().zip(right_on.iter_mut()).enumerate() { let ltype = get_dtype!(lnode, &schema_left)?; let rtype = get_dtype!(rnode, &schema_right)?; - if let (AExpr::Column(lname), AExpr::Column(rname)) = ( - ctxt.expr_arena.get(lnode.node()).clone(), - ctxt.expr_arena.get(rnode.node()).clone(), - ) { - if let Some(dtype) = get_numeric_upcast_supertype_lossless(<ype, &rtype) { - { - let expr = ctxt.expr_arena.add(AExpr::Column(lname.clone())); - to_cast_left.push(ExprIR::new( - ctxt.expr_arena.add(AExpr::Cast { - expr, - dtype: dtype.clone(), - options: CastOptions::Strict, - }), - OutputName::ColumnLhs(lname), - )) - }; - - { - let expr = ctxt.expr_arena.add(AExpr::Column(rname.clone())); - to_cast_right.push(ExprIR::new( - ctxt.expr_arena.add(AExpr::Cast { - expr, - dtype: dtype.clone(), - options: CastOptions::Strict, - }), - OutputName::ColumnLhs(rname), - )) - }; + if let Some(dtype) = get_numeric_upcast_supertype_lossless(<ype, &rtype) { + let casted_l = ctxt.expr_arena.add(AExpr::Cast { + expr: lnode.node(), + dtype: dtype.clone(), + options: CastOptions::Strict, + }); + let casted_r = ctxt.expr_arena.add(AExpr::Cast { + expr: rnode.node(), + dtype, + options: CastOptions::Strict, + }); + + if key_cols_coalesced { + let mut lnode = lnode.clone(); + let mut rnode = rnode.clone(); + + let ae_l = ctxt.expr_arena.get(lnode.node()); + let ae_r = ctxt.expr_arena.get(rnode.node()); + + polars_ensure!( + ae_l.is_col() && ae_r.is_col(), + SchemaMismatch: "can only 'coalesce' full join if join keys are column expressions", + ); - to_cast_indices.push(i); + lnode.set_node(casted_l); + rnode.set_node(casted_r); - continue; + to_cast_indices.push(i); + to_cast_right.push(rnode); + to_cast_left.push(lnode); + } else { + lnode.set_node(casted_l); + rnode.set_node(casted_r); } + } else { + polars_ensure!( + ltype == rtype, + SchemaMismatch: "datatypes of join keys don't match - `{}`: {} on left does not match `{}`: {} on right", + lnode.output_name(), ltype, rnode.output_name(), rtype + ) } - - polars_ensure!( - ltype == rtype, - SchemaMismatch: "datatypes of join keys don't match - `{}`: {} on left does not match `{}`: {} on right", - lnode.output_name(), ltype, rnode.output_name(), rtype - ) } + // Every expression must be elementwise so that we are // guaranteed the keys for a join are all the same length. @@ -234,15 +239,6 @@ pub fn resolve_join( options: ProjectionOptions::default(), }) }; - } else { - for ((i, ir_left), ir_right) in to_cast_indices - .into_iter() - .zip(to_cast_left) - .zip(to_cast_right) - { - left_on[i] = ir_left; - right_on[i] = ir_right; - } } let lp = IR::Join {