From a9e2e7b331cc08552835ca0acbc0e15f69b42b30 Mon Sep 17 00:00:00 2001 From: dan_the_3rd <43445237+danthe3rd@users.noreply.github.com> Date: Tue, 2 Jul 2024 15:39:30 +0000 Subject: [PATCH] Add torch compile support for BlockDiagonalMask (fairinternal/xformers#1145) * Add torch compile support for BlockDiagonalMask * Update AttentionBias.to method to match torch.Tensor * Comment from bottler * Construct biases by default on CUDA device, do not convert to CUDA when calling mem-eff - error instead * Update decoder.py and rope_padded * Fix mypy __original_commit__ = fairinternal/xformers@4d1eb107ce11d215921f074f5b28b3a5c46157c5 --- CHANGELOG.md | 10 +- tests/test_mem_eff_attention.py | 73 ++++++++---- xformers/attn_bias_utils.py | 4 +- xformers/ops/fmha/attn_bias.py | 184 +++++++++++++++++++++++------ xformers/ops/fmha/common.py | 22 ++++ xformers/ops/fmha/cutlass.py | 3 +- xformers/ops/fmha/decoder.py | 4 +- xformers/ops/fmha/flash.py | 31 ++--- xformers/ops/fmha/triton_splitk.py | 46 +++++--- xformers/ops/rope_padded.py | 9 +- 10 files changed, 276 insertions(+), 110 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b6ba704b42..141c475de2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,12 +6,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [0.0.27] - TBD ### Added -- fMHA: PagedBlockDiagonalGappyKeysMask -- fMHA: heterogeneous queries in triton_splitk +- fMHA: `PagedBlockDiagonalGappyKeysMask` +- fMHA: heterogeneous queries in `triton_splitk` - fMHA: support for paged attention in flash -- backwards pass for merge_attentions -- fMHA: Added `torch.compile` support for 2 biases (`LowerTriangularMask` and `LowerTriangularMaskWithTensorBias`) +- fMHA: Added backwards pass for `merge_attentions` +- fMHA: Added `torch.compile` support for 3 biases (`LowerTriangularMask`, `LowerTriangularMaskWithTensorBias` and `BlockDiagonalMask`) - some might require PyTorch 2.4 - fMHA: Added `torch.compile` support in `memory_efficient_attention` when passing the flash operator explicitely (eg `memory_efficient_attention(..., op=(flash.FwOp, flash.BwOp))`) +- fMHA: `memory_efficient_attention` now expects its `attn_bias` argument to be on the same device as the other input tensor. Previously, it would convert the bias to the right device. +- fMHA: `AttentionBias` subclasses are now constructed by default on the `cuda` device if available - they used to be created on the CPU device - 2:4 sparsity: Added `xformers.ops.sp24.sparsify24_ste` for Straight Through Estimator (STE) with options to rescale the gradient differently for masked out/kept values ### Improved - fMHA: Fixed out-of-bounds reading for Split-K triton implementation diff --git a/tests/test_mem_eff_attention.py b/tests/test_mem_eff_attention.py index 918fdacdee..dce31201e1 100644 --- a/tests/test_mem_eff_attention.py +++ b/tests/test_mem_eff_attention.py @@ -7,7 +7,7 @@ import math import random from functools import partial -from typing import List, Optional, Sequence, Tuple, Type, TypeVar, Union +from typing import Any, List, Optional, Sequence, Tuple, Type, TypeVar, Union import pytest import torch @@ -289,7 +289,7 @@ def get_bias_grad(attn_bias, clear: bool = False) -> Optional[torch.Tensor]: def create_tensors( - op: Type[AttentionOpBase], + op: Optional[Type[AttentionOpBase]], device, dtype, attn_bias_type, @@ -303,7 +303,7 @@ def create_tensors( attn_bias_requires_grad: bool = False, fmt: str = "BMK", g: int = 1, -): +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Any]: torch.manual_seed(B * q_len + kv_len * k + kv) mask_is_bottom_right = attn_bias_type is not None and issubclass( @@ -329,7 +329,7 @@ def create_tensors( ), ): page_size_choices = [256, 512] - if issubclass(op, fmha.triton_splitk.FwOp): + if op is not None and issubclass(op, fmha.triton_splitk.FwOp): # TODO: enable small pages for flash attention when that's implemented page_size_choices.extend([64, 128]) page_size = random.choice(page_size_choices) @@ -394,12 +394,13 @@ def create_tensors( ] inputs = fmha.Inputs(query=query, key=key, value=value, attn_bias=attn_bias) - reasons = op.not_supported_reasons(inputs) - if reasons: - err_msg = f"{op.NAME}: unsupported ({'/'.join(reasons)})" - # Ensure we free memory to avoid OOMs - del query, key, value, attn_bias, inputs - pytest.skip(err_msg) + if op is not None: + reasons = op.not_supported_reasons(inputs) + if reasons: + err_msg = f"{op.NAME}: unsupported ({'/'.join(reasons)})" + # Ensure we free memory to avoid OOMs + del query, key, value, attn_bias, inputs + pytest.skip(err_msg) return query, key, value, attn_bias @@ -1645,11 +1646,11 @@ def _test_to_copy(attn_bias: torch.Tensor) -> None: assert attn_bias_fp16.device.type == "cpu", f"{attn_bias_fp16.device}" assert attn_bias_fp16.dtype == torch.float16, f"{attn_bias_fp16.dtype}" - attn_bias = fmha.attn_bias.LowerTriangularMask() + attn_bias = fmha.attn_bias.LowerTriangularMask().to("cpu") _test_to_copy(attn_bias) tensor_bias = torch.tensor([[1.0, 2.0, 3.0], [3.0, 4.0, 5.0]]) - attn_bias = fmha.attn_bias.LowerTriangularMaskWithTensorBias(tensor_bias) + attn_bias = fmha.attn_bias.LowerTriangularMaskWithTensorBias(tensor_bias).to("cpu") _test_to_copy(attn_bias) @@ -3414,26 +3415,45 @@ def _merge_attentions_ref(attn_split, lse_split): @sm80_or_better_only @skip_if_rocm # rocm doesn't support backward yet -@pytest.mark.parametrize("bias_t", [None, fmha.attn_bias.LowerTriangularMask]) +@pytest.mark.parametrize( + "bias_t", + [None, fmha.attn_bias.LowerTriangularMask, fmha.attn_bias.BlockDiagonalMask], +) @pytest.mark.parametrize("create_bias_inside_compiled", [False, True]) -@pytest.mark.parametrize("op", [None, (fmha.flash.FwOp, fmha.flash.BwOp)]) +@pytest.mark.parametrize( + "op", + [None, (fmha.flash.FwOp, fmha.flash.BwOp), (fmha.cutlass.FwOp, fmha.flash.BwOp)], +) def test_memeff_compile(bias_t, create_bias_inside_compiled: bool, op) -> None: torch.manual_seed(0) - dtype = torch.float16 + torch._dynamo.reset_code_caches() # avoids hitting recompilation limit B, M, H, K = 1, 256, 2, 64 - q, k, v = [ - 3 * torch.randn([B, M, H, K], device="cuda", dtype=dtype) for _ in range(3) - ] + q, k, v, bias = create_tensors( + op if op is None else op[0], + "cuda", + torch.float16, + bias_t, + B, + M, + M, + H, + K, + K, + fmt="BMHK", + ) grad = torch.randn_like(q) - bias = None - if not create_bias_inside_compiled and bias_t is not None: - bias = bias_t() + if create_bias_inside_compiled: + bias = None + if bias_t not in [None, fmha.attn_bias.LowerTriangularMask]: + pytest.skip("Can't create this mask inside compile") + if bias is not None: + bias.to(q.device) q.requires_grad_(True) k.requires_grad_(True) v.requires_grad_(True) def fmha_fn(q, k, v, bias): - if bias is None and bias_t is not None: + if create_bias_inside_compiled and bias_t is not None: bias = bias_t() return fmha.memory_efficient_attention(q, k, v, attn_bias=bias, op=op) @@ -3452,10 +3472,13 @@ def fmha_fn(q, k, v, bias): out, out_ref, "out", - atol=fmha.flash.FwOp.ERROR_ATOL[dtype], - rtol=fmha.flash.FwOp.ERROR_RTOL[dtype], + atol=fmha.flash.FwOp.ERROR_ATOL[q.dtype], + rtol=fmha.flash.FwOp.ERROR_RTOL[q.dtype], + ) + atol, rtol = ( + fmha.flash.BwOp.ERROR_ATOL[q.dtype], + fmha.flash.BwOp.ERROR_RTOL[q.dtype], ) - atol, rtol = fmha.flash.BwOp.ERROR_ATOL[dtype], fmha.flash.BwOp.ERROR_RTOL[dtype] assert_allclose(q.grad, dq_ref, "dq", atol=atol, rtol=rtol) assert_allclose(k.grad, dk_ref, "dk", atol=atol, rtol=rtol) assert_allclose(v.grad, dv_ref, "dv", atol=atol, rtol=rtol) diff --git a/xformers/attn_bias_utils.py b/xformers/attn_bias_utils.py index 224302c4f8..fb8d8207f2 100644 --- a/xformers/attn_bias_utils.py +++ b/xformers/attn_bias_utils.py @@ -39,7 +39,7 @@ def create_attn_bias( dtype, requires_grad: bool, fmt: str, - op: Type[AttentionOpBase], + op: Optional[Type[AttentionOpBase]] = None, page_size: Optional[int] = None, ): if bias_type is None or isinstance(None, bias_type): @@ -59,7 +59,7 @@ def create_attn_bias( * 3 ) attn_bias = attn_bias.expand(batch_size, num_heads, q_len, kv_len) - elif issubclass(op, fmha.triton_splitk.FwOp): + elif op is not None and issubclass(op, fmha.triton_splitk.FwOp): attn_bias = ( torch.randn( (batch_size, num_heads_groups, num_heads, q_len, kv_len), diff --git a/xformers/ops/fmha/attn_bias.py b/xformers/ops/fmha/attn_bias.py index 2f095df22e..d3c5f9487a 100644 --- a/xformers/ops/fmha/attn_bias.py +++ b/xformers/ops/fmha/attn_bias.py @@ -87,6 +87,14 @@ def materialize( raise NotImplementedError() +def _get_default_bias_device(device: Optional[torch.device] = None) -> torch.device: + if device is None: + if torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("cpu") + return device + + def _materialize_causal_mask( shape: Tuple[int, ...], dtype: torch.dtype = torch.float32, @@ -218,6 +226,9 @@ class LowerTriangularFromBottomRightMask(AttentionBias): equivalent if the number of queries equals the number of keys. """ + def to(self, device: torch.device) -> "LowerTriangularFromBottomRightMask": + return self + def materialize( self, shape: Tuple[int, ...], @@ -302,15 +313,22 @@ class _SeqLenInfo: min_seqlen: int seqstart_py: List[int] - def to(self, device: torch.device) -> None: - self.seqstart = self.seqstart.to(device, non_blocking=True) + def to(self, device: torch.device) -> "_SeqLenInfo": + if self.seqstart.device == device: + return self + return _SeqLenInfo( + seqstart=self.seqstart.to(device), + max_seqlen=self.max_seqlen, + min_seqlen=self.min_seqlen, + seqstart_py=self.seqstart_py, + ) def intervals(self) -> Iterable[Tuple[int, int]]: yield from zip(self.seqstart_py, self.seqstart_py[1:]) @classmethod def _get_seqstart( - cls, seqlens: Iterable[int] + cls, seqlens: Iterable[int], *, device: torch.device ) -> Tuple[int, int, List[int], torch.Tensor]: """ Given sequence lengths, returns the min/max value and the sequence start @@ -325,16 +343,21 @@ def _get_seqstart( min_seqlen = min(min_seqlen, seqlen) if min_seqlen != -1 else seqlen max_seqlen = max(max_seqlen, seqlen) seqstart_py.append(seqstart_py[len(seqstart_py) - 1] + seqlen) - seqstart = torch.tensor(seqstart_py, dtype=torch.int32) + seqstart = torch.tensor(seqstart_py, dtype=torch.int32, device=device) return (min_seqlen, max_seqlen, seqstart_py, seqstart) @classmethod - def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo": + def from_seqlens( + cls, seqlens: Iterable[int], *, device: Optional[torch.device] = None + ) -> "_SeqLenInfo": """ Input tensors are assumed to be in shape [B, M, *] """ - min_seqlen, max_seqlen, seqstart_py, seqstart = cls._get_seqstart(seqlens) + device = _get_default_bias_device(device) + min_seqlen, max_seqlen, seqstart_py, seqstart = cls._get_seqstart( + seqlens, device=device + ) return cls( max_seqlen=max_seqlen, @@ -413,23 +436,40 @@ class _PaddedSeqLenInfo(_SeqLenInfo): def __post_init__(self) -> None: assert len(self.seqstart_py) == len(self.seqlen_py) + 1 - def to(self, device: torch.device) -> None: - self.seqlen = self.seqlen.to(device, non_blocking=True) - super().to(device) + def to(self, device: torch.device) -> "_PaddedSeqLenInfo": + if self.seqlen.device == device: + return self + return _PaddedSeqLenInfo( + # _SeqLenInfo + seqstart=self.seqstart.to(device), + max_seqlen=self.max_seqlen, + min_seqlen=self.min_seqlen, + seqstart_py=self.seqstart_py, + # _PaddedSeqLenInfo + seqlen=self.seqlen.to(device), + seqlen_py=self.seqlen_py, + padding=self.padding, + ) def intervals(self) -> Iterable[Tuple[int, int]]: for (start, _), length in zip(super().intervals(), self.seqlen_py): yield start, start + length @classmethod - def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo": + def from_seqlens( + cls, seqlens: Iterable[int], *, device: Optional[torch.device] = None + ) -> "_SeqLenInfo": raise RuntimeError( "Use either `_SeqLenInfo.from_seqlens` or `_PaddedSeqLenInfo.from_seqlens_padded`" ) @classmethod def from_seqlens_padded( - cls, seqlens: Sequence[int], padding: int + cls, + seqlens: Sequence[int], + padding: int, + *, + device: Optional[torch.device] = None, ) -> "_PaddedSeqLenInfo": """ Input tensors are assumed to be in shape [B, M, *] @@ -439,14 +479,15 @@ def from_seqlens_padded( assert all( seqlen <= padding for seqlen in seqlens ), f"Seqlens {seqlens} Padding {padding}" + device = _get_default_bias_device(device) seqstart_py = list(range(0, len(seqlens) * padding + 1, padding)) - seqlen = torch.tensor(seqlens, dtype=torch.int32) + seqlen = torch.tensor(seqlens, dtype=torch.int32, device=device) return cls( seqlen=seqlen, seqlen_py=seqlens, max_seqlen=max(seqlens), min_seqlen=min(seqlens), - seqstart=torch.tensor(seqstart_py, dtype=torch.int32), + seqstart=torch.tensor(seqstart_py, dtype=torch.int32, device=device), seqstart_py=seqstart_py, padding=padding, ) @@ -510,21 +551,38 @@ class _GappySeqInfo(_SeqLenInfo): # of the i-th sequence # seqstart: torch.Tensor - def to(self, device: torch.device) -> None: - self.seqlen = self.seqlen.to(device, non_blocking=True) - super().to(device) + def to(self, device: torch.device) -> "_GappySeqInfo": + if self.seqlen.device == device: + return self + return _GappySeqInfo( + # _SeqLenInfo + seqstart=self.seqstart.to(device), + max_seqlen=self.max_seqlen, + min_seqlen=self.min_seqlen, + seqstart_py=self.seqstart_py, + # _GappySeqInfo + seqlen=self.seqlen.to(device), + seqlen_py=self.seqlen_py, + ) def intervals(self) -> Iterable[Tuple[int, int]]: for (start, _), length in zip(super().intervals(), self.seqlen_py): yield start, start + length @classmethod - def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo": + def from_seqlens( + cls, seqlens: Iterable[int], *, device: Optional[torch.device] = None + ) -> "_SeqLenInfo": raise NotImplementedError() @classmethod def from_seqlens_gappy( - cls, seqstarts: Sequence[int], seqlens: Sequence[int], paged: bool + cls, + seqstarts: Sequence[int], + seqlens: Sequence[int], + paged: bool, + *, + device: torch.device, ) -> "_GappySeqInfo": assert not isinstance(seqlens, torch.Tensor) seqstart_py = list(seqstarts) @@ -535,13 +593,13 @@ def from_seqlens_gappy( raise ValueError( f"len(seqstarts)={seqstarts} should be {extra}len(seqlens)={seqlens}" ) - seqlen = torch.tensor(seqlens, dtype=torch.int32) + seqlen = torch.tensor(seqlens, dtype=torch.int32, device=device) return cls( seqlen=seqlen, seqlen_py=seqlens, max_seqlen=max(seqlens), min_seqlen=min(seqlens), - seqstart=torch.tensor(seqstart_py, dtype=torch.int32), + seqstart=torch.tensor(seqstart_py, dtype=torch.int32, device=device), seqstart_py=seqstart_py, ) @@ -595,6 +653,13 @@ class BlockDiagonalMask(AttentionBias): k_seqinfo: _SeqLenInfo _batch_sizes: Optional[Sequence[int]] = None + def to(self, device) -> "BlockDiagonalMask": + return BlockDiagonalMask( + q_seqinfo=self.q_seqinfo.to(device), + k_seqinfo=self.k_seqinfo.to(device), + _batch_sizes=self._batch_sizes, + ) + def _create_block_mask( self, shape: Tuple[int, ...], @@ -644,6 +709,8 @@ def from_seqlens( cls, q_seqlen: Sequence[int], kv_seqlen: Optional[Sequence[int]] = None, + *, + device: Optional[torch.device] = None, ) -> "BlockDiagonalMask": """Creates a :attr:`BlockDiagonalMask` from a list of tensors lengths for query and key/value. @@ -654,12 +721,13 @@ def from_seqlens( Returns: BlockDiagonalMask """ + device = _get_default_bias_device(device) assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen) - q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen) + q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen, device=device) if kv_seqlen is None or q_seqlen == kv_seqlen: k_seqinfo = q_seqinfo else: - k_seqinfo = _SeqLenInfo.from_seqlens(kv_seqlen) + k_seqinfo = _SeqLenInfo.from_seqlens(kv_seqlen, device=device) return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo) @classmethod @@ -864,6 +932,12 @@ class BlockDiagonalPaddedKeysMask(AttentionBias): q_seqinfo: _SeqLenInfo k_seqinfo: _PaddedSeqLenInfo + def to(self, device) -> "BlockDiagonalPaddedKeysMask": + return BlockDiagonalPaddedKeysMask( + q_seqinfo=self.q_seqinfo.to(device), + k_seqinfo=self.k_seqinfo.to(device), + ) + def _create_block_mask( self, shape: Tuple[int, ...], @@ -907,6 +981,8 @@ def from_seqlens( kv_padding: int, kv_seqlen: Sequence[int], causal_diagonal: Any = None, + *, + device: Optional[torch.device] = None, ) -> "BlockDiagonalPaddedKeysMask": """Creates a :attr:`BlockDiagonalPaddedKeysMask` from a list of tensor lengths for query and key/value. @@ -919,12 +995,15 @@ def from_seqlens( Returns: BlockDiagonalPaddedKeysMask """ + device = _get_default_bias_device(device) assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen), ( q_seqlen, kv_seqlen, ) - q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen) - k_seqinfo = _PaddedSeqLenInfo.from_seqlens_padded(kv_seqlen, kv_padding) + q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen, device=device) + k_seqinfo = _PaddedSeqLenInfo.from_seqlens_padded( + kv_seqlen, kv_padding, device=device + ) return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo) def make_paged( @@ -982,6 +1061,8 @@ def from_seqlens( kv_padding: int, kv_seqlen: Sequence[int], causal_diagonal: Any = None, + *, + device: Optional[torch.device] = None, ) -> "BlockDiagonalCausalWithOffsetPaddedKeysMask": """Creates a :attr:`BlockDiagonalCausalWithOffsetPaddedKeysMask` from a list of tensor lengths for query and key/value. @@ -998,8 +1079,11 @@ def from_seqlens( q_seqlen, kv_seqlen, ) - q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen) - k_seqinfo = _PaddedSeqLenInfo.from_seqlens_padded(kv_seqlen, kv_padding) + device = _get_default_bias_device(device) + q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen, device=device) + k_seqinfo = _PaddedSeqLenInfo.from_seqlens_padded( + kv_seqlen, kv_padding, device=device + ) return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo) @@ -1021,6 +1105,14 @@ class PagedBlockDiagonalPaddedKeysMask(AttentionBias): Type[BlockDiagonalPaddedKeysMask] ] = BlockDiagonalPaddedKeysMask + def to(self, device: torch.device) -> "PagedBlockDiagonalPaddedKeysMask": + return PagedBlockDiagonalPaddedKeysMask( + q_seqinfo=self.q_seqinfo.to(device), + k_seqinfo=self.k_seqinfo.to(device), + block_tables=self.block_tables.to(device), + page_size=self.page_size, + ) + def materialize( self, shape: Tuple[int, ...], @@ -1067,6 +1159,8 @@ def from_seqlens( kv_seqlen: Sequence[int], block_tables: torch.Tensor, page_size: int, + *, + device: Optional[torch.device] = None, ) -> "PagedBlockDiagonalPaddedKeysMask": """Creates a :attr:`PagedBlockDiagonalPaddedKeysMask` from a list of tensor lengths for query and key/value. @@ -1083,9 +1177,10 @@ def from_seqlens( q_seqlen, kv_seqlen, ) - q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen) + device = _get_default_bias_device(device) + q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen, device=device) k_seqinfo = _PaddedSeqLenInfo.from_seqlens_padded( - kv_seqlen, padding=block_tables.shape[1] * page_size + kv_seqlen, padding=block_tables.shape[1] * page_size, device=device ) return cls( q_seqinfo=q_seqinfo, @@ -1121,6 +1216,12 @@ class BlockDiagonalGappyKeysMask(AttentionBias): q_seqinfo: _SeqLenInfo k_seqinfo: _GappySeqInfo + def to(self, device: torch.device) -> "BlockDiagonalGappyKeysMask": + return BlockDiagonalGappyKeysMask( + q_seqinfo=self.q_seqinfo.to(device), + k_seqinfo=self.k_seqinfo.to(device), + ) + def materialize( self, shape: Tuple[int, ...], @@ -1151,6 +1252,8 @@ def from_seqlens( q_seqlen: Sequence[int], kv_seqstarts: Sequence[int], kv_seqlen: Sequence[int], + *, + device: Optional[torch.device] = None, ) -> "BlockDiagonalGappyKeysMask": """Creates a :attr:`BlockDiagonalGappyKeysMask` from a list of tensor lengths for query and key/value. @@ -1159,8 +1262,11 @@ def from_seqlens( q_seqlen, kv_seqlen, ) - q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen) - k_seqinfo = _GappySeqInfo.from_seqlens_gappy(kv_seqstarts, kv_seqlen, False) + device = _get_default_bias_device(device) + q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen, device=device) + k_seqinfo = _GappySeqInfo.from_seqlens_gappy( + kv_seqstarts, kv_seqlen, False, device=device + ) return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo) def make_paged( @@ -1183,7 +1289,7 @@ def make_paged( ] assert all(0 <= i < max_row_len for i in new_seqstarts) k_seqinfo = _GappySeqInfo.from_seqlens_gappy( - new_seqstarts, self.k_seqinfo.seqlen_py, True + new_seqstarts, self.k_seqinfo.seqlen_py, True, device=block_tables.device ) assert self.k_seqinfo.max_seqlen <= max_row_len paged_bias = paged_type( @@ -1272,7 +1378,10 @@ def materialize( bias_nonpaged = self._UNPAGED_TYPE( q_seqinfo=self.q_seqinfo, k_seqinfo=_GappySeqInfo.from_seqlens_gappy( - new_seqstarts, self.k_seqinfo.seqlen_py, False + new_seqstarts, + self.k_seqinfo.seqlen_py, + False, + device=torch.device(device), ), ) mask_nonpaged = bias_nonpaged.materialize(shape, dtype, device) @@ -1305,6 +1414,8 @@ def from_seqlens( kv_seqlen: Sequence[int], block_tables: torch.Tensor, page_size: int, + *, + device: Optional[torch.device] = None, ) -> "PagedBlockDiagonalGappyKeysMask": """Creates a :attr:`PagedBlockDiagonalGappyKeysMask` from a list of tensor lengths for query and key/value. @@ -1323,8 +1434,11 @@ def from_seqlens( kv_seqlen, kv_seqstarts, ) - q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen) - k_seqinfo = _GappySeqInfo.from_seqlens_gappy(kv_seqstarts, kv_seqlen, True) + device = block_tables.device if device is None else device + q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen, device=device) + k_seqinfo = _GappySeqInfo.from_seqlens_gappy( + kv_seqstarts, kv_seqlen, True, device=device + ) return cls( q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo, @@ -1430,7 +1544,7 @@ class AttentionBiasSubTensor(torch.Tensor, AttentionBias): @staticmethod def __new__(cls, *, _subtensor=None): if _subtensor is None: - _subtensor = torch.empty((0,), device="cpu") + _subtensor = torch.empty((0,), device=_get_default_bias_device()) tensor = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] cls, [], diff --git a/xformers/ops/fmha/common.py b/xformers/ops/fmha/common.py index 734c44d018..03632cd633 100644 --- a/xformers/ops/fmha/common.py +++ b/xformers/ops/fmha/common.py @@ -26,8 +26,12 @@ from .attn_bias import ( AttentionBias, AttentionBiasSubTensor, + BlockDiagonalGappyKeysMask, BlockDiagonalMask, + BlockDiagonalPaddedKeysMask, LowerTriangularMask, + PagedBlockDiagonalGappyKeysMask, + PagedBlockDiagonalPaddedKeysMask, ) @@ -126,6 +130,24 @@ def validate_inputs(self) -> None: ) if any(x.device != self.query.device for x in qkv): raise ValueError("Query/Key/Value should all be on the same device") + if isinstance( + self.attn_bias, + ( + BlockDiagonalMask, + BlockDiagonalPaddedKeysMask, + PagedBlockDiagonalPaddedKeysMask, + BlockDiagonalGappyKeysMask, + PagedBlockDiagonalGappyKeysMask, + ), + ): + bias_device = self.attn_bias.q_seqinfo.seqstart.device + if bias_device != self.query.device: + raise ValueError( + f"Attention bias and Query/Key/Value should be on the same device\n" + f" query.device: {self.query.device}\n" + f" attn_bias : {bias_device}\n" + ) + quantized_dtypes = self.key.dtype == self.value.dtype == torch.int32 non_quantized_dtypes = all(x.dtype == self.query.dtype for x in qkv) if not (quantized_dtypes or non_quantized_dtypes): diff --git a/xformers/ops/fmha/cutlass.py b/xformers/ops/fmha/cutlass.py index 78556d4173..e95b42321c 100644 --- a/xformers/ops/fmha/cutlass.py +++ b/xformers/ops/fmha/cutlass.py @@ -70,8 +70,7 @@ def _get_seqlen_info( if isinstance( attn_bias, (BlockDiagonalMask, BlockDiagonalCausalWithOffsetPaddedKeysMask) ): - attn_bias.k_seqinfo.to(inp.query.device) - attn_bias.q_seqinfo.to(inp.query.device) + assert attn_bias.k_seqinfo.seqstart.device == inp.query.device seqstart_k = attn_bias.k_seqinfo.seqstart seqstart_q = attn_bias.q_seqinfo.seqstart max_seqlen_q = attn_bias.q_seqinfo.max_seqlen diff --git a/xformers/ops/fmha/decoder.py b/xformers/ops/fmha/decoder.py index 286a87d9c2..a3090c2539 100644 --- a/xformers/ops/fmha/decoder.py +++ b/xformers/ops/fmha/decoder.py @@ -76,9 +76,7 @@ def apply( raise NotImplementedError("gradient") attn_bias = inp.attn_bias assert isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask) - - attn_bias.k_seqinfo.to(inp.query.device) - attn_bias.q_seqinfo.to(inp.query.device) + assert attn_bias.k_seqinfo.seqlen.device == inp.query.device padding = attn_bias.k_seqinfo.padding query, key, value = inp.get_qkv_in_bmghk() diff --git a/xformers/ops/fmha/flash.py b/xformers/ops/fmha/flash.py index 49e708dc28..4b85ebecf0 100644 --- a/xformers/ops/fmha/flash.py +++ b/xformers/ops/fmha/flash.py @@ -227,9 +227,15 @@ def _flash_fwd_abstract( block_tables, unpadded_lse, ): - B, M, H, K = query.shape out = torch.empty_like(query) - lse_shape = [H, B * M] if unpadded_lse else [B, H, M] + if cu_seq_lens_q is None: + B, M, H, K = query.shape + lse_shape = [H, B * M] if unpadded_lse else [B, H, M] + else: + assert unpadded_lse is False + M, H, K = query.shape + B = cu_seq_lens_q.shape[0] - 1 + lse_shape = [B, H, max_seq_len_q] softmax_lse = torch.empty(lse_shape, device=query.device, dtype=torch.float32) rng_state = torch.empty([2], device=query.device, dtype=torch.int64) return out, softmax_lse, rng_state @@ -387,14 +393,7 @@ def _convert_input_format( attn_bias = inp.attn_bias if isinstance(attn_bias, BlockDiagonalMask): - # BlockDiagonalMask or BlockDiagonalCausalMask - attn_bias.k_seqinfo.seqstart = attn_bias.k_seqinfo.seqstart.to( - inp.query.device, non_blocking=True - ) - attn_bias.q_seqinfo.seqstart = attn_bias.q_seqinfo.seqstart.to( - inp.query.device, non_blocking=True - ) - + assert attn_bias.k_seqinfo.seqstart.device == inp.query.device cu_seqlen_k = attn_bias.k_seqinfo.seqstart cu_seqlen_q = attn_bias.q_seqinfo.seqstart max_seqlen_q = attn_bias.q_seqinfo.max_seqlen @@ -408,15 +407,7 @@ def _convert_input_format( PagedBlockDiagonalPaddedKeysMask, ), ): - attn_bias.k_seqinfo.seqstart = attn_bias.k_seqinfo.seqstart.to( - inp.query.device, non_blocking=True - ) - attn_bias.q_seqinfo.seqstart = attn_bias.q_seqinfo.seqstart.to( - inp.query.device, non_blocking=True - ) - attn_bias.k_seqinfo.seqlen = attn_bias.k_seqinfo.seqlen.to( - inp.query.device, non_blocking=True - ) + assert attn_bias.k_seqinfo.seqstart.device == inp.query.device cu_seqlen_k = attn_bias.k_seqinfo.seqstart cu_seqlen_q = attn_bias.q_seqinfo.seqstart max_seqlen_q = attn_bias.q_seqinfo.max_seqlen @@ -469,7 +460,7 @@ def fold(x): query=query, key=key, value=value, - attn_bias=inp.attn_bias, + attn_bias=attn_bias, p=inp.p, scale=inp.scale, output_dtype=inp.output_dtype, diff --git a/xformers/ops/fmha/triton_splitk.py b/xformers/ops/fmha/triton_splitk.py index 60fe792f94..740589d10c 100644 --- a/xformers/ops/fmha/triton_splitk.py +++ b/xformers/ops/fmha/triton_splitk.py @@ -16,6 +16,8 @@ Sequence, Tuple, Type, + Union, + cast, ) import torch @@ -23,7 +25,6 @@ from ... import _is_triton_available from ..common import register_operator from .attn_bias import ( - AttentionBias, BlockDiagonalCausalWithOffsetGappyKeysMask, BlockDiagonalCausalWithOffsetPaddedKeysMask, BlockDiagonalGappyKeysMask, @@ -1316,12 +1317,26 @@ def apply( cls, inp: Inputs, needs_gradient: bool ) -> Tuple[torch.Tensor, Optional[Context]]: output_dtype = inp.get_output_dtype() - if isinstance(inp.attn_bias, torch.Tensor): - attn_bias_tensor = inp.attn_bias - attn_bias: Optional[AttentionBias] = None - else: + if not isinstance(inp.attn_bias, torch.Tensor): attn_bias_tensor = None - attn_bias = inp.attn_bias + attn_bias = cast( + Optional[ + Union[ + BlockDiagonalCausalWithOffsetPaddedKeysMask, + BlockDiagonalGappyKeysMask, + BlockDiagonalCausalWithOffsetGappyKeysMask, + BlockDiagonalPaddedKeysMask, + PagedBlockDiagonalCausalWithOffsetPaddedKeysMask, + PagedBlockDiagonalGappyKeysMask, + PagedBlockDiagonalPaddedKeysMask, + ] + ], + inp.attn_bias, + ) + else: + attn_bias_tensor = inp.attn_bias + attn_bias = None + seq_len = None seq_starts_k = None seq_starts_q = None @@ -1336,26 +1351,23 @@ def apply( is_paged = _is_supported_paged_bias(attn_bias) if attn_bias is not None: assert is_paged or is_block_diagonal or is_gappy - # TODO: do we really need to do this cast? seems fishy but - # I just copied it from the decoder.py - attn_bias.k_seqinfo.to(inp.query.device) # type: ignore - attn_bias.q_seqinfo.to(inp.query.device) # type: ignore - seq_len = attn_bias.k_seqinfo.seqlen # type: ignore + assert attn_bias.k_seqinfo.seqlen.device == inp.query.device + seq_len = attn_bias.k_seqinfo.seqlen assert seq_len.stride(0) == 1 if is_gappy: - seq_starts_k = attn_bias.k_seqinfo.seqstart # type: ignore + seq_starts_k = attn_bias.k_seqinfo.seqstart assert seq_starts_k.stride(0) == 1 assert q.shape[0] == 1 B = len(seq_len) G, Hq, Kq = q.shape[-3:] # force a bool because triton cannot take np.bool_ - multiple_q = bool(attn_bias.q_seqinfo.max_seqlen > 1) # type: ignore + multiple_q = bool(attn_bias.q_seqinfo.max_seqlen > 1) IS_CAUSAL = multiple_q and _is_supported_causal_bias(attn_bias) variable_q = multiple_q and not IS_CAUSAL Kkv = v.shape[-1] if variable_q: - seq_starts_q = attn_bias.q_seqinfo.seqstart # type: ignore + seq_starts_q = attn_bias.q_seqinfo.seqstart seq_starts_q_multiplier = 1 assert seq_starts_q.stride(0) == 1 else: @@ -1408,7 +1420,9 @@ def apply( Bqq, Mqq, G, H, Kq = q.shape assert Lk == Kq, f"Keys have head dim {Lk} but queries have head dim {Kq}" if variable_q: - M = attn_bias.q_seqinfo.max_seqlen * seq_starts_q_multiplier # type: ignore + assert attn_bias is not None + assert seq_starts_q_multiplier is not None + M = attn_bias.q_seqinfo.max_seqlen * seq_starts_q_multiplier else: M = Mqq page_size = inp.attn_bias.page_size if is_paged else 0 # type: ignore @@ -1419,7 +1433,7 @@ def apply( kv_cache_blocks_per_row = block_tables.shape[1] Mk = block_tables.shape[1] * page_size elif attn_bias is not None: - Mk = min(Mk, attn_bias.k_seqinfo.max_seqlen) # type: ignore + Mk = min(Mk, attn_bias.k_seqinfo.max_seqlen) if cls.SPLIT_K is not None: split_k = cls.SPLIT_K diff --git a/xformers/ops/rope_padded.py b/xformers/ops/rope_padded.py index 18b474abc2..3aad1d64cc 100644 --- a/xformers/ops/rope_padded.py +++ b/xformers/ops/rope_padded.py @@ -219,12 +219,15 @@ def rope_padded( # heuristics for number of warps num_warps = min(max(BLOCK_SIZE // 256, 1), 8) device = xq.device - # Move these to the right device, like fmha does. - attn_bias.k_seqinfo.to(device) - attn_bias.q_seqinfo.to(device) seqstartq = attn_bias.q_seqinfo.seqstart seqstartk = attn_bias.k_seqinfo.seqstart seqlenk = attn_bias.k_seqinfo.seqlen + if ( + seqstartq.device != device + or seqstartk.device != device + or seqlenk.device != device + ): + raise ValueError("`attn_bias` must be on the same device as the other inputs") assert internal_dtype in ["", "f32", "f64"] # experiment with the order of dims here. with torch.cuda.device(xq.device):