From 222ce02b0a326e72bd42af8beea24fb369d7c054 Mon Sep 17 00:00:00 2001 From: Gabriele Oliaro Date: Fri, 18 Oct 2024 23:44:24 +0000 Subject: [PATCH] remove barrier where not strictly needed --- src/ops/kernels/lora_linear_kernels.cu | 4 ++-- src/parallel_ops/allreduce.cc | 8 ++++---- src/parallel_ops/parallel_identity.cc | 8 ++++---- src/runtime/model.cc | 12 ++++++------ 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/ops/kernels/lora_linear_kernels.cu b/src/ops/kernels/lora_linear_kernels.cu index 090eb17e7b..638cee8cae 100644 --- a/src/ops/kernels/lora_linear_kernels.cu +++ b/src/ops/kernels/lora_linear_kernels.cu @@ -368,8 +368,8 @@ __global__ void sgd_update(size_t count, template void peft_bwd_kernel(Context ctx, - Runtime *runtime, - LoraLinearMeta *m, + Runtime *runtime, + LoraLinearMeta *m, BatchConfig const *bc, DT *input_grad_ptr, DT const *output_grad_ptr, diff --git a/src/parallel_ops/allreduce.cc b/src/parallel_ops/allreduce.cc index 25782bbf03..6611a6bb1f 100644 --- a/src/parallel_ops/allreduce.cc +++ b/src/parallel_ops/allreduce.cc @@ -197,9 +197,9 @@ void AllReduce::forward_task(Task const *task, m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); assert(input.data_type == output.data_type); - runtime->concurrent_task_barrier(ctx); + // runtime->concurrent_task_barrier(ctx); forward_kernel_wrapper(m, input, output); - runtime->concurrent_task_barrier(ctx); + // runtime->concurrent_task_barrier(ctx); } void AllReduce::backward(FFModel const &ff) { @@ -349,9 +349,9 @@ void AllReduce::inference_task(Task const *task, m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); assert(input.data_type == output.data_type); - runtime->concurrent_task_barrier(ctx); + // runtime->concurrent_task_barrier(ctx); inference_kernel_wrapper(m, bc, input, output); - runtime->concurrent_task_barrier(ctx); + // runtime->concurrent_task_barrier(ctx); if (m->inference_debugging) { assert(task->index_point.get_dim() == 1); int shard_id = task->index_point.point_data[0]; diff --git a/src/parallel_ops/parallel_identity.cc b/src/parallel_ops/parallel_identity.cc index fabc425ad8..2f76897712 100644 --- a/src/parallel_ops/parallel_identity.cc +++ b/src/parallel_ops/parallel_identity.cc @@ -245,9 +245,9 @@ void ParallelIdentity::backward_task(Task const *task, m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); assert(input_grad.data_type == output_grad.data_type); - runtime->concurrent_task_barrier(ctx); + // runtime->concurrent_task_barrier(ctx); backward_kernel_wrapper(m, input_grad, output_grad); - runtime->concurrent_task_barrier(ctx); + // runtime->concurrent_task_barrier(ctx); } void ParallelIdentity::init_inference( @@ -423,9 +423,9 @@ void ParallelIdentity::peft_bwd_task(Task const *task, m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime); assert(input_grad.data_type == output_grad.data_type); - runtime->concurrent_task_barrier(ctx); + // runtime->concurrent_task_barrier(ctx); peft_bwd_kernel_wrapper(m, bc, input_grad, output_grad); - runtime->concurrent_task_barrier(ctx); + // runtime->concurrent_task_barrier(ctx); if (m->inference_debugging) { assert(task->index_point.get_dim() == 1); int shard_id = task->index_point.point_data[0]; diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 0b8a507d70..417cd2c056 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -7283,7 +7283,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, // AllReduce forward and backward must run concurrently since they // use ncclAllReduce internally registrar.set_concurrent(); - registrar.set_concurrent_barrier(); + // registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "AllReduce Forward Task"); @@ -7316,7 +7316,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, // AllReduce forward and backward must run concurrently since they // use ncclAllReduce internally registrar.set_concurrent(); - registrar.set_concurrent_barrier(); + // registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "AllReduce Inference Task"); @@ -7380,7 +7380,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); registrar.set_concurrent(); - registrar.set_concurrent_barrier(); + // registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "ParallelIdentity Backward Task"); @@ -7414,7 +7414,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); registrar.set_concurrent(); - registrar.set_concurrent_barrier(); + // registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "ParallelIdentity PEFT Backward Task"); @@ -7655,7 +7655,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); registrar.set_concurrent(); - registrar.set_concurrent_barrier(); + // registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "NCCL Init Communicators Task", 111 /*variant ID*/); @@ -7673,7 +7673,7 @@ void register_flexflow_internal_tasks(Runtime *runtime, registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC)); registrar.set_leaf(); registrar.set_concurrent(); - registrar.set_concurrent_barrier(); + // registrar.set_concurrent_barrier(); if (pre_register) { Runtime::preregister_task_variant( registrar, "NCCL Finish Communicators Task", 111 /*variant ID*/);