-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataset.py
164 lines (141 loc) · 6.51 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# -*- coding: utf-8 -*-
import pickle as pickle
import numpy as np
import random
import os
from utils import pad_seq, bytes_to_file, \
read_split_image, shift_and_resize_image, normalize_image
class PickledImageProvider(object):
def __init__(self, obj_path):
self.obj_path = obj_path
self.examples = self.load_pickled_examples()
def load_pickled_examples(self):
with open(self.obj_path, "rb") as of:
examples = list()
while True:
try:
e = pickle.load(of)
examples.append(e)
if len(examples) % 1000 == 0:
print("processed %d examples" % len(examples))
except EOFError:
break
except Exception:
pass
print("unpickled total %d examples" % len(examples))
return examples
def get_batch_iter(examples, batch_size, augment):
# the transpose ops requires deterministic
# batch size, thus comes the padding
padded = pad_seq(examples, batch_size)
def process(img):
img = bytes_to_file(img)
try:
img_A, img_B, img_C = read_split_image(img)
if augment:
# augment the image by:
# 1) enlarge the image
# 2) random crop the image back to its original size
# NOTE: image A and B needs to be in sync as how much
# to be shifted
w, h, _ = img_A.shape
multiplier = random.uniform(1.00, 1.20)
# add an eps to prevent cropping issue
nw = int(multiplier * w) + 1
nh = int(multiplier * h) + 1
shift_x = int(np.ceil(np.random.uniform(0.01, nw - w)))
shift_y = int(np.ceil(np.random.uniform(0.01, nh - h)))
img_A = shift_and_resize_image(img_A, shift_x, shift_y, nw, nh)
img_B = shift_and_resize_image(img_B, shift_x, shift_y, nw, nh)
img_C = shift_and_resize_image(img_C, shift_x, shift_y, nw, nh)
img_A = normalize_image(img_A)
img_B = normalize_image(img_B)
img_C = normalize_image(img_C)
return np.concatenate([img_A, img_B, img_C], axis=2)
finally:
img.close()
def batch_iter():
for i in range(0, len(padded), batch_size):
batch = padded[i: i + batch_size]
labels = [e[0] for e in batch]
processed = [process(e[1]) for e in batch]
# stack into tensor
yield labels, np.array(processed).astype(np.float32)
return batch_iter()
class TrainDataProvider(object):
def __init__(self, data_dir, train_name="train.obj", val_name="val.obj", filter_by=None):
self.data_dir = data_dir
self.filter_by = filter_by
self.train_path = os.path.join(self.data_dir, train_name)
self.val_path = os.path.join(self.data_dir, val_name)
self.train = PickledImageProvider(self.train_path)
self.val = PickledImageProvider(self.val_path)
if self.filter_by:
print("filter by label ->", filter_by)
self.train.examples = filter(lambda e: e[0] in self.filter_by, self.train.examples)
self.val.examples = filter(lambda e: e[0] in self.filter_by, self.val.examples)
print("train examples -> %d, val examples -> %d" % (len(self.train.examples), len(self.val.examples)))
def get_train_iter(self, batch_size, shuffle=True):
training_examples = self.train.examples[:]
if shuffle:
np.random.shuffle(training_examples)
return get_batch_iter(training_examples, batch_size, augment=True)
def get_val_iter(self, batch_size, shuffle=True):
"""
Validation iterator runs forever
"""
val_examples = self.val.examples[:]
if shuffle:
np.random.shuffle(val_examples)
while True:
val_batch_iter = get_batch_iter(val_examples, batch_size, augment=False)
for labels, examples in val_batch_iter:
yield labels, examples
def get_test_iter(self, batch_size, shuffle=False):
"""
Validation iterator runs forever
"""
val_examples = self.val.examples[:]
if shuffle:
np.random.shuffle(val_examples)
return get_batch_iter(val_examples, batch_size, augment=False)
# while True:
# val_batch_iter = get_batch_iter(val_examples, batch_size, augment=False)
# for labels, examples in val_batch_iter:
# yield labels, examples
def compute_total_batch_num(self, batch_size):
"""Total padded batch num"""
return int(np.ceil(len(self.train.examples) / float(batch_size)))
def get_all_labels(self):
"""Get all training labels"""
return list({e[0] for e in self.train.examples})
def get_train_val_path(self):
return self.train_path, self.val_path
class InjectDataProvider(object):
def __init__(self, obj_path):
self.data = PickledImageProvider(obj_path)
print("examples -> %d" % len(self.data.examples))
def get_single_embedding_iter(self, batch_size, embedding_id):
examples = self.data.examples[:]
batch_iter = get_batch_iter(examples, batch_size, augment=False)
for _, images in batch_iter:
# inject specific embedding style here
labels = [embedding_id] * batch_size
yield labels, images
def get_random_embedding_iter(self, batch_size, embedding_ids):
examples = self.data.examples[:]
batch_iter = get_batch_iter(examples, batch_size, augment=False)
for _, images in batch_iter:
# inject specific embedding style here
labels = [random.choice(embedding_ids) for i in range(batch_size)]
yield labels, images
class NeverEndingLoopingProvider(InjectDataProvider):
def __init__(self, obj_path):
super(NeverEndingLoopingProvider, self).__init__(obj_path)
def get_random_embedding_iter(self, batch_size, embedding_ids):
while True:
# np.random.shuffle(self.data.examples)
rand_iter = super(NeverEndingLoopingProvider, self) \
.get_random_embedding_iter(batch_size, embedding_ids)
for labels, images in rand_iter:
yield labels, images