diff --git a/quantizer_pytorch/quantizer.py b/quantizer_pytorch/quantizer.py index 98753d3..7d184e1 100644 --- a/quantizer_pytorch/quantizer.py +++ b/quantizer_pytorch/quantizer.py @@ -532,3 +532,107 @@ def forward( # Rearrange indices to expose residual info["indices"] = rearrange(info["indices"], "b n o -> b 1 n o") return (x, info) if with_info else x + + +""" +QBit +""" + + +class QBitLayer(nn.Module): + def __init__(self, features: int, num_bits: int, temperature: float): + super().__init__() + self.num_bits = num_bits + self.temperature = temperature + self.codebook = nn.Parameter(torch.randn(2, features)) + self.register_buffer("bitmask", 2 ** torch.arange(num_bits - 1, -1, -1)) + + def from_ids(self, indices: Tensor) -> Tensor: + bits = self.to_bits(indices) + bits = rearrange(bits, "b i j -> b (i j)") + attn = F.one_hot(bits, num_classes=2).float() + out = einsum("b m n, n d -> b m d", attn, self.codebook) + return out + + def to_bits(self, indices: Tensor) -> Tensor: + return indices.unsqueeze(-1).bitwise_and(self.bitmask).ne(0).long() # type: ignore #noqa + + def to_ints(self, bits: Tensor) -> Tensor: + return torch.sum(self.bitmask * bits, dim=-1) + + def forward(self, x: Tensor) -> Tuple[Tensor, Dict]: + n = x.shape[1] + assert n % self.num_bits == 0, "input must be divisible by num_bits" + + sim = einsum("b m d, n d -> b m n", x, self.codebook) + + if self.training: + attn = F.gumbel_softmax(sim, tau=self.temperature, dim=-1, hard=True) + bits = attn.argmax(dim=-1) + else: + bits = sim.argmax(dim=-1) + attn = F.one_hot(bits, num_classes=2).float() + + bits_flat = bits + bits = rearrange(bits_flat, "b (i j) -> b i j", j=self.num_bits) + indices = self.to_ints(bits) + out = einsum("b m n, n d -> b m d", attn, self.codebook) + return out, dict(indices=indices, bits=bits_flat) + + +class QBit(nn.Module): + def __init__( + self, + features: int, + num_bits: int, + num_layers: int, + temperature: float, + ): + super().__init__() + self.layers = nn.ModuleList( + [ + QBitLayer(features=features, num_bits=num_bits, temperature=temperature) + for _ in range(num_layers) + ] + ) + + def from_ids(self, indices: Tensor) -> Tensor: + indices = rearrange(indices, "b l n -> l b n") + ys = [] + + for layer, indices in zip(self.layers, indices): + ys += [layer.from_ids(indices)] + + return reduce(torch.stack(ys), "l b n d -> b n d", "sum") + + def forward(self, x: Tensor) -> Tuple[Tensor, Dict]: + ys, indices, bits = [], [], [] + + for layer in self.layers: + y, info = layer(x) + x = x - y + ys += [y] + indices += [info["indices"]] + bits += [info["bits"]] + + y = reduce(torch.stack(ys), "l b n d -> b n d", "sum") + + info = dict( + indices=rearrange(indices, "l b n -> b l n"), + bits=rearrange(bits, "l b k -> b l k"), + ) + + return y, info + + +class QBit1d(QBit): + def __init__(self, channels: int, **kwargs): + super().__init__(features=channels, **kwargs) + + def forward( # type: ignore + self, x: Tensor, with_info: bool = True, **kwargs + ) -> Union[Tensor, Tuple[Tensor, Dict]]: + x = rearrange(x, "b c t -> b t c") + x, info = super().forward(x, **kwargs) + x = rearrange(x, "b t c -> b c t") + return (x, info) if with_info else x diff --git a/setup.py b/setup.py index 333a6c1..6e1bd14 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="quantizer-pytorch", packages=find_packages(exclude=[]), - version="0.0.21", + version="0.0.22", license="MIT", description="Quantizer - PyTorch", long_description_content_type="text/markdown",