Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
fealho committed Jul 29, 2024
1 parent 118f3b6 commit dd48c67
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions ctgan/synthesizers/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,30 @@ def __init__(self, input_dim, discriminator_dim, pac=10):

def calc_gradient_penalty(self, real_data, fake_data, device='cpu', pac=10, lambda_=10):
"""Compute the gradient penalty."""
torch.cuda.empty_cache()
torch.cuda.set_device(0)
alpha = torch.rand(real_data.size(0) // pac, 1, 1, device=device)
alpha = alpha.repeat(1, pac, real_data.size(1))
alpha = alpha.view(-1, real_data.size(1))
print('alpha: ', alpha.device)

interpolates = alpha * real_data + ((1 - alpha) * fake_data)
print('interpolates: ', interpolates.device)

disc_interpolates = self(interpolates)
self.set_device(device)
print('disc_interpolates: ', disc_interpolates.device)
a = torch.ones(disc_interpolates.size(), device=device)
print('a: ', a.device)

gradients = torch.autograd.grad(
outputs=disc_interpolates,
inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size(), device=device),
grad_outputs=a,
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
print('gradients: ', gradients.device)

gradients_view = gradients.view(-1, pac * real_data.size(1)).norm(2, dim=1) - 1
gradient_penalty = ((gradients_view) ** 2).mean() * lambda_
Expand Down

0 comments on commit dd48c67

Please sign in to comment.