Skip to content

Commit

Permalink
some further cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
Bob-Chen222 committed Sep 25, 2024
1 parent 311ed13 commit 1297687
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 104 deletions.
5 changes: 2 additions & 3 deletions include/flexflow/batch_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,8 @@ class BatchConfig {
int num_tokens_in_batch = 0;

// page attention: we need some additional attention information here to allocate physical blocks in load_batch_config
// TODO: might need to add more fields here
int32_t num_kv_pages; //number of kv pages used
int32_t kv_last_page_len;
int32_t kv_last_page_len; //last page length of kv
RequestGuid request_guid;
};

Expand All @@ -88,7 +87,7 @@ class BatchConfig {
int request_index = -1;
};

std::vector<int32_t> page_indices; //the indices for each page
std::vector<int32_t> page_indices; //the physical block indices for each page

struct CommittedTokensInfo {
int index_in_kv_cache = -1; // the index in the temporary key-value cache
Expand Down
14 changes: 0 additions & 14 deletions src/ops/tree_inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ using flashinfer::QKVLayout;
__device__ __forceinline__ size_t get_k_entry_offset(int const token_idx,
int const page_idx,
int const hidden_size) {
// page attention: changed
size_t index = ((page_idx) * kPagesize * 2 + (token_idx % kPagesize)) * hidden_size;
return index;
}
Expand Down Expand Up @@ -138,7 +137,6 @@ __global__ void commit_tokens_kernel(
}
}
// get the starting index of kv page
// page attention: WARNING: this implicitly assume that the kv page is stored in the same order as the available requests
int start = kv_indptr[requext_idx_in_batch];
int end = kv_indptr[requext_idx_in_batch + 1] - 1;
for (int i = 0; i < num_committed_tokens; i++) {
Expand All @@ -154,8 +152,6 @@ __global__ void commit_tokens_kernel(
int const page_from_idx = kv_page_indices[start + (tok_id / kPagesize)];

// page attention: since we cannot store temporary tokens in the cache, we need to figure out another way
// WARNING: we assume that index_in_kv_cache is flattened index in gpu memory
// index_in_kv_cache is actually the flattened index
size_t from_k_idx = get_k_entry_offset(index_in_kv_cache, page_from_idx, hidden_size),
from_v_idx = get_v_entry_offset(index_in_kv_cache, page_from_idx, hidden_size);

Expand Down Expand Up @@ -309,7 +305,6 @@ __global__ void
size_t from_idx = token_idx * QKV_WEIGHT_NUM * hidden_size;
size_t to_k_idx = get_k_entry_offset(token_abs_idx, page_idx, hidden_size),
to_v_idx = get_v_entry_offset(token_abs_idx, page_idx, hidden_size);
// printf("to_k_idx: %lu, to_v_idx: %lu\n", to_k_idx, to_v_idx);
// key and value cache should be stored interleaved
kCache_ptr[to_k_idx + offset] =
Expand Down Expand Up @@ -471,15 +466,6 @@ void tree_verify_attention(TreeIncMultiHeadSelfAttentionMeta const *m,
m->handle.tree_verify_attention_metadata->kv_last_page_len,
sizeof(int32_t) * BatchConfig::MAX_NUM_REQUESTS,
cudaMemcpyDeviceToHost);
//print the request information
// for (int i = 0; i < bc->num_active_requests(); i++) {
// printf("request %d: ", i);
// for (int j = kv_indptr_tmp[i]; j < kv_indptr_tmp[i + 1]; j++) {
// printf("%d ", kv_indices_tmp[j]);
// }
// printf("\n");
// }
// printf("last page length: %d\n", kv_last_page_len_tmp[0]);
paged_kv_t<PageStorage::kIndices, QKVLayout::kNHD, half, int32_t> paged_kv(
num_kv_heads,
kPagesize,
Expand Down
70 changes: 0 additions & 70 deletions src/runtime/request_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -899,34 +899,19 @@ BatchConfig RequestManager::prepare_llm_prefilling_batch() {
std::cerr << "llm_cache_size: " << prefill_request->llm_cache_size << std::endl;

PageManager *page_manager = nullptr;
// page attention: add logical blocks here
// TODO: currently only support specinfer, might need to support incremental
if (decoding_mode == SPECULATIVE_DECODING) {
std::cerr << "number of tokens in prefilling request: " << prefill_request->tokens.size() << std::endl;
int start = prefill_request->llm_cache_size;
int end = prefill_request->llm_cache_size + prefill_request->num_tokens_in_batch;
_append_tokens_to_blocks(*prefill_request, prefill_request->tokens, true, start, end);
// printf("append block\n");
// printf("prefilling request num_tokens: %d\n", prefill_request->tokens.size());
page_manager = PageManager::get_page_manager();
// printf("page manager address prepare: %p\n", page_manager);
assert(page_manager != nullptr);
// we first need to update the physical block numbers
int num_allocated_blocks = page_manager->get_num_allocated_blocks(guid);
std::cerr << "called prefiling num_allocated_blocks: " << num_allocated_blocks << std::endl;
std::cerr << "num_allocated_blocks: " << num_allocated_blocks << std::endl;
std::cerr << "request.blocks.size(): " << request.blocks.size() << std::endl;
int diff_block = request.blocks.size() - num_allocated_blocks;
std::cerr << "diff_block: " << diff_block << std::endl;
assert(diff_block >= 0);
for (int i = 0; i < diff_block; i++) {
assert(false);
page_manager->allocate(guid);
}
bc.requestsInfo[request_index].kv_last_page_len = request.blocks.back().get_num_alloc_slots();
assert(bc.requestsInfo[request_index].kv_last_page_len <= 64);
bc.requestsIndices[request_index] = page_manager->get_block_table_indices(guid);
// update the num kv pages
bc.requestsInfo[request_index].num_kv_pages = bc.requestsIndices[request_index].size();
bc.requestsInfo[request_index].request_guid = guid;
}
Expand All @@ -949,19 +934,7 @@ BatchConfig RequestManager::prepare_llm_prefilling_batch() {
if (verbose) {
std::cout << "prepare_llm_prefilling_batch NEW batchconfig:" << std::endl;
bc.print();
// also print the page indices
if (decoding_mode == SPECULATIVE_DECODING) {
std::cout << "page indices are: " << std::endl;
std::vector<int> page_indices = page_manager -> get_block_table_indices(1000000);
for (int i = 0; i < page_indices.size(); i++) {
std::cout << page_indices[i] << " ";
}
std::cout << "last page len: " << request.blocks.back().get_num_alloc_slots() << std::endl;
std::cout << "last page commit token: " << request.blocks.back().num_commit_tokens << std::endl;
std::cout << "last page spec token: " << request.blocks.back().num_spec_tokens << std::endl;
}
}
// printf("end of prepare_llm_prefilling_batch\n");
return bc;
}

Expand Down Expand Up @@ -1263,7 +1236,6 @@ BatchConfig RequestManager::prepare_verify_batch_config() {
std::cout
<< "\n############### prepare_verify_batch_config ###############\n";
}
// printf("prepare_verify_batch_config_lalala\n");
// This method does the following:
// 1. Commit the verified tokens in the last iteration through the
// BatchConfig. We can do this request by request.
Expand Down Expand Up @@ -1306,19 +1278,16 @@ BatchConfig RequestManager::prepare_verify_batch_config() {

//page attention: delete the spec tokens in the logical block
assert(request.blocks.size() == page_manager->get_num_allocated_blocks(guid));
// get a copy of guid's physical blocks table
std::vector<int> block_table = page_manager->get_block_table_indices(guid);
std::vector<int> block_table_copy = block_table;
if (request.page_id_commit + 1 < request.blocks.size()) {
request.blocks.erase(request.blocks.begin() + request.page_id_commit + 1, request.blocks.end());
// std::cerr << "page_id_commit: " << request.page_id_commit << std::endl;
page_manager->erase_last_pages(guid, request.page_id_commit);
}
request.blocks.back().reset_num_spec_tokens();
block_table = page_manager->get_block_table_indices(guid);

// we still need to assure that number of logical blocks is the same as the number of physical blocks
// std::cerr << "number of logical blocks after: " << request.blocks.size() << std::endl;
assert(request.blocks.size() == page_manager->get_num_allocated_blocks(guid));


Expand All @@ -1344,19 +1313,7 @@ BatchConfig RequestManager::prepare_verify_batch_config() {
// page attention: add to request's logical block
_append_tokens_to_blocks(request, {committed_token.token_id}, true);
new_bc.committed_tokens[new_bc.num_tokens_to_commit].token_depth = page_manager->get_block_table_indices(guid).back() * kPagesize + request.blocks.back().get_num_alloc_slots() - 1;
printf("token depth: %d\n", new_bc.committed_tokens[new_bc.num_tokens_to_commit].token_depth);
printf("back index: %d\n", page_manager->get_block_table_indices(guid).back());
printf("last commit page: %d\n", request.page_id_commit);
// printf("index_to_kv_cache: %d\n", new_bc.committed_tokens[new_bc.num_tokens_to_commit].index_to_kv_cache);
}

// printf("num tokens currently in the last page: %d\n", request.blocks.back().get_num_alloc_slots());


// Load the tokens on the token tree that are not yet pruned to
// BatchConfig.tokensInfo.
// page attention: we should also add these tokens to the logical blocks
// std:cerr << "here1\n" << std::endl;
TokenTree &token_tree = request.speculative_token_trees[0];
int token_tree_index = 0;
int layer_index = 0;
Expand All @@ -1376,36 +1333,20 @@ BatchConfig RequestManager::prepare_verify_batch_config() {
}
layer_index++;
}
// std::cerr << "here2\n" << std::endl;
assert(token_tree_index == token_tree.tree_size);
// page attention: add metadata here
// I think we are now already have updated logical block data in update_results, and
// we need to update the block table here

// get page manager
// we first need to update the physical block numbers
int diff_block = request.blocks.size() - page_manager->get_num_allocated_blocks(guid);
assert(diff_block >= 0);
for (int i = 0; i < diff_block; i++) {
std::cerr << "allocate new block\n";
assert(false);
page_manager->allocate(guid);
}

std::cerr << "number of physical blocks after: " << page_manager->get_num_allocated_blocks(guid) << std::endl;
for (int i = 0; i < page_manager->get_block_table_indices(guid).size(); i++) {
std::cerr << page_manager->get_block_table_indices(guid)[i] << " ";
}
std::cerr << std::endl;
// update last kv len
new_bc.requestsInfo[request_index].kv_last_page_len = request.blocks.back().get_num_alloc_slots();
// update the block table
new_bc.requestsIndices[request_index] = page_manager->get_block_table_indices(guid);
// update the num kv pages
new_bc.requestsInfo[request_index].num_kv_pages = new_bc.requestsIndices[request_index].size();

new_bc.requestsInfo[request_index].request_guid = guid;

new_bc.requestsInfo[request_index].num_tokens_in_batch = token_tree_index;

request.first_token_offset_in_batch = new_bc.num_tokens - token_tree_index;
Expand All @@ -1419,14 +1360,6 @@ BatchConfig RequestManager::prepare_verify_batch_config() {
if (verbose) {
std::cout << "prepare_verify_batch_config NEW batchconfig:" << std::endl;
new_bc.print();
std::cout << "page indices are: " << std::endl;
std::vector<int> page_indices = page_manager -> get_block_table_indices(1000000);
for (int i = 0; i < page_indices.size(); i++) {
std::cout << page_indices[i] << " ";
}
std::cout << "last page len: " << all_requests[1000000].blocks.back().get_num_alloc_slots() << std::endl;
std::cout << "last page commit token: " << all_requests[1000000].blocks.back().num_commit_tokens << std::endl;
std::cout << "last page spec token: " << all_requests[1000000].blocks.back().num_spec_tokens << std::endl;
}
profiling.llm_step_start = Realm::Clock::current_time_in_microseconds();
return new_bc;
Expand Down Expand Up @@ -1594,13 +1527,10 @@ void RequestManager::_append_logical_block_to_request(
// update page_id_commit
if (is_commit) {
request.page_id_commit++;
// printf("page_id_commit: %d\n", request.page_id_commit);
// printf("blocks size: %d\n", request.blocks.size());
assert(request.page_id_commit < request.blocks.size());
}
}

// [start, end) is the number of tokens that we want to extract
void RequestManager::_append_tokens_to_blocks(Request &request, std::vector<TokenId> const &tokens, bool is_commit, int start, int end) {
assert(start >= 0 && start < tokens.size());
int cursor = start;
Expand Down
18 changes: 1 addition & 17 deletions src/runtime/request_manager.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ void RequestManager::load_tokens_task(
}
}

// TODO: add detailed documentation
// pass the kv page related information to the handle located at GPU
void prepare_inference_params_kernel_h(BatchConfig const *batch_config,
PageManager *pm,
FFHandler handle,
Expand All @@ -80,8 +80,6 @@ void prepare_inference_params_kernel_h(BatchConfig const *batch_config,
int32_t *qk_indptr_h) {
int batch_size = batch_config->num_active_requests();
// we just search for the page number for each request
// kv_last_page_len can be handled
// kv_indices can be handled
q_indptr_h[0] = 0;
kv_indptr_h[0] = 0;
qk_indptr_h[0] = 0;
Expand Down Expand Up @@ -543,20 +541,6 @@ void RequestManager::load_batch_config_task(
handle.tree_verify_attention_metadata->prompt_handler_collections[batch_size]);
}
// static int32_t q_indptr_h[BatchConfig::MAX_NUM_REQUESTS + 1], kv_indptr_h[BatchConfig::MAX_NUM_REQUESTS + 1];
// q_indptr_h[0] = 0;
// kv_indptr_h[0] = 0;
// for (int req_idx = 0, indptr_idx = 0; req_idx < batch_config->max_requests_per_batch(); req_idx++) {
// if (batch_config->request_available[req_idx]) {
// int q_len = batch_config->requestsInfo[req_idx].num_tokens_in_batch;
// int kv_len = batch_config->requestsInfo[req_idx].num_tokens_in_batch +
// batch_config->requestsInfo[req_idx].first_token_index_in_request;
// q_indptr_h[indptr_idx + 1] = q_indptr_h[indptr_idx] + q_len;
// kv_indptr_h[indptr_idx + 1] = kv_indptr_h[indptr_idx] + (kv_len + kPagesize - 1) / kPagesize;
// indptr_idx++;
// }
// }
handler->SetCUDAStream(stream);
handler->BeginForward<half, int32_t>(static_cast<void*>(
static_cast<char*>(handle.tree_verify_attention_metadata->workspace) +
Expand Down

0 comments on commit 1297687

Please sign in to comment.