From 7b074eb42a62ab290ce8fb325e49bc32e0e2cba8 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Mon, 3 Jun 2024 15:11:33 +0200 Subject: [PATCH] Fix `test_nan_propagation` for immutable arrays. --- array_api_tests/test_special_cases.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/array_api_tests/test_special_cases.py b/array_api_tests/test_special_cases.py index 07ab3616..f2d2d154 100644 --- a/array_api_tests/test_special_cases.py +++ b/array_api_tests/test_special_cases.py @@ -23,8 +23,9 @@ from warnings import warn import pytest -from hypothesis import given, note, settings +from hypothesis import given, note, settings, assume from hypothesis import strategies as st +from hypothesis.strategies import composite from array_api_tests.typing import Array, DataType @@ -1321,6 +1322,11 @@ def test_empty_arrays(func_name, expected): # TODO: parse docstrings to get exp else: assert out == expected, msg +@composite +def not_all_false(draw, shape): + ret = draw(hh.arrays(dtype=hh.bool_dtype, shape=shape)) + assume(ret.any()) + return ret @pytest.mark.parametrize( "func_name", [f.__name__ for f in category_to_funcs["statistical"]] @@ -1331,10 +1337,8 @@ def test_empty_arrays(func_name, expected): # TODO: parse docstrings to get exp ) def test_nan_propagation(func_name, x, data): func = getattr(xp, func_name) - set_idx = data.draw( - xps.indices(x.shape, max_dims=0, allow_ellipsis=False), label="set idx" - ) - x[set_idx] = float("nan") + nan_positions = data.draw(not_all_false(x.shape)) + x = xp.where(nan_positions, float("nan"), x) note(f"{x=}") out = func(x)