Skip to content

Commit

Permalink
fix spec decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
jiazhihao committed Jan 12, 2024
1 parent 9c85a4f commit 197e308
Show file tree
Hide file tree
Showing 6 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion deps/legion
Submodule legion updated from 626b55 to d06527
2 changes: 1 addition & 1 deletion inference/models/falcon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ void FALCON::create_falcon_model(FFModel &ff,
Tensor input;
{
// assert(falcon_config.max_num_tokens <= BatchConfig::MAX_NUM_TOKENS);
int const token_dims[] = {mode == TREE_VERIFY_MODE
int const token_dims[] = {(mode == TREE_VERIFY_MODE || mode == BEAM_SEARCH_MODE)
? BatchConfig::max_verify_tokens_per_batch()
: BatchConfig::max_tokens_per_batch(),
1};
Expand Down
2 changes: 1 addition & 1 deletion inference/models/llama.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ void LLAMA::create_llama_model(FFModel &ff,

Tensor input;
{
int const token_dims[] = {mode == TREE_VERIFY_MODE
int const token_dims[] = {(mode == TREE_VERIFY_MODE || mode == BEAM_SEARCH_MODE)
? BatchConfig::max_verify_tokens_per_batch()
: BatchConfig::max_tokens_per_batch(),
1};
Expand Down
2 changes: 1 addition & 1 deletion inference/models/mpt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ void MPT::create_mpt_model(FFModel &ff,
//------------------------------ build the model --------------------------
Tensor input;
{
int const token_dims[] = {mode == TREE_VERIFY_MODE
int const token_dims[] = {(mode == TREE_VERIFY_MODE || mode == BEAM_SEARCH_MODE)
? BatchConfig::max_verify_tokens_per_batch()
: BatchConfig::max_tokens_per_batch(),
1};
Expand Down
2 changes: 1 addition & 1 deletion inference/models/opt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ void OPT::create_opt_model(FFModel &ff,
Tensor position_input;
ff.set_position_offset(2);
{
int const token_dims[] = {mode == TREE_VERIFY_MODE
int const token_dims[] = {(mode == TREE_VERIFY_MODE || mode == BEAM_SEARCH_MODE)
? BatchConfig::max_verify_tokens_per_batch()
: BatchConfig::max_tokens_per_batch(),
1};
Expand Down
2 changes: 1 addition & 1 deletion inference/models/starcoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ void STARCODER::create_starcoder_model(
ff.set_position_offset(0);
{
// assert(startcoder_config.max_num_tokens <= BatchConfig::MAX_NUM_TOKENS);
int const token_dims[] = {mode == TREE_VERIFY_MODE
int const token_dims[] = {(mode == TREE_VERIFY_MODE || mode == BEAM_SEARCH_MODE)
? BatchConfig::max_verify_tokens_per_batch()
: BatchConfig::max_tokens_per_batch(),
1};
Expand Down

0 comments on commit 197e308

Please sign in to comment.