-
Notifications
You must be signed in to change notification settings - Fork 6
/
sure_inference.py
173 lines (151 loc) · 7.61 KB
/
sure_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import torch
import numpy as np
def pvp_infer(model, device, all_loader):
model.eval()
align_out0 = []
align_out1 = []
class_labels_cluster = []
len_alldata = len(all_loader.dataset)
align_labels = torch.zeros(len_alldata)
with torch.no_grad():
for batch_idx, (x0, x1, labels, class_labels0, class_labels1) in enumerate(all_loader):
test_num = len(labels)
x0, x1, labels = x0.to(device), x1.to(device), labels.to(device)
x0 = x0.view(x0.size()[0], -1)
x1 = x1.view(x1.size()[0], -1)
h0, h1, _, _ = model(x0, x1)
C = euclidean_dist(h0, h1)
for i in range(test_num):
idx = torch.argsort(C[i, :])
C[:, idx[0]] = float("inf")
align_out0.append((h0[i, :].cpu()).numpy())
align_out1.append((h1[idx[0], :].cpu()).numpy())
if class_labels0[i] == class_labels1[idx[0]]:
align_labels[1024 * batch_idx + i] = 1
class_labels_cluster.extend(class_labels0.numpy())
count = torch.sum(align_labels)
inference_acc = count.item() / len_alldata
return np.array(align_out0), np.array(align_out1), np.array(class_labels_cluster), inference_acc
def pdp_infer(model, device, all_loader):
model.eval()
recover_out0 = [] # view 0 for learned selecting filling
recover_out1 = []
class_labels = []
with torch.no_grad():
k = 3
for batch_idx, (x0, x1, labels, class_labels0, class_labels1, mask) in enumerate(all_loader):
test_num = len(labels)
x0, x1, labels = x0.to(device), x1.to(device), labels.to(device)
x0 = x0.view(x0.size()[0], -1)
x1 = x1.view(x1.size()[0], -1)
class_labels.extend((labels.cpu()).numpy())
h0, h1, _, _ = model(x0, x1)
if mask.sum() == test_num: # complete
continue
fill_num = k
C = euclidean_dist(h0, h1)
row_idx = C.argsort()
col_idx = (C.t()).argsort()
# Mij denotes the flag of i-th sample in view 0 and j-th sample in view 1
M = torch.logical_and((mask[:, 0].repeat(test_num, 1)).t(), mask[:, 1].repeat(test_num, 1))
for i in range(test_num):
idx0 = col_idx[i, :][M[col_idx[i, :], i]] # idx for view 0 to sort and find the non-missing neighbors
idx1 = row_idx[i, :][M[i, row_idx[i, :]]] # idx for view 1 to sort and find the non-missing neighbors
if len(idx1) != 0 and len(idx0) == 0: # i-th sample in view 1 is missing
# weight = torch.softmax(h1[idx1[0:fill_num], :], dim=0)
# avg_fill = (weight * h1[idx1[0:fill_num], :]).sum(dim=0)
avg_fill = h1[idx1[0:fill_num], :].sum(dim=0) / fill_num
recover_out0.append((h0[i, :].cpu()).numpy())
recover_out1.append((avg_fill.cpu()).numpy()) # missing
# missing_cnt += 1
elif len(idx0) != 0 and len(idx1) == 0: # i-th sample in view 0 is missing
# weight = torch.softmax(h0[idx0[0:fill_num], :], dim=0)
# avg_fill = (weight * h0[idx0[0:fill_num], :]).sum(dim=0)
avg_fill = h0[idx0[0:fill_num], :].sum(dim=0) / fill_num
recover_out0.append((avg_fill.cpu()).numpy()) # missing
recover_out1.append((h1[i, :].cpu()).numpy())
# missing_cnt += 1
elif len(idx0) != 0 and len(idx1) != 0: # complete
recover_out0.append((h0[i, :].cpu()).numpy())
recover_out1.append((h1[i, :].cpu()).numpy())
else:
raise Exception('error')
return np.array(recover_out0), np.array(recover_out1), np.array(class_labels)
def both_infer(model, device, all_loader, setting):
model.eval()
align_out0 = []
align_out1 = []
class_labels = []
len_alldata = len(all_loader.dataset)
with torch.no_grad():
cnt = 0
k = 3
missing_cnt = 0
for batch_idx, (x0, x1, labels, class_labels0, class_labels1, mask) in enumerate(all_loader):
test_num = len(labels)
x0, x1, labels = x0.to(device), x1.to(device), labels.to(device)
x0 = x0.view(x0.size()[0], -1)
x1 = x1.view(x1.size()[0], -1)
class_labels.extend((labels.cpu()).numpy())
h0, h1, _, _ = model(x0, x1)
# impute missing samples
if setting != 0:
recover_out0 = (torch.empty_like(h0)).to(device)
recover_out1 = (torch.empty_like(h1)).to(device)
fill_num = k
C = euclidean_dist(h0, h1)
row_idx = C.argsort()
col_idx = (C.t()).argsort()
# Mij denotes the flag of i-th sample in view 0 and j-th sample in view 1
M = torch.logical_and((mask[:, 0].repeat(test_num, 1)).t(), mask[:, 1].repeat(test_num, 1))
for i in range(test_num):
idx0 = col_idx[i, :][M[col_idx[i, :], i]] # idx for view 0 to sort and find the non-missing neighbors
idx1 = row_idx[i, :][M[i, row_idx[i, :]]] # idx for view 1 to sort and find the non-missing neighbors
if len(idx1) != 0 and len(idx0) == 0: # i-th sample in view 1 is missing
avg_fill = h1[idx1[0:fill_num], :].sum(dim=0) / fill_num
cnt += (class_labels1[idx1[0:fill_num]] == class_labels1[i]).sum()
missing_cnt += 1
recover_out0[i, :] = h0[i, :]
recover_out1[i, :] = avg_fill # missing
elif len(idx0) != 0 and len(idx1) == 0:
avg_fill = h0[idx0[0:fill_num], :].sum(dim=0) / fill_num
cnt += (class_labels0[idx0[0:fill_num]] == class_labels0[i]).sum()
missing_cnt += 1
recover_out0[i, :] = avg_fill # missing
recover_out1[i, :] = h1[i, :]
elif len(idx0) != 0 and len(idx1) != 0:
recover_out0[i, :] = h0[i, :]
recover_out1[i, :] = h1[i, :]
else:
raise Exception('error')
if setting == 1:
align_out0.extend((recover_out0.cpu()).numpy())
align_out1.extend((recover_out1.cpu()).numpy())
continue
# reestablish the correspondence across views
if setting != 1:
if setting == 0:
recover_out0 = h0
recover_out1 = h1
C = euclidean_dist(recover_out0, recover_out1)
for i in range(test_num):
idx = torch.argsort(C[i, :])
C[:, idx[0]] = float("inf")
align_out0.append((recover_out0[i, :].cpu()).numpy())
align_out1.append((recover_out1[idx[0], :].cpu()).numpy())
return np.array(align_out0), np.array(align_out1), np.array(class_labels)
def euclidean_dist(x, y):
"""
Args:
x: pytorch Variable, with shape [m, d]
y: pytorch Variable, with shape [n, d]
Returns:
dist: pytorch Variable, with shape [m, n]
"""
m, n = x.size(0), y.size(0)
xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
dist = xx + yy
dist.addmm_(1, -2, x, y.t())
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
return dist