diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 49ea86e6..764ef3f3 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -14,7 +14,7 @@ # limitations under the License. """PyTorch LLaMa model.""" -from typing import Dict, Optional, Union, List +from typing import Dict, List, Optional, Union import torch from torch import nn @@ -74,9 +74,10 @@ def init_rotary_embeddings(self): self.freqs_cis = self.freqs_cis.to(torch.float) assert self.freqs_cis.dtype == torch.float freqs = 1.0 / ( - self.theta - ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cuda")[: (self.dim // 2)] / self.dim) - ) + self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cpu")[: (self.dim // 2)] / self.dim) + ).to( + "cuda" + ) # should be computed on CPU, otherwise different results with Transformers. t = torch.arange(self.end, device="cuda") freqs = torch.outer(t, freqs).float() complex_freqs = torch.polar(torch.ones_like(freqs), freqs) @@ -118,6 +119,84 @@ def forward( return x_out.type(dtype) +## Copy from transformers. Non interleaved version of RoPE. Will be refactored later +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +class LlamaRotaryEmbedding(nn.Module): + def __init__(self, dim: int, end: int, theta: float = 500000.0): + super().__init__() + self.dim = dim + self.end = end + self.theta = theta + self.init_rotary_embeddings() + + def init_rotary_embeddings(self): + inv_freq = 1.0 / ( + self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cpu") / self.dim) + ) # important to compute on CPU + # inv_freq = apply_scaling(inv_freq) # if LLaMA 3.1 + self.register_buffer( + "inv_freq", torch.empty(self.dim // 2, dtype=torch.float, device="cuda"), persistent=False + ) + self.inv_freq = self.inv_freq.to( + torch.float + ) # make it float32 before copy to avoid precision loss during copy_ + self.inv_freq.copy_(inv_freq) + + saved_inv_freq = torch.load("/fsx/haojun/LLaMA/.cache/activation_values/inv_freq.pt") + assert torch.equal(self.inv_freq.cpu(), saved_inv_freq), "inv_freq mismatch." + + @torch.no_grad() + def forward( + self, + x: torch.Tensor, # [batch_size, seq_length, num_heads, d_qk] + position_ids: Optional[torch.LongTensor], # [batch_size, seq_length] + ): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=2): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + class GLUActivation(nn.Module): def __init__(self, act_fn_name: str): super().__init__() @@ -620,9 +699,9 @@ def __init__( self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) - + self.recompute_layer = parallel_config.recompute_layer - + def _core_forward( self, hidden_states: Union[torch.Tensor, TensorPointer], @@ -641,12 +720,12 @@ def _core_forward( hidden_states = hidden_states + residual return hidden_states, output["sequence_mask"] - + def _checkpointed_forward( self, hidden_states: torch.Tensor, sequence_mask: torch.Tensor, - ) -> List[torch.Tensor]: + ) -> List[torch.Tensor]: return CheckpointFunction.apply(self._core_forward, True, hidden_states, sequence_mask) def forward( @@ -654,7 +733,7 @@ def forward( hidden_states: Union[torch.Tensor, TensorPointer], sequence_mask: Union[torch.Tensor, TensorPointer], ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: - + if self.recompute_layer and not isinstance(hidden_states, TensorPointer): hidden_states, sequence_mask = self._checkpointed_forward(hidden_states, sequence_mask) else: @@ -665,6 +744,7 @@ def forward( "sequence_mask": sequence_mask, } + class Embedding(nn.Module, AttachableStore): def __init__(self, tp_pg: dist.ProcessGroup, config: LlamaConfig, parallel_config: Optional[ParallelismArgs]): super().__init__() diff --git a/src/nanotron/models/llama3_1.py b/src/nanotron/models/llama3_1.py new file mode 100644 index 00000000..41772b2b --- /dev/null +++ b/src/nanotron/models/llama3_1.py @@ -0,0 +1,1043 @@ +# coding=utf-8 +# Copyright 2018 HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The LLaMA model that match the transformers LLaMA model by doing the following changes: +1 Merged QKV -> Separate QKV +2 Merged Gate/Up -> Separate Gate/Up +3 Triton RMSNorm -> LLaMA RMSNorm +4 Flash RoPR(training) -> RoPE +5 Interleaved RoPE -> Non interleaved RoPE +6 Core attention -> flash_attn_func +7 Same computation device as transformers (CPU then CUDA) +8 Fix precision bug +-> Exact logits match during generation. + +All the others are the same as the llama.py +Note: It's yet not clear which one should be kept in the future. What's the trade-off between the performance gain and the precision loss. + +""" + +import math +from typing import Dict, List, Optional, Union + +import torch +from torch import nn +from torch.utils.checkpoint import CheckpointFunction + +from nanotron import distributed as dist +from nanotron import logging +from nanotron.config import Config, LlamaConfig, ParallelismArgs +from nanotron.config.models_config import RandomInit, SpectralMupInit +from nanotron.generation.generate_store import AttachableStore +from nanotron.logging import log_rank +from nanotron.models import NanotronModel +from nanotron.models.llama import LlamaRotaryEmbedding, RotaryEmbedding, apply_rotary_pos_emb +from nanotron.nn.activations import ACT2FN +from nanotron.nn.layer_norm import RMSNorm +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import NanotronParameter +from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer +from nanotron.parallel.pipeline_parallel.p2p import P2P +from nanotron.parallel.ring_flash_attn.utils import zigzag_split +from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy +from nanotron.parallel.tensor_parallel.nn import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelLinearMode, + TensorParallelRowLinear, +) +from nanotron.random import RandomStates +from nanotron.scaling.parametrization import SpectralMupParametrizator, StandardParametrizator +from nanotron.utils import supports_flash_attention + +if supports_flash_attention(): + from flash_attn import bert_padding + from flash_attn.flash_attn_interface import ( + flash_attn_func, + flash_attn_varlen_func, + flash_attn_with_kvcache, + ) + from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding + + from nanotron.parallel.ring_flash_attn.zigzag_ring_flash_attn import zigzag_ring_flash_attn_func + + +logger = logging.get_logger(__name__) + + +def apply_scaling(freqs: torch.Tensor): + # Values obtained from grid search + scale_factor = 8 + low_freq_factor = 1 + high_freq_factor = 4 + old_context_len = 8192 # original llama3 length + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + new_freqs = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / scale_factor) + else: + smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + + +class MLP(nn.Module): + def __init__( + self, + config: LlamaConfig, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + ): + super().__init__() + + # TODO @thomasw21: refactor so that we store that default in a single place. + tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) + + self.gate_proj = TensorParallelColumnLinear( + config.hidden_size, + config.intermediate_size, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication, + ) + self.up_proj = TensorParallelColumnLinear( + config.hidden_size, + config.intermediate_size, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication, + ) + self.down_proj = TensorParallelRowLinear( + config.intermediate_size, + config.hidden_size, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, + ) + self.act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] + gate_states, up_states = self.gate_proj(hidden_states), self.up_proj(hidden_states) + hidden_states = self.down_proj(self.act(gate_states) * up_states) + return {"hidden_states": hidden_states} + + +class RingFlashAttention(nn.Module): + def __init__(self, config: LlamaConfig, pg: dist.ProcessGroup): + super().__init__() + assert config.hidden_size % config.num_attention_heads == 0 + assert config.hidden_size % config.num_key_value_heads == 0 + assert dist.get_world_size(pg) > 1, "Ring attention process group size must be greater than 1" + self.pg = pg + + def forward( + self, + local_q: torch.Tensor, # [batch_size, q_length, n_local_q_heads, inner_dim] + local_k: torch.Tensor, # [batch_size, kv_length, n_local_kv_heads, inner_dim] + local_v: torch.Tensor, # [batch_size, kv_length, n_local_kv_heads, inner_dim] + ): + causal = True + ring_out, _, _ = zigzag_ring_flash_attn_func( + local_q, + local_k, + local_v, + dropout_p=0.0, + causal=causal, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=True, + group=self.pg, + ) + return ring_out + + +def pad_to_right(tensor, mask, new_tensor=None): + """Transform a left-padded tensor into a right-padded tensor. (Useful for prefilling key/value states) + Args: + tensor: (batch_size, seqlen, d1, d2) + mask: (batch_size, seqlen) + new_tensor: (batch_size, new_tensor_seqlen, d1, d2) + Returns: + new_tensor: (batch_size, new_tensor_seqlen, d1, d2) + right_padded_mask: (batch_size, seqlen) + """ + # First, we need to find the number of padding for each row + unpad_seqlens = mask.sum(1) + # Then, we need to find the maximum length of the tensor + max_seqlen = mask.shape[1] + # We can then create the indices to select the padded values + # The indices are the same for each row + indices = torch.arange(max_seqlen, device=mask.device) + # We can then create the mask for the padded values + right_padded_mask = indices < unpad_seqlens[:, None] + # We select the useful values + useful_values = tensor[mask] + # We create the new tensor (if not provided) + new_tensor = torch.zeros_like(tensor) if new_tensor is None else new_tensor + # We fill the new tensor with the useful values + new_tensor[:, : right_padded_mask.shape[1], :, :][right_padded_mask] = useful_values + return new_tensor, right_padded_mask + + +class CausalSelfAttention(nn.Module, AttachableStore): + def __init__( + self, + config: LlamaConfig, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + sp_pg: dist.ProcessGroup, + layer_idx: int, + ): + + super().__init__() + # Tensor parallel considerations: We split tensors along head dimension + assert ( + config.num_attention_heads % tp_pg.size() == 0 + ), f"Number of attention heads ({config.num_attention_heads}) must be divisible by TP size ({tp_pg.size()})." + try: + assert ( + config.num_key_value_heads % tp_pg.size() == 0 + ), f"Number of key/value heads ({config.num_key_value_heads}) must be divisible by TP size ({tp_pg.size()})." + except AttributeError: + log_rank( + "WARNING: num_key_value_heads not defined, assuming it is equal to num_attention_heads", + logger=logger, + level=logging.WARNING, + rank=0, + ) + # If num_key_value_heads is not defined, we assume that it is equal to num_attention_heads + config.num_key_value_heads = config.num_attention_heads + assert ( + config.num_attention_heads % config.num_key_value_heads == 0 + ), f"Number of attention heads ({config.num_attention_heads}) must be divisible by number of key/value heads ({config.num_key_value_heads})." + self.n_local_q_heads = config.num_attention_heads // tp_pg.size() + self.n_local_kv_heads = config.num_key_value_heads // tp_pg.size() + self.n_repeats = config.num_attention_heads // config.num_key_value_heads + self.is_gqa = config.num_attention_heads != config.num_key_value_heads # Whether we are using GQA or not + self.d_qk = config.hidden_size // config.num_attention_heads + self.d_v = config.hidden_size // config.num_attention_heads + self.d_model = config.hidden_size + self.is_using_mup = config.is_using_mup + + # TODO @thomasw21: refactor so that we store that default in a single place. + tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) + + self.q_proj = TensorParallelColumnLinear( + self.d_model, + config.num_attention_heads * self.d_qk, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication, + ) + self.k_proj = TensorParallelColumnLinear( + self.d_model, + config.num_key_value_heads * self.d_qk, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication, + ) + self.v_proj = TensorParallelColumnLinear( + self.d_model, + config.num_key_value_heads * self.d_qk, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication, + ) + # TODO(kunhao): We want to have only one version per device and not one version per layer. + # config.rope_interleaved = True + if config.rope_interleaved: + self.rotary_embedding = RotaryEmbedding( + dim=self.d_qk, + end=config.max_position_embeddings, + theta=config.rope_theta, + ) + else: + self.rotary_embedding = LlamaRotaryEmbedding( + dim=self.d_qk, + end=config.max_position_embeddings, + theta=config.rope_theta, + ) + self.rope_interleaved = config.rope_interleaved + + # NOTE: Only supported for training (TODO(fmom): position_ids not supported yet) + self.flash_rotary_embedding = FlashRotaryEmbedding( + dim=self.d_qk, interleaved=config.rope_interleaved, base=config.rope_theta + ) + + self.o_proj = TensorParallelRowLinear( + config.num_attention_heads * self.d_qk, + self.d_model, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication, + ) + # Normal attention when Sequence parallelism group size = 1 + if sp_pg.size() > 1: + self.attention = RingFlashAttention( + config, + pg=sp_pg, + ) + self.sp_pg = sp_pg + self.prefill_kv_len = ( + config.max_position_embeddings + ) # TODO @nouamane: compute based on free memory, because in rope we can surpass max_position_embeddings + self.layer_idx = layer_idx + + def forward( + self, + hidden_states, # [seq_length, batch_size, hidden_size] + sequence_mask, # [batch_size, seq_length] + position_ids: Optional[torch.LongTensor] = None, + ): + q_length, batch_size, _ = hidden_states.shape + query_states = ( + self.q_proj(hidden_states) + .transpose(0, 1) + .contiguous() + .view(batch_size, q_length, self.n_local_q_heads, self.d_qk) + ) # [batch_size, q_length, n_local_q_heads, d_qk] + key_states = ( + self.k_proj(hidden_states) + .transpose(0, 1) + .contiguous() + .view(batch_size, q_length, self.n_local_kv_heads, self.d_qk) + ) # [batch_size, q_length, n_local_kv_heads, d_qk] + value_states = ( + self.v_proj(hidden_states) + .transpose(0, 1) + .contiguous() + .view(batch_size, q_length, self.n_local_kv_heads, self.d_qk) + ) # [batch_size, q_length, n_local_kv_heads, d_qk] + + store = self.get_local_store() + if store is not None: # Inference case + # Double check that we use store only at inference time + assert key_states.requires_grad is False + assert value_states.requires_grad is False + if "position_offsets" in store: + old_position_offsets = store["position_offsets"] + position_ids = old_position_offsets[:, None] + sequence_mask + else: + position_ids = torch.cumsum(sequence_mask, dim=-1, dtype=torch.int32) - 1 + position_offsets = position_ids[:, -1] + + # Compute rotary embeddings + # Note: keep track of old rotary embedding end to check if we need to enlarge k_cache and v_cache + old_rotary_embed_end = self.rotary_embedding.end + + # interleaved version. + if self.rope_interleaved: + query_states = self.rotary_embedding(query_states, position_ids=position_ids) + key_states = self.rotary_embedding(key_states, position_ids=position_ids) + # non interleaved version. + else: + cos, sin = self.rotary_embedding(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if "key" not in store: + # First inference iteration (Prefill) + # TODO @nouamane: support custom masking + # assert that [ False, False, False, False, True, True, True, True, True, True] is accepted + # but [ False, False, False, False, True, True, False, False, True, True] is not (can't mask in the middle of sequence) + assert ~( + sequence_mask[:, :-1] & (~sequence_mask[:, 1:]) # True is never followed by False + ).any(), "Can't mask in the middle of sequence, please make sure that pads are at the left of the sequence if existing" + + # preallocate k_cache, v_cache to self.prefill_kv_len + k_cache = torch.zeros( + ( + batch_size, + self.prefill_kv_len, + self.n_local_kv_heads, + self.d_qk, + ), + dtype=query_states.dtype, + device=query_states.device, + ) + v_cache = torch.zeros( + (batch_size, self.prefill_kv_len, self.n_local_kv_heads, self.d_v), + dtype=query_states.dtype, + device=query_states.device, + ) + # Remove pad tokens from key_states and concatenate samples in key_unpad + # cu_seqlens_k is the cumulative sequence lengths of key_states + (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input( + query_states, + sequence_mask, + ) + (key_unpad, indices_k, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input( + key_states, sequence_mask + ) + (value_unpad, _, _, _) = bert_padding.unpad_input(value_states, sequence_mask) + + # NOTE: this scale is for µTransfer, + # in SP, we use sqrt(1/d_h) + softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None + output_unpad = flash_attn_varlen_func( + q=query_unpad, # (total_q, n_local_q_heads, d_qk) + k=key_unpad, # (total_kv, n_local_kv_heads, d_qk) + v=value_unpad, # (total_kv, n_local_kv_heads, d_v) + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=0.0, + softmax_scale=softmax_scale, + causal=True, # True in prefill phase, False in subsequent phases + return_attn_probs=False, + ) # (total_unpadded, n_local_q_heads, d_v) + + attention_output = bert_padding.pad_input( + output_unpad, indices_q, batch_size, q_length + ) # (batch_size, q_length, n_local_q_heads, d_v) + + pad_to_right(key_states, sequence_mask, new_tensor=k_cache) + pad_to_right(value_states, sequence_mask, new_tensor=v_cache) + + else: + # Pull pre-computed key/value states + # Subsequent inference iterations (q_length=1) + k_cache = store["key"] + v_cache = store["value"] + + # NOTE(fmom): According to flash_attn_with_kvcache, "If you pass in k / v, you must make sure that the cache is large enough to hold the new values" + # Since rotary embedding has changed (to enable larger context), we need to enlarge k_cache and v_cache + if self.rotary_embedding.end > old_rotary_embed_end: + k_cache = torch.cat( + [ + k_cache, + torch.zeros( + ( + batch_size, + self.rotary_embedding.end - old_rotary_embed_end, + self.n_local_kv_heads, + self.d_qk, + ), + dtype=query_states.dtype, + device=query_states.device, + ), + ], + dim=1, + ) + + v_cache = torch.cat( + [ + v_cache, + torch.zeros( + ( + batch_size, + self.rotary_embedding.end - old_rotary_embed_end, + self.n_local_kv_heads, + self.d_v, + ), + dtype=query_states.dtype, + device=query_states.device, + ), + ], + dim=1, + ) + + assert ( + k_cache.shape[1] == self.rotary_embedding.end + ), f"Cache size {k_cache.shape[1]} is smaller than rotary embedding end {self.rotary_embedding.end}" + assert ( + v_cache.shape[1] == self.rotary_embedding.end + ), f"Cache size {v_cache.shape[1]} is smaller than rotary embedding end {self.rotary_embedding.end}" + + # [batch_size, seq_length, num_heads, d_qk] + query_states = query_states.view( + batch_size, q_length, self.n_local_q_heads, self.d_qk + ) # [batch_size, q_length, self.n_heads, d_qk] + kv_length = key_states.shape[1] + key_states = key_states.view( + batch_size, kv_length, self.n_local_kv_heads, self.d_qk + ) # [batch_size, kv_length, self.n_heads, d_qk] + value_states = value_states.view( + batch_size, kv_length, self.n_local_kv_heads, self.d_v + ) # [batch_size, kv_length, self.n_heads, d_v] + + # NOTE: this scale is for µTransfer, + # in SP, we use sqrt(1/d_h) + softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None + attention_output = flash_attn_with_kvcache( + query_states, + k_cache, + v_cache, + key_states, + value_states, + rotary_cos=None, + rotary_sin=None, + # TODO @nouamane: seems like this doesn't help to indicate padding in (for first iteration it's just 0) + cache_seqlens=position_offsets.contiguous(), + softmax_scale=softmax_scale, + causal=True, + rotary_interleaved=False, # GPT-NeoX style + ) + + store.update( + { + "key": k_cache, # flash-attn has updated with new key_states using cache_seqlens + "value": v_cache, + "position_offsets": position_offsets, + } + ) + + else: # Training case + # apply rotary embedding. + + if self.rope_interleaved: + # interleaved version. + query_states = self.rotary_embedding(query_states, position_ids=position_ids) + key_states = self.rotary_embedding(key_states, position_ids=position_ids) + # non interleaved version. + else: + cos, sin = self.rotary_embedding(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + kv_length = key_states.shape[1] + ## ring attention + if self.sp_pg.size() > 1: + key_states = key_states.view(batch_size, kv_length, self.n_local_kv_heads, self.d_qk) + value_states = value_states.view(batch_size, kv_length, self.n_local_kv_heads, self.d_v) + + attention_output = self.attention( + query_states, + key_states, + value_states, + ) + ## flash attention + else: + attention_output = flash_attn_func(query_states, key_states, value_states, causal=True) + + attention_output = ( + attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v).transpose(0, 1) + ) + output = self.o_proj(attention_output) + + return {"hidden_states": output, "sequence_mask": sequence_mask} + + +class LlamaDecoderLayer(nn.Module): + def __init__( + self, + config: LlamaConfig, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + sp_pg: dist.ProcessGroup, + layer_idx: int, + ): + super().__init__() + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attn = CausalSelfAttention( + config=config, + parallel_config=parallel_config, + tp_pg=tp_pg, + sp_pg=sp_pg, + layer_idx=layer_idx, + ) + self.layer_idx = layer_idx + + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) + + self.recompute_layer = parallel_config.recompute_layer + + def _core_forward( + self, + hidden_states: Union[torch.Tensor, TensorPointer], + sequence_mask: Union[torch.Tensor, TensorPointer], + position_ids: Optional[torch.LongTensor] = None, + ) -> List[Union[torch.Tensor, TensorPointer]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask, position_ids=position_ids) + hidden_states = output["hidden_states"] + hidden_states = hidden_states + residual + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"] + hidden_states = hidden_states + residual + + return hidden_states, output["sequence_mask"], position_ids + + def _checkpointed_forward( + self, + hidden_states: torch.Tensor, + sequence_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + ) -> List[torch.Tensor]: + return CheckpointFunction.apply(self._core_forward, True, hidden_states, sequence_mask, position_ids) + + def forward( + self, + hidden_states: Union[torch.Tensor, TensorPointer], + sequence_mask: Union[torch.Tensor, TensorPointer], + position_ids: Optional[torch.LongTensor] = None, + ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + + if self.recompute_layer and not isinstance(hidden_states, TensorPointer): + hidden_states, sequence_mask, position_ids = self._checkpointed_forward( + hidden_states, sequence_mask, position_ids + ) + else: + hidden_states, sequence_mask, position_ids = self._core_forward(hidden_states, sequence_mask, position_ids) + + return { + "hidden_states": hidden_states, + "sequence_mask": sequence_mask, + "position_ids": position_ids, + } + + +class Embedding(nn.Module, AttachableStore): + def __init__(self, tp_pg: dist.ProcessGroup, config: LlamaConfig, parallel_config: Optional[ParallelismArgs]): + super().__init__() + self.token_embedding = TensorParallelEmbedding( + num_embeddings=config.vocab_size, + embedding_dim=config.hidden_size, + padding_idx=config.pad_token_id, + pg=tp_pg, + mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE, + ) + self.pg = tp_pg + + def forward(self, input_ids: torch.Tensor, input_mask: torch.Tensor): # [batch_size, seq_length] + store = self.get_local_store() + if store is not None: + if "past_length" in store: + past_length = store["past_length"] + else: + past_length = torch.zeros(1, dtype=torch.long, device=input_ids.device).expand(input_ids.shape[0]) + + cumsum_mask = input_mask.cumsum(-1, dtype=torch.long) + # Store new past_length in store + store["past_length"] = past_length + cumsum_mask[:, -1] + + # Format input in `[seq_length, batch_size]` to support high TP with low batch_size + input_ids = input_ids.transpose(0, 1) + input_embeds = self.token_embedding(input_ids) + return {"input_embeds": input_embeds} + + +class LlamaModel(nn.Module): + """Build pipeline graph""" + + def __init__( + self, + config: LlamaConfig, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + ): + super().__init__() + + # Declare all the nodes + self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda")) + self.config = config + self.parallel_config = parallel_config + self.parallel_context = parallel_context + self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) + log_rank("Initializing LLama 3.1", logger=logger, level=logging.INFO, rank=0) + + self.token_position_embeddings = PipelineBlock( + p2p=self.p2p, + module_builder=Embedding, + module_kwargs={ + "tp_pg": parallel_context.tp_pg, + "config": config, + "parallel_config": parallel_config, + }, + module_input_keys={"input_ids", "input_mask"}, + module_output_keys={"input_embeds"}, + ) + log_rank(f"Initialize RoPE Theta = {config.rope_theta}", logger=logger, level=logging.INFO, rank=0) + self.decoder = nn.ModuleList( + [ + PipelineBlock( + p2p=self.p2p, + module_builder=LlamaDecoderLayer, + module_kwargs={ + "config": config, + "parallel_config": parallel_config, + "tp_pg": parallel_context.tp_pg, + "sp_pg": parallel_context.sp_pg, + "layer_idx": layer_idx, + }, + module_input_keys={"hidden_states", "sequence_mask", "position_ids"}, + module_output_keys={"hidden_states", "sequence_mask", "position_ids"}, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + self.final_layer_norm = PipelineBlock( + p2p=self.p2p, + module_builder=RMSNorm, + module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps}, + module_input_keys={"input"}, + module_output_keys={"hidden_states"}, + ) # TODO + + self.lm_head = PipelineBlock( + p2p=self.p2p, + # Understand that this means that we return sharded logits that are going to need to be gathered + module_builder=TensorParallelColumnLinear, + module_kwargs={ + "in_features": config.hidden_size, + "out_features": config.vocab_size, + "pg": parallel_context.tp_pg, + "bias": False, + # TODO @thomasw21: refactor so that we store that default in a single place. + "mode": self.tp_mode, + "async_communication": tp_linear_async_communication, + }, + module_input_keys={"x"}, + module_output_keys={"logits"}, + ) + self.cast_to_fp32 = PipelineBlock( + p2p=self.p2p, + module_builder=lambda: lambda x: x.float(), + module_kwargs={}, + module_input_keys={"x"}, + module_output_keys={"output"}, + ) + + def forward( + self, + input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + ): + return self.forward_with_hidden_states(input_ids=input_ids, input_mask=input_mask)[0] + + def forward_with_hidden_states( + self, + input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + ): + if isinstance(input_ids, torch.Tensor): + batch_size, seq_length = input_ids.shape + position_ids = torch.cumsum(input_mask, dim=-1, dtype=torch.int32) - 1 + # split input if using ring attention + if self.parallel_context.sp_pg.size() > 1: + world_size = self.parallel_context.sp_pg.size() + rank = dist.get_rank(self.parallel_context.sp_pg) + input_ids, input_mask, position_ids = zigzag_split( + rank, world_size, input_ids, input_mask, position_ids + ) + else: + position_ids = TensorPointer(input_ids.group_rank) + # all tensors are optional as most ranks don't need anything from the dataloader. + output = self.token_position_embeddings(input_ids=input_ids, input_mask=input_mask) + + hidden_encoder_states = { + "hidden_states": output["input_embeds"], + "sequence_mask": input_mask, + "position_ids": position_ids, + } + for encoder_block in self.decoder: + hidden_encoder_states = encoder_block(**hidden_encoder_states) + + hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] + + sharded_logits = self.lm_head(x=hidden_states)["logits"] + + fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] + + return fp32_sharded_logits, hidden_states + + def get_block_compute_costs(self): + """Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" + model_config = self.config + d_ff = model_config.intermediate_size + d_qkv = model_config.hidden_size // model_config.num_attention_heads + block_compute_costs = { + # CausalSelfAttention (qkv proj + attn out) + MLP + LlamaDecoderLayer: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size + + 3 * d_ff * model_config.hidden_size, + # This is the last lm_head + TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size, + } + return block_compute_costs + + def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size): + """Get flops per second for a given model""" + world_size = self.parallel_context.world_pg.size() + try: + num_key_values_heads = self.config.num_key_value_heads + except AttributeError: + num_key_values_heads = self.config.num_attention_heads + + model_flops, hardware_flops = get_flops( + num_layers=self.config.num_hidden_layers, + hidden_size=self.config.hidden_size, + num_heads=self.config.num_attention_heads, + num_key_value_heads=num_key_values_heads, + vocab_size=self.config.vocab_size, + ffn_hidden_size=self.config.intermediate_size, + seq_len=sequence_length, + batch_size=global_batch_size, + ) + + model_flops_per_s = model_flops / (iteration_time_in_sec * world_size * 1e12) + hardware_flops_per_s = hardware_flops / (iteration_time_in_sec * world_size * 1e12) + return model_flops_per_s, hardware_flops_per_s + + +@torch.jit.script +def masked_mean(loss, label_mask, dtype): + # type: (Tensor, Tensor, torch.dtype) -> Tensor + return (loss * label_mask).sum(dtype=dtype) / label_mask.sum() + + +class Loss(nn.Module): + def __init__(self, tp_pg: dist.ProcessGroup, sp_pg: dist.ProcessGroup): + super().__init__() + self.tp_pg = tp_pg + self.sp_pg = sp_pg + + def forward( + self, + sharded_logits: torch.Tensor, # [seq_length, batch_size, logits] + label_ids: torch.Tensor, # [batch_size, seq_length] + label_mask: torch.Tensor, # [batch_size, seq_length] + ) -> Dict[str, torch.Tensor]: + # Megatron by defaults cast everything in fp32. `--f16-lm-cross-entropy` is an option you can use to keep current precision. + # https://github.com/NVIDIA/Megatron-LM/blob/f267e6186eae1d6e2055b412b00e2e545a8e896a/megatron/model/gpt_model.py#L38 + + # ring attention: split the label as well + if isinstance(label_ids, torch.Tensor) and self.sp_pg.size() > 1: + world_size = self.sp_pg.size() + rank = dist.get_rank(self.sp_pg) + label_ids, label_mask = zigzag_split(rank, world_size, label_ids, label_mask) + loss = sharded_cross_entropy( + sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float + ).transpose(0, 1) + # TODO @thomasw21: It's unclear what kind of normalization we want to do. + loss = masked_mean(loss, label_mask, dtype=torch.float) + # I think indexing causes a sync we don't actually want + # loss = loss[label_mask].sum() + return {"loss": loss} + + +class LlamaForTraining(NanotronModel): + def __init__( + self, + config: LlamaConfig, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: Optional[RandomStates] = None, + ): + super().__init__() + self.model = LlamaModel(config=config, parallel_context=parallel_context, parallel_config=parallel_config) + self.loss = PipelineBlock( + p2p=self.model.p2p, + module_builder=Loss, + module_kwargs={"tp_pg": parallel_context.tp_pg, "sp_pg": parallel_context.sp_pg}, + module_input_keys={ + "sharded_logits", + "label_ids", + "label_mask", + }, + module_output_keys={"loss"}, + ) + self.parallel_context = parallel_context + self.config = config + self.parallel_config = parallel_config + + def forward( + self, + input_ids: Union[torch.Tensor, TensorPointer], + input_mask: Union[torch.Tensor, TensorPointer], + label_ids: Union[torch.Tensor, TensorPointer], + label_mask: Union[torch.Tensor, TensorPointer], + ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + sharded_logits = self.model( + input_ids=input_ids, + input_mask=input_mask, + ) + loss = self.loss( + sharded_logits=sharded_logits, + label_ids=label_ids, + label_mask=label_mask, + )["loss"] + return {"loss": loss} + + @torch.no_grad() + def init_model_randomly(self, config: Config): + """Initialize model parameters randomly. + Note: + Layernorm weight all 0 or 1 depending on `apply_layernorm_1p` + """ + init_method = config.model.init_method + if isinstance(init_method, RandomInit): + parametrizator_cls = StandardParametrizator + elif isinstance(init_method, SpectralMupInit): + parametrizator_cls = SpectralMupParametrizator + else: + raise ValueError(f"Unknown init method {init_method}") + + parametrizator = parametrizator_cls(config=config.model) + + log_rank( + f"Parametrizing model parameters using {parametrizator.__class__.__name__}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + model = self + initialized_parameters = set() + # Handle tensor parallelism + module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()} + # Fix the root_model + module_id_to_prefix[id(model)] = "" + + for param_name, param in model.named_parameters(): + assert isinstance(param, NanotronParameter) + + module_name, param_name = param_name.rsplit(".", 1) + + if param.is_tied: + tied_info = param.get_tied_info() + full_param_name = tied_info.get_full_name_from_module_id_to_prefix( + module_id_to_prefix=module_id_to_prefix + ) + else: + full_param_name = f"{module_name}.{param_name}" + + if full_param_name in initialized_parameters: + # Already initialized + continue + + module = model.get_submodule(module_name) + parametrizator.parametrize(param_name, module) + + assert full_param_name not in initialized_parameters + initialized_parameters.add(full_param_name) + + assert initialized_parameters == { + param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) + if param.is_tied + else name + for name, param in model.named_parameters() + }, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}" + + def get_embeddings_lm_head_tied_names(self): + """Get the names of the tied embeddings and lm_head weights""" + if self.config.tie_word_embeddings is True: + return ["model.token_position_embeddings.pp_block.token_embedding.weight", "model.lm_head.pp_block.weight"] + else: + return [] + + def get_block_compute_costs(self): + """Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" + return self.model.get_block_compute_costs() + + def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size): + """Get flops per second for a given model""" + return self.model.get_flops_per_sec(iteration_time_in_sec, sequence_length, global_batch_size) + + +def get_flops( + num_layers, + hidden_size, + num_heads, + num_key_value_heads, + vocab_size, + seq_len, + ffn_hidden_size, + batch_size=1, +): + """Counts flops in an decoder-only model + Args: + num_layers: number of decoder layers + hidden_size: hidden size of the model + num_heads: number of heads in the model + num_key_value_heads: number of key/value heads in the model + ffn_hidden_size: hidden size of the FFN + vocab_size: size of the vocabulary + seq_len: sequence length of the decoder + batch_size: batch size + Returns: + model_flops: flops in the model (should be independent of the hardware and model implementation) + hardware_flops: flops in the hardware (actual flops performed on the hardware). Check 6.3 in https://arxiv.org/pdf/2205.05198.pdf + """ + if num_key_value_heads is None: + num_key_value_heads = num_heads + hidden_size_per_head = hidden_size // num_heads + # In the following we mark the reduced dimension with parentheses + # decoder + # self attention + ## qkv projection + decoder_qkv_proj_flops_fwd = ( + 2 * num_layers * batch_size * seq_len * (hidden_size) * num_heads * hidden_size_per_head + + 2 * num_layers * batch_size * seq_len * (hidden_size) * 2 * num_key_value_heads * hidden_size_per_head + ) + ## qk logits + decoder_qk_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (hidden_size_per_head) * seq_len + ## v logits + decoder_v_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (seq_len) * hidden_size_per_head + ## attn out + decoder_attn_out_flops_fwd = ( + 2 * num_layers * batch_size * num_heads * seq_len * (hidden_size_per_head) * hidden_size + ) + # FF + ## 1st layer + decoder_ffn_1_flops_fwd = 4 * num_layers * batch_size * seq_len * (hidden_size) * ffn_hidden_size + ## 2nd layer + decoder_ffn_2_flops_fwd = 2 * num_layers * batch_size * seq_len * (ffn_hidden_size) * hidden_size + + decoder_flops_fwd = ( + decoder_qkv_proj_flops_fwd + + decoder_qk_logits_flops_fwd + + decoder_v_logits_flops_fwd + + decoder_attn_out_flops_fwd + + decoder_ffn_1_flops_fwd + + decoder_ffn_2_flops_fwd + ) + + # lm head + lm_head_flops_fwd = 2 * batch_size * seq_len * (hidden_size) * vocab_size + + # the bwd pass requires double the flops in case of matmuls to calculate the gradients with respect to + # both input and weight tensors + model_flops = 3 * (decoder_flops_fwd + lm_head_flops_fwd) # 1 for fwd + 2 for bwd + + hardware_flops = model_flops # TODO: This is a placeholder for now + + return model_flops, hardware_flops diff --git a/src/nanotron/nn/layer_norm.py b/src/nanotron/nn/layer_norm.py index 688eaa78..2db78f49 100644 --- a/src/nanotron/nn/layer_norm.py +++ b/src/nanotron/nn/layer_norm.py @@ -51,3 +51,21 @@ def forward( is_rms_norm=True, return_dropout_mask=return_dropout_mask, ) + + +# equivalent to TritonRMSNorm +class RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-5): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, input): + input_dtype = input.dtype + input = input.to(torch.float32) + variance = input.pow(2).mean(-1, keepdim=True) + input = input * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * input.to(input_dtype) diff --git a/src/nanotron/scaling/parametrization.py b/src/nanotron/scaling/parametrization.py index e6241651..5ad8f769 100644 --- a/src/nanotron/scaling/parametrization.py +++ b/src/nanotron/scaling/parametrization.py @@ -4,7 +4,7 @@ from typing import Dict from nanotron.config import ModelArgs -from nanotron.nn.layer_norm import TritonRMSNorm +from nanotron.nn.layer_norm import RMSNorm, TritonRMSNorm from nanotron.parallel.tensor_parallel.nn import ( TensorParallelColumnLinear, TensorParallelEmbedding, @@ -37,6 +37,7 @@ def __init__(self, config: ModelArgs): TensorParallelColumnLinear: self._parametrize_column_linear, TensorParallelRowLinear: self._parametrize_row_linear, TritonRMSNorm: self._parametrize_layer_norm, + RMSNorm: self._parametrize_layer_norm, TensorParallelEmbedding: self._parametrize_embedding, } @@ -88,6 +89,7 @@ def __init__(self, config: ModelArgs): TensorParallelColumnLinear: self._parametrize_mup_weight, TensorParallelRowLinear: self._parametrize_mup_weight, TritonRMSNorm: self._parametrize_layer_norm, + RMSNorm: self._parametrize_layer_norm, TensorParallelEmbedding: self._parametrize_embedding, } self.std = 1.0 diff --git a/tests/test_llama_generation.py b/tests/test_llama_generation.py new file mode 100644 index 00000000..065af773 --- /dev/null +++ b/tests/test_llama_generation.py @@ -0,0 +1,189 @@ +""" +Nanotron Inference Script + +Usage: +CUDA_LAUNCH_BLOCKING=1 CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=1 --nnodes=1 --rdzv_backend=c10d --rdzv_endpoint=localhost:29600 --max_restarts=0 --tee=3 tests/test_llama_generation.py --ckpt-path /fsx/haojun/lighteval_evaluation_model/Llama-3-8B-split + +""" + +import argparse +from pathlib import Path + +import torch +from nanotron import distributed as dist +from nanotron import logging +from nanotron.config import ( + LoggingArgs, + ParallelismArgs, + get_config_from_file, +) +from nanotron.logging import log_rank, set_ranks_logging_level +from nanotron.models import build_model +from nanotron.models.llama3_1 import LlamaForTraining as LlamaForTraining_test +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import sanity_check +from nanotron.parallel.pipeline_parallel.engine import ( + OneForwardOneBackwardPipelineEngine, +) +from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode +from nanotron.random import ( + RandomStates, + get_current_random_state, + get_synced_random_state, + set_random_seed, +) +from nanotron.serialize import load_weights +from nanotron.trainer import mark_tied_parameters +from transformers import AutoModelForCausalLM, AutoTokenizer + +logger = logging.get_logger(__name__) + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt-path", type=Path, required=True, help="Checkpoint path") + parser.add_argument("--max-new-tokens", type=int, default=100, help="Maximum number of new tokens to generate") + return parser.parse_args() + + +def main(): + args = get_args() + + assert args.ckpt_path.exists(), f"Checkpoint path {args.ckpt_path} does not exist" + + config = get_config_from_file((args.ckpt_path / "config.yaml").as_posix()) + model_config = config.model.model_config + tokenizer_path = config.tokenizer.tokenizer_name_or_path + + # as tp/pp/sp will introduce small differences in the output, we need to set them to 1 + parallel_config = ParallelismArgs( + dp=1, + pp=1, + tp=1, + sp=1, + pp_engine=OneForwardOneBackwardPipelineEngine(), + tp_mode=TensorParallelLinearMode.ALL_REDUCE, + tp_linear_async_communication=False, + ) + + # Initialise all process groups + parallel_context = ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + sequence_parallel_size=parallel_config.sp, + ) + + # Set log levels + logging_config = LoggingArgs( + log_level="info", + log_level_replica="info", + ) + + # Set log levels + set_ranks_logging_level(parallel_context=parallel_context, logging_config=logging_config) + + log_rank(f"model_config: {model_config}", logger=logger, level=logging.INFO, rank=0) + log_rank(f"tokenizer_path: {tokenizer_path}", logger=logger, level=logging.INFO, rank=0) + + dtype = torch.bfloat16 + + # Set random states + set_random_seed(42) + + # Get synchronized random states + if parallel_config.tp_mode is TensorParallelLinearMode.ALL_REDUCE: + random_states = RandomStates( + {"tp_synced": get_synced_random_state(random_state=get_current_random_state(), pg=parallel_context.tp_pg)} + ) + else: + # We don't need to sync across TP when using sequence parallel (REDUCE_SCATTER) + random_states = RandomStates({}) + + model = build_model( + model_builder=lambda: LlamaForTraining_test( + config=model_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=random_states, + ), + dtype=dtype, + parallel_context=parallel_context, + ) + + # Mark some parameters as tied + # TODO @nouamane: this is only needed for training, can we just mark params as NanotronParameter instead? + mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config) + + # Sanity check model + sanity_check(root_module=model) + + # Load checkpoint + checkpoint_path = args.ckpt_path + log_rank( + f"Loading checkpoint from {checkpoint_path}:", + logger=logger, + level=logging.INFO, + rank=0, + ) + load_weights(model=model, parallel_context=parallel_context, root_folder=checkpoint_path) + + ## Tokenizer + pretrained_model_name_or_path = "meta-llama/Meta-Llama-3-8B" + tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) + tokenizer.pad_token = tokenizer.eos_token + + ## Transformers model + attn_implementation = "flash_attention_2" # 'sdpa' / 'flash_attention_2' + transformer_model = AutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path, torch_dtype=torch.bfloat16, attn_implementation=attn_implementation + ).to("cuda") + + model.eval() + transformer_model.eval() + + # Input prompt + input_text = "The future of AI is" + inputs = tokenizer(input_text, return_tensors="pt").to("cuda") + + # TODO: also compare the generation with cache. + for i in range(args.max_new_tokens): + output_logits = transformer_model(**inputs)[ + "logits" + ] # output logits of transformer model is still in float32 even dtype is bfloat16 + my_output_logits = model.model(inputs["input_ids"], inputs["attention_mask"]) + next_token_id = torch.argmax(output_logits[:, -1, :], dim=-1) + my_next_token_id = torch.argmax(my_output_logits[-1, :, :], dim=-1) + try: + # test logits and generation on the same time. + torch.testing.assert_close( + my_output_logits[:, 0, :], output_logits[0, :, :], rtol=1e-5, atol=1e-5 + ) # check if the output logits are close + assert torch.equal( + my_output_logits[:, 0, :], output_logits[0, :, :] + ), "Output logits are not the same" # check if the output logits are the same + except AssertionError as e: + print(f"Token {i+1} failed: {e}") + print("Reference: ", output_logits) + print("My output: ", my_output_logits) + + assert ( + next_token_id == my_next_token_id + ), f"Predictions are not the same: {next_token_id} != {my_next_token_id}" + inputs["input_ids"] = torch.cat([inputs["input_ids"], next_token_id.unsqueeze(-1)], dim=-1) + inputs["attention_mask"] = torch.cat( + [inputs["attention_mask"], torch.ones(1, 1).to(dtype=torch.bool, device="cuda")], dim=-1 + ) + if next_token_id == tokenizer.eos_token_id: + break + print("Input prompt:", input_text) + print( + "Generated text:", tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)[len(input_text) :] + ) # remove the input text from the generated text + print("Test passed!") + + dist.barrier() + + +if __name__ == "__main__": + main()