Skip to content

Commit

Permalink
Use condition numbers to filter out ill-conditioned value testing for…
Browse files Browse the repository at this point in the history
… sum/cumulative_sum/prod

This isn't completely rigorous (I haven't tweaked the tolerances used in
isclose from the generous ones we were using before), but I haven't gotten
hypothesis to find any bad corner cases for this yet. If any crop up we can
easily tweak the values.

Fixes data-apis#168
  • Loading branch information
asmeurer committed Sep 12, 2024
1 parent 8f240f6 commit 621316b
Showing 1 changed file with 45 additions and 7 deletions.
52 changes: 45 additions & 7 deletions array_api_tests/test_statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@


@pytest.mark.min_version("2023.12")
@pytest.mark.unvectorized
@given(
x=hh.arrays(
dtype=hh.numeric_dtypes,
Expand Down Expand Up @@ -80,10 +81,15 @@ def test_cumulative_sum(x, data):
if dh.is_int_dtype(out.dtype):
m, M = dh.dtype_ranges[out.dtype]
assume(m <= expected <= M)
ph.assert_scalar_equals("cumulative_sum", type_=scalar_type,
idx=out_idx.raw, out=out_val,
expected=expected)

ph.assert_scalar_equals("cumulative_sum", type_=scalar_type,
idx=out_idx.raw, out=out_val,
expected=expected)
else:
condition_number = _sum_condition_number(elements)
assume(condition_number < 1e6)
ph.assert_scalar_isclose("cumulative_sum", type_=scalar_type,
idx=out_idx.raw, out=out_val,
expected=expected)

def kwarg_dtypes(dtype: DataType) -> st.SearchStrategy[Optional[DataType]]:
dtypes = [d2 for d1, d2 in dh.promotion_table if d1 == dtype]
Expand Down Expand Up @@ -176,6 +182,16 @@ def test_min(x, data):
ph.assert_scalar_equals("min", type_=scalar_type, idx=out_idx, out=min_, expected=expected)


def _prod_condition_number(elements):
# Relative condition number using the infinity norm
abs_max = max([abs(i) for i in elements])
abs_min = min([abs(i) for i in elements])

if abs_min == 0:
return float('inf')

return abs_max / abs_min

@pytest.mark.unvectorized
@given(
x=hh.arrays(
Expand Down Expand Up @@ -225,7 +241,13 @@ def test_prod(x, data):
if dh.is_int_dtype(out.dtype):
m, M = dh.dtype_ranges[out.dtype]
assume(m <= expected <= M)
ph.assert_scalar_equals("prod", type_=scalar_type, idx=out_idx, out=prod, expected=expected)
ph.assert_scalar_equals("prod", type_=scalar_type, idx=out_idx,
out=prod, expected=expected)
else:
condition_number = _prod_condition_number(elements)
assume(condition_number < 1e15)
ph.assert_scalar_isclose("prod", type_=scalar_type, idx=out_idx,
out=prod, expected=expected)


@pytest.mark.skip(reason="flaky") # TODO: fix!
Expand Down Expand Up @@ -264,8 +286,16 @@ def test_std(x, data):
)
# We can't easily test the result(s) as standard deviation methods vary a lot

def _sum_condition_number(elements):
sum_abs = sum([abs(i) for i in elements])
abs_sum = abs(sum(elements))

@pytest.mark.unvectorized
if abs_sum == 0:
return float('inf')

return sum_abs / abs_sum

# @pytest.mark.unvectorized
@given(
x=hh.arrays(
dtype=hh.numeric_dtypes,
Expand Down Expand Up @@ -314,7 +344,15 @@ def test_sum(x, data):
if dh.is_int_dtype(out.dtype):
m, M = dh.dtype_ranges[out.dtype]
assume(m <= expected <= M)
ph.assert_scalar_equals("sum", type_=scalar_type, idx=out_idx, out=sum_, expected=expected)
ph.assert_scalar_equals("sum", type_=scalar_type, idx=out_idx,
out=sum_, expected=expected)
else:
# Avoid value testing for ill conditioned summations. See
# https://en.wikipedia.org/wiki/Kahan_summation_algorithm#Accuracy and
# https://en.wikipedia.org/wiki/Condition_number.
condition_number = _sum_condition_number(elements)
assume(condition_number < 1e6)
ph.assert_scalar_isclose("sum", type_=scalar_type, idx=out_idx, out=sum_, expected=expected)


@pytest.mark.unvectorized
Expand Down

0 comments on commit 621316b

Please sign in to comment.