From cb22d77516abdc6578475c7e8bc011a77f45e58c Mon Sep 17 00:00:00 2001 From: Angus Hollands Date: Sat, 11 Nov 2023 12:00:08 +0000 Subject: [PATCH] refactor: enforce more type hints --- src/dask_awkward/lib/structure.py | 303 +++++++++++++++++++----------- 1 file changed, 189 insertions(+), 114 deletions(-) diff --git a/src/dask_awkward/lib/structure.py b/src/dask_awkward/lib/structure.py index e781985a..fc0bbf09 100644 --- a/src/dask_awkward/lib/structure.py +++ b/src/dask_awkward/lib/structure.py @@ -2,7 +2,7 @@ import builtins import warnings -from collections.abc import Sequence +from collections.abc import Iterable, Mapping, Sequence from numbers import Number from typing import TYPE_CHECKING, Any @@ -87,27 +87,19 @@ def __call__(self, *arrays): @borrow_docstring(ak.argcartesian) def argcartesian( - arrays, - axis=1, - nested=None, - parameters=None, - with_name=None, - highlevel=True, - behavior=None, -): + arrays: Sequence[Array] | Mapping[str, Array], + axis: int = 1, + nested: bool | Iterable[str | int] | None = None, + parameters: dict[str, Any] | None = None, + with_name: str | None = None, + highlevel: bool = True, + behavior: Mapping | None = None, +) -> Array: if not highlevel: raise ValueError("Only highlevel=True is supported") - if axis == 1: - meta = ak.cartesian( - [array._meta for array in arrays], - axis=axis, - nested=nested, - parameters=parameters, - with_name=with_name, - highlevel=highlevel, - behavior=behavior, - ) + # FIXME: resolve negative axis + if axis >= 1: fn = _ArgCartesianFn( axis=axis, nested=nested, @@ -116,9 +108,7 @@ def argcartesian( highlevel=highlevel, behavior=behavior, ) - return map_partitions( - fn, *arrays, label="argcartesian", output_divisions=1, meta=meta - ) + return map_partitions(fn, *arrays, label="argcartesian", output_divisions=1) raise DaskAwkwardNotImplemented("TODO") @@ -134,26 +124,27 @@ def __call__(self, array): @borrow_docstring(ak.argcombinations) def argcombinations( - array, - n, - replacement=False, - axis=1, - fields=None, - parameters=None, - with_name=None, - highlevel=True, - behavior=None, -): + array: Array, + n: int, + replacement: bool = False, + axis: int = 1, + fields: Sequence[str] | None = None, + parameters: dict[str, Any] | None = None, + with_name: str | None = None, + highlevel: bool = True, + behavior: Mapping | None = None, +) -> Array: if not highlevel: raise ValueError("Only highlevel=True is supported") if fields is not None and len(fields) != n: raise ValueError("if provided, the length of 'fields' must be 'n'") + # FIXME: resolve negative axis if axis < 0: raise ValueError("the 'axis' for argcombinations must be non-negative") - if axis != 0: + if axis >= 0: fn = _ArgCombinationsFn( n=n, replacement=replacement, @@ -183,13 +174,15 @@ def __call__(self, array): @borrow_docstring(ak.argsort) def argsort( - array, - axis=-1, - ascending=True, - stable=True, - highlevel=True, - behavior=None, -): + array: Array, + axis: int = -1, + ascending: bool = True, + stable: bool = True, + highlevel: bool = True, + behavior: Mapping | None = None, +) -> Array: + if not highlevel: + raise ValueError("Only highlevel=True is supported") if axis == 0: raise DaskAwkwardNotImplemented("TODO") fn = _ArgsortFn( @@ -211,7 +204,9 @@ def __call__(self, *arrays): @borrow_docstring(ak.broadcast_arrays) -def broadcast_arrays(*arrays, highlevel=True, **kwargs): +def broadcast_arrays( + *arrays: Array, highlevel: bool = True, **kwargs: Any +) -> list[Array]: if not highlevel: raise ValueError("Only highlevel=True is supported") @@ -247,14 +242,14 @@ def __call__(self, *arrays): @borrow_docstring(ak.cartesian) def cartesian( - arrays, - axis=1, - nested=None, - parameters=None, - with_name=None, - highlevel=True, - behavior=None, -): + arrays: Sequence[Array] | Mapping[str, Array], + axis: int = 1, + nested: bool | Iterable[str | int] | None = None, + parameters: dict[str, Any] | None = None, + with_name: str | None = None, + highlevel: bool = True, + behavior: Mapping | None = None, +) -> Array: if not highlevel: raise ValueError("Only highlevel=True is supported") if axis == 1: @@ -287,10 +282,10 @@ def combinations( replacement: bool = False, axis: int = 1, fields: list[str] | None = None, - parameters: dict | None = None, + parameters: Mapping[str, Any] | None = None, with_name: str | None = None, highlevel: bool = True, - behavior: dict | None = None, + behavior: Mapping | None = None, ) -> Array: if not highlevel: raise ValueError("Only highlevel=True is supported") @@ -319,10 +314,8 @@ def combinations( @borrow_docstring(ak.copy) -def copy(array): - raise DaskAwkwardNotImplemented( - "This function is not necessary in the context of dask-awkward." - ) +def copy(array: Array) -> Array: + return array class _FillNoneFn: @@ -338,9 +331,9 @@ def __call__(self, arr): def fill_none( array: Array, value: Any, - axis: int = -1, + axis: int | None = -1, highlevel: bool = True, - behavior: dict | None = None, + behavior: Mapping | None = None, ) -> Array: if not highlevel: raise ValueError("Only highlevel=True is supported") @@ -362,7 +355,7 @@ def drop_none( array: Array, axis: int | None = None, highlevel: bool = True, - behavior: dict | None = None, + behavior: Mapping | None = None, ) -> Array: if not highlevel: raise ValueError("Only highlevel=True is supported") @@ -384,8 +377,8 @@ def firsts( array: Array, axis: int = 1, highlevel: bool = True, - behavior: dict | None = None, -) -> Any: + behavior: Mapping | None = None, +) -> Array: if axis == 1: return map_partitions( _FirstsFn( @@ -415,7 +408,7 @@ def flatten( array: Array, axis: int | None = 1, highlevel: bool = True, - behavior: dict | None = None, + behavior: Mapping | None = None, ) -> Array: if not highlevel: raise ValueError("Only highlevel=True is supported") @@ -432,7 +425,9 @@ def flatten( @borrow_docstring(ak.from_regular) -def from_regular(array, axis=1, highlevel=True, behavior=None): +def from_regular( + array: Array, axis: int = 1, highlevel: bool = True, behavior: Mapping | None = None +) -> Array: if not highlevel: raise ValueError("Only highlevel=True is supported") @@ -450,7 +445,13 @@ def from_regular(array, axis=1, highlevel=True, behavior=None): @borrow_docstring(ak.full_like) -def full_like(array, fill_value, highlevel=True, behavior=None, dtype=None): +def full_like( + array: Array, + fill_value: Any, + highlevel: bool = True, + behavior: Mapping | None = None, + dtype: np.dtype | str | None = None, +) -> Array: if not highlevel: raise ValueError("Only highlevel=True is supported") @@ -474,8 +475,14 @@ def full_like(array, fill_value, highlevel=True, behavior=None, dtype=None): @borrow_docstring(ak.isclose) def isclose( - a, b, rtol=1e-05, atol=1e-08, equal_nan=False, highlevel=True, behavior=None -): + a: Array, + b: Array, + rtol: float = 1e-05, + atol: float = 1e-08, + equal_nan: bool = False, + highlevel: bool = True, + behavior: Mapping | None = None, +) -> Array: if not highlevel: raise ValueError("Only highlevel=True is supported") @@ -505,29 +512,41 @@ def __call__(self, array): @borrow_docstring(ak.is_none) -def is_none(array, axis=0, highlevel=True, behavior=None): +def is_none( + array: Array, axis: int = 0, highlevel: bool = True, behavior: Mapping | None = None +) -> Array: fn = _IsNoneFn(axis=axis, highlevel=highlevel, behavior=behavior) return map_partitions(fn, array, label="is-none", output_divisions=1) @borrow_docstring(ak.local_index) -def local_index(array, axis=-1, highlevel=True, behavior=None): +def local_index( + array: Array, + axis: int = -1, + highlevel: bool = True, + behavior: Mapping | None = None, +) -> Array: if not highlevel: raise ValueError("Only highlevel=True is supported") - if axis == 0: - DaskAwkwardNotImplemented("axis=0 for local_index is not supported") - if axis and axis != 0: - return map_partitions( - ak.local_index, - array, - axis=axis, - highlevel=highlevel, - behavior=behavior, - ) + if axis <= 0: + DaskAwkwardNotImplemented("axis<=0 for local_index is not supported") + return map_partitions( + ak.local_index, + array, + axis=axis, + highlevel=highlevel, + behavior=behavior, + ) @borrow_docstring(ak.mask) -def mask(array, mask, valid_when=True, highlevel=True, behavior=None): +def mask( + array: Array, + mask: Array, + valid_when: bool = True, + highlevel: bool = True, + behavior: Mapping | None = None, +) -> Array: if partition_compatibility(array, mask) == PartitionCompatibility.NO: raise IncompatiblePartitions("mask", array, mask) return map_partitions( @@ -573,7 +592,7 @@ def num( array: Any, axis: int = 1, highlevel: bool = True, - behavior: dict | None = None, + behavior: Mapping | None = None, ) -> Any: if not highlevel: raise ValueError("Only highlevel=True is supported") @@ -613,7 +632,7 @@ def num( def ones_like( array: Array, highlevel: bool = True, - behavior: dict | None = None, + behavior: Mapping | None = None, dtype: DTypeLike | None = None, ) -> Array: if not highlevel: @@ -629,7 +648,9 @@ def ones_like( @borrow_docstring(ak.to_packed) -def to_packed(array, highlevel=True, behavior=None): +def to_packed( + array: Array, highlevel: bool = True, behavior: Mapping | None = None +) -> Array: if not highlevel: raise ValueError("Only highlevel=True is supported") @@ -658,13 +679,16 @@ def __call__(self, array): @borrow_docstring(ak.pad_none) def pad_none( - array, - target, - axis=1, - clip=False, - highlevel=True, - behavior=None, -): + array: Array, + target: bool, + axis: int = 1, + clip: bool = False, + highlevel: bool = True, + behavior: Mapping | None = None, +) -> Array: + if not highlevel: + raise ValueError("Only highlevel=True is supported") + if axis == 0: DaskAwkwardNotImplemented("axis=0 for pad_none is not supported") return map_partitions( @@ -681,7 +705,9 @@ def pad_none( @borrow_docstring(ak.ravel) -def ravel(array, highlevel=True, behavior=None): +def ravel( + array: Array, highlevel: bool = True, behavior: Mapping | None = None +) -> Array: if not highlevel: raise ValueError("Only highlevel=True is supported") @@ -698,7 +724,9 @@ def ravel(array, highlevel=True, behavior=None): @borrow_docstring(ak.run_lengths) -def run_lengths(array, highlevel=True, behavior=None): +def run_lengths( + array: Array, highlevel: bool = True, behavior: Mapping | None = None +) -> Array: if not highlevel: raise ValueError("Only highlevel=True is supported") @@ -728,7 +756,9 @@ def __call__(self, array): @borrow_docstring(ak.singletons) -def singletons(array, axis=0, highlevel=True, behavior=None): +def singletons( + array: Array, axis: int = 0, highlevel: bool = True, behavior: Mapping | None = None +) -> Array: if not highlevel: raise ValueError("Only highlevel=True is supported") @@ -748,7 +778,16 @@ def __call__(self, array): @borrow_docstring(ak.sort) -def sort(array, axis=-1, ascending=True, stable=True, highlevel=True, behavior=None): +def sort( + array: Array, + axis: int = -1, + ascending: bool = True, + stable: bool = True, + highlevel: bool = True, + behavior: Mapping | None = None, +) -> Array: + if not highlevel: + raise ValueError("Only highlevel=True is supported") if axis == 0: raise DaskAwkwardNotImplemented("TODO") fn = _SortFn( @@ -761,12 +800,19 @@ def sort(array, axis=-1, ascending=True, stable=True, highlevel=True, behavior=N @borrow_docstring(ak.strings_astype) -def strings_astype(array, to, highlevel=True, behavior=None): +def strings_astype( + array: Array, + to: np.dtype | str, + highlevel: bool = True, + behavior: Mapping | None = None, +) -> Array: raise DaskAwkwardNotImplemented("TODO") @borrow_docstring(ak.to_regular) -def to_regular(array, axis=1, highlevel=True, behavior=None): +def to_regular( + array: Array, axis: int = 1, highlevel: bool = True, behavior: Mapping | None = None +) -> Array: if not highlevel: raise ValueError("Only highlevel=True is supported") @@ -787,7 +833,13 @@ def to_regular(array, axis=1, highlevel=True, behavior=None): @borrow_docstring(ak.unflatten) -def unflatten(array, counts, axis=0, highlevel=True, behavior=None): +def unflatten( + array: Array, + counts: int | Array, + axis: int = 0, + highlevel: bool = True, + behavior: Mapping | None = None, +) -> Array: if not highlevel: raise ValueError("Only highlevel=True is supported") @@ -809,21 +861,34 @@ def unflatten(array, counts, axis=0, highlevel=True, behavior=None): ) +def _array_with_behavior(array: Array, behavior: Mapping | None) -> Array: + if behavior is None: + new_meta = array._meta + else: + new_meta = ak.Array(array._meta, behavior=behavior) + return Array(array.dask, array.name, new_meta, array.divisions) + + @borrow_docstring(ak.unzip) def unzip( - array: Array, highlevel: bool = True, behavior: dict | None = None + array: Array, highlevel: bool = True, behavior: Mapping | None = None ) -> tuple[Array, ...]: if not highlevel: raise ValueError("Only highlevel=True is supported") fields = ak.fields(array._meta) if len(fields) == 0: - return (array,) + return (_array_with_behavior(array, behavior),) else: - return tuple(array[field] for field in fields) + return tuple(_array_with_behavior(array[field], behavior) for field in fields) @borrow_docstring(ak.values_astype) -def values_astype(array, to, highlevel=True, behavior=None): +def values_astype( + array: Array, + to: np.dtype | str, + highlevel: bool = True, + behavior: Mapping | None = None, +) -> Array: if not highlevel: raise ValueError("Only highlevel=True is supported") return map_partitions( @@ -840,7 +905,7 @@ def __init__( self, mergebool: bool = True, highlevel: bool = True, - behavior: dict | None = None, + behavior: Mapping | None = None, ) -> None: self.mergebool = mergebool self.highlevel = highlevel @@ -864,7 +929,7 @@ def where( y: Array, mergebool: bool = True, highlevel: bool = True, - behavior: dict | None = None, + behavior: Mapping | None = None, ) -> Array: if not highlevel: raise ValueError("Only highlevel=True is supported") @@ -892,9 +957,9 @@ def where( class _WithFieldFn: def __init__( self, - where: str | None = None, + where: str | Sequence[str] | None = None, highlevel: bool = True, - behavior: dict | None = None, + behavior: Mapping | None = None, ) -> None: self.where = where self.highlevel = highlevel @@ -911,7 +976,13 @@ def __call__(self, base: ak.Array, what: ak.Array) -> ak.Array: @borrow_docstring(ak.with_field) -def with_field(base, what, where=None, highlevel=True, behavior=None): +def with_field( + base: Array, + what: Array | int | float | complex | bool, + where: str | Sequence[str] | None = None, + highlevel: bool = True, + behavior: Mapping | None = None, +) -> Array: if not highlevel: raise ValueError("Only highlevel=True is supported") @@ -938,7 +1009,7 @@ def with_field(base, what, where=None, highlevel=True, behavior=None): class _WithNameFn: - def __init__(self, name: str, behavior: dict | None = None) -> None: + def __init__(self, name: str | None, behavior: Mapping | None = None) -> None: self.name = name self.behavior = behavior @@ -949,9 +1020,9 @@ def __call__(self, array: ak.Array) -> ak.Array: @borrow_docstring(ak.with_name) def with_name( array: Array, - name: str, + name: str | None, highlevel: bool = True, - behavior: dict | None = None, + behavior: Mapping | None = None, ) -> Array: if not highlevel: raise ValueError("Only highlevel=True is supported") @@ -985,8 +1056,10 @@ def with_parameter( parameter: str, value: Any, highlevel: bool = True, - behavior: dict | None = None, + behavior: Mapping | None = None, ) -> Array: + if not highlevel: + raise ValueError("Only highlevel=True is supported") return map_partitions( _WithParameterFn(parameter=parameter, value=value, behavior=behavior), array, @@ -1007,8 +1080,10 @@ def __call__(self, array): def without_parameters( array: Array, highlevel: bool = True, - behavior: dict | None = None, + behavior: Mapping | None = None, ) -> Array: + if not highlevel: + raise ValueError("Only highlevel=True is supported") return map_partitions( _WithoutParameterFn(behavior=behavior), array, @@ -1021,8 +1096,8 @@ def without_parameters( def zeros_like( array: Array, highlevel: bool = True, - behavior: dict | None = None, - dtype: DTypeLike | None = None, + behavior: Mapping | None = None, + dtype: np.dtype | str | None = None, ) -> Array: if not highlevel: raise ValueError("Only highlevel=True is supported") @@ -1058,12 +1133,12 @@ def __call__(self, *parts: Any) -> ak.Array: @borrow_docstring(ak.zip) def zip( - arrays: dict | list | tuple, + arrays: Sequence[Array] | Mapping[str, Array], depth_limit: int | None = None, - parameters: dict | None = None, + parameters: Mapping[str, Any] | None = None, with_name: str | None = None, highlevel: bool = True, - behavior: dict | None = None, + behavior: Mapping | None = None, right_broadcast: bool = False, optiontype_outside_record: bool = False, ) -> Array: