Skip to content

Commit

Permalink
feat(rust, python): in rolling aggregation functions, sort under-the-…
Browse files Browse the repository at this point in the history
…hood for user if data is unsorted (with warning)
  • Loading branch information
MarcoGorelli committed Sep 16, 2023
1 parent cbf30ce commit 9f8d828
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 27 deletions.
53 changes: 47 additions & 6 deletions crates/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#![allow(ambiguous_glob_reexports)]
//! Domain specific language for the Lazy API.
#[cfg(feature = "rolling_window")]
use polars_core::utils::ensure_sorted_arg;
#[cfg(feature = "dtype-categorical")]
pub mod cat;
#[cfg(feature = "dtype-categorical")]
Expand Down Expand Up @@ -1230,19 +1228,52 @@ impl Expr {
move |s| {
let mut by = s[1].clone();
by = by.rechunk();
let s = &s[0];
let series: Series;

polars_ensure!(
options.weights.is_none(),
ComputeError: "`weights` is not supported in 'rolling by' expression"
);
let (by, tz) = match by.dtype() {
let (mut by, tz) = match by.dtype() {
DataType::Datetime(tu, tz) => {
(by.cast(&DataType::Datetime(*tu, None))?, tz)
},
_ => (by.clone(), &None),
};
ensure_sorted_arg(&by, expr_name)?;
let sorting_indices;
let original_indices;
let by_flag = by.is_sorted_flag();
match by_flag {
IsSorted::Ascending => {
series = s[0].clone();
original_indices = None;
},
IsSorted::Descending => {
series = s[0].reverse();
by = by.reverse();
original_indices = None;
},
IsSorted::Not => {
eprintln!(
"PolarsPerformanceWarning: Series is not known to be \
sorted by `by` column, so Polars is temporarily \
sorting it for you.\n\
You can silence this warning by:\n\
- passing `warn_if_unsorted=False`;\n\
- sorting your data by your `by` column beforehand;\n\
- setting `.set_sorted()` if you already know your data is sorted\n\
before passing it to the rolling aggregation function"
);
sorting_indices = by.arg_sort(Default::default());
unsafe { by = by.take_unchecked(&sorting_indices)? };
unsafe { series = s[0].take_unchecked(&sorting_indices)? };
let int_range =
UInt32Chunked::from_iter_values("", 0..s[0].len() as u32)
.into_series();
original_indices =
unsafe { int_range.take_unchecked(&sorting_indices) }.ok()
},
};
let by = by.datetime().unwrap();
let by_values = by.cont_slice().map_err(|_| {
polars_err!(
Expand All @@ -1264,7 +1295,17 @@ impl Expr {
fn_params: options.fn_params.clone(),
};

rolling_fn(s, options).map(Some)
match by_flag {
IsSorted::Ascending => rolling_fn(&series, options).map(Some),
IsSorted::Descending => {
Ok(rolling_fn(&series, options)?.reverse()).map(Some)
},
IsSorted::Not => {
let res = rolling_fn(&series, options)?;
let indices = &original_indices.unwrap().arg_sort(Default::default());
unsafe { res.take_unchecked(indices) }.map(Some)
},
}
},
&[col(by)],
output_type,
Expand Down
3 changes: 3 additions & 0 deletions crates/polars-time/src/chunkedarray/rolling_window/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ pub struct RollingOptions {
pub closed_window: Option<ClosedWindow>,
/// Optional parameters for the rolling function
pub fn_params: DynArgs,
/// Warn if data is not known to be sorted by `by` column (if passed)
pub warn_if_unsorted: bool,
}

#[cfg(feature = "rolling_window")]
Expand All @@ -51,6 +53,7 @@ impl Default for RollingOptions {
by: None,
closed_window: None,
fn_params: None,
warn_if_unsorted: true,
}
}
}
Expand Down
51 changes: 45 additions & 6 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5279,6 +5279,7 @@ def rolling_min(
center: bool = False,
by: str | None = None,
closed: ClosedInterval = "left",
warn_if_unsorted: bool = True,
) -> Self:
"""
Apply a rolling min (moving min) over the values in this array.
Expand Down Expand Up @@ -5350,6 +5351,8 @@ def rolling_min(
closed : {'left', 'right', 'both', 'none'}
Define which sides of the temporal interval are closed (inclusive); only
applicable if `by` has been set.
warn_if_unsorted
Warn if data is not known to be sorted by `by` column (if passed).
Warnings
--------
Expand Down Expand Up @@ -5475,7 +5478,7 @@ def rolling_min(
)
return self._from_pyexpr(
self._pyexpr.rolling_min(
window_size, weights, min_periods, center, by, closed
window_size, weights, min_periods, center, by, closed, warn_if_unsorted
)
)

Expand All @@ -5489,6 +5492,7 @@ def rolling_max(
center: bool = False,
by: str | None = None,
closed: ClosedInterval = "left",
warn_if_unsorted: bool = True,
) -> Self:
"""
Apply a rolling max (moving max) over the values in this array.
Expand Down Expand Up @@ -5556,6 +5560,8 @@ def rolling_max(
closed : {'left', 'right', 'both', 'none'}
Define which sides of the temporal interval are closed (inclusive); only
applicable if `by` has been set.
warn_if_unsorted
Warn if data is not known to be sorted by `by` column (if passed).
Warnings
--------
Expand Down Expand Up @@ -5708,7 +5714,7 @@ def rolling_max(
)
return self._from_pyexpr(
self._pyexpr.rolling_max(
window_size, weights, min_periods, center, by, closed
window_size, weights, min_periods, center, by, closed, warn_if_unsorted
)
)

Expand All @@ -5722,6 +5728,7 @@ def rolling_mean(
center: bool = False,
by: str | None = None,
closed: ClosedInterval = "left",
warn_if_unsorted: bool = True,
) -> Self:
"""
Apply a rolling mean (moving mean) over the values in this array.
Expand Down Expand Up @@ -5793,6 +5800,8 @@ def rolling_mean(
closed : {'left', 'right', 'both', 'none'}
Define which sides of the temporal interval are closed (inclusive); only
applicable if `by` has been set.
warn_if_unsorted
Warn if data is not known to be sorted by `by` column (if passed).
Warnings
--------
Expand Down Expand Up @@ -5945,7 +5954,13 @@ def rolling_mean(
)
return self._from_pyexpr(
self._pyexpr.rolling_mean(
window_size, weights, min_periods, center, by, closed
window_size,
weights,
min_periods,
center,
by,
closed,
warn_if_unsorted,
)
)

Expand All @@ -5959,6 +5974,7 @@ def rolling_sum(
center: bool = False,
by: str | None = None,
closed: ClosedInterval = "left",
warn_if_unsorted: bool = True,
) -> Self:
"""
Apply a rolling sum (moving sum) over the values in this array.
Expand Down Expand Up @@ -6026,6 +6042,8 @@ def rolling_sum(
closed : {'left', 'right', 'both', 'none'}
Define which sides of the temporal interval are closed (inclusive); only
applicable if `by` has been set.
warn_if_unsorted
Warn if data is not known to be sorted by `by` column (if passed).
Warnings
--------
Expand Down Expand Up @@ -6178,7 +6196,7 @@ def rolling_sum(
)
return self._from_pyexpr(
self._pyexpr.rolling_sum(
window_size, weights, min_periods, center, by, closed
window_size, weights, min_periods, center, by, closed, warn_if_unsorted
)
)

Expand All @@ -6193,6 +6211,7 @@ def rolling_std(
by: str | None = None,
closed: ClosedInterval = "left",
ddof: int = 1,
warn_if_unsorted: bool = True,
) -> Self:
"""
Compute a rolling standard deviation.
Expand Down Expand Up @@ -6262,6 +6281,8 @@ def rolling_std(
applicable if `by` has been set.
ddof
"Delta Degrees of Freedom": The divisor for a length N window is N - ddof
warn_if_unsorted
Warn if data is not known to be sorted by `by` column (if passed).
Warnings
--------
Expand Down Expand Up @@ -6414,7 +6435,14 @@ def rolling_std(
)
return self._from_pyexpr(
self._pyexpr.rolling_std(
window_size, weights, min_periods, center, by, closed, ddof
window_size,
weights,
min_periods,
center,
by,
closed,
ddof,
warn_if_unsorted,
)
)

Expand All @@ -6429,6 +6457,7 @@ def rolling_var(
by: str | None = None,
closed: ClosedInterval = "left",
ddof: int = 1,
warn_if_unsorted: bool = True,
) -> Self:
"""
Compute a rolling variance.
Expand Down Expand Up @@ -6498,6 +6527,8 @@ def rolling_var(
applicable if `by` has been set.
ddof
"Delta Degrees of Freedom": The divisor for a length N window is N - ddof
warn_if_unsorted
Warn if data is not known to be sorted by `by` column (if passed).
Warnings
--------
Expand Down Expand Up @@ -6657,6 +6688,7 @@ def rolling_var(
by,
closed,
ddof,
warn_if_unsorted,
)
)

Expand All @@ -6670,6 +6702,7 @@ def rolling_median(
center: bool = False,
by: str | None = None,
closed: ClosedInterval = "left",
warn_if_unsorted: bool = True,
) -> Self:
"""
Compute a rolling median.
Expand Down Expand Up @@ -6737,6 +6770,8 @@ def rolling_median(
closed : {'left', 'right', 'both', 'none'}
Define which sides of the temporal interval are closed (inclusive); only
applicable if `by` has been set.
warn_if_unsorted
Warn if data is not known to be sorted by `by` column (if passed).
Warnings
--------
Expand Down Expand Up @@ -6815,7 +6850,7 @@ def rolling_median(
)
return self._from_pyexpr(
self._pyexpr.rolling_median(
window_size, weights, min_periods, center, by, closed
window_size, weights, min_periods, center, by, closed, warn_if_unsorted
)
)

Expand All @@ -6831,6 +6866,7 @@ def rolling_quantile(
center: bool = False,
by: str | None = None,
closed: ClosedInterval = "left",
warn_if_unsorted: bool = True,
) -> Self:
"""
Compute a rolling quantile.
Expand Down Expand Up @@ -6902,6 +6938,8 @@ def rolling_quantile(
closed : {'left', 'right', 'both', 'none'}
Define which sides of the temporal interval are closed (inclusive); only
applicable if `by` has been set.
warn_if_unsorted
Warn if data is not known to be sorted by `by` column (if passed).
Warnings
--------
Expand Down Expand Up @@ -7016,6 +7054,7 @@ def rolling_quantile(
center,
by,
closed,
warn_if_unsorted,
)
)

Expand Down
Loading

0 comments on commit 9f8d828

Please sign in to comment.