From 7d111fd0b453755eeb523e0e7c04b9ecc25e1f87 Mon Sep 17 00:00:00 2001 From: Valeriu Lacatusu Date: Fri, 7 Jun 2024 16:07:41 +0200 Subject: [PATCH] use skip decorator for cutlass_pt (#1126) --- tests/test_mem_eff_attention.py | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/tests/test_mem_eff_attention.py b/tests/test_mem_eff_attention.py index 7f1187108b..618e41e08f 100644 --- a/tests/test_mem_eff_attention.py +++ b/tests/test_mem_eff_attention.py @@ -48,6 +48,9 @@ skip_if_rocm = pytest.mark.skipif( torch.version.hip is not None, reason="not supported on ROCm" ) +skip_if_pt_cutlass = pytest.mark.skipif( + fmha.cutlass.USE_TORCH_CUTLASS, reason="using PT cutlass" +) _devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] T = TypeVar( @@ -1933,6 +1936,7 @@ def test_permuted_attn_bias(self) -> None: @cuda_only @disable_on_rocm +@skip_if_pt_cutlass @pytest.mark.parametrize("dtype_str", ["f32", "f16", "bf16"]) @pytest.mark.parametrize( "sm_shmem", @@ -1945,16 +1949,6 @@ def test_has_kernel_for(sm_shmem: Tuple[int, int], dtype_str: str) -> None: if sm < 80 and dtype_str == "bf16": return - if hasattr(torch.ops.xformers, "_has_cutlassF_kernel_for"): - pytest.skip( - "xformers doesnt have any _has_cutlassF_kernel_for implementation since it uses torch CUTLASS" - ) - - if hasattr(torch.ops.xformers, "_has_cutlassB_kernel_for"): - pytest.skip( - "xformers doesnt have any _has_cutlassB_kernel_for implementation since it uses torch CUTLASS" - ) - for k in [16, 32, 64, 128, 256]: assert torch.ops.xformers._has_cutlassF_kernel_for( dtype, sm, shmem_kbytes * 1024, k @@ -2354,6 +2348,7 @@ def test_local_attn_bias() -> None: @cuda_only @disable_on_rocm +@skip_if_pt_cutlass @pytest.mark.parametrize("cc", [60, 70, 80]) @pytest.mark.parametrize("maxK", [32, 64, 128, 256]) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) @@ -2405,11 +2400,6 @@ def test_cutlassB_iter_order( .. and we test this across variable causal masks+local attention combinations """ - if hasattr(torch.ops.xformers, "_cutlassB_iteration_data"): - pytest.skip( - "xformers doesnt have any _cutlassB_iteration_data implementation since it uses torch CUTLASS" - ) - if ( window_size > 0 and custom_mask_type == fmha.cutlass._CustomMaskType.NoCustomMask