diff --git a/torchtrain/models/llama/model.py b/torchtrain/models/llama/model.py index a3fbeb4b..2cbefc5f 100644 --- a/torchtrain/models/llama/model.py +++ b/torchtrain/models/llama/model.py @@ -278,6 +278,40 @@ def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x)) +class RotaryEmbedding(nn.Module): + def __init__(self, params: ModelArgs): + """ + Initialize the embedding module. + """ + super().__init__() + self.params = params + self.tok_embeddings = nn.Embedding( + params.vocab_size, params.dim + ) + + self.freqs_cis = precompute_freqs_cis( + # Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation of models is 4096. + # Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training or fine-tuning. + self.params.dim // self.params.n_heads, self.params.max_seq_len * 2 + ) + + def forward(self, tokens: torch.Tensor): + """ + Perform a forward pass through the embedding module. + + Args: + tokens (torch.Tensor): Input tensor. + + Returns: + Tuple(torch.Tensor, torch.Tensor): Embedded input tensor and freqs_cis + """ + _bsz, seqlen = tokens.shape + h = self.tok_embeddings(tokens) + self.freqs_cis = self.freqs_cis.to(h.device) + freqs_cis = self.freqs_cis[0 : seqlen] + return h, freqs_cis + + class TransformerBlock(nn.Module): def __init__(self, layer_id: int, args: ModelArgs): """ @@ -360,9 +394,7 @@ def __init__(self, params: ModelArgs): self.vocab_size = params.vocab_size self.n_layers = params.n_layers - self.tok_embeddings = nn.Embedding( - params.vocab_size, params.dim - ) + self.embeddings = RotaryEmbedding(params) self.layers = torch.nn.ModuleList() for layer_id in range(params.n_layers): @@ -373,12 +405,6 @@ def __init__(self, params: ModelArgs): params.dim, params.vocab_size, bias=False ) - self.freqs_cis = precompute_freqs_cis( - # Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation of models is 4096. - # Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training or fine-tuning. - self.params.dim // self.params.n_heads, self.params.max_seq_len * 2 - ) - def forward(self, tokens: torch.Tensor): """ Perform a forward pass through the Transformer model. @@ -390,10 +416,7 @@ def forward(self, tokens: torch.Tensor): torch.Tensor: Output logits after applying the Transformer model. """ - _bsz, seqlen = tokens.shape - h = self.tok_embeddings(tokens) - self.freqs_cis = self.freqs_cis.to(h.device) - freqs_cis = self.freqs_cis[0 : seqlen] + h, freqs_cis = self.embeddings(tokens) for layer in self.layers: h = layer(h, freqs_cis)