Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Split replace functionality into two separate methods #16921

Merged
merged 5 commits into from
Jun 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 177 additions & 39 deletions crates/polars-ops/src/series/ops/replace.rs
Original file line number Diff line number Diff line change
@@ -1,34 +1,48 @@
use std::ops::BitOr;

use polars_core::prelude::*;
use polars_core::utils::try_get_supertype;
use polars_error::polars_ensure;

use crate::frame::join::*;
use crate::prelude::*;

pub fn replace(
/// Replace values by different values of the same data type.
pub fn replace(s: &Series, old: &Series, new: &Series) -> PolarsResult<Series> {
if old.len() == 0 {
return Ok(s.clone());
}
validate_old(old)?;

let dtype = s.dtype();
let old = cast_old_to_series_dtype(old, dtype)?;
let new = new.strict_cast(dtype)?;

if new.len() == 1 {
replace_by_single(s, &old, &new, s)
} else {
replace_by_multiple(s, old, new, s)
}
}

/// Replace all values by different values.
///
/// Unmatched values are replaced by a default value.
pub fn replace_or_default(
s: &Series,
old: &Series,
new: &Series,
default: &Series,
return_dtype: Option<DataType>,
) -> PolarsResult<Series> {
polars_ensure!(
old.n_unique()? == old.len(),
ComputeError: "`old` input for `replace` must not contain duplicates"
default.len() == s.len() || default.len() == 1,
InvalidOperation: "`default` input for `replace_strict` must have the same length as the input or have length 1"
);
validate_old(old)?;

let return_dtype = match return_dtype {
Some(dtype) => dtype,
None => try_get_supertype(new.dtype(), default.dtype())?,
};

polars_ensure!(
default.len() == s.len() || default.len() == 1,
ComputeError: "`default` input for `replace` must have the same length as the input or have length 1"
);

let default = default.cast(&return_dtype)?;

if old.len() == 0 {
Expand All @@ -37,18 +51,10 @@ pub fn replace(
} else {
default
};

return Ok(out);
}

let old = match (s.dtype(), old.dtype()) {
#[cfg(feature = "dtype-categorical")]
(DataType::Categorical(_, ord), DataType::String) => {
let dt = DataType::Categorical(None, *ord);
old.strict_cast(&dt)?
},
_ => old.strict_cast(s.dtype())?,
};
let old = cast_old_to_series_dtype(old, s.dtype())?;
let new = new.cast(&return_dtype)?;

if new.len() == 1 {
Expand All @@ -58,26 +64,95 @@ pub fn replace(
}
}

/// Replace all values by different values.
///
/// Raises an error if not all values were replaced.
pub fn replace_strict(
s: &Series,
old: &Series,
new: &Series,
return_dtype: Option<DataType>,
) -> PolarsResult<Series> {
if old.len() == 0 {
polars_ensure!(
s.len() == s.null_count(),
InvalidOperation: "must specify which values to replace"
);
return Ok(s.clone());
}
validate_old(old)?;

let old = cast_old_to_series_dtype(old, s.dtype())?;
let new = match return_dtype {
Some(dtype) => new.strict_cast(&dtype)?,
None => new.clone(),
};

if new.len() == 1 {
replace_by_single_strict(s, &old, &new)
} else {
replace_by_multiple_strict(s, old, new)
}
}

/// Validate the `old` input.
fn validate_old(old: &Series) -> PolarsResult<()> {
polars_ensure!(
old.n_unique()? == old.len(),
InvalidOperation: "`old` input for `replace` must not contain duplicates"
);
Ok(())
}

/// Cast `old` input while enabling String to Categorical casts.
fn cast_old_to_series_dtype(old: &Series, dtype: &DataType) -> PolarsResult<Series> {
match (old.dtype(), dtype) {
#[cfg(feature = "dtype-categorical")]
(DataType::String, DataType::Categorical(_, ord)) => {
let empty_categorical_dtype = DataType::Categorical(None, *ord);
old.strict_cast(&empty_categorical_dtype)
},
_ => old.strict_cast(dtype),
}
}

// Fast path for replacing by a single value
fn replace_by_single(
s: &Series,
old: &Series,
new: &Series,
default: &Series,
) -> PolarsResult<Series> {
let mask = if old.null_count() == old.len() {
s.is_null()
} else {
let mask = is_in(s, old)?;

if old.null_count() == 0 {
mask
} else {
mask.bitor(s.is_null())
}
};
let mut mask = get_replacement_mask(s, old)?;
if old.null_count() > 0 {
mask = mask.fill_null_with_values(true)?;
}
new.zip_with(&mask, default)
}
/// Fast path for replacing by a single value in strict mode
fn replace_by_single_strict(s: &Series, old: &Series, new: &Series) -> PolarsResult<Series> {
let mask = get_replacement_mask(s, old)?;
ensure_all_replaced(&mask, s, old.null_count() > 0, true)?;

let mut out = new.new_from_index(0, s.len());

// Transfer validity from `mask` to `out`.
if mask.null_count() > 0 {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a future PR we must optimize this. Adding a validity is almost free, but here it takes a very expensive branchy path.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I wasn't happy about this either - but I couldn't find a way to do this without a lot of code... I figure we're probably missing some API building blocks here.

out = out.zip_with(&mask, &Series::new_null("", s.len()))?
}
Ok(out)
}
/// Get a boolean mask of which values in the original Series will be replaced.
///
/// Null values are propagated to the mask.
fn get_replacement_mask(s: &Series, old: &Series) -> PolarsResult<BooleanChunked> {
if old.null_count() == old.len() {
// Fast path for when users are using `replace(None, ...)` instead of `fill_null`.
Ok(s.is_null())
} else {
is_in(s, old)
}
}

/// General case for replacing by multiple values
fn replace_by_multiple(
Expand All @@ -86,20 +161,19 @@ fn replace_by_multiple(
new: Series,
default: &Series,
) -> PolarsResult<Series> {
polars_ensure!(
new.len() == old.len(),
ComputeError: "`new` input for `replace` must have the same length as `old` or have length 1"
);
validate_new(&new, &old)?;

let df = s.clone().into_frame();
let replacer = create_replacer(old, new)?;
let add_replacer_mask = new.null_count() > 0;
let replacer = create_replacer(old, new, add_replacer_mask)?;

let joined = df.join(
&replacer,
[s.name()],
["__POLARS_REPLACE_OLD"],
JoinArgs {
how: JoinType::Left,
coalesce: JoinCoalesce::CoalesceColumns,
join_nulls: true,
..Default::default()
},
Expand All @@ -123,12 +197,44 @@ fn replace_by_multiple(
}
}

// Build replacer dataframe
fn create_replacer(mut old: Series, mut new: Series) -> PolarsResult<DataFrame> {
/// General case for replacing by multiple values in strict mode
fn replace_by_multiple_strict(s: &Series, old: Series, new: Series) -> PolarsResult<Series> {
validate_new(&new, &old)?;

let df = s.clone().into_frame();
let old_has_null = old.null_count() > 0;
let replacer = create_replacer(old, new, true)?;

let joined = df.join(
&replacer,
[s.name()],
["__POLARS_REPLACE_OLD"],
JoinArgs {
how: JoinType::Left,
coalesce: JoinCoalesce::CoalesceColumns,
join_nulls: true,
..Default::default()
},
)?;

let replaced = joined.column("__POLARS_REPLACE_NEW").unwrap();

let mask = joined
.column("__POLARS_REPLACE_MASK")
.unwrap()
.bool()
.unwrap();
ensure_all_replaced(mask, s, old_has_null, false)?;

Ok(replaced.clone())
}

// Build replacer dataframe.
fn create_replacer(mut old: Series, mut new: Series, add_mask: bool) -> PolarsResult<DataFrame> {
old.rename("__POLARS_REPLACE_OLD");
new.rename("__POLARS_REPLACE_NEW");

let cols = if new.null_count() > 0 {
let cols = if add_mask {
let mask = Series::new("__POLARS_REPLACE_MASK", &[true]).new_from_index(0, new.len());
vec![old, new, mask]
} else {
Expand All @@ -137,3 +243,35 @@ fn create_replacer(mut old: Series, mut new: Series) -> PolarsResult<DataFrame>
let out = unsafe { DataFrame::new_no_checks(cols) };
Ok(out)
}

/// Validate the `new` input.
fn validate_new(new: &Series, old: &Series) -> PolarsResult<()> {
polars_ensure!(
new.len() == old.len(),
InvalidOperation: "`new` input for `replace` must have the same length as `old` or have length 1"
);
Ok(())
}

/// Ensure that all values were replaced.
fn ensure_all_replaced(
mask: &BooleanChunked,
s: &Series,
old_has_null: bool,
check_all: bool,
) -> PolarsResult<()> {
let nulls_check = if old_has_null {
mask.null_count() == 0
} else {
mask.null_count() == s.null_count()
};
// Checking booleans is only relevant for the 'replace_by_single' path.
let bools_check = !check_all || mask.all();

let all_replaced = bools_check && nulls_check;
polars_ensure!(
all_replaced,
InvalidOperation: "incomplete mapping specified for `replace_strict`\n\nHint: Pass a `default` value to set unmapped values."
);
Ok(())
}
15 changes: 12 additions & 3 deletions crates/polars-plan/src/dsl/function_expr/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,18 @@ pub(super) fn hist(
}

#[cfg(feature = "replace")]
pub(super) fn replace(s: &[Series], return_dtype: Option<DataType>) -> PolarsResult<Series> {
let default = if let Some(s) = s.get(3) { s } else { &s[0] };
polars_ops::series::replace(&s[0], &s[1], &s[2], default, return_dtype)
pub(super) fn replace(s: &[Series]) -> PolarsResult<Series> {
polars_ops::series::replace(&s[0], &s[1], &s[2])
}

#[cfg(feature = "replace")]
pub(super) fn replace_strict(s: &[Series], return_dtype: Option<DataType>) -> PolarsResult<Series> {
match s.get(3) {
Some(default) => {
polars_ops::series::replace_or_default(&s[0], &s[1], &s[2], default, return_dtype)
},
None => polars_ops::series::replace_strict(&s[0], &s[1], &s[2], return_dtype),
}
}

pub(super) fn fill_null_with_strategy(
Expand Down
21 changes: 16 additions & 5 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,9 @@ pub enum FunctionExpr {
options: EWMOptions,
},
#[cfg(feature = "replace")]
Replace {
Replace,
#[cfg(feature = "replace")]
ReplaceStrict {
return_dtype: Option<DataType>,
},
GatherEvery {
Expand Down Expand Up @@ -567,7 +569,9 @@ impl Hash for FunctionExpr {
include_breakpoint.hash(state);
},
#[cfg(feature = "replace")]
Replace { return_dtype } => return_dtype.hash(state),
Replace => {},
#[cfg(feature = "replace")]
ReplaceStrict { return_dtype } => return_dtype.hash(state),
FillNullWithStrategy(strategy) => strategy.hash(state),
GatherEvery { n, offset } => (n, offset).hash(state),
#[cfg(feature = "reinterpret")]
Expand Down Expand Up @@ -752,7 +756,9 @@ impl Display for FunctionExpr {
#[cfg(feature = "hist")]
Hist { .. } => "hist",
#[cfg(feature = "replace")]
Replace { .. } => "replace",
Replace => "replace",
#[cfg(feature = "replace")]
ReplaceStrict { .. } => "replace_strict",
FillNullWithStrategy(_) => "fill_null_with_strategy",
GatherEvery { .. } => "gather_every",
#[cfg(feature = "reinterpret")]
Expand Down Expand Up @@ -1138,9 +1144,14 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
#[cfg(feature = "ewma")]
EwmVar { options } => map!(ewm::ewm_var, options),
#[cfg(feature = "replace")]
Replace { return_dtype } => {
map_as_slice!(dispatch::replace, return_dtype.clone())
Replace => {
map_as_slice!(dispatch::replace)
},
#[cfg(feature = "replace")]
ReplaceStrict { return_dtype } => {
map_as_slice!(dispatch::replace_strict, return_dtype.clone())
},

FillNullWithStrategy(strategy) => map!(dispatch::fill_null_with_strategy, strategy),
GatherEvery { n, offset } => map!(dispatch::gather_every, n, offset),
#[cfg(feature = "reinterpret")]
Expand Down
12 changes: 8 additions & 4 deletions crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,9 @@ impl FunctionExpr {
#[cfg(feature = "ewma")]
EwmVar { .. } => mapper.map_to_float_dtype(),
#[cfg(feature = "replace")]
Replace { return_dtype } => mapper.replace_dtype(return_dtype.clone()),
Replace => mapper.with_same_dtype(),
#[cfg(feature = "replace")]
ReplaceStrict { return_dtype } => mapper.replace_dtype(return_dtype.clone()),
FillNullWithStrategy(_) => mapper.with_same_dtype(),
GatherEvery { .. } => mapper.with_same_dtype(),
#[cfg(feature = "reinterpret")]
Expand Down Expand Up @@ -532,11 +534,13 @@ impl<'a> FieldsMapper<'a> {
pub fn replace_dtype(&self, return_dtype: Option<DataType>) -> PolarsResult<Field> {
let dtype = match return_dtype {
Some(dtype) => dtype,
// Supertype of `new` and `default`
None => {
let column_type = &self.fields[0];
let new = &self.fields[2];
try_get_supertype(column_type.data_type(), new.data_type())?
let default = self.fields.get(3);
match default {
Some(default) => try_get_supertype(default.data_type(), new.data_type())?,
None => new.data_type().clone(),
}
},
};
self.with_dtype(dtype)
Expand Down
Loading