Skip to content

Commit

Permalink
fix issues for mpt models
Browse files Browse the repository at this point in the history
  • Loading branch information
april-yyt committed Apr 5, 2024
1 parent ee41f3a commit 8efb92b
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions python/flexflow/serve/models/mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ def build_model(self, max_tokens_per_batch):

self.ffmodel = ffmodel

# TODO: finish this
def convert_hf_weight_name(name):
return (
name.replace("transformer.blocks.", "layers.")
Expand Down Expand Up @@ -293,11 +292,14 @@ def convert_ff_weight_name(name):
if "norm_f" in converted_name or "wte" in converted_name:
converted_name = converted_name.replace("_", ".").replace("norm.f", "norm_f")

converted_name = converted_name.replace("attention_wo", "attn.out_proj")
converted_name = converted_name.replace("attn.o_proj", "attn.out_proj")
converted_name = converted_name.replace("ffn_", "ffn.")
converted_name = re.sub(r"layers.(\d+).", r"transformer.blocks.\1.", converted_name)
converted_name = re.sub(r"_(bias|weight)$", r".\1", converted_name)

if ("wte" in converted_name) or ("norm_f" in converted_name):
converted_name = "transformer." + converted_name

return converted_name

def load_weights_into_hf_model(model, src_folder):
Expand All @@ -320,7 +322,6 @@ def load_weights_into_hf_model(model, src_folder):
print("skipping rev_sha.txt")
continue
elif "lm_head" in weight_path:
# todo: double check how to handle lm_head in uploading mpt models
print("skipping lm_head.weight")
continue
else:
Expand All @@ -331,9 +332,10 @@ def load_weights_into_hf_model(model, src_folder):
raise FileNotFoundError(f"No weight file found for {file_name}")

weight_data = np.fromfile(weight_path, dtype=np.float32)
print(f"Data type after conversion: {weight_data.dtype}, Size: {weight_data.size}")

# Special handling for combined QKV weights
if ("wq" in file_name) or ("wk" in file_name) or ("wv" in file_name):
if ("q_proj" in file_name) or ("k_proj" in file_name) or ("v_proj" in file_name):
layer_num_match = re.search(r"layers\.(\d+)", original_name)
layer_num = int(layer_num_match.group(1)) if layer_num_match else None
qkv_type = original_name.split("_")[-2]
Expand Down

0 comments on commit 8efb92b

Please sign in to comment.