Skip to content

Commit

Permalink
fix: allow for Nones in repartition n_to_one
Browse files Browse the repository at this point in the history
  • Loading branch information
martindurant committed Jun 19, 2024
1 parent 38c065b commit 188483e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/dask_awkward/lib/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,7 +1421,7 @@ def simple_repartition_layer(
if arr.npartitions % n_to_one:
new_divisions = new_divisions + (arr.divisions[-1],)
layer[(key, i0 + 1)] = (_subcat,) + tuple(
(arr.name, part) for part in range(new_divisions[-2], new_divisions[-1])
(arr.name, part0) for part0 in range(len(layer), arr.npartitions)
)
elif one_to_n:
for i in range(arr.npartitions):
Expand Down
5 changes: 4 additions & 1 deletion tests/test_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ def test_repartition_one_to_n(daa):


def test_repartition_n_to_one():
daa = dak.from_lists([[[1, 2, 3], [], [4, 5]]] * 52)
daa = dak.from_lists([[[1, 2, 3], [], [4, 5]] * 2] * 52)
daa2 = daa.repartition(n_to_one=52)
assert daa2.npartitions == 1
assert daa.compute().to_list() == daa2.compute().to_list()
Expand All @@ -568,6 +568,9 @@ def test_repartition_n_to_one():
daa2 = daa.repartition(n_to_one=10)
assert daa2.npartitions == 6
assert daa.compute().to_list() == daa2.compute().to_list()
daa._divisions = (None,) * len(daa.divisions)
assert daa2.npartitions == 6
assert daa.compute().to_list() == daa2.compute().to_list()


def test_repartition_no_change(daa):
Expand Down

0 comments on commit 188483e

Please sign in to comment.