diff --git a/inference/models/falcon.cc b/inference/models/falcon.cc index b0592655c9..e00f4e9cfd 100644 --- a/inference/models/falcon.cc +++ b/inference/models/falcon.cc @@ -39,10 +39,11 @@ void FALCON::create_falcon_model(FFModel &ff, Tensor input; { // assert(falcon_config.max_num_tokens <= BatchConfig::MAX_NUM_TOKENS); - int const token_dims[] = {mode == TREE_VERIFY_MODE - ? BatchConfig::max_verify_tokens_per_batch() - : BatchConfig::max_tokens_per_batch(), - 1}; + int const token_dims[] = { + (mode == TREE_VERIFY_MODE || mode == BEAM_SEARCH_MODE) + ? BatchConfig::max_verify_tokens_per_batch() + : BatchConfig::max_tokens_per_batch(), + 1}; input = ff.create_tensor<2>(token_dims, DT_INT32); } diff --git a/inference/models/llama.cc b/inference/models/llama.cc index db6e34475b..14b8c31fa1 100644 --- a/inference/models/llama.cc +++ b/inference/models/llama.cc @@ -41,10 +41,11 @@ void LLAMA::create_llama_model(FFModel &ff, Tensor input; { - int const token_dims[] = {mode == TREE_VERIFY_MODE - ? BatchConfig::max_verify_tokens_per_batch() - : BatchConfig::max_tokens_per_batch(), - 1}; + int const token_dims[] = { + (mode == TREE_VERIFY_MODE || mode == BEAM_SEARCH_MODE) + ? BatchConfig::max_verify_tokens_per_batch() + : BatchConfig::max_tokens_per_batch(), + 1}; input = ff.create_tensor<2>(token_dims, DT_INT32); } diff --git a/inference/models/mpt.cc b/inference/models/mpt.cc index abfff10373..7e8fc8358f 100644 --- a/inference/models/mpt.cc +++ b/inference/models/mpt.cc @@ -40,10 +40,11 @@ void MPT::create_mpt_model(FFModel &ff, //------------------------------ build the model -------------------------- Tensor input; { - int const token_dims[] = {mode == TREE_VERIFY_MODE - ? BatchConfig::max_verify_tokens_per_batch() - : BatchConfig::max_tokens_per_batch(), - 1}; + int const token_dims[] = { + (mode == TREE_VERIFY_MODE || mode == BEAM_SEARCH_MODE) + ? BatchConfig::max_verify_tokens_per_batch() + : BatchConfig::max_tokens_per_batch(), + 1}; input = ff.create_tensor<2>(token_dims, DT_INT32); } diff --git a/inference/models/opt.cc b/inference/models/opt.cc index a6edf40837..3ff4c96fdf 100644 --- a/inference/models/opt.cc +++ b/inference/models/opt.cc @@ -42,10 +42,11 @@ void OPT::create_opt_model(FFModel &ff, Tensor position_input; ff.set_position_offset(2); { - int const token_dims[] = {mode == TREE_VERIFY_MODE - ? BatchConfig::max_verify_tokens_per_batch() - : BatchConfig::max_tokens_per_batch(), - 1}; + int const token_dims[] = { + (mode == TREE_VERIFY_MODE || mode == BEAM_SEARCH_MODE) + ? BatchConfig::max_verify_tokens_per_batch() + : BatchConfig::max_tokens_per_batch(), + 1}; input = ff.create_tensor<2>(token_dims, DT_INT32); position_input = ff.create_tensor<2>(token_dims, DT_INT32); } diff --git a/inference/models/starcoder.cc b/inference/models/starcoder.cc index 360313e227..2327c86119 100644 --- a/inference/models/starcoder.cc +++ b/inference/models/starcoder.cc @@ -48,10 +48,11 @@ void STARCODER::create_starcoder_model( ff.set_position_offset(0); { // assert(startcoder_config.max_num_tokens <= BatchConfig::MAX_NUM_TOKENS); - int const token_dims[] = {mode == TREE_VERIFY_MODE - ? BatchConfig::max_verify_tokens_per_batch() - : BatchConfig::max_tokens_per_batch(), - 1}; + int const token_dims[] = { + (mode == TREE_VERIFY_MODE || mode == BEAM_SEARCH_MODE) + ? BatchConfig::max_verify_tokens_per_batch() + : BatchConfig::max_tokens_per_batch(), + 1}; input = ff.create_tensor<2>(token_dims, DT_INT32); position_input = ff.create_tensor<2>(token_dims, DT_INT32); } diff --git a/python/flexflow/serve/models/falcon.py b/python/flexflow/serve/models/falcon.py index 2b114f09b3..e9cd789bcc 100644 --- a/python/flexflow/serve/models/falcon.py +++ b/python/flexflow/serve/models/falcon.py @@ -23,6 +23,7 @@ def __init__(self, hf_config): #self.max_num_tokens = 64 self.max_beam_width = 1 self.max_beam_depth = 8 + self.max_spec_tree_token_num = 64 self.bias = hf_config.bias self.hidden_size = hf_config.hidden_size self.layer_norm_epsilon = hf_config.layer_norm_epsilon @@ -70,6 +71,7 @@ def __init__( self.weights_filepath = weights_filepath self.tokenizer_filepath = tokenizer_filepath self.maxint = 2**31 - 1 + max_verify_tokens_per_batch = max_tokens_per_batch + self.falcon_config.max_spec_tree_token_num # Sanity checks if self.falcon_config.hidden_size % self.falcon_config.n_head != 0: @@ -84,7 +86,7 @@ def __init__( f"Number of q attention heads ({self.falcon_config.n_head}) is smaller, or not divisible by tensor parallelism degree ({self.ffconfig.tensor_parallelism_degree})" ) - self.build_model(max_tokens_per_batch) + self.build_model(max_tokens_per_batch if self.mode == InferenceMode.INC_DECODING_MODE else max_verify_tokens_per_batch) def build_model(self, max_tokens_per_batch): ffmodel = FFModel(self.ffconfig) diff --git a/python/flexflow/serve/models/llama.py b/python/flexflow/serve/models/llama.py index 7ba0e78a37..900ab48bcd 100644 --- a/python/flexflow/serve/models/llama.py +++ b/python/flexflow/serve/models/llama.py @@ -23,6 +23,7 @@ def __init__(self, hf_config): #self.max_num_tokens = 64 self.max_beam_width = 1 self.max_beam_depth = 8 + self.max_spec_tree_token_num = 64 self.num_hidden_layers = hf_config.num_hidden_layers self.vocab_size = hf_config.vocab_size self.hidden_size = hf_config.hidden_size @@ -62,6 +63,8 @@ def __init__( self.weights_filepath = weights_filepath self.tokenizer_filepath = tokenizer_filepath self.maxint = 2**31 - 1 + max_verify_tokens_per_batch = max_tokens_per_batch + self.llama_config.max_spec_tree_token_num + # Sanity checks if self.llama_config.hidden_size % self.llama_config.num_attention_heads != 0: @@ -81,7 +84,7 @@ def __init__( f"Number of attention heads ({self.llama_config.num_attention_heads}) is smaller, or not divisible by tensor parallelism degree ({self.ffconfig.tensor_parallelism_degree})" ) - self.build_model(max_tokens_per_batch) + self.build_model(max_tokens_per_batch if self.mode == InferenceMode.INC_DECODING_MODE else max_verify_tokens_per_batch) def build_model(self, max_tokens_per_batch): ffmodel = FFModel(self.ffconfig) diff --git a/python/flexflow/serve/models/mpt.py b/python/flexflow/serve/models/mpt.py index 79a5bb940f..c0f995bf22 100644 --- a/python/flexflow/serve/models/mpt.py +++ b/python/flexflow/serve/models/mpt.py @@ -23,6 +23,7 @@ def __init__(self, hf_config): #self.max_num_tokens = 64 self.max_beam_width = 1 self.max_beam_depth = 8 + self.max_spec_tree_token_num = 64 self.hidden_size = hf_config.d_model self.n_heads = hf_config.n_heads self.n_layers = hf_config.n_layers @@ -57,6 +58,8 @@ def __init__( self.weights_filepath = weights_filepath self.tokenizer_filepath = tokenizer_filepath self.maxint = 2**31 - 1 + max_verify_tokens_per_batch = max_tokens_per_batch + self.mpt_config.max_spec_tree_token_num + # Sanity checks if self.mpt_config.hidden_size % self.mpt_config.n_heads != 0: @@ -72,7 +75,7 @@ def __init__( raise ValueError( f"Number of attention heads ({self.mpt_config.n_heads}) is smaller, or not divisible by tensor parallelism degree ({self.ffconfig.tensor_parallelism_degree})" ) - self.build_model(max_tokens_per_batch) + self.build_model(max_tokens_per_batch if self.mode == InferenceMode.INC_DECODING_MODE else max_verify_tokens_per_batch) def build_model(self, max_tokens_per_batch): ffmodel = FFModel(self.ffconfig) diff --git a/python/flexflow/serve/models/opt.py b/python/flexflow/serve/models/opt.py index dd36fa6592..dc3f841a5a 100644 --- a/python/flexflow/serve/models/opt.py +++ b/python/flexflow/serve/models/opt.py @@ -23,6 +23,7 @@ def __init__(self, hf_config): #self.max_num_tokens = 64 self.max_beam_width = 1 self.max_beam_depth = 8 + self.max_spec_tree_token_num = 64 self.do_layer_norm_before = hf_config.do_layer_norm_before self.dropout = hf_config.dropout self.enable_bias = hf_config.enable_bias @@ -63,6 +64,7 @@ def __init__( self.weights_filepath = weights_filepath self.tokenizer_filepath = tokenizer_filepath self.maxint = 2**31 - 1 + max_verify_tokens_per_batch = max_tokens_per_batch + self.opt_config.max_spec_tree_token_num # Sanity checks if self.opt_config.hidden_size % self.opt_config.num_attention_heads != 0: @@ -82,7 +84,7 @@ def __init__( f"Number of attention heads ({self.opt_config.num_attention_heads}) is smaller, or not divisible by tensor parallelism degree ({self.ffconfig.tensor_parallelism_degree})" ) - self.build_model(max_tokens_per_batch) + self.build_model(max_tokens_per_batch if self.mode == InferenceMode.INC_DECODING_MODE else max_verify_tokens_per_batch) def build_model(self, max_tokens_per_batch): ffmodel = FFModel(self.ffconfig) diff --git a/python/flexflow/serve/models/starcoder.py b/python/flexflow/serve/models/starcoder.py index f4f28a70e1..4a6f191abd 100644 --- a/python/flexflow/serve/models/starcoder.py +++ b/python/flexflow/serve/models/starcoder.py @@ -23,6 +23,7 @@ def __init__(self, hf_config): #self.max_num_tokens = 64 self.max_beam_width = 1 self.max_beam_depth = 8 + self.max_spec_tree_token_num = 64 self.dropout_p = hf_config.attn_pdrop self.hidden_size = hf_config.n_embd self.layer_norm_epsilon = hf_config.layer_norm_epsilon @@ -61,6 +62,8 @@ def __init__( self.weights_filepath = weights_filepath self.tokenizer_filepath = tokenizer_filepath self.maxint = 2**31 - 1 + max_verify_tokens_per_batch = max_tokens_per_batch + self.starcoder_config.max_spec_tree_token_num + # Sanity checks if ( @@ -84,7 +87,7 @@ def __init__( f"Number of attention heads ({self.starcoder_config.num_attention_heads}) is smaller, or not divisible by tensor parallelism degree ({self.ffconfig.tensor_parallelism_degree})" ) - self.build_model(max_tokens_per_batch) + self.build_model(max_tokens_per_batch if self.mode == InferenceMode.INC_DECODING_MODE else max_verify_tokens_per_batch) def build_model(self, max_tokens_per_batch): ffmodel = FFModel(self.ffconfig)