Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Conv1x1 #105

Draft
wants to merge 23 commits into
base: main
Choose a base branch
from
8 changes: 8 additions & 0 deletions flowtorch/bijectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
from flowtorch.bijectors.autoregressive import Autoregressive
from flowtorch.bijectors.base import Bijector
from flowtorch.bijectors.compose import Compose
from flowtorch.bijectors.conv11 import Conv1x1Bijector
from flowtorch.bijectors.conv11 import SomeOtherClass
from flowtorch.bijectors.coupling import ConvCouplingBijector
from flowtorch.bijectors.coupling import CouplingBijector
from flowtorch.bijectors.elementwise import Elementwise
from flowtorch.bijectors.elu import ELU
from flowtorch.bijectors.exp import Exp
Expand All @@ -35,6 +39,9 @@
("Affine", Affine),
("AffineAutoregressive", AffineAutoregressive),
("AffineFixed", AffineFixed),
("Conv1x1Bijector", Conv1x1Bijector),
("ConvCouplingBijector", ConvCouplingBijector),
("CouplingBijector", CouplingBijector),
("ELU", ELU),
("Exp", Exp),
("LeakyReLU", LeakyReLU),
Expand All @@ -54,6 +61,7 @@
("Bijector", Bijector),
("Compose", Compose),
("Invert", Invert),
("SomeOtherClass", SomeOtherClass),
("VolumePreserving", VolumePreserving),
]

