Skip to content

Commit

Permalink
readability changes pt.2
Browse files Browse the repository at this point in the history
  • Loading branch information
Chaluvadi committed Mar 12, 2024
1 parent 1f8d475 commit 527b807
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions tests/test_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)


types = [s16, s32, s64, u8, u16, u32, u64, f16, f32, f64, c32, c64]
all_types = [s16, s32, s64, u8, u16, u32, u64, f16, f32, f64, c32, c64]


@pytest.mark.parametrize(
Expand Down Expand Up @@ -66,7 +66,6 @@ def test_constant_complex_shape(shape: tuple) -> None:
"""Test if constant_complex creates an array with the correct shape."""
dtype = c32

dtype = c32
rand_array = wrapper.randu((1, 1), dtype)
number = wrapper.get_scalar(rand_array, dtype)

Expand Down Expand Up @@ -167,11 +166,11 @@ def test_constant_ulong_shape_invalid() -> None:

@pytest.mark.parametrize(
"dtype",
types,
all_types,
)
def test_constant_dtype(dtype: Dtype) -> None:
"""Test if constant creates an array with the correct dtype."""
if dtype in [c32, c64] or (dtype == f64 and not wrapper.get_dbl_support()):
if is_cmplx_type(dtype) or not is_system_supported(dtype):
pytest.skip()

rand_array = wrapper.randu((1, 1), dtype)
Expand All @@ -186,11 +185,11 @@ def test_constant_dtype(dtype: Dtype) -> None:

@pytest.mark.parametrize(
"dtype",
types,
all_types,
)
def test_constant_complex_dtype(dtype: Dtype) -> None:
"""Test if constant_complex creates an array with the correct dtype."""
if dtype not in [c32, c64] or (dtype == c64 and not wrapper.get_dbl_support()):
if not is_cmplx_type(dtype) or not is_system_supported(dtype):
pytest.skip()

rand_array = wrapper.randu((1, 1), dtype)
Expand Down Expand Up @@ -234,3 +233,14 @@ def test_constant_ulong_dtype() -> None:
assert c_api_value_to_dtype(wrapper.get_type(result)) == dtype
else:
pytest.skip()


def is_cmplx_type(dtype: Dtype) -> bool:
return dtype == c32 or dtype == c64


def is_system_supported(dtype: Dtype) -> bool:
if dtype in [f64, c64] and not wrapper.get_dbl_support():
return False

return True

0 comments on commit 527b807

Please sign in to comment.