Skip to content

Commit

Permalink
Alex's work
Browse files Browse the repository at this point in the history
This is a combination of 11 commits.

save

fix: dropout=0.0 woorks

feat: dropout restrictions removed. failing tests

test: reduced tests to simple cases

test: failure is due to query + key padding mask NOT varlen itself

feat: varlen dropout fwd passes

fix: varlen bwd dropout works!

test: discovered  bwd error for non-dropout cases for large seqlen

save

save

use triton commit 3ca2f498e98ed7249b82722587c511a5610e00c4 -- now batched layout passes
  • Loading branch information
alexkranias-amd authored and micmelesse committed Dec 5, 2024
1 parent c57b3d0 commit a99d51c
Show file tree
Hide file tree
Showing 10 changed files with 931 additions and 68 deletions.
72 changes: 72 additions & 0 deletions flash_attn/flash_attn_triton_amd/bwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,31 @@
import triton.language as tl
from .utils import get_shape_from_layout, get_strides_from_layout, DEBUG, PERF

@triton.jit
def cdiv_fn(x, y):
return (x + y - 1) // y

@triton.jit
def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
# tl.device_print('bwd_philox_offset:', philox_offset)
ms = tl.arange(0, m)
ns = tl.arange(0, n)
return philox_offset + ms[:, None] * stride + ns[None, :]


@triton.jit
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride).to(tl.uint32)
# TODO: use tl.randint for better performance
return tl.rand(philox_seed, rng_offsets)


@triton.jit
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride)
rng_keep = rng_output > dropout_p
return rng_keep

@triton.jit
def _bwd_preprocess_use_o(
Out,
Expand Down Expand Up @@ -117,12 +142,14 @@ def _bwd_kernel_one_col_block(
start_n,
num_block_m,
num_block_n,
dropout_p, philox_seed, philox_offset_base,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
SEQUENCE_PARALLEL: tl.constexpr,
CAUSAL: tl.constexpr,
DROPOUT: tl.constexpr,
USE_EXP2: tl.constexpr,
GROUP_SIZE: tl.constexpr,
):
Expand Down Expand Up @@ -194,12 +221,31 @@ def _bwd_kernel_one_col_block(
p = tl.where(p_mask, p, 0.0)
p = p.to(tl.float16)

# NOTE: must create a new var p_drop to prevent p (which is used later to compute ds) from changing
if DROPOUT:
philox_offset = philox_offset_base + start_m * N_CTX_K + start_n * BLOCK_N
keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, N_CTX_K)
p_drop = tl.where(keep, p, 0.0)

p_drop = p_drop / (1 - dropout_p)
p_drop = p_drop.to(Q.dtype.element_ty)
else:
p_drop = p

# compute dv
dv += tl.dot(tl.trans(p), do)

# compute dp
dp = tl.dot(do, tl.trans(v))

if DROPOUT:
philox_offset = philox_offset_base + start_m * N_CTX_K + start_n * BLOCK_N
keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, N_CTX_K)
dp = tl.where(keep, dp, 0.0)

dp = dp / (1 - dropout_p)
dp = dp.to(Q.dtype.element_ty)

