Skip to content

Commit

Permalink
Update CUTLASS to v3.5.0
Browse files Browse the repository at this point in the history
ghstack-source-id: 887e806ca45188df99edb4265efab1aae1714608
Pull Request resolved: fairinternal/xformers#1147

__original_commit__ = fairinternal/xformers@6f0c5d0
  • Loading branch information
lw authored and xFormers Bot committed Jul 2, 2024
1 parent b63143b commit 89b9128
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ on:

env:
# you need at least cuda 5.0 for some of the stuff compiled here.
TORCH_CUDA_ARCH_LIST: "5.0+PTX 6.0 6.1 7.0 7.5 8.0+PTX"
TORCH_CUDA_ARCH_LIST: "6.0+PTX 6.1 7.0 7.5 8.0+PTX"
MAX_JOBS: 3 # Avoids OOMs
XFORMERS_BUILD_TYPE: "Release"
XFORMERS_PACKAGE_FROM: "conda-${{ github.ref_name }}"
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/wheels_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ on:

env:
# you need at least cuda 5.0 for some of the stuff compiled here.
TORCH_CUDA_ARCH_LIST: "5.0+PTX 6.0 6.1 7.0 7.5 8.0+PTX"
TORCH_CUDA_ARCH_LIST: "6.0+PTX 6.1 7.0 7.5 8.0+PTX"
MAX_JOBS: 4
DISTUTILS_USE_SDK: 1 # otherwise distutils will complain on windows about multiple versions of msvc
XFORMERS_BUILD_TYPE: "Release"
Expand Down
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,9 @@ def get_extensions():
extra_compile_args = {"cxx": ["-O3", "-std=c++17"]}
if sys.platform == "win32":
define_macros += [("xformers_EXPORTS", None)]
extra_compile_args["cxx"].extend(["/MP", "/Zc:lambda", "/Zc:preprocessor"])
extra_compile_args["cxx"].extend(
["/MP", "/Zc:lambda", "/Zc:preprocessor", "/Zc:__cplusplus"]
)
elif "OpenMP not found" not in torch.__config__.parallel_info():
extra_compile_args["cxx"].append("-fopenmp")

Expand Down Expand Up @@ -360,6 +362,8 @@ def get_extensions():
"/Zc:lambda",
"-Xcompiler",
"/Zc:preprocessor",
"-Xcompiler",
"/Zc:__cplusplus",
]
extra_compile_args["nvcc"] = nvcc_flags

Expand Down
2 changes: 1 addition & 1 deletion third_party/cutlass
Submodule cutlass updated 2032 files
2 changes: 1 addition & 1 deletion xformers/csrc/attention/cuda/fmha/kernel_backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -1441,7 +1441,7 @@ struct AttentionBackwardKernel {
uint8_t lane_id) {
cutlass::Array<cutlass::uint1b_t, MatmulDOIVJ::Mma::FragmentC::kElements>
dropout_keep_mask_doivj;
dropout_keep_mask_doivj.fill(1);
dropout_keep_mask_doivj.fill(cutlass::uint1b_t{1});
const float dropout_scale =
kApplyDropout ? 1.0 / (1.0 - p.dropout_prob) : 1.0f;

Expand Down
6 changes: 3 additions & 3 deletions xformers/csrc/sparse24/compute_sparse_tile.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ struct LargestValuesGreedy {
for (int j = 0; j < 4; ++j) {
TileValueOrdered& v = values_ordered[i * 4 + j];
v.parts.value = values.at(i, j).get();
v.parts.col = j;
v.parts.row = i;
v.parts.col = uint2b_t{j};
v.parts.row = uint2b_t{i};
}
}
// Use a sorting network (aka without branches) to avoid
Expand Down Expand Up @@ -150,7 +150,7 @@ struct Causal1122 {
for (int col = 0; col < 4; ++col) {
TileValueOrdered& v = values_ordered[col];
v.parts.value = values.at(row, col).get();
v.parts.col = col;
v.parts.col = uint2b_t{col};
}
// Use a sorting network (aka without branches) to avoid
// warp divergence
Expand Down
16 changes: 8 additions & 8 deletions xformers/csrc/sparse24/sparse24_pack.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ warp_shuffle_meta(uint32_t meta_ab, bool transposed = false) {
uint8b_t(meta_ab >> (8 * (thread_left + 2)))};
// shfl t0-t4 / t1-t5
stage0_data[0] =
__shfl_xor_sync(0xffffffff, stage0_data[0], transposed ? 1 : 4);
uint8b_t{__shfl_xor_sync(0xffffffff, stage0_data[0], transposed ? 1 : 4)};
stage0_data[1] =
__shfl_xor_sync(0xffffffff, stage0_data[1], transposed ? 1 : 4);
uint8b_t{__shfl_xor_sync(0xffffffff, stage0_data[1], transposed ? 1 : 4)};

uint16_t line0 = int(uint8b_t(meta_ab >> (8 * (1 - thread_left))))
<< ((1 - thread_left) * 8);
Expand Down Expand Up @@ -229,24 +229,24 @@ struct KernelTypes {
// We know that col0 is always packed to position 0 if it's there
// and col1 is packed to pos 0 or 1 (depending if col0 is selected)
if (isSelected(1)) {
packValue(0, 1);
packValue(uint2b_t{0}, uint2b_t{1});
}
if (isSelected(0)) {
packValue(0, 0);
packValue(uint2b_t{0}, uint2b_t{0});
}
if (isSelected(0) && isSelected(1)) {
packValue(1, 1);
packValue(uint2b_t{1}, uint2b_t{1});
}
// Process cols 2/3
// same sort of heuristic
if (isSelected(2)) {
packValue(1, 2);
packValue(uint2b_t{1}, uint2b_t{2});
}
if (isSelected(3)) {
packValue(1, 3);
packValue(uint2b_t{1}, uint2b_t{3});
}
if (isSelected(2) && isSelected(3)) {
packValue(0, 2);
packValue(uint2b_t{0}, uint2b_t{2});
}
int add_mask = (col0_from | (col1_from << 2)) << (8 * row + meta_pos);
meta |= add_mask;
Expand Down

0 comments on commit 89b9128

Please sign in to comment.