Skip to content

Commit

Permalink
Bug fixes, uploaded missing cpp implmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
yingchen21 committed Sep 18, 2024
1 parent 6533245 commit 281a8bf
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 23 deletions.
6 changes: 3 additions & 3 deletions inference/models/falcon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,14 @@ void FALCON::create_falcon_model(FFModel &ff,
3, // q, k, v. need to change if want to remove replication.
// (q_heads + 2 * kv_heads) * proj_size
AC_MODE_NONE,
false, // seems like llama does not use bias
false, // seems like it does not use bias
DT_NONE, // what is this
nullptr, // ?
nullptr, // ?
nullptr, // ?
REG_MODE_NONE, // no regularization
0.0f, // no dropout
std::string("layers." + std::to_string(i) + ".self_attn.qkv_proj")
std::string("layers." + std::to_string(i) + ".self_attention.qkv_proj")
.c_str());
qkv_proj->print("qkv_proj");

Expand Down Expand Up @@ -206,7 +206,7 @@ void FALCON::create_falcon_model(FFModel &ff,
nullptr,
REG_MODE_NONE,
0.0f,
std::string("layers." + std::to_string(i) + ".self_attn.o_proj")
std::string("layers." + std::to_string(i) + ".self_attention.o_proj")
.c_str());
mha->print("mha");

Expand Down
44 changes: 37 additions & 7 deletions inference/models/mpt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,27 @@ void MPT::create_mpt_model(FFModel &ff,
layernorm_output = res_ln_outputs[1];
}

Tensor attn_outputs;
Tensor qkv_proj = ff.dense(
layernorm_output,
mpt_config.hidden_size *
3, // q, k, v. need to change if want to remove replication.
// (q_heads + 2 * kv_heads) * proj_size
AC_MODE_NONE,
false, // seems like it does not use bias
DT_NONE, // what is this
nullptr, // ?
nullptr, // ?
nullptr, // ?
REG_MODE_NONE, // no regularization
0.0f, // no dropout
std::string("layers." + std::to_string(i) + ".attn.qkv_proj")
.c_str());

