From 2f2ee7808918d3f084e297c90ddaded829083378 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Fri, 7 Jun 2024 10:23:10 +0200 Subject: [PATCH 1/5] Add doc entries --- py-polars/docs/source/reference/expressions/modify_select.rst | 1 + py-polars/docs/source/reference/series/computation.rst | 1 + 2 files changed, 2 insertions(+) diff --git a/py-polars/docs/source/reference/expressions/modify_select.rst b/py-polars/docs/source/reference/expressions/modify_select.rst index 53437ea96fc9..73b9aaee9b5f 100644 --- a/py-polars/docs/source/reference/expressions/modify_select.rst +++ b/py-polars/docs/source/reference/expressions/modify_select.rst @@ -41,6 +41,7 @@ Manipulation/selection Expr.reinterpret Expr.repeat_by Expr.replace + Expr.replace_strict Expr.reshape Expr.reverse Expr.rle diff --git a/py-polars/docs/source/reference/series/computation.rst b/py-polars/docs/source/reference/series/computation.rst index 9710a23dc782..9e3edb3ac0f6 100644 --- a/py-polars/docs/source/reference/series/computation.rst +++ b/py-polars/docs/source/reference/series/computation.rst @@ -45,6 +45,7 @@ Computation Series.peak_min Series.rank Series.replace + Series.replace_strict Series.rolling_map Series.rolling_max Series.rolling_mean From 3ba88c581133af09a581920176686f81f6142f87 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Thu, 13 Jun 2024 10:16:58 +0200 Subject: [PATCH 2/5] Update Python implementation --- py-polars/polars/_utils/udfs.py | 4 +- py-polars/polars/_utils/various.py | 5 +- py-polars/polars/expr/expr.py | 217 +++++++++++++++++++++++++++-- py-polars/polars/series/series.py | 155 ++++++++++++++++++--- 4 files changed, 341 insertions(+), 40 deletions(-) diff --git a/py-polars/polars/_utils/udfs.py b/py-polars/polars/_utils/udfs.py index 381bc0104630..6a94c5bf6b9f 100644 --- a/py-polars/polars/_utils/udfs.py +++ b/py-polars/polars/_utils/udfs.py @@ -601,7 +601,7 @@ def op(inst: Instruction) -> str: elif inst.opname in OpNames.UNARY: return OpNames.UNARY[inst.opname] elif inst.opname == "BINARY_SUBSCR": - return "replace" + return "replace_strict" else: msg = ( "unrecognized opname" @@ -654,7 +654,7 @@ def _expr(self, value: StackEntry, col: str, param_name: str, depth: int) -> str if " " in e1 else f"{not_}{e1}.is_in({e2})" ) - elif op == "replace": + elif op == "replace_strict": if not self._caller_variables: self._caller_variables = _get_all_caller_variables() if not isinstance(self._caller_variables.get(e1, None), dict): diff --git a/py-polars/polars/_utils/various.py b/py-polars/polars/_utils/various.py index 09701fd677cc..d9a7dc54eeab 100644 --- a/py-polars/polars/_utils/various.py +++ b/py-polars/polars/_utils/various.py @@ -321,10 +321,7 @@ def str_duration_(td: str | None) -> int | None: .cast(tp) ) elif tp == Boolean: - cast_cols[c] = F.col(c).replace( - {"true": True, "false": False}, - default=None, - ) + cast_cols[c] = F.col(c).replace_strict({"true": True, "false": False}) elif tp in INTEGER_DTYPES: int_string = F.col(c).str.replace_all(r"[^\d+-]", "") cast_cols[c] = ( diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index f0dbff7c8b5d..69141750b280 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -10023,7 +10023,7 @@ def replace( return_dtype: PolarsDataType | None = None, ) -> Expr: """ - Replace values by different values. + Replace the given values by different values of the same data type. Parameters ---------- @@ -10038,16 +10038,27 @@ def replace( Accepts expression input. Sequences are parsed as Series, other non-expression inputs are parsed as literals. Length must match the length of `old` or have length 1. + default Set values that were not replaced to this value. Defaults to keeping the original value. Accepts expression input. Non-expression inputs are parsed as literals. + + .. deprecated:: 1.0.0 + Use :meth:`replace_strict` instead to set a default while replacing + values. + return_dtype The data type of the resulting expression. If set to `None` (default), - the data type is determined automatically based on the other inputs. + the data type of the original column is preserved. + + .. deprecated:: 1.0.0 + Use :meth:`replace_strict` instead to set a return data type while + replacing values, or explicitly call :meth:`cast` on the output. See Also -------- + replace_strict str.replace Notes @@ -10089,25 +10100,24 @@ def replace( └─────┴──────────┘ Passing a mapping with replacements is also supported as syntactic sugar. - Specify a default to set all values that were not matched. >>> mapping = {2: 100, 3: 200} - >>> df.with_columns(replaced=pl.col("a").replace(mapping, default=-1)) + >>> df.with_columns(replaced=pl.col("a").replace(mapping)) shape: (4, 2) ┌─────┬──────────┐ │ a ┆ replaced │ │ --- ┆ --- │ │ i64 ┆ i64 │ ╞═════╪══════════╡ - │ 1 ┆ -1 │ + │ 1 ┆ 1 │ │ 2 ┆ 100 │ │ 2 ┆ 100 │ │ 3 ┆ 200 │ └─────┴──────────┘ - Replacing by values of a different data type sets the return type based on - a combination of the `new` data type and either the original data type or the - default data type if it was set. + The original data type is preserved when replacing by values of a different + data type. Use :meth:`replace_strict` to replace and change the return data + type. >>> df = pl.DataFrame({"a": ["x", "y", "z"]}) >>> mapping = {"x": 1, "y": 2, "z": 3} @@ -10122,7 +10132,175 @@ def replace( │ y ┆ 2 │ │ z ┆ 3 │ └─────┴──────────┘ - >>> df.with_columns(replaced=pl.col("a").replace(mapping, default=None)) + + Expression input is supported. + + >>> df = pl.DataFrame({"a": [1, 2, 2, 3], "b": [1.5, 2.5, 5.0, 1.0]}) + >>> df.with_columns( + ... replaced=pl.col("a").replace( + ... old=pl.col("a").max(), + ... new=pl.col("b").sum(), + ... ) + ... ) + shape: (4, 3) + ┌─────┬─────┬──────────┐ + │ a ┆ b ┆ replaced │ + │ --- ┆ --- ┆ --- │ + │ i64 ┆ f64 ┆ i64 │ + ╞═════╪═════╪══════════╡ + │ 1 ┆ 1.5 ┆ 1 │ + │ 2 ┆ 2.5 ┆ 2 │ + │ 2 ┆ 5.0 ┆ 2 │ + │ 3 ┆ 1.0 ┆ 10 │ + └─────┴─────┴──────────┘ + """ + if return_dtype is not None: + issue_deprecation_warning( + "The `return_dtype` parameter for `replace` is deprecated." + " Use `replace_strict` instead to set a return data type while replacing values.", + version="1.0.0", + ) + if default is not no_default: + issue_deprecation_warning( + "The `default` parameter for `replace` is deprecated." + " Use `replace_strict` instead to set a default while replacing values.", + version="1.0.0", + ) + return self.replace_strict( + old, new, default=default, return_dtype=return_dtype + ) + + if new is no_default and isinstance(old, Mapping): + new = pl.Series(old.values()) + old = pl.Series(old.keys()) + else: + if isinstance(old, Sequence) and not isinstance(old, (str, pl.Series)): + old = pl.Series(old) + if isinstance(new, Sequence) and not isinstance(new, (str, pl.Series)): + new = pl.Series(new) + + old = parse_into_expression(old, str_as_lit=True) # type: ignore[arg-type] + new = parse_into_expression(new, str_as_lit=True) # type: ignore[arg-type] + + result = self._from_pyexpr(self._pyexpr.replace(old, new)) + + if return_dtype is not None: + result = result.cast(return_dtype) + + return result + + def replace_strict( + self, + old: IntoExpr | Sequence[Any] | Mapping[Any, Any], + new: IntoExpr | Sequence[Any] | NoDefault = no_default, + *, + default: IntoExpr | NoDefault = no_default, + return_dtype: PolarsDataType | None = None, + ) -> Expr: + """ + Replace all values by different values. + + Parameters + ---------- + old + Value or sequence of values to replace. + Accepts expression input. Sequences are parsed as Series, + other non-expression inputs are parsed as literals. + Also accepts a mapping of values to their replacement as syntactic sugar for + `replace_all(old=Series(mapping.keys()), new=Series(mapping.values()))`. + new + Value or sequence of values to replace by. + Accepts expression input. Sequences are parsed as Series, + other non-expression inputs are parsed as literals. + Length must match the length of `old` or have length 1. + default + Set values that were not replaced to this value. If no default is specified, + (default), an error is raised if any values were not replaced. + Accepts expression input. Non-expression inputs are parsed as literals. + return_dtype + The data type of the resulting expression. If set to `None` (default), + the data type is determined automatically based on the other inputs. + + Raises + ------ + InvalidOperationError + If any non-null values in the original column were not replaced, and no + `default` was specified. + + See Also + -------- + replace + str.replace + + Notes + ----- + The global string cache must be enabled when replacing categorical values. + + Examples + -------- + Replace values by passing sequences to the `old` and `new` parameters. + + >>> df = pl.DataFrame({"a": [1, 2, 2, 3]}) + >>> df.with_columns( + ... replaced=pl.col("a").replace_strict([1, 2, 3], [100, 200, 300]) + ... ) + shape: (4, 2) + ┌─────┬──────────┐ + │ a ┆ replaced │ + │ --- ┆ --- │ + │ i64 ┆ i64 │ + ╞═════╪══════════╡ + │ 1 ┆ 100 │ + │ 2 ┆ 200 │ + │ 2 ┆ 200 │ + │ 3 ┆ 300 │ + └─────┴──────────┘ + + Passing a mapping with replacements is also supported as syntactic sugar. + + >>> mapping = {1: 100, 2: 200, 3: 300} + >>> df.with_columns(replaced=pl.col("a").replace_strict(mapping)) + shape: (4, 2) + ┌─────┬──────────┐ + │ a ┆ replaced │ + │ --- ┆ --- │ + │ i64 ┆ i64 │ + ╞═════╪══════════╡ + │ 1 ┆ 100 │ + │ 2 ┆ 200 │ + │ 2 ┆ 200 │ + │ 3 ┆ 300 │ + └─────┴──────────┘ + + By default, an error is raised if any non-null values were not replaced. + Specify a default to set all values that were not matched. + + >>> mapping = {2: 200, 3: 300} + >>> df.with_columns( + ... replaced=pl.col("a").replace_strict(mapping) + ... ) # doctest: +SKIP + Traceback (most recent call last): + ... + polars.exceptions.InvalidOperationError: incomplete mapping specified for `replace_strict` + >>> df.with_columns(replaced=pl.col("a").replace_strict(mapping, default=-1)) + shape: (4, 2) + ┌─────┬──────────┐ + │ a ┆ replaced │ + │ --- ┆ --- │ + │ i64 ┆ i64 │ + ╞═════╪══════════╡ + │ 1 ┆ -1 │ + │ 2 ┆ 200 │ + │ 2 ┆ 200 │ + │ 3 ┆ 300 │ + └─────┴──────────┘ + + Replacing by values of a different data type sets the return type based on + a combination of the `new` data type and the `default` data type. + + >>> df = pl.DataFrame({"a": ["x", "y", "z"]}) + >>> mapping = {"x": 1, "y": 2, "z": 3} + >>> df.with_columns(replaced=pl.col("a").replace_strict(mapping)) shape: (3, 2) ┌─────┬──────────┐ │ a ┆ replaced │ @@ -10133,11 +10311,22 @@ def replace( │ y ┆ 2 │ │ z ┆ 3 │ └─────┴──────────┘ + >>> df.with_columns(replaced=pl.col("a").replace_strict(mapping, default="x")) + shape: (3, 2) + ┌─────┬──────────┐ + │ a ┆ replaced │ + │ --- ┆ --- │ + │ str ┆ str │ + ╞═════╪══════════╡ + │ x ┆ 1 │ + │ y ┆ 2 │ + │ z ┆ 3 │ + └─────┴──────────┘ Set the `return_dtype` parameter to control the resulting data type directly. >>> df.with_columns( - ... replaced=pl.col("a").replace(mapping, return_dtype=pl.UInt8) + ... replaced=pl.col("a").replace_strict(mapping, return_dtype=pl.UInt8) ... ) shape: (3, 2) ┌─────┬──────────┐ @@ -10154,7 +10343,7 @@ def replace( >>> df = pl.DataFrame({"a": [1, 2, 2, 3], "b": [1.5, 2.5, 5.0, 1.0]}) >>> df.with_columns( - ... replaced=pl.col("a").replace( + ... replaced=pl.col("a").replace_strict( ... old=pl.col("a").max(), ... new=pl.col("b").sum(), ... default=pl.col("b"), @@ -10171,7 +10360,7 @@ def replace( │ 2 ┆ 5.0 ┆ 5.0 │ │ 3 ┆ 1.0 ┆ 10.0 │ └─────┴─────┴──────────┘ - """ + """ # noqa: W505 if new is no_default and isinstance(old, Mapping): new = pl.Series(old.values()) old = pl.Series(old.keys()) @@ -10185,7 +10374,9 @@ def replace( else parse_into_expression(default, str_as_lit=True) ) - return self._from_pyexpr(self._pyexpr.replace(old, new, default, return_dtype)) + return self._from_pyexpr( + self._pyexpr.replace_strict(old, new, default, return_dtype) + ) @deprecate_function( "Use `polars.plugins.register_plugin_function` instead.", version="0.20.16" diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 7ba1f9dab824..7aead22f05f7 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -6493,7 +6493,7 @@ def replace( return_dtype: PolarsDataType | None = None, ) -> Self: """ - Replace values by different values. + Replace values by different values of the same data type. Parameters ---------- @@ -6504,16 +6504,27 @@ def replace( new Value or sequence of values to replace by. Length must match the length of `old` or have length 1. + default Set values that were not replaced to this value. Defaults to keeping the original value. Accepts expression input. Non-expression inputs are parsed as literals. + + .. deprecated:: 0.20.31 + Use :meth:`replace_all` instead to set a default while replacing values. + return_dtype - The data type of the resulting Series. If set to `None` (default), + The data type of the resulting expression. If set to `None` (default), the data type is determined automatically based on the other inputs. + .. deprecated:: 0.20.31 + Use :meth:`replace_all` instead to set a return data type while + replacing values. + + See Also -------- + replace_strict str.replace Notes @@ -6549,48 +6560,142 @@ def replace( ] Passing a mapping with replacements is also supported as syntactic sugar. - Specify a default to set all values that were not matched. >>> mapping = {2: 100, 3: 200} - >>> s.replace(mapping, default=-1) + >>> s.replace(mapping) shape: (4,) Series: '' [i64] [ - -1 + 1 100 100 200 ] + The original data type is preserved when replacing by values of a different + data type. Use :meth:`replace_strict` to replace and change the return data + type. + + >>> s = pl.Series(["x", "y", "z"]) + >>> mapping = {"x": 1, "y": 2, "z": 3} + >>> s.replace(mapping) + shape: (3,) + Series: '' [str] + [ + "1" + "2" + "3" + ] + """ + + def replace_strict( + self, + old: IntoExpr | Sequence[Any] | Mapping[Any, Any], + new: IntoExpr | Sequence[Any] | NoDefault = no_default, + *, + default: IntoExpr | NoDefault = no_default, + return_dtype: PolarsDataType | None = None, + ) -> Self: + """ + Replace all values by different values. + + Parameters + ---------- + old + Value or sequence of values to replace. + Also accepts a mapping of values to their replacement as syntactic sugar for + `replace_all(old=Series(mapping.keys()), new=Series(mapping.values()))`. + new + Value or sequence of values to replace by. + Length must match the length of `old` or have length 1. + default + Set values that were not replaced to this value. If no default is specified, + (default), an error is raised if any values were not replaced. + Accepts expression input. Non-expression inputs are parsed as literals. + return_dtype + The data type of the resulting Series. If set to `None` (default), + the data type is determined automatically based on the other inputs. + + Raises + ------ + InvalidOperationError + If any non-null values in the original column were not replaced, and no + `default` was specified. + + See Also + -------- + replace + str.replace + + Notes + ----- + The global string cache must be enabled when replacing categorical values. + + Examples + -------- + Replace values by passing sequences to the `old` and `new` parameters. + + >>> s = pl.Series([1, 2, 2, 3]) + >>> s.replace_strict([1, 2, 3], [100, 200, 300]) + shape: (4,) + Series: '' [i64] + [ + 100 + 200 + 200 + 300 + ] + + Passing a mapping with replacements is also supported as syntactic sugar. + + >>> mapping = {1: 100, 2: 200, 3: 300} + >>> s.replace_strict(mapping) + shape: (4,) + Series: '' [i64] + [ + 100 + 200 + 200 + 300 + ] + + By default, an error is raised if any non-null values were not replaced. + Specify a default to set all values that were not matched. + + >>> mapping = {2: 200, 3: 300} + >>> s.replace_strict(mapping) # doctest: +SKIP + Traceback (most recent call last): + ... + polars.exceptions.InvalidOperationError: incomplete mapping specified for `replace_strict` + >>> s.replace_strict(mapping, default=-1) + shape: (4,) + Series: '' [i64] + [ + -1 + 200 + 200 + 300 + ] The default can be another Series. >>> default = pl.Series([2.5, 5.0, 7.5, 10.0]) - >>> s.replace(2, 100, default=default) + >>> s.replace_strict(2, 200, default=default) shape: (4,) Series: '' [f64] [ 2.5 - 100.0 - 100.0 + 200.0 + 200.0 10.0 ] Replacing by values of a different data type sets the return type based on - a combination of the `new` data type and either the original data type or the - default data type if it was set. + a combination of the `new` data type and the `default` data type. >>> s = pl.Series(["x", "y", "z"]) >>> mapping = {"x": 1, "y": 2, "z": 3} - >>> s.replace(mapping) - shape: (3,) - Series: '' [str] - [ - "1" - "2" - "3" - ] - >>> s.replace(mapping, default=None) + >>> s.replace_strict(mapping) shape: (3,) Series: '' [i64] [ @@ -6598,10 +6703,18 @@ def replace( 2 3 ] + >>> s.replace_strict(mapping, default="x") + shape: (3,) + Series: '' [str] + [ + "1" + "2" + "3" + ] Set the `return_dtype` parameter to control the resulting data type directly. - >>> s.replace(mapping, return_dtype=pl.UInt8) + >>> s.replace_strict(mapping, return_dtype=pl.UInt8) shape: (3,) Series: '' [u8] [ @@ -6609,7 +6722,7 @@ def replace( 2 3 ] - """ + """ # noqa: W505 def reshape(self, dimensions: tuple[int, ...]) -> Series: """ From d1e868a8f9cd4c11deb3c99b59d712a579991e62 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Thu, 13 Jun 2024 10:17:15 +0200 Subject: [PATCH 3/5] Update Rust side --- crates/polars-ops/src/series/ops/replace.rs | 207 ++++++++++++++---- .../src/dsl/function_expr/dispatch.rs | 15 +- .../polars-plan/src/dsl/function_expr/mod.rs | 21 +- .../src/dsl/function_expr/schema.rs | 12 +- crates/polars-plan/src/dsl/mod.rs | 37 +++- py-polars/src/expr/general.rs | 8 +- py-polars/src/lazyframe/visitor/expr_nodes.rs | 5 +- 7 files changed, 246 insertions(+), 59 deletions(-) diff --git a/crates/polars-ops/src/series/ops/replace.rs b/crates/polars-ops/src/series/ops/replace.rs index c169bff7f70d..54ee17e99576 100644 --- a/crates/polars-ops/src/series/ops/replace.rs +++ b/crates/polars-ops/src/series/ops/replace.rs @@ -1,5 +1,3 @@ -use std::ops::BitOr; - use polars_core::prelude::*; use polars_core::utils::try_get_supertype; use polars_error::polars_ensure; @@ -7,7 +5,28 @@ use polars_error::polars_ensure; use crate::frame::join::*; use crate::prelude::*; -pub fn replace( +/// Replace values by different values of the same data type. +pub fn replace(s: &Series, old: &Series, new: &Series) -> PolarsResult { + if old.len() == 0 { + return Ok(s.clone()); + } + validate_old(old)?; + + let dtype = s.dtype(); + let old = cast_old_to_series_dtype(old, dtype)?; + let new = new.strict_cast(dtype)?; + + if new.len() == 1 { + replace_by_single(s, &old, &new, s) + } else { + replace_by_multiple(s, old, new, s) + } +} + +/// Replace all values by different values. +/// +/// Unmatched values are replaced by a default value. +pub fn replace_or_default( s: &Series, old: &Series, new: &Series, @@ -15,20 +34,15 @@ pub fn replace( return_dtype: Option, ) -> PolarsResult { polars_ensure!( - old.n_unique()? == old.len(), - ComputeError: "`old` input for `replace` must not contain duplicates" + default.len() == s.len() || default.len() == 1, + InvalidOperation: "`default` input for `replace_strict` must have the same length as the input or have length 1" ); + validate_old(old)?; let return_dtype = match return_dtype { Some(dtype) => dtype, None => try_get_supertype(new.dtype(), default.dtype())?, }; - - polars_ensure!( - default.len() == s.len() || default.len() == 1, - ComputeError: "`default` input for `replace` must have the same length as the input or have length 1" - ); - let default = default.cast(&return_dtype)?; if old.len() == 0 { @@ -37,18 +51,10 @@ pub fn replace( } else { default }; - return Ok(out); } - let old = match (s.dtype(), old.dtype()) { - #[cfg(feature = "dtype-categorical")] - (DataType::Categorical(_, ord), DataType::String) => { - let dt = DataType::Categorical(None, *ord); - old.strict_cast(&dt)? - }, - _ => old.strict_cast(s.dtype())?, - }; + let old = cast_old_to_series_dtype(old, s.dtype())?; let new = new.cast(&return_dtype)?; if new.len() == 1 { @@ -58,6 +64,58 @@ pub fn replace( } } +/// Replace all values by different values. +/// +/// Raises an error if not all values were replaced. +pub fn replace_strict( + s: &Series, + old: &Series, + new: &Series, + return_dtype: Option, +) -> PolarsResult { + if old.len() == 0 { + polars_ensure!( + s.len() == s.null_count(), + InvalidOperation: "must specify which values to replace" + ); + return Ok(s.clone()); + } + validate_old(old)?; + + let old = cast_old_to_series_dtype(old, s.dtype())?; + let new = match return_dtype { + Some(dtype) => new.strict_cast(&dtype)?, + None => new.clone(), + }; + + if new.len() == 1 { + replace_by_single_strict(s, &old, &new) + } else { + replace_by_multiple_strict(s, old, new) + } +} + +/// Validate the `old` input. +fn validate_old(old: &Series) -> PolarsResult<()> { + polars_ensure!( + old.n_unique()? == old.len(), + InvalidOperation: "`old` input for `replace` must not contain duplicates" + ); + Ok(()) +} + +/// Cast `old` input while enabling String to Categorical casts. +fn cast_old_to_series_dtype(old: &Series, dtype: &DataType) -> PolarsResult { + match (old.dtype(), dtype) { + #[cfg(feature = "dtype-categorical")] + (DataType::String, DataType::Categorical(_, ord)) => { + let empty_categorical_dtype = DataType::Categorical(None, *ord); + old.strict_cast(&empty_categorical_dtype) + }, + _ => old.strict_cast(dtype), + } +} + // Fast path for replacing by a single value fn replace_by_single( s: &Series, @@ -65,19 +123,36 @@ fn replace_by_single( new: &Series, default: &Series, ) -> PolarsResult { - let mask = if old.null_count() == old.len() { - s.is_null() - } else { - let mask = is_in(s, old)?; - - if old.null_count() == 0 { - mask - } else { - mask.bitor(s.is_null()) - } - }; + let mut mask = get_replacement_mask(s, old)?; + if old.null_count() > 0 { + mask = mask.fill_null_with_values(true)?; + } new.zip_with(&mask, default) } +/// Fast path for replacing by a single value in strict mode +fn replace_by_single_strict(s: &Series, old: &Series, new: &Series) -> PolarsResult { + let mask = get_replacement_mask(s, old)?; + ensure_all_replaced(&mask, s, old.null_count() > 0)?; + + let mut out = new.new_from_index(0, s.len()); + + // Transfer validity from `mask` to `out`. + if mask.null_count() > 0 { + out = out.zip_with(&mask, &Series::new_null("", s.len()))? + } + Ok(out) +} +/// Get a boolean mask of which values in the original Series will be replaced. +/// +/// Null values are propagated to the mask. +fn get_replacement_mask(s: &Series, old: &Series) -> PolarsResult { + if old.null_count() == old.len() { + // Fast path for when users are using `replace(None, ...)` instead of `fill_null`. + Ok(s.is_null()) + } else { + is_in(s, old) + } +} /// General case for replacing by multiple values fn replace_by_multiple( @@ -86,13 +161,11 @@ fn replace_by_multiple( new: Series, default: &Series, ) -> PolarsResult { - polars_ensure!( - new.len() == old.len(), - ComputeError: "`new` input for `replace` must have the same length as `old` or have length 1" - ); + validate_new(&new, &old)?; let df = s.clone().into_frame(); - let replacer = create_replacer(old, new)?; + let add_replacer_mask = new.null_count() > 0; + let replacer = create_replacer(old, new, add_replacer_mask)?; let joined = df.join( &replacer, @@ -100,6 +173,7 @@ fn replace_by_multiple( ["__POLARS_REPLACE_OLD"], JoinArgs { how: JoinType::Left, + coalesce: JoinCoalesce::CoalesceColumns, join_nulls: true, ..Default::default() }, @@ -123,12 +197,44 @@ fn replace_by_multiple( } } -// Build replacer dataframe -fn create_replacer(mut old: Series, mut new: Series) -> PolarsResult { +/// General case for replacing by multiple values in strict mode +fn replace_by_multiple_strict(s: &Series, old: Series, new: Series) -> PolarsResult { + validate_new(&new, &old)?; + + let df = s.clone().into_frame(); + let old_has_null = old.null_count() > 0; + let replacer = create_replacer(old, new, true)?; + + let joined = df.join( + &replacer, + [s.name()], + ["__POLARS_REPLACE_OLD"], + JoinArgs { + how: JoinType::Left, + coalesce: JoinCoalesce::CoalesceColumns, + join_nulls: true, + ..Default::default() + }, + )?; + + let replaced = joined.column("__POLARS_REPLACE_NEW").unwrap(); + + let mask = joined + .column("__POLARS_REPLACE_MASK") + .unwrap() + .bool() + .unwrap(); + ensure_all_replaced(mask, s, old_has_null)?; + + Ok(replaced.clone()) +} + +// Build replacer dataframe. +fn create_replacer(mut old: Series, mut new: Series, add_mask: bool) -> PolarsResult { old.rename("__POLARS_REPLACE_OLD"); new.rename("__POLARS_REPLACE_NEW"); - let cols = if new.null_count() > 0 { + let cols = if add_mask { let mask = Series::new("__POLARS_REPLACE_MASK", &[true]).new_from_index(0, new.len()); vec![old, new, mask] } else { @@ -137,3 +243,26 @@ fn create_replacer(mut old: Series, mut new: Series) -> PolarsResult let out = unsafe { DataFrame::new_no_checks(cols) }; Ok(out) } + +/// Validate the `new` input. +fn validate_new(new: &Series, old: &Series) -> PolarsResult<()> { + polars_ensure!( + new.len() == old.len(), + InvalidOperation: "`new` input for `replace` must have the same length as `old` or have length 1" + ); + Ok(()) +} + +/// Ensure that all values were replaced. +fn ensure_all_replaced(mask: &BooleanChunked, s: &Series, old_has_null: bool) -> PolarsResult<()> { + let all_replaced = if old_has_null { + mask.null_count() == 0 + } else { + mask.null_count() == s.null_count() + }; + polars_ensure!( + all_replaced, + InvalidOperation: "incomplete mapping specified for `replace_strict`\n\nHint: Pass a `default` value to set unmapped values." + ); + Ok(()) +} diff --git a/crates/polars-plan/src/dsl/function_expr/dispatch.rs b/crates/polars-plan/src/dsl/function_expr/dispatch.rs index 0be7f34ff86e..cd82ae4251d8 100644 --- a/crates/polars-plan/src/dsl/function_expr/dispatch.rs +++ b/crates/polars-plan/src/dsl/function_expr/dispatch.rs @@ -156,9 +156,18 @@ pub(super) fn hist( } #[cfg(feature = "replace")] -pub(super) fn replace(s: &[Series], return_dtype: Option) -> PolarsResult { - let default = if let Some(s) = s.get(3) { s } else { &s[0] }; - polars_ops::series::replace(&s[0], &s[1], &s[2], default, return_dtype) +pub(super) fn replace(s: &[Series]) -> PolarsResult { + polars_ops::series::replace(&s[0], &s[1], &s[2]) +} + +#[cfg(feature = "replace")] +pub(super) fn replace_strict(s: &[Series], return_dtype: Option) -> PolarsResult { + match s.get(3) { + Some(default) => { + polars_ops::series::replace_or_default(&s[0], &s[1], &s[2], default, return_dtype) + }, + None => polars_ops::series::replace_strict(&s[0], &s[1], &s[2], return_dtype), + } } pub(super) fn fill_null_with_strategy( diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index 0dd79141ea02..dfc390d0b1dc 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -340,7 +340,9 @@ pub enum FunctionExpr { options: EWMOptions, }, #[cfg(feature = "replace")] - Replace { + Replace, + #[cfg(feature = "replace")] + ReplaceStrict { return_dtype: Option, }, GatherEvery { @@ -567,7 +569,9 @@ impl Hash for FunctionExpr { include_breakpoint.hash(state); }, #[cfg(feature = "replace")] - Replace { return_dtype } => return_dtype.hash(state), + Replace => {}, + #[cfg(feature = "replace")] + ReplaceStrict { return_dtype } => return_dtype.hash(state), FillNullWithStrategy(strategy) => strategy.hash(state), GatherEvery { n, offset } => (n, offset).hash(state), #[cfg(feature = "reinterpret")] @@ -752,7 +756,9 @@ impl Display for FunctionExpr { #[cfg(feature = "hist")] Hist { .. } => "hist", #[cfg(feature = "replace")] - Replace { .. } => "replace", + Replace => "replace", + #[cfg(feature = "replace")] + ReplaceStrict { .. } => "replace_strict", FillNullWithStrategy(_) => "fill_null_with_strategy", GatherEvery { .. } => "gather_every", #[cfg(feature = "reinterpret")] @@ -1138,9 +1144,14 @@ impl From for SpecialEq> { #[cfg(feature = "ewma")] EwmVar { options } => map!(ewm::ewm_var, options), #[cfg(feature = "replace")] - Replace { return_dtype } => { - map_as_slice!(dispatch::replace, return_dtype.clone()) + Replace => { + map_as_slice!(dispatch::replace) }, + #[cfg(feature = "replace")] + ReplaceStrict { return_dtype } => { + map_as_slice!(dispatch::replace_strict, return_dtype.clone()) + }, + FillNullWithStrategy(strategy) => map!(dispatch::fill_null_with_strategy, strategy), GatherEvery { n, offset } => map!(dispatch::gather_every, n, offset), #[cfg(feature = "reinterpret")] diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index 686aea7f59c1..a385c27820d6 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -314,7 +314,9 @@ impl FunctionExpr { #[cfg(feature = "ewma")] EwmVar { .. } => mapper.map_to_float_dtype(), #[cfg(feature = "replace")] - Replace { return_dtype } => mapper.replace_dtype(return_dtype.clone()), + Replace => mapper.with_same_dtype(), + #[cfg(feature = "replace")] + ReplaceStrict { return_dtype } => mapper.replace_dtype(return_dtype.clone()), FillNullWithStrategy(_) => mapper.with_same_dtype(), GatherEvery { .. } => mapper.with_same_dtype(), #[cfg(feature = "reinterpret")] @@ -532,11 +534,13 @@ impl<'a> FieldsMapper<'a> { pub fn replace_dtype(&self, return_dtype: Option) -> PolarsResult { let dtype = match return_dtype { Some(dtype) => dtype, - // Supertype of `new` and `default` None => { - let column_type = &self.fields[0]; let new = &self.fields[2]; - try_get_supertype(column_type.data_type(), new.data_type())? + let default = self.fields.get(3); + match default { + Some(default) => try_get_supertype(default.data_type(), new.data_type())?, + None => new.data_type().clone(), + } }, }; self.with_dtype(dtype) diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 1e3c71460905..5c591f1e0e03 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -1495,7 +1495,25 @@ impl Expr { #[cfg(feature = "replace")] /// Replace the given values with other values. - pub fn replace>( + pub fn replace>(self, old: E, new: E) -> Expr { + let old = old.into(); + let new = new.into(); + + // If we search and replace by literals, we can run on batches. + let literal_searchers = matches!(&old, Expr::Literal(_)) & matches!(&new, Expr::Literal(_)); + + let args = [old, new]; + + if literal_searchers { + self.map_many_private(FunctionExpr::Replace, &args, false, false) + } else { + self.apply_many_private(FunctionExpr::Replace, &args, false, false) + } + } + + #[cfg(feature = "replace")] + /// Replace the given values with other values. + pub fn replace_strict>( self, old: E, new: E, @@ -1504,7 +1522,8 @@ impl Expr { ) -> Expr { let old = old.into(); let new = new.into(); - // If we search and replace by literals, we can run on batches. + + // If we replace by literals, we can run on batches. let literal_searchers = matches!(&old, Expr::Literal(_)) & matches!(&new, Expr::Literal(_)); let mut args = vec![old, new]; @@ -1513,9 +1532,19 @@ impl Expr { } if literal_searchers { - self.map_many_private(FunctionExpr::Replace { return_dtype }, &args, false, false) + self.map_many_private( + FunctionExpr::ReplaceStrict { return_dtype }, + &args, + false, + false, + ) } else { - self.apply_many_private(FunctionExpr::Replace { return_dtype }, &args, false, false) + self.apply_many_private( + FunctionExpr::ReplaceStrict { return_dtype }, + &args, + false, + false, + ) } } diff --git a/py-polars/src/expr/general.rs b/py-polars/src/expr/general.rs index d4f7c5cd5d83..7028d434ae7e 100644 --- a/py-polars/src/expr/general.rs +++ b/py-polars/src/expr/general.rs @@ -924,7 +924,11 @@ impl PyExpr { self.inner.clone().set_sorted_flag(is_sorted).into() } - fn replace( + fn replace(&self, old: PyExpr, new: PyExpr) -> Self { + self.inner.clone().replace(old.inner, new.inner).into() + } + + fn replace_strict( &self, old: PyExpr, new: PyExpr, @@ -933,7 +937,7 @@ impl PyExpr { ) -> Self { self.inner .clone() - .replace( + .replace_strict( old.inner, new.inner, default.map(|e| e.inner), diff --git a/py-polars/src/lazyframe/visitor/expr_nodes.rs b/py-polars/src/lazyframe/visitor/expr_nodes.rs index 51aed93d2fc1..691be602ee72 100644 --- a/py-polars/src/lazyframe/visitor/expr_nodes.rs +++ b/py-polars/src/lazyframe/visitor/expr_nodes.rs @@ -1218,8 +1218,9 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { FunctionExpr::EwmVar { options: _ } => { return Err(PyNotImplementedError::new_err("ewm var")) }, - FunctionExpr::Replace { return_dtype: _ } => { - return Err(PyNotImplementedError::new_err("replace")) + FunctionExpr::Replace => return Err(PyNotImplementedError::new_err("replace")), + FunctionExpr::ReplaceStrict { return_dtype: _ } => { + return Err(PyNotImplementedError::new_err("replace_strict")) }, FunctionExpr::Negate => return Err(PyNotImplementedError::new_err("negate")), FunctionExpr::FillNullWithStrategy(_) => { From 62463f8cd7b44cff800a4fd3087504bfffe56d72 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Thu, 13 Jun 2024 10:17:20 +0200 Subject: [PATCH 4/5] Update tests --- .../map/test_inefficient_map_warning.py | 4 +- .../tests/unit/operations/test_replace.py | 349 ++--------------- .../unit/operations/test_replace_strict.py | 358 ++++++++++++++++++ 3 files changed, 391 insertions(+), 320 deletions(-) create mode 100644 py-polars/tests/unit/operations/test_replace_strict.py diff --git a/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py b/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py index adf75e8fbe85..4c6877d08694 100644 --- a/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py +++ b/py-polars/tests/unit/operations/map/test_inefficient_map_warning.py @@ -184,11 +184,11 @@ # --------------------------------------------- # replace # --------------------------------------------- - ("a", "lambda x: MY_DICT[x]", 'pl.col("a").replace(MY_DICT)'), + ("a", "lambda x: MY_DICT[x]", 'pl.col("a").replace_strict(MY_DICT)'), ( "a", "lambda x: MY_DICT[x - 1] + MY_DICT[1 + x]", - '(pl.col("a") - 1).replace(MY_DICT) + (1 + pl.col("a")).replace(MY_DICT)', + '(pl.col("a") - 1).replace_strict(MY_DICT) + (1 + pl.col("a")).replace_strict(MY_DICT)', ), # --------------------------------------------- # standard library datetime parsing diff --git a/py-polars/tests/unit/operations/test_replace.py b/py-polars/tests/unit/operations/test_replace.py index f52bbc5b9488..03d1feb2681c 100644 --- a/py-polars/tests/unit/operations/test_replace.py +++ b/py-polars/tests/unit/operations/test_replace.py @@ -1,16 +1,11 @@ from __future__ import annotations -import contextlib from typing import Any import pytest import polars as pl -from polars.exceptions import ( - CategoricalRemappingWarning, - ComputeError, - InvalidOperationError, -) +from polars.exceptions import InvalidOperationError from polars.testing import assert_frame_equal, assert_series_equal @@ -31,44 +26,6 @@ def test_replace_str_to_str(str_mapping: dict[str | None, str]) -> None: assert_frame_equal(result, expected) -def test_replace_str_to_str_default_self(str_mapping: dict[str | None, str]) -> None: - df = pl.DataFrame({"country_code": ["FR", None, "ES", "DE"]}) - result = df.select( - replaced=pl.col("country_code").replace( - str_mapping, default=pl.col("country_code") - ) - ) - expected = pl.DataFrame({"replaced": ["France", "Not specified", "ES", "Germany"]}) - assert_frame_equal(result, expected) - - -def test_replace_str_to_str_default_null(str_mapping: dict[str | None, str]) -> None: - df = pl.DataFrame({"country_code": ["FR", None, "ES", "DE"]}) - result = df.select( - replaced=pl.col("country_code").replace(str_mapping, default=None) - ) - expected = pl.DataFrame({"replaced": ["France", "Not specified", None, "Germany"]}) - assert_frame_equal(result, expected) - - -def test_replace_str_to_str_default_other(str_mapping: dict[str | None, str]) -> None: - df = pl.DataFrame({"country_code": ["FR", None, "ES", "DE"]}) - - result = df.with_row_index().select( - replaced=pl.col("country_code").replace(str_mapping, default=pl.col("index")) - ) - expected = pl.DataFrame({"replaced": ["France", "Not specified", "2", "Germany"]}) - assert_frame_equal(result, expected) - - -def test_replace_str_to_cat() -> None: - s = pl.Series(["a", "b", "c"]) - mapping = {"a": "c", "b": "d"} - result = s.replace(mapping, return_dtype=pl.Categorical) - expected = pl.Series(["c", "d", "c"], dtype=pl.Categorical) - assert_series_equal(result, expected, categorical_as_str=True) - - def test_replace_enum() -> None: dtype = pl.Enum(["a", "b", "c", "d"]) s = pl.Series(["a", "b", "c"], dtype=dtype) @@ -87,20 +44,7 @@ def test_replace_enum_to_str() -> None: result = s.replace({"a": "c", "b": "d"}) - expected = pl.Series(["c", "d", "c"], dtype=pl.String) - assert_series_equal(result, expected) - - -def test_replace_enum_to_new_enum() -> None: - s = pl.Series(["a", "b", "c"], dtype=pl.Enum(["a", "b", "c", "d"])) - old = ["a", "b"] - - new_dtype = pl.Enum(["a", "b", "c", "d", "e"]) - new = pl.Series(["c", "e"], dtype=new_dtype) - - result = s.replace(old, new, return_dtype=new_dtype) - - expected = pl.Series(["c", "e", "c"], dtype=new_dtype) + expected = pl.Series(["c", "d", "c"], dtype=dtype) assert_series_equal(result, expected) @@ -138,7 +82,7 @@ def test_replace_int_to_int() -> None: mapping = {1: 5, 3: 7} result = df.select(replaced=pl.col("int").replace(mapping)) expected = pl.DataFrame( - {"replaced": [None, 5, None, 7]}, schema={"replaced": pl.Int64} + {"replaced": [None, 5, None, 7]}, schema={"replaced": pl.Int16} ) assert_frame_equal(result, expected) @@ -155,56 +99,22 @@ def test_replace_int_to_int_keep_dtype() -> None: assert_frame_equal(result, expected) -def test_replace_int_to_str2() -> None: +def test_replace_int_to_str() -> None: df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) mapping = {1: "b", 3: "d"} - result = df.select(replaced=pl.col("int").replace(mapping)) - expected = pl.DataFrame({"replaced": [None, "b", None, "d"]}) - assert_frame_equal(result, expected) + with pytest.raises( + InvalidOperationError, match="conversion from `str` to `i16` failed" + ): + df.select(replaced=pl.col("int").replace(mapping)) def test_replace_int_to_str_with_null() -> None: df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) mapping = {1: "b", 3: "d", None: "e"} - result = df.select(replaced=pl.col("int").replace(mapping)) - expected = pl.DataFrame({"replaced": ["e", "b", "e", "d"]}) - assert_frame_equal(result, expected) - - -def test_replace_int_to_int_null() -> None: - df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) - mapping = {3: None} - result = df.select( - replaced=pl.col("int").replace(mapping, default=pl.lit(6).cast(pl.Int16)) - ) - expected = pl.DataFrame( - {"replaced": [6, 6, 6, None]}, schema={"replaced": pl.Int16} - ) - assert_frame_equal(result, expected) - - -def test_replace_int_to_int_null_default_null() -> None: - df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) - mapping = {3: None} - result = df.select(replaced=pl.col("int").replace(mapping, default=None)) - expected = pl.DataFrame( - {"replaced": [None, None, None, None]}, schema={"replaced": pl.Null} - ) - assert_frame_equal(result, expected) - - -def test_replace_int_to_int_null_return_dtype() -> None: - df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) - mapping = {3: None} - - result = df.select( - replaced=pl.col("int").replace(mapping, default=6, return_dtype=pl.Int32) - ) - - expected = pl.DataFrame( - {"replaced": [6, 6, 6, None]}, schema={"replaced": pl.Int32} - ) - assert_frame_equal(result, expected) + with pytest.raises( + InvalidOperationError, match="conversion from `str` to `i16` failed" + ): + df.select(replaced=pl.col("int").replace(mapping)) def test_replace_empty_mapping() -> None: @@ -214,14 +124,6 @@ def test_replace_empty_mapping() -> None: assert_frame_equal(result, df) -def test_replace_empty_mapping_default() -> None: - df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) - mapping: dict[Any, Any] = {} - result = df.select(pl.col("int").replace(mapping, default=pl.lit("A"))) - expected = pl.DataFrame({"int": ["A", "A", "A", "A"]}) - assert_frame_equal(result, expected) - - def test_replace_mapping_different_dtype_str_int() -> None: df = pl.DataFrame({"int": [None, "1", None, "3"]}) mapping = {1: "b", 3: "d"} @@ -256,60 +158,6 @@ def test_replace_str_to_str_replace_all() -> None: assert_frame_equal(result, expected) -def test_replace_int_to_int_df() -> None: - lf = pl.LazyFrame({"a": [1, 2, 3]}, schema={"a": pl.UInt8}) - mapping = {1: 11, 2: 22} - - result = lf.select( - pl.col("a").replace( - old=pl.Series(mapping.keys()), - new=pl.Series(mapping.values(), dtype=pl.UInt8), - default=pl.lit(99).cast(pl.UInt8), - ) - ) - expected = pl.LazyFrame({"a": [11, 22, 99]}, schema_overrides={"a": pl.UInt8}) - assert_frame_equal(result, expected) - - -def test_replace_str_to_int_fill_null() -> None: - lf = pl.LazyFrame({"a": ["one", "two"]}) - mapping = {"one": 1} - - result = lf.select( - pl.col("a") - .replace(mapping, default=None, return_dtype=pl.UInt32) - .fill_null(999) - ) - - expected = pl.LazyFrame({"a": pl.Series([1, 999], dtype=pl.UInt32)}) - assert_frame_equal(result, expected) - - -def test_replace_mix() -> None: - df = pl.DataFrame( - [ - pl.Series("float_to_boolean", [1.0, None]), - pl.Series("boolean_to_int", [True, False]), - pl.Series("boolean_to_str", [True, False]), - ] - ) - - result = df.with_columns( - pl.col("float_to_boolean").replace({1.0: True}, default=None), - pl.col("boolean_to_int").replace({True: 1, False: 0}), - pl.col("boolean_to_str").replace({True: "1", False: "0"}), - ) - - expected = pl.DataFrame( - [ - pl.Series("float_to_boolean", [True, None], dtype=pl.Boolean), - pl.Series("boolean_to_int", [1, 0], dtype=pl.Int64), - pl.Series("boolean_to_str", ["1", "0"], dtype=pl.String), - ] - ) - assert_frame_equal(result, expected) - - @pytest.fixture(scope="module") def int_mapping() -> dict[int, int]: return {1: 11, 2: 22, 3: 33, 4: 44, 5: 55} @@ -322,20 +170,6 @@ def test_replace_int_to_int1(int_mapping: dict[int, int]) -> None: assert_series_equal(result, expected) -def test_replace_int_to_int2(int_mapping: dict[int, int]) -> None: - s = pl.Series([1, 22, None, 44, -5]) - result = s.replace(int_mapping, default=None) - expected = pl.Series([11, None, None, None, None], dtype=pl.Int64) - assert_series_equal(result, expected) - - -def test_replace_int_to_int3(int_mapping: dict[int, int]) -> None: - s = pl.Series([1, 22, None, 44, -5], dtype=pl.Int16) - result = s.replace(int_mapping, default=9) - expected = pl.Series([11, 9, 9, 9, 9], dtype=pl.Int64) - assert_series_equal(result, expected) - - def test_replace_int_to_int4(int_mapping: dict[int, int]) -> None: s = pl.Series([-1, 22, None, 44, -5]) result = s.replace(int_mapping) @@ -343,72 +177,6 @@ def test_replace_int_to_int4(int_mapping: dict[int, int]) -> None: assert_series_equal(result, expected) -def test_replace_int_to_int4_return_dtype(int_mapping: dict[int, int]) -> None: - s = pl.Series([-1, 22, None, 44, -5], dtype=pl.Int16) - result = s.replace(int_mapping, return_dtype=pl.Float32) - expected = pl.Series([-1.0, 22.0, None, 44.0, -5.0], dtype=pl.Float32) - assert_series_equal(result, expected) - - -def test_replace_int_to_int5_return_dtype(int_mapping: dict[int, int]) -> None: - s = pl.Series([1, 22, None, 44, -5], dtype=pl.Int16) - result = s.replace(int_mapping, default=9, return_dtype=pl.Float32) - expected = pl.Series([11.0, 9.0, 9.0, 9.0, 9.0], dtype=pl.Float32) - assert_series_equal(result, expected) - - -def test_replace_bool_to_int() -> None: - s = pl.Series([True, False, False, None]) - mapping = {True: 1, False: 0} - result = s.replace(mapping) - expected = pl.Series([1, 0, 0, None]) - assert_series_equal(result, expected) - - -def test_replace_bool_to_str() -> None: - s = pl.Series([True, False, False, None]) - mapping = {True: "1", False: "0"} - result = s.replace(mapping) - expected = pl.Series(["1", "0", "0", None]) - assert_series_equal(result, expected) - - -def test_replace_str_to_bool_without_default() -> None: - s = pl.Series(["True", "False", "False", None]) - mapping = {"True": True, "False": False} - result = s.replace(mapping) - expected = pl.Series(["true", "false", "false", None]) - assert_series_equal(result, expected) - - -def test_replace_str_to_bool_with_default() -> None: - s = pl.Series(["True", "False", "False", None]) - mapping = {"True": True, "False": False} - result = s.replace(mapping, default=None) - expected = pl.Series([True, False, False, None]) - assert_series_equal(result, expected) - - -def test_replace_int_to_str() -> None: - s = pl.Series("a", [-1, 2, None, 4, -5]) - mapping = {1: "one", 2: "two", 3: "three", 4: "four", 5: "five"} - - result = s.replace(mapping) - - expected = pl.Series("a", ["-1", "two", None, "four", "-5"]) - assert_series_equal(result, expected) - - -def test_replace_int_to_str_with_default() -> None: - s = pl.Series("a", [1, 2, None, 4, 5]) - mapping = {1: "one", 2: "two", 3: "three", 4: "four", 5: "five"} - - result = s.replace(mapping, default="?") - - expected = pl.Series("a", ["one", "two", "?", "four", "five"]) - assert_series_equal(result, expected) - - # https://github.com/pola-rs/polars/issues/12728 def test_replace_str_to_int2() -> None: s = pl.Series(["a", "b"]) @@ -418,11 +186,11 @@ def test_replace_str_to_int2() -> None: assert_series_equal(result, expected) -def test_replace_str_to_int_with_default() -> None: - s = pl.Series(["a", "b"]) - mapping = {"a": 1, "b": 2} - result = s.replace(mapping, default=None) - expected = pl.Series([1, 2]) +def test_replace_str_to_bool_without_default() -> None: + s = pl.Series(["True", "False", "False", None]) + mapping = {"True": True, "False": False} + result = s.replace(mapping) + expected = pl.Series(["true", "false", "false", None]) assert_series_equal(result, expected) @@ -442,7 +210,7 @@ def test_replace_old_new_many_to_one() -> None: def test_replace_old_new_mismatched_lengths() -> None: s = pl.Series([1, 2, 2, 3, 4]) - with pytest.raises(ComputeError): + with pytest.raises(InvalidOperationError): s.replace([2, 3, 4], [8, 9]) @@ -475,20 +243,6 @@ def test_replace_fast_path_many_to_one() -> None: assert_frame_equal(result, expected) -def test_replace_fast_path_many_to_one_default() -> None: - lf = pl.LazyFrame({"a": [1, 2, 2, 3]}) - result = lf.select(pl.col("a").replace([2, 3], 100, default=-1)) - expected = pl.LazyFrame({"a": [-1, 100, 100, 100]}, schema={"a": pl.Int64}) - assert_frame_equal(result, expected) - - -def test_replace_fast_path_many_to_one_null() -> None: - lf = pl.LazyFrame({"a": [1, 2, 2, 3]}) - result = lf.select(pl.col("a").replace([2, 3], None, default=-1)) - expected = pl.LazyFrame({"a": [-1, None, None, None]}, schema={"a": pl.Int64}) - assert_frame_equal(result, expected) - - @pytest.mark.parametrize( ("old", "new"), [ @@ -500,7 +254,7 @@ def test_replace_fast_path_many_to_one_null() -> None: def test_replace_duplicates_old(old: list[int], new: int | list[int]) -> None: s = pl.Series([1, 2, 3, 2, 3]) with pytest.raises( - ComputeError, + InvalidOperationError, match="`old` input for `replace` must not contain duplicates", ): s.replace(old, new) @@ -513,58 +267,17 @@ def test_replace_duplicates_new() -> None: assert_series_equal(result, expected) -@pytest.mark.parametrize( - ("context", "dtype"), - [ - (pl.StringCache(), pl.Categorical), - (pytest.warns(CategoricalRemappingWarning), pl.Categorical), - (contextlib.nullcontext(), pl.Enum(["a", "b", "OTHER"])), - ], -) -def test_replace_cat_str( - context: contextlib.AbstractContextManager, # type: ignore[type-arg] - dtype: pl.DataType, -) -> None: - with context: - for old, new, expected in [ - ("a", "c", pl.Series("s", ["c", None], dtype=pl.Utf8)), - (["a", "b"], ["c", "d"], pl.Series("s", ["c", "d"], dtype=pl.Utf8)), - (pl.lit("a", dtype=dtype), "c", pl.Series("s", ["c", None], dtype=pl.Utf8)), - ( - pl.Series(["a", "b"], dtype=dtype), - ["c", "d"], - pl.Series("s", ["c", "d"], dtype=pl.Utf8), - ), - ]: - s = pl.Series("s", ["a", "b"], dtype=dtype) - s_replaced = s.replace(old, new, default=None) # type: ignore[arg-type] - assert_series_equal(s_replaced, expected) - - s = pl.Series("s", ["a", "b"], dtype=dtype) - s_replaced = s.replace(old, new, default="OTHER") # type: ignore[arg-type] - assert_series_equal(s_replaced, expected.fill_null("OTHER")) +def test_replace_return_dtype_deprecated() -> None: + s = pl.Series([1, 2, 3]) + with pytest.deprecated_call(): + result = s.replace(1, 10, return_dtype=pl.Int8) + expected = pl.Series([10, 2, 3], dtype=pl.Int8) + assert_series_equal(result, expected) -@pytest.mark.parametrize( - "context", [pl.StringCache(), pytest.warns(CategoricalRemappingWarning)] -) -def test_replace_cat_cat( - context: contextlib.AbstractContextManager, # type: ignore[type-arg] -) -> None: - with context: - dt = pl.Categorical - for old, new, expected in [ - ("a", pl.lit("c", dtype=dt), pl.Series("s", ["c", None], dtype=dt)), - ( - ["a", "b"], - pl.Series(["c", "d"], dtype=dt), - pl.Series("s", ["c", "d"], dtype=dt), - ), - ]: - s = pl.Series("s", ["a", "b"], dtype=dt) - s_replaced = s.replace(old, new, default=None) # type: ignore[arg-type] - assert_series_equal(s_replaced, expected) - - s = pl.Series("s", ["a", "b"], dtype=dt) - s_replaced = s.replace(old, new, default=pl.lit("OTHER", dtype=dt)) # type: ignore[arg-type] - assert_series_equal(s_replaced, expected.fill_null("OTHER")) +def test_replace_default_deprecated() -> None: + s = pl.Series([1, 2, 3]) + with pytest.deprecated_call(): + result = s.replace(1, 10, default=None) + expected = pl.Series([10, None, None], dtype=pl.Int32) + assert_series_equal(result, expected) diff --git a/py-polars/tests/unit/operations/test_replace_strict.py b/py-polars/tests/unit/operations/test_replace_strict.py new file mode 100644 index 000000000000..1ed0684599ed --- /dev/null +++ b/py-polars/tests/unit/operations/test_replace_strict.py @@ -0,0 +1,358 @@ +from __future__ import annotations + +import contextlib +from typing import Any + +import pytest + +import polars as pl +from polars.exceptions import CategoricalRemappingWarning, InvalidOperationError +from polars.testing import assert_frame_equal, assert_series_equal + + +def test_replace_strict_incomplete_mapping() -> None: + lf = pl.LazyFrame({"a": [1, 2, 2, 3]}) + + with pytest.raises(InvalidOperationError, match="incomplete mapping"): + lf.select(pl.col("a").replace_strict({2: 200, 3: 300})).collect() + + +def test_replace_strict_empty() -> None: + lf = pl.LazyFrame({"a": [None, None]}) + result = lf.select(pl.col("a").replace_strict({})) + assert_frame_equal(lf, result) + + +def test_replace_strict_fast_path_many_to_one_default() -> None: + lf = pl.LazyFrame({"a": [1, 2, 2, 3]}) + result = lf.select(pl.col("a").replace_strict([2, 3], 100, default=-1)) + expected = pl.LazyFrame({"a": [-1, 100, 100, 100]}, schema={"a": pl.Int32}) + assert_frame_equal(result, expected) + + +def test_replace_strict_fast_path_many_to_one_null() -> None: + lf = pl.LazyFrame({"a": [1, 2, 2, 3]}) + result = lf.select(pl.col("a").replace_strict([2, 3], None, default=-1)) + expected = pl.LazyFrame({"a": [-1, None, None, None]}, schema={"a": pl.Int32}) + assert_frame_equal(result, expected) + + +@pytest.fixture(scope="module") +def str_mapping() -> dict[str | None, str]: + return { + "CA": "Canada", + "DE": "Germany", + "FR": "France", + None: "Not specified", + } + + +def test_replace_strict_str_to_str_default_self( + str_mapping: dict[str | None, str], +) -> None: + df = pl.DataFrame({"country_code": ["FR", None, "ES", "DE"]}) + result = df.select( + replaced=pl.col("country_code").replace_strict( + str_mapping, default=pl.col("country_code") + ) + ) + expected = pl.DataFrame({"replaced": ["France", "Not specified", "ES", "Germany"]}) + assert_frame_equal(result, expected) + + +def test_replace_strict_str_to_str_default_null( + str_mapping: dict[str | None, str], +) -> None: + df = pl.DataFrame({"country_code": ["FR", None, "ES", "DE"]}) + result = df.select( + replaced=pl.col("country_code").replace_strict(str_mapping, default=None) + ) + expected = pl.DataFrame({"replaced": ["France", "Not specified", None, "Germany"]}) + assert_frame_equal(result, expected) + + +def test_replace_strict_str_to_str_default_other( + str_mapping: dict[str | None, str], +) -> None: + df = pl.DataFrame({"country_code": ["FR", None, "ES", "DE"]}) + + result = df.with_row_index().select( + replaced=pl.col("country_code").replace_strict( + str_mapping, default=pl.col("index") + ) + ) + expected = pl.DataFrame({"replaced": ["France", "Not specified", "2", "Germany"]}) + assert_frame_equal(result, expected) + + +def test_replace_strict_str_to_cat() -> None: + s = pl.Series(["a", "b", "c"]) + mapping = {"a": "c", "b": "d"} + result = s.replace_strict(mapping, default=None, return_dtype=pl.Categorical) + expected = pl.Series(["c", "d", None], dtype=pl.Categorical) + assert_series_equal(result, expected, categorical_as_str=True) + + +def test_replace_strict_enum_to_new_enum() -> None: + s = pl.Series(["a", "b", "c"], dtype=pl.Enum(["a", "b", "c", "d"])) + old = ["a", "b"] + + new_dtype = pl.Enum(["a", "b", "c", "d", "e"]) + new = pl.Series(["c", "e"], dtype=new_dtype) + + result = s.replace_strict(old, new, default=None, return_dtype=new_dtype) + + expected = pl.Series(["c", "e", None], dtype=new_dtype) + assert_series_equal(result, expected) + + +def test_replace_strict_int_to_int_null() -> None: + df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) + mapping = {3: None} + result = df.select( + replaced=pl.col("int").replace_strict(mapping, default=pl.lit(6).cast(pl.Int16)) + ) + expected = pl.DataFrame( + {"replaced": [6, 6, 6, None]}, schema={"replaced": pl.Int16} + ) + assert_frame_equal(result, expected) + + +def test_replace_strict_int_to_int_null_default_null() -> None: + df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) + mapping = {3: None} + result = df.select(replaced=pl.col("int").replace_strict(mapping, default=None)) + expected = pl.DataFrame( + {"replaced": [None, None, None, None]}, schema={"replaced": pl.Null} + ) + assert_frame_equal(result, expected) + + +def test_replace_strict_int_to_int_null_return_dtype() -> None: + df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) + mapping = {3: None} + + result = df.select( + replaced=pl.col("int").replace_strict(mapping, default=6, return_dtype=pl.Int32) + ) + + expected = pl.DataFrame( + {"replaced": [6, 6, 6, None]}, schema={"replaced": pl.Int32} + ) + assert_frame_equal(result, expected) + + +def test_replace_strict_empty_mapping_default() -> None: + df = pl.DataFrame({"int": [None, 1, None, 3]}, schema={"int": pl.Int16}) + mapping: dict[Any, Any] = {} + result = df.select(pl.col("int").replace_strict(mapping, default=pl.lit("A"))) + expected = pl.DataFrame({"int": ["A", "A", "A", "A"]}) + assert_frame_equal(result, expected) + + +def test_replace_strict_int_to_int_df() -> None: + lf = pl.LazyFrame({"a": [1, 2, 3]}, schema={"a": pl.UInt8}) + mapping = {1: 11, 2: 22} + + result = lf.select( + pl.col("a").replace_strict( + old=pl.Series(mapping.keys()), + new=pl.Series(mapping.values(), dtype=pl.UInt8), + default=pl.lit(99).cast(pl.UInt8), + ) + ) + expected = pl.LazyFrame({"a": [11, 22, 99]}, schema_overrides={"a": pl.UInt8}) + assert_frame_equal(result, expected) + + +def test_replace_strict_str_to_int_fill_null() -> None: + lf = pl.LazyFrame({"a": ["one", "two"]}) + mapping = {"one": 1} + + result = lf.select( + pl.col("a") + .replace_strict(mapping, default=None, return_dtype=pl.UInt32) + .fill_null(999) + ) + + expected = pl.LazyFrame({"a": pl.Series([1, 999], dtype=pl.UInt32)}) + assert_frame_equal(result, expected) + + +def test_replace_strict_mix() -> None: + df = pl.DataFrame( + [ + pl.Series("float_to_boolean", [1.0, None]), + pl.Series("boolean_to_int", [True, False]), + pl.Series("boolean_to_str", [True, False]), + ] + ) + + result = df.with_columns( + pl.col("float_to_boolean").replace_strict({1.0: True}), + pl.col("boolean_to_int").replace_strict({True: 1, False: 0}), + pl.col("boolean_to_str").replace_strict({True: "1", False: "0"}), + ) + + expected = pl.DataFrame( + [ + pl.Series("float_to_boolean", [True, None], dtype=pl.Boolean), + pl.Series("boolean_to_int", [1, 0], dtype=pl.Int64), + pl.Series("boolean_to_str", ["1", "0"], dtype=pl.String), + ] + ) + assert_frame_equal(result, expected) + + +@pytest.fixture(scope="module") +def int_mapping() -> dict[int, int]: + return {1: 11, 2: 22, 3: 33, 4: 44, 5: 55} + + +def test_replace_strict_int_to_int2(int_mapping: dict[int, int]) -> None: + s = pl.Series([1, 22, None, 44, -5]) + result = s.replace_strict(int_mapping, default=None) + expected = pl.Series([11, None, None, None, None], dtype=pl.Int64) + assert_series_equal(result, expected) + + +def test_replace_strict_int_to_int3(int_mapping: dict[int, int]) -> None: + s = pl.Series([1, 22, None, 44, -5], dtype=pl.Int16) + result = s.replace_strict(int_mapping, default=9) + expected = pl.Series([11, 9, 9, 9, 9], dtype=pl.Int64) + assert_series_equal(result, expected) + + +def test_replace_strict_int_to_int4_return_dtype(int_mapping: dict[int, int]) -> None: + s = pl.Series([-1, 22, None, 44, -5], dtype=pl.Int16) + result = s.replace_strict(int_mapping, default=s, return_dtype=pl.Float32) + expected = pl.Series([-1.0, 22.0, None, 44.0, -5.0], dtype=pl.Float32) + assert_series_equal(result, expected) + + +def test_replace_strict_int_to_int5_return_dtype(int_mapping: dict[int, int]) -> None: + s = pl.Series([1, 22, None, 44, -5], dtype=pl.Int16) + result = s.replace_strict(int_mapping, default=9, return_dtype=pl.Float32) + expected = pl.Series([11.0, 9.0, 9.0, 9.0, 9.0], dtype=pl.Float32) + assert_series_equal(result, expected) + + +def test_replace_strict_bool_to_int() -> None: + s = pl.Series([True, False, False, None]) + mapping = {True: 1, False: 0} + result = s.replace_strict(mapping) + expected = pl.Series([1, 0, 0, None]) + assert_series_equal(result, expected) + + +def test_replace_strict_bool_to_str() -> None: + s = pl.Series([True, False, False, None]) + mapping = {True: "1", False: "0"} + result = s.replace_strict(mapping) + expected = pl.Series(["1", "0", "0", None]) + assert_series_equal(result, expected) + + +def test_replace_strict_str_to_bool() -> None: + s = pl.Series(["True", "False", "False", None]) + mapping = {"True": True, "False": False} + result = s.replace_strict(mapping) + expected = pl.Series([True, False, False, None]) + assert_series_equal(result, expected) + + +def test_replace_strict_int_to_str() -> None: + s = pl.Series("a", [-1, 2, None, 4, -5]) + mapping = {1: "one", 2: "two", 3: "three", 4: "four", 5: "five"} + + with pytest.raises(InvalidOperationError, match="incomplete mapping"): + s.replace_strict(mapping) + result = s.replace_strict(mapping, default=None) + + expected = pl.Series("a", [None, "two", None, "four", None]) + assert_series_equal(result, expected) + + +def test_replace_strict_int_to_str2() -> None: + s = pl.Series("a", [1, 2, None, 4, 5]) + mapping = {1: "one", 2: "two", 3: "three", 4: "four", 5: "five"} + + result = s.replace_strict(mapping) + + expected = pl.Series("a", ["one", "two", None, "four", "five"]) + assert_series_equal(result, expected) + + +def test_replace_strict_int_to_str_with_default() -> None: + s = pl.Series("a", [1, 2, None, 4, 5]) + mapping = {1: "one", 2: "two", 3: "three", 4: "four", 5: "five"} + + result = s.replace_strict(mapping, default="?") + + expected = pl.Series("a", ["one", "two", "?", "four", "five"]) + assert_series_equal(result, expected) + + +def test_replace_strict_str_to_int() -> None: + s = pl.Series(["a", "b"]) + mapping = {"a": 1, "b": 2} + result = s.replace_strict(mapping) + expected = pl.Series([1, 2]) + assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + ("context", "dtype"), + [ + (pl.StringCache(), pl.Categorical), + (pytest.warns(CategoricalRemappingWarning), pl.Categorical), + (contextlib.nullcontext(), pl.Enum(["a", "b", "OTHER"])), + ], +) +def test_replace_strict_cat_str( + context: contextlib.AbstractContextManager, # type: ignore[type-arg] + dtype: pl.DataType, +) -> None: + with context: + for old, new, expected in [ + ("a", "c", pl.Series("s", ["c", None], dtype=pl.Utf8)), + (["a", "b"], ["c", "d"], pl.Series("s", ["c", "d"], dtype=pl.Utf8)), + (pl.lit("a", dtype=dtype), "c", pl.Series("s", ["c", None], dtype=pl.Utf8)), + ( + pl.Series(["a", "b"], dtype=dtype), + ["c", "d"], + pl.Series("s", ["c", "d"], dtype=pl.Utf8), + ), + ]: + s = pl.Series("s", ["a", "b"], dtype=dtype) + s_replaced = s.replace_strict(old, new, default=None) # type: ignore[arg-type] + assert_series_equal(s_replaced, expected) + + s = pl.Series("s", ["a", "b"], dtype=dtype) + s_replaced = s.replace_strict(old, new, default="OTHER") # type: ignore[arg-type] + assert_series_equal(s_replaced, expected.fill_null("OTHER")) + + +@pytest.mark.parametrize( + "context", [pl.StringCache(), pytest.warns(CategoricalRemappingWarning)] +) +def test_replace_strict_cat_cat( + context: contextlib.AbstractContextManager, # type: ignore[type-arg] +) -> None: + with context: + dt = pl.Categorical + for old, new, expected in [ + ("a", pl.lit("c", dtype=dt), pl.Series("s", ["c", None], dtype=dt)), + ( + ["a", "b"], + pl.Series(["c", "d"], dtype=dt), + pl.Series("s", ["c", "d"], dtype=dt), + ), + ]: + s = pl.Series("s", ["a", "b"], dtype=dt) + s_replaced = s.replace_strict(old, new, default=None) # type: ignore[arg-type] + assert_series_equal(s_replaced, expected) + + s = pl.Series("s", ["a", "b"], dtype=dt) + s_replaced = s.replace_strict(old, new, default=pl.lit("OTHER", dtype=dt)) # type: ignore[arg-type] + assert_series_equal(s_replaced, expected.fill_null("OTHER")) From 2dc0d1faa53bd0b3294e2885e0aea4d72ecc611f Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Thu, 20 Jun 2024 21:39:20 +0200 Subject: [PATCH 5/5] Fix strictness check in single case --- crates/polars-ops/src/series/ops/replace.rs | 17 ++++++-- .../unit/operations/test_replace_strict.py | 42 +++++++++++++++++++ 2 files changed, 55 insertions(+), 4 deletions(-) diff --git a/crates/polars-ops/src/series/ops/replace.rs b/crates/polars-ops/src/series/ops/replace.rs index 54ee17e99576..a331078318ea 100644 --- a/crates/polars-ops/src/series/ops/replace.rs +++ b/crates/polars-ops/src/series/ops/replace.rs @@ -132,7 +132,7 @@ fn replace_by_single( /// Fast path for replacing by a single value in strict mode fn replace_by_single_strict(s: &Series, old: &Series, new: &Series) -> PolarsResult { let mask = get_replacement_mask(s, old)?; - ensure_all_replaced(&mask, s, old.null_count() > 0)?; + ensure_all_replaced(&mask, s, old.null_count() > 0, true)?; let mut out = new.new_from_index(0, s.len()); @@ -224,7 +224,7 @@ fn replace_by_multiple_strict(s: &Series, old: Series, new: Series) -> PolarsRes .unwrap() .bool() .unwrap(); - ensure_all_replaced(mask, s, old_has_null)?; + ensure_all_replaced(mask, s, old_has_null, false)?; Ok(replaced.clone()) } @@ -254,12 +254,21 @@ fn validate_new(new: &Series, old: &Series) -> PolarsResult<()> { } /// Ensure that all values were replaced. -fn ensure_all_replaced(mask: &BooleanChunked, s: &Series, old_has_null: bool) -> PolarsResult<()> { - let all_replaced = if old_has_null { +fn ensure_all_replaced( + mask: &BooleanChunked, + s: &Series, + old_has_null: bool, + check_all: bool, +) -> PolarsResult<()> { + let nulls_check = if old_has_null { mask.null_count() == 0 } else { mask.null_count() == s.null_count() }; + // Checking booleans is only relevant for the 'replace_by_single' path. + let bools_check = !check_all || mask.all(); + + let all_replaced = bools_check && nulls_check; polars_ensure!( all_replaced, InvalidOperation: "incomplete mapping specified for `replace_strict`\n\nHint: Pass a `default` value to set unmapped values." diff --git a/py-polars/tests/unit/operations/test_replace_strict.py b/py-polars/tests/unit/operations/test_replace_strict.py index 1ed0684599ed..d72f0c7968d6 100644 --- a/py-polars/tests/unit/operations/test_replace_strict.py +++ b/py-polars/tests/unit/operations/test_replace_strict.py @@ -17,6 +17,48 @@ def test_replace_strict_incomplete_mapping() -> None: lf.select(pl.col("a").replace_strict({2: 200, 3: 300})).collect() +def test_replace_strict_incomplete_mapping_null_raises() -> None: + s = pl.Series("a", [1, 2, 2, None, None]) + with pytest.raises(InvalidOperationError): + s.replace_strict({1: 10}) + + +def test_replace_strict_mapping_null_not_specified() -> None: + s = pl.Series("a", [1, 2, 2, None, None]) + + result = s.replace_strict({1: 10, 2: 20}) + + expected = pl.Series("a", [10, 20, 20, None, None]) + assert_series_equal(result, expected) + + +def test_replace_strict_mapping_null_specified() -> None: + s = pl.Series("a", [1, 2, 2, None, None]) + + result = s.replace_strict({1: 10, 2: 20, None: 0}) + + expected = pl.Series("a", [10, 20, 20, 0, 0]) + assert_series_equal(result, expected) + + +def test_replace_strict_mapping_null_replace_by_null() -> None: + s = pl.Series("a", [1, 2, 2, None]) + + result = s.replace_strict({1: 10, 2: None, None: 0}) + + expected = pl.Series("a", [10, None, None, 0]) + assert_series_equal(result, expected) + + +def test_replace_strict_mapping_null_with_default() -> None: + s = pl.Series("a", [1, 2, 2, None, None]) + + result = s.replace_strict({1: 10}, default=0) + + expected = pl.Series("a", [10, 0, 0, 0, 0]) + assert_series_equal(result, expected) + + def test_replace_strict_empty() -> None: lf = pl.LazyFrame({"a": [None, None]}) result = lf.select(pl.col("a").replace_strict({}))