diff --git a/array_api_tests/__init__.py b/array_api_tests/__init__.py index 9af6796b..fd05697d 100644 --- a/array_api_tests/__init__.py +++ b/array_api_tests/__init__.py @@ -11,7 +11,7 @@ # You can comment the following out and instead import the specific array module -# you want to test, e.g. `import numpy.array_api as xp`. +# you want to test, e.g. `import array_api_strict as xp`. if "ARRAY_API_TESTS_MODULE" in os.environ: xp_name = os.environ["ARRAY_API_TESTS_MODULE"] _module, _sub = xp_name, None @@ -33,6 +33,17 @@ ) +# If xp.bool is not available, like in some versions of NumPy and CuPy, try +# patching in xp.bool_. +try: + xp.bool +except AttributeError as e: + if hasattr(xp, "bool_"): + xp.bool = xp.bool_ + else: + raise e + + # We monkey patch floats() to always disable subnormals as they are out-of-scope _floats = st.floats