Skip to content

Commit

Permalink
Supported aten._to_copy in attention biases (fairinternal/xformers#…
Browse files Browse the repository at this point in the history
…1135)

__original_commit__ = fairinternal/xformers@55d7e78
  • Loading branch information
awgu authored and xFormers Bot committed Jun 14, 2024
1 parent f5603a4 commit 96e5222
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 4 deletions.
18 changes: 18 additions & 0 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1587,6 +1587,24 @@ def test_attn_bias_padded() -> None:
)


@cuda_only
def test_attn_bias_to_copy() -> None:
def _test_to_copy(attn_bias: torch.Tensor) -> None:
assert attn_bias.device.type == "cpu", f"{attn_bias.device}"
attn_bias_cuda = attn_bias.cuda()
assert attn_bias_cuda.device.type == "cuda", f"{attn_bias_cuda.device}"
attn_bias_fp16 = attn_bias.to(torch.float16)
assert attn_bias_fp16.device.type == "cpu", f"{attn_bias_fp16.device}"
assert attn_bias_fp16.dtype == torch.float16, f"{attn_bias_fp16.dtype}"

attn_bias = fmha.attn_bias.LowerTriangularMask()
_test_to_copy(attn_bias)

tensor_bias = torch.tensor([[1.0, 2.0, 3.0], [3.0, 4.0, 5.0]])
attn_bias = fmha.attn_bias.LowerTriangularMaskWithTensorBias(tensor_bias)
_test_to_copy(attn_bias)


def _kv_heads_label(kv_heads: Optional[int]) -> str:
if kv_heads is None:
return ""
Expand Down
18 changes: 14 additions & 4 deletions xformers/ops/fmha/attn_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,7 +1435,7 @@ def __new__(cls, *, _subtensor=None):
cls,
[],
device=_subtensor.device,
dtype=torch.float32,
dtype=_subtensor.dtype,
requires_grad=False,
)
tensor._subtensor = _subtensor
Expand All @@ -1449,8 +1449,13 @@ def __repr__(self):

@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if func._overloadpacket in [torch.ops.aten.clone, torch.ops.aten.detach]:
return cls(_subtensor=func(args[0]._subtensor))
kwargs = kwargs or {}
if func._overloadpacket in [
torch.ops.aten.clone,
torch.ops.aten.detach,
torch.ops.aten._to_copy,
]:
return cls(_subtensor=func(args[0]._subtensor, **kwargs))
return NotImplemented

def __tensor_flatten__(self):
Expand Down Expand Up @@ -1542,14 +1547,19 @@ def materialize(

@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
if func._overloadpacket in [
torch.ops.aten.unsqueeze,
torch.ops.aten.select,
torch.ops.aten.slice,
torch.ops.aten.clone,
torch.ops.aten.detach,
torch.ops.aten._to_copy,
]:
output = func(*[a._subtensor if isinstance(a, cls) else a for a in args])
output = func(
*[a._subtensor if isinstance(a, cls) else a for a in args],
**kwargs,
)
return cls(output)
return NotImplemented

Expand Down

0 comments on commit 96e5222

Please sign in to comment.