-
Notifications
You must be signed in to change notification settings - Fork 14
/
learning_to_learn.py
501 lines (390 loc) · 21 KB
/
learning_to_learn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
# coding: utf-8
# Learning to learn by gradient descent by gradient descent
# =========================#
# https://arxiv.org/abs/1611.03824
# https://yangsenius.github.io/blog/LSTM_Meta/
# author:yangsen
# #### “通过梯度下降来学习如何通过梯度下降学习”
# #### "learning to learn by gradient descent by gradient descent"
# #### 要让优化器学会这样 "为了更好地得到,要先去舍弃" 这样类似的知识!
# #### make the optimizer to learn the knowledge of "sometimes, in order to get it better, you have to give up first. "
import torch
import torch.nn as nn
from timeit import default_timer as timer
##################### 优化问题 ##########################
##################### optimization ##########################
USE_CUDA = False
DIM = 10
batchsize = 128
if torch.cuda.is_available():
USE_CUDA = True
print('\n\nUSE_CUDA = {}\n\n'.format(USE_CUDA))
def f(W,Y,x):
"""quadratic function : f(\theta) = \|W\theta - y\|_2^2"""
if USE_CUDA:
W = W.cuda()
Y = Y.cuda()
x = x.cuda()
return ((torch.matmul(W,x.unsqueeze(-1)).squeeze()-Y)**2).sum(dim=1).mean(dim=0)
###############################################################
###################### 手工的优化器 ###################
###################### hand-craft optimizer ###################
def SGD(gradients, state, learning_rate=0.001):
return -gradients*learning_rate, state
def RMS(gradients, state, learning_rate=0.01, decay_rate=0.9):
if state is None:
state = torch.zeros(DIM)
if USE_CUDA == True:
state = state.cuda()
state = decay_rate*state + (1-decay_rate)*torch.pow(gradients, 2)
update = -learning_rate*gradients / (torch.sqrt(state+1e-5))
return update, state
def adam():
return torch.optim.Adam()
##########################################################
##################### 自动 LSTM 优化器模型 ##########################
##################### auto LSTM optimizer model ##########################
class LSTM_optimizer_Model(torch.nn.Module):
"""LSTM优化器
LSTM optimizer"""
def __init__(self,input_size,output_size, hidden_size, num_stacks, batchsize, preprocess = True ,p = 10 ,output_scale = 1):
super(LSTM_optimizer_Model,self).__init__()
self.preprocess_flag = preprocess
self.p = p
self.input_flag = 2
if preprocess != True:
self.input_flag = 1
self.output_scale = output_scale
self.lstm = torch.nn.LSTM(input_size*self.input_flag, hidden_size, num_stacks)
self.Linear = torch.nn.Linear(hidden_size,output_size) #1-> output_size
def LogAndSign_Preprocess_Gradient(self,gradients):
"""
Args:
gradients: `Tensor` of gradients with shape `[d_1, ..., d_n]`.
p : `p` > 0 is a parameter controlling how small gradients are disregarded
Returns:
`Tensor` with shape `[d_1, ..., d_n-1, 2 * d_n]`. The first `d_n` elements
along the nth dimension correspond to the `log output` \in [-1,1] and the remaining
`d_n` elements to the `sign output`.
"""
p = self.p
log = torch.log(torch.abs(gradients))
clamp_log = torch.clamp(log/p , min = -1.0,max = 1.0)
clamp_sign = torch.clamp(torch.exp(torch.Tensor(p))*gradients, min = -1.0, max =1.0)
return torch.cat((clamp_log,clamp_sign),dim = -1) #在gradients的最后一维input_dims拼接 # concatenate in final dim
def Output_Gradient_Increment_And_Update_LSTM_Hidden_State(self, input_gradients, prev_state):
"""LSTM的核心操作 core operation
coordinate-wise LSTM """
if prev_state is None: #init_state
prev_state = (torch.zeros(Layers,batchsize,Hidden_nums),
torch.zeros(Layers,batchsize,Hidden_nums))
if USE_CUDA :
prev_state = (torch.zeros(Layers,batchsize,Hidden_nums).cuda(),
torch.zeros(Layers,batchsize,Hidden_nums).cuda())
update , next_state = self.lstm(input_gradients, prev_state)
update = self.Linear(update) * self.output_scale # transform the LSTM output to the target output dim
return update, next_state
def forward(self,input_gradients, prev_state):
if USE_CUDA:
input_gradients = input_gradients.cuda()
#pytorch requires the `torch.nn.lstm`'s input as(1,batchsize,input_dim)
# original gradient.size()=torch.size[5] ->[1,1,5]
gradients = input_gradients.unsqueeze(0)
if self.preprocess_flag == True:
gradients = self.LogAndSign_Preprocess_Gradient(gradients)
update , next_state = self.Output_Gradient_Increment_And_Update_LSTM_Hidden_State(gradients , prev_state)
# Squeeze to make it a single batch again.[1,1,5]->[5]
update = update.squeeze().squeeze()
return update , next_state
################# 优化器模型参数 ##############################
################# Parameters of optimizer ##############################
Layers = 2
Hidden_nums = 20
Input_DIM = DIM
Output_DIM = DIM
output_scale_value=1
####### 构造一个优化器 #######
####### construct a optimizer #######
LSTM_optimizer = LSTM_optimizer_Model(Input_DIM, Output_DIM, Hidden_nums ,Layers , batchsize=batchsize,\
preprocess=False,output_scale=output_scale_value)
print(LSTM_optimizer)
if USE_CUDA:
LSTM_optimizer = LSTM_optimizer.cuda()
###################### 优化问题目标函数的学习过程 ###############
###################### the learning process of optimizing the target function ###############
class Learner( object ):
"""
Args :
`f` : 要学习的问题 the learning problem, also called `optimizee` in the paper
`optimizer` : 使用的优化器 the used optimizer
`train_steps` : 对于其他SGD,Adam等是训练周期,对于LSTM训练时的展开周期 training steps for SGD and ADAM, unfolded step for LSTM train
`retain_graph_flag=False` : 默认每次loss_backward后 释放动态图 default: free the dynamic graph after the loss backward
`reset_theta = False ` : 默认每次学习前 不随机初始化参数 default: do not initialize the theta
`reset_function_from_IID_distirbution = True` : 默认从分布中随机采样函数 default: random sample from distribution
Return :
`losses` : reserves each loss value in each iteration
`global_loss_graph` : constructs the graph of all Unroll steps for LSTM's BPTT
"""
def __init__(self, f , optimizer, train_steps ,
eval_flag = False,
retain_graph_flag=False,
reset_theta = False ,
reset_function_from_IID_distirbution = True):
self.f = f
self.optimizer = optimizer
self.train_steps = train_steps
#self.num_roll=num_roll
self.eval_flag = eval_flag
self.retain_graph_flag = retain_graph_flag
self.reset_theta = reset_theta
self.reset_function_from_IID_distirbution = reset_function_from_IID_distirbution
self.init_theta_of_f()
self.state = None
self.global_loss_graph = 0 # global loss for optimizing LSTM
self.losses = [] # KEEP each loss of all epoches
def init_theta_of_f(self,):
''' 初始化 优化问题 f 的参数
initialize the theta of optimization f '''
self.DIM = 10
self.batchsize = 128
self.W = torch.randn(batchsize,DIM,DIM) # represents IID
self.Y = torch.randn(batchsize,DIM)
self.x = torch.zeros(self.batchsize,self.DIM)
self.x.requires_grad = True
if USE_CUDA:
self.W = self.W.cuda()
self.Y = self.Y.cuda()
self.x = self.x.cuda()
def Reset_Or_Reuse(self , x , W , Y , state, num_roll):
''' re-initialize the `W, Y, x , state` at the begining of each global training
IF `num_roll` == 0 '''
reset_theta =self.reset_theta
reset_function_from_IID_distirbution = self.reset_function_from_IID_distirbution
if num_roll == 0 and reset_theta == True:
theta = torch.zeros(batchsize,DIM)
theta_init_new = theta.clone().detach().requires_grad_(True)
x = theta_init_new
################ 每次全局训练迭代,从独立同分布的Normal Gaussian采样函数 ##################
################ at the first iteration , sample from IID Normal Gaussian ##################
if num_roll == 0 and reset_function_from_IID_distirbution == True :
W = torch.randn(batchsize,DIM,DIM) # represents IID
Y = torch.randn(batchsize,DIM) # represents IID
if num_roll == 0:
state = None
print('reset the values of `W`, `x`, `Y` and `state` for this optimizer')
if USE_CUDA:
W = W.cuda()
Y = Y.cuda()
x = x.cuda()
x.retain_grad()
return x , W , Y , state
def __call__(self, num_roll=0) :
'''
Total Training steps = Unroll_Train_Steps * the times of `Learner` been called
SGD,RMS,LSTM FROM defination above
but Adam is adopted by pytorch~ This can be improved later'''
f = self.f
x , W , Y , state = self.Reset_Or_Reuse(self.x , self.W , self.Y , self.state , num_roll )
self.global_loss_graph = 0 #at the beginning of unroll, reset to 0
optimizer = self.optimizer
if optimizer!='Adam':
for i in range(self.train_steps):
loss = f(W,Y,x)
#self.global_loss_graph += (0.8*torch.log10(torch.Tensor([i+1]))+1)*loss
self.global_loss_graph += loss
loss.backward(retain_graph=self.retain_graph_flag) # default as False,set to True for LSTMS
update, state = optimizer(x.grad.clone().detach(), state)
self.losses.append(loss)
x = x + update
x.retain_grad()
update.retain_grad()
if state is not None:
self.state = (state[0].detach(),state[1].detach())
return self.losses ,self.global_loss_graph
else: #Pytorch Adam
x.detach_()
x.requires_grad = True
optimizer= torch.optim.Adam( [x],lr=0.1 )
for i in range(self.train_steps):
optimizer.zero_grad()
loss = f(W,Y,x)
self.global_loss_graph += loss
loss.backward(retain_graph=self.retain_graph_flag)
optimizer.step()
self.losses.append(loss.detach_())
return self.losses, self.global_loss_graph
####### LSTM 优化器的训练过程 Learning to learn ###############
####### LSTM training Learning to learn ###############
def Learning_to_learn_global_training(optimizer, global_taining_steps, optimizer_Train_Steps, UnRoll_STEPS, Evaluate_period ,optimizer_lr=0.1):
""" Training the LSTM optimizer . Learning to learn
Args:
`optimizer` : DeepLSTMCoordinateWise optimizer model
`global_taining_steps` : how many steps for optimizer training o可以ptimizee
`optimizer_Train_Steps` : how many step for optimizer opimitzing each function sampled from IID.
`UnRoll_STEPS` :: how many steps for LSTM optimizer being unrolled to construct a computing graph to BPTT.
"""
global_loss_list = []
Total_Num_Unroll = optimizer_Train_Steps // UnRoll_STEPS
adam_global_optimizer = torch.optim.Adam(optimizer.parameters(),lr = optimizer_lr)
LSTM_Learner = Learner(f, optimizer, UnRoll_STEPS, retain_graph_flag=True, reset_theta=True,)
#这里考虑Batchsize代表IID的话,那么就可以不需要每次都重新IID采样
# If regarding `Batchsize` as `IID` ,there is no need for reset the theta
#That is ,reset_function_from_IID_distirbution = False else it is True
best_sum_loss = 999999
best_final_loss = 999999
best_flag = False
for i in range(Global_Train_Steps):
print('\n========================================> global training steps: {}'.format(i))
for num in range(Total_Num_Unroll):
start = timer()
_,global_loss = LSTM_Learner(num)
adam_global_optimizer.zero_grad()
global_loss.backward()
adam_global_optimizer.step()
# print('xxx',[(z.grad,z.requires_grad) for z in optimizer.lstm.parameters() ])
global_loss_list.append(global_loss.detach_())
time = timer() - start
#if i % 10 == 0:
print('-> time consuming [{:.1f}s] optimizer train steps : [{}] | Global_Loss = [{:.1f}] '\
.format(time,(num +1)* UnRoll_STEPS,global_loss,))
if (i + 1) % Evaluate_period == 0:
best_sum_loss, best_final_loss, best_flag = evaluate(best_sum_loss,best_final_loss,best_flag , optimizer_lr)
return global_loss_list,best_flag
def evaluate(best_sum_loss,best_final_loss, best_flag,lr):
print('\n --------> evalute the model')
STEPS = 100
x = np.arange(STEPS)
Adam = 'Adam'
LSTM_learner = Learner(f , LSTM_optimizer, STEPS, eval_flag=True,reset_theta=True, retain_graph_flag=True)
SGD_Learner = Learner(f , SGD, STEPS, eval_flag=True,reset_theta=True,)
RMS_Learner = Learner(f , RMS, STEPS, eval_flag=True,reset_theta=True,)
Adam_Learner = Learner(f , Adam, STEPS, eval_flag=True,reset_theta=True,)
sgd_losses, sgd_sum_loss = SGD_Learner()
rms_losses, rms_sum_loss = RMS_Learner()
adam_losses, adam_sum_loss = Adam_Learner()
lstm_losses, lstm_sum_loss = LSTM_learner()
p1, = plt.plot(x, sgd_losses, label='SGD')
p2, = plt.plot(x, rms_losses, label='RMS')
p3, = plt.plot(x, adam_losses, label='Adam')
p4, = plt.plot(x, lstm_losses, label='LSTM')
plt.yscale('log')
plt.legend(handles=[p1, p2, p3, p4])
plt.title('Losses')
plt.pause(1.5)
#plt.show()
print("sum_loss:sgd={},rms={},adam={},lstm={}".format(sgd_sum_loss,rms_sum_loss,adam_sum_loss,lstm_sum_loss ))
plt.close()
torch.save(LSTM_optimizer.state_dict(),'current_LSTM_optimizer_ckpt.pth')
try:
best = torch.load('best_loss.txt')
except IOError:
print ('can not find best_loss.txt')
now_sum_loss = lstm_sum_loss.cpu()
now_final_loss = lstm_losses[-1].cpu()
pass
else:
best_sum_loss = best[0].cpu()
best_final_loss = best[1].cpu()
now_sum_loss = lstm_sum_loss.cpu()
now_final_loss = lstm_losses[-1].cpu()
print(" ==> History: sum loss = [{:.1f}] \t| final loss = [{:.2f}]".format(best_sum_loss,best_final_loss))
print(" ==> Current: sum loss = [{:.1f}] \t| final loss = [{:.2f}]".format(now_sum_loss,now_final_loss))
# save the best model according to the conditions below
# there may be several choices to make a trade-off
if now_final_loss < best_final_loss: # and now_sum_loss < best_sum_loss:
best_final_loss = now_final_loss
best_sum_loss = now_sum_loss
print('\n\n===> update new best of final LOSS[{}]: = {}, best_sum_loss ={}'.format(STEPS, best_final_loss,best_sum_loss))
torch.save(LSTM_optimizer.state_dict(),'best_LSTM_optimizer.pth')
torch.save([best_sum_loss ,best_final_loss,lr ],'best_loss.txt')
best_flag = True
return best_sum_loss, best_final_loss, best_flag
########################## before learning LSTM optimizer ###############################
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
STEPS = 100
x = np.arange(STEPS)
Adam = 'Adam' # Adam in Pytorch
for _ in range(1):
SGD_Learner = Learner(f , SGD, STEPS, eval_flag=True,reset_theta=True,)
RMS_Learner = Learner(f , RMS, STEPS, eval_flag=True,reset_theta=True,)
Adam_Learner = Learner(f , Adam, STEPS, eval_flag=True,reset_theta=True,)
LSTM_learner = Learner(f , LSTM_optimizer, STEPS, eval_flag=True,reset_theta=True,retain_graph_flag=True)
sgd_losses, sgd_sum_loss = SGD_Learner()
rms_losses, rms_sum_loss = RMS_Learner()
adam_losses, adam_sum_loss = Adam_Learner()
lstm_losses, lstm_sum_loss = LSTM_learner()
p1, = plt.plot(x, sgd_losses, label='SGD')
p2, = plt.plot(x, rms_losses, label='RMS')
p3, = plt.plot(x, adam_losses, label='Adam')
p4, = plt.plot(x, lstm_losses, label='LSTM')
p1.set_dashes([2, 2, 2, 2]) # 2pt line, 2pt break, 10pt line, 2pt break
p2.set_dashes([4, 2, 8, 2]) # 2pt line, 2pt break, 10pt line, 2pt break
p3.set_dashes([3, 2, 10, 2]) # 2pt line, 2pt break, 10pt line, 2pt break
plt.yscale('log')
plt.legend(handles=[p1, p2, p3, p4])
plt.title('Losses')
plt.pause(2.5)
print("\n\nsum_loss:sgd={},rms={},adam={},lstm={}".format(sgd_sum_loss,rms_sum_loss,adam_sum_loss,lstm_sum_loss ))
#################### Learning to learn (optimizing optimizer) ######################
Global_Train_Steps = 1000 #可修改 changeable
optimizer_Train_Steps = 100
UnRoll_STEPS = 20
Evaluate_period = 1 #可修改 changeable
optimizer_lr = 0.1 #可修改 changeable
global_loss_list ,flag = Learning_to_learn_global_training( LSTM_optimizer,
Global_Train_Steps,
optimizer_Train_Steps,
UnRoll_STEPS,
Evaluate_period,
optimizer_lr)
######################################################################3#
########################## show learning process results
#torch.load('best_LSTM_optimizer.pth'))
#import numpy as np
#import matplotlib
#import matplotlib.pyplot as plt
#Global_T = np.arange(len(global_loss_list))
#p1, = plt.plot(Global_T, global_loss_list, label='Global_graph_loss')
#plt.legend(handles=[p1])
#plt.title('Training LSTM optimizer by gradient descent ')
#plt.show()
######################################################################3#
########################## show contrast results SGD,ADAM, RMS ,LSTM ###############################
import copy
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
if flag ==True :
print('\n==== > load best LSTM model')
last_state_dict = copy.deepcopy(LSTM_optimizer.state_dict())
torch.save(LSTM_optimizer.state_dict(),'final_LSTM_optimizer.pth')
LSTM_optimizer.load_state_dict( torch.load('best_LSTM_optimizer.pth'))
LSTM_optimizer.load_state_dict(torch.load('best_LSTM_optimizer.pth'))
#LSTM_optimizer.load_state_dict(torch.load('final_LSTM_optimizer.pth'))
STEPS = 100
x = np.arange(STEPS)
Adam = 'Adam'
for _ in range(3): #可以多试几次测试实验,LSTM不稳定 for several test, the trained LSTM is not stable?
SGD_Learner = Learner(f , SGD, STEPS, eval_flag=True,reset_theta=True,)
RMS_Learner = Learner(f , RMS, STEPS, eval_flag=True,reset_theta=True,)
Adam_Learner = Learner(f , Adam, STEPS, eval_flag=True,reset_theta=True,)
LSTM_learner = Learner(f , LSTM_optimizer, STEPS, eval_flag=True,reset_theta=True,retain_graph_flag=True)
sgd_losses, sgd_sum_loss = SGD_Learner()
rms_losses, rms_sum_loss = RMS_Learner()
adam_losses, adam_sum_loss = Adam_Learner()
lstm_losses, lstm_sum_loss = LSTM_learner()
p1, = plt.plot(x, sgd_losses, label='SGD')
p2, = plt.plot(x, rms_losses, label='RMS')
p3, = plt.plot(x, adam_losses, label='Adam')
p4, = plt.plot(x, lstm_losses, label='LSTM')
p1.set_dashes([2, 2, 2, 2]) # 2pt line, 2pt break, 10pt line, 2pt break
p2.set_dashes([4, 2, 8, 2]) # 2pt line, 2pt break, 10pt line, 2pt break
p3.set_dashes([3, 2, 10, 2]) # 2pt line, 2pt break, 10pt line, 2pt break
#p4.set_dashes([2, 2, 10, 2]) # 2pt line, 2pt break, 10pt line, 2pt break
plt.yscale('log')
plt.legend(handles=[p1, p2, p3, p4])
plt.title('Losses')
plt.show()
print("\n\nsum_loss:sgd={},rms={},adam={},lstm={}".format(sgd_sum_loss,rms_sum_loss,adam_sum_loss,lstm_sum_loss ))