Skip to content

Commit

Permalink
Update num_splits heuristics
Browse files Browse the repository at this point in the history
  • Loading branch information
poyenc committed Dec 17, 2024
1 parent 3c655c5 commit e5c5435
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 37 deletions.
51 changes: 44 additions & 7 deletions csrc/flash_attn_ck/flash_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,65 @@
#include "flash_common.hpp"

namespace flash {
int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits)
int override_num_splits_if_necessary(int batch,
int nhead,
int max_seqlen_q,
int hdim_q,
int hdim_v,
float p_drop,
bool is_prefill,
int num_splits)
{
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
{
return num_splits;
}

hipDeviceProp_t props{};
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess)
{
return num_splits;
}

// TODO - tile size should match the TileFmhaShape, hardcode for now
const int kM0 = 128;
const int kN1 = hdim_v;
const int kM0 = [&] {
// get kM0 for prefill phase
if(is_prefill)
{
return 128;
}

// get kM0 for decode phase
/// TODO: take dtype=fp8/bf8 into consideration
const std::map<int, int> hdim_to_m0 = {
{32, 32},
{64, 64},
// {96, 64},
{128, 64},
{256, 64},
};

for(auto [hdim, m0] : hdim_to_m0)
{
if(hdim_q <= hdim && hdim_v <= hdim)
{
return m0;
}
}

return 64; // meet unsupported hdim_q/hdim_v
}();
// const int kN1 = hdim_v;

const int num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
const int num_n_blocks = (hdim_v + kN1 - 1) / kN1;
// const int num_n_blocks = (hdim_v + kN1 - 1) / kN1; // always 1

if(num_splits < 1 && p_drop == 0.0f)
return num_splits_heuristic_ck(
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128);
{
return num_splits_heuristic_ck(batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 8);
}

return num_splits;
}
Expand Down
63 changes: 35 additions & 28 deletions csrc/flash_attn_ck/flash_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,42 +35,49 @@ inline __global__ void ParsePhiloxCudaState(at::PhiloxCudaState arg, uint64_t* r
}
}

inline int num_splits_heuristic_ck(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {
inline int num_splits_heuristic_ck(int batch_nhead_mblocks, int num_SMs, int max_splits)
{
// If we have enough to almost fill the SMs, then just use 1 split
if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; }
max_splits = std::min({max_splits, num_SMs, num_n_blocks});
if(batch_nhead_mblocks >= 0.8f * num_SMs)
{
return 1;
}

max_splits = std::min({max_splits, num_SMs});

constexpr std::array<int, 5> num_splits_array = {1, 2, 4, 8, 16};

float max_efficiency = 0.f;
std::vector<float> efficiency;
efficiency.reserve(max_splits);
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
// Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
// we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
// (i.e. it's 11 splits anyway).
// So we check if the number of blocks per split is the same as the previous num_splits.
auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
};
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
if (!is_split_eligible(num_splits)) {
efficiency.push_back(0.f);
} else {
float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;
float eff = n_waves / ceil(n_waves);
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
if (eff > max_efficiency) { max_efficiency = eff; }
efficiency.push_back(eff);
std::array<float, num_splits_array.size()> efficiency;

for(size_t idx = 0; idx < num_splits_array.size() && num_splits_array[idx] <= max_splits; ++idx)
{
float n_blocks = float(batch_nhead_mblocks * num_splits_array[idx]) / num_SMs;
float eff = n_blocks / std::ceil(n_blocks);

if(eff > max_efficiency)
{
max_efficiency = eff;
}
efficiency[idx] = eff;
}
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
if (!is_split_eligible(num_splits)) { continue; }
if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
// printf("num_splits chosen = %d\n", num_splits);
return num_splits;
for(size_t idx = 0; idx < num_splits_array.size() && num_splits_array[idx] <= max_splits; ++idx)
{
if(efficiency[idx] >= 0.85 * max_efficiency)
{
return num_splits_array[idx];
}
}
return 1;
}

int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits);
int override_num_splits_if_necessary(int batch,
int nhead,
int max_seqlen_q,
int hdim_q,
int hdim_v,
float p_drop,
bool is_prefill,
int num_splits);

} // namespace flash
3 changes: 2 additions & 1 deletion csrc/flash_attn_ck/mha_fwd_kvcache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_siz
TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32");
}

num_splits = flash::override_num_splits_if_necessary(batch_size, num_heads, seqlen_q, head_size_8x, 0, num_splits);
num_splits = flash::override_num_splits_if_necessary(batch_size, num_heads, seqlen_q, head_size_8x, head_size_8x,
/*p_drop=*/0, /*is_prefill=*/false, num_splits);
TORCH_CHECK(num_splits > 0, "num_splits should greater than 0");
TORCH_CHECK(num_splits <= 128, "num_splits greater than 128 is not supported");

Expand Down
3 changes: 2 additions & 1 deletion csrc/flash_attn_ck/mha_varlen_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si
}

int num_splits = 0;
num_splits = flash::override_num_splits_if_necessary(batch_size, num_heads, max_seqlen_q, head_size, 0, num_splits);
num_splits = flash::override_num_splits_if_necessary(batch_size, num_heads, max_seqlen_q, head_size, head_size,
/*p_drop=*/0, /*is_prefill=*/true, num_splits);
TORCH_CHECK(num_splits > 0, "num_splits should greater than 0");
TORCH_CHECK(num_splits <= 128, "num_splits greater than 128 is not supported");

Expand Down

0 comments on commit e5c5435

Please sign in to comment.