Skip to content

Commit

Permalink
[ROCM] adjust test_flash_attn_rocm test tolerance (#21379)
Browse files Browse the repository at this point in the history
The test_flash_attn_rocm.py from
#21032 failed frequently.
For example, I saw two failed jobs today:
E           Max absolute difference: 0.002167
E           Max absolute difference: 0.002686

Adjust the abs threshold from 0.002 to 0.005, and use default relative tolerance rtol=0.001.
  • Loading branch information
tianleiwu authored Jul 17, 2024
1 parent fa28704 commit 0f4c39e
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions onnxruntime/test/python/transformers/test_flash_attn_rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_inte
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
rtol=0.002,
atol=0.002,
rtol=0.001,
atol=0.005,
)
parity_check_gqa_prompt_no_buff(
config,
Expand All @@ -45,8 +45,8 @@ def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_inte
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
rtol=0.002,
atol=0.002,
rtol=0.001,
atol=0.005,
)

@parameterized.expand(gqa_past_flash_attention_test_cases())
Expand All @@ -67,8 +67,8 @@ def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interle
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
rtol=0.002,
atol=0.002,
rtol=0.001,
atol=0.005,
)
parity_check_gqa_past_no_buff(
config,
Expand All @@ -77,8 +77,8 @@ def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interle
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
rtol=0.002,
atol=0.002,
rtol=0.001,
atol=0.005,
)


Expand Down

0 comments on commit 0f4c39e

Please sign in to comment.