Skip to content

Commit

Permalink
Update _test_stacks to use updated ndindex behavior
Browse files Browse the repository at this point in the history
This requires Quansight-Labs/ndindex#155 which is not
yet released.
  • Loading branch information
asmeurer committed Feb 3, 2024
1 parent 3501116 commit 5c1aa45
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions array_api_tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def assert_equal(x, y, msg_extra=None):

def _test_stacks(f, *args, res=None, dims=2, true_val=None,
matrix_axes=(-2, -1),
res_axes=None,
assert_equal=assert_equal, **kw):
"""
Test that f(*args, **kw) maps across stacks of matrices
Expand All @@ -84,7 +85,10 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None,

# Assume the result is stacked along the last 'dims' axes of matrix_axes.
# This holds for all the functions tested in this file
res_axes = matrix_axes[::-1][:dims]
if res_axes is None:
if not isinstance(matrix_axes, tuple) and all(isinstance(x, int) for x in matrix_axes):
raise ValueError("res_axes must be specified if matrix_axes is not a tuple of integers")
res_axes = matrix_axes[::-1][:dims]

for (x_idxes, (res_idx,)) in zip(
iter_indices(*shapes, skip_axes=matrix_axes),
Expand Down Expand Up @@ -330,10 +334,12 @@ def test_matmul(x1, x2):
assert res.shape == ()
elif len(x1.shape) == 1:
assert res.shape == x2.shape[:-2] + x2.shape[-1:]
_test_stacks(_array_module.matmul, x1, x2, res=res, dims=1)
_test_stacks(_array_module.matmul, x1, x2, res=res, dims=1,
matrix_axes=[(0,), (-2, -1)], res_axes=[-1])
elif len(x2.shape) == 1:
assert res.shape == x1.shape[:-1]
_test_stacks(_array_module.matmul, x1, x2, res=res, dims=1)
_test_stacks(_array_module.matmul, x1, x2, res=res, dims=1,
matrix_axes=[(-2, -1), (0,)], res_axes=[-1])
else:
stack_shape = sh.broadcast_shapes(x1.shape[:-2], x2.shape[:-2])
assert res.shape == stack_shape + (x1.shape[-2], x2.shape[-1])
Expand Down Expand Up @@ -546,10 +552,11 @@ def test_solve(x1, x2):
# TODO: This requires an upstream fix to ndindex
# (https://github.com/Quansight-Labs/ndindex/pull/131)

# if x2.ndim == 1:
# _test_stacks(linalg.solve, x1, x2, res=res, dims=1)
# else:
# _test_stacks(linalg.solve, x1, x2, res=res, dims=2)
if x2.ndim == 1:
_test_stacks(linalg.solve, x1, x2, res=res, dims=1,
matrix_axes=[(-2, -1), (0,)], res_axes=[-1])
else:
_test_stacks(linalg.solve, x1, x2, res=res, dims=2)

@pytest.mark.xp_extension('linalg')
@given(
Expand Down

0 comments on commit 5c1aa45

Please sign in to comment.