Skip to content

Commit

Permalink
Add torch compile support for BlockDiagonalMask (fairinternal/xformer…
Browse files Browse the repository at this point in the history
…s#1145)

* Add torch compile support for BlockDiagonalMask

* Update AttentionBias.to method to match torch.Tensor

* Comment from bottler

* Construct biases by default on CUDA device, do not convert to CUDA when calling mem-eff - error instead

* Update decoder.py and rope_padded

* Fix mypy

__original_commit__ = fairinternal/xformers@4d1eb10
  • Loading branch information
danthe3rd authored and xFormers Bot committed Jul 2, 2024
1 parent 691b03d commit a9e2e7b
Show file tree
Hide file tree
Showing 10 changed files with 276 additions and 110 deletions.
10 changes: 6 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [0.0.27] - TBD
### Added
- fMHA: PagedBlockDiagonalGappyKeysMask
- fMHA: heterogeneous queries in triton_splitk
- fMHA: `PagedBlockDiagonalGappyKeysMask`
- fMHA: heterogeneous queries in `triton_splitk`
- fMHA: support for paged attention in flash
- backwards pass for merge_attentions
- fMHA: Added `torch.compile` support for 2 biases (`LowerTriangularMask` and `LowerTriangularMaskWithTensorBias`)
- fMHA: Added backwards pass for `merge_attentions`
- fMHA: Added `torch.compile` support for 3 biases (`LowerTriangularMask`, `LowerTriangularMaskWithTensorBias` and `BlockDiagonalMask`) - some might require PyTorch 2.4
- fMHA: Added `torch.compile` support in `memory_efficient_attention` when passing the flash operator explicitely (eg `memory_efficient_attention(..., op=(flash.FwOp, flash.BwOp))`)
- fMHA: `memory_efficient_attention` now expects its `attn_bias` argument to be on the same device as the other input tensor. Previously, it would convert the bias to the right device.
- fMHA: `AttentionBias` subclasses are now constructed by default on the `cuda` device if available - they used to be created on the CPU device
- 2:4 sparsity: Added `xformers.ops.sp24.sparsify24_ste` for Straight Through Estimator (STE) with options to rescale the gradient differently for masked out/kept values
### Improved
- fMHA: Fixed out-of-bounds reading for Split-K triton implementation
Expand Down
73 changes: 48 additions & 25 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import math
import random
from functools import partial
from typing import List, Optional, Sequence, Tuple, Type, TypeVar, Union
from typing import Any, List, Optional, Sequence, Tuple, Type, TypeVar, Union

import pytest
import torch
Expand Down Expand Up @@ -289,7 +289,7 @@ def get_bias_grad(attn_bias, clear: bool = False) -> Optional[torch.Tensor]:


def create_tensors(
op: Type[AttentionOpBase],
op: Optional[Type[AttentionOpBase]],
device,
dtype,
attn_bias_type,
Expand All @@ -303,7 +303,7 @@ def create_tensors(
attn_bias_requires_grad: bool = False,
fmt: str = "BMK",
g: int = 1,
):
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Any]:
torch.manual_seed(B * q_len + kv_len * k + kv)

mask_is_bottom_right = attn_bias_type is not None and issubclass(
Expand All @@ -329,7 +329,7 @@ def create_tensors(
),
):
page_size_choices = [256, 512]
if issubclass(op, fmha.triton_splitk.FwOp):
if op is not None and issubclass(op, fmha.triton_splitk.FwOp):
# TODO: enable small pages for flash attention when that's implemented
page_size_choices.extend([64, 128])
page_size = random.choice(page_size_choices)
Expand Down Expand Up @@ -394,12 +394,13 @@ def create_tensors(
]

inputs = fmha.Inputs(query=query, key=key, value=value, attn_bias=attn_bias)
reasons = op.not_supported_reasons(inputs)
if reasons:
err_msg = f"{op.NAME}: unsupported ({'/'.join(reasons)})"
# Ensure we free memory to avoid OOMs
del query, key, value, attn_bias, inputs
pytest.skip(err_msg)
if op is not None:
reasons = op.not_supported_reasons(inputs)
if reasons:
err_msg = f"{op.NAME}: unsupported ({'/'.join(reasons)})"
# Ensure we free memory to avoid OOMs
del query, key, value, attn_bias, inputs
pytest.skip(err_msg)
return query, key, value, attn_bias


Expand Down Expand Up @@ -1645,11 +1646,11 @@ def _test_to_copy(attn_bias: torch.Tensor) -> None:
assert attn_bias_fp16.device.type == "cpu", f"{attn_bias_fp16.device}"
assert attn_bias_fp16.dtype == torch.float16, f"{attn_bias_fp16.dtype}"

attn_bias = fmha.attn_bias.LowerTriangularMask()
attn_bias = fmha.attn_bias.LowerTriangularMask().to("cpu")
_test_to_copy(attn_bias)

tensor_bias = torch.tensor([[1.0, 2.0, 3.0], [3.0, 4.0, 5.0]])
attn_bias = fmha.attn_bias.LowerTriangularMaskWithTensorBias(tensor_bias)
attn_bias = fmha.attn_bias.LowerTriangularMaskWithTensorBias(tensor_bias).to("cpu")
_test_to_copy(attn_bias)


Expand Down Expand Up @@ -3414,26 +3415,45 @@ def _merge_attentions_ref(attn_split, lse_split):

@sm80_or_better_only
@skip_if_rocm # rocm doesn't support backward yet
@pytest.mark.parametrize("bias_t", [None, fmha.attn_bias.LowerTriangularMask])
@pytest.mark.parametrize(
"bias_t",
[None, fmha.attn_bias.LowerTriangularMask, fmha.attn_bias.BlockDiagonalMask],
)
@pytest.mark.parametrize("create_bias_inside_compiled", [False, True])
@pytest.mark.parametrize("op", [None, (fmha.flash.FwOp, fmha.flash.BwOp)])
@pytest.mark.parametrize(
"op",
[None, (fmha.flash.FwOp, fmha.flash.BwOp), (fmha.cutlass.FwOp, fmha.flash.BwOp)],
)
def test_memeff_compile(bias_t, create_bias_inside_compiled: bool, op) -> None:
torch.manual_seed(0)
dtype = torch.float16
torch._dynamo.reset_code_caches() # avoids hitting recompilation limit
B, M, H, K = 1, 256, 2, 64
q, k, v = [
3 * torch.randn([B, M, H, K], device="cuda", dtype=dtype) for _ in range(3)
]
q, k, v, bias = create_tensors(
op if op is None else op[0],
"cuda",
torch.float16,
bias_t,
B,
M,
M,
H,
K,
K,
fmt="BMHK",
)
grad = torch.randn_like(q)
bias = None
if not create_bias_inside_compiled and bias_t is not None:
bias = bias_t()
if create_bias_inside_compiled:
bias = None
if bias_t not in [None, fmha.attn_bias.LowerTriangularMask]:
pytest.skip("Can't create this mask inside compile")
if bias is not None:
bias.to(q.device)
q.requires_grad_(True)
k.requires_grad_(True)
v.requires_grad_(True)

def fmha_fn(q, k, v, bias):
if bias is None and bias_t is not None:
if create_bias_inside_compiled and bias_t is not None:
bias = bias_t()
return fmha.memory_efficient_attention(q, k, v, attn_bias=bias, op=op)

Expand All @@ -3452,10 +3472,13 @@ def fmha_fn(q, k, v, bias):
out,
out_ref,
"out",
atol=fmha.flash.FwOp.ERROR_ATOL[dtype],
rtol=fmha.flash.FwOp.ERROR_RTOL[dtype],
atol=fmha.flash.FwOp.ERROR_ATOL[q.dtype],
rtol=fmha.flash.FwOp.ERROR_RTOL[q.dtype],
)
atol, rtol = (
fmha.flash.BwOp.ERROR_ATOL[q.dtype],
fmha.flash.BwOp.ERROR_RTOL[q.dtype],
)
atol, rtol = fmha.flash.BwOp.ERROR_ATOL[dtype], fmha.flash.BwOp.ERROR_RTOL[dtype]
assert_allclose(q.grad, dq_ref, "dq", atol=atol, rtol=rtol)
assert_allclose(k.grad, dk_ref, "dk", atol=atol, rtol=rtol)
assert_allclose(v.grad, dv_ref, "dv", atol=atol, rtol=rtol)
Expand Down
4 changes: 2 additions & 2 deletions xformers/attn_bias_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def create_attn_bias(
dtype,
requires_grad: bool,
fmt: str,
op: Type[AttentionOpBase],
op: Optional[Type[AttentionOpBase]] = None,
page_size: Optional[int] = None,
):
if bias_type is None or isinstance(None, bias_type):
Expand All @@ -59,7 +59,7 @@ def create_attn_bias(
* 3
)
attn_bias = attn_bias.expand(batch_size, num_heads, q_len, kv_len)
elif issubclass(op, fmha.triton_splitk.FwOp):
elif op is not None and issubclass(op, fmha.triton_splitk.FwOp):
attn_bias = (
torch.randn(
(batch_size, num_heads_groups, num_heads, q_len, kv_len),
Expand Down
Loading

0 comments on commit a9e2e7b

Please sign in to comment.