-
Notifications
You must be signed in to change notification settings - Fork 6
/
data.py
executable file
·30 lines (26 loc) · 1.09 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import os
import numpy as np
import tensorflow as tf
import tensorlayer as tl
DATA_PATH = "data"
def get_celebA(output_size, batch_size):
# dataset API and augmentation
images_path = tl.files.load_file_list(path=DATA_PATH, regx='.*.jpg', keep_prefix=True, printable=False)
def generator_train():
for image_path in images_path:
yield image_path.encode('utf-8')
def _map_fn(image_path):
image = tf.io.read_file(image_path)
image = tf.image.decode_jpeg(image, channels=3) # get RGB with 0~1
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image = image[45:173, 25:153, :] # central crop
image = tf.image.resize([image], (output_size, output_size))[0]
image = tf.image.random_flip_left_right(image)
image = image * 2 - 1
return image
train_ds = tf.data.Dataset.from_generator(generator_train, output_types=tf.string)
ds = train_ds.shuffle(buffer_size=4096)
ds = ds.map(_map_fn, num_parallel_calls=4)
ds = ds.batch(batch_size)
ds = ds.prefetch(buffer_size=2)
return ds, images_path