Skip to content

Commit

Permalink
feat: add with_info option
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Oct 13, 2022
1 parent 641ecf7 commit d3dd826
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
24 changes: 16 additions & 8 deletions quantizer_pytorch/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,14 +296,16 @@ def from_ids(self, indices: LongTensor, **kwargs) -> Tensor:
x = self.quantize.from_ids(indices, **kwargs)
return rearrange(x, "b t c -> b c t")

def forward(self, x: Tensor, **kwargs) -> Tuple[Tensor, Dict]:
def forward(
self, x: Tensor, with_info: bool = True, **kwargs
) -> Union[Tensor, Tuple[Tensor, Dict]]:
r = self.num_residuals
x = rearrange(x, "b c t -> b t c")
x, info = self.quantize(x, **kwargs)
x = rearrange(x, "b t c -> b c t")
# Rearrange indices to expose residual
info["indices"] = rearrange(info["indices"], "b g (n r) -> b g n r", r=r)
return x, info
return (x, info) if with_info else x


class QuantizerChannelwise1d(nn.Module):
Expand Down Expand Up @@ -333,7 +335,9 @@ def from_ids(self, indices: LongTensor, **kwargs) -> Tensor:
x = self.quantize.from_ids(indices, **kwargs)
return rearrange(x, "b (k s) (g d) -> b (g k) (s d)", g=g, s=s)

def forward(self, x: Tensor, **kwargs) -> Tuple[Tensor, Dict]:
def forward(
self, x: Tensor, with_info: bool = True, **kwargs
) -> Union[Tensor, Tuple[Tensor, Dict]]:
b, c, t = x.shape
g, s = self.num_groups, t // self.split_size
# Quantize each group in a different head (codebook)
Expand All @@ -344,7 +348,7 @@ def forward(self, x: Tensor, **kwargs) -> Tuple[Tensor, Dict]:
x = rearrange(x, "(b s) (g k) d -> b (g k) (s d)", g=g, s=s)
# Rearrange info to match input shape
info["indices"] = rearrange(info["indices"], "b g (k s) -> b (g k) s", s=s)
return x, info
return (x, info) if with_info else x


class QuantizerBlock1d(nn.Module):
Expand Down Expand Up @@ -376,14 +380,16 @@ def from_ids(self, indices: LongTensor, **kwargs) -> Tensor:
x = self.quantize.from_ids(indices, **kwargs)
return rearrange(x, "b (cn sn) (cd sd) -> b (cn cd) (sn sd)", cn=cn, sd=sd)

def forward(self, x: Tensor, **kwargs) -> Tuple[Tensor, Dict]:
def forward(
self, x: Tensor, with_info: bool = True, **kwargs
) -> Union[Tensor, Tuple[Tensor, Dict]]:
cn, sd, r = self.num_groups, self.split_size, self.num_residuals
x = rearrange(x, "b (cn cd) (sn sd) -> b (cn sn) (cd sd)", cn=cn, sd=sd)
x, info = self.quantize(x, **kwargs)
x = rearrange(x, "b (cn sn) (cd sd) -> b (cn cd) (sn sd)", cn=cn, sd=sd)
# Rearrange info to match input shape
info["indices"] = rearrange(info["indices"], "b 1 (sn r) -> b sn r", r=r)
return x, info
return (x, info) if with_info else x


class Quantizer2d(nn.Module):
Expand Down Expand Up @@ -517,10 +523,12 @@ def from_ids(self, indices: LongTensor, **kwargs) -> Tensor:
x = self.quantize.from_ids(indices, **kwargs)
return rearrange(x, "b t c -> b c t")

def forward(self, x: Tensor, **kwargs) -> Tuple[Tensor, Dict]:
def forward(
self, x: Tensor, with_info: bool = True, **kwargs
) -> Union[Tensor, Tuple[Tensor, Dict]]:
x = rearrange(x, "b c t -> b t c")
x, info = self.quantize(x, **kwargs)
x = rearrange(x, "b t c -> b c t")
# Rearrange indices to expose residual
info["indices"] = rearrange(info["indices"], "b n o -> b 1 n o")
return x, info
return (x, info) if with_info else x
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="quantizer-pytorch",
packages=find_packages(exclude=[]),
version="0.0.20",
version="0.0.21",
license="MIT",
description="Quantizer - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit d3dd826

Please sign in to comment.