Skip to content

Commit

Permalink
feat: add budgeted vector quantizer
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Sep 12, 2022
1 parent 9ac691d commit ca33ad2
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 31 deletions.
2 changes: 1 addition & 1 deletion quantizer_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .quantizer import VQ, Quantizer1d, QuantizerChannelwise1d, ResidualVQ
from .quantizer import BVQ, Quantizer1d, QuantizerChannelwise1d, ResidualVQ
97 changes: 68 additions & 29 deletions quantizer_pytorch/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,44 @@ def perplexity(onehot: Tensor, eps: float = 1e-10) -> Tensor:
return torch.exp(-reduce(mean * torch.log(mean + eps), "h s -> h", "sum"))


def ema_inplace(
moving_avg: Union[Tensor, nn.Module], new: Tensor, decay: float
) -> None:
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) # type: ignore
def ema(moving_avg: Union[Tensor, nn.Module], new: Tensor, decay: float) -> Tensor:
return moving_avg * decay + new * (1 - decay) # type: ignore


def update_inplace(
old: Union[Tensor, nn.Module],
new: Tensor,
) -> None:
old.data.copy_(new) # type: ignore


class BVQ(Quantization):
"""
Budgeted Vector Quantization
Features:
[x] EMA update
[x] Multiheaded codebook
[x] Expiration invariant to the number of tokens, batch size, and codebook size.
[x] Budgeted random replacement
The total budget is always equivalent to the codebook size `m`, and each codebook
element starts with budget of 1. The budget is slowly redistributed according to
the distribution of the `n` incoming tokens with respect to the codebook. If a
codebook element is matched by many incoming vectors, its buget will increase.
The codebook vectors will be updated by averaging the matching incoming vectors.
If a codebook element budget goes below the expire threshold, the element undergoes
a hard replacement with a random vector from an incoming batch, and its budget is
reset to 1. The total budget is then renormalized, at the expense of other codebook
elements.
"""

class VQ(Quantization):
def __init__(
self,
features: int,
num_heads: int,
codebook_size: int,
expire_threshold: int = 0,
expire_threshold: float = 0.05,
ema_decay: float = 0.99,
):
super().__init__()
Expand All @@ -66,10 +91,9 @@ def __init__(
codebooks = torch.randn(num_heads, codebook_size, self.head_features)
self.codebooks = nn.Parameter(codebooks)

# Track codebook cluster size to expire dead codes faster
ema_cluster_size = torch.zeros(num_heads, codebook_size)
self.register_buffer("ema_cluster_size", ema_cluster_size)
self.register_buffer("ema_embedding_sum", codebooks.clone())
# Each element starts with budget=1, if it goes below threshold, it is replaced
ema_budget = torch.ones(num_heads, codebook_size)
self.register_buffer("budget_ema", ema_budget)

def forward(self, x: Tensor) -> Tuple[Tensor, Dict]:
b = x.shape[0]
Expand Down Expand Up @@ -99,29 +123,33 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Dict]:

def update_codebooks(self, q: Tensor, onehot: Tensor) -> None:
"""Update codebooks embeddings with EMA"""
b, n, m = q.shape[0], q.shape[2], self.codebook_size

# Update codebook cluster sizes with EMA
batch_cluster_size = reduce(onehot, "b h n m -> b h m", "sum")
avg_cluster_size = reduce(batch_cluster_size, "b h m -> h m", "mean")
ema_inplace(self.ema_cluster_size, avg_cluster_size, self.ema_decay)
# Compute incoming total hits and avg embedding sum
tot_incoming = reduce(onehot, "b h n m -> h m 1", "sum")
sum_incoming = einsum("b h n m, b h n d -> h m d", onehot, q)
avg_incoming = sum_incoming / (tot_incoming + 1e-5)

