Skip to content

Commit

Permalink
add new sum type for partition compatibility checks
Browse files Browse the repository at this point in the history
  • Loading branch information
douglasdavis committed Jun 29, 2023
1 parent 99f3dd9 commit fa99273
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 18 deletions.
66 changes: 60 additions & 6 deletions src/dask_awkward/lib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sys
import warnings
from collections.abc import Callable, Hashable, Mapping, Sequence
from enum import IntEnum
from functools import cached_property, partial
from numbers import Number
from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload
Expand Down Expand Up @@ -818,7 +819,7 @@ def _getitem_outer_bool_or_int_lazy_array(
self, where: Array | tuple[Any, ...]
) -> Any:
ba = where if isinstance(where, Array) else where[0]
if not compatible_partitions(self, ba):
if partition_compatibility(self, ba) == PartitionCompatibility.NO:
raise IncompatiblePartitions("getitem", self, ba)

new_meta: Any | None = None
Expand Down Expand Up @@ -1279,8 +1280,8 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
new_meta = ufunc(*inputs_meta)

dak_arrays = tuple(a for a in inputs if isinstance(a, Array))
if not compatible_partitions(*dak_arrays):
raise IncompatiblePartitions(*dak_arrays)
if partition_compatibility(*dak_arrays) == PartitionCompatibility.NO:
raise IncompatiblePartitions(ufunc.__name__, *dak_arrays)

return map_partitions(
ufunc,
Expand All @@ -1290,7 +1291,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
**kwargs,
)

def __array__(self, *args, **kwargs):
def __array__(self, *_, **__):
raise NotImplementedError

def to_delayed(self, optimize_graph: bool = True) -> list[Delayed]:
Expand Down Expand Up @@ -2136,6 +2137,12 @@ def compatible_partitions(*args: Array) -> bool:
``True`` if the collections appear to be equally partitioned.
"""

warnings.warn(
"dak.compatible_partitions is deprecated, please use dak.partition_compatibility.",
DeprecationWarning,
)

# first check to see if all arguments have the same number of
# partitions; this is _always_ defined.

Expand All @@ -2153,8 +2160,8 @@ def compatible_partitions(*args: Array) -> bool:
if arg.known_divisions:
refarr = arg
break
# if we never hit the break just return True because we have no
# known division Arrays.
# if we never hit the break just return True because we have no
# known division Arrays.
else:
return True

Expand Down Expand Up @@ -2271,3 +2278,50 @@ def make_unknown_length(array: ak.Array) -> ak.Array:
"""
return ak.Array(ak.to_layout(array).to_typetracer(forget_length=True))


class PartitionCompatibility(IntEnum):
YES = 0
NO = 1
MAYBE = 2

@staticmethod
def check(*args: Array) -> PartitionCompatibility:
# first check to see if all arguments have the same number of
# partitions; this is _always_ defined.
for arg in args[1:]:
if args[0].npartitions != arg.npartitions:
return PartitionCompatibility.NO

# now we check if divisions are compatible. Sometimes divisions
# are unknown and we just have a tuple of Nones; but if divisions
# are known we want to check if they are compatible.
refarr: Array | None = None
for arg in args:
if arg.known_divisions:
refarr = arg
break
# if we never hit the break just return True because we have no
# known division Arrays.
else:
return PartitionCompatibility.MAYBE

# at this point we have a reference array to compare divisions
ngood = 0
for arg in args:
if arg.known_divisions:
if arg.divisions != refarr.divisions:
return PartitionCompatibility.NO
else:
ngood += 1

# the ngood counter tells us if all divisions were present and are equal
if ngood == len(args):
return PartitionCompatibility.YES

# if ngood is less than len(args) then we fall back on maybe compatible
return PartitionCompatibility.MAYBE


def partition_compatibility(*args: Array) -> PartitionCompatibility:
return PartitionCompatibility.check(*args)
5 changes: 3 additions & 2 deletions src/dask_awkward/lib/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from dask_awkward.layers import AwkwardMaterializedLayer
from dask_awkward.lib.core import (
Array,
compatible_partitions,
PartitionCompatibility,
map_partitions,
new_array_object,
partition_compatibility,
)
from dask_awkward.utils import DaskAwkwardNotImplemented, IncompatiblePartitions

