-
Notifications
You must be signed in to change notification settings - Fork 1
/
dataloading_example.py
51 lines (44 loc) · 2.01 KB
/
dataloading_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from pathlib import Path
import cv2
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from dataset.provider import DatasetProvider
from dataset.visualization import disp_img_to_rgb_img, show_disp_overlay, show_image
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--dsec_dir', default="/home/siyuan/workspace/CVbyDL/", help='Path to DSEC dataset directory')
parser.add_argument('--visualize', action='store_true', help='Visualize data', default=True)
parser.add_argument('--overlay', action='store_true', help='If visualizing, overlay disparity and voxel grid image', default=True)
args = parser.parse_args()
visualize = args.visualize
dsec_dir = Path(args.dsec_dir)
assert dsec_dir.is_dir()
dataset_provider = DatasetProvider(dsec_dir)
train_dataset = dataset_provider.get_train_dataset()
batch_size = 1
num_workers = 0
train_loader = DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
drop_last=False)
with torch.no_grad():
for data in tqdm(train_loader):
# data = {'disparity_gt', 'file_index', 'representation'{'left', 'right'}}
# data['representation']['left'].shape = [1, 15, 480, 640]
# data['representation']['right'].shape = [1, 15, 480, 640]
# data['disparity_gt'].shape = [1, 480, 640]
if batch_size == 1 and visualize:
disp = data['disparity_gt'].numpy().squeeze()
disp_img = disp_img_to_rgb_img(disp)
if args.overlay:
left_voxel_grid = data['representation']['left'].squeeze()
ev_img = torch.sum(left_voxel_grid, axis=0).numpy()
ev_img = (ev_img/ev_img.max()*256).astype('uint8')
show_disp_overlay(ev_img, disp_img, height=480, width=640)
else:
show_image(disp_img)