-
Notifications
You must be signed in to change notification settings - Fork 0
/
cifar10data.py
121 lines (99 loc) · 4.43 KB
/
cifar10data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# //////////////////////////////////////////////
# ////////////// DATASET: CIFAR10 //////////////
# //////////////////////////////////////////////
# ==============================================
# SETUP AND IMPORTS
# ==============================================
# import Libraries
import os
# torch imports
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, random_split
import torchvision
from torchvision import transforms
from torchvision.datasets import CIFAR10
import torch.nn.functional as F
import torchmetrics
# pl imports
import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
# remaining imports
import wandb
from math import floor
# ==============================================
# CIFAR10_DATAMODULE
# ==============================================
class CIFAR10_DataModule(pl.LightningDataModule):
'''
DataModule to hold the CIFAR10 dataset. Accepts different transforms for train and test to
allow for extrapolation experiments.
Parameters
----------
data_dir : str
Directory where CIFAR10 will be downloaded or taken from.
train_transform : [transform]
List of transformations for the training dataset. The same
transformations are also applied to the validation dataset.
test_transform : [transform] or [[transform]]
List of transformations for the test dataset. Also accepts a list of
lists to validate on multiple datasets with different transforms.
batch_size : int
Batch size for both all dataloaders.
'''
def __init__(self, data_dir='./', train_transform=None, test_transform=None, batch_size=128):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.train_transform = train_transform
self.test_transform = test_transform
self.default_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
def prepare_data(self):
'''called only once and on 1 GPU'''
# download data (train/val and test sets)
CIFAR10(self.data_dir, train=True, download=True)
CIFAR10(self.data_dir, train=False, download=True)
def setup(self, stage=None):
'''
Called on each GPU separately - stage defines if we are
at fit, validate, test or predict step.
'''
# we set up only relevant datasets when stage is specified
if stage in [None, 'fit', 'validate']:
cifar_full = CIFAR10(self.data_dir, train=True, transform=(self.train_transform or self.default_transform))
self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])
if stage == 'test' or stage is None:
if self.test_transform is None or isinstance(self.test_transform, transforms.Compose):
self.cifar_test = CIFAR10(self.data_dir,
train=False,
transform=(self.test_transform or self.default_transform))
else:
self.cifar_test = [CIFAR10(self.data_dir,
train=False,
transform=test_transform)
for test_transform in self.test_transform]
def train_dataloader(self):
'''returns training dataloader'''
cifar_train = DataLoader(self.cifar_train, batch_size=self.batch_size, shuffle=True)
return cifar_train
def val_dataloader(self):
'''returns validation dataloader'''
cifar_val = DataLoader(self.cifar_val, batch_size=self.batch_size)
return cifar_val
def test_dataloader(self):
'''returns test dataloader(s)'''
if isinstance(self.cifar_test, CIFAR10):
return DataLoader(self.cifar_test, batch_size=self.batch_size)
cifar_test = [DataLoader(test_dataset, batch_size=self.batch_size)
for test_dataset in self.cifar_test]
return cifar_test