Skip to content

Commit

Permalink
feat: add default timewise quantizer, usage exampled
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Sep 8, 2022
1 parent b485b01 commit 5935432
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 26 deletions.
67 changes: 67 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,66 @@ pip install quantizer-pytorch
```
[![PyPI - Python Version](https://img.shields.io/pypi/v/quantizer-pytorch?style=flat&colorA=black&colorB=black)](https://pypi.org/project/quantizer-pytorch/)

## Usage

### Timewise Quantizer 1d
```py
from quantizer_pytorch import Quantizer1d

quantizer = Quantizer1d(
channels=32,
num_groups=1,
codebook_size=1024,
num_residuals=2
)
quantizer.eval() # If the model is set to training mode quantizer will train with EMA by simply forwarding values

# Quantize sequence of shape [batch_size, channels, length]
x = torch.randn(1, 32, 80)
x_quantized, info = quantizer(x)

print(info.keys()) # ['indices', 'loss', 'perplexity', 'replaced_codes']
print(x_quantized.shape) # [1, 32, 80], same as input but quantized
print(info['indices'].shape) # [1, 1, 80, 2], i.e. [batch, num_groups, length, num_residuals]
print(info['loss']) # 0.8637, the mean squared error between x and x_quantized
print(info['replaced_codes']) # [0, 0], number of replaced codes per group

# Reconstruct x_quantized from indices
x_quantiezed_recon = quantizer.from_ids(info['indices'])
assert torch.allclose(x_quantized, x_quantiezed_recon) # This assert should pass if in eval mode
```


### Channelwise Quantizer 1d
```py
from quantizer_pytorch import QuantizerChannelwise1d

quantizer = QuantizerChannelwise1d(
channels=32,
split_size=4, # Each channels will be split into vectors of size split_size and quantized
num_groups=1,
codebook_size=1024
)
quantizer.eval() # If the model is set to training mode quantizer will train with EMA by simply forwarding values

# Quantize sequence of shape [batch_size, channels, length]
x = torch.randn(1, 32, 80)
x_quantized, info = quantizer(x)

print(info.keys()) # ['indices', 'loss', 'perplexity', 'replaced_codes']
print(x_quantized.shape) # [1, 32, 80], same as input but quantized
print(info['indices'].shape) # [1, 32, 20], since the length is 80 and we use a split_size (you can think of this as kernel_size=stride=split_size) we have 20 indices
print(info['loss']) # 0.0620, the mean squared error between x and x_quantized
print(info['replaced_codes']) # [1], number of replaced codes per group

# Reconstruct x_quantized from indices
x_quantiezed_recon = quantizer.from_ids(info['indices'])
assert torch.allclose(x_quantized, x_quantiezed_recon) # This assert should pass if in eval mode
```

## Citations


```bibtex
@misc{2106.04283,
Author = {Rayhane Mama and Marc S. Tyndel and Hashiam Kadhim and Cole Clifford and Ragavan Thurairatnam},
Expand All @@ -18,3 +76,12 @@ Year = {2021},
Eprint = {arXiv:2106.04283},
}
```

```bibtex
@misc{2107.03312,
Author = {Neil Zeghidour and Alejandro Luebs and Ahmed Omran and Jan Skoglund and Marco Tagliasacchi},
Title = {SoundStream: An End-to-End Neural Audio Codec},
Year = {2021},
Eprint = {arXiv:2107.03312},
}
```
2 changes: 1 addition & 1 deletion quantizer_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .quantizer import Quantizer1d
from .quantizer import VQ, Quantizer1d, QuantizerChannelwise1d, ResidualVQ
71 changes: 47 additions & 24 deletions quantizer_pytorch/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,38 +220,62 @@ def forward(


class Quantizer1d(nn.Module):
def __init__(
self,
channels: int,
num_groups: int,
codebook_size: int,
num_residuals: int = 1,
**kwargs
):
super().__init__()
assert channels % num_groups == 0, "channels must be divisible by num_groups"
self.num_groups = num_groups
self.num_residuals = num_residuals

self.quantize = ResidualVQ(
features=channels,
num_heads=num_groups,
codebook_size=codebook_size,
num_residuals=num_residuals,
**kwargs
)

def from_ids(self, indices: LongTensor) -> Tensor:
indices = rearrange(indices, "b g n r -> b g (n r)")
x = self.quantize.from_ids(indices)
return rearrange(x, "b t c -> b c t")

def forward(self, x: Tensor) -> Tuple[Tensor, Dict]:
r = self.num_residuals
x = rearrange(x, "b c t -> b t c")
x, info = self.quantize(x)
x = rearrange(x, "b t c -> b c t")
# Rearrange indices to expose residual
info["indices"] = rearrange(info["indices"], "b g (n r) -> b g n r", r=r)
return x, info


class QuantizerChannelwise1d(nn.Module):
def __init__(
self,
channels: int,
split_size: int,
num_groups: int,
codebook_size: int,
quantizer_type: str = "vqe",
num_residuals: int = 1,
**kwargs
):
super().__init__()
self.split_size = split_size
self.num_groups = num_groups
quantize: Optional[Quantization] = None

if quantizer_type == "vq":
quantize = VQ(
features=num_groups * split_size,
num_heads=num_groups,
codebook_size=codebook_size,
**kwargs
)
elif quantizer_type == "rvq":
quantize = ResidualVQ(
features=num_groups * split_size,
num_heads=num_groups,
codebook_size=codebook_size,
**kwargs
)
else:
raise ValueError("Invalid quantizer type")

self.quantize = quantize
self.quantize = ResidualVQ(
features=num_groups * split_size,
num_heads=num_groups,
codebook_size=codebook_size,
num_residuals=num_residuals,
**kwargs
)

def from_ids(self, indices: LongTensor) -> Tensor:
g, s = self.num_groups, indices.shape[-1]
Expand All @@ -265,10 +289,9 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Dict]:
# Quantize each group in a different head (codebook)
x = rearrange(x, "b (g k) (s d) -> b (k s) (g d)", g=g, s=s)
x, info = self.quantize(x)
# Mask channel tokens with increasing probability
tokens = rearrange(x, "b (k s) (g d) -> (b s) (g k) d", g=g, s=s)
x = rearrange(x, "b (k s) (g d) -> (b s) (g k) d", g=g, s=s)
# Turn back to original shape
x = rearrange(tokens, "(b s) (g k) d -> b (g k) (s d)", g=g, s=s)
x = rearrange(x, "(b s) (g k) d -> b (g k) (s d)", g=g, s=s)
# Rearrange info to match input shape
info["indices"] = rearrange(info["indices"], "b g (k s) -> b (g k) s", s=s)
return x, info
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="quantizer-pytorch",
packages=find_packages(exclude=[]),
version="0.0.13",
version="0.0.14",
license="MIT",
description="Quantizer - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 5935432

Please sign in to comment.