Skip to content

Commit

Permalink
fix(rust, python)!: return float dtype in interpolate (for method="li…
Browse files Browse the repository at this point in the history
…near") for numeric dtypes (#11624)
  • Loading branch information
MarcoGorelli authored Oct 13, 2023
1 parent 8bb0b93 commit f7374e5
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 77 deletions.
81 changes: 35 additions & 46 deletions crates/polars-ops/src/chunked_array/interpolate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,29 +60,6 @@ where
}
}

#[inline]
fn unsigned_interp<T>(low: T, high: T, steps: IdxSize, steps_n: T, av: &mut Vec<T>)
where
T: Sub<Output = T>
+ Mul<Output = T>
+ Add<Output = T>
+ Div<Output = T>
+ NumCast
+ PartialOrd
+ Copy,
{
if high >= low {
signed_interp::<T>(low, high, steps, steps_n, av)
} else {
let diff = low - high;
for step_i in (1..steps).rev() {
let step_i: T = NumCast::from(step_i).unwrap();
let v = linear_itp(high, step_i, diff, steps_n);
av.push(v)
}
}
}

fn interpolate_impl<T, I>(chunked_arr: &ChunkedArray<T>, interpolation_branch: I) -> ChunkedArray<T>
where
T: PolarsNumericType,
Expand Down Expand Up @@ -196,34 +173,46 @@ fn interpolate_linear(s: &Series) -> Series {
let logical = s.dtype();

let s = s.to_physical_repr();
let out = match s.dtype() {
#[cfg(feature = "dtype-i8")]
DataType::Int8 => linear_interp_signed(s.i8().unwrap()),
#[cfg(feature = "dtype-i16")]
DataType::Int16 => linear_interp_signed(s.i16().unwrap()),
DataType::Int32 => linear_interp_signed(s.i32().unwrap()),
DataType::Int64 => linear_interp_signed(s.i64().unwrap()),
#[cfg(feature = "dtype-u8")]
DataType::UInt8 => linear_interp_unsigned(s.u8().unwrap()),
#[cfg(feature = "dtype-u16")]
DataType::UInt16 => linear_interp_unsigned(s.u16().unwrap()),
DataType::UInt32 => linear_interp_unsigned(s.u32().unwrap()),
DataType::UInt64 => linear_interp_unsigned(s.u64().unwrap()),
DataType::Float32 => linear_interp_unsigned(s.f32().unwrap()),
DataType::Float64 => linear_interp_unsigned(s.f64().unwrap()),
_ => s.as_ref().clone(),

let out = if matches!(
logical,
DataType::Date | DataType::Datetime(_, _) | DataType::Duration(_) | DataType::Time
) {
match s.dtype() {
// Datetime, Time, or Duration
DataType::Int64 => linear_interp_signed(s.i64().unwrap()),
// Date
DataType::Int32 => linear_interp_signed(s.i32().unwrap()),
_ => unreachable!(),
}
} else {
match s.dtype() {
DataType::Float32 => linear_interp_signed(s.f32().unwrap()),
DataType::Float64 => linear_interp_signed(s.f64().unwrap()),
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64 => {
linear_interp_signed(s.cast(&DataType::Float64).unwrap().f64().unwrap())
},
_ => s.as_ref().clone(),
}
};
out.cast(logical).unwrap()
match logical {
DataType::Date
| DataType::Datetime(_, _)
| DataType::Duration(_)
| DataType::Time => out.cast(logical).unwrap(),
_ => out,
}
},
}
}

fn linear_interp_unsigned<T: PolarsNumericType>(ca: &ChunkedArray<T>) -> Series
where
ChunkedArray<T>: IntoSeries,
{
interpolate_impl(ca, unsigned_interp::<T::Native>).into_series()
}
fn linear_interp_signed<T: PolarsNumericType>(ca: &ChunkedArray<T>) -> Series
where
ChunkedArray<T>: IntoSeries,
Expand Down
19 changes: 18 additions & 1 deletion crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,10 @@ impl FunctionExpr {
dt => dt.clone(),
}),
#[cfg(feature = "interpolate")]
Interpolate(_) => mapper.with_same_dtype(),
Interpolate(method) => match method {
InterpolationMethod::Linear => mapper.map_numeric_to_float_dtype(),
InterpolationMethod::Nearest => mapper.with_same_dtype(),
},
ShrinkType => {
// we return the smallest type this can return
// this might not be correct once the actual data
Expand Down Expand Up @@ -286,6 +289,20 @@ impl<'a> FieldsMapper<'a> {
})
}