Expand Down
21 changes: 17 additions & 4 deletions flowtorch/bijectors/affine_autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,28 @@ def __init__(
*,
shape: torch.Size,
context_shape: Optional[torch.Size] = None,
clamp_values: bool = False,
log_scale_min_clip: float = -5.0,
log_scale_max_clip: float = 3.0,
sigmoid_bias: float = 2.0,
positive_map: str = "softplus",
positive_bias: Optional[float] = None,
) -> None:
super().__init__(
AffineOp.__init__(
self,
params_fn,
shape=shape,
context_shape=context_shape,
clamp_values=clamp_values,
log_scale_min_clip=log_scale_min_clip,
log_scale_max_clip=log_scale_max_clip,
sigmoid_bias=sigmoid_bias,
positive_map=positive_map,
positive_bias=positive_bias,
)
Autoregressive.__init__(
self,
params_fn,
shape=shape,
context_shape=context_shape,
)
self.log_scale_min_clip = log_scale_min_clip
self.log_scale_max_clip = log_scale_max_clip
self.sigmoid_bias = sigmoid_bias
2 changes: 1 addition & 1 deletion flowtorch/bijectors/autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def inverse(
# TODO: Make permutation, inverse work for other event shapes
log_detJ: Optional[torch.Tensor] = None
for idx in cast(torch.LongTensor, permutation):
_params = self._params_fn(x_new.clone(), context=context)
_params = self._params_fn(x_new.clone(), inverse=False, context=context)
x_temp, log_detJ = self._inverse(y, params=_params)
x_new[..., idx] = x_temp[..., idx]
# _log_detJ = out[1]
Expand Down
16 changes: 12 additions & 4 deletions flowtorch/bijectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,11 @@ def forward(
assert isinstance(x, BijectiveTensor)
return x.get_parent_from_bijector(self)

params = self._params_fn(x, context) if self._params_fn is not None else None
params = (
self._params_fn(x, inverse=False, context=context)
if self._params_fn is not None
else None
)
y, log_detJ = self._forward(x, params)
if (
is_record_flow_graph_enabled()
Expand Down Expand Up @@ -117,7 +121,11 @@ def inverse(
return y.get_parent_from_bijector(self)

# TODO: What to do in this line?
params = self._params_fn(x, context) if self._params_fn is not None else None
params = (
self._params_fn(y, inverse=True, context=context)
if self._params_fn is not None
else None
)
x, log_detJ = self._inverse(y, params)

if (
Expand Down Expand Up @@ -170,10 +178,10 @@ def log_abs_det_jacobian(
if ladj is None:
if is_record_flow_graph_enabled():
warnings.warn(
"Computing _log_abs_det_jacobian from values and not " "from cache."
"Computing _log_abs_det_jacobian from values and not from cache."
)
params = (
self._params_fn(x, context) if self._params_fn is not None else None
self._params_fn(x, y, context) if self._params_fn is not None else None
)
return self._log_abs_det_jacobian(x, y, params)
return ladj
Expand Down
110 changes: 110 additions & 0 deletions flowtorch/bijectors/conv11.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) Meta Platforms, Inc
from copy import deepcopy
from typing import Optional, Sequence, Tuple

import flowtorch
import torch
from torch.distributions import constraints
from torch.nn import functional as F

from ..parameters.conv11 import Conv1x1Params
from .base import Bijector

_REAL3d = deepcopy(constraints.real)
_REAL3d.event_dim = 3


class SomeOtherClass(Bijector):
pass


class Conv1x1Bijector(SomeOtherClass):

domain: constraints.Constraint = _REAL3d
codomain: constraints.Constraint = _REAL3d

def __init__(
self,
params_fn: Optional[flowtorch.Lazy] = None,
*,
shape: torch.Size,
context_shape: Optional[torch.Size] = None,
LU_decompose: bool = False,
double_solve: bool = False,
zero_init: bool = False,
):
if params_fn is None:
params_fn = Conv1x1Params(LU_decompose, zero_init=zero_init) # type: ignore
self._LU = LU_decompose
self._double_solve = double_solve
self.dims = (-3, -2, -1)
super().__init__(
params_fn=params_fn,
shape=shape,
context_shape=context_shape,
)

def _forward(
self,
x: torch.Tensor,
params: Optional[Sequence[torch.Tensor]],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
assert isinstance(params, (list, tuple))
weight, logdet = params
unsqueeze = False
if x.ndimension() == 3:
x = x.unsqueeze(0)
unsqueeze = True
z = F.conv2d(x, weight)
if unsqueeze:
z = z.squeeze(0)
return z, logdet

def _inverse(
self,
y: torch.Tensor,
params: Optional[Sequence[torch.Tensor]],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
assert isinstance(params, (list, tuple))
unsqueeze = False
if y.ndimension() == 3:
y = y.unsqueeze(0)
unsqueeze = True

if self._LU:
p, low, up, logdet = params
dtype = low.dtype
output_view = y.permute(0, 2, 3, 1).unsqueeze(-1)
if self._double_solve:
low = low.double()
p = p.double()
output_view = output_view.double()
up = up.double()

z_view = torch.triangular_solve(
torch.triangular_solve(
p.transpose(-1, -2) @ output_view, low, upper=False
)[0],
up,
upper=True,
)[0]

if self._double_solve:
z_view = z_view.to(dtype)

z = z_view.squeeze(-1).permute(0, 3, 1, 2)
else:
weight, logdet = params
z = F.conv2d(y, weight)

if unsqueeze:
z = z.squeeze(0)
logdet = logdet.squeeze(0)
return z, logdet.expand_as(z.sum(self.dims))

def param_shapes(self, shape: torch.Size) -> Sequence[torch.Size]:
shape = torch.Size([shape[-3], shape[-3], 1, 1])
if not self._LU:
return (shape,)
else:
return (shape, shape, shape)
131 changes: 131 additions & 0 deletions flowtorch/bijectors/coupling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright (c) Meta Platforms, Inc
from copy import deepcopy
from typing import Optional, Sequence, Tuple

import flowtorch.parameters
import torch
from flowtorch.bijectors.ops.affine import Affine as AffineOp
from flowtorch.parameters import ConvCoupling, DenseCoupling
from torch.distributions import constraints


_REAL3d = deepcopy(constraints.real)
_REAL3d.event_dim = 3

_REAL1d = deepcopy(constraints.real)
_REAL1d.event_dim = 1


class CouplingBijector(AffineOp):
"""
Examples:
>>> params = DenseCoupling()
>>> bij = CouplingBijector(params)
>>> bij = bij(shape=torch.Size([32,]))
>>> for p in bij.parameters():
... p.data += torch.randn_like(p)/10
>>> x = torch.randn(1, 32,requires_grad=True)
>>> y = bij.forward(x).detach_from_flow()
>>> x_bis = bij.inverse(y)
>>> torch.testing.assert_allclose(x, x_bis)
"""

domain: constraints.Constraint = _REAL1d
codomain: constraints.Constraint = _REAL1d

def __init__(
self,
params_fn: Optional[flowtorch.Lazy] = None,
*,
shape: torch.Size,
context_shape: Optional[torch.Size] = None,
clamp_values: bool = False,
log_scale_min_clip: float = -5.0,
log_scale_max_clip: float = 3.0,
sigmoid_bias: float = 2.0,
positive_map: str = "softplus",
positive_bias: Optional[float] = None,
) -> None:

if params_fn is None:
params_fn = DenseCoupling() # type: ignore

AffineOp.__init__(
self,
params_fn,
shape=shape,
context_shape=context_shape,
clamp_values=clamp_values,
log_scale_min_clip=log_scale_min_clip,
log_scale_max_clip=log_scale_max_clip,
sigmoid_bias=sigmoid_bias,
positive_map=positive_map,
positive_bias=positive_bias,
)

def _forward(
self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]]
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
assert self._params_fn is not None

y, ldj = super()._forward(x, params)
return y, ldj

def _inverse(
self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]]
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
assert self._params_fn is not None

x, ldj = super()._inverse(y, params)
return x, ldj


class ConvCouplingBijector(CouplingBijector):
"""
Examples:
>>> params = ConvCoupling()
>>> bij = ConvCouplingBijector(params)
>>> bij = bij(shape=torch.Size([3,16,16]))
>>> for p in bij.parameters():
... p.data += torch.randn_like(p)/10
>>> x = torch.randn(4, 3, 16, 16)
>>> y = bij.forward(x)
>>> x_bis = bij.inverse(y.detach_from_flow())
>>> torch.testing.assert_allclose(x, x_bis)
"""

domain: constraints.Constraint = _REAL3d
codomain: constraints.Constraint = _REAL3d

def __init__(
self,
params_fn: Optional[flowtorch.Lazy] = None,
*,
shape: torch.Size,
context_shape: Optional[torch.Size] = None,
clamp_values: bool = False,
log_scale_min_clip: float = -5.0,
log_scale_max_clip: float = 3.0,
sigmoid_bias: float = 2.0,
positive_map: str = "softplus",
positive_bias: Optional[float] = None,
) -> None:

if not len(shape) == 3:
raise ValueError(f"Expected a 3d-tensor shape, got {shape}")

if params_fn is None:
params_fn = ConvCoupling() # type: ignore

AffineOp.__init__(
self,
params_fn,
shape=shape,
context_shape=context_shape,
clamp_values=clamp_values,
log_scale_min_clip=log_scale_min_clip,
log_scale_max_clip=log_scale_max_clip,
sigmoid_bias=sigmoid_bias,
positive_map=positive_map,
positive_bias=positive_bias,
)
Loading