Skip to content

Commit

Permalink
Add missing tests for unstack()
Browse files Browse the repository at this point in the history
  • Loading branch information
asmeurer committed May 30, 2024
1 parent a04ff8f commit 6362204
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions array_api_tests/test_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,22 @@ def test_tile(x, data):
def test_unstack(x, data):
axis = data.draw(st.integers(min_value=-x.ndim, max_value=x.ndim - 1), label="axis")
kw = data.draw(hh.specified_kwargs(("axis", axis, 0)), label="kw")
out = xp.asarray(xp.unstack(x, **kw), dtype=x.dtype)
ph.assert_dtype("unstack", in_dtype=x.dtype, out_dtype=out.dtype)
# TODO: shapes and values testing
out = xp.unstack(x, **kw)

assert isinstance(out, tuple)
assert len(out) == x.shape[axis]
expected_shape = list(x.shape)
expected_shape.pop(axis)
expected_shape = tuple(expected_shape)
for i in range(x.shape[axis]):
arr = out[i]
ph.assert_result_shape("unstack", in_shapes=[x.shape],
out_shape=arr.shape, expected=expected_shape,
kw=kw, repr_name=f"out[{i}].shape")

ph.assert_dtype("unstack", in_dtype=x.dtype, out_dtype=arr.dtype,
repr_name=f"out[{i}].dtype")

idx = [slice(None)] * x.ndim
idx[axis] = i
ph.assert_array_elements("unstack", out=arr, expected=x[tuple(idx)], kw=kw, out_repr=f"out[{i}]")

0 comments on commit 6362204

Please sign in to comment.