Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Dec 5, 2024
1 parent 68eae16 commit 6d9ed27
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 55 deletions.
13 changes: 5 additions & 8 deletions flash_attn/flash_attn_triton_amd/bwd_prefill.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import torch
import triton
import triton.language as tl
from .utils import get_shape_from_layout, get_strides_from_layout, DEBUG, write_dropout_mask
from .utils import DEBUG, get_shape_from_layout, get_strides_from_layout, write_dropout_mask

DEBUG_DROPOUT: tl.constexpr = False

@triton.jit
def _bwd_preprocess_use_p(
Expand Down Expand Up @@ -329,9 +331,6 @@ def _bwd_kernel_one_col_block(
USE_EXP2: tl.constexpr,
GROUP_SIZE: tl.constexpr,
):
DEBUG_DROPOUT = False

# causal
if CAUSAL:
# TODO: Causal can skip more blocks with something like lo = start_m * BLOCK_M
lo = 0
Expand All @@ -358,9 +357,6 @@ def _bwd_kernel_one_col_block(
k = tl.load(k_ptrs, mask=kv_mask, other=0.0)
v = tl.load(v_ptrs, mask=kv_mask, other=0.0)

if DROPOUT:
dropout_scale = 1/ (1 - dropout_p)

# loop over rows
for start_m in range(lo, num_block_m):
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
Expand Down Expand Up @@ -409,6 +405,7 @@ def _bwd_kernel_one_col_block(
# print("philox_offset:", philox_offset)
rand_vals = tl.rand(philox_seed, philox_offset)
dropout_mask = rand_vals > dropout_p
dropout_scale = 1/ (1 - dropout_p)

if DEBUG_DROPOUT:
dropout_ptrs = dropout_offset + offs_m[:, None] * stride_dropoutm + offs_n[None, :] * stride_dropoutn
Expand All @@ -420,7 +417,7 @@ def _bwd_kernel_one_col_block(
p_drop_scaled = p_drop_scaled.to(tl.float16)

# compute dv
dv += tl.dot(tl.trans(p_drop_scaled), do) # dropout scale is applied at the end
dv += tl.dot(tl.trans(p_drop_scaled), do)

# compute dp
dp_drop_scaled = tl.dot(do, tl.trans(v))
Expand Down
38 changes: 13 additions & 25 deletions flash_attn/flash_attn_triton_amd/bwd_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def attention_backward_core_ref_impl(
print("p_drop:", p_drop, p_drop.shape)
print("p_drop_scaled:", p_drop_scaled, p_drop_scaled.shape)

# compute gradient wrt v
# compute dv
dv = torch.matmul(p_drop_scaled.transpose(-2, -1), do)
if DEBUG_CORE:
print("dv:", dv, dv.shape)
Expand All @@ -99,25 +99,14 @@ def attention_backward_core_ref_impl(
print("dp:", dp, dp.shape)

# calculate ds
if False:
if True:
delta = torch.sum(o * do, axis=-1).unsqueeze(-1)
else:
delta = torch.sum(p * dp, axis=-1).unsqueeze(-1)
dscores_scaled = p * (dp - delta)
ds = dscores_scaled * sm_scale
if DEBUG_CORE:
print("delta:", delta, delta.shape)
print("dscores_scaled:", dscores_scaled, dscores_scaled.shape)
print("ds:", ds, ds.shape)

# compute gradient wrt k & q
dk = torch.matmul(ds.transpose(-2, -1), q)
dq = torch.matmul(ds, k)
if DEBUG_CORE:
print("dk:", dk, dk.shape)
print("dq:", dq, dq.shape)
else:
# compute gradient wrt v
# compute dv
dv = torch.matmul(p.transpose(-2, -1), do)
if DEBUG_CORE:
print("dv:", dv, dv.shape)
Expand All @@ -131,18 +120,17 @@ def attention_backward_core_ref_impl(
delta = torch.sum(o * do, axis=-1).unsqueeze(-1)
dscores_scaled = p * (dp - delta)
ds = dscores_scaled * sm_scale
if DEBUG_CORE:
print("delta:", delta, delta.shape)
print("dscores_scaled:", dscores_scaled, dscores_scaled.shape)
print("ds:", ds, ds.shape)

if DEBUG_CORE:
print("delta:", delta, delta.shape)
print("dscores_scaled:", dscores_scaled, dscores_scaled.shape)
print("ds:", ds, ds.shape)

# compute gradient wrt k & q
dk = torch.matmul(ds.transpose(-2, -1), q)
dq = torch.matmul(ds, k)
if DEBUG_CORE:
print("dk:", dk, dk.shape)
print("dq:", dq, dq.shape)
# compute gradient wrt k & q
dk = torch.matmul(ds.transpose(-2, -1), q)
dq = torch.matmul(ds, k)
if DEBUG_CORE:
print("dk:", dk, dk.shape)
print("dq:", dq, dq.shape)

# cast back to original dtype
dq = dq.to(torch.float16)
Expand Down
7 changes: 0 additions & 7 deletions flash_attn/flash_attn_triton_amd/common.py

This file was deleted.

5 changes: 3 additions & 2 deletions flash_attn/flash_attn_triton_amd/fwd_prefill.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import torch
import triton
import triton.language as tl
from .utils import get_shape_from_layout, get_strides_from_layout, is_cdna, is_rdna, DEBUG, AUTOTUNE, write_dropout_mask
from .utils import DEBUG, AUTOTUNE, get_shape_from_layout, get_strides_from_layout, is_cdna, is_rdna, write_dropout_mask

DEBUG_DROPOUT: tl.constexpr = False

# Convenience function to load with optional boundary checks.
# "First" is the major dim, "second" is the minor dim.
Expand Down Expand Up @@ -64,7 +66,6 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
ENABLE_DROPOUT: tl.constexpr, PADDED_HEAD: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr, SM_SCALE: tl.constexpr, USE_EXP2: tl.constexpr,
RETURN_SCORES: tl.constexpr):
DEBUG_DROPOUT = False
if USE_EXP2:
RCP_LN2: tl.constexpr = 1.4426950408889634

Expand Down
3 changes: 1 addition & 2 deletions flash_attn/flash_attn_triton_amd/test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import torch
import pytest

from .utils import MetaData, get_input_shapes, input_helper, varlen_input_helper, DEBUG
from .common import compute_alibi_tensor_ref
from .utils import DEBUG, MetaData, get_input_shapes, input_helper, varlen_input_helper, compute_alibi_tensor_ref
from .interface_torch import attention_prefill, attention_decode
from .fwd_ref import attention_forward_pytorch_ref_impl
from .fwd_prefill import attention_prefill_forward_triton_impl
Expand Down
11 changes: 11 additions & 0 deletions flash_attn/flash_attn_triton_amd/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@
import math
import torch
import os
import random
import triton
import triton.language as tl

AUTOTUNE = os.environ.get('FLASH_ATTENTION_TRITON_AMD_AUTOTUNE', '0').lower() in ('1', 'true', 'yes')
DEBUG = os.environ.get('FLASH_ATTENTION_TRITON_AMD_DEBUG', '0').lower() in ('1', 'true', 'yes')
PERF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_PERF', '0').lower() in ('1', 'true', 'yes')
USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE"
if USE_TRITON_ROCM: # TODO remove this
random.seed(42)

class MetaData():
cu_seqlens_q = None
Expand Down Expand Up @@ -260,6 +265,12 @@ def get_padded_headsize(size):
padded_d_model = max(padded_d_model, 16)
return padded_d_model

def compute_alibi_tensor_ref(alibi_slopes, seqlen_q, seqlen_k):
q_idx = torch.arange(seqlen_q, dtype=torch.int32, device="cuda").unsqueeze(-1) # (N_CTX_Q, 1)
k_idx = torch.arange(seqlen_k, dtype=torch.int32, device="cuda").unsqueeze(0) # (1, N_CTX_K)
relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - k_idx) # (N_CTX_Q, N_CTX_K)
return -1 * alibi_slopes.unsqueeze(-1).unsqueeze(-1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K)

def write_dropout_mask(x, tensor_name = "tensor"):
batch, head, seqlen_m, seqlen_n = x.shape
x = x.tolist()
Expand Down
15 changes: 4 additions & 11 deletions tests/test_flash_attn_triton_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,7 @@
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import _get_block_size_n
from flash_attn.layers.rotary import apply_rotary_emb
from flash_attn.flash_attn_triton_amd.utils import DEBUG, is_rdna

# Test ROCM Triton Backend
USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE"
if USE_TRITON_ROCM:
random.seed(42)
from flash_attn.flash_attn_triton_amd.utils import USE_TRITON_ROCM, DEBUG, is_rdna

MAX_HEADDIM_SM8x = 192

Expand Down Expand Up @@ -590,7 +585,7 @@ def get_dropout_fraction(
@pytest.mark.parametrize("seqlen", [97, 128, 200, 384, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize("seqlen", [128])
@pytest.mark.parametrize("dropout_p", [0.0, 0.17])
# @pytest.mark.parametrize("dropout_p", [0.0])
# @pytest.mark.parametrize("dropout_p", [0.17])
def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype):
if USE_TRITON_ROCM:
if local == True:
Expand All @@ -601,8 +596,8 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 1
nheads = 1
batch_size = 4
nheads = 9
window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,))
qkv = torch.randn(
batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True
Expand Down Expand Up @@ -932,7 +927,6 @@ def test_flash_attn_output(
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap
):
if USE_TRITON_ROCM:

if softcap != 0.0:
pytest.skip("softcap not supported on AMD's Triton Backend yet")

Expand Down Expand Up @@ -1216,7 +1210,6 @@ def test_flash_attn_output(
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
# (32, 32),
(1, 147),
(113, 203),
(128, 217),
Expand Down

0 comments on commit 6d9ed27

Please sign in to comment.