Skip to content

Commit

Permalink
Merge pull request #214 from zzhhjjj/match-transformers
Browse files Browse the repository at this point in the history
Match Transformers RoPE implementation
  • Loading branch information
3outeille authored Sep 5, 2024
2 parents 3be44ef + 2677380 commit 1456446
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 19 deletions.
3 changes: 3 additions & 0 deletions src/nanotron/config/models_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class LlamaConfig:
rms_norm_eps: float = 1e-6
rope_scaling: Optional[dict] = None
rope_theta: float = 10000.0
rope_interleaved: bool = (
False # The default value has been True, but for loading Llama3 checkpoints you have to set it to False
)
tie_word_embeddings: bool = False
use_cache: bool = True
vocab_size: int = 32000
Expand Down
137 changes: 118 additions & 19 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
"""PyTorch LLaMa model."""

from typing import Dict, Optional, Union, List
from typing import Dict, List, Optional, Union

import torch
from torch import nn
Expand Down Expand Up @@ -74,9 +74,10 @@ def init_rotary_embeddings(self):
self.freqs_cis = self.freqs_cis.to(torch.float)
assert self.freqs_cis.dtype == torch.float
freqs = 1.0 / (
self.theta
** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cuda")[: (self.dim // 2)] / self.dim)
)
self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cpu")[: (self.dim // 2)] / self.dim)
).to(
"cuda"
) # should be computed on CPU, otherwise different results with Transformers.
t = torch.arange(self.end, device="cuda")
freqs = torch.outer(t, freqs).float()
complex_freqs = torch.polar(torch.ones_like(freqs), freqs)
Expand Down Expand Up @@ -118,6 +119,78 @@ def forward(
return x_out.type(dtype)


## Copy from transformers. Non interleaved version of RoPE. Will be refactored later
class LlamaRotaryEmbedding(nn.Module):
def __init__(self, dim: int, end: int, theta: float = 500000.0):
super().__init__()
self.dim = dim
self.end = end
self.theta = theta
self.init_rotary_embeddings()

def init_rotary_embeddings(self):
inv_freq = 1.0 / (
self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device="cpu") / self.dim)
) # important to compute on CPU
self.register_buffer(
"inv_freq", torch.empty(self.dim // 2, dtype=torch.float, device="cuda"), persistent=False
)
self.inv_freq = self.inv_freq.to(
torch.float
) # make it float32 before copy to avoid precision loss during copy_
self.inv_freq.copy_(inv_freq)

@torch.no_grad()
def forward(
self,
x: torch.Tensor, # [batch_size, seq_length, num_heads, d_qk]
position_ids: Optional[torch.LongTensor], # [batch_size, seq_length]
):
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

def rotate_half(self, x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(self, q, k, cos, sin, unsqueeze_dim=2):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (self.rotate_half(q) * sin)
k_embed = (k * cos) + (self.rotate_half(k) * sin)
return q_embed, k_embed


class GLUActivation(nn.Module):
def __init__(self, act_fn_name: str):
super().__init__()
Expand Down Expand Up @@ -319,14 +392,24 @@ def __init__(
tp_recompute_allgather=parallel_config.tp_recompute_allgather,
)
# TODO(kunhao): We want to have only one version per device and not one version per layer.
self.rotary_embedding = RotaryEmbedding(
dim=self.d_qk,
end=config.max_position_embeddings,
theta=config.rope_theta,
)
if config.rope_interleaved:
self.rotary_embedding = RotaryEmbedding(
dim=self.d_qk,
end=config.max_position_embeddings,
theta=config.rope_theta,
)
else:
self.rotary_embedding = LlamaRotaryEmbedding(
dim=self.d_qk,
end=config.max_position_embeddings,
theta=config.rope_theta,
)
self.rope_interleaved = config.rope_interleaved

# NOTE: Only supported for training (TODO(fmom): position_ids not supported yet)
self.flash_rotary_embedding = FlashRotaryEmbedding(dim=self.d_qk, base=config.rope_theta, interleaved=True)
self.flash_rotary_embedding = FlashRotaryEmbedding(
dim=self.d_qk, base=config.rope_theta, interleaved=config.rope_interleaved
)

self.o_proj = TensorParallelRowLinear(
config.num_attention_heads * self.d_qk,
Expand Down Expand Up @@ -405,8 +488,16 @@ def forward(
# Compute rotary embeddings
# Note: keep track of old rotary embedding end to check if we need to enlarge k_cache and v_cache
old_rotary_embed_end = self.rotary_embedding.end
query_states = self.rotary_embedding(query_states, position_ids=position_ids)
key_states = self.rotary_embedding(key_states, position_ids=position_ids)
# interleaved version.
if self.rope_interleaved:
query_states = self.rotary_embedding(query_states, position_ids=position_ids)
key_states = self.rotary_embedding(key_states, position_ids=position_ids)
# non interleaved version.
else:
cos, sin = self.rotary_embedding(value_states, position_ids)
query_states, key_states = self.rotary_embedding.apply_rotary_pos_emb(
query_states, key_states, cos, sin
)

if "key" not in store:
# First inference iteration (Prefill)
Expand Down Expand Up @@ -545,7 +636,7 @@ def forward(
cache_seqlens=position_offsets.contiguous(),
softmax_scale=softmax_scale,
causal=True,
rotary_interleaved=False, # GPT-NeoX style
rotary_interleaved=False, # the value is not used unless rotary_cos/sin is provided. https://github.com/Dao-AILab/flash-attention
)

store.update(
Expand Down Expand Up @@ -620,9 +711,9 @@ def __init__(

self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg)

self.recompute_layer = parallel_config.recompute_layer

def _core_forward(
self,
hidden_states: Union[torch.Tensor, TensorPointer],
Expand All @@ -641,20 +732,20 @@ def _core_forward(
hidden_states = hidden_states + residual

return hidden_states, output["sequence_mask"]

def _checkpointed_forward(
self,
hidden_states: torch.Tensor,
sequence_mask: torch.Tensor,
) -> List[torch.Tensor]:
) -> List[torch.Tensor]:
return CheckpointFunction.apply(self._core_forward, True, hidden_states, sequence_mask)

def forward(
self,
hidden_states: Union[torch.Tensor, TensorPointer],
sequence_mask: Union[torch.Tensor, TensorPointer],
) -> Dict[str, Union[torch.Tensor, TensorPointer]]:

if self.recompute_layer and not isinstance(hidden_states, TensorPointer):
hidden_states, sequence_mask = self._checkpointed_forward(hidden_states, sequence_mask)
else:
Expand All @@ -665,6 +756,7 @@ def forward(
"sequence_mask": sequence_mask,
}


class Embedding(nn.Module, AttachableStore):
def __init__(self, tp_pg: dist.ProcessGroup, config: LlamaConfig, parallel_config: Optional[ParallelismArgs]):
super().__init__()
Expand Down Expand Up @@ -727,7 +819,14 @@ def __init__(
module_input_keys={"input_ids", "input_mask"},
module_output_keys={"input_embeds"},
)

log_rank(f"Initialize RoPE Theta = {config.rope_theta}", logger=logger, level=logging.INFO, rank=0)
if config.rope_interleaved:
log_rank(
"The RoPE interleaved version differs from the Transformers implementation. It's better to set rope_interleaved=False if you need to convert the weights to Transformers",
logger=logger,
level=logging.INFO,
rank=0,
)
self.decoder = nn.ModuleList(
[
PipelineBlock(
Expand Down
2 changes: 2 additions & 0 deletions src/nanotron/nn/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def forward(
)


# This is equivalent to LLaMA RMSNorm
# https://github.com/huggingface/transformers/blob/28952248b19db29ca25ccf34a5eec413376494a9/src/transformers/models/llama/modeling_llama.py#L112
class TritonRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
Expand Down

0 comments on commit 1456446

Please sign in to comment.