Skip to content

Commit

Permalink
fix deadlock?
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Oct 18, 2024
1 parent d82cd13 commit c224f31
Show file tree
Hide file tree
Showing 15 changed files with 119 additions and 51 deletions.
11 changes: 6 additions & 5 deletions cmake/nccl.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,12 @@ if(NCCL_LIBRARY AND NCCL_INCLUDE_DIR)
string(REGEX MATCH "([0-9]+)" NCCL_MAJOR ${NCCL_VERSION_DEFINES})
string(REGEX MATCH "([0-9]+)" NCCL_MINOR ${NCCL_VERSION_DEFINES2})
set(NCCL_VERSION "${NCCL_MAJOR}.${NCCL_MINOR}")
if(NCCL_VERSION VERSION_LESS 2.23)
set(NCCL_OLD TRUE)
else()
set(NCCL_OLD FALSE)
endif()
set(NCCL_OLD FALSE)
# if(NCCL_VERSION VERSION_LESS 2.23)
# set(NCCL_OLD TRUE)
# else()
# set(NCCL_OLD FALSE)
# endif()
message(STATUS "Found NCCL version: ${NCCL_VERSION}")
else()
message(WARNING "NCCL header not found, unable to determine version")
Expand Down
24 changes: 12 additions & 12 deletions docker/flexflow-environment/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,18 @@ ENV CUDA_DIR /usr/local/cuda
ARG FF_GPU_BACKEND "cuda"

# Update NCCL if FF_GPU_BACKEND is cuda
RUN /bin/bash -c 'if [ "$FF_GPU_BACKEND" = "cuda" ]; then \
echo "FF_GPU_BACKEND: ${FF_GPU_BACKEND}. Updating NCCL"; \
ubuntu_version=$(lsb_release -rs); \
ubuntu_version=${ubuntu_version//./}; \
wget "https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${ubuntu_version}/x86_64/cuda-keyring_1.0-1_all.deb"; \
DEBIAN_FRONTEND=noninteractive dpkg -i cuda-keyring_1.0-1_all.deb; \
DEBIAN_FRONTEND=noninteractive apt-get update -y --allow-change-held-packages; \
rm -f cuda-keyring_1.0-1_all.deb; \
DEBIAN_FRONTEND=noninteractive apt install -y --allow-change-held-packages libnccl2 libnccl-dev; \
else \
echo "FF_GPU_BACKEND: ${FF_GPU_BACKEND}. Skipping updating NCCL"; \
fi'
# RUN /bin/bash -c 'if [ "$FF_GPU_BACKEND" = "cuda" ]; then \
# echo "FF_GPU_BACKEND: ${FF_GPU_BACKEND}. Updating NCCL"; \
# ubuntu_version=$(lsb_release -rs); \
# ubuntu_version=${ubuntu_version//./}; \
# wget "https://developer.download.nvidia.com/compute/cuda/repos/ubuntu${ubuntu_version}/x86_64/cuda-keyring_1.0-1_all.deb"; \
# DEBIAN_FRONTEND=noninteractive dpkg -i cuda-keyring_1.0-1_all.deb; \
# DEBIAN_FRONTEND=noninteractive apt-get update -y --allow-change-held-packages; \
# rm -f cuda-keyring_1.0-1_all.deb; \
# DEBIAN_FRONTEND=noninteractive apt install -y --allow-change-held-packages libnccl2 libnccl-dev; \
# else \
# echo "FF_GPU_BACKEND: ${FF_GPU_BACKEND}. Skipping updating NCCL"; \
# fi'

# Install hip dependencies if FF_GPU_BACKEND is hip_cuda or hip_rocm
# Note that amd's docs say to also install the `hip-runtime-nvidia` package. This
Expand Down
11 changes: 8 additions & 3 deletions include/flexflow/ops/kernels/lora_linear_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
#include "flexflow/ops/lora_linear.h"

