Skip to content

Commit

Permalink
fix(python): fix as_dict with include_key=False for partition_by (#9865)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller authored Oct 8, 2023
1 parent 969a416 commit 740cf48
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
15 changes: 12 additions & 3 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -7409,10 +7409,19 @@ def partition_by(
]

if as_dict:
if len(by) == 1:
return {df[by][0, 0]: df for df in partitions}
df = self._from_pydf(self._df)
if include_key:
if len(by) == 1:
names = [p[by[0]][0] for p in partitions]
else:
names = [p.select(by).row(0) for p in partitions]
else:
return {df[by].row(0): df for df in partitions}
if len(by) == 1:
names = df[by[0]].unique(maintain_order=True).to_list()
else:
names = df.select(by).unique(maintain_order=True).rows()

return dict(zip(names, partitions))

return partitions

Expand Down
13 changes: 13 additions & 0 deletions py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -2610,6 +2610,19 @@ def test_partition_by() -> None:
"b": [1, 3],
}

# test with both as_dict and include_key=False
df = pl.DataFrame(
{
"a": pl.int_range(0, 100, dtype=pl.UInt8, eager=True),
"b": pl.int_range(0, 100, dtype=pl.UInt8, eager=True),
"c": pl.int_range(0, 100, dtype=pl.UInt8, eager=True),
"d": pl.int_range(0, 100, dtype=pl.UInt8, eager=True),
}
).sample(n=100_000, with_replacement=True, shuffle=True)

partitions = df.partition_by(["a", "b"], as_dict=True, include_key=False)
assert all(key == value.row(0) for key, value in partitions.items())


def test_list_of_list_of_struct() -> None:
expected = [{"list_of_list_of_struct": [[{"a": 1}, {"a": 2}]]}]
Expand Down

0 comments on commit 740cf48

Please sign in to comment.