Skip to content

Commit

Permalink
[MRG] Translation Invariant Sinkhorn for Unbalanced OT (#676)
Browse files Browse the repository at this point in the history
* uot sinkhorn translation invariant

* correct log sinkhorn_ti

* fix log sinkhorn_ti

* test infinite reg sinkhorn unbalanced

* fix doc translation invariant sinkhorn

* fix pep8

* avoid nan in loop ti sinkhorn

* Add test multiple hists, log False

* up test multiple input with reg_type='entropy'

* up test multiple inputs

* correct number ref

* correct number ref

* jax vmap searchsorted

* jax vmap searchsorted
  • Loading branch information
clbonet authored Oct 16, 2024
1 parent 791137b commit 1a6c790
Show file tree
Hide file tree
Showing 7 changed files with 476 additions and 18 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -381,4 +381,6 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil
[71] H. Tran, H. Janati, N. Courty, R. Flamary, I. Redko, P. Demetci & R. Singh (2023). [Unbalanced Co-Optimal Transport](https://dl.acm.org/doi/10.1609/aaai.v37i8.26193). AAAI Conference on
Artificial Intelligence.

[72] Thibault Séjourné, François-Xavier Vialard, and Gabriel Peyré (2021). [The Unbalanced Gromov Wasserstein Distance: Conic Formulation and Relaxation](https://proceedings.neurips.cc/paper/2021/file/4990974d150d0de5e6e15a1454fe6b0f-Paper.pdf). Neural Information Processing Systems (NeurIPS).
[72] Thibault Séjourné, François-Xavier Vialard, and Gabriel Peyré (2021). [The Unbalanced Gromov Wasserstein Distance: Conic Formulation and Relaxation](https://proceedings.neurips.cc/paper/2021/file/4990974d150d0de5e6e15a1454fe6b0f-Paper.pdf). Neural Information Processing Systems (NeurIPS).

[73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). [Faster Unbalanced Optimal Transport: Translation Invariant Sinkhorn and 1-D Frank-Wolfe](https://proceedings.mlr.press/v151/sejourne22a.html). In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR.
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
- Restructured `ot.unbalanced` module (PR #658)
- Added `ot.unbalanced.lbfgsb_unbalanced2` and add flexible reference measure `c` in all unbalanced solvers (PR #658)
- Implemented Fused unbalanced Gromov-Wasserstein and unbalanced Co-Optimal Transport (PR #677)
- Added `ot.unbalanced.sinkhorn_unbalanced_translation_invariant` (PR #676)

#### Closed issues
- Fixed `ot.gaussian` ignoring weights when computing means (PR #649, Issue #648)
Expand Down
101 changes: 101 additions & 0 deletions examples/unbalanced-partial/plot_conv_sinkhorn_ti.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# -*- coding: utf-8 -*-
"""
===============================================================
Translation Invariant Sinkhorn for Unbalanced Optimal Transport
===============================================================
This examples illustrates the better convergence of the translation
invariance Sinkhorn algorithm proposed in [73] compared to the classical
Sinkhorn algorithm.
[73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022).
Faster unbalanced optimal transport: Translation invariant sinkhorn and 1-d frank-wolfe.
In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR.
"""

# Author: Clément Bonet <clement.bonet@ensae.fr>
# License: MIT License

import numpy as np
import matplotlib.pylab as pl
import ot

##############################################################################
# Setting parameters
# -------------

# %% parameters

n_iter = 50 # nb iters
n = 40 # nb samples

num_iter_max = 100
n_noise = 10

reg = 0.005
reg_m_kl = 0.05

mu_s = np.array([-1, -1])
cov_s = np.array([[1, 0], [0, 1]])

mu_t = np.array([4, 4])
cov_t = np.array([[1, -.8], [-.8, 1]])


##############################################################################
# Compute entropic kl-regularized UOT with Sinkhorn and Translation Invariant Sinkhorn
# -----------

err_sinkhorn_uot = np.empty((n_iter, num_iter_max))
err_sinkhorn_uot_ti = np.empty((n_iter, num_iter_max))


for seed in range(n_iter):
np.random.seed(seed)
xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)

xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) - 4))), axis=0)
xt = np.concatenate((xt, ((np.random.rand(n_noise, 2) + 6))), axis=0)

n = n + n_noise

a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples

# loss matrix
M = ot.dist(xs, xt)
M /= M.max()

entropic_kl_uot, log_uot = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl, reg_type="kl", log=True, numItermax=num_iter_max, stopThr=0)
entropic_kl_uot_ti, log_uot_ti = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl, reg_type="kl",
method="sinkhorn_translation_invariant", log=True,
numItermax=num_iter_max, stopThr=0)

err_sinkhorn_uot[seed] = log_uot["err"]
err_sinkhorn_uot_ti[seed] = log_uot_ti["err"]

##############################################################################
# Plot the results
# ----------------

mean_sinkh = np.mean(err_sinkhorn_uot, axis=0)
std_sinkh = np.std(err_sinkhorn_uot, axis=0)

mean_sinkh_ti = np.mean(err_sinkhorn_uot_ti, axis=0)
std_sinkh_ti = np.std(err_sinkhorn_uot_ti, axis=0)

absc = list(range(num_iter_max))

pl.plot(absc, mean_sinkh, label="Sinkhorn")
pl.fill_between(absc, mean_sinkh - 2 * std_sinkh, mean_sinkh + 2 * std_sinkh, alpha=0.5)

pl.plot(absc, mean_sinkh_ti, label="Translation Invariant Sinkhorn")
pl.fill_between(absc, mean_sinkh_ti - 2 * std_sinkh_ti, mean_sinkh_ti + 2 * std_sinkh_ti, alpha=0.5)

pl.yscale("log")
pl.legend()
pl.xlabel("Number of Iterations")
pl.ylabel(r"$\|u-v\|_\infty$")
pl.grid(True)
pl.show()
4 changes: 1 addition & 3 deletions ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1590,9 +1590,7 @@ def searchsorted(self, a, v, side='left'):
if a.ndim == 1:
return jnp.searchsorted(a, v, side)
else:
# this is a not very efficient way to make jax numpy
# searchsorted work on 2d arrays
return jnp.array([jnp.searchsorted(a[i, :], v[i, :], side) for i in range(a.shape[0])])
return jax.vmap(lambda b, u: jnp.searchsorted(b, u, side))(a, v)

def flip(self, a, axis=None):
return jnp.flip(a, axis)
Expand Down
4 changes: 3 additions & 1 deletion ot/unbalanced/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ._sinkhorn import (sinkhorn_knopp_unbalanced,
sinkhorn_unbalanced,
sinkhorn_stabilized_unbalanced,
sinkhorn_unbalanced_translation_invariant,
sinkhorn_unbalanced2,
barycenter_unbalanced_sinkhorn,
barycenter_unbalanced_stabilized,
Expand All @@ -22,6 +23,7 @@
from ._lbfgs import (lbfgsb_unbalanced, lbfgsb_unbalanced2)

__all__ = ['sinkhorn_knopp_unbalanced', 'sinkhorn_unbalanced', 'sinkhorn_stabilized_unbalanced',
'sinkhorn_unbalanced2', 'barycenter_unbalanced_sinkhorn', 'barycenter_unbalanced_stabilized',
'sinkhorn_unbalanced_translation_invariant', 'sinkhorn_unbalanced2',
'barycenter_unbalanced_sinkhorn', 'barycenter_unbalanced_stabilized',
'barycenter_unbalanced', 'mm_unbalanced', 'mm_unbalanced2', '_get_loss_unbalanced',
'lbfgsb_unbalanced', 'lbfgsb_unbalanced2']
Loading

0 comments on commit 1a6c790

Please sign in to comment.