-
Notifications
You must be signed in to change notification settings - Fork 10
/
sparse_kl.py
186 lines (173 loc) · 6.99 KB
/
sparse_kl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import matplotlib.pyplot as plt # for plotting images
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import time
import os
from tqdm import tqdm # for showing progess bars during training
import loader # module for dataset loading
device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
seed = 42 # seed is used to ensure that we get the same output every time
torch.manual_seed(seed)
batch_size = 1200 # This will give 100 batches per epoch as the train set is 60k images
epochs = 20
learning_rate = 9e-3
model_file = 'sparse_kl.pth' # Path where the model is saved/loaded
rho = 0.08
beta = 0.005
class AE(nn.Module):
'''
This is the autoencoder class.
'''
def __init__(self):
super().__init__()
self.enc1 = nn.Conv2d(1,8,3,stride=1)
self.enc2 = nn.Conv2d(8,16,5,stride=1)
self.enc3 = nn.Conv2d(16,32,5,stride=2)
self.bottle = nn.MaxPool2d(2, 2)
# The encoder network
self.dec1 = nn.ConvTranspose2d(32,16,3, stride=2)
self.dec2 = nn.ConvTranspose2d(16,8,5, stride=2)
self.dec3 = nn.ConvTranspose2d(8,1,8, stride=1)
# The decoder network
def forward(self, features):
x = F.relu(self.enc1(features.float()))
x = F.relu(self.enc2(x))
x = F.relu(self.enc3(x))
x = self.bottle(x)
x = F.relu(self.dec1(x))
x = F.relu(self.dec2(x))
x = torch.sigmoid(self.dec3(x))
return x
def encode(self, features):
# Encodes the input to a smaller size
x = F.relu(self.enc1(features.float()))
x = F.relu(self.enc2(x))
x = F.relu(self.enc3(x))
x = self.bottle(x)
return x
def decode(self, features):
# Decodes the given input back to its original size
x = F.relu(self.dec1(features.float()))
x = F.relu(self.dec2(x))
x = torch.sigmoid(self.dec3(x))
return x
def forward_with_layers(self, features):
'''
Used during training, returns activations of hidden layers as a list along with output.
Takes one parameter:
features: The input to the neural network
'''
ret_act = []
a1 = F.relu(self.enc1(features.float()))
ret_act.append(a1)
a2 = F.relu(self.enc2(a1))
ret_act.append(a2)
a3 = F.relu(self.enc3(a2))
ret_act.append(a3)
a4 = self.bottle(a3)
ret_act.append(a4)
a5 = F.relu(self.dec1(a4))
ret_act.append(a5)
a6 = F.relu(self.dec2(a5))
ret_act.append(a6)
ans = torch.sigmoid(self.dec3(a6))
return ans, ret_act
def clrscr(): # used for clearing the screen after every move
if os.name == "posix":
# Unix/Linux/MacOS/BSD/etc
os.system('clear')
elif os.name in ("nt", "dos", "ce"):
# DOS/Windows
os.system('cls')
def kl_loss(layers):
'''
Calculates Kullback-Liebler divergence loss of the hidden layers passed in the list
Requires two parameters:
layers: Layers passed as a list of torch tensors.
'''
loss = 0
for i in layers:
loss += F.kl_div(torch.mean(i,dim=0), rho*torch.ones_like(i[0]))
return beta*torch.abs(loss)
def train_model(model, optimizer, criterion):
'''
This function trains the Neural Network. Parameters are:
model: The neural network model object,
optimizer: The optimizer object to be used during training,
criterion: The loss function object to be used during training
'''
train_loader = loader.train_loader_fn(batch_size) # Loads the training dataset
loss_list = [] # Stores loss after every epoch
for epoch in tqdm(range(epochs)): # Looping for every epoch
loss = 0
for batch_features, _ in tqdm(train_loader): # Looping for every batch
batch_features = batch_features.to(device)
optimizer.zero_grad() # Model training starts here
outputs, layers = model.forward_with_layers(batch_features)
train_loss = criterion(outputs, batch_features.to(device)) + kl_loss(layers)
train_loss.backward()
optimizer.step()
loss += train_loss.item()
loss = loss / len(train_loader)
loss_list.append(loss) # Stores loss in the loss_list list
clrscr() # clears the screen
print("epoch : {}/{}, recon loss = {:.8f}".format(epoch + 1, epochs, loss))
return loss_list
def main():
# The main function of the code, executes automatically if run as a single file
model = AE().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.BCELoss()
train_need = input("Press l to load model, t to train model, tl to load and train model: ").lower()
# Asks user whether to load saved model or train from scratch
if train_need == 't':
loss_list = train_model(model, optimizer, criterion)
elif train_need == 'l':
model.load_state_dict(torch.load(model_file))
elif train_need == 'tl':
model.load_state_dict(torch.load(model_file))
loss_list = train_model(model, optimizer, criterion)
test_loader = loader.test_loader_fn(batch_size) # loads the testing dataset
test_examples = None
with torch.no_grad():
for batch_features in test_loader: # Test examples are passed through the model for testing
batch_features = batch_features[0]
test_examples = batch_features.to(device)
reconstruction = model(test_examples)
test_loss = nn.functional.binary_cross_entropy(reconstruction, test_examples)
print("Test Loss is: ", test_loss.item())
break
try: # Plots a graph if the training was done, else skips it
plt.xlabel('Number of epochs')
plt.ylabel('Loss Value')
plt.plot(range(len(loss_list)),loss_list)
except:
pass
with torch.no_grad():
number = 10
plt.figure(figsize=(20, 4))
# Below code plots 10 input images and their output
for index in range(number):
ax = plt.subplot(2, number, index + 1)
plt.imshow(test_examples[index].cpu().numpy().reshape(28, 28))
# plotting the ith test image in subplot
plt.gray() # changing color code to grayscale
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False) # removing axes
ax = plt.subplot(2, number, index + 1 + number)
plt.imshow(reconstruction[index].cpu().numpy().reshape(28, 28))
# plotting the ith test image in subplot
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
if train_need == 't' or train_need == 'tl':
# If the model was trained, it asks whether or not to save the model
save_status=input("Enter s to save the model: ").lower()
if save_status=='s':
torch.save(model.state_dict(),model_file)
if __name__ == "__main__":
main()