diff --git a/xformers/ops/fmha/ck.py b/xformers/ops/fmha/ck.py index 39a0895533..f1d99f96dc 100644 --- a/xformers/ops/fmha/ck.py +++ b/xformers/ops/fmha/ck.py @@ -101,22 +101,6 @@ def _check_bias_alignment( "you should call `.contiguous()` on the bias" ) - -def _check_large_shapes(reasons: List[str], inp: Inputs) -> None: - """CK kernel throws "Memory access fault by GPU node-2" when B * T >= 2**20, might be some index overflow. - To reproduce, remove this function and run benchmark_mem_eff_attention with ParlAI model shape (256, 4096, 16, 64). - This needs further debugging, for now let's not support such shapes. - """ - b_t_limit = 1024**2 - q_too_large = inp.query.shape[0] * inp.query.shape[1] >= b_t_limit - k_too_large = inp.key.shape[0] * inp.key.shape[1] >= b_t_limit - v_too_large = inp.value.shape[0] * inp.value.shape[1] >= b_t_limit - if q_too_large or k_too_large or v_too_large: - reasons.append( - "Input is too large: product of first two dimensions of q/k/v must be < 2**20" - ) - - class _CustomMaskType(int, Enum): """ (Matches CustomMaskType in C++.) @@ -325,7 +309,6 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]: check_lastdim_alignment_stride1(reasons, "query", d.query, matmul_alignment_mn) check_lastdim_alignment_stride1(reasons, "value", d.value, matmul_alignment_mn) _check_bias_alignment(reasons, d.attn_bias) - _check_large_shapes(reasons, d) return reasons @classmethod @@ -416,7 +399,6 @@ def not_supported_reasons(cls, d: Inputs) -> List[str]: f"(shape: {tuple(attn_bias_tensor.shape)}" f"/ expected: {expected_bias_shape})" ) - _check_large_shapes(reasons, d) return reasons