Skip to content

Commit

Permalink
Merge pull request #516 from martindurant/scalar_binop
Browse files Browse the repository at this point in the history
fix: Respect inverse in scalar binop
  • Loading branch information
martindurant authored Jun 19, 2024
2 parents 9fb24dc + 442d441 commit 1c4b97a
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 12 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ repos:
rev: v1.10.0
hooks:
- id: mypy
files: "src/"
args: [--ignore-missing-imports]
additional_dependencies:
- dask
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ src_paths = ["src", "tests"]

[tool.mypy]
python_version = "3.9"
files = ["src", "tests"]
files = ["src"]
exclude = ["tests/"]
strict = false
warn_unused_configs = true
show_error_codes = true
Expand Down
24 changes: 19 additions & 5 deletions src/dask_awkward/lib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,9 +502,17 @@ def f(self, other):
if is_dask_collection(other):
task = (op, self.key, *other.__dask_keys__())
deps.append(other)
plns.append(other.name)
if inv:
plns.insert(0, other.name)
else:
plns.append(other.name)
else:
task = (op, self.key, other)
if inv:
task = (op, other, self.key)
else:
task = (op, self.key, other)
if inv:
plns.reverse()
graph = HighLevelGraph.from_collections(
name,
layer=AwkwardMaterializedLayer(
Expand All @@ -514,10 +522,16 @@ def f(self, other):
),
dependencies=tuple(deps),
)
if isinstance(other, Scalar):
meta = op(self._meta, other._meta)
if isinstance(other, (Scalar, Array)):
if inv:
meta = op(other._meta, self._meta)
else:
meta = op(self._meta, other._meta)
else:
meta = op(self._meta, other)
if inv:
meta = op(other, self._meta)
else:
meta = op(self._meta, other)
return new_scalar_object(graph, name, meta=meta)

return f
Expand Down
20 changes: 14 additions & 6 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_len(ndjson_points_file: str) -> None:
assert len(daa) == 10
daa.eager_compute_divisions()
assert daa.known_divisions
assert len(daa) == 10 # type: ignore
assert len(daa) == 10


def test_meta_exists(daa: Array) -> None:
Expand Down Expand Up @@ -157,7 +157,7 @@ def test_partitions_divisions(ndjson_points_file: str) -> None:
assert not t1.known_divisions
t2 = daa.partitions[1]
assert t2.known_divisions
assert t2.divisions == (0, divs[2] - divs[1]) # type: ignore
assert t2.divisions == (0, divs[2] - divs[1])


def test_array_rebuild(ndjson_points_file: str) -> None:
Expand Down Expand Up @@ -537,7 +537,7 @@ def test_compatible_partitions_after_slice() -> None:
assert_eq(lazy, ccrt)

# sanity
assert dak.compatible_partitions(lazy, lazy + 2) # type: ignore
assert dak.compatible_partitions(lazy, lazy + 2)
assert dak.compatible_partitions(lazy, dak.num(lazy, axis=1) > 2)

assert not dak.compatible_partitions(lazy[:-2], lazy)
Expand Down Expand Up @@ -646,6 +646,14 @@ def test_scalar_divisions(daa: Array) -> None:
assert s.divisions == (None, None)


def test_scalar_binop_inv() -> None:
# GH #515
x = dak.from_lists([[1]])
y = x[0] # scalar
assert (0 - y) == -1
assert (y - 0) == 1


def test_array_persist(daa: Array) -> None:
daa2 = daa["points"]["x"].persist()
assert_eq(daa["points"]["x"], daa2)
Expand Down Expand Up @@ -886,7 +894,7 @@ def test_shape_only_ops(fn: Callable, tmp_path_factory: pytest.TempPathFactory)
p = tmp_path_factory.mktemp("zeros-like-flat")
ak.to_parquet(a, str(p / "file.parquet"))
lazy = dak.from_parquet(str(p))
result = fn(lazy.b) # type: ignore
result = fn(lazy.b)
with dask.config.set({"awkward.optimization.enabled": True}):
result.compute()

Expand All @@ -898,7 +906,7 @@ def test_assign_behavior() -> None:
with pytest.raises(
TypeError, match="'mappingproxy' object does not support item assignment"
):
dx.behavior["should_fail"] = None # type: ignore
dx.behavior["should_fail"] = None
assert dx.behavior == behavior


Expand All @@ -909,7 +917,7 @@ def test_assign_attrs() -> None:
with pytest.raises(
TypeError, match="'mappingproxy' object does not support item assignment"
):
dx.attrs["should_fail"] = None # type: ignore
dx.attrs["should_fail"] = None
assert dx.attrs == attrs


Expand Down

0 comments on commit 1c4b97a

Please sign in to comment.