Skip to content

Commit

Permalink
fix spec token num
Browse files Browse the repository at this point in the history
  • Loading branch information
chenzhuofu committed Oct 11, 2024
1 parent 8394f15 commit 2ec8b5b
Showing 1 changed file with 35 additions and 7 deletions.
42 changes: 35 additions & 7 deletions src/runtime/request_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1529,7 +1529,7 @@ BatchConfig RequestManager::prepare_verify_batch_config() {
committed_tokens.at(committed_token_index);

// assert(request.page_last_committed < request.blocks.size());
printf("in verify: page_last_committed: %d, request->blocks.size(): %d\n", request.page_last_committed, request.blocks.size());
// printf("in verify: page_last_committed: %d, request->blocks.size(): %d\n", request.page_last_committed, request.blocks.size());
int idx_to_physical = append_token_to_block(request, committed_token.token_id, true);
int idx_from_logical = committed_token.from_index - request.first_token_offset_in_batch;
int idx_from_physical = block_table_before_commit[idx_from_logical / kPagesize] * kPagesize + committed_token.from_index % kPagesize;
Expand Down Expand Up @@ -1561,13 +1561,15 @@ BatchConfig RequestManager::prepare_verify_batch_config() {
token_tree_index++;

// Append the token to the block
printf("in verify spec tree: page_last_committed: %d, request->blocks.size(): %d\n", request.page_last_committed, request.blocks.size());
// printf("in verify spec tree: page_last_committed: %d, request->blocks.size(): %d ", request.page_last_committed, request.blocks.size());
// printf("in verify spec tree: last page len: %d\n", get_len_last_block(request));
// assert(request.page_last_committed < request.blocks.size());
append_token_to_block(request, tree_node->id, false);
}
}
layer_index++;
}
printf("there are %d tokens in the token tree\n", token_tree_index);
new_bc.requestsInfo[request_index].num_tokens_in_batch = token_tree_index;

request.first_token_offset_in_batch = new_bc.num_tokens - token_tree_index;
Expand Down Expand Up @@ -1919,7 +1921,11 @@ int RequestManager::idx_logical_to_physical(Request &request, int idx_logical) {
// get physical indices
PageManager *page_manager = PageManager::get_page_manager();
std::vector<int> block_table_indices = page_manager->get_block_table_indices(request.guid);
assert(request.blocks.size() == block_table_indices.size());
if (request.blocks.size() != block_table_indices.size()) {
printf("page manager get block table indices: %d, request.blocks.size(): %d\n", page_manager->get_block_table_indices(request.guid).size(), request.blocks.size());
printf("request.blocks.size(): %d, block_table_indices.size(): %d\n", request.blocks.size(), block_table_indices.size());
assert(request.blocks.size() == block_table_indices.size());
}
return block_table_indices[idx_logical / kPagesize] * kPagesize + idx_logical % kPagesize;
}

Expand Down Expand Up @@ -1951,15 +1957,17 @@ void RequestManager::_append_block_to_request(
//it will return the physical position of this token
int RequestManager::append_token_to_block(Request &request, TokenId token, bool is_commit) {
// assert(request.page_last_committed < request.blocks.size());
PageManager *page_manager = PageManager::get_page_manager();
if (request.blocks.empty() ||
request.blocks.back().is_full()) {
PageManager *page_manager = PageManager::get_page_manager();
// Append a new logical block
_append_block_to_request(request, is_commit);
assert(request.blocks.size() == page_manager->get_block_table_indices(request.guid).size());
// also allocate one physical page
}
// insert token to both logical block and physical block
request.blocks.back().append_tokens({token}, is_commit);
assert(request.blocks.size() == page_manager->get_block_table_indices(request.guid).size());
int idx_logical = get_idx_last_logical_token(request);
int idx_physical = idx_logical_to_physical(request, idx_logical);
return idx_physical;
Expand All @@ -1972,14 +1980,34 @@ void RequestManager::reset_block_table(Request &request){
assert(request.blocks.size() == page_manager->get_block_table_indices(request.guid).size());
std::vector<int> block_table_indices = page_manager->get_block_table_indices(request.guid);
// reset the block table according to the request's page_last_commit
page_manager->free_multiple_blocks(request.guid, block_table_indices.size() - request.page_last_committed);
assert(block_table_indices.size() > request.page_last_committed);
page_manager->free_multiple_blocks(request.guid, block_table_indices.size() - request.page_last_committed - 1);
// reset this request's logical block table
request.blocks.erase(request.blocks.begin() + request.page_last_committed + 1, request.blocks.end());
if (request.page_last_committed < static_cast<int>(request.blocks.size())) {
request.blocks.erase(request.blocks.begin() + request.page_last_committed + 1, request.blocks.end());
}
request.blocks.back().reset_num_spec_tokens();
printf("after reset, block now has %d tokens\n", request.blocks.back().get_num_tokens());
printf("number of pages allocated: %d\n", page_manager->get_block_table_indices(request.guid).size());
printf("number of blocks: %d\n", request.blocks.size());
printf("num spec tokens: %d\n", request.blocks.back().get_num_spec_tokens());
printf("num committed tokens: %d\n", request.blocks.back().get_num_commit_tokens());

assert(request.blocks.size() == block_table_indices.size());
assert(request.blocks.size() == page_manager->get_block_table_indices(request.guid).size());
return;
}

// debug function
void RequestManager::print_num_tokens(Request &request) {
PageManager *page_manager = PageManager::get_page_manager();
std::vector<int> block_table_indices = page_manager->get_block_table_indices(request.guid);
printf("number of blocks: %d", request.blocks.size());
printf(" number of pages allocated: %d", block_table_indices.size());
printf(" last page length: %d", request.blocks.back().get_num_tokens());
printf(" last page spec tokens: %d", request.blocks.back().get_num_spec_tokens());
printf(" last page commit tokens: %d\n", request.blocks.back().get_num_commit_tokens());
}

/* --------- Bitmask Related Functions --------- */
void RequestManager::gumbel_conditioned_on_max(
double target_max, std::vector<std::pair<double, int>> &logits) {
Expand Down

0 comments on commit 2ec8b5b

Please sign in to comment.