-
Notifications
You must be signed in to change notification settings - Fork 31
/
Copy pathdata_loader.py
119 lines (104 loc) · 3.72 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import numpy as np
import torch
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
def get_train_valid_loader(data_dir,
name,
batch_size,
valid_size=0.1,
shuffle=True,
num_workers=4,
pin_memory=False):
"""
Utility function for loading and returning train and valid
multi-process iterators over the desired dataset.
Params
------
- data_dir: path directory to the dataset.
- name: string specifying which dataset to load. Can be `mnist`,
`cifar10`, `cifar100`.
- batch_size: how many samples per batch to load.
- valid_size: percentage split of the training set used for
the validation set. Should be a float in the range [0, 1].
In the paper, this number is set to 0.1.
- shuffle: whether to shuffle the train/validation indices.
- num_workers: number of subprocesses to use when loading the dataset.
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
True if using GPU.
Returns
-------
- train_loader: training set iterator.
- valid_loader: validation set iterator.
"""
error_msg1 = "[!] valid_size should be in the range [0, 1]."
error_msg2 = "[!] Invalid dataset name."
assert ((valid_size >= 0) and (valid_size <= 1)), error_msg1
assert name in ['mnist', 'cifar10', 'cifar100'], error_msg2
# define transforms
if name == 'mnist':
normalize = transforms.Normalize(
mean=(0.1307,),
std=(0.3081,)
)
else:
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
)
train_trans = transforms.Compose([
transforms.ToTensor(),
normalize,
])
valid_trans = transforms.Compose([
transforms.ToTensor(),
normalize,
])
# load the dataset
if name == 'mnist':
train_dataset = datasets.MNIST(
root=data_dir, train=True,
download=True, transform=train_trans,
)
valid_dataset = datasets.MNIST(
root=data_dir, train=True,
download=True, transform=valid_trans,
)
elif name == 'cifar10':
train_dataset = datasets.CIFAR10(
root=data_dir, train=True,
download=True, transform=train_trans,
)
valid_dataset = datasets.CIFAR10(
root=data_dir, train=True,
download=True, transform=valid_trans,
)
else:
train_dataset = datasets.CIFAR100(
root=data_dir, train=True,
download=True, transform=train_trans,
)
valid_dataset = datasets.CIFAR100(
root=data_dir, train=True,
download=True, transform=valid_trans,
)
# create dataloaders
num_train = len(train_dataset)
indices = list(range(num_train))
split = int(np.floor(valid_size * num_train))
if shuffle:
seed = 786427186
np.random.seed(seed)
np.random.shuffle(indices)
train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, sampler=train_sampler,
num_workers=num_workers, pin_memory=pin_memory,
)
valid_loader = torch.utils.data.DataLoader(
valid_dataset, batch_size=batch_size, sampler=valid_sampler,
num_workers=num_workers, pin_memory=pin_memory,
)
return (train_loader, valid_loader)