diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 93ec81d..52a9f91 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -11,5 +11,6 @@ optuna prettytable seaborn scikit-learn +torch yfinance==0.2.28 tqdm diff --git a/tsgm/models/cvae.py b/tsgm/models/cvae.py index 3080f6c..46baad9 100644 --- a/tsgm/models/cvae.py +++ b/tsgm/models/cvae.py @@ -1,5 +1,6 @@ from tensorflow import keras import tensorflow as tf +import torch import typing as T import tsgm.utils @@ -9,7 +10,10 @@ class BetaVAE(keras.Model): """ beta-VAE implementation for unlabeled time series. """ - def __init__(self, encoder: keras.Model, decoder: keras.Model, beta: float = 1.0, **kwargs) -> None: + + def __init__( + self, encoder: keras.Model, decoder: keras.Model, beta: float = 1.0, **kwargs + ) -> None: """ :param encoder: An encoder model which takes a time series as input and check whether the image is real or fake. @@ -59,10 +63,14 @@ def call(self, X: tsgm.types.Tensor) -> tsgm.types.Tensor: x_decoded = x_decoded.reshape((1, -1)) return x_decoded - def _get_reconstruction_loss(self, X: tsgm.types.Tensor, Xr: tsgm.types.Tensor) -> float: - reconst_loss = tsgm.utils.reconstruction_loss_by_axis(X, Xr, axis=0) +\ - tsgm.utils.reconstruction_loss_by_axis(X, Xr, axis=1) +\ - tsgm.utils.reconstruction_loss_by_axis(X, Xr, axis=2) + def _get_reconstruction_loss( + self, X: tsgm.types.Tensor, Xr: tsgm.types.Tensor + ) -> float: + reconst_loss = ( + tsgm.utils.reconstruction_loss_by_axis(X, Xr, axis=0) + + tsgm.utils.reconstruction_loss_by_axis(X, Xr, axis=1) + + tsgm.utils.reconstruction_loss_by_axis(X, Xr, axis=2) + ) return reconst_loss def train_step(self, data: tsgm.types.Tensor) -> T.Dict: @@ -106,9 +114,44 @@ def generate(self, n: int) -> tsgm.types.Tensor: z = tf.random.normal((n, self.latent_dim)) return self.decoder(z) + def train_step_torch(self, data: torch.Tensor) -> T.Dict: + """ + Performs a training step using a batch of data, stored in data. + + :param data: A batch of data in a format batch_size x seq_len x feat_dim + :type data: torch.Tensor + + :returns: A dict with losses + :rtype: T.Dict + """ + z_mean, z_log_var, z = self.encoder(data) + reconstruction = self.decoder(z) + reconstruction_loss = self._get_reconstruction_loss(data, reconstruction) + kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)) + kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1)) + total_loss = reconstruction_loss + kl_loss + + self.total_loss_tracker.update_state(total_loss) + self.reconstruction_loss_tracker.update_state(reconstruction_loss) + self.kl_loss_tracker.update_state(kl_loss) + + return { + "loss": self.total_loss_tracker.result(), + "reconstruction_loss": self.reconstruction_loss_tracker.result(), + "kl_loss": self.kl_loss_tracker.result(), + } + class cBetaVAE(keras.Model): - def __init__(self, encoder: keras.Model, decoder: keras.Model, latent_dim: int, temporal: bool, beta: float = 1.0, **kwargs) -> None: + def __init__( + self, + encoder: keras.Model, + decoder: keras.Model, + latent_dim: int, + temporal: bool, + beta: float = 1.0, + **kwargs + ) -> None: super(cBetaVAE, self).__init__(**kwargs) self.beta = beta self.encoder = encoder @@ -134,7 +177,9 @@ def metrics(self) -> T.List: self.kl_loss_tracker, ] - def generate(self, labels: tsgm.types.Tensor) -> T.Tuple[tsgm.types.Tensor, tsgm.types.Tensor]: + def generate( + self, labels: tsgm.types.Tensor + ) -> T.Tuple[tsgm.types.Tensor, tsgm.types.Tensor]: """ Generates new data from the model. @@ -145,7 +190,9 @@ def generate(self, labels: tsgm.types.Tensor) -> T.Tuple[tsgm.types.Tensor, tsgm :rtype: T.Tuple[tsgm.types.Tensor, tsgm.types.Tensor] """ batch_size = tf.shape(labels)[0] - z = tf.random.normal((batch_size, self._seq_len, self.latent_dim), dtype=labels.dtype) + z = tf.random.normal( + (batch_size, self._seq_len, self.latent_dim), dtype=labels.dtype + ) decoder_input = self._get_decoder_input(z, labels) return (self.decoder(decoder_input), labels) @@ -168,20 +215,36 @@ def call(self, data: tsgm.types.Tensor) -> tsgm.types.Tensor: x_decoded = x_decoded.reshape((1, -1)) return x_decoded - def _get_reconstruction_loss(self, X: tsgm.types.Tensor, Xr: tsgm.types.Tensor) -> float: - reconst_loss = tf.reduce_sum(tf.math.squared_difference(X, Xr)) +\ - tf.reduce_sum(tf.math.squared_difference(tf.reduce_mean(X, axis=1), tf.reduce_mean(Xr, axis=1))) +\ - tf.reduce_sum(tf.math.squared_difference(tf.reduce_mean(X, axis=2), tf.reduce_mean(Xr, axis=2))) + def _get_reconstruction_loss( + self, X: tsgm.types.Tensor, Xr: tsgm.types.Tensor + ) -> float: + reconst_loss = ( + tf.reduce_sum(tf.math.squared_difference(X, Xr)) + + tf.reduce_sum( + tf.math.squared_difference( + tf.reduce_mean(X, axis=1), tf.reduce_mean(Xr, axis=1) + ) + ) + + tf.reduce_sum( + tf.math.squared_difference( + tf.reduce_mean(X, axis=2), tf.reduce_mean(Xr, axis=2) + ) + ) + ) return reconst_loss - def _get_encoder_input(self, X: tsgm.types.Tensor, labels: tsgm.types.Tensor) -> tsgm.types.Tensor: + def _get_encoder_input( + self, X: tsgm.types.Tensor, labels: tsgm.types.Tensor + ) -> tsgm.types.Tensor: if self._temporal: return tf.concat([X, labels[:, :, None]], axis=2) else: rep_labels = tf.repeat(labels[:, None, :], [self._seq_len], axis=1) return tf.concat([X, rep_labels], axis=2) - def _get_decoder_input(self, z: tsgm.types.Tensor, labels: tsgm.types.Tensor) -> tsgm.types.Tensor: + def _get_decoder_input( + self, z: tsgm.types.Tensor, labels: tsgm.types.Tensor + ) -> tsgm.types.Tensor: if self._temporal: rep_labels = labels[:, :, None] else: @@ -220,3 +283,36 @@ def train_step(self, data: tsgm.types.Tensor) -> T.Dict[str, float]: "reconstruction_loss": self.reconstruction_loss_tracker.result(), "kl_loss": self.kl_loss_tracker.result(), } + + def train_step_torch( + self, data: T.Tuple[torch.Tensor, torch.Tensor] + ) -> T.Dict[str, float]: + """ + Performs a training step using a batch of data, stored in data. + + :param data: A batch of data in a format batch_size x seq_len x feat_dim + :type data: T.Tuple[torch.Tensor, torch.Tensor] + + :returns: A dict with losses + :rtype: T.Dict[str, float] + """ + X, labels = data + encoder_input = self._get_encoder_input(X, labels) + z_mean, z_log_var, z = self.encoder(encoder_input) + + decoder_input = self._get_decoder_input(z_mean, labels) + reconstruction = self.decoder(decoder_input) + reconstruction_loss = self._get_reconstruction_loss(X, reconstruction) + kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)) + kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1)) + total_loss = reconstruction_loss + self.beta * kl_loss + + self.total_loss_tracker.update_state(total_loss) + self.reconstruction_loss_tracker.update_state(reconstruction_loss) + self.kl_loss_tracker.update_state(kl_loss) + + return { + "loss": self.total_loss_tracker.result(), + "reconstruction_loss": self.reconstruction_loss_tracker.result(), + "kl_loss": self.kl_loss_tracker.result(), + }