Skip to content

Commit

Permalink
Fixed few of attention test failure (#634)
Browse files Browse the repository at this point in the history
PR fixes #598

---------

Signed-off-by: Whitney Tsang <whitney.tsang@intel.com>
Co-authored-by: Ettore Tiotto <ettore.tiotto@intel.com>
Co-authored-by: Whitney Tsang <whitney.tsang@intel.com>
Co-authored-by: Pavel Chekin <pavel.chekin@intel.com>
  • Loading branch information
4 people authored Mar 11, 2024
1 parent 0f97829 commit dc0dcdb
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 13 deletions.
9 changes: 1 addition & 8 deletions .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,7 @@ jobs:
run: |
cd python/test/unit
python3 -m pytest --junitxml=~/reports/interpreter_core.xml -vvv -n 4 -m interpreter language/test_core.py --device cpu
# TODO: merge the two interpreter tests with env TRITON_INTERPRET=1 and device cpu
- name: Run interpreter tests (test_flash_attention.py)
# env:
# TRITON_INTERPRET: "1"
run: |
cd python/test/unit
python3 -m pytest --junitxml=~/reports/flash_attention.xml -n 8 -m interpreter -vvv -s operators/test_flash_attention.py::test_op --device xpu
python3 -m pytest --junitxml=~/reports/interpreter_flash_attention.xml -n 8 -m interpreter -vvv -s operators/test_flash_attention.py::test_op --device cpu
- name: Run partial operators tests
run: |
Expand Down
16 changes: 12 additions & 4 deletions python/test/unit/operators/test_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,23 @@
@pytest.mark.parametrize('causal', [True, False])
@pytest.mark.parametrize('seq_par', [True, False])
def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par, device):
if D_HEAD != 16:
pytest.skip("FIXME: Enable larger problem sizes when tl.dot uses DPAS")
if torch.xpu.is_available():
if D_HEAD != 16:
pytest.skip("FIXME: Enable larger problem sizes when tl.dot uses DPAS")

if torch.cuda.is_available():
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
pytest.skip("Flash attention only supported for compute capability >= 80")
if dtype == torch.bfloat16 and os.environ.get("TRITON_INTERPRET", "0") == "1":
pytest.skip("Flash attention bfloat16 not supported in interpreter mode")

# Pytorch does not support Half data type for matmul operation hence the skip
if device == 'cpu':
if dtype == torch.float16 and os.environ.get("TRITON_INTERPRET", "0") == "1":
pytest.skip("FIXME: Half is not implemented in Pytorch")

if dtype == torch.bfloat16 and os.environ.get("TRITON_INTERPRET", "0") == "1":
pytest.xfail("Flash attention bfloat16 not supported in interpreter mode")

torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device=device).normal_(mean=0., std=0.5).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device=device).normal_(mean=0., std=0.5).requires_grad_()
Expand Down
2 changes: 1 addition & 1 deletion scripts/test-triton.sh
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ run_core_tests() {
echo "FAILED: return code $?" ; exit $?
fi

TRITON_DISABLE_LINE_INFO=1 python3 -m pytest -n 8 -m interpreter -vvv -s operators/test_flash_attention.py::test_op --device xpu
TRITON_INTERPRET=1 TRITON_DISABLE_LINE_INFO=1 python3 -m pytest -n 8 -m interpreter -vvv -s operators/test_flash_attention.py::test_op --device cpu
if [ $? -ne 0 ]; then
echo "FAILED: return code $?" ; exit $?
fi
Expand Down

0 comments on commit dc0dcdb

Please sign in to comment.