-
Notifications
You must be signed in to change notification settings - Fork 0
/
kwta.py
120 lines (100 loc) · 2.88 KB
/
kwta.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""
Iterative winners-take-all and k-winners-take-all activation functions.
"""
import numpy as np
def kWTA(x, k):
"""
Default k-winners-take-all activation function.
Parameters
----------
x : np.ndarray
A presynaptic sum (N,) vector or (N, S) array of vector samples.
k : int
The number of active neurons in the output.
Returns
-------
sdr : np.ndarray
A binary vector or array with the same shape as `x` and exactly `k`
neurons active.
"""
if k == 0:
return np.zeros_like(x)
if x.ndim == 1:
x = np.expand_dims(x, axis=1)
winners = np.argsort(x, axis=0)[-k:] # (k, S) shape
sdr = np.zeros_like(x)
sdr[winners, range(x.shape[1])] = 1
return sdr.squeeze()
def iWTA(x, w_xh, w_xy, w_hy, w_yy=None, w_hh=None, w_yh=None):
"""
Iterative winners-take-all activation function.
Parameters
----------
x : (Nx, S) np.ndarray
The input samples.
w_xh, w_xy, w_hy, w_yy, w_hh, w_yh : np.ndarray or None
Binary weights.
Returns
-------
h : (Nh, S) np.ndarray
Inhibitory populations output.
y : (Ny, S) np.ndarray
Excitatory populations output.
"""
h0 = w_xh @ x
y0 = w_xy @ x
h = np.zeros_like(h0, dtype=np.int32)
y = np.zeros_like(y0, dtype=np.int32)
t_start = max(h0.max(), y0.max())
for threshold in range(t_start, 0, -1):
z_h = h0
if w_hh is not None:
z_h = z_h - w_hh @ h
if w_yh is not None:
z_h = z_h + w_yh @ y
z_h = z_h >= threshold
z_y = y0 - w_hy @ h
if w_yy is not None:
z_y += w_yy @ y
z_y = z_y >= threshold
h |= z_h
y |= z_y
return h, y
def iWTA_history(x, w_xh, w_xy, w_hy, w_yy=None, w_hh=None, w_yh=None):
"""
Iterative winners-take-all activation function history. Used in Vogels'
update rule.
Parameters
----------
x : (Nx, S) np.ndarray
The input samples.
w_xh, w_xy, w_hy, w_yy, w_hh, w_yh : np.ndarray or None
Binary weights.
Returns
-------
z_h, z_y : tuple
Tuples of size `S`, each containing intermediate activations of
inhibitory and excitatory populations that were balancing each other.
"""
h0 = w_xh @ x
y0 = w_xy @ x
h = np.zeros_like(h0, dtype=np.int32)
y = np.zeros_like(y0, dtype=np.int32)
t_start = max(h0.max(), y0.max())
history = []
for threshold in range(t_start, 0, -1):
z_h = h0
if w_hh is not None:
z_h = z_h - w_hh @ h
if w_yh is not None:
z_h = z_h + w_yh @ y
z_h = z_h >= threshold
z_y = y0 - w_hy @ h
if w_yy is not None:
z_y += w_yy @ y
z_y = z_y >= threshold
h |= z_h
y |= z_y
history.append((z_h, z_y))
z_h, z_y = zip(*history)
return z_h, z_y