-
Notifications
You must be signed in to change notification settings - Fork 1
/
rnn.py
358 lines (257 loc) · 12 KB
/
rnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
import numpy as np
import matplotlib.pyplot as plt
# analysis
from sklearn.decomposition import PCA, FactorAnalysis
from sklearn.linear_model import LinearRegression
from scipy import stats, interpolate
from scipy import linalg as LA
# miscellaneous
from tqdm import tqdm
from itertools import cycle
from copy import deepcopy
import logging
import warnings
import dataclasses
from dataclasses import dataclass
from typing import Optional, List
from utils.functions import f, df, theta, rgetattr
@dataclass
class RNNparams():
"""
Hyperparameter class for vanilla RNN
Attributes:
n_in (int): dimension of inputs
n_rec (int): dimension of recurrent units
n_out (int): dimension of output
sig_in (float): noise scale of input weights
sig_rec (float): noise scale of recurrent weights
sig_out (float): noise scale of output weights
tau_rec (float): time constant for recurrent weights
eta_in (float): learning rate for input weights
eta_rec (float): learning rate for recurrent weights
eta_out (float): learning rate for output weights
driving_feedback (bool): whether there is driving feedback
eta_fb (float): learning rate for feedback weights
feedback_signal (str):
velocity_transform (bool): whether to apply low pass filter to neural readout
tau_vel (float): time constant for velocity transformation
dt_vel (float):
rng: random number generator
TODO: Feedback noise
"""
""" number of units at each layer """
n_in: int
n_rec: int
n_out: int
""" noise """
sig_in: float
sig_rec: float
sig_out: float
tau_rec: float
""" integration timescale of simulation
Note that if this value is changed to something other than 1, it affects
the simulation in 2 locations:
- the recurrent activity update, which is scaled by dt/tau_rec
- the recurrent noise xi, which is scaled by sqrt(dt)
"""
dt: Optional[float] = 1.0
""" learning rates for each population """
# note that this does not mean that the RNN necessarily learns
eta_in: Optional[float] = None
eta_rec: Optional[float] = None
eta_out: Optional[float] = None
g_in: Optional[float] = 2.0
g_rec: Optional[float] = 1.5
g_out: Optional[float] = 2.0
g_fb: Optional[float] = 2.0
""" driving feedback parameters """
driving_feedback: bool = False
eta_fb: Optional[float] = None # learning rate for feedback weights
sig_fb: Optional[float] = None
feedback_signal: Optional[str] = 'position'
""" velocity transform """
velocity_transform: bool = False
tau_vel: Optional[float] = None
dt_vel: Optional[float] = None
""" recurrent noise dimension parameters """
sig_rec_dim: Optional[int] = None
""" weight mirroring parameters """
eta_m: Optional[float] = None
sig_m: Optional[float] = None
lam_m: Optional[float] = None
rng: np.random.RandomState() = np.random.RandomState(17)
def print_params(self) -> None:
""" Method to print hyperparameters """
for k,v in dataclasses.asdict(self).items():
print(k+':',v)
class RNN():
"""
RNN class
This class also holds the current state of the RNN, and the function
'next_state' advances the RNN to the next state.
Class functions can also initialize weights and insert/set weights.
Args:
params (RNNparams): dataclass object that stores hyperparameters for network
init (boolean): if true, initialize weights. default=True
"""
def __init__(self, params: RNNparams,init=True, f=f, df=df, sig_rec_covariance=None, load_weights_file=False) -> None:
for key, value in dataclasses.asdict(params).items():
setattr(self, key, value)
if load_weights_file:
self.load_weights(load_weights_file)
elif init:
self.initialize_weights()
# Initialize
self.x_in = 0
self.h0 = np.zeros((self.n_rec,1)) # initial activity of the RNN
self.h = np.copy(self.h0)
self.y_out = np.zeros((self.n_out,1))
self.pos = np.zeros((self.n_out,1))
self.u = np.zeros((self.n_rec,1))
self.f = f
self.df = df
if self.velocity_transform:
self.vel = np.zeros((self.n_out,1))
assert self.dt_vel, "If applying a velocity transform, dt_vel must be specified"
else:
self.vel = None
if self.driving_feedback:
assert self.eta_fb is not None, "If driving feebdack, eta_fb must be set"
assert self.feedback_signal in ['position','error'], "Must specify if feedback_signal from {'position','error'}"
assert self.sig_fb is not None, "If driving feedback, sig_fb must be set"
# TO DO: I don't want this to be here, but think it is necessary for probes
self.r = None
self.r_current = None
self.err = np.zeros((self.n_out,1)) # does this have to be here?
""" Properties of recurrent noise """
if self.sig_rec_dim == None:
self.sig_rec_dim = self.n_rec # dimension of recurrent noise
assert self.sig_rec_dim <= self.n_rec, 'recurrent noise dimension must be less than or equal to number of recurrent units'
""" generate covariance matrix for recurrent noise
this is used for sampling a multivariate gaussian via rng.multivariate_normal(mean,cov)"""
# scenario1 - full rank, isotropic noise
if sig_rec_covariance is None and self.sig_rec_dim == self.n_rec:
self.sig_rec_covariance = self.sig_rec * np.eye(self.n_rec) # isotropic sample
# scenario 2 - low-D, isotropic noise
elif sig_rec_covariance is None and self.sig_rec_dim < self.n_rec:
C = self.sig_rec * np.eye(self.n_rec)
ind = self.rng.choice(np.arange(self.sig_rec_dim,dtype=int),1)
C[ind] = 0 # set some neurons to zero
self.sig_rec_covariance = C
# scenario 3 - noise is specified by covariance matrix
else:
assert sig_rec_covariance.shape[0] == self.n_rec, 'covariance matrix must have shape (n_rec,n_rec)'
assert sig_rec_covariance.shape[0] == self.n_rec, 'covariance matrix must have shape (n_rec,n_rec)'
# CHECK - NOTE THERE IS NO MULTIPLICATION BY self.sig_rec HERE
self.sig_rec_covariance = sig_rec_covariance
def initialize_weights(self) -> None:
""" Initialize all weights with random number generator """
self.w_in = self.g_in*(self.rng.rand(self.n_rec, self.n_in) - 1) # changed from 0.1
self.w_rec = self.g_rec*self.rng.randn(self.n_rec, self.n_rec)/self.n_rec**0.5
self.w_out = self.g_out*(2*self.rng.rand(self.n_out, self.n_rec) - 1)/self.n_rec**0.5
self.w_m = np.copy(self.w_out).T # CHANGE THIS
if self.driving_feedback:
self.w_fb = self.g_fb*self.rng.randn(self.n_rec,self.n_out)/self.n_rec**0.5
def get_weights(self):
if self.driving_feedback:
return {
"w_in": self.w_in,
"w_rec": self.w_rec,
"w_out": self.w_out,
"w_m": self.w_m,
"w_fb": self.w_fb
}
return {
"w_in": self.w_in,
"w_rec": self.w_rec,
"w_out": self.w_out,
"w_m": self.w_m,
}
def save_weights(self, file):
np.savez(file, **self.get_weights())
def load_weights(self, file):
self.set_weights(**np.load(file))
def set_weights(self,
w_in: Optional[np.array]=None,
w_rec: Optional[np.array]=None,
w_out: Optional[np.array]=None,
w_m: Optional[np.array]=None,
w_fb: Optional[np.array]=None) -> None:
""" Set weights with predefined values
The weight(s) to set must be specified
Args:
w_in: input matrix
w_rec: recurrent matrix
w_out: output matrix
w_m: "transpose" matrix that updates learning (in SL)
w_fb: feedback matrix that drives RNN activity
"""
if w_in is not None:
assert w_in.shape == (self.n_rec,self.n_in), 'Dimensions must be (n_rec,n_in)'
self.w_in = w_in # = w_init['w_in'] #0.1*(np.random.rand(n_rec, n_in) - 1)
if w_rec is not None:
assert w_rec.shape == (self.n_rec,self.n_rec), 'Dimensions must be (n_rec,n_rec)'
self.w_rec = w_rec # = w_init['w_rec'] #1*np.random.randn(n_rec, n_rec)/n_rec**0.5 # --> changed from 1.5
if w_out is not None:
assert w_out.shape == (self.n_out,self.n_rec), 'Dimensions must be (n_out,n_rec)'
self.w_out = w_out # = w_init['w_out'] #1*(2*np.random.rand(n_out, n_rec) - 1)/n_rec**0.5 # --> should be on same scale as target
if w_m is not None:
assert w_m.shape == self.w_out.T.shape, 'Dimensions must be (n_out,n_rec)'
self.w_m = w_m
if w_fb is not None:
assert self.driving_feedback, 'driving_feedback should be set to True'
assert w_fb.shape == (self.n_rec,self.n_out), 'Dimensions must be (n_rec,n_out)'
self.w_fb = w_fb
def next_state(self, x_in: np.array) -> None:
"""
Advance the network forward by one step
Note that this is the basic RNN activity equation
Args:
x_in (np.array): external input
"""
self.h_prev = np.copy(self.h)
self.x_in_prev = self.x_in
""" recurrent activity """
if self.driving_feedback:
# Feedback signal is position (which seems to do better than error)
if self.feedback_signal == 'position':
self.u = np.dot(self.w_rec, self.h) + np.dot(self.w_in, x_in + self.sig_in*self.rng.randn(self.n_in,1)) + np.dot(self.w_fb, self.pos + self.sig_fb*self.rng.randn(self.n_out,1))
# Feedback signal here is error, not position
if self.feedback_signal == 'error':
self.u = np.dot(self.w_rec, self.h) + np.dot(self.w_in, x_in + self.sig_in*self.rng.randn(self.n_in,1)) + np.dot(self.w_fb, self.err + self.sig_fb*self.rng.randn(self.n_out,1))
else:
self.u = np.dot(self.w_rec, self.h) + np.dot(self.w_in, x_in + self.sig_in*self.rng.randn(self.n_in,1))
# update step
#self.xi = self.sig_rec*self.rng.randn(self.n_rec,1)
self.xi = self._generate_recurrent_noise()
self.h = self.h + (-self.h + self.f(self.u) + self.xi)*self.dt/self.tau_rec
#self.h = self.h + (-self.h + self.f(self.u) + self.sig_rec*self.rng.randn(self.n_rec,1))/self.tau_rec
self.x_in = x_in
def output(self) -> None:
""" Readout of the RNN
If there is no velocity transform, the readout is just
a mapping from the RNN activity directly to the position
via the matrix 'w_out'
"""
self.y_prev = np.copy(self.y_out)
# output
self.y_out = np.dot(self.w_out, self.h) + self.sig_out*self.rng.randn(self.n_out,1)
if self.velocity_transform:
# cursor velocity
self.vel = (1-1/self.tau_vel)*self.vel + (1/self.tau_vel)*self.y_out
# cursor position
self.pos = self.pos + self.vel*self.dt_vel
else:
self.pos = self.y_out
def _generate_recurrent_noise(self):
""" Generate Recurrent Noise from multivariate gaussian
This function generates noise that is injected into the recurrent units.
Noise is sampled from a gaussian distribution, and can be nonisotropic or low-D.
Returns:
xi: vector of dimension n_rec, divided by square root of integration step dt
Note: It is up to the user to check that a specified covariance matrix is positive semidefinite
"""
# sample from multivariate gaussian
mean = np.zeros(self.sig_rec_covariance.shape[0])
xi = self.rng.multivariate_normal(mean, cov=self.sig_rec_covariance, size=1).T # should be size (n_neurons,1)
return xi/np.sqrt(self.dt)