Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(rust, python): std when ddof>=n_values returns None even in rolling context #11750

Merged
merged 21 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
39c6c7d
fix(rust, python): std when ddof>=n_values returns None even in rolli…
MarcoGorelli Oct 16, 2023
8254439
doctests
MarcoGorelli Oct 16, 2023
a3a3346
Merge remote-tracking branch 'upstream/main' into ddof-le-nvalues-rol…
MarcoGorelli Nov 12, 2023
1f6ed72
wip
MarcoGorelli Nov 12, 2023
a5c4fe8
chore(rust): use unwrap_or_else
MarcoGorelli Nov 12, 2023
1c29190
use unwrap_or_else
MarcoGorelli Nov 12, 2023
d65169a
use get_unchecked_release
MarcoGorelli Nov 12, 2023
8ded05a
use get_unchecked_release
MarcoGorelli Nov 12, 2023
cd1d340
Merge remote-tracking branch 'upstream/main' into ddof-le-nvalues-rol…
MarcoGorelli Nov 13, 2023
1d4a23f
Merge remote-tracking branch 'upstream/main' into ddof-le-nvalues-rol…
MarcoGorelli Dec 17, 2023
5ccd6f4
Merge remote-tracking branch 'upstream/main' into ddof-le-nvalues-rol…
MarcoGorelli Dec 19, 2023
617421e
rust tests
MarcoGorelli Dec 19, 2023
c3dc96a
Merge remote-tracking branch 'upstream/main' into ddof-le-nvalues-rol…
MarcoGorelli Dec 30, 2023
5f98db4
Merge remote-tracking branch 'upstream/main' into ddof-le-nvalues-rol…
MarcoGorelli Mar 1, 2024
b0eef23
dont create validity if you dont have to
MarcoGorelli Mar 1, 2024
74b07d8
make doctest
MarcoGorelli Mar 1, 2024
788221b
Merge remote-tracking branch 'upstream/main' into ddof-le-nvalues-rol…
MarcoGorelli Mar 5, 2024
30fd107
Merge remote-tracking branch 'upstream/main' into ddof-le-nvalues-rol…
MarcoGorelli Mar 8, 2024
eae7780
Merge remote-tracking branch 'upstream/main' into ddof-le-nvalues-rol…
MarcoGorelli Mar 8, 2024
27f4914
use new_null
MarcoGorelli Mar 8, 2024
dceff5b
Merge remote-tracking branch 'upstream/main' into ddof-le-nvalues-rol…
MarcoGorelli Mar 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ impl<
}
}

unsafe fn update(&mut self, start: usize, end: usize) -> T {
let sum = self.sum.update(start, end);
sum / NumCast::from(end - start).unwrap()
unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
let sum = self.sum.update(start, end).unwrap_unchecked();
Some(sum / NumCast::from(end - start).unwrap())
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ macro_rules! minmax_window {
}
}

