-
Notifications
You must be signed in to change notification settings - Fork 0
/
plr_embeddings.py
58 lines (43 loc) · 2.09 KB
/
plr_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""
Adapted from https://github.com/yandex-research/tabular-dl-tabr/blob/main/lib/deep.py.
"""
import torch
from torch import nn
class PeriodicEmbeddings(nn.Module):
def __init__(self, features_dim, frequencies_dim, frequencies_scale, shared_frequencies=False):
super().__init__()
if shared_frequencies:
self.frequencies = nn.Parameter(torch.randn(1, frequencies_dim) * frequencies_scale)
else:
self.frequencies = nn.Parameter(torch.randn(features_dim, frequencies_dim) * frequencies_scale)
def forward(self, x):
x = 2 * torch.pi * self.frequencies[None, ...] * x[..., None]
x = torch.cat([torch.cos(x), torch.sin(x)], axis=-1)
return x
class NLinear(nn.Module):
def __init__(self, features_dim, input_dim, output_dim, bias=True):
super().__init__()
init_max = 1 / torch.tensor(input_dim).sqrt()
self.weight = nn.Parameter(torch.Tensor(features_dim, input_dim, output_dim).uniform_(-init_max, init_max))
self.bias = nn.Parameter(torch.Tensor(features_dim, output_dim).uniform_(-init_max, init_max)) if bias else None
def forward(self, x):
x = (x[..., None] * self.weight).sum(axis=-2)
if self.bias is not None:
x = x + self.bias
return x
class PLREmbeddings(nn.Module):
def __init__(self, features_dim, frequencies_dim, frequencies_scale, embedding_dim, shared_linear=False,
shared_frequencies=False):
super().__init__()
if shared_linear:
linear_layer = nn.Linear(in_features=frequencies_dim * 2, out_features=embedding_dim)
else:
linear_layer = NLinear(features_dim=features_dim, input_dim=frequencies_dim * 2, output_dim=embedding_dim)
self.plr_embeddings = nn.Sequential(
PeriodicEmbeddings(features_dim=features_dim, frequencies_dim=frequencies_dim,
frequencies_scale=frequencies_scale, shared_frequencies=shared_frequencies),
linear_layer,
nn.ReLU()
)
def forward(self, x):
return self.plr_embeddings(x)