Skip to content

Commit

Permalink
feat: vanilla ae with independent channels
Browse files Browse the repository at this point in the history
Signed-off-by: Avik Basu <ab93@users.noreply.github.com>
  • Loading branch information
ab93 committed Jun 12, 2024
1 parent f3252a2 commit 4b84bf2
Show file tree
Hide file tree
Showing 7 changed files with 436 additions and 509 deletions.
684 changes: 182 additions & 502 deletions examples/fc.ipynb

Large diffs are not rendered by default.

8 changes: 7 additions & 1 deletion numalogic/models/autoencoder/variants/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from numalogic.models.autoencoder.variants.vanilla import VanillaAE, SparseVanillaAE, MultichannelAE
from numalogic.models.autoencoder.variants.vanilla import (
VanillaAE,
SparseVanillaAE,
MultichannelAE,
)
from numalogic.models.autoencoder.variants.icvanilla import VanillaICAE
from numalogic.models.autoencoder.variants.conv import Conv1dAE, SparseConv1dAE
from numalogic.models.autoencoder.variants.lstm import LSTMAE, SparseLSTMAE
from numalogic.models.autoencoder.variants.transformer import TransformerAE, SparseTransformerAE
Expand All @@ -16,4 +21,5 @@
"TransformerAE",
"SparseTransformerAE",
"BaseAE",
"VanillaICAE",
]
196 changes: 196 additions & 0 deletions numalogic/models/autoencoder/variants/icvanilla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
from collections.abc import Sequence

import torch
from torch import nn, Tensor

from numalogic.models.autoencoder.base import BaseAE
from numalogic.tools.exceptions import LayerSizeMismatchError
from numalogic.tools.layer import MultiChannelLinear


class _VanillaEncoder(nn.Module):
r"""Encoder module for the VanillaAE.
Args:
----
seq_len: sequence length / window length
n_features: num of features
layersizes: encoder layer size
dropout_p: the dropout value
"""

def __init__(
self,
seq_len: int,
n_features: int,
layersizes: Sequence[int],
dropout_p: float,
batchnorm: bool,
):
super().__init__()
self.seq_len = seq_len
self.n_features = n_features
self.dropout_p = dropout_p
self.bnorm = batchnorm

layers = self._construct_layers(layersizes)
self.encoder = nn.Sequential(*layers)

def _construct_layers(self, layersizes: Sequence[int]) -> nn.ModuleList:
r"""Utility function to generate a simple feedforward network layer.
Args:
----
layersizes: layer size
Returns
-------
A simple feedforward network layer of type nn.ModuleList
"""
layers = nn.ModuleList()
start_layersize = self.seq_len

for lsize in layersizes[:-1]:
_l = [MultiChannelLinear(start_layersize, lsize, self.n_features)]
if self.bnorm:
_l.append(nn.BatchNorm1d(self.n_features))

Check warning on line 57 in numalogic/models/autoencoder/variants/icvanilla.py

View check run for this annotation

Codecov / codecov/patch

numalogic/models/autoencoder/variants/icvanilla.py#L57

Added line #L57 was not covered by tests
layers.extend([*_l, nn.Tanh(), nn.Dropout(p=self.dropout_p)])
start_layersize = lsize

_l = [MultiChannelLinear(start_layersize, layersizes[-1], self.n_features)]
if self.bnorm:
_l.append(nn.BatchNorm1d(self.n_features))

Check warning on line 63 in numalogic/models/autoencoder/variants/icvanilla.py

View check run for this annotation

Codecov / codecov/patch

numalogic/models/autoencoder/variants/icvanilla.py#L63

Added line #L63 was not covered by tests
layers.extend([*_l, nn.Tanh(), nn.Dropout(p=self.dropout_p)])

return layers

def forward(self, x: Tensor) -> Tensor:
return self.encoder(x)


class _Decoder(nn.Module):
r"""Decoder module for the autoencoder module.
Args:
----
seq_len: sequence length / window length
n_features: num of features
layersizes: decoder layer size
dropout_p: the dropout value
"""

def __init__(
self,
seq_len: int,
n_features: int,
layersizes: Sequence[int],
dropout_p: float,
batchnorm: bool,
):
super().__init__()
self.seq_len = seq_len
self.n_features = n_features
self.dropout_p = dropout_p
self.bnorm = batchnorm

