From f7374e56770cd468ac8293d47fb7d9b50a41d708 Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Fri, 13 Oct 2023 08:50:24 +0300 Subject: [PATCH] fix(rust, python)!: return float dtype in interpolate (for method="linear") for numeric dtypes (#11624) --- .../src/chunked_array/interpolate.rs | 81 +++++------ .../src/dsl/function_expr/schema.rs | 19 ++- py-polars/polars/dataframe/frame.py | 20 +-- py-polars/polars/expr/expr.py | 8 +- py-polars/polars/lazyframe/frame.py | 20 +-- py-polars/polars/series/series.py | 12 +- .../tests/unit/operations/test_interpolate.py | 134 ++++++++++++++++++ 7 files changed, 217 insertions(+), 77 deletions(-) create mode 100644 py-polars/tests/unit/operations/test_interpolate.py diff --git a/crates/polars-ops/src/chunked_array/interpolate.rs b/crates/polars-ops/src/chunked_array/interpolate.rs index 1d824200f824..b06c06574960 100644 --- a/crates/polars-ops/src/chunked_array/interpolate.rs +++ b/crates/polars-ops/src/chunked_array/interpolate.rs @@ -60,29 +60,6 @@ where } } -#[inline] -fn unsigned_interp(low: T, high: T, steps: IdxSize, steps_n: T, av: &mut Vec) -where - T: Sub - + Mul - + Add - + Div - + NumCast - + PartialOrd - + Copy, -{ - if high >= low { - signed_interp::(low, high, steps, steps_n, av) - } else { - let diff = low - high; - for step_i in (1..steps).rev() { - let step_i: T = NumCast::from(step_i).unwrap(); - let v = linear_itp(high, step_i, diff, steps_n); - av.push(v) - } - } -} - fn interpolate_impl(chunked_arr: &ChunkedArray, interpolation_branch: I) -> ChunkedArray where T: PolarsNumericType, @@ -196,34 +173,46 @@ fn interpolate_linear(s: &Series) -> Series { let logical = s.dtype(); let s = s.to_physical_repr(); - let out = match s.dtype() { - #[cfg(feature = "dtype-i8")] - DataType::Int8 => linear_interp_signed(s.i8().unwrap()), - #[cfg(feature = "dtype-i16")] - DataType::Int16 => linear_interp_signed(s.i16().unwrap()), - DataType::Int32 => linear_interp_signed(s.i32().unwrap()), - DataType::Int64 => linear_interp_signed(s.i64().unwrap()), - #[cfg(feature = "dtype-u8")] - DataType::UInt8 => linear_interp_unsigned(s.u8().unwrap()), - #[cfg(feature = "dtype-u16")] - DataType::UInt16 => linear_interp_unsigned(s.u16().unwrap()), - DataType::UInt32 => linear_interp_unsigned(s.u32().unwrap()), - DataType::UInt64 => linear_interp_unsigned(s.u64().unwrap()), - DataType::Float32 => linear_interp_unsigned(s.f32().unwrap()), - DataType::Float64 => linear_interp_unsigned(s.f64().unwrap()), - _ => s.as_ref().clone(), + + let out = if matches!( + logical, + DataType::Date | DataType::Datetime(_, _) | DataType::Duration(_) | DataType::Time + ) { + match s.dtype() { + // Datetime, Time, or Duration + DataType::Int64 => linear_interp_signed(s.i64().unwrap()), + // Date + DataType::Int32 => linear_interp_signed(s.i32().unwrap()), + _ => unreachable!(), + } + } else { + match s.dtype() { + DataType::Float32 => linear_interp_signed(s.f32().unwrap()), + DataType::Float64 => linear_interp_signed(s.f64().unwrap()), + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => { + linear_interp_signed(s.cast(&DataType::Float64).unwrap().f64().unwrap()) + }, + _ => s.as_ref().clone(), + } }; - out.cast(logical).unwrap() + match logical { + DataType::Date + | DataType::Datetime(_, _) + | DataType::Duration(_) + | DataType::Time => out.cast(logical).unwrap(), + _ => out, + } }, } } -fn linear_interp_unsigned(ca: &ChunkedArray) -> Series -where - ChunkedArray: IntoSeries, -{ - interpolate_impl(ca, unsigned_interp::).into_series() -} fn linear_interp_signed(ca: &ChunkedArray) -> Series where ChunkedArray: IntoSeries, diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index c27f745a6fd2..59e7b567ee53 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -143,7 +143,10 @@ impl FunctionExpr { dt => dt.clone(), }), #[cfg(feature = "interpolate")] - Interpolate(_) => mapper.with_same_dtype(), + Interpolate(method) => match method { + InterpolationMethod::Linear => mapper.map_numeric_to_float_dtype(), + InterpolationMethod::Nearest => mapper.with_same_dtype(), + }, ShrinkType => { // we return the smallest type this can return // this might not be correct once the actual data @@ -286,6 +289,20 @@ impl<'a> FieldsMapper<'a> { }) } + /// Map to a float supertype if numeric, else preserve + pub fn map_numeric_to_float_dtype(&self) -> PolarsResult { + self.map_dtype(|dtype| { + if dtype.is_numeric() { + match dtype { + DataType::Float32 => DataType::Float32, + _ => DataType::Float64, + } + } else { + dtype.clone() + } + }) + } + /// Map to a physical type. pub fn to_physical_type(&self) -> PolarsResult { self.map_dtype(|dtype| dtype.to_physical()) diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 58e7b23a3412..bfd38de96663 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -9480,16 +9480,16 @@ def interpolate(self) -> DataFrame: ... ) >>> df.interpolate() shape: (4, 3) - ┌─────┬──────┬─────┐ - │ foo ┆ bar ┆ baz │ - │ --- ┆ --- ┆ --- │ - │ i64 ┆ i64 ┆ i64 │ - ╞═════╪══════╪═════╡ - │ 1 ┆ 6 ┆ 1 │ - │ 5 ┆ 7 ┆ 3 │ - │ 9 ┆ 9 ┆ 6 │ - │ 10 ┆ null ┆ 9 │ - └─────┴──────┴─────┘ + ┌──────┬──────┬──────────┐ + │ foo ┆ bar ┆ baz │ + │ --- ┆ --- ┆ --- │ + │ f64 ┆ f64 ┆ f64 │ + ╞══════╪══════╪══════════╡ + │ 1.0 ┆ 6.0 ┆ 1.0 │ + │ 5.0 ┆ 7.0 ┆ 3.666667 │ + │ 9.0 ┆ 9.0 ┆ 6.333333 │ + │ 10.0 ┆ null ┆ 9.0 │ + └──────┴──────┴──────────┘ """ return self.select(F.col("*").interpolate()) diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 4a80eb1a2642..d1c963b345ff 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -5387,11 +5387,11 @@ def interpolate(self, method: InterpolationMethod = "linear") -> Self: ┌─────┬─────┐ │ a ┆ b │ │ --- ┆ --- │ - │ i64 ┆ f64 │ + │ f64 ┆ f64 │ ╞═════╪═════╡ - │ 1 ┆ 1.0 │ - │ 2 ┆ NaN │ - │ 3 ┆ 3.0 │ + │ 1.0 ┆ 1.0 │ + │ 2.0 ┆ NaN │ + │ 3.0 ┆ 3.0 │ └─────┴─────┘ Fill null values using nearest interpolation. diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index 879808823f4f..ab866dc57126 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -5395,16 +5395,16 @@ def interpolate(self) -> Self: ... ) >>> lf.interpolate().collect() shape: (4, 3) - ┌─────┬──────┬─────┐ - │ foo ┆ bar ┆ baz │ - │ --- ┆ --- ┆ --- │ - │ i64 ┆ i64 ┆ i64 │ - ╞═════╪══════╪═════╡ - │ 1 ┆ 6 ┆ 1 │ - │ 5 ┆ 7 ┆ 3 │ - │ 9 ┆ 9 ┆ 6 │ - │ 10 ┆ null ┆ 9 │ - └─────┴──────┴─────┘ + ┌──────┬──────┬──────────┐ + │ foo ┆ bar ┆ baz │ + │ --- ┆ --- ┆ --- │ + │ f64 ┆ f64 ┆ f64 │ + ╞══════╪══════╪══════════╡ + │ 1.0 ┆ 6.0 ┆ 1.0 │ + │ 5.0 ┆ 7.0 ┆ 3.666667 │ + │ 9.0 ┆ 9.0 ┆ 6.333333 │ + │ 10.0 ┆ null ┆ 9.0 │ + └──────┴──────┴──────────┘ """ return self.select(F.col("*").interpolate()) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index b6f368e5d443..fc3927395cfb 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -5875,13 +5875,13 @@ def interpolate(self, method: InterpolationMethod = "linear") -> Series: >>> s = pl.Series("a", [1, 2, None, None, 5]) >>> s.interpolate() shape: (5,) - Series: 'a' [i64] + Series: 'a' [f64] [ - 1 - 2 - 3 - 4 - 5 + 1.0 + 2.0 + 3.0 + 4.0 + 5.0 ] """ diff --git a/py-polars/tests/unit/operations/test_interpolate.py b/py-polars/tests/unit/operations/test_interpolate.py new file mode 100644 index 000000000000..904fe23b41a8 --- /dev/null +++ b/py-polars/tests/unit/operations/test_interpolate.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +from datetime import date, datetime, time, timedelta +from typing import TYPE_CHECKING, Any + +import pytest + +import polars as pl +from polars.testing import assert_frame_equal + +if TYPE_CHECKING: + from polars.type_aliases import PolarsDataType, PolarsTemporalType + + +@pytest.mark.parametrize( + ("input_dtype", "output_dtype"), + [ + (pl.Int8, pl.Float64), + (pl.Int16, pl.Float64), + (pl.Int32, pl.Float64), + (pl.Int64, pl.Float64), + (pl.UInt8, pl.Float64), + (pl.UInt16, pl.Float64), + (pl.UInt32, pl.Float64), + (pl.UInt64, pl.Float64), + (pl.Float32, pl.Float32), + (pl.Float64, pl.Float64), + ], +) +def test_interpolate_linear( + input_dtype: PolarsDataType, output_dtype: PolarsDataType +) -> None: + df = pl.LazyFrame({"a": [1, None, 2, None, 3]}, schema={"a": input_dtype}) + result = df.with_columns(pl.all().interpolate(method="linear")) + assert result.schema["a"] == output_dtype + expected = pl.DataFrame( + {"a": [1.0, 1.5, 2.0, 2.5, 3.0]}, schema={"a": output_dtype} + ) + assert_frame_equal(result.collect(), expected) + + +@pytest.mark.parametrize( + ("input", "input_dtype", "output"), + [ + ( + [date(2020, 1, 1), None, date(2020, 1, 2)], + pl.Date, + [date(2020, 1, 1), date(2020, 1, 1), date(2020, 1, 2)], + ), + ( + [datetime(2020, 1, 1), None, datetime(2020, 1, 2)], + pl.Datetime("ms"), + [datetime(2020, 1, 1), datetime(2020, 1, 1, 12), datetime(2020, 1, 2)], + ), + ( + [datetime(2020, 1, 1), None, datetime(2020, 1, 2)], + pl.Datetime("us", "Asia/Kathmandu"), + [datetime(2020, 1, 1), datetime(2020, 1, 1, 12), datetime(2020, 1, 2)], + ), + ([time(1), None, time(2)], pl.Time, [time(1), time(1, 30), time(2)]), + ( + [timedelta(1), None, timedelta(2)], + pl.Duration("ms"), + [timedelta(1), timedelta(1, hours=12), timedelta(2)], + ), + ], +) +def test_interpolate_temporal_linear( + input: list[Any], input_dtype: PolarsTemporalType, output: list[Any] +) -> None: + df = pl.LazyFrame({"a": input}, schema={"a": input_dtype}) + result = df.with_columns(pl.all().interpolate(method="linear")) + assert result.schema["a"] == input_dtype + expected = pl.DataFrame({"a": output}, schema={"a": input_dtype}) + assert_frame_equal(result.collect(), expected) + + +@pytest.mark.parametrize( + "input_dtype", + [ + pl.Int8, + pl.Int16, + pl.Int32, + pl.Int64, + pl.UInt8, + pl.UInt16, + pl.UInt32, + pl.UInt64, + pl.Float32, + pl.Float64, + ], +) +def test_interpolate_nearest(input_dtype: PolarsDataType) -> None: + df = pl.LazyFrame({"a": [1, None, 2, None, 3]}, schema={"a": input_dtype}) + result = df.with_columns(pl.all().interpolate(method="nearest")) + assert result.schema["a"] == input_dtype + expected = pl.DataFrame({"a": [1, 2, 2, 3, 3]}, schema={"a": input_dtype}) + assert_frame_equal(result.collect(), expected) + + +@pytest.mark.parametrize( + ("input", "input_dtype", "output"), + [ + ( + [date(2020, 1, 1), None, date(2020, 1, 2)], + pl.Date, + [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 2)], + ), + ( + [datetime(2020, 1, 1), None, datetime(2020, 1, 2)], + pl.Datetime("ms"), + [datetime(2020, 1, 1), datetime(2020, 1, 2), datetime(2020, 1, 2)], + ), + ( + [datetime(2020, 1, 1), None, datetime(2020, 1, 2)], + pl.Datetime("us", "Asia/Kathmandu"), + [datetime(2020, 1, 1), datetime(2020, 1, 2), datetime(2020, 1, 2)], + ), + ([time(1), None, time(2)], pl.Time, [time(1), time(2), time(2)]), + ( + [timedelta(1), None, timedelta(2)], + pl.Duration("ms"), + [timedelta(1), timedelta(2), timedelta(2)], + ), + ], +) +def test_interpolate_temporal_nearest( + input: list[Any], input_dtype: PolarsTemporalType, output: list[Any] +) -> None: + df = pl.LazyFrame({"a": input}, schema={"a": input_dtype}) + result = df.with_columns(pl.all().interpolate(method="nearest")) + assert result.schema["a"] == input_dtype + expected = pl.DataFrame({"a": output}, schema={"a": input_dtype}) + assert_frame_equal(result.collect(), expected)