Skip to content

Commit

Permalink
remove barrier where not strictly needed
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Oct 18, 2024
1 parent c224f31 commit 222ce02
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 16 deletions.
4 changes: 2 additions & 2 deletions src/ops/kernels/lora_linear_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,8 @@ __global__ void sgd_update(size_t count,

template <typename DT>
void peft_bwd_kernel(Context ctx,
Runtime *runtime,
LoraLinearMeta *m,
Runtime *runtime,
LoraLinearMeta *m,
BatchConfig const *bc,
DT *input_grad_ptr,
DT const *output_grad_ptr,
Expand Down
8 changes: 4 additions & 4 deletions src/parallel_ops/allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,9 @@ void AllReduce::forward_task(Task const *task,
m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime);

assert(input.data_type == output.data_type);
runtime->concurrent_task_barrier(ctx);
// runtime->concurrent_task_barrier(ctx);
forward_kernel_wrapper(m, input, output);
runtime->concurrent_task_barrier(ctx);
// runtime->concurrent_task_barrier(ctx);
}

void AllReduce::backward(FFModel const &ff) {
Expand Down Expand Up @@ -349,9 +349,9 @@ void AllReduce::inference_task(Task const *task,
m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime);

assert(input.data_type == output.data_type);
runtime->concurrent_task_barrier(ctx);
// runtime->concurrent_task_barrier(ctx);
inference_kernel_wrapper(m, bc, input, output);
runtime->concurrent_task_barrier(ctx);
// runtime->concurrent_task_barrier(ctx);
if (m->inference_debugging) {
assert(task->index_point.get_dim() == 1);
int shard_id = task->index_point.point_data[0];
Expand Down
8 changes: 4 additions & 4 deletions src/parallel_ops/parallel_identity.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,9 @@ void ParallelIdentity::backward_task(Task const *task,
m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime);

assert(input_grad.data_type == output_grad.data_type);
runtime->concurrent_task_barrier(ctx);
// runtime->concurrent_task_barrier(ctx);
backward_kernel_wrapper(m, input_grad, output_grad);
runtime->concurrent_task_barrier(ctx);
// runtime->concurrent_task_barrier(ctx);
}

void ParallelIdentity::init_inference(
Expand Down Expand Up @@ -423,9 +423,9 @@ void ParallelIdentity::peft_bwd_task(Task const *task,
m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime);

assert(input_grad.data_type == output_grad.data_type);
runtime->concurrent_task_barrier(ctx);
// runtime->concurrent_task_barrier(ctx);
peft_bwd_kernel_wrapper(m, bc, input_grad, output_grad);
runtime->concurrent_task_barrier(ctx);
// runtime->concurrent_task_barrier(ctx);
if (m->inference_debugging) {
assert(task->index_point.get_dim() == 1);
int shard_id = task->index_point.point_data[0];
Expand Down
12 changes: 6 additions & 6 deletions src/runtime/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7283,7 +7283,7 @@ void register_flexflow_internal_tasks(Runtime *runtime,
// AllReduce forward and backward must run concurrently since they
// use ncclAllReduce internally
registrar.set_concurrent();
registrar.set_concurrent_barrier();
// registrar.set_concurrent_barrier();
if (pre_register) {
Runtime::preregister_task_variant<AllReduce::forward_task>(
registrar, "AllReduce Forward Task");
Expand Down Expand Up @@ -7316,7 +7316,7 @@ void register_flexflow_internal_tasks(Runtime *runtime,
// AllReduce forward and backward must run concurrently since they
// use ncclAllReduce internally
registrar.set_concurrent();
registrar.set_concurrent_barrier();
// registrar.set_concurrent_barrier();
if (pre_register) {
Runtime::preregister_task_variant<AllReduce::inference_task>(
registrar, "AllReduce Inference Task");
Expand Down Expand Up @@ -7380,7 +7380,7 @@ void register_flexflow_internal_tasks(Runtime *runtime,
registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC));
registrar.set_leaf();
registrar.set_concurrent();
registrar.set_concurrent_barrier();
// registrar.set_concurrent_barrier();
if (pre_register) {
Runtime::preregister_task_variant<ParallelIdentity::backward_task>(
registrar, "ParallelIdentity Backward Task");
Expand Down Expand Up @@ -7414,7 +7414,7 @@ void register_flexflow_internal_tasks(Runtime *runtime,
registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC));
registrar.set_leaf();
registrar.set_concurrent();
registrar.set_concurrent_barrier();
// registrar.set_concurrent_barrier();
if (pre_register) {
Runtime::preregister_task_variant<ParallelIdentity::peft_bwd_task>(
registrar, "ParallelIdentity PEFT Backward Task");
Expand Down Expand Up @@ -7655,7 +7655,7 @@ void register_flexflow_internal_tasks(Runtime *runtime,
registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC));
registrar.set_leaf();
registrar.set_concurrent();
registrar.set_concurrent_barrier();
// registrar.set_concurrent_barrier();
if (pre_register) {
Runtime::preregister_task_variant<ncclComm_t, Op::init_nccl_comms_task>(
registrar, "NCCL Init Communicators Task", 111 /*variant ID*/);
Expand All @@ -7673,7 +7673,7 @@ void register_flexflow_internal_tasks(Runtime *runtime,
registrar.add_constraint(ProcessorConstraint(Processor::TOC_PROC));
registrar.set_leaf();
registrar.set_concurrent();
registrar.set_concurrent_barrier();
// registrar.set_concurrent_barrier();
if (pre_register) {
Runtime::preregister_task_variant<Op::finish_nccl_comms_task>(
registrar, "NCCL Finish Communicators Task", 111 /*variant ID*/);
Expand Down

0 comments on commit 222ce02

Please sign in to comment.