-
Notifications
You must be signed in to change notification settings - Fork 0
/
training_data_processing.py
113 lines (94 loc) · 5.47 KB
/
training_data_processing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import numpy as np
from PIL import Image
from skimage.transform import resize
import scipy
from tqdm import tqdm
def training_data_loading(path_training='training/', training_resize=384, training_number=100):
"""
training_data_loading - Load and generate the resized training dataset and validation dataset.
Args:
path_training (str): the location in your Google Drive or local file system
training_resize (int): the resolution of resized training images and their corresponding masks (training pairs) (default: 384)
training_number (int): the number of resized training pairs used for data augmentation (default: 100)
Returns:
images_training, labels_training (numpy): the resized training dataset
images_validation, labels_validation (numpy): the resized validation dataset
"""
images_loading = np.empty(shape=(100, 3, training_resize, training_resize))
labels_loading = np.empty(shape=(100, 1, training_resize, training_resize))
for index in tqdm(range(1, 101)):
# Load a training pair
image = np.array(Image.open(f'{path_training}images/satImage_{str(index).zfill(3)}.png')).astype(float) / 255
label = np.array(Image.open(f'{path_training}groundtruth/satImage_{str(index).zfill(3)}.png')).astype(float) / 255
# Expand the shape of the mask
label = np.expand_dims(label, 2)
# Resize the training pair
image = resize(image, (training_resize, training_resize))
label = resize(label, (training_resize, training_resize))
# Reverse the axes of the resized training pair
image = np.transpose(image, (2, 0, 1))
label = np.transpose(label, (2, 0, 1))
images_loading[index-1] = image
labels_loading[index-1] = label
# Permute the resized training dataset randomly
permuted_sequence = np.random.permutation(100)
images_loading = images_loading[permuted_sequence]
labels_loading = labels_loading[permuted_sequence]
# Generate the resized training dataset and validation dataset
images_training = images_loading[:training_number]
labels_training = labels_loading[:training_number]
images_validation = images_loading[training_number:]
labels_validation = labels_loading[training_number:]
return images_training, labels_training, images_validation, labels_validation
def training_data_augmentation(images_training, labels_training, rotations, flips, shifts, training_resize=384):
"""
training_data_augmentation - Generate the augmented training dataset.
Args:
images_training, labels_training (numpy): the resized training dataset
rotations (list): the parameters for rotating resized training images and their corresponding masks (training pairs)
flips (list): the parameters for flipping rotated training pairs
shifts (list): the parameters for shifting flipped training pairs
training_resize (int): the resolution of resized training pairs (default: 384)
Returns:
images_augmented, labels_augmented (numpy): the augmented training dataset
"""
num_rota = len(rotations)
num_flip = len(flips)
num_shft = len(shifts)
# Generate the augmented training dataset
num_training = images_training.shape[0]
num_augmented = num_training * num_rota * num_flip * num_shft
images_augmented = np.empty(shape=(num_augmented, 3, training_resize, training_resize))
labels_augmented = np.empty(shape=(num_augmented, 1, training_resize, training_resize))
print(f"images_augmented.shape = {images_augmented.shape}")
print(f"labels_augmented.shape = {labels_augmented.shape}")
counter = 0
for index in tqdm(range(num_training)):
image = np.transpose(images_training[index], (1, 2, 0))
label = np.transpose(labels_training[index], (1, 2, 0))
for rota in rotations:
for flip in flips:
for shft in shifts:
# Rotate a resized training pair
image_rota = scipy.ndimage.rotate(image, rota, reshape=False, mode='reflect')
label_rota = scipy.ndimage.rotate(label, rota, reshape=False, mode='reflect')
# Flip the rotated training pair
if flip == 'original':
image_flip = image_rota
label_flip = label_rota
else:
image_flip = flip(image_rota)
label_flip = flip(label_rota)
# Shift the flipped training pair
shft_H = np.random.uniform(low=shft[0], high=shft[1], size=1)[0]
shft_W = np.random.uniform(low=shft[0], high=shft[1], size=1)[0]
image_shft = scipy.ndimage.shift(image_flip, (shft_H, shft_W, 0), mode='reflect')
label_shft = scipy.ndimage.shift(label_flip, (shft_H, shft_W, 0), mode='reflect')
images_augmented[counter] = np.clip(np.transpose(image_shft, (2, 0, 1)), 0, 1)
labels_augmented[counter] = np.transpose(label_shft, (2, 0, 1)) > 0.3
counter += 1
# Permute the augmented training dataset randomly
permuted_sequence = np.random.permutation(num_augmented)
images_augmented = images_augmented[permuted_sequence]
labels_augmented = labels_augmented[permuted_sequence]
return images_augmented, labels_augmented