Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuse bias and relu in OPT #1265

Merged
merged 4 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/flexflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -1090,14 +1090,15 @@ class FFModel {
std::unordered_map<Op *, std::vector<std::pair<Op *, int>>>
get_bwd_edge_map() const;

// Internal funcitons
// Internal functions
Legion::IndexSpace get_or_create_task_is(ParallelConfig const &pc);
Legion::IndexSpace get_or_create_task_is(MachineView const &view);
Legion::IndexSpace get_or_create_task_is(Legion::Domain const &domain);
Legion::IndexSpace get_or_create_task_is(const ParallelTensor);
Legion::IndexSpace get_task_is(Legion::Domain const &domain) const;
Legion::IndexSpace get_task_is(ParallelConfig const &pc) const;
Legion::IndexSpace get_task_is(MachineView const &view) const;
bool is_mlp_block(int layer_idx) const;
void create_operators_from_layers();
Op *create_operator_from_layer(Layer *layer,
std::vector<ParallelTensor> const &inputs);
Expand Down
5 changes: 2 additions & 3 deletions inference/models/opt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ void OPT::create_opt_model(FFModel &ff,
Tensor fc1 =
ff.dense(final_norm,
opt_config.ffn_dim,
AC_MODE_NONE,
AC_MODE_RELU,
true,
DT_NONE,
nullptr,
Expand All @@ -205,8 +205,7 @@ void OPT::create_opt_model(FFModel &ff,
REG_MODE_NONE,
0.0f,
std::string("layers_" + std::to_string(i) + "_fc1").c_str());
Tensor activation = ff.relu(fc1, false);
fc2 = ff.dense(activation,
fc2 = ff.dense(fc1,
opt_config.hidden_size,
AC_MODE_NONE,
true,
Expand Down
5 changes: 2 additions & 3 deletions python/flexflow/serve/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,12 @@ def build_model(self, max_tokens_per_batch):
fc1 = ffmodel.dense(
ff_norm,
self.opt_config.ffn_dim,
ActiMode.AC_MODE_NONE,
ActiMode.AC_MODE_RELU,
True,
name=f"layers_{i}_fc1",
)
activation = ffmodel.relu(fc1, False)
fc2 = ffmodel.dense(
activation,
fc1,
self.opt_config.hidden_size,
ActiMode.AC_MODE_NONE,
True,
Expand Down
22 changes: 22 additions & 0 deletions src/ops/kernels/linear_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,18 @@ Parameter* Linear::get_parameter(int index)
*/
namespace Internal {

template <typename DT>
__global__ void AddBiasWithReLU(DT *output_ptr,
DT const *bias_ptr,
int out_dim,
int batch_size) {
CUDA_KERNEL_LOOP(i, out_dim * batch_size) {
int bias_idx = i % out_dim;
DT value = output_ptr[i] + bias_ptr[bias_idx];
output_ptr[i] = ((float)value > 0.0f) ? value : (DT)0.0f;
}
}

template <typename DT>
void forward_kernel(LinearMeta const *m,
void const *input_ptr,
Expand Down Expand Up @@ -343,6 +355,16 @@ void forward_kernel(LinearMeta const *m,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// use_bias = True
if (bias_ptr != NULL) {
// fuse bias and relu
if (m->activation == AC_MODE_RELU) {
int parallelism = out_dim * batch_size;
AddBiasWithReLU<<<GET_BLOCKS(parallelism), CUDA_NUM_THREADS, 0, stream>>>(
static_cast<DT *>(output_ptr),
static_cast<DT const *>(bias_ptr),
out_dim,
batch_size);
return;
}
checkCUDA(cublasGemmEx(m->handle.blas,
CUBLAS_OP_T,
CUBLAS_OP_N,
Expand Down
27 changes: 24 additions & 3 deletions src/runtime/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3236,6 +3236,27 @@ Op *FFModel::create_operator_from_layer(
}
}

bool FFModel::is_mlp_block(int layer_idx) const {
auto const &l = layers[layer_idx];
// standard opt relu
if (l->op_type == OP_LINEAR && layer_idx >= 2 &&
layers[layer_idx - 1]->op_type == OP_RELU &&
layers[layer_idx - 2]->op_type == OP_LINEAR) {
return true;
}
// mlp layer with relu embedded in first dense layer
if (l->op_type == OP_LINEAR && layer_idx >= 1 &&
layers[layer_idx - 1]->op_type == OP_LINEAR) {
long long value;
layers[layer_idx - 1]->get_int_property("activation", value);
ActiMode activation = (ActiMode)value;
if (activation == AC_MODE_RELU) {
return true;
}
}
return false;
}

void FFModel::create_operators_from_layers() {
std::map<const Tensor, ParallelTensor> tensors_to_parallel_tensors;
// for (auto const &l : layers) {
Expand Down Expand Up @@ -3280,9 +3301,9 @@ void FFModel::create_operators_from_layers() {
config.tensor_parallelism_degree > 1 &&
(l->op_type == OP_INC_MULTIHEAD_SELF_ATTENTION ||
l->op_type == OP_TREE_INC_MULTIHEAD_SELF_ATTENTION ||
(l->op_type == OP_LINEAR && layer_idx >= 2 &&
layers[layer_idx - 1]->op_type == OP_RELU &&
layers[layer_idx - 2]->op_type == OP_LINEAR) ||
// mlp layer
is_mlp_block(layer_idx) ||
// llama mlp layer
(l->op_type == OP_LINEAR && layer_idx >= 2 &&
layers[layer_idx - 1]->op_type == OP_GELU &&
layers[layer_idx - 2]->op_type == OP_LINEAR) ||
Expand Down
Loading