From a45826e9daa0364b49f353c1c85cf2a9800bc1d9 Mon Sep 17 00:00:00 2001 From: Xinhao Cheng <99570243+xinhaoc@users.noreply.github.com> Date: Tue, 2 Jan 2024 15:28:52 -0500 Subject: [PATCH] SpecInfer fix corner case (#1258) * init * fix speculative * fix speculative * bitmap+tree verify * fix. * fix * multi batch * copy metadata once * fix some corner cases * Replicate load_token tasks so that it can be fused with other compute tasks; this eliminates Replicate and enables a larger fused op * more fix. * clean up * . * load batchconfig * clean * hip * hip * embedding return when no token * use arg topk instead of beam topk * embedding * fmt * hip * fix corner case --------- Co-authored-by: Zhihao Jia --- include/flexflow/batch_config.h | 14 ++- include/flexflow/config.h | 3 +- include/flexflow/model.h | 1 + .../inc_multihead_self_attention_utils.cuh | 2 +- .../ops/spec_inc_multihead_self_attention.h | 1 + .../ops/tree_inc_multihead_self_attention.h | 1 + include/flexflow/request_manager.h | 2 + inference/models/falcon.cc | 5 +- inference/models/llama.cc | 5 +- inference/models/mpt.cc | 5 +- inference/models/opt.cc | 5 +- inference/models/starcoder.cc | 5 +- src/ops/arg_topk.cu | 11 ++- src/ops/inc_multihead_self_attention.cu | 4 +- src/ops/spec_inc_multihead_self_attention.cu | 60 +++++++----- src/ops/tree_inc_multihead_self_attention.cu | 62 +++++++------ src/runtime/batch_config.cc | 6 ++ src/runtime/beam_search_batch_config.cc | 4 + src/runtime/model.cc | 14 +++ src/runtime/request_manager.cc | 93 +++++++++++-------- src/runtime/request_manager.cu | 28 +++++- 21 files changed, 224 insertions(+), 107 deletions(-) diff --git a/include/flexflow/batch_config.h b/include/flexflow/batch_config.h index 13904aaa46..ef17ef43ed 100644 --- a/include/flexflow/batch_config.h +++ b/include/flexflow/batch_config.h @@ -45,6 +45,7 @@ class BatchConfig { int num_active_tokens() const; static int max_requests_per_batch(); static int max_tokens_per_batch(); + static int max_verify_tokens_per_batch(); static int max_sequence_length(); friend std::ostream &operator<<(std::ostream &os, BatchConfig const &bc); void print() const; @@ -72,6 +73,7 @@ class BatchConfig { // request id in batch config: int batch_config_request_id; + bool prompt_phase = false; RequestGuid request_guid; }; struct PerTokenInfo { @@ -85,15 +87,15 @@ class BatchConfig { // how many tokens before the tree, every sub requests need this part of // cache - int non_tree_cache_size; + int non_tree_cache_size = 0; // current tree size - int tree_size; + int tree_size = 0; - int this_layer_size; + int this_layer_size = 0; // input length-> prompt/root - int prompt_size; + int prompt_size = 0; }; BitMask causalMask[MAX_NUM_REQUESTS]; @@ -145,9 +147,13 @@ class BeamSearchBatchConfig : public BatchConfig { bool done() const; int max_beam_depth_all_requests() const; int current_depth_all_requests() const; + int get_speculative_request_num() const; size_t beam_width; size_t target_iterations; + + // how many requests is in speculative phase + int speculative_request_num = 0; inline static int const MAX_BEAM_WIDTH = 3; inline static int const MAX_BEAM_DEPTH = 8; diff --git a/include/flexflow/config.h b/include/flexflow/config.h index e1480264cc..17a3f59e29 100644 --- a/include/flexflow/config.h +++ b/include/flexflow/config.h @@ -84,7 +84,8 @@ struct FFHandler { sizeof(BeamSearchBatchConfig::beamTokenInfo) + sizeof(BeamSearchBatchConfig::beamRequestsInfo) + sizeof(BatchConfig::causalMask) + - sizeof(TreeVerifyBatchConfig::committed_tokens); + sizeof(TreeVerifyBatchConfig::committed_tokens) + + sizeof(BatchConfig::request_completed); void *offload_reserve_space; size_t offload_reserve_space_size; DataType quantization_type; diff --git a/include/flexflow/model.h b/include/flexflow/model.h index cf7bb3dd2d..6f805e21bd 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -73,6 +73,7 @@ enum TaskIDs { DROPOUT_BWD_TASK_ID, EMBED_INIT_TASK_ID, EMBED_FWD_TASK_ID, + EMBED_INF_TASK_ID, EMBED_BWD_TASK_ID, GATHER_INIT_TASK_ID, GATHER_FWD_TASK_ID, diff --git a/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh b/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh index c128c1a126..d1e0e050b2 100644 --- a/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh +++ b/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh @@ -493,7 +493,7 @@ inline void smem_size_in_bytes_tree(int hidden_size_per_head, } // todo fix this - int max_qk_length = max_query_length * max_total_length; + int max_qk_length = max_query_length; // The amount of shared memory needed to store the Q*K^T values in float. size_t qk_sz = div_up(max_qk_length + 1, 4) * 16; diff --git a/include/flexflow/ops/spec_inc_multihead_self_attention.h b/include/flexflow/ops/spec_inc_multihead_self_attention.h index a306f7985a..a0d01092bf 100644 --- a/include/flexflow/ops/spec_inc_multihead_self_attention.h +++ b/include/flexflow/ops/spec_inc_multihead_self_attention.h @@ -142,6 +142,7 @@ class SpecIncMultiHeadSelfAttentionMeta : public IncMultiHeadSelfAttentionMeta { Realm::RegionInstance beam_search_reserve_inst; BeamSearchBatchConfig::BeamSearchPerTokenInfo *beam_token_infos; BeamSearchBatchConfig::BeamSearchPerRequestInfo *beam_request_infos; + bool *request_completed; BatchConfig::BitMask *causalMask; }; diff --git a/include/flexflow/ops/tree_inc_multihead_self_attention.h b/include/flexflow/ops/tree_inc_multihead_self_attention.h index d160da4a72..02df0c0137 100644 --- a/include/flexflow/ops/tree_inc_multihead_self_attention.h +++ b/include/flexflow/ops/tree_inc_multihead_self_attention.h @@ -147,6 +147,7 @@ class TreeIncMultiHeadSelfAttentionMeta : public IncMultiHeadSelfAttentionMeta { int num_active_tokens; Realm::RegionInstance committed_token_reserve_inst; TreeVerifyBatchConfig::CommittedTokensInfo *committed_token_infos; + bool *request_completed; BatchConfig::BitMask *causalMask; }; diff --git a/include/flexflow/request_manager.h b/include/flexflow/request_manager.h index 1c4b0b2a2f..33714c106e 100644 --- a/include/flexflow/request_manager.h +++ b/include/flexflow/request_manager.h @@ -103,6 +103,7 @@ class RequestManager { int get_max_requests_per_batch(); void set_max_tokens_per_batch(int max_num_tokens); int get_max_tokens_per_batch(); + int get_max_verify_tokens_per_batch(); void set_max_sequence_length(int max_seq_length); void push_spec_infer_tree_width(int tree_width); int get_max_sequence_length(); @@ -113,6 +114,7 @@ class RequestManager { std::string const &path); void register_output_filepath(std::string const &); void initBitMask(BatchConfig::BitMask &bitmask, int initLength); + void appendPendingRequest(BatchConfig::BitMask &bitmask, int initLength); void appendBitMask(BatchConfig::BitMask &bitmask, int newNodes, int preBeamSize, diff --git a/inference/models/falcon.cc b/inference/models/falcon.cc index bfcec847b9..999ca37037 100644 --- a/inference/models/falcon.cc +++ b/inference/models/falcon.cc @@ -39,7 +39,10 @@ void FALCON::create_falcon_model(FFModel &ff, Tensor input; { // assert(falcon_config.max_num_tokens <= BatchConfig::MAX_NUM_TOKENS); - int const token_dims[] = {BatchConfig::max_tokens_per_batch(), 1}; + int const token_dims[] = {mode == TREE_VERIFY_MODE + ? BatchConfig::max_verify_tokens_per_batch() + : BatchConfig::max_tokens_per_batch(), + 1}; input = ff.create_tensor<2>(token_dims, DT_INT32); } diff --git a/inference/models/llama.cc b/inference/models/llama.cc index e9c84efe90..e54d6d8811 100644 --- a/inference/models/llama.cc +++ b/inference/models/llama.cc @@ -41,7 +41,10 @@ void LLAMA::create_llama_model(FFModel &ff, Tensor input; { - int const token_dims[] = {BatchConfig::max_tokens_per_batch(), 1}; + int const token_dims[] = {mode == TREE_VERIFY_MODE + ? BatchConfig::max_verify_tokens_per_batch() + : BatchConfig::max_tokens_per_batch(), + 1}; input = ff.create_tensor<2>(token_dims, DT_INT32); } diff --git a/inference/models/mpt.cc b/inference/models/mpt.cc index b074d332ed..3df67b264c 100644 --- a/inference/models/mpt.cc +++ b/inference/models/mpt.cc @@ -40,7 +40,10 @@ void MPT::create_mpt_model(FFModel &ff, //------------------------------ build the model -------------------------- Tensor input; { - int const token_dims[] = {BatchConfig::max_tokens_per_batch(), 1}; + int const token_dims[] = {mode == TREE_VERIFY_MODE + ? BatchConfig::max_verify_tokens_per_batch() + : BatchConfig::max_tokens_per_batch(), + 1}; input = ff.create_tensor<2>(token_dims, DT_INT32); } diff --git a/inference/models/opt.cc b/inference/models/opt.cc index 9b29ae5410..0279f83239 100644 --- a/inference/models/opt.cc +++ b/inference/models/opt.cc @@ -42,7 +42,10 @@ void OPT::create_opt_model(FFModel &ff, Tensor position_input; ff.set_position_offset(2); { - int const token_dims[] = {BatchConfig::max_tokens_per_batch(), 1}; + int const token_dims[] = {mode == TREE_VERIFY_MODE + ? BatchConfig::max_verify_tokens_per_batch() + : BatchConfig::max_tokens_per_batch(), + 1}; input = ff.create_tensor<2>(token_dims, DT_INT32); position_input = ff.create_tensor<2>(token_dims, DT_INT32); } diff --git a/inference/models/starcoder.cc b/inference/models/starcoder.cc index ba7b2cb43a..e683376e47 100644 --- a/inference/models/starcoder.cc +++ b/inference/models/starcoder.cc @@ -48,7 +48,10 @@ void STARCODER::create_starcoder_model( ff.set_position_offset(0); { // assert(startcoder_config.max_num_tokens <= BatchConfig::MAX_NUM_TOKENS); - int const token_dims[] = {BatchConfig::max_tokens_per_batch(), 1}; + int const token_dims[] = {mode == TREE_VERIFY_MODE + ? BatchConfig::max_verify_tokens_per_batch() + : BatchConfig::max_tokens_per_batch(), + 1}; input = ff.create_tensor<2>(token_dims, DT_INT32); position_input = ff.create_tensor<2>(token_dims, DT_INT32); } diff --git a/src/ops/arg_topk.cu b/src/ops/arg_topk.cu index 0b8bb8b563..5b7978812c 100644 --- a/src/ops/arg_topk.cu +++ b/src/ops/arg_topk.cu @@ -404,17 +404,22 @@ void ArgTopK::forward_kernel(ArgTopKMeta const *m, assert(bc->num_active_requests() >= 0); // check + // allow last request different with others int beam_size = -1; - for (int i = 1; i < bc->max_requests_per_batch(); i++) { + int num_activate_requests = bc->num_active_requests(); + int last_request_idx = + bc->requestsInfo[num_activate_requests - 1].batch_config_request_id; + for (int i = 0; i < bc->max_requests_per_batch(); i++) { if (bc->request_completed[i]) { continue; } else if (beam_size == -1) { beam_size = bc->beamRequestsInfo[i].beam_size; - } else { + + } else if (i != last_request_idx) { assert(beam_size == bc->beamRequestsInfo[i].beam_size); + } else if (i == last_request_idx) { } } - assert(num_shards >= (size_t)beam_size); num_shards = k; arg_topk_forward_kernel<<>>( diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index db64868cb9..7c8601d3c8 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -1349,7 +1349,9 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( // allocate memory for the seqArray and reserve space { - int max_tokens_per_batch = BatchConfig::max_tokens_per_batch(); + int max_tokens_per_batch = infer_mode == TREE_VERIFY_MODE + ? BatchConfig::max_verify_tokens_per_batch() + : BatchConfig::max_tokens_per_batch(); size_t qkv_max_proj_size = max_tokens_per_batch * (qProjSize * num_q_heads + kProjSize * num_q_heads + vProjSize * num_q_heads); diff --git a/src/ops/spec_inc_multihead_self_attention.cu b/src/ops/spec_inc_multihead_self_attention.cu index 88dd3f92e4..b31e5d0994 100644 --- a/src/ops/spec_inc_multihead_self_attention.cu +++ b/src/ops/spec_inc_multihead_self_attention.cu @@ -50,7 +50,8 @@ __global__ void compute_spec_inc_attention_kernel_generation_kernel( int hidden_size, BatchConfig::PerRequestInfo *request_infos, BeamSearchBatchConfig::BeamSearchPerRequestInfo *beam_request_infos, - BatchConfig::BitMask *causalMask) { + BatchConfig::BitMask *causalMask, + bool *request_completed) { // q, k using Q_vec = typename VEC_K::Type; @@ -86,11 +87,12 @@ __global__ void compute_spec_inc_attention_kernel_generation_kernel( // request_infos[batch_config_request_id].first_token_depth_in_request + // request_infos[batch_config_request_id].num_tokens_in_batch; - int const totalCacheSize = bitmask.non_tree_cache_size + bitmask.tree_size; + int const totalCacheSize = + bitmask.non_tree_cache_size + bitmask.tree_size + bitmask.prompt_size - 1; int first_token_idx = 0; - for (int r = 0; r < request_idx; r++) { - first_token_idx += causalMask[r].this_layer_size; + for (int r = 0; r < batch_config_request_id; r++) { + first_token_idx += request_completed[r] ? 0 : causalMask[r].this_layer_size; } int const tree_branch_num = @@ -138,7 +140,8 @@ __global__ void compute_spec_inc_attention_kernel_generation_kernel( ii * THREADS_PER_KEY * K_VEC_SIZE); } - int const query_token = bitmask.tree_size - tree_branch_num + qi; + int const query_token = + bitmask.prompt_size + bitmask.tree_size - 1 - tree_branch_num + qi; __syncthreads(); for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { @@ -163,8 +166,12 @@ __global__ void compute_spec_inc_attention_kernel_generation_kernel( (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & (1 << query_token)))); - // if (blockIdx.y == 0 && blockIdx.x == 0 && !mask) { - // printf("spec inc attn qkqkqk %d, %.10f, %d\n", ti, qk, qi); + // if (head_idx == 0 && ti == 0 && request_idx == 15 && !mask) { + // printf("spec inc attn qkqkqk request id %d, %.10f, %d\n", + // batch_config_request_id, + // ti, + // qk, + // qi); // } qk_max = mask ? qk_max : fmaxf(qk_max, qk); qk_smem[ti - first_step] = mask ? 0.f : qk; @@ -336,17 +343,12 @@ __global__ void spec_inc_store_kv_cache( BatchConfig::BitMask bitmask = causalMask[req_id]; - // int const tree_branch_num = beamRequestInfos[req_id].sub_request_num; - - // int const query_token = bitmask.non_tree_cache_size + bitmask.tree_size - - // tree_branch_num + sub_req_id + tok_id; - // bitmask.tree_size - tree_branch_num + sub_req_id; - // if prompt token -> token id // if tree token: - int const cache_idx = bitmask.non_tree_cache_size + bitmask.tree_size - - bitmask.this_layer_size + token_idx - - request_token_offset; + + int const cache_idx = bitmask.prompt_size + bitmask.non_tree_cache_size + + bitmask.tree_size - 1 - bitmask.this_layer_size + + token_idx - request_token_offset; kCache_ptr[req_id * (hidden_size * max_seq_len) + (cache_idx)*hidden_size + offset] = kVal; @@ -411,7 +413,8 @@ void update_kv_cache_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, m->hidden_size, \ m->request_infos, \ m->beam_request_infos, \ - m->causalMask) + m->causalMask, \ + m->request_completed) template void compute_spec_inc_attention_kernel_generation( @@ -420,7 +423,8 @@ void compute_spec_inc_attention_kernel_generation( DT *output_ptr, cudaStream_t stream) { // one block == one head per request - dim3 grid(m->num_q_heads, bc->num_active_requests()); + // how many generation requests + dim3 grid(m->num_q_heads, bc->get_speculative_request_num()); int const per_head_size = m->qProjSize; float scale = (*m->qk_prod_scaling) ? 1.0f / sqrt(m->kProjSize) : 1.0f; size_t smem_sz; @@ -499,11 +503,10 @@ void compute_attention_kernel_prompt(SpecIncMultiHeadSelfAttentionMeta const *m, for (int i = 0; i < bc->max_requests_per_batch(); i++) { if (bc->request_completed[i]) { continue; + } else if (tokens_previous_requests < bc->num_generation_tokens) { + tokens_previous_requests += bc->requestsInfo[i].num_tokens_in_batch; + continue; } - // else if (tokens_previous_requests < bc->num_generation_tokens) { - // tokens_previous_requests += bc->requestsInfo[i].num_tokens_in_batch; - // continue; - // } // all requests in prompt phase should only have one sub requests; assert(bc->sub_requests[i] == 1); @@ -659,10 +662,10 @@ void compute_attention_kernel_prompt(SpecIncMultiHeadSelfAttentionMeta const *m, // To get C, skip over softmax(QK^T/sqrt(d_k))V products from previous // requests - // print_tensor((float*)C_softmax, 32, "C_softmax"); + int token_offset = bc->requestsInfo[i].first_token_offset_in_batch; + C = static_cast
(m->attn_heads) + - (tokens_previous_requests + bc->num_generation_tokens) * - m->num_q_heads * m->vProjSize; + (token_offset)*m->num_q_heads * m->vProjSize; checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, CUBLAS_OP_N, CUBLAS_OP_T, @@ -860,6 +863,13 @@ SpecIncMultiHeadSelfAttentionMeta::SpecIncMultiHeadSelfAttentionMeta( sizeof(BatchConfig::tokensInfo) + sizeof(BatchConfig::requestsInfo) + sizeof(BeamSearchBatchConfig::beamTokenInfo) + sizeof(BeamSearchBatchConfig::beamRequestsInfo)); + + request_completed = reinterpret_cast( + reinterpret_cast(handler.batch_config_metadata) + + sizeof(BatchConfig::tokensInfo) + sizeof(BatchConfig::requestsInfo) + + sizeof(BeamSearchBatchConfig::beamTokenInfo) + + sizeof(BeamSearchBatchConfig::beamRequestsInfo) + + sizeof(BatchConfig::causalMask)); } cudaStreamSynchronize(stream); diff --git a/src/ops/tree_inc_multihead_self_attention.cu b/src/ops/tree_inc_multihead_self_attention.cu index b4af80976f..fc86e1498e 100644 --- a/src/ops/tree_inc_multihead_self_attention.cu +++ b/src/ops/tree_inc_multihead_self_attention.cu @@ -54,6 +54,7 @@ __global__ void compute_attention_kernel_fused_kernel( int num_heads, int num_requests, BatchConfig::BitMask *causalMask, + bool *request_completed, int qk_smem_sz) { // q, k @@ -90,13 +91,14 @@ __global__ void compute_attention_kernel_fused_kernel( BatchConfig::BitMask bitmask = causalMask[batch_config_request_id]; int first_token_idx = 0; - for (int r = 0; r < request_idx; r++) { - first_token_idx += request_infos[r].num_tokens_in_batch; + for (int r = 0; r < batch_config_request_id; r++) { + first_token_idx += + request_completed[r] ? 0 : request_infos[r].num_tokens_in_batch; } - // if(tidx == 0 && head_idx == 0){ - // printf("tree req: %d, %d\n", request_idx, first_token_idx); - // } + bool prompt_phase = request_infos[batch_config_request_id].prompt_phase; + int q_start = + request_infos[batch_config_request_id].first_token_depth_in_request; // shared memory objects extern __shared__ char smem_[]; @@ -139,7 +141,7 @@ __global__ void compute_attention_kernel_fused_kernel( q_ptr + (hidden_size * QKV_WEIGHT_NUM * qi) + ki + ii * THREADS_PER_KEY * K_VEC_SIZE); - // if (head_idx == 0 && qi == 1 && tidx == 0) { + // if (head_idx == 0 && request_idx == 1 && tidx == 0) { // printf("laod q %d, %d %.10f\n", // request_idx, // qi,q_vecs[ki_o][ii].x); @@ -163,19 +165,23 @@ __global__ void compute_attention_kernel_fused_kernel( if (ti < tlength && tidx % THREADS_PER_KEY == 0) { bool const mask = - (ti >= bitmask.non_tree_cache_size && - (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & (1 << qi)))); + prompt_phase ? (qi + q_start < ti) + : (ti >= bitmask.non_tree_cache_size && + (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & + (1 << qi)))); qk_max = mask ? qk_max : fmaxf(qk_max, qk); - // if (head_idx == 0 && qi == 0 && !mask) { - // printf("tree attn qkqkqkqk request id %d, %d %.10f, %.10f, %.10f\n - // ", + // if (head_idx == 0 && !mask) { + // printf("tree attn qkqkqkqk request id %d qi%d, ti %d, %.10f, %.10f, + // %.10f, %d\n", // request_idx, + // qi, // ti, // qk, // q_vecs[ki_o][0].x, - // k[0].x); + // k[0].x, + // bitmask.non_tree_cache_size); // } qk_smem[ti - first_step] = mask ? 0.0f : qk; } @@ -217,8 +223,10 @@ __global__ void compute_attention_kernel_fused_kernel( float exp_sum = 0.f; for (int ti = first_step + tidx; ti < tlength; ti += THREADS_PER_BLOCK) { bool const mask = - (ti >= bitmask.non_tree_cache_size && - (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & (1 << qi)))); + prompt_phase ? (q_start + qi < ti) + : (ti >= bitmask.non_tree_cache_size && + (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & + (1 << qi)))); float logit = mask ? 0.0f : __expf(qk_smem[ti - first_step] - qk_max); exp_sum += logit; qk_smem[ti - first_step] = mask ? 0.0f : logit; @@ -265,8 +273,11 @@ __global__ void compute_attention_kernel_fused_kernel( if (ti < tlength) { bool const mask = - (ti >= bitmask.non_tree_cache_size && - (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & (1 << qi)))); + prompt_phase + ? (q_start + qi < ti) + : (ti >= bitmask.non_tree_cache_size && + (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & + (1 << qi)))); float logit = mask ? 0.0f : qk_smem[ti - first_step]; out = FlexFlow::fma(logit, cast_to_float(v), out); } @@ -810,6 +821,7 @@ void compute_attention_kernel(TreeIncMultiHeadSelfAttentionMeta const *m, m->num_q_heads, \ bc->num_active_requests(), \ m->causalMask, \ + m->request_completed, \ smem_sz[0]) template @@ -841,7 +853,6 @@ void compute_attention_kernel_fused(TreeIncMultiHeadSelfAttentionMeta const *m, dim3 grid(m->num_q_heads, bc->num_active_requests()); int const per_head_size = m->qProjSize; float scale = (*m->qk_prod_scaling) ? 1.0f / sqrt(m->kProjSize) : 1.0f; - // 0->qk production size, 1->total shared size int smem_sz[2]; if (per_head_size == 64) { @@ -890,17 +901,6 @@ void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m, // std::cout << "tokens to be committed: " << bc->num_tokens_to_commit << // "\n"; - cudaMemcpyAsync(m->committed_token_infos, - &(bc->committed_tokens), - bc->num_tokens_to_commit * - sizeof(TreeVerifyBatchConfig::CommittedTokensInfo), - cudaMemcpyHostToDevice, - stream); - cudaMemcpyAsync(m->causalMask, - &(bc->causalMask), - bc->num_active_requests() * sizeof(BatchConfig::BitMask), - cudaMemcpyHostToDevice, - stream); commit_tokens
(m, bc, stream); // After commit we update m->num_active_tokens to be the number of active @@ -1068,6 +1068,12 @@ TreeIncMultiHeadSelfAttentionMeta::TreeIncMultiHeadSelfAttentionMeta( sizeof(BatchConfig::tokensInfo) + sizeof(BatchConfig::requestsInfo) + sizeof(BatchConfig::causalMask)); + + request_completed = reinterpret_cast( + reinterpret_cast(handler.batch_config_metadata) + + sizeof(BatchConfig::tokensInfo) + sizeof(BatchConfig::requestsInfo) + + sizeof(BatchConfig::causalMask) + + sizeof(TreeVerifyBatchConfig::committed_tokens)); } cudaStreamSynchronize(stream); diff --git a/src/runtime/batch_config.cc b/src/runtime/batch_config.cc index d2fbc0883f..c432208eca 100644 --- a/src/runtime/batch_config.cc +++ b/src/runtime/batch_config.cc @@ -84,6 +84,12 @@ int BatchConfig::max_tokens_per_batch() { return RequestManager::get_request_manager()->get_max_tokens_per_batch(); } +/*static*/ +int BatchConfig::max_verify_tokens_per_batch() { + return RequestManager::get_request_manager() + ->get_max_verify_tokens_per_batch(); +} + /*static*/ int BatchConfig::max_sequence_length() { return RequestManager::get_request_manager()->get_max_sequence_length(); diff --git a/src/runtime/beam_search_batch_config.cc b/src/runtime/beam_search_batch_config.cc index 74843e9460..ff7bf1a819 100644 --- a/src/runtime/beam_search_batch_config.cc +++ b/src/runtime/beam_search_batch_config.cc @@ -85,6 +85,10 @@ int BeamSearchBatchConfig::max_beam_depth_all_requests() const { return max_depth_all_requests; } +int BeamSearchBatchConfig::get_speculative_request_num() const { + return speculative_request_num; +} + int BeamSearchBatchConfig::current_depth_all_requests() const { int current_depth = 0; for (int i = 0; i < BeamSearchBatchConfig::max_requests_per_batch(); i++) { diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 32b524f643..76bed36bda 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -4805,6 +4805,20 @@ void register_flexflow_internal_tasks(Runtime *runtime, runtime->register_task_variant(registrar); } } + { + TaskVariantRegistrar registrar(EMBED_INF_TASK_ID, "Embedding Inference"); + registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); + registrar.set_leaf(); + if (pre_register) { + Runtime::preregister_task_variant( + registrar, "Embedding Inference Task"); + } else { + if (enable_control_replication) { + registrar.global_registration = false; + } + runtime->register_task_variant(registrar); + } + } { TaskVariantRegistrar registrar(EMBED_BWD_TASK_ID, "Embedding Backward"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 89d4ddaed4..88754f5a82 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -97,6 +97,12 @@ int RequestManager::get_max_tokens_per_batch() { return max_tokens_per_batch; } +int RequestManager::get_max_verify_tokens_per_batch() { + assert(max_tokens_per_batch > 0); + return max_tokens_per_batch + + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM * max_requests_per_batch; +} + void RequestManager::set_max_sequence_length(int max_seq_length) { assert(max_sequence_length == -1 || max_sequence_length == max_seq_length); max_sequence_length = max_seq_length; @@ -1126,7 +1132,6 @@ BeamSearchBatchConfig old_bc.beamRequestsInfo[i].sub_request_num, tree, old_bc.beamRequestsInfo[i].current_depth); - // assert(false); for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) { int depth = new_bc.requestsInfo[i].first_token_depth_in_request + j; for (int k = 0; k < new_bc.beamRequestsInfo[i].sub_request_num; k++) { @@ -1146,6 +1151,9 @@ BeamSearchBatchConfig } } + // how many requests is in speculative phase + new_bc.speculative_request_num = num_active_req + 1; + // Add prompt tokens to the batch for (int i = 0; i < BatchConfig::max_requests_per_batch(); i++) { if (old_bc.request_completed[i] || old_bc.request_running[i]) { @@ -1184,13 +1192,14 @@ BeamSearchBatchConfig spec_infer_tree_width.size() > ssm_decoding_steps ? spec_infer_tree_width[ssm_decoding_steps] : 1; - printf("beam size: %d, %d\n", - new_bc.beamRequestsInfo[i].beam_size, - ssm_decoding_steps); + // printf("beam size: %d, %d\n", + // new_bc.beamRequestsInfo[i].beam_size, + // ssm_decoding_steps); new_bc.beamRequestsInfo[i].max_depth = old_bc.beamRequestsInfo[i].max_depth; - new_bc.sub_requests[i] = - old_bc.sub_requests[i] * new_bc.beamRequestsInfo[i].beam_size; + // new_bc.sub_requests[i] = + // old_bc.sub_requests[i] * new_bc.beamRequestsInfo[i].beam_size; + new_bc.sub_requests[i] = 1; new_bc.beamRequestsInfo[i].sub_request_num = old_bc.beamRequestsInfo[i].sub_request_num; @@ -1218,6 +1227,9 @@ BeamSearchBatchConfig request.tokens.size()) { // request is done new_bc.requestsInfo[i].num_tokens_in_batch = 0; + new_bc.causalMask[i].this_layer_size = 0; + new_bc.beamRequestsInfo[i].sub_request_num = 0; + new_bc.beamRequestsInfo[i].beam_size = 1; } else { // Prompt phase new_bc.requestsInfo[i].num_tokens_in_batch = @@ -1227,12 +1239,8 @@ BeamSearchBatchConfig new_bc.requestsInfo[i].first_token_depth_in_request); request.ssm_cache_size += new_bc.requestsInfo[i].num_tokens_in_batch; BeamTree tree = request.beam_trees[old_bc.model_id]; - appendBitMask(new_bc.causalMask[i], - new_bc.beamRequestsInfo[i].sub_request_num, - old_bc.beamRequestsInfo[i].beam_size, - old_bc.beamRequestsInfo[i].sub_request_num, - tree, - old_bc.beamRequestsInfo[i].current_depth); + appendPendingRequest(new_bc.causalMask[i], + new_bc.requestsInfo[i].num_tokens_in_batch); } if (verbose) { @@ -1258,11 +1266,11 @@ BeamSearchBatchConfig // get value from requestinfo new_bc.tokensInfo[new_bc.num_tokens].token_id = - request.tokens[request.tokens.size() - 1]; + request.tokens[request.tokens.size() - + new_bc.requestsInfo[i].num_tokens_in_batch + j]; new_bc.beamTokenInfo[new_bc.num_tokens].sub_request_index = k; new_bc.num_tokens++; - num_generation_tokens++; } } } @@ -1319,7 +1327,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( new_bc.num_tokens_to_commit = 0; new_bc.num_tokens = 0; - int max_prompt_load_size = get_max_tokens_per_batch(); + int max_prompt_load_size = get_max_verify_tokens_per_batch(); for (int i = 0; i < TreeVerifyBatchConfig::max_requests_per_batch(); i++) { if (old_batches.at(0).request_completed[i]) { continue; @@ -1427,7 +1435,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( new_bc.num_tokens++; new_bc.requestsInfo[i].num_tokens_in_batch++; - if (new_bc.num_tokens > get_max_tokens_per_batch()) { + if (new_bc.num_tokens > get_max_verify_tokens_per_batch()) { assert(false && "Exceeding the space available in the TreeVerify batch"); break; @@ -1453,7 +1461,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( new_bc.num_tokens++; new_bc.requestsInfo[i].num_tokens_in_batch++; - if (new_bc.num_tokens == get_max_tokens_per_batch() && + if (new_bc.num_tokens == get_max_verify_tokens_per_batch() && (j != dfs_tree_inputs.size() - 1)) { cutLayer = true; break; @@ -1542,7 +1550,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( new_bc.num_tokens++; } - if (new_bc.num_tokens > get_max_tokens_per_batch()) { + if (new_bc.num_tokens > get_max_verify_tokens_per_batch()) { assert(false && "Exceeding the space available in the TreeVerify batch"); break; @@ -1555,15 +1563,17 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( request.status = Request::RUNNING; new_bc.request_running[i] = true; - std::cout << "new_bc.requestsInfo[i].num_tokens_in_batch: " - << new_bc.requestsInfo[i].num_tokens_in_batch << std::endl; + // std::cout << "new_bc.requestsInfo[i].num_tokens_in_batch: " + // << new_bc.requestsInfo[i].num_tokens_in_batch << + // std::endl; + new_bc.requestsInfo[i].prompt_phase = true; dfs_tree_inputs[guid] = std::vector>{std::make_pair( request.tokens.back(), request.tokens.size() - 1)}; } } else { // launch the request into running phase after loading all prompt - if (get_max_tokens_per_batch() - new_bc.num_tokens > 0) { + if (get_max_verify_tokens_per_batch() - new_bc.num_tokens > 0) { // std::cout << "Initialization running phase: " // << new_bc.requestsInfo[i].num_tokens_in_batch << "\n"; request.status = Request::RUNNING; @@ -1576,9 +1586,11 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( new_bc.num_tokens++; new_bc.requestsInfo[i].num_tokens_in_batch++; - std::cout << "new_bc.requestsInfo[i].num_tokens_in_batch: " - << new_bc.requestsInfo[i].num_tokens_in_batch << std::endl; + // std::cout << "new_bc.requestsInfo[i].num_tokens_in_batch2: " + // << new_bc.requestsInfo[i].num_tokens_in_batch << + // std::endl; + new_bc.requestsInfo[i].prompt_phase = true; dfs_tree_inputs[guid] = std::vector>{std::make_pair( request.tokens.back(), request.tokens.size() - 1)}; @@ -1760,20 +1772,14 @@ void RequestManager::update_beam_metadata(BeamSearchBatchConfig &new_bc, // prompt phase, init task void RequestManager::initBitMask(BatchConfig::BitMask &bitmask, int initLength) { - assert(initLength <= BatchConfig::MAX_SPEC_TREE_TOKEN_NUM && - "do not support tree size > 64"); + assert(initLength > 0); // eg. 4 tokens: t1: 0000000..1111, t2: 0000000..1110, t3: 0000000..1100, t4: // 0000000..1000 bitmask.non_tree_cache_size = 0; - bitmask.tree_size = initLength; + bitmask.tree_size = 1; bitmask.prompt_size = initLength; bitmask.this_layer_size = initLength; - for (int i = 0; i < bitmask.prompt_size; i++) { - for (int j = i; j < bitmask.prompt_size; j++) { - bitmask.mask[i] |= (1 << j); - } - } // std::cout << "see bit mask" << bitmask.prompt_size << "\n"; // std::cout << "see bit mask" << std::bitset<64>(bitmask.mask[0]) << "\n"; // std::cout << "see bit mask" << std::bitset<64>(bitmask.mask[1]) << "\n"; @@ -1810,6 +1816,25 @@ void RequestManager::updateBitMask(BatchConfig::BitMask &bitmask, // << "\n"; } +// prompt phase, init task +void RequestManager::appendPendingRequest(BatchConfig::BitMask &bitmask, + int initLength) { + assert(initLength > 0); + std::cout << "append pending bit mask: " << initLength << "\n"; + // eg. 4 tokens: t1: 0000000..1111, t2: 0000000..1110, t3: 0000000..1100, t4: + // 0000000..1000 + bitmask.non_tree_cache_size = 0; + bitmask.tree_size = 1; + bitmask.prompt_size += initLength; + bitmask.this_layer_size = initLength; + + // for (int i = 0; i < bitmask.prompt_size; i++) { + // for (int j = i; j < bitmask.prompt_size; j++) { + // bitmask.mask[i] |= (1 << j); + // } + // } +} + // prepare next beam, append layers to the tree void RequestManager::appendBitMask(BatchConfig::BitMask &bitmask, int newNodes, @@ -1862,12 +1887,6 @@ void RequestManager::appendBitMask(BatchConfig::BitMask &bitmask, } } - // std::cout << "token idx: " << token_idx << ", " << pre_tree_size << ", " - // << new_nodes_start_idx << ", " << newNodes - // << "current depth: " << currentDepth << "\n"; - // std::cout << "new nodes end " << new_nodes_start_idx << "\n"; - - // std::cout << "tree size: " << bitmask.tree_size << "\n"; assert(token_idx == pre_tree_size); assert(currentDepth <= 1 || new_nodes_start_idx == bitmask.tree_size); diff --git a/src/runtime/request_manager.cu b/src/runtime/request_manager.cu index 51c52c3026..8380d6be73 100644 --- a/src/runtime/request_manager.cu +++ b/src/runtime/request_manager.cu @@ -35,10 +35,17 @@ void RequestManager::load_tokens_task( // Extreme long prompts are not supported, only load up to // BatchConfig::max_tokens_per_batch() as prompt - if (batch_config->num_tokens > BatchConfig::max_tokens_per_batch()) { + if (batch_config->num_tokens > BatchConfig::max_tokens_per_batch() && + batch_config->get_mode() == INC_DECODING_MODE) { printf("Warning: too many tokens in prompt, only load up to %d tokens\n", BatchConfig::max_tokens_per_batch()); printf("Got: %d tokens\n", batch_config->num_tokens); + } else if (batch_config->num_tokens > + BatchConfig::max_verify_tokens_per_batch()) { + printf("Warning: Speculative decoding. too many tokens in prompt, only " + "load up to %d tokens\n", + BatchConfig::max_verify_tokens_per_batch()); + printf("Got: %d tokens\n", batch_config->num_tokens); } for (int i = 0; i < batch_config->num_tokens; i++) { @@ -117,8 +124,16 @@ void RequestManager::load_batch_config_task( sizeof(BatchConfig::causalMask), cudaMemcpyHostToDevice, stream)); - total_copy_size += sizeof(BatchConfig::causalMask); + + checkCUDA(cudaMemcpyAsync( + static_cast(handle.batch_config_metadata) + total_copy_size, + &(batch_config->request_completed), + sizeof(BatchConfig::request_completed), + cudaMemcpyHostToDevice, + stream)); + + total_copy_size += sizeof(BatchConfig::request_completed); } else if (batch_config->get_mode() == TREE_VERIFY_MODE) { TreeVerifyBatchConfig const *tree_batch_config = static_cast(batch_config); @@ -137,6 +152,15 @@ void RequestManager::load_batch_config_task( cudaMemcpyHostToDevice, stream)); total_copy_size += sizeof(TreeVerifyBatchConfig::committed_tokens); + + checkCUDA(cudaMemcpyAsync( + static_cast(handle.batch_config_metadata) + total_copy_size, + &(batch_config->request_completed), + sizeof(BatchConfig::request_completed), + cudaMemcpyHostToDevice, + stream)); + + total_copy_size += sizeof(BatchConfig::request_completed); } // add a size check