From 5ab38c59cd2cd9952bb419b9f79d85a0aa8e1694 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Mon, 6 Jan 2025 17:35:47 +0100 Subject: [PATCH] reset solve_gromov behavior for unbalanced=none --- ot/solvers.py | 58 ++++++++------------ test/test_solvers.py | 123 +++++++++++++++++-------------------------- 2 files changed, 71 insertions(+), 110 deletions(-) diff --git a/ot/solvers.py b/ot/solvers.py index 5f8f65870..4b30eced7 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -32,6 +32,9 @@ from .lowrank import lowrank_sinkhorn from .optim import cg +import warnings + + lst_method_lazy = [ "1d", "gaussian", @@ -658,7 +661,8 @@ def solve_gromov( ``alpha=0.5`` for Fused Gromov-Wasserstein problem (``M!=None``) unbalanced : float, optional Unbalanced penalization weight :math:`\lambda_u`, by default None - (balanced OT), Not implemented yet + (balanced OT). Not implemented yet for "KL" unbalanced penalization + function :math:`U`. Corresponds to the total transport mass for partial OT. unbalanced_type : str, optional Type of unbalanced penalization function :math:`U` either "KL", "semirelaxed", "partial", by default "KL" but note that it is not implemented yet. @@ -864,8 +868,14 @@ def solve_gromov( if reg is None or reg == 0: # exact OT if unbalanced is None and unbalanced_type.lower() not in [ "semirelaxed", - "partial", ]: # Exact balanced OT + if unbalanced_type.lower() in ["partial"]: + warnings.warn( + "Exact balanced OT is computed as `unbalanced=None` even though " + f"unbalanced_type = {unbalanced_type}.", + stacklevel=2, + ) + if M is None or alpha == 1: # Gromov-Wasserstein problem # default values for solver if max_iter is None: @@ -1002,14 +1012,7 @@ def solve_gromov( elif unbalanced_type.lower() in ["partial"]: # Partial OT if M is None or alpha == 1.0: # Partial Gromov-Wasserstein problem - if unbalanced is None: - raise ( - ValueError( - "Partial GW mass given in `unbalanced` must be float and not None" - ) - ) - - elif unbalanced > nx.sum(a) or unbalanced > nx.sum(b): + if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): raise ( ValueError("Partial GW mass given in `unbalanced` is too large") ) @@ -1040,14 +1043,7 @@ def solve_gromov( # potentials = (log['u'], log['v']) TODO else: # partial FGW - if unbalanced is None: - raise ( - ValueError( - "Partial GW mass given in `unbalanced` must be float and not None" - ) - ) - - elif unbalanced > nx.sum(a) or unbalanced > nx.sum(b): + if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): raise ( ValueError("Partial GW mass given in `unbalanced` is too large") ) @@ -1092,8 +1088,14 @@ def solve_gromov( else: # regularized OT if unbalanced is None and unbalanced_type.lower() not in [ "semirelaxed", - "partial", ]: # Balanced regularized OT + if unbalanced_type.lower() in ["partial"]: + warnings.warn( + "Exact balanced OT is computed as `unbalanced=None` even though " + f"unbalanced_type = {unbalanced_type}.", + stacklevel=2, + ) + if reg_type.lower() in ["entropy"] and ( M is None or alpha == 1 ): # Entropic Gromov-Wasserstein problem @@ -1250,14 +1252,7 @@ def solve_gromov( elif unbalanced_type.lower() in ["partial"]: # Partial OT if M is None or alpha == 1.0: # Partial Gromov-Wasserstein problem - if unbalanced is None: - raise ( - ValueError( - "Partial GW mass given in `unbalanced` must be float and not None" - ) - ) - - elif unbalanced > nx.sum(a) or unbalanced > nx.sum(b): + if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): raise ( ValueError("Partial GW mass given in `unbalanced` is too large") ) @@ -1289,14 +1284,7 @@ def solve_gromov( # potentials = (log['u'], log['v']) TODO value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) else: # partial FGW - if unbalanced is None: - raise ( - ValueError( - "Partial GW mass given in `unbalanced` must be float and not None" - ) - ) - - elif unbalanced > nx.sum(a) or unbalanced > nx.sum(b): + if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): raise ( ValueError("Partial GW mass given in `unbalanced` is too large") ) diff --git a/test/test_solvers.py b/test/test_solvers.py index b1bd097a3..a0c1d7c43 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -432,82 +432,55 @@ def test_solve_gromov_grid(nx, reg, reg_type, unbalanced, unbalanced_type, alpha M = np.random.rand(n_samples_s, n_samples_t) try: - if unbalanced_type == "partial" and unbalanced is None: - ax, bx, Mx, Cax, Cbx = nx.from_numpy(a, b, M, Ca, Cb) - - with pytest.raises(ValueError): - solx = ot.solve_gromov( - Cax, - Cbx, - reg=reg, - reg_type=reg_type, - unbalanced=unbalanced, - unbalanced_type=unbalanced_type, - loss=loss, - ) # GW - with pytest.raises(ValueError): - solx_fgw = ot.solve_gromov( - Cax, - Cbx, - Mx, - reg=reg, - reg_type=reg_type, - unbalanced=unbalanced, - unbalanced_type=unbalanced_type, - alpha=alpha, - loss=loss, - ) # FGW + sol0 = ot.solve_gromov( + Ca, + Cb, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + loss=loss, + ) # GW + sol0_fgw = ot.solve_gromov( + Ca, + Cb, + M, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + alpha=alpha, + loss=loss, + ) # FGW - else: - sol0 = ot.solve_gromov( - Ca, - Cb, - reg=reg, - reg_type=reg_type, - unbalanced=unbalanced, - unbalanced_type=unbalanced_type, - loss=loss, - ) # GW - sol0_fgw = ot.solve_gromov( - Ca, - Cb, - M, - reg=reg, - reg_type=reg_type, - unbalanced=unbalanced, - unbalanced_type=unbalanced_type, - alpha=alpha, - loss=loss, - ) # FGW - - # solve in backend - ax, bx, Mx, Cax, Cbx = nx.from_numpy(a, b, M, Ca, Cb) - - solx = ot.solve_gromov( - Cax, - Cbx, - reg=reg, - reg_type=reg_type, - unbalanced=unbalanced, - unbalanced_type=unbalanced_type, - loss=loss, - ) # GW - solx_fgw = ot.solve_gromov( - Cax, - Cbx, - Mx, - reg=reg, - reg_type=reg_type, - unbalanced=unbalanced, - unbalanced_type=unbalanced_type, - alpha=alpha, - loss=loss, - ) # FGW - - solx.value_quad - - assert_allclose_sol(sol0, solx) - assert_allclose_sol(sol0_fgw, solx_fgw) + # solve in backend + ax, bx, Mx, Cax, Cbx = nx.from_numpy(a, b, M, Ca, Cb) + + solx = ot.solve_gromov( + Cax, + Cbx, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + loss=loss, + ) # GW + solx_fgw = ot.solve_gromov( + Cax, + Cbx, + Mx, + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + alpha=alpha, + loss=loss, + ) # FGW + + solx.value_quad + + assert_allclose_sol(sol0, solx) + assert_allclose_sol(sol0_fgw, solx_fgw) except NotImplementedError: pytest.skip("Not implemented")