Skip to content

Commit

Permalink
Improve test_convert_mma2mma coverage (#3115)
Browse files Browse the repository at this point in the history
When the parameter set `['mma_pair']` is empty, test case is considered
as skipped like below:
```
language/test_core.py::test_convert_mma2mma[mma_pair0-float16-64-1] SKIPPED (got empty parameter set ['mma_pair'], function test_...)
language/test_core.py::test_convert_mma2mma[mma_pair0-float16-1-64] SKIPPED (got empty parameter set ['mma_pair'], function test_...)
language/test_core.py::test_convert_mma2mma[mma_pair0-float16-64-64] SKIPPED (got empty parameter set ['mma_pair'], function test...)
language/test_core.py::test_convert_mma2mma[mma_pair0-float16-128-128] SKIPPED (got empty parameter set ['mma_pair'], function te...)
language/test_core.py::test_convert_mma2mma[mma_pair0-float16-256-256] SKIPPED (got empty parameter set ['mma_pair'], function te...)
```

Before:
```
language: passed: 11964, failed: 0, skipped: 7, xfailed: 547, total: 12518, fixme: 0, pass rate (w/o xfailed): 99.94%
all: passed: 18664, failed: 0, skipped: 28, xfailed: 1309, total: 20001, fixme: 48, pass rate (w/o xfailed): 99.85%
```
After:
```
language: passed: 11969, failed: 0, skipped: 2, xfailed: 547, total: 12518, fixme: 0, pass rate (w/o xfailed): 99.98%
all: passed: 18669, failed: 0, skipped: 23, xfailed: 1309, total: 20001, fixme: 48, pass rate (w/o xfailed): 99.88%
```

Signed-off-by: Whitney Tsang <whitney.tsang@intel.com>
  • Loading branch information
whitneywhtsang authored Jan 8, 2025
1 parent b9da9cc commit 66795f8
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,10 @@ def filter_layouts(layouts):
return [l for l in layouts if is_layout_applicable(l)]


def filter_layout_pairs(pairs):
return [p for p in pairs if is_layout_applicable(p[0]) and is_layout_applicable(p[1])]


@pytest.mark.interpreter
@pytest.mark.parametrize("dtype_x", list(dtypes) + ["bfloat16"])
def test_empty_kernel(dtype_x, device):
Expand Down Expand Up @@ -5770,12 +5774,18 @@ def test_local_load_store_mma(M, N, mma_layout, shared_layout, device, tmp_path:
MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 64, 16]),
MmaLayout((3, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 128, 16]),
],
[
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=1, threads_per_warp=32,
warps_per_cta=[4, 1], rep_cluster=[1, 1]),
DpasLayout(repeatCount=8, systolic_depth=8, execution_size=8, ops_per_chan=2, threads_per_warp=32,
warps_per_cta=[2, 2], rep_cluster=[1, 1]),
],
]


@pytest.mark.parametrize("M, N", [[64, 1], [1, 64], [64, 64], [128, 128], [256, 256]])
@pytest.mark.parametrize("dtype", ['float16'])
@pytest.mark.parametrize("mma_pair", filter_layouts(mma_pairs))
@pytest.mark.parametrize("mma_pair", filter_layout_pairs(mma_pairs))
def test_convert_mma2mma(M, N, mma_pair, dtype, device, tmp_path: pathlib.Path):
src_layout, _ = mma_pair
num_warps = np.prod(src_layout.warps_per_cta)
Expand Down

0 comments on commit 66795f8

Please sign in to comment.