Skip to content

Commit

Permalink
feat: add qbit quantizer
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Oct 16, 2022
1 parent d3dd826 commit 90e2413
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 1 deletion.
104 changes: 104 additions & 0 deletions quantizer_pytorch/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,3 +532,107 @@ def forward(
# Rearrange indices to expose residual
info["indices"] = rearrange(info["indices"], "b n o -> b 1 n o")
return (x, info) if with_info else x


"""
QBit
"""


class QBitLayer(nn.Module):
def __init__(self, features: int, num_bits: int, temperature: float):
super().__init__()
self.num_bits = num_bits
self.temperature = temperature
self.codebook = nn.Parameter(torch.randn(2, features))
self.register_buffer("bitmask", 2 ** torch.arange(num_bits - 1, -1, -1))

def from_ids(self, indices: Tensor) -> Tensor:
bits = self.to_bits(indices)
bits = rearrange(bits, "b i j -> b (i j)")
attn = F.one_hot(bits, num_classes=2).float()
out = einsum("b m n, n d -> b m d", attn, self.codebook)
return out

def to_bits(self, indices: Tensor) -> Tensor:
return indices.unsqueeze(-1).bitwise_and(self.bitmask).ne(0).long() # type: ignore #noqa

def to_ints(self, bits: Tensor) -> Tensor:
return torch.sum(self.bitmask * bits, dim=-1)

def forward(self, x: Tensor) -> Tuple[Tensor, Dict]:
n = x.shape[1]
assert n % self.num_bits == 0, "input must be divisible by num_bits"

sim = einsum("b m d, n d -> b m n", x, self.codebook)

if self.training:
attn = F.gumbel_softmax(sim, tau=self.temperature, dim=-1, hard=True)
bits = attn.argmax(dim=-1)
else:
bits = sim.argmax(dim=-1)
attn = F.one_hot(bits, num_classes=2).float()

bits_flat = bits
bits = rearrange(bits_flat, "b (i j) -> b i j", j=self.num_bits)
indices = self.to_ints(bits)
out = einsum("b m n, n d -> b m d", attn, self.codebook)
return out, dict(indices=indices, bits=bits_flat)


class QBit(nn.Module):
def __init__(
self,
features: int,
num_bits: int,
num_layers: int,
temperature: float,
):
super().__init__()
self.layers = nn.ModuleList(
[
QBitLayer(features=features, num_bits=num_bits, temperature=temperature)
for _ in range(num_layers)
]
)

def from_ids(self, indices: Tensor) -> Tensor:
indices = rearrange(indices, "b l n -> l b n")
ys = []

for layer, indices in zip(self.layers, indices):
ys += [layer.from_ids(indices)]

return reduce(torch.stack(ys), "l b n d -> b n d", "sum")

def forward(self, x: Tensor) -> Tuple[Tensor, Dict]:
ys, indices, bits = [], [], []

for layer in self.layers:
y, info = layer(x)
x = x - y
ys += [y]
indices += [info["indices"]]
bits += [info["bits"]]

y = reduce(torch.stack(ys), "l b n d -> b n d", "sum")

info = dict(
indices=rearrange(indices, "l b n -> b l n"),
bits=rearrange(bits, "l b k -> b l k"),
)

return y, info


class QBit1d(QBit):
def __init__(self, channels: int, **kwargs):
super().__init__(features=channels, **kwargs)

def forward( # type: ignore
self, x: Tensor, with_info: bool = True, **kwargs
) -> Union[Tensor, Tuple[Tensor, Dict]]:
x = rearrange(x, "b c t -> b t c")
x, info = super().forward(x, **kwargs)
x = rearrange(x, "b t c -> b c t")
return (x, info) if with_info else x
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="quantizer-pytorch",
packages=find_packages(exclude=[]),
version="0.0.21",
version="0.0.22",
license="MIT",
description="Quantizer - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 90e2413

Please sign in to comment.