/// Map to a float supertype if numeric, else preserve
pub fn map_numeric_to_float_dtype(&self) -> PolarsResult<Field> {
self.map_dtype(|dtype| {
if dtype.is_numeric() {
match dtype {
DataType::Float32 => DataType::Float32,
_ => DataType::Float64,
}
} else {
dtype.clone()
}
})
}

/// Map to a physical type.
pub fn to_physical_type(&self) -> PolarsResult<Field> {
self.map_dtype(|dtype| dtype.to_physical())
Expand Down
20 changes: 10 additions & 10 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -9480,16 +9480,16 @@ def interpolate(self) -> DataFrame:
... )
>>> df.interpolate()
shape: (4, 3)
┌─────┬──────┬─────┐
│ foo ┆ bar ┆ baz │
│ --- ┆ --- ┆ --- │
i64 ┆ i64i64
╞═════╪══════╪═════╡
│ 1 ┆ 6 ┆ 1 │
│ 5 ┆ 7 ┆ 3
│ 9 ┆ 9 ┆ 6
│ 10 ┆ null ┆ 9 │
└─────┴──────┴─────┘
┌─────┬──────┬──────────┐
│ foo ┆ bar ┆ baz
│ --- ┆ --- ┆ ---
f64 ┆ f64f64
╞═════╪══════╪══════════╡
│ 1.0 ┆ 6.0 ┆ 1.0
│ 5.0 ┆ 7.0 ┆ 3.666667
│ 9.0 ┆ 9.0 ┆ 6.333333
│ 10.0 ┆ null ┆ 9.0
└─────┴──────┴──────────┘
"""
return self.select(F.col("*").interpolate())
Expand Down
8 changes: 4 additions & 4 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5387,11 +5387,11 @@ def interpolate(self, method: InterpolationMethod = "linear") -> Self:
┌─────┬─────┐
│ a ┆ b │
│ --- ┆ --- │
i64 ┆ f64 │
f64 ┆ f64 │
╞═════╪═════╡
│ 1 ┆ 1.0 │
│ 2 ┆ NaN │
│ 3 ┆ 3.0 │
│ 1.0 ┆ 1.0 │
│ 2.0 ┆ NaN │
│ 3.0 ┆ 3.0 │
└─────┴─────┘
Fill null values using nearest interpolation.
Expand Down
20 changes: 10 additions & 10 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5395,16 +5395,16 @@ def interpolate(self) -> Self:
... )
>>> lf.interpolate().collect()
shape: (4, 3)
┌─────┬──────┬─────┐
│ foo ┆ bar ┆ baz │
│ --- ┆ --- ┆ --- │
i64 ┆ i64i64
╞═════╪══════╪═════╡
│ 1 ┆ 6 ┆ 1 │
│ 5 ┆ 7 ┆ 3
│ 9 ┆ 9 ┆ 6
│ 10 ┆ null ┆ 9 │
└─────┴──────┴─────┘
┌─────┬──────┬──────────┐
│ foo ┆ bar ┆ baz
│ --- ┆ --- ┆ ---
f64 ┆ f64f64
╞═════╪══════╪══════════╡
│ 1.0 ┆ 6.0 ┆ 1.0
│ 5.0 ┆ 7.0 ┆ 3.666667
│ 9.0 ┆ 9.0 ┆ 6.333333
│ 10.0 ┆ null ┆ 9.0
└─────┴──────┴──────────┘
"""
return self.select(F.col("*").interpolate())
Expand Down
12 changes: 6 additions & 6 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -5875,13 +5875,13 @@ def interpolate(self, method: InterpolationMethod = "linear") -> Series:
>>> s = pl.Series("a", [1, 2, None, None, 5])
>>> s.interpolate()
shape: (5,)
Series: 'a' [i64]
Series: 'a' [f64]
[
1
2
3
4
5
1.0
2.0
3.0
4.0
5.0
]
"""
Expand Down
134 changes: 134 additions & 0 deletions py-polars/tests/unit/operations/test_interpolate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from __future__ import annotations

from datetime import date, datetime, time, timedelta
from typing import TYPE_CHECKING, Any

import pytest

import polars as pl
from polars.testing import assert_frame_equal

if TYPE_CHECKING:
from polars.type_aliases import PolarsDataType, PolarsTemporalType


@pytest.mark.parametrize(
("input_dtype", "output_dtype"),
[
(pl.Int8, pl.Float64),
(pl.Int16, pl.Float64),
(pl.Int32, pl.Float64),
(pl.Int64, pl.Float64),
(pl.UInt8, pl.Float64),
(pl.UInt16, pl.Float64),
(pl.UInt32, pl.Float64),
(pl.UInt64, pl.Float64),
(pl.Float32, pl.Float32),
(pl.Float64, pl.Float64),
],
)
def test_interpolate_linear(
input_dtype: PolarsDataType, output_dtype: PolarsDataType
) -> None:
df = pl.LazyFrame({"a": [1, None, 2, None, 3]}, schema={"a": input_dtype})
result = df.with_columns(pl.all().interpolate(method="linear"))
assert result.schema["a"] == output_dtype
expected = pl.DataFrame(
{"a": [1.0, 1.5, 2.0, 2.5, 3.0]}, schema={"a": output_dtype}
)
assert_frame_equal(result.collect(), expected)


@pytest.mark.parametrize(
("input", "input_dtype", "output"),
[
(
[date(2020, 1, 1), None, date(2020, 1, 2)],
pl.Date,
[date(2020, 1, 1), date(2020, 1, 1), date(2020, 1, 2)],
),
(
[datetime(2020, 1, 1), None, datetime(2020, 1, 2)],
pl.Datetime("ms"),
[datetime(2020, 1, 1), datetime(2020, 1, 1, 12), datetime(2020, 1, 2)],
),
(
[datetime(2020, 1, 1), None, datetime(2020, 1, 2)],
pl.Datetime("us", "Asia/Kathmandu"),
[datetime(2020, 1, 1), datetime(2020, 1, 1, 12), datetime(2020, 1, 2)],
),
([time(1), None, time(2)], pl.Time, [time(1), time(1, 30), time(2)]),
(
[timedelta(1), None, timedelta(2)],
pl.Duration("ms"),
[timedelta(1), timedelta(1, hours=12), timedelta(2)],
),
],
)
def test_interpolate_temporal_linear(
input: list[Any], input_dtype: PolarsTemporalType, output: list[Any]
) -> None:
df = pl.LazyFrame({"a": input}, schema={"a": input_dtype})
result = df.with_columns(pl.all().interpolate(method="linear"))
assert result.schema["a"] == input_dtype
expected = pl.DataFrame({"a": output}, schema={"a": input_dtype})
assert_frame_equal(result.collect(), expected)


@pytest.mark.parametrize(
"input_dtype",
[
pl.Int8,
pl.Int16,
pl.Int32,
pl.Int64,
pl.UInt8,
pl.UInt16,
pl.UInt32,
pl.UInt64,
pl.Float32,
pl.Float64,
],
)
def test_interpolate_nearest(input_dtype: PolarsDataType) -> None:
df = pl.LazyFrame({"a": [1, None, 2, None, 3]}, schema={"a": input_dtype})
result = df.with_columns(pl.all().interpolate(method="nearest"))
assert result.schema["a"] == input_dtype
expected = pl.DataFrame({"a": [1, 2, 2, 3, 3]}, schema={"a": input_dtype})
assert_frame_equal(result.collect(), expected)


@pytest.mark.parametrize(
("input", "input_dtype", "output"),
[
(
[date(2020, 1, 1), None, date(2020, 1, 2)],
pl.Date,
[date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 2)],
),
(
[datetime(2020, 1, 1), None, datetime(2020, 1, 2)],
pl.Datetime("ms"),
[datetime(2020, 1, 1), datetime(2020, 1, 2), datetime(2020, 1, 2)],
),
(
[datetime(2020, 1, 1), None, datetime(2020, 1, 2)],
pl.Datetime("us", "Asia/Kathmandu"),
[datetime(2020, 1, 1), datetime(2020, 1, 2), datetime(2020, 1, 2)],
),
([time(1), None, time(2)], pl.Time, [time(1), time(2), time(2)]),
(
[timedelta(1), None, timedelta(2)],
pl.Duration("ms"),
[timedelta(1), timedelta(2), timedelta(2)],
),
],
)
def test_interpolate_temporal_nearest(
input: list[Any], input_dtype: PolarsTemporalType, output: list[Any]
) -> None:
df = pl.LazyFrame({"a": input}, schema={"a": input_dtype})
result = df.with_columns(pl.all().interpolate(method="nearest"))
assert result.schema["a"] == input_dtype
expected = pl.DataFrame({"a": output}, schema={"a": input_dtype})
assert_frame_equal(result.collect(), expected)

0 comments on commit f7374e5

Please sign in to comment.