From 0803b43a1a545d9e383d3c6a7764f75a9d6d930d Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Wed, 19 Jun 2024 10:36:03 -0400 Subject: [PATCH 1/2] Fix n_to_one repartition --- src/dask_awkward/lib/structure.py | 9 +++++++-- tests/test_structure.py | 13 +++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/src/dask_awkward/lib/structure.py b/src/dask_awkward/lib/structure.py index 7afba822..af72d527 100644 --- a/src/dask_awkward/lib/structure.py +++ b/src/dask_awkward/lib/structure.py @@ -1412,12 +1412,17 @@ def simple_repartition_layer( layer: dict[tuple[str, int], tuple[Any, ...]] = {} new_divisions: tuple[Any, ...] if n_to_one: - for i in range(0, arr.npartitions, n_to_one): - layer[(key, i)] = (_subcat,) + tuple( + for i0, i in enumerate(range(0, arr.npartitions, n_to_one)): + layer[(key, i0)] = (_subcat,) + tuple( (arr.name, part) for part in range(i, min(i + n_to_one, arr.npartitions)) ) new_divisions = arr.divisions[::n_to_one] + if arr.npartitions % n_to_one: + new_divisions = new_divisions + (arr.divisions[-1],) + layer[(key, i0 + 1)] = (_subcat,) + tuple( + (arr.name, part) for part in range(new_divisions[-2], new_divisions[-1]) + ) elif one_to_n: for i in range(arr.npartitions): for part in range(one_to_n): diff --git a/tests/test_structure.py b/tests/test_structure.py index 8e003d97..22a3d500 100644 --- a/tests/test_structure.py +++ b/tests/test_structure.py @@ -554,6 +554,19 @@ def test_repartition_one_to_n(daa): assert_eq(daa, daa1, check_divisions=False) +def test_repartition_n_to_one(): + daa = dak.from_lists([[[1, 2, 3], [], [4, 5]]] * 52) + daa2 = daa.repartition(n_to_one=52) + assert daa2.npartitions == 1 + assert daa.compute().to_list() == daa2.compute().to_list() + daa2 = daa.repartition(n_to_one=53) + assert daa2.npartitions == 1 + assert daa.compute().to_list() == daa2.compute().to_list() + daa2 = daa.repartition(n_to_one=2) + assert daa2.npartitions == 26 + assert daa.compute().to_list() == daa2.compute().to_list() + + def test_repartition_no_change(daa): daa1 = daa.repartition(divisions=(0, 5, 10, 15)) assert daa1.npartitions == 3 From 5a0af7b9f4f82470c66534f4313e4511960e754b Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Wed, 19 Jun 2024 10:38:06 -0400 Subject: [PATCH 2/2] test unever case --- tests/test_structure.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_structure.py b/tests/test_structure.py index 22a3d500..9b761339 100644 --- a/tests/test_structure.py +++ b/tests/test_structure.py @@ -565,6 +565,9 @@ def test_repartition_n_to_one(): daa2 = daa.repartition(n_to_one=2) assert daa2.npartitions == 26 assert daa.compute().to_list() == daa2.compute().to_list() + daa2 = daa.repartition(n_to_one=10) + assert daa2.npartitions == 6 + assert daa.compute().to_list() == daa2.compute().to_list() def test_repartition_no_change(daa):