Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix take return dtype in group context. #11949

Merged
merged 4 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 47 additions & 19 deletions crates/polars-lazy/src/physical_plan/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,48 @@ use crate::physical_plan::state::ExecutionState;
use crate::prelude::*;

pub struct ApplyExpr {
pub inputs: Vec<Arc<dyn PhysicalExpr>>,
pub function: SpecialEq<Arc<dyn SeriesUdf>>,
pub expr: Expr,
pub collect_groups: ApplyOptions,
pub auto_explode: bool,
pub allow_rename: bool,
pub pass_name_to_apply: bool,
pub input_schema: Option<SchemaRef>,
pub allow_threading: bool,
pub check_lengths: bool,
pub allow_group_aware: bool,
inputs: Vec<Arc<dyn PhysicalExpr>>,
function: SpecialEq<Arc<dyn SeriesUdf>>,
expr: Expr,
collect_groups: ApplyOptions,
returns_scalar: bool,
allow_rename: bool,
pass_name_to_apply: bool,
input_schema: Option<SchemaRef>,
allow_threading: bool,
check_lengths: bool,
allow_group_aware: bool,
}

impl ApplyExpr {
pub(crate) fn new(
inputs: Vec<Arc<dyn PhysicalExpr>>,
function: SpecialEq<Arc<dyn SeriesUdf>>,
expr: Expr,
options: FunctionOptions,
allow_threading: bool,
input_schema: Option<SchemaRef>,
) -> Self {
#[cfg(debug_assertions)]
if matches!(options.collect_groups, ApplyOptions::ElementWise) && options.returns_scalar {
panic!("expr {} is not implemented correctly. 'returns_scalar' and 'elementwise' are mutually exclusive", expr)
}

Self {
inputs,
function,
expr,
collect_groups: options.collect_groups,
returns_scalar: options.returns_scalar,
allow_rename: options.allow_rename,
pass_name_to_apply: options.pass_name_to_apply,
input_schema,
allow_threading,
check_lengths: options.check_lengths(),
allow_group_aware: options.allow_group_aware,
}
}

pub(crate) fn new_minimal(
inputs: Vec<Arc<dyn PhysicalExpr>>,
function: SpecialEq<Arc<dyn SeriesUdf>>,
Expand All @@ -39,7 +67,7 @@ impl ApplyExpr {
function,
expr,
collect_groups,
auto_explode: false,
returns_scalar: false,
allow_rename: false,
pass_name_to_apply: false,
input_schema: None,
Expand Down Expand Up @@ -70,7 +98,7 @@ impl ApplyExpr {
ca: ListChunked,
) -> PolarsResult<AggregationContext<'a>> {
let all_unit_len = all_unit_length(&ca);
if all_unit_len && self.auto_explode {
if all_unit_len && self.returns_scalar {
ac.with_series(ca.explode().unwrap().into_series(), true, Some(&self.expr))?;
ac.update_groups = UpdateGroups::No;
} else {
Expand Down Expand Up @@ -289,8 +317,8 @@ impl PhysicalExpr for ApplyExpr {
ac.with_series(s, true, Some(&self.expr))?;
Ok(ac)
},
ApplyOptions::ApplyGroups => self.apply_single_group_aware(ac),
ApplyOptions::ApplyFlat => self.apply_single_elementwise(ac),
ApplyOptions::GroupWise => self.apply_single_group_aware(ac),
ApplyOptions::ElementWise => self.apply_single_elementwise(ac),
}
} else {
let mut acs = self.prepare_multiple_inputs(df, groups, state)?;
Expand All @@ -305,8 +333,8 @@ impl PhysicalExpr for ApplyExpr {
ac.with_series(s, true, Some(&self.expr))?;
Ok(ac)
},
ApplyOptions::ApplyGroups => self.apply_multiple_group_aware(acs, df),
ApplyOptions::ApplyFlat => {
ApplyOptions::GroupWise => self.apply_multiple_group_aware(acs, df),
ApplyOptions::ElementWise => {
if acs
.iter()
.any(|ac| matches!(ac.agg_state(), AggState::AggregatedList(_)))
Expand All @@ -328,7 +356,7 @@ impl PhysicalExpr for ApplyExpr {
self.expr.to_field(input_schema, Context::Default)
}
fn is_valid_aggregation(&self) -> bool {
matches!(self.collect_groups, ApplyOptions::ApplyGroups)
matches!(self.collect_groups, ApplyOptions::GroupWise)
}
#[cfg(feature = "parquet")]
fn as_stats_evaluator(&self) -> Option<&dyn polars_io::predicates::StatsEvaluator> {
Expand All @@ -345,7 +373,7 @@ impl PhysicalExpr for ApplyExpr {
}
}
fn as_partitioned_aggregator(&self) -> Option<&dyn PartitionedAggregation> {
if self.inputs.len() == 1 && matches!(self.collect_groups, ApplyOptions::ApplyFlat) {
if self.inputs.len() == 1 && matches!(self.collect_groups, ApplyOptions::ElementWise) {
Some(self)
} else {
None
Expand Down
21 changes: 20 additions & 1 deletion crates/polars-lazy/src/physical_plan/expressions/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub struct TakeExpr {
pub(crate) phys_expr: Arc<dyn PhysicalExpr>,
pub(crate) idx: Arc<dyn PhysicalExpr>,
pub(crate) expr: Expr,
pub(crate) returns_scalar: bool,
}

impl TakeExpr {
Expand Down Expand Up @@ -101,12 +102,23 @@ impl PhysicalExpr for TakeExpr {
},
};
let taken = ac.flat_naive().take(&idx)?;

let taken = if self.returns_scalar {
taken
} else {
taken.as_list().into_series()
};

ac.with_series(taken, true, Some(&self.expr))?;
return Ok(ac);
},
AggState::AggregatedList(s) => s.list().unwrap().clone(),
AggState::AggregatedList(s) => {
polars_ensure!(!self.returns_scalar, ComputeError: "expected single index");
s.list().unwrap().clone()
},
// Maybe a literal as well, this needs a different path.
AggState::NotAggregated(_) => {
polars_ensure!(!self.returns_scalar, ComputeError: "expected single index");
let s = idx.aggregated();
s.list().unwrap().clone()
},
Expand Down Expand Up @@ -144,6 +156,13 @@ impl PhysicalExpr for TakeExpr {
},
};
let taken = ac.flat_naive().take(&idx.into_inner())?;

let taken = if self.returns_scalar {
taken
} else {
taken.as_list().into_series()
};

ac.with_series(taken, true, Some(&self.expr))?;
ac.with_update_groups(UpdateGroups::WithGroupsLen);
Ok(ac)
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-lazy/src/physical_plan/expressions/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,8 @@ impl WindowExpr {
},
Expr::Function { options, .. }
| Expr::AnonymousFunction { options, .. } => {
if options.auto_explode
&& matches!(options.collect_groups, ApplyOptions::ApplyGroups)
if options.returns_scalar
&& matches!(options.collect_groups, ApplyOptions::GroupWise)
{
agg_col = true;
}
Expand Down
55 changes: 25 additions & 30 deletions crates/polars-lazy/src/physical_plan/planner/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,18 @@ pub(crate) fn create_physical_expr(
node_to_expr(expression, expr_arena),
)))
},
Take { expr, idx } => {
Take {
expr,
idx,
returns_scalar,
} => {
let phys_expr = create_physical_expr(expr, ctxt, expr_arena, schema, state)?;
let phys_idx = create_physical_expr(idx, ctxt, expr_arena, schema, state)?;
Ok(Arc::new(TakeExpr {
phys_expr,
idx: phys_idx,
expr: node_to_expr(expression, expr_arena),
returns_scalar,
}))
},
SortBy {
Expand Down Expand Up @@ -391,7 +396,7 @@ pub(crate) fn create_physical_expr(
vec![input],
function,
node_to_expr(expression, expr_arena),
ApplyOptions::ApplyFlat,
ApplyOptions::ElementWise,
)))
},
_ => {
Expand Down Expand Up @@ -463,7 +468,7 @@ pub(crate) fn create_physical_expr(
options,
} => {
let is_reducing_aggregation =
options.auto_explode && matches!(options.collect_groups, ApplyOptions::ApplyGroups);
options.returns_scalar && matches!(options.collect_groups, ApplyOptions::GroupWise);
// will be reset in the function so get that here
let has_window = state.local.has_window;
let input = create_physical_expressions_check_state(
Expand All @@ -478,19 +483,14 @@ pub(crate) fn create_physical_expr(
},
)?;

Ok(Arc::new(ApplyExpr {
inputs: input,
Ok(Arc::new(ApplyExpr::new(
input,
function,
expr: node_to_expr(expression, expr_arena),
collect_groups: options.collect_groups,
auto_explode: options.auto_explode,
allow_rename: options.allow_rename,
pass_name_to_apply: options.pass_name_to_apply,
input_schema: schema.cloned(),
allow_threading: !state.has_cache,
check_lengths: options.check_lengths(),
allow_group_aware: options.allow_group_aware,
}))
node_to_expr(expression, expr_arena),
options,
!state.has_cache,
schema.cloned(),
)))
},
Function {
input,
Expand All @@ -499,7 +499,7 @@ pub(crate) fn create_physical_expr(
..
} => {
let is_reducing_aggregation =
options.auto_explode && matches!(options.collect_groups, ApplyOptions::ApplyGroups);
options.returns_scalar && matches!(options.collect_groups, ApplyOptions::GroupWise);
// will be reset in the function so get that here
let has_window = state.local.has_window;
let input = create_physical_expressions_check_state(
Expand All @@ -514,19 +514,14 @@ pub(crate) fn create_physical_expr(
},
)?;

Ok(Arc::new(ApplyExpr {
inputs: input,
function: function.into(),
expr: node_to_expr(expression, expr_arena),
collect_groups: options.collect_groups,
auto_explode: options.auto_explode,
allow_rename: options.allow_rename,
pass_name_to_apply: options.pass_name_to_apply,
input_schema: schema.cloned(),
allow_threading: !state.has_cache,
check_lengths: options.check_lengths(),
allow_group_aware: options.allow_group_aware,
}))
Ok(Arc::new(ApplyExpr::new(
input,
function.into(),
node_to_expr(expression, expr_arena),
options,
!state.has_cache,
schema.cloned(),
)))
},
Slice {
input,
Expand All @@ -553,7 +548,7 @@ pub(crate) fn create_physical_expr(
vec![input],
function,
node_to_expr(expression, expr_arena),
ApplyOptions::ApplyGroups,
ApplyOptions::GroupWise,
)))
},
Wildcard => panic!("should be no wildcard at this point"),
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-lazy/src/physical_plan/planner/lp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ fn partitionable_gb(
)
},
Function {input, options, ..} => {
matches!(options.collect_groups, ApplyOptions::ApplyFlat) && input.len() == 1 &&
matches!(options.collect_groups, ApplyOptions::ElementWise) && input.len() == 1 &&
!has_aggregation(input[0])
}
BinaryExpr {left, right, ..} => {
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-lazy/src/physical_plan/streaming/checks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ pub(super) fn is_streamable(node: Node, expr_arena: &Arena<AExpr>, context: Cont
{
Context::Default => matches!(
options.collect_groups,
ApplyOptions::ApplyFlat | ApplyOptions::ApplyList
ApplyOptions::ElementWise | ApplyOptions::ApplyList
),
Context::Aggregation => matches!(options.collect_groups, ApplyOptions::ApplyFlat),
Context::Aggregation => matches!(options.collect_groups, ApplyOptions::ElementWise),
},
AExpr::Column(_) => {
seen_column = true;
Expand Down
19 changes: 8 additions & 11 deletions crates/polars-lazy/src/tests/aggregations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ fn take_aggregations() -> PolarsResult<()> {
.clone()
.lazy()
.group_by([col("user")])
.agg([col("book").take(col("count").arg_max()).alias("fav_book")])
.agg([col("book").get(col("count").arg_max()).alias("fav_book")])
.sort("user", Default::default())
.collect()?;

Expand Down Expand Up @@ -460,7 +460,7 @@ fn take_aggregations() -> PolarsResult<()> {
let out = df
.lazy()
.group_by([col("user")])
.agg([col("book").take(lit(0)).alias("take_lit")])
.agg([col("book").get(lit(0)).alias("take_lit")])
.sort("user", Default::default())
.collect()?;

Expand All @@ -484,7 +484,7 @@ fn test_take_consistency() -> PolarsResult<()> {
multithreaded: true,
maintain_order: false,
})
.take(lit(0))])
.get(lit(0))])
.collect()?;

let a = out.column("A")?;
Expand All @@ -502,7 +502,7 @@ fn test_take_consistency() -> PolarsResult<()> {
multithreaded: true,
maintain_order: false,
})
.take(lit(0))])
.get(lit(0))])
.collect()?;

let out = out.column("A")?;
Expand All @@ -521,18 +521,18 @@ fn test_take_consistency() -> PolarsResult<()> {
multithreaded: true,
maintain_order: false,
})
.take(lit(0))
.get(lit(0))
.alias("1"),
col("A")
.take(
.get(
col("A")
.arg_sort(SortOptions {
descending: true,
nulls_last: false,
multithreaded: true,
maintain_order: false,
})
.take(lit(0)),
.get(lit(0)),
)
.alias("2"),
])
Expand All @@ -556,10 +556,7 @@ fn test_take_in_groups() -> PolarsResult<()> {
let out = df
.lazy()
.sort("fruits", Default::default())
.select([col("B")
.take(lit(Series::new("", &[0u32])))
.over([col("fruits")])
.alias("taken")])
.select([col("B").get(lit(0u32)).over([col("fruits")]).alias("taken")])
.collect()?;

assert_eq!(
Expand Down
Loading