-
Notifications
You must be signed in to change notification settings - Fork 3
/
rope_self_attn.py
154 lines (128 loc) · 6.29 KB
/
rope_self_attn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""
This code was originally obtained from:
https://github.com/meta-llama/codellama/blob/main/llama/model.py
"""
import torch
import torch.nn as nn
import math
from functools import partial
def init_2d_freqs(dim: int, num_heads: int, theta: float = 10.0, rotate: bool = True):
freqs_x = []
freqs_y = []
mag = 1 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
for i in range(num_heads):
angles = torch.rand(1) * 2 * torch.pi if rotate else torch.zeros(1)
fx = torch.cat([mag * torch.cos(angles), mag * torch.cos(torch.pi/2 + angles)], dim=-1)
fy = torch.cat([mag * torch.sin(angles), mag * torch.sin(torch.pi/2 + angles)], dim=-1)
freqs_x.append(fx)
freqs_y.append(fy)
freqs_x = torch.stack(freqs_x, dim=0)
freqs_y = torch.stack(freqs_y, dim=0)
freqs = torch.stack([freqs_x, freqs_y], dim=0)
return freqs
def init_t_xy(end_x: int, end_y: int):
t = torch.arange(end_x * end_y, dtype=torch.float32)
t_x = (t % end_x).float()
t_y = torch.div(t, end_x, rounding_mode='floor').float()
return t_x, t_y
def compute_mixed_cis(freqs: torch.Tensor, t_x: torch.Tensor, t_y: torch.Tensor, num_heads: int):
N = t_x.shape[0]
# No float 16 for this range
with torch.cuda.amp.autocast(enabled=False):
freqs_x = (t_x.unsqueeze(-1) @ freqs[0].unsqueeze(-2)).view(N, num_heads, -1).permute(1, 0, 2)
freqs_y = (t_y.unsqueeze(-1) @ freqs[1].unsqueeze(-2)).view(N, num_heads, -1).permute(1, 0, 2)
freqs_cis = torch.polar(torch.ones_like(freqs_x), freqs_x + freqs_y)
return freqs_cis
def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 100.0):
freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
t_x, t_y = init_t_xy(end_x, end_y)
freqs_x = torch.outer(t_x, freqs_x)
freqs_y = torch.outer(t_y, freqs_y)
freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
if freqs_cis.shape == (x.shape[-2], x.shape[-1]):
shape = [d if i >= ndim-2 else 1 for i, d in enumerate(x.shape)]
elif freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]):
shape = [d if i >= ndim-3 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
class Attention(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class RoPEAttention(Attention):
"""Multi-head Attention block with rotary position embeddings."""
def __init__(self, *args, rope_theta=10.0, rope_mixed=True, **kwargs):
super().__init__(*args, **kwargs)
self.rope_mixed = rope_mixed
if self.rope_mixed:
self.compute_cis = partial(compute_mixed_cis, num_heads=self.num_heads)
freqs = init_2d_freqs(
dim=self.dim // self.num_heads, num_heads=self.num_heads, theta=rope_theta,
rotate=True
).view(2, -1)
self.freqs = nn.Parameter(freqs, requires_grad=True)
t_x, t_y = init_t_xy(end_x=14, end_y=14)
self.register_buffer('freqs_t_x', t_x)
self.register_buffer('freqs_t_y', t_y)
else:
self.compute_cis = partial(compute_axial_cis, dim=self.dim // self.num_heads, theta=rope_theta)
freqs_cis = self.compute_cis(end_x=14, end_y=14)
self.freqs_cis = freqs_cis
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
###### Apply rotary position embedding
w = h = math.sqrt(x.shape[1] - 1)
if self.rope_mixed:
t_x, t_y = self.freqs_t_x, self.freqs_t_y
if self.freqs_t_x.shape[0] != x.shape[1] - 1:
t_x, t_y = init_t_xy(end_x=w, end_y=h)
t_x, t_y = t_x.to(x.device), t_y.to(x.device)
freqs_cis = self.compute_cis(self.freqs, t_x, t_y)
else:
freqs_cis = self.freqs_cis
if self.freqs_cis.shape[0] != x.shape[1] - 1:
freqs_cis = self.compute_cis(end_x=w, end_y=h)
freqs_cis = freqs_cis.to(x.device)
q[:, :, 1:], k[:, :, 1:] = apply_rotary_emb(q[:, :, 1:], k[:, :, 1:], freqs_cis=freqs_cis)
#########
attn = (q * self.scale) @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x