From 2d7dfa9a015ca8298f28d77ef322272e38a9024a Mon Sep 17 00:00:00 2001 From: osoleve Date: Sat, 12 Oct 2024 15:13:28 -0400 Subject: [PATCH] Rewrite translate_keys --- download_weights.py | 69 +++++++++++++++++++-------------------------- 1 file changed, 29 insertions(+), 40 deletions(-) diff --git a/download_weights.py b/download_weights.py index b8d19d3..b7449c9 100644 --- a/download_weights.py +++ b/download_weights.py @@ -9,52 +9,41 @@ from unittest.mock import patch from transformers.dynamic_module_utils import get_imports -def translate_key(in_key: str): - out_key = in_key.replace('.weight', '') - if out_key.startswith('model.'): - out_key = out_key.replace('model.', '') - if out_key.endswith('input_layernorm'): - out_key = out_key.replace('input_layernorm', 'attention_norm') - elif out_key.endswith('mlp.down_proj'): - out_key = out_key.replace('mlp.down_proj', 'feed_forward.w2') - elif out_key.endswith('mlp.gate_proj'): - out_key = out_key.replace('mlp.gate_proj', 'feed_forward.w1') - elif out_key.endswith('mlp.up_proj'): - out_key = out_key.replace('mlp.up_proj', 'feed_forward.w3') - elif out_key.endswith('post_attention_layernorm'): - out_key = out_key.replace('post_attention_layernorm', 'ffn_norm') - elif out_key.endswith('self_attn.k_proj'): - out_key = out_key.replace('self_attn.k_proj', 'attention.wk') - elif out_key.endswith('self_attn.o_proj'): - out_key = out_key.replace('self_attn.o_proj', 'attention.wo') - elif out_key.endswith('self_attn.q_proj'): - out_key = out_key.replace('self_attn.q_proj', 'attention.wq') - elif out_key.endswith('self_attn.v_proj'): - out_key = out_key.replace('self_attn.v_proj', 'attention.wv') - elif out_key.endswith('down_proj'): - out_key = out_key.replace('down_proj', 'w2') - elif out_key.endswith('gate_proj'): - out_key = out_key.replace('gate_proj', 'w1') - elif out_key.endswith('up_proj'): - out_key = out_key.replace('up_proj', 'w3') - elif out_key == 'embed_tokens': - out_key = 'tok_embeddings' - elif out_key == 'norm': - out_key = 'norm' - else: + +def translate_key(in_key): + out_key = in_key.removeprefix("model.").removesuffix(".weight") + + match out_key.split("."): + case [key] if key in {"lm_head", "embed_tokens", "norm"}: + out_key = out_key.replace("embed_tokens", "tok_embeddings") + out_key = out_key.replace("lm_head", "output") + + case [*_, key] if key.endswith("layernorm"): + out_key = out_key.replace("layernorm", "norm") + out_key = out_key.replace("input", "attention") + out_key = out_key.replace("post_attention", "ffn") + + case [*_, "self_attn", qkvo_proj]: + qkvo = qkvo_proj[0] + out_key = out_key.replace("self_attn", "attention") + out_key = out_key.replace(qkvo_proj, f"w{qkvo}") + + case [*_, proj] if proj.endswith("proj"): + kind = proj[:-5] + n = {"gate": 1, "down": 2, "up": 3}[kind] + out_key = out_key.replace(proj, f"w{n}") + out_key = out_key.replace("mlp", "feed_forward") + + case _: print(f"Don't know how to handle {in_key=}") - elif out_key == 'lm_head': - out_key = 'output' - else: - print(f"Don't know how to handle {in_key=}") - return f'{out_key}.weight' + + return f"{out_key}.weight" def reverse_permute(tensor: torch.Tensor, n_heads: int = 32, dim1:int = 4096, dim2: int = 4096) -> torch.Tensor: return tensor.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2) - def fixed_get_imports(filename: str | os.PathLike) -> list[str]: """Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72.""" if not str(filename).endswith("/modeling_deepseek.py"): @@ -115,4 +104,4 @@ def main(model_id: str, out_dir: Path): if __name__ == "__main__": - tyro.cli(main) \ No newline at end of file + tyro.cli(main)