Skip to content

Commit

Permalink
Remove _check_large_shapes checking in fmha/ck.py
Browse files Browse the repository at this point in the history
  • Loading branch information
qianfengz committed Jul 11, 2024
1 parent 184b280 commit fd237e9
Showing 1 changed file with 0 additions and 18 deletions.
18 changes: 0 additions & 18 deletions xformers/ops/fmha/ck.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,22 +101,6 @@ def _check_bias_alignment(
"you should call `.contiguous()` on the bias"
)


def _check_large_shapes(reasons: List[str], inp: Inputs) -> None:
"""CK kernel throws "Memory access fault by GPU node-2" when B * T >= 2**20, might be some index overflow.
To reproduce, remove this function and run benchmark_mem_eff_attention with ParlAI model shape (256, 4096, 16, 64).
This needs further debugging, for now let's not support such shapes.
"""
b_t_limit = 1024**2
q_too_large = inp.query.shape[0] * inp.query.shape[1] >= b_t_limit
k_too_large = inp.key.shape[0] * inp.key.shape[1] >= b_t_limit
v_too_large = inp.value.shape[0] * inp.value.shape[1] >= b_t_limit
if q_too_large or k_too_large or v_too_large:
reasons.append(
"Input is too large: product of first two dimensions of q/k/v must be < 2**20"
)


class _CustomMaskType(int, Enum):
"""
(Matches CustomMaskType in C++.)
Expand Down Expand Up @@ -325,7 +309,6 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]:
check_lastdim_alignment_stride1(reasons, "query", d.query, matmul_alignment_mn)
check_lastdim_alignment_stride1(reasons, "value", d.value, matmul_alignment_mn)
_check_bias_alignment(reasons, d.attn_bias)
_check_large_shapes(reasons, d)
return reasons

@classmethod
Expand Down Expand Up @@ -416,7 +399,6 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]:
f"(shape: {tuple(attn_bias_tensor.shape)}"
f"/ expected: {expected_bias_shape})"
)
_check_large_shapes(reasons, d)

return reasons

Expand Down

0 comments on commit fd237e9

Please sign in to comment.