From 8e9dc325e8726aeeac153b5593d4d1fede88aa7c Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Thu, 20 Jun 2024 14:56:11 +0000 Subject: [PATCH] Relax the atol for test_forward and test_dropout due to the using of packed fp16_2_fp32 conversion in ck_tile --- tests/test_mem_eff_attention.py | 2 +- xformers/ops/fmha/ck.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_mem_eff_attention.py b/tests/test_mem_eff_attention.py index 1b022b4aee..0e58c742e1 100644 --- a/tests/test_mem_eff_attention.py +++ b/tests/test_mem_eff_attention.py @@ -910,7 +910,7 @@ def test_dropout(op, q_len, kv_len, batch_size, k_len, p, seed, attn_bias): if dtype is torch.float: assert_allclose(out, ref, atol=2e-4), f"{(out - ref).abs().max()}" else: - assert_allclose(out.float(), ref, atol=2.2e-2), f"{(out - ref).abs().max()}" + assert_allclose(out.float(), ref, atol=2.8e-2), f"{(out - ref).abs().max()}" num_trials = 1000 p_val_tol = 1e-6 diff --git a/xformers/ops/fmha/ck.py b/xformers/ops/fmha/ck.py index 5d94ff5a23..39a0895533 100644 --- a/xformers/ops/fmha/ck.py +++ b/xformers/ops/fmha/ck.py @@ -183,7 +183,7 @@ class FwOp(AttentionFwOpBase): ERROR_ATOL: Mapping[torch.dtype, float] = { torch.float: 3e-4, - torch.half: 4e-3, + torch.half: 6e-3, torch.bfloat16: 2.8e-2, } ERROR_RTOL: Mapping[torch.dtype, float] = {