diff --git a/tests/test_mem_eff_attention.py b/tests/test_mem_eff_attention.py index f33c7f456..14027a164 100644 --- a/tests/test_mem_eff_attention.py +++ b/tests/test_mem_eff_attention.py @@ -672,6 +672,8 @@ def test_backward( if op_bw == fmha.ck.BwOp: op_fw = fmha.ck.FwOp if dtype == torch.bfloat16: + ## bfloat16 testing can be enabled by export ENABLE_HIP_FMHA_RTN_BF16_CONVERT=1 when + ## building xformers and get accurate results pytest.skip( "CK Fmha backward for bfloat16 currently is not very accurate for some cases!" )