# Update codebook embedding sums with EMA
batch_embedding_sum = einsum("b h n m, b h n d -> b h m d", onehot, q)
avg_embedding_sum = reduce(batch_embedding_sum, "b h m d -> h m d", "mean")
ema_inplace(self.ema_embedding_sum, avg_embedding_sum, self.ema_decay)
# Mask for codebook elements that have not been hit by any vector
mask = tot_incoming.bool()

# Update codebook embedding by averaging vectors
self.codebooks.data.copy_(
self.ema_embedding_sum
/ rearrange(self.ema_cluster_size + 1e-5, "h m -> h m 1") # type: ignore
)
# Update codebook with EMA
codebooks_new = torch.where(mask, avg_incoming, self.codebooks)
codebooks_ema = ema(self.codebooks, codebooks_new, self.ema_decay)
update_inplace(self.codebooks, codebooks_ema)

# Compute budgets, update with EMA, renormalize such that: total budget < m
budget = (rearrange(tot_incoming, "h m 1 -> h m") / (b * n)) * m
budget_ema = ema(self.budget_ema, budget, self.ema_decay)
budget_ema_norm = (budget_ema / reduce(budget_ema, "h m -> h 1", "sum")) * m
update_inplace(self.budget_ema, budget_ema_norm)

def expire_dead_codes(self, x: Tensor) -> Tensor:
"""Replaces dead codes in codebook with random batch elements"""
is_disabled = self.expire_threshold <= 0

# Mask is true where codes are expired
expired_codes_per_head = self.ema_cluster_size < self.expire_threshold # type: ignore # noqa
expired_codes_per_head = self.budget_ema < self.expire_threshold # type: ignore # noqa
num_expired_codes_per_head = reduce(expired_codes_per_head, "h m -> h", "sum")
no_expired = torch.all(num_expired_codes_per_head == 0)

Expand All @@ -131,8 +159,9 @@ def expire_dead_codes(self, x: Tensor) -> Tensor:

# Candidate vectors for codebook replacement
vectors = rearrange(x, "b h d -> (b h) d")
n, device = vectors.shape[0], x.device
new_codebooks = self.codebooks.data
n, m, device = vectors.shape[0], self.codebook_size, x.device
codebooks_new = self.codebooks.data
budget_new = self.budget_ema

for head_idx in range(self.num_heads):
num_expired_codes = num_expired_codes_per_head[head_idx]
Expand All @@ -146,9 +175,19 @@ def expire_dead_codes(self, x: Tensor) -> Tensor:
# Update codebook head
head_start = head_idx * self.head_features
head_end = head_start + self.head_features
new_codebooks[head_idx, expired_codes] = vectors[ids, head_start:head_end]
codebooks_new[head_idx, expired_codes] = vectors[ids, head_start:head_end]
# Update budget head
budget_new[head_idx] = torch.where( # type: ignore
expired_codes, torch.ones(m, device=device), budget_new[head_idx] # type: ignore # noqa
)

# Update codebook
update_inplace(self.codebooks, codebooks_new)

# Normalize and update budget
budget_new_norm = (budget_new / reduce(budget_new, "h m -> h 1", "sum")) * m
update_inplace(self.budget_ema, budget_new_norm)

self.codebooks.data.copy_(new_codebooks)
return num_expired_codes_per_head

def from_ids(self, indices: LongTensor) -> Tensor:
Expand All @@ -167,7 +206,7 @@ def __init__(self, num_residuals: int, shared_codebook: bool = True, **kwargs):
super().__init__()
self.num_residuals = num_residuals

self.quantizers = nn.ModuleList([VQ(**kwargs) for _ in range(num_residuals)])
self.quantizers = nn.ModuleList([BVQ(**kwargs) for _ in range(num_residuals)])

if not shared_codebook:
return
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="quantizer-pytorch",
packages=find_packages(exclude=[]),
version="0.0.14",
version="0.0.15",
license="MIT",
description="Quantizer - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit ca33ad2

Please sign in to comment.