Skip to content

Commit

Permalink
add padding via pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Aug 21, 2024
1 parent ac3bc2c commit 305d30d
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 6 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/amd_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ jobs:
# run: |
# pytest tests/test_flash_attn.py::test_flash_attn_output
# pytest tests/test_flash_attn.py::test_flash_attn_varlen_output
- name: Flash Attention causal Tests
run: |
pytest tests/test_flash_attn.py::test_flash_attn_causal
pytest tests/test_flash_attn.py::test_flash_attn_varlen_causal
# - name: Flash Attention causal Tests
# run: |
# pytest tests/test_flash_attn.py::test_flash_attn_causal
# pytest tests/test_flash_attn.py::test_flash_attn_varlen_causal
- name: Flash Attention kvcache Tests
run: |
pytest tests/test_flash_attn.py::test_flash_attn_splitkv
pytest tests/test_flash_attn.py::test_flash_attn_kvcache
pytest tests/test_flash_attn.py::test_flash_attn_splitkv
- name: Flash Attention race condition Tests
run: |
pytest tests/test_flash_attn.py::test_flash_attn_race_condition
Expand Down
41 changes: 40 additions & 1 deletion flash_attn/flash_attn_triton_kernel_decode_amd.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from typing import Optional
import pytest
import torch
Expand Down Expand Up @@ -638,6 +639,29 @@ def get_split_k(B: int, G: int, H: int, Mk: int) -> int:
split_k = max(split_k, 1)
return split_k


def pad_to_power_of_2(tensor, dim=-1):
"""Pad the last dimension of the tensor to the next power of 2."""
current_size = tensor.size(dim)
next_power_of_2 = 2 ** math.ceil(math.log2(current_size))
pad_size = next_power_of_2 - current_size

if pad_size == 0:
return tensor

pad_shape = list(tensor.shape)
pad_shape[dim] = pad_size
padding = torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)
return torch.cat([tensor, padding], dim=dim)

def unpad_from_power_of_2(tensor, original_size, dim=-1):
"""Remove padding from the last dimension of the tensor."""
return tensor.narrow(dim, 0, original_size)

def needs_padding(size):
"""Check if the given size needs padding to the next power of 2."""
return size & (size - 1) != 0

class _attention(torch.autograd.Function):

OPERATOR = _fwd_kernel_splitK
Expand Down Expand Up @@ -693,6 +717,18 @@ def forward(cls, q, k, v, input_metadata):

assert input_metadata.layout == "bsghd"

# Store original dmodel size
original_dmodel = q.shape[-1]

# Check if padding is needed
if needs_padding(original_dmodel):
# Pad q, k, and v to the next power of 2
q = pad_to_power_of_2(q)
k = pad_to_power_of_2(k)
v = pad_to_power_of_2(v)

input_metadata.dmodel = q.shape[-1]

# get dims
batch_size, seqlen_q, n_group_q, heads_per_group_q, dim_q = q.shape
_, seqlen_k, n_group_k, heads_per_group_k, dim_k = k.shape
Expand Down Expand Up @@ -898,6 +934,9 @@ def forward(cls, q, k, v, input_metadata):
# the data is laid out properly. Just need to reshape dims
out = out.reshape(batch_size, seqlen_q, -1, dim_q)

if needs_padding(original_dmodel):
out = unpad_from_power_of_2(out, original_dmodel)

return out


Expand Down Expand Up @@ -1055,4 +1094,4 @@ def main():


if __name__ == '__main__':
sys.exit(main())
sys.exit(main())

0 comments on commit 305d30d

Please sign in to comment.