Skip to content

Commit

Permalink
Add array and axis testing to repeat()
Browse files Browse the repository at this point in the history
Still need to add values testing.
  • Loading branch information
asmeurer committed Sep 24, 2024
1 parent b4c0823 commit e46e978
Showing 1 changed file with 39 additions and 14 deletions.
53 changes: 39 additions & 14 deletions array_api_tests/test_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,45 @@ def test_permute_dims(x, axes):
out_indices=permuted_indices)


@pytest.mark.min_version("2023.12")
@given(
x=hh.arrays(dtype=hh.all_dtypes, shape=shared_shapes(min_dims=1)),
kw=hh.kwargs(
axis=st.none() | shared_shapes(min_dims=1).flatmap(
lambda s: st.integers(-len(s), len(s) - 1)
)
),
data=st.data(),
)
def test_repeat(x, kw, data):
shape = x.shape
axis = kw.get("axis", None)
dim = math.prod(shape) if axis is None else shape[axis]
repeat_strat = st.integers(1, 4)
repeats = data.draw(repeat_strat
| hh.arrays(dtype=hh.int_dtypes, elements=repeat_strat,
shape=st.sampled_from([(1,), (dim,)])),
label="repeats")
if isinstance(repeats, int):
n_repitions = dim*repeats
else:
if repeats.shape == (1,):
n_repitions = dim*repeats[0]
else:
n_repitions = int(xp.sum(repeats))

out = xp.repeat(x, repeats, **kw)
ph.assert_dtype("repeat", in_dtype=x.dtype, out_dtype=out.dtype)
if axis is None:
expected_shape = (n_repitions,)
else:
expected_shape = list(shape)
expected_shape[axis] = n_repitions
expected_shape = tuple(expected_shape)
ph.assert_shape("repeat", out_shape=out.shape, expected=expected_shape)
# TODO: values testing


@st.composite
def reshape_shapes(draw, shape):
size = 1 if len(shape) == 0 else math.prod(shape)
Expand All @@ -298,20 +337,6 @@ def reshape_shapes(draw, shape):
return tuple(rshape)


@pytest.mark.min_version("2023.12")
@given(
x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_dims=1)),
repeats=st.integers(1, 4),
)
def test_repeat(x, repeats):
# TODO: test array repeats and non-None axis, adjust shape and value testing accordingly
out = xp.repeat(x, repeats)
ph.assert_dtype("repeat", in_dtype=x.dtype, out_dtype=out.dtype)
expected_shape = (math.prod(x.shape) * repeats,)
ph.assert_shape("repeat", out_shape=out.shape, expected=expected_shape)
# TODO: values testing


@pytest.mark.unvectorized
@pytest.mark.skip("flaky") # TODO: fix!
@given(
Expand Down

0 comments on commit e46e978

Please sign in to comment.