Skip to content

Commit

Permalink
exact output logits match for LLaMA-3
Browse files Browse the repository at this point in the history
  • Loading branch information
zzhhjjj committed Aug 14, 2024
1 parent 03d67f2 commit f40fa05
Show file tree
Hide file tree
Showing 5 changed files with 1,342 additions and 10 deletions.
98 changes: 89 additions & 9 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
"""PyTorch LLaMa model."""

from typing import Dict, Optional, Union, List
from typing import Dict, List, Optional, Union

import torch
from torch import nn
Expand Down Expand Up @@ -74,9 +74,10 @@ def init_rotary_embeddings(self):
self.freqs_cis = self.freqs_cis.to(torch.float)
assert self.freqs_cis.dtype == torch.float
freqs = 1.0 / (
self.theta
** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cuda")[: (self.dim // 2)] / self.dim)
)
self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cpu")[: (self.dim // 2)] / self.dim)
).to(
"cuda"
) # should be computed on CPU, otherwise different results with Transformers.
t = torch.arange(self.end, device="cuda")
freqs = torch.outer(t, freqs).float()
complex_freqs = torch.polar(torch.ones_like(freqs), freqs)
Expand Down Expand Up @@ -118,6 +119,84 @@ def forward(
return x_out.type(dtype)


## Copy from transformers. Non interleaved version of RoPE. Will be refactored later
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)


class LlamaRotaryEmbedding(nn.Module):
def __init__(self, dim: int, end: int, theta: float = 500000.0):
super().__init__()
self.dim = dim
self.end = end
self.theta = theta
self.init_rotary_embeddings()

def init_rotary_embeddings(self):
inv_freq = 1.0 / (
self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cpu") / self.dim)
) # important to compute on CPU
# inv_freq = apply_scaling(inv_freq) # if LLaMA 3.1
self.register_buffer(
"inv_freq", torch.empty(self.dim // 2, dtype=torch.float, device="cuda"), persistent=False
)
self.inv_freq = self.inv_freq.to(
torch.float
) # make it float32 before copy to avoid precision loss during copy_
self.inv_freq.copy_(inv_freq)

saved_inv_freq = torch.load("/fsx/haojun/LLaMA/.cache/activation_values/inv_freq.pt")
assert torch.equal(self.inv_freq.cpu(), saved_inv_freq), "inv_freq mismatch."

@torch.no_grad()
def forward(
self,
x: torch.Tensor, # [batch_size, seq_length, num_heads, d_qk]
position_ids: Optional[torch.LongTensor], # [batch_size, seq_length]
):
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=2):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed


class GLUActivation(nn.Module):
def __init__(self, act_fn_name: str):
super().__init__()
Expand Down Expand Up @@ -620,9 +699,9 @@ def __init__(

self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg)

self.recompute_layer = parallel_config.recompute_layer

def _core_forward(
self,
hidden_states: Union[torch.Tensor, TensorPointer],
Expand All @@ -641,20 +720,20 @@ def _core_forward(
hidden_states = hidden_states + residual

return hidden_states, output["sequence_mask"]

def _checkpointed_forward(
self,
hidden_states: torch.Tensor,
sequence_mask: torch.Tensor,
) -> List[torch.Tensor]:
) -> List[torch.Tensor]:
return CheckpointFunction.apply(self._core_forward, True, hidden_states, sequence_mask)

def forward(
self,
hidden_states: Union[torch.Tensor, TensorPointer],
sequence_mask: Union[torch.Tensor, TensorPointer],
) -> Dict[str, Union[torch.Tensor, TensorPointer]]:

if self.recompute_layer and not isinstance(hidden_states, TensorPointer):
hidden_states, sequence_mask = self._checkpointed_forward(hidden_states, sequence_mask)
else:
Expand All @@ -665,6 +744,7 @@ def forward(
"sequence_mask": sequence_mask,
}


class Embedding(nn.Module, AttachableStore):
def __init__(self, tp_pg: dist.ProcessGroup, config: LlamaConfig, parallel_config: Optional[ParallelismArgs]):
super().__init__()
Expand Down
Loading

0 comments on commit f40fa05

Please sign in to comment.