Skip to content

Commit

Permalink
fix: remove batch averaging
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Sep 5, 2022
1 parent 3683e4b commit eb4ce7a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
13 changes: 5 additions & 8 deletions quantizer_pytorch/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,20 +236,17 @@ def update_codebooks(self, q: Tensor, onehot: Tensor) -> None:
"""Update codebooks embeddings with EMA"""

# Update codebook cluster sizes with EMA
batch_cluster_size = reduce(onehot, "b h n m -> b h m", "sum")
avg_cluster_size = reduce(batch_cluster_size, "b h m -> h m", "mean")
ema_inplace(self.ema_cluster_size, avg_cluster_size, self.ema_decay)
batch_cluster_size = reduce(onehot, "b h n m -> h m", "sum")
ema_inplace(self.ema_cluster_size, batch_cluster_size, self.ema_decay)

# Update codebook embedding sums with EMA
batch_embedding_sum = einsum("b h n m, b h n d -> b h m d", onehot, q)
avg_embedding_sum = reduce(batch_embedding_sum, "b h m d -> h m d", "mean")
ema_inplace(self.ema_embedding_sum, avg_embedding_sum, self.ema_decay)
batch_embedding_sum = einsum("b h n m, b h n d -> h m d", onehot, q)
ema_inplace(self.ema_embedding_sum, batch_embedding_sum, self.ema_decay)

# Update codebook embedding by averaging vectors
new_codebooks = self.ema_embedding_sum / rearrange(
self.codebooks = self.ema_embedding_sum / rearrange(
self.ema_cluster_size + 1e-5, "h m -> h m 1" # type: ignore
)
self.codebooks = new_codebooks

def expire_dead_codes(self, x: Tensor) -> Tensor:
"""Replaces dead codes in codebook with random batch elements"""
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="quantizer-pytorch",
packages=find_packages(exclude=[]),
version="0.0.11",
version="0.0.12",
license="MIT",
description="Quantizer - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit eb4ce7a

Please sign in to comment.