Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Jul 30, 2024
1 parent 118f3b6 commit 0ad909e
Showing 1 changed file with 42 additions and 58 deletions.
100 changes: 42 additions & 58 deletions ctgan/synthesizers/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,14 @@ def calc_gradient_penalty(self, real_data, fake_data, device='cpu', pac=10, lamb
interpolates = alpha * real_data + ((1 - alpha) * fake_data)

disc_interpolates = self(interpolates)
self.set_device(device)
gradients = torch.autograd.grad(
outputs=disc_interpolates,
inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size(), device=device),
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]

with warnings.catch_warnings():
warnings.simplefilter('ignore', category=UserWarning)
gradients = torch.autograd.grad(
outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size(), device=device),
create_graph=True, retain_graph=True, only_inputs=True
)[0]

gradients_view = gradients.view(-1, pac * real_data.size(1)).norm(2, dim=1) - 1
gradient_penalty = ((gradients_view) ** 2).mean() * lambda_
Expand Down Expand Up @@ -143,23 +142,11 @@ class CTGAN(BaseSynthesizer):
Defaults to ``True``.
"""

def __init__(
self,
embedding_dim=128,
generator_dim=(256, 256),
discriminator_dim=(256, 256),
generator_lr=2e-4,
generator_decay=1e-6,
discriminator_lr=2e-4,
discriminator_decay=1e-6,
batch_size=500,
discriminator_steps=1,
log_frequency=True,
verbose=False,
epochs=300,
pac=10,
cuda=True,
):
def __init__(self, embedding_dim=128, generator_dim=(256, 256), discriminator_dim=(256, 256),
generator_lr=2e-4, generator_decay=1e-6, discriminator_lr=2e-4,
discriminator_decay=1e-6, batch_size=500, discriminator_steps=1,
log_frequency=True, verbose=False, epochs=300, pac=10, cuda=True):

assert batch_size % 2 == 0

self._embedding_dim = embedding_dim
Expand Down Expand Up @@ -254,7 +241,9 @@ def _cond_loss(self, data, c, m):
ed = st + span_info.dim
ed_c = st_c + span_info.dim
tmp = functional.cross_entropy(
data[:, st:ed], torch.argmax(c[:, st_c:ed_c], dim=1), reduction='none'
data[:, st:ed],
torch.argmax(c[:, st_c:ed_c], dim=1),
reduction='none'
)
loss.append(tmp)
st = ed
Expand Down Expand Up @@ -308,11 +297,9 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
epochs = self._epochs
else:
warnings.warn(
(
'`epochs` argument in `fit` method has been deprecated and will be removed '
'in a future version. Please pass `epochs` to the constructor instead'
),
DeprecationWarning,
('`epochs` argument in `fit` method has been deprecated and will be removed '
'in a future version. Please pass `epochs` to the constructor instead'),
DeprecationWarning
)

self._transformer = DataTransformer()
Expand All @@ -321,31 +308,32 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
train_data = self._transformer.transform(train_data)

self._data_sampler = DataSampler(
train_data, self._transformer.output_info_list, self._log_frequency
)
train_data,
self._transformer.output_info_list,
self._log_frequency)

data_dim = self._transformer.output_dimensions

self._generator = Generator(
self._embedding_dim + self._data_sampler.dim_cond_vec(), self._generator_dim, data_dim
self._embedding_dim + self._data_sampler.dim_cond_vec(),
self._generator_dim,
data_dim
).to(self._device)

discriminator = Discriminator(
data_dim + self._data_sampler.dim_cond_vec(), self._discriminator_dim, pac=self.pac
data_dim + self._data_sampler.dim_cond_vec(),
self._discriminator_dim,
pac=self.pac
).to(self._device)

optimizerG = optim.Adam(
self._generator.parameters(),
lr=self._generator_lr,
betas=(0.5, 0.9),
weight_decay=self._generator_decay,
self._generator.parameters(), lr=self._generator_lr, betas=(0.5, 0.9),
weight_decay=self._generator_decay
)

optimizerD = optim.Adam(
discriminator.parameters(),
lr=self._discriminator_lr,
betas=(0.5, 0.9),
weight_decay=self._discriminator_decay,
discriminator.parameters(), lr=self._discriminator_lr,
betas=(0.5, 0.9), weight_decay=self._discriminator_decay
)

mean = torch.zeros(self._batch_size, self._embedding_dim, device=self._device)
Expand All @@ -361,15 +349,15 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
steps_per_epoch = max(len(train_data) // self._batch_size, 1)
for i in epoch_iterator:
for id_ in range(steps_per_epoch):

for n in range(self._discriminator_steps):
fakez = torch.normal(mean=mean, std=std)

condvec = self._data_sampler.sample_condvec(self._batch_size)
if condvec is None:
c1, m1, col, opt = None, None, None, None
real = self._data_sampler.sample_data(
train_data, self._batch_size, col, opt
)
train_data, self._batch_size, col, opt)
else:
c1, m1, col, opt = condvec
c1 = torch.from_numpy(c1).to(self._device)
Expand All @@ -379,8 +367,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
perm = np.arange(self._batch_size)
np.random.shuffle(perm)
real = self._data_sampler.sample_data(
train_data, self._batch_size, col[perm], opt[perm]
)
train_data, self._batch_size, col[perm], opt[perm])
c2 = c1[perm]

fake = self._generator(fakez)
Expand All @@ -399,8 +386,7 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
y_real = discriminator(real_cat)

pen = discriminator.calc_gradient_penalty(
real_cat, fake_cat, self._device, self.pac
)
real_cat, fake_cat, self._device, self.pac)
loss_d = -(torch.mean(y_real) - torch.mean(y_fake))

optimizerD.zero_grad(set_to_none=False)
Expand Down Expand Up @@ -444,12 +430,12 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
epoch_loss_df = pd.DataFrame({
'Epoch': [i],
'Generator Loss': [generator_loss],
'Discriminator Loss': [discriminator_loss],
'Discriminator Loss': [discriminator_loss]
})
if not self.loss_values.empty:
self.loss_values = pd.concat([self.loss_values, epoch_loss_df]).reset_index(
drop=True
)
self.loss_values = pd.concat(
[self.loss_values, epoch_loss_df]
).reset_index(drop=True)
else:
self.loss_values = epoch_loss_df

Expand Down Expand Up @@ -479,11 +465,9 @@ def sample(self, n, condition_column=None, condition_value=None):
"""
if condition_column is not None and condition_value is not None:
condition_info = self._transformer.convert_column_name_value_to_id(
condition_column, condition_value
)
condition_column, condition_value)
global_condition_vec = self._data_sampler.generate_cond_from_condition_column_info(
condition_info, self._batch_size
)
condition_info, self._batch_size)
else:
global_condition_vec = None

Expand Down

0 comments on commit 0ad909e

Please sign in to comment.