From ed5a2e07fdc9285612f167c150f8d138e51895f7 Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Mon, 25 Dec 2023 12:17:48 -0500 Subject: [PATCH 01/30] init --- include/flexflow/batch_config.h | 12 + include/flexflow/config.h | 9 + include/flexflow/ffconst.h | 1 + include/flexflow/model.h | 45 + include/flexflow/operator_params.h | 2 + .../specinfer_inc_multihead_self_attention.h | 150 +++ ...nfer_inc_multihead_self_attention_params.h | 33 + include/flexflow/request_manager.h | 14 +- inference/file_loader.cc | 3 +- inference/models/llama.cc | 5 +- inference/spec_infer/spec_infer.cc | 3 + src/ops/inc_multihead_self_attention.cpp | 19 + src/ops/inc_multihead_self_attention.cu | 61 +- .../specinfer_inc_multihead_self_attention.cc | 883 +++++++++++++++++ .../specinfer_inc_multihead_self_attention.cu | 890 ++++++++++++++++++ src/ops/tree_inc_multihead_self_attention.cu | 24 +- src/runtime/ffconst_utils.cc | 2 + src/runtime/graph.cc | 71 +- src/runtime/inference_manager.cc | 13 +- src/runtime/model.cc | 149 ++- src/runtime/model.cpp | 48 + src/runtime/model.cu | 28 +- src/runtime/request_manager.cc | 250 +++-- src/runtime/request_manager.cpp | 16 + src/runtime/request_manager.cu | 50 + 25 files changed, 2589 insertions(+), 192 deletions(-) create mode 100644 include/flexflow/ops/specinfer_inc_multihead_self_attention.h create mode 100644 include/flexflow/ops/specinfer_inc_multihead_self_attention_params.h create mode 100644 src/ops/specinfer_inc_multihead_self_attention.cc create mode 100644 src/ops/specinfer_inc_multihead_self_attention.cu diff --git a/include/flexflow/batch_config.h b/include/flexflow/batch_config.h index e2903c4d11..c33c3558cc 100644 --- a/include/flexflow/batch_config.h +++ b/include/flexflow/batch_config.h @@ -129,6 +129,9 @@ class BeamSearchBatchConfig : public BatchConfig { inline static int const MAX_BEAM_WIDTH = 1; inline static int const MAX_BEAM_DEPTH = 8; + // maximum tree branches for a request + inline static int const MAX_SPECULATIVE_TREE_BRANCHES = 9; + int model_id; struct BeamSearchPerRequestInfo { @@ -139,14 +142,23 @@ class BeamSearchBatchConfig : public BatchConfig { BatchConfig::TokenId tokens[BeamSearchBatchConfig::MAX_BEAM_WIDTH]; float probs[BeamSearchBatchConfig::MAX_BEAM_WIDTH]; int parent_id[BeamSearchBatchConfig::MAX_BEAM_WIDTH]; + int sub_request_num; }; struct BeamSearchPerTokenInfo { int sub_request_index; }; + struct SpecInferTopology { + int real_token_pos[MAX_SPECULATIVE_TREE_BRANCHES][MAX_NUM_TOKENS]; + int allocated_tokens; + }; + + BeamSearchPerRequestInfo beamRequestsInfo[MAX_NUM_REQUESTS]; BeamSearchPerTokenInfo beamTokenInfo[MAX_NUM_TOKENS * MAX_BEAM_WIDTH]; + SpecInferTopology topology_mask[MAX_NUM_REQUESTS]; + // why is this == MAX_NUM_REQUESTS * MAX_BEAM_WIDTH? int sub_requests[MAX_NUM_REQUESTS * MAX_BEAM_WIDTH]; diff --git a/include/flexflow/config.h b/include/flexflow/config.h index c2af6d707c..321d14961b 100644 --- a/include/flexflow/config.h +++ b/include/flexflow/config.h @@ -16,6 +16,7 @@ #ifndef _FLEXFLOW_CONFIG_H_ #define _FLEXFLOW_CONFIG_H_ #include "ffconst.h" +#include "flexflow/batch_config.h" #include "legion.h" #include #if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) @@ -75,6 +76,14 @@ struct FFHandler { #endif void *workSpace; size_t workSpaceSize; + void *batch_config_metadata; + + // request info + token info + topolopgy mask info + size_t batch_config_metadata_size = + sizeof(BatchConfig::tokensInfo) + sizeof(BatchConfig::requestsInfo) + + sizeof(BeamSearchBatchConfig::topology_mask) + + sizeof(BeamSearchBatchConfig::beamTokenInfo) + + sizeof(BeamSearchBatchConfig::beamRequestsInfo); void *offload_reserve_space; size_t offload_reserve_space_size; DataType quantization_type; diff --git a/include/flexflow/ffconst.h b/include/flexflow/ffconst.h index 512645e624..ef0003b08e 100644 --- a/include/flexflow/ffconst.h +++ b/include/flexflow/ffconst.h @@ -171,6 +171,7 @@ enum OperatorType { OP_INC_MULTIHEAD_SELF_ATTENTION, OP_SPEC_INC_MULTIHEAD_SELF_ATTENTION, OP_TREE_INC_MULTIHEAD_SELF_ATTENTION, + OP_SPECINFER_INC_MULTIHEAD_SELF_ATTENTION, OP_SAMPLING, // Parallel Ops OP_REPARTITION, diff --git a/include/flexflow/model.h b/include/flexflow/model.h index d8402ba622..3602cb108b 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -172,6 +172,8 @@ enum TaskIDs { SPEC_INC_MULTIHEAD_SELF_ATTENTION_INF_TASK_ID, TREE_INC_MULTIHEAD_SELF_ATTENTION_INIT_TASK_ID, TREE_INC_MULTIHEAD_SELF_ATTENTION_INF_TASK_ID, + SPECINFER_INC_MULTIHEAD_SELF_ATTENTION_INIT_TASK_ID, + SPECINFER_INC_MULTIHEAD_SELF_ATTENTION_INF_TASK_ID, MSELOSS_BWD_TASK_ID, FUSEDOP_INIT_TASK_ID, FUSEDOP_FWD_TASK_ID, @@ -324,6 +326,7 @@ class Linear; class MultiHeadAttention; class IncMultiHeadSelfAttention; class TreeIncMultiHeadSelfAttention; +class SpecInferIncMultiHeadSelfAttention; class Pool2D; class Reduce; class Reshape; @@ -743,6 +746,25 @@ class FFModel { bool qk_prod_scaling = true, bool position_bias = false, char const *name = NULL); + +Tensor specinfer_inc_multihead_self_attention( + const Tensor input, + int embed_dim, + int num_heads, + int kdim = 0, + int vdim = 0, + float dropout = 0.0f, + bool bias = false, + bool add_bias_kv = false, + bool add_zero_attn = false, + DataType data_type = DT_NONE, + Initializer *kernel_initializer = NULL, + bool apply_rotary_embedding = false, + bool scaling_query = false, + float scaling_factor = 1.0f, + bool qk_prod_scaling = true, + bool position_bias = false, + char const *name = NULL); Tensor inc_multiquery_self_attention(const Tensor input, int embed_dim, int num_q_heads, @@ -799,6 +821,26 @@ class FFModel { bool qk_prod_scaling = true, bool position_bias = false, char const *name = NULL); + + Tensor specinfer_inc_multiquery_self_attention( + const Tensor input, + int embed_dim, + int num_q_heads, + int num_kv_heads, + int kdim = 0, + int vdim = 0, + float dropout = 0.0f, + bool bias = false, + bool add_bias_kv = false, + bool add_zero_attn = false, + DataType data_type = DT_NONE, + Initializer *kernel_initializer = NULL, + bool apply_rotary_embedding = false, + bool scaling_query = false, + float scaling_factor = 1.0f, + bool qk_prod_scaling = true, + bool position_bias = false, + char const *name = NULL); // ======================================== // Inference APIs // ======================================== @@ -1200,6 +1242,9 @@ class FFModel { std::unordered_map< std::pair, TreeIncMultiHeadSelfAttention *>, + std::unordered_map< + std::pair, + SpecInferIncMultiHeadSelfAttention *>, std::unordered_map, Reduce *>, std::unordered_map, diff --git a/include/flexflow/operator_params.h b/include/flexflow/operator_params.h index 5b187839ef..cee2ae95a4 100644 --- a/include/flexflow/operator_params.h +++ b/include/flexflow/operator_params.h @@ -37,6 +37,7 @@ #include "flexflow/ops/topk_params.h" #include "flexflow/ops/transpose_params.h" #include "flexflow/ops/tree_inc_multihead_self_attention_params.h" +#include "flexflow/ops/specinfer_inc_multihead_self_attention_params.h" #include "flexflow/parallel_ops/allreduce_params.h" #include "flexflow/parallel_ops/combine_params.h" #include "flexflow/parallel_ops/fused_parallel_op_params.h" @@ -72,6 +73,7 @@ using OperatorParameters = mp::variant +#include + +namespace FlexFlow { + +class SpecInferIncMultiHeadSelfAttentionMeta; + +class SpecInferIncMultiHeadSelfAttention : public Op { +public: + using Params = SpecInferIncMultiHeadSelfAttentionParams; + using Input = ParallelTensor; + + SpecInferIncMultiHeadSelfAttention(FFModel &model, + LayerID const &layer_guid, + const ParallelTensor _input, + int _embed_dim, + int _num_q_heads, + int _num_kv_heads, + int _kdim, + int _vdim, + float _dropout, + bool _qkv_bias, + bool _final_bias, + bool _add_zero_attn, + bool _apply_rotary_embedding, + bool _scaling_query, + float _scaling_factor, + bool _qk_prod_scaling, + bool _position_bias, + bool allocate_weights, + char const *name); + SpecInferIncMultiHeadSelfAttention(FFModel &model, + const ParallelTensor _input, + const ParallelTensor _weight, + int _embed_dim, + int _num_q_heads, + int _num_kv_heads, + int _kdim, + int _vdim, + float _dropout, + bool _qkv_bias, + bool _final_bias, + bool _add_zero_attn, + bool _apply_rotary_embedding, + bool _scaling_query, + float _scaling_factor, + bool _qk_prod_scaling, + bool _position_bias, + bool allocate_weights, + char const *name); + SpecInferIncMultiHeadSelfAttention(FFModel &model, + SpecInferIncMultiHeadSelfAttention const &other, + const ParallelTensor input, + bool allocate_weights); + SpecInferIncMultiHeadSelfAttention(FFModel &model, + Params const ¶ms, + Input const &inputs, + bool allocate_weights = false, + char const *name = nullptr); + static Op * + create_operator_from_layer(FFModel &model, + Layer const *layer, + std::vector const &inputs); + void init(FFModel const &) override; + void init_inference(FFModel const &, + std::vector const &, + std::vector const &, + MachineView const *mv = nullptr) override; + void forward(FFModel const &) override; + void backward(FFModel const &) override; + Legion::FutureMap inference(FFModel const &, + BatchConfigFuture const &, + std::vector const &, + std::vector const &, + MachineView const *mv = nullptr) override; + void print_layer(FFModel const &model) override { + assert(0); + } + bool get_int_parameter(PMParameter, int *) const override; + + static OpMeta *init_task(Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); + static void inference_task(Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); + Op *materialize(FFModel &ff, + ParallelTensor inputs[], + int num_inputs) const override; + bool measure_operator_cost(Simulator *sim, + MachineView const &mv, + CostMetrics &cost_metrics) const override; + + static void + inference_kernel_wrapper(SpecInferIncMultiHeadSelfAttentionMeta const *m, + BeamSearchBatchConfig const *bc, + int shard_id, + GenericTensorAccessorR const &input, + GenericTensorAccessorR const &weight, + GenericTensorAccessorW const &output, + GenericTensorAccessorR const &bias); + Params get_params() const; + +public: + int num_q_heads, num_kv_heads, tensor_parallelism_degree; + float dropout, scaling_factor; + bool qkv_bias; + bool final_bias, add_zero_attn, apply_rotary_embedding, scaling_query, + qk_prod_scaling, position_bias; + int qSize, kSize, vSize, qProjSize, kProjSize, vProjSize, oProjSize; + int qoSeqLength, kvSeqLength; +}; + +class SpecInferIncMultiHeadSelfAttentionMeta : public IncMultiHeadSelfAttentionMeta { +public: + SpecInferIncMultiHeadSelfAttentionMeta(FFHandler handler, + SpecInferIncMultiHeadSelfAttention const *attn, + GenericTensorAccessorR const &weight, + MemoryAllocator &gpu_mem_allocator, + int num_samples, + int _num_q_heads, + int _num_kv_heads); + ~SpecInferIncMultiHeadSelfAttentionMeta(void); + +public: + Realm::RegionInstance beam_search_reserve_inst; + BeamSearchBatchConfig::BeamSearchPerTokenInfo *beam_token_infos; + BeamSearchBatchConfig::BeamSearchPerRequestInfo *beam_request_infos; + BeamSearchBatchConfig::SpecInferTopology *beam_topology_mask; +}; + +}; // namespace FlexFlow + +#endif // _FLEXFLOW_SPECINFER_INC_MULTIHEAD_SELF_ATTENTION_H diff --git a/include/flexflow/ops/specinfer_inc_multihead_self_attention_params.h b/include/flexflow/ops/specinfer_inc_multihead_self_attention_params.h new file mode 100644 index 0000000000..b57b06a7f7 --- /dev/null +++ b/include/flexflow/ops/specinfer_inc_multihead_self_attention_params.h @@ -0,0 +1,33 @@ +#ifndef _FLEXFLOW_SPECINFER_INC_MULTIHEAD_SELF_ATTENTION_PARAMS_H +#define _FLEXFLOW_SPECINFER_INC_MULTIHEAD_SELF_ATTENTION_PARAMS_H + +#include "flexflow/ffconst.h" +#include "flexflow/fftype.h" +#include "flexflow/parallel_tensor.h" + +namespace FlexFlow { + +struct SpecInferIncMultiHeadSelfAttentionParams { + LayerID layer_guid; + int embed_dim, num_q_heads, num_kv_heads, kdim, vdim; + float dropout, scaling_factor; + bool qkv_bias, final_bias, add_zero_attn, apply_rotary_embedding, + scaling_query, qk_prod_scaling, position_bias; + + bool is_valid(ParallelTensorShape const &) const; +}; + +bool operator==(SpecInferIncMultiHeadSelfAttentionParams const &, + SpecInferIncMultiHeadSelfAttentionParams const &); + +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t + operator()(FlexFlow::SpecInferIncMultiHeadSelfAttentionParams const &) const; +}; +} // namespace std + +#endif // _FLEXFLOW_SPECINFER_INC_MULTIHEAD_SELF_ATTENTION_PARAMS_H diff --git a/include/flexflow/request_manager.h b/include/flexflow/request_manager.h index baf6844801..e67888d2d6 100644 --- a/include/flexflow/request_manager.h +++ b/include/flexflow/request_manager.h @@ -38,7 +38,8 @@ class InferenceManager { Legion::FutureMap inference(FFModel *model, int index, BatchConfigFuture const &bc); void load_input_tokens_from_batch_config(BatchConfigFuture const &bc, - ParallelTensor const input); + ParallelTensor const input, + FFHandler *handlers); void load_positions(BatchConfigFuture const &bc, ParallelTensor position_input, int offset); @@ -72,9 +73,10 @@ struct Request { struct BeamTree { struct treeLayer { BeamSearchBatchConfig::TokenId - tokens[BeamSearchBatchConfig::MAX_BEAM_WIDTH]; + tokens[BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES]; int parent_ids[BeamSearchBatchConfig::MAX_BEAM_WIDTH]; - float probs[BeamSearchBatchConfig::MAX_BEAM_WIDTH]; + float probs[BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES]; + int nodes_num_this_layer = 0; }; treeLayer treeLayers[BeamSearchBatchConfig::MAX_BEAM_DEPTH + 1]; }; @@ -100,6 +102,7 @@ class RequestManager { void set_max_tokens_per_batch(int max_num_tokens); int get_max_tokens_per_batch(); void set_max_sequence_length(int max_seq_length); + void push_spec_infer_tree_width(int tree_width); int get_max_sequence_length(); int register_ssm_model(FFModel *model); void register_tokenizer(ModelType model_type, @@ -148,6 +151,7 @@ class RequestManager { void store_beam_metadata(BeamSearchBatchConfig const &old_bc, BeamInferenceResult const &result); void update_beam_metadata(BeamSearchBatchConfig &new_bc, + BeamSearchBatchConfig const &old_bc, BeamTree &tree, int request_index); @@ -210,6 +214,7 @@ class RequestManager { int max_requests_per_batch; int max_tokens_per_batch; int max_sequence_length; + std::vector spec_infer_tree_width; // private fields std::unique_ptr tokenizer_; bool verbose; @@ -243,7 +248,8 @@ class RequestManager { private: struct ProfileInfo { - int decoding_steps; + int llm_decoding_steps; + int ssm_decoding_steps; double start_time, finish_time; }; std::unordered_map profiling_requests; diff --git a/inference/file_loader.cc b/inference/file_loader.cc index 7c6870d439..3f70ddf488 100644 --- a/inference/file_loader.cc +++ b/inference/file_loader.cc @@ -726,7 +726,8 @@ void FileDataLoader::load_single_weight_tensor(FFModel *ff, if (l->op_type == OP_INC_MULTIHEAD_SELF_ATTENTION || l->op_type == OP_SPEC_INC_MULTIHEAD_SELF_ATTENTION || - l->op_type == OP_TREE_INC_MULTIHEAD_SELF_ATTENTION) { + l->op_type == OP_TREE_INC_MULTIHEAD_SELF_ATTENTION || + l->op_type == OP_SPECINFER_INC_MULTIHEAD_SELF_ATTENTION) { if (weight_filename.find("self_attention") != std::string::npos) { load_attention_weights_multi_query( data, weight_filename, weights_folder, hidden_dim, num_heads); diff --git a/inference/models/llama.cc b/inference/models/llama.cc index b8fe70526d..f62df1b1d7 100644 --- a/inference/models/llama.cc +++ b/inference/models/llama.cc @@ -90,7 +90,7 @@ void LLAMA::create_llama_model(FFModel &ff, Tensor mha; switch (mode) { case BEAM_SEARCH_MODE: { - mha = ff.spec_inc_multihead_self_attention( + mha = ff.specinfer_inc_multihead_self_attention( att_norm, llama_config.hidden_size, llama_config.num_attention_heads, @@ -246,7 +246,8 @@ void LLAMA::create_llama_model(FFModel &ff, if (mode == BEAM_SEARCH_MODE) { Tensor softmax = ff.softmax(dense, -1); // output = ff.beam_top_k(softmax, llama_config.max_beam_width, false); - output = ff.argmax(softmax, /*beam_Search*/ true); + // output = ff.argmax(softmax, /*beam_Search*/ true); + output = ff.beam_top_k(softmax, llama_config.max_beam_width, false); } else { // Tensor softmax = ff.softmax(dense, -1); if (generation_config.do_sample) { diff --git a/inference/spec_infer/spec_infer.cc b/inference/spec_infer/spec_infer.cc index 8b0eb926d9..e2594ba87f 100644 --- a/inference/spec_infer/spec_infer.cc +++ b/inference/spec_infer/spec_infer.cc @@ -302,6 +302,9 @@ void FlexFlow::top_level_task(Task const *task, model_metadata.llm_tokenizer_path); rm->register_output_filepath(file_paths.output_file_path); + //first decoding step: 3 results + rm->push_spec_infer_tree_width(1); + // Create LLM model FFModel tree_model(ffconfig, ffconfig.cpu_offload); if (model_metadata.llm_model_type == ModelType::LLAMA) { diff --git a/src/ops/inc_multihead_self_attention.cpp b/src/ops/inc_multihead_self_attention.cpp index d60386f927..a59740f4a3 100644 --- a/src/ops/inc_multihead_self_attention.cpp +++ b/src/ops/inc_multihead_self_attention.cpp @@ -1098,4 +1098,23 @@ template void Kernels::IncMultiHeadAttention::pre_build_weight_kernel( DataType data_type, hipStream_t stream); +template void Kernels::IncMultiHeadAttention::compute_o_prod_bias( + IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + int shard_id, + float *output_ptr, + float const *weight_ptr, + float const *bias_ptr, + int num_tokens, + cudaStream_t stream); +template void Kernels::IncMultiHeadAttention::compute_o_prod_bias( + IncMultiHeadSelfAttentionMeta const *m, + BatchConfig const *bc, + int shard_id, + half *output_ptr, + half const *weight_ptr, + half const *bias_ptr, + int num_tokens, + cudaStream_t stream); + }; // namespace FlexFlow diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index 695f4b13b9..4c184acb3c 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -826,17 +826,17 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta const *m, } // todo Xinhao copy how many requests if requests are not continous? - cudaMemcpyAsync(m->token_infos, - &(bc->tokensInfo), - bc->num_active_tokens() * sizeof(BatchConfig::PerTokenInfo), - cudaMemcpyHostToDevice, - stream); - cudaMemcpyAsync(m->request_infos, - &(bc->requestsInfo), - bc->max_requests_per_batch() * - sizeof(BatchConfig::PerRequestInfo), - cudaMemcpyHostToDevice, - stream); + // cudaMemcpyAsync(m->token_infos, + // &(bc->tokensInfo), + // bc->num_active_tokens() * + // sizeof(BatchConfig::PerTokenInfo), cudaMemcpyHostToDevice, + // stream); + // cudaMemcpyAsync(m->request_infos, + // &(bc->requestsInfo), + // bc->max_requests_per_batch() * + // sizeof(BatchConfig::PerRequestInfo), + // cudaMemcpyHostToDevice, + // stream); // phase 1: Implement kernel to compute KQV for input tokens compute_qkv_kernel(m, @@ -1375,14 +1375,15 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( break; } case BEAM_SEARCH_MODE: { + // a K-ary tree max node is (k^n - 1) / 2 key_cache_size = num_q_heads * kProjSize * BeamSearchBatchConfig::max_requests_per_batch() * BatchConfig::max_sequence_length() * - BeamSearchBatchConfig::MAX_BEAM_WIDTH; + BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES; value_cache_size = num_q_heads * vProjSize * BeamSearchBatchConfig::max_requests_per_batch() * BatchConfig::max_sequence_length() * - BeamSearchBatchConfig::MAX_BEAM_WIDTH; + BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES; break; } default: @@ -1400,10 +1401,7 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( (qkv_max_proj_size + key_cache_size + value_cache_size + 2 * qk_prod_size + attn_heads_size) * size_of_dt + - tokeninfo_size * sizeof(BatchConfig::PerTokenInfo) + - complex_size * sizeof(cuFloatComplex) + - requestinfo_size * - sizeof(BatchConfig::PerRequestInfo); // more components will + complex_size * sizeof(cuFloatComplex); // more components will // be added here later if (offload) { // assert that we have enough reserved work space left @@ -1447,10 +1445,15 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( valueCache = gpu_mem_allocator.allocate_instance_untyped(value_cache_size * size_of_dt); + token_infos = + static_cast(handler.batch_config_metadata); + request_infos = static_cast( + handler.batch_config_metadata + sizeof(BatchConfig::tokensInfo)); + if (offload) { - token_infos = - gpu_mem_allocator.allocate_reserved( - tokeninfo_size); + // token_infos = + // gpu_mem_allocator.allocate_reserved( + // tokeninfo_size); // offset += sizeof(BatchConfig::PerTokenInfo) * tokeninfo_size; qk_prods = gpu_mem_allocator.allocate_reserved_untyped(qk_prod_size * size_of_dt); @@ -1464,13 +1467,13 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( complex_input = gpu_mem_allocator.allocate_reserved(complex_size); // offset += complex_size * sizeof(cuFloatComplex); - request_infos = - gpu_mem_allocator.allocate_reserved( - requestinfo_size); + // request_infos = + // gpu_mem_allocator.allocate_reserved( + // requestinfo_size); } else { - token_infos = - gpu_mem_allocator.allocate_instance( - tokeninfo_size); + // token_infos = + // gpu_mem_allocator.allocate_instance( + // tokeninfo_size); qk_prods = gpu_mem_allocator.allocate_instance_untyped(qk_prod_size * size_of_dt); qk_prods_softmax = gpu_mem_allocator.allocate_instance_untyped( @@ -1479,9 +1482,9 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( size_of_dt); complex_input = gpu_mem_allocator.allocate_instance(complex_size); - request_infos = - gpu_mem_allocator.allocate_instance( - requestinfo_size); + // request_infos = + // gpu_mem_allocator.allocate_instance( + // requestinfo_size); } // allocate more size for quantization data diff --git a/src/ops/specinfer_inc_multihead_self_attention.cc b/src/ops/specinfer_inc_multihead_self_attention.cc new file mode 100644 index 0000000000..42074f39e4 --- /dev/null +++ b/src/ops/specinfer_inc_multihead_self_attention.cc @@ -0,0 +1,883 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "flexflow/ops/specinfer_inc_multihead_self_attention.h" +#include "flexflow/ffconst_utils.h" +#include "flexflow/model.h" +#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) +#include "flexflow/utils/cuda_helper.h" +#else +#include "flexflow/utils/hip_helper.h" +#endif +#include "flexflow/utils/hash_utils.h" +#include "legion/legion_utilities.h" + +namespace FlexFlow { + +// declare Legion names +using Legion::ArgumentMap; +using Legion::Context; +using Legion::coord_t; +using Legion::Domain; +using Legion::Future; +using Legion::FutureMap; +using Legion::IndexLauncher; +using Legion::Machine; +using Legion::Memory; +using Legion::PhysicalRegion; +using Legion::Predicate; +using Legion::Rect; +using Legion::RegionRequirement; +using Legion::Runtime; +using Legion::Task; +using Legion::TaskArgument; +using Legion::TaskLauncher; +using PCG::Node; + +bool SpecInferIncMultiHeadSelfAttentionParams::is_valid( + ParallelTensorShape const &input) const { + bool is_valid = input.is_valid(); + return is_valid; +} + +Tensor FFModel::specinfer_inc_multihead_self_attention( + Tensor const input, + int embed_dim, + int num_heads, + int kdim, + int vdim, + float dropout, + bool qkv_bias, + bool final_bias, + bool add_zero_attn, + DataType data_type, + Initializer *kernel_initializer, + bool apply_rotary_embedding, + bool scaling_query, + float scaling_factor, + bool qk_prod_scaling, + bool position_bias, + char const *name) { + return specinfer_inc_multiquery_self_attention(input, + embed_dim, + num_heads, + num_heads, + kdim, + vdim, + dropout, + qkv_bias, + final_bias, + add_zero_attn, + data_type, + kernel_initializer, + apply_rotary_embedding, + scaling_query, + scaling_factor, + qk_prod_scaling, + position_bias, + name); +} + +Tensor FFModel::specinfer_inc_multiquery_self_attention( + Tensor const input, + int embed_dim, + int num_q_heads, + int num_kv_heads, + int kdim, + int vdim, + float dropout, + bool qkv_bias, + bool final_bias, + bool add_zero_attn, + DataType data_type, + Initializer *kernel_initializer, + bool apply_rotary_embedding, + bool scaling_query, + float scaling_factor, + bool qk_prod_scaling, + bool position_bias, + char const *name) { + if (data_type == DT_NONE) { + data_type = input->data_type; + } + Layer *li = nullptr; + int weight_num = (qkv_bias || final_bias) ? 2 : 1; + if (data_type != input->data_type) { + Tensor casted_input = cast(input, data_type, "type cast for IncMHA"); + li = new Layer(this, + OP_SPECINFER_INC_MULTIHEAD_SELF_ATTENTION, + data_type, + name, + 1 /*inputs*/, + weight_num /*weights*/, + 1 /*outputs*/, + casted_input); + } else { + li = new Layer(this, + OP_SPECINFER_INC_MULTIHEAD_SELF_ATTENTION, + data_type, + name, + 1 /*inputs*/, + weight_num /*weights*/, + 1 /*outputs*/, + input); + } + { + int numdims = input->num_dims; + int dims[MAX_TENSOR_DIM]; + for (int i = 0; i < numdims; i++) { + dims[i] = input->dims[i]; + } + dims[0] = embed_dim; + li->outputs[0] = create_tensor_legion_ordering( + numdims, dims, data_type, li, 0, true /*create_grad*/); + } + // Compute weight size + int qProjSize = kdim, kProjSize = kdim, vProjSize = kdim, + oProjSize = embed_dim; + int qSize = input->dims[0], kSize = input->dims[0], vSize = input->dims[0]; + int qParas = qProjSize * qSize; + int kParas = kProjSize * kSize; + int vParas = vProjSize * vSize; + int oParas = oProjSize * (vProjSize > 0 ? vProjSize : vSize); + int weight_size = qParas * num_q_heads + kParas * num_q_heads + + vParas * num_q_heads + oParas * num_q_heads; + { + int dims[1] = {weight_size}; + li->weights[0] = create_weight_legion_ordering(1, + dims, + data_type, + li, + true /*create_grad*/, + kernel_initializer, + CHOSEN_SYNC_TYPE); + } + if (qkv_bias || final_bias) { + // q, k, v, o + int qkv_bias_size = + qProjSize * num_q_heads + (kProjSize + vProjSize) * num_q_heads; + int dims[1] = {(qkv_bias ? qkv_bias_size : 0) + + (final_bias ? oProjSize : 0)}; + li->weights[1] = create_weight_legion_ordering(1, + dims, + data_type, + li, + true /*create_grad*/, + kernel_initializer, + CHOSEN_SYNC_TYPE); + } + li->data_type = data_type; + li->add_int_property("embed_dim", embed_dim); + li->add_int_property("num_q_heads", num_q_heads); + li->add_int_property("num_kv_heads", num_kv_heads); + li->add_int_property("kdim", kdim); + li->add_int_property("vdim", vdim); + li->add_int_property("qkv_bias", qkv_bias); + li->add_int_property("final_bias", final_bias); + li->add_int_property("add_zero_attn", add_zero_attn); + li->add_float_property("dropout", dropout); + li->add_int_property("apply_rotary_embedding", apply_rotary_embedding); + li->add_int_property("scaling_query", scaling_query); + li->add_float_property("scaling_factor", scaling_factor); + li->add_int_property("qk_prod_scaling", qk_prod_scaling); + li->add_int_property("position_bias", position_bias); + layers.push_back(li); + return li->outputs[0]; +} + +Op *SpecInferIncMultiHeadSelfAttention::create_operator_from_layer( + FFModel &model, + Layer const *layer, + std::vector const &inputs) { + + std::cout << "spec create operator: " << layer->name << "\n"; + long long value; + layer->get_int_property("embed_dim", value); + int embed_dim = value; + layer->get_int_property("num_q_heads", value); + int num_q_heads = value; + layer->get_int_property("num_kv_heads", value); + int num_kv_heads = value; + layer->get_int_property("kdim", value); + int kdim = value; + layer->get_int_property("vdim", value); + int vdim = value; + float dropout; + layer->get_float_property("dropout", dropout); + layer->get_int_property("qkv_bias", value); + bool qkv_bias = (bool)value; + layer->get_int_property("final_bias", value); + bool final_bias = (bool)value; + layer->get_int_property("add_zero_attn", value); + bool add_zero_attn = (bool)value; + layer->get_int_property("apply_rotary_embedding", value); + bool apply_rotary_embedding = (bool)value; + layer->get_int_property("scaling_query", value); + bool scaling_query = (bool)value; + float scaling_factor; + layer->get_float_property("scaling_factor", scaling_factor); + layer->get_int_property("qk_prod_scaling", value); + bool qk_prod_scaling = (bool)value; + layer->get_int_property("position_bias", value); + bool position_bias = (bool)value; + + return new SpecInferIncMultiHeadSelfAttention(model, + layer->layer_guid, + inputs[0], + embed_dim, + num_q_heads, + num_kv_heads, + kdim, + vdim, + dropout, + qkv_bias, + final_bias, + add_zero_attn, + apply_rotary_embedding, + scaling_query, + scaling_factor, + qk_prod_scaling, + position_bias, + false /*allocate_weights*/, + layer->name); +} + +SpecInferIncMultiHeadSelfAttention::SpecInferIncMultiHeadSelfAttention( + FFModel &model, + LayerID const &_layer_guid, + ParallelTensor const _input, + int _embed_dim, + int _num_q_heads, + int _num_kv_heads, + int _kdim, + int _vdim, + float _dropout, + bool _qkv_bias, + bool _final_bias, + bool _add_zero_attn, + bool _apply_rotary_embedding, + bool _scaling_query, + float _scaling_factor, + bool _qk_prod_scaling, + bool _position_bias, + bool allocate_weights, + char const *name) + // Initializer* _bias_initializer) + : Op(model, + OP_SPECINFER_INC_MULTIHEAD_SELF_ATTENTION, + _input->data_type, + name, + 1 /*inputs*/, + (_qkv_bias || _final_bias ? 2 : 1) /*weights*/, + 1 /*outputs*/, + _input), + num_q_heads(_num_q_heads), num_kv_heads(_num_kv_heads), dropout(_dropout), + qkv_bias(_qkv_bias), final_bias(_final_bias), + add_zero_attn(_add_zero_attn), + apply_rotary_embedding(_apply_rotary_embedding), + qSize(_input->dims[0].size), kSize(_input->dims[0].size), + vSize(_input->dims[0].size), qProjSize(_kdim), kProjSize(_kdim), + vProjSize(_vdim), oProjSize(_embed_dim), + qoSeqLength(_input->dims[1].size), kvSeqLength(_input->dims[1].size), + scaling_query(_scaling_query), scaling_factor(_scaling_factor), + qk_prod_scaling(_qk_prod_scaling), position_bias(_position_bias) { + // overwrite layer_guid + layer_guid = _layer_guid; + + numOutputs = 1; + int numdim = _input->num_dims; + ParallelDim dims[MAX_TENSOR_DIM]; + for (int i = 0; i < numdim; i++) { + dims[i] = _input->dims[i]; + } + dims[0].size = _embed_dim; + // Currently require no parallelism along this dim + assert(dims[0].degree == 1); + if (allocate_weights) { + // Create weight tensor + int num_dims = inputs[0]->num_dims; + // Compute weight size + int qParas = this->qProjSize * this->qSize; + int kParas = this->kProjSize * this->kSize; + int vParas = this->vProjSize * this->vSize; + int oParas = + this->oProjSize * (this->vProjSize > 0 ? this->vProjSize : this->vSize); + ParallelDim dims[2]; + dims[0] = inputs[0]->dims[num_dims - 2]; + dims[0].size = dims[0].degree; + dims[1] = inputs[0]->dims[num_dims - 1]; + dims[1].size = this->num_q_heads * (qParas + oParas) + + this->num_q_heads * (kParas + vParas); + dims[1].is_replica_dim = false; + int seed = std::rand(); + Initializer *initializer = new GlorotUniform(seed); + weights[0] = model.create_parallel_weight<2>(dims, + this->data_type, + NULL /*owner_op*/, + true /*create_grad*/, + initializer, + CHOSEN_SYNC_TYPE); + if (qkv_bias || final_bias) { + ParallelTensorShape bias_shape = _input->get_shape(); + int qkv_bias_size = + qProjSize * num_q_heads + (kProjSize + vProjSize) * num_q_heads; + bias_shape.dims[0].size = + (qkv_bias ? qkv_bias_size : 0) + (final_bias ? oProjSize : 0); + bias_shape.dims[1].size = bias_shape.dims[2].size = 1; + weights[1] = + model.create_parallel_weight_legion_ordering(bias_shape.num_dims, + bias_shape.dims, + this->data_type, + nullptr /*owner_op*/, + true /*create_grad*/, + initializer, + CHOSEN_SYNC_TYPE); + } + } + + outputs[0] = model.create_parallel_tensor_legion_ordering( + _input->num_dims, dims, this->data_type, this); + /* for (int i = 0; i < numdim; i++) { */ + /* register_output_input_parallel_dims(outputs[0], i, inputs[0], i); */ + /* } */ + /* // Check correctness */ + /* assert(check_output_input_weight_parallel_dims()); */ +} + +SpecInferIncMultiHeadSelfAttention::SpecInferIncMultiHeadSelfAttention( + FFModel &model, + ParallelTensor const _input, + ParallelTensor const _weight, + int _embed_dim, + int _num_q_heads, + int _num_kv_heads, + int _kdim, + int _vdim, + float _dropout, + bool _qkv_bias, + bool _final_bias, + bool _add_zero_attn, + bool _apply_rotary_embedding, + bool _scaling_query, + float _scaling_factor, + bool _qk_prod_scaling, + bool _position_bias, + bool allocate_weights, + char const *name) + // Initializer* _bias_initializer) + : Op(model, + OP_SPECINFER_INC_MULTIHEAD_SELF_ATTENTION, + _input->data_type, + name, + 1 /*inputs*/, + (_qkv_bias || _final_bias ? 2 : 1) /*weights*/, + 1 /*outputs*/, + _input, + _weight), + num_q_heads(_num_q_heads), num_kv_heads(_num_kv_heads), dropout(_dropout), + qkv_bias(_qkv_bias), final_bias(_final_bias), + add_zero_attn(_add_zero_attn), + apply_rotary_embedding(_apply_rotary_embedding), + qSize(_input->dims[0].size), kSize(_input->dims[0].size), + vSize(_input->dims[0].size), qProjSize(_kdim), kProjSize(_kdim), + vProjSize(_vdim), oProjSize(_embed_dim), + qoSeqLength(_input->dims[1].size), kvSeqLength(_input->dims[1].size), + scaling_query(_scaling_query), scaling_factor(_scaling_factor), + qk_prod_scaling(_qk_prod_scaling), position_bias(_position_bias) +// bias_initializer(_bias_initializer) +{ + numOutputs = 1; + int numdim = _input->num_dims; + ParallelDim dims[MAX_TENSOR_DIM]; + for (int i = 0; i < numdim; i++) { + dims[i] = _input->dims[i]; + } + dims[0].size = _embed_dim; + // Currently require no parallelism along this dim + assert(dims[0].degree == 1); + if (allocate_weights) { + // Create weight tensor + int num_dims = inputs[0]->num_dims; + // Compute weight size + int qParas = this->qProjSize * this->qSize; + int kParas = this->kProjSize * this->kSize; + int vParas = this->vProjSize * this->vSize; + int oParas = + this->oProjSize * (this->vProjSize > 0 ? this->vProjSize : this->vSize); + ParallelDim dims[2]; + dims[0] = inputs[0]->dims[num_dims - 2]; + dims[0].size = dims[0].degree; + dims[1] = inputs[0]->dims[num_dims - 1]; + dims[1].size = this->num_q_heads * (qParas + oParas) + + this->num_q_heads * (kParas + vParas); + dims[1].is_replica_dim = false; + // dims[2].size = qParas + kParas + vParas + oParas; + int seed = std::rand(); + Initializer *initializer = new GlorotUniform(seed); + weights[0] = model.create_parallel_weight<2>(dims, + this->data_type, + NULL /*owner_op*/, + true /*create_grad*/, + initializer, + CHOSEN_SYNC_TYPE); + if (qkv_bias || final_bias) { + ParallelTensorShape bias_shape = _input->get_shape(); + int qkv_bias_size = + qProjSize * num_q_heads + (kProjSize + vProjSize) * num_q_heads; + bias_shape.dims[0].size = + (qkv_bias ? qkv_bias_size : 0) + (final_bias ? oProjSize : 0); + bias_shape.dims[1].size = bias_shape.dims[2].size = 1; + weights[1] = + model.create_parallel_weight_legion_ordering(bias_shape.num_dims, + bias_shape.dims, + this->data_type, + nullptr /*owner_op*/, + true /*create_grad*/, + initializer, + CHOSEN_SYNC_TYPE); + } + } + + outputs[0] = model.create_parallel_tensor_legion_ordering( + _input->num_dims, dims, this->data_type, this); + + /* for (int i = 0; i < numdim; i++) { */ + /* register_output_input_parallel_dims(outputs[0], i, inputs[0], i); */ + /* } */ + /* register_output_weight_parallel_dims(outputs[0], numdim-1, _weight, 1); */ + /* register_output_weight_parallel_dims(outputs[0], numdim-2, _weight, 2); */ + // Check correctness + /* assert(check_output_input_weight_parallel_dims()); */ +} + +SpecInferIncMultiHeadSelfAttention::SpecInferIncMultiHeadSelfAttention( + FFModel &model, + SpecInferIncMultiHeadSelfAttention const &other, + ParallelTensor const input, + bool allocate_weights) + : SpecInferIncMultiHeadSelfAttention(model, + other.layer_guid, + input, + other.oProjSize, + other.num_q_heads, + other.num_kv_heads, + other.qProjSize, + other.vProjSize, + other.dropout, + other.qkv_bias, + other.final_bias, + other.add_zero_attn, + other.apply_rotary_embedding, + other.scaling_query, + other.scaling_factor, + other.qk_prod_scaling, + other.position_bias, + allocate_weights, + other.name) {} + +SpecInferIncMultiHeadSelfAttention::SpecInferIncMultiHeadSelfAttention( + FFModel &model, + SpecInferIncMultiHeadSelfAttentionParams const ¶ms, + ParallelTensor const &input, + bool allocate_weights, + char const *name) + : SpecInferIncMultiHeadSelfAttention(model, + params.layer_guid, + input, + params.embed_dim, + params.num_q_heads, + params.num_kv_heads, + params.kdim, + params.vdim, + params.dropout, + params.qkv_bias, + params.final_bias, + params.add_zero_attn, + params.apply_rotary_embedding, + params.scaling_query, + params.scaling_factor, + params.qk_prod_scaling, + params.position_bias, + allocate_weights, + name) {} + +void SpecInferIncMultiHeadSelfAttention::init_inference( + FFModel const &ff, + std::vector const &batch_inputs, + std::vector const &batch_outputs, + MachineView const *mv) { + assert(check_output_input_weight_same_parallel_is()); + parallel_is = batch_outputs[0]->parallel_is; + ArgumentMap argmap; + Context ctx = ff.config.lg_ctx; + Runtime *runtime = ff.config.lg_hlr; + MachineView const *view = mv ? mv : &batch_outputs[0]->machine_view; + size_t machine_view_hash = view->hash(); + set_argumentmap_for_init_inference(ff, argmap, batch_outputs[0]); + IndexLauncher launcher( + SPECINFER_INC_MULTIHEAD_SELF_ATTENTION_INIT_TASK_ID, + parallel_is, + TaskArgument(this, sizeof(SpecInferIncMultiHeadSelfAttention)), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + machine_view_hash); + launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, + 0 /*projection id*/, + READ_ONLY, + EXCLUSIVE, + batch_inputs[0]->region)); + launcher.add_field(0, FID_DATA); + launcher.add_region_requirement(RegionRequirement(weights[0]->part, + 0 /*projection id*/, + READ_ONLY, + EXCLUSIVE, + weights[0]->region)); + launcher.add_field(1, FID_DATA); + launcher.add_region_requirement(RegionRequirement(batch_outputs[0]->part, + 0 /*projection id*/, + WRITE_ONLY, + EXCLUSIVE, + batch_outputs[0]->region)); + launcher.add_field(2, FID_DATA); + FutureMap fm = runtime->execute_index_space(ctx, launcher); + fm.wait_all_results(); + set_opmeta_from_futuremap_inference(ff, fm, batch_outputs[0]); +} + +void SpecInferIncMultiHeadSelfAttention::init(FFModel const &ff) { + assert(check_output_input_weight_same_parallel_is()); + parallel_is = outputs[0]->parallel_is; + ArgumentMap argmap; + Context ctx = ff.config.lg_ctx; + Runtime *runtime = ff.config.lg_hlr; + set_argumentmap_for_init(ff, argmap); + IndexLauncher launcher( + SPECINFER_INC_MULTIHEAD_SELF_ATTENTION_INIT_TASK_ID, + parallel_is, + TaskArgument(this, sizeof(SpecInferIncMultiHeadSelfAttention)), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + outputs[0]->machine_view.hash()); + launcher.add_region_requirement(RegionRequirement(inputs[0]->part, + 0 /*projection id*/, + READ_ONLY, + EXCLUSIVE, + inputs[0]->region)); + launcher.add_field(0, FID_DATA); + launcher.add_region_requirement(RegionRequirement(weights[0]->part, + 0 /*projection id*/, + READ_ONLY, + EXCLUSIVE, + weights[0]->region)); + launcher.add_field(1, FID_DATA); + launcher.add_region_requirement(RegionRequirement(outputs[0]->part, + 0 /*projection id*/, + WRITE_ONLY, + EXCLUSIVE, + outputs[0]->region)); + launcher.add_field(2, FID_DATA); + FutureMap fm = runtime->execute_index_space(ctx, launcher); + fm.wait_all_results(); + set_opmeta_from_futuremap(ff, fm); +} + +/* + regions[0](I): input + regions[1](I): weight + regions[2](O): output +*/ +OpMeta *SpecInferIncMultiHeadSelfAttention::init_task( + Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + SpecInferIncMultiHeadSelfAttention const *attn = + (SpecInferIncMultiHeadSelfAttention *)task->args; + FFHandler handle = *((FFHandler const *)task->local_args); + + GenericTensorAccessorR input = + helperGetGenericTensorAccessorRO(attn->inputs[0]->data_type, + regions[0], + task->regions[0], + FID_DATA, + ctx, + runtime); + GenericTensorAccessorR weight = + helperGetGenericTensorAccessorRO(attn->weights[0]->data_type, + regions[1], + task->regions[1], + FID_DATA, + ctx, + runtime); + GenericTensorAccessorW output = + helperGetGenericTensorAccessorWO(attn->outputs[0]->data_type, + regions[2], + task->regions[2], + FID_DATA, + ctx, + runtime); + + int num_samples = input.domain.hi()[2] - input.domain.lo()[2] + 1; + assert(attn->qoSeqLength == input.domain.hi()[1] - input.domain.lo()[1] + 1); + assert(attn->kvSeqLength == input.domain.hi()[1] - input.domain.lo()[1] + 1); + int num_q_heads = attn->num_q_heads; + int num_kv_heads = attn->num_kv_heads; + assert(attn->oProjSize == output.domain.hi()[0] - output.domain.lo()[0] + 1); + + Memory gpu_mem = Machine::MemoryQuery(Machine::get_machine()) + .only_kind(Memory::GPU_FB_MEM) + .best_affinity_to(task->target_proc) + .first(); + MemoryAllocator gpu_mem_allocator(gpu_mem); + // We don't do offloading for SSMs (small speculative models) + SpecInferIncMultiHeadSelfAttentionMeta *m = + new SpecInferIncMultiHeadSelfAttentionMeta(handle, + attn, + weight, + gpu_mem_allocator, + num_samples, + num_q_heads, + num_kv_heads); + // assert that we didn't over allocate memory + assert(gpu_mem_allocator.instance_allocated_size == + gpu_mem_allocator.instance_total_size); + m->profiling = attn->profiling; + m->inference_debugging = attn->inference_debugging; + std::strcpy(m->op_name, attn->name); + m->layer_guid = attn->layer_guid; + assert(weight.domain.get_volume() * data_type_size(weight.data_type) == + m->weightSize); + return m; +} + +void SpecInferIncMultiHeadSelfAttention::forward(FFModel const &ff) { + // SpecInferIncMultiHeadSelfAttention doesn't support forward + assert(false); +} + +FutureMap SpecInferIncMultiHeadSelfAttention::inference( + FFModel const &ff, + BatchConfigFuture const &bc, + std::vector const &batch_inputs, + std::vector const &batch_outputs, + MachineView const *mv) { + ArgumentMap argmap; + Context ctx = ff.config.lg_ctx; + Runtime *runtime = ff.config.lg_hlr; + parallel_is = batch_outputs[0]->parallel_is; + MachineView const *view = mv ? mv : &batch_outputs[0]->machine_view; + set_argumentmap_for_inference(ff, argmap, batch_outputs[0]); + size_t machine_view_hash = view->hash(); + int idx = 0; + IndexLauncher launcher(SPECINFER_INC_MULTIHEAD_SELF_ATTENTION_INF_TASK_ID, + parallel_is, + TaskArgument(nullptr, 0), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + machine_view_hash); + launcher.add_future(bc); + launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, + 0 /*projection id*/, + READ_ONLY, + EXCLUSIVE, + batch_inputs[0]->region)); + launcher.add_field(idx++, FID_DATA); + launcher.add_region_requirement(RegionRequirement(weights[0]->part, + 0 /*projection id*/, + READ_ONLY, + EXCLUSIVE, + weights[0]->region)); + launcher.add_field(idx++, FID_DATA); + launcher.add_region_requirement(RegionRequirement(batch_outputs[0]->part, + 0 /*projection id*/, + WRITE_ONLY, + EXCLUSIVE, + batch_outputs[0]->region)); + launcher.add_field(idx++, FID_DATA); + + if (qkv_bias || final_bias) { + launcher.add_region_requirement(RegionRequirement(weights[1]->part, + 0 /*projection id*/, + READ_ONLY, + EXCLUSIVE, + weights[1]->region)); + launcher.add_field(idx++, FID_DATA); + } + return runtime->execute_index_space(ctx, launcher); +} + +/* + regions[0](I): input + regions[3](I): weight + regions[4](O): output +*/ +void SpecInferIncMultiHeadSelfAttention::inference_task( + Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + assert(task->regions.size() == regions.size()); + + BeamSearchBatchConfig const &bc = + Future(task->futures[0]).get_result(); + if (bc.num_tokens == 0) { + return; + } + + SpecInferIncMultiHeadSelfAttentionMeta *m = + *((SpecInferIncMultiHeadSelfAttentionMeta **)task->local_args); + assert(((*m->qkv_bias || *m->final_bias) ? regions.size() == 4 + : regions.size() == 3)); + + GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( + m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); + GenericTensorAccessorR weight = helperGetGenericTensorAccessorRO( + m->weight_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); + GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( + m->output_type[0], regions[2], task->regions[2], FID_DATA, ctx, runtime); + GenericTensorAccessorR biases; + if (*m->qkv_bias || *m->final_bias) { + biases = helperGetGenericTensorAccessorRO(m->weight_type[1], + regions[3], + task->regions[3], + FID_DATA, + ctx, + runtime); + Domain bias_domain = runtime->get_index_space_domain( + ctx, task->regions[3].region.get_index_space()); + assert(bias_domain.get_dim() == 4); + } + Domain input_domain = runtime->get_index_space_domain( + ctx, task->regions[0].region.get_index_space()); + Domain weight_domain = runtime->get_index_space_domain( + ctx, task->regions[1].region.get_index_space()); + Domain output_domain = runtime->get_index_space_domain( + ctx, task->regions[2].region.get_index_space()); + + assert(input_domain.get_dim() == 4); + assert(weight_domain.get_dim() == 2); + assert(output_domain.get_dim() == 4); + + assert(task->index_point.get_dim() == 1); + SpecInferIncMultiHeadSelfAttention::inference_kernel_wrapper( + m, &bc, task->index_point.point_data[0], input, weight, output, biases); + if (m->inference_debugging) { + assert(task->index_point.get_dim() == 1); + int shard_id = task->index_point.point_data[0]; + std::vector weights_accessors; + weights_accessors.push_back(weight); + if (*m->qkv_bias || *m->final_bias) { + weights_accessors.push_back(biases); + } + SpecInferIncMultiHeadSelfAttention::save_inference_tensors_to_file( + m, shard_id, &bc, {input}, weights_accessors, {output}); + } +} + +void SpecInferIncMultiHeadSelfAttention::backward(FFModel const &ff) { + // SpecInferIncMultiHeadSelfAttention does not support backward + assert(false); +} + +bool SpecInferIncMultiHeadSelfAttention::get_int_parameter(PMParameter para, + int *value) const { + switch (para) { + case PM_NUM_HEADS: + *value = num_q_heads; + return true; + default: + return Op::get_int_parameter(para, value); + } +} + +Op *SpecInferIncMultiHeadSelfAttention::materialize(FFModel &ff, + ParallelTensor inputs[], + int num_inputs) const { + SpecInferIncMultiHeadSelfAttentionParams params = get_params(); + return new SpecInferIncMultiHeadSelfAttention( + ff, params, inputs[0], true, this->name); +} + +bool SpecInferIncMultiHeadSelfAttention::measure_operator_cost( + Simulator *sim, MachineView const &mv, CostMetrics &cost_metrics) const { + return false; +} + +bool operator==(SpecInferIncMultiHeadSelfAttentionParams const &lhs, + SpecInferIncMultiHeadSelfAttentionParams const &rhs) { + return lhs.layer_guid == rhs.layer_guid && lhs.embed_dim == rhs.embed_dim && + lhs.num_q_heads == rhs.num_q_heads && lhs.kdim == rhs.kdim && + lhs.vdim == rhs.vdim && lhs.dropout == rhs.dropout && + lhs.qkv_bias == rhs.qkv_bias && lhs.final_bias == rhs.final_bias && + lhs.add_zero_attn == rhs.add_zero_attn && + lhs.apply_rotary_embedding == rhs.apply_rotary_embedding && + lhs.scaling_query == rhs.scaling_query && + lhs.scaling_factor == rhs.scaling_factor && + lhs.qk_prod_scaling == rhs.qk_prod_scaling && + lhs.position_bias == rhs.position_bias; +} + +SpecInferIncMultiHeadSelfAttentionParams + SpecInferIncMultiHeadSelfAttention::get_params() const { + SpecInferIncMultiHeadSelfAttentionParams params; + params.layer_guid = this->layer_guid; + params.embed_dim = this->oProjSize; + params.num_q_heads = this->num_q_heads; + params.num_kv_heads = this->num_kv_heads; + params.kdim = this->kProjSize; + params.vdim = this->vProjSize; + params.dropout = this->dropout; + params.qkv_bias = this->qkv_bias; + params.final_bias = this->final_bias; + params.add_zero_attn = this->add_zero_attn; + params.apply_rotary_embedding = this->apply_rotary_embedding; + params.scaling_query = this->scaling_query; + params.scaling_factor = this->scaling_factor; + params.qk_prod_scaling = this->qk_prod_scaling; + params.position_bias = this->position_bias; + + return params; +} + +}; // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::SpecInferIncMultiHeadSelfAttentionParams const ¶ms) const { + size_t key = 0; + hash_combine(key, params.layer_guid.id); + hash_combine(key, params.embed_dim); + hash_combine(key, params.num_q_heads); + hash_combine(key, params.num_kv_heads); + hash_combine(key, params.kdim); + hash_combine(key, params.vdim); + hash_combine(key, params.dropout); + hash_combine(key, params.qkv_bias); + hash_combine(key, params.final_bias); + hash_combine(key, params.add_zero_attn); + hash_combine(key, params.apply_rotary_embedding); + hash_combine(key, params.scaling_query); + hash_combine(key, params.scaling_factor); + hash_combine(key, params.qk_prod_scaling); + hash_combine(key, params.position_bias); + return key; +} +}; // namespace std diff --git a/src/ops/specinfer_inc_multihead_self_attention.cu b/src/ops/specinfer_inc_multihead_self_attention.cu new file mode 100644 index 0000000000..0bdf07a9d7 --- /dev/null +++ b/src/ops/specinfer_inc_multihead_self_attention.cu @@ -0,0 +1,890 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) +#include "cuComplex.h" +#endif +#include "flexflow/ffconst_utils.h" +#include "flexflow/ops/kernels/inc_multihead_self_attention_kernels.h" +#include "flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh" +#include "flexflow/ops/specinfer_inc_multihead_self_attention.h" +#include "flexflow/utils/cuda_helper.h" + +namespace FlexFlow { + +#define WARP_SIZE 32 + +// declare Legion names +using Legion::coord_t; +using Legion::Memory; +using namespace Kernels::IncMultiHeadAttention; + +namespace Kernels { +namespace SpecInferIncMultiHeadAttention { + +template +__global__ void compute_specinfer_attention_kernel_generation_kernel( + DT const *query, + DT const *key_cache, + DT const *value_cache, + DT *output_ptr, + float const scale, + int const max_seq_length, + int per_head_size, + int hidden_size, + BatchConfig::PerRequestInfo *request_infos, + BeamSearchBatchConfig::BeamSearchPerRequestInfo *beam_request_infos, + BeamSearchBatchConfig::SpecInferTopology *topology_mask, + int max_tree_branches) { + + // q, k + using Q_vec = typename VEC_K::Type; + using K_vec = typename VEC_K::Type; + using V_vec = typename VEC_V
::Type; + using Out_sum = typename Vec_fp32_::Type; + + constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; + + constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(DT); + constexpr int K_ELTS_PER_THREAD = Dh / THREADS_PER_KEY; + constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; + // constexpr int QK_ELTS_IN_16B = 16 / sizeof(DT); + + // thread id + int const tidx = threadIdx.x; + // head id + int const head_idx = blockIdx.x; + // request idx + int const request_idx = blockIdx.y; + + BeamSearchBatchConfig::SpecInferTopology topology = + topology_mask[request_idx]; + + int const first_step = 0; + + int const tlength = request_infos[request_idx].first_token_depth_in_request + + request_infos[request_idx].num_tokens_in_batch; + // int const qlength = request_infos[request_idx].num_tokens_in_batch; + int const tree_branch_num = beam_request_infos[request_idx].sub_request_num; + + // will decode qlength tokens in this thread block + // int const qlength = tree_branch_num; + + int first_token_idx = 0; + for (int r = 0; r < request_idx; r++) { + first_token_idx += request_infos[request_idx].num_tokens_in_batch; + } + + // shared memory objects + extern __shared__ char smem_[]; + + float *qk_smem = reinterpret_cast(smem_); + float *out_smem = reinterpret_cast(smem_); + + float qk_max = -FLT_MAX; + + // first WARPS_PER_BLOCK for store qk_max, second WARPS_PER_BLOCK for sum + __shared__ float red_smem[WARPS_PER_BLOCK * 2]; + + const DT *q_ptr = query + first_token_idx * hidden_size * QKV_WEIGHT_NUM + + head_idx * per_head_size; + __shared__ Q_vec q_vecs[THREADS_PER_KEY][K_VECS_PER_THREAD]; + + // the start offset of the element eg. (0, 1, 2, 3) * K_VEC_SIZE + int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE; + int ki_o = tidx % THREADS_PER_KEY; + // the first key's offset for this thread + // ko = 0, 0, 0, 0, 1, 1, 1, 1, .... + int ko = tidx / THREADS_PER_KEY; + // load q tensor + Q_vec q_vec[K_VECS_PER_THREAD]; + + constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; + // The number of keys per warp. + constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; + + DT const *k_cache_batch = + key_cache + + request_idx * max_seq_length * hidden_size * max_tree_branches + ki; + + int ti_end = + div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step; + + for (int sub_req_idx = 0; sub_req_idx < tree_branch_num; sub_req_idx += 1) { +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + q_vecs[ki_o][ii] = *reinterpret_cast( + q_ptr + (hidden_size * QKV_WEIGHT_NUM * sub_req_idx) + ki + + ii * THREADS_PER_KEY * K_VEC_SIZE); + } + __syncthreads(); + for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { + K_vec k[K_VECS_PER_THREAD]; + int const ti_circ = ti % max_seq_length; + + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * THREADS_PER_KEY * K_VEC_SIZE; + if (ti < tlength) { + // find the real position of the cache; + // depth: 0, 1, 2, 3, 4, 4, 5, 5 ,5, 5, + int const real_cache_idx = topology.real_token_pos[sub_req_idx][ti]; + k[ii] = *reinterpret_cast( + k_cache_batch + real_cache_idx * hidden_size + + head_idx * per_head_size + jj); + } + } + float qk = scale * Qk_dot::dot(q_vecs[ki_o], k); + + if (ti < tlength && tidx % THREADS_PER_KEY == 0) { + // todo add alobi here + bool const mask = ti_circ >= tlength; + if (mask) { + assert(false); + } + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + qk_smem[ti - first_step] = mask ? 0.f : qk; + } + } + + __syncthreads(); + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + // Decompose the thread index into warp and lane. + int const warp = tidx / WARP_SIZE; + int const lane = tidx % WARP_SIZE; + + // The warp leader writes the max to shared memory. + if (lane == 0) { + red_smem[warp] = qk_max; + } + + // Make sure the products are in shared memory. + __syncthreads(); + + // The warps finalize the reduction. + qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + // Broadcast to all the threads in the warp. + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + + float exp_sum = 0.f; + for (int ti = first_step + tidx; ti < tlength; ti += THREADS_PER_BLOCK) { + float logit = __expf(qk_smem[ti - first_step] - qk_max); + exp_sum += logit; + qk_smem[ti - first_step] = logit; + } + + // Compute the sum. + exp_sum = block_sum(&red_smem[WARPS_PER_BLOCK], exp_sum); + + // softmax + float inv_sum = __fdividef(1.f, exp_sum + 1.e-6); + for (int ti = first_step + tidx; ti < tlength; ti += THREADS_PER_BLOCK) { + qk_smem[ti - first_step] *= inv_sum; + } + + __syncthreads(); + // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { + // printf("softmax %.10f\n", qk_smem[0]); + // } + + // value projection + constexpr int V_VEC_SIZE = 16 / sizeof(DT); + // A vector of V elements for the current timestep. + // using V_vec_k = typename V_vec_k_::Type; + // using V_vec_acum = typename V_vec_acum_fp32_::Type; + + // The value computed by this thread. + int vo = tidx / THREADS_PER_VALUE; + // The hidden dimensions computed by this particular thread. + int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; + constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; + + Out_sum out; + zero(out); + + // The base pointer for the value in the cache buffer. + DT const *v_cache_batch = + value_cache + + request_idx * max_seq_length * hidden_size * max_tree_branches + vi; + // DT const *v_cache_batch = + // value_cache + + // (beam_request_idx * max_beam_width + beam_sub_request_idx) * + // max_seq_length * hidden_size + + // vi; + + if (Dh == Dh_MAX || vi < Dh) { + for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) { + // Load the values from the cache. + int const ti_circ = ti % max_seq_length; + int const real_cache_idx = topology.real_token_pos[sub_req_idx][ti]; + V_vec v = *reinterpret_cast( + v_cache_batch + real_cache_idx * hidden_size + + head_idx * per_head_size); + float logit = qk_smem[ti - first_step]; + out = FlexFlow::fma(logit, cast_to_float(v), out); + } + } + + // // Make sure we can start writing to shared memory. + __syncthreads(); + + // Run the final reduction amongst the different groups computing different + // partial outputs. + if (Dh == Dh_MAX || vi < Dh) { +#pragma unroll + for (int active_groups = V_PER_ITER; active_groups >= 2; + active_groups /= 2) { + + // The midpoint in the number of active groups. + int midpoint = active_groups / 2; + + // The upper part of active threads store to shared memory. + if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) { + *reinterpret_cast(out_smem + (vo - midpoint) * Dh + vi) = + out; + } + __syncthreads(); + + // The bottom warps update their values. + if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) { + out = add(*reinterpret_cast(out_smem + vo * Dh + vi), + out); + } + __syncthreads(); + } + } + + // Output the final values. + if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { + convert_from_float( + *reinterpret_cast(output_ptr + request_idx * hidden_size + + head_idx * per_head_size + vi), + out); + } + } +} + +template +__global__ void specinfer_store_kv_cache( + DT const *devQKVProjArray, + DT *kCache_ptr, + DT *vCache_ptr, + BatchConfig::PerTokenInfo *tokenInfos, + BatchConfig::PerRequestInfo *requestInfo, + BeamSearchBatchConfig::BeamSearchPerTokenInfo *beamTokenInfos, + BeamSearchBatchConfig::BeamSearchPerRequestInfo *beamRequestInfos, + BeamSearchBatchConfig::SpecInferTopology *beam_topology_mask, + int qProjSize, + int kProjSize, + int vProjSize, + int num_tokens, + int max_seq_len, + int max_tree_branches, + bool is_root, + int hidden_size) { + CUDA_KERNEL_LOOP(i, num_tokens * hidden_size * 2) { + int token_idx = i / (hidden_size * KV_WEIGHT_NUM); + int offset = i % hidden_size; + + size_t val_idx = + token_idx * QKV_WEIGHT_NUM * hidden_size + hidden_size + offset; + + DT kVal = devQKVProjArray[val_idx]; + DT vVal = devQKVProjArray[val_idx + hidden_size]; + + // above no need to be changed + // int const req_id = id_map[token_idx].request_index; + // int const tok_id = id_map[token_idx].token_position; + // int const sub_req_id = id_map[token_idx].sub_request_index; + // int const parent_id = id_map[token_idx].parent_id; + // int const beam_depth = id_map[token_idx].beam_depth; + // int const beam_width = id_map[token_idx].beam_width; + + int const req_id = tokenInfos[token_idx].request_index; + int const tok_id = tokenInfos[token_idx].abs_depth_in_request; + int const sub_req_id = beamTokenInfos[token_idx].sub_request_index; + // int const parent_id = beamRequestInfos[req_id].parent_id[sub_req_id]; + // int const beam_depth = beamRequestInfos[req_id].current_depth; + // int const beam_width = beamRequestInfos[req_id].beam_size; + int const allocated_tokens = beam_topology_mask[req_id].allocated_tokens; + + kCache_ptr[(req_id * max_tree_branches) * (hidden_size * max_seq_len) + + (allocated_tokens + sub_req_id) * hidden_size + offset] = kVal; + vCache_ptr[(req_id * max_tree_branches) * (hidden_size * max_seq_len) + + (allocated_tokens + sub_req_id) * hidden_size + offset] = vVal; + } +} + +template +void update_kv_cache_kernel(SpecInferIncMultiHeadSelfAttentionMeta const *m, + BeamSearchBatchConfig const *bc, + cudaStream_t stream) { + int num_tokens = bc->num_active_tokens(); + int curr_depth = bc->beamRequestsInfo[0].current_depth; + // printf("curr depth: %d\n", curr_depth); + // assert(curr_depth < 3); + if (num_tokens > 0) { + int parallelism = m->hidden_size * KV_WEIGHT_NUM * num_tokens; + specinfer_store_kv_cache<<>>( + static_cast
(m->devQKVProjArray), + static_cast
(m->keyCache), + static_cast
(m->valueCache), + m->token_infos, + m->request_infos, + m->beam_token_infos, + m->beam_request_infos, + m->beam_topology_mask, + m->qProjSize, + m->kProjSize, + m->vProjSize, + num_tokens, + BatchConfig::max_sequence_length(), + BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES, + /*root*/ curr_depth == 0, + m->hidden_size); + } +} + +#define LAUNCH_SPECINFER_ATTENTION_SCORE_KERNEL( \ + DT, Dh, Dh_MAX, THDS_PER_KEY, THREADS_PER_VALUE, THDS_PER_BLOCK, stream) \ + smem_sz = smem_size_in_bytes
(m->qProjSize, \ + BatchConfig::max_sequence_length(), \ + THREADS_PER_VALUE, \ + THDS_PER_BLOCK); \ + compute_specinfer_attention_kernel_generation_kernel \ + <<>>( \ + static_cast
(m->devQKVProjArray), \ + static_cast
(m->keyCache), \ + static_cast
(m->valueCache), \ + output_ptr, \ + scale, \ + BatchConfig::max_sequence_length(), \ + m->qProjSize, \ + m->hidden_size, \ + m->request_infos, \ + m->beam_request_infos, \ + m->beam_topology_mask, \ + BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES) + +template +void compute_specinfer_attention_kernel_generation( + SpecInferIncMultiHeadSelfAttentionMeta const *m, + BeamSearchBatchConfig const *bc, + DT *output_ptr, + cudaStream_t stream) { + // one block == one head per request + dim3 grid(m->num_q_heads, bc->num_active_requests()); + int const per_head_size = m->qProjSize; + float scale = (*m->qk_prod_scaling) ? 1.0f / sqrt(m->kProjSize) : 1.0f; + size_t smem_sz; + if (per_head_size == 64) { + constexpr int THREADS_PER_VALUE_64 = threads_per_value_t::value; + LAUNCH_SPECINFER_ATTENTION_SCORE_KERNEL( + DT, 64, 64, 4, THREADS_PER_VALUE_64, 128, stream); + } else if (per_head_size == 128) { + constexpr int THREADS_PER_VALUE_128 = threads_per_value_t::value; + LAUNCH_SPECINFER_ATTENTION_SCORE_KERNEL( + DT, 128, 128, 4, THREADS_PER_VALUE_128, 128, stream); + } else { + assert(false && "a unsupported head size"); + } +} + +template +__global__ void spec_fill_entries_above_diagonal(DT *matrix, + size_t new_tokens, + size_t total_tokens_in_request, + size_t num_q_heads, + DT value) { + CUDA_KERNEL_LOOP(i, new_tokens * total_tokens_in_request * num_q_heads) { + // size_t head_idx = i / (new_tokens * total_tokens_in_request); + size_t src_idx = (i / new_tokens) % total_tokens_in_request; + size_t dst_idx = i % new_tokens + total_tokens_in_request - new_tokens; + // Casual Mask + if (src_idx > dst_idx) { + matrix[i] = value; + } + } +} + +template +void compute_attention_kernel_prompt( + SpecInferIncMultiHeadSelfAttentionMeta const *m, + BeamSearchBatchConfig const *bc, + int shard_id, + DT *output_ptr, + DT const *bias_ptr, + DT const *weight_ptr, + cudaStream_t stream) { + checkCUDA(cublasSetStream(m->handle.blas, stream)); + checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); + cudaDataType_t cublas_data_type = ff_to_cuda_datatype(m->output_type[0]); + cudnnDataType_t cudnn_data_type = ff_to_cudnn_datatype(m->output_type[0]); + assert(data_type_size(m->output_type[0]) == sizeof(DT)); +#if defined(CUDA_VERSION) && (CUDA_VERSION < 11000) + cudaDataType_t compute_type = cublas_data_type; +#else + // For best performance, set the default cublas compute type to + // CUBLAS_COMPUTE_16F for half precision and to + // CUBLAS_COMPUTE_32F_FAST_16F for full precision + cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; + if (m->output_type[0] == DT_FLOAT) { + compute_type = CUBLAS_COMPUTE_32F_FAST_16F; + } +#endif + // int num_requests = bc->num_active_requests(); + int num_tokens = bc->num_active_tokens(); + int tokens_previous_requests = 0; + int tokens_prev_requests_squares = 0; + // int qkv_block_size = + // (m->qProjSize + m->kProjSize + m->vProjSize) * num_tokens; + int q_block_size = m->qProjSize; + + int kt_block_size = m->kProjSize; + int kt_req_block_size = + kt_block_size * m->num_q_heads * BatchConfig::max_sequence_length(); + int vt_block_size = m->vProjSize; + int vt_req_block_size = + vt_block_size * m->num_q_heads * BatchConfig::max_sequence_length(); + assert(m->qProjSize == m->kProjSize); + + for (int i = 0; i < bc->max_requests_per_batch(); i++) { + if (bc->request_completed[i]) { + continue; + } else if (tokens_previous_requests < bc->num_generation_tokens) { + tokens_previous_requests += bc->requestsInfo[i].num_tokens_in_batch; + continue; + } + + // all requests in prompt phase should only have one sub requests; + assert(bc->sub_requests[i] == 1); + // int num_new_tokens = bc->num_processing_tokens[i]; + // int total_tokens = bc->token_last_available_idx[i] + 1; + + int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch; + int total_tokens = bc->requestsInfo[i].first_token_depth_in_request + + bc->requestsInfo[i].num_tokens_in_batch; + + if (num_new_tokens <= 0) { + continue; + } + + // Compute (QK^T/sqrt(d_k)) + int m_ = num_new_tokens; + int n = total_tokens; + int k = m->qProjSize; + int lda = k * m->num_q_heads * QKV_WEIGHT_NUM, ldb = k * m->num_q_heads, + ldc = m_; + int strideA = q_block_size; + int strideB = kt_block_size; + int strideC = num_new_tokens * total_tokens; + + // a flag of using this scaling alpha + DT alpha = 1.0f, beta = 0.0f; + if (*m->qk_prod_scaling) { + alpha = static_cast
(1.0f / sqrt(m->kProjSize)); + } + // To get A, skip over Q entries from previous requests (same head) + DT const *A = static_cast
(m->devQKVProjArray) + + bc->requestsInfo[i].first_token_offset_in_batch * + m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM; + // To get B, skip over K entries from previous requests (all heads + + // padding) + DT const *B = static_cast
(m->keyCache) + + (i * bc->MAX_SPECULATIVE_TREE_BRANCHES) * kt_req_block_size; + + // if (i == 0 && sub_req_id == 0 && + // bc->beam_slots.at(0).current_depth == 1) { + // int offset = (float *)B - m->keyCache; + // printf("key cache offset %d\n", kt_req_block_size); + // } + // To get C, skip over QK^T products from previous requests + DT *C = static_cast
(m->qk_prods) + + m->num_q_heads * tokens_prev_requests_squares; + checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, + CUBLAS_OP_T, + CUBLAS_OP_N, + m_, + n, + k, + &alpha, + A, + cublas_data_type, + lda, + strideA, + B, + cublas_data_type, + ldb, + strideB, + &beta, + C, + cublas_data_type, + ldc, + strideC, + m->num_q_heads, + compute_type, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + // add alibi position bias to qk production + // add alibi position bias to qk production + if (*m->position_bias) { + size_t parallelism = m->num_q_heads * total_tokens * num_new_tokens; + apply_position_bias_qkprd<<>>(C, + num_new_tokens, + total_tokens, + m->num_q_heads, + m->global_num_q_heads, + shard_id); + } + // Fill all elements above diagonal in qk prods with -inf to force + // causal attention. + assert(num_new_tokens <= total_tokens); + if (num_new_tokens > 1) { + size_t parallelism = m->num_q_heads * num_new_tokens * total_tokens; + spec_fill_entries_above_diagonal<<>>(C, + num_new_tokens, + total_tokens, + m->num_q_heads, + static_cast
(-INFINITY)); + } + // Compute Softmax(QK^T/sqrt(d_k)) + // Before modifying the parameters below, make sure to read the following + // description of the CUDNN_TENSOR_NCHW tensor layout, from + // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnTensorFormat_t: + // This tensor format specifies that the data is laid out in the following + // order: batch size, feature maps, rows, columns. The strides are + // implicitly defined in such a way that the data are contiguous in memory + // with no padding between images, feature maps, rows, and columns; the + // columns are the inner dimension and the images are the outermost + // dimension. + int n_param = m->num_q_heads; + int c_param = total_tokens; + int h_param = 1; + int w_param = num_new_tokens; + checkCUDNN(cudnnSetTensor4dDescriptor(m->qk_tensor, + CUDNN_TENSOR_NCHW, + cudnn_data_type, + n_param, + c_param, + h_param, + w_param)); + float softmax_alpha = 1.0f, softmax_beta = 0.0f; + DT *C_softmax = static_cast
(m->qk_prods_softmax) + + m->num_q_heads * tokens_prev_requests_squares; + // The softmax operation below is executed according to the + // CUDNN_SOFTMAX_MODE_CHANNEL, which is also described in the docs: The + // softmax operation is computed per spatial location (H,W) per image (N) + // across dimension C. + checkCUDNN(cudnnSoftmaxForward(m->handle.dnn, + CUDNN_SOFTMAX_ACCURATE, + CUDNN_SOFTMAX_MODE_CHANNEL, + &softmax_alpha, + m->qk_tensor, + C, + &softmax_beta, + m->qk_tensor, + C_softmax)); + // Matmul softmax(QK^T/sqrt(d_k)) by V + alpha = 1.0f, beta = 0.0f; + m_ = m->vProjSize; + n = num_new_tokens; + k = total_tokens; + lda = m_ * m->num_q_heads, ldb = n, ldc = m_ * m->num_q_heads; + strideA = vt_block_size; + strideB = num_new_tokens * total_tokens; + strideC = m->vProjSize; + // To get A, skip over V^T entries from previous requests (all heads + + // padding) + A = static_cast
(m->valueCache) + + (i * bc->MAX_SPECULATIVE_TREE_BRANCHES) * vt_req_block_size; + // To get B, skip over softmax(QK^T/sqrt(d_k)) entries from previous + // requests (all heads) + B = C_softmax; + // To get C, skip over softmax(QK^T/sqrt(d_k))V products from previous + // requests + C = static_cast
(m->attn_heads) + + (tokens_previous_requests + bc->num_generation_tokens) * + m->num_q_heads * m->vProjSize; + checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, + CUBLAS_OP_N, + CUBLAS_OP_T, + m_, + n, + k, + &alpha, + A, + cublas_data_type, + lda, + strideA, + B, + cublas_data_type, + ldb, + strideB, + &beta, + C, + cublas_data_type, + ldc, + strideC, + m->num_q_heads, + compute_type, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + tokens_previous_requests += num_new_tokens; + tokens_prev_requests_squares += num_new_tokens * total_tokens; + } + + // assert(tokens_previous_requests == num_tokens); +} + +template +void inference_kernel(SpecInferIncMultiHeadSelfAttentionMeta const *m, + BeamSearchBatchConfig const *bc, + int shard_id, + DT const *input_ptr, + DT const *weight_ptr, + DT *output_ptr, + DT const *bias_ptr, + cudaStream_t stream) { + // phase 1: Implement kernel to compute KQV for input tokens + compute_qkv_kernel(m, + bc, + shard_id, + input_ptr, + weight_ptr, + static_cast
(m->devQKVProjArray), + bias_ptr, + stream); + // phase 2: Update key/val cache + update_kv_cache_kernel
(m, bc, stream); + if (bc->num_generation_tokens > 0) { + compute_specinfer_attention_kernel_generation
( + m, bc, static_cast
(m->attn_heads), stream); + } + // phase 3: Compute attention score + // 3 kernels for pahse 3: matmul1 - softmax - matmal2 + if (bc->num_tokens > bc->num_generation_tokens) { + compute_attention_kernel_prompt( + m, bc, shard_id, output_ptr, bias_ptr, weight_ptr, stream); + } + + // compute output production and bias together for all tokens + int num_tokens = bc->num_active_tokens(); + + compute_o_prod_bias( + m, bc, shard_id, output_ptr, weight_ptr, bias_ptr, num_tokens, stream); +} + +} // namespace SpecInferIncMultiHeadAttention +} // namespace Kernels + +/*static*/ +void SpecInferIncMultiHeadSelfAttention::inference_kernel_wrapper( + SpecInferIncMultiHeadSelfAttentionMeta const *m, + BeamSearchBatchConfig const *bc, + int shard_id, + GenericTensorAccessorR const &input, + GenericTensorAccessorR const &weight, + GenericTensorAccessorW const &output, + GenericTensorAccessorR const &bias) { + cudaStream_t stream; + checkCUDA(get_legion_stream(&stream)); + bool use_bias = *m->qkv_bias || *m->final_bias; + + cudaEvent_t t_start, t_end; + if (m->profiling) { + cudaEventCreate(&t_start); + cudaEventCreate(&t_end); + cudaEventRecord(t_start, stream); + } + + assert(input.data_type == weight.data_type); + assert(input.data_type == output.data_type); + if (use_bias) { + assert(input.data_type == bias.data_type); + } + + if (input.data_type == DT_HALF) { + half const *bias_ptr = + use_bias ? bias.get_half_ptr() : static_cast(nullptr); + Kernels::SpecInferIncMultiHeadAttention::inference_kernel( + m, + bc, + shard_id, + input.get_half_ptr(), + weight.get_half_ptr(), + output.get_half_ptr(), + bias_ptr, + stream); + } else if (input.data_type == DT_FLOAT) { + float const *bias_ptr = + use_bias ? bias.get_float_ptr() : static_cast(nullptr); + Kernels::SpecInferIncMultiHeadAttention::inference_kernel( + m, + bc, + shard_id, + input.get_float_ptr(), + weight.get_float_ptr(), + output.get_float_ptr(), + bias_ptr, + stream); + } else { + assert(false && "Unspported data type"); + } + + if (m->profiling) { + cudaEventRecord(t_end, stream); + checkCUDA(cudaEventSynchronize(t_end)); + float elapsed = 0; + checkCUDA(cudaEventElapsedTime(&elapsed, t_start, t_end)); + cudaEventDestroy(t_start); + cudaEventDestroy(t_end); + printf("SpecInferIncMultiHeadSelfAttention forward time = %.2fms\n", + elapsed); + // print_tensor<3, float>(acc_query.ptr, acc_query.rect, + // "[Attention:forward:query]"); print_tensor<3, float>(acc_output.ptr, + // acc_output.rect, "[Attention:forward:output]"); + } +} + +SpecInferIncMultiHeadSelfAttentionMeta::SpecInferIncMultiHeadSelfAttentionMeta( + FFHandler handler, + SpecInferIncMultiHeadSelfAttention const *attn, + GenericTensorAccessorR const &weight, + MemoryAllocator &gpu_mem_allocator, + int num_samples, + int _num_q_heads, + int _num_kv_heads) + : IncMultiHeadSelfAttentionMeta(handler, + BEAM_SEARCH_MODE, + attn, + attn->qSize, + attn->kSize, + attn->vSize, + attn->qProjSize, + attn->kProjSize, + attn->vProjSize, + attn->oProjSize, + attn->apply_rotary_embedding, + attn->qkv_bias, + attn->scaling_query, + attn->qk_prod_scaling, + attn->position_bias, + attn->final_bias, + attn->scaling_factor, + weight, + gpu_mem_allocator, + num_samples, + attn->num_q_heads, + attn->num_kv_heads, + _num_q_heads, + _num_kv_heads, + DT_NONE, + false) { + cudaStream_t stream; + checkCUDA(get_legion_stream(&stream)); + checkCUDNN(cudnnSetStream(handler.dnn, stream)); + + // allocate memory for the seqArray and reserve space + { + // int max_tokens_per_batch = BatchConfig::max_tokens_per_batch(); + // size_t beam_tokeninfo_size = + // max_tokens_per_batch * BeamSearchBatchConfig::MAX_BEAM_WIDTH; + // size_t requestinfo_size = + // BeamSearchBatchConfig::max_requests_per_batch(); size_t + // beam_requestinfo_size = + // BeamSearchBatchConfig::max_requests_per_batch(); + // size_t total_size = + // beam_tokeninfo_size * + // sizeof(BeamSearchBatchConfig::BeamSearchPerTokenInfo) + + // beam_requestinfo_size * + // sizeof(BeamSearchBatchConfig:: + // BeamSearchPerRequestInfo); // more components will + // // be added here later + + // We always directly allocate memory for small speculative models + // gpu_mem_allocator.create_legion_instance(beam_search_reserve_inst, + // total_size); + beam_topology_mask = + static_cast( + handler.batch_config_metadata + sizeof(BatchConfig::tokensInfo) + + sizeof(BatchConfig::requestsInfo)); + + beam_token_infos = + static_cast( + handler.batch_config_metadata + sizeof(BatchConfig::tokensInfo) + + sizeof(BatchConfig::requestsInfo) + + sizeof(BeamSearchBatchConfig::topology_mask)); + + beam_request_infos = + static_cast( + handler.batch_config_metadata + sizeof(BatchConfig::tokensInfo) + + sizeof(BatchConfig::requestsInfo) + + sizeof(BeamSearchBatchConfig::topology_mask) + + sizeof(BeamSearchBatchConfig::beamTokenInfo)); + // beam_token_infos = + // gpu_mem_allocator + // .allocate_instance( + // beam_tokeninfo_size); + // offset += beam_tokeninfo_size * + // sizeof(BeamSearchBatchConfig::BeamSearchPerTokenInfo); + // beam_request_infos = + // gpu_mem_allocator + // .allocate_instance( + // beam_requestinfo_size); + // offset += beam_requestinfo_size * + // sizeof(BeamSearchBatchConfig::BeamSearchPerRequestInfo); + // assert(offset == total_size); + // assert(gpu_mem_allocator.instance_total_size == + // gpu_mem_allocator.instance_allocated_size); + } + + cudaStreamSynchronize(stream); +} + +SpecInferIncMultiHeadSelfAttentionMeta::~SpecInferIncMultiHeadSelfAttentionMeta( + void) { + if (beam_search_reserve_inst != Realm::RegionInstance::NO_INST) { + beam_search_reserve_inst.destroy(); + } +} + +}; // namespace FlexFlow diff --git a/src/ops/tree_inc_multihead_self_attention.cu b/src/ops/tree_inc_multihead_self_attention.cu index bc7d1017b7..1da56e383a 100644 --- a/src/ops/tree_inc_multihead_self_attention.cu +++ b/src/ops/tree_inc_multihead_self_attention.cu @@ -834,18 +834,18 @@ void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m, m->bias_ptr, bias_ptr, m->biasSize, cudaMemcpyHostToDevice, stream); bias_ptr = static_cast
(m->bias_ptr); } - cudaMemcpyAsync(m->token_infos, - &(bc->tokensInfo), - bc->num_active_tokens() * - sizeof(TreeVerifyBatchConfig::PerTokenInfo), - cudaMemcpyHostToDevice, - stream); - cudaMemcpyAsync(m->request_infos, - &(bc->requestsInfo), - bc->max_requests_per_batch() * - sizeof(BatchConfig::PerRequestInfo), - cudaMemcpyHostToDevice, - stream); + // cudaMemcpyAsync(m->token_infos, + // &(bc->tokensInfo), + // bc->num_active_tokens() * + // sizeof(TreeVerifyBatchConfig::PerTokenInfo), + // cudaMemcpyHostToDevice, + // stream); + // cudaMemcpyAsync(m->request_infos, + // &(bc->requestsInfo), + // bc->max_requests_per_batch() * + // sizeof(BatchConfig::PerRequestInfo), + // cudaMemcpyHostToDevice, + // stream); // phase 1: Implement kernel to compute KQV for input tokens compute_qkv_kernel(m, bc, diff --git a/src/runtime/ffconst_utils.cc b/src/runtime/ffconst_utils.cc index c7b6e1257a..904bfbcaff 100644 --- a/src/runtime/ffconst_utils.cc +++ b/src/runtime/ffconst_utils.cc @@ -154,6 +154,8 @@ std::string get_operator_type_name(OperatorType type) { return "SpecIncMultiHeadSelfAttention"; case OP_TREE_INC_MULTIHEAD_SELF_ATTENTION: return "TreeIncMultiHeadSelfAttention"; + case OP_SPECINFER_INC_MULTIHEAD_SELF_ATTENTION: + return "SpecInferPgraoIncMultiHeadSelfAttention"; case OP_INPUT: return "Input"; case OP_WEIGHT: diff --git a/src/runtime/graph.cc b/src/runtime/graph.cc index 6d33dd9f27..46f7cc0f29 100644 --- a/src/runtime/graph.cc +++ b/src/runtime/graph.cc @@ -51,6 +51,7 @@ #include "flexflow/ops/topk.h" #include "flexflow/ops/transpose.h" #include "flexflow/ops/tree_inc_multihead_self_attention.h" +#include "flexflow/ops/specinfer_inc_multihead_self_attention.h" #include "flexflow/parallel_ops/allreduce.h" #include "flexflow/parallel_ops/combine.h" #include "flexflow/parallel_ops/fused_parallel_op.h" @@ -69,7 +70,7 @@ using FlexFlow::MachineView; LegionRuntime::Logger::Category log_graph("graph"); LegionRuntime::Logger::Category log_simplify("graph_simplify"); -const Node Node::INVALID_NODE = Node(); +Node const Node::INVALID_NODE = Node(); Node::Node(void) : guid(0), ptr(NULL) {} @@ -2384,6 +2385,28 @@ GraphOptimalViewSerialized sez.serialize(attn->tensor_parallelism_degree); break; } + case OP_SPECINFER_INC_MULTIHEAD_SELF_ATTENTION: { + SpecInferIncMultiHeadSelfAttention *attn = + (SpecInferIncMultiHeadSelfAttention *)op; + sez.serialize(attn->layer_guid.id); + sez.serialize(attn->layer_guid.transformer_layer_id); + sez.serialize(attn->layer_guid.model_id); + sez.serialize(attn->oProjSize); + sez.serialize(attn->num_q_heads); + sez.serialize(attn->qProjSize); + sez.serialize(attn->vProjSize); + sez.serialize(attn->dropout); + sez.serialize(attn->qkv_bias); + sez.serialize(attn->final_bias); + sez.serialize(attn->add_zero_attn); + sez.serialize(attn->apply_rotary_embedding); + sez.serialize(attn->scaling_query); + sez.serialize(attn->scaling_factor); + sez.serialize(attn->qk_prod_scaling); + sez.serialize(attn->position_bias); + sez.serialize(attn->num_kv_heads); + break; + } case OP_SOFTMAX: { Softmax *softmax = (Softmax *)op; sez.serialize(softmax->dim); @@ -2914,6 +2937,52 @@ void FFModel::deserialize_graph_optimal_view( params); break; } + case OP_SPECINFER_INC_MULTIHEAD_SELF_ATTENTION: { + assert(num_inputs == 1); + int embed_dim, num_q_heads, k_dim, v_dim, num_kv_heads; + float dropout, scaling_factor; + bool qkv_bias, final_bias, add_zero_attn, apply_rotary_embedding, + scaling_query, qk_prod_scaling, position_bias; + size_t id, transformer_layer_id, deserialized_model_id; + dez.deserialize(id); + dez.deserialize(transformer_layer_id); + dez.deserialize(deserialized_model_id); + LayerID layer_guid(id, transformer_layer_id, deserialized_model_id); + dez.deserialize(embed_dim); + dez.deserialize(num_q_heads); + dez.deserialize(k_dim); + dez.deserialize(v_dim); + dez.deserialize(dropout); + dez.deserialize(qkv_bias); + dez.deserialize(final_bias); + dez.deserialize(add_zero_attn); + dez.deserialize(apply_rotary_embedding); + dez.deserialize(scaling_query); + dez.deserialize(scaling_factor); + dez.deserialize(qk_prod_scaling); + dez.deserialize(position_bias); + dez.deserialize(num_kv_heads); + + SpecInferIncMultiHeadSelfAttentionParams params; + params.embed_dim = embed_dim; + params.num_q_heads = num_q_heads; + params.kdim = k_dim; + params.vdim = v_dim; + params.dropout = dropout; + params.qkv_bias = qkv_bias; + params.final_bias = final_bias; + params.add_zero_attn = add_zero_attn; + params.layer_guid = layer_guid; + params.apply_rotary_embedding = apply_rotary_embedding; + params.scaling_query = scaling_query; + params.scaling_factor = scaling_factor; + params.qk_prod_scaling = qk_prod_scaling; + params.position_bias = position_bias; + params.num_kv_heads = num_kv_heads; + node = get_or_create_node(inputs[0], + params); + break; + } case OP_TOPK: { node = TopK::deserialize(*this, dez, inputs, num_inputs); break; diff --git a/src/runtime/inference_manager.cc b/src/runtime/inference_manager.cc index eb045e8159..fb978adfff 100644 --- a/src/runtime/inference_manager.cc +++ b/src/runtime/inference_manager.cc @@ -318,7 +318,7 @@ FutureMap InferenceManager::inference(FFModel *model, found_input_operator = true; assert(op->numOutputs == 1); ParallelTensor pt = tensor_buffer[op->outputs[0]][batch_index]; - load_input_tokens_from_batch_config(bc, pt); + load_input_tokens_from_batch_config(bc, pt, model->handlers); } } @@ -348,11 +348,20 @@ FutureMap InferenceManager::inference(FFModel *model, }; void InferenceManager::load_input_tokens_from_batch_config( - BatchConfigFuture const &bc, ParallelTensor const input) { + BatchConfigFuture const &bc, ParallelTensor const input, FFHandler *handlers) { Context ctx = ff_config.lg_ctx; Runtime *runtime = ff_config.lg_hlr; size_t machine_view_hash = input->machine_view.hash(); ArgumentMap argmap; + Rect<1> task_rect(Point<1>(0), + Point<1>(ff_config.workersPerNode * ff_config.numNodes - 1)); + IndexSpaceT<1> task_is = runtime->create_index_space(ctx, task_rect); + MachineView view = input->machine_view; + for (PointInRectIterator<1> it(task_rect); it(); it++) { + FFHandler handle = handlers[view.get_device_id(*it)]; + argmap.set_point(*it, TaskArgument(&handle, sizeof(FFHandler))); + } + IndexLauncher launcher(RM_LOAD_TOKENS_TASK_ID, input->parallel_is, TaskArgument(nullptr, 0), diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 92f0cff472..8bda9016c3 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -59,6 +59,7 @@ #include "flexflow/ops/sigmoid_silu_multi.h" #include "flexflow/ops/softmax.h" #include "flexflow/ops/spec_inc_multihead_self_attention.h" +#include "flexflow/ops/specinfer_inc_multihead_self_attention.h" #include "flexflow/ops/split.h" #include "flexflow/ops/topk.h" #include "flexflow/ops/transpose.h" @@ -93,10 +94,10 @@ Op::Op(FFModel &model, int numWeights, bool allocate_weights, int numOutputs, - const ParallelTensor input1, - const ParallelTensor input2, - const ParallelTensor input3, - const ParallelTensor input4) + ParallelTensor const input1, + ParallelTensor const input2, + ParallelTensor const input3, + ParallelTensor const input4) : Op(model, otype, dtype, @@ -116,10 +117,10 @@ Op::Op(FFModel &model, int _numInputs, int _numWeights, int _numOutputs, - const ParallelTensor _input1, - const ParallelTensor _input2, - const ParallelTensor _input3, - const ParallelTensor _input4) + ParallelTensor const _input1, + ParallelTensor const _input2, + ParallelTensor const _input3, + ParallelTensor const _input4) : op_type(_otype), data_type(_dtype), op_guid(model.op_global_guid++), numInputs(_numInputs), numWeights(_numWeights), numOutputs(_numOutputs), profiling(model.config.profiling), @@ -1024,9 +1025,9 @@ void Op::register_output_parallel_dims( operation); } -int Op::get_output_to_input_dim_mapping(const ParallelTensor output, +int Op::get_output_to_input_dim_mapping(ParallelTensor const output, int output_dim, - const ParallelTensor input) { + ParallelTensor const input) { int output_idx = -1, input_idx = -1; for (int i = 0; i < numOutputs; i++) { if (output == outputs[i]) { @@ -1059,9 +1060,9 @@ int Op::get_output_to_input_dim_mapping(const ParallelTensor output, return -1; } -int Op::get_output_to_weight_dim_mapping(const ParallelTensor output, +int Op::get_output_to_weight_dim_mapping(ParallelTensor const output, int output_dim, - const ParallelTensor weight) { + ParallelTensor const weight) { int output_idx = -1, weight_idx = -1; for (int i = 0; i < numOutputs; i++) { if (output == outputs[i]) { @@ -1658,7 +1659,7 @@ Tensor FFModel::create_tensor(int numdim, } ParallelTensor FFModel::create_parallel_tensor(int numdim, - const ParallelDim dims[], + ParallelDim const dims[], DataType data_type, Op const *op, int idx, @@ -1691,7 +1692,7 @@ Tensor FFModel::create_tensor_legion_ordering(int numdim, ParallelTensor FFModel::create_parallel_tensor_legion_ordering(int numdim, - const ParallelDim dims[], + ParallelDim const dims[], DataType data_type, Op const *op, int idx, @@ -1741,7 +1742,7 @@ Tensor FFModel::create_tensor(int const dims[], } template -ParallelTensor FFModel::create_parallel_tensor(const ParallelDim dims[], +ParallelTensor FFModel::create_parallel_tensor(ParallelDim const dims[], DataType data_type, Op const *owner_op, int owner_idx, @@ -1822,7 +1823,7 @@ Parameter FFModel::create_weight(int numdim, } template -ParallelParameter FFModel::create_parallel_weight(const ParallelDim dims[], +ParallelParameter FFModel::create_parallel_weight(ParallelDim const dims[], DataType data_type, Op const *owner_op, bool create_grad, @@ -1853,7 +1854,7 @@ ParallelParameter FFModel::create_parallel_weight(const ParallelDim dims[], } ParallelParameter FFModel::create_parallel_weight(int numdim, - const ParallelDim dims[], + ParallelDim const dims[], DataType data_type, Op const *owner_op, bool create_grad, @@ -1873,7 +1874,7 @@ ParallelParameter FFModel::create_parallel_weight(int numdim, ParallelParameter FFModel::create_parallel_weight_legion_ordering( int numdim, - const ParallelDim dims[], + ParallelDim const dims[], DataType data_type, Op const *owner_op, bool create_grad, @@ -2087,7 +2088,7 @@ void FFModel::map_weight_with_dim(ParallelTensor weight, } bool FFModel::get_parallel_tensor_from_tensor( - const Tensor tensor, ParallelTensor ¶llel_tensor) const { + Tensor const tensor, ParallelTensor ¶llel_tensor) const { // check if tensor->parallel_tensor is already set if (tensor->parallel_tensor != nullptr) { parallel_tensor = tensor->parallel_tensor; @@ -2124,7 +2125,7 @@ bool FFModel::get_parallel_tensor_from_tensor( } void FFModel::create_disjoint_partition(int num_dims, - const ParallelDim dims[], + ParallelDim const dims[], IndexSpace const &part_is, LogicalRegion const ®ion, LogicalPartition &part) { @@ -2147,7 +2148,7 @@ void FFModel::create_disjoint_partition(int num_dims, template void FFModel::create_disjoint_partition_with_dim2( - const ParallelDim dims[], + ParallelDim const dims[], IndexSpaceT const &part_is, LogicalRegion const ®ion, LogicalPartition &part) { @@ -2180,7 +2181,7 @@ void FFModel::create_disjoint_partition_with_dim2( } void FFModel::create_aliased_partition(int num_dims, - const ParallelDim dims[], + ParallelDim const dims[], int aliased_dim, IndexSpace const &part_is, LogicalRegion const ®ion, @@ -2204,7 +2205,7 @@ void FFModel::create_aliased_partition(int num_dims, template void FFModel::create_aliased_partition_with_dim2( - const ParallelDim dims[], + ParallelDim const dims[], int aliased_dim, IndexSpaceT const &part_is, LogicalRegion const ®ion, @@ -2241,7 +2242,7 @@ void FFModel::create_aliased_partition_with_dim2( } template -void FFModel::create_disjoint_partition(const ParallelTensor tensor, +void FFModel::create_disjoint_partition(ParallelTensor const tensor, IndexSpaceT const &part_is, LogicalPartition &part_fwd, LogicalPartition &part_bwd) { @@ -2289,7 +2290,7 @@ void FFModel::create_disjoint_partition(const ParallelTensor tensor, template void FFModel::create_data_parallel_partition_with_diff_dims( - const ParallelTensor tensor, + ParallelTensor const tensor, IndexSpaceT const &part_is, LogicalPartition &part_fwd, LogicalPartition &part_bwd) { @@ -2671,7 +2672,7 @@ IndexSpace FFModel::get_task_is(ParallelConfig const &pc) const { return get_task_is(view); } -IndexSpace FFModel::get_or_create_task_is(const ParallelTensor tensor) { +IndexSpace FFModel::get_or_create_task_is(ParallelTensor const tensor) { MachineView view; view.ndims = 0; for (int i = 0; i < tensor->num_dims; i++) { @@ -3038,6 +3039,12 @@ Op *FFModel::create_operator_from_layer( operators.push_back(op); return op; } + case OP_SPECINFER_INC_MULTIHEAD_SELF_ATTENTION: { + Op *op = SpecInferIncMultiHeadSelfAttention::create_operator_from_layer( + *this, layer, inputs); + operators.push_back(op); + return op; + } case OP_BATCHMATMUL: { Op *op = BatchMatmul::create_operator_from_layer(*this, layer, inputs); operators.push_back(op); @@ -3227,7 +3234,7 @@ Op *FFModel::create_operator_from_layer( } void FFModel::create_operators_from_layers() { - std::map tensors_to_parallel_tensors; + std::map tensors_to_parallel_tensors; // for (auto const &l : layers) { for (int layer_idx = 0; layer_idx < layers.size(); layer_idx++) { auto const &l = layers[layer_idx]; @@ -3973,38 +3980,38 @@ void FFIterationConfig::reset() { // Default Config Parameters struct DefaultConfig { - const static int epochs = 1; + static int const epochs = 1; // const static int iterations = 1; - const static int batchSize = 64; - const static bool profiling = false; - const static bool inference_debugging = false; + static int const batchSize = 64; + static bool const profiling = false; + static bool const inference_debugging = false; constexpr static float learningRate = 0.01f; constexpr static float weightDecay = 0.0001f; - const static size_t workSpaceSize = (size_t)128 * 1024 * 1024; // 128 MB - const static int numNodes = 1; - const static int workersPerNode = 0; - const static int cpusPerNode = 0; - const static size_t searchBudget = -1; - const static size_t simulatorWorkSpaceSize = + static size_t const workSpaceSize = (size_t)128 * 1024 * 1024; // 128 MB + static int const numNodes = 1; + static int const workersPerNode = 0; + static int const cpusPerNode = 0; + static size_t const searchBudget = -1; + static size_t const simulatorWorkSpaceSize = (size_t)2 * 1024 * 1024 * 1024; // 2 GB constexpr static float searchAlpha = 1.2f; - const static bool searchOverlapBackwardUpdate = false; - const static size_t offloadReserveSpaceSize = + static bool const searchOverlapBackwardUpdate = false; + static size_t const offloadReserveSpaceSize = (size_t)8 * 1024 * 1024 * 1024; // 8 GB - const static bool cpuOffload = false; - const static bool onlyDataParallel = true; - const static bool enableSampleParallel = true; - const static bool enableParameterParallel = false; - const static bool enableAttributeParallel = false; - const static bool enableInplaceOptimizations = false; - const static bool allowTensorOpMathConversion = false; - const static int machine_model_version = 0; - const static int simulator_segment_size = 16777216; // 16 MB - const static int simulator_max_num_segments = 1; - const static int base_optimize_threshold = 10; - const static bool enable_control_replication = true; + static bool const cpuOffload = false; + static bool const onlyDataParallel = true; + static bool const enableSampleParallel = true; + static bool const enableParameterParallel = false; + static bool const enableAttributeParallel = false; + static bool const enableInplaceOptimizations = false; + static bool const allowTensorOpMathConversion = false; + static int const machine_model_version = 0; + static int const simulator_segment_size = 16777216; // 16 MB + static int const simulator_max_num_segments = 1; + static int const base_optimize_threshold = 10; + static bool const enable_control_replication = true; // The default python data loader type is 2 to enable control replication - const static int python_data_loader_type = 2; + static int const python_data_loader_type = 2; }; FFConfig::FFConfig() { @@ -6209,6 +6216,44 @@ void register_flexflow_internal_tasks(Runtime *runtime, TreeIncMultiHeadSelfAttention::inference_task>(registrar); } } + { + TaskVariantRegistrar registrar( + SPECINFER_INC_MULTIHEAD_SELF_ATTENTION_INIT_TASK_ID, + "SpecInferIncMultiHeadSelfAttention Init"); + registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); + registrar.set_leaf(); + if (pre_register) { + Runtime::preregister_task_variant< + OpMeta *, + SpecInferIncMultiHeadSelfAttention::init_task>( + registrar, "SpecInferIncMultiHeadSelfAttention Init Task"); + } else { + if (enable_control_replication) { + registrar.global_registration = false; + } + runtime->register_task_variant< + OpMeta *, + SpecInferIncMultiHeadSelfAttention::init_task>(registrar); + } + } + { + TaskVariantRegistrar registrar( + SPECINFER_INC_MULTIHEAD_SELF_ATTENTION_INF_TASK_ID, + "SpecInferIncMultiHeadSelfAttention Inference"); + registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); + registrar.set_leaf(); + if (pre_register) { + Runtime::preregister_task_variant< + SpecInferIncMultiHeadSelfAttention::inference_task>( + registrar, "SpecInferIncMultiHeadSelfAttention Inference Task"); + } else { + if (enable_control_replication) { + registrar.global_registration = false; + } + runtime->register_task_variant< + SpecInferIncMultiHeadSelfAttention::inference_task>(registrar); + } + } // NoOp { TaskVariantRegistrar registrar(NOOP_INIT_TASK_ID, "Weight NCCL Init"); diff --git a/src/runtime/model.cpp b/src/runtime/model.cpp index 6c482426eb..b51ab83091 100644 --- a/src/runtime/model.cpp +++ b/src/runtime/model.cpp @@ -131,6 +131,54 @@ FFHandler .wait(); handle.workSpace = workspaceInst.pointer_untyped(0, sizeof(char)); } + if (handle.offload_reserve_space_size > 0) { + // allocate memory for offload reserve space + Memory gpu_mem = Machine::MemoryQuery(Machine::get_machine()) + .only_kind(Memory::GPU_FB_MEM) + .best_affinity_to(task->target_proc) + .first(); + Realm::Rect<1, coord_t> bounds( + Realm::Point<1, coord_t>(0), + Realm::Point<1, coord_t>(handle.offload_reserve_space_size - 1)); + std::vector field_sizes; + field_sizes.push_back(sizeof(char)); + Realm::RegionInstance workspaceInst; + Realm::RegionInstance::create_instance(workspaceInst, + gpu_mem, + bounds, + field_sizes, + 0, + Realm::ProfilingRequestSet()) + .wait(); + handle.offload_reserve_space = + workspaceInst.pointer_untyped(0, sizeof(char)); + }else { + handle.offload_reserve_space = nullptr; + } + if (handle.batch_config_metadata_size > 0) { + // allocate memory for offload reserve space + Memory gpu_mem = Machine::MemoryQuery(Machine::get_machine()) + .only_kind(Memory::GPU_FB_MEM) + .best_affinity_to(task->target_proc) + .first(); + Realm::Rect<1, coord_t> bounds( + Realm::Point<1, coord_t>(0), + Realm::Point<1, coord_t>(handle.batch_config_metadata_size - 1)); + std::vector field_sizes; + field_sizes.push_back(sizeof(char)); + Realm::RegionInstance workspaceInst; + Realm::RegionInstance::create_instance(workspaceInst, + gpu_mem, + bounds, + field_sizes, + 0, + Realm::ProfilingRequestSet()) + .wait(); + handle.batch_config_metadata = + workspaceInst.pointer_untyped(0, sizeof(char)); + }else { + handle.batch_config_metadata = nullptr; + } // checkCUDA(hipMalloc(&handle.workSpace, handle.workSpaceSize)); #ifdef FF_USE_NCCL handle.ncclComm = NULL; diff --git a/src/runtime/model.cu b/src/runtime/model.cu index 17401a0f14..523b3c76f3 100644 --- a/src/runtime/model.cu +++ b/src/runtime/model.cu @@ -148,9 +148,35 @@ FFHandler .wait(); handle.offload_reserve_space = workspaceInst.pointer_untyped(0, sizeof(char)); - } else { + }else { handle.offload_reserve_space = nullptr; } + if (handle.batch_config_metadata_size > 0) { + printf("allocate instance for metadata %d\n", handle.batch_config_metadata_size); + // allocate memory for offload reserve space + Memory gpu_mem = Machine::MemoryQuery(Machine::get_machine()) + .only_kind(Memory::GPU_FB_MEM) + .best_affinity_to(task->target_proc) + .first(); + Realm::Rect<1, coord_t> bounds( + Realm::Point<1, coord_t>(0), + Realm::Point<1, coord_t>(handle.batch_config_metadata_size - 1)); + std::vector field_sizes; + field_sizes.push_back(sizeof(char)); + Realm::RegionInstance workspaceInst; + Realm::RegionInstance::create_instance(workspaceInst, + gpu_mem, + bounds, + field_sizes, + 0, + Realm::ProfilingRequestSet()) + .wait(); + handle.batch_config_metadata = + workspaceInst.pointer_untyped(0, sizeof(char)); + }else { + handle.batch_config_metadata = nullptr; + } + // checkCUDA(cudaMalloc(&handle.workSpace, handle.workSpaceSize)); #ifdef FF_USE_NCCL diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 7c37f3391e..e1b591c320 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -106,6 +106,11 @@ int RequestManager::get_max_sequence_length() { return max_sequence_length; } +void RequestManager::push_spec_infer_tree_width(int tree_width) { + assert(tree_width <= BeamSearchBatchConfig::MAX_BEAM_WIDTH); + spec_infer_tree_width.emplace_back(tree_width); +} + void RequestManager::register_tokenizer(ModelType type, int bos_token_id, int eos_token_id, @@ -176,7 +181,7 @@ size_t RequestManager::get_num_ssms() { RequestManager::RequestGuid RequestManager::register_new_request(std::vector const &prompt, int max_sequence_length) { - const std::lock_guard lock(request_queue_mutex); + std::lock_guard const lock(request_queue_mutex); // Add a new request Request request; @@ -232,7 +237,7 @@ RequestManager::RequestGuid RequestManager::RequestGuid RequestManager::register_new_request(std::string const &prompt, int max_sequence_length) { - const std::lock_guard lock(request_queue_mutex); + std::lock_guard const lock(request_queue_mutex); // Add a new request Request request; request.status = Request::PENDING; @@ -290,7 +295,7 @@ RequestManager::RequestGuid } bool RequestManager::is_request_completed(RequestGuid const &guid) { - const std::lock_guard lock(request_queue_mutex); + std::lock_guard const lock(request_queue_mutex); assert(all_requests.find(guid) != all_requests.end()); Request const &request = all_requests[guid]; // return request.tokens.size() >= request.max_sequence_length; @@ -299,7 +304,7 @@ bool RequestManager::is_request_completed(RequestGuid const &guid) { GenerationResult RequestManager::get_generation_result(RequestGuid const &guid) { - const std::lock_guard lock(request_queue_mutex); + std::lock_guard const lock(request_queue_mutex); assert(request_generation_results.find(guid) != request_generation_results.end()); return request_generation_results[guid]; @@ -337,7 +342,7 @@ BatchConfig RequestManager::prepare_next_batch_task( BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, InferenceResult const &result) { - const std::lock_guard lock(request_queue_mutex); + std::lock_guard const lock(request_queue_mutex); // Step 1: append result from previous iteration to request's tokens for (int i = 0; i < old_bc.num_tokens; i++) { @@ -406,13 +411,14 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, total_request_run_time += profile_info.finish_time - profile_info.start_time; profiling_requests[request.guid] = profile_info; - log_req_mgr.print("[Profile] guid(%zu) decoding_steps(%d) start(%.1lf) " - "finish(%.1lf) latency(%.1lf)", - request.guid, - profile_info.decoding_steps, - profile_info.start_time, - profile_info.finish_time, - profile_info.finish_time - profile_info.start_time); + log_req_mgr.print( + "[Profile] guid(%zu) llm_decoding_steps(%d) start(%.1lf) " + "finish(%.1lf) latency(%.1lf)", + request.guid, + profile_info.llm_decoding_steps, + profile_info.start_time, + profile_info.finish_time, + profile_info.finish_time - profile_info.start_time); // Write output to file if needed: if (!output_filepath.empty()) { std::ofstream outputFile(output_filepath, std::ios::app); @@ -420,8 +426,8 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, outputFile << "end-to-end latency: " << std::fixed << std::setprecision(3) << total_request_run_time << std::endl; - outputFile << "num decoding steps: " << profile_info.decoding_steps - << std::endl; + outputFile << "num decoding steps: " + << profile_info.llm_decoding_steps << std::endl; outputFile << "token IDs: "; for (int i = 0; i < request.tokens.size(); i++) { outputFile << request.tokens[i]; @@ -469,7 +475,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, } // Update profiling profiling_requests[new_bc.requestsInfo[i].request_guid] - .decoding_steps++; + .llm_decoding_steps++; } } } @@ -494,7 +500,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, new_bc.request_completed[i] = false; // add profile_info for the new request ProfileInfo profile_info; - profile_info.decoding_steps = 1; + profile_info.llm_decoding_steps = 1; profile_info.start_time = Realm::Clock::current_time_in_microseconds(); profiling_requests[new_request.guid] = profile_info; for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) { @@ -553,7 +559,7 @@ BeamSearchBatchConfig RequestManager::prepare_next_batch_init(TreeVerifyBatchConfig const &old_bc, InferenceResult const &result, int model_id) { - const std::lock_guard lock(request_queue_mutex); + std::lock_guard const lock(request_queue_mutex); if (verbose) { std::cout << "\n############### prepare_next_batch_init ###############\n"; } @@ -664,16 +670,18 @@ BeamSearchBatchConfig // Log profiling info ProfileInfo profile_info = profiling_requests[request.guid]; profile_info.finish_time = Realm::Clock::current_time_in_microseconds(); + profile_info.ssm_decoding_steps = 0; total_request_run_time += profile_info.finish_time - profile_info.start_time; profiling_requests[request.guid] = profile_info; - log_req_mgr.print("[Profile] guid(%zu) decoding_steps(%d) start(%.1lf) " - "finish(%.1lf) latency(%.1lf)", - request.guid, - profile_info.decoding_steps, - profile_info.start_time, - profile_info.finish_time, - profile_info.finish_time - profile_info.start_time); + log_req_mgr.print( + "[Profile] guid(%zu) llm_decoding_steps(%d) start(%.1lf) " + "finish(%.1lf) latency(%.1lf)", + request.guid, + profile_info.llm_decoding_steps, + profile_info.start_time, + profile_info.finish_time, + profile_info.finish_time - profile_info.start_time); // Write output to file if needed: if (!output_filepath.empty()) { @@ -682,8 +690,8 @@ BeamSearchBatchConfig outputFile << "end-to-end latency: " << std::fixed << std::setprecision(3) << total_request_run_time << std::endl; - outputFile << "num decoding steps: " << profile_info.decoding_steps - << std::endl; + outputFile << "num decoding steps: " + << profile_info.llm_decoding_steps << std::endl; outputFile << "token IDs: "; for (int i = 0; i < request.tokens.size(); i++) { outputFile << request.tokens[i]; @@ -726,8 +734,14 @@ BeamSearchBatchConfig new_bc.requestsInfo[i].first_token_depth_in_request - verified_tokens.size(); new_bc.beamRequestsInfo[i].current_depth = 1; + + profiling_requests[request.guid].ssm_decoding_steps = 0; + + int ssm_decoding_steps = 0; new_bc.beamRequestsInfo[i].beam_size = - BeamSearchBatchConfig::MAX_BEAM_WIDTH; + spec_infer_tree_width.size() > ssm_decoding_steps + ? spec_infer_tree_width[ssm_decoding_steps] + : 1; new_bc.beamRequestsInfo[i].max_depth = std::min(new_max_depth, BeamSearchBatchConfig::MAX_BEAM_DEPTH); for (int j = 0; j < BeamSearchBatchConfig::MAX_BEAM_WIDTH; j++) { @@ -735,6 +749,8 @@ BeamSearchBatchConfig new_bc.beamRequestsInfo[i].probs[j] = 1; } + new_bc.beamRequestsInfo[i].sub_request_num = 1; + new_bc.sub_requests[i] = 1; // Token Info @@ -746,6 +762,8 @@ BeamSearchBatchConfig new_bc.tokensInfo[new_bc.num_tokens].token_id = token.first; new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = token.second; + new_bc.topology_mask[i].real_token_pos[0][token.second] = + new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request; // Beam Token Info new_bc.beamTokenInfo[new_bc.num_tokens].sub_request_index = 0; @@ -786,14 +804,20 @@ BeamSearchBatchConfig // TODO: Beam Request Info, missing from VerifyTreeBatchConfig new_bc.beamRequestsInfo[i].current_depth = 1; + int ssm_decoding_steps = + profiling_requests[request.guid].ssm_decoding_steps; new_bc.beamRequestsInfo[i].beam_size = - BeamSearchBatchConfig::MAX_BEAM_WIDTH; + spec_infer_tree_width.size() > ssm_decoding_steps + ? spec_infer_tree_width[ssm_decoding_steps] + : 1; new_bc.beamRequestsInfo[i].max_depth = 0; for (int j = 0; j < BeamSearchBatchConfig::MAX_BEAM_WIDTH; j++) { new_bc.beamRequestsInfo[i].parent_id[j] = 0; new_bc.beamRequestsInfo[i].probs[j] = 1; } + new_bc.beamRequestsInfo[i].sub_request_num = 1; + new_bc.sub_requests[i] = 1; // Token Info @@ -829,12 +853,17 @@ BeamSearchBatchConfig // add profile_info for the new request ProfileInfo profile_info; - profile_info.decoding_steps = 0; + profile_info.llm_decoding_steps = 0; + profile_info.ssm_decoding_steps = 0; profile_info.start_time = Realm::Clock::current_time_in_microseconds(); profiling_requests[new_request.guid] = profile_info; // init the beam search metadata per request + int ssm_decoding_steps = profile_info.ssm_decoding_steps; + new_bc.beamRequestsInfo[i].beam_size = - BeamSearchBatchConfig::MAX_BEAM_WIDTH; + spec_infer_tree_width.size() > ssm_decoding_steps + ? spec_infer_tree_width[ssm_decoding_steps] + : 1; new_bc.beamRequestsInfo[i].current_depth = 1; new_bc.beamRequestsInfo[i].max_depth = std::min(BeamSearchBatchConfig::MAX_BEAM_DEPTH, @@ -846,6 +875,7 @@ BeamSearchBatchConfig } new_bc.request_completed[i] = false; + new_bc.beamRequestsInfo[i].sub_request_num = 1; new_bc.sub_requests[i] = 1; for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) { @@ -855,6 +885,7 @@ BeamSearchBatchConfig assert(depth < new_request.tokens.size()); new_bc.tokensInfo[new_bc.num_tokens].token_id = new_request.tokens[depth]; + new_bc.topology_mask[i].real_token_pos[0][depth] = depth; // beam search meta data, indicate which sub request this token // belongs to, init to 0; @@ -937,7 +968,7 @@ BeamSearchBatchConfig RequestManager::prepare_next_batch_beam_task( BeamSearchBatchConfig RequestManager::prepare_next_batch_beam(BeamSearchBatchConfig const &old_bc, BeamInferenceResult const &result) { - const std::lock_guard lock(request_queue_mutex); + std::lock_guard const lock(request_queue_mutex); if (verbose) { std::cout << "\n############### prepare_next_batch_beam ###############\n"; } @@ -1005,25 +1036,38 @@ BeamSearchBatchConfig new_bc.requestsInfo[i].request_guid = old_bc.requestsInfo[i].request_guid; new_bc.requestsInfo[i].max_sequence_length = old_bc.requestsInfo[i].max_sequence_length; - + profiling_requests[request.guid].ssm_decoding_steps += 1; // update the beam search metadata // how many sub request in current request // why is sub_requests has max_requests_per_batch() * MAX_BEAM_WIDTH // entries? - new_bc.sub_requests[i] = old_bc.beamRequestsInfo[i].beam_size; - // update the parentid, accumalated_probs, depth, and token_ids + int ssm_decoding_steps = + profiling_requests[request.guid].ssm_decoding_steps; + new_bc.beamRequestsInfo[i].beam_size = - old_bc.beamRequestsInfo[i].beam_size; + spec_infer_tree_width.size() > ssm_decoding_steps + ? spec_infer_tree_width[ssm_decoding_steps] + : 1; new_bc.beamRequestsInfo[i].max_depth = old_bc.beamRequestsInfo[i].max_depth; + new_bc.sub_requests[i] = + old_bc.sub_requests[i] * new_bc.beamRequestsInfo[i].beam_size; + new_bc.beamRequestsInfo[i].sub_request_num = + old_bc.beamRequestsInfo[i].sub_request_num * + new_bc.beamRequestsInfo[i].beam_size; + + assert(new_bc.beamRequestsInfo[i].sub_request_num <= + BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES); + if (request.status == Request::RUNNING) { new_bc.beamRequestsInfo[i].current_depth = old_bc.beamRequestsInfo[i].current_depth + 1; new_bc.request_running[i] = true; // do the slot exchange to minimize the cache exchange in kernel. - update_beam_metadata(new_bc, request.beam_trees.at(old_bc.model_id), i); + update_beam_metadata( + new_bc, old_bc, request.beam_trees.at(old_bc.model_id), i); } else { assert(false && "Request should not be pending in beam search phase"); } @@ -1059,7 +1103,7 @@ BeamSearchBatchConfig // register more tokens due to the beam width for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) { int depth = new_bc.requestsInfo[i].first_token_depth_in_request + j; - for (int k = 0; k < new_bc.sub_requests[i]; k++) { + for (int k = 0; k < new_bc.beamRequestsInfo[i].sub_request_num; k++) { new_bc.tokensInfo[new_bc.num_tokens].request_index = i; new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = depth; @@ -1103,13 +1147,24 @@ BeamSearchBatchConfig // how many sub request in current request // why is sub_requests has max_requests_per_batch() * MAX_BEAM_WIDTH // entries? - new_bc.sub_requests[i] = old_bc.beamRequestsInfo[i].beam_size; + int ssm_decoding_steps = + profiling_requests[request.guid].ssm_decoding_steps; - // update the parentid, accumalated_probs, depth, and token_ids new_bc.beamRequestsInfo[i].beam_size = - old_bc.beamRequestsInfo[i].beam_size; + spec_infer_tree_width.size() > ssm_decoding_steps + ? spec_infer_tree_width[ssm_decoding_steps] + : 1; new_bc.beamRequestsInfo[i].max_depth = old_bc.beamRequestsInfo[i].max_depth; + new_bc.sub_requests[i] = + old_bc.sub_requests[i] * new_bc.beamRequestsInfo[i].beam_size; + new_bc.beamRequestsInfo[i].sub_request_num = + old_bc.beamRequestsInfo[i].sub_request_num * + new_bc.beamRequestsInfo[i].beam_size; + assert(new_bc.beamRequestsInfo[i].sub_request_num <= + BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES); + + // update the parentid, accumalated_probs, depth, and token_ids if (request.status == Request::PENDING) { // if the request is pending, we need to update the beam search @@ -1152,7 +1207,7 @@ BeamSearchBatchConfig // register more tokens due to the beam width for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) { int depth = new_bc.requestsInfo[i].first_token_depth_in_request + j; - for (int k = 0; k < new_bc.sub_requests[i]; k++) { + for (int k = 0; k < new_bc.beamRequestsInfo[i].sub_request_num; k++) { new_bc.tokensInfo[new_bc.num_tokens].request_index = i; new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = depth; @@ -1209,7 +1264,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify_task( TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( std::vector const &old_batches) { - const std::lock_guard lock(request_queue_mutex); + std::lock_guard const lock(request_queue_mutex); std::cout << "\n############### prepare_next_batch_verify ###############\n"; @@ -1238,7 +1293,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( Request &request = all_requests[guid]; // Profiling - profiling_requests[request.guid].decoding_steps += 1; + profiling_requests[request.guid].llm_decoding_steps += 1; if (request.status == Request::RUNNING) { new_bc.request_running[i] = true; @@ -1478,16 +1533,19 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, int index = old_bc.tokensInfo[i - 1].request_index; int beam_size = old_bc.beamRequestsInfo[index].beam_size; + + // int leaf_node_num = old_bc.sub_requests[index]; + int leaf_node_num = old_bc.beamRequestsInfo[i].sub_request_num; int depth = old_bc.beamRequestsInfo[index].current_depth; // Each token yields (beam_width) results - int beam_width = old_bc.beamRequestsInfo[index].beam_size; + // int beam_width = old_bc.beamRequestsInfo[index].beam_size; // Count tokens sent to model in this request to find the final token's // index result_index += (old_bc.tokensInfo[i - 1].abs_depth_in_request - start_depth) * - beam_width; + leaf_node_num; if (verbose) { std::cout << "i = " << i << ", result index = " << result_index @@ -1514,7 +1572,7 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, } } - for (int beam_id = 0; beam_id < beam_width; beam_id++) { + for (int beam_id = 0; beam_id < leaf_node_num; beam_id++) { request.beam_trees.at(old_bc.model_id) .treeLayers[depth] .tokens[beam_id] = result.token_ids[result_index]; @@ -1546,6 +1604,7 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, // for updating the beam search metadata in requests in incremental phase void RequestManager::update_beam_metadata(BeamSearchBatchConfig &new_bc, + BeamSearchBatchConfig const &old_bc, BeamTree &tree, int request_index) { @@ -1556,6 +1615,9 @@ void RequestManager::update_beam_metadata(BeamSearchBatchConfig &new_bc, int depth = new_bc.beamRequestsInfo[request_index].current_depth - 1; int beam_size = new_bc.beamRequestsInfo[request_index].beam_size; + // int leaf_node_num = old_bc.sub_requests[request_index]; + int leaf_node_num = old_bc.beamRequestsInfo[request_index].sub_request_num; + if (new_bc.beamRequestsInfo[request_index].current_depth == 1) { // TODO: check if this is correct // for (int j = 0; j < beam_size; j++) { @@ -1568,49 +1630,61 @@ void RequestManager::update_beam_metadata(BeamSearchBatchConfig &new_bc, // Do nothing // assert(false); } else { - std::set parents; - std::set childs; - // cache stealing - for (int j = 0; j < beam_size; j++) { - int parent_id = tree.treeLayers[depth].parent_ids[j]; - if (childs.find(parent_id) == childs.end()) { - // copy beam slot - new_bc.beamRequestsInfo[request_index].parent_id[parent_id] = - tree.treeLayers[depth].parent_ids[j]; - new_bc.beamRequestsInfo[request_index].probs[parent_id] = - tree.treeLayers[depth].probs[j]; - new_bc.beamRequestsInfo[request_index].tokens[parent_id] = - tree.treeLayers[depth].tokens[j]; - parents.emplace(j); - childs.emplace(parent_id); - } - } - if (parents.size() < beam_size) { - for (int j = 0; j < beam_size; j++) { - if (parents.find(j) == parents.end()) { - // this slot has not been assigned - // find the smallest not assigned child and put in - if (verbose) { - std::cout << "request_index" << request_index - << ", miss slot: " << j << "\n"; - } - for (int k = 0; k < beam_size; k++) { - if (childs.find(k) == childs.end()) { - // parent -> j to child k; - new_bc.beamRequestsInfo[request_index].parent_id[k] = - tree.treeLayers[depth].parent_ids[j]; - new_bc.beamRequestsInfo[request_index].probs[k] = - tree.treeLayers[depth].probs[j]; - new_bc.beamRequestsInfo[request_index].tokens[k] = - tree.treeLayers[depth].tokens[j]; - parents.emplace(j); - childs.emplace(k); - break; - } - } - } - } + for (int j = 0; j < leaf_node_num; j++) { + new_bc.beamRequestsInfo[request_index].parent_id[j] = + tree.treeLayers[depth].parent_ids[j]; + new_bc.beamRequestsInfo[request_index].probs[j] = + tree.treeLayers[depth].probs[j]; + new_bc.beamRequestsInfo[request_index].tokens[j] = + tree.treeLayers[depth].tokens[j]; + + // new_bc.topology_mask[request_index].real_token_pos[j] = } + assert(false); + + // std::set parents; + // std::set childs; + // // cache stealing + // for (int j = 0; j < beam_size; j++) { + // int parent_id = tree.treeLayers[depth].parent_ids[j]; + // if (childs.find(parent_id) == childs.end()) { + // // copy beam slot + // new_bc.beamRequestsInfo[request_index].parent_id[parent_id] = + // tree.treeLayers[depth].parent_ids[j]; + // new_bc.beamRequestsInfo[request_index].probs[parent_id] = + // tree.treeLayers[depth].probs[j]; + // new_bc.beamRequestsInfo[request_index].tokens[parent_id] = + // tree.treeLayers[depth].tokens[j]; + // parents.emplace(j); + // childs.emplace(parent_id); + // } + // } + // if (parents.size() < beam_size) { + // for (int j = 0; j < beam_size; j++) { + // if (parents.find(j) == parents.end()) { + // // this slot has not been assigned + // // find the smallest not assigned child and put in + // if (verbose) { + // std::cout << "request_index" << request_index + // << ", miss slot: " << j << "\n"; + // } + // for (int k = 0; k < beam_size; k++) { + // if (childs.find(k) == childs.end()) { + // // parent -> j to child k; + // new_bc.beamRequestsInfo[request_index].parent_id[k] = + // tree.treeLayers[depth].parent_ids[j]; + // new_bc.beamRequestsInfo[request_index].probs[k] = + // tree.treeLayers[depth].probs[j]; + // new_bc.beamRequestsInfo[request_index].tokens[k] = + // tree.treeLayers[depth].tokens[j]; + // parents.emplace(j); + // childs.emplace(k); + // break; + // } + // } + // } + // } + // } } if (verbose) { std::cout << "-----------after parent id exchange-----------" << std::endl; diff --git a/src/runtime/request_manager.cpp b/src/runtime/request_manager.cpp index 1e756606f8..9635b3bc1e 100644 --- a/src/runtime/request_manager.cpp +++ b/src/runtime/request_manager.cpp @@ -56,6 +56,22 @@ void RequestManager::load_tokens_task( sizeof(TokenId) * batch_config->num_tokens, hipMemcpyHostToDevice, stream)); + + // copy meta data to workSpace + FFHandler handle = *((FFHandler const *)task->local_args); + cudaMemcpyAsync(handle.batch_config_metadata, + &(batch_config->tokensInfo), + batch_config->num_active_tokens() * + sizeof(BatchConfig::PerTokenInfo), + cudaMemcpyHostToDevice, + stream); + cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + + sizeof(BatchConfig::tokensInfo), + &(batch_config->requestsInfo), + batch_config->max_requests_per_batch() * + sizeof(BatchConfig::PerRequestInfo), + cudaMemcpyHostToDevice, + stream); } void RequestManager::load_positions_task( diff --git a/src/runtime/request_manager.cu b/src/runtime/request_manager.cu index cd3e03fff6..f4500d152d 100644 --- a/src/runtime/request_manager.cu +++ b/src/runtime/request_manager.cu @@ -30,6 +30,7 @@ void RequestManager::load_tokens_task( // BatchConfig const batch_config = *((BatchConfig *)task->args); BatchConfig const *batch_config = BatchConfig::from_future(task->futures[0]); + BatchConfig::TokenId dram_copy[BatchConfig::MAX_NUM_TOKENS]; // Extreme long prompts are not supported, only load up to @@ -55,6 +56,55 @@ void RequestManager::load_tokens_task( sizeof(TokenId) * batch_config->num_tokens, cudaMemcpyHostToDevice, stream)); + + // copy meta data to workSpace + FFHandler handle = *((FFHandler const *)task->local_args); + cudaMemcpyAsync(handle.batch_config_metadata, + &(batch_config->tokensInfo), + batch_config->num_active_tokens() * + sizeof(BatchConfig::PerTokenInfo), + cudaMemcpyHostToDevice, + stream); + cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + + sizeof(BatchConfig::tokensInfo), + &(batch_config->requestsInfo), + batch_config->max_requests_per_batch() * + sizeof(BatchConfig::PerRequestInfo), + cudaMemcpyHostToDevice, + stream); + + + // load speculative metadata + if (batch_config->get_mode() == BEAM_SEARCH_MODE) { + BeamSearchBatchConfig const *beam_batch_config = + static_cast(batch_config); + + cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + + sizeof(BatchConfig::tokensInfo) + + sizeof(BatchConfig::requestsInfo), + &(beam_batch_config->topology_mask), + sizeof(BeamSearchBatchConfig::topology_mask), + cudaMemcpyHostToDevice, + stream); + + cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + + sizeof(BatchConfig::tokensInfo) + + sizeof(BatchConfig::requestsInfo) + + sizeof(BeamSearchBatchConfig::topology_mask), + &(beam_batch_config->beamRequestsInfo), + sizeof(BeamSearchBatchConfig::beamRequestsInfo), + cudaMemcpyHostToDevice, + stream); + cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + + sizeof(BatchConfig::tokensInfo) + + sizeof(BatchConfig::requestsInfo) + + sizeof(BeamSearchBatchConfig::topology_mask) + + sizeof(BeamSearchBatchConfig::beamRequestsInfo), + &(beam_batch_config->beamTokenInfo), + sizeof(BeamSearchBatchConfig::beamTokenInfo), + cudaMemcpyHostToDevice, + stream); + } } void RequestManager::load_positions_task( From d3a57cb22b080741d9677d82701f035ccd33f8da Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Tue, 26 Dec 2023 03:09:33 -0500 Subject: [PATCH 02/30] fix speculative --- include/flexflow/batch_config.h | 4 +- inference/models/llama.cc | 1 + inference/spec_infer/spec_infer.cc | 4 +- src/ops/beam_topk.cc | 11 ++- src/ops/beam_topk.cu | 61 ++++++------ .../specinfer_inc_multihead_self_attention.cu | 91 +++++++++++------- src/runtime/inference_manager.cc | 1 + src/runtime/request_manager.cc | 93 +++++++++++++++---- src/runtime/request_manager.cu | 10 +- 9 files changed, 185 insertions(+), 91 deletions(-) diff --git a/include/flexflow/batch_config.h b/include/flexflow/batch_config.h index c33c3558cc..dd947bbd85 100644 --- a/include/flexflow/batch_config.h +++ b/include/flexflow/batch_config.h @@ -126,11 +126,11 @@ class BeamSearchBatchConfig : public BatchConfig { size_t beam_width; size_t target_iterations; - inline static int const MAX_BEAM_WIDTH = 1; + inline static int const MAX_BEAM_WIDTH = 3; inline static int const MAX_BEAM_DEPTH = 8; // maximum tree branches for a request - inline static int const MAX_SPECULATIVE_TREE_BRANCHES = 9; + inline static int const MAX_SPECULATIVE_TREE_BRANCHES = 3; int model_id; diff --git a/inference/models/llama.cc b/inference/models/llama.cc index f62df1b1d7..4f76e9e0fa 100644 --- a/inference/models/llama.cc +++ b/inference/models/llama.cc @@ -248,6 +248,7 @@ void LLAMA::create_llama_model(FFModel &ff, // output = ff.beam_top_k(softmax, llama_config.max_beam_width, false); // output = ff.argmax(softmax, /*beam_Search*/ true); output = ff.beam_top_k(softmax, llama_config.max_beam_width, false); + // output = ff.top_k(softmax, ) } else { // Tensor softmax = ff.softmax(dense, -1); if (generation_config.do_sample) { diff --git a/inference/spec_infer/spec_infer.cc b/inference/spec_infer/spec_infer.cc index e2594ba87f..2ccdfd388d 100644 --- a/inference/spec_infer/spec_infer.cc +++ b/inference/spec_infer/spec_infer.cc @@ -303,7 +303,7 @@ void FlexFlow::top_level_task(Task const *task, rm->register_output_filepath(file_paths.output_file_path); //first decoding step: 3 results - rm->push_spec_infer_tree_width(1); + rm->push_spec_infer_tree_width(3); // Create LLM model FFModel tree_model(ffconfig, ffconfig.cpu_offload); @@ -404,7 +404,7 @@ void FlexFlow::top_level_task(Task const *task, prompts.push_back(text); // tree_model.generate(text, 128 /*max_sequence_length*/); } - tree_model.generate(prompts, 128 /*max_sequence_length*/); + tree_model.generate(prompts, 15 /*max_sequence_length*/); } // Execution fence diff --git a/src/ops/beam_topk.cc b/src/ops/beam_topk.cc index 2883428254..3f636c2c98 100644 --- a/src/ops/beam_topk.cc +++ b/src/ops/beam_topk.cc @@ -366,14 +366,18 @@ BeamInferenceResult GenericTensorAccessorW value = helperGetGenericTensorAccessorWO( DT_FLOAT, regions[2], task->regions[2], FID_DATA, ctx, runtime); GenericTensorAccessorW parent = helperGetGenericTensorAccessorWO( - DT_FLOAT, regions[3], task->regions[3], FID_DATA, ctx, runtime); + DT_INT32, regions[3], task->regions[3], FID_DATA, ctx, runtime); Domain input_domain = runtime->get_index_space_domain( ctx, task->regions[0].region.get_index_space()); - + + printf("----------1-----------\n"); int *index_ptr = index.get_int32_ptr(); + printf("----------2-----------\n"); float *value_ptr = value.get_float_ptr(); + printf("----------3-----------\n"); int *parent_ptr = parent.get_int32_ptr(); + printf("----------4-----------\n"); // embedding size: eg. 4096 int length = input_domain.hi()[0] - input_domain.lo()[0] + 1; @@ -398,6 +402,9 @@ BeamInferenceResult download_tensor( parent_ptr, ir.parent_id, batch_size * m->max_beam_width); + print_tensor(index_ptr, 32, "indexxxxxxx"); + printf("max beam width %d\n", m->max_beam_width); + if (m->inference_debugging) { assert(task->index_point.get_dim() == 1); int shard_id = task->index_point.point_data[0]; diff --git a/src/ops/beam_topk.cu b/src/ops/beam_topk.cu index 72ab7862a6..515bba4bc0 100644 --- a/src/ops/beam_topk.cu +++ b/src/ops/beam_topk.cu @@ -379,9 +379,9 @@ template __global__ void mergeSubRequestsKernel(int64_t N, T const *X, T const *rstd, T *Y) { using T_ACC = T; - const int64_t i = blockIdx.x; + int64_t const i = blockIdx.x; for (int64_t j = threadIdx.x; j < N; j += blockDim.x) { - const int64_t index = i * N + j; + int64_t const index = i * N + j; Y[index] = static_cast(X[index]) * static_cast(rstd[i]); } } @@ -556,8 +556,7 @@ void BeamTopK::forward_kernel(BeamTopKMeta const *m, int beam_size = bc->beamRequestsInfo[i].beam_size; // initial request - log_beam_topk.debug() << "sub_requests: " << i << ", " << sub_requests[i] - << "\n"; + std::cout << "sub_requests: " << i << ", " << sub_requests[i] << "\n"; assert(sub_requests[i] > 0); // process sub requests for (int j = 0; j < sub_requests[i]; j++) { @@ -565,12 +564,12 @@ void BeamTopK::forward_kernel(BeamTopKMeta const *m, // beam_slots[i].parent_id[j]; acc_probs[req_index * BeamSearchBatchConfig::MAX_BEAM_WIDTH + j] = bc->beamRequestsInfo[i].probs[j]; - log_beam_topk.debug() - << "probbbb req: " << i - << ", sub req probability : " << bc->beamRequestsInfo[i].probs[j] - << ", sub request id " << j << ", parent id " - << bc->beamRequestsInfo[i].parent_id[j] << ", data inddd" - << req_index * BeamSearchBatchConfig::MAX_BEAM_WIDTH + j << "\n"; + std::cout << "probbbb req: " << i << ", sub req probability : " + << bc->beamRequestsInfo[i].probs[j] << ", sub request id " << j + << ", parent id " << bc->beamRequestsInfo[i].parent_id[j] + << ", data inddd" + << req_index * BeamSearchBatchConfig::MAX_BEAM_WIDTH + j + << "\n"; } // process tokens @@ -584,6 +583,8 @@ void BeamTopK::forward_kernel(BeamTopKMeta const *m, max_heap_size = std::max(max_heap_size, beam_size * sub_requests[i]); max_beam_width = std::max(max_beam_width, beam_size); + + std::cout << "max beam width: " << max_beam_width << "\n"; req_index += 1; block_start_index += (sub_requests[i] - 1) * num_new_tokens * length; } @@ -613,26 +614,34 @@ void BeamTopK::forward_kernel(BeamTopKMeta const *m, assert(num_shards >= (size_t)max_heap_size); num_shards = max_heap_size; - checkCUDA(cudaMemcpy(m->parent_ids, - parent_ids, - sizeof(int) * max_total_requests, - cudaMemcpyHostToDevice)); - checkCUDA(cudaMemcpy(m->acc_probs, - acc_probs, - sizeof(DT) * max_total_requests, - cudaMemcpyHostToDevice)); - checkCUDA(cudaMemcpy(m->block_start_index, - beam_block_start_index.data(), - sizeof(int) * beam_num_blocks, - cudaMemcpyHostToDevice)); - checkCUDA(cudaMemcpy(m->request_id, + checkCUDA(cudaMemcpyAsync(m->parent_ids, + parent_ids, + sizeof(int) * max_total_requests, + cudaMemcpyHostToDevice, + stream)); + checkCUDA(cudaMemcpyAsync(m->acc_probs, + acc_probs, + sizeof(DT) * max_total_requests, + cudaMemcpyHostToDevice, + stream)); + // trick, set acc_probs to 0; + checkCUDA( + cudaMemsetAsync(m->acc_probs, 1.0, batch_size * sizeof(DT), stream)); + checkCUDA(cudaMemcpyAsync(m->block_start_index, + beam_block_start_index.data(), + sizeof(int) * beam_num_blocks, + cudaMemcpyHostToDevice, + stream)); + checkCUDA(cudaMemcpyAsync(m->request_id, request_id.data(), sizeof(int) * beam_num_blocks, - cudaMemcpyHostToDevice)); - checkCUDA(cudaMemcpy(m->tokens_per_request, + cudaMemcpyHostToDevice, + stream)); + checkCUDA(cudaMemcpyAsync(m->tokens_per_request, tokens_per_request.data(), sizeof(int) * beam_num_blocks, - cudaMemcpyHostToDevice)); + cudaMemcpyHostToDevice, + stream)); // int depth = // bc->beamRequestsInfo[bc->tokensInfo[0].request_index].current_depth; beam_topk_forward_kernel<<>>( diff --git a/src/ops/specinfer_inc_multihead_self_attention.cu b/src/ops/specinfer_inc_multihead_self_attention.cu index 0bdf07a9d7..9d6f70d5ba 100644 --- a/src/ops/specinfer_inc_multihead_self_attention.cu +++ b/src/ops/specinfer_inc_multihead_self_attention.cu @@ -133,6 +133,13 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( q_ptr + (hidden_size * QKV_WEIGHT_NUM * sub_req_idx) + ki + ii * THREADS_PER_KEY * K_VEC_SIZE); } + + if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { + printf("cacheposssss %d, %d\n", tree_branch_num, topology.real_token_pos[0][0]); + printf("cacheposssss %d, %d\n", tree_branch_num, topology.real_token_pos[0][1]); + printf("cacheposssss %d, %d\n", tree_branch_num, topology.real_token_pos[0][2]); + printf("cacheposssss %d, %d\n", tree_branch_num, topology.real_token_pos[0][10]); + } __syncthreads(); for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { K_vec k[K_VECS_PER_THREAD]; @@ -317,26 +324,38 @@ __global__ void specinfer_store_kv_cache( DT kVal = devQKVProjArray[val_idx]; DT vVal = devQKVProjArray[val_idx + hidden_size]; - // above no need to be changed - // int const req_id = id_map[token_idx].request_index; - // int const tok_id = id_map[token_idx].token_position; - // int const sub_req_id = id_map[token_idx].sub_request_index; - // int const parent_id = id_map[token_idx].parent_id; - // int const beam_depth = id_map[token_idx].beam_depth; - // int const beam_width = id_map[token_idx].beam_width; - int const req_id = tokenInfos[token_idx].request_index; int const tok_id = tokenInfos[token_idx].abs_depth_in_request; + int const first_token_in_req = requestInfo[req_id].first_token_depth_in_request; int const sub_req_id = beamTokenInfos[token_idx].sub_request_index; - // int const parent_id = beamRequestInfos[req_id].parent_id[sub_req_id]; - // int const beam_depth = beamRequestInfos[req_id].current_depth; - // int const beam_width = beamRequestInfos[req_id].beam_size; int const allocated_tokens = beam_topology_mask[req_id].allocated_tokens; + int const beam_size = beamRequestInfos[req_id].sub_request_num; + + int real_idx = tok_id - first_token_in_req + allocated_tokens; + + if (i == 0) { + printf("ffasdasds%d, %d, %d, %d, %d, %d\n", + beamTokenInfos[0].sub_request_index, + allocated_tokens, + sub_req_id, + tok_id, + first_token_in_req, + real_idx); + } + // }else if(i == hidden_size * 2){ + // printf("ffasdasdskkkk%d, %d, %d\n", allocated_tokens, tok_id, + // sub_req_id); + // } + + + kCache_ptr[(req_id * max_tree_branches) * (hidden_size * max_seq_len) + - (allocated_tokens + sub_req_id) * hidden_size + offset] = kVal; + (real_idx) * hidden_size + + offset] = kVal; vCache_ptr[(req_id * max_tree_branches) * (hidden_size * max_seq_len) + - (allocated_tokens + sub_req_id) * hidden_size + offset] = vVal; + (real_idx) * hidden_size + + offset] = vVal; } } @@ -350,6 +369,9 @@ void update_kv_cache_kernel(SpecInferIncMultiHeadSelfAttentionMeta const *m, // assert(curr_depth < 3); if (num_tokens > 0) { int parallelism = m->hidden_size * KV_WEIGHT_NUM * num_tokens; + printf("tokenInfo %d, %d\n", + bc->beamTokenInfo[0].sub_request_index, + num_tokens); specinfer_store_kv_cache<<max_requests_per_batch(); i++) { if (bc->request_completed[i]) { continue; - } else if (tokens_previous_requests < bc->num_generation_tokens) { - tokens_previous_requests += bc->requestsInfo[i].num_tokens_in_batch; - continue; - } + } + // else if (tokens_previous_requests < bc->num_generation_tokens) { + // tokens_previous_requests += bc->requestsInfo[i].num_tokens_in_batch; + // continue; + // } // all requests in prompt phase should only have one sub requests; assert(bc->sub_requests[i] == 1); @@ -523,6 +546,9 @@ void compute_attention_kernel_prompt( m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM; // To get B, skip over K entries from previous requests (all heads + // padding) + + print_tensor((float*)A, 32, "A"); + std::cout << "meta: " << num_new_tokens << ", " << total_tokens << "\n"; DT const *B = static_cast
(m->keyCache) + (i * bc->MAX_SPECULATIVE_TREE_BRANCHES) * kt_req_block_size; @@ -557,6 +583,7 @@ void compute_attention_kernel_prompt( m->num_q_heads, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + print_tensor((float*)C, 32, "C"); // add alibi position bias to qk production // add alibi position bias to qk production if (*m->position_bias) { @@ -641,6 +668,8 @@ void compute_attention_kernel_prompt( B = C_softmax; // To get C, skip over softmax(QK^T/sqrt(d_k))V products from previous // requests + + print_tensor((float*)C_softmax, 32, "C_softmax"); C = static_cast
(m->attn_heads) + (tokens_previous_requests + bc->num_generation_tokens) * m->num_q_heads * m->vProjSize; @@ -695,6 +724,8 @@ void inference_kernel(SpecInferIncMultiHeadSelfAttentionMeta const *m, stream); // phase 2: Update key/val cache update_kv_cache_kernel
(m, bc, stream); + std::cout << "specinfer kernel token num: " << bc->num_generation_tokens + << ", " << bc->num_tokens << "\n"; if (bc->num_generation_tokens > 0) { compute_specinfer_attention_kernel_generation
( m, bc, static_cast
(m->attn_heads), stream); @@ -705,6 +736,8 @@ void inference_kernel(SpecInferIncMultiHeadSelfAttentionMeta const *m, compute_attention_kernel_prompt( m, bc, shard_id, output_ptr, bias_ptr, weight_ptr, stream); } + // compute_attention_kernel_prompt( + // m, bc, shard_id, output_ptr, bias_ptr, weight_ptr, stream); // compute output production and bias together for all tokens int num_tokens = bc->num_active_tokens(); @@ -783,6 +816,12 @@ void SpecInferIncMultiHeadSelfAttention::inference_kernel_wrapper( // "[Attention:forward:query]"); print_tensor<3, float>(acc_output.ptr, // acc_output.rect, "[Attention:forward:output]"); } + + // if(bc->num_tokens == 1){ + // print_tensor(input.get_float_ptr(), 32, "specinc input"); + // print_tensor(output.get_float_ptr(), 32, "specinc output"); + // assert(false); + // } } SpecInferIncMultiHeadSelfAttentionMeta::SpecInferIncMultiHeadSelfAttentionMeta( @@ -825,24 +864,6 @@ SpecInferIncMultiHeadSelfAttentionMeta::SpecInferIncMultiHeadSelfAttentionMeta( // allocate memory for the seqArray and reserve space { - // int max_tokens_per_batch = BatchConfig::max_tokens_per_batch(); - // size_t beam_tokeninfo_size = - // max_tokens_per_batch * BeamSearchBatchConfig::MAX_BEAM_WIDTH; - // size_t requestinfo_size = - // BeamSearchBatchConfig::max_requests_per_batch(); size_t - // beam_requestinfo_size = - // BeamSearchBatchConfig::max_requests_per_batch(); - // size_t total_size = - // beam_tokeninfo_size * - // sizeof(BeamSearchBatchConfig::BeamSearchPerTokenInfo) + - // beam_requestinfo_size * - // sizeof(BeamSearchBatchConfig:: - // BeamSearchPerRequestInfo); // more components will - // // be added here later - - // We always directly allocate memory for small speculative models - // gpu_mem_allocator.create_legion_instance(beam_search_reserve_inst, - // total_size); beam_topology_mask = static_cast( handler.batch_config_metadata + sizeof(BatchConfig::tokensInfo) + diff --git a/src/runtime/inference_manager.cc b/src/runtime/inference_manager.cc index fb978adfff..52fd64c606 100644 --- a/src/runtime/inference_manager.cc +++ b/src/runtime/inference_manager.cc @@ -257,6 +257,7 @@ void InferenceManager::init_operators_inference(FFModel *model) { ((ParallelOp *)op) ->create_input_partition_inference(*model, inputs, outputs); } + printf("init op %s\n", op->name); op->init_inference(*model, inputs, outputs); } } diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index e1b591c320..845a580c13 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -714,7 +714,8 @@ BeamSearchBatchConfig dfs_tree_inputs.erase(request.guid); } else { // Request not finished, pass verified_tokens to next iteration - + + std::cout << "parse to next iteration: " << "\n"; new_bc.request_completed[i] = false; new_bc.request_running[i] = true; @@ -752,6 +753,12 @@ BeamSearchBatchConfig new_bc.beamRequestsInfo[i].sub_request_num = 1; new_bc.sub_requests[i] = 1; + new_bc.topology_mask[i].allocated_tokens = request.tokens.size(); + + //assign new kv cache position + for(int j = 0; j < request.tokens.size(); j++){ + new_bc.topology_mask[i].real_token_pos[0][j] = j; + } // Token Info for (int j = 0; j < verified_tokens.size(); j++) { @@ -768,6 +775,8 @@ BeamSearchBatchConfig // Beam Token Info new_bc.beamTokenInfo[new_bc.num_tokens].sub_request_index = 0; new_bc.num_tokens++; + std::cout << "num_gen ++ " << "\n"; + num_generation_tokens++; // Add verified token to request's token list request.tokens.push_back(token.first); @@ -776,6 +785,8 @@ BeamSearchBatchConfig break; } } + + std::string output = this->tokenizer_->Decode(request.tokens); // Unlike Huggingface, the sentencepiece C++ library automatically // removes the BOS token @@ -817,6 +828,7 @@ BeamSearchBatchConfig } new_bc.beamRequestsInfo[i].sub_request_num = 1; + new_bc.topology_mask[i].allocated_tokens = 0; new_bc.sub_requests[i] = 1; @@ -875,7 +887,11 @@ BeamSearchBatchConfig } new_bc.request_completed[i] = false; + new_bc.beamRequestsInfo[i].sub_request_num = 1; + printf("sub request num == 1, %d \n", + new_bc.beamRequestsInfo[i].beam_size); + new_bc.sub_requests[i] = 1; for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) { @@ -892,6 +908,7 @@ BeamSearchBatchConfig new_bc.beamTokenInfo[new_bc.num_tokens].sub_request_index = 0; new_bc.num_tokens++; } + new_bc.topology_mask[i].allocated_tokens = 0; // if (new_bc.requestsInfo[i].num_tokens_in_batch < // new_request.initial_len) { @@ -927,6 +944,8 @@ BeamSearchBatchConfig } new_bc.num_generation_tokens = num_generation_tokens; + std::cout << "prepare next batch init gen tokens: " << new_bc.num_generation_tokens << "\n"; + if (verbose) { std::cout << "prepare_next_batch_init OLD vs NEW batchconfigs below:" << std::endl; @@ -969,10 +988,10 @@ BeamSearchBatchConfig RequestManager::prepare_next_batch_beam(BeamSearchBatchConfig const &old_bc, BeamInferenceResult const &result) { std::lock_guard const lock(request_queue_mutex); - if (verbose) { + if (true) { std::cout << "\n############### prepare_next_batch_beam ###############\n"; } - if (verbose) { + if (true) { std::cout << "print all results" << "\n"; for (int i = 0; i < 40; i++) { @@ -980,6 +999,8 @@ BeamSearchBatchConfig } std::cout << "Current Beam Depth: " << old_bc.beamRequestsInfo[0].current_depth << "\n"; + std::cout << "Current sub request num: " + << old_bc.beamRequestsInfo[0].sub_request_num << "\n"; } // Step 1: Store result to the beam tree struct store_beam_metadata(old_bc, result); @@ -1049,6 +1070,7 @@ BeamSearchBatchConfig spec_infer_tree_width.size() > ssm_decoding_steps ? spec_infer_tree_width[ssm_decoding_steps] : 1; + new_bc.beamRequestsInfo[i].max_depth = old_bc.beamRequestsInfo[i].max_depth; @@ -1154,13 +1176,16 @@ BeamSearchBatchConfig spec_infer_tree_width.size() > ssm_decoding_steps ? spec_infer_tree_width[ssm_decoding_steps] : 1; + printf("beam size: %d, %d\n", + new_bc.beamRequestsInfo[i].beam_size, + ssm_decoding_steps); new_bc.beamRequestsInfo[i].max_depth = old_bc.beamRequestsInfo[i].max_depth; new_bc.sub_requests[i] = old_bc.sub_requests[i] * new_bc.beamRequestsInfo[i].beam_size; new_bc.beamRequestsInfo[i].sub_request_num = - old_bc.beamRequestsInfo[i].sub_request_num * - new_bc.beamRequestsInfo[i].beam_size; + old_bc.beamRequestsInfo[i].sub_request_num; + assert(new_bc.beamRequestsInfo[i].sub_request_num <= BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES); @@ -1230,6 +1255,16 @@ BeamSearchBatchConfig old_bc.print(); new_bc.print(); } + + if (true) { + std::cout << "print all resultsBBB" + << "\n"; + for (int i = 0; i < 40; i++) { + std::cout << result.token_ids[i] << ", "; + } + std::cout << "Current Beam DepthBBB: " + << old_bc.beamRequestsInfo[0].current_depth << "\n"; + } return new_bc; } @@ -1296,6 +1331,8 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( profiling_requests[request.guid].llm_decoding_steps += 1; if (request.status == Request::RUNNING) { + std::cout << "prepare next batch running: pending\n" + << "\n"; new_bc.request_running[i] = true; std::cout << "[Verify] Request " << request.guid << " is running" << std::endl; @@ -1401,6 +1438,8 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( } } else if (request.status == Request::PENDING) { + std::cout << "prepare next batch verify: pending\n" + << "\n"; new_bc.request_running[i] = false; if (verbose) { std::cout << "[Verify] Request " << request.guid @@ -1450,6 +1489,9 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( << std::endl; if (request.llm_cache_size < request.initial_len) { + std::cout << "Initialization (prompt) phase: " + << new_bc.requestsInfo[i].num_tokens_in_batch << ", " + << old_batches.at(0).beamRequestsInfo[i].beam_size << "\n"; // Initialization (prompt) phase for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) { new_bc.tokensInfo[new_bc.num_tokens].request_index = i; @@ -1457,7 +1499,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( request.tokens[request.llm_cache_size + j]; new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = request.llm_cache_size + j; - + std::cout << "load prompt tokens: " << j << ": " << new_bc.tokensInfo[new_bc.num_tokens].token_id << "\n"; new_bc.num_tokens++; } @@ -1483,6 +1525,8 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( } } else { // launch the request into running phase after loading all prompt if (get_max_tokens_per_batch() - new_bc.num_tokens > 0) { + std::cout << "Initialization running phase: " + << new_bc.requestsInfo[i].num_tokens_in_batch << "\n"; request.status = Request::RUNNING; new_bc.request_running[i] = true; @@ -1521,7 +1565,7 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, auto start_depth = old_bc.tokensInfo[0].abs_depth_in_request; int result_index = 0; - if (verbose) { + if (true) { std::cout << "Store total of " << old_bc.num_tokens << " tokens in the current batch.\n"; } @@ -1535,7 +1579,8 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, int beam_size = old_bc.beamRequestsInfo[index].beam_size; // int leaf_node_num = old_bc.sub_requests[index]; - int leaf_node_num = old_bc.beamRequestsInfo[i].sub_request_num; + int leaf_node_num = + old_bc.beamRequestsInfo[index].sub_request_num * beam_size; int depth = old_bc.beamRequestsInfo[index].current_depth; // Each token yields (beam_width) results @@ -1545,18 +1590,26 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, // index result_index += (old_bc.tokensInfo[i - 1].abs_depth_in_request - start_depth) * - leaf_node_num; + beam_size; - if (verbose) { + // result_index += old_bc.topology_mask[index].allocated_tokens; + + if (true) { std::cout << "i = " << i << ", result index = " << result_index - << ", value: " << result.token_ids[result_index] << "\n"; + << ", value: " << result.token_ids[result_index] + << ", leaf node num: " << leaf_node_num << ", depth" << depth + << ", beam size: " << beam_size << "\n"; } Request &request = all_requests[old_bc.requestsInfo[index].request_guid]; + if (old_bc.requestsInfo[index].num_tokens_in_batch == 0) { + continue; + } + if (depth == 1) { // store the last input into the tree; - if (verbose) { + if (true) { std::cout << "try to store the input" << "\n"; } @@ -1566,7 +1619,7 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, request.beam_trees.at(old_bc.model_id).treeLayers[0].probs[0] = 1; request.beam_trees.at(old_bc.model_id).treeLayers[0].parent_ids[0] = -1; - if (verbose) { + if (true) { std::cout << "Store the previous last token to the tree root: " << request.tokens.back() << "\n"; } @@ -1583,7 +1636,7 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, .treeLayers[depth] .parent_ids[beam_id] = result.parent_id[result_index]; - if (verbose) { + if (true) { std::cout << "tree value: " << depth << "token: " << request.beam_trees.at(old_bc.model_id) .treeLayers[depth] @@ -1592,7 +1645,6 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, } result_index += 1; } - // update the guid and start_depth for current request if (i < old_bc.num_tokens) { guid = old_bc.requestsInfo[index].request_guid; @@ -1600,6 +1652,10 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, } } } + + if (old_bc.num_tokens != 10) { + assert(false); + } } // for updating the beam search metadata in requests in incremental phase @@ -1638,7 +1694,6 @@ void RequestManager::update_beam_metadata(BeamSearchBatchConfig &new_bc, new_bc.beamRequestsInfo[request_index].tokens[j] = tree.treeLayers[depth].tokens[j]; - // new_bc.topology_mask[request_index].real_token_pos[j] = } assert(false); @@ -1784,7 +1839,7 @@ std::vector> // depth) pairs for (auto const &pair : inputSerializedTree) { oss << " " << pair.second << ":" << pair.first; - // log_req_mgr.print("(%d, %d)", pair.first, pair.second); + log_req_mgr.print("(%d, %d)", pair.first, pair.second); } log_req_mgr.print("Input tree:%s", oss.str().c_str()); } @@ -1793,7 +1848,7 @@ std::vector> // outputSerializedTree is an array of (token id, depth + 1) pairs std::ostringstream oss; for (auto const &pair : outputSerializedTree) { - // log_req_mgr.print("(%d, %d)", pair.first, pair.second); + log_req_mgr.print("(%d, %d)", pair.first, pair.second); oss << " " << pair.second << ":" << pair.first; } log_req_mgr.print("Output tree:%s", oss.str().c_str()); @@ -1847,7 +1902,7 @@ std::vector> // log_req_mgr.print("========Verified============"); std::ostringstream oss; for (auto const &pair : verifiedTree) { - // log_req_mgr.print("(%d, %d)", pair.first, pair.second); + log_req_mgr.print("(%d, %d)", pair.first, pair.second); oss << " " << pair.second << ":" << pair.first; } log_req_mgr.print("Verified:%s", oss.str().c_str()); diff --git a/src/runtime/request_manager.cu b/src/runtime/request_manager.cu index f4500d152d..b76c5c326e 100644 --- a/src/runtime/request_manager.cu +++ b/src/runtime/request_manager.cu @@ -91,17 +91,17 @@ void RequestManager::load_tokens_task( sizeof(BatchConfig::tokensInfo) + sizeof(BatchConfig::requestsInfo) + sizeof(BeamSearchBatchConfig::topology_mask), - &(beam_batch_config->beamRequestsInfo), - sizeof(BeamSearchBatchConfig::beamRequestsInfo), + &(beam_batch_config->beamTokenInfo), + sizeof(BeamSearchBatchConfig::beamTokenInfo), cudaMemcpyHostToDevice, stream); cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + sizeof(BatchConfig::tokensInfo) + sizeof(BatchConfig::requestsInfo) + sizeof(BeamSearchBatchConfig::topology_mask) + - sizeof(BeamSearchBatchConfig::beamRequestsInfo), - &(beam_batch_config->beamTokenInfo), - sizeof(BeamSearchBatchConfig::beamTokenInfo), + sizeof(BeamSearchBatchConfig::beamTokenInfo), + &(beam_batch_config->beamRequestsInfo), + sizeof(BeamSearchBatchConfig::beamRequestsInfo), cudaMemcpyHostToDevice, stream); } From 617a29fdda4e79d0d9c7bbcc1455ed447c42988f Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Tue, 26 Dec 2023 13:43:49 -0500 Subject: [PATCH 03/30] fix speculative --- .../specinfer_inc_multihead_self_attention.cu | 42 ++++--- src/runtime/request_manager.cc | 107 +++++++++++++----- 2 files changed, 109 insertions(+), 40 deletions(-) diff --git a/src/ops/specinfer_inc_multihead_self_attention.cu b/src/ops/specinfer_inc_multihead_self_attention.cu index 9d6f70d5ba..63cd90f44f 100644 --- a/src/ops/specinfer_inc_multihead_self_attention.cu +++ b/src/ops/specinfer_inc_multihead_self_attention.cu @@ -134,11 +134,20 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( ii * THREADS_PER_KEY * K_VEC_SIZE); } - if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { - printf("cacheposssss %d, %d\n", tree_branch_num, topology.real_token_pos[0][0]); - printf("cacheposssss %d, %d\n", tree_branch_num, topology.real_token_pos[0][1]); - printf("cacheposssss %d, %d\n", tree_branch_num, topology.real_token_pos[0][2]); - printf("cacheposssss %d, %d\n", tree_branch_num, topology.real_token_pos[0][10]); + if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0 && sub_req_idx == 0) { + printf("cacheposssssA %d, %d\n", tree_branch_num, topology.real_token_pos[0][0]); + printf("cacheposssssB %d, %d\n", tree_branch_num, topology.real_token_pos[0][1]); + printf("cacheposssssC %d, %d\n", tree_branch_num, topology.real_token_pos[0][2]); + printf("cacheposssssD %d, %d\n", tree_branch_num, topology.real_token_pos[0][11]); + printf("cacheposssssD %d, %d\n", tree_branch_num, topology.real_token_pos[0][12]); + printf("cacheposssssD %d, %d\n", tree_branch_num, topology.real_token_pos[0][13]); + }else if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0 && sub_req_idx == 1) { + printf("cacheposssssE %d, %d\n", tree_branch_num, topology.real_token_pos[sub_req_idx][0]); + printf("cacheposssssF %d, %d\n", tree_branch_num, topology.real_token_pos[sub_req_idx][1]); + printf("cacheposssssG %d, %d\n", tree_branch_num, topology.real_token_pos[sub_req_idx][2]); + printf("cacheposssssH %d, %d\n", tree_branch_num, topology.real_token_pos[sub_req_idx][11]); + printf("cacheposssssH %d, %d\n", tree_branch_num, topology.real_token_pos[sub_req_idx][12]); + printf("cacheposssssH %d, %d\n", tree_branch_num, topology.real_token_pos[sub_req_idx][13]); } __syncthreads(); for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { @@ -289,7 +298,7 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( // Output the final values. if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { convert_from_float( - *reinterpret_cast(output_ptr + request_idx * hidden_size + + *reinterpret_cast(output_ptr + (request_idx + sub_req_idx) * hidden_size + head_idx * per_head_size + vi), out); } @@ -332,7 +341,7 @@ __global__ void specinfer_store_kv_cache( int const beam_size = beamRequestInfos[req_id].sub_request_num; - int real_idx = tok_id - first_token_in_req + allocated_tokens; + int real_idx = tok_id - first_token_in_req + allocated_tokens + sub_req_id; if (i == 0) { printf("ffasdasds%d, %d, %d, %d, %d, %d\n", @@ -343,10 +352,15 @@ __global__ void specinfer_store_kv_cache( first_token_in_req, real_idx); } - // }else if(i == hidden_size * 2){ - // printf("ffasdasdskkkk%d, %d, %d\n", allocated_tokens, tok_id, - // sub_req_id); - // } + else if(i == hidden_size * 2){ + printf("hshddhdhdsdaww%d, %d, %d, %d, %d, %d\n", + beamTokenInfos[0].sub_request_index, + allocated_tokens, + sub_req_id, + tok_id, + first_token_in_req, + real_idx); + } @@ -547,7 +561,7 @@ void compute_attention_kernel_prompt( // To get B, skip over K entries from previous requests (all heads + // padding) - print_tensor((float*)A, 32, "A"); + // print_tensor((float*)A, 32, "A"); std::cout << "meta: " << num_new_tokens << ", " << total_tokens << "\n"; DT const *B = static_cast
(m->keyCache) + (i * bc->MAX_SPECULATIVE_TREE_BRANCHES) * kt_req_block_size; @@ -583,7 +597,7 @@ void compute_attention_kernel_prompt( m->num_q_heads, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - print_tensor((float*)C, 32, "C"); + // print_tensor((float*)C, 32, "C"); // add alibi position bias to qk production // add alibi position bias to qk production if (*m->position_bias) { @@ -669,7 +683,7 @@ void compute_attention_kernel_prompt( // To get C, skip over softmax(QK^T/sqrt(d_k))V products from previous // requests - print_tensor((float*)C_softmax, 32, "C_softmax"); + // print_tensor((float*)C_softmax, 32, "C_softmax"); C = static_cast
(m->attn_heads) + (tokens_previous_requests + bc->num_generation_tokens) * m->num_q_heads * m->vProjSize; diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 845a580c13..775280e2cf 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -714,8 +714,9 @@ BeamSearchBatchConfig dfs_tree_inputs.erase(request.guid); } else { // Request not finished, pass verified_tokens to next iteration - - std::cout << "parse to next iteration: " << "\n"; + + std::cout << "parse to next iteration: " + << "\n"; new_bc.request_completed[i] = false; new_bc.request_running[i] = true; @@ -755,8 +756,8 @@ BeamSearchBatchConfig new_bc.sub_requests[i] = 1; new_bc.topology_mask[i].allocated_tokens = request.tokens.size(); - //assign new kv cache position - for(int j = 0; j < request.tokens.size(); j++){ + // assign new kv cache position + for (int j = 0; j < request.tokens.size(); j++) { new_bc.topology_mask[i].real_token_pos[0][j] = j; } @@ -775,7 +776,8 @@ BeamSearchBatchConfig // Beam Token Info new_bc.beamTokenInfo[new_bc.num_tokens].sub_request_index = 0; new_bc.num_tokens++; - std::cout << "num_gen ++ " << "\n"; + std::cout << "num_gen ++ " + << "\n"; num_generation_tokens++; // Add verified token to request's token list @@ -785,7 +787,6 @@ BeamSearchBatchConfig break; } } - std::string output = this->tokenizer_->Decode(request.tokens); // Unlike Huggingface, the sentencepiece C++ library automatically @@ -944,7 +945,8 @@ BeamSearchBatchConfig } new_bc.num_generation_tokens = num_generation_tokens; - std::cout << "prepare next batch init gen tokens: " << new_bc.num_generation_tokens << "\n"; + std::cout << "prepare next batch init gen tokens: " + << new_bc.num_generation_tokens << "\n"; if (verbose) { std::cout << "prepare_next_batch_init OLD vs NEW batchconfigs below:" @@ -1078,7 +1080,14 @@ BeamSearchBatchConfig old_bc.sub_requests[i] * new_bc.beamRequestsInfo[i].beam_size; new_bc.beamRequestsInfo[i].sub_request_num = old_bc.beamRequestsInfo[i].sub_request_num * - new_bc.beamRequestsInfo[i].beam_size; + old_bc.beamRequestsInfo[i].beam_size; + + std::cout << "oldbc : " << old_bc.beamRequestsInfo[i].sub_request_num + << ", " << old_bc.beamRequestsInfo[i].beam_size << "\n"; + + // if (old_bc.beamRequestsInfo[i].current_depth == 3) { + // assert(false); + // } assert(new_bc.beamRequestsInfo[i].sub_request_num <= BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES); @@ -1090,6 +1099,10 @@ BeamSearchBatchConfig // do the slot exchange to minimize the cache exchange in kernel. update_beam_metadata( new_bc, old_bc, request.beam_trees.at(old_bc.model_id), i); + + new_bc.topology_mask[i].allocated_tokens = + old_bc.topology_mask[i].allocated_tokens + + old_bc.beamRequestsInfo[i].sub_request_num; } else { assert(false && "Request should not be pending in beam search phase"); } @@ -1101,6 +1114,7 @@ BeamSearchBatchConfig request.tokens.size()) { // Incremental phase if (request.status == Request::RUNNING) { + // todo check it new_bc.requestsInfo[i].num_tokens_in_batch = 1; } else { assert(false && "Request should be done"); @@ -1122,7 +1136,31 @@ BeamSearchBatchConfig << std::endl; } + // for (int j = 0; j < request.tokens.size(); j++) { + // new_bc.topology_mask[i].real_token_pos[0][j] = j; + // } + // register more tokens due to the beam width + std::cout << "register more tokens: " + << new_bc.beamRequestsInfo[i].sub_request_num << ", " + << new_bc.requestsInfo[i].num_tokens_in_batch << ", " + << new_bc.topology_mask[i].allocated_tokens << "\n"; + + // copy meta data and replicate + int replicate_num = new_bc.beamRequestsInfo[i].sub_request_num / + old_bc.beamRequestsInfo[i].sub_request_num; + + for (int j = 0; j < old_bc.beamRequestsInfo[i].sub_request_num; j++) { + int old_idx = j; + for (int k = 0; k < replicate_num; k++) { + int new_idx = j * replicate_num + k; + std::cout << "copy from " << old_idx << "to: " << new_idx << "\n"; + memcpy(new_bc.topology_mask[i].real_token_pos[new_idx], + old_bc.topology_mask[i].real_token_pos[old_idx], + sizeof(int) * BatchConfig::MAX_NUM_TOKENS); + } + } + for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) { int depth = new_bc.requestsInfo[i].first_token_depth_in_request + j; for (int k = 0; k < new_bc.beamRequestsInfo[i].sub_request_num; k++) { @@ -1135,6 +1173,15 @@ BeamSearchBatchConfig new_bc.beamTokenInfo[new_bc.num_tokens].sub_request_index = k; new_bc.num_tokens++; + + // width first + new_bc.topology_mask[i].real_token_pos[k][depth] = + new_bc.topology_mask[i].allocated_tokens + num_generation_tokens; + + std::cout << "topology: sub request: " << k << ", " + << ", " << depth << ", " + << new_bc.topology_mask[i].real_token_pos[k][depth] << "\n"; + num_generation_tokens++; } } } @@ -1331,6 +1378,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( profiling_requests[request.guid].llm_decoding_steps += 1; if (request.status == Request::RUNNING) { + std::cout << "prepare next batch running: pending\n" << "\n"; new_bc.request_running[i] = true; @@ -1415,11 +1463,12 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( new_bc.requestsInfo[i].first_token_depth_in_request = request.tokens.size() - 1; - + + std::cout << "prepare next batch verify: " << dfs_tree_inputs.size() << "\n"; // Add Tokens from the DFS Tree to the next batch for (int j = 1; j < dfs_tree_inputs.size(); j++) { auto token = dfs_tree_inputs.at(j); - if (verbose) { + if (true) { std::cout << "[" << j << "] Token: " << token.first << ", Depth:" << token.second << std::endl; } @@ -1436,6 +1485,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( break; } } + assert(false); } else if (request.status == Request::PENDING) { std::cout << "prepare next batch verify: pending\n" @@ -1499,7 +1549,8 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( request.tokens[request.llm_cache_size + j]; new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = request.llm_cache_size + j; - std::cout << "load prompt tokens: " << j << ": " << new_bc.tokensInfo[new_bc.num_tokens].token_id << "\n"; + std::cout << "load prompt tokens: " << j << ": " + << new_bc.tokensInfo[new_bc.num_tokens].token_id << "\n"; new_bc.num_tokens++; } @@ -1625,7 +1676,10 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, } } + std::cout << "leaffffff: " << leaf_node_num << "\n"; + for (int beam_id = 0; beam_id < leaf_node_num; beam_id++) { + request.beam_trees.at(old_bc.model_id) .treeLayers[depth] .tokens[beam_id] = result.token_ids[result_index]; @@ -1635,14 +1689,19 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, request.beam_trees.at(old_bc.model_id) .treeLayers[depth] .parent_ids[beam_id] = result.parent_id[result_index]; - - if (true) { - std::cout << "tree value: " << depth << "token: " - << request.beam_trees.at(old_bc.model_id) - .treeLayers[depth] - .tokens[beam_id] - << "result tokens: " << result.token_ids[result_index]; - } + std::cout << "??????? beam id: " << beam_id << ", token: " + << request.beam_trees.at(old_bc.model_id) + .treeLayers[depth] + .tokens[beam_id] + << "\n"; + + // if (true) { + // std::cout << "tree value: " << depth << "token: " + // << request.beam_trees.at(old_bc.model_id) + // .treeLayers[depth] + // .tokens[beam_id] + // << "result tokens: " << result.token_ids[result_index]; + // } result_index += 1; } // update the guid and start_depth for current request @@ -1652,10 +1711,6 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, } } } - - if (old_bc.num_tokens != 10) { - assert(false); - } } // for updating the beam search metadata in requests in incremental phase @@ -1672,7 +1727,7 @@ void RequestManager::update_beam_metadata(BeamSearchBatchConfig &new_bc, int beam_size = new_bc.beamRequestsInfo[request_index].beam_size; // int leaf_node_num = old_bc.sub_requests[request_index]; - int leaf_node_num = old_bc.beamRequestsInfo[request_index].sub_request_num; + int leaf_node_num = new_bc.beamRequestsInfo[request_index].sub_request_num; if (new_bc.beamRequestsInfo[request_index].current_depth == 1) { // TODO: check if this is correct @@ -1693,9 +1748,9 @@ void RequestManager::update_beam_metadata(BeamSearchBatchConfig &new_bc, tree.treeLayers[depth].probs[j]; new_bc.beamRequestsInfo[request_index].tokens[j] = tree.treeLayers[depth].tokens[j]; - + std::cout << "token: " << j << ": " + << new_bc.beamRequestsInfo[request_index].tokens[j] << "\n"; } - assert(false); // std::set parents; // std::set childs; From b5f9d5d2d5eea50951a466d339bdc47910e69e07 Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Thu, 28 Dec 2023 01:57:39 -0500 Subject: [PATCH 04/30] bitmap+tree verify --- include/flexflow/batch_config.h | 20 +- include/flexflow/config.h | 3 +- .../inc_multihead_self_attention_utils.cuh | 2 +- .../specinfer_inc_multihead_self_attention.h | 1 + .../ops/tree_inc_multihead_self_attention.h | 1 + include/flexflow/request_manager.h | 10 + src/ops/argmax.cc | 2 + src/ops/inc_multihead_self_attention.cu | 8 +- src/ops/kernels/embedding_kernels.cu | 1 + .../specinfer_inc_multihead_self_attention.cu | 202 ++++++++---- src/ops/tree_inc_multihead_self_attention.cu | 197 ++++++++---- src/runtime/request_manager.cc | 291 ++++++++++++++---- src/runtime/request_manager.cu | 12 + 13 files changed, 562 insertions(+), 188 deletions(-) diff --git a/include/flexflow/batch_config.h b/include/flexflow/batch_config.h index dd947bbd85..db5d4a8e48 100644 --- a/include/flexflow/batch_config.h +++ b/include/flexflow/batch_config.h @@ -56,6 +56,7 @@ class BatchConfig { // across workers static int const MAX_NUM_REQUESTS = 64; static int const MAX_NUM_TOKENS = 1024; + static int const MAX_SPEC_TREE_TOKEN_NUM = 64; // Set by update int num_tokens; @@ -75,6 +76,24 @@ class BatchConfig { int request_index; TokenId token_id; }; + + struct BitMask { + unsigned long long mask[MAX_SPEC_TREE_TOKEN_NUM] = {0}; + + // how many tokens before the tree, every sub requests need this part of + // cache + int non_tree_cache_size; + + // current tree size + int tree_size; + + int this_layer_size; + + // input length-> prompt/root + int prompt_size; + }; + + BitMask causalMask[MAX_NUM_REQUESTS]; PerRequestInfo requestsInfo[MAX_NUM_REQUESTS]; PerTokenInfo tokensInfo[MAX_NUM_TOKENS]; @@ -154,7 +173,6 @@ class BeamSearchBatchConfig : public BatchConfig { int allocated_tokens; }; - BeamSearchPerRequestInfo beamRequestsInfo[MAX_NUM_REQUESTS]; BeamSearchPerTokenInfo beamTokenInfo[MAX_NUM_TOKENS * MAX_BEAM_WIDTH]; SpecInferTopology topology_mask[MAX_NUM_REQUESTS]; diff --git a/include/flexflow/config.h b/include/flexflow/config.h index 321d14961b..fe261dfb48 100644 --- a/include/flexflow/config.h +++ b/include/flexflow/config.h @@ -83,7 +83,8 @@ struct FFHandler { sizeof(BatchConfig::tokensInfo) + sizeof(BatchConfig::requestsInfo) + sizeof(BeamSearchBatchConfig::topology_mask) + sizeof(BeamSearchBatchConfig::beamTokenInfo) + - sizeof(BeamSearchBatchConfig::beamRequestsInfo); + sizeof(BeamSearchBatchConfig::beamRequestsInfo) + + sizeof(BatchConfig::causalMask); void *offload_reserve_space; size_t offload_reserve_space_size; DataType quantization_type; diff --git a/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh b/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh index c128c1a126..0c065b6b0e 100644 --- a/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh +++ b/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh @@ -456,7 +456,7 @@ inline size_t smem_size_in_bytes(int hidden_size_per_head, int threads_per_block) { // The amount of shared memory needed to store the Q*K^T values in float. - size_t qk_sz = div_up(max_sequence_length + 1, 4) * 16; + size_t qk_sz = div_up(1000 + 1, 4) * 16; size_t logits_sz = qk_sz; // The total size needed during softmax. diff --git a/include/flexflow/ops/specinfer_inc_multihead_self_attention.h b/include/flexflow/ops/specinfer_inc_multihead_self_attention.h index 6e5dc73b5c..eb1b2882c3 100644 --- a/include/flexflow/ops/specinfer_inc_multihead_self_attention.h +++ b/include/flexflow/ops/specinfer_inc_multihead_self_attention.h @@ -143,6 +143,7 @@ class SpecInferIncMultiHeadSelfAttentionMeta : public IncMultiHeadSelfAttentionM BeamSearchBatchConfig::BeamSearchPerTokenInfo *beam_token_infos; BeamSearchBatchConfig::BeamSearchPerRequestInfo *beam_request_infos; BeamSearchBatchConfig::SpecInferTopology *beam_topology_mask; + BatchConfig::BitMask *causalMask; }; }; // namespace FlexFlow diff --git a/include/flexflow/ops/tree_inc_multihead_self_attention.h b/include/flexflow/ops/tree_inc_multihead_self_attention.h index 6e2da19ce9..d160da4a72 100644 --- a/include/flexflow/ops/tree_inc_multihead_self_attention.h +++ b/include/flexflow/ops/tree_inc_multihead_self_attention.h @@ -147,6 +147,7 @@ class TreeIncMultiHeadSelfAttentionMeta : public IncMultiHeadSelfAttentionMeta { int num_active_tokens; Realm::RegionInstance committed_token_reserve_inst; TreeVerifyBatchConfig::CommittedTokensInfo *committed_token_infos; + BatchConfig::BitMask *causalMask; }; }; // namespace FlexFlow diff --git a/include/flexflow/request_manager.h b/include/flexflow/request_manager.h index e67888d2d6..dc1939c74b 100644 --- a/include/flexflow/request_manager.h +++ b/include/flexflow/request_manager.h @@ -110,6 +110,16 @@ class RequestManager { int eos_token_id, std::string const &path); void register_output_filepath(std::string const &); + void initBitMask(BatchConfig::BitMask &bitmask, int initLength); + void appendBitMask(BatchConfig::BitMask &bitmask, + int newNodes, + int preBeamSize, + int old_sub_num, + BeamTree const tree, + int currentDepth); + void updateBitMask(BatchConfig::BitMask &bitmask, + int initLength, + int non_tree_size); FFModel *get_model(int model_id); diff --git a/src/ops/argmax.cc b/src/ops/argmax.cc index f336c843e8..0344c707fc 100644 --- a/src/ops/argmax.cc +++ b/src/ops/argmax.cc @@ -398,6 +398,8 @@ InferenceResult ArgMax::save_inference_tensors_to_file( m, shard_id, bc, {}, {}, {input, indices}); } + + print_tensor(indices.get_int32_ptr(), 32, "tree attn output"); download_tensor( indices.get_int32_ptr(), ir.token_ids, batch_size); return ir; diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index 4c184acb3c..a05dbbf919 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -1364,8 +1364,7 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( vProjSize * num_q_heads); size_t key_cache_size = 0, value_cache_size = 0; switch (infer_mode) { - case INC_DECODING_MODE: - case TREE_VERIFY_MODE: { + case INC_DECODING_MODE: { key_cache_size = num_q_heads * kProjSize * BatchConfig::max_requests_per_batch() * BatchConfig::max_sequence_length(); @@ -1374,7 +1373,8 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( BatchConfig::max_sequence_length(); break; } - case BEAM_SEARCH_MODE: { + case BEAM_SEARCH_MODE: + case TREE_VERIFY_MODE: { // a K-ary tree max node is (k^n - 1) / 2 key_cache_size = num_q_heads * kProjSize * BeamSearchBatchConfig::max_requests_per_batch() * @@ -1402,7 +1402,7 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( 2 * qk_prod_size + attn_heads_size) * size_of_dt + complex_size * sizeof(cuFloatComplex); // more components will - // be added here later + // be added here later if (offload) { // assert that we have enough reserved work space left size_t totalSharedSize = diff --git a/src/ops/kernels/embedding_kernels.cu b/src/ops/kernels/embedding_kernels.cu index 22d8161ff1..91f5d60e85 100644 --- a/src/ops/kernels/embedding_kernels.cu +++ b/src/ops/kernels/embedding_kernels.cu @@ -118,6 +118,7 @@ void forward_kernel_wrapper(EmbeddingMeta const *m, // print_tensor(output_ptr, output_domain.get_volume(), // "[Embedding:forward:output]"); } + print_tensor(input.get_int32_ptr(), 32, "embeddinginput"); } /*static*/ diff --git a/src/ops/specinfer_inc_multihead_self_attention.cu b/src/ops/specinfer_inc_multihead_self_attention.cu index 63cd90f44f..e8ac1d980c 100644 --- a/src/ops/specinfer_inc_multihead_self_attention.cu +++ b/src/ops/specinfer_inc_multihead_self_attention.cu @@ -51,6 +51,7 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( BatchConfig::PerRequestInfo *request_infos, BeamSearchBatchConfig::BeamSearchPerRequestInfo *beam_request_infos, BeamSearchBatchConfig::SpecInferTopology *topology_mask, + BatchConfig::BitMask *causalMask, int max_tree_branches) { // q, k @@ -75,11 +76,18 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( BeamSearchBatchConfig::SpecInferTopology topology = topology_mask[request_idx]; + BatchConfig::BitMask bitmask = causalMask[request_idx]; int const first_step = 0; int const tlength = request_infos[request_idx].first_token_depth_in_request + request_infos[request_idx].num_tokens_in_batch; + + if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { + printf("specinfer attn fused kernel %lld\n", bitmask.mask[1]); + } + + int const totalCacheSize = bitmask.non_tree_cache_size + bitmask.tree_size; // int const qlength = request_infos[request_idx].num_tokens_in_batch; int const tree_branch_num = beam_request_infos[request_idx].sub_request_num; @@ -88,7 +96,8 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( int first_token_idx = 0; for (int r = 0; r < request_idx; r++) { - first_token_idx += request_infos[request_idx].num_tokens_in_batch; + // first_token_idx += request_infos[request_idx].num_tokens_in_batch; + first_token_idx += bitmask.this_layer_size; } // shared memory objects @@ -124,7 +133,7 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( request_idx * max_seq_length * hidden_size * max_tree_branches + ki; int ti_end = - div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step; + div_up(totalCacheSize - first_step, K_PER_WARP) * K_PER_WARP + first_step; for (int sub_req_idx = 0; sub_req_idx < tree_branch_num; sub_req_idx += 1) { #pragma unroll @@ -134,21 +143,25 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( ii * THREADS_PER_KEY * K_VEC_SIZE); } - if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0 && sub_req_idx == 0) { - printf("cacheposssssA %d, %d\n", tree_branch_num, topology.real_token_pos[0][0]); - printf("cacheposssssB %d, %d\n", tree_branch_num, topology.real_token_pos[0][1]); - printf("cacheposssssC %d, %d\n", tree_branch_num, topology.real_token_pos[0][2]); - printf("cacheposssssD %d, %d\n", tree_branch_num, topology.real_token_pos[0][11]); - printf("cacheposssssD %d, %d\n", tree_branch_num, topology.real_token_pos[0][12]); - printf("cacheposssssD %d, %d\n", tree_branch_num, topology.real_token_pos[0][13]); - }else if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0 && sub_req_idx == 1) { - printf("cacheposssssE %d, %d\n", tree_branch_num, topology.real_token_pos[sub_req_idx][0]); - printf("cacheposssssF %d, %d\n", tree_branch_num, topology.real_token_pos[sub_req_idx][1]); - printf("cacheposssssG %d, %d\n", tree_branch_num, topology.real_token_pos[sub_req_idx][2]); - printf("cacheposssssH %d, %d\n", tree_branch_num, topology.real_token_pos[sub_req_idx][11]); - printf("cacheposssssH %d, %d\n", tree_branch_num, topology.real_token_pos[sub_req_idx][12]); - printf("cacheposssssH %d, %d\n", tree_branch_num, topology.real_token_pos[sub_req_idx][13]); - } + int const query_token = bitmask.tree_size - tree_branch_num + sub_req_idx; + + if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0 && sub_req_idx == 0) { + // printf("fuckmasksss %d, %d, %d, %d, %d\n", + // bitmask.prompt_size, + // bitmask.non_tree_cache_size, + // tree_branch_num, + // bitmask.tree_size, + // tlength); + // printf("cacheposssssB %d, %d\n", tree_branch_num, + // topology.real_token_pos[0][1]); + // printf("cacheposssssC %d, %d\n", tree_branch_num, + // topology.real_token_pos[0][2]); + // printf("cacheposssssD %d, %d\n", tree_branch_num, + // topology.real_token_pos[0][11]); printf("cacheposssssD %d, %d\n", + // tree_branch_num, topology.real_token_pos[0][12]); + // printf("cacheposssssD %d, %d\n", tree_branch_num, + // topology.real_token_pos[0][13]); + } __syncthreads(); for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { K_vec k[K_VECS_PER_THREAD]; @@ -156,22 +169,33 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { int jj = ii * THREADS_PER_KEY * K_VEC_SIZE; - if (ti < tlength) { + if (ti < totalCacheSize) { // find the real position of the cache; // depth: 0, 1, 2, 3, 4, 4, 5, 5 ,5, 5, - int const real_cache_idx = topology.real_token_pos[sub_req_idx][ti]; + // int const real_cache_idx = + // topology.real_token_pos[sub_req_idx][ti]; k[ii] = *reinterpret_cast( - k_cache_batch + real_cache_idx * hidden_size + - head_idx * per_head_size + jj); + k_cache_batch + ti_circ * hidden_size + head_idx * per_head_size + + jj); } } float qk = scale * Qk_dot::dot(q_vecs[ki_o], k); - if (ti < tlength && tidx % THREADS_PER_KEY == 0) { + if (ti < totalCacheSize && tidx % THREADS_PER_KEY == 0) { // todo add alobi here - bool const mask = ti_circ >= tlength; - if (mask) { - assert(false); + // bool const mask = ti_circ >= totalCacheSize; + bool const mask = (ti >= bitmask.non_tree_cache_size && + (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & + (1 << query_token)))); + + if (blockIdx.y == 0 && blockIdx.x == 0 && mask && sub_req_idx == 0) { + // printf("specinfer mask: ti:%d, %d, %d, %d, %lld\n", + // ti, + // totalCacheSize, + // ti - bitmask.non_tree_cache_size, + // query_token, + // bitmask.mask[ti - bitmask.non_tree_cache_size]); + // assert(false); } qk_max = mask ? qk_max : fmaxf(qk_max, qk); qk_smem[ti - first_step] = mask ? 0.f : qk; @@ -208,10 +232,14 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); float exp_sum = 0.f; - for (int ti = first_step + tidx; ti < tlength; ti += THREADS_PER_BLOCK) { - float logit = __expf(qk_smem[ti - first_step] - qk_max); + for (int ti = first_step + tidx; ti < totalCacheSize; + ti += THREADS_PER_BLOCK) { + bool const mask = (ti >= bitmask.non_tree_cache_size && + (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & + (1 << query_token)))); + float logit = mask ? 0.0f : __expf(qk_smem[ti - first_step] - qk_max); exp_sum += logit; - qk_smem[ti - first_step] = logit; + qk_smem[ti - first_step] = mask ? 0.0f : logit; } // Compute the sum. @@ -219,7 +247,8 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( // softmax float inv_sum = __fdividef(1.f, exp_sum + 1.e-6); - for (int ti = first_step + tidx; ti < tlength; ti += THREADS_PER_BLOCK) { + for (int ti = first_step + tidx; ti < totalCacheSize; + ti += THREADS_PER_BLOCK) { qk_smem[ti - first_step] *= inv_sum; } @@ -254,14 +283,17 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( // vi; if (Dh == Dh_MAX || vi < Dh) { - for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) { + for (int ti = first_step + vo; ti < totalCacheSize; ti += V_PER_ITER) { // Load the values from the cache. int const ti_circ = ti % max_seq_length; - int const real_cache_idx = topology.real_token_pos[sub_req_idx][ti]; + // int const real_cache_idx = topology.real_token_pos[sub_req_idx][ti]; V_vec v = *reinterpret_cast( - v_cache_batch + real_cache_idx * hidden_size + - head_idx * per_head_size); - float logit = qk_smem[ti - first_step]; + v_cache_batch + ti_circ * hidden_size + head_idx * per_head_size); + + bool const mask = (ti >= bitmask.non_tree_cache_size && + (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & + (1 << query_token)))); + float logit = mask ? 0.0f : qk_smem[ti - first_step]; out = FlexFlow::fma(logit, cast_to_float(v), out); } } @@ -298,7 +330,8 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( // Output the final values. if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { convert_from_float( - *reinterpret_cast(output_ptr + (request_idx + sub_req_idx) * hidden_size + + *reinterpret_cast(output_ptr + + (request_idx + sub_req_idx) * hidden_size + head_idx * per_head_size + vi), out); } @@ -315,6 +348,7 @@ __global__ void specinfer_store_kv_cache( BeamSearchBatchConfig::BeamSearchPerTokenInfo *beamTokenInfos, BeamSearchBatchConfig::BeamSearchPerRequestInfo *beamRequestInfos, BeamSearchBatchConfig::SpecInferTopology *beam_topology_mask, + BatchConfig::BitMask *causalMask, int qProjSize, int kProjSize, int vProjSize, @@ -335,41 +369,57 @@ __global__ void specinfer_store_kv_cache( int const req_id = tokenInfos[token_idx].request_index; int const tok_id = tokenInfos[token_idx].abs_depth_in_request; - int const first_token_in_req = requestInfo[req_id].first_token_depth_in_request; + int const first_token_in_req = + requestInfo[req_id].first_token_depth_in_request; int const sub_req_id = beamTokenInfos[token_idx].sub_request_index; int const allocated_tokens = beam_topology_mask[req_id].allocated_tokens; + int const total_token = requestInfo[req_id].num_tokens_in_batch; + + BatchConfig::BitMask bitmask = causalMask[req_id]; + + int const sub_request_num = beamRequestInfos[req_id].sub_request_num; - int const beam_size = beamRequestInfos[req_id].sub_request_num; + int const tree_branch_num = beamRequestInfos[req_id].sub_request_num; + + // int const query_token = bitmask.non_tree_cache_size + bitmask.tree_size - + // tree_branch_num + sub_req_id + tok_id; + // bitmask.tree_size - tree_branch_num + sub_req_id; + + // if prompt token -> token id + // if tree token: + int const cache_idx = bitmask.non_tree_cache_size + bitmask.tree_size - + bitmask.this_layer_size + token_idx; int real_idx = tok_id - first_token_in_req + allocated_tokens + sub_req_id; - if (i == 0) { - printf("ffasdasds%d, %d, %d, %d, %d, %d\n", - beamTokenInfos[0].sub_request_index, - allocated_tokens, - sub_req_id, - tok_id, - first_token_in_req, - real_idx); - } - else if(i == hidden_size * 2){ - printf("hshddhdhdsdaww%d, %d, %d, %d, %d, %d\n", - beamTokenInfos[0].sub_request_index, - allocated_tokens, - sub_req_id, - tok_id, - first_token_in_req, - real_idx); - } - - + // if (i == 0) { + // printf("ffasdasds%d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d\n", + // beamTokenInfos[0].sub_request_index, + // allocated_tokens, + // sub_req_id, + // tok_id, + // first_token_in_req, + // real_idx, + // cache_idx, + // bitmask.non_tree_cache_size, + // bitmask.tree_size, + // sub_request_num, + // token_idx ); + // } else if (i == hidden_size * 2) { + // printf("hshddhdhdsdaww%d, %d, %d, %d, %d, %d, %d\n", + // beamTokenInfos[0].sub_request_index, + // allocated_tokens, + // sub_req_id, + // tok_id, + // first_token_in_req, + // real_idx, + // cache_idx); + // } kCache_ptr[(req_id * max_tree_branches) * (hidden_size * max_seq_len) + - (real_idx) * hidden_size + - offset] = kVal; + (cache_idx)*hidden_size + offset] = kVal; vCache_ptr[(req_id * max_tree_branches) * (hidden_size * max_seq_len) + - (real_idx) * hidden_size + - offset] = vVal; + (cache_idx)*hidden_size + offset] = vVal; } } @@ -398,6 +448,7 @@ void update_kv_cache_kernel(SpecInferIncMultiHeadSelfAttentionMeta const *m, m->beam_token_infos, m->beam_request_infos, m->beam_topology_mask, + m->causalMask, m->qProjSize, m->kProjSize, m->vProjSize, @@ -433,6 +484,7 @@ void update_kv_cache_kernel(SpecInferIncMultiHeadSelfAttentionMeta const *m, m->request_infos, \ m->beam_request_infos, \ m->beam_topology_mask, \ + m->causalMask, \ BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES) template @@ -520,7 +572,7 @@ void compute_attention_kernel_prompt( for (int i = 0; i < bc->max_requests_per_batch(); i++) { if (bc->request_completed[i]) { continue; - } + } // else if (tokens_previous_requests < bc->num_generation_tokens) { // tokens_previous_requests += bc->requestsInfo[i].num_tokens_in_batch; // continue; @@ -728,6 +780,16 @@ void inference_kernel(SpecInferIncMultiHeadSelfAttentionMeta const *m, DT const *bias_ptr, cudaStream_t stream) { // phase 1: Implement kernel to compute KQV for input tokens + + cudaMemcpyAsync(m->causalMask, + &(bc->causalMask), + bc->num_active_requests() * sizeof(BatchConfig::BitMask), + cudaMemcpyHostToDevice, + stream); + std::cout << "kernel bit mask: " << bc->causalMask[0].prompt_size << ", " + << bc->causalMask[0].non_tree_cache_size << ", " + << bc->causalMask[0].mask[0] << ", " << sizeof(BatchConfig::BitMask) + << "\n"; compute_qkv_kernel(m, bc, shard_id, @@ -830,6 +892,7 @@ void SpecInferIncMultiHeadSelfAttention::inference_kernel_wrapper( // "[Attention:forward:query]"); print_tensor<3, float>(acc_output.ptr, // acc_output.rect, "[Attention:forward:output]"); } + // print_tensor(output.get_float_ptr(), 32, "specinc output"); // if(bc->num_tokens == 1){ // print_tensor(input.get_float_ptr(), 32, "specinc input"); @@ -878,6 +941,11 @@ SpecInferIncMultiHeadSelfAttentionMeta::SpecInferIncMultiHeadSelfAttentionMeta( // allocate memory for the seqArray and reserve space { + size_t causal_mask_size = BatchConfig::MAX_NUM_REQUESTS; + size_t total_size = causal_mask_size * sizeof(BatchConfig::BitMask); + gpu_mem_allocator.create_legion_instance(beam_search_reserve_inst, + total_size); + beam_topology_mask = static_cast( handler.batch_config_metadata + sizeof(BatchConfig::tokensInfo) + @@ -895,6 +963,16 @@ SpecInferIncMultiHeadSelfAttentionMeta::SpecInferIncMultiHeadSelfAttentionMeta( sizeof(BatchConfig::requestsInfo) + sizeof(BeamSearchBatchConfig::topology_mask) + sizeof(BeamSearchBatchConfig::beamTokenInfo)); + // causalMask = + // static_cast( + // handler.batch_config_metadata + sizeof(BatchConfig::tokensInfo) + + // sizeof(BatchConfig::requestsInfo) + + // sizeof(BeamSearchBatchConfig::topology_mask) + + // sizeof(BeamSearchBatchConfig::beamTokenInfo)) + + // sizeof(BeamSearchBatchConfig::beamRequestsInfo); + + causalMask = gpu_mem_allocator.allocate_instance( + causal_mask_size); // beam_token_infos = // gpu_mem_allocator // .allocate_instance( diff --git a/src/ops/tree_inc_multihead_self_attention.cu b/src/ops/tree_inc_multihead_self_attention.cu index 1da56e383a..a3e3adcc30 100644 --- a/src/ops/tree_inc_multihead_self_attention.cu +++ b/src/ops/tree_inc_multihead_self_attention.cu @@ -53,6 +53,8 @@ __global__ void compute_attention_kernel_fused_kernel( BatchConfig::PerRequestInfo *request_infos, int num_heads, int num_requests, + int max_tree_branches, + BatchConfig::BitMask *causalMask, int qk_smem_sz) { // q, k @@ -81,6 +83,17 @@ __global__ void compute_attention_kernel_fused_kernel( request_infos[request_idx].num_tokens_in_batch; int const qlength = request_infos[request_idx].num_tokens_in_batch; + BatchConfig::BitMask bitmask = causalMask[request_idx]; + + // bitmask.mask[1] = 3; + if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { + printf("tree attn fused kernel %d, %d, %d, %lld\n", + tlength, + qlength, + bitmask.non_tree_cache_size, + bitmask.mask[1]); + } + int first_token_idx = 0; for (int r = 0; r < request_idx; r++) { first_token_idx += request_infos[request_idx].num_tokens_in_batch; @@ -115,7 +128,8 @@ __global__ void compute_attention_kernel_fused_kernel( constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; DT const *k_cache_batch = - key_cache + request_idx * max_seq_length * hidden_size + ki; + key_cache + + request_idx * max_tree_branches * max_seq_length * hidden_size + ki; int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step; @@ -127,10 +141,12 @@ __global__ void compute_attention_kernel_fused_kernel( q_ptr + (hidden_size * QKV_WEIGHT_NUM * qi) + ki + ii * THREADS_PER_KEY * K_VEC_SIZE); } + __syncthreads(); for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { K_vec k[K_VECS_PER_THREAD]; int const ti_circ = ti % max_seq_length; + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { int jj = ii * THREADS_PER_KEY * K_VEC_SIZE; if (ti < tlength) { @@ -142,22 +158,35 @@ __global__ void compute_attention_kernel_fused_kernel( float qk = scale * Qk_dot::dot(q_vecs[ki_o], k); if (ti < tlength && tidx % THREADS_PER_KEY == 0) { - bool const mask = ti_circ >= tlength; - if (mask) { - assert(false); + bool const mask = + (ti >= bitmask.non_tree_cache_size && + (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & (1 << qi)))); + + if (blockIdx.y == 0 && blockIdx.x == 0 && qi == 0 && mask) { + printf("tree attn mask for first token %d, %lld, %d, %d\n", + ti, + bitmask.mask[ti - bitmask.non_tree_cache_size], + bitmask.non_tree_cache_size, + qi); } - int pos = ti * qlength + qi; - if (((pos / qlength) % tlength) > (pos % qlength + tlength - qlength)) { - qk = -FLT_MAX; - } qk_max = mask ? qk_max : fmaxf(qk_max, qk); - qk_smem[pos] = mask ? 0.f : qk; + if (blockIdx.y == 0 && blockIdx.x == 0 && qi == 1 && !mask) { + printf("tree attn mask for second token %d, %lld, %d, %d, %.10f\n", + ti, + bitmask.mask[ti - bitmask.non_tree_cache_size], + bitmask.non_tree_cache_size, + qi, + qk); + } + qk_smem[ti - first_step] = mask ? 0.0f : qk; } } + __syncthreads(); +#pragma unroll for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); } @@ -176,66 +205,97 @@ __global__ void compute_attention_kernel_fused_kernel( // The warps finalize the reduction. qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; - +#pragma unroll for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); } // Broadcast to all the threads in the warp. qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + + if (blockIdx.y == 0 && blockIdx.x == 0 && qi == 0 && tidx == 0) { + printf("tree attn first token qk_max %f\n", + qk_max); + } - float exp_sum = 0.f; + float exp_sum = 0.f; for (int ti = first_step + tidx; ti < tlength; ti += THREADS_PER_BLOCK) { - float logit = __expf(qk_smem[ti * qlength + qi] - qk_max); + bool const mask = + (ti >= bitmask.non_tree_cache_size && + (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & (1 << qi)))); + float logit = mask ? 0.0f : __expf(qk_smem[ti - first_step] - qk_max); exp_sum += logit; - qk_smem[ti * qlength + qi] = logit; + qk_smem[ti - first_step] = mask ? 0.0f : logit; } // Compute the sum. exp_sum = block_sum(&red_smem[WARPS_PER_BLOCK], exp_sum); + if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0 && qi == 1) { + printf("expsum %.10f\n", exp_sum); + } + // softmax float inv_sum = __fdividef(1.f, exp_sum + 1.e-6); - for (int ti = first_step + tidx; ti < tlength; ti += THREADS_PER_BLOCK) { - qk_smem[ti * qlength + qi] *= inv_sum; + qk_smem[ti - first_step] *= inv_sum; } __syncthreads(); - } + if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0 && qi == 1) { + printf("softmax %.10f\n", qk_smem[0]); + } - // value projection - constexpr int V_VEC_SIZE = 16 / sizeof(DT); - // The value computed by this thread. - int vo = tidx / THREADS_PER_VALUE; - // The hidden dimensions computed by this particular thread. - int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; - constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; + // value projection + constexpr int V_VEC_SIZE = 16 / sizeof(DT); + // A vector of V elements for the current timestep. + // using V_vec_k = typename V_vec_k_::Type; + // using V_vec_acum = typename V_vec_acum_fp32_::Type; - Out_sum out; - // The base pointer for the value in the cache buffer. - DT const *v_cache_batch = - value_cache + request_idx * max_seq_length * hidden_size + vi; + // The value computed by this thread. + int vo = tidx / THREADS_PER_VALUE; + // The hidden dimensions computed by this particular thread. + int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; + constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; - for (int qi = 0; qi < qlength; qi++) { + Out_sum out; zero(out); - __syncthreads(); + + // The base pointer for the value in the cache buffer. + DT const *v_cache_batch = + value_cache + + request_idx * max_seq_length * hidden_size * max_tree_branches + vi; + // DT const *v_cache_batch = + // value_cache + + // (beam_request_idx * max_beam_width + beam_sub_request_idx) * + // max_seq_length * hidden_size + + // vi; + if (Dh == Dh_MAX || vi < Dh) { for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) { // Load the values from the cache. int const ti_circ = ti % max_seq_length; - + // int const real_cache_idx = topology.real_token_pos[sub_req_idx][ti]; V_vec v = *reinterpret_cast( v_cache_batch + ti_circ * hidden_size + head_idx * per_head_size); - float logit = qk_smem[ti * qlength + qi]; + + bool const mask = + (ti >= bitmask.non_tree_cache_size && + (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & (1 << qi)))); + float logit = mask ? 0.0f : qk_smem[ti - first_step]; out = FlexFlow::fma(logit, cast_to_float(v), out); + } } - // Make sure we can start writing to shared memory. + // // Make sure we can start writing to shared memory. __syncthreads(); + if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0 && qi == 1) { + printf("valueX %.10f\n", out.x); + } + // Run the final reduction amongst the different groups computing different // partial outputs. if (Dh == Dh_MAX || vi < Dh) { @@ -268,6 +328,11 @@ __global__ void compute_attention_kernel_fused_kernel( output_ptr + (first_token_idx + qi) * hidden_size + head_idx * per_head_size + vi), out); + if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0 && qi == 1) { + printf("tree attn final value, %.9f, %.9f, %.9f, %.9f, %d, %d\n", + out.x, out.y, out.z, out.w, vi, (first_token_idx + qi) * hidden_size + + head_idx * per_head_size + vi); + } } } } @@ -380,7 +445,9 @@ __global__ void update_tree_branch_kv_cache_fused( int vProjSize, int num_new_tokens, int max_seq_len, - int hidden_size) { + int hidden_size, + int max_tree_branches, + int first_token_depth) { CUDA_KERNEL_LOOP(i, num_new_tokens * hidden_size) { int token_idx = i / hidden_size; @@ -393,10 +460,10 @@ __global__ void update_tree_branch_kv_cache_fused( int const req_id = tokenInfos[token_idx].request_index; int const tok_id = tokenInfos[token_idx].abs_depth_in_request; - kCache_ptr[req_id * (hidden_size * max_seq_len) + tok_id * hidden_size + - offset] = kVal; - vCache_ptr[req_id * (hidden_size * max_seq_len) + tok_id * hidden_size + - offset] = vVal; + kCache_ptr[(req_id * max_tree_branches) * (hidden_size * max_seq_len) + + (token_idx + first_token_depth) * hidden_size + offset] = kVal; + vCache_ptr[(req_id * max_tree_branches) * (hidden_size * max_seq_len) + + (token_idx + first_token_depth) * hidden_size + offset] = vVal; } } @@ -473,7 +540,6 @@ void compute_attention_kernel(TreeIncMultiHeadSelfAttentionMeta const *m, } std::cout << "num_new_tokens: " << num_new_tokens << "\n"; - assert(false); int total_tokens_in_request = bc->tokensInfo[j].abs_depth_in_request + 1; assert(num_new_tokens >= 1 && total_tokens_in_request >= num_new_tokens); @@ -728,22 +794,11 @@ void compute_attention_kernel(TreeIncMultiHeadSelfAttentionMeta const *m, THDS_PER_KEY, \ THDS_PER_VALUE> \ <<>>( \ - static_cast
(m->devQKVProjArray), \ - static_cast
(m->keyCache), \ - static_cast
(m->valueCache), \ - output_ptr, \ - scale, \ - BatchConfig::max_sequence_length(), \ - BatchConfig::max_tokens_per_batch(), \ - m->qProjSize, \ - m->hidden_size, \ - m->request_infos, \ - m->num_q_heads, \ - bc->num_active_requests(), \ + static_cast
(m->devQKVProjArray), static_cast
(m->keyCache), static_cast
(m->valueCache), output_ptr, scale, BatchConfig::max_sequence_length(), BatchConfig::max_tokens_per_batch(), m->qProjSize, m->hidden_size, m->request_infos, m->num_q_heads, bc->num_active_requests(), BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES, m->causalMask, \ smem_sz[0]) template -void compute_attention_kernel_fused(IncMultiHeadSelfAttentionMeta const *m, +void compute_attention_kernel_fused(TreeIncMultiHeadSelfAttentionMeta const *m, TreeVerifyBatchConfig const *bc, DT *output_ptr, cudaStream_t stream) { @@ -752,6 +807,12 @@ void compute_attention_kernel_fused(IncMultiHeadSelfAttentionMeta const *m, // update K-V cache int num_new_tokens = bc->num_active_tokens(); int parallelism = m->hidden_size * num_new_tokens; + printf("update KV cache %d, idx: %d\n", + num_new_tokens, + bc->requestsInfo[0].first_token_depth_in_request); + for (int i = 0; i < num_new_tokens; i++) { + printf("abs depth:%d\n", bc->tokensInfo[i].abs_depth_in_request); + } update_tree_branch_kv_cache_fused<<vProjSize, num_new_tokens, BatchConfig::max_sequence_length(), - m->hidden_size); + m->hidden_size, + BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES, + bc->requestsInfo[0].first_token_depth_in_request); dim3 grid(m->num_q_heads, bc->num_active_requests()); int const per_head_size = m->qProjSize; @@ -816,12 +879,19 @@ void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m, // Note that m->num_active_tokens stores the number of active // tokens in the previous batch, which is needed for committing // keys/values to the key-value cache + std::cout << "tokens to be committed: " << bc->num_tokens_to_commit << "\n"; + cudaMemcpyAsync(m->committed_token_infos, &(bc->committed_tokens), bc->num_tokens_to_commit * sizeof(TreeVerifyBatchConfig::CommittedTokensInfo), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(m->causalMask, + &(bc->causalMask), + bc->num_active_requests() * sizeof(BatchConfig::BitMask), + cudaMemcpyHostToDevice, + stream); commit_tokens
(m, bc, stream); // After commit we update m->num_active_tokens to be the number of active @@ -948,6 +1018,20 @@ void TreeIncMultiHeadSelfAttention::inference_kernel_wrapper( cudaEventDestroy(t_start); cudaEventDestroy(t_end); } + + // print_tensor(output.get_float_ptr(), 32, "tree attn kernel"); + + // save_tensor( + // input.get_float_ptr(), + // 768 * bc->num_active_tokens(), + // "/home/xinhaoc/FlexFlow/inference/output/Newtreeinput.txt"); + // save_tensor( + // output.get_float_ptr(), + // 768 * bc->num_active_tokens(), + // "/home/xinhaoc/FlexFlow/inference/output/Newtreeoutput.txt"); + // std::cout << "new tokens: " << bc->num_active_tokens() << "\n"; + + // assert(bc->num_tokens_to_commit == 0); } TreeIncMultiHeadSelfAttentionMeta::TreeIncMultiHeadSelfAttentionMeta( @@ -993,8 +1077,11 @@ TreeIncMultiHeadSelfAttentionMeta::TreeIncMultiHeadSelfAttentionMeta( { int max_tokens_per_batch = BatchConfig::max_tokens_per_batch(); size_t committed_tokeninfo_size = max_tokens_per_batch; + size_t causal_mask_size = BatchConfig::MAX_NUM_REQUESTS; + size_t total_size = committed_tokeninfo_size * - sizeof(TreeVerifyBatchConfig::CommittedTokensInfo); + sizeof(TreeVerifyBatchConfig::CommittedTokensInfo) + + causal_mask_size * sizeof(BatchConfig::BitMask); if (offload) { // assert that we have enough reserved work space left assert(gpu_mem_allocator.reserved_total_size - @@ -1004,6 +1091,8 @@ TreeIncMultiHeadSelfAttentionMeta::TreeIncMultiHeadSelfAttentionMeta( gpu_mem_allocator .allocate_reserved( committed_tokeninfo_size); + causalMask = gpu_mem_allocator.allocate_instance( + causal_mask_size); } else { gpu_mem_allocator.create_legion_instance(committed_token_reserve_inst, total_size); @@ -1011,6 +1100,8 @@ TreeIncMultiHeadSelfAttentionMeta::TreeIncMultiHeadSelfAttentionMeta( gpu_mem_allocator .allocate_instance( committed_tokeninfo_size); + causalMask = gpu_mem_allocator.allocate_instance( + causal_mask_size); } } diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 775280e2cf..8a7cea1cc3 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -16,6 +16,7 @@ #include "flexflow/request_manager.h" #include "flexflow/parallel_ops/parallel_op.h" // #include "flexflow/tokenizers.h" +#include #include #include #include @@ -735,6 +736,11 @@ BeamSearchBatchConfig new_bc.requestsInfo[i].max_sequence_length - new_bc.requestsInfo[i].first_token_depth_in_request - verified_tokens.size(); + // std::cout << "max depth: " << new_max_depth << ", " + // << new_bc.requestsInfo[i].first_token_depth_in_request << + // ", " + // << verified_tokens.size() << "\n"; + // assert(false); new_bc.beamRequestsInfo[i].current_depth = 1; profiling_requests[request.guid].ssm_decoding_steps = 0; @@ -761,6 +767,10 @@ BeamSearchBatchConfig new_bc.topology_mask[i].real_token_pos[0][j] = j; } + updateBitMask(new_bc.causalMask[i], + verified_tokens.size(), + request.tokens.size()); + // Token Info for (int j = 0; j < verified_tokens.size(); j++) { auto token = verified_tokens.at(j); @@ -910,6 +920,11 @@ BeamSearchBatchConfig new_bc.num_tokens++; } new_bc.topology_mask[i].allocated_tokens = 0; + new_bc.causalMask[i].non_tree_cache_size = 0; + new_bc.causalMask[i].tree_size = + new_bc.requestsInfo[i].num_tokens_in_batch; + initBitMask(new_bc.causalMask[i], + new_bc.requestsInfo[i].num_tokens_in_batch); // if (new_bc.requestsInfo[i].num_tokens_in_batch < // new_request.initial_len) { @@ -1161,6 +1176,27 @@ BeamSearchBatchConfig } } + memcpy(&new_bc.causalMask[i], + &old_bc.causalMask[i], + sizeof(BatchConfig::BitMask)); + // sub_request_num -> nodes of input next iteration + // beam_size replicate num + + std::cout << "print beam tree: " + << old_bc.beamRequestsInfo[i].current_depth << "\n"; + BeamTree tree = request.beam_trees[old_bc.model_id]; + for (int k = 0; k <= old_bc.beamRequestsInfo[i].current_depth; k++) { + std::cout << "layer: " << k << "\n"; + std::cout << "nodes: " << tree.treeLayers[k].nodes_num_this_layer + << "\n"; + } + appendBitMask(new_bc.causalMask[i], + new_bc.beamRequestsInfo[i].sub_request_num, + old_bc.beamRequestsInfo[i].beam_size, + old_bc.beamRequestsInfo[i].sub_request_num, + tree, + old_bc.beamRequestsInfo[i].current_depth); + // assert(false); for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) { int depth = new_bc.requestsInfo[i].first_token_depth_in_request + j; for (int k = 0; k < new_bc.beamRequestsInfo[i].sub_request_num; k++) { @@ -1248,6 +1284,10 @@ BeamSearchBatchConfig assert(false && "Request should be pending"); } + memcpy(&new_bc.causalMask[i], + &old_bc.causalMask[i], + sizeof(BatchConfig::BitMask)); + if (new_bc.requestsInfo[i].first_token_depth_in_request >= request.tokens.size()) { // request is done @@ -1260,6 +1300,13 @@ BeamSearchBatchConfig (int)request.tokens.size() - new_bc.requestsInfo[i].first_token_depth_in_request); request.ssm_cache_size += new_bc.requestsInfo[i].num_tokens_in_batch; + BeamTree tree = request.beam_trees[old_bc.model_id]; + appendBitMask(new_bc.causalMask[i], + new_bc.beamRequestsInfo[i].sub_request_num, + old_bc.beamRequestsInfo[i].beam_size, + old_bc.beamRequestsInfo[i].sub_request_num, + tree, + old_bc.beamRequestsInfo[i].current_depth); } if (verbose) { @@ -1378,7 +1425,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( profiling_requests[request.guid].llm_decoding_steps += 1; if (request.status == Request::RUNNING) { - + std::cout << "prepare next batch running: pending\n" << "\n"; new_bc.request_running[i] = true; @@ -1398,7 +1445,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( std::vector> dfs_tree_inputs = merge_dfs_trees(all_dfs_trees, request.tokens.size() - 1, guid); - if (verbose) { + if (true) { std::cout << "Request Tokens Size: " << request.tokens.size() << std::endl; for (int k = 0; k < request.tokens.size(); k++) { @@ -1414,6 +1461,13 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( old_batches.at(0).requestsInfo[i].request_guid; new_bc.requestsInfo[i].max_sequence_length = old_batches.at(0).requestsInfo[i].max_sequence_length; + + // copy bitmask to verify batchconfig + memcpy(&(new_bc.causalMask[i]), + &(old_batches.at(0).causalMask[i]), + sizeof(BatchConfig::BitMask)); + // std::cout << "bitmask: " << new_bc.causalMask[i].mask[0] << "\n"; + // assert(false); // TODO: Check this new_bc.requestsInfo[i].num_tokens_in_batch = 0; new_bc.request_completed[i] = false; @@ -1429,7 +1483,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( i; new_bc.committed_tokens[new_bc.num_tokens_to_commit].token_depth = committed_token.first; - if (verbose) { + if (true) { std::cout << new_bc.num_tokens_to_commit << "- committed_token.token_depth: " << committed_token.first @@ -1441,7 +1495,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( } } } - if (verbose) { + if (true) { std::cout << "new_bc.num_tokens_to_commit: " << new_bc.num_tokens_to_commit << std::endl; } @@ -1463,8 +1517,10 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( new_bc.requestsInfo[i].first_token_depth_in_request = request.tokens.size() - 1; - - std::cout << "prepare next batch verify: " << dfs_tree_inputs.size() << "\n"; + + std::cout << "prepare next batch verify: " << dfs_tree_inputs.size() + << "\n"; + // Add Tokens from the DFS Tree to the next batch for (int j = 1; j < dfs_tree_inputs.size(); j++) { auto token = dfs_tree_inputs.at(j); @@ -1485,7 +1541,6 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( break; } } - assert(false); } else if (request.status == Request::PENDING) { std::cout << "prepare next batch verify: pending\n" @@ -1518,6 +1573,12 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( << new_bc.num_tokens_to_commit << std::endl; } + memcpy(&(new_bc.causalMask[i]), + &(old_batches.at(0).causalMask[i]), + sizeof(BatchConfig::BitMask)); + // std::cout << "bitmask: " << new_bc.causalMask[i].mask[0] << "\n"; + // assert(false); + // Normal Request Info new_bc.requestsInfo[i].first_token_depth_in_request = request.llm_cache_size; @@ -1643,8 +1704,6 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, (old_bc.tokensInfo[i - 1].abs_depth_in_request - start_depth) * beam_size; - // result_index += old_bc.topology_mask[index].allocated_tokens; - if (true) { std::cout << "i = " << i << ", result index = " << result_index << ", value: " << result.token_ids[result_index] @@ -1669,6 +1728,9 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, request.tokens.back(); request.beam_trees.at(old_bc.model_id).treeLayers[0].probs[0] = 1; request.beam_trees.at(old_bc.model_id).treeLayers[0].parent_ids[0] = -1; + request.beam_trees.at(old_bc.model_id) + .treeLayers[0] + .nodes_num_this_layer = 1; if (true) { std::cout << "Store the previous last token to the tree root: " @@ -1677,7 +1739,9 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, } std::cout << "leaffffff: " << leaf_node_num << "\n"; - + request.beam_trees.at(old_bc.model_id) + .treeLayers[depth] + .nodes_num_this_layer = leaf_node_num; for (int beam_id = 0; beam_id < leaf_node_num; beam_id++) { request.beam_trees.at(old_bc.model_id) @@ -1751,50 +1815,6 @@ void RequestManager::update_beam_metadata(BeamSearchBatchConfig &new_bc, std::cout << "token: " << j << ": " << new_bc.beamRequestsInfo[request_index].tokens[j] << "\n"; } - - // std::set parents; - // std::set childs; - // // cache stealing - // for (int j = 0; j < beam_size; j++) { - // int parent_id = tree.treeLayers[depth].parent_ids[j]; - // if (childs.find(parent_id) == childs.end()) { - // // copy beam slot - // new_bc.beamRequestsInfo[request_index].parent_id[parent_id] = - // tree.treeLayers[depth].parent_ids[j]; - // new_bc.beamRequestsInfo[request_index].probs[parent_id] = - // tree.treeLayers[depth].probs[j]; - // new_bc.beamRequestsInfo[request_index].tokens[parent_id] = - // tree.treeLayers[depth].tokens[j]; - // parents.emplace(j); - // childs.emplace(parent_id); - // } - // } - // if (parents.size() < beam_size) { - // for (int j = 0; j < beam_size; j++) { - // if (parents.find(j) == parents.end()) { - // // this slot has not been assigned - // // find the smallest not assigned child and put in - // if (verbose) { - // std::cout << "request_index" << request_index - // << ", miss slot: " << j << "\n"; - // } - // for (int k = 0; k < beam_size; k++) { - // if (childs.find(k) == childs.end()) { - // // parent -> j to child k; - // new_bc.beamRequestsInfo[request_index].parent_id[k] = - // tree.treeLayers[depth].parent_ids[j]; - // new_bc.beamRequestsInfo[request_index].probs[k] = - // tree.treeLayers[depth].probs[j]; - // new_bc.beamRequestsInfo[request_index].tokens[k] = - // tree.treeLayers[depth].tokens[j]; - // parents.emplace(j); - // childs.emplace(k); - // break; - // } - // } - // } - // } - // } } if (verbose) { std::cout << "-----------after parent id exchange-----------" << std::endl; @@ -1809,6 +1829,128 @@ void RequestManager::update_beam_metadata(BeamSearchBatchConfig &new_bc, } } +// bit mask related function + +// prompt phase, init task +void RequestManager::initBitMask(BatchConfig::BitMask &bitmask, + int initLength) { + assert(initLength <= BatchConfig::MAX_SPEC_TREE_TOKEN_NUM && + "do not support tree size > 64"); + // eg. 4 tokens: t1: 0000000..1111, t2: 0000000..1110, t3: 0000000..1100, t4: + // 0000000..1000 + + bitmask.prompt_size = initLength; + bitmask.this_layer_size = initLength; + bitmask.tree_size = initLength; + for (int i = 0; i < bitmask.prompt_size; i++) { + for (int j = i; j < bitmask.prompt_size; j++) { + bitmask.mask[i] |= (1 << j); + } + } + std::cout << "see bit mask" << bitmask.prompt_size << "\n"; + std::cout << "see bit mask" << std::bitset<64>(bitmask.mask[0]) << "\n"; + std::cout << "see bit mask" << std::bitset<64>(bitmask.mask[1]) << "\n"; + std::cout << "see bit mask" << std::bitset<64>(bitmask.mask[2]) << "\n"; +} + +// prepare next init +void RequestManager::updateBitMask(BatchConfig::BitMask &bitmask, + int initLength, + int non_tree_size) { + // assert(initLength == 1); + // eg. 4 tokens: t1: 0000000..1111, t2: 0000000..1110, t3: 0000000..1100, t4: + // 0000000..1000 + assert(initLength <= BatchConfig::MAX_SPEC_TREE_TOKEN_NUM && + "do not support tree size > 64"); + bitmask.non_tree_cache_size = non_tree_size; + bitmask.tree_size = initLength; + bitmask.this_layer_size = initLength; + std::cout << "non_tree_size: " << non_tree_size << "\n"; + bitmask.prompt_size = initLength; + for (int i = 0; i < bitmask.prompt_size; i++) { + for (int j = i; j < bitmask.prompt_size; j++) { + bitmask.mask[i] |= (1 << j); + } + } + + std::cout << "see bit mask update" << bitmask.prompt_size << "\n"; + std::cout << "see bit mask update" << std::bitset<64>(bitmask.mask[0]) + << "\n"; + std::cout << "see bit mask update" << std::bitset<64>(bitmask.mask[1]) + << "\n"; + std::cout << "see bit mask update" << std::bitset<64>(bitmask.mask[2]) + << "\n"; +} + +// prepare next beam, append layers to the tree +void RequestManager::appendBitMask(BatchConfig::BitMask &bitmask, + int newNodes, + int preBeamSize, + int old_sub_num, + BeamTree const tree, + int currentDepth) { + int pre_tree_size = bitmask.tree_size; + bitmask.tree_size += newNodes; + bitmask.this_layer_size = newNodes; + assert(bitmask.tree_size <= BatchConfig::MAX_SPEC_TREE_TOKEN_NUM && + "do not support tree size > 64"); + // preBeamSize: replicate num + + // add relationship with input/prompt + for (int i = 0; i < bitmask.prompt_size; i++) { + for (int j = pre_tree_size; j < bitmask.tree_size; j++) { + bitmask.mask[i] |= (1 << j); + std::cout << "see bit mask append: " << i << ", to" << j + << std::bitset<64>(bitmask.mask[i]) << "\n"; + } + } + + std::cout << "bitmask.tree_size: " << bitmask.tree_size << ", " + << pre_tree_size << ", " << bitmask.prompt_size << ", " + << preBeamSize << "\n"; + + // int num_groups = newNodes / preBeamSize; + // int group_size = newNodes / num_groups; + // add relations to branch + // requests in same groups share same relations, except the last token. + + // set middle layers + // skip the root prompt/tokens + int token_idx = bitmask.prompt_size; + int new_nodes_start_idx = pre_tree_size; + std::cout << "new nodes start " << new_nodes_start_idx << "\n"; + for (int i = 1; i < currentDepth; i++) { + new_nodes_start_idx = pre_tree_size; + int nodes_this_layer = tree.treeLayers[i].nodes_num_this_layer; + std::cout << "tree layer: " << i << " nodes:" << nodes_this_layer + << "group size: " << newNodes / nodes_this_layer << "\n"; + for (int j = 0; j < nodes_this_layer; j++) { + int group_size = newNodes / nodes_this_layer; + for (int k = 0; k < group_size; k++) { + bitmask.mask[token_idx] |= (1 << new_nodes_start_idx); + new_nodes_start_idx += 1; + } + token_idx += 1; + } + } + + std::cout << "token idx: " << token_idx << ", " << pre_tree_size << ", " + << new_nodes_start_idx << ", " << newNodes + << "current depth: " << currentDepth << "\n"; + std::cout << "new nodes end " << new_nodes_start_idx << "\n"; + + std::cout << "tree size: " << bitmask.tree_size << "\n"; + assert(token_idx == pre_tree_size); + assert(currentDepth <= 1 || new_nodes_start_idx == bitmask.tree_size); + + // assert(currentDepth <= 2); + // set last layer, all tokens are only relevant to it self; + for (int i = token_idx; i < bitmask.tree_size; i++) { + bitmask.mask[i] |= (1 << i); + std::cout << "set rel: " << i << "to: " << i << "\n"; + } +} + bool PreOrder( BeamTree const &tree, int max_depth, @@ -1979,7 +2121,7 @@ std::vector> RequestManager::traverse_beam_tree(BeamSearchBatchConfig const &old_bc, int request_index, int first_token_depth_in_request) { - if (verbose) { + if (true) { std::cout << "[Traverse Beam Tree] request_index: " << request_index << "\n"; std::cout << "[Traverse Beam Tree] max_depth: " @@ -1988,6 +2130,8 @@ std::vector> << old_bc.beamRequestsInfo[request_index].current_depth << "\n"; std::cout << "[Traverse Beam Tree] beam_width: " << old_bc.beamRequestsInfo[request_index].beam_size << "\n"; + std::cout << "[Traverse Beam Tree] start index: " + << first_token_depth_in_request << "\n"; } auto guid = old_bc.requestsInfo[request_index].request_guid; @@ -1995,27 +2139,39 @@ std::vector> // std::cout << "request.beam_trees.size(): " << request.beam_trees.size() // << std::endl; BeamTree tree = request.beam_trees.at(old_bc.model_id); - // std::cout << "\n\n"; + std::cout << "print beam tree: " + << "\n"; + std::vector> serializedTree; + for (int i = 0; i <= old_bc.beamRequestsInfo[request_index].max_depth; i++) { + std::cout << "tree layer: " << i + << ", num_nodes: " << tree.treeLayers[i].nodes_num_this_layer + << "\n"; + // push tokens into tree + for (int j = 0; j < tree.treeLayers[i].nodes_num_this_layer; j++) { + std::cout << "token: " << tree.treeLayers[i].tokens[j] << "\n"; + serializedTree.push_back(std::make_pair(tree.treeLayers[i].tokens[j], i)); + } + } // token, index // todo make this one global for different stages - std::vector> serializedTree; - PreOrder(tree, - old_bc.beamRequestsInfo[request_index].max_depth, - 0, - old_bc.beamRequestsInfo[request_index].beam_size, - 0, - serializedTree, - verbose); + + // PreOrder(tree, + // old_bc.beamRequestsInfo[request_index].max_depth, + // 0, + // old_bc.beamRequestsInfo[request_index].beam_size, + // 0, + // serializedTree, + // verbose); // print it - if (verbose) { + if (true) { std::cout << "Print serialized tree: size:" << request_index << serializedTree.size() << "\n"; } for (int k = 0; k < serializedTree.size(); k++) { serializedTree.at(k).second += first_token_depth_in_request; - if (verbose) { + if (true) { std::cout << "token id: " << serializedTree.at(k).first << ", depth: " << serializedTree.at(k).second << "\n"; } @@ -2041,6 +2197,9 @@ std::vector> input_trees, int root_depth, RequestGuid guid) { + assert(input_trees.size() == 1 && "currently using one ssm"); + return input_trees.at(0); + std::vector> merged_tree; std::unordered_map> childrens; diff --git a/src/runtime/request_manager.cu b/src/runtime/request_manager.cu index b76c5c326e..4d7e2c8806 100644 --- a/src/runtime/request_manager.cu +++ b/src/runtime/request_manager.cu @@ -104,6 +104,18 @@ void RequestManager::load_tokens_task( sizeof(BeamSearchBatchConfig::beamRequestsInfo), cudaMemcpyHostToDevice, stream); + + // cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + + // sizeof(BatchConfig::tokensInfo) + + // sizeof(BatchConfig::requestsInfo) + + // sizeof(BeamSearchBatchConfig::topology_mask) + + // sizeof(BeamSearchBatchConfig::beamTokenInfo) + + // sizeof(BeamSearchBatchConfig::beamRequestsInfo), + // &(beam_batch_config->causalMask), + // sizeof(BatchConfig::causalMask), + // cudaMemcpyHostToDevice, + // stream); + // std::cout << "copy calsual mask info: " << beam_batch_config->causalMask[0].prompt_size << "\n"; } } From 945268f1a56e804b62b731c136bf8358c47b765f Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Thu, 28 Dec 2023 11:19:16 -0500 Subject: [PATCH 05/30] fix. --- inference/spec_infer/spec_infer.cc | 2 +- src/ops/tree_inc_multihead_self_attention.cu | 78 ++++++++++---------- src/runtime/request_manager.cc | 11 ++- 3 files changed, 50 insertions(+), 41 deletions(-) diff --git a/inference/spec_infer/spec_infer.cc b/inference/spec_infer/spec_infer.cc index 2ccdfd388d..e4fa71a1d5 100644 --- a/inference/spec_infer/spec_infer.cc +++ b/inference/spec_infer/spec_infer.cc @@ -404,7 +404,7 @@ void FlexFlow::top_level_task(Task const *task, prompts.push_back(text); // tree_model.generate(text, 128 /*max_sequence_length*/); } - tree_model.generate(prompts, 15 /*max_sequence_length*/); + tree_model.generate(prompts, 23 /*max_sequence_length*/); } // Execution fence diff --git a/src/ops/tree_inc_multihead_self_attention.cu b/src/ops/tree_inc_multihead_self_attention.cu index a3e3adcc30..3d5ccf9431 100644 --- a/src/ops/tree_inc_multihead_self_attention.cu +++ b/src/ops/tree_inc_multihead_self_attention.cu @@ -162,24 +162,24 @@ __global__ void compute_attention_kernel_fused_kernel( (ti >= bitmask.non_tree_cache_size && (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & (1 << qi)))); - if (blockIdx.y == 0 && blockIdx.x == 0 && qi == 0 && mask) { - printf("tree attn mask for first token %d, %lld, %d, %d\n", - ti, - bitmask.mask[ti - bitmask.non_tree_cache_size], - bitmask.non_tree_cache_size, - qi); - } + // if (blockIdx.y == 0 && blockIdx.x == 0 && qi == 0 && mask) { + // printf("tree attn mask for first token %d, %lld, %d, %d\n", + // ti, + // bitmask.mask[ti - bitmask.non_tree_cache_size], + // bitmask.non_tree_cache_size, + // qi); + // } qk_max = mask ? qk_max : fmaxf(qk_max, qk); - if (blockIdx.y == 0 && blockIdx.x == 0 && qi == 1 && !mask) { - printf("tree attn mask for second token %d, %lld, %d, %d, %.10f\n", - ti, - bitmask.mask[ti - bitmask.non_tree_cache_size], - bitmask.non_tree_cache_size, - qi, - qk); - } + // if (blockIdx.y == 0 && blockIdx.x == 0 && qi == 1 && !mask) { + // printf("tree attn mask for second token %d, %lld, %d, %d, %.10f\n", + // ti, + // bitmask.mask[ti - bitmask.non_tree_cache_size], + // bitmask.non_tree_cache_size, + // qi, + // qk); + // } qk_smem[ti - first_step] = mask ? 0.0f : qk; } } @@ -213,10 +213,10 @@ __global__ void compute_attention_kernel_fused_kernel( // Broadcast to all the threads in the warp. qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); - if (blockIdx.y == 0 && blockIdx.x == 0 && qi == 0 && tidx == 0) { - printf("tree attn first token qk_max %f\n", - qk_max); - } + // if (blockIdx.y == 0 && blockIdx.x == 0 && qi == 0 && tidx == 0) { + // printf("tree attn first token qk_max %f\n", + // qk_max); + // } float exp_sum = 0.f; @@ -232,9 +232,9 @@ __global__ void compute_attention_kernel_fused_kernel( // Compute the sum. exp_sum = block_sum(&red_smem[WARPS_PER_BLOCK], exp_sum); - if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0 && qi == 1) { - printf("expsum %.10f\n", exp_sum); - } + // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0 && qi == 1) { + // printf("expsum %.10f\n", exp_sum); + // } // softmax float inv_sum = __fdividef(1.f, exp_sum + 1.e-6); @@ -243,9 +243,9 @@ __global__ void compute_attention_kernel_fused_kernel( } __syncthreads(); - if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0 && qi == 1) { - printf("softmax %.10f\n", qk_smem[0]); - } + // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0 && qi == 1) { + // printf("softmax %.10f\n", qk_smem[0]); + // } // value projection constexpr int V_VEC_SIZE = 16 / sizeof(DT); @@ -292,9 +292,9 @@ __global__ void compute_attention_kernel_fused_kernel( // // Make sure we can start writing to shared memory. __syncthreads(); - if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0 && qi == 1) { - printf("valueX %.10f\n", out.x); - } + // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0 && qi == 1) { + // printf("valueX %.10f\n", out.x); + // } // Run the final reduction amongst the different groups computing different // partial outputs. @@ -328,11 +328,11 @@ __global__ void compute_attention_kernel_fused_kernel( output_ptr + (first_token_idx + qi) * hidden_size + head_idx * per_head_size + vi), out); - if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0 && qi == 1) { - printf("tree attn final value, %.9f, %.9f, %.9f, %.9f, %d, %d\n", - out.x, out.y, out.z, out.w, vi, (first_token_idx + qi) * hidden_size + - head_idx * per_head_size + vi); - } + // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0 && qi == 1) { + // printf("tree attn final value, %.9f, %.9f, %.9f, %.9f, %d, %d\n", + // out.x, out.y, out.z, out.w, vi, (first_token_idx + qi) * hidden_size + + // head_idx * per_head_size + vi); + // } } } } @@ -807,12 +807,12 @@ void compute_attention_kernel_fused(TreeIncMultiHeadSelfAttentionMeta const *m, // update K-V cache int num_new_tokens = bc->num_active_tokens(); int parallelism = m->hidden_size * num_new_tokens; - printf("update KV cache %d, idx: %d\n", - num_new_tokens, - bc->requestsInfo[0].first_token_depth_in_request); - for (int i = 0; i < num_new_tokens; i++) { - printf("abs depth:%d\n", bc->tokensInfo[i].abs_depth_in_request); - } + // printf("update KV cache %d, idx: %d\n", + // num_new_tokens, + // bc->requestsInfo[0].first_token_depth_in_request); + // for (int i = 0; i < num_new_tokens; i++) { + // printf("abs depth:%d\n", bc->tokensInfo[i].abs_depth_in_request); + // } update_tree_branch_kv_cache_fused<<> verified_tokens = traverse_verify_tree(guid, dfs_tree_inputs.at(guid), tree_outputs); + log_req_mgr.print("Number of Verified Tokens = %zu", verified_tokens.size()); @@ -1426,7 +1429,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( if (request.status == Request::RUNNING) { - std::cout << "prepare next batch running: pending\n" + std::cout << "prepare next batch running:\n" << "\n"; new_bc.request_running[i] = true; std::cout << "[Verify] Request " << request.guid << " is running" @@ -1663,6 +1666,9 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( } } + std::cout << "check dfs tree input size: " << dfs_tree_inputs[1000000].size() + << "\n"; + return new_bc; } @@ -2198,6 +2204,7 @@ std::vector> int root_depth, RequestGuid guid) { assert(input_trees.size() == 1 && "currently using one ssm"); + dfs_tree_inputs[guid] = input_trees.at(0); return input_trees.at(0); std::vector> merged_tree; @@ -2249,6 +2256,8 @@ std::vector> } dfs_tree_inputs[guid] = merged_tree; + // std::cout << "assign dfr tree: " << guid << ", " << merged_tree.size() << ", " + // << dfs_tree_inputs[guid].size() << "\n"; return merged_tree; } From ce95127aecaf553679539310574b48417609efa2 Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Fri, 29 Dec 2023 03:41:26 -0500 Subject: [PATCH 06/30] fix --- inference/spec_infer/spec_infer.cc | 4 +- src/ops/kernels/embedding_kernels.cu | 2 +- .../specinfer_inc_multihead_self_attention.cu | 76 ++++--- src/ops/tree_inc_multihead_self_attention.cu | 114 ++++++---- src/runtime/request_manager.cc | 198 +++++++++++------- 5 files changed, 246 insertions(+), 148 deletions(-) diff --git a/inference/spec_infer/spec_infer.cc b/inference/spec_infer/spec_infer.cc index e4fa71a1d5..9af3e12e5a 100644 --- a/inference/spec_infer/spec_infer.cc +++ b/inference/spec_infer/spec_infer.cc @@ -402,9 +402,9 @@ void FlexFlow::top_level_task(Task const *task, printf("Prompt[%d]: %s\n", total_num_requests, text.c_str()); total_num_requests++; prompts.push_back(text); - // tree_model.generate(text, 128 /*max_sequence_length*/); + // tree_model.generate(text, 128 /*max_sequence_length*/); } - tree_model.generate(prompts, 23 /*max_sequence_length*/); + tree_model.generate(prompts, 128 /*max_sequence_length*/); } // Execution fence diff --git a/src/ops/kernels/embedding_kernels.cu b/src/ops/kernels/embedding_kernels.cu index 91f5d60e85..0cde42de56 100644 --- a/src/ops/kernels/embedding_kernels.cu +++ b/src/ops/kernels/embedding_kernels.cu @@ -118,7 +118,7 @@ void forward_kernel_wrapper(EmbeddingMeta const *m, // print_tensor(output_ptr, output_domain.get_volume(), // "[Embedding:forward:output]"); } - print_tensor(input.get_int32_ptr(), 32, "embeddinginput"); + // print_tensor(input.get_int32_ptr(), 32, "embeddinginput"); } /*static*/ diff --git a/src/ops/specinfer_inc_multihead_self_attention.cu b/src/ops/specinfer_inc_multihead_self_attention.cu index e8ac1d980c..f2ea63d904 100644 --- a/src/ops/specinfer_inc_multihead_self_attention.cu +++ b/src/ops/specinfer_inc_multihead_self_attention.cu @@ -83,9 +83,9 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( int const tlength = request_infos[request_idx].first_token_depth_in_request + request_infos[request_idx].num_tokens_in_batch; - if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { - printf("specinfer attn fused kernel %lld\n", bitmask.mask[1]); - } + // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { + // printf("specinfer attn fused kernel %lld\n", bitmask.mask[1]); + // } int const totalCacheSize = bitmask.non_tree_cache_size + bitmask.tree_size; // int const qlength = request_infos[request_idx].num_tokens_in_batch; @@ -181,6 +181,10 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( } float qk = scale * Qk_dot::dot(q_vecs[ki_o], k); + // if (blockIdx.y == 0 && blockIdx.x == 0) { + // printf("spec inc attn qkqkqk %d, %.10f, %d\n", ti, qk, sub_req_idx); + // } + if (ti < totalCacheSize && tidx % THREADS_PER_KEY == 0) { // todo add alobi here // bool const mask = ti_circ >= totalCacheSize; @@ -188,15 +192,15 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & (1 << query_token)))); - if (blockIdx.y == 0 && blockIdx.x == 0 && mask && sub_req_idx == 0) { - // printf("specinfer mask: ti:%d, %d, %d, %d, %lld\n", - // ti, - // totalCacheSize, - // ti - bitmask.non_tree_cache_size, - // query_token, - // bitmask.mask[ti - bitmask.non_tree_cache_size]); - // assert(false); - } + // if (blockIdx.y == 0 && blockIdx.x == 0 && sub_req_idx == 0) { + // printf("specinfer mask: ti:%d, %d, %d, %d, %lld\n", + // ti, + // totalCacheSize, + // bitmask.non_tree_cache_size, + // query_token, + // bitmask.mask[ti - bitmask.non_tree_cache_size]); + // // assert(false); + // } qk_max = mask ? qk_max : fmaxf(qk_max, qk); qk_smem[ti - first_step] = mask ? 0.f : qk; } @@ -231,6 +235,10 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( // Broadcast to all the threads in the warp. qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { + // printf("spec inc attn first token qk_max %.10f\n", qk_max); + // } + float exp_sum = 0.f; for (int ti = first_step + tidx; ti < totalCacheSize; ti += THREADS_PER_BLOCK) { @@ -245,6 +253,10 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( // Compute the sum. exp_sum = block_sum(&red_smem[WARPS_PER_BLOCK], exp_sum); + // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { + // printf("spec inc attn exp_sum %.10f\n", exp_sum); + // } + // softmax float inv_sum = __fdividef(1.f, exp_sum + 1.e-6); for (int ti = first_step + tidx; ti < totalCacheSize; @@ -301,6 +313,10 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( // // Make sure we can start writing to shared memory. __syncthreads(); + // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { + // printf("valueX %.10f\n", out.x); + // } + // Run the final reduction amongst the different groups computing different // partial outputs. if (Dh == Dh_MAX || vi < Dh) { @@ -357,8 +373,8 @@ __global__ void specinfer_store_kv_cache( int max_tree_branches, bool is_root, int hidden_size) { - CUDA_KERNEL_LOOP(i, num_tokens * hidden_size * 2) { - int token_idx = i / (hidden_size * KV_WEIGHT_NUM); + CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { + int token_idx = i / (hidden_size); int offset = i % hidden_size; size_t val_idx = @@ -416,6 +432,16 @@ __global__ void specinfer_store_kv_cache( // cache_idx); // } + // if (i % hidden_size == 0) { + // printf("update cache: %d, %d, %d, %d, %d, %d\n", + // cache_idx, + // num_tokens, + // bitmask.non_tree_cache_size, + // bitmask.tree_size, + // bitmask.this_layer_size, + // token_idx); + // } + kCache_ptr[(req_id * max_tree_branches) * (hidden_size * max_seq_len) + (cache_idx)*hidden_size + offset] = kVal; vCache_ptr[(req_id * max_tree_branches) * (hidden_size * max_seq_len) + @@ -433,9 +459,9 @@ void update_kv_cache_kernel(SpecInferIncMultiHeadSelfAttentionMeta const *m, // assert(curr_depth < 3); if (num_tokens > 0) { int parallelism = m->hidden_size * KV_WEIGHT_NUM * num_tokens; - printf("tokenInfo %d, %d\n", - bc->beamTokenInfo[0].sub_request_index, - num_tokens); + // printf("tokenInfo %d, %d\n", + // bc->beamTokenInfo[0].sub_request_index, + // num_tokens); specinfer_store_kv_cache<<num_active_requests() * sizeof(BatchConfig::BitMask), cudaMemcpyHostToDevice, stream); - std::cout << "kernel bit mask: " << bc->causalMask[0].prompt_size << ", " - << bc->causalMask[0].non_tree_cache_size << ", " - << bc->causalMask[0].mask[0] << ", " << sizeof(BatchConfig::BitMask) - << "\n"; + // std::cout << "kernel bit mask: " << bc->causalMask[0].prompt_size << ", " + // << bc->causalMask[0].non_tree_cache_size << ", " + // << bc->causalMask[0].mask[0] << ", " << + // sizeof(BatchConfig::BitMask) + // << "\n"; compute_qkv_kernel(m, bc, shard_id, @@ -800,8 +827,8 @@ void inference_kernel(SpecInferIncMultiHeadSelfAttentionMeta const *m, stream); // phase 2: Update key/val cache update_kv_cache_kernel
(m, bc, stream); - std::cout << "specinfer kernel token num: " << bc->num_generation_tokens - << ", " << bc->num_tokens << "\n"; + // std::cout << "specinfer kernel token num: " << bc->num_generation_tokens + // << ", " << bc->num_tokens << "\n"; if (bc->num_generation_tokens > 0) { compute_specinfer_attention_kernel_generation
( m, bc, static_cast
(m->attn_heads), stream); @@ -809,6 +836,7 @@ void inference_kernel(SpecInferIncMultiHeadSelfAttentionMeta const *m, // phase 3: Compute attention score // 3 kernels for pahse 3: matmul1 - softmax - matmal2 if (bc->num_tokens > bc->num_generation_tokens) { + // printf("spec inc prompt decoding\n"); compute_attention_kernel_prompt( m, bc, shard_id, output_ptr, bias_ptr, weight_ptr, stream); } @@ -892,7 +920,7 @@ void SpecInferIncMultiHeadSelfAttention::inference_kernel_wrapper( // "[Attention:forward:query]"); print_tensor<3, float>(acc_output.ptr, // acc_output.rect, "[Attention:forward:output]"); } - // print_tensor(output.get_float_ptr(), 32, "specinc output"); + // print_tensor(output.get_float_ptr(), 32, "specinc output"); // if(bc->num_tokens == 1){ // print_tensor(input.get_float_ptr(), 32, "specinc input"); diff --git a/src/ops/tree_inc_multihead_self_attention.cu b/src/ops/tree_inc_multihead_self_attention.cu index 3d5ccf9431..180a165451 100644 --- a/src/ops/tree_inc_multihead_self_attention.cu +++ b/src/ops/tree_inc_multihead_self_attention.cu @@ -86,13 +86,13 @@ __global__ void compute_attention_kernel_fused_kernel( BatchConfig::BitMask bitmask = causalMask[request_idx]; // bitmask.mask[1] = 3; - if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { - printf("tree attn fused kernel %d, %d, %d, %lld\n", - tlength, - qlength, - bitmask.non_tree_cache_size, - bitmask.mask[1]); - } + // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { + // printf("tree attn fused kernel %d, %d, %d, %lld\n", + // tlength, + // qlength, + // bitmask.non_tree_cache_size, + // bitmask.mask[3]); + // } int first_token_idx = 0; for (int r = 0; r < request_idx; r++) { @@ -161,7 +161,7 @@ __global__ void compute_attention_kernel_fused_kernel( bool const mask = (ti >= bitmask.non_tree_cache_size && (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & (1 << qi)))); - + // if (blockIdx.y == 0 && blockIdx.x == 0 && qi == 0 && mask) { // printf("tree attn mask for first token %d, %lld, %d, %d\n", // ti, @@ -169,16 +169,22 @@ __global__ void compute_attention_kernel_fused_kernel( // bitmask.non_tree_cache_size, // qi); // } + // if (blockIdx.y == 0 && blockIdx.x == 0 && qi == 3 && mask) { + // printf("tree attn mask for third token %d, %lld, %d, %d\n", + // ti, + // bitmask.mask[ti - bitmask.non_tree_cache_size], + // bitmask.non_tree_cache_size, + // qi); + // } qk_max = mask ? qk_max : fmaxf(qk_max, qk); // if (blockIdx.y == 0 && blockIdx.x == 0 && qi == 1 && !mask) { - // printf("tree attn mask for second token %d, %lld, %d, %d, %.10f\n", + // printf("tree attn qkqkqkqk %d %.10f, %.10f, %.10f\n", // ti, - // bitmask.mask[ti - bitmask.non_tree_cache_size], - // bitmask.non_tree_cache_size, - // qi, - // qk); + // qk, + // q_vecs[ki_o][0].x, + // k[0].x); // } qk_smem[ti - first_step] = mask ? 0.0f : qk; } @@ -212,12 +218,10 @@ __global__ void compute_attention_kernel_fused_kernel( // Broadcast to all the threads in the warp. qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); - - // if (blockIdx.y == 0 && blockIdx.x == 0 && qi == 0 && tidx == 0) { - // printf("tree attn first token qk_max %f\n", - // qk_max); - // } + // if (blockIdx.y == 0 && blockIdx.x == 0 && qi == 1 && tidx == 0) { + // printf("tree attn first token qk_max %f\n", qk_max); + // } float exp_sum = 0.f; for (int ti = first_step + tidx; ti < tlength; ti += THREADS_PER_BLOCK) { @@ -244,7 +248,7 @@ __global__ void compute_attention_kernel_fused_kernel( __syncthreads(); // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0 && qi == 1) { - // printf("softmax %.10f\n", qk_smem[0]); + // printf("softmax %.10f\n", qk_smem[1]); // } // value projection @@ -280,12 +284,13 @@ __global__ void compute_attention_kernel_fused_kernel( V_vec v = *reinterpret_cast( v_cache_batch + ti_circ * hidden_size + head_idx * per_head_size); - bool const mask = - (ti >= bitmask.non_tree_cache_size && - (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & (1 << qi)))); - float logit = mask ? 0.0f : qk_smem[ti - first_step]; - out = FlexFlow::fma(logit, cast_to_float(v), out); - + if (ti < tlength) { + bool const mask = + (ti >= bitmask.non_tree_cache_size && + (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & (1 << qi)))); + float logit = mask ? 0.0f : qk_smem[ti - first_step]; + out = FlexFlow::fma(logit, cast_to_float(v), out); + } } } @@ -328,11 +333,16 @@ __global__ void compute_attention_kernel_fused_kernel( output_ptr + (first_token_idx + qi) * hidden_size + head_idx * per_head_size + vi), out); - // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0 && qi == 1) { - // printf("tree attn final value, %.9f, %.9f, %.9f, %.9f, %d, %d\n", - // out.x, out.y, out.z, out.w, vi, (first_token_idx + qi) * hidden_size + - // head_idx * per_head_size + vi); - // } + // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0 && qi == 1) { + // printf("tree attn final value, %.9f, %.9f, %.9f, %.9f, %d, %d\n", + // out.x, + // out.y, + // out.z, + // out.w, + // vi, + // (first_token_idx + qi) * hidden_size + head_idx * per_head_size + + // vi); + // } } } } @@ -349,11 +359,12 @@ __global__ void commit_tokens_kernel( int num_tokens_to_commit, int num_active_tokens_in_last_batch, int max_seq_len, - int hidden_size) { + int hidden_size, + int max_tree_branches) { - CUDA_KERNEL_LOOP(i, num_tokens_to_commit * hidden_size * 2) { + CUDA_KERNEL_LOOP(i, num_tokens_to_commit * hidden_size) { - int token_pos = i / (hidden_size * KV_WEIGHT_NUM); + int token_pos = i / (hidden_size); int token_idx_in_last_batch = committedTokenInfos[token_pos].token_index; int offset = i % hidden_size; assert(token_idx_in_last_batch < num_active_tokens_in_last_batch); @@ -367,10 +378,23 @@ __global__ void commit_tokens_kernel( int const req_id = committedTokenInfos[token_pos].request_index; int const tok_id = committedTokenInfos[token_pos].token_depth; - kCache_ptr[req_id * (hidden_size * max_seq_len) + tok_id * hidden_size + - offset] = kVal; - vCache_ptr[req_id * (hidden_size * max_seq_len) + tok_id * hidden_size + - offset] = vVal; + // if(i == 0){ + // printf("commit token: %d %d %f\n", token_idx_in_last_batch, tok_id, + // kVal); + // } + // if(i == hidden_size){ + // printf("commit token 1: %d %d %f\n", token_idx_in_last_batch, tok_id, + // kVal); + // } + // if(i == 2 * hidden_size){ + // printf("commit token 2: %d %d %f\n", token_idx_in_last_batch, tok_id, + // kVal); + // } + + kCache_ptr[req_id * max_tree_branches * (hidden_size * max_seq_len) + + tok_id * hidden_size + offset] = kVal; + vCache_ptr[req_id * max_tree_branches * (hidden_size * max_seq_len) + + tok_id * hidden_size + offset] = vVal; } } @@ -395,7 +419,8 @@ void commit_tokens(TreeIncMultiHeadSelfAttentionMeta const *m, num_tokens_to_commit, m->num_active_tokens, // number of active tokens in previous batch BatchConfig::max_sequence_length(), - m->hidden_size); + m->hidden_size, + BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES); } } @@ -413,9 +438,9 @@ __global__ void update_tree_branch_kv_cache( int total_tokens_in_batch, int max_seq_len, int hidden_size) { - CUDA_KERNEL_LOOP(i, num_tokens_in_branch * hidden_size * 2) { + CUDA_KERNEL_LOOP(i, num_tokens_in_branch * hidden_size) { - int token_idx = i / (hidden_size * KV_WEIGHT_NUM); + int token_idx = i / (hidden_size); int offset = i % hidden_size; token_idx += processed_tokens_in_batch; // get index in the whole batch @@ -460,6 +485,11 @@ __global__ void update_tree_branch_kv_cache_fused( int const req_id = tokenInfos[token_idx].request_index; int const tok_id = tokenInfos[token_idx].abs_depth_in_request; + + // if(i % hidden_size == 0){ + // printf("update token id: %d, %d\n", token_idx, token_idx + + // first_token_depth); + // } kCache_ptr[(req_id * max_tree_branches) * (hidden_size * max_seq_len) + (token_idx + first_token_depth) * hidden_size + offset] = kVal; vCache_ptr[(req_id * max_tree_branches) * (hidden_size * max_seq_len) + @@ -879,7 +909,8 @@ void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m, // Note that m->num_active_tokens stores the number of active // tokens in the previous batch, which is needed for committing // keys/values to the key-value cache - std::cout << "tokens to be committed: " << bc->num_tokens_to_commit << "\n"; + // std::cout << "tokens to be committed: " << bc->num_tokens_to_commit << + // "\n"; cudaMemcpyAsync(m->committed_token_infos, &(bc->committed_tokens), @@ -925,6 +956,7 @@ void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m, static_cast
(m->devQKVProjArray), bias_ptr, stream); + // print_tensor((float *)m->devQKVProjArray, 32, "qkvtenor"); // phase 2: No need to update key/val cache // IncMultiHeadSelfAttention::update_kv_cache_kernel( diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index e7b08f653d..d5c2b7392d 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -609,6 +609,8 @@ BeamSearchBatchConfig committed_tokens[guid].emplace_back(abs_depth, result_index); } else if (abs_depth >= root_abs_depth) { tree_outputs.emplace_back(token_id, abs_depth + 1); + std::cout << "committred tokens push: " << abs_depth + << " ,result index: " << result_index << "\n"; committed_tokens[guid].emplace_back(abs_depth, result_index); if (verbose) { @@ -789,9 +791,9 @@ BeamSearchBatchConfig // Beam Token Info new_bc.beamTokenInfo[new_bc.num_tokens].sub_request_index = 0; new_bc.num_tokens++; - std::cout << "num_gen ++ " - << "\n"; - num_generation_tokens++; + // std::cout << "num_gen ++ " + // << "\n"; + // num_generation_tokens++; // Add verified token to request's token list request.tokens.push_back(token.first); @@ -923,9 +925,7 @@ BeamSearchBatchConfig new_bc.num_tokens++; } new_bc.topology_mask[i].allocated_tokens = 0; - new_bc.causalMask[i].non_tree_cache_size = 0; - new_bc.causalMask[i].tree_size = - new_bc.requestsInfo[i].num_tokens_in_batch; + initBitMask(new_bc.causalMask[i], new_bc.requestsInfo[i].num_tokens_in_batch); @@ -1185,14 +1185,14 @@ BeamSearchBatchConfig // sub_request_num -> nodes of input next iteration // beam_size replicate num - std::cout << "print beam tree: " - << old_bc.beamRequestsInfo[i].current_depth << "\n"; + // std::cout << "print beam tree: " + // << old_bc.beamRequestsInfo[i].current_depth << "\n"; BeamTree tree = request.beam_trees[old_bc.model_id]; - for (int k = 0; k <= old_bc.beamRequestsInfo[i].current_depth; k++) { - std::cout << "layer: " << k << "\n"; - std::cout << "nodes: " << tree.treeLayers[k].nodes_num_this_layer - << "\n"; - } + // for (int k = 0; k <= old_bc.beamRequestsInfo[i].current_depth; k++) { + // std::cout << "layer: " << k << "\n"; + // std::cout << "nodes: " << tree.treeLayers[k].nodes_num_this_layer + // << "\n"; + // } appendBitMask(new_bc.causalMask[i], new_bc.beamRequestsInfo[i].sub_request_num, old_bc.beamRequestsInfo[i].beam_size, @@ -1217,9 +1217,10 @@ BeamSearchBatchConfig new_bc.topology_mask[i].real_token_pos[k][depth] = new_bc.topology_mask[i].allocated_tokens + num_generation_tokens; - std::cout << "topology: sub request: " << k << ", " - << ", " << depth << ", " - << new_bc.topology_mask[i].real_token_pos[k][depth] << "\n"; + // std::cout << "topology: sub request: " << k << ", " + // << ", " << depth << ", " + // << new_bc.topology_mask[i].real_token_pos[k][depth] << + // "\n"; num_generation_tokens++; } } @@ -1354,13 +1355,13 @@ BeamSearchBatchConfig } if (true) { - std::cout << "print all resultsBBB" - << "\n"; - for (int i = 0; i < 40; i++) { - std::cout << result.token_ids[i] << ", "; - } - std::cout << "Current Beam DepthBBB: " - << old_bc.beamRequestsInfo[0].current_depth << "\n"; + // std::cout << "print all resultsBBB" + // << "\n"; + // for (int i = 0; i < 40; i++) { + // std::cout << result.token_ids[i] << ", "; + // } + // std::cout << "Current Beam DepthBBB: " + // << old_bc.beamRequestsInfo[0].current_depth << "\n"; } return new_bc; } @@ -1449,11 +1450,11 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( merge_dfs_trees(all_dfs_trees, request.tokens.size() - 1, guid); if (true) { - std::cout << "Request Tokens Size: " << request.tokens.size() - << std::endl; - for (int k = 0; k < request.tokens.size(); k++) { - std::cout << k << ": " << request.tokens[k] << std::endl; - } + // std::cout << "Request Tokens Size: " << request.tokens.size() + // << std::endl; + // for (int k = 0; k < request.tokens.size(); k++) { + // std::cout << k << ": " << request.tokens[k] << std::endl; + // } } // Normal Request Info @@ -1475,27 +1476,42 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( new_bc.requestsInfo[i].num_tokens_in_batch = 0; new_bc.request_completed[i] = false; + std::cout << "dfs_tree_inputs: " << dfs_tree_inputs.size() << ", " + << new_bc.causalMask[i].tree_size << ", " + << new_bc.causalMask[i].non_tree_cache_size << "\n"; + std::cout << "mask: " << std::bitset<64>(new_bc.causalMask[i].mask[0]) + << "\n"; + std::cout << "mask: " << std::bitset<64>(new_bc.causalMask[i].mask[1]) + << "\n"; + std::cout << "mask: " << std::bitset<64>(new_bc.causalMask[i].mask[2]) + << "\n"; + std::cout << "mask: " << std::bitset<64>(new_bc.causalMask[i].mask[3]) + << "\n"; + std::cout << "mask: " << std::bitset<64>(new_bc.causalMask[i].mask[4]) + << "\n"; + // Committed Tokens if (committed_tokens.find(guid) != committed_tokens.end()) { - for (int j = 0; j < dfs_tree_inputs.size(); j++) { - if (j < committed_tokens.at(guid).size()) { - auto committed_token = committed_tokens.at(guid).at(j); - new_bc.committed_tokens[new_bc.num_tokens_to_commit].token_index = - committed_token.second; - new_bc.committed_tokens[new_bc.num_tokens_to_commit].request_index = - i; - new_bc.committed_tokens[new_bc.num_tokens_to_commit].token_depth = - committed_token.first; - if (true) { - std::cout << new_bc.num_tokens_to_commit - << "- committed_token.token_depth: " - << committed_token.first - << ", token_index: " << committed_token.second - << std::endl; - } - new_bc.num_tokens_to_commit++; - request.llm_cache_size++; + for (int j = 0; j < committed_tokens.at(guid).size(); j++) { + // if (j < committed_tokens.at(guid).size()) { + + auto committed_token = committed_tokens.at(guid).at(j); + new_bc.committed_tokens[new_bc.num_tokens_to_commit].token_index = + committed_token.second; + new_bc.committed_tokens[new_bc.num_tokens_to_commit].request_index = + i; + new_bc.committed_tokens[new_bc.num_tokens_to_commit].token_depth = + committed_token.first; + if (true) { + std::cout << new_bc.num_tokens_to_commit + << "- committed_token.token_depth: " + << committed_token.first + << ", token_index: " << committed_token.second + << std::endl; } + new_bc.num_tokens_to_commit++; + request.llm_cache_size++; + // } } } if (true) { @@ -1759,11 +1775,11 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, request.beam_trees.at(old_bc.model_id) .treeLayers[depth] .parent_ids[beam_id] = result.parent_id[result_index]; - std::cout << "??????? beam id: " << beam_id << ", token: " - << request.beam_trees.at(old_bc.model_id) - .treeLayers[depth] - .tokens[beam_id] - << "\n"; + // std::cout << "??????? beam id: " << beam_id << ", token: " + // << request.beam_trees.at(old_bc.model_id) + // .treeLayers[depth] + // .tokens[beam_id] + // << "\n"; // if (true) { // std::cout << "tree value: " << depth << "token: " @@ -1844,19 +1860,20 @@ void RequestManager::initBitMask(BatchConfig::BitMask &bitmask, "do not support tree size > 64"); // eg. 4 tokens: t1: 0000000..1111, t2: 0000000..1110, t3: 0000000..1100, t4: // 0000000..1000 + bitmask.non_tree_cache_size = 0; + bitmask.tree_size = initLength; bitmask.prompt_size = initLength; bitmask.this_layer_size = initLength; - bitmask.tree_size = initLength; for (int i = 0; i < bitmask.prompt_size; i++) { for (int j = i; j < bitmask.prompt_size; j++) { bitmask.mask[i] |= (1 << j); } } - std::cout << "see bit mask" << bitmask.prompt_size << "\n"; - std::cout << "see bit mask" << std::bitset<64>(bitmask.mask[0]) << "\n"; - std::cout << "see bit mask" << std::bitset<64>(bitmask.mask[1]) << "\n"; - std::cout << "see bit mask" << std::bitset<64>(bitmask.mask[2]) << "\n"; + // std::cout << "see bit mask" << bitmask.prompt_size << "\n"; + // std::cout << "see bit mask" << std::bitset<64>(bitmask.mask[0]) << "\n"; + // std::cout << "see bit mask" << std::bitset<64>(bitmask.mask[1]) << "\n"; + // std::cout << "see bit mask" << std::bitset<64>(bitmask.mask[2]) << "\n"; } // prepare next init @@ -1868,11 +1885,16 @@ void RequestManager::updateBitMask(BatchConfig::BitMask &bitmask, // 0000000..1000 assert(initLength <= BatchConfig::MAX_SPEC_TREE_TOKEN_NUM && "do not support tree size > 64"); - bitmask.non_tree_cache_size = non_tree_size; - bitmask.tree_size = initLength; + assert(initLength >= 1 && "verified token num should >= 1"); + + std::cout << "non tree size: " << non_tree_size << ", " + << bitmask.non_tree_cache_size << "\n"; + + bitmask.non_tree_cache_size = non_tree_size + initLength - 1; + bitmask.tree_size = 1; bitmask.this_layer_size = initLength; std::cout << "non_tree_size: " << non_tree_size << "\n"; - bitmask.prompt_size = initLength; + bitmask.prompt_size = 1; for (int i = 0; i < bitmask.prompt_size; i++) { for (int j = i; j < bitmask.prompt_size; j++) { bitmask.mask[i] |= (1 << j); @@ -1906,14 +1928,14 @@ void RequestManager::appendBitMask(BatchConfig::BitMask &bitmask, for (int i = 0; i < bitmask.prompt_size; i++) { for (int j = pre_tree_size; j < bitmask.tree_size; j++) { bitmask.mask[i] |= (1 << j); - std::cout << "see bit mask append: " << i << ", to" << j - << std::bitset<64>(bitmask.mask[i]) << "\n"; + // std::cout << "see bit mask append: " << i << ", to" << j + // << std::bitset<64>(bitmask.mask[i]) << "\n"; } } - std::cout << "bitmask.tree_size: " << bitmask.tree_size << ", " - << pre_tree_size << ", " << bitmask.prompt_size << ", " - << preBeamSize << "\n"; + // std::cout << "bitmask.tree_size: " << bitmask.tree_size << ", " + // << pre_tree_size << ", " << bitmask.prompt_size << ", " + // << preBeamSize << "\n"; // int num_groups = newNodes / preBeamSize; // int group_size = newNodes / num_groups; @@ -1924,12 +1946,12 @@ void RequestManager::appendBitMask(BatchConfig::BitMask &bitmask, // skip the root prompt/tokens int token_idx = bitmask.prompt_size; int new_nodes_start_idx = pre_tree_size; - std::cout << "new nodes start " << new_nodes_start_idx << "\n"; + // std::cout << "new nodes start " << new_nodes_start_idx << "\n"; for (int i = 1; i < currentDepth; i++) { new_nodes_start_idx = pre_tree_size; int nodes_this_layer = tree.treeLayers[i].nodes_num_this_layer; - std::cout << "tree layer: " << i << " nodes:" << nodes_this_layer - << "group size: " << newNodes / nodes_this_layer << "\n"; + // std::cout << "tree layer: " << i << " nodes:" << nodes_this_layer + // << "group size: " << newNodes / nodes_this_layer << "\n"; for (int j = 0; j < nodes_this_layer; j++) { int group_size = newNodes / nodes_this_layer; for (int k = 0; k < group_size; k++) { @@ -1940,12 +1962,12 @@ void RequestManager::appendBitMask(BatchConfig::BitMask &bitmask, } } - std::cout << "token idx: " << token_idx << ", " << pre_tree_size << ", " - << new_nodes_start_idx << ", " << newNodes - << "current depth: " << currentDepth << "\n"; - std::cout << "new nodes end " << new_nodes_start_idx << "\n"; + // std::cout << "token idx: " << token_idx << ", " << pre_tree_size << ", " + // << new_nodes_start_idx << ", " << newNodes + // << "current depth: " << currentDepth << "\n"; + // std::cout << "new nodes end " << new_nodes_start_idx << "\n"; - std::cout << "tree size: " << bitmask.tree_size << "\n"; + // std::cout << "tree size: " << bitmask.tree_size << "\n"; assert(token_idx == pre_tree_size); assert(currentDepth <= 1 || new_nodes_start_idx == bitmask.tree_size); @@ -1953,8 +1975,23 @@ void RequestManager::appendBitMask(BatchConfig::BitMask &bitmask, // set last layer, all tokens are only relevant to it self; for (int i = token_idx; i < bitmask.tree_size; i++) { bitmask.mask[i] |= (1 << i); - std::cout << "set rel: " << i << "to: " << i << "\n"; + // std::cout << "set rel: " << i << "to: " << i << "\n"; } + + // if(bitmask.non_tree_cache_size == 19 && bitmask.tree_size > 2){ + // assert(false); + // } + + std::cout << "see bit mask append" << bitmask.prompt_size << "\n"; + std::cout << "see bit mask append" << bitmask.non_tree_cache_size << "\n"; + std::cout << "see bit mask append" << std::bitset<64>(bitmask.mask[0]) + << "\n"; + std::cout << "see bit mask append" << std::bitset<64>(bitmask.mask[1]) + << "\n"; + std::cout << "see bit mask append" << std::bitset<64>(bitmask.mask[2]) + << "\n"; + std::cout << "see bit mask append" << std::bitset<64>(bitmask.mask[3]) + << "\n"; } bool PreOrder( @@ -2146,16 +2183,16 @@ std::vector> // << std::endl; BeamTree tree = request.beam_trees.at(old_bc.model_id); - std::cout << "print beam tree: " - << "\n"; + // std::cout << "print beam tree: " + // << "\n"; std::vector> serializedTree; for (int i = 0; i <= old_bc.beamRequestsInfo[request_index].max_depth; i++) { - std::cout << "tree layer: " << i - << ", num_nodes: " << tree.treeLayers[i].nodes_num_this_layer - << "\n"; + // std::cout << "tree layer: " << i + // << ", num_nodes: " << tree.treeLayers[i].nodes_num_this_layer + // << "\n"; // push tokens into tree for (int j = 0; j < tree.treeLayers[i].nodes_num_this_layer; j++) { - std::cout << "token: " << tree.treeLayers[i].tokens[j] << "\n"; + // std::cout << "token: " << tree.treeLayers[i].tokens[j] << "\n"; serializedTree.push_back(std::make_pair(tree.treeLayers[i].tokens[j], i)); } } @@ -2256,7 +2293,8 @@ std::vector> } dfs_tree_inputs[guid] = merged_tree; - // std::cout << "assign dfr tree: " << guid << ", " << merged_tree.size() << ", " + // std::cout << "assign dfr tree: " << guid << ", " << merged_tree.size() << + // ", " // << dfs_tree_inputs[guid].size() << "\n"; return merged_tree; From 3ed25d681127d742770776b8d07d9771e0e19f79 Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Fri, 29 Dec 2023 16:10:16 -0500 Subject: [PATCH 07/30] multi batch --- src/ops/beam_topk.cc | 3 +- src/ops/beam_topk.cu | 3 +- .../specinfer_inc_multihead_self_attention.cu | 66 +++++++------------ .../tree attn kernel, 0----> -0.029753357172 | 1 + src/ops/tree_inc_multihead_self_attention.cu | 45 +++++++++---- src/runtime/request_manager.cc | 37 ++++++++--- 6 files changed, 89 insertions(+), 66 deletions(-) create mode 100644 src/ops/tree attn kernel, 0----> -0.029753357172 diff --git a/src/ops/beam_topk.cc b/src/ops/beam_topk.cc index 3f636c2c98..20d019eec3 100644 --- a/src/ops/beam_topk.cc +++ b/src/ops/beam_topk.cc @@ -402,8 +402,7 @@ BeamInferenceResult download_tensor( parent_ptr, ir.parent_id, batch_size * m->max_beam_width); - print_tensor(index_ptr, 32, "indexxxxxxx"); - printf("max beam width %d\n", m->max_beam_width); + // print_tensor(index_ptr, 32, "indexxxxxxx"); if (m->inference_debugging) { assert(task->index_point.get_dim() == 1); diff --git a/src/ops/beam_topk.cu b/src/ops/beam_topk.cu index 515bba4bc0..d647fe9ed7 100644 --- a/src/ops/beam_topk.cu +++ b/src/ops/beam_topk.cu @@ -626,7 +626,7 @@ void BeamTopK::forward_kernel(BeamTopKMeta const *m, stream)); // trick, set acc_probs to 0; checkCUDA( - cudaMemsetAsync(m->acc_probs, 1.0, batch_size * sizeof(DT), stream)); + cudaMemsetAsync(m->acc_probs, 1.0, max_total_requests * sizeof(DT), stream)); checkCUDA(cudaMemcpyAsync(m->block_start_index, beam_block_start_index.data(), sizeof(int) * beam_num_blocks, @@ -644,6 +644,7 @@ void BeamTopK::forward_kernel(BeamTopKMeta const *m, stream)); // int depth = // bc->beamRequestsInfo[bc->tokensInfo[0].request_index].current_depth; + beam_num_blocks = bc->num_active_tokens(); beam_topk_forward_kernel<<>>( input_ptr, shared_memory_size, diff --git a/src/ops/specinfer_inc_multihead_self_attention.cu b/src/ops/specinfer_inc_multihead_self_attention.cu index f2ea63d904..3fdd1ab554 100644 --- a/src/ops/specinfer_inc_multihead_self_attention.cu +++ b/src/ops/specinfer_inc_multihead_self_attention.cu @@ -100,6 +100,10 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( first_token_idx += bitmask.this_layer_size; } + // if (tidx == 0 && head_idx == 0) { + // printf("spec req: %d, %d\n", request_idx, first_token_idx); + // } + // shared memory objects extern __shared__ char smem_[]; @@ -135,17 +139,16 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( int ti_end = div_up(totalCacheSize - first_step, K_PER_WARP) * K_PER_WARP + first_step; - for (int sub_req_idx = 0; sub_req_idx < tree_branch_num; sub_req_idx += 1) { + for (int qi = 0; qi < tree_branch_num; qi += 1) { #pragma unroll for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { q_vecs[ki_o][ii] = *reinterpret_cast( - q_ptr + (hidden_size * QKV_WEIGHT_NUM * sub_req_idx) + ki + + q_ptr + (hidden_size * QKV_WEIGHT_NUM * qi) + ki + ii * THREADS_PER_KEY * K_VEC_SIZE); } - int const query_token = bitmask.tree_size - tree_branch_num + sub_req_idx; - - if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0 && sub_req_idx == 0) { + int const query_token = bitmask.tree_size - tree_branch_num + qi; + if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0 && qi == 0) { // printf("fuckmasksss %d, %d, %d, %d, %d\n", // bitmask.prompt_size, // bitmask.non_tree_cache_size, @@ -345,11 +348,10 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( // Output the final values. if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { - convert_from_float( - *reinterpret_cast(output_ptr + - (request_idx + sub_req_idx) * hidden_size + - head_idx * per_head_size + vi), - out); + convert_from_float(*reinterpret_cast( + output_ptr + (first_token_idx + qi) * hidden_size + + head_idx * per_head_size + vi), + out); } } } @@ -391,6 +393,9 @@ __global__ void specinfer_store_kv_cache( int const allocated_tokens = beam_topology_mask[req_id].allocated_tokens; int const total_token = requestInfo[req_id].num_tokens_in_batch; + int const request_token_offset = + requestInfo[req_id].first_token_offset_in_batch; + BatchConfig::BitMask bitmask = causalMask[req_id]; int const sub_request_num = beamRequestInfos[req_id].sub_request_num; @@ -404,42 +409,18 @@ __global__ void specinfer_store_kv_cache( // if prompt token -> token id // if tree token: int const cache_idx = bitmask.non_tree_cache_size + bitmask.tree_size - - bitmask.this_layer_size + token_idx; + bitmask.this_layer_size + token_idx - + request_token_offset; int real_idx = tok_id - first_token_in_req + allocated_tokens + sub_req_id; - // if (i == 0) { - // printf("ffasdasds%d, %d, %d, %d, %d, %d, %d, %d, %d, %d, %d\n", - // beamTokenInfos[0].sub_request_index, - // allocated_tokens, - // sub_req_id, - // tok_id, - // first_token_in_req, + // if (i % hidden_size == 0) { + // printf("ffasdasds request %d, real idx %d, cache idx %d token id %d, kval %.10f\n", + // req_id, // real_idx, // cache_idx, - // bitmask.non_tree_cache_size, - // bitmask.tree_size, - // sub_request_num, - // token_idx ); - // } else if (i == hidden_size * 2) { - // printf("hshddhdhdsdaww%d, %d, %d, %d, %d, %d, %d\n", - // beamTokenInfos[0].sub_request_index, - // allocated_tokens, - // sub_req_id, // tok_id, - // first_token_in_req, - // real_idx, - // cache_idx); - // } - - // if (i % hidden_size == 0) { - // printf("update cache: %d, %d, %d, %d, %d, %d\n", - // cache_idx, - // num_tokens, - // bitmask.non_tree_cache_size, - // bitmask.tree_size, - // bitmask.this_layer_size, - // token_idx); + // kVal); // } kCache_ptr[(req_id * max_tree_branches) * (hidden_size * max_seq_len) + @@ -846,6 +827,8 @@ void inference_kernel(SpecInferIncMultiHeadSelfAttentionMeta const *m, // compute output production and bias together for all tokens int num_tokens = bc->num_active_tokens(); + // std::cout << "specinfer num tokens: " << num_tokens; + compute_o_prod_bias( m, bc, shard_id, output_ptr, weight_ptr, bias_ptr, num_tokens, stream); } @@ -920,7 +903,8 @@ void SpecInferIncMultiHeadSelfAttention::inference_kernel_wrapper( // "[Attention:forward:query]"); print_tensor<3, float>(acc_output.ptr, // acc_output.rect, "[Attention:forward:output]"); } - // print_tensor(output.get_float_ptr(), 32, "specinc output"); + // save_tensor(output.get_float_ptr(), 768 * 3, "/home/xinhaoc/FlexFlow/inference/output/fk1.txt"); + // save_tensor(output.get_float_ptr() + 768 * 3, 768 * 3, "/home/xinhaoc/FlexFlow/inference/output/fk2.txt"); // if(bc->num_tokens == 1){ // print_tensor(input.get_float_ptr(), 32, "specinc input"); diff --git a/src/ops/tree attn kernel, 0----> -0.029753357172 b/src/ops/tree attn kernel, 0----> -0.029753357172 new file mode 100644 index 0000000000..e4f14ee757 --- /dev/null +++ b/src/ops/tree attn kernel, 0----> -0.029753357172 @@ -0,0 +1 @@ +tree attn kernel, 0----> -0.02975335717201232910 0.01930358447134494781 0.03780741989612579346 0.11878532171249389648 -0.03523746877908706665 0.02421043440699577332 0.03719477355480194092 -0.00304851122200489044 0.02062662504613399506 0.06683708727359771729 -0.00642335414886474609 -0.00504039414227008820 0.02955199964344501495 0.00648811273276805878 0.00558663159608840942 0.02003456838428974152 -0.04041406139731407166 0.00736814411357045174 -0.04575226455926895142 0.03949077427387237549 0.05742383748292922974 0.04866250604391098022 0.04687267541885375977 -0.00701304525136947632 -0.03712264448404312134 -0.02175992354750633240 -0.03979443758726119995 0.03961737453937530518 -0.07450901716947555542 0.02090370282530784607 -0.03487894684076309204 0.01653470844030380249 \ No newline at end of file diff --git a/src/ops/tree_inc_multihead_self_attention.cu b/src/ops/tree_inc_multihead_self_attention.cu index 180a165451..11169fa36d 100644 --- a/src/ops/tree_inc_multihead_self_attention.cu +++ b/src/ops/tree_inc_multihead_self_attention.cu @@ -99,6 +99,10 @@ __global__ void compute_attention_kernel_fused_kernel( first_token_idx += request_infos[request_idx].num_tokens_in_batch; } + // if(tidx == 0 && head_idx == 0){ + // printf("tree req: %d, %d\n", request_idx, first_token_idx); + // } + // shared memory objects extern __shared__ char smem_[]; @@ -140,6 +144,12 @@ __global__ void compute_attention_kernel_fused_kernel( q_vecs[ki_o][ii] = *reinterpret_cast( q_ptr + (hidden_size * QKV_WEIGHT_NUM * qi) + ki + ii * THREADS_PER_KEY * K_VEC_SIZE); + + // if (head_idx == 0 && qi == 1 && tidx == 0) { + // printf("laod q %d, %d %.10f\n", + // request_idx, + // qi,q_vecs[ki_o][ii].x); + // } } __syncthreads(); @@ -162,11 +172,12 @@ __global__ void compute_attention_kernel_fused_kernel( (ti >= bitmask.non_tree_cache_size && (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & (1 << qi)))); - // if (blockIdx.y == 0 && blockIdx.x == 0 && qi == 0 && mask) { - // printf("tree attn mask for first token %d, %lld, %d, %d\n", + // if (head_idx == 0 && qi == 9 && mask) { + // printf("tree attn mask for first token %d, %lld, %d, %d, %d\n", // ti, // bitmask.mask[ti - bitmask.non_tree_cache_size], // bitmask.non_tree_cache_size, + // request_idx, // qi); // } // if (blockIdx.y == 0 && blockIdx.x == 0 && qi == 3 && mask) { @@ -179,11 +190,15 @@ __global__ void compute_attention_kernel_fused_kernel( qk_max = mask ? qk_max : fmaxf(qk_max, qk); - // if (blockIdx.y == 0 && blockIdx.x == 0 && qi == 1 && !mask) { - // printf("tree attn qkqkqkqk %d %.10f, %.10f, %.10f\n", + // if (head_idx == 0 && qi == 1 && !mask && tidx == 0) { + // printf("tree attn qkqkqkqk request id %d, %d %.10f, %.10f, %.10f\n", + // request_idx, // ti, // qk, // q_vecs[ki_o][0].x, + // q_vecs[ki_o][1].x, + // q_vecs[ki_o][2].x, + // q_vecs[ki_o][3].x, // k[0].x); // } qk_smem[ti - first_step] = mask ? 0.0f : qk; @@ -219,7 +234,7 @@ __global__ void compute_attention_kernel_fused_kernel( // Broadcast to all the threads in the warp. qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); - // if (blockIdx.y == 0 && blockIdx.x == 0 && qi == 1 && tidx == 0) { + // if (head_idx == 0 && qi == 9 && tidx == 0) { // printf("tree attn first token qk_max %f\n", qk_max); // } @@ -236,7 +251,7 @@ __global__ void compute_attention_kernel_fused_kernel( // Compute the sum. exp_sum = block_sum(&red_smem[WARPS_PER_BLOCK], exp_sum); - // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0 && qi == 1) { + // if (head_idx == 0 && tidx == 0 && qi == 9) { // printf("expsum %.10f\n", exp_sum); // } @@ -247,7 +262,7 @@ __global__ void compute_attention_kernel_fused_kernel( } __syncthreads(); - // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0 && qi == 1) { + // if (head_idx == 0 && tidx == 0 && qi == 9) { // printf("softmax %.10f\n", qk_smem[1]); // } @@ -465,6 +480,7 @@ __global__ void update_tree_branch_kv_cache_fused( DT *kCache_ptr, DT *vCache_ptr, TreeVerifyBatchConfig::PerTokenInfo const *tokenInfos, + BatchConfig::PerRequestInfo *request_infos, int qProjSize, int kProjSize, int vProjSize, @@ -486,14 +502,15 @@ __global__ void update_tree_branch_kv_cache_fused( int const req_id = tokenInfos[token_idx].request_index; int const tok_id = tokenInfos[token_idx].abs_depth_in_request; + int const request_token_offset = request_infos[req_id].first_token_offset_in_batch; + // if(i % hidden_size == 0){ - // printf("update token id: %d, %d\n", token_idx, token_idx + - // first_token_depth); + // printf("update token request id: %d, %d, %d value%.10f\n", req_id, token_idx, request_token_offset, kVal); // } kCache_ptr[(req_id * max_tree_branches) * (hidden_size * max_seq_len) + - (token_idx + first_token_depth) * hidden_size + offset] = kVal; + (token_idx + first_token_depth - request_token_offset) * hidden_size + offset] = kVal; vCache_ptr[(req_id * max_tree_branches) * (hidden_size * max_seq_len) + - (token_idx + first_token_depth) * hidden_size + offset] = vVal; + (token_idx + first_token_depth - request_token_offset) * hidden_size + offset] = vVal; } } @@ -851,6 +868,7 @@ void compute_attention_kernel_fused(TreeIncMultiHeadSelfAttentionMeta const *m, static_cast
(m->keyCache), static_cast
(m->valueCache), m->token_infos, + m->request_infos, m->qProjSize, m->kProjSize, m->vProjSize, @@ -956,7 +974,8 @@ void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m, static_cast
(m->devQKVProjArray), bias_ptr, stream); - // print_tensor((float *)m->devQKVProjArray, 32, "qkvtenor"); + // print_tensor((float *)m->devQKVProjArray + 768 * 8 * 3 + 768, 32, "qkvtenor1"); + // print_tensor((float *)m->devQKVProjArray + 768 * 18 * 3 + 768, 32, "qkvtenor2"); // phase 2: No need to update key/val cache // IncMultiHeadSelfAttention::update_kv_cache_kernel( @@ -1000,6 +1019,8 @@ void TreeIncMultiHeadSelfAttention::inference_kernel_wrapper( cudaEventRecord(t_start, stream); } + std::cout << "tree input tokens: " <num_active_tokens() << "\n"; + // assert(input.data_type == weight.data_type); assert(input.data_type == output.data_type); if (use_bias) { diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index d5c2b7392d..ab062a4610 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -634,6 +634,7 @@ BeamSearchBatchConfig if (request.status == Request::RUNNING) { std::cout << "verify running: " << dfs_tree_inputs.at(guid).size() << ", " << tree_outputs.size() << "\n"; + std::vector> verified_tokens = traverse_verify_tree(guid, dfs_tree_inputs.at(guid), tree_outputs); @@ -812,6 +813,7 @@ BeamSearchBatchConfig } log_req_mgr.print("Output: %s", output.c_str()); } + } else if (request.status == Request::PENDING) { new_bc.request_completed[i] = false; new_bc.request_running[i] = false; @@ -1185,8 +1187,8 @@ BeamSearchBatchConfig // sub_request_num -> nodes of input next iteration // beam_size replicate num - // std::cout << "print beam tree: " - // << old_bc.beamRequestsInfo[i].current_depth << "\n"; + std::cout << "print beam tree: " + << old_bc.beamRequestsInfo[i].current_depth << "\n"; BeamTree tree = request.beam_trees[old_bc.model_id]; // for (int k = 0; k <= old_bc.beamRequestsInfo[i].current_depth; k++) { // std::cout << "layer: " << k << "\n"; @@ -1224,6 +1226,12 @@ BeamSearchBatchConfig num_generation_tokens++; } } + // if(new_bc.beamRequestsInfo[i].current_depth >= 3 && i > 0){ + // assert(false); + // } + + + } } @@ -1709,6 +1717,8 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, old_bc.requestsInfo[old_bc.tokensInfo[i].request_index].request_guid != guid) { + std::cout << "i is: " << i << "old guid" << guid << " new guid" << old_bc.requestsInfo[old_bc.tokensInfo[i].request_index].request_guid <<"\n"; + int index = old_bc.tokensInfo[i - 1].request_index; int beam_size = old_bc.beamRequestsInfo[index].beam_size; @@ -1722,16 +1732,21 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, // Count tokens sent to model in this request to find the final token's // index + + std::cout << "previous result index: "<< result_index; + result_index += (old_bc.tokensInfo[i - 1].abs_depth_in_request - start_depth) * beam_size; - - if (true) { - std::cout << "i = " << i << ", result index = " << result_index - << ", value: " << result.token_ids[result_index] - << ", leaf node num: " << leaf_node_num << ", depth" << depth - << ", beam size: " << beam_size << "\n"; - } + + std::cout << "after result index: "<< result_index; + + // if (true) { + // std::cout << "i = " << i << ", result index = " << result_index + // << ", value: " << result.token_ids[result_index] + // << ", leaf node num: " << leaf_node_num << ", depth" << depth + // << ", beam size: " << beam_size << "\n"; + // } Request &request = all_requests[old_bc.requestsInfo[index].request_guid]; @@ -1792,7 +1807,9 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, } // update the guid and start_depth for current request if (i < old_bc.num_tokens) { - guid = old_bc.requestsInfo[index].request_guid; + int new_req_idx = old_bc.tokensInfo[i].request_index; + guid = old_bc.requestsInfo[new_req_idx].request_guid; + std::cout << "update guid: " << guid << ", request idx: " << index<< "\n"; start_depth = old_bc.tokensInfo[i].abs_depth_in_request; } } From 5c3ad3592f7b71dc705466fa24cb7c7c1e179deb Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Fri, 29 Dec 2023 17:37:28 -0500 Subject: [PATCH 08/30] copy metadata once --- include/flexflow/batch_config.h | 6 -- include/flexflow/config.h | 4 +- .../specinfer_inc_multihead_self_attention.h | 1 - src/ops/inc_multihead_self_attention.cu | 13 --- .../specinfer_inc_multihead_self_attention.cu | 94 ++++--------------- src/ops/tree_inc_multihead_self_attention.cu | 65 ++++++------- src/runtime/request_manager.cc | 46 +-------- src/runtime/request_manager.cu | 74 ++++++++------- 8 files changed, 89 insertions(+), 214 deletions(-) diff --git a/include/flexflow/batch_config.h b/include/flexflow/batch_config.h index db5d4a8e48..c3a75e59a4 100644 --- a/include/flexflow/batch_config.h +++ b/include/flexflow/batch_config.h @@ -168,14 +168,8 @@ class BeamSearchBatchConfig : public BatchConfig { int sub_request_index; }; - struct SpecInferTopology { - int real_token_pos[MAX_SPECULATIVE_TREE_BRANCHES][MAX_NUM_TOKENS]; - int allocated_tokens; - }; - BeamSearchPerRequestInfo beamRequestsInfo[MAX_NUM_REQUESTS]; BeamSearchPerTokenInfo beamTokenInfo[MAX_NUM_TOKENS * MAX_BEAM_WIDTH]; - SpecInferTopology topology_mask[MAX_NUM_REQUESTS]; // why is this == MAX_NUM_REQUESTS * MAX_BEAM_WIDTH? int sub_requests[MAX_NUM_REQUESTS * MAX_BEAM_WIDTH]; diff --git a/include/flexflow/config.h b/include/flexflow/config.h index fe261dfb48..1526b9291f 100644 --- a/include/flexflow/config.h +++ b/include/flexflow/config.h @@ -81,10 +81,10 @@ struct FFHandler { // request info + token info + topolopgy mask info size_t batch_config_metadata_size = sizeof(BatchConfig::tokensInfo) + sizeof(BatchConfig::requestsInfo) + - sizeof(BeamSearchBatchConfig::topology_mask) + sizeof(BeamSearchBatchConfig::beamTokenInfo) + sizeof(BeamSearchBatchConfig::beamRequestsInfo) + - sizeof(BatchConfig::causalMask); + sizeof(BatchConfig::causalMask) + + sizeof(TreeVerifyBatchConfig::committed_tokens); void *offload_reserve_space; size_t offload_reserve_space_size; DataType quantization_type; diff --git a/include/flexflow/ops/specinfer_inc_multihead_self_attention.h b/include/flexflow/ops/specinfer_inc_multihead_self_attention.h index eb1b2882c3..b6fed1ae25 100644 --- a/include/flexflow/ops/specinfer_inc_multihead_self_attention.h +++ b/include/flexflow/ops/specinfer_inc_multihead_self_attention.h @@ -142,7 +142,6 @@ class SpecInferIncMultiHeadSelfAttentionMeta : public IncMultiHeadSelfAttentionM Realm::RegionInstance beam_search_reserve_inst; BeamSearchBatchConfig::BeamSearchPerTokenInfo *beam_token_infos; BeamSearchBatchConfig::BeamSearchPerRequestInfo *beam_request_infos; - BeamSearchBatchConfig::SpecInferTopology *beam_topology_mask; BatchConfig::BitMask *causalMask; }; diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index a05dbbf919..a084f216e9 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -825,19 +825,6 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta const *m, bias_ptr = static_cast
(m->bias_ptr); } - // todo Xinhao copy how many requests if requests are not continous? - // cudaMemcpyAsync(m->token_infos, - // &(bc->tokensInfo), - // bc->num_active_tokens() * - // sizeof(BatchConfig::PerTokenInfo), cudaMemcpyHostToDevice, - // stream); - // cudaMemcpyAsync(m->request_infos, - // &(bc->requestsInfo), - // bc->max_requests_per_batch() * - // sizeof(BatchConfig::PerRequestInfo), - // cudaMemcpyHostToDevice, - // stream); - // phase 1: Implement kernel to compute KQV for input tokens compute_qkv_kernel(m, bc, diff --git a/src/ops/specinfer_inc_multihead_self_attention.cu b/src/ops/specinfer_inc_multihead_self_attention.cu index 3fdd1ab554..4d4afd28e4 100644 --- a/src/ops/specinfer_inc_multihead_self_attention.cu +++ b/src/ops/specinfer_inc_multihead_self_attention.cu @@ -50,7 +50,6 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( int hidden_size, BatchConfig::PerRequestInfo *request_infos, BeamSearchBatchConfig::BeamSearchPerRequestInfo *beam_request_infos, - BeamSearchBatchConfig::SpecInferTopology *topology_mask, BatchConfig::BitMask *causalMask, int max_tree_branches) { @@ -74,8 +73,6 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( // request idx int const request_idx = blockIdx.y; - BeamSearchBatchConfig::SpecInferTopology topology = - topology_mask[request_idx]; BatchConfig::BitMask bitmask = causalMask[request_idx]; int const first_step = 0; @@ -148,23 +145,7 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( } int const query_token = bitmask.tree_size - tree_branch_num + qi; - if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0 && qi == 0) { - // printf("fuckmasksss %d, %d, %d, %d, %d\n", - // bitmask.prompt_size, - // bitmask.non_tree_cache_size, - // tree_branch_num, - // bitmask.tree_size, - // tlength); - // printf("cacheposssssB %d, %d\n", tree_branch_num, - // topology.real_token_pos[0][1]); - // printf("cacheposssssC %d, %d\n", tree_branch_num, - // topology.real_token_pos[0][2]); - // printf("cacheposssssD %d, %d\n", tree_branch_num, - // topology.real_token_pos[0][11]); printf("cacheposssssD %d, %d\n", - // tree_branch_num, topology.real_token_pos[0][12]); - // printf("cacheposssssD %d, %d\n", tree_branch_num, - // topology.real_token_pos[0][13]); - } + __syncthreads(); for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { K_vec k[K_VECS_PER_THREAD]; @@ -173,10 +154,7 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { int jj = ii * THREADS_PER_KEY * K_VEC_SIZE; if (ti < totalCacheSize) { - // find the real position of the cache; - // depth: 0, 1, 2, 3, 4, 4, 5, 5 ,5, 5, - // int const real_cache_idx = - // topology.real_token_pos[sub_req_idx][ti]; + k[ii] = *reinterpret_cast( k_cache_batch + ti_circ * hidden_size + head_idx * per_head_size + jj); @@ -291,17 +269,12 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( DT const *v_cache_batch = value_cache + request_idx * max_seq_length * hidden_size * max_tree_branches + vi; - // DT const *v_cache_batch = - // value_cache + - // (beam_request_idx * max_beam_width + beam_sub_request_idx) * - // max_seq_length * hidden_size + - // vi; + if (Dh == Dh_MAX || vi < Dh) { for (int ti = first_step + vo; ti < totalCacheSize; ti += V_PER_ITER) { // Load the values from the cache. int const ti_circ = ti % max_seq_length; - // int const real_cache_idx = topology.real_token_pos[sub_req_idx][ti]; V_vec v = *reinterpret_cast( v_cache_batch + ti_circ * hidden_size + head_idx * per_head_size); @@ -365,7 +338,6 @@ __global__ void specinfer_store_kv_cache( BatchConfig::PerRequestInfo *requestInfo, BeamSearchBatchConfig::BeamSearchPerTokenInfo *beamTokenInfos, BeamSearchBatchConfig::BeamSearchPerRequestInfo *beamRequestInfos, - BeamSearchBatchConfig::SpecInferTopology *beam_topology_mask, BatchConfig::BitMask *causalMask, int qProjSize, int kProjSize, @@ -390,7 +362,6 @@ __global__ void specinfer_store_kv_cache( int const first_token_in_req = requestInfo[req_id].first_token_depth_in_request; int const sub_req_id = beamTokenInfos[token_idx].sub_request_index; - int const allocated_tokens = beam_topology_mask[req_id].allocated_tokens; int const total_token = requestInfo[req_id].num_tokens_in_batch; int const request_token_offset = @@ -412,17 +383,6 @@ __global__ void specinfer_store_kv_cache( bitmask.this_layer_size + token_idx - request_token_offset; - int real_idx = tok_id - first_token_in_req + allocated_tokens + sub_req_id; - - // if (i % hidden_size == 0) { - // printf("ffasdasds request %d, real idx %d, cache idx %d token id %d, kval %.10f\n", - // req_id, - // real_idx, - // cache_idx, - // tok_id, - // kVal); - // } - kCache_ptr[(req_id * max_tree_branches) * (hidden_size * max_seq_len) + (cache_idx)*hidden_size + offset] = kVal; vCache_ptr[(req_id * max_tree_branches) * (hidden_size * max_seq_len) + @@ -454,7 +414,6 @@ void update_kv_cache_kernel(SpecInferIncMultiHeadSelfAttentionMeta const *m, m->request_infos, m->beam_token_infos, m->beam_request_infos, - m->beam_topology_mask, m->causalMask, m->qProjSize, m->kProjSize, @@ -490,7 +449,6 @@ void update_kv_cache_kernel(SpecInferIncMultiHeadSelfAttentionMeta const *m, m->hidden_size, \ m->request_infos, \ m->beam_request_infos, \ - m->beam_topology_mask, \ m->causalMask, \ BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES) @@ -788,16 +746,6 @@ void inference_kernel(SpecInferIncMultiHeadSelfAttentionMeta const *m, cudaStream_t stream) { // phase 1: Implement kernel to compute KQV for input tokens - cudaMemcpyAsync(m->causalMask, - &(bc->causalMask), - bc->num_active_requests() * sizeof(BatchConfig::BitMask), - cudaMemcpyHostToDevice, - stream); - // std::cout << "kernel bit mask: " << bc->causalMask[0].prompt_size << ", " - // << bc->causalMask[0].non_tree_cache_size << ", " - // << bc->causalMask[0].mask[0] << ", " << - // sizeof(BatchConfig::BitMask) - // << "\n"; compute_qkv_kernel(m, bc, shard_id, @@ -953,38 +901,30 @@ SpecInferIncMultiHeadSelfAttentionMeta::SpecInferIncMultiHeadSelfAttentionMeta( // allocate memory for the seqArray and reserve space { - size_t causal_mask_size = BatchConfig::MAX_NUM_REQUESTS; - size_t total_size = causal_mask_size * sizeof(BatchConfig::BitMask); - gpu_mem_allocator.create_legion_instance(beam_search_reserve_inst, - total_size); + // size_t causal_mask_size = BatchConfig::MAX_NUM_REQUESTS; + // size_t total_size = causal_mask_size * sizeof(BatchConfig::BitMask); + // gpu_mem_allocator.create_legion_instance(beam_search_reserve_inst, + // total_size); - beam_topology_mask = - static_cast( - handler.batch_config_metadata + sizeof(BatchConfig::tokensInfo) + - sizeof(BatchConfig::requestsInfo)); beam_token_infos = static_cast( handler.batch_config_metadata + sizeof(BatchConfig::tokensInfo) + - sizeof(BatchConfig::requestsInfo) + - sizeof(BeamSearchBatchConfig::topology_mask)); + sizeof(BatchConfig::requestsInfo)); beam_request_infos = static_cast( handler.batch_config_metadata + sizeof(BatchConfig::tokensInfo) + - sizeof(BatchConfig::requestsInfo) + - sizeof(BeamSearchBatchConfig::topology_mask) + + sizeof(BatchConfig::requestsInfo) + sizeof(BeamSearchBatchConfig::beamTokenInfo)); - // causalMask = - // static_cast( - // handler.batch_config_metadata + sizeof(BatchConfig::tokensInfo) + - // sizeof(BatchConfig::requestsInfo) + - // sizeof(BeamSearchBatchConfig::topology_mask) + - // sizeof(BeamSearchBatchConfig::beamTokenInfo)) + - // sizeof(BeamSearchBatchConfig::beamRequestsInfo); - - causalMask = gpu_mem_allocator.allocate_instance( - causal_mask_size); + causalMask = static_cast( + handler.batch_config_metadata + sizeof(BatchConfig::tokensInfo) + + sizeof(BatchConfig::requestsInfo) + + sizeof(BeamSearchBatchConfig::beamTokenInfo) + + sizeof(BeamSearchBatchConfig::beamRequestsInfo)); + + // causalMask = gpu_mem_allocator.allocate_instance( + // causal_mask_size); // beam_token_infos = // gpu_mem_allocator // .allocate_instance( diff --git a/src/ops/tree_inc_multihead_self_attention.cu b/src/ops/tree_inc_multihead_self_attention.cu index 11169fa36d..ebbfac23ea 100644 --- a/src/ops/tree_inc_multihead_self_attention.cu +++ b/src/ops/tree_inc_multihead_self_attention.cu @@ -191,8 +191,8 @@ __global__ void compute_attention_kernel_fused_kernel( qk_max = mask ? qk_max : fmaxf(qk_max, qk); // if (head_idx == 0 && qi == 1 && !mask && tidx == 0) { - // printf("tree attn qkqkqkqk request id %d, %d %.10f, %.10f, %.10f\n", - // request_idx, + // printf("tree attn qkqkqkqk request id %d, %d %.10f, %.10f, + // %.10f\n", request_idx, // ti, // qk, // q_vecs[ki_o][0].x, @@ -355,7 +355,8 @@ __global__ void compute_attention_kernel_fused_kernel( // out.z, // out.w, // vi, - // (first_token_idx + qi) * hidden_size + head_idx * per_head_size + + // (first_token_idx + qi) * hidden_size + head_idx * + // per_head_size + // vi); // } } @@ -502,15 +503,21 @@ __global__ void update_tree_branch_kv_cache_fused( int const req_id = tokenInfos[token_idx].request_index; int const tok_id = tokenInfos[token_idx].abs_depth_in_request; - int const request_token_offset = request_infos[req_id].first_token_offset_in_batch; + int const request_token_offset = + request_infos[req_id].first_token_offset_in_batch; // if(i % hidden_size == 0){ - // printf("update token request id: %d, %d, %d value%.10f\n", req_id, token_idx, request_token_offset, kVal); + // printf("update token request id: %d, %d, %d value%.10f\n", req_id, + // token_idx, request_token_offset, kVal); // } kCache_ptr[(req_id * max_tree_branches) * (hidden_size * max_seq_len) + - (token_idx + first_token_depth - request_token_offset) * hidden_size + offset] = kVal; + (token_idx + first_token_depth - request_token_offset) * + hidden_size + + offset] = kVal; vCache_ptr[(req_id * max_tree_branches) * (hidden_size * max_seq_len) + - (token_idx + first_token_depth - request_token_offset) * hidden_size + offset] = vVal; + (token_idx + first_token_depth - request_token_offset) * + hidden_size + + offset] = vVal; } } @@ -974,8 +981,9 @@ void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m, static_cast
(m->devQKVProjArray), bias_ptr, stream); - // print_tensor((float *)m->devQKVProjArray + 768 * 8 * 3 + 768, 32, "qkvtenor1"); - // print_tensor((float *)m->devQKVProjArray + 768 * 18 * 3 + 768, 32, "qkvtenor2"); + // print_tensor((float *)m->devQKVProjArray + 768 * 8 * 3 + 768, 32, + // "qkvtenor1"); print_tensor((float *)m->devQKVProjArray + 768 * 18 * + // 3 + 768, 32, "qkvtenor2"); // phase 2: No need to update key/val cache // IncMultiHeadSelfAttention::update_kv_cache_kernel( @@ -1019,7 +1027,7 @@ void TreeIncMultiHeadSelfAttention::inference_kernel_wrapper( cudaEventRecord(t_start, stream); } - std::cout << "tree input tokens: " <num_active_tokens() << "\n"; + std::cout << "tree input tokens: " << bc->num_active_tokens() << "\n"; // assert(input.data_type == weight.data_type); assert(input.data_type == output.data_type); @@ -1128,34 +1136,15 @@ TreeIncMultiHeadSelfAttentionMeta::TreeIncMultiHeadSelfAttentionMeta( // allocate memory for the seqArray and reserve space { - int max_tokens_per_batch = BatchConfig::max_tokens_per_batch(); - size_t committed_tokeninfo_size = max_tokens_per_batch; - size_t causal_mask_size = BatchConfig::MAX_NUM_REQUESTS; - - size_t total_size = committed_tokeninfo_size * - sizeof(TreeVerifyBatchConfig::CommittedTokensInfo) + - causal_mask_size * sizeof(BatchConfig::BitMask); - if (offload) { - // assert that we have enough reserved work space left - assert(gpu_mem_allocator.reserved_total_size - - gpu_mem_allocator.reserved_allocated_size >= - total_size); - committed_token_infos = - gpu_mem_allocator - .allocate_reserved( - committed_tokeninfo_size); - causalMask = gpu_mem_allocator.allocate_instance( - causal_mask_size); - } else { - gpu_mem_allocator.create_legion_instance(committed_token_reserve_inst, - total_size); - committed_token_infos = - gpu_mem_allocator - .allocate_instance( - committed_tokeninfo_size); - causalMask = gpu_mem_allocator.allocate_instance( - causal_mask_size); - } + + causalMask = static_cast( + handler.batch_config_metadata + sizeof(BatchConfig::tokensInfo) + + sizeof(BatchConfig::requestsInfo)); + committed_token_infos = + static_cast( + handler.batch_config_metadata + sizeof(BatchConfig::tokensInfo) + + sizeof(BatchConfig::requestsInfo) + + sizeof(BatchConfig::causalMask)); } cudaStreamSynchronize(stream); diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index ab062a4610..670db1ab0e 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -766,12 +766,6 @@ BeamSearchBatchConfig new_bc.beamRequestsInfo[i].sub_request_num = 1; new_bc.sub_requests[i] = 1; - new_bc.topology_mask[i].allocated_tokens = request.tokens.size(); - - // assign new kv cache position - for (int j = 0; j < request.tokens.size(); j++) { - new_bc.topology_mask[i].real_token_pos[0][j] = j; - } updateBitMask(new_bc.causalMask[i], verified_tokens.size(), @@ -786,8 +780,6 @@ BeamSearchBatchConfig new_bc.tokensInfo[new_bc.num_tokens].token_id = token.first; new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = token.second; - new_bc.topology_mask[i].real_token_pos[0][token.second] = - new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request; // Beam Token Info new_bc.beamTokenInfo[new_bc.num_tokens].sub_request_index = 0; @@ -846,7 +838,6 @@ BeamSearchBatchConfig } new_bc.beamRequestsInfo[i].sub_request_num = 1; - new_bc.topology_mask[i].allocated_tokens = 0; new_bc.sub_requests[i] = 1; @@ -919,14 +910,12 @@ BeamSearchBatchConfig assert(depth < new_request.tokens.size()); new_bc.tokensInfo[new_bc.num_tokens].token_id = new_request.tokens[depth]; - new_bc.topology_mask[i].real_token_pos[0][depth] = depth; // beam search meta data, indicate which sub request this token // belongs to, init to 0; new_bc.beamTokenInfo[new_bc.num_tokens].sub_request_index = 0; new_bc.num_tokens++; } - new_bc.topology_mask[i].allocated_tokens = 0; initBitMask(new_bc.causalMask[i], new_bc.requestsInfo[i].num_tokens_in_batch); @@ -1120,9 +1109,6 @@ BeamSearchBatchConfig update_beam_metadata( new_bc, old_bc, request.beam_trees.at(old_bc.model_id), i); - new_bc.topology_mask[i].allocated_tokens = - old_bc.topology_mask[i].allocated_tokens + - old_bc.beamRequestsInfo[i].sub_request_num; } else { assert(false && "Request should not be pending in beam search phase"); } @@ -1156,31 +1142,9 @@ BeamSearchBatchConfig << std::endl; } - // for (int j = 0; j < request.tokens.size(); j++) { - // new_bc.topology_mask[i].real_token_pos[0][j] = j; - // } - // register more tokens due to the beam width - std::cout << "register more tokens: " - << new_bc.beamRequestsInfo[i].sub_request_num << ", " - << new_bc.requestsInfo[i].num_tokens_in_batch << ", " - << new_bc.topology_mask[i].allocated_tokens << "\n"; - - // copy meta data and replicate - int replicate_num = new_bc.beamRequestsInfo[i].sub_request_num / - old_bc.beamRequestsInfo[i].sub_request_num; - - for (int j = 0; j < old_bc.beamRequestsInfo[i].sub_request_num; j++) { - int old_idx = j; - for (int k = 0; k < replicate_num; k++) { - int new_idx = j * replicate_num + k; - std::cout << "copy from " << old_idx << "to: " << new_idx << "\n"; - memcpy(new_bc.topology_mask[i].real_token_pos[new_idx], - old_bc.topology_mask[i].real_token_pos[old_idx], - sizeof(int) * BatchConfig::MAX_NUM_TOKENS); - } - } + //copy metadata memcpy(&new_bc.causalMask[i], &old_bc.causalMask[i], sizeof(BatchConfig::BitMask)); @@ -1215,14 +1179,6 @@ BeamSearchBatchConfig new_bc.beamTokenInfo[new_bc.num_tokens].sub_request_index = k; new_bc.num_tokens++; - // width first - new_bc.topology_mask[i].real_token_pos[k][depth] = - new_bc.topology_mask[i].allocated_tokens + num_generation_tokens; - - // std::cout << "topology: sub request: " << k << ", " - // << ", " << depth << ", " - // << new_bc.topology_mask[i].real_token_pos[k][depth] << - // "\n"; num_generation_tokens++; } } diff --git a/src/runtime/request_manager.cu b/src/runtime/request_manager.cu index 4d7e2c8806..e8824feda5 100644 --- a/src/runtime/request_manager.cu +++ b/src/runtime/request_manager.cu @@ -59,64 +59,74 @@ void RequestManager::load_tokens_task( // copy meta data to workSpace FFHandler handle = *((FFHandler const *)task->local_args); + size_t total_copy_size = 0; cudaMemcpyAsync(handle.batch_config_metadata, &(batch_config->tokensInfo), - batch_config->num_active_tokens() * - sizeof(BatchConfig::PerTokenInfo), + sizeof(BatchConfig::tokensInfo), cudaMemcpyHostToDevice, stream); + total_copy_size += sizeof(BatchConfig::tokensInfo); + cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + - sizeof(BatchConfig::tokensInfo), + total_copy_size, &(batch_config->requestsInfo), - batch_config->max_requests_per_batch() * - sizeof(BatchConfig::PerRequestInfo), + sizeof(BatchConfig::requestsInfo), cudaMemcpyHostToDevice, stream); + total_copy_size += sizeof(BatchConfig::requestsInfo); - // load speculative metadata if (batch_config->get_mode() == BEAM_SEARCH_MODE) { BeamSearchBatchConfig const *beam_batch_config = static_cast(batch_config); cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + - sizeof(BatchConfig::tokensInfo) + - sizeof(BatchConfig::requestsInfo), - &(beam_batch_config->topology_mask), - sizeof(BeamSearchBatchConfig::topology_mask), - cudaMemcpyHostToDevice, - stream); - - cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + - sizeof(BatchConfig::tokensInfo) + - sizeof(BatchConfig::requestsInfo) + - sizeof(BeamSearchBatchConfig::topology_mask), + total_copy_size, &(beam_batch_config->beamTokenInfo), sizeof(BeamSearchBatchConfig::beamTokenInfo), cudaMemcpyHostToDevice, stream); + + total_copy_size += sizeof(BeamSearchBatchConfig::beamTokenInfo); + cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + - sizeof(BatchConfig::tokensInfo) + - sizeof(BatchConfig::requestsInfo) + - sizeof(BeamSearchBatchConfig::topology_mask) + - sizeof(BeamSearchBatchConfig::beamTokenInfo), + total_copy_size, &(beam_batch_config->beamRequestsInfo), sizeof(BeamSearchBatchConfig::beamRequestsInfo), cudaMemcpyHostToDevice, stream); + total_copy_size += sizeof(BeamSearchBatchConfig::beamRequestsInfo); - // cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + - // sizeof(BatchConfig::tokensInfo) + - // sizeof(BatchConfig::requestsInfo) + - // sizeof(BeamSearchBatchConfig::topology_mask) + - // sizeof(BeamSearchBatchConfig::beamTokenInfo) + - // sizeof(BeamSearchBatchConfig::beamRequestsInfo), - // &(beam_batch_config->causalMask), - // sizeof(BatchConfig::causalMask), - // cudaMemcpyHostToDevice, - // stream); - // std::cout << "copy calsual mask info: " << beam_batch_config->causalMask[0].prompt_size << "\n"; + cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + + total_copy_size, + &(beam_batch_config->causalMask), + sizeof(BatchConfig::causalMask), + cudaMemcpyHostToDevice, + stream); + + total_copy_size += sizeof(BatchConfig::causalMask); + } else if (batch_config->get_mode() == TREE_VERIFY_MODE) { + TreeVerifyBatchConfig const *tree_batch_config = + static_cast(batch_config); + + cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + + total_copy_size, + &(tree_batch_config->causalMask), + sizeof(BatchConfig::causalMask), + cudaMemcpyHostToDevice, + stream); + total_copy_size += sizeof(BatchConfig::causalMask); + cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + + total_copy_size, + &(tree_batch_config->committed_tokens), + sizeof(TreeVerifyBatchConfig::committed_tokens), + cudaMemcpyHostToDevice, + stream); + total_copy_size += sizeof(TreeVerifyBatchConfig::committed_tokens); } + + // add a size check + assert(total_copy_size <= handle.batch_config_metadata_size); } void RequestManager::load_positions_task( From fae148da9a4b495d26642c1929ebe9f25cdf3b1d Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Sat, 30 Dec 2023 05:11:38 -0500 Subject: [PATCH 09/30] fix some corner cases --- include/flexflow/model.h | 1 + .../inc_multihead_self_attention_utils.cuh | 4 +- include/flexflow/request_manager.h | 7 + inference/spec_infer/spec_infer.cc | 6 +- src/ops/argmax.cc | 2 +- src/ops/beam_topk.cc | 1 + src/ops/inc_multihead_self_attention.cu | 8 +- src/ops/kernels/embedding_kernels.cu | 2 +- src/ops/spec_inc_multihead_self_attention.cu | 18 +-- .../specinfer_inc_multihead_self_attention.cu | 75 +++++----- src/ops/tree_inc_multihead_self_attention.cu | 94 ++++++------ src/runtime/cuda_helper.cu | 2 +- src/runtime/inference_manager.cc | 61 +++++++- src/runtime/model.cc | 17 +++ src/runtime/request_manager.cc | 141 ++++++++++++++---- src/runtime/request_manager.cu | 87 +++++++++++ 16 files changed, 389 insertions(+), 137 deletions(-) diff --git a/include/flexflow/model.h b/include/flexflow/model.h index 3602cb108b..9cdbec64a9 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -242,6 +242,7 @@ enum TaskIDs { // InferenceManager & RequestManager RM_LOAD_TOKENS_TASK_ID, RM_LOAD_POSITION_TASK_ID, + RM_LOAD_BATCH_CONFIG_TASK_ID, RM_PREPARE_NEXT_BATCH_TASK_ID, RM_PREPARE_NEXT_BATCH_INIT_TASK_ID, RM_PREPARE_NEXT_BATCH_BEAM_TASK_ID, diff --git a/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh b/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh index 0c065b6b0e..1b21a80dc9 100644 --- a/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh +++ b/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh @@ -456,7 +456,7 @@ inline size_t smem_size_in_bytes(int hidden_size_per_head, int threads_per_block) { // The amount of shared memory needed to store the Q*K^T values in float. - size_t qk_sz = div_up(1000 + 1, 4) * 16; + size_t qk_sz = div_up(2000 + 1, 4) * 16; size_t logits_sz = qk_sz; // The total size needed during softmax. @@ -493,7 +493,7 @@ inline void smem_size_in_bytes_tree(int hidden_size_per_head, } // todo fix this - int max_qk_length = max_query_length * max_total_length; + int max_qk_length = max_query_length * max_total_length + 1000; // The amount of shared memory needed to store the Q*K^T values in float. size_t qk_sz = div_up(max_qk_length + 1, 4) * 16; diff --git a/include/flexflow/request_manager.h b/include/flexflow/request_manager.h index dc1939c74b..8cb45e55b4 100644 --- a/include/flexflow/request_manager.h +++ b/include/flexflow/request_manager.h @@ -43,6 +43,8 @@ class InferenceManager { void load_positions(BatchConfigFuture const &bc, ParallelTensor position_input, int offset); + void load_inference_metadata_batch_config(BatchConfigFuture const &bc, + FFHandler *handlers); public: FFConfig ff_config; @@ -195,6 +197,11 @@ class RequestManager { Legion::Context ctx, Legion::Runtime *runtime); + static void + load_batch_config_task(Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); static BatchConfig prepare_next_batch_task( Legion::Task const *task, std::vector const ®ions, diff --git a/inference/spec_infer/spec_infer.cc b/inference/spec_infer/spec_infer.cc index 9af3e12e5a..258b2d78eb 100644 --- a/inference/spec_infer/spec_infer.cc +++ b/inference/spec_infer/spec_infer.cc @@ -266,9 +266,9 @@ void FlexFlow::top_level_task(Task const *task, ModelMeta model_metadata; bool use_full_precision = false; bool verbose = false; - int max_requests_per_batch = 16; - int max_tokens_per_batch = 256; - int max_sequence_length = 1024; + int max_requests_per_batch = 10; + int max_tokens_per_batch = 199; + int max_sequence_length = 200; InputArgs const &command_args = HighLevelRuntime::get_input_args(); char **argv = command_args.argv; diff --git a/src/ops/argmax.cc b/src/ops/argmax.cc index 0344c707fc..d195a5af75 100644 --- a/src/ops/argmax.cc +++ b/src/ops/argmax.cc @@ -399,7 +399,7 @@ InferenceResult m, shard_id, bc, {}, {}, {input, indices}); } - print_tensor(indices.get_int32_ptr(), 32, "tree attn output"); + // print_tensor(indices.get_int32_ptr(), 199, "tree attn output"); download_tensor( indices.get_int32_ptr(), ir.token_ids, batch_size); return ir; diff --git a/src/ops/beam_topk.cc b/src/ops/beam_topk.cc index 20d019eec3..5dfaae41ee 100644 --- a/src/ops/beam_topk.cc +++ b/src/ops/beam_topk.cc @@ -404,6 +404,7 @@ BeamInferenceResult // print_tensor(index_ptr, 32, "indexxxxxxx"); + if (m->inference_debugging) { assert(task->index_point.get_dim() == 1); int shard_id = task->index_point.point_data[0]; diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index a084f216e9..2f16dd71c2 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -1365,12 +1365,12 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( // a K-ary tree max node is (k^n - 1) / 2 key_cache_size = num_q_heads * kProjSize * BeamSearchBatchConfig::max_requests_per_batch() * - BatchConfig::max_sequence_length() * - BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES; + (BatchConfig::max_sequence_length() + + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM); value_cache_size = num_q_heads * vProjSize * BeamSearchBatchConfig::max_requests_per_batch() * - BatchConfig::max_sequence_length() * - BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES; + (BatchConfig::max_sequence_length() + + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM); break; } default: diff --git a/src/ops/kernels/embedding_kernels.cu b/src/ops/kernels/embedding_kernels.cu index 0cde42de56..3085fdb6ba 100644 --- a/src/ops/kernels/embedding_kernels.cu +++ b/src/ops/kernels/embedding_kernels.cu @@ -118,7 +118,7 @@ void forward_kernel_wrapper(EmbeddingMeta const *m, // print_tensor(output_ptr, output_domain.get_volume(), // "[Embedding:forward:output]"); } - // print_tensor(input.get_int32_ptr(), 32, "embeddinginput"); + print_tensor(input.get_int32_ptr(), 200, "embeddinginput"); } /*static*/ diff --git a/src/ops/spec_inc_multihead_self_attention.cu b/src/ops/spec_inc_multihead_self_attention.cu index 562dee4d93..29e3d9a48d 100644 --- a/src/ops/spec_inc_multihead_self_attention.cu +++ b/src/ops/spec_inc_multihead_self_attention.cu @@ -111,15 +111,15 @@ __global__ void spec_store_kv_cache( // naive cache stealing if (sub_req_id != parent_id) { - if (offset == 0 && tok_id == 0) { - printf("cache stealing!, depth %d req_id %d sub_req_id %d, parentid " - "%d, tok_id %d\n", - beam_depth, - req_id, - sub_req_id, - parent_id, - tok_id); - } + // if (offset == 0 && tok_id == 0) { + // printf("cache stealing!, depth %d req_id %d sub_req_id %d, parentid " + // "%d, tok_id %d\n", + // beam_depth, + // req_id, + // sub_req_id, + // parent_id, + // tok_id); + // } for (int depth = 0; depth < beam_depth; depth++) { int steal_token_idx = tok_id - beam_depth + depth; diff --git a/src/ops/specinfer_inc_multihead_self_attention.cu b/src/ops/specinfer_inc_multihead_self_attention.cu index 4d4afd28e4..e84ec3095c 100644 --- a/src/ops/specinfer_inc_multihead_self_attention.cu +++ b/src/ops/specinfer_inc_multihead_self_attention.cu @@ -50,8 +50,7 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( int hidden_size, BatchConfig::PerRequestInfo *request_infos, BeamSearchBatchConfig::BeamSearchPerRequestInfo *beam_request_infos, - BatchConfig::BitMask *causalMask, - int max_tree_branches) { + BatchConfig::BitMask *causalMask) { // q, k using Q_vec = typename VEC_K::Type; @@ -83,8 +82,14 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { // printf("specinfer attn fused kernel %lld\n", bitmask.mask[1]); // } + int const totalCacheSize = bitmask.non_tree_cache_size + bitmask.tree_size; + + // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { + // printf("specinfer attn fused kernel %d, %d\n", + // totalCacheSize,request_infos[request_idx].num_tokens_in_batch); + // } // int const qlength = request_infos[request_idx].num_tokens_in_batch; int const tree_branch_num = beam_request_infos[request_idx].sub_request_num; @@ -94,7 +99,7 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( int first_token_idx = 0; for (int r = 0; r < request_idx; r++) { // first_token_idx += request_infos[request_idx].num_tokens_in_batch; - first_token_idx += bitmask.this_layer_size; + first_token_idx += causalMask[r].this_layer_size; } // if (tidx == 0 && head_idx == 0) { @@ -130,8 +135,7 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; DT const *k_cache_batch = - key_cache + - request_idx * max_seq_length * hidden_size * max_tree_branches + ki; + key_cache + request_idx * max_seq_length * hidden_size + ki; int ti_end = div_up(totalCacheSize - first_step, K_PER_WARP) * K_PER_WARP + first_step; @@ -267,9 +271,7 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( // The base pointer for the value in the cache buffer. DT const *v_cache_batch = - value_cache + - request_idx * max_seq_length * hidden_size * max_tree_branches + vi; - + value_cache + request_idx * max_seq_length * hidden_size + vi; if (Dh == Dh_MAX || vi < Dh) { for (int ti = first_step + vo; ti < totalCacheSize; ti += V_PER_ITER) { @@ -344,7 +346,6 @@ __global__ void specinfer_store_kv_cache( int vProjSize, int num_tokens, int max_seq_len, - int max_tree_branches, bool is_root, int hidden_size) { CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { @@ -383,10 +384,10 @@ __global__ void specinfer_store_kv_cache( bitmask.this_layer_size + token_idx - request_token_offset; - kCache_ptr[(req_id * max_tree_branches) * (hidden_size * max_seq_len) + - (cache_idx)*hidden_size + offset] = kVal; - vCache_ptr[(req_id * max_tree_branches) * (hidden_size * max_seq_len) + - (cache_idx)*hidden_size + offset] = vVal; + kCache_ptr[req_id * (hidden_size * max_seq_len) + (cache_idx)*hidden_size + + offset] = kVal; + vCache_ptr[req_id * (hidden_size * max_seq_len) + (cache_idx)*hidden_size + + offset] = vVal; } } @@ -419,8 +420,8 @@ void update_kv_cache_kernel(SpecInferIncMultiHeadSelfAttentionMeta const *m, m->kProjSize, m->vProjSize, num_tokens, - BatchConfig::max_sequence_length(), - BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES, + BatchConfig::max_sequence_length() + + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM, /*root*/ curr_depth == 0, m->hidden_size); } @@ -429,7 +430,8 @@ void update_kv_cache_kernel(SpecInferIncMultiHeadSelfAttentionMeta const *m, #define LAUNCH_SPECINFER_ATTENTION_SCORE_KERNEL( \ DT, Dh, Dh_MAX, THDS_PER_KEY, THREADS_PER_VALUE, THDS_PER_BLOCK, stream) \ smem_sz = smem_size_in_bytes
(m->qProjSize, \ - BatchConfig::max_sequence_length(), \ + BatchConfig::max_sequence_length() + \ + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM, \ THREADS_PER_VALUE, \ THDS_PER_BLOCK); \ compute_specinfer_attention_kernel_generation_kernel(m->valueCache), \ output_ptr, \ scale, \ - BatchConfig::max_sequence_length(), \ + BatchConfig::max_sequence_length() + \ + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM, \ m->qProjSize, \ m->hidden_size, \ m->request_infos, \ m->beam_request_infos, \ - m->causalMask, \ - BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES) + m->causalMask) template void compute_specinfer_attention_kernel_generation( @@ -527,11 +529,13 @@ void compute_attention_kernel_prompt( int q_block_size = m->qProjSize; int kt_block_size = m->kProjSize; - int kt_req_block_size = - kt_block_size * m->num_q_heads * BatchConfig::max_sequence_length(); + int kt_req_block_size = kt_block_size * m->num_q_heads * + (BatchConfig::max_sequence_length() + + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM); int vt_block_size = m->vProjSize; - int vt_req_block_size = - vt_block_size * m->num_q_heads * BatchConfig::max_sequence_length(); + int vt_req_block_size = vt_block_size * m->num_q_heads * + (BatchConfig::max_sequence_length() + + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM); assert(m->qProjSize == m->kProjSize); for (int i = 0; i < bc->max_requests_per_batch(); i++) { @@ -580,8 +584,7 @@ void compute_attention_kernel_prompt( // print_tensor((float*)A, 32, "A"); std::cout << "meta: " << num_new_tokens << ", " << total_tokens << "\n"; - DT const *B = static_cast
(m->keyCache) + - (i * bc->MAX_SPECULATIVE_TREE_BRANCHES) * kt_req_block_size; + DT const *B = static_cast
(m->keyCache) + i * kt_req_block_size; // if (i == 0 && sub_req_id == 0 && // bc->beam_slots.at(0).current_depth == 1) { @@ -692,8 +695,7 @@ void compute_attention_kernel_prompt( strideC = m->vProjSize; // To get A, skip over V^T entries from previous requests (all heads + // padding) - A = static_cast
(m->valueCache) + - (i * bc->MAX_SPECULATIVE_TREE_BRANCHES) * vt_req_block_size; + A = static_cast
(m->valueCache) + i * vt_req_block_size; // To get B, skip over softmax(QK^T/sqrt(d_k)) entries from previous // requests (all heads) B = C_softmax; @@ -851,8 +853,10 @@ void SpecInferIncMultiHeadSelfAttention::inference_kernel_wrapper( // "[Attention:forward:query]"); print_tensor<3, float>(acc_output.ptr, // acc_output.rect, "[Attention:forward:output]"); } - // save_tensor(output.get_float_ptr(), 768 * 3, "/home/xinhaoc/FlexFlow/inference/output/fk1.txt"); - // save_tensor(output.get_float_ptr() + 768 * 3, 768 * 3, "/home/xinhaoc/FlexFlow/inference/output/fk2.txt"); + // save_tensor(output.get_float_ptr(), 768 * 3, + // "/home/xinhaoc/FlexFlow/inference/output/fk1.txt"); + // save_tensor(output.get_float_ptr() + 768 * 3, 768 * 3, + // "/home/xinhaoc/FlexFlow/inference/output/fk2.txt"); // if(bc->num_tokens == 1){ // print_tensor(input.get_float_ptr(), 32, "specinc input"); @@ -906,7 +910,6 @@ SpecInferIncMultiHeadSelfAttentionMeta::SpecInferIncMultiHeadSelfAttentionMeta( // gpu_mem_allocator.create_legion_instance(beam_search_reserve_inst, // total_size); - beam_token_infos = static_cast( handler.batch_config_metadata + sizeof(BatchConfig::tokensInfo) + @@ -915,13 +918,13 @@ SpecInferIncMultiHeadSelfAttentionMeta::SpecInferIncMultiHeadSelfAttentionMeta( beam_request_infos = static_cast( handler.batch_config_metadata + sizeof(BatchConfig::tokensInfo) + - sizeof(BatchConfig::requestsInfo) + + sizeof(BatchConfig::requestsInfo) + sizeof(BeamSearchBatchConfig::beamTokenInfo)); - causalMask = static_cast( - handler.batch_config_metadata + sizeof(BatchConfig::tokensInfo) + - sizeof(BatchConfig::requestsInfo) + - sizeof(BeamSearchBatchConfig::beamTokenInfo) - + sizeof(BeamSearchBatchConfig::beamRequestsInfo)); + causalMask = static_cast( + handler.batch_config_metadata + sizeof(BatchConfig::tokensInfo) + + sizeof(BatchConfig::requestsInfo) + + sizeof(BeamSearchBatchConfig::beamTokenInfo) + + sizeof(BeamSearchBatchConfig::beamRequestsInfo)); // causalMask = gpu_mem_allocator.allocate_instance( // causal_mask_size); diff --git a/src/ops/tree_inc_multihead_self_attention.cu b/src/ops/tree_inc_multihead_self_attention.cu index ebbfac23ea..8641e63e38 100644 --- a/src/ops/tree_inc_multihead_self_attention.cu +++ b/src/ops/tree_inc_multihead_self_attention.cu @@ -53,7 +53,6 @@ __global__ void compute_attention_kernel_fused_kernel( BatchConfig::PerRequestInfo *request_infos, int num_heads, int num_requests, - int max_tree_branches, BatchConfig::BitMask *causalMask, int qk_smem_sz) { @@ -86,8 +85,9 @@ __global__ void compute_attention_kernel_fused_kernel( BatchConfig::BitMask bitmask = causalMask[request_idx]; // bitmask.mask[1] = 3; - // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { - // printf("tree attn fused kernel %d, %d, %d, %lld\n", + // if (head_idx == 0 && tidx == 0) { + // printf("tree attn fused kernel req id %d %d, %d, %d, %lld\n", + // request_idx, // tlength, // qlength, // bitmask.non_tree_cache_size, @@ -96,12 +96,12 @@ __global__ void compute_attention_kernel_fused_kernel( int first_token_idx = 0; for (int r = 0; r < request_idx; r++) { - first_token_idx += request_infos[request_idx].num_tokens_in_batch; + first_token_idx += request_infos[r].num_tokens_in_batch; } - // if(tidx == 0 && head_idx == 0){ - // printf("tree req: %d, %d\n", request_idx, first_token_idx); - // } + if(tidx == 0 && head_idx == 0){ + printf("tree req: %d, %d\n", request_idx, first_token_idx); + } // shared memory objects extern __shared__ char smem_[]; @@ -132,8 +132,7 @@ __global__ void compute_attention_kernel_fused_kernel( constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; DT const *k_cache_batch = - key_cache + - request_idx * max_tree_branches * max_seq_length * hidden_size + ki; + key_cache + request_idx * max_seq_length * hidden_size + ki; int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step; @@ -190,17 +189,14 @@ __global__ void compute_attention_kernel_fused_kernel( qk_max = mask ? qk_max : fmaxf(qk_max, qk); - // if (head_idx == 0 && qi == 1 && !mask && tidx == 0) { - // printf("tree attn qkqkqkqk request id %d, %d %.10f, %.10f, - // %.10f\n", request_idx, - // ti, - // qk, - // q_vecs[ki_o][0].x, - // q_vecs[ki_o][1].x, - // q_vecs[ki_o][2].x, - // q_vecs[ki_o][3].x, - // k[0].x); - // } + if (head_idx == 0 && qi == 0 && !mask) { + printf("tree attn qkqkqkqk request id %d, %d %.10f, %.10f, %.10f\n ", + request_idx, + ti, + qk, + q_vecs[ki_o][0].x, + k[0].x); + } qk_smem[ti - first_step] = mask ? 0.0f : qk; } } @@ -283,8 +279,7 @@ __global__ void compute_attention_kernel_fused_kernel( // The base pointer for the value in the cache buffer. DT const *v_cache_batch = - value_cache + - request_idx * max_seq_length * hidden_size * max_tree_branches + vi; + value_cache + request_idx * max_seq_length * hidden_size + vi; // DT const *v_cache_batch = // value_cache + // (beam_request_idx * max_beam_width + beam_sub_request_idx) * @@ -375,8 +370,7 @@ __global__ void commit_tokens_kernel( int num_tokens_to_commit, int num_active_tokens_in_last_batch, int max_seq_len, - int hidden_size, - int max_tree_branches) { + int hidden_size) { CUDA_KERNEL_LOOP(i, num_tokens_to_commit * hidden_size) { @@ -407,10 +401,10 @@ __global__ void commit_tokens_kernel( // kVal); // } - kCache_ptr[req_id * max_tree_branches * (hidden_size * max_seq_len) + - tok_id * hidden_size + offset] = kVal; - vCache_ptr[req_id * max_tree_branches * (hidden_size * max_seq_len) + - tok_id * hidden_size + offset] = vVal; + kCache_ptr[req_id * (hidden_size * max_seq_len) + tok_id * hidden_size + + offset] = kVal; + vCache_ptr[req_id * (hidden_size * max_seq_len) + tok_id * hidden_size + + offset] = vVal; } } @@ -434,9 +428,9 @@ void commit_tokens(TreeIncMultiHeadSelfAttentionMeta const *m, m->vProjSize, num_tokens_to_commit, m->num_active_tokens, // number of active tokens in previous batch - BatchConfig::max_sequence_length(), - m->hidden_size, - BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES); + BatchConfig::max_sequence_length() + + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM, + m->hidden_size); } } @@ -488,7 +482,6 @@ __global__ void update_tree_branch_kv_cache_fused( int num_new_tokens, int max_seq_len, int hidden_size, - int max_tree_branches, int first_token_depth) { CUDA_KERNEL_LOOP(i, num_new_tokens * hidden_size) { @@ -510,11 +503,11 @@ __global__ void update_tree_branch_kv_cache_fused( // printf("update token request id: %d, %d, %d value%.10f\n", req_id, // token_idx, request_token_offset, kVal); // } - kCache_ptr[(req_id * max_tree_branches) * (hidden_size * max_seq_len) + + kCache_ptr[req_id * (hidden_size * max_seq_len) + (token_idx + first_token_depth - request_token_offset) * hidden_size + offset] = kVal; - vCache_ptr[(req_id * max_tree_branches) * (hidden_size * max_seq_len) + + vCache_ptr[req_id * (hidden_size * max_seq_len) + (token_idx + first_token_depth - request_token_offset) * hidden_size + offset] = vVal; @@ -569,10 +562,12 @@ void compute_attention_kernel(TreeIncMultiHeadSelfAttentionMeta const *m, int q_block_size = m->qProjSize; int kt_block_size = m->kProjSize; int kt_req_block_size = - kt_block_size * m->num_q_heads * BatchConfig::max_sequence_length(); + kt_block_size * m->num_q_heads * BatchConfig::max_sequence_length() + + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM; int vt_block_size = m->vProjSize; int vt_req_block_size = - vt_block_size * m->num_q_heads * BatchConfig::max_sequence_length(); + vt_block_size * m->num_q_heads * BatchConfig::max_sequence_length() + + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM; assert(m->qProjSize == m->kProjSize); for (int i = 0; i < bc->max_requests_per_batch(); i++) { @@ -836,7 +831,8 @@ void compute_attention_kernel(TreeIncMultiHeadSelfAttentionMeta const *m, #define LAUNCH_TREE_VERIFY_ATTENTION_SCORE_KERNEL( \ DT, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, stream) \ smem_size_in_bytes_tree
(m->qProjSize, \ - BatchConfig::max_sequence_length(), \ + BatchConfig::max_sequence_length() + \ + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM, \ THDS_PER_VALUE, \ THDS_PER_BLOCK, \ bc, \ @@ -848,7 +844,20 @@ void compute_attention_kernel(TreeIncMultiHeadSelfAttentionMeta const *m, THDS_PER_KEY, \ THDS_PER_VALUE> \ <<>>( \ - static_cast
(m->devQKVProjArray), static_cast
(m->keyCache), static_cast
(m->valueCache), output_ptr, scale, BatchConfig::max_sequence_length(), BatchConfig::max_tokens_per_batch(), m->qProjSize, m->hidden_size, m->request_infos, m->num_q_heads, bc->num_active_requests(), BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES, m->causalMask, \ + static_cast
(m->devQKVProjArray), \ + static_cast
(m->keyCache), \ + static_cast
(m->valueCache), \ + output_ptr, \ + scale, \ + BatchConfig::max_sequence_length() + \ + BatchConfig::BatchConfig::MAX_SPEC_TREE_TOKEN_NUM, \ + BatchConfig::max_tokens_per_batch(), \ + m->qProjSize, \ + m->hidden_size, \ + m->request_infos, \ + m->num_q_heads, \ + bc->num_active_requests(), \ + m->causalMask, \ smem_sz[0]) template @@ -880,9 +889,8 @@ void compute_attention_kernel_fused(TreeIncMultiHeadSelfAttentionMeta const *m, m->kProjSize, m->vProjSize, num_new_tokens, - BatchConfig::max_sequence_length(), + BatchConfig::max_sequence_length() + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM, m->hidden_size, - BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES, bc->requestsInfo[0].first_token_depth_in_request); dim3 grid(m->num_q_heads, bc->num_active_requests()); @@ -981,9 +989,9 @@ void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m, static_cast
(m->devQKVProjArray), bias_ptr, stream); - // print_tensor((float *)m->devQKVProjArray + 768 * 8 * 3 + 768, 32, - // "qkvtenor1"); print_tensor((float *)m->devQKVProjArray + 768 * 18 * - // 3 + 768, 32, "qkvtenor2"); + + // print_tensor((float *)m->devQKVProjArray, 32, "qkvtenor1"); + // print_tensor((float *)m->devQKVProjArray + 768 * (25 * 7) * 3, 32, "qkvtenor2"); // phase 2: No need to update key/val cache // IncMultiHeadSelfAttention::update_kv_cache_kernel( diff --git a/src/runtime/cuda_helper.cu b/src/runtime/cuda_helper.cu index fa6bf55fe5..398ed7f3cd 100644 --- a/src/runtime/cuda_helper.cu +++ b/src/runtime/cuda_helper.cu @@ -226,7 +226,7 @@ __host__ void print_tensor(T const *ptr, printf("%s, %d---->", prefix, shard_id); for (idx = 0; idx < num_elements; idx++) { printf(" %.20lf", (float)host_ptr[idx]); - if (idx >= 100) { + if (idx >= 200) { break; } } diff --git a/src/runtime/inference_manager.cc b/src/runtime/inference_manager.cc index 52fd64c606..e7f7c5f52d 100644 --- a/src/runtime/inference_manager.cc +++ b/src/runtime/inference_manager.cc @@ -320,6 +320,7 @@ FutureMap InferenceManager::inference(FFModel *model, assert(op->numOutputs == 1); ParallelTensor pt = tensor_buffer[op->outputs[0]][batch_index]; load_input_tokens_from_batch_config(bc, pt, model->handlers); + load_inference_metadata_batch_config(bc, model->handlers); } } @@ -349,18 +350,32 @@ FutureMap InferenceManager::inference(FFModel *model, }; void InferenceManager::load_input_tokens_from_batch_config( - BatchConfigFuture const &bc, ParallelTensor const input, FFHandler *handlers) { + BatchConfigFuture const &bc, + ParallelTensor const input, + FFHandler *handlers) { Context ctx = ff_config.lg_ctx; Runtime *runtime = ff_config.lg_hlr; size_t machine_view_hash = input->machine_view.hash(); ArgumentMap argmap; - Rect<1> task_rect(Point<1>(0), - Point<1>(ff_config.workersPerNode * ff_config.numNodes - 1)); - IndexSpaceT<1> task_is = runtime->create_index_space(ctx, task_rect); - MachineView view = input->machine_view; - for (PointInRectIterator<1> it(task_rect); it(); it++) { - FFHandler handle = handlers[view.get_device_id(*it)]; - argmap.set_point(*it, TaskArgument(&handle, sizeof(FFHandler))); + Domain domain = runtime->get_index_space_domain(ctx, input->parallel_is); + + switch (domain.get_dim()) { +#define DIMFUNC(DIM) \ + case DIM: { \ + Rect rect = domain; \ + MachineView view = input->machine_view; \ + int idx = 0; \ + for (PointInRectIterator it(rect); it(); it++) { \ + argmap.set_point(*it, \ + TaskArgument(&handlers[view.get_device_id(*it)], \ + sizeof(FFHandler))); \ + } \ + break; \ + } + LEGION_FOREACH_N(DIMFUNC) +#undef DIMFUNC + default: + assert(false); } IndexLauncher launcher(RM_LOAD_TOKENS_TASK_ID, @@ -378,6 +393,36 @@ void InferenceManager::load_input_tokens_from_batch_config( runtime->execute_index_space(ctx, launcher); } +void InferenceManager::load_inference_metadata_batch_config( + BatchConfigFuture const &bc, + FFHandler *handlers) { + Context ctx = ff_config.lg_ctx; + Runtime *runtime = ff_config.lg_hlr; + ArgumentMap argmap; + + Rect<1> task_rect(Point<1>(0), + Point<1>(ff_config.workersPerNode * ff_config.numNodes - 1)); + IndexSpaceT<1> task_is = runtime->create_index_space(ctx, task_rect); + + // int rank = 0; + int idx = 0; + for (PointInRectIterator<1> it(task_rect); it(); it++) { + FFHandler handler = handlers[idx++]; + argmap.set_point(*it, TaskArgument(&handler, sizeof(FFHandler))); + } + + IndexLauncher launcher(RM_LOAD_BATCH_CONFIG_TASK_ID, + task_is, + TaskArgument(nullptr, 0), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + FFConfig::DataParallelism_GPU); + launcher.add_future(bc); + runtime->execute_index_space(ctx, launcher); +} + void InferenceManager::load_positions(BatchConfigFuture const &bc, ParallelTensor position_input, int offset) { diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 8bda9016c3..cf72f2d40b 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -4344,6 +4344,23 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar); } } + // RequestManager load metadata + { + TaskVariantRegistrar registrar(RM_LOAD_BATCH_CONFIG_TASK_ID, + "RequestManager Load meta data"); + registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); + registrar.set_leaf(); + if (pre_register) { + Runtime::preregister_task_variant( + registrar, "RequestManager Load metadata Task"); + } else { + if (enable_control_replication) { + registrar.global_registration = false; + } + runtime->register_task_variant( + registrar); + } + } // RequestManager prepare_next_batch { TaskVariantRegistrar registrar(RM_PREPARE_NEXT_BATCH_TASK_ID, diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 670db1ab0e..5c3262eb27 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -609,8 +609,8 @@ BeamSearchBatchConfig committed_tokens[guid].emplace_back(abs_depth, result_index); } else if (abs_depth >= root_abs_depth) { tree_outputs.emplace_back(token_id, abs_depth + 1); - std::cout << "committred tokens push: " << abs_depth - << " ,result index: " << result_index << "\n"; + // std::cout << "committred tokens push: " << abs_depth + // << " ,result index: " << result_index << "\n"; committed_tokens[guid].emplace_back(abs_depth, result_index); if (verbose) { @@ -621,12 +621,12 @@ BeamSearchBatchConfig tree_outputs.back().second, token_id); } - std::cout << "Index within old batch: " << result_index << std::endl; - printf(" Input: [%d] %d ---> [%d] %d \n", - abs_depth, - old_bc.tokensInfo[result_index].token_id, - tree_outputs.back().second, - token_id); + // std::cout << "Index within old batch: " << result_index << std::endl; + // printf(" Input: [%d] %d ---> [%d] %d \n", + // abs_depth, + // old_bc.tokensInfo[result_index].token_id, + // tree_outputs.back().second, + // token_id); } result_index++; } @@ -634,13 +634,12 @@ BeamSearchBatchConfig if (request.status == Request::RUNNING) { std::cout << "verify running: " << dfs_tree_inputs.at(guid).size() << ", " << tree_outputs.size() << "\n"; - + std::vector> verified_tokens = traverse_verify_tree(guid, dfs_tree_inputs.at(guid), tree_outputs); log_req_mgr.print("Number of Verified Tokens = %zu", verified_tokens.size()); - // check if the request is finished if (verified_tokens.size() + request.tokens.size() >= request.max_sequence_length) { @@ -805,7 +804,12 @@ BeamSearchBatchConfig } log_req_mgr.print("Output: %s", output.c_str()); } - + + if (request.tokens.size() > 19 && i >= 7) { + std::cout << request.tokens.size() << "\n"; + assert(false); + } + } else if (request.status == Request::PENDING) { new_bc.request_completed[i] = false; new_bc.request_running[i] = false; @@ -1099,7 +1103,8 @@ BeamSearchBatchConfig // } assert(new_bc.beamRequestsInfo[i].sub_request_num <= - BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES); + BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES && + "exceed maximum nodes per layer"); if (request.status == Request::RUNNING) { new_bc.beamRequestsInfo[i].current_depth = @@ -1144,7 +1149,7 @@ BeamSearchBatchConfig // register more tokens due to the beam width - //copy metadata + // copy metadata memcpy(&new_bc.causalMask[i], &old_bc.causalMask[i], sizeof(BatchConfig::BitMask)); @@ -1185,9 +1190,6 @@ BeamSearchBatchConfig // if(new_bc.beamRequestsInfo[i].current_depth >= 3 && i > 0){ // assert(false); // } - - - } } @@ -1238,7 +1240,8 @@ BeamSearchBatchConfig old_bc.beamRequestsInfo[i].sub_request_num; assert(new_bc.beamRequestsInfo[i].sub_request_num <= - BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES); + BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES && + "exceed maximum nodes per layer"); // update the parentid, accumalated_probs, depth, and token_ids @@ -1504,6 +1507,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( std::cout << "prepare next batch verify: " << dfs_tree_inputs.size() << "\n"; + bool cutLayer = false; // Add Tokens from the DFS Tree to the next batch for (int j = 1; j < dfs_tree_inputs.size(); j++) { auto token = dfs_tree_inputs.at(j); @@ -1520,11 +1524,27 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( new_bc.num_tokens++; new_bc.requestsInfo[i].num_tokens_in_batch++; - if (new_bc.num_tokens == get_max_tokens_per_batch() - 1) { + if (new_bc.num_tokens == get_max_tokens_per_batch() && + (j != dfs_tree_inputs.size() - 1)) { + cutLayer = true; break; } } + // delete the last incomplete layer + if (cutLayer) { + int total_tokens = new_bc.num_tokens; + for (int j = total_tokens - 1; j >= 1; j--) { + new_bc.num_tokens--; + new_bc.requestsInfo[i].num_tokens_in_batch--; + std::cout << "cut: " << j << "\n"; + if (new_bc.tokensInfo[j].abs_depth_in_request != + new_bc.tokensInfo[j - 1].abs_depth_in_request) { + break; + } + } + } + } else if (request.status == Request::PENDING) { std::cout << "prepare next batch verify: pending\n" << "\n"; @@ -1646,6 +1666,8 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( } } + std::cout << "how many tokens in verify? " << new_bc.num_tokens << "\n"; + std::cout << "check dfs tree input size: " << dfs_tree_inputs[1000000].size() << "\n"; @@ -1673,7 +1695,10 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, old_bc.requestsInfo[old_bc.tokensInfo[i].request_index].request_guid != guid) { - std::cout << "i is: " << i << "old guid" << guid << " new guid" << old_bc.requestsInfo[old_bc.tokensInfo[i].request_index].request_guid <<"\n"; + std::cout << "i is: " << i << "old guid" << guid << " new guid" + << old_bc.requestsInfo[old_bc.tokensInfo[i].request_index] + .request_guid + << "\n"; int index = old_bc.tokensInfo[i - 1].request_index; int beam_size = old_bc.beamRequestsInfo[index].beam_size; @@ -1689,18 +1714,19 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, // Count tokens sent to model in this request to find the final token's // index - std::cout << "previous result index: "<< result_index; + std::cout << "previous result index: " << result_index; result_index += (old_bc.tokensInfo[i - 1].abs_depth_in_request - start_depth) * beam_size; - - std::cout << "after result index: "<< result_index; + + std::cout << "after result index: " << result_index; // if (true) { // std::cout << "i = " << i << ", result index = " << result_index // << ", value: " << result.token_ids[result_index] - // << ", leaf node num: " << leaf_node_num << ", depth" << depth + // << ", leaf node num: " << leaf_node_num << ", depth" << + // depth // << ", beam size: " << beam_size << "\n"; // } @@ -1765,7 +1791,8 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, if (i < old_bc.num_tokens) { int new_req_idx = old_bc.tokensInfo[i].request_index; guid = old_bc.requestsInfo[new_req_idx].request_guid; - std::cout << "update guid: " << guid << ", request idx: " << index<< "\n"; + std::cout << "update guid: " << guid << ", request idx: " << index + << "\n"; start_depth = old_bc.tokensInfo[i].abs_depth_in_request; } } @@ -2082,12 +2109,42 @@ std::vector> // In this case the inputSeriedTree ends with padding 0s assert(inputSerializedTree.size() >= outputSerializedTree.size()); + int *treeLayers = new int[inputSerializedTree.size()]; + int node_num = 1; + int layer_num = 0; + for (int token_id = 0; token_id < inputSerializedTree.size(); token_id++) { + if (token_id == (inputSerializedTree.size() - 1) || + inputSerializedTree.at(token_id + 1).second != + inputSerializedTree.at(token_id).second) { + treeLayers[layer_num] = node_num; + layer_num += 1; + node_num = 1; + } else { + node_num++; + } + } + + // to avoid branch switch when same tokens in input tree. + + bool findFirst = false; + layer_num = -1; + int first_layer_slot = 0; + int first_layer_slot_total = 0; + int processed_whole_layer_tokens = 0; + for (int i = 0; i < outputSerializedTree.size(); i++) { auto input = inputSerializedTree.at(i); auto output = outputSerializedTree.at(i); + if (i == 0 || inputSerializedTree.at(i - 1).second != + inputSerializedTree.at(i).second) { + layer_num += 1; + processed_whole_layer_tokens += i == 0 ? 0 : treeLayers[layer_num - 1]; + } + if (i == 0) { verifiedTree.push_back(output); + new_committed_tokens.push_back(std::make_pair( input.second, committed_tokens.at(guid).at(i).second)); // > if (input.first == verifiedTree.back().first && input.second == verifiedTree.back().second) { - verifiedTree.push_back(output); - new_committed_tokens.push_back(std::make_pair( - input.second, - committed_tokens.at(guid).at(i).second)); // + if (findFirst) { + // must in this branch. + int layer_slot = i - processed_whole_layer_tokens; + int layer_slot_total = treeLayers[layer_num]; + if ((first_layer_slot == layer_slot)) { + verifiedTree.push_back(output); + new_committed_tokens.push_back(std::make_pair( + input.second, committed_tokens.at(guid).at(i).second)); + // at this point, you'll not go other branches + std::cout << "verify tree push back: " << output.first + << ", tree size is: " << verifiedTree.size() + << ", ??: " << input.first << ", " << input.second << "\n"; + + } else { + printf("not correct slot\n"); + } + } else { + verifiedTree.push_back(output); + first_layer_slot = i - processed_whole_layer_tokens; + first_layer_slot_total = treeLayers[layer_num]; + findFirst = true; + new_committed_tokens.push_back(std::make_pair( + input.second, + committed_tokens.at(guid).at(i).second)); // + // at this point, you'll not go other branches + std::cout << "verify tree push back: " << output.first + << ", tree size is: " << verifiedTree.size() + << ", ??: " << input.first << ", " << input.second << "\n"; + } + assert(committed_tokens.at(guid).at(i).first == input.second); } } diff --git a/src/runtime/request_manager.cu b/src/runtime/request_manager.cu index e8824feda5..bb6b6030aa 100644 --- a/src/runtime/request_manager.cu +++ b/src/runtime/request_manager.cu @@ -57,6 +57,92 @@ void RequestManager::load_tokens_task( cudaMemcpyHostToDevice, stream)); + // // copy meta data to workSpace + // FFHandler handle = *((FFHandler const *)task->local_args); + // size_t total_copy_size = 0; + // cudaMemcpyAsync(handle.batch_config_metadata, + // &(batch_config->tokensInfo), + // sizeof(BatchConfig::tokensInfo), + // cudaMemcpyHostToDevice, + // stream); + // total_copy_size += sizeof(BatchConfig::tokensInfo); + + // cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + + // total_copy_size, + // &(batch_config->requestsInfo), + // sizeof(BatchConfig::requestsInfo), + // cudaMemcpyHostToDevice, + // stream); + // total_copy_size += sizeof(BatchConfig::requestsInfo); + + // // load speculative metadata + // if (batch_config->get_mode() == BEAM_SEARCH_MODE) { + // BeamSearchBatchConfig const *beam_batch_config = + // static_cast(batch_config); + + // cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + + // total_copy_size, + // &(beam_batch_config->beamTokenInfo), + // sizeof(BeamSearchBatchConfig::beamTokenInfo), + // cudaMemcpyHostToDevice, + // stream); + + // total_copy_size += sizeof(BeamSearchBatchConfig::beamTokenInfo); + + // cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + + // total_copy_size, + // &(beam_batch_config->beamRequestsInfo), + // sizeof(BeamSearchBatchConfig::beamRequestsInfo), + // cudaMemcpyHostToDevice, + // stream); + // total_copy_size += sizeof(BeamSearchBatchConfig::beamRequestsInfo); + + // cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + + // total_copy_size, + // &(beam_batch_config->causalMask), + // sizeof(BatchConfig::causalMask), + // cudaMemcpyHostToDevice, + // stream); + + // total_copy_size += sizeof(BatchConfig::causalMask); + // } else if (batch_config->get_mode() == TREE_VERIFY_MODE) { + // TreeVerifyBatchConfig const *tree_batch_config = + // static_cast(batch_config); + + // cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + + // total_copy_size, + // &(tree_batch_config->causalMask), + // sizeof(BatchConfig::causalMask), + // cudaMemcpyHostToDevice, + // stream); + // total_copy_size += sizeof(BatchConfig::causalMask); + // cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + + // total_copy_size, + // &(tree_batch_config->committed_tokens), + // sizeof(TreeVerifyBatchConfig::committed_tokens), + // cudaMemcpyHostToDevice, + // stream); + // total_copy_size += sizeof(TreeVerifyBatchConfig::committed_tokens); + // } + + // // add a size check + // std::cout << "handle.batch_config_metadata_size: " << handle.batch_config_metadata_size << ", "<< total_copy_size << "\n"; + // assert(total_copy_size <= handle.batch_config_metadata_size); +} + +void RequestManager::load_batch_config_task( + Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + assert(regions.size() == 0); + assert(task->regions.size() == 0); + cudaStream_t stream; + checkCUDA(get_legion_stream(&stream)); + + // BatchConfig const batch_config = *((BatchConfig *)task->args); + BatchConfig const *batch_config = BatchConfig::from_future(task->futures[0]); + // copy meta data to workSpace FFHandler handle = *((FFHandler const *)task->local_args); size_t total_copy_size = 0; @@ -126,6 +212,7 @@ void RequestManager::load_tokens_task( } // add a size check + std::cout << "hahaha handle.batch_config_metadata_size: " << handle.batch_config_metadata_size << ", "<< total_copy_size << "\n"; assert(total_copy_size <= handle.batch_config_metadata_size); } From 6c442593976ebc7efa6a50087a486ee613616a74 Mon Sep 17 00:00:00 2001 From: Zhihao Jia Date: Sat, 30 Dec 2023 13:06:37 -0500 Subject: [PATCH 10/30] Replicate load_token tasks so that it can be fused with other compute tasks; this eliminates Replicate and enables a larger fused op --- include/flexflow/config.h | 1 + src/ops/embedding.cc | 18 ++++++------------ src/runtime/model.cc | 31 ++++++++++++++++++++----------- 3 files changed, 27 insertions(+), 23 deletions(-) diff --git a/include/flexflow/config.h b/include/flexflow/config.h index c2af6d707c..01f318c6d5 100644 --- a/include/flexflow/config.h +++ b/include/flexflow/config.h @@ -132,6 +132,7 @@ class FFConfig { size_t workSpaceSize; Legion::Context lg_ctx; Legion::Runtime *lg_hlr; + Legion::IndexSpaceT<1> all_gpu_task_is; // Legion::FieldSpace field_space; bool syntheticInput, profiling, perform_fusion; bool inference_debugging; diff --git a/src/ops/embedding.cc b/src/ops/embedding.cc index 007e799fe0..76236e65ff 100644 --- a/src/ops/embedding.cc +++ b/src/ops/embedding.cc @@ -155,11 +155,8 @@ int Embedding::output_size(ParallelDim output_dims[MAX_TENSOR_DIM]) { output_dims[OUT_CHANNELS].size = this->out_channels; output_dims[OUT_CHANNELS].degree = 1; output_dims[OUT_CHANNELS].parallel_idx = -1; - // Currently do not support parallelizing over the replica dim - output_dims[num_dims - 1].size = 1; - output_dims[num_dims - 1].degree = 1; - output_dims[num_dims - 1].parallel_idx = -1; - output_dims[num_dims - 1].is_replica_dim = true; + // Copy replica dim + output_dims[num_dims - 1] = input->dims[input->num_dims - 1]; return num_dims; } else { int num_dims = input->num_dims; @@ -170,11 +167,8 @@ int Embedding::output_size(ParallelDim output_dims[MAX_TENSOR_DIM]) { output_dims[OUT_CHANNELS].size = this->out_channels; output_dims[OUT_CHANNELS].degree = 1; output_dims[OUT_CHANNELS].parallel_idx = -1; - // Currently do not support parallelizing over the replica dim - output_dims[num_dims - 1].size = 1; - output_dims[num_dims - 1].degree = 1; - output_dims[num_dims - 1].parallel_idx = -1; - output_dims[num_dims - 1].is_replica_dim = true; + // Copy replica dim + output_dims[num_dims - 1] = input->dims[input->num_dims - 1]; return num_dims; } // const int REPLICA = this->output_vocab_size_replica_dim(); @@ -189,13 +183,13 @@ int Embedding::weight_size(ParallelDim weight_dims[MAX_TENSOR_DIM]) { weight_dims[Weight::VOCAB_SIZE].size = this->num_entries; weight_dims[Weight::VOCAB_SIZE].degree = 1; weight_dims[Weight::VOCAB_SIZE].parallel_idx = -1; - for (int i = 2; i < input->num_dims; i++) { + for (int i = 2; i < input->num_dims + 1; i++) { weight_dims[i].size = input->dims[i - 1].degree; weight_dims[i].degree = weight_dims[i].size; weight_dims[i].parallel_idx = input->dims[i - 1].parallel_idx; weight_dims[i].is_replica_dim = true; } - return input->num_dims; + return input->num_dims + 1; } void Embedding::register_output_mappings() { diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 92f0cff472..975045cd3b 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -1499,10 +1499,8 @@ FFRuntime::FFRuntime(FFConfig &config) { Context ctx = config.lg_ctx; ArgumentMap argmap; - Rect<1> task_rect(Point<1>(0), - Point<1>(config.workersPerNode * config.numNodes - 1)); - IndexSpaceT<1> task_is = runtime->create_index_space(ctx, task_rect); - + Domain domain = runtime->get_index_space_domain(ctx, config.all_gpu_task_is); + Rect<1> task_rect = domain; // int rank = 0; for (PointInRectIterator<1> it(task_rect); it(); it++) { FFInitInfo info; @@ -1518,7 +1516,7 @@ FFRuntime::FFRuntime(FFConfig &config) { // Init CUDA library on each worker IndexLauncher initLauncher(FF_INIT_TASK_ID, - task_is, + config.all_gpu_task_is, TaskArgument(NULL, 0), argmap, Predicate::TRUE_PRED, @@ -2993,6 +2991,12 @@ Op *FFModel::create_operator_from_layer( dims[num_dims].degree = 1; dims[num_dims].parallel_idx = -1; dims[num_dims].is_replica_dim = true; + if (config.computationMode == COMP_MODE_INFERENCE && + config.tensor_parallelism_degree > 1) { + dims[num_dims].size *= config.tensor_parallelism_degree; + dims[num_dims].degree *= config.tensor_parallelism_degree; + dims[num_dims].parallel_idx = 0; + } // create_parallel_tensor adds an NoOp into operators ParallelTensor pt = create_parallel_tensor_legion_ordering(num_dims + 1, @@ -3002,6 +3006,7 @@ Op *FFModel::create_operator_from_layer( 0, true /*gradients*/, tensor->tensor_guid); + assert(pt->get_shape().is_valid()); // assert that this tensor hasn't been mapped before assert(tensor->parallel_tensor == nullptr); tensor->parallel_tensor = pt; @@ -3260,12 +3265,12 @@ void FFModel::create_operators_from_layers() { if (config.computationMode == COMP_MODE_INFERENCE && config.tensor_parallelism_degree > 1 && l->op_type == OP_EMBEDDING) { assert(op->numOutputs == 1); - Replicate *repl = new Replicate(*this, - op->outputs[0], - op->outputs[0]->num_dims - 1, - config.tensor_parallelism_degree); - operators.push_back(repl); - op = repl; + // Replicate *repl = new Replicate(*this, + // op->outputs[0], + // op->outputs[0]->num_dims - 1, + // config.tensor_parallelism_degree); + // operators.push_back(repl); + // op = repl; } else if (config.computationMode == COMP_MODE_INFERENCE && config.tensor_parallelism_degree > 1 && (l->op_type == OP_INC_MULTIHEAD_SELF_ATTENTION || @@ -4076,6 +4081,10 @@ FFConfig::FFConfig() { Runtime *runtime = Runtime::get_runtime(); lg_hlr = runtime; lg_ctx = Runtime::get_context(); + Rect<1> task_rect(Point<1>(0), Point<1>(workersPerNode * numNodes - 1)); + // Create an index space for tasks running on all GPUs + all_gpu_task_is = runtime->create_index_space(lg_ctx, task_rect); + // field_space = runtime->create_field_space(lg_ctx); } From ac112037a8e88193d3377684ae2821d253551c2d Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Sat, 30 Dec 2023 15:09:19 -0500 Subject: [PATCH 11/30] more fix. --- include/flexflow/batch_config.h | 3 + src/ops/inc_multihead_self_attention.cu | 13 ++-- src/ops/kernels/embedding_kernels.cu | 2 +- .../specinfer_inc_multihead_self_attention.cu | 58 ++++++++--------- src/ops/tree_inc_multihead_self_attention.cu | 42 ++++++------ src/runtime/request_manager.cc | 65 ++++++++++--------- 6 files changed, 98 insertions(+), 85 deletions(-) diff --git a/include/flexflow/batch_config.h b/include/flexflow/batch_config.h index c3a75e59a4..8065e0f038 100644 --- a/include/flexflow/batch_config.h +++ b/include/flexflow/batch_config.h @@ -69,6 +69,9 @@ class BatchConfig { int first_token_offset_in_batch; int num_tokens_in_batch; int max_sequence_length; + + //request id in batch config: + int batch_config_request_id; RequestGuid request_guid; }; struct PerTokenInfo { diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index 2f16dd71c2..3b3879e8e5 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -82,6 +82,9 @@ __global__ void compute_attention_kernel_generation_kernel( // request idx int const request_idx = blockIdx.y; + int const batch_config_request_id = + request_infos[request_idx].batch_config_request_id; + int const beam_request_idx = is_beam ? request_idx / max_beam_width : request_idx; int const beam_sub_request_idx = is_beam ? request_idx % max_beam_width : 0; @@ -89,8 +92,8 @@ __global__ void compute_attention_kernel_generation_kernel( int const first_step = 0; int const tlength = - request_infos[beam_request_idx].first_token_depth_in_request + - request_infos[beam_request_idx].num_tokens_in_batch; + request_infos[batch_config_request_id].first_token_depth_in_request + + request_infos[batch_config_request_id].num_tokens_in_batch; // shared memory objects extern __shared__ char smem_[]; @@ -103,7 +106,7 @@ __global__ void compute_attention_kernel_generation_kernel( // first WARPS_PER_BLOCK for store qk_max, second WARPS_PER_BLOCK for sum __shared__ float red_smem[WARPS_PER_BLOCK * 2]; - const DT *q_ptr = query + beam_request_idx * hidden_size * QKV_WEIGHT_NUM + + const DT *q_ptr = query + batch_config_request_id * hidden_size * QKV_WEIGHT_NUM + head_idx * per_head_size; __shared__ Q_vec q_vecs[THREADS_PER_KEY][K_VECS_PER_THREAD]; // DT const *q_ptr = @@ -139,7 +142,7 @@ __global__ void compute_attention_kernel_generation_kernel( DT const *k_cache_batch = key_cache + - (beam_request_idx * max_beam_width + beam_sub_request_idx) * + (batch_config_request_id * max_beam_width + beam_sub_request_idx) * max_seq_length * hidden_size + ki; @@ -245,7 +248,7 @@ __global__ void compute_attention_kernel_generation_kernel( // The base pointer for the value in the cache buffer. DT const *v_cache_batch = value_cache + - (beam_request_idx * max_beam_width + beam_sub_request_idx) * + (batch_config_request_id * max_beam_width + beam_sub_request_idx) * max_seq_length * hidden_size + vi; diff --git a/src/ops/kernels/embedding_kernels.cu b/src/ops/kernels/embedding_kernels.cu index 3085fdb6ba..6947be432e 100644 --- a/src/ops/kernels/embedding_kernels.cu +++ b/src/ops/kernels/embedding_kernels.cu @@ -118,7 +118,7 @@ void forward_kernel_wrapper(EmbeddingMeta const *m, // print_tensor(output_ptr, output_domain.get_volume(), // "[Embedding:forward:output]"); } - print_tensor(input.get_int32_ptr(), 200, "embeddinginput"); + // print_tensor(input.get_int32_ptr(), 200, "embeddinginput"); } /*static*/ diff --git a/src/ops/specinfer_inc_multihead_self_attention.cu b/src/ops/specinfer_inc_multihead_self_attention.cu index e84ec3095c..8340519ff3 100644 --- a/src/ops/specinfer_inc_multihead_self_attention.cu +++ b/src/ops/specinfer_inc_multihead_self_attention.cu @@ -69,36 +69,43 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( int const tidx = threadIdx.x; // head id int const head_idx = blockIdx.x; - // request idx + // nth request idx int const request_idx = blockIdx.y; - BatchConfig::BitMask bitmask = causalMask[request_idx]; + // request id in batch config + int const batch_config_request_id = + request_infos[request_idx].batch_config_request_id; + + // request_idx = re + + BatchConfig::BitMask bitmask = causalMask[batch_config_request_id]; int const first_step = 0; - int const tlength = request_infos[request_idx].first_token_depth_in_request + - request_infos[request_idx].num_tokens_in_batch; + int const tlength = + request_infos[batch_config_request_id].first_token_depth_in_request + + request_infos[batch_config_request_id].num_tokens_in_batch; - // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { - // printf("specinfer attn fused kernel %lld\n", bitmask.mask[1]); - // } - + if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { + printf("specinfer attn fused kernel!!!\n"); + } int const totalCacheSize = bitmask.non_tree_cache_size + bitmask.tree_size; - // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { - // printf("specinfer attn fused kernel %d, %d\n", - // totalCacheSize,request_infos[request_idx].num_tokens_in_batch); - // } + if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { + printf("specinfer attn fused kernel %d, %d\n", + totalCacheSize, + request_infos[batch_config_request_id].num_tokens_in_batch); + } // int const qlength = request_infos[request_idx].num_tokens_in_batch; - int const tree_branch_num = beam_request_infos[request_idx].sub_request_num; + int const tree_branch_num = + beam_request_infos[batch_config_request_id].sub_request_num; // will decode qlength tokens in this thread block // int const qlength = tree_branch_num; int first_token_idx = 0; for (int r = 0; r < request_idx; r++) { - // first_token_idx += request_infos[request_idx].num_tokens_in_batch; first_token_idx += causalMask[r].this_layer_size; } @@ -135,7 +142,7 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; DT const *k_cache_batch = - key_cache + request_idx * max_seq_length * hidden_size + ki; + key_cache + batch_config_request_id * max_seq_length * hidden_size + ki; int ti_end = div_up(totalCacheSize - first_step, K_PER_WARP) * K_PER_WARP + first_step; @@ -166,10 +173,6 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( } float qk = scale * Qk_dot::dot(q_vecs[ki_o], k); - // if (blockIdx.y == 0 && blockIdx.x == 0) { - // printf("spec inc attn qkqkqk %d, %.10f, %d\n", ti, qk, sub_req_idx); - // } - if (ti < totalCacheSize && tidx % THREADS_PER_KEY == 0) { // todo add alobi here // bool const mask = ti_circ >= totalCacheSize; @@ -177,14 +180,8 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & (1 << query_token)))); - // if (blockIdx.y == 0 && blockIdx.x == 0 && sub_req_idx == 0) { - // printf("specinfer mask: ti:%d, %d, %d, %d, %lld\n", - // ti, - // totalCacheSize, - // bitmask.non_tree_cache_size, - // query_token, - // bitmask.mask[ti - bitmask.non_tree_cache_size]); - // // assert(false); + // if (blockIdx.y == 0 && blockIdx.x == 0 && !mask) { + // printf("spec inc attn qkqkqk %d, %.10f, %d\n", ti, qk, qi); // } qk_max = mask ? qk_max : fmaxf(qk_max, qk); qk_smem[ti - first_step] = mask ? 0.f : qk; @@ -271,7 +268,8 @@ __global__ void compute_specinfer_attention_kernel_generation_kernel( // The base pointer for the value in the cache buffer. DT const *v_cache_batch = - value_cache + request_idx * max_seq_length * hidden_size + vi; + value_cache + batch_config_request_id * max_seq_length * hidden_size + + vi; if (Dh == Dh_MAX || vi < Dh) { for (int ti = first_step + vo; ti < totalCacheSize; ti += V_PER_ITER) { @@ -461,6 +459,7 @@ void compute_specinfer_attention_kernel_generation( DT *output_ptr, cudaStream_t stream) { // one block == one head per request + printf("??? at here: %d\n", bc->num_active_requests()); dim3 grid(m->num_q_heads, bc->num_active_requests()); int const per_head_size = m->qProjSize; float scale = (*m->qk_prod_scaling) ? 1.0f / sqrt(m->kProjSize) : 1.0f; @@ -761,13 +760,14 @@ void inference_kernel(SpecInferIncMultiHeadSelfAttentionMeta const *m, // std::cout << "specinfer kernel token num: " << bc->num_generation_tokens // << ", " << bc->num_tokens << "\n"; if (bc->num_generation_tokens > 0) { + printf("spec inc generation decoding\n"); compute_specinfer_attention_kernel_generation
( m, bc, static_cast
(m->attn_heads), stream); } // phase 3: Compute attention score // 3 kernels for pahse 3: matmul1 - softmax - matmal2 if (bc->num_tokens > bc->num_generation_tokens) { - // printf("spec inc prompt decoding\n"); + printf("spec inc prompt decoding\n"); compute_attention_kernel_prompt( m, bc, shard_id, output_ptr, bias_ptr, weight_ptr, stream); } diff --git a/src/ops/tree_inc_multihead_self_attention.cu b/src/ops/tree_inc_multihead_self_attention.cu index 8641e63e38..a4329f52db 100644 --- a/src/ops/tree_inc_multihead_self_attention.cu +++ b/src/ops/tree_inc_multihead_self_attention.cu @@ -76,13 +76,16 @@ __global__ void compute_attention_kernel_fused_kernel( // request idx int const request_idx = blockIdx.y; + int const batch_config_request_id = + request_infos[request_idx].batch_config_request_id; + int const first_step = 0; - int const tlength = request_infos[request_idx].first_token_depth_in_request + - request_infos[request_idx].num_tokens_in_batch; - int const qlength = request_infos[request_idx].num_tokens_in_batch; + int const tlength = request_infos[batch_config_request_id].first_token_depth_in_request + + request_infos[batch_config_request_id].num_tokens_in_batch; + int const qlength = request_infos[batch_config_request_id].num_tokens_in_batch; - BatchConfig::BitMask bitmask = causalMask[request_idx]; + BatchConfig::BitMask bitmask = causalMask[batch_config_request_id]; // bitmask.mask[1] = 3; // if (head_idx == 0 && tidx == 0) { @@ -132,7 +135,7 @@ __global__ void compute_attention_kernel_fused_kernel( constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; DT const *k_cache_batch = - key_cache + request_idx * max_seq_length * hidden_size + ki; + key_cache + batch_config_request_id * max_seq_length * hidden_size + ki; int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step; @@ -189,14 +192,14 @@ __global__ void compute_attention_kernel_fused_kernel( qk_max = mask ? qk_max : fmaxf(qk_max, qk); - if (head_idx == 0 && qi == 0 && !mask) { - printf("tree attn qkqkqkqk request id %d, %d %.10f, %.10f, %.10f\n ", - request_idx, - ti, - qk, - q_vecs[ki_o][0].x, - k[0].x); - } + // if (head_idx == 0 && qi == 0 && !mask) { + // printf("tree attn qkqkqkqk request id %d, %d %.10f, %.10f, %.10f\n ", + // request_idx, + // ti, + // qk, + // q_vecs[ki_o][0].x, + // k[0].x); + // } qk_smem[ti - first_step] = mask ? 0.0f : qk; } } @@ -279,7 +282,7 @@ __global__ void compute_attention_kernel_fused_kernel( // The base pointer for the value in the cache buffer. DT const *v_cache_batch = - value_cache + request_idx * max_seq_length * hidden_size + vi; + value_cache + batch_config_request_id * max_seq_length * hidden_size + vi; // DT const *v_cache_batch = // value_cache + // (beam_request_idx * max_beam_width + beam_sub_request_idx) * @@ -481,8 +484,7 @@ __global__ void update_tree_branch_kv_cache_fused( int vProjSize, int num_new_tokens, int max_seq_len, - int hidden_size, - int first_token_depth) { + int hidden_size) { CUDA_KERNEL_LOOP(i, num_new_tokens * hidden_size) { int token_idx = i / hidden_size; @@ -498,10 +500,11 @@ __global__ void update_tree_branch_kv_cache_fused( int const request_token_offset = request_infos[req_id].first_token_offset_in_batch; + int const first_token_depth = request_infos[req_id].first_token_depth_in_request; // if(i % hidden_size == 0){ - // printf("update token request id: %d, %d, %d value%.10f\n", req_id, - // token_idx, request_token_offset, kVal); + // printf("update token request id: %d, %d, %d real id %d, value%.10f\n", req_id, + // token_idx, request_token_offset,(token_idx + first_token_depth - request_token_offset), kVal); // } kCache_ptr[req_id * (hidden_size * max_seq_len) + (token_idx + first_token_depth - request_token_offset) * @@ -890,8 +893,7 @@ void compute_attention_kernel_fused(TreeIncMultiHeadSelfAttentionMeta const *m, m->vProjSize, num_new_tokens, BatchConfig::max_sequence_length() + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM, - m->hidden_size, - bc->requestsInfo[0].first_token_depth_in_request); + m->hidden_size); dim3 grid(m->num_q_heads, bc->num_active_requests()); int const per_head_size = m->qProjSize; diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 5c3262eb27..e30a7ee478 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -364,6 +364,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, } } int num_generation_tokens = 0; + int num_active_req = -1; // Step 2: prepare the next batch for existing requests BatchConfig new_bc; @@ -454,6 +455,8 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, old_bc.requestsInfo[i].request_guid; new_bc.requestsInfo[i].max_sequence_length = old_bc.requestsInfo[i].max_sequence_length; + num_active_req++; + new_bc.requestsInfo[num_active_req].batch_config_request_id = i; if (new_bc.requestsInfo[i].first_token_depth_in_request + 1 == request.tokens.size()) { // Incremental phase @@ -490,6 +493,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, Request new_request = pending_request_queue.front(); pending_request_queue.pop(); // all_requests[new_request.guid] = new_request; + new_bc.requestsInfo[i].first_token_depth_in_request = 0; new_bc.requestsInfo[i].first_token_offset_in_batch = new_bc.num_tokens; new_bc.requestsInfo[i].request_guid = new_request.guid; @@ -499,6 +503,8 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, new_bc.requestsInfo[i].max_sequence_length = new_request.max_sequence_length; new_bc.request_completed[i] = false; + num_active_req++; + new_bc.requestsInfo[num_active_req].batch_config_request_id = i; // add profile_info for the new request ProfileInfo profile_info; profile_info.llm_decoding_steps = 1; @@ -574,6 +580,7 @@ BeamSearchBatchConfig int result_index = 0; int num_generation_tokens = 0; + int num_active_req = -1; for (int i = 0; i < BatchConfig::max_requests_per_batch(); i++) { if (old_bc.request_completed[i]) { @@ -596,10 +603,11 @@ BeamSearchBatchConfig } else { committed_tokens[guid].clear(); } + // iterate through all the tokens that belong to request i int root_abs_depth = request.tokens.size() - 1; - + while (result_index < old_bc.num_tokens && old_bc.tokensInfo[result_index].request_index == i) { int abs_depth = old_bc.tokensInfo[result_index].abs_depth_in_request; @@ -639,7 +647,7 @@ BeamSearchBatchConfig traverse_verify_tree(guid, dfs_tree_inputs.at(guid), tree_outputs); log_req_mgr.print("Number of Verified Tokens = %zu", - verified_tokens.size()); + verified_tokens.size()); // check if the request is finished if (verified_tokens.size() + request.tokens.size() >= request.max_sequence_length) { @@ -723,8 +731,10 @@ BeamSearchBatchConfig std::cout << "parse to next iteration: " << "\n"; + new_bc.request_completed[i] = false; new_bc.request_running[i] = true; + num_active_req++; // Normal Request Info new_bc.requestsInfo[i].first_token_depth_in_request = @@ -735,6 +745,7 @@ BeamSearchBatchConfig new_bc.requestsInfo[i].max_sequence_length = old_bc.requestsInfo[i].max_sequence_length; new_bc.requestsInfo[i].num_tokens_in_batch = verified_tokens.size(); + new_bc.requestsInfo[num_active_req].batch_config_request_id = i; // TODO: Beam Request Info, missing from VerifyTreeBatchConfig int new_max_depth = @@ -805,14 +816,15 @@ BeamSearchBatchConfig log_req_mgr.print("Output: %s", output.c_str()); } - if (request.tokens.size() > 19 && i >= 7) { - std::cout << request.tokens.size() << "\n"; - assert(false); - } + // if (request.tokens.size() > 19 && i >= 7) { + // std::cout << request.tokens.size() << "\n"; + // assert(false); + // } } else if (request.status == Request::PENDING) { new_bc.request_completed[i] = false; new_bc.request_running[i] = false; + num_active_req++; std::cout << "ssm_cache_size: " << request.ssm_cache_size << ", " << "initial_len: " << request.initial_len << std::endl; @@ -826,6 +838,7 @@ BeamSearchBatchConfig new_bc.requestsInfo[i].max_sequence_length = old_bc.requestsInfo[i].max_sequence_length; new_bc.requestsInfo[i].num_tokens_in_batch = 0; + new_bc.requestsInfo[num_active_req].batch_config_request_id = i; // TODO: Beam Request Info, missing from VerifyTreeBatchConfig new_bc.beamRequestsInfo[i].current_depth = 1; @@ -867,6 +880,7 @@ BeamSearchBatchConfig Request new_request = pending_request_queue.front(); pending_request_queue.pop(); // all_requests[new_request.guid] = new_request; + num_active_req++; new_bc.requestsInfo[i].first_token_depth_in_request = 0; new_bc.requestsInfo[i].first_token_offset_in_batch = new_bc.num_tokens; new_bc.requestsInfo[i].request_guid = new_request.guid; @@ -875,6 +889,7 @@ BeamSearchBatchConfig (int)new_request.tokens.size()); new_bc.requestsInfo[i].max_sequence_length = new_request.max_sequence_length; + new_bc.requestsInfo[num_active_req].batch_config_request_id = i; // add profile_info for the new request ProfileInfo profile_info; @@ -967,6 +982,8 @@ BeamSearchBatchConfig old_bc.print(); new_bc.print(); } + std::cout << "prepare next batch init active tokens: " + << new_bc.num_tokens << "\n"; return new_bc; } @@ -1027,10 +1044,12 @@ BeamSearchBatchConfig int num_generation_tokens = 0; // Add incremental tokens to the batch + int num_active_req = -1; for (int i = 0; i < BatchConfig::max_requests_per_batch(); i++) { if (old_bc.request_completed[i] || !old_bc.request_running[i]) { continue; } + num_active_req ++; // Comment out this assertion since num_tokens_in_batch can be // zero when beam search has reached required sequence length // assert(old_bc.requestsInfo[i].num_tokens_in_batch > 0); @@ -1040,29 +1059,6 @@ BeamSearchBatchConfig // assert(processed_tokens < request.tokens.size()); log_req_mgr.debug() << "processed_tokens: " << processed_tokens << "\n"; - // if (processed_tokens > - // old_bc.beamRequestsInfo[i].max_depth + request.tokens.size() && - // request.status == Request::RUNNING - // // || ir.results[t] == 0 TODO: replace this with - // ) { - // // log_req_mgr.print("[Done] guid(%zu) with spec_tree_depth(%d)", - // // old_bc.requestsInfo[i].request_guid, - // // old_bc.beamRequestsInfo[i].max_depth); - // // // new_bc.request_completed[i] = true; - // // new_bc.request_completed[i] = false; - // // new_bc.requestsInfo[i].first_token_depth_in_request = - // processed_tokens; - // // new_bc.requestsInfo[i].request_guid = - // // old_bc.requestsInfo[i].request_guid; - // // new_bc.requestsInfo[i].max_sequence_length = - // // old_bc.requestsInfo[i].max_sequence_length; - // // new_bc.beamRequestsInfo[i].current_depth = - // // old_bc.beamRequestsInfo[i].current_depth; - // // new_bc.request_running[i] = false; - // std::cout << "beam search end:" << request.status << i << ", " - // << new_bc.requestsInfo[i].num_tokens_in_batch << "\n"; - // } - // else { log_req_mgr.debug() << "num tokens: " << old_bc.num_tokens << ", " << new_bc.num_tokens; @@ -1073,6 +1069,7 @@ BeamSearchBatchConfig new_bc.requestsInfo[i].max_sequence_length = old_bc.requestsInfo[i].max_sequence_length; profiling_requests[request.guid].ssm_decoding_steps += 1; + new_bc.requestsInfo[num_active_req].batch_config_request_id = i; // update the beam search metadata // how many sub request in current request // why is sub_requests has max_requests_per_batch() * MAX_BEAM_WIDTH @@ -1164,6 +1161,7 @@ BeamSearchBatchConfig // std::cout << "nodes: " << tree.treeLayers[k].nodes_num_this_layer // << "\n"; // } + std::cout << "append bit mask: "<< i << "\n"; appendBitMask(new_bc.causalMask[i], new_bc.beamRequestsInfo[i].sub_request_num, old_bc.beamRequestsInfo[i].beam_size, @@ -1198,6 +1196,7 @@ BeamSearchBatchConfig if (old_bc.request_completed[i] || old_bc.request_running[i]) { continue; } + num_active_req++; // Comment out this assertion since num_tokens_in_batch can be // zero when beam search has reached required sequence length // assert(old_bc.requestsInfo[i].num_tokens_in_batch > 0); @@ -1217,6 +1216,7 @@ BeamSearchBatchConfig new_bc.requestsInfo[i].request_guid = old_bc.requestsInfo[i].request_guid; new_bc.requestsInfo[i].max_sequence_length = old_bc.requestsInfo[i].max_sequence_length; + new_bc.requestsInfo[num_active_req].batch_config_request_id = i; // update the beam search metadata // how many sub request in current request @@ -1330,6 +1330,8 @@ BeamSearchBatchConfig // std::cout << "Current Beam DepthBBB: " // << old_bc.beamRequestsInfo[0].current_depth << "\n"; } + std::cout << "prepare next batch beam total tokens: " << new_bc.num_tokens + << "gneration tokens: " << new_bc.num_generation_tokens << "\n"; return new_bc; } @@ -1384,11 +1386,12 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( max_prompt_load_size -= 1; } } - + int num_active_req = -1; for (int i = 0; i < TreeVerifyBatchConfig::max_requests_per_batch(); i++) { if (old_batches.at(0).request_completed[i]) { continue; } + num_active_req++; size_t guid = old_batches.at(0).requestsInfo[i].request_guid; Request &request = all_requests[guid]; @@ -1432,6 +1435,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( old_batches.at(0).requestsInfo[i].request_guid; new_bc.requestsInfo[i].max_sequence_length = old_batches.at(0).requestsInfo[i].max_sequence_length; + new_bc.requestsInfo[num_active_req].batch_config_request_id = i; // copy bitmask to verify batchconfig memcpy(&(new_bc.causalMask[i]), @@ -1590,6 +1594,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( old_batches.at(0).requestsInfo[i].request_guid; new_bc.requestsInfo[i].max_sequence_length = old_batches.at(0).requestsInfo[i].max_sequence_length; + new_bc.requestsInfo[num_active_req].batch_config_request_id = i; new_bc.request_completed[i] = false; From 7eaffbc480b05d674bbf465c903b2277f6240e0b Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Sat, 30 Dec 2023 17:24:08 -0500 Subject: [PATCH 12/30] clean up --- include/flexflow/batch_config.h | 2 +- include/flexflow/ffconst.h | 1 - include/flexflow/model.h | 45 - include/flexflow/operator_params.h | 2 - .../ops/spec_inc_multihead_self_attention.h | 1 + .../specinfer_inc_multihead_self_attention.h | 150 --- ...nfer_inc_multihead_self_attention_params.h | 33 - include/flexflow/request_manager.h | 2 + inference/file_loader.cc | 3 +- inference/models/llama.cc | 2 +- src/ops/argmax.cc | 1 - src/ops/beam_topk.cc | 7 +- src/ops/beam_topk.cu | 39 +- src/ops/inc_multihead_self_attention.cu | 3 +- src/ops/kernels/embedding_kernels.cu | 1 - src/ops/spec_inc_multihead_self_attention.cc | 12 +- src/ops/spec_inc_multihead_self_attention.cu | 1011 +++++++++++------ .../specinfer_inc_multihead_self_attention.cc | 883 -------------- .../specinfer_inc_multihead_self_attention.cu | 958 ---------------- .../tree attn kernel, 0----> -0.029753357172 | 1 - src/ops/tree_inc_multihead_self_attention.cu | 122 +- src/runtime/ffconst_utils.cc | 2 - src/runtime/graph.cc | 71 +- src/runtime/inference_manager.cc | 8 +- src/runtime/model.cc | 149 +-- src/runtime/model.cpp | 4 +- src/runtime/model.cu | 5 +- src/runtime/request_manager.cc | 288 ++--- src/runtime/request_manager.cu | 1 - 29 files changed, 835 insertions(+), 2972 deletions(-) delete mode 100644 include/flexflow/ops/specinfer_inc_multihead_self_attention.h delete mode 100644 include/flexflow/ops/specinfer_inc_multihead_self_attention_params.h delete mode 100644 src/ops/specinfer_inc_multihead_self_attention.cc delete mode 100644 src/ops/specinfer_inc_multihead_self_attention.cu delete mode 100644 src/ops/tree attn kernel, 0----> -0.029753357172 diff --git a/include/flexflow/batch_config.h b/include/flexflow/batch_config.h index 8065e0f038..13904aaa46 100644 --- a/include/flexflow/batch_config.h +++ b/include/flexflow/batch_config.h @@ -70,7 +70,7 @@ class BatchConfig { int num_tokens_in_batch; int max_sequence_length; - //request id in batch config: + // request id in batch config: int batch_config_request_id; RequestGuid request_guid; }; diff --git a/include/flexflow/ffconst.h b/include/flexflow/ffconst.h index ef0003b08e..512645e624 100644 --- a/include/flexflow/ffconst.h +++ b/include/flexflow/ffconst.h @@ -171,7 +171,6 @@ enum OperatorType { OP_INC_MULTIHEAD_SELF_ATTENTION, OP_SPEC_INC_MULTIHEAD_SELF_ATTENTION, OP_TREE_INC_MULTIHEAD_SELF_ATTENTION, - OP_SPECINFER_INC_MULTIHEAD_SELF_ATTENTION, OP_SAMPLING, // Parallel Ops OP_REPARTITION, diff --git a/include/flexflow/model.h b/include/flexflow/model.h index 9cdbec64a9..16df99ab1a 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -172,8 +172,6 @@ enum TaskIDs { SPEC_INC_MULTIHEAD_SELF_ATTENTION_INF_TASK_ID, TREE_INC_MULTIHEAD_SELF_ATTENTION_INIT_TASK_ID, TREE_INC_MULTIHEAD_SELF_ATTENTION_INF_TASK_ID, - SPECINFER_INC_MULTIHEAD_SELF_ATTENTION_INIT_TASK_ID, - SPECINFER_INC_MULTIHEAD_SELF_ATTENTION_INF_TASK_ID, MSELOSS_BWD_TASK_ID, FUSEDOP_INIT_TASK_ID, FUSEDOP_FWD_TASK_ID, @@ -327,7 +325,6 @@ class Linear; class MultiHeadAttention; class IncMultiHeadSelfAttention; class TreeIncMultiHeadSelfAttention; -class SpecInferIncMultiHeadSelfAttention; class Pool2D; class Reduce; class Reshape; @@ -747,25 +744,6 @@ class FFModel { bool qk_prod_scaling = true, bool position_bias = false, char const *name = NULL); - -Tensor specinfer_inc_multihead_self_attention( - const Tensor input, - int embed_dim, - int num_heads, - int kdim = 0, - int vdim = 0, - float dropout = 0.0f, - bool bias = false, - bool add_bias_kv = false, - bool add_zero_attn = false, - DataType data_type = DT_NONE, - Initializer *kernel_initializer = NULL, - bool apply_rotary_embedding = false, - bool scaling_query = false, - float scaling_factor = 1.0f, - bool qk_prod_scaling = true, - bool position_bias = false, - char const *name = NULL); Tensor inc_multiquery_self_attention(const Tensor input, int embed_dim, int num_q_heads, @@ -822,26 +800,6 @@ Tensor specinfer_inc_multihead_self_attention( bool qk_prod_scaling = true, bool position_bias = false, char const *name = NULL); - - Tensor specinfer_inc_multiquery_self_attention( - const Tensor input, - int embed_dim, - int num_q_heads, - int num_kv_heads, - int kdim = 0, - int vdim = 0, - float dropout = 0.0f, - bool bias = false, - bool add_bias_kv = false, - bool add_zero_attn = false, - DataType data_type = DT_NONE, - Initializer *kernel_initializer = NULL, - bool apply_rotary_embedding = false, - bool scaling_query = false, - float scaling_factor = 1.0f, - bool qk_prod_scaling = true, - bool position_bias = false, - char const *name = NULL); // ======================================== // Inference APIs // ======================================== @@ -1243,9 +1201,6 @@ Tensor specinfer_inc_multihead_self_attention( std::unordered_map< std::pair, TreeIncMultiHeadSelfAttention *>, - std::unordered_map< - std::pair, - SpecInferIncMultiHeadSelfAttention *>, std::unordered_map, Reduce *>, std::unordered_map, diff --git a/include/flexflow/operator_params.h b/include/flexflow/operator_params.h index cee2ae95a4..5b187839ef 100644 --- a/include/flexflow/operator_params.h +++ b/include/flexflow/operator_params.h @@ -37,7 +37,6 @@ #include "flexflow/ops/topk_params.h" #include "flexflow/ops/transpose_params.h" #include "flexflow/ops/tree_inc_multihead_self_attention_params.h" -#include "flexflow/ops/specinfer_inc_multihead_self_attention_params.h" #include "flexflow/parallel_ops/allreduce_params.h" #include "flexflow/parallel_ops/combine_params.h" #include "flexflow/parallel_ops/fused_parallel_op_params.h" @@ -73,7 +72,6 @@ using OperatorParameters = mp::variant -#include - -namespace FlexFlow { - -class SpecInferIncMultiHeadSelfAttentionMeta; - -class SpecInferIncMultiHeadSelfAttention : public Op { -public: - using Params = SpecInferIncMultiHeadSelfAttentionParams; - using Input = ParallelTensor; - - SpecInferIncMultiHeadSelfAttention(FFModel &model, - LayerID const &layer_guid, - const ParallelTensor _input, - int _embed_dim, - int _num_q_heads, - int _num_kv_heads, - int _kdim, - int _vdim, - float _dropout, - bool _qkv_bias, - bool _final_bias, - bool _add_zero_attn, - bool _apply_rotary_embedding, - bool _scaling_query, - float _scaling_factor, - bool _qk_prod_scaling, - bool _position_bias, - bool allocate_weights, - char const *name); - SpecInferIncMultiHeadSelfAttention(FFModel &model, - const ParallelTensor _input, - const ParallelTensor _weight, - int _embed_dim, - int _num_q_heads, - int _num_kv_heads, - int _kdim, - int _vdim, - float _dropout, - bool _qkv_bias, - bool _final_bias, - bool _add_zero_attn, - bool _apply_rotary_embedding, - bool _scaling_query, - float _scaling_factor, - bool _qk_prod_scaling, - bool _position_bias, - bool allocate_weights, - char const *name); - SpecInferIncMultiHeadSelfAttention(FFModel &model, - SpecInferIncMultiHeadSelfAttention const &other, - const ParallelTensor input, - bool allocate_weights); - SpecInferIncMultiHeadSelfAttention(FFModel &model, - Params const ¶ms, - Input const &inputs, - bool allocate_weights = false, - char const *name = nullptr); - static Op * - create_operator_from_layer(FFModel &model, - Layer const *layer, - std::vector const &inputs); - void init(FFModel const &) override; - void init_inference(FFModel const &, - std::vector const &, - std::vector const &, - MachineView const *mv = nullptr) override; - void forward(FFModel const &) override; - void backward(FFModel const &) override; - Legion::FutureMap inference(FFModel const &, - BatchConfigFuture const &, - std::vector const &, - std::vector const &, - MachineView const *mv = nullptr) override; - void print_layer(FFModel const &model) override { - assert(0); - } - bool get_int_parameter(PMParameter, int *) const override; - - static OpMeta *init_task(Legion::Task const *task, - std::vector const ®ions, - Legion::Context ctx, - Legion::Runtime *runtime); - static void inference_task(Legion::Task const *task, - std::vector const ®ions, - Legion::Context ctx, - Legion::Runtime *runtime); - Op *materialize(FFModel &ff, - ParallelTensor inputs[], - int num_inputs) const override; - bool measure_operator_cost(Simulator *sim, - MachineView const &mv, - CostMetrics &cost_metrics) const override; - - static void - inference_kernel_wrapper(SpecInferIncMultiHeadSelfAttentionMeta const *m, - BeamSearchBatchConfig const *bc, - int shard_id, - GenericTensorAccessorR const &input, - GenericTensorAccessorR const &weight, - GenericTensorAccessorW const &output, - GenericTensorAccessorR const &bias); - Params get_params() const; - -public: - int num_q_heads, num_kv_heads, tensor_parallelism_degree; - float dropout, scaling_factor; - bool qkv_bias; - bool final_bias, add_zero_attn, apply_rotary_embedding, scaling_query, - qk_prod_scaling, position_bias; - int qSize, kSize, vSize, qProjSize, kProjSize, vProjSize, oProjSize; - int qoSeqLength, kvSeqLength; -}; - -class SpecInferIncMultiHeadSelfAttentionMeta : public IncMultiHeadSelfAttentionMeta { -public: - SpecInferIncMultiHeadSelfAttentionMeta(FFHandler handler, - SpecInferIncMultiHeadSelfAttention const *attn, - GenericTensorAccessorR const &weight, - MemoryAllocator &gpu_mem_allocator, - int num_samples, - int _num_q_heads, - int _num_kv_heads); - ~SpecInferIncMultiHeadSelfAttentionMeta(void); - -public: - Realm::RegionInstance beam_search_reserve_inst; - BeamSearchBatchConfig::BeamSearchPerTokenInfo *beam_token_infos; - BeamSearchBatchConfig::BeamSearchPerRequestInfo *beam_request_infos; - BatchConfig::BitMask *causalMask; -}; - -}; // namespace FlexFlow - -#endif // _FLEXFLOW_SPECINFER_INC_MULTIHEAD_SELF_ATTENTION_H diff --git a/include/flexflow/ops/specinfer_inc_multihead_self_attention_params.h b/include/flexflow/ops/specinfer_inc_multihead_self_attention_params.h deleted file mode 100644 index b57b06a7f7..0000000000 --- a/include/flexflow/ops/specinfer_inc_multihead_self_attention_params.h +++ /dev/null @@ -1,33 +0,0 @@ -#ifndef _FLEXFLOW_SPECINFER_INC_MULTIHEAD_SELF_ATTENTION_PARAMS_H -#define _FLEXFLOW_SPECINFER_INC_MULTIHEAD_SELF_ATTENTION_PARAMS_H - -#include "flexflow/ffconst.h" -#include "flexflow/fftype.h" -#include "flexflow/parallel_tensor.h" - -namespace FlexFlow { - -struct SpecInferIncMultiHeadSelfAttentionParams { - LayerID layer_guid; - int embed_dim, num_q_heads, num_kv_heads, kdim, vdim; - float dropout, scaling_factor; - bool qkv_bias, final_bias, add_zero_attn, apply_rotary_embedding, - scaling_query, qk_prod_scaling, position_bias; - - bool is_valid(ParallelTensorShape const &) const; -}; - -bool operator==(SpecInferIncMultiHeadSelfAttentionParams const &, - SpecInferIncMultiHeadSelfAttentionParams const &); - -} // namespace FlexFlow - -namespace std { -template <> -struct hash { - size_t - operator()(FlexFlow::SpecInferIncMultiHeadSelfAttentionParams const &) const; -}; -} // namespace std - -#endif // _FLEXFLOW_SPECINFER_INC_MULTIHEAD_SELF_ATTENTION_PARAMS_H diff --git a/include/flexflow/request_manager.h b/include/flexflow/request_manager.h index 8cb45e55b4..1c4b0b2a2f 100644 --- a/include/flexflow/request_manager.h +++ b/include/flexflow/request_manager.h @@ -231,6 +231,8 @@ class RequestManager { int max_requests_per_batch; int max_tokens_per_batch; int max_sequence_length; + + // tree width in each speculative step, if not specified 1 std::vector spec_infer_tree_width; // private fields std::unique_ptr tokenizer_; diff --git a/inference/file_loader.cc b/inference/file_loader.cc index 3f70ddf488..7c6870d439 100644 --- a/inference/file_loader.cc +++ b/inference/file_loader.cc @@ -726,8 +726,7 @@ void FileDataLoader::load_single_weight_tensor(FFModel *ff, if (l->op_type == OP_INC_MULTIHEAD_SELF_ATTENTION || l->op_type == OP_SPEC_INC_MULTIHEAD_SELF_ATTENTION || - l->op_type == OP_TREE_INC_MULTIHEAD_SELF_ATTENTION || - l->op_type == OP_SPECINFER_INC_MULTIHEAD_SELF_ATTENTION) { + l->op_type == OP_TREE_INC_MULTIHEAD_SELF_ATTENTION) { if (weight_filename.find("self_attention") != std::string::npos) { load_attention_weights_multi_query( data, weight_filename, weights_folder, hidden_dim, num_heads); diff --git a/inference/models/llama.cc b/inference/models/llama.cc index 4f76e9e0fa..10001ee916 100644 --- a/inference/models/llama.cc +++ b/inference/models/llama.cc @@ -90,7 +90,7 @@ void LLAMA::create_llama_model(FFModel &ff, Tensor mha; switch (mode) { case BEAM_SEARCH_MODE: { - mha = ff.specinfer_inc_multihead_self_attention( + mha = ff.spec_inc_multihead_self_attention( att_norm, llama_config.hidden_size, llama_config.num_attention_heads, diff --git a/src/ops/argmax.cc b/src/ops/argmax.cc index d195a5af75..c3bb3d493e 100644 --- a/src/ops/argmax.cc +++ b/src/ops/argmax.cc @@ -399,7 +399,6 @@ InferenceResult m, shard_id, bc, {}, {}, {input, indices}); } - // print_tensor(indices.get_int32_ptr(), 199, "tree attn output"); download_tensor( indices.get_int32_ptr(), ir.token_ids, batch_size); return ir; diff --git a/src/ops/beam_topk.cc b/src/ops/beam_topk.cc index 5dfaae41ee..87d357b535 100644 --- a/src/ops/beam_topk.cc +++ b/src/ops/beam_topk.cc @@ -370,14 +370,10 @@ BeamInferenceResult Domain input_domain = runtime->get_index_space_domain( ctx, task->regions[0].region.get_index_space()); - - printf("----------1-----------\n"); + int *index_ptr = index.get_int32_ptr(); - printf("----------2-----------\n"); float *value_ptr = value.get_float_ptr(); - printf("----------3-----------\n"); int *parent_ptr = parent.get_int32_ptr(); - printf("----------4-----------\n"); // embedding size: eg. 4096 int length = input_domain.hi()[0] - input_domain.lo()[0] + 1; @@ -404,7 +400,6 @@ BeamInferenceResult // print_tensor(index_ptr, 32, "indexxxxxxx"); - if (m->inference_debugging) { assert(task->index_point.get_dim() == 1); int shard_id = task->index_point.point_data[0]; diff --git a/src/ops/beam_topk.cu b/src/ops/beam_topk.cu index d647fe9ed7..a958786be3 100644 --- a/src/ops/beam_topk.cu +++ b/src/ops/beam_topk.cu @@ -379,9 +379,9 @@ template __global__ void mergeSubRequestsKernel(int64_t N, T const *X, T const *rstd, T *Y) { using T_ACC = T; - int64_t const i = blockIdx.x; + const int64_t i = blockIdx.x; for (int64_t j = threadIdx.x; j < N; j += blockDim.x) { - int64_t const index = i * N + j; + const int64_t index = i * N + j; Y[index] = static_cast(X[index]) * static_cast(rstd[i]); } } @@ -556,7 +556,6 @@ void BeamTopK::forward_kernel(BeamTopKMeta const *m, int beam_size = bc->beamRequestsInfo[i].beam_size; // initial request - std::cout << "sub_requests: " << i << ", " << sub_requests[i] << "\n"; assert(sub_requests[i] > 0); // process sub requests for (int j = 0; j < sub_requests[i]; j++) { @@ -564,12 +563,13 @@ void BeamTopK::forward_kernel(BeamTopKMeta const *m, // beam_slots[i].parent_id[j]; acc_probs[req_index * BeamSearchBatchConfig::MAX_BEAM_WIDTH + j] = bc->beamRequestsInfo[i].probs[j]; - std::cout << "probbbb req: " << i << ", sub req probability : " - << bc->beamRequestsInfo[i].probs[j] << ", sub request id " << j - << ", parent id " << bc->beamRequestsInfo[i].parent_id[j] - << ", data inddd" - << req_index * BeamSearchBatchConfig::MAX_BEAM_WIDTH + j - << "\n"; + // std::cout << "probbbb req: " << i << ", sub req probability : " + // << bc->beamRequestsInfo[i].probs[j] << ", sub request id " << + // j + // << ", parent id " << bc->beamRequestsInfo[i].parent_id[j] + // << ", data inddd" + // << req_index * BeamSearchBatchConfig::MAX_BEAM_WIDTH + j + // << "\n"; } // process tokens @@ -584,7 +584,6 @@ void BeamTopK::forward_kernel(BeamTopKMeta const *m, max_heap_size = std::max(max_heap_size, beam_size * sub_requests[i]); max_beam_width = std::max(max_beam_width, beam_size); - std::cout << "max beam width: " << max_beam_width << "\n"; req_index += 1; block_start_index += (sub_requests[i] - 1) * num_new_tokens * length; } @@ -625,23 +624,23 @@ void BeamTopK::forward_kernel(BeamTopKMeta const *m, cudaMemcpyHostToDevice, stream)); // trick, set acc_probs to 0; - checkCUDA( - cudaMemsetAsync(m->acc_probs, 1.0, max_total_requests * sizeof(DT), stream)); + checkCUDA(cudaMemsetAsync( + m->acc_probs, 1.0, max_total_requests * sizeof(DT), stream)); checkCUDA(cudaMemcpyAsync(m->block_start_index, beam_block_start_index.data(), sizeof(int) * beam_num_blocks, cudaMemcpyHostToDevice, stream)); checkCUDA(cudaMemcpyAsync(m->request_id, - request_id.data(), - sizeof(int) * beam_num_blocks, - cudaMemcpyHostToDevice, - stream)); + request_id.data(), + sizeof(int) * beam_num_blocks, + cudaMemcpyHostToDevice, + stream)); checkCUDA(cudaMemcpyAsync(m->tokens_per_request, - tokens_per_request.data(), - sizeof(int) * beam_num_blocks, - cudaMemcpyHostToDevice, - stream)); + tokens_per_request.data(), + sizeof(int) * beam_num_blocks, + cudaMemcpyHostToDevice, + stream)); // int depth = // bc->beamRequestsInfo[bc->tokensInfo[0].request_index].current_depth; beam_num_blocks = bc->num_active_tokens(); diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index 3b3879e8e5..cca0b230c3 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -106,7 +106,8 @@ __global__ void compute_attention_kernel_generation_kernel( // first WARPS_PER_BLOCK for store qk_max, second WARPS_PER_BLOCK for sum __shared__ float red_smem[WARPS_PER_BLOCK * 2]; - const DT *q_ptr = query + batch_config_request_id * hidden_size * QKV_WEIGHT_NUM + + const DT *q_ptr = query + + batch_config_request_id * hidden_size * QKV_WEIGHT_NUM + head_idx * per_head_size; __shared__ Q_vec q_vecs[THREADS_PER_KEY][K_VECS_PER_THREAD]; // DT const *q_ptr = diff --git a/src/ops/kernels/embedding_kernels.cu b/src/ops/kernels/embedding_kernels.cu index 6947be432e..22d8161ff1 100644 --- a/src/ops/kernels/embedding_kernels.cu +++ b/src/ops/kernels/embedding_kernels.cu @@ -118,7 +118,6 @@ void forward_kernel_wrapper(EmbeddingMeta const *m, // print_tensor(output_ptr, output_domain.get_volume(), // "[Embedding:forward:output]"); } - // print_tensor(input.get_int32_ptr(), 200, "embeddinginput"); } /*static*/ diff --git a/src/ops/spec_inc_multihead_self_attention.cc b/src/ops/spec_inc_multihead_self_attention.cc index eb6fd721e6..5d234df822 100644 --- a/src/ops/spec_inc_multihead_self_attention.cc +++ b/src/ops/spec_inc_multihead_self_attention.cc @@ -53,7 +53,7 @@ bool SpecIncMultiHeadSelfAttentionParams::is_valid( } Tensor - FFModel::spec_inc_multihead_self_attention(const Tensor input, + FFModel::spec_inc_multihead_self_attention(Tensor const input, int embed_dim, int num_heads, int kdim, @@ -91,7 +91,7 @@ Tensor } Tensor - FFModel::spec_inc_multiquery_self_attention(const Tensor input, + FFModel::spec_inc_multiquery_self_attention(Tensor const input, int embed_dim, int num_q_heads, int num_kv_heads, @@ -257,7 +257,7 @@ Op *SpecIncMultiHeadSelfAttention::create_operator_from_layer( SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( FFModel &model, LayerID const &_layer_guid, - const ParallelTensor _input, + ParallelTensor const _input, int _embed_dim, int _num_q_heads, int _num_kv_heads, @@ -358,8 +358,8 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( FFModel &model, - const ParallelTensor _input, - const ParallelTensor _weight, + ParallelTensor const _input, + ParallelTensor const _weight, int _embed_dim, int _num_q_heads, int _num_kv_heads, @@ -465,7 +465,7 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( FFModel &model, SpecIncMultiHeadSelfAttention const &other, - const ParallelTensor input, + ParallelTensor const input, bool allocate_weights) : SpecIncMultiHeadSelfAttention(model, other.layer_guid, diff --git a/src/ops/spec_inc_multihead_self_attention.cu b/src/ops/spec_inc_multihead_self_attention.cu index 29e3d9a48d..b3a87fe244 100644 --- a/src/ops/spec_inc_multihead_self_attention.cu +++ b/src/ops/spec_inc_multihead_self_attention.cu @@ -23,16 +23,295 @@ namespace FlexFlow { +#define WARP_SIZE 32 + // declare Legion names using Legion::coord_t; using Legion::Memory; using namespace Kernels::IncMultiHeadAttention; namespace Kernels { -namespace SpecIncMultiHeadAttention { +namespace SpecIncMultiHeadSelfAttention { + +template +__global__ void compute_spec_inc_attention_kernel_generation_kernel( + DT const *query, + DT const *key_cache, + DT const *value_cache, + DT *output_ptr, + float const scale, + int const max_seq_length, + int per_head_size, + int hidden_size, + BatchConfig::PerRequestInfo *request_infos, + BeamSearchBatchConfig::BeamSearchPerRequestInfo *beam_request_infos, + BatchConfig::BitMask *causalMask) { + + // q, k + using Q_vec = typename VEC_K::Type; + using K_vec = typename VEC_K::Type; + using V_vec = typename VEC_V
::Type; + using Out_sum = typename Vec_fp32_::Type; + + constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; + + constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(DT); + constexpr int K_ELTS_PER_THREAD = Dh / THREADS_PER_KEY; + constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; + // constexpr int QK_ELTS_IN_16B = 16 / sizeof(DT); + + // thread id + int const tidx = threadIdx.x; + // head id + int const head_idx = blockIdx.x; + // nth request idx + int const request_idx = blockIdx.y; + + // request id in batch config + int const batch_config_request_id = + request_infos[request_idx].batch_config_request_id; + + // request_idx = re + + BatchConfig::BitMask bitmask = causalMask[batch_config_request_id]; + + int const first_step = 0; + + int const tlength = + request_infos[batch_config_request_id].first_token_depth_in_request + + request_infos[batch_config_request_id].num_tokens_in_batch; + + int const totalCacheSize = bitmask.non_tree_cache_size + bitmask.tree_size; + + // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { + // printf("spec inc attn fused kernel %d, %d\n", + // totalCacheSize, + // request_infos[batch_config_request_id].num_tokens_in_batch); + // } + // int const qlength = request_infos[request_idx].num_tokens_in_batch; + int const tree_branch_num = + beam_request_infos[batch_config_request_id].sub_request_num; + + // will decode qlength tokens in this thread block + // int const qlength = tree_branch_num; + + int first_token_idx = 0; + for (int r = 0; r < request_idx; r++) { + first_token_idx += causalMask[r].this_layer_size; + } + + // shared memory objects + extern __shared__ char smem_[]; + + float *qk_smem = reinterpret_cast(smem_); + float *out_smem = reinterpret_cast(smem_); + + float qk_max = -FLT_MAX; + + // first WARPS_PER_BLOCK for store qk_max, second WARPS_PER_BLOCK for sum + __shared__ float red_smem[WARPS_PER_BLOCK * 2]; + + const DT *q_ptr = query + first_token_idx * hidden_size * QKV_WEIGHT_NUM + + head_idx * per_head_size; + __shared__ Q_vec q_vecs[THREADS_PER_KEY][K_VECS_PER_THREAD]; + + // the start offset of the element eg. (0, 1, 2, 3) * K_VEC_SIZE + int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE; + int ki_o = tidx % THREADS_PER_KEY; + // the first key's offset for this thread + // ko = 0, 0, 0, 0, 1, 1, 1, 1, .... + int ko = tidx / THREADS_PER_KEY; + // load q tensor + Q_vec q_vec[K_VECS_PER_THREAD]; + + constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; + // The number of keys per warp. + constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; + + DT const *k_cache_batch = + key_cache + batch_config_request_id * max_seq_length * hidden_size + ki; + + int ti_end = + div_up(totalCacheSize - first_step, K_PER_WARP) * K_PER_WARP + first_step; + + for (int qi = 0; qi < tree_branch_num; qi += 1) { +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + q_vecs[ki_o][ii] = *reinterpret_cast( + q_ptr + (hidden_size * QKV_WEIGHT_NUM * qi) + ki + + ii * THREADS_PER_KEY * K_VEC_SIZE); + } + + int const query_token = bitmask.tree_size - tree_branch_num + qi; + + __syncthreads(); + for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { + K_vec k[K_VECS_PER_THREAD]; + int const ti_circ = ti % max_seq_length; + + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * THREADS_PER_KEY * K_VEC_SIZE; + if (ti < totalCacheSize) { + + k[ii] = *reinterpret_cast( + k_cache_batch + ti_circ * hidden_size + head_idx * per_head_size + + jj); + } + } + float qk = scale * Qk_dot::dot(q_vecs[ki_o], k); + + if (ti < totalCacheSize && tidx % THREADS_PER_KEY == 0) { + // todo add alobi here + // bool const mask = ti_circ >= totalCacheSize; + bool const mask = (ti >= bitmask.non_tree_cache_size && + (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & + (1 << query_token)))); + + // if (blockIdx.y == 0 && blockIdx.x == 0 && !mask) { + // printf("spec inc attn qkqkqk %d, %.10f, %d\n", ti, qk, qi); + // } + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + qk_smem[ti - first_step] = mask ? 0.f : qk; + } + } + + __syncthreads(); + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + // Decompose the thread index into warp and lane. + int const warp = tidx / WARP_SIZE; + int const lane = tidx % WARP_SIZE; + + // The warp leader writes the max to shared memory. + if (lane == 0) { + red_smem[warp] = qk_max; + } + + // Make sure the products are in shared memory. + __syncthreads(); + + // The warps finalize the reduction. + qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + // Broadcast to all the threads in the warp. + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + + // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { + // printf("spec inc attn first token qk_max %.10f\n", qk_max); + // } + + float exp_sum = 0.f; + for (int ti = first_step + tidx; ti < totalCacheSize; + ti += THREADS_PER_BLOCK) { + bool const mask = (ti >= bitmask.non_tree_cache_size && + (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & + (1 << query_token)))); + float logit = mask ? 0.0f : __expf(qk_smem[ti - first_step] - qk_max); + exp_sum += logit; + qk_smem[ti - first_step] = mask ? 0.0f : logit; + } + + // Compute the sum. + exp_sum = block_sum(&red_smem[WARPS_PER_BLOCK], exp_sum); + + // softmax + float inv_sum = __fdividef(1.f, exp_sum + 1.e-6); + for (int ti = first_step + tidx; ti < totalCacheSize; + ti += THREADS_PER_BLOCK) { + qk_smem[ti - first_step] *= inv_sum; + } + + __syncthreads(); + + // value projection + constexpr int V_VEC_SIZE = 16 / sizeof(DT); + // A vector of V elements for the current timestep. + // using V_vec_k = typename V_vec_k_::Type; + // using V_vec_acum = typename V_vec_acum_fp32_::Type; + + // The value computed by this thread. + int vo = tidx / THREADS_PER_VALUE; + // The hidden dimensions computed by this particular thread. + int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; + constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; + + Out_sum out; + zero(out); + + // The base pointer for the value in the cache buffer. + DT const *v_cache_batch = + value_cache + batch_config_request_id * max_seq_length * hidden_size + + vi; + + if (Dh == Dh_MAX || vi < Dh) { + for (int ti = first_step + vo; ti < totalCacheSize; ti += V_PER_ITER) { + // Load the values from the cache. + int const ti_circ = ti % max_seq_length; + V_vec v = *reinterpret_cast( + v_cache_batch + ti_circ * hidden_size + head_idx * per_head_size); + + bool const mask = (ti >= bitmask.non_tree_cache_size && + (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & + (1 << query_token)))); + float logit = mask ? 0.0f : qk_smem[ti - first_step]; + out = FlexFlow::fma(logit, cast_to_float(v), out); + } + } + + // // Make sure we can start writing to shared memory. + __syncthreads(); + + // Run the final reduction amongst the different groups computing different + // partial outputs. + if (Dh == Dh_MAX || vi < Dh) { +#pragma unroll + for (int active_groups = V_PER_ITER; active_groups >= 2; + active_groups /= 2) { + + // The midpoint in the number of active groups. + int midpoint = active_groups / 2; + + // The upper part of active threads store to shared memory. + if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) { + *reinterpret_cast(out_smem + (vo - midpoint) * Dh + vi) = + out; + } + __syncthreads(); + + // The bottom warps update their values. + if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) { + out = add(*reinterpret_cast(out_smem + vo * Dh + vi), + out); + } + __syncthreads(); + } + } + + // Output the final values. + if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { + convert_from_float(*reinterpret_cast( + output_ptr + (first_token_idx + qi) * hidden_size + + head_idx * per_head_size + vi), + out); + } + } +} template -__global__ void spec_store_kv_cache( +__global__ void spec_inc_store_kv_cache( DT const *devQKVProjArray, DT *kCache_ptr, DT *vCache_ptr, @@ -40,16 +319,16 @@ __global__ void spec_store_kv_cache( BatchConfig::PerRequestInfo *requestInfo, BeamSearchBatchConfig::BeamSearchPerTokenInfo *beamTokenInfos, BeamSearchBatchConfig::BeamSearchPerRequestInfo *beamRequestInfos, + BatchConfig::BitMask *causalMask, int qProjSize, int kProjSize, int vProjSize, int num_tokens, int max_seq_len, - int max_beam_width, bool is_root, int hidden_size) { - CUDA_KERNEL_LOOP(i, num_tokens * hidden_size * 2) { - int token_idx = i / (hidden_size * KV_WEIGHT_NUM); + CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { + int token_idx = i / (hidden_size); int offset = i % hidden_size; size_t val_idx = @@ -58,100 +337,36 @@ __global__ void spec_store_kv_cache( DT kVal = devQKVProjArray[val_idx]; DT vVal = devQKVProjArray[val_idx + hidden_size]; - // above no need to be changed - // int const req_id = id_map[token_idx].request_index; - // int const tok_id = id_map[token_idx].token_position; - // int const sub_req_id = id_map[token_idx].sub_request_index; - // int const parent_id = id_map[token_idx].parent_id; - // int const beam_depth = id_map[token_idx].beam_depth; - // int const beam_width = id_map[token_idx].beam_width; - int const req_id = tokenInfos[token_idx].request_index; int const tok_id = tokenInfos[token_idx].abs_depth_in_request; + int const first_token_in_req = + requestInfo[req_id].first_token_depth_in_request; int const sub_req_id = beamTokenInfos[token_idx].sub_request_index; - int const parent_id = beamRequestInfos[req_id].parent_id[sub_req_id]; - int const beam_depth = beamRequestInfos[req_id].current_depth; - int const beam_width = beamRequestInfos[req_id].beam_size; - - kCache_ptr[(req_id * max_beam_width + sub_req_id) * - (hidden_size * max_seq_len) + - tok_id * hidden_size + offset] = kVal; - vCache_ptr[(req_id * max_beam_width + sub_req_id) * - (hidden_size * max_seq_len) + - tok_id * hidden_size + offset] = vVal; - - // replica in the root iteration - if (beam_depth == 1) { - for (int i = 1; i < beam_width; i++) { - kCache_ptr[(req_id * max_beam_width + i) * (hidden_size * max_seq_len) + - tok_id * hidden_size + offset] = kVal; - vCache_ptr[(req_id * max_beam_width + i) * (hidden_size * max_seq_len) + - tok_id * hidden_size + offset] = vVal; - } - } + int const total_token = requestInfo[req_id].num_tokens_in_batch; - // if (head_idx == 0 && beam_depth == 0 && token_idx == 8 && k_cache) { - // // printf("token idx %d\n", token_idx); - // printf("data idx: %d, tok_id %d, new_token_cache_idx %d, parent_id %d, - // " - // "sub_req_id %d, num_tokens %d, kProjSize %d, num_kv_heads %d, - // val " - // "%f, beam_width %d\n", - // data_idx, - // tok_id, - // new_token_cache_idx, - // parent_id, - // sub_req_id, - // num_tokens, - // kProjSize, - // num_kv_heads, - // val, - // beam_width); - // } + int const request_token_offset = + requestInfo[req_id].first_token_offset_in_batch; - // naive cache stealing - if (sub_req_id != parent_id) { - // if (offset == 0 && tok_id == 0) { - // printf("cache stealing!, depth %d req_id %d sub_req_id %d, parentid " - // "%d, tok_id %d\n", - // beam_depth, - // req_id, - // sub_req_id, - // parent_id, - // tok_id); - // } - - for (int depth = 0; depth < beam_depth; depth++) { - int steal_token_idx = tok_id - beam_depth + depth; - int steal_from_idx = (req_id * max_beam_width + parent_id) * - (hidden_size * max_seq_len) + - steal_token_idx * hidden_size + offset; - int steal_to_idx = (req_id * max_beam_width + sub_req_id) * - (hidden_size * max_seq_len) + - steal_token_idx * hidden_size + offset; - kCache_ptr[steal_to_idx] = kCache_ptr[steal_from_idx]; - vCache_ptr[steal_to_idx] = vCache_ptr[steal_from_idx]; - - // if(data_idx == 0 && head_idx == 0 && k_cache && req_id == 1){ - // printf("cache stealing kernel!, steal_token_idx %d\n", - // steal_token_idx); - // } - } - } + BatchConfig::BitMask bitmask = causalMask[req_id]; - // parallel cache stealing not yet implemented - // logic shld be - // launch spec_store_kv_cache with parallelism * current depth - // from the i here, get depth index - // if depth index not the current one, check if we need to steal - // steal if needed - - // cache stealing theory - // identify which sub request does this token come from - // for initial token, 0 - // for other, may 0,0,1/ 0,1,2/ 1,1,1 to get which cache to be reuse and - // which to be delete copy beam_size bunch of blocks when sub_req_id == - // parent_id : like 0 -> 0, 1->1, 2->2, do nothing, just append the new k/v + int const sub_request_num = beamRequestInfos[req_id].sub_request_num; + + int const tree_branch_num = beamRequestInfos[req_id].sub_request_num; + + // int const query_token = bitmask.non_tree_cache_size + bitmask.tree_size - + // tree_branch_num + sub_req_id + tok_id; + // bitmask.tree_size - tree_branch_num + sub_req_id; + + // if prompt token -> token id + // if tree token: + int const cache_idx = bitmask.non_tree_cache_size + bitmask.tree_size - + bitmask.this_layer_size + token_idx - + request_token_offset; + + kCache_ptr[req_id * (hidden_size * max_seq_len) + (cache_idx)*hidden_size + + offset] = kVal; + vCache_ptr[req_id * (hidden_size * max_seq_len) + (cache_idx)*hidden_size + + offset] = vVal; } } @@ -161,28 +376,79 @@ void update_kv_cache_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, cudaStream_t stream) { int num_tokens = bc->num_active_tokens(); int curr_depth = bc->beamRequestsInfo[0].current_depth; - // printf("curr depth: %d\n", curr_depth); - // assert(curr_depth < 3); if (num_tokens > 0) { int parallelism = m->hidden_size * KV_WEIGHT_NUM * num_tokens; - spec_store_kv_cache<<>>(static_cast
(m->devQKVProjArray), - static_cast
(m->keyCache), - static_cast
(m->valueCache), - m->token_infos, - m->request_infos, - m->beam_token_infos, - m->beam_request_infos, - m->qProjSize, - m->kProjSize, - m->vProjSize, - num_tokens, - BatchConfig::max_sequence_length(), - BeamSearchBatchConfig::MAX_BEAM_WIDTH, - /*root*/ curr_depth == 0, - m->hidden_size); + spec_inc_store_kv_cache<<>>( + static_cast
(m->devQKVProjArray), + static_cast
(m->keyCache), + static_cast
(m->valueCache), + m->token_infos, + m->request_infos, + m->beam_token_infos, + m->beam_request_infos, + m->causalMask, + m->qProjSize, + m->kProjSize, + m->vProjSize, + num_tokens, + BatchConfig::max_sequence_length() + + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM, + /*root*/ curr_depth == 0, + m->hidden_size); + } +} + +#define LAUNCH_SPEC_INC_ATTENTION_SCORE_KERNEL( \ + DT, Dh, Dh_MAX, THDS_PER_KEY, THREADS_PER_VALUE, THDS_PER_BLOCK, stream) \ + smem_sz = smem_size_in_bytes
(m->qProjSize, \ + BatchConfig::max_sequence_length() + \ + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM, \ + THREADS_PER_VALUE, \ + THDS_PER_BLOCK); \ + compute_spec_inc_attention_kernel_generation_kernel \ + <<>>( \ + static_cast
(m->devQKVProjArray), \ + static_cast
(m->keyCache), \ + static_cast
(m->valueCache), \ + output_ptr, \ + scale, \ + BatchConfig::max_sequence_length() + \ + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM, \ + m->qProjSize, \ + m->hidden_size, \ + m->request_infos, \ + m->beam_request_infos, \ + m->causalMask) + +template +void compute_spec_inc_attention_kernel_generation( + SpecIncMultiHeadSelfAttentionMeta const *m, + BeamSearchBatchConfig const *bc, + DT *output_ptr, + cudaStream_t stream) { + // one block == one head per request + dim3 grid(m->num_q_heads, bc->num_active_requests()); + int const per_head_size = m->qProjSize; + float scale = (*m->qk_prod_scaling) ? 1.0f / sqrt(m->kProjSize) : 1.0f; + size_t smem_sz; + if (per_head_size == 64) { + constexpr int THREADS_PER_VALUE_64 = threads_per_value_t::value; + LAUNCH_SPEC_INC_ATTENTION_SCORE_KERNEL( + DT, 64, 64, 4, THREADS_PER_VALUE_64, 128, stream); + } else if (per_head_size == 128) { + constexpr int THREADS_PER_VALUE_128 = threads_per_value_t::value; + LAUNCH_SPEC_INC_ATTENTION_SCORE_KERNEL( + DT, 128, 128, 4, THREADS_PER_VALUE_128, 128, stream); + } else { + assert(false && "a unsupported head size"); } } @@ -204,13 +470,14 @@ __global__ void spec_fill_entries_above_diagonal(DT *matrix, } template -void compute_attention_kernel_prompt(SpecIncMultiHeadSelfAttentionMeta const *m, - BeamSearchBatchConfig const *bc, - int shard_id, - DT *output_ptr, - DT const *bias_ptr, - DT const *weight_ptr, - cudaStream_t stream) { +void compute_attention_kernel_prompt( + SpecIncMultiHeadSelfAttentionMeta const *m, + BeamSearchBatchConfig const *bc, + int shard_id, + DT *output_ptr, + DT const *bias_ptr, + DT const *weight_ptr, + cudaStream_t stream) { checkCUDA(cublasSetStream(m->handle.blas, stream)); checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); cudaDataType_t cublas_data_type = ff_to_cuda_datatype(m->output_type[0]); @@ -236,199 +503,208 @@ void compute_attention_kernel_prompt(SpecIncMultiHeadSelfAttentionMeta const *m, int q_block_size = m->qProjSize; int kt_block_size = m->kProjSize; - int kt_req_block_size = - kt_block_size * m->num_q_heads * BatchConfig::max_sequence_length(); + int kt_req_block_size = kt_block_size * m->num_q_heads * + (BatchConfig::max_sequence_length() + + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM); int vt_block_size = m->vProjSize; - int vt_req_block_size = - vt_block_size * m->num_q_heads * BatchConfig::max_sequence_length(); + int vt_req_block_size = vt_block_size * m->num_q_heads * + (BatchConfig::max_sequence_length() + + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM); assert(m->qProjSize == m->kProjSize); for (int i = 0; i < bc->max_requests_per_batch(); i++) { if (bc->request_completed[i]) { continue; } - for (int sub_req_id = 0; sub_req_id < bc->sub_requests[i]; sub_req_id++) { - // int num_new_tokens = bc->num_processing_tokens[i]; - // int total_tokens = bc->token_last_available_idx[i] + 1; + // else if (tokens_previous_requests < bc->num_generation_tokens) { + // tokens_previous_requests += bc->requestsInfo[i].num_tokens_in_batch; + // continue; + // } - int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch; - int total_tokens = bc->requestsInfo[i].first_token_depth_in_request + - bc->requestsInfo[i].num_tokens_in_batch; + // all requests in prompt phase should only have one sub requests; + assert(bc->sub_requests[i] == 1); + // int num_new_tokens = bc->num_processing_tokens[i]; + // int total_tokens = bc->token_last_available_idx[i] + 1; - if (num_new_tokens <= 0) { - continue; - } + int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch; + int total_tokens = bc->requestsInfo[i].first_token_depth_in_request + + bc->requestsInfo[i].num_tokens_in_batch; - // Compute (QK^T/sqrt(d_k)) - int m_ = num_new_tokens; - int n = total_tokens; - int k = m->qProjSize; - int lda = k * m->num_q_heads * QKV_WEIGHT_NUM, ldb = k * m->num_q_heads, - ldc = m_; - int strideA = q_block_size; - int strideB = kt_block_size; - int strideC = num_new_tokens * total_tokens; - - // a flag of using this scaling alpha - DT alpha = 1.0f, beta = 0.0f; - if (*m->qk_prod_scaling) { - alpha = static_cast
(1.0f / sqrt(m->kProjSize)); - } - // To get A, skip over Q entries from previous requests (same head) - DT const *A = static_cast
(m->devQKVProjArray) + - bc->requestsInfo[i].first_token_offset_in_batch * - m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM; - // To get B, skip over K entries from previous requests (all heads + - // padding) - DT const *B = static_cast
(m->keyCache) + - (i * bc->MAX_BEAM_WIDTH + sub_req_id) * kt_req_block_size; - - // if (i == 0 && sub_req_id == 0 && - // bc->beam_slots.at(0).current_depth == 1) { - // int offset = (float *)B - m->keyCache; - // printf("key cache offset %d\n", kt_req_block_size); - // } - // To get C, skip over QK^T products from previous requests - DT *C = static_cast
(m->qk_prods) + - m->num_q_heads * tokens_prev_requests_squares; - checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, - CUBLAS_OP_T, - CUBLAS_OP_N, - m_, - n, - k, - &alpha, - A, - cublas_data_type, - lda, - strideA, - B, - cublas_data_type, - ldb, - strideB, - &beta, - C, - cublas_data_type, - ldc, - strideC, - m->num_q_heads, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - // add alibi position bias to qk production - // add alibi position bias to qk production - if (*m->position_bias) { - size_t parallelism = m->num_q_heads * total_tokens * num_new_tokens; - apply_position_bias_qkprd<<>>(C, - num_new_tokens, - total_tokens, - m->num_q_heads, - m->global_num_q_heads, - shard_id); - } - // Fill all elements above diagonal in qk prods with -inf to force - // causal attention. - assert(num_new_tokens <= total_tokens); - if (num_new_tokens > 1) { - size_t parallelism = m->num_q_heads * num_new_tokens * total_tokens; - spec_fill_entries_above_diagonal<<>>( - C, - num_new_tokens, - total_tokens, - m->num_q_heads, - static_cast
(-INFINITY)); - } - // Compute Softmax(QK^T/sqrt(d_k)) - // Before modifying the parameters below, make sure to read the following - // description of the CUDNN_TENSOR_NCHW tensor layout, from - // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnTensorFormat_t: - // This tensor format specifies that the data is laid out in the following - // order: batch size, feature maps, rows, columns. The strides are - // implicitly defined in such a way that the data are contiguous in memory - // with no padding between images, feature maps, rows, and columns; the - // columns are the inner dimension and the images are the outermost - // dimension. - int n_param = m->num_q_heads; - int c_param = total_tokens; - int h_param = 1; - int w_param = num_new_tokens; - checkCUDNN(cudnnSetTensor4dDescriptor(m->qk_tensor, - CUDNN_TENSOR_NCHW, - cudnn_data_type, - n_param, - c_param, - h_param, - w_param)); - float softmax_alpha = 1.0f, softmax_beta = 0.0f; - DT *C_softmax = static_cast
(m->qk_prods_softmax) + - m->num_q_heads * tokens_prev_requests_squares; - // The softmax operation below is executed according to the - // CUDNN_SOFTMAX_MODE_CHANNEL, which is also described in the docs: The - // softmax operation is computed per spatial location (H,W) per image (N) - // across dimension C. - checkCUDNN(cudnnSoftmaxForward(m->handle.dnn, - CUDNN_SOFTMAX_ACCURATE, - CUDNN_SOFTMAX_MODE_CHANNEL, - &softmax_alpha, - m->qk_tensor, - C, - &softmax_beta, - m->qk_tensor, - C_softmax)); - // Matmul softmax(QK^T/sqrt(d_k)) by V - alpha = 1.0f, beta = 0.0f; - m_ = m->vProjSize; - n = num_new_tokens; - k = total_tokens; - lda = m_ * m->num_q_heads, ldb = n, ldc = m_ * m->num_q_heads; - strideA = vt_block_size; - strideB = num_new_tokens * total_tokens; - strideC = m->vProjSize; - // To get A, skip over V^T entries from previous requests (all heads + - // padding) - A = static_cast
(m->valueCache) + - (i * bc->MAX_BEAM_WIDTH + sub_req_id) * vt_req_block_size; - // To get B, skip over softmax(QK^T/sqrt(d_k)) entries from previous - // requests (all heads) - B = C_softmax; - // To get C, skip over softmax(QK^T/sqrt(d_k))V products from previous - // requests - C = static_cast
(m->attn_heads) + - (tokens_previous_requests + bc->num_generation_tokens) * - m->num_q_heads * m->vProjSize; - checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, - CUBLAS_OP_N, - CUBLAS_OP_T, - m_, - n, - k, - &alpha, - A, - cublas_data_type, - lda, - strideA, - B, - cublas_data_type, - ldb, - strideB, - &beta, - C, - cublas_data_type, - ldc, - strideC, - m->num_q_heads, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - tokens_previous_requests += num_new_tokens; - tokens_prev_requests_squares += num_new_tokens * total_tokens; + if (num_new_tokens <= 0) { + continue; + } + + // Compute (QK^T/sqrt(d_k)) + int m_ = num_new_tokens; + int n = total_tokens; + int k = m->qProjSize; + int lda = k * m->num_q_heads * QKV_WEIGHT_NUM, ldb = k * m->num_q_heads, + ldc = m_; + int strideA = q_block_size; + int strideB = kt_block_size; + int strideC = num_new_tokens * total_tokens; + + // a flag of using this scaling alpha + DT alpha = 1.0f, beta = 0.0f; + if (*m->qk_prod_scaling) { + alpha = static_cast
(1.0f / sqrt(m->kProjSize)); + } + // To get A, skip over Q entries from previous requests (same head) + DT const *A = static_cast
(m->devQKVProjArray) + + bc->requestsInfo[i].first_token_offset_in_batch * + m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM; + // To get B, skip over K entries from previous requests (all heads + + // padding) + + // print_tensor((float*)A, 32, "A"); + DT const *B = static_cast
(m->keyCache) + i * kt_req_block_size; + + // if (i == 0 && sub_req_id == 0 && + // bc->beam_slots.at(0).current_depth == 1) { + // int offset = (float *)B - m->keyCache; + // printf("key cache offset %d\n", kt_req_block_size); + // } + // To get C, skip over QK^T products from previous requests + DT *C = static_cast
(m->qk_prods) + + m->num_q_heads * tokens_prev_requests_squares; + checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, + CUBLAS_OP_T, + CUBLAS_OP_N, + m_, + n, + k, + &alpha, + A, + cublas_data_type, + lda, + strideA, + B, + cublas_data_type, + ldb, + strideB, + &beta, + C, + cublas_data_type, + ldc, + strideC, + m->num_q_heads, + compute_type, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + // print_tensor((float*)C, 32, "C"); + // add alibi position bias to qk production + // add alibi position bias to qk production + if (*m->position_bias) { + size_t parallelism = m->num_q_heads * total_tokens * num_new_tokens; + apply_position_bias_qkprd<<>>(C, + num_new_tokens, + total_tokens, + m->num_q_heads, + m->global_num_q_heads, + shard_id); } + // Fill all elements above diagonal in qk prods with -inf to force + // causal attention. + assert(num_new_tokens <= total_tokens); + if (num_new_tokens > 1) { + size_t parallelism = m->num_q_heads * num_new_tokens * total_tokens; + spec_fill_entries_above_diagonal<<>>(C, + num_new_tokens, + total_tokens, + m->num_q_heads, + static_cast
(-INFINITY)); + } + // Compute Softmax(QK^T/sqrt(d_k)) + // Before modifying the parameters below, make sure to read the following + // description of the CUDNN_TENSOR_NCHW tensor layout, from + // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnTensorFormat_t: + // This tensor format specifies that the data is laid out in the following + // order: batch size, feature maps, rows, columns. The strides are + // implicitly defined in such a way that the data are contiguous in memory + // with no padding between images, feature maps, rows, and columns; the + // columns are the inner dimension and the images are the outermost + // dimension. + int n_param = m->num_q_heads; + int c_param = total_tokens; + int h_param = 1; + int w_param = num_new_tokens; + checkCUDNN(cudnnSetTensor4dDescriptor(m->qk_tensor, + CUDNN_TENSOR_NCHW, + cudnn_data_type, + n_param, + c_param, + h_param, + w_param)); + float softmax_alpha = 1.0f, softmax_beta = 0.0f; + DT *C_softmax = static_cast
(m->qk_prods_softmax) + + m->num_q_heads * tokens_prev_requests_squares; + // The softmax operation below is executed according to the + // CUDNN_SOFTMAX_MODE_CHANNEL, which is also described in the docs: The + // softmax operation is computed per spatial location (H,W) per image (N) + // across dimension C. + checkCUDNN(cudnnSoftmaxForward(m->handle.dnn, + CUDNN_SOFTMAX_ACCURATE, + CUDNN_SOFTMAX_MODE_CHANNEL, + &softmax_alpha, + m->qk_tensor, + C, + &softmax_beta, + m->qk_tensor, + C_softmax)); + // Matmul softmax(QK^T/sqrt(d_k)) by V + alpha = 1.0f, beta = 0.0f; + m_ = m->vProjSize; + n = num_new_tokens; + k = total_tokens; + lda = m_ * m->num_q_heads, ldb = n, ldc = m_ * m->num_q_heads; + strideA = vt_block_size; + strideB = num_new_tokens * total_tokens; + strideC = m->vProjSize; + // To get A, skip over V^T entries from previous requests (all heads + + // padding) + A = static_cast
(m->valueCache) + i * vt_req_block_size; + // To get B, skip over softmax(QK^T/sqrt(d_k)) entries from previous + // requests (all heads) + B = C_softmax; + // To get C, skip over softmax(QK^T/sqrt(d_k))V products from previous + // requests + + // print_tensor((float*)C_softmax, 32, "C_softmax"); + C = static_cast
(m->attn_heads) + + (tokens_previous_requests + bc->num_generation_tokens) * + m->num_q_heads * m->vProjSize; + checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, + CUBLAS_OP_N, + CUBLAS_OP_T, + m_, + n, + k, + &alpha, + A, + cublas_data_type, + lda, + strideA, + B, + cublas_data_type, + ldb, + strideB, + &beta, + C, + cublas_data_type, + ldc, + strideC, + m->num_q_heads, + compute_type, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + tokens_previous_requests += num_new_tokens; + tokens_prev_requests_squares += num_new_tokens * total_tokens; } // assert(tokens_previous_requests == num_tokens); @@ -443,31 +719,8 @@ void inference_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, DT *output_ptr, DT const *bias_ptr, cudaStream_t stream) { - // here because we need postion info in infernece 1 - cudaMemcpyAsync(m->token_infos, - &(bc->tokensInfo), - bc->num_active_tokens() * sizeof(BatchConfig::PerTokenInfo), - cudaMemcpyHostToDevice, - stream); - cudaMemcpyAsync(m->request_infos, - &(bc->requestsInfo), - bc->max_requests_per_batch() * - sizeof(BatchConfig::PerRequestInfo), - cudaMemcpyHostToDevice, - stream); - cudaMemcpyAsync(m->beam_token_infos, - &(bc->beamTokenInfo), - bc->num_active_tokens() * bc->MAX_BEAM_WIDTH * - sizeof(BeamSearchBatchConfig::BeamSearchPerTokenInfo), - cudaMemcpyHostToDevice, - stream); - cudaMemcpyAsync(m->beam_request_infos, - &(bc->beamRequestsInfo), - bc->max_requests_per_batch() * - sizeof(BeamSearchBatchConfig::BeamSearchPerRequestInfo), - cudaMemcpyHostToDevice, - stream); // phase 1: Implement kernel to compute KQV for input tokens + compute_qkv_kernel(m, bc, shard_id, @@ -479,7 +732,7 @@ void inference_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, // phase 2: Update key/val cache update_kv_cache_kernel
(m, bc, stream); if (bc->num_generation_tokens > 0) { - compute_attention_kernel_generation
( + compute_spec_inc_attention_kernel_generation
( m, bc, static_cast
(m->attn_heads), stream); } // phase 3: Compute attention score @@ -488,16 +741,14 @@ void inference_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, compute_attention_kernel_prompt( m, bc, shard_id, output_ptr, bias_ptr, weight_ptr, stream); } - // compute output production and bias together for all tokens - int num_tokens = - bc->num_active_tokens() * BeamSearchBatchConfig::MAX_BEAM_WIDTH; + int num_tokens = bc->num_active_tokens(); compute_o_prod_bias( m, bc, shard_id, output_ptr, weight_ptr, bias_ptr, num_tokens, stream); } -} // namespace SpecIncMultiHeadAttention +} // namespace SpecIncMultiHeadSelfAttention } // namespace Kernels /*static*/ @@ -529,25 +780,27 @@ void SpecIncMultiHeadSelfAttention::inference_kernel_wrapper( if (input.data_type == DT_HALF) { half const *bias_ptr = use_bias ? bias.get_half_ptr() : static_cast(nullptr); - Kernels::SpecIncMultiHeadAttention::inference_kernel(m, - bc, - shard_id, - input.get_half_ptr(), - weight.get_half_ptr(), - output.get_half_ptr(), - bias_ptr, - stream); + Kernels::SpecIncMultiHeadSelfAttention::inference_kernel( + m, + bc, + shard_id, + input.get_half_ptr(), + weight.get_half_ptr(), + output.get_half_ptr(), + bias_ptr, + stream); } else if (input.data_type == DT_FLOAT) { float const *bias_ptr = use_bias ? bias.get_float_ptr() : static_cast(nullptr); - Kernels::SpecIncMultiHeadAttention::inference_kernel(m, - bc, - shard_id, - input.get_float_ptr(), - weight.get_float_ptr(), - output.get_float_ptr(), - bias_ptr, - stream); + Kernels::SpecIncMultiHeadSelfAttention::inference_kernel( + m, + bc, + shard_id, + input.get_float_ptr(), + weight.get_float_ptr(), + output.get_float_ptr(), + bias_ptr, + stream); } else { assert(false && "Unspported data type"); } @@ -559,7 +812,8 @@ void SpecIncMultiHeadSelfAttention::inference_kernel_wrapper( checkCUDA(cudaEventElapsedTime(&elapsed, t_start, t_end)); cudaEventDestroy(t_start); cudaEventDestroy(t_end); - printf("SpecIncMultiHeadSelfAttention forward time = %.2fms\n", elapsed); + printf("SpecIncMultiHeadSelfAttention forward time = %.2fms\n", + elapsed); // print_tensor<3, float>(acc_query.ptr, acc_query.rect, // "[Attention:forward:query]"); print_tensor<3, float>(acc_output.ptr, // acc_output.rect, "[Attention:forward:output]"); @@ -606,44 +860,51 @@ SpecIncMultiHeadSelfAttentionMeta::SpecIncMultiHeadSelfAttentionMeta( // allocate memory for the seqArray and reserve space { - int max_tokens_per_batch = BatchConfig::max_tokens_per_batch(); - size_t beam_tokeninfo_size = - max_tokens_per_batch * BeamSearchBatchConfig::MAX_BEAM_WIDTH; - size_t requestinfo_size = BeamSearchBatchConfig::max_requests_per_batch(); - size_t beam_requestinfo_size = - BeamSearchBatchConfig::max_requests_per_batch(); - size_t total_size = - beam_tokeninfo_size * - sizeof(BeamSearchBatchConfig::BeamSearchPerTokenInfo) + - beam_requestinfo_size * - sizeof(BeamSearchBatchConfig:: - BeamSearchPerRequestInfo); // more components will - // be added here later - - // We always directly allocate memory for small speculative models - gpu_mem_allocator.create_legion_instance(beam_search_reserve_inst, - total_size); + // size_t causal_mask_size = BatchConfig::MAX_NUM_REQUESTS; + // size_t total_size = causal_mask_size * sizeof(BatchConfig::BitMask); + // gpu_mem_allocator.create_legion_instance(beam_search_reserve_inst, + // total_size); + beam_token_infos = - gpu_mem_allocator - .allocate_instance( - beam_tokeninfo_size); + static_cast( + handler.batch_config_metadata + sizeof(BatchConfig::tokensInfo) + + sizeof(BatchConfig::requestsInfo)); + + beam_request_infos = + static_cast( + handler.batch_config_metadata + sizeof(BatchConfig::tokensInfo) + + sizeof(BatchConfig::requestsInfo) + + sizeof(BeamSearchBatchConfig::beamTokenInfo)); + causalMask = static_cast( + handler.batch_config_metadata + sizeof(BatchConfig::tokensInfo) + + sizeof(BatchConfig::requestsInfo) + + sizeof(BeamSearchBatchConfig::beamTokenInfo) + + sizeof(BeamSearchBatchConfig::beamRequestsInfo)); + + // causalMask = gpu_mem_allocator.allocate_instance( + // causal_mask_size); + // beam_token_infos = + // gpu_mem_allocator + // .allocate_instance( + // beam_tokeninfo_size); // offset += beam_tokeninfo_size * // sizeof(BeamSearchBatchConfig::BeamSearchPerTokenInfo); - beam_request_infos = - gpu_mem_allocator - .allocate_instance( - beam_requestinfo_size); + // beam_request_infos = + // gpu_mem_allocator + // .allocate_instance( + // beam_requestinfo_size); // offset += beam_requestinfo_size * // sizeof(BeamSearchBatchConfig::BeamSearchPerRequestInfo); // assert(offset == total_size); - assert(gpu_mem_allocator.instance_total_size == - gpu_mem_allocator.instance_allocated_size); + // assert(gpu_mem_allocator.instance_total_size == + // gpu_mem_allocator.instance_allocated_size); } cudaStreamSynchronize(stream); } -SpecIncMultiHeadSelfAttentionMeta::~SpecIncMultiHeadSelfAttentionMeta(void) { +SpecIncMultiHeadSelfAttentionMeta::~SpecIncMultiHeadSelfAttentionMeta( + void) { if (beam_search_reserve_inst != Realm::RegionInstance::NO_INST) { beam_search_reserve_inst.destroy(); } diff --git a/src/ops/specinfer_inc_multihead_self_attention.cc b/src/ops/specinfer_inc_multihead_self_attention.cc deleted file mode 100644 index 42074f39e4..0000000000 --- a/src/ops/specinfer_inc_multihead_self_attention.cc +++ /dev/null @@ -1,883 +0,0 @@ -/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "flexflow/ops/specinfer_inc_multihead_self_attention.h" -#include "flexflow/ffconst_utils.h" -#include "flexflow/model.h" -#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) -#include "flexflow/utils/cuda_helper.h" -#else -#include "flexflow/utils/hip_helper.h" -#endif -#include "flexflow/utils/hash_utils.h" -#include "legion/legion_utilities.h" - -namespace FlexFlow { - -// declare Legion names -using Legion::ArgumentMap; -using Legion::Context; -using Legion::coord_t; -using Legion::Domain; -using Legion::Future; -using Legion::FutureMap; -using Legion::IndexLauncher; -using Legion::Machine; -using Legion::Memory; -using Legion::PhysicalRegion; -using Legion::Predicate; -using Legion::Rect; -using Legion::RegionRequirement; -using Legion::Runtime; -using Legion::Task; -using Legion::TaskArgument; -using Legion::TaskLauncher; -using PCG::Node; - -bool SpecInferIncMultiHeadSelfAttentionParams::is_valid( - ParallelTensorShape const &input) const { - bool is_valid = input.is_valid(); - return is_valid; -} - -Tensor FFModel::specinfer_inc_multihead_self_attention( - Tensor const input, - int embed_dim, - int num_heads, - int kdim, - int vdim, - float dropout, - bool qkv_bias, - bool final_bias, - bool add_zero_attn, - DataType data_type, - Initializer *kernel_initializer, - bool apply_rotary_embedding, - bool scaling_query, - float scaling_factor, - bool qk_prod_scaling, - bool position_bias, - char const *name) { - return specinfer_inc_multiquery_self_attention(input, - embed_dim, - num_heads, - num_heads, - kdim, - vdim, - dropout, - qkv_bias, - final_bias, - add_zero_attn, - data_type, - kernel_initializer, - apply_rotary_embedding, - scaling_query, - scaling_factor, - qk_prod_scaling, - position_bias, - name); -} - -Tensor FFModel::specinfer_inc_multiquery_self_attention( - Tensor const input, - int embed_dim, - int num_q_heads, - int num_kv_heads, - int kdim, - int vdim, - float dropout, - bool qkv_bias, - bool final_bias, - bool add_zero_attn, - DataType data_type, - Initializer *kernel_initializer, - bool apply_rotary_embedding, - bool scaling_query, - float scaling_factor, - bool qk_prod_scaling, - bool position_bias, - char const *name) { - if (data_type == DT_NONE) { - data_type = input->data_type; - } - Layer *li = nullptr; - int weight_num = (qkv_bias || final_bias) ? 2 : 1; - if (data_type != input->data_type) { - Tensor casted_input = cast(input, data_type, "type cast for IncMHA"); - li = new Layer(this, - OP_SPECINFER_INC_MULTIHEAD_SELF_ATTENTION, - data_type, - name, - 1 /*inputs*/, - weight_num /*weights*/, - 1 /*outputs*/, - casted_input); - } else { - li = new Layer(this, - OP_SPECINFER_INC_MULTIHEAD_SELF_ATTENTION, - data_type, - name, - 1 /*inputs*/, - weight_num /*weights*/, - 1 /*outputs*/, - input); - } - { - int numdims = input->num_dims; - int dims[MAX_TENSOR_DIM]; - for (int i = 0; i < numdims; i++) { - dims[i] = input->dims[i]; - } - dims[0] = embed_dim; - li->outputs[0] = create_tensor_legion_ordering( - numdims, dims, data_type, li, 0, true /*create_grad*/); - } - // Compute weight size - int qProjSize = kdim, kProjSize = kdim, vProjSize = kdim, - oProjSize = embed_dim; - int qSize = input->dims[0], kSize = input->dims[0], vSize = input->dims[0]; - int qParas = qProjSize * qSize; - int kParas = kProjSize * kSize; - int vParas = vProjSize * vSize; - int oParas = oProjSize * (vProjSize > 0 ? vProjSize : vSize); - int weight_size = qParas * num_q_heads + kParas * num_q_heads + - vParas * num_q_heads + oParas * num_q_heads; - { - int dims[1] = {weight_size}; - li->weights[0] = create_weight_legion_ordering(1, - dims, - data_type, - li, - true /*create_grad*/, - kernel_initializer, - CHOSEN_SYNC_TYPE); - } - if (qkv_bias || final_bias) { - // q, k, v, o - int qkv_bias_size = - qProjSize * num_q_heads + (kProjSize + vProjSize) * num_q_heads; - int dims[1] = {(qkv_bias ? qkv_bias_size : 0) + - (final_bias ? oProjSize : 0)}; - li->weights[1] = create_weight_legion_ordering(1, - dims, - data_type, - li, - true /*create_grad*/, - kernel_initializer, - CHOSEN_SYNC_TYPE); - } - li->data_type = data_type; - li->add_int_property("embed_dim", embed_dim); - li->add_int_property("num_q_heads", num_q_heads); - li->add_int_property("num_kv_heads", num_kv_heads); - li->add_int_property("kdim", kdim); - li->add_int_property("vdim", vdim); - li->add_int_property("qkv_bias", qkv_bias); - li->add_int_property("final_bias", final_bias); - li->add_int_property("add_zero_attn", add_zero_attn); - li->add_float_property("dropout", dropout); - li->add_int_property("apply_rotary_embedding", apply_rotary_embedding); - li->add_int_property("scaling_query", scaling_query); - li->add_float_property("scaling_factor", scaling_factor); - li->add_int_property("qk_prod_scaling", qk_prod_scaling); - li->add_int_property("position_bias", position_bias); - layers.push_back(li); - return li->outputs[0]; -} - -Op *SpecInferIncMultiHeadSelfAttention::create_operator_from_layer( - FFModel &model, - Layer const *layer, - std::vector const &inputs) { - - std::cout << "spec create operator: " << layer->name << "\n"; - long long value; - layer->get_int_property("embed_dim", value); - int embed_dim = value; - layer->get_int_property("num_q_heads", value); - int num_q_heads = value; - layer->get_int_property("num_kv_heads", value); - int num_kv_heads = value; - layer->get_int_property("kdim", value); - int kdim = value; - layer->get_int_property("vdim", value); - int vdim = value; - float dropout; - layer->get_float_property("dropout", dropout); - layer->get_int_property("qkv_bias", value); - bool qkv_bias = (bool)value; - layer->get_int_property("final_bias", value); - bool final_bias = (bool)value; - layer->get_int_property("add_zero_attn", value); - bool add_zero_attn = (bool)value; - layer->get_int_property("apply_rotary_embedding", value); - bool apply_rotary_embedding = (bool)value; - layer->get_int_property("scaling_query", value); - bool scaling_query = (bool)value; - float scaling_factor; - layer->get_float_property("scaling_factor", scaling_factor); - layer->get_int_property("qk_prod_scaling", value); - bool qk_prod_scaling = (bool)value; - layer->get_int_property("position_bias", value); - bool position_bias = (bool)value; - - return new SpecInferIncMultiHeadSelfAttention(model, - layer->layer_guid, - inputs[0], - embed_dim, - num_q_heads, - num_kv_heads, - kdim, - vdim, - dropout, - qkv_bias, - final_bias, - add_zero_attn, - apply_rotary_embedding, - scaling_query, - scaling_factor, - qk_prod_scaling, - position_bias, - false /*allocate_weights*/, - layer->name); -} - -SpecInferIncMultiHeadSelfAttention::SpecInferIncMultiHeadSelfAttention( - FFModel &model, - LayerID const &_layer_guid, - ParallelTensor const _input, - int _embed_dim, - int _num_q_heads, - int _num_kv_heads, - int _kdim, - int _vdim, - float _dropout, - bool _qkv_bias, - bool _final_bias, - bool _add_zero_attn, - bool _apply_rotary_embedding, - bool _scaling_query, - float _scaling_factor, - bool _qk_prod_scaling, - bool _position_bias, - bool allocate_weights, - char const *name) - // Initializer* _bias_initializer) - : Op(model, - OP_SPECINFER_INC_MULTIHEAD_SELF_ATTENTION, - _input->data_type, - name, - 1 /*inputs*/, - (_qkv_bias || _final_bias ? 2 : 1) /*weights*/, - 1 /*outputs*/, - _input), - num_q_heads(_num_q_heads), num_kv_heads(_num_kv_heads), dropout(_dropout), - qkv_bias(_qkv_bias), final_bias(_final_bias), - add_zero_attn(_add_zero_attn), - apply_rotary_embedding(_apply_rotary_embedding), - qSize(_input->dims[0].size), kSize(_input->dims[0].size), - vSize(_input->dims[0].size), qProjSize(_kdim), kProjSize(_kdim), - vProjSize(_vdim), oProjSize(_embed_dim), - qoSeqLength(_input->dims[1].size), kvSeqLength(_input->dims[1].size), - scaling_query(_scaling_query), scaling_factor(_scaling_factor), - qk_prod_scaling(_qk_prod_scaling), position_bias(_position_bias) { - // overwrite layer_guid - layer_guid = _layer_guid; - - numOutputs = 1; - int numdim = _input->num_dims; - ParallelDim dims[MAX_TENSOR_DIM]; - for (int i = 0; i < numdim; i++) { - dims[i] = _input->dims[i]; - } - dims[0].size = _embed_dim; - // Currently require no parallelism along this dim - assert(dims[0].degree == 1); - if (allocate_weights) { - // Create weight tensor - int num_dims = inputs[0]->num_dims; - // Compute weight size - int qParas = this->qProjSize * this->qSize; - int kParas = this->kProjSize * this->kSize; - int vParas = this->vProjSize * this->vSize; - int oParas = - this->oProjSize * (this->vProjSize > 0 ? this->vProjSize : this->vSize); - ParallelDim dims[2]; - dims[0] = inputs[0]->dims[num_dims - 2]; - dims[0].size = dims[0].degree; - dims[1] = inputs[0]->dims[num_dims - 1]; - dims[1].size = this->num_q_heads * (qParas + oParas) + - this->num_q_heads * (kParas + vParas); - dims[1].is_replica_dim = false; - int seed = std::rand(); - Initializer *initializer = new GlorotUniform(seed); - weights[0] = model.create_parallel_weight<2>(dims, - this->data_type, - NULL /*owner_op*/, - true /*create_grad*/, - initializer, - CHOSEN_SYNC_TYPE); - if (qkv_bias || final_bias) { - ParallelTensorShape bias_shape = _input->get_shape(); - int qkv_bias_size = - qProjSize * num_q_heads + (kProjSize + vProjSize) * num_q_heads; - bias_shape.dims[0].size = - (qkv_bias ? qkv_bias_size : 0) + (final_bias ? oProjSize : 0); - bias_shape.dims[1].size = bias_shape.dims[2].size = 1; - weights[1] = - model.create_parallel_weight_legion_ordering(bias_shape.num_dims, - bias_shape.dims, - this->data_type, - nullptr /*owner_op*/, - true /*create_grad*/, - initializer, - CHOSEN_SYNC_TYPE); - } - } - - outputs[0] = model.create_parallel_tensor_legion_ordering( - _input->num_dims, dims, this->data_type, this); - /* for (int i = 0; i < numdim; i++) { */ - /* register_output_input_parallel_dims(outputs[0], i, inputs[0], i); */ - /* } */ - /* // Check correctness */ - /* assert(check_output_input_weight_parallel_dims()); */ -} - -SpecInferIncMultiHeadSelfAttention::SpecInferIncMultiHeadSelfAttention( - FFModel &model, - ParallelTensor const _input, - ParallelTensor const _weight, - int _embed_dim, - int _num_q_heads, - int _num_kv_heads, - int _kdim, - int _vdim, - float _dropout, - bool _qkv_bias, - bool _final_bias, - bool _add_zero_attn, - bool _apply_rotary_embedding, - bool _scaling_query, - float _scaling_factor, - bool _qk_prod_scaling, - bool _position_bias, - bool allocate_weights, - char const *name) - // Initializer* _bias_initializer) - : Op(model, - OP_SPECINFER_INC_MULTIHEAD_SELF_ATTENTION, - _input->data_type, - name, - 1 /*inputs*/, - (_qkv_bias || _final_bias ? 2 : 1) /*weights*/, - 1 /*outputs*/, - _input, - _weight), - num_q_heads(_num_q_heads), num_kv_heads(_num_kv_heads), dropout(_dropout), - qkv_bias(_qkv_bias), final_bias(_final_bias), - add_zero_attn(_add_zero_attn), - apply_rotary_embedding(_apply_rotary_embedding), - qSize(_input->dims[0].size), kSize(_input->dims[0].size), - vSize(_input->dims[0].size), qProjSize(_kdim), kProjSize(_kdim), - vProjSize(_vdim), oProjSize(_embed_dim), - qoSeqLength(_input->dims[1].size), kvSeqLength(_input->dims[1].size), - scaling_query(_scaling_query), scaling_factor(_scaling_factor), - qk_prod_scaling(_qk_prod_scaling), position_bias(_position_bias) -// bias_initializer(_bias_initializer) -{ - numOutputs = 1; - int numdim = _input->num_dims; - ParallelDim dims[MAX_TENSOR_DIM]; - for (int i = 0; i < numdim; i++) { - dims[i] = _input->dims[i]; - } - dims[0].size = _embed_dim; - // Currently require no parallelism along this dim - assert(dims[0].degree == 1); - if (allocate_weights) { - // Create weight tensor - int num_dims = inputs[0]->num_dims; - // Compute weight size - int qParas = this->qProjSize * this->qSize; - int kParas = this->kProjSize * this->kSize; - int vParas = this->vProjSize * this->vSize; - int oParas = - this->oProjSize * (this->vProjSize > 0 ? this->vProjSize : this->vSize); - ParallelDim dims[2]; - dims[0] = inputs[0]->dims[num_dims - 2]; - dims[0].size = dims[0].degree; - dims[1] = inputs[0]->dims[num_dims - 1]; - dims[1].size = this->num_q_heads * (qParas + oParas) + - this->num_q_heads * (kParas + vParas); - dims[1].is_replica_dim = false; - // dims[2].size = qParas + kParas + vParas + oParas; - int seed = std::rand(); - Initializer *initializer = new GlorotUniform(seed); - weights[0] = model.create_parallel_weight<2>(dims, - this->data_type, - NULL /*owner_op*/, - true /*create_grad*/, - initializer, - CHOSEN_SYNC_TYPE); - if (qkv_bias || final_bias) { - ParallelTensorShape bias_shape = _input->get_shape(); - int qkv_bias_size = - qProjSize * num_q_heads + (kProjSize + vProjSize) * num_q_heads; - bias_shape.dims[0].size = - (qkv_bias ? qkv_bias_size : 0) + (final_bias ? oProjSize : 0); - bias_shape.dims[1].size = bias_shape.dims[2].size = 1; - weights[1] = - model.create_parallel_weight_legion_ordering(bias_shape.num_dims, - bias_shape.dims, - this->data_type, - nullptr /*owner_op*/, - true /*create_grad*/, - initializer, - CHOSEN_SYNC_TYPE); - } - } - - outputs[0] = model.create_parallel_tensor_legion_ordering( - _input->num_dims, dims, this->data_type, this); - - /* for (int i = 0; i < numdim; i++) { */ - /* register_output_input_parallel_dims(outputs[0], i, inputs[0], i); */ - /* } */ - /* register_output_weight_parallel_dims(outputs[0], numdim-1, _weight, 1); */ - /* register_output_weight_parallel_dims(outputs[0], numdim-2, _weight, 2); */ - // Check correctness - /* assert(check_output_input_weight_parallel_dims()); */ -} - -SpecInferIncMultiHeadSelfAttention::SpecInferIncMultiHeadSelfAttention( - FFModel &model, - SpecInferIncMultiHeadSelfAttention const &other, - ParallelTensor const input, - bool allocate_weights) - : SpecInferIncMultiHeadSelfAttention(model, - other.layer_guid, - input, - other.oProjSize, - other.num_q_heads, - other.num_kv_heads, - other.qProjSize, - other.vProjSize, - other.dropout, - other.qkv_bias, - other.final_bias, - other.add_zero_attn, - other.apply_rotary_embedding, - other.scaling_query, - other.scaling_factor, - other.qk_prod_scaling, - other.position_bias, - allocate_weights, - other.name) {} - -SpecInferIncMultiHeadSelfAttention::SpecInferIncMultiHeadSelfAttention( - FFModel &model, - SpecInferIncMultiHeadSelfAttentionParams const ¶ms, - ParallelTensor const &input, - bool allocate_weights, - char const *name) - : SpecInferIncMultiHeadSelfAttention(model, - params.layer_guid, - input, - params.embed_dim, - params.num_q_heads, - params.num_kv_heads, - params.kdim, - params.vdim, - params.dropout, - params.qkv_bias, - params.final_bias, - params.add_zero_attn, - params.apply_rotary_embedding, - params.scaling_query, - params.scaling_factor, - params.qk_prod_scaling, - params.position_bias, - allocate_weights, - name) {} - -void SpecInferIncMultiHeadSelfAttention::init_inference( - FFModel const &ff, - std::vector const &batch_inputs, - std::vector const &batch_outputs, - MachineView const *mv) { - assert(check_output_input_weight_same_parallel_is()); - parallel_is = batch_outputs[0]->parallel_is; - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - MachineView const *view = mv ? mv : &batch_outputs[0]->machine_view; - size_t machine_view_hash = view->hash(); - set_argumentmap_for_init_inference(ff, argmap, batch_outputs[0]); - IndexLauncher launcher( - SPECINFER_INC_MULTIHEAD_SELF_ATTENTION_INIT_TASK_ID, - parallel_is, - TaskArgument(this, sizeof(SpecInferIncMultiHeadSelfAttention)), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - machine_view_hash); - launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - batch_inputs[0]->region)); - launcher.add_field(0, FID_DATA); - launcher.add_region_requirement(RegionRequirement(weights[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - weights[0]->region)); - launcher.add_field(1, FID_DATA); - launcher.add_region_requirement(RegionRequirement(batch_outputs[0]->part, - 0 /*projection id*/, - WRITE_ONLY, - EXCLUSIVE, - batch_outputs[0]->region)); - launcher.add_field(2, FID_DATA); - FutureMap fm = runtime->execute_index_space(ctx, launcher); - fm.wait_all_results(); - set_opmeta_from_futuremap_inference(ff, fm, batch_outputs[0]); -} - -void SpecInferIncMultiHeadSelfAttention::init(FFModel const &ff) { - assert(check_output_input_weight_same_parallel_is()); - parallel_is = outputs[0]->parallel_is; - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - set_argumentmap_for_init(ff, argmap); - IndexLauncher launcher( - SPECINFER_INC_MULTIHEAD_SELF_ATTENTION_INIT_TASK_ID, - parallel_is, - TaskArgument(this, sizeof(SpecInferIncMultiHeadSelfAttention)), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - outputs[0]->machine_view.hash()); - launcher.add_region_requirement(RegionRequirement(inputs[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - inputs[0]->region)); - launcher.add_field(0, FID_DATA); - launcher.add_region_requirement(RegionRequirement(weights[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - weights[0]->region)); - launcher.add_field(1, FID_DATA); - launcher.add_region_requirement(RegionRequirement(outputs[0]->part, - 0 /*projection id*/, - WRITE_ONLY, - EXCLUSIVE, - outputs[0]->region)); - launcher.add_field(2, FID_DATA); - FutureMap fm = runtime->execute_index_space(ctx, launcher); - fm.wait_all_results(); - set_opmeta_from_futuremap(ff, fm); -} - -/* - regions[0](I): input - regions[1](I): weight - regions[2](O): output -*/ -OpMeta *SpecInferIncMultiHeadSelfAttention::init_task( - Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - SpecInferIncMultiHeadSelfAttention const *attn = - (SpecInferIncMultiHeadSelfAttention *)task->args; - FFHandler handle = *((FFHandler const *)task->local_args); - - GenericTensorAccessorR input = - helperGetGenericTensorAccessorRO(attn->inputs[0]->data_type, - regions[0], - task->regions[0], - FID_DATA, - ctx, - runtime); - GenericTensorAccessorR weight = - helperGetGenericTensorAccessorRO(attn->weights[0]->data_type, - regions[1], - task->regions[1], - FID_DATA, - ctx, - runtime); - GenericTensorAccessorW output = - helperGetGenericTensorAccessorWO(attn->outputs[0]->data_type, - regions[2], - task->regions[2], - FID_DATA, - ctx, - runtime); - - int num_samples = input.domain.hi()[2] - input.domain.lo()[2] + 1; - assert(attn->qoSeqLength == input.domain.hi()[1] - input.domain.lo()[1] + 1); - assert(attn->kvSeqLength == input.domain.hi()[1] - input.domain.lo()[1] + 1); - int num_q_heads = attn->num_q_heads; - int num_kv_heads = attn->num_kv_heads; - assert(attn->oProjSize == output.domain.hi()[0] - output.domain.lo()[0] + 1); - - Memory gpu_mem = Machine::MemoryQuery(Machine::get_machine()) - .only_kind(Memory::GPU_FB_MEM) - .best_affinity_to(task->target_proc) - .first(); - MemoryAllocator gpu_mem_allocator(gpu_mem); - // We don't do offloading for SSMs (small speculative models) - SpecInferIncMultiHeadSelfAttentionMeta *m = - new SpecInferIncMultiHeadSelfAttentionMeta(handle, - attn, - weight, - gpu_mem_allocator, - num_samples, - num_q_heads, - num_kv_heads); - // assert that we didn't over allocate memory - assert(gpu_mem_allocator.instance_allocated_size == - gpu_mem_allocator.instance_total_size); - m->profiling = attn->profiling; - m->inference_debugging = attn->inference_debugging; - std::strcpy(m->op_name, attn->name); - m->layer_guid = attn->layer_guid; - assert(weight.domain.get_volume() * data_type_size(weight.data_type) == - m->weightSize); - return m; -} - -void SpecInferIncMultiHeadSelfAttention::forward(FFModel const &ff) { - // SpecInferIncMultiHeadSelfAttention doesn't support forward - assert(false); -} - -FutureMap SpecInferIncMultiHeadSelfAttention::inference( - FFModel const &ff, - BatchConfigFuture const &bc, - std::vector const &batch_inputs, - std::vector const &batch_outputs, - MachineView const *mv) { - ArgumentMap argmap; - Context ctx = ff.config.lg_ctx; - Runtime *runtime = ff.config.lg_hlr; - parallel_is = batch_outputs[0]->parallel_is; - MachineView const *view = mv ? mv : &batch_outputs[0]->machine_view; - set_argumentmap_for_inference(ff, argmap, batch_outputs[0]); - size_t machine_view_hash = view->hash(); - int idx = 0; - IndexLauncher launcher(SPECINFER_INC_MULTIHEAD_SELF_ATTENTION_INF_TASK_ID, - parallel_is, - TaskArgument(nullptr, 0), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - machine_view_hash); - launcher.add_future(bc); - launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - batch_inputs[0]->region)); - launcher.add_field(idx++, FID_DATA); - launcher.add_region_requirement(RegionRequirement(weights[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - weights[0]->region)); - launcher.add_field(idx++, FID_DATA); - launcher.add_region_requirement(RegionRequirement(batch_outputs[0]->part, - 0 /*projection id*/, - WRITE_ONLY, - EXCLUSIVE, - batch_outputs[0]->region)); - launcher.add_field(idx++, FID_DATA); - - if (qkv_bias || final_bias) { - launcher.add_region_requirement(RegionRequirement(weights[1]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - weights[1]->region)); - launcher.add_field(idx++, FID_DATA); - } - return runtime->execute_index_space(ctx, launcher); -} - -/* - regions[0](I): input - regions[3](I): weight - regions[4](O): output -*/ -void SpecInferIncMultiHeadSelfAttention::inference_task( - Task const *task, - std::vector const ®ions, - Context ctx, - Runtime *runtime) { - assert(task->regions.size() == regions.size()); - - BeamSearchBatchConfig const &bc = - Future(task->futures[0]).get_result(); - if (bc.num_tokens == 0) { - return; - } - - SpecInferIncMultiHeadSelfAttentionMeta *m = - *((SpecInferIncMultiHeadSelfAttentionMeta **)task->local_args); - assert(((*m->qkv_bias || *m->final_bias) ? regions.size() == 4 - : regions.size() == 3)); - - GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( - m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); - GenericTensorAccessorR weight = helperGetGenericTensorAccessorRO( - m->weight_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); - GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( - m->output_type[0], regions[2], task->regions[2], FID_DATA, ctx, runtime); - GenericTensorAccessorR biases; - if (*m->qkv_bias || *m->final_bias) { - biases = helperGetGenericTensorAccessorRO(m->weight_type[1], - regions[3], - task->regions[3], - FID_DATA, - ctx, - runtime); - Domain bias_domain = runtime->get_index_space_domain( - ctx, task->regions[3].region.get_index_space()); - assert(bias_domain.get_dim() == 4); - } - Domain input_domain = runtime->get_index_space_domain( - ctx, task->regions[0].region.get_index_space()); - Domain weight_domain = runtime->get_index_space_domain( - ctx, task->regions[1].region.get_index_space()); - Domain output_domain = runtime->get_index_space_domain( - ctx, task->regions[2].region.get_index_space()); - - assert(input_domain.get_dim() == 4); - assert(weight_domain.get_dim() == 2); - assert(output_domain.get_dim() == 4); - - assert(task->index_point.get_dim() == 1); - SpecInferIncMultiHeadSelfAttention::inference_kernel_wrapper( - m, &bc, task->index_point.point_data[0], input, weight, output, biases); - if (m->inference_debugging) { - assert(task->index_point.get_dim() == 1); - int shard_id = task->index_point.point_data[0]; - std::vector weights_accessors; - weights_accessors.push_back(weight); - if (*m->qkv_bias || *m->final_bias) { - weights_accessors.push_back(biases); - } - SpecInferIncMultiHeadSelfAttention::save_inference_tensors_to_file( - m, shard_id, &bc, {input}, weights_accessors, {output}); - } -} - -void SpecInferIncMultiHeadSelfAttention::backward(FFModel const &ff) { - // SpecInferIncMultiHeadSelfAttention does not support backward - assert(false); -} - -bool SpecInferIncMultiHeadSelfAttention::get_int_parameter(PMParameter para, - int *value) const { - switch (para) { - case PM_NUM_HEADS: - *value = num_q_heads; - return true; - default: - return Op::get_int_parameter(para, value); - } -} - -Op *SpecInferIncMultiHeadSelfAttention::materialize(FFModel &ff, - ParallelTensor inputs[], - int num_inputs) const { - SpecInferIncMultiHeadSelfAttentionParams params = get_params(); - return new SpecInferIncMultiHeadSelfAttention( - ff, params, inputs[0], true, this->name); -} - -bool SpecInferIncMultiHeadSelfAttention::measure_operator_cost( - Simulator *sim, MachineView const &mv, CostMetrics &cost_metrics) const { - return false; -} - -bool operator==(SpecInferIncMultiHeadSelfAttentionParams const &lhs, - SpecInferIncMultiHeadSelfAttentionParams const &rhs) { - return lhs.layer_guid == rhs.layer_guid && lhs.embed_dim == rhs.embed_dim && - lhs.num_q_heads == rhs.num_q_heads && lhs.kdim == rhs.kdim && - lhs.vdim == rhs.vdim && lhs.dropout == rhs.dropout && - lhs.qkv_bias == rhs.qkv_bias && lhs.final_bias == rhs.final_bias && - lhs.add_zero_attn == rhs.add_zero_attn && - lhs.apply_rotary_embedding == rhs.apply_rotary_embedding && - lhs.scaling_query == rhs.scaling_query && - lhs.scaling_factor == rhs.scaling_factor && - lhs.qk_prod_scaling == rhs.qk_prod_scaling && - lhs.position_bias == rhs.position_bias; -} - -SpecInferIncMultiHeadSelfAttentionParams - SpecInferIncMultiHeadSelfAttention::get_params() const { - SpecInferIncMultiHeadSelfAttentionParams params; - params.layer_guid = this->layer_guid; - params.embed_dim = this->oProjSize; - params.num_q_heads = this->num_q_heads; - params.num_kv_heads = this->num_kv_heads; - params.kdim = this->kProjSize; - params.vdim = this->vProjSize; - params.dropout = this->dropout; - params.qkv_bias = this->qkv_bias; - params.final_bias = this->final_bias; - params.add_zero_attn = this->add_zero_attn; - params.apply_rotary_embedding = this->apply_rotary_embedding; - params.scaling_query = this->scaling_query; - params.scaling_factor = this->scaling_factor; - params.qk_prod_scaling = this->qk_prod_scaling; - params.position_bias = this->position_bias; - - return params; -} - -}; // namespace FlexFlow - -namespace std { -size_t hash::operator()( - FlexFlow::SpecInferIncMultiHeadSelfAttentionParams const ¶ms) const { - size_t key = 0; - hash_combine(key, params.layer_guid.id); - hash_combine(key, params.embed_dim); - hash_combine(key, params.num_q_heads); - hash_combine(key, params.num_kv_heads); - hash_combine(key, params.kdim); - hash_combine(key, params.vdim); - hash_combine(key, params.dropout); - hash_combine(key, params.qkv_bias); - hash_combine(key, params.final_bias); - hash_combine(key, params.add_zero_attn); - hash_combine(key, params.apply_rotary_embedding); - hash_combine(key, params.scaling_query); - hash_combine(key, params.scaling_factor); - hash_combine(key, params.qk_prod_scaling); - hash_combine(key, params.position_bias); - return key; -} -}; // namespace std diff --git a/src/ops/specinfer_inc_multihead_self_attention.cu b/src/ops/specinfer_inc_multihead_self_attention.cu deleted file mode 100644 index 8340519ff3..0000000000 --- a/src/ops/specinfer_inc_multihead_self_attention.cu +++ /dev/null @@ -1,958 +0,0 @@ -/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) -#include "cuComplex.h" -#endif -#include "flexflow/ffconst_utils.h" -#include "flexflow/ops/kernels/inc_multihead_self_attention_kernels.h" -#include "flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh" -#include "flexflow/ops/specinfer_inc_multihead_self_attention.h" -#include "flexflow/utils/cuda_helper.h" - -namespace FlexFlow { - -#define WARP_SIZE 32 - -// declare Legion names -using Legion::coord_t; -using Legion::Memory; -using namespace Kernels::IncMultiHeadAttention; - -namespace Kernels { -namespace SpecInferIncMultiHeadAttention { - -template -__global__ void compute_specinfer_attention_kernel_generation_kernel( - DT const *query, - DT const *key_cache, - DT const *value_cache, - DT *output_ptr, - float const scale, - int const max_seq_length, - int per_head_size, - int hidden_size, - BatchConfig::PerRequestInfo *request_infos, - BeamSearchBatchConfig::BeamSearchPerRequestInfo *beam_request_infos, - BatchConfig::BitMask *causalMask) { - - // q, k - using Q_vec = typename VEC_K::Type; - using K_vec = typename VEC_K::Type; - using V_vec = typename VEC_V
::Type; - using Out_sum = typename Vec_fp32_::Type; - - constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; - - constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(DT); - constexpr int K_ELTS_PER_THREAD = Dh / THREADS_PER_KEY; - constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; - // constexpr int QK_ELTS_IN_16B = 16 / sizeof(DT); - - // thread id - int const tidx = threadIdx.x; - // head id - int const head_idx = blockIdx.x; - // nth request idx - int const request_idx = blockIdx.y; - - // request id in batch config - int const batch_config_request_id = - request_infos[request_idx].batch_config_request_id; - - // request_idx = re - - BatchConfig::BitMask bitmask = causalMask[batch_config_request_id]; - - int const first_step = 0; - - int const tlength = - request_infos[batch_config_request_id].first_token_depth_in_request + - request_infos[batch_config_request_id].num_tokens_in_batch; - - if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { - printf("specinfer attn fused kernel!!!\n"); - } - - int const totalCacheSize = bitmask.non_tree_cache_size + bitmask.tree_size; - - if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { - printf("specinfer attn fused kernel %d, %d\n", - totalCacheSize, - request_infos[batch_config_request_id].num_tokens_in_batch); - } - // int const qlength = request_infos[request_idx].num_tokens_in_batch; - int const tree_branch_num = - beam_request_infos[batch_config_request_id].sub_request_num; - - // will decode qlength tokens in this thread block - // int const qlength = tree_branch_num; - - int first_token_idx = 0; - for (int r = 0; r < request_idx; r++) { - first_token_idx += causalMask[r].this_layer_size; - } - - // if (tidx == 0 && head_idx == 0) { - // printf("spec req: %d, %d\n", request_idx, first_token_idx); - // } - - // shared memory objects - extern __shared__ char smem_[]; - - float *qk_smem = reinterpret_cast(smem_); - float *out_smem = reinterpret_cast(smem_); - - float qk_max = -FLT_MAX; - - // first WARPS_PER_BLOCK for store qk_max, second WARPS_PER_BLOCK for sum - __shared__ float red_smem[WARPS_PER_BLOCK * 2]; - - const DT *q_ptr = query + first_token_idx * hidden_size * QKV_WEIGHT_NUM + - head_idx * per_head_size; - __shared__ Q_vec q_vecs[THREADS_PER_KEY][K_VECS_PER_THREAD]; - - // the start offset of the element eg. (0, 1, 2, 3) * K_VEC_SIZE - int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE; - int ki_o = tidx % THREADS_PER_KEY; - // the first key's offset for this thread - // ko = 0, 0, 0, 0, 1, 1, 1, 1, .... - int ko = tidx / THREADS_PER_KEY; - // load q tensor - Q_vec q_vec[K_VECS_PER_THREAD]; - - constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; - // The number of keys per warp. - constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; - - DT const *k_cache_batch = - key_cache + batch_config_request_id * max_seq_length * hidden_size + ki; - - int ti_end = - div_up(totalCacheSize - first_step, K_PER_WARP) * K_PER_WARP + first_step; - - for (int qi = 0; qi < tree_branch_num; qi += 1) { -#pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - q_vecs[ki_o][ii] = *reinterpret_cast( - q_ptr + (hidden_size * QKV_WEIGHT_NUM * qi) + ki + - ii * THREADS_PER_KEY * K_VEC_SIZE); - } - - int const query_token = bitmask.tree_size - tree_branch_num + qi; - - __syncthreads(); - for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { - K_vec k[K_VECS_PER_THREAD]; - int const ti_circ = ti % max_seq_length; - - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - int jj = ii * THREADS_PER_KEY * K_VEC_SIZE; - if (ti < totalCacheSize) { - - k[ii] = *reinterpret_cast( - k_cache_batch + ti_circ * hidden_size + head_idx * per_head_size + - jj); - } - } - float qk = scale * Qk_dot::dot(q_vecs[ki_o], k); - - if (ti < totalCacheSize && tidx % THREADS_PER_KEY == 0) { - // todo add alobi here - // bool const mask = ti_circ >= totalCacheSize; - bool const mask = (ti >= bitmask.non_tree_cache_size && - (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & - (1 << query_token)))); - - // if (blockIdx.y == 0 && blockIdx.x == 0 && !mask) { - // printf("spec inc attn qkqkqk %d, %.10f, %d\n", ti, qk, qi); - // } - qk_max = mask ? qk_max : fmaxf(qk_max, qk); - qk_smem[ti - first_step] = mask ? 0.f : qk; - } - } - - __syncthreads(); - -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); - } - - // Decompose the thread index into warp and lane. - int const warp = tidx / WARP_SIZE; - int const lane = tidx % WARP_SIZE; - - // The warp leader writes the max to shared memory. - if (lane == 0) { - red_smem[warp] = qk_max; - } - - // Make sure the products are in shared memory. - __syncthreads(); - - // The warps finalize the reduction. - qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; -#pragma unroll - for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); - } - - // Broadcast to all the threads in the warp. - qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); - - // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { - // printf("spec inc attn first token qk_max %.10f\n", qk_max); - // } - - float exp_sum = 0.f; - for (int ti = first_step + tidx; ti < totalCacheSize; - ti += THREADS_PER_BLOCK) { - bool const mask = (ti >= bitmask.non_tree_cache_size && - (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & - (1 << query_token)))); - float logit = mask ? 0.0f : __expf(qk_smem[ti - first_step] - qk_max); - exp_sum += logit; - qk_smem[ti - first_step] = mask ? 0.0f : logit; - } - - // Compute the sum. - exp_sum = block_sum(&red_smem[WARPS_PER_BLOCK], exp_sum); - - // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { - // printf("spec inc attn exp_sum %.10f\n", exp_sum); - // } - - // softmax - float inv_sum = __fdividef(1.f, exp_sum + 1.e-6); - for (int ti = first_step + tidx; ti < totalCacheSize; - ti += THREADS_PER_BLOCK) { - qk_smem[ti - first_step] *= inv_sum; - } - - __syncthreads(); - // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { - // printf("softmax %.10f\n", qk_smem[0]); - // } - - // value projection - constexpr int V_VEC_SIZE = 16 / sizeof(DT); - // A vector of V elements for the current timestep. - // using V_vec_k = typename V_vec_k_::Type; - // using V_vec_acum = typename V_vec_acum_fp32_::Type; - - // The value computed by this thread. - int vo = tidx / THREADS_PER_VALUE; - // The hidden dimensions computed by this particular thread. - int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; - constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; - - Out_sum out; - zero(out); - - // The base pointer for the value in the cache buffer. - DT const *v_cache_batch = - value_cache + batch_config_request_id * max_seq_length * hidden_size + - vi; - - if (Dh == Dh_MAX || vi < Dh) { - for (int ti = first_step + vo; ti < totalCacheSize; ti += V_PER_ITER) { - // Load the values from the cache. - int const ti_circ = ti % max_seq_length; - V_vec v = *reinterpret_cast( - v_cache_batch + ti_circ * hidden_size + head_idx * per_head_size); - - bool const mask = (ti >= bitmask.non_tree_cache_size && - (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & - (1 << query_token)))); - float logit = mask ? 0.0f : qk_smem[ti - first_step]; - out = FlexFlow::fma(logit, cast_to_float(v), out); - } - } - - // // Make sure we can start writing to shared memory. - __syncthreads(); - - // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { - // printf("valueX %.10f\n", out.x); - // } - - // Run the final reduction amongst the different groups computing different - // partial outputs. - if (Dh == Dh_MAX || vi < Dh) { -#pragma unroll - for (int active_groups = V_PER_ITER; active_groups >= 2; - active_groups /= 2) { - - // The midpoint in the number of active groups. - int midpoint = active_groups / 2; - - // The upper part of active threads store to shared memory. - if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) { - *reinterpret_cast(out_smem + (vo - midpoint) * Dh + vi) = - out; - } - __syncthreads(); - - // The bottom warps update their values. - if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) { - out = add(*reinterpret_cast(out_smem + vo * Dh + vi), - out); - } - __syncthreads(); - } - } - - // Output the final values. - if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { - convert_from_float(*reinterpret_cast( - output_ptr + (first_token_idx + qi) * hidden_size + - head_idx * per_head_size + vi), - out); - } - } -} - -template -__global__ void specinfer_store_kv_cache( - DT const *devQKVProjArray, - DT *kCache_ptr, - DT *vCache_ptr, - BatchConfig::PerTokenInfo *tokenInfos, - BatchConfig::PerRequestInfo *requestInfo, - BeamSearchBatchConfig::BeamSearchPerTokenInfo *beamTokenInfos, - BeamSearchBatchConfig::BeamSearchPerRequestInfo *beamRequestInfos, - BatchConfig::BitMask *causalMask, - int qProjSize, - int kProjSize, - int vProjSize, - int num_tokens, - int max_seq_len, - bool is_root, - int hidden_size) { - CUDA_KERNEL_LOOP(i, num_tokens * hidden_size) { - int token_idx = i / (hidden_size); - int offset = i % hidden_size; - - size_t val_idx = - token_idx * QKV_WEIGHT_NUM * hidden_size + hidden_size + offset; - - DT kVal = devQKVProjArray[val_idx]; - DT vVal = devQKVProjArray[val_idx + hidden_size]; - - int const req_id = tokenInfos[token_idx].request_index; - int const tok_id = tokenInfos[token_idx].abs_depth_in_request; - int const first_token_in_req = - requestInfo[req_id].first_token_depth_in_request; - int const sub_req_id = beamTokenInfos[token_idx].sub_request_index; - int const total_token = requestInfo[req_id].num_tokens_in_batch; - - int const request_token_offset = - requestInfo[req_id].first_token_offset_in_batch; - - BatchConfig::BitMask bitmask = causalMask[req_id]; - - int const sub_request_num = beamRequestInfos[req_id].sub_request_num; - - int const tree_branch_num = beamRequestInfos[req_id].sub_request_num; - - // int const query_token = bitmask.non_tree_cache_size + bitmask.tree_size - - // tree_branch_num + sub_req_id + tok_id; - // bitmask.tree_size - tree_branch_num + sub_req_id; - - // if prompt token -> token id - // if tree token: - int const cache_idx = bitmask.non_tree_cache_size + bitmask.tree_size - - bitmask.this_layer_size + token_idx - - request_token_offset; - - kCache_ptr[req_id * (hidden_size * max_seq_len) + (cache_idx)*hidden_size + - offset] = kVal; - vCache_ptr[req_id * (hidden_size * max_seq_len) + (cache_idx)*hidden_size + - offset] = vVal; - } -} - -template -void update_kv_cache_kernel(SpecInferIncMultiHeadSelfAttentionMeta const *m, - BeamSearchBatchConfig const *bc, - cudaStream_t stream) { - int num_tokens = bc->num_active_tokens(); - int curr_depth = bc->beamRequestsInfo[0].current_depth; - // printf("curr depth: %d\n", curr_depth); - // assert(curr_depth < 3); - if (num_tokens > 0) { - int parallelism = m->hidden_size * KV_WEIGHT_NUM * num_tokens; - // printf("tokenInfo %d, %d\n", - // bc->beamTokenInfo[0].sub_request_index, - // num_tokens); - specinfer_store_kv_cache<<>>( - static_cast
(m->devQKVProjArray), - static_cast
(m->keyCache), - static_cast
(m->valueCache), - m->token_infos, - m->request_infos, - m->beam_token_infos, - m->beam_request_infos, - m->causalMask, - m->qProjSize, - m->kProjSize, - m->vProjSize, - num_tokens, - BatchConfig::max_sequence_length() + - BatchConfig::MAX_SPEC_TREE_TOKEN_NUM, - /*root*/ curr_depth == 0, - m->hidden_size); - } -} - -#define LAUNCH_SPECINFER_ATTENTION_SCORE_KERNEL( \ - DT, Dh, Dh_MAX, THDS_PER_KEY, THREADS_PER_VALUE, THDS_PER_BLOCK, stream) \ - smem_sz = smem_size_in_bytes
(m->qProjSize, \ - BatchConfig::max_sequence_length() + \ - BatchConfig::MAX_SPEC_TREE_TOKEN_NUM, \ - THREADS_PER_VALUE, \ - THDS_PER_BLOCK); \ - compute_specinfer_attention_kernel_generation_kernel \ - <<>>( \ - static_cast
(m->devQKVProjArray), \ - static_cast
(m->keyCache), \ - static_cast
(m->valueCache), \ - output_ptr, \ - scale, \ - BatchConfig::max_sequence_length() + \ - BatchConfig::MAX_SPEC_TREE_TOKEN_NUM, \ - m->qProjSize, \ - m->hidden_size, \ - m->request_infos, \ - m->beam_request_infos, \ - m->causalMask) - -template -void compute_specinfer_attention_kernel_generation( - SpecInferIncMultiHeadSelfAttentionMeta const *m, - BeamSearchBatchConfig const *bc, - DT *output_ptr, - cudaStream_t stream) { - // one block == one head per request - printf("??? at here: %d\n", bc->num_active_requests()); - dim3 grid(m->num_q_heads, bc->num_active_requests()); - int const per_head_size = m->qProjSize; - float scale = (*m->qk_prod_scaling) ? 1.0f / sqrt(m->kProjSize) : 1.0f; - size_t smem_sz; - if (per_head_size == 64) { - constexpr int THREADS_PER_VALUE_64 = threads_per_value_t::value; - LAUNCH_SPECINFER_ATTENTION_SCORE_KERNEL( - DT, 64, 64, 4, THREADS_PER_VALUE_64, 128, stream); - } else if (per_head_size == 128) { - constexpr int THREADS_PER_VALUE_128 = threads_per_value_t::value; - LAUNCH_SPECINFER_ATTENTION_SCORE_KERNEL( - DT, 128, 128, 4, THREADS_PER_VALUE_128, 128, stream); - } else { - assert(false && "a unsupported head size"); - } -} - -template -__global__ void spec_fill_entries_above_diagonal(DT *matrix, - size_t new_tokens, - size_t total_tokens_in_request, - size_t num_q_heads, - DT value) { - CUDA_KERNEL_LOOP(i, new_tokens * total_tokens_in_request * num_q_heads) { - // size_t head_idx = i / (new_tokens * total_tokens_in_request); - size_t src_idx = (i / new_tokens) % total_tokens_in_request; - size_t dst_idx = i % new_tokens + total_tokens_in_request - new_tokens; - // Casual Mask - if (src_idx > dst_idx) { - matrix[i] = value; - } - } -} - -template -void compute_attention_kernel_prompt( - SpecInferIncMultiHeadSelfAttentionMeta const *m, - BeamSearchBatchConfig const *bc, - int shard_id, - DT *output_ptr, - DT const *bias_ptr, - DT const *weight_ptr, - cudaStream_t stream) { - checkCUDA(cublasSetStream(m->handle.blas, stream)); - checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); - cudaDataType_t cublas_data_type = ff_to_cuda_datatype(m->output_type[0]); - cudnnDataType_t cudnn_data_type = ff_to_cudnn_datatype(m->output_type[0]); - assert(data_type_size(m->output_type[0]) == sizeof(DT)); -#if defined(CUDA_VERSION) && (CUDA_VERSION < 11000) - cudaDataType_t compute_type = cublas_data_type; -#else - // For best performance, set the default cublas compute type to - // CUBLAS_COMPUTE_16F for half precision and to - // CUBLAS_COMPUTE_32F_FAST_16F for full precision - cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; - if (m->output_type[0] == DT_FLOAT) { - compute_type = CUBLAS_COMPUTE_32F_FAST_16F; - } -#endif - // int num_requests = bc->num_active_requests(); - int num_tokens = bc->num_active_tokens(); - int tokens_previous_requests = 0; - int tokens_prev_requests_squares = 0; - // int qkv_block_size = - // (m->qProjSize + m->kProjSize + m->vProjSize) * num_tokens; - int q_block_size = m->qProjSize; - - int kt_block_size = m->kProjSize; - int kt_req_block_size = kt_block_size * m->num_q_heads * - (BatchConfig::max_sequence_length() + - BatchConfig::MAX_SPEC_TREE_TOKEN_NUM); - int vt_block_size = m->vProjSize; - int vt_req_block_size = vt_block_size * m->num_q_heads * - (BatchConfig::max_sequence_length() + - BatchConfig::MAX_SPEC_TREE_TOKEN_NUM); - assert(m->qProjSize == m->kProjSize); - - for (int i = 0; i < bc->max_requests_per_batch(); i++) { - if (bc->request_completed[i]) { - continue; - } - // else if (tokens_previous_requests < bc->num_generation_tokens) { - // tokens_previous_requests += bc->requestsInfo[i].num_tokens_in_batch; - // continue; - // } - - // all requests in prompt phase should only have one sub requests; - assert(bc->sub_requests[i] == 1); - // int num_new_tokens = bc->num_processing_tokens[i]; - // int total_tokens = bc->token_last_available_idx[i] + 1; - - int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch; - int total_tokens = bc->requestsInfo[i].first_token_depth_in_request + - bc->requestsInfo[i].num_tokens_in_batch; - - if (num_new_tokens <= 0) { - continue; - } - - // Compute (QK^T/sqrt(d_k)) - int m_ = num_new_tokens; - int n = total_tokens; - int k = m->qProjSize; - int lda = k * m->num_q_heads * QKV_WEIGHT_NUM, ldb = k * m->num_q_heads, - ldc = m_; - int strideA = q_block_size; - int strideB = kt_block_size; - int strideC = num_new_tokens * total_tokens; - - // a flag of using this scaling alpha - DT alpha = 1.0f, beta = 0.0f; - if (*m->qk_prod_scaling) { - alpha = static_cast
(1.0f / sqrt(m->kProjSize)); - } - // To get A, skip over Q entries from previous requests (same head) - DT const *A = static_cast
(m->devQKVProjArray) + - bc->requestsInfo[i].first_token_offset_in_batch * - m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM; - // To get B, skip over K entries from previous requests (all heads + - // padding) - - // print_tensor((float*)A, 32, "A"); - std::cout << "meta: " << num_new_tokens << ", " << total_tokens << "\n"; - DT const *B = static_cast
(m->keyCache) + i * kt_req_block_size; - - // if (i == 0 && sub_req_id == 0 && - // bc->beam_slots.at(0).current_depth == 1) { - // int offset = (float *)B - m->keyCache; - // printf("key cache offset %d\n", kt_req_block_size); - // } - // To get C, skip over QK^T products from previous requests - DT *C = static_cast
(m->qk_prods) + - m->num_q_heads * tokens_prev_requests_squares; - checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, - CUBLAS_OP_T, - CUBLAS_OP_N, - m_, - n, - k, - &alpha, - A, - cublas_data_type, - lda, - strideA, - B, - cublas_data_type, - ldb, - strideB, - &beta, - C, - cublas_data_type, - ldc, - strideC, - m->num_q_heads, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - // print_tensor((float*)C, 32, "C"); - // add alibi position bias to qk production - // add alibi position bias to qk production - if (*m->position_bias) { - size_t parallelism = m->num_q_heads * total_tokens * num_new_tokens; - apply_position_bias_qkprd<<>>(C, - num_new_tokens, - total_tokens, - m->num_q_heads, - m->global_num_q_heads, - shard_id); - } - // Fill all elements above diagonal in qk prods with -inf to force - // causal attention. - assert(num_new_tokens <= total_tokens); - if (num_new_tokens > 1) { - size_t parallelism = m->num_q_heads * num_new_tokens * total_tokens; - spec_fill_entries_above_diagonal<<>>(C, - num_new_tokens, - total_tokens, - m->num_q_heads, - static_cast
(-INFINITY)); - } - // Compute Softmax(QK^T/sqrt(d_k)) - // Before modifying the parameters below, make sure to read the following - // description of the CUDNN_TENSOR_NCHW tensor layout, from - // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnTensorFormat_t: - // This tensor format specifies that the data is laid out in the following - // order: batch size, feature maps, rows, columns. The strides are - // implicitly defined in such a way that the data are contiguous in memory - // with no padding between images, feature maps, rows, and columns; the - // columns are the inner dimension and the images are the outermost - // dimension. - int n_param = m->num_q_heads; - int c_param = total_tokens; - int h_param = 1; - int w_param = num_new_tokens; - checkCUDNN(cudnnSetTensor4dDescriptor(m->qk_tensor, - CUDNN_TENSOR_NCHW, - cudnn_data_type, - n_param, - c_param, - h_param, - w_param)); - float softmax_alpha = 1.0f, softmax_beta = 0.0f; - DT *C_softmax = static_cast
(m->qk_prods_softmax) + - m->num_q_heads * tokens_prev_requests_squares; - // The softmax operation below is executed according to the - // CUDNN_SOFTMAX_MODE_CHANNEL, which is also described in the docs: The - // softmax operation is computed per spatial location (H,W) per image (N) - // across dimension C. - checkCUDNN(cudnnSoftmaxForward(m->handle.dnn, - CUDNN_SOFTMAX_ACCURATE, - CUDNN_SOFTMAX_MODE_CHANNEL, - &softmax_alpha, - m->qk_tensor, - C, - &softmax_beta, - m->qk_tensor, - C_softmax)); - // Matmul softmax(QK^T/sqrt(d_k)) by V - alpha = 1.0f, beta = 0.0f; - m_ = m->vProjSize; - n = num_new_tokens; - k = total_tokens; - lda = m_ * m->num_q_heads, ldb = n, ldc = m_ * m->num_q_heads; - strideA = vt_block_size; - strideB = num_new_tokens * total_tokens; - strideC = m->vProjSize; - // To get A, skip over V^T entries from previous requests (all heads + - // padding) - A = static_cast
(m->valueCache) + i * vt_req_block_size; - // To get B, skip over softmax(QK^T/sqrt(d_k)) entries from previous - // requests (all heads) - B = C_softmax; - // To get C, skip over softmax(QK^T/sqrt(d_k))V products from previous - // requests - - // print_tensor((float*)C_softmax, 32, "C_softmax"); - C = static_cast
(m->attn_heads) + - (tokens_previous_requests + bc->num_generation_tokens) * - m->num_q_heads * m->vProjSize; - checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, - CUBLAS_OP_N, - CUBLAS_OP_T, - m_, - n, - k, - &alpha, - A, - cublas_data_type, - lda, - strideA, - B, - cublas_data_type, - ldb, - strideB, - &beta, - C, - cublas_data_type, - ldc, - strideC, - m->num_q_heads, - compute_type, - CUBLAS_GEMM_DEFAULT_TENSOR_OP)); - - tokens_previous_requests += num_new_tokens; - tokens_prev_requests_squares += num_new_tokens * total_tokens; - } - - // assert(tokens_previous_requests == num_tokens); -} - -template -void inference_kernel(SpecInferIncMultiHeadSelfAttentionMeta const *m, - BeamSearchBatchConfig const *bc, - int shard_id, - DT const *input_ptr, - DT const *weight_ptr, - DT *output_ptr, - DT const *bias_ptr, - cudaStream_t stream) { - // phase 1: Implement kernel to compute KQV for input tokens - - compute_qkv_kernel(m, - bc, - shard_id, - input_ptr, - weight_ptr, - static_cast
(m->devQKVProjArray), - bias_ptr, - stream); - // phase 2: Update key/val cache - update_kv_cache_kernel
(m, bc, stream); - // std::cout << "specinfer kernel token num: " << bc->num_generation_tokens - // << ", " << bc->num_tokens << "\n"; - if (bc->num_generation_tokens > 0) { - printf("spec inc generation decoding\n"); - compute_specinfer_attention_kernel_generation
( - m, bc, static_cast
(m->attn_heads), stream); - } - // phase 3: Compute attention score - // 3 kernels for pahse 3: matmul1 - softmax - matmal2 - if (bc->num_tokens > bc->num_generation_tokens) { - printf("spec inc prompt decoding\n"); - compute_attention_kernel_prompt( - m, bc, shard_id, output_ptr, bias_ptr, weight_ptr, stream); - } - // compute_attention_kernel_prompt( - // m, bc, shard_id, output_ptr, bias_ptr, weight_ptr, stream); - - // compute output production and bias together for all tokens - int num_tokens = bc->num_active_tokens(); - - // std::cout << "specinfer num tokens: " << num_tokens; - - compute_o_prod_bias( - m, bc, shard_id, output_ptr, weight_ptr, bias_ptr, num_tokens, stream); -} - -} // namespace SpecInferIncMultiHeadAttention -} // namespace Kernels - -/*static*/ -void SpecInferIncMultiHeadSelfAttention::inference_kernel_wrapper( - SpecInferIncMultiHeadSelfAttentionMeta const *m, - BeamSearchBatchConfig const *bc, - int shard_id, - GenericTensorAccessorR const &input, - GenericTensorAccessorR const &weight, - GenericTensorAccessorW const &output, - GenericTensorAccessorR const &bias) { - cudaStream_t stream; - checkCUDA(get_legion_stream(&stream)); - bool use_bias = *m->qkv_bias || *m->final_bias; - - cudaEvent_t t_start, t_end; - if (m->profiling) { - cudaEventCreate(&t_start); - cudaEventCreate(&t_end); - cudaEventRecord(t_start, stream); - } - - assert(input.data_type == weight.data_type); - assert(input.data_type == output.data_type); - if (use_bias) { - assert(input.data_type == bias.data_type); - } - - if (input.data_type == DT_HALF) { - half const *bias_ptr = - use_bias ? bias.get_half_ptr() : static_cast(nullptr); - Kernels::SpecInferIncMultiHeadAttention::inference_kernel( - m, - bc, - shard_id, - input.get_half_ptr(), - weight.get_half_ptr(), - output.get_half_ptr(), - bias_ptr, - stream); - } else if (input.data_type == DT_FLOAT) { - float const *bias_ptr = - use_bias ? bias.get_float_ptr() : static_cast(nullptr); - Kernels::SpecInferIncMultiHeadAttention::inference_kernel( - m, - bc, - shard_id, - input.get_float_ptr(), - weight.get_float_ptr(), - output.get_float_ptr(), - bias_ptr, - stream); - } else { - assert(false && "Unspported data type"); - } - - if (m->profiling) { - cudaEventRecord(t_end, stream); - checkCUDA(cudaEventSynchronize(t_end)); - float elapsed = 0; - checkCUDA(cudaEventElapsedTime(&elapsed, t_start, t_end)); - cudaEventDestroy(t_start); - cudaEventDestroy(t_end); - printf("SpecInferIncMultiHeadSelfAttention forward time = %.2fms\n", - elapsed); - // print_tensor<3, float>(acc_query.ptr, acc_query.rect, - // "[Attention:forward:query]"); print_tensor<3, float>(acc_output.ptr, - // acc_output.rect, "[Attention:forward:output]"); - } - // save_tensor(output.get_float_ptr(), 768 * 3, - // "/home/xinhaoc/FlexFlow/inference/output/fk1.txt"); - // save_tensor(output.get_float_ptr() + 768 * 3, 768 * 3, - // "/home/xinhaoc/FlexFlow/inference/output/fk2.txt"); - - // if(bc->num_tokens == 1){ - // print_tensor(input.get_float_ptr(), 32, "specinc input"); - // print_tensor(output.get_float_ptr(), 32, "specinc output"); - // assert(false); - // } -} - -SpecInferIncMultiHeadSelfAttentionMeta::SpecInferIncMultiHeadSelfAttentionMeta( - FFHandler handler, - SpecInferIncMultiHeadSelfAttention const *attn, - GenericTensorAccessorR const &weight, - MemoryAllocator &gpu_mem_allocator, - int num_samples, - int _num_q_heads, - int _num_kv_heads) - : IncMultiHeadSelfAttentionMeta(handler, - BEAM_SEARCH_MODE, - attn, - attn->qSize, - attn->kSize, - attn->vSize, - attn->qProjSize, - attn->kProjSize, - attn->vProjSize, - attn->oProjSize, - attn->apply_rotary_embedding, - attn->qkv_bias, - attn->scaling_query, - attn->qk_prod_scaling, - attn->position_bias, - attn->final_bias, - attn->scaling_factor, - weight, - gpu_mem_allocator, - num_samples, - attn->num_q_heads, - attn->num_kv_heads, - _num_q_heads, - _num_kv_heads, - DT_NONE, - false) { - cudaStream_t stream; - checkCUDA(get_legion_stream(&stream)); - checkCUDNN(cudnnSetStream(handler.dnn, stream)); - - // allocate memory for the seqArray and reserve space - { - // size_t causal_mask_size = BatchConfig::MAX_NUM_REQUESTS; - // size_t total_size = causal_mask_size * sizeof(BatchConfig::BitMask); - // gpu_mem_allocator.create_legion_instance(beam_search_reserve_inst, - // total_size); - - beam_token_infos = - static_cast( - handler.batch_config_metadata + sizeof(BatchConfig::tokensInfo) + - sizeof(BatchConfig::requestsInfo)); - - beam_request_infos = - static_cast( - handler.batch_config_metadata + sizeof(BatchConfig::tokensInfo) + - sizeof(BatchConfig::requestsInfo) + - sizeof(BeamSearchBatchConfig::beamTokenInfo)); - causalMask = static_cast( - handler.batch_config_metadata + sizeof(BatchConfig::tokensInfo) + - sizeof(BatchConfig::requestsInfo) + - sizeof(BeamSearchBatchConfig::beamTokenInfo) + - sizeof(BeamSearchBatchConfig::beamRequestsInfo)); - - // causalMask = gpu_mem_allocator.allocate_instance( - // causal_mask_size); - // beam_token_infos = - // gpu_mem_allocator - // .allocate_instance( - // beam_tokeninfo_size); - // offset += beam_tokeninfo_size * - // sizeof(BeamSearchBatchConfig::BeamSearchPerTokenInfo); - // beam_request_infos = - // gpu_mem_allocator - // .allocate_instance( - // beam_requestinfo_size); - // offset += beam_requestinfo_size * - // sizeof(BeamSearchBatchConfig::BeamSearchPerRequestInfo); - // assert(offset == total_size); - // assert(gpu_mem_allocator.instance_total_size == - // gpu_mem_allocator.instance_allocated_size); - } - - cudaStreamSynchronize(stream); -} - -SpecInferIncMultiHeadSelfAttentionMeta::~SpecInferIncMultiHeadSelfAttentionMeta( - void) { - if (beam_search_reserve_inst != Realm::RegionInstance::NO_INST) { - beam_search_reserve_inst.destroy(); - } -} - -}; // namespace FlexFlow diff --git a/src/ops/tree attn kernel, 0----> -0.029753357172 b/src/ops/tree attn kernel, 0----> -0.029753357172 deleted file mode 100644 index e4f14ee757..0000000000 --- a/src/ops/tree attn kernel, 0----> -0.029753357172 +++ /dev/null @@ -1 +0,0 @@ -tree attn kernel, 0----> -0.02975335717201232910 0.01930358447134494781 0.03780741989612579346 0.11878532171249389648 -0.03523746877908706665 0.02421043440699577332 0.03719477355480194092 -0.00304851122200489044 0.02062662504613399506 0.06683708727359771729 -0.00642335414886474609 -0.00504039414227008820 0.02955199964344501495 0.00648811273276805878 0.00558663159608840942 0.02003456838428974152 -0.04041406139731407166 0.00736814411357045174 -0.04575226455926895142 0.03949077427387237549 0.05742383748292922974 0.04866250604391098022 0.04687267541885375977 -0.00701304525136947632 -0.03712264448404312134 -0.02175992354750633240 -0.03979443758726119995 0.03961737453937530518 -0.07450901716947555542 0.02090370282530784607 -0.03487894684076309204 0.01653470844030380249 \ No newline at end of file diff --git a/src/ops/tree_inc_multihead_self_attention.cu b/src/ops/tree_inc_multihead_self_attention.cu index a4329f52db..5c6527baf9 100644 --- a/src/ops/tree_inc_multihead_self_attention.cu +++ b/src/ops/tree_inc_multihead_self_attention.cu @@ -81,30 +81,22 @@ __global__ void compute_attention_kernel_fused_kernel( int const first_step = 0; - int const tlength = request_infos[batch_config_request_id].first_token_depth_in_request + - request_infos[batch_config_request_id].num_tokens_in_batch; - int const qlength = request_infos[batch_config_request_id].num_tokens_in_batch; + int const tlength = + request_infos[batch_config_request_id].first_token_depth_in_request + + request_infos[batch_config_request_id].num_tokens_in_batch; + int const qlength = + request_infos[batch_config_request_id].num_tokens_in_batch; BatchConfig::BitMask bitmask = causalMask[batch_config_request_id]; - // bitmask.mask[1] = 3; - // if (head_idx == 0 && tidx == 0) { - // printf("tree attn fused kernel req id %d %d, %d, %d, %lld\n", - // request_idx, - // tlength, - // qlength, - // bitmask.non_tree_cache_size, - // bitmask.mask[3]); - // } - int first_token_idx = 0; for (int r = 0; r < request_idx; r++) { first_token_idx += request_infos[r].num_tokens_in_batch; } - if(tidx == 0 && head_idx == 0){ - printf("tree req: %d, %d\n", request_idx, first_token_idx); - } + // if(tidx == 0 && head_idx == 0){ + // printf("tree req: %d, %d\n", request_idx, first_token_idx); + // } // shared memory objects extern __shared__ char smem_[]; @@ -174,26 +166,11 @@ __global__ void compute_attention_kernel_fused_kernel( (ti >= bitmask.non_tree_cache_size && (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & (1 << qi)))); - // if (head_idx == 0 && qi == 9 && mask) { - // printf("tree attn mask for first token %d, %lld, %d, %d, %d\n", - // ti, - // bitmask.mask[ti - bitmask.non_tree_cache_size], - // bitmask.non_tree_cache_size, - // request_idx, - // qi); - // } - // if (blockIdx.y == 0 && blockIdx.x == 0 && qi == 3 && mask) { - // printf("tree attn mask for third token %d, %lld, %d, %d\n", - // ti, - // bitmask.mask[ti - bitmask.non_tree_cache_size], - // bitmask.non_tree_cache_size, - // qi); - // } - qk_max = mask ? qk_max : fmaxf(qk_max, qk); // if (head_idx == 0 && qi == 0 && !mask) { - // printf("tree attn qkqkqkqk request id %d, %d %.10f, %.10f, %.10f\n ", + // printf("tree attn qkqkqkqk request id %d, %d %.10f, %.10f, %.10f\n + // ", // request_idx, // ti, // qk, @@ -250,10 +227,6 @@ __global__ void compute_attention_kernel_fused_kernel( // Compute the sum. exp_sum = block_sum(&red_smem[WARPS_PER_BLOCK], exp_sum); - // if (head_idx == 0 && tidx == 0 && qi == 9) { - // printf("expsum %.10f\n", exp_sum); - // } - // softmax float inv_sum = __fdividef(1.f, exp_sum + 1.e-6); for (int ti = first_step + tidx; ti < tlength; ti += THREADS_PER_BLOCK) { @@ -261,9 +234,6 @@ __global__ void compute_attention_kernel_fused_kernel( } __syncthreads(); - // if (head_idx == 0 && tidx == 0 && qi == 9) { - // printf("softmax %.10f\n", qk_smem[1]); - // } // value projection constexpr int V_VEC_SIZE = 16 / sizeof(DT); @@ -282,12 +252,8 @@ __global__ void compute_attention_kernel_fused_kernel( // The base pointer for the value in the cache buffer. DT const *v_cache_batch = - value_cache + batch_config_request_id * max_seq_length * hidden_size + vi; - // DT const *v_cache_batch = - // value_cache + - // (beam_request_idx * max_beam_width + beam_sub_request_idx) * - // max_seq_length * hidden_size + - // vi; + value_cache + batch_config_request_id * max_seq_length * hidden_size + + vi; if (Dh == Dh_MAX || vi < Dh) { for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) { @@ -310,10 +276,6 @@ __global__ void compute_attention_kernel_fused_kernel( // // Make sure we can start writing to shared memory. __syncthreads(); - // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0 && qi == 1) { - // printf("valueX %.10f\n", out.x); - // } - // Run the final reduction amongst the different groups computing different // partial outputs. if (Dh == Dh_MAX || vi < Dh) { @@ -391,19 +353,6 @@ __global__ void commit_tokens_kernel( int const req_id = committedTokenInfos[token_pos].request_index; int const tok_id = committedTokenInfos[token_pos].token_depth; - // if(i == 0){ - // printf("commit token: %d %d %f\n", token_idx_in_last_batch, tok_id, - // kVal); - // } - // if(i == hidden_size){ - // printf("commit token 1: %d %d %f\n", token_idx_in_last_batch, tok_id, - // kVal); - // } - // if(i == 2 * hidden_size){ - // printf("commit token 2: %d %d %f\n", token_idx_in_last_batch, tok_id, - // kVal); - // } - kCache_ptr[req_id * (hidden_size * max_seq_len) + tok_id * hidden_size + offset] = kVal; vCache_ptr[req_id * (hidden_size * max_seq_len) + tok_id * hidden_size + @@ -500,11 +449,13 @@ __global__ void update_tree_branch_kv_cache_fused( int const request_token_offset = request_infos[req_id].first_token_offset_in_batch; - int const first_token_depth = request_infos[req_id].first_token_depth_in_request; + int const first_token_depth = + request_infos[req_id].first_token_depth_in_request; // if(i % hidden_size == 0){ - // printf("update token request id: %d, %d, %d real id %d, value%.10f\n", req_id, - // token_idx, request_token_offset,(token_idx + first_token_depth - request_token_offset), kVal); + // printf("update token request id: %d, %d, %d real id %d, value%.10f\n", + // req_id, token_idx, request_token_offset,(token_idx + first_token_depth + // - request_token_offset), kVal); // } kCache_ptr[req_id * (hidden_size * max_seq_len) + (token_idx + first_token_depth - request_token_offset) * @@ -591,8 +542,6 @@ void compute_attention_kernel(TreeIncMultiHeadSelfAttentionMeta const *m, num_new_tokens++; } - std::cout << "num_new_tokens: " << num_new_tokens << "\n"; - int total_tokens_in_request = bc->tokensInfo[j].abs_depth_in_request + 1; assert(num_new_tokens >= 1 && total_tokens_in_request >= num_new_tokens); { @@ -873,12 +822,6 @@ void compute_attention_kernel_fused(TreeIncMultiHeadSelfAttentionMeta const *m, // update K-V cache int num_new_tokens = bc->num_active_tokens(); int parallelism = m->hidden_size * num_new_tokens; - // printf("update KV cache %d, idx: %d\n", - // num_new_tokens, - // bc->requestsInfo[0].first_token_depth_in_request); - // for (int i = 0; i < num_new_tokens; i++) { - // printf("abs depth:%d\n", bc->tokensInfo[i].abs_depth_in_request); - // } update_tree_branch_kv_cache_fused<<bias_ptr, bias_ptr, m->biasSize, cudaMemcpyHostToDevice, stream); bias_ptr = static_cast
(m->bias_ptr); } - // cudaMemcpyAsync(m->token_infos, - // &(bc->tokensInfo), - // bc->num_active_tokens() * - // sizeof(TreeVerifyBatchConfig::PerTokenInfo), - // cudaMemcpyHostToDevice, - // stream); - // cudaMemcpyAsync(m->request_infos, - // &(bc->requestsInfo), - // bc->max_requests_per_batch() * - // sizeof(BatchConfig::PerRequestInfo), - // cudaMemcpyHostToDevice, - // stream); // phase 1: Implement kernel to compute KQV for input tokens compute_qkv_kernel(m, bc, @@ -992,9 +923,6 @@ void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m, bias_ptr, stream); - // print_tensor((float *)m->devQKVProjArray, 32, "qkvtenor1"); - // print_tensor((float *)m->devQKVProjArray + 768 * (25 * 7) * 3, 32, "qkvtenor2"); - // phase 2: No need to update key/val cache // IncMultiHeadSelfAttention::update_kv_cache_kernel( // m, bc, stream); @@ -1037,8 +965,6 @@ void TreeIncMultiHeadSelfAttention::inference_kernel_wrapper( cudaEventRecord(t_start, stream); } - std::cout << "tree input tokens: " << bc->num_active_tokens() << "\n"; - // assert(input.data_type == weight.data_type); assert(input.data_type == output.data_type); if (use_bias) { @@ -1089,20 +1015,6 @@ void TreeIncMultiHeadSelfAttention::inference_kernel_wrapper( cudaEventDestroy(t_start); cudaEventDestroy(t_end); } - - // print_tensor(output.get_float_ptr(), 32, "tree attn kernel"); - - // save_tensor( - // input.get_float_ptr(), - // 768 * bc->num_active_tokens(), - // "/home/xinhaoc/FlexFlow/inference/output/Newtreeinput.txt"); - // save_tensor( - // output.get_float_ptr(), - // 768 * bc->num_active_tokens(), - // "/home/xinhaoc/FlexFlow/inference/output/Newtreeoutput.txt"); - // std::cout << "new tokens: " << bc->num_active_tokens() << "\n"; - - // assert(bc->num_tokens_to_commit == 0); } TreeIncMultiHeadSelfAttentionMeta::TreeIncMultiHeadSelfAttentionMeta( diff --git a/src/runtime/ffconst_utils.cc b/src/runtime/ffconst_utils.cc index 904bfbcaff..c7b6e1257a 100644 --- a/src/runtime/ffconst_utils.cc +++ b/src/runtime/ffconst_utils.cc @@ -154,8 +154,6 @@ std::string get_operator_type_name(OperatorType type) { return "SpecIncMultiHeadSelfAttention"; case OP_TREE_INC_MULTIHEAD_SELF_ATTENTION: return "TreeIncMultiHeadSelfAttention"; - case OP_SPECINFER_INC_MULTIHEAD_SELF_ATTENTION: - return "SpecInferPgraoIncMultiHeadSelfAttention"; case OP_INPUT: return "Input"; case OP_WEIGHT: diff --git a/src/runtime/graph.cc b/src/runtime/graph.cc index 46f7cc0f29..6d33dd9f27 100644 --- a/src/runtime/graph.cc +++ b/src/runtime/graph.cc @@ -51,7 +51,6 @@ #include "flexflow/ops/topk.h" #include "flexflow/ops/transpose.h" #include "flexflow/ops/tree_inc_multihead_self_attention.h" -#include "flexflow/ops/specinfer_inc_multihead_self_attention.h" #include "flexflow/parallel_ops/allreduce.h" #include "flexflow/parallel_ops/combine.h" #include "flexflow/parallel_ops/fused_parallel_op.h" @@ -70,7 +69,7 @@ using FlexFlow::MachineView; LegionRuntime::Logger::Category log_graph("graph"); LegionRuntime::Logger::Category log_simplify("graph_simplify"); -Node const Node::INVALID_NODE = Node(); +const Node Node::INVALID_NODE = Node(); Node::Node(void) : guid(0), ptr(NULL) {} @@ -2385,28 +2384,6 @@ GraphOptimalViewSerialized sez.serialize(attn->tensor_parallelism_degree); break; } - case OP_SPECINFER_INC_MULTIHEAD_SELF_ATTENTION: { - SpecInferIncMultiHeadSelfAttention *attn = - (SpecInferIncMultiHeadSelfAttention *)op; - sez.serialize(attn->layer_guid.id); - sez.serialize(attn->layer_guid.transformer_layer_id); - sez.serialize(attn->layer_guid.model_id); - sez.serialize(attn->oProjSize); - sez.serialize(attn->num_q_heads); - sez.serialize(attn->qProjSize); - sez.serialize(attn->vProjSize); - sez.serialize(attn->dropout); - sez.serialize(attn->qkv_bias); - sez.serialize(attn->final_bias); - sez.serialize(attn->add_zero_attn); - sez.serialize(attn->apply_rotary_embedding); - sez.serialize(attn->scaling_query); - sez.serialize(attn->scaling_factor); - sez.serialize(attn->qk_prod_scaling); - sez.serialize(attn->position_bias); - sez.serialize(attn->num_kv_heads); - break; - } case OP_SOFTMAX: { Softmax *softmax = (Softmax *)op; sez.serialize(softmax->dim); @@ -2937,52 +2914,6 @@ void FFModel::deserialize_graph_optimal_view( params); break; } - case OP_SPECINFER_INC_MULTIHEAD_SELF_ATTENTION: { - assert(num_inputs == 1); - int embed_dim, num_q_heads, k_dim, v_dim, num_kv_heads; - float dropout, scaling_factor; - bool qkv_bias, final_bias, add_zero_attn, apply_rotary_embedding, - scaling_query, qk_prod_scaling, position_bias; - size_t id, transformer_layer_id, deserialized_model_id; - dez.deserialize(id); - dez.deserialize(transformer_layer_id); - dez.deserialize(deserialized_model_id); - LayerID layer_guid(id, transformer_layer_id, deserialized_model_id); - dez.deserialize(embed_dim); - dez.deserialize(num_q_heads); - dez.deserialize(k_dim); - dez.deserialize(v_dim); - dez.deserialize(dropout); - dez.deserialize(qkv_bias); - dez.deserialize(final_bias); - dez.deserialize(add_zero_attn); - dez.deserialize(apply_rotary_embedding); - dez.deserialize(scaling_query); - dez.deserialize(scaling_factor); - dez.deserialize(qk_prod_scaling); - dez.deserialize(position_bias); - dez.deserialize(num_kv_heads); - - SpecInferIncMultiHeadSelfAttentionParams params; - params.embed_dim = embed_dim; - params.num_q_heads = num_q_heads; - params.kdim = k_dim; - params.vdim = v_dim; - params.dropout = dropout; - params.qkv_bias = qkv_bias; - params.final_bias = final_bias; - params.add_zero_attn = add_zero_attn; - params.layer_guid = layer_guid; - params.apply_rotary_embedding = apply_rotary_embedding; - params.scaling_query = scaling_query; - params.scaling_factor = scaling_factor; - params.qk_prod_scaling = qk_prod_scaling; - params.position_bias = position_bias; - params.num_kv_heads = num_kv_heads; - node = get_or_create_node(inputs[0], - params); - break; - } case OP_TOPK: { node = TopK::deserialize(*this, dez, inputs, num_inputs); break; diff --git a/src/runtime/inference_manager.cc b/src/runtime/inference_manager.cc index e7f7c5f52d..52a1efc2ab 100644 --- a/src/runtime/inference_manager.cc +++ b/src/runtime/inference_manager.cc @@ -257,7 +257,6 @@ void InferenceManager::init_operators_inference(FFModel *model) { ((ParallelOp *)op) ->create_input_partition_inference(*model, inputs, outputs); } - printf("init op %s\n", op->name); op->init_inference(*model, inputs, outputs); } } @@ -394,14 +393,13 @@ void InferenceManager::load_input_tokens_from_batch_config( } void InferenceManager::load_inference_metadata_batch_config( - BatchConfigFuture const &bc, - FFHandler *handlers) { + BatchConfigFuture const &bc, FFHandler *handlers) { Context ctx = ff_config.lg_ctx; Runtime *runtime = ff_config.lg_hlr; ArgumentMap argmap; - Rect<1> task_rect(Point<1>(0), - Point<1>(ff_config.workersPerNode * ff_config.numNodes - 1)); + Rect<1> task_rect( + Point<1>(0), Point<1>(ff_config.workersPerNode * ff_config.numNodes - 1)); IndexSpaceT<1> task_is = runtime->create_index_space(ctx, task_rect); // int rank = 0; diff --git a/src/runtime/model.cc b/src/runtime/model.cc index cf72f2d40b..c3ee73d78c 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -59,7 +59,6 @@ #include "flexflow/ops/sigmoid_silu_multi.h" #include "flexflow/ops/softmax.h" #include "flexflow/ops/spec_inc_multihead_self_attention.h" -#include "flexflow/ops/specinfer_inc_multihead_self_attention.h" #include "flexflow/ops/split.h" #include "flexflow/ops/topk.h" #include "flexflow/ops/transpose.h" @@ -94,10 +93,10 @@ Op::Op(FFModel &model, int numWeights, bool allocate_weights, int numOutputs, - ParallelTensor const input1, - ParallelTensor const input2, - ParallelTensor const input3, - ParallelTensor const input4) + const ParallelTensor input1, + const ParallelTensor input2, + const ParallelTensor input3, + const ParallelTensor input4) : Op(model, otype, dtype, @@ -117,10 +116,10 @@ Op::Op(FFModel &model, int _numInputs, int _numWeights, int _numOutputs, - ParallelTensor const _input1, - ParallelTensor const _input2, - ParallelTensor const _input3, - ParallelTensor const _input4) + const ParallelTensor _input1, + const ParallelTensor _input2, + const ParallelTensor _input3, + const ParallelTensor _input4) : op_type(_otype), data_type(_dtype), op_guid(model.op_global_guid++), numInputs(_numInputs), numWeights(_numWeights), numOutputs(_numOutputs), profiling(model.config.profiling), @@ -1025,9 +1024,9 @@ void Op::register_output_parallel_dims( operation); } -int Op::get_output_to_input_dim_mapping(ParallelTensor const output, +int Op::get_output_to_input_dim_mapping(const ParallelTensor output, int output_dim, - ParallelTensor const input) { + const ParallelTensor input) { int output_idx = -1, input_idx = -1; for (int i = 0; i < numOutputs; i++) { if (output == outputs[i]) { @@ -1060,9 +1059,9 @@ int Op::get_output_to_input_dim_mapping(ParallelTensor const output, return -1; } -int Op::get_output_to_weight_dim_mapping(ParallelTensor const output, +int Op::get_output_to_weight_dim_mapping(const ParallelTensor output, int output_dim, - ParallelTensor const weight) { + const ParallelTensor weight) { int output_idx = -1, weight_idx = -1; for (int i = 0; i < numOutputs; i++) { if (output == outputs[i]) { @@ -1659,7 +1658,7 @@ Tensor FFModel::create_tensor(int numdim, } ParallelTensor FFModel::create_parallel_tensor(int numdim, - ParallelDim const dims[], + const ParallelDim dims[], DataType data_type, Op const *op, int idx, @@ -1692,7 +1691,7 @@ Tensor FFModel::create_tensor_legion_ordering(int numdim, ParallelTensor FFModel::create_parallel_tensor_legion_ordering(int numdim, - ParallelDim const dims[], + const ParallelDim dims[], DataType data_type, Op const *op, int idx, @@ -1742,7 +1741,7 @@ Tensor FFModel::create_tensor(int const dims[], } template -ParallelTensor FFModel::create_parallel_tensor(ParallelDim const dims[], +ParallelTensor FFModel::create_parallel_tensor(const ParallelDim dims[], DataType data_type, Op const *owner_op, int owner_idx, @@ -1823,7 +1822,7 @@ Parameter FFModel::create_weight(int numdim, } template -ParallelParameter FFModel::create_parallel_weight(ParallelDim const dims[], +ParallelParameter FFModel::create_parallel_weight(const ParallelDim dims[], DataType data_type, Op const *owner_op, bool create_grad, @@ -1854,7 +1853,7 @@ ParallelParameter FFModel::create_parallel_weight(ParallelDim const dims[], } ParallelParameter FFModel::create_parallel_weight(int numdim, - ParallelDim const dims[], + const ParallelDim dims[], DataType data_type, Op const *owner_op, bool create_grad, @@ -1874,7 +1873,7 @@ ParallelParameter FFModel::create_parallel_weight(int numdim, ParallelParameter FFModel::create_parallel_weight_legion_ordering( int numdim, - ParallelDim const dims[], + const ParallelDim dims[], DataType data_type, Op const *owner_op, bool create_grad, @@ -2088,7 +2087,7 @@ void FFModel::map_weight_with_dim(ParallelTensor weight, } bool FFModel::get_parallel_tensor_from_tensor( - Tensor const tensor, ParallelTensor ¶llel_tensor) const { + const Tensor tensor, ParallelTensor ¶llel_tensor) const { // check if tensor->parallel_tensor is already set if (tensor->parallel_tensor != nullptr) { parallel_tensor = tensor->parallel_tensor; @@ -2125,7 +2124,7 @@ bool FFModel::get_parallel_tensor_from_tensor( } void FFModel::create_disjoint_partition(int num_dims, - ParallelDim const dims[], + const ParallelDim dims[], IndexSpace const &part_is, LogicalRegion const ®ion, LogicalPartition &part) { @@ -2148,7 +2147,7 @@ void FFModel::create_disjoint_partition(int num_dims, template void FFModel::create_disjoint_partition_with_dim2( - ParallelDim const dims[], + const ParallelDim dims[], IndexSpaceT const &part_is, LogicalRegion const ®ion, LogicalPartition &part) { @@ -2181,7 +2180,7 @@ void FFModel::create_disjoint_partition_with_dim2( } void FFModel::create_aliased_partition(int num_dims, - ParallelDim const dims[], + const ParallelDim dims[], int aliased_dim, IndexSpace const &part_is, LogicalRegion const ®ion, @@ -2205,7 +2204,7 @@ void FFModel::create_aliased_partition(int num_dims, template void FFModel::create_aliased_partition_with_dim2( - ParallelDim const dims[], + const ParallelDim dims[], int aliased_dim, IndexSpaceT const &part_is, LogicalRegion const ®ion, @@ -2242,7 +2241,7 @@ void FFModel::create_aliased_partition_with_dim2( } template -void FFModel::create_disjoint_partition(ParallelTensor const tensor, +void FFModel::create_disjoint_partition(const ParallelTensor tensor, IndexSpaceT const &part_is, LogicalPartition &part_fwd, LogicalPartition &part_bwd) { @@ -2290,7 +2289,7 @@ void FFModel::create_disjoint_partition(ParallelTensor const tensor, template void FFModel::create_data_parallel_partition_with_diff_dims( - ParallelTensor const tensor, + const ParallelTensor tensor, IndexSpaceT const &part_is, LogicalPartition &part_fwd, LogicalPartition &part_bwd) { @@ -2672,7 +2671,7 @@ IndexSpace FFModel::get_task_is(ParallelConfig const &pc) const { return get_task_is(view); } -IndexSpace FFModel::get_or_create_task_is(ParallelTensor const tensor) { +IndexSpace FFModel::get_or_create_task_is(const ParallelTensor tensor) { MachineView view; view.ndims = 0; for (int i = 0; i < tensor->num_dims; i++) { @@ -3039,12 +3038,6 @@ Op *FFModel::create_operator_from_layer( operators.push_back(op); return op; } - case OP_SPECINFER_INC_MULTIHEAD_SELF_ATTENTION: { - Op *op = SpecInferIncMultiHeadSelfAttention::create_operator_from_layer( - *this, layer, inputs); - operators.push_back(op); - return op; - } case OP_BATCHMATMUL: { Op *op = BatchMatmul::create_operator_from_layer(*this, layer, inputs); operators.push_back(op); @@ -3234,7 +3227,7 @@ Op *FFModel::create_operator_from_layer( } void FFModel::create_operators_from_layers() { - std::map tensors_to_parallel_tensors; + std::map tensors_to_parallel_tensors; // for (auto const &l : layers) { for (int layer_idx = 0; layer_idx < layers.size(); layer_idx++) { auto const &l = layers[layer_idx]; @@ -3980,38 +3973,38 @@ void FFIterationConfig::reset() { // Default Config Parameters struct DefaultConfig { - static int const epochs = 1; + const static int epochs = 1; // const static int iterations = 1; - static int const batchSize = 64; - static bool const profiling = false; - static bool const inference_debugging = false; + const static int batchSize = 64; + const static bool profiling = false; + const static bool inference_debugging = false; constexpr static float learningRate = 0.01f; constexpr static float weightDecay = 0.0001f; - static size_t const workSpaceSize = (size_t)128 * 1024 * 1024; // 128 MB - static int const numNodes = 1; - static int const workersPerNode = 0; - static int const cpusPerNode = 0; - static size_t const searchBudget = -1; - static size_t const simulatorWorkSpaceSize = + const static size_t workSpaceSize = (size_t)128 * 1024 * 1024; // 128 MB + const static int numNodes = 1; + const static int workersPerNode = 0; + const static int cpusPerNode = 0; + const static size_t searchBudget = -1; + const static size_t simulatorWorkSpaceSize = (size_t)2 * 1024 * 1024 * 1024; // 2 GB constexpr static float searchAlpha = 1.2f; - static bool const searchOverlapBackwardUpdate = false; - static size_t const offloadReserveSpaceSize = + const static bool searchOverlapBackwardUpdate = false; + const static size_t offloadReserveSpaceSize = (size_t)8 * 1024 * 1024 * 1024; // 8 GB - static bool const cpuOffload = false; - static bool const onlyDataParallel = true; - static bool const enableSampleParallel = true; - static bool const enableParameterParallel = false; - static bool const enableAttributeParallel = false; - static bool const enableInplaceOptimizations = false; - static bool const allowTensorOpMathConversion = false; - static int const machine_model_version = 0; - static int const simulator_segment_size = 16777216; // 16 MB - static int const simulator_max_num_segments = 1; - static int const base_optimize_threshold = 10; - static bool const enable_control_replication = true; + const static bool cpuOffload = false; + const static bool onlyDataParallel = true; + const static bool enableSampleParallel = true; + const static bool enableParameterParallel = false; + const static bool enableAttributeParallel = false; + const static bool enableInplaceOptimizations = false; + const static bool allowTensorOpMathConversion = false; + const static int machine_model_version = 0; + const static int simulator_segment_size = 16777216; // 16 MB + const static int simulator_max_num_segments = 1; + const static int base_optimize_threshold = 10; + const static bool enable_control_replication = true; // The default python data loader type is 2 to enable control replication - static int const python_data_loader_type = 2; + const static int python_data_loader_type = 2; }; FFConfig::FFConfig() { @@ -6233,44 +6226,6 @@ void register_flexflow_internal_tasks(Runtime *runtime, TreeIncMultiHeadSelfAttention::inference_task>(registrar); } } - { - TaskVariantRegistrar registrar( - SPECINFER_INC_MULTIHEAD_SELF_ATTENTION_INIT_TASK_ID, - "SpecInferIncMultiHeadSelfAttention Init"); - registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); - registrar.set_leaf(); - if (pre_register) { - Runtime::preregister_task_variant< - OpMeta *, - SpecInferIncMultiHeadSelfAttention::init_task>( - registrar, "SpecInferIncMultiHeadSelfAttention Init Task"); - } else { - if (enable_control_replication) { - registrar.global_registration = false; - } - runtime->register_task_variant< - OpMeta *, - SpecInferIncMultiHeadSelfAttention::init_task>(registrar); - } - } - { - TaskVariantRegistrar registrar( - SPECINFER_INC_MULTIHEAD_SELF_ATTENTION_INF_TASK_ID, - "SpecInferIncMultiHeadSelfAttention Inference"); - registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); - registrar.set_leaf(); - if (pre_register) { - Runtime::preregister_task_variant< - SpecInferIncMultiHeadSelfAttention::inference_task>( - registrar, "SpecInferIncMultiHeadSelfAttention Inference Task"); - } else { - if (enable_control_replication) { - registrar.global_registration = false; - } - runtime->register_task_variant< - SpecInferIncMultiHeadSelfAttention::inference_task>(registrar); - } - } // NoOp { TaskVariantRegistrar registrar(NOOP_INIT_TASK_ID, "Weight NCCL Init"); diff --git a/src/runtime/model.cpp b/src/runtime/model.cpp index b51ab83091..5499a280a8 100644 --- a/src/runtime/model.cpp +++ b/src/runtime/model.cpp @@ -152,7 +152,7 @@ FFHandler .wait(); handle.offload_reserve_space = workspaceInst.pointer_untyped(0, sizeof(char)); - }else { + } else { handle.offload_reserve_space = nullptr; } if (handle.batch_config_metadata_size > 0) { @@ -176,7 +176,7 @@ FFHandler .wait(); handle.batch_config_metadata = workspaceInst.pointer_untyped(0, sizeof(char)); - }else { + } else { handle.batch_config_metadata = nullptr; } // checkCUDA(hipMalloc(&handle.workSpace, handle.workSpaceSize)); diff --git a/src/runtime/model.cu b/src/runtime/model.cu index 523b3c76f3..c885b29db2 100644 --- a/src/runtime/model.cu +++ b/src/runtime/model.cu @@ -148,11 +148,10 @@ FFHandler .wait(); handle.offload_reserve_space = workspaceInst.pointer_untyped(0, sizeof(char)); - }else { + } else { handle.offload_reserve_space = nullptr; } if (handle.batch_config_metadata_size > 0) { - printf("allocate instance for metadata %d\n", handle.batch_config_metadata_size); // allocate memory for offload reserve space Memory gpu_mem = Machine::MemoryQuery(Machine::get_machine()) .only_kind(Memory::GPU_FB_MEM) @@ -173,7 +172,7 @@ FFHandler .wait(); handle.batch_config_metadata = workspaceInst.pointer_untyped(0, sizeof(char)); - }else { + } else { handle.batch_config_metadata = nullptr; } diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index e30a7ee478..89d4ddaed4 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -182,7 +182,7 @@ size_t RequestManager::get_num_ssms() { RequestManager::RequestGuid RequestManager::register_new_request(std::vector const &prompt, int max_sequence_length) { - std::lock_guard const lock(request_queue_mutex); + const std::lock_guard lock(request_queue_mutex); // Add a new request Request request; @@ -238,7 +238,7 @@ RequestManager::RequestGuid RequestManager::RequestGuid RequestManager::register_new_request(std::string const &prompt, int max_sequence_length) { - std::lock_guard const lock(request_queue_mutex); + const std::lock_guard lock(request_queue_mutex); // Add a new request Request request; request.status = Request::PENDING; @@ -296,7 +296,7 @@ RequestManager::RequestGuid } bool RequestManager::is_request_completed(RequestGuid const &guid) { - std::lock_guard const lock(request_queue_mutex); + const std::lock_guard lock(request_queue_mutex); assert(all_requests.find(guid) != all_requests.end()); Request const &request = all_requests[guid]; // return request.tokens.size() >= request.max_sequence_length; @@ -305,7 +305,7 @@ bool RequestManager::is_request_completed(RequestGuid const &guid) { GenerationResult RequestManager::get_generation_result(RequestGuid const &guid) { - std::lock_guard const lock(request_queue_mutex); + const std::lock_guard lock(request_queue_mutex); assert(request_generation_results.find(guid) != request_generation_results.end()); return request_generation_results[guid]; @@ -343,7 +343,7 @@ BatchConfig RequestManager::prepare_next_batch_task( BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, InferenceResult const &result) { - std::lock_guard const lock(request_queue_mutex); + const std::lock_guard lock(request_queue_mutex); // Step 1: append result from previous iteration to request's tokens for (int i = 0; i < old_bc.num_tokens; i++) { @@ -456,7 +456,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, new_bc.requestsInfo[i].max_sequence_length = old_bc.requestsInfo[i].max_sequence_length; num_active_req++; - new_bc.requestsInfo[num_active_req].batch_config_request_id = i; + new_bc.requestsInfo[num_active_req].batch_config_request_id = i; if (new_bc.requestsInfo[i].first_token_depth_in_request + 1 == request.tokens.size()) { // Incremental phase @@ -504,7 +504,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, new_request.max_sequence_length; new_bc.request_completed[i] = false; num_active_req++; - new_bc.requestsInfo[num_active_req].batch_config_request_id = i; + new_bc.requestsInfo[num_active_req].batch_config_request_id = i; // add profile_info for the new request ProfileInfo profile_info; profile_info.llm_decoding_steps = 1; @@ -566,7 +566,7 @@ BeamSearchBatchConfig RequestManager::prepare_next_batch_init(TreeVerifyBatchConfig const &old_bc, InferenceResult const &result, int model_id) { - std::lock_guard const lock(request_queue_mutex); + const std::lock_guard lock(request_queue_mutex); if (verbose) { std::cout << "\n############### prepare_next_batch_init ###############\n"; } @@ -603,11 +603,10 @@ BeamSearchBatchConfig } else { committed_tokens[guid].clear(); } - // iterate through all the tokens that belong to request i int root_abs_depth = request.tokens.size() - 1; - + while (result_index < old_bc.num_tokens && old_bc.tokensInfo[result_index].request_index == i) { int abs_depth = old_bc.tokensInfo[result_index].abs_depth_in_request; @@ -640,14 +639,12 @@ BeamSearchBatchConfig } if (request.status == Request::RUNNING) { - std::cout << "verify running: " << dfs_tree_inputs.at(guid).size() << ", " - << tree_outputs.size() << "\n"; std::vector> verified_tokens = traverse_verify_tree(guid, dfs_tree_inputs.at(guid), tree_outputs); log_req_mgr.print("Number of Verified Tokens = %zu", - verified_tokens.size()); + verified_tokens.size()); // check if the request is finished if (verified_tokens.size() + request.tokens.size() >= request.max_sequence_length) { @@ -729,9 +726,6 @@ BeamSearchBatchConfig } else { // Request not finished, pass verified_tokens to next iteration - std::cout << "parse to next iteration: " - << "\n"; - new_bc.request_completed[i] = false; new_bc.request_running[i] = true; num_active_req++; @@ -745,18 +739,13 @@ BeamSearchBatchConfig new_bc.requestsInfo[i].max_sequence_length = old_bc.requestsInfo[i].max_sequence_length; new_bc.requestsInfo[i].num_tokens_in_batch = verified_tokens.size(); - new_bc.requestsInfo[num_active_req].batch_config_request_id = i; + new_bc.requestsInfo[num_active_req].batch_config_request_id = i; // TODO: Beam Request Info, missing from VerifyTreeBatchConfig int new_max_depth = new_bc.requestsInfo[i].max_sequence_length - new_bc.requestsInfo[i].first_token_depth_in_request - verified_tokens.size(); - // std::cout << "max depth: " << new_max_depth << ", " - // << new_bc.requestsInfo[i].first_token_depth_in_request << - // ", " - // << verified_tokens.size() << "\n"; - // assert(false); new_bc.beamRequestsInfo[i].current_depth = 1; profiling_requests[request.guid].ssm_decoding_steps = 0; @@ -794,9 +783,6 @@ BeamSearchBatchConfig // Beam Token Info new_bc.beamTokenInfo[new_bc.num_tokens].sub_request_index = 0; new_bc.num_tokens++; - // std::cout << "num_gen ++ " - // << "\n"; - // num_generation_tokens++; // Add verified token to request's token list request.tokens.push_back(token.first); @@ -816,11 +802,6 @@ BeamSearchBatchConfig log_req_mgr.print("Output: %s", output.c_str()); } - // if (request.tokens.size() > 19 && i >= 7) { - // std::cout << request.tokens.size() << "\n"; - // assert(false); - // } - } else if (request.status == Request::PENDING) { new_bc.request_completed[i] = false; new_bc.request_running[i] = false; @@ -838,7 +819,7 @@ BeamSearchBatchConfig new_bc.requestsInfo[i].max_sequence_length = old_bc.requestsInfo[i].max_sequence_length; new_bc.requestsInfo[i].num_tokens_in_batch = 0; - new_bc.requestsInfo[num_active_req].batch_config_request_id = i; + new_bc.requestsInfo[num_active_req].batch_config_request_id = i; // TODO: Beam Request Info, missing from VerifyTreeBatchConfig new_bc.beamRequestsInfo[i].current_depth = 1; @@ -889,7 +870,7 @@ BeamSearchBatchConfig (int)new_request.tokens.size()); new_bc.requestsInfo[i].max_sequence_length = new_request.max_sequence_length; - new_bc.requestsInfo[num_active_req].batch_config_request_id = i; + new_bc.requestsInfo[num_active_req].batch_config_request_id = i; // add profile_info for the new request ProfileInfo profile_info; @@ -973,17 +954,12 @@ BeamSearchBatchConfig } new_bc.num_generation_tokens = num_generation_tokens; - std::cout << "prepare next batch init gen tokens: " - << new_bc.num_generation_tokens << "\n"; - if (verbose) { std::cout << "prepare_next_batch_init OLD vs NEW batchconfigs below:" << std::endl; old_bc.print(); new_bc.print(); } - std::cout << "prepare next batch init active tokens: " - << new_bc.num_tokens << "\n"; return new_bc; } @@ -1019,11 +995,11 @@ BeamSearchBatchConfig RequestManager::prepare_next_batch_beam_task( BeamSearchBatchConfig RequestManager::prepare_next_batch_beam(BeamSearchBatchConfig const &old_bc, BeamInferenceResult const &result) { - std::lock_guard const lock(request_queue_mutex); - if (true) { + const std::lock_guard lock(request_queue_mutex); + if (verbose) { std::cout << "\n############### prepare_next_batch_beam ###############\n"; } - if (true) { + if (verbose) { std::cout << "print all results" << "\n"; for (int i = 0; i < 40; i++) { @@ -1049,7 +1025,7 @@ BeamSearchBatchConfig if (old_bc.request_completed[i] || !old_bc.request_running[i]) { continue; } - num_active_req ++; + num_active_req++; // Comment out this assertion since num_tokens_in_batch can be // zero when beam search has reached required sequence length // assert(old_bc.requestsInfo[i].num_tokens_in_batch > 0); @@ -1092,13 +1068,6 @@ BeamSearchBatchConfig old_bc.beamRequestsInfo[i].sub_request_num * old_bc.beamRequestsInfo[i].beam_size; - std::cout << "oldbc : " << old_bc.beamRequestsInfo[i].sub_request_num - << ", " << old_bc.beamRequestsInfo[i].beam_size << "\n"; - - // if (old_bc.beamRequestsInfo[i].current_depth == 3) { - // assert(false); - // } - assert(new_bc.beamRequestsInfo[i].sub_request_num <= BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES && "exceed maximum nodes per layer"); @@ -1122,7 +1091,7 @@ BeamSearchBatchConfig request.tokens.size()) { // Incremental phase if (request.status == Request::RUNNING) { - // todo check it + // todo this is replaced by this_layer_size, but should check it new_bc.requestsInfo[i].num_tokens_in_batch = 1; } else { assert(false && "Request should be done"); @@ -1150,18 +1119,7 @@ BeamSearchBatchConfig memcpy(&new_bc.causalMask[i], &old_bc.causalMask[i], sizeof(BatchConfig::BitMask)); - // sub_request_num -> nodes of input next iteration - // beam_size replicate num - - std::cout << "print beam tree: " - << old_bc.beamRequestsInfo[i].current_depth << "\n"; BeamTree tree = request.beam_trees[old_bc.model_id]; - // for (int k = 0; k <= old_bc.beamRequestsInfo[i].current_depth; k++) { - // std::cout << "layer: " << k << "\n"; - // std::cout << "nodes: " << tree.treeLayers[k].nodes_num_this_layer - // << "\n"; - // } - std::cout << "append bit mask: "<< i << "\n"; appendBitMask(new_bc.causalMask[i], new_bc.beamRequestsInfo[i].sub_request_num, old_bc.beamRequestsInfo[i].beam_size, @@ -1185,9 +1143,6 @@ BeamSearchBatchConfig num_generation_tokens++; } } - // if(new_bc.beamRequestsInfo[i].current_depth >= 3 && i > 0){ - // assert(false); - // } } } @@ -1320,18 +1275,6 @@ BeamSearchBatchConfig old_bc.print(); new_bc.print(); } - - if (true) { - // std::cout << "print all resultsBBB" - // << "\n"; - // for (int i = 0; i < 40; i++) { - // std::cout << result.token_ids[i] << ", "; - // } - // std::cout << "Current Beam DepthBBB: " - // << old_bc.beamRequestsInfo[0].current_depth << "\n"; - } - std::cout << "prepare next batch beam total tokens: " << new_bc.num_tokens - << "gneration tokens: " << new_bc.num_generation_tokens << "\n"; return new_bc; } @@ -1366,7 +1309,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify_task( TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( std::vector const &old_batches) { - std::lock_guard const lock(request_queue_mutex); + const std::lock_guard lock(request_queue_mutex); std::cout << "\n############### prepare_next_batch_verify ###############\n"; @@ -1399,12 +1342,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( profiling_requests[request.guid].llm_decoding_steps += 1; if (request.status == Request::RUNNING) { - - std::cout << "prepare next batch running:\n" - << "\n"; new_bc.request_running[i] = true; - std::cout << "[Verify] Request " << request.guid << " is running" - << std::endl; // Get the dfs tree std::vector>> @@ -1419,12 +1357,12 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( std::vector> dfs_tree_inputs = merge_dfs_trees(all_dfs_trees, request.tokens.size() - 1, guid); - if (true) { - // std::cout << "Request Tokens Size: " << request.tokens.size() - // << std::endl; - // for (int k = 0; k < request.tokens.size(); k++) { - // std::cout << k << ": " << request.tokens[k] << std::endl; - // } + if (verbose) { + std::cout << "Request Tokens Size: " << request.tokens.size() + << std::endl; + for (int k = 0; k < request.tokens.size(); k++) { + std::cout << k << ": " << request.tokens[k] << std::endl; + } } // Normal Request Info @@ -1435,31 +1373,21 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( old_batches.at(0).requestsInfo[i].request_guid; new_bc.requestsInfo[i].max_sequence_length = old_batches.at(0).requestsInfo[i].max_sequence_length; - new_bc.requestsInfo[num_active_req].batch_config_request_id = i; + new_bc.requestsInfo[num_active_req].batch_config_request_id = i; // copy bitmask to verify batchconfig memcpy(&(new_bc.causalMask[i]), &(old_batches.at(0).causalMask[i]), sizeof(BatchConfig::BitMask)); - // std::cout << "bitmask: " << new_bc.causalMask[i].mask[0] << "\n"; - // assert(false); // TODO: Check this new_bc.requestsInfo[i].num_tokens_in_batch = 0; new_bc.request_completed[i] = false; - std::cout << "dfs_tree_inputs: " << dfs_tree_inputs.size() << ", " - << new_bc.causalMask[i].tree_size << ", " - << new_bc.causalMask[i].non_tree_cache_size << "\n"; - std::cout << "mask: " << std::bitset<64>(new_bc.causalMask[i].mask[0]) - << "\n"; - std::cout << "mask: " << std::bitset<64>(new_bc.causalMask[i].mask[1]) - << "\n"; - std::cout << "mask: " << std::bitset<64>(new_bc.causalMask[i].mask[2]) - << "\n"; - std::cout << "mask: " << std::bitset<64>(new_bc.causalMask[i].mask[3]) - << "\n"; - std::cout << "mask: " << std::bitset<64>(new_bc.causalMask[i].mask[4]) - << "\n"; + // std::cout << "dfs_tree_inputs: " << dfs_tree_inputs.size() << ", " + // << new_bc.causalMask[i].tree_size << ", " + // << new_bc.causalMask[i].non_tree_cache_size << "\n"; + // std::cout << "mask: " << std::bitset<64>(new_bc.causalMask[i].mask[0]) + // << "\n"; // Committed Tokens if (committed_tokens.find(guid) != committed_tokens.end()) { @@ -1473,7 +1401,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( i; new_bc.committed_tokens[new_bc.num_tokens_to_commit].token_depth = committed_token.first; - if (true) { + if (verbose) { std::cout << new_bc.num_tokens_to_commit << "- committed_token.token_depth: " << committed_token.first @@ -1485,7 +1413,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( // } } } - if (true) { + if (verbose) { std::cout << "new_bc.num_tokens_to_commit: " << new_bc.num_tokens_to_commit << std::endl; } @@ -1508,14 +1436,11 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( new_bc.requestsInfo[i].first_token_depth_in_request = request.tokens.size() - 1; - std::cout << "prepare next batch verify: " << dfs_tree_inputs.size() - << "\n"; - bool cutLayer = false; // Add Tokens from the DFS Tree to the next batch for (int j = 1; j < dfs_tree_inputs.size(); j++) { auto token = dfs_tree_inputs.at(j); - if (true) { + if (verbose) { std::cout << "[" << j << "] Token: " << token.first << ", Depth:" << token.second << std::endl; } @@ -1541,7 +1466,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( for (int j = total_tokens - 1; j >= 1; j--) { new_bc.num_tokens--; new_bc.requestsInfo[i].num_tokens_in_batch--; - std::cout << "cut: " << j << "\n"; + // std::cout << "cut: " << j << "\n"; if (new_bc.tokensInfo[j].abs_depth_in_request != new_bc.tokensInfo[j - 1].abs_depth_in_request) { break; @@ -1550,8 +1475,6 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( } } else if (request.status == Request::PENDING) { - std::cout << "prepare next batch verify: pending\n" - << "\n"; new_bc.request_running[i] = false; if (verbose) { std::cout << "[Verify] Request " << request.guid @@ -1583,8 +1506,6 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( memcpy(&(new_bc.causalMask[i]), &(old_batches.at(0).causalMask[i]), sizeof(BatchConfig::BitMask)); - // std::cout << "bitmask: " << new_bc.causalMask[i].mask[0] << "\n"; - // assert(false); // Normal Request Info new_bc.requestsInfo[i].first_token_depth_in_request = @@ -1594,7 +1515,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( old_batches.at(0).requestsInfo[i].request_guid; new_bc.requestsInfo[i].max_sequence_length = old_batches.at(0).requestsInfo[i].max_sequence_length; - new_bc.requestsInfo[num_active_req].batch_config_request_id = i; + new_bc.requestsInfo[num_active_req].batch_config_request_id = i; new_bc.request_completed[i] = false; @@ -1608,9 +1529,9 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( << std::endl; if (request.llm_cache_size < request.initial_len) { - std::cout << "Initialization (prompt) phase: " - << new_bc.requestsInfo[i].num_tokens_in_batch << ", " - << old_batches.at(0).beamRequestsInfo[i].beam_size << "\n"; + // std::cout << "Initialization (prompt) phase: " + // << new_bc.requestsInfo[i].num_tokens_in_batch << ", " + // << old_batches.at(0).beamRequestsInfo[i].beam_size << "\n"; // Initialization (prompt) phase for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) { new_bc.tokensInfo[new_bc.num_tokens].request_index = i; @@ -1618,8 +1539,6 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( request.tokens[request.llm_cache_size + j]; new_bc.tokensInfo[new_bc.num_tokens].abs_depth_in_request = request.llm_cache_size + j; - std::cout << "load prompt tokens: " << j << ": " - << new_bc.tokensInfo[new_bc.num_tokens].token_id << "\n"; new_bc.num_tokens++; } @@ -1645,8 +1564,8 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( } } else { // launch the request into running phase after loading all prompt if (get_max_tokens_per_batch() - new_bc.num_tokens > 0) { - std::cout << "Initialization running phase: " - << new_bc.requestsInfo[i].num_tokens_in_batch << "\n"; + // std::cout << "Initialization running phase: " + // << new_bc.requestsInfo[i].num_tokens_in_batch << "\n"; request.status = Request::RUNNING; new_bc.request_running[i] = true; @@ -1671,11 +1590,6 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( } } - std::cout << "how many tokens in verify? " << new_bc.num_tokens << "\n"; - - std::cout << "check dfs tree input size: " << dfs_tree_inputs[1000000].size() - << "\n"; - return new_bc; } @@ -1690,7 +1604,7 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, auto start_depth = old_bc.tokensInfo[0].abs_depth_in_request; int result_index = 0; - if (true) { + if (verbose) { std::cout << "Store total of " << old_bc.num_tokens << " tokens in the current batch.\n"; } @@ -1700,10 +1614,10 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, old_bc.requestsInfo[old_bc.tokensInfo[i].request_index].request_guid != guid) { - std::cout << "i is: " << i << "old guid" << guid << " new guid" - << old_bc.requestsInfo[old_bc.tokensInfo[i].request_index] - .request_guid - << "\n"; + // std::cout << "i is: " << i << "old guid" << guid << " new guid" + // << old_bc.requestsInfo[old_bc.tokensInfo[i].request_index] + // .request_guid + // << "\n"; int index = old_bc.tokensInfo[i - 1].request_index; int beam_size = old_bc.beamRequestsInfo[index].beam_size; @@ -1718,22 +1632,16 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, // Count tokens sent to model in this request to find the final token's // index - - std::cout << "previous result index: " << result_index; - result_index += (old_bc.tokensInfo[i - 1].abs_depth_in_request - start_depth) * beam_size; - std::cout << "after result index: " << result_index; - - // if (true) { - // std::cout << "i = " << i << ", result index = " << result_index - // << ", value: " << result.token_ids[result_index] - // << ", leaf node num: " << leaf_node_num << ", depth" << - // depth - // << ", beam size: " << beam_size << "\n"; - // } + if (verbose) { + std::cout << "i = " << i << ", result index = " << result_index + << ", value: " << result.token_ids[result_index] + << ", leaf node num: " << leaf_node_num << ", depth" << depth + << ", beam size: " << beam_size << "\n"; + } Request &request = all_requests[old_bc.requestsInfo[index].request_guid]; @@ -1743,7 +1651,7 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, if (depth == 1) { // store the last input into the tree; - if (true) { + if (verbose) { std::cout << "try to store the input" << "\n"; } @@ -1756,13 +1664,11 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, .treeLayers[0] .nodes_num_this_layer = 1; - if (true) { + if (verbose) { std::cout << "Store the previous last token to the tree root: " << request.tokens.back() << "\n"; } } - - std::cout << "leaffffff: " << leaf_node_num << "\n"; request.beam_trees.at(old_bc.model_id) .treeLayers[depth] .nodes_num_this_layer = leaf_node_num; @@ -1777,27 +1683,20 @@ void RequestManager::store_beam_metadata(BeamSearchBatchConfig const &old_bc, request.beam_trees.at(old_bc.model_id) .treeLayers[depth] .parent_ids[beam_id] = result.parent_id[result_index]; - // std::cout << "??????? beam id: " << beam_id << ", token: " - // << request.beam_trees.at(old_bc.model_id) - // .treeLayers[depth] - // .tokens[beam_id] - // << "\n"; - - // if (true) { - // std::cout << "tree value: " << depth << "token: " - // << request.beam_trees.at(old_bc.model_id) - // .treeLayers[depth] - // .tokens[beam_id] - // << "result tokens: " << result.token_ids[result_index]; - // } + + if (verbose) { + std::cout << "tree value: " << depth << "token: " + << request.beam_trees.at(old_bc.model_id) + .treeLayers[depth] + .tokens[beam_id] + << "result tokens: " << result.token_ids[result_index]; + } result_index += 1; } // update the guid and start_depth for current request if (i < old_bc.num_tokens) { int new_req_idx = old_bc.tokensInfo[i].request_index; guid = old_bc.requestsInfo[new_req_idx].request_guid; - std::cout << "update guid: " << guid << ", request idx: " << index - << "\n"; start_depth = old_bc.tokensInfo[i].abs_depth_in_request; } } @@ -1839,8 +1738,8 @@ void RequestManager::update_beam_metadata(BeamSearchBatchConfig &new_bc, tree.treeLayers[depth].probs[j]; new_bc.beamRequestsInfo[request_index].tokens[j] = tree.treeLayers[depth].tokens[j]; - std::cout << "token: " << j << ": " - << new_bc.beamRequestsInfo[request_index].tokens[j] << "\n"; + // std::cout << "token: " << j << ": " + // << new_bc.beamRequestsInfo[request_index].tokens[j] << "\n"; } } if (verbose) { @@ -1892,13 +1791,13 @@ void RequestManager::updateBitMask(BatchConfig::BitMask &bitmask, "do not support tree size > 64"); assert(initLength >= 1 && "verified token num should >= 1"); - std::cout << "non tree size: " << non_tree_size << ", " - << bitmask.non_tree_cache_size << "\n"; + // std::cout << "non tree size: " << non_tree_size << ", " + // << bitmask.non_tree_cache_size << "\n"; bitmask.non_tree_cache_size = non_tree_size + initLength - 1; bitmask.tree_size = 1; bitmask.this_layer_size = initLength; - std::cout << "non_tree_size: " << non_tree_size << "\n"; + // std::cout << "non_tree_size: " << non_tree_size << "\n"; bitmask.prompt_size = 1; for (int i = 0; i < bitmask.prompt_size; i++) { for (int j = i; j < bitmask.prompt_size; j++) { @@ -1906,13 +1805,9 @@ void RequestManager::updateBitMask(BatchConfig::BitMask &bitmask, } } - std::cout << "see bit mask update" << bitmask.prompt_size << "\n"; - std::cout << "see bit mask update" << std::bitset<64>(bitmask.mask[0]) - << "\n"; - std::cout << "see bit mask update" << std::bitset<64>(bitmask.mask[1]) - << "\n"; - std::cout << "see bit mask update" << std::bitset<64>(bitmask.mask[2]) - << "\n"; + // std::cout << "see bit mask update" << bitmask.prompt_size << "\n"; + // std::cout << "see bit mask update" << std::bitset<64>(bitmask.mask[0]) + // << "\n"; } // prepare next beam, append layers to the tree @@ -1987,16 +1882,10 @@ void RequestManager::appendBitMask(BatchConfig::BitMask &bitmask, // assert(false); // } - std::cout << "see bit mask append" << bitmask.prompt_size << "\n"; - std::cout << "see bit mask append" << bitmask.non_tree_cache_size << "\n"; - std::cout << "see bit mask append" << std::bitset<64>(bitmask.mask[0]) - << "\n"; - std::cout << "see bit mask append" << std::bitset<64>(bitmask.mask[1]) - << "\n"; - std::cout << "see bit mask append" << std::bitset<64>(bitmask.mask[2]) - << "\n"; - std::cout << "see bit mask append" << std::bitset<64>(bitmask.mask[3]) - << "\n"; + // std::cout << "see bit mask append" << bitmask.prompt_size << "\n"; + // std::cout << "see bit mask append" << bitmask.non_tree_cache_size << "\n"; + // std::cout << "see bit mask append" << std::bitset<64>(bitmask.mask[0]) + // << "\n"; } bool PreOrder( @@ -2084,7 +1973,7 @@ std::vector> // depth) pairs for (auto const &pair : inputSerializedTree) { oss << " " << pair.second << ":" << pair.first; - log_req_mgr.print("(%d, %d)", pair.first, pair.second); + // log_req_mgr.print("(%d, %d)", pair.first, pair.second); } log_req_mgr.print("Input tree:%s", oss.str().c_str()); } @@ -2093,7 +1982,7 @@ std::vector> // outputSerializedTree is an array of (token id, depth + 1) pairs std::ostringstream oss; for (auto const &pair : outputSerializedTree) { - log_req_mgr.print("(%d, %d)", pair.first, pair.second); + // log_req_mgr.print("(%d, %d)", pair.first, pair.second); oss << " " << pair.second << ":" << pair.first; } log_req_mgr.print("Output tree:%s", oss.str().c_str()); @@ -2130,6 +2019,7 @@ std::vector> } // to avoid branch switch when same tokens in input tree. + // todo, only checked for N->1->1->1 cases bool findFirst = false; layer_num = -1; @@ -2173,9 +2063,10 @@ std::vector> new_committed_tokens.push_back(std::make_pair( input.second, committed_tokens.at(guid).at(i).second)); // at this point, you'll not go other branches - std::cout << "verify tree push back: " << output.first - << ", tree size is: " << verifiedTree.size() - << ", ??: " << input.first << ", " << input.second << "\n"; + // std::cout << "verify tree push back: " << output.first + // << ", tree size is: " << verifiedTree.size() + // << ", ??: " << input.first << ", " << input.second << + // "\n"; } else { printf("not correct slot\n"); @@ -2190,9 +2081,9 @@ std::vector> committed_tokens.at(guid).at(i).second)); // // at this point, you'll not go other branches - std::cout << "verify tree push back: " << output.first - << ", tree size is: " << verifiedTree.size() - << ", ??: " << input.first << ", " << input.second << "\n"; + // std::cout << "verify tree push back: " << output.first + // << ", tree size is: " << verifiedTree.size() + // << ", ??: " << input.first << ", " << input.second << "\n"; } assert(committed_tokens.at(guid).at(i).first == input.second); @@ -2203,7 +2094,7 @@ std::vector> // log_req_mgr.print("========Verified============"); std::ostringstream oss; for (auto const &pair : verifiedTree) { - log_req_mgr.print("(%d, %d)", pair.first, pair.second); + // log_req_mgr.print("(%d, %d)", pair.first, pair.second); oss << " " << pair.second << ":" << pair.first; } log_req_mgr.print("Verified:%s", oss.str().c_str()); @@ -2225,7 +2116,7 @@ std::vector> RequestManager::traverse_beam_tree(BeamSearchBatchConfig const &old_bc, int request_index, int first_token_depth_in_request) { - if (true) { + if (verbose) { std::cout << "[Traverse Beam Tree] request_index: " << request_index << "\n"; std::cout << "[Traverse Beam Tree] max_depth: " @@ -2269,13 +2160,13 @@ std::vector> // verbose); // print it - if (true) { + if (verbose) { std::cout << "Print serialized tree: size:" << request_index << serializedTree.size() << "\n"; } for (int k = 0; k < serializedTree.size(); k++) { serializedTree.at(k).second += first_token_depth_in_request; - if (true) { + if (verbose) { std::cout << "token id: " << serializedTree.at(k).first << ", depth: " << serializedTree.at(k).second << "\n"; } @@ -2354,9 +2245,6 @@ std::vector> } dfs_tree_inputs[guid] = merged_tree; - // std::cout << "assign dfr tree: " << guid << ", " << merged_tree.size() << - // ", " - // << dfs_tree_inputs[guid].size() << "\n"; return merged_tree; } diff --git a/src/runtime/request_manager.cu b/src/runtime/request_manager.cu index bb6b6030aa..bb20fb263f 100644 --- a/src/runtime/request_manager.cu +++ b/src/runtime/request_manager.cu @@ -212,7 +212,6 @@ void RequestManager::load_batch_config_task( } // add a size check - std::cout << "hahaha handle.batch_config_metadata_size: " << handle.batch_config_metadata_size << ", "<< total_copy_size << "\n"; assert(total_copy_size <= handle.batch_config_metadata_size); } From b621f2a9f62f24a8112df7af3850dc3bdb494dc7 Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Sat, 30 Dec 2023 17:25:28 -0500 Subject: [PATCH 13/30] . --- inference/spec_infer/spec_infer.cc | 10 +++++----- src/runtime/cuda_helper.cu | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/inference/spec_infer/spec_infer.cc b/inference/spec_infer/spec_infer.cc index 258b2d78eb..b369a13c1d 100644 --- a/inference/spec_infer/spec_infer.cc +++ b/inference/spec_infer/spec_infer.cc @@ -266,9 +266,9 @@ void FlexFlow::top_level_task(Task const *task, ModelMeta model_metadata; bool use_full_precision = false; bool verbose = false; - int max_requests_per_batch = 10; - int max_tokens_per_batch = 199; - int max_sequence_length = 200; + int max_requests_per_batch = 16; + int max_tokens_per_batch = 256; + int max_sequence_length = 1024; InputArgs const &command_args = HighLevelRuntime::get_input_args(); char **argv = command_args.argv; @@ -302,7 +302,7 @@ void FlexFlow::top_level_task(Task const *task, model_metadata.llm_tokenizer_path); rm->register_output_filepath(file_paths.output_file_path); - //first decoding step: 3 results + // first decoding step: 3 results rm->push_spec_infer_tree_width(3); // Create LLM model @@ -402,7 +402,7 @@ void FlexFlow::top_level_task(Task const *task, printf("Prompt[%d]: %s\n", total_num_requests, text.c_str()); total_num_requests++; prompts.push_back(text); - // tree_model.generate(text, 128 /*max_sequence_length*/); + // tree_model.generate(text, 128 /*max_sequence_length*/); } tree_model.generate(prompts, 128 /*max_sequence_length*/); } diff --git a/src/runtime/cuda_helper.cu b/src/runtime/cuda_helper.cu index 398ed7f3cd..fa6bf55fe5 100644 --- a/src/runtime/cuda_helper.cu +++ b/src/runtime/cuda_helper.cu @@ -226,7 +226,7 @@ __host__ void print_tensor(T const *ptr, printf("%s, %d---->", prefix, shard_id); for (idx = 0; idx < num_elements; idx++) { printf(" %.20lf", (float)host_ptr[idx]); - if (idx >= 200) { + if (idx >= 100) { break; } } From 8a0b007bfe20b50302ad201c01c7ac1dfb30a25a Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Sat, 30 Dec 2023 18:49:19 -0500 Subject: [PATCH 14/30] load batchconfig --- src/ops/inc_multihead_self_attention.cpp | 4 ++-- src/runtime/inference_manager.cc | 9 ++++----- src/runtime/model.cpp | 2 +- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/ops/inc_multihead_self_attention.cpp b/src/ops/inc_multihead_self_attention.cpp index a59740f4a3..00cc4d8868 100644 --- a/src/ops/inc_multihead_self_attention.cpp +++ b/src/ops/inc_multihead_self_attention.cpp @@ -1106,7 +1106,7 @@ template void Kernels::IncMultiHeadAttention::compute_o_prod_bias( float const *weight_ptr, float const *bias_ptr, int num_tokens, - cudaStream_t stream); + hipStream_t stream); template void Kernels::IncMultiHeadAttention::compute_o_prod_bias( IncMultiHeadSelfAttentionMeta const *m, BatchConfig const *bc, @@ -1115,6 +1115,6 @@ template void Kernels::IncMultiHeadAttention::compute_o_prod_bias( half const *weight_ptr, half const *bias_ptr, int num_tokens, - cudaStream_t stream); + hipStream_t stream); }; // namespace FlexFlow diff --git a/src/runtime/inference_manager.cc b/src/runtime/inference_manager.cc index 52a1efc2ab..8af0ed8978 100644 --- a/src/runtime/inference_manager.cc +++ b/src/runtime/inference_manager.cc @@ -398,11 +398,10 @@ void InferenceManager::load_inference_metadata_batch_config( Runtime *runtime = ff_config.lg_hlr; ArgumentMap argmap; - Rect<1> task_rect( - Point<1>(0), Point<1>(ff_config.workersPerNode * ff_config.numNodes - 1)); - IndexSpaceT<1> task_is = runtime->create_index_space(ctx, task_rect); + Domain domain = + runtime->get_index_space_domain(ctx, ff_config.all_gpu_task_is); + Rect<1> task_rect = domain; - // int rank = 0; int idx = 0; for (PointInRectIterator<1> it(task_rect); it(); it++) { FFHandler handler = handlers[idx++]; @@ -410,7 +409,7 @@ void InferenceManager::load_inference_metadata_batch_config( } IndexLauncher launcher(RM_LOAD_BATCH_CONFIG_TASK_ID, - task_is, + ff_config.all_gpu_task_is, TaskArgument(nullptr, 0), argmap, Predicate::TRUE_PRED, diff --git a/src/runtime/model.cpp b/src/runtime/model.cpp index 5499a280a8..ad2b781567 100644 --- a/src/runtime/model.cpp +++ b/src/runtime/model.cpp @@ -152,7 +152,7 @@ FFHandler .wait(); handle.offload_reserve_space = workspaceInst.pointer_untyped(0, sizeof(char)); - } else { + } else { handle.offload_reserve_space = nullptr; } if (handle.batch_config_metadata_size > 0) { From 17a718f95523ed3892d0324ed493ef6043607b13 Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Sat, 30 Dec 2023 19:18:22 -0500 Subject: [PATCH 15/30] clean --- .../inc_multihead_self_attention_utils.cuh | 4 +- src/ops/argmax.cc | 1 - src/ops/beam_topk.cc | 2 - src/ops/inc_multihead_self_attention.cu | 7 +- src/ops/spec_inc_multihead_self_attention.cu | 111 ++++++------------ src/ops/tree_inc_multihead_self_attention.cu | 13 +- 6 files changed, 49 insertions(+), 89 deletions(-) diff --git a/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh b/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh index 1b21a80dc9..c128c1a126 100644 --- a/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh +++ b/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh @@ -456,7 +456,7 @@ inline size_t smem_size_in_bytes(int hidden_size_per_head, int threads_per_block) { // The amount of shared memory needed to store the Q*K^T values in float. - size_t qk_sz = div_up(2000 + 1, 4) * 16; + size_t qk_sz = div_up(max_sequence_length + 1, 4) * 16; size_t logits_sz = qk_sz; // The total size needed during softmax. @@ -493,7 +493,7 @@ inline void smem_size_in_bytes_tree(int hidden_size_per_head, } // todo fix this - int max_qk_length = max_query_length * max_total_length + 1000; + int max_qk_length = max_query_length * max_total_length; // The amount of shared memory needed to store the Q*K^T values in float. size_t qk_sz = div_up(max_qk_length + 1, 4) * 16; diff --git a/src/ops/argmax.cc b/src/ops/argmax.cc index c3bb3d493e..dc7e4ea3b3 100644 --- a/src/ops/argmax.cc +++ b/src/ops/argmax.cc @@ -352,7 +352,6 @@ BeamInferenceResult GenericTensorAccessorW parent = helperGetGenericTensorAccessorWO( DT_INT32, regions[2], task->regions[2], FID_DATA, ctx, runtime); ArgMax::forward_kernel_wrapper(m, input, indices, parent, batch_size); - BeamInferenceResult ir; download_tensor( indices.get_int32_ptr(), ir.token_ids, batch_size); diff --git a/src/ops/beam_topk.cc b/src/ops/beam_topk.cc index 87d357b535..18d0ec1587 100644 --- a/src/ops/beam_topk.cc +++ b/src/ops/beam_topk.cc @@ -398,8 +398,6 @@ BeamInferenceResult download_tensor( parent_ptr, ir.parent_id, batch_size * m->max_beam_width); - // print_tensor(index_ptr, 32, "indexxxxxxx"); - if (m->inference_debugging) { assert(task->index_point.get_dim() == 1); int shard_id = task->index_point.point_data[0]; diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index cca0b230c3..da70e23f87 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -1381,7 +1381,7 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( assert(false && "Unkown inference mode"); } size_t requestinfo_size = BatchConfig::max_requests_per_batch(); - size_t tokeninfo_size = max_tokens_per_batch; + // size_t tokeninfo_size = max_tokens_per_batch; size_t qk_prod_size = max_tokens_per_batch * BatchConfig::max_sequence_length() * num_q_heads; size_t attn_heads_size = max_tokens_per_batch * num_q_heads * vProjSize; @@ -1438,8 +1438,9 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( token_infos = static_cast(handler.batch_config_metadata); - request_infos = static_cast( - handler.batch_config_metadata + sizeof(BatchConfig::tokensInfo)); + request_infos = reinterpret_cast( + reinterpret_cast(handler.batch_config_metadata) + + sizeof(BatchConfig::tokensInfo)); if (offload) { // token_infos = diff --git a/src/ops/spec_inc_multihead_self_attention.cu b/src/ops/spec_inc_multihead_self_attention.cu index b3a87fe244..88dd3f92e4 100644 --- a/src/ops/spec_inc_multihead_self_attention.cu +++ b/src/ops/spec_inc_multihead_self_attention.cu @@ -82,29 +82,20 @@ __global__ void compute_spec_inc_attention_kernel_generation_kernel( int const first_step = 0; - int const tlength = - request_infos[batch_config_request_id].first_token_depth_in_request + - request_infos[batch_config_request_id].num_tokens_in_batch; + // int const tlength = + // request_infos[batch_config_request_id].first_token_depth_in_request + + // request_infos[batch_config_request_id].num_tokens_in_batch; int const totalCacheSize = bitmask.non_tree_cache_size + bitmask.tree_size; - // if (blockIdx.y == 0 && blockIdx.x == 0 && tidx == 0) { - // printf("spec inc attn fused kernel %d, %d\n", - // totalCacheSize, - // request_infos[batch_config_request_id].num_tokens_in_batch); - // } - // int const qlength = request_infos[request_idx].num_tokens_in_batch; - int const tree_branch_num = - beam_request_infos[batch_config_request_id].sub_request_num; - - // will decode qlength tokens in this thread block - // int const qlength = tree_branch_num; - int first_token_idx = 0; for (int r = 0; r < request_idx; r++) { first_token_idx += causalMask[r].this_layer_size; } + int const tree_branch_num = + beam_request_infos[batch_config_request_id].sub_request_num; + // shared memory objects extern __shared__ char smem_[]; @@ -338,20 +329,14 @@ __global__ void spec_inc_store_kv_cache( DT vVal = devQKVProjArray[val_idx + hidden_size]; int const req_id = tokenInfos[token_idx].request_index; - int const tok_id = tokenInfos[token_idx].abs_depth_in_request; - int const first_token_in_req = - requestInfo[req_id].first_token_depth_in_request; - int const sub_req_id = beamTokenInfos[token_idx].sub_request_index; - int const total_token = requestInfo[req_id].num_tokens_in_batch; + // int const tok_id = tokenInfos[token_idx].abs_depth_in_request; int const request_token_offset = requestInfo[req_id].first_token_offset_in_batch; BatchConfig::BitMask bitmask = causalMask[req_id]; - int const sub_request_num = beamRequestInfos[req_id].sub_request_num; - - int const tree_branch_num = beamRequestInfos[req_id].sub_request_num; + // int const tree_branch_num = beamRequestInfos[req_id].sub_request_num; // int const query_token = bitmask.non_tree_cache_size + bitmask.tree_size - // tree_branch_num + sub_req_id + tok_id; @@ -379,9 +364,9 @@ void update_kv_cache_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, if (num_tokens > 0) { int parallelism = m->hidden_size * KV_WEIGHT_NUM * num_tokens; spec_inc_store_kv_cache<<>>( + min(CUDA_NUM_THREADS, parallelism), + 0, + stream>>>( static_cast
(m->devQKVProjArray), static_cast
(m->keyCache), static_cast
(m->valueCache), @@ -401,19 +386,19 @@ void update_kv_cache_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, } } -#define LAUNCH_SPEC_INC_ATTENTION_SCORE_KERNEL( \ +#define LAUNCH_SPEC_INC_ATTENTION_SCORE_KERNEL( \ DT, Dh, Dh_MAX, THDS_PER_KEY, THREADS_PER_VALUE, THDS_PER_BLOCK, stream) \ smem_sz = smem_size_in_bytes
(m->qProjSize, \ BatchConfig::max_sequence_length() + \ BatchConfig::MAX_SPEC_TREE_TOKEN_NUM, \ THREADS_PER_VALUE, \ THDS_PER_BLOCK); \ - compute_spec_inc_attention_kernel_generation_kernel \ + compute_spec_inc_attention_kernel_generation_kernel \ <<>>( \ static_cast
(m->devQKVProjArray), \ static_cast
(m->keyCache), \ @@ -470,14 +455,13 @@ __global__ void spec_fill_entries_above_diagonal(DT *matrix, } template -void compute_attention_kernel_prompt( - SpecIncMultiHeadSelfAttentionMeta const *m, - BeamSearchBatchConfig const *bc, - int shard_id, - DT *output_ptr, - DT const *bias_ptr, - DT const *weight_ptr, - cudaStream_t stream) { +void compute_attention_kernel_prompt(SpecIncMultiHeadSelfAttentionMeta const *m, + BeamSearchBatchConfig const *bc, + int shard_id, + DT *output_ptr, + DT const *bias_ptr, + DT const *weight_ptr, + cudaStream_t stream) { checkCUDA(cublasSetStream(m->handle.blas, stream)); checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); cudaDataType_t cublas_data_type = ff_to_cuda_datatype(m->output_type[0]); @@ -812,8 +796,7 @@ void SpecIncMultiHeadSelfAttention::inference_kernel_wrapper( checkCUDA(cudaEventElapsedTime(&elapsed, t_start, t_end)); cudaEventDestroy(t_start); cudaEventDestroy(t_end); - printf("SpecIncMultiHeadSelfAttention forward time = %.2fms\n", - elapsed); + printf("SpecIncMultiHeadSelfAttention forward time = %.2fms\n", elapsed); // print_tensor<3, float>(acc_query.ptr, acc_query.rect, // "[Attention:forward:query]"); print_tensor<3, float>(acc_output.ptr, // acc_output.rect, "[Attention:forward:output]"); @@ -860,51 +843,29 @@ SpecIncMultiHeadSelfAttentionMeta::SpecIncMultiHeadSelfAttentionMeta( // allocate memory for the seqArray and reserve space { - // size_t causal_mask_size = BatchConfig::MAX_NUM_REQUESTS; - // size_t total_size = causal_mask_size * sizeof(BatchConfig::BitMask); - // gpu_mem_allocator.create_legion_instance(beam_search_reserve_inst, - // total_size); - beam_token_infos = - static_cast( - handler.batch_config_metadata + sizeof(BatchConfig::tokensInfo) + + reinterpret_cast( + reinterpret_cast(handler.batch_config_metadata) + + sizeof(BatchConfig::tokensInfo) + sizeof(BatchConfig::requestsInfo)); beam_request_infos = - static_cast( - handler.batch_config_metadata + sizeof(BatchConfig::tokensInfo) + + reinterpret_cast( + reinterpret_cast(handler.batch_config_metadata) + + sizeof(BatchConfig::tokensInfo) + sizeof(BatchConfig::requestsInfo) + sizeof(BeamSearchBatchConfig::beamTokenInfo)); - causalMask = static_cast( - handler.batch_config_metadata + sizeof(BatchConfig::tokensInfo) + - sizeof(BatchConfig::requestsInfo) + + causalMask = reinterpret_cast( + reinterpret_cast(handler.batch_config_metadata) + + sizeof(BatchConfig::tokensInfo) + sizeof(BatchConfig::requestsInfo) + sizeof(BeamSearchBatchConfig::beamTokenInfo) + sizeof(BeamSearchBatchConfig::beamRequestsInfo)); - - // causalMask = gpu_mem_allocator.allocate_instance( - // causal_mask_size); - // beam_token_infos = - // gpu_mem_allocator - // .allocate_instance( - // beam_tokeninfo_size); - // offset += beam_tokeninfo_size * - // sizeof(BeamSearchBatchConfig::BeamSearchPerTokenInfo); - // beam_request_infos = - // gpu_mem_allocator - // .allocate_instance( - // beam_requestinfo_size); - // offset += beam_requestinfo_size * - // sizeof(BeamSearchBatchConfig::BeamSearchPerRequestInfo); - // assert(offset == total_size); - // assert(gpu_mem_allocator.instance_total_size == - // gpu_mem_allocator.instance_allocated_size); } cudaStreamSynchronize(stream); } -SpecIncMultiHeadSelfAttentionMeta::~SpecIncMultiHeadSelfAttentionMeta( - void) { +SpecIncMultiHeadSelfAttentionMeta::~SpecIncMultiHeadSelfAttentionMeta(void) { if (beam_search_reserve_inst != Realm::RegionInstance::NO_INST) { beam_search_reserve_inst.destroy(); } diff --git a/src/ops/tree_inc_multihead_self_attention.cu b/src/ops/tree_inc_multihead_self_attention.cu index 5c6527baf9..b4af80976f 100644 --- a/src/ops/tree_inc_multihead_self_attention.cu +++ b/src/ops/tree_inc_multihead_self_attention.cu @@ -445,7 +445,7 @@ __global__ void update_tree_branch_kv_cache_fused( DT vVal = devQKVProjArray[val_idx + hidden_size]; int const req_id = tokenInfos[token_idx].request_index; - int const tok_id = tokenInfos[token_idx].abs_depth_in_request; + // int const tok_id = tokenInfos[token_idx].abs_depth_in_request; int const request_token_offset = request_infos[req_id].first_token_offset_in_batch; @@ -1059,12 +1059,13 @@ TreeIncMultiHeadSelfAttentionMeta::TreeIncMultiHeadSelfAttentionMeta( // allocate memory for the seqArray and reserve space { - causalMask = static_cast( - handler.batch_config_metadata + sizeof(BatchConfig::tokensInfo) + - sizeof(BatchConfig::requestsInfo)); + causalMask = reinterpret_cast( + reinterpret_cast(handler.batch_config_metadata) + + sizeof(BatchConfig::tokensInfo) + sizeof(BatchConfig::requestsInfo)); committed_token_infos = - static_cast( - handler.batch_config_metadata + sizeof(BatchConfig::tokensInfo) + + reinterpret_cast( + reinterpret_cast(handler.batch_config_metadata) + + sizeof(BatchConfig::tokensInfo) + sizeof(BatchConfig::requestsInfo) + sizeof(BatchConfig::causalMask)); } From c8a107b1b75e5c90a9c7329ab2618b940a4b260f Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Sat, 30 Dec 2023 19:19:45 -0500 Subject: [PATCH 16/30] hip --- src/ops/inc_multihead_self_attention.cpp | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/src/ops/inc_multihead_self_attention.cpp b/src/ops/inc_multihead_self_attention.cpp index 00cc4d8868..d60386f927 100644 --- a/src/ops/inc_multihead_self_attention.cpp +++ b/src/ops/inc_multihead_self_attention.cpp @@ -1098,23 +1098,4 @@ template void Kernels::IncMultiHeadAttention::pre_build_weight_kernel( DataType data_type, hipStream_t stream); -template void Kernels::IncMultiHeadAttention::compute_o_prod_bias( - IncMultiHeadSelfAttentionMeta const *m, - BatchConfig const *bc, - int shard_id, - float *output_ptr, - float const *weight_ptr, - float const *bias_ptr, - int num_tokens, - hipStream_t stream); -template void Kernels::IncMultiHeadAttention::compute_o_prod_bias( - IncMultiHeadSelfAttentionMeta const *m, - BatchConfig const *bc, - int shard_id, - half *output_ptr, - half const *weight_ptr, - half const *bias_ptr, - int num_tokens, - hipStream_t stream); - }; // namespace FlexFlow From 42e1b5d92cf3e93e3f56d3d18d3fb68803b6caaf Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Sat, 30 Dec 2023 20:42:49 -0500 Subject: [PATCH 17/30] hip --- src/runtime/request_manager.cpp | 95 +++++++++++++++++--- src/runtime/request_manager.cu | 154 +++++++++----------------------- 2 files changed, 123 insertions(+), 126 deletions(-) diff --git a/src/runtime/request_manager.cpp b/src/runtime/request_manager.cpp index 9635b3bc1e..fadbf80d6d 100644 --- a/src/runtime/request_manager.cpp +++ b/src/runtime/request_manager.cpp @@ -56,22 +56,91 @@ void RequestManager::load_tokens_task( sizeof(TokenId) * batch_config->num_tokens, hipMemcpyHostToDevice, stream)); +} + +void RequestManager::load_batch_config_task( + Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + assert(regions.size() == 0); + assert(task->regions.size() == 0); + hipStream_t stream; + checkCUDA(get_legion_stream(&stream)); + + // BatchConfig const batch_config = *((BatchConfig *)task->args); + BatchConfig const *batch_config = BatchConfig::from_future(task->futures[0]); // copy meta data to workSpace FFHandler handle = *((FFHandler const *)task->local_args); - cudaMemcpyAsync(handle.batch_config_metadata, - &(batch_config->tokensInfo), - batch_config->num_active_tokens() * - sizeof(BatchConfig::PerTokenInfo), - cudaMemcpyHostToDevice, - stream); - cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + - sizeof(BatchConfig::tokensInfo), - &(batch_config->requestsInfo), - batch_config->max_requests_per_batch() * - sizeof(BatchConfig::PerRequestInfo), - cudaMemcpyHostToDevice, - stream); + size_t total_copy_size = 0; + checkCUDA(hipMemcpyAsync(handle.batch_config_metadata, + &(batch_config->tokensInfo), + sizeof(BatchConfig::tokensInfo), + hipMemcpyHostToDevice, + stream)); + total_copy_size += sizeof(BatchConfig::tokensInfo); + + checkCUDA(hipMemcpyAsync(static_cast(handle.batch_config_metadata) + + total_copy_size, + &(batch_config->requestsInfo), + sizeof(BatchConfig::requestsInfo), + hipMemcpyHostToDevice, + stream)); + total_copy_size += sizeof(BatchConfig::requestsInfo); + + // load speculative metadata + if (batch_config->get_mode() == BEAM_SEARCH_MODE) { + BeamSearchBatchConfig const *beam_batch_config = + static_cast(batch_config); + + checkCUDA(hipMemcpyAsync(static_cast(handle.batch_config_metadata) + + total_copy_size, + &(beam_batch_config->beamTokenInfo), + sizeof(BeamSearchBatchConfig::beamTokenInfo), + hipMemcpyHostToDevice, + stream)); + + total_copy_size += sizeof(BeamSearchBatchConfig::beamTokenInfo); + + checkCUDA(hipMemcpyAsync(static_cast(handle.batch_config_metadata) + + total_copy_size, + &(beam_batch_config->beamRequestsInfo), + sizeof(BeamSearchBatchConfig::beamRequestsInfo), + hipMemcpyHostToDevice, + stream)); + total_copy_size += sizeof(BeamSearchBatchConfig::beamRequestsInfo); + + checkCUDA(hipMemcpyAsync(static_cast(handle.batch_config_metadata) + + total_copy_size, + &(beam_batch_config->causalMask), + sizeof(BatchConfig::causalMask), + hipMemcpyHostToDevice, + stream)); + + total_copy_size += sizeof(BatchConfig::causalMask); + } else if (batch_config->get_mode() == TREE_VERIFY_MODE) { + TreeVerifyBatchConfig const *tree_batch_config = + static_cast(batch_config); + + checkCUDA(hipMemcpyAsync(static_cast(handle.batch_config_metadata) + + total_copy_size, + &(tree_batch_config->causalMask), + sizeof(BatchConfig::causalMask), + hipMemcpyHostToDevice, + stream)); + total_copy_size += sizeof(BatchConfig::causalMask); + checkCUDA(hipMemcpyAsync(static_cast(handle.batch_config_metadata) + + total_copy_size, + &(tree_batch_config->committed_tokens), + sizeof(TreeVerifyBatchConfig::committed_tokens), + hipMemcpyHostToDevice, + stream)); + total_copy_size += sizeof(TreeVerifyBatchConfig::committed_tokens); + } + + // add a size check + assert(total_copy_size <= handle.batch_config_metadata_size); } void RequestManager::load_positions_task( diff --git a/src/runtime/request_manager.cu b/src/runtime/request_manager.cu index bb20fb263f..51c52c3026 100644 --- a/src/runtime/request_manager.cu +++ b/src/runtime/request_manager.cu @@ -56,78 +56,6 @@ void RequestManager::load_tokens_task( sizeof(TokenId) * batch_config->num_tokens, cudaMemcpyHostToDevice, stream)); - - // // copy meta data to workSpace - // FFHandler handle = *((FFHandler const *)task->local_args); - // size_t total_copy_size = 0; - // cudaMemcpyAsync(handle.batch_config_metadata, - // &(batch_config->tokensInfo), - // sizeof(BatchConfig::tokensInfo), - // cudaMemcpyHostToDevice, - // stream); - // total_copy_size += sizeof(BatchConfig::tokensInfo); - - // cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + - // total_copy_size, - // &(batch_config->requestsInfo), - // sizeof(BatchConfig::requestsInfo), - // cudaMemcpyHostToDevice, - // stream); - // total_copy_size += sizeof(BatchConfig::requestsInfo); - - // // load speculative metadata - // if (batch_config->get_mode() == BEAM_SEARCH_MODE) { - // BeamSearchBatchConfig const *beam_batch_config = - // static_cast(batch_config); - - // cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + - // total_copy_size, - // &(beam_batch_config->beamTokenInfo), - // sizeof(BeamSearchBatchConfig::beamTokenInfo), - // cudaMemcpyHostToDevice, - // stream); - - // total_copy_size += sizeof(BeamSearchBatchConfig::beamTokenInfo); - - // cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + - // total_copy_size, - // &(beam_batch_config->beamRequestsInfo), - // sizeof(BeamSearchBatchConfig::beamRequestsInfo), - // cudaMemcpyHostToDevice, - // stream); - // total_copy_size += sizeof(BeamSearchBatchConfig::beamRequestsInfo); - - // cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + - // total_copy_size, - // &(beam_batch_config->causalMask), - // sizeof(BatchConfig::causalMask), - // cudaMemcpyHostToDevice, - // stream); - - // total_copy_size += sizeof(BatchConfig::causalMask); - // } else if (batch_config->get_mode() == TREE_VERIFY_MODE) { - // TreeVerifyBatchConfig const *tree_batch_config = - // static_cast(batch_config); - - // cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + - // total_copy_size, - // &(tree_batch_config->causalMask), - // sizeof(BatchConfig::causalMask), - // cudaMemcpyHostToDevice, - // stream); - // total_copy_size += sizeof(BatchConfig::causalMask); - // cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + - // total_copy_size, - // &(tree_batch_config->committed_tokens), - // sizeof(TreeVerifyBatchConfig::committed_tokens), - // cudaMemcpyHostToDevice, - // stream); - // total_copy_size += sizeof(TreeVerifyBatchConfig::committed_tokens); - // } - - // // add a size check - // std::cout << "handle.batch_config_metadata_size: " << handle.batch_config_metadata_size << ", "<< total_copy_size << "\n"; - // assert(total_copy_size <= handle.batch_config_metadata_size); } void RequestManager::load_batch_config_task( @@ -146,19 +74,19 @@ void RequestManager::load_batch_config_task( // copy meta data to workSpace FFHandler handle = *((FFHandler const *)task->local_args); size_t total_copy_size = 0; - cudaMemcpyAsync(handle.batch_config_metadata, - &(batch_config->tokensInfo), - sizeof(BatchConfig::tokensInfo), - cudaMemcpyHostToDevice, - stream); + checkCUDA(cudaMemcpyAsync(handle.batch_config_metadata, + &(batch_config->tokensInfo), + sizeof(BatchConfig::tokensInfo), + cudaMemcpyHostToDevice, + stream)); total_copy_size += sizeof(BatchConfig::tokensInfo); - cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + - total_copy_size, - &(batch_config->requestsInfo), - sizeof(BatchConfig::requestsInfo), - cudaMemcpyHostToDevice, - stream); + checkCUDA(cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + + total_copy_size, + &(batch_config->requestsInfo), + sizeof(BatchConfig::requestsInfo), + cudaMemcpyHostToDevice, + stream)); total_copy_size += sizeof(BatchConfig::requestsInfo); // load speculative metadata @@ -166,48 +94,48 @@ void RequestManager::load_batch_config_task( BeamSearchBatchConfig const *beam_batch_config = static_cast(batch_config); - cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + - total_copy_size, - &(beam_batch_config->beamTokenInfo), - sizeof(BeamSearchBatchConfig::beamTokenInfo), - cudaMemcpyHostToDevice, - stream); + checkCUDA(cudaMemcpyAsync( + static_cast(handle.batch_config_metadata) + total_copy_size, + &(beam_batch_config->beamTokenInfo), + sizeof(BeamSearchBatchConfig::beamTokenInfo), + cudaMemcpyHostToDevice, + stream)); total_copy_size += sizeof(BeamSearchBatchConfig::beamTokenInfo); - cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + - total_copy_size, - &(beam_batch_config->beamRequestsInfo), - sizeof(BeamSearchBatchConfig::beamRequestsInfo), - cudaMemcpyHostToDevice, - stream); + checkCUDA(cudaMemcpyAsync( + static_cast(handle.batch_config_metadata) + total_copy_size, + &(beam_batch_config->beamRequestsInfo), + sizeof(BeamSearchBatchConfig::beamRequestsInfo), + cudaMemcpyHostToDevice, + stream)); total_copy_size += sizeof(BeamSearchBatchConfig::beamRequestsInfo); - cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + - total_copy_size, - &(beam_batch_config->causalMask), - sizeof(BatchConfig::causalMask), - cudaMemcpyHostToDevice, - stream); + checkCUDA(cudaMemcpyAsync( + static_cast(handle.batch_config_metadata) + total_copy_size, + &(beam_batch_config->causalMask), + sizeof(BatchConfig::causalMask), + cudaMemcpyHostToDevice, + stream)); total_copy_size += sizeof(BatchConfig::causalMask); } else if (batch_config->get_mode() == TREE_VERIFY_MODE) { TreeVerifyBatchConfig const *tree_batch_config = static_cast(batch_config); - cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + - total_copy_size, - &(tree_batch_config->causalMask), - sizeof(BatchConfig::causalMask), - cudaMemcpyHostToDevice, - stream); + checkCUDA(cudaMemcpyAsync( + static_cast(handle.batch_config_metadata) + total_copy_size, + &(tree_batch_config->causalMask), + sizeof(BatchConfig::causalMask), + cudaMemcpyHostToDevice, + stream)); total_copy_size += sizeof(BatchConfig::causalMask); - cudaMemcpyAsync(static_cast(handle.batch_config_metadata) + - total_copy_size, - &(tree_batch_config->committed_tokens), - sizeof(TreeVerifyBatchConfig::committed_tokens), - cudaMemcpyHostToDevice, - stream); + checkCUDA(cudaMemcpyAsync( + static_cast(handle.batch_config_metadata) + total_copy_size, + &(tree_batch_config->committed_tokens), + sizeof(TreeVerifyBatchConfig::committed_tokens), + cudaMemcpyHostToDevice, + stream)); total_copy_size += sizeof(TreeVerifyBatchConfig::committed_tokens); } From 1901f65bc2045860d4c26c26c2a158b270cb300a Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Sun, 31 Dec 2023 23:25:21 -0500 Subject: [PATCH 18/30] embedding return when no token --- src/ops/embedding.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/ops/embedding.cc b/src/ops/embedding.cc index 76236e65ff..3be3eac618 100644 --- a/src/ops/embedding.cc +++ b/src/ops/embedding.cc @@ -478,6 +478,7 @@ FutureMap Embedding::inference(FFModel const &ff, 0 /*mapper_id*/, machine_view_hash); // regions[0]: input + launcher.add_future(bc); launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, 0 /*projection*/, READ_ONLY, @@ -516,6 +517,10 @@ void Embedding::forward_task(Task const *task, assert(task->regions.size() == 3); // Assert that weight and output must have the same data type // otherwise, a cast operator should be inserted + BatchConfig const *bc = BatchConfig::from_future(task->futures[0]); + if (bc->num_active_tokens() == 0) { + return; + } assert(m->weight_type[0] == m->output_type[0]); assert(m->input_type[0] == DT_INT32 || m->input_type[0] == DT_INT64); GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( From 130ad92f8369d6ba39dd470dafd160b844e49e99 Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Mon, 1 Jan 2024 01:39:41 -0500 Subject: [PATCH 19/30] use arg topk instead of beam topk --- include/flexflow/flexflow_c.h | 1 + include/flexflow/model.h | 2 + include/flexflow/ops/arg_topk.h | 16 ++- include/flexflow/ops/arg_topk_params.h | 1 + inference/models/llama.cc | 2 +- python/flexflow/core/flexflow_cffi.py | 5 +- src/c/flexflow_c.cc | 4 +- src/ops/arg_topk.cc | 185 +++++++++++++++++++------ src/ops/arg_topk.cu | 91 +++++++++--- src/runtime/model.cc | 18 +++ 10 files changed, 258 insertions(+), 67 deletions(-) diff --git a/include/flexflow/flexflow_c.h b/include/flexflow/flexflow_c.h index 01a2818a2b..305c8da513 100644 --- a/include/flexflow/flexflow_c.h +++ b/include/flexflow/flexflow_c.h @@ -571,6 +571,7 @@ flexflow_tensor_t flexflow_model_add_arg_top_k(flexflow_model_t handle_, const flexflow_tensor_t input_, int k, bool sorted, + bool speculative_decoding, char const *name); flexflow_tensor_t flexflow_model_add_beam_top_k(flexflow_model_t handle_, diff --git a/include/flexflow/model.h b/include/flexflow/model.h index 16df99ab1a..01244a371b 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -146,6 +146,7 @@ enum TaskIDs { TOPK_BWD_TASK_ID, ARG_TOPK_INIT_TASK_ID, ARG_TOPK_INF_TASK_ID, + ARG_TOPK_INF_SPECULATIVE_TASK_ID, SAMPLING_INIT_TASK_ID, SAMPLING_INF_TASK_ID, ARGMAX_INIT_TASK_ID, @@ -674,6 +675,7 @@ class FFModel { // Tensor *outputs, int k, bool sorted, + bool speculative_decoding, char const *name = NULL); Tensor argmax(const Tensor input, bool beam_search, char const *name = NULL); Tensor sampling(const Tensor input, float top_p, char const *name = NULL); diff --git a/include/flexflow/ops/arg_topk.h b/include/flexflow/ops/arg_topk.h index 8b2d2aa11c..3822a5e41e 100644 --- a/include/flexflow/ops/arg_topk.h +++ b/include/flexflow/ops/arg_topk.h @@ -12,6 +12,8 @@ class ArgTopKMeta : public OpMeta { public: ArgTopKMeta(FFHandler handle, Op const *op); bool sorted; + int k; + bool speculative_decoding; }; class ArgTopK : public Op { @@ -23,6 +25,7 @@ class ArgTopK : public Op { const ParallelTensor input, int k, bool sorted, + bool speculative_decoding, char const *name); ArgTopK(FFModel &model, LayerID const &layer_guid, @@ -61,6 +64,11 @@ class ArgTopK : public Op { std::vector const ®ions, Legion::Context ctx, Legion::Runtime *runtime); + static BeamInferenceResult inference_speculative_task( + Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); void serialize(Legion::Serializer &s) const override; static PCG::Node deserialize(FFModel &ff, Legion::Deserializer &d, @@ -75,22 +83,26 @@ class ArgTopK : public Op { template static void forward_kernel(ArgTopKMeta const *m, DT const *input_ptr, - // float *output_ptr, + float *output_ptr, int *indices_ptr, size_t batch_size, int length, int k, bool sorted, + BeamSearchBatchConfig const *bc, ffStream_t stream); static void forward_kernel_wrapper(ArgTopKMeta const *m, GenericTensorAccessorR const &input, + GenericTensorAccessorW const &prob, GenericTensorAccessorW const &indices, - int batch_size); + int batch_size, + BeamSearchBatchConfig const *bc); Params get_params() const; public: int k; bool sorted; + bool speculative_decoding; }; }; // namespace FlexFlow diff --git a/include/flexflow/ops/arg_topk_params.h b/include/flexflow/ops/arg_topk_params.h index 9d2a21034f..bd9c38e2a9 100644 --- a/include/flexflow/ops/arg_topk_params.h +++ b/include/flexflow/ops/arg_topk_params.h @@ -11,6 +11,7 @@ struct ArgTopKParams { LayerID layer_guid; int k; bool sorted; + bool speculative_decoding; bool is_valid(ParallelTensorShape const &) const; }; bool operator==(ArgTopKParams const &, ArgTopKParams const &); diff --git a/inference/models/llama.cc b/inference/models/llama.cc index 10001ee916..e9c84efe90 100644 --- a/inference/models/llama.cc +++ b/inference/models/llama.cc @@ -247,7 +247,7 @@ void LLAMA::create_llama_model(FFModel &ff, Tensor softmax = ff.softmax(dense, -1); // output = ff.beam_top_k(softmax, llama_config.max_beam_width, false); // output = ff.argmax(softmax, /*beam_Search*/ true); - output = ff.beam_top_k(softmax, llama_config.max_beam_width, false); + output = ff.arg_top_k(softmax, llama_config.max_beam_width, false, true); // output = ff.top_k(softmax, ) } else { // Tensor softmax = ff.softmax(dense, -1); diff --git a/python/flexflow/core/flexflow_cffi.py b/python/flexflow/core/flexflow_cffi.py index de3f7e6929..a3c221474d 100644 --- a/python/flexflow/core/flexflow_cffi.py +++ b/python/flexflow/core/flexflow_cffi.py @@ -3349,7 +3349,7 @@ def residual_rms_norm(self, input1, input2, eps, dim, name=None): handles_array[1], owner_op_type=OpType.RESIDUAL_RMS_NORM ) - def arg_top_k(self, input, k, sorted, name=None): + def arg_top_k(self, input, k, sorted, speculative_decoding, name=None): """Defines the Arg TopK layer. :param input: the input Tensor. @@ -3361,6 +3361,9 @@ def arg_top_k(self, input, k, sorted, name=None): :param sorted: Whether the entries should be sorted :type sorted: bool + :param speculative_decoding: Whether you need to perform beam search + :type speculative_decoding: bool + :param name: the name of the layer. Default is None. :type name: string diff --git a/src/c/flexflow_c.cc b/src/c/flexflow_c.cc index 80202f6f99..579fc5e2d1 100644 --- a/src/c/flexflow_c.cc +++ b/src/c/flexflow_c.cc @@ -1489,10 +1489,12 @@ flexflow_tensor_t flexflow_model_add_arg_top_k(flexflow_model_t handle_, const flexflow_tensor_t input_, int k, bool sorted, + bool speculative_decoding, char const *name) { FFModel *handle = FFCObjectWrapper::unwrap(handle_); Tensor input = FFCObjectWrapper::unwrap(input_); - Tensor tensor = handle->arg_top_k(input, k, sorted, name); + Tensor tensor = + handle->arg_top_k(input, k, sorted, speculative_decoding, name); return FFCObjectWrapper::wrap(tensor); } diff --git a/src/ops/arg_topk.cc b/src/ops/arg_topk.cc index a06b89de07..2727a1d249 100644 --- a/src/ops/arg_topk.cc +++ b/src/ops/arg_topk.cc @@ -51,6 +51,7 @@ using PCG::Node; Tensor FFModel::arg_top_k(const Tensor input, int k, bool sorted, + bool speculative_decoding, char const *name) { Layer *li = new Layer(this, OP_ARG_TOPK, @@ -58,7 +59,7 @@ Tensor FFModel::arg_top_k(const Tensor input, name, 1 /*inputs*/, 0 /*weights*/, - 1 /*outputs*/, + speculative_decoding ? 2 : 1 /*outputs*/, input); { int numdims = input->num_dims; @@ -71,9 +72,14 @@ Tensor FFModel::arg_top_k(const Tensor input, // numdims, dims, input->data_type, li, 0, true /*create_grad*/); li->outputs[0] = create_tensor_legion_ordering( numdims, dims, DT_INT32, li, 0, false /*create_grad*/); + if (speculative_decoding) { + li->outputs[1] = create_tensor_legion_ordering( + numdims, dims, DT_FLOAT, li, 1, false /*create_grad*/); + } } li->add_int_property("k", k); li->add_int_property("sorted", sorted); + li->add_int_property("speculative_decoding", speculative_decoding); layers.push_back(li); // outputs[0] = li->outputs[0]; // outputs[1] = li->outputs[1]; @@ -89,14 +95,23 @@ Op *ArgTopK::create_operator_from_layer( int k = value; layer->get_int_property("sorted", value); bool sorted = (bool)value; - return new ArgTopK( - model, layer->layer_guid, inputs[0], k, sorted, layer->name); + layer->get_int_property("speculative_decoding", value); + bool speculative_decoding = (bool)value; + + return new ArgTopK(model, + layer->layer_guid, + inputs[0], + k, + sorted, + speculative_decoding, + layer->name); } ArgTopKParams ArgTopK::get_params() const { ArgTopKParams params; params.k = this->k; params.sorted = this->sorted; + params.speculative_decoding = this->speculative_decoding; return params; } @@ -106,7 +121,8 @@ bool ArgTopKParams::is_valid(ParallelTensorShape const &) const { } bool operator==(ArgTopKParams const &lhs, ArgTopKParams const &rhs) { - return lhs.k == rhs.k && lhs.sorted == rhs.sorted; + return lhs.k == rhs.k && lhs.sorted == rhs.sorted && + lhs.speculative_decoding == rhs.speculative_decoding; } ArgTopK::ArgTopK(FFModel &model, @@ -114,6 +130,7 @@ ArgTopK::ArgTopK(FFModel &model, const ParallelTensor _input, int _k, bool _sorted, + bool _speculative_decoding, char const *name) : Op(model, OP_ARG_TOPK, @@ -121,9 +138,9 @@ ArgTopK::ArgTopK(FFModel &model, name, 1 /*inputs*/, 0 /*weights*/, - 1 /*outputs*/, + _speculative_decoding ? 2 : 1 /*outputs*/, _input), - k(_k), sorted(_sorted) { + k(_k), sorted(_sorted), speculative_decoding(_speculative_decoding) { // overwrite layer_guid layer_guid = _layer_guid; int numdim = inputs[0]->num_dims; @@ -131,26 +148,42 @@ ArgTopK::ArgTopK(FFModel &model, for (int i = 0; i < numdim; i++) { dims[i] = inputs[0]->dims[i]; } + dims[0].size = k; assert(inputs[0]->dims[0].degree == 1); assert(inputs[0]->dims[0].parallel_idx == -1); - // outputs[0] = model.create_parallel_tensor_legion_ordering( - // numdim, dims, _input->data_type, this, 0 /*owner_idx*/); + outputs[0] = model.create_parallel_tensor_legion_ordering( numdim, dims, DT_INT32, this, 0 /*owner_idx*/); + if (_speculative_decoding) { + outputs[1] = model.create_parallel_tensor_legion_ordering( + numdim, dims, DT_FLOAT, this, 1 /*owner_idx*/); + } } ArgTopK::ArgTopK(FFModel &model, LayerID const &layer_guid, ArgTopK const &other, const ParallelTensor input) - : ArgTopK(model, layer_guid, input, other.k, other.sorted, other.name) {} + : ArgTopK(model, + layer_guid, + input, + other.k, + other.sorted, + other.speculative_decoding, + other.name) {} ArgTopK::ArgTopK(FFModel &model, ArgTopKParams const ¶ms, - const ParallelTensor input, + ParallelTensor const input, char const *name) - : ArgTopK(model, params.layer_guid, input, params.k, params.sorted, name) {} + : ArgTopK(model, + params.layer_guid, + input, + params.k, + params.sorted, + params.speculative_decoding, + name) {} void ArgTopK::init_inference(FFModel const &ff, std::vector const &batch_inputs, @@ -243,8 +276,10 @@ OpMeta *ArgTopK::init_task(Task const *task, m->profiling = topk->profiling; m->inference_debugging = topk->inference_debugging; m->sorted = topk->sorted; + m->k = topk->k; std::strcpy(m->op_name, topk->name); m->layer_guid = topk->layer_guid; + m->speculative_decoding = topk->speculative_decoding; return m; } @@ -267,34 +302,64 @@ FutureMap ArgTopK::inference(FFModel const &ff, size_t machine_view_hash = view->hash(); /* std::cout << "ArgTopK op machine_view: " << *(MachineView const *)mv << std::endl; */ - IndexLauncher launcher(ARG_TOPK_INF_TASK_ID, - parallel_is, - TaskArgument(nullptr, 0), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - machine_view_hash); - launcher.add_future(bc); - launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, - 0 /*projection id*/, - READ_ONLY, - EXCLUSIVE, - batch_inputs[0]->region)); - launcher.add_field(0, FID_DATA); - launcher.add_region_requirement(RegionRequirement(batch_outputs[0]->part, - 0 /*projection id*/, - WRITE_ONLY, - EXCLUSIVE, - batch_outputs[0]->region)); - launcher.add_field(1, FID_DATA); - // launcher.add_region_requirement(RegionRequirement(batch_outputs[1]->part, - // 0 /*projection id*/, - // WRITE_ONLY, - // EXCLUSIVE, - // batch_outputs[1]->region)); - // launcher.add_field(2, FID_DATA); - return runtime->execute_index_space(ctx, launcher); + if (speculative_decoding) { + IndexLauncher launcher(ARG_TOPK_INF_SPECULATIVE_TASK_ID, + parallel_is, + TaskArgument(nullptr, 0), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + machine_view_hash); + launcher.add_future(bc); + launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, + 0 /*projection id*/, + READ_ONLY, + EXCLUSIVE, + batch_inputs[0]->region)); + launcher.add_field(0, FID_DATA); + launcher.add_region_requirement( + RegionRequirement(batch_outputs[0]->part, + 0 /*projection id*/, + WRITE_ONLY, + EXCLUSIVE, + batch_outputs[0]->region)); + launcher.add_field(1, FID_DATA); + + launcher.add_region_requirement( + RegionRequirement(batch_outputs[1]->part, + 0 /*projection id*/, + WRITE_ONLY, + EXCLUSIVE, + batch_outputs[1]->region)); + launcher.add_field(2, FID_DATA); + return runtime->execute_index_space(ctx, launcher); + + } else { + IndexLauncher launcher(ARG_TOPK_INF_TASK_ID, + parallel_is, + TaskArgument(nullptr, 0), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + machine_view_hash); + launcher.add_future(bc); + launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, + 0 /*projection id*/, + READ_ONLY, + EXCLUSIVE, + batch_inputs[0]->region)); + launcher.add_field(0, FID_DATA); + launcher.add_region_requirement( + RegionRequirement(batch_outputs[0]->part, + 0 /*projection id*/, + WRITE_ONLY, + EXCLUSIVE, + batch_outputs[0]->region)); + launcher.add_field(1, FID_DATA); + return runtime->execute_index_space(ctx, launcher); + } } InferenceResult @@ -317,9 +382,11 @@ InferenceResult m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); GenericTensorAccessorW indices = helperGetGenericTensorAccessorWO( DT_INT32, regions[1], task->regions[1], FID_DATA, ctx, runtime); + GenericTensorAccessorW probs; int batch_size = bc->num_active_tokens(); - ArgTopK::forward_kernel_wrapper(m, input, indices, batch_size); + ArgTopK::forward_kernel_wrapper( + m, input, probs, indices, batch_size, nullptr); if (m->inference_debugging) { assert(task->index_point.get_dim() == 1); @@ -334,6 +401,39 @@ InferenceResult return ir; } +BeamInferenceResult ArgTopK::inference_speculative_task( + Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + assert(regions.size() == 3); + assert(task->regions.size() == 3); + BeamSearchBatchConfig const &bc = + Future(task->futures[0]).get_result(); + if (bc.num_active_tokens() == 0) { + // Directly return for empty batch config + BeamInferenceResult ir; + return ir; + } + ArgTopKMeta *m = *((ArgTopKMeta **)task->local_args); + + GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( + m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); + GenericTensorAccessorW indices = helperGetGenericTensorAccessorWO( + DT_INT32, regions[1], task->regions[1], FID_DATA, ctx, runtime); + GenericTensorAccessorW probs = helperGetGenericTensorAccessorWO( + DT_FLOAT, regions[2], task->regions[2], FID_DATA, ctx, runtime); + + int batch_size = bc.num_active_tokens(); + ArgTopK::forward_kernel_wrapper(m, input, probs, indices, batch_size, &bc); + + BeamInferenceResult ir; + download_tensor( + indices.get_int32_ptr(), ir.token_ids, batch_size * m->k); + download_tensor(probs.get_float_ptr(), ir.probs, batch_size * m->k); + return ir; +} + void ArgTopK::backward(FFModel const &ff) { // ArgTopK does not support backward assert(false); @@ -345,6 +445,7 @@ void ArgTopK::serialize(Legion::Serializer &sez) const { sez.serialize(this->layer_guid.model_id); sez.serialize(this->k); sez.serialize(this->sorted); + sez.serialize(this->speculative_decoding); } Node ArgTopK::deserialize(FFModel &ff, @@ -359,12 +460,15 @@ Node ArgTopK::deserialize(FFModel &ff, LayerID layer_guid(id, transformer_layer_id, deserialized_model_id); int k; bool sorted; + bool speculative_decoding; dez.deserialize(k); dez.deserialize(sorted); + dez.deserialize(speculative_decoding); ArgTopKParams params; params.layer_guid = layer_guid; params.k = k; params.sorted = sorted; + params.speculative_decoding = speculative_decoding; return ff.get_or_create_node(inputs[0], params); } @@ -390,6 +494,7 @@ size_t hash::operator()( hash_combine(key, params.layer_guid.id); hash_combine(key, params.k); hash_combine(key, params.sorted); + hash_combine(key, params.speculative_decoding); return key; } }; // namespace std diff --git a/src/ops/arg_topk.cu b/src/ops/arg_topk.cu index 575e0183b4..0b8bb8b563 100644 --- a/src/ops/arg_topk.cu +++ b/src/ops/arg_topk.cu @@ -262,8 +262,9 @@ __device__ void mergeShards(int num_shards, int k, Entry *__restrict__ entries, Entry *__restrict__ top_k_heap, - // T *top_k_values, - int *top_k_indices) { + float *top_k_values, + int *top_k_indices, + bool speculative_decoding) { // If k < num_shards, we can use a min-heap with k elements to get the top k // of the sorted blocks. // If k > num_shards, we can initialize a min-heap with the top element from @@ -313,7 +314,11 @@ __device__ void mergeShards(int num_shards, int const last_k = k - 1; for (int rank = 0; rank < last_k; rank++) { Entry const &max_element = max_heap.root(); - // top_k_values[rank] = max_element.value; + if (speculative_decoding) { + assert(top_k_values != nullptr); + top_k_values[rank] = static_cast(max_element.value); + } + int shard_index = max_element.index; top_k_indices[rank] = entries[shard_index].index; int next_shard_index = shard_index + num_shards; @@ -337,8 +342,9 @@ __global__ void arg_topk_forward_kernel(T const *__restrict__ input, int length, int k, bool sorted, - // T *__restrict__ output, - int *__restrict__ indices) { + float *__restrict__ output, + int *__restrict__ indices, + bool speculative_decoding) { __shared__ char shared_memory[48 << 10]; int const batch_index = blockIdx.x; T const *batch_input = input + batch_index * length; @@ -350,15 +356,16 @@ __global__ void arg_topk_forward_kernel(T const *__restrict__ input, __syncthreads(); if (thread_index == 0) { int const offset = batch_index * k; - // auto batch_output = output + offset; + auto batch_output = output + offset; auto batch_indices = indices + offset; Entry *top_k_heap = shared_entries + thread_count * k; mergeShards(thread_count, k, shared_entries, top_k_heap, - // batch_output, - batch_indices); + batch_output, + batch_indices, + speculative_decoding); } } @@ -366,12 +373,13 @@ __global__ void arg_topk_forward_kernel(T const *__restrict__ input, template void ArgTopK::forward_kernel(ArgTopKMeta const *m, DT const *input_ptr, - // float *output_ptr, + float *output_ptr, int *indices_ptr, size_t batch_size, int length, int k, bool sorted, + BeamSearchBatchConfig const *bc, cudaStream_t stream) { // Adopted from TensorFlow's ArgTopK implementation // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/topk_op_gpu.h @@ -390,24 +398,58 @@ void ArgTopK::forward_kernel(ArgTopKMeta const *m, size_t shared_memory_size = (num_shards + 1) * k * sizeof(Entry
); // size_t num_blocks = (batch_size + num_shards - 1) / num_shards; size_t num_blocks = batch_size; - assert(num_shards >= (size_t)k); - num_shards = k; - arg_topk_forward_kernel<<>>( - input_ptr, - shared_memory_size, - length, - k, - sorted, - // output_ptr, - indices_ptr); + + // all requests are in the same beam stages + if (m->speculative_decoding) { + assert(bc->num_active_requests() >= 0); + + // check + int beam_size = -1; + for (int i = 1; i < bc->max_requests_per_batch(); i++) { + if (bc->request_completed[i]) { + continue; + } else if (beam_size == -1) { + beam_size = bc->beamRequestsInfo[i].beam_size; + } else { + assert(beam_size == bc->beamRequestsInfo[i].beam_size); + } + } + + assert(num_shards >= (size_t)beam_size); + num_shards = k; + arg_topk_forward_kernel<<>>( + input_ptr, + shared_memory_size, + length, + beam_size, + sorted, + output_ptr, + indices_ptr, + m->speculative_decoding); + } else { + + assert(num_shards >= (size_t)k); + num_shards = k; + arg_topk_forward_kernel<<>>( + input_ptr, + shared_memory_size, + length, + k, + sorted, + nullptr, + indices_ptr, + false); + } } /*static*/ void ArgTopK::forward_kernel_wrapper(ArgTopKMeta const *m, GenericTensorAccessorR const &input, // float *output_ptr, + GenericTensorAccessorW const &probs, GenericTensorAccessorW const &indices, - int batch_size) { + int batch_size, + BeamSearchBatchConfig const *bc) { cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); @@ -439,6 +481,7 @@ void ArgTopK::forward_kernel_wrapper(ArgTopKMeta const *m, int length = input.domain.hi()[0] - input.domain.lo()[0] + 1; int k = indices.domain.hi()[0] - indices.domain.lo()[0] + 1; /*TODO: This prints to 5*/ + // batch_size = input.domain.get_volume() / length; // assert(indices.domain.get_volume() / k == batch_size); cudaEvent_t t_start, t_end; @@ -451,22 +494,26 @@ void ArgTopK::forward_kernel_wrapper(ArgTopKMeta const *m, if (input.data_type == DT_HALF) { ArgTopK::forward_kernel(m, input.get_half_ptr(), - // output_ptr, + m->speculative_decoding ? probs.get_float_ptr() + : nullptr, indices.get_int32_ptr(), batch_size, length, k, m->sorted, + m->speculative_decoding ? bc : nullptr, stream); } else if (input.data_type == DT_FLOAT) { ArgTopK::forward_kernel(m, input.get_float_ptr(), - // output_ptr, + m->speculative_decoding ? probs.get_float_ptr() + : nullptr, indices.get_int32_ptr(), batch_size, length, k, m->sorted, + m->speculative_decoding ? bc : nullptr, stream); } else { assert(false && "Unsupported data type"); diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 37605c44a4..f72d320bc8 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -5917,6 +5917,24 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar); } } + { + TaskVariantRegistrar registrar(ARG_TOPK_INF_SPECULATIVE_TASK_ID, + "ArgTopK Speculative Inference"); + registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); + registrar.set_leaf(); + if (pre_register) { + Runtime::preregister_task_variant( + registrar, "ArgTopK Speculative Inference Task"); + } else { + if (enable_control_replication) { + registrar.global_registration = false; + } + runtime->register_task_variant( + registrar); + } + } // BeamTopk task { TaskVariantRegistrar registrar(BEAM_TOPK_INIT_TASK_ID, "BeamTopK Init"); From 4259d2dfa5c42488dad76d511517e45c0ad438c7 Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Mon, 1 Jan 2024 10:08:38 -0500 Subject: [PATCH 20/30] embedding --- include/flexflow/ops/embedding.h | 4 ++ src/ops/embedding.cc | 64 ++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+) diff --git a/include/flexflow/ops/embedding.h b/include/flexflow/ops/embedding.h index ae93ef4d1d..0f1b1335d4 100644 --- a/include/flexflow/ops/embedding.h +++ b/include/flexflow/ops/embedding.h @@ -80,6 +80,10 @@ class Embedding : public Op { std::vector const ®ions, Legion::Context ctx, Legion::Runtime *runtime); + static void inference_task(Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); static void backward_task(Legion::Task const *task, std::vector const ®ions, Legion::Context ctx, diff --git a/src/ops/embedding.cc b/src/ops/embedding.cc index 3be3eac618..40d5b600be 100644 --- a/src/ops/embedding.cc +++ b/src/ops/embedding.cc @@ -517,6 +517,70 @@ void Embedding::forward_task(Task const *task, assert(task->regions.size() == 3); // Assert that weight and output must have the same data type // otherwise, a cast operator should be inserted + assert(m->weight_type[0] == m->output_type[0]); + assert(m->input_type[0] == DT_INT32 || m->input_type[0] == DT_INT64); + GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( + m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); + GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( + m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); + GenericTensorAccessorR kernel = helperGetGenericTensorAccessorRO( + m->weight_type[0], regions[2], task->regions[2], FID_DATA, ctx, runtime); + if (m->aggr == AGGR_MODE_NONE) { + // assert(kernel_domain.get_dim() == 2); + assert(input.domain.get_dim() + 1 == output.domain.get_dim()); + for (size_t i = 0; i < input.domain.get_dim(); i++) { + assert(input.domain.hi()[i] == output.domain.hi()[i + 1]); + assert(input.domain.lo()[i] == output.domain.lo()[i + 1]); + } + assert(kernel.domain.hi()[0] - kernel.domain.lo()[0] == + output.domain.hi()[0] - output.domain.lo()[0]); + } else { + // assert(kernel_domain.get_dim() == 2); + assert(input.domain.get_dim() == output.domain.get_dim()); + for (size_t i = 1; i < input.domain.get_dim(); i++) { + assert(input.domain.hi()[i] == output.domain.hi()[i]); + assert(input.domain.lo()[i] == output.domain.lo()[i]); + } + assert(kernel.domain.hi()[0] - kernel.domain.lo()[0] == + output.domain.hi()[0] - output.domain.lo()[0]); + } + + int in_dim, out_dim, effective_batch_size; + if (m->aggr == AGGR_MODE_NONE) { + in_dim = 1; + out_dim = output.domain.hi()[0] - output.domain.lo()[0] + 1; + effective_batch_size = output.domain.get_volume() / out_dim; + assert(effective_batch_size * in_dim == input.domain.get_volume()); + } else { + in_dim = input.domain.hi()[0] - input.domain.lo()[0] + 1; + out_dim = output.domain.hi()[0] - output.domain.lo()[0] + 1; + effective_batch_size = output.domain.get_volume() / out_dim; + assert(effective_batch_size * in_dim == input.domain.get_volume()); + } + forward_kernel_wrapper( + m, input, output, kernel, in_dim, out_dim, effective_batch_size); + if (m->inference_debugging) { + assert(task->index_point.get_dim() == 1); + int shard_id = task->index_point.point_data[0]; + Embedding::save_inference_tensors_to_file( + m, shard_id, nullptr, {input}, {kernel}, {output}); + } +} + +/* + regions[0](I): input + regions[1](O): output + regions[2](I): kernel +*/ +void Embedding::inference_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + EmbeddingMeta *m = *((EmbeddingMeta **)task->local_args); + assert(regions.size() == 3); + assert(task->regions.size() == 3); + // Assert that weight and output must have the same data type + // otherwise, a cast operator should be inserted BatchConfig const *bc = BatchConfig::from_future(task->futures[0]); if (bc->num_active_tokens() == 0) { return; From fae7fba1994aaf3c04da250a04bec3beb217236e Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Mon, 1 Jan 2024 10:13:30 -0500 Subject: [PATCH 21/30] fmt --- include/flexflow/ops/embedding.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/flexflow/ops/embedding.h b/include/flexflow/ops/embedding.h index 0f1b1335d4..ed89fcf37a 100644 --- a/include/flexflow/ops/embedding.h +++ b/include/flexflow/ops/embedding.h @@ -83,7 +83,7 @@ class Embedding : public Op { static void inference_task(Legion::Task const *task, std::vector const ®ions, Legion::Context ctx, - Legion::Runtime *runtime); + Legion::Runtime *runtime); static void backward_task(Legion::Task const *task, std::vector const ®ions, Legion::Context ctx, From 8d1d5842253a0b6c894bec14550dd1e88eb9c4fd Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Mon, 1 Jan 2024 12:05:12 -0500 Subject: [PATCH 22/30] hip --- src/ops/arg_topk.cpp | 90 ++++++++++++++++++++++++++++++++------------ 1 file changed, 66 insertions(+), 24 deletions(-) diff --git a/src/ops/arg_topk.cpp b/src/ops/arg_topk.cpp index 6db8abb8c4..f431d3d4bf 100644 --- a/src/ops/arg_topk.cpp +++ b/src/ops/arg_topk.cpp @@ -263,8 +263,9 @@ __device__ void mergeShards(int num_shards, int k, Entry *__restrict__ entries, Entry *__restrict__ top_k_heap, - // T *top_k_values, - int *top_k_indices) { + float *top_k_values, + int *top_k_indices, + bool speculative_decoding) { // If k < num_shards, we can use a min-heap with k elements to get the top k // of the sorted blocks. // If k > num_shards, we can initialize a min-heap with the top element from @@ -314,7 +315,10 @@ __device__ void mergeShards(int num_shards, int const last_k = k - 1; for (int rank = 0; rank < last_k; rank++) { Entry const &max_element = max_heap.root(); - // top_k_values[rank] = max_element.value; + if (speculative_decoding) { + assert(top_k_values != nullptr); + top_k_values[rank] = static_cast(max_element.value); + } int shard_index = max_element.index; top_k_indices[rank] = entries[shard_index].index; int next_shard_index = shard_index + num_shards; @@ -338,8 +342,9 @@ __global__ void arg_topk_forward_kernel(T const *__restrict__ input, int length, int k, bool sorted, - // T *__restrict__ output, - int *__restrict__ indices) { + float *__restrict__ output, + int *__restrict__ indices, + bool speculative_decoding) { __shared__ char shared_memory[48 << 10]; int const batch_index = blockIdx.x; T const *batch_input = input + batch_index * length; @@ -351,15 +356,16 @@ __global__ void arg_topk_forward_kernel(T const *__restrict__ input, __syncthreads(); if (thread_index == 0) { int const offset = batch_index * k; - // auto batch_output = output + offset; + auto batch_output = output + offset; auto batch_indices = indices + offset; Entry *top_k_heap = shared_entries + thread_count * k; mergeShards(thread_count, k, shared_entries, top_k_heap, - // batch_output, - batch_indices); + batch_output, + batch_indices, + speculative_decoding); } } @@ -367,12 +373,13 @@ __global__ void arg_topk_forward_kernel(T const *__restrict__ input, template void ArgTopK::forward_kernel(ArgTopKMeta const *m, DT const *input_ptr, - // float *output_ptr, + float *output_ptr, int *indices_ptr, size_t batch_size, int length, int k, bool sorted, + BeamSearchBatchConfig const *bc, hipStream_t stream) { // Adopted from TensorFlow's ArgTopK implementation // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/topk_op_gpu.h @@ -391,28 +398,57 @@ void ArgTopK::forward_kernel(ArgTopKMeta const *m, size_t shared_memory_size = (num_shards + 1) * k * sizeof(Entry
); // size_t num_blocks = (batch_size + num_shards - 1) / num_shards; size_t num_blocks = batch_size; - assert(num_shards >= (size_t)k); - num_shards = k; - hipLaunchKernelGGL(arg_topk_forward_kernel, - num_blocks, - num_shards, - 0, - stream, - input_ptr, - shared_memory_size, - length, - k, - sorted, - // output_ptr, - indices_ptr); + // all requests are in the same beam stages + if (m->speculative_decoding) { + assert(bc->num_active_requests() >= 0); + + // check + int beam_size = -1; + for (int i = 1; i < bc->max_requests_per_batch(); i++) { + if (bc->request_completed[i]) { + continue; + } else if (beam_size == -1) { + beam_size = bc->beamRequestsInfo[i].beam_size; + } else { + assert(beam_size == bc->beamRequestsInfo[i].beam_size); + } + } + + assert(num_shards >= (size_t)beam_size); + num_shards = k; + arg_topk_forward_kernel<<>>( + input_ptr, + shared_memory_size, + length, + beam_size, + sorted, + output_ptr, + indices_ptr, + m->speculative_decoding); + } else { + + assert(num_shards >= (size_t)k); + num_shards = k; + arg_topk_forward_kernel<<>>( + input_ptr, + shared_memory_size, + length, + k, + sorted, + nullptr, + indices_ptr, + false); + } } /*static*/ void ArgTopK::forward_kernel_wrapper(ArgTopKMeta const *m, GenericTensorAccessorR const &input, + GenericTensorAccessorW const &probs, // float *output_ptr, GenericTensorAccessorW const &indices, - int batch_size) { + int batch_size, + BeamSearchBatchConfig const *bc) { hipStream_t stream; checkCUDA(get_legion_stream(&stream)); // Domain in1_domain = runtime->get_index_space_domain( @@ -457,21 +493,27 @@ void ArgTopK::forward_kernel_wrapper(ArgTopKMeta const *m, ArgTopK::forward_kernel(m, input.get_half_ptr(), // output_ptr, + m->speculative_decoding ? probs.get_float_ptr() + : nullptr, indices.get_int32_ptr(), batch_size, length, k, m->sorted, + m->speculative_decoding ? bc : nullptr, stream); } else if (input.data_type == DT_FLOAT) { ArgTopK::forward_kernel(m, input.get_float_ptr(), // output_ptr, + m->speculative_decoding ? probs.get_float_ptr() + : nullptr, indices.get_int32_ptr(), batch_size, length, k, m->sorted, + m->speculative_decoding ? bc : nullptr, stream); } else { assert(false && "Unsupported data type"); From d7e8d728b67557bebbf9f76de9b806575b8a4cc2 Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Tue, 2 Jan 2024 13:54:29 -0500 Subject: [PATCH 23/30] fix corner case --- include/flexflow/batch_config.h | 14 ++- include/flexflow/config.h | 3 +- include/flexflow/model.h | 1 + .../inc_multihead_self_attention_utils.cuh | 2 +- .../ops/spec_inc_multihead_self_attention.h | 1 + .../ops/tree_inc_multihead_self_attention.h | 1 + include/flexflow/request_manager.h | 2 + inference/models/falcon.cc | 5 +- inference/models/llama.cc | 5 +- inference/models/mpt.cc | 5 +- inference/models/opt.cc | 5 +- inference/models/starcoder.cc | 5 +- src/ops/arg_topk.cu | 11 ++- src/ops/inc_multihead_self_attention.cu | 4 +- src/ops/spec_inc_multihead_self_attention.cu | 60 +++++++----- src/ops/tree_inc_multihead_self_attention.cu | 62 +++++++------ src/runtime/batch_config.cc | 6 ++ src/runtime/beam_search_batch_config.cc | 4 + src/runtime/model.cc | 14 +++ src/runtime/request_manager.cc | 93 +++++++++++-------- src/runtime/request_manager.cu | 28 +++++- 21 files changed, 225 insertions(+), 106 deletions(-) diff --git a/include/flexflow/batch_config.h b/include/flexflow/batch_config.h index 13904aaa46..ef17ef43ed 100644 --- a/include/flexflow/batch_config.h +++ b/include/flexflow/batch_config.h @@ -45,6 +45,7 @@ class BatchConfig { int num_active_tokens() const; static int max_requests_per_batch(); static int max_tokens_per_batch(); + static int max_verify_tokens_per_batch(); static int max_sequence_length(); friend std::ostream &operator<<(std::ostream &os, BatchConfig const &bc); void print() const; @@ -72,6 +73,7 @@ class BatchConfig { // request id in batch config: int batch_config_request_id; + bool prompt_phase = false; RequestGuid request_guid; }; struct PerTokenInfo { @@ -85,15 +87,15 @@ class BatchConfig { // how many tokens before the tree, every sub requests need this part of // cache - int non_tree_cache_size; + int non_tree_cache_size = 0; // current tree size - int tree_size; + int tree_size = 0; - int this_layer_size; + int this_layer_size = 0; // input length-> prompt/root - int prompt_size; + int prompt_size = 0; }; BitMask causalMask[MAX_NUM_REQUESTS]; @@ -145,9 +147,13 @@ class BeamSearchBatchConfig : public BatchConfig { bool done() const; int max_beam_depth_all_requests() const; int current_depth_all_requests() const; + int get_speculative_request_num() const; size_t beam_width; size_t target_iterations; + + // how many requests is in speculative phase + int speculative_request_num = 0; inline static int const MAX_BEAM_WIDTH = 3; inline static int const MAX_BEAM_DEPTH = 8; diff --git a/include/flexflow/config.h b/include/flexflow/config.h index e1480264cc..17a3f59e29 100644 --- a/include/flexflow/config.h +++ b/include/flexflow/config.h @@ -84,7 +84,8 @@ struct FFHandler { sizeof(BeamSearchBatchConfig::beamTokenInfo) + sizeof(BeamSearchBatchConfig::beamRequestsInfo) + sizeof(BatchConfig::causalMask) + - sizeof(TreeVerifyBatchConfig::committed_tokens); + sizeof(TreeVerifyBatchConfig::committed_tokens) + + sizeof(BatchConfig::request_completed); void *offload_reserve_space; size_t offload_reserve_space_size; DataType quantization_type; diff --git a/include/flexflow/model.h b/include/flexflow/model.h index cf7bb3dd2d..6f805e21bd 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -73,6 +73,7 @@ enum TaskIDs { DROPOUT_BWD_TASK_ID, EMBED_INIT_TASK_ID, EMBED_FWD_TASK_ID, + EMBED_INF_TASK_ID, EMBED_BWD_TASK_ID, GATHER_INIT_TASK_ID, GATHER_FWD_TASK_ID, diff --git a/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh b/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh index c128c1a126..d1e0e050b2 100644 --- a/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh +++ b/include/flexflow/ops/kernels/inc_multihead_self_attention_utils.cuh @@ -493,7 +493,7 @@ inline void smem_size_in_bytes_tree(int hidden_size_per_head, } // todo fix this - int max_qk_length = max_query_length * max_total_length; + int max_qk_length = max_query_length; // The amount of shared memory needed to store the Q*K^T values in float. size_t qk_sz = div_up(max_qk_length + 1, 4) * 16; diff --git a/include/flexflow/ops/spec_inc_multihead_self_attention.h b/include/flexflow/ops/spec_inc_multihead_self_attention.h index a306f7985a..a0d01092bf 100644 --- a/include/flexflow/ops/spec_inc_multihead_self_attention.h +++ b/include/flexflow/ops/spec_inc_multihead_self_attention.h @@ -142,6 +142,7 @@ class SpecIncMultiHeadSelfAttentionMeta : public IncMultiHeadSelfAttentionMeta { Realm::RegionInstance beam_search_reserve_inst; BeamSearchBatchConfig::BeamSearchPerTokenInfo *beam_token_infos; BeamSearchBatchConfig::BeamSearchPerRequestInfo *beam_request_infos; + bool *request_completed; BatchConfig::BitMask *causalMask; }; diff --git a/include/flexflow/ops/tree_inc_multihead_self_attention.h b/include/flexflow/ops/tree_inc_multihead_self_attention.h index d160da4a72..02df0c0137 100644 --- a/include/flexflow/ops/tree_inc_multihead_self_attention.h +++ b/include/flexflow/ops/tree_inc_multihead_self_attention.h @@ -147,6 +147,7 @@ class TreeIncMultiHeadSelfAttentionMeta : public IncMultiHeadSelfAttentionMeta { int num_active_tokens; Realm::RegionInstance committed_token_reserve_inst; TreeVerifyBatchConfig::CommittedTokensInfo *committed_token_infos; + bool *request_completed; BatchConfig::BitMask *causalMask; }; diff --git a/include/flexflow/request_manager.h b/include/flexflow/request_manager.h index 1c4b0b2a2f..33714c106e 100644 --- a/include/flexflow/request_manager.h +++ b/include/flexflow/request_manager.h @@ -103,6 +103,7 @@ class RequestManager { int get_max_requests_per_batch(); void set_max_tokens_per_batch(int max_num_tokens); int get_max_tokens_per_batch(); + int get_max_verify_tokens_per_batch(); void set_max_sequence_length(int max_seq_length); void push_spec_infer_tree_width(int tree_width); int get_max_sequence_length(); @@ -113,6 +114,7 @@ class RequestManager { std::string const &path); void register_output_filepath(std::string const &); void initBitMask(BatchConfig::BitMask &bitmask, int initLength); + void appendPendingRequest(BatchConfig::BitMask &bitmask, int initLength); void appendBitMask(BatchConfig::BitMask &bitmask, int newNodes, int preBeamSize, diff --git a/inference/models/falcon.cc b/inference/models/falcon.cc index bfcec847b9..999ca37037 100644 --- a/inference/models/falcon.cc +++ b/inference/models/falcon.cc @@ -39,7 +39,10 @@ void FALCON::create_falcon_model(FFModel &ff, Tensor input; { // assert(falcon_config.max_num_tokens <= BatchConfig::MAX_NUM_TOKENS); - int const token_dims[] = {BatchConfig::max_tokens_per_batch(), 1}; + int const token_dims[] = {mode == TREE_VERIFY_MODE + ? BatchConfig::max_verify_tokens_per_batch() + : BatchConfig::max_tokens_per_batch(), + 1}; input = ff.create_tensor<2>(token_dims, DT_INT32); } diff --git a/inference/models/llama.cc b/inference/models/llama.cc index e9c84efe90..e54d6d8811 100644 --- a/inference/models/llama.cc +++ b/inference/models/llama.cc @@ -41,7 +41,10 @@ void LLAMA::create_llama_model(FFModel &ff, Tensor input; { - int const token_dims[] = {BatchConfig::max_tokens_per_batch(), 1}; + int const token_dims[] = {mode == TREE_VERIFY_MODE + ? BatchConfig::max_verify_tokens_per_batch() + : BatchConfig::max_tokens_per_batch(), + 1}; input = ff.create_tensor<2>(token_dims, DT_INT32); } diff --git a/inference/models/mpt.cc b/inference/models/mpt.cc index b074d332ed..3df67b264c 100644 --- a/inference/models/mpt.cc +++ b/inference/models/mpt.cc @@ -40,7 +40,10 @@ void MPT::create_mpt_model(FFModel &ff, //------------------------------ build the model -------------------------- Tensor input; { - int const token_dims[] = {BatchConfig::max_tokens_per_batch(), 1}; + int const token_dims[] = {mode == TREE_VERIFY_MODE + ? BatchConfig::max_verify_tokens_per_batch() + : BatchConfig::max_tokens_per_batch(), + 1}; input = ff.create_tensor<2>(token_dims, DT_INT32); } diff --git a/inference/models/opt.cc b/inference/models/opt.cc index 9b29ae5410..0279f83239 100644 --- a/inference/models/opt.cc +++ b/inference/models/opt.cc @@ -42,7 +42,10 @@ void OPT::create_opt_model(FFModel &ff, Tensor position_input; ff.set_position_offset(2); { - int const token_dims[] = {BatchConfig::max_tokens_per_batch(), 1}; + int const token_dims[] = {mode == TREE_VERIFY_MODE + ? BatchConfig::max_verify_tokens_per_batch() + : BatchConfig::max_tokens_per_batch(), + 1}; input = ff.create_tensor<2>(token_dims, DT_INT32); position_input = ff.create_tensor<2>(token_dims, DT_INT32); } diff --git a/inference/models/starcoder.cc b/inference/models/starcoder.cc index ba7b2cb43a..e683376e47 100644 --- a/inference/models/starcoder.cc +++ b/inference/models/starcoder.cc @@ -48,7 +48,10 @@ void STARCODER::create_starcoder_model( ff.set_position_offset(0); { // assert(startcoder_config.max_num_tokens <= BatchConfig::MAX_NUM_TOKENS); - int const token_dims[] = {BatchConfig::max_tokens_per_batch(), 1}; + int const token_dims[] = {mode == TREE_VERIFY_MODE + ? BatchConfig::max_verify_tokens_per_batch() + : BatchConfig::max_tokens_per_batch(), + 1}; input = ff.create_tensor<2>(token_dims, DT_INT32); position_input = ff.create_tensor<2>(token_dims, DT_INT32); } diff --git a/src/ops/arg_topk.cu b/src/ops/arg_topk.cu index 0b8bb8b563..3302178728 100644 --- a/src/ops/arg_topk.cu +++ b/src/ops/arg_topk.cu @@ -405,13 +405,20 @@ void ArgTopK::forward_kernel(ArgTopKMeta const *m, // check int beam_size = -1; - for (int i = 1; i < bc->max_requests_per_batch(); i++) { + + // allow last request different with others + int num_activate_requests = bc->num_active_requests(); + int last_request_idx = + bc->requestsInfo[num_activate_requests - 1].batch_config_request_id; + for (int i = 0; i < bc->max_requests_per_batch(); i++) { if (bc->request_completed[i]) { continue; } else if (beam_size == -1) { beam_size = bc->beamRequestsInfo[i].beam_size; - } else { + + } else if (i != last_request_idx) { assert(beam_size == bc->beamRequestsInfo[i].beam_size); + } else if (i == last_request_idx) { } } diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index db64868cb9..7c8601d3c8 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -1349,7 +1349,9 @@ IncMultiHeadSelfAttentionMeta::IncMultiHeadSelfAttentionMeta( // allocate memory for the seqArray and reserve space { - int max_tokens_per_batch = BatchConfig::max_tokens_per_batch(); + int max_tokens_per_batch = infer_mode == TREE_VERIFY_MODE + ? BatchConfig::max_verify_tokens_per_batch() + : BatchConfig::max_tokens_per_batch(); size_t qkv_max_proj_size = max_tokens_per_batch * (qProjSize * num_q_heads + kProjSize * num_q_heads + vProjSize * num_q_heads); diff --git a/src/ops/spec_inc_multihead_self_attention.cu b/src/ops/spec_inc_multihead_self_attention.cu index 88dd3f92e4..b31e5d0994 100644 --- a/src/ops/spec_inc_multihead_self_attention.cu +++ b/src/ops/spec_inc_multihead_self_attention.cu @@ -50,7 +50,8 @@ __global__ void compute_spec_inc_attention_kernel_generation_kernel( int hidden_size, BatchConfig::PerRequestInfo *request_infos, BeamSearchBatchConfig::BeamSearchPerRequestInfo *beam_request_infos, - BatchConfig::BitMask *causalMask) { + BatchConfig::BitMask *causalMask, + bool *request_completed) { // q, k using Q_vec = typename VEC_K::Type; @@ -86,11 +87,12 @@ __global__ void compute_spec_inc_attention_kernel_generation_kernel( // request_infos[batch_config_request_id].first_token_depth_in_request + // request_infos[batch_config_request_id].num_tokens_in_batch; - int const totalCacheSize = bitmask.non_tree_cache_size + bitmask.tree_size; + int const totalCacheSize = + bitmask.non_tree_cache_size + bitmask.tree_size + bitmask.prompt_size - 1; int first_token_idx = 0; - for (int r = 0; r < request_idx; r++) { - first_token_idx += causalMask[r].this_layer_size; + for (int r = 0; r < batch_config_request_id; r++) { + first_token_idx += request_completed[r] ? 0 : causalMask[r].this_layer_size; } int const tree_branch_num = @@ -138,7 +140,8 @@ __global__ void compute_spec_inc_attention_kernel_generation_kernel( ii * THREADS_PER_KEY * K_VEC_SIZE); } - int const query_token = bitmask.tree_size - tree_branch_num + qi; + int const query_token = + bitmask.prompt_size + bitmask.tree_size - 1 - tree_branch_num + qi; __syncthreads(); for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { @@ -163,8 +166,12 @@ __global__ void compute_spec_inc_attention_kernel_generation_kernel( (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & (1 << query_token)))); - // if (blockIdx.y == 0 && blockIdx.x == 0 && !mask) { - // printf("spec inc attn qkqkqk %d, %.10f, %d\n", ti, qk, qi); + // if (head_idx == 0 && ti == 0 && request_idx == 15 && !mask) { + // printf("spec inc attn qkqkqk request id %d, %.10f, %d\n", + // batch_config_request_id, + // ti, + // qk, + // qi); // } qk_max = mask ? qk_max : fmaxf(qk_max, qk); qk_smem[ti - first_step] = mask ? 0.f : qk; @@ -336,17 +343,12 @@ __global__ void spec_inc_store_kv_cache( BatchConfig::BitMask bitmask = causalMask[req_id]; - // int const tree_branch_num = beamRequestInfos[req_id].sub_request_num; - - // int const query_token = bitmask.non_tree_cache_size + bitmask.tree_size - - // tree_branch_num + sub_req_id + tok_id; - // bitmask.tree_size - tree_branch_num + sub_req_id; - // if prompt token -> token id // if tree token: - int const cache_idx = bitmask.non_tree_cache_size + bitmask.tree_size - - bitmask.this_layer_size + token_idx - - request_token_offset; + + int const cache_idx = bitmask.prompt_size + bitmask.non_tree_cache_size + + bitmask.tree_size - 1 - bitmask.this_layer_size + + token_idx - request_token_offset; kCache_ptr[req_id * (hidden_size * max_seq_len) + (cache_idx)*hidden_size + offset] = kVal; @@ -411,7 +413,8 @@ void update_kv_cache_kernel(SpecIncMultiHeadSelfAttentionMeta const *m, m->hidden_size, \ m->request_infos, \ m->beam_request_infos, \ - m->causalMask) + m->causalMask, \ + m->request_completed) template void compute_spec_inc_attention_kernel_generation( @@ -420,7 +423,8 @@ void compute_spec_inc_attention_kernel_generation( DT *output_ptr, cudaStream_t stream) { // one block == one head per request - dim3 grid(m->num_q_heads, bc->num_active_requests()); + // how many generation requests + dim3 grid(m->num_q_heads, bc->get_speculative_request_num()); int const per_head_size = m->qProjSize; float scale = (*m->qk_prod_scaling) ? 1.0f / sqrt(m->kProjSize) : 1.0f; size_t smem_sz; @@ -499,11 +503,10 @@ void compute_attention_kernel_prompt(SpecIncMultiHeadSelfAttentionMeta const *m, for (int i = 0; i < bc->max_requests_per_batch(); i++) { if (bc->request_completed[i]) { continue; + } else if (tokens_previous_requests < bc->num_generation_tokens) { + tokens_previous_requests += bc->requestsInfo[i].num_tokens_in_batch; + continue; } - // else if (tokens_previous_requests < bc->num_generation_tokens) { - // tokens_previous_requests += bc->requestsInfo[i].num_tokens_in_batch; - // continue; - // } // all requests in prompt phase should only have one sub requests; assert(bc->sub_requests[i] == 1); @@ -659,10 +662,10 @@ void compute_attention_kernel_prompt(SpecIncMultiHeadSelfAttentionMeta const *m, // To get C, skip over softmax(QK^T/sqrt(d_k))V products from previous // requests - // print_tensor((float*)C_softmax, 32, "C_softmax"); + int token_offset = bc->requestsInfo[i].first_token_offset_in_batch; + C = static_cast
(m->attn_heads) + - (tokens_previous_requests + bc->num_generation_tokens) * - m->num_q_heads * m->vProjSize; + (token_offset)*m->num_q_heads * m->vProjSize; checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, CUBLAS_OP_N, CUBLAS_OP_T, @@ -860,6 +863,13 @@ SpecIncMultiHeadSelfAttentionMeta::SpecIncMultiHeadSelfAttentionMeta( sizeof(BatchConfig::tokensInfo) + sizeof(BatchConfig::requestsInfo) + sizeof(BeamSearchBatchConfig::beamTokenInfo) + sizeof(BeamSearchBatchConfig::beamRequestsInfo)); + + request_completed = reinterpret_cast( + reinterpret_cast(handler.batch_config_metadata) + + sizeof(BatchConfig::tokensInfo) + sizeof(BatchConfig::requestsInfo) + + sizeof(BeamSearchBatchConfig::beamTokenInfo) + + sizeof(BeamSearchBatchConfig::beamRequestsInfo) + + sizeof(BatchConfig::causalMask)); } cudaStreamSynchronize(stream); diff --git a/src/ops/tree_inc_multihead_self_attention.cu b/src/ops/tree_inc_multihead_self_attention.cu index b4af80976f..fc86e1498e 100644 --- a/src/ops/tree_inc_multihead_self_attention.cu +++ b/src/ops/tree_inc_multihead_self_attention.cu @@ -54,6 +54,7 @@ __global__ void compute_attention_kernel_fused_kernel( int num_heads, int num_requests, BatchConfig::BitMask *causalMask, + bool *request_completed, int qk_smem_sz) { // q, k @@ -90,13 +91,14 @@ __global__ void compute_attention_kernel_fused_kernel( BatchConfig::BitMask bitmask = causalMask[batch_config_request_id]; int first_token_idx = 0; - for (int r = 0; r < request_idx; r++) { - first_token_idx += request_infos[r].num_tokens_in_batch; + for (int r = 0; r < batch_config_request_id; r++) { + first_token_idx += + request_completed[r] ? 0 : request_infos[r].num_tokens_in_batch; } - // if(tidx == 0 && head_idx == 0){ - // printf("tree req: %d, %d\n", request_idx, first_token_idx); - // } + bool prompt_phase = request_infos[batch_config_request_id].prompt_phase; + int q_start = + request_infos[batch_config_request_id].first_token_depth_in_request; // shared memory objects extern __shared__ char smem_[]; @@ -139,7 +141,7 @@ __global__ void compute_attention_kernel_fused_kernel( q_ptr + (hidden_size * QKV_WEIGHT_NUM * qi) + ki + ii * THREADS_PER_KEY * K_VEC_SIZE); - // if (head_idx == 0 && qi == 1 && tidx == 0) { + // if (head_idx == 0 && request_idx == 1 && tidx == 0) { // printf("laod q %d, %d %.10f\n", // request_idx, // qi,q_vecs[ki_o][ii].x); @@ -163,19 +165,23 @@ __global__ void compute_attention_kernel_fused_kernel( if (ti < tlength && tidx % THREADS_PER_KEY == 0) { bool const mask = - (ti >= bitmask.non_tree_cache_size && - (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & (1 << qi)))); + prompt_phase ? (qi + q_start < ti) + : (ti >= bitmask.non_tree_cache_size && + (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & + (1 << qi)))); qk_max = mask ? qk_max : fmaxf(qk_max, qk); - // if (head_idx == 0 && qi == 0 && !mask) { - // printf("tree attn qkqkqkqk request id %d, %d %.10f, %.10f, %.10f\n - // ", + // if (head_idx == 0 && !mask) { + // printf("tree attn qkqkqkqk request id %d qi%d, ti %d, %.10f, %.10f, + // %.10f, %d\n", // request_idx, + // qi, // ti, // qk, // q_vecs[ki_o][0].x, - // k[0].x); + // k[0].x, + // bitmask.non_tree_cache_size); // } qk_smem[ti - first_step] = mask ? 0.0f : qk; } @@ -217,8 +223,10 @@ __global__ void compute_attention_kernel_fused_kernel( float exp_sum = 0.f; for (int ti = first_step + tidx; ti < tlength; ti += THREADS_PER_BLOCK) { bool const mask = - (ti >= bitmask.non_tree_cache_size && - (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & (1 << qi)))); + prompt_phase ? (q_start + qi < ti) + : (ti >= bitmask.non_tree_cache_size && + (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & + (1 << qi)))); float logit = mask ? 0.0f : __expf(qk_smem[ti - first_step] - qk_max); exp_sum += logit; qk_smem[ti - first_step] = mask ? 0.0f : logit; @@ -265,8 +273,11 @@ __global__ void compute_attention_kernel_fused_kernel( if (ti < tlength) { bool const mask = - (ti >= bitmask.non_tree_cache_size && - (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & (1 << qi)))); + prompt_phase + ? (q_start + qi < ti) + : (ti >= bitmask.non_tree_cache_size && + (!(bitmask.mask[ti - bitmask.non_tree_cache_size] & + (1 << qi)))); float logit = mask ? 0.0f : qk_smem[ti - first_step]; out = FlexFlow::fma(logit, cast_to_float(v), out); } @@ -810,6 +821,7 @@ void compute_attention_kernel(TreeIncMultiHeadSelfAttentionMeta const *m, m->num_q_heads, \ bc->num_active_requests(), \ m->causalMask, \ + m->request_completed, \ smem_sz[0]) template @@ -841,7 +853,6 @@ void compute_attention_kernel_fused(TreeIncMultiHeadSelfAttentionMeta const *m, dim3 grid(m->num_q_heads, bc->num_active_requests()); int const per_head_size = m->qProjSize; float scale = (*m->qk_prod_scaling) ? 1.0f / sqrt(m->kProjSize) : 1.0f; - // 0->qk production size, 1->total shared size int smem_sz[2]; if (per_head_size == 64) { @@ -890,17 +901,6 @@ void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m, // std::cout << "tokens to be committed: " << bc->num_tokens_to_commit << // "\n"; - cudaMemcpyAsync(m->committed_token_infos, - &(bc->committed_tokens), - bc->num_tokens_to_commit * - sizeof(TreeVerifyBatchConfig::CommittedTokensInfo), - cudaMemcpyHostToDevice, - stream); - cudaMemcpyAsync(m->causalMask, - &(bc->causalMask), - bc->num_active_requests() * sizeof(BatchConfig::BitMask), - cudaMemcpyHostToDevice, - stream); commit_tokens
(m, bc, stream); // After commit we update m->num_active_tokens to be the number of active @@ -1068,6 +1068,12 @@ TreeIncMultiHeadSelfAttentionMeta::TreeIncMultiHeadSelfAttentionMeta( sizeof(BatchConfig::tokensInfo) + sizeof(BatchConfig::requestsInfo) + sizeof(BatchConfig::causalMask)); + + request_completed = reinterpret_cast( + reinterpret_cast(handler.batch_config_metadata) + + sizeof(BatchConfig::tokensInfo) + sizeof(BatchConfig::requestsInfo) + + sizeof(BatchConfig::causalMask) + + sizeof(TreeVerifyBatchConfig::committed_tokens)); } cudaStreamSynchronize(stream); diff --git a/src/runtime/batch_config.cc b/src/runtime/batch_config.cc index d2fbc0883f..c432208eca 100644 --- a/src/runtime/batch_config.cc +++ b/src/runtime/batch_config.cc @@ -84,6 +84,12 @@ int BatchConfig::max_tokens_per_batch() { return RequestManager::get_request_manager()->get_max_tokens_per_batch(); } +/*static*/ +int BatchConfig::max_verify_tokens_per_batch() { + return RequestManager::get_request_manager() + ->get_max_verify_tokens_per_batch(); +} + /*static*/ int BatchConfig::max_sequence_length() { return RequestManager::get_request_manager()->get_max_sequence_length(); diff --git a/src/runtime/beam_search_batch_config.cc b/src/runtime/beam_search_batch_config.cc index 74843e9460..ff7bf1a819 100644 --- a/src/runtime/beam_search_batch_config.cc +++ b/src/runtime/beam_search_batch_config.cc @@ -85,6 +85,10 @@ int BeamSearchBatchConfig::max_beam_depth_all_requests() const { return max_depth_all_requests; } +int BeamSearchBatchConfig::get_speculative_request_num() const { + return speculative_request_num; +} + int BeamSearchBatchConfig::current_depth_all_requests() const { int current_depth = 0; for (int i = 0; i < BeamSearchBatchConfig::max_requests_per_batch(); i++) { diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 32b524f643..76bed36bda 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -4805,6 +4805,20 @@ void register_flexflow_internal_tasks(Runtime *runtime, runtime->register_task_variant(registrar); } } + { + TaskVariantRegistrar registrar(EMBED_INF_TASK_ID, "Embedding Inference"); + registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); + registrar.set_leaf(); + if (pre_register) { + Runtime::preregister_task_variant( + registrar, "Embedding Inference Task"); + } else { + if (enable_control_replication) { + registrar.global_registration = false; + } + runtime->register_task_variant(registrar); + } + } { TaskVariantRegistrar registrar(EMBED_BWD_TASK_ID, "Embedding Backward"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 89d4ddaed4..88754f5a82 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -97,6 +97,12 @@ int RequestManager::get_max_tokens_per_batch() { return max_tokens_per_batch; } +int RequestManager::get_max_verify_tokens_per_batch() { + assert(max_tokens_per_batch > 0); + return max_tokens_per_batch + + BatchConfig::MAX_SPEC_TREE_TOKEN_NUM * max_requests_per_batch; +} + void RequestManager::set_max_sequence_length(int max_seq_length) { assert(max_sequence_length == -1 || max_sequence_length == max_seq_length); max_sequence_length = max_seq_length; @@ -1126,7 +1132,6 @@ BeamSearchBatchConfig old_bc.beamRequestsInfo[i].sub_request_num, tree, old_bc.beamRequestsInfo[i].current_depth); - // assert(false); for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) { int depth = new_bc.requestsInfo[i].first_token_depth_in_request + j; for (int k = 0; k < new_bc.beamRequestsInfo[i].sub_request_num; k++) { @@ -1146,6 +1151,9 @@ BeamSearchBatchConfig } } + // how many requests is in speculative phase + new_bc.speculative_request_num = num_active_req + 1; + // Add prompt tokens to the batch for (int i = 0; i < BatchConfig::max_requests_per_batch(); i++) { if (old_bc.request_completed[i] || old_bc.request_running[i]) { @@ -1184,13 +1192,14 @@ BeamSearchBatchConfig spec_infer_tree_width.size() > ssm_decoding_steps ? spec_infer_tree_width[ssm_decoding_steps] : 1; - printf("beam size: %d, %d\n", - new_bc.beamRequestsInfo[i].beam_size, - ssm_decoding_steps); + // printf("beam size: %d, %d\n", + // new_bc.beamRequestsInfo[i].beam_size, + // ssm_decoding_steps); new_bc.beamRequestsInfo[i].max_depth = old_bc.beamRequestsInfo[i].max_depth; - new_bc.sub_requests[i] = - old_bc.sub_requests[i] * new_bc.beamRequestsInfo[i].beam_size; + // new_bc.sub_requests[i] = + // old_bc.sub_requests[i] * new_bc.beamRequestsInfo[i].beam_size; + new_bc.sub_requests[i] = 1; new_bc.beamRequestsInfo[i].sub_request_num = old_bc.beamRequestsInfo[i].sub_request_num; @@ -1218,6 +1227,9 @@ BeamSearchBatchConfig request.tokens.size()) { // request is done new_bc.requestsInfo[i].num_tokens_in_batch = 0; + new_bc.causalMask[i].this_layer_size = 0; + new_bc.beamRequestsInfo[i].sub_request_num = 0; + new_bc.beamRequestsInfo[i].beam_size = 1; } else { // Prompt phase new_bc.requestsInfo[i].num_tokens_in_batch = @@ -1227,12 +1239,8 @@ BeamSearchBatchConfig new_bc.requestsInfo[i].first_token_depth_in_request); request.ssm_cache_size += new_bc.requestsInfo[i].num_tokens_in_batch; BeamTree tree = request.beam_trees[old_bc.model_id]; - appendBitMask(new_bc.causalMask[i], - new_bc.beamRequestsInfo[i].sub_request_num, - old_bc.beamRequestsInfo[i].beam_size, - old_bc.beamRequestsInfo[i].sub_request_num, - tree, - old_bc.beamRequestsInfo[i].current_depth); + appendPendingRequest(new_bc.causalMask[i], + new_bc.requestsInfo[i].num_tokens_in_batch); } if (verbose) { @@ -1258,11 +1266,11 @@ BeamSearchBatchConfig // get value from requestinfo new_bc.tokensInfo[new_bc.num_tokens].token_id = - request.tokens[request.tokens.size() - 1]; + request.tokens[request.tokens.size() - + new_bc.requestsInfo[i].num_tokens_in_batch + j]; new_bc.beamTokenInfo[new_bc.num_tokens].sub_request_index = k; new_bc.num_tokens++; - num_generation_tokens++; } } } @@ -1319,7 +1327,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( new_bc.num_tokens_to_commit = 0; new_bc.num_tokens = 0; - int max_prompt_load_size = get_max_tokens_per_batch(); + int max_prompt_load_size = get_max_verify_tokens_per_batch(); for (int i = 0; i < TreeVerifyBatchConfig::max_requests_per_batch(); i++) { if (old_batches.at(0).request_completed[i]) { continue; @@ -1427,7 +1435,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( new_bc.num_tokens++; new_bc.requestsInfo[i].num_tokens_in_batch++; - if (new_bc.num_tokens > get_max_tokens_per_batch()) { + if (new_bc.num_tokens > get_max_verify_tokens_per_batch()) { assert(false && "Exceeding the space available in the TreeVerify batch"); break; @@ -1453,7 +1461,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( new_bc.num_tokens++; new_bc.requestsInfo[i].num_tokens_in_batch++; - if (new_bc.num_tokens == get_max_tokens_per_batch() && + if (new_bc.num_tokens == get_max_verify_tokens_per_batch() && (j != dfs_tree_inputs.size() - 1)) { cutLayer = true; break; @@ -1542,7 +1550,7 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( new_bc.num_tokens++; } - if (new_bc.num_tokens > get_max_tokens_per_batch()) { + if (new_bc.num_tokens > get_max_verify_tokens_per_batch()) { assert(false && "Exceeding the space available in the TreeVerify batch"); break; @@ -1555,15 +1563,17 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( request.status = Request::RUNNING; new_bc.request_running[i] = true; - std::cout << "new_bc.requestsInfo[i].num_tokens_in_batch: " - << new_bc.requestsInfo[i].num_tokens_in_batch << std::endl; + // std::cout << "new_bc.requestsInfo[i].num_tokens_in_batch: " + // << new_bc.requestsInfo[i].num_tokens_in_batch << + // std::endl; + new_bc.requestsInfo[i].prompt_phase = true; dfs_tree_inputs[guid] = std::vector>{std::make_pair( request.tokens.back(), request.tokens.size() - 1)}; } } else { // launch the request into running phase after loading all prompt - if (get_max_tokens_per_batch() - new_bc.num_tokens > 0) { + if (get_max_verify_tokens_per_batch() - new_bc.num_tokens > 0) { // std::cout << "Initialization running phase: " // << new_bc.requestsInfo[i].num_tokens_in_batch << "\n"; request.status = Request::RUNNING; @@ -1576,9 +1586,11 @@ TreeVerifyBatchConfig RequestManager::prepare_next_batch_verify( new_bc.num_tokens++; new_bc.requestsInfo[i].num_tokens_in_batch++; - std::cout << "new_bc.requestsInfo[i].num_tokens_in_batch: " - << new_bc.requestsInfo[i].num_tokens_in_batch << std::endl; + // std::cout << "new_bc.requestsInfo[i].num_tokens_in_batch2: " + // << new_bc.requestsInfo[i].num_tokens_in_batch << + // std::endl; + new_bc.requestsInfo[i].prompt_phase = true; dfs_tree_inputs[guid] = std::vector>{std::make_pair( request.tokens.back(), request.tokens.size() - 1)}; @@ -1760,20 +1772,14 @@ void RequestManager::update_beam_metadata(BeamSearchBatchConfig &new_bc, // prompt phase, init task void RequestManager::initBitMask(BatchConfig::BitMask &bitmask, int initLength) { - assert(initLength <= BatchConfig::MAX_SPEC_TREE_TOKEN_NUM && - "do not support tree size > 64"); + assert(initLength > 0); // eg. 4 tokens: t1: 0000000..1111, t2: 0000000..1110, t3: 0000000..1100, t4: // 0000000..1000 bitmask.non_tree_cache_size = 0; - bitmask.tree_size = initLength; + bitmask.tree_size = 1; bitmask.prompt_size = initLength; bitmask.this_layer_size = initLength; - for (int i = 0; i < bitmask.prompt_size; i++) { - for (int j = i; j < bitmask.prompt_size; j++) { - bitmask.mask[i] |= (1 << j); - } - } // std::cout << "see bit mask" << bitmask.prompt_size << "\n"; // std::cout << "see bit mask" << std::bitset<64>(bitmask.mask[0]) << "\n"; // std::cout << "see bit mask" << std::bitset<64>(bitmask.mask[1]) << "\n"; @@ -1810,6 +1816,25 @@ void RequestManager::updateBitMask(BatchConfig::BitMask &bitmask, // << "\n"; } +// prompt phase, init task +void RequestManager::appendPendingRequest(BatchConfig::BitMask &bitmask, + int initLength) { + assert(initLength > 0); + std::cout << "append pending bit mask: " << initLength << "\n"; + // eg. 4 tokens: t1: 0000000..1111, t2: 0000000..1110, t3: 0000000..1100, t4: + // 0000000..1000 + bitmask.non_tree_cache_size = 0; + bitmask.tree_size = 1; + bitmask.prompt_size += initLength; + bitmask.this_layer_size = initLength; + + // for (int i = 0; i < bitmask.prompt_size; i++) { + // for (int j = i; j < bitmask.prompt_size; j++) { + // bitmask.mask[i] |= (1 << j); + // } + // } +} + // prepare next beam, append layers to the tree void RequestManager::appendBitMask(BatchConfig::BitMask &bitmask, int newNodes, @@ -1862,12 +1887,6 @@ void RequestManager::appendBitMask(BatchConfig::BitMask &bitmask, } } - // std::cout << "token idx: " << token_idx << ", " << pre_tree_size << ", " - // << new_nodes_start_idx << ", " << newNodes - // << "current depth: " << currentDepth << "\n"; - // std::cout << "new nodes end " << new_nodes_start_idx << "\n"; - - // std::cout << "tree size: " << bitmask.tree_size << "\n"; assert(token_idx == pre_tree_size); assert(currentDepth <= 1 || new_nodes_start_idx == bitmask.tree_size); diff --git a/src/runtime/request_manager.cu b/src/runtime/request_manager.cu index 51c52c3026..8380d6be73 100644 --- a/src/runtime/request_manager.cu +++ b/src/runtime/request_manager.cu @@ -35,10 +35,17 @@ void RequestManager::load_tokens_task( // Extreme long prompts are not supported, only load up to // BatchConfig::max_tokens_per_batch() as prompt - if (batch_config->num_tokens > BatchConfig::max_tokens_per_batch()) { + if (batch_config->num_tokens > BatchConfig::max_tokens_per_batch() && + batch_config->get_mode() == INC_DECODING_MODE) { printf("Warning: too many tokens in prompt, only load up to %d tokens\n", BatchConfig::max_tokens_per_batch()); printf("Got: %d tokens\n", batch_config->num_tokens); + } else if (batch_config->num_tokens > + BatchConfig::max_verify_tokens_per_batch()) { + printf("Warning: Speculative decoding. too many tokens in prompt, only " + "load up to %d tokens\n", + BatchConfig::max_verify_tokens_per_batch()); + printf("Got: %d tokens\n", batch_config->num_tokens); } for (int i = 0; i < batch_config->num_tokens; i++) { @@ -117,8 +124,16 @@ void RequestManager::load_batch_config_task( sizeof(BatchConfig::causalMask), cudaMemcpyHostToDevice, stream)); - total_copy_size += sizeof(BatchConfig::causalMask); + + checkCUDA(cudaMemcpyAsync( + static_cast(handle.batch_config_metadata) + total_copy_size, + &(batch_config->request_completed), + sizeof(BatchConfig::request_completed), + cudaMemcpyHostToDevice, + stream)); + + total_copy_size += sizeof(BatchConfig::request_completed); } else if (batch_config->get_mode() == TREE_VERIFY_MODE) { TreeVerifyBatchConfig const *tree_batch_config = static_cast(batch_config); @@ -137,6 +152,15 @@ void RequestManager::load_batch_config_task( cudaMemcpyHostToDevice, stream)); total_copy_size += sizeof(TreeVerifyBatchConfig::committed_tokens); + + checkCUDA(cudaMemcpyAsync( + static_cast(handle.batch_config_metadata) + total_copy_size, + &(batch_config->request_completed), + sizeof(BatchConfig::request_completed), + cudaMemcpyHostToDevice, + stream)); + + total_copy_size += sizeof(BatchConfig::request_completed); } // add a size check From 8490e50d5744b6731df9fdc4147b2a6ebd4f2d71 Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Tue, 2 Jan 2024 16:20:24 -0500 Subject: [PATCH 24/30] fix --- src/runtime/request_manager.cc | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 88754f5a82..a285932b7f 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -1188,10 +1188,7 @@ BeamSearchBatchConfig int ssm_decoding_steps = profiling_requests[request.guid].ssm_decoding_steps; - new_bc.beamRequestsInfo[i].beam_size = - spec_infer_tree_width.size() > ssm_decoding_steps - ? spec_infer_tree_width[ssm_decoding_steps] - : 1; + new_bc.beamRequestsInfo[i].beam_size = 1; // printf("beam size: %d, %d\n", // new_bc.beamRequestsInfo[i].beam_size, // ssm_decoding_steps); @@ -1820,7 +1817,7 @@ void RequestManager::updateBitMask(BatchConfig::BitMask &bitmask, void RequestManager::appendPendingRequest(BatchConfig::BitMask &bitmask, int initLength) { assert(initLength > 0); - std::cout << "append pending bit mask: " << initLength << "\n"; + // std::cout << "append pending bit mask: " << initLength << "\n"; // eg. 4 tokens: t1: 0000000..1111, t2: 0000000..1110, t3: 0000000..1100, t4: // 0000000..1000 bitmask.non_tree_cache_size = 0; From c12f0c6ddaea6629214278167b047ffa3158b491 Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Wed, 3 Jan 2024 00:28:15 -0500 Subject: [PATCH 25/30] fix request id issue --- src/ops/inc_multihead_self_attention.cu | 42 +++++--------------- src/ops/spec_inc_multihead_self_attention.cu | 8 ++-- src/runtime/request_manager.cc | 6 +++ 3 files changed, 20 insertions(+), 36 deletions(-) diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index 7c8601d3c8..42933cee27 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -52,9 +52,7 @@ __global__ void compute_attention_kernel_generation_kernel( int max_seq_length, int per_head_size, int hidden_size, - BatchConfig::PerRequestInfo *request_infos, - bool is_beam, - int max_beam_width) { + BatchConfig::PerRequestInfo *request_infos) { // q, k using Q_vec = typename VEC_K::Type; @@ -85,10 +83,6 @@ __global__ void compute_attention_kernel_generation_kernel( int const batch_config_request_id = request_infos[request_idx].batch_config_request_id; - int const beam_request_idx = - is_beam ? request_idx / max_beam_width : request_idx; - int const beam_sub_request_idx = is_beam ? request_idx % max_beam_width : 0; - int const first_step = 0; int const tlength = @@ -106,8 +100,7 @@ __global__ void compute_attention_kernel_generation_kernel( // first WARPS_PER_BLOCK for store qk_max, second WARPS_PER_BLOCK for sum __shared__ float red_smem[WARPS_PER_BLOCK * 2]; - const DT *q_ptr = query + - batch_config_request_id * hidden_size * QKV_WEIGHT_NUM + + const DT *q_ptr = query + request_idx * hidden_size * QKV_WEIGHT_NUM + head_idx * per_head_size; __shared__ Q_vec q_vecs[THREADS_PER_KEY][K_VECS_PER_THREAD]; // DT const *q_ptr = @@ -142,10 +135,7 @@ __global__ void compute_attention_kernel_generation_kernel( constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; DT const *k_cache_batch = - key_cache + - (batch_config_request_id * max_beam_width + beam_sub_request_idx) * - max_seq_length * hidden_size + - ki; + key_cache + batch_config_request_id * max_seq_length * hidden_size + ki; int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step; @@ -248,10 +238,7 @@ __global__ void compute_attention_kernel_generation_kernel( // The base pointer for the value in the cache buffer. DT const *v_cache_batch = - value_cache + - (batch_config_request_id * max_beam_width + beam_sub_request_idx) * - max_seq_length * hidden_size + - vi; + value_cache + batch_config_request_id * max_seq_length * hidden_size + vi; if (Dh == Dh_MAX || vi < Dh) { for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) { @@ -297,7 +284,7 @@ __global__ void compute_attention_kernel_generation_kernel( // Output the final values. if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { convert_from_float( - *reinterpret_cast(output_ptr + beam_request_idx * hidden_size + + *reinterpret_cast(output_ptr + request_idx * hidden_size + head_idx * per_head_size + vi), out); } @@ -727,9 +714,7 @@ void compute_o_prod_bias(IncMultiHeadSelfAttentionMeta const *m, BatchConfig::max_sequence_length(), \ m->qProjSize, \ m->hidden_size, \ - m->request_infos, \ - false, \ - 0) + m->request_infos) template void compute_attention_kernel_generation(IncMultiHeadSelfAttentionMeta const *m, @@ -944,14 +929,9 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta const *m, assert(m->qProjSize == m->kProjSize); for (int i = 0; i < bc->max_requests_per_batch(); i++) { - if (bc->request_completed[i]) { - continue; - } else if (tokens_previous_requests < bc->num_generation_tokens) { - tokens_previous_requests += bc->requestsInfo[i].num_tokens_in_batch; + if (bc->request_completed[i] || (!bc->requestsInfo[i].prompt_phase)) { continue; } - assert(tokens_previous_requests == - bc->requestsInfo[i].first_token_offset_in_batch); int num_new_tokens = bc->requestsInfo[i].num_tokens_in_batch; int total_tokens = bc->requestsInfo[i].first_token_depth_in_request + bc->requestsInfo[i].num_tokens_in_batch; @@ -978,8 +958,8 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta const *m, // matrix A's layout: [qProjSize, num_heads, 3, num_new_tokens] // To get query projection, skip over Q entries from previous requests DT const *A = static_cast
(m->devQKVProjArray) + - tokens_previous_requests * m->qProjSize * m->num_q_heads * - QKV_WEIGHT_NUM; + bc->requestsInfo[i].first_token_offset_in_batch * + m->qProjSize * m->num_q_heads * QKV_WEIGHT_NUM; // matrix B: key cache // matrix B's layout: [kProjSize * num_heads, total_tokens] // To get B, skip over K entries from previous requests (all heads + @@ -1117,7 +1097,7 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta const *m, // requests // store the result attn heads, also skip the genration tokens DT *C = static_cast
(m->attn_heads) + - (tokens_previous_requests + bc->num_generation_tokens) * + (bc->requestsInfo[i].first_token_offset_in_batch) * m->num_q_heads * m->vProjSize; checkCUDA(cublasGemmStridedBatchedEx(m->handle.blas, CUBLAS_OP_N, @@ -1145,7 +1125,7 @@ void compute_attention_kernel_prompt(IncMultiHeadSelfAttentionMeta const *m, } tokens_previous_requests += num_new_tokens; } - assert(tokens_previous_requests == num_tokens); + assert(tokens_previous_requests == (num_tokens - bc->num_generation_tokens)); } /*static*/ diff --git a/src/ops/spec_inc_multihead_self_attention.cu b/src/ops/spec_inc_multihead_self_attention.cu index b31e5d0994..a63417de51 100644 --- a/src/ops/spec_inc_multihead_self_attention.cu +++ b/src/ops/spec_inc_multihead_self_attention.cu @@ -501,10 +501,8 @@ void compute_attention_kernel_prompt(SpecIncMultiHeadSelfAttentionMeta const *m, assert(m->qProjSize == m->kProjSize); for (int i = 0; i < bc->max_requests_per_batch(); i++) { - if (bc->request_completed[i]) { - continue; - } else if (tokens_previous_requests < bc->num_generation_tokens) { - tokens_previous_requests += bc->requestsInfo[i].num_tokens_in_batch; + if (bc->request_completed[i] || (!bc->requestsInfo[i].prompt_phase) || + (bc->requestsInfo[i].num_tokens_in_batch == 0)) { continue; } @@ -694,7 +692,7 @@ void compute_attention_kernel_prompt(SpecIncMultiHeadSelfAttentionMeta const *m, tokens_prev_requests_squares += num_new_tokens * total_tokens; } - // assert(tokens_previous_requests == num_tokens); + assert(tokens_previous_requests == (num_tokens - bc->num_generation_tokens)); } template diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index a285932b7f..c867d2a979 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -468,12 +468,14 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, // Incremental phase new_bc.requestsInfo[i].num_tokens_in_batch = 1; num_generation_tokens++; + new_bc.requestsInfo[i].prompt_phase = false; } else { // Prompt phase new_bc.requestsInfo[i].num_tokens_in_batch = std::min(get_max_tokens_per_batch() - new_bc.num_tokens, (int)request.tokens.size() - new_bc.requestsInfo[i].first_token_depth_in_request); + new_bc.requestsInfo[i].prompt_phase = true; } for (int j = 0; j < new_bc.requestsInfo[i].num_tokens_in_batch; j++) { int depth = new_bc.requestsInfo[i].first_token_depth_in_request + j; @@ -509,6 +511,7 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, new_bc.requestsInfo[i].max_sequence_length = new_request.max_sequence_length; new_bc.request_completed[i] = false; + new_bc.requestsInfo[i].prompt_phase = true; num_active_req++; new_bc.requestsInfo[num_active_req].batch_config_request_id = i; // add profile_info for the new request @@ -755,6 +758,7 @@ BeamSearchBatchConfig new_bc.beamRequestsInfo[i].current_depth = 1; profiling_requests[request.guid].ssm_decoding_steps = 0; + new_bc.requestsInfo[i].prompt_phase = true; int ssm_decoding_steps = 0; new_bc.beamRequestsInfo[i].beam_size = @@ -902,6 +906,7 @@ BeamSearchBatchConfig } new_bc.request_completed[i] = false; + new_bc.requestsInfo[i].prompt_phase = true; new_bc.beamRequestsInfo[i].sub_request_num = 1; printf("sub request num == 1, %d \n", @@ -1220,6 +1225,7 @@ BeamSearchBatchConfig &old_bc.causalMask[i], sizeof(BatchConfig::BitMask)); + new_bc.requestsInfo[i].prompt_phase = true; if (new_bc.requestsInfo[i].first_token_depth_in_request >= request.tokens.size()) { // request is done From e17fb8d923b38221d3ab8ba52677505c2c4a9f93 Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Wed, 3 Jan 2024 23:32:45 -0500 Subject: [PATCH 26/30] change MAX_SPECULATIVE_TREE_BRANCHES --- include/flexflow/batch_config.h | 23 ++++++++++++++--------- include/flexflow/request_manager.h | 2 +- src/runtime/request_manager.cc | 11 ++++++++--- 3 files changed, 23 insertions(+), 13 deletions(-) diff --git a/include/flexflow/batch_config.h b/include/flexflow/batch_config.h index ef17ef43ed..3dcae464cc 100644 --- a/include/flexflow/batch_config.h +++ b/include/flexflow/batch_config.h @@ -167,9 +167,10 @@ class BeamSearchBatchConfig : public BatchConfig { int current_depth = -1; int max_depth = MAX_BEAM_DEPTH; - BatchConfig::TokenId tokens[BeamSearchBatchConfig::MAX_BEAM_WIDTH]; - float probs[BeamSearchBatchConfig::MAX_BEAM_WIDTH]; - int parent_id[BeamSearchBatchConfig::MAX_BEAM_WIDTH]; + BatchConfig::TokenId + tokens[BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES]; + float probs[BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES]; + int parent_id[BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES]; int sub_request_num; }; @@ -178,10 +179,11 @@ class BeamSearchBatchConfig : public BatchConfig { }; BeamSearchPerRequestInfo beamRequestsInfo[MAX_NUM_REQUESTS]; - BeamSearchPerTokenInfo beamTokenInfo[MAX_NUM_TOKENS * MAX_BEAM_WIDTH]; + BeamSearchPerTokenInfo + beamTokenInfo[MAX_NUM_TOKENS + + MAX_SPEC_TREE_TOKEN_NUM * MAX_NUM_REQUESTS]; - // why is this == MAX_NUM_REQUESTS * MAX_BEAM_WIDTH? - int sub_requests[MAX_NUM_REQUESTS * MAX_BEAM_WIDTH]; + int sub_requests[MAX_SPECULATIVE_TREE_BRANCHES]; private: size_t current_iteration; @@ -190,9 +192,12 @@ class BeamSearchBatchConfig : public BatchConfig { struct BeamInferenceResult { static int const MAX_NUM_TOKENS = BatchConfig::MAX_NUM_TOKENS; BatchConfig::TokenId - token_ids[MAX_NUM_TOKENS * BeamSearchBatchConfig::MAX_BEAM_WIDTH]; - float probs[MAX_NUM_TOKENS * BeamSearchBatchConfig::MAX_BEAM_WIDTH]; - int parent_id[MAX_NUM_TOKENS * BeamSearchBatchConfig::MAX_BEAM_WIDTH]; + token_ids[MAX_NUM_TOKENS * + BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES]; + float probs[MAX_NUM_TOKENS * + BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES]; + int parent_id[MAX_NUM_TOKENS * + BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES]; }; }; // namespace FlexFlow diff --git a/include/flexflow/request_manager.h b/include/flexflow/request_manager.h index 33714c106e..f74b6c5b9f 100644 --- a/include/flexflow/request_manager.h +++ b/include/flexflow/request_manager.h @@ -76,7 +76,7 @@ struct BeamTree { struct treeLayer { BeamSearchBatchConfig::TokenId tokens[BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES]; - int parent_ids[BeamSearchBatchConfig::MAX_BEAM_WIDTH]; + int parent_ids[BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES]; float probs[BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES]; int nodes_num_this_layer = 0; }; diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index c867d2a979..91a5d3be86 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -767,7 +767,9 @@ BeamSearchBatchConfig : 1; new_bc.beamRequestsInfo[i].max_depth = std::min(new_max_depth, BeamSearchBatchConfig::MAX_BEAM_DEPTH); - for (int j = 0; j < BeamSearchBatchConfig::MAX_BEAM_WIDTH; j++) { + for (int j = 0; + j < BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES; + j++) { new_bc.beamRequestsInfo[i].parent_id[j] = 0; new_bc.beamRequestsInfo[i].probs[j] = 1; } @@ -840,7 +842,8 @@ BeamSearchBatchConfig ? spec_infer_tree_width[ssm_decoding_steps] : 1; new_bc.beamRequestsInfo[i].max_depth = 0; - for (int j = 0; j < BeamSearchBatchConfig::MAX_BEAM_WIDTH; j++) { + for (int j = 0; j < BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES; + j++) { new_bc.beamRequestsInfo[i].parent_id[j] = 0; new_bc.beamRequestsInfo[i].probs[j] = 1; } @@ -900,7 +903,9 @@ BeamSearchBatchConfig std::min(BeamSearchBatchConfig::MAX_BEAM_DEPTH, get_max_tokens_per_batch() - new_bc.requestsInfo[i].num_tokens_in_batch - 1); - for (int j = 0; j < BeamSearchBatchConfig::MAX_BEAM_WIDTH; j++) { + for (int j = 0; + j < BeamSearchBatchConfig::MAX_SPECULATIVE_TREE_BRANCHES; + j++) { new_bc.beamRequestsInfo[i].parent_id[j] = 0; new_bc.beamRequestsInfo[i].probs[j] = 1; } From 429ddb59073f3155acd7f255c97f2153f99d130b Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Thu, 4 Jan 2024 00:06:48 -0500 Subject: [PATCH 27/30] =?UTF-8?q?=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/flexflow/batch_config.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/flexflow/batch_config.h b/include/flexflow/batch_config.h index 3dcae464cc..5c126293cf 100644 --- a/include/flexflow/batch_config.h +++ b/include/flexflow/batch_config.h @@ -183,7 +183,7 @@ class BeamSearchBatchConfig : public BatchConfig { beamTokenInfo[MAX_NUM_TOKENS + MAX_SPEC_TREE_TOKEN_NUM * MAX_NUM_REQUESTS]; - int sub_requests[MAX_SPECULATIVE_TREE_BRANCHES]; + int sub_requests[MAX_NUM_REQUESTS]; private: size_t current_iteration; From 4f61b9f348094f87cc4d32625a65ffb64156d325 Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Mon, 8 Jan 2024 19:31:20 +0000 Subject: [PATCH 28/30] fix --- src/runtime/request_manager.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 91a5d3be86..56a2c122d3 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -43,7 +43,8 @@ std::string LoadBytesFromFile(std::string const &path) { } RequestManager::RequestManager() - : verbose(false), next_available_guid(1000000), num_processed_requests(0) { + : verbose(false), next_available_guid(1000000), num_processed_requests(0), + total_request_run_time(0.0f) { // The following config parameters are set // during ffmodel.compile() // Initialize them to -1 to make sure no one From 29735f2432efd8290bf4ebb301fa96cbb5530eff Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Mon, 8 Jan 2024 22:33:22 +0000 Subject: [PATCH 29/30] fixes to run chatgpt.json prompt dataset in python --- .dockerignore | 2 ++ .gitignore | 3 ++- python/flexflow/core/flexflow_cffi.py | 2 +- src/c/flexflow_c.cc | 6 +++++- src/runtime/model.cu | 1 - tests/inference/python_inference_tests.sh | 3 ++- 6 files changed, 12 insertions(+), 5 deletions(-) diff --git a/.dockerignore b/.dockerignore index a7470203e3..b9f228c009 100644 --- a/.dockerignore +++ b/.dockerignore @@ -17,3 +17,5 @@ python/flexflow/core/legion_cffi_header.py /inference/tokenizer/* /inference/prompt/* /inference/output/* + +/tests/inference/python_test_configs/*.json diff --git a/.gitignore b/.gitignore index 8fcc105f01..7f6a3c4137 100644 --- a/.gitignore +++ b/.gitignore @@ -186,4 +186,5 @@ gpt_tokenizer # pip version python/flexflow/version.txt -inference_tensors \ No newline at end of file +inference_tensors +tests/inference/python_test_configs/*.json diff --git a/python/flexflow/core/flexflow_cffi.py b/python/flexflow/core/flexflow_cffi.py index a3c221474d..00133dacb4 100644 --- a/python/flexflow/core/flexflow_cffi.py +++ b/python/flexflow/core/flexflow_cffi.py @@ -56,7 +56,7 @@ def get_c_name(name): if name is None: return ffi.NULL else: - return ffi.new("char[]", name.encode("ascii")) + return ffi.new("char[]", name.encode("utf-8")) def get_datatype_size(datatype): diff --git a/src/c/flexflow_c.cc b/src/c/flexflow_c.cc index 579fc5e2d1..82a37a9736 100644 --- a/src/c/flexflow_c.cc +++ b/src/c/flexflow_c.cc @@ -1596,7 +1596,11 @@ flexflow_generation_result_t GenerationResult result = handle->generate(prompts, max_seq_length); DEBUG_PRINT( "[Model] generate %p %s %i", handle, text_str.c_str(), max_seq_length); - assert(result.output_tokens.size() <= max_seq_length); + // If the prompt exceeds max seq len, check that we return the prompt with no + // additional token. Otherwise, check that the output does not exceed the max + // sequence length. + assert(result.output_tokens.size() <= max_seq_length || + result.output_tokens.size() == result.input_tokens.size()); output_length_and_tokens[0] = result.output_tokens.size(); std::copy(result.output_tokens.begin(), result.output_tokens.end(), diff --git a/src/runtime/model.cu b/src/runtime/model.cu index c885b29db2..23b7f0efbe 100644 --- a/src/runtime/model.cu +++ b/src/runtime/model.cu @@ -175,7 +175,6 @@ FFHandler } else { handle.batch_config_metadata = nullptr; } - // checkCUDA(cudaMalloc(&handle.workSpace, handle.workSpaceSize)); #ifdef FF_USE_NCCL diff --git a/tests/inference/python_inference_tests.sh b/tests/inference/python_inference_tests.sh index 3544f58e26..10c0821835 100755 --- a/tests/inference/python_inference_tests.sh +++ b/tests/inference/python_inference_tests.sh @@ -6,11 +6,12 @@ set -e cd "${BASH_SOURCE[0]%/*}" # Generate test configs +rm -rf python_test_configs/*.json python python_test_configs/generate_configs.py # Run all tests # Loop through .json files in the ./python_test_configs dir -for file in ./python_test_configs/*.json; do +for file in ./python_test_configs/*"llama"*.json; do # Check filename prefix if [[ $file == *"incr_dec"* ]]; then script="../../inference/python/incr_decoding.py" From ba4af39404bb92af10926222ceb6d9e88a147fb9 Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Tue, 9 Jan 2024 06:56:36 +0000 Subject: [PATCH 30/30] fix --- tests/inference/python_inference_tests.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/inference/python_inference_tests.sh b/tests/inference/python_inference_tests.sh index 10c0821835..a1ee281914 100755 --- a/tests/inference/python_inference_tests.sh +++ b/tests/inference/python_inference_tests.sh @@ -11,7 +11,7 @@ python python_test_configs/generate_configs.py # Run all tests # Loop through .json files in the ./python_test_configs dir -for file in ./python_test_configs/*"llama"*.json; do +for file in ./python_test_configs/*.json; do # Check filename prefix if [[ $file == *"incr_dec"* ]]; then script="../../inference/python/incr_decoding.py"