-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_iter.py
41 lines (32 loc) · 1.66 KB
/
data_iter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
class Real_Dataset(Dataset):
def __init__(self, filepath):
self.data = np.load(filepath)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return torch.from_numpy(self.data[idx]).long()
class Dis_Dataset(Dataset):
def __init__(self, positive_filepath, negative_filepath):
pos_data = np.load(positive_filepath, allow_pickle=True)
neg_data = np.load(negative_filepath, allow_pickle=True)
#print("Pos data: {}".format(len(pos_data)))
#print("Neg data: {}".format(len(neg_data)))
pos_label = np.array([1 for _ in pos_data])
neg_label = np.array([0 for _ in neg_data])
self.data = np.concatenate([pos_data, neg_data])
self.label = np.concatenate([pos_label, neg_label])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data = torch.from_numpy(self.data[idx]).long()
label = torch.nn.init.constant_(torch.zeros(1), int(self.label[idx])).long()
return {"data": data, "label": label}
def real_data_loader(filepath, batch_size, shuffle, num_workers, pin_memory):
dataset = Real_Dataset(filepath)
return DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory)
def dis_data_loader(positive_filepath, negative_filepath, batch_size, shuffle, num_workers, pin_memory):
dataset = Dis_Dataset(positive_filepath, negative_filepath)
return DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory)