-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathauto-encoder.py
63 lines (55 loc) · 2.08 KB
/
auto-encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from __future__ import print_function
import matplotlib
matplotlib.use('Agg')
import argparse
# import matplotlib.pyplot as plt
import torch.nn as nn
from data_loader import *
from miniimagenet_loader import AutoEncoder
if __name__ == '__main__':
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--config', type=str, default='config', help='config file name in config dir')
parser.add_argument('--mode', type=str, default='train', help='Mode (train/test/train_and_evaluate)')
parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()
# Print Config setting
Config(args.config)
print("Config: ", Config)
if Config.get("description", None):
print("Config Description")
for key, value in Config.description.items():
print(f" - {key}: {value}")
train_dataset, valid_dataset = get_loader("train")
# define the autoencoder and move the network into GPU
ae = AutoEncoder()
ae.train()
ae.set_optimizer()
is_cuda = True if torch.cuda.is_available() else False
ae.use_cude = is_cuda
if is_cuda:
ae.cuda()
# define the loss (criterion) and create an optimizer
criterion = nn.MSELoss()
resume = True
if resume:
epoch, ae = ae.load_saved_model(ae.path_to_save, ae)
print(f"Model has been loaded epoch:{epoch}, path:{ae.path_to_save}")
else:
epoch = 0
for epoch in range(0, 2): # epochs loop
for batch_idx, (batch_img, batch_label) in enumerate(train_dataset): # batches loop
if is_cuda:
batch_img = batch_img.cuda()
batch_label = batch_label.cuda()
output = ae(batch_img)
# show_images_no_labels(output,batch_idx,"awe")
# show_images_no_labels(batch_img,batch_idx,"original")
loss = criterion(output, batch_img) # calculate the loss
if batch_idx % 50 == 0:
print(f'batch_idx:{batch_idx} loss: ', loss.data.item())
ae.optimizer.zero_grad()
loss.backward() # calculate the gradients (backpropagation)
ae.optimizer.step() # update the weights
if batch_idx % 1000 == 0:
ae.save_checkpoint(f"ep{epoch}_idx{batch_idx}") # show()
print(f"Saved in {ae.path_to_save}")