From 96e52222bef4c9a1cb34c886a0886cca9d5589e2 Mon Sep 17 00:00:00 2001 From: Andrew Gu <31054793+awgu@users.noreply.github.com> Date: Fri, 14 Jun 2024 18:07:15 +0000 Subject: [PATCH] Supported `aten._to_copy` in attention biases (fairinternal/xformers#1135) __original_commit__ = fairinternal/xformers@55d7e785bec2490984410ce5b2760aab657c28b2 --- tests/test_mem_eff_attention.py | 18 ++++++++++++++++++ xformers/ops/fmha/attn_bias.py | 18 ++++++++++++++---- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/tests/test_mem_eff_attention.py b/tests/test_mem_eff_attention.py index 5f5babf97e..34239ced0c 100644 --- a/tests/test_mem_eff_attention.py +++ b/tests/test_mem_eff_attention.py @@ -1587,6 +1587,24 @@ def test_attn_bias_padded() -> None: ) +@cuda_only +def test_attn_bias_to_copy() -> None: + def _test_to_copy(attn_bias: torch.Tensor) -> None: + assert attn_bias.device.type == "cpu", f"{attn_bias.device}" + attn_bias_cuda = attn_bias.cuda() + assert attn_bias_cuda.device.type == "cuda", f"{attn_bias_cuda.device}" + attn_bias_fp16 = attn_bias.to(torch.float16) + assert attn_bias_fp16.device.type == "cpu", f"{attn_bias_fp16.device}" + assert attn_bias_fp16.dtype == torch.float16, f"{attn_bias_fp16.dtype}" + + attn_bias = fmha.attn_bias.LowerTriangularMask() + _test_to_copy(attn_bias) + + tensor_bias = torch.tensor([[1.0, 2.0, 3.0], [3.0, 4.0, 5.0]]) + attn_bias = fmha.attn_bias.LowerTriangularMaskWithTensorBias(tensor_bias) + _test_to_copy(attn_bias) + + def _kv_heads_label(kv_heads: Optional[int]) -> str: if kv_heads is None: return "" diff --git a/xformers/ops/fmha/attn_bias.py b/xformers/ops/fmha/attn_bias.py index 02c66672af..2f095df22e 100644 --- a/xformers/ops/fmha/attn_bias.py +++ b/xformers/ops/fmha/attn_bias.py @@ -1435,7 +1435,7 @@ def __new__(cls, *, _subtensor=None): cls, [], device=_subtensor.device, - dtype=torch.float32, + dtype=_subtensor.dtype, requires_grad=False, ) tensor._subtensor = _subtensor @@ -1449,8 +1449,13 @@ def __repr__(self): @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - if func._overloadpacket in [torch.ops.aten.clone, torch.ops.aten.detach]: - return cls(_subtensor=func(args[0]._subtensor)) + kwargs = kwargs or {} + if func._overloadpacket in [ + torch.ops.aten.clone, + torch.ops.aten.detach, + torch.ops.aten._to_copy, + ]: + return cls(_subtensor=func(args[0]._subtensor, **kwargs)) return NotImplemented def __tensor_flatten__(self): @@ -1542,14 +1547,19 @@ def materialize( @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + kwargs = kwargs or {} if func._overloadpacket in [ torch.ops.aten.unsqueeze, torch.ops.aten.select, torch.ops.aten.slice, torch.ops.aten.clone, torch.ops.aten.detach, + torch.ops.aten._to_copy, ]: - output = func(*[a._subtensor if isinstance(a, cls) else a for a in args]) + output = func( + *[a._subtensor if isinstance(a, cls) else a for a in args], + **kwargs, + ) return cls(output) return NotImplemented