namespace FlexFlow {

using Legion::Context;
using Legion::Runtime;
struct LoraLinearWeight {
// weights
void *w0_ptr, *w1_ptr;
Expand Down Expand Up @@ -46,7 +47,9 @@ void inference_kernel_wrapper(LoraLinearMeta *m,
BatchConfig const *bc,
GenericTensorAccessorR const &input,
GenericTensorAccessorW const &output);
void peft_bwd_kernel_wrapper(LoraLinearMeta *m,
void peft_bwd_kernel_wrapper(Context ctx,
Runtime *runtime,
LoraLinearMeta *m,
BatchConfig const *bc,
GenericTensorAccessorW const &input_grad,
GenericTensorAccessorR const &output_grad);
Expand All @@ -63,7 +66,9 @@ void inference_kernel(LoraLinearMeta *m,
int out_dim,
ffStream_t stream);
template <typename DT>
void peft_bwd_kernel(LoraLinearMeta *m,
void peft_bwd_kernel(Context ctx,
Runtime *runtime,
LoraLinearMeta *m,
BatchConfig const *bc,
DT *input_grad_ptr,
DT const *output_grad_ptr,
Expand Down
11 changes: 8 additions & 3 deletions include/flexflow/optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
#include "legion.h"

namespace FlexFlow {

using Legion::Context;
using Legion::Runtime;
class FFModel;
class OpMeta;

Expand Down Expand Up @@ -60,7 +61,9 @@ class SGDOptimizer : public Optimizer {
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static void nccl_update_task_gpu(SGDOptimizer const *op,
static void nccl_update_task_gpu(Context ctx,
Runtime *runtime,
SGDOptimizer const *op,
OpMeta const *meta,
float const *w_grad_ptr,
size_t size,
Expand Down Expand Up @@ -103,7 +106,9 @@ class AdamOptimizer : public Optimizer {
std::vector<Legion::PhysicalRegion> const &regions,
Legion::Context ctx,
Legion::Runtime *runtime);
static void nccl_update_task_gpu(AdamOptimizer const *op,
static void nccl_update_task_gpu(Context ctx,
Runtime *runtime,
AdamOptimizer const *op,
OpMeta const *meta,
float const *w_grad_ptr,
size_t size,
Expand Down
2 changes: 0 additions & 2 deletions src/ops/fused.cc
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,6 @@ void FusedOp::init(FFModel const &ff) {
false /*must*/,
0 /*mapper_id*/,
outputs[0]->machine_view.hash());
launcher.concurrent = true;
FutureMap fm = runtime->execute_index_space(ctx, launcher);
fm.wait_all_results();
switch (domain.get_dim()) {
Expand Down Expand Up @@ -571,7 +570,6 @@ void FusedOp::init_inference(FFModel const &ff,
false /*must*/,
0 /*mapper_id*/,
machine_view_hash);
launcher.concurrent = true;
FutureMap fm = runtime->execute_index_space(ctx, launcher);
fm.wait_all_results();
switch (domain.get_dim()) {
Expand Down
11 changes: 10 additions & 1 deletion src/ops/fused.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -612,8 +612,10 @@ __host__ void
assert(fused->op_num_inputs[op] == 1);
assert(fused->op_num_outputs[op] == 1);
AllReduceMeta const *m = (AllReduceMeta *)metas->meta[op];
runtime->concurrent_task_barrier(ctx);
Kernels::AllReduce::inference_kernel_wrapper(
m, bc, my_input_accessor[0], my_output_accessor[0]);
runtime->concurrent_task_barrier(ctx);
break;
}
case OP_PARALLEL_IDENTITY: {
Expand Down Expand Up @@ -870,7 +872,12 @@ __host__ void FusedOp::peft_bwd_task(Task const *task,
// since we ``inplace'' the output for LoRA
assert(my_input_grad_accessor[1].ptr == my_output_grad_accessor[0].ptr);
Kernels::LoraLinear::peft_bwd_kernel_wrapper(
m, bc, my_input_grad_accessor[0], my_output_grad_accessor[0]);
ctx,
runtime,
m,
bc,
my_input_grad_accessor[0],
my_output_grad_accessor[0]);
break;
}
case OP_BATCHMATMUL: {
Expand Down Expand Up @@ -1129,8 +1136,10 @@ __host__ void FusedOp::peft_bwd_task(Task const *task,
assert(fused->op_num_inputs[op] == 1);
assert(fused->op_num_outputs[op] == 1);
ParallelIdentityMeta const *m = (ParallelIdentityMeta *)metas->meta[op];
runtime->concurrent_task_barrier(ctx);
Kernels::ParallelIdentity::peft_bwd_kernel_wrapper(
m, bc, my_input_grad_accessor[0], my_output_grad_accessor[0]);
runtime->concurrent_task_barrier(ctx);
break;
}
default: {
Expand Down
11 changes: 10 additions & 1 deletion src/ops/fused.cu
Original file line number Diff line number Diff line change
Expand Up @@ -623,8 +623,10 @@ __host__ void
assert(fused->op_num_inputs[op] == 1);
assert(fused->op_num_outputs[op] == 1);
AllReduceMeta const *m = (AllReduceMeta *)metas->meta[op];
runtime->concurrent_task_barrier(ctx);
Kernels::AllReduce::inference_kernel_wrapper(
m, bc, my_input_accessor[0], my_output_accessor[0]);
runtime->concurrent_task_barrier(ctx);
break;
}
case OP_PARALLEL_IDENTITY: {
Expand Down Expand Up @@ -888,7 +890,12 @@ __host__ void FusedOp::peft_bwd_task(Task const *task,
// since we ``inplace'' the output for LoRA
assert(my_input_grad_accessor[1].ptr == my_output_grad_accessor[0].ptr);
Kernels::LoraLinear::peft_bwd_kernel_wrapper(
m, bc, my_input_grad_accessor[0], my_output_grad_accessor[0]);
ctx,
runtime,
m,
bc,
my_input_grad_accessor[0],
my_output_grad_accessor[0]);
break;
}
case OP_BATCHMATMUL: {
Expand Down Expand Up @@ -1149,8 +1156,10 @@ __host__ void FusedOp::peft_bwd_task(Task const *task,
assert(fused->op_num_inputs[op] == 1);
assert(fused->op_num_outputs[op] == 1);
ParallelIdentityMeta const *m = (ParallelIdentityMeta *)metas->meta[op];
runtime->concurrent_task_barrier(ctx);
Kernels::ParallelIdentity::peft_bwd_kernel_wrapper(
m, bc, my_input_grad_accessor[0], my_output_grad_accessor[0]);
runtime->concurrent_task_barrier(ctx);
break;
}
default: {
Expand Down
20 changes: 15 additions & 5 deletions src/ops/kernels/lora_linear_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ void inference_kernel_wrapper(LoraLinearMeta *m,
}
}

void peft_bwd_kernel_wrapper(LoraLinearMeta *m,
void peft_bwd_kernel_wrapper(Context ctx,
Runtime *runtime,
LoraLinearMeta *m,
BatchConfig const *bc,
GenericTensorAccessorW const &input_grad,
GenericTensorAccessorR const &output_grad) {
Expand All @@ -111,15 +113,19 @@ void peft_bwd_kernel_wrapper(LoraLinearMeta *m,
int in_dim = input_grad.domain.hi()[0] - input_grad.domain.lo()[0] + 1;
int out_dim = output_grad.domain.hi()[0] - output_grad.domain.lo()[0] + 1;
if (m->input_type[0] == DT_FLOAT) {
Internal::peft_bwd_kernel<float>(m,
Internal::peft_bwd_kernel<float>(ctx,
runtime,
m,
bc,
input_grad.get_float_ptr(),
output_grad.get_float_ptr(),
in_dim,
out_dim,
stream);
} else if (m->input_type[0] == DT_HALF) {
Internal::peft_bwd_kernel<half>(m,
Internal::peft_bwd_kernel<half>(ctx,
runtime,
m,
bc,
input_grad.get_half_ptr(),
output_grad.get_half_ptr(),
Expand Down Expand Up @@ -361,7 +367,9 @@ __global__ void sgd_update(size_t count,
}

template <typename DT>
void peft_bwd_kernel(LoraLinearMeta *m,
void peft_bwd_kernel(Context ctx,
Runtime *runtime,
LoraLinearMeta *m,
BatchConfig const *bc,
DT *input_grad_ptr,
DT const *output_grad_ptr,
Expand Down Expand Up @@ -543,13 +551,15 @@ void peft_bwd_kernel(LoraLinearMeta *m,
// and sum first
#ifdef FF_USE_NCCL
ncclDataType_t nccl_data_type = ff_to_nccl_datatype(m->output_type[0]);
checkCUDA(ncclAllReduce(static_cast<DT const *>(weight.w1_grad_ptr),
runtime->concurrent_task_barrier(ctx);
checkNCCL(ncclAllReduce(static_cast<DT const *>(weight.w1_grad_ptr),
static_cast<DT *>(weight.w1_grad_ptr),
w1_num_elements,
nccl_data_type,
ncclSum,
m->handle.ncclComm,
stream));
runtime->concurrent_task_barrier(ctx);
#else
assert(false && "Must enable FF_USE_NCCL to use AllReduce operators");
#endif
Expand Down
3 changes: 1 addition & 2 deletions src/ops/lora_linear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,6 @@ void LoraLinear::init_inference(
false /*must*/,
0 /*mapper_id*/,
machine_view_hash);
launcher.concurrent = true;
launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part,
0 /*projection id*/,
READ_ONLY,
Expand Down Expand Up @@ -1066,7 +1065,7 @@ void LoraLinear::peft_bwd_task(Task const *task,
int out_dim = output_grad.domain.hi()[0] - output_grad.domain.lo()[0] + 1;
// int num_infr_tokens = bc->num_active_infr_tokens();
// int num_peft_tokens = bc->num_active_peft_tokens();
peft_bwd_kernel_wrapper(m, bc, input_grad, output_grad);
peft_bwd_kernel_wrapper(ctx, runtime, m, bc, input_grad, output_grad);

save_peft_weights_if_needed(m, bc, in_dim, out_dim, shard_id);

Expand Down
4 changes: 4 additions & 0 deletions src/parallel_ops/allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ void AllReduce::forward_task(Task const *task,
m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime);

assert(input.data_type == output.data_type);
runtime->concurrent_task_barrier(ctx);
forward_kernel_wrapper(m, input, output);
runtime->concurrent_task_barrier(ctx);
}

void AllReduce::backward(FFModel const &ff) {
Expand Down Expand Up @@ -347,7 +349,9 @@ void AllReduce::inference_task(Task const *task,
m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime);

assert(input.data_type == output.data_type);
runtime->concurrent_task_barrier(ctx);
inference_kernel_wrapper(m, bc, input, output);
runtime->concurrent_task_barrier(ctx);
if (m->inference_debugging) {
assert(task->index_point.get_dim() == 1);
int shard_id = task->index_point.point_data[0];
Expand Down
5 changes: 4 additions & 1 deletion src/parallel_ops/parallel_identity.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,9 @@ void ParallelIdentity::backward_task(Task const *task,
m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime);

assert(input_grad.data_type == output_grad.data_type);
runtime->concurrent_task_barrier(ctx);
backward_kernel_wrapper(m, input_grad, output_grad);
runtime->concurrent_task_barrier(ctx);
}

void ParallelIdentity::init_inference(
Expand All @@ -270,7 +272,6 @@ void ParallelIdentity::init_inference(
false /*must*/,
0 /*mapper_id*/,
machine_view_hash);
launcher.concurrent = true;
launcher.add_region_requirement(RegionRequirement(batch_inputs[0]->part,
0 /*projection id*/,
READ_ONLY,
Expand Down Expand Up @@ -422,7 +423,9 @@ void ParallelIdentity::peft_bwd_task(Task const *task,
m->output_type[0], regions[1], task->regions[1], FID_DATA, ctx, runtime);

assert(input_grad.data_type == output_grad.data_type);
runtime->concurrent_task_barrier(ctx);
peft_bwd_kernel_wrapper(m, bc, input_grad, output_grad);
runtime->concurrent_task_barrier(ctx);
if (m->inference_debugging) {
assert(task->index_point.get_dim() == 1);
int shard_id = task->index_point.point_data[0];
Expand Down
Loading

0 comments on commit c224f31

Please sign in to comment.