-
Notifications
You must be signed in to change notification settings - Fork 4
/
train_forward.py
118 lines (96 loc) · 4.83 KB
/
train_forward.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from torch import optim
import torch.utils.data as data
from setting import *
# torch.manual_seed(0)
class Trainer:
def __init__(self, model, iterations, iter_number, loss):
self.forward_model = model
self.iteration = iterations
self.interval = 100
self.losses = loss
self.batch_size = 256
self.lr = 3e-4
self.epoch = 200
self.iter_num = iter_number
self.threshold = loss['err_test'][-1] if load else 1 # 0 if not update
train_numpy = np.array(np.load('./data/AI_train.npy'), dtype=np.float32)
train_input = torch.FloatTensor(train_numpy[:, p1:p2]).to(device)
train_label = torch.FloatTensor(train_numpy[:, 0:p1]).to(device)
train_dataset = data.TensorDataset(train_input, train_label)
self.loader = data.DataLoader(dataset=train_dataset, batch_size=self.batch_size, shuffle=True)
test_numpy = np.array(np.load('./data/AI_test.npy'), dtype=np.float32)
self.test_input = torch.FloatTensor(test_numpy[:, p1:p2]).to(device)
self.test_label = torch.FloatTensor(test_numpy[:, 0:p1]).to(device)
def train(self):
optimizer = optim.Adam(self.forward_model.parameters(), lr=self.lr)
criterion = nn.MSELoss(reduction='mean')
for t in range(self.epoch):
for step, item in enumerate(self.loader):
# train
train_input, train_label = item
train_predict = self.forward_model(train_input)
loss_train = criterion(train_predict, train_label)
if self.iteration % self.interval == 0:
# test
test_predict = self.forward_model(self.test_input)
loss_test = criterion(test_predict, self.test_label)
self.losses['loss_train'].append(to_numpy(loss_train))
self.losses['loss_test'].append(to_numpy(loss_test))
# compute and print the absolute error
train_out = train_predict - train_label
train_error = np.abs(to_numpy(train_out)).mean()
test_out = test_predict - self.test_label
test_error = np.abs(to_numpy(test_out)).mean()
self.losses['err_train'].append(train_error)
self.losses['err_test'].append(test_error)
print('iteration: {}'.format(self.iteration))
print('train_loss: {:.4}, test_loss: {:.4}'.
format(loss_train, loss_test))
print('train_error: {:.4}, test_error: {:.4}'.
format(train_error, test_error))
self.iter_num.append(self.iteration)
# save the model
if self.threshold > test_error:
self.threshold = test_error
torch.save({'iteration': self.iteration,
'iter_num': self.iter_num,
'state_dict': self.forward_model.state_dict(),
'loss': self.losses,
'time': time.time() - start},
'checkpoint_forward.pth')
# update parameters
optimizer.zero_grad()
loss_train.backward()
optimizer.step()
self.iteration += 1
def inference(self):
c = 5
valid_input = self.test_input[c:c+1]
valid_predict = to_numpy(self.forward_model(valid_input).view(-1)) * 100
valid_label = to_numpy(self.test_label[c]) * 100
spectra = np.linspace(1, 320, 320) * 50
plt.title('Comparison of Transmission Spectrum')
plt.plot(spectra, valid_predict, color=color1, label='Prediction')
plt.plot(spectra, valid_label, color=color2, label='Simulation')
plt.legend(loc='upper right')
plt.xlabel('Frequency (HZ)')
plt.ylabel('Intensity (%)')
plt.show()
if __name__ == '__main__':
load = True
forward_model = Forward_Net().to(device)
if load:
checkpoint_forward = torch.load('checkpoint_forward.pth')
forward_model.load_state_dict(checkpoint_forward['state_dict'])
iteration = checkpoint_forward['iteration']
iter_num = checkpoint_forward['iter_num']
losses = checkpoint_forward['loss']
start = time.time() - checkpoint_forward['time']
else:
iteration = 0
iter_num = []
losses = {'err_train': [], 'err_test': [], 'loss_train': [], 'loss_test': []}
start = time.time()
trainer = Trainer(forward_model, iteration, iter_num, losses)
trainer.train()
# trainer.inference()