From 8efb92bd7650dbe2219c0dbf46bbc936c40ab626 Mon Sep 17 00:00:00 2001 From: April Yang Date: Fri, 5 Apr 2024 00:23:01 +0000 Subject: [PATCH] fix issues for mpt models --- python/flexflow/serve/models/mpt.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/flexflow/serve/models/mpt.py b/python/flexflow/serve/models/mpt.py index e67ccc42d4..9d4d0f31bb 100644 --- a/python/flexflow/serve/models/mpt.py +++ b/python/flexflow/serve/models/mpt.py @@ -250,7 +250,6 @@ def build_model(self, max_tokens_per_batch): self.ffmodel = ffmodel - # TODO: finish this def convert_hf_weight_name(name): return ( name.replace("transformer.blocks.", "layers.") @@ -293,11 +292,14 @@ def convert_ff_weight_name(name): if "norm_f" in converted_name or "wte" in converted_name: converted_name = converted_name.replace("_", ".").replace("norm.f", "norm_f") - converted_name = converted_name.replace("attention_wo", "attn.out_proj") + converted_name = converted_name.replace("attn.o_proj", "attn.out_proj") converted_name = converted_name.replace("ffn_", "ffn.") converted_name = re.sub(r"layers.(\d+).", r"transformer.blocks.\1.", converted_name) converted_name = re.sub(r"_(bias|weight)$", r".\1", converted_name) + if ("wte" in converted_name) or ("norm_f" in converted_name): + converted_name = "transformer." + converted_name + return converted_name def load_weights_into_hf_model(model, src_folder): @@ -320,7 +322,6 @@ def load_weights_into_hf_model(model, src_folder): print("skipping rev_sha.txt") continue elif "lm_head" in weight_path: - # todo: double check how to handle lm_head in uploading mpt models print("skipping lm_head.weight") continue else: @@ -331,9 +332,10 @@ def load_weights_into_hf_model(model, src_folder): raise FileNotFoundError(f"No weight file found for {file_name}") weight_data = np.fromfile(weight_path, dtype=np.float32) + print(f"Data type after conversion: {weight_data.dtype}, Size: {weight_data.size}") # Special handling for combined QKV weights - if ("wq" in file_name) or ("wk" in file_name) or ("wv" in file_name): + if ("q_proj" in file_name) or ("k_proj" in file_name) or ("v_proj" in file_name): layer_num_match = re.search(r"layers\.(\d+)", original_name) layer_num = int(layer_num_match.group(1)) if layer_num_match else None qkv_type = original_name.split("_")[-2]