Skip to content

Commit

Permalink
fix: vqe and replace vq with it
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Sep 5, 2022
1 parent eb4ce7a commit b485b01
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 158 deletions.
179 changes: 22 additions & 157 deletions quantizer_pytorch/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
from torch import LongTensor, Tensor, einsum, log, nn
from torch import LongTensor, Tensor, einsum, nn
from typing_extensions import TypeGuard

T = TypeVar("T")
Expand Down Expand Up @@ -44,143 +44,13 @@ def ema_inplace(
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) # type: ignore


def l2norm(x: Tensor) -> Tensor:
return F.normalize(x, dim=-1)


def distance(q: Tensor, c: Tensor) -> Tensor:
l2_q = reduce(q**2, "b h n d -> b h n 1", "sum")
l2_c = reduce(c**2, "b h m d -> b h 1 m", "sum")
sim = einsum("b h n d, b h m d -> b h n m", q, c)
return l2_q + l2_c - 2 * sim


class VQ(Quantization):
"""Vector Quantization Block with EMA"""

def __init__(
self,
features: int,
codebook_size: int,
temperature: float = 0.0,
cluster_size_expire_threshold: int = 2,
ema_decay: float = 0.99,
ema_epsilon: float = 1e-5,
):
super().__init__()
self.temperature = temperature
self.cluster_size_expire_threshold = cluster_size_expire_threshold
self.ema_decay = ema_decay
self.ema_epsilon = ema_epsilon
# Embedding parameters
self.embedding = nn.Embedding(codebook_size, features)
nn.init.kaiming_uniform_(self.embedding.weight)
# Exponential Moving Average (EMA) parameters
self.register_buffer("ema_cluster_size", torch.zeros(codebook_size))
self.register_buffer("ema_embedding_sum", self.embedding.weight.clone())

def forward(self, x: Tensor) -> Tuple[Tensor, Dict]:
b, n, d = x.shape
# Flatten
q = rearrange(x, "b n d -> (b n) d")
# Compute quantization
k = self.embedding.weight
z, indices, onehot = self.quantize(q, k, temperature=self.temperature)
# Update embedding with EMA
if self.training:
self.update_embedding(q, onehot)
self.expire_codes(new_samples=q)
# Unflatten all and return
quantized = rearrange(z, "(b n) d -> b n d", b=b)
info = {
"loss": F.mse_loss(quantized.detach(), x),
"indices": rearrange(indices, "(b n) -> b 1 n", b=b),
"perplexity": perplexity(rearrange(onehot, "(b n) m -> b 1 n m", b=b)),
"ema_cluster_size": self.ema_cluster_size,
}
return quantized, info

def from_ids(self, indices: LongTensor) -> Tensor:
indices = rearrange(indices, "b 1 n -> b n")
return self.embedding(indices)

def quantize(self, q: Tensor, k: Tensor, temperature: float) -> Tuple[Tensor, ...]:
(_, d), (m, d_) = q.shape, k.shape
# Dimensionality checks
assert d == d_, "Expected q, k to have same number of dimensions"
# Compute similarity between queries and value vectors
similarity = -self.distances(q, k) # [n, m]
# Get quatized indeces with highest similarity
indices = self.get_indices(similarity, temperature=temperature) # [n]
# Compute hard attention matrix
onehot = F.one_hot(indices, num_classes=m).float() # [n, m]
# Get quantized vectors
z = einsum("n m, m d -> n d", onehot, k)
# Copy gradients to input
z = q + (z - q).detach() if self.training else z
return z, indices, onehot

def get_indices(self, similarity: Tensor, temperature: float) -> Tensor:
if temperature == 0.0:
return torch.argmax(similarity, dim=1)
# Gumbel sample
noise = torch.zeros_like(similarity).uniform_(0, 1)
gumbel_noise = -log(-log(noise))
return ((similarity / temperature) + gumbel_noise).argmax(dim=1)

def distances(self, q: Tensor, k: Tensor) -> Tensor:
l2_q = reduce(q**2, "n d -> n 1", "sum")
l2_k = reduce(k**2, "m d -> m", "sum")
sim = einsum("n d, m d -> n m", q, k)
return l2_q + l2_k - 2 * sim

def update_embedding(self, q: Tensor, z_onehot: Tensor) -> None:
"""Update codebook embeddings with EMA"""
# Compute batch number of hits per codebook element
batch_cluster_size = reduce(z_onehot, "n m -> m", "sum")
# Compute batch overlapped embeddings
batch_embedding_sum = einsum("n m, n d -> m d", z_onehot, q)
# Update with EMA
ema_inplace(self.ema_cluster_size, batch_cluster_size, self.ema_decay) # [m]
ema_inplace(self.ema_embedding_sum, batch_embedding_sum, self.ema_decay)
# Update codebook embedding by averaging vectors
new_embedding = self.ema_embedding_sum / rearrange(
self.ema_cluster_size + 1e-5, "k -> k 1" # type: ignore
)
self.embedding.weight.data.copy_(new_embedding)

def expire_codes(self, new_samples: Tensor) -> None:
"""Replaces dead codes in codebook with random batch elements"""
if self.cluster_size_expire_threshold == 0:
return

# Mask is true where codes are expired
expired_codes = self.ema_cluster_size < self.cluster_size_expire_threshold # type: ignore # noqa
num_expired_codes: int = expired_codes.sum().item() # type: ignore

if num_expired_codes == 0:
return

n, device = new_samples.shape[0], new_samples.device

if n < num_expired_codes:
# If fewer new samples than expired codes, repeat with duplicates at random
indices = torch.randint(0, n, (num_expired_codes,), device=device)
else:
# If more new samples than expired codes, pick random candidates
indices = torch.randperm(n, device=device)[0:num_expired_codes]

