-
Notifications
You must be signed in to change notification settings - Fork 101
/
capsule_layers.py
353 lines (294 loc) · 14.9 KB
/
capsule_layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
'''
Capsules for Object Segmentation (SegCaps)
Original Paper: https://arxiv.org/abs/1804.04241
Code written by: Rodney LaLonde
If you use significant portions of this code or the ideas from our paper, please cite it :)
If you have any questions, please email me at lalonde@knights.ucf.edu.
This file contains the definitions of the various capsule layers and dynamic routing and squashing functions.
'''
import keras.backend as K
import tensorflow as tf
from keras import initializers, layers
from keras.utils.conv_utils import conv_output_length, deconv_length
import numpy as np
class Length(layers.Layer):
def __init__(self, num_classes, seg=True, **kwargs):
super(Length, self).__init__(**kwargs)
if num_classes == 2:
self.num_classes = 1
else:
self.num_classes = num_classes
self.seg = seg
def call(self, inputs, **kwargs):
if inputs.get_shape().ndims == 5:
assert inputs.get_shape()[-2].value == 1, 'Error: Must have num_capsules = 1 going into Length'
inputs = K.squeeze(inputs, axis=-2)
return K.expand_dims(tf.norm(inputs, axis=-1), axis=-1)
def compute_output_shape(self, input_shape):
if len(input_shape) == 5:
input_shape = input_shape[0:-2] + input_shape[-1:]
if self.seg:
return input_shape[:-1] + (self.num_classes,)
else:
return input_shape[:-1]
def get_config(self):
config = {'num_classes': self.num_classes, 'seg': self.seg}
base_config = super(Length, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class Mask(layers.Layer):
def __init__(self, resize_masks=False, **kwargs):
super(Mask, self).__init__(**kwargs)
self.resize_masks = resize_masks
def call(self, inputs, **kwargs):
if type(inputs) is list:
assert len(inputs) == 2
input, mask = inputs
_, hei, wid, _, _ = input.get_shape()
if self.resize_masks:
mask = tf.image.resize_bicubic(mask, (hei.value, wid.value))
mask = K.expand_dims(mask, -1)
if input.get_shape().ndims == 3:
masked = K.batch_flatten(mask * input)
else:
masked = mask * input
else:
if inputs.get_shape().ndims == 3:
x = K.sqrt(K.sum(K.square(inputs), -1))
mask = K.one_hot(indices=K.argmax(x, 1), num_classes=x.get_shape().as_list()[1])
masked = K.batch_flatten(K.expand_dims(mask, -1) * inputs)
else:
masked = inputs
return masked
def compute_output_shape(self, input_shape):
if type(input_shape[0]) is tuple: # true label provided
if len(input_shape[0]) == 3:
return tuple([None, input_shape[0][1] * input_shape[0][2]])
else:
return input_shape[0]
else: # no true label provided
if len(input_shape) == 3:
return tuple([None, input_shape[1] * input_shape[2]])
else:
return input_shape
def get_config(self):
config = {'resize_masks': self.resize_masks}
base_config = super(Mask, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class ConvCapsuleLayer(layers.Layer):
def __init__(self, kernel_size, num_capsule, num_atoms, strides=1, padding='same', routings=3,
kernel_initializer='he_normal', **kwargs):
super(ConvCapsuleLayer, self).__init__(**kwargs)
self.kernel_size = kernel_size
self.num_capsule = num_capsule
self.num_atoms = num_atoms
self.strides = strides
self.padding = padding
self.routings = routings
self.kernel_initializer = initializers.get(kernel_initializer)
def build(self, input_shape):
assert len(input_shape) == 5, "The input Tensor should have shape=[None, input_height, input_width," \
" input_num_capsule, input_num_atoms]"
self.input_height = input_shape[1]
self.input_width = input_shape[2]
self.input_num_capsule = input_shape[3]
self.input_num_atoms = input_shape[4]
# Transform matrix
self.W = self.add_weight(shape=[self.kernel_size, self.kernel_size,
self.input_num_atoms, self.num_capsule * self.num_atoms],
initializer=self.kernel_initializer,
name='W')
self.b = self.add_weight(shape=[1, 1, self.num_capsule, self.num_atoms],
initializer=initializers.constant(0.1),
name='b')
self.built = True
def call(self, input_tensor, training=None):
input_transposed = tf.transpose(input_tensor, [3, 0, 1, 2, 4])
input_shape = K.shape(input_transposed)
input_tensor_reshaped = K.reshape(input_transposed, [
input_shape[0] * input_shape[1], self.input_height, self.input_width, self.input_num_atoms])
input_tensor_reshaped.set_shape((None, self.input_height, self.input_width, self.input_num_atoms))
conv = K.conv2d(input_tensor_reshaped, self.W, (self.strides, self.strides),
padding=self.padding, data_format='channels_last')
votes_shape = K.shape(conv)
_, conv_height, conv_width, _ = conv.get_shape()
votes = K.reshape(conv, [input_shape[1], input_shape[0], votes_shape[1], votes_shape[2],
self.num_capsule, self.num_atoms])
votes.set_shape((None, self.input_num_capsule, conv_height.value, conv_width.value,
self.num_capsule, self.num_atoms))
logit_shape = K.stack([
input_shape[1], input_shape[0], votes_shape[1], votes_shape[2], self.num_capsule])
biases_replicated = K.tile(self.b, [conv_height.value, conv_width.value, 1, 1])
activations = update_routing(
votes=votes,
biases=biases_replicated,
logit_shape=logit_shape,
num_dims=6,
input_dim=self.input_num_capsule,
output_dim=self.num_capsule,
num_routing=self.routings)
return activations
def compute_output_shape(self, input_shape):
space = input_shape[1:-2]
new_space = []
for i in range(len(space)):
new_dim = conv_output_length(
space[i],
self.kernel_size,
padding=self.padding,
stride=self.strides,
dilation=1)
new_space.append(new_dim)
return (input_shape[0],) + tuple(new_space) + (self.num_capsule, self.num_atoms)
def get_config(self):
config = {
'kernel_size': self.kernel_size,
'num_capsule': self.num_capsule,
'num_atoms': self.num_atoms,
'strides': self.strides,
'padding': self.padding,
'routings': self.routings,
'kernel_initializer': initializers.serialize(self.kernel_initializer)
}
base_config = super(ConvCapsuleLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class DeconvCapsuleLayer(layers.Layer):
def __init__(self, kernel_size, num_capsule, num_atoms, scaling=2, upsamp_type='deconv', padding='same', routings=3,
kernel_initializer='he_normal', **kwargs):
super(DeconvCapsuleLayer, self).__init__(**kwargs)
self.kernel_size = kernel_size
self.num_capsule = num_capsule
self.num_atoms = num_atoms
self.scaling = scaling
self.upsamp_type = upsamp_type
self.padding = padding
self.routings = routings
self.kernel_initializer = initializers.get(kernel_initializer)
def build(self, input_shape):
assert len(input_shape) == 5, "The input Tensor should have shape=[None, input_height, input_width," \
" input_num_capsule, input_num_atoms]"
self.input_height = input_shape[1]
self.input_width = input_shape[2]
self.input_num_capsule = input_shape[3]
self.input_num_atoms = input_shape[4]
# Transform matrix
if self.upsamp_type == 'subpix':
self.W = self.add_weight(shape=[self.kernel_size, self.kernel_size,
self.input_num_atoms,
self.num_capsule * self.num_atoms * self.scaling * self.scaling],
initializer=self.kernel_initializer,
name='W')
elif self.upsamp_type == 'resize':
self.W = self.add_weight(shape=[self.kernel_size, self.kernel_size,
self.input_num_atoms, self.num_capsule * self.num_atoms],
initializer=self.kernel_initializer, name='W')
elif self.upsamp_type == 'deconv':
self.W = self.add_weight(shape=[self.kernel_size, self.kernel_size,
self.num_capsule * self.num_atoms, self.input_num_atoms],
initializer=self.kernel_initializer, name='W')
else:
raise NotImplementedError('Upsampling must be one of: "deconv", "resize", or "subpix"')
self.b = self.add_weight(shape=[1, 1, self.num_capsule, self.num_atoms],
initializer=initializers.constant(0.1),
name='b')
self.built = True
def call(self, input_tensor, training=None):
input_transposed = tf.transpose(input_tensor, [3, 0, 1, 2, 4])
input_shape = K.shape(input_transposed)
input_tensor_reshaped = K.reshape(input_transposed, [
input_shape[1] * input_shape[0], self.input_height, self.input_width, self.input_num_atoms])
input_tensor_reshaped.set_shape((None, self.input_height, self.input_width, self.input_num_atoms))
if self.upsamp_type == 'resize':
upsamp = K.resize_images(input_tensor_reshaped, self.scaling, self.scaling, 'channels_last')
outputs = K.conv2d(upsamp, kernel=self.W, strides=(1, 1), padding=self.padding, data_format='channels_last')
elif self.upsamp_type == 'subpix':
conv = K.conv2d(input_tensor_reshaped, kernel=self.W, strides=(1, 1), padding='same',
data_format='channels_last')
outputs = tf.depth_to_space(conv, self.scaling)
else:
batch_size = input_shape[1] * input_shape[0]
# Infer the dynamic output shape:
out_height = deconv_length(self.input_height, self.scaling, self.kernel_size, self.padding)
out_width = deconv_length(self.input_width, self.scaling, self.kernel_size, self.padding)
output_shape = (batch_size, out_height, out_width, self.num_capsule * self.num_atoms)
outputs = K.conv2d_transpose(input_tensor_reshaped, self.W, output_shape, (self.scaling, self.scaling),
padding=self.padding, data_format='channels_last')
votes_shape = K.shape(outputs)
_, conv_height, conv_width, _ = outputs.get_shape()
votes = K.reshape(outputs, [input_shape[1], input_shape[0], votes_shape[1], votes_shape[2],
self.num_capsule, self.num_atoms])
votes.set_shape((None, self.input_num_capsule, conv_height.value, conv_width.value,
self.num_capsule, self.num_atoms))
logit_shape = K.stack([
input_shape[1], input_shape[0], votes_shape[1], votes_shape[2], self.num_capsule])
biases_replicated = K.tile(self.b, [votes_shape[1], votes_shape[2], 1, 1])
activations = update_routing(
votes=votes,
biases=biases_replicated,
logit_shape=logit_shape,
num_dims=6,
input_dim=self.input_num_capsule,
output_dim=self.num_capsule,
num_routing=self.routings)
return activations
def compute_output_shape(self, input_shape):
output_shape = list(input_shape)
output_shape[1] = deconv_length(output_shape[1], self.scaling, self.kernel_size, self.padding)
output_shape[2] = deconv_length(output_shape[2], self.scaling, self.kernel_size, self.padding)
output_shape[3] = self.num_capsule
output_shape[4] = self.num_atoms
return tuple(output_shape)
def get_config(self):
config = {
'kernel_size': self.kernel_size,
'num_capsule': self.num_capsule,
'num_atoms': self.num_atoms,
'scaling': self.scaling,
'padding': self.padding,
'upsamp_type': self.upsamp_type,
'routings': self.routings,
'kernel_initializer': initializers.serialize(self.kernel_initializer)
}
base_config = super(DeconvCapsuleLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def update_routing(votes, biases, logit_shape, num_dims, input_dim, output_dim,
num_routing):
if num_dims == 6:
votes_t_shape = [5, 0, 1, 2, 3, 4]
r_t_shape = [1, 2, 3, 4, 5, 0]
elif num_dims == 4:
votes_t_shape = [3, 0, 1, 2]
r_t_shape = [1, 2, 3, 0]
else:
raise NotImplementedError('Not implemented')
votes_trans = tf.transpose(votes, votes_t_shape)
_, _, _, height, width, caps = votes_trans.get_shape()
def _body(i, logits, activations):
"""Routing while loop."""
# route: [batch, input_dim, output_dim, ...]
route = tf.nn.softmax(logits, dim=-1)
preactivate_unrolled = route * votes_trans
preact_trans = tf.transpose(preactivate_unrolled, r_t_shape)
preactivate = tf.reduce_sum(preact_trans, axis=1) + biases
activation = _squash(preactivate)
activations = activations.write(i, activation)
act_3d = K.expand_dims(activation, 1)
tile_shape = np.ones(num_dims, dtype=np.int32).tolist()
tile_shape[1] = input_dim
act_replicated = tf.tile(act_3d, tile_shape)
distances = tf.reduce_sum(votes * act_replicated, axis=-1)
logits += distances
return (i + 1, logits, activations)
activations = tf.TensorArray(
dtype=tf.float32, size=num_routing, clear_after_read=False)
logits = tf.fill(logit_shape, 0.0)
i = tf.constant(0, dtype=tf.int32)
_, logits, activations = tf.while_loop(
lambda i, logits, activations: i < num_routing,
_body,
loop_vars=[i, logits, activations],
swap_memory=True)
return K.cast(activations.read(num_routing - 1), dtype='float32')
def _squash(input_tensor):
norm = tf.norm(input_tensor, axis=-1, keep_dims=True)
norm_squared = norm * norm
return (input_tensor / norm) * (norm_squared / (1 + norm_squared))