Skip to content

Commit

Permalink
SpecInfer fix corner case (#1258)
Browse files Browse the repository at this point in the history
* init

* fix speculative

* fix speculative

* bitmap+tree verify

* fix.

* fix

* multi batch

* copy metadata once

* fix some corner cases

* Replicate load_token tasks so that it can be fused with other compute tasks; this eliminates Replicate and enables a larger fused op

* more fix.

* clean up

* .

* load batchconfig

* clean

* hip

* hip

* embedding return when no token

* use arg topk instead of beam topk

* embedding

* fmt

* hip

* fix corner case

---------

Co-authored-by: Zhihao Jia <zhihao@cmu.edu>
  • Loading branch information
xinhaoc and jiazhihao authored Jan 2, 2024
1 parent 25097e0 commit a45826e
Show file tree
Hide file tree
Showing 21 changed files with 224 additions and 107 deletions.
14 changes: 10 additions & 4 deletions include/flexflow/batch_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class BatchConfig {
int num_active_tokens() const;
static int max_requests_per_batch();
static int max_tokens_per_batch();
static int max_verify_tokens_per_batch();
static int max_sequence_length();
friend std::ostream &operator<<(std::ostream &os, BatchConfig const &bc);
void print() const;
Expand Down Expand Up @@ -72,6 +73,7 @@ class BatchConfig {

// request id in batch config:
int batch_config_request_id;
bool prompt_phase = false;
RequestGuid request_guid;
};
struct PerTokenInfo {
Expand All @@ -85,15 +87,15 @@ class BatchConfig {

// how many tokens before the tree, every sub requests need this part of
// cache
int non_tree_cache_size;
int non_tree_cache_size = 0;

// current tree size
int tree_size;
int tree_size = 0;

int this_layer_size;
int this_layer_size = 0;

// input length-> prompt/root
int prompt_size;
int prompt_size = 0;
};

BitMask causalMask[MAX_NUM_REQUESTS];
Expand Down Expand Up @@ -145,9 +147,13 @@ class BeamSearchBatchConfig : public BatchConfig {
bool done() const;
int max_beam_depth_all_requests() const;
int current_depth_all_requests() const;
int get_speculative_request_num() const;

size_t beam_width;
size_t target_iterations;

// how many requests is in speculative phase
int speculative_request_num = 0;
inline static int const MAX_BEAM_WIDTH = 3;
inline static int const MAX_BEAM_DEPTH = 8;

Expand Down
3 changes: 2 additions & 1 deletion include/flexflow/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ struct FFHandler {
sizeof(BeamSearchBatchConfig::beamTokenInfo) +
sizeof(BeamSearchBatchConfig::beamRequestsInfo) +
sizeof(BatchConfig::causalMask) +
sizeof(TreeVerifyBatchConfig::committed_tokens);
sizeof(TreeVerifyBatchConfig::committed_tokens) +
sizeof(BatchConfig::request_completed);
void *offload_reserve_space;
size_t offload_reserve_space_size;
DataType quantization_type;
Expand Down
1 change: 1 addition & 0 deletions include/flexflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ enum TaskIDs {
DROPOUT_BWD_TASK_ID,
EMBED_INIT_TASK_ID,
EMBED_FWD_TASK_ID,
EMBED_INF_TASK_ID,
EMBED_BWD_TASK_ID,
GATHER_INIT_TASK_ID,
GATHER_FWD_TASK_ID,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ inline void smem_size_in_bytes_tree(int hidden_size_per_head,
}

// todo fix this
int max_qk_length = max_query_length * max_total_length;
int max_qk_length = max_query_length;

// The amount of shared memory needed to store the Q*K^T values in float.
size_t qk_sz = div_up(max_qk_length + 1, 4) * 16;
Expand Down
1 change: 1 addition & 0 deletions include/flexflow/ops/spec_inc_multihead_self_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class SpecIncMultiHeadSelfAttentionMeta : public IncMultiHeadSelfAttentionMeta {
Realm::RegionInstance beam_search_reserve_inst;
BeamSearchBatchConfig::BeamSearchPerTokenInfo *beam_token_infos;
BeamSearchBatchConfig::BeamSearchPerRequestInfo *beam_request_infos;
bool *request_completed;
BatchConfig::BitMask *causalMask;
};

Expand Down
1 change: 1 addition & 0 deletions include/flexflow/ops/tree_inc_multihead_self_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class TreeIncMultiHeadSelfAttentionMeta : public IncMultiHeadSelfAttentionMeta {
int num_active_tokens;
Realm::RegionInstance committed_token_reserve_inst;
TreeVerifyBatchConfig::CommittedTokensInfo *committed_token_infos;
bool *request_completed;
BatchConfig::BitMask *causalMask;
};

Expand Down
2 changes: 2 additions & 0 deletions include/flexflow/request_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class RequestManager {
int get_max_requests_per_batch();
void set_max_tokens_per_batch(int max_num_tokens);
int get_max_tokens_per_batch();
int get_max_verify_tokens_per_batch();
void set_max_sequence_length(int max_seq_length);
void push_spec_infer_tree_width(int tree_width);
int get_max_sequence_length();
Expand All @@ -113,6 +114,7 @@ class RequestManager {
std::string const &path);
void register_output_filepath(std::string const &);
void initBitMask(BatchConfig::BitMask &bitmask, int initLength);
void appendPendingRequest(BatchConfig::BitMask &bitmask, int initLength);
void appendBitMask(BatchConfig::BitMask &bitmask,
int newNodes,
int preBeamSize,
Expand Down
5 changes: 4 additions & 1 deletion inference/models/falcon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ void FALCON::create_falcon_model(FFModel &ff,
Tensor input;
{
// assert(falcon_config.max_num_tokens <= BatchConfig::MAX_NUM_TOKENS);
int const token_dims[] = {BatchConfig::max_tokens_per_batch(), 1};
int const token_dims[] = {mode == TREE_VERIFY_MODE
? BatchConfig::max_verify_tokens_per_batch()
: BatchConfig::max_tokens_per_batch(),
1};
input = ff.create_tensor<2>(token_dims, DT_INT32);
}

Expand Down
5 changes: 4 additions & 1 deletion inference/models/llama.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ void LLAMA::create_llama_model(FFModel &ff,

Tensor input;
{
int const token_dims[] = {BatchConfig::max_tokens_per_batch(), 1};
int const token_dims[] = {mode == TREE_VERIFY_MODE
? BatchConfig::max_verify_tokens_per_batch()
: BatchConfig::max_tokens_per_batch(),
1};
input = ff.create_tensor<2>(token_dims, DT_INT32);
}

Expand Down
5 changes: 4 additions & 1 deletion inference/models/mpt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ void MPT::create_mpt_model(FFModel &ff,
//------------------------------ build the model --------------------------
Tensor input;
{
int const token_dims[] = {BatchConfig::max_tokens_per_batch(), 1};
int const token_dims[] = {mode == TREE_VERIFY_MODE
? BatchConfig::max_verify_tokens_per_batch()
: BatchConfig::max_tokens_per_batch(),
1};
input = ff.create_tensor<2>(token_dims, DT_INT32);
}

Expand Down
5 changes: 4 additions & 1 deletion inference/models/opt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ void OPT::create_opt_model(FFModel &ff,
Tensor position_input;
ff.set_position_offset(2);
{
int const token_dims[] = {BatchConfig::max_tokens_per_batch(), 1};
int const token_dims[] = {mode == TREE_VERIFY_MODE
? BatchConfig::max_verify_tokens_per_batch()
: BatchConfig::max_tokens_per_batch(),
1};
input = ff.create_tensor<2>(token_dims, DT_INT32);
position_input = ff.create_tensor<2>(token_dims, DT_INT32);
}
Expand Down
5 changes: 4 additions & 1 deletion inference/models/starcoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ void STARCODER::create_starcoder_model(
ff.set_position_offset(0);
{
// assert(startcoder_config.max_num_tokens <= BatchConfig::MAX_NUM_TOKENS);
int const token_dims[] = {BatchConfig::max_tokens_per_batch(), 1};
int const token_dims[] = {mode == TREE_VERIFY_MODE
? BatchConfig::max_verify_tokens_per_batch()
: BatchConfig::max_tokens_per_batch(),
1};
input = ff.create_tensor<2>(token_dims, DT_INT32);
position_input = ff.create_tensor<2>(token_dims, DT_INT32);
}
Expand Down
11 changes: 8 additions & 3 deletions src/ops/arg_topk.cu
Original file line number Diff line number Diff line change
Expand Up @@ -404,17 +404,22 @@ void ArgTopK::forward_kernel(ArgTopKMeta const *m,
assert(bc->num_active_requests() >= 0);

// check
// allow last request different with others
int beam_size = -1;
for (int i = 1; i < bc->max_requests_per_batch(); i++) {
int num_activate_requests = bc->num_active_requests();
int last_request_idx =
bc->requestsInfo[num_activate_requests - 1].batch_config_request_id;
for (int i = 0; i < bc->max_requests_per_batch(); i++) {
if (bc->request_completed[i]) {
continue;
} else if (beam_size == -1) {
beam_size = bc->beamRequestsInfo[i].beam_size;
} else {

} else if (i != last_request_idx) {
assert(beam_size == bc->beamRequestsInfo[i].beam_size);
} else if (i == last_request_idx) {
}
}

assert(num_shards >= (size_t)beam_size);
num_shards = k;
arg_topk_forward_kernel<<<num_blocks, num_shards, 0, stream>>>(
Expand Down
4 changes: 3 additions & 1 deletion src/ops/inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1349,7 +1349,9 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta(
// allocate memory for the seqArray and reserve space
{
int max_tokens_per_batch = BatchConfig::max_tokens_per_batch();
int max_tokens_per_batch = infer_mode == TREE_VERIFY_MODE
? BatchConfig::max_verify_tokens_per_batch()
: BatchConfig::max_tokens_per_batch();
size_t qkv_max_proj_size = max_tokens_per_batch * (qProjSize * num_q_heads +
kProjSize * num_q_heads +
vProjSize * num_q_heads);
Expand Down
60 changes: 35 additions & 25 deletions src/ops/spec_inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ __global__ void compute_spec_inc_attention_kernel_generation_kernel(
int hidden_size,
BatchConfig::PerRequestInfo *request_infos,
BeamSearchBatchConfig::BeamSearchPerRequestInfo *beam_request_infos,
BatchConfig::BitMask *causalMask) {
BatchConfig::BitMask *causalMask,
bool *request_completed) {

// q, k
using Q_vec = typename VEC_K<DT, THREADS_PER_KEY>::Type;
Expand Down Expand Up @@ -86,11 +87,12 @@ __global__ void compute_spec_inc_attention_kernel_generation_kernel(
// request_infos[batch_config_request_id].first_token_depth_in_request +
// request_infos[batch_config_request_id].num_tokens_in_batch;

int const totalCacheSize = bitmask.non_tree_cache_size + bitmask.tree_size;
int const totalCacheSize =
bitmask.non_tree_cache_size + bitmask.tree_size + bitmask.prompt_size - 1;

int first_token_idx = 0;
for (int r = 0; r < request_idx; r++) {
first_token_idx += causalMask[r].this_layer_size;
for (int r = 0; r < batch_config_request_id; r++) {
first_token_idx += request_completed[r] ? 0 : causalMask[r].this_layer_size;
}

int const tree_branch_num =
Expand Down Expand Up @@ -138,7 +140,8 @@ __global__ void compute_spec_inc_attention_kernel_generation_kernel(
ii * THREADS_PER_KEY * K_VEC_SIZE);
}

int const query_token = bitmask.tree_size - tree_branch_num + qi;
int const query_token =
bitmask.prompt_size + bitmask.tree_size - 1 - tree_branch_num + qi;

__syncthreads();
for (int ti = ko; ti < ti_end; ti += K_PER_ITER) {
Expand All @@ -163,8 +166,12 @@ __global__ void compute_spec_inc_attention_kernel_generation_kernel(
(!(bitmask.mask[ti - bitmask.non_tree_cache_size] &
(1 << query_token))));

// if (blockIdx.y == 0 && blockIdx.x == 0 && !mask) {
// printf("spec inc attn qkqkqk %d, %.10f, %d\n", ti, qk, qi);
// if (head_idx == 0 && ti == 0 && request_idx == 15 && !mask) {
// printf("spec inc attn qkqkqk request id %d, %.10f, %d\n",
// batch_config_request_id,
// ti,
// qk,
// qi);
// }
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
qk_smem[ti - first_step] = mask ? 0.f : qk;
Expand Down Expand Up @@ -336,17 +343,12 @@ __global__ void spec_inc_store_kv_cache(

BatchConfig::BitMask bitmask = causalMask[req_id];

// int const tree_branch_num = beamRequestInfos[req_id].sub_request_num;

// int const query_token = bitmask.non_tree_cache_size + bitmask.tree_size -
// tree_branch_num + sub_req_id + tok_id;
// bitmask.tree_size - tree_branch_num + sub_req_id;

// if prompt token -> token id
// if tree token:
int const cache_idx = bitmask.non_tree_cache_size + bitmask.tree_size -
bitmask.this_layer_size + token_idx -
request_token_offset;

int const cache_idx = bitmask.prompt_size + bitmask.non_tree_cache_size +
bitmask.tree_size - 1 - bitmask.this_layer_size +
token_idx - request_token_offset;

kCache_ptr[req_id * (hidden_size * max_seq_len) + (cache_idx)*hidden_size +
offset] = kVal;
Expand Down Expand Up @@ -411,7 +413,8 @@ void update_kv_cache_kernel(SpecIncMultiHeadSelfAttentionMeta const *m,
m->hidden_size, \
m->request_infos, \
m->beam_request_infos, \
m->causalMask)
m->causalMask, \
m->request_completed)
template <typename DT>
void compute_spec_inc_attention_kernel_generation(
Expand All @@ -420,7 +423,8 @@ void compute_spec_inc_attention_kernel_generation(
DT *output_ptr,
cudaStream_t stream) {
// one block == one head per request
dim3 grid(m->num_q_heads, bc->num_active_requests());
// how many generation requests
dim3 grid(m->num_q_heads, bc->get_speculative_request_num());
int const per_head_size = m->qProjSize;
float scale = (*m->qk_prod_scaling) ? 1.0f / sqrt(m->kProjSize) : 1.0f;
size_t smem_sz;
Expand Down Expand Up @@ -499,11 +503,10 @@ void compute_attention_kernel_prompt(SpecIncMultiHeadSelfAttentionMeta const *m,
for (int i = 0; i < bc->max_requests_per_batch(); i++) {
if (bc->request_completed[i]) {
continue;
} else if (tokens_previous_requests < bc->num_generation_tokens) {
tokens_previous_requests += bc->requestsInfo[i].num_tokens_in_batch;
continue;
}
// else if (tokens_previous_requests < bc->num_generation_tokens) {
// tokens_previous_requests += bc->requestsInfo[i].num_tokens_in_batch;
// continue;
// }
// all requests in prompt phase should only have one sub requests;
assert(bc->sub_requests[i] == 1);
Expand Down Expand Up @@ -659,10 +662,10 @@ void compute_attention_kernel_prompt(SpecIncMultiHeadSelfAttentionMeta const *m,
// To get C, skip over softmax(QK^T/sqrt(d_k))V products from previous
// requests
// print_tensor<float>((float*)C_softmax, 32, "C_softmax");
int token_offset = bc->requestsInfo[i].first_token_offset_in_batch;
C = static_cast<DT *>(m->attn_heads) +
(tokens_previous_requests + bc->num_generation_tokens) *
m->num_q_heads * m->vProjSize;
(token_offset)*m->num_q_heads * m->vProjSize;
checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas,
CUBLAS_OP_N,
CUBLAS_OP_T,
Expand Down Expand Up @@ -860,6 +863,13 @@ SpecIncMultiHeadSelfAttentionMeta::SpecIncMultiHeadSelfAttentionMeta(
sizeof(BatchConfig::tokensInfo) + sizeof(BatchConfig::requestsInfo) +
sizeof(BeamSearchBatchConfig::beamTokenInfo) +
sizeof(BeamSearchBatchConfig::beamRequestsInfo));
request_completed = reinterpret_cast<bool *>(
reinterpret_cast<char *>(handler.batch_config_metadata) +
sizeof(BatchConfig::tokensInfo) + sizeof(BatchConfig::requestsInfo) +
sizeof(BeamSearchBatchConfig::beamTokenInfo) +
sizeof(BeamSearchBatchConfig::beamRequestsInfo) +
sizeof(BatchConfig::causalMask));
}
cudaStreamSynchronize(stream);
Expand Down
Loading

0 comments on commit a45826e

Please sign in to comment.