diff --git a/include/flexflow/model.h b/include/flexflow/model.h index 6f805e21bd..75b1dbcbe9 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -1090,7 +1090,7 @@ class FFModel { std::unordered_map>> get_bwd_edge_map() const; - // Internal funcitons + // Internal functions Legion::IndexSpace get_or_create_task_is(ParallelConfig const &pc); Legion::IndexSpace get_or_create_task_is(MachineView const &view); Legion::IndexSpace get_or_create_task_is(Legion::Domain const &domain); @@ -1098,6 +1098,7 @@ class FFModel { Legion::IndexSpace get_task_is(Legion::Domain const &domain) const; Legion::IndexSpace get_task_is(ParallelConfig const &pc) const; Legion::IndexSpace get_task_is(MachineView const &view) const; + bool is_mlp_block(int layer_idx) const; void create_operators_from_layers(); Op *create_operator_from_layer(Layer *layer, std::vector const &inputs); diff --git a/inference/models/opt.cc b/inference/models/opt.cc index 0279f83239..e260f8fa36 100644 --- a/inference/models/opt.cc +++ b/inference/models/opt.cc @@ -196,7 +196,7 @@ void OPT::create_opt_model(FFModel &ff, Tensor fc1 = ff.dense(final_norm, opt_config.ffn_dim, - AC_MODE_NONE, + AC_MODE_RELU, true, DT_NONE, nullptr, @@ -205,8 +205,7 @@ void OPT::create_opt_model(FFModel &ff, REG_MODE_NONE, 0.0f, std::string("layers_" + std::to_string(i) + "_fc1").c_str()); - Tensor activation = ff.relu(fc1, false); - fc2 = ff.dense(activation, + fc2 = ff.dense(fc1, opt_config.hidden_size, AC_MODE_NONE, true, diff --git a/python/flexflow/serve/models/opt.py b/python/flexflow/serve/models/opt.py index dfd1cde7d4..dd36fa6592 100644 --- a/python/flexflow/serve/models/opt.py +++ b/python/flexflow/serve/models/opt.py @@ -216,13 +216,12 @@ def build_model(self, max_tokens_per_batch): fc1 = ffmodel.dense( ff_norm, self.opt_config.ffn_dim, - ActiMode.AC_MODE_NONE, + ActiMode.AC_MODE_RELU, True, name=f"layers_{i}_fc1", ) - activation = ffmodel.relu(fc1, False) fc2 = ffmodel.dense( - activation, + fc1, self.opt_config.hidden_size, ActiMode.AC_MODE_NONE, True, diff --git a/src/ops/kernels/linear_kernels.cu b/src/ops/kernels/linear_kernels.cu index 9373c2fb2f..c30c9f71c1 100644 --- a/src/ops/kernels/linear_kernels.cu +++ b/src/ops/kernels/linear_kernels.cu @@ -252,6 +252,18 @@ Parameter* Linear::get_parameter(int index) */ namespace Internal { +template +__global__ void AddBiasWithReLU(DT *output_ptr, + DT const *bias_ptr, + int out_dim, + int batch_size) { + CUDA_KERNEL_LOOP(i, out_dim * batch_size) { + int bias_idx = i % out_dim; + DT value = output_ptr[i] + bias_ptr[bias_idx]; + output_ptr[i] = ((float)value > 0.0f) ? value : (DT)0.0f; + } +} + template void forward_kernel(LinearMeta const *m, void const *input_ptr, @@ -343,6 +355,16 @@ void forward_kernel(LinearMeta const *m, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // use_bias = True if (bias_ptr != NULL) { + // fuse bias and relu + if (m->activation == AC_MODE_RELU) { + int parallelism = out_dim * batch_size; + AddBiasWithReLU<<>>( + static_cast
(output_ptr), + static_cast
(bias_ptr), + out_dim, + batch_size); + return; + } checkCUDA(cublasGemmEx(m->handle.blas, CUBLAS_OP_T, CUBLAS_OP_N, diff --git a/src/runtime/model.cc b/src/runtime/model.cc index 76bed36bda..4270515224 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -3236,6 +3236,27 @@ Op *FFModel::create_operator_from_layer( } } +bool FFModel::is_mlp_block(int layer_idx) const { + auto const &l = layers[layer_idx]; + // standard opt relu + if (l->op_type == OP_LINEAR && layer_idx >= 2 && + layers[layer_idx - 1]->op_type == OP_RELU && + layers[layer_idx - 2]->op_type == OP_LINEAR) { + return true; + } + // mlp layer with relu embedded in first dense layer + if (l->op_type == OP_LINEAR && layer_idx >= 1 && + layers[layer_idx - 1]->op_type == OP_LINEAR) { + long long value; + layers[layer_idx - 1]->get_int_property("activation", value); + ActiMode activation = (ActiMode)value; + if (activation == AC_MODE_RELU) { + return true; + } + } + return false; +} + void FFModel::create_operators_from_layers() { std::map tensors_to_parallel_tensors; // for (auto const &l : layers) { @@ -3280,9 +3301,9 @@ void FFModel::create_operators_from_layers() { config.tensor_parallelism_degree > 1 && (l->op_type == OP_INC_MULTIHEAD_SELF_ATTENTION || l->op_type == OP_TREE_INC_MULTIHEAD_SELF_ATTENTION || - (l->op_type == OP_LINEAR && layer_idx >= 2 && - layers[layer_idx - 1]->op_type == OP_RELU && - layers[layer_idx - 2]->op_type == OP_LINEAR) || + // mlp layer + is_mlp_block(layer_idx) || + // llama mlp layer (l->op_type == OP_LINEAR && layer_idx >= 2 && layers[layer_idx - 1]->op_type == OP_GELU && layers[layer_idx - 2]->op_type == OP_LINEAR) ||