Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: return float dtype in interpolate (for method="linear") for numeric dtypes #11624

Merged
merged 3 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 35 additions & 46 deletions crates/polars-ops/src/chunked_array/interpolate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,29 +60,6 @@ where
}
}

#[inline]
fn unsigned_interp<T>(low: T, high: T, steps: IdxSize, steps_n: T, av: &mut Vec<T>)
where
T: Sub<Output = T>
+ Mul<Output = T>
+ Add<Output = T>
+ Div<Output = T>
+ NumCast
+ PartialOrd
+ Copy,
{
if high >= low {
signed_interp::<T>(low, high, steps, steps_n, av)
} else {
let diff = low - high;
for step_i in (1..steps).rev() {
let step_i: T = NumCast::from(step_i).unwrap();
let v = linear_itp(high, step_i, diff, steps_n);
av.push(v)
}
}
}

fn interpolate_impl<T, I>(chunked_arr: &ChunkedArray<T>, interpolation_branch: I) -> ChunkedArray<T>
where
T: PolarsNumericType,
Expand Down Expand Up @@ -196,34 +173,46 @@ fn interpolate_linear(s: &Series) -> Series {
let logical = s.dtype();

let s = s.to_physical_repr();
let out = match s.dtype() {
#[cfg(feature = "dtype-i8")]
DataType::Int8 => linear_interp_signed(s.i8().unwrap()),
#[cfg(feature = "dtype-i16")]
DataType::Int16 => linear_interp_signed(s.i16().unwrap()),
DataType::Int32 => linear_interp_signed(s.i32().unwrap()),
DataType::Int64 => linear_interp_signed(s.i64().unwrap()),
#[cfg(feature = "dtype-u8")]
DataType::UInt8 => linear_interp_unsigned(s.u8().unwrap()),
#[cfg(feature = "dtype-u16")]
DataType::UInt16 => linear_interp_unsigned(s.u16().unwrap()),
DataType::UInt32 => linear_interp_unsigned(s.u32().unwrap()),
DataType::UInt64 => linear_interp_unsigned(s.u64().unwrap()),
DataType::Float32 => linear_interp_unsigned(s.f32().unwrap()),
DataType::Float64 => linear_interp_unsigned(s.f64().unwrap()),
_ => s.as_ref().clone(),

let out = if matches!(
logical,
DataType::Date | DataType::Datetime(_, _) | DataType::Duration(_) | DataType::Time
) {
match s.dtype() {
// Datetime, Time, or Duration
DataType::Int64 => linear_interp_signed(s.i64().unwrap()),
// Date
DataType::Int32 => linear_interp_signed(s.i32().unwrap()),
_ => unreachable!(),
}
} else {
match s.dtype() {
DataType::Float32 => linear_interp_signed(s.f32().unwrap()),
DataType::Float64 => linear_interp_signed(s.f64().unwrap()),
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64 => {
linear_interp_signed(s.cast(&DataType::Float64).unwrap().f64().unwrap())
},
_ => s.as_ref().clone(),
}
};
out.cast(logical).unwrap()
match logical {
DataType::Date
| DataType::Datetime(_, _)
| DataType::Duration(_)
| DataType::Time => out.cast(logical).unwrap(),
_ => out,
}
},
}
}

fn linear_interp_unsigned<T: PolarsNumericType>(ca: &ChunkedArray<T>) -> Series
where
ChunkedArray<T>: IntoSeries,
{
interpolate_impl(ca, unsigned_interp::<T::Native>).into_series()
}
fn linear_interp_signed<T: PolarsNumericType>(ca: &ChunkedArray<T>) -> Series
where
ChunkedArray<T>: IntoSeries,
Expand Down
19 changes: 18 additions & 1 deletion crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,10 @@ impl FunctionExpr {
dt => dt.clone(),
}),
#[cfg(feature = "interpolate")]
Interpolate(_) => mapper.with_same_dtype(),
Interpolate(method) => match method {
InterpolationMethod::Linear => mapper.map_numeric_to_float_dtype(),
InterpolationMethod::Nearest => mapper.with_same_dtype(),
},
ShrinkType => {
// we return the smallest type this can return
// this might not be correct once the actual data
Expand Down Expand Up @@ -328,6 +331,20 @@ impl<'a> FieldsMapper<'a> {
})
}

