-
Notifications
You must be signed in to change notification settings - Fork 0
/
modules.py
154 lines (111 loc) · 5.23 KB
/
modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import torch
from torch import nn
from dgl import ops
from dgl.nn.functional import edge_softmax
class ResidualModuleWrapper(nn.Module):
def __init__(self, module, normalization, dim, **kwargs):
super().__init__()
self.normalization = normalization(dim)
self.module = module(dim=dim, **kwargs)
def forward(self, graph, x):
x_res = self.normalization(x)
x_res = self.module(graph, x_res)
x = x + x_res
return x
class FeedForwardModule(nn.Module):
def __init__(self, dim, hidden_dim_multiplier, dropout, input_dim_multiplier=1, **kwargs):
super().__init__()
input_dim = int(dim * input_dim_multiplier)
hidden_dim = int(dim * hidden_dim_multiplier)
self.linear_1 = nn.Linear(in_features=input_dim, out_features=hidden_dim)
self.dropout_1 = nn.Dropout(p=dropout)
self.act = nn.GELU()
self.linear_2 = nn.Linear(in_features=hidden_dim, out_features=dim)
self.dropout_2 = nn.Dropout(p=dropout)
def forward(self, graph, x):
x = self.linear_1(x)
x = self.dropout_1(x)
x = self.act(x)
x = self.linear_2(x)
x = self.dropout_2(x)
return x
class GCNModule(nn.Module):
def __init__(self, dim, hidden_dim_multiplier, dropout, **kwargs):
super().__init__()
self.feed_forward_module = FeedForwardModule(dim=dim,
hidden_dim_multiplier=hidden_dim_multiplier,
dropout=dropout)
def forward(self, graph, x):
degrees = graph.out_degrees().float()
degree_edge_products = ops.u_mul_v(graph, degrees, degrees)
norm_coefs = 1 / degree_edge_products ** 0.5
x = ops.u_mul_e_sum(graph, x, norm_coefs)
x = self.feed_forward_module(graph, x)
return x
class SAGEModule(nn.Module):
def __init__(self, dim, hidden_dim_multiplier, dropout, **kwargs):
super().__init__()
self.feed_forward_module = FeedForwardModule(dim=dim,
input_dim_multiplier=2,
hidden_dim_multiplier=hidden_dim_multiplier,
dropout=dropout)
def forward(self, graph, x):
message = ops.copy_u_mean(graph, x)
x = torch.cat([x, message], axis=1)
x = self.feed_forward_module(graph, x)
return x
def _check_dim_and_num_heads_consistency(dim, num_heads):
if dim % num_heads != 0:
raise ValueError('Dimension mismatch: hidden_dim should be a multiple of num_heads.')
class GATModule(nn.Module):
def __init__(self, dim, hidden_dim_multiplier, num_heads, dropout, **kwargs):
super().__init__()
_check_dim_and_num_heads_consistency(dim, num_heads)
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.input_linear = nn.Linear(in_features=dim, out_features=dim)
self.attn_linear_u = nn.Linear(in_features=dim, out_features=num_heads)
self.attn_linear_v = nn.Linear(in_features=dim, out_features=num_heads, bias=False)
self.attn_act = nn.LeakyReLU(negative_slope=0.2)
self.feed_forward_module = FeedForwardModule(dim=dim,
hidden_dim_multiplier=hidden_dim_multiplier,
dropout=dropout)
def forward(self, graph, x):
x = self.input_linear(x)
attn_scores_u = self.attn_linear_u(x)
attn_scores_v = self.attn_linear_v(x)
attn_scores = ops.u_add_v(graph, attn_scores_u, attn_scores_v)
attn_scores = self.attn_act(attn_scores)
attn_probs = edge_softmax(graph, attn_scores)
x = x.reshape(-1, self.head_dim, self.num_heads)
x = ops.u_mul_e_sum(graph, x, attn_probs)
x = x.reshape(-1, self.dim)
x = self.feed_forward_module(graph, x)
return x
class TransformerAttentionModule(nn.Module):
def __init__(self, dim, num_heads, dropout, **kwargs):
super().__init__()
_check_dim_and_num_heads_consistency(dim, num_heads)
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.attn_query = nn.Linear(in_features=dim, out_features=dim)
self.attn_key = nn.Linear(in_features=dim, out_features=dim)
self.attn_value = nn.Linear(in_features=dim, out_features=dim)
self.output_linear = nn.Linear(in_features=dim, out_features=dim)
self.dropout = nn.Dropout(p=dropout)
def forward(self, graph, x):
queries = self.attn_query(x)
keys = self.attn_key(x)
values = self.attn_value(x)
queries = queries.reshape(-1, self.num_heads, self.head_dim)
keys = keys.reshape(-1, self.num_heads, self.head_dim)
values = values.reshape(-1, self.num_heads, self.head_dim)
attn_scores = ops.u_dot_v(graph, queries, keys) / self.head_dim ** 0.5
attn_probs = edge_softmax(graph, attn_scores)
x = ops.u_mul_e_sum(graph, values, attn_probs)
x = x.reshape(-1, self.dim)
x = self.output_linear(x)
x = self.dropout(x)
return x