Skip to content

Commit

Permalink
fix example
Browse files Browse the repository at this point in the history
  • Loading branch information
SoniaMaz8 committed Nov 20, 2024
1 parent fba3e2c commit d341e67
Showing 1 changed file with 24 additions and 23 deletions.
47 changes: 24 additions & 23 deletions examples/plot_Sinkhorn_gradients.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# %%
# -*- coding: utf-8 -*-
"""
================================================
Expand Down Expand Up @@ -28,55 +29,55 @@

# %% parameters

n = 100 # nb bins
n_trials = 500
n_trials = 30
times_autodiff = torch.zeros(n_trials)
times_envelope = torch.zeros(n_trials)
times_last_step = torch.zeros(n_trials)

# bin positions
x = np.arange(n, dtype=np.float64)
n_samples_s = 300
n_samples_t = 300
n_features = 5
reg = 0.03

# Time required for the Sinkhorn solver and gradient computations, for different gradient options over multiple Gaussian distributions
for i in range(n_trials):
# Gaussian distributions with random parameters
ma = np.random.randint(10, 40, 2)
sa = np.random.randint(5, 10, 2)
mb = np.random.randint(10, 40)
sb = np.random.randint(5, 10)

a = 0.6 * gauss(n, m=ma[0], s=sa[0]) + 0.4 * gauss(
n, m=ma[1], s=sa[1]
) # m= mean, s= std
b = gauss(n, m=mb, s=sb)

# loss matrix
M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
M /= M.max()
x = torch.rand((n_samples_s, n_features))
y = torch.rand((n_samples_t, n_features))
a = ot.utils.unif(n_samples_s)
b = ot.utils.unif(n_samples_t)
M = ot.dist(x, y)

a = torch.tensor(a, requires_grad=True)
b = torch.tensor(b, requires_grad=True)
M = torch.tensor(M, requires_grad=True)
M = M.clone().detach().requires_grad_(True)

# autodiff provides the gradient for all the outputs (plan, value, value_linear)
ot.tic()
res_autodiff = ot.solve(M, a, b, reg=10, grad="autodiff")
res_autodiff = ot.solve(M, a, b, reg=reg, grad="autodiff")
res_autodiff.value.backward()
times_autodiff[i] = ot.toq()

a = a.clone().detach().requires_grad_(True)
b = b.clone().detach().requires_grad_(True)
M = M.clone().detach().requires_grad_(True)

# envelope provides the gradient for value
ot.tic()
res_envelope = ot.solve(M, a, b, reg=10, grad="envelope")
res_envelope = ot.solve(M, a, b, reg=reg, grad="envelope")
res_envelope.value.backward()
times_envelope[i] = ot.toq()

a = a.clone().detach().requires_grad_(True)
b = b.clone().detach().requires_grad_(True)
M = M.clone().detach().requires_grad_(True)

# last_step provides the gradient for all the outputs, but only for the last iteration of the Sinkhorn algorithm
ot.tic()
res_last_step = ot.solve(M, a, b, reg=10, grad="last_step")
res_last_step = ot.solve(M, a, b, reg=reg, grad="last_step")
res_last_step.value.backward()
times_last_step[i] = ot.toq()

pl.figure(1, figsize=(4, 3))
pl.figure(1, figsize=(5, 3))
pl.ticklabel_format(axis="both", style="sci", scilimits=(0, 0))
pl.boxplot(
([times_autodiff, times_envelope, times_last_step]),
Expand Down

0 comments on commit d341e67

Please sign in to comment.