Skip to content

Commit

Permalink
added unit tests for range function
Browse files Browse the repository at this point in the history
  • Loading branch information
Chaluvadi committed Mar 28, 2024
1 parent 8c7d809 commit 2fbaace
Showing 1 changed file with 61 additions and 0 deletions.
61 changes: 61 additions & 0 deletions tests/test_range.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import random

import pytest

import arrayfire_wrapper.dtypes as dtypes
import arrayfire_wrapper.lib as wrapper


@pytest.mark.parametrize(
"shape",
[
(),
(random.randint(1, 10), 1),
(random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
],
)
def test_range_shape(shape: tuple) -> None:
"""Test if the range function output an AFArray with the correct shape"""
dim = 2
dtype = dtypes.s16

result = wrapper.range(shape, dim, dtype)

assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203


def test_range_invalid_shape() -> None:
"""Test if range function correctly handles an invalid shape"""
with pytest.raises(TypeError):
shape = (
random.randint(1, 10),
random.randint(1, 10),
random.randint(1, 10),
random.randint(1, 10),
random.randint(1, 10),
)
dim = 2
dtype = dtypes.s16

wrapper.range(shape, dim, dtype)


@pytest.mark.parametrize(
"shape",
[
(),
(random.randint(1, 10), 1),
(random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
],
)
def test_range_invalid_dim(shape: tuple) -> None:
"""Test if the range function can properly handle and invalid dimension given"""
with pytest.raises(RuntimeError):
dim = random.randint(4, 10)
dtype = dtypes.s16

wrapper.range(shape, dim, dtype)

0 comments on commit 2fbaace

Please sign in to comment.