Skip to content

Commit

Permalink
Merge branch 'inference' into background_worker
Browse files Browse the repository at this point in the history
  • Loading branch information
zwang86 authored Jan 12, 2024
2 parents 4b4d1a9 + 9c85a4f commit 70212f6
Show file tree
Hide file tree
Showing 11 changed files with 65 additions and 14 deletions.
2 changes: 2 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ python/flexflow/core/legion_cffi_header.py
/inference/tokenizer/*
/inference/prompt/*
/inference/output/*

/tests/inference/python_test_configs/*.json
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -186,4 +186,5 @@ gpt_tokenizer
# pip version
python/flexflow/version.txt

inference_tensors
inference_tensors
tests/inference/python_test_configs/*.json
3 changes: 2 additions & 1 deletion include/flexflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -1091,14 +1091,15 @@ class FFModel {
std::unordered_map<Op *, std::vector<std::pair<Op *, int>>>
get_bwd_edge_map() const;

// Internal funcitons
// Internal functions
Legion::IndexSpace get_or_create_task_is(ParallelConfig const &pc);
Legion::IndexSpace get_or_create_task_is(MachineView const &view);
Legion::IndexSpace get_or_create_task_is(Legion::Domain const &domain);
Legion::IndexSpace get_or_create_task_is(const ParallelTensor);
Legion::IndexSpace get_task_is(Legion::Domain const &domain) const;
Legion::IndexSpace get_task_is(ParallelConfig const &pc) const;
Legion::IndexSpace get_task_is(MachineView const &view) const;
bool is_mlp_block(int layer_idx) const;
void create_operators_from_layers();
Op *create_operator_from_layer(Layer *layer,
std::vector<ParallelTensor> const &inputs);
Expand Down
5 changes: 2 additions & 3 deletions inference/models/opt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ void OPT::create_opt_model(FFModel &ff,
Tensor fc1 =
ff.dense(final_norm,
opt_config.ffn_dim,
AC_MODE_NONE,
AC_MODE_RELU,
true,
DT_NONE,
nullptr,
Expand All @@ -205,8 +205,7 @@ void OPT::create_opt_model(FFModel &ff,
REG_MODE_NONE,
0.0f,
std::string("layers_" + std::to_string(i) + "_fc1").c_str());
Tensor activation = ff.relu(fc1, false);
fc2 = ff.dense(activation,
fc2 = ff.dense(fc1,
opt_config.hidden_size,
AC_MODE_NONE,
true,
Expand Down
2 changes: 1 addition & 1 deletion python/flexflow/core/flexflow_cffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get_c_name(name):
if name is None:
return ffi.NULL
else:
return ffi.new("char[]", name.encode("ascii"))
return ffi.new("char[]", name.encode("utf-8"))


def get_datatype_size(datatype):
Expand Down
5 changes: 2 additions & 3 deletions python/flexflow/serve/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,12 @@ def build_model(self, max_tokens_per_batch):
fc1 = ffmodel.dense(
ff_norm,
self.opt_config.ffn_dim,
ActiMode.AC_MODE_NONE,
ActiMode.AC_MODE_RELU,
True,
name=f"layers_{i}_fc1",
)
activation = ffmodel.relu(fc1, False)
fc2 = ffmodel.dense(
activation,
fc1,
self.opt_config.hidden_size,
ActiMode.AC_MODE_NONE,
True,
Expand Down
6 changes: 5 additions & 1 deletion src/c/flexflow_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1597,7 +1597,11 @@ flexflow_generation_result_t
handle->generate(prompts, max_seq_length);
DEBUG_PRINT(
"[Model] generate %p %s %i", handle, text_str.c_str(), max_seq_length);
assert(results[0].output_tokens.size() <= max_seq_length);
// If the prompt exceeds max seq len, check that we return the prompt with no
// additional token. Otherwise, check that the output does not exceed the max
// sequence length.
assert(results[0].output_tokens.size() <= max_seq_length ||
results[0].output_tokens.size() == results[0].input_tokens.size());
output_length_and_tokens[0] = results[0].output_tokens.size();
std::copy(results[0].output_tokens.begin(),
results[0].output_tokens.end(),
Expand Down
22 changes: 22 additions & 0 deletions src/ops/kernels/linear_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,18 @@ Parameter* Linear::get_parameter(int index)
*/
namespace Internal {
template <typename DT>
__global__ void AddBiasWithReLU(DT *output_ptr,
DT const *bias_ptr,
int out_dim,
int batch_size) {
CUDA_KERNEL_LOOP(i, out_dim * batch_size) {
int bias_idx = i % out_dim;
DT value = output_ptr[i] + bias_ptr[bias_idx];
output_ptr[i] = ((float)value > 0.0f) ? value : (DT)0.0f;
}
}
template <typename DT>
void forward_kernel(LinearMeta const *m,
void const *input_ptr,
Expand Down Expand Up @@ -343,6 +355,16 @@ void forward_kernel(LinearMeta const *m,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// use_bias = True
if (bias_ptr != NULL) {
// fuse bias and relu
if (m->activation == AC_MODE_RELU) {
int parallelism = out_dim * batch_size;
AddBiasWithReLU<<<GET_BLOCKS(parallelism), CUDA_NUM_THREADS, 0, stream>>>(
static_cast<DT *>(output_ptr),
static_cast<DT const *>(bias_ptr),
out_dim,
batch_size);
return;
}
checkCUDA(cublasGemmEx(m->handle.blas,
CUBLAS_OP_T,
CUBLAS_OP_N,
Expand Down
27 changes: 24 additions & 3 deletions src/runtime/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3236,6 +3236,27 @@ Op *FFModel::create_operator_from_layer(
}
}

bool FFModel::is_mlp_block(int layer_idx) const {
auto const &l = layers[layer_idx];
// standard opt relu
if (l->op_type == OP_LINEAR && layer_idx >= 2 &&
layers[layer_idx - 1]->op_type == OP_RELU &&
layers[layer_idx - 2]->op_type == OP_LINEAR) {
return true;
}
// mlp layer with relu embedded in first dense layer
if (l->op_type == OP_LINEAR && layer_idx >= 1 &&
layers[layer_idx - 1]->op_type == OP_LINEAR) {
long long value;
layers[layer_idx - 1]->get_int_property("activation", value);
ActiMode activation = (ActiMode)value;
if (activation == AC_MODE_RELU) {
return true;
}
}
return false;
}

void FFModel::create_operators_from_layers() {
std::map<const Tensor, ParallelTensor> tensors_to_parallel_tensors;
// for (auto const &l : layers) {
Expand Down Expand Up @@ -3280,9 +3301,9 @@ void FFModel::create_operators_from_layers() {
config.tensor_parallelism_degree > 1 &&
(l->op_type == OP_INC_MULTIHEAD_SELF_ATTENTION ||
l->op_type == OP_TREE_INC_MULTIHEAD_SELF_ATTENTION ||
(l->op_type == OP_LINEAR && layer_idx >= 2 &&
layers[layer_idx - 1]->op_type == OP_RELU &&
layers[layer_idx - 2]->op_type == OP_LINEAR) ||
// mlp layer
is_mlp_block(layer_idx) ||
// llama mlp layer
(l->op_type == OP_LINEAR && layer_idx >= 2 &&
layers[layer_idx - 1]->op_type == OP_GELU &&
layers[layer_idx - 2]->op_type == OP_LINEAR) ||
Expand Down
3 changes: 2 additions & 1 deletion src/runtime/request_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ std::string LoadBytesFromFile(std::string const &path) {

RequestManager::RequestManager()
: request_manager_status(INITIALIZED), verbose(false),
next_available_guid(1000000), num_processed_requests(0) {
next_available_guid(1000000), num_processed_requests(0),
total_request_run_time(0.0f){
// The following config parameters are set
// during ffmodel.compile()
// Initialize them to -1 to make sure no one
Expand Down
1 change: 1 addition & 0 deletions tests/inference/python_inference_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ set -e
cd "${BASH_SOURCE[0]%/*}"

# Generate test configs
rm -rf python_test_configs/*.json
python python_test_configs/generate_configs.py

# Run all tests
Expand Down

0 comments on commit 70212f6

Please sign in to comment.