Expand Down Expand Up @@ -60,7 +61,7 @@ def concatenate(
return new_array_object(hlg, name, meta=meta, npartitions=npartitions)

if axis > 0:
if not compatible_partitions(*arrays):
if partition_compatibility(*arrays) == PartitionCompatibility.NO:
raise IncompatiblePartitions("concatenate", *arrays)

fn = _ConcatenateFnAxisGT0(axis=axis)
Expand Down
19 changes: 12 additions & 7 deletions src/dask_awkward/lib/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
import awkward as ak
from dask.base import is_dask_collection

from dask_awkward.lib.core import Array, compatible_partitions, map_partitions
from dask_awkward.lib.core import (
Array,
PartitionCompatibility,
map_partitions,
partition_compatibility,
)
from dask_awkward.utils import (
DaskAwkwardNotImplemented,
IncompatiblePartitions,
Expand Down Expand Up @@ -201,7 +206,7 @@ def broadcast_arrays(*arrays, highlevel=True, **kwargs):
if not highlevel:
raise ValueError("Only highlevel=True is supported")

if not compatible_partitions(*arrays):
if partition_compatibility(*arrays) == PartitionCompatibility.NO:
raise IncompatiblePartitions("broadcast_arrays", *arrays)

array_metas = (array._meta for array in arrays)
Expand Down Expand Up @@ -465,7 +470,7 @@ def isclose(
if not highlevel:
raise ValueError("Only highlevel=True is supported")

if not compatible_partitions(a, b):
if partition_compatibility(a, b) == PartitionCompatibility.NO:
raise IncompatiblePartitions("isclose", a, b)

return map_partitions(
Expand Down Expand Up @@ -514,7 +519,7 @@ def local_index(array, axis=-1, highlevel=True, behavior=None):

@borrow_docstring(ak.mask)
def mask(array, mask, valid_when=True, highlevel=True, behavior=None):
if not compatible_partitions(array, mask):
if partition_compatibility(array, mask) == PartitionCompatibility.NO:
raise IncompatiblePartitions("mask", array, mask)
return map_partitions(
ak.mask,
Expand Down Expand Up @@ -841,7 +846,7 @@ def where(
"The condition argugment to where must be a dask_awkward.Array"
)

if not compatible_partitions(*dask_args):
if partition_compatibility(*dask_args) == PartitionCompatibility.NO:
raise IncompatiblePartitions("where", *dask_args)

return map_partitions(
Expand Down Expand Up @@ -890,8 +895,8 @@ def with_field(base, what, where=None, highlevel=True, behavior=None):
maybe_dask_args = [base, what]
dask_args = tuple(arg for arg in maybe_dask_args if is_dask_collection(arg))

if not compatible_partitions(*dask_args):
raise IncompatiblePartitions("with_field", base, what)
if partition_compatibility(*dask_args) == PartitionCompatibility.NO:
raise IncompatiblePartitions("with_field", *dask_args)
return map_partitions(
_WithFieldFn(where=where, highlevel=highlevel, behavior=behavior),
base,
Expand Down
10 changes: 7 additions & 3 deletions src/dask_awkward/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from __future__ import annotations

from collections.abc import Callable, Mapping
from typing import Any, TypeVar
from typing import TYPE_CHECKING, Any, TypeVar

if TYPE_CHECKING:
from dask_awkward.lib.core import Array


T = TypeVar("T")

Expand All @@ -19,12 +23,12 @@ def __init__(self, msg: str | None = None) -> None:


class IncompatiblePartitions(ValueError):
def __init__(self, name, *args):
def __init__(self, name: str, *args: Array) -> None:
msg = self.divisions_msg(name, *args)
super().__init__(msg)

@staticmethod
def divisions_msg(name: str, *args: Any) -> str:
def divisions_msg(name: str, *args: Array) -> str:
msg = f"The inputs to {name} are incompatibly partitioned\n"
for i, arg in enumerate(args):
msg += f"- arg{i} divisions: {arg.divisions}\n"
Expand Down

0 comments on commit fa99273

Please sign in to comment.