From d8072ab6efe7bae43058c6a3ffeb94499c804124 Mon Sep 17 00:00:00 2001 From: xinhaoc Date: Thu, 1 Jun 2023 16:09:09 +0000 Subject: [PATCH 01/12] fix --- include/flexflow/inference.h | 1 + inference/spec_infer/spec_infer.cc | 4 +++- src/runtime/request_manager.cc | 23 +++++++++++++++++++++-- 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/include/flexflow/inference.h b/include/flexflow/inference.h index 8825a79283..8ba110583c 100644 --- a/include/flexflow/inference.h +++ b/include/flexflow/inference.h @@ -117,6 +117,7 @@ class RequestManager { &inputSerializedTree, std::vector> const &outputSerializedTree); + int get_requests_init_length(BeamSearchBatchConfig const &old_bc); // TreeVerifyBatchConfig // convert_beam_to_tree_batch_config(BeamSearchBatchConfig const diff --git a/inference/spec_infer/spec_infer.cc b/inference/spec_infer/spec_infer.cc index 8df4cf4028..3f08bf27fb 100644 --- a/inference/spec_infer/spec_infer.cc +++ b/inference/spec_infer/spec_infer.cc @@ -257,7 +257,9 @@ void FlexFlow::top_level_task(Task const *task, assert(fm.get_future_map_domain().get_volume() == 1); Future future = fm.get_future(0); BeamInferenceResult beam_ir = future.get_result(); - if (depth - 1 >= beam_bc.max_beam_depth_all_requests()) { + if (depth - 1 >= beam_bc.max_beam_depth_all_requests() || + depth + 1 + rm.get_requests_init_length(beam_bc) >= + BatchConfig::MAX_NUM_TOKENS) { break; } else { beam_bc = rm.prepare_next_batch_beam(beam_bc, beam_ir); diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index cf0aeb94de..c5b874c798 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -236,6 +236,25 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, /* ----- Speculative Inference Specific functions ----- */ +int RequestManager::get_requests_init_length( + BeamSearchBatchConfig const &old_bc) { + int init_length = 0; + for (int i = 0; i < BatchConfig::MAX_NUM_REQUESTS; i++) { + if (old_bc.request_completed[i]) { + continue; + } + Request &request = + running_request_queue[old_bc.requestsInfo[i].request_guid]; + if (old_bc.requestsInfo[i].token_start_offset + 1 >= + request.tokens.size()) { + init_length = 0; + } else if (request.initial_len > init_length) { + init_length = request.initial_len; + } + } + return init_length; +} + // update beam search metadata BeamSearchBatchConfig RequestManager::prepare_next_batch_beam(BeamSearchBatchConfig const &old_bc, @@ -247,7 +266,7 @@ BeamSearchBatchConfig if (verbose) { std::cout << "print all results" << "\n"; - for (int i = 0; i < 40; i++) { + for (int i = 0; i < 64; i++) { std::cout << result.token_ids[i] << ", "; } std::cout << "Current Beam Depth: " @@ -304,7 +323,7 @@ BeamSearchBatchConfig new_bc.beamRequestsInfo[i].beam_size = old_bc.beamRequestsInfo[i].beam_size; new_bc.beamRequestsInfo[i].max_depth = - old_bc.beamRequestsInfo[i].max_depth; + old_bc.beamRequestsInfo[i].current_depth; // do the slot exchange to minimize the cache exchange in kernel. // std::cout << "update metadata" << std::endl; From f74377afa8c029fe0b87e8efb08bc36adbde7237 Mon Sep 17 00:00:00 2001 From: Zeyu Wang Date: Tue, 27 Jun 2023 01:43:12 +0000 Subject: [PATCH 02/12] Formatting. --- examples/cpp/inference/mixture_of_experts/moe.h | 4 ++-- include/flexflow/inference.h | 2 +- inference/spec_infer/spec_infer.cc | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/cpp/inference/mixture_of_experts/moe.h b/examples/cpp/inference/mixture_of_experts/moe.h index 183229bc07..4fdd3b2e3f 100644 --- a/examples/cpp/inference/mixture_of_experts/moe.h +++ b/examples/cpp/inference/mixture_of_experts/moe.h @@ -22,9 +22,9 @@ struct MoeConfig : InferenceConfig { MoeConfig(void) : InferenceConfig() { //----------------------- MoE layer -------------------------------- // total number of experts - num_exp = 128; + num_exp = 64; // number of experts in each block of fused experts - experts_per_block = 32; + experts_per_block = 16; // number of experts to route each token to num_select = 2; // expert capacity parameters diff --git a/include/flexflow/inference.h b/include/flexflow/inference.h index 1d3b62fb00..5cf9926cff 100644 --- a/include/flexflow/inference.h +++ b/include/flexflow/inference.h @@ -133,7 +133,7 @@ class RequestManager { &inputSerializedTree, std::vector> const &outputSerializedTree); - int get_requests_init_length(BeamSearchBatchConfig const &old_bc); + int get_requests_init_length(BeamSearchBatchConfig const &old_bc); static void load_tokens_task(Legion::Task const *task, diff --git a/inference/spec_infer/spec_infer.cc b/inference/spec_infer/spec_infer.cc index 3cf568fe17..b532f7318d 100644 --- a/inference/spec_infer/spec_infer.cc +++ b/inference/spec_infer/spec_infer.cc @@ -294,7 +294,7 @@ void FlexFlow::top_level_task(Task const *task, if (depth - 1 >= beam_bc_vec[i].max_beam_depth_all_requests() || depth + 1 + rm.get_requests_init_length(beam_bc_vec[i]) >= - BatchConfig::MAX_NUM_TOKENS) { + BatchConfig::MAX_NUM_TOKENS) { break; } else { beam_bc_vec[i] = rm.prepare_next_batch_beam(beam_bc_vec[i], beam_ir); From 3a87e02a9e6ffca9ebfc4dc3694dba4aeea929aa Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Thu, 29 Jun 2023 15:55:03 +0800 Subject: [PATCH 03/12] [Inference] - Fix Multiple-GPUs CI test (#804) * fix linear region requirement * fix set tensor issue --- src/ops/inc_multihead_self_attention.cc | 2 ++ src/ops/inc_multiquery_self_attention.cc | 2 ++ src/ops/linear.cc | 4 ++-- src/ops/spec_inc_multihead_self_attention.cc | 2 ++ src/ops/tree_inc_multihead_self_attention.cc | 2 ++ src/runtime/parallel_tensor.cc | 6 ++++-- 6 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/ops/inc_multihead_self_attention.cc b/src/ops/inc_multihead_self_attention.cc index 765b3c5bfc..07598f99ea 100644 --- a/src/ops/inc_multihead_self_attention.cc +++ b/src/ops/inc_multihead_self_attention.cc @@ -285,6 +285,7 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( dims[0].size = dims[0].degree; dims[1] = inputs[0]->dims[num_dims - 1]; dims[1].size = this->num_heads; + dims[1].is_replica_dim = false; dims[2].size = qParas + kParas + vParas + oParas; if (quantization_type != DT_NONE) { @@ -392,6 +393,7 @@ IncMultiHeadSelfAttention::IncMultiHeadSelfAttention( dims[0].size = dims[0].degree; dims[1] = inputs[0]->dims[num_dims - 1]; dims[1].size = this->num_heads; + dims[1].is_replica_dim = false; dims[2].size = qParas + kParas + vParas + oParas; if (quantization_type != DT_NONE) { dims[2].size = get_quantization_to_byte_size( diff --git a/src/ops/inc_multiquery_self_attention.cc b/src/ops/inc_multiquery_self_attention.cc index 05c57af2ff..6ce448c9ec 100644 --- a/src/ops/inc_multiquery_self_attention.cc +++ b/src/ops/inc_multiquery_self_attention.cc @@ -228,6 +228,7 @@ IncMultiQuerySelfAttention::IncMultiQuerySelfAttention( dims[0] = inputs[0]->dims[num_dims - 2]; dims[0].size = dims[0].degree; dims[1] = inputs[0]->dims[num_dims - 1]; + dims[1].is_replica_dim = false; dims[1].size = this->embed_dim; dims[2].size = this->embed_dim + this->kProjSize + this->vProjSize + this->oProjSize; @@ -308,6 +309,7 @@ IncMultiQuerySelfAttention::IncMultiQuerySelfAttention( dims[0].size = dims[0].degree; dims[1] = inputs[0]->dims[num_dims - 1]; dims[1].size = this->embed_dim; + dims[1].is_replica_dim = false; dims[2].size = this->embed_dim + this->kProjSize + this->vProjSize + this->oProjSize; int seed = std::rand(); diff --git a/src/ops/linear.cc b/src/ops/linear.cc index e3204c01d9..cca92f014f 100644 --- a/src/ops/linear.cc +++ b/src/ops/linear.cc @@ -306,7 +306,7 @@ void Linear::init(FFModel const &ff) { // launcher.add_field(0, FID_DATA); launcher.add_region_requirement(RegionRequirement(inputs[0]->part, 0 /*projection id*/, - READ_ONLY, + WRITE_ONLY, EXCLUSIVE, inputs[0]->region)); launcher.add_field(0, FID_DATA); @@ -365,7 +365,7 @@ void Linear::init_inference(FFModel const &ff, // launcher.add_field(0, FID_DATA); launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, 0 /*projection id*/, - READ_ONLY, + WRITE_ONLY, EXCLUSIVE, batch_inputs[0]->region)); launcher.add_field(0, FID_DATA); diff --git a/src/ops/spec_inc_multihead_self_attention.cc b/src/ops/spec_inc_multihead_self_attention.cc index b9dedda418..e765960985 100644 --- a/src/ops/spec_inc_multihead_self_attention.cc +++ b/src/ops/spec_inc_multihead_self_attention.cc @@ -265,6 +265,7 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( dims[0].size = dims[0].degree; dims[1] = inputs[0]->dims[num_dims - 1]; dims[1].size = this->num_heads; + dims[1].is_replica_dim = false; dims[2].size = qParas + kParas + vParas + oParas; dims[2].degree = 1; dims[2].parallel_idx = -1; @@ -363,6 +364,7 @@ SpecIncMultiHeadSelfAttention::SpecIncMultiHeadSelfAttention( dims[0].size = dims[0].degree; dims[1] = inputs[0]->dims[num_dims - 1]; dims[1].size = this->num_heads; + dims[1].is_replica_dim = false; dims[2].size = qParas + kParas + vParas + oParas; int seed = std::rand(); Initializer *initializer = new GlorotUniform(seed); diff --git a/src/ops/tree_inc_multihead_self_attention.cc b/src/ops/tree_inc_multihead_self_attention.cc index d0bf1d5675..105bd41647 100644 --- a/src/ops/tree_inc_multihead_self_attention.cc +++ b/src/ops/tree_inc_multihead_self_attention.cc @@ -286,6 +286,7 @@ TreeIncMultiHeadSelfAttention::TreeIncMultiHeadSelfAttention( dims[0].size = dims[0].degree; dims[1] = inputs[0]->dims[num_dims - 1]; dims[1].size = this->num_heads; + dims[1].is_replica_dim = false; dims[2].size = qParas + kParas + vParas + oParas; if (quantization_type != DT_NONE) { dims[2].size = get_quantization_to_byte_size( @@ -392,6 +393,7 @@ TreeIncMultiHeadSelfAttention::TreeIncMultiHeadSelfAttention( dims[0].size = dims[0].degree; dims[1] = inputs[0]->dims[num_dims - 1]; dims[1].size = this->num_heads; + dims[1].is_replica_dim = false; dims[2].size = qParas + kParas + vParas + oParas; if (quantization_type != DT_NONE) { dims[2].size = get_quantization_to_byte_size( diff --git a/src/runtime/parallel_tensor.cc b/src/runtime/parallel_tensor.cc index 0ed594fd7e..8f1be15fd1 100644 --- a/src/runtime/parallel_tensor.cc +++ b/src/runtime/parallel_tensor.cc @@ -660,8 +660,10 @@ bool ParallelTensorBase::set_tensor(FFModel const *ff, if (sync_type == ParameterSyncType::NCCL) { // Domain domain = runtime->get_index_space_domain(ctx, parallel_is); // num_replicas = domain.get_volume(); - if (this->num_dims >= 2 && this->dims[this->num_dims - 1].is_replica_dim) { - num_replicas = this->dims[this->num_dims - 1].size; + for (int i = 0; i < this->num_dims; i++) { + if (this->dims[i].is_replica_dim) { + num_replicas *= this->dims[i].size; + } } } else if (sync_type == ParameterSyncType::PS) { num_replicas = 1; From f02c9a0e870129c2cde0ef064405883a06f8d4ac Mon Sep 17 00:00:00 2001 From: DerrickYLJ <99985904+DerrickYLJ@users.noreply.github.com> Date: Thu, 29 Jun 2023 04:12:56 -0400 Subject: [PATCH 04/12] Update README.md (#814) Update links/names of docker container from flexflow-{cuda, hip_rocm} to specinfer-{cuda, hip_rocm} with the disclaimer of CUDA version. Co-authored-by: Gabriele Oliaro --- .github/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/README.md b/.github/README.md index 010d7c07bb..576b1ca84e 100644 --- a/.github/README.md +++ b/.github/README.md @@ -29,7 +29,7 @@ for serving generative LLMs while provably preserving model quality.

