Skip to content

Commit

Permalink
ckpt single request
Browse files Browse the repository at this point in the history
  • Loading branch information
Bob-Chen222 committed Oct 11, 2024
1 parent b12df8c commit 6298f2a
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 5 deletions.
2 changes: 2 additions & 0 deletions include/flexflow/page_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class LogicalTokenBlock {
void append_tokens(const std::vector<TokenId>& token_ids_to_append, bool committed);

int get_num_tokens() const { return num_tokens; }
int get_num_commit_tokens() const { return num_commit_tokens; }
int get_num_spec_tokens() const { return num_spec_tokens; }

std::vector<TokenId> get_token_ids() const;

Expand Down
1 change: 1 addition & 0 deletions include/flexflow/request_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ class RequestManager {
Request &request, bool is_commit);
int append_token_to_block(Request &request, TokenId token, bool is_commit);
void reset_block_table(Request &request);
void print_num_tokens(Request &request);

// Token tree related
void init_token_tree(RequestGuid guid);
Expand Down
4 changes: 2 additions & 2 deletions src/ops/tree_inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ __global__ void commit_tokens_kernel(
}
}
int start = kv_indptr[requext_idx_in_batch];
int end = kv_indptr[requext_idx_in_batch + 1] - 1;
// int start = kv_indptr[requext_idx_in_batch];
// int end = kv_indptr[requext_idx_in_batch + 1] - 1;
for (int i = 0; i < *num_committed_tokens; i++) {
if (committedTokenInfos[i].request_index == requext_idx_in_batch) {
Expand Down
14 changes: 13 additions & 1 deletion src/runtime/request_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1529,10 +1529,11 @@ BatchConfig RequestManager::prepare_verify_batch_config() {
committed_tokens.at(committed_token_index);

// assert(request.page_last_committed < request.blocks.size());
// printf("in verify: page_last_committed: %d, request->blocks.size(): %d\n", request.page_last_committed, request.blocks.size());
printf("in verify: page_last_committed: %d, request->blocks.size(): %d\n", request.page_last_committed, request.blocks.size());
int idx_to_physical = append_token_to_block(request, committed_token.token_id, true);
int idx_from_logical = committed_token.from_index - request.first_token_offset_in_batch;
int idx_from_physical = block_table_before_commit[idx_from_logical / kPagesize] * kPagesize + committed_token.from_index % kPagesize;
printf("id to physical: %d, from physical: %d\n", idx_to_physical, idx_from_physical);


new_bc.committed_tokens[new_bc.num_tokens_to_commit].request_index =
Expand Down Expand Up @@ -1941,6 +1942,10 @@ void RequestManager::_append_block_to_request(
kPagesize);
request.blocks.push_back(block);
page_manager->allocate_one_block(request.guid);
std::vector<int> block_table_indices = page_manager->get_block_table_indices(request.guid);
for (int i = 0; i < block_table_indices.size(); i++) {
printf("block table indices: %d\n", block_table_indices[i]);
}
assert(request.blocks.size() == page_manager->get_block_table_indices(request.guid).size());
// update page_id_commit
if (is_commit) {
Expand Down Expand Up @@ -1969,7 +1974,9 @@ int RequestManager::append_token_to_block(Request &request, TokenId token, bool
request.blocks.back().append_tokens({token}, is_commit);
assert(request.blocks.size() == page_manager->get_block_table_indices(request.guid).size());
int idx_logical = get_idx_last_logical_token(request);
// printf("idx_logical: %d\n", idx_logical);
int idx_physical = idx_logical_to_physical(request, idx_logical);
// printf("idx_physical: %d\n", idx_physical);
return idx_physical;
}

Expand All @@ -1992,6 +1999,11 @@ void RequestManager::reset_block_table(Request &request){
printf("number of blocks: %d\n", request.blocks.size());
printf("num spec tokens: %d\n", request.blocks.back().get_num_spec_tokens());
printf("num committed tokens: %d\n", request.blocks.back().get_num_commit_tokens());
// the indices of block table should be the same as the number of blocks
std::vector<int> block_table = page_manager->get_block_table_indices(request.guid);
for (int i = 0; i < request.blocks.size(); i++) {
printf("block table indices: %d\n", block_table[i]);
}

assert(request.blocks.size() == page_manager->get_block_table_indices(request.guid).size());
return;
Expand Down
17 changes: 15 additions & 2 deletions src/runtime/request_manager.cu
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ void prepare_inference_params_kernel_h(BatchConfig const *batch_config,
int q_len = batch_config->requestsInfo[req_idx].num_tokens_in_batch;
int kv_len = batch_config->requestsInfo[req_idx].num_tokens_in_batch +
batch_config->requestsInfo[req_idx].first_token_index_in_request;

q_lens += q_len;
qk_lens += (q_len * kv_len + 7) / 8;
indices_offset = indices_lens;
Expand All @@ -106,8 +107,12 @@ void prepare_inference_params_kernel_h(BatchConfig const *batch_config,
kv_indptr_h[indptr_idx + 1] = batch_config->requestsInfo[req_idx].num_kv_pages + kv_indptr_h[indptr_idx];

assert(batch_config->requestsInfo[req_idx].num_kv_pages == (kv_len + kPagesize - 1) / kPagesize);
assert(batch_config->requestsInfo[req_idx].kv_last_page_len <= 64);
assert(batch_config->requestsInfo[req_idx].kv_last_page_len <= kPagesize);
std::vector<int32_t> kv_indices = pm -> get_block_table_indices(batch_config->requestsInfo[req_idx].request_guid);
printf("request_guid: %d\n", batch_config->requestsInfo[req_idx].request_guid);
printf("kv_indices.size() = %d, kv_len = %d\n", kv_indices.size(), kv_len);
printf("kv last page len = %d\n", batch_config->requestsInfo[req_idx].kv_last_page_len);
printf("num_kv_pages = %d\n", batch_config->requestsInfo[req_idx].num_kv_pages);
assert(kv_indices.size() == (kv_len + kPagesize - 1) / kPagesize);
for (int i = indices_offset; i < indices_lens; i++) {
kv_indices_h[i] = kv_indices[i - indices_offset];
Expand Down Expand Up @@ -616,7 +621,9 @@ void RequestManager::load_batch_config_task(
}
}
} else if (batch_config->get_mode() == TREE_VERIFY_MODE) {
static PageManager *pm = PageManager::get_page_manager();
PageManager *pm = PageManager::get_page_manager();
// hardcode request
// printf("request has allocated %d pages\n", pm -> get_block_table_indices(1000000).size());
static int32_t q_indptr_h[BatchConfig::MAX_NUM_REQUESTS + 1], kv_indptr_h[BatchConfig::MAX_NUM_REQUESTS + 1];
static int32_t kv_indices_h[BatchConfig::MAX_NUM_REQUESTS * BatchConfig::MAX_NUM_TOKENS];
static int32_t qk_indptr_h[BatchConfig::MAX_NUM_REQUESTS + 1];
Expand Down Expand Up @@ -760,6 +767,12 @@ void RequestManager::load_batch_config_task(
handle.tree_verify_attention_metadata->num_kv_heads(),
handle.tree_verify_attention_metadata->head_dim(),
kPagesize);
cudaError_t syncErr = cudaDeviceSynchronize();
if (syncErr != cudaSuccess) {
printf("Kernel execution error: %s\n", cudaGetErrorString(syncErr));
assert(false);
}
}
}
}
Expand Down

0 comments on commit 6298f2a

Please sign in to comment.