diff --git a/cmake/nccl.cmake b/cmake/nccl.cmake index 82cf3b4122..abb4864588 100644 --- a/cmake/nccl.cmake +++ b/cmake/nccl.cmake @@ -36,11 +36,12 @@ if(NCCL_LIBRARY AND NCCL_INCLUDE_DIR) string(REGEX MATCH "([0-9]+)" NCCL_MAJOR ${NCCL_VERSION_DEFINES}) string(REGEX MATCH "([0-9]+)" NCCL_MINOR ${NCCL_VERSION_DEFINES2}) set(NCCL_VERSION "${NCCL_MAJOR}.${NCCL_MINOR}") - if(NCCL_VERSION VERSION_LESS 2.23) - set(NCCL_OLD TRUE) - else() - set(NCCL_OLD FALSE) - endif() + set(NCCL_OLD FALSE) + # if(NCCL_VERSION VERSION_LESS 2.23) + # set(NCCL_OLD TRUE) + # else() + # set(NCCL_OLD FALSE) + # endif() message(STATUS "Found NCCL version: ${NCCL_VERSION}") else() message(WARNING "NCCL header not found, unable to determine version") diff --git a/docker/flexflow-environment/Dockerfile b/docker/flexflow-environment/Dockerfile index ee13a07375..7028fc4b2e 100644 --- a/docker/flexflow-environment/Dockerfile +++ b/docker/flexflow-environment/Dockerfile @@ -55,18 +55,18 @@ ENV CUDA_DIR /usr/local/cuda ARG FF_GPU_BACKEND "cuda" # Update NCCL if FF_GPU_BACKEND is cuda -RUN /bin/bash -c 'if [ "$FF_GPU_BACKEND" = "cuda" ]; then \ - echo "FF_GPU_BACKEND: ${FF_GPU_BACKEND}. Updating NCCL"; \ - ubuntu_version=$(lsb_release -rs); \ - ubuntu_version=${ubuntu_version//./}; \ - wget "https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${ubuntu_version}/x86_64/cuda-keyring_1.0-1_all.deb"; \ - DEBIAN_FRONTEND=noninteractive dpkg -i cuda-keyring_1.0-1_all.deb; \ - DEBIAN_FRONTEND=noninteractive apt-get update -y --allow-change-held-packages; \ - rm -f cuda-keyring_1.0-1_all.deb; \ - DEBIAN_FRONTEND=noninteractive apt install -y --allow-change-held-packages libnccl2 libnccl-dev; \ - else \ - echo "FF_GPU_BACKEND: ${FF_GPU_BACKEND}. Skipping updating NCCL"; \ - fi' +# RUN /bin/bash -c 'if [ "$FF_GPU_BACKEND" = "cuda" ]; then \ +# echo "FF_GPU_BACKEND: ${FF_GPU_BACKEND}. Updating NCCL"; \ +# ubuntu_version=$(lsb_release -rs); \ +# ubuntu_version=${ubuntu_version//./}; \ +# wget "https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${ubuntu_version}/x86_64/cuda-keyring_1.0-1_all.deb"; \ +# DEBIAN_FRONTEND=noninteractive dpkg -i cuda-keyring_1.0-1_all.deb; \ +# DEBIAN_FRONTEND=noninteractive apt-get update -y --allow-change-held-packages; \ +# rm -f cuda-keyring_1.0-1_all.deb; \ +# DEBIAN_FRONTEND=noninteractive apt install -y --allow-change-held-packages libnccl2 libnccl-dev; \ +# else \ +# echo "FF_GPU_BACKEND: ${FF_GPU_BACKEND}. Skipping updating NCCL"; \ +# fi' # Install hip dependencies if FF_GPU_BACKEND is hip_cuda or hip_rocm # Note that amd's docs say to also install the `hip-runtime-nvidia` package. This diff --git a/include/flexflow/ops/kernels/lora_linear_kernels.h b/include/flexflow/ops/kernels/lora_linear_kernels.h index 5360b5f8ea..eee9875d30 100644 --- a/include/flexflow/ops/kernels/lora_linear_kernels.h +++ b/include/flexflow/ops/kernels/lora_linear_kernels.h @@ -8,7 +8,8 @@ #include "flexflow/ops/lora_linear.h" namespace FlexFlow { - +using Legion::Context; +using Legion::Runtime; struct LoraLinearWeight { // weights void *w0_ptr, *w1_ptr; @@ -46,7 +47,9 @@ void inference_kernel_wrapper(LoraLinearMeta *m, BatchConfig const *bc, GenericTensorAccessorR const &input, GenericTensorAccessorW const &output); -void peft_bwd_kernel_wrapper(LoraLinearMeta *m, +void peft_bwd_kernel_wrapper(Context ctx, + Runtime *runtime, + LoraLinearMeta *m, BatchConfig const *bc, GenericTensorAccessorW const &input_grad, GenericTensorAccessorR const &output_grad); @@ -63,7 +66,9 @@ void inference_kernel(LoraLinearMeta *m, int out_dim, ffStream_t stream); template -void peft_bwd_kernel(LoraLinearMeta *m, +void peft_bwd_kernel(Context ctx, + Runtime *runtime, + LoraLinearMeta *m, BatchConfig const *bc, DT *input_grad_ptr, DT const *output_grad_ptr, diff --git a/include/flexflow/optimizer.h b/include/flexflow/optimizer.h index bab7e6e4ed..4917df73c3 100644 --- a/include/flexflow/optimizer.h +++ b/include/flexflow/optimizer.h @@ -20,7 +20,8 @@ #include "legion.h" namespace FlexFlow { - +using Legion::Context; +using Legion::Runtime; class FFModel; class OpMeta; @@ -60,7 +61,9 @@ class SGDOptimizer : public Optimizer { std::vector const ®ions, Legion::Context ctx, Legion::Runtime *runtime); - static void nccl_update_task_gpu(SGDOptimizer const *op, + static void nccl_update_task_gpu(Context ctx, + Runtime *runtime, + SGDOptimizer const *op, OpMeta const *meta, float const *w_grad_ptr, size_t size, @@ -103,7 +106,9 @@ class AdamOptimizer : public Optimizer { std::vector const ®ions, Legion::Context ctx, Legion::Runtime *runtime); - static void nccl_update_task_gpu(AdamOptimizer const *op, + static void nccl_update_task_gpu(Context ctx, + Runtime *runtime, + AdamOptimizer const *op, OpMeta const *meta, float const *w_grad_ptr, size_t size, diff --git a/src/ops/fused.cc b/src/ops/fused.cc index 720d678a4a..984691fa66 100644 --- a/src/ops/fused.cc +++ b/src/ops/fused.cc @@ -476,7 +476,6 @@ void FusedOp::init(FFModel const &ff) { false /*must*/, 0 /*mapper_id*/, outputs[0]->machine_view.hash()); - launcher.concurrent = true; FutureMap fm = runtime->execute_index_space(ctx, launcher); fm.wait_all_results(); switch (domain.get_dim()) { @@ -571,7 +570,6 @@ void FusedOp::init_inference(FFModel const &ff, false /*must*/, 0 /*mapper_id*/, machine_view_hash); - launcher.concurrent = true; FutureMap fm = runtime->execute_index_space(ctx, launcher); fm.wait_all_results(); switch (domain.get_dim()) { diff --git a/src/ops/fused.cpp b/src/ops/fused.cpp index 2cede662f3..dfb524d206 100644 --- a/src/ops/fused.cpp +++ b/src/ops/fused.cpp @@ -612,8 +612,10 @@ __host__ void assert(fused->op_num_inputs[op] == 1); assert(fused->op_num_outputs[op] == 1); AllReduceMeta const *m = (AllReduceMeta *)metas->meta[op]; + runtime->concurrent_task_barrier(ctx); Kernels::AllReduce::inference_kernel_wrapper( m, bc, my_input_accessor[0], my_output_accessor[0]); + runtime->concurrent_task_barrier(ctx); break; } case OP_PARALLEL_IDENTITY: { @@ -870,7 +872,12 @@ __host__ void FusedOp::peft_bwd_task(Task const *task, // since we ``inplace'' the output for LoRA assert(my_input_grad_accessor[1].ptr == my_output_grad_accessor[0].ptr); Kernels::LoraLinear::peft_bwd_kernel_wrapper( - m, bc, my_input_grad_accessor[0], my_output_grad_accessor[0]); + ctx, + runtime, + m, + bc, + my_input_grad_accessor[0], + my_output_grad_accessor[0]); break; } case OP_BATCHMATMUL: { @@ -1129,8 +1136,10 @@ __host__ void FusedOp::peft_bwd_task(Task const *task, assert(fused->op_num_inputs[op] == 1); assert(fused->op_num_outputs[op] == 1); ParallelIdentityMeta const *m = (ParallelIdentityMeta *)metas->meta[op]; + runtime->concurrent_task_barrier(ctx); Kernels::ParallelIdentity::peft_bwd_kernel_wrapper( m, bc, my_input_grad_accessor[0], my_output_grad_accessor[0]); + runtime->concurrent_task_barrier(ctx); break; } default: { diff --git a/src/ops/fused.cu b/src/ops/fused.cu index 5aed2cd69a..62845c0f8e 100644 --- a/src/ops/fused.cu +++ b/src/ops/fused.cu @@ -623,8 +623,10 @@ __host__ void assert(fused->op_num_inputs[op] == 1); assert(fused->op_num_outputs[op] == 1); AllReduceMeta const *m = (AllReduceMeta *)metas->meta[op]; + runtime->concurrent_task_barrier(ctx); Kernels::AllReduce::inference_kernel_wrapper( m, bc, my_input_accessor[0], my_output_accessor[0]); + runtime->concurrent_task_barrier(ctx); break; } case OP_PARALLEL_IDENTITY: { @@ -888,7 +890,12 @@ __host__ void FusedOp::peft_bwd_task(Task const *task, // since we ``inplace'' the output for LoRA assert(my_input_grad_accessor[1].ptr == my_output_grad_accessor[0].ptr); Kernels::LoraLinear::peft_bwd_kernel_wrapper( - m, bc, my_input_grad_accessor[0], my_output_grad_accessor[0]); + ctx, + runtime, + m, + bc, + my_input_grad_accessor[0], + my_output_grad_accessor[0]); break; } case OP_BATCHMATMUL: { @@ -1149,8 +1156,10 @@ __host__ void FusedOp::peft_bwd_task(Task const *task, assert(fused->op_num_inputs[op] == 1); assert(fused->op_num_outputs[op] == 1); ParallelIdentityMeta const *m = (ParallelIdentityMeta *)metas->meta[op]; + runtime->concurrent_task_barrier(ctx); Kernels::ParallelIdentity::peft_bwd_kernel_wrapper( m, bc, my_input_grad_accessor[0], my_output_grad_accessor[0]); + runtime->concurrent_task_barrier(ctx); break; } default: { diff --git a/src/ops/kernels/lora_linear_kernels.cu b/src/ops/kernels/lora_linear_kernels.cu index 93e5820f9c..090eb17e7b 100644 --- a/src/ops/kernels/lora_linear_kernels.cu +++ b/src/ops/kernels/lora_linear_kernels.cu @@ -96,7 +96,9 @@ void inference_kernel_wrapper(LoraLinearMeta *m, } } -void peft_bwd_kernel_wrapper(LoraLinearMeta *m, +void peft_bwd_kernel_wrapper(Context ctx, + Runtime *runtime, + LoraLinearMeta *m, BatchConfig const *bc, GenericTensorAccessorW const &input_grad, GenericTensorAccessorR const &output_grad) { @@ -111,7 +113,9 @@ void peft_bwd_kernel_wrapper(LoraLinearMeta *m, int in_dim = input_grad.domain.hi()[0] - input_grad.domain.lo()[0] + 1; int out_dim = output_grad.domain.hi()[0] - output_grad.domain.lo()[0] + 1; if (m->input_type[0] == DT_FLOAT) { - Internal::peft_bwd_kernel(m, + Internal::peft_bwd_kernel(ctx, + runtime, + m, bc, input_grad.get_float_ptr(), output_grad.get_float_ptr(), @@ -119,7 +123,9 @@ void peft_bwd_kernel_wrapper(LoraLinearMeta *m, out_dim, stream); } else if (m->input_type[0] == DT_HALF) { - Internal::peft_bwd_kernel(m, + Internal::peft_bwd_kernel(ctx, + runtime, + m, bc, input_grad.get_half_ptr(), output_grad.get_half_ptr(), @@ -361,7 +367,9 @@ __global__ void sgd_update(size_t count, } template -void peft_bwd_kernel(LoraLinearMeta *m, +void peft_bwd_kernel(Context ctx, + Runtime *runtime, + LoraLinearMeta *m, BatchConfig const *bc, DT *input_grad_ptr, DT const *output_grad_ptr, @@ -543,13 +551,15 @@ void peft_bwd_kernel(LoraLinearMeta *m, // and sum first #ifdef FF_USE_NCCL ncclDataType_t nccl_data_type = ff_to_nccl_datatype(m->output_type[0]); - checkCUDA(ncclAllReduce(static_cast
(weight.w1_grad_ptr), + runtime->concurrent_task_barrier(ctx); + checkNCCL(ncclAllReduce(static_cast
(weight.w1_grad_ptr), static_cast
(weight.w1_grad_ptr), w1_num_elements, nccl_data_type, ncclSum, m->handle.ncclComm, stream)); + runtime->concurrent_task_barrier(ctx); #else assert(false && "Must enable FF_USE_NCCL to use AllReduce operators"); #endif diff --git a/src/ops/lora_linear.cc b/src/ops/lora_linear.cc index 513147f3b7..3749cce994 100644 --- a/src/ops/lora_linear.cc +++ b/src/ops/lora_linear.cc @@ -296,7 +296,6 @@ void LoraLinear::init_inference( false /*must*/, 0 /*mapper_id*/, machine_view_hash); - launcher.concurrent = true; launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, 0 /*projection id*/, READ_ONLY, @@ -1066,7 +1065,7 @@ void LoraLinear::peft_bwd_task(Task const *task, int out_dim = output_grad.domain.hi()[0] - output_grad.domain.lo()[0] + 1; // int num_infr_tokens = bc->num_active_infr_tokens(); // int num_peft_tokens = bc->num_active_peft_tokens(); - peft_bwd_kernel_wrapper(m, bc, input_grad, output_grad); + peft_bwd_kernel_wrapper(ctx, runtime, m, bc, input_grad, output_grad); save_peft_weights_if_needed(m, bc, in_dim, out_dim, shard_id); diff --git a/src/parallel_ops/allreduce.cc b/src/parallel_ops/allreduce.cc index a4443c4066..25782bbf03 100644 --- a/src/parallel_ops/allreduce.cc +++ b/src/parallel_ops/allreduce.cc @@ -197,7 +197,9 @@ void AllReduce::forward_task(Task const *task, m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); assert(input.data_type == output.data_type); + runtime->concurrent_task_barrier(ctx); forward_kernel_wrapper(m, input, output); + runtime->concurrent_task_barrier(ctx); } void AllReduce::backward(FFModel const &ff) { @@ -347,7 +349,9 @@ void AllReduce::inference_task(Task const *task, m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); assert(input.data_type == output.data_type); + runtime->concurrent_task_barrier(ctx); inference_kernel_wrapper(m, bc, input, output); + runtime->concurrent_task_barrier(ctx); if (m->inference_debugging) { assert(task->index_point.get_dim() == 1); int shard_id = task->index_point.point_data[0]; diff --git a/src/parallel_ops/parallel_identity.cc b/src/parallel_ops/parallel_identity.cc index 7d68036709..fabc425ad8 100644 --- a/src/parallel_ops/parallel_identity.cc +++ b/src/parallel_ops/parallel_identity.cc @@ -245,7 +245,9 @@ void ParallelIdentity::backward_task(Task const *task, m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); assert(input_grad.data_type == output_grad.data_type); + runtime->concurrent_task_barrier(ctx); backward_kernel_wrapper(m, input_grad, output_grad); + runtime->concurrent_task_barrier(ctx); } void ParallelIdentity::init_inference( @@ -270,7 +272,6 @@ void ParallelIdentity::init_inference( false /*must*/, 0 /*mapper_id*/, machine_view_hash); - launcher.concurrent = true; launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part, 0 /*projection id*/, READ_ONLY, @@ -422,7 +423,9 @@ void ParallelIdentity::peft_bwd_task(Task const *task, m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); assert(input_grad.data_type == output_grad.data_type); + runtime->concurrent_task_barrier(ctx); peft_bwd_kernel_wrapper(m, bc, input_grad, output_grad); + runtime->concurrent_task_barrier(ctx); if (m->inference_debugging) { assert(task->index_point.get_dim() == 1); int shard_id = task->index_point.point_data[0]; diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 9f438469ff..0b8a507d70 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -6900,7 +6900,6 @@ void register_flexflow_internal_tasks(Runtime *runtime, TaskVariantRegistrar registrar(LORA_LINEAR_INIT_TASK_ID, "LoraLinear Init"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); - registrar.set_concurrent(); if (pre_register) { Runtime::preregister_task_variant( registrar, "LoraLinear Init Task"); @@ -6933,6 +6932,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); registrar.set_concurrent(); + registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "LoraLinear PEFT Backward Task"); @@ -6964,7 +6964,6 @@ void register_flexflow_internal_tasks(Runtime *runtime, TaskVariantRegistrar registrar(FUSEDOP_INIT_TASK_ID, "FusedOp Init"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); - registrar.set_concurrent(); if (pre_register) { Runtime::preregister_task_variant( registrar, "FusedOp Init Task"); @@ -6980,6 +6979,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); registrar.set_concurrent(); + registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "FusedOp Inference Task"); @@ -6996,6 +6996,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); registrar.set_concurrent(); + registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "FusedOp PEFT Backward Task"); @@ -7012,6 +7013,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); registrar.set_concurrent(); + registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "FusedOp Forward Task"); @@ -7027,6 +7029,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); registrar.set_concurrent(); + registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "FusedOp Backward Task"); @@ -7263,7 +7266,6 @@ void register_flexflow_internal_tasks(Runtime *runtime, TaskVariantRegistrar registrar(ALLREDUCE_INIT_TASK_ID, "AllReduce Init"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); - registrar.set_concurrent(); if (pre_register) { Runtime::preregister_task_variant( registrar, "AllReduce init Task"); @@ -7281,6 +7283,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, // AllReduce forward and backward must run concurrently since they // use ncclAllReduce internally registrar.set_concurrent(); + registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "AllReduce Forward Task"); @@ -7295,9 +7298,6 @@ void register_flexflow_internal_tasks(Runtime *runtime, TaskVariantRegistrar registrar(ALLREDUCE_BWD_TASK_ID, "AllReduce Backward"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); - // AllReduce forward and backward must run concurrently since they - // use ncclAllReduce internally - // registrar.set_concurrent(); if (pre_register) { Runtime::preregister_task_variant( registrar, "AllReduce Backward Task"); @@ -7316,6 +7316,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, // AllReduce forward and backward must run concurrently since they // use ncclAllReduce internally registrar.set_concurrent(); + registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "AllReduce Inference Task"); @@ -7331,9 +7332,6 @@ void register_flexflow_internal_tasks(Runtime *runtime, "AllReduce PEFT Backward"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); - // AllReduce forward and backward must run concurrently since they - // use ncclAllReduce internally - // registrar.set_concurrent(); if (pre_register) { Runtime::preregister_task_variant( registrar, "AllReduce PEFT Backward Task"); @@ -7350,7 +7348,6 @@ void register_flexflow_internal_tasks(Runtime *runtime, "ParallelIdentity Init"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); - registrar.set_concurrent(); if (pre_register) { Runtime::preregister_task_variant( registrar, "ParallelIdentity init Task"); @@ -7383,6 +7380,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); registrar.set_concurrent(); + registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "ParallelIdentity Backward Task"); @@ -7416,6 +7414,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); registrar.set_concurrent(); + registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "ParallelIdentity PEFT Backward Task"); @@ -7434,6 +7433,8 @@ void register_flexflow_internal_tasks(Runtime *runtime, "FusedParallel Forward"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); + registrar.set_concurrent(); + registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "FusedParallel Forward Task"); @@ -7449,6 +7450,8 @@ void register_flexflow_internal_tasks(Runtime *runtime, "FusedParallel Backward"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); + registrar.set_concurrent(); + registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "FusedParallel Backward Task"); @@ -7497,6 +7500,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); registrar.set_concurrent(); + registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "SGD NCCL Update Task", 111 /*variant ID*/); @@ -7512,6 +7516,8 @@ void register_flexflow_internal_tasks(Runtime *runtime, TaskVariantRegistrar registrar(ADAM_UPD_NCCL_TASK_ID, "Adam NCCL Update"); registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); + registrar.set_concurrent(); + registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "Adam NCCL Update Task", 111 /*variant ID*/); @@ -7649,6 +7655,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); registrar.set_concurrent(); + registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "NCCL Init Communicators Task", 111 /*variant ID*/); @@ -7666,6 +7673,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); registrar.set_concurrent(); + registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "NCCL Finish Communicators Task", 111 /*variant ID*/); diff --git a/src/runtime/optimizer.cc b/src/runtime/optimizer.cc index c42a0c9aa6..96b735803c 100644 --- a/src/runtime/optimizer.cc +++ b/src/runtime/optimizer.cc @@ -311,7 +311,7 @@ void SGDOptimizer::nccl_update_task(Task const *task, } } - nccl_update_task_gpu(op, meta, w_grad_ptr, size, w_ptr, v_ptr); + nccl_update_task_gpu(ctx, runtime, op, meta, w_grad_ptr, size, w_ptr, v_ptr); } #endif @@ -603,7 +603,8 @@ void AdamOptimizer::nccl_update_task(Task const *task, } } - nccl_update_task_gpu(op, meta, w_grad_ptr, size, w_ptr, v_ptr, m_ptr); + nccl_update_task_gpu( + ctx, runtime, op, meta, w_grad_ptr, size, w_ptr, v_ptr, m_ptr); } #endif diff --git a/src/runtime/optimizer_kernel.cpp b/src/runtime/optimizer_kernel.cpp index 59efaf5256..9b0d3c8892 100644 --- a/src/runtime/optimizer_kernel.cpp +++ b/src/runtime/optimizer_kernel.cpp @@ -86,7 +86,9 @@ __host__ void SGDOptimizer::ps_update_task_gpu(SGDOptimizer const *op, } #ifdef FF_USE_NCCL -__host__ void SGDOptimizer::nccl_update_task_gpu(SGDOptimizer const *op, +__host__ void SGDOptimizer::nccl_update_task_gpu(Context ctx, + Runtime *runtime, + SGDOptimizer const *op, OpMeta const *meta, float const *w_grad_ptr, size_t size, @@ -96,6 +98,7 @@ __host__ void SGDOptimizer::nccl_update_task_gpu(SGDOptimizer const *op, // fprintf(stderr, "weight(%p) Before ncclAllReduce...\n", w_grad_ptr); hipStream_t stream; checkCUDA(get_legion_stream(&stream)); + runtime->concurrent_task_barrier(ctx); checkNCCL(ncclAllReduce(w_grad_ptr, (float *)w_grad_ptr, size, @@ -103,6 +106,7 @@ __host__ void SGDOptimizer::nccl_update_task_gpu(SGDOptimizer const *op, ncclSum, meta->handle.ncclComm, stream)); + runtime->concurrent_task_barrier(ctx); // fprintf(stderr, "weight(%p) After ncclAllReduce...\n", w_grad_ptr); // Step 2: SGD update @@ -208,7 +212,9 @@ __host__ void AdamOptimizer::ps_update_task_gpu(AdamOptimizer const *op, } #ifdef FF_USE_NCCL -__host__ void AdamOptimizer::nccl_update_task_gpu(AdamOptimizer const *op, +__host__ void AdamOptimizer::nccl_update_task_gpu(Context ctx, + Runtime *runtime, + AdamOptimizer const *op, OpMeta const *meta, float const *w_grad_ptr, size_t size, @@ -218,6 +224,7 @@ __host__ void AdamOptimizer::nccl_update_task_gpu(AdamOptimizer const *op, // Use NCCL to sync gradients hipStream_t stream; checkCUDA(get_legion_stream(&stream)); + runtime->concurrent_task_barrier(ctx); checkNCCL(ncclAllReduce(w_grad_ptr, (float *)w_grad_ptr, size, @@ -225,6 +232,7 @@ __host__ void AdamOptimizer::nccl_update_task_gpu(AdamOptimizer const *op, ncclSum, meta->handle.ncclComm, stream)); + runtime->concurrent_task_barrier(ctx); // fprintf(stderr, "alpha = %.8lf alpha_t = %.8lf decay = %.8lf\n", // op->alpha, op->alpha_t, op->weight_decay); // Step 2: Adam update diff --git a/src/runtime/optimizer_kernel.cu b/src/runtime/optimizer_kernel.cu index df37e3b135..72ee74940f 100644 --- a/src/runtime/optimizer_kernel.cu +++ b/src/runtime/optimizer_kernel.cu @@ -75,7 +75,9 @@ __host__ void SGDOptimizer::ps_update_task_gpu(SGDOptimizer const *op, } #ifdef FF_USE_NCCL -__host__ void SGDOptimizer::nccl_update_task_gpu(SGDOptimizer const *op, +__host__ void SGDOptimizer::nccl_update_task_gpu(Context ctx, + Runtime *runtime, + SGDOptimizer const *op, OpMeta const *meta, float const *w_grad_ptr, size_t size, @@ -85,6 +87,7 @@ __host__ void SGDOptimizer::nccl_update_task_gpu(SGDOptimizer const *op, // fprintf(stderr, "weight(%p) Before ncclAllReduce...\n", w_grad_ptr); cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); + runtime->concurrent_task_barrier(ctx); checkNCCL(ncclAllReduce(w_grad_ptr, (float *)w_grad_ptr, size, @@ -92,6 +95,7 @@ __host__ void SGDOptimizer::nccl_update_task_gpu(SGDOptimizer const *op, ncclSum, meta->handle.ncclComm, stream)); + runtime->concurrent_task_barrier(ctx); // fprintf(stderr, "weight(%p) After ncclAllReduce...\n", w_grad_ptr); // print_tensor((float*)w_grad_ptr, 16, "[After ncclAllReduce]"); @@ -183,7 +187,9 @@ __host__ void AdamOptimizer::ps_update_task_gpu(AdamOptimizer const *op, } #ifdef FF_USE_NCCL -__host__ void AdamOptimizer::nccl_update_task_gpu(AdamOptimizer const *op, +__host__ void AdamOptimizer::nccl_update_task_gpu(Context ctx, + Runtime *runtime, + AdamOptimizer const *op, OpMeta const *meta, float const *w_grad_ptr, size_t size, @@ -193,6 +199,7 @@ __host__ void AdamOptimizer::nccl_update_task_gpu(AdamOptimizer const *op, // Use NCCL to sync gradients cudaStream_t stream; checkCUDA(get_legion_stream(&stream)); + runtime->concurrent_task_barrier(ctx); checkNCCL(ncclAllReduce(w_grad_ptr, (float *)w_grad_ptr, size, @@ -200,6 +207,7 @@ __host__ void AdamOptimizer::nccl_update_task_gpu(AdamOptimizer const *op, ncclSum, meta->handle.ncclComm, stream)); + runtime->concurrent_task_barrier(ctx); // fprintf(stderr, "alpha = %.8lf alpha_t = %.8lf decay = %.8lf\n", // op->alpha, op->alpha_t, op->weight_decay); // Step 2: Adam update