unsafe fn update(&mut self, start: usize, end: usize) -> T {
unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
//For details see: https://github.com/pola-rs/polars/pull/9277#issuecomment-1581401692
self.last_start = start; // Don't care where the last one started
let old_last_end = self.last_end; // But we need this
Expand All @@ -168,10 +168,10 @@ macro_rules! minmax_window {
if entering.map(|em| $new_is_m(&self.m, em.1) || empty_overlap) == Some(true) {
// The entering extremum "beats" the previous extremum so we can ignore the overlap
self.update_m_and_m_idx(entering.unwrap());
return self.m;
return Some(self.m);
} else if self.m_idx >= start || empty_overlap {
// The previous extremum didn't drop off. Keep it
return self.m;
return Some(self.m);
}
// Otherwise get the min of the overlapping window and the entering min
match (
Expand All @@ -191,7 +191,7 @@ macro_rules! minmax_window {
(None, None) => unreachable!(),
}

self.m
Some(self.m)
}
}
};
Expand Down Expand Up @@ -241,7 +241,7 @@ macro_rules! rolling_minmax_func {
_params: DynArgs,
) -> PolarsResult<ArrayRef>
where
T: NativeType + PartialOrd + IsFloat + Bounded + NumCast + Mul<Output = T>,
T: NativeType + PartialOrd + IsFloat + Bounded + NumCast + Mul<Output = T> + Num,
{
let offset_fn = match center {
true => det_offsets_center,
Expand Down
34 changes: 20 additions & 14 deletions crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@ mod min_max;
mod quantile;
mod sum;
mod variance;

use std::fmt::Debug;

pub use mean::*;
pub use min_max::*;
use num_traits::{Float, NumCast};
use num_traits::{Float, Num, NumCast};
pub use quantile::*;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
Expand All @@ -28,7 +27,7 @@ pub trait RollingAggWindowNoNulls<'a, T: NativeType> {
///
/// # Safety
/// `start` and `end` must be within the windows bounds
unsafe fn update(&mut self, start: usize, end: usize) -> T;
unsafe fn update(&mut self, start: usize, end: usize) -> Option<T>;
}

// Use an aggregation window that maintains the state
Expand All @@ -42,27 +41,34 @@ pub(super) fn rolling_apply_agg_window<'a, Agg, T, Fo>(
where
Fo: Fn(Idx, WindowSize, Len) -> (Start, End),
Agg: RollingAggWindowNoNulls<'a, T>,
T: Debug + NativeType,
T: Debug + NativeType + Num,
{
let len = values.len();
let (start, end) = det_offsets_fn(0, window_size, len);
let mut agg_window = Agg::new(values, start, end, params);
if let Some(validity) = create_validity(min_periods, len, window_size, &det_offsets_fn) {
if validity.iter().all(|x| !x) {
return Ok(Box::new(PrimitiveArray::<T>::new_null(
T::PRIMITIVE.into(),
len,
)));
}
}

let out = (0..len)
.map(|idx| {
let (start, end) = det_offsets_fn(idx, window_size, len);
// SAFETY:
// we are in bounds
unsafe { agg_window.update(start, end) }
if end - start < min_periods {
None
} else {
// SAFETY:
// we are in bounds
unsafe { agg_window.update(start, end) }
}
})
.collect_trusted::<Vec<_>>();

let validity = create_validity(min_periods, len, window_size, det_offsets_fn);
Ok(Box::new(PrimitiveArray::new(
T::PRIMITIVE.into(),
out.into(),
validity.map(|b| b.into()),
)))
let arr = PrimitiveArray::from(out);
Ok(Box::new(arr))
}

#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl<
}
}

unsafe fn update(&mut self, start: usize, end: usize) -> T {
unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
let vals = self.sorted.update(start, end);
let length = vals.len();

Expand All @@ -48,13 +48,13 @@ impl<
let float_idx_top = (length_f - 1.0) * self.prob;
let top_idx = float_idx_top.ceil() as usize;
return if idx == top_idx {
unsafe { *vals.get_unchecked_release(idx) }
Some(unsafe { *vals.get_unchecked_release(idx) })
} else {
let proportion = T::from(float_idx_top - idx as f64).unwrap();
let vi = unsafe { *vals.get_unchecked_release(idx) };
let vj = unsafe { *vals.get_unchecked_release(top_idx) };

proportion * (vj - vi) + vi
Some(proportion * (vj - vi) + vi)
};
},
Midpoint => {
Expand All @@ -66,7 +66,7 @@ impl<
return if top_idx == idx {
// SAFETY:
// we are in bounds
unsafe { *vals.get_unchecked_release(idx) }
Some(unsafe { *vals.get_unchecked_release(idx) })
} else {
// SAFETY:
// we are in bounds
Expand All @@ -77,7 +77,7 @@ impl<
)
};

(mid + mid_plus_1) / (T::one() + T::one())
Some((mid + mid_plus_1) / (T::one() + T::one()))
};
},
Nearest => {
Expand All @@ -93,7 +93,7 @@ impl<

// SAFETY:
// we are in bounds
unsafe { *vals.get_unchecked_release(idx) }
Some(unsafe { *vals.get_unchecked_release(idx) })
}
}

Expand Down
13 changes: 10 additions & 3 deletions crates/polars-arrow/src/legacy/kernels/rolling/no_nulls/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign>
}
}

