Skip to content

Commit

Permalink
Fix the scripts to pass flake8 checking
Browse files Browse the repository at this point in the history
  • Loading branch information
qianfengz committed Dec 21, 2024
1 parent 8d3dad1 commit 870fefc
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 5 deletions.
4 changes: 2 additions & 2 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,8 +672,8 @@ def test_backward(
if op_bw == fmha.ck.BwOp:
op_fw = fmha.ck.FwOp
if dtype == torch.bfloat16:
## bfloat16 testing can be enabled by export ENABLE_HIP_FMHA_RTN_BF16_CONVERT=1 when
## building xformers and get accurate results
# bfloat16 testing can be enabled by export ENABLE_HIP_FMHA_RTN_BF16_CONVERT=1 when
# building xformers and get accurate results
pytest.skip(
"CK Fmha backward for bfloat16 currently is not very accurate for some cases!"
)
Expand Down
4 changes: 1 addition & 3 deletions xformers/ops/fmha/ck.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from dataclasses import replace
from enum import Enum
from functools import partial
from typing import Any, Iterable, List, Mapping, Optional, Set, Tuple, Union

import torch
Expand Down Expand Up @@ -38,7 +37,6 @@
Context,
Gradients,
Inputs,
_attn_bias_apply,
check_lastdim_alignment_stride1,
)

Expand Down Expand Up @@ -218,7 +216,7 @@ def apply(
assert inp.query.ndim == 5, f"query has shape {inp.query.shape}"
ctx: Optional[Context] = None

## consider for expanded 5-D inputted
# consider for expanded 5-D inputted
if inp.key.stride()[3] == 0:
assert (
inp.value.stride()[3] == 0
Expand Down

0 comments on commit 870fefc

Please sign in to comment.