-
Notifications
You must be signed in to change notification settings - Fork 0
/
symbol_retrieval.py
308 lines (249 loc) · 12.1 KB
/
symbol_retrieval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
import torch
import torch.nn as nn
from positional_encoding import RelativePositionalEncoding
class SymbolicAttention(nn.Module):
def __init__(self,
d_model: int,
n_heads: int,
n_symbols: int,
dropout: float = 0.0,
scale: float = None,
trainable_symbols: bool = True):
"""
Symbolic Attention.
Learns a library of "symbols" and corresponding template features.
For a given input, retrieves a symbol from the symbol library via attention.
Parameters
----------
d_model : int
model dimension. this is the dimension of the input and the dimension of the symbols and template features.
n_heads : int
number of heads in symbolic attention.
n_symbols : int
number of symbols in the symbol library.
dropout : float, optional
dropout probability, by default 0.0
scale : float, optional
scaling factor in scaled_dot_product_attention, by default None
trainable_symbols: bool, optional
whether to make the symbol library trainable, by default True
"""
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.n_symbols = n_symbols
self.dropout = dropout
self.scale = scale
self.trainable_symbols = trainable_symbols
self.q_proj = nn.Linear(self.d_model, self.d_model)
self.template_features = nn.Parameter(torch.empty(self.n_symbols, self.d_model))
self.symbol_library = nn.Parameter(torch.empty(self.n_symbols, self.d_model), requires_grad=trainable_symbols)
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.normal_(self.template_features)
torch.nn.init.normal_(self.symbol_library)
def forward(self, x):
batch_size, seq_len, dim = x.size()
# create query from input
query = self.q_proj(x)
query = query.view(batch_size, seq_len, self.n_heads, dim // self.n_heads).transpose(1, 2)
# create keys from template features
key = self.template_features.view(self.n_symbols, self.n_heads, self.d_model // self.n_heads).transpose(0, 1)
key = self._repeat_kv(key, batch_size)
# create values from symbol library
value = self.symbol_library.view(self.n_symbols, self.n_heads, self.d_model // self.n_heads).transpose(0, 1)
value = self._repeat_kv(value, batch_size)
retrieved_symbols = torch.nn.functional.scaled_dot_product_attention(
query, key, value,
scale=self.scale, dropout_p=self.dropout, attn_mask=None, is_causal=False)
retrieved_symbols = retrieved_symbols.transpose(1, 2).contiguous().view(batch_size, seq_len, dim)
return retrieved_symbols
def _repeat_kv(self, x, batch_size):
"""
template_features and symbol_library are of shape (n_heads, n_s, d_s//n_heads).
repeat for each input and add a batch dimension of size batch_size.
"""
return x.unsqueeze(0).repeat(batch_size, 1, 1, 1)
class PositionalSymbolRetriever(nn.Module):
def __init__(self, symbol_dim, max_length, sinusoidal=False):
"""
Postional Symbol Retriever.
Learns a library of "symbols".
Retrieves a symbol for each object based on its position.
Parameters
----------
symbol_dim : int
dimension of the symbols.
max_symbols : int
maximum number of symbols.
"""
super().__init__()
self.symbol_dim = symbol_dim
self.max_length = max_length
self.sinusoidal = sinusoidal
self.symbol_library = nn.Embedding(self.max_length, self.symbol_dim)
# TODO: implement sinusoidal symbols?
def forward(self, x):
device = x.device
batch_size, seq_len, dim = x.size()
pos = torch.arange(0, seq_len, dtype=torch.long, device=device)
retrieved_symbols = self.symbol_library(pos).unsqueeze(0).repeat(batch_size, 1, 1)
return retrieved_symbols
# TODO: add support for causal-only position-relative symbols?
# cuts param count by half
class PositionRelativeSymbolRetriever(nn.Module):
def __init__(self, symbol_dim, max_rel_pos):
"""
Position-Relative Symbol Retriever.
For i -> j, the symbol s_{ij} encodes the relative position j - i.
Parameters
----------
symbol_dim : int
dimension of the symbols.
max_rel_pos : int
maximum relative position encoded by symbols.
Positions exceeding this will be truncated.
"""
super().__init__()
self.symbol_dim = symbol_dim
self.max_rel_pos = max_rel_pos
self.rel_pos_enc = RelativePositionalEncoding(dim=symbol_dim, max_rel_pos=max_rel_pos)
def forward(self, x):
length = x.shape[1]
return self.rel_pos_enc(length)
class RelationalSymbolicAttention(nn.Module):
def __init__(self,
d_model: int,
rel_n_heads: int,
symbolic_attn_n_heads: int,
n_symbols: int,
nbhd_delta: int,
causal_nbhd: bool = True,
include_self: bool = False,
normalize_rels: bool = True,
dropout: float = 0.0,
rel_scale: float = None,
symbolic_attn_scale: float = None):
"""
Relational symbolic attention module.
Retrieves a symbol for each object in the input based on its relationship with its neighborhood.
First, we compute a local relation vector for each object in the input. This local relation vector
is then used to retrieve a symbol from the symbol library via symbolic attention.
Parameters
----------
d_model : int
Model dimension. this is the dimension of the input and the dimension of the symbols and template features.
rel_n_heads : int
Dimensionality of relations computed with neighborhood.
symbolic_attn_n_heads : int
Number of symbolic attention heads.
n_symbols : int
Number of symbols to learn in the symbol library.
nbhd_delta : int
The size of the neighborhood.
causal_nbhd : bool, optional
Whether to use causal neighborhood. if causal_nbhd is True, the neighborhood is [i-nbhd_delta, i].
if causal_nbhd is False, the neighborhood is [i-nbhd_delta, i+nbhd_delta]. Defaults to True.
include_self : bool, optional
Whether to include self in the neighborhood. E.g., if False and causal_nbhd, the neighborhood is
[i-nbhd_delta, i-1]. If False and not causal_nbhd, the neighborhood is [i-nbhd_delta, i-1] U [i+1, i+nbhd_delta].
Defaults to False.
normalize_rels : bool, optional
Whether to normalize relations with softmax across neighborhood. Defaults to True.
dropout : float, optional
The dropout rate. Defaults to 0.0.
rel_scale : float, optional
The scaling factor when normalizing relations via softmax. If None, it is computed based on model_dim and rel_n_heads.
symbolic_attn_scale : float, optional
The scaling factor used in symbolic attention.
Attributes
----------
symbolic_attention : SymbolicAttention
The symbolic attention module.
q_proj : nn.Linear
Linear layer for projecting the query.
k_proj : nn.Linear
Linear layer for projecting the key.
model_dim_proj : nn.Linear
Linear layer for projecting the neighborhood relation vector to model_dim.
"""
super().__init__()
self.d_model = d_model
self.rel_n_heads = rel_n_heads
self.symbolic_attn_n_heads = symbolic_attn_n_heads
self.n_symbols = n_symbols
self.nbhd_delta = nbhd_delta
self.causal_nbhd = causal_nbhd
self.dropout = dropout
self.rel_scale = rel_scale if rel_scale is not None else (d_model//rel_n_heads) ** -0.5
self.symbolic_attn_scale = symbolic_attn_scale
self.include_self = include_self
self.normalize_rels = normalize_rels
self.nbhd_rel_dim = self._compute_nbhd_rel_dim(rel_n_heads, nbhd_delta, causal_nbhd, include_self)
self.symbolic_attention = SymbolicAttention(d_model, symbolic_attn_n_heads, n_symbols, dropout, symbolic_attn_scale)
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.model_dim_proj = nn.Linear(self.nbhd_rel_dim, d_model) # project neighborhood relation vector to model_dim
def forward(self, x):
batch_size, seq_len, dim = x.size()
# compute query and key transformations to compute relations with neighborhood
query = self.q_proj(x)
key = self.k_proj(x)
# reshape to (batch_size, n_heads, n, d_k); i.e., split model_dim into n_heads
query = query.view(batch_size, seq_len, self.rel_n_heads, self.d_model // self.rel_n_heads).transpose(1, 2)
key = key.view(batch_size, seq_len, self.rel_n_heads, self.d_model // self.rel_n_heads).transpose(1, 2)
# compute neighborhood mask
if self.causal_nbhd:
neighbor_mask = self.compute_causal_neighbor_mask(seq_len, self.nbhd_delta, self.include_self)
else:
neighbor_mask = self.compute_neighbor_mask(seq_len, self.nbhd_delta, self.include_self)
neighborhood_keys = key[:, :, neighbor_mask] # (batch_size, n_heads, n, Delta, d_k)
# compute relations with neighborhood
# einstein summation: R[b,h,i,j] = sum_d Q[b,h,i,d] * nbhd_K[b,h,i,j,d], where nhbd_K[b,h,i,j,d] = K[b,h,i,i-j,d]
neighbor_rel_tensor = torch.einsum('bhid,bhijd->bhij', query, neighborhood_keys) # (batch_size, n_heads, n, Delta)
if self.normalize_rels:
# normalize relations across neigborhood (of size Delta)
neighbor_rel_tensor = torch.softmax(neighbor_rel_tensor * self.rel_scale, dim=-1)
# permute dims to shape (batch_size, n, Delta, n_heads)
neighbor_rel_tensor = neighbor_rel_tensor.permute(0, 2, 3, 1)
# flatten n_heads dimension to get shape (batch_size, n, nbhd_rel_dim) [e.g., nbhd_rel_dim = Delta * n_heads]
neighbor_rel_tensor = neighbor_rel_tensor.contiguous().view(batch_size, -1, self.nbhd_rel_dim)
# project neighborhood relation vector to model_dim
neighbor_rel_tensor = self.model_dim_proj(neighbor_rel_tensor)
# compute symbolic attention
retrieved_symbols = self.symbolic_attention(neighbor_rel_tensor)
return retrieved_symbols
def _compute_nbhd_rel_dim(self, rel_n_heads, nbhd_delta, causal_nbhd, include_self):
'''computes the dimension of the neigborhood relation vector'''
if causal_nbhd:
if include_self:
return rel_n_heads * (nbhd_delta + 1)
else:
return rel_n_heads * nbhd_delta
else:
if include_self:
return rel_n_heads * (2 * nbhd_delta + 1)
else:
return rel_n_heads * (2 * nbhd_delta)
@staticmethod
def compute_neighbor_mask(n, delta, include_self=True):
'''computes the neighborhood mask for a sequence of length n and neighborhood size delta'''
sequence = torch.arange(n).unsqueeze(1)
if include_self:
neighborhood = torch.arange(-delta, delta + 1).unsqueeze(0)
else:
neighborhood = torch.concat([torch.arange(-delta, 0), torch.arange(1, delta + 1)]).unsqueeze(0)
mask = sequence + neighborhood
mask = mask.clamp(0, n - 1)
return mask
@staticmethod
def compute_causal_neighbor_mask(n, delta, include_self=False):
'''computes the causal neighborhood mask for a sequence of length n and neighborhood size delta'''
sequence = torch.arange(n).unsqueeze(1)
if include_self:
neighborhood = torch.arange(delta + 1).unsqueeze(0)
else:
neighborhood = torch.arange(1, delta + 1).unsqueeze(0)
mask = sequence - neighborhood
mask = mask.clamp(0, n - 1)
return mask