Skip to content

Commit

Permalink
Make MAX_BATCH_SIZE, MAX_NUM_TOKENS, MAX_SEQ_LENGTH user-provided inp…
Browse files Browse the repository at this point in the history
…ut arguments (#1018)

* add max_tokens_per_batch, max_requests_per_batch, and max_sequence_length in RequestManager

* initial implementation

* fix c++ examples

* fix

* .

* more tries to fix

* remove MAX_SEQ_LENGTH

---------

Co-authored-by: zwang86 <46699021+zwang86@users.noreply.github.com>
Co-authored-by: Xinhao Cheng <99570243+xinhaoc@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 1, 2023
1 parent 5919fff commit d9a95ef
Show file tree
Hide file tree
Showing 43 changed files with 547 additions and 346 deletions.
11 changes: 8 additions & 3 deletions include/flexflow/batch_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,17 @@ class BatchConfig {
BatchConfig();
int num_active_requests() const;
int num_active_tokens() const;
static int max_requests_per_batch();
static int max_tokens_per_batch();
static int max_sequence_length();
void print() const;
virtual InferenceMode get_mode() const;
static BatchConfig const *from_future(BatchConfigFuture const &future);
static int const MAX_NUM_REQUESTS = 7;
static int const MAX_NUM_TOKENS = 64;
static int const MAX_SEQ_LENGTH = 256;
// Maximum possible values for different parameters
// These maximum values are used for copying BatchConfig
// across workers
static int const MAX_NUM_REQUESTS = 64;
static int const MAX_NUM_TOKENS = 1024;

// These are set by update
int num_tokens;
Expand Down
9 changes: 9 additions & 0 deletions include/flexflow/flexflow_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -965,6 +965,15 @@ flexflow_request_manager_t flexflow_request_manager_get_request_manager(void);

// void flexflow_request_manager_destroy(flexflow_request_manager_t handle_);

void flexflow_request_manager_set_max_requests_per_batch(
flexflow_request_manager_t handle_, int max_num_requests);

void flexflow_request_manager_set_max_tokens_per_batch(
flexflow_request_manager_t handle_, int max_num_tokens);

void flexflow_request_manager_set_max_sequence_length(
flexflow_request_manager_t handle_, int max_seq_length);

void flexflow_request_manager_register_tokenizer(
flexflow_request_manager_t handle_,
enum ModelType model_type,
Expand Down
14 changes: 12 additions & 2 deletions include/flexflow/request_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ using tokenizers::Tokenizer;

class InferenceManager {
public:
InferenceManager(FFConfig const &config, int max_num_tokens_per_batch);
InferenceManager(FFConfig const &config);
static InferenceManager *get_inference_manager();
void compile_model_and_allocate_buffer(FFModel *model);
void init_operators_inference(FFModel *model);
Expand All @@ -46,7 +46,6 @@ class InferenceManager {
public:
FFConfig ff_config;
std::unordered_map<ParallelTensor, std::vector<ParallelTensor>> tensor_buffer;
int max_num_tokens_per_batch;
int num_devices;
};

Expand Down Expand Up @@ -96,6 +95,12 @@ class RequestManager {
size_t get_num_processed_requests();
size_t get_num_ssms();

void set_max_requests_per_batch(int max_num_requests);
int get_max_requests_per_batch();
void set_max_tokens_per_batch(int max_num_tokens);
int get_max_tokens_per_batch();
void set_max_sequence_length(int max_seq_length);
int get_max_sequence_length();
int register_ssm_model(FFModel *model);
void register_tokenizer(ModelType model_type,
int bos_token_id,
Expand Down Expand Up @@ -201,6 +206,11 @@ class RequestManager {
Legion::Runtime *runtime);

private:
// configuration parameters
int max_requests_per_batch;
int max_tokens_per_batch;
int max_sequence_length;
// private fields
std::unique_ptr<Tokenizer> tokenizer_;
bool verbose;
ModelType model_type;
Expand Down
28 changes: 26 additions & 2 deletions inference/incr_decoding/incr_decoding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ void parse_input_args(char **argv,
bool &verbose,
bool &do_sample,
float &temperature,
float &topp) {
float &topp,
int &max_requests_per_batch,
int &max_tokens_per_batch,
int &max_sequence_length) {
for (int i = 1; i < argc; i++) {
// llm model type
if (!strcmp(argv[i], "-llm-model")) {
Expand Down Expand Up @@ -89,6 +92,18 @@ void parse_input_args(char **argv,
topp = std::stof(argv[++i]);
continue;
}
if (!strcmp(argv[i], "--max-requests-per-batch")) {
max_requests_per_batch = std::stoi(argv[++i]);
continue;
}
if (!strcmp(argv[i], "--max-tokens-per-batch")) {
max_tokens_per_batch = std::stoi(argv[++i]);
continue;
}
if (!strcmp(argv[i], "--max-sequence-length")) {
max_sequence_length = std::stoi(argv[++i]);
continue;
}
}
if (paths.cache_folder_path.empty()) {
paths.cache_folder_path = "~/.cache/flexflow";
Expand All @@ -115,6 +130,9 @@ void FlexFlow::top_level_task(Task const *task,
bool do_sample = false;
float temperature = 0.0f;
float topp = 0.0f;
int max_requests_per_batch = 8;
int max_tokens_per_batch = 128;
int max_sequence_length = 256;

InputArgs const &command_args = HighLevelRuntime::get_input_args();
char **argv = command_args.argv;
Expand All @@ -127,7 +145,10 @@ void FlexFlow::top_level_task(Task const *task,
verbose,
do_sample,
temperature,
topp);
topp,
max_requests_per_batch,
max_tokens_per_batch,
max_sequence_length);

assert(ffconfig.data_parallelism_degree * ffconfig.tensor_parallelism_degree *
ffconfig.pipeline_parallelism_degree ==
Expand Down Expand Up @@ -191,6 +212,9 @@ void FlexFlow::top_level_task(Task const *task,

GenerationConfig generationConfig(do_sample, temperature, topp);
RequestManager *rm = RequestManager::get_request_manager();
rm->set_max_requests_per_batch(max_requests_per_batch);
rm->set_max_tokens_per_batch(max_tokens_per_batch);
rm->set_max_sequence_length(max_sequence_length);
rm->register_tokenizer(
model_type, bos_token_id, eos_token_id, tokenizer_filepath);
rm->register_output_filepath(file_paths.output_file_path);
Expand Down
4 changes: 2 additions & 2 deletions inference/models/falcon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ void FALCON::create_falcon_model(FFModel &ff,

Tensor input;
{
assert(falcon_config.max_num_tokens <= BatchConfig::MAX_NUM_TOKENS);
int const token_dims[] = {BatchConfig::MAX_NUM_TOKENS, 1};
// assert(falcon_config.max_num_tokens <= BatchConfig::MAX_NUM_TOKENS);
int const token_dims[] = {BatchConfig::max_tokens_per_batch(), 1};
input = ff.create_tensor<2>(token_dims, DT_INT32);
}

Expand Down
11 changes: 6 additions & 5 deletions inference/models/falcon.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ class FALCON {
<< std::endl;
assert(false);
}
max_seq_len = BatchConfig::MAX_SEQ_LENGTH;
max_num_tokens = BatchConfig::MAX_NUM_TOKENS;
// max_seq_len = BatchConfig::MAX_SEQ_LENGTH;
// max_num_tokens = BatchConfig::MAX_NUM_TOKENS;
max_beam_width = BeamSearchBatchConfig::MAX_BEAM_WIDTH;
max_beam_depth = BeamSearchBatchConfig::MAX_BEAM_DEPTH;
}
Expand All @@ -77,16 +77,17 @@ class FALCON {
std::cout << "\tparallel_attn: " << parallel_attn << std::endl;
std::cout << "\tvocab_size: " << vocab_size << std::endl;

std::cout << "\tmax_seq_len: " << max_seq_len << std::endl;
std::cout << "\tmax_num_tokens: " << max_num_tokens << std::endl;
// std::cout << "\tmax_seq_len: " << max_seq_len << std::endl;
// std::cout << "\tmax_num_tokens: " << max_num_tokens << std::endl;
std::cout << "\tmax_beam_width: " << max_beam_width << std::endl;
std::cout << "\tmax_beam_depth: " << max_beam_depth << std::endl;
}

bool bias, multi_query, parallel_attn;
int hidden_size, n_head, n_head_kv, n_layer, vocab_size;
float layer_norm_epsilon;
int max_seq_len, max_num_tokens, max_beam_width, max_beam_depth;
// int max_seq_len, max_num_tokens;
int max_beam_width, max_beam_depth;
};

static void create_falcon_model(FFModel &ff,
Expand Down
3 changes: 1 addition & 2 deletions inference/models/llama.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ void LLAMA::create_llama_model(FFModel &ff,

Tensor input;
{
assert(llama_config.max_num_tokens <= BatchConfig::MAX_NUM_TOKENS);
int const token_dims[] = {BatchConfig::MAX_NUM_TOKENS, 1};
int const token_dims[] = {BatchConfig::max_tokens_per_batch(), 1};
input = ff.create_tensor<2>(token_dims, DT_INT32);
}

Expand Down
11 changes: 6 additions & 5 deletions inference/models/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ class LLAMA {
<< std::endl;
assert(false);
}
max_seq_len = BatchConfig::MAX_SEQ_LENGTH;
max_num_tokens = BatchConfig::MAX_NUM_TOKENS;
// max_seq_len = BatchConfig::MAX_SEQ_LENGTH;
// max_num_tokens = BatchConfig::MAX_NUM_TOKENS;
max_beam_width = BeamSearchBatchConfig::MAX_BEAM_WIDTH;
max_beam_depth = BeamSearchBatchConfig::MAX_BEAM_DEPTH;
}
Expand All @@ -65,13 +65,14 @@ class LLAMA {
std::cout << "\trms_norm_eps: " << rms_norm_eps << std::endl;
std::cout << "\tintermediate_size: " << intermediate_size << std::endl;

std::cout << "\tmax_seq_len: " << max_seq_len << std::endl;
std::cout << "\tmax_num_tokens: " << max_num_tokens << std::endl;
// std::cout << "\tmax_seq_len: " << max_seq_len << std::endl;
// std::cout << "\tmax_num_tokens: " << max_num_tokens << std::endl;
std::cout << "\tmax_beam_width: " << max_beam_width << std::endl;
std::cout << "\tmax_beam_depth: " << max_beam_depth << std::endl;
}

int max_seq_len, max_num_tokens, max_beam_width, max_beam_depth;
// int max_seq_len, max_num_tokens;
int max_beam_width, max_beam_depth;
int num_hidden_layers, vocab_size, num_attention_heads, hidden_size,
intermediate_size;
float rms_norm_eps;
Expand Down
2 changes: 1 addition & 1 deletion inference/models/mpt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ void MPT::create_mpt_model(FFModel &ff,
//------------------------------ build the model --------------------------
Tensor input;
{
int const token_dims[] = {BatchConfig::MAX_NUM_TOKENS, 1};
int const token_dims[] = {BatchConfig::max_tokens_per_batch(), 1};
input = ff.create_tensor<2>(token_dims, DT_INT32);
}

Expand Down
7 changes: 4 additions & 3 deletions inference/models/mpt.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ class MPT {
<< std::endl;
assert(false);
}
max_seq_len = BatchConfig::MAX_SEQ_LENGTH;
max_num_tokens = BatchConfig::MAX_NUM_TOKENS;
// max_seq_len = BatchConfig::MAX_SEQ_LENGTH;
// max_num_tokens = BatchConfig::MAX_NUM_TOKENS;
max_beam_width = BeamSearchBatchConfig::MAX_BEAM_WIDTH;
max_beam_depth = BeamSearchBatchConfig::MAX_BEAM_DEPTH;
}
Expand All @@ -60,7 +60,8 @@ class MPT {
std::cout << "\tvocab_size: " << vocab_size << std::endl;
}

int max_seq_len, max_num_tokens, max_beam_width, max_beam_depth;
// int max_seq_len, max_num_tokens;
int max_beam_width, max_beam_depth;
int hidden_size, n_heads, n_layers, vocab_size;
};

Expand Down
2 changes: 1 addition & 1 deletion inference/models/opt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ void OPT::create_opt_model(FFModel &ff,
Tensor position_input;
ff.set_position_offset(2);
{
int const token_dims[] = {BatchConfig::MAX_NUM_TOKENS, 1};
int const token_dims[] = {BatchConfig::max_tokens_per_batch(), 1};
input = ff.create_tensor<2>(token_dims, DT_INT32);
position_input = ff.create_tensor<2>(token_dims, DT_INT32);
}
Expand Down
11 changes: 6 additions & 5 deletions inference/models/opt.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ class OPT {
<< std::endl;
assert(false);
}
max_seq_len = BatchConfig::MAX_SEQ_LENGTH;
max_num_tokens = BatchConfig::MAX_NUM_TOKENS;
// max_seq_len = BatchConfig::MAX_SEQ_LENGTH;
// max_num_tokens = BatchConfig::MAX_NUM_TOKENS;
max_beam_width = BeamSearchBatchConfig::MAX_BEAM_WIDTH;
max_beam_depth = BeamSearchBatchConfig::MAX_BEAM_DEPTH;
}
Expand All @@ -79,13 +79,14 @@ class OPT {
std::cout << "\tword_embed_proj_dim: " << word_embed_proj_dim
<< std::endl;

std::cout << "\tmax_seq_len: " << max_seq_len << std::endl;
std::cout << "\tmax_num_tokens: " << max_num_tokens << std::endl;
// std::cout << "\tmax_seq_len: " << max_seq_len << std::endl;
// std::cout << "\tmax_num_tokens: " << max_num_tokens << std::endl;
std::cout << "\tmax_beam_width: " << max_beam_width << std::endl;
std::cout << "\tmax_beam_depth: " << max_beam_depth << std::endl;
}

int max_seq_len, max_num_tokens, max_beam_width, max_beam_depth;
// int max_seq_len, max_num_tokens;
int max_beam_width, max_beam_depth;
bool do_layer_norm_before, enable_bias, layer_norm_elementwise_affine;
float dropout;
int ffn_dim, hidden_size, max_position_embeddings, num_attention_heads,
Expand Down
4 changes: 2 additions & 2 deletions inference/models/starcoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ void STARCODER::create_starcoder_model(
Tensor position_input;
ff.set_position_offset(0);
{
assert(startcoder_config.max_num_tokens <= BatchConfig::MAX_NUM_TOKENS);
int const token_dims[] = {BatchConfig::MAX_NUM_TOKENS, 1};
// assert(startcoder_config.max_num_tokens <= BatchConfig::MAX_NUM_TOKENS);
int const token_dims[] = {BatchConfig::max_tokens_per_batch(), 1};
input = ff.create_tensor<2>(token_dims, DT_INT32);
position_input = ff.create_tensor<2>(token_dims, DT_INT32);
}
Expand Down
7 changes: 4 additions & 3 deletions inference/models/starcoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,16 @@ class STARCODER {
<< std::endl;
assert(false);
}
max_seq_len = BatchConfig::MAX_SEQ_LENGTH;
max_num_tokens = BatchConfig::MAX_NUM_TOKENS;
// max_seq_len = BatchConfig::MAX_SEQ_LENGTH;
// max_num_tokens = BatchConfig::MAX_NUM_TOKENS;
max_beam_width = BeamSearchBatchConfig::MAX_BEAM_WIDTH;
max_beam_depth = BeamSearchBatchConfig::MAX_BEAM_DEPTH;
}

void print() const {}

int max_seq_len, max_num_tokens, max_beam_width, max_beam_depth;
// int max_seq_len, max_num_tokens;
int max_beam_width, max_beam_depth;
int num_hidden_layers, vocab_size, num_attention_heads, hidden_size,
intermediate_size, max_position_embeddings;
float layer_norm_epsilon, dropout_p;
Expand Down
2 changes: 1 addition & 1 deletion inference/python/incr_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def main():
)
llm.compile(
generation_config,
max_batch_size=1,
max_requests_per_batch=1,
max_seq_length=256,
max_tokens_per_batch=64,
)
Expand Down
4 changes: 2 additions & 2 deletions inference/python/spec_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,15 @@ def main():
for ssm in ssms:
ssm.compile(
generation_config,
max_batch_size=1,
max_requests_per_batch=1,
max_seq_length=256,
max_tokens_per_batch=64,
)

# Compile the LLM for inference and load the weights into memory
llm.compile(
generation_config,
max_batch_size=1,
max_requests_per_batch=1,
max_seq_length=256,
max_tokens_per_batch=64,
ssms=ssms,
Expand Down
Loading

0 comments on commit d9a95ef

Please sign in to comment.