-
Notifications
You must be signed in to change notification settings - Fork 1
/
synth_dataset.py
115 lines (97 loc) · 4.49 KB
/
synth_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""Synthetic dataset generation and dataloader
To use this script, call the init_synth_dataloader_original to get a dataloader
for the synthetic dataset. The dataset will be generated in a h5 file the first
time you call this function.
"""
import numpy as np
from torch.utils.data import Dataset
from skimage import filters
import h5py
import os
import torchvision.transforms as transforms
import torch
class SynthDataset(Dataset):
def __init__(self, output_folder, mode='train', transform=None, anomaly = False):
super(SynthDataset, self).__init__()
self.output_folder = output_folder
self.anomaly = anomaly
self.mode = mode
self.transform = transform
self.load_cache()
self.indices = np.arange(len(self.images))
def load_cache(self):
data = load_and_generate_data(output_folder = self.output_folder, mode = self.mode)
imsize = 224
images = np.reshape(data['features'][:], [-1, imsize, imsize])
images = np.expand_dims(images, 1)
labels = data['regression_target'][:]
if self.mode in ['val', 'test']:
if self.anomaly:
indexes_to_use = np.where(labels<0.7)[0]
else:
indexes_to_use = np.where(labels>=0.7)[0]
labels = labels[indexes_to_use]
images = images[indexes_to_use]
self.images = images
self.n_images = len(self.images)
self.targets = labels
def __len__(self):
return self.n_images
def __getitem__(self, index):
index = self.indices[index]
x = self.images[index, ...]
y = np.expand_dims(self.targets[index, ...],axis = 1)
if self.transform:
x = self.transform(x)
return x, y
def load_and_generate_data(output_folder, mode = 'train'):
np.random.seed(7)
h5_filename = 'synthetic_mode_'+mode+'.hdf5'
h5_filepath = os.path.join(output_folder, h5_filename)
if not os.path.exists(output_folder):
os.mkdir(output_folder)
if not os.path.exists(h5_filepath):
regression_target, features = prepare_data_squares_by_size()
with h5py.File(h5_filepath, 'w') as hdf5_file:
hdf5_file.create_dataset('features',
data=features, dtype=np.float32)
hdf5_file.create_dataset('regression_target',
data=regression_target, dtype=np.float32)
return h5py.File(h5_filepath, 'r')
def prepare_data_squares_by_size(image_size = 224,
num_samples=10000):
regression_target = np.around(0.75*np.random.weibull(7, num_samples), decimals = 2)
features = np.zeros([num_samples, image_size, image_size])
for i in range(num_samples):
features[i,:,:] = get_clean_square(regression_target[i], image_size)
noise = np.random.normal(scale=1,
size=np.asarray([image_size, image_size]))
smoothed_noise = filters.gaussian(noise, 2.5)
smoothed_noise = smoothed_noise / np.std(smoothed_noise) * 0.5
features[i,:,:] += smoothed_noise
return regression_target, features.reshape([-1, num_samples])
def get_clean_square(regression_target, image_size):
half_image_size = int(image_size / 2)
block_size = int((half_image_size*0.8)*regression_target)
to_return = np.zeros([image_size, image_size])
to_return -= 0.5
to_return[half_image_size - block_size: half_image_size + block_size,
half_image_size - block_size: half_image_size + block_size] = 0.5
return to_return
def init_synth_dataloader_original(output_folder, batch_size, mode='train'):
dataset = SynthDataset(output_folder,
mode=mode,
transform=transforms.Compose([
torch.tensor,
]))
dataloader_class1 = torch.utils.data.DataLoader(dataset, batch_size=batch_size,num_workers = 0,
shuffle = (mode=='train'), drop_last=True)
dataset = SynthDataset(output_folder,
anomaly = True,
mode=mode,
transform=transforms.Compose([
torch.tensor,
]))
dataloader_class2 = torch.utils.data.DataLoader(dataset, batch_size=batch_size,num_workers = 0,
shuffle = (mode=='train'), drop_last=True)
return dataloader_class1, dataloader_class2