Skip to content

Commit

Permalink
Merge branch 'master' into bwgd_barycenter
Browse files Browse the repository at this point in the history
  • Loading branch information
cedricvincentcuaz authored Dec 1, 2024
2 parents d4045f1 + 1761d0b commit f669a8e
Show file tree
Hide file tree
Showing 9 changed files with 327 additions and 19 deletions.
88 changes: 88 additions & 0 deletions CITATION.cff
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# CFF file for POT contributors

cff-version: 1.2.0
title: POT Python Optimal Transport
message: >-
If you use this software, please cite it using the
metadata from this file.
type: software
authors:
- given-names: Rémi
family-names: Flamary
affiliation: École Polytechnique
orcid: 'https://orcid.org/0000-0002-4212-6627'
- given-names: Cédric
family-names: Vincent-Cuaz
affiliation: EPFL
- given-names: Nicolas
family-names: Courty
affiliation: Université Bretagne Sud
- given-names: Alexandre
family-names: Gramfort
affiliation: INRIA
- given-names: Oleksii
family-names: Kachaiev
affiliation: Università degli Studi di Genova
- given-names: Huy
family-names: Quang Tran
affiliation: Université Bretagne Sud
- given-names: Laurène
family-names: David
affiliation: Institut Polytechnique de Paris
- given-names: Clément
family-names: Bonet
affiliation: ENSAE - CREST
- given-names: Nathan
family-names: Cassereau
affiliation: IDRIS-CNRS
- given-names: Théo
family-names: Gnassounou
affiliation: INRIA
- given-names: Eloi
family-names: Tanguy
affiliation: Université Paris-Cité
- given-names: Julie
family-names: Delon
affiliation: Université Paris-Cité
- given-names: Antoine
family-names: Collas
affiliation: INRIA
- given-names: Sonia
family-names: Mazelet
affiliation: Ecole Polytechnique
- given-names: Laetitia
family-names: Chapel
affiliation: Institut Agro Rennes-Angers, IRISA
- given-names: Tanguy
family-names: Kerdoncuff
affiliation: Université de Lyon
- given-names: Xizheng
family-names: Yu
affiliation: Brown University
- given-names: Matthew
family-names: Feickert
affiliation: University of Wisconsin-Madison
- given-names: Paul
family-names: Krzakala
affiliation: Telecom Paris
- given-names: Tianlin
family-names: Liu
affiliation: University of Basel
- given-names: Eduardo
family-names: Fernandes Montesuma
affiliation: Université Paris-Saclay & CEA-List
orcid: 'https://orcid.org/0000-0003-3850-4602'
identifiers:
- type: url
value: 'https://github.com/PythonOT/POT'
description: Code
repository-code: 'https://github.com/PythonOT/POT'
url: 'https://pythonot.github.io/'
keywords:
- optimal transport
- python
- sinkhorn
- wasserstein
- gromov-wasserstein
license: MIT
version: 0.9.5
17 changes: 12 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,23 @@ Some other examples are available in the [documentation](https://pythonot.githu
#### Using and citing the toolbox

If you use this toolbox in your research and find it useful, please cite POT
using the following reference from our [JMLR paper](https://jmlr.org/papers/v22/20-451.html):
using the following references from the current version and from our [JMLR
paper](https://jmlr.org/papers/v22/20-451.html):

Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer,
POT Python Optimal Transport library,
Journal of Machine Learning Research, 22(78):1−8, 2021.
Website: https://pythonot.github.io/
Flamary R., Vincent-Cuaz C., Courty N., Gramfort A., Kachaiev O., Quang Tran H., David L., Bonet C., Cassereau N., Gnassounou T., Tanguy E., Delon J., Collas A., Mazelet S., Chapel L., Kerdoncuff T., Yu X., Feickert M., Krzakala P., Liu T., Fernandes Montesuma E. POT Python Optimal Transport (version 0.9.5). URL: https://github.com/PythonOT/POT

Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer, POT Python Optimal Transport library, Journal of Machine Learning Research, 22(78):1−8, 2021. URL: https://pythonot.github.io/

In Bibtex format:

```bibtex
@misc{flamary2024pot,
author = {Flamary, R{\'e}mi and Vincent-Cuaz, C{\'e}dric and Courty, Nicolas and Gramfort, Alexandre and Kachaiev, Oleksii and Quang Tran, Huy and David, Laurène and Bonet, Cl{\'e}ment and Cassereau, Nathan and Gnassounou, Th{\'e}o and Tanguy, Eloi and Delon, Julie and Collas, Antoine and Mazelet, Sonia and Chapel, Laetitia and Kerdoncuff, Tanguy and Yu, Xizheng and Feickert, Matthew and Krzakala, Paul and Liu, Tianlin and Fernandes Montesuma, Eduardo},
title = {POT Python Optimal Transport (version 0.9.5)},
url = {https://github.com/PythonOT/POT},
year = {2024}
}
@article{flamary2021pot,
author = {R{\'e}mi Flamary and Nicolas Courty and Alexandre Gramfort and Mokhtar Z. Alaya and Aur{\'e}lie Boisbunon and Stanislas Chambon and Laetitia Chapel and Adrien Corenflos and Kilian Fatras and Nemo Fournier and L{\'e}o Gautheron and Nathalie T.H. Gayraud and Hicham Janati and Alain Rakotomamonjy and Ievgen Redko and Antoine Rolet and Antony Schutz and Vivien Seguy and Danica J. Sutherland and Romain Tavenard and Alexander Tong and Titouan Vayer},
title = {POT: Python Optimal Transport},
Expand Down
4 changes: 3 additions & 1 deletion RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@

#### New features
- Implement CG solvers for partial FGW (PR #687)
- Added feature `grad=last_step` for `ot.solvers.solve` (PR #693)

#### Closed issues
- Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668)
- Fixed numerical errors in `ot.gmm` (PR #690, Issue #689)

- Add version number to the documentation (PR #696)
- Update doc for default regularization in `ot.unbalanced` sinkhorn solvers (Issue #691, PR #700)

## 0.9.5

Expand Down
2 changes: 1 addition & 1 deletion docs/source/_templates/versions.html
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
<!-- add shift_up to the class for force viewing ,
data-toggle="rst-current-version" -->
<span class="rst-current-version" style="margin-bottom:1mm;">
<span class="fa fa-book"> Python Optimal Transport</span>
<span class="fa fa-book"> Python Optimal Transport</span> {{ version }}
<hr style="margin-bottom:1.5mm;margin-top:5mm;">
<!-- versions
<span class="fa fa-caret-down"></span>-->
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def __getattr__(cls, name):
# further. For a list of options available for each theme, see the
# documentation.

html_theme_options = {}
html_theme_options = {"version_selector": True}

# Add any paths that contain custom themes here, relative to this directory.
# html_theme_path = []
Expand Down
85 changes: 85 additions & 0 deletions examples/backends/plot_Sinkhorn_gradients.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# -*- coding: utf-8 -*-
"""
================================================
Different gradient computations for regularized optimal transport
================================================
This example illustrates the differences in terms of computation time between the gradient options for the Sinkhorn solver.
"""

# Author: Sonia Mazelet <sonia.mazelet@polytechnique.edu>
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 1

import matplotlib.pylab as pl
import ot
from ot.backend import torch


##############################################################################
# Time comparison of the Sinkhorn solver for different gradient options
# -------------


# %% parameters

n_trials = 10
times_autodiff = torch.zeros(n_trials)
times_envelope = torch.zeros(n_trials)
times_last_step = torch.zeros(n_trials)

n_samples_s = 300
n_samples_t = 300
n_features = 5
reg = 0.03

# Time required for the Sinkhorn solver and gradient computations, for different gradient options over multiple Gaussian distributions
for i in range(n_trials):
x = torch.rand((n_samples_s, n_features))
y = torch.rand((n_samples_t, n_features))
a = ot.utils.unif(n_samples_s)
b = ot.utils.unif(n_samples_t)
M = ot.dist(x, y)

a = torch.tensor(a, requires_grad=True)
b = torch.tensor(b, requires_grad=True)
M = M.clone().detach().requires_grad_(True)

# autodiff provides the gradient for all the outputs (plan, value, value_linear)
ot.tic()
res_autodiff = ot.solve(M, a, b, reg=reg, grad="autodiff")
res_autodiff.value.backward()
times_autodiff[i] = ot.toq()

a = a.clone().detach().requires_grad_(True)
b = b.clone().detach().requires_grad_(True)
M = M.clone().detach().requires_grad_(True)

# envelope provides the gradient for value
ot.tic()
res_envelope = ot.solve(M, a, b, reg=reg, grad="envelope")
res_envelope.value.backward()
times_envelope[i] = ot.toq()

a = a.clone().detach().requires_grad_(True)
b = b.clone().detach().requires_grad_(True)
M = M.clone().detach().requires_grad_(True)

# last_step provides the gradient for all the outputs, but only for the last iteration of the Sinkhorn algorithm
ot.tic()
res_last_step = ot.solve(M, a, b, reg=reg, grad="last_step")
res_last_step.value.backward()
times_last_step[i] = ot.toq()

pl.figure(1, figsize=(5, 3))
pl.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
pl.boxplot(
([times_autodiff, times_envelope, times_last_step]),
tick_labels=["autodiff", "envelope", "last_step"],
showfliers=False,
)
pl.ylabel("Time (s)")
pl.show()
37 changes: 31 additions & 6 deletions ot/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,13 @@ def solve(
verbose : bool, optional
Print information in the solver, by default False
grad : str, optional
Type of gradient computation, either or 'autodiff' or 'envelope' used only for
Type of gradient computation, either or 'autodiff', 'envelope' or 'last_step' used only for
Sinkhorn solver. By default 'autodiff' provides gradients wrt all
outputs (`plan, value, value_linear`) but with important memory cost.
'envelope' provides gradients only for `value` and and other outputs are
detached. This is useful for memory saving when only the value is needed.
detached. This is useful for memory saving when only the value is needed. 'last_step' provides
gradients only for the last iteration of the Sinkhorn solver, but provides gradient for both the OT plan and the objective values.
'detach' does not compute the gradients for the Sinkhorn solver.
Returns
-------
Expand Down Expand Up @@ -281,7 +283,6 @@ def solve(
linear regression. NeurIPS.
"""

# detect backend
nx = get_backend(M, a, b, c)

Expand Down Expand Up @@ -412,7 +413,11 @@ def solve(
potentials = (log["u"], log["v"])

elif reg_type.lower() in ["entropy", "kl"]:
if grad == "envelope": # if envelope then detach the input
if grad in [
"envelope",
"last_step",
"detach",
]: # if envelope, last_step or detach then detach the input
M0, a0, b0 = M, a, b
M, a, b = nx.detach(M, a, b)

Expand All @@ -421,6 +426,12 @@ def solve(
max_iter = 1000
if tol is None:
tol = 1e-9
if grad == "last_step":
if max_iter == 0:
raise ValueError(
"The maximum number of iterations must be greater than 0 when using grad=last_step."
)
max_iter = max_iter - 1

plan, log = sinkhorn_log(
a,
Expand All @@ -433,6 +444,22 @@ def solve(
verbose=verbose,
)

potentials = (log["log_u"], log["log_v"])

# if last_step, compute the last step of the Sinkhorn algorithm with the non-detached inputs
if grad == "last_step":
loga = nx.log(a0)
logb = nx.log(b0)
v = logb - nx.logsumexp(-M0 / reg + potentials[0][:, None], 0)
u = loga - nx.logsumexp(-M0 / reg + potentials[1][None, :], 1)
plan = nx.exp(-M0 / reg + u[:, None] + v[None, :])
potentials = (u, v)
log["niter"] = max_iter + 1
log["log_u"] = u
log["log_v"] = v
log["u"] = nx.exp(u)
log["v"] = nx.exp(v)

value_linear = nx.sum(M * plan)

if reg_type.lower() == "entropy":
Expand All @@ -442,8 +469,6 @@ def solve(
plan, a[:, None] * b[None, :]
)

potentials = (log["log_u"], log["log_v"])

if grad == "envelope": # set the gradient at convergence
value = nx.set_gradients(
value,
Expand Down
24 changes: 20 additions & 4 deletions ot/unbalanced/_sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,12 @@ def sinkhorn_unbalanced(
- KL is the Kullback-Leibler divergence
The algorithm used for solving the problem is the generalized
Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] <references-sinkhorn-unbalanced>`
Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25]
<references-sinkhorn-unbalanced>`
.. warning::
Starting from version 0.9.5, the default value has been changed to `reg_type='kl'` instead of `reg_type='entropy'`. This makes the function more consistent with the literature
and the other solvers. If you want to use the entropy regularization, please set `reg_type='entropy'` explicitly.
Parameters
Expand Down Expand Up @@ -91,7 +96,7 @@ def sinkhorn_unbalanced(
+ Negative entropy: 'entropy':
:math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`.
This is equivalent (up to a constant) to :math:`\Omega(\gamma) = \text{KL}(\gamma, 1_{dim_a} 1_{dim_b}^T)`.
+ Kullback-Leibler divergence: 'kl':
+ Kullback-Leibler divergence (default): 'kl':
:math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`.
c : array-like (dim_a, dim_b), optional (default=None)
Reference measure for the regularization.
Expand Down Expand Up @@ -281,8 +286,12 @@ def sinkhorn_unbalanced2(
- KL is the Kullback-Leibler divergence
The algorithm used for solving the problem is the generalized
Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] <references-sinkhorn-unbalanced2>`
Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25]
<references-sinkhorn-unbalanced2>`
.. warning::
Starting from version 0.9.5, the default value has been changed to `reg_type='kl'` instead of `reg_type='entropy'`. This makes the function more consistent with the literature
and the other solvers. If you want to use the entropy regularization, please set `reg_type='entropy'` explicitly.
Parameters
----------
Expand Down Expand Up @@ -588,6 +597,10 @@ def sinkhorn_knopp_unbalanced(
The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] <references-sinkhorn-knopp-unbalanced>`
.. warning::
Starting from version 0.9.5, the default value has been changed to `reg_type='kl'` instead of `reg_type='entropy'`. This makes the function more consistent with the literature
and the other solvers. If you want to use the entropy regularization, please set `reg_type='entropy'` explicitly.
Parameters
----------
Expand Down Expand Up @@ -895,6 +908,10 @@ def sinkhorn_stabilized_unbalanced(
log : bool, optional
record `log` if `True`
.. warning::
Starting from version 0.9.5, the default value has been changed to `reg_type='kl'` instead of `reg_type='entropy'`. This makes the function more consistent with the literature
and the other solvers. If you want to use the entropy regularization, please set `reg_type='entropy'` explicitly.
Returns
-------
Expand Down Expand Up @@ -1132,7 +1149,6 @@ def sinkhorn_unbalanced_translation_invariant(
The algorithm used for solving the problem is the translation invariant Sinkhorn algorithm as proposed in :ref:`[73] <references-sinkhorn-unbalanced-translation-invariant>`
Parameters
----------
a : array-like (dim_a,)
Expand Down
Loading

0 comments on commit f669a8e

Please sign in to comment.