diff --git a/quantizer_pytorch/quantizer.py b/quantizer_pytorch/quantizer.py index 95d9a76..95c0544 100644 --- a/quantizer_pytorch/quantizer.py +++ b/quantizer_pytorch/quantizer.py @@ -236,20 +236,17 @@ def update_codebooks(self, q: Tensor, onehot: Tensor) -> None: """Update codebooks embeddings with EMA""" # Update codebook cluster sizes with EMA - batch_cluster_size = reduce(onehot, "b h n m -> b h m", "sum") - avg_cluster_size = reduce(batch_cluster_size, "b h m -> h m", "mean") - ema_inplace(self.ema_cluster_size, avg_cluster_size, self.ema_decay) + batch_cluster_size = reduce(onehot, "b h n m -> h m", "sum") + ema_inplace(self.ema_cluster_size, batch_cluster_size, self.ema_decay) # Update codebook embedding sums with EMA - batch_embedding_sum = einsum("b h n m, b h n d -> b h m d", onehot, q) - avg_embedding_sum = reduce(batch_embedding_sum, "b h m d -> h m d", "mean") - ema_inplace(self.ema_embedding_sum, avg_embedding_sum, self.ema_decay) + batch_embedding_sum = einsum("b h n m, b h n d -> h m d", onehot, q) + ema_inplace(self.ema_embedding_sum, batch_embedding_sum, self.ema_decay) # Update codebook embedding by averaging vectors - new_codebooks = self.ema_embedding_sum / rearrange( + self.codebooks = self.ema_embedding_sum / rearrange( self.ema_cluster_size + 1e-5, "h m -> h m 1" # type: ignore ) - self.codebooks = new_codebooks def expire_dead_codes(self, x: Tensor) -> Tensor: """Replaces dead codes in codebook with random batch elements""" diff --git a/setup.py b/setup.py index dd181ce..0525ae5 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="quantizer-pytorch", packages=find_packages(exclude=[]), - version="0.0.11", + version="0.0.12", license="MIT", description="Quantizer - PyTorch", long_description_content_type="text/markdown",