-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel.py
216 lines (178 loc) · 8.46 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Author: XiaShan
@Contact: 153765931@qq.com
@Time: 2024/6/22 20:05
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import get_laplacian, add_self_loops
from torch_geometric.nn.conv import MessagePassing
from torch import ones_like, sparse_coo_tensor, svd_lowrank
class MP(MessagePassing):
def __init__(self):
super(MP, self).__init__()
def forward(self, x, edge_index, norm=None):
return self.propagate(edge_index=edge_index, x=x, norm=None)
def message(self, x_j, norm=None):
if norm != None:
return norm.view(-1, 1) * x_j # 广播计算
else:
return x_j
class Basis_Generator(nn.Module):
def __init__(self, nx, nlx, nl, k, operator, low_x=False, low_lx=False, norm1=False):
super(Basis_Generator, self).__init__()
self.nx = nx
self.nlx = nlx
self.nl = nl
self.norm1 = norm1
self.k = k
self.operator = operator
self.low_x = low_x
self.low_lx = low_lx
self.mp = MP()
def get_x_basis(self, x):
x = F.normalize(x, dim=1) # 对于每个节点,对所有维特征进行规范化
x = F.normalize(x, dim=0) # 对于每维特征,对所有节点进行规范化
if self.low_x:
# 是否对节点特征矩阵x进行有损压缩(基于奇异值分解)
U, S, V = svd_lowrank(x, q=self.nx)
low_x = U @ torch.diag(S)
return low_x
else:
return x
def get_lx_basis(self, x, edge_index):
"""生成所有的特征空间"""
lxs = []
num_nodes = x.shape[0]
# L = I - D^(-1/2) A D^(-1/2) edge_index再添加自环
edge_index_lap, edge_weight_lap = get_laplacian(edge_index=edge_index, normalization='sym', num_nodes=num_nodes)
h = F.normalize(x, dim=1)
if self.operator == 'gcn':
lxs = [h]
edge_index, edge_weight = add_self_loops(edge_index=edge_index_lap,
edge_attr=-edge_weight_lap,
fill_value=2.0,
num_nodes=num_nodes)
edge_index, edge_weight = get_laplacian(edge_index=edge_index,
edge_weight=edge_weight,
normalization='sym',
num_nodes=num_nodes)
edge_index, edge_weight = add_self_loops(edge_index=edge_index,
edge_attr=-edge_weight,
fill_value=1.,
num_nodes=num_nodes)
for k in range(self.k + 1):
h = self.mp.propagate(edge_index=edge_index, x=h, norm=edge_weight)
if self.norm1:
h = F.normalize(h, dim=1)
lxs.append(h)
elif self.operator == 'gpr':
lxs = [h]
edge_index, edge_weight = add_self_loops(edge_index=edge_index_lap,
edge_attr=-edge_weight_lap,
fill_value=1.0,
num_nodes=num_nodes)
for k in range(self.k):
h = self.mp.propagate(edge_index=edge_index, x=h, norm=edge_weight)
if self.norm1:
h = F.normalize(h, dim=1)
lxs.append(h)
elif self.operator == 'cheb':
edge_index, edge_weight = add_self_loops(edge_index=edge_index_lap,
edge_attr=edge_weight_lap,
fill_value=-1.0,
num_nodes=num_nodes)
for k in range(self.k + 1):
if k == 0:
pass
elif k == 1:
h = self.mp.propagate(edge_index=edge_index, x=h, norm=edge_weight)
else:
h = self.mp.propagate(edge_index=edge_index, x=h, norm=edge_weight) * 2
h = h - lxs[-1]
if self.norm1:
h = F.normalize(h, dim=1)
lxs.append(h)
elif self.operator == 'ours':
lxs = [h]
edge_index, edge_weight = add_self_loops(edge_index=edge_index_lap,
edge_attr=edge_weight_lap,
fill_value=-1.0,
num_nodes=num_nodes)
for k in range(self.k):
h = self.mp.propagate(edge_index=edge_index, x=h, norm=edge_weight)
h = h - lxs[-1]
if self.norm1:
h = F.normalize(h, dim=1)
lxs.append(h)
norm_lxs = []
low_lxs = []
for lx in lxs:
if self.low_lx:
U, S, V = svd_lowrank(lx, q=self.nlx)
low_lx = U @ torch.diag(S)
low_lxs.append(low_lx)
norm_lxs.append(F.normalize(low_lx, dim=1))
else:
norm_lxs.append(F.normalize(lx, dim=1))
final_lxs = [F.normalize(lx, dim=0) for lx in lxs]
return final_lxs
def get_l_basis(self, edge_index, num_nodes):
"""对图结构(邻接)矩阵进行有损压缩(基于奇异值分解)"""
# pyg2.5版本: get_laplacian先计算 L = I - D^(-1/2) A D^(-1/2) edge_index再添加自环
edge_index, edge_weight = get_laplacian(edge_index=edge_index, normalization='sym', num_nodes=num_nodes)
adj = sparse_coo_tensor(indices=edge_index,
values=ones_like(edge_index[0]),
size=(num_nodes, num_nodes),
device=edge_index.device,
dtype=torch.float32).to_dense()
adj = F.normalize(adj, dim=1) # 对二维矩阵, 沿着列dim=1 对行 进行规范化
U, S, V = svd_lowrank(adj, q=self.nl, niter=2) # 奇异值分解 adj ≈ U diag(S) V^T
adj = U @ torch.diag(S) # 矩阵近似 adj ≈ U diag(S)
adj = F.normalize(adj, dim=0) # 对二维矩阵, 沿着行dim=0 对列 进行规范化
return adj
class FE_GNN(nn.Module):
def __init__(self, args, ninput, nclass):
super(FE_GNN, self).__init__()
self.nx = ninput if args.nx < 0 else args.nx
self.nlx = ninput if args.nlx < 0 else args.nlx
self.nl = args.nl
self.k = args.k
self.operator = args.operator
self.basis_generator = Basis_Generator(nx=self.nx, nlx=self.nlx, nl=self.nl, k=self.k, operator=args.operator,
low_x=False, low_lx=False, norm1=False)
self.share_lx = args.share_lx
self.thetas = nn.Parameter(torch.ones(args.k + 1), requires_grad=True)
self.lin_lxs = nn.ModuleList()
for i in range(self.k + 1):
self.lin_lxs.append(nn.Linear(self.nlx, args.nhid, bias=True))
self.lin_x = nn.Linear(self.nx, args.nhid, bias=True)
self.lin_lx = nn.Linear(self.nlx, args.nhid, bias=True)
self.lin_l = nn.Linear(self.nl, args.nhid, bias=True)
self.cls = nn.Linear(args.nhid, nclass, bias=True)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x_basis = self.basis_generator.get_x_basis(x)
lx_basis = self.basis_generator.get_lx_basis(x, edge_index)
l_basis = self.basis_generator.get_l_basis(edge_index, x.shape[0])
feature_mat = 0
if self.nx > 0:
x_mat = self.lin_x(x_basis)
feature_mat += x_mat
if self.nlx > 0:
lxs_mat = 0
for k in range(self.k + 1):
if self.share_lx:
lx_mat = self.lin_lx(lx_basis[k]) * self.thetas[k] # share W_lx across each layer/order
else:
lx_mat = self.lin_lxs[k](lx_basis[k]) # do not share the W_lx parameters
lxs_mat = lxs_mat + lx_mat
feature_mat += lxs_mat
if self.nl > 0:
l_mat = self.lin_l(l_basis)
feature_mat += l_mat
output = self.cls(feature_mat)
return F.log_softmax(output, dim=1)