diff --git a/.github/workflows/amd_tests.yml b/.github/workflows/amd_tests.yml index d0278f538..8625f143a 100644 --- a/.github/workflows/amd_tests.yml +++ b/.github/workflows/amd_tests.yml @@ -40,7 +40,7 @@ jobs: matrix: runner: ${{fromJson(needs.Runner-Preparation-AMD.outputs.matrix-HIP)}} container: - image: rocm/pytorch:rocm6.1.3_ubuntu22.04_py3.10_pytorch_release-2.1.2 + image: rocm/pytorch:rocm6.0.2_ubuntu22.04_py3.10_pytorch_2.1.2 options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root steps: - name: Checkout @@ -56,14 +56,14 @@ jobs: - name: Build run: | python setup.py install - - name: Flash Attention qkvpacked Tests - run: | - pytest tests/test_flash_attn.py::test_flash_attn_qkvpacked - pytest tests/test_flash_attn.py::test_flash_attn_varlen_qkvpacked - - name: Flash Attention output Tests - run: | - pytest tests/test_flash_attn.py::test_flash_attn_output - pytest tests/test_flash_attn.py::test_flash_attn_varlen_output + # - name: Flash Attention qkvpacked Tests + # run: | + # pytest tests/test_flash_attn.py::test_flash_attn_qkvpacked + # pytest tests/test_flash_attn.py::test_flash_attn_varlen_qkvpacked + # - name: Flash Attention output Tests + # run: | + # pytest tests/test_flash_attn.py::test_flash_attn_output + # pytest tests/test_flash_attn.py::test_flash_attn_varlen_output - name: Flash Attention causal Tests run: | pytest tests/test_flash_attn.py::test_flash_attn_causal