From 621316bcaf8e58dd146f2ccec5b767371f089981 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Thu, 12 Sep 2024 16:23:07 -0600 Subject: [PATCH] Use condition numbers to filter out ill-conditioned value testing for sum/cumulative_sum/prod This isn't completely rigorous (I haven't tweaked the tolerances used in isclose from the generous ones we were using before), but I haven't gotten hypothesis to find any bad corner cases for this yet. If any crop up we can easily tweak the values. Fixes #168 --- array_api_tests/test_statistical_functions.py | 52 ++++++++++++++++--- 1 file changed, 45 insertions(+), 7 deletions(-) diff --git a/array_api_tests/test_statistical_functions.py b/array_api_tests/test_statistical_functions.py index d533e51b..1778b5d0 100644 --- a/array_api_tests/test_statistical_functions.py +++ b/array_api_tests/test_statistical_functions.py @@ -17,6 +17,7 @@ @pytest.mark.min_version("2023.12") +@pytest.mark.unvectorized @given( x=hh.arrays( dtype=hh.numeric_dtypes, @@ -80,10 +81,15 @@ def test_cumulative_sum(x, data): if dh.is_int_dtype(out.dtype): m, M = dh.dtype_ranges[out.dtype] assume(m <= expected <= M) - ph.assert_scalar_equals("cumulative_sum", type_=scalar_type, - idx=out_idx.raw, out=out_val, - expected=expected) - + ph.assert_scalar_equals("cumulative_sum", type_=scalar_type, + idx=out_idx.raw, out=out_val, + expected=expected) + else: + condition_number = _sum_condition_number(elements) + assume(condition_number < 1e6) + ph.assert_scalar_isclose("cumulative_sum", type_=scalar_type, + idx=out_idx.raw, out=out_val, + expected=expected) def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]: dtypes = [d2 for d1, d2 in dh.promotion_table if d1 == dtype] @@ -176,6 +182,16 @@ def test_min(x, data): ph.assert_scalar_equals("min", type_=scalar_type, idx=out_idx, out=min_, expected=expected) +def _prod_condition_number(elements): + # Relative condition number using the infinity norm + abs_max = max([abs(i) for i in elements]) + abs_min = min([abs(i) for i in elements]) + + if abs_min == 0: + return float('inf') + + return abs_max / abs_min + @pytest.mark.unvectorized @given( x=hh.arrays( @@ -225,7 +241,13 @@ def test_prod(x, data): if dh.is_int_dtype(out.dtype): m, M = dh.dtype_ranges[out.dtype] assume(m <= expected <= M) - ph.assert_scalar_equals("prod", type_=scalar_type, idx=out_idx, out=prod, expected=expected) + ph.assert_scalar_equals("prod", type_=scalar_type, idx=out_idx, + out=prod, expected=expected) + else: + condition_number = _prod_condition_number(elements) + assume(condition_number < 1e15) + ph.assert_scalar_isclose("prod", type_=scalar_type, idx=out_idx, + out=prod, expected=expected) @pytest.mark.skip(reason="flaky") # TODO: fix! @@ -264,8 +286,16 @@ def test_std(x, data): ) # We can't easily test the result(s) as standard deviation methods vary a lot +def _sum_condition_number(elements): + sum_abs = sum([abs(i) for i in elements]) + abs_sum = abs(sum(elements)) -@pytest.mark.unvectorized + if abs_sum == 0: + return float('inf') + + return sum_abs / abs_sum + +# @pytest.mark.unvectorized @given( x=hh.arrays( dtype=hh.numeric_dtypes, @@ -314,7 +344,15 @@ def test_sum(x, data): if dh.is_int_dtype(out.dtype): m, M = dh.dtype_ranges[out.dtype] assume(m <= expected <= M) - ph.assert_scalar_equals("sum", type_=scalar_type, idx=out_idx, out=sum_, expected=expected) + ph.assert_scalar_equals("sum", type_=scalar_type, idx=out_idx, + out=sum_, expected=expected) + else: + # Avoid value testing for ill conditioned summations. See + # https://en.wikipedia.org/wiki/Kahan_summation_algorithm#Accuracy and + # https://en.wikipedia.org/wiki/Condition_number. + condition_number = _sum_condition_number(elements) + assume(condition_number < 1e6) + ph.assert_scalar_isclose("sum", type_=scalar_type, idx=out_idx, out=sum_, expected=expected) @pytest.mark.unvectorized