-
Notifications
You must be signed in to change notification settings - Fork 1
/
helper.py
93 lines (73 loc) · 3.19 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# helper functions for saving sample data and models
# import data loading libraries
import os
import pdb
import pickle
import argparse
import warnings
warnings.filterwarnings("ignore")
# import torch
import torch
# numpy & scipy imports
import numpy as np
import scipy
import scipy.misc
import matplotlib.pyplot as plt
def checkpoint(iteration, G_XtoY, G_YtoX, D_X, D_Y, checkpoint_dir='checkpoints_cyclegan'):
"""Saves the parameters of both generators G_YtoX, G_XtoY and discriminators D_X, D_Y.
"""
G_XtoY_path = os.path.join(checkpoint_dir, 'G_XtoY.pkl')
G_YtoX_path = os.path.join(checkpoint_dir, 'G_YtoX.pkl')
D_X_path = os.path.join(checkpoint_dir, 'D_X.pkl')
D_Y_path = os.path.join(checkpoint_dir, 'D_Y.pkl')
torch.save(G_XtoY.state_dict(), G_XtoY_path)
torch.save(G_YtoX.state_dict(), G_YtoX_path)
torch.save(D_X.state_dict(), D_X_path)
torch.save(D_Y.state_dict(), D_Y_path)
def merge_images(sources, targets, batch_size=16):
"""Creates a grid consisting of pairs of columns, where the first column in
each pair contains images source images and the second column in each pair
contains images generated by the CycleGAN from the corresponding images in
the first column.
"""
_, _, h, w = sources.shape
row = int(np.sqrt(batch_size))
merged = np.zeros([row * h, row * w * 2, 3])
for idx, (s, t) in enumerate(zip(sources, targets)):
i = idx // row
j = idx % row
print('merged[0]: From',i ,'* 215, To ',(i + 1) ,'* 215')
print('merged[2]: From',(j * 2), '* 215, To ',(j * 2 + 1),' * 215')
merged[i * h:(i + 1) * h, (j * 2) * h:(j * 2 + 1) * h, :] = np.transpose(s.data,(1,2,0)).reshape(215,215,3)
merged[i * h:(i + 1) * h, (j * 2 + 1) * h:(j * 2 + 2) * h, :] = np.transpose(t.data,(1,2,0)).reshape(215,215,3)
#merged = merged.transpose(1, 2, 0)
#merged = merged.reshape(row * h, row * w * 2,3)
return merged
def to_data(x):
"""Converts variable to numpy."""
if torch.cuda.is_available():
x = x.cpu()
x = x.data.numpy()
x = ((x + 1) * 255 / (2)).astype(np.uint8) # rescale to 0-255
return x
def save_samples(iteration, images_Y, images_X, G_YtoX, G_XtoY, batch_size=8,
sample_dir='C:\\Users\\HM\\OneDrive - Danmarks Tekniske Universitet\\DTU\\Special Course\\results_samples'):
"""Saves samples from both generators X->Y and Y->X.
"""
# move input data to correct device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
fake_X = G_YtoX(images_Y.to(device))
fake_Y = G_XtoY(images_X.to(device))
X, fake_X = to_data(images_X), to_data(fake_X)
Y, fake_Y = to_data(images_Y), to_data(fake_Y)
merged = merge_images(X, fake_Y, batch_size)
path = os.path.join(sample_dir, 'sample-{:06d}-X-Y.png'.format(iteration))
scipy.misc.imsave(path, merged)
print('Saved {}'.format(path))
merged = merge_images(Y, fake_X, batch_size)
path = os.path.join(sample_dir, 'sample-{:06d}-Y-X.png'.format(iteration))
scipy.misc.imsave(path, merged)
print('Saved {}'.format(path))
def imshow(img):
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))