diff --git a/quantizer_pytorch/quantizer.py b/quantizer_pytorch/quantizer.py index 3611ad2..08e2e78 100644 --- a/quantizer_pytorch/quantizer.py +++ b/quantizer_pytorch/quantizer.py @@ -212,11 +212,14 @@ def __init__(self, num_residuals: int, shared_codebook: bool = True, **kwargs): if not shared_codebook: return + # Share both codebooks and total budget first_vq, *rest_vq = self.quantizers codebooks = first_vq.codebooks + budget_ema = first_vq.budget_ema for quantizer in rest_vq: quantizer.codebooks = codebooks + quantizer.budget_ema = budget_ema def from_ids( self, indices: LongTensor, num_residuals: Optional[int] = None diff --git a/setup.py b/setup.py index 233dac6..8bd063d 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="quantizer-pytorch", packages=find_packages(exclude=[]), - version="0.0.16", + version="0.0.17", license="MIT", description="Quantizer - PyTorch", long_description_content_type="text/markdown",