unsafe fn update(&mut self, start: usize, end: usize) -> T {
unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
// if we exceed the end, we have a completely new window
// so we recompute
let recompute_sum = if start >= self.last_end {
Expand Down Expand Up @@ -60,7 +60,7 @@ impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign>
}
}
self.last_end = end;
self.sum
Some(self.sum)
}
}

Expand All @@ -73,7 +73,14 @@ pub fn rolling_sum<T>(
_params: DynArgs,
) -> PolarsResult<ArrayRef>
where
T: NativeType + std::iter::Sum + NumCast + Mul<Output = T> + AddAssign + SubAssign + IsFloat,
T: NativeType
+ std::iter::Sum
+ NumCast
+ Mul<Output = T>
+ AddAssign
+ SubAssign
+ IsFloat
+ Num,
{
match (center, weights) {
(true, None) => rolling_apply_agg_window::<SumWindow<_>, _, _>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign + Mul<
}
}

unsafe fn update(&mut self, start: usize, end: usize) -> T {
unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
// if we exceed the end, we have a completely new window
// so we recompute
let recompute_sum = if start >= self.last_end || self.last_recompute > 128 {
Expand Down Expand Up @@ -68,7 +68,7 @@ impl<'a, T: NativeType + IsFloat + std::iter::Sum + AddAssign + SubAssign + Mul<
}
}
self.last_end = end;
self.sum_of_squares
Some(self.sum_of_squares)
}
}

Expand Down Expand Up @@ -108,25 +108,24 @@ impl<
}
}

