From fa80378817584b3227c0f226dec74da39388124c Mon Sep 17 00:00:00 2001 From: Alexander Nikitin <1243786+AlexanderVNikitin@users.noreply.github.com> Date: Sun, 19 Nov 2023 15:03:09 +0200 Subject: [PATCH] add architecture tests --- tests/test_monitors.py | 1 - tests/test_zoo.py | 8 ++++++-- tsgm/models/architectures/zoo.py | 7 +------ 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/test_monitors.py b/tests/test_monitors.py index b38b7ac..090fb37 100644 --- a/tests/test_monitors.py +++ b/tests/test_monitors.py @@ -45,7 +45,6 @@ def test_vaemonitor(save, monkeypatch): vae_monitor = tsgm.models.monitors.VAEMonitor( num_samples=3, latent_dim=123, save=save) vae_monitor.model = MagicMock() # mock the model - import pdb; pdb.set_trace() vae_monitor.model.generate = lambda x: (x[:, 0][:, None], None) vae_monitor.on_epoch_end(epoch=2) diff --git a/tests/test_zoo.py b/tests/test_zoo.py index 81a3914..1f8e543 100644 --- a/tests/test_zoo.py +++ b/tests/test_zoo.py @@ -66,7 +66,11 @@ def test_zoo_clf(model_type_name): assert arch.model == arch_dict["model"] -def test_basic_rec(): +@pytest.mark.parametrize("network_type", [ + "gru", + "lstm", +]) +def test_basic_rec(network_type): seq_len = 10 feat_dim = 2 output_dim = 1 @@ -75,7 +79,7 @@ def test_basic_rec(): hidden_dim=2, output_dim=output_dim, n_layers=1, - network_type="gru") + network_type=network_type) model = arch.build() assert model is not None diff --git a/tsgm/models/architectures/zoo.py b/tsgm/models/architectures/zoo.py index c02baad..821e709 100644 --- a/tsgm/models/architectures/zoo.py +++ b/tsgm/models/architectures/zoo.py @@ -477,7 +477,7 @@ def __init__( self.n_layers = n_layers self.network_type = network_type.lower() - assert self.network_type in ["gru", "lstm", "lstmLN"] + assert self.network_type in ["gru", "lstm"] self._name = name @@ -493,11 +493,6 @@ def _rnn_cell(self) -> keras.layers.Layer: # LSTM elif self.network_type == "lstm": cell = keras.layers.LSTMCell(self.hidden_dim, activation="tanh") - # LSTM Layer Normalization - elif self.network_type == "lstmLN": - cell = keras.layers.LayerNormLSTMCell( - num_units=self.hidden_dim, activation="tanh" - ) return cell def _make_network(self, model: keras.models.Model, activation: str, return_sequences: bool) -> keras.models.Model: