diff --git a/deps/legion b/deps/legion index 626b55689c..d065278678 160000 --- a/deps/legion +++ b/deps/legion @@ -1 +1 @@ -Subproject commit 626b55689c77848b246e1da19678c7ad58899f0c +Subproject commit d0652786784249e933dd62f675591da99a5e960d diff --git a/inference/models/falcon.cc b/inference/models/falcon.cc index 999ca37037..cf6e90a7de 100644 --- a/inference/models/falcon.cc +++ b/inference/models/falcon.cc @@ -39,7 +39,7 @@ void FALCON::create_falcon_model(FFModel &ff, Tensor input; { // assert(falcon_config.max_num_tokens <= BatchConfig::MAX_NUM_TOKENS); - int const token_dims[] = {mode == TREE_VERIFY_MODE + int const token_dims[] = {(mode == TREE_VERIFY_MODE || mode == BEAM_SEARCH_MODE) ? BatchConfig::max_verify_tokens_per_batch() : BatchConfig::max_tokens_per_batch(), 1}; diff --git a/inference/models/llama.cc b/inference/models/llama.cc index e54d6d8811..3deba47953 100644 --- a/inference/models/llama.cc +++ b/inference/models/llama.cc @@ -41,7 +41,7 @@ void LLAMA::create_llama_model(FFModel &ff, Tensor input; { - int const token_dims[] = {mode == TREE_VERIFY_MODE + int const token_dims[] = {(mode == TREE_VERIFY_MODE || mode == BEAM_SEARCH_MODE) ? BatchConfig::max_verify_tokens_per_batch() : BatchConfig::max_tokens_per_batch(), 1}; diff --git a/inference/models/mpt.cc b/inference/models/mpt.cc index 3df67b264c..484a09f62e 100644 --- a/inference/models/mpt.cc +++ b/inference/models/mpt.cc @@ -40,7 +40,7 @@ void MPT::create_mpt_model(FFModel &ff, //------------------------------ build the model -------------------------- Tensor input; { - int const token_dims[] = {mode == TREE_VERIFY_MODE + int const token_dims[] = {(mode == TREE_VERIFY_MODE || mode == BEAM_SEARCH_MODE) ? BatchConfig::max_verify_tokens_per_batch() : BatchConfig::max_tokens_per_batch(), 1}; diff --git a/inference/models/opt.cc b/inference/models/opt.cc index e260f8fa36..9f75dcea4c 100644 --- a/inference/models/opt.cc +++ b/inference/models/opt.cc @@ -42,7 +42,7 @@ void OPT::create_opt_model(FFModel &ff, Tensor position_input; ff.set_position_offset(2); { - int const token_dims[] = {mode == TREE_VERIFY_MODE + int const token_dims[] = {(mode == TREE_VERIFY_MODE || mode == BEAM_SEARCH_MODE) ? BatchConfig::max_verify_tokens_per_batch() : BatchConfig::max_tokens_per_batch(), 1}; diff --git a/inference/models/starcoder.cc b/inference/models/starcoder.cc index e683376e47..ef5388b6ca 100644 --- a/inference/models/starcoder.cc +++ b/inference/models/starcoder.cc @@ -48,7 +48,7 @@ void STARCODER::create_starcoder_model( ff.set_position_offset(0); { // assert(startcoder_config.max_num_tokens <= BatchConfig::MAX_NUM_TOKENS); - int const token_dims[] = {mode == TREE_VERIFY_MODE + int const token_dims[] = {(mode == TREE_VERIFY_MODE || mode == BEAM_SEARCH_MODE) ? BatchConfig::max_verify_tokens_per_batch() : BatchConfig::max_tokens_per_batch(), 1};