-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvector_quantizer.py
60 lines (47 loc) · 2.28 KB
/
vector_quantizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import tensorflow as tf
import tensorflow.keras as K
EPS = 1e-10
class VectorQuantizer(K.Model):
def __init__(self, n_embeddings, embedding_dim, commitment_cost,
initializer=K.initializers.RandomUniform(-1, 1)):
super(VectorQuantizer, self).__init__()
self._embedding_dim = embedding_dim
self._n_embeddings = n_embeddings
self._commitment_cost = commitment_cost
intial_w = initializer(shape=[embedding_dim, n_embeddings])
self._w = tf.Variable(initial_value=intial_w, trainable=True,
name="embedding")
def call(self, x):
# validate the input shape and flatten
tf.assert_equal(tf.shape(x)[-1], self._embedding_dim)
flat_x = tf.reshape(x, (-1, self._embedding_dim))
# compute distances of the vectors in x_flat to the embedding vectors
distances = (tf.reduce_sum(flat_x ** 2, axis=1, keepdims=True)
- 2 * tf.matmul(flat_x, self._w)
+ tf.reduce_sum(self._w ** 2, axis=0, keepdims=True))
# encode and quantize the inputs to the nearest embedding vector
encoding_indices = tf.argmin(distances, axis=1)
encodings = tf.one_hot(encoding_indices, self._n_embeddings)
encoding_indices = tf.reshape(encoding_indices, tf.shape(x)[:-1])
quantized = self.quantize(encoding_indices)
# calculate q and e latent losses
e_latent_loss = tf.reduce_mean((tf.stop_gradient(quantized) - x) ** 2)
q_latent_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2)
loss = q_latent_loss + self._commitment_cost * e_latent_loss
quantized = x + tf.stop_gradient(quantized - x) # gradients of quantized are copied over to x
avg_probs = tf.reduce_mean(encodings, axis=0)
perplexity = tf.math.exp(
-tf.reduce_sum(avg_probs * tf.math.log(avg_probs + EPS)))
return {
'quantized': quantized,
'loss': loss,
'perplexity': perplexity,
'encodings': encodings,
'encoding_indices': encoding_indices,
}
@property
def embeddings(self):
return self._w
def quantize(self, encoding_indices):
w = tf.transpose(self.embeddings, [1, 0])
return tf.nn.embedding_lookup(w, encoding_indices)