forked from jaryP/MMD-Bayesian-Neural-Network
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbayesian_utils.py
187 lines (147 loc) · 6.03 KB
/
bayesian_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import numpy as np
from torch import nn as nn
from torch.distributions import Normal
from torch.nn import functional as F, init
import torch
class BayesianParameters(nn.Module):
def __init__(self, size, mu_initialization=None, rho_initialization=None, posterior_type='weights', is_bias=False):
super().__init__()
self.posterior_type = posterior_type
self.mask = None
if mu_initialization is None:
t = torch.empty(size)
if is_bias:
bound = 1 / np.sqrt(size)
init.uniform_(t, -bound, bound)
else:
init.kaiming_uniform_(t, a=np.sqrt(5))
self.mu = nn.Parameter(t, requires_grad=True)
else:
t = mu_initialization['type']
if t == 'uniform':
a, b = mu_initialization['a'], mu_initialization['b']
self.mu = nn.Parameter(torch.zeros(size).uniform_(a, b), requires_grad=True)
elif t == 'gaussian':
mu, sigma = mu_initialization['mu'], mu_initialization['sigma']
self.mu = nn.Parameter(torch.zeros(size).normal_(mu, sigma), requires_grad=True)
elif t == 'constant':
self.mu = nn.Parameter(torch.ones(size) * mu_initialization['c'], requires_grad=True)
else:
raise ValueError("Pissible initialization for mu parameter: \n"
"-gaussian {{mu, sigma}}\n"
"-uniform {{a, b}}\n"
"-constant {{c}}. \n {} was given".format(t))
rho_size = size
if posterior_type == 'layers':
rho_size = (1,)
elif posterior_type == 'neurons':
if is_bias:
rho_size = size
else:
rho_size = list(size)
for i in range(1, len(rho_size)):
rho_size[i] = 1
if rho_initialization is None:
rho = torch.randn(rho_size)
else:
t = rho_initialization['type']
if t == 'uniform':
a, b = rho_initialization['a'], rho_initialization['b']
rho = torch.zeros(rho_size).uniform_(a, b)
elif t == 'gaussian':
mu, sigma = rho_initialization['mu'], rho_initialization['sigma']
rho = torch.zeros(rho_size).normal_(mu, sigma)
elif t == 'constant':
rho = torch.ones(rho_size) * rho_initialization['c']
else:
raise ValueError("Pissible initialization for rho parameter: \n"
"-gaussian {{mu, sigma}}\n"
"-uniform {{a, b}}\n"
"-constant {{c}}. \n {} was given".format(t))
self.rho = nn.Parameter(rho, requires_grad=True)
def set_mask(self, p):
if p is not None:
if p < 0 or p > 1:
raise ValueError('Mask percentile should be between 0 and 1, {} was given'.format(p))
self.mask = p
@property
def weights(self):
sigma = self.sigma
r = self.mu + sigma * torch.randn(self.mu.shape, requires_grad=True).to(self.mu.device)
if self.mask is not None:
mean = torch.abs(self.mu)
std = self.sigma
snr = mean / std
percentile = np.percentile(snr.cpu(), self.mask * 100)
mask = torch.ones_like(snr)
mask[snr < torch.tensor(percentile)] = 0
r = r * mask
return r
@property
def sigma(self):
if self.posterior_type == 'weights':
return F.softplus(self.rho)
if self.posterior_type == 'multiplicative':
return torch.mul(F.softplus(self.rho), self.mu.pow(2))
else:
return torch.mul(F.softplus(self.rho), self.mu.pow(2))
def posterior_distribution(self):
return Normal(self.mu.data.clone(), torch.log(1 + torch.exp(self.rho)).clone())
def posterior_log_prob(self, w):
return self.posterior_distribution().log_prob(w)
def forward(self, input, sample=1):
pass
def b_drop(x, p=0.5):
return F.dropout(x, p=p, training=True, inplace=False)
def pdist(p, q):
pdim, qdim = p.size(0), q.size(0)
pnorm = torch.sum(p ** 2, dim=1, keepdim=True)
qnorm = torch.sum(q ** 2, dim=1, keepdim=True)
norms = (pnorm.expand(pdim, qdim) +
qnorm.transpose(0, 1).expand(pdim, qdim))
distances_squared = norms - 2 * p.mm(q.t())
return torch.sqrt(1e-5 + torch.abs(distances_squared))
def pairwise_distances(x, y):
x_norm = (x ** 2).sum(1).view(-1, 1)
y_norm = (y ** 2).sum(1).view(1, -1)
dist = x_norm + y_norm - 2.0 * torch.mm(x, torch.transpose(y, 0, 1))
return torch.clamp(dist, 0.0, np.inf)
def compute_mmd(x, y, type='inverse', biased=True, space=None, max=False):
d = x.device
xs = x.shape[0]
XX, YY, XY = torch.zeros([xs, xs]).to(d), torch.zeros([xs, xs]).to(d), torch.zeros([xs, xs]).to(d)
xxd = pdist(x, x) ** 2
yyd = pdist(y, y) ** 2
xyd = pdist(x, y) ** 2
if type == 'rbf':
if space is None:
space = [0.5, 1, 2, 4, 8, 16]
for gamma in space:
gamma = 1.0 / (2 * gamma ** 2)
xx = torch.exp(-xxd * gamma)
yy = torch.exp(-yyd * gamma)
xy = torch.exp(-xyd * gamma)
XX += xx
YY += yy
XY += xy
elif type == 'inverse':
if space is None:
space = [0.05, 0.2, 0.6, 0.9, 1]
for a in space:
a = a ** 2
xxk = torch.div(1, torch.sqrt(a + xxd))
yyk = torch.div(1, torch.sqrt(a + yyd))
xyk = torch.div(1, torch.sqrt(a + xyd))
XX += xxk
YY += yyk
XY += xyk
else:
return None
if biased:
mmd = XX.mean() + YY.mean() - 2 * XY.mean()
else:
XX = XX.sum() - XX.trace()
YY = YY.sum() - YY.trace()
XY = XY.sum()
mmd = (1 / (xs ** 2)) * XX + (1 / (xs ** 2)) * YY - (2 / (xs * xs)) * XY
return F.relu(mmd)