Skip to content

Commit

Permalink
[test_core/test_convert2d] Skip tests if scratch buffer exceeds avail… (
Browse files Browse the repository at this point in the history
#906)

…able shared memory
  • Loading branch information
alexbaden authored Apr 23, 2024
1 parent f32dbb7 commit 71cb08b
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4814,17 +4814,21 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device):
pytest.xfail("Out of bound access when maxPhase > 1")
if str(src_layout) == str(dst_layout):
pytest.xfail("Do not convert same layout")
if is_hip():
if is_hip() or is_xpu():
try:
scratch_shape = compute_scratch_buffer_shape(src_layout, dst_layout, (M, N))
except AssertionError:
if is_xpu():
# expect compute scratch buffer to not error on xpu
raise
pytest.skip("Can't compute scratch buffer size")
lds_size = 65536
shared_mem_size = triton.runtime.driver.active.utils.get_device_properties(
triton.runtime.driver.active.get_current_device())["max_shared_mem"] if is_xpu() else 65536
# consider int32 dtype in scratch buffer size,
# because it is the largest dtype used in convert_layout in this test
int32_size = 4
# skip even if scratch buffer equal to lds_size, because real scratch buffer is typically larger due to padding
if scratch_shape[0] * scratch_shape[1] * int32_size >= lds_size:
# skip even if scratch buffer equal to shared mem size, because real scratch buffer is typically larger due to padding
if scratch_shape[0] * scratch_shape[1] * int32_size >= shared_mem_size:
pytest.skip("Scratch buffer is too large")

layouts = f"""
Expand Down

0 comments on commit 71cb08b

Please sign in to comment.