diff --git a/crates/polars-core/src/frame/column/scalar.rs b/crates/polars-core/src/frame/column/scalar.rs index 1906b2c7a472..a454e35d0353 100644 --- a/crates/polars-core/src/frame/column/scalar.rs +++ b/crates/polars-core/src/frame/column/scalar.rs @@ -143,7 +143,11 @@ impl ScalarColumn { pub fn from_single_value_series(series: Series, length: usize) -> Self { debug_assert!(series.len() <= 1); - let value = series.get(0).map_or(AnyValue::Null, |av| av.into_static()); + let value = if series.is_empty() { + AnyValue::Null + } else { + unsafe { series.get_unchecked(0) }.into_static() + }; let value = Scalar::new(series.dtype().clone(), value); ScalarColumn::new(series.name().clone(), value, length) } diff --git a/crates/polars-plan/src/constants.rs b/crates/polars-plan/src/constants.rs index e63ad1193774..562cc91c4ef1 100644 --- a/crates/polars-plan/src/constants.rs +++ b/crates/polars-plan/src/constants.rs @@ -4,6 +4,7 @@ use polars_utils::pl_str::PlSmallStr; pub static MAP_LIST_NAME: &str = "map_list"; pub static CSE_REPLACED: &str = "__POLARS_CSER_"; +pub static POLARS_TMP_PREFIX: &str = "_POLARS_"; pub const LEN: &str = "len"; const LITERAL_NAME: &str = "literal"; pub const UNLIMITED_CACHE: u32 = u32::MAX; diff --git a/crates/polars-plan/src/plans/builder_ir.rs b/crates/polars-plan/src/plans/builder_ir.rs index f60c1d5ec95a..645c20f9d464 100644 --- a/crates/polars-plan/src/plans/builder_ir.rs +++ b/crates/polars-plan/src/plans/builder_ir.rs @@ -269,21 +269,13 @@ impl<'a> IRBuilder<'a> { let schema_left = self.schema(); let schema_right = self.lp_arena.get(other).schema(self.lp_arena); - let left_on_exprs = left_on - .iter() - .map(|e| e.to_expr(self.expr_arena)) - .collect::>(); - let right_on_exprs = right_on - .iter() - .map(|e| e.to_expr(self.expr_arena)) - .collect::>(); - let schema = det_join_schema( &schema_left, &schema_right, - &left_on_exprs, - &right_on_exprs, + &left_on, + &right_on, &options, + self.expr_arena, ) .unwrap(); diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs index 90fa93c5a120..6312a31d7853 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs @@ -647,6 +647,7 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult ctxt, ) .map_err(|e| e.context(failed_here!(join))) + .map(|t| t.0) }, DslPlan::HStack { input, diff --git a/crates/polars-plan/src/plans/conversion/join.rs b/crates/polars-plan/src/plans/conversion/join.rs index ba0eb2670b14..2d3e2a0f483d 100644 --- a/crates/polars-plan/src/plans/conversion/join.rs +++ b/crates/polars-plan/src/plans/conversion/join.rs @@ -3,8 +3,11 @@ use either::Either; use polars_core::chunked_array::cast::CastOptions; use polars_core::error::feature_gated; use polars_core::utils::get_numeric_upcast_supertype_lossless; +use polars_utils::format_pl_smallstr; +use polars_utils::itertools::Itertools; use super::*; +use crate::constants::POLARS_TMP_PREFIX; use crate::dsl::Expr; #[cfg(feature = "iejoin")] use crate::plans::AExpr; @@ -20,6 +23,8 @@ fn check_join_keys(keys: &[Expr]) -> PolarsResult<()> { } Ok(()) } + +/// Returns: left: join_node, right: last_node (often both the same) pub fn resolve_join( input_left: Either, Node>, input_right: Either, Node>, @@ -28,7 +33,7 @@ pub fn resolve_join( predicates: Vec, mut options: Arc, ctxt: &mut DslConversionContext, -) -> PolarsResult { +) -> PolarsResult<(Node, Node)> { if !predicates.is_empty() { feature_gated!("iejoin", { debug_assert!(left_on.is_empty() && right_on.is_empty()); @@ -101,9 +106,6 @@ pub fn resolve_join( ); } - let schema = det_join_schema(&schema_left, &schema_right, &left_on, &right_on, &options) - .map_err(|e| e.context(failed_here!(join schema resolving)))?; - let mut left_on = to_expr_irs_ignore_alias(left_on, ctxt.expr_arena)?; let mut right_on = to_expr_irs_ignore_alias(right_on, ctxt.expr_arena)?; let mut joined_on = PlHashSet::new(); @@ -147,15 +149,102 @@ pub fn resolve_join( .get_type($schema, Context::Default, ctxt.expr_arena) }; } + // # Resolve scalars + // + // Scalars need to be expanded. We translate them to temporary columns added with + // `with_columns` and remove them later with `project` + // This way the backends don't have to expand the literals in the join implementation + + let has_scalars = left_on + .iter() + .chain(right_on.iter()) + .any(|e| e.is_scalar(ctxt.expr_arena)); + + let (schema_left, schema_right) = if has_scalars { + let mut as_with_columns_l = vec![]; + let mut as_with_columns_r = vec![]; + for (i, e) in left_on.iter().enumerate() { + if e.is_scalar(ctxt.expr_arena) { + as_with_columns_l.push((i, e.clone())); + } + } + for (i, e) in right_on.iter().enumerate() { + if e.is_scalar(ctxt.expr_arena) { + as_with_columns_r.push((i, e.clone())); + } + } + let mut count = 0; + let get_tmp_name = |i| format_pl_smallstr!("{POLARS_TMP_PREFIX}{i}"); + + // Early clone because of bck. + let mut schema_right_new = if !as_with_columns_r.is_empty() { + (**schema_right).clone() + } else { + Default::default() + }; + if !as_with_columns_l.is_empty() { + let mut schema_left_new = (**schema_left).clone(); + + let mut exprs = Vec::with_capacity(as_with_columns_l.len()); + for (i, mut e) in as_with_columns_l { + let tmp_name = get_tmp_name(count); + count += 1; + e.set_alias(tmp_name.clone()); + let dtype = e.dtype(&schema_left_new, Context::Default, ctxt.expr_arena)?; + schema_left_new.with_column(tmp_name.clone(), dtype.clone()); + + let col = ctxt.expr_arena.add(AExpr::Column(tmp_name)); + left_on[i] = ExprIR::from_node(col, ctxt.expr_arena); + exprs.push(e); + } + input_left = ctxt.lp_arena.add(IR::HStack { + input: input_left, + exprs, + schema: Arc::new(schema_left_new), + options: ProjectionOptions::default(), + }) + } + if !as_with_columns_r.is_empty() { + let mut exprs = Vec::with_capacity(as_with_columns_r.len()); + for (i, mut e) in as_with_columns_r { + let tmp_name = get_tmp_name(count); + count += 1; + e.set_alias(tmp_name.clone()); + let dtype = e.dtype(&schema_right_new, Context::Default, ctxt.expr_arena)?; + schema_right_new.with_column(tmp_name.clone(), dtype.clone()); + + let col = ctxt.expr_arena.add(AExpr::Column(tmp_name)); + right_on[i] = ExprIR::from_node(col, ctxt.expr_arena); + exprs.push(e); + } + input_right = ctxt.lp_arena.add(IR::HStack { + input: input_right, + exprs, + schema: Arc::new(schema_right_new), + options: ProjectionOptions::default(), + }) + } + + ( + ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena), + ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena), + ) + } else { + (schema_left, schema_right) + }; + + // # Cast lossless + // // If we do a full join and keys are coalesced, the casted keys must be added up front. let key_cols_coalesced = options.args.should_coalesce() && matches!(&options.args.how, JoinType::Full); - let mut to_cast_left = vec![]; - let mut to_cast_right = vec![]; - let mut to_cast_indices = vec![]; + let mut as_with_columns_l = vec![]; + let mut as_with_columns_r = vec![]; + for (lnode, rnode) in left_on.iter_mut().zip(right_on.iter_mut()) { + //polars_ensure!(!lnode.is_scalar(&ctxt.expr_arena), InvalidOperation: "joining on scalars is not allowed, consider using 'join_where'"); + //polars_ensure!(!rnode.is_scalar(&ctxt.expr_arena), InvalidOperation: "joining on scalars is not allowed, consider using 'join_where'"); - for (i, (lnode, rnode)) in left_on.iter_mut().zip(right_on.iter_mut()).enumerate() { let ltype = get_dtype!(lnode, &schema_left)?; let rtype = get_dtype!(rnode, &schema_right)?; @@ -186,9 +275,8 @@ pub fn resolve_join( lnode.set_node(casted_l); rnode.set_node(casted_r); - to_cast_indices.push(i); - to_cast_right.push(rnode); - to_cast_left.push(lnode); + as_with_columns_r.push(rnode); + as_with_columns_l.push(lnode); } else { lnode.set_node(casted_l); rnode.set_node(casted_r); @@ -214,42 +302,70 @@ pub fn resolve_join( let schema_left = schema_left.into_owned(); let schema_right = schema_right.into_owned(); - let key_cols_coalesced = - options.args.should_coalesce() && matches!(&options.args.how, JoinType::Full); + let join_schema = det_join_schema( + &schema_left, + &schema_right, + &left_on, + &right_on, + &options, + ctxt.expr_arena, + ) + .map_err(|e| e.context(failed_here!(join schema resolving)))?; if key_cols_coalesced { - input_left = if to_cast_left.is_empty() { + input_left = if as_with_columns_l.is_empty() { input_left } else { ctxt.lp_arena.add(IR::HStack { input: input_left, - exprs: to_cast_left, + exprs: as_with_columns_l, schema: schema_left, options: ProjectionOptions::default(), }) }; - input_right = if to_cast_right.is_empty() { + input_right = if as_with_columns_r.is_empty() { input_right } else { ctxt.lp_arena.add(IR::HStack { input: input_right, - exprs: to_cast_right, + exprs: as_with_columns_r, schema: schema_right, options: ProjectionOptions::default(), }) }; } - let lp = IR::Join { + let ir = IR::Join { input_left, input_right, - schema, + schema: join_schema.clone(), left_on, right_on, options, }; - Ok(ctxt.lp_arena.add(lp)) + let join_node = ctxt.lp_arena.add(ir); + + if has_scalars { + let names = join_schema + .iter_names() + .filter_map(|n| { + if n.starts_with(POLARS_TMP_PREFIX) { + None + } else { + Some(n.clone()) + } + }) + .collect_vec(); + + let builder = IRBuilder::new(join_node, ctxt.expr_arena, ctxt.lp_arena); + let ir = builder.project_simple(names).map(|b| b.build())?; + let select_node = ctxt.lp_arena.add(ir); + + Ok((select_node, join_node)) + } else { + Ok((join_node, join_node)) + } } #[cfg(feature = "iejoin")] @@ -265,13 +381,14 @@ impl From for Operator { } #[cfg(feature = "iejoin")] +/// Returns: left: join_node, right: last_node (often both the same) fn resolve_join_where( input_left: Arc, input_right: Arc, predicates: Vec, mut options: Arc, ctxt: &mut DslConversionContext, -) -> PolarsResult { +) -> PolarsResult<(Node, Node)> { check_join_keys(&predicates)?; let input_left = to_alp_impl(Arc::unwrap_or_clone(input_left), ctxt) .map_err(|e| e.context(failed_here!(join left)))?; @@ -499,10 +616,10 @@ fn resolve_join_where( } } - let join_node = if !eq_left_on.is_empty() { + let (mut last_node, join_node) = if !eq_left_on.is_empty() { // We found one or more equality predicates. Go into a default equi join // as those are cheapest on avg. - let join_node = resolve_join( + let (last_node, join_node) = resolve_join( Either::Right(input_left), Either::Right(input_right), eq_left_on, @@ -520,7 +637,7 @@ fn resolve_join_where( &schema_right, &suffix, ); - join_node + (last_node, join_node) } else if ie_right_on.len() >= 2 { // Do an IEjoin. let opts = Arc::make_mut(&mut options); @@ -529,7 +646,7 @@ fn resolve_join_where( operator2: Some(ie_op[1]), }); - let join_node = resolve_join( + let (last_node, join_node) = resolve_join( Either::Right(input_left), Either::Right(input_right), ie_left_on[..2].to_vec(), @@ -550,7 +667,7 @@ fn resolve_join_where( remaining_preds.push(to_binary_post_join(l, op.into(), r, &schema_right, &suffix)) } - join_node + (last_node, join_node) } else if ie_right_on.len() == 1 { // For a single inequality comparison, we use the piecewise merge join algorithm let opts = Arc::make_mut(&mut options); @@ -605,8 +722,6 @@ fn resolve_join_where( .schema(ctxt.lp_arena) .into_owned(); - let mut last_node = join_node; - // Ensure that the predicates use the proper suffix for e in remaining_preds { let predicate = to_expr_ir_ignore_alias(e, ctxt.expr_arena)?; @@ -637,5 +752,5 @@ fn resolve_join_where( }; last_node = ctxt.lp_arena.add(ir); } - Ok(last_node) + Ok((last_node, join_node)) } diff --git a/crates/polars-plan/src/plans/expr_ir.rs b/crates/polars-plan/src/plans/expr_ir.rs index 01bb76a2de5c..15a25286895b 100644 --- a/crates/polars-plan/src/plans/expr_ir.rs +++ b/crates/polars-plan/src/plans/expr_ir.rs @@ -193,7 +193,6 @@ impl ExprIR { self.output_dtype = OnceLock::new(); } - #[cfg(feature = "cse")] pub(crate) fn set_alias(&mut self, name: PlSmallStr) { self.output_name = OutputName::Alias(name) } diff --git a/crates/polars-plan/src/plans/schema.rs b/crates/polars-plan/src/plans/schema.rs index ac129afc3703..0f19dbbfb31b 100644 --- a/crates/polars-plan/src/plans/schema.rs +++ b/crates/polars-plan/src/plans/schema.rs @@ -240,9 +240,10 @@ pub fn set_estimated_row_counts( pub(crate) fn det_join_schema( schema_left: &SchemaRef, schema_right: &SchemaRef, - left_on: &[Expr], - right_on: &[Expr], + left_on: &[ExprIR], + right_on: &[ExprIR], options: &JoinOptions, + expr_arena: &Arena, ) -> PolarsResult { match &options.args.how { // semi and anti joins are just filtering operations @@ -251,16 +252,15 @@ pub(crate) fn det_join_schema( JoinType::Semi | JoinType::Anti => Ok(schema_left.clone()), JoinType::Right => { // Get join names. - let mut arena = Arena::with_capacity(8); let mut join_on_left: PlHashSet<_> = PlHashSet::with_capacity(left_on.len()); for e in left_on { - let field = e.to_field_amortized(schema_left, Context::Default, &mut arena)?; + let field = e.field(schema_left, Context::Default, expr_arena)?; join_on_left.insert(field.name); } let mut join_on_right: PlHashSet<_> = PlHashSet::with_capacity(right_on.len()); for e in right_on { - let field = e.to_field_amortized(schema_right, Context::Default, &mut arena)?; + let field = e.field(schema_right, Context::Default, expr_arena)?; join_on_right.insert(field.name); } @@ -303,8 +303,6 @@ pub(crate) fn det_join_schema( } let should_coalesce = options.args.should_coalesce(); - let mut arena = Arena::with_capacity(8); - // Handles coalescing of asof-joins. // Asof joins are not equi-joins // so the columns that are joined on, may have different @@ -312,10 +310,9 @@ pub(crate) fn det_join_schema( #[cfg(feature = "asof_join")] if matches!(_how, JoinType::AsOf(_)) { for (left_on, right_on) in left_on.iter().zip(right_on) { - let field_left = - left_on.to_field_amortized(schema_left, Context::Default, &mut arena)?; - let field_right = - right_on.to_field_amortized(schema_right, Context::Default, &mut arena)?; + let field_left = left_on.field(schema_left, Context::Default, expr_arena)?; + let field_right = right_on.field(schema_right, Context::Default, expr_arena)?; + if should_coalesce && field_left.name != field_right.name { if schema_left.contains(&field_right.name) { new_schema.with_column( @@ -330,7 +327,7 @@ pub(crate) fn det_join_schema( } let mut join_on_right: PlHashSet<_> = PlHashSet::with_capacity(right_on.len()); for e in right_on { - let field = e.to_field_amortized(schema_right, Context::Default, &mut arena)?; + let field = e.field(schema_right, Context::Default, expr_arena)?; join_on_right.insert(field.name); } diff --git a/py-polars/tests/unit/operations/test_inequality_join.py b/py-polars/tests/unit/operations/test_inequality_join.py index d306dd173431..87a76c985b13 100644 --- a/py-polars/tests/unit/operations/test_inequality_join.py +++ b/py-polars/tests/unit/operations/test_inequality_join.py @@ -655,3 +655,30 @@ def test_join_predicate_pushdown_19580() -> None: ) assert_frame_equal(expect, q.collect(), check_row_order=False) + + +def test_join_where_literal_20061() -> None: + df_left = pl.DataFrame( + {"id": [1, 2, 3], "value_left": [10, 20, 30], "flag": [1, 0, 1]} + ) + + df_right = pl.DataFrame( + { + "id": [1, 2, 3], + "value_right": [5, 5, 25], + "flag": [1, 0, 1], + } + ) + + assert df_left.join_where( + df_right, + pl.col("value_left") > pl.col("value_right"), + pl.col("flag_right").cast(pl.Int32) == 1, + ).sort("id").to_dict(as_series=False) == { + "id": [1, 2, 3, 3], + "value_left": [10, 20, 30, 30], + "flag": [1, 0, 1, 1], + "id_right": [1, 1, 1, 3], + "value_right": [5, 5, 5, 25], + "flag_right": [1, 1, 1, 1], + } diff --git a/py-polars/tests/unit/operations/test_join.py b/py-polars/tests/unit/operations/test_join.py index fd384408e247..bfa34e5c9018 100644 --- a/py-polars/tests/unit/operations/test_join.py +++ b/py-polars/tests/unit/operations/test_join.py @@ -1014,9 +1014,10 @@ def test_join_lit_panic_11410() -> None: df = pl.LazyFrame({"date": [1, 2, 3], "symbol": [4, 5, 6]}) dates = df.select("date").unique(maintain_order=True) symbols = df.select("symbol").unique(maintain_order=True) + assert symbols.join(dates, left_on=pl.lit(1), right_on=pl.lit(1)).collect().to_dict( as_series=False - ) == {"symbol": [4], "date": [1]} + ) == {"symbol": [4, 4, 4, 5, 5, 5, 6, 6, 6], "date": [1, 2, 3, 1, 2, 3, 1, 2, 3]} def test_join_empty_literal_17027() -> None: @@ -1320,48 +1321,48 @@ def test_join_preserve_order_full() -> None: @pytest.mark.parametrize( "dtypes", [ - ["Int128" , "Int128" , "Int64" ], - ["Int128" , "Int128" , "Int32" ], - ["Int128" , "Int128" , "Int16" ], - ["Int128" , "Int128" , "Int8" ], - ["Int128" , "UInt64" , "Int128" ], - ["Int128" , "UInt64" , "Int64" ], - ["Int128" , "UInt64" , "Int32" ], - ["Int128" , "UInt64" , "Int16" ], - ["Int128" , "UInt64" , "Int8" ], - ["Int128" , "UInt32" , "Int128" ], - ["Int128" , "UInt16" , "Int128" ], - ["Int128" , "UInt8" , "Int128" ], - - ["Int64" , "Int64" , "Int32" ], - ["Int64" , "Int64" , "Int16" ], - ["Int64" , "Int64" , "Int8" ], - ["Int64" , "UInt32" , "Int64" ], - ["Int64" , "UInt32" , "Int32" ], - ["Int64" , "UInt32" , "Int16" ], - ["Int64" , "UInt32" , "Int8" ], - ["Int64" , "UInt16" , "Int64" ], - ["Int64" , "UInt8" , "Int64" ], - - ["Int32" , "Int32" , "Int16" ], - ["Int32" , "Int32" , "Int8" ], - ["Int32" , "UInt16" , "Int32" ], - ["Int32" , "UInt16" , "Int16" ], - ["Int32" , "UInt16" , "Int8" ], - ["Int32" , "UInt8" , "Int32" ], - - ["Int16" , "Int16" , "Int8" ], - ["Int16" , "UInt8" , "Int16" ], - ["Int16" , "UInt8" , "Int8" ], - - ["UInt64" , "UInt64" , "UInt32" ], - ["UInt64" , "UInt64" , "UInt16" ], - ["UInt64" , "UInt64" , "UInt8" ], - - ["UInt32" , "UInt32" , "UInt16" ], - ["UInt32" , "UInt32" , "UInt8" ], - - ["UInt16" , "UInt16" , "UInt8" ], + ["Int128", "Int128", "Int64"], + ["Int128", "Int128", "Int32"], + ["Int128", "Int128", "Int16"], + ["Int128", "Int128", "Int8"], + ["Int128", "UInt64", "Int128"], + ["Int128", "UInt64", "Int64"], + ["Int128", "UInt64", "Int32"], + ["Int128", "UInt64", "Int16"], + ["Int128", "UInt64", "Int8"], + ["Int128", "UInt32", "Int128"], + ["Int128", "UInt16", "Int128"], + ["Int128", "UInt8", "Int128"], + + ["Int64", "Int64", "Int32"], + ["Int64", "Int64", "Int16"], + ["Int64", "Int64", "Int8"], + ["Int64", "UInt32", "Int64"], + ["Int64", "UInt32", "Int32"], + ["Int64", "UInt32", "Int16"], + ["Int64", "UInt32", "Int8"], + ["Int64", "UInt16", "Int64"], + ["Int64", "UInt8", "Int64"], + + ["Int32", "Int32", "Int16"], + ["Int32", "Int32", "Int8"], + ["Int32", "UInt16", "Int32"], + ["Int32", "UInt16", "Int16"], + ["Int32", "UInt16", "Int8"], + ["Int32", "UInt8", "Int32"], + + ["Int16", "Int16", "Int8"], + ["Int16", "UInt8", "Int16"], + ["Int16", "UInt8", "Int8"], + + ["UInt64", "UInt64", "UInt32"], + ["UInt64", "UInt64", "UInt16"], + ["UInt64", "UInt64", "UInt8"], + + ["UInt32", "UInt32", "UInt16"], + ["UInt32", "UInt32", "UInt8"], + + ["UInt16", "UInt16", "UInt8"], ["Float64", "Float64", "Float32"], ],