## Build/Install SpecInfer -SpecInfer is built on top of FlexFlow. You can build/install SpecInfer by building the inference branch of FlexFlow. Please read the [instructions](../INSTALL.md) for building/installing FlexFlow from source code. If you would like to quickly try SpecInfer, we also provide pre-built Docker packages ([flexflow-cuda](https://github.com/flexflow/FlexFlow/pkgs/container/flexflow-cuda) with a CUDA backend, [flexflow-hip_rocm](https://github.com/flexflow/FlexFlow/pkgs/container/flexflow-hip_rocm) with a HIP-ROCM backend) with all dependencies pre-installed (N.B.: currently, the CUDA pre-built containers are only fully compatible with host machines that have CUDA 11.7 installed), together with [Dockerfiles](./docker) if you wish to build the containers manually. +SpecInfer is built on top of FlexFlow. You can build/install SpecInfer by building the inference branch of FlexFlow. Please read the [instructions](../INSTALL.md) for building/installing FlexFlow from source code. If you would like to quickly try SpecInfer, we also provide pre-built Docker packages ([specinfer-cuda](https://github.com/flexflow/FlexFlow/pkgs/container/specinfer-cuda) with a CUDA backend, [specinfer-hip_rocm](https://github.com/flexflow/FlexFlow/pkgs/container/specinfer-hip_rocm) with a HIP-ROCM backend) with all dependencies pre-installed (N.B.: currently, the CUDA pre-built containers are only fully compatible with host machines that have CUDA 11.7 installed), together with [Dockerfiles](./docker) if you wish to build the containers manually. ## Run SpecInfer The source code of the SpecInfer pipeline is available at [this folder](../inference/spec_infer/). The SpecInfer executable will be available at `/build_dir/inference/spec_infer/spec_infer` at compilation. You can use the following command-line arguments to run SpecInfer: From 08bda773c8dd968e75c6fbbf2bfa8a902197874e Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Thu, 29 Jun 2023 21:57:02 +0800 Subject: [PATCH 05/12] [Inference] - Better device placement in tensor model parallelism (#805) * add data parallelism degree setting * compute multi-device machines views * fix bugs * fix and linting * update inference test, comment out print statements * fix --- .github/workflows/gpu-ci.yml | 1 + .../cpp/inference/mixture_of_experts/moe.cc | 9 ++- .../inference/transformers/transformers.cc | 10 ++-- include/flexflow/config.h | 4 +- include/flexflow/inference.h | 5 +- inference/incr_decoding/incr_decoding.cc | 29 +++++++-- inference/models/llama.cc | 59 +++++++++++++++---- inference/models/opt.cc | 49 +++++++++++++-- inference/spec_infer/spec_infer.cc | 29 +++++++-- src/runtime/inference_manager.cc | 47 +++++++++++---- tests/inference_tests.sh | 16 ++++- 11 files changed, 203 insertions(+), 55 deletions(-) diff --git a/.github/workflows/gpu-ci.yml b/.github/workflows/gpu-ci.yml index 7f83fb2691..bdbb8a751b 100644 --- a/.github/workflows/gpu-ci.yml +++ b/.github/workflows/gpu-ci.yml @@ -180,6 +180,7 @@ jobs: ./tests/gpt_tokenizer_test.sh # Inference tests + export TENSOR_PARALLELISM_TESTS=ON ./tests/inference_tests.sh gpu-ci-flexflow: diff --git a/examples/cpp/inference/mixture_of_experts/moe.cc b/examples/cpp/inference/mixture_of_experts/moe.cc index 0c94452ec1..39459d63ac 100644 --- a/examples/cpp/inference/mixture_of_experts/moe.cc +++ b/examples/cpp/inference/mixture_of_experts/moe.cc @@ -139,8 +139,7 @@ void FlexFlow::top_level_task(Task const *task, Tensor output = ff.arg_top_k(t, /*k=*/1, /*sorted=*/false); //------------------- Initialize the inference manager ------------------ - InferenceManager im( - ff.config, moeConfig.batch_size, moeConfig.num_inflight_batches); + InferenceManager im(ff.config, moeConfig.batch_size); std::unordered_map> mapping; im.compile_model_and_allocate_buffer(&ff, mapping); im.init_operators_inference(&ff); @@ -162,7 +161,7 @@ void FlexFlow::top_level_task(Task const *task, ParallelTensor input_pt; ff.get_parallel_tensor_from_tensor(input, input_pt); assert(im.tensor_buffer.find(input_pt) != im.tensor_buffer.end()); - assert(im.tensor_buffer[input_pt].size() == im.max_num_inflight_batches); + assert(im.tensor_buffer[input_pt].size() == ffConfig.data_parallelism_degree); DataLoader data_loader( ff, moeConfig, data_generator, im.tensor_buffer[input_pt]); @@ -184,13 +183,13 @@ void FlexFlow::top_level_task(Task const *task, std::map batch_configs; std::pair new_prompts; BatchConfig *bc = nullptr; - std::map batch_predictions[im.max_num_inflight_batches]; + std::map batch_predictions[ffConfig.data_parallelism_degree]; assert(im.max_num_tokens_per_batch == moeConfig.batch_size); // simulation loop. For deployment, we will use a while(true) while (processed_requests < moeConfig.total_requests) { - for (int bid = 0; bid < im.max_num_inflight_batches; bid++) { + for (int bid = 0; bid < ffConfig.data_parallelism_degree; bid++) { size_t max_reqs, max_tkns; if (future_handlers.find(bid) == future_handlers.end()) { max_reqs = moeConfig.incremental_mode ? bc->MAX_NUM_REQUESTS diff --git a/examples/cpp/inference/transformers/transformers.cc b/examples/cpp/inference/transformers/transformers.cc index d416fdca3c..d56473c8bd 100644 --- a/examples/cpp/inference/transformers/transformers.cc +++ b/examples/cpp/inference/transformers/transformers.cc @@ -114,9 +114,7 @@ void FlexFlow::top_level_task(Task const *task, Tensor output = ff.arg_top_k(t, /*k=*/1, false); //------------------- Initialize the inference manager ------------------ - InferenceManager im(ff.config, - transformerConfig.batch_size, - transformerConfig.num_inflight_batches); + InferenceManager im(ff.config, transformerConfig.batch_size); std::unordered_map> mapping; im.compile_model_and_allocate_buffer(&ff, mapping); im.init_operators_inference(&ff); @@ -138,7 +136,7 @@ void FlexFlow::top_level_task(Task const *task, ParallelTensor input_pt; ff.get_parallel_tensor_from_tensor(input, input_pt); assert(im.tensor_buffer.find(input_pt) != im.tensor_buffer.end()); - assert(im.tensor_buffer[input_pt].size() == im.max_num_inflight_batches); + assert(im.tensor_buffer[input_pt].size() == ffConfig.data_parallelism_degree); DataLoader data_loader( ff, transformerConfig, data_generator, im.tensor_buffer[input_pt]); @@ -160,14 +158,14 @@ void FlexFlow::top_level_task(Task const *task, std::map batch_configs; std::pair new_prompts; BatchConfig *bc = nullptr; - std::map batch_predictions[im.max_num_inflight_batches]; + std::map batch_predictions[ffConfig.data_parallelism_degree]; assert(im.max_num_tokens_per_batch == transformerConfig.batch_size); // assert(transformerConfig.batch_size <= BatchConfig::MAX_NUM_REQUESTS); // simulation loop. For deployment, we will use a while(true) while (processed_requests < transformerConfig.total_requests) { - for (int bid = 0; bid < im.max_num_inflight_batches; bid++) { + for (int bid = 0; bid < ffConfig.data_parallelism_degree; bid++) { size_t max_reqs, max_tkns; if (future_handlers.find(bid) == future_handlers.end()) { max_reqs = transformerConfig.incremental_mode diff --git a/include/flexflow/config.h b/include/flexflow/config.h index f7c59f7b58..f1b218e50f 100644 --- a/include/flexflow/config.h +++ b/include/flexflow/config.h @@ -143,8 +143,10 @@ class FFConfig { bool enable_parameter_parallel; bool enable_attribute_parallel; bool enable_inplace_optimizations; - // Control tensor model parallelism degree in inference + // Control parallelism degrees in inference + int data_parallelism_degree; int tensor_parallelism_degree; + int pipeline_parallelism_degree; // Control Tensor Op Math Conversion bool allow_tensor_op_math_conversion; std::string dataset_path; diff --git a/include/flexflow/inference.h b/include/flexflow/inference.h index 4da8dbaf20..1fd2fdff78 100644 --- a/include/flexflow/inference.h +++ b/include/flexflow/inference.h @@ -28,9 +28,7 @@ using tokenizers::Tokenizer; class InferenceManager { public: - InferenceManager(FFConfig const &config, - int max_num_tokens_per_batch, - int max_num_inflight_batches); + InferenceManager(FFConfig const &config, int max_num_tokens_per_batch); void compile_model_and_allocate_buffer( FFModel *model, std::unordered_map> const &mapping); @@ -45,7 +43,6 @@ class InferenceManager { FFConfig ff_config; std::unordered_map> tensor_buffer; int max_num_tokens_per_batch; - int max_num_inflight_batches; int num_devices; std::vector machine_views; }; diff --git a/inference/incr_decoding/incr_decoding.cc b/inference/incr_decoding/incr_decoding.cc index a281f52853..d43cab17f9 100644 --- a/inference/incr_decoding/incr_decoding.cc +++ b/inference/incr_decoding/incr_decoding.cc @@ -38,7 +38,9 @@ void parse_input_args(char **argv, ModelType &llm_model_type, bool &use_full_precision, bool &verbose, - int &tensor_parallelism_degree) { + int &data_parallelism_degree, + int &tensor_parallelism_degree, + int &pipeline_parallelism_degree) { for (int i = 1; i < argc; i++) { // llm model type if (!strcmp(argv[i], "-llm-model")) { @@ -83,11 +85,21 @@ void parse_input_args(char **argv, paths.output_file_path = std::string(argv[++i]); continue; } + // data parallelism degree + if (!strcmp(argv[i], "-data-parallelism-degree")) { + data_parallelism_degree = std::stoi(argv[++i]); + continue; + } // tensor parallelism degree if (!strcmp(argv[i], "-tensor-parallelism-degree")) { tensor_parallelism_degree = std::stoi(argv[++i]); continue; } + // pipeline parallelism degree + if (!strcmp(argv[i], "-pipeline-parallelism-degree")) { + pipeline_parallelism_degree = std::stoi(argv[++i]); + continue; + } if (!strcmp(argv[i], "--use-full-precision")) { use_full_precision = true; continue; @@ -112,7 +124,9 @@ void FlexFlow::top_level_task(Task const *task, ModelType model_type; bool use_full_precision = false; bool verbose = false; - int tensor_parallelism_degree = 1; + size_t num_devices = ffconfig.workersPerNode * ffconfig.numNodes; + int data_parallelism_degree = 1, tensor_parallelism_degree = 1, + pipeline_parallelism_degree = -1; InputArgs const &command_args = HighLevelRuntime::get_input_args(); char **argv = command_args.argv; @@ -123,13 +137,20 @@ void FlexFlow::top_level_task(Task const *task, model_type, use_full_precision, verbose, - tensor_parallelism_degree); + data_parallelism_degree, + tensor_parallelism_degree, + pipeline_parallelism_degree); + ffconfig.data_parallelism_degree = data_parallelism_degree; ffconfig.tensor_parallelism_degree = tensor_parallelism_degree; + ffconfig.pipeline_parallelism_degree = + pipeline_parallelism_degree == -1 + ? num_devices / (tensor_parallelism_degree * data_parallelism_degree) + : pipeline_parallelism_degree; assert(model_type != ModelType::UNKNOWN && "Invalid LLM model type passed (or no type was passed)."); - InferenceManager im(ffconfig, BatchConfig::MAX_NUM_TOKENS, 1); + InferenceManager im(ffconfig, BatchConfig::MAX_NUM_TOKENS); RequestManager rm(model_type, file_paths.tokenizer_file_path, /*verbose*/ verbose, diff --git a/inference/models/llama.cc b/inference/models/llama.cc index f7c1563095..1e61f43a98 100644 --- a/inference/models/llama.cc +++ b/inference/models/llama.cc @@ -29,8 +29,27 @@ void LLAMA::create_llama_model(FFModel &ff, // do not apply cpu offload in beam search model. Config llama_config(model_config_file_path); llama_config.printConfig(); - //------------------------------compute machine views ------------------ + //---------------------- parallelization setup work ---------------------- int num_devices = ff.config.workersPerNode * ff.config.numNodes; + int num_transformer_layers = llama_config.n_layers; + assert(num_transformer_layers % ff.config.pipeline_parallelism_degree == 0); + int num_layers_per_pp_block = + num_transformer_layers / ff.config.pipeline_parallelism_degree; + int num_devices_per_data_parallelism_line = + num_devices / ff.config.data_parallelism_degree; + + // std::cout << "dp: " << ff.config.data_parallelism_degree + // << " tp: " << ff.config.tensor_parallelism_degree + // << " pp: " << ff.config.pipeline_parallelism_degree << std::endl; + // std::cout << "num_devices: " << num_devices << std::endl; + // std::cout << "num_transformer_layers: " << num_transformer_layers + // << std::endl; + // std::cout << "num_devices_per_data_parallelism_line: " + // << num_devices_per_data_parallelism_line << std::endl; + // std::cout << "num layers: " << llama_config.n_layers << std::endl; + + //------------------------------compute machine views ------------------ + // single device std::vector machine_views; for (int i = 0; i < num_devices; i++) { MachineView view; @@ -41,6 +60,7 @@ void LLAMA::create_llama_model(FFModel &ff, view.start_device_id = i; machine_views.push_back(view); } + assert(machine_views.size() == num_devices); std::unordered_map> mapping; std::unordered_map weights_layers; @@ -51,7 +71,10 @@ void LLAMA::create_llama_model(FFModel &ff, int const token_dims[] = {BatchConfig::MAX_NUM_TOKENS, 1}; input = ff.create_tensor<2>(token_dims, DT_INT32); } - mapping[input].push_back(machine_views[0]); + for (int i = 0; i < ff.config.data_parallelism_degree; i++) { + mapping[input].push_back( + machine_views[i * num_devices_per_data_parallelism_line]); + } Initializer *embed_init = new UniformInitializer(std::rand(), 0, 0); @@ -78,9 +101,10 @@ void LLAMA::create_llama_model(FFModel &ff, Layer *embedding = ff.layers.back(); weights_layers.emplace("tok_embeddings_weight", embedding); - int num_transformer_layers = llama_config.n_layers; - int num_transformer_layers_per_stage = - (num_transformer_layers + num_pipeline_stages - 1) / num_pipeline_stages; + // int num_transformer_layers = llama_config.n_layers; + // int num_transformer_layers_per_stage = + // (num_transformer_layers + num_pipeline_stages - 1) / + // num_pipeline_stages; for (int i = 0; i < num_transformer_layers; i++) { // step 1: attention @@ -89,12 +113,25 @@ void LLAMA::create_llama_model(FFModel &ff, ff.rms_norm(token, llama_config.norm_eps, llama_config.dim); Layer *attention_norm = ff.layers.back(); - if (i % num_transformer_layers_per_stage == 0) { - // Map att_norm to the next GPU - // since the size of att_norm is minimum across - // all tensors - mapping[att_norm].push_back( - machine_views[i / num_transformer_layers_per_stage]); + // if (i % num_transformer_layers_per_stage == 0) { + // // Map att_norm to the next GPU + // // since the size of att_norm is minimum across + // // all tensors + // mapping[att_norm].push_back( + // machine_views[i / num_transformer_layers_per_stage]); + // } + for (int dp_index = 0; dp_index < ff.config.data_parallelism_degree; + dp_index++) { + int pp_block_idx = i / num_layers_per_pp_block; + int first_device_idx = dp_index * num_devices_per_data_parallelism_line + + ff.config.tensor_parallelism_degree * pp_block_idx; + // std::cout << "assigning layer " << i << " to devices " << + // first_device_idx + // << "-" + // << first_device_idx + ff.config.tensor_parallelism_degree - 1 + // << std::endl; + assert(first_device_idx < num_devices); + mapping[att_norm].push_back(machine_views[first_device_idx]); } weights_layers.emplace("layers_" + std::to_string(i) + diff --git a/inference/models/opt.cc b/inference/models/opt.cc index 1e81e4eba7..499eb92642 100644 --- a/inference/models/opt.cc +++ b/inference/models/opt.cc @@ -28,8 +28,27 @@ void OPT::create_opt_model(FFModel &ff, bool use_full_precision) { Config opt_config(model_config_file_path); opt_config.printConfig(); - //------------------------------compute machine views ------------------ + //---------------------- parallelization setup work ---------------------- int num_devices = ff.config.workersPerNode * ff.config.numNodes; + int num_transformer_layers = opt_config.num_hidden_layers; + assert(num_transformer_layers % ff.config.pipeline_parallelism_degree == 0); + int num_layers_per_pp_block = + num_transformer_layers / ff.config.pipeline_parallelism_degree; + int num_devices_per_data_parallelism_line = + num_devices / ff.config.data_parallelism_degree; + + // std::cout << "dp: " << ff.config.data_parallelism_degree + // << " tp: " << ff.config.tensor_parallelism_degree + // << " pp: " << ff.config.pipeline_parallelism_degree << std::endl; + // std::cout << "num_devices: " << num_devices << std::endl; + // std::cout << "num_transformer_layers: " << num_transformer_layers + // << std::endl; + // std::cout << "num_devices_per_data_parallelism_line: " + // << num_devices_per_data_parallelism_line << std::endl; + // std::cout << "num layers: " << opt_config.num_hidden_layers << std::endl; + + //------------------------------compute machine views ------------------ + // single device std::vector machine_views; for (int i = 0; i < num_devices; i++) { MachineView view; @@ -40,6 +59,7 @@ void OPT::create_opt_model(FFModel &ff, view.start_device_id = i; machine_views.push_back(view); } + assert(machine_views.size() == num_devices); std::unordered_map> mapping; std::unordered_map weights_layers; @@ -52,8 +72,12 @@ void OPT::create_opt_model(FFModel &ff, input = ff.create_tensor<2>(token_dims, DT_INT32); position_input = ff.create_tensor<2>(token_dims, DT_INT32); } - mapping[input].push_back(machine_views[0]); - mapping[position_input].push_back(machine_views[0]); + for (int i = 0; i < ff.config.data_parallelism_degree; i++) { + mapping[input].push_back( + machine_views[i * num_devices_per_data_parallelism_line]); + mapping[position_input].push_back( + machine_views[i * num_devices_per_data_parallelism_line]); + } Initializer *embed_init = new UniformInitializer(std::rand(), 0, 0); std::vector axes = {0}; @@ -118,10 +142,23 @@ void OPT::create_opt_model(FFModel &ff, "_attention_layer_norm_weight", self_attn_layer_norm); - if (i % num_transformer_layers_per_stage == 0) { - mapping[hidden_states].push_back( - machine_views[i / num_transformer_layers_per_stage]); + for (int dp_index = 0; dp_index < ff.config.data_parallelism_degree; + dp_index++) { + int pp_block_idx = i / num_layers_per_pp_block; + int first_device_idx = dp_index * num_devices_per_data_parallelism_line + + ff.config.tensor_parallelism_degree * pp_block_idx; + // std::cout << "assigning layer " << i << " to devices " << + // first_device_idx + // << "-" + // << first_device_idx + ff.config.tensor_parallelism_degree - 1 + // << std::endl; + assert(first_device_idx < num_devices); + mapping[hidden_states].push_back(machine_views[first_device_idx]); } + // if (i % num_transformer_layers_per_stage == 0) { + // mapping[hidden_states].push_back( + // machine_views[i / num_transformer_layers_per_stage]); + // } Tensor mha; switch (mode) { diff --git a/inference/spec_infer/spec_infer.cc b/inference/spec_infer/spec_infer.cc index 72666ed312..fbb07b2b25 100644 --- a/inference/spec_infer/spec_infer.cc +++ b/inference/spec_infer/spec_infer.cc @@ -44,7 +44,9 @@ void parse_input_args(char **argv, ModelTypes &model_types, bool &use_full_precision, bool &verbose, - int &tensor_parallelism_degree) { + int &data_parallelism_degree, + int &tensor_parallelism_degree, + int &pipeline_parallelism_degree) { for (int i = 1; i < argc; i++) { // llm model type if (!strcmp(argv[i], "-llm-model")) { @@ -115,11 +117,21 @@ void parse_input_args(char **argv, paths.output_file_path = std::string(argv[++i]); continue; } + // data parallelism degree + if (!strcmp(argv[i], "-data-parallelism-degree")) { + data_parallelism_degree = std::stoi(argv[++i]); + continue; + } // tensor parallelism degree if (!strcmp(argv[i], "-tensor-parallelism-degree")) { tensor_parallelism_degree = std::stoi(argv[++i]); continue; } + // pipeline parallelism degree + if (!strcmp(argv[i], "-pipeline-parallelism-degree")) { + pipeline_parallelism_degree = std::stoi(argv[++i]); + continue; + } if (!strcmp(argv[i], "--use-full-precision")) { use_full_precision = true; continue; @@ -141,7 +153,9 @@ void FlexFlow::top_level_task(Task const *task, ModelTypes model_types; bool use_full_precision = false; bool verbose = false; - int tensor_parallelism_degree = 1; + size_t num_devices = ffconfig.workersPerNode * ffconfig.numNodes; + int data_parallelism_degree = 1, tensor_parallelism_degree = 1, + pipeline_parallelism_degree = -1; InputArgs const &command_args = HighLevelRuntime::get_input_args(); char **argv = command_args.argv; @@ -152,8 +166,15 @@ void FlexFlow::top_level_task(Task const *task, model_types, use_full_precision, verbose, - tensor_parallelism_degree); + data_parallelism_degree, + tensor_parallelism_degree, + pipeline_parallelism_degree); + ffconfig.data_parallelism_degree = data_parallelism_degree; ffconfig.tensor_parallelism_degree = tensor_parallelism_degree; + ffconfig.pipeline_parallelism_degree = + pipeline_parallelism_degree == -1 + ? num_devices / (tensor_parallelism_degree * data_parallelism_degree) + : pipeline_parallelism_degree; if (file_paths.ssm_weight_file_paths.size() == 0) { assert(false && @@ -178,7 +199,7 @@ void FlexFlow::top_level_task(Task const *task, } // Create SentencePiece tokenizer or OPT tokenizer - InferenceManager im(ffconfig, BatchConfig::MAX_NUM_TOKENS, 1); + InferenceManager im(ffconfig, BatchConfig::MAX_NUM_TOKENS); RequestManager rm(model_types.llm_model_type, file_paths.tokenizer_file_path, /*verbose*/ verbose, diff --git a/src/runtime/inference_manager.cc b/src/runtime/inference_manager.cc index f844834761..67a78f9700 100644 --- a/src/runtime/inference_manager.cc +++ b/src/runtime/inference_manager.cc @@ -29,12 +29,32 @@ LegionRuntime::Logger::Category log_inf_mgr("InferenceManager"); LegionRuntime::Logger::Category log_offload("Offloading"); InferenceManager::InferenceManager(FFConfig const &_config, - int _max_num_tokens_per_batch, - int _max_num_inflight_batches) - : ff_config(_config), max_num_tokens_per_batch(_max_num_tokens_per_batch), - max_num_inflight_batches(_max_num_inflight_batches) { - // populate array of valid single-device machine views + int _max_num_tokens_per_batch) + : ff_config(_config), max_num_tokens_per_batch(_max_num_tokens_per_batch) { num_devices = ff_config.workersPerNode * ff_config.numNodes; + // Check parallelization degrees + assert(ff_config.data_parallelism_degree <= num_devices && + "Data parallelism degree exceeds number of available devices"); + assert(num_devices % ff_config.data_parallelism_degree == 0 && + "Number of available devices is not divisible by data parallelism " + "degree"); + assert(ff_config.tensor_parallelism_degree <= num_devices && + "Tensor parallelism degree exceeds number of available devices"); + assert(num_devices % ff_config.tensor_parallelism_degree == 0 && + "Number of available devices is not divisible by tensor parallelism " + "degree"); + assert(ff_config.pipeline_parallelism_degree <= num_devices && + "Pipeline parallelism degree exceeds number of available devices"); + assert(num_devices % ff_config.pipeline_parallelism_degree == 0 && + "Number of available devices is not divisible by pipeline parallelism " + "degree"); + assert(ff_config.data_parallelism_degree * + ff_config.tensor_parallelism_degree * + ff_config.pipeline_parallelism_degree == + num_devices && + "Product of data, tensor, and pipeline parallelism degrees does not " + "match the number of available devices"); + // populate array of valid single-device machine views for (int i = 0; i < num_devices; i++) { MachineView view; view.device_type = MachineView::GPU; @@ -90,6 +110,7 @@ void InferenceManager::compile_model_and_allocate_buffer( assert(pt->owner_op != nullptr); mapping[pt->owner_op] = it.second; } + // std::cout << std::endl << std::endl << "Operators MVs:" << std::endl; for (int op_idx = 0; op_idx < model->operators.size(); op_idx++) { Op const *op = model->operators[op_idx]; // Skip weight operators @@ -100,12 +121,12 @@ void InferenceManager::compile_model_and_allocate_buffer( std::vector machine_views; if (mapping.find(op) != mapping.end()) { machine_views = mapping[op]; - assert(machine_views.size() == max_num_inflight_batches); + assert(machine_views.size() == ff_config.data_parallelism_degree); } else { // Mapping the current operator using the same machine // view as the inputs assert(op->numInputs > 0); - for (int j = 0; j < max_num_inflight_batches; j++) { + for (int j = 0; j < ff_config.data_parallelism_degree; j++) { MachineView mv = tensor_buffer[op->inputs[0]][j]->machine_view; for (int k = 1; k < op->numInputs; k++) { if (mv != tensor_buffer[op->inputs[k]][j]->machine_view) { @@ -143,14 +164,14 @@ void InferenceManager::compile_model_and_allocate_buffer( assert(mv.start_device_id + mv.dim[0] <= num_devices); machine_views.push_back(mv); } - assert(machine_views.size() == max_num_inflight_batches); + assert(machine_views.size() == ff_config.data_parallelism_degree); } // std::cout << "operator: " << op->name << std::endl; // for (int i = 0; i < op->numInputs; i++) { // op->inputs[i]->print("input pt"); // std::cout << "input mv: " << op->inputs[i]->machine_view << std::endl; // } - + // std::cout << "Op " << op->name << ": "; for (int i = 0; i < op->numOutputs; i++) { ParallelTensor pt_base = op->outputs[i]; assert(tensor_buffer.find(pt_base) == tensor_buffer.end()); @@ -211,7 +232,7 @@ void InferenceManager::compile_model_and_allocate_buffer( } } if (!found_parallel_tensor) { - for (int j = 0; j < max_num_inflight_batches; j++) { + for (int j = 0; j < ff_config.data_parallelism_degree; j++) { // Copy the metadata from pt_base to pt ParallelTensor pt = new ParallelTensorBase(*pt_base); pt->region = @@ -221,6 +242,7 @@ void InferenceManager::compile_model_and_allocate_buffer( pt->part = runtime->get_logical_partition( ctx, pt->region, pt_base->part.get_index_partition()); pt->machine_view = machine_views[j]; + // std::cout << "output mv: " << pt->machine_view << std::endl; Domain part_domain = runtime->get_index_space_domain(ctx, pt_base->parallel_is); assert(pt->machine_view.get_domain() == part_domain); @@ -230,11 +252,12 @@ void InferenceManager::compile_model_and_allocate_buffer( assert(tensor_buffer.find(pt_base) == tensor_buffer.end()); tensor_buffer[pt_base] = list; } + // std::cout << std::endl; } } void InferenceManager::init_operators_inference(FFModel *model) { - for (int batch_index = 0; batch_index < max_num_inflight_batches; + for (int batch_index = 0; batch_index < ff_config.data_parallelism_degree; batch_index++) { int expert_device_index = 0; int device_index = batch_index % num_devices; @@ -290,7 +313,7 @@ FutureMap InferenceManager::inference(FFModel *model, assert(bc.num_active_tokens() > 0 && bc.num_active_requests() > 0); // We currently assume that the index-th batch will be placed // on the device_index-th device (except for the experts layers) - int batch_index = index % max_num_inflight_batches; + int batch_index = index % ff_config.data_parallelism_degree; FutureMap fm; bool found_input_operator = false; for (size_t o = 0; o < model->operators.size(); o++) { diff --git a/tests/inference_tests.sh b/tests/inference_tests.sh index 1262ec21d5..3e0d7cac53 100755 --- a/tests/inference_tests.sh +++ b/tests/inference_tests.sh @@ -48,9 +48,13 @@ mkdir -p ../inference/output # Tensor parallelism tests if [ "$TENSOR_PARALLELISM_TESTS" = "ON" ]; then + # LLAMA + ../build/inference/spec_infer/spec_infer -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 --use-full-precision -llm-model llama -llm-weight ../inference/weights/llama_7B_weights/ -llm-config ../inference/models/configs/llama_7B.json -ssm-model llama -ssm-weight ../inference/weights/llama_160M_weights/ -ssm-config ../inference/models/configs/llama_160M.json -tokenizer ../inference/tokenizer/tokenizer.model -prompt ../inference/prompt/test.json -output-file ../inference/output/spec_inference_llama_tp.txt -tensor-parallelism-degree 2 # LLAMA (half precision) ../build/inference/spec_infer/spec_infer -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model llama -llm-weight ../inference/weights/llama_7B_weights_half/ -llm-config ../inference/models/configs/llama_7B.json -ssm-model llama -ssm-weight ../inference/weights/llama_160M_weights_half/ -ssm-config ../inference/models/configs/llama_160M.json -tokenizer ../inference/tokenizer/tokenizer.model -prompt ../inference/prompt/test.json -output-file ../inference/output/spec_inference_llama_half_tp.txt -tensor-parallelism-degree 2 + # OPT + ../build/inference/spec_infer/spec_infer -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 --use-full-precision -llm-model opt -llm-weight ../inference/weights/opt_6B_weights/ -llm-config ../inference/models/configs/opt_6B.json -ssm-model opt -ssm-weight ../inference/weights/opt_125M_weights/ -ssm-config ../inference/models/configs/opt_125M.json -tokenizer ../inference/tokenizer/ -prompt ../inference/prompt/test.json -output-file ../inference/output/spec_inference_opt_tp.txt -tensor-parallelism-degree 2 # OPT (half precision) ../build/inference/spec_infer/spec_infer -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model opt -llm-weight ../inference/weights/opt_6B_weights_half/ -llm-config ../inference/models/configs/opt_6B.json -ssm-model opt -ssm-weight ../inference/weights/opt_125M_weights_half/ -ssm-config ../inference/models/configs/opt_125M.json -tokenizer ../inference/tokenizer/ -prompt ../inference/prompt/test.json -output-file ../inference/output/spec_inference_opt_half_tp.txt -tensor-parallelism-degree 2 fi @@ -86,6 +90,8 @@ if [ "$TENSOR_PARALLELISM_TESTS" = "ON" ]; then # LLAMA (small model, half precision) ../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model llama -llm-weight ../inference/weights/llama_160M_weights_half/ -llm-config ../inference/models/configs/llama_160M.json -tokenizer ../inference/tokenizer/tokenizer.model -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_llama_160M_half_tp.txt -tensor-parallelism-degree 2 + # LLAMA (big model) + ../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 --use-full-precision -llm-model llama -llm-weight ../inference/weights/llama_7B_weights/ -llm-config ../inference/models/configs/llama_7B.json -tokenizer ../inference/tokenizer/tokenizer.model -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_llama_7B_tp.txt -tensor-parallelism-degree 2 # LLAMA (big model, half precision) ../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model llama -llm-weight ../inference/weights/llama_7B_weights_half/ -llm-config ../inference/models/configs/llama_7B.json -tokenizer ../inference/tokenizer/tokenizer.model -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_llama_7B_half_tp.txt -tensor-parallelism-degree 2 @@ -94,6 +100,8 @@ if [ "$TENSOR_PARALLELISM_TESTS" = "ON" ]; then # OPT (small model, half precision) ../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model opt -llm-weight ../inference/weights/opt_125M_weights_half/ -llm-config ../inference/models/configs/opt_125M.json -tokenizer ../inference/tokenizer/ -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_opt_125M_half_tp.txt -tensor-parallelism-degree 2 + # OPT (big model) + ../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 --use-full-precision -llm-model opt -llm-weight ../inference/weights/opt_6B_weights/ -llm-config ../inference/models/configs/opt_6B.json -tokenizer ../inference/tokenizer/ -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_opt_6B_tp.txt -tensor-parallelism-degree 2 # OPT (big model, half precision) ../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model opt -llm-weight ../inference/weights/opt_6B_weights_half/ -llm-config ../inference/models/configs/opt_6B.json -tokenizer ../inference/tokenizer/ -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_opt_6B_half_tp.txt -tensor-parallelism-degree 2 fi @@ -143,13 +151,17 @@ compare_speed_spec_infer_incr_decoding "../inference/output/incr_decoding_opt_6B ############ Alignment between tensor model parallelism and pipeline parallelism only ################# if [ "$TENSOR_PARALLELISM_TESTS" = "ON" ]; then - # diff <(tail -n +2 "../inference/output/spec_inference_llama_half_tp.txt") <(tail -n +2 "../inference/output/spec_inference_llama_half.txt") + diff <(tail -n +2 "../inference/output/spec_inference_llama_tp.txt") <(tail -n +2 "../inference/output/spec_inference_llama.txt") + diff <(tail -n +2 "../inference/output/spec_inference_opt_tp.txt") <(tail -n +2 "../inference/output/spec_inference_opt.txt") + diff <(tail -n +2 "../inference/output/spec_inference_llama_half_tp.txt") <(tail -n +2 "../inference/output/spec_inference_llama_half.txt") diff <(tail -n +2 "../inference/output/spec_inference_opt_half_tp.txt") <(tail -n +2 "../inference/output/spec_inference_opt_half.txt") diff <(tail -n +2 "../inference/output/incr_decoding_llama_160M_tp.txt") <(tail -n +2 "../inference/output/incr_decoding_llama_160M.txt") - diff <(tail -n +2 "../inference/output/incr_decoding_llama_160M_half_tp.txt") <(tail -n +2 "../inference/output/incr_decoding_llama_160M_half.txt") + # diff <(tail -n +2 "../inference/output/incr_decoding_llama_160M_half_tp.txt") <(tail -n +2 "../inference/output/incr_decoding_llama_160M_half.txt") + diff <(tail -n +2 "../inference/output/incr_decoding_llama_7B_tp.txt") <(tail -n +2 "../inference/output/incr_decoding_llama_7B.txt") diff <(tail -n +2 "../inference/output/incr_decoding_llama_7B_half_tp.txt") <(tail -n +2 "../inference/output/incr_decoding_llama_7B_half.txt") diff <(tail -n +2 "../inference/output/incr_decoding_opt_125M_tp.txt") <(tail -n +2 "../inference/output/incr_decoding_opt_125M.txt") diff <(tail -n +2 "../inference/output/incr_decoding_opt_125M_half_tp.txt") <(tail -n +2 "../inference/output/incr_decoding_opt_125M_half.txt") + diff <(tail -n +2 "../inference/output/incr_decoding_opt_6B_tp.txt") <(tail -n +2 "../inference/output/incr_decoding_opt_6B.txt") diff <(tail -n +2 "../inference/output/incr_decoding_opt_6B_half_tp.txt") <(tail -n +2 "../inference/output/incr_decoding_opt_6B_half.txt") fi From e47a1795045c2fc4a0fe4fe54ab87bd601069d55 Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Fri, 30 Jun 2023 15:01:33 +0800 Subject: [PATCH 06/12] Revert "[Inference] fix bug when init_length + beam_depth > max_num_tokens" (#821) --- .../cpp/inference/mixture_of_experts/moe.h | 4 ++-- include/flexflow/inference.h | 1 - inference/spec_infer/spec_infer.cc | 7 +++--- src/runtime/request_manager.cc | 23 ++----------------- 4 files changed, 7 insertions(+), 28 deletions(-) diff --git a/examples/cpp/inference/mixture_of_experts/moe.h b/examples/cpp/inference/mixture_of_experts/moe.h index 4fdd3b2e3f..183229bc07 100644 --- a/examples/cpp/inference/mixture_of_experts/moe.h +++ b/examples/cpp/inference/mixture_of_experts/moe.h @@ -22,9 +22,9 @@ struct MoeConfig : InferenceConfig { MoeConfig(void) : InferenceConfig() { //----------------------- MoE layer -------------------------------- // total number of experts - num_exp = 64; + num_exp = 128; // number of experts in each block of fused experts - experts_per_block = 16; + experts_per_block = 32; // number of experts to route each token to num_select = 2; // expert capacity parameters diff --git a/include/flexflow/inference.h b/include/flexflow/inference.h index ca3a61592f..1fd2fdff78 100644 --- a/include/flexflow/inference.h +++ b/include/flexflow/inference.h @@ -130,7 +130,6 @@ class RequestManager { &inputSerializedTree, std::vector> const &outputSerializedTree); - int get_requests_init_length(BeamSearchBatchConfig const &old_bc); static void load_tokens_task(Legion::Task const *task, diff --git a/inference/spec_infer/spec_infer.cc b/inference/spec_infer/spec_infer.cc index 2f581b7c34..fbb07b2b25 100644 --- a/inference/spec_infer/spec_infer.cc +++ b/inference/spec_infer/spec_infer.cc @@ -306,16 +306,15 @@ void FlexFlow::top_level_task(Task const *task, for (int i = 0; i < num_ssms; i++) { while (true) { - depth = beam_bc_vec[i].current_depth_all_requests(); + beam_bc = beam_bc_vec[i]; + depth = beam_bc.beamRequestsInfo[0].current_depth; FutureMap fm = im.inference(rm.get_model(0), 0, beam_bc_vec[i]); assert(fm.get_future_map_domain().get_volume() == 1); Future future = fm.get_future(0); BeamInferenceResult beam_ir = future.get_result(); - if (depth - 1 >= beam_bc_vec[i].max_beam_depth_all_requests() || - depth + 1 + rm.get_requests_init_length(beam_bc_vec[i]) >= - BatchConfig::MAX_NUM_TOKENS) { + if (depth - 1 >= BeamSearchBatchConfig::MAX_BEAM_DEPTH) { break; } else { beam_bc_vec[i] = rm.prepare_next_batch_beam(beam_bc_vec[i], beam_ir); diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 2211a8df78..56b9bf6241 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -334,25 +334,6 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, /* ----- Speculative Inference Specific functions ----- */ -int RequestManager::get_requests_init_length( - BeamSearchBatchConfig const &old_bc) { - int init_length = 0; - for (int i = 0; i < BatchConfig::MAX_NUM_REQUESTS; i++) { - if (old_bc.request_completed[i]) { - continue; - } - Request &request = - running_request_queue[old_bc.requestsInfo[i].request_guid]; - if (old_bc.requestsInfo[i].token_start_offset + 1 >= - request.tokens.size()) { - init_length = 0; - } else if (request.initial_len > init_length) { - init_length = request.initial_len; - } - } - return init_length; -} - // update beam search metadata BeamSearchBatchConfig RequestManager::prepare_next_batch_beam(BeamSearchBatchConfig const &old_bc, @@ -364,7 +345,7 @@ BeamSearchBatchConfig if (verbose) { std::cout << "print all results" << "\n"; - for (int i = 0; i < 64; i++) { + for (int i = 0; i < 40; i++) { std::cout << result.token_ids[i] << ", "; } std::cout << "Current Beam Depth: " @@ -423,7 +404,7 @@ BeamSearchBatchConfig new_bc.beamRequestsInfo[i].beam_size = old_bc.beamRequestsInfo[i].beam_size; new_bc.beamRequestsInfo[i].max_depth = - old_bc.beamRequestsInfo[i].current_depth; + old_bc.beamRequestsInfo[i].max_depth; // do the slot exchange to minimize the cache exchange in kernel. std::cout << "update metadata" << std::endl; From d038e946e4e0dd5fdf4048e698767b447425dda0 Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Wed, 5 Jul 2023 14:41:31 -0400 Subject: [PATCH 07/12] Merge `master` branch into `inference` (#835) * Fix directory in python example in INSTALL.md (#783) * Remove incomplete sentence in readme (#784) * Fix Code Color in README (#822) Specify code block is Python to have correct coloring in second code block in README.md * Update README.md (#824) Co-authored-by: Zhihao Jia * fix-link (#829) Co-authored-by: Kate Unger * Fix CUDA version in Docker image (11.7.0 to 11.7.1) (#833) --------- Co-authored-by: Colin Unger Co-authored-by: Kate Unger <32380357+KateUnger@users.noreply.github.com> Co-authored-by: Zhihao Jia Co-authored-by: Kate Unger Co-authored-by: DerrickYLJ <99985904+DerrickYLJ@users.noreply.github.com> --- INSTALL.md | 6 +++--- README.md | 9 ++++----- docker/flexflow-environment/Dockerfile | 2 +- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/INSTALL.md b/INSTALL.md index b0f8133483..4165683370 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -90,11 +90,11 @@ To run the Python examples, you have two options: you can use the `flexflow_pyth * `export PYTHONPATH="${FF_HOME}/python:${FF_HOME}/build/python"` * `export FF_USE_NATIVE_PYTHON=1` -**We recommend that you run the `mnist_mlp` test under `native` using the following cmd to check if FlexFlow has been installed correctly:** +**We recommend that you run the** `mnist_mlp` **test under** `native` **using the following cmd to check if FlexFlow has been installed correctly:** ``` -cd python -./flexflow_python examples/python/native/mnist_mlp.py -ll:py 1 -ll:gpu 1 -ll:fsize -ll:zsize +cd "$FF_HOME" +./python/flexflow_python examples/python/native/mnist_mlp.py -ll:py 1 -ll:gpu 1 -ll:fsize -ll:zsize ``` A script to run all the Python examples is available at `tests/multi_gpu_tests.sh` diff --git a/README.md b/README.md index 0420f8f902..c26904749d 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ FlexFlow is a deep learning framework that accelerates distributed DNN training by automatically searching for efficient parallelization strategies. FlexFlow provides a drop-in replacement for PyTorch and TensorFlow Keras. Running existing PyTorch and Keras programs in FlexFlow only requires [a few lines of changes to the program](https://flexflow.ai/keras). ## Install FlexFlow -To install FlexFlow from source code, please read the [instructions](INSTALL.md). If you would like to quickly try FlexFlow, we also provide pre-built Docker packages ([flexflow-cuda](https://github.com/flexflow/FlexFlow/pkgs/container/flexflow-cuda) with a CUDA backend, [flexflow-hip_rocm](https://github.com/flexflow/FlexFlow/pkgs/container/flexflow-hip_rocm) with a HIP-ROCM backend) with all dependencies pre-installed (N.B.: currently, the CUDA pre-built containers are only fully compatible with host machines that have CUDA 11.7 installed), together with [Dockerfiles](./docker) if you wish to build the containers manually. You can also use `conda` to install the FlexFlow Python package (coming soon). +To install FlexFlow from source code, please read the [instructions](https://flexflow.readthedocs.io/en/latest/installation.html). If you would like to quickly try FlexFlow, we also provide pre-built Docker packages ([flexflow-cuda](https://github.com/flexflow/FlexFlow/pkgs/container/flexflow-cuda) with a CUDA backend, [flexflow-hip_rocm](https://github.com/flexflow/FlexFlow/pkgs/container/flexflow-hip_rocm) with a HIP-ROCM backend) with all dependencies pre-installed (N.B.: currently, the CUDA pre-built containers are only fully compatible with host machines that have CUDA 11.7 installed), together with [Dockerfiles](./docker) if you wish to build the containers manually. You can also use `conda` to install the FlexFlow Python package (coming soon). ## PyTorch Support Users can also use FlexFlow to optimize the parallelization performance of existing PyTorch models in two steps. First, a PyTorch model can be exported to the FlexFlow model format using `flexflow.torch.fx.torch_to_flexflow`. @@ -18,7 +18,7 @@ fx.torch_to_flexflow(model, "mymodel.ff") Second, a FlexFlow program can directly import a previously saved PyTorch model and [autotune](https://www.usenix.org/conference/osdi22/presentation/unger) the parallelization performance for a given parallel machine. -``` +```python from flexflow.pytorch.model import PyTorchModel def top_level_task(): @@ -39,7 +39,7 @@ FlexFlow prioritizes PyTorch compatibility, but also includes frontends for [Ten ## C++ Interface For users that prefer to program in C/C++. FlexFlow supports a C++ program inference that is equivalent to its Python APIs. -**More FlexFlow C++ examples**: see the [C++ examples folder](https://github.com/flexflow/FlexFlow/tree/master/examples/c++). +**More FlexFlow C++ examples**: see the [C++ examples folder](https://github.com/flexflow/FlexFlow/tree/master/examples/cpp). ## Command-Line Flags @@ -69,12 +69,11 @@ Performance auto-tuning flags: For performance tuning related flags: see [performance autotuning](https://flexflow.ai/search). ## Contributing + Please let us know if you encounter any bugs or have any suggestions by [submitting an issue](https://github.com/flexflow/flexflow/issues). We welcome all contributions to FlexFlow from bug fixes to new features and extensions. -Please subscribe to the FlexFlow users mailing list for - ## Citations * Colin Unger, Zhihao Jia, Wei Wu, Sina Lin, Mandeep Baines, Carlos Efrain Quintero Narvaez, Vinay Ramakrishnaiah, Nirmal Prajapati, Pat McCormick, Jamaludin Mohd-Yusof, Xi Luo, Dheevatsa Mudigere, Jongsoo Park, Misha Smelyanskiy, and Alex Aiken. [Unity: Accelerating DNN Training Through Joint Optimization of Algebraic Transformations and Parallelization](https://www.usenix.org/conference/osdi22/presentation/unger). In Proceedings of the Symposium on Operating Systems Design and Implementation (OSDI), July 2022. diff --git a/docker/flexflow-environment/Dockerfile b/docker/flexflow-environment/Dockerfile index 598690a8a7..43c1599d0f 100644 --- a/docker/flexflow-environment/Dockerfile +++ b/docker/flexflow-environment/Dockerfile @@ -1,4 +1,4 @@ -FROM nvidia/cuda:11.7.0-cudnn8-devel-ubuntu20.04 +FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu20.04 LABEL org.opencontainers.image.source=https://github.com/flexflow/FlexFlow LABEL org.opencontainers.image.description="FlexFlow environment container" From 869d166916c7167eb9dea39d63419e4163990453 Mon Sep 17 00:00:00 2001 From: zwang86 <46699021+zwang86@users.noreply.github.com> Date: Fri, 7 Jul 2023 20:11:07 -0400 Subject: [PATCH 08/12] Fixation. (#840) --- include/flexflow/batch_config.h | 1 + inference/spec_infer/spec_infer.cc | 6 +++++- src/runtime/request_manager.cc | 4 ++++ 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/include/flexflow/batch_config.h b/include/flexflow/batch_config.h index b56466bfe5..61a1e345ae 100644 --- a/include/flexflow/batch_config.h +++ b/include/flexflow/batch_config.h @@ -116,6 +116,7 @@ class BeamSearchBatchConfig : public BatchConfig { inline static int const MAX_BEAM_DEPTH = 8; int model_id; + int max_init_length = 0; struct BeamSearchPerRequestInfo { int beam_size; diff --git a/inference/spec_infer/spec_infer.cc b/inference/spec_infer/spec_infer.cc index fbb07b2b25..e5a6c8d5e6 100644 --- a/inference/spec_infer/spec_infer.cc +++ b/inference/spec_infer/spec_infer.cc @@ -314,7 +314,11 @@ void FlexFlow::top_level_task(Task const *task, Future future = fm.get_future(0); BeamInferenceResult beam_ir = future.get_result(); - if (depth - 1 >= BeamSearchBatchConfig::MAX_BEAM_DEPTH) { + int iteration = + std::min(BeamSearchBatchConfig::MAX_BEAM_DEPTH, + BatchConfig::MAX_SEQ_LENGTH - beam_bc.max_init_length); + + if (depth - 1 >= iteration) { break; } else { beam_bc_vec[i] = rm.prepare_next_batch_beam(beam_bc_vec[i], beam_ir); diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index 56b9bf6241..b47b17ad12 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -357,6 +357,7 @@ BeamSearchBatchConfig // Step 2: preparing the next batch for existing requests BeamSearchBatchConfig new_bc; + new_bc.max_init_length = 0; new_bc.model_id = old_bc.model_id; std::cout << "old_bc.model_id: " << old_bc.model_id << "\n"; @@ -634,12 +635,15 @@ BeamSearchBatchConfig } // Step 2: Initialize new request + new_bc.max_init_length = 0; for (int i = 0; i < BeamSearchBatchConfig::MAX_NUM_REQUESTS; i++) { if (new_bc.request_completed[i]) { if (!pending_request_queue.empty() && new_bc.num_tokens < BeamSearchBatchConfig::MAX_NUM_TOKENS) { Request new_request = pending_request_queue.front(); pending_request_queue.pop(); + new_bc.max_init_length = + std::max(new_bc.max_init_length, new_request.initial_len); running_request_queue[new_request.guid] = new_request; new_bc.requestsInfo[i].token_start_offset = 0; new_bc.requestsInfo[i].request_guid = new_request.guid; From 93e3896d219496fee4b2b3c4518e20b32c51748f Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Sat, 8 Jul 2023 09:21:04 -0400 Subject: [PATCH 09/12] [Inference] - Save output of inference test as an artifact (#845) --- .github/workflows/gpu-ci.yml | 9 +++++++++ tests/inference_tests.sh | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/.github/workflows/gpu-ci.yml b/.github/workflows/gpu-ci.yml index bdbb8a751b..699ca9fc11 100644 --- a/.github/workflows/gpu-ci.yml +++ b/.github/workflows/gpu-ci.yml @@ -182,6 +182,15 @@ jobs: # Inference tests export TENSOR_PARALLELISM_TESTS=ON ./tests/inference_tests.sh + cd inference + tar -zcvf output.tar.gz ./output + cd .. + + - name: Save inference output as an artifact + uses: actions/upload-artifact@v3 + with: + name: output + path: inference/output.tar.gz gpu-ci-flexflow: name: Single Machine, Multiple GPUs Tests diff --git a/tests/inference_tests.sh b/tests/inference_tests.sh index 3e0d7cac53..761c6cf332 100755 --- a/tests/inference_tests.sh +++ b/tests/inference_tests.sh @@ -207,4 +207,4 @@ diff <(tail -n +2 "../inference/output/huggingface_opt_125M_half.txt") <(tail -n ############################################################################################### # Clean up after test -cleanup +# cleanup From 53c5617a8e5149ca1475978f391e4eb73c3434c5 Mon Sep 17 00:00:00 2001 From: Zhihao Jia Date: Sun, 9 Jul 2023 21:53:47 -0500 Subject: [PATCH 10/12] Using AllReduce instead of Reduce + Replicate when tensor model parallelism is enabled (#813) * [AllReduce] initial implementation * checkpoint * format * fusion * support half precision in fusedop * format * checkpoint * bug fixes * fix a performance issue in linear inference * fix * fix * fix specinfer and incr decoding * update readme * default data_parallelism_degree=1 * fix fusion * reduce unnecessary calculation. * makefile & rocm cmake fixes * only compare first 30 tokens in half precision * fix test script * check incr decoding steps instead of latency * hip rocm fix * makefile fix * more inference test fixes * update fusedop to support specinfer * fix rocm linking issue --------- Co-authored-by: Gabriele Oliaro Co-authored-by: xinhaoc --- .github/README.md | 5 +- .../cpp/inference/mixture_of_experts/moe.cc | 3 +- .../inference/transformers/transformers.cc | 3 +- include/flexflow/config.h | 11 +- include/flexflow/ffconst.h | 2 + include/flexflow/fftype.h | 7 +- include/flexflow/inference.h | 4 +- include/flexflow/model.h | 13 + include/flexflow/operator_params.h | 2 + include/flexflow/ops/arg_topk.h | 9 +- include/flexflow/ops/arg_topk_params.h | 2 + include/flexflow/ops/element_binary.h | 10 +- include/flexflow/ops/element_binary_params.h | 2 + include/flexflow/ops/fused.h | 13 + include/flexflow/ops/kernels/linear_kernels.h | 1 - include/flexflow/ops/layer_norm.h | 8 +- include/flexflow/ops/linear.h | 4 + include/flexflow/parallel_ops/allreduce.h | 70 +++ .../flexflow/parallel_ops/allreduce_params.h | 21 + include/flexflow/parallel_ops/combine.h | 14 + .../parallel_ops/kernels/allreduce_kernels.h | 31 ++ .../parallel_ops/kernels/combine_kernels.h | 1 + include/flexflow/utils/cuda_helper.h | 6 + inference/incr_decoding/incr_decoding.cc | 14 +- inference/models/falcon.cc | 27 +- inference/models/llama.cc | 37 +- inference/models/llama.h | 1 - inference/models/opt.cc | 47 +- inference/models/opt.h | 1 - inference/spec_infer/spec_infer.cc | 23 +- src/ops/arg_topk.cc | 27 +- src/ops/arg_topk.cpp | 7 +- src/ops/arg_topk.cu | 8 +- src/ops/beam_topk.cc | 6 +- src/ops/conv_2d.cc | 6 +- src/ops/element_binary.cc | 61 ++- src/ops/element_unary.cc | 6 +- src/ops/experts.cc | 6 +- src/ops/fused.cc | 155 +++++- src/ops/fused.cpp | 417 +++++++++++++++++ src/ops/fused.cu | 442 +++++++++++++++++- src/ops/gather.cc | 6 +- src/ops/inc_multihead_self_attention.cu | 2 +- src/ops/kernels/linear_kernels.cpp | 37 +- src/ops/kernels/linear_kernels.cu | 39 +- src/ops/layer_norm.cc | 74 +-- src/ops/layer_norm.cpp | 8 +- src/ops/layer_norm.cu | 8 +- src/ops/linear.cc | 83 +++- src/ops/reduce.cc | 6 +- src/ops/reshape.cc | 6 +- src/ops/rms_norm.cc | 23 +- src/ops/tree_inc_multihead_self_attention.cu | 2 +- src/parallel_ops/allreduce.cc | 362 ++++++++++++++ src/parallel_ops/combine.cc | 121 ++++- .../kernels/allreduce_kernels.cpp | 46 ++ src/parallel_ops/kernels/allreduce_kernels.cu | 56 +++ src/parallel_ops/kernels/combine_kernels.cpp | 6 + src/parallel_ops/kernels/combine_kernels.cu | 6 + src/runtime/cuda_helper.cu | 24 + src/runtime/ffconst_utils.cc | 2 + src/runtime/fftype.cc | 14 +- src/runtime/graph.cc | 124 +++-- src/runtime/hip_helper.cpp | 17 +- src/runtime/inference_manager.cc | 151 +++--- src/runtime/layer.cc | 10 +- src/runtime/model.cc | 128 ++++- src/runtime/operator_params.cc | 3 + src/runtime/request_manager.cc | 4 + src/runtime/substitution.cc | 21 +- tests/inference_tests.sh | 169 ++++--- 71 files changed, 2605 insertions(+), 486 deletions(-) create mode 100644 include/flexflow/parallel_ops/allreduce.h create mode 100644 include/flexflow/parallel_ops/allreduce_params.h create mode 100644 include/flexflow/parallel_ops/kernels/allreduce_kernels.h create mode 100644 src/parallel_ops/allreduce.cc create mode 100644 src/parallel_ops/kernels/allreduce_kernels.cpp create mode 100644 src/parallel_ops/kernels/allreduce_kernels.cu diff --git a/.github/README.md b/.github/README.md index 576b1ca84e..c4f220e222 100644 --- a/.github/README.md +++ b/.github/README.md @@ -44,7 +44,10 @@ The source code of the SpecInfer pipeline is available at [this folder](../infer * `-ssm-weight`: path to the folder that stores the small speculative models' weights. The number of `-ssm-weight`s must match the number of `-ssm-model`s and `-ssm-config`s. * `-ssm-config`: path to the json file that stores the SSM model configs. The number of `-ssm-config`s must match the number of `-ssm-model`s and `-ssm-weight`s. * `-tokenizer`: path to the tokenizer file (see [Tokenizers](#tokenizers) for preparing a tokenizer for SpecInfer). +* `-data-parallelism-degree`, `-tensor-parallelism-degree` and `-pipeline-parallelism-degree`: parallelization degrees in the data, tensor, and pipeline dimensions. Their product must equal the number of GPUs available on the machine. When any of the three parallelism degree arguments is omitted, a default value of 1 will be used. * `-prompt`: (optional) path to the prompt file. SpecInfer expects a json format file for prompts, all of which will be served by SpecInfer. In addition, users can also use the following API for registering requests: +* `-output-file`: (optional) filepath to use to save the output of the model, together with the generation latency + ```c++ class RequestManager { @@ -54,7 +57,7 @@ class RequestManager { For example, you can use the following command line to serve a LLaMA-7B or LLaMA-13B model on 4 GPUs and use two collectively boost-tuned LLaMA-190M models for speculative inference. ```bash -./inference/spec_infer/spec_infer -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model llama -llm-weight /path/to/llm/weights -llm-config /path/to/llm/config.json -ssm-model llama -ssm-weight /path/to/ssm1/weights -ssm-config /path/to/ssm/config.json -ssm-model llama -smm-weight /path/to/ssm2/weights -ssm-config /path/to/ssm2/config.json -tokenizer /path/to/tokenizer.model -prompt /path/to/prompt.json --use-full-precision +./inference/spec_infer/spec_infer -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model llama -llm-weight /path/to/llm/weights -llm-config /path/to/llm/config.json -ssm-model llama -ssm-weight /path/to/ssm1/weights -ssm-config /path/to/ssm/config.json -ssm-model llama -smm-weight /path/to/ssm2/weights -ssm-config /path/to/ssm2/config.json -tokenizer /path/to/tokenizer.model -prompt /path/to/prompt.json --use-full-precision -tensor-parallelism-degree 2 -pipeline-parallelism-degree 2 ``` ### Tokenizers diff --git a/examples/cpp/inference/mixture_of_experts/moe.cc b/examples/cpp/inference/mixture_of_experts/moe.cc index 39459d63ac..ff3f6bb53a 100644 --- a/examples/cpp/inference/mixture_of_experts/moe.cc +++ b/examples/cpp/inference/mixture_of_experts/moe.cc @@ -140,8 +140,7 @@ void FlexFlow::top_level_task(Task const *task, //------------------- Initialize the inference manager ------------------ InferenceManager im(ff.config, moeConfig.batch_size); - std::unordered_map> mapping; - im.compile_model_and_allocate_buffer(&ff, mapping); + im.compile_model_and_allocate_buffer(&ff); im.init_operators_inference(&ff); //------------ Initialize the data loader and data generator ------------ diff --git a/examples/cpp/inference/transformers/transformers.cc b/examples/cpp/inference/transformers/transformers.cc index d56473c8bd..074e832d47 100644 --- a/examples/cpp/inference/transformers/transformers.cc +++ b/examples/cpp/inference/transformers/transformers.cc @@ -115,8 +115,7 @@ void FlexFlow::top_level_task(Task const *task, //------------------- Initialize the inference manager ------------------ InferenceManager im(ff.config, transformerConfig.batch_size); - std::unordered_map> mapping; - im.compile_model_and_allocate_buffer(&ff, mapping); + im.compile_model_and_allocate_buffer(&ff); im.init_operators_inference(&ff); //------------ Initialize the data loader and data generator ------------ diff --git a/include/flexflow/config.h b/include/flexflow/config.h index f1b218e50f..be6c0d21da 100644 --- a/include/flexflow/config.h +++ b/include/flexflow/config.h @@ -37,14 +37,15 @@ namespace FlexFlow { // ======================================================== // Define Runtime Constants // ======================================================== -#define MAX_NUM_INPUTS 256 -#define MAX_NUM_WEIGHTS 64 -#define MAX_NUM_OUTPUTS 256 -#define MAX_NUM_FUSED_OPERATORS 64 -#define MAX_NUM_FUSED_TENSORS 64 +#define MAX_NUM_INPUTS 2048 +#define MAX_NUM_WEIGHTS 2048 +#define MAX_NUM_OUTPUTS 2048 +#define MAX_NUM_FUSED_OPERATORS 2048 +#define MAX_NUM_FUSED_TENSORS 2048 #define MAX_NUM_WORKERS 1024 #define MAX_FILENAME 200 #define MAX_OPNAME 128 +#define MAX_NUM_TRANSFORMER_LAYERS 100 // DataLoader #define MAX_SAMPLES_PER_LOAD 64 #define MAX_FILE_LENGTH 128 diff --git a/include/flexflow/ffconst.h b/include/flexflow/ffconst.h index 0b572a9674..3d899ac91d 100644 --- a/include/flexflow/ffconst.h +++ b/include/flexflow/ffconst.h @@ -173,6 +173,7 @@ enum OperatorType { OP_REPLICATE, OP_REDUCTION, OP_PIPELINE, + OP_ALLREDUCE, OP_FUSED_PARALLEL, OP_INVALID, }; @@ -207,6 +208,7 @@ enum PMParameter { PM_COMBINE_DEGREE, // Combine PM_REDUCTION_DIM, // Reduction PM_REDUCTION_DEGREE, // Reduction + PM_ALLREDUCE_DIM, // AllReduce PM_SOFTMAX_DIM, // Softmax PM_NUM_HEADS, // MultiHeadAttention PM_INVALID, diff --git a/include/flexflow/fftype.h b/include/flexflow/fftype.h index a71c85dbc8..18ed6b8100 100644 --- a/include/flexflow/fftype.h +++ b/include/flexflow/fftype.h @@ -8,15 +8,16 @@ namespace FlexFlow { class LayerID { public: + static const LayerID NO_ID; LayerID(); - LayerID(size_t id); + LayerID(size_t id, size_t transformer_layer_id); bool is_valid_id() const; friend bool operator==(LayerID const &lhs, LayerID const &rhs); public: - size_t id; + size_t id, transformer_layer_id; }; }; // namespace FlexFlow -#endif // _FF_TYPE_H \ No newline at end of file +#endif // _FF_TYPE_H diff --git a/include/flexflow/inference.h b/include/flexflow/inference.h index 1fd2fdff78..a1846c96dc 100644 --- a/include/flexflow/inference.h +++ b/include/flexflow/inference.h @@ -29,9 +29,7 @@ using tokenizers::Tokenizer; class InferenceManager { public: InferenceManager(FFConfig const &config, int max_num_tokens_per_batch); - void compile_model_and_allocate_buffer( - FFModel *model, - std::unordered_map> const &mapping); + void compile_model_and_allocate_buffer(FFModel *model); void init_operators_inference(FFModel *model); MachineView *get_machine_view(int mv_id); Legion::FutureMap inference(FFModel *model, int index, BatchConfig const &bc); diff --git a/include/flexflow/model.h b/include/flexflow/model.h index 1277b29b3d..2b95eecac0 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -104,6 +104,7 @@ enum TaskIDs { LAYERNORM_BWD_TASK_ID, LINEAR_INIT_TASK_ID, LINEAR_INIT_PARA_TASK_ID, + LINEAR_INF_TASK_ID, LINEAR_FWD_TASK_ID, LINEAR_BWD_TASK_ID, LINEAR_BWD2_TASK_ID, @@ -159,6 +160,7 @@ enum TaskIDs { FUSEDOP_INIT_TASK_ID, FUSEDOP_FWD_TASK_ID, FUSEDOP_BWD_TASK_ID, + FUSEDOP_INF_TASK_ID, NOOP_INIT_TASK_ID, // Metrics tasks METRICS_COMP_TASK_ID, @@ -212,6 +214,9 @@ enum TaskIDs { PIPELINE_INIT_TASK_ID, PIPELINE_FWD_TASK_ID, PIPELINE_BWD_TASK_ID, + ALLREDUCE_INIT_TASK_ID, + ALLREDUCE_FWD_TASK_ID, + ALLREDUCE_BWD_TASK_ID, FUSED_PARALLELOP_INIT_TASK_ID, FUSED_PARALLELOP_FWD_TASK_ID, FUSED_PARALLELOP_BWD_TASK_ID, @@ -311,6 +316,7 @@ class Combine; class Repartition; class Reduction; class Replicate; +class AllReduce; class FusedParallelOp; class ParallelOpInfo; @@ -897,6 +903,9 @@ class FFModel { std::vector const ®ions, Legion::Context ctx, Legion::Runtime *runtime); + // ======================================== + // Internal APIs that should not be invoked from applications + // ======================================== void reset_metrics(); void init_operators(); void init_operators_inference( @@ -919,6 +928,7 @@ class FFModel { std::vector const &metrics, CompMode comp_mode = COMP_MODE_TRAINING); void compile_inference(); + void set_transformer_layer_id(int id); void graph_optimize(size_t budget, bool only_data_parallel, std::unique_ptr &best_graph, @@ -975,6 +985,7 @@ class FFModel { public: size_t op_global_guid, layer_global_guid; size_t tensor_global_guid, parallel_tensor_global_guid, node_global_guid; + size_t current_transformer_layer_id; FFConfig config; FFIterationConfig iter_config; Optimizer *optimizer; @@ -1078,6 +1089,8 @@ class FFModel { Reduction *>, std::unordered_map, Combine *>, + std::unordered_map, + AllReduce *>, std::unordered_map, FusedParallelOp *>> cached_ops; diff --git a/include/flexflow/operator_params.h b/include/flexflow/operator_params.h index 8c52dfb584..f6918ff581 100644 --- a/include/flexflow/operator_params.h +++ b/include/flexflow/operator_params.h @@ -32,6 +32,7 @@ #include "flexflow/ops/topk_params.h" #include "flexflow/ops/transpose_params.h" #include "flexflow/ops/tree_inc_multihead_self_attention_params.h" +#include "flexflow/parallel_ops/allreduce_params.h" #include "flexflow/parallel_ops/combine_params.h" #include "flexflow/parallel_ops/fused_parallel_op_params.h" #include "flexflow/parallel_ops/partition_params.h" @@ -76,6 +77,7 @@ using OperatorParameters = mp::variant; tl::optional get_op_parameters(Op const *op); diff --git a/include/flexflow/ops/arg_topk.h b/include/flexflow/ops/arg_topk.h index a00ab76385..ed92200fbe 100644 --- a/include/flexflow/ops/arg_topk.h +++ b/include/flexflow/ops/arg_topk.h @@ -19,11 +19,15 @@ class ArgTopK : public Op { using Params = ArgTopKParams; using Input = ParallelTensor; ArgTopK(FFModel &model, + LayerID const &layer_guid, const ParallelTensor input, int k, bool sorted, char const *name); - ArgTopK(FFModel &model, ArgTopK const &other, const ParallelTensor input); + ArgTopK(FFModel &model, + LayerID const &layer_guid, + ArgTopK const &other, + const ParallelTensor input); ArgTopK(FFModel &model, Params const ¶ms, Input const input, @@ -80,7 +84,8 @@ class ArgTopK : public Op { ffStream_t stream); static void forward_kernel_wrapper(ArgTopKMeta const *m, GenericTensorAccessorR const &input, - GenericTensorAccessorW const &indices); + GenericTensorAccessorW const &indices, + int batch_size); Params get_params() const; public: diff --git a/include/flexflow/ops/arg_topk_params.h b/include/flexflow/ops/arg_topk_params.h index ca88a5b9be..9d2a21034f 100644 --- a/include/flexflow/ops/arg_topk_params.h +++ b/include/flexflow/ops/arg_topk_params.h @@ -2,11 +2,13 @@ #define _FLEXFLOW_ARG_TOPK_PARAMS_H #include "flexflow/ffconst.h" +#include "flexflow/fftype.h" #include "flexflow/parallel_tensor.h" namespace FlexFlow { struct ArgTopKParams { + LayerID layer_guid; int k; bool sorted; bool is_valid(ParallelTensorShape const &) const; diff --git a/include/flexflow/ops/element_binary.h b/include/flexflow/ops/element_binary.h index 9c2e6c1252..fe7dc2602c 100644 --- a/include/flexflow/ops/element_binary.h +++ b/include/flexflow/ops/element_binary.h @@ -15,6 +15,7 @@ class ElementBinary : public Op { using Input = std::pair; ElementBinary(FFModel &model, + LayerID const &layer_guid, OperatorType type, const ParallelTensor x, const ParallelTensor y, @@ -23,8 +24,7 @@ class ElementBinary : public Op { ElementBinary(FFModel &model, Params const ¶ms, Input const &inputs, - char const *name = nullptr, - bool inplace_a = false); + char const *name = nullptr); void init(FFModel const &) override; void init_inference(FFModel const &, std::vector const &, @@ -63,6 +63,12 @@ class ElementBinary : public Op { bool measure_operator_cost(Simulator *sim, MachineView const &pc, CostMetrics &cost_metrics) const override; + + void serialize(Legion::Serializer &) const override; + static PCG::Node deserialize(FFModel &ff, + Legion::Deserializer &d, + ParallelTensor inputs[], + int num_inputs); Params get_params() const; public: diff --git a/include/flexflow/ops/element_binary_params.h b/include/flexflow/ops/element_binary_params.h index 5aa20e25a5..8b26877af2 100644 --- a/include/flexflow/ops/element_binary_params.h +++ b/include/flexflow/ops/element_binary_params.h @@ -7,7 +7,9 @@ namespace FlexFlow { struct ElementBinaryParams { + LayerID layer_guid; OperatorType type; + bool inplace_a; bool is_valid( std::pair const &) const; diff --git a/include/flexflow/ops/fused.h b/include/flexflow/ops/fused.h index 87d35da902..87e562d143 100644 --- a/include/flexflow/ops/fused.h +++ b/include/flexflow/ops/fused.h @@ -29,8 +29,17 @@ class FusedOp : public Op { return ParallelTensor(); } void init(FFModel const &) override; + void init_inference(FFModel const &, + std::vector const &, + std::vector const &, + MachineView const *mv = nullptr) override; void forward(FFModel const &) override; void backward(FFModel const &) override; + Legion::FutureMap inference(FFModel const &, + BatchConfig const &, + std::vector const &, + std::vector const &, + MachineView const *mv = nullptr) override; void print_layer(FFModel const &model) override { assert(0); } @@ -38,6 +47,10 @@ class FusedOp : public Op { std::vector const ®ions, Legion::Context ctx, Legion::Runtime *runtime); + static void inference_task(Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); static void forward_task(Legion::Task const *task, std::vector const ®ions, Legion::Context ctx, diff --git a/include/flexflow/ops/kernels/linear_kernels.h b/include/flexflow/ops/kernels/linear_kernels.h index 9644fd9c8f..29791b53ff 100644 --- a/include/flexflow/ops/kernels/linear_kernels.h +++ b/include/flexflow/ops/kernels/linear_kernels.h @@ -33,7 +33,6 @@ class LinearMeta : public OpMeta { RegularizerMode kernel_reg_type; float kernel_reg_lambda; bool use_bias, add_bias_only_once; - DataType input_type, weight_type, output_type; char op_name[MAX_OPNAME]; }; diff --git a/include/flexflow/ops/layer_norm.h b/include/flexflow/ops/layer_norm.h index b962edf326..b5a36262b4 100644 --- a/include/flexflow/ops/layer_norm.h +++ b/include/flexflow/ops/layer_norm.h @@ -72,14 +72,14 @@ class LayerNorm : public Op { static void forward_kernel(LayerNormMeta const *m, T const *input_ptr, T *output_ptr, - T *gamma_ptr, - T *beta_ptr, + T const *gamma_ptr, + T const *beta_ptr, ffStream_t stream); static void forward_kernel_wrapper(LayerNormMeta const *m, GenericTensorAccessorR const &input, GenericTensorAccessorW &output, - GenericTensorAccessorW &gamma, - GenericTensorAccessorW &beta); + GenericTensorAccessorR const &gamma, + GenericTensorAccessorR const &beta); template static void backward_kernel(LayerNormMeta const *m, T const *output_grad_ptr, diff --git a/include/flexflow/ops/linear.h b/include/flexflow/ops/linear.h index 7b134502b7..ff6ba1ef90 100644 --- a/include/flexflow/ops/linear.h +++ b/include/flexflow/ops/linear.h @@ -62,6 +62,10 @@ class Linear : public Op { std::vector const ®ions, Legion::Context ctx, Legion::Runtime *runtime); + static void inference_task(Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); static void forward_task(Legion::Task const *task, std::vector const ®ions, Legion::Context ctx, diff --git a/include/flexflow/parallel_ops/allreduce.h b/include/flexflow/parallel_ops/allreduce.h new file mode 100644 index 0000000000..2faf128d93 --- /dev/null +++ b/include/flexflow/parallel_ops/allreduce.h @@ -0,0 +1,70 @@ +#ifndef _FLEXFLOW_ALLREDUCE_H +#define _FLEXFLOW_ALLREDUCE_H + +#include "flexflow/layer.h" +#include "flexflow/node.h" +#include "flexflow/op_meta.h" +#include "flexflow/operator.h" +#include "flexflow/parallel_ops/allreduce_params.h" +#include "parallel_op.h" + +namespace FlexFlow { + +class AllReduce : public ParallelOp { +public: + using Params = AllReduceParams; + using Input = ParallelTensor; + + AllReduce(FFModel &model, + const ParallelTensor input, + int allreduce_legion_dim, + char const *name = NULL); + AllReduce(FFModel &model, + Params const ¶ms, + Input const input, + char const *name = nullptr); + void create_input_partition(FFModel &model) override; + void create_input_partition_inference( + FFModel &model, + std::vector const &batch_inputs, + std::vector const &batch_outputs) override; + void init(FFModel const &) override; + void init_inference(FFModel const &, + std::vector const &, + std::vector const &, + MachineView const *mv = nullptr) override; + void forward(FFModel const &) override; + Legion::FutureMap inference(FFModel const &, + BatchConfig const &bc, + std::vector const &, + std::vector const &, + MachineView const *mv = nullptr) override; + void backward(FFModel const &) override; + bool get_int_parameter(PMParameter, int *) const override; + bool append_parallel_op_info( + std::vector ¶llel_ops) const override; + static OpMeta *init_task(Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); + static void forward_task(Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); + static void backward_task(Legion::Task const *task, + std::vector const ®ions, + Legion::Context ctx, + Legion::Runtime *runtime); + bool measure_operator_cost(Simulator *sim, + MachineView const &pc, + CostMetrics &cost_metrics) const override; + + Params get_params() const; + +public: + int allreduce_dim; +}; + +}; // namespace FlexFlow + +#endif // _FLEXFLOW_ALLREDUCE_H diff --git a/include/flexflow/parallel_ops/allreduce_params.h b/include/flexflow/parallel_ops/allreduce_params.h new file mode 100644 index 0000000000..c04676ffeb --- /dev/null +++ b/include/flexflow/parallel_ops/allreduce_params.h @@ -0,0 +1,21 @@ +#ifndef _FLEXFLOW_ALLREDUCE_PARAMS_H +#define _FLEXFLOW_ALLREDUCE_PARAMS_H + +namespace FlexFlow { + +struct AllReduceParams { + int allreduce_legion_dim; + bool is_valid(ParallelTensorShape const &) const; +}; +bool operator==(AllReduceParams const &, AllReduceParams const &); + +} // namespace FlexFlow + +namespace std { +template <> +struct hash { + size_t operator()(FlexFlow::AllReduceParams const &) const; +}; +} // namespace std + +#endif // _FLEXFLOW_ALLREDUCE_PARAMS_H diff --git a/include/flexflow/parallel_ops/combine.h b/include/flexflow/parallel_ops/combine.h index 310e599f54..d09a789de2 100644 --- a/include/flexflow/parallel_ops/combine.h +++ b/include/flexflow/parallel_ops/combine.h @@ -3,6 +3,7 @@ #include "flexflow/layer.h" #include "flexflow/node.h" +#include "flexflow/op_meta.h" #include "flexflow/operator.h" #include "flexflow/parallel_ops/combine_params.h" #include "parallel_op.h" @@ -24,8 +25,21 @@ class Combine : public ParallelOp { Input const input, char const *name = nullptr); void create_input_partition(FFModel &model) override; + void create_input_partition_inference( + FFModel &model, + std::vector const &batch_inputs, + std::vector const &batch_outputs) override; void init(FFModel const &) override; + void init_inference(FFModel const &, + std::vector const &, + std::vector const &, + MachineView const *mv = nullptr) override; void forward(FFModel const &) override; + Legion::FutureMap inference(FFModel const &, + BatchConfig const &bc, + std::vector const &, + std::vector const &, + MachineView const *mv = nullptr) override; void backward(FFModel const &) override; bool get_int_parameter(PMParameter, int *) const override; bool append_parallel_op_info( diff --git a/include/flexflow/parallel_ops/kernels/allreduce_kernels.h b/include/flexflow/parallel_ops/kernels/allreduce_kernels.h new file mode 100644 index 0000000000..02a5026fcf --- /dev/null +++ b/include/flexflow/parallel_ops/kernels/allreduce_kernels.h @@ -0,0 +1,31 @@ +#ifndef _FLEXFLOW_OPS_KERNELS_ALLREDUCE_KERNELS_H +#define _FLEXFLOW_OPS_KERNELS_ALLREDUCE_KERNELS_H + +#include "flexflow/device.h" +#include "flexflow/fftype.h" +#include "flexflow/op_meta.h" +#include "flexflow/parallel_ops/allreduce.h" + +namespace FlexFlow { + +class AllReduceMeta : public OpMeta { +public: + AllReduceMeta(FFHandler handle, AllReduce const *reduct); +}; + +namespace Kernels { +namespace AllReduce { + +void forward_kernel_wrapper(AllReduceMeta const *m, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &output); + +void backward_kernel_wrapper(AllReduceMeta const *m, + GenericTensorAccessorW const &input_grad, + GenericTensorAccessorR const &output_grad); + +} // namespace AllReduce +} // namespace Kernels +} // namespace FlexFlow + +#endif // _FLEXFLOW_OPS_KERNELS_ALLREDUCE_KERNELS_H diff --git a/include/flexflow/parallel_ops/kernels/combine_kernels.h b/include/flexflow/parallel_ops/kernels/combine_kernels.h index 6f540679a2..456013cd81 100644 --- a/include/flexflow/parallel_ops/kernels/combine_kernels.h +++ b/include/flexflow/parallel_ops/kernels/combine_kernels.h @@ -4,6 +4,7 @@ #include "flexflow/device.h" #include "flexflow/fftype.h" #include "flexflow/op_meta.h" +#include "flexflow/parallel_ops/combine.h" namespace FlexFlow { diff --git a/include/flexflow/utils/cuda_helper.h b/include/flexflow/utils/cuda_helper.h index 5ac4571118..1787c5a0b7 100644 --- a/include/flexflow/utils/cuda_helper.h +++ b/include/flexflow/utils/cuda_helper.h @@ -4,6 +4,9 @@ #include "legion.h" #include #include +#ifdef FF_USE_NCCL +#include +#endif #define FatalError(s) \ do { \ @@ -165,6 +168,9 @@ cudnnStatus_t cudaDataType_t ff_to_cuda_datatype(DataType type); cudnnDataType_t ff_to_cudnn_datatype(DataType type); +#ifdef FF_USE_NCCL +ncclDataType_t ff_to_nccl_datatype(DataType type); +#endif cudaDataType_t cudnn_to_cuda_datatype(cudnnDataType_t type); cudnnDataType_t cuda_to_cudnn_datatype(cudaDataType_t type); diff --git a/inference/incr_decoding/incr_decoding.cc b/inference/incr_decoding/incr_decoding.cc index d43cab17f9..68a8e10042 100644 --- a/inference/incr_decoding/incr_decoding.cc +++ b/inference/incr_decoding/incr_decoding.cc @@ -126,7 +126,7 @@ void FlexFlow::top_level_task(Task const *task, bool verbose = false; size_t num_devices = ffconfig.workersPerNode * ffconfig.numNodes; int data_parallelism_degree = 1, tensor_parallelism_degree = 1, - pipeline_parallelism_degree = -1; + pipeline_parallelism_degree = 1; InputArgs const &command_args = HighLevelRuntime::get_input_args(); char **argv = command_args.argv; @@ -142,10 +142,10 @@ void FlexFlow::top_level_task(Task const *task, pipeline_parallelism_degree); ffconfig.data_parallelism_degree = data_parallelism_degree; ffconfig.tensor_parallelism_degree = tensor_parallelism_degree; - ffconfig.pipeline_parallelism_degree = - pipeline_parallelism_degree == -1 - ? num_devices / (tensor_parallelism_degree * data_parallelism_degree) - : pipeline_parallelism_degree; + ffconfig.pipeline_parallelism_degree = pipeline_parallelism_degree; + assert(data_parallelism_degree * tensor_parallelism_degree * + pipeline_parallelism_degree == + ffconfig.numNodes * ffconfig.workersPerNode); assert(model_type != ModelType::UNKNOWN && "Invalid LLM model type passed (or no type was passed)."); @@ -162,8 +162,6 @@ void FlexFlow::top_level_task(Task const *task, im, file_paths.llm_config_file_path, file_paths.llm_weight_file_path, - ffconfig.workersPerNode * ffconfig.numNodes / - tensor_parallelism_degree, INC_DECODING_MODE, use_full_precision); } else if (model_type == ModelType::OPT) { @@ -171,8 +169,6 @@ void FlexFlow::top_level_task(Task const *task, im, file_paths.llm_config_file_path, file_paths.llm_weight_file_path, - ffconfig.workersPerNode * ffconfig.numNodes / - tensor_parallelism_degree, INC_DECODING_MODE, use_full_precision); } else if (model_type == ModelType::FALCON) { diff --git a/inference/models/falcon.cc b/inference/models/falcon.cc index 7fc3124278..bced5dc1e0 100644 --- a/inference/models/falcon.cc +++ b/inference/models/falcon.cc @@ -28,20 +28,6 @@ void FALCON::create_falcon_model(FFModel &ff, bool use_full_precision) { Config falcon_config(model_config_file_path); falcon_config.printConfig(); - //------------------------------compute machine views ------------------ - int num_devices = ff.config.workersPerNode * ff.config.numNodes; - std::vector machine_views; - for (int i = 0; i < num_devices; i++) { - MachineView view; - view.device_type = MachineView::GPU; - view.ndims = 1; - view.dim[0] = 1; - view.stride[0] = 0; - view.start_device_id = i; - machine_views.push_back(view); - } - - std::unordered_map> mapping; std::unordered_map weights_layers; Tensor input; @@ -50,7 +36,6 @@ void FALCON::create_falcon_model(FFModel &ff, int const token_dims[] = {BatchConfig::MAX_NUM_TOKENS, 1}; input = ff.create_tensor<2>(token_dims, DT_INT32); } - mapping[input].push_back(machine_views[0]); Initializer *embed_init = new UniformInitializer(std::rand(), 0, 0); @@ -83,18 +68,12 @@ void FALCON::create_falcon_model(FFModel &ff, (num_transformer_layers + num_pipeline_stages - 1) / num_pipeline_stages; for (int i = 0; i < num_transformer_layers; i++) { + // set transformer layer id + ff.set_transformer_layer_id(i); // step 1: attention Tensor att_norm = ff.layer_norm(token, axes, true, falcon_config.norm_eps); Layer *attention_norm = ff.layers.back(); - if (i % num_transformer_layers_per_stage == 0) { - // Map att_norm to the next GPU - // since the size of att_norm is minimum across - // all tensors - mapping[att_norm].push_back( - machine_views[i / num_transformer_layers_per_stage]); - } - weights_layers.emplace("layers_" + std::to_string(i) + "_input_layernorm_weight", attention_norm); @@ -162,7 +141,7 @@ void FALCON::create_falcon_model(FFModel &ff, // Compile the model std::cout << "------start compile ----------" << std::endl; - im.compile_model_and_allocate_buffer(&ff, mapping); + im.compile_model_and_allocate_buffer(&ff); FileDataLoader fileloader("", weight_file_path, falcon_config.n_heads, diff --git a/inference/models/llama.cc b/inference/models/llama.cc index 1e61f43a98..e54ec13147 100644 --- a/inference/models/llama.cc +++ b/inference/models/llama.cc @@ -23,7 +23,6 @@ void LLAMA::create_llama_model(FFModel &ff, InferenceManager &im, std::string const &model_config_file_path, std::string const &weight_file_path, - int num_pipeline_stages, InferenceMode mode, bool use_full_precision) { // do not apply cpu offload in beam search model. @@ -62,7 +61,6 @@ void LLAMA::create_llama_model(FFModel &ff, } assert(machine_views.size() == num_devices); - std::unordered_map> mapping; std::unordered_map weights_layers; Tensor input; @@ -71,10 +69,6 @@ void LLAMA::create_llama_model(FFModel &ff, int const token_dims[] = {BatchConfig::MAX_NUM_TOKENS, 1}; input = ff.create_tensor<2>(token_dims, DT_INT32); } - for (int i = 0; i < ff.config.data_parallelism_degree; i++) { - mapping[input].push_back( - machine_views[i * num_devices_per_data_parallelism_line]); - } Initializer *embed_init = new UniformInitializer(std::rand(), 0, 0); @@ -101,39 +95,14 @@ void LLAMA::create_llama_model(FFModel &ff, Layer *embedding = ff.layers.back(); weights_layers.emplace("tok_embeddings_weight", embedding); - // int num_transformer_layers = llama_config.n_layers; - // int num_transformer_layers_per_stage = - // (num_transformer_layers + num_pipeline_stages - 1) / - // num_pipeline_stages; - for (int i = 0; i < num_transformer_layers; i++) { + // set transformer layer id + ff.set_transformer_layer_id(i); // step 1: attention std::vector axes = {2}; Tensor att_norm = ff.rms_norm(token, llama_config.norm_eps, llama_config.dim); Layer *attention_norm = ff.layers.back(); - - // if (i % num_transformer_layers_per_stage == 0) { - // // Map att_norm to the next GPU - // // since the size of att_norm is minimum across - // // all tensors - // mapping[att_norm].push_back( - // machine_views[i / num_transformer_layers_per_stage]); - // } - for (int dp_index = 0; dp_index < ff.config.data_parallelism_degree; - dp_index++) { - int pp_block_idx = i / num_layers_per_pp_block; - int first_device_idx = dp_index * num_devices_per_data_parallelism_line + - ff.config.tensor_parallelism_degree * pp_block_idx; - // std::cout << "assigning layer " << i << " to devices " << - // first_device_idx - // << "-" - // << first_device_idx + ff.config.tensor_parallelism_degree - 1 - // << std::endl; - assert(first_device_idx < num_devices); - mapping[att_norm].push_back(machine_views[first_device_idx]); - } - weights_layers.emplace("layers_" + std::to_string(i) + "_attention_norm_weight", attention_norm); @@ -246,7 +215,7 @@ void LLAMA::create_llama_model(FFModel &ff, // Compile the model std::cout << "------start compile ----------" << std::endl; - im.compile_model_and_allocate_buffer(&ff, mapping); + im.compile_model_and_allocate_buffer(&ff); FileDataLoader fileloader("", weight_file_path, llama_config.n_heads, diff --git a/inference/models/llama.h b/inference/models/llama.h index 11fc354a2c..ab9bd4c7f3 100644 --- a/inference/models/llama.h +++ b/inference/models/llama.h @@ -106,7 +106,6 @@ class LLAMA { InferenceManager &im, std::string const &model_config_file_path, std::string const &weight_file_path, - int num_pipeline_stages, InferenceMode mode, bool use_full_precision = false); }; diff --git a/inference/models/opt.cc b/inference/models/opt.cc index 499eb92642..503be39672 100644 --- a/inference/models/opt.cc +++ b/inference/models/opt.cc @@ -23,7 +23,6 @@ void OPT::create_opt_model(FFModel &ff, InferenceManager &im, std::string const &model_config_file_path, std::string const &weight_file_path, - int num_pipeline_stages, InferenceMode mode, bool use_full_precision) { Config opt_config(model_config_file_path); @@ -47,21 +46,6 @@ void OPT::create_opt_model(FFModel &ff, // << num_devices_per_data_parallelism_line << std::endl; // std::cout << "num layers: " << opt_config.num_hidden_layers << std::endl; - //------------------------------compute machine views ------------------ - // single device - std::vector machine_views; - for (int i = 0; i < num_devices; i++) { - MachineView view; - view.device_type = MachineView::GPU; - view.ndims = 1; - view.dim[0] = 1; - view.stride[0] = 0; - view.start_device_id = i; - machine_views.push_back(view); - } - assert(machine_views.size() == num_devices); - - std::unordered_map> mapping; std::unordered_map weights_layers; //------------------------------ build the model -------------------------- @@ -72,12 +56,6 @@ void OPT::create_opt_model(FFModel &ff, input = ff.create_tensor<2>(token_dims, DT_INT32); position_input = ff.create_tensor<2>(token_dims, DT_INT32); } - for (int i = 0; i < ff.config.data_parallelism_degree; i++) { - mapping[input].push_back( - machine_views[i * num_devices_per_data_parallelism_line]); - mapping[position_input].push_back( - machine_views[i * num_devices_per_data_parallelism_line]); - } Initializer *embed_init = new UniformInitializer(std::rand(), 0, 0); std::vector axes = {0}; @@ -127,9 +105,10 @@ void OPT::create_opt_model(FFModel &ff, Tensor residual = ff.add(token, positional_embedding); - int num_transformer_layers_per_stage = - (32 + num_pipeline_stages - 1) / num_pipeline_stages; for (int i = 0; i < opt_config.num_hidden_layers; i++) { + // set transformer layer id + ff.set_transformer_layer_id(i); + // 125m, 1.7B, ..., 175B applies layer norm BEFORE attention, // 350m applies layer norm AFTER attention // https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py#LL324C1-L325C1 @@ -142,24 +121,6 @@ void OPT::create_opt_model(FFModel &ff, "_attention_layer_norm_weight", self_attn_layer_norm); - for (int dp_index = 0; dp_index < ff.config.data_parallelism_degree; - dp_index++) { - int pp_block_idx = i / num_layers_per_pp_block; - int first_device_idx = dp_index * num_devices_per_data_parallelism_line + - ff.config.tensor_parallelism_degree * pp_block_idx; - // std::cout << "assigning layer " << i << " to devices " << - // first_device_idx - // << "-" - // << first_device_idx + ff.config.tensor_parallelism_degree - 1 - // << std::endl; - assert(first_device_idx < num_devices); - mapping[hidden_states].push_back(machine_views[first_device_idx]); - } - // if (i % num_transformer_layers_per_stage == 0) { - // mapping[hidden_states].push_back( - // machine_views[i / num_transformer_layers_per_stage]); - // } - Tensor mha; switch (mode) { case BEAM_SEARCH_MODE: { @@ -279,7 +240,7 @@ void OPT::create_opt_model(FFModel &ff, //------------------- compile the model -------------------------------- std::cout << "------start compile ----------" << std::endl; - im.compile_model_and_allocate_buffer(&ff, mapping); + im.compile_model_and_allocate_buffer(&ff); FileDataLoader fileloader("", weight_file_path, opt_config.num_attention_heads, diff --git a/inference/models/opt.h b/inference/models/opt.h index 77d9aae962..d5fa845cd5 100644 --- a/inference/models/opt.h +++ b/inference/models/opt.h @@ -108,7 +108,6 @@ class OPT { InferenceManager &im, std::string const &model_config_file_path, std::string const &weight_file_path, - int num_pipeline_stages, InferenceMode mode, bool use_full_precision = false); }; diff --git a/inference/spec_infer/spec_infer.cc b/inference/spec_infer/spec_infer.cc index e5a6c8d5e6..9cdcb454a2 100644 --- a/inference/spec_infer/spec_infer.cc +++ b/inference/spec_infer/spec_infer.cc @@ -155,7 +155,7 @@ void FlexFlow::top_level_task(Task const *task, bool verbose = false; size_t num_devices = ffconfig.workersPerNode * ffconfig.numNodes; int data_parallelism_degree = 1, tensor_parallelism_degree = 1, - pipeline_parallelism_degree = -1; + pipeline_parallelism_degree = 1; InputArgs const &command_args = HighLevelRuntime::get_input_args(); char **argv = command_args.argv; @@ -171,10 +171,10 @@ void FlexFlow::top_level_task(Task const *task, pipeline_parallelism_degree); ffconfig.data_parallelism_degree = data_parallelism_degree; ffconfig.tensor_parallelism_degree = tensor_parallelism_degree; - ffconfig.pipeline_parallelism_degree = - pipeline_parallelism_degree == -1 - ? num_devices / (tensor_parallelism_degree * data_parallelism_degree) - : pipeline_parallelism_degree; + ffconfig.pipeline_parallelism_degree = pipeline_parallelism_degree; + assert(data_parallelism_degree * tensor_parallelism_degree * + pipeline_parallelism_degree == + ffconfig.numNodes * ffconfig.workersPerNode); if (file_paths.ssm_weight_file_paths.size() == 0) { assert(false && @@ -212,8 +212,6 @@ void FlexFlow::top_level_task(Task const *task, im, file_paths.llm_config_file_path, file_paths.llm_weight_file_path, - ffconfig.workersPerNode * ffconfig.numNodes / - tensor_parallelism_degree, TREE_VERIFY_MODE, use_full_precision); } else if (model_types.llm_model_type == ModelType::OPT) { @@ -221,8 +219,6 @@ void FlexFlow::top_level_task(Task const *task, im, file_paths.llm_config_file_path, file_paths.llm_weight_file_path, - ffconfig.workersPerNode * ffconfig.numNodes / - tensor_parallelism_degree, TREE_VERIFY_MODE, use_full_precision); } else { @@ -233,8 +229,11 @@ void FlexFlow::top_level_task(Task const *task, int num_ssms = model_types.ssm_model_types.size(); std::vector ssm_model_ids; std::vector ssm_models; + FFConfig bm_config = ffconfig; + bm_config.data_parallelism_degree = bm_config.tensor_parallelism_degree = + bm_config.pipeline_parallelism_degree = 1; for (int ssm_id = 0; ssm_id < num_ssms; ssm_id++) { - FFModel beam_model(ffconfig); + FFModel beam_model(bm_config); ssm_models.push_back(beam_model); } @@ -245,7 +244,6 @@ void FlexFlow::top_level_task(Task const *task, im, file_paths.ssm_config_file_paths[ssm_id], file_paths.ssm_weight_file_paths[ssm_id], - 1, BEAM_SEARCH_MODE, use_full_precision); } else if (model_types.ssm_model_types[ssm_id] == ModelType::OPT) { @@ -253,7 +251,6 @@ void FlexFlow::top_level_task(Task const *task, im, file_paths.ssm_config_file_paths[ssm_id], file_paths.ssm_weight_file_paths[ssm_id], - 1, BEAM_SEARCH_MODE, use_full_precision); } else { @@ -352,4 +349,4 @@ void FlexFlow::top_level_task(Task const *task, std::cout << "----------inference finished--------------" << std::endl; } -void FlexFlow::register_custom_tasks() {} \ No newline at end of file +void FlexFlow::register_custom_tasks() {} diff --git a/src/ops/arg_topk.cc b/src/ops/arg_topk.cc index eedd89bd5f..a604c016d2 100644 --- a/src/ops/arg_topk.cc +++ b/src/ops/arg_topk.cc @@ -88,7 +88,8 @@ Op *ArgTopK::create_operator_from_layer( int k = value; layer->get_int_property("sorted", value); bool sorted = (bool)value; - return new ArgTopK(model, inputs[0], k, sorted, layer->name); + return new ArgTopK( + model, layer->layer_guid, inputs[0], k, sorted, layer->name); } ArgTopKParams ArgTopK::get_params() const { @@ -108,6 +109,7 @@ bool operator==(ArgTopKParams const &lhs, ArgTopKParams const &rhs) { } ArgTopK::ArgTopK(FFModel &model, + LayerID const &_layer_guid, const ParallelTensor _input, int _k, bool _sorted, @@ -121,6 +123,8 @@ ArgTopK::ArgTopK(FFModel &model, 1 /*outputs*/, _input), k(_k), sorted(_sorted) { + // overwrite layer_guid + layer_guid = _layer_guid; int numdim = inputs[0]->num_dims; ParallelDim dims[MAX_TENSOR_DIM]; for (int i = 0; i < numdim; i++) { @@ -136,15 +140,16 @@ ArgTopK::ArgTopK(FFModel &model, } ArgTopK::ArgTopK(FFModel &model, + LayerID const &layer_guid, ArgTopK const &other, const ParallelTensor input) - : ArgTopK(model, input, other.k, other.sorted, other.name) {} + : ArgTopK(model, layer_guid, input, other.k, other.sorted, other.name) {} ArgTopK::ArgTopK(FFModel &model, ArgTopKParams const ¶ms, const ParallelTensor input, char const *name) - : ArgTopK(model, input, params.k, params.sorted, name) {} + : ArgTopK(model, params.layer_guid, input, params.k, params.sorted, name) {} void ArgTopK::init_inference(FFModel const &ff, std::vector const &batch_inputs, @@ -260,7 +265,7 @@ FutureMap ArgTopK::inference(FFModel const &ff, << std::endl; */ IndexLauncher launcher(ARG_TOPK_INF_TASK_ID, parallel_is, - TaskArgument(NULL, 0), + TaskArgument(&bc, sizeof(BatchConfig)), argmap, Predicate::TRUE_PRED, false /*must*/, @@ -295,6 +300,7 @@ InferenceResult assert(regions.size() == 2); assert(task->regions.size() == 2); // const ArgTopK* topk = (const ArgTopK*) task->args; + BatchConfig const *bc = (BatchConfig *)task->args; ArgTopKMeta const *m = *((ArgTopKMeta **)task->local_args); GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( @@ -302,10 +308,11 @@ InferenceResult GenericTensorAccessorW indices = helperGetGenericTensorAccessorWO( DT_INT32, regions[1], task->regions[1], FID_DATA, ctx, runtime); - ArgTopK::forward_kernel_wrapper(m, input, indices); + int batch_size = bc->num_active_tokens(); + ArgTopK::forward_kernel_wrapper(m, input, indices, batch_size); int length = input.domain.hi()[0] - input.domain.lo()[0] + 1; - int batch_size = input.domain.get_volume() / length; + batch_size = input.domain.get_volume() / length; InferenceResult ir; download_tensor( @@ -319,6 +326,8 @@ void ArgTopK::backward(FFModel const &ff) { } void ArgTopK::serialize(Legion::Serializer &sez) const { + sez.serialize(this->layer_guid.id); + sez.serialize(this->layer_guid.transformer_layer_id); sez.serialize(this->k); sez.serialize(this->sorted); } @@ -328,11 +337,16 @@ Node ArgTopK::deserialize(FFModel &ff, ParallelTensor inputs[], int num_inputs) { assert(num_inputs == 1); + size_t id, transformer_layer_id; + dez.deserialize(id); + dez.deserialize(transformer_layer_id); + LayerID layer_guid(id, transformer_layer_id); int k; bool sorted; dez.deserialize(k); dez.deserialize(sorted); ArgTopKParams params; + params.layer_guid = layer_guid; params.k = k; params.sorted = sorted; return ff.get_or_create_node(inputs[0], params); @@ -357,6 +371,7 @@ namespace std { size_t hash::operator()( FlexFlow::ArgTopKParams const ¶ms) const { size_t key = 0; + hash_combine(key, params.layer_guid.id); hash_combine(key, params.k); hash_combine(key, params.sorted); return key; diff --git a/src/ops/arg_topk.cpp b/src/ops/arg_topk.cpp index d055e09def..4937166b66 100644 --- a/src/ops/arg_topk.cpp +++ b/src/ops/arg_topk.cpp @@ -411,7 +411,8 @@ void ArgTopK::forward_kernel(ArgTopKMeta const *m, void ArgTopK::forward_kernel_wrapper(ArgTopKMeta const *m, GenericTensorAccessorR const &input, // float *output_ptr, - GenericTensorAccessorW const &indices) { + GenericTensorAccessorW const &indices, + int batch_size) { hipStream_t stream; checkCUDA(get_legion_stream(&stream)); // Domain in1_domain = runtime->get_index_space_domain( @@ -442,8 +443,8 @@ void ArgTopK::forward_kernel_wrapper(ArgTopKMeta const *m, int length = input.domain.hi()[0] - input.domain.lo()[0] + 1; int k = indices.domain.hi()[0] - indices.domain.lo()[0] + 1; /*TODO: This prints to 5*/ - size_t batch_size = input.domain.get_volume() / length; - assert(indices.domain.get_volume() / k == batch_size); + // size_t batch_size = input.domain.get_volume() / length; + // assert(indices.domain.get_volume() / k == batch_size); hipEvent_t t_start, t_end; if (m->profiling) { diff --git a/src/ops/arg_topk.cu b/src/ops/arg_topk.cu index 9583af525e..575e0183b4 100644 --- a/src/ops/arg_topk.cu +++ b/src/ops/arg_topk.cu @@ -406,7 +406,8 @@ void ArgTopK::forward_kernel(ArgTopKMeta const *m, void ArgTopK::forward_kernel_wrapper(ArgTopKMeta const *m, GenericTensorAccessorR const &input, // float *output_ptr, - GenericTensorAccessorW const &indices) { + GenericTensorAccessorW const &indices, + int batch_size) { cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); @@ -438,9 +439,8 @@ void ArgTopK::forward_kernel_wrapper(ArgTopKMeta const *m, int length = input.domain.hi()[0] - input.domain.lo()[0] + 1; int k = indices.domain.hi()[0] - indices.domain.lo()[0] + 1; /*TODO: This prints to 5*/ - size_t batch_size = input.domain.get_volume() / length; - assert(indices.domain.get_volume() / k == batch_size); - + // batch_size = input.domain.get_volume() / length; + // assert(indices.domain.get_volume() / k == batch_size); cudaEvent_t t_start, t_end; if (m->profiling) { cudaEventCreate(&t_start); diff --git a/src/ops/beam_topk.cc b/src/ops/beam_topk.cc index d67c84a9df..db507c1729 100644 --- a/src/ops/beam_topk.cc +++ b/src/ops/beam_topk.cc @@ -429,6 +429,7 @@ void BeamTopK::backward(FFModel const &ff) { void BeamTopK::serialize(Legion::Serializer &sez) const { sez.serialize(this->layer_guid.id); + sez.serialize(this->layer_guid.transformer_layer_id); sez.serialize(this->sorted); sez.serialize(this->max_beam_width); } @@ -439,10 +440,11 @@ Node BeamTopK::deserialize(FFModel &ff, int num_inputs) { assert(num_inputs == 1); bool sorted; - size_t id; + size_t id, transformer_layer_id; int max_beam_width; dez.deserialize(id); - LayerID layer_guid(id); + dez.deserialize(transformer_layer_id); + LayerID layer_guid(id, transformer_layer_id); dez.deserialize(sorted); dez.deserialize(max_beam_width); BeamTopKParams params; diff --git a/src/ops/conv_2d.cc b/src/ops/conv_2d.cc index 786c3427e9..ce7b6ebc01 100644 --- a/src/ops/conv_2d.cc +++ b/src/ops/conv_2d.cc @@ -1012,6 +1012,7 @@ bool Conv2D::estimate_sync_cost(Simulator *sim, void Conv2D::serialize(Legion::Serializer &sez) const { sez.serialize(this->layer_guid.id); + sez.serialize(this->layer_guid.transformer_layer_id); sez.serialize(this->out_channels); sez.serialize(this->kernel_h); sez.serialize(this->kernel_w); @@ -1036,9 +1037,10 @@ Node Conv2D::deserialize(FFModel &ff, padding_w, groups; bool use_bias; ActiMode activation; - size_t id; + size_t id, transformer_layer_id; dez.deserialize(id); - LayerID layer_guid(id); + dez.deserialize(transformer_layer_id); + LayerID layer_guid(id, transformer_layer_id); dez.deserialize(out_channels); dez.deserialize(kernel_h); dez.deserialize(kernel_w); diff --git a/src/ops/element_binary.cc b/src/ops/element_binary.cc index cf90919e6b..7562a727d7 100644 --- a/src/ops/element_binary.cc +++ b/src/ops/element_binary.cc @@ -97,8 +97,13 @@ Op *ElementBinary::create_operator_from_layer( long long value; layer->get_int_property("inplace_a", value); bool inplace_a = (bool)value; - return new ElementBinary( - model, layer->op_type, inputs[0], inputs[1], inplace_a, layer->name); + return new ElementBinary(model, + layer->layer_guid, + layer->op_type, + inputs[0], + inputs[1], + inplace_a, + layer->name); } Tensor FFModel::add(const Tensor in1, @@ -166,10 +171,12 @@ bool ElementBinaryParams::is_valid( bool operator==(ElementBinaryParams const &lhs, ElementBinaryParams const &rhs) { - return lhs.type == rhs.type; + return lhs.type == rhs.type && lhs.layer_guid == rhs.layer_guid && + lhs.inplace_a == rhs.inplace_a; } ElementBinary::ElementBinary(FFModel &model, + LayerID const &_layer_guid, OperatorType _op_type, const ParallelTensor in1, const ParallelTensor in2, @@ -185,6 +192,8 @@ ElementBinary::ElementBinary(FFModel &model, in1, in2), inplace_a(_inplace_a) { + // overwrite layer_guid + layer_guid = _layer_guid; numOutputs = 1; numWeights = 0; assert(in1->data_type == in2->data_type); @@ -217,10 +226,14 @@ ElementBinary::ElementBinary( FFModel &model, ElementBinaryParams const ¶ms, std::pair const &inputs, - char const *name, - bool inplace_a) - : ElementBinary( - model, params.type, inputs.first, inputs.second, inplace_a, name) {} + char const *name) + : ElementBinary(model, + params.layer_guid, + params.type, + inputs.first, + inputs.second, + params.inplace_a, + name) {} void ElementBinary::map_output_tensors(FFModel &ff) { if (has_inplace_output()) { @@ -975,9 +988,41 @@ bool ElementBinary::measure_operator_cost(Simulator *sim, return true; } +void ElementBinary::serialize(Legion::Serializer &sez) const { + sez.serialize(this->layer_guid.id); + sez.serialize(this->layer_guid.transformer_layer_id); + sez.serialize(this->op_type); + sez.serialize(this->inplace_a); +} + +using PCG::Node; +/*static*/ +Node ElementBinary::deserialize(FFModel &ff, + Legion::Deserializer &dez, + ParallelTensor inputs[], + int num_inputs) { + assert(num_inputs == 2); + OperatorType op_type; + size_t id, transformer_layer_id; + bool inplace_a; + dez.deserialize(id); + dez.deserialize(transformer_layer_id); + LayerID layer_guid(id, transformer_layer_id); + dez.deserialize(op_type); + dez.deserialize(inplace_a); + + ElementBinaryParams params; + params.layer_guid = layer_guid; + params.type = op_type; + params.inplace_a = inplace_a; + return ff.get_or_create_node({inputs[0], inputs[1]}, params); +} + ElementBinaryParams ElementBinary::get_params() const { ElementBinaryParams params; + params.layer_guid = this->layer_guid; params.type = this->op_type; + params.inplace_a = this->inplace_a; return params; } @@ -987,7 +1032,9 @@ namespace std { size_t hash::operator()( FlexFlow::ElementBinaryParams const ¶ms) const { size_t key = 0; + hash_combine(key, params.layer_guid.id); hash_combine(key, params.type); + hash_combine(key, params.inplace_a); return key; } }; // namespace std diff --git a/src/ops/element_unary.cc b/src/ops/element_unary.cc index f0713dd0a1..69533db53d 100644 --- a/src/ops/element_unary.cc +++ b/src/ops/element_unary.cc @@ -672,6 +672,7 @@ void ElementUnary::serialize(Legion::Serializer &sez) const { sez.serialize(this->inplace); sez.serialize(scalar); sez.serialize(this->layer_guid.id); + sez.serialize(this->layer_guid.transformer_layer_id); } bool ElementUnary::measure_operator_cost(Simulator *sim, @@ -782,9 +783,10 @@ Node ElementUnary::deserialize(FFModel &ff, dez.deserialize(op_type); dez.deserialize(inplace); dez.deserialize(scalar); - size_t id; + size_t id, transformer_layer_id; dez.deserialize(id); - LayerID layer_guid(id); + dez.deserialize(transformer_layer_id); + LayerID layer_guid(id, transformer_layer_id); ElementUnaryParams params; params.op_type = op_type; diff --git a/src/ops/experts.cc b/src/ops/experts.cc index 77cd748f9c..06e007abef 100644 --- a/src/ops/experts.cc +++ b/src/ops/experts.cc @@ -396,6 +396,7 @@ Experts::Experts(FFModel &model, void Experts::serialize(Legion::Serializer &sez) const { ExpertsParams params = get_params(); sez.serialize(params.layer_guid.id); + sez.serialize(params.layer_guid.transformer_layer_id); sez.serialize(params.num_experts); sez.serialize(params.experts_start_idx); sez.serialize(params.experts_output_dim_size); @@ -416,9 +417,10 @@ Node Experts::deserialize(FFModel &ff, float alpha; ActiMode activation; bool use_bias; - size_t id; + size_t id, transformer_layer_id; dez.deserialize(id); - LayerID layer_guid(id); + dez.deserialize(transformer_layer_id); + LayerID layer_guid(id, transformer_layer_id); dez.deserialize(num_experts); dez.deserialize(experts_start_idx); dez.deserialize(experts_output_dim_size); diff --git a/src/ops/fused.cc b/src/ops/fused.cc index 3dc442708f..cf01f5bd1e 100644 --- a/src/ops/fused.cc +++ b/src/ops/fused.cc @@ -100,6 +100,7 @@ FusedOp::FusedOp(FFModel &model, Op *op) op_num_outputs[0] = op->numOutputs; op_op_type[0] = op->op_type; operators[0] = op; + layer_guid = op->layer_guid; // for (int i = 0; i < numInputs; i++) { // op_input_source[i] = SOURCE_INPUT; // op_input_idx[i] = i; @@ -127,9 +128,9 @@ bool FusedOp::add_operator(FFModel &model, Op *op) { // assert(model.config.find_parallel_config(my_domain.get_dim(), name, // my_config)); assert(model.config.find_parallel_config(op_domain.get_dim(), // op->name, op_config)); - // Cannot fuse parallel operators since they have different paralel_is - // in forward and backward - assert(!op->is_parallel_op()); + // Cannot fuse parallel operators (except allreduce) since they have different + // paralel_is in forward and backward + assert(!op->is_parallel_op() || op->op_type == OP_ALLREDUCE); // Currently don't consider nested fusion assert(op->op_type != OP_FUSED); MachineView my_view = outputs[0]->machine_view; @@ -149,12 +150,14 @@ bool FusedOp::add_operator(FFModel &model, Op *op) { (weight_offset + op->numWeights > MAX_NUM_FUSED_TENSORS) || (output_offset + op->numOutputs > MAX_NUM_FUSED_TENSORS)) { fprintf(stderr, "Cannot fuse. Consider increase MAX_NUM_FUSED_TENSORS\n"); + assert(false); return false; } if (numOperators + 1 > MAX_NUM_FUSED_OPERATORS) { fprintf( stderr, "Reach to the fusion limit. Consider increase MAX_NUM_FUSED_OPERATORS"); + assert(false); return false; } // Set inputs @@ -331,6 +334,92 @@ void FusedOp::init(FFModel const &ff) { } } +void FusedOp::init_inference(FFModel const &ff, + std::vector const &batch_inputs, + std::vector const &batch_outputs, + MachineView const *mv) { + assert(check_output_input_weight_same_parallel_is()); + parallel_is = batch_outputs[0]->parallel_is; + ArgumentMap argmap; + Context ctx = ff.config.lg_ctx; + Runtime *runtime = ff.config.lg_hlr; + // Call init methods in individual operators + Domain domain = runtime->get_index_space_domain(ctx, parallel_is); + int ioff = 0, ooff = 0; + for (int op = 0; op < numOperators; op++) { + // prepare batch_inputs, batch_outputs for operators[i] + std::vector my_batch_inputs; + std::vector my_batch_outputs; + for (int i = 0; i < op_num_inputs[op]; i++) { + int my_off = op_input_idx[i + ioff]; + if (op_input_source[i + ioff] == SOURCE_INPUT) { + my_batch_inputs.push_back(batch_inputs[my_off]); + } else if (op_input_source[i + ioff] == SOURCE_OUTPUT) { + my_batch_inputs.push_back(batch_outputs[my_off]); + } else { + assert(false); + } + } + for (int i = 0; i < op_num_outputs[op]; i++) { + assert(op_output_source[i + ooff] == SOURCE_OUTPUT); + my_batch_outputs.push_back(batch_outputs[i + ooff]); + } + ioff += op_num_inputs[op]; + ooff += op_num_outputs[op]; + operators[op]->init_inference(ff, my_batch_inputs, my_batch_outputs, mv); + for (size_t j = 0; j < domain.get_volume(); j++) { + fused_meta[j].meta[op] = + operators[op]->inference_meta[my_batch_outputs[0]][j]; + } + } + for (size_t j = 0; j < domain.get_volume(); j++) { + fused_meta[j].numOperators = numOperators; + } + switch (domain.get_dim()) { +#define DIMFUNC(DIM) \ + case DIM: { \ + Rect rect = domain; \ + int idx = 0; \ + for (PointInRectIterator it(rect); it(); it++) { \ + argmap.set_point(*it, \ + TaskArgument(&fused_meta[idx++], sizeof(FusedOpMeta))); \ + } \ + break; \ + } + LEGION_FOREACH_N(DIMFUNC) +#undef DIMFUNC + default: + assert(false); + } + MachineView const *view = mv ? mv : &batch_outputs[0]->machine_view; + size_t machine_view_hash = view->hash(); + IndexLauncher launcher(FUSEDOP_INIT_TASK_ID, + parallel_is, + TaskArgument(this, sizeof(FusedOp)), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + machine_view_hash); + FutureMap fm = runtime->execute_index_space(ctx, launcher); + fm.wait_all_results(); + switch (domain.get_dim()) { +#define DIMFUNC(DIM) \ + case DIM: { \ + Rect rect = domain; \ + int idx = 0; \ + for (PointInRectIterator it(rect); it(); it++) { \ + inference_meta[batch_outputs[0]][idx++] = fm.get_result(*it); \ + } \ + break; \ + } + LEGION_FOREACH_N(DIMFUNC) +#undef DIMFUNC + default: + assert(false); + } +} + void FusedOp::forward(FFModel const &ff) { // Set iter_config iter_config = ff.iter_config; @@ -380,6 +469,66 @@ void FusedOp::forward(FFModel const &ff) { runtime->execute_index_space(ctx, launcher); } +FutureMap FusedOp::inference(FFModel const &ff, + BatchConfig const &bc, + std::vector const &batch_inputs, + std::vector const &batch_outputs, + MachineView const *mv) { + // Set iter_config + iter_config = ff.iter_config; + ArgumentMap argmap; + Context ctx = ff.config.lg_ctx; + Runtime *runtime = ff.config.lg_hlr; + set_argumentmap_for_inference(ff, argmap, batch_outputs[0]); + MachineView const *view = mv ? mv : &batch_outputs[0]->machine_view; + size_t machine_view_hash = view->hash(); + // bc is one of BatchConfig, TreeVerifyBatchConfig, and BeamSearchBatchConfig + // so we transfer the maximum of them + size_t batch_config_size = + std::max(sizeof(TreeVerifyBatchConfig), sizeof(BeamSearchBatchConfig)); + IndexLauncher launcher(FUSEDOP_INF_TASK_ID, + parallel_is, + TaskArgument(&bc, batch_config_size), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + machine_view_hash); + int offset = 0; + for (int i = 0; i < numInputs; i++) { + assert(inputs[i]->part != LogicalPartition::NO_PART); + assert(inputs[i]->region != LogicalRegion::NO_REGION); + launcher.add_region_requirement(RegionRequirement(batch_inputs[i]->part, + 0 /*projection id*/, + READ_ONLY, + EXCLUSIVE, + batch_inputs[i]->region)); + launcher.add_field(offset + i, FID_DATA); + } + offset += numInputs; + for (int i = 0; i < numWeights; i++) { + assert(weights[i]->region != LogicalRegion::NO_REGION); + launcher.add_region_requirement(RegionRequirement(weights[i]->part, + 0 /*projection id*/, + READ_ONLY, + EXCLUSIVE, + weights[i]->region)); + launcher.add_field(offset + i, FID_DATA); + } + offset += numWeights; + for (int i = 0; i < numOutputs; i++) { + assert(outputs[i]->region != LogicalRegion::NO_REGION); + launcher.add_region_requirement( + RegionRequirement(batch_outputs[i]->part, + 0 /*projection id*/, + WRITE_ONLY, + EXCLUSIVE, + batch_outputs[i]->region)); + launcher.add_field(offset + i, FID_DATA); + } + return runtime->execute_index_space(ctx, launcher); +} + void FusedOp::backward(FFModel const &ff) { // Set iter_config iter_config = ff.iter_config; diff --git a/src/ops/fused.cpp b/src/ops/fused.cpp index 712ed143b1..c717881e66 100644 --- a/src/ops/fused.cpp +++ b/src/ops/fused.cpp @@ -14,20 +14,29 @@ */ #include "flexflow/ops/fused.h" +#include "flexflow/accessor.h" #include "flexflow/model.h" #include "flexflow/ops/batch_norm.h" #include "flexflow/ops/element_unary.h" +#include "flexflow/ops/embedding.h" +#include "flexflow/ops/inc_multihead_self_attention.h" #include "flexflow/ops/kernels/batch_matmul_kernels.h" #include "flexflow/ops/kernels/concat_kernels.h" #include "flexflow/ops/kernels/conv_2d_kernels.h" #include "flexflow/ops/kernels/dropout_kernels.h" #include "flexflow/ops/kernels/element_binary_kernels.h" +#include "flexflow/ops/kernels/embedding_kernels.h" #include "flexflow/ops/kernels/flat_kernels.h" #include "flexflow/ops/kernels/linear_kernels.h" #include "flexflow/ops/kernels/pool_2d_kernels.h" #include "flexflow/ops/kernels/reshape_kernels.h" +#include "flexflow/ops/kernels/rms_norm_kernels.h" #include "flexflow/ops/kernels/transpose_kernels.h" +#include "flexflow/ops/layer_norm.h" #include "flexflow/ops/linear.h" +#include "flexflow/ops/spec_inc_multihead_self_attention.h" +#include "flexflow/ops/tree_inc_multihead_self_attention.h" +#include "flexflow/parallel_ops/kernels/allreduce_kernels.h" #include "flexflow/utils/hip_helper.h" #include @@ -373,6 +382,414 @@ __host__ void FusedOp::forward_task(Task const *task, // "[Fused:forward:output]"); } +/* + regions[...](I): inputs + regions[...](I): weights + regions[...](I): outputs +*/ +__host__ void + FusedOp::inference_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + // const FusedOp* fused = (FusedOp*) task->args; + FusedOpMeta const *metas = *((FusedOpMeta **)task->local_args); + FusedOp const *fused = metas->fused_op; + BatchConfig const *bc = (BatchConfig *)task->args; + assert(metas->numOperators == fused->numOperators); + assert(regions.size() == task->regions.size()); + assert((int)regions.size() == + fused->numInputs + fused->numWeights + fused->numOutputs); + GenericTensorAccessorR input_accessor[MAX_NUM_INPUTS]; + GenericTensorAccessorR weight_accessor[MAX_NUM_WEIGHTS]; + GenericTensorAccessorW output_accessor[MAX_NUM_OUTPUTS]; + assert(fused->numInputs <= MAX_NUM_INPUTS); + for (int i = 0; i < fused->numInputs; i++) { + input_accessor[i] = + helperGetGenericTensorAccessorRO(fused->input_data_types[i], + regions[i], + task->regions[i], + FID_DATA, + ctx, + runtime); + } + int roff = fused->numInputs; + assert(fused->numWeights <= MAX_NUM_WEIGHTS); + for (int i = 0; i < fused->numWeights; i++) { + weight_accessor[i] = + helperGetGenericTensorAccessorRO(fused->weight_data_types[i], + regions[i + roff], + task->regions[i + roff], + FID_DATA, + ctx, + runtime); + } + roff += fused->numWeights; + assert(fused->numOutputs <= MAX_NUM_OUTPUTS); + for (int i = 0; i < fused->numOutputs; i++) { + output_accessor[i] = + helperGetGenericTensorAccessorWO(fused->output_data_types[i], + regions[i + roff], + task->regions[i + roff], + FID_DATA, + ctx, + runtime); + } + // Assert that all meta share the same dnn/blas handler + int start = 0; + for (start = 0; start < fused->numOperators; start++) { + if (metas->meta[start] != NULL) { + break; + } + } + for (int op = start + 1; op < fused->numOperators; op++) { + if (metas->meta[op] != NULL) { + assert(metas->meta[start]->handle.blas == metas->meta[op]->handle.blas); + assert(metas->meta[start]->handle.dnn == metas->meta[op]->handle.dnn); + } + } + + hipStream_t stream; + if (start < fused->numOperators) { + checkCUDA(get_legion_stream(&stream)); + } + + int ioff = 0, woff = 0, ooff = 0; + for (int op = 0; op < fused->numOperators; op++) { + GenericTensorAccessorR my_input_accessor[MAX_NUM_INPUTS]; + GenericTensorAccessorR my_weight_accessor[MAX_NUM_WEIGHTS]; + GenericTensorAccessorW my_output_accessor[MAX_NUM_OUTPUTS]; + for (int i = 0; i < fused->op_num_inputs[op]; i++) { + int my_off = fused->op_input_idx[i + ioff]; + if (fused->op_input_source[i + ioff] == SOURCE_INPUT) { + my_input_accessor[i] = input_accessor[my_off]; + } else if (fused->op_input_source[i + ioff] == SOURCE_OUTPUT) { + my_input_accessor[i] = output_accessor[my_off]; + } else { + assert(false); + } + } + for (int i = 0; i < fused->op_num_weights[op]; i++) { + assert(fused->op_weight_source[i + woff] == SOURCE_WEIGHT); + my_weight_accessor[i] = weight_accessor[fused->op_weight_idx[i + woff]]; + } + for (int i = 0; i < fused->op_num_outputs[op]; i++) { + assert(fused->op_output_source[i + ooff] == SOURCE_OUTPUT); + my_output_accessor[i] = output_accessor[i + ooff]; + } + switch (fused->op_op_type[op]) { + case OP_CONCAT: { + assert(fused->op_num_weights[op] == 0); + assert(fused->op_num_outputs[op] == 1); + ConcatMeta *m = (ConcatMeta *)metas->meta[op]; + int num_inputs = fused->op_num_inputs[op]; + Kernels::Concat::forward_kernel_wrapper(m, + my_output_accessor[0], + my_input_accessor, + num_inputs, + m->legion_axis); + break; + } + case OP_BATCHNORM: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + assert(my_input_accessor[0].domain.get_dim() == 5); + assert(my_output_accessor[0].domain.get_dim() == 5); + assert(my_weight_accessor[0].domain.get_dim() == 2); + assert(my_weight_accessor[1].domain.get_dim() == 2); + BatchNormMeta *m = (BatchNormMeta *)metas->meta[op]; + BatchNorm::forward_kernel(m, + my_input_accessor[0].get_float_ptr(), + my_output_accessor[0].get_float_ptr(), + my_weight_accessor[0].get_float_ptr(), + my_weight_accessor[1].get_float_ptr()); + break; + } + case OP_LINEAR: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + Domain kernel_domain = my_weight_accessor[0].domain; + int in_dim = kernel_domain.hi()[0] - kernel_domain.lo()[0] + 1; + int out_dim = kernel_domain.hi()[1] - kernel_domain.lo()[1] + 1; + int batch_size = my_input_accessor[0].domain.get_volume() / in_dim; + assert(my_output_accessor[0].domain.get_volume() == + out_dim * batch_size); + assert(my_input_accessor[0].domain.get_volume() == in_dim * batch_size); + void const *bias_ptr = nullptr; + if (fused->op_num_weights[op] == 2) { + assert(my_weight_accessor[1].domain.get_volume() == out_dim); + bias_ptr = my_weight_accessor[1].ptr; + } else { + assert(fused->op_num_weights[op] == 1); + } + LinearMeta *m = (LinearMeta *)metas->meta[op]; + assert(m->input_type[0] == my_input_accessor[0].data_type); + assert(m->input_type[0] == my_output_accessor[0].data_type); + batch_size = bc->num_active_tokens(); + Kernels::Linear::forward_kernel_wrapper(m, + my_input_accessor[0].ptr, + my_output_accessor[0].ptr, + my_weight_accessor[0].ptr, + bias_ptr, + in_dim, + out_dim, + batch_size); + break; + } + case OP_BATCHMATMUL: { + assert(fused->op_num_inputs[op] == 2); + assert(fused->op_num_weights[op] == 0); + assert(fused->op_num_outputs[op] == 1); + Domain out_domain = my_output_accessor[0].domain; + Domain a_domain = my_input_accessor[0].domain; + Domain b_domain = my_input_accessor[1].domain; + int m = b_domain.hi()[0] - b_domain.lo()[0] + 1; + assert(m == out_domain.hi()[0] - out_domain.lo()[0] + 1); + int n = a_domain.hi()[1] - a_domain.lo()[1] + 1; + assert(n == out_domain.hi()[1] - out_domain.lo()[1] + 1); + int k = a_domain.hi()[0] - a_domain.lo()[0] + 1; + assert(k == b_domain.hi()[1] - b_domain.lo()[1] + 1); + assert(a_domain.get_dim() == b_domain.get_dim()); + assert(a_domain.get_dim() == out_domain.get_dim()); + int batch = 1; + for (int i = 2; i < a_domain.get_dim(); i++) { + int dim_size = a_domain.hi()[i] - a_domain.lo()[i] + 1; + assert(dim_size == b_domain.hi()[i] - b_domain.lo()[i] + 1); + assert(dim_size == out_domain.hi()[i] - out_domain.lo()[i] + 1); + batch *= dim_size; + } + BatchMatmulMeta *meta = (BatchMatmulMeta *)metas->meta[op]; + Kernels::BatchMatmul::forward_kernel_wrapper( + meta, + my_output_accessor[0].get_float_ptr(), + my_input_accessor[0].get_float_ptr(), + my_input_accessor[1].get_float_ptr(), + (float const *)nullptr, + m, + n, + k, + batch, + meta->a_seq_length_dim, + meta->b_seq_length_dim, + fused->iter_config.seq_length); + break; + } + case OP_EW_ADD: + case OP_EW_SUB: + case OP_EW_MUL: + case OP_EW_DIV: + case OP_EW_MAX: + case OP_EW_MIN: { + assert(fused->op_num_inputs[op] == 2); + assert(fused->op_num_weights[op] == 0); + assert(fused->op_num_outputs[op] == 1); + assert(my_input_accessor[0].domain == my_input_accessor[1].domain); + assert(my_input_accessor[0].domain == my_output_accessor[0].domain); + ElementBinaryMeta *m = (ElementBinaryMeta *)metas->meta[op]; + Kernels::ElementBinary::forward_kernel_wrapper(m, + my_input_accessor[0], + my_input_accessor[1], + my_output_accessor[0]); + break; + break; + } + case OP_EMBEDDING: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_weights[op] == 1); + assert(fused->op_num_outputs[op] == 1); + EmbeddingMeta *m = (EmbeddingMeta *)metas->meta[op]; + if (m->aggr == AGGR_MODE_NONE) { + // assert(kernel_domain.get_dim() == 2); + assert(my_input_accessor[0].domain.get_dim() + 1 == + my_output_accessor[0].domain.get_dim()); + for (size_t i = 0; i < my_input_accessor[0].domain.get_dim(); i++) { + assert(my_input_accessor[0].domain.hi()[i] == + my_output_accessor[0].domain.hi()[i + 1]); + assert(my_input_accessor[0].domain.lo()[i] == + my_output_accessor[0].domain.lo()[i + 1]); + } + assert(my_weight_accessor[0].domain.hi()[0] - + my_weight_accessor[0].domain.lo()[0] == + my_output_accessor[0].domain.hi()[0] - + my_output_accessor[0].domain.lo()[0]); + } else { + assert(my_input_accessor[0].domain.get_dim() == + my_output_accessor[0].domain.get_dim()); + for (size_t i = 1; i < my_input_accessor[0].domain.get_dim(); i++) { + assert(my_input_accessor[0].domain.hi()[i] == + my_output_accessor[0].domain.hi()[i]); + assert(my_input_accessor[0].domain.lo()[i] == + my_output_accessor[0].domain.lo()[i]); + } + assert(my_weight_accessor[0].domain.hi()[0] - + my_weight_accessor[0].domain.lo()[0] == + my_output_accessor[0].domain.hi()[0] - + my_output_accessor[0].domain.lo()[0]); + } + int in_dim, out_dim, effective_batch_size; + if (m->aggr == AGGR_MODE_NONE) { + in_dim = 1; + out_dim = my_output_accessor[0].domain.hi()[0] - + my_output_accessor[0].domain.lo()[0] + 1; + effective_batch_size = + my_output_accessor[0].domain.get_volume() / out_dim; + assert(effective_batch_size * in_dim == + my_input_accessor[0].domain.get_volume()); + } else { + assert(m->aggr == AGGR_MODE_AVG || m->aggr == AGGR_MODE_SUM); + in_dim = my_input_accessor[0].domain.hi()[0] - + my_input_accessor[0].domain.lo()[0] + 1; + out_dim = my_output_accessor[0].domain.hi()[0] - + my_output_accessor[0].domain.lo()[0] + 1; + effective_batch_size = + my_output_accessor[0].domain.get_volume() / out_dim; + assert(effective_batch_size * in_dim == + my_input_accessor[0].domain.get_volume()); + } + + assert(my_input_accessor[0].data_type == DT_INT32 || + my_input_accessor[0].data_type == DT_INT64); + Kernels::Embedding::forward_kernel_wrapper(m, + my_input_accessor[0], + my_output_accessor[0], + my_weight_accessor[0], + in_dim, + out_dim, + effective_batch_size); + break; + } + case OP_RELU: + case OP_SIGMOID: + case OP_TANH: + case OP_ELU: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_weights[op] == 0); + assert(fused->op_num_outputs[op] == 1); + assert(my_input_accessor[0].domain == my_output_accessor[0].domain); + ElementUnaryMeta *m = (ElementUnaryMeta *)metas->meta[op]; + ElementUnary::forward_kernel_wrapper( + m, + my_input_accessor[0].get_float_ptr(), + my_output_accessor[0].get_float_ptr(), + my_input_accessor[0].domain.get_volume()); + break; + } + case OP_RMS_NORM: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_weights[op] == 1); + assert(fused->op_num_outputs[op] == 1); + RMSNormMeta const *m = (RMSNormMeta *)metas->meta[op]; + Kernels::RMSNorm::forward_kernel_wrapper(m, + my_input_accessor[0], + my_weight_accessor[0], + my_output_accessor[0]); + break; + } + case OP_INC_MULTIHEAD_SELF_ATTENTION: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + IncMultiHeadSelfAttentionMeta const *m = + (IncMultiHeadSelfAttentionMeta *)metas->meta[op]; + assert(fused->op_num_weights[op] == (1 + (int)(*m->bias))); + GenericTensorAccessorR biases; + if (*m->bias) { + assert(fused->op_num_weights[op] == 2); + biases = my_weight_accessor[1]; + } + IncMultiHeadSelfAttention::inference_kernel_wrapper( + m, + bc, + task->index_point.point_data[0], + my_input_accessor[0], + my_weight_accessor[0], + my_output_accessor[0], + biases); + break; + } + case OP_TREE_INC_MULTIHEAD_SELF_ATTENTION: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + TreeIncMultiHeadSelfAttentionMeta *m = + (TreeIncMultiHeadSelfAttentionMeta *)metas->meta[op]; + TreeVerifyBatchConfig const *tree_bc = + (TreeVerifyBatchConfig *)task->args; + assert(fused->op_num_weights[op] == (1 + (int)(*m->bias))); + GenericTensorAccessorR biases; + if (*m->bias) { + assert(fused->op_num_weights[op] == 2); + biases = my_weight_accessor[1]; + } + TreeIncMultiHeadSelfAttention::inference_kernel_wrapper( + m, + tree_bc, + task->index_point.point_data[0], + my_input_accessor[0], + my_weight_accessor[0], + my_output_accessor[0], + biases); + break; + } + case OP_SPEC_INC_MULTIHEAD_SELF_ATTENTION: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + SpecIncMultiHeadSelfAttentionMeta const *m = + (SpecIncMultiHeadSelfAttentionMeta *)metas->meta[op]; + BeamSearchBatchConfig const *beam_bc = + (BeamSearchBatchConfig *)task->args; + assert(fused->op_num_weights[op] == (1 + (int)(*m->bias))); + GenericTensorAccessorR biases; + if (*m->bias) { + assert(fused->op_num_weights[op] == 2); + biases = my_weight_accessor[1]; + } + SpecIncMultiHeadSelfAttention::inference_kernel_wrapper( + m, + beam_bc, + task->index_point.point_data[0], + my_input_accessor[0], + my_weight_accessor[0], + my_output_accessor[0], + biases); + break; + } + case OP_LAYERNORM: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + LayerNormMeta const *m = (LayerNormMeta *)metas->meta[op]; + assert(fused->op_num_weights[op] == 2 * (int)(m->elementwise_affine)); + GenericTensorAccessorR gamma, beta; + if (m->elementwise_affine) { + gamma = my_weight_accessor[0]; + beta = my_weight_accessor[1]; + } + LayerNorm::forward_kernel_wrapper( + m, my_input_accessor[0], my_output_accessor[0], gamma, beta); + break; + } + case OP_ALLREDUCE: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + AllReduceMeta const *m = (AllReduceMeta *)metas->meta[op]; + Kernels::AllReduce::forward_kernel_wrapper( + m, my_input_accessor[0], my_output_accessor[0]); + break; + } + default: { + fprintf(stderr, + "Fusion currently does not support type = %d\n", + fused->op_op_type[op]); + assert(false && "Fusion currently does not support type"); + } + } + ioff += fused->op_num_inputs[op]; + woff += fused->op_num_weights[op]; + ooff += fused->op_num_outputs[op]; + } + // for (int i = 0; i < fused->numOutputs; i++) + // print_tensor(output_ptr[i], output_domain[i].get_volume(), + // "[Fused:forward:output]"); +} + /* regions[...](I): input regions[...](I): weight diff --git a/src/ops/fused.cu b/src/ops/fused.cu index 17b0f9616d..2f84100554 100644 --- a/src/ops/fused.cu +++ b/src/ops/fused.cu @@ -20,6 +20,7 @@ #include "flexflow/ops/embedding.h" #include "flexflow/ops/flat.h" #include "flexflow/ops/fused.h" +#include "flexflow/ops/inc_multihead_self_attention.h" #include "flexflow/ops/kernels/batch_matmul_kernels.h" #include "flexflow/ops/kernels/concat_kernels.h" #include "flexflow/ops/kernels/conv_2d_kernels.h" @@ -30,7 +31,12 @@ #include "flexflow/ops/kernels/linear_kernels.h" #include "flexflow/ops/kernels/pool_2d_kernels.h" #include "flexflow/ops/kernels/reshape_kernels.h" +#include "flexflow/ops/kernels/rms_norm_kernels.h" #include "flexflow/ops/kernels/transpose_kernels.h" +#include "flexflow/ops/layer_norm.h" +#include "flexflow/ops/spec_inc_multihead_self_attention.h" +#include "flexflow/ops/tree_inc_multihead_self_attention.h" +#include "flexflow/parallel_ops/kernels/allreduce_kernels.h" #include "flexflow/utils/cuda_helper.h" namespace FlexFlow { @@ -62,7 +68,7 @@ OpMeta *FusedOp::init_task(Task const *task, /* regions[...](I): inputs regions[...](I): weights - regions[...](I): outputs + regions[...](O): outputs */ __host__ void FusedOp::forward_task(Task const *task, std::vector const ®ions, @@ -357,7 +363,8 @@ __host__ void FusedOp::forward_task(Task const *task, my_input_accessor[0].domain.get_volume()); } - assert(my_input_accessor[0].data_type == DT_INT64); + assert(my_input_accessor[0].data_type == DT_INT32 || + my_input_accessor[0].data_type == DT_INT64); Kernels::Embedding::forward_kernel_wrapper(m, my_input_accessor[0], my_output_accessor[0], @@ -450,6 +457,436 @@ __host__ void FusedOp::forward_task(Task const *task, // "[Fused:forward:output]"); } +/* + regions[...](I): inputs + regions[...](I): weights + regions[...](O): outputs +*/ +__host__ void + FusedOp::inference_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + // const FusedOp* fused = (FusedOp*) task->args; + FusedOpMeta const *metas = *((FusedOpMeta **)task->local_args); + FusedOp const *fused = metas->fused_op; + BatchConfig const *bc = (BatchConfig *)task->args; + assert(metas->numOperators == fused->numOperators); + assert(regions.size() == task->regions.size()); + assert((int)regions.size() == + fused->numInputs + fused->numWeights + fused->numOutputs); + // Domain input_domain[MAX_NUM_INPUTS]; + // Domain weight_domain[MAX_NUM_WEIGHTS]; + // Domain output_domain[MAX_NUM_OUTPUTS]; + GenericTensorAccessorR input_accessor[MAX_NUM_INPUTS]; + GenericTensorAccessorR weight_accessor[MAX_NUM_WEIGHTS]; + GenericTensorAccessorW output_accessor[MAX_NUM_OUTPUTS]; + assert(fused->numInputs <= MAX_NUM_INPUTS); + for (int i = 0; i < fused->numInputs; i++) { + // input_domain[i] = runtime->get_index_space_domain( + // ctx, task->regions[i].region.get_index_space()); + input_accessor[i] = + helperGetGenericTensorAccessorRO(fused->input_data_types[i], + regions[i], + task->regions[i], + FID_DATA, + ctx, + runtime); + } + int roff = fused->numInputs; + assert(fused->numWeights <= MAX_NUM_WEIGHTS); + for (int i = 0; i < fused->numWeights; i++) { + // weight_domain[i] = runtime->get_index_space_domain( + // ctx, task->regions[i + roff].region.get_index_space()); + weight_accessor[i] = + helperGetGenericTensorAccessorRO(fused->weight_data_types[i], + regions[i + roff], + task->regions[i + roff], + FID_DATA, + ctx, + runtime); + } + roff += fused->numWeights; + assert(fused->numOutputs <= MAX_NUM_OUTPUTS); + for (int i = 0; i < fused->numOutputs; i++) { + // output_domain[i] = runtime->get_index_space_domain( + // ctx, task->regions[i + roff].region.get_index_space()); + output_accessor[i] = + helperGetGenericTensorAccessorWO(fused->output_data_types[i], + regions[i + roff], + task->regions[i + roff], + FID_DATA, + ctx, + runtime); + } + // Assert that all meta share the same dnn/blas handler + int start = 0; + for (start = 0; start < fused->numOperators; start++) { + if (metas->meta[start] != NULL) { + break; + } + } + for (int op = start + 1; op < fused->numOperators; op++) { + if (metas->meta[op] != NULL) { + assert(metas->meta[start]->handle.blas == metas->meta[op]->handle.blas); + assert(metas->meta[start]->handle.dnn == metas->meta[op]->handle.dnn); + } + } + + int ioff = 0, woff = 0, ooff = 0; + for (int op = 0; op < fused->numOperators; op++) { + // Domain my_id[MAX_NUM_INPUTS]; + // Domain my_wd[MAX_NUM_WEIGHTS]; + // Domain my_od[MAX_NUM_OUTPUTS]; + GenericTensorAccessorR my_input_accessor[MAX_NUM_INPUTS]; + GenericTensorAccessorR my_weight_accessor[MAX_NUM_WEIGHTS]; + GenericTensorAccessorW my_output_accessor[MAX_NUM_OUTPUTS]; + for (int i = 0; i < fused->op_num_inputs[op]; i++) { + int my_off = fused->op_input_idx[i + ioff]; + if (fused->op_input_source[i + ioff] == SOURCE_INPUT) { + // my_id[i] = input_domain[my_off]; + my_input_accessor[i] = input_accessor[my_off]; + } else if (fused->op_input_source[i + ioff] == SOURCE_OUTPUT) { + // my_id[i] = output_domain[my_off]; + my_input_accessor[i] = output_accessor[my_off]; + } else { + assert(false); + } + } + for (int i = 0; i < fused->op_num_weights[op]; i++) { + assert(fused->op_weight_source[i + woff] == SOURCE_WEIGHT); + // my_wd[i] = weight_domain[fused->op_weight_idx[i + woff]]; + // my_wp[i] = weight_ptr[fused->op_weight_idx[i + woff]]; + my_weight_accessor[i] = weight_accessor[fused->op_weight_idx[i + woff]]; + } + for (int i = 0; i < fused->op_num_outputs[op]; i++) { + assert(fused->op_output_source[i + ooff] == SOURCE_OUTPUT); + // my_od[i] = output_domain[fused->op_output_idx[i + ooff]]; + // my_op[i] = output_ptr[fused->op_output_idx[i + ooff]]; + my_output_accessor[i] = output_accessor[i + ooff]; + } + switch (fused->op_op_type[op]) { + case OP_CONCAT: { + assert(fused->op_num_weights[op] == 0); + assert(fused->op_num_outputs[op] == 1); + ConcatMeta *m = (ConcatMeta *)metas->meta[op]; + int num_inputs = fused->op_num_inputs[op]; + Kernels::Concat::forward_kernel_wrapper(m, + my_output_accessor[0], + my_input_accessor, + num_inputs, + m->legion_axis); + break; + } + case OP_BATCHNORM: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + assert(my_input_accessor[0].domain.get_dim() == 5); + assert(my_output_accessor[0].domain.get_dim() == 5); + assert(my_weight_accessor[0].domain.get_dim() == 2); + assert(my_weight_accessor[1].domain.get_dim() == 2); + BatchNormMeta *m = (BatchNormMeta *)metas->meta[op]; + BatchNorm::forward_kernel(m, + my_input_accessor[0].get_float_ptr(), + my_output_accessor[0].get_float_ptr(), + my_weight_accessor[0].get_float_ptr(), + my_weight_accessor[1].get_float_ptr()); + break; + } + case OP_LINEAR: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + Domain kernel_domain = my_weight_accessor[0].domain; + int in_dim = kernel_domain.hi()[0] - kernel_domain.lo()[0] + 1; + int out_dim = kernel_domain.hi()[1] - kernel_domain.lo()[1] + 1; + int batch_size = my_input_accessor[0].domain.get_volume() / in_dim; + assert(my_output_accessor[0].domain.get_volume() == + out_dim * batch_size); + assert(my_input_accessor[0].domain.get_volume() == in_dim * batch_size); + void const *bias_ptr = nullptr; + if (fused->op_num_weights[op] == 2) { + assert(my_weight_accessor[1].domain.get_volume() == out_dim); + bias_ptr = my_weight_accessor[1].ptr; + } else { + assert(fused->op_num_weights[op] == 1); + } + LinearMeta *m = (LinearMeta *)metas->meta[op]; + assert(m->input_type[0] == my_input_accessor[0].data_type); + assert(m->input_type[0] == my_output_accessor[0].data_type); + batch_size = bc->num_active_tokens(); + Kernels::Linear::forward_kernel_wrapper(m, + my_input_accessor[0].ptr, + my_output_accessor[0].ptr, + my_weight_accessor[0].ptr, + bias_ptr, + in_dim, + out_dim, + batch_size); + break; + } + case OP_BATCHMATMUL: { + assert(fused->op_num_inputs[op] == 2); + assert(fused->op_num_weights[op] == 0); + assert(fused->op_num_outputs[op] == 1); + Domain out_domain = my_output_accessor[0].domain; + Domain a_domain = my_input_accessor[0].domain; + Domain b_domain = my_input_accessor[1].domain; + int m = b_domain.hi()[0] - b_domain.lo()[0] + 1; + assert(m == out_domain.hi()[0] - out_domain.lo()[0] + 1); + int n = a_domain.hi()[1] - a_domain.lo()[1] + 1; + assert(n == out_domain.hi()[1] - out_domain.lo()[1] + 1); + int k = a_domain.hi()[0] - a_domain.lo()[0] + 1; + assert(k == b_domain.hi()[1] - b_domain.lo()[1] + 1); + assert(a_domain.get_dim() == b_domain.get_dim()); + assert(a_domain.get_dim() == out_domain.get_dim()); + int batch = 1; + for (int i = 2; i < a_domain.get_dim(); i++) { + int dim_size = a_domain.hi()[i] - a_domain.lo()[i] + 1; + assert(dim_size == b_domain.hi()[i] - b_domain.lo()[i] + 1); + assert(dim_size == out_domain.hi()[i] - out_domain.lo()[i] + 1); + batch *= dim_size; + } + BatchMatmulMeta *meta = (BatchMatmulMeta *)metas->meta[op]; + Kernels::BatchMatmul::forward_kernel_wrapper( + meta, + my_output_accessor[0].get_float_ptr(), + my_input_accessor[0].get_float_ptr(), + my_input_accessor[1].get_float_ptr(), + (float const *)nullptr, + m, + n, + k, + batch, + meta->a_seq_length_dim, + meta->b_seq_length_dim, + fused->iter_config.seq_length); + break; + } + case OP_EW_ADD: + case OP_EW_SUB: + case OP_EW_MUL: + case OP_EW_DIV: + case OP_EW_MAX: + case OP_EW_MIN: { + assert(fused->op_num_inputs[op] == 2); + assert(fused->op_num_weights[op] == 0); + assert(fused->op_num_outputs[op] == 1); + assert(my_input_accessor[0].domain == my_input_accessor[1].domain); + assert(my_input_accessor[0].domain == my_output_accessor[0].domain); + ElementBinaryMeta *m = (ElementBinaryMeta *)metas->meta[op]; + Kernels::ElementBinary::forward_kernel_wrapper(m, + my_input_accessor[0], + my_input_accessor[1], + my_output_accessor[0]); + break; + } + case OP_EMBEDDING: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_weights[op] == 1); + assert(fused->op_num_outputs[op] == 1); + EmbeddingMeta *m = (EmbeddingMeta *)metas->meta[op]; + if (m->aggr == AGGR_MODE_NONE) { + // assert(kernel_domain.get_dim() == 2); + assert(my_input_accessor[0].domain.get_dim() + 1 == + my_output_accessor[0].domain.get_dim()); + for (size_t i = 0; i < my_input_accessor[0].domain.get_dim(); i++) { + assert(my_input_accessor[0].domain.hi()[i] == + my_output_accessor[0].domain.hi()[i + 1]); + assert(my_input_accessor[0].domain.lo()[i] == + my_output_accessor[0].domain.lo()[i + 1]); + } + assert(my_weight_accessor[0].domain.hi()[0] - + my_weight_accessor[0].domain.lo()[0] == + my_output_accessor[0].domain.hi()[0] - + my_output_accessor[0].domain.lo()[0]); + } else { + assert(my_input_accessor[0].domain.get_dim() == + my_output_accessor[0].domain.get_dim()); + for (size_t i = 1; i < my_input_accessor[0].domain.get_dim(); i++) { + assert(my_input_accessor[0].domain.hi()[i] == + my_output_accessor[0].domain.hi()[i]); + assert(my_input_accessor[0].domain.lo()[i] == + my_output_accessor[0].domain.lo()[i]); + } + assert(my_weight_accessor[0].domain.hi()[0] - + my_weight_accessor[0].domain.lo()[0] == + my_output_accessor[0].domain.hi()[0] - + my_output_accessor[0].domain.lo()[0]); + } + int in_dim, out_dim, effective_batch_size; + if (m->aggr == AGGR_MODE_NONE) { + in_dim = 1; + out_dim = my_output_accessor[0].domain.hi()[0] - + my_output_accessor[0].domain.lo()[0] + 1; + effective_batch_size = + my_output_accessor[0].domain.get_volume() / out_dim; + assert(effective_batch_size * in_dim == + my_input_accessor[0].domain.get_volume()); + } else { + assert(m->aggr == AGGR_MODE_AVG || m->aggr == AGGR_MODE_SUM); + in_dim = my_input_accessor[0].domain.hi()[0] - + my_input_accessor[0].domain.lo()[0] + 1; + out_dim = my_output_accessor[0].domain.hi()[0] - + my_output_accessor[0].domain.lo()[0] + 1; + effective_batch_size = + my_output_accessor[0].domain.get_volume() / out_dim; + assert(effective_batch_size * in_dim == + my_input_accessor[0].domain.get_volume()); + } + + assert(my_input_accessor[0].data_type == DT_INT32 || + my_input_accessor[0].data_type == DT_INT64); + Kernels::Embedding::forward_kernel_wrapper(m, + my_input_accessor[0], + my_output_accessor[0], + my_weight_accessor[0], + in_dim, + out_dim, + effective_batch_size); + break; + } + case OP_RELU: + case OP_SIGMOID: + case OP_TANH: + case OP_ELU: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_weights[op] == 0); + assert(fused->op_num_outputs[op] == 1); + assert(my_input_accessor[0].domain == my_output_accessor[0].domain); + ElementUnaryMeta *m = (ElementUnaryMeta *)metas->meta[op]; + if (m->data_type == DT_HALF) { + ElementUnary::forward_kernel_wrapper( + m, + my_input_accessor[0].get_half_ptr(), + my_output_accessor[0].get_half_ptr(), + my_input_accessor[0].domain.get_volume()); + } else if (m->data_type == DT_FLOAT) { + ElementUnary::forward_kernel_wrapper( + m, + my_input_accessor[0].get_float_ptr(), + my_output_accessor[0].get_float_ptr(), + my_input_accessor[0].domain.get_volume()); + } else { + assert(false && "Unsupported data type in ElementUnary forward"); + } + break; + } + case OP_RMS_NORM: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_weights[op] == 1); + assert(fused->op_num_outputs[op] == 1); + RMSNormMeta const *m = (RMSNormMeta *)metas->meta[op]; + Kernels::RMSNorm::forward_kernel_wrapper(m, + my_input_accessor[0], + my_weight_accessor[0], + my_output_accessor[0]); + break; + } + case OP_INC_MULTIHEAD_SELF_ATTENTION: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + IncMultiHeadSelfAttentionMeta const *m = + (IncMultiHeadSelfAttentionMeta *)metas->meta[op]; + assert(fused->op_num_weights[op] == (1 + (int)(*m->bias))); + GenericTensorAccessorR biases; + if (*m->bias) { + assert(fused->op_num_weights[op] == 2); + biases = my_weight_accessor[1]; + } + IncMultiHeadSelfAttention::inference_kernel_wrapper( + m, + bc, + task->index_point.point_data[0], + my_input_accessor[0], + my_weight_accessor[0], + my_output_accessor[0], + biases); + break; + } + case OP_TREE_INC_MULTIHEAD_SELF_ATTENTION: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + TreeIncMultiHeadSelfAttentionMeta *m = + (TreeIncMultiHeadSelfAttentionMeta *)metas->meta[op]; + TreeVerifyBatchConfig const *tree_bc = + (TreeVerifyBatchConfig *)task->args; + assert(fused->op_num_weights[op] == (1 + (int)(*m->bias))); + GenericTensorAccessorR biases; + if (*m->bias) { + assert(fused->op_num_weights[op] == 2); + biases = my_weight_accessor[1]; + } + TreeIncMultiHeadSelfAttention::inference_kernel_wrapper( + m, + tree_bc, + task->index_point.point_data[0], + my_input_accessor[0], + my_weight_accessor[0], + my_output_accessor[0], + biases); + break; + } + case OP_SPEC_INC_MULTIHEAD_SELF_ATTENTION: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + SpecIncMultiHeadSelfAttentionMeta const *m = + (SpecIncMultiHeadSelfAttentionMeta *)metas->meta[op]; + BeamSearchBatchConfig const *beam_bc = + (BeamSearchBatchConfig *)task->args; + assert(fused->op_num_weights[op] == (1 + (int)(*m->bias))); + GenericTensorAccessorR biases; + if (*m->bias) { + assert(fused->op_num_weights[op] == 2); + biases = my_weight_accessor[1]; + } + SpecIncMultiHeadSelfAttention::inference_kernel_wrapper( + m, + beam_bc, + task->index_point.point_data[0], + my_input_accessor[0], + my_weight_accessor[0], + my_output_accessor[0], + biases); + break; + } + case OP_LAYERNORM: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + LayerNormMeta const *m = (LayerNormMeta *)metas->meta[op]; + assert(fused->op_num_weights[op] == 2 * (int)(m->elementwise_affine)); + GenericTensorAccessorR gamma, beta; + if (m->elementwise_affine) { + gamma = my_weight_accessor[0]; + beta = my_weight_accessor[1]; + } + LayerNorm::forward_kernel_wrapper( + m, my_input_accessor[0], my_output_accessor[0], gamma, beta); + break; + } + case OP_ALLREDUCE: { + assert(fused->op_num_inputs[op] == 1); + assert(fused->op_num_outputs[op] == 1); + AllReduceMeta const *m = (AllReduceMeta *)metas->meta[op]; + Kernels::AllReduce::forward_kernel_wrapper( + m, my_input_accessor[0], my_output_accessor[0]); + break; + } + default: { + fprintf(stderr, + "Fusion currently does not support type = %d\n", + fused->op_op_type[op]); + assert(false && "Fusion currently does not support type"); + } + } + ioff += fused->op_num_inputs[op]; + woff += fused->op_num_weights[op]; + ooff += fused->op_num_outputs[op]; + } + // for (int i = 0; i < fused->numOutputs; i++) + // print_tensor(output_ptr[i], output_domain[i].get_volume(), + // "[Fused:forward:output]"); +} + /* regions[...](I): input regions[...](I): weight @@ -458,7 +895,6 @@ __host__ void FusedOp::forward_task(Task const *task, regions[...](I/O): weight_grad regions[...](I/O): output_grad */ - __host__ void FusedOp::backward_task(Task const *task, std::vector const ®ions, Context ctx, diff --git a/src/ops/gather.cc b/src/ops/gather.cc index f094fe38b0..635c741d8b 100644 --- a/src/ops/gather.cc +++ b/src/ops/gather.cc @@ -166,6 +166,7 @@ void Gather::serialize(Legion::Serializer &sez) const { GatherParams params = get_params(); sez.serialize(params.legion_dim); sez.serialize(this->layer_guid.id); + sez.serialize(this->layer_guid.transformer_layer_id); } using PCG::Node; @@ -177,9 +178,10 @@ Node Gather::deserialize(FFModel &ff, assert(num_inputs == 2); int legion_dim; dez.deserialize(legion_dim); - size_t id; + size_t id, transformer_layer_id; dez.deserialize(id); - LayerID layer_guid(id); + dez.deserialize(transformer_layer_id); + LayerID layer_guid(id, transformer_layer_id); GatherParams params; params.legion_dim = legion_dim; diff --git a/src/ops/inc_multihead_self_attention.cu b/src/ops/inc_multihead_self_attention.cu index 991b6d2236..f5b72b9ac8 100644 --- a/src/ops/inc_multihead_self_attention.cu +++ b/src/ops/inc_multihead_self_attention.cu @@ -452,7 +452,7 @@ void inference_kernel(IncMultiHeadSelfAttentionMeta const *m, } cudaMemcpyAsync(m->token_infos, &(bc->tokensInfo), - bc->MAX_NUM_TOKENS * sizeof(BatchConfig::PerTokenInfo), + bc->num_active_tokens() * sizeof(BatchConfig::PerTokenInfo), cudaMemcpyHostToDevice, stream); // phase 1: Implement kernel to compute KQV for input tokens diff --git a/src/ops/kernels/linear_kernels.cpp b/src/ops/kernels/linear_kernels.cpp index 55a47d7108..41b9912702 100644 --- a/src/ops/kernels/linear_kernels.cpp +++ b/src/ops/kernels/linear_kernels.cpp @@ -75,12 +75,13 @@ void Linear::init_kernel(LinearMeta *m, int batch_size, int channel) { assert(false); } checkCUDNN(miopenSetActivationDescriptor(m->actiDesc, mode, 0.0, 0.0, 0.0)); - checkCUDNN(miopenSet4dTensorDescriptor(m->outputTensor, - ff_to_cudnn_datatype(m->output_type), - batch_size, - channel, - 1, - 1)); + checkCUDNN( + miopenSet4dTensorDescriptor(m->outputTensor, + ff_to_cudnn_datatype(m->output_type[0]), + batch_size, + channel, + 1, + 1)); } } @@ -102,7 +103,7 @@ void forward_kernel_wrapper(LinearMeta const *m, hipEventRecord(t_start, stream); } - if (m->input_type == DT_FLOAT) { + if (m->input_type[0] == DT_FLOAT) { Internal::forward_kernel(m, input_ptr, output_ptr, @@ -112,7 +113,7 @@ void forward_kernel_wrapper(LinearMeta const *m, out_dim, batch_size, stream); - } else if (m->input_type == DT_HALF) { + } else if (m->input_type[0] == DT_HALF) { Internal::forward_kernel(m, input_ptr, output_ptr, @@ -161,7 +162,7 @@ void backward_kernel_wrapper(LinearMeta const *m, hipEventCreate(&t_end); hipEventRecord(t_start, stream); } - if (m->input_type == DT_FLOAT) { + if (m->input_type[0] == DT_FLOAT) { Internal::backward_kernel(m, input_ptr, input_grad_ptr, @@ -174,7 +175,7 @@ void backward_kernel_wrapper(LinearMeta const *m, out_dim, batch_size, stream); - } else if (m->input_type == DT_HALF) { + } else if (m->input_type[0] == DT_HALF) { Internal::backward_kernel(m, input_ptr, input_grad_ptr, @@ -236,9 +237,9 @@ void forward_kernel(LinearMeta const *m, checkCUDA(hipblasSetStream(m->handle.blas, stream)); checkCUDNN(miopenSetStream(m->handle.dnn, stream)); DT alpha = 1.0f, beta = 0.0f; - hipblasDatatype_t input_type = ff_to_cuda_datatype(m->input_type); - hipblasDatatype_t weight_type = ff_to_cuda_datatype(m->weight_type); - hipblasDatatype_t output_type = ff_to_cuda_datatype(m->output_type); + hipblasDatatype_t input_type = ff_to_cuda_datatype(m->input_type[0]); + hipblasDatatype_t weight_type = ff_to_cuda_datatype(m->weight_type[0]); + hipblasDatatype_t output_type = ff_to_cuda_datatype(m->output_type[0]); #if CUDA_VERSION >= 11000 // TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; @@ -332,9 +333,9 @@ void backward_kernel(LinearMeta const *m, checkCUDNN(miopenSetStream(m->handle.dnn, stream)); DT alpha = 1.0f; - hipblasDatatype_t input_type = ff_to_cuda_datatype(m->input_type); - hipblasDatatype_t weight_type = ff_to_cuda_datatype(m->weight_type); - hipblasDatatype_t output_type = ff_to_cuda_datatype(m->output_type); + hipblasDatatype_t input_type = ff_to_cuda_datatype(m->input_type[0]); + hipblasDatatype_t weight_type = ff_to_cuda_datatype(m->weight_type[0]); + hipblasDatatype_t output_type = ff_to_cuda_datatype(m->output_type[0]); #if CUDA_VERSION >= 11000 // TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; @@ -344,10 +345,10 @@ void backward_kernel(LinearMeta const *m, int output_size = out_dim * batch_size; if (m->activation == AC_MODE_RELU) { relu_backward_kernel( - m->output_type, output_grad_ptr, output_ptr, output_size, stream); + m->output_type[0], output_grad_ptr, output_ptr, output_size, stream); } else if (m->activation == AC_MODE_SIGMOID) { sigmoid_backward_kernel( - m->output_type, output_grad_ptr, output_ptr, output_size, stream); + m->output_type[0], output_grad_ptr, output_ptr, output_size, stream); } else { // TODO: only support relu and sigmoid for now assert(m->activation == AC_MODE_NONE); diff --git a/src/ops/kernels/linear_kernels.cu b/src/ops/kernels/linear_kernels.cu index 3f806dd4f5..06677f86e6 100644 --- a/src/ops/kernels/linear_kernels.cu +++ b/src/ops/kernels/linear_kernels.cu @@ -96,13 +96,14 @@ void init_kernel(LinearMeta *m, int batch_size, int channel) { } checkCUDNN(cudnnSetActivationDescriptor( m->actiDesc, mode, CUDNN_PROPAGATE_NAN, 0.0)); - checkCUDNN(cudnnSetTensor4dDescriptor(m->outputTensor, - CUDNN_TENSOR_NCHW, - ff_to_cudnn_datatype(m->output_type), - batch_size, - channel, - 1, - 1)); + checkCUDNN( + cudnnSetTensor4dDescriptor(m->outputTensor, + CUDNN_TENSOR_NCHW, + ff_to_cudnn_datatype(m->output_type[0]), + batch_size, + channel, + 1, + 1)); } } @@ -122,7 +123,7 @@ void forward_kernel_wrapper(LinearMeta const *m, cudaEventCreate(&t_end); cudaEventRecord(t_start, stream); } - if (m->input_type == DT_FLOAT) { + if (m->input_type[0] == DT_FLOAT) { Internal::forward_kernel(m, input_ptr, output_ptr, @@ -132,7 +133,7 @@ void forward_kernel_wrapper(LinearMeta const *m, out_dim, batch_size, stream); - } else if (m->input_type == DT_HALF) { + } else if (m->input_type[0] == DT_HALF) { Internal::forward_kernel(m, input_ptr, output_ptr, @@ -180,7 +181,7 @@ void backward_kernel_wrapper(LinearMeta const *m, cudaEventCreate(&t_end); cudaEventRecord(t_start, stream); } - if (m->input_type == DT_FLOAT) { + if (m->input_type[0] == DT_FLOAT) { Internal::backward_kernel(m, input_ptr, input_grad_ptr, @@ -193,7 +194,7 @@ void backward_kernel_wrapper(LinearMeta const *m, out_dim, batch_size, stream); - } else if (m->input_type == DT_HALF) { + } else if (m->input_type[0] == DT_HALF) { Internal::backward_kernel(m, input_ptr, input_grad_ptr, @@ -295,11 +296,11 @@ void forward_kernel(LinearMeta const *m, checkCUDA(cublasSetStream(m->handle.blas, stream)); checkCUDNN(cudnnSetStream(m->handle.dnn, stream)); DT alpha = 1.0f, beta = 0.0f; - cudaDataType_t input_type = ff_to_cuda_datatype(m->input_type); + cudaDataType_t input_type = ff_to_cuda_datatype(m->input_type[0]); cudaDataType_t weight_type = m->offload ? ff_to_cuda_datatype(m->weight_ptr_type) - : ff_to_cuda_datatype(m->weight_type); - cudaDataType_t output_type = ff_to_cuda_datatype(m->output_type); + : ff_to_cuda_datatype(m->weight_type[0]); + cudaDataType_t output_type = ff_to_cuda_datatype(m->output_type[0]); assert(input_type == weight_type && weight_type == output_type); #if CUDA_VERSION >= 11000 // TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance @@ -388,9 +389,9 @@ void backward_kernel(LinearMeta const *m, DT alpha = 1.0f; float sgeam_alpha = 1.0f; - cudaDataType_t input_type = ff_to_cuda_datatype(m->input_type); - cudaDataType_t weight_type = ff_to_cuda_datatype(m->weight_type); - cudaDataType_t output_type = ff_to_cuda_datatype(m->output_type); + cudaDataType_t input_type = ff_to_cuda_datatype(m->input_type[0]); + cudaDataType_t weight_type = ff_to_cuda_datatype(m->weight_type[0]); + cudaDataType_t output_type = ff_to_cuda_datatype(m->output_type[0]); #if CUDA_VERSION >= 11000 // TODO: currently set the default to CUBLAS_COMPUTE_16F for best performance cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; @@ -400,10 +401,10 @@ void backward_kernel(LinearMeta const *m, int output_size = out_dim * batch_size; if (m->activation == AC_MODE_RELU) { relu_backward_kernel( - m->output_type, output_grad_ptr, output_ptr, output_size, stream); + m->output_type[0], output_grad_ptr, output_ptr, output_size, stream); } else if (m->activation == AC_MODE_SIGMOID) { sigmoid_backward_kernel( - m->output_type, output_grad_ptr, output_ptr, output_size, stream); + m->output_type[0], output_grad_ptr, output_ptr, output_size, stream); } else { // TODO: only support relu and sigmoid for now assert(m->activation == AC_MODE_NONE); diff --git a/src/ops/layer_norm.cc b/src/ops/layer_norm.cc index 0c08a2426f..0124c827f3 100644 --- a/src/ops/layer_norm.cc +++ b/src/ops/layer_norm.cc @@ -216,33 +216,39 @@ LayerNorm::LayerNorm(FFModel &model, for (int i = 0; i < axes.size(); i++) { M *= inputs[0]->dims[axes[i]].size; } + int num_replicas = 1; + for (int i = 0; i < inputs[0]->num_dims; i++) { + if (inputs[0]->dims[i].is_replica_dim) { + num_replicas *= inputs[0]->dims[i].size; + } + } effective_num_elements = M; - effective_batch_size = inputs[0]->get_volume() / M; + effective_batch_size = (inputs[0]->get_volume() / num_replicas) / M; assert(elementwise_affine == (numWeights == 2)); if (numWeights > 0 && allocate_weights) { - ParallelDim dims[axes.size()]; - for (int i = 0; i < axes.size(); i++) { - dims[i] = inputs[0]->dims[i]; + ParallelTensorShape beta_gamma_shape = _input->get_shape(); + for (int i = axes.size(); i < beta_gamma_shape.num_dims - 1; i++) { + beta_gamma_shape.dims[i].size = 1; } int seed = std::rand(); Initializer *gamma_initializer = new UniformInitializer(seed, 1.0f, 1.0f); Initializer *beta_initializer = new UniformInitializer(seed, 0.0f, 0.0f); - weights[0] = - model.create_parallel_weight_legion_ordering(axes.size(), - dims, - _input->data_type, - NULL /*owner_op*/, - true /*create_grad*/, - gamma_initializer, - CHOSEN_SYNC_TYPE); - weights[1] = - model.create_parallel_weight_legion_ordering(axes.size(), - dims, - _input->data_type, - NULL /*owner_op*/, - true /*create_grad*/, - beta_initializer, - CHOSEN_SYNC_TYPE); + weights[0] = model.create_parallel_weight_legion_ordering( + beta_gamma_shape.num_dims, // axes.size(), + beta_gamma_shape.dims, + _input->data_type, + NULL /*owner_op*/, + true /*create_grad*/, + gamma_initializer, + CHOSEN_SYNC_TYPE); + weights[1] = model.create_parallel_weight_legion_ordering( + beta_gamma_shape.num_dims, //.size(), + beta_gamma_shape.dims, + _input->data_type, + NULL /*owner_op*/, + true /*create_grad*/, + beta_initializer, + CHOSEN_SYNC_TYPE); } } @@ -383,13 +389,13 @@ void LayerNorm::forward(FFModel const &ff) { if (elementwise_affine) { launcher.add_region_requirement(RegionRequirement(weights[0]->part, 0 /*projection id*/, - READ_WRITE, + READ_ONLY, EXCLUSIVE, weights[0]->region)); launcher.add_field(2, FID_DATA); launcher.add_region_requirement(RegionRequirement(weights[1]->part, 0 /*projection id*/, - READ_WRITE, + READ_ONLY, EXCLUSIVE, weights[1]->region)); launcher.add_field(3, FID_DATA); @@ -434,13 +440,13 @@ FutureMap LayerNorm::inference(FFModel const &ff, if (elementwise_affine) { launcher.add_region_requirement(RegionRequirement(weights[0]->part, 0 /*projection id*/, - READ_WRITE, + READ_ONLY, EXCLUSIVE, weights[0]->region)); launcher.add_field(2, FID_DATA); launcher.add_region_requirement(RegionRequirement(weights[1]->part, 0 /*projection id*/, - READ_WRITE, + READ_ONLY, EXCLUSIVE, weights[1]->region)); launcher.add_field(3, FID_DATA); @@ -462,8 +468,8 @@ void LayerNorm::forward_task(Task const *task, assert(task->regions.size() == regions.size()); float const *in_ptr = NULL; float *out_ptr = NULL, *gamma_ptr = NULL, *beta_ptr = NULL; - GenericTensorAccessorR in; - GenericTensorAccessorW out, gamma, beta; + GenericTensorAccessorR in, gamma, beta; + GenericTensorAccessorW out; Domain in_domain = runtime->get_index_space_domain( ctx, task->regions[0].region.get_index_space()); @@ -486,21 +492,25 @@ void LayerNorm::forward_task(Task const *task, ctx, task->regions[2].region.get_index_space()); // gamma_ptr = helperGetTensorPointerRW( // regions[2], task->regions[2], FID_DATA, ctx, runtime); - gamma = helperGetGenericTensorAccessorRW( + gamma = helperGetGenericTensorAccessorRO( m->input_type[0], regions[2], task->regions[2], FID_DATA, ctx, runtime); Domain beta_domain = runtime->get_index_space_domain( ctx, task->regions[3].region.get_index_space()); // beta_ptr = helperGetTensorPointerRW( // regions[3], task->regions[3], FID_DATA, ctx, runtime); - beta = helperGetGenericTensorAccessorRW( + beta = helperGetGenericTensorAccessorRO( m->input_type[0], regions[3], task->regions[3], FID_DATA, ctx, runtime); assert(gamma_domain == beta_domain); assert(gamma_domain.get_volume() == m->effective_num_elements); int numdims = gamma_domain.get_dim(); - for (int i = 0; i < numdims; i++) { + size_t vol = 1; + int i = 0; + while (vol < gamma_domain.get_volume()) { int g_d = gamma_domain.hi()[i] - gamma_domain.lo()[i] + 1; int in_d = in_domain.hi()[i] - in_domain.lo()[i] + 1; assert(g_d == in_d); + vol *= g_d; + i++; } } else { assert(regions.size() == 2); @@ -730,6 +740,7 @@ bool LayerNorm::measure_operator_cost(Simulator *sim, void LayerNorm::serialize(Legion::Serializer &sez) const { sez.serialize(this->layer_guid.id); + sez.serialize(this->layer_guid.transformer_layer_id); sez.serialize(this->axes.size()); for (size_t i = 0; i < this->axes.size(); i++) { sez.serialize(this->axes[i]); @@ -749,9 +760,10 @@ Node LayerNorm::deserialize(FFModel &ff, std::vector axes; bool elementwise_affine; float eps; - size_t id; + size_t id, transformer_layer_id; dez.deserialize(id); - LayerID layer_guid(id); + dez.deserialize(transformer_layer_id); + LayerID layer_guid(id, transformer_layer_id); dez.deserialize(num_axes); for (size_t i = 0; i < num_axes; i++) { int axis_idx; diff --git a/src/ops/layer_norm.cpp b/src/ops/layer_norm.cpp index 3f1c621e71..fc6be70c74 100644 --- a/src/ops/layer_norm.cpp +++ b/src/ops/layer_norm.cpp @@ -129,8 +129,8 @@ template void LayerNorm::forward_kernel(LayerNormMeta const *m, T const *in_ptr, T *out_ptr, - T *gamma_ptr, - T *beta_ptr, + T const *gamma_ptr, + T const *beta_ptr, hipStream_t stream) { hipLaunchKernelGGL(HIP_KERNEL_NAME(RowwiseMomentsCUDAKernel), m->effective_batch_size, @@ -160,8 +160,8 @@ void LayerNorm::forward_kernel(LayerNormMeta const *m, void LayerNorm::forward_kernel_wrapper(LayerNormMeta const *m, GenericTensorAccessorR const &input, GenericTensorAccessorW &output, - GenericTensorAccessorW &gamma, - GenericTensorAccessorW &beta) { + GenericTensorAccessorR const &gamma, + GenericTensorAccessorR const &beta) { hipStream_t stream; checkCUDA(get_legion_stream(&stream)); if (m->input_type[0] == DT_FLOAT) { diff --git a/src/ops/layer_norm.cu b/src/ops/layer_norm.cu index 35616de980..1f4e7d3933 100644 --- a/src/ops/layer_norm.cu +++ b/src/ops/layer_norm.cu @@ -135,8 +135,8 @@ template void LayerNorm::forward_kernel(LayerNormMeta const *m, T const *in_ptr, T *out_ptr, - T *gamma_ptr, - T *beta_ptr, + T const *gamma_ptr, + T const *beta_ptr, cudaStream_t stream) { RowwiseMomentsCUDAKernel <<effective_batch_size, kCUDABlockReduceNumThreads, 0, stream>>>( @@ -160,8 +160,8 @@ void LayerNorm::forward_kernel(LayerNormMeta const *m, void LayerNorm::forward_kernel_wrapper(LayerNormMeta const *m, GenericTensorAccessorR const &input, GenericTensorAccessorW &output, - GenericTensorAccessorW &gamma, - GenericTensorAccessorW &beta) { + GenericTensorAccessorR const &gamma, + GenericTensorAccessorR const &beta) { cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); diff --git a/src/ops/linear.cc b/src/ops/linear.cc index cca92f014f..c5903c1e74 100644 --- a/src/ops/linear.cc +++ b/src/ops/linear.cc @@ -504,10 +504,7 @@ OpMeta *Linear::init_task_with_dim(Task const *task, m->add_bias_only_once = linear->add_bias_only_once; m->profiling = linear->profiling; m->trainableInputs[0] = linear->trainableInputs[0]; - m->input_type = linear->inputs[0]->data_type; - m->weight_type = linear->weights[0]->data_type; - m->output_type = linear->outputs[0]->data_type; - m->weight_ptr_type = m->input_type; + m->weight_ptr_type = m->input_type[0]; m->quantization_type = linear->quantization_type; m->offload = linear->offload; std::strcpy(m->op_name, linear->name); @@ -573,9 +570,9 @@ FutureMap Linear::inference(FFModel const &ff, size_t machine_view_hash = view->hash(); /* std::cout << "Linear op machine_view: " << *(MachineView const *)mv << std::endl; */ - IndexLauncher launcher(LINEAR_FWD_TASK_ID, + IndexLauncher launcher(LINEAR_INF_TASK_ID, parallel_is, - TaskArgument(nullptr, 0), + TaskArgument(&bc, sizeof(BatchConfig)), argmap, Predicate::TRUE_PRED, false /*must*/, @@ -612,6 +609,52 @@ FutureMap Linear::inference(FFModel const &ff, return runtime->execute_index_space(ctx, launcher); } +void Linear::inference_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + Domain input_domain = runtime->get_index_space_domain( + ctx, task->regions[0].region.get_index_space()); + LinearMeta const *m = *((LinearMeta **)task->local_args); + BatchConfig const *bc = (BatchConfig *)task->args; + assert(regions.size() == (3 + static_cast(m->use_bias))); + assert(task->regions.size() == (3 + static_cast(m->use_bias))); + if (m->quantization_type == DT_NONE) { + assert(m->input_type[0] == m->weight_type[0]); + } + assert(m->input_type[0] == m->output_type[0]); + + GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( + m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); + GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( + m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); + GenericTensorAccessorR weight = helperGetGenericTensorAccessorRO( + m->weight_type[0], regions[2], task->regions[2], FID_DATA, ctx, runtime); + int in_dim = input.domain.hi()[0] - input.domain.lo()[0] + 1; + int out_dim = output.domain.hi()[0] - output.domain.lo()[0] + 1; + + int batch_size = bc->num_active_tokens(); + GenericTensorAccessorR bias; + if (m->use_bias && + !(m->add_bias_only_once && task->index_point.point_data[0] != 0)) { + bias = helperGetGenericTensorAccessorRO(m->weight_type[1], + regions[3], + task->regions[3], + FID_DATA, + ctx, + runtime); + assert(bias.domain.get_volume() == static_cast(out_dim)); + } + forward_kernel_wrapper(m, + input.ptr, + output.ptr, + weight.ptr, + bias.ptr, + in_dim, + out_dim, + batch_size); +} + void Linear::forward_task(Task const *task, std::vector const ®ions, Context ctx, @@ -620,13 +663,13 @@ void Linear::forward_task(Task const *task, ctx, task->regions[0].region.get_index_space()); LinearMeta const *m = *((LinearMeta **)task->local_args); if (m->quantization_type == DT_NONE) { - assert(m->input_type == m->weight_type); + assert(m->input_type[0] == m->weight_type[0]); } - assert(m->input_type == m->output_type); + assert(m->input_type[0] == m->output_type[0]); switch (input_domain.get_dim()) { #define DIMFUNC(DIM) \ case DIM: \ - if (m->output_type == DT_HALF) { \ + if (m->output_type[0] == DT_HALF) { \ if (m->quantization_type != DT_NONE) { \ return forward_task_with_dim( \ task, regions, ctx, runtime); \ @@ -634,7 +677,7 @@ void Linear::forward_task(Task const *task, return forward_task_with_dim( \ task, regions, ctx, runtime); \ } \ - } else if (m->output_type == DT_FLOAT) { \ + } else if (m->output_type[0] == DT_FLOAT) { \ if (m->quantization_type != DT_NONE) { \ return forward_task_with_dim( \ task, regions, ctx, runtime); \ @@ -787,15 +830,15 @@ void Linear::backward_task(Task const *task, ctx, task->regions[0].region.get_index_space()); LinearMeta const *m = *((LinearMeta **)task->local_args); if (m->quantization_type == DT_NONE) { - assert(m->input_type == m->weight_type); + assert(m->input_type[0] == m->weight_type[0]); } - assert(m->input_type == m->output_type); + assert(m->input_type[0] == m->output_type[0]); switch (in_domain.get_dim()) { #define DIMFUNC(DIM) \ case DIM: \ - if (m->output_type == DT_HALF) { \ + if (m->output_type[0] == DT_HALF) { \ return backward_task_with_dim(task, regions, ctx, runtime); \ - } else if (m->output_type == DT_FLOAT) { \ + } else if (m->output_type[0] == DT_FLOAT) { \ return backward_task_with_dim(task, regions, ctx, runtime); \ } else { \ assert(false && "Unsupported data type"); \ @@ -1068,9 +1111,9 @@ bool Linear::measure_operator_cost(Simulator *sim, m->activation = activation; m->kernel_reg_type = kernel_reg_type; m->kernel_reg_lambda = kernel_reg_lambda; - m->input_type = inputs[0]->data_type; - m->weight_type = this->data_type; - m->output_type = outputs[0]->data_type; + m->input_type[0] = inputs[0]->data_type; + m->weight_type[0] = this->data_type; + m->output_type[0] = outputs[0]->data_type; assert(m->profiling == false); init_kernel(m, output_n, output_c); @@ -1186,6 +1229,7 @@ bool operator==(LinearParams const &lhs, LinearParams const &rhs) { void Linear::serialize(Legion::Serializer &sez) const { sez.serialize(this->layer_guid.id); + sez.serialize(this->layer_guid.transformer_layer_id); sez.serialize(this->out_channels); sez.serialize(this->activation); sez.serialize(this->kernel_reg_type); @@ -1211,9 +1255,10 @@ Node Linear::deserialize(FFModel &ff, DataType data_type; DataType quantization_type; bool offload; - size_t id; + size_t id, transformer_layer_id; dez.deserialize(id); - LayerID layer_guid(id); + dez.deserialize(transformer_layer_id); + LayerID layer_guid(id, transformer_layer_id); dez.deserialize(out_channels); dez.deserialize(activation); dez.deserialize(kernel_reg_type); diff --git a/src/ops/reduce.cc b/src/ops/reduce.cc index 5761281686..36112b0812 100644 --- a/src/ops/reduce.cc +++ b/src/ops/reduce.cc @@ -374,6 +374,7 @@ void Reduce::serialize(Legion::Serializer &sez) const { } sez.serialize(params.keepdims); sez.serialize(this->layer_guid.id); + sez.serialize(this->layer_guid.transformer_layer_id); } using PCG::Node; @@ -392,9 +393,10 @@ Node Reduce::deserialize(FFModel &ff, axes.push_back(dim_idx); } dez.deserialize(keepdims); - size_t id; + size_t id, transformer_layer_id; dez.deserialize(id); - LayerID layer_guid(id); + dez.deserialize(transformer_layer_id); + LayerID layer_guid(id, transformer_layer_id); return ff.get_or_create_node(inputs[0], {axes, keepdims, layer_guid}); } diff --git a/src/ops/reshape.cc b/src/ops/reshape.cc index 2b8a60bf21..41c3fcdbf1 100644 --- a/src/ops/reshape.cc +++ b/src/ops/reshape.cc @@ -410,6 +410,7 @@ void Reshape::serialize(Legion::Serializer &sez) const { sez.serialize(this->shape_array[i]); } sez.serialize(this->layer_guid.id); + sez.serialize(this->layer_guid.transformer_layer_id); } using PCG::Node; @@ -427,9 +428,10 @@ Node Reshape::deserialize(FFModel &ff, dez.deserialize(value); shape.push_back(value); } - size_t id; + size_t id, transformer_layer_id; dez.deserialize(id); - LayerID layer_guid(id); + dez.deserialize(transformer_layer_id); + LayerID layer_guid(id, transformer_layer_id); ReshapeParams params; params.shape = shape; diff --git a/src/ops/rms_norm.cc b/src/ops/rms_norm.cc index a926fd3b22..e0076b5202 100644 --- a/src/ops/rms_norm.cc +++ b/src/ops/rms_norm.cc @@ -165,7 +165,11 @@ RMSNorm::RMSNorm(FFModel &model, for (int i = 1; i <= num_dims - 2; i++) { effective_batch_size *= _input->dims[i].size; } - + // Currently assert that all non-replica dims are not parallelized + // We only support parallelism along the replica dim now + for (int i = 0; i < _input->num_dims - 1; i++) { + assert(_input->dims[i].degree == 1); + } // output has the same parallel dims as input ParallelDim output_dims[MAX_TENSOR_DIM]; for (int i = 0; i < _input->num_dims; i++) { @@ -173,15 +177,14 @@ RMSNorm::RMSNorm(FFModel &model, } outputs[0] = model.create_parallel_tensor_legion_ordering( _input->num_dims, output_dims, _input->data_type, this); - if (allocate_weights) { // weights should have the shape of (data_dim, data_dim) ParallelDim new_weight_dims[MAX_TENSOR_DIM]; - new_weight_dims[0] = _input->dims[_input->num_dims - 1]; - new_weight_dims[1].size = dim; - new_weight_dims[1].degree = 1; - new_weight_dims[1].parallel_idx = -1; + new_weight_dims[0].size = dim; + new_weight_dims[0].degree = 1; + new_weight_dims[0].parallel_idx = -1; + new_weight_dims[1] = _input->dims[_input->num_dims - 1]; // replica dim // weights Initializer *kernel_initializer = new GlorotUniform(std::rand() /*seed*/); @@ -189,7 +192,7 @@ RMSNorm::RMSNorm(FFModel &model, model.create_parallel_weight_legion_ordering(2, new_weight_dims, _input->data_type, - NULL /*owner_op*/, + nullptr /*owner_op*/, false /*create_grad*/, kernel_initializer, CHOSEN_SYNC_TYPE); @@ -389,6 +392,7 @@ void RMSNorm::forward_task(Task const *task, void RMSNorm::serialize(Legion::Serializer &sez) const { sez.serialize(this->layer_guid.id); + sez.serialize(this->layer_guid.transformer_layer_id); sez.serialize(this->eps); sez.serialize(this->dim); } @@ -401,11 +405,12 @@ Node RMSNorm::deserialize(FFModel &ff, int num_inputs) { assert(num_inputs == 1); float eps; - size_t id; + size_t id, transformer_layer_id; int dim; dez.deserialize(id); + dez.deserialize(transformer_layer_id); - LayerID layer_guid(id); + LayerID layer_guid(id, transformer_layer_id); dez.deserialize(eps); dez.deserialize(dim); RMSNormParams params; diff --git a/src/ops/tree_inc_multihead_self_attention.cu b/src/ops/tree_inc_multihead_self_attention.cu index 541322efc4..b46ccb4853 100644 --- a/src/ops/tree_inc_multihead_self_attention.cu +++ b/src/ops/tree_inc_multihead_self_attention.cu @@ -498,7 +498,7 @@ void inference_kernel(TreeIncMultiHeadSelfAttentionMeta *m, // keys/values to the key-value cache cudaMemcpyAsync(m->committed_token_infos, &(bc->committed_tokens), - bc->MAX_NUM_TOKENS * + bc->num_tokens_to_commit * sizeof(TreeVerifyBatchConfig::CommittedTokensInfo), cudaMemcpyHostToDevice, stream); diff --git a/src/parallel_ops/allreduce.cc b/src/parallel_ops/allreduce.cc new file mode 100644 index 0000000000..123e85c7c5 --- /dev/null +++ b/src/parallel_ops/allreduce.cc @@ -0,0 +1,362 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "flexflow/parallel_ops/allreduce.h" +#include "flexflow/ffconst_utils.h" +#include "flexflow/model.h" +#include "flexflow/parallel_ops/kernels/allreduce_kernels.h" +#include "flexflow/utils/hash_utils.h" + +namespace FlexFlow { +// declare Legion names +using Legion::ArgumentMap; +using Legion::Context; +using Legion::coord_t; +using Legion::Domain; +using Legion::FutureMap; +using Legion::IndexLauncher; +using Legion::LogicalPartition; +using Legion::LogicalRegion; +using Legion::Machine; +using Legion::Memory; +using Legion::PhysicalRegion; +using Legion::Predicate; +using Legion::Rect; +using Legion::RegionRequirement; +using Legion::Runtime; +using Legion::Task; +using Legion::TaskArgument; +using Legion::TaskLauncher; + +using namespace FlexFlow::Kernels::AllReduce; + +/* Params */ +bool operator==(AllReduceParams const &lhs, AllReduceParams const &rhs) { + return lhs.allreduce_legion_dim == rhs.allreduce_legion_dim; +} + +bool AllReduceParams::is_valid(ParallelTensorShape const &input) const { + return input.is_valid(); +} + +AllReduceParams AllReduce::get_params() const { + AllReduceParams params; + params.allreduce_legion_dim = this->allreduce_dim; + return params; +} + +AllReduce::AllReduce(FFModel &model, + const ParallelTensor _input, + int _allreduce_legion_dim, + char const *name) + : ParallelOp(model, OP_ALLREDUCE, name, _input), + allreduce_dim(_allreduce_legion_dim) { + int numdim = _input->num_dims; + ParallelDim dims[MAX_TENSOR_DIM]; + for (int i = 0; i < numdim; i++) { + dims[i] = _input->dims[i]; + } + assert(dims[allreduce_dim].degree > 1); + // ParallelTensorBase::update_parallel_ids(numdim, dims); + outputs[0] = model.create_parallel_tensor_legion_ordering( + numdim, dims, _input->data_type, this); +} + +AllReduce::AllReduce(FFModel &model, + AllReduceParams const ¶ms, + ParallelTensor const input, + char const *name) + : AllReduce(model, input, params.allreduce_legion_dim, name) {} + +void AllReduce::create_input_partition(FFModel &ff) { + // Do nothing + return; +} + +void AllReduce::create_input_partition_inference( + FFModel &ff, + std::vector const &batch_inputs, + std::vector const &batch_outputs) { + assert(ff.config.computationMode == COMP_MODE_INFERENCE); + assert(batch_outputs[0]->part != LogicalPartition::NO_PART); + assert(batch_inputs[0]->part != LogicalPartition::NO_PART); + // Do nothing + return; +} + +OpMeta *AllReduce::init_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + AllReduce *ar = (AllReduce *)task->args; + FFHandler handle = *((FFHandler const *)task->local_args); + AllReduceMeta *meta = new AllReduceMeta(handle, ar); + meta->input_type[0] = ar->inputs[0]->data_type; + meta->output_type[0] = ar->outputs[0]->data_type; + assert(meta->input_type[0] == meta->output_type[0]); + return meta; +} + +void AllReduce::init(FFModel const &ff) { + ArgumentMap argmap; + parallel_is = outputs[0]->parallel_is; + Context ctx = ff.config.lg_ctx; + Runtime *runtime = ff.config.lg_hlr; + assert(numOutputs == 1); + assert(numInputs == 1); + set_argumentmap_for_init(ff, argmap); + IndexLauncher launcher(ALLREDUCE_INIT_TASK_ID, + parallel_is, + TaskArgument(this, sizeof(AllReduce)), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + outputs[0]->machine_view.hash()); + launcher.add_region_requirement(RegionRequirement(inputs[0]->part, + 0 /*projection id*/, + READ_ONLY, + EXCLUSIVE, + inputs[0]->region)); + launcher.add_field(0, FID_DATA); + launcher.add_region_requirement(RegionRequirement(outputs[0]->part, + 0 /*projection id*/, + WRITE_ONLY, + EXCLUSIVE, + outputs[0]->region)); + launcher.add_field(1, FID_DATA); + FutureMap fm = runtime->execute_index_space(ctx, launcher); + fm.wait_all_results(); + set_opmeta_from_futuremap(ff, fm); +} + +void AllReduce::init_inference(FFModel const &ff, + std::vector const &batch_inputs, + std::vector const &batch_outputs, + MachineView const *mv) { + ArgumentMap argmap; + parallel_is = batch_outputs[0]->parallel_is; + Context ctx = ff.config.lg_ctx; + Runtime *runtime = ff.config.lg_hlr; + assert(numOutputs == 1); + assert(numInputs == 1); + size_t machine_view_hash = + mv ? mv->hash() : batch_outputs[0]->machine_view.hash(); + set_argumentmap_for_init_inference(ff, argmap, batch_outputs[0]); + IndexLauncher launcher(ALLREDUCE_INIT_TASK_ID, + parallel_is, + TaskArgument(this, sizeof(AllReduce)), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + machine_view_hash); + launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, + 0 /*projection id*/, + READ_ONLY, + EXCLUSIVE, + batch_inputs[0]->region)); + launcher.add_field(0, FID_DATA); + launcher.add_region_requirement(RegionRequirement(batch_outputs[0]->part, + 0 /*projection id*/, + WRITE_ONLY, + EXCLUSIVE, + batch_outputs[0]->region)); + launcher.add_field(1, FID_DATA); + FutureMap fm = runtime->execute_index_space(ctx, launcher); + fm.wait_all_results(); + set_opmeta_from_futuremap_inference(ff, fm, batch_outputs[0]); +} + +FutureMap AllReduce::inference(FFModel const &ff, + BatchConfig const &bc, + std::vector const &batch_inputs, + std::vector const &batch_outputs, + MachineView const *mv) { + ArgumentMap argmap; + Context ctx = ff.config.lg_ctx; + Runtime *runtime = ff.config.lg_hlr; + parallel_is = batch_outputs[0]->parallel_is; + assert(numOutputs == 1); + assert(numInputs == 1); + assert(batch_inputs[0]->data_type == batch_outputs[0]->data_type); + DataType data_type = batch_inputs[0]->data_type; + size_t machine_view_hash = + mv ? mv->hash() : batch_outputs[0]->machine_view.hash(); + set_argumentmap_for_inference(ff, argmap, batch_outputs[0]); + IndexLauncher launcher(ALLREDUCE_FWD_TASK_ID, + batch_outputs[0]->parallel_is, + TaskArgument(NULL, 0), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + machine_view_hash); + launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, + 0 /*projection id*/, + READ_ONLY, + EXCLUSIVE, + batch_inputs[0]->region)); + launcher.add_field(0, FID_DATA); + launcher.add_region_requirement(RegionRequirement(batch_outputs[0]->part, + 0 /*projection id*/, + WRITE_ONLY, + EXCLUSIVE, + batch_outputs[0]->region)); + launcher.add_field(1, FID_DATA); + return runtime->execute_index_space(ctx, launcher); +} + +void AllReduce::forward(FFModel const &ff) { + ArgumentMap argmap; + Context ctx = ff.config.lg_ctx; + Runtime *runtime = ff.config.lg_hlr; + parallel_is = outputs[0]->parallel_is; + assert(numOutputs == 1); + assert(numInputs == 1); + set_argumentmap_for_forward(ff, argmap); + IndexLauncher launcher(ALLREDUCE_FWD_TASK_ID, + outputs[0]->parallel_is, + TaskArgument(NULL, 0), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + outputs[0]->machine_view.hash()); + launcher.add_region_requirement(RegionRequirement(inputs[0]->part, + 0 /*projection id*/, + READ_ONLY, + EXCLUSIVE, + inputs[0]->region)); + launcher.add_field(0, FID_DATA); + launcher.add_region_requirement(RegionRequirement(outputs[0]->part, + 0 /*projection id*/, + WRITE_ONLY, + EXCLUSIVE, + outputs[0]->region)); + launcher.add_field(1, FID_DATA); + runtime->execute_index_space(ctx, launcher); +} + +void AllReduce::backward(FFModel const &ff) { + ArgumentMap argmap; + Context ctx = ff.config.lg_ctx; + Runtime *runtime = ff.config.lg_hlr; + assert(numOutputs == 1); + assert(numInputs == 1); + IndexLauncher launcher(ALLREDUCE_BWD_TASK_ID, + inputs[0]->parallel_is, + TaskArgument(NULL, 0), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + inputs[0]->machine_view.hash()); + launcher.add_region_requirement(RegionRequirement(inputs[0]->part_grad, + 0 /*projection id*/, + READ_WRITE, + EXCLUSIVE, + inputs[0]->region_grad)); + launcher.add_field(0, FID_DATA); + launcher.add_region_requirement(RegionRequirement(outputs[0]->part_grad, + 0 /*projection id*/, + READ_ONLY, + EXCLUSIVE, + outputs[0]->region_grad)); + launcher.add_field(1, FID_DATA); + runtime->execute_index_space(ctx, launcher); +} + +bool AllReduce::measure_operator_cost(Simulator *sim, + MachineView const &pc, + CostMetrics &cost_metrics) const { + cost_metrics = CostMetrics(); + cost_metrics.forward_time = 0.0f; + cost_metrics.backward_time = 0.0f; + + cost_metrics.sync_time = 0; + cost_metrics.inputs_memory = 0; + cost_metrics.outputs_memory = 0; + cost_metrics.weights_memory = 0; + return true; +} + +bool AllReduce::get_int_parameter(PMParameter para, int *value) const { + switch (para) { + case PM_ALLREDUCE_DIM: + *value = allreduce_dim; + return true; + default: + return Op::get_int_parameter(para, value); + } +} + +bool AllReduce::append_parallel_op_info( + std::vector ¶llel_ops) const { + ParallelOpInfo ret; + ret.op_type = op_type; + ret.parallel_dim = allreduce_dim; + ret.parallel_degree = -1; // AllReduce does not affect parallel degree + parallel_ops.push_back(ret); + return true; +} + +/*static*/ +void AllReduce::forward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + assert(regions.size() == 2); + assert(task->regions.size() == 2); + + AllReduceMeta const *m = *((AllReduceMeta **)task->local_args); + + GenericTensorAccessorR input = helperGetGenericTensorAccessorRO( + m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); + GenericTensorAccessorW output = helperGetGenericTensorAccessorWO( + m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); + + assert(input.data_type == output.data_type); + forward_kernel_wrapper(m, input, output); +} + +void AllReduce::backward_task(Task const *task, + std::vector const ®ions, + Context ctx, + Runtime *runtime) { + assert(regions.size() == 2); + assert(task->regions.size() == 2); + AllReduceMeta const *m = *((AllReduceMeta **)task->local_args); + + GenericTensorAccessorW input_grad = helperGetGenericTensorAccessorRW( + m->input_type[0], regions[0], task->regions[0], FID_DATA, ctx, runtime); + GenericTensorAccessorR output_grad = helperGetGenericTensorAccessorRO( + m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); + + assert(input_grad.data_type == output_grad.data_type); + backward_kernel_wrapper(m, input_grad, output_grad); +} + +}; // namespace FlexFlow + +namespace std { +size_t hash::operator()( + FlexFlow::AllReduceParams const ¶ms) const { + size_t key = 0; + hash_combine(key, params.allreduce_legion_dim); + return key; +} + +} // namespace std diff --git a/src/parallel_ops/combine.cc b/src/parallel_ops/combine.cc index a4169ea306..198f450636 100644 --- a/src/parallel_ops/combine.cc +++ b/src/parallel_ops/combine.cc @@ -88,7 +88,7 @@ Combine::Combine(FFModel &model, dims[combine_dim].degree /= combine_degree; ParallelTensorBase::update_parallel_ids(numdim, dims); outputs[0] = model.create_parallel_tensor_legion_ordering( - numdim, dims, DT_FLOAT, this); + numdim, dims, _input->data_type, this); // inputs[0]->print("Combine::input"); // outputs[0]->print("Combine::output"); } @@ -97,11 +97,13 @@ OpMeta *Combine::init_task(Task const *task, std::vector const ®ions, Context ctx, Runtime *runtime) { - Combine *rep = (Combine *)task->args; - // FFHandler handle = *((FFHandler *)task->local_args); - // CombineMeta* m = new CombineMeta(handle); - // m->data_type = rep->outputs[0]->data_type; - return nullptr; + Combine *cmb = (Combine *)task->args; + FFHandler handle = *((FFHandler *)task->local_args); + CombineMeta *m = new CombineMeta(handle); + m->input_type[0] = cmb->inputs[0]->data_type; + m->output_type[0] = cmb->outputs[0]->data_type; + assert(m->input_type[0] == m->output_type[0]); + return m; } void Combine::init(FFModel const &ff) { @@ -111,6 +113,7 @@ void Combine::init(FFModel const &ff) { Runtime *runtime = ff.config.lg_hlr; assert(numOutputs == 1); assert(numInputs == 1); + set_argumentmap_for_init(ff, argmap); IndexLauncher launcher(COMBINE_INIT_TASK_ID, parallel_is, TaskArgument(this, sizeof(Combine)), @@ -130,6 +133,48 @@ void Combine::init(FFModel const &ff) { launcher.add_field(1, FID_DATA); FutureMap fm = runtime->execute_index_space(ctx, launcher); fm.wait_all_results(); + set_opmeta_from_futuremap(ff, fm); +} + +void Combine::init_inference(FFModel const &ff, + std::vector const &batch_inputs, + std::vector const &batch_outputs, + MachineView const *mv) { + ArgumentMap argmap; + parallel_is = batch_outputs[0]->parallel_is; + Context ctx = ff.config.lg_ctx; + Runtime *runtime = ff.config.lg_hlr; + assert(numOutputs == 1); + assert(numInputs == 1); + size_t machine_view_hash = + mv ? mv->hash() : batch_outputs[0]->machine_view.hash(); + set_argumentmap_for_init_inference(ff, argmap, batch_outputs[0]); + IndexLauncher launcher(COMBINE_INIT_TASK_ID, + parallel_is, + TaskArgument(this, sizeof(Combine)), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + machine_view_hash); + assert(inference_input_lps.find(batch_inputs[0]) != + inference_input_lps.end()); + launcher.add_region_requirement( + RegionRequirement(inference_input_lps[batch_inputs[0]], + 0 /*projection id*/, + READ_ONLY, + EXCLUSIVE, + batch_inputs[0]->region)); + launcher.add_field(0, FID_DATA); + launcher.add_region_requirement(RegionRequirement(batch_outputs[0]->part, + 0 /*projection id*/, + WRITE_ONLY, + EXCLUSIVE, + batch_outputs[0]->region)); + launcher.add_field(1, FID_DATA); + FutureMap fm = runtime->execute_index_space(ctx, launcher); + fm.wait_all_results(); + set_opmeta_from_futuremap_inference(ff, fm, batch_outputs[0]); } void Combine::create_input_partition(FFModel &ff) { @@ -147,6 +192,61 @@ void Combine::create_input_partition(FFModel &ff) { output_grad_lp); } +void Combine::create_input_partition_inference( + FFModel &ff, + std::vector const &batch_inputs, + std::vector const &batch_outputs) { + assert(ff.config.computationMode == COMP_MODE_INFERENCE); + assert(batch_outputs[0]->part != LogicalPartition::NO_PART); + assert(batch_inputs[0]->part != LogicalPartition::NO_PART); + // input_lp is a disjoint partition + ff.create_disjoint_partition(batch_outputs[0]->num_dims, + batch_outputs[0]->dims, + batch_outputs[0]->parallel_is, + batch_inputs[0]->region, + inference_input_lps[batch_inputs[0]]); +} + +FutureMap Combine::inference(FFModel const &ff, + BatchConfig const &bc, + std::vector const &batch_inputs, + std::vector const &batch_outputs, + MachineView const *mv) { + ArgumentMap argmap; + Context ctx = ff.config.lg_ctx; + Runtime *runtime = ff.config.lg_hlr; + parallel_is = batch_outputs[0]->parallel_is; + assert(numOutputs == 1); + assert(numInputs == 1); + assert(batch_inputs[0]->data_type == batch_outputs[0]->data_type); + DataType data_type = batch_inputs[0]->data_type; + size_t machine_view_hash = + mv ? mv->hash() : batch_outputs[0]->machine_view.hash(); + set_argumentmap_for_inference(ff, argmap, batch_outputs[0]); + IndexLauncher launcher(COMBINE_FWD_TASK_ID, + batch_outputs[0]->parallel_is, + TaskArgument(nullptr, 0), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + machine_view_hash); + launcher.add_region_requirement( + RegionRequirement(inference_input_lps[batch_inputs[0]], + 0 /*projection id*/, + READ_ONLY, + EXCLUSIVE, + batch_inputs[0]->region)); + launcher.add_field(0, FID_DATA); + launcher.add_region_requirement(RegionRequirement(batch_outputs[0]->part, + 0 /*projection id*/, + WRITE_ONLY, + EXCLUSIVE, + batch_outputs[0]->region)); + launcher.add_field(1, FID_DATA); + return runtime->execute_index_space(ctx, launcher); +} + void Combine::forward(FFModel const &ff) { ArgumentMap argmap; Context ctx = ff.config.lg_ctx; @@ -157,7 +257,7 @@ void Combine::forward(FFModel const &ff) { DataType data_type = inputs[0]->data_type; IndexLauncher launcher(COMBINE_FWD_TASK_ID, outputs[0]->parallel_is, - TaskArgument(&data_type, sizeof(data_type)), + TaskArgument(nullptr, 0), argmap, Predicate::TRUE_PRED, false /*must*/, @@ -261,8 +361,11 @@ void Combine::forward_task(Task const *task, Runtime *runtime) { assert(regions.size() == 2); assert(task->regions.size() == 2); - DataType data_type = *((DataType *)task->args); - if (data_type == DT_FLOAT) { + CombineMeta const *m = *((CombineMeta **)task->local_args); + DataType data_type = m->input_type[0]; + if (data_type == DT_HALF) { + forward_task_with_type(task, regions, ctx, runtime); + } else if (data_type == DT_FLOAT) { forward_task_with_type(task, regions, ctx, runtime); } else if (data_type == DT_DOUBLE) { forward_task_with_type(task, regions, ctx, runtime); diff --git a/src/parallel_ops/kernels/allreduce_kernels.cpp b/src/parallel_ops/kernels/allreduce_kernels.cpp new file mode 100644 index 0000000000..78742568c6 --- /dev/null +++ b/src/parallel_ops/kernels/allreduce_kernels.cpp @@ -0,0 +1,46 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "flexflow/parallel_ops/kernels/allreduce_kernels.h" +#include "flexflow/utils/hip_helper.h" +#include + +namespace FlexFlow { + +AllReduceMeta::AllReduceMeta(FFHandler handle, AllReduce const *reduct) + : OpMeta(handle) {} + +namespace Kernels { +namespace AllReduce { + +void forward_kernel_wrapper(AllReduceMeta const *m, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &output) { + hipStream_t stream; + checkCUDA(get_legion_stream(&stream)); + assert(input.data_type == output.data_type); + assert(input.domain == output.domain); + assert(false && "To be implemented"); +} + +void backward_kernel_wrapper(AllReduceMeta const *m, + GenericTensorAccessorW const &input_grad, + GenericTensorAccessorR const &output_grad) { + assert(false && "To be implemented"); +} + +} // namespace AllReduce +} // namespace Kernels +} // namespace FlexFlow diff --git a/src/parallel_ops/kernels/allreduce_kernels.cu b/src/parallel_ops/kernels/allreduce_kernels.cu new file mode 100644 index 0000000000..1ae9ee27b8 --- /dev/null +++ b/src/parallel_ops/kernels/allreduce_kernels.cu @@ -0,0 +1,56 @@ +/* Copyright 2023 CMU, Facebook, LANL, MIT, NVIDIA, and Stanford (alphabetical) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "flexflow/parallel_ops/kernels/allreduce_kernels.h" +#include "flexflow/utils/cuda_helper.h" + +namespace FlexFlow { + +AllReduceMeta::AllReduceMeta(FFHandler handle, AllReduce const *reduct) + : OpMeta(handle) {} + +namespace Kernels { +namespace AllReduce { + +void forward_kernel_wrapper(AllReduceMeta const *m, + GenericTensorAccessorR const &input, + GenericTensorAccessorW const &output) { + cudaStream_t stream; + checkCUDA(get_legion_stream(&stream)); + assert(input.data_type == output.data_type); + assert(input.domain == output.domain); +#ifdef FF_USE_NCCL + ncclDataType_t nccl_data_type = ff_to_nccl_datatype(input.data_type); + checkNCCL(ncclAllReduce(input.ptr, + output.ptr, + input.domain.get_volume(), + nccl_data_type, + ncclSum, + m->handle.ncclComm, + stream)); +#else + assert(false && "Must enable FF_USE_NCCL to use AllReduce operators"); +#endif +} + +void backward_kernel_wrapper(AllReduceMeta const *m, + GenericTensorAccessorW const &input_grad, + GenericTensorAccessorR const &output_grad) { + assert(false && "To be implemented"); +} + +} // namespace AllReduce +} // namespace Kernels +} // namespace FlexFlow diff --git a/src/parallel_ops/kernels/combine_kernels.cpp b/src/parallel_ops/kernels/combine_kernels.cpp index 2d748cfab3..d6e9568223 100644 --- a/src/parallel_ops/kernels/combine_kernels.cpp +++ b/src/parallel_ops/kernels/combine_kernels.cpp @@ -51,6 +51,9 @@ void backward_kernel(T const *output_grad_ptr, num_elements); } +template void forward_kernel(half const *input_ptr, + half *output_ptr, + size_t num_elements); template void forward_kernel(float const *input_ptr, float *output_ptr, size_t num_elements); @@ -63,6 +66,9 @@ template void forward_kernel(int32_t const *input_ptr, template void forward_kernel(int64_t const *input_ptr, int64_t *output_ptr, size_t num_elements); +template void backward_kernel(half const *output_grad_ptr, + half *input_grad_ptr, + size_t num_elements); template void backward_kernel(float const *output_grad_ptr, float *input_grad_ptr, size_t num_elements); diff --git a/src/parallel_ops/kernels/combine_kernels.cu b/src/parallel_ops/kernels/combine_kernels.cu index d8f414ef0f..1ab79a7944 100644 --- a/src/parallel_ops/kernels/combine_kernels.cu +++ b/src/parallel_ops/kernels/combine_kernels.cu @@ -44,6 +44,9 @@ void backward_kernel(T const *output_grad_ptr, input_grad_ptr, output_grad_ptr, num_elements); } +template void forward_kernel(half const *input_ptr, + half *output_ptr, + size_t num_elements); template void forward_kernel(float const *input_ptr, float *output_ptr, size_t num_elements); @@ -56,6 +59,9 @@ template void forward_kernel(int32_t const *input_ptr, template void forward_kernel(int64_t const *input_ptr, int64_t *output_ptr, size_t num_elements); +template void backward_kernel(half const *output_grad_ptr, + half *input_grad_ptr, + size_t num_elements); template void backward_kernel(float const *output_grad_ptr, float *input_grad_ptr, size_t num_elements); diff --git a/src/runtime/cuda_helper.cu b/src/runtime/cuda_helper.cu index 6ef06e1f65..1aa216e5c9 100644 --- a/src/runtime/cuda_helper.cu +++ b/src/runtime/cuda_helper.cu @@ -461,6 +461,24 @@ cudaDataType_t ff_to_cuda_datatype(DataType type) { return CUDA_R_32F; } +#ifdef FF_USE_NCCL +ncclDataType_t ff_to_nccl_datatype(DataType type) { + switch (type) { + case DT_HALF: + return ncclHalf; + case DT_FLOAT: + return ncclFloat; + case DT_DOUBLE: + return ncclDouble; + case DT_INT32: + return ncclInt; + default: + assert(false && "Unspoorted nccl data type"); + } + return ncclFloat; +} +#endif + cudaDataType_t cudnn_to_cuda_datatype(cudnnDataType_t type) { switch (type) { case CUDNN_DATA_FLOAT: @@ -500,6 +518,8 @@ template __global__ void template __global__ void assign_kernel(int64_t *ptr, coord_t size, int64_t value); +template __global__ void + add_kernel(half *dst, half const *src, size_t size); template __global__ void add_kernel(float *dst, float const *src, size_t size); template __global__ void @@ -509,8 +529,12 @@ template __global__ void template __global__ void add_kernel(int64_t *dst, int64_t const *src, size_t size); +template __global__ void + copy_kernel(half *dst, half const *src, coord_t size); template __global__ void copy_kernel(float *dst, float const *src, coord_t size); +template __global__ void + copy_kernel(double *dst, double const *src, coord_t size); template __global__ void copy_kernel(int32_t *dst, int32_t const *src, coord_t size); template __global__ void diff --git a/src/runtime/ffconst_utils.cc b/src/runtime/ffconst_utils.cc index d2b68595bd..39f9d1dd0d 100644 --- a/src/runtime/ffconst_utils.cc +++ b/src/runtime/ffconst_utils.cc @@ -187,6 +187,8 @@ std::string get_operator_type_name(OperatorType type) { return "Replicate"; case OP_REDUCTION: return "Reduction"; + case OP_ALLREDUCE: + return "AllReduce"; case OP_PIPELINE: return "Pipeline"; case OP_FUSED_PARALLEL: diff --git a/src/runtime/fftype.cc b/src/runtime/fftype.cc index 91e0d077c4..2b94f07999 100644 --- a/src/runtime/fftype.cc +++ b/src/runtime/fftype.cc @@ -1,11 +1,15 @@ #include "flexflow/fftype.h" +#include "flexflow/config.h" #include namespace FlexFlow { -LayerID::LayerID() : id(0) {} +const LayerID LayerID::NO_ID = LayerID(); -LayerID::LayerID(size_t _id) : id(_id) { +LayerID::LayerID() : id(0), transformer_layer_id(MAX_NUM_TRANSFORMER_LAYERS) {} + +LayerID::LayerID(size_t _id, size_t _transformer_layer_id) + : id(_id), transformer_layer_id(_transformer_layer_id) { assert(is_valid_id()); } @@ -14,7 +18,11 @@ bool LayerID::is_valid_id() const { } bool operator==(LayerID const &lhs, LayerID const &rhs) { + // id should be sufficient to distinguish different layers + if (lhs.id == rhs.id) { + assert(lhs.transformer_layer_id == rhs.transformer_layer_id); + } return lhs.id == rhs.id; } -}; // namespace FlexFlow \ No newline at end of file +}; // namespace FlexFlow diff --git a/src/runtime/graph.cc b/src/runtime/graph.cc index e8a1b6f9f1..5c0513baa8 100644 --- a/src/runtime/graph.cc +++ b/src/runtime/graph.cc @@ -46,6 +46,7 @@ #include "flexflow/ops/topk.h" #include "flexflow/ops/transpose.h" #include "flexflow/ops/tree_inc_multihead_self_attention.h" +#include "flexflow/parallel_ops/allreduce.h" #include "flexflow/parallel_ops/combine.h" #include "flexflow/parallel_ops/fused_parallel_op.h" #include "flexflow/parallel_ops/partition.h" @@ -1961,14 +1962,61 @@ std::pair, std::unordered_map> } curr_best_graph = std::unique_ptr(graph); MachineView data_parallel_view; - data_parallel_view.device_type = MachineView::GPU; - data_parallel_view.ndims = 1; - data_parallel_view.dim[0] = - model->config.numNodes * model->config.workersPerNode; - data_parallel_view.stride[0] = 1; - data_parallel_view.start_device_id = 0; + int degree, num_transformer_layers_per_stage; + if (model->config.computationMode == COMP_MODE_TRAINING) { + data_parallel_view.device_type = MachineView::GPU; + data_parallel_view.ndims = 1; + data_parallel_view.dim[0] = + model->config.numNodes * model->config.workersPerNode; + data_parallel_view.stride[0] = 1; + data_parallel_view.start_device_id = 0; + } else { + // Currently assume a 1D machine view is needed + assert(model->config.data_parallelism_degree == 1 || + model->config.tensor_parallelism_degree == 1); + degree = model->config.data_parallelism_degree * + model->config.tensor_parallelism_degree; + num_transformer_layers_per_stage = + model->current_transformer_layer_id / + model->config.pipeline_parallelism_degree + + 1; + } for (auto const &node : curr_best_graph->inEdges) { - curr_optimal_views[node.first] = data_parallel_view; + Op const *op = node.first.ptr; + if (model->config.computationMode == COMP_MODE_TRAINING) { + curr_optimal_views[node.first] = data_parallel_view; + } else { + MachineView mv; + mv.device_type = MachineView::GPU; + mv.ndims = 1; + int total_parallel_degree = 1; + for (int i = 0; i < op->outputs[0]->num_dims; i++) { + total_parallel_degree *= op->outputs[0]->dims[i].degree; + } + mv.dim[0] = total_parallel_degree; + mv.stride[0] = 1; + LayerID layer_guid = op->layer_guid; + if (op->op_type == OP_INPUT) { + // All inputs are assigned to the first stage + layer_guid.transformer_layer_id = 0; + } else if (layer_guid == LayerID::NO_ID) { + // Assert that we only have a single input + while (op->layer_guid == LayerID::NO_ID) { + assert(op->numInputs == 1); + op = op->inputs[0]->owner_op; + assert(op != nullptr); + } + layer_guid = op->layer_guid; + } + mv.start_device_id = degree * (layer_guid.transformer_layer_id / + num_transformer_layers_per_stage); + assert(mv.start_device_id + degree - 1 < + model->config.numNodes * model->config.workersPerNode); + curr_optimal_views[node.first] = mv; + for (int i = 0; i < node.first.ptr->numOutputs; i++) { + assert(node.first.ptr->outputs[i]->is_valid_machine_view(mv)); + } + } } } else { // Main step to optimize the PCG of an FFModel @@ -2237,23 +2285,17 @@ GraphOptimalViewSerialized case OP_EMBEDDING: { Embedding *embed = (Embedding *)op; sez.serialize(embed->layer_guid.id); + sez.serialize(embed->layer_guid.transformer_layer_id); sez.serialize(embed->num_entries); sez.serialize(embed->out_channels); sez.serialize(embed->aggr); sez.serialize(embed->data_type); break; } - case OP_EW_ADD: - case OP_EW_SUB: - case OP_EW_MUL: - case OP_EW_MAX: - case OP_EW_MIN: { - sez.serialize(op->op_type); - break; - } case OP_MULTIHEAD_ATTENTION: { MultiHeadAttention *attn = (MultiHeadAttention *)op; sez.serialize(attn->layer_guid.id); + sez.serialize(attn->layer_guid.transformer_layer_id); sez.serialize(attn->oProjSize); sez.serialize(attn->num_heads); sez.serialize(attn->qProjSize); @@ -2267,6 +2309,7 @@ GraphOptimalViewSerialized case OP_INC_MULTIHEAD_SELF_ATTENTION: { IncMultiHeadSelfAttention *attn = (IncMultiHeadSelfAttention *)op; sez.serialize(attn->layer_guid.id); + sez.serialize(attn->layer_guid.transformer_layer_id); sez.serialize(attn->oProjSize); sez.serialize(attn->num_heads); sez.serialize(attn->qProjSize); @@ -2287,6 +2330,7 @@ GraphOptimalViewSerialized SpecIncMultiHeadSelfAttention *attn = (SpecIncMultiHeadSelfAttention *)op; sez.serialize(attn->layer_guid.id); + sez.serialize(attn->layer_guid.transformer_layer_id); sez.serialize(attn->oProjSize); sez.serialize(attn->num_heads); sez.serialize(attn->qProjSize); @@ -2305,6 +2349,7 @@ GraphOptimalViewSerialized TreeIncMultiHeadSelfAttention *attn = (TreeIncMultiHeadSelfAttention *)op; sez.serialize(attn->layer_guid.id); + sez.serialize(attn->layer_guid.transformer_layer_id); sez.serialize(attn->oProjSize); sez.serialize(attn->num_heads); sez.serialize(attn->qProjSize); @@ -2324,6 +2369,7 @@ GraphOptimalViewSerialized case OP_INC_MULTIQUERY_SELF_ATTENTION: { IncMultiQuerySelfAttention *attn = (IncMultiQuerySelfAttention *)op; sez.serialize(attn->layer_guid.id); + sez.serialize(attn->layer_guid.transformer_layer_id); sez.serialize(attn->oProjSize); sez.serialize(attn->num_heads); sez.serialize(attn->qProjSize); @@ -2363,6 +2409,11 @@ GraphOptimalViewSerialized sez.serialize(combine->combine_degree); break; } + case OP_ALLREDUCE: { + AllReduce *allreduce = (AllReduce *)op; + sez.serialize(allreduce->allreduce_dim); + break; + } case OP_FUSED_PARALLEL: { FusedParallelOp *fused = (FusedParallelOp *)op; sez.serialize(fused->num_parallel_ops); @@ -2589,10 +2640,11 @@ void FFModel::deserialize_graph_optimal_view( assert(num_inputs == 1); AggrMode aggr; int num_entries, out_channels; - size_t id; + size_t id, transformer_layer_id; DataType data_type; dez.deserialize(id); - LayerID layer_guid(id); + dez.deserialize(transformer_layer_id); + LayerID layer_guid(id, transformer_layer_id); dez.deserialize(num_entries); dez.deserialize(out_channels); dez.deserialize(aggr); @@ -2612,11 +2664,7 @@ void FFModel::deserialize_graph_optimal_view( case OP_EW_MUL: case OP_EW_MAX: case OP_EW_MIN: { - assert(num_inputs == 2); - OperatorType op_type; - dez.deserialize(op_type); - node = get_or_create_node({inputs[0], inputs[1]}, - {op_type}); + node = ElementBinary::deserialize(*this, dez, inputs, num_inputs); break; } case OP_CONV2D: { @@ -2667,9 +2715,10 @@ void FFModel::deserialize_graph_optimal_view( int embed_dim, num_heads, k_dim, v_dim; float dropout; bool bias, add_bias_kv, add_zero_attn; - size_t id; + size_t id, transformer_layer_id; dez.deserialize(id); - LayerID layer_guid(id); + dez.deserialize(transformer_layer_id); + LayerID layer_guid(id, transformer_layer_id); dez.deserialize(embed_dim); dez.deserialize(num_heads); dez.deserialize(k_dim); @@ -2700,9 +2749,10 @@ void FFModel::deserialize_graph_optimal_view( bool bias, add_bias_kv, add_zero_attn, apply_rotary_embedding, scaling_query, qk_prod_scaling, offload; DataType quantization_type; - size_t id; + size_t id, transformer_layer_id; dez.deserialize(id); - LayerID layer_guid(id); + dez.deserialize(transformer_layer_id); + LayerID layer_guid(id, transformer_layer_id); dez.deserialize(embed_dim); dez.deserialize(num_heads); dez.deserialize(k_dim); @@ -2743,9 +2793,10 @@ void FFModel::deserialize_graph_optimal_view( float dropout, scaling_factor; bool bias, add_bias_kv, add_zero_attn, apply_rotary_embedding, scaling_query, qk_prod_scaling; - size_t id; + size_t id, transformer_layer_id; dez.deserialize(id); - LayerID layer_guid(id); + dez.deserialize(transformer_layer_id); + LayerID layer_guid(id, transformer_layer_id); dez.deserialize(embed_dim); dez.deserialize(num_heads); dez.deserialize(k_dim); @@ -2784,9 +2835,10 @@ void FFModel::deserialize_graph_optimal_view( bool bias, add_bias_kv, add_zero_attn, apply_rotary_embedding, scaling_query, qk_prod_scaling, offload; DataType quantization_type; - size_t id; + size_t id, transformer_layer_id; dez.deserialize(id); - LayerID layer_guid(id); + dez.deserialize(transformer_layer_id); + LayerID layer_guid(id, transformer_layer_id); dez.deserialize(embed_dim); dez.deserialize(num_heads); dez.deserialize(k_dim); @@ -2828,9 +2880,10 @@ void FFModel::deserialize_graph_optimal_view( float dropout, scaling_factor; bool bias, add_bias_kv, add_zero_attn, apply_rotary_embedding, scaling_query, qk_prod_scaling; - size_t id; + size_t id, transformer_layer_id; dez.deserialize(id); - LayerID layer_guid(id); + dez.deserialize(transformer_layer_id); + LayerID layer_guid(id, transformer_layer_id); dez.deserialize(embed_dim); dez.deserialize(num_heads); dez.deserialize(k_dim); @@ -2949,6 +3002,13 @@ void FFModel::deserialize_graph_optimal_view( {reduction_dim, reduction_degree}); break; } + case OP_ALLREDUCE: { + assert(num_inputs == 1); + int allreduce_dim; + dez.deserialize(allreduce_dim); + node = get_or_create_node(inputs[0], {allreduce_dim}); + break; + } case OP_FUSED_PARALLEL: { assert(num_inputs == 1); std::vector parallel_ops; diff --git a/src/runtime/hip_helper.cpp b/src/runtime/hip_helper.cpp index 6354c5d737..9bcccb041a 100644 --- a/src/runtime/hip_helper.cpp +++ b/src/runtime/hip_helper.cpp @@ -372,16 +372,23 @@ template __global__ void template __global__ void assign_kernel(int64_t *ptr, coord_t size, int64_t value); +template __global__ void + add_kernel(half *dst, half const *src, size_t size); template __global__ void add_kernel(float *dst, float const *src, size_t size); template __global__ void add_kernel(double *dst, double const *src, size_t size); -template __global__ void add_kernel(int *dst, int const *src, size_t size); template __global__ void - add_kernel(long *dst, long const *src, size_t size); + add_kernel(int32_t *dst, int32_t const *src, size_t size); +template __global__ void + add_kernel(int64_t *dst, int64_t const *src, size_t size); +template __global__ void + copy_kernel(half *dst, half const *src, coord_t size); template __global__ void copy_kernel(float *dst, float const *src, coord_t size); +template __global__ void + copy_kernel(double *dst, double const *src, coord_t size); template __global__ void copy_kernel(int32_t *dst, int32_t const *src, coord_t size); template __global__ void @@ -406,13 +413,19 @@ template __global__ void apply_add_with_scale(int64_t *data_ptr, template __host__ void print_tensor(float const *ptr, size_t rect, char const *prefix); +template __host__ void + print_tensor(double const *ptr, size_t rect, char const *prefix); template __host__ void print_tensor(int32_t const *ptr, size_t rect, char const *prefix); template __host__ void print_tensor(int64_t const *ptr, size_t rect, char const *prefix); +template __host__ void + print_tensor(half const *ptr, size_t rect, char const *prefix); template __host__ float *download_tensor(float const *ptr, size_t num_elements); +template __host__ half *download_tensor(half const *ptr, + size_t num_elements); template __host__ double *download_tensor(double const *ptr, size_t num_elements); template __host__ int32_t *download_tensor(int32_t const *ptr, diff --git a/src/runtime/inference_manager.cc b/src/runtime/inference_manager.cc index 67a78f9700..b6be945a94 100644 --- a/src/runtime/inference_manager.cc +++ b/src/runtime/inference_manager.cc @@ -54,6 +54,7 @@ InferenceManager::InferenceManager(FFConfig const &_config, num_devices && "Product of data, tensor, and pipeline parallelism degrees does not " "match the number of available devices"); + // Deprecated logic below // populate array of valid single-device machine views for (int i = 0; i < num_devices; i++) { MachineView view; @@ -94,23 +95,23 @@ bool parallel_tensor_list_overlaps(std::vector const &list1, return false; } -void InferenceManager::compile_model_and_allocate_buffer( - FFModel *model, - std::unordered_map> const - &tensor_mapping) { +void InferenceManager::compile_model_and_allocate_buffer(FFModel *model) { + // TODO: currently assume there is a single data-parallel pipeline + // (i.e., data-parallel-degree == 1) + assert(model->config.data_parallelism_degree == 1); model->config.batchSize = max_num_tokens_per_batch; model->compile_inference(); Context ctx = model->config.lg_ctx; Runtime *runtime = model->config.lg_hlr; - std::unordered_map> mapping; - for (auto const &it : tensor_mapping) { - ParallelTensor pt; - model->get_parallel_tensor_from_tensor(it.first, pt); - assert(pt->owner_op != nullptr); - mapping[pt->owner_op] = it.second; - } // std::cout << std::endl << std::endl << "Operators MVs:" << std::endl; + int num_transformer_layers_per_stage = + model->current_transformer_layer_id / + model->config.pipeline_parallelism_degree + + 1; + int degree = model->config.data_parallelism_degree * + model->config.tensor_parallelism_degree; + for (int op_idx = 0; op_idx < model->operators.size(); op_idx++) { Op const *op = model->operators[op_idx]; // Skip weight operators @@ -119,52 +120,35 @@ void InferenceManager::compile_model_and_allocate_buffer( } // Get machine views std::vector machine_views; - if (mapping.find(op) != mapping.end()) { - machine_views = mapping[op]; - assert(machine_views.size() == ff_config.data_parallelism_degree); - } else { - // Mapping the current operator using the same machine - // view as the inputs - assert(op->numInputs > 0); - for (int j = 0; j < ff_config.data_parallelism_degree; j++) { - MachineView mv = tensor_buffer[op->inputs[0]][j]->machine_view; - for (int k = 1; k < op->numInputs; k++) { - if (mv != tensor_buffer[op->inputs[k]][j]->machine_view) { - fprintf(stderr, - "[Warning] a potentially unnecessary " - " inter-GPU copy of size %zu\n", - op->inputs[k]->get_volume()); - // Heuristics: we use the mv with a larger start_device_id - // to promote load balancing - if (mv.start_device_id < - tensor_buffer[op->inputs[k]][j]->machine_view.start_device_id) { - mv = tensor_buffer[op->inputs[k]][j]->machine_view; - } - } - } - if (op->op_type == OP_REPLICATE) { - // std::cout << "Replicate operator got machine view: " << mv - // << std::endl; - assert(model->config.tensor_parallelism_degree > 1); - mv.dim[0] = ff_config.tensor_parallelism_degree; - mv.stride[0] = 1; - if (mv.start_device_id + mv.dim[0] > num_devices) { - mv.start_device_id -= - (mv.start_device_id + mv.dim[0]) - num_devices; - } - // std::cout << "Corrected machine view: " << mv << std::endl; - } else if (op->op_type == OP_REDUCTION) { - // std::cout << "Reduction operator got machine view: " << mv - // << std::endl; - assert(model->config.tensor_parallelism_degree > 1); - mv.dim[0] = 1; - mv.stride[0] = 0; - // std::cout << "Corrected machine view: " << mv << std::endl; + for (int j = 0; j < model->config.data_parallelism_degree; j++) { + MachineView mv; + mv.device_type == MachineView::GPU; + mv.ndims = 1; + // mv.start_device_id = 0; + mv.stride[0] = 1; + int parallel_degree = 1; + for (int k = 0; k < op->outputs[0]->num_dims; k++) { + parallel_degree *= op->outputs[0]->dims[k].degree; + } + mv.dim[0] = parallel_degree; + LayerID layer_guid = op->layer_guid; + if (op->op_type == OP_INPUT) { + // All inputs are assigned to the first stage + layer_guid.transformer_layer_id = 0; + } else if (layer_guid == LayerID::NO_ID) { + Op const *op_with_guid = op; + // Assert that we only have a single input + while (op_with_guid->layer_guid == LayerID::NO_ID) { + assert(op_with_guid->numInputs == 1); + op_with_guid = op_with_guid->inputs[0]->owner_op; + assert(op_with_guid != nullptr); } - assert(mv.start_device_id + mv.dim[0] <= num_devices); - machine_views.push_back(mv); + layer_guid = op_with_guid->layer_guid; } - assert(machine_views.size() == ff_config.data_parallelism_degree); + mv.start_device_id = degree * (layer_guid.transformer_layer_id / + num_transformer_layers_per_stage); + assert(mv == op->outputs[0]->machine_view); + machine_views.push_back(mv); } // std::cout << "operator: " << op->name << std::endl; // for (int i = 0; i < op->numInputs; i++) { @@ -232,7 +216,7 @@ void InferenceManager::compile_model_and_allocate_buffer( } } if (!found_parallel_tensor) { - for (int j = 0; j < ff_config.data_parallelism_degree; j++) { + for (int j = 0; j < model->config.data_parallelism_degree; j++) { // Copy the metadata from pt_base to pt ParallelTensor pt = new ParallelTensorBase(*pt_base); pt->region = @@ -257,7 +241,7 @@ void InferenceManager::compile_model_and_allocate_buffer( } void InferenceManager::init_operators_inference(FFModel *model) { - for (int batch_index = 0; batch_index < ff_config.data_parallelism_degree; + for (int batch_index = 0; batch_index < model->config.data_parallelism_degree; batch_index++) { int expert_device_index = 0; int device_index = batch_index % num_devices; @@ -313,7 +297,7 @@ FutureMap InferenceManager::inference(FFModel *model, assert(bc.num_active_tokens() > 0 && bc.num_active_requests() > 0); // We currently assume that the index-th batch will be placed // on the device_index-th device (except for the experts layers) - int batch_index = index % ff_config.data_parallelism_degree; + int batch_index = index % model->config.data_parallelism_degree; FutureMap fm; bool found_input_operator = false; for (size_t o = 0; o < model->operators.size(); o++) { @@ -410,15 +394,19 @@ void InferenceManager::load_positions(BatchConfig const &bc, runtime->execute_index_space(ctx, launcher); } +void FFModel::set_transformer_layer_id(int id) { + // We assume that users call this function with + // monotonically increasing ids + assert(id == current_transformer_layer_id + 1 || + (id == 0 && current_transformer_layer_id == 0)); + current_transformer_layer_id = id; + assert(id < MAX_NUM_TRANSFORMER_LAYERS); +} + void FFModel::compile_inference() { Context ctx = config.lg_ctx; Runtime *runtime = config.lg_hlr; config.computationMode = COMP_MODE_INFERENCE; - { - fprintf( - stderr, - "Note: inference currently only supports data/pipeline parallel.\n"); - } create_operators_from_layers(); // Launch the graph optimize task { @@ -651,5 +639,42 @@ void FFModel::compile_inference() { handle.get_tree_id()); } } +#ifdef FF_USE_NCCL + for (size_t l = 0; l < operators.size(); l++) { + // Only create nccl for allreduce and fusedop for inference + // (fusedop may include allreduces) + if (operators[l]->op_type == OP_ALLREDUCE || + operators[l]->op_type == OP_FUSED) { + MachineView view = operators[l]->outputs[0]->machine_view; + if (view_hash_to_nccl_comms.find(view.hash()) == + view_hash_to_nccl_comms.end()) { + TaskLauncher launcher(NCCL_GETUNIQUEID_TASK_ID, TaskArgument(NULL, 0)); + Future future = runtime->execute_task(ctx, launcher); + ncclUniqueId ncclId = future.get_result(); + IndexSpace task_is = get_or_create_task_is(view); + ArgumentMap argmap; + IndexLauncher index_launcher( + NCCL_INIT_COMMS_TASK_ID, + task_is, + TaskArgument(&ncclId, sizeof(ncclUniqueId)), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + view.hash() /*MappingTagID*/); + FutureMap fm = runtime->execute_index_space(ctx, index_launcher); + fm.wait_all_results(); + int idx = 0; + Domain task_domain = runtime->get_index_space_domain(ctx, task_is); + ncclComm_t *nccl_comms = + (ncclComm_t *)malloc(sizeof(ncclComm_t) * task_domain.get_volume()); + for (Domain::DomainPointIterator it(task_domain); it; it++, idx++) { + nccl_comms[idx] = fm.get_result(*it); + } + view_hash_to_nccl_comms[view.hash()] = nccl_comms; + } + } + } +#endif } }; // namespace FlexFlow diff --git a/src/runtime/layer.cc b/src/runtime/layer.cc index 6dfd5f2f35..d2473f4b2b 100644 --- a/src/runtime/layer.cc +++ b/src/runtime/layer.cc @@ -16,8 +16,9 @@ Layer::Layer(FFModel *model, const Tensor _input3, const Tensor _input4) : op_type(_otype), data_type(_dtype), - layer_guid(model->layer_global_guid++), numInputs(_numInputs), - numWeights(_numWeights), numOutputs(_numOutputs) { + layer_guid(model->layer_global_guid++, + model->current_transformer_layer_id), + numInputs(_numInputs), numWeights(_numWeights), numOutputs(_numOutputs) { std::string pcname; if (_name == nullptr) { pcname = get_operator_type_name(op_type); @@ -50,8 +51,9 @@ Layer::Layer(FFModel *model, int _numOutputs, Tensor const *_tensors) : op_type(_otype), data_type(_dtype), - layer_guid(model->layer_global_guid++), numInputs(_numInputs), - numWeights(_numWeights), numOutputs(_numOutputs) { + layer_guid(model->layer_global_guid++, + model->current_transformer_layer_id), + numInputs(_numInputs), numWeights(_numWeights), numOutputs(_numOutputs) { std::string pcname; if (_name == nullptr) { pcname = get_operator_type_name(op_type); diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 64c3a2eb61..763a5bcfd5 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -58,6 +58,7 @@ #include "flexflow/ops/topk.h" #include "flexflow/ops/transpose.h" #include "flexflow/ops/tree_inc_multihead_self_attention.h" +#include "flexflow/parallel_ops/allreduce.h" #include "flexflow/parallel_ops/combine.h" #include "flexflow/parallel_ops/fused_parallel_op.h" #include "flexflow/parallel_ops/partition.h" @@ -990,6 +991,7 @@ void Op::set_argumentmap_for_init_inference(FFModel const &ff, Runtime *runtime = ff.config.lg_hlr; Domain domain = runtime->get_index_space_domain(ctx, this->parallel_is); MachineView const view = output0->machine_view; + assert(ff.config.computationMode == COMP_MODE_INFERENCE); switch (domain.get_dim()) { #ifdef FF_USE_NCCL #define DIMFUNC(DIM) \ @@ -998,8 +1000,7 @@ void Op::set_argumentmap_for_init_inference(FFModel const &ff, int idx = 0; \ for (PointInRectIterator it(rect); it(); it++) { \ FFHandler handle = ff.handlers[view.get_device_id(*it)]; \ - if (ff.config.computationMode == COMP_MODE_TRAINING && \ - op_type == OP_WEIGHT) { \ + if (op_type == OP_ALLREDUCE) { \ ncclComm_t *nccl_comms = ff.find_nccl_comms(view); \ handle.ncclComm = nccl_comms[idx++]; \ } \ @@ -1302,8 +1303,9 @@ FFModel::FFModel(FFConfig &_config, bool cpu_offload) layer_global_guid(LAYER_GUID_FIRST_VALID), tensor_global_guid(TENSOR_GUID_FIRST_VALID), parallel_tensor_global_guid(PARALLEL_TENSOR_GUID_FIRST_VALID), - node_global_guid(NODE_GUID_FIRST_VALID), config(_config), optimizer(NULL), - loss_op(NULL), metrics_op(NULL), simulator(NULL) { + node_global_guid(NODE_GUID_FIRST_VALID), current_transformer_layer_id(0), + config(_config), optimizer(NULL), loss_op(NULL), metrics_op(NULL), + simulator(NULL) { this->search = new PCG::SearchHelper(this); this->graph_search = new PCG::GraphSearchHelper(this); this->cpu_offload = cpu_offload; @@ -1348,7 +1350,7 @@ ncclComm_t *FFModel::find_nccl_comms(MachineView const &view) const { auto const &it = view_hash_to_nccl_comms.find(view.hash()); if (it == view_hash_to_nccl_comms.end()) { assert(config.computationMode == COMP_MODE_INFERENCE); - return NULL; + return nullptr; } else { return it->second; } @@ -2630,9 +2632,14 @@ bool FFModel::apply_fusion(std::vector const &operators, operators[l]->op_type == OP_WEIGHT) { continue; } - // don't fuse parallel op since they have different parallel_is in - // forward/backward - if (operators[l]->is_parallel_op()) { + // don't fuse parallel op except allReduce since they have different + // parallel_is in forward/backward + if (operators[l]->is_parallel_op() && + operators[l]->op_type != OP_ALLREDUCE) { + continue; + } + // don't fuse softmax since it returns inference results + if (operators[l]->op_type == OP_SOFTMAX) { continue; } size_t start = 0; @@ -2675,9 +2682,10 @@ bool FFModel::apply_fusion(std::vector const &operators, operators[i]->op_type == OP_WEIGHT) { continue; } - // don't fuse parallel op since they have different parallel_is in - // forward/backward - if (operators[i]->is_parallel_op()) { + // don't fuse parallel op except allReduce since they have different + // parallel_is in forward/backward + if (operators[i]->is_parallel_op() && + operators[i]->op_type != OP_ALLREDUCE) { continue; } fused_op = new FusedOp(*this, operators[i]); @@ -2967,7 +2975,51 @@ void FFModel::create_operators_from_layers() { inputs.push_back(tensors_to_parallel_tensors[l->inputs[i]]); } Op *op = nullptr; - // add replicate operators if needed + // add a combine before arg_topk + if (config.computationMode == COMP_MODE_INFERENCE && + config.tensor_parallelism_degree > 1 && l->op_type == OP_ARG_TOPK) { + std::vector partitioned_inputs; + assert(inputs.size() == 1); + Combine *comb = new Combine(*this, + inputs[0], + 0 /*inner most dim*/, + config.tensor_parallelism_degree); + partitioned_inputs.push_back(comb->outputs[0]); + operators.push_back(comb); + op = create_operator_from_layer(l, partitioned_inputs); + } else { + op = create_operator_from_layer(l, inputs); + } + // add replicate operators after op if needed + if (config.computationMode == COMP_MODE_INFERENCE && + config.tensor_parallelism_degree > 1 && l->op_type == OP_EMBEDDING) { + assert(op->numOutputs == 1); + Replicate *repl = new Replicate(*this, + op->outputs[0], + op->outputs[0]->num_dims - 1, + config.tensor_parallelism_degree); + operators.push_back(repl); + op = repl; + } else if (config.computationMode == COMP_MODE_INFERENCE && + config.tensor_parallelism_degree > 1 && + (l->op_type == OP_INC_MULTIHEAD_SELF_ATTENTION || + l->op_type == OP_TREE_INC_MULTIHEAD_SELF_ATTENTION || + (l->op_type == OP_LINEAR && layer_idx >= 2 && + layers[layer_idx - 1]->op_type == OP_RELU && + layers[layer_idx - 2]->op_type == OP_LINEAR) || + (l->op_type == OP_LINEAR && layer_idx >= 5 && + layers[layer_idx - 1]->op_type == OP_EW_MUL && + layers[layer_idx - 2]->op_type == OP_EW_MUL && + layers[layer_idx - 3]->op_type == OP_SIGMOID && + layers[layer_idx - 4]->op_type == OP_LINEAR && + layers[layer_idx - 5]->op_type == OP_LINEAR))) { + assert(op->numOutputs == 1); + AllReduce *allreduce = + new AllReduce(*this, op->outputs[0], op->outputs[0]->num_dims - 1); + operators.push_back(allreduce); + op = allreduce; + } +#ifdef DEADCODE if (config.computationMode == COMP_MODE_INFERENCE && config.tensor_parallelism_degree > 1 && (l->op_type == OP_INC_MULTIHEAD_SELF_ATTENTION || @@ -3022,7 +3074,7 @@ void FFModel::create_operators_from_layers() { operators.push_back(reduct); op = reduct; } - +#endif assert(op->numOutputs == l->numOutputs); for (int i = 0; i < op->numOutputs; i++) { tensors_to_parallel_tensors[l->outputs[i]] = op->outputs[i]; @@ -3364,13 +3416,10 @@ void FFModel::compile(LossType loss_type, } #ifdef FF_USE_NCCL - if (config.computationMode == COMP_MODE_TRAINING) { - // init all nccl communicators - for (size_t l = 0; l < operators.size(); l++) { - // Only create nccl for weights - if (operators[l]->op_type != OP_WEIGHT) { - continue; - } + for (size_t l = 0; l < operators.size(); l++) { + // Only create nccl for weights in training + if ((operators[l]->op_type == OP_WEIGHT && + config.computationMode == COMP_MODE_TRAINING)) { MachineView view = operators[l]->outputs[0]->machine_view; if (view_hash_to_nccl_comms.find(view.hash()) == view_hash_to_nccl_comms.end()) { @@ -3789,6 +3838,9 @@ FFConfig::FFConfig() { } // Use Real::Machine::get_address_space_count() to obtain the number of nodes numNodes = Realm::Machine::get_machine().get_address_space_count(); + data_parallelism_degree = 1; + tensor_parallelism_degree = 1; + pipeline_parallelism_degree = 1; Runtime *runtime = Runtime::get_runtime(); lg_hlr = runtime; @@ -4426,6 +4478,13 @@ void register_flexflow_internal_tasks() { Runtime::preregister_task_variant( registrar, "Linear Init Task"); } + { + TaskVariantRegistrar registrar(LINEAR_INF_TASK_ID, "Linear Inference"); + registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); + registrar.set_leaf(); + Runtime::preregister_task_variant( + registrar, "Linear Inference Task"); + } { TaskVariantRegistrar registrar(LINEAR_FWD_TASK_ID, "Linear Forward"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); @@ -4836,6 +4895,13 @@ void register_flexflow_internal_tasks() { Runtime::preregister_task_variant( registrar, "FusedOp Forward Task"); } + { + TaskVariantRegistrar registrar(FUSEDOP_INF_TASK_ID, "FusedOp Inference"); + registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); + registrar.set_leaf(); + Runtime::preregister_task_variant( + registrar, "FusedOp Inference Task"); + } { TaskVariantRegistrar registrar(FUSEDOP_BWD_TASK_ID, "FusedOp Backward"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); @@ -4935,6 +5001,28 @@ void register_flexflow_internal_tasks() { Runtime::preregister_task_variant( registrar, "Reduction Backward Task"); } + // AllReduce + { + TaskVariantRegistrar registrar(ALLREDUCE_INIT_TASK_ID, "AllReduce Init"); + registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); + registrar.set_leaf(); + Runtime::preregister_task_variant( + registrar, "AllReduce init Task"); + } + { + TaskVariantRegistrar registrar(ALLREDUCE_FWD_TASK_ID, "AllReduce Forward"); + registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); + registrar.set_leaf(); + Runtime::preregister_task_variant( + registrar, "AllReduce Forward Task"); + } + { + TaskVariantRegistrar registrar(ALLREDUCE_BWD_TASK_ID, "AllReduce Backward"); + registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); + registrar.set_leaf(); + Runtime::preregister_task_variant( + registrar, "AllReduce Backward Task"); + } // FusedParallelOp { TaskVariantRegistrar registrar(FUSED_PARALLELOP_FWD_TASK_ID, diff --git a/src/runtime/operator_params.cc b/src/runtime/operator_params.cc index 8fdeacc623..6b61d5ac7a 100644 --- a/src/runtime/operator_params.cc +++ b/src/runtime/operator_params.cc @@ -34,6 +34,7 @@ #include "flexflow/ops/topk.h" #include "flexflow/ops/transpose.h" #include "flexflow/ops/tree_inc_multihead_self_attention.h" +#include "flexflow/parallel_ops/allreduce.h" #include "flexflow/parallel_ops/combine.h" #include "flexflow/parallel_ops/fused_parallel_op.h" #include "flexflow/parallel_ops/partition.h" @@ -105,6 +106,8 @@ tl::optional get_op_parameters(Op const *op) { return ((Reduction *)op)->get_params(); case OP_COMBINE: return ((Combine *)op)->get_params(); + case OP_ALLREDUCE: + return ((AllReduce *)op)->get_params(); case OP_FUSED_PARALLEL: return ((FusedParallelOp *)op)->get_params(); case OP_TRANSPOSE: diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index b47b17ad12..478092727f 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -243,6 +243,8 @@ BatchConfig RequestManager::prepare_next_batch(BatchConfig const &old_bc, outputFile << "end-to-end latency: " << std::fixed << std::setprecision(3) << total_request_run_time << std::endl; + outputFile << "num decoding steps: " << profile_info.decoding_steps + << std::endl; outputFile << "token IDs: "; for (int i = 0; i < request.tokens.size(); i++) { outputFile << request.tokens[i]; @@ -562,6 +564,8 @@ BeamSearchBatchConfig outputFile << "end-to-end latency: " << std::fixed << std::setprecision(3) << total_request_run_time << std::endl; + outputFile << "num decoding steps: " << profile_info.decoding_steps + << std::endl; outputFile << "token IDs: "; for (int i = 0; i < request.tokens.size(); i++) { outputFile << request.tokens[i]; diff --git a/src/runtime/substitution.cc b/src/runtime/substitution.cc index 58623258f1..6a61e70fc6 100644 --- a/src/runtime/substitution.cc +++ b/src/runtime/substitution.cc @@ -37,6 +37,7 @@ #include "flexflow/ops/softmax.h" #include "flexflow/ops/split.h" #include "flexflow/ops/tree_inc_multihead_self_attention.h" +#include "flexflow/parallel_ops/allreduce.h" #include "flexflow/parallel_ops/combine.h" #include "flexflow/parallel_ops/fused_parallel_op.h" #include "flexflow/parallel_ops/partition.h" @@ -898,8 +899,11 @@ bool GraphXfer::create_new_operator(OpX const *opx, Node &op) { case OP_EW_MUL: case OP_EW_MAX: case OP_EW_MIN: { + ElementBinaryParams params; + params.type = opx->type; + params.inplace_a = false; op = model->get_or_create_node({inputs[0], inputs[1]}, - {opx->type}); + params); break; } case OP_RELU: { @@ -3683,8 +3687,13 @@ bool FFModel::convert_graph_to_operators( case OP_EW_MIN: { assert(inList.size() == 2); ElementBinary *eb = (ElementBinary *)node.ptr; - new_op = new ElementBinary( - *this, eb->op_type, inputs[0], inputs[1], eb->inplace_a, NULL); + new_op = new ElementBinary(*this, + eb->layer_guid, + eb->op_type, + inputs[0], + inputs[1], + eb->inplace_a, + NULL); break; } case OP_POOL2D: { @@ -3777,6 +3786,12 @@ bool FFModel::convert_graph_to_operators( reduction->reduction_degree); break; } + case OP_ALLREDUCE: { + assert(inList.size() == 1); + AllReduce *allreduce = (AllReduce *)node.ptr; + new_op = new AllReduce(*this, inputs[0], allreduce->allreduce_dim); + break; + } case OP_FUSED_PARALLEL: { assert(inList.size() == 1); FusedParallelOp *fused = (FusedParallelOp *)node.ptr; diff --git a/tests/inference_tests.sh b/tests/inference_tests.sh index 761c6cf332..f50d374633 100755 --- a/tests/inference_tests.sh +++ b/tests/inference_tests.sh @@ -37,26 +37,26 @@ mkdir -p ../inference/output ############################################################################################### # LLAMA -../build/inference/spec_infer/spec_infer -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 --use-full-precision -llm-model llama -llm-weight ../inference/weights/llama_7B_weights/ -llm-config ../inference/models/configs/llama_7B.json -ssm-model llama -ssm-weight ../inference/weights/llama_160M_weights/ -ssm-config ../inference/models/configs/llama_160M.json -tokenizer ../inference/tokenizer/tokenizer.model -prompt ../inference/prompt/test.json -output-file ../inference/output/spec_inference_llama.txt +../build/inference/spec_infer/spec_infer -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 --use-full-precision -llm-model llama -llm-weight ../inference/weights/llama_7B_weights/ -llm-config ../inference/models/configs/llama_7B.json -ssm-model llama -ssm-weight ../inference/weights/llama_160M_weights/ -ssm-config ../inference/models/configs/llama_160M.json -tokenizer ../inference/tokenizer/tokenizer.model -prompt ../inference/prompt/test.json -output-file ../inference/output/spec_inference_llama.txt -pipeline-parallelism-degree 4 # LLAMA (half precision) -../build/inference/spec_infer/spec_infer -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model llama -llm-weight ../inference/weights/llama_7B_weights_half/ -llm-config ../inference/models/configs/llama_7B.json -ssm-model llama -ssm-weight ../inference/weights/llama_160M_weights_half/ -ssm-config ../inference/models/configs/llama_160M.json -tokenizer ../inference/tokenizer/tokenizer.model -prompt ../inference/prompt/test.json -output-file ../inference/output/spec_inference_llama_half.txt +../build/inference/spec_infer/spec_infer -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model llama -llm-weight ../inference/weights/llama_7B_weights_half/ -llm-config ../inference/models/configs/llama_7B.json -ssm-model llama -ssm-weight ../inference/weights/llama_160M_weights_half/ -ssm-config ../inference/models/configs/llama_160M.json -tokenizer ../inference/tokenizer/tokenizer.model -prompt ../inference/prompt/test.json -output-file ../inference/output/spec_inference_llama_half.txt -pipeline-parallelism-degree 4 # OPT -../build/inference/spec_infer/spec_infer -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 --use-full-precision -llm-model opt -llm-weight ../inference/weights/opt_6B_weights/ -llm-config ../inference/models/configs/opt_6B.json -ssm-model opt -ssm-weight ../inference/weights/opt_125M_weights/ -ssm-config ../inference/models/configs/opt_125M.json -tokenizer ../inference/tokenizer/ -prompt ../inference/prompt/test.json -output-file ../inference/output/spec_inference_opt.txt +../build/inference/spec_infer/spec_infer -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 --use-full-precision -llm-model opt -llm-weight ../inference/weights/opt_6B_weights/ -llm-config ../inference/models/configs/opt_6B.json -ssm-model opt -ssm-weight ../inference/weights/opt_125M_weights/ -ssm-config ../inference/models/configs/opt_125M.json -tokenizer ../inference/tokenizer/ -prompt ../inference/prompt/test.json -output-file ../inference/output/spec_inference_opt.txt -pipeline-parallelism-degree 4 # OPT (half precision) -../build/inference/spec_infer/spec_infer -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model opt -llm-weight ../inference/weights/opt_6B_weights_half/ -llm-config ../inference/models/configs/opt_6B.json -ssm-model opt -ssm-weight ../inference/weights/opt_125M_weights_half/ -ssm-config ../inference/models/configs/opt_125M.json -tokenizer ../inference/tokenizer/ -prompt ../inference/prompt/test.json -output-file ../inference/output/spec_inference_opt_half.txt +../build/inference/spec_infer/spec_infer -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model opt -llm-weight ../inference/weights/opt_6B_weights_half/ -llm-config ../inference/models/configs/opt_6B.json -ssm-model opt -ssm-weight ../inference/weights/opt_125M_weights_half/ -ssm-config ../inference/models/configs/opt_125M.json -tokenizer ../inference/tokenizer/ -prompt ../inference/prompt/test.json -output-file ../inference/output/spec_inference_opt_half.txt -pipeline-parallelism-degree 4 # Tensor parallelism tests if [ "$TENSOR_PARALLELISM_TESTS" = "ON" ]; then # LLAMA - ../build/inference/spec_infer/spec_infer -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 --use-full-precision -llm-model llama -llm-weight ../inference/weights/llama_7B_weights/ -llm-config ../inference/models/configs/llama_7B.json -ssm-model llama -ssm-weight ../inference/weights/llama_160M_weights/ -ssm-config ../inference/models/configs/llama_160M.json -tokenizer ../inference/tokenizer/tokenizer.model -prompt ../inference/prompt/test.json -output-file ../inference/output/spec_inference_llama_tp.txt -tensor-parallelism-degree 2 + ../build/inference/spec_infer/spec_infer -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 --use-full-precision -llm-model llama -llm-weight ../inference/weights/llama_7B_weights/ -llm-config ../inference/models/configs/llama_7B.json -ssm-model llama -ssm-weight ../inference/weights/llama_160M_weights/ -ssm-config ../inference/models/configs/llama_160M.json -tokenizer ../inference/tokenizer/tokenizer.model -prompt ../inference/prompt/test.json -output-file ../inference/output/spec_inference_llama_tp.txt -pipeline-parallelism-degree 2 -tensor-parallelism-degree 2 # LLAMA (half precision) - ../build/inference/spec_infer/spec_infer -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model llama -llm-weight ../inference/weights/llama_7B_weights_half/ -llm-config ../inference/models/configs/llama_7B.json -ssm-model llama -ssm-weight ../inference/weights/llama_160M_weights_half/ -ssm-config ../inference/models/configs/llama_160M.json -tokenizer ../inference/tokenizer/tokenizer.model -prompt ../inference/prompt/test.json -output-file ../inference/output/spec_inference_llama_half_tp.txt -tensor-parallelism-degree 2 + ../build/inference/spec_infer/spec_infer -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model llama -llm-weight ../inference/weights/llama_7B_weights_half/ -llm-config ../inference/models/configs/llama_7B.json -ssm-model llama -ssm-weight ../inference/weights/llama_160M_weights_half/ -ssm-config ../inference/models/configs/llama_160M.json -tokenizer ../inference/tokenizer/tokenizer.model -prompt ../inference/prompt/test.json -output-file ../inference/output/spec_inference_llama_half_tp.txt -pipeline-parallelism-degree 2 -tensor-parallelism-degree 2 # OPT - ../build/inference/spec_infer/spec_infer -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 --use-full-precision -llm-model opt -llm-weight ../inference/weights/opt_6B_weights/ -llm-config ../inference/models/configs/opt_6B.json -ssm-model opt -ssm-weight ../inference/weights/opt_125M_weights/ -ssm-config ../inference/models/configs/opt_125M.json -tokenizer ../inference/tokenizer/ -prompt ../inference/prompt/test.json -output-file ../inference/output/spec_inference_opt_tp.txt -tensor-parallelism-degree 2 + ../build/inference/spec_infer/spec_infer -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 --use-full-precision -llm-model opt -llm-weight ../inference/weights/opt_6B_weights/ -llm-config ../inference/models/configs/opt_6B.json -ssm-model opt -ssm-weight ../inference/weights/opt_125M_weights/ -ssm-config ../inference/models/configs/opt_125M.json -tokenizer ../inference/tokenizer/ -prompt ../inference/prompt/test.json -output-file ../inference/output/spec_inference_opt_tp.txt -pipeline-parallelism-degree 2 -tensor-parallelism-degree 2 # OPT (half precision) - ../build/inference/spec_infer/spec_infer -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model opt -llm-weight ../inference/weights/opt_6B_weights_half/ -llm-config ../inference/models/configs/opt_6B.json -ssm-model opt -ssm-weight ../inference/weights/opt_125M_weights_half/ -ssm-config ../inference/models/configs/opt_125M.json -tokenizer ../inference/tokenizer/ -prompt ../inference/prompt/test.json -output-file ../inference/output/spec_inference_opt_half_tp.txt -tensor-parallelism-degree 2 + ../build/inference/spec_infer/spec_infer -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model opt -llm-weight ../inference/weights/opt_6B_weights_half/ -llm-config ../inference/models/configs/opt_6B.json -ssm-model opt -ssm-weight ../inference/weights/opt_125M_weights_half/ -ssm-config ../inference/models/configs/opt_125M.json -tokenizer ../inference/tokenizer/ -prompt ../inference/prompt/test.json -output-file ../inference/output/spec_inference_opt_half_tp.txt -pipeline-parallelism-degree 2 -tensor-parallelism-degree 2 fi ############################################################################################### @@ -64,61 +64,80 @@ fi ############################################################################################### # LLAMA (small model) -../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 --use-full-precision -llm-model llama -llm-weight ../inference/weights/llama_160M_weights/ -llm-config ../inference/models/configs/llama_160M.json -tokenizer ../inference/tokenizer/tokenizer.model -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_llama_160M.txt +../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 --use-full-precision -llm-model llama -llm-weight ../inference/weights/llama_160M_weights/ -llm-config ../inference/models/configs/llama_160M.json -tokenizer ../inference/tokenizer/tokenizer.model -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_llama_160M.txt -pipeline-parallelism-degree 4 # LLAMA (small model, half precision) -../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model llama -llm-weight ../inference/weights/llama_160M_weights_half/ -llm-config ../inference/models/configs/llama_160M.json -tokenizer ../inference/tokenizer/tokenizer.model -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_llama_160M_half.txt +../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model llama -llm-weight ../inference/weights/llama_160M_weights_half/ -llm-config ../inference/models/configs/llama_160M.json -tokenizer ../inference/tokenizer/tokenizer.model -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_llama_160M_half.txt -pipeline-parallelism-degree 4 # LLAMA (big model) -../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 --use-full-precision -llm-model llama -llm-weight ../inference/weights/llama_7B_weights/ -llm-config ../inference/models/configs/llama_7B.json -tokenizer ../inference/tokenizer/tokenizer.model -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_llama_7B.txt +../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 --use-full-precision -llm-model llama -llm-weight ../inference/weights/llama_7B_weights/ -llm-config ../inference/models/configs/llama_7B.json -tokenizer ../inference/tokenizer/tokenizer.model -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_llama_7B.txt -pipeline-parallelism-degree 4 # LLAMA (big model, half precision) -../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model llama -llm-weight ../inference/weights/llama_7B_weights_half/ -llm-config ../inference/models/configs/llama_7B.json -tokenizer ../inference/tokenizer/tokenizer.model -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_llama_7B_half.txt +../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model llama -llm-weight ../inference/weights/llama_7B_weights_half/ -llm-config ../inference/models/configs/llama_7B.json -tokenizer ../inference/tokenizer/tokenizer.model -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_llama_7B_half.txt -pipeline-parallelism-degree 4 # OPT (small model) -../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 --use-full-precision -llm-model opt -llm-weight ../inference/weights/opt_125M_weights/ -llm-config ../inference/models/configs/opt_125M.json -tokenizer ../inference/tokenizer/ -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_opt_125M.txt +../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 --use-full-precision -llm-model opt -llm-weight ../inference/weights/opt_125M_weights/ -llm-config ../inference/models/configs/opt_125M.json -tokenizer ../inference/tokenizer/ -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_opt_125M.txt -pipeline-parallelism-degree 4 # OPT (small model, half precision) -../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model opt -llm-weight ../inference/weights/opt_125M_weights_half/ -llm-config ../inference/models/configs/opt_125M.json -tokenizer ../inference/tokenizer/ -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_opt_125M_half.txt +../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model opt -llm-weight ../inference/weights/opt_125M_weights_half/ -llm-config ../inference/models/configs/opt_125M.json -tokenizer ../inference/tokenizer/ -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_opt_125M_half.txt -pipeline-parallelism-degree 4 # OPT (big model) -../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 --use-full-precision -llm-model opt -llm-weight ../inference/weights/opt_6B_weights/ -llm-config ../inference/models/configs/opt_6B.json -tokenizer ../inference/tokenizer/ -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_opt_6B.txt +../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 --use-full-precision -llm-model opt -llm-weight ../inference/weights/opt_6B_weights/ -llm-config ../inference/models/configs/opt_6B.json -tokenizer ../inference/tokenizer/ -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_opt_6B.txt -pipeline-parallelism-degree 4 # OPT (big model, half precision) -../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model opt -llm-weight ../inference/weights/opt_6B_weights_half/ -llm-config ../inference/models/configs/opt_6B.json -tokenizer ../inference/tokenizer/ -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_opt_6B_half.txt +../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model opt -llm-weight ../inference/weights/opt_6B_weights_half/ -llm-config ../inference/models/configs/opt_6B.json -tokenizer ../inference/tokenizer/ -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_opt_6B_half.txt -pipeline-parallelism-degree 4 # Tensor parallelism tests if [ "$TENSOR_PARALLELISM_TESTS" = "ON" ]; then # LLAMA (small model) - ../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 --use-full-precision -llm-model llama -llm-weight ../inference/weights/llama_160M_weights/ -llm-config ../inference/models/configs/llama_160M.json -tokenizer ../inference/tokenizer/tokenizer.model -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_llama_160M_tp.txt -tensor-parallelism-degree 2 + ../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 --use-full-precision -llm-model llama -llm-weight ../inference/weights/llama_160M_weights/ -llm-config ../inference/models/configs/llama_160M.json -tokenizer ../inference/tokenizer/tokenizer.model -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_llama_160M_tp.txt -pipeline-parallelism-degree 2 -tensor-parallelism-degree 2 # LLAMA (small model, half precision) - ../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model llama -llm-weight ../inference/weights/llama_160M_weights_half/ -llm-config ../inference/models/configs/llama_160M.json -tokenizer ../inference/tokenizer/tokenizer.model -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_llama_160M_half_tp.txt -tensor-parallelism-degree 2 + ../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model llama -llm-weight ../inference/weights/llama_160M_weights_half/ -llm-config ../inference/models/configs/llama_160M.json -tokenizer ../inference/tokenizer/tokenizer.model -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_llama_160M_half_tp.txt -pipeline-parallelism-degree 2 -tensor-parallelism-degree 2 # LLAMA (big model) - ../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 --use-full-precision -llm-model llama -llm-weight ../inference/weights/llama_7B_weights/ -llm-config ../inference/models/configs/llama_7B.json -tokenizer ../inference/tokenizer/tokenizer.model -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_llama_7B_tp.txt -tensor-parallelism-degree 2 + ../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 --use-full-precision -llm-model llama -llm-weight ../inference/weights/llama_7B_weights/ -llm-config ../inference/models/configs/llama_7B.json -tokenizer ../inference/tokenizer/tokenizer.model -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_llama_7B_tp.txt -pipeline-parallelism-degree 2 -tensor-parallelism-degree 2 # LLAMA (big model, half precision) - ../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model llama -llm-weight ../inference/weights/llama_7B_weights_half/ -llm-config ../inference/models/configs/llama_7B.json -tokenizer ../inference/tokenizer/tokenizer.model -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_llama_7B_half_tp.txt -tensor-parallelism-degree 2 + ../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model llama -llm-weight ../inference/weights/llama_7B_weights_half/ -llm-config ../inference/models/configs/llama_7B.json -tokenizer ../inference/tokenizer/tokenizer.model -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_llama_7B_half_tp.txt -pipeline-parallelism-degree 2 -tensor-parallelism-degree 2 # OPT (small model) - ../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 --use-full-precision -llm-model opt -llm-weight ../inference/weights/opt_125M_weights/ -llm-config ../inference/models/configs/opt_125M.json -tokenizer ../inference/tokenizer/ -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_opt_125M_tp.txt -tensor-parallelism-degree 2 + ../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 --use-full-precision -llm-model opt -llm-weight ../inference/weights/opt_125M_weights/ -llm-config ../inference/models/configs/opt_125M.json -tokenizer ../inference/tokenizer/ -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_opt_125M_tp.txt -pipeline-parallelism-degree 2 -tensor-parallelism-degree 2 # OPT (small model, half precision) - ../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model opt -llm-weight ../inference/weights/opt_125M_weights_half/ -llm-config ../inference/models/configs/opt_125M.json -tokenizer ../inference/tokenizer/ -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_opt_125M_half_tp.txt -tensor-parallelism-degree 2 + ../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model opt -llm-weight ../inference/weights/opt_125M_weights_half/ -llm-config ../inference/models/configs/opt_125M.json -tokenizer ../inference/tokenizer/ -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_opt_125M_half_tp.txt -pipeline-parallelism-degree 2 -tensor-parallelism-degree 2 # OPT (big model) - ../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 --use-full-precision -llm-model opt -llm-weight ../inference/weights/opt_6B_weights/ -llm-config ../inference/models/configs/opt_6B.json -tokenizer ../inference/tokenizer/ -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_opt_6B_tp.txt -tensor-parallelism-degree 2 + ../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 --use-full-precision -llm-model opt -llm-weight ../inference/weights/opt_6B_weights/ -llm-config ../inference/models/configs/opt_6B.json -tokenizer ../inference/tokenizer/ -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_opt_6B_tp.txt -pipeline-parallelism-degree 2 -tensor-parallelism-degree 2 # OPT (big model, half precision) - ../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model opt -llm-weight ../inference/weights/opt_6B_weights_half/ -llm-config ../inference/models/configs/opt_6B.json -tokenizer ../inference/tokenizer/ -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_opt_6B_half_tp.txt -tensor-parallelism-degree 2 + ../build/inference/incr_decoding/incr_decoding -ll:gpu 4 -ll:fsize 14000 -ll:zsize 30000 -llm-model opt -llm-weight ../inference/weights/opt_6B_weights_half/ -llm-config ../inference/models/configs/opt_6B.json -tokenizer ../inference/tokenizer/ -prompt ../inference/prompt/test.json -output-file ../inference/output/incr_decoding_opt_6B_half_tp.txt -pipeline-parallelism-degree 2 -tensor-parallelism-degree 2 fi ############################################################################################### ############################### Alignment and Speed tests ##################################### ############################################################################################### -############ Alignment between speculative inference and incremental decoding ################# -# Full precision -diff <(tail -n +2 "../inference/output/incr_decoding_llama_7B.txt") <(tail -n +2 "../inference/output/spec_inference_llama.txt") -diff <(tail -n +2 "../inference/output/incr_decoding_opt_6B.txt") <(tail -n +2 "../inference/output/spec_inference_opt.txt") -# Half precision -#diff <(tail -n +2 "../inference/output/incr_decoding_llama_7B_half.txt") <(tail -n +2 "../inference/output/spec_inference_llama_half.txt") -#diff <(tail -n +2 "../inference/output/incr_decoding_opt_6B_half.txt" ) <(tail -n +2 "../inference/output/spec_inference_opt_half.txt") +##################################### Helper functions ####################################### +function check_partial_token_match { + local file1="$1" + local file2="$2" + local num_tokens_to_match=30 + + # Read the second line of the first file + third_line=$(sed -n '3p' "$file1") + read -r line1 <<< "$third_line" + tokens1=${line1#*: } + IFS=',' read -ra arr1 <<< "$tokens1" + + # Read the second line of the second file + third_line=$(sed -n '3p' "$file2") + read -r line2 <<< "$third_line" + tokens2=${line2#*: } + IFS=',' read -ra arr2 <<< "$tokens2" + + # Compare the first few integers in the two lists + for ((i = 0; i < num_tokens_to_match; i++)); do + if [[ "${arr1[$i]}" != "${arr2[$i]}" ]]; then + echo "The first $num_tokens_to_match tokens in files $file1 and $file2 are not identical." + exit 1 + fi + done + #echo "The first $num_tokens_to_match integers are identical." +} -# Speed test: speculative inference should be at very least 1.5x faster than incremental decoding function compare_speed_spec_infer_incr_decoding { local incrDec_file="$1" local specInf_file="$2" @@ -142,27 +161,69 @@ function compare_speed_spec_infer_incr_decoding { exit 1 fi } + +function compare_decoding_steps_spec_infer_incr_decoding { + local incrDec_file="$1" + local specInf_file="$2" + + # Read the number of decoding steps from the second line of the files + second_line=$(sed -n '2p' "$incrDec_file") + read -r line <<< "$second_line" + incrDec=${line#*: } + second_line=$(sed -n '2p' "$specInf_file") + read -r line <<< "$second_line" + specInf=${line#*: } + + if ! command -v bc &> /dev/null; then + echo "bc is not installed. Installing..." + sudo apt-get install -y bc + fi + + # Perform the comparison + threshold=$(bc <<< "$specInf * 1.5") + if (( $(echo "$incrDec >= $threshold" | bc -l) )); then + #echo "The decoding steps in $specInf_file are at least 1.5x less than those in $incrDec_file." + : + else + echo "Error: The decoding steps in $specInf_file are not at least 1.5x less than those in $incrDec_file!" + exit 1 + fi +} + +############ Alignment between speculative inference and incremental decoding ################# +# Full precision +diff <(tail -n +3 "../inference/output/incr_decoding_llama_7B.txt") <(tail -n +3 "../inference/output/spec_inference_llama.txt") +diff <(tail -n +3 "../inference/output/incr_decoding_opt_6B.txt") <(tail -n +3 "../inference/output/spec_inference_opt.txt") +# Half precision +check_partial_token_match "../inference/output/incr_decoding_llama_7B_half.txt" "../inference/output/spec_inference_llama_half.txt" +check_partial_token_match "../inference/output/incr_decoding_opt_6B_half.txt" "../inference/output/spec_inference_opt_half.txt" + +# Speed test: speculative inference should be at very least 1.5x faster than incremental decoding # Full precision -compare_speed_spec_infer_incr_decoding "../inference/output/incr_decoding_llama_7B.txt" "../inference/output/spec_inference_llama.txt" -compare_speed_spec_infer_incr_decoding "../inference/output/incr_decoding_opt_6B.txt" "../inference/output/spec_inference_opt.txt" +#compare_speed_spec_infer_incr_decoding "../inference/output/incr_decoding_llama_7B.txt" "../inference/output/spec_inference_llama.txt" +#compare_speed_spec_infer_incr_decoding "../inference/output/incr_decoding_opt_6B.txt" "../inference/output/spec_inference_opt.txt" +compare_decoding_steps_spec_infer_incr_decoding "../inference/output/incr_decoding_llama_7B.txt" "../inference/output/spec_inference_llama.txt" +compare_decoding_steps_spec_infer_incr_decoding "../inference/output/incr_decoding_opt_6B.txt" "../inference/output/spec_inference_opt.txt" # Half precision #compare_speed_spec_infer_incr_decoding "../inference/output/incr_decoding_llama_7B_half.txt" "../inference/output/spec_inference_llama_half.txt" #compare_speed_spec_infer_incr_decoding "../inference/output/incr_decoding_opt_6B_half.txt" "../inference/output/spec_inference_opt_half.txt" +compare_decoding_steps_spec_infer_incr_decoding "../inference/output/incr_decoding_llama_7B_half.txt" "../inference/output/spec_inference_llama_half.txt" +compare_decoding_steps_spec_infer_incr_decoding "../inference/output/incr_decoding_opt_6B_half.txt" "../inference/output/spec_inference_opt_half.txt" ############ Alignment between tensor model parallelism and pipeline parallelism only ################# if [ "$TENSOR_PARALLELISM_TESTS" = "ON" ]; then - diff <(tail -n +2 "../inference/output/spec_inference_llama_tp.txt") <(tail -n +2 "../inference/output/spec_inference_llama.txt") - diff <(tail -n +2 "../inference/output/spec_inference_opt_tp.txt") <(tail -n +2 "../inference/output/spec_inference_opt.txt") - diff <(tail -n +2 "../inference/output/spec_inference_llama_half_tp.txt") <(tail -n +2 "../inference/output/spec_inference_llama_half.txt") - diff <(tail -n +2 "../inference/output/spec_inference_opt_half_tp.txt") <(tail -n +2 "../inference/output/spec_inference_opt_half.txt") - diff <(tail -n +2 "../inference/output/incr_decoding_llama_160M_tp.txt") <(tail -n +2 "../inference/output/incr_decoding_llama_160M.txt") - # diff <(tail -n +2 "../inference/output/incr_decoding_llama_160M_half_tp.txt") <(tail -n +2 "../inference/output/incr_decoding_llama_160M_half.txt") - diff <(tail -n +2 "../inference/output/incr_decoding_llama_7B_tp.txt") <(tail -n +2 "../inference/output/incr_decoding_llama_7B.txt") - diff <(tail -n +2 "../inference/output/incr_decoding_llama_7B_half_tp.txt") <(tail -n +2 "../inference/output/incr_decoding_llama_7B_half.txt") - diff <(tail -n +2 "../inference/output/incr_decoding_opt_125M_tp.txt") <(tail -n +2 "../inference/output/incr_decoding_opt_125M.txt") - diff <(tail -n +2 "../inference/output/incr_decoding_opt_125M_half_tp.txt") <(tail -n +2 "../inference/output/incr_decoding_opt_125M_half.txt") - diff <(tail -n +2 "../inference/output/incr_decoding_opt_6B_tp.txt") <(tail -n +2 "../inference/output/incr_decoding_opt_6B.txt") - diff <(tail -n +2 "../inference/output/incr_decoding_opt_6B_half_tp.txt") <(tail -n +2 "../inference/output/incr_decoding_opt_6B_half.txt") + diff <(tail -n +3 "../inference/output/spec_inference_llama_tp.txt") <(tail -n +3 "../inference/output/spec_inference_llama.txt") + diff <(tail -n +3 "../inference/output/spec_inference_opt_tp.txt") <(tail -n +3 "../inference/output/spec_inference_opt.txt") + check_partial_token_match "../inference/output/spec_inference_llama_half_tp.txt" "../inference/output/spec_inference_llama_half.txt" + check_partial_token_match "../inference/output/spec_inference_opt_half_tp.txt" "../inference/output/spec_inference_opt_half.txt" + diff <(tail -n +3 "../inference/output/incr_decoding_llama_160M_tp.txt") <(tail -n +3 "../inference/output/incr_decoding_llama_160M.txt") + check_partial_token_match "../inference/output/incr_decoding_llama_160M_half_tp.txt" "../inference/output/incr_decoding_llama_160M_half.txt" + diff <(tail -n +3 "../inference/output/incr_decoding_llama_7B_tp.txt") <(tail -n +3 "../inference/output/incr_decoding_llama_7B.txt") + check_partial_token_match "../inference/output/incr_decoding_llama_7B_half_tp.txt" "../inference/output/incr_decoding_llama_7B_half.txt" + diff <(tail -n +3 "../inference/output/incr_decoding_opt_125M_tp.txt") <(tail -n +3 "../inference/output/incr_decoding_opt_125M.txt") + check_partial_token_match "../inference/output/incr_decoding_opt_125M_half_tp.txt" "../inference/output/incr_decoding_opt_125M_half.txt" + diff <(tail -n +3 "../inference/output/incr_decoding_opt_6B_tp.txt") <(tail -n +3 "../inference/output/incr_decoding_opt_6B.txt") + check_partial_token_match "../inference/output/incr_decoding_opt_6B_half_tp.txt" "../inference/output/incr_decoding_opt_6B_half.txt" fi ######################### Alignment tests with HuggingFace #################################### @@ -192,15 +253,15 @@ python3 ./inference/huggingface_inference.py --model-name "facebook/opt-125m" -- # OPT (big model, half precision) #python3 ./inference/huggingface_inference.py --model-name "facebook/opt-6.7b" --tokenizer-model-name "facebook/opt-6.7b" --prompt-file "../../inference/prompt/test.json" --output-file "../../inference/output/huggingface_opt_6B_half.txt" --gpu --max-length 127 -diff <(tail -n +2 "../inference/output/huggingface_llama_160M.txt") <(tail -n +4 "../inference/output/incr_decoding_llama_160M.txt") -diff <(tail -n +2 "../inference/output/huggingface_llama_160M_half.txt") <(tail -n +4 "../inference/output/incr_decoding_llama_160M_half.txt") -diff <(tail -n +2 "../inference/output/huggingface_llama_7B.txt") <(tail -n +4 "../inference/output/incr_decoding_llama_7B.txt") -diff <(tail -n +2 "../inference/output/huggingface_llama_7B_half.txt") <(tail -n +4 "../inference/output/incr_decoding_llama_7B_half.txt") +diff <(tail -n +2 "../inference/output/huggingface_llama_160M.txt") <(tail -n +5 "../inference/output/incr_decoding_llama_160M.txt") +diff <(tail -n +2 "../inference/output/huggingface_llama_160M_half.txt" | tr -s '[:space:]' '\n' | head -n 20) <(tail -n +5 "../inference/output/incr_decoding_llama_160M_half.txt" | tr -s '[:space:]' '\n' | head -n 20) +diff <(tail -n +2 "../inference/output/huggingface_llama_7B.txt") <(tail -n +5 "../inference/output/incr_decoding_llama_7B.txt") +diff <(tail -n +2 "../inference/output/huggingface_llama_7B_half.txt" | tr -s '[:space:]' '\n' | head -n 20) <(tail -n +5 "../inference/output/incr_decoding_llama_7B_half.txt" | tr -s '[:space:]' '\n' | head -n 20) -diff <(tail -n +2 "../inference/output/huggingface_opt_125M.txt") <(tail -n +4 "../inference/output/incr_decoding_opt_125M.txt") -diff <(tail -n +2 "../inference/output/huggingface_opt_125M_half.txt") <(tail -n +4 "../inference/output/incr_decoding_opt_125M_half.txt") -#diff <(tail -n +2 "../inference/output/huggingface_opt_6B.txt") <(tail -n +4 "../inference/output/incr_decoding_opt_6B.txt") -#diff <(tail -n +2 "../inference/output/huggingface_opt_6B_half.txt") <(tail -n +4 "../inference/output/incr_decoding_opt_6B_half.txt") +diff <(tail -n +2 "../inference/output/huggingface_opt_125M.txt") <(tail -n +5 "../inference/output/incr_decoding_opt_125M.txt") +diff <(tail -n +2 "../inference/output/huggingface_opt_125M_half.txt" | tr -s '[:space:]' '\n' | head -n 20) <(tail -n +5 "../inference/output/incr_decoding_opt_125M_half.txt" | tr -s '[:space:]' '\n' | head -n 20) +#diff <(tail -n +2 "../inference/output/huggingface_opt_6B.txt") <(tail -n +5 "../inference/output/incr_decoding_opt_6B.txt") +#diff <(tail -n +2 "../inference/output/huggingface_opt_6B_half.txt") <(tail -n +5 "../inference/output/incr_decoding_opt_6B_half.txt") ############################################################################################### ###################################### Cleanup ################################################ From ae67898b00405a130e8197b0b7808b5fc27d4867 Mon Sep 17 00:00:00 2001 From: xinhaoc <99570243+xinhaoc@users.noreply.github.com> Date: Sun, 16 Jul 2023 14:45:58 -0400 Subject: [PATCH 11/12] change batch_size to num_active_tokens (#861) --- include/flexflow/ops/beam_topk.h | 4 ++-- src/ops/arg_topk.cc | 3 --- src/ops/beam_topk.cc | 5 +---- src/ops/beam_topk.cpp | 4 ++-- src/ops/beam_topk.cu | 4 ++-- 5 files changed, 7 insertions(+), 13 deletions(-) diff --git a/include/flexflow/ops/beam_topk.h b/include/flexflow/ops/beam_topk.h index 76404bfb6d..57ab5c1074 100644 --- a/include/flexflow/ops/beam_topk.h +++ b/include/flexflow/ops/beam_topk.h @@ -82,7 +82,7 @@ class BeamTopK : public Op { float *output_ptr, int *indices_ptr, int *parent_ptr, - size_t batch_size, + int batch_size, int length, bool sorted, ffStream_t stream); @@ -92,7 +92,7 @@ class BeamTopK : public Op { float *output_ptr, int *indices_ptr, int *parent_ptr, - size_t batch_size, + int batch_size, int length, bool sorted); Params get_params() const; diff --git a/src/ops/arg_topk.cc b/src/ops/arg_topk.cc index a604c016d2..c1bbb65f1e 100644 --- a/src/ops/arg_topk.cc +++ b/src/ops/arg_topk.cc @@ -311,9 +311,6 @@ InferenceResult int batch_size = bc->num_active_tokens(); ArgTopK::forward_kernel_wrapper(m, input, indices, batch_size); - int length = input.domain.hi()[0] - input.domain.lo()[0] + 1; - batch_size = input.domain.get_volume() / length; - InferenceResult ir; download_tensor( indices.get_int32_ptr(), ir.token_ids, batch_size); diff --git a/src/ops/beam_topk.cc b/src/ops/beam_topk.cc index db507c1729..0920105acc 100644 --- a/src/ops/beam_topk.cc +++ b/src/ops/beam_topk.cc @@ -379,12 +379,9 @@ BeamInferenceResult // total token nums size_t tokens_per_request = in1_domain.hi()[1] - in1_domain.lo()[1] + 1; - size_t batch_size = in1_domain.get_volume() / length; - + int batch_size = bc->num_active_tokens(); // std::cout << "beam search topk params: " << length << ", " << k << ", " // << batch_size << "\n"; - assert(out2_domain.get_volume() / k == batch_size); - // std::vector beam_width; // std::unordered_map sub_requests = bc->sub_requests; // for (int i = 0; i < bc->MAX_NUM_REQUESTS; i++) { diff --git a/src/ops/beam_topk.cpp b/src/ops/beam_topk.cpp index 1817eae4da..248ab188da 100644 --- a/src/ops/beam_topk.cpp +++ b/src/ops/beam_topk.cpp @@ -479,7 +479,7 @@ void BeamTopK::forward_kernel(BeamTopKMeta const *m, float *output_ptr, int *indices_ptr, int *parent_ptr, - size_t batch_size, + int batch_size, int length, bool sorted, hipStream_t stream) { @@ -630,7 +630,7 @@ void BeamTopK::forward_kernel_wrapper(BeamTopKMeta const *m, float *output_ptr, int *indices_ptr, int *parent_ptr, - size_t batch_size, + int batch_size, int length, bool sorted) { hipStream_t stream; diff --git a/src/ops/beam_topk.cu b/src/ops/beam_topk.cu index 9a5cd86486..ceddb55f2d 100644 --- a/src/ops/beam_topk.cu +++ b/src/ops/beam_topk.cu @@ -511,7 +511,7 @@ void BeamTopK::forward_kernel(BeamTopKMeta const *m, float *output_ptr, int *indices_ptr, int *parent_ptr, - size_t batch_size, + int batch_size, int length, bool sorted, cudaStream_t stream) { @@ -662,7 +662,7 @@ void BeamTopK::forward_kernel_wrapper(BeamTopKMeta const *m, float *output_ptr, int *indices_ptr, int *parent_ptr, - size_t batch_size, + int batch_size, int length, bool sorted) { cudaStream_t stream; From 58b745d04c67a85fb42392ecd692fda30b8e80ae Mon Sep 17 00:00:00 2001 From: lambda shi Date: Mon, 17 Jul 2023 05:03:50 +0800 Subject: [PATCH 12/12] Add opt-13B config (#841) Co-authored-by: Zhihao Jia --- inference/models/configs/opt_13B.json | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 inference/models/configs/opt_13B.json diff --git a/inference/models/configs/opt_13B.json b/inference/models/configs/opt_13B.json new file mode 100644 index 0000000000..96cad5c99b --- /dev/null +++ b/inference/models/configs/opt_13B.json @@ -0,0 +1,15 @@ +{ + "vocab_size": 50272, + "word_embed_proj_dim": 5120, + "hidden_size": 5120, + "num_attention_heads": 40, + "max_position_embeddings": 2048, + "layer_norm_elementwise_affine": true, + "num_hidden_layers": 40, + "dropout": 0.1, + "ffn_dim": 20480, + "max_beam_width": 1, + "batchSize": 8, + "sentence_len": 100, + "max_beam_depth": 4 +}