Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add torch train to vae #44

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ optuna
prettytable
seaborn
scikit-learn
torch
yfinance==0.2.28
tqdm
124 changes: 110 additions & 14 deletions tsgm/models/cvae.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from tensorflow import keras
import tensorflow as tf
import torch
import typing as T

import tsgm.utils
Expand All @@ -9,7 +10,10 @@ class BetaVAE(keras.Model):
"""
beta-VAE implementation for unlabeled time series.
"""
def __init__(self, encoder: keras.Model, decoder: keras.Model, beta: float = 1.0, **kwargs) -> None:

def __init__(
self, encoder: keras.Model, decoder: keras.Model, beta: float = 1.0, **kwargs
) -> None:
"""
:param encoder: An encoder model which takes a time series as input and check
whether the image is real or fake.
Expand Down Expand Up @@ -59,10 +63,14 @@ def call(self, X: tsgm.types.Tensor) -> tsgm.types.Tensor:
x_decoded = x_decoded.reshape((1, -1))
return x_decoded

def _get_reconstruction_loss(self, X: tsgm.types.Tensor, Xr: tsgm.types.Tensor) -> float:
reconst_loss = tsgm.utils.reconstruction_loss_by_axis(X, Xr, axis=0) +\
tsgm.utils.reconstruction_loss_by_axis(X, Xr, axis=1) +\
tsgm.utils.reconstruction_loss_by_axis(X, Xr, axis=2)
def _get_reconstruction_loss(
self, X: tsgm.types.Tensor, Xr: tsgm.types.Tensor
) -> float:
reconst_loss = (
tsgm.utils.reconstruction_loss_by_axis(X, Xr, axis=0)
+ tsgm.utils.reconstruction_loss_by_axis(X, Xr, axis=1)
+ tsgm.utils.reconstruction_loss_by_axis(X, Xr, axis=2)
)
return reconst_loss

def train_step(self, data: tsgm.types.Tensor) -> T.Dict:
Expand Down Expand Up @@ -106,9 +114,44 @@ def generate(self, n: int) -> tsgm.types.Tensor:
z = tf.random.normal((n, self.latent_dim))
return self.decoder(z)

def train_step_torch(self, data: torch.Tensor) -> T.Dict:
"""
Performs a training step using a batch of data, stored in data.

:param data: A batch of data in a format batch_size x seq_len x feat_dim
:type data: torch.Tensor

:returns: A dict with losses
:rtype: T.Dict
"""
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z)
reconstruction_loss = self._get_reconstruction_loss(data, reconstruction)
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
total_loss = reconstruction_loss + kl_loss

self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)

return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}


class cBetaVAE(keras.Model):
def __init__(self, encoder: keras.Model, decoder: keras.Model, latent_dim: int, temporal: bool, beta: float = 1.0, **kwargs) -> None:
def __init__(
self,
encoder: keras.Model,
decoder: keras.Model,
latent_dim: int,
temporal: bool,
beta: float = 1.0,
**kwargs
) -> None:
super(cBetaVAE, self).__init__(**kwargs)
self.beta = beta
self.encoder = encoder
Expand All @@ -134,7 +177,9 @@ def metrics(self) -> T.List:
self.kl_loss_tracker,
]

def generate(self, labels: tsgm.types.Tensor) -> T.Tuple[tsgm.types.Tensor, tsgm.types.Tensor]:
def generate(
self, labels: tsgm.types.Tensor
) -> T.Tuple[tsgm.types.Tensor, tsgm.types.Tensor]:
"""
Generates new data from the model.

Expand All @@ -145,7 +190,9 @@ def generate(self, labels: tsgm.types.Tensor) -> T.Tuple[tsgm.types.Tensor, tsgm
:rtype: T.Tuple[tsgm.types.Tensor, tsgm.types.Tensor]
"""
batch_size = tf.shape(labels)[0]
z = tf.random.normal((batch_size, self._seq_len, self.latent_dim), dtype=labels.dtype)
z = tf.random.normal(
(batch_size, self._seq_len, self.latent_dim), dtype=labels.dtype
)
decoder_input = self._get_decoder_input(z, labels)
return (self.decoder(decoder_input), labels)

Expand All @@ -168,20 +215,36 @@ def call(self, data: tsgm.types.Tensor) -> tsgm.types.Tensor:
x_decoded = x_decoded.reshape((1, -1))
return x_decoded

def _get_reconstruction_loss(self, X: tsgm.types.Tensor, Xr: tsgm.types.Tensor) -> float:
reconst_loss = tf.reduce_sum(tf.math.squared_difference(X, Xr)) +\
tf.reduce_sum(tf.math.squared_difference(tf.reduce_mean(X, axis=1), tf.reduce_mean(Xr, axis=1))) +\
tf.reduce_sum(tf.math.squared_difference(tf.reduce_mean(X, axis=2), tf.reduce_mean(Xr, axis=2)))
def _get_reconstruction_loss(
self, X: tsgm.types.Tensor, Xr: tsgm.types.Tensor
) -> float:
reconst_loss = (
tf.reduce_sum(tf.math.squared_difference(X, Xr))
+ tf.reduce_sum(
tf.math.squared_difference(
tf.reduce_mean(X, axis=1), tf.reduce_mean(Xr, axis=1)
)
)
+ tf.reduce_sum(
tf.math.squared_difference(
tf.reduce_mean(X, axis=2), tf.reduce_mean(Xr, axis=2)
)
)
)
return reconst_loss

def _get_encoder_input(self, X: tsgm.types.Tensor, labels: tsgm.types.Tensor) -> tsgm.types.Tensor:
def _get_encoder_input(
self, X: tsgm.types.Tensor, labels: tsgm.types.Tensor
) -> tsgm.types.Tensor:
if self._temporal:
return tf.concat([X, labels[:, :, None]], axis=2)
else:
rep_labels = tf.repeat(labels[:, None, :], [self._seq_len], axis=1)
return tf.concat([X, rep_labels], axis=2)

def _get_decoder_input(self, z: tsgm.types.Tensor, labels: tsgm.types.Tensor) -> tsgm.types.Tensor:
def _get_decoder_input(
self, z: tsgm.types.Tensor, labels: tsgm.types.Tensor
) -> tsgm.types.Tensor:
if self._temporal:
rep_labels = labels[:, :, None]
else:
Expand Down Expand Up @@ -220,3 +283,36 @@ def train_step(self, data: tsgm.types.Tensor) -> T.Dict[str, float]:
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}

def train_step_torch(
self, data: T.Tuple[torch.Tensor, torch.Tensor]
) -> T.Dict[str, float]:
"""
Performs a training step using a batch of data, stored in data.

:param data: A batch of data in a format batch_size x seq_len x feat_dim
:type data: T.Tuple[torch.Tensor, torch.Tensor]

:returns: A dict with losses
:rtype: T.Dict[str, float]
"""
X, labels = data
encoder_input = self._get_encoder_input(X, labels)
z_mean, z_log_var, z = self.encoder(encoder_input)

decoder_input = self._get_decoder_input(z_mean, labels)
reconstruction = self.decoder(decoder_input)
reconstruction_loss = self._get_reconstruction_loss(X, reconstruction)
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
total_loss = reconstruction_loss + self.beta * kl_loss

self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)

return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
Loading