diff --git a/ctgan/synthesizers/ctgan.py b/ctgan/synthesizers/ctgan.py index bbcac8f3..02aff697 100644 --- a/ctgan/synthesizers/ctgan.py +++ b/ctgan/synthesizers/ctgan.py @@ -39,15 +39,14 @@ def calc_gradient_penalty(self, real_data, fake_data, device='cpu', pac=10, lamb interpolates = alpha * real_data + ((1 - alpha) * fake_data) disc_interpolates = self(interpolates) - self.set_device(device) - gradients = torch.autograd.grad( - outputs=disc_interpolates, - inputs=interpolates, - grad_outputs=torch.ones(disc_interpolates.size(), device=device), - create_graph=True, - retain_graph=True, - only_inputs=True, - )[0] + + with warnings.catch_warnings(): + warnings.simplefilter('ignore', category=UserWarning) + gradients = torch.autograd.grad( + outputs=disc_interpolates, inputs=interpolates, + grad_outputs=torch.ones(disc_interpolates.size(), device=device), + create_graph=True, retain_graph=True, only_inputs=True + )[0] gradients_view = gradients.view(-1, pac * real_data.size(1)).norm(2, dim=1) - 1 gradient_penalty = ((gradients_view) ** 2).mean() * lambda_ @@ -143,23 +142,11 @@ class CTGAN(BaseSynthesizer): Defaults to ``True``. """ - def __init__( - self, - embedding_dim=128, - generator_dim=(256, 256), - discriminator_dim=(256, 256), - generator_lr=2e-4, - generator_decay=1e-6, - discriminator_lr=2e-4, - discriminator_decay=1e-6, - batch_size=500, - discriminator_steps=1, - log_frequency=True, - verbose=False, - epochs=300, - pac=10, - cuda=True, - ): + def __init__(self, embedding_dim=128, generator_dim=(256, 256), discriminator_dim=(256, 256), + generator_lr=2e-4, generator_decay=1e-6, discriminator_lr=2e-4, + discriminator_decay=1e-6, batch_size=500, discriminator_steps=1, + log_frequency=True, verbose=False, epochs=300, pac=10, cuda=True): + assert batch_size % 2 == 0 self._embedding_dim = embedding_dim @@ -254,7 +241,9 @@ def _cond_loss(self, data, c, m): ed = st + span_info.dim ed_c = st_c + span_info.dim tmp = functional.cross_entropy( - data[:, st:ed], torch.argmax(c[:, st_c:ed_c], dim=1), reduction='none' + data[:, st:ed], + torch.argmax(c[:, st_c:ed_c], dim=1), + reduction='none' ) loss.append(tmp) st = ed @@ -308,11 +297,9 @@ def fit(self, train_data, discrete_columns=(), epochs=None): epochs = self._epochs else: warnings.warn( - ( - '`epochs` argument in `fit` method has been deprecated and will be removed ' - 'in a future version. Please pass `epochs` to the constructor instead' - ), - DeprecationWarning, + ('`epochs` argument in `fit` method has been deprecated and will be removed ' + 'in a future version. Please pass `epochs` to the constructor instead'), + DeprecationWarning ) self._transformer = DataTransformer() @@ -321,31 +308,32 @@ def fit(self, train_data, discrete_columns=(), epochs=None): train_data = self._transformer.transform(train_data) self._data_sampler = DataSampler( - train_data, self._transformer.output_info_list, self._log_frequency - ) + train_data, + self._transformer.output_info_list, + self._log_frequency) data_dim = self._transformer.output_dimensions self._generator = Generator( - self._embedding_dim + self._data_sampler.dim_cond_vec(), self._generator_dim, data_dim + self._embedding_dim + self._data_sampler.dim_cond_vec(), + self._generator_dim, + data_dim ).to(self._device) discriminator = Discriminator( - data_dim + self._data_sampler.dim_cond_vec(), self._discriminator_dim, pac=self.pac + data_dim + self._data_sampler.dim_cond_vec(), + self._discriminator_dim, + pac=self.pac ).to(self._device) optimizerG = optim.Adam( - self._generator.parameters(), - lr=self._generator_lr, - betas=(0.5, 0.9), - weight_decay=self._generator_decay, + self._generator.parameters(), lr=self._generator_lr, betas=(0.5, 0.9), + weight_decay=self._generator_decay ) optimizerD = optim.Adam( - discriminator.parameters(), - lr=self._discriminator_lr, - betas=(0.5, 0.9), - weight_decay=self._discriminator_decay, + discriminator.parameters(), lr=self._discriminator_lr, + betas=(0.5, 0.9), weight_decay=self._discriminator_decay ) mean = torch.zeros(self._batch_size, self._embedding_dim, device=self._device) @@ -361,6 +349,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None): steps_per_epoch = max(len(train_data) // self._batch_size, 1) for i in epoch_iterator: for id_ in range(steps_per_epoch): + for n in range(self._discriminator_steps): fakez = torch.normal(mean=mean, std=std) @@ -368,8 +357,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None): if condvec is None: c1, m1, col, opt = None, None, None, None real = self._data_sampler.sample_data( - train_data, self._batch_size, col, opt - ) + train_data, self._batch_size, col, opt) else: c1, m1, col, opt = condvec c1 = torch.from_numpy(c1).to(self._device) @@ -379,8 +367,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None): perm = np.arange(self._batch_size) np.random.shuffle(perm) real = self._data_sampler.sample_data( - train_data, self._batch_size, col[perm], opt[perm] - ) + train_data, self._batch_size, col[perm], opt[perm]) c2 = c1[perm] fake = self._generator(fakez) @@ -399,8 +386,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None): y_real = discriminator(real_cat) pen = discriminator.calc_gradient_penalty( - real_cat, fake_cat, self._device, self.pac - ) + real_cat, fake_cat, self._device, self.pac) loss_d = -(torch.mean(y_real) - torch.mean(y_fake)) optimizerD.zero_grad(set_to_none=False) @@ -444,12 +430,12 @@ def fit(self, train_data, discrete_columns=(), epochs=None): epoch_loss_df = pd.DataFrame({ 'Epoch': [i], 'Generator Loss': [generator_loss], - 'Discriminator Loss': [discriminator_loss], + 'Discriminator Loss': [discriminator_loss] }) if not self.loss_values.empty: - self.loss_values = pd.concat([self.loss_values, epoch_loss_df]).reset_index( - drop=True - ) + self.loss_values = pd.concat( + [self.loss_values, epoch_loss_df] + ).reset_index(drop=True) else: self.loss_values = epoch_loss_df @@ -479,11 +465,9 @@ def sample(self, n, condition_column=None, condition_value=None): """ if condition_column is not None and condition_value is not None: condition_info = self._transformer.convert_column_name_value_to_id( - condition_column, condition_value - ) + condition_column, condition_value) global_condition_vec = self._data_sampler.generate_cond_from_condition_column_info( - condition_info, self._batch_size - ) + condition_info, self._batch_size) else: global_condition_vec = None