Skip to content

Commit

Permalink
fixed import formatting, black and flake8 checks
Browse files Browse the repository at this point in the history
  • Loading branch information
Chaluvadi authored and roaffix committed Mar 2, 2024
1 parent 886f411 commit 21f72a2
Showing 1 changed file with 10 additions and 13 deletions.
23 changes: 10 additions & 13 deletions tests/test_upper.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import pytest

import arrayfire_wrapper.dtypes as dtypes
from arrayfire_wrapper.lib.create_and_modify_array.create_array.constant import constant
from arrayfire_wrapper.lib.create_and_modify_array.create_array.diag import diag_extract
from arrayfire_wrapper.lib.create_and_modify_array.create_array.upper import upper
from arrayfire_wrapper.lib.create_and_modify_array.manage_array import get_scalar
import arrayfire_wrapper.lib as wrapper


@pytest.mark.parametrize(
Expand All @@ -18,11 +15,11 @@
def test_diag_is_unit(shape: tuple) -> None:
"""Test if when is_unit_diag in lower returns an array with a unit diagonal"""
dtype = dtypes.s64
constant_array = constant(3, shape, dtype)
constant_array = wrapper.constant(3, shape, dtype)

lower_array = upper(constant_array, True)
diagonal = diag_extract(lower_array, 0)
diagonal_value = get_scalar(diagonal, dtype)
lower_array = wrapper.upper(constant_array, True)
diagonal = wrapper.diag_extract(lower_array, 0)
diagonal_value = wrapper.get_scalar(diagonal, dtype)

assert diagonal_value == 1

Expand All @@ -38,11 +35,11 @@ def test_diag_is_unit(shape: tuple) -> None:
def test_is_original(shape: tuple) -> None:
"""Test if is_original keeps the diagonal the same as the original array"""
dtype = dtypes.s64
constant_array = constant(3, shape, dtype)
original_value = get_scalar(constant_array, dtype)
constant_array = wrapper.constant(3, shape, dtype)
original_value = wrapper.get_scalar(constant_array, dtype)

lower_array = upper(constant_array, False)
diagonal = diag_extract(lower_array, 0)
diagonal_value = get_scalar(diagonal, dtype)
lower_array = wrapper.upper(constant_array, False)
diagonal = wrapper.diag_extract(lower_array, 0)
diagonal_value = wrapper.get_scalar(diagonal, dtype)

assert original_value == diagonal_value

0 comments on commit 21f72a2

Please sign in to comment.