Skip to content

Commit

Permalink
Merge pull request #55 from Dao-AILab/main
Browse files Browse the repository at this point in the history
update to upstream
  • Loading branch information
rocking5566 authored May 28, 2024
2 parents 23e8fa5 + 320fb59 commit bc191ef
Show file tree
Hide file tree
Showing 18 changed files with 167 additions and 122 deletions.
36 changes: 22 additions & 14 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ jobs:
# Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
os: [ubuntu-20.04]
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.0', '2.3.0.dev20240105']
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12']
torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.2', '2.3.0', '2.4.0.dev20240407']
cuda-version: ['11.8.0', '12.2.2']
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
Expand All @@ -53,6 +53,15 @@ jobs:
cxx11_abi: ['FALSE', 'TRUE']
exclude:
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
# Pytorch < 2.2 does not support Python 3.12
- torch-version: '1.12.1'
python-version: '3.12'
- torch-version: '1.13.1'
python-version: '3.12'
- torch-version: '2.0.1'
python-version: '3.12'
- torch-version: '2.1.2'
python-version: '3.12'
# Pytorch <= 1.12 does not support Python 3.11
- torch-version: '1.12.1'
python-version: '3.11'
Expand All @@ -61,9 +70,11 @@ jobs:
python-version: '3.7'
- torch-version: '2.1.2'
python-version: '3.7'
- torch-version: '2.2.0'
- torch-version: '2.2.2'
python-version: '3.7'
- torch-version: '2.3.0'
python-version: '3.7'
- torch-version: '2.3.0.dev20240105'
- torch-version: '2.4.0.dev20240407'
python-version: '3.7'
# Pytorch <= 2.0 only supports CUDA <= 11.8
- torch-version: '1.12.1'
Expand Down Expand Up @@ -123,23 +134,19 @@ jobs:
# If we don't install before installing Pytorch, we get error for torch 2.0.1
# ERROR: Could not find a version that satisfies the requirement setuptools>=40.8.0 (from versions: none)
pip install lit
# For some reason torch 2.2.0 on python 3.12 errors saying no setuptools
pip install setuptools
# We want to figure out the CUDA version to download pytorch
# e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
# This code is ugly, maybe there's a better way to do this.
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118, '2.2': 118, '2.3': 118}[env['MATRIX_TORCH_VERSION']]; \
maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121, '2.2': 121, '2.3': 121}[env['MATRIX_TORCH_VERSION']]; \
minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118}[env['MATRIX_TORCH_VERSION']]; \
maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121, '2.2': 121, '2.3': 121, '2.4': 121}[env['MATRIX_TORCH_VERSION']]; \
print(max(min(int(env['MATRIX_CUDA_VERSION']), maxv), minv))" \
)
if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
if [[ ${MATRIX_TORCH_VERSION} == "2.2" ]]; then
# --no-deps because we can't install old versions of pytorch-triton
pip install typing-extensions jinja2
pip install --no-cache-dir --no-deps --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl
else
pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
fi
pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
else
pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
fi
Expand All @@ -161,7 +168,8 @@ jobs:
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
# Limit MAX_JOBS otherwise the github runner goes OOM
MAX_JOBS=2 FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
# CUDA 11.8 can compile with 2 jobs, but CUDA 12.2 goes OOM
MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "122" ] && echo 1 || echo 2) FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist
tmpname=cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }}
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
Expand Down
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -400,12 +400,13 @@ If you use this codebase, or otherwise found our work valuable, please cite:
@inproceedings{dao2022flashattention,
title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle={Advances in Neural Information Processing Systems},
booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
year={2022}
}
@article{dao2023flashattention2,
@inproceedings{dao2023flashattention2,
title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning},
author={Dao, Tri},
year={2023}
booktitle={International Conference on Learning Representations (ICLR)},
year={2024}
}
```
2 changes: 1 addition & 1 deletion csrc/cutlass
Submodule cutlass updated 548 files
16 changes: 12 additions & 4 deletions csrc/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,8 @@ void set_params_splitkv(Flash_fwd_params &params, const int batch_size,
params.num_splits = num_splits;
if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
if (num_splits < 1) {
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
// We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block.
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount * 2, num_n_blocks, 128);
}
if (params.num_splits > 1) {
at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
Expand Down Expand Up @@ -372,8 +373,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
const int ngroups = num_heads / num_heads_k;
if (seqlenq_ngroups_swapped) {
const int ngroups = num_heads / num_heads_k;
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
seqlen_q = ngroups;
num_heads = num_heads_k;
Expand All @@ -400,7 +401,10 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size_og);
if (seqlenq_ngroups_swapped) {
out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
}
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
} else {
out = torch::empty_like(q_padded);
Expand Down Expand Up @@ -571,8 +575,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
const int ngroups = num_heads / num_heads_k;
if (seqlenq_ngroups_swapped) {
const int ngroups = num_heads / num_heads_k;
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
max_seqlen_q = ngroups;
num_heads = num_heads_k;
Expand Down Expand Up @@ -627,6 +631,10 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
CHECK_DEVICE(out);
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, total_q, num_heads, head_size_og);
CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
if (seqlenq_ngroups_swapped) {
out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
}
if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
} else {
out = torch::empty_like(q_padded);
Expand Down
2 changes: 1 addition & 1 deletion csrc/flash_attn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#pragma once

#include <cute/algorithm/copy.hpp>
#include <cute/tensor.hpp>

#include <cutlass/cutlass.h>
#include <cutlass/array.h>
Expand Down
2 changes: 1 addition & 1 deletion csrc/flash_attn/src/flash_bwd_preprocess_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#pragma once

#include <cute/algorithm/copy.hpp>
#include <cute/tensor.hpp>

#include <cutlass/cutlass.h>
#include <cutlass/array.h>
Expand Down
Loading

0 comments on commit bc191ef

Please sign in to comment.