-
Notifications
You must be signed in to change notification settings - Fork 61
/
Copy pathprepare.py
88 lines (69 loc) · 3.15 KB
/
prepare.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import argparse
import glob
import h5py
import numpy as np
import PIL.Image as pil_image
from utils import calc_patch_size, convert_rgb_to_y
@calc_patch_size
def train(args):
h5_file = h5py.File(args.output_path, 'w')
lr_patches = []
hr_patches = []
for image_path in sorted(glob.glob('{}/*'.format(args.images_dir))):
hr = pil_image.open(image_path).convert('RGB')
hr_images = []
if args.with_aug:
for s in [1.0, 0.9, 0.8, 0.7, 0.6]:
for r in [0, 90, 180, 270]:
tmp = hr.resize((int(hr.width * s), int(hr.height * s)), resample=pil_image.BICUBIC)
tmp = tmp.rotate(r, expand=True)
hr_images.append(tmp)
else:
hr_images.append(hr)
for hr in hr_images:
hr_width = (hr.width // args.scale) * args.scale
hr_height = (hr.height // args.scale) * args.scale
hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
lr = hr.resize((hr.width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)
hr = np.array(hr).astype(np.float32)
lr = np.array(lr).astype(np.float32)
hr = convert_rgb_to_y(hr)
lr = convert_rgb_to_y(lr)
for i in range(0, lr.shape[0] - args.patch_size + 1, args.scale):
for j in range(0, lr.shape[1] - args.patch_size + 1, args.scale):
lr_patches.append(lr[i:i+args.patch_size, j:j+args.patch_size])
hr_patches.append(hr[i*args.scale:i*args.scale+args.patch_size*args.scale, j*args.scale:j*args.scale+args.patch_size*args.scale])
lr_patches = np.array(lr_patches)
hr_patches = np.array(hr_patches)
h5_file.create_dataset('lr', data=lr_patches)
h5_file.create_dataset('hr', data=hr_patches)
h5_file.close()
def eval(args):
h5_file = h5py.File(args.output_path, 'w')
lr_group = h5_file.create_group('lr')
hr_group = h5_file.create_group('hr')
for i, image_path in enumerate(sorted(glob.glob('{}/*'.format(args.images_dir)))):
hr = pil_image.open(image_path).convert('RGB')
hr_width = (hr.width // args.scale) * args.scale
hr_height = (hr.height // args.scale) * args.scale
hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
lr = hr.resize((hr.width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)
hr = np.array(hr).astype(np.float32)
lr = np.array(lr).astype(np.float32)
hr = convert_rgb_to_y(hr)
lr = convert_rgb_to_y(lr)
lr_group.create_dataset(str(i), data=lr)
hr_group.create_dataset(str(i), data=hr)
h5_file.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--images-dir', type=str, required=True)
parser.add_argument('--output-path', type=str, required=True)
parser.add_argument('--scale', type=int, default=2)
parser.add_argument('--with-aug', action='store_true')
parser.add_argument('--eval', action='store_true')
args = parser.parse_args()
if not args.eval:
train(args)
else:
eval(args)