Skip to content

Commit

Permalink
fix: Exclude index from expansion in rolling/group_by_dynamic (#17086)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jun 20, 2024
1 parent be392c4 commit 742c54e
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 7 deletions.
15 changes: 13 additions & 2 deletions crates/polars-plan/src/plans/conversion/dsl_to_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -733,21 +733,27 @@ fn resolve_group_by(
) -> PolarsResult<(Vec<ExprIR>, Vec<ExprIR>, SchemaRef)> {
let current_schema = lp_arena.get(input).schema(lp_arena);
let current_schema = current_schema.as_ref();
let keys = rewrite_projections(keys, current_schema, &[])?;
let aggs = rewrite_projections(aggs, current_schema, &keys)?;
let mut keys = rewrite_projections(keys, current_schema, &[])?;

// Initialize schema from keys
let mut schema = expressions_to_schema(&keys, current_schema, Context::Default)?;

#[allow(unused_mut)]
let mut pop_keys = false;
// Add dynamic groupby index column(s)
// Also add index columns to keys for expression expansion.
#[cfg(feature = "dynamic_group_by")]
{
if let Some(options) = _options.rolling.as_ref() {
let name = &options.index_column;
let dtype = current_schema.try_get(name)?;
keys.push(col(name));
pop_keys = true;
schema.with_column(name.clone(), dtype.clone());
} else if let Some(options) = _options.dynamic.as_ref() {
let name = &options.index_column;
keys.push(col(name));
pop_keys = true;
let dtype = current_schema.try_get(name)?;
if options.include_boundaries {
schema.with_column("_lower_boundary".into(), dtype.clone());
Expand All @@ -758,6 +764,11 @@ fn resolve_group_by(
}
let keys_index_len = schema.len();

let aggs = rewrite_projections(aggs, current_schema, &keys)?;
if pop_keys {
let _ = keys.pop();
}

// Add aggregation column(s)
let aggs_schema = expressions_to_schema(&aggs, current_schema, Context::Aggregation)?;
schema.merge(aggs_schema);
Expand Down
36 changes: 31 additions & 5 deletions py-polars/tests/unit/operations/test_group_by_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,15 +450,13 @@ def test_no_sorted_no_error() -> None:
"dt": [datetime(2001, 1, 1), datetime(2001, 1, 2)],
}
)
result = df.group_by_dynamic("dt", every="1h").agg(
pl.all().count().name.suffix("_foo")
)
result = df.group_by_dynamic("dt", every="1h").agg(pl.len().alias("count"))
expected = pl.DataFrame(
{
"dt": [datetime(2001, 1, 1), datetime(2001, 1, 2)],
"dt_foo": [1, 1],
"count": [1, 1],
},
schema_overrides={"dt_foo": pl.UInt32},
schema_overrides={"count": pl.get_index_type()},
)
assert_frame_equal(result, expected)

Expand Down Expand Up @@ -1017,3 +1015,31 @@ def test_group_by_dynamic_get() -> None:
],
"get": [1, 3, 5, 7],
}


def test_group_by_dynamic_exclude_index_from_expansion_17075() -> None:
lf = pl.LazyFrame(
{
"time": pl.datetime_range(
start=datetime(2021, 12, 16),
end=datetime(2021, 12, 16, 3),
interval="30m",
eager=True,
),
"n": range(7),
"m": range(7),
}
)

assert lf.group_by_dynamic(
"time", every="1h", closed="right"
).last().collect().to_dict(as_series=False) == {
"time": [
datetime(2021, 12, 15, 23, 0),
datetime(2021, 12, 16, 0, 0),
datetime(2021, 12, 16, 1, 0),
datetime(2021, 12, 16, 2, 0),
],
"n": [0, 2, 4, 6],
"m": [0, 2, 4, 6],
}

0 comments on commit 742c54e

Please sign in to comment.