Skip to content

Commit

Permalink
refactor(rust): reshape and repeat_by non-anoymous (#12064)
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa authored Oct 28, 2023
1 parent 41a8f67 commit ec2876a
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 36 deletions.
12 changes: 12 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@ pub(super) fn unique_counts(s: &Series) -> PolarsResult<Series> {
polars_ops::prelude::unique_counts(s)
}

pub(super) fn reshape(s: &Series, dims: Vec<i64>) -> PolarsResult<Series> {
s.reshape(&dims)
}

#[cfg(feature = "repeat_by")]
pub(super) fn repeat_by(s: &[Series]) -> PolarsResult<Series> {
let by = &s[1];
let s = &s[0];
let by = by.cast(&IDX_DTYPE)?;
polars_ops::chunked_array::repeat_by(s, by.idx()?).map(|ok| ok.into_series())
}

pub(super) fn backward_fill(s: &Series, limit: FillNullLimit) -> PolarsResult<Series> {
s.fill_null(FillNullStrategy::Backward(limit))
}
Expand Down
14 changes: 14 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ pub enum FunctionExpr {
Skew(bool),
#[cfg(feature = "moment")]
Kurtosis(bool, bool),
Reshape(Vec<i64>),
#[cfg(feature = "repeat_by")]
RepeatBy,
ArgUnique,
#[cfg(feature = "rank")]
Rank {
Expand Down Expand Up @@ -455,6 +458,11 @@ impl Hash for FunctionExpr {
left_closed.hash(state);
include_breaks.hash(state);
},
FunctionExpr::Reshape(dims) => {
dims.hash(state);
},
#[cfg(feature = "repeat_by")]
FunctionExpr::RepeatBy => {},
#[cfg(feature = "cutqcut")]
FunctionExpr::QCut {
probs,
Expand Down Expand Up @@ -620,6 +628,9 @@ impl Display for FunctionExpr {
Cut { .. } => "cut",
#[cfg(feature = "cutqcut")]
QCut { .. } => "qcut",
Reshape(_) => "reshape",
#[cfg(feature = "repeat_by")]
RepeatBy => "repeat_by",
#[cfg(feature = "rle")]
RLE => "rle",
#[cfg(feature = "rle")]
Expand Down Expand Up @@ -934,6 +945,9 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
PeakMin => map!(peaks::peak_min),
#[cfg(feature = "peaks")]
PeakMax => map!(peaks::peak_max),
#[cfg(feature = "repeat_by")]
RepeatBy => map_as_slice!(dispatch::repeat_by),
Reshape(dims) => map!(dispatch::reshape, dims.clone()),
#[cfg(feature = "cutqcut")]
Cut {
breaks,
Expand Down
10 changes: 10 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,16 @@ impl FunctionExpr {
]);
mapper.with_dtype(struct_dt)
},
#[cfg(feature = "repeat_by")]
RepeatBy => mapper.map_dtype(|dt| DataType::List(dt.clone().into())),
Reshape(dims) => mapper.map_dtype(|dt| {
let dtype = dt.inner_dtype().unwrap_or(dt).clone();
if dims.len() == 1 {
dtype
} else {
DataType::List(Box::new(dtype))
}
}),
#[cfg(feature = "cutqcut")]
QCut {
include_breaks: false,
Expand Down
38 changes: 2 additions & 36 deletions crates/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1111,19 +1111,7 @@ impl Expr {

#[cfg(feature = "repeat_by")]
fn repeat_by_impl(self, by: Expr) -> Expr {
let function = |s: &mut [Series]| {
let by = &s[1];
let s = &s[0];
let by = by.cast(&IDX_DTYPE)?;
Ok(Some(repeat_by(s, by.idx()?)?.into_series()))
};

self.apply_many(
function,
&[by],
GetOutput::map_dtype(|dt| DataType::List(dt.clone().into())),
)
.with_fmt("repeat_by")
self.apply_many_private(FunctionExpr::RepeatBy, &[by], false, false)
}

#[cfg(feature = "repeat_by")]
Expand Down Expand Up @@ -1574,29 +1562,7 @@ impl Expr {

pub fn reshape(self, dims: &[i64]) -> Self {
let dims = dims.to_vec();
let output_type = if dims.len() == 1 {
GetOutput::map_field(|fld| {
Field::new(
fld.name(),
fld.data_type()
.inner_dtype()
.unwrap_or_else(|| fld.data_type())
.clone(),
)
})
} else {
GetOutput::map_field(|fld| {
let dtype = fld
.data_type()
.inner_dtype()
.unwrap_or_else(|| fld.data_type())
.clone();

Field::new(fld.name(), DataType::List(Box::new(dtype)))
})
};
self.apply(move |s| s.reshape(&dims).map(Some), output_type)
.with_fmt("reshape")
self.apply_private(FunctionExpr::Reshape(dims))
}

#[cfg(feature = "ewma")]
Expand Down

0 comments on commit ec2876a

Please sign in to comment.