diff --git a/flowtorch/bijectors/__init__.py b/flowtorch/bijectors/__init__.py index 57e3c1ad..9e724c41 100644 --- a/flowtorch/bijectors/__init__.py +++ b/flowtorch/bijectors/__init__.py @@ -16,6 +16,10 @@ from flowtorch.bijectors.autoregressive import Autoregressive from flowtorch.bijectors.base import Bijector from flowtorch.bijectors.compose import Compose +from flowtorch.bijectors.conv11 import Conv1x1Bijector +from flowtorch.bijectors.conv11 import SomeOtherClass +from flowtorch.bijectors.coupling import ConvCouplingBijector +from flowtorch.bijectors.coupling import CouplingBijector from flowtorch.bijectors.elementwise import Elementwise from flowtorch.bijectors.elu import ELU from flowtorch.bijectors.exp import Exp @@ -35,6 +39,9 @@ ("Affine", Affine), ("AffineAutoregressive", AffineAutoregressive), ("AffineFixed", AffineFixed), + ("Conv1x1Bijector", Conv1x1Bijector), + ("ConvCouplingBijector", ConvCouplingBijector), + ("CouplingBijector", CouplingBijector), ("ELU", ELU), ("Exp", Exp), ("LeakyReLU", LeakyReLU), @@ -54,6 +61,7 @@ ("Bijector", Bijector), ("Compose", Compose), ("Invert", Invert), + ("SomeOtherClass", SomeOtherClass), ("VolumePreserving", VolumePreserving), ] diff --git a/flowtorch/bijectors/affine_autoregressive.py b/flowtorch/bijectors/affine_autoregressive.py index 610e5477..a855cf5d 100644 --- a/flowtorch/bijectors/affine_autoregressive.py +++ b/flowtorch/bijectors/affine_autoregressive.py @@ -16,15 +16,28 @@ def __init__( *, shape: torch.Size, context_shape: Optional[torch.Size] = None, + clamp_values: bool = False, log_scale_min_clip: float = -5.0, log_scale_max_clip: float = 3.0, sigmoid_bias: float = 2.0, + positive_map: str = "softplus", + positive_bias: Optional[float] = None, ) -> None: - super().__init__( + AffineOp.__init__( + self, + params_fn, + shape=shape, + context_shape=context_shape, + clamp_values=clamp_values, + log_scale_min_clip=log_scale_min_clip, + log_scale_max_clip=log_scale_max_clip, + sigmoid_bias=sigmoid_bias, + positive_map=positive_map, + positive_bias=positive_bias, + ) + Autoregressive.__init__( + self, params_fn, shape=shape, context_shape=context_shape, ) - self.log_scale_min_clip = log_scale_min_clip - self.log_scale_max_clip = log_scale_max_clip - self.sigmoid_bias = sigmoid_bias diff --git a/flowtorch/bijectors/autoregressive.py b/flowtorch/bijectors/autoregressive.py index 8367b51b..8a8371f9 100644 --- a/flowtorch/bijectors/autoregressive.py +++ b/flowtorch/bijectors/autoregressive.py @@ -60,7 +60,7 @@ def inverse( # TODO: Make permutation, inverse work for other event shapes log_detJ: Optional[torch.Tensor] = None for idx in cast(torch.LongTensor, permutation): - _params = self._params_fn(x_new.clone(), context=context) + _params = self._params_fn(x_new.clone(), inverse=False, context=context) x_temp, log_detJ = self._inverse(y, params=_params) x_new[..., idx] = x_temp[..., idx] # _log_detJ = out[1] diff --git a/flowtorch/bijectors/base.py b/flowtorch/bijectors/base.py index 2a3d0f01..a9621cca 100644 --- a/flowtorch/bijectors/base.py +++ b/flowtorch/bijectors/base.py @@ -71,7 +71,11 @@ def forward( assert isinstance(x, BijectiveTensor) return x.get_parent_from_bijector(self) - params = self._params_fn(x, context) if self._params_fn is not None else None + params = ( + self._params_fn(x, inverse=False, context=context) + if self._params_fn is not None + else None + ) y, log_detJ = self._forward(x, params) if ( is_record_flow_graph_enabled() @@ -117,7 +121,11 @@ def inverse( return y.get_parent_from_bijector(self) # TODO: What to do in this line? - params = self._params_fn(x, context) if self._params_fn is not None else None + params = ( + self._params_fn(y, inverse=True, context=context) + if self._params_fn is not None + else None + ) x, log_detJ = self._inverse(y, params) if ( @@ -170,10 +178,10 @@ def log_abs_det_jacobian( if ladj is None: if is_record_flow_graph_enabled(): warnings.warn( - "Computing _log_abs_det_jacobian from values and not " "from cache." + "Computing _log_abs_det_jacobian from values and not from cache." ) params = ( - self._params_fn(x, context) if self._params_fn is not None else None + self._params_fn(x, y, context) if self._params_fn is not None else None ) return self._log_abs_det_jacobian(x, y, params) return ladj diff --git a/flowtorch/bijectors/conv11.py b/flowtorch/bijectors/conv11.py new file mode 100644 index 00000000..d412a10d --- /dev/null +++ b/flowtorch/bijectors/conv11.py @@ -0,0 +1,110 @@ +# Copyright (c) Meta Platforms, Inc +from copy import deepcopy +from typing import Optional, Sequence, Tuple + +import flowtorch +import torch +from torch.distributions import constraints +from torch.nn import functional as F + +from ..parameters.conv11 import Conv1x1Params +from .base import Bijector + +_REAL3d = deepcopy(constraints.real) +_REAL3d.event_dim = 3 + + +class SomeOtherClass(Bijector): + pass + + +class Conv1x1Bijector(SomeOtherClass): + + domain: constraints.Constraint = _REAL3d + codomain: constraints.Constraint = _REAL3d + + def __init__( + self, + params_fn: Optional[flowtorch.Lazy] = None, + *, + shape: torch.Size, + context_shape: Optional[torch.Size] = None, + LU_decompose: bool = False, + double_solve: bool = False, + zero_init: bool = False, + ): + if params_fn is None: + params_fn = Conv1x1Params(LU_decompose, zero_init=zero_init) # type: ignore + self._LU = LU_decompose + self._double_solve = double_solve + self.dims = (-3, -2, -1) + super().__init__( + params_fn=params_fn, + shape=shape, + context_shape=context_shape, + ) + + def _forward( + self, + x: torch.Tensor, + params: Optional[Sequence[torch.Tensor]], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + assert isinstance(params, (list, tuple)) + weight, logdet = params + unsqueeze = False + if x.ndimension() == 3: + x = x.unsqueeze(0) + unsqueeze = True + z = F.conv2d(x, weight) + if unsqueeze: + z = z.squeeze(0) + return z, logdet + + def _inverse( + self, + y: torch.Tensor, + params: Optional[Sequence[torch.Tensor]], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + assert isinstance(params, (list, tuple)) + unsqueeze = False + if y.ndimension() == 3: + y = y.unsqueeze(0) + unsqueeze = True + + if self._LU: + p, low, up, logdet = params + dtype = low.dtype + output_view = y.permute(0, 2, 3, 1).unsqueeze(-1) + if self._double_solve: + low = low.double() + p = p.double() + output_view = output_view.double() + up = up.double() + + z_view = torch.triangular_solve( + torch.triangular_solve( + p.transpose(-1, -2) @ output_view, low, upper=False + )[0], + up, + upper=True, + )[0] + + if self._double_solve: + z_view = z_view.to(dtype) + + z = z_view.squeeze(-1).permute(0, 3, 1, 2) + else: + weight, logdet = params + z = F.conv2d(y, weight) + + if unsqueeze: + z = z.squeeze(0) + logdet = logdet.squeeze(0) + return z, logdet.expand_as(z.sum(self.dims)) + + def param_shapes(self, shape: torch.Size) -> Sequence[torch.Size]: + shape = torch.Size([shape[-3], shape[-3], 1, 1]) + if not self._LU: + return (shape,) + else: + return (shape, shape, shape) diff --git a/flowtorch/bijectors/coupling.py b/flowtorch/bijectors/coupling.py new file mode 100644 index 00000000..73b47f50 --- /dev/null +++ b/flowtorch/bijectors/coupling.py @@ -0,0 +1,131 @@ +# Copyright (c) Meta Platforms, Inc +from copy import deepcopy +from typing import Optional, Sequence, Tuple + +import flowtorch.parameters +import torch +from flowtorch.bijectors.ops.affine import Affine as AffineOp +from flowtorch.parameters import ConvCoupling, DenseCoupling +from torch.distributions import constraints + + +_REAL3d = deepcopy(constraints.real) +_REAL3d.event_dim = 3 + +_REAL1d = deepcopy(constraints.real) +_REAL1d.event_dim = 1 + + +class CouplingBijector(AffineOp): + """ + Examples: + >>> params = DenseCoupling() + >>> bij = CouplingBijector(params) + >>> bij = bij(shape=torch.Size([32,])) + >>> for p in bij.parameters(): + ... p.data += torch.randn_like(p)/10 + >>> x = torch.randn(1, 32,requires_grad=True) + >>> y = bij.forward(x).detach_from_flow() + >>> x_bis = bij.inverse(y) + >>> torch.testing.assert_allclose(x, x_bis) + """ + + domain: constraints.Constraint = _REAL1d + codomain: constraints.Constraint = _REAL1d + + def __init__( + self, + params_fn: Optional[flowtorch.Lazy] = None, + *, + shape: torch.Size, + context_shape: Optional[torch.Size] = None, + clamp_values: bool = False, + log_scale_min_clip: float = -5.0, + log_scale_max_clip: float = 3.0, + sigmoid_bias: float = 2.0, + positive_map: str = "softplus", + positive_bias: Optional[float] = None, + ) -> None: + + if params_fn is None: + params_fn = DenseCoupling() # type: ignore + + AffineOp.__init__( + self, + params_fn, + shape=shape, + context_shape=context_shape, + clamp_values=clamp_values, + log_scale_min_clip=log_scale_min_clip, + log_scale_max_clip=log_scale_max_clip, + sigmoid_bias=sigmoid_bias, + positive_map=positive_map, + positive_bias=positive_bias, + ) + + def _forward( + self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + assert self._params_fn is not None + + y, ldj = super()._forward(x, params) + return y, ldj + + def _inverse( + self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + assert self._params_fn is not None + + x, ldj = super()._inverse(y, params) + return x, ldj + + +class ConvCouplingBijector(CouplingBijector): + """ + Examples: + >>> params = ConvCoupling() + >>> bij = ConvCouplingBijector(params) + >>> bij = bij(shape=torch.Size([3,16,16])) + >>> for p in bij.parameters(): + ... p.data += torch.randn_like(p)/10 + >>> x = torch.randn(4, 3, 16, 16) + >>> y = bij.forward(x) + >>> x_bis = bij.inverse(y.detach_from_flow()) + >>> torch.testing.assert_allclose(x, x_bis) + """ + + domain: constraints.Constraint = _REAL3d + codomain: constraints.Constraint = _REAL3d + + def __init__( + self, + params_fn: Optional[flowtorch.Lazy] = None, + *, + shape: torch.Size, + context_shape: Optional[torch.Size] = None, + clamp_values: bool = False, + log_scale_min_clip: float = -5.0, + log_scale_max_clip: float = 3.0, + sigmoid_bias: float = 2.0, + positive_map: str = "softplus", + positive_bias: Optional[float] = None, + ) -> None: + + if not len(shape) == 3: + raise ValueError(f"Expected a 3d-tensor shape, got {shape}") + + if params_fn is None: + params_fn = ConvCoupling() # type: ignore + + AffineOp.__init__( + self, + params_fn, + shape=shape, + context_shape=context_shape, + clamp_values=clamp_values, + log_scale_min_clip=log_scale_min_clip, + log_scale_max_clip=log_scale_max_clip, + sigmoid_bias=sigmoid_bias, + positive_map=positive_map, + positive_bias=positive_bias, + ) diff --git a/flowtorch/bijectors/ops/affine.py b/flowtorch/bijectors/ops/affine.py index d9cdf56f..196cfc20 100644 --- a/flowtorch/bijectors/ops/affine.py +++ b/flowtorch/bijectors/ops/affine.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc -from typing import Optional, Sequence, Tuple +from typing import Callable, Dict, Optional, Sequence, Tuple import flowtorch import torch @@ -8,6 +8,17 @@ from flowtorch.ops import clamp_preserve_gradients from torch.distributions.utils import _sum_rightmost +_DEFAULT_POSITIVE_BIASES = { + "softplus": 0.5413248538970947, + "exp": 0.0, +} + +_POSITIVE_MAPS: Dict[str, Callable[[torch.Tensor], torch.Tensor]] = { + "softplus": torch.nn.functional.softplus, + "sigmoid": torch.sigmoid, + "exp": torch.exp, +} + class Affine(Bijector): r""" @@ -22,38 +33,63 @@ def __init__( *, shape: torch.Size, context_shape: Optional[torch.Size] = None, + clamp_values: bool = False, log_scale_min_clip: float = -5.0, log_scale_max_clip: float = 3.0, sigmoid_bias: float = 2.0, + positive_map: str = "softplus", + positive_bias: Optional[float] = None, ) -> None: super().__init__(params_fn, shape=shape, context_shape=context_shape) + self.clamp_values = clamp_values self.log_scale_min_clip = log_scale_min_clip self.log_scale_max_clip = log_scale_max_clip self.sigmoid_bias = sigmoid_bias + if positive_bias is None: + positive_bias = _DEFAULT_POSITIVE_BIASES[positive_map] + self.positive_bias = positive_bias + if positive_map not in _POSITIVE_MAPS: + raise RuntimeError(f"Unknwon positive map {positive_map}") + self._positive_map = _POSITIVE_MAPS[positive_map] + self._exp_map = self._positive_map is torch.exp and self.positive_bias == 0 + + def positive_map(self, x: torch.Tensor) -> torch.Tensor: + return self._positive_map(x + self.positive_bias) def _forward( self, x: torch.Tensor, params: Optional[Sequence[torch.Tensor]] - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: assert params is not None - mean, log_scale = params - log_scale = clamp_preserve_gradients( - log_scale, self.log_scale_min_clip, self.log_scale_max_clip - ) - scale = torch.exp(log_scale) + mean, unbounded_scale = params + if self.clamp_values: + unbounded_scale = clamp_preserve_gradients( + unbounded_scale, self.log_scale_min_clip, self.log_scale_max_clip + ) + scale = self.positive_map(unbounded_scale) + log_scale = scale.log() if not self._exp_map else unbounded_scale y = scale * x + mean return y, _sum_rightmost(log_scale, self.domain.event_dim) def _inverse( self, y: torch.Tensor, params: Optional[Sequence[torch.Tensor]] - ) -> Tuple[torch.Tensor, torch.Tensor]: - assert params is not None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + assert ( + params is not None + ), f"{self.__class__.__name__}._inverse got no parameters" - mean, log_scale = params - log_scale = clamp_preserve_gradients( - log_scale, self.log_scale_min_clip, self.log_scale_max_clip - ) - inverse_scale = torch.exp(-log_scale) + mean, unbounded_scale = params + if self.clamp_values: + unbounded_scale = clamp_preserve_gradients( + unbounded_scale, self.log_scale_min_clip, self.log_scale_max_clip + ) + + if not self._exp_map: + inverse_scale = self.positive_map(unbounded_scale).reciprocal() + log_scale = -inverse_scale.log() + else: + inverse_scale = torch.exp(-unbounded_scale) + log_scale = unbounded_scale x_new = (y - mean) * inverse_scale return x_new, _sum_rightmost(log_scale, self.domain.event_dim) @@ -65,9 +101,15 @@ def _log_abs_det_jacobian( ) -> torch.Tensor: assert params is not None - _, log_scale = params - log_scale = clamp_preserve_gradients( - log_scale, self.log_scale_min_clip, self.log_scale_max_clip + _, unbounded_scale = params + if self.clamp_values: + unbounded_scale = clamp_preserve_gradients( + unbounded_scale, self.log_scale_min_clip, self.log_scale_max_clip + ) + log_scale = ( + self.positive_map(unbounded_scale).log() + if not self._exp_map + else unbounded_scale ) return _sum_rightmost(log_scale, self.domain.event_dim) diff --git a/flowtorch/distributions/flow.py b/flowtorch/distributions/flow.py index bfb0e97d..6b7d6488 100644 --- a/flowtorch/distributions/flow.py +++ b/flowtorch/distributions/flow.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc -from typing import Any, Dict, Optional, Union, Iterator +from typing import Any, Dict, Iterator, Optional, Union import flowtorch import torch diff --git a/flowtorch/parameters/__init__.py b/flowtorch/parameters/__init__.py index 86f8045c..c7651888 100644 --- a/flowtorch/parameters/__init__.py +++ b/flowtorch/parameters/__init__.py @@ -7,7 +7,17 @@ """ from flowtorch.parameters.base import Parameters +from flowtorch.parameters.conv11 import Conv1x1Params +from flowtorch.parameters.coupling import ConvCoupling +from flowtorch.parameters.coupling import DenseCoupling from flowtorch.parameters.dense_autoregressive import DenseAutoregressive from flowtorch.parameters.tensor import Tensor -__all__ = ["Parameters", "DenseAutoregressive", "Tensor"] +__all__ = [ + "Parameters", + "Conv1x1Params", + "ConvCoupling", + "DenseCoupling", + "DenseAutoregressive", + "Tensor", +] diff --git a/flowtorch/parameters/base.py b/flowtorch/parameters/base.py index 72e4b69f..5d0a74cc 100644 --- a/flowtorch/parameters/base.py +++ b/flowtorch/parameters/base.py @@ -24,15 +24,17 @@ def __init__( def forward( self, - x: Optional[torch.Tensor] = None, + input: torch.Tensor, + inverse: bool, context: Optional[torch.Tensor] = None, ) -> Optional[Sequence[torch.Tensor]]: # TODO: Caching etc. - return self._forward(x, context) + return self._forward(input, inverse=inverse, context=context) def _forward( self, - x: Optional[torch.Tensor] = None, + input: torch.Tensor, + inverse: bool, context: Optional[torch.Tensor] = None, ) -> Optional[Sequence[torch.Tensor]]: # I raise an exception rather than using @abstractmethod and diff --git a/flowtorch/parameters/conv11.py b/flowtorch/parameters/conv11.py new file mode 100644 index 00000000..c17eac74 --- /dev/null +++ b/flowtorch/parameters/conv11.py @@ -0,0 +1,167 @@ +# Copyright (c) Meta Platforms, Inc +from typing import Optional, Sequence, Union, List + +import torch +from flowtorch.parameters import Parameters +from scipy import linalg as scipy_linalg # type: ignore +from torch import nn +from torch.nn import functional as F + + +def _pixels(tensor: torch.Tensor) -> int: + return int(tensor.shape[-2] * tensor.shape[-1]) + + +def _sum( + tensor: torch.Tensor, + dim: Optional[Union[int, List[int]]] = None, + keepdim: bool = False, +) -> torch.Tensor: + if dim is None: + # sum up all dim + return torch.sum(tensor) + else: + if isinstance(dim, int): + dims = [dim] + else: + dims = dim + dims = sorted(dims) + for d in dims: + tensor = tensor.sum(dim=d, keepdim=True) + if not keepdim: + for i, d in enumerate(dims): + tensor.squeeze_(d - i) + return tensor + + +class Conv1x1Params(Parameters): + BIAS_SOFTPLUS = 0.54 + + def __init__( + self, + LU_decompose: bool, + param_shapes: Sequence[torch.Size], + input_shape: torch.Size, + context_shape: Optional[torch.Size], + zero_init: bool = True, + ) -> None: + self.LU = LU_decompose + self.zero_init = zero_init + super().__init__( + param_shapes=param_shapes, + input_shape=input_shape, + context_shape=context_shape, + ) + self._build() + + def _get_permutation_matrix(self, num_channels: int) -> torch.Tensor: + nz = [ + torch.zeros(num_channels - k).scatter_( + -1, + torch.multinomial( + torch.randn(num_channels - k).softmax(-1), 1, replacement=False + ), + 1.0, + ) + for k in range(num_channels) + ] + np_p0 = torch.zeros((num_channels, num_channels)) + allidx = torch.arange(0, num_channels) + for i, _nz in enumerate(nz): + np_p0[:, i][allidx] = _nz + allidx = allidx[~_nz.bool()] + return np_p0 + + def _build(self) -> None: + self.num_channels = num_channels = self.param_shapes[0][-3] + w_shape = torch.Size([num_channels, num_channels]) + + np_p0 = self._get_permutation_matrix(num_channels) + if self.zero_init: + w_init = None + else: + w = torch.randn(w_shape) + torch.eye(num_channels) * 1e-4 + w_init = torch.linalg.qr(w)[ + 0 + ] # np.linalg.qr(np.random.randn(*w_shape))[0].astype(np.float32) + if not self.LU: + # Sample a random orthogonal matrix: + self.register_buffer("p", np_p0) + weight = np_p0 @ w_init + self.weight = nn.Parameter(weight.clone().requires_grad_()) + else: + if w_init is not None: + np_p, np_l, np_u = [ + torch.tensor(_v) for _v in scipy_linalg.lu(w_init.numpy()) + ] + else: + np_p = np_p0 + np_l = torch.eye(w_shape[0]) + np_u = torch.eye(w_shape[0]) + self.register_buffer("p", np_p) + if self.zero_init: + np_s = torch.ones(num_channels).expm1().log() - self.BIAS_SOFTPLUS + np_s_sign = torch.ones_like(np_s, dtype=torch.int) + else: + np_s = abs(np_u.diag()).expm1().log() - self.BIAS_SOFTPLUS + np_s_sign = np_u.diag().sign() + self.log_s = nn.Parameter(np_s.requires_grad_()) + self.low = nn.Parameter(np_l.requires_grad_()) + self.up = nn.Parameter(np_u.requires_grad_()) + + self.register_buffer("s_sign", np_s_sign) + l_mask = torch.ones(w_shape).tril(-1) + self.register_buffer("l_mask", l_mask) + eye = torch.eye(*w_shape) + self.register_buffer("eye", eye) + + def _forward( + self, + input: torch.Tensor, + inverse: bool, + context: Optional[torch.Tensor] = None, + ) -> Optional[Sequence[torch.Tensor]]: + num_channels = self.num_channels + if not self.LU: + pixels = _pixels(input) + weight = self.weight + dlogdet = torch.slogdet(weight)[1] * pixels + if not inverse: + weight_out = self.weight.view( + num_channels, num_channels, 1, 1 + ) # type: ignore + else: + weight_double = weight.double() + dtype = weight.dtype + assert isinstance(dtype, torch.dtype) + weight_out = ( + torch.inverse(weight_double) + .to(dtype) + .view(num_channels, num_channels, 1, 1) + ) + return weight_out, dlogdet + else: + low = self.low + l_mask = self.l_mask + assert isinstance(l_mask, torch.Tensor) + eye = self.eye + assert isinstance(eye, torch.Tensor) + low_out = low * l_mask + eye + + log_s = self.log_s + s_sign = self.s_sign + s = F.softplus(self.BIAS_SOFTPLUS + log_s) * s_sign + assert isinstance(s, torch.Tensor) + + up = self.up + assert isinstance(up, torch.Tensor) + l_mask_transpose = l_mask.transpose(-1, -2) # type: ignore + up_out = up * l_mask_transpose + s.diag_embed() + dlogdet = _sum(abs(s).clamp_min(1e-5).log()) * _pixels(input) + if not inverse: + w = self.p @ low_out @ up_out + return w.view(num_channels, num_channels, 1, 1), dlogdet + else: + p = self.p + assert isinstance(p, torch.Tensor) + return p, low_out, up_out, dlogdet diff --git a/flowtorch/parameters/coupling.py b/flowtorch/parameters/coupling.py new file mode 100644 index 00000000..f8179c55 --- /dev/null +++ b/flowtorch/parameters/coupling.py @@ -0,0 +1,355 @@ +# Copyright (c) Meta Platforms, Inc + +from typing import Callable, Optional, Sequence + +import torch +import torch.nn as nn +from flowtorch.nn.made import MaskedLinear +from flowtorch.parameters.base import Parameters + + +def _make_mask(shape: torch.Size, mask_type: str) -> torch.Tensor: + if mask_type.startswith("neg_"): + return _make_mask(shape, mask_type[4:]) + elif mask_type == "chessboard": + z = torch.zeros(shape, dtype=torch.bool) + z[:, ::2, ::2] = 1 + z[:, 1::2, 1::2] = 1 + return z + elif mask_type == "quadrant": + z = torch.zeros(shape, dtype=torch.bool) + z[:, shape[1] // 2 :, : shape[2] // 2] = 1 + z[:, : shape[1] // 2, shape[2] // 2 :] = 1 + return z + else: + raise NotImplementedError(shape) + + +class DenseCoupling(Parameters): + autoregressive = False + + def __init__( + self, + param_shapes: Sequence[torch.Size], + input_shape: torch.Size, + context_shape: Optional[torch.Size], + *, + hidden_dims: Sequence[int] = (256, 256), + nonlinearity: Callable[[], nn.Module] = nn.ReLU, + permutation: Optional[torch.LongTensor] = None, + skip_connections: bool = False, + ) -> None: + super().__init__(param_shapes, input_shape, context_shape) + + # Check consistency of input_shape with param_shapes + # We need each param_shapes to match input_shape in + # its leftmost dimensions + for s in param_shapes: + assert (len(s) >= len(input_shape)) and ( + s[: len(input_shape)] == input_shape + ) + + self.hidden_dims = hidden_dims + self.nonlinearity = nonlinearity + self.skip_connections = skip_connections + self._build(input_shape, param_shapes, context_shape, permutation) + + def _build( + self, + input_shape: torch.Size, + param_shapes: Sequence[torch.Size], + context_shape: Optional[torch.Size], + permutation: Optional[torch.LongTensor], + ) -> None: + + # Work out flattened input and output shapes + param_shapes_ = list(param_shapes) + input_dims = sum(input_shape) + self.input_dims = input_dims + if input_dims == 0: + input_dims = 1 # scalars represented by torch.Size([]) + if permutation is None: + # permutation will define the split of the input + permutation = torch.LongTensor( + torch.randperm(input_dims, device="cpu").to( + torch.LongTensor((1,)).device + ) + ) + else: + # The permutation is chosen by the user + permutation = torch.LongTensor(permutation) + + self.param_dims = [ + int(max(torch.prod(torch.tensor(s[len(input_shape) :])).item(), 1)) + for s in param_shapes_ + ] + + self.output_multiplier = sum(self.param_dims) + + if input_dims == 1: + raise ValueError( + "Coupling input_dim = 1. Coupling transforms require at least " + "two features." + ) + + self.register_buffer("permutation", permutation) + self.register_buffer("inv_permutation", permutation.argsort()) + + # Create masks + hidden_dims = self.hidden_dims + + # Create masked layers: + # input is [x1 ; 0] + # output is [0 ; mu2], [0 ; sig2] + mask_input = torch.ones(hidden_dims[0], input_dims) + self.x1_dim = x1_dim = input_dims // 2 + mask_input[:, x1_dim:] = 0.0 + mask_input = mask_input[:, self.permutation] + + out_dims = input_dims * self.output_multiplier + mask_output = torch.ones( + self.output_multiplier, + input_dims, + hidden_dims[-1], + dtype=torch.bool, + ) + mask_output[:, :x1_dim] = 0.0 + mask_output = mask_output[:, self.permutation] + mask_output_reg = mask_output[0, :, 0] + mask_output = mask_output.view(-1, hidden_dims[-1]) + + self._bias = nn.Parameter( + torch.zeros(self.output_multiplier, x1_dim, requires_grad=True) + ) + + layers = [ + MaskedLinear( + input_dims, # + context_dims, + hidden_dims[0], + mask_input, + ), + self.nonlinearity(), + ] + for i in range(1, len(hidden_dims)): + layers.extend( + [ + nn.Linear(hidden_dims[i - 1], hidden_dims[i]), + self.nonlinearity(), + ] + ) + layers.append( + MaskedLinear( + hidden_dims[-1], + out_dims, + mask_output, + bias=False, + ) + ) + + if self.skip_connections: + self.skip_layer = MaskedLinear( + input_dims, # + context_dims, + out_dims, + mask_output, + bias=False, + ) + + self.layers = nn.Sequential(*layers) + self.register_buffer("mask_output", mask_output_reg.to(torch.bool)) + self._init_weights() + + def _init_weights(self) -> None: + for layer in self.modules(): + if hasattr(layer, "weight"): + layer.weight.data.normal_(0.0, 1e-3) # type: ignore + if hasattr(layer, "bias") and layer.bias is not None: + layer.bias.data.fill_(0.0) # type: ignore + + @property + def bias(self) -> torch.Tensor: + z = torch.zeros( + self.output_multiplier, + self.input_dims - self.x1_dim, + device=self._bias.device, + dtype=self._bias.dtype, + ) + return torch.cat([z, self._bias], -1).view(-1) + + def _forward( + self, + input: torch.Tensor, + inverse: bool, + context: Optional[torch.Tensor] = None, + ) -> Optional[Sequence[torch.Tensor]]: + + input_masked = input.masked_fill(self.mask_output, 0.0) # type: ignore + if context is not None: + input_aug = torch.cat( + [context.expand((*input.shape[:-1], -1)), input_masked], dim=-1 + ) + else: + input_aug = input_masked + + h = self.layers(input_aug) + self.bias + + # TODO: Get skip_layers working again! + if self.skip_connections: + h = h + self.skip_layer(input_aug) + + # Shape the output + h = h.view(*input.shape[:-1], self.output_multiplier, -1) + + result = h.unbind(-2) + result = tuple( + r.masked_fill(~self.mask_output.expand_as(r), 0.0) # type: ignore + for r in result # type: ignore + ) + return result + + +class ConvCoupling(Parameters): + autoregressive = False + _mask_types = [ + "chessboard", + "quadrants", + "inv_chessboard", + "inv_quadrants", + ] + + def __init__( + self, + param_shapes: Sequence[torch.Size], + input_shape: torch.Size, + context_shape: Optional[torch.Size], + *, + cnn_activate_input: bool = True, + cnn_channels: int = 256, + cnn_kernel: Sequence[int] = None, + cnn_padding: Sequence[int] = None, + cnn_stride: Sequence[int] = None, + nonlinearity: Callable[[], nn.Module] = nn.ReLU, + skip_connections: bool = False, + mask_type: str = "chessboard", + ) -> None: + super().__init__(param_shapes, input_shape, context_shape) + + # Check consistency of input_shape with param_shapes + # We need each param_shapes to match input_shape in + # its leftmost dimensions + for s in param_shapes: + assert (len(s) >= len(input_shape)) and ( + s[: len(input_shape)] == input_shape + ) + + if cnn_kernel is None: + cnn_kernel = [3, 1, 3] + if cnn_padding is None: + cnn_padding = [1, 0, 1] + if cnn_stride is None: + cnn_stride = [1, 1, 1] + + self.cnn_channels = cnn_channels + self.cnn_activate_input = cnn_activate_input + self.cnn_kernel = cnn_kernel + self.cnn_padding = cnn_padding + self.cnn_stride = cnn_stride + + self.nonlinearity = nonlinearity + self.skip_connections = skip_connections + self._build(input_shape, param_shapes, context_shape, mask_type) + + def _build( + self, + input_shape: torch.Size, # something like [C, W, H] + param_shapes: Sequence[torch.Size], # [[C, W, H], [C, W, H]] + context_shape: Optional[torch.Size], + mask_type: str, + ) -> None: + + mask = _make_mask(input_shape, mask_type) + self.register_buffer("mask", mask) + self.output_multiplier = len(param_shapes) + + out_channels, width, height = input_shape + + layers = [] + if self.cnn_activate_input: + layers.append(self.nonlinearity()) + layers.append( + nn.LazyConv2d( + out_channels=self.cnn_channels, + kernel_size=self.cnn_kernel[0], + padding=self.cnn_padding[0], + stride=self.cnn_stride[0], + ) + ) + layers.append(self.nonlinearity()) + layers.append( + nn.Conv2d( + in_channels=self.cnn_channels, + out_channels=self.cnn_channels, + kernel_size=self.cnn_kernel[1], + padding=self.cnn_padding[1], + stride=self.cnn_stride[1], + ) + ) + layers.append(self.nonlinearity()) + layers.append( + nn.Conv2d( + in_channels=self.cnn_channels, + out_channels=out_channels * self.output_multiplier, + kernel_size=self.cnn_kernel[2], + padding=self.cnn_padding[2], + stride=self.cnn_stride[2], + ) + ) + + self.layers = nn.Sequential(*layers) + self._init_weights() + + def _init_weights(self) -> None: + for layer in self.modules(): + if hasattr(layer, "weight"): + layer.weight.data.normal_(0.0, 1e-3) # type: ignore + if hasattr(layer, "bias") and layer.bias is not None: + layer.bias.data.fill_(0.0) # type: ignore + + def _forward( + self, + input: torch.Tensor, + inverse: bool, + context: Optional[torch.Tensor] = None, + ) -> Optional[Sequence[torch.Tensor]]: + + unsqueeze = False + if input.ndimension() == 3: + # mostly for initialization + unsqueeze = True + input = input.unsqueeze(0) + + input_masked = input.masked_fill(self.mask, 0.0) # type: ignore + if context is not None: + context_shape = [shape for shape in input_masked.shape] + context_shape[-3] = context.shape[-3] + input_aug = torch.cat( + [context.expand(*context_shape), input_masked], dim=-1 + ) + else: + input_aug = input_masked + + h = self.layers(input_aug) + + if self.skip_connections: + h = h + input_masked + + # Shape the output + + if unsqueeze: + h = h.squeeze(0) + result = h.chunk(2, -3) + + result = tuple( + r.masked_fill(~self.mask.expand_as(r), 0.0) for r in result # type: ignore + ) + + return result diff --git a/flowtorch/parameters/dense_autoregressive.py b/flowtorch/parameters/dense_autoregressive.py index 8110e5a6..70a432b9 100644 --- a/flowtorch/parameters/dense_autoregressive.py +++ b/flowtorch/parameters/dense_autoregressive.py @@ -45,6 +45,7 @@ def _build( ) -> None: # Work out flattened input and output shapes param_shapes_ = list(param_shapes) + # Why not just (sum(input_shape))? input_dims = int(torch.sum(torch.tensor(input_shape)).int().item()) if input_dims == 0: input_dims = 1 # scalars represented by torch.Size([]) @@ -60,6 +61,7 @@ def _build( # The permutation is chosen by the user permutation = torch.LongTensor(permutation) + # why not math.pod(s[len(input_shape):]), where math.prod([])=1? self.param_dims = [ int(max(torch.prod(torch.tensor(s[len(input_shape) :])).item(), 1)) for s in param_shapes_ @@ -141,33 +143,35 @@ def _build( ) ) + # Why not using regular sequential? self.layers = nn.ModuleList(layers) def _forward( self, - x: Optional[torch.Tensor] = None, + input: torch.Tensor, + inverse: bool, context: Optional[torch.Tensor] = None, ) -> Optional[Sequence[torch.Tensor]]: - assert x is not None - # Flatten x - batch_shape = x.shape[: len(x.shape) - len(self.input_shape)] + # Flatten input + batch_shape = input.shape[: len(input.shape) - len(self.input_shape)] if len(batch_shape) > 0: - x = x.reshape(batch_shape + (-1,)) + input = input.reshape(batch_shape + (-1,)) if context is not None: # TODO: Fix the following! - h = torch.cat([context.expand((x.shape[0], -1)), x], dim=-1) + h = torch.cat([context.expand((input.shape[0], -1)), input], dim=-1) else: - h = x + h = input + # Why not using regular sequential? for idx in range(len(self.layers) // 2): h = self.layers[2 * idx + 1](self.layers[2 * idx](h)) h = self.layers[-1](h) # TODO: Get skip_layers working again! # if self.skip_layer is not None: - # h = h + self.skip_layer(x) + # h = h + self.skip_layer(input) # Shape the output # h ~ (batch_dims * input_dims, total_params_per_dim) diff --git a/flowtorch/parameters/tensor.py b/flowtorch/parameters/tensor.py index 3de8680a..213188cc 100644 --- a/flowtorch/parameters/tensor.py +++ b/flowtorch/parameters/tensor.py @@ -22,6 +22,9 @@ def __init__( ) def _forward( - self, x: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None + self, + input: torch.Tensor, + inverse: bool, + context: Optional[torch.Tensor] = None, ) -> Optional[Sequence[torch.Tensor]]: return list(self.params) diff --git a/tests/test_bijectivetensor.py b/tests/test_bijectivetensor.py index 72bbdf70..fa340f57 100644 --- a/tests/test_bijectivetensor.py +++ b/tests/test_bijectivetensor.py @@ -15,7 +15,6 @@ def get_net() -> AffineAutoregressive: [ AffineAutoregressive(params.DenseAutoregressive()), AffineAutoregressive(params.DenseAutoregressive()), - AffineAutoregressive(params.DenseAutoregressive()), ] ) ar = ar( diff --git a/tests/test_bijector.py b/tests/test_bijector.py index e9344ef6..eec9f949 100644 --- a/tests/test_bijector.py +++ b/tests/test_bijector.py @@ -1,4 +1,5 @@ # Copyright (c) Meta Platforms, Inc +import math import warnings import flowtorch.bijectors as bijectors @@ -21,11 +22,13 @@ def test_bijector_constructor(): @pytest.fixture(params=[bij_name for _, bij_name in bijectors.standard_bijectors]) def flow(request): + torch.set_default_dtype(torch.double) bij = request.param event_dim = max(bij.domain.event_dim, 1) event_shape = event_dim * [3] base_dist = dist.Independent( - dist.Normal(torch.zeros(event_shape), torch.ones(event_shape)), event_dim + dist.Normal(torch.zeros(event_shape), torch.ones(event_shape)), + event_dim, ) flow = Flow(base_dist, bij) @@ -41,10 +44,12 @@ def test_jacobian(flow, epsilon=1e-2): x = torch.randn(*flow.event_shape) x = torch.distributions.transform_to(bij.domain)(x) y = bij.forward(x) - if bij.domain.event_dim == 1: - analytic_ldt = bij.log_abs_det_jacobian(x, y).data + if bij.domain.event_dim == 0: + analytic_ldt = bij.log_abs_det_jacobian(x, y).data.sum(-1) else: - analytic_ldt = bij.log_abs_det_jacobian(x, y).sum(-1).data + analytic_ldt = bij.log_abs_det_jacobian(x, y).data + for _ in range(bij.domain.event_dim - 1): + analytic_ldt = analytic_ldt.sum(-1) # Calculate numerical Jacobian # TODO: Better way to get all indices of array/tensor? @@ -86,7 +91,8 @@ def test_jacobian(flow, epsilon=1e-2): if hasattr(params, "permutation"): numeric_ldt = torch.sum(torch.log(torch.diag(jacobian))) else: - numeric_ldt = torch.log(torch.abs(jacobian.det())) + jacobian = jacobian.view(int(math.sqrt(jacobian.numel())), -1) + numeric_ldt = torch.log(torch.abs(jacobian.det())).sum() ldt_discrepancy = (analytic_ldt - numeric_ldt).abs() assert ldt_discrepancy < epsilon @@ -109,6 +115,7 @@ def test_inverse(flow, epsilon=1e-5): # Test g^{-1}(g(x)) = x x_true = base_dist.sample(torch.Size([10])) + assert x_true.dtype is torch.double x_true = torch.distributions.transform_to(bij.domain)(x_true) y = bij.forward(x_true) diff --git a/tests/test_distribution.py b/tests/test_distribution.py index db7c9095..25c065a5 100644 --- a/tests/test_distribution.py +++ b/tests/test_distribution.py @@ -15,7 +15,8 @@ def test_tdist_standalone(): def make_tdist(): # train a flow here base_dist = torch.distributions.Independent( - torch.distributions.Normal(torch.zeros(input_dim), torch.ones(input_dim)), 1 + torch.distributions.Normal(torch.zeros(input_dim), torch.ones(input_dim)), + 1, ) bijector = bijs.AffineAutoregressive() tdist = dist.Flow(base_dist, bijector) @@ -37,9 +38,9 @@ def test_neals_funnel_vi(): flow = dist.Flow(base_dist, bijector) bijector = flow.bijector - opt = torch.optim.Adam(flow.parameters(), lr=2e-3) + opt = torch.optim.Adam(flow.parameters(), lr=1e-2) num_elbo_mc_samples = 200 - for _ in range(100): + for _ in range(500): z0 = flow.base_dist.rsample(sample_shape=(num_elbo_mc_samples,)) zk = bijector.forward(z0) ldj = zk._log_detJ