diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 57225243..2630e1d6 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -48,6 +48,9 @@ class LlamaConfig: rms_norm_eps: float = 1e-6 rope_scaling: Optional[dict] = None rope_theta: float = 10000.0 + rope_interleaved: bool = ( + True # The default value has been True, but for loading Llama3 checkpoints you have to set it to False + ) tie_word_embeddings: bool = False use_cache: bool = True vocab_size: int = 32000 diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 5f77100e..b9ec5deb 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -14,7 +14,7 @@ # limitations under the License. """PyTorch LLaMa model.""" -from typing import Dict, Optional, Union, List +from typing import Dict, Optional, Union import torch from torch import nn @@ -189,35 +189,21 @@ def __init__(self, config: LlamaConfig, parallel_config: Optional[ParallelismArg @checkpoint_method(attr_name="checkpoint_attention") def forward( self, - query_states: torch.Tensor, # [batch_size * q_length, n_local_q_heads, inner_dim] - key_states: torch.Tensor, # [batch_size * kv_length, n_local_kv_heads, inner_dim] - value_states: torch.Tensor, # [batch_size * kv_length, n_local_kv_heads, inner_dim] - q_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, q_length] (can be broadcasted to that size) - kv_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, kv_length] (can be broadcasted to that size) + query_states: torch.Tensor, # [batch_size, q_length, n_local_q_heads, inner_dim] + key_states: torch.Tensor, # [batch_size, kv_length, n_local_kv_heads, inner_dim] + value_states: torch.Tensor, # [batch_size, kv_length, n_local_kv_heads, inner_dim] ): - from flash_attn.flash_attn_interface import flash_attn_varlen_func - - # TODO @thomasw21: Compute once, instead of computing for each layers. - cu_seqlens_q = torch.zeros((q_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) - cu_seqlens_k = torch.zeros((kv_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) - torch.cumsum(q_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_q[1:]) - torch.cumsum(kv_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_k[1:]) - - # TODO(kunhao): flash attn's causal means that the query can only attend to the keys before it. This is not - # what we want if we are using kv cache. This is a hack as we always have q_length == 1 when using kv cache. - causal = False if q_sequence_mask.shape[1] == 1 else True + from flash_attn.flash_attn_interface import flash_attn_func # NOTE: this scale is for µTransfer, # in SP, we use sqrt(1/d_h) softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None - attn_output = flash_attn_varlen_func( + # For now we are assuming that we use causual mask. No magic here + causal = True + attn_output = flash_attn_func( q=query_states, k=key_states, v=value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=q_sequence_mask.shape[1], - max_seqlen_k=kv_sequence_mask.shape[1], dropout_p=0.0, softmax_scale=softmax_scale, causal=causal, @@ -325,7 +311,9 @@ def __init__( ) # NOTE: Only supported for training (TODO(fmom): position_ids not supported yet) - self.flash_rotary_embedding = FlashRotaryEmbedding(dim=self.d_qk, base=config.rope_theta, interleaved=True) + self.flash_rotary_embedding = FlashRotaryEmbedding( + dim=self.d_qk, interleaved=config.rope_interleaved, base=config.rope_theta + ) self.o_proj = TensorParallelRowLinear( config.num_attention_heads * self.d_qk, @@ -567,29 +555,14 @@ def forward( # [batch_size, seq_length, num_heads, d_qk] key_states, value_states = torch.split(key_value_states, 1, dim=2) - q_sequence_mask = sequence_mask - kv_sequence_mask = sequence_mask - kv_length = key_states.shape[1] - # [batch_size, seq_length, num_heads, d_qk] - # Shaping for use in `flash-attn` version of flash-attn: `flash_attn_unpadded_func` - query_states = query_states.view( - batch_size * q_length, self.n_local_q_heads, self.d_qk - ) # [batch_size * q_length, self.n_heads, d_qk] - - key_states = key_states.view( - batch_size * kv_length, self.n_local_kv_heads, self.d_qk - ) # [batch_size * kv_length, self.n_heads, d_qk] - value_states = value_states.view( - batch_size * kv_length, self.n_local_kv_heads, self.d_v - ) # [batch_size * kv_length, self.n_heads, d_v] + key_states = key_states.view(batch_size, kv_length, self.n_local_kv_heads, self.d_qk) + value_states = value_states.view(batch_size, kv_length, self.n_local_kv_heads, self.d_v) attention_output = self.attention( query_states=query_states, key_states=key_states, value_states=value_states, - q_sequence_mask=q_sequence_mask, - kv_sequence_mask=kv_sequence_mask, ) attention_output = ( diff --git a/tools/llama3/README.md b/tools/llama3/README.md new file mode 100644 index 00000000..048b16c9 --- /dev/null +++ b/tools/llama3/README.md @@ -0,0 +1,32 @@ +# Llama3 Weight conversion tool +This directory contains the scripts to convert the Llama3 checkpoints from HuggingFace to Nanotron and vice versa. + +## Downloading Llama3 weights +We will use the Llama3 checkpoints stored in the HuggingFace Hub for the conversion. Despite being able to download the checkpoints setting `--pretrained-model-name-or-pathmeta-llama/Meta-Llama-3-8B-Instruct`, this is not recommended since it will download the pretrained weights to the [HuggingFace Cache](https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables#hfhubcache). We encourage to download the checkpoints explicityly to a folder with the following script: +```python +from huggingface_hub import snapshot_download + +snapshot_download(repo_id="meta-llama/Meta-Llama-3-8B", + local_dir = "models/Meta-Llama-3-8B", + local_dir_use_symlinks=False, + ignore_patterns=["original/*"]) # Llama3 models in the Hub contain the original checkpoints. We just want the HF checkpoint stored in the safetensor format +``` + +## Conversion + +- Convert from HuggingFace to Nanotron + +`torchrun --nproc-per-node 1 tools/llama3/convert_hf_to_nanotron.py --nanotron-checkpoint-path nanotron_checkpoints/Nanotron-Llama-3-8B --pretrained-model-name-or-path meta-llama/Meta-Llama-3-8B-Instruct` +- Convert from Nanotron to HuggingFace + +`torchrun --nproc-per-node 1 tools/llama3/convert_nanotron_to_hf.py --nanotron-checkpoint-path nanotron_checkpoints/Nanotron-Llama3-8B --hugging-face-checkpoint-path hf_checkpoints/Converted-Nanotron-Llama-3-8B` + +In summary, we will do the following: +- Initialize the HuggingFace model with the pretrained weights. The model definition is [here](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py). +- Initialize a Nanotron model with empty weights. The model definition is [here](https://github.com/huggingface/nanotron/blob/main/src/nanotron/models/llama.py). +- Copy the parameters layer by layer from one model to the other. +- Store the Nanotron model along with the tokenizer. + +When comparing the HuggingFace implementation with the Nanotron implementation, the main difference lies in the Q, K & V matrices and in the MLP projections. In the HuggingFace implementation, these matrices are separated [[1]](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L415), [[2]](https://github.com/huggingface/transformers/blob/1518508467d96b3866fc4ebcb7a5b3a2e0df2aa4/src/transformers/models/llama/modeling_llama.py#L194), while in the Nanotron implementation, they are concatenated [[1b]](https://github.com/huggingface/nanotron/blob/b69690703a1c41b60cd706f92a80a3d23ebaf2d0/src/nanotron/models/llama.py#L310), [[2b]](https://github.com/huggingface/nanotron/blob/b69690703a1c41b60cd706f92a80a3d23ebaf2d0/src/nanotron/models/llama.py#L149). It is crucial to pay attention to these details to convert the models correctly. + +To perform the conversion, we will need at least **1 GPU**, although the operations will be carried out on the **CPU**. We will convert the models with a parallel configuration of DP = PP = TP = 1, but it should be noted that the checkpoints generated by Nanotron are topology agnostic. diff --git a/tools/llama3/convert_hf_to_nanotron.py b/tools/llama3/convert_hf_to_nanotron.py new file mode 100644 index 00000000..e30610a3 --- /dev/null +++ b/tools/llama3/convert_hf_to_nanotron.py @@ -0,0 +1,266 @@ +""" +torchrun --nproc-per-node 1 tools/llama3/convert_hf_to_nanotron.py --nanotron-checkpoint-path nanotron_checkpoints/Nanotron-Llama-3-8B --pretrained-model-name-or-path meta-llama/Meta-Llama-3-8B-Instruct +""" +import argparse +import json +from dataclasses import asdict +from pathlib import Path + +import torch +import yaml +from nanotron import logging +from nanotron.config import Config, GeneralArgs, LoggingArgs, ModelArgs, ParallelismArgs, TokenizerArgs +from nanotron.config.models_config import ExistingCheckpointInit +from nanotron.config.models_config import LlamaConfig as LlamaConfigNanotron +from nanotron.logging import log_rank, set_ranks_logging_level +from nanotron.models import build_model +from nanotron.models.llama import LlamaForTraining +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import sanity_check +from nanotron.serialize import TrainingMetadata, save_meta, save_weights +from nanotron.serialize.metadata import DataStageMetadata +from nanotron.trainer import mark_tied_parameters +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +logger = logging.get_logger(__name__) + +DEVICE = torch.device("cpu") +TORCH_DTYPE = torch.bfloat16 + + +def get_args(): + parser = argparse.ArgumentParser() + group = parser.add_argument_group(title="Nanotron Model") + group.add_argument( + "--nanotron-checkpoint-path", + type=str, + required=True, + help="A path to a directory to store the converted Nanotron Checkpoint", + ) + + group = parser.add_argument_group(title="HuggingFace Model") + group.add_argument( + "--pretrained-model-name-or-path", + type=str, + required=True, + help="A path to a directory containing model weights saved using save_pretrained() or the model id of a pretrained model hosted inside a model repo on the Hugging Face Hub", + ) + + args = parser.parse_args() + + return args + + +def main(args): + # Init Nanotron Parallel Utilities + parallel_config = ParallelismArgs(dp=1, pp=1, tp=1) + + parallel_context = ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + + set_ranks_logging_level(parallel_context=parallel_context, logging_config=LoggingArgs()) + + # Load Llama3-8B HF model + log_rank( + f"Loading pretrained Llama3 Model: {args.pretrained_model_name_or_path}", + logger=logger, + level=logging.INFO, + rank=0, + ) + hf_model = AutoModelForCausalLM.from_pretrained( + args.pretrained_model_name_or_path, torch_dtype=TORCH_DTYPE, attn_implementation="flash_attention_2" + ).to(DEVICE) + hf_config = hf_model.config + + # Set Nanotron LlamaConfig + nanotron_llama_config = LlamaConfigNanotron( + bos_token_id=hf_config.bos_token_id, + eos_token_id=hf_config.eos_token_id, + hidden_act=hf_config.hidden_act, + hidden_size=hf_config.hidden_size, + initializer_range=hf_config.initializer_range, + intermediate_size=hf_config.intermediate_size, + is_llama_config=True, + max_position_embeddings=hf_config.max_position_embeddings, + num_attention_heads=hf_config.num_attention_heads, + num_hidden_layers=hf_config.num_hidden_layers, + num_key_value_heads=hf_config.num_key_value_heads, + pad_token_id=None, + pretraining_tp=hf_config.pretraining_tp, + rms_norm_eps=hf_config.rms_norm_eps, + rope_scaling=hf_config.rope_scaling, + rope_theta=hf_config.rope_theta, + rope_interleaved=False, + tie_word_embeddings=hf_config.tie_word_embeddings, + use_cache=hf_config.use_cache, + vocab_size=hf_config.vocab_size, + ) + + # Init Llama3-8B Nanotron model + log_rank("Init empty Nanotron Llama3 Model", logger=logger, level=logging.INFO, rank=0) + nanotron_model = build_model( + model_builder=lambda: LlamaForTraining( + config=nanotron_llama_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=None, + ), + parallel_context=parallel_context, + dtype=TORCH_DTYPE, + device=DEVICE, + ) + + mark_tied_parameters(model=nanotron_model, parallel_context=parallel_context) + sanity_check(root_module=nanotron_model) + + # Copy params from HF to Nanotron + log_rank("Copying weights from HF model to Nanotron model...", logger=logger, level=logging.INFO, rank=0) + # Token embeddings + log_rank("Copying Token Embeddings...", logger=logger, level=logging.INFO, rank=0) + assert ( + nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.shape + == hf_model.model.embed_tokens.weight.shape + ) + with torch.no_grad(): + nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.copy_( + hf_model.model.embed_tokens.weight + ) + + # Decoder layers + for i in tqdm( + range(nanotron_llama_config.num_hidden_layers), + desc="Copying Hidden Layers", + total=nanotron_llama_config.num_hidden_layers, + ): + # Input layer norm + assert ( + hf_model.model.layers[i].input_layernorm.weight.shape + == nanotron_model.model.decoder[i].pp_block.input_layernorm.weight.shape + ) + with torch.no_grad(): + nanotron_model.model.decoder[i].pp_block.input_layernorm.weight.copy_( + hf_model.model.layers[i].input_layernorm.weight + ) + + # Self attn + ## QKV + tmp_qkv_proj = torch.cat( + [ + hf_model.model.layers[i].self_attn.q_proj.weight, + hf_model.model.layers[i].self_attn.k_proj.weight, + hf_model.model.layers[i].self_attn.v_proj.weight, + ], + dim=0, + ) + assert tmp_qkv_proj.shape == nanotron_model.model.decoder[i].pp_block.attn.qkv_proj.weight.shape + with torch.no_grad(): + nanotron_model.model.decoder[i].pp_block.attn.qkv_proj.weight.copy_(tmp_qkv_proj) + + ## O + assert ( + hf_model.model.layers[i].self_attn.o_proj.weight.shape + == nanotron_model.model.decoder[i].pp_block.attn.o_proj.weight.shape + ) + with torch.no_grad(): + nanotron_model.model.decoder[i].pp_block.attn.o_proj.weight.copy_( + hf_model.model.layers[i].self_attn.o_proj.weight + ) + + # MLP + ## Gate Up Proj + tmp_gate_up_proj = torch.cat( + [ + hf_model.model.layers[i].mlp.gate_proj.weight, + hf_model.model.layers[i].mlp.up_proj.weight, + ], + dim=0, + ) + + assert tmp_gate_up_proj.shape == nanotron_model.model.decoder[i].pp_block.mlp.gate_up_proj.weight.shape + with torch.no_grad(): + nanotron_model.model.decoder[i].pp_block.mlp.gate_up_proj.weight.copy_(tmp_gate_up_proj) + + ## Down Proj + assert ( + hf_model.model.layers[i].mlp.down_proj.weight.shape + == nanotron_model.model.decoder[i].pp_block.mlp.down_proj.weight.shape + ) + with torch.no_grad(): + nanotron_model.model.decoder[i].pp_block.mlp.down_proj.weight.copy_( + hf_model.model.layers[i].mlp.down_proj.weight + ) + + # Post attn layer norm + assert ( + hf_model.model.layers[i].post_attention_layernorm.weight.shape + == nanotron_model.model.decoder[i].pp_block.post_attention_layernorm.weight.shape + ) + with torch.no_grad(): + nanotron_model.model.decoder[i].pp_block.post_attention_layernorm.weight.copy_( + hf_model.model.layers[i].post_attention_layernorm.weight + ) + + # Last layer norm + log_rank("Copying Final Layer Norm...", logger=logger, level=logging.INFO, rank=0) + assert nanotron_model.model.final_layer_norm.pp_block.weight.shape == hf_model.model.norm.weight.shape + with torch.no_grad(): + nanotron_model.model.final_layer_norm.pp_block.weight.copy_(hf_model.model.norm.weight) + + # LM_Head + log_rank("Copying LM Head...", logger=logger, level=logging.INFO, rank=0) + assert nanotron_model.model.lm_head.pp_block.weight.shape == hf_model.lm_head.weight.shape + with torch.no_grad(): + nanotron_model.model.lm_head.pp_block.weight.copy_(hf_model.lm_head.weight) + + log_rank("Copied weights from HF model to Nanotron model!", logger=logger, level=logging.INFO, rank=0) + # Store weights + nanotron_checkpoint_path = Path(args.nanotron_checkpoint_path) + save_weights(model=nanotron_model, parallel_context=parallel_context, root_folder=nanotron_checkpoint_path) + + # Store metadata + log_rank("Storing Nanotron model Configs and Metadata!", logger=logger, level=logging.INFO, rank=0) + training_metadata = TrainingMetadata( + last_train_step=0, + consumed_train_samples=0, + data_stages=[DataStageMetadata(name="Empty", consumed_train_samples=0, start_training_step=0)], + ) + save_meta( + root_folder=nanotron_checkpoint_path, parallel_context=parallel_context, training_metadata=training_metadata + ) + # Store Tokenizer into Nanotron Checkpoint folder + tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path) + tokenizer.save_pretrained(nanotron_checkpoint_path) + + # Store Config and Model Config files + with open(nanotron_checkpoint_path / "config.yaml", "w") as f: + config = Config( + general=GeneralArgs(project="Nanotron", run="Llama3"), + parallelism=parallel_config, + model=ModelArgs( + init_method=ExistingCheckpointInit(nanotron_checkpoint_path), + model_config=nanotron_llama_config, + ), + tokenizer=TokenizerArgs(nanotron_checkpoint_path), + ) + log_rank("Saving config ...", logger=logger, level=logging.INFO, rank=0) + yaml.dump(config.as_dict(), f) + + with open(nanotron_checkpoint_path / "model_config.json", "w") as f: + log_rank("Saving model config ...", logger=logger, level=logging.INFO, rank=0) + json.dump(asdict(nanotron_llama_config), f) + + log_rank( + f"Checkpoint conversion finished, check {args.nanotron_checkpoint_path}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + +if __name__ == "__main__": + _args = get_args() + main(_args) diff --git a/tools/llama3/convert_nanotron_to_hf.py b/tools/llama3/convert_nanotron_to_hf.py new file mode 100644 index 00000000..c5fb1940 --- /dev/null +++ b/tools/llama3/convert_nanotron_to_hf.py @@ -0,0 +1,229 @@ +""" +torchrun --nproc-per-node 1 tools/llama3/convert_nanotron_to_hf.py --nanotron-checkpoint-path nanotron_checkpoints/Nanotron-Llama-3-8B --hugging-face-checkpoint-path hf_checkpoints/Converted-Nanotron-Llama-3-8B +""" +import argparse +import os +from dataclasses import asdict +from pathlib import Path + +import torch +from nanotron import logging +from nanotron.config import Config, LoggingArgs, ParallelismArgs, get_config_from_file +from nanotron.logging import log_rank, set_ranks_logging_level +from nanotron.models import build_model +from nanotron.models.llama import LlamaForTraining +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import sanity_check +from nanotron.serialize import load_weights +from nanotron.trainer import mark_tied_parameters +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.models.llama import LlamaConfig as LlamaConfigHF + +logger = logging.get_logger(__name__) + +DEVICE = torch.device("cpu") +TORCH_DTYPE = torch.bfloat16 + + +def get_args(): + parser = argparse.ArgumentParser() + group = parser.add_argument_group(title="Nanotron Model") + group.add_argument( + "--nanotron-checkpoint-path", + type=str, + required=True, + help="A path to a directory with a Nanotron Checkpoint", + ) + + group = parser.add_argument_group(title="HuggingFace Model") + group.add_argument( + "--hugging-face-checkpoint-path", + type=str, + required=True, + help="A path to a directory to store the converted checkpoint", + ) + + args = parser.parse_args() + + return args + + +def main(args): + # Init Nanotron Parallel Utilities + parallel_config = ParallelismArgs(dp=1, pp=1, tp=1) + + parallel_context = ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + + set_ranks_logging_level(parallel_context=parallel_context, logging_config=LoggingArgs()) + + # Load Nanotron checkpoint config + log_rank( + f"Loading Nanotron checkpoint config file: {os.path.join(args.nanotron_checkpoint_path, 'config.yaml')}", + logger=logger, + level=logging.INFO, + rank=0, + ) + nanotron_config = get_config_from_file( + os.path.join(args.nanotron_checkpoint_path, "config.yaml"), config_class=Config, model_config_class=None + ) + nanotron_llama_config = nanotron_config.model.model_config + + # Init Llama3-8B Nanotron model + log_rank("Init empty Nanotron Llama3 Model", logger=logger, level=logging.INFO, rank=0) + + nanotron_model = build_model( + model_builder=lambda: LlamaForTraining( + config=nanotron_config.model.model_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=None, + ), + parallel_context=parallel_context, + dtype=TORCH_DTYPE, + device=DEVICE, + ) + + mark_tied_parameters(model=nanotron_model, parallel_context=parallel_context) + sanity_check(root_module=nanotron_model) + + # Load Nanotron Checkpoint + log_rank("Loading Nanotron Llama3 Model...", logger=logger, level=logging.INFO, rank=0) + load_weights( + model=nanotron_model, parallel_context=parallel_context, root_folder=Path(args.nanotron_checkpoint_path) + ) + + # Build empty HF Model + log_rank("Init empty HF Llama3 Model", logger=logger, level=logging.INFO, rank=0) + hf_model = AutoModelForCausalLM.from_config( # WARN This takes a long time + config=LlamaConfigHF(**asdict(nanotron_llama_config)), + torch_dtype=TORCH_DTYPE, + attn_implementation="flash_attention_2", + ).to(DEVICE) + + # Copy params from Nanotron to HF + log_rank("Copying weights from Nanotron model to HF model...", logger=logger, level=logging.INFO, rank=0) + # Token embeddings + log_rank("Copying Token Embeddings...", logger=logger, level=logging.INFO, rank=0) + assert ( + nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.shape + == hf_model.model.embed_tokens.weight.shape + ) + with torch.no_grad(): + hf_model.model.embed_tokens.weight.copy_( + nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight + ) + + # Decoder layers + for i in tqdm( + range(nanotron_llama_config.num_hidden_layers), + desc="Copying Hidden Layers", + total=nanotron_llama_config.num_hidden_layers, + ): + # Input layer norm + assert ( + hf_model.model.layers[i].input_layernorm.weight.shape + == nanotron_model.model.decoder[i].pp_block.input_layernorm.weight.shape + ) + with torch.no_grad(): + hf_model.model.layers[i].input_layernorm.weight.copy_( + nanotron_model.model.decoder[i].pp_block.input_layernorm.weight + ) + + # Self attn + # Split Nanotrn qkv projection into q, k, v + q, k, v = torch.split( + nanotron_model.model.decoder[i].pp_block.attn.qkv_proj.weight, + [ + nanotron_llama_config.num_attention_heads * nanotron_model.model.decoder[i].pp_block.attn.d_qk, + nanotron_llama_config.num_key_value_heads * nanotron_model.model.decoder[i].pp_block.attn.d_qk, + nanotron_llama_config.num_key_value_heads * nanotron_model.model.decoder[i].pp_block.attn.d_qk, + ], + ) + assert q.shape == hf_model.model.layers[i].self_attn.q_proj.weight.shape + assert k.shape == hf_model.model.layers[i].self_attn.k_proj.weight.shape + assert v.shape == hf_model.model.layers[i].self_attn.v_proj.weight.shape + + with torch.no_grad(): + hf_model.model.layers[i].self_attn.q_proj.weight.copy_(q) + hf_model.model.layers[i].self_attn.k_proj.weight.copy_(k) + hf_model.model.layers[i].self_attn.v_proj.weight.copy_(v) + + ## O + assert ( + hf_model.model.layers[i].self_attn.o_proj.weight.shape + == nanotron_model.model.decoder[i].pp_block.attn.o_proj.weight.shape + ) + with torch.no_grad(): + hf_model.model.layers[i].self_attn.o_proj.weight.copy_( + nanotron_model.model.decoder[i].pp_block.attn.o_proj.weight + ) + + # MLP + ## Gate Up Proj + gate_proj, up_proj = torch.split( + nanotron_model.model.decoder[i].pp_block.mlp.gate_up_proj.weight, + split_size_or_sections=[nanotron_llama_config.intermediate_size, nanotron_llama_config.intermediate_size], + ) + assert gate_proj.shape == hf_model.model.layers[i].mlp.gate_proj.weight.shape + assert up_proj.shape == hf_model.model.layers[i].mlp.up_proj.weight.shape + + with torch.no_grad(): + hf_model.model.layers[i].mlp.gate_proj.weight.copy_(gate_proj) + hf_model.model.layers[i].mlp.up_proj.weight.copy_(up_proj) + + ## Down Proj + assert ( + hf_model.model.layers[i].mlp.down_proj.weight.shape + == nanotron_model.model.decoder[i].pp_block.mlp.down_proj.weight.shape + ) + with torch.no_grad(): + hf_model.model.layers[i].mlp.down_proj.weight.copy_( + nanotron_model.model.decoder[i].pp_block.mlp.down_proj.weight + ) + + # Post attn layer norm + assert ( + hf_model.model.layers[i].post_attention_layernorm.weight.shape + == nanotron_model.model.decoder[i].pp_block.post_attention_layernorm.weight.shape + ) + with torch.no_grad(): + hf_model.model.layers[i].post_attention_layernorm.weight.copy_( + nanotron_model.model.decoder[i].pp_block.post_attention_layernorm.weight + ) + + # Last layer norm + log_rank("Copying Final Layer Norm...", logger=logger, level=logging.INFO, rank=0) + assert nanotron_model.model.final_layer_norm.pp_block.weight.shape == hf_model.model.norm.weight.shape + with torch.no_grad(): + hf_model.model.norm.weight.copy_(nanotron_model.model.final_layer_norm.pp_block.weight) + + # LM_Head + log_rank("Copying LM Head...", logger=logger, level=logging.INFO, rank=0) + assert nanotron_model.model.lm_head.pp_block.weight.shape == hf_model.lm_head.weight.shape + with torch.no_grad(): + hf_model.lm_head.weight.copy_(nanotron_model.model.lm_head.pp_block.weight) + + log_rank("Copied weights from Nanotron model to HF model!", logger=logger, level=logging.INFO, rank=0) + # Store weights + log_rank("Storing HF model Checkpoint and Tokenizer!", logger=logger, level=logging.INFO, rank=0) + hf_model.save_pretrained(args.hugging_face_checkpoint_path, from_pt=True) + # Store tokenizer + tokenizer = AutoTokenizer.from_pretrained(nanotron_config.tokenizer.tokenizer_name_or_path) + tokenizer.save_pretrained(args.hugging_face_checkpoint_path) + + log_rank( + f"Checkpoint conversion finished, check {args.hugging_face_checkpoint_path}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + +if __name__ == "__main__": + _args = get_args() + main(_args)