-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtest_preprocess.py
114 lines (84 loc) · 3.36 KB
/
test_preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import tensorflow as tf
import time
from pathlib import Path
import numpy as np
from pathlib import Path
import pandas as pd
from matplotlib import pyplot as plt
import h5py
class dataset_wrapper:
def __init__(self, h5fp, labels, channels):
self.labels = labels
channels = ["channel_%d" % chan for chan in channels]
shape = tuple([len(channels)] + list(h5fp["channel_1/images"].shape))
self.images = np.empty(shape=shape, dtype=np.float32)
for i, chan in enumerate(channels):
images = h5fp[chan]["images"]
masks = h5fp[chan]["masks"]
self.images[i] = np.multiply(images, masks, dtype=np.float32)/2**16
class generator:
def __init__(self, data, indices):
self.data = data
self.indices = indices
def __call__(self):
np.random.shuffle(self.indices) # shuffle happens in-place
for idx in self.indices:
yield self.data.images[:, idx, :, :], self.data.labels[idx]
def apply_augmentation(image):
image = tf.transpose(image, [1, 2, 0])
# Randomly flip the image horizontally.
distorted_image = tf.image.random_flip_left_right(image)
# Randomly flip the image vertically.
distorted_image = tf.image.random_flip_up_down(distorted_image)
distorted_image = tf.transpose(distorted_image, [2, 0, 1])
return distorted_image
def preprocess_batch(batch, aug):
return tf.map_fn(
aug, batch,
dtype=tf.float32
)
def load_dataset(data, indices, labels, cache_file, type="train", augment_func = None):
X = []
for i in range(8):
idx = np.where(labels == i)[0]
X.append(list(idx))
ds = tf.data.experimental.sample_from_datasets([
tf.data.Dataset.from_generator(generator(data, x), output_types=(tf.float32, tf.uint8)).repeat()
for x in X
])
ds = ds.batch(batch_size=128)
ds = ds.map(lambda images, labels: (preprocess_batch(images, apply_augmentation), labels), num_parallel_calls=4)
return ds
if __name__ == "__main__":
from collections import Counter
h5 = "/home/maximl/DATA/Experiment_data/9-color/s123.h5"
meta = pd.read_csv("/home/maximl/DATA/Experiment_data/9-color/train_data_no_images.csv")
train_indices = np.loadtxt(Path("/home/maximl/DATA/Experiment_data/9-color_meta/s123_5fold/0", "val.txt"), dtype=int)
train_cache = str(Path("caches", "test"))
labels = meta["label"].values
with h5py.File(h5) as h5fp:
data = dataset_wrapper(h5fp, labels, [1, 6, 9])
ds = load_dataset(data, train_indices, labels, train_cache, "val", augment_func=apply_augmentation)
batches = 11
it = iter(ds.take(batches))
next(it)
run = []
fig, axes = plt.subplots(16, 8, figsize=(50,25))
axes = axes.ravel()
images, labels = next(it)
for im, ax in zip(images, axes):
ax.imshow(im[0])
plt.savefig("tmp.png")
it = None
# times = []
# for i in range(1):
# it = iter(ds.take(batches))
# next(it)
# run = []
# start = time.time()
# for i, (images, labels) in enumerate(it):
# print(Counter(labels.numpy()))
# run.append(time.time()-start)
# start = time.time()
# times.append(run)
# print(np.mean(times))