Skip to content

Commit

Permalink
Smoke axes argument in FFT shift tests
Browse files Browse the repository at this point in the history
  • Loading branch information
honno committed Jan 16, 2024
1 parent 674dd0a commit 6e806ec
Showing 1 changed file with 11 additions and 12 deletions.
23 changes: 11 additions & 12 deletions array_api_tests/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,15 +303,14 @@ def test_rfftfreq(n, kw):
ph.assert_shape("rfftfreq", out_shape=out.shape, expected=(n // 2 + 1,), kw={"n": n})


@given(x=hh.arrays(xps.floating_dtypes(), fft_shapes_strat))
def test_fftshift(x):
out = xp.fft.fftshift(x)
ph.assert_dtype("fftshift", in_dtype=x.dtype, out_dtype=out.dtype)
ph.assert_shape("fftshift", out_shape=out.shape, expected=x.shape)


@given(x=hh.arrays(xps.floating_dtypes(), fft_shapes_strat))
def test_ifftshift(x):
out = xp.fft.ifftshift(x)
ph.assert_dtype("ifftshift", in_dtype=x.dtype, out_dtype=out.dtype)
ph.assert_shape("ifftshift", out_shape=out.shape, expected=x.shape)
@pytest.mark.parametrize("func_name", ["fftshift", "ifftshift"])
@given(x=hh.arrays(xps.floating_dtypes(), fft_shapes_strat), data=st.data())
def test_shift_func(func_name, x, data):
func = getattr(xp.fft, func_name)
axes = data.draw(
st.none() | st.lists(st.sampled_from(list(range(x.ndim))), min_size=1, unique=True),
label="axes",
)
out = func(x, axes=axes)
ph.assert_dtype(func_name, in_dtype=x.dtype, out_dtype=out.dtype)
ph.assert_shape(func_name, out_shape=out.shape, expected=x.shape)

0 comments on commit 6e806ec

Please sign in to comment.