From c21269907d18192ee6a8aaf6a591915a3e72d582 Mon Sep 17 00:00:00 2001 From: Qinghan Chen Date: Sat, 10 Aug 2024 00:14:45 +0000 Subject: [PATCH] ckpt before adding tmp block to batchconfig --- src/runtime/request_manager.cc | 1 + src/runtime/request_manager.cu | 57 +++++++++++++++++++++++++++------- 2 files changed, 47 insertions(+), 11 deletions(-) diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 98f40242bc..866a112f5a 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -1327,6 +1327,7 @@ BatchConfig RequestManager::prepare_verify_batch_config() { // Load the tokens on the token tree that are not yet pruned to // BatchConfig.tokensInfo. + // page attention: we should also add these tokens to the logical blocks TokenTree &token_tree = request.speculative_token_trees[0]; int token_tree_index = 0; int layer_index = 0; diff --git a/src/runtime/request_manager.cu b/src/runtime/request_manager.cu index fea9670069..15e487dbf9 100644 --- a/src/runtime/request_manager.cu +++ b/src/runtime/request_manager.cu @@ -467,16 +467,16 @@ void RequestManager::load_batch_config_task( int parallelism = batch_size; - prepare_inference_params_kernel_h(batch_config, - pm, - handle, - stream, - max_num_pages, - q_indptr_h, - kv_indptr_h, - kv_indices_h, - kv_last_page_len_h, - qk_indptr_h); + // prepare_inference_params_kernel_h(batch_config, + // pm, + // handle, + // stream, + // max_num_pages, + // q_indptr_h, + // kv_indptr_h, + // kv_indices_h, + // kv_last_page_len_h, + // qk_indptr_h); prepare_inference_params_kernel<<custom_mask, + sizeof(uint8_t) * batch_size * max_num_pages, + cudaMemcpyDeviceToHost)); + printf("------------------------updated mask------------------------\n"); + for (int i = 0; i < BatchConfig::max_requests_per_batch(); i++) { + if (batch_config -> request_available[i]) { + for (int j = 0; j < BatchConfig::max_spec_tree_token_num(); j++) { + printf("%d ", custom_mask[i * BatchConfig::max_spec_tree_token_num() + j]); + } + printf("\n"); + } + } } } @@ -552,8 +574,21 @@ void RequestManager::load_batch_config_task( handle.tree_verify_attention_metadata->prompt_handler_collections[batch_size]); } + static int32_t q_indptr_h[BatchConfig::MAX_NUM_REQUESTS + 1], kv_indptr_h[BatchConfig::MAX_NUM_REQUESTS + 1]; + q_indptr_h[0] = 0; + kv_indptr_h[0] = 0; + for (int req_idx = 0, indptr_idx = 0; req_idx < batch_config->max_requests_per_batch(); req_idx++) { + if (batch_config->request_available[req_idx]) { + int q_len = batch_config->requestsInfo[req_idx].num_tokens_in_batch; + int kv_len = batch_config->requestsInfo[req_idx].num_tokens_in_batch + + batch_config->requestsInfo[req_idx].first_token_index_in_request; + q_indptr_h[indptr_idx + 1] = q_indptr_h[indptr_idx] + q_len; + kv_indptr_h[indptr_idx + 1] = kv_indptr_h[indptr_idx] + (kv_len + kPagesize - 1) / kPagesize; + indptr_idx++; + } + } + handler->SetCUDAStream(stream); - // printf("here??\n"); handler->BeginForward(static_cast( static_cast(handle.tree_verify_attention_metadata->workspace) + handle.tree_verify_attention_metadata->workspace_block * batch_size),