From 66795f89921adb9c3b2b8f31410aedfbf5cf70db Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Wed, 8 Jan 2025 16:36:17 -0500 Subject: [PATCH] Improve `test_convert_mma2mma` coverage (#3115) When the parameter set `['mma_pair']` is empty, test case is considered as skipped like below: ``` language/test_core.py::test_convert_mma2mma[mma_pair0-float16-64-1] SKIPPED (got empty parameter set ['mma_pair'], function test_...) language/test_core.py::test_convert_mma2mma[mma_pair0-float16-1-64] SKIPPED (got empty parameter set ['mma_pair'], function test_...) language/test_core.py::test_convert_mma2mma[mma_pair0-float16-64-64] SKIPPED (got empty parameter set ['mma_pair'], function test...) language/test_core.py::test_convert_mma2mma[mma_pair0-float16-128-128] SKIPPED (got empty parameter set ['mma_pair'], function te...) language/test_core.py::test_convert_mma2mma[mma_pair0-float16-256-256] SKIPPED (got empty parameter set ['mma_pair'], function te...) ``` Before: ``` language: passed: 11964, failed: 0, skipped: 7, xfailed: 547, total: 12518, fixme: 0, pass rate (w/o xfailed): 99.94% all: passed: 18664, failed: 0, skipped: 28, xfailed: 1309, total: 20001, fixme: 48, pass rate (w/o xfailed): 99.85% ``` After: ``` language: passed: 11969, failed: 0, skipped: 2, xfailed: 547, total: 12518, fixme: 0, pass rate (w/o xfailed): 99.98% all: passed: 18669, failed: 0, skipped: 23, xfailed: 1309, total: 20001, fixme: 48, pass rate (w/o xfailed): 99.88% ``` Signed-off-by: Whitney Tsang --- python/test/unit/language/test_core.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 848ace41f2..922c8880ea 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -266,6 +266,10 @@ def filter_layouts(layouts): return [l for l in layouts if is_layout_applicable(l)] +def filter_layout_pairs(pairs): + return [p for p in pairs if is_layout_applicable(p[0]) and is_layout_applicable(p[1])] + + @pytest.mark.interpreter @pytest.mark.parametrize("dtype_x", list(dtypes) + ["bfloat16"]) def test_empty_kernel(dtype_x, device): @@ -5770,12 +5774,18 @@ def test_local_load_store_mma(M, N, mma_layout, shared_layout, device, tmp_path: MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 64, 16]), MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 128, 16]), ], + [ + DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=32, + warps_per_cta=[4, 1], rep_cluster=[1, 1]), + DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=2, threads_per_warp=32, + warps_per_cta=[2, 2], rep_cluster=[1, 1]), + ], ] @pytest.mark.parametrize("M, N", [[64, 1], [1, 64], [64, 64], [128, 128], [256, 256]]) @pytest.mark.parametrize("dtype", ['float16']) -@pytest.mark.parametrize("mma_pair", filter_layouts(mma_pairs)) +@pytest.mark.parametrize("mma_pair", filter_layout_pairs(mma_pairs)) def test_convert_mma2mma(M, N, mma_pair, dtype, device, tmp_path: pathlib.Path): src_layout, _ = mma_pair num_warps = np.prod(src_layout.warps_per_cta)