Tensor o_proj;
switch (mode) {
case BEAM_SEARCH_MODE: {
attn_outputs = ff.spec_inc_multihead_self_attention(
layernorm_output,
o_proj = ff.spec_inc_multihead_self_attention(
qkv_proj,
mpt_config.hidden_size,
mpt_config.n_heads,
mpt_config.hidden_size / mpt_config.n_heads,
Expand All @@ -120,8 +136,8 @@ void MPT::create_mpt_model(FFModel &ff,
break;
}
case TREE_VERIFY_MODE: {
attn_outputs = ff.inc_multihead_self_attention_verify(
layernorm_output,
o_proj = ff.inc_multihead_self_attention_verify(
qkv_proj,
mpt_config.hidden_size,
mpt_config.n_heads,
mpt_config.hidden_size / mpt_config.n_heads,
Expand All @@ -144,8 +160,8 @@ void MPT::create_mpt_model(FFModel &ff,
break;
}
case INC_DECODING_MODE: {
attn_outputs = ff.inc_multihead_self_attention(
layernorm_output,
o_proj = ff.inc_multihead_self_attention(
qkv_proj,
mpt_config.hidden_size,
mpt_config.n_heads,
mpt_config.hidden_size / mpt_config.n_heads,
Expand All @@ -172,6 +188,20 @@ void MPT::create_mpt_model(FFModel &ff,
}
}

Tensor attn_outputs = ff.dense(
o_proj,
mpt_config.hidden_size,
AC_MODE_NONE,
false,
DT_NONE,
nullptr,
nullptr,
nullptr,
REG_MODE_NONE,
0.0f,
std::string("layers." + std::to_string(i) + ".attn.o_proj")
.c_str());

ff.residual_layer_norm(
attn_outputs,
hidden_states,
Expand Down
44 changes: 37 additions & 7 deletions inference/models/opt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,27 @@ void OPT::create_opt_model(FFModel &ff,
Tensor residual = res_ln_outputs[0];
Tensor hidden_states = res_ln_outputs[1];

Tensor mha;
Tensor qkv_proj = ff.dense(
hidden_states,
opt_config.hidden_size *
3, // q, k, v. need to change if want to remove replication.
// (q_heads + 2 * kv_heads) * proj_size
AC_MODE_NONE,
false, // seems like it does not use bias
DT_NONE, // what is this
nullptr, // ?
nullptr, // ?
nullptr, // ?
REG_MODE_NONE, // no regularization
0.0f, // no dropout
std::string("layers." + std::to_string(i) + ".self_attn.qkv_proj")
.c_str());

Tensor o_proj;
switch (mode) {
case BEAM_SEARCH_MODE: {
mha = ff.spec_inc_multihead_self_attention(
hidden_states,
o_proj = ff.spec_inc_multihead_self_attention(
qkv_proj,
opt_config.hidden_size,
opt_config.num_attention_heads,
opt_config.hidden_size / opt_config.num_attention_heads,
Expand All @@ -128,8 +144,8 @@ void OPT::create_opt_model(FFModel &ff,
break;
}
case TREE_VERIFY_MODE: {
mha = ff.inc_multihead_self_attention_verify(
hidden_states,
o_proj = ff.inc_multihead_self_attention_verify(
qkv_proj,
opt_config.hidden_size,
opt_config.num_attention_heads,
opt_config.hidden_size / opt_config.num_attention_heads,
Expand All @@ -152,8 +168,8 @@ void OPT::create_opt_model(FFModel &ff,
break;
}
case INC_DECODING_MODE: {
mha = ff.inc_multihead_self_attention(
hidden_states,
o_proj = ff.inc_multihead_self_attention(
qkv_proj,
opt_config.hidden_size,
opt_config.num_attention_heads,
opt_config.hidden_size / opt_config.num_attention_heads,
Expand All @@ -180,6 +196,20 @@ void OPT::create_opt_model(FFModel &ff,
}
}

Tensor mha = ff.dense(
o_proj,
opt_config.hidden_size,
AC_MODE_NONE,
false,
DT_NONE,
nullptr,
nullptr,
nullptr,
REG_MODE_NONE,
0.0f,
std::string("layers." + std::to_string(i) + ".self_attn.o_proj")
.c_str());

ff.add_bias_residual_layer_norm(mha,
residual,
res_ln_outputs,
Expand Down
35 changes: 33 additions & 2 deletions inference/models/starcoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,28 @@ void STARCODER::create_starcoder_model(
Tensor hidden_states = res_ln_outputs[0];
Tensor ln_1 = res_ln_outputs[1];

Tensor qkv_proj = ff.dense(
ln_1,
startcoder_config.hidden_size *
3, // q, k, v. need to change if want to remove replication.
// (q_heads + 2 * kv_heads) * proj_size
AC_MODE_NONE,
false, // seems like it does not use bias
DT_NONE, // what is this
nullptr, // ?
nullptr, // ?
nullptr, // ?
REG_MODE_NONE, // no regularization
0.0f, // no dropout
std::string("layers." + std::to_string(i) + ".self_attention.qkv_proj")
.c_str());

Tensor mha;
Tensor o_proj;
switch (mode) {
case INC_DECODING_MODE: {
mha = ff.inc_multiquery_self_attention(
ln_1,
o_proj = ff.inc_multiquery_self_attention(
qkv_proj,
startcoder_config.hidden_size,
startcoder_config.num_attention_heads,
1,
Expand Down Expand Up @@ -135,6 +152,20 @@ void STARCODER::create_starcoder_model(
}
}

mha = ff.dense(
o_proj,
startcoder_config.hidden_size,
AC_MODE_NONE,
false,
DT_NONE,
nullptr,
nullptr,
nullptr,
REG_MODE_NONE,
0.0f,
std::string("layers." + std::to_string(i) + ".self_attn.o_proj")
.c_str());

ff.residual_layer_norm(
hidden_states,
mha,
Expand Down
6 changes: 3 additions & 3 deletions python/flexflow/serve/models/mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@ def build_model(self, max_tokens_per_batch):

qkv_proj = ffmodel.dense(
layernorm_output,
3 * self.falcon_config.hidden_size,
3 * self.mpt_config.hidden_size,
ActiMode.AC_MODE_NONE,
False,
name=f"layers.{i}.self_attn.qkv_proj",
name=f"layers.{i}.attn.qkv_proj",
)

if self.mode == InferenceMode.BEAM_SEARCH_MODE:
Expand Down Expand Up @@ -208,7 +208,7 @@ def build_model(self, max_tokens_per_batch):
self.mpt_config.hidden_size,
ActiMode.AC_MODE_NONE,
False,
name=f"layers.{i}.self_attn.o_proj"
name=f"layers.{i}.attn.o_proj"
)

hidden_states, layernorm_output = ffmodel.residual_layer_norm(
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/file_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,7 @@ void FileDataLoader::load_single_weight_tensor(FFModel *ff,
// self_attn.qkv_proj or self_attn.o_proj
// so looking for self_attn. in the name can determine if it is an attention
// projection
if (weight_filename.find("self_attn.") != std::string::npos || weight_filename.find("self_attention.") != std::string::npos) {
if (weight_filename.find("attn.") != std::string::npos || weight_filename.find("self_attention.") != std::string::npos) {
size_t pos = weight_filename.find(".o_proj");
if (pos != std::string::npos) {
weight_filename.replace(pos, std::string(".o_proj").length(), "");
Expand Down

0 comments on commit 281a8bf

Please sign in to comment.