-
Notifications
You must be signed in to change notification settings - Fork 8
/
data_loader.py
21 lines (19 loc) · 996 Bytes
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
#!/usr/bin/python
# A simple data loader that imports the train and test mat files
# from the `filename` and converts them to torch.tesnors()
# to be loaded for training and testing DLoc network
# `features_wo_offset`: targets for the consistency decoder
# `features_w_offset` : inputs for the network/encoder
# `labels_gaussian_2d`: targets for the location decoder
import torch
import h5py
import scipy.io
import numpy as np
def load_data(filename):
print('Loading '+filename)
arrays = {}
f = h5py.File(filename,'r')
features_wo_offset = torch.tensor(np.transpose(np.array(f.get('features_wo_offset'), dtype=np.float32)), dtype=torch.float32)
features_w_offset = torch.tensor(np.transpose(np.array(f.get('features_w_offset'), dtype=np.float32)), dtype=torch.float32)
labels_gaussian_2d = torch.tensor(np.transpose(np.array(f.get('labels_gaussian_2d'), dtype=np.float32)), dtype=torch.float32)
return features_wo_offset,features_w_offset, labels_gaussian_2d