From ec2876a88e2a4de6f2b54458a21bedfa98c96700 Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Sat, 28 Oct 2023 14:19:10 +0800 Subject: [PATCH] refactor(rust): reshape and repeat_by non-anoymous (#12064) --- .../src/dsl/function_expr/dispatch.rs | 12 ++++++ .../polars-plan/src/dsl/function_expr/mod.rs | 14 +++++++ .../src/dsl/function_expr/schema.rs | 10 +++++ crates/polars-plan/src/dsl/mod.rs | 38 +------------------ 4 files changed, 38 insertions(+), 36 deletions(-) diff --git a/crates/polars-plan/src/dsl/function_expr/dispatch.rs b/crates/polars-plan/src/dsl/function_expr/dispatch.rs index b87c87040b4e..bf10270035e9 100644 --- a/crates/polars-plan/src/dsl/function_expr/dispatch.rs +++ b/crates/polars-plan/src/dsl/function_expr/dispatch.rs @@ -53,6 +53,18 @@ pub(super) fn unique_counts(s: &Series) -> PolarsResult { polars_ops::prelude::unique_counts(s) } +pub(super) fn reshape(s: &Series, dims: Vec) -> PolarsResult { + s.reshape(&dims) +} + +#[cfg(feature = "repeat_by")] +pub(super) fn repeat_by(s: &[Series]) -> PolarsResult { + let by = &s[1]; + let s = &s[0]; + let by = by.cast(&IDX_DTYPE)?; + polars_ops::chunked_array::repeat_by(s, by.idx()?).map(|ok| ok.into_series()) +} + pub(super) fn backward_fill(s: &Series, limit: FillNullLimit) -> PolarsResult { s.fill_null(FillNullStrategy::Backward(limit)) } diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index 3199958ca88e..b80dfccc4904 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -144,6 +144,9 @@ pub enum FunctionExpr { Skew(bool), #[cfg(feature = "moment")] Kurtosis(bool, bool), + Reshape(Vec), + #[cfg(feature = "repeat_by")] + RepeatBy, ArgUnique, #[cfg(feature = "rank")] Rank { @@ -455,6 +458,11 @@ impl Hash for FunctionExpr { left_closed.hash(state); include_breaks.hash(state); }, + FunctionExpr::Reshape(dims) => { + dims.hash(state); + }, + #[cfg(feature = "repeat_by")] + FunctionExpr::RepeatBy => {}, #[cfg(feature = "cutqcut")] FunctionExpr::QCut { probs, @@ -620,6 +628,9 @@ impl Display for FunctionExpr { Cut { .. } => "cut", #[cfg(feature = "cutqcut")] QCut { .. } => "qcut", + Reshape(_) => "reshape", + #[cfg(feature = "repeat_by")] + RepeatBy => "repeat_by", #[cfg(feature = "rle")] RLE => "rle", #[cfg(feature = "rle")] @@ -934,6 +945,9 @@ impl From for SpecialEq> { PeakMin => map!(peaks::peak_min), #[cfg(feature = "peaks")] PeakMax => map!(peaks::peak_max), + #[cfg(feature = "repeat_by")] + RepeatBy => map_as_slice!(dispatch::repeat_by), + Reshape(dims) => map!(dispatch::reshape, dims.clone()), #[cfg(feature = "cutqcut")] Cut { breaks, diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index 7024b197500c..b69749e5fb8f 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -227,6 +227,16 @@ impl FunctionExpr { ]); mapper.with_dtype(struct_dt) }, + #[cfg(feature = "repeat_by")] + RepeatBy => mapper.map_dtype(|dt| DataType::List(dt.clone().into())), + Reshape(dims) => mapper.map_dtype(|dt| { + let dtype = dt.inner_dtype().unwrap_or(dt).clone(); + if dims.len() == 1 { + dtype + } else { + DataType::List(Box::new(dtype)) + } + }), #[cfg(feature = "cutqcut")] QCut { include_breaks: false, diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 2e37dd5f8b1c..01e3d6a29618 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -1111,19 +1111,7 @@ impl Expr { #[cfg(feature = "repeat_by")] fn repeat_by_impl(self, by: Expr) -> Expr { - let function = |s: &mut [Series]| { - let by = &s[1]; - let s = &s[0]; - let by = by.cast(&IDX_DTYPE)?; - Ok(Some(repeat_by(s, by.idx()?)?.into_series())) - }; - - self.apply_many( - function, - &[by], - GetOutput::map_dtype(|dt| DataType::List(dt.clone().into())), - ) - .with_fmt("repeat_by") + self.apply_many_private(FunctionExpr::RepeatBy, &[by], false, false) } #[cfg(feature = "repeat_by")] @@ -1574,29 +1562,7 @@ impl Expr { pub fn reshape(self, dims: &[i64]) -> Self { let dims = dims.to_vec(); - let output_type = if dims.len() == 1 { - GetOutput::map_field(|fld| { - Field::new( - fld.name(), - fld.data_type() - .inner_dtype() - .unwrap_or_else(|| fld.data_type()) - .clone(), - ) - }) - } else { - GetOutput::map_field(|fld| { - let dtype = fld - .data_type() - .inner_dtype() - .unwrap_or_else(|| fld.data_type()) - .clone(); - - Field::new(fld.name(), DataType::List(Box::new(dtype))) - }) - }; - self.apply(move |s| s.reshape(&dims).map(Some), output_type) - .with_fmt("reshape") + self.apply_private(FunctionExpr::Reshape(dims)) } #[cfg(feature = "ewma")]