Skip to content

Commit

Permalink
Merge pull request #480 from lgray/remove_lambdas
Browse files Browse the repository at this point in the history
chore: various bits of housekeeping
  • Loading branch information
martindurant authored Mar 3, 2024
2 parents d10d240 + 4c6db51 commit c8ba5b1
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 4 deletions.
5 changes: 5 additions & 0 deletions src/dask_awkward/awkward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,8 @@ awkward:
# continue on to be computed without dask-awkward specific
# optimizations.
on-fail: raise

aggregation:
# For tree reductions in dask-awkward, control how many partitions
# are aggregated per non-leaf tree node.
split-every: 8
2 changes: 1 addition & 1 deletion src/dask_awkward/lib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2246,7 +2246,7 @@ def non_trivial_reduction(
finalize_fn = partial(finalize_fn, **finalize_kwargs)

if split_every is None:
split_every = 8
split_every = dask.config.get("awkward.aggregation.split-every", 8)
elif split_every is False:
split_every = sys.maxsize
else:
Expand Down
8 changes: 5 additions & 3 deletions src/dask_awkward/lib/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from typing import TYPE_CHECKING, Any, cast

import awkward as ak
import dask.config
import numpy as np
from awkward.types.numpytype import primitive_to_dtype
from awkward.typetracer import length_zero_if_typetracer
from dask.base import flatten, tokenize
from dask.highlevelgraph import HighLevelGraph
from dask.local import identity
from dask.utils import funcname, is_integer, parse_bytes
from fsspec.utils import infer_compression

Expand Down Expand Up @@ -651,7 +653,7 @@ def from_map(
axis=0,
)

split_every = 8
split_every = dask.config.get("awkward.aggregation.split-every", 8)

rep_trl_label = f"{label}-report"
rep_trl_token = tokenize(result, second, concat_fn, split_every)
Expand All @@ -667,8 +669,8 @@ def from_map(
name_input=rep_part.name,
npartitions_input=rep_part.npartitions,
concat_func=concat_fn,
tree_node_func=lambda x: x,
finalize_func=lambda x: x,
tree_node_func=identity,
finalize_func=identity,
split_every=split_every,
tree_node_name=rep_trl_tree_node_name,
)
Expand Down

0 comments on commit c8ba5b1

Please sign in to comment.