-
Notifications
You must be signed in to change notification settings - Fork 1
/
nyuv2.py
309 lines (258 loc) · 10.6 KB
/
nyuv2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
"""
author: Mihai Suteu
date: 15/05/19
"""
import os
import sys
import h5py
import torch
import shutil
import random
import tarfile
import zipfile
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_url
class NYUv2(Dataset):
"""
PyTorch wrapper for the NYUv2 dataset focused on multi-task learning.
Data sources available: RGB, Semantic Segmentation, Surface Normals, Depth Images.
If no transformation is provided, the image type will not be returned.
### Output
All images are of size: 640 x 480
1. RGB: 3 channel input image
2. Semantic Segmentation: 1 channel representing one of the 14 (0 -
background) classes. Conversion to int will happen automatically if
transformation ends in a tensor.
3. Surface Normals: 3 channels, with values in [0, 1].
4. Depth Images: 1 channel with floats representing the distance in meters.
Conversion will happen automatically if transformation ends in a tensor.
"""
def __init__(
self,
root: str='/u/yuanyiz2/work/nyu_data',
train: bool = True,
download: bool = False,
rgb_transform=None,
seg_transform=None,
sn_transform=None,
depth_transform=None,
):
"""
Will return tuples based on what data source has been enabled (rgb, seg etc).
:param root: path to root folder (eg /data/NYUv2)
:param train: whether to load the train or test set
:param download: whether to download and process data if missing
:param rgb_transform: the transformation pipeline for rbg images
:param seg_transform: the transformation pipeline for segmentation images. If
the transformation ends in a tensor, the result will be automatically
converted to int in [0, 14)
:param sn_transform: the transformation pipeline for surface normal images
:param depth_transform: the transformation pipeline for depth images. If the
transformation ends in a tensor, the result will be automatically converted
to meters
"""
super().__init__()
self.root = root
self.rgb_transform = rgb_transform
self.seg_transform = seg_transform
self.sn_transform = sn_transform
self.depth_transform = depth_transform
self.train = train
self._split = "train" if train else "test"
if download:
self.download()
if not self._check_exists():
raise RuntimeError(
"Dataset not complete." + " You can use download=True to download it"
)
# rgb folder as ground truth
self._files = os.listdir(os.path.join(root, f"{self._split}_rgb"))
def __getitem__(self, index: int):
folder = lambda name: os.path.join(self.root, f"{self._split}_{name}")
# seed = random.randrange(sys.maxsize)
imgs = []
if self.rgb_transform is not None:
# random.seed(seed)
img = Image.open(os.path.join(folder("rgb"), self._files[index]))
img = self.rgb_transform(img)
imgs.append(img)
if self.seg_transform is not None:
# random.seed(seed)
img = Image.open(os.path.join(folder("seg13"), self._files[index]))
img = self.seg_transform(img)
if isinstance(img, torch.Tensor):
# ToTensor scales to [0, 1] by default
img = (img * 255).long()
imgs.append(img)
if self.sn_transform is not None:
# random.seed(seed)
img = Image.open(os.path.join(folder("sn"), self._files[index]))
img = self.sn_transform(img)
imgs.append(img)
if self.depth_transform is not None:
# random.seed(seed)
img = Image.open(os.path.join(folder("depth"), self._files[index]))
img = self.depth_transform(img)
if isinstance(img, torch.Tensor):
# depth png is uint16
img = img.float() / 1e4
imgs.append(img)
return imgs
def __len__(self):
return len(self._files)
def __repr__(self):
fmt_str = f"Dataset {self.__class__.__name__}\n"
fmt_str += f" Number of data points: {self.__len__()}\n"
fmt_str += f" Split: {self._split}\n"
fmt_str += f" Root Location: {self.root}\n"
tmp = " RGB Transforms: "
fmt_str += "{0}{1}\n".format(
tmp, self.rgb_transform.__repr__().replace("\n", "\n" + " " * len(tmp))
)
tmp = " Seg Transforms: "
fmt_str += "{0}{1}\n".format(
tmp, self.seg_transform.__repr__().replace("\n", "\n" + " " * len(tmp))
)
tmp = " SN Transforms: "
fmt_str += "{0}{1}\n".format(
tmp, self.sn_transform.__repr__().replace("\n", "\n" + " " * len(tmp))
)
tmp = " Depth Transforms: "
fmt_str += "{0}{1}\n".format(
tmp, self.depth_transform.__repr__().replace("\n", "\n" + " " * len(tmp))
)
return fmt_str
def _check_exists(self) -> bool:
"""
Only checking for folder existence
"""
try:
for split in ["train", "test"]:
# for type_ in ["rgb", "seg13", "sn", "depth"]:
for type_ in ["rgb", "seg13", "depth"]:
path = os.path.join(self.root, f"{split}_{type_}")
if not os.path.exists(path):
raise FileNotFoundError("Missing Folder")
except FileNotFoundError as e:
return False
return True
def download(self, ):
if self._check_exists():
return
download_rgb(self.root)
download_seg(self.root)
# download_sn(self.root)
download_depth(self.root)
print("Done!")
def download_rgb(root: str):
train_url = "http://www.doc.ic.ac.uk/~ahanda/nyu_train_rgb.tgz"
test_url = "http://www.doc.ic.ac.uk/~ahanda/nyu_test_rgb.tgz"
def _proc(url: str, dst: str):
if not os.path.exists(dst):
tar = os.path.join(root, url.split("/")[-1])
if not os.path.exists(tar):
download_url(url, root)
if os.path.exists(tar):
_unpack(tar)
_replace_folder(tar.rstrip(".tgz"), dst)
_rename_files(dst, lambda x: x.split("_")[2])
_proc(train_url, os.path.join(root, "train_rgb"))
_proc(test_url, os.path.join(root, "test_rgb"))
def download_seg(root: str):
train_url = "https://github.com/ankurhanda/nyuv2-meta-data/raw/master/train_labels_13/nyuv2_train_class13.tgz"
test_url = "https://github.com/ankurhanda/nyuv2-meta-data/raw/master/test_labels_13/nyuv2_test_class13.tgz"
def _proc(url: str, dst: str):
if not os.path.exists(dst):
tar = os.path.join(root, url.split("/")[-1])
if not os.path.exists(tar):
download_url(url, root)
if os.path.exists(tar):
_unpack(tar)
_replace_folder(tar.rstrip(".tgz"), dst)
_rename_files(dst, lambda x: x.split("_")[3])
_proc(train_url, os.path.join(root, "train_seg13"))
_proc(test_url, os.path.join(root, "test_seg13"))
def download_sn(root: str):
url = "https://www.inf.ethz.ch/personal/ladickyl/nyu_normals_gt.zip"
train_dst = os.path.join(root, "train_sn")
test_dst = os.path.join(root, "test_sn")
if not os.path.exists(train_dst) or not os.path.exists(test_dst):
tar = os.path.join(root, url.split("/")[-1])
if not os.path.exists(tar):
download_url(url, root)
if os.path.exists(tar):
_unpack(tar)
if not os.path.exists(train_dst):
_replace_folder(
os.path.join(root, "nyu_normals_gt", "train"), train_dst
)
_rename_files(train_dst, lambda x: x[1:])
if not os.path.exists(test_dst):
_replace_folder(os.path.join(root, "nyu_normals_gt", "test"), test_dst)
_rename_files(test_dst, lambda x: x[1:])
shutil.rmtree(os.path.join(root, "nyu_normals_gt"))
def download_depth(root: str):
url = (
"http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/nyu_depth_v2_labeled.mat"
)
train_dst = os.path.join(root, "train_depth")
test_dst = os.path.join(root, "test_depth")
if not os.path.exists(train_dst) or not os.path.exists(test_dst):
tar = os.path.join(root, url.split("/")[-1])
if not os.path.exists(tar):
download_url(url, root)
if os.path.exists(tar):
train_ids = [
f.split(".")[0] for f in os.listdir(os.path.join(root, "train_rgb"))
]
_create_depth_files(tar, root, train_ids)
def _unpack(file: str):
"""
Unpacks tar and zip, does nothing for any other type
:param file: path of file
"""
path = file.rsplit(".", 1)[0]
if file.endswith(".tgz"):
tar = tarfile.open(file, "r:gz")
tar.extractall(path)
tar.close()
elif file.endswith(".zip"):
zip = zipfile.ZipFile(file, "r")
zip.extractall(path)
zip.close()
def _rename_files(folder: str, rename_func: callable):
"""
Renames all files inside a folder based on the passed rename function
:param folder: path to folder that contains files
:param rename_func: function renaming filename (not including path) str -> str
"""
imgs_old = os.listdir(folder)
imgs_new = [rename_func(file) for file in imgs_old]
for img_old, img_new in zip(imgs_old, imgs_new):
shutil.move(os.path.join(folder, img_old), os.path.join(folder, img_new))
def _replace_folder(src: str, dst: str):
"""
Rename src into dst, replacing/overwriting dst if it exists.
"""
if os.path.exists(dst):
shutil.rmtree(dst)
shutil.move(src, dst)
def _create_depth_files(mat_file: str, root: str, train_ids: list):
"""
Extract the depth arrays from the mat file into images
:param mat_file: path to the official labelled dataset .mat file
:param root: The root directory of the dataset
:param train_ids: the IDs of the training images as string (for splitting)
"""
os.mkdir(os.path.join(root, "train_depth"))
os.mkdir(os.path.join(root, "test_depth"))
train_ids = set(train_ids)
depths = h5py.File(mat_file, "r")["depths"]
for i in range(len(depths)):
img = (depths[i] * 1e4).astype(np.uint16).T
id_ = str(i + 1).zfill(4)
folder = "train" if id_ in train_ids else "test"
save_path = os.path.join(root, f"{folder}_depth", id_ + ".png")
Image.fromarray(img).save(save_path)