Skip to content

Commit

Permalink
Fix test_matmul_simple to avoid use of out-of-bounds Python scalars
Browse files Browse the repository at this point in the history
Change the test so that input matrices that get multiplied only have
blocks of ones no larger than the max integer for the type, rest is
populated with zeros. This change applies to integral types only.
  • Loading branch information
oleksandr-pavlyk committed Aug 20, 2024
1 parent 2fe2791 commit 16e6c1e
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions dpctl/tests/test_usm_ndarray_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,20 @@ def test_matmul_simple(dtype):
skip_if_dtype_not_supported(dtype, q)

n, m = 235, 17
m1 = dpt.ones((m, n), dtype=dtype)
m2 = dpt.ones((n, m), dtype=dtype)
m1 = dpt.zeros((m, n), dtype=dtype)
m2 = dpt.zeros((n, m), dtype=dtype)

dt = m1.dtype
if dt.kind in "ui":
n1 = min(n, dpt.iinfo(dt).max)
else:
n1 = n
m1[:, :n1] = dpt.ones((m, n1), dtype=dt)
m2[:n1, :] = dpt.ones((n1, m), dtype=dt)

for k in [1, 2, 3, 4, 7, 8, 9, 15, 16, 17]:
r = dpt.matmul(m1[:k, :], m2[:, :k])
assert dpt.all(r == dpt.full((k, k), n, dtype=dtype))
assert dpt.all(r == dpt.full((k, k), fill_value=n1, dtype=dt))


@pytest.mark.parametrize("dtype", _numeric_types)
Expand Down

0 comments on commit 16e6c1e

Please sign in to comment.