diff --git a/tests/test_monitors.py b/tests/test_monitors.py index 2a4845f..b38b7ac 100644 --- a/tests/test_monitors.py +++ b/tests/test_monitors.py @@ -5,6 +5,7 @@ import tensorflow as tf from tensorflow import keras +import matplotlib.pyplot as plt import tsgm @@ -20,15 +21,34 @@ def _get_labels(num_samples, output_dim): return labels -def test_ganmonitor(): +@pytest.mark.parametrize("save", [ + True, False +]) +def test_ganmonitor(save, monkeypatch): + monkeypatch.setattr(plt, 'show', lambda: None) n_samples, n_classes = 3, 2 labels = _get_labels(n_samples, n_classes) gan_monitor = tsgm.models.monitors.GANMonitor( - num_samples=3, latent_dim=123, labels=labels, mode="clf", save=True) + num_samples=3, latent_dim=123, labels=labels, mode="clf", save=save) gan_monitor.model = MagicMock() # mock the model gan_monitor.model.generator.side_effect = lambda x: x[:, None] - gan_monitor.on_epoch_end(epoch=100) + gan_monitor.on_epoch_end(epoch=2) + + +@pytest.mark.parametrize("save", [ + True, False +]) +def test_vaemonitor(save, monkeypatch): + monkeypatch.setattr(plt, 'show', lambda: None) + n_samples, n_classes = 3, 2 + vae_monitor = tsgm.models.monitors.VAEMonitor( + num_samples=3, latent_dim=123, save=save) + vae_monitor.model = MagicMock() # mock the model + import pdb; pdb.set_trace() + vae_monitor.model.generate = lambda x: (x[:, 0][:, None], None) + + vae_monitor.on_epoch_end(epoch=2) def test_exceptions(): @@ -43,4 +63,4 @@ def test_exceptions(): gan_monitor = tsgm.models.monitors.GANMonitor( num_samples=3, latent_dim=123, labels=labels, mode="clf", save=True) gan_monitor._mode = "abcde123" - gan_monitor.on_epoch_end(epoch=100) + gan_monitor.on_epoch_end(epoch=2) diff --git a/tsgm/models/monitors.py b/tsgm/models/monitors.py index d160bfa..0d43a2a 100644 --- a/tsgm/models/monitors.py +++ b/tsgm/models/monitors.py @@ -93,7 +93,10 @@ def on_epoch_end(self, epoch, logs=None) -> None: generated_images, _ = self.model.generate(labels) for i in range(self._output_dim * self._num_samples): - sns.lineplot(x=range(0, generated_images[i].shape[0]), y=tf.squeeze(generated_images[i])) + sns.lineplot( + x=range(0, generated_images[i].shape[0]), + y=tf.squeeze(generated_images[i]).numpy() + ) if self._save: plt.savefig(os.path.join(self._save_path, "epoch_{}_sample_{}".format(epoch, i))) else: