Skip to content

Commit

Permalink
Fix test_nan_propagation for immutable arrays.
Browse files Browse the repository at this point in the history
  • Loading branch information
hameerabbasi committed Jun 3, 2024
1 parent 33f2d2e commit 7b074eb
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions array_api_tests/test_special_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@
from warnings import warn

import pytest
from hypothesis import given, note, settings
from hypothesis import given, note, settings, assume
from hypothesis import strategies as st
from hypothesis.strategies import composite

from array_api_tests.typing import Array, DataType

Expand Down Expand Up @@ -1321,6 +1322,11 @@ def test_empty_arrays(func_name, expected): # TODO: parse docstrings to get exp
else:
assert out == expected, msg

@composite
def not_all_false(draw, shape):
ret = draw(hh.arrays(dtype=hh.bool_dtype, shape=shape))
assume(ret.any())
return ret

@pytest.mark.parametrize(
"func_name", [f.__name__ for f in category_to_funcs["statistical"]]
Expand All @@ -1331,10 +1337,8 @@ def test_empty_arrays(func_name, expected): # TODO: parse docstrings to get exp
)
def test_nan_propagation(func_name, x, data):
func = getattr(xp, func_name)
set_idx = data.draw(
xps.indices(x.shape, max_dims=0, allow_ellipsis=False), label="set idx"
)
x[set_idx] = float("nan")
nan_positions = data.draw(not_all_false(x.shape))
x = xp.where(nan_positions, float("nan"), x)
note(f"{x=}")

out = func(x)
Expand Down

0 comments on commit 7b074eb

Please sign in to comment.