Skip to content

Commit

Permalink
have Scalar.__getitem__ raise and instruct to use to_delayed
Browse files Browse the repository at this point in the history
  • Loading branch information
douglasdavis committed Nov 13, 2023
1 parent 9d3433d commit c2670a8
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 9 deletions.
12 changes: 4 additions & 8 deletions src/dask_awkward/lib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,15 +193,11 @@ def __str__(self) -> str:
return f"dask.awkward<{key_split(self.name)}, type=Scalar, dtype={dt}>"

def __getitem__(self, where: Any) -> Any:
token = tokenize(self, operator.getitem, where)
label = "getitem"
name = f"{label}-{token}"
task = AwkwardMaterializedLayer(
{(name, 0): (operator.getitem, self.key, where)},
previous_layer_names=[self.name],
msg = (
"__getitem__ access on Scalars should be done after converting "
"the Scalar collection to delayed with the to_delayed method."
)
hlg = HighLevelGraph.from_collections(name, task, dependencies=[self])
return new_scalar_object(hlg, name, meta=None)
raise RuntimeError(msg)

@property
def known_value(self) -> Any | None:
Expand Down
4 changes: 3 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,9 @@ def test_scalar_collection(daa: Array) -> None:
def test_scalar_getitem_getattr() -> None:
d = {"a": 5}
s = new_known_scalar(d)
assert s["a"].compute() == d["a"]
with pytest.raises(RuntimeError, match="should be done after converting"):
s["a"].compute() == d["a"]
s.to_delayed()["a"].compute() == d["a"]
Thing = namedtuple("Thing", "a b c")
t = Thing(c=3, b=2, a=1)
s = new_known_scalar(t)
Expand Down

0 comments on commit c2670a8

Please sign in to comment.