-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate.py
339 lines (297 loc) · 12.1 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
import argparse
import os
import re
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torchvision.utils as tvu
import tqdm
import yaml
from datasets import inverse_data_transform
from models.diffusion import Model
from pytorch_fid import fid_score
from torch.multiprocessing import Process
device = torch.device("cuda")
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
def sigmoid(x):
return 1 / (np.exp(-x) + 1)
if beta_schedule == "quad":
betas = (
np.linspace(
beta_start**0.5,
beta_end**0.5,
num_diffusion_timesteps,
dtype=np.float64,
)
** 2
)
elif beta_schedule == "linear":
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
elif beta_schedule == "const":
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64)
elif beta_schedule == "sigmoid":
betas = np.linspace(-6, 6, num_diffusion_timesteps)
betas = sigmoid(betas) * (beta_end - beta_start) + beta_start
else:
raise NotImplementedError(beta_schedule)
assert betas.shape == (num_diffusion_timesteps,)
return betas
def sample_image(args, config, x, model, randn_like=torch.randn_like, last=True):
betas = get_beta_schedule(
beta_schedule=config.diffusion.beta_schedule,
beta_start=config.diffusion.beta_start,
beta_end=config.diffusion.beta_end,
num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps,
)
betas = torch.from_numpy(betas).float().to(device)
num_timesteps = betas.shape[0]
try:
skip = args.skip
except Exception:
skip = 1
if args.sample_type == "generalized":
if args.skip_type == "uniform":
skip = num_timesteps // args.timesteps
seq = range(0, num_timesteps, skip)
elif args.skip_type == "quad":
seq = np.linspace(0, np.sqrt(num_timesteps * 0.8), args.timesteps) ** 2
seq = [int(s) for s in list(seq)]
else:
raise NotImplementedError
from functions.denoising import generalized_steps
xs = generalized_steps(x, seq, model, betas, randn_like, eta=args.eta)
x = xs
elif args.sample_type == "ddpm_noisy":
if args.skip_type == "uniform":
skip = num_timesteps // args.timesteps
seq = range(0, num_timesteps, skip)
elif args.skip_type == "quad":
seq = np.linspace(0, np.sqrt(num_timesteps * 0.8), args.timesteps) ** 2
seq = [int(s) for s in list(seq)]
else:
raise NotImplementedError
from functions.denoising import ddpm_steps
x = ddpm_steps(x, seq, model, betas)
else:
raise NotImplementedError
if last:
x = x[0][-1]
return x
class StackedRandomGenerator:
def __init__(self, device, seeds):
super().__init__()
self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds]
def randn(self, size, **kwargs):
assert size[0] == len(self.generators)
return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators])
def randn_like(self, input):
return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device)
def randint(self, *args, size, **kwargs):
assert size[0] == len(self.generators)
return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators])
def parse_int_list(s):
if isinstance(s, list):
return s
ranges = []
range_re = re.compile(r"^(\d+)-(\d+)$")
for p in s.split(","):
m = range_re.match(p)
if m:
ranges.extend(range(int(m.group(1)), int(m.group(2)) + 1))
else:
ranges.append(int(p))
return ranges
def init_processes(rank, size, fn, args, config):
"""Initialize the distributed environment."""
os.environ["MASTER_ADDR"] = args.master_address
os.environ["MASTER_PORT"] = "6020"
torch.cuda.set_device(args.local_rank)
gpu = args.local_rank
dist.init_process_group(backend="nccl", init_method="env://", rank=rank, world_size=size)
fn(rank, gpu, args, config)
dist.barrier()
cleanup()
def cleanup():
dist.destroy_process_group()
def sample(rank, gpu, args, config):
def broadcast_params(params):
for param in params:
dist.broadcast(param.data, src=0)
seeds = args.seeds
num_batches = (
(len(args.seeds) - 1) // (config.sampling.batch_size * dist.get_world_size()) + 1
) * dist.get_world_size()
all_batches = torch.as_tensor(seeds).tensor_split(num_batches)
rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()]
# Load network.
model = Model(config)
if rank == 0:
print(f"Loading checkpoint from model_{args.ckpt_id}_ema.pth")
states = torch.load(
os.path.join(args.log_path, f"model_{args.ckpt_id}_ema.pth"),
map_location=device,
)
model = model.to(device)
broadcast_params(model.parameters())
model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu], find_unused_parameters=True)
model.load_state_dict(states, strict=True)
model.eval()
# Loop over batches.
if rank == 0:
print(f'Generating {len(seeds)} images to "{args.image_folder}"...')
for batch_seeds in tqdm.tqdm(rank_batches, unit="batch", disable=(dist.get_rank() != 0)):
torch.distributed.barrier()
batch_size = len(batch_seeds)
if batch_size == 0:
continue
# Pick latents and labels.
rnd = StackedRandomGenerator(device, batch_seeds)
latents = rnd.randn(
[batch_size, config.data.channels, config.data.image_size, config.data.image_size], device=device
)
images = sample_image(args, config, latents, model, randn_like=rnd.randn_like)
# Save images.
images = inverse_data_transform(config, images)
os.makedirs(args.image_folder, exist_ok=True)
img_id = 0
for seed in batch_seeds:
image_path = os.path.join(args.image_folder, f"{seed:06d}.png")
tvu.save_image(images[img_id], image_path)
img_id += 1
# Done.
dist.barrier()
if rank == 0:
print("Done.")
if config.data.dataset == "CIFAR10":
fid_value = fid_score.calculate_fid_given_paths(
[args.image_folder, "pytorch_fid/cifar10_train_stat.npy"], 50, "cuda", 2048
)
elif config.data.dataset == "CELEBA":
fid_value = fid_score.calculate_fid_given_paths(
[args.image_folder, "pytorch_fid/fid_stats_celeba64.npz"], 50, "cuda", 2048
)
elif config.data.dataset == "LSUN" and config.data.category == "church_outdoor":
fid_value = fid_score.calculate_fid_given_paths(
[args.image_folder, "pytorch_fid/lsun_church_stat.npy"], 50, "cuda", 2048
)
elif config.data.dataset == "CELEBAHQ":
fid_value = fid_score.calculate_fid_given_paths(
[args.image_folder, "pytorch_fid/celebahq_stat.npy"], 50, "cuda", 2048
)
elif config.data.dataset == "FFHQ64":
fid_value = fid_score.calculate_fid_given_paths(
[args.image_folder, "pytorch_fid/ffhq-64x64.npz"], 50, "cuda", 2048
)
elif config.data.dataset == "AFHQ":
fid_value = fid_score.calculate_fid_given_paths(
[args.image_folder, "pytorch_fid/afhqv2-64x64.npz"], 50, "cuda", 2048
)
path_name = "/".join(args.image_folder.split("/")[-2:])
with open(args.fid_log, "a") as f:
f.write(f"Checkpoint {path_name} --> FID {fid_value}\n")
print(f"Checkpoint {path_name} --> FID {fid_value}\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser("dpm parameters")
# ddp
parser.add_argument("--num_proc_node", type=int, default=1, help="The number of nodes in multi node env.")
parser.add_argument("--num_process_per_node", type=int, default=1, help="number of gpus")
parser.add_argument("--node_rank", type=int, default=0, help="The index of node.")
parser.add_argument("--local_rank", type=int, default=0, help="rank of process in the node")
parser.add_argument("--master_address", type=str, default="127.0.0.1", help="address for master")
# diffusion
parser.add_argument("--config", type=str, required=True, help="Path to the config file")
parser.add_argument("--exp", type=str, default="exp", help="Path for saving running related data.")
parser.add_argument(
"--doc",
type=str,
required=True,
help="A string for documentation purpose. " "Will be the name of the log folder.",
)
parser.add_argument("--comment", type=str, default="", help="A string for experiment comment")
parser.add_argument(
"--verbose",
type=str,
default="info",
help="Verbose level: info | debug | warning | critical",
)
parser.add_argument(
"-i",
"--image_folder",
type=str,
default="images",
help="The folder name of samples",
)
parser.add_argument(
"--ni",
action="store_true",
help="No interaction. Suitable for Slurm Job launcher",
)
parser.add_argument(
"--sample_type",
type=str,
default="generalized",
help="sampling approach (generalized or ddpm_noisy)",
)
parser.add_argument(
"--skip_type",
type=str,
default="uniform",
help="skip according to (uniform or quadratic)",
)
parser.add_argument("--timesteps", type=int, default=1000, help="number of steps involved")
parser.add_argument(
"--eta",
type=float,
default=0.0,
help="eta used to control the variances of sigma",
)
parser.add_argument(
"--fid_log",
type=str,
default="fid.txt",
help="File to log FID",
)
parser.add_argument("--ckpt_id", type=int, default=500000, help="ckpt id")
parser.add_argument("--num_samples", type=int, default=50000, help="Number of generated samples")
parser.add_argument("--model_ema", action="store_true")
parser.add_argument(
"--seeds", help="Random seeds (e.g. 1,2,5-10)", metavar="LIST", type=parse_int_list, default="0-63"
)
args = parser.parse_args()
args.log_path = os.path.join(args.exp, "logs", args.doc)
args.image_folder = os.path.join(args.exp, "image_samples", args.image_folder)
with open(os.path.join("configs", args.config), "r") as f:
config = yaml.safe_load(f)
def dict2namespace(config):
namespace = argparse.Namespace()
for key, value in config.items():
if isinstance(value, dict):
new_value = dict2namespace(value)
else:
new_value = value
setattr(namespace, key, new_value)
return namespace
config = dict2namespace(config)
args.world_size = args.num_proc_node * args.num_process_per_node
size = args.num_process_per_node
if size > 1:
processes = []
for rank in range(size):
args.local_rank = rank
global_rank = rank + args.node_rank * args.num_process_per_node
global_size = args.num_proc_node * args.num_process_per_node
args.global_rank = global_rank
print("Node rank %d, local proc %d, global proc %d" % (args.node_rank, rank, global_rank))
p = Process(target=init_processes, args=(global_rank, global_size, sample, args, config))
p.start()
processes.append(p)
for p in processes:
p.join()
else:
print("starting in debug mode")
init_processes(0, size, sample, args, config)
# main(args, config)
# ----------------------------------------------------------------------------