-
Notifications
You must be signed in to change notification settings - Fork 27
/
mnist.py
55 lines (45 loc) · 1.68 KB
/
mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torchvision.datasets as datasets
from torch.utils.data import SubsetRandomSampler, DataLoader
from torchvision import transforms
import torch
import params
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
mnist_train_dataset = datasets.MNIST(root='../data/MNIST', train=True, download=True,
transform=transform)
mnist_valid_dataset = datasets.MNIST(root='../data/MNIST', train=True, download=True,
transform=transform)
mnist_test_dataset = datasets.MNIST(root='../data/MNIST', train=False, transform=transform)
indices = list(range(len(mnist_train_dataset)))
validation_size = 5000
train_idx, valid_idx = indices[validation_size:], indices[:validation_size]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
mnist_train_loader = DataLoader(
mnist_train_dataset,
batch_size=params.batch_size,
sampler=train_sampler,
num_workers=params.num_workers
)
mnist_valid_loader = DataLoader(
mnist_valid_dataset,
batch_size=params.batch_size,
sampler=train_sampler,
num_workers=params.num_workers
)
mnist_test_loader = DataLoader(
mnist_test_dataset,
batch_size=params.batch_size,
num_workers=params.num_workers
)
def one_hot_embedding(labels, num_classes=10):
"""Embedding labels to one-hot form.
Args:
labels: (LongTensor) class labels, sized [N,].
num_classes: (int) number of classes.
Returns:
(tensor) encoded labels, sized [N, #classes].
"""
y = torch.eye(num_classes)
return y[labels]