# Update codebook embedding
self.embedding.weight.data[expired_codes] = new_samples[indices]


class VQE(Quantization):
def __init__(
self,
features: int,
num_heads: int,
codebook_size: int,
expire_threshold: int = 2,
expire_threshold: int = 0,
ema_decay: float = 0.99,
):
super().__init__()
Expand All @@ -194,12 +64,10 @@ def __init__(

# Initialize codebook (h, m, d)
codebooks = torch.randn(num_heads, codebook_size, self.head_features)
self.register_buffer("codebooks", codebooks)
self.codebooks = nn.Parameter(codebooks)

# Track codebook cluster size to expire dead codes faster
ema_cluster_size = torch.full(
size=(num_heads, codebook_size), fill_value=float(self.expire_threshold)
)
ema_cluster_size = torch.zeros(num_heads, codebook_size)
self.register_buffer("ema_cluster_size", ema_cluster_size)
self.register_buffer("ema_embedding_sum", codebooks.clone())

Expand All @@ -209,15 +77,12 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Dict]:
q = rearrange(x, "b n (h d) -> b h n d", h=self.num_heads)
c = repeat(self.codebooks, "h m d -> b h m d", b=b)

# q2, c2 = map(l2norm, (q, c))
# sim = einsum("b h i d, b h j d -> b h i j", q2, c2) # b h n m
# sim = -torch.cdist(q, c, p=2.0)
sim = -distance(q, c)
sim = -torch.cdist(q, c, p=2.0) # b h n m

codebook_indices = sim.argmax(dim=-1)
attn = F.one_hot(codebook_indices, num_classes=self.codebook_size).float()

out = einsum("b h n i, b h j d -> b h n d", attn, c)
out = einsum("b h n m, b h m d -> b h n d", attn, c)
out = rearrange(out, "b h n d -> b n (h d)")
out = x + (out - x).detach() if self.training else out

Expand All @@ -236,16 +101,19 @@ def update_codebooks(self, q: Tensor, onehot: Tensor) -> None:
"""Update codebooks embeddings with EMA"""

# Update codebook cluster sizes with EMA
batch_cluster_size = reduce(onehot, "b h n m -> h m", "sum")
ema_inplace(self.ema_cluster_size, batch_cluster_size, self.ema_decay)
batch_cluster_size = reduce(onehot, "b h n m -> b h m", "sum")
avg_cluster_size = reduce(batch_cluster_size, "b h m -> h m", "mean")
ema_inplace(self.ema_cluster_size, avg_cluster_size, self.ema_decay)

# Update codebook embedding sums with EMA
batch_embedding_sum = einsum("b h n m, b h n d -> h m d", onehot, q)
ema_inplace(self.ema_embedding_sum, batch_embedding_sum, self.ema_decay)
batch_embedding_sum = einsum("b h n m, b h n d -> b h m d", onehot, q)
avg_embedding_sum = reduce(batch_embedding_sum, "b h m d -> h m d", "mean")
ema_inplace(self.ema_embedding_sum, avg_embedding_sum, self.ema_decay)

# Update codebook embedding by averaging vectors
self.codebooks = self.ema_embedding_sum / rearrange(
self.ema_cluster_size + 1e-5, "h m -> h m 1" # type: ignore
self.codebooks.data.copy_(
self.ema_embedding_sum
/ rearrange(self.ema_cluster_size + 1e-5, "h m -> h m 1") # type: ignore
)

def expire_dead_codes(self, x: Tensor) -> Tensor:
Expand Down Expand Up @@ -280,7 +148,7 @@ def expire_dead_codes(self, x: Tensor) -> Tensor:
head_end = head_start + self.head_features
new_codebooks[head_idx, expired_codes] = vectors[ids, head_start:head_end]

self.codebooks = new_codebooks
self.codebooks.data.copy_(new_codebooks)
return num_expired_codes_per_head

def from_ids(self, indices: LongTensor) -> Tensor:
Expand All @@ -289,17 +157,17 @@ def from_ids(self, indices: LongTensor) -> Tensor:
# Get attention matrix from indices
attn = F.one_hot(indices, num_classes=self.codebook_size).float()
# Compute output with codebook
out = einsum("b h n i, b h j d -> b h n d", attn, c)
out = einsum("b h n m, b h m d -> b h n d", attn, c)
out = rearrange(out, "b h n d -> b n (h d)")
return out


class ResidualVQE(Quantization):
class ResidualVQ(Quantization):
def __init__(self, num_residuals: int, shared_codebook: bool = True, **kwargs):
super().__init__()
self.num_residuals = num_residuals

self.quantizers = nn.ModuleList([VQE(**kwargs) for _ in range(num_residuals)])
self.quantizers = nn.ModuleList([VQ(**kwargs) for _ in range(num_residuals)])

if not shared_codebook:
return
Expand Down Expand Up @@ -367,17 +235,14 @@ def __init__(
quantize: Optional[Quantization] = None

if quantizer_type == "vq":
assert num_groups == 1, "num_groups must be 1 with with vq quantization"
quantize = VQ(features=split_size, codebook_size=codebook_size, **kwargs)
elif quantizer_type == "vqe":
quantize = VQE(
quantize = VQ(
features=num_groups * split_size,
num_heads=num_groups,
codebook_size=codebook_size,
**kwargs
)
elif quantizer_type == "rvqe":
quantize = ResidualVQE(
elif quantizer_type == "rvq":
quantize = ResidualVQ(
features=num_groups * split_size,
num_heads=num_groups,
codebook_size=codebook_size,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="quantizer-pytorch",
packages=find_packages(exclude=[]),
version="0.0.12",
version="0.0.13",
license="MIT",
description="Quantizer - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit b485b01

Please sign in to comment.