Skip to content

Commit

Permalink
feat!: Split replace functionality into two separate functions (#16921
Browse files Browse the repository at this point in the history
)
  • Loading branch information
stinodego authored Jun 22, 2024
1 parent 88cfb23 commit 0d3a285
Show file tree
Hide file tree
Showing 16 changed files with 1,031 additions and 419 deletions.
216 changes: 177 additions & 39 deletions crates/polars-ops/src/series/ops/replace.rs
Original file line number Diff line number Diff line change
@@ -1,34 +1,48 @@
use std::ops::BitOr;

use polars_core::prelude::*;
use polars_core::utils::try_get_supertype;
use polars_error::polars_ensure;

use crate::frame::join::*;
use crate::prelude::*;

pub fn replace(
/// Replace values by different values of the same data type.
pub fn replace(s: &Series, old: &Series, new: &Series) -> PolarsResult<Series> {
if old.len() == 0 {
return Ok(s.clone());
}
validate_old(old)?;

let dtype = s.dtype();
let old = cast_old_to_series_dtype(old, dtype)?;
let new = new.strict_cast(dtype)?;

if new.len() == 1 {
replace_by_single(s, &old, &new, s)
} else {
replace_by_multiple(s, old, new, s)
}
}

/// Replace all values by different values.
///
/// Unmatched values are replaced by a default value.
pub fn replace_or_default(
s: &Series,
old: &Series,
new: &Series,
default: &Series,
return_dtype: Option<DataType>,
) -> PolarsResult<Series> {
polars_ensure!(
old.n_unique()? == old.len(),
ComputeError: "`old` input for `replace` must not contain duplicates"
default.len() == s.len() || default.len() == 1,
InvalidOperation: "`default` input for `replace_strict` must have the same length as the input or have length 1"
);
validate_old(old)?;

let return_dtype = match return_dtype {
Some(dtype) => dtype,
None => try_get_supertype(new.dtype(), default.dtype())?,
};

polars_ensure!(
default.len() == s.len() || default.len() == 1,
ComputeError: "`default` input for `replace` must have the same length as the input or have length 1"
);

let default = default.cast(&return_dtype)?;

if old.len() == 0 {
Expand All @@ -37,18 +51,10 @@ pub fn replace(
} else {
default
};

return Ok(out);
}

let old = match (s.dtype(), old.dtype()) {
#[cfg(feature = "dtype-categorical")]
(DataType::Categorical(_, ord), DataType::String) => {
let dt = DataType::Categorical(None, *ord);
old.strict_cast(&dt)?
},
_ => old.strict_cast(s.dtype())?,
};
let old = cast_old_to_series_dtype(old, s.dtype())?;
let new = new.cast(&return_dtype)?;

if new.len() == 1 {
Expand All @@ -58,26 +64,95 @@ pub fn replace(
}
}

/// Replace all values by different values.
///
/// Raises an error if not all values were replaced.
pub fn replace_strict(
s: &Series,
old: &Series,
new: &Series,
return_dtype: Option<DataType>,
) -> PolarsResult<Series> {
if old.len() == 0 {
polars_ensure!(
s.len() == s.null_count(),
InvalidOperation: "must specify which values to replace"
);
return Ok(s.clone());
}
validate_old(old)?;

let old = cast_old_to_series_dtype(old, s.dtype())?;
let new = match return_dtype {
Some(dtype) => new.strict_cast(&dtype)?,
None => new.clone(),
};

if new.len() == 1 {
replace_by_single_strict(s, &old, &new)
} else {
replace_by_multiple_strict(s, old, new)
}
}

/// Validate the `old` input.
fn validate_old(old: &Series) -> PolarsResult<()> {
polars_ensure!(
old.n_unique()? == old.len(),
InvalidOperation: "`old` input for `replace` must not contain duplicates"
);
Ok(())
}

/// Cast `old` input while enabling String to Categorical casts.
fn cast_old_to_series_dtype(old: &Series, dtype: &DataType) -> PolarsResult<Series> {
match (old.dtype(), dtype) {
#[cfg(feature = "dtype-categorical")]
(DataType::String, DataType::Categorical(_, ord)) => {
let empty_categorical_dtype = DataType::Categorical(None, *ord);
old.strict_cast(&empty_categorical_dtype)
},
_ => old.strict_cast(dtype),
}
}

// Fast path for replacing by a single value
fn replace_by_single(
s: &Series,
old: &Series,
new: &Series,
default: &Series,
) -> PolarsResult<Series> {
let mask = if old.null_count() == old.len() {
s.is_null()
} else {
let mask = is_in(s, old)?;

if old.null_count() == 0 {
mask
} else {
mask.bitor(s.is_null())
}
};
let mut mask = get_replacement_mask(s, old)?;
if old.null_count() > 0 {
mask = mask.fill_null_with_values(true)?;
}
new.zip_with(&mask, default)
}
/// Fast path for replacing by a single value in strict mode
fn replace_by_single_strict(s: &Series, old: &Series, new: &Series) -> PolarsResult<Series> {
let mask = get_replacement_mask(s, old)?;
ensure_all_replaced(&mask, s, old.null_count() > 0, true)?;

let mut out = new.new_from_index(0, s.len());

// Transfer validity from `mask` to `out`.
if mask.null_count() > 0 {
out = out.zip_with(&mask, &Series::new_null("", s.len()))?
}
Ok(out)
}
/// Get a boolean mask of which values in the original Series will be replaced.
///
/// Null values are propagated to the mask.
fn get_replacement_mask(s: &Series, old: &Series) -> PolarsResult<BooleanChunked> {
if old.null_count() == old.len() {
// Fast path for when users are using `replace(None, ...)` instead of `fill_null`.
Ok(s.is_null())
} else {
is_in(s, old)
}
}

/// General case for replacing by multiple values
fn replace_by_multiple(
Expand All @@ -86,20 +161,19 @@ fn replace_by_multiple(
new: Series,
default: &Series,
) -> PolarsResult<Series> {
polars_ensure!(
new.len() == old.len(),
ComputeError: "`new` input for `replace` must have the same length as `old` or have length 1"
);
validate_new(&new, &old)?;

let df = s.clone().into_frame();
let replacer = create_replacer(old, new)?;
let add_replacer_mask = new.null_count() > 0;
let replacer = create_replacer(old, new, add_replacer_mask)?;

let joined = df.join(
&replacer,
[s.name()],
["__POLARS_REPLACE_OLD"],
JoinArgs {
how: JoinType::Left,
coalesce: JoinCoalesce::CoalesceColumns,
join_nulls: true,
..Default::default()
},
Expand All @@ -123,12 +197,44 @@ fn replace_by_multiple(
}
}

// Build replacer dataframe
fn create_replacer(mut old: Series, mut new: Series) -> PolarsResult<DataFrame> {
/// General case for replacing by multiple values in strict mode
fn replace_by_multiple_strict(s: &Series, old: Series, new: Series) -> PolarsResult<Series> {
validate_new(&new, &old)?;

let df = s.clone().into_frame();
let old_has_null = old.null_count() > 0;
let replacer = create_replacer(old, new, true)?;

let joined = df.join(
&replacer,
[s.name()],
["__POLARS_REPLACE_OLD"],
JoinArgs {
how: JoinType::Left,
coalesce: JoinCoalesce::CoalesceColumns,
join_nulls: true,
..Default::default()
},
)?;

let replaced = joined.column("__POLARS_REPLACE_NEW").unwrap();

let mask = joined
.column("__POLARS_REPLACE_MASK")
.unwrap()
.bool()
.unwrap();
ensure_all_replaced(mask, s, old_has_null, false)?;

Ok(replaced.clone())
}

// Build replacer dataframe.
fn create_replacer(mut old: Series, mut new: Series, add_mask: bool) -> PolarsResult<DataFrame> {
old.rename("__POLARS_REPLACE_OLD");
new.rename("__POLARS_REPLACE_NEW");

let cols = if new.null_count() > 0 {
let cols = if add_mask {
let mask = Series::new("__POLARS_REPLACE_MASK", &[true]).new_from_index(0, new.len());
vec![old, new, mask]
} else {
Expand All @@ -137,3 +243,35 @@ fn create_replacer(mut old: Series, mut new: Series) -> PolarsResult<DataFrame>
let out = unsafe { DataFrame::new_no_checks(cols) };
Ok(out)
}

/// Validate the `new` input.
fn validate_new(new: &Series, old: &Series) -> PolarsResult<()> {
polars_ensure!(
new.len() == old.len(),
InvalidOperation: "`new` input for `replace` must have the same length as `old` or have length 1"
);
Ok(())
}

/// Ensure that all values were replaced.
fn ensure_all_replaced(
mask: &BooleanChunked,
s: &Series,
old_has_null: bool,
check_all: bool,
) -> PolarsResult<()> {
let nulls_check = if old_has_null {
mask.null_count() == 0
} else {
mask.null_count() == s.null_count()
};
// Checking booleans is only relevant for the 'replace_by_single' path.
let bools_check = !check_all || mask.all();

let all_replaced = bools_check && nulls_check;
polars_ensure!(
all_replaced,
InvalidOperation: "incomplete mapping specified for `replace_strict`\n\nHint: Pass a `default` value to set unmapped values."
);
Ok(())
}
15 changes: 12 additions & 3 deletions crates/polars-plan/src/dsl/function_expr/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,18 @@ pub(super) fn hist(
}

#[cfg(feature = "replace")]
pub(super) fn replace(s: &[Series], return_dtype: Option<DataType>) -> PolarsResult<Series> {
let default = if let Some(s) = s.get(3) { s } else { &s[0] };
polars_ops::series::replace(&s[0], &s[1], &s[2], default, return_dtype)
pub(super) fn replace(s: &[Series]) -> PolarsResult<Series> {
polars_ops::series::replace(&s[0], &s[1], &s[2])
}

#[cfg(feature = "replace")]
pub(super) fn replace_strict(s: &[Series], return_dtype: Option<DataType>) -> PolarsResult<Series> {
match s.get(3) {
Some(default) => {
polars_ops::series::replace_or_default(&s[0], &s[1], &s[2], default, return_dtype)
},
None => polars_ops::series::replace_strict(&s[0], &s[1], &s[2], return_dtype),
}
}

pub(super) fn fill_null_with_strategy(
Expand Down
21 changes: 16 additions & 5 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,9 @@ pub enum FunctionExpr {
options: EWMOptions,
},
#[cfg(feature = "replace")]
Replace {
Replace,
#[cfg(feature = "replace")]
ReplaceStrict {
return_dtype: Option<DataType>,
},
GatherEvery {
Expand Down Expand Up @@ -567,7 +569,9 @@ impl Hash for FunctionExpr {
include_breakpoint.hash(state);
},
#[cfg(feature = "replace")]
Replace { return_dtype } => return_dtype.hash(state),
Replace => {},
#[cfg(feature = "replace")]
ReplaceStrict { return_dtype } => return_dtype.hash(state),
FillNullWithStrategy(strategy) => strategy.hash(state),
GatherEvery { n, offset } => (n, offset).hash(state),
#[cfg(feature = "reinterpret")]
Expand Down Expand Up @@ -752,7 +756,9 @@ impl Display for FunctionExpr {
#[cfg(feature = "hist")]
Hist { .. } => "hist",
#[cfg(feature = "replace")]
Replace { .. } => "replace",
Replace => "replace",
#[cfg(feature = "replace")]
ReplaceStrict { .. } => "replace_strict",
FillNullWithStrategy(_) => "fill_null_with_strategy",
GatherEvery { .. } => "gather_every",
#[cfg(feature = "reinterpret")]
Expand Down Expand Up @@ -1138,9 +1144,14 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
#[cfg(feature = "ewma")]
EwmVar { options } => map!(ewm::ewm_var, options),
#[cfg(feature = "replace")]
Replace { return_dtype } => {
map_as_slice!(dispatch::replace, return_dtype.clone())
Replace => {
map_as_slice!(dispatch::replace)
},
#[cfg(feature = "replace")]
ReplaceStrict { return_dtype } => {
map_as_slice!(dispatch::replace_strict, return_dtype.clone())
},

FillNullWithStrategy(strategy) => map!(dispatch::fill_null_with_strategy, strategy),
GatherEvery { n, offset } => map!(dispatch::gather_every, n, offset),
#[cfg(feature = "reinterpret")]
Expand Down
12 changes: 8 additions & 4 deletions crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,9 @@ impl FunctionExpr {
#[cfg(feature = "ewma")]
EwmVar { .. } => mapper.map_to_float_dtype(),
#[cfg(feature = "replace")]
Replace { return_dtype } => mapper.replace_dtype(return_dtype.clone()),
Replace => mapper.with_same_dtype(),
#[cfg(feature = "replace")]
ReplaceStrict { return_dtype } => mapper.replace_dtype(return_dtype.clone()),
FillNullWithStrategy(_) => mapper.with_same_dtype(),
GatherEvery { .. } => mapper.with_same_dtype(),
#[cfg(feature = "reinterpret")]
Expand Down Expand Up @@ -532,11 +534,13 @@ impl<'a> FieldsMapper<'a> {
pub fn replace_dtype(&self, return_dtype: Option<DataType>) -> PolarsResult<Field> {
let dtype = match return_dtype {
Some(dtype) => dtype,
// Supertype of `new` and `default`
None => {
let column_type = &self.fields[0];
let new = &self.fields[2];
try_get_supertype(column_type.data_type(), new.data_type())?
let default = self.fields.get(3);
match default {
Some(default) => try_get_supertype(default.data_type(), new.data_type())?,
None => new.data_type().clone(),
}
},
};
self.with_dtype(dtype)
Expand Down
Loading

0 comments on commit 0d3a285

Please sign in to comment.