-
Notifications
You must be signed in to change notification settings - Fork 0
/
ops.py
169 lines (128 loc) · 5.35 KB
/
ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn import LayerNorm
from kernel import glorot_uniform_af, bias_sigmod_ele
class DropoutRowwise(nn.Module):
def __init__(self, p):
super(DropoutRowwise, self).__init__()
self.p = p
self.dropout = nn.Dropout(p=p)
def forward(self, x):
dropout_mask = torch.ones_like(x[:, 0:1, :, :])
dropout_mask = self.dropout(dropout_mask)
return dropout_mask * x
class DropoutColumnwise(nn.Module):
def __init__(self, p):
super(DropoutColumnwise, self).__init__()
self.p = p
self.dropout = nn.Dropout(p=p)
def forward(self, x):
dropout_mask = torch.ones_like(x[:, :, 0:1, :])
dropout_mask = self.dropout(dropout_mask)
return dropout_mask * x
class Transition(nn.Module):
def __init__(self, d, n=4):
super(Transition, self).__init__()
self.norm = LayerNorm(d)
self.linear1 = Linear(d, n * d, initializer='relu')
self.linear2 = Linear(n * d, d, initializer='zeros')
def forward(self, src):
x = self.norm(src)
x = self.linear2(F.relu(self.linear1(x)))
return src + x
class OutProductMean(nn.Module):
def __init__(self, n_feat=64, n_feat_out=128, n_feat_proj=32):
super(OutProductMean, self).__init__()
self.layernormM = LayerNorm(n_feat)
self.linear_a = Linear(n_feat, n_feat_proj)
self.linear_b = Linear(n_feat, n_feat_proj)
self.o_linear = Linear(n_feat_proj * n_feat_proj,
n_feat_out,
initializer='zero',
use_bias=True)
def forward(self, M):
M = self.layernormM(M)
left_act = self.linear_a(M)
right_act = self.linear_b(M)
O = torch.einsum('bsid,bsje->bijde', left_act, right_act).contiguous()
O = rearrange(O, 'b i j d e -> b i j (d e)')
Z = self.o_linear(O)
return Z
class Linear(nn.Linear):
"""
A Linear layer with built-in nonstandard initializations. Called just
like torch.nn.Linear.
Implements the initializers in 1.11.4, plus some additional ones found
in the code.
"""
def __init__(
self,
feature_in: int,
feature_out: int,
initializer: str = 'linear',
use_bias: bool = True,
bias_init: float = 0.,
):
super(Linear, self).__init__(feature_in, feature_out, bias=use_bias)
self.use_bias = use_bias
if initializer == 'linear':
glorot_uniform_af(self.weight, gain=1.0)
elif initializer == 'relu':
glorot_uniform_af(self.weight, gain=2.0)
elif initializer == 'zeros':
nn.init.zeros_(self.weight)
if self.use_bias:
with torch.no_grad():
self.bias.fill_(bias_init)
class SelfAttention(nn.Module):
"""
Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors
"""
def __init__(self, qkv_dim, c, n_head, out_dim, gating=True, last_bias_fuse=False):
super(SelfAttention, self).__init__()
self.qkv_dim = qkv_dim
self.c = c
self.n_head = n_head
self.out_dim = out_dim
self.gating = gating
self.last_bias_fuse = last_bias_fuse
self.scaling = self.c**(-0.5)
# self.to_qkv = Linear(qkv_dim, 3 * n_head * c, initializer='linear')
self.to_q = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
self.to_k = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
self.to_v = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
if gating:
self.gating_bias = nn.parameter.Parameter(data=torch.ones((n_head * c,)))
self.gating_linear = Linear(qkv_dim, n_head * c, initializer='zero', use_bias=False)
self.o_linear = Linear(n_head * c,
out_dim,
initializer='zero',
use_bias=(not last_bias_fuse))
def forward(self, in_data, nonbatched_bias=None):
"""
:param in_data: [batch_size1, batch_size2, len_qkv, qkv_dim]
:param bias: None or [batch_size1, batch_size2, n_head, len_q, len_kv]
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
"""
# qkv = self.to_qkv(in_data).chunk(3, dim=-1)
# q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv)
q = self.to_q(in_data)
k = self.to_k(in_data)
v = self.to_k(in_data)
q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head),
[q, k, v])
q = q * self.scaling
logits = torch.matmul(q, k.transpose(-1, -2))
if nonbatched_bias is not None:
logits += nonbatched_bias.unsqueeze(1)
weights = torch.softmax(logits, dim=-1)
# weights = softmax(logits)
weighted_avg = torch.matmul(weights, v)
weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)')
if self.gating:
gate_values = self.gating_linear(in_data)
weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias, weighted_avg)
output = self.o_linear(weighted_avg)
return output