Skip to content

Commit

Permalink
[Fix] Update InternLM2 apply_rotary_pos_emb (#383)
Browse files Browse the repository at this point in the history
update
  • Loading branch information
LZHgrla authored Feb 1, 2024
1 parent 0197953 commit 58537c3
Showing 1 changed file with 5 additions and 14 deletions.
19 changes: 5 additions & 14 deletions xtuner/model/modules/dispatch/internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,20 +70,11 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
cos = cos.unsqueeze(0).unsqueeze(0).expand(len(position_ids), -1, -1, -1)
sin = sin.unsqueeze(0).unsqueeze(0).expand(len(position_ids), -1, -1, -1)
# print(q.shape, cos.shape, rotate_half(q).shape)
if q.size(2) == 1:
q_embed = (q * cos[:, :, -1, :]) + (rotate_half(q) * sin[:, :, -1, :])
else:
q_embed = (q * cos) + (rotate_half(q) * sin)

if k.size(2) == 1:
k_embed = (k * cos[:, :, -1, :]) + (rotate_half(k) * sin[:, :, -1, :])
else:
k_embed = (k * cos) + (rotate_half(k) * sin)

def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed


Expand Down

0 comments on commit 58537c3

Please sign in to comment.