Skip to content

Commit

Permalink
Tests/time gan (#24)
Browse files Browse the repository at this point in the history
* optuna

* logging

* tests

* tests

* test training
  • Loading branch information
letiziaia authored Nov 16, 2023
1 parent e2884c2 commit dc77d7c
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 13 deletions.
163 changes: 160 additions & 3 deletions tests/test_timegan.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from unittest.mock import Mock
import tsgm

import tensorflow as tf
Expand All @@ -14,7 +15,12 @@ def test_timegan():

dataset = _gen_dataset(batch_size, seq_len, feature_dim)
timegan = tsgm.models.timeGAN.TimeGAN(
seq_len=seq_len, module="gru", hidden_dim=latent_dim, n_features=feature_dim, n_layers=3, batch_size=batch_size
seq_len=seq_len,
module="gru",
hidden_dim=latent_dim,
n_features=feature_dim,
n_layers=3,
batch_size=batch_size,
)
timegan.compile()
timegan.fit(dataset, epochs=1)
Expand All @@ -34,7 +40,12 @@ def test_timegan_on_dataset():

dataset = _gen_tf_dataset(batch_size, seq_len, feature_dim) # tf.data.Dataset
timegan = tsgm.models.timeGAN.TimeGAN(
seq_len=seq_len, module="gru", hidden_dim=latent_dim, n_features=feature_dim, n_layers=3, batch_size=batch_size
seq_len=seq_len,
module="gru",
hidden_dim=latent_dim,
n_features=feature_dim,
n_layers=3,
batch_size=batch_size,
)
timegan.compile()
timegan.fit(dataset, epochs=1)
Expand Down Expand Up @@ -91,7 +102,6 @@ def _gen_tf_dataset(no, seq_len, dim):


def _check_internals(timegan):

# Check internal nets
assert timegan.generator is not None
assert timegan.discriminator is not None
Expand All @@ -111,3 +121,150 @@ def _check_internals(timegan):
assert timegan.embedder_opt is not None
assert timegan.autoencoder_opt is not None
assert timegan.adversarialsup_opt is not None


def test_losstracker():
losstracker = tsgm.models.timeGAN.LossTracker()
losstracker["foo"] = 0.1
assert isinstance(losstracker.to_numpy(), np.ndarray)
assert isinstance(losstracker.labels(), list)


@pytest.fixture
def mocked_gradienttape(mocker):
mock = Mock()
mock.gradient.return_value = [1.0, 1.0, 1.0]
return mock


def test_train_timegan(mocked_gradienttape):
latent_dim = 24
feature_dim = 6
seq_len = 24
batch_size = 2

dataset = _gen_dataset(batch_size, seq_len, feature_dim)
timegan = tsgm.models.timeGAN.TimeGAN(
seq_len=seq_len,
module="gru",
hidden_dim=latent_dim,
n_features=feature_dim,
n_layers=3,
batch_size=batch_size,
)
timegan.compile()
timegan.fit(dataset, epochs=1)
batches = timegan._get_data_batch(dataset, n_windows=len(dataset))
assert timegan._train_autoencoder(next(batches), timegan.autoencoder_opt)
assert timegan._train_supervisor(next(batches), timegan.adversarialsup_opt)
assert timegan._train_generator(
next(batches), next(timegan.get_noise_batch()), timegan.generator_opt
)
assert timegan._train_embedder(next(batches), timegan.embedder_opt)
assert timegan._train_discriminator(
next(batches), next(timegan.get_noise_batch()), timegan.discriminator_opt
)


@pytest.fixture
def mock_optimizer():
yield tf.keras.optimizers.Adam(learning_rate=0.001)


@pytest.fixture
def mocked_data():
feature_dim = 6
seq_len = 24
batch_size = 16
yield _gen_tf_dataset(batch_size, seq_len, feature_dim)


@pytest.fixture
def mocked_timegan(mocked_data):
latent_dim = 24
feature_dim = 6
seq_len = 24
batch_size = 16

timegan = tsgm.models.timeGAN.TimeGAN(
seq_len=seq_len,
module="gru",
hidden_dim=latent_dim,
n_features=feature_dim,
n_layers=3,
batch_size=batch_size,
)
timegan.compile()
timegan.fit(mocked_data, epochs=1)
yield timegan


def test_timegan_train_autoencoder(mocked_data, mocked_timegan):
batches = iter(mocked_data.repeat())

mocked_timegan._define_timegan()
X_ = next(batches)
loss = mocked_timegan._train_autoencoder(X_, mocked_timegan.autoencoder_opt)

# Assert that the loss is a float
assert loss.dtype in [tf.float32, tf.float64]


def test_timegan_train_embedder(mocked_data, mocked_timegan):
batches = iter(mocked_data.repeat())

mocked_timegan._define_timegan()
X_ = next(batches)
_, loss = mocked_timegan._train_embedder(X_, mocked_timegan.embedder_opt)

# Assert that the loss is a float
assert loss.dtype in [tf.float32, tf.float64]


def test_timegan_train_generator(mocked_data, mocked_timegan):
batches = iter(mocked_data.repeat())

mocked_timegan._define_timegan()
X_ = next(batches)
Z_ = next(mocked_timegan.get_noise_batch())
(
step_g_loss_u,
step_g_loss_u_e,
step_g_loss_s,
step_g_loss_v,
step_g_loss,
) = mocked_timegan._train_generator(X_, Z_, mocked_timegan.generator_opt)

# Assert that the loss is a float
for loss in (
step_g_loss_u,
step_g_loss_u_e,
step_g_loss_s,
step_g_loss_v,
step_g_loss,
):
assert loss.dtype in [tf.float32, tf.float64]


def test_timegan_check_discriminator_loss(mocked_data, mocked_timegan):
batches = iter(mocked_data.repeat())

mocked_timegan._define_timegan()
X_ = next(batches)
Z_ = next(mocked_timegan.get_noise_batch())
loss = mocked_timegan._check_discriminator_loss(X_, Z_)

# Assert that the loss is a float
assert loss.dtype in [tf.float32, tf.float64]


def test_timegan_train_discriminator(mocked_data, mocked_timegan):
batches = iter(mocked_data.repeat())

mocked_timegan._define_timegan()
X_ = next(batches)
Z_ = next(mocked_timegan.get_noise_batch())
loss = mocked_timegan._train_discriminator(X_, Z_, mocked_timegan.discriminator_opt)

# Assert that the loss is a float
assert loss.dtype in [tf.float32, tf.float64]
22 changes: 12 additions & 10 deletions tsgm/models/timeGAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,21 +492,21 @@ def fit(
self._define_timegan()

# 1. Embedding network training
print("Start Embedding Network Training")
logger.info("Start Embedding Network Training")

for epoch in tqdm(range(epochs), desc="Autoencoder - training"):
X_ = next(batches)
step_e_loss_0 = self._train_autoencoder(X_, self.autoencoder_opt)

# Checkpoint
if checkpoints_interval is not None and epoch % checkpoints_interval == 0:
print(f"step: {epoch}/{epochs}, e_loss: {step_e_loss_0}")
logger.info(f"step: {epoch}/{epochs}, e_loss: {step_e_loss_0}")
self.training_losses_history["autoencoder"] = float(step_e_loss_0)

print("Finished Embedding Network Training")
logger.info("Finished Embedding Network Training")

# 2. Training only with supervised loss
print("Start Training with Supervised Loss Only")
logger.info("Start Training with Supervised Loss Only")

# Adversarial Supervised network training
for epoch in tqdm(range(epochs), desc="Adversarial Supervised - training"):
Expand All @@ -515,17 +515,17 @@ def fit(

# Checkpoint
if checkpoints_interval is not None and epoch % checkpoints_interval == 0:
print(
logger.info(
f"step: {epoch}/{epochs}, s_loss: {np.round(np.sqrt(step_g_loss_s), 4)}"
)
self.training_losses_history["adversarial_supervised"] = float(
np.sqrt(step_g_loss_s)
)

print("Finished Training with Supervised Loss Only")
logger.info("Finished Training with Supervised Loss Only")

# 3. Joint Training
print("Start Joint Training")
logger.info("Start Joint Training")

# GAN with embedding network training
for epoch in tqdm(range(epochs), desc="GAN with embedding - training"):
Expand Down Expand Up @@ -554,12 +554,14 @@ def fit(
Z_ = next(self.get_noise_batch())
step_d_loss = self._check_discriminator_loss(X_, Z_)
if step_d_loss > 0.15:
print("Train Discriminator (discriminator does not work well yet)")
logger.info(
"Train Discriminator (discriminator does not work well yet)"
)
step_d_loss = self._train_discriminator(X_, Z_, self.discriminator_opt)

# Print multiple checkpoints
if checkpoints_interval is not None and epoch % checkpoints_interval == 0:
print(
logger.info(
f"""step: {epoch}/{epochs},
d_loss: {np.round(step_d_loss, 4)},
g_loss_u: {np.round(step_g_loss_u, 4)},
Expand All @@ -582,7 +584,7 @@ def fit(
_sample = self.generate(n_samples=len(data))
self.synthetic_data_generated_in_training[epoch] = _sample

print("Finished Joint Training")
logger.info("Finished Joint Training")
return

def generate(self, n_samples: int) -> TensorLike:
Expand Down

0 comments on commit dc77d7c

Please sign in to comment.