-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_vgg16.py
144 lines (116 loc) · 5.8 KB
/
train_vgg16.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# Import libraries and packages
# Using PyTorch
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
import matplotlib.pyplot as plt # plot images and graphs
import argparse # for parsing arguments
import time
import os
# VGG16 model
from AElib.VGG import VGG16
# Let's train the model
def train(device, model, loaders, optimizer, criterion, epochs=10, save_param=True, dataset="mnist"):
try:
model = model.to(device) # Load model to CUDA
history_loss = {"train": [], "test": []}
history_accuracy = {"train": [], "test": []}
best_test_accuracy = 0 # variable to store the best test accuracy
start_time = time.time()
for epoch in range(epochs):
sum_loss = {"train": 0, "test": 0}
sum_accuracy = {"train": 0, "test": 0}
for split in ["train", "test"]:
if split == "train":
model.train()
else:
model.eval()
for (inputs, labels) in loaders[split]:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad() # Reset gradients
prediction = model(inputs)
labels = labels.long()
loss = criterion(prediction, labels)
sum_loss[split] += loss.item() # Update loss
if split == "train":
loss.backward() # Compute gradients
optimizer.step() # Optimize
_,pred_label = torch.max(prediction, dim = 1)
pred_labels = (pred_label == labels).float()
batch_accuracy = pred_labels.sum().item() / inputs.size(0)
sum_accuracy[split] += batch_accuracy # Update accuracy
# Compute epoch loss/accuracy
epoch_loss = {split: sum_loss[split] / len(loaders[split]) for split in ["train", "test"]}
epoch_accuracy = {split: sum_accuracy[split] / len(loaders[split]) for split in ["train", "test"]}
# Store params at the best test accuracy
if save_param and epoch_accuracy["test"] > best_test_accuracy:
torch.save(model.state_dict(), f"./models/vgg16_{dataset}_model.pth")
best_test_accuracy = epoch_accuracy["test"]
# Update history
for split in ["train", "test"]:
history_loss[split].append(epoch_loss[split])
history_accuracy[split].append(epoch_accuracy[split])
print(f"Epoch: [{epoch + 1} | {epochs}]\nTrain Loss: {epoch_loss['train']:.4f}, Train Accuracy: {epoch_accuracy['train']:.2f}, \
Test Loss: {epoch_loss['test']:.4f}, Test Accuracy: {epoch_accuracy['test']:.2f}, Time Taken: {(time.time() - start_time) / 60:.2f} mins")
except KeyboardInterrupt:
print("Interrupted")
finally:
# Plot loss
plt.title("Loss")
for split in ["train", "test"]:
plt.plot(history_loss[split], label=split)
plt.legend()
plt.savefig(f"./images/vgg16_{dataset}_loss.png")
plt.close()
# Plot accuracy
plt.title("Accuracy")
for split in ["train", "test"]:
plt.plot(history_accuracy[split], label=split)
plt.legend()
plt.savefig(f"./images/vgg16_{dataset}_accuracy.png")
plt.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser(
prog='Training VGG16 Model',
description='''VGG16 is object detection and classification algorithm which is able to classify 1000 images of 1000 different categories with 92.7% accuracy. It is one of the popular algorithms for image classification and is easy to use with transfer learning.
This script is used to train a VGG16 model on MNIST and Fashion MNIST datasets.
'''
)
# inputs
parser.add_argument('--dataset', type=str, choices=['mnist', 'fashion-mnist'], default='mnist', help='dataset to use')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
parser.add_argument('--epochs', type=int, default=10, help='number of epochs')
args = vars(parser.parse_args())
# Use GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the model
model = VGG16((1,32,32), batch_norm=True)
# optimizer
optimizer = optim.SGD(model.parameters(), lr=args['lr'])
# loss function
criterion = nn.CrossEntropyLoss()
transform=transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
])
# Create a data directory
if not os.path.exists('./data'):
os.makedirs('./data')
print('Created a data directory ...')
# Load the dataset
if args['dataset'] == 'mnist':
train_set = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_set = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
elif args['dataset'] == 'fashion-mnist':
train_set = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_set = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
# Set up data loaders
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=False)
# Define dictionary of loaders
loaders = {"train": train_loader,
"test": test_loader}
train(device, model, loaders, optimizer, criterion, args['epochs'], dataset=args['dataset'])