This repository has been archived by the owner on Jan 14, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2
/
dataloader.py
176 lines (158 loc) · 7.19 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
from pathlib import Path
import numpy as np
import math
import tensorflow.keras.preprocessing.image as image
import tensorflow.keras.utils as utils
def random_crop(img, size=(256, 256)):
H, W = img.shape[0], img.shape[1]
if isinstance(size, int):
if size > min(H, W):
size = min(H, W)
dy, dx = size, size
elif len(size) == 2:
dy, dx = size
if dy > H:
dy = H
if dx > W:
dx = W
else:
raise ValueError('The size must be an integer or a tuple (or list) of 2 integers: {}'.format(size))
y = np.random.randint(0, H - dy + 1)
x = np.random.randint(0, W - dx + 1)
return img[y:(y + dy), x:(x + dx), :]
def center_crop(img, size=(256, 256)):
H, W = img.shape[0], img.shape[1]
if isinstance(size, int):
if size > min(H, W):
size = min(H, W)
dy, dx = size, size
elif len(size) == 2:
dy, dx = size
if dy > H:
dy = H
if dx > W:
dx = W
else:
raise ValueError('The size must be an integer or a tuple (or list) of 2 integers: {}'.format(size))
y = int((H - dy) / 2)
x = int((W - dx) / 2)
return img[y:(y + dy), x:(x + dx), :]
class ContentStyleLoader(utils.Sequence):
def __init__(self, content_root=None, content_image_shape=None, content_crop=None, content_crop_size=256, style_root=None, style_image_shape=None, style_crop=None, style_crop_size=256, n_per_epoch=1000, batch_size=8):
# training settings
self.n_per_epoch = n_per_epoch
self.batch_size = batch_size
if batch_size > n_per_epoch:
raise ValueError('The batch is greater than the total images per epoch: {} > {}'.format(batch_size, n_per_epoch))
# content images
self.content_root = content_root
content_root = Path(content_root)
if not content_root.exists():
raise ValueError('The content root directory is not exist: {}'.format(content_root))
self.content_images = list(content_root.glob('**/*.*'))
self.n_content = len(self.content_images)
# order of content images
self.content_indices = []
while len(self.content_indices) < n_per_epoch:
self.content_indices.extend(np.random.permutation(self.n_content))
self.cur_content_indices = self.content_indices[:n_per_epoch]
self.content_indices = self.content_indices[n_per_epoch:]
# content image transfomer: path -> np.array
if content_image_shape is not None and len(content_image_shape) != 2:
raise ValueError('The content_image_shape should be None (to use original shape of the image) or 2 dimensional tuple: {}'.format(content_image_shape))
if not (content_crop is None or isinstance(content_crop_size, int) or len(content_crop_size) == 1 or len(content_crop_size == 2)):
raise ValueError('The dimension of the cropped content image must be 1 (rectangle), 2: {}'.format(content_crop_size))
self.content_transform = [
lambda x: image.load_img(x, target_size=content_image_shape),
lambda x: image.img_to_array(x)
]
if content_crop is None:
pass
elif content_crop == 'random':
self.content_transform.append(lambda x: random_crop(x, size=content_crop_size))
elif content_crop == 'center':
self.content_transform.append(lambda x: center_crop(x, size=content_crop_size))
else:
raise ValueError('Unsuppored crop option: {}'.format(content_crop))
# content images
self.style_root = style_root
style_root = Path(style_root)
if not style_root.exists():
raise ValueError('The style root directory is not exist: {}'.format(style_root))
self.style_images = list(style_root.glob('**/*.*'))
self.n_style = len(self.style_images)
# order of style images
self.style_indices = []
while len(self.style_indices) < n_per_epoch:
self.style_indices.extend(np.random.permutation(self.n_style))
self.cur_style_indices = self.style_indices[:n_per_epoch]
self.style_indices = self.style_indices[n_per_epoch:]
# style image transfomer: path -> np.array
if style_image_shape is not None and len(style_image_shape) != 2:
raise ValueError('The style_image_shape should be None (to use original shape of the image) or 2 dimensional tuple: {}'.format(style_image_shape))
if not (style_crop is None or isinstance(style_crop_size, int) or len(style_crop_size) == 1 or len(style_crop_size == 2)):
raise ValueError('The dimension of the cropped style image must be 1 (rectangle), 2: {}'.format(style_crop_size))
self.style_transform = [
lambda x: image.load_img(x, target_size=style_image_shape),
lambda x: image.img_to_array(x)
]
if style_crop is None:
pass
elif style_crop == 'random':
self.style_transform.append(lambda x: random_crop(x, size=style_crop_size))
elif style_crop == 'center':
self.style_transform.append(lambda x: center_crop(x, size=style_crop_size))
else:
raise ValueError('Unsuppored crop option: {}'.format(style_crop))
def __len__(self):
return math.ceil(self.n_per_epoch / self.batch_size)
def __getitem__(self, idx):
# get content
idx_content = self.cur_content_indices[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_content = []
for i in idx_content:
content = self.content_images[i]
for fn in self.content_transform:
content = fn(content)
batch_content.append(content)
# get style
idx_style = self.cur_style_indices[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_style = []
for i in idx_style:
style = self.style_images[i]
for fn in self.style_transform:
style = fn(style)
batch_style.append(style)
return [np.array(batch_content), np.array(batch_style)], np.ones(self.batch_size, dtype=np.float32)
def on_epoch_end(self):
# order of content images
self.content_indices = []
while len(self.content_indices) < self.n_per_epoch:
self.content_indices.extend(np.random.permutation(self.n_content))
self.cur_content_indices = self.content_indices[:self.n_per_epoch]
self.content_indices = self.content_indices[self.n_per_epoch:]
# order of style images
self.style_indices = []
while len(self.style_indices) < self.n_per_epoch:
self.style_indices.extend(np.random.permutation(self.n_style))
self.cur_style_indices = self.style_indices[:self.n_per_epoch]
self.style_indices = self.style_indices[self.n_per_epoch:]
def load_image(filepath, image_shape=None, crop=None, crop_size=None):
filepath = Path(filepath)
if not filepath.exists() or filepath.is_dir():
raise ValueError('The file is not exists: {}'.format(filepath))
if image_shape is not None and len(image_shape) != 2:
raise ValueError('The image_shape should be None (to use original shape of the image) or 2 dimensional tuple: {}'.format(image_shape))
img = image.load_img(filepath, target_size=image_shape)
img = image.img_to_array(img)
if not (crop is None or isinstance(crop_size, int) or len(crop_size) == 1 or len(crop_size == 2)):
raise ValueError('The dimension of the cropped image must be 1 (rectangle), 2: {}'.format(crop_size))
if crop is None:
pass
elif crop == 'random':
img = random_crop(img, size=crop_size)
elif crop == 'center':
img = center_crop(img, size=crop_size)
else:
raise ValueError('Unsupported crop option: {}'.format(crop))
return img