From 2b138cbf38dbbd64d02bb6867a59c63786e8a0e8 Mon Sep 17 00:00:00 2001 From: Alexander Nikitin <1243786+AlexanderVNikitin@users.noreply.github.com> Date: Thu, 11 Jul 2024 16:52:59 +0300 Subject: [PATCH] add wgan --- README.md | 1 + tests/test_cgan.py | 32 +++++++++-- tsgm/models/architectures/zoo.py | 94 ++++++++++++++++++++++++++++++++ tsgm/models/cgan.py | 60 +++++++++++++++++--- 4 files changed, 176 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 3bcf4fd..e5ca903 100644 --- a/README.md +++ b/README.md @@ -127,6 +127,7 @@ TSGM implements several generative models for synthetic time series data. | ------------- | ------------- | ------------- | ------------- | | Structural Time Series model | [tsgm.models.sts.STS](https://tsgm.readthedocs.io/en/latest/modules/root.html#tsgm.models.sts.STS) | Data-driven | Great for modeling time series when prior knowledge is available (e.g., trend or seasonality). | | GAN | [tsgm.models.cgan.GAN](https://tsgm.readthedocs.io/en/latest/modules/root.html#tsgm.models.cgan.GAN) | Data-driven | A generic implementation of GAN for time series generation. It can be customized with architectures for generators and discriminators. | +| WaveGAN | [tsgm.models.cgan.GAN](https://tsgm.readthedocs.io/en/latest/modules/root.html#tsgm.models.cgan.GAN) | Data-driven | WaveGAN is the model for audio synthesis proposed in [Adversarial Audio Synthesis](https://arxiv.org/abs/1802.04208). To use WaveGAN, set `use_wgan=True` when initializing the GAN class and use the `zoo["wavegan"]` architecture from the model zoo. | | ConditionalGAN | [tsgm.models.cgan.ConditionalGAN](https://tsgm.readthedocs.io/en/latest/modules/root.html#tsgm.models.cgan.ConditionalGAN) | Data-driven | A generic implementation of conditional GAN. It supports scalar conditioning as well as temporal one. | | BetaVAE | [tsgm.models.cvae.BetaVAE](https://tsgm.readthedocs.io/en/latest/modules/root.html#tsgm.models.cvae.BetaVAE) | Data-driven | A generic implementation of Beta VAE for TS. The loss function is customized to work well with multi-dimensional time series. | | cBetaVAE | [tsgm.models.cvae.cBetaVAE](https://tsgm.readthedocs.io/en/latest/modules/root.html#tsgm.models.cvae.cBetaVAE) | Data-driven | Conditional version of BetaVAE. It supports temporal a scalar condiotioning.| diff --git a/tests/test_cgan.py b/tests/test_cgan.py index 8c0e78d..298eb81 100644 --- a/tests/test_cgan.py +++ b/tests/test_cgan.py @@ -236,7 +236,6 @@ def test_dp_compiler(): learning_rate=learning_rate ) - g_optimizer = tf_privacy.DPKerasAdamOptimizer( l2_norm_clip=l2_norm_clip, noise_multiplier=noise_multiplier, @@ -259,6 +258,31 @@ def test_dp_compiler(): assert generated_samples.shape == (10, 64, 1) -def test_temporal_cgan_multiple_features(): - # TODO - pass +def test_wavegan(): + latent_dim = 2 + output_dim = 1 + feature_dim = 1 + seq_len = 64 + batch_size = 48 + + dataset = _gen_dataset(seq_len, feature_dim, batch_size) + architecture = tsgm.models.architectures.zoo["wavegan"]( + seq_len=seq_len, feat_dim=feature_dim, + latent_dim=latent_dim, output_dim=output_dim) + discriminator, generator = architecture.discriminator, architecture.generator + gan = tsgm.models.cgan.GAN( + discriminator=discriminator, generator=generator, latent_dim=latent_dim, use_wgan=True + ) + gan.compile( + d_optimizer=keras.optimizers.Adam(learning_rate=0.0003), + g_optimizer=keras.optimizers.Adam(learning_rate=0.0003), + loss_fn=keras.losses.BinaryCrossentropy(), + ) + + gan.fit(dataset, epochs=1) + + assert gan.generator is not None + assert gan.discriminator is not None + # Check generation + generated_samples = gan.generate(10) + assert generated_samples.shape == (10, seq_len, 1) diff --git a/tsgm/models/architectures/zoo.py b/tsgm/models/architectures/zoo.py index 2253c8b..87e4b21 100644 --- a/tsgm/models/architectures/zoo.py +++ b/tsgm/models/architectures/zoo.py @@ -871,6 +871,99 @@ def _build_generator(self, output_activation: str) -> keras.Model: return generator +class WaveGANArchitecture(BaseGANArchitecture): + """ + WaveGAN architecture, from https://arxiv.org/abs/1802.04208 + + Inherits from BaseGANArchitecture. + """ + arch_type = "gan:raw" + + def __init__(self, seq_len: int, feat_dim: int = 64, latent_dim: int = 32, output_dim: int = 1, kernel_size: int = 32, phase_rad: int = 2, use_batchnorm: bool = False): + """ + Initializes the WaveGANArchitecture. + + :param seq_len: Length of input sequences. + :type seq_len: int + :param feat_dim: Dimensionality of input features. + :type feat_dim: int + :param latent_dim: Dimensionality of the latent space. + :type latent_dim: int + :param output_dim: Dimensionality of the output. + :type output_dim: int + :param kernel_size: Sizes of convolutions + :type kernel_size: int, optional + :param phase_rad: Phase shuffle radius for wavegan (default is 2) + :type phase_rad: int, optional + :param use_batchnorm: Whether to use batchnorm (default is False) + :type use_batchnorm: bool, optional + """ + self.seq_len = seq_len + self.feat_dim = feat_dim + self.latent_dim = latent_dim + self.kernel_size = kernel_size + self.phase_rad = phase_rad + self.output_dim = output_dim + self.use_batchnorm = use_batchnorm + + self._discriminator = self._build_discriminator() + self._generator = self._build_generator() + + def _apply_phaseshuffle(self, x, rad): + ''' + Based on + https://github.com/chrisdonahue/wavegan/ + ''' + if rad <= 0 or x.shape[1] <= 1: + return x + + b, x_len, nch = x.get_shape().as_list() + + phase = tf.random.uniform([], minval=-rad, maxval=rad + 1, dtype=tf.int32) + pad_l, pad_r = tf.maximum(phase, 0), tf.maximum(-phase, 0) + phase_start = pad_r + x = tf.pad(x, [[0, 0], [pad_l, pad_r], [0, 0]], mode="reflect") + + x = x[:, phase_start:phase_start + x_len] + x.set_shape([b, x_len, nch]) + + return x + + def _conv_transpose_block(self, inputs, channels, strides=4): + x = layers.Conv1DTranspose(channels, self.kernel_size, strides=strides, padding='same', use_bias=False)(inputs) + x = layers.BatchNormalization()(x) if self.use_batchnorm else x + x = layers.LeakyReLU()(x) + return x + + def _build_generator(self): + inputs = layers.Input((self.latent_dim,)) + x = layers.Dense(16 * 1024, use_bias=False)(inputs) + x = layers.BatchNormalization()(x) if self.use_batchnorm else x + x = layers.LeakyReLU()(x) + x = layers.Reshape((16, 1024))(x) + + for conv_size in [512, 256, 128, 64]: + x = self._conv_transpose_block(x, conv_size) + + x = layers.Conv1DTranspose(1, self.kernel_size, strides=4, padding='same', use_bias=False, activation='tanh')(x) + pool_and_stride = math.ceil((x.shape[1] + 1) / (self.seq_len + 1)) + x = layers.AveragePooling1D(pool_size=pool_and_stride, strides=pool_and_stride)(x) + return keras.Model(inputs, x) + + def _build_discriminator(self): + inputs = layers.Input((self.seq_len, self.feat_dim)) + for conv_size in [64, 128, 256, 512]: + x = layers.Conv1D(conv_size, self.kernel_size, strides=4, padding='same')(inputs) + x = layers.BatchNormalization()(x) if self.use_batchnorm else x + x = layers.LeakyReLU()(x) + x = self._apply_phaseshuffle(x, self.phase_rad) + + x = layers.Flatten()(x) + x = layers.Dense(1)(x) + + return keras.Model(inputs, x) + + class Zoo(dict): """ A collection of architectures represented. It behaves like supports Python `dict` API. @@ -901,6 +994,7 @@ def summary(self) -> None: "t-cgan_c4": tcGAN_Conv4Architecture, "cgan_lstm_n": cGAN_LSTMnArchitecture, "cgan_lstm_3": cGAN_LSTMConv3Architecture, + "wavegan": WaveGANArchitecture, # Downstream models "clf_cn": ConvnArchitecture, diff --git a/tsgm/models/cgan.py b/tsgm/models/cgan.py index 2ffc188..51a0454 100644 --- a/tsgm/models/cgan.py +++ b/tsgm/models/cgan.py @@ -27,26 +27,56 @@ class GAN(keras.Model): """ GAN implementation for unlabeled time series. """ - def __init__(self, discriminator: keras.Model, generator: keras.Model, latent_dim: int) -> None: + def __init__(self, discriminator: keras.Model, generator: keras.Model, latent_dim: int, use_wgan: bool = False) -> None: """ :param discriminator: A discriminator model which takes a time series as input and check - whether the image is real or fake. + whether the sample is real or fake. :type discriminator: keras.Model :param generator: Takes as input a random noise vector of `latent_dim` length and returns a simulated time-series. :type generator: keras.Model :param latent_dim: The size of the noise vector. :type latent_dim: int + :param use_wgan: Use Wasserstein GAN with gradien penalty + :type use_wgan: bool """ super(GAN, self).__init__() self.discriminator = discriminator self.generator = generator self.latent_dim = latent_dim self._seq_len = self.generator.output_shape[1] + self.use_wgan = use_wgan + self.gp_weight = 10.0 self.gen_loss_tracker = keras.metrics.Mean(name="generator_loss") self.disc_loss_tracker = keras.metrics.Mean(name="discriminator_loss") + def wgan_discriminator_loss(self, real_sample, fake_sample): + real_loss = tf.reduce_mean(real_sample) + fake_loss = tf.reduce_mean(fake_sample) + return fake_loss - real_loss + + # Define the loss functions to be used for generator + def wgan_generator_loss(self, fake_sample): + return -tf.reduce_mean(fake_sample) + + def gradient_penalty(self, batch_size, real_samples, fake_samples): + # get the interpolated samples + alpha = tf.random.normal([batch_size, 1, 1], 0.0, 1.0) + diff = fake_samples - real_samples + interpolated = real_samples + alpha * diff + with tf.GradientTape() as gp_tape: + gp_tape.watch(interpolated) + # 1. Get the discriminator output for this interpolated sample. + pred = self.discriminator(interpolated, training=True) + + # 2. Calculate the gradients w.r.t to this interpolated sample. + grads = gp_tape.gradient(pred, [interpolated])[0] + # 3. Calcuate the norm of the gradients + norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2])) + gp = tf.reduce_mean((norm - 1.0) ** 2) + return gp + @property def metrics(self) -> T.List: """ @@ -94,7 +124,6 @@ def train_step(self, data: tsgm.types.Tensor) -> T.Dict[str, float]: """ real_data = data batch_size = tf.shape(real_data)[0] - # Generate ts random_vector = self._get_random_vector_labels(batch_size) fake_data = self.generator(random_vector) @@ -111,7 +140,19 @@ def train_step(self, data: tsgm.types.Tensor) -> T.Dict[str, float]: ) with tf.GradientTape() as tape: predictions = self.discriminator(combined_data) - d_loss = self.loss_fn(desc_labels, predictions) + if self.use_wgan: + fake_logits = self.discriminator(fake_data, training=True) + # Get the logits for the real samples + real_logits = self.discriminator(real_data, training=True) + + # Calculate the discriminator loss using the fake and real sample logits + d_cost = self.wgan_discriminator_loss(real_logits, fake_logits) + # Calculate the gradient penalty + gp = self.gradient_penalty(batch_size, real_data, fake_data) + # Add the gradient penalty to the original discriminator loss + d_loss = d_cost + gp * self.gp_weight + else: + d_loss = self.loss_fn(desc_labels, predictions) grads = tape.gradient(d_loss, self.discriminator.trainable_weights) self.d_optimizer.apply_gradients( zip(grads, self.discriminator.trainable_weights) @@ -126,7 +167,11 @@ def train_step(self, data: tsgm.types.Tensor) -> T.Dict[str, float]: with tf.GradientTape() as tape: fake_data = self.generator(random_vector) predictions = self.discriminator(fake_data) - g_loss = self.loss_fn(misleading_labels, predictions) + if self.use_wgan: + # uses logits + g_loss = self.wgan_generator_loss(predictions) + else: + g_loss = self.loss_fn(misleading_labels, predictions) grads = tape.gradient(g_loss, self.generator.trainable_weights) self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights)) @@ -167,10 +212,10 @@ class ConditionalGAN(keras.Model): """ Conditional GAN implementation for labeled and temporally labeled time series. """ - def __init__(self, discriminator: keras.Model, generator: keras.Model, latent_dim: int, temporal=False) -> None: + def __init__(self, discriminator: keras.Model, generator: keras.Model, latent_dim: int, temporal=False, use_wgan=False) -> None: """ :param discriminator: A discriminator model which takes a time series as input and check - whether the image is real or fake. + whether the sample is real or fake. :type discriminator: keras.Model :param generator: Takes as input a random noise vector of `latent_dim` length and return a simulated time-series. @@ -312,6 +357,7 @@ def train_step(self, data: T.Tuple) -> T.Dict[str, float]: fake_data = tf.concat([fake_samples, rep_labels], -1) predictions = self.discriminator(fake_data) g_loss = self.loss_fn(misleading_labels, predictions) + if self.dp: # For DP optimizers from `tensorflow.privacy` self.g_optimizer.minimize(g_loss, self.generator.trainable_weights, tape=tape)