layers = self._construct_layers(layersizes)
self.decoder = nn.Sequential(*layers)

def forward(self, x: Tensor) -> Tensor:
return self.decoder(x)

def _construct_layers(self, layersizes: Sequence[int]) -> nn.ModuleList:
r"""Utility function to generate a simple feedforward network layer.
Args:
----
layersizes: layer size
Returns
-------
A simple feedforward network layer
"""
layers = nn.ModuleList()

for idx, _ in enumerate(layersizes[:-1]):
_l = [MultiChannelLinear(layersizes[idx], layersizes[idx + 1], self.n_features)]
if self.bnorm:
_l.append(nn.BatchNorm1d(self.n_features))

Check warning on line 120 in numalogic/models/autoencoder/variants/icvanilla.py

View check run for this annotation

Codecov / codecov/patch

numalogic/models/autoencoder/variants/icvanilla.py#L120

Added line #L120 was not covered by tests
layers.extend([*_l, nn.Tanh(), nn.Dropout(p=self.dropout_p)])

layers.append(MultiChannelLinear(layersizes[-1], self.seq_len, self.n_features))
return layers


class VanillaICAE(BaseAE):
r"""Multichannel Vanilla Autoencoder model based on the vanilla encoder and decoder.
Each channel is an isolated neural network.
Args:
----
seq_len: sequence length / window length
n_features: num of features/channel, each channel is a separate neural network
encoder_layersizes: encoder layer size (default = Sequence[int] = (16, 8))
decoder_layersizes: decoder layer size (default = Sequence[int] = (8, 16))
dropout_p: the dropout value (default=0.25)
batchnorm: Flag to enable batch normalization (default=False)
**kwargs: BaseAE kwargs
"""

def __init__(
self,
seq_len: int,
n_features: int = 1,
encoder_layersizes: Sequence[int] = (16, 8),
decoder_layersizes: Sequence[int] = (8, 16),
dropout_p: float = 0.25,
batchnorm: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.seq_len = seq_len
self.dropout_prob = dropout_p
self.n_features = n_features

if encoder_layersizes[-1] != decoder_layersizes[0]:
raise LayerSizeMismatchError(

Check warning on line 158 in numalogic/models/autoencoder/variants/icvanilla.py

View check run for this annotation

Codecov / codecov/patch

numalogic/models/autoencoder/variants/icvanilla.py#L158

Added line #L158 was not covered by tests
f"Last layersize of encoder: {encoder_layersizes[-1]} "
f"does not match first layersize of decoder: {decoder_layersizes[0]}"
)

self.encoder = _VanillaEncoder(
seq_len=seq_len,
n_features=n_features,
layersizes=encoder_layersizes,
dropout_p=dropout_p,
batchnorm=batchnorm,
)
self.decoder = _Decoder(
seq_len=seq_len,
n_features=n_features,
layersizes=decoder_layersizes,
dropout_p=dropout_p,
batchnorm=batchnorm,
)

self.encoder.apply(self.init_weights)
self.decoder.apply(self.init_weights)

@staticmethod
def init_weights(m: nn.Module) -> None:
"""Initialize the parameters in the model."""
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)

Check warning on line 185 in numalogic/models/autoencoder/variants/icvanilla.py

View check run for this annotation

Codecov / codecov/patch

numalogic/models/autoencoder/variants/icvanilla.py#L185

Added line #L185 was not covered by tests

def forward(self, batch: Tensor) -> tuple[Tensor, Tensor]:
batch = torch.swapdims(batch, 1, 2)
encoded = self.encoder(batch)
decoded = self.decoder(encoded)
return encoded, torch.swapdims(decoded, 1, 2)

def predict_step(self, batch: Tensor, batch_idx: int, dataloader_idx: int = 0):
"""Returns reconstruction for streaming input."""
recon = self.reconstruction(batch)
return self.criterion(batch, recon, reduction="none")
4 changes: 0 additions & 4 deletions numalogic/models/autoencoder/variants/vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
from numalogic.models.autoencoder.base import BaseAE
from numalogic.tools.exceptions import LayerSizeMismatchError

EMPTY_TENSOR = torch.empty(0)


