Skip to content

Commit

Permalink
use skip decorator for cutlass_pt (facebookresearch#1126)
Browse files Browse the repository at this point in the history
  • Loading branch information
lvaleriu authored Jun 7, 2024
1 parent 8c1d0bd commit 7d111fd
Showing 1 changed file with 5 additions and 15 deletions.
20 changes: 5 additions & 15 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@
skip_if_rocm = pytest.mark.skipif(
torch.version.hip is not None, reason="not supported on ROCm"
)
skip_if_pt_cutlass = pytest.mark.skipif(
fmha.cutlass.USE_TORCH_CUTLASS, reason="using PT cutlass"
)
_devices = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]

T = TypeVar(
Expand Down Expand Up @@ -1933,6 +1936,7 @@ def test_permuted_attn_bias(self) -> None:

@cuda_only
@disable_on_rocm
@skip_if_pt_cutlass
@pytest.mark.parametrize("dtype_str", ["f32", "f16", "bf16"])
@pytest.mark.parametrize(
"sm_shmem",
Expand All @@ -1945,16 +1949,6 @@ def test_has_kernel_for(sm_shmem: Tuple[int, int], dtype_str: str) -> None:
if sm < 80 and dtype_str == "bf16":
return

if hasattr(torch.ops.xformers, "_has_cutlassF_kernel_for"):
pytest.skip(
"xformers doesnt have any _has_cutlassF_kernel_for implementation since it uses torch CUTLASS"
)

if hasattr(torch.ops.xformers, "_has_cutlassB_kernel_for"):
pytest.skip(
"xformers doesnt have any _has_cutlassB_kernel_for implementation since it uses torch CUTLASS"
)

for k in [16, 32, 64, 128, 256]:
assert torch.ops.xformers._has_cutlassF_kernel_for(
dtype, sm, shmem_kbytes * 1024, k
Expand Down Expand Up @@ -2354,6 +2348,7 @@ def test_local_attn_bias() -> None:

@cuda_only
@disable_on_rocm
@skip_if_pt_cutlass
@pytest.mark.parametrize("cc", [60, 70, 80])
@pytest.mark.parametrize("maxK", [32, 64, 128, 256])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
Expand Down Expand Up @@ -2405,11 +2400,6 @@ def test_cutlassB_iter_order(
.. and we test this across variable causal masks+local attention combinations
"""

if hasattr(torch.ops.xformers, "_cutlassB_iteration_data"):
pytest.skip(
"xformers doesnt have any _cutlassB_iteration_data implementation since it uses torch CUTLASS"
)

if (
window_size > 0
and custom_mask_type == fmha.cutlass._CustomMaskType.NoCustomMask
Expand Down

0 comments on commit 7d111fd

Please sign in to comment.