-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_cnn_lightning.py
123 lines (100 loc) · 4.42 KB
/
train_cnn_lightning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
An example of training script that implements Pytorch-Lightning
@Author: Francesco Picetti
"""
from argparse import ArgumentParser
import os
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
import src
try:
import pytorch_lightning as pl
except ModuleNotFoundError:
raise ModuleNotFoundError("Please install Pytorch Lightning with: `pip install pytorch_lightning`")
class CNN(pl.LightningModule):
def __init__(self, n_classes=10):
"""A standard convolutional classifier"""
super().__init__()
self.conv1 = torch.nn.Conv2d(3, 32, 3) # 30x30
self.pool1 = torch.nn.MaxPool2d(2, 2) # 15x15
self.conv2 = torch.nn.Conv2d(32, 64, 3) # 13x13
self.pool2 = torch.nn.MaxPool2d(2, 2) # 6x6, no padding
self.conv3 = torch.nn.Conv2d(64, 64, 3) # 4x4
self.dense1 = torch.nn.Linear(4*4*64, 64)
self.dense2 = torch.nn.Linear(64, n_classes)
self.loss_fn = torch.nn.CrossEntropyLoss()
def forward(self, x):
x = self.pool1(F.relu(self.conv1(x)))
x = self.pool2(F.relu(self.conv2(x)))
x = F.relu(self.conv3(x))
x = torch.flatten(x, 1)
x = F.relu(self.dense1(x))
x = self.dense2(x)
return x
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return {
"optimizer" : optimizer,
"lr_scheduler": {
"scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, mode='min', verbose=True),
"monitor" : "val_loss",
},
}
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.forward(x)
loss = self.loss_fn(y_hat, y)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.forward(x)
val_loss = self.loss_fn(y_hat, y)
self.log('val_loss', val_loss)
return val_loss
def main():
parser = ArgumentParser(description="An example of parsing arguments")
parser.add_argument("--outpath", type=str, required=False,
default="./data/trained_models/lightning",
help="Results directory")
parser.add_argument("--num_gpus", type=int, required=False, default=1,
help="Number of GPUs to use")
parser.add_argument("--batch_size", type=int, required=False, default=32,
choices=[16, 32, 64],
help="Batch size")
parser.add_argument("--epochs", type=int, required=False, default=10,
help="Max iterations number")
args = parser.parse_args()
# save args to outpath, for reproducibility
os.makedirs(args.outpath, exist_ok=True) # set to True to enable overwriting
src.write_args(filename=os.path.join(args.outpath, "args.txt"),
args=args)
# Transform to tensor and normalize to [0, 1]
trans = transforms.Compose([
transforms.ToTensor(),
])
# Load training and validation set, initialize Dataloaders
trainset = CIFAR10(root='./data', train=True, download=True, transform=trans)
train_dataloader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2)
valset = CIFAR10(root='./data', train=False, download=True, transform=trans)
val_dataloader = DataLoader(valset, batch_size=args.batch_size, shuffle=False, num_workers=2)
# initialize a model
cnn = CNN()
# define callbacks
early_stopping = pl.callbacks.EarlyStopping('val_loss', mode="min", patience=10)
checkpoint = pl.callbacks.ModelCheckpoint(dirpath=args.outpath, filename="best_model")
# initialize a trainer
trainer = pl.Trainer(gpus=args.num_gpus, # how many GPUs to use...
auto_select_gpus=True if args.num_gpus != 0 else False, # ... only if they are available
deterministic=True, # enables reproducibility
max_epochs=args.epochs,
progress_bar_refresh_rate=20,
callbacks=[early_stopping, checkpoint])
# train the model
trainer.fit(cnn, train_dataloader, val_dataloader)
print("Done!")
if __name__ == '__main__':
main()