Skip to content

Commit

Permalink
feat: add MQ quantizer
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Oct 5, 2022
1 parent 4524b38 commit 641ecf7
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 1 deletion.
136 changes: 136 additions & 0 deletions quantizer_pytorch/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
from einops.layers.torch import EinMix
from torch import LongTensor, Tensor, einsum, nn
from typing_extensions import TypeGuard

Expand Down Expand Up @@ -388,3 +389,138 @@ def forward(self, x: Tensor, **kwargs) -> Tuple[Tensor, Dict]:
class Quantizer2d(nn.Module):
def __init__(self):
super().__init__()


"""
Experimental
"""


class Memcodes(nn.Module):
def __init__(self, features: int, codebook_size: int, temperature: float):
super().__init__()
self.features = features
self.scale = features**-0.5
self.codebook_size = codebook_size
self.temperature = temperature

self.codebook = nn.Parameter(torch.randn(codebook_size, features))
# Different linear projection for each key/value head
self.to_k = EinMix(
pattern="n d -> n c",
weight_shape="d c",
d=features,
c=features,
)
self.to_v = EinMix(
pattern="n d -> n c",
weight_shape="d c",
d=features,
c=features,
)

def forward(self, x: Tensor) -> Tuple[Tensor, Dict]:
assert x.shape[-1] == self.features

q = x * self.scale
# Compute keys/values of codebook
k, v = self.to_k(self.codebook), self.to_v(self.codebook)
# Logits matrix between codebook and input queries
logits = einsum("b i d, j d -> b i j", q, k) # b, n, s

if self.training:
# Attention matrix with hard stochastic (differentiable) argmax
attn = F.gumbel_softmax(logits, tau=self.temperature, dim=-1, hard=True)
codebook_indices = attn.argmax(dim=-1)
else:
# Attention matrix with hard deterministic argmax
codebook_indices = logits.argmax(dim=-1)
attn = F.one_hot(codebook_indices, num_classes=self.codebook_size).float()

out = einsum("b i j, j d -> b i d", attn, v)

info = {"indices": codebook_indices, "perplexity": self.perplexity(attn)}
return out, info

def from_ids(self, indices: LongTensor) -> Tensor:
b = indices.shape[0]
# Compute values from codebook
v = repeat(self.to_v(self.codebook), "n d -> b n d", b=b)
# Repeat indices d times
indices = repeat(indices, "... -> ... d", d=v.shape[-1])
# Gather values on indices last dim
out = v.gather(dim=1, index=indices)
return out

def perplexity(self, onehot: Tensor, eps: float = 1e-10) -> Tensor:
mean = reduce(onehot, "b n s -> s", "mean")
return torch.exp(-reduce(mean * torch.log(mean + eps), "s -> 1", "sum"))


class MQ(nn.Module):
def __init__(
self, features: int, codebook_size: int, num_overlaps: int, temperature: float
):
super().__init__()
self.num_overlaps = num_overlaps

self.memcodes = nn.ModuleList(
[
Memcodes(
features=features,
codebook_size=codebook_size,
temperature=temperature,
)
for _ in range(num_overlaps)
]
)

def from_ids(self, indices: LongTensor) -> Tensor:
o = self.num_overlaps
indices = rearrange(indices, "b n o -> o b n")
out = sum([self.memcodes[i].from_ids(indices[i]) for i in range(o)])
return out

def forward(self, x: Tensor) -> Tuple[Tensor, Dict]:
embeddings, indices, perplexities = [], [], []

for i in range(self.num_overlaps):
embedding, info = self.memcodes[i](x)
embeddings += [embedding]
indices += [info["indices"]]
perplexities += [info["perplexity"]]

out = reduce(torch.stack(embeddings), "o b n d -> b n d", "sum")
info = dict(
indices=rearrange(indices, "o b n -> b n o"), perplexity=perplexities
)

return out, info


class MQ1d(nn.Module):
def __init__(
self, channels: int, codebook_size: int, num_overlaps: int, temperature: float
):
super().__init__()
self.num_overlaps = num_overlaps

self.quantize = MQ(
features=channels,
codebook_size=codebook_size,
num_overlaps=num_overlaps,
temperature=temperature,
)

def from_ids(self, indices: LongTensor, **kwargs) -> Tensor:
indices = rearrange(indices, "b 1 n o -> b n o")
x = self.quantize.from_ids(indices, **kwargs)
return rearrange(x, "b t c -> b c t")

def forward(self, x: Tensor, **kwargs) -> Tuple[Tensor, Dict]:
x = rearrange(x, "b c t -> b t c")
x, info = self.quantize(x, **kwargs)
x = rearrange(x, "b t c -> b c t")
# Rearrange indices to expose residual
info["indices"] = rearrange(info["indices"], "b n o -> b 1 n o")
return x, info
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="quantizer-pytorch",
packages=find_packages(exclude=[]),
version="0.0.19",
version="0.0.20",
license="MIT",
description="Quantizer - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 641ecf7

Please sign in to comment.