Skip to content

Commit

Permalink
use GradientTape to get GP
Browse files Browse the repository at this point in the history
  • Loading branch information
kimmo1019 committed May 14, 2024
1 parent 5a70518 commit 6c0ada8
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions src/CausalEGM/causalEGM.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,16 @@ def train_disc_step(self, data_z, data_v):
with tf.GradientTape(persistent=True) as disc_tape:
data_v_ = self.g_net(data_z)
data_z_ = self.e_net(data_v)
data_z_hat = data_z*epsilon_z + data_z_*(1-epsilon_z)
data_v_hat = data_v*epsilon_v + data_v_*(1-epsilon_v)

with tf.GradientTape() as gp_tape_z:
gp_tape_z.watch(data_z_hat)
data_dz_hat = self.dz_net(data_z_hat)
with tf.GradientTape() as gp_tape_v:
gp_tape_v.watch(data_v_hat)
data_dv_hat = self.dv_net(data_v_hat)

data_dv_ = self.dv_net(data_v_)
data_dz_ = self.dz_net(data_z_)

Expand All @@ -215,16 +224,12 @@ def train_disc_step(self, data_z, data_v):
dv_loss = -tf.reduce_mean(data_dv) + tf.reduce_mean(data_dv_)

#gradient penalty for z
data_z_hat = data_z*epsilon_z + data_z_*(1-epsilon_z)
data_dz_hat = self.dz_net(data_z_hat)
grad_z = tf.gradients(data_dz_hat, data_z_hat)[0] #(bs,z_dim)
grad_z = gp_tape_z.gradient(data_dz_hat, data_z_hat) #(bs,z_dim)
grad_norm_z = tf.sqrt(tf.reduce_sum(tf.square(grad_z), axis=1))#(bs,)
gpz_loss = tf.reduce_mean(tf.square(grad_norm_z - 1.0))

#gradient penalty for v
data_v_hat = data_v*epsilon_v + data_v_*(1-epsilon_v)
data_dv_hat = self.dv_net(data_v_hat)
grad_v = tf.gradients(data_dv_hat, data_v_hat)[0] #(bs,v_dim)
grad_v = gp_tape_v.gradient(data_dv_hat, data_v_hat) #(bs,v_dim)
grad_norm_v = tf.sqrt(tf.reduce_sum(tf.square(grad_v), axis=1))#(bs,)
gpv_loss = tf.reduce_mean(tf.square(grad_norm_v - 1.0))

Expand Down

0 comments on commit 6c0ada8

Please sign in to comment.