From 5935432c62b04d41cff92a5e8218d2ab0611ea26 Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Thu, 8 Sep 2022 23:29:23 +0200 Subject: [PATCH] feat: add default timewise quantizer, usage exampled --- README.md | 67 ++++++++++++++++++++++++++++++++ quantizer_pytorch/__init__.py | 2 +- quantizer_pytorch/quantizer.py | 71 ++++++++++++++++++++++------------ setup.py | 2 +- 4 files changed, 116 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 84596e0..cb8082e 100644 --- a/README.md +++ b/README.md @@ -8,8 +8,66 @@ pip install quantizer-pytorch ``` [![PyPI - Python Version](https://img.shields.io/pypi/v/quantizer-pytorch?style=flat&colorA=black&colorB=black)](https://pypi.org/project/quantizer-pytorch/) +## Usage + +### Timewise Quantizer 1d +```py +from quantizer_pytorch import Quantizer1d + +quantizer = Quantizer1d( + channels=32, + num_groups=1, + codebook_size=1024, + num_residuals=2 +) +quantizer.eval() # If the model is set to training mode quantizer will train with EMA by simply forwarding values + +# Quantize sequence of shape [batch_size, channels, length] +x = torch.randn(1, 32, 80) +x_quantized, info = quantizer(x) + +print(info.keys()) # ['indices', 'loss', 'perplexity', 'replaced_codes'] +print(x_quantized.shape) # [1, 32, 80], same as input but quantized +print(info['indices'].shape) # [1, 1, 80, 2], i.e. [batch, num_groups, length, num_residuals] +print(info['loss']) # 0.8637, the mean squared error between x and x_quantized +print(info['replaced_codes']) # [0, 0], number of replaced codes per group + +# Reconstruct x_quantized from indices +x_quantiezed_recon = quantizer.from_ids(info['indices']) +assert torch.allclose(x_quantized, x_quantiezed_recon) # This assert should pass if in eval mode +``` + + +### Channelwise Quantizer 1d +```py +from quantizer_pytorch import QuantizerChannelwise1d + +quantizer = QuantizerChannelwise1d( + channels=32, + split_size=4, # Each channels will be split into vectors of size split_size and quantized + num_groups=1, + codebook_size=1024 +) +quantizer.eval() # If the model is set to training mode quantizer will train with EMA by simply forwarding values + +# Quantize sequence of shape [batch_size, channels, length] +x = torch.randn(1, 32, 80) +x_quantized, info = quantizer(x) + +print(info.keys()) # ['indices', 'loss', 'perplexity', 'replaced_codes'] +print(x_quantized.shape) # [1, 32, 80], same as input but quantized +print(info['indices'].shape) # [1, 32, 20], since the length is 80 and we use a split_size (you can think of this as kernel_size=stride=split_size) we have 20 indices +print(info['loss']) # 0.0620, the mean squared error between x and x_quantized +print(info['replaced_codes']) # [1], number of replaced codes per group + +# Reconstruct x_quantized from indices +x_quantiezed_recon = quantizer.from_ids(info['indices']) +assert torch.allclose(x_quantized, x_quantiezed_recon) # This assert should pass if in eval mode +``` + ## Citations + ```bibtex @misc{2106.04283, Author = {Rayhane Mama and Marc S. Tyndel and Hashiam Kadhim and Cole Clifford and Ragavan Thurairatnam}, @@ -18,3 +76,12 @@ Year = {2021}, Eprint = {arXiv:2106.04283}, } ``` + +```bibtex +@misc{2107.03312, +Author = {Neil Zeghidour and Alejandro Luebs and Ahmed Omran and Jan Skoglund and Marco Tagliasacchi}, +Title = {SoundStream: An End-to-End Neural Audio Codec}, +Year = {2021}, +Eprint = {arXiv:2107.03312}, +} +``` diff --git a/quantizer_pytorch/__init__.py b/quantizer_pytorch/__init__.py index 708f952..499e60e 100644 --- a/quantizer_pytorch/__init__.py +++ b/quantizer_pytorch/__init__.py @@ -1 +1 @@ -from .quantizer import Quantizer1d +from .quantizer import VQ, Quantizer1d, QuantizerChannelwise1d, ResidualVQ diff --git a/quantizer_pytorch/quantizer.py b/quantizer_pytorch/quantizer.py index bf23f2c..a69a597 100644 --- a/quantizer_pytorch/quantizer.py +++ b/quantizer_pytorch/quantizer.py @@ -220,38 +220,62 @@ def forward( class Quantizer1d(nn.Module): + def __init__( + self, + channels: int, + num_groups: int, + codebook_size: int, + num_residuals: int = 1, + **kwargs + ): + super().__init__() + assert channels % num_groups == 0, "channels must be divisible by num_groups" + self.num_groups = num_groups + self.num_residuals = num_residuals + + self.quantize = ResidualVQ( + features=channels, + num_heads=num_groups, + codebook_size=codebook_size, + num_residuals=num_residuals, + **kwargs + ) + + def from_ids(self, indices: LongTensor) -> Tensor: + indices = rearrange(indices, "b g n r -> b g (n r)") + x = self.quantize.from_ids(indices) + return rearrange(x, "b t c -> b c t") + + def forward(self, x: Tensor) -> Tuple[Tensor, Dict]: + r = self.num_residuals + x = rearrange(x, "b c t -> b t c") + x, info = self.quantize(x) + x = rearrange(x, "b t c -> b c t") + # Rearrange indices to expose residual + info["indices"] = rearrange(info["indices"], "b g (n r) -> b g n r", r=r) + return x, info + + +class QuantizerChannelwise1d(nn.Module): def __init__( self, channels: int, split_size: int, num_groups: int, codebook_size: int, - quantizer_type: str = "vqe", + num_residuals: int = 1, **kwargs ): super().__init__() self.split_size = split_size self.num_groups = num_groups - quantize: Optional[Quantization] = None - - if quantizer_type == "vq": - quantize = VQ( - features=num_groups * split_size, - num_heads=num_groups, - codebook_size=codebook_size, - **kwargs - ) - elif quantizer_type == "rvq": - quantize = ResidualVQ( - features=num_groups * split_size, - num_heads=num_groups, - codebook_size=codebook_size, - **kwargs - ) - else: - raise ValueError("Invalid quantizer type") - - self.quantize = quantize + self.quantize = ResidualVQ( + features=num_groups * split_size, + num_heads=num_groups, + codebook_size=codebook_size, + num_residuals=num_residuals, + **kwargs + ) def from_ids(self, indices: LongTensor) -> Tensor: g, s = self.num_groups, indices.shape[-1] @@ -265,10 +289,9 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Dict]: # Quantize each group in a different head (codebook) x = rearrange(x, "b (g k) (s d) -> b (k s) (g d)", g=g, s=s) x, info = self.quantize(x) - # Mask channel tokens with increasing probability - tokens = rearrange(x, "b (k s) (g d) -> (b s) (g k) d", g=g, s=s) + x = rearrange(x, "b (k s) (g d) -> (b s) (g k) d", g=g, s=s) # Turn back to original shape - x = rearrange(tokens, "(b s) (g k) d -> b (g k) (s d)", g=g, s=s) + x = rearrange(x, "(b s) (g k) d -> b (g k) (s d)", g=g, s=s) # Rearrange info to match input shape info["indices"] = rearrange(info["indices"], "b g (k s) -> b (g k) s", s=s) return x, info diff --git a/setup.py b/setup.py index 440040f..261a9a4 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="quantizer-pytorch", packages=find_packages(exclude=[]), - version="0.0.13", + version="0.0.14", license="MIT", description="Quantizer - PyTorch", long_description_content_type="text/markdown",