-
Notifications
You must be signed in to change notification settings - Fork 9
/
path_dataloader.py
53 lines (36 loc) · 1.77 KB
/
path_dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import tensorflow as tf
import glob
import random
from preprocess_image import load_image
# function to get the path of individual image.
def data_path(orig_img_path, hazy_img_path):
train_img = []
val_img = []
orig_img = glob.glob(orig_img_path + '/*.jpg')
n = len(orig_img)
random.shuffle(orig_img)
train_keys = orig_img[:int(0.9*n)] #90% data for train, 10% for test
val_keys = orig_img[int(0.9*n):]
split_dict = {}
for key in train_keys:
split_dict[key] = 'train'
for key in val_keys:
split_dict[key] = 'val'
hazy_img = glob.glob(hazy_img_path + '/*.jpg')
for img in hazy_img:
img_name = img.split('/')[-1]
orig_path = orig_img_path + '/' + img_name.split('_')[0] + '.jpg'
if (split_dict[orig_path] == 'train'):
train_img.append([img, orig_path])
else:
val_img.append([img, orig_path])
return train_img, val_img
# function to load tensor image data in batches.
def dataloader(train_data, val_data, batch_size):
train_data_orig = tf.data.Dataset.from_tensor_slices([img[1] for img in train_data]).map(lambda x: load_image(x))
train_data_haze = tf.data.Dataset.from_tensor_slices([img[0] for img in train_data]).map(lambda x: load_image(x))
train = tf.data.Dataset.zip((train_data_haze, train_data_orig)).shuffle(buffer_size=100).batch(batch_size)
val_data_orig = tf.data.Dataset.from_tensor_slices([img[1] for img in val_data]).map(lambda x: load_image(x))
val_data_haze = tf.data.Dataset.from_tensor_slices([img[0] for img in val_data]).map(lambda x: load_image(x))
val = tf.data.Dataset.zip((val_data_haze, val_data_orig)).shuffle(buffer_size=100).batch(batch_size)
return train, val