From 641ecf7a64a9c7b01e8fbb71578c42c85f07ba59 Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Wed, 5 Oct 2022 21:25:12 +0200 Subject: [PATCH] feat: add MQ quantizer --- quantizer_pytorch/quantizer.py | 136 +++++++++++++++++++++++++++++++++ setup.py | 2 +- 2 files changed, 137 insertions(+), 1 deletion(-) diff --git a/quantizer_pytorch/quantizer.py b/quantizer_pytorch/quantizer.py index 5b0dea9..4e4746c 100644 --- a/quantizer_pytorch/quantizer.py +++ b/quantizer_pytorch/quantizer.py @@ -3,6 +3,7 @@ import torch import torch.nn.functional as F from einops import rearrange, reduce, repeat +from einops.layers.torch import EinMix from torch import LongTensor, Tensor, einsum, nn from typing_extensions import TypeGuard @@ -388,3 +389,138 @@ def forward(self, x: Tensor, **kwargs) -> Tuple[Tensor, Dict]: class Quantizer2d(nn.Module): def __init__(self): super().__init__() + + +""" +Experimental +""" + + +class Memcodes(nn.Module): + def __init__(self, features: int, codebook_size: int, temperature: float): + super().__init__() + self.features = features + self.scale = features**-0.5 + self.codebook_size = codebook_size + self.temperature = temperature + + self.codebook = nn.Parameter(torch.randn(codebook_size, features)) + # Different linear projection for each key/value head + self.to_k = EinMix( + pattern="n d -> n c", + weight_shape="d c", + d=features, + c=features, + ) + self.to_v = EinMix( + pattern="n d -> n c", + weight_shape="d c", + d=features, + c=features, + ) + + def forward(self, x: Tensor) -> Tuple[Tensor, Dict]: + assert x.shape[-1] == self.features + + q = x * self.scale + # Compute keys/values of codebook + k, v = self.to_k(self.codebook), self.to_v(self.codebook) + # Logits matrix between codebook and input queries + logits = einsum("b i d, j d -> b i j", q, k) # b, n, s + + if self.training: + # Attention matrix with hard stochastic (differentiable) argmax + attn = F.gumbel_softmax(logits, tau=self.temperature, dim=-1, hard=True) + codebook_indices = attn.argmax(dim=-1) + else: + # Attention matrix with hard deterministic argmax + codebook_indices = logits.argmax(dim=-1) + attn = F.one_hot(codebook_indices, num_classes=self.codebook_size).float() + + out = einsum("b i j, j d -> b i d", attn, v) + + info = {"indices": codebook_indices, "perplexity": self.perplexity(attn)} + return out, info + + def from_ids(self, indices: LongTensor) -> Tensor: + b = indices.shape[0] + # Compute values from codebook + v = repeat(self.to_v(self.codebook), "n d -> b n d", b=b) + # Repeat indices d times + indices = repeat(indices, "... -> ... d", d=v.shape[-1]) + # Gather values on indices last dim + out = v.gather(dim=1, index=indices) + return out + + def perplexity(self, onehot: Tensor, eps: float = 1e-10) -> Tensor: + mean = reduce(onehot, "b n s -> s", "mean") + return torch.exp(-reduce(mean * torch.log(mean + eps), "s -> 1", "sum")) + + +class MQ(nn.Module): + def __init__( + self, features: int, codebook_size: int, num_overlaps: int, temperature: float + ): + super().__init__() + self.num_overlaps = num_overlaps + + self.memcodes = nn.ModuleList( + [ + Memcodes( + features=features, + codebook_size=codebook_size, + temperature=temperature, + ) + for _ in range(num_overlaps) + ] + ) + + def from_ids(self, indices: LongTensor) -> Tensor: + o = self.num_overlaps + indices = rearrange(indices, "b n o -> o b n") + out = sum([self.memcodes[i].from_ids(indices[i]) for i in range(o)]) + return out + + def forward(self, x: Tensor) -> Tuple[Tensor, Dict]: + embeddings, indices, perplexities = [], [], [] + + for i in range(self.num_overlaps): + embedding, info = self.memcodes[i](x) + embeddings += [embedding] + indices += [info["indices"]] + perplexities += [info["perplexity"]] + + out = reduce(torch.stack(embeddings), "o b n d -> b n d", "sum") + info = dict( + indices=rearrange(indices, "o b n -> b n o"), perplexity=perplexities + ) + + return out, info + + +class MQ1d(nn.Module): + def __init__( + self, channels: int, codebook_size: int, num_overlaps: int, temperature: float + ): + super().__init__() + self.num_overlaps = num_overlaps + + self.quantize = MQ( + features=channels, + codebook_size=codebook_size, + num_overlaps=num_overlaps, + temperature=temperature, + ) + + def from_ids(self, indices: LongTensor, **kwargs) -> Tensor: + indices = rearrange(indices, "b 1 n o -> b n o") + x = self.quantize.from_ids(indices, **kwargs) + return rearrange(x, "b t c -> b c t") + + def forward(self, x: Tensor, **kwargs) -> Tuple[Tensor, Dict]: + x = rearrange(x, "b c t -> b t c") + x, info = self.quantize(x, **kwargs) + x = rearrange(x, "b t c -> b c t") + # Rearrange indices to expose residual + info["indices"] = rearrange(info["indices"], "b n o -> b 1 n o") + return x, info diff --git a/setup.py b/setup.py index 5ee1df9..ec8ba93 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="quantizer-pytorch", packages=find_packages(exclude=[]), - version="0.0.19", + version="0.0.20", license="MIT", description="Quantizer - PyTorch", long_description_content_type="text/markdown",