Skip to content

Commit

Permalink
before cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
Bob-Chen222 committed Oct 29, 2024
1 parent a87aaf3 commit 412fad5
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 12 deletions.
9 changes: 1 addition & 8 deletions src/ops/inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -503,24 +503,17 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta(
query_tmp_size = num_q_heads * qk_dim * max_tokens_per_batch;
// a K-ary tree max node is (k^n - 1) / 2
if (total_size == -1){
printf("we should be here this time\n");
printf("num_hidden layers: %d\n", _num_hidden_layers);
// fall back to the default value
key_cache_size = num_kv_heads * qk_dim * BatchConfig::max_requests_per_batch() *
max_num_pages * kPagesize;
value_cache_size = num_kv_heads * v_dim * BatchConfig::max_requests_per_batch() *
max_num_pages * kPagesize;
}else{
printf("we should be here this time\n");
key_cache_size = total_size / 2 / num_hidden_layers;
value_cache_size = total_size / 2 / num_hidden_layers;
assert(key_cache_size > 0 && value_cache_size > 0 && "Invalid kvcache size");
}
printf("key_cache_size: %lu\n", key_cache_size);
printf("value_cache_size: %lu\n", value_cache_size);
printf("size_of_dt: %lu\n", size_of_dt);
printf("num_kv_heads: %d\n", num_kv_heads);
printf("qk_dim: %d\n", qk_dim);
printf("v_dim: %d\n", v_dim);
PageManager::get_page_manager(size_of_dt, num_kv_heads, qk_dim + v_dim,
(key_cache_size + value_cache_size), total_size);

Expand Down
3 changes: 1 addition & 2 deletions src/ops/tree_inc_multihead_self_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,6 @@ void TreeIncMultiHeadSelfAttention::init_inference(
std::vector<ParallelTensor> const &batch_inputs,
std::vector<ParallelTensor> const &batch_outputs,
MachineView const *mv) {
printf("start inference num_hidden_layers %d\n", num_hidden_layers);
assert(check_output_input_weight_same_parallel_is());
parallel_is = batch_outputs[0]->parallel_is;
ArgumentMap argmap;
Expand Down Expand Up @@ -720,7 +719,7 @@ OpMeta *TreeIncMultiHeadSelfAttention::init_task(
handle.offload_reserve_space, handle.offload_reserve_space_size);
}
// assert(attn->num_hidden_layers != 0);
printf("num_hidden_layers = %d\n", attn->num_hidden_layers);
printf("p_layers = %d\n", attn->num_hidden_layers);
TreeIncMultiHeadSelfAttentionMeta *m =
new TreeIncMultiHeadSelfAttentionMeta(handle,
attn,
Expand Down
3 changes: 1 addition & 2 deletions src/runtime/page_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,7 @@ PageManager *PageManager::get_page_manager(size_t size_of_dt, int num_kv_heads,
if (page_manager_singleton != nullptr) {
return page_manager_singleton;
}
int num_total_blocks = kv_cache_size_per_layer * 1024 * 1024 / kPagesize / size_of_dt / num_kv_heads / qkv_dim;
printf("num_total_blocks assigned: %d\n", num_total_blocks);
int num_total_blocks = kv_cache_size_per_layer / kPagesize / size_of_dt / num_kv_heads / qkv_dim;
page_manager_singleton = new PageManager(kPagesize, num_total_blocks);
return page_manager_singleton;
}
Expand Down

0 comments on commit 412fad5

Please sign in to comment.