# compute ds , ds = p * (dp - delta[:, None])
d_ptrs = d_offset + offs_m * stride_deltam
Di = tl.load(d_ptrs, mask=mask_m)
Expand Down Expand Up @@ -269,12 +315,14 @@ def _bwd_kernel(
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p, philox_seed, philox_offset,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
SEQUENCE_PARALLEL: tl.constexpr,
CAUSAL: tl.constexpr,
DROPOUT: tl.constexpr,
USE_EXP2: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
Expand All @@ -291,6 +339,11 @@ def _bwd_kernel(
else:
off_hk = off_hq

if DROPOUT:
batch_philox_offset = philox_offset + off_hq * max_seqlen_q * max_seqlen_k
else:
batch_philox_offset = 0

if IS_VARLEN:
# Compute sequence lengths for the current batch
q_start = tl.load(cu_seqlens_q + off_z)
Expand Down Expand Up @@ -368,12 +421,14 @@ def _bwd_kernel(
start_n,
num_block_m,
num_block_n,
dropout_p, philox_seed, batch_philox_offset,
BLOCK_M=BLOCK_M,
BLOCK_DMODEL=BLOCK_DMODEL,
ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
BLOCK_N=BLOCK_N,
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
CAUSAL=CAUSAL,
DROPOUT=DROPOUT,
USE_EXP2=USE_EXP2,
GROUP_SIZE=GROUP_SIZE
)
Expand Down Expand Up @@ -421,12 +476,14 @@ def _bwd_kernel(
start_n,
num_block_m,
num_block_n,
dropout_p, philox_seed, batch_philox_offset,
BLOCK_M=BLOCK_M,
BLOCK_DMODEL=BLOCK_DMODEL,
ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
BLOCK_N=BLOCK_N,
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
CAUSAL=CAUSAL,
DROPOUT=DROPOUT,
USE_EXP2=USE_EXP2,
GROUP_SIZE=GROUP_SIZE
)
Expand All @@ -446,12 +503,14 @@ def attention_prefill_backward_triton_impl(
sm_scale: float,
alibi_slopes,
causal,
dropout_p,
layout: str,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q: int,
max_seqlen_k: int,
use_exp2: bool,
rng_state: torch.Tensor,
sequence_parallel = True,
):
if DEBUG:
Expand All @@ -475,6 +534,7 @@ def attention_prefill_backward_triton_impl(
print("max_seqlen_q:", max_seqlen_q)
print("max_seqlen_k:", max_seqlen_k)
print("use_exp2:", use_exp2)
print("rng_state", rng_state)
print("sequence_parallel:", sequence_parallel)

# make contigious
Expand All @@ -491,6 +551,13 @@ def attention_prefill_backward_triton_impl(
stride_vz, stride_vh, stride_vn, stride_vk = v_strides
stride_oz, stride_oh, stride_om, stride_ok = o_strides
is_varlen = layout == "thd"


# get dropout metadata
if dropout_p > 0.0:
philox_seed, philox_offset = rng_state[0].item(), rng_state[1].item()
else:
philox_seed, philox_offset = None, None

# FIXME: some configs lead to oom for some reason when using 64 x 64 blocks
if max_seqlen_q <= 32 or max_seqlen_k <= 32:
Expand Down Expand Up @@ -619,6 +686,9 @@ def attention_prefill_backward_triton_impl(
print("heads_q:",nheads_q)
print("max_seqlen_q:",max_seqlen_q)
print("max_seqlen_k:",max_seqlen_k)
print("dropout_p:",dropout_p)
print("philox_seed:", philox_seed)
print("philox_offset:",philox_offset)
print("BLOCK_M:",BLOCK_M)
print("BLOCK_N:",BLOCK_M)
print("BLOCK_DMODEL:",BLOCK_DMODEL)
Expand Down Expand Up @@ -657,12 +727,14 @@ def attention_prefill_backward_triton_impl(
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p, philox_seed, philox_offset,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_DMODEL=BLOCK_DMODEL,
ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
SEQUENCE_PARALLEL=sequence_parallel,
CAUSAL=causal,
DROPOUT=dropout_p>0.0,
USE_EXP2=use_exp2,
num_warps=num_warps,
num_stages=num_stages,
Expand Down
4 changes: 3 additions & 1 deletion flash_attn/flash_attn_triton_amd/bwd_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,12 +359,14 @@ def attention_backward_pytorch_ref_impl(
softmax_lse,
sm_scale,
causal,
dropout_p,
layout,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
use_exp2
use_exp2,
rng_state
):

if DEBUG:
Expand Down
Loading

0 comments on commit a99d51c

Please sign in to comment.