Skip to content

Commit

Permalink
Merge pull request #229 from honno/remaining-fft-tests
Browse files Browse the repository at this point in the history
Remaining FFT tests
  • Loading branch information
honno authored Jan 16, 2024
2 parents ae0017a + 6e806ec commit d0d9696
Showing 1 changed file with 23 additions and 5 deletions.
28 changes: 23 additions & 5 deletions array_api_tests/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,26 @@ def test_ihfft(x, data):
assert_n_axis_shape("ihfft", x=x, n=n, axis=axis, out=out, size_gt_1=True)


# TODO:
# fftfreq
# rfftfreq
# fftshift
# ifftshift
@given( n=st.integers(1, 100), kw=hh.kwargs(d=st.floats(0.1, 5)))
def test_fftfreq(n, kw):
out = xp.fft.fftfreq(n, **kw)
ph.assert_shape("fftfreq", out_shape=out.shape, expected=(n,), kw={"n": n})


@given(n=st.integers(1, 100), kw=hh.kwargs(d=st.floats(0.1, 5)))
def test_rfftfreq(n, kw):
out = xp.fft.rfftfreq(n, **kw)
ph.assert_shape("rfftfreq", out_shape=out.shape, expected=(n // 2 + 1,), kw={"n": n})


@pytest.mark.parametrize("func_name", ["fftshift", "ifftshift"])
@given(x=hh.arrays(xps.floating_dtypes(), fft_shapes_strat), data=st.data())
def test_shift_func(func_name, x, data):
func = getattr(xp.fft, func_name)
axes = data.draw(
st.none() | st.lists(st.sampled_from(list(range(x.ndim))), min_size=1, unique=True),
label="axes",
)
out = func(x, axes=axes)
ph.assert_dtype(func_name, in_dtype=x.dtype, out_dtype=out.dtype)
ph.assert_shape(func_name, out_shape=out.shape, expected=x.shape)

0 comments on commit d0d9696

Please sign in to comment.