Skip to content

Commit

Permalink
Relax the atol for test_forward and test_dropout due to the using of …
Browse files Browse the repository at this point in the history
…packed fp16_2_fp32 conversion in ck_tile
  • Loading branch information
qianfengz committed Jun 20, 2024
1 parent 2655be6 commit 8e9dc32
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,7 @@ def test_dropout(op, q_len, kv_len, batch_size, k_len, p, seed, attn_bias):
if dtype is torch.float:
assert_allclose(out, ref, atol=2e-4), f"{(out - ref).abs().max()}"
else:
assert_allclose(out.float(), ref, atol=2.2e-2), f"{(out - ref).abs().max()}"
assert_allclose(out.float(), ref, atol=2.8e-2), f"{(out - ref).abs().max()}"

num_trials = 1000
p_val_tol = 1e-6
Expand Down
2 changes: 1 addition & 1 deletion xformers/ops/fmha/ck.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class FwOp(AttentionFwOpBase):

ERROR_ATOL: Mapping[torch.dtype, float] = {
torch.float: 3e-4,
torch.half: 4e-3,
torch.half: 6e-3,
torch.bfloat16: 2.8e-2,
}
ERROR_RTOL: Mapping[torch.dtype, float] = {
Expand Down

0 comments on commit 8e9dc32

Please sign in to comment.