-
Notifications
You must be signed in to change notification settings - Fork 1
/
datasets.py
61 lines (52 loc) · 2.14 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import json
import h5py
import torch
from torch.utils.data import Dataset
class ClassifierDataset(Dataset):
def __init__(self, feature_path, label_path, feature_map_path, split):
self.feature_path = feature_path
self.image_features = None
with open(label_path, 'r') as fp:
self.dataset = json.load(fp)[split]
with open(feature_map_path, 'r') as fp:
self.feature_map = json.load(fp)
def __getitem__(self, i):
if not self.image_features:
self.image_features = h5py.File(self.feature_path,
'r')['image_feature']
image_id = self.dataset[i]['image_id']
feature_id = int(self.feature_map[str(image_id)])
image_feature = self.image_features[feature_id] #type: ignore
label = self.dataset[i]['label']
return {
'image_id': image_id,
'inputs': (torch.tensor(image_feature), ),
'label': (torch.tensor(label, dtype=torch.float32), ),
}
def __len__(self):
return len(self.dataset)
class LSTMDataset(Dataset):
def __init__(self, feature_path, label_path, feature_map_path, split):
self.feature_path = feature_path
self.image_features = None
with open(label_path, 'r') as fp:
self.dataset = json.load(fp)[split]
with open(feature_map_path, 'r') as fp:
self.feature_map = json.load(fp)
def __getitem__(self, i):
if not self.image_features:
self.image_features = h5py.File(self.feature_path,
'r')['image_feature']
image_id = self.dataset[i]['image_id']
feature_id = int(self.feature_map[str(image_id)])
image_feature = torch.tensor(
self.image_features[feature_id]) #type: ignore
seq = torch.tensor(self.dataset[i]['seq'])
seq_length = torch.tensor(self.dataset[i]['seq_length'])
return {
'image_id': image_id,
'inputs': (image_feature, seq, seq_length),
'label': (seq, seq_length)
}
def __len__(self):
return len(self.dataset)