-
Notifications
You must be signed in to change notification settings - Fork 15
/
NILM_Dataloader.py
25 lines (21 loc) · 999 Bytes
/
NILM_Dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch.utils.data as data_utils
class NILMDataloader():
def __init__(self, args, ds_parser,pretrain = False):
self.args = args
self.mask_prob = args.mask_prob
self.batch_size = args.batch_size
if pretrain:
self.train_dataset, self.val_dataset = ds_parser.get_pretrain_datasets(mask_prob=self.mask_prob)
else:
self.train_dataset, self.val_dataset = ds_parser.get_train_datasets()
def get_dataloaders(self):
train_loader = self._get_loader(self.train_dataset)
val_loader = self._get_loader(self.val_dataset)
return train_loader, val_loader
def _get_loader(self, dataset):
dataloader = data_utils.DataLoader(dataset,
batch_size=self.batch_size,
shuffle=False,
pin_memory=True
)
return dataloader