Skip to content

Commit

Permalink
add some docuementation and delete print
Browse files Browse the repository at this point in the history
  • Loading branch information
Bob-Chen222 committed Oct 21, 2024
1 parent 945dee9 commit 19e41d6
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 131 deletions.
20 changes: 2 additions & 18 deletions include/flexflow/page_manager.h
Original file line number Diff line number Diff line change
@@ -1,18 +1,3 @@
/* Copyright 2023 CMU, Stanford, Facebook, LANL
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "flexflow/batch_config.h"
Expand All @@ -31,7 +16,7 @@ using TokenId = BatchConfig::TokenId;

/**
* @class LogicalTokenBlock
* @brief A class to represent a logical block of tokens similar to virtual memory address
* @brief A class to represent a sequence of tokens for each request
*/
class LogicalTokenBlock {
public:
Expand Down Expand Up @@ -70,7 +55,6 @@ class LogicalTokenBlock {
int num_tokens; // the number of tokens currently stored in the block
int num_commit_tokens; // the number of tokens inside this block that are already committed
int num_spec_tokens; // the number of tokens inside this block that are speculative tokens, which is stored temporarily

std::vector<TokenId> token_ids; //store the token ids in a order that corresponds to the inference sequence
};

Expand Down Expand Up @@ -132,9 +116,9 @@ class PageManager {
using RequestGuid = BatchConfig::RequestGuid;
PageManager(int block_size, int num_total_blocks);


int allocate_one_block(const RequestGuid& request_guid);
void free_request(const RequestGuid& request_guid);
//used for the case that we want to free the last num_blocks that stores spec tokens(which are the tokens are not yet committed)
void free_multiple_blocks(const RequestGuid& request_guid, int num_blocks);
std::vector<int> get_block_table_indices(const RequestGuid& request_guid) const;

Expand Down
3 changes: 0 additions & 3 deletions src/ops/tree_inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,6 @@ __global__ void
if (offset < kv_hidden_size) {
int start = kv_indptr[req_idx_compact];
int end = kv_indptr[req_idx_compact + 1] - 1;
if (start > end) {
printf("Invalid kv_indptr: %d %d\n", start, end);
}
assert(start <= end && "Invalid kv_indptr");
assert(start + (token_abs_idx / kPagesize) <= end &&
"Invalid page index");
Expand Down
5 changes: 1 addition & 4 deletions src/runtime/page_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ void BlockAllocator::free(PhysicalTokenBlock& block) {
}
block.decr_ref_count();
if (block.ref_count == 0) {
printf("put block number: %d back to free_blocks\n", block.get_block_number());
// printf("put block number: %d back to free_blocks\n", block.get_block_number());
free_blocks.push_back(block);
}else{
// in current implementation this should not be the case
Expand All @@ -128,7 +128,6 @@ int BlockAllocator::get_num_free_blocks() const {
PageManager::PageManager(int block_size, int num_total_blocks)
: block_size(block_size), num_total_blocks(num_total_blocks),
block_allocator(block_size, num_total_blocks) {
printf("page manager init with block_size: %d, num_total_blocks: %d\n", block_size, num_total_blocks);
}

//return the physical number of this block
Expand All @@ -138,7 +137,6 @@ int PageManager::allocate_one_block(const RequestGuid& request_guid) {
PhysicalTokenBlock block = block_allocator.allocate();
block_table.push_back(block);
block_tables[request_guid] = block_table;
printf("request_guid: %d, block_number: %d\n", request_guid, block.get_block_number());
return block.get_block_number();
}

Expand Down Expand Up @@ -184,7 +182,6 @@ std::vector<int> PageManager::get_block_table_indices(const RequestGuid& request
std::vector<int> indices;
const auto& it = block_tables.find(request_guid);
if (it == block_tables.end()) {
printf("page manager not found request_guid: %d\n", request_guid);
return indices;
}
const auto& block_table = it->second;
Expand Down
53 changes: 9 additions & 44 deletions src/runtime/request_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,8 @@ void RequestManager::set_spec_infer_old_version(bool spec_infer_old_version_) {
spec_infer_old_version = spec_infer_old_version_;
}

void RequestManager::set_greedy_scheduler(bool greedy_scheduler_) {
greedy_scheduler = greedy_scheduler_;
void RequestManager::set_greedy_schedule(bool greedy_scheduler_) {
greedy_schedule = greedy_scheduler_;
}

void RequestManager::set_equal_schedule(bool equal_schedule_) {
Expand All @@ -341,8 +341,8 @@ bool RequestManager::get_spec_infer_old_version() {
return spec_infer_old_version;
}

bool RequestManager::get_greedy_scheduler() {
return greedy_scheduler;
bool RequestManager::get_greedy_schedule() {
return greedy_schedule;
}

bool RequestManager::get_equal_schedule() {
Expand Down Expand Up @@ -683,7 +683,6 @@ void RequestManager::request_complete_clean_up(int batch_index) {

// page attention: free the pages
PageManager *page_manager = PageManager::get_page_manager();
printf("free request %d\n", guid);
page_manager->free_request(guid);

// Find the sos and eos in the sequence
Expand Down Expand Up @@ -1148,14 +1147,11 @@ BatchConfig RequestManager::prepare_llm_prefilling_batch() {
}
//update related page info in batch config
bc.requestsInfo[request_index].num_kv_pages = get_num_blocks_allocated(*request);
printf("request: %d has %d kv pages after prefilling\n", request->guid, bc.requestsInfo[request_index].num_kv_pages);
// WARNING: it is possible that it has no tokens allocated!! but not allowed for flashinfer
if (bc.requestsInfo[request_index].num_kv_pages == 0) {
// turn this request into not available
// turn this request into not available for one round
bc.request_available[request_index] = false;
}
bc.requestsInfo[request_index].kv_last_page_len = get_len_last_block(*request);
printf("request: %d has %d kv last page len after prefilling\n", request->guid, bc.requestsInfo[request_index].kv_last_page_len);
bc.requestsInfo[request_index].request_guid = request->guid;
}
bc.num_tokens = num_tokens;
Expand All @@ -1164,7 +1160,6 @@ BatchConfig RequestManager::prepare_llm_prefilling_batch() {
std::cout << "prepare_llm_prefilling_batch NEW batchconfig:" << std::endl;
bc.print();
}
printf("there are %d requests in the batch in prefilling stage\n", bc.num_available_requests);
return bc;
}

Expand Down Expand Up @@ -1550,7 +1545,7 @@ BatchConfig RequestManager::prepare_verify_batch_config() {
Request &request = all_requests[guid];
assert(request.status == Request::RUNNING);

//page attention: before commit token, reset the pages assigned by cleaning all the tokens
//before commit token, reset the pages assigned by cleaning all the tokens
std::vector<int> block_table_before_commit = page_manager->get_block_table_indices(guid);
// also need to reset the pages
reset_block_table(request);
Expand All @@ -1576,17 +1571,11 @@ BatchConfig RequestManager::prepare_verify_batch_config() {
Request::CommittedToken &committed_token =
committed_tokens.at(committed_token_index);

// assert(request.page_last_committed < request.blocks.size());
printf("in verify: page_last_committed: %d, request->blocks.size(): %d\n", request.page_last_committed, request.blocks.size());
int idx_to_physical = append_token_to_block(request, committed_token.token_id, true);
int idx_from_logical = committed_token.from_index;
if (idx_from_logical < 0) {
printf("idx_from_logical: %d, from_index: %d, first_token_offset_in_batch: %d\n", idx_from_logical, committed_token.from_index, request.first_token_offset_in_batch);
}
assert(idx_from_logical >= 0);
assert(idx_from_logical / kPagesize < block_table_before_commit.size());
int idx_from_physical = block_table_before_commit[idx_from_logical / kPagesize] * kPagesize + committed_token.from_index % kPagesize;
printf("id to physical: %d, from physical: %d\n", idx_to_physical, idx_from_physical);


new_bc.committed_tokens[new_bc.num_tokens_to_commit].request_index =
Expand Down Expand Up @@ -1615,15 +1604,11 @@ BatchConfig RequestManager::prepare_verify_batch_config() {
token_tree_index++;

// Append the token to the block
// printf("in verify spec tree: page_last_committed: %d, request->blocks.size(): %d ", request.page_last_committed, request.blocks.size());
// printf("in verify spec tree: last page len: %d\n", get_len_last_block(request));
// assert(request.page_last_committed < request.blocks.size());
append_token_to_block(request, tree_node->id, false);
}
}
layer_index++;
}
printf("there are %d tokens in the token tree\n", token_tree_index);
new_bc.requestsInfo[request_index].num_tokens_in_batch = token_tree_index;

request.first_token_offset_in_batch = new_bc.num_tokens - token_tree_index;
Expand Down Expand Up @@ -1973,6 +1958,7 @@ int RequestManager::get_len_last_block(Request &request) const {
return request.blocks.back().get_num_tokens();
}

// get the index of the last token in the request
int RequestManager::get_idx_last_logical_token(Request &request) const {
if (request.blocks.empty()) {
printf("Error: request.blocks is empty\n");
Expand All @@ -1987,8 +1973,6 @@ int RequestManager::idx_logical_to_physical(Request &request, int idx_logical) {
PageManager *page_manager = PageManager::get_page_manager();
std::vector<int> block_table_indices = page_manager->get_block_table_indices(request.guid);
if (request.blocks.size() != block_table_indices.size()) {
printf("page manager get block table indices: %d, request.blocks.size(): %d\n", page_manager->get_block_table_indices(request.guid).size(), request.blocks.size());
printf("request.blocks.size(): %d, block_table_indices.size(): %d\n", request.blocks.size(), block_table_indices.size());
assert(request.blocks.size() == block_table_indices.size());
}
return block_table_indices[idx_logical / kPagesize] * kPagesize + idx_logical % kPagesize;
Expand All @@ -2007,25 +1991,18 @@ void RequestManager::_append_block_to_request(
request.blocks.push_back(block);
page_manager->allocate_one_block(request.guid);
std::vector<int> block_table_indices = page_manager->get_block_table_indices(request.guid);
// for (int i = 0; i < block_table_indices.size(); i++) {
// printf("block table indices: %d\n", block_table_indices[i]);
// }
assert(request.blocks.size() == page_manager->get_block_table_indices(request.guid).size());
// update page_id_commit
if (is_commit) {
request.page_last_committed++;
int size_blocks = request.blocks.size();
if (request.page_last_committed >= size_blocks) {
printf("request page_last_committed: %d, size_blocks) {: %d\n", request.page_last_committed, size_blocks);
assert(request.page_last_committed < static_cast<int>(request.blocks.size()));
}
assert(request.page_last_committed < static_cast<int>(request.blocks.size()));
}
}

//this function is used for appending a token to the last logical block and also the last physical block
//it will return the physical position of this token
int RequestManager::append_token_to_block(Request &request, TokenId token, bool is_commit) {
// assert(request.page_last_committed < request.blocks.size());
PageManager *page_manager = PageManager::get_page_manager();
if (request.blocks.empty() ||
request.blocks.back().is_full()) {
Expand Down Expand Up @@ -2058,16 +2035,8 @@ void RequestManager::reset_block_table(Request &request){
request.blocks.erase(request.blocks.begin() + request.page_last_committed + 1, request.blocks.end());
}
request.blocks.back().reset_num_spec_tokens();
printf("after reset, block now has %d tokens\n", request.blocks.back().get_num_tokens());
printf("number of pages allocated: %d\n", page_manager->get_block_table_indices(request.guid).size());
// printf("number of blocks: %d\n", request.blocks.size());
// printf("num spec tokens: %d\n", request.blocks.back().get_num_spec_tokens());
// printf("num committed tokens: %d\n", request.blocks.back().get_num_commit_tokens());
// the indices of block table should be the same as the number of blocks
std::vector<int> block_table = page_manager->get_block_table_indices(request.guid);
// for (int i = 0; i < request.blocks.size(); i++) {
// printf("block table indices: %d\n", block_table[i]);
// }

assert(request.blocks.size() == page_manager->get_block_table_indices(request.guid).size());
return;
Expand Down Expand Up @@ -2469,17 +2438,13 @@ void RequestManager::start_background_server(FFModel *model) {
background_server_handler = runtime->execute_task(ctx, launcher);
// Register callbacks for normal exit
{
printf("called exit\n");
int ret = std::atexit(RequestManager::terminate_background_server_at_exit);
printf("return from exit\n");
assert(ret == 0); // make sure the callback is successfully registered
}
// Register callbacks for termination
{
printf("called terminate\n");
std::set_terminate([]() {
RequestManager::terminate_background_server_at_exit();
printf("return from terminate\n");
printStackTrace();
std::abort();
});
Expand Down Expand Up @@ -3395,7 +3360,7 @@ void RequestManager::add_tokens_toward_goodput_per_request(int budget,
Request &request = all_requests[guid];
assert(request.status == Request::RUNNING);
if (request.token_tree_nodes_acc_prob_pair_pq.empty()) {
continue;
return;
}

auto &pq = request.token_tree_nodes_acc_prob_pair_pq;
Expand Down
65 changes: 3 additions & 62 deletions src/runtime/request_manager.cu
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ void prepare_inference_params_kernel_h(BatchConfig const *batch_config,
q_indptr_h[0] = 0;
kv_indptr_h[0] = 0;
qk_indptr_h[0] = 0;
int cnt_1 = 0, q_lens = 0, qk_lens = 0;
int indices_offset = 0, indices_lens = 0, kv_len = 0;
int q_lens = 0, qk_lens = 0;
int indices_offset = 0, indices_lens = 0;
for (int req_idx = 0, indptr_idx = 0; req_idx < batch_config->max_requests_per_batch(); req_idx++) {
if (batch_config->request_available[req_idx]) {
int q_len = batch_config->requestsInfo[req_idx].num_tokens_in_batch;
Expand All @@ -110,10 +110,6 @@ void prepare_inference_params_kernel_h(BatchConfig const *batch_config,
assert(batch_config->requestsInfo[req_idx].num_kv_pages == (kv_len + kPagesize - 1) / kPagesize);
assert(batch_config->requestsInfo[req_idx].kv_last_page_len <= kPagesize);
std::vector<int32_t> kv_indices = pm -> get_block_table_indices(batch_config->requestsInfo[req_idx].request_guid);
// printf("request_guid: %d\n", batch_config->requestsInfo[req_idx].request_guid);
// printf("kv_indices.size() = %d, kv_len = %d\n", kv_indices.size(), kv_len);
// printf("kv last page len = %d\n", batch_config->requestsInfo[req_idx].kv_last_page_len);
// printf("num_kv_pages = %d\n", batch_config->requestsInfo[req_idx].num_kv_pages);
assert(kv_indices.size() == (kv_len + kPagesize - 1) / kPagesize);
for (int i = indices_offset; i < indices_lens; i++) {
kv_indices_h[i] = kv_indices[i - indices_offset];
Expand All @@ -122,14 +118,6 @@ void prepare_inference_params_kernel_h(BatchConfig const *batch_config,
kv_last_page_len_h[indptr_idx] = batch_config->requestsInfo[req_idx].kv_last_page_len;
indptr_idx++;
}
// }else{
// q_indptr_h[indptr_idx + 1] = q_indptr_h[indptr_idx];
// q_indptr_h[indptr_idx + 1] = q_indptr_h[indptr_idx];
// kv_indptr_h[indptr_idx + 1] = kv_indptr_h[indptr_idx];
// qk_indptr_h[indptr_idx + 1] = 0;
// kv_last_page_len_h[indptr_idx] = 0;
// indptr_idx++;
// }
}

// do the copy
Expand Down Expand Up @@ -405,12 +393,6 @@ void RequestManager::load_batch_config_task(
handle.incr_attention_metadata->kv_indices,
handle.incr_attention_metadata->kv_last_page_len,
handle.incr_attention_metadata->qk_indptr);
// check on error
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) {
printf("CUDA error in prepare_inference_params_kernel: %s\n", cudaGetErrorString(error));
}
}
// prepare attention forward handler
Expand Down Expand Up @@ -637,8 +619,6 @@ void RequestManager::load_batch_config_task(
}
} else if (batch_config->get_mode() == TREE_VERIFY_MODE) {
PageManager *pm = PageManager::get_page_manager();
// hardcode request
// printf("request has allocated %d pages\n", pm -> get_block_table_indices(1000000).size());
static int32_t q_indptr_h[BatchConfig::MAX_NUM_REQUESTS + 1], kv_indptr_h[BatchConfig::MAX_NUM_REQUESTS + 1];
static int32_t kv_indices_h[BatchConfig::MAX_NUM_REQUESTS * BatchConfig::MAX_NUM_TOKENS];
static int32_t qk_indptr_h[BatchConfig::MAX_NUM_REQUESTS + 1];
Expand Down Expand Up @@ -666,20 +646,7 @@ void RequestManager::load_batch_config_task(
round_up_pages(BatchConfig::max_sequence_length() +
BatchConfig::max_spec_tree_token_num());
int parallelism = batch_size;
// prepare_inference_params_kernel<<<GET_BLOCKS(parallelism),
// min(CUDA_NUM_THREADS, parallelism),
// 0,
// stream>>>(
// batch_size,
// request_infos,
// request_available,
// max_num_pages,
// handle.tree_verify_attention_metadata->q_indptr,
// handle.tree_verify_attention_metadata->kv_indptr,
// handle.tree_verify_attention_metadata->kv_indices,
// handle.tree_verify_attention_metadata->kv_last_page_len,
// handle.tree_verify_attention_metadata->qk_indptr);
// int parallelism = batch_size;
prepare_inference_params_kernel_h(batch_config,
pm,
handle,
Expand Down Expand Up @@ -747,26 +714,6 @@ void RequestManager::load_batch_config_task(
->prompt_handler_collections[batch_size]);
}
// static int32_t q_indptr_h[BatchConfig::MAX_NUM_REQUESTS + 1],
// kv_indptr_h[BatchConfig::MAX_NUM_REQUESTS + 1];
// q_indptr_h[0] = 0;
// kv_indptr_h[0] = 0;
// for (int req_idx = 0, indptr_idx = 0;
// req_idx < batch_config->max_requests_per_batch();
// req_idx++) {
// if (batch_config->request_available[req_idx]) {
// int q_len = batch_config->requestsInfo[req_idx].num_tokens_in_batch;
// int kv_len =
// batch_config->requestsInfo[req_idx].num_tokens_in_batch +
// batch_config->requestsInfo[req_idx]
// .first_token_index_in_request;
// q_indptr_h[indptr_idx + 1] = q_indptr_h[indptr_idx] + q_len;
// kv_indptr_h[indptr_idx + 1] =
// kv_indptr_h[indptr_idx] + round_up_pages(kv_len);
// indptr_idx++;
// }
// }
handler->SetCUDAStream(stream);
handler->BeginForward<half, int32_t>(
static_cast<void *>(
Expand All @@ -782,12 +729,6 @@ void RequestManager::load_batch_config_task(
handle.tree_verify_attention_metadata->num_kv_heads(),
handle.tree_verify_attention_metadata->head_dim(),
kPagesize);
cudaError_t syncErr = cudaDeviceSynchronize();
if (syncErr != cudaSuccess) {
printf("Kernel execution error: %s\n", cudaGetErrorString(syncErr));
assert(false);
}
}
}
}
Expand Down

0 comments on commit 19e41d6

Please sign in to comment.