diff --git a/dpctl/tests/test_usm_ndarray_linalg.py b/dpctl/tests/test_usm_ndarray_linalg.py index c36c195769..e2aa23873f 100644 --- a/dpctl/tests/test_usm_ndarray_linalg.py +++ b/dpctl/tests/test_usm_ndarray_linalg.py @@ -89,12 +89,20 @@ def test_matmul_simple(dtype): skip_if_dtype_not_supported(dtype, q) n, m = 235, 17 - m1 = dpt.ones((m, n), dtype=dtype) - m2 = dpt.ones((n, m), dtype=dtype) + m1 = dpt.zeros((m, n), dtype=dtype) + m2 = dpt.zeros((n, m), dtype=dtype) + + dt = m1.dtype + if dt.kind in "ui": + n1 = min(n, dpt.iinfo(dt).max) + else: + n1 = n + m1[:, :n1] = dpt.ones((m, n1), dtype=dt) + m2[:n1, :] = dpt.ones((n1, m), dtype=dt) for k in [1, 2, 3, 4, 7, 8, 9, 15, 16, 17]: r = dpt.matmul(m1[:k, :], m2[:, :k]) - assert dpt.all(r == dpt.full((k, k), n, dtype=dtype)) + assert dpt.all(r == dpt.full((k, k), fill_value=n1, dtype=dt)) @pytest.mark.parametrize("dtype", _numeric_types)