/// Map to a float supertype if numeric, else preserve
pub fn map_numeric_to_float_dtype(&self) -> PolarsResult<Field> {
self.map_dtype(|dtype| {
if dtype.is_numeric() {
match dtype {
DataType::Float32 => DataType::Float32,
_ => DataType::Float64,
}
} else {
dtype.clone()
}
})
}

/// Map to a physical type.
pub fn to_physical_type(&self) -> PolarsResult<Field> {
self.map_dtype(|dtype| dtype.to_physical())
Expand Down
20 changes: 10 additions & 10 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -9477,16 +9477,16 @@ def interpolate(self) -> DataFrame:
... )
>>> df.interpolate()
shape: (4, 3)
┌─────┬──────┬─────┐
│ foo ┆ bar ┆ baz │
│ --- ┆ --- ┆ --- │
i64 ┆ i64i64
╞═════╪══════╪═════╡
│ 1 ┆ 6 ┆ 1 │
│ 5 ┆ 7 ┆ 3
│ 9 ┆ 9 ┆ 6
│ 10 ┆ null ┆ 9 │
└─────┴──────┴─────┘
┌─────┬──────┬──────────┐
│ foo ┆ bar ┆ baz
│ --- ┆ --- ┆ ---
f64 ┆ f64f64
╞═════╪══════╪══════════╡
│ 1.0 ┆ 6.0 ┆ 1.0
│ 5.0 ┆ 7.0 ┆ 3.666667
│ 9.0 ┆ 9.0 ┆ 6.333333
│ 10.0 ┆ null ┆ 9.0
└─────┴──────┴──────────┘

"""
return self.select(F.col("*").interpolate())
Expand Down
8 changes: 4 additions & 4 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5390,11 +5390,11 @@ def interpolate(self, method: InterpolationMethod = "linear") -> Self:
┌─────┬─────┐
│ a ┆ b │
│ --- ┆ --- │
i64 ┆ f64 │
f64 ┆ f64 │
╞═════╪═════╡
│ 1 ┆ 1.0 │
│ 2 ┆ NaN │
│ 3 ┆ 3.0 │
│ 1.0 ┆ 1.0 │
│ 2.0 ┆ NaN │
│ 3.0 ┆ 3.0 │
└─────┴─────┘

Fill null values using nearest interpolation.
Expand Down
20 changes: 10 additions & 10 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5326,16 +5326,16 @@ def interpolate(self) -> Self:
... )
>>> lf.interpolate().collect()
shape: (4, 3)
┌─────┬──────┬─────┐
│ foo ┆ bar ┆ baz │
│ --- ┆ --- ┆ --- │
i64 ┆ i64i64
╞═════╪══════╪═════╡
│ 1 ┆ 6 ┆ 1 │
│ 5 ┆ 7 ┆ 3
│ 9 ┆ 9 ┆ 6
│ 10 ┆ null ┆ 9 │
└─────┴──────┴─────┘
┌─────┬──────┬──────────┐
│ foo ┆ bar ┆ baz
│ --- ┆ --- ┆ ---
f64 ┆ f64f64
╞═════╪══════╪══════════╡
│ 1.0 ┆ 6.0 ┆ 1.0
│ 5.0 ┆ 7.0 ┆ 3.666667
│ 9.0 ┆ 9.0 ┆ 6.333333
│ 10.0 ┆ null ┆ 9.0
└─────┴──────┴──────────┘

"""
return self.select(F.col("*").interpolate())
Expand Down
12 changes: 6 additions & 6 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -5867,13 +5867,13 @@ def interpolate(self, method: InterpolationMethod = "linear") -> Series:
>>> s = pl.Series("a", [1, 2, None, None, 5])
>>> s.interpolate()
shape: (5,)
Series: 'a' [i64]
Series: 'a' [f64]
[
1
2
3
4
5
1.0
2.0
3.0
4.0
5.0
]

