Skip to content

Commit

Permalink
add monitor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderVNikitin committed Nov 19, 2023
1 parent 63cdfe9 commit 641fd75
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 5 deletions.
28 changes: 24 additions & 4 deletions tests/test_monitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt

import tsgm

Expand All @@ -20,15 +21,34 @@ def _get_labels(num_samples, output_dim):
return labels


def test_ganmonitor():
@pytest.mark.parametrize("save", [
True, False
])
def test_ganmonitor(save, monkeypatch):
monkeypatch.setattr(plt, 'show', lambda: None)
n_samples, n_classes = 3, 2
labels = _get_labels(n_samples, n_classes)
gan_monitor = tsgm.models.monitors.GANMonitor(
num_samples=3, latent_dim=123, labels=labels, mode="clf", save=True)
num_samples=3, latent_dim=123, labels=labels, mode="clf", save=save)
gan_monitor.model = MagicMock() # mock the model
gan_monitor.model.generator.side_effect = lambda x: x[:, None]

gan_monitor.on_epoch_end(epoch=100)
gan_monitor.on_epoch_end(epoch=2)


@pytest.mark.parametrize("save", [
True, False
])
def test_vaemonitor(save, monkeypatch):
monkeypatch.setattr(plt, 'show', lambda: None)
n_samples, n_classes = 3, 2
vae_monitor = tsgm.models.monitors.VAEMonitor(
num_samples=3, latent_dim=123, save=save)
vae_monitor.model = MagicMock() # mock the model
import pdb; pdb.set_trace()
vae_monitor.model.generate = lambda x: (x[:, 0][:, None], None)

vae_monitor.on_epoch_end(epoch=2)


def test_exceptions():
Expand All @@ -43,4 +63,4 @@ def test_exceptions():
gan_monitor = tsgm.models.monitors.GANMonitor(
num_samples=3, latent_dim=123, labels=labels, mode="clf", save=True)
gan_monitor._mode = "abcde123"
gan_monitor.on_epoch_end(epoch=100)
gan_monitor.on_epoch_end(epoch=2)
5 changes: 4 additions & 1 deletion tsgm/models/monitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,10 @@ def on_epoch_end(self, epoch, logs=None) -> None:
generated_images, _ = self.model.generate(labels)

for i in range(self._output_dim * self._num_samples):
sns.lineplot(x=range(0, generated_images[i].shape[0]), y=tf.squeeze(generated_images[i]))
sns.lineplot(
x=range(0, generated_images[i].shape[0]),
y=tf.squeeze(generated_images[i]).numpy()
)
if self._save:
plt.savefig(os.path.join(self._save_path, "epoch_{}_sample_{}".format(epoch, i)))
else:
Expand Down

0 comments on commit 641fd75

Please sign in to comment.