class _VanillaEncoder(nn.Module):
r"""Encoder module for the VanillaAE.
Expand Down Expand Up @@ -221,8 +219,6 @@ class MultichannelAE(BaseAE):
decoder_layersizes: decoder layer size (default = Sequence[int] = (8, 16))
dropout_p: the dropout value (default=0.25)
batchnorm: Flag to enable batch normalization (default=False)
encoderinfo: Flag to enable returning encoder information in the "forward" step
(default=False)
**kwargs: BaseAE kwargs
"""

Expand Down
2 changes: 1 addition & 1 deletion numalogic/models/vae/variants/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.distributions import MultivariateNormal, kl_divergence

from numalogic.models.vae.base import BaseVAE
from numalogic.models.vae.layer import CausalConvBlock
from numalogic.tools.layer import CausalConvBlock
from numalogic.tools.exceptions import ModelInitializationError

_DEFAULT_KERNEL_SIZE: Final[int] = 3
Expand Down
33 changes: 33 additions & 0 deletions numalogic/models/vae/layer.py → numalogic/tools/layer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math
import torch
from torch import nn, Tensor
import torch.nn.functional as F

Expand Down Expand Up @@ -60,3 +62,34 @@ def __init__(

def forward(self, input_: Tensor) -> Tensor:
return self.relu(self.bnorm(self.conv(input_)))


class MultiChannelLinear(nn.Module):

Check failure on line 67 in numalogic/tools/layer.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D101)

numalogic/tools/layer.py:67:7: D101 Missing docstring in public class
def __init__(
self, in_features: int, out_features: int, n_channels: int, device=None, dtype=None
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.n_channels = n_channels

self.weight = nn.Parameter(
torch.empty((n_channels, in_features, out_features), **factory_kwargs)
)
self.bias = nn.Parameter(torch.empty((n_channels, 1, out_features), **factory_kwargs))
self.reset_parameters()

def reset_parameters(self) -> None:
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
nn.init.uniform_(self.bias, -bound, bound)

def forward(self, x):
x = torch.swapdims(x, 0, 1)
output = torch.bmm(x, self.weight) + self.bias
return torch.swapdims(output, 0, 1)

def extra_repr(self) -> str:
return f"in_features={self.in_features}, out_features={self.out_features}, n_channels={self.n_channels}"

Check failure on line 95 in numalogic/tools/layer.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

numalogic/tools/layer.py:95:101: E501 Line too long (112 > 100 characters)

Check warning on line 95 in numalogic/tools/layer.py

View check run for this annotation

Codecov / codecov/patch

numalogic/tools/layer.py#L95

Added line #L95 was not covered by tests
18 changes: 17 additions & 1 deletion tests/models/autoencoder/variants/test_vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from numalogic._constants import TESTS_DIR
from numalogic.tools.data import StreamingDataset, TimeseriesDataModule
from numalogic.tools.trainer import TimeseriesTrainer
from numalogic.models.autoencoder.variants.vanilla import VanillaAE, SparseVanillaAE, MultichannelAE
from numalogic.models.autoencoder.variants.vanilla import (
VanillaAE,
SparseVanillaAE,
MultichannelAE,
)
from numalogic.models.autoencoder.variants.icvanilla import VanillaICAE
from numalogic.tools.exceptions import LayerSizeMismatchError

ROOT_DIR = os.path.join(TESTS_DIR, "resources", "data")
Expand Down Expand Up @@ -70,6 +75,17 @@ def test_multichannel(self):
test_reconerr = stream_trainer.predict(model, dataloaders=streamloader, unbatch=False)
self.assertTupleEqual((229, SEQ_LEN, self.X_train.shape[1]), test_reconerr.size())

def test_vanilla_ic(self):
model = VanillaICAE(seq_len=SEQ_LEN, n_features=2)
datamodule = TimeseriesDataModule(SEQ_LEN, self.X_train, batch_size=BATCH_SIZE)
trainer = TimeseriesTrainer(fast_dev_run=True, deterministic=True)
trainer.fit(model, datamodule=datamodule)

streamloader = DataLoader(StreamingDataset(self.X_val, SEQ_LEN), batch_size=BATCH_SIZE)
stream_trainer = TimeseriesTrainer()
test_reconerr = stream_trainer.predict(model, dataloaders=streamloader, unbatch=False)
self.assertTupleEqual((229, SEQ_LEN, self.X_train.shape[1]), test_reconerr.size())

def test_native_train(self):
model = VanillaAE(
SEQ_LEN,
Expand Down

0 comments on commit 4b84bf2

Please sign in to comment.