This repository has been archived by the owner on Apr 11, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 10
/
custom_optimizers.py
214 lines (198 loc) · 10.1 KB
/
custom_optimizers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
"""
Implementation of unitary and orthogonality preserving
tensorflow optimization ops.
"""
import tensorflow as tf
from tensorflow.python.training import training_ops
class RMSpropNatGrad(tf.train.Optimizer):
""" RMSProp optimizer with the capability to do natural gradient steps.
Inspired by: https://github.com/stwisdom/urnn/blob/master/custom_optimizers.py
See also:
Full-Capacity Unitary Recurrent Neural Networks, Wisdom et al, at:
https://arxiv.org/abs/1611.00035
"""
def __init__(self, learning_rate, decay=0.9, momentum=0.0,
epsilon=1e-10, global_step=None, nat_grad_normalization=False,
qr_steps=-1, name='RMSpropNatGrad'):
"""
Create an RMSProp Stiefel-manifold hybrid optimizer.
learning_rate: Stes the global learning rate for all updates.
(Please note Wisdom uses seperate learning rates for the RMSProp
and Stiefel updates.)
decay: Discounting factor for the history/coming gradient
momentum: A scalar tensor.
epsilon: Small value to avoid zero denominator.
global_step: The global step, used to decide when to run
the qr re-normalization.
nat_grad_normalization: If True, the Stiefel gradient will be
normalized as well.
qr_steps: When to run the qr re-normalization. -1 means never.
(Setting this value leads to a Stiefel manifold, projected
gradient descent hybrid approach. Not used in the paper.)
name: The name for your optimizer.
"""
if global_step is not None:
self._global_step_tensor = global_step
else:
raise ValueError("global_step tensor must be missing.")
use_locking = False
super().__init__(use_locking, name)
self._learning_rate = learning_rate
self._decay = decay
self._momentum = momentum
self._epsilon = epsilon
self._nat_grad_normalization = nat_grad_normalization
self._debug = True
if qr_steps > 0:
self._qr_steps = int(qr_steps)
else:
self._qr_steps = None
# Tensors for learning rate and momentum. Created in _prepare.
self._learning_rate_tensor = None
self._decay_tensor = None
self._momentum_tensor = None
self._epsilon_tensor = None
# print("training params:", self._learning_rate, self._decay, self._momentum)
def _create_slots(self, var_list):
""" Set up rmsprop slots for all variables."""
init_eps = tf.constant_initializer(self._epsilon)
for v in var_list:
init_rms = tf.ones_initializer(dtype=v.dtype)
self._get_or_make_slot_with_initializer(v, init_rms, v.get_shape(),
v.dtype, "rms", self._name)
self._zeros_slot(v, "momentum", self._name)
self._get_or_make_slot_with_initializer(v, init_eps, v.get_shape(),
v.dtype, "eps", self._name)
def _prepare(self):
"""Convert algorthm parameters to tensors. """
self._learning_rate_tensor = tf.convert_to_tensor(self._learning_rate,
name="learning_rate")
self._decay_tensor = tf.convert_to_tensor(self._decay, name="decay")
self._momentum_tensor = tf.convert_to_tensor(self._momentum,
name="momentum")
self._epsilon_tensor = tf.convert_to_tensor(self._epsilon,
name="epsilon")
def _summary_A(self, A):
# test A's skew symmetrie:
test_a = tf.transpose(tf.conj(A)) - (-A)
test_a_norm = tf.real(tf.norm(test_a))
tf.summary.scalar('A.H--A', test_a_norm)
def _summary_C(self, C):
# C must be unitary/orthogonal:
eye = tf.eye(*tf.Tensor.get_shape(C).as_list(), dtype=C.dtype)
test_c = eye - tf.matmul(tf.transpose(tf.conj(C)), C)
test_c_norm = tf.real(tf.norm(test_c))
tf.summary.scalar('I-C.HC', test_c_norm)
def _summary_W(self, W):
# W must also be unitary/orthogonal:
eye = tf.eye(*tf.Tensor.get_shape(W).as_list(), dtype=W.dtype)
test_w = eye - tf.matmul(tf.transpose(tf.conj(W)), W)
test_w_norm = tf.real(tf.norm(test_w))
tf.summary.scalar('I-W.HW', test_w_norm)
def re_unitarize(self, W):
# TODO: check this.
W, _ = tf.qr(W)
W = tf.Print(W, [tf.constant(0)], 'step with qr.')
return W
def _apply_dense(self, grad, var):
rms = self.get_slot(var, "rms")
mom = self.get_slot(var, "momentum")
eps = self.get_slot(var, 'eps')
tf.summary.scalar('grad_norm', tf.norm(grad))
# debug_here()
if 'orthogonal_stiefel' in var.name and 'bias' not in var.name:
with tf.variable_scope("orthogonal_update"):
print('Appling an orthogonality preserving step to', var.name)
# apply the rms update rule.
new_rms = self._decay_tensor * rms + (1. - self._decay_tensor) \
* tf.square(grad)
rms_assign_op = tf.assign(rms, new_rms)
# scale the gradient.
if self._nat_grad_normalization:
grad = grad / (tf.sqrt(rms) + eps)
# the update should preserve orthogonality.
grad_shape = tf.Tensor.get_shape(grad).as_list()
# W_new_lst = []
eye = tf.eye(grad_shape[0], dtype=tf.float32)
G = grad
W = var
# Reunitarize after n steps.
if self._qr_steps is not None:
W = tf.cond(tf.equal(tf.mod(self._global_step_tensor,
self._qr_steps), 0),
lambda: self.re_unitarize(W), lambda: W)
# A = tf.matmul(tf.transpose(G), W) - tf.matmul(tf.transpose(W), G)
A = tf.matmul(G, tf.transpose(W)) - tf.matmul(W, tf.transpose(G))
cayleyDenom = eye + (self._learning_rate_tensor/2.0) * A
cayleyNumer = eye - (self._learning_rate_tensor/2.0) * A
C = tf.matmul(tf.matrix_inverse(cayleyDenom), cayleyNumer)
W_new = tf.matmul(C, W)
if self._debug:
# self._summary_A(A)
self._summary_C(C)
self._summary_W(W)
var_update_op = tf.assign(var, W_new)
return tf.group(*[var_update_op, rms_assign_op])
elif 'unitary_stiefel' in var.name and 'bias' not in var.name:
with tf.variable_scope("unitary_update"):
print('Appling an unitarity preserving step to', var.name)
# apply the rms update rule.
new_rms = self._decay_tensor * rms + (1. - self._decay_tensor) \
* tf.square(grad)
rms_assign_op = tf.assign(rms, new_rms)
# scale the gradient.
if self._nat_grad_normalization:
grad = grad / (tf.sqrt(new_rms) + eps)
# do an update step, which preserves unitary structure.
# checking shapes.
grad_shape = tf.Tensor.get_shape(grad).as_list()
assert grad_shape[0] == grad_shape[1]
eye = tf.eye(grad_shape[0], dtype=tf.complex64)
G = tf.complex(grad[:, :, 0], grad[:, :, 1])
W = tf.complex(var[:, :, 0], var[:, :, 1])
# Reunitarize after n steps.
if self._qr_steps is not None:
W = tf.cond(tf.equal(tf.mod(self._global_step_tensor,
self._qr_steps), 0),
lambda: self.re_unitarize(W), lambda: W)
A = tf.matmul(G, tf.conj(tf.transpose(W))) \
- tf.matmul(W, tf.conj(tf.transpose(G)))
# A must be skew symmetric.
larning_rate_scale = tf.complex(self._learning_rate_tensor/2.0,
tf.zeros_like(self._learning_rate_tensor))
cayleyDenom = eye + larning_rate_scale * A
cayleyNumer = eye - larning_rate_scale * A
C = tf.matmul(tf.matrix_inverse(cayleyDenom), cayleyNumer)
W_new = tf.matmul(C, W)
if self._debug:
# self._summary_A(A)
self._summary_C(C)
self._summary_W(W)
# debug_here()
W_new_re = tf.real(W_new)
W_new_img = tf.imag(W_new)
W_array = tf.stack([W_new_re, W_new_img], -1)
var_update_op = tf.assign(var, W_array)
return tf.group(*[var_update_op, rms_assign_op])
else:
# do the usual RMSprop update
if 1:
# tensorflow default.
print('Appling standard rmsprop to', var.name)
return training_ops.apply_rms_prop(
var, rms, mom,
tf.cast(self._learning_rate_tensor, var.dtype.base_dtype),
tf.cast(self._decay_tensor, var.dtype.base_dtype),
tf.cast(self._momentum_tensor, var.dtype.base_dtype),
tf.cast(self._epsilon_tensor, var.dtype.base_dtype),
grad, use_locking=False).op
else:
# My rmsprop implementation.
new_rms = self._decay_tensor * rms \
+ (1. - self._decay_tensor) * tf.square(grad)
rms_assign_op = tf.assign(rms, new_rms)
W_new = var - self._learning_rate_tensor * grad / (tf.sqrt(new_rms) + eps)
var_update_op = tf.assign(var, W_new)
return tf.group(*[var_update_op, rms_assign_op])
def _apply_sparse(self, grad, var):
raise NotImplementedError("Sparse gradient updates are not supported.")