unsafe fn update(&mut self, start: usize, end: usize) -> T {
unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
let count: T = NumCast::from(end - start).unwrap();
let sum_of_squares = self.sum_of_squares.update(start, end);
let mean = self.mean.update(start, end);
let sum_of_squares = self.sum_of_squares.update(start, end).unwrap_unchecked();
let mean = self.mean.update(start, end).unwrap_unchecked();

let denom = count - NumCast::from(self.ddof).unwrap();
if end - start == 1 {
T::zero()
} else if denom <= T::zero() {
//ddof would be greater than # of observations
T::infinity()
if denom <= T::zero() {
None
} else if end - start == 1 {
Some(T::zero())
} else {
let out = (sum_of_squares - count * mean * mean) / denom;
// variance cannot be negative.
// if it is negative it is due to numeric instability
if out < T::zero() {
T::zero()
Some(T::zero())
} else {
out
Some(out)
}
}
}
Expand Down Expand Up @@ -208,14 +207,11 @@ mod test {

let out = rolling_var(values, 2, 1, false, None, None).unwrap();
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
let out = out
.into_iter()
.map(|v| v.copied().unwrap())
.collect::<Vec<_>>();
let out = out.into_iter().map(|v| v.copied()).collect::<Vec<_>>();
// we cannot compare nans, so we compare the string values
assert_eq!(
format!("{:?}", out.as_slice()),
format!("{:?}", &[0.0, 8.0, 2.0, 0.5])
format!("{:?}", &[None, Some(8.0), Some(2.0), Some(0.5)])
);
// test nan handling.
let values = &[-10.0, 2.0, 3.0, f64::nan(), 5.0, 6.0, 7.0];
Expand Down
34 changes: 29 additions & 5 deletions crates/polars-core/src/frame/group_by/aggregations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ where
None
} else {
// SAFETY: we are in bounds.
Some(unsafe { agg_window.update(start as usize, end as usize) })
unsafe { agg_window.update(start as usize, end as usize) }
}
})
.collect::<PrimitiveArray<T>>()
Expand Down Expand Up @@ -799,7 +799,13 @@ where
debug_assert!(len <= self.len() as IdxSize);
match len {
0 => None,
1 => NumCast::from(0),
1 => {
if ddof == 0 {
NumCast::from(0)
} else {
None
}
},
_ => {
let arr_group = _slice_from_offsets(self, first, len);
arr_group.var(ddof).map(|flt| NumCast::from(flt).unwrap())
Expand Down Expand Up @@ -861,7 +867,13 @@ where
debug_assert!(len <= self.len() as IdxSize);
match len {
0 => None,
1 => NumCast::from(0),
1 => {
if ddof == 0 {
NumCast::from(0)
} else {
None
}
},
_ => {
let arr_group = _slice_from_offsets(self, first, len);
arr_group.std(ddof).map(|flt| NumCast::from(flt).unwrap())
Expand Down Expand Up @@ -1012,7 +1024,13 @@ where
debug_assert!(first + len <= self.len() as IdxSize);
match len {
0 => None,
1 => NumCast::from(0),
1 => {
if ddof == 0 {
NumCast::from(0)
} else {
None
}
},
_ => {
let arr_group = _slice_from_offsets(self, first, len);
arr_group.var(ddof)
Expand Down Expand Up @@ -1054,7 +1072,13 @@ where
debug_assert!(first + len <= self.len() as IdxSize);
match len {
0 => None,
1 => NumCast::from(0),
1 => {
if ddof == 0 {
NumCast::from(0)
} else {
None
}
},
_ => {
let arr_group = _slice_from_offsets(self, first, len);
arr_group.std(ddof)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ where
} else {
// SAFETY:
// we are in bounds
Some(unsafe { agg_window.update(start as usize, end as usize) })
unsafe { agg_window.update(start as usize, end as usize) }
}
})
})
Expand Down
8 changes: 4 additions & 4 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6834,7 +6834,7 @@ def rolling_std(
│ u32 ┆ datetime[μs] ┆ f64 │
╞═══════╪═════════════════════╪═════════════════╡
│ 0 ┆ 2001-01-01 00:00:00 ┆ null │
│ 1 ┆ 2001-01-01 01:00:00 ┆ 0.0
│ 1 ┆ 2001-01-01 01:00:00 ┆ null
│ 2 ┆ 2001-01-01 02:00:00 ┆ 0.707107 │
│ 3 ┆ 2001-01-01 03:00:00 ┆ 0.707107 │
│ 4 ┆ 2001-01-01 04:00:00 ┆ 0.707107 │
Expand All @@ -6859,7 +6859,7 @@ def rolling_std(
│ --- ┆ --- ┆ --- │
│ u32 ┆ datetime[μs] ┆ f64 │
╞═══════╪═════════════════════╪═════════════════╡
│ 0 ┆ 2001-01-01 00:00:00 ┆ 0.0
│ 0 ┆ 2001-01-01 00:00:00 ┆ null
│ 1 ┆ 2001-01-01 01:00:00 ┆ 0.707107 │
│ 2 ┆ 2001-01-01 02:00:00 ┆ 1.0 │
│ 3 ┆ 2001-01-01 03:00:00 ┆ 1.0 │
Expand Down Expand Up @@ -7081,7 +7081,7 @@ def rolling_var(
│ u32 ┆ datetime[μs] ┆ f64 │
╞═══════╪═════════════════════╪═════════════════╡
│ 0 ┆ 2001-01-01 00:00:00 ┆ null │
│ 1 ┆ 2001-01-01 01:00:00 ┆ 0.0
│ 1 ┆ 2001-01-01 01:00:00 ┆ null
│ 2 ┆ 2001-01-01 02:00:00 ┆ 0.5 │
│ 3 ┆ 2001-01-01 03:00:00 ┆ 0.5 │
│ 4 ┆ 2001-01-01 04:00:00 ┆ 0.5 │
Expand All @@ -7106,7 +7106,7 @@ def rolling_var(
│ --- ┆ --- ┆ --- │
│ u32 ┆ datetime[μs] ┆ f64 │
╞═══════╪═════════════════════╪═════════════════╡
│ 0 ┆ 2001-01-01 00:00:00 ┆ 0.0
│ 0 ┆ 2001-01-01 00:00:00 ┆ null
│ 1 ┆ 2001-01-01 01:00:00 ┆ 0.5 │
│ 2 ┆ 2001-01-01 02:00:00 ┆ 1.0 │
│ 3 ┆ 2001-01-01 03:00:00 ┆ 1.0 │
Expand Down
Loading
Loading