-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathgen_tfrecord.py
63 lines (60 loc) · 2.21 KB
/
gen_tfrecord.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import tensorflow as tf
import os
from PIL import Image
import numpy as np
height = 318
width = 424
depth_height=318
depth_width=424
trainwriter= tf.python_io.TFRecordWriter("train.tfrecords")
rgbdir = './rgb_train/'
depthdir = './depth_train/'
eptlist=[] #
for img_name in os.listdir(rgbdir):
if os.path.splitext(img_name)[1] == '.png':
if os.path.isfile(os.path.join(depthdir, img_name)):
eptlist.append(img_name)
lens=len(eptlist)
print(lens)
indexar=np.arange(lens)
randindex=np.random.permutation(indexar)
for indexx in randindex:
filename = eptlist[indexx]
imgraw = Image.open(rgbdir+filename).convert('RGB')
imgraw = imgraw.resize((width, height),Image.BILINEAR)
imgraw = imgraw.tobytes()
imglabel = Image.open(depthdir+filename).convert('F')
imglabel = imglabel.resize((depth_width, depth_height),Image.BILINEAR)
imglabel = imglabel.tobytes()
example = tf.train.Example(features=tf.train.Features(feature={
'img': tf.train.Feature(bytes_list=tf.train.BytesList(value=[imgraw])),
'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[imglabel]))
}))
trainwriter.write(example.SerializeToString())
trainwriter.close()
testwriter= tf.python_io.TFRecordWriter("test.tfrecords")
rgbdir = './rgb_test/'
depthdir = './depth_test/'
eptlist=[]
for img_name in os.listdir(rgbdir):
if os.path.splitext(img_name)[1] == '.png':
if os.path.isfile(depthdir+img_name):
eptlist.append(img_name)
lens=len(eptlist)
print(lens)
indexar=np.arange(lens)
randindex=np.random.permutation(indexar)
for indexx in randindex:
filename=eptlist[indexx]
imgraw=Image.open(rgbdir+filename).convert('RGB')
imgraw = imgraw.resize((width, height),Image.BILINEAR)
imgraw = imgraw.tobytes()
imglabel=Image.open(depthdir+filename).convert('F')
imglabel = imglabel.resize((depth_width, depth_height),Image.BILINEAR)
imglabel = imglabel.tobytes()
example = tf.train.Example(features=tf.train.Features(feature={
'img': tf.train.Feature(bytes_list=tf.train.BytesList(value=[imgraw])),
'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[imglabel]))
}))
testwriter.write(example.SerializeToString())
testwriter.close()