From b485b017a94e065c61d9b819381fa6e97d792136 Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Mon, 5 Sep 2022 22:27:49 +0200 Subject: [PATCH] fix: vqe and replace vq with it --- quantizer_pytorch/quantizer.py | 179 ++++----------------------------- setup.py | 2 +- 2 files changed, 23 insertions(+), 158 deletions(-) diff --git a/quantizer_pytorch/quantizer.py b/quantizer_pytorch/quantizer.py index 95c0544..bf23f2c 100644 --- a/quantizer_pytorch/quantizer.py +++ b/quantizer_pytorch/quantizer.py @@ -3,7 +3,7 @@ import torch import torch.nn.functional as F from einops import rearrange, reduce, repeat -from torch import LongTensor, Tensor, einsum, log, nn +from torch import LongTensor, Tensor, einsum, nn from typing_extensions import TypeGuard T = TypeVar("T") @@ -44,143 +44,13 @@ def ema_inplace( moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) # type: ignore -def l2norm(x: Tensor) -> Tensor: - return F.normalize(x, dim=-1) - - -def distance(q: Tensor, c: Tensor) -> Tensor: - l2_q = reduce(q**2, "b h n d -> b h n 1", "sum") - l2_c = reduce(c**2, "b h m d -> b h 1 m", "sum") - sim = einsum("b h n d, b h m d -> b h n m", q, c) - return l2_q + l2_c - 2 * sim - - class VQ(Quantization): - """Vector Quantization Block with EMA""" - - def __init__( - self, - features: int, - codebook_size: int, - temperature: float = 0.0, - cluster_size_expire_threshold: int = 2, - ema_decay: float = 0.99, - ema_epsilon: float = 1e-5, - ): - super().__init__() - self.temperature = temperature - self.cluster_size_expire_threshold = cluster_size_expire_threshold - self.ema_decay = ema_decay - self.ema_epsilon = ema_epsilon - # Embedding parameters - self.embedding = nn.Embedding(codebook_size, features) - nn.init.kaiming_uniform_(self.embedding.weight) - # Exponential Moving Average (EMA) parameters - self.register_buffer("ema_cluster_size", torch.zeros(codebook_size)) - self.register_buffer("ema_embedding_sum", self.embedding.weight.clone()) - - def forward(self, x: Tensor) -> Tuple[Tensor, Dict]: - b, n, d = x.shape - # Flatten - q = rearrange(x, "b n d -> (b n) d") - # Compute quantization - k = self.embedding.weight - z, indices, onehot = self.quantize(q, k, temperature=self.temperature) - # Update embedding with EMA - if self.training: - self.update_embedding(q, onehot) - self.expire_codes(new_samples=q) - # Unflatten all and return - quantized = rearrange(z, "(b n) d -> b n d", b=b) - info = { - "loss": F.mse_loss(quantized.detach(), x), - "indices": rearrange(indices, "(b n) -> b 1 n", b=b), - "perplexity": perplexity(rearrange(onehot, "(b n) m -> b 1 n m", b=b)), - "ema_cluster_size": self.ema_cluster_size, - } - return quantized, info - - def from_ids(self, indices: LongTensor) -> Tensor: - indices = rearrange(indices, "b 1 n -> b n") - return self.embedding(indices) - - def quantize(self, q: Tensor, k: Tensor, temperature: float) -> Tuple[Tensor, ...]: - (_, d), (m, d_) = q.shape, k.shape - # Dimensionality checks - assert d == d_, "Expected q, k to have same number of dimensions" - # Compute similarity between queries and value vectors - similarity = -self.distances(q, k) # [n, m] - # Get quatized indeces with highest similarity - indices = self.get_indices(similarity, temperature=temperature) # [n] - # Compute hard attention matrix - onehot = F.one_hot(indices, num_classes=m).float() # [n, m] - # Get quantized vectors - z = einsum("n m, m d -> n d", onehot, k) - # Copy gradients to input - z = q + (z - q).detach() if self.training else z - return z, indices, onehot - - def get_indices(self, similarity: Tensor, temperature: float) -> Tensor: - if temperature == 0.0: - return torch.argmax(similarity, dim=1) - # Gumbel sample - noise = torch.zeros_like(similarity).uniform_(0, 1) - gumbel_noise = -log(-log(noise)) - return ((similarity / temperature) + gumbel_noise).argmax(dim=1) - - def distances(self, q: Tensor, k: Tensor) -> Tensor: - l2_q = reduce(q**2, "n d -> n 1", "sum") - l2_k = reduce(k**2, "m d -> m", "sum") - sim = einsum("n d, m d -> n m", q, k) - return l2_q + l2_k - 2 * sim - - def update_embedding(self, q: Tensor, z_onehot: Tensor) -> None: - """Update codebook embeddings with EMA""" - # Compute batch number of hits per codebook element - batch_cluster_size = reduce(z_onehot, "n m -> m", "sum") - # Compute batch overlapped embeddings - batch_embedding_sum = einsum("n m, n d -> m d", z_onehot, q) - # Update with EMA - ema_inplace(self.ema_cluster_size, batch_cluster_size, self.ema_decay) # [m] - ema_inplace(self.ema_embedding_sum, batch_embedding_sum, self.ema_decay) - # Update codebook embedding by averaging vectors - new_embedding = self.ema_embedding_sum / rearrange( - self.ema_cluster_size + 1e-5, "k -> k 1" # type: ignore - ) - self.embedding.weight.data.copy_(new_embedding) - - def expire_codes(self, new_samples: Tensor) -> None: - """Replaces dead codes in codebook with random batch elements""" - if self.cluster_size_expire_threshold == 0: - return - - # Mask is true where codes are expired - expired_codes = self.ema_cluster_size < self.cluster_size_expire_threshold # type: ignore # noqa - num_expired_codes: int = expired_codes.sum().item() # type: ignore - - if num_expired_codes == 0: - return - - n, device = new_samples.shape[0], new_samples.device - - if n < num_expired_codes: - # If fewer new samples than expired codes, repeat with duplicates at random - indices = torch.randint(0, n, (num_expired_codes,), device=device) - else: - # If more new samples than expired codes, pick random candidates - indices = torch.randperm(n, device=device)[0:num_expired_codes] - - # Update codebook embedding - self.embedding.weight.data[expired_codes] = new_samples[indices] - - -class VQE(Quantization): def __init__( self, features: int, num_heads: int, codebook_size: int, - expire_threshold: int = 2, + expire_threshold: int = 0, ema_decay: float = 0.99, ): super().__init__() @@ -194,12 +64,10 @@ def __init__( # Initialize codebook (h, m, d) codebooks = torch.randn(num_heads, codebook_size, self.head_features) - self.register_buffer("codebooks", codebooks) + self.codebooks = nn.Parameter(codebooks) # Track codebook cluster size to expire dead codes faster - ema_cluster_size = torch.full( - size=(num_heads, codebook_size), fill_value=float(self.expire_threshold) - ) + ema_cluster_size = torch.zeros(num_heads, codebook_size) self.register_buffer("ema_cluster_size", ema_cluster_size) self.register_buffer("ema_embedding_sum", codebooks.clone()) @@ -209,15 +77,12 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Dict]: q = rearrange(x, "b n (h d) -> b h n d", h=self.num_heads) c = repeat(self.codebooks, "h m d -> b h m d", b=b) - # q2, c2 = map(l2norm, (q, c)) - # sim = einsum("b h i d, b h j d -> b h i j", q2, c2) # b h n m - # sim = -torch.cdist(q, c, p=2.0) - sim = -distance(q, c) + sim = -torch.cdist(q, c, p=2.0) # b h n m codebook_indices = sim.argmax(dim=-1) attn = F.one_hot(codebook_indices, num_classes=self.codebook_size).float() - out = einsum("b h n i, b h j d -> b h n d", attn, c) + out = einsum("b h n m, b h m d -> b h n d", attn, c) out = rearrange(out, "b h n d -> b n (h d)") out = x + (out - x).detach() if self.training else out @@ -236,16 +101,19 @@ def update_codebooks(self, q: Tensor, onehot: Tensor) -> None: """Update codebooks embeddings with EMA""" # Update codebook cluster sizes with EMA - batch_cluster_size = reduce(onehot, "b h n m -> h m", "sum") - ema_inplace(self.ema_cluster_size, batch_cluster_size, self.ema_decay) + batch_cluster_size = reduce(onehot, "b h n m -> b h m", "sum") + avg_cluster_size = reduce(batch_cluster_size, "b h m -> h m", "mean") + ema_inplace(self.ema_cluster_size, avg_cluster_size, self.ema_decay) # Update codebook embedding sums with EMA - batch_embedding_sum = einsum("b h n m, b h n d -> h m d", onehot, q) - ema_inplace(self.ema_embedding_sum, batch_embedding_sum, self.ema_decay) + batch_embedding_sum = einsum("b h n m, b h n d -> b h m d", onehot, q) + avg_embedding_sum = reduce(batch_embedding_sum, "b h m d -> h m d", "mean") + ema_inplace(self.ema_embedding_sum, avg_embedding_sum, self.ema_decay) # Update codebook embedding by averaging vectors - self.codebooks = self.ema_embedding_sum / rearrange( - self.ema_cluster_size + 1e-5, "h m -> h m 1" # type: ignore + self.codebooks.data.copy_( + self.ema_embedding_sum + / rearrange(self.ema_cluster_size + 1e-5, "h m -> h m 1") # type: ignore ) def expire_dead_codes(self, x: Tensor) -> Tensor: @@ -280,7 +148,7 @@ def expire_dead_codes(self, x: Tensor) -> Tensor: head_end = head_start + self.head_features new_codebooks[head_idx, expired_codes] = vectors[ids, head_start:head_end] - self.codebooks = new_codebooks + self.codebooks.data.copy_(new_codebooks) return num_expired_codes_per_head def from_ids(self, indices: LongTensor) -> Tensor: @@ -289,17 +157,17 @@ def from_ids(self, indices: LongTensor) -> Tensor: # Get attention matrix from indices attn = F.one_hot(indices, num_classes=self.codebook_size).float() # Compute output with codebook - out = einsum("b h n i, b h j d -> b h n d", attn, c) + out = einsum("b h n m, b h m d -> b h n d", attn, c) out = rearrange(out, "b h n d -> b n (h d)") return out -class ResidualVQE(Quantization): +class ResidualVQ(Quantization): def __init__(self, num_residuals: int, shared_codebook: bool = True, **kwargs): super().__init__() self.num_residuals = num_residuals - self.quantizers = nn.ModuleList([VQE(**kwargs) for _ in range(num_residuals)]) + self.quantizers = nn.ModuleList([VQ(**kwargs) for _ in range(num_residuals)]) if not shared_codebook: return @@ -367,17 +235,14 @@ def __init__( quantize: Optional[Quantization] = None if quantizer_type == "vq": - assert num_groups == 1, "num_groups must be 1 with with vq quantization" - quantize = VQ(features=split_size, codebook_size=codebook_size, **kwargs) - elif quantizer_type == "vqe": - quantize = VQE( + quantize = VQ( features=num_groups * split_size, num_heads=num_groups, codebook_size=codebook_size, **kwargs ) - elif quantizer_type == "rvqe": - quantize = ResidualVQE( + elif quantizer_type == "rvq": + quantize = ResidualVQ( features=num_groups * split_size, num_heads=num_groups, codebook_size=codebook_size, diff --git a/setup.py b/setup.py index 0525ae5..440040f 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="quantizer-pytorch", packages=find_packages(exclude=[]), - version="0.0.12", + version="0.0.13", license="MIT", description="Quantizer - PyTorch", long_description_content_type="text/markdown",