-
Notifications
You must be signed in to change notification settings - Fork 1
/
diffusion.py
163 lines (142 loc) · 6.04 KB
/
diffusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import torch
import torch.nn as nn
import torch.nn.functional as F
def extract(v, t, x_shape):
"""
Extract some coefficients at specified timesteps, then reshape to
[batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
out = torch.gather(v, index=t, dim=0).float()
return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))
class GaussianDiffusionTrainer(nn.Module):
def __init__(self, model, beta_1, beta_T, T):
super().__init__()
self.model = model
self.T = T
self.register_buffer(
'betas', torch.linspace(beta_1, beta_T, T).double())
alphas = 1. - self.betas
alphas_bar = torch.cumprod(alphas, dim=0)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
'sqrt_alphas_bar', torch.sqrt(alphas_bar))
self.register_buffer(
'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))
def forward(self, x_0):
"""
Algorithm 1.
"""
t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device)
noise = torch.randn_like(x_0)
x_t = (
extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +
extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)
loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')
return loss
class GaussianDiffusionSampler(nn.Module):
def __init__(self, model, beta_1, beta_T, T, img_size=32,
mean_type='eps', var_type='fixedlarge'):
assert mean_type in ['xprev' 'xstart', 'epsilon']
assert var_type in ['fixedlarge', 'fixedsmall']
super().__init__()
self.model = model
self.T = T
self.img_size = img_size
self.mean_type = mean_type
self.var_type = var_type
self.register_buffer(
'betas', torch.linspace(beta_1, beta_T, T).double())
alphas = 1. - self.betas
alphas_bar = torch.cumprod(alphas, dim=0)
alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T]
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
'sqrt_recip_alphas_bar', torch.sqrt(1. / alphas_bar))
self.register_buffer(
'sqrt_recipm1_alphas_bar', torch.sqrt(1. / alphas_bar - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
self.register_buffer(
'posterior_var',
self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar))
# below: log calculation clipped because the posterior variance is 0 at
# the beginning of the diffusion chain
self.register_buffer(
'posterior_log_var_clipped',
torch.log(
torch.cat([self.posterior_var[1:2], self.posterior_var[1:]])))
self.register_buffer(
'posterior_mean_coef1',
torch.sqrt(alphas_bar_prev) * self.betas / (1. - alphas_bar))
self.register_buffer(
'posterior_mean_coef2',
torch.sqrt(alphas) * (1. - alphas_bar_prev) / (1. - alphas_bar))
def q_mean_variance(self, x_0, x_t, t):
"""
Compute the mean and variance of the diffusion posterior
q(x_{t-1} | x_t, x_0)
"""
assert x_0.shape == x_t.shape
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_0 +
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_log_var_clipped = extract(
self.posterior_log_var_clipped, t, x_t.shape)
return posterior_mean, posterior_log_var_clipped
def predict_xstart_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
extract(self.sqrt_recip_alphas_bar, t, x_t.shape) * x_t -
extract(self.sqrt_recipm1_alphas_bar, t, x_t.shape) * eps
)
def predict_xstart_from_xprev(self, x_t, t, xprev):
assert x_t.shape == xprev.shape
return ( # (xprev - coef2*x_t) / coef1
extract(
1. / self.posterior_mean_coef1, t, x_t.shape) * xprev -
extract(
self.posterior_mean_coef2 / self.posterior_mean_coef1, t,
x_t.shape) * x_t
)
def p_mean_variance(self, x_t, t):
# below: only log_variance is used in the KL computations
model_log_var = {
# for fixedlarge, we set the initial (log-)variance like so to
# get a better decoder log likelihood
'fixedlarge': torch.log(torch.cat([self.posterior_var[1:2],
self.betas[1:]])),
'fixedsmall': self.posterior_log_var_clipped,
}[self.var_type]
model_log_var = extract(model_log_var, t, x_t.shape)
# Mean parameterization
if self.mean_type == 'xprev': # the model predicts x_{t-1}
x_prev = self.model(x_t, t)
x_0 = self.predict_xstart_from_xprev(x_t, t, xprev=x_prev)
model_mean = x_prev
elif self.mean_type == 'xstart': # the model predicts x_0
x_0 = self.model(x_t, t)
model_mean, _ = self.q_mean_variance(x_0, x_t, t)
elif self.mean_type == 'epsilon': # the model predicts epsilon
eps = self.model(x_t, t)
x_0 = self.predict_xstart_from_eps(x_t, t, eps=eps)
model_mean, _ = self.q_mean_variance(x_0, x_t, t)
else:
raise NotImplementedError(self.mean_type)
x_0 = torch.clip(x_0, -1., 1.)
return model_mean, model_log_var
def forward(self, x_T):
"""
Algorithm 2.
"""
x_t = x_T
for time_step in reversed(range(self.T)):
t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step
mean, log_var = self.p_mean_variance(x_t=x_t, t=t)
# no noise when t == 0
if time_step > 0:
noise = torch.randn_like(x_t)
else:
noise = 0
x_t = mean + torch.exp(0.5 * log_var) * noise
x_0 = x_t
return torch.clip(x_0, -1, 1)