-
Notifications
You must be signed in to change notification settings - Fork 1
/
dataloader.py
40 lines (34 loc) · 1.66 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
from torch.utils.data import Dataset
# Convert dataset into tensor for model testing
class OFFSpottingDataset(Dataset):
def __init__(self, tensors):
self.tensors = tensors
def __len__(self):
return self.tensors[0].shape[0]
# Convert into two 3-stream data
def __getitem__(self, index):
x1 = torch.reshape(self.tensors[0][index][0].to('cuda'), (1,28,28))
x2 = torch.reshape(self.tensors[0][index][1].to('cuda'), (1,28,28))
x3 = torch.reshape(self.tensors[0][index][2].to('cuda'), (1,28,28))
x4 = torch.reshape(self.tensors[1][index][0].to('cuda'), (1,28,28))
x5 = torch.reshape(self.tensors[1][index][1].to('cuda'), (1,28,28))
x6 = torch.reshape(self.tensors[1][index][2].to('cuda'), (1,28,28))
return x1, x2, x3, x4, x5, x6
# Convert dataset into tensor for model training
class OFFSpottingDatasetTrain(Dataset):
def __init__(self, tensors):
self.tensors = tensors
def __len__(self):
return self.tensors[0].shape[0]
# Convert into two 3-stream data
def __getitem__(self, index):
x1 = torch.reshape(self.tensors[0][index][0].to('cuda'), (1,28,28))
x2 = torch.reshape(self.tensors[0][index][1].to('cuda'), (1,28,28))
x3 = torch.reshape(self.tensors[0][index][2].to('cuda'), (1,28,28))
x4 = torch.reshape(self.tensors[1][index][0].to('cuda'), (1,28,28))
x5 = torch.reshape(self.tensors[1][index][1].to('cuda'), (1,28,28))
x6 = torch.reshape(self.tensors[1][index][2].to('cuda'), (1,28,28))
y = self.tensors[2][index]
y1 = self.tensors[3][index]
return x1, x2, x3, x4, x5, x6, y, y1