Skip to content

Commit

Permalink
Merge pull request data-apis#266 from asmeurer/skip_dtypes
Browse files Browse the repository at this point in the history
Support skipping dtypes by setting ARRAY_API_TESTS_SKIP_DTYPES
  • Loading branch information
asmeurer authored May 23, 2024
2 parents f9022a1 + 9816630 commit 05ab7f1
Show file tree
Hide file tree
Showing 19 changed files with 202 additions and 189 deletions.
2 changes: 2 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[flake8]
select = F
15 changes: 14 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
This is the test suite for array libraries adopting the [Python Array API
standard](https://data-apis.org/array-api/latest).

Keeping full coverage of the spec is an on-going priority as the Array API evolves.
Keeping full coverage of the spec is an on-going priority as the Array API evolves.
Feedback and contributions are welcome!

## Quickstart
Expand Down Expand Up @@ -285,6 +285,19 @@ values should result in more rigorous runs. For example, `--max-examples
10_000` may find bugs where default runs don't but will take much longer to
run.

#### Skipping Dtypes

The test suite will automatically skip testing of inessential dtypes if they
are not present on the array module namespace, but dtypes can also be skipped
manually by setting the environment variable `ARRAY_API_TESTS_SKIP_DTYPES` to
a comma separated list of dtypes to skip. For example

```
ARRAY_API_TESTS_SKIP_DTYPES=uint16,uint32,uint64 pytest array_api_tests/
```
Note that skipping certain essential dtypes such as `bool` and the default
floating-point dtype is not supported.
## Contributing
Expand Down
36 changes: 20 additions & 16 deletions array_api_tests/dtype_helpers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import re
from collections import defaultdict
from collections.abc import Mapping
Expand Down Expand Up @@ -104,9 +105,18 @@ def __repr__(self):
numeric_names = real_names + complex_names
dtype_names = ("bool",) + numeric_names

_skip_dtypes = os.getenv("ARRAY_API_TESTS_SKIP_DTYPES", '')
_skip_dtypes = _skip_dtypes.split(',')
skip_dtypes = []
for dtype in _skip_dtypes:
if dtype and dtype not in dtype_names:
raise ValueError(f"Invalid dtype name in ARRAY_API_TESTS_SKIP_DTYPES: {dtype}")
skip_dtypes.append(dtype)

_name_to_dtype = {}
for name in dtype_names:
if name in skip_dtypes:
continue
try:
dtype = getattr(xp, name)
except AttributeError:
Expand Down Expand Up @@ -184,9 +194,9 @@ def _make_dtype_mapping_from_names(mapping: Dict[str, Any]) -> EqualityMapping:
dtype_value_pairs = []
for name, value in mapping.items():
assert isinstance(name, str) and name in dtype_names # sanity check
try:
dtype = getattr(xp, name)
except AttributeError:
if name in _name_to_dtype:
dtype = _name_to_dtype[name]
else:
continue
dtype_value_pairs.append((dtype, value))
return EqualityMapping(dtype_value_pairs)
Expand Down Expand Up @@ -313,9 +323,9 @@ def accumulation_result_dtype(x_dtype, dtype_kwarg):
else:
default_complex = None
if dtype_nbits[default_int] == 32:
default_uint = getattr(xp, "uint32", None)
default_uint = _name_to_dtype.get("uint32")
else:
default_uint = getattr(xp, "uint64", None)
default_uint = _name_to_dtype.get("uint64")

_promotion_table: Dict[Tuple[str, str], str] = {
("bool", "bool"): "bool",
Expand Down Expand Up @@ -366,18 +376,12 @@ def accumulation_result_dtype(x_dtype, dtype_kwarg):
_promotion_table.update({(d2, d1): res for (d1, d2), res in _promotion_table.items()})
_promotion_table_pairs: List[Tuple[Tuple[DataType, DataType], DataType]] = []
for (in_name1, in_name2), res_name in _promotion_table.items():
try:
in_dtype1 = getattr(xp, in_name1)
except AttributeError:
continue
try:
in_dtype2 = getattr(xp, in_name2)
except AttributeError:
continue
try:
res_dtype = getattr(xp, res_name)
except AttributeError:
if in_name1 not in _name_to_dtype or in_name2 not in _name_to_dtype or res_name not in _name_to_dtype:
continue
in_dtype1 = _name_to_dtype[in_name1]
in_dtype2 = _name_to_dtype[in_name2]
res_dtype = _name_to_dtype[res_name]

_promotion_table_pairs.append(((in_dtype1, in_dtype2), res_dtype))
promotion_table = EqualityMapping(_promotion_table_pairs)

Expand Down
34 changes: 24 additions & 10 deletions array_api_tests/hypothesis_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,24 @@ def oneway_broadcastable_shapes(draw) -> OnewayBroadcastableShapes:
return OnewayBroadcastableShapes(input_shape, result_shape)


# Use these instead of xps.scalar_dtypes, etc. because it skips dtypes from
# ARRAY_API_TESTS_SKIP_DTYPES
all_dtypes = sampled_from(_sorted_dtypes)
int_dtypes = sampled_from(dh.int_dtypes)
uint_dtypes = sampled_from(dh.uint_dtypes)
real_dtypes = sampled_from(dh.real_dtypes)
# Warning: The hypothesis "floating_dtypes" is what we call
# "real_floating_dtypes"
floating_dtypes = sampled_from(dh.all_float_dtypes)
real_floating_dtypes = sampled_from(dh.real_float_dtypes)
numeric_dtypes = sampled_from(dh.numeric_dtypes)
# Note: this always returns complex dtypes, even if api_version < 2022.12
complex_dtypes = sampled_from(dh.complex_dtypes)

def all_floating_dtypes() -> SearchStrategy[DataType]:
strat = xps.floating_dtypes()
strat = floating_dtypes
if api_version >= "2022.12":
strat |= xps.complex_dtypes()
strat |= complex_dtypes
return strat


Expand Down Expand Up @@ -236,7 +250,7 @@ def matrix_shapes(draw, stack_shapes=shapes()):

@composite
def finite_matrices(draw, shape=matrix_shapes()):
return draw(arrays(dtype=xps.floating_dtypes(),
return draw(arrays(dtype=floating_dtypes,
shape=shape,
elements=dict(allow_nan=False,
allow_infinity=False)))
Expand All @@ -245,7 +259,7 @@ def finite_matrices(draw, shape=matrix_shapes()):
# Should we set a max_value here?
_rtol_float_kw = dict(allow_nan=False, allow_infinity=False, min_value=0)
rtols = one_of(floats(**_rtol_float_kw),
arrays(dtype=xps.floating_dtypes(),
arrays(dtype=real_floating_dtypes,
shape=rtol_shared_matrix_shapes.map(lambda shape: shape[:-2]),
elements=_rtol_float_kw))

Expand Down Expand Up @@ -280,9 +294,9 @@ def mutually_broadcastable_shapes(

two_mutually_broadcastable_shapes = mutually_broadcastable_shapes(2)

# Note: This should become hermitian_matrices when complex dtypes are added
# TODO: Add support for complex Hermitian matrices
@composite
def symmetric_matrices(draw, dtypes=xps.floating_dtypes(), finite=True, bound=10.):
def symmetric_matrices(draw, dtypes=real_floating_dtypes, finite=True, bound=10.):
shape = draw(square_matrix_shapes)
dtype = draw(dtypes)
if not isinstance(finite, bool):
Expand All @@ -297,7 +311,7 @@ def symmetric_matrices(draw, dtypes=xps.floating_dtypes(), finite=True, bound=10
return H

@composite
def positive_definite_matrices(draw, dtypes=xps.floating_dtypes()):
def positive_definite_matrices(draw, dtypes=floating_dtypes):
# For now just generate stacks of identity matrices
# TODO: Generate arbitrary positive definite matrices, for instance, by
# using something like
Expand All @@ -310,7 +324,7 @@ def positive_definite_matrices(draw, dtypes=xps.floating_dtypes()):
return broadcast_to(eye(n, dtype=dtype), shape)

@composite
def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes()):
def invertible_matrices(draw, dtypes=floating_dtypes, stack_shapes=shapes()):
# For now, just generate stacks of diagonal matrices.
stack_shape = draw(stack_shapes)
n = draw(integers(0, SQRT_MAX_ARRAY_SIZE // max(math.prod(stack_shape), 1)),)
Expand Down Expand Up @@ -344,7 +358,7 @@ def two_broadcastable_shapes(draw):
sqrt_sizes = integers(0, SQRT_MAX_ARRAY_SIZE)

numeric_arrays = arrays(
dtype=shared(xps.floating_dtypes(), key='dtypes'),
dtype=shared(floating_dtypes, key='dtypes'),
shape=shared(xps.array_shapes(), key='shapes'),
)

Expand Down Expand Up @@ -388,7 +402,7 @@ def python_integer_indices(draw, sizes):
def integer_indices(draw, sizes):
# Return either a Python integer or a 0-D array with some integer dtype
idx = draw(python_integer_indices(sizes))
dtype = draw(xps.integer_dtypes() | xps.unsigned_integer_dtypes())
dtype = draw(int_dtypes | uint_dtypes)
m, M = dh.dtype_ranges[dtype]
if m <= idx <= M:
return draw(one_of(just(idx),
Expand Down
28 changes: 28 additions & 0 deletions array_api_tests/pytest_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,34 @@ def assert_dtype(
assert out_dtype == expected, msg


def assert_float_to_complex_dtype(
func_name: str, *, in_dtype: DataType, out_dtype: DataType
):
if in_dtype == xp.float32:
expected = xp.complex64
else:
assert in_dtype == xp.float64 # sanity check
expected = xp.complex128
assert_dtype(
func_name, in_dtype=in_dtype, out_dtype=out_dtype, expected=expected
)


def assert_complex_to_float_dtype(
func_name: str, *, in_dtype: DataType, out_dtype: DataType, repr_name: str = "out.dtype"
):
if in_dtype == xp.complex64:
expected = xp.float32
elif in_dtype == xp.complex128:
expected = xp.float64
else:
assert in_dtype in (xp.float32, xp.float64) # sanity check
expected = in_dtype
assert_dtype(
func_name, in_dtype=in_dtype, out_dtype=out_dtype, expected=expected, repr_name=repr_name
)


def assert_kw_dtype(
func_name: str,
*,
Expand Down
23 changes: 9 additions & 14 deletions array_api_tests/test_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from . import pytest_helpers as ph
from . import shape_helpers as sh
from . import xps
from . import xp as _xp
from .typing import DataType, Index, Param, Scalar, ScalarType, Shape


Expand Down Expand Up @@ -75,7 +74,7 @@ def get_indexed_axes_and_out_shape(
return tuple(axes_indices), tuple(out_shape)


@given(shape=hh.shapes(), dtype=xps.scalar_dtypes(), data=st.data())
@given(shape=hh.shapes(), dtype=hh.all_dtypes, data=st.data())
def test_getitem(shape, dtype, data):
zero_sided = any(side == 0 for side in shape)
if zero_sided:
Expand Down Expand Up @@ -157,7 +156,7 @@ def test_setitem(shape, dtypes, data):
@pytest.mark.data_dependent_shapes
@given(hh.shapes(), st.data())
def test_getitem_masking(shape, data):
x = data.draw(hh.arrays(xps.scalar_dtypes(), shape=shape), label="x")
x = data.draw(hh.arrays(hh.all_dtypes, shape=shape), label="x")
mask_shapes = st.one_of(
st.sampled_from([x.shape, ()]),
st.lists(st.booleans(), min_size=x.ndim, max_size=x.ndim).map(
Expand Down Expand Up @@ -202,7 +201,7 @@ def test_getitem_masking(shape, data):
@pytest.mark.unvectorized
@given(hh.shapes(), st.data())
def test_setitem_masking(shape, data):
x = data.draw(hh.arrays(xps.scalar_dtypes(), shape=shape), label="x")
x = data.draw(hh.arrays(hh.all_dtypes, shape=shape), label="x")
key = data.draw(hh.arrays(dtype=xp.bool, shape=shape), label="key")
value = data.draw(
hh.from_dtype(x.dtype) | hh.arrays(dtype=x.dtype, shape=()), label="value"
Expand Down Expand Up @@ -252,18 +251,14 @@ def make_scalar_casting_param(


@pytest.mark.parametrize(
"method_name, dtype_name, stype",
[make_scalar_casting_param("__bool__", "bool", bool)]
+ [make_scalar_casting_param("__int__", n, int) for n in dh.all_int_names]
+ [make_scalar_casting_param("__index__", n, int) for n in dh.all_int_names]
+ [make_scalar_casting_param("__float__", n, float) for n in dh.real_float_names],
"method_name, dtype, stype",
[make_scalar_casting_param("__bool__", xp.bool, bool)]
+ [make_scalar_casting_param("__int__", n, int) for n in dh.all_int_dtypes]
+ [make_scalar_casting_param("__index__", n, int) for n in dh.all_int_dtypes]
+ [make_scalar_casting_param("__float__", n, float) for n in dh.real_float_dtypes],
)
@given(data=st.data())
def test_scalar_casting(method_name, dtype_name, stype, data):
try:
dtype = getattr(_xp, dtype_name)
except AttributeError as e:
pytest.skip(str(e))
def test_scalar_casting(method_name, dtype, stype, data):
x = data.draw(hh.arrays(dtype, shape=()), label="x")
method = getattr(x, method_name)
out = method()
Expand Down
Loading

0 comments on commit 05ab7f1

Please sign in to comment.