diff --git a/include/flexflow/request_manager.h b/include/flexflow/request_manager.h index a38a3b2671..042fb5dc81 100644 --- a/include/flexflow/request_manager.h +++ b/include/flexflow/request_manager.h @@ -63,6 +63,7 @@ struct Request { RUNNING = 102, // running inference COMPLETED = 103, // finished and verified FINISHING = 104, // finishing request, but not yet verified + PREFILLING = 105 // prefilling the tree }; BatchConfig::RequestGuid guid; int max_sequence_length; @@ -307,6 +308,7 @@ class RequestManager { }; std::unordered_map profiling_requests; double total_request_run_time; + BatchConfig buffer_bc = nullptr; }; }; // namespace FlexFlow diff --git a/src/ops/fused.cu b/src/ops/fused.cu index 483028599e..5b85798c7a 100644 --- a/src/ops/fused.cu +++ b/src/ops/fused.cu @@ -816,324 +816,865 @@ __host__ void my_input_accessor[0].domain.get_volume()); } - assert(my_input_accessor[0].data_type == DT_INT32 || - my_input_accessor[0].data_type == DT_INT64); - Kernels::Embedding::forward_kernel_wrapper(m, - my_input_accessor[0], - my_output_accessor[0], - my_weight_accessor[0], - in_dim, - out_dim, - effective_batch_size); - break; - } - case OP_GELU: - case OP_RELU: - case OP_SIGMOID: - case OP_TANH: - case OP_ELU: - case OP_SCALAR_TRUE_DIV: { - assert(fused->op_num_inputs[op] == 1); - assert(fused->op_num_weights[op] == 0); - assert(fused->op_num_outputs[op] == 1); - assert(my_input_accessor[0].domain == my_output_accessor[0].domain); - ElementUnaryMeta *m = (ElementUnaryMeta *)metas->meta[op]; - if (m->data_type == DT_HALF) { - ElementUnary::forward_kernel_wrapper( - m, - my_input_accessor[0].get_half_ptr(), - my_output_accessor[0].get_half_ptr(), - my_input_accessor[0].domain.get_volume()); - } else if (m->data_type == DT_FLOAT) { - ElementUnary::forward_kernel_wrapper( - m, - my_input_accessor[0].get_float_ptr(), - my_output_accessor[0].get_float_ptr(), - my_input_accessor[0].domain.get_volume()); - } else { - assert(false && "Unsupported data type in ElementUnary forward"); - } - break; - } - case OP_RMS_NORM: { - assert(fused->op_num_inputs[op] == 1); - assert(fused->op_num_weights[op] == 1); - assert(fused->op_num_outputs[op] == 1); - RMSNormMeta const *m = (RMSNormMeta *)metas->meta[op]; - Kernels::RMSNorm::forward_kernel_wrapper(m, - my_input_accessor[0], - my_weight_accessor[0], - my_output_accessor[0]); - break; - } - case OP_RESIDUAL_RMS_NORM: { - assert(fused->op_num_inputs[op] == 2); - assert(fused->op_num_weights[op] == 1); - assert(fused->op_num_outputs[op] == 2); - ResidualRMSNormMeta const *m = (ResidualRMSNormMeta *)metas->meta[op]; - Kernels::ResidualRMSNorm::forward_kernel_wrapper(m, - my_input_accessor[0], - my_input_accessor[1], - my_weight_accessor[0], - my_output_accessor[0], - my_output_accessor[1]); - break; - } - case OP_INC_MULTIHEAD_SELF_ATTENTION: { - assert(fused->op_num_inputs[op] == 1); - assert(fused->op_num_outputs[op] == 1); - IncMultiHeadSelfAttentionMeta const *m = - (IncMultiHeadSelfAttentionMeta *)metas->meta[op]; - assert(fused->op_num_weights[op] == - (1 + (int)(*m->qkv_bias || *m->final_bias))); - GenericTensorAccessorR biases; - if (*m->qkv_bias || *m->final_bias) { - assert(fused->op_num_weights[op] == 2); - biases = my_weight_accessor[1]; - } - IncMultiHeadSelfAttention::inference_kernel_wrapper( - m, - bc, - task->index_point.point_data[0], - my_input_accessor[0], - my_weight_accessor[0], - my_output_accessor[0], - biases); - break; - } - case OP_TREE_INC_MULTIHEAD_SELF_ATTENTION: { - assert(fused->op_num_inputs[op] == 1); - assert(fused->op_num_outputs[op] == 1); - TreeIncMultiHeadSelfAttentionMeta *m = - (TreeIncMultiHeadSelfAttentionMeta *)metas->meta[op]; - // TreeVerifyBatchConfig const *tree_bc = - // (TreeVerifyBatchConfig *)task->args; - TreeVerifyBatchConfig const &tree_bc = - Future(task->futures[0]).get_result(); - assert(fused->op_num_weights[op] == - (1 + (int)(*m->qkv_bias || *m->final_bias))); - GenericTensorAccessorR biases; - if (*m->qkv_bias || *m->final_bias) { - assert(fused->op_num_weights[op] == 2); - biases = my_weight_accessor[1]; - } - TreeIncMultiHeadSelfAttention::inference_kernel_wrapper( - m, - &tree_bc, - task->index_point.point_data[0], - my_input_accessor[0], - my_weight_accessor[0], - my_output_accessor[0], - biases); - break; - } - case OP_SPEC_INC_MULTIHEAD_SELF_ATTENTION: { - assert(fused->op_num_inputs[op] == 1); - assert(fused->op_num_outputs[op] == 1); - SpecIncMultiHeadSelfAttentionMeta const *m = - (SpecIncMultiHeadSelfAttentionMeta *)metas->meta[op]; - // BeamSearchBatchConfig const *beam_bc = - // (BeamSearchBatchConfig *)task->args; - BeamSearchBatchConfig const &beam_bc = - Future(task->futures[0]).get_result(); - assert(fused->op_num_weights[op] == - (1 + (int)(*m->qkv_bias || *m->final_bias))); - GenericTensorAccessorR biases; - if (*m->qkv_bias || *m->final_bias) { - assert(fused->op_num_weights[op] == 2); - biases = my_weight_accessor[1]; - } - SpecIncMultiHeadSelfAttention::inference_kernel_wrapper( - m, - &beam_bc, - task->index_point.point_data[0], - my_input_accessor[0], - my_weight_accessor[0], - my_output_accessor[0], - biases); - break; - } - case OP_LAYERNORM: { - assert(fused->op_num_inputs[op] == 1); - assert(fused->op_num_outputs[op] == 1); - LayerNormMeta const *m = (LayerNormMeta *)metas->meta[op]; - if (m->elementwise_affine) { - assert(fused->op_num_weights[op] == 1 + (int)(m->use_bias)); - } - GenericTensorAccessorR gamma, beta; - if (m->elementwise_affine) { - gamma = my_weight_accessor[0]; - if (m->use_bias) { - beta = my_weight_accessor[1]; - } - } - LayerNorm::forward_kernel_wrapper( - m, my_input_accessor[0], my_output_accessor[0], gamma, beta); - break; - } - case OP_RESIDUAL_LAYERNORM: { - assert(fused->op_num_outputs[op] == 2); - ResidualLayerNormMeta const *m = - (ResidualLayerNormMeta *)metas->meta[op]; - if (m->use_two_residuals) { - assert(fused->op_num_inputs[op] == 3); - } else { - assert(fused->op_num_inputs[op] == 2); - } - if (!m->elementwise_affine) { - assert(fused->op_num_weights[op] == 0); - } else { - if (!m->use_bias) { - assert(fused->op_num_weights[op] == 1); // weight - } else { - assert(fused->op_num_weights[op] == 2); // weight + bias + assert(my_input_accessor[0].data_type == DT_INT32 || + my_input_accessor[0].data_type == DT_INT64); + Kernels::Embedding::forward_kernel_wrapper(m, + my_input_accessor[0], + my_output_accessor[0], + my_weight_accessor[0], + in_dim, + out_dim, + effective_batch_size); + break; + } + case OP_GELU: + case OP_RELU: + case OP_SIGMOID: + case OP_TANH: + case OP_ELU: + case OP_SCALAR_TRUE_DIV: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_weights[op] == 0); + assert(fused->op_num_outputs[op] == 1); + assert(my_input_accessor[0].domain == my_output_accessor[0].domain); + ElementUnaryMeta *m = (ElementUnaryMeta *)metas->meta[op]; + if (m->data_type == DT_HALF) { + ElementUnary::forward_kernel_wrapper( + m, + my_input_accessor[0].get_half_ptr(), + my_output_accessor[0].get_half_ptr(), + my_input_accessor[0].domain.get_volume()); + } else if (m->data_type == DT_FLOAT) { + ElementUnary::forward_kernel_wrapper( + m, + my_input_accessor[0].get_float_ptr(), + my_output_accessor[0].get_float_ptr(), + my_input_accessor[0].domain.get_volume()); + } else { + assert(false && "Unsupported data type in ElementUnary forward"); + } + break; + } + case OP_RMS_NORM: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_weights[op] == 1); + assert(fused->op_num_outputs[op] == 1); + RMSNormMeta const *m = (RMSNormMeta *)metas->meta[op]; + Kernels::RMSNorm::forward_kernel_wrapper(m, + my_input_accessor[0], + my_weight_accessor[0], + my_output_accessor[0]); + break; + } + case OP_RESIDUAL_RMS_NORM: { + assert(fused->op_num_inputs[op] == 2); + assert(fused->op_num_weights[op] == 1); + assert(fused->op_num_outputs[op] == 2); + ResidualRMSNormMeta const *m = (ResidualRMSNormMeta *)metas->meta[op]; + Kernels::ResidualRMSNorm::forward_kernel_wrapper(m, + my_input_accessor[0], + my_input_accessor[1], + my_weight_accessor[0], + my_output_accessor[0], + my_output_accessor[1]); + break; + } + case OP_INC_MULTIHEAD_SELF_ATTENTION: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + IncMultiHeadSelfAttentionMeta const *m = + (IncMultiHeadSelfAttentionMeta *)metas->meta[op]; + assert(fused->op_num_weights[op] == + (1 + (int)(*m->qkv_bias || *m->final_bias))); + GenericTensorAccessorR biases; + if (*m->qkv_bias || *m->final_bias) { + assert(fused->op_num_weights[op] == 2); + biases = my_weight_accessor[1]; + } + IncMultiHeadSelfAttention::inference_kernel_wrapper( + m, + bc, + task->index_point.point_data[0], + my_input_accessor[0], + my_weight_accessor[0], + my_output_accessor[0], + biases); + break; + } + case OP_TREE_INC_MULTIHEAD_SELF_ATTENTION: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + TreeIncMultiHeadSelfAttentionMeta *m = + (TreeIncMultiHeadSelfAttentionMeta *)metas->meta[op]; + // TreeVerifyBatchConfig const *tree_bc = + // (TreeVerifyBatchConfig *)task->args; + TreeVerifyBatchConfig const &tree_bc = + Future(task->futures[0]).get_result(); + assert(fused->op_num_weights[op] == + (1 + (int)(*m->qkv_bias || *m->final_bias))); + GenericTensorAccessorR biases; + if (*m->qkv_bias || *m->final_bias) { + assert(fused->op_num_weights[op] == 2); + biases = my_weight_accessor[1]; + } + TreeIncMultiHeadSelfAttention::inference_kernel_wrapper( + m, + &tree_bc, + task->index_point.point_data[0], + my_input_accessor[0], + my_weight_accessor[0], + my_output_accessor[0], + biases); + break; + } + case OP_SPEC_INC_MULTIHEAD_SELF_ATTENTION: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + SpecIncMultiHeadSelfAttentionMeta const *m = + (SpecIncMultiHeadSelfAttentionMeta *)metas->meta[op]; + // BeamSearchBatchConfig const *beam_bc = + // (BeamSearchBatchConfig *)task->args; + BeamSearchBatchConfig const &beam_bc = + Future(task->futures[0]).get_result(); + assert(fused->op_num_weights[op] == + (1 + (int)(*m->qkv_bias || *m->final_bias))); + GenericTensorAccessorR biases; + if (*m->qkv_bias || *m->final_bias) { + assert(fused->op_num_weights[op] == 2); + biases = my_weight_accessor[1]; + } + SpecIncMultiHeadSelfAttention::inference_kernel_wrapper( + m, + &beam_bc, + task->index_point.point_data[0], + my_input_accessor[0], + my_weight_accessor[0], + my_output_accessor[0], + biases); + break; + } + case OP_LAYERNORM: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + LayerNormMeta const *m = (LayerNormMeta *)metas->meta[op]; + if (m->elementwise_affine) { + assert(fused->op_num_weights[op] == 1 + (int)(m->use_bias)); + } + GenericTensorAccessorR gamma, beta; + if (m->elementwise_affine) { + gamma = my_weight_accessor[0]; + if (m->use_bias) { + beta = my_weight_accessor[1]; + } + } + LayerNorm::forward_kernel_wrapper( + m, my_input_accessor[0], my_output_accessor[0], gamma, beta); + break; + } + case OP_RESIDUAL_LAYERNORM: { + assert(fused->op_num_outputs[op] == 2); + ResidualLayerNormMeta const *m = + (ResidualLayerNormMeta *)metas->meta[op]; + if (m->use_two_residuals) { + assert(fused->op_num_inputs[op] == 3); + } else { + assert(fused->op_num_inputs[op] == 2); + } + if (!m->elementwise_affine) { + assert(fused->op_num_weights[op] == 0); + } else { + if (!m->use_bias) { + assert(fused->op_num_weights[op] == 1); // weight + } else { + assert(fused->op_num_weights[op] == 2); // weight + bias + } + } + GenericTensorAccessorR residual2; + if (m->use_two_residuals) { + residual2 = my_input_accessor[2]; + } + GenericTensorAccessorR gamma, beta; + if (m->elementwise_affine) { + gamma = my_weight_accessor[0]; + if (m->use_bias) { + beta = my_weight_accessor[1]; + } + } + ResidualLayerNorm::inference_kernel_wrapper(m, + my_input_accessor[0], + my_input_accessor[1], + residual2, + my_output_accessor[0], + my_output_accessor[1], + gamma, + beta); + break; + } + case OP_ADD_BIAS_RESIDUAL_LAYERNORM: { + assert(fused->op_num_inputs[op] == 2); + assert(fused->op_num_outputs[op] == 2); + AddBiasResidualLayerNormMeta const *m = + (AddBiasResidualLayerNormMeta *)metas->meta[op]; + if (!m->elementwise_affine) { + assert(fused->op_num_weights[op] == 1); // attn bias + } else { + if (!m->use_bias) { + assert(fused->op_num_weights[op] == 2); // attn bias + weight + } else { + assert(fused->op_num_weights[op] == 3); // attn bias + weight + bias + } + } + GenericTensorAccessorR gamma, beta; + if (m->elementwise_affine) { + gamma = my_weight_accessor[1]; + if (m->use_bias) { + beta = my_weight_accessor[2]; + } + } + Domain attn_bias_domain = my_weight_accessor[0].domain; + Domain residual_domain = my_input_accessor[1].domain; + int attn_bias_dim = + attn_bias_domain.hi()[0] - attn_bias_domain.lo()[0] + 1; + int residual_volume = residual_domain.get_volume(); + AddBiasResidualLayerNorm::inference_kernel_wrapper( + m, + attn_bias_dim, + residual_volume, + my_input_accessor[0], + my_output_accessor[0], + my_output_accessor[1], + my_input_accessor[1], + my_weight_accessor[0], + gamma, + beta); + break; + } + case OP_SIGMOID_SILU_MULTI: { + assert(fused->op_num_inputs[op] == 2); + assert(fused->op_num_outputs[op] == 1); + SigmoidSiluMultiMeta const *m = (SigmoidSiluMultiMeta *)metas->meta[op]; + SigmoidSiluMulti::inference_kernel_wrapper(m, + my_input_accessor[0], + my_input_accessor[1], + my_output_accessor[0]); + break; + } + case OP_SOFTMAX: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_weights[op] == 0); + assert(fused->op_num_outputs[op] == 1); + assert(my_input_accessor[0].domain.get_volume() == + my_output_accessor[0].domain.get_volume()); + SoftmaxMeta *m = (SoftmaxMeta *)metas->meta[op]; + if (m->input_type == DT_HALF) { + Kernels::Softmax::forward_kernel_wrapper( + m, + my_input_accessor[0].get_half_ptr(), + my_output_accessor[0].get_half_ptr()); + } else if (m->input_type == DT_FLOAT) { + Kernels::Softmax::forward_kernel_wrapper( + m, + my_input_accessor[0].get_float_ptr(), + my_output_accessor[0].get_float_ptr()); + } + break; + } + case OP_ALLREDUCE: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + AllReduceMeta const *m = (AllReduceMeta *)metas->meta[op]; + Kernels::AllReduce::inference_kernel_wrapper( + m, bc, my_input_accessor[0], my_output_accessor[0]); + break; + } + default: { + fprintf(stderr, + "Fusion currently does not support type = %d\n", + fused->op_op_type[op]); + assert(false && "Fusion currently does not support type"); + } + } + if (metas->meta[op]->inference_debugging) { + std::vector input_accessors_to_save; + std::vector weight_accessors_to_save; + std::vector output_accessors_to_save; + for (int i = 0; i < fused->op_num_inputs[op]; i++) { + int my_off = fused->op_input_idx[i + ioff]; + if (fused->op_input_source[i + ioff] == SOURCE_INPUT) { + input_accessors_to_save.push_back(input_accessor[my_off]); + } else if (fused->op_input_source[i + ioff] == SOURCE_OUTPUT) { + input_accessors_to_save.push_back(output_accessor[my_off]); + } else { + assert(false); + } + } + for (int i = 0; i < fused->op_num_weights[op]; i++) { + assert(fused->op_weight_source[i + woff] == SOURCE_WEIGHT); + weight_accessors_to_save.push_back( + weight_accessor[fused->op_weight_idx[i + woff]]); + } + for (int i = 0; i < fused->op_num_outputs[op]; i++) { + output_accessors_to_save.push_back(output_accessor[i + ooff]); + } + assert(task->index_point.get_dim() == 1); + int shard_id = task->index_point.point_data[0]; + FusedOp::save_inference_tensors_to_file(metas->meta[op], + shard_id, + bc, + input_accessors_to_save, + weight_accessors_to_save, + output_accessors_to_save); + } + ioff += fused->op_num_inputs[op]; + woff += fused->op_num_weights[op]; + ooff += fused->op_num_outputs[op]; + + } + cudaStreamEndCapture(stream, &graph); } - } - GenericTensorAccessorR residual2; - if (m->use_two_residuals) { - residual2 = my_input_accessor[2]; - } - GenericTensorAccessorR gamma, beta; - if (m->elementwise_affine) { - gamma = my_weight_accessor[0]; - if (m->use_bias) { - beta = my_weight_accessor[1]; - } - } - ResidualLayerNorm::inference_kernel_wrapper(m, - my_input_accessor[0], - my_input_accessor[1], - residual2, - my_output_accessor[0], - my_output_accessor[1], - gamma, - beta); - break; - } - case OP_ADD_BIAS_RESIDUAL_LAYERNORM: { - assert(fused->op_num_inputs[op] == 2); - assert(fused->op_num_outputs[op] == 2); - AddBiasResidualLayerNormMeta const *m = - (AddBiasResidualLayerNormMeta *)metas->meta[op]; - if (!m->elementwise_affine) { - assert(fused->op_num_weights[op] == 1); // attn bias - } else { - if (!m->use_bias) { - assert(fused->op_num_weights[op] == 2); // attn bias + weight - } else { - assert(fused->op_num_weights[op] == 3); // attn bias + weight + bias - } - } - GenericTensorAccessorR gamma, beta; - if (m->elementwise_affine) { - gamma = my_weight_accessor[1]; - if (m->use_bias) { - beta = my_weight_accessor[2]; - } - } - Domain attn_bias_domain = my_weight_accessor[0].domain; - Domain residual_domain = my_input_accessor[1].domain; - int attn_bias_dim = - attn_bias_domain.hi()[0] - attn_bias_domain.lo()[0] + 1; - int residual_volume = residual_domain.get_volume(); - AddBiasResidualLayerNorm::inference_kernel_wrapper( - m, - attn_bias_dim, - residual_volume, - my_input_accessor[0], - my_output_accessor[0], - my_output_accessor[1], - my_input_accessor[1], - my_weight_accessor[0], - gamma, - beta); - break; - } - case OP_SIGMOID_SILU_MULTI: { - assert(fused->op_num_inputs[op] == 2); - assert(fused->op_num_outputs[op] == 1); - SigmoidSiluMultiMeta const *m = (SigmoidSiluMultiMeta *)metas->meta[op]; - SigmoidSiluMulti::inference_kernel_wrapper(m, - my_input_accessor[0], - my_input_accessor[1], - my_output_accessor[0]); - break; - } - case OP_SOFTMAX: { - assert(fused->op_num_inputs[op] == 1); - assert(fused->op_num_weights[op] == 0); - assert(fused->op_num_outputs[op] == 1); - assert(my_input_accessor[0].domain.get_volume() == - my_output_accessor[0].domain.get_volume()); - SoftmaxMeta *m = (SoftmaxMeta *)metas->meta[op]; - if (m->input_type == DT_HALF) { - Kernels::Softmax::forward_kernel_wrapper( - m, - my_input_accessor[0].get_half_ptr(), - my_output_accessor[0].get_half_ptr()); - } else if (m->input_type == DT_FLOAT) { - Kernels::Softmax::forward_kernel_wrapper( - m, - my_input_accessor[0].get_float_ptr(), - my_output_accessor[0].get_float_ptr()); - } - break; - } - case OP_ALLREDUCE: { - assert(fused->op_num_inputs[op] == 1); - assert(fused->op_num_outputs[op] == 1); - AllReduceMeta const *m = (AllReduceMeta *)metas->meta[op]; - Kernels::AllReduce::inference_kernel_wrapper( - m, bc, my_input_accessor[0], my_output_accessor[0]); - break; - } - default: { - fprintf(stderr, - "Fusion currently does not support type = %d\n", - fused->op_op_type[op]); - assert(false && "Fusion currently does not support type"); - } - } - if (metas->meta[op]->inference_debugging) { - std::vector input_accessors_to_save; - std::vector weight_accessors_to_save; - std::vector output_accessors_to_save; - for (int i = 0; i < fused->op_num_inputs[op]; i++) { - int my_off = fused->op_input_idx[i + ioff]; - if (fused->op_input_source[i + ioff] == SOURCE_INPUT) { - input_accessors_to_save.push_back(input_accessor[my_off]); - } else if (fused->op_input_source[i + ioff] == SOURCE_OUTPUT) { - input_accessors_to_save.push_back(output_accessor[my_off]); - } else { - assert(false); - } - } - for (int i = 0; i < fused->op_num_weights[op]; i++) { - assert(fused->op_weight_source[i + woff] == SOURCE_WEIGHT); - weight_accessors_to_save.push_back( - weight_accessor[fused->op_weight_idx[i + woff]]); - } - for (int i = 0; i < fused->op_num_outputs[op]; i++) { - output_accessors_to_save.push_back(output_accessor[i + ooff]); + // if(shard_id == 0) { + // printf("*************start cudaGraphInstantiate**********\n"); + // graph_params.Print(); + // bc->print(); + // printf("*************end cudaGraphInstantiate**********\n"); + // } + + cudaGraphInstantiate(&instance, graph, NULL, NULL, 0); + metas->graph_collections[graph_params] = instance; } - assert(task->index_point.get_dim() == 1); - int shard_id = task->index_point.point_data[0]; - FusedOp::save_inference_tensors_to_file(metas->meta[op], - shard_id, - bc, - input_accessors_to_save, - weight_accessors_to_save, - output_accessors_to_save); - } - ioff += fused->op_num_inputs[op]; - woff += fused->op_num_weights[op]; - ooff += fused->op_num_outputs[op]; + assert(metas->graph_collections.find(graph_params) != + metas->graph_collections.end()); + cudaGraphLaunch(instance, stream); + } else { + //mixed batch + int ioff = 0, woff = 0, ooff = 0; + for (int op = 0; op < fused->numOperators; op++) { + clock_t last_timer = clock(); + + // Domain my_id[MAX_NUM_INPUTS]; + // Domain my_wd[MAX_NUM_WEIGHTS]; + // Domain my_od[MAX_NUM_OUTPUTS]; + GenericTensorAccessorR my_input_accessor[MAX_NUM_INPUTS]; + GenericTensorAccessorR my_weight_accessor[MAX_NUM_WEIGHTS]; + GenericTensorAccessorW my_output_accessor[MAX_NUM_OUTPUTS]; + for (int i = 0; i < fused->op_num_inputs[op]; i++) { + int my_off = fused->op_input_idx[i + ioff]; + if (fused->op_input_source[i + ioff] == SOURCE_INPUT) { + // my_id[i] = input_domain[my_off]; + assert(my_off < fused->numInputs); + my_input_accessor[i] = input_accessor[my_off]; + } else if (fused->op_input_source[i + ioff] == SOURCE_OUTPUT) { + // my_id[i] = output_domain[my_off]; + assert(my_off < fused->numOutputs); + my_input_accessor[i] = output_accessor[my_off]; + } else { + assert(false); + } + } + for (int i = 0; i < fused->op_num_weights[op]; i++) { + assert(fused->op_weight_source[i + woff] == SOURCE_WEIGHT); + // my_wd[i] = weight_domain[fused->op_weight_idx[i + woff]]; + // my_wp[i] = weight_ptr[fused->op_weight_idx[i + woff]]; + assert(fused->op_weight_idx[i + woff] < fused->numWeights); + my_weight_accessor[i] = weight_accessor[fused->op_weight_idx[i + woff]]; + } + for (int i = 0; i < fused->op_num_outputs[op]; i++) { + int my_off = fused->op_output_idx[i + ooff]; + assert(fused->op_output_source[i + ooff] == SOURCE_OUTPUT); + assert(my_off < fused->numOutputs); + // my_od[i] = output_domain[fused->op_output_idx[i + ooff]]; + // my_op[i] = output_ptr[fused->op_output_idx[i + ooff]]; + my_output_accessor[i] = output_accessor[my_off]; + } + switch (fused->op_op_type[op]) { + case OP_CONCAT: { + assert(fused->op_num_weights[op] == 0); + assert(fused->op_num_outputs[op] == 1); + ConcatMeta *m = (ConcatMeta *)metas->meta[op]; + int num_inputs = fused->op_num_inputs[op]; + Kernels::Concat::forward_kernel_wrapper(m, + my_output_accessor[0], + my_input_accessor, + num_inputs, + m->legion_axis); + break; + } + case OP_BATCHNORM: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + assert(my_input_accessor[0].domain.get_dim() == 5); + assert(my_output_accessor[0].domain.get_dim() == 5); + assert(my_weight_accessor[0].domain.get_dim() == 2); + assert(my_weight_accessor[1].domain.get_dim() == 2); + BatchNormMeta *m = (BatchNormMeta *)metas->meta[op]; + BatchNorm::forward_kernel(m, + my_input_accessor[0].get_float_ptr(), + my_output_accessor[0].get_float_ptr(), + my_weight_accessor[0].get_float_ptr(), + my_weight_accessor[1].get_float_ptr()); + break; + } + case OP_LINEAR: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + Domain kernel_domain = my_weight_accessor[0].domain; + int in_dim = kernel_domain.hi()[0] - kernel_domain.lo()[0] + 1; + int out_dim = kernel_domain.hi()[1] - kernel_domain.lo()[1] + 1; + int batch_size = my_input_accessor[0].domain.get_volume() / in_dim; + assert(my_output_accessor[0].domain.get_volume() == + out_dim * batch_size); + assert(my_input_accessor[0].domain.get_volume() == in_dim * batch_size); + void const *bias_ptr = nullptr; + LinearMeta *m = (LinearMeta *)metas->meta[op]; + if (fused->op_num_weights[op] == 2) { + assert(my_weight_accessor[1].domain.get_volume() == out_dim); + if (!m->add_bias_only_once || task->index_point.point_data[0] == 0) { + bias_ptr = my_weight_accessor[1].ptr; + } + } else { + assert(fused->op_num_weights[op] == 1); + } + assert(m->input_type[0] == my_input_accessor[0].data_type); + assert(m->input_type[0] == my_output_accessor[0].data_type); + batch_size = bc->num_active_tokens(); + Kernels::Linear::forward_kernel_wrapper(m, + my_input_accessor[0].ptr, + my_output_accessor[0].ptr, + my_weight_accessor[0].ptr, + bias_ptr, + in_dim, + out_dim, + batch_size); + break; + } + case OP_BATCHMATMUL: { + assert(fused->op_num_inputs[op] == 2); + assert(fused->op_num_weights[op] == 0); + assert(fused->op_num_outputs[op] == 1); + Domain out_domain = my_output_accessor[0].domain; + Domain a_domain = my_input_accessor[0].domain; + Domain b_domain = my_input_accessor[1].domain; + int m = b_domain.hi()[0] - b_domain.lo()[0] + 1; + assert(m == out_domain.hi()[0] - out_domain.lo()[0] + 1); + int n = a_domain.hi()[1] - a_domain.lo()[1] + 1; + assert(n == out_domain.hi()[1] - out_domain.lo()[1] + 1); + int k = a_domain.hi()[0] - a_domain.lo()[0] + 1; + assert(k == b_domain.hi()[1] - b_domain.lo()[1] + 1); + assert(a_domain.get_dim() == b_domain.get_dim()); + assert(a_domain.get_dim() == out_domain.get_dim()); + int batch = 1; + for (int i = 2; i < a_domain.get_dim(); i++) { + int dim_size = a_domain.hi()[i] - a_domain.lo()[i] + 1; + assert(dim_size == b_domain.hi()[i] - b_domain.lo()[i] + 1); + assert(dim_size == out_domain.hi()[i] - out_domain.lo()[i] + 1); + batch *= dim_size; + } + BatchMatmulMeta *meta = (BatchMatmulMeta *)metas->meta[op]; + Kernels::BatchMatmul::forward_kernel_wrapper( + meta, + my_output_accessor[0].get_float_ptr(), + my_input_accessor[0].get_float_ptr(), + my_input_accessor[1].get_float_ptr(), + (float const *)nullptr, + m, + n, + k, + batch, + meta->a_seq_length_dim, + meta->b_seq_length_dim, + fused->iter_config.seq_length); + break; + } + case OP_EW_ADD: + case OP_EW_SUB: + case OP_EW_MUL: + case OP_EW_DIV: + case OP_EW_MAX: + case OP_EW_MIN: { + assert(fused->op_num_inputs[op] == 2); + assert(fused->op_num_weights[op] == 0); + assert(fused->op_num_outputs[op] == 1); + assert(my_input_accessor[0].domain == my_input_accessor[1].domain); + assert(my_input_accessor[0].domain == my_output_accessor[0].domain); + ElementBinaryMeta *m = (ElementBinaryMeta *)metas->meta[op]; + Kernels::ElementBinary::forward_kernel_wrapper(m, + my_input_accessor[0], + my_input_accessor[1], + my_output_accessor[0]); + break; + } + case OP_EMBEDDING: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_weights[op] == 1); + assert(fused->op_num_outputs[op] == 1); + EmbeddingMeta *m = (EmbeddingMeta *)metas->meta[op]; + if (m->aggr == AGGR_MODE_NONE) { + // assert(kernel_domain.get_dim() == 2); + assert(my_input_accessor[0].domain.get_dim() + 1 == + my_output_accessor[0].domain.get_dim()); + for (size_t i = 0; i < my_input_accessor[0].domain.get_dim(); i++) { + assert(my_input_accessor[0].domain.hi()[i] == + my_output_accessor[0].domain.hi()[i + 1]); + assert(my_input_accessor[0].domain.lo()[i] == + my_output_accessor[0].domain.lo()[i + 1]); + } + assert(my_weight_accessor[0].domain.hi()[0] - + my_weight_accessor[0].domain.lo()[0] == + my_output_accessor[0].domain.hi()[0] - + my_output_accessor[0].domain.lo()[0]); + } else { + assert(my_input_accessor[0].domain.get_dim() == + my_output_accessor[0].domain.get_dim()); + for (size_t i = 1; i < my_input_accessor[0].domain.get_dim(); i++) { + assert(my_input_accessor[0].domain.hi()[i] == + my_output_accessor[0].domain.hi()[i]); + assert(my_input_accessor[0].domain.lo()[i] == + my_output_accessor[0].domain.lo()[i]); + } + assert(my_weight_accessor[0].domain.hi()[0] - + my_weight_accessor[0].domain.lo()[0] == + my_output_accessor[0].domain.hi()[0] - + my_output_accessor[0].domain.lo()[0]); + } + int in_dim, out_dim, effective_batch_size; + if (m->aggr == AGGR_MODE_NONE) { + in_dim = 1; + out_dim = my_output_accessor[0].domain.hi()[0] - + my_output_accessor[0].domain.lo()[0] + 1; + effective_batch_size = + my_output_accessor[0].domain.get_volume() / out_dim; + assert(effective_batch_size * in_dim == + my_input_accessor[0].domain.get_volume()); + } else { + assert(m->aggr == AGGR_MODE_AVG || m->aggr == AGGR_MODE_SUM); + in_dim = my_input_accessor[0].domain.hi()[0] - + my_input_accessor[0].domain.lo()[0] + 1; + out_dim = my_output_accessor[0].domain.hi()[0] - + my_output_accessor[0].domain.lo()[0] + 1; + effective_batch_size = + my_output_accessor[0].domain.get_volume() / out_dim; + assert(effective_batch_size * in_dim == + my_input_accessor[0].domain.get_volume()); + } + + assert(my_input_accessor[0].data_type == DT_INT32 || + my_input_accessor[0].data_type == DT_INT64); + Kernels::Embedding::forward_kernel_wrapper(m, + my_input_accessor[0], + my_output_accessor[0], + my_weight_accessor[0], + in_dim, + out_dim, + effective_batch_size); + break; + } + case OP_GELU: + case OP_RELU: + case OP_SIGMOID: + case OP_TANH: + case OP_ELU: + case OP_SCALAR_TRUE_DIV: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_weights[op] == 0); + assert(fused->op_num_outputs[op] == 1); + assert(my_input_accessor[0].domain == my_output_accessor[0].domain); + ElementUnaryMeta *m = (ElementUnaryMeta *)metas->meta[op]; + if (m->data_type == DT_HALF) { + ElementUnary::forward_kernel_wrapper( + m, + my_input_accessor[0].get_half_ptr(), + my_output_accessor[0].get_half_ptr(), + my_input_accessor[0].domain.get_volume()); + } else if (m->data_type == DT_FLOAT) { + ElementUnary::forward_kernel_wrapper( + m, + my_input_accessor[0].get_float_ptr(), + my_output_accessor[0].get_float_ptr(), + my_input_accessor[0].domain.get_volume()); + } else { + assert(false && "Unsupported data type in ElementUnary forward"); + } + break; + } + case OP_RMS_NORM: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_weights[op] == 1); + assert(fused->op_num_outputs[op] == 1); + RMSNormMeta const *m = (RMSNormMeta *)metas->meta[op]; + Kernels::RMSNorm::forward_kernel_wrapper(m, + my_input_accessor[0], + my_weight_accessor[0], + my_output_accessor[0]); + break; + } + case OP_RESIDUAL_RMS_NORM: { + assert(fused->op_num_inputs[op] == 2); + assert(fused->op_num_weights[op] == 1); + assert(fused->op_num_outputs[op] == 2); + ResidualRMSNormMeta const *m = (ResidualRMSNormMeta *)metas->meta[op]; + Kernels::ResidualRMSNorm::forward_kernel_wrapper(m, + my_input_accessor[0], + my_input_accessor[1], + my_weight_accessor[0], + my_output_accessor[0], + my_output_accessor[1]); + break; + } + case OP_INC_MULTIHEAD_SELF_ATTENTION: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + IncMultiHeadSelfAttentionMeta const *m = + (IncMultiHeadSelfAttentionMeta *)metas->meta[op]; + assert(fused->op_num_weights[op] == + (1 + (int)(*m->qkv_bias || *m->final_bias))); + GenericTensorAccessorR biases; + if (*m->qkv_bias || *m->final_bias) { + assert(fused->op_num_weights[op] == 2); + biases = my_weight_accessor[1]; + } + IncMultiHeadSelfAttention::inference_kernel_wrapper( + m, + bc, + task->index_point.point_data[0], + my_input_accessor[0], + my_weight_accessor[0], + my_output_accessor[0], + biases); + break; + } + case OP_TREE_INC_MULTIHEAD_SELF_ATTENTION: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + TreeIncMultiHeadSelfAttentionMeta *m = + (TreeIncMultiHeadSelfAttentionMeta *)metas->meta[op]; + // TreeVerifyBatchConfig const *tree_bc = + // (TreeVerifyBatchConfig *)task->args; + TreeVerifyBatchConfig const &tree_bc = + Future(task->futures[0]).get_result(); + assert(fused->op_num_weights[op] == + (1 + (int)(*m->qkv_bias || *m->final_bias))); + GenericTensorAccessorR biases; + if (*m->qkv_bias || *m->final_bias) { + assert(fused->op_num_weights[op] == 2); + biases = my_weight_accessor[1]; + } + TreeIncMultiHeadSelfAttention::inference_kernel_wrapper( + m, + &tree_bc, + task->index_point.point_data[0], + my_input_accessor[0], + my_weight_accessor[0], + my_output_accessor[0], + biases); + break; + } + case OP_SPEC_INC_MULTIHEAD_SELF_ATTENTION: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + SpecIncMultiHeadSelfAttentionMeta const *m = + (SpecIncMultiHeadSelfAttentionMeta *)metas->meta[op]; + // BeamSearchBatchConfig const *beam_bc = + // (BeamSearchBatchConfig *)task->args; + BeamSearchBatchConfig const &beam_bc = + Future(task->futures[0]).get_result(); + assert(fused->op_num_weights[op] == + (1 + (int)(*m->qkv_bias || *m->final_bias))); + GenericTensorAccessorR biases; + if (*m->qkv_bias || *m->final_bias) { + assert(fused->op_num_weights[op] == 2); + biases = my_weight_accessor[1]; + } + SpecIncMultiHeadSelfAttention::inference_kernel_wrapper( + m, + &beam_bc, + task->index_point.point_data[0], + my_input_accessor[0], + my_weight_accessor[0], + my_output_accessor[0], + biases); + break; + } + case OP_LAYERNORM: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + LayerNormMeta const *m = (LayerNormMeta *)metas->meta[op]; + if (m->elementwise_affine) { + assert(fused->op_num_weights[op] == 1 + (int)(m->use_bias)); + } + GenericTensorAccessorR gamma, beta; + if (m->elementwise_affine) { + gamma = my_weight_accessor[0]; + if (m->use_bias) { + beta = my_weight_accessor[1]; + } + } + LayerNorm::forward_kernel_wrapper( + m, my_input_accessor[0], my_output_accessor[0], gamma, beta); + break; + } + case OP_RESIDUAL_LAYERNORM: { + assert(fused->op_num_outputs[op] == 2); + ResidualLayerNormMeta const *m = + (ResidualLayerNormMeta *)metas->meta[op]; + if (m->use_two_residuals) { + assert(fused->op_num_inputs[op] == 3); + } else { + assert(fused->op_num_inputs[op] == 2); + } + if (!m->elementwise_affine) { + assert(fused->op_num_weights[op] == 0); + } else { + if (!m->use_bias) { + assert(fused->op_num_weights[op] == 1); // weight + } else { + assert(fused->op_num_weights[op] == 2); // weight + bias + } + } + GenericTensorAccessorR residual2; + if (m->use_two_residuals) { + residual2 = my_input_accessor[2]; + } + GenericTensorAccessorR gamma, beta; + if (m->elementwise_affine) { + gamma = my_weight_accessor[0]; + if (m->use_bias) { + beta = my_weight_accessor[1]; + } + } + ResidualLayerNorm::inference_kernel_wrapper(m, + my_input_accessor[0], + my_input_accessor[1], + residual2, + my_output_accessor[0], + my_output_accessor[1], + gamma, + beta); + break; + } + case OP_ADD_BIAS_RESIDUAL_LAYERNORM: { + assert(fused->op_num_inputs[op] == 2); + assert(fused->op_num_outputs[op] == 2); + AddBiasResidualLayerNormMeta const *m = + (AddBiasResidualLayerNormMeta *)metas->meta[op]; + if (!m->elementwise_affine) { + assert(fused->op_num_weights[op] == 1); // attn bias + } else { + if (!m->use_bias) { + assert(fused->op_num_weights[op] == 2); // attn bias + weight + } else { + assert(fused->op_num_weights[op] == 3); // attn bias + weight + bias + } + } + GenericTensorAccessorR gamma, beta; + if (m->elementwise_affine) { + gamma = my_weight_accessor[1]; + if (m->use_bias) { + beta = my_weight_accessor[2]; + } + } + Domain attn_bias_domain = my_weight_accessor[0].domain; + Domain residual_domain = my_input_accessor[1].domain; + int attn_bias_dim = + attn_bias_domain.hi()[0] - attn_bias_domain.lo()[0] + 1; + int residual_volume = residual_domain.get_volume(); + AddBiasResidualLayerNorm::inference_kernel_wrapper( + m, + attn_bias_dim, + residual_volume, + my_input_accessor[0], + my_output_accessor[0], + my_output_accessor[1], + my_input_accessor[1], + my_weight_accessor[0], + gamma, + beta); + break; + } + case OP_SIGMOID_SILU_MULTI: { + assert(fused->op_num_inputs[op] == 2); + assert(fused->op_num_outputs[op] == 1); + SigmoidSiluMultiMeta const *m = (SigmoidSiluMultiMeta *)metas->meta[op]; + SigmoidSiluMulti::inference_kernel_wrapper(m, + my_input_accessor[0], + my_input_accessor[1], + my_output_accessor[0]); + break; + } + case OP_SOFTMAX: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_weights[op] == 0); + assert(fused->op_num_outputs[op] == 1); + assert(my_input_accessor[0].domain.get_volume() == + my_output_accessor[0].domain.get_volume()); + SoftmaxMeta *m = (SoftmaxMeta *)metas->meta[op]; + if (m->input_type == DT_HALF) { + Kernels::Softmax::forward_kernel_wrapper( + m, + my_input_accessor[0].get_half_ptr(), + my_output_accessor[0].get_half_ptr()); + } else if (m->input_type == DT_FLOAT) { + Kernels::Softmax::forward_kernel_wrapper( + m, + my_input_accessor[0].get_float_ptr(), + my_output_accessor[0].get_float_ptr()); + } + break; + } + case OP_ALLREDUCE: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + AllReduceMeta const *m = (AllReduceMeta *)metas->meta[op]; + Kernels::AllReduce::inference_kernel_wrapper( + m, bc, my_input_accessor[0], my_output_accessor[0]); + break; + } + default: { + fprintf(stderr, + "Fusion currently does not support type = %d\n", + fused->op_op_type[op]); + assert(false && "Fusion currently does not support type"); + } + } + if (metas->meta[op]->inference_debugging) { + std::vector input_accessors_to_save; + std::vector weight_accessors_to_save; + std::vector output_accessors_to_save; + for (int i = 0; i < fused->op_num_inputs[op]; i++) { + int my_off = fused->op_input_idx[i + ioff]; + if (fused->op_input_source[i + ioff] == SOURCE_INPUT) { + input_accessors_to_save.push_back(input_accessor[my_off]); + } else if (fused->op_input_source[i + ioff] == SOURCE_OUTPUT) { + input_accessors_to_save.push_back(output_accessor[my_off]); + } else { + assert(false); + } + } + for (int i = 0; i < fused->op_num_weights[op]; i++) { + assert(fused->op_weight_source[i + woff] == SOURCE_WEIGHT); + weight_accessors_to_save.push_back( + weight_accessor[fused->op_weight_idx[i + woff]]); + } + for (int i = 0; i < fused->op_num_outputs[op]; i++) { + output_accessors_to_save.push_back(output_accessor[i + ooff]); + } + assert(task->index_point.get_dim() == 1); + int shard_id = task->index_point.point_data[0]; + FusedOp::save_inference_tensors_to_file(metas->meta[op], + shard_id, + bc, + input_accessors_to_save, + weight_accessors_to_save, + output_accessors_to_save); + } + ioff += fused->op_num_inputs[op]; + woff += fused->op_num_weights[op]; + ooff += fused->op_num_outputs[op]; + + } } - // for (int i = 0; i < fused->numOutputs; i++) - // print_tensor(output_ptr[i], output_domain[i].get_volume(), - // "[Fused:forward:output]"); } + /* regions[...](I): input regions[...](I): weight diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 40f758282c..bbd8b33bf1 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -4063,7 +4063,7 @@ void FFIterationConfig::reset() { struct DefaultConfig { const static int epochs = 1; // const static int iterations = 1; - const static int batchSize = 64; + const static int batchSize = 2; const static bool profiling = false; const static bool inference_debugging = false; constexpr static float learningRate = 0.01f; diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 16513e918a..92984cedb8 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -54,7 +54,6 @@ RequestManager::RequestManager() // ffmodel.compile() max_requests_per_batch = -1; max_tokens_per_batch = -1; - max_spec_tree_token_num = -1; max_sequence_length = -1; } @@ -76,27 +75,15 @@ void RequestManager::set_max_tokens_per_batch(int max_num_tokens) { assert(max_tokens_per_batch <= BatchConfig::MAX_NUM_TOKENS); } -void RequestManager::set_max_spec_tree_token_num(int max_num_tokens) { - assert(max_spec_tree_token_num == -1 || - max_spec_tree_token_num == max_num_tokens); - max_spec_tree_token_num = max_num_tokens; - assert(max_spec_tree_token_num <= BatchConfig::MAX_SPEC_TREE_TOKEN_NUM); -} - int RequestManager::get_max_tokens_per_batch() { assert(max_tokens_per_batch > 0); return max_tokens_per_batch; } -int RequestManager::get_max_spec_tree_token_num() { - assert(max_spec_tree_token_num > 0); - return max_spec_tree_token_num; -} - int RequestManager::get_max_verify_tokens_per_batch() { assert(max_tokens_per_batch > 0); return max_tokens_per_batch + - max_spec_tree_token_num * max_requests_per_batch; + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM * max_requests_per_batch; } void RequestManager::set_max_sequence_length(int max_seq_length) { @@ -363,6 +350,57 @@ BatchConfig RequestManager::prepare_next_batch_task( return rm->prepare_next_batch(*bc, result); } +BatchConfig RequestManager::prepare_prefilling_batch(int i) { + const std::lock_guard lock(request_queue_mutex); + + BatchConfig new_bc; + + // mark empty requests as completed + for(int j = 0; j < BatchConfig::max_requests_per_batch(); j++) { + if (j == i) { + new_bc.request_completed[j] = false; + } else { + new_bc.request_completed[i] = true; + } + } + + // pop top request from the queue + Request new_request = pending_request_queue.front(); + pending_request_queue.pop(); + new_request.status = Request::PREFILLING; + all_requests[new_request.guid] = new_request; + + new_bc.requestsInfo[i].first_token_depth_in_request = 0; + new_bc.requestsInfo[i].first_token_offset_in_batch = 0; + new_bc.requestsInfo[i].request_guid = new_request.guid; + new_bc.requestsInfo[i].num_tokens_in_batch = + std::min(get_max_tokens_per_batch(), + (int)new_request.tokens.size()); + new_bc.requestsInfo[i].max_sequence_length = + new_request.max_sequence_length; + new_bc.request_completed[i] = false; + new_bc.requestsInfo[i].prompt_phase = true; + new_bc.requestsInfo[0].batch_config_request_id = i; + + // add profile_info for the new request + ProfileInfo profile_info; + profile_info.llm_decoding_steps = 1; + profile_info.start_time = Realm::Clock::current_time_in_microseconds(); + profiling_requests[new_request.guid] = profile_info; + + // add tokens to the batch + for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) { + int depth = new_bc.requestsInfo[i].first_token_depth_in_request + j; + new_bc.tokensInfo[new_bc.num_tokens].request_index = i; + new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = depth; + assert(depth < new_request.tokens.size()); + new_bc.tokensInfo[new_bc.num_tokens].token_id = + new_request.tokens[depth]; + new_bc.num_tokens++; + } + return new_bc; +} + BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, InferenceResult const &result) { const std::lock_guard lock(request_queue_mutex); @@ -385,11 +423,21 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, // log_req_mgr.print("Output: %s", output.c_str()); } } + int num_generation_tokens = 0; int num_active_req = -1; // Step 2: prepare the next batch for existing requests BatchConfig new_bc; + if (buffer_bc != nullptr) { + new_bc = *buffer_bc; + buffer_bc = nullptr; + for (int i = 0; i < BatchConfig::max_requests_per_batch(); i++) { + if (!new_bc.request_completed[i]) { + num_active_req++; + } + } + for (int i = 0; i < BatchConfig::max_requests_per_batch(); i++) { if (old_bc.request_completed[i]) { // add new requests to the next batch continue; @@ -424,6 +472,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, gr.output_text = output; } request.status = Request::COMPLETED; + new_bc.request_completed[i] = true; trigger_request_completion_future(request.guid); log_req_mgr.print("[Done] guid(%zu) final_length(%zu)", old_bc.requestsInfo[i].request_guid, @@ -448,10 +497,10 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, std::ofstream outputFile(output_filepath, std::ios::app); if (outputFile.is_open()) { outputFile << "end-to-end latency: " << std::fixed - << std::setprecision(3) << total_request_run_time - << std::endl; + << std::setprecision(3) << total_request_run_time + << std::endl; outputFile << "num decoding steps: " - << profile_info.llm_decoding_steps << std::endl; + << profile_info.llm_decoding_steps << std::endl; outputFile << "token IDs: "; for (int i = 0; i < request.tokens.size(); i++) { outputFile << request.tokens[i]; @@ -489,8 +538,8 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, // Prompt phase new_bc.requestsInfo[i].num_tokens_in_batch = std::min(get_max_tokens_per_batch() - new_bc.num_tokens, - (int)request.tokens.size() - - new_bc.requestsInfo[i].first_token_depth_in_request); + (int)request.tokens.size() - + new_bc.requestsInfo[i].first_token_depth_in_request); new_bc.requestsInfo[i].prompt_phase = true; } for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) { @@ -514,39 +563,8 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, if (new_bc.request_completed[i]) { if (!pending_request_queue.empty() && new_bc.num_tokens < get_max_tokens_per_batch()) { - Request new_request = pending_request_queue.front(); - pending_request_queue.pop(); - // all_requests[new_request.guid] = new_request; - - new_bc.requestsInfo[i].first_token_depth_in_request = 0; - new_bc.requestsInfo[i].first_token_offset_in_batch = new_bc.num_tokens; - new_bc.requestsInfo[i].request_guid = new_request.guid; - new_bc.requestsInfo[i].num_tokens_in_batch = - std::min(get_max_tokens_per_batch() - new_bc.num_tokens, - (int)new_request.tokens.size()); - new_bc.requestsInfo[i].max_sequence_length = - new_request.max_sequence_length; - new_bc.request_completed[i] = false; - new_bc.requestsInfo[i].prompt_phase = true; - num_active_req++; - new_bc.requestsInfo[num_active_req].batch_config_request_id = i; - // add profile_info for the new request - ProfileInfo profile_info; - profile_info.llm_decoding_steps = 1; - profile_info.start_time = Realm::Clock::current_time_in_microseconds(); - profiling_requests[new_request.guid] = profile_info; - for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) { - int depth = new_bc.requestsInfo[i].first_token_depth_in_request + j; - new_bc.tokensInfo[new_bc.num_tokens].request_index = i; - new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = depth; - assert(depth < new_request.tokens.size()); - new_bc.tokensInfo[new_bc.num_tokens].token_id = - new_request.tokens[depth]; - new_bc.num_tokens++; - } - if (new_bc.num_tokens == get_max_tokens_per_batch()) { - break; - } + buffer_bc = &new_bc; + new_bc = prepare_prefilling_batch(i); } } } @@ -1577,11 +1595,9 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( } if (new_bc.num_tokens > get_max_verify_tokens_per_batch()) { - printf("Exceeding (%i) the space available (%i) in the TreeVerify " - "batch\n", - new_bc.num_tokens, - get_max_verify_tokens_per_batch()); - assert(false); + assert(false && + "Exceeding the space available in the TreeVerify batch"); + break; } if (new_bc.requestsInfo[i].num_tokens_in_batch +