From f692b98d805850983f14deec7a9104583c58b107 Mon Sep 17 00:00:00 2001 From: Ivan Komarov Date: Fri, 5 Apr 2024 22:40:41 +0200 Subject: [PATCH 01/17] Fix spurious re-compilations of `rotary_kernel` (#911) All integer parameters are specialized by default, so the two parameters removed in this commit could lead to kernel re-compilation, even if they were completely unused. --- flash_attn/ops/triton/rotary.py | 13 ------------ tests/test_rotary.py | 37 +++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 13 deletions(-) diff --git a/flash_attn/ops/triton/rotary.py b/flash_attn/ops/triton/rotary.py index 8d2e09b0c..6c04a523e 100644 --- a/flash_attn/ops/triton/rotary.py +++ b/flash_attn/ops/triton/rotary.py @@ -8,15 +8,6 @@ import triton.language as tl -# @triton.autotune( -# configs=[ -# triton.Config({"BLOCK_M": 2}), -# triton.Config({"BLOCK_M": 4}), -# triton.Config({"BLOCK_M": 8}), -# triton.Config({"BLOCK_M": 16}), -# ], -# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"], -# ) @triton.jit def rotary_kernel( OUT, # Pointers to matrices @@ -27,10 +18,8 @@ def rotary_kernel( SEQLEN_OFFSETS, # this could be int or a pointer # Matrix dimensions seqlen, - nheads, rotary_dim, seqlen_ro, - CACHE_KEY_SEQLEN, # strides stride_out_batch, stride_out_seqlen, @@ -218,10 +207,8 @@ def apply_rotary( cu_seqlens, seqlen_offsets, seqlen, # shapes - nheads, rotary_dim, seqlen_ro, - seqlen // 128, # key for triton cache (limit number of compilations) output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 output.stride(-3), # seqlen_stride or total_seqlen_stride output.stride(-2), # nheads_stride diff --git a/tests/test_rotary.py b/tests/test_rotary.py index 574d0526b..6f2a5fae7 100644 --- a/tests/test_rotary.py +++ b/tests/test_rotary.py @@ -252,3 +252,40 @@ def test_rotary_emb_varlen_func(inplace, interleaved, rotary_fraction, seqlen_of assert torch.allclose(out, out_pt, rtol=rtol, atol=2 * atol) atol = ((x_pt.grad + 0.3 - 0.3) - x_pt.grad).abs().max().item() assert torch.allclose(x_grad, x_pt.grad, rtol=rtol, atol=2 * atol) + + +def test_compilation_count(): + batch_size = 1 + headdim = 128 + device = "cuda" + dtype = torch.float16 + torch.manual_seed(42) + + from triton.runtime.jit import JITFunction + from flash_attn.ops.triton.rotary import rotary_kernel + compilation_count = 0 + + def count_compilations(*args, **kwargs): + nonlocal compilation_count + compilation_count += 1 + + old_cache_func = JITFunction.cache_hook + + try: + rotary_kernel.cache.clear() + JITFunction.cache_hook = count_compilations + + for seqlen in (128, 256): + for nheads in (4, 32): + x = torch.randn(batch_size, seqlen, nheads, headdim, dtype=dtype, device=device) + x.requires_grad_() + cos, sin = generate_cos_sin(seqlen, headdim, device, dtype) + out = apply_rotary_emb(x, cos, sin) + out.backward(torch.randn_like(out)) + + # Only two kernels are expected to be compiled: + # * for the forward pass (conjugate=False) + # * for the backward pass (conjugate=True) + assert compilation_count == 2 + finally: + JITFunction.cache_hook = old_cache_func From 9eb3d099c1fb806b4f6351a7b18f2e7731bcec64 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 7 Apr 2024 20:04:39 -0700 Subject: [PATCH 02/17] Transpose out when swapping seqlen_q and num_groups --- csrc/flash_attn/flash_api.cpp | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 001acacaf..ac753af2c 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -282,7 +282,8 @@ void set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size, params.num_splits = num_splits; if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout if (num_splits < 1) { - params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128); + // We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block. + params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount * 2, num_n_blocks, 128); } if (params.num_splits > 1) { at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); @@ -372,8 +373,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value(); + const int ngroups = num_heads / num_heads_k; if (seqlenq_ngroups_swapped) { - const int ngroups = num_heads / num_heads_k; q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); seqlen_q = ngroups; num_heads = num_heads_k; @@ -400,7 +401,10 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); CHECK_DEVICE(out); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); - CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og); + CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size_og); + if (seqlenq_ngroups_swapped) { + out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2); + } if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } } else { out = torch::empty_like(q_padded); @@ -571,8 +575,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value(); + const int ngroups = num_heads / num_heads_k; if (seqlenq_ngroups_swapped) { - const int ngroups = num_heads / num_heads_k; q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og}); max_seqlen_q = ngroups; num_heads = num_heads_k; @@ -627,6 +631,10 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s CHECK_DEVICE(out); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); CHECK_SHAPE(out, total_q, num_heads, head_size_og); + CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og); + if (seqlenq_ngroups_swapped) { + out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og}); + } if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } } else { out = torch::empty_like(q_padded); From 656daef4eace3b626e299a102ce111bc95385060 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 7 Apr 2024 20:09:55 -0700 Subject: [PATCH 03/17] Use Cute's local_tile to get gQ, gK, gV --- csrc/flash_attn/src/flash_fwd_kernel.h | 105 ++++++++++++------------- 1 file changed, 51 insertions(+), 54 deletions(-) diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index ab9f36743..104e16419 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -68,14 +68,16 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0. // Otherwise we might read OOB elements from gK and gV. if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) { - const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) - + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; - const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; - Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), - Shape, Int>{}, - make_stride(params.o_row_stride, _1{})); - Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), - Shape>{}, Stride<_1>{}); + Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.o_row_stride, params.o_head_stride, _1{})); + Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, + make_coord(m_block, 0)); // (kBlockM, kHeadDim) + Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), + make_shape(params.b, params.h, params.seqlen_q), + make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{})); + Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape>{}, make_coord(m_block)); typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); @@ -108,25 +110,27 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // that needs masking when we read K and V from global memory. Moreover, iterating in reverse // might save us 1 register (we just need n_block instead of both n_block and n_block_max). - const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) - + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; - // We move K and V to the last block. - const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) - + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; - const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) - + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; - Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), - Shape, Int>{}, - make_stride(params.q_row_stride, _1{})); - Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), - Shape, Int>{}, - make_stride(params.k_row_stride, _1{})); - Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), - Shape, Int>{}, - make_stride(params.v_row_stride, _1{})); + Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.q_row_stride, params.q_head_stride, _1{})); + Tensor gQ = local_tile(mQ(_, bidh, _), Shape, Int>{}, + make_coord(m_block, 0)); // (kBlockM, kHeadDim) + Tensor mK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + + binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)), + make_shape(binfo.actual_seqlen_k, params.h_k, params.d), + make_stride(params.k_row_stride, params.k_head_stride, _1{})); + Tensor gK = local_tile(mK(_, bidh / params.h_h_k_ratio, _), Shape, Int>{}, + make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN) + Tensor mV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + + binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)), + make_shape(binfo.actual_seqlen_k, params.h_k, params.d), + make_stride(params.v_row_stride, params.v_head_stride, _1{})); + Tensor gV = local_tile(mV(_, bidh / params.h_h_k_ratio, _), Shape, Int>{}, + make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN) Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast(params.p_ptr) + row_offset_p), Shape, Int>{}, make_stride(params.seqlen_k_rounded, _1{})); @@ -145,9 +149,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); - Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN) Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); - Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN) Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); typename Kernel_traits::TiledMma tiled_mma; @@ -240,7 +244,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi int n_block = n_block_max - 1; // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); cute::cp_async_fence(); // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } @@ -281,12 +285,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Advance gV if (masking_step > 0) { - tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); } else { // Clear the smem tiles to account for predicated off loads flash::copy( - gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN ); } cute::cp_async_fence(); @@ -304,9 +307,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi flash::cp_async_wait<0>(); __syncthreads(); if (n_block > n_block_min) { - // Advance gK - tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -354,9 +355,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi clear(acc_s); flash::cp_async_wait<0>(); __syncthreads(); - // Advance gV - tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); - flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); flash::gemm( @@ -367,9 +366,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi flash::cp_async_wait<0>(); __syncthreads(); if (n_block > n_block_min) { - // Advance gK - tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); - flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + flash::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); // This cp_async_fence needs to be in the if block, otherwise the synchronization // isn't right and we get race conditions. cute::cp_async_fence(); @@ -421,14 +418,16 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); - const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) - + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; - const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; - Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), - Shape, Int>{}, - make_stride(params.o_row_stride, _1{})); - Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), - Shape>{}, Stride<_1>{}); + Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.o_row_stride, params.o_head_stride, _1{})); + Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, + make_coord(m_block, 0)); // (kBlockM, kHeadDim) + Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), + make_shape(params.b, params.h, params.seqlen_q), + make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{})); + Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape>{}, make_coord(m_block)); typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); @@ -555,8 +554,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // that needs masking when we read K and V from global memory. Moreover, iterating in reverse // might save us 1 register (we just need n_block instead of both n_block and n_block_max). - const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) - + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; // We move K and V to the last block. const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride; @@ -571,9 +568,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; - Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), - Shape, Int>{}, - make_stride(params.q_row_stride, _1{})); + Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.q_row_stride, params.q_head_stride, _1{})); + Tensor gQ = local_tile(mQ(_, bidh, _), Shape, Int>{}, + make_coord(m_block, 0)); // (kBlockM, kHeadDim) Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), Shape, Int>{}, make_stride(params.k_row_stride, _1{})); @@ -1033,8 +1032,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons flash::copy( gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM ); - // __syncthreads(); - // if (cute::thread0()) { print(tOgOaccum); } } //////////////////////////////////////////////////////////////////////////////////////////////////// From 2aea958f8988e7497c2aa638f7c859cfbf9576d6 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 7 Apr 2024 20:11:52 -0700 Subject: [PATCH 04/17] [CI] Compile with torch 2.3.0.dev20240207 --- .github/workflows/publish.yml | 4 ++-- setup.py | 7 ++++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 47cd2589c..2413d3e96 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -44,7 +44,7 @@ jobs: # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. os: [ubuntu-20.04] python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] - torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.0', '2.3.0.dev20240105'] + torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.0', '2.3.0.dev20240207'] cuda-version: ['11.8.0', '12.2.2'] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. @@ -63,7 +63,7 @@ jobs: python-version: '3.7' - torch-version: '2.2.0' python-version: '3.7' - - torch-version: '2.3.0.dev20240105' + - torch-version: '2.3.0.dev20240207' python-version: '3.7' # Pytorch <= 2.0 only supports CUDA <= 11.8 - torch-version: '1.12.1' diff --git a/setup.py b/setup.py index 6978dd6f3..54f88dbfd 100644 --- a/setup.py +++ b/setup.py @@ -200,6 +200,11 @@ def append_nvcc_threads(nvcc_extra_args): # "--ptxas-options=-v", # "--ptxas-options=-O2", # "-lineinfo", + # "-DFLASHATTENTION_DISABLE_BACKWARD", + # "-DFLASHATTENTION_DISABLE_DROPOUT", + # "-DFLASHATTENTION_DISABLE_ALIBI", + # "-DFLASHATTENTION_DISABLE_UNEVEN_K", + # "-DFLASHATTENTION_DISABLE_LOCAL", ] + generator_flag + cc_flag @@ -345,4 +350,4 @@ def __init__(self, *args, **kwargs) -> None: setup_requires=[ "psutil" ], -) \ No newline at end of file +) From 85881f547fd1053a7b4a2c3faad6690cca969279 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 7 Apr 2024 20:13:05 -0700 Subject: [PATCH 05/17] Bump to v2.5.7 --- flash_attn/__init__.py | 2 +- training/Dockerfile | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index 756253685..2cb147527 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.5.6" +__version__ = "2.5.7" from flash_attn.flash_attn_interface import ( flash_attn_func, diff --git a/training/Dockerfile b/training/Dockerfile index b753cd1c7..33b396b5b 100644 --- a/training/Dockerfile +++ b/training/Dockerfile @@ -85,7 +85,7 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0 # Install FlashAttention -RUN pip install flash-attn==2.5.6 +RUN pip install flash-attn==2.5.7 # Install CUDA extensions for fused dense -RUN pip install git+https://github.com/HazyResearch/flash-attention@v2.5.6#subdirectory=csrc/fused_dense_lib +RUN pip install git+https://github.com/HazyResearch/flash-attention@v2.5.7#subdirectory=csrc/fused_dense_lib From ec6d22143b5d375e253b2ebfc563b26a43f43684 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 26 Apr 2024 10:50:41 -0700 Subject: [PATCH 06/17] [CrossEntropy] Change ignored_index -> ignore_index --- flash_attn/losses/cross_entropy.py | 4 ++-- flash_attn/ops/triton/cross_entropy.py | 26 +++++++++++++------------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/flash_attn/losses/cross_entropy.py b/flash_attn/losses/cross_entropy.py index 2a1b77a34..2c5032c77 100644 --- a/flash_attn/losses/cross_entropy.py +++ b/flash_attn/losses/cross_entropy.py @@ -20,7 +20,7 @@ def __init__( ): """ Arguments: - ignored_index: int. If labels == ignored_index, the loss is set to 0.0. + ignore_index: int. If labels == ignore_index, the loss is set to 0.0. label_smoothing: float lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. This is also referred to as "z-loss". @@ -60,7 +60,7 @@ def forward(self, input, target): label_smoothing=self.label_smoothing, logit_scale=self.logit_scale, lse_square_scale=self.lse_square_scale, - ignored_index=self.ignore_index, + ignore_index=self.ignore_index, inplace_backward=self.inplace_backward, process_group=self.process_group, ) diff --git a/flash_attn/ops/triton/cross_entropy.py b/flash_attn/ops/triton/cross_entropy.py index c8111ca54..1f895d7db 100644 --- a/flash_attn/ops/triton/cross_entropy.py +++ b/flash_attn/ops/triton/cross_entropy.py @@ -32,7 +32,7 @@ def cross_entropy_fwd_kernel( smoothing, logit_scale, lse_square_scale, - ignored_index, + ignore_index, total_classes, class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes n_cols, # shapes @@ -56,7 +56,7 @@ def cross_entropy_fwd_kernel( sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0) lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse) - if label_idx == ignored_index: + if label_idx == ignore_index: loss = 0.0 z_loss = 0.0 else: @@ -104,7 +104,7 @@ def cross_entropy_bwd_kernel( smoothing, logit_scale, lse_square_scale, - ignored_index, + ignore_index, total_classes, class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes n_cols, # shapes @@ -120,7 +120,7 @@ def cross_entropy_bwd_kernel( dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64) col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) label_idx = tl.load(labels_ptr + row_idx) - if label_idx != ignored_index: + if label_idx != ignore_index: dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride) else: dloss = 0.0 @@ -150,7 +150,7 @@ def forward( smoothing=0.0, logit_scale=1.0, lse_square_scale=0.0, - ignored_index=-100, + ignore_index=-100, inplace_backward=False, process_group=None, ): @@ -192,7 +192,7 @@ def forward( smoothing, logit_scale, lse_square_scale, - ignored_index, + ignore_index, total_classes, class_start_idx, n_cols, # shapes @@ -229,18 +229,18 @@ def forward( losses += lse if lse_square_scale != 0.0: z_losses = lse_square_scale * lse.square() - z_losses.masked_fill_(labels == ignored_index, 0.0) + z_losses.masked_fill_(labels == ignore_index, 0.0) losses += z_losses else: z_losses = torch.zeros_like(losses) - losses.masked_fill_(labels == ignored_index, 0.0) + losses.masked_fill_(labels == ignore_index, 0.0) ctx.save_for_backward(logits, lse, labels) ctx.mark_non_differentiable(z_losses) ctx.smoothing = smoothing ctx.logit_scale = logit_scale ctx.lse_square_scale = lse_square_scale - ctx.ignored_index = ignored_index + ctx.ignore_index = ignore_index ctx.total_classes = total_classes ctx.class_start_idx = class_start_idx ctx.inplace_backward = inplace_backward @@ -269,7 +269,7 @@ def backward(ctx, grad_losses, grad_z_losses): ctx.smoothing, ctx.logit_scale, ctx.lse_square_scale, - ctx.ignored_index, + ctx.ignore_index, ctx.total_classes, ctx.class_start_idx, n_cols, # shapes @@ -287,7 +287,7 @@ def cross_entropy_loss( label_smoothing: float = 0.0, logit_scale: float = 1.0, lse_square_scale: float = 0.0, - ignored_index=-100, + ignore_index=-100, inplace_backward: bool = False, process_group=None, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -299,7 +299,7 @@ def cross_entropy_loss( logit_scale: float. Multiply logits by this scale before calculating the loss. lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. This is also referred to as "z-loss". - ignored_index: int. If labels == ignored_index, the loss is set to 0.0. + ignore_index: int. If labels == ignore_index, the loss is set to 0.0. inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits. This saves memory. process_group: if not None, we're doing Tensor Parallel: each process is responsible for @@ -314,7 +314,7 @@ def cross_entropy_loss( label_smoothing, logit_scale, lse_square_scale, - ignored_index, + ignore_index, inplace_backward, process_group, ) From 35060e74504d9d555492028d5492f9c3f2f02a41 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 26 Apr 2024 10:53:24 -0700 Subject: [PATCH 07/17] [CI] Compile for pytorch 2.2.2 and 2.3.0 --- .github/workflows/publish.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 2413d3e96..f75e41e21 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -44,7 +44,7 @@ jobs: # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. os: [ubuntu-20.04] python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] - torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.0', '2.3.0.dev20240207'] + torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.2', '2.3.0'] cuda-version: ['11.8.0', '12.2.2'] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. @@ -61,9 +61,9 @@ jobs: python-version: '3.7' - torch-version: '2.1.2' python-version: '3.7' - - torch-version: '2.2.0' + - torch-version: '2.2.2' python-version: '3.7' - - torch-version: '2.3.0.dev20240207' + - torch-version: '2.3.0' python-version: '3.7' # Pytorch <= 2.0 only supports CUDA <= 11.8 - torch-version: '1.12.1' From 9a11f440d3a34f618b4ba814c825b109c6d7e8f5 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 26 Apr 2024 10:54:52 -0700 Subject: [PATCH 08/17] Bump to v2.5.8 --- flash_attn/__init__.py | 2 +- training/Dockerfile | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index 2cb147527..7b26bc096 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.5.7" +__version__ = "2.5.8" from flash_attn.flash_attn_interface import ( flash_attn_func, diff --git a/training/Dockerfile b/training/Dockerfile index 33b396b5b..a4d12bd76 100644 --- a/training/Dockerfile +++ b/training/Dockerfile @@ -85,7 +85,7 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0 # Install FlashAttention -RUN pip install flash-attn==2.5.7 +RUN pip install flash-attn==2.5.8 # Install CUDA extensions for fused dense -RUN pip install git+https://github.com/HazyResearch/flash-attention@v2.5.7#subdirectory=csrc/fused_dense_lib +RUN pip install git+https://github.com/HazyResearch/flash-attention@v2.5.8#subdirectory=csrc/fused_dense_lib From 9c0e9ee86d0e0022b60deddb405c20ab77481582 Mon Sep 17 00:00:00 2001 From: Wei Ji <23487320+weiji14@users.noreply.github.com> Date: Tue, 7 May 2024 04:45:54 +1200 Subject: [PATCH 09/17] Move packaging and ninja from install_requires to setup_requires (#937) Set `packaging` and `ninja` as build time dependencies rather than runtime dependencies. --- setup.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 54f88dbfd..f872b8e42 100644 --- a/setup.py +++ b/setup.py @@ -344,10 +344,10 @@ def __init__(self, *args, **kwargs) -> None: install_requires=[ "torch", "einops", - "packaging", - "ninja", ], setup_requires=[ - "psutil" + "packaging", + "psutil", + "ninja", ], ) From 22339db185027324f334a7f59e2584da266bfd4c Mon Sep 17 00:00:00 2001 From: lancerts Date: Thu, 23 May 2024 11:12:31 -0700 Subject: [PATCH 10/17] remove an unused import (#960) --- flash_attn/ops/triton/cross_entropy.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/flash_attn/ops/triton/cross_entropy.py b/flash_attn/ops/triton/cross_entropy.py index 1f895d7db..178233813 100644 --- a/flash_attn/ops/triton/cross_entropy.py +++ b/flash_attn/ops/triton/cross_entropy.py @@ -4,8 +4,6 @@ import torch -from einops import rearrange - import triton import triton.language as tl From beb8b8ba9f69475a3fe97a076825f3eea5f537b4 Mon Sep 17 00:00:00 2001 From: Corey James Levinson Date: Sun, 26 May 2024 15:33:03 -0400 Subject: [PATCH 11/17] add exception to Timeout Error (#963) When timeout connecting, you get URLError: , In that case, build it from source. --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f872b8e42..edbe13b02 100644 --- a/setup.py +++ b/setup.py @@ -282,7 +282,7 @@ def run(self): wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") print("Raw wheel path", wheel_path) os.rename(wheel_filename, wheel_path) - except urllib.error.HTTPError: + except (urllib.error.HTTPError, urllib.error.URLError): print("Precompiled wheel not found. Building from source...") # If the wheel could not be downloaded, build from source super().run() From 40e667236ce9e2b80513b9bf8d1fe93960c322d7 Mon Sep 17 00:00:00 2001 From: Wongboo <44860323+Wongboo@users.noreply.github.com> Date: Mon, 27 May 2024 03:34:49 +0800 Subject: [PATCH 12/17] Update for python3.12 (#870) --- .github/workflows/publish.yml | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index f75e41e21..ab1cce822 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -43,7 +43,7 @@ jobs: # Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. os: [ubuntu-20.04] - python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] + python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12'] torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.2', '2.3.0'] cuda-version: ['11.8.0', '12.2.2'] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. @@ -53,6 +53,15 @@ jobs: cxx11_abi: ['FALSE', 'TRUE'] exclude: # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix + # Pytorch < 2.2 does not support Python 3.12 + - torch-version: '1.12.1' + python-version: '3.12' + - torch-version: '1.13.1' + python-version: '3.12' + - torch-version: '2.0.1' + python-version: '3.12' + - torch-version: '2.1.2' + python-version: '3.12' # Pytorch <= 1.12 does not support Python 3.11 - torch-version: '1.12.1' python-version: '3.11' @@ -123,6 +132,8 @@ jobs: # If we don't install before installing Pytorch, we get error for torch 2.0.1 # ERROR: Could not find a version that satisfies the requirement setuptools>=40.8.0 (from versions: none) pip install lit + # For some reason torch 2.2.0 on python 3.12 errors saying no setuptools + pip install setuptools # We want to figure out the CUDA version to download pytorch # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116 # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix From af627063e3387e4f7517e0eb8cf428ae912a300c Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 26 May 2024 12:41:17 -0700 Subject: [PATCH 13/17] [CI] Compile for pytorch 2.4.0.dev20240407 (for nvcr 24.05) --- .github/workflows/publish.yml | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index ab1cce822..88aa16768 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -44,7 +44,7 @@ jobs: # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. os: [ubuntu-20.04] python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12'] - torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.2', '2.3.0'] + torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.2', '2.3.0', '2.4.0.dev20240407'] cuda-version: ['11.8.0', '12.2.2'] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. @@ -74,6 +74,8 @@ jobs: python-version: '3.7' - torch-version: '2.3.0' python-version: '3.7' + - torch-version: '2.4.0.dev20240407' + python-version: '3.7' # Pytorch <= 2.0 only supports CUDA <= 11.8 - torch-version: '1.12.1' cuda-version: '12.2.2' @@ -139,18 +141,12 @@ jobs: # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix # This code is ugly, maybe there's a better way to do this. export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \ - minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118, '2.2': 118, '2.3': 118}[env['MATRIX_TORCH_VERSION']]; \ - maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121, '2.2': 121, '2.3': 121}[env['MATRIX_TORCH_VERSION']]; \ + minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118}[env['MATRIX_TORCH_VERSION']]; \ + maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121, '2.2': 121, '2.3': 121, '2.4': 121}[env['MATRIX_TORCH_VERSION']]; \ print(max(min(int(env['MATRIX_CUDA_VERSION']), maxv), minv))" \ ) if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then - if [[ ${MATRIX_TORCH_VERSION} == "2.2" ]]; then - # --no-deps because we can't install old versions of pytorch-triton - pip install typing-extensions jinja2 - pip install --no-cache-dir --no-deps --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl - else - pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} - fi + pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} else pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} fi From d732be1e67ec517572e85360248ac3c6d2cc2ae8 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 26 May 2024 12:49:33 -0700 Subject: [PATCH 14/17] Update to Cutlass 3.5 --- csrc/cutlass | 2 +- csrc/flash_attn/src/flash_bwd_kernel.h | 2 +- csrc/flash_attn/src/flash_bwd_preprocess_kernel.h | 2 +- csrc/flash_attn/src/flash_fwd_kernel.h | 2 +- csrc/flash_attn/src/kernel_traits.h | 2 +- csrc/flash_attn/src/rotary.h | 2 +- csrc/flash_attn/src/utils.h | 3 +-- 7 files changed, 7 insertions(+), 8 deletions(-) diff --git a/csrc/cutlass b/csrc/cutlass index bbe579a9e..7d49e6c7e 160000 --- a/csrc/cutlass +++ b/csrc/cutlass @@ -1 +1 @@ -Subproject commit bbe579a9e3beb6ea6626d9227ec32d0dae119a49 +Subproject commit 7d49e6c7e2f8896c47f586706e67e1fb215529dc diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index 6f89c2137..7d35209c0 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -4,7 +4,7 @@ #pragma once -#include +#include #include #include diff --git a/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h b/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h index 6582d814c..aa0641530 100644 --- a/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h @@ -4,7 +4,7 @@ #pragma once -#include +#include #include #include diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index 104e16419..fd68cec12 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -4,7 +4,7 @@ #pragma once -#include +#include #include #include diff --git a/csrc/flash_attn/src/kernel_traits.h b/csrc/flash_attn/src/kernel_traits.h index a7a5cf1ed..5a7b74911 100644 --- a/csrc/flash_attn/src/kernel_traits.h +++ b/csrc/flash_attn/src/kernel_traits.h @@ -4,7 +4,7 @@ #pragma once -#include "cute/algorithm/copy.hpp" +#include "cute/tensor.hpp" #include "cutlass/cutlass.h" #include "cutlass/layout/layout.h" diff --git a/csrc/flash_attn/src/rotary.h b/csrc/flash_attn/src/rotary.h index dc2825be7..7f1614ad2 100644 --- a/csrc/flash_attn/src/rotary.h +++ b/csrc/flash_attn/src/rotary.h @@ -4,7 +4,7 @@ #pragma once -#include +#include #include "utils.h" diff --git a/csrc/flash_attn/src/utils.h b/csrc/flash_attn/src/utils.h index 2b45e87b2..708aeddfa 100644 --- a/csrc/flash_attn/src/utils.h +++ b/csrc/flash_attn/src/utils.h @@ -14,8 +14,7 @@ #include #endif -#include -#include +#include #include #include From ce7350357869bab7a2d8665c37bdf326c9e98b61 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 26 May 2024 14:02:11 -0700 Subject: [PATCH 15/17] Bump to 2.5.9 --- flash_attn/__init__.py | 2 +- training/Dockerfile | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index 7b26bc096..a461e8ac6 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.5.8" +__version__ = "2.5.9" from flash_attn.flash_attn_interface import ( flash_attn_func, diff --git a/training/Dockerfile b/training/Dockerfile index a4d12bd76..2c68bd1ea 100644 --- a/training/Dockerfile +++ b/training/Dockerfile @@ -85,7 +85,7 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0 # Install FlashAttention -RUN pip install flash-attn==2.5.8 +RUN pip install flash-attn==2.5.9 # Install CUDA extensions for fused dense -RUN pip install git+https://github.com/HazyResearch/flash-attention@v2.5.8#subdirectory=csrc/fused_dense_lib +RUN pip install git+https://github.com/HazyResearch/flash-attention@v2.5.9#subdirectory=csrc/fused_dense_lib From e2e4333c955b829d0e6087d27ee435f55c80d3a5 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 26 May 2024 15:35:21 -0700 Subject: [PATCH 16/17] Limit to MAX_JOBS=1 with CUDA 12.2 --- .github/workflows/publish.yml | 3 ++- flash_attn/__init__.py | 2 +- training/Dockerfile | 4 ++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 88aa16768..020c1371a 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -168,7 +168,8 @@ jobs: export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH # Limit MAX_JOBS otherwise the github runner goes OOM - MAX_JOBS=2 FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist + # CUDA 11.8 can compile with 2 jobs, but CUDA 12.2 goes OOM + MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "122" ] && echo 1 || echo 2) FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist tmpname=cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }} wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2") ls dist/*whl |xargs -I {} mv {} dist/${wheel_name} diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index a461e8ac6..242022d6a 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.5.9" +__version__ = "2.5.9.post1" from flash_attn.flash_attn_interface import ( flash_attn_func, diff --git a/training/Dockerfile b/training/Dockerfile index 2c68bd1ea..0baec9278 100644 --- a/training/Dockerfile +++ b/training/Dockerfile @@ -85,7 +85,7 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0 # Install FlashAttention -RUN pip install flash-attn==2.5.9 +RUN pip install flash-attn==2.5.9.post1 # Install CUDA extensions for fused dense -RUN pip install git+https://github.com/HazyResearch/flash-attention@v2.5.9#subdirectory=csrc/fused_dense_lib +RUN pip install git+https://github.com/HazyResearch/flash-attention@v2.5.9.post1#subdirectory=csrc/fused_dense_lib From 320fb59487658f033f56711efd3d61b7c7a6f8f3 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 26 May 2024 16:09:03 -0700 Subject: [PATCH 17/17] Update citation --- README.md | 7 ++++--- flash_attn/utils/generation.py | 7 ++++++- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index aaab7e1f1..53c9fd3af 100644 --- a/README.md +++ b/README.md @@ -400,12 +400,13 @@ If you use this codebase, or otherwise found our work valuable, please cite: @inproceedings{dao2022flashattention, title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness}, author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, - booktitle={Advances in Neural Information Processing Systems}, + booktitle={Advances in Neural Information Processing Systems (NeurIPS)}, year={2022} } -@article{dao2023flashattention2, +@inproceedings{dao2023flashattention2, title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning}, author={Dao, Tri}, - year={2023} + booktitle={International Conference on Learning Representations (ICLR)}, + year={2024} } ``` diff --git a/flash_attn/utils/generation.py b/flash_attn/utils/generation.py index d5d113903..0d9120c38 100644 --- a/flash_attn/utils/generation.py +++ b/flash_attn/utils/generation.py @@ -12,7 +12,12 @@ from einops import rearrange, repeat from torch import Tensor from torch.profiler import ProfilerActivity, profile, record_function -from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput + +try: + from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput +except ImportError: + GreedySearchDecoderOnlyOutput = namedtuple("GreedySearchDecoderOnlyOutput", ["sequences", "scores"]) + SampleDecoderOnlyOutput = namedtuple("SampleDecoderOnlyOutput", ["sequences", "scores"]) @dataclass