-
Notifications
You must be signed in to change notification settings - Fork 1
/
unit_metric_computers.py
349 lines (293 loc) · 13.4 KB
/
unit_metric_computers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = "3"
import logging
import cv2
import numpy as np
from sklearn.preprocessing import MinMaxScaler
import scores
def _is_dead_unit(heatmap):
"""
Given a unit's 2D heatmap, check if it is a dead unit.
"""
# return np.allclose(heatmap, 0)
# unit is dead if less than 1% of the heatmap is active
return np.sum(heatmap > 0) < 0.01 * heatmap.shape[0] * heatmap.shape[1]
def _compute_single_heatmap_fields_info(
heatmap,
pixel_min_threshold,
pixel_max_threshold
):
"""
Given a 2D heatmap of a unit, compute:
num_clusters, num_pixels_in_clusters, max_value_in_clusters, \
mean_value_in_clusters, var_value_in_clusters, heatmap_thresholded
"""
# scaler = MinMaxScaler()
# # normalize to [0, 1]
# heatmap_normalized = scaler.fit_transform(heatmap)
heatmap_normalized = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap))
# convert to [0, 255]
heatmap_gray = (heatmap_normalized * 255).astype(np.uint8)
# compute activity threshold as the mean of the heatmap
activity_threshold = np.mean(heatmap_gray)
_, heatmap_thresholded = cv2.threshold(
heatmap_gray, activity_threshold,
255, cv2.THRESH_BINARY
)
# num_labels=4,
# num_labels includes background
# labels \in (17, 17)
# stats \in (4, 5): [left, top, width, height, area] for each label
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(heatmap_thresholded)
# Create a mask to filter clusters based on pixel thresholds
# e.g. mask=[False, True, False, True] for each label (i.e. a cluster)
mask = (stats[:, cv2.CC_STAT_AREA] >= pixel_min_threshold) & \
(stats[:, cv2.CC_STAT_AREA] <= pixel_max_threshold)
# set background to False regardless of pixel thresholds
mask[0] = False
# Filter the stats and labels based on the mask
# filtered_stats.shape (2, 5)
filtered_stats = stats[mask]
# For labels with mask=True, keep the label, otherwise set to 0
# this in fact will include 0, but we want 1, 3 only
# so when using `filtered_labels` to extract max value in each cluster
# we need to exclude 0
filtered_labels = np.where(np.isin(labels, np.nonzero(mask)[0]), labels, 0)
# Count the number of clusters that meet the criteria
num_clusters = np.array([filtered_stats.shape[0]])
# Get the number of pixels in each cluster
num_pixels_in_clusters = filtered_stats[:, cv2.CC_STAT_AREA]
# Get the max/mean/var value in heatmap based on each cluster
max_value_in_clusters = []
max_value_indices_in_clusters = []
mean_value_in_clusters = []
var_value_in_clusters = []
for label in np.unique(filtered_labels):
if label != 0:
max_value = np.max(heatmap[filtered_labels == label])
max_value_in_clusters.append(np.around(max_value, 1))
max_value_index_x = np.where(heatmap == max_value)[0][0]
max_value_index_y = np.where(heatmap == max_value)[1][0]
max_value_indices_in_clusters.append((max_value_index_x, max_value_index_y))
mean_value_in_clusters.append(
np.around(
np.mean(heatmap[filtered_labels == label]), 1
)
)
var_value_in_clusters.append(
np.around(
np.var(heatmap[filtered_labels == label]), 1
)
)
# Add 0 to `num_pixels_in_clusters` and `max_value_in_clusters`
# in case `num_clusters` is 0. This is helpful when we want to
# plot fields info against coef, as no matter if there is a cluster
# for a unit, there is always a coef for that unit.
if num_clusters[0] == 0:
num_pixels_in_clusters = np.array([0])
max_value_in_clusters = np.array([0])
mean_value_in_clusters = np.array([0])
var_value_in_clusters = np.array([0])
max_value_indices_in_clusters = np.array([(0, 0)])
else:
max_value_in_clusters = np.array(max_value_in_clusters)
mean_value_in_clusters = np.array(mean_value_in_clusters)
var_value_in_clusters = np.array(var_value_in_clusters)
max_value_indices_in_clusters = np.array(max_value_indices_in_clusters)
colors = np.arange(100, dtype=int).tolist()
for label in np.unique(filtered_labels):
if label != 0:
# create a mask for each label
mask = np.where(filtered_labels == label, 255, 0).astype(np.uint8)
# find contours
contours, _ = cv2.findContours(
mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
# draw contours
cv2.drawContours(heatmap_thresholded, contours, -1, colors[label-1], 1)
# Compute the mean angle of the fields
# wrt center of heatmap, which has coords
angles = []
center_y, center_x = np.array(heatmap.shape) // 2
for max_index in max_value_indices_in_clusters:
x, y = max_index
dx = x - center_x
dy = y - center_y
angle = np.degrees(np.arctan2(dy, dx))
angles.append(angle)
mean_angle = np.mean(angles)
return num_clusters, num_pixels_in_clusters, max_value_in_clusters, \
mean_value_in_clusters, var_value_in_clusters, heatmap_thresholded, mean_angle
def _compute_single_heatmap_grid_scores(activation_map, smooth=False):
# mask parameters
starts = [0.2] * 10
ends = np.linspace(0.4, 1.0, num=10)
masks_parameters = zip(starts, ends.tolist())
scorer = scores.GridScorer(
len(activation_map), # nbins
[0, len(activation_map)-1], # coords_range
masks_parameters # parameters for the masks
)
score_60, score_90, max_60_mask, max_90_mask, sac = \
scorer.get_scores(activation_map)
return score_60, score_90, max_60_mask, max_90_mask, sac, scorer
def _compute_single_heatmap_border_scores(activation_map, db=3):
"""
Banino et al. 2018 uses db=3.
"""
num_bins = activation_map.shape[0]
# Compute c (average activity for bins further than db bins from any wall)
c = np.mean([
activation_map[i, j]
for i in range(db, num_bins - db)
for j in range(db, num_bins - db)
])
wall_scores = []
# Compute the average activation for each wall
for i in range(4):
if i == 0:
# Top wall
activations = activation_map[:db, :]
elif i == 1:
# Right wall
activations = activation_map[:, -db:]
elif i == 2:
# Bottom wall
activations = activation_map[-db:, :]
elif i == 3:
# Left wall
activations = activation_map[:, :db]
bi = np.mean(activations)
wall_scores.append((bi - c) / (bi + c))
return np.max(wall_scores)
def _compute_single_heatmap_directional_scores(activation_maps):
"""
Args:
`activation_maps` correspond to the un-summed activation maps
for a single unit across rotations \in (n_locations, n_rotations)
- num of angular bins in Banino here becomes `n_rotations`.
- based on Banino eq, we need to convert each n_rotations to
`alpha_i` which is angle.
- the intensity `beta_i` of an angle is the average activation
across all locations for that angle.
"""
# model_reps \in (n_locations, n_rotations, n_features)
# activation_maps \in (n_locations, n_rotations)
num_bins = activation_maps.shape[1]
alphas = np.linspace(0, 2*np.pi, num=num_bins, endpoint=False)
betas = np.mean(activation_maps, axis=0)
# given a rotation, we can compute alpha_i and beta_i
# which are used to compute r_i in the eq.
# we collect r_i for each rotation and compute the mean
# vector, whose length is used as the directional score.
polar_plot_coords = [] # (n_rotations, 2)
per_rotation_vector_length = []
for alpha_i, beta_i in zip(alphas, betas):
polar_plot_coords.append(
[beta_i*np.cos(alpha_i), beta_i*np.sin(alpha_i)]
)
per_rotation_vector_length.append(
np.linalg.norm([beta_i*np.cos(alpha_i), beta_i*np.sin(alpha_i)])
)
# to compute mean vector length,
# first we compute the sum of r_i normed by sum of beta_i
r_normed_by_beta = np.sum(
np.array(polar_plot_coords), axis=0) / np.sum(betas)
# then we compute the length of the mean vector
mean_vector_length = np.linalg.norm(r_normed_by_beta)
logging.info(f'[Check] mean_vector_length: {mean_vector_length}')
return mean_vector_length, per_rotation_vector_length
def _unit_chart_type_classification(unit_chart_info):
"""
Given a unit_chart_info, classify the units into different types,
and return the indices of units by type or combo of types.
"""
dead_units_indices = []
max_num_clusters = np.max(unit_chart_info[:, 1]) # global max used for setting xaxis.
num_clusters = np.zeros(max_num_clusters+1)
cluster_sizes = []
cluster_peaks = []
border_cell_indices = []
place_cells_indices = []
direction_cell_indices = []
active_no_type_indices = []
for unit_index in range(unit_chart_info.shape[0]):
if unit_chart_info[unit_index, 0] == 0:
dead_units_indices.append(unit_index)
else:
num_clusters[int(unit_chart_info[unit_index, 1])] += 1
cluster_sizes.extend(unit_chart_info[unit_index, 2])
cluster_peaks.extend(unit_chart_info[unit_index, 3])
if unit_chart_info[unit_index, 1] > 0:
place_cells_indices.append(unit_index)
is_place_cell = True
else:
is_place_cell = False
if unit_chart_info[unit_index, 10] > 0.47:
direction_cell_indices.append(unit_index)
is_direction_cell = True
else:
is_direction_cell = False
if unit_chart_info[unit_index, 9] > 0.5:
border_cell_indices.append(unit_index)
is_border_cell = True
else:
is_border_cell = False
if not (is_place_cell or is_direction_cell or is_border_cell):
active_no_type_indices.append(unit_index)
# plot
n_dead_units = len(dead_units_indices)
n_active_units = unit_chart_info.shape[0] - n_dead_units
# Collect the indices of units that are all three types
# (place + border + direction)
place_border_direction_cells_indices = \
list(set(place_cells_indices) & set(border_cell_indices) & set(direction_cell_indices))
# Collect the indices of units that are two types (inc. three types)
# (place + border cells)
# (place + direction cells)
# (border + direction cells)
place_and_border_cells_indices = \
list(set(place_cells_indices) & set(border_cell_indices))
place_and_direction_cells_indices = \
list(set(place_cells_indices) & set(direction_cell_indices))
border_and_direction_cells_indices = \
list(set(border_cell_indices) & set(direction_cell_indices))
# Collect the indices of units that are only two types
# (place + border - direction),
# (place + direction - border),
# (border + direction - place)
place_and_border_not_direction_cells_indices = \
list(set(place_and_border_cells_indices) - set(place_border_direction_cells_indices))
place_and_direction_not_border_cells_indices = \
list(set(place_and_direction_cells_indices) - set(place_border_direction_cells_indices))
border_and_direction_not_place_cells_indices = \
list(set(border_and_direction_cells_indices) - set(place_border_direction_cells_indices))
# Collect the indices of units that are exclusive
# place cells,
# border cells,
# direction cells
exclusive_place_cells_indices = \
list(set(place_cells_indices) - (set(place_and_border_cells_indices) | set(place_and_direction_cells_indices)))
exclusive_border_cells_indices = \
list(set(border_cell_indices) - (set(place_and_border_cells_indices) | set(border_and_direction_cells_indices)))
exclusive_direction_cells_indices = \
list(set(direction_cell_indices) - (set(place_and_direction_cells_indices) | set(border_and_direction_cells_indices)))
results = {
'dead_units_indices': dead_units_indices,
'place_border_direction_cells_indices': place_border_direction_cells_indices,
'place_and_border_not_direction_cells_indices': place_and_border_not_direction_cells_indices,
'place_and_direction_not_border_cells_indices': place_and_direction_not_border_cells_indices,
'border_and_direction_not_place_cells_indices': border_and_direction_not_place_cells_indices,
'exclusive_place_cells_indices': exclusive_place_cells_indices,
'exclusive_border_cells_indices': exclusive_border_cells_indices,
'exclusive_direction_cells_indices': exclusive_direction_cells_indices,
'active_no_type_indices': active_no_type_indices,
}
assert unit_chart_info.shape[0] == sum([len(v) for v in results.values()])
# Check all values are mutually exclusive
for key, value in results.items():
for key2, value2 in results.items():
if key != key2:
assert len(set(value) & set(value2)) == 0, f'{key} and {key2} have common elements'
return results