"""
Expand Down
134 changes: 134 additions & 0 deletions py-polars/tests/unit/operations/test_interpolate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from __future__ import annotations

from datetime import date, datetime, time, timedelta
from typing import TYPE_CHECKING, Any

import pytest

import polars as pl
from polars.testing import assert_frame_equal

if TYPE_CHECKING:
from polars.type_aliases import PolarsDataType, PolarsTemporalType


@pytest.mark.parametrize(
("input_dtype", "output_dtype"),
[
(pl.Int8, pl.Float64),
(pl.Int16, pl.Float64),
(pl.Int32, pl.Float64),
(pl.Int64, pl.Float64),
(pl.UInt8, pl.Float64),
(pl.UInt16, pl.Float64),
(pl.UInt32, pl.Float64),
(pl.UInt64, pl.Float64),
(pl.Float32, pl.Float32),
(pl.Float64, pl.Float64),
],
)
def test_interpolate_linear(
input_dtype: PolarsDataType, output_dtype: PolarsDataType
) -> None:
df = pl.LazyFrame({"a": [1, None, 2, None, 3]}, schema={"a": input_dtype})
result = df.with_columns(pl.all().interpolate(method="linear"))
assert result.schema["a"] == output_dtype
expected = pl.DataFrame(
{"a": [1.0, 1.5, 2.0, 2.5, 3.0]}, schema={"a": output_dtype}
)
assert_frame_equal(result.collect(), expected)


@pytest.mark.parametrize(
("input", "input_dtype", "output"),
[
(
[date(2020, 1, 1), None, date(2020, 1, 2)],
pl.Date,
[date(2020, 1, 1), date(2020, 1, 1), date(2020, 1, 2)],
),
(
[datetime(2020, 1, 1), None, datetime(2020, 1, 2)],
pl.Datetime("ms"),
[datetime(2020, 1, 1), datetime(2020, 1, 1, 12), datetime(2020, 1, 2)],
),
(
[datetime(2020, 1, 1), None, datetime(2020, 1, 2)],
pl.Datetime("us", "Asia/Kathmandu"),
[datetime(2020, 1, 1), datetime(2020, 1, 1, 12), datetime(2020, 1, 2)],
),
([time(1), None, time(2)], pl.Time, [time(1), time(1, 30), time(2)]),
(
[timedelta(1), None, timedelta(2)],
pl.Duration("ms"),
[timedelta(1), timedelta(1, hours=12), timedelta(2)],
),
],
)
def test_interpolate_temporal_linear(
input: list[Any], input_dtype: PolarsTemporalType, output: list[Any]
) -> None:
df = pl.LazyFrame({"a": input}, schema={"a": input_dtype})
result = df.with_columns(pl.all().interpolate(method="linear"))
assert result.schema["a"] == input_dtype
expected = pl.DataFrame({"a": output}, schema={"a": input_dtype})
assert_frame_equal(result.collect(), expected)


@pytest.mark.parametrize(
"input_dtype",
[
pl.Int8,
pl.Int16,
pl.Int32,
pl.Int64,
pl.UInt8,
pl.UInt16,
pl.UInt32,
pl.UInt64,
pl.Float32,
pl.Float64,
],
)
def test_interpolate_nearest(input_dtype: PolarsDataType) -> None:
df = pl.LazyFrame({"a": [1, None, 2, None, 3]}, schema={"a": input_dtype})
result = df.with_columns(pl.all().interpolate(method="nearest"))
assert result.schema["a"] == input_dtype
expected = pl.DataFrame({"a": [1, 2, 2, 3, 3]}, schema={"a": input_dtype})
assert_frame_equal(result.collect(), expected)


@pytest.mark.parametrize(
("input", "input_dtype", "output"),
[
(
[date(2020, 1, 1), None, date(2020, 1, 2)],
pl.Date,
[date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 2)],
),
(
[datetime(2020, 1, 1), None, datetime(2020, 1, 2)],
pl.Datetime("ms"),
[datetime(2020, 1, 1), datetime(2020, 1, 2), datetime(2020, 1, 2)],
),
(
[datetime(2020, 1, 1), None, datetime(2020, 1, 2)],
pl.Datetime("us", "Asia/Kathmandu"),
[datetime(2020, 1, 1), datetime(2020, 1, 2), datetime(2020, 1, 2)],
),
([time(1), None, time(2)], pl.Time, [time(1), time(2), time(2)]),
(
[timedelta(1), None, timedelta(2)],
pl.Duration("ms"),
[timedelta(1), timedelta(2), timedelta(2)],
),
],
)
def test_interpolate_temporal_nearest(
input: list[Any], input_dtype: PolarsTemporalType, output: list[Any]
) -> None:
df = pl.LazyFrame({"a": input}, schema={"a": input_dtype})
result = df.with_columns(pl.all().interpolate(method="nearest"))
assert result.schema["a"] == input_dtype
expected = pl.DataFrame({"a": output}, schema={"a": input_dtype})
assert_frame_equal(result.collect(), expected)